diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..326c945019085a8d7e0921da847e81cbefd01a47 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/__pycache__/_typing.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/__pycache__/_typing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..adfff30af128e9794f71ab4a501af233f7e32f5b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/__pycache__/_typing.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/__pycache__/_version.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/__pycache__/_version.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef45b2e4bb306e2953b8214ce5ee9f31009a214d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/__pycache__/_version.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/__pycache__/_version_meson.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/__pycache__/_version_meson.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56423935ebd8c55447d22d20e6528cf16a968d56 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/__pycache__/_version_meson.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/__pycache__/conftest.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/__pycache__/conftest.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fcedf8dab2ad5a033756f35c75d9f58412651eb Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/__pycache__/conftest.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/__pycache__/testing.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/__pycache__/testing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6df4cd8af262e3ad3d129942241ec6490c4ba5e4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/__pycache__/testing.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/_testing/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/_testing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..87d419e2db8dd017688415dc4d97d0d940914123 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/_testing/__init__.py @@ -0,0 +1,639 @@ +from __future__ import annotations + +from decimal import Decimal +import operator +import os +from sys import byteorder +from typing import ( + TYPE_CHECKING, + Callable, + ContextManager, + cast, +) +import warnings + +import numpy as np + +from pandas._config.localization import ( + can_set_locale, + get_locales, + set_locale, +) + +from pandas.compat import pa_version_under10p1 + +from pandas.core.dtypes.common import is_string_dtype + +import pandas as pd +from pandas import ( + ArrowDtype, + DataFrame, + Index, + MultiIndex, + RangeIndex, + Series, +) +from pandas._testing._io import ( + round_trip_localpath, + round_trip_pathlib, + round_trip_pickle, + write_to_compressed, +) +from pandas._testing._warnings import ( + assert_produces_warning, + maybe_produces_warning, +) +from pandas._testing.asserters import ( + assert_almost_equal, + assert_attr_equal, + assert_categorical_equal, + assert_class_equal, + assert_contains_all, + assert_copy, + assert_datetime_array_equal, + assert_dict_equal, + assert_equal, + assert_extension_array_equal, + assert_frame_equal, + assert_index_equal, + assert_indexing_slices_equivalent, + assert_interval_array_equal, + assert_is_sorted, + assert_is_valid_plot_return_object, + assert_metadata_equivalent, + assert_numpy_array_equal, + assert_period_array_equal, + assert_series_equal, + assert_sp_array_equal, + assert_timedelta_array_equal, + raise_assert_detail, +) +from pandas._testing.compat import ( + get_dtype, + get_obj, +) +from pandas._testing.contexts import ( + assert_cow_warning, + decompress_file, + ensure_clean, + raises_chained_assignment_error, + set_timezone, + use_numexpr, + with_csv_dialect, +) +from pandas.core.arrays import ( + BaseMaskedArray, + ExtensionArray, + NumpyExtensionArray, +) +from pandas.core.arrays._mixins import NDArrayBackedExtensionArray +from pandas.core.construction import extract_array + +if TYPE_CHECKING: + from pandas._typing import ( + Dtype, + NpDtype, + ) + + from pandas.core.arrays import ArrowExtensionArray + +UNSIGNED_INT_NUMPY_DTYPES: list[NpDtype] = ["uint8", "uint16", "uint32", "uint64"] +UNSIGNED_INT_EA_DTYPES: list[Dtype] = ["UInt8", "UInt16", "UInt32", "UInt64"] +SIGNED_INT_NUMPY_DTYPES: list[NpDtype] = [int, "int8", "int16", "int32", "int64"] +SIGNED_INT_EA_DTYPES: list[Dtype] = ["Int8", "Int16", "Int32", "Int64"] +ALL_INT_NUMPY_DTYPES = UNSIGNED_INT_NUMPY_DTYPES + SIGNED_INT_NUMPY_DTYPES +ALL_INT_EA_DTYPES = UNSIGNED_INT_EA_DTYPES + SIGNED_INT_EA_DTYPES +ALL_INT_DTYPES: list[Dtype] = [*ALL_INT_NUMPY_DTYPES, *ALL_INT_EA_DTYPES] + +FLOAT_NUMPY_DTYPES: list[NpDtype] = [float, "float32", "float64"] +FLOAT_EA_DTYPES: list[Dtype] = ["Float32", "Float64"] +ALL_FLOAT_DTYPES: list[Dtype] = [*FLOAT_NUMPY_DTYPES, *FLOAT_EA_DTYPES] + +COMPLEX_DTYPES: list[Dtype] = [complex, "complex64", "complex128"] +STRING_DTYPES: list[Dtype] = [str, "str", "U"] +COMPLEX_FLOAT_DTYPES: list[Dtype] = [*COMPLEX_DTYPES, *FLOAT_NUMPY_DTYPES] + +DATETIME64_DTYPES: list[Dtype] = ["datetime64[ns]", "M8[ns]"] +TIMEDELTA64_DTYPES: list[Dtype] = ["timedelta64[ns]", "m8[ns]"] + +BOOL_DTYPES: list[Dtype] = [bool, "bool"] +BYTES_DTYPES: list[Dtype] = [bytes, "bytes"] +OBJECT_DTYPES: list[Dtype] = [object, "object"] + +ALL_REAL_NUMPY_DTYPES = FLOAT_NUMPY_DTYPES + ALL_INT_NUMPY_DTYPES +ALL_REAL_EXTENSION_DTYPES = FLOAT_EA_DTYPES + ALL_INT_EA_DTYPES +ALL_REAL_DTYPES: list[Dtype] = [*ALL_REAL_NUMPY_DTYPES, *ALL_REAL_EXTENSION_DTYPES] +ALL_NUMERIC_DTYPES: list[Dtype] = [*ALL_REAL_DTYPES, *COMPLEX_DTYPES] + +ALL_NUMPY_DTYPES = ( + ALL_REAL_NUMPY_DTYPES + + COMPLEX_DTYPES + + STRING_DTYPES + + DATETIME64_DTYPES + + TIMEDELTA64_DTYPES + + BOOL_DTYPES + + OBJECT_DTYPES + + BYTES_DTYPES +) + +NARROW_NP_DTYPES = [ + np.float16, + np.float32, + np.int8, + np.int16, + np.int32, + np.uint8, + np.uint16, + np.uint32, +] + +PYTHON_DATA_TYPES = [ + str, + int, + float, + complex, + list, + tuple, + range, + dict, + set, + frozenset, + bool, + bytes, + bytearray, + memoryview, +] + +ENDIAN = {"little": "<", "big": ">"}[byteorder] + +NULL_OBJECTS = [None, np.nan, pd.NaT, float("nan"), pd.NA, Decimal("NaN")] +NP_NAT_OBJECTS = [ + cls("NaT", unit) + for cls in [np.datetime64, np.timedelta64] + for unit in [ + "Y", + "M", + "W", + "D", + "h", + "m", + "s", + "ms", + "us", + "ns", + "ps", + "fs", + "as", + ] +] + +if not pa_version_under10p1: + import pyarrow as pa + + UNSIGNED_INT_PYARROW_DTYPES = [pa.uint8(), pa.uint16(), pa.uint32(), pa.uint64()] + SIGNED_INT_PYARROW_DTYPES = [pa.int8(), pa.int16(), pa.int32(), pa.int64()] + ALL_INT_PYARROW_DTYPES = UNSIGNED_INT_PYARROW_DTYPES + SIGNED_INT_PYARROW_DTYPES + ALL_INT_PYARROW_DTYPES_STR_REPR = [ + str(ArrowDtype(typ)) for typ in ALL_INT_PYARROW_DTYPES + ] + + # pa.float16 doesn't seem supported + # https://github.com/apache/arrow/blob/master/python/pyarrow/src/arrow/python/helpers.cc#L86 + FLOAT_PYARROW_DTYPES = [pa.float32(), pa.float64()] + FLOAT_PYARROW_DTYPES_STR_REPR = [ + str(ArrowDtype(typ)) for typ in FLOAT_PYARROW_DTYPES + ] + DECIMAL_PYARROW_DTYPES = [pa.decimal128(7, 3)] + STRING_PYARROW_DTYPES = [pa.string()] + BINARY_PYARROW_DTYPES = [pa.binary()] + + TIME_PYARROW_DTYPES = [ + pa.time32("s"), + pa.time32("ms"), + pa.time64("us"), + pa.time64("ns"), + ] + DATE_PYARROW_DTYPES = [pa.date32(), pa.date64()] + DATETIME_PYARROW_DTYPES = [ + pa.timestamp(unit=unit, tz=tz) + for unit in ["s", "ms", "us", "ns"] + for tz in [None, "UTC", "US/Pacific", "US/Eastern"] + ] + TIMEDELTA_PYARROW_DTYPES = [pa.duration(unit) for unit in ["s", "ms", "us", "ns"]] + + BOOL_PYARROW_DTYPES = [pa.bool_()] + + # TODO: Add container like pyarrow types: + # https://arrow.apache.org/docs/python/api/datatypes.html#factory-functions + ALL_PYARROW_DTYPES = ( + ALL_INT_PYARROW_DTYPES + + FLOAT_PYARROW_DTYPES + + DECIMAL_PYARROW_DTYPES + + STRING_PYARROW_DTYPES + + BINARY_PYARROW_DTYPES + + TIME_PYARROW_DTYPES + + DATE_PYARROW_DTYPES + + DATETIME_PYARROW_DTYPES + + TIMEDELTA_PYARROW_DTYPES + + BOOL_PYARROW_DTYPES + ) + ALL_REAL_PYARROW_DTYPES_STR_REPR = ( + ALL_INT_PYARROW_DTYPES_STR_REPR + FLOAT_PYARROW_DTYPES_STR_REPR + ) +else: + FLOAT_PYARROW_DTYPES_STR_REPR = [] + ALL_INT_PYARROW_DTYPES_STR_REPR = [] + ALL_PYARROW_DTYPES = [] + ALL_REAL_PYARROW_DTYPES_STR_REPR = [] + +ALL_REAL_NULLABLE_DTYPES = ( + FLOAT_NUMPY_DTYPES + ALL_REAL_EXTENSION_DTYPES + ALL_REAL_PYARROW_DTYPES_STR_REPR +) + +arithmetic_dunder_methods = [ + "__add__", + "__radd__", + "__sub__", + "__rsub__", + "__mul__", + "__rmul__", + "__floordiv__", + "__rfloordiv__", + "__truediv__", + "__rtruediv__", + "__pow__", + "__rpow__", + "__mod__", + "__rmod__", +] + +comparison_dunder_methods = ["__eq__", "__ne__", "__le__", "__lt__", "__ge__", "__gt__"] + + +# ----------------------------------------------------------------------------- +# Comparators + + +def box_expected(expected, box_cls, transpose: bool = True): + """ + Helper function to wrap the expected output of a test in a given box_class. + + Parameters + ---------- + expected : np.ndarray, Index, Series + box_cls : {Index, Series, DataFrame} + + Returns + ------- + subclass of box_cls + """ + if box_cls is pd.array: + if isinstance(expected, RangeIndex): + # pd.array would return an IntegerArray + expected = NumpyExtensionArray(np.asarray(expected._values)) + else: + expected = pd.array(expected, copy=False) + elif box_cls is Index: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Dtype inference", category=FutureWarning) + expected = Index(expected) + elif box_cls is Series: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Dtype inference", category=FutureWarning) + expected = Series(expected) + elif box_cls is DataFrame: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Dtype inference", category=FutureWarning) + expected = Series(expected).to_frame() + if transpose: + # for vector operations, we need a DataFrame to be a single-row, + # not a single-column, in order to operate against non-DataFrame + # vectors of the same length. But convert to two rows to avoid + # single-row special cases in datetime arithmetic + expected = expected.T + expected = pd.concat([expected] * 2, ignore_index=True) + elif box_cls is np.ndarray or box_cls is np.array: + expected = np.array(expected) + elif box_cls is to_array: + expected = to_array(expected) + else: + raise NotImplementedError(box_cls) + return expected + + +def to_array(obj): + """ + Similar to pd.array, but does not cast numpy dtypes to nullable dtypes. + """ + # temporary implementation until we get pd.array in place + dtype = getattr(obj, "dtype", None) + + if dtype is None: + return np.asarray(obj) + + return extract_array(obj, extract_numpy=True) + + +class SubclassedSeries(Series): + _metadata = ["testattr", "name"] + + @property + def _constructor(self): + # For testing, those properties return a generic callable, and not + # the actual class. In this case that is equivalent, but it is to + # ensure we don't rely on the property returning a class + # See https://github.com/pandas-dev/pandas/pull/46018 and + # https://github.com/pandas-dev/pandas/issues/32638 and linked issues + return lambda *args, **kwargs: SubclassedSeries(*args, **kwargs) + + @property + def _constructor_expanddim(self): + return lambda *args, **kwargs: SubclassedDataFrame(*args, **kwargs) + + +class SubclassedDataFrame(DataFrame): + _metadata = ["testattr"] + + @property + def _constructor(self): + return lambda *args, **kwargs: SubclassedDataFrame(*args, **kwargs) + + @property + def _constructor_sliced(self): + return lambda *args, **kwargs: SubclassedSeries(*args, **kwargs) + + +def convert_rows_list_to_csv_str(rows_list: list[str]) -> str: + """ + Convert list of CSV rows to single CSV-formatted string for current OS. + + This method is used for creating expected value of to_csv() method. + + Parameters + ---------- + rows_list : List[str] + Each element represents the row of csv. + + Returns + ------- + str + Expected output of to_csv() in current OS. + """ + sep = os.linesep + return sep.join(rows_list) + sep + + +def external_error_raised(expected_exception: type[Exception]) -> ContextManager: + """ + Helper function to mark pytest.raises that have an external error message. + + Parameters + ---------- + expected_exception : Exception + Expected error to raise. + + Returns + ------- + Callable + Regular `pytest.raises` function with `match` equal to `None`. + """ + import pytest + + return pytest.raises(expected_exception, match=None) + + +cython_table = pd.core.common._cython_table.items() + + +def get_cython_table_params(ndframe, func_names_and_expected): + """ + Combine frame, functions from com._cython_table + keys and expected result. + + Parameters + ---------- + ndframe : DataFrame or Series + func_names_and_expected : Sequence of two items + The first item is a name of a NDFrame method ('sum', 'prod') etc. + The second item is the expected return value. + + Returns + ------- + list + List of three items (DataFrame, function, expected result) + """ + results = [] + for func_name, expected in func_names_and_expected: + results.append((ndframe, func_name, expected)) + results += [ + (ndframe, func, expected) + for func, name in cython_table + if name == func_name + ] + return results + + +def get_op_from_name(op_name: str) -> Callable: + """ + The operator function for a given op name. + + Parameters + ---------- + op_name : str + The op name, in form of "add" or "__add__". + + Returns + ------- + function + A function performing the operation. + """ + short_opname = op_name.strip("_") + try: + op = getattr(operator, short_opname) + except AttributeError: + # Assume it is the reverse operator + rop = getattr(operator, short_opname[1:]) + op = lambda x, y: rop(y, x) + + return op + + +# ----------------------------------------------------------------------------- +# Indexing test helpers + + +def getitem(x): + return x + + +def setitem(x): + return x + + +def loc(x): + return x.loc + + +def iloc(x): + return x.iloc + + +def at(x): + return x.at + + +def iat(x): + return x.iat + + +# ----------------------------------------------------------------------------- + +_UNITS = ["s", "ms", "us", "ns"] + + +def get_finest_unit(left: str, right: str): + """ + Find the higher of two datetime64 units. + """ + if _UNITS.index(left) >= _UNITS.index(right): + return left + return right + + +def shares_memory(left, right) -> bool: + """ + Pandas-compat for np.shares_memory. + """ + if isinstance(left, np.ndarray) and isinstance(right, np.ndarray): + return np.shares_memory(left, right) + elif isinstance(left, np.ndarray): + # Call with reversed args to get to unpacking logic below. + return shares_memory(right, left) + + if isinstance(left, RangeIndex): + return False + if isinstance(left, MultiIndex): + return shares_memory(left._codes, right) + if isinstance(left, (Index, Series)): + return shares_memory(left._values, right) + + if isinstance(left, NDArrayBackedExtensionArray): + return shares_memory(left._ndarray, right) + if isinstance(left, pd.core.arrays.SparseArray): + return shares_memory(left.sp_values, right) + if isinstance(left, pd.core.arrays.IntervalArray): + return shares_memory(left._left, right) or shares_memory(left._right, right) + + if ( + isinstance(left, ExtensionArray) + and is_string_dtype(left.dtype) + and left.dtype.storage in ("pyarrow", "pyarrow_numpy") # type: ignore[attr-defined] + ): + # https://github.com/pandas-dev/pandas/pull/43930#discussion_r736862669 + left = cast("ArrowExtensionArray", left) + if ( + isinstance(right, ExtensionArray) + and is_string_dtype(right.dtype) + and right.dtype.storage in ("pyarrow", "pyarrow_numpy") # type: ignore[attr-defined] + ): + right = cast("ArrowExtensionArray", right) + left_pa_data = left._pa_array + right_pa_data = right._pa_array + left_buf1 = left_pa_data.chunk(0).buffers()[1] + right_buf1 = right_pa_data.chunk(0).buffers()[1] + return left_buf1 == right_buf1 + + if isinstance(left, BaseMaskedArray) and isinstance(right, BaseMaskedArray): + # By convention, we'll say these share memory if they share *either* + # the _data or the _mask + return np.shares_memory(left._data, right._data) or np.shares_memory( + left._mask, right._mask + ) + + if isinstance(left, DataFrame) and len(left._mgr.arrays) == 1: + arr = left._mgr.arrays[0] + return shares_memory(arr, right) + + raise NotImplementedError(type(left), type(right)) + + +__all__ = [ + "ALL_INT_EA_DTYPES", + "ALL_INT_NUMPY_DTYPES", + "ALL_NUMPY_DTYPES", + "ALL_REAL_NUMPY_DTYPES", + "assert_almost_equal", + "assert_attr_equal", + "assert_categorical_equal", + "assert_class_equal", + "assert_contains_all", + "assert_copy", + "assert_datetime_array_equal", + "assert_dict_equal", + "assert_equal", + "assert_extension_array_equal", + "assert_frame_equal", + "assert_index_equal", + "assert_indexing_slices_equivalent", + "assert_interval_array_equal", + "assert_is_sorted", + "assert_is_valid_plot_return_object", + "assert_metadata_equivalent", + "assert_numpy_array_equal", + "assert_period_array_equal", + "assert_produces_warning", + "assert_series_equal", + "assert_sp_array_equal", + "assert_timedelta_array_equal", + "assert_cow_warning", + "at", + "BOOL_DTYPES", + "box_expected", + "BYTES_DTYPES", + "can_set_locale", + "COMPLEX_DTYPES", + "convert_rows_list_to_csv_str", + "DATETIME64_DTYPES", + "decompress_file", + "ENDIAN", + "ensure_clean", + "external_error_raised", + "FLOAT_EA_DTYPES", + "FLOAT_NUMPY_DTYPES", + "get_cython_table_params", + "get_dtype", + "getitem", + "get_locales", + "get_finest_unit", + "get_obj", + "get_op_from_name", + "iat", + "iloc", + "loc", + "maybe_produces_warning", + "NARROW_NP_DTYPES", + "NP_NAT_OBJECTS", + "NULL_OBJECTS", + "OBJECT_DTYPES", + "raise_assert_detail", + "raises_chained_assignment_error", + "round_trip_localpath", + "round_trip_pathlib", + "round_trip_pickle", + "setitem", + "set_locale", + "set_timezone", + "shares_memory", + "SIGNED_INT_EA_DTYPES", + "SIGNED_INT_NUMPY_DTYPES", + "STRING_DTYPES", + "SubclassedDataFrame", + "SubclassedSeries", + "TIMEDELTA64_DTYPES", + "to_array", + "UNSIGNED_INT_EA_DTYPES", + "UNSIGNED_INT_NUMPY_DTYPES", + "use_numexpr", + "with_csv_dialect", + "write_to_compressed", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/_testing/_hypothesis.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/_testing/_hypothesis.py new file mode 100644 index 0000000000000000000000000000000000000000..084ca9c306d192a2543108249dbc345d1259be01 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/_testing/_hypothesis.py @@ -0,0 +1,93 @@ +""" +Hypothesis data generator helpers. +""" +from datetime import datetime + +from hypothesis import strategies as st +from hypothesis.extra.dateutil import timezones as dateutil_timezones +from hypothesis.extra.pytz import timezones as pytz_timezones + +from pandas.compat import is_platform_windows + +import pandas as pd + +from pandas.tseries.offsets import ( + BMonthBegin, + BMonthEnd, + BQuarterBegin, + BQuarterEnd, + BYearBegin, + BYearEnd, + MonthBegin, + MonthEnd, + QuarterBegin, + QuarterEnd, + YearBegin, + YearEnd, +) + +OPTIONAL_INTS = st.lists(st.one_of(st.integers(), st.none()), max_size=10, min_size=3) + +OPTIONAL_FLOATS = st.lists(st.one_of(st.floats(), st.none()), max_size=10, min_size=3) + +OPTIONAL_TEXT = st.lists(st.one_of(st.none(), st.text()), max_size=10, min_size=3) + +OPTIONAL_DICTS = st.lists( + st.one_of(st.none(), st.dictionaries(st.text(), st.integers())), + max_size=10, + min_size=3, +) + +OPTIONAL_LISTS = st.lists( + st.one_of(st.none(), st.lists(st.text(), max_size=10, min_size=3)), + max_size=10, + min_size=3, +) + +OPTIONAL_ONE_OF_ALL = st.one_of( + OPTIONAL_DICTS, OPTIONAL_FLOATS, OPTIONAL_INTS, OPTIONAL_LISTS, OPTIONAL_TEXT +) + +if is_platform_windows(): + DATETIME_NO_TZ = st.datetimes(min_value=datetime(1900, 1, 1)) +else: + DATETIME_NO_TZ = st.datetimes() + +DATETIME_JAN_1_1900_OPTIONAL_TZ = st.datetimes( + min_value=pd.Timestamp( + 1900, 1, 1 + ).to_pydatetime(), # pyright: ignore[reportGeneralTypeIssues] + max_value=pd.Timestamp( + 1900, 1, 1 + ).to_pydatetime(), # pyright: ignore[reportGeneralTypeIssues] + timezones=st.one_of(st.none(), dateutil_timezones(), pytz_timezones()), +) + +DATETIME_IN_PD_TIMESTAMP_RANGE_NO_TZ = st.datetimes( + min_value=pd.Timestamp.min.to_pydatetime(warn=False), + max_value=pd.Timestamp.max.to_pydatetime(warn=False), +) + +INT_NEG_999_TO_POS_999 = st.integers(-999, 999) + +# The strategy for each type is registered in conftest.py, as they don't carry +# enough runtime information (e.g. type hints) to infer how to build them. +YQM_OFFSET = st.one_of( + *map( + st.from_type, + [ + MonthBegin, + MonthEnd, + BMonthBegin, + BMonthEnd, + QuarterBegin, + QuarterEnd, + BQuarterBegin, + BQuarterEnd, + YearBegin, + YearEnd, + BYearBegin, + BYearEnd, + ], + ) +) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/_testing/_io.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/_testing/_io.py new file mode 100644 index 0000000000000000000000000000000000000000..95977edb600ade42a8f8a1fada2b5085cee1da56 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/_testing/_io.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +import gzip +import io +import pathlib +import tarfile +from typing import ( + TYPE_CHECKING, + Any, + Callable, +) +import uuid +import zipfile + +from pandas.compat import ( + get_bz2_file, + get_lzma_file, +) +from pandas.compat._optional import import_optional_dependency + +import pandas as pd +from pandas._testing.contexts import ensure_clean + +if TYPE_CHECKING: + from pandas._typing import ( + FilePath, + ReadPickleBuffer, + ) + + from pandas import ( + DataFrame, + Series, + ) + +# ------------------------------------------------------------------ +# File-IO + + +def round_trip_pickle( + obj: Any, path: FilePath | ReadPickleBuffer | None = None +) -> DataFrame | Series: + """ + Pickle an object and then read it again. + + Parameters + ---------- + obj : any object + The object to pickle and then re-read. + path : str, path object or file-like object, default None + The path where the pickled object is written and then read. + + Returns + ------- + pandas object + The original object that was pickled and then re-read. + """ + _path = path + if _path is None: + _path = f"__{uuid.uuid4()}__.pickle" + with ensure_clean(_path) as temp_path: + pd.to_pickle(obj, temp_path) + return pd.read_pickle(temp_path) + + +def round_trip_pathlib(writer, reader, path: str | None = None): + """ + Write an object to file specified by a pathlib.Path and read it back + + Parameters + ---------- + writer : callable bound to pandas object + IO writing function (e.g. DataFrame.to_csv ) + reader : callable + IO reading function (e.g. pd.read_csv ) + path : str, default None + The path where the object is written and then read. + + Returns + ------- + pandas object + The original object that was serialized and then re-read. + """ + Path = pathlib.Path + if path is None: + path = "___pathlib___" + with ensure_clean(path) as path: + writer(Path(path)) # type: ignore[arg-type] + obj = reader(Path(path)) # type: ignore[arg-type] + return obj + + +def round_trip_localpath(writer, reader, path: str | None = None): + """ + Write an object to file specified by a py.path LocalPath and read it back. + + Parameters + ---------- + writer : callable bound to pandas object + IO writing function (e.g. DataFrame.to_csv ) + reader : callable + IO reading function (e.g. pd.read_csv ) + path : str, default None + The path where the object is written and then read. + + Returns + ------- + pandas object + The original object that was serialized and then re-read. + """ + import pytest + + LocalPath = pytest.importorskip("py.path").local + if path is None: + path = "___localpath___" + with ensure_clean(path) as path: + writer(LocalPath(path)) + obj = reader(LocalPath(path)) + return obj + + +def write_to_compressed(compression, path, data, dest: str = "test") -> None: + """ + Write data to a compressed file. + + Parameters + ---------- + compression : {'gzip', 'bz2', 'zip', 'xz', 'zstd'} + The compression type to use. + path : str + The file path to write the data. + data : str + The data to write. + dest : str, default "test" + The destination file (for ZIP only) + + Raises + ------ + ValueError : An invalid compression value was passed in. + """ + args: tuple[Any, ...] = (data,) + mode = "wb" + method = "write" + compress_method: Callable + + if compression == "zip": + compress_method = zipfile.ZipFile + mode = "w" + args = (dest, data) + method = "writestr" + elif compression == "tar": + compress_method = tarfile.TarFile + mode = "w" + file = tarfile.TarInfo(name=dest) + bytes = io.BytesIO(data) + file.size = len(data) + args = (file, bytes) + method = "addfile" + elif compression == "gzip": + compress_method = gzip.GzipFile + elif compression == "bz2": + compress_method = get_bz2_file() + elif compression == "zstd": + compress_method = import_optional_dependency("zstandard").open + elif compression == "xz": + compress_method = get_lzma_file() + else: + raise ValueError(f"Unrecognized compression type: {compression}") + + with compress_method(path, mode=mode) as f: + getattr(f, method)(*args) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/_testing/_warnings.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/_testing/_warnings.py new file mode 100644 index 0000000000000000000000000000000000000000..c9a287942f2dac5ddbaf49168db280ec2ba3f2c4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/_testing/_warnings.py @@ -0,0 +1,232 @@ +from __future__ import annotations + +from contextlib import ( + contextmanager, + nullcontext, +) +import inspect +import re +import sys +from typing import ( + TYPE_CHECKING, + Literal, + cast, +) +import warnings + +from pandas.compat import PY311 + +if TYPE_CHECKING: + from collections.abc import ( + Generator, + Sequence, + ) + + +@contextmanager +def assert_produces_warning( + expected_warning: type[Warning] | bool | tuple[type[Warning], ...] | None = Warning, + filter_level: Literal[ + "error", "ignore", "always", "default", "module", "once" + ] = "always", + check_stacklevel: bool = True, + raise_on_extra_warnings: bool = True, + match: str | None = None, +) -> Generator[list[warnings.WarningMessage], None, None]: + """ + Context manager for running code expected to either raise a specific warning, + multiple specific warnings, or not raise any warnings. Verifies that the code + raises the expected warning(s), and that it does not raise any other unexpected + warnings. It is basically a wrapper around ``warnings.catch_warnings``. + + Parameters + ---------- + expected_warning : {Warning, False, tuple[Warning, ...], None}, default Warning + The type of Exception raised. ``exception.Warning`` is the base + class for all warnings. To raise multiple types of exceptions, + pass them as a tuple. To check that no warning is returned, + specify ``False`` or ``None``. + filter_level : str or None, default "always" + Specifies whether warnings are ignored, displayed, or turned + into errors. + Valid values are: + + * "error" - turns matching warnings into exceptions + * "ignore" - discard the warning + * "always" - always emit a warning + * "default" - print the warning the first time it is generated + from each location + * "module" - print the warning the first time it is generated + from each module + * "once" - print the warning the first time it is generated + + check_stacklevel : bool, default True + If True, displays the line that called the function containing + the warning to show were the function is called. Otherwise, the + line that implements the function is displayed. + raise_on_extra_warnings : bool, default True + Whether extra warnings not of the type `expected_warning` should + cause the test to fail. + match : str, optional + Match warning message. + + Examples + -------- + >>> import warnings + >>> with assert_produces_warning(): + ... warnings.warn(UserWarning()) + ... + >>> with assert_produces_warning(False): + ... warnings.warn(RuntimeWarning()) + ... + Traceback (most recent call last): + ... + AssertionError: Caused unexpected warning(s): ['RuntimeWarning']. + >>> with assert_produces_warning(UserWarning): + ... warnings.warn(RuntimeWarning()) + Traceback (most recent call last): + ... + AssertionError: Did not see expected warning of class 'UserWarning'. + + ..warn:: This is *not* thread-safe. + """ + __tracebackhide__ = True + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter(filter_level) + try: + yield w + finally: + if expected_warning: + expected_warning = cast(type[Warning], expected_warning) + _assert_caught_expected_warning( + caught_warnings=w, + expected_warning=expected_warning, + match=match, + check_stacklevel=check_stacklevel, + ) + if raise_on_extra_warnings: + _assert_caught_no_extra_warnings( + caught_warnings=w, + expected_warning=expected_warning, + ) + + +def maybe_produces_warning(warning: type[Warning], condition: bool, **kwargs): + """ + Return a context manager that possibly checks a warning based on the condition + """ + if condition: + return assert_produces_warning(warning, **kwargs) + else: + return nullcontext() + + +def _assert_caught_expected_warning( + *, + caught_warnings: Sequence[warnings.WarningMessage], + expected_warning: type[Warning], + match: str | None, + check_stacklevel: bool, +) -> None: + """Assert that there was the expected warning among the caught warnings.""" + saw_warning = False + matched_message = False + unmatched_messages = [] + + for actual_warning in caught_warnings: + if issubclass(actual_warning.category, expected_warning): + saw_warning = True + + if check_stacklevel: + _assert_raised_with_correct_stacklevel(actual_warning) + + if match is not None: + if re.search(match, str(actual_warning.message)): + matched_message = True + else: + unmatched_messages.append(actual_warning.message) + + if not saw_warning: + raise AssertionError( + f"Did not see expected warning of class " + f"{repr(expected_warning.__name__)}" + ) + + if match and not matched_message: + raise AssertionError( + f"Did not see warning {repr(expected_warning.__name__)} " + f"matching '{match}'. The emitted warning messages are " + f"{unmatched_messages}" + ) + + +def _assert_caught_no_extra_warnings( + *, + caught_warnings: Sequence[warnings.WarningMessage], + expected_warning: type[Warning] | bool | tuple[type[Warning], ...] | None, +) -> None: + """Assert that no extra warnings apart from the expected ones are caught.""" + extra_warnings = [] + + for actual_warning in caught_warnings: + if _is_unexpected_warning(actual_warning, expected_warning): + # GH#38630 pytest.filterwarnings does not suppress these. + if actual_warning.category == ResourceWarning: + # GH 44732: Don't make the CI flaky by filtering SSL-related + # ResourceWarning from dependencies + if "unclosed bool: + """Check if the actual warning issued is unexpected.""" + if actual_warning and not expected_warning: + return True + expected_warning = cast(type[Warning], expected_warning) + return bool(not issubclass(actual_warning.category, expected_warning)) + + +def _assert_raised_with_correct_stacklevel( + actual_warning: warnings.WarningMessage, +) -> None: + # https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow + frame = inspect.currentframe() + for _ in range(4): + frame = frame.f_back # type: ignore[union-attr] + try: + caller_filename = inspect.getfile(frame) # type: ignore[arg-type] + finally: + # See note in + # https://docs.python.org/3/library/inspect.html#inspect.Traceback + del frame + msg = ( + "Warning not set with correct stacklevel. " + f"File where warning is raised: {actual_warning.filename} != " + f"{caller_filename}. Warning message: {actual_warning.message}" + ) + assert actual_warning.filename == caller_filename, msg diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/_testing/asserters.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/_testing/asserters.py new file mode 100644 index 0000000000000000000000000000000000000000..41d2a7344a4edf2e05664eb599b0049d2c696e4c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/_testing/asserters.py @@ -0,0 +1,1435 @@ +from __future__ import annotations + +import operator +from typing import ( + TYPE_CHECKING, + Literal, + NoReturn, + cast, +) + +import numpy as np + +from pandas._libs import lib +from pandas._libs.missing import is_matching_na +from pandas._libs.sparse import SparseIndex +import pandas._libs.testing as _testing +from pandas._libs.tslibs.np_datetime import compare_mismatched_resolutions + +from pandas.core.dtypes.common import ( + is_bool, + is_float_dtype, + is_integer_dtype, + is_number, + is_numeric_dtype, + needs_i8_conversion, +) +from pandas.core.dtypes.dtypes import ( + CategoricalDtype, + DatetimeTZDtype, + ExtensionDtype, + NumpyEADtype, +) +from pandas.core.dtypes.missing import array_equivalent + +import pandas as pd +from pandas import ( + Categorical, + DataFrame, + DatetimeIndex, + Index, + IntervalDtype, + IntervalIndex, + MultiIndex, + PeriodIndex, + RangeIndex, + Series, + TimedeltaIndex, +) +from pandas.core.arrays import ( + DatetimeArray, + ExtensionArray, + IntervalArray, + PeriodArray, + TimedeltaArray, +) +from pandas.core.arrays.datetimelike import DatetimeLikeArrayMixin +from pandas.core.arrays.string_ import StringDtype +from pandas.core.indexes.api import safe_sort_index + +from pandas.io.formats.printing import pprint_thing + +if TYPE_CHECKING: + from pandas._typing import DtypeObj + + +def assert_almost_equal( + left, + right, + check_dtype: bool | Literal["equiv"] = "equiv", + rtol: float = 1.0e-5, + atol: float = 1.0e-8, + **kwargs, +) -> None: + """ + Check that the left and right objects are approximately equal. + + By approximately equal, we refer to objects that are numbers or that + contain numbers which may be equivalent to specific levels of precision. + + Parameters + ---------- + left : object + right : object + check_dtype : bool or {'equiv'}, default 'equiv' + Check dtype if both a and b are the same type. If 'equiv' is passed in, + then `RangeIndex` and `Index` with int64 dtype are also considered + equivalent when doing type checking. + rtol : float, default 1e-5 + Relative tolerance. + atol : float, default 1e-8 + Absolute tolerance. + """ + if isinstance(left, Index): + assert_index_equal( + left, + right, + check_exact=False, + exact=check_dtype, + rtol=rtol, + atol=atol, + **kwargs, + ) + + elif isinstance(left, Series): + assert_series_equal( + left, + right, + check_exact=False, + check_dtype=check_dtype, + rtol=rtol, + atol=atol, + **kwargs, + ) + + elif isinstance(left, DataFrame): + assert_frame_equal( + left, + right, + check_exact=False, + check_dtype=check_dtype, + rtol=rtol, + atol=atol, + **kwargs, + ) + + else: + # Other sequences. + if check_dtype: + if is_number(left) and is_number(right): + # Do not compare numeric classes, like np.float64 and float. + pass + elif is_bool(left) and is_bool(right): + # Do not compare bool classes, like np.bool_ and bool. + pass + else: + if isinstance(left, np.ndarray) or isinstance(right, np.ndarray): + obj = "numpy array" + else: + obj = "Input" + assert_class_equal(left, right, obj=obj) + + # if we have "equiv", this becomes True + _testing.assert_almost_equal( + left, right, check_dtype=bool(check_dtype), rtol=rtol, atol=atol, **kwargs + ) + + +def _check_isinstance(left, right, cls) -> None: + """ + Helper method for our assert_* methods that ensures that + the two objects being compared have the right type before + proceeding with the comparison. + + Parameters + ---------- + left : The first object being compared. + right : The second object being compared. + cls : The class type to check against. + + Raises + ------ + AssertionError : Either `left` or `right` is not an instance of `cls`. + """ + cls_name = cls.__name__ + + if not isinstance(left, cls): + raise AssertionError( + f"{cls_name} Expected type {cls}, found {type(left)} instead" + ) + if not isinstance(right, cls): + raise AssertionError( + f"{cls_name} Expected type {cls}, found {type(right)} instead" + ) + + +def assert_dict_equal(left, right, compare_keys: bool = True) -> None: + _check_isinstance(left, right, dict) + _testing.assert_dict_equal(left, right, compare_keys=compare_keys) + + +def assert_index_equal( + left: Index, + right: Index, + exact: bool | str = "equiv", + check_names: bool = True, + check_exact: bool = True, + check_categorical: bool = True, + check_order: bool = True, + rtol: float = 1.0e-5, + atol: float = 1.0e-8, + obj: str = "Index", +) -> None: + """ + Check that left and right Index are equal. + + Parameters + ---------- + left : Index + right : Index + exact : bool or {'equiv'}, default 'equiv' + Whether to check the Index class, dtype and inferred_type + are identical. If 'equiv', then RangeIndex can be substituted for + Index with an int64 dtype as well. + check_names : bool, default True + Whether to check the names attribute. + check_exact : bool, default True + Whether to compare number exactly. + check_categorical : bool, default True + Whether to compare internal Categorical exactly. + check_order : bool, default True + Whether to compare the order of index entries as well as their values. + If True, both indexes must contain the same elements, in the same order. + If False, both indexes must contain the same elements, but in any order. + rtol : float, default 1e-5 + Relative tolerance. Only used when check_exact is False. + atol : float, default 1e-8 + Absolute tolerance. Only used when check_exact is False. + obj : str, default 'Index' + Specify object name being compared, internally used to show appropriate + assertion message. + + Examples + -------- + >>> from pandas import testing as tm + >>> a = pd.Index([1, 2, 3]) + >>> b = pd.Index([1, 2, 3]) + >>> tm.assert_index_equal(a, b) + """ + __tracebackhide__ = True + + def _check_types(left, right, obj: str = "Index") -> None: + if not exact: + return + + assert_class_equal(left, right, exact=exact, obj=obj) + assert_attr_equal("inferred_type", left, right, obj=obj) + + # Skip exact dtype checking when `check_categorical` is False + if isinstance(left.dtype, CategoricalDtype) and isinstance( + right.dtype, CategoricalDtype + ): + if check_categorical: + assert_attr_equal("dtype", left, right, obj=obj) + assert_index_equal(left.categories, right.categories, exact=exact) + return + + assert_attr_equal("dtype", left, right, obj=obj) + + # instance validation + _check_isinstance(left, right, Index) + + # class / dtype comparison + _check_types(left, right, obj=obj) + + # level comparison + if left.nlevels != right.nlevels: + msg1 = f"{obj} levels are different" + msg2 = f"{left.nlevels}, {left}" + msg3 = f"{right.nlevels}, {right}" + raise_assert_detail(obj, msg1, msg2, msg3) + + # length comparison + if len(left) != len(right): + msg1 = f"{obj} length are different" + msg2 = f"{len(left)}, {left}" + msg3 = f"{len(right)}, {right}" + raise_assert_detail(obj, msg1, msg2, msg3) + + # If order doesn't matter then sort the index entries + if not check_order: + left = safe_sort_index(left) + right = safe_sort_index(right) + + # MultiIndex special comparison for little-friendly error messages + if isinstance(left, MultiIndex): + right = cast(MultiIndex, right) + + for level in range(left.nlevels): + lobj = f"MultiIndex level [{level}]" + try: + # try comparison on levels/codes to avoid densifying MultiIndex + assert_index_equal( + left.levels[level], + right.levels[level], + exact=exact, + check_names=check_names, + check_exact=check_exact, + check_categorical=check_categorical, + rtol=rtol, + atol=atol, + obj=lobj, + ) + assert_numpy_array_equal(left.codes[level], right.codes[level]) + except AssertionError: + llevel = left.get_level_values(level) + rlevel = right.get_level_values(level) + + assert_index_equal( + llevel, + rlevel, + exact=exact, + check_names=check_names, + check_exact=check_exact, + check_categorical=check_categorical, + rtol=rtol, + atol=atol, + obj=lobj, + ) + # get_level_values may change dtype + _check_types(left.levels[level], right.levels[level], obj=obj) + + # skip exact index checking when `check_categorical` is False + elif check_exact and check_categorical: + if not left.equals(right): + mismatch = left._values != right._values + + if not isinstance(mismatch, np.ndarray): + mismatch = cast("ExtensionArray", mismatch).fillna(True) + + diff = np.sum(mismatch.astype(int)) * 100.0 / len(left) + msg = f"{obj} values are different ({np.round(diff, 5)} %)" + raise_assert_detail(obj, msg, left, right) + else: + # if we have "equiv", this becomes True + exact_bool = bool(exact) + _testing.assert_almost_equal( + left.values, + right.values, + rtol=rtol, + atol=atol, + check_dtype=exact_bool, + obj=obj, + lobj=left, + robj=right, + ) + + # metadata comparison + if check_names: + assert_attr_equal("names", left, right, obj=obj) + if isinstance(left, PeriodIndex) or isinstance(right, PeriodIndex): + assert_attr_equal("dtype", left, right, obj=obj) + if isinstance(left, IntervalIndex) or isinstance(right, IntervalIndex): + assert_interval_array_equal(left._values, right._values) + + if check_categorical: + if isinstance(left.dtype, CategoricalDtype) or isinstance( + right.dtype, CategoricalDtype + ): + assert_categorical_equal(left._values, right._values, obj=f"{obj} category") + + +def assert_class_equal( + left, right, exact: bool | str = True, obj: str = "Input" +) -> None: + """ + Checks classes are equal. + """ + __tracebackhide__ = True + + def repr_class(x): + if isinstance(x, Index): + # return Index as it is to include values in the error message + return x + + return type(x).__name__ + + def is_class_equiv(idx: Index) -> bool: + """Classes that are a RangeIndex (sub-)instance or exactly an `Index` . + + This only checks class equivalence. There is a separate check that the + dtype is int64. + """ + return type(idx) is Index or isinstance(idx, RangeIndex) + + if type(left) == type(right): + return + + if exact == "equiv": + if is_class_equiv(left) and is_class_equiv(right): + return + + msg = f"{obj} classes are different" + raise_assert_detail(obj, msg, repr_class(left), repr_class(right)) + + +def assert_attr_equal(attr: str, left, right, obj: str = "Attributes") -> None: + """ + Check attributes are equal. Both objects must have attribute. + + Parameters + ---------- + attr : str + Attribute name being compared. + left : object + right : object + obj : str, default 'Attributes' + Specify object name being compared, internally used to show appropriate + assertion message + """ + __tracebackhide__ = True + + left_attr = getattr(left, attr) + right_attr = getattr(right, attr) + + if left_attr is right_attr or is_matching_na(left_attr, right_attr): + # e.g. both np.nan, both NaT, both pd.NA, ... + return None + + try: + result = left_attr == right_attr + except TypeError: + # datetimetz on rhs may raise TypeError + result = False + if (left_attr is pd.NA) ^ (right_attr is pd.NA): + result = False + elif not isinstance(result, bool): + result = result.all() + + if not result: + msg = f'Attribute "{attr}" are different' + raise_assert_detail(obj, msg, left_attr, right_attr) + return None + + +def assert_is_valid_plot_return_object(objs) -> None: + from matplotlib.artist import Artist + from matplotlib.axes import Axes + + if isinstance(objs, (Series, np.ndarray)): + if isinstance(objs, Series): + objs = objs._values + for el in objs.ravel(): + msg = ( + "one of 'objs' is not a matplotlib Axes instance, " + f"type encountered {repr(type(el).__name__)}" + ) + assert isinstance(el, (Axes, dict)), msg + else: + msg = ( + "objs is neither an ndarray of Artist instances nor a single " + "ArtistArtist instance, tuple, or dict, 'objs' is a " + f"{repr(type(objs).__name__)}" + ) + assert isinstance(objs, (Artist, tuple, dict)), msg + + +def assert_is_sorted(seq) -> None: + """Assert that the sequence is sorted.""" + if isinstance(seq, (Index, Series)): + seq = seq.values + # sorting does not change precisions + if isinstance(seq, np.ndarray): + assert_numpy_array_equal(seq, np.sort(np.array(seq))) + else: + assert_extension_array_equal(seq, seq[seq.argsort()]) + + +def assert_categorical_equal( + left, + right, + check_dtype: bool = True, + check_category_order: bool = True, + obj: str = "Categorical", +) -> None: + """ + Test that Categoricals are equivalent. + + Parameters + ---------- + left : Categorical + right : Categorical + check_dtype : bool, default True + Check that integer dtype of the codes are the same. + check_category_order : bool, default True + Whether the order of the categories should be compared, which + implies identical integer codes. If False, only the resulting + values are compared. The ordered attribute is + checked regardless. + obj : str, default 'Categorical' + Specify object name being compared, internally used to show appropriate + assertion message. + """ + _check_isinstance(left, right, Categorical) + + exact: bool | str + if isinstance(left.categories, RangeIndex) or isinstance( + right.categories, RangeIndex + ): + exact = "equiv" + else: + # We still want to require exact matches for Index + exact = True + + if check_category_order: + assert_index_equal( + left.categories, right.categories, obj=f"{obj}.categories", exact=exact + ) + assert_numpy_array_equal( + left.codes, right.codes, check_dtype=check_dtype, obj=f"{obj}.codes" + ) + else: + try: + lc = left.categories.sort_values() + rc = right.categories.sort_values() + except TypeError: + # e.g. '<' not supported between instances of 'int' and 'str' + lc, rc = left.categories, right.categories + assert_index_equal(lc, rc, obj=f"{obj}.categories", exact=exact) + assert_index_equal( + left.categories.take(left.codes), + right.categories.take(right.codes), + obj=f"{obj}.values", + exact=exact, + ) + + assert_attr_equal("ordered", left, right, obj=obj) + + +def assert_interval_array_equal( + left, right, exact: bool | Literal["equiv"] = "equiv", obj: str = "IntervalArray" +) -> None: + """ + Test that two IntervalArrays are equivalent. + + Parameters + ---------- + left, right : IntervalArray + The IntervalArrays to compare. + exact : bool or {'equiv'}, default 'equiv' + Whether to check the Index class, dtype and inferred_type + are identical. If 'equiv', then RangeIndex can be substituted for + Index with an int64 dtype as well. + obj : str, default 'IntervalArray' + Specify object name being compared, internally used to show appropriate + assertion message + """ + _check_isinstance(left, right, IntervalArray) + + kwargs = {} + if left._left.dtype.kind in "mM": + # We have a DatetimeArray or TimedeltaArray + kwargs["check_freq"] = False + + assert_equal(left._left, right._left, obj=f"{obj}.left", **kwargs) + assert_equal(left._right, right._right, obj=f"{obj}.left", **kwargs) + + assert_attr_equal("closed", left, right, obj=obj) + + +def assert_period_array_equal(left, right, obj: str = "PeriodArray") -> None: + _check_isinstance(left, right, PeriodArray) + + assert_numpy_array_equal(left._ndarray, right._ndarray, obj=f"{obj}._ndarray") + assert_attr_equal("dtype", left, right, obj=obj) + + +def assert_datetime_array_equal( + left, right, obj: str = "DatetimeArray", check_freq: bool = True +) -> None: + __tracebackhide__ = True + _check_isinstance(left, right, DatetimeArray) + + assert_numpy_array_equal(left._ndarray, right._ndarray, obj=f"{obj}._ndarray") + if check_freq: + assert_attr_equal("freq", left, right, obj=obj) + assert_attr_equal("tz", left, right, obj=obj) + + +def assert_timedelta_array_equal( + left, right, obj: str = "TimedeltaArray", check_freq: bool = True +) -> None: + __tracebackhide__ = True + _check_isinstance(left, right, TimedeltaArray) + assert_numpy_array_equal(left._ndarray, right._ndarray, obj=f"{obj}._ndarray") + if check_freq: + assert_attr_equal("freq", left, right, obj=obj) + + +def raise_assert_detail( + obj, message, left, right, diff=None, first_diff=None, index_values=None +) -> NoReturn: + __tracebackhide__ = True + + msg = f"""{obj} are different + +{message}""" + + if isinstance(index_values, Index): + index_values = np.asarray(index_values) + + if isinstance(index_values, np.ndarray): + msg += f"\n[index]: {pprint_thing(index_values)}" + + if isinstance(left, np.ndarray): + left = pprint_thing(left) + elif isinstance(left, (CategoricalDtype, NumpyEADtype, StringDtype)): + left = repr(left) + + if isinstance(right, np.ndarray): + right = pprint_thing(right) + elif isinstance(right, (CategoricalDtype, NumpyEADtype, StringDtype)): + right = repr(right) + + msg += f""" +[left]: {left} +[right]: {right}""" + + if diff is not None: + msg += f"\n[diff]: {diff}" + + if first_diff is not None: + msg += f"\n{first_diff}" + + raise AssertionError(msg) + + +def assert_numpy_array_equal( + left, + right, + strict_nan: bool = False, + check_dtype: bool | Literal["equiv"] = True, + err_msg=None, + check_same=None, + obj: str = "numpy array", + index_values=None, +) -> None: + """ + Check that 'np.ndarray' is equivalent. + + Parameters + ---------- + left, right : numpy.ndarray or iterable + The two arrays to be compared. + strict_nan : bool, default False + If True, consider NaN and None to be different. + check_dtype : bool, default True + Check dtype if both a and b are np.ndarray. + err_msg : str, default None + If provided, used as assertion message. + check_same : None|'copy'|'same', default None + Ensure left and right refer/do not refer to the same memory area. + obj : str, default 'numpy array' + Specify object name being compared, internally used to show appropriate + assertion message. + index_values : Index | numpy.ndarray, default None + optional index (shared by both left and right), used in output. + """ + __tracebackhide__ = True + + # instance validation + # Show a detailed error message when classes are different + assert_class_equal(left, right, obj=obj) + # both classes must be an np.ndarray + _check_isinstance(left, right, np.ndarray) + + def _get_base(obj): + return obj.base if getattr(obj, "base", None) is not None else obj + + left_base = _get_base(left) + right_base = _get_base(right) + + if check_same == "same": + if left_base is not right_base: + raise AssertionError(f"{repr(left_base)} is not {repr(right_base)}") + elif check_same == "copy": + if left_base is right_base: + raise AssertionError(f"{repr(left_base)} is {repr(right_base)}") + + def _raise(left, right, err_msg) -> NoReturn: + if err_msg is None: + if left.shape != right.shape: + raise_assert_detail( + obj, f"{obj} shapes are different", left.shape, right.shape + ) + + diff = 0 + for left_arr, right_arr in zip(left, right): + # count up differences + if not array_equivalent(left_arr, right_arr, strict_nan=strict_nan): + diff += 1 + + diff = diff * 100.0 / left.size + msg = f"{obj} values are different ({np.round(diff, 5)} %)" + raise_assert_detail(obj, msg, left, right, index_values=index_values) + + raise AssertionError(err_msg) + + # compare shape and values + if not array_equivalent(left, right, strict_nan=strict_nan): + _raise(left, right, err_msg) + + if check_dtype: + if isinstance(left, np.ndarray) and isinstance(right, np.ndarray): + assert_attr_equal("dtype", left, right, obj=obj) + + +def assert_extension_array_equal( + left, + right, + check_dtype: bool | Literal["equiv"] = True, + index_values=None, + check_exact: bool | lib.NoDefault = lib.no_default, + rtol: float | lib.NoDefault = lib.no_default, + atol: float | lib.NoDefault = lib.no_default, + obj: str = "ExtensionArray", +) -> None: + """ + Check that left and right ExtensionArrays are equal. + + Parameters + ---------- + left, right : ExtensionArray + The two arrays to compare. + check_dtype : bool, default True + Whether to check if the ExtensionArray dtypes are identical. + index_values : Index | numpy.ndarray, default None + Optional index (shared by both left and right), used in output. + check_exact : bool, default False + Whether to compare number exactly. + + .. versionchanged:: 2.2.0 + + Defaults to True for integer dtypes if none of + ``check_exact``, ``rtol`` and ``atol`` are specified. + rtol : float, default 1e-5 + Relative tolerance. Only used when check_exact is False. + atol : float, default 1e-8 + Absolute tolerance. Only used when check_exact is False. + obj : str, default 'ExtensionArray' + Specify object name being compared, internally used to show appropriate + assertion message. + + .. versionadded:: 2.0.0 + + Notes + ----- + Missing values are checked separately from valid values. + A mask of missing values is computed for each and checked to match. + The remaining all-valid values are cast to object dtype and checked. + + Examples + -------- + >>> from pandas import testing as tm + >>> a = pd.Series([1, 2, 3, 4]) + >>> b, c = a.array, a.array + >>> tm.assert_extension_array_equal(b, c) + """ + if ( + check_exact is lib.no_default + and rtol is lib.no_default + and atol is lib.no_default + ): + check_exact = ( + is_numeric_dtype(left.dtype) + and not is_float_dtype(left.dtype) + or is_numeric_dtype(right.dtype) + and not is_float_dtype(right.dtype) + ) + elif check_exact is lib.no_default: + check_exact = False + + rtol = rtol if rtol is not lib.no_default else 1.0e-5 + atol = atol if atol is not lib.no_default else 1.0e-8 + + assert isinstance(left, ExtensionArray), "left is not an ExtensionArray" + assert isinstance(right, ExtensionArray), "right is not an ExtensionArray" + if check_dtype: + assert_attr_equal("dtype", left, right, obj=f"Attributes of {obj}") + + if ( + isinstance(left, DatetimeLikeArrayMixin) + and isinstance(right, DatetimeLikeArrayMixin) + and type(right) == type(left) + ): + # GH 52449 + if not check_dtype and left.dtype.kind in "mM": + if not isinstance(left.dtype, np.dtype): + l_unit = cast(DatetimeTZDtype, left.dtype).unit + else: + l_unit = np.datetime_data(left.dtype)[0] + if not isinstance(right.dtype, np.dtype): + r_unit = cast(DatetimeTZDtype, right.dtype).unit + else: + r_unit = np.datetime_data(right.dtype)[0] + if ( + l_unit != r_unit + and compare_mismatched_resolutions( + left._ndarray, right._ndarray, operator.eq + ).all() + ): + return + # Avoid slow object-dtype comparisons + # np.asarray for case where we have a np.MaskedArray + assert_numpy_array_equal( + np.asarray(left.asi8), + np.asarray(right.asi8), + index_values=index_values, + obj=obj, + ) + return + + left_na = np.asarray(left.isna()) + right_na = np.asarray(right.isna()) + assert_numpy_array_equal( + left_na, right_na, obj=f"{obj} NA mask", index_values=index_values + ) + + left_valid = left[~left_na].to_numpy(dtype=object) + right_valid = right[~right_na].to_numpy(dtype=object) + if check_exact: + assert_numpy_array_equal( + left_valid, right_valid, obj=obj, index_values=index_values + ) + else: + _testing.assert_almost_equal( + left_valid, + right_valid, + check_dtype=bool(check_dtype), + rtol=rtol, + atol=atol, + obj=obj, + index_values=index_values, + ) + + +# This could be refactored to use the NDFrame.equals method +def assert_series_equal( + left, + right, + check_dtype: bool | Literal["equiv"] = True, + check_index_type: bool | Literal["equiv"] = "equiv", + check_series_type: bool = True, + check_names: bool = True, + check_exact: bool | lib.NoDefault = lib.no_default, + check_datetimelike_compat: bool = False, + check_categorical: bool = True, + check_category_order: bool = True, + check_freq: bool = True, + check_flags: bool = True, + rtol: float | lib.NoDefault = lib.no_default, + atol: float | lib.NoDefault = lib.no_default, + obj: str = "Series", + *, + check_index: bool = True, + check_like: bool = False, +) -> None: + """ + Check that left and right Series are equal. + + Parameters + ---------- + left : Series + right : Series + check_dtype : bool, default True + Whether to check the Series dtype is identical. + check_index_type : bool or {'equiv'}, default 'equiv' + Whether to check the Index class, dtype and inferred_type + are identical. + check_series_type : bool, default True + Whether to check the Series class is identical. + check_names : bool, default True + Whether to check the Series and Index names attribute. + check_exact : bool, default False + Whether to compare number exactly. + + .. versionchanged:: 2.2.0 + + Defaults to True for integer dtypes if none of + ``check_exact``, ``rtol`` and ``atol`` are specified. + check_datetimelike_compat : bool, default False + Compare datetime-like which is comparable ignoring dtype. + check_categorical : bool, default True + Whether to compare internal Categorical exactly. + check_category_order : bool, default True + Whether to compare category order of internal Categoricals. + check_freq : bool, default True + Whether to check the `freq` attribute on a DatetimeIndex or TimedeltaIndex. + check_flags : bool, default True + Whether to check the `flags` attribute. + rtol : float, default 1e-5 + Relative tolerance. Only used when check_exact is False. + atol : float, default 1e-8 + Absolute tolerance. Only used when check_exact is False. + obj : str, default 'Series' + Specify object name being compared, internally used to show appropriate + assertion message. + check_index : bool, default True + Whether to check index equivalence. If False, then compare only values. + + .. versionadded:: 1.3.0 + check_like : bool, default False + If True, ignore the order of the index. Must be False if check_index is False. + Note: same labels must be with the same data. + + .. versionadded:: 1.5.0 + + Examples + -------- + >>> from pandas import testing as tm + >>> a = pd.Series([1, 2, 3, 4]) + >>> b = pd.Series([1, 2, 3, 4]) + >>> tm.assert_series_equal(a, b) + """ + __tracebackhide__ = True + check_exact_index = False if check_exact is lib.no_default else check_exact + if ( + check_exact is lib.no_default + and rtol is lib.no_default + and atol is lib.no_default + ): + check_exact = ( + is_numeric_dtype(left.dtype) + and not is_float_dtype(left.dtype) + or is_numeric_dtype(right.dtype) + and not is_float_dtype(right.dtype) + ) + elif check_exact is lib.no_default: + check_exact = False + + rtol = rtol if rtol is not lib.no_default else 1.0e-5 + atol = atol if atol is not lib.no_default else 1.0e-8 + + if not check_index and check_like: + raise ValueError("check_like must be False if check_index is False") + + # instance validation + _check_isinstance(left, right, Series) + + if check_series_type: + assert_class_equal(left, right, obj=obj) + + # length comparison + if len(left) != len(right): + msg1 = f"{len(left)}, {left.index}" + msg2 = f"{len(right)}, {right.index}" + raise_assert_detail(obj, "Series length are different", msg1, msg2) + + if check_flags: + assert left.flags == right.flags, f"{repr(left.flags)} != {repr(right.flags)}" + + if check_index: + # GH #38183 + assert_index_equal( + left.index, + right.index, + exact=check_index_type, + check_names=check_names, + check_exact=check_exact_index, + check_categorical=check_categorical, + check_order=not check_like, + rtol=rtol, + atol=atol, + obj=f"{obj}.index", + ) + + if check_like: + left = left.reindex_like(right) + + if check_freq and isinstance(left.index, (DatetimeIndex, TimedeltaIndex)): + lidx = left.index + ridx = right.index + assert lidx.freq == ridx.freq, (lidx.freq, ridx.freq) + + if check_dtype: + # We want to skip exact dtype checking when `check_categorical` + # is False. We'll still raise if only one is a `Categorical`, + # regardless of `check_categorical` + if ( + isinstance(left.dtype, CategoricalDtype) + and isinstance(right.dtype, CategoricalDtype) + and not check_categorical + ): + pass + else: + assert_attr_equal("dtype", left, right, obj=f"Attributes of {obj}") + if check_exact: + left_values = left._values + right_values = right._values + # Only check exact if dtype is numeric + if isinstance(left_values, ExtensionArray) and isinstance( + right_values, ExtensionArray + ): + assert_extension_array_equal( + left_values, + right_values, + check_dtype=check_dtype, + index_values=left.index, + obj=str(obj), + ) + else: + # convert both to NumPy if not, check_dtype would raise earlier + lv, rv = left_values, right_values + if isinstance(left_values, ExtensionArray): + lv = left_values.to_numpy() + if isinstance(right_values, ExtensionArray): + rv = right_values.to_numpy() + assert_numpy_array_equal( + lv, + rv, + check_dtype=check_dtype, + obj=str(obj), + index_values=left.index, + ) + elif check_datetimelike_compat and ( + needs_i8_conversion(left.dtype) or needs_i8_conversion(right.dtype) + ): + # we want to check only if we have compat dtypes + # e.g. integer and M|m are NOT compat, but we can simply check + # the values in that case + + # datetimelike may have different objects (e.g. datetime.datetime + # vs Timestamp) but will compare equal + if not Index(left._values).equals(Index(right._values)): + msg = ( + f"[datetimelike_compat=True] {left._values} " + f"is not equal to {right._values}." + ) + raise AssertionError(msg) + elif isinstance(left.dtype, IntervalDtype) and isinstance( + right.dtype, IntervalDtype + ): + assert_interval_array_equal(left.array, right.array) + elif isinstance(left.dtype, CategoricalDtype) or isinstance( + right.dtype, CategoricalDtype + ): + _testing.assert_almost_equal( + left._values, + right._values, + rtol=rtol, + atol=atol, + check_dtype=bool(check_dtype), + obj=str(obj), + index_values=left.index, + ) + elif isinstance(left.dtype, ExtensionDtype) and isinstance( + right.dtype, ExtensionDtype + ): + assert_extension_array_equal( + left._values, + right._values, + rtol=rtol, + atol=atol, + check_dtype=check_dtype, + index_values=left.index, + obj=str(obj), + ) + elif is_extension_array_dtype_and_needs_i8_conversion( + left.dtype, right.dtype + ) or is_extension_array_dtype_and_needs_i8_conversion(right.dtype, left.dtype): + assert_extension_array_equal( + left._values, + right._values, + check_dtype=check_dtype, + index_values=left.index, + obj=str(obj), + ) + elif needs_i8_conversion(left.dtype) and needs_i8_conversion(right.dtype): + # DatetimeArray or TimedeltaArray + assert_extension_array_equal( + left._values, + right._values, + check_dtype=check_dtype, + index_values=left.index, + obj=str(obj), + ) + else: + _testing.assert_almost_equal( + left._values, + right._values, + rtol=rtol, + atol=atol, + check_dtype=bool(check_dtype), + obj=str(obj), + index_values=left.index, + ) + + # metadata comparison + if check_names: + assert_attr_equal("name", left, right, obj=obj) + + if check_categorical: + if isinstance(left.dtype, CategoricalDtype) or isinstance( + right.dtype, CategoricalDtype + ): + assert_categorical_equal( + left._values, + right._values, + obj=f"{obj} category", + check_category_order=check_category_order, + ) + + +# This could be refactored to use the NDFrame.equals method +def assert_frame_equal( + left, + right, + check_dtype: bool | Literal["equiv"] = True, + check_index_type: bool | Literal["equiv"] = "equiv", + check_column_type: bool | Literal["equiv"] = "equiv", + check_frame_type: bool = True, + check_names: bool = True, + by_blocks: bool = False, + check_exact: bool | lib.NoDefault = lib.no_default, + check_datetimelike_compat: bool = False, + check_categorical: bool = True, + check_like: bool = False, + check_freq: bool = True, + check_flags: bool = True, + rtol: float | lib.NoDefault = lib.no_default, + atol: float | lib.NoDefault = lib.no_default, + obj: str = "DataFrame", +) -> None: + """ + Check that left and right DataFrame are equal. + + This function is intended to compare two DataFrames and output any + differences. It is mostly intended for use in unit tests. + Additional parameters allow varying the strictness of the + equality checks performed. + + Parameters + ---------- + left : DataFrame + First DataFrame to compare. + right : DataFrame + Second DataFrame to compare. + check_dtype : bool, default True + Whether to check the DataFrame dtype is identical. + check_index_type : bool or {'equiv'}, default 'equiv' + Whether to check the Index class, dtype and inferred_type + are identical. + check_column_type : bool or {'equiv'}, default 'equiv' + Whether to check the columns class, dtype and inferred_type + are identical. Is passed as the ``exact`` argument of + :func:`assert_index_equal`. + check_frame_type : bool, default True + Whether to check the DataFrame class is identical. + check_names : bool, default True + Whether to check that the `names` attribute for both the `index` + and `column` attributes of the DataFrame is identical. + by_blocks : bool, default False + Specify how to compare internal data. If False, compare by columns. + If True, compare by blocks. + check_exact : bool, default False + Whether to compare number exactly. + + .. versionchanged:: 2.2.0 + + Defaults to True for integer dtypes if none of + ``check_exact``, ``rtol`` and ``atol`` are specified. + check_datetimelike_compat : bool, default False + Compare datetime-like which is comparable ignoring dtype. + check_categorical : bool, default True + Whether to compare internal Categorical exactly. + check_like : bool, default False + If True, ignore the order of index & columns. + Note: index labels must match their respective rows + (same as in columns) - same labels must be with the same data. + check_freq : bool, default True + Whether to check the `freq` attribute on a DatetimeIndex or TimedeltaIndex. + check_flags : bool, default True + Whether to check the `flags` attribute. + rtol : float, default 1e-5 + Relative tolerance. Only used when check_exact is False. + atol : float, default 1e-8 + Absolute tolerance. Only used when check_exact is False. + obj : str, default 'DataFrame' + Specify object name being compared, internally used to show appropriate + assertion message. + + See Also + -------- + assert_series_equal : Equivalent method for asserting Series equality. + DataFrame.equals : Check DataFrame equality. + + Examples + -------- + This example shows comparing two DataFrames that are equal + but with columns of differing dtypes. + + >>> from pandas.testing import assert_frame_equal + >>> df1 = pd.DataFrame({'a': [1, 2], 'b': [3, 4]}) + >>> df2 = pd.DataFrame({'a': [1, 2], 'b': [3.0, 4.0]}) + + df1 equals itself. + + >>> assert_frame_equal(df1, df1) + + df1 differs from df2 as column 'b' is of a different type. + + >>> assert_frame_equal(df1, df2) + Traceback (most recent call last): + ... + AssertionError: Attributes of DataFrame.iloc[:, 1] (column name="b") are different + + Attribute "dtype" are different + [left]: int64 + [right]: float64 + + Ignore differing dtypes in columns with check_dtype. + + >>> assert_frame_equal(df1, df2, check_dtype=False) + """ + __tracebackhide__ = True + _rtol = rtol if rtol is not lib.no_default else 1.0e-5 + _atol = atol if atol is not lib.no_default else 1.0e-8 + _check_exact = check_exact if check_exact is not lib.no_default else False + + # instance validation + _check_isinstance(left, right, DataFrame) + + if check_frame_type: + assert isinstance(left, type(right)) + # assert_class_equal(left, right, obj=obj) + + # shape comparison + if left.shape != right.shape: + raise_assert_detail( + obj, f"{obj} shape mismatch", f"{repr(left.shape)}", f"{repr(right.shape)}" + ) + + if check_flags: + assert left.flags == right.flags, f"{repr(left.flags)} != {repr(right.flags)}" + + # index comparison + assert_index_equal( + left.index, + right.index, + exact=check_index_type, + check_names=check_names, + check_exact=_check_exact, + check_categorical=check_categorical, + check_order=not check_like, + rtol=_rtol, + atol=_atol, + obj=f"{obj}.index", + ) + + # column comparison + assert_index_equal( + left.columns, + right.columns, + exact=check_column_type, + check_names=check_names, + check_exact=_check_exact, + check_categorical=check_categorical, + check_order=not check_like, + rtol=_rtol, + atol=_atol, + obj=f"{obj}.columns", + ) + + if check_like: + left = left.reindex_like(right) + + # compare by blocks + if by_blocks: + rblocks = right._to_dict_of_blocks() + lblocks = left._to_dict_of_blocks() + for dtype in list(set(list(lblocks.keys()) + list(rblocks.keys()))): + assert dtype in lblocks + assert dtype in rblocks + assert_frame_equal( + lblocks[dtype], rblocks[dtype], check_dtype=check_dtype, obj=obj + ) + + # compare by columns + else: + for i, col in enumerate(left.columns): + # We have already checked that columns match, so we can do + # fast location-based lookups + lcol = left._ixs(i, axis=1) + rcol = right._ixs(i, axis=1) + + # GH #38183 + # use check_index=False, because we do not want to run + # assert_index_equal for each column, + # as we already checked it for the whole dataframe before. + assert_series_equal( + lcol, + rcol, + check_dtype=check_dtype, + check_index_type=check_index_type, + check_exact=check_exact, + check_names=check_names, + check_datetimelike_compat=check_datetimelike_compat, + check_categorical=check_categorical, + check_freq=check_freq, + obj=f'{obj}.iloc[:, {i}] (column name="{col}")', + rtol=rtol, + atol=atol, + check_index=False, + check_flags=False, + ) + + +def assert_equal(left, right, **kwargs) -> None: + """ + Wrapper for tm.assert_*_equal to dispatch to the appropriate test function. + + Parameters + ---------- + left, right : Index, Series, DataFrame, ExtensionArray, or np.ndarray + The two items to be compared. + **kwargs + All keyword arguments are passed through to the underlying assert method. + """ + __tracebackhide__ = True + + if isinstance(left, Index): + assert_index_equal(left, right, **kwargs) + if isinstance(left, (DatetimeIndex, TimedeltaIndex)): + assert left.freq == right.freq, (left.freq, right.freq) + elif isinstance(left, Series): + assert_series_equal(left, right, **kwargs) + elif isinstance(left, DataFrame): + assert_frame_equal(left, right, **kwargs) + elif isinstance(left, IntervalArray): + assert_interval_array_equal(left, right, **kwargs) + elif isinstance(left, PeriodArray): + assert_period_array_equal(left, right, **kwargs) + elif isinstance(left, DatetimeArray): + assert_datetime_array_equal(left, right, **kwargs) + elif isinstance(left, TimedeltaArray): + assert_timedelta_array_equal(left, right, **kwargs) + elif isinstance(left, ExtensionArray): + assert_extension_array_equal(left, right, **kwargs) + elif isinstance(left, np.ndarray): + assert_numpy_array_equal(left, right, **kwargs) + elif isinstance(left, str): + assert kwargs == {} + assert left == right + else: + assert kwargs == {} + assert_almost_equal(left, right) + + +def assert_sp_array_equal(left, right) -> None: + """ + Check that the left and right SparseArray are equal. + + Parameters + ---------- + left : SparseArray + right : SparseArray + """ + _check_isinstance(left, right, pd.arrays.SparseArray) + + assert_numpy_array_equal(left.sp_values, right.sp_values) + + # SparseIndex comparison + assert isinstance(left.sp_index, SparseIndex) + assert isinstance(right.sp_index, SparseIndex) + + left_index = left.sp_index + right_index = right.sp_index + + if not left_index.equals(right_index): + raise_assert_detail( + "SparseArray.index", "index are not equal", left_index, right_index + ) + else: + # Just ensure a + pass + + assert_attr_equal("fill_value", left, right) + assert_attr_equal("dtype", left, right) + assert_numpy_array_equal(left.to_dense(), right.to_dense()) + + +def assert_contains_all(iterable, dic) -> None: + for k in iterable: + assert k in dic, f"Did not contain item: {repr(k)}" + + +def assert_copy(iter1, iter2, **eql_kwargs) -> None: + """ + iter1, iter2: iterables that produce elements + comparable with assert_almost_equal + + Checks that the elements are equal, but not + the same object. (Does not check that items + in sequences are also not the same object) + """ + for elem1, elem2 in zip(iter1, iter2): + assert_almost_equal(elem1, elem2, **eql_kwargs) + msg = ( + f"Expected object {repr(type(elem1))} and object {repr(type(elem2))} to be " + "different objects, but they were the same object." + ) + assert elem1 is not elem2, msg + + +def is_extension_array_dtype_and_needs_i8_conversion( + left_dtype: DtypeObj, right_dtype: DtypeObj +) -> bool: + """ + Checks that we have the combination of an ExtensionArraydtype and + a dtype that should be converted to int64 + + Returns + ------- + bool + + Related to issue #37609 + """ + return isinstance(left_dtype, ExtensionDtype) and needs_i8_conversion(right_dtype) + + +def assert_indexing_slices_equivalent(ser: Series, l_slc: slice, i_slc: slice) -> None: + """ + Check that ser.iloc[i_slc] matches ser.loc[l_slc] and, if applicable, + ser[l_slc]. + """ + expected = ser.iloc[i_slc] + + assert_series_equal(ser.loc[l_slc], expected) + + if not is_integer_dtype(ser.index): + # For integer indices, .loc and plain getitem are position-based. + assert_series_equal(ser[l_slc], expected) + + +def assert_metadata_equivalent( + left: DataFrame | Series, right: DataFrame | Series | None = None +) -> None: + """ + Check that ._metadata attributes are equivalent. + """ + for attr in left._metadata: + val = getattr(left, attr, None) + if right is None: + assert val is None + else: + assert val == getattr(right, attr, None) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/_testing/compat.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/_testing/compat.py new file mode 100644 index 0000000000000000000000000000000000000000..cc352ba7b8f2f5a5548d4d5749d3b48ac838aced --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/_testing/compat.py @@ -0,0 +1,29 @@ +""" +Helpers for sharing tests between DataFrame/Series +""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pandas import DataFrame + +if TYPE_CHECKING: + from pandas._typing import DtypeObj + + +def get_dtype(obj) -> DtypeObj: + if isinstance(obj, DataFrame): + # Note: we are assuming only one column + return obj.dtypes.iat[0] + else: + return obj.dtype + + +def get_obj(df: DataFrame, klass): + """ + For sharing tests using frame_or_series, either return the DataFrame + unchanged or return it's first column as a Series. + """ + if klass is DataFrame: + return df + return df._ixs(0, axis=1) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/_testing/contexts.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/_testing/contexts.py new file mode 100644 index 0000000000000000000000000000000000000000..eb6e4a917889aef221b2fc08eb2723c4fe568e04 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/_testing/contexts.py @@ -0,0 +1,257 @@ +from __future__ import annotations + +from contextlib import contextmanager +import os +from pathlib import Path +import tempfile +from typing import ( + IO, + TYPE_CHECKING, + Any, +) +import uuid + +from pandas._config import using_copy_on_write + +from pandas.compat import PYPY +from pandas.errors import ChainedAssignmentError + +from pandas import set_option + +from pandas.io.common import get_handle + +if TYPE_CHECKING: + from collections.abc import Generator + + from pandas._typing import ( + BaseBuffer, + CompressionOptions, + FilePath, + ) + + +@contextmanager +def decompress_file( + path: FilePath | BaseBuffer, compression: CompressionOptions +) -> Generator[IO[bytes], None, None]: + """ + Open a compressed file and return a file object. + + Parameters + ---------- + path : str + The path where the file is read from. + + compression : {'gzip', 'bz2', 'zip', 'xz', 'zstd', None} + Name of the decompression to use + + Returns + ------- + file object + """ + with get_handle(path, "rb", compression=compression, is_text=False) as handle: + yield handle.handle + + +@contextmanager +def set_timezone(tz: str) -> Generator[None, None, None]: + """ + Context manager for temporarily setting a timezone. + + Parameters + ---------- + tz : str + A string representing a valid timezone. + + Examples + -------- + >>> from datetime import datetime + >>> from dateutil.tz import tzlocal + >>> tzlocal().tzname(datetime(2021, 1, 1)) # doctest: +SKIP + 'IST' + + >>> with set_timezone('US/Eastern'): + ... tzlocal().tzname(datetime(2021, 1, 1)) + ... + 'EST' + """ + import time + + def setTZ(tz) -> None: + if tz is None: + try: + del os.environ["TZ"] + except KeyError: + pass + else: + os.environ["TZ"] = tz + time.tzset() + + orig_tz = os.environ.get("TZ") + setTZ(tz) + try: + yield + finally: + setTZ(orig_tz) + + +@contextmanager +def ensure_clean( + filename=None, return_filelike: bool = False, **kwargs: Any +) -> Generator[Any, None, None]: + """ + Gets a temporary path and agrees to remove on close. + + This implementation does not use tempfile.mkstemp to avoid having a file handle. + If the code using the returned path wants to delete the file itself, windows + requires that no program has a file handle to it. + + Parameters + ---------- + filename : str (optional) + suffix of the created file. + return_filelike : bool (default False) + if True, returns a file-like which is *always* cleaned. Necessary for + savefig and other functions which want to append extensions. + **kwargs + Additional keywords are passed to open(). + + """ + folder = Path(tempfile.gettempdir()) + + if filename is None: + filename = "" + filename = str(uuid.uuid4()) + filename + path = folder / filename + + path.touch() + + handle_or_str: str | IO = str(path) + encoding = kwargs.pop("encoding", None) + if return_filelike: + kwargs.setdefault("mode", "w+b") + if encoding is None and "b" not in kwargs["mode"]: + encoding = "utf-8" + handle_or_str = open(path, encoding=encoding, **kwargs) + + try: + yield handle_or_str + finally: + if not isinstance(handle_or_str, str): + handle_or_str.close() + if path.is_file(): + path.unlink() + + +@contextmanager +def with_csv_dialect(name: str, **kwargs) -> Generator[None, None, None]: + """ + Context manager to temporarily register a CSV dialect for parsing CSV. + + Parameters + ---------- + name : str + The name of the dialect. + kwargs : mapping + The parameters for the dialect. + + Raises + ------ + ValueError : the name of the dialect conflicts with a builtin one. + + See Also + -------- + csv : Python's CSV library. + """ + import csv + + _BUILTIN_DIALECTS = {"excel", "excel-tab", "unix"} + + if name in _BUILTIN_DIALECTS: + raise ValueError("Cannot override builtin dialect.") + + csv.register_dialect(name, **kwargs) + try: + yield + finally: + csv.unregister_dialect(name) + + +@contextmanager +def use_numexpr(use, min_elements=None) -> Generator[None, None, None]: + from pandas.core.computation import expressions as expr + + if min_elements is None: + min_elements = expr._MIN_ELEMENTS + + olduse = expr.USE_NUMEXPR + oldmin = expr._MIN_ELEMENTS + set_option("compute.use_numexpr", use) + expr._MIN_ELEMENTS = min_elements + try: + yield + finally: + expr._MIN_ELEMENTS = oldmin + set_option("compute.use_numexpr", olduse) + + +def raises_chained_assignment_error(warn=True, extra_warnings=(), extra_match=()): + from pandas._testing import assert_produces_warning + + if not warn: + from contextlib import nullcontext + + return nullcontext() + + if PYPY and not extra_warnings: + from contextlib import nullcontext + + return nullcontext() + elif PYPY and extra_warnings: + return assert_produces_warning( + extra_warnings, + match="|".join(extra_match), + ) + else: + if using_copy_on_write(): + warning = ChainedAssignmentError + match = ( + "A value is trying to be set on a copy of a DataFrame or Series " + "through chained assignment" + ) + else: + warning = FutureWarning # type: ignore[assignment] + # TODO update match + match = "ChainedAssignmentError" + if extra_warnings: + warning = (warning, *extra_warnings) # type: ignore[assignment] + return assert_produces_warning( + warning, + match="|".join((match, *extra_match)), + ) + + +def assert_cow_warning(warn=True, match=None, **kwargs): + """ + Assert that a warning is raised in the CoW warning mode. + + Parameters + ---------- + warn : bool, default True + By default, check that a warning is raised. Can be turned off by passing False. + match : str + The warning message to match against, if different from the default. + kwargs + Passed through to assert_produces_warning + """ + from pandas._testing import assert_produces_warning + + if not warn: + from contextlib import nullcontext + + return nullcontext() + + if not match: + match = "Setting a value on a view" + + return assert_produces_warning(FutureWarning, match=match, **kwargs) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/arrays/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/arrays/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a11755275d00e070bea6ab73a881b98d0b976551 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/arrays/__init__.py @@ -0,0 +1,53 @@ +""" +All of pandas' ExtensionArrays. + +See :ref:`extending.extension-types` for more. +""" +from pandas.core.arrays import ( + ArrowExtensionArray, + ArrowStringArray, + BooleanArray, + Categorical, + DatetimeArray, + FloatingArray, + IntegerArray, + IntervalArray, + NumpyExtensionArray, + PeriodArray, + SparseArray, + StringArray, + TimedeltaArray, +) + +__all__ = [ + "ArrowExtensionArray", + "ArrowStringArray", + "BooleanArray", + "Categorical", + "DatetimeArray", + "FloatingArray", + "IntegerArray", + "IntervalArray", + "NumpyExtensionArray", + "PeriodArray", + "SparseArray", + "StringArray", + "TimedeltaArray", +] + + +def __getattr__(name: str) -> type[NumpyExtensionArray]: + if name == "PandasArray": + # GH#53694 + import warnings + + from pandas.util._exceptions import find_stack_level + + warnings.warn( + "PandasArray has been renamed NumpyExtensionArray. Use that " + "instead. This alias will be removed in a future version.", + FutureWarning, + stacklevel=find_stack_level(), + ) + return NumpyExtensionArray + raise AttributeError(f"module 'pandas.arrays' has no attribute '{name}'") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/errors/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/errors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..01094ba36b9dd5f3414c32a9a4f832b85902e021 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/errors/__init__.py @@ -0,0 +1,850 @@ +""" +Expose public exceptions & warnings +""" +from __future__ import annotations + +import ctypes + +from pandas._config.config import OptionError + +from pandas._libs.tslibs import ( + OutOfBoundsDatetime, + OutOfBoundsTimedelta, +) + +from pandas.util.version import InvalidVersion + + +class IntCastingNaNError(ValueError): + """ + Exception raised when converting (``astype``) an array with NaN to an integer type. + + Examples + -------- + >>> pd.DataFrame(np.array([[1, np.nan], [2, 3]]), dtype="i8") + Traceback (most recent call last): + IntCastingNaNError: Cannot convert non-finite values (NA or inf) to integer + """ + + +class NullFrequencyError(ValueError): + """ + Exception raised when a ``freq`` cannot be null. + + Particularly ``DatetimeIndex.shift``, ``TimedeltaIndex.shift``, + ``PeriodIndex.shift``. + + Examples + -------- + >>> df = pd.DatetimeIndex(["2011-01-01 10:00", "2011-01-01"], freq=None) + >>> df.shift(2) + Traceback (most recent call last): + NullFrequencyError: Cannot shift with no freq + """ + + +class PerformanceWarning(Warning): + """ + Warning raised when there is a possible performance impact. + + Examples + -------- + >>> df = pd.DataFrame({"jim": [0, 0, 1, 1], + ... "joe": ["x", "x", "z", "y"], + ... "jolie": [1, 2, 3, 4]}) + >>> df = df.set_index(["jim", "joe"]) + >>> df + jolie + jim joe + 0 x 1 + x 2 + 1 z 3 + y 4 + >>> df.loc[(1, 'z')] # doctest: +SKIP + # PerformanceWarning: indexing past lexsort depth may impact performance. + df.loc[(1, 'z')] + jolie + jim joe + 1 z 3 + """ + + +class UnsupportedFunctionCall(ValueError): + """ + Exception raised when attempting to call a unsupported numpy function. + + For example, ``np.cumsum(groupby_object)``. + + Examples + -------- + >>> df = pd.DataFrame({"A": [0, 0, 1, 1], + ... "B": ["x", "x", "z", "y"], + ... "C": [1, 2, 3, 4]} + ... ) + >>> np.cumsum(df.groupby(["A"])) + Traceback (most recent call last): + UnsupportedFunctionCall: numpy operations are not valid with groupby. + Use .groupby(...).cumsum() instead + """ + + +class UnsortedIndexError(KeyError): + """ + Error raised when slicing a MultiIndex which has not been lexsorted. + + Subclass of `KeyError`. + + Examples + -------- + >>> df = pd.DataFrame({"cat": [0, 0, 1, 1], + ... "color": ["white", "white", "brown", "black"], + ... "lives": [4, 4, 3, 7]}, + ... ) + >>> df = df.set_index(["cat", "color"]) + >>> df + lives + cat color + 0 white 4 + white 4 + 1 brown 3 + black 7 + >>> df.loc[(0, "black"):(1, "white")] + Traceback (most recent call last): + UnsortedIndexError: 'Key length (2) was greater + than MultiIndex lexsort depth (1)' + """ + + +class ParserError(ValueError): + """ + Exception that is raised by an error encountered in parsing file contents. + + This is a generic error raised for errors encountered when functions like + `read_csv` or `read_html` are parsing contents of a file. + + See Also + -------- + read_csv : Read CSV (comma-separated) file into a DataFrame. + read_html : Read HTML table into a DataFrame. + + Examples + -------- + >>> data = '''a,b,c + ... cat,foo,bar + ... dog,foo,"baz''' + >>> from io import StringIO + >>> pd.read_csv(StringIO(data), skipfooter=1, engine='python') + Traceback (most recent call last): + ParserError: ',' expected after '"'. Error could possibly be due + to parsing errors in the skipped footer rows + """ + + +class DtypeWarning(Warning): + """ + Warning raised when reading different dtypes in a column from a file. + + Raised for a dtype incompatibility. This can happen whenever `read_csv` + or `read_table` encounter non-uniform dtypes in a column(s) of a given + CSV file. + + See Also + -------- + read_csv : Read CSV (comma-separated) file into a DataFrame. + read_table : Read general delimited file into a DataFrame. + + Notes + ----- + This warning is issued when dealing with larger files because the dtype + checking happens per chunk read. + + Despite the warning, the CSV file is read with mixed types in a single + column which will be an object type. See the examples below to better + understand this issue. + + Examples + -------- + This example creates and reads a large CSV file with a column that contains + `int` and `str`. + + >>> df = pd.DataFrame({'a': (['1'] * 100000 + ['X'] * 100000 + + ... ['1'] * 100000), + ... 'b': ['b'] * 300000}) # doctest: +SKIP + >>> df.to_csv('test.csv', index=False) # doctest: +SKIP + >>> df2 = pd.read_csv('test.csv') # doctest: +SKIP + ... # DtypeWarning: Columns (0) have mixed types + + Important to notice that ``df2`` will contain both `str` and `int` for the + same input, '1'. + + >>> df2.iloc[262140, 0] # doctest: +SKIP + '1' + >>> type(df2.iloc[262140, 0]) # doctest: +SKIP + + >>> df2.iloc[262150, 0] # doctest: +SKIP + 1 + >>> type(df2.iloc[262150, 0]) # doctest: +SKIP + + + One way to solve this issue is using the `dtype` parameter in the + `read_csv` and `read_table` functions to explicit the conversion: + + >>> df2 = pd.read_csv('test.csv', sep=',', dtype={'a': str}) # doctest: +SKIP + + No warning was issued. + """ + + +class EmptyDataError(ValueError): + """ + Exception raised in ``pd.read_csv`` when empty data or header is encountered. + + Examples + -------- + >>> from io import StringIO + >>> empty = StringIO() + >>> pd.read_csv(empty) + Traceback (most recent call last): + EmptyDataError: No columns to parse from file + """ + + +class ParserWarning(Warning): + """ + Warning raised when reading a file that doesn't use the default 'c' parser. + + Raised by `pd.read_csv` and `pd.read_table` when it is necessary to change + parsers, generally from the default 'c' parser to 'python'. + + It happens due to a lack of support or functionality for parsing a + particular attribute of a CSV file with the requested engine. + + Currently, 'c' unsupported options include the following parameters: + + 1. `sep` other than a single character (e.g. regex separators) + 2. `skipfooter` higher than 0 + 3. `sep=None` with `delim_whitespace=False` + + The warning can be avoided by adding `engine='python'` as a parameter in + `pd.read_csv` and `pd.read_table` methods. + + See Also + -------- + pd.read_csv : Read CSV (comma-separated) file into DataFrame. + pd.read_table : Read general delimited file into DataFrame. + + Examples + -------- + Using a `sep` in `pd.read_csv` other than a single character: + + >>> import io + >>> csv = '''a;b;c + ... 1;1,8 + ... 1;2,1''' + >>> df = pd.read_csv(io.StringIO(csv), sep='[;,]') # doctest: +SKIP + ... # ParserWarning: Falling back to the 'python' engine... + + Adding `engine='python'` to `pd.read_csv` removes the Warning: + + >>> df = pd.read_csv(io.StringIO(csv), sep='[;,]', engine='python') + """ + + +class MergeError(ValueError): + """ + Exception raised when merging data. + + Subclass of ``ValueError``. + + Examples + -------- + >>> left = pd.DataFrame({"a": ["a", "b", "b", "d"], + ... "b": ["cat", "dog", "weasel", "horse"]}, + ... index=range(4)) + >>> right = pd.DataFrame({"a": ["a", "b", "c", "d"], + ... "c": ["meow", "bark", "chirp", "nay"]}, + ... index=range(4)).set_index("a") + >>> left.join(right, on="a", validate="one_to_one",) + Traceback (most recent call last): + MergeError: Merge keys are not unique in left dataset; not a one-to-one merge + """ + + +class AbstractMethodError(NotImplementedError): + """ + Raise this error instead of NotImplementedError for abstract methods. + + Examples + -------- + >>> class Foo: + ... @classmethod + ... def classmethod(cls): + ... raise pd.errors.AbstractMethodError(cls, methodtype="classmethod") + ... def method(self): + ... raise pd.errors.AbstractMethodError(self) + >>> test = Foo.classmethod() + Traceback (most recent call last): + AbstractMethodError: This classmethod must be defined in the concrete class Foo + + >>> test2 = Foo().method() + Traceback (most recent call last): + AbstractMethodError: This classmethod must be defined in the concrete class Foo + """ + + def __init__(self, class_instance, methodtype: str = "method") -> None: + types = {"method", "classmethod", "staticmethod", "property"} + if methodtype not in types: + raise ValueError( + f"methodtype must be one of {methodtype}, got {types} instead." + ) + self.methodtype = methodtype + self.class_instance = class_instance + + def __str__(self) -> str: + if self.methodtype == "classmethod": + name = self.class_instance.__name__ + else: + name = type(self.class_instance).__name__ + return f"This {self.methodtype} must be defined in the concrete class {name}" + + +class NumbaUtilError(Exception): + """ + Error raised for unsupported Numba engine routines. + + Examples + -------- + >>> df = pd.DataFrame({"key": ["a", "a", "b", "b"], "data": [1, 2, 3, 4]}, + ... columns=["key", "data"]) + >>> def incorrect_function(x): + ... return sum(x) * 2.7 + >>> df.groupby("key").agg(incorrect_function, engine="numba") + Traceback (most recent call last): + NumbaUtilError: The first 2 arguments to incorrect_function + must be ['values', 'index'] + """ + + +class DuplicateLabelError(ValueError): + """ + Error raised when an operation would introduce duplicate labels. + + Examples + -------- + >>> s = pd.Series([0, 1, 2], index=['a', 'b', 'c']).set_flags( + ... allows_duplicate_labels=False + ... ) + >>> s.reindex(['a', 'a', 'b']) + Traceback (most recent call last): + ... + DuplicateLabelError: Index has duplicates. + positions + label + a [0, 1] + """ + + +class InvalidIndexError(Exception): + """ + Exception raised when attempting to use an invalid index key. + + Examples + -------- + >>> idx = pd.MultiIndex.from_product([["x", "y"], [0, 1]]) + >>> df = pd.DataFrame([[1, 1, 2, 2], + ... [3, 3, 4, 4]], columns=idx) + >>> df + x y + 0 1 0 1 + 0 1 1 2 2 + 1 3 3 4 4 + >>> df[:, 0] + Traceback (most recent call last): + InvalidIndexError: (slice(None, None, None), 0) + """ + + +class DataError(Exception): + """ + Exceptionn raised when performing an operation on non-numerical data. + + For example, calling ``ohlc`` on a non-numerical column or a function + on a rolling window. + + Examples + -------- + >>> ser = pd.Series(['a', 'b', 'c']) + >>> ser.rolling(2).sum() + Traceback (most recent call last): + DataError: No numeric types to aggregate + """ + + +class SpecificationError(Exception): + """ + Exception raised by ``agg`` when the functions are ill-specified. + + The exception raised in two scenarios. + + The first way is calling ``agg`` on a + Dataframe or Series using a nested renamer (dict-of-dict). + + The second way is calling ``agg`` on a Dataframe with duplicated functions + names without assigning column name. + + Examples + -------- + >>> df = pd.DataFrame({'A': [1, 1, 1, 2, 2], + ... 'B': range(5), + ... 'C': range(5)}) + >>> df.groupby('A').B.agg({'foo': 'count'}) # doctest: +SKIP + ... # SpecificationError: nested renamer is not supported + + >>> df.groupby('A').agg({'B': {'foo': ['sum', 'max']}}) # doctest: +SKIP + ... # SpecificationError: nested renamer is not supported + + >>> df.groupby('A').agg(['min', 'min']) # doctest: +SKIP + ... # SpecificationError: nested renamer is not supported + """ + + +class SettingWithCopyError(ValueError): + """ + Exception raised when trying to set on a copied slice from a ``DataFrame``. + + The ``mode.chained_assignment`` needs to be set to set to 'raise.' This can + happen unintentionally when chained indexing. + + For more information on evaluation order, + see :ref:`the user guide`. + + For more information on view vs. copy, + see :ref:`the user guide`. + + Examples + -------- + >>> pd.options.mode.chained_assignment = 'raise' + >>> df = pd.DataFrame({'A': [1, 1, 1, 2, 2]}, columns=['A']) + >>> df.loc[0:3]['A'] = 'a' # doctest: +SKIP + ... # SettingWithCopyError: A value is trying to be set on a copy of a... + """ + + +class SettingWithCopyWarning(Warning): + """ + Warning raised when trying to set on a copied slice from a ``DataFrame``. + + The ``mode.chained_assignment`` needs to be set to set to 'warn.' + 'Warn' is the default option. This can happen unintentionally when + chained indexing. + + For more information on evaluation order, + see :ref:`the user guide`. + + For more information on view vs. copy, + see :ref:`the user guide`. + + Examples + -------- + >>> df = pd.DataFrame({'A': [1, 1, 1, 2, 2]}, columns=['A']) + >>> df.loc[0:3]['A'] = 'a' # doctest: +SKIP + ... # SettingWithCopyWarning: A value is trying to be set on a copy of a... + """ + + +class ChainedAssignmentError(Warning): + """ + Warning raised when trying to set using chained assignment. + + When the ``mode.copy_on_write`` option is enabled, chained assignment can + never work. In such a situation, we are always setting into a temporary + object that is the result of an indexing operation (getitem), which under + Copy-on-Write always behaves as a copy. Thus, assigning through a chain + can never update the original Series or DataFrame. + + For more information on view vs. copy, + see :ref:`the user guide`. + + Examples + -------- + >>> pd.options.mode.copy_on_write = True + >>> df = pd.DataFrame({'A': [1, 1, 1, 2, 2]}, columns=['A']) + >>> df["A"][0:3] = 10 # doctest: +SKIP + ... # ChainedAssignmentError: ... + >>> pd.options.mode.copy_on_write = False + """ + + +_chained_assignment_msg = ( + "A value is trying to be set on a copy of a DataFrame or Series " + "through chained assignment.\n" + "When using the Copy-on-Write mode, such chained assignment never works " + "to update the original DataFrame or Series, because the intermediate " + "object on which we are setting values always behaves as a copy.\n\n" + "Try using '.loc[row_indexer, col_indexer] = value' instead, to perform " + "the assignment in a single step.\n\n" + "See the caveats in the documentation: " + "https://pandas.pydata.org/pandas-docs/stable/user_guide/" + "indexing.html#returning-a-view-versus-a-copy" +) + + +_chained_assignment_method_msg = ( + "A value is trying to be set on a copy of a DataFrame or Series " + "through chained assignment using an inplace method.\n" + "When using the Copy-on-Write mode, such inplace method never works " + "to update the original DataFrame or Series, because the intermediate " + "object on which we are setting values always behaves as a copy.\n\n" + "For example, when doing 'df[col].method(value, inplace=True)', try " + "using 'df.method({col: value}, inplace=True)' instead, to perform " + "the operation inplace on the original object.\n\n" +) + + +_chained_assignment_warning_msg = ( + "ChainedAssignmentError: behaviour will change in pandas 3.0!\n" + "You are setting values through chained assignment. Currently this works " + "in certain cases, but when using Copy-on-Write (which will become the " + "default behaviour in pandas 3.0) this will never work to update the " + "original DataFrame or Series, because the intermediate object on which " + "we are setting values will behave as a copy.\n" + "A typical example is when you are setting values in a column of a " + "DataFrame, like:\n\n" + 'df["col"][row_indexer] = value\n\n' + 'Use `df.loc[row_indexer, "col"] = values` instead, to perform the ' + "assignment in a single step and ensure this keeps updating the original `df`.\n\n" + "See the caveats in the documentation: " + "https://pandas.pydata.org/pandas-docs/stable/user_guide/" + "indexing.html#returning-a-view-versus-a-copy\n" +) + + +_chained_assignment_warning_method_msg = ( + "A value is trying to be set on a copy of a DataFrame or Series " + "through chained assignment using an inplace method.\n" + "The behavior will change in pandas 3.0. This inplace method will " + "never work because the intermediate object on which we are setting " + "values always behaves as a copy.\n\n" + "For example, when doing 'df[col].method(value, inplace=True)', try " + "using 'df.method({col: value}, inplace=True)' or " + "df[col] = df[col].method(value) instead, to perform " + "the operation inplace on the original object.\n\n" +) + + +def _check_cacher(obj): + # This is a mess, selection paths that return a view set the _cacher attribute + # on the Series; most of them also set _item_cache which adds 1 to our relevant + # reference count, but iloc does not, so we have to check if we are actually + # in the item cache + if hasattr(obj, "_cacher"): + parent = obj._cacher[1]() + # parent could be dead + if parent is None: + return False + if hasattr(parent, "_item_cache"): + if obj._cacher[0] in parent._item_cache: + # Check if we are actually the item from item_cache, iloc creates a + # new object + return obj is parent._item_cache[obj._cacher[0]] + return False + + +class NumExprClobberingError(NameError): + """ + Exception raised when trying to use a built-in numexpr name as a variable name. + + ``eval`` or ``query`` will throw the error if the engine is set + to 'numexpr'. 'numexpr' is the default engine value for these methods if the + numexpr package is installed. + + Examples + -------- + >>> df = pd.DataFrame({'abs': [1, 1, 1]}) + >>> df.query("abs > 2") # doctest: +SKIP + ... # NumExprClobberingError: Variables in expression "(abs) > (2)" overlap... + >>> sin, a = 1, 2 + >>> pd.eval("sin + a", engine='numexpr') # doctest: +SKIP + ... # NumExprClobberingError: Variables in expression "(sin) + (a)" overlap... + """ + + +class UndefinedVariableError(NameError): + """ + Exception raised by ``query`` or ``eval`` when using an undefined variable name. + + It will also specify whether the undefined variable is local or not. + + Examples + -------- + >>> df = pd.DataFrame({'A': [1, 1, 1]}) + >>> df.query("A > x") # doctest: +SKIP + ... # UndefinedVariableError: name 'x' is not defined + >>> df.query("A > @y") # doctest: +SKIP + ... # UndefinedVariableError: local variable 'y' is not defined + >>> pd.eval('x + 1') # doctest: +SKIP + ... # UndefinedVariableError: name 'x' is not defined + """ + + def __init__(self, name: str, is_local: bool | None = None) -> None: + base_msg = f"{repr(name)} is not defined" + if is_local: + msg = f"local variable {base_msg}" + else: + msg = f"name {base_msg}" + super().__init__(msg) + + +class IndexingError(Exception): + """ + Exception is raised when trying to index and there is a mismatch in dimensions. + + Examples + -------- + >>> df = pd.DataFrame({'A': [1, 1, 1]}) + >>> df.loc[..., ..., 'A'] # doctest: +SKIP + ... # IndexingError: indexer may only contain one '...' entry + >>> df = pd.DataFrame({'A': [1, 1, 1]}) + >>> df.loc[1, ..., ...] # doctest: +SKIP + ... # IndexingError: Too many indexers + >>> df[pd.Series([True], dtype=bool)] # doctest: +SKIP + ... # IndexingError: Unalignable boolean Series provided as indexer... + >>> s = pd.Series(range(2), + ... index = pd.MultiIndex.from_product([["a", "b"], ["c"]])) + >>> s.loc["a", "c", "d"] # doctest: +SKIP + ... # IndexingError: Too many indexers + """ + + +class PyperclipException(RuntimeError): + """ + Exception raised when clipboard functionality is unsupported. + + Raised by ``to_clipboard()`` and ``read_clipboard()``. + """ + + +class PyperclipWindowsException(PyperclipException): + """ + Exception raised when clipboard functionality is unsupported by Windows. + + Access to the clipboard handle would be denied due to some other + window process is accessing it. + """ + + def __init__(self, message: str) -> None: + # attr only exists on Windows, so typing fails on other platforms + message += f" ({ctypes.WinError()})" # type: ignore[attr-defined] + super().__init__(message) + + +class CSSWarning(UserWarning): + """ + Warning is raised when converting css styling fails. + + This can be due to the styling not having an equivalent value or because the + styling isn't properly formatted. + + Examples + -------- + >>> df = pd.DataFrame({'A': [1, 1, 1]}) + >>> df.style.applymap( + ... lambda x: 'background-color: blueGreenRed;' + ... ).to_excel('styled.xlsx') # doctest: +SKIP + CSSWarning: Unhandled color format: 'blueGreenRed' + >>> df.style.applymap( + ... lambda x: 'border: 1px solid red red;' + ... ).to_excel('styled.xlsx') # doctest: +SKIP + CSSWarning: Unhandled color format: 'blueGreenRed' + """ + + +class PossibleDataLossError(Exception): + """ + Exception raised when trying to open a HDFStore file when already opened. + + Examples + -------- + >>> store = pd.HDFStore('my-store', 'a') # doctest: +SKIP + >>> store.open("w") # doctest: +SKIP + ... # PossibleDataLossError: Re-opening the file [my-store] with mode [a]... + """ + + +class ClosedFileError(Exception): + """ + Exception is raised when trying to perform an operation on a closed HDFStore file. + + Examples + -------- + >>> store = pd.HDFStore('my-store', 'a') # doctest: +SKIP + >>> store.close() # doctest: +SKIP + >>> store.keys() # doctest: +SKIP + ... # ClosedFileError: my-store file is not open! + """ + + +class IncompatibilityWarning(Warning): + """ + Warning raised when trying to use where criteria on an incompatible HDF5 file. + """ + + +class AttributeConflictWarning(Warning): + """ + Warning raised when index attributes conflict when using HDFStore. + + Occurs when attempting to append an index with a different + name than the existing index on an HDFStore or attempting to append an index with a + different frequency than the existing index on an HDFStore. + + Examples + -------- + >>> idx1 = pd.Index(['a', 'b'], name='name1') + >>> df1 = pd.DataFrame([[1, 2], [3, 4]], index=idx1) + >>> df1.to_hdf('file', 'data', 'w', append=True) # doctest: +SKIP + >>> idx2 = pd.Index(['c', 'd'], name='name2') + >>> df2 = pd.DataFrame([[5, 6], [7, 8]], index=idx2) + >>> df2.to_hdf('file', 'data', 'a', append=True) # doctest: +SKIP + AttributeConflictWarning: the [index_name] attribute of the existing index is + [name1] which conflicts with the new [name2]... + """ + + +class DatabaseError(OSError): + """ + Error is raised when executing sql with bad syntax or sql that throws an error. + + Examples + -------- + >>> from sqlite3 import connect + >>> conn = connect(':memory:') + >>> pd.read_sql('select * test', conn) # doctest: +SKIP + ... # DatabaseError: Execution failed on sql 'test': near "test": syntax error + """ + + +class PossiblePrecisionLoss(Warning): + """ + Warning raised by to_stata on a column with a value outside or equal to int64. + + When the column value is outside or equal to the int64 value the column is + converted to a float64 dtype. + + Examples + -------- + >>> df = pd.DataFrame({"s": pd.Series([1, 2**53], dtype=np.int64)}) + >>> df.to_stata('test') # doctest: +SKIP + ... # PossiblePrecisionLoss: Column converted from int64 to float64... + """ + + +class ValueLabelTypeMismatch(Warning): + """ + Warning raised by to_stata on a category column that contains non-string values. + + Examples + -------- + >>> df = pd.DataFrame({"categories": pd.Series(["a", 2], dtype="category")}) + >>> df.to_stata('test') # doctest: +SKIP + ... # ValueLabelTypeMismatch: Stata value labels (pandas categories) must be str... + """ + + +class InvalidColumnName(Warning): + """ + Warning raised by to_stata the column contains a non-valid stata name. + + Because the column name is an invalid Stata variable, the name needs to be + converted. + + Examples + -------- + >>> df = pd.DataFrame({"0categories": pd.Series([2, 2])}) + >>> df.to_stata('test') # doctest: +SKIP + ... # InvalidColumnName: Not all pandas column names were valid Stata variable... + """ + + +class CategoricalConversionWarning(Warning): + """ + Warning is raised when reading a partial labeled Stata file using a iterator. + + Examples + -------- + >>> from pandas.io.stata import StataReader + >>> with StataReader('dta_file', chunksize=2) as reader: # doctest: +SKIP + ... for i, block in enumerate(reader): + ... print(i, block) + ... # CategoricalConversionWarning: One or more series with value labels... + """ + + +class LossySetitemError(Exception): + """ + Raised when trying to do a __setitem__ on an np.ndarray that is not lossless. + + Notes + ----- + This is an internal error. + """ + + +class NoBufferPresent(Exception): + """ + Exception is raised in _get_data_buffer to signal that there is no requested buffer. + """ + + +class InvalidComparison(Exception): + """ + Exception is raised by _validate_comparison_value to indicate an invalid comparison. + + Notes + ----- + This is an internal error. + """ + + +__all__ = [ + "AbstractMethodError", + "AttributeConflictWarning", + "CategoricalConversionWarning", + "ClosedFileError", + "CSSWarning", + "DatabaseError", + "DataError", + "DtypeWarning", + "DuplicateLabelError", + "EmptyDataError", + "IncompatibilityWarning", + "IntCastingNaNError", + "InvalidColumnName", + "InvalidComparison", + "InvalidIndexError", + "InvalidVersion", + "IndexingError", + "LossySetitemError", + "MergeError", + "NoBufferPresent", + "NullFrequencyError", + "NumbaUtilError", + "NumExprClobberingError", + "OptionError", + "OutOfBoundsDatetime", + "OutOfBoundsTimedelta", + "ParserError", + "ParserWarning", + "PerformanceWarning", + "PossibleDataLossError", + "PossiblePrecisionLoss", + "PyperclipException", + "PyperclipWindowsException", + "SettingWithCopyError", + "SettingWithCopyWarning", + "SpecificationError", + "UndefinedVariableError", + "UnsortedIndexError", + "UnsupportedFunctionCall", + "ValueLabelTypeMismatch", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/io/feather_format.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/io/feather_format.py new file mode 100644 index 0000000000000000000000000000000000000000..d0aaf83b84cb241ebdd872c1c8b7982fadc9acdb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/io/feather_format.py @@ -0,0 +1,143 @@ +""" feather-format compat """ +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, +) + +from pandas._config import using_pyarrow_string_dtype + +from pandas._libs import lib +from pandas.compat._optional import import_optional_dependency +from pandas.util._decorators import doc +from pandas.util._validators import check_dtype_backend + +import pandas as pd +from pandas.core.api import DataFrame +from pandas.core.shared_docs import _shared_docs + +from pandas.io._util import arrow_string_types_mapper +from pandas.io.common import get_handle + +if TYPE_CHECKING: + from collections.abc import ( + Hashable, + Sequence, + ) + + from pandas._typing import ( + DtypeBackend, + FilePath, + ReadBuffer, + StorageOptions, + WriteBuffer, + ) + + +@doc(storage_options=_shared_docs["storage_options"]) +def to_feather( + df: DataFrame, + path: FilePath | WriteBuffer[bytes], + storage_options: StorageOptions | None = None, + **kwargs: Any, +) -> None: + """ + Write a DataFrame to the binary Feather format. + + Parameters + ---------- + df : DataFrame + path : str, path object, or file-like object + {storage_options} + **kwargs : + Additional keywords passed to `pyarrow.feather.write_feather`. + + """ + import_optional_dependency("pyarrow") + from pyarrow import feather + + if not isinstance(df, DataFrame): + raise ValueError("feather only support IO with DataFrames") + + with get_handle( + path, "wb", storage_options=storage_options, is_text=False + ) as handles: + feather.write_feather(df, handles.handle, **kwargs) + + +@doc(storage_options=_shared_docs["storage_options"]) +def read_feather( + path: FilePath | ReadBuffer[bytes], + columns: Sequence[Hashable] | None = None, + use_threads: bool = True, + storage_options: StorageOptions | None = None, + dtype_backend: DtypeBackend | lib.NoDefault = lib.no_default, +) -> DataFrame: + """ + Load a feather-format object from the file path. + + Parameters + ---------- + path : str, path object, or file-like object + String, path object (implementing ``os.PathLike[str]``), or file-like + object implementing a binary ``read()`` function. The string could be a URL. + Valid URL schemes include http, ftp, s3, and file. For file URLs, a host is + expected. A local file could be: ``file://localhost/path/to/table.feather``. + columns : sequence, default None + If not provided, all columns are read. + use_threads : bool, default True + Whether to parallelize reading using multiple threads. + {storage_options} + + dtype_backend : {{'numpy_nullable', 'pyarrow'}}, default 'numpy_nullable' + Back-end data type applied to the resultant :class:`DataFrame` + (still experimental). Behaviour is as follows: + + * ``"numpy_nullable"``: returns nullable-dtype-backed :class:`DataFrame` + (default). + * ``"pyarrow"``: returns pyarrow-backed nullable :class:`ArrowDtype` + DataFrame. + + .. versionadded:: 2.0 + + Returns + ------- + type of object stored in file + + Examples + -------- + >>> df = pd.read_feather("path/to/file.feather") # doctest: +SKIP + """ + import_optional_dependency("pyarrow") + from pyarrow import feather + + # import utils to register the pyarrow extension types + import pandas.core.arrays.arrow.extension_types # pyright: ignore[reportUnusedImport] # noqa: F401 + + check_dtype_backend(dtype_backend) + + with get_handle( + path, "rb", storage_options=storage_options, is_text=False + ) as handles: + if dtype_backend is lib.no_default and not using_pyarrow_string_dtype(): + return feather.read_feather( + handles.handle, columns=columns, use_threads=bool(use_threads) + ) + + pa_table = feather.read_table( + handles.handle, columns=columns, use_threads=bool(use_threads) + ) + + if dtype_backend == "numpy_nullable": + from pandas.io._util import _arrow_dtype_mapping + + return pa_table.to_pandas(types_mapper=_arrow_dtype_mapping().get) + + elif dtype_backend == "pyarrow": + return pa_table.to_pandas(types_mapper=pd.ArrowDtype) + + elif using_pyarrow_string_dtype(): + return pa_table.to_pandas(types_mapper=arrow_string_types_mapper()) + else: + raise NotImplementedError diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/io/gbq.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/io/gbq.py new file mode 100644 index 0000000000000000000000000000000000000000..24e4e0b7cef0a5fa66a70fa0ad70b52364b02091 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/io/gbq.py @@ -0,0 +1,255 @@ +""" Google BigQuery support """ +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, +) +import warnings + +from pandas.compat._optional import import_optional_dependency +from pandas.util._exceptions import find_stack_level + +if TYPE_CHECKING: + from google.auth.credentials import Credentials + + from pandas import DataFrame + + +def _try_import(): + # since pandas is a dependency of pandas-gbq + # we need to import on first use + msg = ( + "pandas-gbq is required to load data from Google BigQuery. " + "See the docs: https://pandas-gbq.readthedocs.io." + ) + pandas_gbq = import_optional_dependency("pandas_gbq", extra=msg) + return pandas_gbq + + +def read_gbq( + query: str, + project_id: str | None = None, + index_col: str | None = None, + col_order: list[str] | None = None, + reauth: bool = False, + auth_local_webserver: bool = True, + dialect: str | None = None, + location: str | None = None, + configuration: dict[str, Any] | None = None, + credentials: Credentials | None = None, + use_bqstorage_api: bool | None = None, + max_results: int | None = None, + progress_bar_type: str | None = None, +) -> DataFrame: + """ + Load data from Google BigQuery. + + .. deprecated:: 2.2.0 + + Please use ``pandas_gbq.read_gbq`` instead. + + This function requires the `pandas-gbq package + `__. + + See the `How to authenticate with Google BigQuery + `__ + guide for authentication instructions. + + Parameters + ---------- + query : str + SQL-Like Query to return data values. + project_id : str, optional + Google BigQuery Account project ID. Optional when available from + the environment. + index_col : str, optional + Name of result column to use for index in results DataFrame. + col_order : list(str), optional + List of BigQuery column names in the desired order for results + DataFrame. + reauth : bool, default False + Force Google BigQuery to re-authenticate the user. This is useful + if multiple accounts are used. + auth_local_webserver : bool, default True + Use the `local webserver flow`_ instead of the `console flow`_ + when getting user credentials. + + .. _local webserver flow: + https://google-auth-oauthlib.readthedocs.io/en/latest/reference/google_auth_oauthlib.flow.html#google_auth_oauthlib.flow.InstalledAppFlow.run_local_server + .. _console flow: + https://google-auth-oauthlib.readthedocs.io/en/latest/reference/google_auth_oauthlib.flow.html#google_auth_oauthlib.flow.InstalledAppFlow.run_console + + *New in version 0.2.0 of pandas-gbq*. + + .. versionchanged:: 1.5.0 + Default value is changed to ``True``. Google has deprecated the + ``auth_local_webserver = False`` `"out of band" (copy-paste) + flow + `_. + dialect : str, default 'legacy' + Note: The default value is changing to 'standard' in a future version. + + SQL syntax dialect to use. Value can be one of: + + ``'legacy'`` + Use BigQuery's legacy SQL dialect. For more information see + `BigQuery Legacy SQL Reference + `__. + ``'standard'`` + Use BigQuery's standard SQL, which is + compliant with the SQL 2011 standard. For more information + see `BigQuery Standard SQL Reference + `__. + location : str, optional + Location where the query job should run. See the `BigQuery locations + documentation + `__ for a + list of available locations. The location must match that of any + datasets used in the query. + + *New in version 0.5.0 of pandas-gbq*. + configuration : dict, optional + Query config parameters for job processing. + For example: + + configuration = {'query': {'useQueryCache': False}} + + For more information see `BigQuery REST API Reference + `__. + credentials : google.auth.credentials.Credentials, optional + Credentials for accessing Google APIs. Use this parameter to override + default credentials, such as to use Compute Engine + :class:`google.auth.compute_engine.Credentials` or Service Account + :class:`google.oauth2.service_account.Credentials` directly. + + *New in version 0.8.0 of pandas-gbq*. + use_bqstorage_api : bool, default False + Use the `BigQuery Storage API + `__ to + download query results quickly, but at an increased cost. To use this + API, first `enable it in the Cloud Console + `__. + You must also have the `bigquery.readsessions.create + `__ + permission on the project you are billing queries to. + + This feature requires version 0.10.0 or later of the ``pandas-gbq`` + package. It also requires the ``google-cloud-bigquery-storage`` and + ``fastavro`` packages. + + max_results : int, optional + If set, limit the maximum number of rows to fetch from the query + results. + + progress_bar_type : Optional, str + If set, use the `tqdm `__ library to + display a progress bar while the data downloads. Install the + ``tqdm`` package to use this feature. + + Possible values of ``progress_bar_type`` include: + + ``None`` + No progress bar. + ``'tqdm'`` + Use the :func:`tqdm.tqdm` function to print a progress bar + to :data:`sys.stderr`. + ``'tqdm_notebook'`` + Use the :func:`tqdm.tqdm_notebook` function to display a + progress bar as a Jupyter notebook widget. + ``'tqdm_gui'`` + Use the :func:`tqdm.tqdm_gui` function to display a + progress bar as a graphical dialog box. + + Returns + ------- + df: DataFrame + DataFrame representing results of query. + + See Also + -------- + pandas_gbq.read_gbq : This function in the pandas-gbq library. + DataFrame.to_gbq : Write a DataFrame to Google BigQuery. + + Examples + -------- + Example taken from `Google BigQuery documentation + `_ + + >>> sql = "SELECT name FROM table_name WHERE state = 'TX' LIMIT 100;" + >>> df = pd.read_gbq(sql, dialect="standard") # doctest: +SKIP + >>> project_id = "your-project-id" # doctest: +SKIP + >>> df = pd.read_gbq(sql, + ... project_id=project_id, + ... dialect="standard" + ... ) # doctest: +SKIP + """ + warnings.warn( + "read_gbq is deprecated and will be removed in a future version. " + "Please use pandas_gbq.read_gbq instead: " + "https://pandas-gbq.readthedocs.io/en/latest/api.html#pandas_gbq.read_gbq", + FutureWarning, + stacklevel=find_stack_level(), + ) + pandas_gbq = _try_import() + + kwargs: dict[str, str | bool | int | None] = {} + + # START: new kwargs. Don't populate unless explicitly set. + if use_bqstorage_api is not None: + kwargs["use_bqstorage_api"] = use_bqstorage_api + if max_results is not None: + kwargs["max_results"] = max_results + + kwargs["progress_bar_type"] = progress_bar_type + # END: new kwargs + + return pandas_gbq.read_gbq( + query, + project_id=project_id, + index_col=index_col, + col_order=col_order, + reauth=reauth, + auth_local_webserver=auth_local_webserver, + dialect=dialect, + location=location, + configuration=configuration, + credentials=credentials, + **kwargs, + ) + + +def to_gbq( + dataframe: DataFrame, + destination_table: str, + project_id: str | None = None, + chunksize: int | None = None, + reauth: bool = False, + if_exists: str = "fail", + auth_local_webserver: bool = True, + table_schema: list[dict[str, str]] | None = None, + location: str | None = None, + progress_bar: bool = True, + credentials: Credentials | None = None, +) -> None: + warnings.warn( + "to_gbq is deprecated and will be removed in a future version. " + "Please use pandas_gbq.to_gbq instead: " + "https://pandas-gbq.readthedocs.io/en/latest/api.html#pandas_gbq.to_gbq", + FutureWarning, + stacklevel=find_stack_level(), + ) + pandas_gbq = _try_import() + pandas_gbq.to_gbq( + dataframe, + destination_table, + project_id=project_id, + chunksize=chunksize, + reauth=reauth, + if_exists=if_exists, + auth_local_webserver=auth_local_webserver, + table_schema=table_schema, + location=location, + progress_bar=progress_bar, + credentials=credentials, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/io/html.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/io/html.py new file mode 100644 index 0000000000000000000000000000000000000000..4eeeb1b655f8ac55309edeacd593f5a5c2516678 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/io/html.py @@ -0,0 +1,1259 @@ +""" +:mod:`pandas.io.html` is a module containing functionality for dealing with +HTML IO. + +""" + +from __future__ import annotations + +from collections import abc +import numbers +import re +from re import Pattern +from typing import ( + TYPE_CHECKING, + Literal, + cast, +) +import warnings + +from pandas._libs import lib +from pandas.compat._optional import import_optional_dependency +from pandas.errors import ( + AbstractMethodError, + EmptyDataError, +) +from pandas.util._decorators import doc +from pandas.util._exceptions import find_stack_level +from pandas.util._validators import check_dtype_backend + +from pandas.core.dtypes.common import is_list_like + +from pandas import isna +from pandas.core.indexes.base import Index +from pandas.core.indexes.multi import MultiIndex +from pandas.core.series import Series +from pandas.core.shared_docs import _shared_docs + +from pandas.io.common import ( + file_exists, + get_handle, + is_file_like, + is_fsspec_url, + is_url, + stringify_path, + validate_header_arg, +) +from pandas.io.formats.printing import pprint_thing +from pandas.io.parsers import TextParser + +if TYPE_CHECKING: + from collections.abc import ( + Iterable, + Sequence, + ) + + from pandas._typing import ( + BaseBuffer, + DtypeBackend, + FilePath, + HTMLFlavors, + ReadBuffer, + StorageOptions, + ) + + from pandas import DataFrame + +############# +# READ HTML # +############# +_RE_WHITESPACE = re.compile(r"[\r\n]+|\s{2,}") + + +def _remove_whitespace(s: str, regex: Pattern = _RE_WHITESPACE) -> str: + """ + Replace extra whitespace inside of a string with a single space. + + Parameters + ---------- + s : str or unicode + The string from which to remove extra whitespace. + regex : re.Pattern + The regular expression to use to remove extra whitespace. + + Returns + ------- + subd : str or unicode + `s` with all extra whitespace replaced with a single space. + """ + return regex.sub(" ", s.strip()) + + +def _get_skiprows(skiprows: int | Sequence[int] | slice | None) -> int | Sequence[int]: + """ + Get an iterator given an integer, slice or container. + + Parameters + ---------- + skiprows : int, slice, container + The iterator to use to skip rows; can also be a slice. + + Raises + ------ + TypeError + * If `skiprows` is not a slice, integer, or Container + + Returns + ------- + it : iterable + A proper iterator to use to skip rows of a DataFrame. + """ + if isinstance(skiprows, slice): + start, step = skiprows.start or 0, skiprows.step or 1 + return list(range(start, skiprows.stop, step)) + elif isinstance(skiprows, numbers.Integral) or is_list_like(skiprows): + return cast("int | Sequence[int]", skiprows) + elif skiprows is None: + return 0 + raise TypeError(f"{type(skiprows).__name__} is not a valid type for skipping rows") + + +def _read( + obj: FilePath | BaseBuffer, + encoding: str | None, + storage_options: StorageOptions | None, +) -> str | bytes: + """ + Try to read from a url, file or string. + + Parameters + ---------- + obj : str, unicode, path object, or file-like object + + Returns + ------- + raw_text : str + """ + text: str | bytes + if ( + is_url(obj) + or hasattr(obj, "read") + or (isinstance(obj, str) and file_exists(obj)) + ): + with get_handle( + obj, "r", encoding=encoding, storage_options=storage_options + ) as handles: + text = handles.handle.read() + elif isinstance(obj, (str, bytes)): + text = obj + else: + raise TypeError(f"Cannot read object of type '{type(obj).__name__}'") + return text + + +class _HtmlFrameParser: + """ + Base class for parsers that parse HTML into DataFrames. + + Parameters + ---------- + io : str or file-like + This can be either a string of raw HTML, a valid URL using the HTTP, + FTP, or FILE protocols or a file-like object. + + match : str or regex + The text to match in the document. + + attrs : dict + List of HTML element attributes to match. + + encoding : str + Encoding to be used by parser + + displayed_only : bool + Whether or not items with "display:none" should be ignored + + extract_links : {None, "all", "header", "body", "footer"} + Table elements in the specified section(s) with tags will have their + href extracted. + + .. versionadded:: 1.5.0 + + Attributes + ---------- + io : str or file-like + raw HTML, URL, or file-like object + + match : regex + The text to match in the raw HTML + + attrs : dict-like + A dictionary of valid table attributes to use to search for table + elements. + + encoding : str + Encoding to be used by parser + + displayed_only : bool + Whether or not items with "display:none" should be ignored + + extract_links : {None, "all", "header", "body", "footer"} + Table elements in the specified section(s) with tags will have their + href extracted. + + .. versionadded:: 1.5.0 + + Notes + ----- + To subclass this class effectively you must override the following methods: + * :func:`_build_doc` + * :func:`_attr_getter` + * :func:`_href_getter` + * :func:`_text_getter` + * :func:`_parse_td` + * :func:`_parse_thead_tr` + * :func:`_parse_tbody_tr` + * :func:`_parse_tfoot_tr` + * :func:`_parse_tables` + * :func:`_equals_tag` + See each method's respective documentation for details on their + functionality. + """ + + def __init__( + self, + io: FilePath | ReadBuffer[str] | ReadBuffer[bytes], + match: str | Pattern, + attrs: dict[str, str] | None, + encoding: str, + displayed_only: bool, + extract_links: Literal[None, "header", "footer", "body", "all"], + storage_options: StorageOptions = None, + ) -> None: + self.io = io + self.match = match + self.attrs = attrs + self.encoding = encoding + self.displayed_only = displayed_only + self.extract_links = extract_links + self.storage_options = storage_options + + def parse_tables(self): + """ + Parse and return all tables from the DOM. + + Returns + ------- + list of parsed (header, body, footer) tuples from tables. + """ + tables = self._parse_tables(self._build_doc(), self.match, self.attrs) + return (self._parse_thead_tbody_tfoot(table) for table in tables) + + def _attr_getter(self, obj, attr): + """ + Return the attribute value of an individual DOM node. + + Parameters + ---------- + obj : node-like + A DOM node. + + attr : str or unicode + The attribute, such as "colspan" + + Returns + ------- + str or unicode + The attribute value. + """ + # Both lxml and BeautifulSoup have the same implementation: + return obj.get(attr) + + def _href_getter(self, obj) -> str | None: + """ + Return a href if the DOM node contains a child or None. + + Parameters + ---------- + obj : node-like + A DOM node. + + Returns + ------- + href : str or unicode + The href from the child of the DOM node. + """ + raise AbstractMethodError(self) + + def _text_getter(self, obj): + """ + Return the text of an individual DOM node. + + Parameters + ---------- + obj : node-like + A DOM node. + + Returns + ------- + text : str or unicode + The text from an individual DOM node. + """ + raise AbstractMethodError(self) + + def _parse_td(self, obj): + """ + Return the td elements from a row element. + + Parameters + ---------- + obj : node-like + A DOM node. + + Returns + ------- + list of node-like + These are the elements of each row, i.e., the columns. + """ + raise AbstractMethodError(self) + + def _parse_thead_tr(self, table): + """ + Return the list of thead row elements from the parsed table element. + + Parameters + ---------- + table : a table element that contains zero or more thead elements. + + Returns + ------- + list of node-like + These are the row elements of a table. + """ + raise AbstractMethodError(self) + + def _parse_tbody_tr(self, table): + """ + Return the list of tbody row elements from the parsed table element. + + HTML5 table bodies consist of either 0 or more elements (which + only contain elements) or 0 or more elements. This method + checks for both structures. + + Parameters + ---------- + table : a table element that contains row elements. + + Returns + ------- + list of node-like + These are the row elements of a table. + """ + raise AbstractMethodError(self) + + def _parse_tfoot_tr(self, table): + """ + Return the list of tfoot row elements from the parsed table element. + + Parameters + ---------- + table : a table element that contains row elements. + + Returns + ------- + list of node-like + These are the row elements of a table. + """ + raise AbstractMethodError(self) + + def _parse_tables(self, document, match, attrs): + """ + Return all tables from the parsed DOM. + + Parameters + ---------- + document : the DOM from which to parse the table element. + + match : str or regular expression + The text to search for in the DOM tree. + + attrs : dict + A dictionary of table attributes that can be used to disambiguate + multiple tables on a page. + + Raises + ------ + ValueError : `match` does not match any text in the document. + + Returns + ------- + list of node-like + HTML
elements to be parsed into raw data. + """ + raise AbstractMethodError(self) + + def _equals_tag(self, obj, tag) -> bool: + """ + Return whether an individual DOM node matches a tag + + Parameters + ---------- + obj : node-like + A DOM node. + + tag : str + Tag name to be checked for equality. + + Returns + ------- + boolean + Whether `obj`'s tag name is `tag` + """ + raise AbstractMethodError(self) + + def _build_doc(self): + """ + Return a tree-like object that can be used to iterate over the DOM. + + Returns + ------- + node-like + The DOM from which to parse the table element. + """ + raise AbstractMethodError(self) + + def _parse_thead_tbody_tfoot(self, table_html): + """ + Given a table, return parsed header, body, and foot. + + Parameters + ---------- + table_html : node-like + + Returns + ------- + tuple of (header, body, footer), each a list of list-of-text rows. + + Notes + ----- + Header and body are lists-of-lists. Top level list is a list of + rows. Each row is a list of str text. + + Logic: Use , , elements to identify + header, body, and footer, otherwise: + - Put all rows into body + - Move rows from top of body to header only if + all elements inside row are . Move the top all- or + while body_rows and row_is_all_th(body_rows[0]): + header_rows.append(body_rows.pop(0)) + + header = self._expand_colspan_rowspan(header_rows, section="header") + body = self._expand_colspan_rowspan(body_rows, section="body") + footer = self._expand_colspan_rowspan(footer_rows, section="footer") + + return header, body, footer + + def _expand_colspan_rowspan( + self, rows, section: Literal["header", "footer", "body"] + ): + """ + Given a list of s, return a list of text rows. + + Parameters + ---------- + rows : list of node-like + List of s + section : the section that the rows belong to (header, body or footer). + + Returns + ------- + list of list + Each returned row is a list of str text, or tuple (text, link) + if extract_links is not None. + + Notes + ----- + Any cell with ``rowspan`` or ``colspan`` will have its contents copied + to subsequent cells. + """ + all_texts = [] # list of rows, each a list of str + text: str | tuple + remainder: list[ + tuple[int, str | tuple, int] + ] = [] # list of (index, text, nrows) + + for tr in rows: + texts = [] # the output for this row + next_remainder = [] + + index = 0 + tds = self._parse_td(tr) + for td in tds: + # Append texts from previous rows with rowspan>1 that come + # before this or (see _parse_thead_tr). + return row.xpath("./td|./th") + + def _parse_tables(self, document, match, kwargs): + pattern = match.pattern + + # 1. check all descendants for the given pattern and only search tables + # GH 49929 + xpath_expr = f"//table[.//text()[re:test(., {repr(pattern)})]]" + + # if any table attributes were given build an xpath expression to + # search for them + if kwargs: + xpath_expr += _build_xpath_expr(kwargs) + + tables = document.xpath(xpath_expr, namespaces=_re_namespace) + + tables = self._handle_hidden_tables(tables, "attrib") + if self.displayed_only: + for table in tables: + # lxml utilizes XPATH 1.0 which does not have regex + # support. As a result, we find all elements with a style + # attribute and iterate them to check for display:none + for elem in table.xpath(".//style"): + elem.drop_tree() + for elem in table.xpath(".//*[@style]"): + if "display:none" in elem.attrib.get("style", "").replace(" ", ""): + elem.drop_tree() + if not tables: + raise ValueError(f"No tables found matching regex {repr(pattern)}") + return tables + + def _equals_tag(self, obj, tag) -> bool: + return obj.tag == tag + + def _build_doc(self): + """ + Raises + ------ + ValueError + * If a URL that lxml cannot parse is passed. + + Exception + * Any other ``Exception`` thrown. For example, trying to parse a + URL that is syntactically correct on a machine with no internet + connection will fail. + + See Also + -------- + pandas.io.html._HtmlFrameParser._build_doc + """ + from lxml.etree import XMLSyntaxError + from lxml.html import ( + HTMLParser, + fromstring, + parse, + ) + + parser = HTMLParser(recover=True, encoding=self.encoding) + + try: + if is_url(self.io): + with get_handle( + self.io, "r", storage_options=self.storage_options + ) as f: + r = parse(f.handle, parser=parser) + else: + # try to parse the input in the simplest way + r = parse(self.io, parser=parser) + try: + r = r.getroot() + except AttributeError: + pass + except (UnicodeDecodeError, OSError) as e: + # if the input is a blob of html goop + if not is_url(self.io): + r = fromstring(self.io, parser=parser) + + try: + r = r.getroot() + except AttributeError: + pass + else: + raise e + else: + if not hasattr(r, "text_content"): + raise XMLSyntaxError("no text parsed from document", 0, 0, 0) + + for br in r.xpath("*//br"): + br.tail = "\n" + (br.tail or "") + + return r + + def _parse_thead_tr(self, table): + rows = [] + + for thead in table.xpath(".//thead"): + rows.extend(thead.xpath("./tr")) + + # HACK: lxml does not clean up the clearly-erroneous + # . (Missing ). Add + # the and _pretend_ it's a ; _parse_td() will find its + # children as though it's a . + # + # Better solution would be to use html5lib. + elements_at_root = thead.xpath("./td|./th") + if elements_at_root: + rows.append(thead) + + return rows + + def _parse_tbody_tr(self, table): + from_tbody = table.xpath(".//tbody//tr") + from_root = table.xpath("./tr") + # HTML spec: at most one of these lists has content + return from_tbody + from_root + + def _parse_tfoot_tr(self, table): + return table.xpath(".//tfoot//tr") + + +def _expand_elements(body) -> None: + data = [len(elem) for elem in body] + lens = Series(data) + lens_max = lens.max() + not_max = lens[lens != lens_max] + + empty = [""] + for ind, length in not_max.items(): + body[ind] += empty * (lens_max - length) + + +def _data_to_frame(**kwargs): + head, body, foot = kwargs.pop("data") + header = kwargs.pop("header") + kwargs["skiprows"] = _get_skiprows(kwargs["skiprows"]) + if head: + body = head + body + + # Infer header when there is a or top
+ - Move rows from bottom of body to footer only if + all elements inside row are + """ + header_rows = self._parse_thead_tr(table_html) + body_rows = self._parse_tbody_tr(table_html) + footer_rows = self._parse_tfoot_tr(table_html) + + def row_is_all_th(row): + return all(self._equals_tag(t, "th") for t in self._parse_td(row)) + + if not header_rows: + # The table has no
rows from + # body_rows to header_rows. (This is a common case because many + # tables in the wild have no
+ while remainder and remainder[0][0] <= index: + prev_i, prev_text, prev_rowspan = remainder.pop(0) + texts.append(prev_text) + if prev_rowspan > 1: + next_remainder.append((prev_i, prev_text, prev_rowspan - 1)) + index += 1 + + # Append the text from this , colspan times + text = _remove_whitespace(self._text_getter(td)) + if self.extract_links in ("all", section): + href = self._href_getter(td) + text = (text, href) + rowspan = int(self._attr_getter(td, "rowspan") or 1) + colspan = int(self._attr_getter(td, "colspan") or 1) + + for _ in range(colspan): + texts.append(text) + if rowspan > 1: + next_remainder.append((index, text, rowspan - 1)) + index += 1 + + # Append texts from previous rows at the final position + for prev_i, prev_text, prev_rowspan in remainder: + texts.append(prev_text) + if prev_rowspan > 1: + next_remainder.append((prev_i, prev_text, prev_rowspan - 1)) + + all_texts.append(texts) + remainder = next_remainder + + # Append rows that only appear because the previous row had non-1 + # rowspan + while remainder: + next_remainder = [] + texts = [] + for prev_i, prev_text, prev_rowspan in remainder: + texts.append(prev_text) + if prev_rowspan > 1: + next_remainder.append((prev_i, prev_text, prev_rowspan - 1)) + all_texts.append(texts) + remainder = next_remainder + + return all_texts + + def _handle_hidden_tables(self, tbl_list, attr_name: str): + """ + Return list of tables, potentially removing hidden elements + + Parameters + ---------- + tbl_list : list of node-like + Type of list elements will vary depending upon parser used + attr_name : str + Name of the accessor for retrieving HTML attributes + + Returns + ------- + list of node-like + Return type matches `tbl_list` + """ + if not self.displayed_only: + return tbl_list + + return [ + x + for x in tbl_list + if "display:none" + not in getattr(x, attr_name).get("style", "").replace(" ", "") + ] + + +class _BeautifulSoupHtml5LibFrameParser(_HtmlFrameParser): + """ + HTML to DataFrame parser that uses BeautifulSoup under the hood. + + See Also + -------- + pandas.io.html._HtmlFrameParser + pandas.io.html._LxmlFrameParser + + Notes + ----- + Documentation strings for this class are in the base class + :class:`pandas.io.html._HtmlFrameParser`. + """ + + def _parse_tables(self, document, match, attrs): + element_name = "table" + tables = document.find_all(element_name, attrs=attrs) + if not tables: + raise ValueError("No tables found") + + result = [] + unique_tables = set() + tables = self._handle_hidden_tables(tables, "attrs") + + for table in tables: + if self.displayed_only: + for elem in table.find_all("style"): + elem.decompose() + + for elem in table.find_all(style=re.compile(r"display:\s*none")): + elem.decompose() + + if table not in unique_tables and table.find(string=match) is not None: + result.append(table) + unique_tables.add(table) + if not result: + raise ValueError(f"No tables found matching pattern {repr(match.pattern)}") + return result + + def _href_getter(self, obj) -> str | None: + a = obj.find("a", href=True) + return None if not a else a["href"] + + def _text_getter(self, obj): + return obj.text + + def _equals_tag(self, obj, tag) -> bool: + return obj.name == tag + + def _parse_td(self, row): + return row.find_all(("td", "th"), recursive=False) + + def _parse_thead_tr(self, table): + return table.select("thead tr") + + def _parse_tbody_tr(self, table): + from_tbody = table.select("tbody tr") + from_root = table.find_all("tr", recursive=False) + # HTML spec: at most one of these lists has content + return from_tbody + from_root + + def _parse_tfoot_tr(self, table): + return table.select("tfoot tr") + + def _setup_build_doc(self): + raw_text = _read(self.io, self.encoding, self.storage_options) + if not raw_text: + raise ValueError(f"No text parsed from document: {self.io}") + return raw_text + + def _build_doc(self): + from bs4 import BeautifulSoup + + bdoc = self._setup_build_doc() + if isinstance(bdoc, bytes) and self.encoding is not None: + udoc = bdoc.decode(self.encoding) + from_encoding = None + else: + udoc = bdoc + from_encoding = self.encoding + + soup = BeautifulSoup(udoc, features="html5lib", from_encoding=from_encoding) + + for br in soup.find_all("br"): + br.replace_with("\n" + br.text) + + return soup + + +def _build_xpath_expr(attrs) -> str: + """ + Build an xpath expression to simulate bs4's ability to pass in kwargs to + search for attributes when using the lxml parser. + + Parameters + ---------- + attrs : dict + A dict of HTML attributes. These are NOT checked for validity. + + Returns + ------- + expr : unicode + An XPath expression that checks for the given HTML attributes. + """ + # give class attribute as class_ because class is a python keyword + if "class_" in attrs: + attrs["class"] = attrs.pop("class_") + + s = " and ".join([f"@{k}={repr(v)}" for k, v in attrs.items()]) + return f"[{s}]" + + +_re_namespace = {"re": "http://exslt.org/regular-expressions"} + + +class _LxmlFrameParser(_HtmlFrameParser): + """ + HTML to DataFrame parser that uses lxml under the hood. + + Warning + ------- + This parser can only handle HTTP, FTP, and FILE urls. + + See Also + -------- + _HtmlFrameParser + _BeautifulSoupLxmlFrameParser + + Notes + ----- + Documentation strings for this class are in the base class + :class:`_HtmlFrameParser`. + """ + + def _href_getter(self, obj) -> str | None: + href = obj.xpath(".//a/@href") + return None if not href else href[0] + + def _text_getter(self, obj): + return obj.text_content() + + def _parse_td(self, row): + # Look for direct children only: the "row" element here may be a + #
foobar
-only rows + if header is None: + if len(head) == 1: + header = 0 + else: + # ignore all-empty-text rows + header = [i for i, row in enumerate(head) if any(text for text in row)] + + if foot: + body += foot + + # fill out elements of body that are "ragged" + _expand_elements(body) + with TextParser(body, header=header, **kwargs) as tp: + return tp.read() + + +_valid_parsers = { + "lxml": _LxmlFrameParser, + None: _LxmlFrameParser, + "html5lib": _BeautifulSoupHtml5LibFrameParser, + "bs4": _BeautifulSoupHtml5LibFrameParser, +} + + +def _parser_dispatch(flavor: HTMLFlavors | None) -> type[_HtmlFrameParser]: + """ + Choose the parser based on the input flavor. + + Parameters + ---------- + flavor : {{"lxml", "html5lib", "bs4"}} or None + The type of parser to use. This must be a valid backend. + + Returns + ------- + cls : _HtmlFrameParser subclass + The parser class based on the requested input flavor. + + Raises + ------ + ValueError + * If `flavor` is not a valid backend. + ImportError + * If you do not have the requested `flavor` + """ + valid_parsers = list(_valid_parsers.keys()) + if flavor not in valid_parsers: + raise ValueError( + f"{repr(flavor)} is not a valid flavor, valid flavors are {valid_parsers}" + ) + + if flavor in ("bs4", "html5lib"): + import_optional_dependency("html5lib") + import_optional_dependency("bs4") + else: + import_optional_dependency("lxml.etree") + return _valid_parsers[flavor] + + +def _print_as_set(s) -> str: + arg = ", ".join([pprint_thing(el) for el in s]) + return f"{{{arg}}}" + + +def _validate_flavor(flavor): + if flavor is None: + flavor = "lxml", "bs4" + elif isinstance(flavor, str): + flavor = (flavor,) + elif isinstance(flavor, abc.Iterable): + if not all(isinstance(flav, str) for flav in flavor): + raise TypeError( + f"Object of type {repr(type(flavor).__name__)} " + f"is not an iterable of strings" + ) + else: + msg = repr(flavor) if isinstance(flavor, str) else str(flavor) + msg += " is not a valid flavor" + raise ValueError(msg) + + flavor = tuple(flavor) + valid_flavors = set(_valid_parsers) + flavor_set = set(flavor) + + if not flavor_set & valid_flavors: + raise ValueError( + f"{_print_as_set(flavor_set)} is not a valid set of flavors, valid " + f"flavors are {_print_as_set(valid_flavors)}" + ) + return flavor + + +def _parse( + flavor, + io, + match, + attrs, + encoding, + displayed_only, + extract_links, + storage_options, + **kwargs, +): + flavor = _validate_flavor(flavor) + compiled_match = re.compile(match) # you can pass a compiled regex here + + retained = None + for flav in flavor: + parser = _parser_dispatch(flav) + p = parser( + io, + compiled_match, + attrs, + encoding, + displayed_only, + extract_links, + storage_options, + ) + + try: + tables = p.parse_tables() + except ValueError as caught: + # if `io` is an io-like object, check if it's seekable + # and try to rewind it before trying the next parser + if hasattr(io, "seekable") and io.seekable(): + io.seek(0) + elif hasattr(io, "seekable") and not io.seekable(): + # if we couldn't rewind it, let the user know + raise ValueError( + f"The flavor {flav} failed to parse your input. " + "Since you passed a non-rewindable file " + "object, we can't rewind it to try " + "another parser. Try read_html() with a different flavor." + ) from caught + + retained = caught + else: + break + else: + assert retained is not None # for mypy + raise retained + + ret = [] + for table in tables: + try: + df = _data_to_frame(data=table, **kwargs) + # Cast MultiIndex header to an Index of tuples when extracting header + # links and replace nan with None (therefore can't use mi.to_flat_index()). + # This maintains consistency of selection (e.g. df.columns.str[1]) + if extract_links in ("all", "header") and isinstance( + df.columns, MultiIndex + ): + df.columns = Index( + ((col[0], None if isna(col[1]) else col[1]) for col in df.columns), + tupleize_cols=False, + ) + + ret.append(df) + except EmptyDataError: # empty table + continue + return ret + + +@doc(storage_options=_shared_docs["storage_options"]) +def read_html( + io: FilePath | ReadBuffer[str], + *, + match: str | Pattern = ".+", + flavor: HTMLFlavors | Sequence[HTMLFlavors] | None = None, + header: int | Sequence[int] | None = None, + index_col: int | Sequence[int] | None = None, + skiprows: int | Sequence[int] | slice | None = None, + attrs: dict[str, str] | None = None, + parse_dates: bool = False, + thousands: str | None = ",", + encoding: str | None = None, + decimal: str = ".", + converters: dict | None = None, + na_values: Iterable[object] | None = None, + keep_default_na: bool = True, + displayed_only: bool = True, + extract_links: Literal[None, "header", "footer", "body", "all"] = None, + dtype_backend: DtypeBackend | lib.NoDefault = lib.no_default, + storage_options: StorageOptions = None, +) -> list[DataFrame]: + r""" + Read HTML tables into a ``list`` of ``DataFrame`` objects. + + Parameters + ---------- + io : str, path object, or file-like object + String, path object (implementing ``os.PathLike[str]``), or file-like + object implementing a string ``read()`` function. + The string can represent a URL or the HTML itself. Note that + lxml only accepts the http, ftp and file url protocols. If you have a + URL that starts with ``'https'`` you might try removing the ``'s'``. + + .. deprecated:: 2.1.0 + Passing html literal strings is deprecated. + Wrap literal string/bytes input in ``io.StringIO``/``io.BytesIO`` instead. + + match : str or compiled regular expression, optional + The set of tables containing text matching this regex or string will be + returned. Unless the HTML is extremely simple you will probably need to + pass a non-empty string here. Defaults to '.+' (match any non-empty + string). The default value will return all tables contained on a page. + This value is converted to a regular expression so that there is + consistent behavior between Beautiful Soup and lxml. + + flavor : {{"lxml", "html5lib", "bs4"}} or list-like, optional + The parsing engine (or list of parsing engines) to use. 'bs4' and + 'html5lib' are synonymous with each other, they are both there for + backwards compatibility. The default of ``None`` tries to use ``lxml`` + to parse and if that fails it falls back on ``bs4`` + ``html5lib``. + + header : int or list-like, optional + The row (or list of rows for a :class:`~pandas.MultiIndex`) to use to + make the columns headers. + + index_col : int or list-like, optional + The column (or list of columns) to use to create the index. + + skiprows : int, list-like or slice, optional + Number of rows to skip after parsing the column integer. 0-based. If a + sequence of integers or a slice is given, will skip the rows indexed by + that sequence. Note that a single element sequence means 'skip the nth + row' whereas an integer means 'skip n rows'. + + attrs : dict, optional + This is a dictionary of attributes that you can pass to use to identify + the table in the HTML. These are not checked for validity before being + passed to lxml or Beautiful Soup. However, these attributes must be + valid HTML table attributes to work correctly. For example, :: + + attrs = {{'id': 'table'}} + + is a valid attribute dictionary because the 'id' HTML tag attribute is + a valid HTML attribute for *any* HTML tag as per `this document + `__. :: + + attrs = {{'asdf': 'table'}} + + is *not* a valid attribute dictionary because 'asdf' is not a valid + HTML attribute even if it is a valid XML attribute. Valid HTML 4.01 + table attributes can be found `here + `__. A + working draft of the HTML 5 spec can be found `here + `__. It contains the + latest information on table attributes for the modern web. + + parse_dates : bool, optional + See :func:`~read_csv` for more details. + + thousands : str, optional + Separator to use to parse thousands. Defaults to ``','``. + + encoding : str, optional + The encoding used to decode the web page. Defaults to ``None``.``None`` + preserves the previous encoding behavior, which depends on the + underlying parser library (e.g., the parser library will try to use + the encoding provided by the document). + + decimal : str, default '.' + Character to recognize as decimal point (e.g. use ',' for European + data). + + converters : dict, default None + Dict of functions for converting values in certain columns. Keys can + either be integers or column labels, values are functions that take one + input argument, the cell (not column) content, and return the + transformed content. + + na_values : iterable, default None + Custom NA values. + + keep_default_na : bool, default True + If na_values are specified and keep_default_na is False the default NaN + values are overridden, otherwise they're appended to. + + displayed_only : bool, default True + Whether elements with "display: none" should be parsed. + + extract_links : {{None, "all", "header", "body", "footer"}} + Table elements in the specified section(s) with tags will have their + href extracted. + + .. versionadded:: 1.5.0 + + dtype_backend : {{'numpy_nullable', 'pyarrow'}}, default 'numpy_nullable' + Back-end data type applied to the resultant :class:`DataFrame` + (still experimental). Behaviour is as follows: + + * ``"numpy_nullable"``: returns nullable-dtype-backed :class:`DataFrame` + (default). + * ``"pyarrow"``: returns pyarrow-backed nullable :class:`ArrowDtype` + DataFrame. + + .. versionadded:: 2.0 + + {storage_options} + + .. versionadded:: 2.1.0 + + Returns + ------- + dfs + A list of DataFrames. + + See Also + -------- + read_csv : Read a comma-separated values (csv) file into DataFrame. + + Notes + ----- + Before using this function you should read the :ref:`gotchas about the + HTML parsing libraries `. + + Expect to do some cleanup after you call this function. For example, you + might need to manually assign column names if the column names are + converted to NaN when you pass the `header=0` argument. We try to assume as + little as possible about the structure of the table and push the + idiosyncrasies of the HTML contained in the table to the user. + + This function searches for ```` elements and only for ```` + and ```` or ```` argument, it is used to construct + the header, otherwise the function attempts to find the header within + the body (by putting rows with only ``
`` rows and ```` elements within each ``
`` + element in the table. ```` stands for "table data". This function + attempts to properly handle ``colspan`` and ``rowspan`` attributes. + If the function has a ``
`` elements into the header). + + Similar to :func:`~read_csv` the `header` argument is applied + **after** `skiprows` is applied. + + This function will *always* return a list of :class:`DataFrame` *or* + it will fail, e.g., it will *not* return an empty list. + + Examples + -------- + See the :ref:`read_html documentation in the IO section of the docs + ` for some examples of reading in HTML tables. + """ + # Type check here. We don't want to parse only to fail because of an + # invalid value of an integer skiprows. + if isinstance(skiprows, numbers.Integral) and skiprows < 0: + raise ValueError( + "cannot skip rows starting from the end of the " + "data (you passed a negative value)" + ) + if extract_links not in [None, "header", "footer", "body", "all"]: + raise ValueError( + "`extract_links` must be one of " + '{None, "header", "footer", "body", "all"}, got ' + f'"{extract_links}"' + ) + + validate_header_arg(header) + check_dtype_backend(dtype_backend) + + io = stringify_path(io) + + if isinstance(io, str) and not any( + [ + is_file_like(io), + file_exists(io), + is_url(io), + is_fsspec_url(io), + ] + ): + warnings.warn( + "Passing literal html to 'read_html' is deprecated and " + "will be removed in a future version. To read from a " + "literal string, wrap it in a 'StringIO' object.", + FutureWarning, + stacklevel=find_stack_level(), + ) + + return _parse( + flavor=flavor, + io=io, + match=match, + header=header, + index_col=index_col, + skiprows=skiprows, + parse_dates=parse_dates, + thousands=thousands, + attrs=attrs, + encoding=encoding, + decimal=decimal, + converters=converters, + na_values=na_values, + keep_default_na=keep_default_na, + displayed_only=displayed_only, + extract_links=extract_links, + dtype_backend=dtype_backend, + storage_options=storage_options, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/io/pickle.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/io/pickle.py new file mode 100644 index 0000000000000000000000000000000000000000..0dae0e7106b69a471f0c2702158cfe0f11f0389c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/io/pickle.py @@ -0,0 +1,210 @@ +""" pickle compat """ +from __future__ import annotations + +import pickle +from typing import ( + TYPE_CHECKING, + Any, +) +import warnings + +from pandas.compat import pickle_compat as pc +from pandas.util._decorators import doc + +from pandas.core.shared_docs import _shared_docs + +from pandas.io.common import get_handle + +if TYPE_CHECKING: + from pandas._typing import ( + CompressionOptions, + FilePath, + ReadPickleBuffer, + StorageOptions, + WriteBuffer, + ) + + from pandas import ( + DataFrame, + Series, + ) + + +@doc( + storage_options=_shared_docs["storage_options"], + compression_options=_shared_docs["compression_options"] % "filepath_or_buffer", +) +def to_pickle( + obj: Any, + filepath_or_buffer: FilePath | WriteBuffer[bytes], + compression: CompressionOptions = "infer", + protocol: int = pickle.HIGHEST_PROTOCOL, + storage_options: StorageOptions | None = None, +) -> None: + """ + Pickle (serialize) object to file. + + Parameters + ---------- + obj : any object + Any python object. + filepath_or_buffer : str, path object, or file-like object + String, path object (implementing ``os.PathLike[str]``), or file-like + object implementing a binary ``write()`` function. + Also accepts URL. URL has to be of S3 or GCS. + {compression_options} + + .. versionchanged:: 1.4.0 Zstandard support. + + protocol : int + Int which indicates which protocol should be used by the pickler, + default HIGHEST_PROTOCOL (see [1], paragraph 12.1.2). The possible + values for this parameter depend on the version of Python. For Python + 2.x, possible values are 0, 1, 2. For Python>=3.0, 3 is a valid value. + For Python >= 3.4, 4 is a valid value. A negative value for the + protocol parameter is equivalent to setting its value to + HIGHEST_PROTOCOL. + + {storage_options} + + .. [1] https://docs.python.org/3/library/pickle.html + + See Also + -------- + read_pickle : Load pickled pandas object (or any object) from file. + DataFrame.to_hdf : Write DataFrame to an HDF5 file. + DataFrame.to_sql : Write DataFrame to a SQL database. + DataFrame.to_parquet : Write a DataFrame to the binary parquet format. + + Examples + -------- + >>> original_df = pd.DataFrame({{"foo": range(5), "bar": range(5, 10)}}) # doctest: +SKIP + >>> original_df # doctest: +SKIP + foo bar + 0 0 5 + 1 1 6 + 2 2 7 + 3 3 8 + 4 4 9 + >>> pd.to_pickle(original_df, "./dummy.pkl") # doctest: +SKIP + + >>> unpickled_df = pd.read_pickle("./dummy.pkl") # doctest: +SKIP + >>> unpickled_df # doctest: +SKIP + foo bar + 0 0 5 + 1 1 6 + 2 2 7 + 3 3 8 + 4 4 9 + """ # noqa: E501 + if protocol < 0: + protocol = pickle.HIGHEST_PROTOCOL + + with get_handle( + filepath_or_buffer, + "wb", + compression=compression, + is_text=False, + storage_options=storage_options, + ) as handles: + # letting pickle write directly to the buffer is more memory-efficient + pickle.dump(obj, handles.handle, protocol=protocol) + + +@doc( + storage_options=_shared_docs["storage_options"], + decompression_options=_shared_docs["decompression_options"] % "filepath_or_buffer", +) +def read_pickle( + filepath_or_buffer: FilePath | ReadPickleBuffer, + compression: CompressionOptions = "infer", + storage_options: StorageOptions | None = None, +) -> DataFrame | Series: + """ + Load pickled pandas object (or any object) from file. + + .. warning:: + + Loading pickled data received from untrusted sources can be + unsafe. See `here `__. + + Parameters + ---------- + filepath_or_buffer : str, path object, or file-like object + String, path object (implementing ``os.PathLike[str]``), or file-like + object implementing a binary ``readlines()`` function. + Also accepts URL. URL is not limited to S3 and GCS. + + {decompression_options} + + .. versionchanged:: 1.4.0 Zstandard support. + + {storage_options} + + Returns + ------- + same type as object stored in file + + See Also + -------- + DataFrame.to_pickle : Pickle (serialize) DataFrame object to file. + Series.to_pickle : Pickle (serialize) Series object to file. + read_hdf : Read HDF5 file into a DataFrame. + read_sql : Read SQL query or database table into a DataFrame. + read_parquet : Load a parquet object, returning a DataFrame. + + Notes + ----- + read_pickle is only guaranteed to be backwards compatible to pandas 0.20.3 + provided the object was serialized with to_pickle. + + Examples + -------- + >>> original_df = pd.DataFrame( + ... {{"foo": range(5), "bar": range(5, 10)}} + ... ) # doctest: +SKIP + >>> original_df # doctest: +SKIP + foo bar + 0 0 5 + 1 1 6 + 2 2 7 + 3 3 8 + 4 4 9 + >>> pd.to_pickle(original_df, "./dummy.pkl") # doctest: +SKIP + + >>> unpickled_df = pd.read_pickle("./dummy.pkl") # doctest: +SKIP + >>> unpickled_df # doctest: +SKIP + foo bar + 0 0 5 + 1 1 6 + 2 2 7 + 3 3 8 + 4 4 9 + """ + excs_to_catch = (AttributeError, ImportError, ModuleNotFoundError, TypeError) + with get_handle( + filepath_or_buffer, + "rb", + compression=compression, + is_text=False, + storage_options=storage_options, + ) as handles: + # 1) try standard library Pickle + # 2) try pickle_compat (older pandas version) to handle subclass changes + # 3) try pickle_compat with latin-1 encoding upon a UnicodeDecodeError + + try: + # TypeError for Cython complaints about object.__new__ vs Tick.__new__ + try: + with warnings.catch_warnings(record=True): + # We want to silence any warnings about, e.g. moved modules. + warnings.simplefilter("ignore", Warning) + return pickle.load(handles.handle) + except excs_to_catch: + # e.g. + # "No module named 'pandas.core.sparse.series'" + # "Can't get attribute '__nat_unpickle' on str: + # set the encoding if we need + if encoding is None: + encoding = _default_encoding + + return encoding + + +def _ensure_str(name): + """ + Ensure that an index / column name is a str (python 3); otherwise they + may be np.string dtype. Non-string dtypes are passed through unchanged. + + https://github.com/pandas-dev/pandas/issues/13492 + """ + if isinstance(name, str): + name = str(name) + return name + + +Term = PyTablesExpr + + +def _ensure_term(where, scope_level: int): + """ + Ensure that the where is a Term or a list of Term. + + This makes sure that we are capturing the scope of variables that are + passed create the terms here with a frame_level=2 (we are 2 levels down) + """ + # only consider list/tuple here as an ndarray is automatically a coordinate + # list + level = scope_level + 1 + if isinstance(where, (list, tuple)): + where = [ + Term(term, scope_level=level + 1) if maybe_expression(term) else term + for term in where + if term is not None + ] + elif maybe_expression(where): + where = Term(where, scope_level=level) + return where if where is None or len(where) else None + + +incompatibility_doc: Final = """ +where criteria is being ignored as this version [%s] is too old (or +not-defined), read the file in and write it out to a new file to upgrade (with +the copy_to method) +""" + +attribute_conflict_doc: Final = """ +the [%s] attribute of the existing index is [%s] which conflicts with the new +[%s], resetting the attribute to None +""" + +performance_doc: Final = """ +your performance may suffer as PyTables will pickle object types that it cannot +map directly to c-types [inferred_type->%s,key->%s] [items->%s] +""" + +# formats +_FORMAT_MAP = {"f": "fixed", "fixed": "fixed", "t": "table", "table": "table"} + +# axes map +_AXES_MAP = {DataFrame: [0]} + +# register our configuration options +dropna_doc: Final = """ +: boolean + drop ALL nan rows when appending to a table +""" +format_doc: Final = """ +: format + default format writing format, if None, then + put will default to 'fixed' and append will default to 'table' +""" + +with config.config_prefix("io.hdf"): + config.register_option("dropna_table", False, dropna_doc, validator=config.is_bool) + config.register_option( + "default_format", + None, + format_doc, + validator=config.is_one_of_factory(["fixed", "table", None]), + ) + +# oh the troubles to reduce import time +_table_mod = None +_table_file_open_policy_is_strict = False + + +def _tables(): + global _table_mod + global _table_file_open_policy_is_strict + if _table_mod is None: + import tables + + _table_mod = tables + + # set the file open policy + # return the file open policy; this changes as of pytables 3.1 + # depending on the HDF5 version + with suppress(AttributeError): + _table_file_open_policy_is_strict = ( + tables.file._FILE_OPEN_POLICY == "strict" + ) + + return _table_mod + + +# interface to/from ### + + +def to_hdf( + path_or_buf: FilePath | HDFStore, + key: str, + value: DataFrame | Series, + mode: str = "a", + complevel: int | None = None, + complib: str | None = None, + append: bool = False, + format: str | None = None, + index: bool = True, + min_itemsize: int | dict[str, int] | None = None, + nan_rep=None, + dropna: bool | None = None, + data_columns: Literal[True] | list[str] | None = None, + errors: str = "strict", + encoding: str = "UTF-8", +) -> None: + """store this object, close it if we opened it""" + if append: + f = lambda store: store.append( + key, + value, + format=format, + index=index, + min_itemsize=min_itemsize, + nan_rep=nan_rep, + dropna=dropna, + data_columns=data_columns, + errors=errors, + encoding=encoding, + ) + else: + # NB: dropna is not passed to `put` + f = lambda store: store.put( + key, + value, + format=format, + index=index, + min_itemsize=min_itemsize, + nan_rep=nan_rep, + data_columns=data_columns, + errors=errors, + encoding=encoding, + dropna=dropna, + ) + + path_or_buf = stringify_path(path_or_buf) + if isinstance(path_or_buf, str): + with HDFStore( + path_or_buf, mode=mode, complevel=complevel, complib=complib + ) as store: + f(store) + else: + f(path_or_buf) + + +def read_hdf( + path_or_buf: FilePath | HDFStore, + key=None, + mode: str = "r", + errors: str = "strict", + where: str | list | None = None, + start: int | None = None, + stop: int | None = None, + columns: list[str] | None = None, + iterator: bool = False, + chunksize: int | None = None, + **kwargs, +): + """ + Read from the store, close it if we opened it. + + Retrieve pandas object stored in file, optionally based on where + criteria. + + .. warning:: + + Pandas uses PyTables for reading and writing HDF5 files, which allows + serializing object-dtype data with pickle when using the "fixed" format. + Loading pickled data received from untrusted sources can be unsafe. + + See: https://docs.python.org/3/library/pickle.html for more. + + Parameters + ---------- + path_or_buf : str, path object, pandas.HDFStore + Any valid string path is acceptable. Only supports the local file system, + remote URLs and file-like objects are not supported. + + If you want to pass in a path object, pandas accepts any + ``os.PathLike``. + + Alternatively, pandas accepts an open :class:`pandas.HDFStore` object. + + key : object, optional + The group identifier in the store. Can be omitted if the HDF file + contains a single pandas object. + mode : {'r', 'r+', 'a'}, default 'r' + Mode to use when opening the file. Ignored if path_or_buf is a + :class:`pandas.HDFStore`. Default is 'r'. + errors : str, default 'strict' + Specifies how encoding and decoding errors are to be handled. + See the errors argument for :func:`open` for a full list + of options. + where : list, optional + A list of Term (or convertible) objects. + start : int, optional + Row number to start selection. + stop : int, optional + Row number to stop selection. + columns : list, optional + A list of columns names to return. + iterator : bool, optional + Return an iterator object. + chunksize : int, optional + Number of rows to include in an iteration when using an iterator. + **kwargs + Additional keyword arguments passed to HDFStore. + + Returns + ------- + object + The selected object. Return type depends on the object stored. + + See Also + -------- + DataFrame.to_hdf : Write a HDF file from a DataFrame. + HDFStore : Low-level access to HDF files. + + Examples + -------- + >>> df = pd.DataFrame([[1, 1.0, 'a']], columns=['x', 'y', 'z']) # doctest: +SKIP + >>> df.to_hdf('./store.h5', 'data') # doctest: +SKIP + >>> reread = pd.read_hdf('./store.h5') # doctest: +SKIP + """ + if mode not in ["r", "r+", "a"]: + raise ValueError( + f"mode {mode} is not allowed while performing a read. " + f"Allowed modes are r, r+ and a." + ) + # grab the scope + if where is not None: + where = _ensure_term(where, scope_level=1) + + if isinstance(path_or_buf, HDFStore): + if not path_or_buf.is_open: + raise OSError("The HDFStore must be open for reading.") + + store = path_or_buf + auto_close = False + else: + path_or_buf = stringify_path(path_or_buf) + if not isinstance(path_or_buf, str): + raise NotImplementedError( + "Support for generic buffers has not been implemented." + ) + try: + exists = os.path.exists(path_or_buf) + + # if filepath is too long + except (TypeError, ValueError): + exists = False + + if not exists: + raise FileNotFoundError(f"File {path_or_buf} does not exist") + + store = HDFStore(path_or_buf, mode=mode, errors=errors, **kwargs) + # can't auto open/close if we are using an iterator + # so delegate to the iterator + auto_close = True + + try: + if key is None: + groups = store.groups() + if len(groups) == 0: + raise ValueError( + "Dataset(s) incompatible with Pandas data types, " + "not table, or no datasets found in HDF5 file." + ) + candidate_only_group = groups[0] + + # For the HDF file to have only one dataset, all other groups + # should then be metadata groups for that candidate group. (This + # assumes that the groups() method enumerates parent groups + # before their children.) + for group_to_check in groups[1:]: + if not _is_metadata_of(group_to_check, candidate_only_group): + raise ValueError( + "key must be provided when HDF5 " + "file contains multiple datasets." + ) + key = candidate_only_group._v_pathname + return store.select( + key, + where=where, + start=start, + stop=stop, + columns=columns, + iterator=iterator, + chunksize=chunksize, + auto_close=auto_close, + ) + except (ValueError, TypeError, LookupError): + if not isinstance(path_or_buf, HDFStore): + # if there is an error, close the store if we opened it. + with suppress(AttributeError): + store.close() + + raise + + +def _is_metadata_of(group: Node, parent_group: Node) -> bool: + """Check if a given group is a metadata group for a given parent_group.""" + if group._v_depth <= parent_group._v_depth: + return False + + current = group + while current._v_depth > 1: + parent = current._v_parent + if parent == parent_group and current._v_name == "meta": + return True + current = current._v_parent + return False + + +class HDFStore: + """ + Dict-like IO interface for storing pandas objects in PyTables. + + Either Fixed or Table format. + + .. warning:: + + Pandas uses PyTables for reading and writing HDF5 files, which allows + serializing object-dtype data with pickle when using the "fixed" format. + Loading pickled data received from untrusted sources can be unsafe. + + See: https://docs.python.org/3/library/pickle.html for more. + + Parameters + ---------- + path : str + File path to HDF5 file. + mode : {'a', 'w', 'r', 'r+'}, default 'a' + + ``'r'`` + Read-only; no data can be modified. + ``'w'`` + Write; a new file is created (an existing file with the same + name would be deleted). + ``'a'`` + Append; an existing file is opened for reading and writing, + and if the file does not exist it is created. + ``'r+'`` + It is similar to ``'a'``, but the file must already exist. + complevel : int, 0-9, default None + Specifies a compression level for data. + A value of 0 or None disables compression. + complib : {'zlib', 'lzo', 'bzip2', 'blosc'}, default 'zlib' + Specifies the compression library to be used. + These additional compressors for Blosc are supported + (default if no compressor specified: 'blosc:blosclz'): + {'blosc:blosclz', 'blosc:lz4', 'blosc:lz4hc', 'blosc:snappy', + 'blosc:zlib', 'blosc:zstd'}. + Specifying a compression library which is not available issues + a ValueError. + fletcher32 : bool, default False + If applying compression use the fletcher32 checksum. + **kwargs + These parameters will be passed to the PyTables open_file method. + + Examples + -------- + >>> bar = pd.DataFrame(np.random.randn(10, 4)) + >>> store = pd.HDFStore('test.h5') + >>> store['foo'] = bar # write to HDF5 + >>> bar = store['foo'] # retrieve + >>> store.close() + + **Create or load HDF5 file in-memory** + + When passing the `driver` option to the PyTables open_file method through + **kwargs, the HDF5 file is loaded or created in-memory and will only be + written when closed: + + >>> bar = pd.DataFrame(np.random.randn(10, 4)) + >>> store = pd.HDFStore('test.h5', driver='H5FD_CORE') + >>> store['foo'] = bar + >>> store.close() # only now, data is written to disk + """ + + _handle: File | None + _mode: str + + def __init__( + self, + path, + mode: str = "a", + complevel: int | None = None, + complib=None, + fletcher32: bool = False, + **kwargs, + ) -> None: + if "format" in kwargs: + raise ValueError("format is not a defined argument for HDFStore") + + tables = import_optional_dependency("tables") + + if complib is not None and complib not in tables.filters.all_complibs: + raise ValueError( + f"complib only supports {tables.filters.all_complibs} compression." + ) + + if complib is None and complevel is not None: + complib = tables.filters.default_complib + + self._path = stringify_path(path) + if mode is None: + mode = "a" + self._mode = mode + self._handle = None + self._complevel = complevel if complevel else 0 + self._complib = complib + self._fletcher32 = fletcher32 + self._filters = None + self.open(mode=mode, **kwargs) + + def __fspath__(self) -> str: + return self._path + + @property + def root(self): + """return the root node""" + self._check_if_open() + assert self._handle is not None # for mypy + return self._handle.root + + @property + def filename(self) -> str: + return self._path + + def __getitem__(self, key: str): + return self.get(key) + + def __setitem__(self, key: str, value) -> None: + self.put(key, value) + + def __delitem__(self, key: str) -> None: + return self.remove(key) + + def __getattr__(self, name: str): + """allow attribute access to get stores""" + try: + return self.get(name) + except (KeyError, ClosedFileError): + pass + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + def __contains__(self, key: str) -> bool: + """ + check for existence of this key + can match the exact pathname or the pathnm w/o the leading '/' + """ + node = self.get_node(key) + if node is not None: + name = node._v_pathname + if key in (name, name[1:]): + return True + return False + + def __len__(self) -> int: + return len(self.groups()) + + def __repr__(self) -> str: + pstr = pprint_thing(self._path) + return f"{type(self)}\nFile path: {pstr}\n" + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.close() + + def keys(self, include: str = "pandas") -> list[str]: + """ + Return a list of keys corresponding to objects stored in HDFStore. + + Parameters + ---------- + + include : str, default 'pandas' + When kind equals 'pandas' return pandas objects. + When kind equals 'native' return native HDF5 Table objects. + + Returns + ------- + list + List of ABSOLUTE path-names (e.g. have the leading '/'). + + Raises + ------ + raises ValueError if kind has an illegal value + + Examples + -------- + >>> df = pd.DataFrame([[1, 2], [3, 4]], columns=['A', 'B']) + >>> store = pd.HDFStore("store.h5", 'w') # doctest: +SKIP + >>> store.put('data', df) # doctest: +SKIP + >>> store.get('data') # doctest: +SKIP + >>> print(store.keys()) # doctest: +SKIP + ['/data1', '/data2'] + >>> store.close() # doctest: +SKIP + """ + if include == "pandas": + return [n._v_pathname for n in self.groups()] + + elif include == "native": + assert self._handle is not None # mypy + return [ + n._v_pathname for n in self._handle.walk_nodes("/", classname="Table") + ] + raise ValueError( + f"`include` should be either 'pandas' or 'native' but is '{include}'" + ) + + def __iter__(self) -> Iterator[str]: + return iter(self.keys()) + + def items(self) -> Iterator[tuple[str, list]]: + """ + iterate on key->group + """ + for g in self.groups(): + yield g._v_pathname, g + + def open(self, mode: str = "a", **kwargs) -> None: + """ + Open the file in the specified mode + + Parameters + ---------- + mode : {'a', 'w', 'r', 'r+'}, default 'a' + See HDFStore docstring or tables.open_file for info about modes + **kwargs + These parameters will be passed to the PyTables open_file method. + """ + tables = _tables() + + if self._mode != mode: + # if we are changing a write mode to read, ok + if self._mode in ["a", "w"] and mode in ["r", "r+"]: + pass + elif mode in ["w"]: + # this would truncate, raise here + if self.is_open: + raise PossibleDataLossError( + f"Re-opening the file [{self._path}] with mode [{self._mode}] " + "will delete the current file!" + ) + + self._mode = mode + + # close and reopen the handle + if self.is_open: + self.close() + + if self._complevel and self._complevel > 0: + self._filters = _tables().Filters( + self._complevel, self._complib, fletcher32=self._fletcher32 + ) + + if _table_file_open_policy_is_strict and self.is_open: + msg = ( + "Cannot open HDF5 file, which is already opened, " + "even in read-only mode." + ) + raise ValueError(msg) + + self._handle = tables.open_file(self._path, self._mode, **kwargs) + + def close(self) -> None: + """ + Close the PyTables file handle + """ + if self._handle is not None: + self._handle.close() + self._handle = None + + @property + def is_open(self) -> bool: + """ + return a boolean indicating whether the file is open + """ + if self._handle is None: + return False + return bool(self._handle.isopen) + + def flush(self, fsync: bool = False) -> None: + """ + Force all buffered modifications to be written to disk. + + Parameters + ---------- + fsync : bool (default False) + call ``os.fsync()`` on the file handle to force writing to disk. + + Notes + ----- + Without ``fsync=True``, flushing may not guarantee that the OS writes + to disk. With fsync, the operation will block until the OS claims the + file has been written; however, other caching layers may still + interfere. + """ + if self._handle is not None: + self._handle.flush() + if fsync: + with suppress(OSError): + os.fsync(self._handle.fileno()) + + def get(self, key: str): + """ + Retrieve pandas object stored in file. + + Parameters + ---------- + key : str + + Returns + ------- + object + Same type as object stored in file. + + Examples + -------- + >>> df = pd.DataFrame([[1, 2], [3, 4]], columns=['A', 'B']) + >>> store = pd.HDFStore("store.h5", 'w') # doctest: +SKIP + >>> store.put('data', df) # doctest: +SKIP + >>> store.get('data') # doctest: +SKIP + >>> store.close() # doctest: +SKIP + """ + with patch_pickle(): + # GH#31167 Without this patch, pickle doesn't know how to unpickle + # old DateOffset objects now that they are cdef classes. + group = self.get_node(key) + if group is None: + raise KeyError(f"No object named {key} in the file") + return self._read_group(group) + + def select( + self, + key: str, + where=None, + start=None, + stop=None, + columns=None, + iterator: bool = False, + chunksize: int | None = None, + auto_close: bool = False, + ): + """ + Retrieve pandas object stored in file, optionally based on where criteria. + + .. warning:: + + Pandas uses PyTables for reading and writing HDF5 files, which allows + serializing object-dtype data with pickle when using the "fixed" format. + Loading pickled data received from untrusted sources can be unsafe. + + See: https://docs.python.org/3/library/pickle.html for more. + + Parameters + ---------- + key : str + Object being retrieved from file. + where : list or None + List of Term (or convertible) objects, optional. + start : int or None + Row number to start selection. + stop : int, default None + Row number to stop selection. + columns : list or None + A list of columns that if not None, will limit the return columns. + iterator : bool or False + Returns an iterator. + chunksize : int or None + Number or rows to include in iteration, return an iterator. + auto_close : bool or False + Should automatically close the store when finished. + + Returns + ------- + object + Retrieved object from file. + + Examples + -------- + >>> df = pd.DataFrame([[1, 2], [3, 4]], columns=['A', 'B']) + >>> store = pd.HDFStore("store.h5", 'w') # doctest: +SKIP + >>> store.put('data', df) # doctest: +SKIP + >>> store.get('data') # doctest: +SKIP + >>> print(store.keys()) # doctest: +SKIP + ['/data1', '/data2'] + >>> store.select('/data1') # doctest: +SKIP + A B + 0 1 2 + 1 3 4 + >>> store.select('/data1', where='columns == A') # doctest: +SKIP + A + 0 1 + 1 3 + >>> store.close() # doctest: +SKIP + """ + group = self.get_node(key) + if group is None: + raise KeyError(f"No object named {key} in the file") + + # create the storer and axes + where = _ensure_term(where, scope_level=1) + s = self._create_storer(group) + s.infer_axes() + + # function to call on iteration + def func(_start, _stop, _where): + return s.read(start=_start, stop=_stop, where=_where, columns=columns) + + # create the iterator + it = TableIterator( + self, + s, + func, + where=where, + nrows=s.nrows, + start=start, + stop=stop, + iterator=iterator, + chunksize=chunksize, + auto_close=auto_close, + ) + + return it.get_result() + + def select_as_coordinates( + self, + key: str, + where=None, + start: int | None = None, + stop: int | None = None, + ): + """ + return the selection as an Index + + .. warning:: + + Pandas uses PyTables for reading and writing HDF5 files, which allows + serializing object-dtype data with pickle when using the "fixed" format. + Loading pickled data received from untrusted sources can be unsafe. + + See: https://docs.python.org/3/library/pickle.html for more. + + + Parameters + ---------- + key : str + where : list of Term (or convertible) objects, optional + start : integer (defaults to None), row number to start selection + stop : integer (defaults to None), row number to stop selection + """ + where = _ensure_term(where, scope_level=1) + tbl = self.get_storer(key) + if not isinstance(tbl, Table): + raise TypeError("can only read_coordinates with a table") + return tbl.read_coordinates(where=where, start=start, stop=stop) + + def select_column( + self, + key: str, + column: str, + start: int | None = None, + stop: int | None = None, + ): + """ + return a single column from the table. This is generally only useful to + select an indexable + + .. warning:: + + Pandas uses PyTables for reading and writing HDF5 files, which allows + serializing object-dtype data with pickle when using the "fixed" format. + Loading pickled data received from untrusted sources can be unsafe. + + See: https://docs.python.org/3/library/pickle.html for more. + + Parameters + ---------- + key : str + column : str + The column of interest. + start : int or None, default None + stop : int or None, default None + + Raises + ------ + raises KeyError if the column is not found (or key is not a valid + store) + raises ValueError if the column can not be extracted individually (it + is part of a data block) + + """ + tbl = self.get_storer(key) + if not isinstance(tbl, Table): + raise TypeError("can only read_column with a table") + return tbl.read_column(column=column, start=start, stop=stop) + + def select_as_multiple( + self, + keys, + where=None, + selector=None, + columns=None, + start=None, + stop=None, + iterator: bool = False, + chunksize: int | None = None, + auto_close: bool = False, + ): + """ + Retrieve pandas objects from multiple tables. + + .. warning:: + + Pandas uses PyTables for reading and writing HDF5 files, which allows + serializing object-dtype data with pickle when using the "fixed" format. + Loading pickled data received from untrusted sources can be unsafe. + + See: https://docs.python.org/3/library/pickle.html for more. + + Parameters + ---------- + keys : a list of the tables + selector : the table to apply the where criteria (defaults to keys[0] + if not supplied) + columns : the columns I want back + start : integer (defaults to None), row number to start selection + stop : integer (defaults to None), row number to stop selection + iterator : bool, return an iterator, default False + chunksize : nrows to include in iteration, return an iterator + auto_close : bool, default False + Should automatically close the store when finished. + + Raises + ------ + raises KeyError if keys or selector is not found or keys is empty + raises TypeError if keys is not a list or tuple + raises ValueError if the tables are not ALL THE SAME DIMENSIONS + """ + # default to single select + where = _ensure_term(where, scope_level=1) + if isinstance(keys, (list, tuple)) and len(keys) == 1: + keys = keys[0] + if isinstance(keys, str): + return self.select( + key=keys, + where=where, + columns=columns, + start=start, + stop=stop, + iterator=iterator, + chunksize=chunksize, + auto_close=auto_close, + ) + + if not isinstance(keys, (list, tuple)): + raise TypeError("keys must be a list/tuple") + + if not len(keys): + raise ValueError("keys must have a non-zero length") + + if selector is None: + selector = keys[0] + + # collect the tables + tbls = [self.get_storer(k) for k in keys] + s = self.get_storer(selector) + + # validate rows + nrows = None + for t, k in itertools.chain([(s, selector)], zip(tbls, keys)): + if t is None: + raise KeyError(f"Invalid table [{k}]") + if not t.is_table: + raise TypeError( + f"object [{t.pathname}] is not a table, and cannot be used in all " + "select as multiple" + ) + + if nrows is None: + nrows = t.nrows + elif t.nrows != nrows: + raise ValueError("all tables must have exactly the same nrows!") + + # The isinstance checks here are redundant with the check above, + # but necessary for mypy; see GH#29757 + _tbls = [x for x in tbls if isinstance(x, Table)] + + # axis is the concentration axes + axis = {t.non_index_axes[0][0] for t in _tbls}.pop() + + def func(_start, _stop, _where): + # retrieve the objs, _where is always passed as a set of + # coordinates here + objs = [ + t.read(where=_where, columns=columns, start=_start, stop=_stop) + for t in tbls + ] + + # concat and return + return concat(objs, axis=axis, verify_integrity=False)._consolidate() + + # create the iterator + it = TableIterator( + self, + s, + func, + where=where, + nrows=nrows, + start=start, + stop=stop, + iterator=iterator, + chunksize=chunksize, + auto_close=auto_close, + ) + + return it.get_result(coordinates=True) + + def put( + self, + key: str, + value: DataFrame | Series, + format=None, + index: bool = True, + append: bool = False, + complib=None, + complevel: int | None = None, + min_itemsize: int | dict[str, int] | None = None, + nan_rep=None, + data_columns: Literal[True] | list[str] | None = None, + encoding=None, + errors: str = "strict", + track_times: bool = True, + dropna: bool = False, + ) -> None: + """ + Store object in HDFStore. + + Parameters + ---------- + key : str + value : {Series, DataFrame} + format : 'fixed(f)|table(t)', default is 'fixed' + Format to use when storing object in HDFStore. Value can be one of: + + ``'fixed'`` + Fixed format. Fast writing/reading. Not-appendable, nor searchable. + ``'table'`` + Table format. Write as a PyTables Table structure which may perform + worse but allow more flexible operations like searching / selecting + subsets of the data. + index : bool, default True + Write DataFrame index as a column. + append : bool, default False + This will force Table format, append the input data to the existing. + data_columns : list of columns or True, default None + List of columns to create as data columns, or True to use all columns. + See `here + `__. + encoding : str, default None + Provide an encoding for strings. + track_times : bool, default True + Parameter is propagated to 'create_table' method of 'PyTables'. + If set to False it enables to have the same h5 files (same hashes) + independent on creation time. + dropna : bool, default False, optional + Remove missing values. + + Examples + -------- + >>> df = pd.DataFrame([[1, 2], [3, 4]], columns=['A', 'B']) + >>> store = pd.HDFStore("store.h5", 'w') # doctest: +SKIP + >>> store.put('data', df) # doctest: +SKIP + """ + if format is None: + format = get_option("io.hdf.default_format") or "fixed" + format = self._validate_format(format) + self._write_to_group( + key, + value, + format=format, + index=index, + append=append, + complib=complib, + complevel=complevel, + min_itemsize=min_itemsize, + nan_rep=nan_rep, + data_columns=data_columns, + encoding=encoding, + errors=errors, + track_times=track_times, + dropna=dropna, + ) + + def remove(self, key: str, where=None, start=None, stop=None) -> None: + """ + Remove pandas object partially by specifying the where condition + + Parameters + ---------- + key : str + Node to remove or delete rows from + where : list of Term (or convertible) objects, optional + start : integer (defaults to None), row number to start selection + stop : integer (defaults to None), row number to stop selection + + Returns + ------- + number of rows removed (or None if not a Table) + + Raises + ------ + raises KeyError if key is not a valid store + + """ + where = _ensure_term(where, scope_level=1) + try: + s = self.get_storer(key) + except KeyError: + # the key is not a valid store, re-raising KeyError + raise + except AssertionError: + # surface any assertion errors for e.g. debugging + raise + except Exception as err: + # In tests we get here with ClosedFileError, TypeError, and + # _table_mod.NoSuchNodeError. TODO: Catch only these? + + if where is not None: + raise ValueError( + "trying to remove a node with a non-None where clause!" + ) from err + + # we are actually trying to remove a node (with children) + node = self.get_node(key) + if node is not None: + node._f_remove(recursive=True) + return None + + # remove the node + if com.all_none(where, start, stop): + s.group._f_remove(recursive=True) + + # delete from the table + else: + if not s.is_table: + raise ValueError( + "can only remove with where on objects written as tables" + ) + return s.delete(where=where, start=start, stop=stop) + + def append( + self, + key: str, + value: DataFrame | Series, + format=None, + axes=None, + index: bool | list[str] = True, + append: bool = True, + complib=None, + complevel: int | None = None, + columns=None, + min_itemsize: int | dict[str, int] | None = None, + nan_rep=None, + chunksize: int | None = None, + expectedrows=None, + dropna: bool | None = None, + data_columns: Literal[True] | list[str] | None = None, + encoding=None, + errors: str = "strict", + ) -> None: + """ + Append to Table in file. + + Node must already exist and be Table format. + + Parameters + ---------- + key : str + value : {Series, DataFrame} + format : 'table' is the default + Format to use when storing object in HDFStore. Value can be one of: + + ``'table'`` + Table format. Write as a PyTables Table structure which may perform + worse but allow more flexible operations like searching / selecting + subsets of the data. + index : bool, default True + Write DataFrame index as a column. + append : bool, default True + Append the input data to the existing. + data_columns : list of columns, or True, default None + List of columns to create as indexed data columns for on-disk + queries, or True to use all columns. By default only the axes + of the object are indexed. See `here + `__. + min_itemsize : dict of columns that specify minimum str sizes + nan_rep : str to use as str nan representation + chunksize : size to chunk the writing + expectedrows : expected TOTAL row size of this table + encoding : default None, provide an encoding for str + dropna : bool, default False, optional + Do not write an ALL nan row to the store settable + by the option 'io.hdf.dropna_table'. + + Notes + ----- + Does *not* check if data being appended overlaps with existing + data in the table, so be careful + + Examples + -------- + >>> df1 = pd.DataFrame([[1, 2], [3, 4]], columns=['A', 'B']) + >>> store = pd.HDFStore("store.h5", 'w') # doctest: +SKIP + >>> store.put('data', df1, format='table') # doctest: +SKIP + >>> df2 = pd.DataFrame([[5, 6], [7, 8]], columns=['A', 'B']) + >>> store.append('data', df2) # doctest: +SKIP + >>> store.close() # doctest: +SKIP + A B + 0 1 2 + 1 3 4 + 0 5 6 + 1 7 8 + """ + if columns is not None: + raise TypeError( + "columns is not a supported keyword in append, try data_columns" + ) + + if dropna is None: + dropna = get_option("io.hdf.dropna_table") + if format is None: + format = get_option("io.hdf.default_format") or "table" + format = self._validate_format(format) + self._write_to_group( + key, + value, + format=format, + axes=axes, + index=index, + append=append, + complib=complib, + complevel=complevel, + min_itemsize=min_itemsize, + nan_rep=nan_rep, + chunksize=chunksize, + expectedrows=expectedrows, + dropna=dropna, + data_columns=data_columns, + encoding=encoding, + errors=errors, + ) + + def append_to_multiple( + self, + d: dict, + value, + selector, + data_columns=None, + axes=None, + dropna: bool = False, + **kwargs, + ) -> None: + """ + Append to multiple tables + + Parameters + ---------- + d : a dict of table_name to table_columns, None is acceptable as the + values of one node (this will get all the remaining columns) + value : a pandas object + selector : a string that designates the indexable table; all of its + columns will be designed as data_columns, unless data_columns is + passed, in which case these are used + data_columns : list of columns to create as data columns, or True to + use all columns + dropna : if evaluates to True, drop rows from all tables if any single + row in each table has all NaN. Default False. + + Notes + ----- + axes parameter is currently not accepted + + """ + if axes is not None: + raise TypeError( + "axes is currently not accepted as a parameter to append_to_multiple; " + "you can create the tables independently instead" + ) + + if not isinstance(d, dict): + raise ValueError( + "append_to_multiple must have a dictionary specified as the " + "way to split the value" + ) + + if selector not in d: + raise ValueError( + "append_to_multiple requires a selector that is in passed dict" + ) + + # figure out the splitting axis (the non_index_axis) + axis = next(iter(set(range(value.ndim)) - set(_AXES_MAP[type(value)]))) + + # figure out how to split the value + remain_key = None + remain_values: list = [] + for k, v in d.items(): + if v is None: + if remain_key is not None: + raise ValueError( + "append_to_multiple can only have one value in d that is None" + ) + remain_key = k + else: + remain_values.extend(v) + if remain_key is not None: + ordered = value.axes[axis] + ordd = ordered.difference(Index(remain_values)) + ordd = sorted(ordered.get_indexer(ordd)) + d[remain_key] = ordered.take(ordd) + + # data_columns + if data_columns is None: + data_columns = d[selector] + + # ensure rows are synchronized across the tables + if dropna: + idxs = (value[cols].dropna(how="all").index for cols in d.values()) + valid_index = next(idxs) + for index in idxs: + valid_index = valid_index.intersection(index) + value = value.loc[valid_index] + + min_itemsize = kwargs.pop("min_itemsize", None) + + # append + for k, v in d.items(): + dc = data_columns if k == selector else None + + # compute the val + val = value.reindex(v, axis=axis) + + filtered = ( + {key: value for (key, value) in min_itemsize.items() if key in v} + if min_itemsize is not None + else None + ) + self.append(k, val, data_columns=dc, min_itemsize=filtered, **kwargs) + + def create_table_index( + self, + key: str, + columns=None, + optlevel: int | None = None, + kind: str | None = None, + ) -> None: + """ + Create a pytables index on the table. + + Parameters + ---------- + key : str + columns : None, bool, or listlike[str] + Indicate which columns to create an index on. + + * False : Do not create any indexes. + * True : Create indexes on all columns. + * None : Create indexes on all columns. + * listlike : Create indexes on the given columns. + + optlevel : int or None, default None + Optimization level, if None, pytables defaults to 6. + kind : str or None, default None + Kind of index, if None, pytables defaults to "medium". + + Raises + ------ + TypeError: raises if the node is not a table + """ + # version requirements + _tables() + s = self.get_storer(key) + if s is None: + return + + if not isinstance(s, Table): + raise TypeError("cannot create table index on a Fixed format store") + s.create_index(columns=columns, optlevel=optlevel, kind=kind) + + def groups(self) -> list: + """ + Return a list of all the top-level nodes. + + Each node returned is not a pandas storage object. + + Returns + ------- + list + List of objects. + + Examples + -------- + >>> df = pd.DataFrame([[1, 2], [3, 4]], columns=['A', 'B']) + >>> store = pd.HDFStore("store.h5", 'w') # doctest: +SKIP + >>> store.put('data', df) # doctest: +SKIP + >>> print(store.groups()) # doctest: +SKIP + >>> store.close() # doctest: +SKIP + [/data (Group) '' + children := ['axis0' (Array), 'axis1' (Array), 'block0_values' (Array), + 'block0_items' (Array)]] + """ + _tables() + self._check_if_open() + assert self._handle is not None # for mypy + assert _table_mod is not None # for mypy + return [ + g + for g in self._handle.walk_groups() + if ( + not isinstance(g, _table_mod.link.Link) + and ( + getattr(g._v_attrs, "pandas_type", None) + or getattr(g, "table", None) + or (isinstance(g, _table_mod.table.Table) and g._v_name != "table") + ) + ) + ] + + def walk(self, where: str = "/") -> Iterator[tuple[str, list[str], list[str]]]: + """ + Walk the pytables group hierarchy for pandas objects. + + This generator will yield the group path, subgroups and pandas object + names for each group. + + Any non-pandas PyTables objects that are not a group will be ignored. + + The `where` group itself is listed first (preorder), then each of its + child groups (following an alphanumerical order) is also traversed, + following the same procedure. + + Parameters + ---------- + where : str, default "/" + Group where to start walking. + + Yields + ------ + path : str + Full path to a group (without trailing '/'). + groups : list + Names (strings) of the groups contained in `path`. + leaves : list + Names (strings) of the pandas objects contained in `path`. + + Examples + -------- + >>> df1 = pd.DataFrame([[1, 2], [3, 4]], columns=['A', 'B']) + >>> store = pd.HDFStore("store.h5", 'w') # doctest: +SKIP + >>> store.put('data', df1, format='table') # doctest: +SKIP + >>> df2 = pd.DataFrame([[5, 6], [7, 8]], columns=['A', 'B']) + >>> store.append('data', df2) # doctest: +SKIP + >>> store.close() # doctest: +SKIP + >>> for group in store.walk(): # doctest: +SKIP + ... print(group) # doctest: +SKIP + >>> store.close() # doctest: +SKIP + """ + _tables() + self._check_if_open() + assert self._handle is not None # for mypy + assert _table_mod is not None # for mypy + + for g in self._handle.walk_groups(where): + if getattr(g._v_attrs, "pandas_type", None) is not None: + continue + + groups = [] + leaves = [] + for child in g._v_children.values(): + pandas_type = getattr(child._v_attrs, "pandas_type", None) + if pandas_type is None: + if isinstance(child, _table_mod.group.Group): + groups.append(child._v_name) + else: + leaves.append(child._v_name) + + yield (g._v_pathname.rstrip("/"), groups, leaves) + + def get_node(self, key: str) -> Node | None: + """return the node with the key or None if it does not exist""" + self._check_if_open() + if not key.startswith("/"): + key = "/" + key + + assert self._handle is not None + assert _table_mod is not None # for mypy + try: + node = self._handle.get_node(self.root, key) + except _table_mod.exceptions.NoSuchNodeError: + return None + + assert isinstance(node, _table_mod.Node), type(node) + return node + + def get_storer(self, key: str) -> GenericFixed | Table: + """return the storer object for a key, raise if not in the file""" + group = self.get_node(key) + if group is None: + raise KeyError(f"No object named {key} in the file") + + s = self._create_storer(group) + s.infer_axes() + return s + + def copy( + self, + file, + mode: str = "w", + propindexes: bool = True, + keys=None, + complib=None, + complevel: int | None = None, + fletcher32: bool = False, + overwrite: bool = True, + ) -> HDFStore: + """ + Copy the existing store to a new file, updating in place. + + Parameters + ---------- + propindexes : bool, default True + Restore indexes in copied file. + keys : list, optional + List of keys to include in the copy (defaults to all). + overwrite : bool, default True + Whether to overwrite (remove and replace) existing nodes in the new store. + mode, complib, complevel, fletcher32 same as in HDFStore.__init__ + + Returns + ------- + open file handle of the new store + """ + new_store = HDFStore( + file, mode=mode, complib=complib, complevel=complevel, fletcher32=fletcher32 + ) + if keys is None: + keys = list(self.keys()) + if not isinstance(keys, (tuple, list)): + keys = [keys] + for k in keys: + s = self.get_storer(k) + if s is not None: + if k in new_store: + if overwrite: + new_store.remove(k) + + data = self.select(k) + if isinstance(s, Table): + index: bool | list[str] = False + if propindexes: + index = [a.name for a in s.axes if a.is_indexed] + new_store.append( + k, + data, + index=index, + data_columns=getattr(s, "data_columns", None), + encoding=s.encoding, + ) + else: + new_store.put(k, data, encoding=s.encoding) + + return new_store + + def info(self) -> str: + """ + Print detailed information on the store. + + Returns + ------- + str + + Examples + -------- + >>> df = pd.DataFrame([[1, 2], [3, 4]], columns=['A', 'B']) + >>> store = pd.HDFStore("store.h5", 'w') # doctest: +SKIP + >>> store.put('data', df) # doctest: +SKIP + >>> print(store.info()) # doctest: +SKIP + >>> store.close() # doctest: +SKIP + + File path: store.h5 + /data frame (shape->[2,2]) + """ + path = pprint_thing(self._path) + output = f"{type(self)}\nFile path: {path}\n" + + if self.is_open: + lkeys = sorted(self.keys()) + if len(lkeys): + keys = [] + values = [] + + for k in lkeys: + try: + s = self.get_storer(k) + if s is not None: + keys.append(pprint_thing(s.pathname or k)) + values.append(pprint_thing(s or "invalid_HDFStore node")) + except AssertionError: + # surface any assertion errors for e.g. debugging + raise + except Exception as detail: + keys.append(k) + dstr = pprint_thing(detail) + values.append(f"[invalid_HDFStore node: {dstr}]") + + output += adjoin(12, keys, values) + else: + output += "Empty" + else: + output += "File is CLOSED" + + return output + + # ------------------------------------------------------------------------ + # private methods + + def _check_if_open(self) -> None: + if not self.is_open: + raise ClosedFileError(f"{self._path} file is not open!") + + def _validate_format(self, format: str) -> str: + """validate / deprecate formats""" + # validate + try: + format = _FORMAT_MAP[format.lower()] + except KeyError as err: + raise TypeError(f"invalid HDFStore format specified [{format}]") from err + + return format + + def _create_storer( + self, + group, + format=None, + value: DataFrame | Series | None = None, + encoding: str = "UTF-8", + errors: str = "strict", + ) -> GenericFixed | Table: + """return a suitable class to operate""" + cls: type[GenericFixed | Table] + + if value is not None and not isinstance(value, (Series, DataFrame)): + raise TypeError("value must be None, Series, or DataFrame") + + pt = _ensure_decoded(getattr(group._v_attrs, "pandas_type", None)) + tt = _ensure_decoded(getattr(group._v_attrs, "table_type", None)) + + # infer the pt from the passed value + if pt is None: + if value is None: + _tables() + assert _table_mod is not None # for mypy + if getattr(group, "table", None) or isinstance( + group, _table_mod.table.Table + ): + pt = "frame_table" + tt = "generic_table" + else: + raise TypeError( + "cannot create a storer if the object is not existing " + "nor a value are passed" + ) + else: + if isinstance(value, Series): + pt = "series" + else: + pt = "frame" + + # we are actually a table + if format == "table": + pt += "_table" + + # a storer node + if "table" not in pt: + _STORER_MAP = {"series": SeriesFixed, "frame": FrameFixed} + try: + cls = _STORER_MAP[pt] + except KeyError as err: + raise TypeError( + f"cannot properly create the storer for: [_STORER_MAP] [group->" + f"{group},value->{type(value)},format->{format}" + ) from err + return cls(self, group, encoding=encoding, errors=errors) + + # existing node (and must be a table) + if tt is None: + # if we are a writer, determine the tt + if value is not None: + if pt == "series_table": + index = getattr(value, "index", None) + if index is not None: + if index.nlevels == 1: + tt = "appendable_series" + elif index.nlevels > 1: + tt = "appendable_multiseries" + elif pt == "frame_table": + index = getattr(value, "index", None) + if index is not None: + if index.nlevels == 1: + tt = "appendable_frame" + elif index.nlevels > 1: + tt = "appendable_multiframe" + + _TABLE_MAP = { + "generic_table": GenericTable, + "appendable_series": AppendableSeriesTable, + "appendable_multiseries": AppendableMultiSeriesTable, + "appendable_frame": AppendableFrameTable, + "appendable_multiframe": AppendableMultiFrameTable, + "worm": WORMTable, + } + try: + cls = _TABLE_MAP[tt] + except KeyError as err: + raise TypeError( + f"cannot properly create the storer for: [_TABLE_MAP] [group->" + f"{group},value->{type(value)},format->{format}" + ) from err + + return cls(self, group, encoding=encoding, errors=errors) + + def _write_to_group( + self, + key: str, + value: DataFrame | Series, + format, + axes=None, + index: bool | list[str] = True, + append: bool = False, + complib=None, + complevel: int | None = None, + fletcher32=None, + min_itemsize: int | dict[str, int] | None = None, + chunksize: int | None = None, + expectedrows=None, + dropna: bool = False, + nan_rep=None, + data_columns=None, + encoding=None, + errors: str = "strict", + track_times: bool = True, + ) -> None: + # we don't want to store a table node at all if our object is 0-len + # as there are not dtypes + if getattr(value, "empty", None) and (format == "table" or append): + return + + group = self._identify_group(key, append) + + s = self._create_storer(group, format, value, encoding=encoding, errors=errors) + if append: + # raise if we are trying to append to a Fixed format, + # or a table that exists (and we are putting) + if not s.is_table or (s.is_table and format == "fixed" and s.is_exists): + raise ValueError("Can only append to Tables") + if not s.is_exists: + s.set_object_info() + else: + s.set_object_info() + + if not s.is_table and complib: + raise ValueError("Compression not supported on Fixed format stores") + + # write the object + s.write( + obj=value, + axes=axes, + append=append, + complib=complib, + complevel=complevel, + fletcher32=fletcher32, + min_itemsize=min_itemsize, + chunksize=chunksize, + expectedrows=expectedrows, + dropna=dropna, + nan_rep=nan_rep, + data_columns=data_columns, + track_times=track_times, + ) + + if isinstance(s, Table) and index: + s.create_index(columns=index) + + def _read_group(self, group: Node): + s = self._create_storer(group) + s.infer_axes() + return s.read() + + def _identify_group(self, key: str, append: bool) -> Node: + """Identify HDF5 group based on key, delete/create group if needed.""" + group = self.get_node(key) + + # we make this assertion for mypy; the get_node call will already + # have raised if this is incorrect + assert self._handle is not None + + # remove the node if we are not appending + if group is not None and not append: + self._handle.remove_node(group, recursive=True) + group = None + + if group is None: + group = self._create_nodes_and_group(key) + + return group + + def _create_nodes_and_group(self, key: str) -> Node: + """Create nodes from key and return group name.""" + # assertion for mypy + assert self._handle is not None + + paths = key.split("/") + # recursively create the groups + path = "/" + for p in paths: + if not len(p): + continue + new_path = path + if not path.endswith("/"): + new_path += "/" + new_path += p + group = self.get_node(new_path) + if group is None: + group = self._handle.create_group(path, p) + path = new_path + return group + + +class TableIterator: + """ + Define the iteration interface on a table + + Parameters + ---------- + store : HDFStore + s : the referred storer + func : the function to execute the query + where : the where of the query + nrows : the rows to iterate on + start : the passed start value (default is None) + stop : the passed stop value (default is None) + iterator : bool, default False + Whether to use the default iterator. + chunksize : the passed chunking value (default is 100000) + auto_close : bool, default False + Whether to automatically close the store at the end of iteration. + """ + + chunksize: int | None + store: HDFStore + s: GenericFixed | Table + + def __init__( + self, + store: HDFStore, + s: GenericFixed | Table, + func, + where, + nrows, + start=None, + stop=None, + iterator: bool = False, + chunksize: int | None = None, + auto_close: bool = False, + ) -> None: + self.store = store + self.s = s + self.func = func + self.where = where + + # set start/stop if they are not set if we are a table + if self.s.is_table: + if nrows is None: + nrows = 0 + if start is None: + start = 0 + if stop is None: + stop = nrows + stop = min(nrows, stop) + + self.nrows = nrows + self.start = start + self.stop = stop + + self.coordinates = None + if iterator or chunksize is not None: + if chunksize is None: + chunksize = 100000 + self.chunksize = int(chunksize) + else: + self.chunksize = None + + self.auto_close = auto_close + + def __iter__(self) -> Iterator: + # iterate + current = self.start + if self.coordinates is None: + raise ValueError("Cannot iterate until get_result is called.") + while current < self.stop: + stop = min(current + self.chunksize, self.stop) + value = self.func(None, None, self.coordinates[current:stop]) + current = stop + if value is None or not len(value): + continue + + yield value + + self.close() + + def close(self) -> None: + if self.auto_close: + self.store.close() + + def get_result(self, coordinates: bool = False): + # return the actual iterator + if self.chunksize is not None: + if not isinstance(self.s, Table): + raise TypeError("can only use an iterator or chunksize on a table") + + self.coordinates = self.s.read_coordinates(where=self.where) + + return self + + # if specified read via coordinates (necessary for multiple selections + if coordinates: + if not isinstance(self.s, Table): + raise TypeError("can only read_coordinates on a table") + where = self.s.read_coordinates( + where=self.where, start=self.start, stop=self.stop + ) + else: + where = self.where + + # directly return the result + results = self.func(self.start, self.stop, where) + self.close() + return results + + +class IndexCol: + """ + an index column description class + + Parameters + ---------- + axis : axis which I reference + values : the ndarray like converted values + kind : a string description of this type + typ : the pytables type + pos : the position in the pytables + + """ + + is_an_indexable: bool = True + is_data_indexable: bool = True + _info_fields = ["freq", "tz", "index_name"] + + def __init__( + self, + name: str, + values=None, + kind=None, + typ=None, + cname: str | None = None, + axis=None, + pos=None, + freq=None, + tz=None, + index_name=None, + ordered=None, + table=None, + meta=None, + metadata=None, + ) -> None: + if not isinstance(name, str): + raise ValueError("`name` must be a str.") + + self.values = values + self.kind = kind + self.typ = typ + self.name = name + self.cname = cname or name + self.axis = axis + self.pos = pos + self.freq = freq + self.tz = tz + self.index_name = index_name + self.ordered = ordered + self.table = table + self.meta = meta + self.metadata = metadata + + if pos is not None: + self.set_pos(pos) + + # These are ensured as long as the passed arguments match the + # constructor annotations. + assert isinstance(self.name, str) + assert isinstance(self.cname, str) + + @property + def itemsize(self) -> int: + # Assumes self.typ has already been initialized + return self.typ.itemsize + + @property + def kind_attr(self) -> str: + return f"{self.name}_kind" + + def set_pos(self, pos: int) -> None: + """set the position of this column in the Table""" + self.pos = pos + if pos is not None and self.typ is not None: + self.typ._v_pos = pos + + def __repr__(self) -> str: + temp = tuple( + map(pprint_thing, (self.name, self.cname, self.axis, self.pos, self.kind)) + ) + return ",".join( + [ + f"{key}->{value}" + for key, value in zip(["name", "cname", "axis", "pos", "kind"], temp) + ] + ) + + def __eq__(self, other: object) -> bool: + """compare 2 col items""" + return all( + getattr(self, a, None) == getattr(other, a, None) + for a in ["name", "cname", "axis", "pos"] + ) + + def __ne__(self, other) -> bool: + return not self.__eq__(other) + + @property + def is_indexed(self) -> bool: + """return whether I am an indexed column""" + if not hasattr(self.table, "cols"): + # e.g. if infer hasn't been called yet, self.table will be None. + return False + return getattr(self.table.cols, self.cname).is_indexed + + def convert( + self, values: np.ndarray, nan_rep, encoding: str, errors: str + ) -> tuple[np.ndarray, np.ndarray] | tuple[Index, Index]: + """ + Convert the data from this selection to the appropriate pandas type. + """ + assert isinstance(values, np.ndarray), type(values) + + # values is a recarray + if values.dtype.fields is not None: + # Copy, otherwise values will be a view + # preventing the original recarry from being free'ed + values = values[self.cname].copy() + + val_kind = _ensure_decoded(self.kind) + values = _maybe_convert(values, val_kind, encoding, errors) + kwargs = {} + kwargs["name"] = _ensure_decoded(self.index_name) + + if self.freq is not None: + kwargs["freq"] = _ensure_decoded(self.freq) + + factory: type[Index | DatetimeIndex] = Index + if lib.is_np_dtype(values.dtype, "M") or isinstance( + values.dtype, DatetimeTZDtype + ): + factory = DatetimeIndex + elif values.dtype == "i8" and "freq" in kwargs: + # PeriodIndex data is stored as i8 + # error: Incompatible types in assignment (expression has type + # "Callable[[Any, KwArg(Any)], PeriodIndex]", variable has type + # "Union[Type[Index], Type[DatetimeIndex]]") + factory = lambda x, **kwds: PeriodIndex.from_ordinals( # type: ignore[assignment] + x, freq=kwds.get("freq", None) + )._rename( + kwds["name"] + ) + + # making an Index instance could throw a number of different errors + try: + new_pd_index = factory(values, **kwargs) + except ValueError: + # if the output freq is different that what we recorded, + # it should be None (see also 'doc example part 2') + if "freq" in kwargs: + kwargs["freq"] = None + new_pd_index = factory(values, **kwargs) + final_pd_index = _set_tz(new_pd_index, self.tz) + return final_pd_index, final_pd_index + + def take_data(self): + """return the values""" + return self.values + + @property + def attrs(self): + return self.table._v_attrs + + @property + def description(self): + return self.table.description + + @property + def col(self): + """return my current col description""" + return getattr(self.description, self.cname, None) + + @property + def cvalues(self): + """return my cython values""" + return self.values + + def __iter__(self) -> Iterator: + return iter(self.values) + + def maybe_set_size(self, min_itemsize=None) -> None: + """ + maybe set a string col itemsize: + min_itemsize can be an integer or a dict with this columns name + with an integer size + """ + if _ensure_decoded(self.kind) == "string": + if isinstance(min_itemsize, dict): + min_itemsize = min_itemsize.get(self.name) + + if min_itemsize is not None and self.typ.itemsize < min_itemsize: + self.typ = _tables().StringCol(itemsize=min_itemsize, pos=self.pos) + + def validate_names(self) -> None: + pass + + def validate_and_set(self, handler: AppendableTable, append: bool) -> None: + self.table = handler.table + self.validate_col() + self.validate_attr(append) + self.validate_metadata(handler) + self.write_metadata(handler) + self.set_attr() + + def validate_col(self, itemsize=None): + """validate this column: return the compared against itemsize""" + # validate this column for string truncation (or reset to the max size) + if _ensure_decoded(self.kind) == "string": + c = self.col + if c is not None: + if itemsize is None: + itemsize = self.itemsize + if c.itemsize < itemsize: + raise ValueError( + f"Trying to store a string with len [{itemsize}] in " + f"[{self.cname}] column but\nthis column has a limit of " + f"[{c.itemsize}]!\nConsider using min_itemsize to " + "preset the sizes on these columns" + ) + return c.itemsize + + return None + + def validate_attr(self, append: bool) -> None: + # check for backwards incompatibility + if append: + existing_kind = getattr(self.attrs, self.kind_attr, None) + if existing_kind is not None and existing_kind != self.kind: + raise TypeError( + f"incompatible kind in col [{existing_kind} - {self.kind}]" + ) + + def update_info(self, info) -> None: + """ + set/update the info for this indexable with the key/value + if there is a conflict raise/warn as needed + """ + for key in self._info_fields: + value = getattr(self, key, None) + idx = info.setdefault(self.name, {}) + + existing_value = idx.get(key) + if key in idx and value is not None and existing_value != value: + # frequency/name just warn + if key in ["freq", "index_name"]: + ws = attribute_conflict_doc % (key, existing_value, value) + warnings.warn( + ws, AttributeConflictWarning, stacklevel=find_stack_level() + ) + + # reset + idx[key] = None + setattr(self, key, None) + + else: + raise ValueError( + f"invalid info for [{self.name}] for [{key}], " + f"existing_value [{existing_value}] conflicts with " + f"new value [{value}]" + ) + elif value is not None or existing_value is not None: + idx[key] = value + + def set_info(self, info) -> None: + """set my state from the passed info""" + idx = info.get(self.name) + if idx is not None: + self.__dict__.update(idx) + + def set_attr(self) -> None: + """set the kind for this column""" + setattr(self.attrs, self.kind_attr, self.kind) + + def validate_metadata(self, handler: AppendableTable) -> None: + """validate that kind=category does not change the categories""" + if self.meta == "category": + new_metadata = self.metadata + cur_metadata = handler.read_metadata(self.cname) + if ( + new_metadata is not None + and cur_metadata is not None + and not array_equivalent( + new_metadata, cur_metadata, strict_nan=True, dtype_equal=True + ) + ): + raise ValueError( + "cannot append a categorical with " + "different categories to the existing" + ) + + def write_metadata(self, handler: AppendableTable) -> None: + """set the meta data""" + if self.metadata is not None: + handler.write_metadata(self.cname, self.metadata) + + +class GenericIndexCol(IndexCol): + """an index which is not represented in the data of the table""" + + @property + def is_indexed(self) -> bool: + return False + + def convert( + self, values: np.ndarray, nan_rep, encoding: str, errors: str + ) -> tuple[Index, Index]: + """ + Convert the data from this selection to the appropriate pandas type. + + Parameters + ---------- + values : np.ndarray + nan_rep : str + encoding : str + errors : str + """ + assert isinstance(values, np.ndarray), type(values) + + index = RangeIndex(len(values)) + return index, index + + def set_attr(self) -> None: + pass + + +class DataCol(IndexCol): + """ + a data holding column, by definition this is not indexable + + Parameters + ---------- + data : the actual data + cname : the column name in the table to hold the data (typically + values) + meta : a string description of the metadata + metadata : the actual metadata + """ + + is_an_indexable = False + is_data_indexable = False + _info_fields = ["tz", "ordered"] + + def __init__( + self, + name: str, + values=None, + kind=None, + typ=None, + cname: str | None = None, + pos=None, + tz=None, + ordered=None, + table=None, + meta=None, + metadata=None, + dtype: DtypeArg | None = None, + data=None, + ) -> None: + super().__init__( + name=name, + values=values, + kind=kind, + typ=typ, + pos=pos, + cname=cname, + tz=tz, + ordered=ordered, + table=table, + meta=meta, + metadata=metadata, + ) + self.dtype = dtype + self.data = data + + @property + def dtype_attr(self) -> str: + return f"{self.name}_dtype" + + @property + def meta_attr(self) -> str: + return f"{self.name}_meta" + + def __repr__(self) -> str: + temp = tuple( + map( + pprint_thing, (self.name, self.cname, self.dtype, self.kind, self.shape) + ) + ) + return ",".join( + [ + f"{key}->{value}" + for key, value in zip(["name", "cname", "dtype", "kind", "shape"], temp) + ] + ) + + def __eq__(self, other: object) -> bool: + """compare 2 col items""" + return all( + getattr(self, a, None) == getattr(other, a, None) + for a in ["name", "cname", "dtype", "pos"] + ) + + def set_data(self, data: ArrayLike) -> None: + assert data is not None + assert self.dtype is None + + data, dtype_name = _get_data_and_dtype_name(data) + + self.data = data + self.dtype = dtype_name + self.kind = _dtype_to_kind(dtype_name) + + def take_data(self): + """return the data""" + return self.data + + @classmethod + def _get_atom(cls, values: ArrayLike) -> Col: + """ + Get an appropriately typed and shaped pytables.Col object for values. + """ + dtype = values.dtype + # error: Item "ExtensionDtype" of "Union[ExtensionDtype, dtype[Any]]" has no + # attribute "itemsize" + itemsize = dtype.itemsize # type: ignore[union-attr] + + shape = values.shape + if values.ndim == 1: + # EA, use block shape pretending it is 2D + # TODO(EA2D): not necessary with 2D EAs + shape = (1, values.size) + + if isinstance(values, Categorical): + codes = values.codes + atom = cls.get_atom_data(shape, kind=codes.dtype.name) + elif lib.is_np_dtype(dtype, "M") or isinstance(dtype, DatetimeTZDtype): + atom = cls.get_atom_datetime64(shape) + elif lib.is_np_dtype(dtype, "m"): + atom = cls.get_atom_timedelta64(shape) + elif is_complex_dtype(dtype): + atom = _tables().ComplexCol(itemsize=itemsize, shape=shape[0]) + elif is_string_dtype(dtype): + atom = cls.get_atom_string(shape, itemsize) + else: + atom = cls.get_atom_data(shape, kind=dtype.name) + + return atom + + @classmethod + def get_atom_string(cls, shape, itemsize): + return _tables().StringCol(itemsize=itemsize, shape=shape[0]) + + @classmethod + def get_atom_coltype(cls, kind: str) -> type[Col]: + """return the PyTables column class for this column""" + if kind.startswith("uint"): + k4 = kind[4:] + col_name = f"UInt{k4}Col" + elif kind.startswith("period"): + # we store as integer + col_name = "Int64Col" + else: + kcap = kind.capitalize() + col_name = f"{kcap}Col" + + return getattr(_tables(), col_name) + + @classmethod + def get_atom_data(cls, shape, kind: str) -> Col: + return cls.get_atom_coltype(kind=kind)(shape=shape[0]) + + @classmethod + def get_atom_datetime64(cls, shape): + return _tables().Int64Col(shape=shape[0]) + + @classmethod + def get_atom_timedelta64(cls, shape): + return _tables().Int64Col(shape=shape[0]) + + @property + def shape(self): + return getattr(self.data, "shape", None) + + @property + def cvalues(self): + """return my cython values""" + return self.data + + def validate_attr(self, append) -> None: + """validate that we have the same order as the existing & same dtype""" + if append: + existing_fields = getattr(self.attrs, self.kind_attr, None) + if existing_fields is not None and existing_fields != list(self.values): + raise ValueError("appended items do not match existing items in table!") + + existing_dtype = getattr(self.attrs, self.dtype_attr, None) + if existing_dtype is not None and existing_dtype != self.dtype: + raise ValueError( + "appended items dtype do not match existing items dtype in table!" + ) + + def convert(self, values: np.ndarray, nan_rep, encoding: str, errors: str): + """ + Convert the data from this selection to the appropriate pandas type. + + Parameters + ---------- + values : np.ndarray + nan_rep : + encoding : str + errors : str + + Returns + ------- + index : listlike to become an Index + data : ndarraylike to become a column + """ + assert isinstance(values, np.ndarray), type(values) + + # values is a recarray + if values.dtype.fields is not None: + values = values[self.cname] + + assert self.typ is not None + if self.dtype is None: + # Note: in tests we never have timedelta64 or datetime64, + # so the _get_data_and_dtype_name may be unnecessary + converted, dtype_name = _get_data_and_dtype_name(values) + kind = _dtype_to_kind(dtype_name) + else: + converted = values + dtype_name = self.dtype + kind = self.kind + + assert isinstance(converted, np.ndarray) # for mypy + + # use the meta if needed + meta = _ensure_decoded(self.meta) + metadata = self.metadata + ordered = self.ordered + tz = self.tz + + assert dtype_name is not None + # convert to the correct dtype + dtype = _ensure_decoded(dtype_name) + + # reverse converts + if dtype.startswith("datetime64"): + # recreate with tz if indicated + converted = _set_tz(converted, tz, coerce=True) + + elif dtype == "timedelta64": + converted = np.asarray(converted, dtype="m8[ns]") + elif dtype == "date": + try: + converted = np.asarray( + [date.fromordinal(v) for v in converted], dtype=object + ) + except ValueError: + converted = np.asarray( + [date.fromtimestamp(v) for v in converted], dtype=object + ) + + elif meta == "category": + # we have a categorical + categories = metadata + codes = converted.ravel() + + # if we have stored a NaN in the categories + # then strip it; in theory we could have BOTH + # -1s in the codes and nulls :< + if categories is None: + # Handle case of NaN-only categorical columns in which case + # the categories are an empty array; when this is stored, + # pytables cannot write a zero-len array, so on readback + # the categories would be None and `read_hdf()` would fail. + categories = Index([], dtype=np.float64) + else: + mask = isna(categories) + if mask.any(): + categories = categories[~mask] + codes[codes != -1] -= mask.astype(int).cumsum()._values + + converted = Categorical.from_codes( + codes, categories=categories, ordered=ordered, validate=False + ) + + else: + try: + converted = converted.astype(dtype, copy=False) + except TypeError: + converted = converted.astype("O", copy=False) + + # convert nans / decode + if _ensure_decoded(kind) == "string": + converted = _unconvert_string_array( + converted, nan_rep=nan_rep, encoding=encoding, errors=errors + ) + + return self.values, converted + + def set_attr(self) -> None: + """set the data for this column""" + setattr(self.attrs, self.kind_attr, self.values) + setattr(self.attrs, self.meta_attr, self.meta) + assert self.dtype is not None + setattr(self.attrs, self.dtype_attr, self.dtype) + + +class DataIndexableCol(DataCol): + """represent a data column that can be indexed""" + + is_data_indexable = True + + def validate_names(self) -> None: + if not is_string_dtype(Index(self.values).dtype): + # TODO: should the message here be more specifically non-str? + raise ValueError("cannot have non-object label DataIndexableCol") + + @classmethod + def get_atom_string(cls, shape, itemsize): + return _tables().StringCol(itemsize=itemsize) + + @classmethod + def get_atom_data(cls, shape, kind: str) -> Col: + return cls.get_atom_coltype(kind=kind)() + + @classmethod + def get_atom_datetime64(cls, shape): + return _tables().Int64Col() + + @classmethod + def get_atom_timedelta64(cls, shape): + return _tables().Int64Col() + + +class GenericDataIndexableCol(DataIndexableCol): + """represent a generic pytables data column""" + + +class Fixed: + """ + represent an object in my store + facilitate read/write of various types of objects + this is an abstract base class + + Parameters + ---------- + parent : HDFStore + group : Node + The group node where the table resides. + """ + + pandas_kind: str + format_type: str = "fixed" # GH#30962 needed by dask + obj_type: type[DataFrame | Series] + ndim: int + parent: HDFStore + is_table: bool = False + + def __init__( + self, + parent: HDFStore, + group: Node, + encoding: str | None = "UTF-8", + errors: str = "strict", + ) -> None: + assert isinstance(parent, HDFStore), type(parent) + assert _table_mod is not None # needed for mypy + assert isinstance(group, _table_mod.Node), type(group) + self.parent = parent + self.group = group + self.encoding = _ensure_encoding(encoding) + self.errors = errors + + @property + def is_old_version(self) -> bool: + return self.version[0] <= 0 and self.version[1] <= 10 and self.version[2] < 1 + + @property + def version(self) -> tuple[int, int, int]: + """compute and set our version""" + version = _ensure_decoded(getattr(self.group._v_attrs, "pandas_version", None)) + try: + version = tuple(int(x) for x in version.split(".")) + if len(version) == 2: + version = version + (0,) + except AttributeError: + version = (0, 0, 0) + return version + + @property + def pandas_type(self): + return _ensure_decoded(getattr(self.group._v_attrs, "pandas_type", None)) + + def __repr__(self) -> str: + """return a pretty representation of myself""" + self.infer_axes() + s = self.shape + if s is not None: + if isinstance(s, (list, tuple)): + jshape = ",".join([pprint_thing(x) for x in s]) + s = f"[{jshape}]" + return f"{self.pandas_type:12.12} (shape->{s})" + return self.pandas_type + + def set_object_info(self) -> None: + """set my pandas type & version""" + self.attrs.pandas_type = str(self.pandas_kind) + self.attrs.pandas_version = str(_version) + + def copy(self) -> Fixed: + new_self = copy.copy(self) + return new_self + + @property + def shape(self): + return self.nrows + + @property + def pathname(self): + return self.group._v_pathname + + @property + def _handle(self): + return self.parent._handle + + @property + def _filters(self): + return self.parent._filters + + @property + def _complevel(self) -> int: + return self.parent._complevel + + @property + def _fletcher32(self) -> bool: + return self.parent._fletcher32 + + @property + def attrs(self): + return self.group._v_attrs + + def set_attrs(self) -> None: + """set our object attributes""" + + def get_attrs(self) -> None: + """get our object attributes""" + + @property + def storable(self): + """return my storable""" + return self.group + + @property + def is_exists(self) -> bool: + return False + + @property + def nrows(self): + return getattr(self.storable, "nrows", None) + + def validate(self, other) -> Literal[True] | None: + """validate against an existing storable""" + if other is None: + return None + return True + + def validate_version(self, where=None) -> None: + """are we trying to operate on an old version?""" + + def infer_axes(self) -> bool: + """ + infer the axes of my storer + return a boolean indicating if we have a valid storer or not + """ + s = self.storable + if s is None: + return False + self.get_attrs() + return True + + def read( + self, + where=None, + columns=None, + start: int | None = None, + stop: int | None = None, + ): + raise NotImplementedError( + "cannot read on an abstract storer: subclasses should implement" + ) + + def write(self, obj, **kwargs) -> None: + raise NotImplementedError( + "cannot write on an abstract storer: subclasses should implement" + ) + + def delete( + self, where=None, start: int | None = None, stop: int | None = None + ) -> None: + """ + support fully deleting the node in its entirety (only) - where + specification must be None + """ + if com.all_none(where, start, stop): + self._handle.remove_node(self.group, recursive=True) + return None + + raise TypeError("cannot delete on an abstract storer") + + +class GenericFixed(Fixed): + """a generified fixed version""" + + _index_type_map = {DatetimeIndex: "datetime", PeriodIndex: "period"} + _reverse_index_map = {v: k for k, v in _index_type_map.items()} + attributes: list[str] = [] + + # indexer helpers + def _class_to_alias(self, cls) -> str: + return self._index_type_map.get(cls, "") + + def _alias_to_class(self, alias): + if isinstance(alias, type): # pragma: no cover + # compat: for a short period of time master stored types + return alias + return self._reverse_index_map.get(alias, Index) + + def _get_index_factory(self, attrs): + index_class = self._alias_to_class( + _ensure_decoded(getattr(attrs, "index_class", "")) + ) + + factory: Callable + + if index_class == DatetimeIndex: + + def f(values, freq=None, tz=None): + # data are already in UTC, localize and convert if tz present + dta = DatetimeArray._simple_new( + values.values, dtype=values.dtype, freq=freq + ) + result = DatetimeIndex._simple_new(dta, name=None) + if tz is not None: + result = result.tz_localize("UTC").tz_convert(tz) + return result + + factory = f + elif index_class == PeriodIndex: + + def f(values, freq=None, tz=None): + dtype = PeriodDtype(freq) + parr = PeriodArray._simple_new(values, dtype=dtype) + return PeriodIndex._simple_new(parr, name=None) + + factory = f + else: + factory = index_class + + kwargs = {} + if "freq" in attrs: + kwargs["freq"] = attrs["freq"] + if index_class is Index: + # DTI/PI would be gotten by _alias_to_class + factory = TimedeltaIndex + + if "tz" in attrs: + if isinstance(attrs["tz"], bytes): + # created by python2 + kwargs["tz"] = attrs["tz"].decode("utf-8") + else: + # created by python3 + kwargs["tz"] = attrs["tz"] + assert index_class is DatetimeIndex # just checking + + return factory, kwargs + + def validate_read(self, columns, where) -> None: + """ + raise if any keywords are passed which are not-None + """ + if columns is not None: + raise TypeError( + "cannot pass a column specification when reading " + "a Fixed format store. this store must be selected in its entirety" + ) + if where is not None: + raise TypeError( + "cannot pass a where specification when reading " + "from a Fixed format store. this store must be selected in its entirety" + ) + + @property + def is_exists(self) -> bool: + return True + + def set_attrs(self) -> None: + """set our object attributes""" + self.attrs.encoding = self.encoding + self.attrs.errors = self.errors + + def get_attrs(self) -> None: + """retrieve our attributes""" + self.encoding = _ensure_encoding(getattr(self.attrs, "encoding", None)) + self.errors = _ensure_decoded(getattr(self.attrs, "errors", "strict")) + for n in self.attributes: + setattr(self, n, _ensure_decoded(getattr(self.attrs, n, None))) + + def write(self, obj, **kwargs) -> None: + self.set_attrs() + + def read_array(self, key: str, start: int | None = None, stop: int | None = None): + """read an array for the specified node (off of group""" + import tables + + node = getattr(self.group, key) + attrs = node._v_attrs + + transposed = getattr(attrs, "transposed", False) + + if isinstance(node, tables.VLArray): + ret = node[0][start:stop] + else: + dtype = _ensure_decoded(getattr(attrs, "value_type", None)) + shape = getattr(attrs, "shape", None) + + if shape is not None: + # length 0 axis + ret = np.empty(shape, dtype=dtype) + else: + ret = node[start:stop] + + if dtype and dtype.startswith("datetime64"): + # reconstruct a timezone if indicated + tz = getattr(attrs, "tz", None) + ret = _set_tz(ret, tz, coerce=True) + + elif dtype == "timedelta64": + ret = np.asarray(ret, dtype="m8[ns]") + + if transposed: + return ret.T + else: + return ret + + def read_index( + self, key: str, start: int | None = None, stop: int | None = None + ) -> Index: + variety = _ensure_decoded(getattr(self.attrs, f"{key}_variety")) + + if variety == "multi": + return self.read_multi_index(key, start=start, stop=stop) + elif variety == "regular": + node = getattr(self.group, key) + index = self.read_index_node(node, start=start, stop=stop) + return index + else: # pragma: no cover + raise TypeError(f"unrecognized index variety: {variety}") + + def write_index(self, key: str, index: Index) -> None: + if isinstance(index, MultiIndex): + setattr(self.attrs, f"{key}_variety", "multi") + self.write_multi_index(key, index) + else: + setattr(self.attrs, f"{key}_variety", "regular") + converted = _convert_index("index", index, self.encoding, self.errors) + + self.write_array(key, converted.values) + + node = getattr(self.group, key) + node._v_attrs.kind = converted.kind + node._v_attrs.name = index.name + + if isinstance(index, (DatetimeIndex, PeriodIndex)): + node._v_attrs.index_class = self._class_to_alias(type(index)) + + if isinstance(index, (DatetimeIndex, PeriodIndex, TimedeltaIndex)): + node._v_attrs.freq = index.freq + + if isinstance(index, DatetimeIndex) and index.tz is not None: + node._v_attrs.tz = _get_tz(index.tz) + + def write_multi_index(self, key: str, index: MultiIndex) -> None: + setattr(self.attrs, f"{key}_nlevels", index.nlevels) + + for i, (lev, level_codes, name) in enumerate( + zip(index.levels, index.codes, index.names) + ): + # write the level + if isinstance(lev.dtype, ExtensionDtype): + raise NotImplementedError( + "Saving a MultiIndex with an extension dtype is not supported." + ) + level_key = f"{key}_level{i}" + conv_level = _convert_index(level_key, lev, self.encoding, self.errors) + self.write_array(level_key, conv_level.values) + node = getattr(self.group, level_key) + node._v_attrs.kind = conv_level.kind + node._v_attrs.name = name + + # write the name + setattr(node._v_attrs, f"{key}_name{name}", name) + + # write the labels + label_key = f"{key}_label{i}" + self.write_array(label_key, level_codes) + + def read_multi_index( + self, key: str, start: int | None = None, stop: int | None = None + ) -> MultiIndex: + nlevels = getattr(self.attrs, f"{key}_nlevels") + + levels = [] + codes = [] + names: list[Hashable] = [] + for i in range(nlevels): + level_key = f"{key}_level{i}" + node = getattr(self.group, level_key) + lev = self.read_index_node(node, start=start, stop=stop) + levels.append(lev) + names.append(lev.name) + + label_key = f"{key}_label{i}" + level_codes = self.read_array(label_key, start=start, stop=stop) + codes.append(level_codes) + + return MultiIndex( + levels=levels, codes=codes, names=names, verify_integrity=True + ) + + def read_index_node( + self, node: Node, start: int | None = None, stop: int | None = None + ) -> Index: + data = node[start:stop] + # If the index was an empty array write_array_empty() will + # have written a sentinel. Here we replace it with the original. + if "shape" in node._v_attrs and np.prod(node._v_attrs.shape) == 0: + data = np.empty(node._v_attrs.shape, dtype=node._v_attrs.value_type) + kind = _ensure_decoded(node._v_attrs.kind) + name = None + + if "name" in node._v_attrs: + name = _ensure_str(node._v_attrs.name) + name = _ensure_decoded(name) + + attrs = node._v_attrs + factory, kwargs = self._get_index_factory(attrs) + + if kind in ("date", "object"): + index = factory( + _unconvert_index( + data, kind, encoding=self.encoding, errors=self.errors + ), + dtype=object, + **kwargs, + ) + else: + index = factory( + _unconvert_index( + data, kind, encoding=self.encoding, errors=self.errors + ), + **kwargs, + ) + + index.name = name + + return index + + def write_array_empty(self, key: str, value: ArrayLike) -> None: + """write a 0-len array""" + # ugly hack for length 0 axes + arr = np.empty((1,) * value.ndim) + self._handle.create_array(self.group, key, arr) + node = getattr(self.group, key) + node._v_attrs.value_type = str(value.dtype) + node._v_attrs.shape = value.shape + + def write_array( + self, key: str, obj: AnyArrayLike, items: Index | None = None + ) -> None: + # TODO: we only have a few tests that get here, the only EA + # that gets passed is DatetimeArray, and we never have + # both self._filters and EA + + value = extract_array(obj, extract_numpy=True) + + if key in self.group: + self._handle.remove_node(self.group, key) + + # Transform needed to interface with pytables row/col notation + empty_array = value.size == 0 + transposed = False + + if isinstance(value.dtype, CategoricalDtype): + raise NotImplementedError( + "Cannot store a category dtype in a HDF5 dataset that uses format=" + '"fixed". Use format="table".' + ) + if not empty_array: + if hasattr(value, "T"): + # ExtensionArrays (1d) may not have transpose. + value = value.T + transposed = True + + atom = None + if self._filters is not None: + with suppress(ValueError): + # get the atom for this datatype + atom = _tables().Atom.from_dtype(value.dtype) + + if atom is not None: + # We only get here if self._filters is non-None and + # the Atom.from_dtype call succeeded + + # create an empty chunked array and fill it from value + if not empty_array: + ca = self._handle.create_carray( + self.group, key, atom, value.shape, filters=self._filters + ) + ca[:] = value + + else: + self.write_array_empty(key, value) + + elif value.dtype.type == np.object_: + # infer the type, warn if we have a non-string type here (for + # performance) + inferred_type = lib.infer_dtype(value, skipna=False) + if empty_array: + pass + elif inferred_type == "string": + pass + else: + ws = performance_doc % (inferred_type, key, items) + warnings.warn(ws, PerformanceWarning, stacklevel=find_stack_level()) + + vlarr = self._handle.create_vlarray(self.group, key, _tables().ObjectAtom()) + vlarr.append(value) + + elif lib.is_np_dtype(value.dtype, "M"): + self._handle.create_array(self.group, key, value.view("i8")) + getattr(self.group, key)._v_attrs.value_type = str(value.dtype) + elif isinstance(value.dtype, DatetimeTZDtype): + # store as UTC + # with a zone + + # error: Item "ExtensionArray" of "Union[Any, ExtensionArray]" has no + # attribute "asi8" + self._handle.create_array( + self.group, key, value.asi8 # type: ignore[union-attr] + ) + + node = getattr(self.group, key) + # error: Item "ExtensionArray" of "Union[Any, ExtensionArray]" has no + # attribute "tz" + node._v_attrs.tz = _get_tz(value.tz) # type: ignore[union-attr] + node._v_attrs.value_type = f"datetime64[{value.dtype.unit}]" + elif lib.is_np_dtype(value.dtype, "m"): + self._handle.create_array(self.group, key, value.view("i8")) + getattr(self.group, key)._v_attrs.value_type = "timedelta64" + elif empty_array: + self.write_array_empty(key, value) + else: + self._handle.create_array(self.group, key, value) + + getattr(self.group, key)._v_attrs.transposed = transposed + + +class SeriesFixed(GenericFixed): + pandas_kind = "series" + attributes = ["name"] + + name: Hashable + + @property + def shape(self): + try: + return (len(self.group.values),) + except (TypeError, AttributeError): + return None + + def read( + self, + where=None, + columns=None, + start: int | None = None, + stop: int | None = None, + ) -> Series: + self.validate_read(columns, where) + index = self.read_index("index", start=start, stop=stop) + values = self.read_array("values", start=start, stop=stop) + result = Series(values, index=index, name=self.name, copy=False) + if using_pyarrow_string_dtype() and is_string_array(values, skipna=True): + result = result.astype("string[pyarrow_numpy]") + return result + + def write(self, obj, **kwargs) -> None: + super().write(obj, **kwargs) + self.write_index("index", obj.index) + self.write_array("values", obj) + self.attrs.name = obj.name + + +class BlockManagerFixed(GenericFixed): + attributes = ["ndim", "nblocks"] + + nblocks: int + + @property + def shape(self) -> Shape | None: + try: + ndim = self.ndim + + # items + items = 0 + for i in range(self.nblocks): + node = getattr(self.group, f"block{i}_items") + shape = getattr(node, "shape", None) + if shape is not None: + items += shape[0] + + # data shape + node = self.group.block0_values + shape = getattr(node, "shape", None) + if shape is not None: + shape = list(shape[0 : (ndim - 1)]) + else: + shape = [] + + shape.append(items) + + return shape + except AttributeError: + return None + + def read( + self, + where=None, + columns=None, + start: int | None = None, + stop: int | None = None, + ) -> DataFrame: + # start, stop applied to rows, so 0th axis only + self.validate_read(columns, where) + select_axis = self.obj_type()._get_block_manager_axis(0) + + axes = [] + for i in range(self.ndim): + _start, _stop = (start, stop) if i == select_axis else (None, None) + ax = self.read_index(f"axis{i}", start=_start, stop=_stop) + axes.append(ax) + + items = axes[0] + dfs = [] + + for i in range(self.nblocks): + blk_items = self.read_index(f"block{i}_items") + values = self.read_array(f"block{i}_values", start=_start, stop=_stop) + + columns = items[items.get_indexer(blk_items)] + df = DataFrame(values.T, columns=columns, index=axes[1], copy=False) + if using_pyarrow_string_dtype() and is_string_array(values, skipna=True): + df = df.astype("string[pyarrow_numpy]") + dfs.append(df) + + if len(dfs) > 0: + out = concat(dfs, axis=1, copy=True) + if using_copy_on_write(): + # with CoW, concat ignores the copy keyword. Here, we still want + # to copy to enforce optimized column-major layout + out = out.copy() + out = out.reindex(columns=items, copy=False) + return out + + return DataFrame(columns=axes[0], index=axes[1]) + + def write(self, obj, **kwargs) -> None: + super().write(obj, **kwargs) + + # TODO(ArrayManager) HDFStore relies on accessing the blocks + if isinstance(obj._mgr, ArrayManager): + obj = obj._as_manager("block") + + data = obj._mgr + if not data.is_consolidated(): + data = data.consolidate() + + self.attrs.ndim = data.ndim + for i, ax in enumerate(data.axes): + if i == 0 and (not ax.is_unique): + raise ValueError("Columns index has to be unique for fixed format") + self.write_index(f"axis{i}", ax) + + # Supporting mixed-type DataFrame objects...nontrivial + self.attrs.nblocks = len(data.blocks) + for i, blk in enumerate(data.blocks): + # I have no idea why, but writing values before items fixed #2299 + blk_items = data.items.take(blk.mgr_locs) + self.write_array(f"block{i}_values", blk.values, items=blk_items) + self.write_index(f"block{i}_items", blk_items) + + +class FrameFixed(BlockManagerFixed): + pandas_kind = "frame" + obj_type = DataFrame + + +class Table(Fixed): + """ + represent a table: + facilitate read/write of various types of tables + + Attrs in Table Node + ------------------- + These are attributes that are store in the main table node, they are + necessary to recreate these tables when read back in. + + index_axes : a list of tuples of the (original indexing axis and + index column) + non_index_axes: a list of tuples of the (original index axis and + columns on a non-indexing axis) + values_axes : a list of the columns which comprise the data of this + table + data_columns : a list of the columns that we are allowing indexing + (these become single columns in values_axes) + nan_rep : the string to use for nan representations for string + objects + levels : the names of levels + metadata : the names of the metadata columns + """ + + pandas_kind = "wide_table" + format_type: str = "table" # GH#30962 needed by dask + table_type: str + levels: int | list[Hashable] = 1 + is_table = True + + metadata: list + + def __init__( + self, + parent: HDFStore, + group: Node, + encoding: str | None = None, + errors: str = "strict", + index_axes: list[IndexCol] | None = None, + non_index_axes: list[tuple[AxisInt, Any]] | None = None, + values_axes: list[DataCol] | None = None, + data_columns: list | None = None, + info: dict | None = None, + nan_rep=None, + ) -> None: + super().__init__(parent, group, encoding=encoding, errors=errors) + self.index_axes = index_axes or [] + self.non_index_axes = non_index_axes or [] + self.values_axes = values_axes or [] + self.data_columns = data_columns or [] + self.info = info or {} + self.nan_rep = nan_rep + + @property + def table_type_short(self) -> str: + return self.table_type.split("_")[0] + + def __repr__(self) -> str: + """return a pretty representation of myself""" + self.infer_axes() + jdc = ",".join(self.data_columns) if len(self.data_columns) else "" + dc = f",dc->[{jdc}]" + + ver = "" + if self.is_old_version: + jver = ".".join([str(x) for x in self.version]) + ver = f"[{jver}]" + + jindex_axes = ",".join([a.name for a in self.index_axes]) + return ( + f"{self.pandas_type:12.12}{ver} " + f"(typ->{self.table_type_short},nrows->{self.nrows}," + f"ncols->{self.ncols},indexers->[{jindex_axes}]{dc})" + ) + + def __getitem__(self, c: str): + """return the axis for c""" + for a in self.axes: + if c == a.name: + return a + return None + + def validate(self, other) -> None: + """validate against an existing table""" + if other is None: + return + + if other.table_type != self.table_type: + raise TypeError( + "incompatible table_type with existing " + f"[{other.table_type} - {self.table_type}]" + ) + + for c in ["index_axes", "non_index_axes", "values_axes"]: + sv = getattr(self, c, None) + ov = getattr(other, c, None) + if sv != ov: + # show the error for the specific axes + # Argument 1 to "enumerate" has incompatible type + # "Optional[Any]"; expected "Iterable[Any]" [arg-type] + for i, sax in enumerate(sv): # type: ignore[arg-type] + # Value of type "Optional[Any]" is not indexable [index] + oax = ov[i] # type: ignore[index] + if sax != oax: + raise ValueError( + f"invalid combination of [{c}] on appending data " + f"[{sax}] vs current table [{oax}]" + ) + + # should never get here + raise Exception( + f"invalid combination of [{c}] on appending data [{sv}] vs " + f"current table [{ov}]" + ) + + @property + def is_multi_index(self) -> bool: + """the levels attribute is 1 or a list in the case of a multi-index""" + return isinstance(self.levels, list) + + def validate_multiindex( + self, obj: DataFrame | Series + ) -> tuple[DataFrame, list[Hashable]]: + """ + validate that we can store the multi-index; reset and return the + new object + """ + levels = com.fill_missing_names(obj.index.names) + try: + reset_obj = obj.reset_index() + except ValueError as err: + raise ValueError( + "duplicate names/columns in the multi-index when storing as a table" + ) from err + assert isinstance(reset_obj, DataFrame) # for mypy + return reset_obj, levels + + @property + def nrows_expected(self) -> int: + """based on our axes, compute the expected nrows""" + return np.prod([i.cvalues.shape[0] for i in self.index_axes]) + + @property + def is_exists(self) -> bool: + """has this table been created""" + return "table" in self.group + + @property + def storable(self): + return getattr(self.group, "table", None) + + @property + def table(self): + """return the table group (this is my storable)""" + return self.storable + + @property + def dtype(self): + return self.table.dtype + + @property + def description(self): + return self.table.description + + @property + def axes(self) -> itertools.chain[IndexCol]: + return itertools.chain(self.index_axes, self.values_axes) + + @property + def ncols(self) -> int: + """the number of total columns in the values axes""" + return sum(len(a.values) for a in self.values_axes) + + @property + def is_transposed(self) -> bool: + return False + + @property + def data_orientation(self) -> tuple[int, ...]: + """return a tuple of my permutated axes, non_indexable at the front""" + return tuple( + itertools.chain( + [int(a[0]) for a in self.non_index_axes], + [int(a.axis) for a in self.index_axes], + ) + ) + + def queryables(self) -> dict[str, Any]: + """return a dict of the kinds allowable columns for this object""" + # mypy doesn't recognize DataFrame._AXIS_NAMES, so we re-write it here + axis_names = {0: "index", 1: "columns"} + + # compute the values_axes queryables + d1 = [(a.cname, a) for a in self.index_axes] + d2 = [(axis_names[axis], None) for axis, values in self.non_index_axes] + d3 = [ + (v.cname, v) for v in self.values_axes if v.name in set(self.data_columns) + ] + + return dict(d1 + d2 + d3) + + def index_cols(self): + """return a list of my index cols""" + # Note: each `i.cname` below is assured to be a str. + return [(i.axis, i.cname) for i in self.index_axes] + + def values_cols(self) -> list[str]: + """return a list of my values cols""" + return [i.cname for i in self.values_axes] + + def _get_metadata_path(self, key: str) -> str: + """return the metadata pathname for this key""" + group = self.group._v_pathname + return f"{group}/meta/{key}/meta" + + def write_metadata(self, key: str, values: np.ndarray) -> None: + """ + Write out a metadata array to the key as a fixed-format Series. + + Parameters + ---------- + key : str + values : ndarray + """ + self.parent.put( + self._get_metadata_path(key), + Series(values, copy=False), + format="table", + encoding=self.encoding, + errors=self.errors, + nan_rep=self.nan_rep, + ) + + def read_metadata(self, key: str): + """return the meta data array for this key""" + if getattr(getattr(self.group, "meta", None), key, None) is not None: + return self.parent.select(self._get_metadata_path(key)) + return None + + def set_attrs(self) -> None: + """set our table type & indexables""" + self.attrs.table_type = str(self.table_type) + self.attrs.index_cols = self.index_cols() + self.attrs.values_cols = self.values_cols() + self.attrs.non_index_axes = self.non_index_axes + self.attrs.data_columns = self.data_columns + self.attrs.nan_rep = self.nan_rep + self.attrs.encoding = self.encoding + self.attrs.errors = self.errors + self.attrs.levels = self.levels + self.attrs.info = self.info + + def get_attrs(self) -> None: + """retrieve our attributes""" + self.non_index_axes = getattr(self.attrs, "non_index_axes", None) or [] + self.data_columns = getattr(self.attrs, "data_columns", None) or [] + self.info = getattr(self.attrs, "info", None) or {} + self.nan_rep = getattr(self.attrs, "nan_rep", None) + self.encoding = _ensure_encoding(getattr(self.attrs, "encoding", None)) + self.errors = _ensure_decoded(getattr(self.attrs, "errors", "strict")) + self.levels: list[Hashable] = getattr(self.attrs, "levels", None) or [] + self.index_axes = [a for a in self.indexables if a.is_an_indexable] + self.values_axes = [a for a in self.indexables if not a.is_an_indexable] + + def validate_version(self, where=None) -> None: + """are we trying to operate on an old version?""" + if where is not None: + if self.is_old_version: + ws = incompatibility_doc % ".".join([str(x) for x in self.version]) + warnings.warn( + ws, + IncompatibilityWarning, + stacklevel=find_stack_level(), + ) + + def validate_min_itemsize(self, min_itemsize) -> None: + """ + validate the min_itemsize doesn't contain items that are not in the + axes this needs data_columns to be defined + """ + if min_itemsize is None: + return + if not isinstance(min_itemsize, dict): + return + + q = self.queryables() + for k in min_itemsize: + # ok, apply generally + if k == "values": + continue + if k not in q: + raise ValueError( + f"min_itemsize has the key [{k}] which is not an axis or " + "data_column" + ) + + @cache_readonly + def indexables(self): + """create/cache the indexables if they don't exist""" + _indexables = [] + + desc = self.description + table_attrs = self.table.attrs + + # Note: each of the `name` kwargs below are str, ensured + # by the definition in index_cols. + # index columns + for i, (axis, name) in enumerate(self.attrs.index_cols): + atom = getattr(desc, name) + md = self.read_metadata(name) + meta = "category" if md is not None else None + + kind_attr = f"{name}_kind" + kind = getattr(table_attrs, kind_attr, None) + + index_col = IndexCol( + name=name, + axis=axis, + pos=i, + kind=kind, + typ=atom, + table=self.table, + meta=meta, + metadata=md, + ) + _indexables.append(index_col) + + # values columns + dc = set(self.data_columns) + base_pos = len(_indexables) + + def f(i, c): + assert isinstance(c, str) + klass = DataCol + if c in dc: + klass = DataIndexableCol + + atom = getattr(desc, c) + adj_name = _maybe_adjust_name(c, self.version) + + # TODO: why kind_attr here? + values = getattr(table_attrs, f"{adj_name}_kind", None) + dtype = getattr(table_attrs, f"{adj_name}_dtype", None) + # Argument 1 to "_dtype_to_kind" has incompatible type + # "Optional[Any]"; expected "str" [arg-type] + kind = _dtype_to_kind(dtype) # type: ignore[arg-type] + + md = self.read_metadata(c) + # TODO: figure out why these two versions of `meta` dont always match. + # meta = "category" if md is not None else None + meta = getattr(table_attrs, f"{adj_name}_meta", None) + + obj = klass( + name=adj_name, + cname=c, + values=values, + kind=kind, + pos=base_pos + i, + typ=atom, + table=self.table, + meta=meta, + metadata=md, + dtype=dtype, + ) + return obj + + # Note: the definition of `values_cols` ensures that each + # `c` below is a str. + _indexables.extend([f(i, c) for i, c in enumerate(self.attrs.values_cols)]) + + return _indexables + + def create_index( + self, columns=None, optlevel=None, kind: str | None = None + ) -> None: + """ + Create a pytables index on the specified columns. + + Parameters + ---------- + columns : None, bool, or listlike[str] + Indicate which columns to create an index on. + + * False : Do not create any indexes. + * True : Create indexes on all columns. + * None : Create indexes on all columns. + * listlike : Create indexes on the given columns. + + optlevel : int or None, default None + Optimization level, if None, pytables defaults to 6. + kind : str or None, default None + Kind of index, if None, pytables defaults to "medium". + + Raises + ------ + TypeError if trying to create an index on a complex-type column. + + Notes + ----- + Cannot index Time64Col or ComplexCol. + Pytables must be >= 3.0. + """ + if not self.infer_axes(): + return + if columns is False: + return + + # index all indexables and data_columns + if columns is None or columns is True: + columns = [a.cname for a in self.axes if a.is_data_indexable] + if not isinstance(columns, (tuple, list)): + columns = [columns] + + kw = {} + if optlevel is not None: + kw["optlevel"] = optlevel + if kind is not None: + kw["kind"] = kind + + table = self.table + for c in columns: + v = getattr(table.cols, c, None) + if v is not None: + # remove the index if the kind/optlevel have changed + if v.is_indexed: + index = v.index + cur_optlevel = index.optlevel + cur_kind = index.kind + + if kind is not None and cur_kind != kind: + v.remove_index() + else: + kw["kind"] = cur_kind + + if optlevel is not None and cur_optlevel != optlevel: + v.remove_index() + else: + kw["optlevel"] = cur_optlevel + + # create the index + if not v.is_indexed: + if v.type.startswith("complex"): + raise TypeError( + "Columns containing complex values can be stored but " + "cannot be indexed when using table format. Either use " + "fixed format, set index=False, or do not include " + "the columns containing complex values to " + "data_columns when initializing the table." + ) + v.create_index(**kw) + elif c in self.non_index_axes[0][1]: + # GH 28156 + raise AttributeError( + f"column {c} is not a data_column.\n" + f"In order to read column {c} you must reload the dataframe \n" + f"into HDFStore and include {c} with the data_columns argument." + ) + + def _read_axes( + self, where, start: int | None = None, stop: int | None = None + ) -> list[tuple[np.ndarray, np.ndarray] | tuple[Index, Index]]: + """ + Create the axes sniffed from the table. + + Parameters + ---------- + where : ??? + start : int or None, default None + stop : int or None, default None + + Returns + ------- + List[Tuple[index_values, column_values]] + """ + # create the selection + selection = Selection(self, where=where, start=start, stop=stop) + values = selection.select() + + results = [] + # convert the data + for a in self.axes: + a.set_info(self.info) + res = a.convert( + values, + nan_rep=self.nan_rep, + encoding=self.encoding, + errors=self.errors, + ) + results.append(res) + + return results + + @classmethod + def get_object(cls, obj, transposed: bool): + """return the data for this obj""" + return obj + + def validate_data_columns(self, data_columns, min_itemsize, non_index_axes): + """ + take the input data_columns and min_itemize and create a data + columns spec + """ + if not len(non_index_axes): + return [] + + axis, axis_labels = non_index_axes[0] + info = self.info.get(axis, {}) + if info.get("type") == "MultiIndex" and data_columns: + raise ValueError( + f"cannot use a multi-index on axis [{axis}] with " + f"data_columns {data_columns}" + ) + + # evaluate the passed data_columns, True == use all columns + # take only valid axis labels + if data_columns is True: + data_columns = list(axis_labels) + elif data_columns is None: + data_columns = [] + + # if min_itemsize is a dict, add the keys (exclude 'values') + if isinstance(min_itemsize, dict): + existing_data_columns = set(data_columns) + data_columns = list(data_columns) # ensure we do not modify + data_columns.extend( + [ + k + for k in min_itemsize.keys() + if k != "values" and k not in existing_data_columns + ] + ) + + # return valid columns in the order of our axis + return [c for c in data_columns if c in axis_labels] + + def _create_axes( + self, + axes, + obj: DataFrame, + validate: bool = True, + nan_rep=None, + data_columns=None, + min_itemsize=None, + ): + """ + Create and return the axes. + + Parameters + ---------- + axes: list or None + The names or numbers of the axes to create. + obj : DataFrame + The object to create axes on. + validate: bool, default True + Whether to validate the obj against an existing object already written. + nan_rep : + A value to use for string column nan_rep. + data_columns : List[str], True, or None, default None + Specify the columns that we want to create to allow indexing on. + + * True : Use all available columns. + * None : Use no columns. + * List[str] : Use the specified columns. + + min_itemsize: Dict[str, int] or None, default None + The min itemsize for a column in bytes. + """ + if not isinstance(obj, DataFrame): + group = self.group._v_name + raise TypeError( + f"cannot properly create the storer for: [group->{group}," + f"value->{type(obj)}]" + ) + + # set the default axes if needed + if axes is None: + axes = [0] + + # map axes to numbers + axes = [obj._get_axis_number(a) for a in axes] + + # do we have an existing table (if so, use its axes & data_columns) + if self.infer_axes(): + table_exists = True + axes = [a.axis for a in self.index_axes] + data_columns = list(self.data_columns) + nan_rep = self.nan_rep + # TODO: do we always have validate=True here? + else: + table_exists = False + + new_info = self.info + + assert self.ndim == 2 # with next check, we must have len(axes) == 1 + # currently support on ndim-1 axes + if len(axes) != self.ndim - 1: + raise ValueError( + "currently only support ndim-1 indexers in an AppendableTable" + ) + + # create according to the new data + new_non_index_axes: list = [] + + # nan_representation + if nan_rep is None: + nan_rep = "nan" + + # We construct the non-index-axis first, since that alters new_info + idx = next(x for x in [0, 1] if x not in axes) + + a = obj.axes[idx] + # we might be able to change the axes on the appending data if necessary + append_axis = list(a) + if table_exists: + indexer = len(new_non_index_axes) # i.e. 0 + exist_axis = self.non_index_axes[indexer][1] + if not array_equivalent( + np.array(append_axis), + np.array(exist_axis), + strict_nan=True, + dtype_equal=True, + ): + # ahah! -> reindex + if array_equivalent( + np.array(sorted(append_axis)), + np.array(sorted(exist_axis)), + strict_nan=True, + dtype_equal=True, + ): + append_axis = exist_axis + + # the non_index_axes info + info = new_info.setdefault(idx, {}) + info["names"] = list(a.names) + info["type"] = type(a).__name__ + + new_non_index_axes.append((idx, append_axis)) + + # Now we can construct our new index axis + idx = axes[0] + a = obj.axes[idx] + axis_name = obj._get_axis_name(idx) + new_index = _convert_index(axis_name, a, self.encoding, self.errors) + new_index.axis = idx + + # Because we are always 2D, there is only one new_index, so + # we know it will have pos=0 + new_index.set_pos(0) + new_index.update_info(new_info) + new_index.maybe_set_size(min_itemsize) # check for column conflicts + + new_index_axes = [new_index] + j = len(new_index_axes) # i.e. 1 + assert j == 1 + + # reindex by our non_index_axes & compute data_columns + assert len(new_non_index_axes) == 1 + for a in new_non_index_axes: + obj = _reindex_axis(obj, a[0], a[1]) + + transposed = new_index.axis == 1 + + # figure out data_columns and get out blocks + data_columns = self.validate_data_columns( + data_columns, min_itemsize, new_non_index_axes + ) + + frame = self.get_object(obj, transposed)._consolidate() + + blocks, blk_items = self._get_blocks_and_items( + frame, table_exists, new_non_index_axes, self.values_axes, data_columns + ) + + # add my values + vaxes = [] + for i, (blk, b_items) in enumerate(zip(blocks, blk_items)): + # shape of the data column are the indexable axes + klass = DataCol + name = None + + # we have a data_column + if data_columns and len(b_items) == 1 and b_items[0] in data_columns: + klass = DataIndexableCol + name = b_items[0] + if not (name is None or isinstance(name, str)): + # TODO: should the message here be more specifically non-str? + raise ValueError("cannot have non-object label DataIndexableCol") + + # make sure that we match up the existing columns + # if we have an existing table + existing_col: DataCol | None + + if table_exists and validate: + try: + existing_col = self.values_axes[i] + except (IndexError, KeyError) as err: + raise ValueError( + f"Incompatible appended table [{blocks}]" + f"with existing table [{self.values_axes}]" + ) from err + else: + existing_col = None + + new_name = name or f"values_block_{i}" + data_converted = _maybe_convert_for_string_atom( + new_name, + blk.values, + existing_col=existing_col, + min_itemsize=min_itemsize, + nan_rep=nan_rep, + encoding=self.encoding, + errors=self.errors, + columns=b_items, + ) + adj_name = _maybe_adjust_name(new_name, self.version) + + typ = klass._get_atom(data_converted) + kind = _dtype_to_kind(data_converted.dtype.name) + tz = None + if getattr(data_converted, "tz", None) is not None: + tz = _get_tz(data_converted.tz) + + meta = metadata = ordered = None + if isinstance(data_converted.dtype, CategoricalDtype): + ordered = data_converted.ordered + meta = "category" + metadata = np.asarray(data_converted.categories).ravel() + + data, dtype_name = _get_data_and_dtype_name(data_converted) + + col = klass( + name=adj_name, + cname=new_name, + values=list(b_items), + typ=typ, + pos=j, + kind=kind, + tz=tz, + ordered=ordered, + meta=meta, + metadata=metadata, + dtype=dtype_name, + data=data, + ) + col.update_info(new_info) + + vaxes.append(col) + + j += 1 + + dcs = [col.name for col in vaxes if col.is_data_indexable] + + new_table = type(self)( + parent=self.parent, + group=self.group, + encoding=self.encoding, + errors=self.errors, + index_axes=new_index_axes, + non_index_axes=new_non_index_axes, + values_axes=vaxes, + data_columns=dcs, + info=new_info, + nan_rep=nan_rep, + ) + if hasattr(self, "levels"): + # TODO: get this into constructor, only for appropriate subclass + new_table.levels = self.levels + + new_table.validate_min_itemsize(min_itemsize) + + if validate and table_exists: + new_table.validate(self) + + return new_table + + @staticmethod + def _get_blocks_and_items( + frame: DataFrame, + table_exists: bool, + new_non_index_axes, + values_axes, + data_columns, + ): + # Helper to clarify non-state-altering parts of _create_axes + + # TODO(ArrayManager) HDFStore relies on accessing the blocks + if isinstance(frame._mgr, ArrayManager): + frame = frame._as_manager("block") + + def get_blk_items(mgr): + return [mgr.items.take(blk.mgr_locs) for blk in mgr.blocks] + + mgr = frame._mgr + mgr = cast(BlockManager, mgr) + blocks: list[Block] = list(mgr.blocks) + blk_items: list[Index] = get_blk_items(mgr) + + if len(data_columns): + # TODO: prove that we only get here with axis == 1? + # It is the case in all extant tests, but NOT the case + # outside this `if len(data_columns)` check. + + axis, axis_labels = new_non_index_axes[0] + new_labels = Index(axis_labels).difference(Index(data_columns)) + mgr = frame.reindex(new_labels, axis=axis)._mgr + mgr = cast(BlockManager, mgr) + + blocks = list(mgr.blocks) + blk_items = get_blk_items(mgr) + for c in data_columns: + # This reindex would raise ValueError if we had a duplicate + # index, so we can infer that (as long as axis==1) we + # get a single column back, so a single block. + mgr = frame.reindex([c], axis=axis)._mgr + mgr = cast(BlockManager, mgr) + blocks.extend(mgr.blocks) + blk_items.extend(get_blk_items(mgr)) + + # reorder the blocks in the same order as the existing table if we can + if table_exists: + by_items = { + tuple(b_items.tolist()): (b, b_items) + for b, b_items in zip(blocks, blk_items) + } + new_blocks: list[Block] = [] + new_blk_items = [] + for ea in values_axes: + items = tuple(ea.values) + try: + b, b_items = by_items.pop(items) + new_blocks.append(b) + new_blk_items.append(b_items) + except (IndexError, KeyError) as err: + jitems = ",".join([pprint_thing(item) for item in items]) + raise ValueError( + f"cannot match existing table structure for [{jitems}] " + "on appending data" + ) from err + blocks = new_blocks + blk_items = new_blk_items + + return blocks, blk_items + + def process_axes(self, obj, selection: Selection, columns=None) -> DataFrame: + """process axes filters""" + # make a copy to avoid side effects + if columns is not None: + columns = list(columns) + + # make sure to include levels if we have them + if columns is not None and self.is_multi_index: + assert isinstance(self.levels, list) # assured by is_multi_index + for n in self.levels: + if n not in columns: + columns.insert(0, n) + + # reorder by any non_index_axes & limit to the select columns + for axis, labels in self.non_index_axes: + obj = _reindex_axis(obj, axis, labels, columns) + + def process_filter(field, filt, op): + for axis_name in obj._AXIS_ORDERS: + axis_number = obj._get_axis_number(axis_name) + axis_values = obj._get_axis(axis_name) + assert axis_number is not None + + # see if the field is the name of an axis + if field == axis_name: + # if we have a multi-index, then need to include + # the levels + if self.is_multi_index: + filt = filt.union(Index(self.levels)) + + takers = op(axis_values, filt) + return obj.loc(axis=axis_number)[takers] + + # this might be the name of a file IN an axis + elif field in axis_values: + # we need to filter on this dimension + values = ensure_index(getattr(obj, field).values) + filt = ensure_index(filt) + + # hack until we support reversed dim flags + if isinstance(obj, DataFrame): + axis_number = 1 - axis_number + + takers = op(values, filt) + return obj.loc(axis=axis_number)[takers] + + raise ValueError(f"cannot find the field [{field}] for filtering!") + + # apply the selection filters (but keep in the same order) + if selection.filter is not None: + for field, op, filt in selection.filter.format(): + obj = process_filter(field, filt, op) + + return obj + + def create_description( + self, + complib, + complevel: int | None, + fletcher32: bool, + expectedrows: int | None, + ) -> dict[str, Any]: + """create the description of the table from the axes & values""" + # provided expected rows if its passed + if expectedrows is None: + expectedrows = max(self.nrows_expected, 10000) + + d = {"name": "table", "expectedrows": expectedrows} + + # description from the axes & values + d["description"] = {a.cname: a.typ for a in self.axes} + + if complib: + if complevel is None: + complevel = self._complevel or 9 + filters = _tables().Filters( + complevel=complevel, + complib=complib, + fletcher32=fletcher32 or self._fletcher32, + ) + d["filters"] = filters + elif self._filters is not None: + d["filters"] = self._filters + + return d + + def read_coordinates( + self, where=None, start: int | None = None, stop: int | None = None + ): + """ + select coordinates (row numbers) from a table; return the + coordinates object + """ + # validate the version + self.validate_version(where) + + # infer the data kind + if not self.infer_axes(): + return False + + # create the selection + selection = Selection(self, where=where, start=start, stop=stop) + coords = selection.select_coords() + if selection.filter is not None: + for field, op, filt in selection.filter.format(): + data = self.read_column( + field, start=coords.min(), stop=coords.max() + 1 + ) + coords = coords[op(data.iloc[coords - coords.min()], filt).values] + + return Index(coords) + + def read_column( + self, + column: str, + where=None, + start: int | None = None, + stop: int | None = None, + ): + """ + return a single column from the table, generally only indexables + are interesting + """ + # validate the version + self.validate_version() + + # infer the data kind + if not self.infer_axes(): + return False + + if where is not None: + raise TypeError("read_column does not currently accept a where clause") + + # find the axes + for a in self.axes: + if column == a.name: + if not a.is_data_indexable: + raise ValueError( + f"column [{column}] can not be extracted individually; " + "it is not data indexable" + ) + + # column must be an indexable or a data column + c = getattr(self.table.cols, column) + a.set_info(self.info) + col_values = a.convert( + c[start:stop], + nan_rep=self.nan_rep, + encoding=self.encoding, + errors=self.errors, + ) + return Series(_set_tz(col_values[1], a.tz), name=column, copy=False) + + raise KeyError(f"column [{column}] not found in the table") + + +class WORMTable(Table): + """ + a write-once read-many table: this format DOES NOT ALLOW appending to a + table. writing is a one-time operation the data are stored in a format + that allows for searching the data on disk + """ + + table_type = "worm" + + def read( + self, + where=None, + columns=None, + start: int | None = None, + stop: int | None = None, + ): + """ + read the indices and the indexing array, calculate offset rows and return + """ + raise NotImplementedError("WORMTable needs to implement read") + + def write(self, obj, **kwargs) -> None: + """ + write in a format that we can search later on (but cannot append + to): write out the indices and the values using _write_array + (e.g. a CArray) create an indexing table so that we can search + """ + raise NotImplementedError("WORMTable needs to implement write") + + +class AppendableTable(Table): + """support the new appendable table formats""" + + table_type = "appendable" + + # error: Signature of "write" incompatible with supertype "Fixed" + def write( # type: ignore[override] + self, + obj, + axes=None, + append: bool = False, + complib=None, + complevel=None, + fletcher32=None, + min_itemsize=None, + chunksize: int | None = None, + expectedrows=None, + dropna: bool = False, + nan_rep=None, + data_columns=None, + track_times: bool = True, + ) -> None: + if not append and self.is_exists: + self._handle.remove_node(self.group, "table") + + # create the axes + table = self._create_axes( + axes=axes, + obj=obj, + validate=append, + min_itemsize=min_itemsize, + nan_rep=nan_rep, + data_columns=data_columns, + ) + + for a in table.axes: + a.validate_names() + + if not table.is_exists: + # create the table + options = table.create_description( + complib=complib, + complevel=complevel, + fletcher32=fletcher32, + expectedrows=expectedrows, + ) + + # set the table attributes + table.set_attrs() + + options["track_times"] = track_times + + # create the table + table._handle.create_table(table.group, **options) + + # update my info + table.attrs.info = table.info + + # validate the axes and set the kinds + for a in table.axes: + a.validate_and_set(table, append) + + # add the rows + table.write_data(chunksize, dropna=dropna) + + def write_data(self, chunksize: int | None, dropna: bool = False) -> None: + """ + we form the data into a 2-d including indexes,values,mask write chunk-by-chunk + """ + names = self.dtype.names + nrows = self.nrows_expected + + # if dropna==True, then drop ALL nan rows + masks = [] + if dropna: + for a in self.values_axes: + # figure the mask: only do if we can successfully process this + # column, otherwise ignore the mask + mask = isna(a.data).all(axis=0) + if isinstance(mask, np.ndarray): + masks.append(mask.astype("u1", copy=False)) + + # consolidate masks + if len(masks): + mask = masks[0] + for m in masks[1:]: + mask = mask & m + mask = mask.ravel() + else: + mask = None + + # broadcast the indexes if needed + indexes = [a.cvalues for a in self.index_axes] + nindexes = len(indexes) + assert nindexes == 1, nindexes # ensures we dont need to broadcast + + # transpose the values so first dimension is last + # reshape the values if needed + values = [a.take_data() for a in self.values_axes] + values = [v.transpose(np.roll(np.arange(v.ndim), v.ndim - 1)) for v in values] + bvalues = [] + for i, v in enumerate(values): + new_shape = (nrows,) + self.dtype[names[nindexes + i]].shape + bvalues.append(v.reshape(new_shape)) + + # write the chunks + if chunksize is None: + chunksize = 100000 + + rows = np.empty(min(chunksize, nrows), dtype=self.dtype) + chunks = nrows // chunksize + 1 + for i in range(chunks): + start_i = i * chunksize + end_i = min((i + 1) * chunksize, nrows) + if start_i >= end_i: + break + + self.write_data_chunk( + rows, + indexes=[a[start_i:end_i] for a in indexes], + mask=mask[start_i:end_i] if mask is not None else None, + values=[v[start_i:end_i] for v in bvalues], + ) + + def write_data_chunk( + self, + rows: np.ndarray, + indexes: list[np.ndarray], + mask: npt.NDArray[np.bool_] | None, + values: list[np.ndarray], + ) -> None: + """ + Parameters + ---------- + rows : an empty memory space where we are putting the chunk + indexes : an array of the indexes + mask : an array of the masks + values : an array of the values + """ + # 0 len + for v in values: + if not np.prod(v.shape): + return + + nrows = indexes[0].shape[0] + if nrows != len(rows): + rows = np.empty(nrows, dtype=self.dtype) + names = self.dtype.names + nindexes = len(indexes) + + # indexes + for i, idx in enumerate(indexes): + rows[names[i]] = idx + + # values + for i, v in enumerate(values): + rows[names[i + nindexes]] = v + + # mask + if mask is not None: + m = ~mask.ravel().astype(bool, copy=False) + if not m.all(): + rows = rows[m] + + if len(rows): + self.table.append(rows) + self.table.flush() + + def delete(self, where=None, start: int | None = None, stop: int | None = None): + # delete all rows (and return the nrows) + if where is None or not len(where): + if start is None and stop is None: + nrows = self.nrows + self._handle.remove_node(self.group, recursive=True) + else: + # pytables<3.0 would remove a single row with stop=None + if stop is None: + stop = self.nrows + nrows = self.table.remove_rows(start=start, stop=stop) + self.table.flush() + return nrows + + # infer the data kind + if not self.infer_axes(): + return None + + # create the selection + table = self.table + selection = Selection(self, where, start=start, stop=stop) + values = selection.select_coords() + + # delete the rows in reverse order + sorted_series = Series(values, copy=False).sort_values() + ln = len(sorted_series) + + if ln: + # construct groups of consecutive rows + diff = sorted_series.diff() + groups = list(diff[diff > 1].index) + + # 1 group + if not len(groups): + groups = [0] + + # final element + if groups[-1] != ln: + groups.append(ln) + + # initial element + if groups[0] != 0: + groups.insert(0, 0) + + # we must remove in reverse order! + pg = groups.pop() + for g in reversed(groups): + rows = sorted_series.take(range(g, pg)) + table.remove_rows( + start=rows[rows.index[0]], stop=rows[rows.index[-1]] + 1 + ) + pg = g + + self.table.flush() + + # return the number of rows removed + return ln + + +class AppendableFrameTable(AppendableTable): + """support the new appendable table formats""" + + pandas_kind = "frame_table" + table_type = "appendable_frame" + ndim = 2 + obj_type: type[DataFrame | Series] = DataFrame + + @property + def is_transposed(self) -> bool: + return self.index_axes[0].axis == 1 + + @classmethod + def get_object(cls, obj, transposed: bool): + """these are written transposed""" + if transposed: + obj = obj.T + return obj + + def read( + self, + where=None, + columns=None, + start: int | None = None, + stop: int | None = None, + ): + # validate the version + self.validate_version(where) + + # infer the data kind + if not self.infer_axes(): + return None + + result = self._read_axes(where=where, start=start, stop=stop) + + info = ( + self.info.get(self.non_index_axes[0][0], {}) + if len(self.non_index_axes) + else {} + ) + + inds = [i for i, ax in enumerate(self.axes) if ax is self.index_axes[0]] + assert len(inds) == 1 + ind = inds[0] + + index = result[ind][0] + + frames = [] + for i, a in enumerate(self.axes): + if a not in self.values_axes: + continue + index_vals, cvalues = result[i] + + # we could have a multi-index constructor here + # ensure_index doesn't recognized our list-of-tuples here + if info.get("type") != "MultiIndex": + cols = Index(index_vals) + else: + cols = MultiIndex.from_tuples(index_vals) + + names = info.get("names") + if names is not None: + cols.set_names(names, inplace=True) + + if self.is_transposed: + values = cvalues + index_ = cols + cols_ = Index(index, name=getattr(index, "name", None)) + else: + values = cvalues.T + index_ = Index(index, name=getattr(index, "name", None)) + cols_ = cols + + # if we have a DataIndexableCol, its shape will only be 1 dim + if values.ndim == 1 and isinstance(values, np.ndarray): + values = values.reshape((1, values.shape[0])) + + if isinstance(values, np.ndarray): + df = DataFrame(values.T, columns=cols_, index=index_, copy=False) + elif isinstance(values, Index): + df = DataFrame(values, columns=cols_, index=index_) + else: + # Categorical + df = DataFrame._from_arrays([values], columns=cols_, index=index_) + if not (using_pyarrow_string_dtype() and values.dtype.kind == "O"): + assert (df.dtypes == values.dtype).all(), (df.dtypes, values.dtype) + if using_pyarrow_string_dtype() and is_string_array( + values, # type: ignore[arg-type] + skipna=True, + ): + df = df.astype("string[pyarrow_numpy]") + frames.append(df) + + if len(frames) == 1: + df = frames[0] + else: + df = concat(frames, axis=1) + + selection = Selection(self, where=where, start=start, stop=stop) + # apply the selection filters & axis orderings + df = self.process_axes(df, selection=selection, columns=columns) + return df + + +class AppendableSeriesTable(AppendableFrameTable): + """support the new appendable table formats""" + + pandas_kind = "series_table" + table_type = "appendable_series" + ndim = 2 + obj_type = Series + + @property + def is_transposed(self) -> bool: + return False + + @classmethod + def get_object(cls, obj, transposed: bool): + return obj + + # error: Signature of "write" incompatible with supertype "Fixed" + def write(self, obj, data_columns=None, **kwargs) -> None: # type: ignore[override] + """we are going to write this as a frame table""" + if not isinstance(obj, DataFrame): + name = obj.name or "values" + obj = obj.to_frame(name) + super().write(obj=obj, data_columns=obj.columns.tolist(), **kwargs) + + def read( + self, + where=None, + columns=None, + start: int | None = None, + stop: int | None = None, + ) -> Series: + is_multi_index = self.is_multi_index + if columns is not None and is_multi_index: + assert isinstance(self.levels, list) # needed for mypy + for n in self.levels: + if n not in columns: + columns.insert(0, n) + s = super().read(where=where, columns=columns, start=start, stop=stop) + if is_multi_index: + s.set_index(self.levels, inplace=True) + + s = s.iloc[:, 0] + + # remove the default name + if s.name == "values": + s.name = None + return s + + +class AppendableMultiSeriesTable(AppendableSeriesTable): + """support the new appendable table formats""" + + pandas_kind = "series_table" + table_type = "appendable_multiseries" + + # error: Signature of "write" incompatible with supertype "Fixed" + def write(self, obj, **kwargs) -> None: # type: ignore[override] + """we are going to write this as a frame table""" + name = obj.name or "values" + newobj, self.levels = self.validate_multiindex(obj) + assert isinstance(self.levels, list) # for mypy + cols = list(self.levels) + cols.append(name) + newobj.columns = Index(cols) + super().write(obj=newobj, **kwargs) + + +class GenericTable(AppendableFrameTable): + """a table that read/writes the generic pytables table format""" + + pandas_kind = "frame_table" + table_type = "generic_table" + ndim = 2 + obj_type = DataFrame + levels: list[Hashable] + + @property + def pandas_type(self) -> str: + return self.pandas_kind + + @property + def storable(self): + return getattr(self.group, "table", None) or self.group + + def get_attrs(self) -> None: + """retrieve our attributes""" + self.non_index_axes = [] + self.nan_rep = None + self.levels = [] + + self.index_axes = [a for a in self.indexables if a.is_an_indexable] + self.values_axes = [a for a in self.indexables if not a.is_an_indexable] + self.data_columns = [a.name for a in self.values_axes] + + @cache_readonly + def indexables(self): + """create the indexables from the table description""" + d = self.description + + # TODO: can we get a typ for this? AFAICT it is the only place + # where we aren't passing one + # the index columns is just a simple index + md = self.read_metadata("index") + meta = "category" if md is not None else None + index_col = GenericIndexCol( + name="index", axis=0, table=self.table, meta=meta, metadata=md + ) + + _indexables: list[GenericIndexCol | GenericDataIndexableCol] = [index_col] + + for i, n in enumerate(d._v_names): + assert isinstance(n, str) + + atom = getattr(d, n) + md = self.read_metadata(n) + meta = "category" if md is not None else None + dc = GenericDataIndexableCol( + name=n, + pos=i, + values=[n], + typ=atom, + table=self.table, + meta=meta, + metadata=md, + ) + _indexables.append(dc) + + return _indexables + + # error: Signature of "write" incompatible with supertype "AppendableTable" + def write(self, **kwargs) -> None: # type: ignore[override] + raise NotImplementedError("cannot write on an generic table") + + +class AppendableMultiFrameTable(AppendableFrameTable): + """a frame with a multi-index""" + + table_type = "appendable_multiframe" + obj_type = DataFrame + ndim = 2 + _re_levels = re.compile(r"^level_\d+$") + + @property + def table_type_short(self) -> str: + return "appendable_multi" + + # error: Signature of "write" incompatible with supertype "Fixed" + def write(self, obj, data_columns=None, **kwargs) -> None: # type: ignore[override] + if data_columns is None: + data_columns = [] + elif data_columns is True: + data_columns = obj.columns.tolist() + obj, self.levels = self.validate_multiindex(obj) + assert isinstance(self.levels, list) # for mypy + for n in self.levels: + if n not in data_columns: + data_columns.insert(0, n) + super().write(obj=obj, data_columns=data_columns, **kwargs) + + def read( + self, + where=None, + columns=None, + start: int | None = None, + stop: int | None = None, + ): + df = super().read(where=where, columns=columns, start=start, stop=stop) + df = df.set_index(self.levels) + + # remove names for 'level_%d' + df.index = df.index.set_names( + [None if self._re_levels.search(name) else name for name in df.index.names] + ) + + return df + + +def _reindex_axis( + obj: DataFrame, axis: AxisInt, labels: Index, other=None +) -> DataFrame: + ax = obj._get_axis(axis) + labels = ensure_index(labels) + + # try not to reindex even if other is provided + # if it equals our current index + if other is not None: + other = ensure_index(other) + if (other is None or labels.equals(other)) and labels.equals(ax): + return obj + + labels = ensure_index(labels.unique()) + if other is not None: + labels = ensure_index(other.unique()).intersection(labels, sort=False) + if not labels.equals(ax): + slicer: list[slice | Index] = [slice(None, None)] * obj.ndim + slicer[axis] = labels + obj = obj.loc[tuple(slicer)] + return obj + + +# tz to/from coercion + + +def _get_tz(tz: tzinfo) -> str | tzinfo: + """for a tz-aware type, return an encoded zone""" + zone = timezones.get_timezone(tz) + return zone + + +@overload +def _set_tz( + values: np.ndarray | Index, tz: str | tzinfo, coerce: bool = False +) -> DatetimeIndex: + ... + + +@overload +def _set_tz(values: np.ndarray | Index, tz: None, coerce: bool = False) -> np.ndarray: + ... + + +def _set_tz( + values: np.ndarray | Index, tz: str | tzinfo | None, coerce: bool = False +) -> np.ndarray | DatetimeIndex: + """ + coerce the values to a DatetimeIndex if tz is set + preserve the input shape if possible + + Parameters + ---------- + values : ndarray or Index + tz : str or tzinfo + coerce : if we do not have a passed timezone, coerce to M8[ns] ndarray + """ + if isinstance(values, DatetimeIndex): + # If values is tzaware, the tz gets dropped in the values.ravel() + # call below (which returns an ndarray). So we are only non-lossy + # if `tz` matches `values.tz`. + assert values.tz is None or values.tz == tz + if values.tz is not None: + return values + + if tz is not None: + if isinstance(values, DatetimeIndex): + name = values.name + else: + name = None + values = values.ravel() + + tz = _ensure_decoded(tz) + values = DatetimeIndex(values, name=name) + values = values.tz_localize("UTC").tz_convert(tz) + elif coerce: + values = np.asarray(values, dtype="M8[ns]") + + # error: Incompatible return value type (got "Union[ndarray, Index]", + # expected "Union[ndarray, DatetimeIndex]") + return values # type: ignore[return-value] + + +def _convert_index(name: str, index: Index, encoding: str, errors: str) -> IndexCol: + assert isinstance(name, str) + + index_name = index.name + # error: Argument 1 to "_get_data_and_dtype_name" has incompatible type "Index"; + # expected "Union[ExtensionArray, ndarray]" + converted, dtype_name = _get_data_and_dtype_name(index) # type: ignore[arg-type] + kind = _dtype_to_kind(dtype_name) + atom = DataIndexableCol._get_atom(converted) + + if ( + lib.is_np_dtype(index.dtype, "iu") + or needs_i8_conversion(index.dtype) + or is_bool_dtype(index.dtype) + ): + # Includes Index, RangeIndex, DatetimeIndex, TimedeltaIndex, PeriodIndex, + # in which case "kind" is "integer", "integer", "datetime64", + # "timedelta64", and "integer", respectively. + return IndexCol( + name, + values=converted, + kind=kind, + typ=atom, + freq=getattr(index, "freq", None), + tz=getattr(index, "tz", None), + index_name=index_name, + ) + + if isinstance(index, MultiIndex): + raise TypeError("MultiIndex not supported here!") + + inferred_type = lib.infer_dtype(index, skipna=False) + # we won't get inferred_type of "datetime64" or "timedelta64" as these + # would go through the DatetimeIndex/TimedeltaIndex paths above + + values = np.asarray(index) + + if inferred_type == "date": + converted = np.asarray([v.toordinal() for v in values], dtype=np.int32) + return IndexCol( + name, converted, "date", _tables().Time32Col(), index_name=index_name + ) + elif inferred_type == "string": + converted = _convert_string_array(values, encoding, errors) + itemsize = converted.dtype.itemsize + return IndexCol( + name, + converted, + "string", + _tables().StringCol(itemsize), + index_name=index_name, + ) + + elif inferred_type in ["integer", "floating"]: + return IndexCol( + name, values=converted, kind=kind, typ=atom, index_name=index_name + ) + else: + assert isinstance(converted, np.ndarray) and converted.dtype == object + assert kind == "object", kind + atom = _tables().ObjectAtom() + return IndexCol(name, converted, kind, atom, index_name=index_name) + + +def _unconvert_index(data, kind: str, encoding: str, errors: str) -> np.ndarray | Index: + index: Index | np.ndarray + + if kind.startswith("datetime64"): + if kind == "datetime64": + # created before we stored resolution information + index = DatetimeIndex(data) + else: + index = DatetimeIndex(data.view(kind)) + elif kind == "timedelta64": + index = TimedeltaIndex(data) + elif kind == "date": + try: + index = np.asarray([date.fromordinal(v) for v in data], dtype=object) + except ValueError: + index = np.asarray([date.fromtimestamp(v) for v in data], dtype=object) + elif kind in ("integer", "float", "bool"): + index = np.asarray(data) + elif kind in ("string"): + index = _unconvert_string_array( + data, nan_rep=None, encoding=encoding, errors=errors + ) + elif kind == "object": + index = np.asarray(data[0]) + else: # pragma: no cover + raise ValueError(f"unrecognized index type {kind}") + return index + + +def _maybe_convert_for_string_atom( + name: str, + bvalues: ArrayLike, + existing_col, + min_itemsize, + nan_rep, + encoding, + errors, + columns: list[str], +): + if bvalues.dtype != object: + return bvalues + + bvalues = cast(np.ndarray, bvalues) + + dtype_name = bvalues.dtype.name + inferred_type = lib.infer_dtype(bvalues, skipna=False) + + if inferred_type == "date": + raise TypeError("[date] is not implemented as a table column") + if inferred_type == "datetime": + # after GH#8260 + # this only would be hit for a multi-timezone dtype which is an error + raise TypeError( + "too many timezones in this block, create separate data columns" + ) + + if not (inferred_type == "string" or dtype_name == "object"): + return bvalues + + mask = isna(bvalues) + data = bvalues.copy() + data[mask] = nan_rep + + # see if we have a valid string type + inferred_type = lib.infer_dtype(data, skipna=False) + if inferred_type != "string": + # we cannot serialize this data, so report an exception on a column + # by column basis + + # expected behaviour: + # search block for a non-string object column by column + for i in range(data.shape[0]): + col = data[i] + inferred_type = lib.infer_dtype(col, skipna=False) + if inferred_type != "string": + error_column_label = columns[i] if len(columns) > i else f"No.{i}" + raise TypeError( + f"Cannot serialize the column [{error_column_label}]\n" + f"because its data contents are not [string] but " + f"[{inferred_type}] object dtype" + ) + + # itemsize is the maximum length of a string (along any dimension) + + data_converted = _convert_string_array(data, encoding, errors).reshape(data.shape) + itemsize = data_converted.itemsize + + # specified min_itemsize? + if isinstance(min_itemsize, dict): + min_itemsize = int(min_itemsize.get(name) or min_itemsize.get("values") or 0) + itemsize = max(min_itemsize or 0, itemsize) + + # check for column in the values conflicts + if existing_col is not None: + eci = existing_col.validate_col(itemsize) + if eci is not None and eci > itemsize: + itemsize = eci + + data_converted = data_converted.astype(f"|S{itemsize}", copy=False) + return data_converted + + +def _convert_string_array(data: np.ndarray, encoding: str, errors: str) -> np.ndarray: + """ + Take a string-like that is object dtype and coerce to a fixed size string type. + + Parameters + ---------- + data : np.ndarray[object] + encoding : str + errors : str + Handler for encoding errors. + + Returns + ------- + np.ndarray[fixed-length-string] + """ + # encode if needed + if len(data): + data = ( + Series(data.ravel(), copy=False) + .str.encode(encoding, errors) + ._values.reshape(data.shape) + ) + + # create the sized dtype + ensured = ensure_object(data.ravel()) + itemsize = max(1, libwriters.max_len_string_array(ensured)) + + data = np.asarray(data, dtype=f"S{itemsize}") + return data + + +def _unconvert_string_array( + data: np.ndarray, nan_rep, encoding: str, errors: str +) -> np.ndarray: + """ + Inverse of _convert_string_array. + + Parameters + ---------- + data : np.ndarray[fixed-length-string] + nan_rep : the storage repr of NaN + encoding : str + errors : str + Handler for encoding errors. + + Returns + ------- + np.ndarray[object] + Decoded data. + """ + shape = data.shape + data = np.asarray(data.ravel(), dtype=object) + + if len(data): + itemsize = libwriters.max_len_string_array(ensure_object(data)) + dtype = f"U{itemsize}" + + if isinstance(data[0], bytes): + data = Series(data, copy=False).str.decode(encoding, errors=errors)._values + else: + data = data.astype(dtype, copy=False).astype(object, copy=False) + + if nan_rep is None: + nan_rep = "nan" + + libwriters.string_array_replace_from_nan_rep(data, nan_rep) + return data.reshape(shape) + + +def _maybe_convert(values: np.ndarray, val_kind: str, encoding: str, errors: str): + assert isinstance(val_kind, str), type(val_kind) + if _need_convert(val_kind): + conv = _get_converter(val_kind, encoding, errors) + values = conv(values) + return values + + +def _get_converter(kind: str, encoding: str, errors: str): + if kind == "datetime64": + return lambda x: np.asarray(x, dtype="M8[ns]") + elif "datetime64" in kind: + return lambda x: np.asarray(x, dtype=kind) + elif kind == "string": + return lambda x: _unconvert_string_array( + x, nan_rep=None, encoding=encoding, errors=errors + ) + else: # pragma: no cover + raise ValueError(f"invalid kind {kind}") + + +def _need_convert(kind: str) -> bool: + if kind in ("datetime64", "string") or "datetime64" in kind: + return True + return False + + +def _maybe_adjust_name(name: str, version: Sequence[int]) -> str: + """ + Prior to 0.10.1, we named values blocks like: values_block_0 an the + name values_0, adjust the given name if necessary. + + Parameters + ---------- + name : str + version : Tuple[int, int, int] + + Returns + ------- + str + """ + if isinstance(version, str) or len(version) < 3: + raise ValueError("Version is incorrect, expected sequence of 3 integers.") + + if version[0] == 0 and version[1] <= 10 and version[2] == 0: + m = re.search(r"values_block_(\d+)", name) + if m: + grp = m.groups()[0] + name = f"values_{grp}" + return name + + +def _dtype_to_kind(dtype_str: str) -> str: + """ + Find the "kind" string describing the given dtype name. + """ + dtype_str = _ensure_decoded(dtype_str) + + if dtype_str.startswith(("string", "bytes")): + kind = "string" + elif dtype_str.startswith("float"): + kind = "float" + elif dtype_str.startswith("complex"): + kind = "complex" + elif dtype_str.startswith(("int", "uint")): + kind = "integer" + elif dtype_str.startswith("datetime64"): + kind = dtype_str + elif dtype_str.startswith("timedelta"): + kind = "timedelta64" + elif dtype_str.startswith("bool"): + kind = "bool" + elif dtype_str.startswith("category"): + kind = "category" + elif dtype_str.startswith("period"): + # We store the `freq` attr so we can restore from integers + kind = "integer" + elif dtype_str == "object": + kind = "object" + else: + raise ValueError(f"cannot interpret dtype of [{dtype_str}]") + + return kind + + +def _get_data_and_dtype_name(data: ArrayLike): + """ + Convert the passed data into a storable form and a dtype string. + """ + if isinstance(data, Categorical): + data = data.codes + + if isinstance(data.dtype, DatetimeTZDtype): + # For datetime64tz we need to drop the TZ in tests TODO: why? + dtype_name = f"datetime64[{data.dtype.unit}]" + else: + dtype_name = data.dtype.name + + if data.dtype.kind in "mM": + data = np.asarray(data.view("i8")) + # TODO: we used to reshape for the dt64tz case, but no longer + # doing that doesn't seem to break anything. why? + + elif isinstance(data, PeriodIndex): + data = data.asi8 + + data = np.asarray(data) + return data, dtype_name + + +class Selection: + """ + Carries out a selection operation on a tables.Table object. + + Parameters + ---------- + table : a Table object + where : list of Terms (or convertible to) + start, stop: indices to start and/or stop selection + + """ + + def __init__( + self, + table: Table, + where=None, + start: int | None = None, + stop: int | None = None, + ) -> None: + self.table = table + self.where = where + self.start = start + self.stop = stop + self.condition = None + self.filter = None + self.terms = None + self.coordinates = None + + if is_list_like(where): + # see if we have a passed coordinate like + with suppress(ValueError): + inferred = lib.infer_dtype(where, skipna=False) + if inferred in ("integer", "boolean"): + where = np.asarray(where) + if where.dtype == np.bool_: + start, stop = self.start, self.stop + if start is None: + start = 0 + if stop is None: + stop = self.table.nrows + self.coordinates = np.arange(start, stop)[where] + elif issubclass(where.dtype.type, np.integer): + if (self.start is not None and (where < self.start).any()) or ( + self.stop is not None and (where >= self.stop).any() + ): + raise ValueError( + "where must have index locations >= start and < stop" + ) + self.coordinates = where + + if self.coordinates is None: + self.terms = self.generate(where) + + # create the numexpr & the filter + if self.terms is not None: + self.condition, self.filter = self.terms.evaluate() + + def generate(self, where): + """where can be a : dict,list,tuple,string""" + if where is None: + return None + + q = self.table.queryables() + try: + return PyTablesExpr(where, queryables=q, encoding=self.table.encoding) + except NameError as err: + # raise a nice message, suggesting that the user should use + # data_columns + qkeys = ",".join(q.keys()) + msg = dedent( + f"""\ + The passed where expression: {where} + contains an invalid variable reference + all of the variable references must be a reference to + an axis (e.g. 'index' or 'columns'), or a data_column + The currently defined references are: {qkeys} + """ + ) + raise ValueError(msg) from err + + def select(self): + """ + generate the selection + """ + if self.condition is not None: + return self.table.table.read_where( + self.condition.format(), start=self.start, stop=self.stop + ) + elif self.coordinates is not None: + return self.table.table.read_coordinates(self.coordinates) + return self.table.table.read(start=self.start, stop=self.stop) + + def select_coords(self): + """ + generate the selection + """ + start, stop = self.start, self.stop + nrows = self.table.nrows + if start is None: + start = 0 + elif start < 0: + start += nrows + if stop is None: + stop = nrows + elif stop < 0: + stop += nrows + + if self.condition is not None: + return self.table.table.get_where_list( + self.condition.format(), start=start, stop=stop, sort=True + ) + elif self.coordinates is not None: + return self.coordinates + + return np.arange(start, stop) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/io/sql.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/io/sql.py new file mode 100644 index 0000000000000000000000000000000000000000..3e17175167f25a4bfc7eb559070927f56dc84eae --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/io/sql.py @@ -0,0 +1,2926 @@ +""" +Collection of query wrappers / abstractions to both facilitate data +retrieval and to reduce dependency on DB-specific API. +""" + +from __future__ import annotations + +from abc import ( + ABC, + abstractmethod, +) +from contextlib import ( + ExitStack, + contextmanager, +) +from datetime import ( + date, + datetime, + time, +) +from functools import partial +import re +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Literal, + cast, + overload, +) +import warnings + +import numpy as np + +from pandas._config import using_pyarrow_string_dtype + +from pandas._libs import lib +from pandas.compat._optional import import_optional_dependency +from pandas.errors import ( + AbstractMethodError, + DatabaseError, +) +from pandas.util._exceptions import find_stack_level +from pandas.util._validators import check_dtype_backend + +from pandas.core.dtypes.common import ( + is_dict_like, + is_list_like, +) +from pandas.core.dtypes.dtypes import ( + ArrowDtype, + DatetimeTZDtype, +) +from pandas.core.dtypes.missing import isna + +from pandas import get_option +from pandas.core.api import ( + DataFrame, + Series, +) +from pandas.core.arrays import ArrowExtensionArray +from pandas.core.base import PandasObject +import pandas.core.common as com +from pandas.core.common import maybe_make_list +from pandas.core.internals.construction import convert_object_array +from pandas.core.tools.datetimes import to_datetime + +if TYPE_CHECKING: + from collections.abc import ( + Iterator, + Mapping, + ) + + from sqlalchemy import Table + from sqlalchemy.sql.expression import ( + Select, + TextClause, + ) + + from pandas._typing import ( + DateTimeErrorChoices, + DtypeArg, + DtypeBackend, + IndexLabel, + Self, + ) + + from pandas import Index + +# ----------------------------------------------------------------------------- +# -- Helper functions + + +def _process_parse_dates_argument(parse_dates): + """Process parse_dates argument for read_sql functions""" + # handle non-list entries for parse_dates gracefully + if parse_dates is True or parse_dates is None or parse_dates is False: + parse_dates = [] + + elif not hasattr(parse_dates, "__iter__"): + parse_dates = [parse_dates] + return parse_dates + + +def _handle_date_column( + col, utc: bool = False, format: str | dict[str, Any] | None = None +): + if isinstance(format, dict): + # GH35185 Allow custom error values in parse_dates argument of + # read_sql like functions. + # Format can take on custom to_datetime argument values such as + # {"errors": "coerce"} or {"dayfirst": True} + error: DateTimeErrorChoices = format.pop("errors", None) or "ignore" + if error == "ignore": + try: + return to_datetime(col, **format) + except (TypeError, ValueError): + # TODO: not reached 2023-10-27; needed? + return col + return to_datetime(col, errors=error, **format) + else: + # Allow passing of formatting string for integers + # GH17855 + if format is None and ( + issubclass(col.dtype.type, np.floating) + or issubclass(col.dtype.type, np.integer) + ): + format = "s" + if format in ["D", "d", "h", "m", "s", "ms", "us", "ns"]: + return to_datetime(col, errors="coerce", unit=format, utc=utc) + elif isinstance(col.dtype, DatetimeTZDtype): + # coerce to UTC timezone + # GH11216 + return to_datetime(col, utc=True) + else: + return to_datetime(col, errors="coerce", format=format, utc=utc) + + +def _parse_date_columns(data_frame, parse_dates): + """ + Force non-datetime columns to be read as such. + Supports both string formatted and integer timestamp columns. + """ + parse_dates = _process_parse_dates_argument(parse_dates) + + # we want to coerce datetime64_tz dtypes for now to UTC + # we could in theory do a 'nice' conversion from a FixedOffset tz + # GH11216 + for i, (col_name, df_col) in enumerate(data_frame.items()): + if isinstance(df_col.dtype, DatetimeTZDtype) or col_name in parse_dates: + try: + fmt = parse_dates[col_name] + except (KeyError, TypeError): + fmt = None + data_frame.isetitem(i, _handle_date_column(df_col, format=fmt)) + + return data_frame + + +def _convert_arrays_to_dataframe( + data, + columns, + coerce_float: bool = True, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", +) -> DataFrame: + content = lib.to_object_array_tuples(data) + arrays = convert_object_array( + list(content.T), + dtype=None, + coerce_float=coerce_float, + dtype_backend=dtype_backend, + ) + if dtype_backend == "pyarrow": + pa = import_optional_dependency("pyarrow") + + result_arrays = [] + for arr in arrays: + pa_array = pa.array(arr, from_pandas=True) + if arr.dtype == "string": + # TODO: Arrow still infers strings arrays as regular strings instead + # of large_string, which is what we preserver everywhere else for + # dtype_backend="pyarrow". We may want to reconsider this + pa_array = pa_array.cast(pa.string()) + result_arrays.append(ArrowExtensionArray(pa_array)) + arrays = result_arrays # type: ignore[assignment] + if arrays: + df = DataFrame(dict(zip(list(range(len(columns))), arrays))) + df.columns = columns + return df + else: + return DataFrame(columns=columns) + + +def _wrap_result( + data, + columns, + index_col=None, + coerce_float: bool = True, + parse_dates=None, + dtype: DtypeArg | None = None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", +): + """Wrap result set of a SQLAlchemy query in a DataFrame.""" + frame = _convert_arrays_to_dataframe(data, columns, coerce_float, dtype_backend) + + if dtype: + frame = frame.astype(dtype) + + frame = _parse_date_columns(frame, parse_dates) + + if index_col is not None: + frame = frame.set_index(index_col) + + return frame + + +def _wrap_result_adbc( + df: DataFrame, + *, + index_col=None, + parse_dates=None, + dtype: DtypeArg | None = None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", +) -> DataFrame: + """Wrap result set of a SQLAlchemy query in a DataFrame.""" + if dtype: + df = df.astype(dtype) + + df = _parse_date_columns(df, parse_dates) + + if index_col is not None: + df = df.set_index(index_col) + + return df + + +def execute(sql, con, params=None): + """ + Execute the given SQL query using the provided connection object. + + Parameters + ---------- + sql : string + SQL query to be executed. + con : SQLAlchemy connection or sqlite3 connection + If a DBAPI2 object, only sqlite3 is supported. + params : list or tuple, optional, default: None + List of parameters to pass to execute method. + + Returns + ------- + Results Iterable + """ + warnings.warn( + "`pandas.io.sql.execute` is deprecated and " + "will be removed in the future version.", + FutureWarning, + stacklevel=find_stack_level(), + ) # GH50185 + sqlalchemy = import_optional_dependency("sqlalchemy", errors="ignore") + + if sqlalchemy is not None and isinstance(con, (str, sqlalchemy.engine.Engine)): + raise TypeError("pandas.io.sql.execute requires a connection") # GH50185 + with pandasSQL_builder(con, need_transaction=True) as pandas_sql: + return pandas_sql.execute(sql, params) + + +# ----------------------------------------------------------------------------- +# -- Read and write to DataFrames + + +@overload +def read_sql_table( + table_name: str, + con, + schema=..., + index_col: str | list[str] | None = ..., + coerce_float=..., + parse_dates: list[str] | dict[str, str] | None = ..., + columns: list[str] | None = ..., + chunksize: None = ..., + dtype_backend: DtypeBackend | lib.NoDefault = ..., +) -> DataFrame: + ... + + +@overload +def read_sql_table( + table_name: str, + con, + schema=..., + index_col: str | list[str] | None = ..., + coerce_float=..., + parse_dates: list[str] | dict[str, str] | None = ..., + columns: list[str] | None = ..., + chunksize: int = ..., + dtype_backend: DtypeBackend | lib.NoDefault = ..., +) -> Iterator[DataFrame]: + ... + + +def read_sql_table( + table_name: str, + con, + schema: str | None = None, + index_col: str | list[str] | None = None, + coerce_float: bool = True, + parse_dates: list[str] | dict[str, str] | None = None, + columns: list[str] | None = None, + chunksize: int | None = None, + dtype_backend: DtypeBackend | lib.NoDefault = lib.no_default, +) -> DataFrame | Iterator[DataFrame]: + """ + Read SQL database table into a DataFrame. + + Given a table name and a SQLAlchemy connectable, returns a DataFrame. + This function does not support DBAPI connections. + + Parameters + ---------- + table_name : str + Name of SQL table in database. + con : SQLAlchemy connectable or str + A database URI could be provided as str. + SQLite DBAPI connection mode not supported. + schema : str, default None + Name of SQL schema in database to query (if database flavor + supports this). Uses default schema if None (default). + index_col : str or list of str, optional, default: None + Column(s) to set as index(MultiIndex). + coerce_float : bool, default True + Attempts to convert values of non-string, non-numeric objects (like + decimal.Decimal) to floating point. Can result in loss of Precision. + parse_dates : list or dict, default None + - List of column names to parse as dates. + - Dict of ``{column_name: format string}`` where format string is + strftime compatible in case of parsing string times or is one of + (D, s, ns, ms, us) in case of parsing integer timestamps. + - Dict of ``{column_name: arg dict}``, where the arg dict corresponds + to the keyword arguments of :func:`pandas.to_datetime` + Especially useful with databases without native Datetime support, + such as SQLite. + columns : list, default None + List of column names to select from SQL table. + chunksize : int, default None + If specified, returns an iterator where `chunksize` is the number of + rows to include in each chunk. + dtype_backend : {'numpy_nullable', 'pyarrow'}, default 'numpy_nullable' + Back-end data type applied to the resultant :class:`DataFrame` + (still experimental). Behaviour is as follows: + + * ``"numpy_nullable"``: returns nullable-dtype-backed :class:`DataFrame` + (default). + * ``"pyarrow"``: returns pyarrow-backed nullable :class:`ArrowDtype` + DataFrame. + + .. versionadded:: 2.0 + + Returns + ------- + DataFrame or Iterator[DataFrame] + A SQL table is returned as two-dimensional data structure with labeled + axes. + + See Also + -------- + read_sql_query : Read SQL query into a DataFrame. + read_sql : Read SQL query or database table into a DataFrame. + + Notes + ----- + Any datetime values with time zone information will be converted to UTC. + + Examples + -------- + >>> pd.read_sql_table('table_name', 'postgres:///db_name') # doctest:+SKIP + """ + + check_dtype_backend(dtype_backend) + if dtype_backend is lib.no_default: + dtype_backend = "numpy" # type: ignore[assignment] + assert dtype_backend is not lib.no_default + + with pandasSQL_builder(con, schema=schema, need_transaction=True) as pandas_sql: + if not pandas_sql.has_table(table_name): + raise ValueError(f"Table {table_name} not found") + + table = pandas_sql.read_table( + table_name, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + columns=columns, + chunksize=chunksize, + dtype_backend=dtype_backend, + ) + + if table is not None: + return table + else: + raise ValueError(f"Table {table_name} not found", con) + + +@overload +def read_sql_query( + sql, + con, + index_col: str | list[str] | None = ..., + coerce_float=..., + params: list[Any] | Mapping[str, Any] | None = ..., + parse_dates: list[str] | dict[str, str] | None = ..., + chunksize: None = ..., + dtype: DtypeArg | None = ..., + dtype_backend: DtypeBackend | lib.NoDefault = ..., +) -> DataFrame: + ... + + +@overload +def read_sql_query( + sql, + con, + index_col: str | list[str] | None = ..., + coerce_float=..., + params: list[Any] | Mapping[str, Any] | None = ..., + parse_dates: list[str] | dict[str, str] | None = ..., + chunksize: int = ..., + dtype: DtypeArg | None = ..., + dtype_backend: DtypeBackend | lib.NoDefault = ..., +) -> Iterator[DataFrame]: + ... + + +def read_sql_query( + sql, + con, + index_col: str | list[str] | None = None, + coerce_float: bool = True, + params: list[Any] | Mapping[str, Any] | None = None, + parse_dates: list[str] | dict[str, str] | None = None, + chunksize: int | None = None, + dtype: DtypeArg | None = None, + dtype_backend: DtypeBackend | lib.NoDefault = lib.no_default, +) -> DataFrame | Iterator[DataFrame]: + """ + Read SQL query into a DataFrame. + + Returns a DataFrame corresponding to the result set of the query + string. Optionally provide an `index_col` parameter to use one of the + columns as the index, otherwise default integer index will be used. + + Parameters + ---------- + sql : str SQL query or SQLAlchemy Selectable (select or text object) + SQL query to be executed. + con : SQLAlchemy connectable, str, or sqlite3 connection + Using SQLAlchemy makes it possible to use any DB supported by that + library. If a DBAPI2 object, only sqlite3 is supported. + index_col : str or list of str, optional, default: None + Column(s) to set as index(MultiIndex). + coerce_float : bool, default True + Attempts to convert values of non-string, non-numeric objects (like + decimal.Decimal) to floating point. Useful for SQL result sets. + params : list, tuple or mapping, optional, default: None + List of parameters to pass to execute method. The syntax used + to pass parameters is database driver dependent. Check your + database driver documentation for which of the five syntax styles, + described in PEP 249's paramstyle, is supported. + Eg. for psycopg2, uses %(name)s so use params={'name' : 'value'}. + parse_dates : list or dict, default: None + - List of column names to parse as dates. + - Dict of ``{column_name: format string}`` where format string is + strftime compatible in case of parsing string times, or is one of + (D, s, ns, ms, us) in case of parsing integer timestamps. + - Dict of ``{column_name: arg dict}``, where the arg dict corresponds + to the keyword arguments of :func:`pandas.to_datetime` + Especially useful with databases without native Datetime support, + such as SQLite. + chunksize : int, default None + If specified, return an iterator where `chunksize` is the number of + rows to include in each chunk. + dtype : Type name or dict of columns + Data type for data or columns. E.g. np.float64 or + {'a': np.float64, 'b': np.int32, 'c': 'Int64'}. + + .. versionadded:: 1.3.0 + dtype_backend : {'numpy_nullable', 'pyarrow'}, default 'numpy_nullable' + Back-end data type applied to the resultant :class:`DataFrame` + (still experimental). Behaviour is as follows: + + * ``"numpy_nullable"``: returns nullable-dtype-backed :class:`DataFrame` + (default). + * ``"pyarrow"``: returns pyarrow-backed nullable :class:`ArrowDtype` + DataFrame. + + .. versionadded:: 2.0 + + Returns + ------- + DataFrame or Iterator[DataFrame] + + See Also + -------- + read_sql_table : Read SQL database table into a DataFrame. + read_sql : Read SQL query or database table into a DataFrame. + + Notes + ----- + Any datetime values with time zone information parsed via the `parse_dates` + parameter will be converted to UTC. + + Examples + -------- + >>> from sqlalchemy import create_engine # doctest: +SKIP + >>> engine = create_engine("sqlite:///database.db") # doctest: +SKIP + >>> with engine.connect() as conn, conn.begin(): # doctest: +SKIP + ... data = pd.read_sql_table("data", conn) # doctest: +SKIP + """ + + check_dtype_backend(dtype_backend) + if dtype_backend is lib.no_default: + dtype_backend = "numpy" # type: ignore[assignment] + assert dtype_backend is not lib.no_default + + with pandasSQL_builder(con) as pandas_sql: + return pandas_sql.read_query( + sql, + index_col=index_col, + params=params, + coerce_float=coerce_float, + parse_dates=parse_dates, + chunksize=chunksize, + dtype=dtype, + dtype_backend=dtype_backend, + ) + + +@overload +def read_sql( + sql, + con, + index_col: str | list[str] | None = ..., + coerce_float=..., + params=..., + parse_dates=..., + columns: list[str] = ..., + chunksize: None = ..., + dtype_backend: DtypeBackend | lib.NoDefault = ..., + dtype: DtypeArg | None = None, +) -> DataFrame: + ... + + +@overload +def read_sql( + sql, + con, + index_col: str | list[str] | None = ..., + coerce_float=..., + params=..., + parse_dates=..., + columns: list[str] = ..., + chunksize: int = ..., + dtype_backend: DtypeBackend | lib.NoDefault = ..., + dtype: DtypeArg | None = None, +) -> Iterator[DataFrame]: + ... + + +def read_sql( + sql, + con, + index_col: str | list[str] | None = None, + coerce_float: bool = True, + params=None, + parse_dates=None, + columns: list[str] | None = None, + chunksize: int | None = None, + dtype_backend: DtypeBackend | lib.NoDefault = lib.no_default, + dtype: DtypeArg | None = None, +) -> DataFrame | Iterator[DataFrame]: + """ + Read SQL query or database table into a DataFrame. + + This function is a convenience wrapper around ``read_sql_table`` and + ``read_sql_query`` (for backward compatibility). It will delegate + to the specific function depending on the provided input. A SQL query + will be routed to ``read_sql_query``, while a database table name will + be routed to ``read_sql_table``. Note that the delegated function might + have more specific notes about their functionality not listed here. + + Parameters + ---------- + sql : str or SQLAlchemy Selectable (select or text object) + SQL query to be executed or a table name. + con : ADBC Connection, SQLAlchemy connectable, str, or sqlite3 connection + ADBC provides high performance I/O with native type support, where available. + Using SQLAlchemy makes it possible to use any DB supported by that + library. If a DBAPI2 object, only sqlite3 is supported. The user is responsible + for engine disposal and connection closure for the ADBC connection and + SQLAlchemy connectable; str connections are closed automatically. See + `here `_. + index_col : str or list of str, optional, default: None + Column(s) to set as index(MultiIndex). + coerce_float : bool, default True + Attempts to convert values of non-string, non-numeric objects (like + decimal.Decimal) to floating point, useful for SQL result sets. + params : list, tuple or dict, optional, default: None + List of parameters to pass to execute method. The syntax used + to pass parameters is database driver dependent. Check your + database driver documentation for which of the five syntax styles, + described in PEP 249's paramstyle, is supported. + Eg. for psycopg2, uses %(name)s so use params={'name' : 'value'}. + parse_dates : list or dict, default: None + - List of column names to parse as dates. + - Dict of ``{column_name: format string}`` where format string is + strftime compatible in case of parsing string times, or is one of + (D, s, ns, ms, us) in case of parsing integer timestamps. + - Dict of ``{column_name: arg dict}``, where the arg dict corresponds + to the keyword arguments of :func:`pandas.to_datetime` + Especially useful with databases without native Datetime support, + such as SQLite. + columns : list, default: None + List of column names to select from SQL table (only used when reading + a table). + chunksize : int, default None + If specified, return an iterator where `chunksize` is the + number of rows to include in each chunk. + dtype_backend : {'numpy_nullable', 'pyarrow'}, default 'numpy_nullable' + Back-end data type applied to the resultant :class:`DataFrame` + (still experimental). Behaviour is as follows: + + * ``"numpy_nullable"``: returns nullable-dtype-backed :class:`DataFrame` + (default). + * ``"pyarrow"``: returns pyarrow-backed nullable :class:`ArrowDtype` + DataFrame. + + .. versionadded:: 2.0 + dtype : Type name or dict of columns + Data type for data or columns. E.g. np.float64 or + {'a': np.float64, 'b': np.int32, 'c': 'Int64'}. + The argument is ignored if a table is passed instead of a query. + + .. versionadded:: 2.0.0 + + Returns + ------- + DataFrame or Iterator[DataFrame] + + See Also + -------- + read_sql_table : Read SQL database table into a DataFrame. + read_sql_query : Read SQL query into a DataFrame. + + Examples + -------- + Read data from SQL via either a SQL query or a SQL tablename. + When using a SQLite database only SQL queries are accepted, + providing only the SQL tablename will result in an error. + + >>> from sqlite3 import connect + >>> conn = connect(':memory:') + >>> df = pd.DataFrame(data=[[0, '10/11/12'], [1, '12/11/10']], + ... columns=['int_column', 'date_column']) + >>> df.to_sql(name='test_data', con=conn) + 2 + + >>> pd.read_sql('SELECT int_column, date_column FROM test_data', conn) + int_column date_column + 0 0 10/11/12 + 1 1 12/11/10 + + >>> pd.read_sql('test_data', 'postgres:///db_name') # doctest:+SKIP + + Apply date parsing to columns through the ``parse_dates`` argument + The ``parse_dates`` argument calls ``pd.to_datetime`` on the provided columns. + Custom argument values for applying ``pd.to_datetime`` on a column are specified + via a dictionary format: + + >>> pd.read_sql('SELECT int_column, date_column FROM test_data', + ... conn, + ... parse_dates={"date_column": {"format": "%d/%m/%y"}}) + int_column date_column + 0 0 2012-11-10 + 1 1 2010-11-12 + + .. versionadded:: 2.2.0 + + pandas now supports reading via ADBC drivers + + >>> from adbc_driver_postgresql import dbapi # doctest:+SKIP + >>> with dbapi.connect('postgres:///db_name') as conn: # doctest:+SKIP + ... pd.read_sql('SELECT int_column FROM test_data', conn) + int_column + 0 0 + 1 1 + """ + + check_dtype_backend(dtype_backend) + if dtype_backend is lib.no_default: + dtype_backend = "numpy" # type: ignore[assignment] + assert dtype_backend is not lib.no_default + + with pandasSQL_builder(con) as pandas_sql: + if isinstance(pandas_sql, SQLiteDatabase): + return pandas_sql.read_query( + sql, + index_col=index_col, + params=params, + coerce_float=coerce_float, + parse_dates=parse_dates, + chunksize=chunksize, + dtype_backend=dtype_backend, + dtype=dtype, + ) + + try: + _is_table_name = pandas_sql.has_table(sql) + except Exception: + # using generic exception to catch errors from sql drivers (GH24988) + _is_table_name = False + + if _is_table_name: + return pandas_sql.read_table( + sql, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + columns=columns, + chunksize=chunksize, + dtype_backend=dtype_backend, + ) + else: + return pandas_sql.read_query( + sql, + index_col=index_col, + params=params, + coerce_float=coerce_float, + parse_dates=parse_dates, + chunksize=chunksize, + dtype_backend=dtype_backend, + dtype=dtype, + ) + + +def to_sql( + frame, + name: str, + con, + schema: str | None = None, + if_exists: Literal["fail", "replace", "append"] = "fail", + index: bool = True, + index_label: IndexLabel | None = None, + chunksize: int | None = None, + dtype: DtypeArg | None = None, + method: Literal["multi"] | Callable | None = None, + engine: str = "auto", + **engine_kwargs, +) -> int | None: + """ + Write records stored in a DataFrame to a SQL database. + + Parameters + ---------- + frame : DataFrame, Series + name : str + Name of SQL table. + con : ADBC Connection, SQLAlchemy connectable, str, or sqlite3 connection + or sqlite3 DBAPI2 connection + ADBC provides high performance I/O with native type support, where available. + Using SQLAlchemy makes it possible to use any DB supported by that + library. + If a DBAPI2 object, only sqlite3 is supported. + schema : str, optional + Name of SQL schema in database to write to (if database flavor + supports this). If None, use default schema (default). + if_exists : {'fail', 'replace', 'append'}, default 'fail' + - fail: If table exists, do nothing. + - replace: If table exists, drop it, recreate it, and insert data. + - append: If table exists, insert data. Create if does not exist. + index : bool, default True + Write DataFrame index as a column. + index_label : str or sequence, optional + Column label for index column(s). If None is given (default) and + `index` is True, then the index names are used. + A sequence should be given if the DataFrame uses MultiIndex. + chunksize : int, optional + Specify the number of rows in each batch to be written at a time. + By default, all rows will be written at once. + dtype : dict or scalar, optional + Specifying the datatype for columns. If a dictionary is used, the + keys should be the column names and the values should be the + SQLAlchemy types or strings for the sqlite3 fallback mode. If a + scalar is provided, it will be applied to all columns. + method : {None, 'multi', callable}, optional + Controls the SQL insertion clause used: + + - None : Uses standard SQL ``INSERT`` clause (one per row). + - ``'multi'``: Pass multiple values in a single ``INSERT`` clause. + - callable with signature ``(pd_table, conn, keys, data_iter) -> int | None``. + + Details and a sample callable implementation can be found in the + section :ref:`insert method `. + engine : {'auto', 'sqlalchemy'}, default 'auto' + SQL engine library to use. If 'auto', then the option + ``io.sql.engine`` is used. The default ``io.sql.engine`` + behavior is 'sqlalchemy' + + .. versionadded:: 1.3.0 + + **engine_kwargs + Any additional kwargs are passed to the engine. + + Returns + ------- + None or int + Number of rows affected by to_sql. None is returned if the callable + passed into ``method`` does not return an integer number of rows. + + .. versionadded:: 1.4.0 + + Notes + ----- + The returned rows affected is the sum of the ``rowcount`` attribute of ``sqlite3.Cursor`` + or SQLAlchemy connectable. If using ADBC the returned rows are the result + of ``Cursor.adbc_ingest``. The returned value may not reflect the exact number of written + rows as stipulated in the + `sqlite3 `__ or + `SQLAlchemy `__ + """ # noqa: E501 + if if_exists not in ("fail", "replace", "append"): + raise ValueError(f"'{if_exists}' is not valid for if_exists") + + if isinstance(frame, Series): + frame = frame.to_frame() + elif not isinstance(frame, DataFrame): + raise NotImplementedError( + "'frame' argument should be either a Series or a DataFrame" + ) + + with pandasSQL_builder(con, schema=schema, need_transaction=True) as pandas_sql: + return pandas_sql.to_sql( + frame, + name, + if_exists=if_exists, + index=index, + index_label=index_label, + schema=schema, + chunksize=chunksize, + dtype=dtype, + method=method, + engine=engine, + **engine_kwargs, + ) + + +def has_table(table_name: str, con, schema: str | None = None) -> bool: + """ + Check if DataBase has named table. + + Parameters + ---------- + table_name: string + Name of SQL table. + con: ADBC Connection, SQLAlchemy connectable, str, or sqlite3 connection + ADBC provides high performance I/O with native type support, where available. + Using SQLAlchemy makes it possible to use any DB supported by that + library. + If a DBAPI2 object, only sqlite3 is supported. + schema : string, default None + Name of SQL schema in database to write to (if database flavor supports + this). If None, use default schema (default). + + Returns + ------- + boolean + """ + with pandasSQL_builder(con, schema=schema) as pandas_sql: + return pandas_sql.has_table(table_name) + + +table_exists = has_table + + +def pandasSQL_builder( + con, + schema: str | None = None, + need_transaction: bool = False, +) -> PandasSQL: + """ + Convenience function to return the correct PandasSQL subclass based on the + provided parameters. Also creates a sqlalchemy connection and transaction + if necessary. + """ + import sqlite3 + + if isinstance(con, sqlite3.Connection) or con is None: + return SQLiteDatabase(con) + + sqlalchemy = import_optional_dependency("sqlalchemy", errors="ignore") + + if isinstance(con, str) and sqlalchemy is None: + raise ImportError("Using URI string without sqlalchemy installed.") + + if sqlalchemy is not None and isinstance(con, (str, sqlalchemy.engine.Connectable)): + return SQLDatabase(con, schema, need_transaction) + + adbc = import_optional_dependency("adbc_driver_manager.dbapi", errors="ignore") + if adbc and isinstance(con, adbc.Connection): + return ADBCDatabase(con) + + warnings.warn( + "pandas only supports SQLAlchemy connectable (engine/connection) or " + "database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 " + "objects are not tested. Please consider using SQLAlchemy.", + UserWarning, + stacklevel=find_stack_level(), + ) + return SQLiteDatabase(con) + + +class SQLTable(PandasObject): + """ + For mapping Pandas tables to SQL tables. + Uses fact that table is reflected by SQLAlchemy to + do better type conversions. + Also holds various flags needed to avoid having to + pass them between functions all the time. + """ + + # TODO: support for multiIndex + + def __init__( + self, + name: str, + pandas_sql_engine, + frame=None, + index: bool | str | list[str] | None = True, + if_exists: Literal["fail", "replace", "append"] = "fail", + prefix: str = "pandas", + index_label=None, + schema=None, + keys=None, + dtype: DtypeArg | None = None, + ) -> None: + self.name = name + self.pd_sql = pandas_sql_engine + self.prefix = prefix + self.frame = frame + self.index = self._index_name(index, index_label) + self.schema = schema + self.if_exists = if_exists + self.keys = keys + self.dtype = dtype + + if frame is not None: + # We want to initialize based on a dataframe + self.table = self._create_table_setup() + else: + # no data provided, read-only mode + self.table = self.pd_sql.get_table(self.name, self.schema) + + if self.table is None: + raise ValueError(f"Could not init table '{name}'") + + if not len(self.name): + raise ValueError("Empty table name specified") + + def exists(self): + return self.pd_sql.has_table(self.name, self.schema) + + def sql_schema(self) -> str: + from sqlalchemy.schema import CreateTable + + return str(CreateTable(self.table).compile(self.pd_sql.con)) + + def _execute_create(self) -> None: + # Inserting table into database, add to MetaData object + self.table = self.table.to_metadata(self.pd_sql.meta) + with self.pd_sql.run_transaction(): + self.table.create(bind=self.pd_sql.con) + + def create(self) -> None: + if self.exists(): + if self.if_exists == "fail": + raise ValueError(f"Table '{self.name}' already exists.") + if self.if_exists == "replace": + self.pd_sql.drop_table(self.name, self.schema) + self._execute_create() + elif self.if_exists == "append": + pass + else: + raise ValueError(f"'{self.if_exists}' is not valid for if_exists") + else: + self._execute_create() + + def _execute_insert(self, conn, keys: list[str], data_iter) -> int: + """ + Execute SQL statement inserting data + + Parameters + ---------- + conn : sqlalchemy.engine.Engine or sqlalchemy.engine.Connection + keys : list of str + Column names + data_iter : generator of list + Each item contains a list of values to be inserted + """ + data = [dict(zip(keys, row)) for row in data_iter] + result = conn.execute(self.table.insert(), data) + return result.rowcount + + def _execute_insert_multi(self, conn, keys: list[str], data_iter) -> int: + """ + Alternative to _execute_insert for DBs support multi-value INSERT. + + Note: multi-value insert is usually faster for analytics DBs + and tables containing a few columns + but performance degrades quickly with increase of columns. + + """ + + from sqlalchemy import insert + + data = [dict(zip(keys, row)) for row in data_iter] + stmt = insert(self.table).values(data) + result = conn.execute(stmt) + return result.rowcount + + def insert_data(self) -> tuple[list[str], list[np.ndarray]]: + if self.index is not None: + temp = self.frame.copy() + temp.index.names = self.index + try: + temp.reset_index(inplace=True) + except ValueError as err: + raise ValueError(f"duplicate name in index/columns: {err}") from err + else: + temp = self.frame + + column_names = list(map(str, temp.columns)) + ncols = len(column_names) + # this just pre-allocates the list: None's will be replaced with ndarrays + # error: List item 0 has incompatible type "None"; expected "ndarray" + data_list: list[np.ndarray] = [None] * ncols # type: ignore[list-item] + + for i, (_, ser) in enumerate(temp.items()): + if ser.dtype.kind == "M": + if isinstance(ser._values, ArrowExtensionArray): + import pyarrow as pa + + if pa.types.is_date(ser.dtype.pyarrow_dtype): + # GH#53854 to_pydatetime not supported for pyarrow date dtypes + d = ser._values.to_numpy(dtype=object) + else: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=FutureWarning) + # GH#52459 to_pydatetime will return Index[object] + d = np.asarray(ser.dt.to_pydatetime(), dtype=object) + else: + d = ser._values.to_pydatetime() + elif ser.dtype.kind == "m": + vals = ser._values + if isinstance(vals, ArrowExtensionArray): + vals = vals.to_numpy(dtype=np.dtype("m8[ns]")) + # store as integers, see GH#6921, GH#7076 + d = vals.view("i8").astype(object) + else: + d = ser._values.astype(object) + + assert isinstance(d, np.ndarray), type(d) + + if ser._can_hold_na: + # Note: this will miss timedeltas since they are converted to int + mask = isna(d) + d[mask] = None + + data_list[i] = d + + return column_names, data_list + + def insert( + self, + chunksize: int | None = None, + method: Literal["multi"] | Callable | None = None, + ) -> int | None: + # set insert method + if method is None: + exec_insert = self._execute_insert + elif method == "multi": + exec_insert = self._execute_insert_multi + elif callable(method): + exec_insert = partial(method, self) + else: + raise ValueError(f"Invalid parameter `method`: {method}") + + keys, data_list = self.insert_data() + + nrows = len(self.frame) + + if nrows == 0: + return 0 + + if chunksize is None: + chunksize = nrows + elif chunksize == 0: + raise ValueError("chunksize argument should be non-zero") + + chunks = (nrows // chunksize) + 1 + total_inserted = None + with self.pd_sql.run_transaction() as conn: + for i in range(chunks): + start_i = i * chunksize + end_i = min((i + 1) * chunksize, nrows) + if start_i >= end_i: + break + + chunk_iter = zip(*(arr[start_i:end_i] for arr in data_list)) + num_inserted = exec_insert(conn, keys, chunk_iter) + # GH 46891 + if num_inserted is not None: + if total_inserted is None: + total_inserted = num_inserted + else: + total_inserted += num_inserted + return total_inserted + + def _query_iterator( + self, + result, + exit_stack: ExitStack, + chunksize: int | None, + columns, + coerce_float: bool = True, + parse_dates=None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", + ): + """Return generator through chunked result set.""" + has_read_data = False + with exit_stack: + while True: + data = result.fetchmany(chunksize) + if not data: + if not has_read_data: + yield DataFrame.from_records( + [], columns=columns, coerce_float=coerce_float + ) + break + + has_read_data = True + self.frame = _convert_arrays_to_dataframe( + data, columns, coerce_float, dtype_backend + ) + + self._harmonize_columns( + parse_dates=parse_dates, dtype_backend=dtype_backend + ) + + if self.index is not None: + self.frame.set_index(self.index, inplace=True) + + yield self.frame + + def read( + self, + exit_stack: ExitStack, + coerce_float: bool = True, + parse_dates=None, + columns=None, + chunksize: int | None = None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", + ) -> DataFrame | Iterator[DataFrame]: + from sqlalchemy import select + + if columns is not None and len(columns) > 0: + cols = [self.table.c[n] for n in columns] + if self.index is not None: + for idx in self.index[::-1]: + cols.insert(0, self.table.c[idx]) + sql_select = select(*cols) + else: + sql_select = select(self.table) + result = self.pd_sql.execute(sql_select) + column_names = result.keys() + + if chunksize is not None: + return self._query_iterator( + result, + exit_stack, + chunksize, + column_names, + coerce_float=coerce_float, + parse_dates=parse_dates, + dtype_backend=dtype_backend, + ) + else: + data = result.fetchall() + self.frame = _convert_arrays_to_dataframe( + data, column_names, coerce_float, dtype_backend + ) + + self._harmonize_columns( + parse_dates=parse_dates, dtype_backend=dtype_backend + ) + + if self.index is not None: + self.frame.set_index(self.index, inplace=True) + + return self.frame + + def _index_name(self, index, index_label): + # for writing: index=True to include index in sql table + if index is True: + nlevels = self.frame.index.nlevels + # if index_label is specified, set this as index name(s) + if index_label is not None: + if not isinstance(index_label, list): + index_label = [index_label] + if len(index_label) != nlevels: + raise ValueError( + "Length of 'index_label' should match number of " + f"levels, which is {nlevels}" + ) + return index_label + # return the used column labels for the index columns + if ( + nlevels == 1 + and "index" not in self.frame.columns + and self.frame.index.name is None + ): + return ["index"] + else: + return com.fill_missing_names(self.frame.index.names) + + # for reading: index=(list of) string to specify column to set as index + elif isinstance(index, str): + return [index] + elif isinstance(index, list): + return index + else: + return None + + def _get_column_names_and_types(self, dtype_mapper): + column_names_and_types = [] + if self.index is not None: + for i, idx_label in enumerate(self.index): + idx_type = dtype_mapper(self.frame.index._get_level_values(i)) + column_names_and_types.append((str(idx_label), idx_type, True)) + + column_names_and_types += [ + (str(self.frame.columns[i]), dtype_mapper(self.frame.iloc[:, i]), False) + for i in range(len(self.frame.columns)) + ] + + return column_names_and_types + + def _create_table_setup(self): + from sqlalchemy import ( + Column, + PrimaryKeyConstraint, + Table, + ) + from sqlalchemy.schema import MetaData + + column_names_and_types = self._get_column_names_and_types(self._sqlalchemy_type) + + columns: list[Any] = [ + Column(name, typ, index=is_index) + for name, typ, is_index in column_names_and_types + ] + + if self.keys is not None: + if not is_list_like(self.keys): + keys = [self.keys] + else: + keys = self.keys + pkc = PrimaryKeyConstraint(*keys, name=self.name + "_pk") + columns.append(pkc) + + schema = self.schema or self.pd_sql.meta.schema + + # At this point, attach to new metadata, only attach to self.meta + # once table is created. + meta = MetaData() + return Table(self.name, meta, *columns, schema=schema) + + def _harmonize_columns( + self, + parse_dates=None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", + ) -> None: + """ + Make the DataFrame's column types align with the SQL table + column types. + Need to work around limited NA value support. Floats are always + fine, ints must always be floats if there are Null values. + Booleans are hard because converting bool column with None replaces + all Nones with false. Therefore only convert bool if there are no + NA values. + Datetimes should already be converted to np.datetime64 if supported, + but here we also force conversion if required. + """ + parse_dates = _process_parse_dates_argument(parse_dates) + + for sql_col in self.table.columns: + col_name = sql_col.name + try: + df_col = self.frame[col_name] + + # Handle date parsing upfront; don't try to convert columns + # twice + if col_name in parse_dates: + try: + fmt = parse_dates[col_name] + except TypeError: + fmt = None + self.frame[col_name] = _handle_date_column(df_col, format=fmt) + continue + + # the type the dataframe column should have + col_type = self._get_dtype(sql_col.type) + + if ( + col_type is datetime + or col_type is date + or col_type is DatetimeTZDtype + ): + # Convert tz-aware Datetime SQL columns to UTC + utc = col_type is DatetimeTZDtype + self.frame[col_name] = _handle_date_column(df_col, utc=utc) + elif dtype_backend == "numpy" and col_type is float: + # floats support NA, can always convert! + self.frame[col_name] = df_col.astype(col_type, copy=False) + + elif dtype_backend == "numpy" and len(df_col) == df_col.count(): + # No NA values, can convert ints and bools + if col_type is np.dtype("int64") or col_type is bool: + self.frame[col_name] = df_col.astype(col_type, copy=False) + except KeyError: + pass # this column not in results + + def _sqlalchemy_type(self, col: Index | Series): + dtype: DtypeArg = self.dtype or {} + if is_dict_like(dtype): + dtype = cast(dict, dtype) + if col.name in dtype: + return dtype[col.name] + + # Infer type of column, while ignoring missing values. + # Needed for inserting typed data containing NULLs, GH 8778. + col_type = lib.infer_dtype(col, skipna=True) + + from sqlalchemy.types import ( + TIMESTAMP, + BigInteger, + Boolean, + Date, + DateTime, + Float, + Integer, + SmallInteger, + Text, + Time, + ) + + if col_type in ("datetime64", "datetime"): + # GH 9086: TIMESTAMP is the suggested type if the column contains + # timezone information + try: + # error: Item "Index" of "Union[Index, Series]" has no attribute "dt" + if col.dt.tz is not None: # type: ignore[union-attr] + return TIMESTAMP(timezone=True) + except AttributeError: + # The column is actually a DatetimeIndex + # GH 26761 or an Index with date-like data e.g. 9999-01-01 + if getattr(col, "tz", None) is not None: + return TIMESTAMP(timezone=True) + return DateTime + if col_type == "timedelta64": + warnings.warn( + "the 'timedelta' type is not supported, and will be " + "written as integer values (ns frequency) to the database.", + UserWarning, + stacklevel=find_stack_level(), + ) + return BigInteger + elif col_type == "floating": + if col.dtype == "float32": + return Float(precision=23) + else: + return Float(precision=53) + elif col_type == "integer": + # GH35076 Map pandas integer to optimal SQLAlchemy integer type + if col.dtype.name.lower() in ("int8", "uint8", "int16"): + return SmallInteger + elif col.dtype.name.lower() in ("uint16", "int32"): + return Integer + elif col.dtype.name.lower() == "uint64": + raise ValueError("Unsigned 64 bit integer datatype is not supported") + else: + return BigInteger + elif col_type == "boolean": + return Boolean + elif col_type == "date": + return Date + elif col_type == "time": + return Time + elif col_type == "complex": + raise ValueError("Complex datatypes not supported") + + return Text + + def _get_dtype(self, sqltype): + from sqlalchemy.types import ( + TIMESTAMP, + Boolean, + Date, + DateTime, + Float, + Integer, + ) + + if isinstance(sqltype, Float): + return float + elif isinstance(sqltype, Integer): + # TODO: Refine integer size. + return np.dtype("int64") + elif isinstance(sqltype, TIMESTAMP): + # we have a timezone capable type + if not sqltype.timezone: + return datetime + return DatetimeTZDtype + elif isinstance(sqltype, DateTime): + # Caution: np.datetime64 is also a subclass of np.number. + return datetime + elif isinstance(sqltype, Date): + return date + elif isinstance(sqltype, Boolean): + return bool + return object + + +class PandasSQL(PandasObject, ABC): + """ + Subclasses Should define read_query and to_sql. + """ + + def __enter__(self) -> Self: + return self + + def __exit__(self, *args) -> None: + pass + + def read_table( + self, + table_name: str, + index_col: str | list[str] | None = None, + coerce_float: bool = True, + parse_dates=None, + columns=None, + schema: str | None = None, + chunksize: int | None = None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", + ) -> DataFrame | Iterator[DataFrame]: + raise NotImplementedError + + @abstractmethod + def read_query( + self, + sql: str, + index_col: str | list[str] | None = None, + coerce_float: bool = True, + parse_dates=None, + params=None, + chunksize: int | None = None, + dtype: DtypeArg | None = None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", + ) -> DataFrame | Iterator[DataFrame]: + pass + + @abstractmethod + def to_sql( + self, + frame, + name: str, + if_exists: Literal["fail", "replace", "append"] = "fail", + index: bool = True, + index_label=None, + schema=None, + chunksize: int | None = None, + dtype: DtypeArg | None = None, + method: Literal["multi"] | Callable | None = None, + engine: str = "auto", + **engine_kwargs, + ) -> int | None: + pass + + @abstractmethod + def execute(self, sql: str | Select | TextClause, params=None): + pass + + @abstractmethod + def has_table(self, name: str, schema: str | None = None) -> bool: + pass + + @abstractmethod + def _create_sql_schema( + self, + frame: DataFrame, + table_name: str, + keys: list[str] | None = None, + dtype: DtypeArg | None = None, + schema: str | None = None, + ) -> str: + pass + + +class BaseEngine: + def insert_records( + self, + table: SQLTable, + con, + frame, + name: str, + index: bool | str | list[str] | None = True, + schema=None, + chunksize: int | None = None, + method=None, + **engine_kwargs, + ) -> int | None: + """ + Inserts data into already-prepared table + """ + raise AbstractMethodError(self) + + +class SQLAlchemyEngine(BaseEngine): + def __init__(self) -> None: + import_optional_dependency( + "sqlalchemy", extra="sqlalchemy is required for SQL support." + ) + + def insert_records( + self, + table: SQLTable, + con, + frame, + name: str, + index: bool | str | list[str] | None = True, + schema=None, + chunksize: int | None = None, + method=None, + **engine_kwargs, + ) -> int | None: + from sqlalchemy import exc + + try: + return table.insert(chunksize=chunksize, method=method) + except exc.StatementError as err: + # GH34431 + # https://stackoverflow.com/a/67358288/6067848 + msg = r"""(\(1054, "Unknown column 'inf(e0)?' in 'field list'"\))(?# + )|inf can not be used with MySQL""" + err_text = str(err.orig) + if re.search(msg, err_text): + raise ValueError("inf cannot be used with MySQL") from err + raise err + + +def get_engine(engine: str) -> BaseEngine: + """return our implementation""" + if engine == "auto": + engine = get_option("io.sql.engine") + + if engine == "auto": + # try engines in this order + engine_classes = [SQLAlchemyEngine] + + error_msgs = "" + for engine_class in engine_classes: + try: + return engine_class() + except ImportError as err: + error_msgs += "\n - " + str(err) + + raise ImportError( + "Unable to find a usable engine; " + "tried using: 'sqlalchemy'.\n" + "A suitable version of " + "sqlalchemy is required for sql I/O " + "support.\n" + "Trying to import the above resulted in these errors:" + f"{error_msgs}" + ) + + if engine == "sqlalchemy": + return SQLAlchemyEngine() + + raise ValueError("engine must be one of 'auto', 'sqlalchemy'") + + +class SQLDatabase(PandasSQL): + """ + This class enables conversion between DataFrame and SQL databases + using SQLAlchemy to handle DataBase abstraction. + + Parameters + ---------- + con : SQLAlchemy Connectable or URI string. + Connectable to connect with the database. Using SQLAlchemy makes it + possible to use any DB supported by that library. + schema : string, default None + Name of SQL schema in database to write to (if database flavor + supports this). If None, use default schema (default). + need_transaction : bool, default False + If True, SQLDatabase will create a transaction. + + """ + + def __init__( + self, con, schema: str | None = None, need_transaction: bool = False + ) -> None: + from sqlalchemy import create_engine + from sqlalchemy.engine import Engine + from sqlalchemy.schema import MetaData + + # self.exit_stack cleans up the Engine and Connection and commits the + # transaction if any of those objects was created below. + # Cleanup happens either in self.__exit__ or at the end of the iterator + # returned by read_sql when chunksize is not None. + self.exit_stack = ExitStack() + if isinstance(con, str): + con = create_engine(con) + self.exit_stack.callback(con.dispose) + if isinstance(con, Engine): + con = self.exit_stack.enter_context(con.connect()) + if need_transaction and not con.in_transaction(): + self.exit_stack.enter_context(con.begin()) + self.con = con + self.meta = MetaData(schema=schema) + self.returns_generator = False + + def __exit__(self, *args) -> None: + if not self.returns_generator: + self.exit_stack.close() + + @contextmanager + def run_transaction(self): + if not self.con.in_transaction(): + with self.con.begin(): + yield self.con + else: + yield self.con + + def execute(self, sql: str | Select | TextClause, params=None): + """Simple passthrough to SQLAlchemy connectable""" + args = [] if params is None else [params] + if isinstance(sql, str): + return self.con.exec_driver_sql(sql, *args) + return self.con.execute(sql, *args) + + def read_table( + self, + table_name: str, + index_col: str | list[str] | None = None, + coerce_float: bool = True, + parse_dates=None, + columns=None, + schema: str | None = None, + chunksize: int | None = None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", + ) -> DataFrame | Iterator[DataFrame]: + """ + Read SQL database table into a DataFrame. + + Parameters + ---------- + table_name : str + Name of SQL table in database. + index_col : string, optional, default: None + Column to set as index. + coerce_float : bool, default True + Attempts to convert values of non-string, non-numeric objects + (like decimal.Decimal) to floating point. This can result in + loss of precision. + parse_dates : list or dict, default: None + - List of column names to parse as dates. + - Dict of ``{column_name: format string}`` where format string is + strftime compatible in case of parsing string times, or is one of + (D, s, ns, ms, us) in case of parsing integer timestamps. + - Dict of ``{column_name: arg}``, where the arg corresponds + to the keyword arguments of :func:`pandas.to_datetime`. + Especially useful with databases without native Datetime support, + such as SQLite. + columns : list, default: None + List of column names to select from SQL table. + schema : string, default None + Name of SQL schema in database to query (if database flavor + supports this). If specified, this overwrites the default + schema of the SQL database object. + chunksize : int, default None + If specified, return an iterator where `chunksize` is the number + of rows to include in each chunk. + dtype_backend : {'numpy_nullable', 'pyarrow'}, default 'numpy_nullable' + Back-end data type applied to the resultant :class:`DataFrame` + (still experimental). Behaviour is as follows: + + * ``"numpy_nullable"``: returns nullable-dtype-backed :class:`DataFrame` + (default). + * ``"pyarrow"``: returns pyarrow-backed nullable :class:`ArrowDtype` + DataFrame. + + .. versionadded:: 2.0 + + Returns + ------- + DataFrame + + See Also + -------- + pandas.read_sql_table + SQLDatabase.read_query + + """ + self.meta.reflect(bind=self.con, only=[table_name], views=True) + table = SQLTable(table_name, self, index=index_col, schema=schema) + if chunksize is not None: + self.returns_generator = True + return table.read( + self.exit_stack, + coerce_float=coerce_float, + parse_dates=parse_dates, + columns=columns, + chunksize=chunksize, + dtype_backend=dtype_backend, + ) + + @staticmethod + def _query_iterator( + result, + exit_stack: ExitStack, + chunksize: int, + columns, + index_col=None, + coerce_float: bool = True, + parse_dates=None, + dtype: DtypeArg | None = None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", + ): + """Return generator through chunked result set""" + has_read_data = False + with exit_stack: + while True: + data = result.fetchmany(chunksize) + if not data: + if not has_read_data: + yield _wrap_result( + [], + columns, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + dtype=dtype, + dtype_backend=dtype_backend, + ) + break + + has_read_data = True + yield _wrap_result( + data, + columns, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + dtype=dtype, + dtype_backend=dtype_backend, + ) + + def read_query( + self, + sql: str, + index_col: str | list[str] | None = None, + coerce_float: bool = True, + parse_dates=None, + params=None, + chunksize: int | None = None, + dtype: DtypeArg | None = None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", + ) -> DataFrame | Iterator[DataFrame]: + """ + Read SQL query into a DataFrame. + + Parameters + ---------- + sql : str + SQL query to be executed. + index_col : string, optional, default: None + Column name to use as index for the returned DataFrame object. + coerce_float : bool, default True + Attempt to convert values of non-string, non-numeric objects (like + decimal.Decimal) to floating point, useful for SQL result sets. + params : list, tuple or dict, optional, default: None + List of parameters to pass to execute method. The syntax used + to pass parameters is database driver dependent. Check your + database driver documentation for which of the five syntax styles, + described in PEP 249's paramstyle, is supported. + Eg. for psycopg2, uses %(name)s so use params={'name' : 'value'} + parse_dates : list or dict, default: None + - List of column names to parse as dates. + - Dict of ``{column_name: format string}`` where format string is + strftime compatible in case of parsing string times, or is one of + (D, s, ns, ms, us) in case of parsing integer timestamps. + - Dict of ``{column_name: arg dict}``, where the arg dict + corresponds to the keyword arguments of + :func:`pandas.to_datetime` Especially useful with databases + without native Datetime support, such as SQLite. + chunksize : int, default None + If specified, return an iterator where `chunksize` is the number + of rows to include in each chunk. + dtype : Type name or dict of columns + Data type for data or columns. E.g. np.float64 or + {'a': np.float64, 'b': np.int32, 'c': 'Int64'} + + .. versionadded:: 1.3.0 + + Returns + ------- + DataFrame + + See Also + -------- + read_sql_table : Read SQL database table into a DataFrame. + read_sql + + """ + result = self.execute(sql, params) + columns = result.keys() + + if chunksize is not None: + self.returns_generator = True + return self._query_iterator( + result, + self.exit_stack, + chunksize, + columns, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + dtype=dtype, + dtype_backend=dtype_backend, + ) + else: + data = result.fetchall() + frame = _wrap_result( + data, + columns, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + dtype=dtype, + dtype_backend=dtype_backend, + ) + return frame + + read_sql = read_query + + def prep_table( + self, + frame, + name: str, + if_exists: Literal["fail", "replace", "append"] = "fail", + index: bool | str | list[str] | None = True, + index_label=None, + schema=None, + dtype: DtypeArg | None = None, + ) -> SQLTable: + """ + Prepares table in the database for data insertion. Creates it if needed, etc. + """ + if dtype: + if not is_dict_like(dtype): + # error: Value expression in dictionary comprehension has incompatible + # type "Union[ExtensionDtype, str, dtype[Any], Type[object], + # Dict[Hashable, Union[ExtensionDtype, Union[str, dtype[Any]], + # Type[str], Type[float], Type[int], Type[complex], Type[bool], + # Type[object]]]]"; expected type "Union[ExtensionDtype, str, + # dtype[Any], Type[object]]" + dtype = {col_name: dtype for col_name in frame} # type: ignore[misc] + else: + dtype = cast(dict, dtype) + + from sqlalchemy.types import TypeEngine + + for col, my_type in dtype.items(): + if isinstance(my_type, type) and issubclass(my_type, TypeEngine): + pass + elif isinstance(my_type, TypeEngine): + pass + else: + raise ValueError(f"The type of {col} is not a SQLAlchemy type") + + table = SQLTable( + name, + self, + frame=frame, + index=index, + if_exists=if_exists, + index_label=index_label, + schema=schema, + dtype=dtype, + ) + table.create() + return table + + def check_case_sensitive( + self, + name: str, + schema: str | None, + ) -> None: + """ + Checks table name for issues with case-sensitivity. + Method is called after data is inserted. + """ + if not name.isdigit() and not name.islower(): + # check for potentially case sensitivity issues (GH7815) + # Only check when name is not a number and name is not lower case + from sqlalchemy import inspect as sqlalchemy_inspect + + insp = sqlalchemy_inspect(self.con) + table_names = insp.get_table_names(schema=schema or self.meta.schema) + if name not in table_names: + msg = ( + f"The provided table name '{name}' is not found exactly as " + "such in the database after writing the table, possibly " + "due to case sensitivity issues. Consider using lower " + "case table names." + ) + warnings.warn( + msg, + UserWarning, + stacklevel=find_stack_level(), + ) + + def to_sql( + self, + frame, + name: str, + if_exists: Literal["fail", "replace", "append"] = "fail", + index: bool = True, + index_label=None, + schema: str | None = None, + chunksize: int | None = None, + dtype: DtypeArg | None = None, + method: Literal["multi"] | Callable | None = None, + engine: str = "auto", + **engine_kwargs, + ) -> int | None: + """ + Write records stored in a DataFrame to a SQL database. + + Parameters + ---------- + frame : DataFrame + name : string + Name of SQL table. + if_exists : {'fail', 'replace', 'append'}, default 'fail' + - fail: If table exists, do nothing. + - replace: If table exists, drop it, recreate it, and insert data. + - append: If table exists, insert data. Create if does not exist. + index : boolean, default True + Write DataFrame index as a column. + index_label : string or sequence, default None + Column label for index column(s). If None is given (default) and + `index` is True, then the index names are used. + A sequence should be given if the DataFrame uses MultiIndex. + schema : string, default None + Name of SQL schema in database to write to (if database flavor + supports this). If specified, this overwrites the default + schema of the SQLDatabase object. + chunksize : int, default None + If not None, then rows will be written in batches of this size at a + time. If None, all rows will be written at once. + dtype : single type or dict of column name to SQL type, default None + Optional specifying the datatype for columns. The SQL type should + be a SQLAlchemy type. If all columns are of the same type, one + single value can be used. + method : {None', 'multi', callable}, default None + Controls the SQL insertion clause used: + + * None : Uses standard SQL ``INSERT`` clause (one per row). + * 'multi': Pass multiple values in a single ``INSERT`` clause. + * callable with signature ``(pd_table, conn, keys, data_iter)``. + + Details and a sample callable implementation can be found in the + section :ref:`insert method `. + engine : {'auto', 'sqlalchemy'}, default 'auto' + SQL engine library to use. If 'auto', then the option + ``io.sql.engine`` is used. The default ``io.sql.engine`` + behavior is 'sqlalchemy' + + .. versionadded:: 1.3.0 + + **engine_kwargs + Any additional kwargs are passed to the engine. + """ + sql_engine = get_engine(engine) + + table = self.prep_table( + frame=frame, + name=name, + if_exists=if_exists, + index=index, + index_label=index_label, + schema=schema, + dtype=dtype, + ) + + total_inserted = sql_engine.insert_records( + table=table, + con=self.con, + frame=frame, + name=name, + index=index, + schema=schema, + chunksize=chunksize, + method=method, + **engine_kwargs, + ) + + self.check_case_sensitive(name=name, schema=schema) + return total_inserted + + @property + def tables(self): + return self.meta.tables + + def has_table(self, name: str, schema: str | None = None) -> bool: + from sqlalchemy import inspect as sqlalchemy_inspect + + insp = sqlalchemy_inspect(self.con) + return insp.has_table(name, schema or self.meta.schema) + + def get_table(self, table_name: str, schema: str | None = None) -> Table: + from sqlalchemy import ( + Numeric, + Table, + ) + + schema = schema or self.meta.schema + tbl = Table(table_name, self.meta, autoload_with=self.con, schema=schema) + for column in tbl.columns: + if isinstance(column.type, Numeric): + column.type.asdecimal = False + return tbl + + def drop_table(self, table_name: str, schema: str | None = None) -> None: + schema = schema or self.meta.schema + if self.has_table(table_name, schema): + self.meta.reflect( + bind=self.con, only=[table_name], schema=schema, views=True + ) + with self.run_transaction(): + self.get_table(table_name, schema).drop(bind=self.con) + self.meta.clear() + + def _create_sql_schema( + self, + frame: DataFrame, + table_name: str, + keys: list[str] | None = None, + dtype: DtypeArg | None = None, + schema: str | None = None, + ) -> str: + table = SQLTable( + table_name, + self, + frame=frame, + index=False, + keys=keys, + dtype=dtype, + schema=schema, + ) + return str(table.sql_schema()) + + +# ---- SQL without SQLAlchemy --- + + +class ADBCDatabase(PandasSQL): + """ + This class enables conversion between DataFrame and SQL databases + using ADBC to handle DataBase abstraction. + + Parameters + ---------- + con : adbc_driver_manager.dbapi.Connection + """ + + def __init__(self, con) -> None: + self.con = con + + @contextmanager + def run_transaction(self): + with self.con.cursor() as cur: + try: + yield cur + except Exception: + self.con.rollback() + raise + self.con.commit() + + def execute(self, sql: str | Select | TextClause, params=None): + if not isinstance(sql, str): + raise TypeError("Query must be a string unless using sqlalchemy.") + args = [] if params is None else [params] + cur = self.con.cursor() + try: + cur.execute(sql, *args) + return cur + except Exception as exc: + try: + self.con.rollback() + except Exception as inner_exc: # pragma: no cover + ex = DatabaseError( + f"Execution failed on sql: {sql}\n{exc}\nunable to rollback" + ) + raise ex from inner_exc + + ex = DatabaseError(f"Execution failed on sql '{sql}': {exc}") + raise ex from exc + + def read_table( + self, + table_name: str, + index_col: str | list[str] | None = None, + coerce_float: bool = True, + parse_dates=None, + columns=None, + schema: str | None = None, + chunksize: int | None = None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", + ) -> DataFrame | Iterator[DataFrame]: + """ + Read SQL database table into a DataFrame. + + Parameters + ---------- + table_name : str + Name of SQL table in database. + coerce_float : bool, default True + Raises NotImplementedError + parse_dates : list or dict, default: None + - List of column names to parse as dates. + - Dict of ``{column_name: format string}`` where format string is + strftime compatible in case of parsing string times, or is one of + (D, s, ns, ms, us) in case of parsing integer timestamps. + - Dict of ``{column_name: arg}``, where the arg corresponds + to the keyword arguments of :func:`pandas.to_datetime`. + Especially useful with databases without native Datetime support, + such as SQLite. + columns : list, default: None + List of column names to select from SQL table. + schema : string, default None + Name of SQL schema in database to query (if database flavor + supports this). If specified, this overwrites the default + schema of the SQL database object. + chunksize : int, default None + Raises NotImplementedError + dtype_backend : {'numpy_nullable', 'pyarrow'}, default 'numpy_nullable' + Back-end data type applied to the resultant :class:`DataFrame` + (still experimental). Behaviour is as follows: + + * ``"numpy_nullable"``: returns nullable-dtype-backed :class:`DataFrame` + (default). + * ``"pyarrow"``: returns pyarrow-backed nullable :class:`ArrowDtype` + DataFrame. + + .. versionadded:: 2.0 + + Returns + ------- + DataFrame + + See Also + -------- + pandas.read_sql_table + SQLDatabase.read_query + + """ + if coerce_float is not True: + raise NotImplementedError( + "'coerce_float' is not implemented for ADBC drivers" + ) + if chunksize: + raise NotImplementedError("'chunksize' is not implemented for ADBC drivers") + + if columns: + if index_col: + index_select = maybe_make_list(index_col) + else: + index_select = [] + to_select = index_select + columns + select_list = ", ".join(f'"{x}"' for x in to_select) + else: + select_list = "*" + if schema: + stmt = f"SELECT {select_list} FROM {schema}.{table_name}" + else: + stmt = f"SELECT {select_list} FROM {table_name}" + + mapping: type[ArrowDtype] | None | Callable + if dtype_backend == "pyarrow": + mapping = ArrowDtype + elif dtype_backend == "numpy_nullable": + from pandas.io._util import _arrow_dtype_mapping + + mapping = _arrow_dtype_mapping().get + elif using_pyarrow_string_dtype(): + from pandas.io._util import arrow_string_types_mapper + + arrow_string_types_mapper() + else: + mapping = None + + with self.con.cursor() as cur: + cur.execute(stmt) + df = cur.fetch_arrow_table().to_pandas(types_mapper=mapping) + + return _wrap_result_adbc( + df, + index_col=index_col, + parse_dates=parse_dates, + ) + + def read_query( + self, + sql: str, + index_col: str | list[str] | None = None, + coerce_float: bool = True, + parse_dates=None, + params=None, + chunksize: int | None = None, + dtype: DtypeArg | None = None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", + ) -> DataFrame | Iterator[DataFrame]: + """ + Read SQL query into a DataFrame. + + Parameters + ---------- + sql : str + SQL query to be executed. + index_col : string, optional, default: None + Column name to use as index for the returned DataFrame object. + coerce_float : bool, default True + Raises NotImplementedError + params : list, tuple or dict, optional, default: None + Raises NotImplementedError + parse_dates : list or dict, default: None + - List of column names to parse as dates. + - Dict of ``{column_name: format string}`` where format string is + strftime compatible in case of parsing string times, or is one of + (D, s, ns, ms, us) in case of parsing integer timestamps. + - Dict of ``{column_name: arg dict}``, where the arg dict + corresponds to the keyword arguments of + :func:`pandas.to_datetime` Especially useful with databases + without native Datetime support, such as SQLite. + chunksize : int, default None + Raises NotImplementedError + dtype : Type name or dict of columns + Data type for data or columns. E.g. np.float64 or + {'a': np.float64, 'b': np.int32, 'c': 'Int64'} + + .. versionadded:: 1.3.0 + + Returns + ------- + DataFrame + + See Also + -------- + read_sql_table : Read SQL database table into a DataFrame. + read_sql + + """ + if coerce_float is not True: + raise NotImplementedError( + "'coerce_float' is not implemented for ADBC drivers" + ) + if params: + raise NotImplementedError("'params' is not implemented for ADBC drivers") + if chunksize: + raise NotImplementedError("'chunksize' is not implemented for ADBC drivers") + + mapping: type[ArrowDtype] | None | Callable + if dtype_backend == "pyarrow": + mapping = ArrowDtype + elif dtype_backend == "numpy_nullable": + from pandas.io._util import _arrow_dtype_mapping + + mapping = _arrow_dtype_mapping().get + else: + mapping = None + + with self.con.cursor() as cur: + cur.execute(sql) + df = cur.fetch_arrow_table().to_pandas(types_mapper=mapping) + + return _wrap_result_adbc( + df, + index_col=index_col, + parse_dates=parse_dates, + dtype=dtype, + ) + + read_sql = read_query + + def to_sql( + self, + frame, + name: str, + if_exists: Literal["fail", "replace", "append"] = "fail", + index: bool = True, + index_label=None, + schema: str | None = None, + chunksize: int | None = None, + dtype: DtypeArg | None = None, + method: Literal["multi"] | Callable | None = None, + engine: str = "auto", + **engine_kwargs, + ) -> int | None: + """ + Write records stored in a DataFrame to a SQL database. + + Parameters + ---------- + frame : DataFrame + name : string + Name of SQL table. + if_exists : {'fail', 'replace', 'append'}, default 'fail' + - fail: If table exists, do nothing. + - replace: If table exists, drop it, recreate it, and insert data. + - append: If table exists, insert data. Create if does not exist. + index : boolean, default True + Write DataFrame index as a column. + index_label : string or sequence, default None + Raises NotImplementedError + schema : string, default None + Name of SQL schema in database to write to (if database flavor + supports this). If specified, this overwrites the default + schema of the SQLDatabase object. + chunksize : int, default None + Raises NotImplementedError + dtype : single type or dict of column name to SQL type, default None + Raises NotImplementedError + method : {None', 'multi', callable}, default None + Raises NotImplementedError + engine : {'auto', 'sqlalchemy'}, default 'auto' + Raises NotImplementedError if not set to 'auto' + """ + if index_label: + raise NotImplementedError( + "'index_label' is not implemented for ADBC drivers" + ) + if chunksize: + raise NotImplementedError("'chunksize' is not implemented for ADBC drivers") + if dtype: + raise NotImplementedError("'dtype' is not implemented for ADBC drivers") + if method: + raise NotImplementedError("'method' is not implemented for ADBC drivers") + if engine != "auto": + raise NotImplementedError( + "engine != 'auto' not implemented for ADBC drivers" + ) + + if schema: + table_name = f"{schema}.{name}" + else: + table_name = name + + # pandas if_exists="append" will still create the + # table if it does not exist; ADBC is more explicit with append/create + # as applicable modes, so the semantics get blurred across + # the libraries + mode = "create" + if self.has_table(name, schema): + if if_exists == "fail": + raise ValueError(f"Table '{table_name}' already exists.") + elif if_exists == "replace": + with self.con.cursor() as cur: + cur.execute(f"DROP TABLE {table_name}") + elif if_exists == "append": + mode = "append" + + import pyarrow as pa + + try: + tbl = pa.Table.from_pandas(frame, preserve_index=index) + except pa.ArrowNotImplementedError as exc: + raise ValueError("datatypes not supported") from exc + + with self.con.cursor() as cur: + total_inserted = cur.adbc_ingest( + table_name=name, data=tbl, mode=mode, db_schema_name=schema + ) + + self.con.commit() + return total_inserted + + def has_table(self, name: str, schema: str | None = None) -> bool: + meta = self.con.adbc_get_objects( + db_schema_filter=schema, table_name_filter=name + ).read_all() + + for catalog_schema in meta["catalog_db_schemas"].to_pylist(): + if not catalog_schema: + continue + for schema_record in catalog_schema: + if not schema_record: + continue + + for table_record in schema_record["db_schema_tables"]: + if table_record["table_name"] == name: + return True + + return False + + def _create_sql_schema( + self, + frame: DataFrame, + table_name: str, + keys: list[str] | None = None, + dtype: DtypeArg | None = None, + schema: str | None = None, + ) -> str: + raise NotImplementedError("not implemented for adbc") + + +# sqlite-specific sql strings and handler class +# dictionary used for readability purposes +_SQL_TYPES = { + "string": "TEXT", + "floating": "REAL", + "integer": "INTEGER", + "datetime": "TIMESTAMP", + "date": "DATE", + "time": "TIME", + "boolean": "INTEGER", +} + + +def _get_unicode_name(name: object): + try: + uname = str(name).encode("utf-8", "strict").decode("utf-8") + except UnicodeError as err: + raise ValueError(f"Cannot convert identifier to UTF-8: '{name}'") from err + return uname + + +def _get_valid_sqlite_name(name: object): + # See https://stackoverflow.com/questions/6514274/how-do-you-escape-strings\ + # -for-sqlite-table-column-names-in-python + # Ensure the string can be encoded as UTF-8. + # Ensure the string does not include any NUL characters. + # Replace all " with "". + # Wrap the entire thing in double quotes. + + uname = _get_unicode_name(name) + if not len(uname): + raise ValueError("Empty table or column name specified") + + nul_index = uname.find("\x00") + if nul_index >= 0: + raise ValueError("SQLite identifier cannot contain NULs") + return '"' + uname.replace('"', '""') + '"' + + +class SQLiteTable(SQLTable): + """ + Patch the SQLTable for fallback support. + Instead of a table variable just use the Create Table statement. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self._register_date_adapters() + + def _register_date_adapters(self) -> None: + # GH 8341 + # register an adapter callable for datetime.time object + import sqlite3 + + # this will transform time(12,34,56,789) into '12:34:56.000789' + # (this is what sqlalchemy does) + def _adapt_time(t) -> str: + # This is faster than strftime + return f"{t.hour:02d}:{t.minute:02d}:{t.second:02d}.{t.microsecond:06d}" + + # Also register adapters for date/datetime and co + # xref https://docs.python.org/3.12/library/sqlite3.html#adapter-and-converter-recipes + # Python 3.12+ doesn't auto-register adapters for us anymore + + adapt_date_iso = lambda val: val.isoformat() + adapt_datetime_iso = lambda val: val.isoformat(" ") + + sqlite3.register_adapter(time, _adapt_time) + + sqlite3.register_adapter(date, adapt_date_iso) + sqlite3.register_adapter(datetime, adapt_datetime_iso) + + convert_date = lambda val: date.fromisoformat(val.decode()) + convert_timestamp = lambda val: datetime.fromisoformat(val.decode()) + + sqlite3.register_converter("date", convert_date) + sqlite3.register_converter("timestamp", convert_timestamp) + + def sql_schema(self) -> str: + return str(";\n".join(self.table)) + + def _execute_create(self) -> None: + with self.pd_sql.run_transaction() as conn: + for stmt in self.table: + conn.execute(stmt) + + def insert_statement(self, *, num_rows: int) -> str: + names = list(map(str, self.frame.columns)) + wld = "?" # wildcard char + escape = _get_valid_sqlite_name + + if self.index is not None: + for idx in self.index[::-1]: + names.insert(0, idx) + + bracketed_names = [escape(column) for column in names] + col_names = ",".join(bracketed_names) + + row_wildcards = ",".join([wld] * len(names)) + wildcards = ",".join([f"({row_wildcards})" for _ in range(num_rows)]) + insert_statement = ( + f"INSERT INTO {escape(self.name)} ({col_names}) VALUES {wildcards}" + ) + return insert_statement + + def _execute_insert(self, conn, keys, data_iter) -> int: + data_list = list(data_iter) + conn.executemany(self.insert_statement(num_rows=1), data_list) + return conn.rowcount + + def _execute_insert_multi(self, conn, keys, data_iter) -> int: + data_list = list(data_iter) + flattened_data = [x for row in data_list for x in row] + conn.execute(self.insert_statement(num_rows=len(data_list)), flattened_data) + return conn.rowcount + + def _create_table_setup(self): + """ + Return a list of SQL statements that creates a table reflecting the + structure of a DataFrame. The first entry will be a CREATE TABLE + statement while the rest will be CREATE INDEX statements. + """ + column_names_and_types = self._get_column_names_and_types(self._sql_type_name) + escape = _get_valid_sqlite_name + + create_tbl_stmts = [ + escape(cname) + " " + ctype for cname, ctype, _ in column_names_and_types + ] + + if self.keys is not None and len(self.keys): + if not is_list_like(self.keys): + keys = [self.keys] + else: + keys = self.keys + cnames_br = ", ".join([escape(c) for c in keys]) + create_tbl_stmts.append( + f"CONSTRAINT {self.name}_pk PRIMARY KEY ({cnames_br})" + ) + if self.schema: + schema_name = self.schema + "." + else: + schema_name = "" + create_stmts = [ + "CREATE TABLE " + + schema_name + + escape(self.name) + + " (\n" + + ",\n ".join(create_tbl_stmts) + + "\n)" + ] + + ix_cols = [cname for cname, _, is_index in column_names_and_types if is_index] + if len(ix_cols): + cnames = "_".join(ix_cols) + cnames_br = ",".join([escape(c) for c in ix_cols]) + create_stmts.append( + "CREATE INDEX " + + escape("ix_" + self.name + "_" + cnames) + + "ON " + + escape(self.name) + + " (" + + cnames_br + + ")" + ) + + return create_stmts + + def _sql_type_name(self, col): + dtype: DtypeArg = self.dtype or {} + if is_dict_like(dtype): + dtype = cast(dict, dtype) + if col.name in dtype: + return dtype[col.name] + + # Infer type of column, while ignoring missing values. + # Needed for inserting typed data containing NULLs, GH 8778. + col_type = lib.infer_dtype(col, skipna=True) + + if col_type == "timedelta64": + warnings.warn( + "the 'timedelta' type is not supported, and will be " + "written as integer values (ns frequency) to the database.", + UserWarning, + stacklevel=find_stack_level(), + ) + col_type = "integer" + + elif col_type == "datetime64": + col_type = "datetime" + + elif col_type == "empty": + col_type = "string" + + elif col_type == "complex": + raise ValueError("Complex datatypes not supported") + + if col_type not in _SQL_TYPES: + col_type = "string" + + return _SQL_TYPES[col_type] + + +class SQLiteDatabase(PandasSQL): + """ + Version of SQLDatabase to support SQLite connections (fallback without + SQLAlchemy). This should only be used internally. + + Parameters + ---------- + con : sqlite connection object + + """ + + def __init__(self, con) -> None: + self.con = con + + @contextmanager + def run_transaction(self): + cur = self.con.cursor() + try: + yield cur + self.con.commit() + except Exception: + self.con.rollback() + raise + finally: + cur.close() + + def execute(self, sql: str | Select | TextClause, params=None): + if not isinstance(sql, str): + raise TypeError("Query must be a string unless using sqlalchemy.") + args = [] if params is None else [params] + cur = self.con.cursor() + try: + cur.execute(sql, *args) + return cur + except Exception as exc: + try: + self.con.rollback() + except Exception as inner_exc: # pragma: no cover + ex = DatabaseError( + f"Execution failed on sql: {sql}\n{exc}\nunable to rollback" + ) + raise ex from inner_exc + + ex = DatabaseError(f"Execution failed on sql '{sql}': {exc}") + raise ex from exc + + @staticmethod + def _query_iterator( + cursor, + chunksize: int, + columns, + index_col=None, + coerce_float: bool = True, + parse_dates=None, + dtype: DtypeArg | None = None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", + ): + """Return generator through chunked result set""" + has_read_data = False + while True: + data = cursor.fetchmany(chunksize) + if type(data) == tuple: + data = list(data) + if not data: + cursor.close() + if not has_read_data: + result = DataFrame.from_records( + [], columns=columns, coerce_float=coerce_float + ) + if dtype: + result = result.astype(dtype) + yield result + break + + has_read_data = True + yield _wrap_result( + data, + columns, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + dtype=dtype, + dtype_backend=dtype_backend, + ) + + def read_query( + self, + sql, + index_col=None, + coerce_float: bool = True, + parse_dates=None, + params=None, + chunksize: int | None = None, + dtype: DtypeArg | None = None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", + ) -> DataFrame | Iterator[DataFrame]: + cursor = self.execute(sql, params) + columns = [col_desc[0] for col_desc in cursor.description] + + if chunksize is not None: + return self._query_iterator( + cursor, + chunksize, + columns, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + dtype=dtype, + dtype_backend=dtype_backend, + ) + else: + data = self._fetchall_as_list(cursor) + cursor.close() + + frame = _wrap_result( + data, + columns, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + dtype=dtype, + dtype_backend=dtype_backend, + ) + return frame + + def _fetchall_as_list(self, cur): + result = cur.fetchall() + if not isinstance(result, list): + result = list(result) + return result + + def to_sql( + self, + frame, + name: str, + if_exists: str = "fail", + index: bool = True, + index_label=None, + schema=None, + chunksize: int | None = None, + dtype: DtypeArg | None = None, + method: Literal["multi"] | Callable | None = None, + engine: str = "auto", + **engine_kwargs, + ) -> int | None: + """ + Write records stored in a DataFrame to a SQL database. + + Parameters + ---------- + frame: DataFrame + name: string + Name of SQL table. + if_exists: {'fail', 'replace', 'append'}, default 'fail' + fail: If table exists, do nothing. + replace: If table exists, drop it, recreate it, and insert data. + append: If table exists, insert data. Create if it does not exist. + index : bool, default True + Write DataFrame index as a column + index_label : string or sequence, default None + Column label for index column(s). If None is given (default) and + `index` is True, then the index names are used. + A sequence should be given if the DataFrame uses MultiIndex. + schema : string, default None + Ignored parameter included for compatibility with SQLAlchemy + version of ``to_sql``. + chunksize : int, default None + If not None, then rows will be written in batches of this + size at a time. If None, all rows will be written at once. + dtype : single type or dict of column name to SQL type, default None + Optional specifying the datatype for columns. The SQL type should + be a string. If all columns are of the same type, one single value + can be used. + method : {None, 'multi', callable}, default None + Controls the SQL insertion clause used: + + * None : Uses standard SQL ``INSERT`` clause (one per row). + * 'multi': Pass multiple values in a single ``INSERT`` clause. + * callable with signature ``(pd_table, conn, keys, data_iter)``. + + Details and a sample callable implementation can be found in the + section :ref:`insert method `. + """ + if dtype: + if not is_dict_like(dtype): + # error: Value expression in dictionary comprehension has incompatible + # type "Union[ExtensionDtype, str, dtype[Any], Type[object], + # Dict[Hashable, Union[ExtensionDtype, Union[str, dtype[Any]], + # Type[str], Type[float], Type[int], Type[complex], Type[bool], + # Type[object]]]]"; expected type "Union[ExtensionDtype, str, + # dtype[Any], Type[object]]" + dtype = {col_name: dtype for col_name in frame} # type: ignore[misc] + else: + dtype = cast(dict, dtype) + + for col, my_type in dtype.items(): + if not isinstance(my_type, str): + raise ValueError(f"{col} ({my_type}) not a string") + + table = SQLiteTable( + name, + self, + frame=frame, + index=index, + if_exists=if_exists, + index_label=index_label, + dtype=dtype, + ) + table.create() + return table.insert(chunksize, method) + + def has_table(self, name: str, schema: str | None = None) -> bool: + wld = "?" + query = f""" + SELECT + name + FROM + sqlite_master + WHERE + type IN ('table', 'view') + AND name={wld}; + """ + + return len(self.execute(query, [name]).fetchall()) > 0 + + def get_table(self, table_name: str, schema: str | None = None) -> None: + return None # not supported in fallback mode + + def drop_table(self, name: str, schema: str | None = None) -> None: + drop_sql = f"DROP TABLE {_get_valid_sqlite_name(name)}" + self.execute(drop_sql) + + def _create_sql_schema( + self, + frame, + table_name: str, + keys=None, + dtype: DtypeArg | None = None, + schema: str | None = None, + ) -> str: + table = SQLiteTable( + table_name, + self, + frame=frame, + index=False, + keys=keys, + dtype=dtype, + schema=schema, + ) + return str(table.sql_schema()) + + +def get_schema( + frame, + name: str, + keys=None, + con=None, + dtype: DtypeArg | None = None, + schema: str | None = None, +) -> str: + """ + Get the SQL db table schema for the given frame. + + Parameters + ---------- + frame : DataFrame + name : str + name of SQL table + keys : string or sequence, default: None + columns to use a primary key + con: ADBC Connection, SQLAlchemy connectable, sqlite3 connection, default: None + ADBC provides high performance I/O with native type support, where available. + Using SQLAlchemy makes it possible to use any DB supported by that + library + If a DBAPI2 object, only sqlite3 is supported. + dtype : dict of column name to SQL type, default None + Optional specifying the datatype for columns. The SQL type should + be a SQLAlchemy type, or a string for sqlite3 fallback connection. + schema: str, default: None + Optional specifying the schema to be used in creating the table. + """ + with pandasSQL_builder(con=con) as pandas_sql: + return pandas_sql._create_sql_schema( + frame, name, keys=keys, dtype=dtype, schema=schema + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..330538c696c28659cf14b92325487b985b718c31 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_aggregation.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_aggregation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..248e10c97ff13920842374a767b98755815ae94e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_aggregation.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_common.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cca993c8e5b5e93dd69b10cd066a8ed51a9b970d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_common.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_downstream.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_downstream.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01a88b1b5663f6ad81f27c0f115b06fa6e25c729 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_downstream.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_errors.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_errors.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d61a0508e8d1157453cbc1d1849692ef81bc9aa Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_errors.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_expressions.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_expressions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68dd7a0d22a166445bd3e60a60a7e011315a46ee Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_expressions.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_flags.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_flags.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48ba32e3d62c35f7733dff6b0aa75764725dfb0f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_flags.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_multilevel.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_multilevel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba26c2bc1b6e25ff733b54253afe4e35415e8d29 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_multilevel.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_nanops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_nanops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe6606d9c224a757b735cc60732b13c6ab845c5e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_nanops.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_optional_dependency.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_optional_dependency.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de84d7e987e8bb02c4a527c7729f6861eb66c3ea Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_optional_dependency.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_register_accessor.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_register_accessor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbbc7e9d5648b91030a77bf8fc8ea188927742ca Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_register_accessor.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_sorting.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_sorting.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..223f90c0258f05cd831a9d5e5465af949adecdd0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_sorting.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_take.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_take.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..306e93c8c6e81fd6f6c1270f6c3e649eb7a32604 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__pycache__/test_take.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/arrays/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/arrays/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/arrays/masked_shared.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/arrays/masked_shared.py new file mode 100644 index 0000000000000000000000000000000000000000..3e74402263cf9c119ec344c5da48dd8598970f69 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/arrays/masked_shared.py @@ -0,0 +1,154 @@ +""" +Tests shared by MaskedArray subclasses. +""" +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm +from pandas.tests.extension.base import BaseOpsUtil + + +class ComparisonOps(BaseOpsUtil): + def _compare_other(self, data, op, other): + # array + result = pd.Series(op(data, other)) + expected = pd.Series(op(data._data, other), dtype="boolean") + + # fill the nan locations + expected[data._mask] = pd.NA + + tm.assert_series_equal(result, expected) + + # series + ser = pd.Series(data) + result = op(ser, other) + + # Set nullable dtype here to avoid upcasting when setting to pd.NA below + expected = op(pd.Series(data._data), other).astype("boolean") + + # fill the nan locations + expected[data._mask] = pd.NA + + tm.assert_series_equal(result, expected) + + # subclass will override to parametrize 'other' + def test_scalar(self, other, comparison_op, dtype): + op = comparison_op + left = pd.array([1, 0, None], dtype=dtype) + + result = op(left, other) + + if other is pd.NA: + expected = pd.array([None, None, None], dtype="boolean") + else: + values = op(left._data, other) + expected = pd.arrays.BooleanArray(values, left._mask, copy=True) + tm.assert_extension_array_equal(result, expected) + + # ensure we haven't mutated anything inplace + result[0] = pd.NA + tm.assert_extension_array_equal(left, pd.array([1, 0, None], dtype=dtype)) + + +class NumericOps: + # Shared by IntegerArray and FloatingArray, not BooleanArray + + def test_searchsorted_nan(self, dtype): + # The base class casts to object dtype, for which searchsorted returns + # 0 from the left and 10 from the right. + arr = pd.array(range(10), dtype=dtype) + + assert arr.searchsorted(np.nan, side="left") == 10 + assert arr.searchsorted(np.nan, side="right") == 10 + + def test_no_shared_mask(self, data): + result = data + 1 + assert not tm.shares_memory(result, data) + + def test_array(self, comparison_op, dtype): + op = comparison_op + + left = pd.array([0, 1, 2, None, None, None], dtype=dtype) + right = pd.array([0, 1, None, 0, 1, None], dtype=dtype) + + result = op(left, right) + values = op(left._data, right._data) + mask = left._mask | right._mask + + expected = pd.arrays.BooleanArray(values, mask) + tm.assert_extension_array_equal(result, expected) + + # ensure we haven't mutated anything inplace + result[0] = pd.NA + tm.assert_extension_array_equal( + left, pd.array([0, 1, 2, None, None, None], dtype=dtype) + ) + tm.assert_extension_array_equal( + right, pd.array([0, 1, None, 0, 1, None], dtype=dtype) + ) + + def test_compare_with_booleanarray(self, comparison_op, dtype): + op = comparison_op + + left = pd.array([True, False, None] * 3, dtype="boolean") + right = pd.array([0] * 3 + [1] * 3 + [None] * 3, dtype=dtype) + other = pd.array([False] * 3 + [True] * 3 + [None] * 3, dtype="boolean") + + expected = op(left, other) + result = op(left, right) + tm.assert_extension_array_equal(result, expected) + + # reversed op + expected = op(other, left) + result = op(right, left) + tm.assert_extension_array_equal(result, expected) + + def test_compare_to_string(self, dtype): + # GH#28930 + ser = pd.Series([1, None], dtype=dtype) + result = ser == "a" + expected = pd.Series([False, pd.NA], dtype="boolean") + + tm.assert_series_equal(result, expected) + + def test_ufunc_with_out(self, dtype): + arr = pd.array([1, 2, 3], dtype=dtype) + arr2 = pd.array([1, 2, pd.NA], dtype=dtype) + + mask = arr == arr + mask2 = arr2 == arr2 + + result = np.zeros(3, dtype=bool) + result |= mask + # If MaskedArray.__array_ufunc__ handled "out" appropriately, + # `result` should still be an ndarray. + assert isinstance(result, np.ndarray) + assert result.all() + + # result |= mask worked because mask could be cast losslessly to + # boolean ndarray. mask2 can't, so this raises + result = np.zeros(3, dtype=bool) + msg = "Specify an appropriate 'na_value' for this dtype" + with pytest.raises(ValueError, match=msg): + result |= mask2 + + # addition + res = np.add(arr, arr2) + expected = pd.array([2, 4, pd.NA], dtype=dtype) + tm.assert_extension_array_equal(res, expected) + + # when passing out=arr, we will modify 'arr' inplace. + res = np.add(arr, arr2, out=arr) + assert res is arr + tm.assert_extension_array_equal(res, expected) + tm.assert_extension_array_equal(arr, expected) + + def test_mul_td64_array(self, dtype): + # GH#45622 + arr = pd.array([1, 2, pd.NA], dtype=dtype) + other = np.arange(3, dtype=np.int64).view("m8[ns]") + + result = arr * other + expected = pd.array([pd.Timedelta(0), pd.Timedelta(2), pd.NaT]) + tm.assert_extension_array_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/arrays/test_array.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/arrays/test_array.py new file mode 100644 index 0000000000000000000000000000000000000000..96263f498935b0d975b12c74b7cd98c6c4853670 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/arrays/test_array.py @@ -0,0 +1,478 @@ +import datetime +import decimal +import re + +import numpy as np +import pytest +import pytz + +import pandas as pd +import pandas._testing as tm +from pandas.api.extensions import register_extension_dtype +from pandas.arrays import ( + BooleanArray, + DatetimeArray, + FloatingArray, + IntegerArray, + IntervalArray, + SparseArray, + TimedeltaArray, +) +from pandas.core.arrays import ( + NumpyExtensionArray, + period_array, +) +from pandas.tests.extension.decimal import ( + DecimalArray, + DecimalDtype, + to_decimal, +) + + +@pytest.mark.parametrize("dtype_unit", ["M8[h]", "M8[m]", "m8[h]", "M8[m]"]) +def test_dt64_array(dtype_unit): + # PR 53817 + dtype_var = np.dtype(dtype_unit) + msg = ( + r"datetime64 and timedelta64 dtype resolutions other than " + r"'s', 'ms', 'us', and 'ns' are deprecated. " + r"In future releases passing unsupported resolutions will " + r"raise an exception." + ) + with tm.assert_produces_warning(FutureWarning, match=re.escape(msg)): + pd.array([], dtype=dtype_var) + + +@pytest.mark.parametrize( + "data, dtype, expected", + [ + # Basic NumPy defaults. + ([], None, FloatingArray._from_sequence([], dtype="Float64")), + ([1, 2], None, IntegerArray._from_sequence([1, 2], dtype="Int64")), + ([1, 2], object, NumpyExtensionArray(np.array([1, 2], dtype=object))), + ( + [1, 2], + np.dtype("float32"), + NumpyExtensionArray(np.array([1.0, 2.0], dtype=np.dtype("float32"))), + ), + ( + np.array([], dtype=object), + None, + NumpyExtensionArray(np.array([], dtype=object)), + ), + ( + np.array([1, 2], dtype="int64"), + None, + IntegerArray._from_sequence([1, 2], dtype="Int64"), + ), + ( + np.array([1.0, 2.0], dtype="float64"), + None, + FloatingArray._from_sequence([1.0, 2.0], dtype="Float64"), + ), + # String alias passes through to NumPy + ([1, 2], "float32", NumpyExtensionArray(np.array([1, 2], dtype="float32"))), + ([1, 2], "int64", NumpyExtensionArray(np.array([1, 2], dtype=np.int64))), + # GH#44715 FloatingArray does not support float16, so fall + # back to NumpyExtensionArray + ( + np.array([1, 2], dtype=np.float16), + None, + NumpyExtensionArray(np.array([1, 2], dtype=np.float16)), + ), + # idempotency with e.g. pd.array(pd.array([1, 2], dtype="int64")) + ( + NumpyExtensionArray(np.array([1, 2], dtype=np.int32)), + None, + NumpyExtensionArray(np.array([1, 2], dtype=np.int32)), + ), + # Period alias + ( + [pd.Period("2000", "D"), pd.Period("2001", "D")], + "Period[D]", + period_array(["2000", "2001"], freq="D"), + ), + # Period dtype + ( + [pd.Period("2000", "D")], + pd.PeriodDtype("D"), + period_array(["2000"], freq="D"), + ), + # Datetime (naive) + ( + [1, 2], + np.dtype("datetime64[ns]"), + DatetimeArray._from_sequence( + np.array([1, 2], dtype="M8[ns]"), dtype="M8[ns]" + ), + ), + ( + [1, 2], + np.dtype("datetime64[s]"), + DatetimeArray._from_sequence( + np.array([1, 2], dtype="M8[s]"), dtype="M8[s]" + ), + ), + ( + np.array([1, 2], dtype="datetime64[ns]"), + None, + DatetimeArray._from_sequence( + np.array([1, 2], dtype="M8[ns]"), dtype="M8[ns]" + ), + ), + ( + pd.DatetimeIndex(["2000", "2001"]), + np.dtype("datetime64[ns]"), + DatetimeArray._from_sequence(["2000", "2001"], dtype="M8[ns]"), + ), + ( + pd.DatetimeIndex(["2000", "2001"]), + None, + DatetimeArray._from_sequence(["2000", "2001"], dtype="M8[ns]"), + ), + ( + ["2000", "2001"], + np.dtype("datetime64[ns]"), + DatetimeArray._from_sequence(["2000", "2001"], dtype="M8[ns]"), + ), + # Datetime (tz-aware) + ( + ["2000", "2001"], + pd.DatetimeTZDtype(tz="CET"), + DatetimeArray._from_sequence( + ["2000", "2001"], dtype=pd.DatetimeTZDtype(tz="CET") + ), + ), + # Timedelta + ( + ["1h", "2h"], + np.dtype("timedelta64[ns]"), + TimedeltaArray._from_sequence(["1h", "2h"], dtype="m8[ns]"), + ), + ( + pd.TimedeltaIndex(["1h", "2h"]), + np.dtype("timedelta64[ns]"), + TimedeltaArray._from_sequence(["1h", "2h"], dtype="m8[ns]"), + ), + ( + np.array([1, 2], dtype="m8[s]"), + np.dtype("timedelta64[s]"), + TimedeltaArray._from_sequence( + np.array([1, 2], dtype="m8[s]"), dtype="m8[s]" + ), + ), + ( + pd.TimedeltaIndex(["1h", "2h"]), + None, + TimedeltaArray._from_sequence(["1h", "2h"], dtype="m8[ns]"), + ), + ( + # preserve non-nano, i.e. don't cast to NumpyExtensionArray + TimedeltaArray._simple_new( + np.arange(5, dtype=np.int64).view("m8[s]"), dtype=np.dtype("m8[s]") + ), + None, + TimedeltaArray._simple_new( + np.arange(5, dtype=np.int64).view("m8[s]"), dtype=np.dtype("m8[s]") + ), + ), + ( + # preserve non-nano, i.e. don't cast to NumpyExtensionArray + TimedeltaArray._simple_new( + np.arange(5, dtype=np.int64).view("m8[s]"), dtype=np.dtype("m8[s]") + ), + np.dtype("m8[s]"), + TimedeltaArray._simple_new( + np.arange(5, dtype=np.int64).view("m8[s]"), dtype=np.dtype("m8[s]") + ), + ), + # Category + (["a", "b"], "category", pd.Categorical(["a", "b"])), + ( + ["a", "b"], + pd.CategoricalDtype(None, ordered=True), + pd.Categorical(["a", "b"], ordered=True), + ), + # Interval + ( + [pd.Interval(1, 2), pd.Interval(3, 4)], + "interval", + IntervalArray.from_tuples([(1, 2), (3, 4)]), + ), + # Sparse + ([0, 1], "Sparse[int64]", SparseArray([0, 1], dtype="int64")), + # IntegerNA + ([1, None], "Int16", pd.array([1, None], dtype="Int16")), + ( + pd.Series([1, 2]), + None, + NumpyExtensionArray(np.array([1, 2], dtype=np.int64)), + ), + # String + ( + ["a", None], + "string", + pd.StringDtype() + .construct_array_type() + ._from_sequence(["a", None], dtype=pd.StringDtype()), + ), + ( + ["a", None], + pd.StringDtype(), + pd.StringDtype() + .construct_array_type() + ._from_sequence(["a", None], dtype=pd.StringDtype()), + ), + # Boolean + ( + [True, None], + "boolean", + BooleanArray._from_sequence([True, None], dtype="boolean"), + ), + ( + [True, None], + pd.BooleanDtype(), + BooleanArray._from_sequence([True, None], dtype="boolean"), + ), + # Index + (pd.Index([1, 2]), None, NumpyExtensionArray(np.array([1, 2], dtype=np.int64))), + # Series[EA] returns the EA + ( + pd.Series(pd.Categorical(["a", "b"], categories=["a", "b", "c"])), + None, + pd.Categorical(["a", "b"], categories=["a", "b", "c"]), + ), + # "3rd party" EAs work + ([decimal.Decimal(0), decimal.Decimal(1)], "decimal", to_decimal([0, 1])), + # pass an ExtensionArray, but a different dtype + ( + period_array(["2000", "2001"], freq="D"), + "category", + pd.Categorical([pd.Period("2000", "D"), pd.Period("2001", "D")]), + ), + ], +) +def test_array(data, dtype, expected): + result = pd.array(data, dtype=dtype) + tm.assert_equal(result, expected) + + +def test_array_copy(): + a = np.array([1, 2]) + # default is to copy + b = pd.array(a, dtype=a.dtype) + assert not tm.shares_memory(a, b) + + # copy=True + b = pd.array(a, dtype=a.dtype, copy=True) + assert not tm.shares_memory(a, b) + + # copy=False + b = pd.array(a, dtype=a.dtype, copy=False) + assert tm.shares_memory(a, b) + + +cet = pytz.timezone("CET") + + +@pytest.mark.parametrize( + "data, expected", + [ + # period + ( + [pd.Period("2000", "D"), pd.Period("2001", "D")], + period_array(["2000", "2001"], freq="D"), + ), + # interval + ([pd.Interval(0, 1), pd.Interval(1, 2)], IntervalArray.from_breaks([0, 1, 2])), + # datetime + ( + [pd.Timestamp("2000"), pd.Timestamp("2001")], + DatetimeArray._from_sequence(["2000", "2001"], dtype="M8[ns]"), + ), + ( + [datetime.datetime(2000, 1, 1), datetime.datetime(2001, 1, 1)], + DatetimeArray._from_sequence(["2000", "2001"], dtype="M8[ns]"), + ), + ( + np.array([1, 2], dtype="M8[ns]"), + DatetimeArray._from_sequence(np.array([1, 2], dtype="M8[ns]")), + ), + ( + np.array([1, 2], dtype="M8[us]"), + DatetimeArray._simple_new( + np.array([1, 2], dtype="M8[us]"), dtype=np.dtype("M8[us]") + ), + ), + # datetimetz + ( + [pd.Timestamp("2000", tz="CET"), pd.Timestamp("2001", tz="CET")], + DatetimeArray._from_sequence( + ["2000", "2001"], dtype=pd.DatetimeTZDtype(tz="CET", unit="ns") + ), + ), + ( + [ + datetime.datetime(2000, 1, 1, tzinfo=cet), + datetime.datetime(2001, 1, 1, tzinfo=cet), + ], + DatetimeArray._from_sequence( + ["2000", "2001"], dtype=pd.DatetimeTZDtype(tz=cet, unit="ns") + ), + ), + # timedelta + ( + [pd.Timedelta("1h"), pd.Timedelta("2h")], + TimedeltaArray._from_sequence(["1h", "2h"], dtype="m8[ns]"), + ), + ( + np.array([1, 2], dtype="m8[ns]"), + TimedeltaArray._from_sequence(np.array([1, 2], dtype="m8[ns]")), + ), + ( + np.array([1, 2], dtype="m8[us]"), + TimedeltaArray._from_sequence(np.array([1, 2], dtype="m8[us]")), + ), + # integer + ([1, 2], IntegerArray._from_sequence([1, 2], dtype="Int64")), + ([1, None], IntegerArray._from_sequence([1, None], dtype="Int64")), + ([1, pd.NA], IntegerArray._from_sequence([1, pd.NA], dtype="Int64")), + ([1, np.nan], IntegerArray._from_sequence([1, np.nan], dtype="Int64")), + # float + ([0.1, 0.2], FloatingArray._from_sequence([0.1, 0.2], dtype="Float64")), + ([0.1, None], FloatingArray._from_sequence([0.1, pd.NA], dtype="Float64")), + ([0.1, np.nan], FloatingArray._from_sequence([0.1, pd.NA], dtype="Float64")), + ([0.1, pd.NA], FloatingArray._from_sequence([0.1, pd.NA], dtype="Float64")), + # integer-like float + ([1.0, 2.0], FloatingArray._from_sequence([1.0, 2.0], dtype="Float64")), + ([1.0, None], FloatingArray._from_sequence([1.0, pd.NA], dtype="Float64")), + ([1.0, np.nan], FloatingArray._from_sequence([1.0, pd.NA], dtype="Float64")), + ([1.0, pd.NA], FloatingArray._from_sequence([1.0, pd.NA], dtype="Float64")), + # mixed-integer-float + ([1, 2.0], FloatingArray._from_sequence([1.0, 2.0], dtype="Float64")), + ( + [1, np.nan, 2.0], + FloatingArray._from_sequence([1.0, None, 2.0], dtype="Float64"), + ), + # string + ( + ["a", "b"], + pd.StringDtype() + .construct_array_type() + ._from_sequence(["a", "b"], dtype=pd.StringDtype()), + ), + ( + ["a", None], + pd.StringDtype() + .construct_array_type() + ._from_sequence(["a", None], dtype=pd.StringDtype()), + ), + # Boolean + ([True, False], BooleanArray._from_sequence([True, False], dtype="boolean")), + ([True, None], BooleanArray._from_sequence([True, None], dtype="boolean")), + ], +) +def test_array_inference(data, expected): + result = pd.array(data) + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "data", + [ + # mix of frequencies + [pd.Period("2000", "D"), pd.Period("2001", "Y")], + # mix of closed + [pd.Interval(0, 1, closed="left"), pd.Interval(1, 2, closed="right")], + # Mix of timezones + [pd.Timestamp("2000", tz="CET"), pd.Timestamp("2000", tz="UTC")], + # Mix of tz-aware and tz-naive + [pd.Timestamp("2000", tz="CET"), pd.Timestamp("2000")], + np.array([pd.Timestamp("2000"), pd.Timestamp("2000", tz="CET")]), + ], +) +def test_array_inference_fails(data): + result = pd.array(data) + expected = NumpyExtensionArray(np.array(data, dtype=object)) + tm.assert_extension_array_equal(result, expected) + + +@pytest.mark.parametrize("data", [np.array(0)]) +def test_nd_raises(data): + with pytest.raises(ValueError, match="NumpyExtensionArray must be 1-dimensional"): + pd.array(data, dtype="int64") + + +def test_scalar_raises(): + with pytest.raises(ValueError, match="Cannot pass scalar '1'"): + pd.array(1) + + +def test_dataframe_raises(): + # GH#51167 don't accidentally cast to StringArray by doing inference on columns + df = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"]) + msg = "Cannot pass DataFrame to 'pandas.array'" + with pytest.raises(TypeError, match=msg): + pd.array(df) + + +def test_bounds_check(): + # GH21796 + with pytest.raises( + TypeError, match=r"cannot safely cast non-equivalent int(32|64) to uint16" + ): + pd.array([-1, 2, 3], dtype="UInt16") + + +# --------------------------------------------------------------------------- +# A couple dummy classes to ensure that Series and Indexes are unboxed before +# getting to the EA classes. + + +@register_extension_dtype +class DecimalDtype2(DecimalDtype): + name = "decimal2" + + @classmethod + def construct_array_type(cls): + """ + Return the array type associated with this dtype. + + Returns + ------- + type + """ + return DecimalArray2 + + +class DecimalArray2(DecimalArray): + @classmethod + def _from_sequence(cls, scalars, *, dtype=None, copy=False): + if isinstance(scalars, (pd.Series, pd.Index)): + raise TypeError("scalars should not be of type pd.Series or pd.Index") + + return super()._from_sequence(scalars, dtype=dtype, copy=copy) + + +def test_array_unboxes(index_or_series): + box = index_or_series + + data = box([decimal.Decimal("1"), decimal.Decimal("2")]) + dtype = DecimalDtype2() + # make sure it works + with pytest.raises( + TypeError, match="scalars should not be of type pd.Series or pd.Index" + ): + DecimalArray2._from_sequence(data, dtype=dtype) + + result = pd.array(data, dtype="decimal2") + expected = DecimalArray2._from_sequence(data.values, dtype=dtype) + tm.assert_equal(result, expected) + + +def test_array_to_numpy_na(): + # GH#40638 + arr = pd.array([pd.NA, 1], dtype="string[python]") + result = arr.to_numpy(na_value=True, dtype=bool) + expected = np.array([True, True]) + tm.assert_numpy_array_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/arrays/test_datetimelike.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/arrays/test_datetimelike.py new file mode 100644 index 0000000000000000000000000000000000000000..4961123a7ca0794aec3a880537cfbd25017207ae --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/arrays/test_datetimelike.py @@ -0,0 +1,1344 @@ +from __future__ import annotations + +import re +import warnings + +import numpy as np +import pytest + +from pandas._libs import ( + NaT, + OutOfBoundsDatetime, + Timestamp, +) +from pandas._libs.tslibs.dtypes import freq_to_period_freqstr +from pandas.compat.numpy import np_version_gt2 + +import pandas as pd +from pandas import ( + DatetimeIndex, + Period, + PeriodIndex, + TimedeltaIndex, +) +import pandas._testing as tm +from pandas.core.arrays import ( + DatetimeArray, + NumpyExtensionArray, + PeriodArray, + TimedeltaArray, +) + + +# TODO: more freq variants +@pytest.fixture(params=["D", "B", "W", "ME", "QE", "YE"]) +def freqstr(request): + """Fixture returning parametrized frequency in string format.""" + return request.param + + +@pytest.fixture +def period_index(freqstr): + """ + A fixture to provide PeriodIndex objects with different frequencies. + + Most PeriodArray behavior is already tested in PeriodIndex tests, + so here we just test that the PeriodArray behavior matches + the PeriodIndex behavior. + """ + # TODO: non-monotone indexes; NaTs, different start dates + with warnings.catch_warnings(): + # suppress deprecation of Period[B] + warnings.filterwarnings( + "ignore", message="Period with BDay freq", category=FutureWarning + ) + freqstr = freq_to_period_freqstr(1, freqstr) + pi = pd.period_range(start=Timestamp("2000-01-01"), periods=100, freq=freqstr) + return pi + + +@pytest.fixture +def datetime_index(freqstr): + """ + A fixture to provide DatetimeIndex objects with different frequencies. + + Most DatetimeArray behavior is already tested in DatetimeIndex tests, + so here we just test that the DatetimeArray behavior matches + the DatetimeIndex behavior. + """ + # TODO: non-monotone indexes; NaTs, different start dates, timezones + dti = pd.date_range(start=Timestamp("2000-01-01"), periods=100, freq=freqstr) + return dti + + +@pytest.fixture +def timedelta_index(): + """ + A fixture to provide TimedeltaIndex objects with different frequencies. + Most TimedeltaArray behavior is already tested in TimedeltaIndex tests, + so here we just test that the TimedeltaArray behavior matches + the TimedeltaIndex behavior. + """ + # TODO: flesh this out + return TimedeltaIndex(["1 Day", "3 Hours", "NaT"]) + + +class SharedTests: + index_cls: type[DatetimeIndex | PeriodIndex | TimedeltaIndex] + + @pytest.fixture + def arr1d(self): + """Fixture returning DatetimeArray with daily frequency.""" + data = np.arange(10, dtype="i8") * 24 * 3600 * 10**9 + if self.array_cls is PeriodArray: + arr = self.array_cls(data, freq="D") + else: + arr = self.index_cls(data, freq="D")._data + return arr + + def test_compare_len1_raises(self, arr1d): + # make sure we raise when comparing with different lengths, specific + # to the case where one has length-1, which numpy would broadcast + arr = arr1d + idx = self.index_cls(arr) + + with pytest.raises(ValueError, match="Lengths must match"): + arr == arr[:1] + + # test the index classes while we're at it, GH#23078 + with pytest.raises(ValueError, match="Lengths must match"): + idx <= idx[[0]] + + @pytest.mark.parametrize( + "result", + [ + pd.date_range("2020", periods=3), + pd.date_range("2020", periods=3, tz="UTC"), + pd.timedelta_range("0 days", periods=3), + pd.period_range("2020Q1", periods=3, freq="Q"), + ], + ) + def test_compare_with_Categorical(self, result): + expected = pd.Categorical(result) + assert all(result == expected) + assert not any(result != expected) + + @pytest.mark.parametrize("reverse", [True, False]) + @pytest.mark.parametrize("as_index", [True, False]) + def test_compare_categorical_dtype(self, arr1d, as_index, reverse, ordered): + other = pd.Categorical(arr1d, ordered=ordered) + if as_index: + other = pd.CategoricalIndex(other) + + left, right = arr1d, other + if reverse: + left, right = right, left + + ones = np.ones(arr1d.shape, dtype=bool) + zeros = ~ones + + result = left == right + tm.assert_numpy_array_equal(result, ones) + + result = left != right + tm.assert_numpy_array_equal(result, zeros) + + if not reverse and not as_index: + # Otherwise Categorical raises TypeError bc it is not ordered + # TODO: we should probably get the same behavior regardless? + result = left < right + tm.assert_numpy_array_equal(result, zeros) + + result = left <= right + tm.assert_numpy_array_equal(result, ones) + + result = left > right + tm.assert_numpy_array_equal(result, zeros) + + result = left >= right + tm.assert_numpy_array_equal(result, ones) + + def test_take(self): + data = np.arange(100, dtype="i8") * 24 * 3600 * 10**9 + np.random.default_rng(2).shuffle(data) + + if self.array_cls is PeriodArray: + arr = PeriodArray(data, dtype="period[D]") + else: + arr = self.index_cls(data)._data + idx = self.index_cls._simple_new(arr) + + takers = [1, 4, 94] + result = arr.take(takers) + expected = idx.take(takers) + + tm.assert_index_equal(self.index_cls(result), expected) + + takers = np.array([1, 4, 94]) + result = arr.take(takers) + expected = idx.take(takers) + + tm.assert_index_equal(self.index_cls(result), expected) + + @pytest.mark.parametrize("fill_value", [2, 2.0, Timestamp(2021, 1, 1, 12).time]) + def test_take_fill_raises(self, fill_value, arr1d): + msg = f"value should be a '{arr1d._scalar_type.__name__}' or 'NaT'. Got" + with pytest.raises(TypeError, match=msg): + arr1d.take([0, 1], allow_fill=True, fill_value=fill_value) + + def test_take_fill(self, arr1d): + arr = arr1d + + result = arr.take([-1, 1], allow_fill=True, fill_value=None) + assert result[0] is NaT + + result = arr.take([-1, 1], allow_fill=True, fill_value=np.nan) + assert result[0] is NaT + + result = arr.take([-1, 1], allow_fill=True, fill_value=NaT) + assert result[0] is NaT + + @pytest.mark.filterwarnings( + "ignore:Period with BDay freq is deprecated:FutureWarning" + ) + def test_take_fill_str(self, arr1d): + # Cast str fill_value matching other fill_value-taking methods + result = arr1d.take([-1, 1], allow_fill=True, fill_value=str(arr1d[-1])) + expected = arr1d[[-1, 1]] + tm.assert_equal(result, expected) + + msg = f"value should be a '{arr1d._scalar_type.__name__}' or 'NaT'. Got" + with pytest.raises(TypeError, match=msg): + arr1d.take([-1, 1], allow_fill=True, fill_value="foo") + + def test_concat_same_type(self, arr1d): + arr = arr1d + idx = self.index_cls(arr) + idx = idx.insert(0, NaT) + arr = arr1d + + result = arr._concat_same_type([arr[:-1], arr[1:], arr]) + arr2 = arr.astype(object) + expected = self.index_cls(np.concatenate([arr2[:-1], arr2[1:], arr2])) + + tm.assert_index_equal(self.index_cls(result), expected) + + def test_unbox_scalar(self, arr1d): + result = arr1d._unbox_scalar(arr1d[0]) + expected = arr1d._ndarray.dtype.type + assert isinstance(result, expected) + + result = arr1d._unbox_scalar(NaT) + assert isinstance(result, expected) + + msg = f"'value' should be a {self.scalar_type.__name__}." + with pytest.raises(ValueError, match=msg): + arr1d._unbox_scalar("foo") + + def test_check_compatible_with(self, arr1d): + arr1d._check_compatible_with(arr1d[0]) + arr1d._check_compatible_with(arr1d[:1]) + arr1d._check_compatible_with(NaT) + + def test_scalar_from_string(self, arr1d): + result = arr1d._scalar_from_string(str(arr1d[0])) + assert result == arr1d[0] + + def test_reduce_invalid(self, arr1d): + msg = "does not support reduction 'not a method'" + with pytest.raises(TypeError, match=msg): + arr1d._reduce("not a method") + + @pytest.mark.parametrize("method", ["pad", "backfill"]) + def test_fillna_method_doesnt_change_orig(self, method): + data = np.arange(10, dtype="i8") * 24 * 3600 * 10**9 + if self.array_cls is PeriodArray: + arr = self.array_cls(data, dtype="period[D]") + else: + arr = self.array_cls._from_sequence(data) + arr[4] = NaT + + fill_value = arr[3] if method == "pad" else arr[5] + + result = arr._pad_or_backfill(method=method) + assert result[4] == fill_value + + # check that the original was not changed + assert arr[4] is NaT + + def test_searchsorted(self): + data = np.arange(10, dtype="i8") * 24 * 3600 * 10**9 + if self.array_cls is PeriodArray: + arr = self.array_cls(data, dtype="period[D]") + else: + arr = self.array_cls._from_sequence(data) + + # scalar + result = arr.searchsorted(arr[1]) + assert result == 1 + + result = arr.searchsorted(arr[2], side="right") + assert result == 3 + + # own-type + result = arr.searchsorted(arr[1:3]) + expected = np.array([1, 2], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + result = arr.searchsorted(arr[1:3], side="right") + expected = np.array([2, 3], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + # GH#29884 match numpy convention on whether NaT goes + # at the end or the beginning + result = arr.searchsorted(NaT) + assert result == 10 + + @pytest.mark.parametrize("box", [None, "index", "series"]) + def test_searchsorted_castable_strings(self, arr1d, box, string_storage): + arr = arr1d + if box is None: + pass + elif box == "index": + # Test the equivalent Index.searchsorted method while we're here + arr = self.index_cls(arr) + else: + # Test the equivalent Series.searchsorted method while we're here + arr = pd.Series(arr) + + # scalar + result = arr.searchsorted(str(arr[1])) + assert result == 1 + + result = arr.searchsorted(str(arr[2]), side="right") + assert result == 3 + + result = arr.searchsorted([str(x) for x in arr[1:3]]) + expected = np.array([1, 2], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + with pytest.raises( + TypeError, + match=re.escape( + f"value should be a '{arr1d._scalar_type.__name__}', 'NaT', " + "or array of those. Got 'str' instead." + ), + ): + arr.searchsorted("foo") + + with pd.option_context("string_storage", string_storage): + with pytest.raises( + TypeError, + match=re.escape( + f"value should be a '{arr1d._scalar_type.__name__}', 'NaT', " + "or array of those. Got string array instead." + ), + ): + arr.searchsorted([str(arr[1]), "baz"]) + + def test_getitem_near_implementation_bounds(self): + # We only check tz-naive for DTA bc the bounds are slightly different + # for other tzs + i8vals = np.asarray([NaT._value + n for n in range(1, 5)], dtype="i8") + if self.array_cls is PeriodArray: + arr = self.array_cls(i8vals, dtype="period[ns]") + else: + arr = self.index_cls(i8vals, freq="ns")._data + arr[0] # should not raise OutOfBoundsDatetime + + index = pd.Index(arr) + index[0] # should not raise OutOfBoundsDatetime + + ser = pd.Series(arr) + ser[0] # should not raise OutOfBoundsDatetime + + def test_getitem_2d(self, arr1d): + # 2d slicing on a 1D array + expected = type(arr1d)._simple_new( + arr1d._ndarray[:, np.newaxis], dtype=arr1d.dtype + ) + result = arr1d[:, np.newaxis] + tm.assert_equal(result, expected) + + # Lookup on a 2D array + arr2d = expected + expected = type(arr2d)._simple_new(arr2d._ndarray[:3, 0], dtype=arr2d.dtype) + result = arr2d[:3, 0] + tm.assert_equal(result, expected) + + # Scalar lookup + result = arr2d[-1, 0] + expected = arr1d[-1] + assert result == expected + + def test_iter_2d(self, arr1d): + data2d = arr1d._ndarray[:3, np.newaxis] + arr2d = type(arr1d)._simple_new(data2d, dtype=arr1d.dtype) + result = list(arr2d) + assert len(result) == 3 + for x in result: + assert isinstance(x, type(arr1d)) + assert x.ndim == 1 + assert x.dtype == arr1d.dtype + + def test_repr_2d(self, arr1d): + data2d = arr1d._ndarray[:3, np.newaxis] + arr2d = type(arr1d)._simple_new(data2d, dtype=arr1d.dtype) + + result = repr(arr2d) + + if isinstance(arr2d, TimedeltaArray): + expected = ( + f"<{type(arr2d).__name__}>\n" + "[\n" + f"['{arr1d[0]._repr_base()}'],\n" + f"['{arr1d[1]._repr_base()}'],\n" + f"['{arr1d[2]._repr_base()}']\n" + "]\n" + f"Shape: (3, 1), dtype: {arr1d.dtype}" + ) + else: + expected = ( + f"<{type(arr2d).__name__}>\n" + "[\n" + f"['{arr1d[0]}'],\n" + f"['{arr1d[1]}'],\n" + f"['{arr1d[2]}']\n" + "]\n" + f"Shape: (3, 1), dtype: {arr1d.dtype}" + ) + + assert result == expected + + def test_setitem(self): + data = np.arange(10, dtype="i8") * 24 * 3600 * 10**9 + if self.array_cls is PeriodArray: + arr = self.array_cls(data, dtype="period[D]") + else: + arr = self.index_cls(data, freq="D")._data + + arr[0] = arr[1] + expected = np.arange(10, dtype="i8") * 24 * 3600 * 10**9 + expected[0] = expected[1] + + tm.assert_numpy_array_equal(arr.asi8, expected) + + arr[:2] = arr[-2:] + expected[:2] = expected[-2:] + tm.assert_numpy_array_equal(arr.asi8, expected) + + @pytest.mark.parametrize( + "box", + [ + pd.Index, + pd.Series, + np.array, + list, + NumpyExtensionArray, + ], + ) + def test_setitem_object_dtype(self, box, arr1d): + expected = arr1d.copy()[::-1] + if expected.dtype.kind in ["m", "M"]: + expected = expected._with_freq(None) + + vals = expected + if box is list: + vals = list(vals) + elif box is np.array: + # if we do np.array(x).astype(object) then dt64 and td64 cast to ints + vals = np.array(vals.astype(object)) + elif box is NumpyExtensionArray: + vals = box(np.asarray(vals, dtype=object)) + else: + vals = box(vals).astype(object) + + arr1d[:] = vals + + tm.assert_equal(arr1d, expected) + + def test_setitem_strs(self, arr1d): + # Check that we parse strs in both scalar and listlike + + # Setting list-like of strs + expected = arr1d.copy() + expected[[0, 1]] = arr1d[-2:] + + result = arr1d.copy() + result[:2] = [str(x) for x in arr1d[-2:]] + tm.assert_equal(result, expected) + + # Same thing but now for just a scalar str + expected = arr1d.copy() + expected[0] = arr1d[-1] + + result = arr1d.copy() + result[0] = str(arr1d[-1]) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize("as_index", [True, False]) + def test_setitem_categorical(self, arr1d, as_index): + expected = arr1d.copy()[::-1] + if not isinstance(expected, PeriodArray): + expected = expected._with_freq(None) + + cat = pd.Categorical(arr1d) + if as_index: + cat = pd.CategoricalIndex(cat) + + arr1d[:] = cat[::-1] + + tm.assert_equal(arr1d, expected) + + def test_setitem_raises(self, arr1d): + arr = arr1d[:10] + val = arr[0] + + with pytest.raises(IndexError, match="index 12 is out of bounds"): + arr[12] = val + + with pytest.raises(TypeError, match="value should be a.* 'object'"): + arr[0] = object() + + msg = "cannot set using a list-like indexer with a different length" + with pytest.raises(ValueError, match=msg): + # GH#36339 + arr[[]] = [arr[1]] + + msg = "cannot set using a slice indexer with a different length than" + with pytest.raises(ValueError, match=msg): + # GH#36339 + arr[1:1] = arr[:3] + + @pytest.mark.parametrize("box", [list, np.array, pd.Index, pd.Series]) + def test_setitem_numeric_raises(self, arr1d, box): + # We dont case e.g. int64 to our own dtype for setitem + + msg = ( + f"value should be a '{arr1d._scalar_type.__name__}', " + "'NaT', or array of those. Got" + ) + with pytest.raises(TypeError, match=msg): + arr1d[:2] = box([0, 1]) + + with pytest.raises(TypeError, match=msg): + arr1d[:2] = box([0.0, 1.0]) + + def test_inplace_arithmetic(self): + # GH#24115 check that iadd and isub are actually in-place + data = np.arange(10, dtype="i8") * 24 * 3600 * 10**9 + if self.array_cls is PeriodArray: + arr = self.array_cls(data, dtype="period[D]") + else: + arr = self.index_cls(data, freq="D")._data + + expected = arr + pd.Timedelta(days=1) + arr += pd.Timedelta(days=1) + tm.assert_equal(arr, expected) + + expected = arr - pd.Timedelta(days=1) + arr -= pd.Timedelta(days=1) + tm.assert_equal(arr, expected) + + def test_shift_fill_int_deprecated(self, arr1d): + # GH#31971, enforced in 2.0 + with pytest.raises(TypeError, match="value should be a"): + arr1d.shift(1, fill_value=1) + + def test_median(self, arr1d): + arr = arr1d + if len(arr) % 2 == 0: + # make it easier to define `expected` + arr = arr[:-1] + + expected = arr[len(arr) // 2] + + result = arr.median() + assert type(result) is type(expected) + assert result == expected + + arr[len(arr) // 2] = NaT + if not isinstance(expected, Period): + expected = arr[len(arr) // 2 - 1 : len(arr) // 2 + 2].mean() + + assert arr.median(skipna=False) is NaT + + result = arr.median() + assert type(result) is type(expected) + assert result == expected + + assert arr[:0].median() is NaT + assert arr[:0].median(skipna=False) is NaT + + # 2d Case + arr2 = arr.reshape(-1, 1) + + result = arr2.median(axis=None) + assert type(result) is type(expected) + assert result == expected + + assert arr2.median(axis=None, skipna=False) is NaT + + result = arr2.median(axis=0) + expected2 = type(arr)._from_sequence([expected], dtype=arr.dtype) + tm.assert_equal(result, expected2) + + result = arr2.median(axis=0, skipna=False) + expected2 = type(arr)._from_sequence([NaT], dtype=arr.dtype) + tm.assert_equal(result, expected2) + + result = arr2.median(axis=1) + tm.assert_equal(result, arr) + + result = arr2.median(axis=1, skipna=False) + tm.assert_equal(result, arr) + + def test_from_integer_array(self): + arr = np.array([1, 2, 3], dtype=np.int64) + data = pd.array(arr, dtype="Int64") + if self.array_cls is PeriodArray: + expected = self.array_cls(arr, dtype=self.example_dtype) + result = self.array_cls(data, dtype=self.example_dtype) + else: + expected = self.array_cls._from_sequence(arr, dtype=self.example_dtype) + result = self.array_cls._from_sequence(data, dtype=self.example_dtype) + + tm.assert_extension_array_equal(result, expected) + + +class TestDatetimeArray(SharedTests): + index_cls = DatetimeIndex + array_cls = DatetimeArray + scalar_type = Timestamp + example_dtype = "M8[ns]" + + @pytest.fixture + def arr1d(self, tz_naive_fixture, freqstr): + """ + Fixture returning DatetimeArray with parametrized frequency and + timezones + """ + tz = tz_naive_fixture + dti = pd.date_range("2016-01-01 01:01:00", periods=5, freq=freqstr, tz=tz) + dta = dti._data + return dta + + def test_round(self, arr1d): + # GH#24064 + dti = self.index_cls(arr1d) + + result = dti.round(freq="2min") + expected = dti - pd.Timedelta(minutes=1) + expected = expected._with_freq(None) + tm.assert_index_equal(result, expected) + + dta = dti._data + result = dta.round(freq="2min") + expected = expected._data._with_freq(None) + tm.assert_datetime_array_equal(result, expected) + + def test_array_interface(self, datetime_index): + arr = datetime_index._data + copy_false = None if np_version_gt2 else False + + # default asarray gives the same underlying data (for tz naive) + result = np.asarray(arr) + expected = arr._ndarray + assert result is expected + tm.assert_numpy_array_equal(result, expected) + result = np.array(arr, copy=copy_false) + assert result is expected + tm.assert_numpy_array_equal(result, expected) + + # specifying M8[ns] gives the same result as default + result = np.asarray(arr, dtype="datetime64[ns]") + expected = arr._ndarray + assert result is expected + tm.assert_numpy_array_equal(result, expected) + result = np.array(arr, dtype="datetime64[ns]", copy=copy_false) + assert result is expected + tm.assert_numpy_array_equal(result, expected) + result = np.array(arr, dtype="datetime64[ns]") + if not np_version_gt2: + # TODO: GH 57739 + assert result is not expected + tm.assert_numpy_array_equal(result, expected) + + # to object dtype + result = np.asarray(arr, dtype=object) + expected = np.array(list(arr), dtype=object) + tm.assert_numpy_array_equal(result, expected) + + # to other dtype always copies + result = np.asarray(arr, dtype="int64") + assert result is not arr.asi8 + assert not np.may_share_memory(arr, result) + expected = arr.asi8.copy() + tm.assert_numpy_array_equal(result, expected) + + # other dtypes handled by numpy + for dtype in ["float64", str]: + result = np.asarray(arr, dtype=dtype) + expected = np.asarray(arr).astype(dtype) + tm.assert_numpy_array_equal(result, expected) + + def test_array_object_dtype(self, arr1d): + # GH#23524 + arr = arr1d + dti = self.index_cls(arr1d) + + expected = np.array(list(dti)) + + result = np.array(arr, dtype=object) + tm.assert_numpy_array_equal(result, expected) + + # also test the DatetimeIndex method while we're at it + result = np.array(dti, dtype=object) + tm.assert_numpy_array_equal(result, expected) + + def test_array_tz(self, arr1d): + # GH#23524 + arr = arr1d + dti = self.index_cls(arr1d) + copy_false = None if np_version_gt2 else False + + expected = dti.asi8.view("M8[ns]") + result = np.array(arr, dtype="M8[ns]") + tm.assert_numpy_array_equal(result, expected) + + result = np.array(arr, dtype="datetime64[ns]") + tm.assert_numpy_array_equal(result, expected) + + # check that we are not making copies when setting copy=copy_false + result = np.array(arr, dtype="M8[ns]", copy=copy_false) + assert result.base is expected.base + assert result.base is not None + result = np.array(arr, dtype="datetime64[ns]", copy=copy_false) + assert result.base is expected.base + assert result.base is not None + + def test_array_i8_dtype(self, arr1d): + arr = arr1d + dti = self.index_cls(arr1d) + copy_false = None if np_version_gt2 else False + + expected = dti.asi8 + result = np.array(arr, dtype="i8") + tm.assert_numpy_array_equal(result, expected) + + result = np.array(arr, dtype=np.int64) + tm.assert_numpy_array_equal(result, expected) + + # check that we are still making copies when setting copy=copy_false + result = np.array(arr, dtype="i8", copy=copy_false) + assert result.base is not expected.base + assert result.base is None + + def test_from_array_keeps_base(self): + # Ensure that DatetimeArray._ndarray.base isn't lost. + arr = np.array(["2000-01-01", "2000-01-02"], dtype="M8[ns]") + dta = DatetimeArray._from_sequence(arr) + + assert dta._ndarray is arr + dta = DatetimeArray._from_sequence(arr[:0]) + assert dta._ndarray.base is arr + + def test_from_dti(self, arr1d): + arr = arr1d + dti = self.index_cls(arr1d) + assert list(dti) == list(arr) + + # Check that Index.__new__ knows what to do with DatetimeArray + dti2 = pd.Index(arr) + assert isinstance(dti2, DatetimeIndex) + assert list(dti2) == list(arr) + + def test_astype_object(self, arr1d): + arr = arr1d + dti = self.index_cls(arr1d) + + asobj = arr.astype("O") + assert isinstance(asobj, np.ndarray) + assert asobj.dtype == "O" + assert list(asobj) == list(dti) + + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + def test_to_period(self, datetime_index, freqstr): + dti = datetime_index + arr = dti._data + + freqstr = freq_to_period_freqstr(1, freqstr) + expected = dti.to_period(freq=freqstr) + result = arr.to_period(freq=freqstr) + assert isinstance(result, PeriodArray) + + tm.assert_equal(result, expected._data) + + def test_to_period_2d(self, arr1d): + arr2d = arr1d.reshape(1, -1) + + warn = None if arr1d.tz is None else UserWarning + with tm.assert_produces_warning(warn): + result = arr2d.to_period("D") + expected = arr1d.to_period("D").reshape(1, -1) + tm.assert_period_array_equal(result, expected) + + @pytest.mark.parametrize("propname", DatetimeArray._bool_ops) + def test_bool_properties(self, arr1d, propname): + # in this case _bool_ops is just `is_leap_year` + dti = self.index_cls(arr1d) + arr = arr1d + assert dti.freq == arr.freq + + result = getattr(arr, propname) + expected = np.array(getattr(dti, propname), dtype=result.dtype) + + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("propname", DatetimeArray._field_ops) + def test_int_properties(self, arr1d, propname): + dti = self.index_cls(arr1d) + arr = arr1d + + result = getattr(arr, propname) + expected = np.array(getattr(dti, propname), dtype=result.dtype) + + tm.assert_numpy_array_equal(result, expected) + + def test_take_fill_valid(self, arr1d, fixed_now_ts): + arr = arr1d + dti = self.index_cls(arr1d) + + now = fixed_now_ts.tz_localize(dti.tz) + result = arr.take([-1, 1], allow_fill=True, fill_value=now) + assert result[0] == now + + msg = f"value should be a '{arr1d._scalar_type.__name__}' or 'NaT'. Got" + with pytest.raises(TypeError, match=msg): + # fill_value Timedelta invalid + arr.take([-1, 1], allow_fill=True, fill_value=now - now) + + with pytest.raises(TypeError, match=msg): + # fill_value Period invalid + arr.take([-1, 1], allow_fill=True, fill_value=Period("2014Q1")) + + tz = None if dti.tz is not None else "US/Eastern" + now = fixed_now_ts.tz_localize(tz) + msg = "Cannot compare tz-naive and tz-aware datetime-like objects" + with pytest.raises(TypeError, match=msg): + # Timestamp with mismatched tz-awareness + arr.take([-1, 1], allow_fill=True, fill_value=now) + + value = NaT._value + msg = f"value should be a '{arr1d._scalar_type.__name__}' or 'NaT'. Got" + with pytest.raises(TypeError, match=msg): + # require NaT, not iNaT, as it could be confused with an integer + arr.take([-1, 1], allow_fill=True, fill_value=value) + + value = np.timedelta64("NaT", "ns") + with pytest.raises(TypeError, match=msg): + # require appropriate-dtype if we have a NA value + arr.take([-1, 1], allow_fill=True, fill_value=value) + + if arr.tz is not None: + # GH#37356 + # Assuming here that arr1d fixture does not include Australia/Melbourne + value = fixed_now_ts.tz_localize("Australia/Melbourne") + result = arr.take([-1, 1], allow_fill=True, fill_value=value) + + expected = arr.take( + [-1, 1], + allow_fill=True, + fill_value=value.tz_convert(arr.dtype.tz), + ) + tm.assert_equal(result, expected) + + def test_concat_same_type_invalid(self, arr1d): + # different timezones + arr = arr1d + + if arr.tz is None: + other = arr.tz_localize("UTC") + else: + other = arr.tz_localize(None) + + with pytest.raises(ValueError, match="to_concat must have the same"): + arr._concat_same_type([arr, other]) + + def test_concat_same_type_different_freq(self, unit): + # we *can* concatenate DTI with different freqs. + a = pd.date_range("2000", periods=2, freq="D", tz="US/Central", unit=unit)._data + b = pd.date_range("2000", periods=2, freq="h", tz="US/Central", unit=unit)._data + result = DatetimeArray._concat_same_type([a, b]) + expected = ( + pd.to_datetime( + [ + "2000-01-01 00:00:00", + "2000-01-02 00:00:00", + "2000-01-01 00:00:00", + "2000-01-01 01:00:00", + ] + ) + .tz_localize("US/Central") + .as_unit(unit) + ._data + ) + + tm.assert_datetime_array_equal(result, expected) + + def test_strftime(self, arr1d): + arr = arr1d + + result = arr.strftime("%Y %b") + expected = np.array([ts.strftime("%Y %b") for ts in arr], dtype=object) + tm.assert_numpy_array_equal(result, expected) + + def test_strftime_nat(self): + # GH 29578 + arr = DatetimeIndex(["2019-01-01", NaT])._data + + result = arr.strftime("%Y-%m-%d") + expected = np.array(["2019-01-01", np.nan], dtype=object) + tm.assert_numpy_array_equal(result, expected) + + +class TestTimedeltaArray(SharedTests): + index_cls = TimedeltaIndex + array_cls = TimedeltaArray + scalar_type = pd.Timedelta + example_dtype = "m8[ns]" + + def test_from_tdi(self): + tdi = TimedeltaIndex(["1 Day", "3 Hours"]) + arr = tdi._data + assert list(arr) == list(tdi) + + # Check that Index.__new__ knows what to do with TimedeltaArray + tdi2 = pd.Index(arr) + assert isinstance(tdi2, TimedeltaIndex) + assert list(tdi2) == list(arr) + + def test_astype_object(self): + tdi = TimedeltaIndex(["1 Day", "3 Hours"]) + arr = tdi._data + asobj = arr.astype("O") + assert isinstance(asobj, np.ndarray) + assert asobj.dtype == "O" + assert list(asobj) == list(tdi) + + def test_to_pytimedelta(self, timedelta_index): + tdi = timedelta_index + arr = tdi._data + + expected = tdi.to_pytimedelta() + result = arr.to_pytimedelta() + + tm.assert_numpy_array_equal(result, expected) + + def test_total_seconds(self, timedelta_index): + tdi = timedelta_index + arr = tdi._data + + expected = tdi.total_seconds() + result = arr.total_seconds() + + tm.assert_numpy_array_equal(result, expected.values) + + @pytest.mark.parametrize("propname", TimedeltaArray._field_ops) + def test_int_properties(self, timedelta_index, propname): + tdi = timedelta_index + arr = tdi._data + + result = getattr(arr, propname) + expected = np.array(getattr(tdi, propname), dtype=result.dtype) + + tm.assert_numpy_array_equal(result, expected) + + def test_array_interface(self, timedelta_index): + arr = timedelta_index._data + copy_false = None if np_version_gt2 else False + + # default asarray gives the same underlying data + result = np.asarray(arr) + expected = arr._ndarray + assert result is expected + tm.assert_numpy_array_equal(result, expected) + result = np.array(arr, copy=copy_false) + assert result is expected + tm.assert_numpy_array_equal(result, expected) + + # specifying m8[ns] gives the same result as default + result = np.asarray(arr, dtype="timedelta64[ns]") + expected = arr._ndarray + assert result is expected + tm.assert_numpy_array_equal(result, expected) + result = np.array(arr, dtype="timedelta64[ns]", copy=copy_false) + assert result is expected + tm.assert_numpy_array_equal(result, expected) + result = np.array(arr, dtype="timedelta64[ns]") + if not np_version_gt2: + # TODO: GH 57739 + assert result is not expected + tm.assert_numpy_array_equal(result, expected) + + # to object dtype + result = np.asarray(arr, dtype=object) + expected = np.array(list(arr), dtype=object) + tm.assert_numpy_array_equal(result, expected) + + # to other dtype always copies + result = np.asarray(arr, dtype="int64") + assert result is not arr.asi8 + assert not np.may_share_memory(arr, result) + expected = arr.asi8.copy() + tm.assert_numpy_array_equal(result, expected) + + # other dtypes handled by numpy + for dtype in ["float64", str]: + result = np.asarray(arr, dtype=dtype) + expected = np.asarray(arr).astype(dtype) + tm.assert_numpy_array_equal(result, expected) + + def test_take_fill_valid(self, timedelta_index, fixed_now_ts): + tdi = timedelta_index + arr = tdi._data + + td1 = pd.Timedelta(days=1) + result = arr.take([-1, 1], allow_fill=True, fill_value=td1) + assert result[0] == td1 + + value = fixed_now_ts + msg = f"value should be a '{arr._scalar_type.__name__}' or 'NaT'. Got" + with pytest.raises(TypeError, match=msg): + # fill_value Timestamp invalid + arr.take([0, 1], allow_fill=True, fill_value=value) + + value = fixed_now_ts.to_period("D") + with pytest.raises(TypeError, match=msg): + # fill_value Period invalid + arr.take([0, 1], allow_fill=True, fill_value=value) + + value = np.datetime64("NaT", "ns") + with pytest.raises(TypeError, match=msg): + # require appropriate-dtype if we have a NA value + arr.take([-1, 1], allow_fill=True, fill_value=value) + + +@pytest.mark.filterwarnings(r"ignore:Period with BDay freq is deprecated:FutureWarning") +@pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") +class TestPeriodArray(SharedTests): + index_cls = PeriodIndex + array_cls = PeriodArray + scalar_type = Period + example_dtype = PeriodIndex([], freq="W").dtype + + @pytest.fixture + def arr1d(self, period_index): + """ + Fixture returning DatetimeArray from parametrized PeriodIndex objects + """ + return period_index._data + + def test_from_pi(self, arr1d): + pi = self.index_cls(arr1d) + arr = arr1d + assert list(arr) == list(pi) + + # Check that Index.__new__ knows what to do with PeriodArray + pi2 = pd.Index(arr) + assert isinstance(pi2, PeriodIndex) + assert list(pi2) == list(arr) + + def test_astype_object(self, arr1d): + pi = self.index_cls(arr1d) + arr = arr1d + asobj = arr.astype("O") + assert isinstance(asobj, np.ndarray) + assert asobj.dtype == "O" + assert list(asobj) == list(pi) + + def test_take_fill_valid(self, arr1d): + arr = arr1d + + value = NaT._value + msg = f"value should be a '{arr1d._scalar_type.__name__}' or 'NaT'. Got" + with pytest.raises(TypeError, match=msg): + # require NaT, not iNaT, as it could be confused with an integer + arr.take([-1, 1], allow_fill=True, fill_value=value) + + value = np.timedelta64("NaT", "ns") + with pytest.raises(TypeError, match=msg): + # require appropriate-dtype if we have a NA value + arr.take([-1, 1], allow_fill=True, fill_value=value) + + @pytest.mark.parametrize("how", ["S", "E"]) + def test_to_timestamp(self, how, arr1d): + pi = self.index_cls(arr1d) + arr = arr1d + + expected = DatetimeIndex(pi.to_timestamp(how=how))._data + result = arr.to_timestamp(how=how) + assert isinstance(result, DatetimeArray) + + tm.assert_equal(result, expected) + + def test_to_timestamp_roundtrip_bday(self): + # Case where infer_freq inside would choose "D" instead of "B" + dta = pd.date_range("2021-10-18", periods=3, freq="B")._data + parr = dta.to_period() + result = parr.to_timestamp() + assert result.freq == "B" + tm.assert_extension_array_equal(result, dta) + + dta2 = dta[::2] + parr2 = dta2.to_period() + result2 = parr2.to_timestamp() + assert result2.freq == "2B" + tm.assert_extension_array_equal(result2, dta2) + + parr3 = dta.to_period("2B") + result3 = parr3.to_timestamp() + assert result3.freq == "B" + tm.assert_extension_array_equal(result3, dta) + + def test_to_timestamp_out_of_bounds(self): + # GH#19643 previously overflowed silently + pi = pd.period_range("1500", freq="Y", periods=3) + msg = "Out of bounds nanosecond timestamp: 1500-01-01 00:00:00" + with pytest.raises(OutOfBoundsDatetime, match=msg): + pi.to_timestamp() + + with pytest.raises(OutOfBoundsDatetime, match=msg): + pi._data.to_timestamp() + + @pytest.mark.parametrize("propname", PeriodArray._bool_ops) + def test_bool_properties(self, arr1d, propname): + # in this case _bool_ops is just `is_leap_year` + pi = self.index_cls(arr1d) + arr = arr1d + + result = getattr(arr, propname) + expected = np.array(getattr(pi, propname)) + + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("propname", PeriodArray._field_ops) + def test_int_properties(self, arr1d, propname): + pi = self.index_cls(arr1d) + arr = arr1d + + result = getattr(arr, propname) + expected = np.array(getattr(pi, propname)) + + tm.assert_numpy_array_equal(result, expected) + + def test_array_interface(self, arr1d): + arr = arr1d + + # default asarray gives objects + result = np.asarray(arr) + expected = np.array(list(arr), dtype=object) + tm.assert_numpy_array_equal(result, expected) + + # to object dtype (same as default) + result = np.asarray(arr, dtype=object) + tm.assert_numpy_array_equal(result, expected) + + result = np.asarray(arr, dtype="int64") + tm.assert_numpy_array_equal(result, arr.asi8) + + # to other dtypes + msg = r"float\(\) argument must be a string or a( real)? number, not 'Period'" + with pytest.raises(TypeError, match=msg): + np.asarray(arr, dtype="float64") + + result = np.asarray(arr, dtype="S20") + expected = np.asarray(arr).astype("S20") + tm.assert_numpy_array_equal(result, expected) + + def test_strftime(self, arr1d): + arr = arr1d + + result = arr.strftime("%Y") + expected = np.array([per.strftime("%Y") for per in arr], dtype=object) + tm.assert_numpy_array_equal(result, expected) + + def test_strftime_nat(self): + # GH 29578 + arr = PeriodArray(PeriodIndex(["2019-01-01", NaT], dtype="period[D]")) + + result = arr.strftime("%Y-%m-%d") + expected = np.array(["2019-01-01", np.nan], dtype=object) + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize( + "arr,casting_nats", + [ + ( + TimedeltaIndex(["1 Day", "3 Hours", "NaT"])._data, + (NaT, np.timedelta64("NaT", "ns")), + ), + ( + pd.date_range("2000-01-01", periods=3, freq="D")._data, + (NaT, np.datetime64("NaT", "ns")), + ), + (pd.period_range("2000-01-01", periods=3, freq="D")._data, (NaT,)), + ], + ids=lambda x: type(x).__name__, +) +def test_casting_nat_setitem_array(arr, casting_nats): + expected = type(arr)._from_sequence([NaT, arr[1], arr[2]], dtype=arr.dtype) + + for nat in casting_nats: + arr = arr.copy() + arr[0] = nat + tm.assert_equal(arr, expected) + + +@pytest.mark.parametrize( + "arr,non_casting_nats", + [ + ( + TimedeltaIndex(["1 Day", "3 Hours", "NaT"])._data, + (np.datetime64("NaT", "ns"), NaT._value), + ), + ( + pd.date_range("2000-01-01", periods=3, freq="D")._data, + (np.timedelta64("NaT", "ns"), NaT._value), + ), + ( + pd.period_range("2000-01-01", periods=3, freq="D")._data, + (np.datetime64("NaT", "ns"), np.timedelta64("NaT", "ns"), NaT._value), + ), + ], + ids=lambda x: type(x).__name__, +) +def test_invalid_nat_setitem_array(arr, non_casting_nats): + msg = ( + "value should be a '(Timestamp|Timedelta|Period)', 'NaT', or array of those. " + "Got '(timedelta64|datetime64|int)' instead." + ) + + for nat in non_casting_nats: + with pytest.raises(TypeError, match=msg): + arr[0] = nat + + +@pytest.mark.parametrize( + "arr", + [ + pd.date_range("2000", periods=4).array, + pd.timedelta_range("2000", periods=4).array, + ], +) +def test_to_numpy_extra(arr): + arr[0] = NaT + original = arr.copy() + + result = arr.to_numpy() + assert np.isnan(result[0]) + + result = arr.to_numpy(dtype="int64") + assert result[0] == -9223372036854775808 + + result = arr.to_numpy(dtype="int64", na_value=0) + assert result[0] == 0 + + result = arr.to_numpy(na_value=arr[1].to_numpy()) + assert result[0] == result[1] + + result = arr.to_numpy(na_value=arr[1].to_numpy(copy=False)) + assert result[0] == result[1] + + tm.assert_equal(arr, original) + + +@pytest.mark.parametrize("as_index", [True, False]) +@pytest.mark.parametrize( + "values", + [ + pd.to_datetime(["2020-01-01", "2020-02-01"]), + pd.to_timedelta([1, 2], unit="D"), + PeriodIndex(["2020-01-01", "2020-02-01"], freq="D"), + ], +) +@pytest.mark.parametrize( + "klass", + [ + list, + np.array, + pd.array, + pd.Series, + pd.Index, + pd.Categorical, + pd.CategoricalIndex, + ], +) +def test_searchsorted_datetimelike_with_listlike(values, klass, as_index): + # https://github.com/pandas-dev/pandas/issues/32762 + if not as_index: + values = values._data + + result = values.searchsorted(klass(values)) + expected = np.array([0, 1], dtype=result.dtype) + + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize( + "values", + [ + pd.to_datetime(["2020-01-01", "2020-02-01"]), + pd.to_timedelta([1, 2], unit="D"), + PeriodIndex(["2020-01-01", "2020-02-01"], freq="D"), + ], +) +@pytest.mark.parametrize( + "arg", [[1, 2], ["a", "b"], [Timestamp("2020-01-01", tz="Europe/London")] * 2] +) +def test_searchsorted_datetimelike_with_listlike_invalid_dtype(values, arg): + # https://github.com/pandas-dev/pandas/issues/32762 + msg = "[Unexpected type|Cannot compare]" + with pytest.raises(TypeError, match=msg): + values.searchsorted(arg) + + +@pytest.mark.parametrize("klass", [list, tuple, np.array, pd.Series]) +def test_period_index_construction_from_strings(klass): + # https://github.com/pandas-dev/pandas/issues/26109 + strings = ["2020Q1", "2020Q2"] * 2 + data = klass(strings) + result = PeriodIndex(data, freq="Q") + expected = PeriodIndex([Period(s) for s in strings]) + tm.assert_index_equal(result, expected) + + +@pytest.mark.parametrize("dtype", ["M8[ns]", "m8[ns]"]) +def test_from_pandas_array(dtype): + # GH#24615 + data = np.array([1, 2, 3], dtype=dtype) + arr = NumpyExtensionArray(data) + + cls = {"M8[ns]": DatetimeArray, "m8[ns]": TimedeltaArray}[dtype] + + depr_msg = f"{cls.__name__}.__init__ is deprecated" + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + result = cls(arr) + expected = cls(data) + tm.assert_extension_array_equal(result, expected) + + result = cls._from_sequence(arr, dtype=dtype) + expected = cls._from_sequence(data, dtype=dtype) + tm.assert_extension_array_equal(result, expected) + + func = {"M8[ns]": pd.to_datetime, "m8[ns]": pd.to_timedelta}[dtype] + result = func(arr).array + expected = func(data).array + tm.assert_equal(result, expected) + + # Let's check the Indexes while we're here + idx_cls = {"M8[ns]": DatetimeIndex, "m8[ns]": TimedeltaIndex}[dtype] + result = idx_cls(arr) + expected = idx_cls(data) + tm.assert_index_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/arrays/test_datetimes.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/arrays/test_datetimes.py new file mode 100644 index 0000000000000000000000000000000000000000..8f0576cc65a2787edacdb1e377a02287d1caaff1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/arrays/test_datetimes.py @@ -0,0 +1,840 @@ +""" +Tests for DatetimeArray +""" +from __future__ import annotations + +from datetime import timedelta +import operator + +try: + from zoneinfo import ZoneInfo +except ImportError: + # Cannot assign to a type + ZoneInfo = None # type: ignore[misc, assignment] + +import numpy as np +import pytest + +from pandas._libs.tslibs import tz_compare + +from pandas.core.dtypes.dtypes import DatetimeTZDtype + +import pandas as pd +import pandas._testing as tm +from pandas.core.arrays import ( + DatetimeArray, + TimedeltaArray, +) + + +class TestNonNano: + @pytest.fixture(params=["s", "ms", "us"]) + def unit(self, request): + """Fixture returning parametrized time units""" + return request.param + + @pytest.fixture + def dtype(self, unit, tz_naive_fixture): + tz = tz_naive_fixture + if tz is None: + return np.dtype(f"datetime64[{unit}]") + else: + return DatetimeTZDtype(unit=unit, tz=tz) + + @pytest.fixture + def dta_dti(self, unit, dtype): + tz = getattr(dtype, "tz", None) + + dti = pd.date_range("2016-01-01", periods=55, freq="D", tz=tz) + if tz is None: + arr = np.asarray(dti).astype(f"M8[{unit}]") + else: + arr = np.asarray(dti.tz_convert("UTC").tz_localize(None)).astype( + f"M8[{unit}]" + ) + + dta = DatetimeArray._simple_new(arr, dtype=dtype) + return dta, dti + + @pytest.fixture + def dta(self, dta_dti): + dta, dti = dta_dti + return dta + + def test_non_nano(self, unit, dtype): + arr = np.arange(5, dtype=np.int64).view(f"M8[{unit}]") + dta = DatetimeArray._simple_new(arr, dtype=dtype) + + assert dta.dtype == dtype + assert dta[0].unit == unit + assert tz_compare(dta.tz, dta[0].tz) + assert (dta[0] == dta[:1]).all() + + @pytest.mark.parametrize( + "field", DatetimeArray._field_ops + DatetimeArray._bool_ops + ) + def test_fields(self, unit, field, dtype, dta_dti): + dta, dti = dta_dti + + assert (dti == dta).all() + + res = getattr(dta, field) + expected = getattr(dti._data, field) + tm.assert_numpy_array_equal(res, expected) + + def test_normalize(self, unit): + dti = pd.date_range("2016-01-01 06:00:00", periods=55, freq="D") + arr = np.asarray(dti).astype(f"M8[{unit}]") + + dta = DatetimeArray._simple_new(arr, dtype=arr.dtype) + + assert not dta.is_normalized + + # TODO: simplify once we can just .astype to other unit + exp = np.asarray(dti.normalize()).astype(f"M8[{unit}]") + expected = DatetimeArray._simple_new(exp, dtype=exp.dtype) + + res = dta.normalize() + tm.assert_extension_array_equal(res, expected) + + def test_simple_new_requires_match(self, unit): + arr = np.arange(5, dtype=np.int64).view(f"M8[{unit}]") + dtype = DatetimeTZDtype(unit, "UTC") + + dta = DatetimeArray._simple_new(arr, dtype=dtype) + assert dta.dtype == dtype + + wrong = DatetimeTZDtype("ns", "UTC") + with pytest.raises(AssertionError, match=""): + DatetimeArray._simple_new(arr, dtype=wrong) + + def test_std_non_nano(self, unit): + dti = pd.date_range("2016-01-01", periods=55, freq="D") + arr = np.asarray(dti).astype(f"M8[{unit}]") + + dta = DatetimeArray._simple_new(arr, dtype=arr.dtype) + + # we should match the nano-reso std, but floored to our reso. + res = dta.std() + assert res._creso == dta._creso + assert res == dti.std().floor(unit) + + @pytest.mark.filterwarnings("ignore:Converting to PeriodArray.*:UserWarning") + def test_to_period(self, dta_dti): + dta, dti = dta_dti + result = dta.to_period("D") + expected = dti._data.to_period("D") + + tm.assert_extension_array_equal(result, expected) + + def test_iter(self, dta): + res = next(iter(dta)) + expected = dta[0] + + assert type(res) is pd.Timestamp + assert res._value == expected._value + assert res._creso == expected._creso + assert res == expected + + def test_astype_object(self, dta): + result = dta.astype(object) + assert all(x._creso == dta._creso for x in result) + assert all(x == y for x, y in zip(result, dta)) + + def test_to_pydatetime(self, dta_dti): + dta, dti = dta_dti + + result = dta.to_pydatetime() + expected = dti.to_pydatetime() + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("meth", ["time", "timetz", "date"]) + def test_time_date(self, dta_dti, meth): + dta, dti = dta_dti + + result = getattr(dta, meth) + expected = getattr(dti, meth) + tm.assert_numpy_array_equal(result, expected) + + def test_format_native_types(self, unit, dtype, dta_dti): + # In this case we should get the same formatted values with our nano + # version dti._data as we do with the non-nano dta + dta, dti = dta_dti + + res = dta._format_native_types() + exp = dti._data._format_native_types() + tm.assert_numpy_array_equal(res, exp) + + def test_repr(self, dta_dti, unit): + dta, dti = dta_dti + + assert repr(dta) == repr(dti._data).replace("[ns", f"[{unit}") + + # TODO: tests with td64 + def test_compare_mismatched_resolutions(self, comparison_op): + # comparison that numpy gets wrong bc of silent overflows + op = comparison_op + + iinfo = np.iinfo(np.int64) + vals = np.array([iinfo.min, iinfo.min + 1, iinfo.max], dtype=np.int64) + + # Construct so that arr2[1] < arr[1] < arr[2] < arr2[2] + arr = np.array(vals).view("M8[ns]") + arr2 = arr.view("M8[s]") + + left = DatetimeArray._simple_new(arr, dtype=arr.dtype) + right = DatetimeArray._simple_new(arr2, dtype=arr2.dtype) + + if comparison_op is operator.eq: + expected = np.array([False, False, False]) + elif comparison_op is operator.ne: + expected = np.array([True, True, True]) + elif comparison_op in [operator.lt, operator.le]: + expected = np.array([False, False, True]) + else: + expected = np.array([False, True, False]) + + result = op(left, right) + tm.assert_numpy_array_equal(result, expected) + + result = op(left[1], right) + tm.assert_numpy_array_equal(result, expected) + + if op not in [operator.eq, operator.ne]: + # check that numpy still gets this wrong; if it is fixed we may be + # able to remove compare_mismatched_resolutions + np_res = op(left._ndarray, right._ndarray) + tm.assert_numpy_array_equal(np_res[1:], ~expected[1:]) + + def test_add_mismatched_reso_doesnt_downcast(self): + # https://github.com/pandas-dev/pandas/pull/48748#issuecomment-1260181008 + td = pd.Timedelta(microseconds=1) + dti = pd.date_range("2016-01-01", periods=3) - td + dta = dti._data.as_unit("us") + + res = dta + td.as_unit("us") + # even though the result is an even number of days + # (so we _could_ downcast to unit="s"), we do not. + assert res.unit == "us" + + @pytest.mark.parametrize( + "scalar", + [ + timedelta(hours=2), + pd.Timedelta(hours=2), + np.timedelta64(2, "h"), + np.timedelta64(2 * 3600 * 1000, "ms"), + pd.offsets.Minute(120), + pd.offsets.Hour(2), + ], + ) + def test_add_timedeltalike_scalar_mismatched_reso(self, dta_dti, scalar): + dta, dti = dta_dti + + td = pd.Timedelta(scalar) + exp_unit = tm.get_finest_unit(dta.unit, td.unit) + + expected = (dti + td)._data.as_unit(exp_unit) + result = dta + scalar + tm.assert_extension_array_equal(result, expected) + + result = scalar + dta + tm.assert_extension_array_equal(result, expected) + + expected = (dti - td)._data.as_unit(exp_unit) + result = dta - scalar + tm.assert_extension_array_equal(result, expected) + + def test_sub_datetimelike_scalar_mismatch(self): + dti = pd.date_range("2016-01-01", periods=3) + dta = dti._data.as_unit("us") + + ts = dta[0].as_unit("s") + + result = dta - ts + expected = (dti - dti[0])._data.as_unit("us") + assert result.dtype == "m8[us]" + tm.assert_extension_array_equal(result, expected) + + def test_sub_datetime64_reso_mismatch(self): + dti = pd.date_range("2016-01-01", periods=3) + left = dti._data.as_unit("s") + right = left.as_unit("ms") + + result = left - right + exp_values = np.array([0, 0, 0], dtype="m8[ms]") + expected = TimedeltaArray._simple_new( + exp_values, + dtype=exp_values.dtype, + ) + tm.assert_extension_array_equal(result, expected) + result2 = right - left + tm.assert_extension_array_equal(result2, expected) + + +class TestDatetimeArrayComparisons: + # TODO: merge this into tests/arithmetic/test_datetime64 once it is + # sufficiently robust + + def test_cmp_dt64_arraylike_tznaive(self, comparison_op): + # arbitrary tz-naive DatetimeIndex + op = comparison_op + + dti = pd.date_range("2016-01-1", freq="MS", periods=9, tz=None) + arr = dti._data + assert arr.freq == dti.freq + assert arr.tz == dti.tz + + right = dti + + expected = np.ones(len(arr), dtype=bool) + if comparison_op.__name__ in ["ne", "gt", "lt"]: + # for these the comparisons should be all-False + expected = ~expected + + result = op(arr, arr) + tm.assert_numpy_array_equal(result, expected) + for other in [ + right, + np.array(right), + list(right), + tuple(right), + right.astype(object), + ]: + result = op(arr, other) + tm.assert_numpy_array_equal(result, expected) + + result = op(other, arr) + tm.assert_numpy_array_equal(result, expected) + + +class TestDatetimeArray: + def test_astype_ns_to_ms_near_bounds(self): + # GH#55979 + ts = pd.Timestamp("1677-09-21 00:12:43.145225") + target = ts.as_unit("ms") + + dta = DatetimeArray._from_sequence([ts], dtype="M8[ns]") + assert (dta.view("i8") == ts.as_unit("ns").value).all() + + result = dta.astype("M8[ms]") + assert result[0] == target + + expected = DatetimeArray._from_sequence([ts], dtype="M8[ms]") + assert (expected.view("i8") == target._value).all() + + tm.assert_datetime_array_equal(result, expected) + + def test_astype_non_nano_tznaive(self): + dti = pd.date_range("2016-01-01", periods=3) + + res = dti.astype("M8[s]") + assert res.dtype == "M8[s]" + + dta = dti._data + res = dta.astype("M8[s]") + assert res.dtype == "M8[s]" + assert isinstance(res, pd.core.arrays.DatetimeArray) # used to be ndarray + + def test_astype_non_nano_tzaware(self): + dti = pd.date_range("2016-01-01", periods=3, tz="UTC") + + res = dti.astype("M8[s, US/Pacific]") + assert res.dtype == "M8[s, US/Pacific]" + + dta = dti._data + res = dta.astype("M8[s, US/Pacific]") + assert res.dtype == "M8[s, US/Pacific]" + + # from non-nano to non-nano, preserving reso + res2 = res.astype("M8[s, UTC]") + assert res2.dtype == "M8[s, UTC]" + assert not tm.shares_memory(res2, res) + + res3 = res.astype("M8[s, UTC]", copy=False) + assert res2.dtype == "M8[s, UTC]" + assert tm.shares_memory(res3, res) + + def test_astype_to_same(self): + arr = DatetimeArray._from_sequence( + ["2000"], dtype=DatetimeTZDtype(tz="US/Central") + ) + result = arr.astype(DatetimeTZDtype(tz="US/Central"), copy=False) + assert result is arr + + @pytest.mark.parametrize("dtype", ["datetime64[ns]", "datetime64[ns, UTC]"]) + @pytest.mark.parametrize( + "other", ["datetime64[ns]", "datetime64[ns, UTC]", "datetime64[ns, CET]"] + ) + def test_astype_copies(self, dtype, other): + # https://github.com/pandas-dev/pandas/pull/32490 + ser = pd.Series([1, 2], dtype=dtype) + orig = ser.copy() + + err = False + if (dtype == "datetime64[ns]") ^ (other == "datetime64[ns]"): + # deprecated in favor of tz_localize + err = True + + if err: + if dtype == "datetime64[ns]": + msg = "Use obj.tz_localize instead or series.dt.tz_localize instead" + else: + msg = "from timezone-aware dtype to timezone-naive dtype" + with pytest.raises(TypeError, match=msg): + ser.astype(other) + else: + t = ser.astype(other) + t[:] = pd.NaT + tm.assert_series_equal(ser, orig) + + @pytest.mark.parametrize("dtype", [int, np.int32, np.int64, "uint32", "uint64"]) + def test_astype_int(self, dtype): + arr = DatetimeArray._from_sequence( + [pd.Timestamp("2000"), pd.Timestamp("2001")], dtype="M8[ns]" + ) + + if np.dtype(dtype) != np.int64: + with pytest.raises(TypeError, match=r"Do obj.astype\('int64'\)"): + arr.astype(dtype) + return + + result = arr.astype(dtype) + expected = arr._ndarray.view("i8") + tm.assert_numpy_array_equal(result, expected) + + def test_astype_to_sparse_dt64(self): + # GH#50082 + dti = pd.date_range("2016-01-01", periods=4) + dta = dti._data + result = dta.astype("Sparse[datetime64[ns]]") + + assert result.dtype == "Sparse[datetime64[ns]]" + assert (result == dta).all() + + def test_tz_setter_raises(self): + arr = DatetimeArray._from_sequence( + ["2000"], dtype=DatetimeTZDtype(tz="US/Central") + ) + with pytest.raises(AttributeError, match="tz_localize"): + arr.tz = "UTC" + + def test_setitem_str_impute_tz(self, tz_naive_fixture): + # Like for getitem, if we are passed a naive-like string, we impute + # our own timezone. + tz = tz_naive_fixture + + data = np.array([1, 2, 3], dtype="M8[ns]") + dtype = data.dtype if tz is None else DatetimeTZDtype(tz=tz) + arr = DatetimeArray._from_sequence(data, dtype=dtype) + expected = arr.copy() + + ts = pd.Timestamp("2020-09-08 16:50").tz_localize(tz) + setter = str(ts.tz_localize(None)) + + # Setting a scalar tznaive string + expected[0] = ts + arr[0] = setter + tm.assert_equal(arr, expected) + + # Setting a listlike of tznaive strings + expected[1] = ts + arr[:2] = [setter, setter] + tm.assert_equal(arr, expected) + + def test_setitem_different_tz_raises(self): + # pre-2.0 we required exact tz match, in 2.0 we require only + # tzawareness-match + data = np.array([1, 2, 3], dtype="M8[ns]") + arr = DatetimeArray._from_sequence( + data, copy=False, dtype=DatetimeTZDtype(tz="US/Central") + ) + with pytest.raises(TypeError, match="Cannot compare tz-naive and tz-aware"): + arr[0] = pd.Timestamp("2000") + + ts = pd.Timestamp("2000", tz="US/Eastern") + arr[0] = ts + assert arr[0] == ts.tz_convert("US/Central") + + def test_setitem_clears_freq(self): + a = pd.date_range("2000", periods=2, freq="D", tz="US/Central")._data + a[0] = pd.Timestamp("2000", tz="US/Central") + assert a.freq is None + + @pytest.mark.parametrize( + "obj", + [ + pd.Timestamp("2021-01-01"), + pd.Timestamp("2021-01-01").to_datetime64(), + pd.Timestamp("2021-01-01").to_pydatetime(), + ], + ) + def test_setitem_objects(self, obj): + # make sure we accept datetime64 and datetime in addition to Timestamp + dti = pd.date_range("2000", periods=2, freq="D") + arr = dti._data + + arr[0] = obj + assert arr[0] == obj + + def test_repeat_preserves_tz(self): + dti = pd.date_range("2000", periods=2, freq="D", tz="US/Central") + arr = dti._data + + repeated = arr.repeat([1, 1]) + + # preserves tz and values, but not freq + expected = DatetimeArray._from_sequence(arr.asi8, dtype=arr.dtype) + tm.assert_equal(repeated, expected) + + def test_value_counts_preserves_tz(self): + dti = pd.date_range("2000", periods=2, freq="D", tz="US/Central") + arr = dti._data.repeat([4, 3]) + + result = arr.value_counts() + + # Note: not tm.assert_index_equal, since `freq`s do not match + assert result.index.equals(dti) + + arr[-2] = pd.NaT + result = arr.value_counts(dropna=False) + expected = pd.Series([4, 2, 1], index=[dti[0], dti[1], pd.NaT], name="count") + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("method", ["pad", "backfill"]) + def test_fillna_preserves_tz(self, method): + dti = pd.date_range("2000-01-01", periods=5, freq="D", tz="US/Central") + arr = DatetimeArray._from_sequence(dti, copy=True) + arr[2] = pd.NaT + + fill_val = dti[1] if method == "pad" else dti[3] + expected = DatetimeArray._from_sequence( + [dti[0], dti[1], fill_val, dti[3], dti[4]], + dtype=DatetimeTZDtype(tz="US/Central"), + ) + + result = arr._pad_or_backfill(method=method) + tm.assert_extension_array_equal(result, expected) + + # assert that arr and dti were not modified in-place + assert arr[2] is pd.NaT + assert dti[2] == pd.Timestamp("2000-01-03", tz="US/Central") + + def test_fillna_2d(self): + dti = pd.date_range("2016-01-01", periods=6, tz="US/Pacific") + dta = dti._data.reshape(3, 2).copy() + dta[0, 1] = pd.NaT + dta[1, 0] = pd.NaT + + res1 = dta._pad_or_backfill(method="pad") + expected1 = dta.copy() + expected1[1, 0] = dta[0, 0] + tm.assert_extension_array_equal(res1, expected1) + + res2 = dta._pad_or_backfill(method="backfill") + expected2 = dta.copy() + expected2 = dta.copy() + expected2[1, 0] = dta[2, 0] + expected2[0, 1] = dta[1, 1] + tm.assert_extension_array_equal(res2, expected2) + + # with different ordering for underlying ndarray; behavior should + # be unchanged + dta2 = dta._from_backing_data(dta._ndarray.copy(order="F")) + assert dta2._ndarray.flags["F_CONTIGUOUS"] + assert not dta2._ndarray.flags["C_CONTIGUOUS"] + tm.assert_extension_array_equal(dta, dta2) + + res3 = dta2._pad_or_backfill(method="pad") + tm.assert_extension_array_equal(res3, expected1) + + res4 = dta2._pad_or_backfill(method="backfill") + tm.assert_extension_array_equal(res4, expected2) + + # test the DataFrame method while we're here + df = pd.DataFrame(dta) + res = df.ffill() + expected = pd.DataFrame(expected1) + tm.assert_frame_equal(res, expected) + + res = df.bfill() + expected = pd.DataFrame(expected2) + tm.assert_frame_equal(res, expected) + + def test_array_interface_tz(self): + tz = "US/Central" + data = pd.date_range("2017", periods=2, tz=tz)._data + result = np.asarray(data) + + expected = np.array( + [ + pd.Timestamp("2017-01-01T00:00:00", tz=tz), + pd.Timestamp("2017-01-02T00:00:00", tz=tz), + ], + dtype=object, + ) + tm.assert_numpy_array_equal(result, expected) + + result = np.asarray(data, dtype=object) + tm.assert_numpy_array_equal(result, expected) + + result = np.asarray(data, dtype="M8[ns]") + + expected = np.array( + ["2017-01-01T06:00:00", "2017-01-02T06:00:00"], dtype="M8[ns]" + ) + tm.assert_numpy_array_equal(result, expected) + + def test_array_interface(self): + data = pd.date_range("2017", periods=2)._data + expected = np.array( + ["2017-01-01T00:00:00", "2017-01-02T00:00:00"], dtype="datetime64[ns]" + ) + + result = np.asarray(data) + tm.assert_numpy_array_equal(result, expected) + + result = np.asarray(data, dtype=object) + expected = np.array( + [pd.Timestamp("2017-01-01T00:00:00"), pd.Timestamp("2017-01-02T00:00:00")], + dtype=object, + ) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("index", [True, False]) + def test_searchsorted_different_tz(self, index): + data = np.arange(10, dtype="i8") * 24 * 3600 * 10**9 + arr = pd.DatetimeIndex(data, freq="D")._data.tz_localize("Asia/Tokyo") + if index: + arr = pd.Index(arr) + + expected = arr.searchsorted(arr[2]) + result = arr.searchsorted(arr[2].tz_convert("UTC")) + assert result == expected + + expected = arr.searchsorted(arr[2:6]) + result = arr.searchsorted(arr[2:6].tz_convert("UTC")) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize("index", [True, False]) + def test_searchsorted_tzawareness_compat(self, index): + data = np.arange(10, dtype="i8") * 24 * 3600 * 10**9 + arr = pd.DatetimeIndex(data, freq="D")._data + if index: + arr = pd.Index(arr) + + mismatch = arr.tz_localize("Asia/Tokyo") + + msg = "Cannot compare tz-naive and tz-aware datetime-like objects" + with pytest.raises(TypeError, match=msg): + arr.searchsorted(mismatch[0]) + with pytest.raises(TypeError, match=msg): + arr.searchsorted(mismatch) + + with pytest.raises(TypeError, match=msg): + mismatch.searchsorted(arr[0]) + with pytest.raises(TypeError, match=msg): + mismatch.searchsorted(arr) + + @pytest.mark.parametrize( + "other", + [ + 1, + np.int64(1), + 1.0, + np.timedelta64("NaT"), + pd.Timedelta(days=2), + "invalid", + np.arange(10, dtype="i8") * 24 * 3600 * 10**9, + np.arange(10).view("timedelta64[ns]") * 24 * 3600 * 10**9, + pd.Timestamp("2021-01-01").to_period("D"), + ], + ) + @pytest.mark.parametrize("index", [True, False]) + def test_searchsorted_invalid_types(self, other, index): + data = np.arange(10, dtype="i8") * 24 * 3600 * 10**9 + arr = pd.DatetimeIndex(data, freq="D")._data + if index: + arr = pd.Index(arr) + + msg = "|".join( + [ + "searchsorted requires compatible dtype or scalar", + "value should be a 'Timestamp', 'NaT', or array of those. Got", + ] + ) + with pytest.raises(TypeError, match=msg): + arr.searchsorted(other) + + def test_shift_fill_value(self): + dti = pd.date_range("2016-01-01", periods=3) + + dta = dti._data + expected = DatetimeArray._from_sequence(np.roll(dta._ndarray, 1)) + + fv = dta[-1] + for fill_value in [fv, fv.to_pydatetime(), fv.to_datetime64()]: + result = dta.shift(1, fill_value=fill_value) + tm.assert_datetime_array_equal(result, expected) + + dta = dta.tz_localize("UTC") + expected = expected.tz_localize("UTC") + fv = dta[-1] + for fill_value in [fv, fv.to_pydatetime()]: + result = dta.shift(1, fill_value=fill_value) + tm.assert_datetime_array_equal(result, expected) + + def test_shift_value_tzawareness_mismatch(self): + dti = pd.date_range("2016-01-01", periods=3) + + dta = dti._data + + fv = dta[-1].tz_localize("UTC") + for invalid in [fv, fv.to_pydatetime()]: + with pytest.raises(TypeError, match="Cannot compare"): + dta.shift(1, fill_value=invalid) + + dta = dta.tz_localize("UTC") + fv = dta[-1].tz_localize(None) + for invalid in [fv, fv.to_pydatetime(), fv.to_datetime64()]: + with pytest.raises(TypeError, match="Cannot compare"): + dta.shift(1, fill_value=invalid) + + def test_shift_requires_tzmatch(self): + # pre-2.0 we required exact tz match, in 2.0 we require just + # matching tzawareness + dti = pd.date_range("2016-01-01", periods=3, tz="UTC") + dta = dti._data + + fill_value = pd.Timestamp("2020-10-18 18:44", tz="US/Pacific") + + result = dta.shift(1, fill_value=fill_value) + expected = dta.shift(1, fill_value=fill_value.tz_convert("UTC")) + tm.assert_equal(result, expected) + + def test_tz_localize_t2d(self): + dti = pd.date_range("1994-05-12", periods=12, tz="US/Pacific") + dta = dti._data.reshape(3, 4) + result = dta.tz_localize(None) + + expected = dta.ravel().tz_localize(None).reshape(dta.shape) + tm.assert_datetime_array_equal(result, expected) + + roundtrip = expected.tz_localize("US/Pacific") + tm.assert_datetime_array_equal(roundtrip, dta) + + easts = ["US/Eastern", "dateutil/US/Eastern"] + if ZoneInfo is not None: + try: + tz = ZoneInfo("US/Eastern") + except KeyError: + # no tzdata + pass + else: + # Argument 1 to "append" of "list" has incompatible type "ZoneInfo"; + # expected "str" + easts.append(tz) # type: ignore[arg-type] + + @pytest.mark.parametrize("tz", easts) + def test_iter_zoneinfo_fold(self, tz): + # GH#49684 + utc_vals = np.array( + [1320552000, 1320555600, 1320559200, 1320562800], dtype=np.int64 + ) + utc_vals *= 1_000_000_000 + + dta = DatetimeArray._from_sequence(utc_vals).tz_localize("UTC").tz_convert(tz) + + left = dta[2] + right = list(dta)[2] + assert str(left) == str(right) + # previously there was a bug where with non-pytz right would be + # Timestamp('2011-11-06 01:00:00-0400', tz='US/Eastern') + # while left would be + # Timestamp('2011-11-06 01:00:00-0500', tz='US/Eastern') + # The .value's would match (so they would compare as equal), + # but the folds would not + assert left.utcoffset() == right.utcoffset() + + # The same bug in ints_to_pydatetime affected .astype, so we test + # that here. + right2 = dta.astype(object)[2] + assert str(left) == str(right2) + assert left.utcoffset() == right2.utcoffset() + + @pytest.mark.parametrize( + "freq, freq_depr", + [ + ("2ME", "2M"), + ("2SME", "2SM"), + ("2SME", "2sm"), + ("2QE", "2Q"), + ("2QE-SEP", "2Q-SEP"), + ("1YE", "1Y"), + ("2YE-MAR", "2Y-MAR"), + ("1YE", "1A"), + ("2YE-MAR", "2A-MAR"), + ("2ME", "2m"), + ("2QE-SEP", "2q-sep"), + ("2YE-MAR", "2a-mar"), + ("2YE", "2y"), + ], + ) + def test_date_range_frequency_M_Q_Y_A_deprecated(self, freq, freq_depr): + # GH#9586, GH#54275 + depr_msg = f"'{freq_depr[1:]}' is deprecated and will be removed " + f"in a future version, please use '{freq[1:]}' instead." + + expected = pd.date_range("1/1/2000", periods=4, freq=freq) + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + result = pd.date_range("1/1/2000", periods=4, freq=freq_depr) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("freq_depr", ["2H", "2CBH", "2MIN", "2S", "2mS", "2Us"]) + def test_date_range_uppercase_frequency_deprecated(self, freq_depr): + # GH#9586, GH#54939 + depr_msg = f"'{freq_depr[1:]}' is deprecated and will be removed in a " + f"future version. Please use '{freq_depr.lower()[1:]}' instead." + + expected = pd.date_range("1/1/2000", periods=4, freq=freq_depr.lower()) + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + result = pd.date_range("1/1/2000", periods=4, freq=freq_depr) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "freq_depr", + [ + "2ye-mar", + "2ys", + "2qe", + "2qs-feb", + "2bqs", + "2sms", + "2bms", + "2cbme", + "2me", + "2w", + ], + ) + def test_date_range_lowercase_frequency_deprecated(self, freq_depr): + # GH#9586, GH#54939 + depr_msg = f"'{freq_depr[1:]}' is deprecated and will be removed in a " + f"future version, please use '{freq_depr.upper()[1:]}' instead." + + expected = pd.date_range("1/1/2000", periods=4, freq=freq_depr.upper()) + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + result = pd.date_range("1/1/2000", periods=4, freq=freq_depr) + tm.assert_index_equal(result, expected) + + +def test_factorize_sort_without_freq(): + dta = DatetimeArray._from_sequence([0, 2, 1], dtype="M8[ns]") + + msg = r"call pd.factorize\(obj, sort=True\) instead" + with pytest.raises(NotImplementedError, match=msg): + dta.factorize(sort=True) + + # Do TimedeltaArray while we're here + tda = dta - dta[0] + with pytest.raises(NotImplementedError, match=msg): + tda.factorize(sort=True) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/arrays/test_ndarray_backed.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/arrays/test_ndarray_backed.py new file mode 100644 index 0000000000000000000000000000000000000000..1fe7cc9b03e8a6cef04558958ed949a0239a96cc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/arrays/test_ndarray_backed.py @@ -0,0 +1,75 @@ +""" +Tests for subclasses of NDArrayBackedExtensionArray +""" +import numpy as np + +from pandas import ( + CategoricalIndex, + date_range, +) +from pandas.core.arrays import ( + Categorical, + DatetimeArray, + NumpyExtensionArray, + TimedeltaArray, +) + + +class TestEmpty: + def test_empty_categorical(self): + ci = CategoricalIndex(["a", "b", "c"], ordered=True) + dtype = ci.dtype + + # case with int8 codes + shape = (4,) + result = Categorical._empty(shape, dtype=dtype) + assert isinstance(result, Categorical) + assert result.shape == shape + assert result._ndarray.dtype == np.int8 + + # case where repr would segfault if we didn't override base implementation + result = Categorical._empty((4096,), dtype=dtype) + assert isinstance(result, Categorical) + assert result.shape == (4096,) + assert result._ndarray.dtype == np.int8 + repr(result) + + # case with int16 codes + ci = CategoricalIndex(list(range(512)) * 4, ordered=False) + dtype = ci.dtype + result = Categorical._empty(shape, dtype=dtype) + assert isinstance(result, Categorical) + assert result.shape == shape + assert result._ndarray.dtype == np.int16 + + def test_empty_dt64tz(self): + dti = date_range("2016-01-01", periods=2, tz="Asia/Tokyo") + dtype = dti.dtype + + shape = (0,) + result = DatetimeArray._empty(shape, dtype=dtype) + assert result.dtype == dtype + assert isinstance(result, DatetimeArray) + assert result.shape == shape + + def test_empty_dt64(self): + shape = (3, 9) + result = DatetimeArray._empty(shape, dtype="datetime64[ns]") + assert isinstance(result, DatetimeArray) + assert result.shape == shape + + def test_empty_td64(self): + shape = (3, 9) + result = TimedeltaArray._empty(shape, dtype="m8[ns]") + assert isinstance(result, TimedeltaArray) + assert result.shape == shape + + def test_empty_pandas_array(self): + arr = NumpyExtensionArray(np.array([1, 2])) + dtype = arr.dtype + + shape = (3, 9) + result = NumpyExtensionArray._empty(shape, dtype=dtype) + assert isinstance(result, NumpyExtensionArray) + assert result.dtype == dtype + assert result.shape == shape diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/arrays/test_period.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/arrays/test_period.py new file mode 100644 index 0000000000000000000000000000000000000000..48453ba19e9a1f6971a2e56872ec42f1856d1dd0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/arrays/test_period.py @@ -0,0 +1,184 @@ +import numpy as np +import pytest + +from pandas._libs.tslibs import iNaT +from pandas._libs.tslibs.period import IncompatibleFrequency + +from pandas.core.dtypes.base import _registry as registry +from pandas.core.dtypes.dtypes import PeriodDtype + +import pandas as pd +import pandas._testing as tm +from pandas.core.arrays import PeriodArray + +# ---------------------------------------------------------------------------- +# Dtype + + +def test_registered(): + assert PeriodDtype in registry.dtypes + result = registry.find("Period[D]") + expected = PeriodDtype("D") + assert result == expected + + +# ---------------------------------------------------------------------------- +# period_array + + +def test_asi8(): + result = PeriodArray._from_sequence(["2000", "2001", None], dtype="period[D]").asi8 + expected = np.array([10957, 11323, iNaT]) + tm.assert_numpy_array_equal(result, expected) + + +def test_take_raises(): + arr = PeriodArray._from_sequence(["2000", "2001"], dtype="period[D]") + with pytest.raises(IncompatibleFrequency, match="freq"): + arr.take([0, -1], allow_fill=True, fill_value=pd.Period("2000", freq="W")) + + msg = "value should be a 'Period' or 'NaT'. Got 'str' instead" + with pytest.raises(TypeError, match=msg): + arr.take([0, -1], allow_fill=True, fill_value="foo") + + +def test_fillna_raises(): + arr = PeriodArray._from_sequence(["2000", "2001", "2002"], dtype="period[D]") + with pytest.raises(ValueError, match="Length"): + arr.fillna(arr[:2]) + + +def test_fillna_copies(): + arr = PeriodArray._from_sequence(["2000", "2001", "2002"], dtype="period[D]") + result = arr.fillna(pd.Period("2000", "D")) + assert result is not arr + + +# ---------------------------------------------------------------------------- +# setitem + + +@pytest.mark.parametrize( + "key, value, expected", + [ + ([0], pd.Period("2000", "D"), [10957, 1, 2]), + ([0], None, [iNaT, 1, 2]), + ([0], np.nan, [iNaT, 1, 2]), + ([0, 1, 2], pd.Period("2000", "D"), [10957] * 3), + ( + [0, 1, 2], + [pd.Period("2000", "D"), pd.Period("2001", "D"), pd.Period("2002", "D")], + [10957, 11323, 11688], + ), + ], +) +def test_setitem(key, value, expected): + arr = PeriodArray(np.arange(3), dtype="period[D]") + expected = PeriodArray(expected, dtype="period[D]") + arr[key] = value + tm.assert_period_array_equal(arr, expected) + + +def test_setitem_raises_incompatible_freq(): + arr = PeriodArray(np.arange(3), dtype="period[D]") + with pytest.raises(IncompatibleFrequency, match="freq"): + arr[0] = pd.Period("2000", freq="Y") + + other = PeriodArray._from_sequence(["2000", "2001"], dtype="period[Y]") + with pytest.raises(IncompatibleFrequency, match="freq"): + arr[[0, 1]] = other + + +def test_setitem_raises_length(): + arr = PeriodArray(np.arange(3), dtype="period[D]") + with pytest.raises(ValueError, match="length"): + arr[[0, 1]] = [pd.Period("2000", freq="D")] + + +def test_setitem_raises_type(): + arr = PeriodArray(np.arange(3), dtype="period[D]") + with pytest.raises(TypeError, match="int"): + arr[0] = 1 + + +# ---------------------------------------------------------------------------- +# Ops + + +def test_sub_period(): + arr = PeriodArray._from_sequence(["2000", "2001"], dtype="period[D]") + other = pd.Period("2000", freq="M") + with pytest.raises(IncompatibleFrequency, match="freq"): + arr - other + + +def test_sub_period_overflow(): + # GH#47538 + dti = pd.date_range("1677-09-22", periods=2, freq="D") + pi = dti.to_period("ns") + + per = pd.Period._from_ordinal(10**14, pi.freq) + + with pytest.raises(OverflowError, match="Overflow in int64 addition"): + pi - per + + with pytest.raises(OverflowError, match="Overflow in int64 addition"): + per - pi + + +# ---------------------------------------------------------------------------- +# Methods + + +@pytest.mark.parametrize( + "other", + [ + pd.Period("2000", freq="h"), + PeriodArray._from_sequence(["2000", "2001", "2000"], dtype="period[h]"), + ], +) +def test_where_different_freq_raises(other): + # GH#45768 The PeriodArray method raises, the Series method coerces + ser = pd.Series( + PeriodArray._from_sequence(["2000", "2001", "2002"], dtype="period[D]") + ) + cond = np.array([True, False, True]) + + with pytest.raises(IncompatibleFrequency, match="freq"): + ser.array._where(cond, other) + + res = ser.where(cond, other) + expected = ser.astype(object).where(cond, other) + tm.assert_series_equal(res, expected) + + +# ---------------------------------------------------------------------------- +# Printing + + +def test_repr_small(): + arr = PeriodArray._from_sequence(["2000", "2001"], dtype="period[D]") + result = str(arr) + expected = ( + "\n['2000-01-01', '2001-01-01']\nLength: 2, dtype: period[D]" + ) + assert result == expected + + +def test_repr_large(): + arr = PeriodArray._from_sequence(["2000", "2001"] * 500, dtype="period[D]") + result = str(arr) + expected = ( + "\n" + "['2000-01-01', '2001-01-01', '2000-01-01', '2001-01-01', " + "'2000-01-01',\n" + " '2001-01-01', '2000-01-01', '2001-01-01', '2000-01-01', " + "'2001-01-01',\n" + " ...\n" + " '2000-01-01', '2001-01-01', '2000-01-01', '2001-01-01', " + "'2000-01-01',\n" + " '2001-01-01', '2000-01-01', '2001-01-01', '2000-01-01', " + "'2001-01-01']\n" + "Length: 1000, dtype: period[D]" + ) + assert result == expected diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/arrays/test_timedeltas.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/arrays/test_timedeltas.py new file mode 100644 index 0000000000000000000000000000000000000000..a3f15467feb144ee21883a0a2a777e3b5e0cdf42 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/arrays/test_timedeltas.py @@ -0,0 +1,313 @@ +from datetime import timedelta + +import numpy as np +import pytest + +import pandas as pd +from pandas import Timedelta +import pandas._testing as tm +from pandas.core.arrays import ( + DatetimeArray, + TimedeltaArray, +) + + +class TestNonNano: + @pytest.fixture(params=["s", "ms", "us"]) + def unit(self, request): + return request.param + + @pytest.fixture + def tda(self, unit): + arr = np.arange(5, dtype=np.int64).view(f"m8[{unit}]") + return TimedeltaArray._simple_new(arr, dtype=arr.dtype) + + def test_non_nano(self, unit): + arr = np.arange(5, dtype=np.int64).view(f"m8[{unit}]") + tda = TimedeltaArray._simple_new(arr, dtype=arr.dtype) + + assert tda.dtype == arr.dtype + assert tda[0].unit == unit + + def test_as_unit_raises(self, tda): + # GH#50616 + with pytest.raises(ValueError, match="Supported units"): + tda.as_unit("D") + + tdi = pd.Index(tda) + with pytest.raises(ValueError, match="Supported units"): + tdi.as_unit("D") + + @pytest.mark.parametrize("field", TimedeltaArray._field_ops) + def test_fields(self, tda, field): + as_nano = tda._ndarray.astype("m8[ns]") + tda_nano = TimedeltaArray._simple_new(as_nano, dtype=as_nano.dtype) + + result = getattr(tda, field) + expected = getattr(tda_nano, field) + tm.assert_numpy_array_equal(result, expected) + + def test_to_pytimedelta(self, tda): + as_nano = tda._ndarray.astype("m8[ns]") + tda_nano = TimedeltaArray._simple_new(as_nano, dtype=as_nano.dtype) + + result = tda.to_pytimedelta() + expected = tda_nano.to_pytimedelta() + tm.assert_numpy_array_equal(result, expected) + + def test_total_seconds(self, unit, tda): + as_nano = tda._ndarray.astype("m8[ns]") + tda_nano = TimedeltaArray._simple_new(as_nano, dtype=as_nano.dtype) + + result = tda.total_seconds() + expected = tda_nano.total_seconds() + tm.assert_numpy_array_equal(result, expected) + + def test_timedelta_array_total_seconds(self): + # GH34290 + expected = Timedelta("2 min").total_seconds() + + result = pd.array([Timedelta("2 min")]).total_seconds()[0] + assert result == expected + + def test_total_seconds_nanoseconds(self): + # issue #48521 + start_time = pd.Series(["2145-11-02 06:00:00"]).astype("datetime64[ns]") + end_time = pd.Series(["2145-11-02 07:06:00"]).astype("datetime64[ns]") + expected = (end_time - start_time).values / np.timedelta64(1, "s") + result = (end_time - start_time).dt.total_seconds().values + assert result == expected + + @pytest.mark.parametrize( + "nat", [np.datetime64("NaT", "ns"), np.datetime64("NaT", "us")] + ) + def test_add_nat_datetimelike_scalar(self, nat, tda): + result = tda + nat + assert isinstance(result, DatetimeArray) + assert result._creso == tda._creso + assert result.isna().all() + + result = nat + tda + assert isinstance(result, DatetimeArray) + assert result._creso == tda._creso + assert result.isna().all() + + def test_add_pdnat(self, tda): + result = tda + pd.NaT + assert isinstance(result, TimedeltaArray) + assert result._creso == tda._creso + assert result.isna().all() + + result = pd.NaT + tda + assert isinstance(result, TimedeltaArray) + assert result._creso == tda._creso + assert result.isna().all() + + # TODO: 2022-07-11 this is the only test that gets to DTA.tz_convert + # or tz_localize with non-nano; implement tests specific to that. + def test_add_datetimelike_scalar(self, tda, tz_naive_fixture): + ts = pd.Timestamp("2016-01-01", tz=tz_naive_fixture).as_unit("ns") + + expected = tda.as_unit("ns") + ts + res = tda + ts + tm.assert_extension_array_equal(res, expected) + res = ts + tda + tm.assert_extension_array_equal(res, expected) + + ts += Timedelta(1) # case where we can't cast losslessly + + exp_values = tda._ndarray + ts.asm8 + expected = ( + DatetimeArray._simple_new(exp_values, dtype=exp_values.dtype) + .tz_localize("UTC") + .tz_convert(ts.tz) + ) + + result = tda + ts + tm.assert_extension_array_equal(result, expected) + + result = ts + tda + tm.assert_extension_array_equal(result, expected) + + def test_mul_scalar(self, tda): + other = 2 + result = tda * other + expected = TimedeltaArray._simple_new(tda._ndarray * other, dtype=tda.dtype) + tm.assert_extension_array_equal(result, expected) + assert result._creso == tda._creso + + def test_mul_listlike(self, tda): + other = np.arange(len(tda)) + result = tda * other + expected = TimedeltaArray._simple_new(tda._ndarray * other, dtype=tda.dtype) + tm.assert_extension_array_equal(result, expected) + assert result._creso == tda._creso + + def test_mul_listlike_object(self, tda): + other = np.arange(len(tda)) + result = tda * other.astype(object) + expected = TimedeltaArray._simple_new(tda._ndarray * other, dtype=tda.dtype) + tm.assert_extension_array_equal(result, expected) + assert result._creso == tda._creso + + def test_div_numeric_scalar(self, tda): + other = 2 + result = tda / other + expected = TimedeltaArray._simple_new(tda._ndarray / other, dtype=tda.dtype) + tm.assert_extension_array_equal(result, expected) + assert result._creso == tda._creso + + def test_div_td_scalar(self, tda): + other = timedelta(seconds=1) + result = tda / other + expected = tda._ndarray / np.timedelta64(1, "s") + tm.assert_numpy_array_equal(result, expected) + + def test_div_numeric_array(self, tda): + other = np.arange(len(tda)) + result = tda / other + expected = TimedeltaArray._simple_new(tda._ndarray / other, dtype=tda.dtype) + tm.assert_extension_array_equal(result, expected) + assert result._creso == tda._creso + + def test_div_td_array(self, tda): + other = tda._ndarray + tda._ndarray[-1] + result = tda / other + expected = tda._ndarray / other + tm.assert_numpy_array_equal(result, expected) + + def test_add_timedeltaarraylike(self, tda): + tda_nano = tda.astype("m8[ns]") + + expected = tda_nano * 2 + res = tda_nano + tda + tm.assert_extension_array_equal(res, expected) + res = tda + tda_nano + tm.assert_extension_array_equal(res, expected) + + expected = tda_nano * 0 + res = tda - tda_nano + tm.assert_extension_array_equal(res, expected) + + res = tda_nano - tda + tm.assert_extension_array_equal(res, expected) + + +class TestTimedeltaArray: + @pytest.mark.parametrize("dtype", [int, np.int32, np.int64, "uint32", "uint64"]) + def test_astype_int(self, dtype): + arr = TimedeltaArray._from_sequence( + [Timedelta("1h"), Timedelta("2h")], dtype="m8[ns]" + ) + + if np.dtype(dtype) != np.int64: + with pytest.raises(TypeError, match=r"Do obj.astype\('int64'\)"): + arr.astype(dtype) + return + + result = arr.astype(dtype) + expected = arr._ndarray.view("i8") + tm.assert_numpy_array_equal(result, expected) + + def test_setitem_clears_freq(self): + a = pd.timedelta_range("1h", periods=2, freq="h")._data + a[0] = Timedelta("1h") + assert a.freq is None + + @pytest.mark.parametrize( + "obj", + [ + Timedelta(seconds=1), + Timedelta(seconds=1).to_timedelta64(), + Timedelta(seconds=1).to_pytimedelta(), + ], + ) + def test_setitem_objects(self, obj): + # make sure we accept timedelta64 and timedelta in addition to Timedelta + tdi = pd.timedelta_range("2 Days", periods=4, freq="h") + arr = tdi._data + + arr[0] = obj + assert arr[0] == Timedelta(seconds=1) + + @pytest.mark.parametrize( + "other", + [ + 1, + np.int64(1), + 1.0, + np.datetime64("NaT"), + pd.Timestamp("2021-01-01"), + "invalid", + np.arange(10, dtype="i8") * 24 * 3600 * 10**9, + (np.arange(10) * 24 * 3600 * 10**9).view("datetime64[ns]"), + pd.Timestamp("2021-01-01").to_period("D"), + ], + ) + @pytest.mark.parametrize("index", [True, False]) + def test_searchsorted_invalid_types(self, other, index): + data = np.arange(10, dtype="i8") * 24 * 3600 * 10**9 + arr = pd.TimedeltaIndex(data, freq="D")._data + if index: + arr = pd.Index(arr) + + msg = "|".join( + [ + "searchsorted requires compatible dtype or scalar", + "value should be a 'Timedelta', 'NaT', or array of those. Got", + ] + ) + with pytest.raises(TypeError, match=msg): + arr.searchsorted(other) + + +class TestUnaryOps: + def test_abs(self): + vals = np.array([-3600 * 10**9, "NaT", 7200 * 10**9], dtype="m8[ns]") + arr = TimedeltaArray._from_sequence(vals) + + evals = np.array([3600 * 10**9, "NaT", 7200 * 10**9], dtype="m8[ns]") + expected = TimedeltaArray._from_sequence(evals) + + result = abs(arr) + tm.assert_timedelta_array_equal(result, expected) + + result2 = np.abs(arr) + tm.assert_timedelta_array_equal(result2, expected) + + def test_pos(self): + vals = np.array([-3600 * 10**9, "NaT", 7200 * 10**9], dtype="m8[ns]") + arr = TimedeltaArray._from_sequence(vals) + + result = +arr + tm.assert_timedelta_array_equal(result, arr) + assert not tm.shares_memory(result, arr) + + result2 = np.positive(arr) + tm.assert_timedelta_array_equal(result2, arr) + assert not tm.shares_memory(result2, arr) + + def test_neg(self): + vals = np.array([-3600 * 10**9, "NaT", 7200 * 10**9], dtype="m8[ns]") + arr = TimedeltaArray._from_sequence(vals) + + evals = np.array([3600 * 10**9, "NaT", -7200 * 10**9], dtype="m8[ns]") + expected = TimedeltaArray._from_sequence(evals) + + result = -arr + tm.assert_timedelta_array_equal(result, expected) + + result2 = np.negative(arr) + tm.assert_timedelta_array_equal(result2, expected) + + def test_neg_freq(self): + tdi = pd.timedelta_range("2 Days", periods=4, freq="h") + arr = tdi._data + + expected = -tdi._data + + result = -arr + tm.assert_timedelta_array_equal(result, expected) + + result2 = np.negative(arr) + tm.assert_timedelta_array_equal(result2, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/base/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/base/common.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/base/common.py new file mode 100644 index 0000000000000000000000000000000000000000..ad0b394105742ca5de92a03a3da2c569c38da469 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/base/common.py @@ -0,0 +1,9 @@ +from typing import Any + +from pandas import Index + + +def allow_na_ops(obj: Any) -> bool: + """Whether to skip test cases including NaN""" + is_bool_index = isinstance(obj, Index) and obj.inferred_type == "boolean" + return not is_bool_index and obj._can_hold_na diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/base/test_constructors.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/base/test_constructors.py new file mode 100644 index 0000000000000000000000000000000000000000..f3ac60f672ee1221a9b1b43faf7c2e023d4b9d3b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/base/test_constructors.py @@ -0,0 +1,179 @@ +from datetime import datetime +import sys + +import numpy as np +import pytest + +from pandas.compat import PYPY + +import pandas as pd +from pandas import ( + DataFrame, + Index, + Series, +) +import pandas._testing as tm +from pandas.core.accessor import PandasDelegate +from pandas.core.base import ( + NoNewAttributesMixin, + PandasObject, +) + + +def series_via_frame_from_dict(x, **kwargs): + return DataFrame({"a": x}, **kwargs)["a"] + + +def series_via_frame_from_scalar(x, **kwargs): + return DataFrame(x, **kwargs)[0] + + +@pytest.fixture( + params=[ + Series, + series_via_frame_from_dict, + series_via_frame_from_scalar, + Index, + ], + ids=["Series", "DataFrame-dict", "DataFrame-array", "Index"], +) +def constructor(request): + return request.param + + +class TestPandasDelegate: + class Delegator: + _properties = ["prop"] + _methods = ["test_method"] + + def _set_prop(self, value): + self.prop = value + + def _get_prop(self): + return self.prop + + prop = property(_get_prop, _set_prop, doc="foo property") + + def test_method(self, *args, **kwargs): + """a test method""" + + class Delegate(PandasDelegate, PandasObject): + def __init__(self, obj) -> None: + self.obj = obj + + def test_invalid_delegation(self): + # these show that in order for the delegation to work + # the _delegate_* methods need to be overridden to not raise + # a TypeError + + self.Delegate._add_delegate_accessors( + delegate=self.Delegator, + accessors=self.Delegator._properties, + typ="property", + ) + self.Delegate._add_delegate_accessors( + delegate=self.Delegator, accessors=self.Delegator._methods, typ="method" + ) + + delegate = self.Delegate(self.Delegator()) + + msg = "You cannot access the property prop" + with pytest.raises(TypeError, match=msg): + delegate.prop + + msg = "The property prop cannot be set" + with pytest.raises(TypeError, match=msg): + delegate.prop = 5 + + msg = "You cannot access the property prop" + with pytest.raises(TypeError, match=msg): + delegate.prop + + @pytest.mark.skipif(PYPY, reason="not relevant for PyPy") + def test_memory_usage(self): + # Delegate does not implement memory_usage. + # Check that we fall back to in-built `__sizeof__` + # GH 12924 + delegate = self.Delegate(self.Delegator()) + sys.getsizeof(delegate) + + +class TestNoNewAttributesMixin: + def test_mixin(self): + class T(NoNewAttributesMixin): + pass + + t = T() + assert not hasattr(t, "__frozen") + + t.a = "test" + assert t.a == "test" + + t._freeze() + assert "__frozen" in dir(t) + assert getattr(t, "__frozen") + msg = "You cannot add any new attribute" + with pytest.raises(AttributeError, match=msg): + t.b = "test" + + assert not hasattr(t, "b") + + +class TestConstruction: + # test certain constructor behaviours on dtype inference across Series, + # Index and DataFrame + + @pytest.mark.parametrize( + "a", + [ + np.array(["2263-01-01"], dtype="datetime64[D]"), + np.array([datetime(2263, 1, 1)], dtype=object), + np.array([np.datetime64("2263-01-01", "D")], dtype=object), + np.array(["2263-01-01"], dtype=object), + ], + ids=[ + "datetime64[D]", + "object-datetime.datetime", + "object-numpy-scalar", + "object-string", + ], + ) + def test_constructor_datetime_outofbound( + self, a, constructor, request, using_infer_string + ): + # GH-26853 (+ bug GH-26206 out of bound non-ns unit) + + # No dtype specified (dtype inference) + # datetime64[non-ns] raise error, other cases result in object dtype + # and preserve original data + if a.dtype.kind == "M": + # Can't fit in nanosecond bounds -> get the nearest supported unit + result = constructor(a) + assert result.dtype == "M8[s]" + else: + result = constructor(a) + if using_infer_string and "object-string" in request.node.callspec.id: + assert result.dtype == "string" + else: + assert result.dtype == "object" + tm.assert_numpy_array_equal(result.to_numpy(), a) + + # Explicit dtype specified + # Forced conversion fails for all -> all cases raise error + msg = "Out of bounds|Out of bounds .* present at position 0" + with pytest.raises(pd.errors.OutOfBoundsDatetime, match=msg): + constructor(a, dtype="datetime64[ns]") + + def test_constructor_datetime_nonns(self, constructor): + arr = np.array(["2020-01-01T00:00:00.000000"], dtype="datetime64[us]") + dta = pd.core.arrays.DatetimeArray._simple_new(arr, dtype=arr.dtype) + expected = constructor(dta) + assert expected.dtype == arr.dtype + + result = constructor(arr) + tm.assert_equal(result, expected) + + # https://github.com/pandas-dev/pandas/issues/34843 + arr.flags.writeable = False + result = constructor(arr) + tm.assert_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/base/test_conversion.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/base/test_conversion.py new file mode 100644 index 0000000000000000000000000000000000000000..fe0f1f1454a55018ad5167d958c73dd6a3bc7018 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/base/test_conversion.py @@ -0,0 +1,562 @@ +import numpy as np +import pytest + +from pandas.core.dtypes.dtypes import DatetimeTZDtype + +import pandas as pd +from pandas import ( + CategoricalIndex, + Series, + Timedelta, + Timestamp, + date_range, +) +import pandas._testing as tm +from pandas.core.arrays import ( + DatetimeArray, + IntervalArray, + NumpyExtensionArray, + PeriodArray, + SparseArray, + TimedeltaArray, +) +from pandas.core.arrays.string_arrow import ArrowStringArrayNumpySemantics + + +class TestToIterable: + # test that we convert an iterable to python types + + dtypes = [ + ("int8", int), + ("int16", int), + ("int32", int), + ("int64", int), + ("uint8", int), + ("uint16", int), + ("uint32", int), + ("uint64", int), + ("float16", float), + ("float32", float), + ("float64", float), + ("datetime64[ns]", Timestamp), + ("datetime64[ns, US/Eastern]", Timestamp), + ("timedelta64[ns]", Timedelta), + ] + + @pytest.mark.parametrize("dtype, rdtype", dtypes) + @pytest.mark.parametrize( + "method", + [ + lambda x: x.tolist(), + lambda x: x.to_list(), + lambda x: list(x), + lambda x: list(x.__iter__()), + ], + ids=["tolist", "to_list", "list", "iter"], + ) + def test_iterable(self, index_or_series, method, dtype, rdtype): + # gh-10904 + # gh-13258 + # coerce iteration to underlying python / pandas types + typ = index_or_series + if dtype == "float16" and issubclass(typ, pd.Index): + with pytest.raises(NotImplementedError, match="float16 indexes are not "): + typ([1], dtype=dtype) + return + s = typ([1], dtype=dtype) + result = method(s)[0] + assert isinstance(result, rdtype) + + @pytest.mark.parametrize( + "dtype, rdtype, obj", + [ + ("object", object, "a"), + ("object", int, 1), + ("category", object, "a"), + ("category", int, 1), + ], + ) + @pytest.mark.parametrize( + "method", + [ + lambda x: x.tolist(), + lambda x: x.to_list(), + lambda x: list(x), + lambda x: list(x.__iter__()), + ], + ids=["tolist", "to_list", "list", "iter"], + ) + def test_iterable_object_and_category( + self, index_or_series, method, dtype, rdtype, obj + ): + # gh-10904 + # gh-13258 + # coerce iteration to underlying python / pandas types + typ = index_or_series + s = typ([obj], dtype=dtype) + result = method(s)[0] + assert isinstance(result, rdtype) + + @pytest.mark.parametrize("dtype, rdtype", dtypes) + def test_iterable_items(self, dtype, rdtype): + # gh-13258 + # test if items yields the correct boxed scalars + # this only applies to series + s = Series([1], dtype=dtype) + _, result = next(iter(s.items())) + assert isinstance(result, rdtype) + + _, result = next(iter(s.items())) + assert isinstance(result, rdtype) + + @pytest.mark.parametrize( + "dtype, rdtype", dtypes + [("object", int), ("category", int)] + ) + def test_iterable_map(self, index_or_series, dtype, rdtype): + # gh-13236 + # coerce iteration to underlying python / pandas types + typ = index_or_series + if dtype == "float16" and issubclass(typ, pd.Index): + with pytest.raises(NotImplementedError, match="float16 indexes are not "): + typ([1], dtype=dtype) + return + s = typ([1], dtype=dtype) + result = s.map(type)[0] + if not isinstance(rdtype, tuple): + rdtype = (rdtype,) + assert result in rdtype + + @pytest.mark.parametrize( + "method", + [ + lambda x: x.tolist(), + lambda x: x.to_list(), + lambda x: list(x), + lambda x: list(x.__iter__()), + ], + ids=["tolist", "to_list", "list", "iter"], + ) + def test_categorial_datetimelike(self, method): + i = CategoricalIndex([Timestamp("1999-12-31"), Timestamp("2000-12-31")]) + + result = method(i)[0] + assert isinstance(result, Timestamp) + + def test_iter_box_dt64(self, unit): + vals = [Timestamp("2011-01-01"), Timestamp("2011-01-02")] + ser = Series(vals).dt.as_unit(unit) + assert ser.dtype == f"datetime64[{unit}]" + for res, exp in zip(ser, vals): + assert isinstance(res, Timestamp) + assert res.tz is None + assert res == exp + assert res.unit == unit + + def test_iter_box_dt64tz(self, unit): + vals = [ + Timestamp("2011-01-01", tz="US/Eastern"), + Timestamp("2011-01-02", tz="US/Eastern"), + ] + ser = Series(vals).dt.as_unit(unit) + + assert ser.dtype == f"datetime64[{unit}, US/Eastern]" + for res, exp in zip(ser, vals): + assert isinstance(res, Timestamp) + assert res.tz == exp.tz + assert res == exp + assert res.unit == unit + + def test_iter_box_timedelta64(self, unit): + # timedelta + vals = [Timedelta("1 days"), Timedelta("2 days")] + ser = Series(vals).dt.as_unit(unit) + assert ser.dtype == f"timedelta64[{unit}]" + for res, exp in zip(ser, vals): + assert isinstance(res, Timedelta) + assert res == exp + assert res.unit == unit + + def test_iter_box_period(self): + # period + vals = [pd.Period("2011-01-01", freq="M"), pd.Period("2011-01-02", freq="M")] + s = Series(vals) + assert s.dtype == "Period[M]" + for res, exp in zip(s, vals): + assert isinstance(res, pd.Period) + assert res.freq == "ME" + assert res == exp + + +@pytest.mark.parametrize( + "arr, expected_type, dtype", + [ + (np.array([0, 1], dtype=np.int64), np.ndarray, "int64"), + (np.array(["a", "b"]), np.ndarray, "object"), + (pd.Categorical(["a", "b"]), pd.Categorical, "category"), + ( + pd.DatetimeIndex(["2017", "2018"], tz="US/Central"), + DatetimeArray, + "datetime64[ns, US/Central]", + ), + ( + pd.PeriodIndex([2018, 2019], freq="Y"), + PeriodArray, + pd.core.dtypes.dtypes.PeriodDtype("Y-DEC"), + ), + (pd.IntervalIndex.from_breaks([0, 1, 2]), IntervalArray, "interval"), + ( + pd.DatetimeIndex(["2017", "2018"]), + DatetimeArray, + "datetime64[ns]", + ), + ( + pd.TimedeltaIndex([10**10]), + TimedeltaArray, + "m8[ns]", + ), + ], +) +def test_values_consistent(arr, expected_type, dtype, using_infer_string): + if using_infer_string and dtype == "object": + expected_type = ArrowStringArrayNumpySemantics + l_values = Series(arr)._values + r_values = pd.Index(arr)._values + assert type(l_values) is expected_type + assert type(l_values) is type(r_values) + + tm.assert_equal(l_values, r_values) + + +@pytest.mark.parametrize("arr", [np.array([1, 2, 3])]) +def test_numpy_array(arr): + ser = Series(arr) + result = ser.array + expected = NumpyExtensionArray(arr) + tm.assert_extension_array_equal(result, expected) + + +def test_numpy_array_all_dtypes(any_numpy_dtype): + ser = Series(dtype=any_numpy_dtype) + result = ser.array + if np.dtype(any_numpy_dtype).kind == "M": + assert isinstance(result, DatetimeArray) + elif np.dtype(any_numpy_dtype).kind == "m": + assert isinstance(result, TimedeltaArray) + else: + assert isinstance(result, NumpyExtensionArray) + + +@pytest.mark.parametrize( + "arr, attr", + [ + (pd.Categorical(["a", "b"]), "_codes"), + (PeriodArray._from_sequence(["2000", "2001"], dtype="period[D]"), "_ndarray"), + (pd.array([0, np.nan], dtype="Int64"), "_data"), + (IntervalArray.from_breaks([0, 1]), "_left"), + (SparseArray([0, 1]), "_sparse_values"), + ( + DatetimeArray._from_sequence(np.array([1, 2], dtype="datetime64[ns]")), + "_ndarray", + ), + # tz-aware Datetime + ( + DatetimeArray._from_sequence( + np.array( + ["2000-01-01T12:00:00", "2000-01-02T12:00:00"], dtype="M8[ns]" + ), + dtype=DatetimeTZDtype(tz="US/Central"), + ), + "_ndarray", + ), + ], +) +def test_array(arr, attr, index_or_series, request): + box = index_or_series + + result = box(arr, copy=False).array + + if attr: + arr = getattr(arr, attr) + result = getattr(result, attr) + + assert result is arr + + +def test_array_multiindex_raises(): + idx = pd.MultiIndex.from_product([["A"], ["a", "b"]]) + msg = "MultiIndex has no single backing array" + with pytest.raises(ValueError, match=msg): + idx.array + + +@pytest.mark.parametrize( + "arr, expected", + [ + (np.array([1, 2], dtype=np.int64), np.array([1, 2], dtype=np.int64)), + (pd.Categorical(["a", "b"]), np.array(["a", "b"], dtype=object)), + ( + pd.core.arrays.period_array(["2000", "2001"], freq="D"), + np.array([pd.Period("2000", freq="D"), pd.Period("2001", freq="D")]), + ), + (pd.array([0, np.nan], dtype="Int64"), np.array([0, np.nan])), + ( + IntervalArray.from_breaks([0, 1, 2]), + np.array([pd.Interval(0, 1), pd.Interval(1, 2)], dtype=object), + ), + (SparseArray([0, 1]), np.array([0, 1], dtype=np.int64)), + # tz-naive datetime + ( + DatetimeArray._from_sequence(np.array(["2000", "2001"], dtype="M8[ns]")), + np.array(["2000", "2001"], dtype="M8[ns]"), + ), + # tz-aware stays tz`-aware + ( + DatetimeArray._from_sequence( + np.array(["2000-01-01T06:00:00", "2000-01-02T06:00:00"], dtype="M8[ns]") + ) + .tz_localize("UTC") + .tz_convert("US/Central"), + np.array( + [ + Timestamp("2000-01-01", tz="US/Central"), + Timestamp("2000-01-02", tz="US/Central"), + ] + ), + ), + # Timedelta + ( + TimedeltaArray._from_sequence( + np.array([0, 3600000000000], dtype="i8").view("m8[ns]") + ), + np.array([0, 3600000000000], dtype="m8[ns]"), + ), + # GH#26406 tz is preserved in Categorical[dt64tz] + ( + pd.Categorical(date_range("2016-01-01", periods=2, tz="US/Pacific")), + np.array( + [ + Timestamp("2016-01-01", tz="US/Pacific"), + Timestamp("2016-01-02", tz="US/Pacific"), + ] + ), + ), + ], +) +def test_to_numpy(arr, expected, index_or_series_or_array, request): + box = index_or_series_or_array + + with tm.assert_produces_warning(None): + thing = box(arr) + + result = thing.to_numpy() + tm.assert_numpy_array_equal(result, expected) + + result = np.asarray(thing) + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize("as_series", [True, False]) +@pytest.mark.parametrize( + "arr", [np.array([1, 2, 3], dtype="int64"), np.array(["a", "b", "c"], dtype=object)] +) +def test_to_numpy_copy(arr, as_series, using_infer_string): + obj = pd.Index(arr, copy=False) + if as_series: + obj = Series(obj.values, copy=False) + + # no copy by default + result = obj.to_numpy() + if using_infer_string and arr.dtype == object: + assert np.shares_memory(arr, result) is False + else: + assert np.shares_memory(arr, result) is True + + result = obj.to_numpy(copy=False) + if using_infer_string and arr.dtype == object: + assert np.shares_memory(arr, result) is False + else: + assert np.shares_memory(arr, result) is True + + # copy=True + result = obj.to_numpy(copy=True) + assert np.shares_memory(arr, result) is False + + +@pytest.mark.parametrize("as_series", [True, False]) +def test_to_numpy_dtype(as_series, unit): + tz = "US/Eastern" + obj = pd.DatetimeIndex(["2000", "2001"], tz=tz) + if as_series: + obj = Series(obj) + + # preserve tz by default + result = obj.to_numpy() + expected = np.array( + [Timestamp("2000", tz=tz), Timestamp("2001", tz=tz)], dtype=object + ) + tm.assert_numpy_array_equal(result, expected) + + result = obj.to_numpy(dtype="object") + tm.assert_numpy_array_equal(result, expected) + + result = obj.to_numpy(dtype="M8[ns]") + expected = np.array(["2000-01-01T05", "2001-01-01T05"], dtype="M8[ns]") + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize( + "values, dtype, na_value, expected", + [ + ([1, 2, None], "float64", 0, [1.0, 2.0, 0.0]), + ( + [Timestamp("2000"), Timestamp("2000"), pd.NaT], + None, + Timestamp("2000"), + [np.datetime64("2000-01-01T00:00:00.000000000")] * 3, + ), + ], +) +def test_to_numpy_na_value_numpy_dtype( + index_or_series, values, dtype, na_value, expected +): + obj = index_or_series(values) + result = obj.to_numpy(dtype=dtype, na_value=na_value) + expected = np.array(expected) + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize( + "data, multiindex, dtype, na_value, expected", + [ + ( + [1, 2, None, 4], + [(0, "a"), (0, "b"), (1, "b"), (1, "c")], + float, + None, + [1.0, 2.0, np.nan, 4.0], + ), + ( + [1, 2, None, 4], + [(0, "a"), (0, "b"), (1, "b"), (1, "c")], + float, + np.nan, + [1.0, 2.0, np.nan, 4.0], + ), + ( + [1.0, 2.0, np.nan, 4.0], + [("a", 0), ("a", 1), ("a", 2), ("b", 0)], + int, + 0, + [1, 2, 0, 4], + ), + ( + [Timestamp("2000"), Timestamp("2000"), pd.NaT], + [(0, Timestamp("2021")), (0, Timestamp("2022")), (1, Timestamp("2000"))], + None, + Timestamp("2000"), + [np.datetime64("2000-01-01T00:00:00.000000000")] * 3, + ), + ], +) +def test_to_numpy_multiindex_series_na_value( + data, multiindex, dtype, na_value, expected +): + index = pd.MultiIndex.from_tuples(multiindex) + series = Series(data, index=index) + result = series.to_numpy(dtype=dtype, na_value=na_value) + expected = np.array(expected) + tm.assert_numpy_array_equal(result, expected) + + +def test_to_numpy_kwargs_raises(): + # numpy + s = Series([1, 2, 3]) + msg = r"to_numpy\(\) got an unexpected keyword argument 'foo'" + with pytest.raises(TypeError, match=msg): + s.to_numpy(foo=True) + + # extension + s = Series([1, 2, 3], dtype="Int64") + with pytest.raises(TypeError, match=msg): + s.to_numpy(foo=True) + + +@pytest.mark.parametrize( + "data", + [ + {"a": [1, 2, 3], "b": [1, 2, None]}, + {"a": np.array([1, 2, 3]), "b": np.array([1, 2, np.nan])}, + {"a": pd.array([1, 2, 3]), "b": pd.array([1, 2, None])}, + ], +) +@pytest.mark.parametrize("dtype, na_value", [(float, np.nan), (object, None)]) +def test_to_numpy_dataframe_na_value(data, dtype, na_value): + # https://github.com/pandas-dev/pandas/issues/33820 + df = pd.DataFrame(data) + result = df.to_numpy(dtype=dtype, na_value=na_value) + expected = np.array([[1, 1], [2, 2], [3, na_value]], dtype=dtype) + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize( + "data, expected", + [ + ( + {"a": pd.array([1, 2, None])}, + np.array([[1.0], [2.0], [np.nan]], dtype=float), + ), + ( + {"a": [1, 2, 3], "b": [1, 2, 3]}, + np.array([[1, 1], [2, 2], [3, 3]], dtype=float), + ), + ], +) +def test_to_numpy_dataframe_single_block(data, expected): + # https://github.com/pandas-dev/pandas/issues/33820 + df = pd.DataFrame(data) + result = df.to_numpy(dtype=float, na_value=np.nan) + tm.assert_numpy_array_equal(result, expected) + + +def test_to_numpy_dataframe_single_block_no_mutate(): + # https://github.com/pandas-dev/pandas/issues/33820 + result = pd.DataFrame(np.array([1.0, 2.0, np.nan])) + expected = pd.DataFrame(np.array([1.0, 2.0, np.nan])) + result.to_numpy(na_value=0.0) + tm.assert_frame_equal(result, expected) + + +class TestAsArray: + @pytest.mark.parametrize("tz", [None, "US/Central"]) + def test_asarray_object_dt64(self, tz): + ser = Series(date_range("2000", periods=2, tz=tz)) + + with tm.assert_produces_warning(None): + # Future behavior (for tzaware case) with no warning + result = np.asarray(ser, dtype=object) + + expected = np.array( + [Timestamp("2000-01-01", tz=tz), Timestamp("2000-01-02", tz=tz)] + ) + tm.assert_numpy_array_equal(result, expected) + + def test_asarray_tz_naive(self): + # This shouldn't produce a warning. + ser = Series(date_range("2000", periods=2)) + expected = np.array(["2000-01-01", "2000-01-02"], dtype="M8[ns]") + result = np.asarray(ser) + + tm.assert_numpy_array_equal(result, expected) + + def test_asarray_tz_aware(self): + tz = "US/Central" + ser = Series(date_range("2000", periods=2, tz=tz)) + expected = np.array(["2000-01-01T06", "2000-01-02T06"], dtype="M8[ns]") + result = np.asarray(ser, dtype="datetime64[ns]") + + tm.assert_numpy_array_equal(result, expected) + + # Old behavior with no warning + result = np.asarray(ser, dtype="M8[ns]") + + tm.assert_numpy_array_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/base/test_fillna.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/base/test_fillna.py new file mode 100644 index 0000000000000000000000000000000000000000..7300d3013305a7ca08312ae85cc42ae8950acf23 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/base/test_fillna.py @@ -0,0 +1,60 @@ +""" +Though Index.fillna and Series.fillna has separate impl, +test here to confirm these works as the same +""" + +import numpy as np +import pytest + +from pandas import MultiIndex +import pandas._testing as tm +from pandas.tests.base.common import allow_na_ops + + +def test_fillna(index_or_series_obj): + # GH 11343 + obj = index_or_series_obj + + if isinstance(obj, MultiIndex): + msg = "isna is not defined for MultiIndex" + with pytest.raises(NotImplementedError, match=msg): + obj.fillna(0) + return + + # values will not be changed + fill_value = obj.values[0] if len(obj) > 0 else 0 + result = obj.fillna(fill_value) + + tm.assert_equal(obj, result) + + # check shallow_copied + assert obj is not result + + +@pytest.mark.parametrize("null_obj", [np.nan, None]) +def test_fillna_null(null_obj, index_or_series_obj): + # GH 11343 + obj = index_or_series_obj + klass = type(obj) + + if not allow_na_ops(obj): + pytest.skip(f"{klass} doesn't allow for NA operations") + elif len(obj) < 1: + pytest.skip("Test doesn't make sense on empty data") + elif isinstance(obj, MultiIndex): + pytest.skip(f"MultiIndex can't hold '{null_obj}'") + + values = obj._values + fill_value = values[0] + expected = values.copy() + values[0:2] = null_obj + expected[0:2] = fill_value + + expected = klass(expected) + obj = klass(values) + + result = obj.fillna(fill_value) + tm.assert_equal(result, expected) + + # check shallow_copied + assert obj is not result diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/base/test_misc.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/base/test_misc.py new file mode 100644 index 0000000000000000000000000000000000000000..65e234e799353844bab2a63df582adfa5842d2cd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/base/test_misc.py @@ -0,0 +1,191 @@ +import sys + +import numpy as np +import pytest + +from pandas._config import using_pyarrow_string_dtype + +from pandas.compat import PYPY + +from pandas.core.dtypes.common import ( + is_dtype_equal, + is_object_dtype, +) + +import pandas as pd +from pandas import ( + Index, + Series, +) +import pandas._testing as tm + + +def test_isnull_notnull_docstrings(): + # GH#41855 make sure its clear these are aliases + doc = pd.DataFrame.notnull.__doc__ + assert doc.startswith("\nDataFrame.notnull is an alias for DataFrame.notna.\n") + doc = pd.DataFrame.isnull.__doc__ + assert doc.startswith("\nDataFrame.isnull is an alias for DataFrame.isna.\n") + + doc = Series.notnull.__doc__ + assert doc.startswith("\nSeries.notnull is an alias for Series.notna.\n") + doc = Series.isnull.__doc__ + assert doc.startswith("\nSeries.isnull is an alias for Series.isna.\n") + + +@pytest.mark.parametrize( + "op_name, op", + [ + ("add", "+"), + ("sub", "-"), + ("mul", "*"), + ("mod", "%"), + ("pow", "**"), + ("truediv", "/"), + ("floordiv", "//"), + ], +) +def test_binary_ops_docstring(frame_or_series, op_name, op): + # not using the all_arithmetic_functions fixture with _get_opstr + # as _get_opstr is used internally in the dynamic implementation of the docstring + klass = frame_or_series + + operand1 = klass.__name__.lower() + operand2 = "other" + expected_str = " ".join([operand1, op, operand2]) + assert expected_str in getattr(klass, op_name).__doc__ + + # reverse version of the binary ops + expected_str = " ".join([operand2, op, operand1]) + assert expected_str in getattr(klass, "r" + op_name).__doc__ + + +def test_ndarray_compat_properties(index_or_series_obj): + obj = index_or_series_obj + + # Check that we work. + for p in ["shape", "dtype", "T", "nbytes"]: + assert getattr(obj, p, None) is not None + + # deprecated properties + for p in ["strides", "itemsize", "base", "data"]: + assert not hasattr(obj, p) + + msg = "can only convert an array of size 1 to a Python scalar" + with pytest.raises(ValueError, match=msg): + obj.item() # len > 1 + + assert obj.ndim == 1 + assert obj.size == len(obj) + + assert Index([1]).item() == 1 + assert Series([1]).item() == 1 + + +@pytest.mark.skipif( + PYPY or using_pyarrow_string_dtype(), + reason="not relevant for PyPy doesn't work properly for arrow strings", +) +def test_memory_usage(index_or_series_memory_obj): + obj = index_or_series_memory_obj + # Clear index caches so that len(obj) == 0 report 0 memory usage + if isinstance(obj, Series): + is_ser = True + obj.index._engine.clear_mapping() + else: + is_ser = False + obj._engine.clear_mapping() + + res = obj.memory_usage() + res_deep = obj.memory_usage(deep=True) + + is_object = is_object_dtype(obj) or (is_ser and is_object_dtype(obj.index)) + is_categorical = isinstance(obj.dtype, pd.CategoricalDtype) or ( + is_ser and isinstance(obj.index.dtype, pd.CategoricalDtype) + ) + is_object_string = is_dtype_equal(obj, "string[python]") or ( + is_ser and is_dtype_equal(obj.index.dtype, "string[python]") + ) + + if len(obj) == 0: + expected = 0 + assert res_deep == res == expected + elif is_object or is_categorical or is_object_string: + # only deep will pick them up + assert res_deep > res + else: + assert res == res_deep + + # sys.getsizeof will call the .memory_usage with + # deep=True, and add on some GC overhead + diff = res_deep - sys.getsizeof(obj) + assert abs(diff) < 100 + + +def test_memory_usage_components_series(series_with_simple_index): + series = series_with_simple_index + total_usage = series.memory_usage(index=True) + non_index_usage = series.memory_usage(index=False) + index_usage = series.index.memory_usage() + assert total_usage == non_index_usage + index_usage + + +@pytest.mark.parametrize("dtype", tm.NARROW_NP_DTYPES) +def test_memory_usage_components_narrow_series(dtype): + series = Series(range(5), dtype=dtype, index=[f"i-{i}" for i in range(5)], name="a") + total_usage = series.memory_usage(index=True) + non_index_usage = series.memory_usage(index=False) + index_usage = series.index.memory_usage() + assert total_usage == non_index_usage + index_usage + + +def test_searchsorted(request, index_or_series_obj): + # numpy.searchsorted calls obj.searchsorted under the hood. + # See gh-12238 + obj = index_or_series_obj + + if isinstance(obj, pd.MultiIndex): + # See gh-14833 + request.applymarker( + pytest.mark.xfail( + reason="np.searchsorted doesn't work on pd.MultiIndex: GH 14833" + ) + ) + elif obj.dtype.kind == "c" and isinstance(obj, Index): + # TODO: Should Series cases also raise? Looks like they use numpy + # comparison semantics https://github.com/numpy/numpy/issues/15981 + mark = pytest.mark.xfail(reason="complex objects are not comparable") + request.applymarker(mark) + + max_obj = max(obj, default=0) + index = np.searchsorted(obj, max_obj) + assert 0 <= index <= len(obj) + + index = np.searchsorted(obj, max_obj, sorter=range(len(obj))) + assert 0 <= index <= len(obj) + + +def test_access_by_position(index_flat): + index = index_flat + + if len(index) == 0: + pytest.skip("Test doesn't make sense on empty data") + + series = Series(index) + assert index[0] == series.iloc[0] + assert index[5] == series.iloc[5] + assert index[-1] == series.iloc[-1] + + size = len(index) + assert index[-1] == index[size - 1] + + msg = f"index {size} is out of bounds for axis 0 with size {size}" + if is_dtype_equal(index.dtype, "string[pyarrow]") or is_dtype_equal( + index.dtype, "string[pyarrow_numpy]" + ): + msg = "index out of bounds" + with pytest.raises(IndexError, match=msg): + index[size] + msg = "single positional indexer is out-of-bounds" + with pytest.raises(IndexError, match=msg): + series.iloc[size] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/base/test_transpose.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/base/test_transpose.py new file mode 100644 index 0000000000000000000000000000000000000000..246f33d27476cb419620fb8571984619785f9b62 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/base/test_transpose.py @@ -0,0 +1,56 @@ +import numpy as np +import pytest + +from pandas import ( + CategoricalDtype, + DataFrame, +) +import pandas._testing as tm + + +def test_transpose(index_or_series_obj): + obj = index_or_series_obj + tm.assert_equal(obj.transpose(), obj) + + +def test_transpose_non_default_axes(index_or_series_obj): + msg = "the 'axes' parameter is not supported" + obj = index_or_series_obj + with pytest.raises(ValueError, match=msg): + obj.transpose(1) + with pytest.raises(ValueError, match=msg): + obj.transpose(axes=1) + + +def test_numpy_transpose(index_or_series_obj): + msg = "the 'axes' parameter is not supported" + obj = index_or_series_obj + tm.assert_equal(np.transpose(obj), obj) + + with pytest.raises(ValueError, match=msg): + np.transpose(obj, axes=1) + + +@pytest.mark.parametrize( + "data, transposed_data, index, columns, dtype", + [ + ([[1], [2]], [[1, 2]], ["a", "a"], ["b"], int), + ([[1], [2]], [[1, 2]], ["a", "a"], ["b"], CategoricalDtype([1, 2])), + ([[1, 2]], [[1], [2]], ["b"], ["a", "a"], int), + ([[1, 2]], [[1], [2]], ["b"], ["a", "a"], CategoricalDtype([1, 2])), + ([[1, 2], [3, 4]], [[1, 3], [2, 4]], ["a", "a"], ["b", "b"], int), + ( + [[1, 2], [3, 4]], + [[1, 3], [2, 4]], + ["a", "a"], + ["b", "b"], + CategoricalDtype([1, 2, 3, 4]), + ), + ], +) +def test_duplicate_labels(data, transposed_data, index, columns, dtype): + # GH 42380 + df = DataFrame(data, index=index, columns=columns, dtype=dtype) + result = df.T + expected = DataFrame(transposed_data, index=columns, columns=index, dtype=dtype) + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/base/test_unique.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/base/test_unique.py new file mode 100644 index 0000000000000000000000000000000000000000..d3fe144f70cfc2b54d978ab80ba23ef896948b9c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/base/test_unique.py @@ -0,0 +1,124 @@ +import numpy as np +import pytest + +from pandas._config import using_pyarrow_string_dtype + +import pandas as pd +import pandas._testing as tm +from pandas.tests.base.common import allow_na_ops + + +@pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") +def test_unique(index_or_series_obj): + obj = index_or_series_obj + obj = np.repeat(obj, range(1, len(obj) + 1)) + result = obj.unique() + + # dict.fromkeys preserves the order + unique_values = list(dict.fromkeys(obj.values)) + if isinstance(obj, pd.MultiIndex): + expected = pd.MultiIndex.from_tuples(unique_values) + expected.names = obj.names + tm.assert_index_equal(result, expected, exact=True) + elif isinstance(obj, pd.Index): + expected = pd.Index(unique_values, dtype=obj.dtype) + if isinstance(obj.dtype, pd.DatetimeTZDtype): + expected = expected.normalize() + tm.assert_index_equal(result, expected, exact=True) + else: + expected = np.array(unique_values) + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") +@pytest.mark.parametrize("null_obj", [np.nan, None]) +def test_unique_null(null_obj, index_or_series_obj): + obj = index_or_series_obj + + if not allow_na_ops(obj): + pytest.skip("type doesn't allow for NA operations") + elif len(obj) < 1: + pytest.skip("Test doesn't make sense on empty data") + elif isinstance(obj, pd.MultiIndex): + pytest.skip(f"MultiIndex can't hold '{null_obj}'") + + values = obj._values + values[0:2] = null_obj + + klass = type(obj) + repeated_values = np.repeat(values, range(1, len(values) + 1)) + obj = klass(repeated_values, dtype=obj.dtype) + result = obj.unique() + + unique_values_raw = dict.fromkeys(obj.values) + # because np.nan == np.nan is False, but None == None is True + # np.nan would be duplicated, whereas None wouldn't + unique_values_not_null = [val for val in unique_values_raw if not pd.isnull(val)] + unique_values = [null_obj] + unique_values_not_null + + if isinstance(obj, pd.Index): + expected = pd.Index(unique_values, dtype=obj.dtype) + if isinstance(obj.dtype, pd.DatetimeTZDtype): + result = result.normalize() + expected = expected.normalize() + tm.assert_index_equal(result, expected, exact=True) + else: + expected = np.array(unique_values, dtype=obj.dtype) + tm.assert_numpy_array_equal(result, expected) + + +def test_nunique(index_or_series_obj): + obj = index_or_series_obj + obj = np.repeat(obj, range(1, len(obj) + 1)) + expected = len(obj.unique()) + assert obj.nunique(dropna=False) == expected + + +@pytest.mark.parametrize("null_obj", [np.nan, None]) +def test_nunique_null(null_obj, index_or_series_obj): + obj = index_or_series_obj + + if not allow_na_ops(obj): + pytest.skip("type doesn't allow for NA operations") + elif isinstance(obj, pd.MultiIndex): + pytest.skip(f"MultiIndex can't hold '{null_obj}'") + + values = obj._values + values[0:2] = null_obj + + klass = type(obj) + repeated_values = np.repeat(values, range(1, len(values) + 1)) + obj = klass(repeated_values, dtype=obj.dtype) + + if isinstance(obj, pd.CategoricalIndex): + assert obj.nunique() == len(obj.categories) + assert obj.nunique(dropna=False) == len(obj.categories) + 1 + else: + num_unique_values = len(obj.unique()) + assert obj.nunique() == max(0, num_unique_values - 1) + assert obj.nunique(dropna=False) == max(0, num_unique_values) + + +@pytest.mark.single_cpu +@pytest.mark.xfail(using_pyarrow_string_dtype(), reason="decoding fails") +def test_unique_bad_unicode(index_or_series): + # regression test for #34550 + uval = "\ud83d" # smiley emoji + + obj = index_or_series([uval] * 2) + result = obj.unique() + + if isinstance(obj, pd.Index): + expected = pd.Index(["\ud83d"], dtype=object) + tm.assert_index_equal(result, expected, exact=True) + else: + expected = np.array(["\ud83d"], dtype=object) + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize("dropna", [True, False]) +def test_nunique_dropna(dropna): + # GH37566 + ser = pd.Series(["yes", "yes", pd.NA, np.nan, None, pd.NaT]) + res = ser.nunique(dropna) + assert res == 1 if dropna else 5 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/base/test_value_counts.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/base/test_value_counts.py new file mode 100644 index 0000000000000000000000000000000000000000..27296663988774ff01af6497060b63019c7deeb9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/base/test_value_counts.py @@ -0,0 +1,356 @@ +import collections +from datetime import timedelta + +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DatetimeIndex, + Index, + Interval, + IntervalIndex, + MultiIndex, + Series, + Timedelta, + TimedeltaIndex, + array, +) +import pandas._testing as tm +from pandas.tests.base.common import allow_na_ops + + +@pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") +def test_value_counts(index_or_series_obj): + obj = index_or_series_obj + obj = np.repeat(obj, range(1, len(obj) + 1)) + result = obj.value_counts() + + counter = collections.Counter(obj) + expected = Series(dict(counter.most_common()), dtype=np.int64, name="count") + + if obj.dtype != np.float16: + expected.index = expected.index.astype(obj.dtype) + else: + with pytest.raises(NotImplementedError, match="float16 indexes are not "): + expected.index.astype(obj.dtype) + return + if isinstance(expected.index, MultiIndex): + expected.index.names = obj.names + else: + expected.index.name = obj.name + + if not isinstance(result.dtype, np.dtype): + if getattr(obj.dtype, "storage", "") == "pyarrow": + expected = expected.astype("int64[pyarrow]") + else: + # i.e IntegerDtype + expected = expected.astype("Int64") + + # TODO(GH#32514): Order of entries with the same count is inconsistent + # on CI (gh-32449) + if obj.duplicated().any(): + result = result.sort_index() + expected = expected.sort_index() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("null_obj", [np.nan, None]) +@pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") +def test_value_counts_null(null_obj, index_or_series_obj): + orig = index_or_series_obj + obj = orig.copy() + + if not allow_na_ops(obj): + pytest.skip("type doesn't allow for NA operations") + elif len(obj) < 1: + pytest.skip("Test doesn't make sense on empty data") + elif isinstance(orig, MultiIndex): + pytest.skip(f"MultiIndex can't hold '{null_obj}'") + + values = obj._values + values[0:2] = null_obj + + klass = type(obj) + repeated_values = np.repeat(values, range(1, len(values) + 1)) + obj = klass(repeated_values, dtype=obj.dtype) + + # because np.nan == np.nan is False, but None == None is True + # np.nan would be duplicated, whereas None wouldn't + counter = collections.Counter(obj.dropna()) + expected = Series(dict(counter.most_common()), dtype=np.int64, name="count") + + if obj.dtype != np.float16: + expected.index = expected.index.astype(obj.dtype) + else: + with pytest.raises(NotImplementedError, match="float16 indexes are not "): + expected.index.astype(obj.dtype) + return + expected.index.name = obj.name + + result = obj.value_counts() + if obj.duplicated().any(): + # TODO(GH#32514): + # Order of entries with the same count is inconsistent on CI (gh-32449) + expected = expected.sort_index() + result = result.sort_index() + + if not isinstance(result.dtype, np.dtype): + if getattr(obj.dtype, "storage", "") == "pyarrow": + expected = expected.astype("int64[pyarrow]") + else: + # i.e IntegerDtype + expected = expected.astype("Int64") + tm.assert_series_equal(result, expected) + + expected[null_obj] = 3 + + result = obj.value_counts(dropna=False) + if obj.duplicated().any(): + # TODO(GH#32514): + # Order of entries with the same count is inconsistent on CI (gh-32449) + expected = expected.sort_index() + result = result.sort_index() + tm.assert_series_equal(result, expected) + + +def test_value_counts_inferred(index_or_series, using_infer_string): + klass = index_or_series + s_values = ["a", "b", "b", "b", "b", "c", "d", "d", "a", "a"] + s = klass(s_values) + expected = Series([4, 3, 2, 1], index=["b", "a", "d", "c"], name="count") + tm.assert_series_equal(s.value_counts(), expected) + + if isinstance(s, Index): + exp = Index(np.unique(np.array(s_values, dtype=np.object_))) + tm.assert_index_equal(s.unique(), exp) + else: + exp = np.unique(np.array(s_values, dtype=np.object_)) + if using_infer_string: + exp = array(exp) + tm.assert_equal(s.unique(), exp) + + assert s.nunique() == 4 + # don't sort, have to sort after the fact as not sorting is + # platform-dep + hist = s.value_counts(sort=False).sort_values() + expected = Series([3, 1, 4, 2], index=list("acbd"), name="count").sort_values() + tm.assert_series_equal(hist, expected) + + # sort ascending + hist = s.value_counts(ascending=True) + expected = Series([1, 2, 3, 4], index=list("cdab"), name="count") + tm.assert_series_equal(hist, expected) + + # relative histogram. + hist = s.value_counts(normalize=True) + expected = Series( + [0.4, 0.3, 0.2, 0.1], index=["b", "a", "d", "c"], name="proportion" + ) + tm.assert_series_equal(hist, expected) + + +def test_value_counts_bins(index_or_series, using_infer_string): + klass = index_or_series + s_values = ["a", "b", "b", "b", "b", "c", "d", "d", "a", "a"] + s = klass(s_values) + + # bins + msg = "bins argument only works with numeric data" + with pytest.raises(TypeError, match=msg): + s.value_counts(bins=1) + + s1 = Series([1, 1, 2, 3]) + res1 = s1.value_counts(bins=1) + exp1 = Series({Interval(0.997, 3.0): 4}, name="count") + tm.assert_series_equal(res1, exp1) + res1n = s1.value_counts(bins=1, normalize=True) + exp1n = Series({Interval(0.997, 3.0): 1.0}, name="proportion") + tm.assert_series_equal(res1n, exp1n) + + if isinstance(s1, Index): + tm.assert_index_equal(s1.unique(), Index([1, 2, 3])) + else: + exp = np.array([1, 2, 3], dtype=np.int64) + tm.assert_numpy_array_equal(s1.unique(), exp) + + assert s1.nunique() == 3 + + # these return the same + res4 = s1.value_counts(bins=4, dropna=True) + intervals = IntervalIndex.from_breaks([0.997, 1.5, 2.0, 2.5, 3.0]) + exp4 = Series([2, 1, 1, 0], index=intervals.take([0, 1, 3, 2]), name="count") + tm.assert_series_equal(res4, exp4) + + res4 = s1.value_counts(bins=4, dropna=False) + intervals = IntervalIndex.from_breaks([0.997, 1.5, 2.0, 2.5, 3.0]) + exp4 = Series([2, 1, 1, 0], index=intervals.take([0, 1, 3, 2]), name="count") + tm.assert_series_equal(res4, exp4) + + res4n = s1.value_counts(bins=4, normalize=True) + exp4n = Series( + [0.5, 0.25, 0.25, 0], index=intervals.take([0, 1, 3, 2]), name="proportion" + ) + tm.assert_series_equal(res4n, exp4n) + + # handle NA's properly + s_values = ["a", "b", "b", "b", np.nan, np.nan, "d", "d", "a", "a", "b"] + s = klass(s_values) + expected = Series([4, 3, 2], index=["b", "a", "d"], name="count") + tm.assert_series_equal(s.value_counts(), expected) + + if isinstance(s, Index): + exp = Index(["a", "b", np.nan, "d"]) + tm.assert_index_equal(s.unique(), exp) + else: + exp = np.array(["a", "b", np.nan, "d"], dtype=object) + if using_infer_string: + exp = array(exp) + tm.assert_equal(s.unique(), exp) + assert s.nunique() == 3 + + s = klass({}) if klass is dict else klass({}, dtype=object) + expected = Series([], dtype=np.int64, name="count") + tm.assert_series_equal(s.value_counts(), expected, check_index_type=False) + # returned dtype differs depending on original + if isinstance(s, Index): + tm.assert_index_equal(s.unique(), Index([]), exact=False) + else: + tm.assert_numpy_array_equal(s.unique(), np.array([]), check_dtype=False) + + assert s.nunique() == 0 + + +def test_value_counts_datetime64(index_or_series, unit): + klass = index_or_series + + # GH 3002, datetime64[ns] + # don't test names though + df = pd.DataFrame( + { + "person_id": ["xxyyzz", "xxyyzz", "xxyyzz", "xxyyww", "foofoo", "foofoo"], + "dt": pd.to_datetime( + [ + "2010-01-01", + "2010-01-01", + "2010-01-01", + "2009-01-01", + "2008-09-09", + "2008-09-09", + ] + ).as_unit(unit), + "food": ["PIE", "GUM", "EGG", "EGG", "PIE", "GUM"], + } + ) + + s = klass(df["dt"].copy()) + s.name = None + idx = pd.to_datetime( + ["2010-01-01 00:00:00", "2008-09-09 00:00:00", "2009-01-01 00:00:00"] + ).as_unit(unit) + expected_s = Series([3, 2, 1], index=idx, name="count") + tm.assert_series_equal(s.value_counts(), expected_s) + + expected = array( + np.array( + ["2010-01-01 00:00:00", "2009-01-01 00:00:00", "2008-09-09 00:00:00"], + dtype=f"datetime64[{unit}]", + ) + ) + result = s.unique() + if isinstance(s, Index): + tm.assert_index_equal(result, DatetimeIndex(expected)) + else: + tm.assert_extension_array_equal(result, expected) + + assert s.nunique() == 3 + + # with NaT + s = df["dt"].copy() + s = klass(list(s.values) + [pd.NaT] * 4) + if klass is Series: + s = s.dt.as_unit(unit) + else: + s = s.as_unit(unit) + + result = s.value_counts() + assert result.index.dtype == f"datetime64[{unit}]" + tm.assert_series_equal(result, expected_s) + + result = s.value_counts(dropna=False) + expected_s = pd.concat( + [ + Series([4], index=DatetimeIndex([pd.NaT]).as_unit(unit), name="count"), + expected_s, + ] + ) + tm.assert_series_equal(result, expected_s) + + assert s.dtype == f"datetime64[{unit}]" + unique = s.unique() + assert unique.dtype == f"datetime64[{unit}]" + + # numpy_array_equal cannot compare pd.NaT + if isinstance(s, Index): + exp_idx = DatetimeIndex(expected.tolist() + [pd.NaT]).as_unit(unit) + tm.assert_index_equal(unique, exp_idx) + else: + tm.assert_extension_array_equal(unique[:3], expected) + assert pd.isna(unique[3]) + + assert s.nunique() == 3 + assert s.nunique(dropna=False) == 4 + + +def test_value_counts_timedelta64(index_or_series, unit): + # timedelta64[ns] + klass = index_or_series + + day = Timedelta(timedelta(1)).as_unit(unit) + tdi = TimedeltaIndex([day], name="dt").as_unit(unit) + + tdvals = np.zeros(6, dtype=f"m8[{unit}]") + day + td = klass(tdvals, name="dt") + + result = td.value_counts() + expected_s = Series([6], index=tdi, name="count") + tm.assert_series_equal(result, expected_s) + + expected = tdi + result = td.unique() + if isinstance(td, Index): + tm.assert_index_equal(result, expected) + else: + tm.assert_extension_array_equal(result, expected._values) + + td2 = day + np.zeros(6, dtype=f"m8[{unit}]") + td2 = klass(td2, name="dt") + result2 = td2.value_counts() + tm.assert_series_equal(result2, expected_s) + + +@pytest.mark.parametrize("dropna", [True, False]) +def test_value_counts_with_nan(dropna, index_or_series): + # GH31944 + klass = index_or_series + values = [True, pd.NA, np.nan] + obj = klass(values) + res = obj.value_counts(dropna=dropna) + if dropna is True: + expected = Series([1], index=Index([True], dtype=obj.dtype), name="count") + else: + expected = Series([1, 1, 1], index=[True, pd.NA, np.nan], name="count") + tm.assert_series_equal(res, expected) + + +def test_value_counts_object_inference_deprecated(): + # GH#56161 + dti = pd.date_range("2016-01-01", periods=3, tz="UTC") + + idx = dti.astype(object) + msg = "The behavior of value_counts with object-dtype is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + res = idx.value_counts() + + exp = dti.value_counts() + tm.assert_series_equal(res, exp) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/computation/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/computation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/computation/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/computation/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c9818ab949fd2bd527b87a65b2c7f2a721a6b81 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/computation/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/computation/test_compat.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/computation/test_compat.py new file mode 100644 index 0000000000000000000000000000000000000000..856a5b3a22a95d35cc577050f52d762b065e3ddf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/computation/test_compat.py @@ -0,0 +1,32 @@ +import pytest + +from pandas.compat._optional import VERSIONS + +import pandas as pd +from pandas.core.computation import expr +from pandas.core.computation.engines import ENGINES +from pandas.util.version import Version + + +def test_compat(): + # test we have compat with our version of numexpr + + from pandas.core.computation.check import NUMEXPR_INSTALLED + + ne = pytest.importorskip("numexpr") + + ver = ne.__version__ + if Version(ver) < Version(VERSIONS["numexpr"]): + assert not NUMEXPR_INSTALLED + else: + assert NUMEXPR_INSTALLED + + +@pytest.mark.parametrize("engine", ENGINES) +@pytest.mark.parametrize("parser", expr.PARSERS) +def test_invalid_numexpr_version(engine, parser): + if engine == "numexpr": + pytest.importorskip("numexpr") + a, b = 1, 2 # noqa: F841 + res = pd.eval("a + b", engine=engine, parser=parser) + assert res == 3 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/computation/test_eval.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/computation/test_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..e8fad6b8cbd63a042d69de8a6a08c1ca1f35b12e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/computation/test_eval.py @@ -0,0 +1,2001 @@ +from __future__ import annotations + +from functools import reduce +from itertools import product +import operator + +import numpy as np +import pytest + +from pandas.compat import PY312 +from pandas.errors import ( + NumExprClobberingError, + PerformanceWarning, + UndefinedVariableError, +) +import pandas.util._test_decorators as td + +from pandas.core.dtypes.common import ( + is_bool, + is_float, + is_list_like, + is_scalar, +) + +import pandas as pd +from pandas import ( + DataFrame, + Index, + Series, + date_range, + period_range, + timedelta_range, +) +import pandas._testing as tm +from pandas.core.computation import ( + expr, + pytables, +) +from pandas.core.computation.engines import ENGINES +from pandas.core.computation.expr import ( + BaseExprVisitor, + PandasExprVisitor, + PythonExprVisitor, +) +from pandas.core.computation.expressions import ( + NUMEXPR_INSTALLED, + USE_NUMEXPR, +) +from pandas.core.computation.ops import ( + ARITH_OPS_SYMS, + SPECIAL_CASE_ARITH_OPS_SYMS, + _binary_math_ops, + _binary_ops_dict, + _unary_math_ops, +) +from pandas.core.computation.scope import DEFAULT_GLOBALS + + +@pytest.fixture( + params=( + pytest.param( + engine, + marks=[ + pytest.mark.skipif( + engine == "numexpr" and not USE_NUMEXPR, + reason=f"numexpr enabled->{USE_NUMEXPR}, " + f"installed->{NUMEXPR_INSTALLED}", + ), + td.skip_if_no("numexpr"), + ], + ) + for engine in ENGINES + ) +) +def engine(request): + return request.param + + +@pytest.fixture(params=expr.PARSERS) +def parser(request): + return request.param + + +def _eval_single_bin(lhs, cmp1, rhs, engine): + c = _binary_ops_dict[cmp1] + if ENGINES[engine].has_neg_frac: + try: + return c(lhs, rhs) + except ValueError as e: + if str(e).startswith( + "negative number cannot be raised to a fractional power" + ): + return np.nan + raise + return c(lhs, rhs) + + +# TODO: using range(5) here is a kludge +@pytest.fixture( + params=list(range(5)), + ids=["DataFrame", "Series", "SeriesNaN", "DataFrameNaN", "float"], +) +def lhs(request): + nan_df1 = DataFrame(np.random.default_rng(2).standard_normal((10, 5))) + nan_df1[nan_df1 > 0.5] = np.nan + + opts = ( + DataFrame(np.random.default_rng(2).standard_normal((10, 5))), + Series(np.random.default_rng(2).standard_normal(5)), + Series([1, 2, np.nan, np.nan, 5]), + nan_df1, + np.random.default_rng(2).standard_normal(), + ) + return opts[request.param] + + +rhs = lhs +midhs = lhs + + +@pytest.fixture +def idx_func_dict(): + return { + "i": lambda n: Index(np.arange(n), dtype=np.int64), + "f": lambda n: Index(np.arange(n), dtype=np.float64), + "s": lambda n: Index([f"{i}_{chr(i)}" for i in range(97, 97 + n)]), + "dt": lambda n: date_range("2020-01-01", periods=n), + "td": lambda n: timedelta_range("1 day", periods=n), + "p": lambda n: period_range("2020-01-01", periods=n, freq="D"), + } + + +class TestEval: + @pytest.mark.parametrize( + "cmp1", + ["!=", "==", "<=", ">=", "<", ">"], + ids=["ne", "eq", "le", "ge", "lt", "gt"], + ) + @pytest.mark.parametrize("cmp2", [">", "<"], ids=["gt", "lt"]) + @pytest.mark.parametrize("binop", expr.BOOL_OPS_SYMS) + def test_complex_cmp_ops(self, cmp1, cmp2, binop, lhs, rhs, engine, parser): + if parser == "python" and binop in ["and", "or"]: + msg = "'BoolOp' nodes are not implemented" + with pytest.raises(NotImplementedError, match=msg): + ex = f"(lhs {cmp1} rhs) {binop} (lhs {cmp2} rhs)" + pd.eval(ex, engine=engine, parser=parser) + return + + lhs_new = _eval_single_bin(lhs, cmp1, rhs, engine) + rhs_new = _eval_single_bin(lhs, cmp2, rhs, engine) + expected = _eval_single_bin(lhs_new, binop, rhs_new, engine) + + ex = f"(lhs {cmp1} rhs) {binop} (lhs {cmp2} rhs)" + result = pd.eval(ex, engine=engine, parser=parser) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize("cmp_op", expr.CMP_OPS_SYMS) + def test_simple_cmp_ops(self, cmp_op, lhs, rhs, engine, parser): + lhs = lhs < 0 + rhs = rhs < 0 + + if parser == "python" and cmp_op in ["in", "not in"]: + msg = "'(In|NotIn)' nodes are not implemented" + + with pytest.raises(NotImplementedError, match=msg): + ex = f"lhs {cmp_op} rhs" + pd.eval(ex, engine=engine, parser=parser) + return + + ex = f"lhs {cmp_op} rhs" + msg = "|".join( + [ + r"only list-like( or dict-like)? objects are allowed to be " + r"passed to (DataFrame\.)?isin\(\), you passed a " + r"(`|')bool(`|')", + "argument of type 'bool' is not iterable", + ] + ) + if cmp_op in ("in", "not in") and not is_list_like(rhs): + with pytest.raises(TypeError, match=msg): + pd.eval( + ex, + engine=engine, + parser=parser, + local_dict={"lhs": lhs, "rhs": rhs}, + ) + else: + expected = _eval_single_bin(lhs, cmp_op, rhs, engine) + result = pd.eval(ex, engine=engine, parser=parser) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize("op", expr.CMP_OPS_SYMS) + def test_compound_invert_op(self, op, lhs, rhs, request, engine, parser): + if parser == "python" and op in ["in", "not in"]: + msg = "'(In|NotIn)' nodes are not implemented" + with pytest.raises(NotImplementedError, match=msg): + ex = f"~(lhs {op} rhs)" + pd.eval(ex, engine=engine, parser=parser) + return + + if ( + is_float(lhs) + and not is_float(rhs) + and op in ["in", "not in"] + and engine == "python" + and parser == "pandas" + ): + mark = pytest.mark.xfail( + reason="Looks like expected is negative, unclear whether " + "expected is incorrect or result is incorrect" + ) + request.applymarker(mark) + skip_these = ["in", "not in"] + ex = f"~(lhs {op} rhs)" + + msg = "|".join( + [ + r"only list-like( or dict-like)? objects are allowed to be " + r"passed to (DataFrame\.)?isin\(\), you passed a " + r"(`|')float(`|')", + "argument of type 'float' is not iterable", + ] + ) + if is_scalar(rhs) and op in skip_these: + with pytest.raises(TypeError, match=msg): + pd.eval( + ex, + engine=engine, + parser=parser, + local_dict={"lhs": lhs, "rhs": rhs}, + ) + else: + # compound + if is_scalar(lhs) and is_scalar(rhs): + lhs, rhs = (np.array([x]) for x in (lhs, rhs)) + expected = _eval_single_bin(lhs, op, rhs, engine) + if is_scalar(expected): + expected = not expected + else: + expected = ~expected + result = pd.eval(ex, engine=engine, parser=parser) + tm.assert_almost_equal(expected, result) + + @pytest.mark.parametrize("cmp1", ["<", ">"]) + @pytest.mark.parametrize("cmp2", ["<", ">"]) + def test_chained_cmp_op(self, cmp1, cmp2, lhs, midhs, rhs, engine, parser): + mid = midhs + if parser == "python": + ex1 = f"lhs {cmp1} mid {cmp2} rhs" + msg = "'BoolOp' nodes are not implemented" + with pytest.raises(NotImplementedError, match=msg): + pd.eval(ex1, engine=engine, parser=parser) + return + + lhs_new = _eval_single_bin(lhs, cmp1, mid, engine) + rhs_new = _eval_single_bin(mid, cmp2, rhs, engine) + + if lhs_new is not None and rhs_new is not None: + ex1 = f"lhs {cmp1} mid {cmp2} rhs" + ex2 = f"lhs {cmp1} mid and mid {cmp2} rhs" + ex3 = f"(lhs {cmp1} mid) & (mid {cmp2} rhs)" + expected = _eval_single_bin(lhs_new, "&", rhs_new, engine) + + for ex in (ex1, ex2, ex3): + result = pd.eval(ex, engine=engine, parser=parser) + + tm.assert_almost_equal(result, expected) + + @pytest.mark.parametrize( + "arith1", sorted(set(ARITH_OPS_SYMS).difference(SPECIAL_CASE_ARITH_OPS_SYMS)) + ) + def test_binary_arith_ops(self, arith1, lhs, rhs, engine, parser): + ex = f"lhs {arith1} rhs" + result = pd.eval(ex, engine=engine, parser=parser) + expected = _eval_single_bin(lhs, arith1, rhs, engine) + + tm.assert_almost_equal(result, expected) + ex = f"lhs {arith1} rhs {arith1} rhs" + result = pd.eval(ex, engine=engine, parser=parser) + nlhs = _eval_single_bin(lhs, arith1, rhs, engine) + try: + nlhs, ghs = nlhs.align(rhs) + except (ValueError, TypeError, AttributeError): + # ValueError: series frame or frame series align + # TypeError, AttributeError: series or frame with scalar align + return + else: + if engine == "numexpr": + import numexpr as ne + + # direct numpy comparison + expected = ne.evaluate(f"nlhs {arith1} ghs") + # Update assert statement due to unreliable numerical + # precision component (GH37328) + # TODO: update testing code so that assert_almost_equal statement + # can be replaced again by the assert_numpy_array_equal statement + tm.assert_almost_equal(result.values, expected) + else: + expected = eval(f"nlhs {arith1} ghs") + tm.assert_almost_equal(result, expected) + + # modulus, pow, and floor division require special casing + + def test_modulus(self, lhs, rhs, engine, parser): + ex = r"lhs % rhs" + result = pd.eval(ex, engine=engine, parser=parser) + expected = lhs % rhs + tm.assert_almost_equal(result, expected) + + if engine == "numexpr": + import numexpr as ne + + expected = ne.evaluate(r"expected % rhs") + if isinstance(result, (DataFrame, Series)): + tm.assert_almost_equal(result.values, expected) + else: + tm.assert_almost_equal(result, expected.item()) + else: + expected = _eval_single_bin(expected, "%", rhs, engine) + tm.assert_almost_equal(result, expected) + + def test_floor_division(self, lhs, rhs, engine, parser): + ex = "lhs // rhs" + + if engine == "python": + res = pd.eval(ex, engine=engine, parser=parser) + expected = lhs // rhs + tm.assert_equal(res, expected) + else: + msg = ( + r"unsupported operand type\(s\) for //: 'VariableNode' and " + "'VariableNode'" + ) + with pytest.raises(TypeError, match=msg): + pd.eval( + ex, + local_dict={"lhs": lhs, "rhs": rhs}, + engine=engine, + parser=parser, + ) + + @td.skip_if_windows + def test_pow(self, lhs, rhs, engine, parser): + # odd failure on win32 platform, so skip + ex = "lhs ** rhs" + expected = _eval_single_bin(lhs, "**", rhs, engine) + result = pd.eval(ex, engine=engine, parser=parser) + + if ( + is_scalar(lhs) + and is_scalar(rhs) + and isinstance(expected, (complex, np.complexfloating)) + and np.isnan(result) + ): + msg = "(DataFrame.columns|numpy array) are different" + with pytest.raises(AssertionError, match=msg): + tm.assert_numpy_array_equal(result, expected) + else: + tm.assert_almost_equal(result, expected) + + ex = "(lhs ** rhs) ** rhs" + result = pd.eval(ex, engine=engine, parser=parser) + + middle = _eval_single_bin(lhs, "**", rhs, engine) + expected = _eval_single_bin(middle, "**", rhs, engine) + tm.assert_almost_equal(result, expected) + + def test_check_single_invert_op(self, lhs, engine, parser): + # simple + try: + elb = lhs.astype(bool) + except AttributeError: + elb = np.array([bool(lhs)]) + expected = ~elb + result = pd.eval("~elb", engine=engine, parser=parser) + tm.assert_almost_equal(expected, result) + + def test_frame_invert(self, engine, parser): + expr = "~lhs" + + # ~ ## + # frame + # float always raises + lhs = DataFrame(np.random.default_rng(2).standard_normal((5, 2))) + if engine == "numexpr": + msg = "couldn't find matching opcode for 'invert_dd'" + with pytest.raises(NotImplementedError, match=msg): + pd.eval(expr, engine=engine, parser=parser) + else: + msg = "ufunc 'invert' not supported for the input types" + with pytest.raises(TypeError, match=msg): + pd.eval(expr, engine=engine, parser=parser) + + # int raises on numexpr + lhs = DataFrame(np.random.default_rng(2).integers(5, size=(5, 2))) + if engine == "numexpr": + msg = "couldn't find matching opcode for 'invert" + with pytest.raises(NotImplementedError, match=msg): + pd.eval(expr, engine=engine, parser=parser) + else: + expect = ~lhs + result = pd.eval(expr, engine=engine, parser=parser) + tm.assert_frame_equal(expect, result) + + # bool always works + lhs = DataFrame(np.random.default_rng(2).standard_normal((5, 2)) > 0.5) + expect = ~lhs + result = pd.eval(expr, engine=engine, parser=parser) + tm.assert_frame_equal(expect, result) + + # object raises + lhs = DataFrame( + {"b": ["a", 1, 2.0], "c": np.random.default_rng(2).standard_normal(3) > 0.5} + ) + if engine == "numexpr": + with pytest.raises(ValueError, match="unknown type object"): + pd.eval(expr, engine=engine, parser=parser) + else: + msg = "bad operand type for unary ~: 'str'" + with pytest.raises(TypeError, match=msg): + pd.eval(expr, engine=engine, parser=parser) + + def test_series_invert(self, engine, parser): + # ~ #### + expr = "~lhs" + + # series + # float raises + lhs = Series(np.random.default_rng(2).standard_normal(5)) + if engine == "numexpr": + msg = "couldn't find matching opcode for 'invert_dd'" + with pytest.raises(NotImplementedError, match=msg): + result = pd.eval(expr, engine=engine, parser=parser) + else: + msg = "ufunc 'invert' not supported for the input types" + with pytest.raises(TypeError, match=msg): + pd.eval(expr, engine=engine, parser=parser) + + # int raises on numexpr + lhs = Series(np.random.default_rng(2).integers(5, size=5)) + if engine == "numexpr": + msg = "couldn't find matching opcode for 'invert" + with pytest.raises(NotImplementedError, match=msg): + pd.eval(expr, engine=engine, parser=parser) + else: + expect = ~lhs + result = pd.eval(expr, engine=engine, parser=parser) + tm.assert_series_equal(expect, result) + + # bool + lhs = Series(np.random.default_rng(2).standard_normal(5) > 0.5) + expect = ~lhs + result = pd.eval(expr, engine=engine, parser=parser) + tm.assert_series_equal(expect, result) + + # float + # int + # bool + + # object + lhs = Series(["a", 1, 2.0]) + if engine == "numexpr": + with pytest.raises(ValueError, match="unknown type object"): + pd.eval(expr, engine=engine, parser=parser) + else: + msg = "bad operand type for unary ~: 'str'" + with pytest.raises(TypeError, match=msg): + pd.eval(expr, engine=engine, parser=parser) + + def test_frame_negate(self, engine, parser): + expr = "-lhs" + + # float + lhs = DataFrame(np.random.default_rng(2).standard_normal((5, 2))) + expect = -lhs + result = pd.eval(expr, engine=engine, parser=parser) + tm.assert_frame_equal(expect, result) + + # int + lhs = DataFrame(np.random.default_rng(2).integers(5, size=(5, 2))) + expect = -lhs + result = pd.eval(expr, engine=engine, parser=parser) + tm.assert_frame_equal(expect, result) + + # bool doesn't work with numexpr but works elsewhere + lhs = DataFrame(np.random.default_rng(2).standard_normal((5, 2)) > 0.5) + if engine == "numexpr": + msg = "couldn't find matching opcode for 'neg_bb'" + with pytest.raises(NotImplementedError, match=msg): + pd.eval(expr, engine=engine, parser=parser) + else: + expect = -lhs + result = pd.eval(expr, engine=engine, parser=parser) + tm.assert_frame_equal(expect, result) + + def test_series_negate(self, engine, parser): + expr = "-lhs" + + # float + lhs = Series(np.random.default_rng(2).standard_normal(5)) + expect = -lhs + result = pd.eval(expr, engine=engine, parser=parser) + tm.assert_series_equal(expect, result) + + # int + lhs = Series(np.random.default_rng(2).integers(5, size=5)) + expect = -lhs + result = pd.eval(expr, engine=engine, parser=parser) + tm.assert_series_equal(expect, result) + + # bool doesn't work with numexpr but works elsewhere + lhs = Series(np.random.default_rng(2).standard_normal(5) > 0.5) + if engine == "numexpr": + msg = "couldn't find matching opcode for 'neg_bb'" + with pytest.raises(NotImplementedError, match=msg): + pd.eval(expr, engine=engine, parser=parser) + else: + expect = -lhs + result = pd.eval(expr, engine=engine, parser=parser) + tm.assert_series_equal(expect, result) + + @pytest.mark.parametrize( + "lhs", + [ + # Float + DataFrame(np.random.default_rng(2).standard_normal((5, 2))), + # Int + DataFrame(np.random.default_rng(2).integers(5, size=(5, 2))), + # bool doesn't work with numexpr but works elsewhere + DataFrame(np.random.default_rng(2).standard_normal((5, 2)) > 0.5), + ], + ) + def test_frame_pos(self, lhs, engine, parser): + expr = "+lhs" + expect = lhs + + result = pd.eval(expr, engine=engine, parser=parser) + tm.assert_frame_equal(expect, result) + + @pytest.mark.parametrize( + "lhs", + [ + # Float + Series(np.random.default_rng(2).standard_normal(5)), + # Int + Series(np.random.default_rng(2).integers(5, size=5)), + # bool doesn't work with numexpr but works elsewhere + Series(np.random.default_rng(2).standard_normal(5) > 0.5), + ], + ) + def test_series_pos(self, lhs, engine, parser): + expr = "+lhs" + expect = lhs + + result = pd.eval(expr, engine=engine, parser=parser) + tm.assert_series_equal(expect, result) + + def test_scalar_unary(self, engine, parser): + msg = "bad operand type for unary ~: 'float'" + warn = None + if PY312 and not (engine == "numexpr" and parser == "pandas"): + warn = DeprecationWarning + with pytest.raises(TypeError, match=msg): + pd.eval("~1.0", engine=engine, parser=parser) + + assert pd.eval("-1.0", parser=parser, engine=engine) == -1.0 + assert pd.eval("+1.0", parser=parser, engine=engine) == +1.0 + assert pd.eval("~1", parser=parser, engine=engine) == ~1 + assert pd.eval("-1", parser=parser, engine=engine) == -1 + assert pd.eval("+1", parser=parser, engine=engine) == +1 + with tm.assert_produces_warning( + warn, match="Bitwise inversion", check_stacklevel=False + ): + assert pd.eval("~True", parser=parser, engine=engine) == ~True + with tm.assert_produces_warning( + warn, match="Bitwise inversion", check_stacklevel=False + ): + assert pd.eval("~False", parser=parser, engine=engine) == ~False + assert pd.eval("-True", parser=parser, engine=engine) == -True + assert pd.eval("-False", parser=parser, engine=engine) == -False + assert pd.eval("+True", parser=parser, engine=engine) == +True + assert pd.eval("+False", parser=parser, engine=engine) == +False + + def test_unary_in_array(self): + # GH 11235 + # TODO: 2022-01-29: result return list with numexpr 2.7.3 in CI + # but cannot reproduce locally + result = np.array( + pd.eval("[-True, True, +True, -False, False, +False, -37, 37, ~37, +37]"), + dtype=np.object_, + ) + expected = np.array( + [ + -True, + True, + +True, + -False, + False, + +False, + -37, + 37, + ~37, + +37, + ], + dtype=np.object_, + ) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("dtype", [np.float32, np.float64]) + @pytest.mark.parametrize("expr", ["x < -0.1", "-5 > x"]) + def test_float_comparison_bin_op(self, dtype, expr): + # GH 16363 + df = DataFrame({"x": np.array([0], dtype=dtype)}) + res = df.eval(expr) + assert res.values == np.array([False]) + + def test_unary_in_function(self): + # GH 46471 + df = DataFrame({"x": [0, 1, np.nan]}) + + result = df.eval("x.fillna(-1)") + expected = df.x.fillna(-1) + # column name becomes None if using numexpr + # only check names when the engine is not numexpr + tm.assert_series_equal(result, expected, check_names=not USE_NUMEXPR) + + result = df.eval("x.shift(1, fill_value=-1)") + expected = df.x.shift(1, fill_value=-1) + tm.assert_series_equal(result, expected, check_names=not USE_NUMEXPR) + + @pytest.mark.parametrize( + "ex", + ( + "1 or 2", + "1 and 2", + "a and b", + "a or b", + "1 or 2 and (3 + 2) > 3", + "2 * x > 2 or 1 and 2", + "2 * df > 3 and 1 or a", + ), + ) + def test_disallow_scalar_bool_ops(self, ex, engine, parser): + x, a, b = np.random.default_rng(2).standard_normal(3), 1, 2 # noqa: F841 + df = DataFrame(np.random.default_rng(2).standard_normal((3, 2))) # noqa: F841 + + msg = "cannot evaluate scalar only bool ops|'BoolOp' nodes are not" + with pytest.raises(NotImplementedError, match=msg): + pd.eval(ex, engine=engine, parser=parser) + + def test_identical(self, engine, parser): + # see gh-10546 + x = 1 + result = pd.eval("x", engine=engine, parser=parser) + assert result == 1 + assert is_scalar(result) + + x = 1.5 + result = pd.eval("x", engine=engine, parser=parser) + assert result == 1.5 + assert is_scalar(result) + + x = False + result = pd.eval("x", engine=engine, parser=parser) + assert not result + assert is_bool(result) + assert is_scalar(result) + + x = np.array([1]) + result = pd.eval("x", engine=engine, parser=parser) + tm.assert_numpy_array_equal(result, np.array([1])) + assert result.shape == (1,) + + x = np.array([1.5]) + result = pd.eval("x", engine=engine, parser=parser) + tm.assert_numpy_array_equal(result, np.array([1.5])) + assert result.shape == (1,) + + x = np.array([False]) # noqa: F841 + result = pd.eval("x", engine=engine, parser=parser) + tm.assert_numpy_array_equal(result, np.array([False])) + assert result.shape == (1,) + + def test_line_continuation(self, engine, parser): + # GH 11149 + exp = """1 + 2 * \ + 5 - 1 + 2 """ + result = pd.eval(exp, engine=engine, parser=parser) + assert result == 12 + + def test_float_truncation(self, engine, parser): + # GH 14241 + exp = "1000000000.006" + result = pd.eval(exp, engine=engine, parser=parser) + expected = np.float64(exp) + assert result == expected + + df = DataFrame({"A": [1000000000.0009, 1000000000.0011, 1000000000.0015]}) + cutoff = 1000000000.0006 + result = df.query(f"A < {cutoff:.4f}") + assert result.empty + + cutoff = 1000000000.0010 + result = df.query(f"A > {cutoff:.4f}") + expected = df.loc[[1, 2], :] + tm.assert_frame_equal(expected, result) + + exact = 1000000000.0011 + result = df.query(f"A == {exact:.4f}") + expected = df.loc[[1], :] + tm.assert_frame_equal(expected, result) + + def test_disallow_python_keywords(self): + # GH 18221 + df = DataFrame([[0, 0, 0]], columns=["foo", "bar", "class"]) + msg = "Python keyword not valid identifier in numexpr query" + with pytest.raises(SyntaxError, match=msg): + df.query("class == 0") + + df = DataFrame() + df.index.name = "lambda" + with pytest.raises(SyntaxError, match=msg): + df.query("lambda == 0") + + def test_true_false_logic(self): + # GH 25823 + # This behavior is deprecated in Python 3.12 + with tm.maybe_produces_warning( + DeprecationWarning, PY312, check_stacklevel=False + ): + assert pd.eval("not True") == -2 + assert pd.eval("not False") == -1 + assert pd.eval("True and not True") == 0 + + def test_and_logic_string_match(self): + # GH 25823 + event = Series({"a": "hello"}) + assert pd.eval(f"{event.str.match('hello').a}") + assert pd.eval(f"{event.str.match('hello').a and event.str.match('hello').a}") + + +# ------------------------------------- +# gh-12388: Typecasting rules consistency with python + + +class TestTypeCasting: + @pytest.mark.parametrize("op", ["+", "-", "*", "**", "/"]) + # maybe someday... numexpr has too many upcasting rules now + # chain(*(np.core.sctypes[x] for x in ['uint', 'int', 'float'])) + @pytest.mark.parametrize("left_right", [("df", "3"), ("3", "df")]) + def test_binop_typecasting( + self, engine, parser, op, complex_or_float_dtype, left_right, request + ): + # GH#21374 + dtype = complex_or_float_dtype + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3)), dtype=dtype) + left, right = left_right + s = f"{left} {op} {right}" + res = pd.eval(s, engine=engine, parser=parser) + if dtype == "complex64" and engine == "numexpr": + mark = pytest.mark.xfail( + reason="numexpr issue with complex that are upcast " + "to complex 128 " + "https://github.com/pydata/numexpr/issues/492" + ) + request.applymarker(mark) + assert df.values.dtype == dtype + assert res.values.dtype == dtype + tm.assert_frame_equal(res, eval(s), check_exact=False) + + +# ------------------------------------- +# Basic and complex alignment + + +def should_warn(*args): + not_mono = not any(map(operator.attrgetter("is_monotonic_increasing"), args)) + only_one_dt = reduce( + operator.xor, (issubclass(x.dtype.type, np.datetime64) for x in args) + ) + return not_mono and only_one_dt + + +class TestAlignment: + index_types = ["i", "s", "dt"] + lhs_index_types = index_types + ["s"] # 'p' + + def test_align_nested_unary_op(self, engine, parser): + s = "df * ~2" + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + res = pd.eval(s, engine=engine, parser=parser) + tm.assert_frame_equal(res, df * ~2) + + @pytest.mark.filterwarnings("always::RuntimeWarning") + @pytest.mark.parametrize("lr_idx_type", lhs_index_types) + @pytest.mark.parametrize("rr_idx_type", index_types) + @pytest.mark.parametrize("c_idx_type", index_types) + def test_basic_frame_alignment( + self, engine, parser, lr_idx_type, rr_idx_type, c_idx_type, idx_func_dict + ): + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 10)), + index=idx_func_dict[lr_idx_type](10), + columns=idx_func_dict[c_idx_type](10), + ) + df2 = DataFrame( + np.random.default_rng(2).standard_normal((20, 10)), + index=idx_func_dict[rr_idx_type](20), + columns=idx_func_dict[c_idx_type](10), + ) + # only warns if not monotonic and not sortable + if should_warn(df.index, df2.index): + with tm.assert_produces_warning(RuntimeWarning): + res = pd.eval("df + df2", engine=engine, parser=parser) + else: + res = pd.eval("df + df2", engine=engine, parser=parser) + tm.assert_frame_equal(res, df + df2) + + @pytest.mark.parametrize("r_idx_type", lhs_index_types) + @pytest.mark.parametrize("c_idx_type", lhs_index_types) + def test_frame_comparison( + self, engine, parser, r_idx_type, c_idx_type, idx_func_dict + ): + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 10)), + index=idx_func_dict[r_idx_type](10), + columns=idx_func_dict[c_idx_type](10), + ) + res = pd.eval("df < 2", engine=engine, parser=parser) + tm.assert_frame_equal(res, df < 2) + + df3 = DataFrame( + np.random.default_rng(2).standard_normal(df.shape), + index=df.index, + columns=df.columns, + ) + res = pd.eval("df < df3", engine=engine, parser=parser) + tm.assert_frame_equal(res, df < df3) + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + @pytest.mark.parametrize("r1", lhs_index_types) + @pytest.mark.parametrize("c1", index_types) + @pytest.mark.parametrize("r2", index_types) + @pytest.mark.parametrize("c2", index_types) + def test_medium_complex_frame_alignment( + self, engine, parser, r1, c1, r2, c2, idx_func_dict + ): + df = DataFrame( + np.random.default_rng(2).standard_normal((3, 2)), + index=idx_func_dict[r1](3), + columns=idx_func_dict[c1](2), + ) + df2 = DataFrame( + np.random.default_rng(2).standard_normal((4, 2)), + index=idx_func_dict[r2](4), + columns=idx_func_dict[c2](2), + ) + df3 = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), + index=idx_func_dict[r2](5), + columns=idx_func_dict[c2](2), + ) + if should_warn(df.index, df2.index, df3.index): + with tm.assert_produces_warning(RuntimeWarning): + res = pd.eval("df + df2 + df3", engine=engine, parser=parser) + else: + res = pd.eval("df + df2 + df3", engine=engine, parser=parser) + tm.assert_frame_equal(res, df + df2 + df3) + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + @pytest.mark.parametrize("index_name", ["index", "columns"]) + @pytest.mark.parametrize("c_idx_type", index_types) + @pytest.mark.parametrize("r_idx_type", lhs_index_types) + def test_basic_frame_series_alignment( + self, engine, parser, index_name, r_idx_type, c_idx_type, idx_func_dict + ): + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 10)), + index=idx_func_dict[r_idx_type](10), + columns=idx_func_dict[c_idx_type](10), + ) + index = getattr(df, index_name) + s = Series(np.random.default_rng(2).standard_normal(5), index[:5]) + + if should_warn(df.index, s.index): + with tm.assert_produces_warning(RuntimeWarning): + res = pd.eval("df + s", engine=engine, parser=parser) + else: + res = pd.eval("df + s", engine=engine, parser=parser) + + if r_idx_type == "dt" or c_idx_type == "dt": + expected = df.add(s) if engine == "numexpr" else df + s + else: + expected = df + s + tm.assert_frame_equal(res, expected) + + @pytest.mark.parametrize("index_name", ["index", "columns"]) + @pytest.mark.parametrize( + "r_idx_type, c_idx_type", + list(product(["i", "s"], ["i", "s"])) + [("dt", "dt")], + ) + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test_basic_series_frame_alignment( + self, request, engine, parser, index_name, r_idx_type, c_idx_type, idx_func_dict + ): + if ( + engine == "numexpr" + and parser in ("pandas", "python") + and index_name == "index" + and r_idx_type == "i" + and c_idx_type == "s" + ): + reason = ( + f"Flaky column ordering when engine={engine}, " + f"parser={parser}, index_name={index_name}, " + f"r_idx_type={r_idx_type}, c_idx_type={c_idx_type}" + ) + request.applymarker(pytest.mark.xfail(reason=reason, strict=False)) + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 7)), + index=idx_func_dict[r_idx_type](10), + columns=idx_func_dict[c_idx_type](7), + ) + index = getattr(df, index_name) + s = Series(np.random.default_rng(2).standard_normal(5), index[:5]) + if should_warn(s.index, df.index): + with tm.assert_produces_warning(RuntimeWarning): + res = pd.eval("s + df", engine=engine, parser=parser) + else: + res = pd.eval("s + df", engine=engine, parser=parser) + + if r_idx_type == "dt" or c_idx_type == "dt": + expected = df.add(s) if engine == "numexpr" else s + df + else: + expected = s + df + tm.assert_frame_equal(res, expected) + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + @pytest.mark.parametrize("c_idx_type", index_types) + @pytest.mark.parametrize("r_idx_type", lhs_index_types) + @pytest.mark.parametrize("index_name", ["index", "columns"]) + @pytest.mark.parametrize("op", ["+", "*"]) + def test_series_frame_commutativity( + self, engine, parser, index_name, op, r_idx_type, c_idx_type, idx_func_dict + ): + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 10)), + index=idx_func_dict[r_idx_type](10), + columns=idx_func_dict[c_idx_type](10), + ) + index = getattr(df, index_name) + s = Series(np.random.default_rng(2).standard_normal(5), index[:5]) + + lhs = f"s {op} df" + rhs = f"df {op} s" + if should_warn(df.index, s.index): + with tm.assert_produces_warning(RuntimeWarning): + a = pd.eval(lhs, engine=engine, parser=parser) + with tm.assert_produces_warning(RuntimeWarning): + b = pd.eval(rhs, engine=engine, parser=parser) + else: + a = pd.eval(lhs, engine=engine, parser=parser) + b = pd.eval(rhs, engine=engine, parser=parser) + + if r_idx_type != "dt" and c_idx_type != "dt": + if engine == "numexpr": + tm.assert_frame_equal(a, b) + + @pytest.mark.filterwarnings("always::RuntimeWarning") + @pytest.mark.parametrize("r1", lhs_index_types) + @pytest.mark.parametrize("c1", index_types) + @pytest.mark.parametrize("r2", index_types) + @pytest.mark.parametrize("c2", index_types) + def test_complex_series_frame_alignment( + self, engine, parser, r1, c1, r2, c2, idx_func_dict + ): + n = 3 + m1 = 5 + m2 = 2 * m1 + df = DataFrame( + np.random.default_rng(2).standard_normal((m1, n)), + index=idx_func_dict[r1](m1), + columns=idx_func_dict[c1](n), + ) + df2 = DataFrame( + np.random.default_rng(2).standard_normal((m2, n)), + index=idx_func_dict[r2](m2), + columns=idx_func_dict[c2](n), + ) + index = df2.columns + ser = Series(np.random.default_rng(2).standard_normal(n), index[:n]) + + if r2 == "dt" or c2 == "dt": + if engine == "numexpr": + expected2 = df2.add(ser) + else: + expected2 = df2 + ser + else: + expected2 = df2 + ser + + if r1 == "dt" or c1 == "dt": + if engine == "numexpr": + expected = expected2.add(df) + else: + expected = expected2 + df + else: + expected = expected2 + df + + if should_warn(df2.index, ser.index, df.index): + with tm.assert_produces_warning(RuntimeWarning): + res = pd.eval("df2 + ser + df", engine=engine, parser=parser) + else: + res = pd.eval("df2 + ser + df", engine=engine, parser=parser) + assert res.shape == expected.shape + tm.assert_frame_equal(res, expected) + + def test_performance_warning_for_poor_alignment(self, engine, parser): + df = DataFrame(np.random.default_rng(2).standard_normal((1000, 10))) + s = Series(np.random.default_rng(2).standard_normal(10000)) + if engine == "numexpr": + seen = PerformanceWarning + else: + seen = False + + with tm.assert_produces_warning(seen): + pd.eval("df + s", engine=engine, parser=parser) + + s = Series(np.random.default_rng(2).standard_normal(1000)) + with tm.assert_produces_warning(False): + pd.eval("df + s", engine=engine, parser=parser) + + df = DataFrame(np.random.default_rng(2).standard_normal((10, 10000))) + s = Series(np.random.default_rng(2).standard_normal(10000)) + with tm.assert_produces_warning(False): + pd.eval("df + s", engine=engine, parser=parser) + + df = DataFrame(np.random.default_rng(2).standard_normal((10, 10))) + s = Series(np.random.default_rng(2).standard_normal(10000)) + + is_python_engine = engine == "python" + + if not is_python_engine: + wrn = PerformanceWarning + else: + wrn = False + + with tm.assert_produces_warning(wrn) as w: + pd.eval("df + s", engine=engine, parser=parser) + + if not is_python_engine: + assert len(w) == 1 + msg = str(w[0].message) + logged = np.log10(s.size - df.shape[1]) + expected = ( + f"Alignment difference on axis 1 is larger " + f"than an order of magnitude on term 'df', " + f"by more than {logged:.4g}; performance may suffer." + ) + assert msg == expected + + +# ------------------------------------ +# Slightly more complex ops + + +class TestOperations: + def eval(self, *args, **kwargs): + kwargs["level"] = kwargs.pop("level", 0) + 1 + return pd.eval(*args, **kwargs) + + def test_simple_arith_ops(self, engine, parser): + exclude_arith = [] + if parser == "python": + exclude_arith = ["in", "not in"] + + arith_ops = [ + op + for op in expr.ARITH_OPS_SYMS + expr.CMP_OPS_SYMS + if op not in exclude_arith + ] + + ops = (op for op in arith_ops if op != "//") + + for op in ops: + ex = f"1 {op} 1" + ex2 = f"x {op} 1" + ex3 = f"1 {op} (x + 1)" + + if op in ("in", "not in"): + msg = "argument of type 'int' is not iterable" + with pytest.raises(TypeError, match=msg): + pd.eval(ex, engine=engine, parser=parser) + else: + expec = _eval_single_bin(1, op, 1, engine) + x = self.eval(ex, engine=engine, parser=parser) + assert x == expec + + expec = _eval_single_bin(x, op, 1, engine) + y = self.eval(ex2, local_dict={"x": x}, engine=engine, parser=parser) + assert y == expec + + expec = _eval_single_bin(1, op, x + 1, engine) + y = self.eval(ex3, local_dict={"x": x}, engine=engine, parser=parser) + assert y == expec + + @pytest.mark.parametrize("rhs", [True, False]) + @pytest.mark.parametrize("lhs", [True, False]) + @pytest.mark.parametrize("op", expr.BOOL_OPS_SYMS) + def test_simple_bool_ops(self, rhs, lhs, op): + ex = f"{lhs} {op} {rhs}" + + if parser == "python" and op in ["and", "or"]: + msg = "'BoolOp' nodes are not implemented" + with pytest.raises(NotImplementedError, match=msg): + self.eval(ex) + return + + res = self.eval(ex) + exp = eval(ex) + assert res == exp + + @pytest.mark.parametrize("rhs", [True, False]) + @pytest.mark.parametrize("lhs", [True, False]) + @pytest.mark.parametrize("op", expr.BOOL_OPS_SYMS) + def test_bool_ops_with_constants(self, rhs, lhs, op): + ex = f"{lhs} {op} {rhs}" + + if parser == "python" and op in ["and", "or"]: + msg = "'BoolOp' nodes are not implemented" + with pytest.raises(NotImplementedError, match=msg): + self.eval(ex) + return + + res = self.eval(ex) + exp = eval(ex) + assert res == exp + + def test_4d_ndarray_fails(self): + x = np.random.default_rng(2).standard_normal((3, 4, 5, 6)) + y = Series(np.random.default_rng(2).standard_normal(10)) + msg = "N-dimensional objects, where N > 2, are not supported with eval" + with pytest.raises(NotImplementedError, match=msg): + self.eval("x + y", local_dict={"x": x, "y": y}) + + def test_constant(self): + x = self.eval("1") + assert x == 1 + + def test_single_variable(self): + df = DataFrame(np.random.default_rng(2).standard_normal((10, 2))) + df2 = self.eval("df", local_dict={"df": df}) + tm.assert_frame_equal(df, df2) + + def test_failing_subscript_with_name_error(self): + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) # noqa: F841 + with pytest.raises(NameError, match="name 'x' is not defined"): + self.eval("df[x > 2] > 2") + + def test_lhs_expression_subscript(self): + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + result = self.eval("(df + 1)[df > 2]", local_dict={"df": df}) + expected = (df + 1)[df > 2] + tm.assert_frame_equal(result, expected) + + def test_attr_expression(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 3)), columns=list("abc") + ) + expr1 = "df.a < df.b" + expec1 = df.a < df.b + expr2 = "df.a + df.b + df.c" + expec2 = df.a + df.b + df.c + expr3 = "df.a + df.b + df.c[df.b < 0]" + expec3 = df.a + df.b + df.c[df.b < 0] + exprs = expr1, expr2, expr3 + expecs = expec1, expec2, expec3 + for e, expec in zip(exprs, expecs): + tm.assert_series_equal(expec, self.eval(e, local_dict={"df": df})) + + def test_assignment_fails(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 3)), columns=list("abc") + ) + df2 = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + expr1 = "df = df2" + msg = "cannot assign without a target object" + with pytest.raises(ValueError, match=msg): + self.eval(expr1, local_dict={"df": df, "df2": df2}) + + def test_assignment_column_multiple_raise(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), columns=list("ab") + ) + # multiple assignees + with pytest.raises(SyntaxError, match="invalid syntax"): + df.eval("d c = a + b") + + def test_assignment_column_invalid_assign(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), columns=list("ab") + ) + # invalid assignees + msg = "left hand side of an assignment must be a single name" + with pytest.raises(SyntaxError, match=msg): + df.eval("d,c = a + b") + + def test_assignment_column_invalid_assign_function_call(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), columns=list("ab") + ) + msg = "cannot assign to function call" + with pytest.raises(SyntaxError, match=msg): + df.eval('Timestamp("20131001") = a + b') + + def test_assignment_single_assign_existing(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), columns=list("ab") + ) + # single assignment - existing variable + expected = df.copy() + expected["a"] = expected["a"] + expected["b"] + df.eval("a = a + b", inplace=True) + tm.assert_frame_equal(df, expected) + + def test_assignment_single_assign_new(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), columns=list("ab") + ) + # single assignment - new variable + expected = df.copy() + expected["c"] = expected["a"] + expected["b"] + df.eval("c = a + b", inplace=True) + tm.assert_frame_equal(df, expected) + + def test_assignment_single_assign_local_overlap(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), columns=list("ab") + ) + df = df.copy() + a = 1 # noqa: F841 + df.eval("a = 1 + b", inplace=True) + + expected = df.copy() + expected["a"] = 1 + expected["b"] + tm.assert_frame_equal(df, expected) + + def test_assignment_single_assign_name(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), columns=list("ab") + ) + + a = 1 # noqa: F841 + old_a = df.a.copy() + df.eval("a = a + b", inplace=True) + result = old_a + df.b + tm.assert_series_equal(result, df.a, check_names=False) + assert result.name is None + + def test_assignment_multiple_raises(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), columns=list("ab") + ) + # multiple assignment + df.eval("c = a + b", inplace=True) + msg = "can only assign a single expression" + with pytest.raises(SyntaxError, match=msg): + df.eval("c = a = b") + + def test_assignment_explicit(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), columns=list("ab") + ) + # explicit targets + self.eval("c = df.a + df.b", local_dict={"df": df}, target=df, inplace=True) + expected = df.copy() + expected["c"] = expected["a"] + expected["b"] + tm.assert_frame_equal(df, expected) + + def test_column_in(self): + # GH 11235 + df = DataFrame({"a": [11], "b": [-32]}) + result = df.eval("a in [11, -32]") + expected = Series([True]) + # TODO: 2022-01-29: Name check failed with numexpr 2.7.3 in CI + # but cannot reproduce locally + tm.assert_series_equal(result, expected, check_names=False) + + @pytest.mark.xfail(reason="Unknown: Omitted test_ in name prior.") + def test_assignment_not_inplace(self): + # see gh-9297 + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), columns=list("ab") + ) + + actual = df.eval("c = a + b", inplace=False) + assert actual is not None + + expected = df.copy() + expected["c"] = expected["a"] + expected["b"] + tm.assert_frame_equal(df, expected) + + def test_multi_line_expression(self, warn_copy_on_write): + # GH 11149 + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + expected = df.copy() + + expected["c"] = expected["a"] + expected["b"] + expected["d"] = expected["c"] + expected["b"] + answer = df.eval( + """ + c = a + b + d = c + b""", + inplace=True, + ) + tm.assert_frame_equal(expected, df) + assert answer is None + + expected["a"] = expected["a"] - 1 + expected["e"] = expected["a"] + 2 + answer = df.eval( + """ + a = a - 1 + e = a + 2""", + inplace=True, + ) + tm.assert_frame_equal(expected, df) + assert answer is None + + # multi-line not valid if not all assignments + msg = "Multi-line expressions are only valid if all expressions contain" + with pytest.raises(ValueError, match=msg): + df.eval( + """ + a = b + 2 + b - 2""", + inplace=False, + ) + + def test_multi_line_expression_not_inplace(self): + # GH 11149 + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + expected = df.copy() + + expected["c"] = expected["a"] + expected["b"] + expected["d"] = expected["c"] + expected["b"] + df = df.eval( + """ + c = a + b + d = c + b""", + inplace=False, + ) + tm.assert_frame_equal(expected, df) + + expected["a"] = expected["a"] - 1 + expected["e"] = expected["a"] + 2 + df = df.eval( + """ + a = a - 1 + e = a + 2""", + inplace=False, + ) + tm.assert_frame_equal(expected, df) + + def test_multi_line_expression_local_variable(self): + # GH 15342 + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + expected = df.copy() + + local_var = 7 + expected["c"] = expected["a"] * local_var + expected["d"] = expected["c"] + local_var + answer = df.eval( + """ + c = a * @local_var + d = c + @local_var + """, + inplace=True, + ) + tm.assert_frame_equal(expected, df) + assert answer is None + + def test_multi_line_expression_callable_local_variable(self): + # 26426 + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + def local_func(a, b): + return b + + expected = df.copy() + expected["c"] = expected["a"] * local_func(1, 7) + expected["d"] = expected["c"] + local_func(1, 7) + answer = df.eval( + """ + c = a * @local_func(1, 7) + d = c + @local_func(1, 7) + """, + inplace=True, + ) + tm.assert_frame_equal(expected, df) + assert answer is None + + def test_multi_line_expression_callable_local_variable_with_kwargs(self): + # 26426 + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + def local_func(a, b): + return b + + expected = df.copy() + expected["c"] = expected["a"] * local_func(b=7, a=1) + expected["d"] = expected["c"] + local_func(b=7, a=1) + answer = df.eval( + """ + c = a * @local_func(b=7, a=1) + d = c + @local_func(b=7, a=1) + """, + inplace=True, + ) + tm.assert_frame_equal(expected, df) + assert answer is None + + def test_assignment_in_query(self): + # GH 8664 + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df_orig = df.copy() + msg = "cannot assign without a target object" + with pytest.raises(ValueError, match=msg): + df.query("a = 1") + tm.assert_frame_equal(df, df_orig) + + def test_query_inplace(self): + # see gh-11149 + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + expected = df.copy() + expected = expected[expected["a"] == 2] + df.query("a == 2", inplace=True) + tm.assert_frame_equal(expected, df) + + df = {} + expected = {"a": 3} + + self.eval("a = 1 + 2", target=df, inplace=True) + tm.assert_dict_equal(df, expected) + + @pytest.mark.parametrize("invalid_target", [1, "cat", [1, 2], np.array([]), (1, 3)]) + def test_cannot_item_assign(self, invalid_target): + msg = "Cannot assign expression output to target" + expression = "a = 1 + 2" + + with pytest.raises(ValueError, match=msg): + self.eval(expression, target=invalid_target, inplace=True) + + if hasattr(invalid_target, "copy"): + with pytest.raises(ValueError, match=msg): + self.eval(expression, target=invalid_target, inplace=False) + + @pytest.mark.parametrize("invalid_target", [1, "cat", (1, 3)]) + def test_cannot_copy_item(self, invalid_target): + msg = "Cannot return a copy of the target" + expression = "a = 1 + 2" + + with pytest.raises(ValueError, match=msg): + self.eval(expression, target=invalid_target, inplace=False) + + @pytest.mark.parametrize("target", [1, "cat", [1, 2], np.array([]), (1, 3), {1: 2}]) + def test_inplace_no_assignment(self, target): + expression = "1 + 2" + + assert self.eval(expression, target=target, inplace=False) == 3 + + msg = "Cannot operate inplace if there is no assignment" + with pytest.raises(ValueError, match=msg): + self.eval(expression, target=target, inplace=True) + + def test_basic_period_index_boolean_expression(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((2, 2)), + columns=period_range("2020-01-01", freq="D", periods=2), + ) + e = df < 2 + r = self.eval("df < 2", local_dict={"df": df}) + x = df < 2 + + tm.assert_frame_equal(r, e) + tm.assert_frame_equal(x, e) + + def test_basic_period_index_subscript_expression(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((2, 2)), + columns=period_range("2020-01-01", freq="D", periods=2), + ) + r = self.eval("df[df < 2 + 3]", local_dict={"df": df}) + e = df[df < 2 + 3] + tm.assert_frame_equal(r, e) + + def test_nested_period_index_subscript_expression(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((2, 2)), + columns=period_range("2020-01-01", freq="D", periods=2), + ) + r = self.eval("df[df[df < 2] < 2] + df * 2", local_dict={"df": df}) + e = df[df[df < 2] < 2] + df * 2 + tm.assert_frame_equal(r, e) + + def test_date_boolean(self, engine, parser): + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + df["dates1"] = date_range("1/1/2012", periods=5) + res = self.eval( + "df.dates1 < 20130101", + local_dict={"df": df}, + engine=engine, + parser=parser, + ) + expec = df.dates1 < "20130101" + tm.assert_series_equal(res, expec, check_names=False) + + def test_simple_in_ops(self, engine, parser): + if parser != "python": + res = pd.eval("1 in [1, 2]", engine=engine, parser=parser) + assert res + + res = pd.eval("2 in (1, 2)", engine=engine, parser=parser) + assert res + + res = pd.eval("3 in (1, 2)", engine=engine, parser=parser) + assert not res + + res = pd.eval("3 not in (1, 2)", engine=engine, parser=parser) + assert res + + res = pd.eval("[3] not in (1, 2)", engine=engine, parser=parser) + assert res + + res = pd.eval("[3] in ([3], 2)", engine=engine, parser=parser) + assert res + + res = pd.eval("[[3]] in [[[3]], 2]", engine=engine, parser=parser) + assert res + + res = pd.eval("(3,) in [(3,), 2]", engine=engine, parser=parser) + assert res + + res = pd.eval("(3,) not in [(3,), 2]", engine=engine, parser=parser) + assert not res + + res = pd.eval("[(3,)] in [[(3,)], 2]", engine=engine, parser=parser) + assert res + else: + msg = "'In' nodes are not implemented" + with pytest.raises(NotImplementedError, match=msg): + pd.eval("1 in [1, 2]", engine=engine, parser=parser) + with pytest.raises(NotImplementedError, match=msg): + pd.eval("2 in (1, 2)", engine=engine, parser=parser) + with pytest.raises(NotImplementedError, match=msg): + pd.eval("3 in (1, 2)", engine=engine, parser=parser) + with pytest.raises(NotImplementedError, match=msg): + pd.eval("[(3,)] in (1, 2, [(3,)])", engine=engine, parser=parser) + msg = "'NotIn' nodes are not implemented" + with pytest.raises(NotImplementedError, match=msg): + pd.eval("3 not in (1, 2)", engine=engine, parser=parser) + with pytest.raises(NotImplementedError, match=msg): + pd.eval("[3] not in (1, 2, [[3]])", engine=engine, parser=parser) + + def test_check_many_exprs(self, engine, parser): + a = 1 # noqa: F841 + expr = " * ".join("a" * 33) + expected = 1 + res = pd.eval(expr, engine=engine, parser=parser) + assert res == expected + + @pytest.mark.parametrize( + "expr", + [ + "df > 2 and df > 3", + "df > 2 or df > 3", + "not df > 2", + ], + ) + def test_fails_and_or_not(self, expr, engine, parser): + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + if parser == "python": + msg = "'BoolOp' nodes are not implemented" + if "not" in expr: + msg = "'Not' nodes are not implemented" + + with pytest.raises(NotImplementedError, match=msg): + pd.eval( + expr, + local_dict={"df": df}, + parser=parser, + engine=engine, + ) + else: + # smoke-test, should not raise + pd.eval( + expr, + local_dict={"df": df}, + parser=parser, + engine=engine, + ) + + @pytest.mark.parametrize("char", ["|", "&"]) + def test_fails_ampersand_pipe(self, char, engine, parser): + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) # noqa: F841 + ex = f"(df + 2)[df > 1] > 0 {char} (df > 0)" + if parser == "python": + msg = "cannot evaluate scalar only bool ops" + with pytest.raises(NotImplementedError, match=msg): + pd.eval(ex, parser=parser, engine=engine) + else: + # smoke-test, should not raise + pd.eval(ex, parser=parser, engine=engine) + + +class TestMath: + def eval(self, *args, **kwargs): + kwargs["level"] = kwargs.pop("level", 0) + 1 + return pd.eval(*args, **kwargs) + + @pytest.mark.skipif( + not NUMEXPR_INSTALLED, reason="Unary ops only implemented for numexpr" + ) + @pytest.mark.parametrize("fn", _unary_math_ops) + def test_unary_functions(self, fn): + df = DataFrame({"a": np.random.default_rng(2).standard_normal(10)}) + a = df.a + + expr = f"{fn}(a)" + got = self.eval(expr) + with np.errstate(all="ignore"): + expect = getattr(np, fn)(a) + tm.assert_series_equal(got, expect, check_names=False) + + @pytest.mark.parametrize("fn", _binary_math_ops) + def test_binary_functions(self, fn): + df = DataFrame( + { + "a": np.random.default_rng(2).standard_normal(10), + "b": np.random.default_rng(2).standard_normal(10), + } + ) + a = df.a + b = df.b + + expr = f"{fn}(a, b)" + got = self.eval(expr) + with np.errstate(all="ignore"): + expect = getattr(np, fn)(a, b) + tm.assert_almost_equal(got, expect, check_names=False) + + def test_df_use_case(self, engine, parser): + df = DataFrame( + { + "a": np.random.default_rng(2).standard_normal(10), + "b": np.random.default_rng(2).standard_normal(10), + } + ) + df.eval( + "e = arctan2(sin(a), b)", + engine=engine, + parser=parser, + inplace=True, + ) + got = df.e + expect = np.arctan2(np.sin(df.a), df.b) + tm.assert_series_equal(got, expect, check_names=False) + + def test_df_arithmetic_subexpression(self, engine, parser): + df = DataFrame( + { + "a": np.random.default_rng(2).standard_normal(10), + "b": np.random.default_rng(2).standard_normal(10), + } + ) + df.eval("e = sin(a + b)", engine=engine, parser=parser, inplace=True) + got = df.e + expect = np.sin(df.a + df.b) + tm.assert_series_equal(got, expect, check_names=False) + + @pytest.mark.parametrize( + "dtype, expect_dtype", + [ + (np.int32, np.float64), + (np.int64, np.float64), + (np.float32, np.float32), + (np.float64, np.float64), + pytest.param(np.complex128, np.complex128, marks=td.skip_if_windows), + ], + ) + def test_result_types(self, dtype, expect_dtype, engine, parser): + # xref https://github.com/pandas-dev/pandas/issues/12293 + # this fails on Windows, apparently a floating point precision issue + + # Did not test complex64 because DataFrame is converting it to + # complex128. Due to https://github.com/pandas-dev/pandas/issues/10952 + df = DataFrame( + {"a": np.random.default_rng(2).standard_normal(10).astype(dtype)} + ) + assert df.a.dtype == dtype + df.eval("b = sin(a)", engine=engine, parser=parser, inplace=True) + got = df.b + expect = np.sin(df.a) + assert expect.dtype == got.dtype + assert expect_dtype == got.dtype + tm.assert_series_equal(got, expect, check_names=False) + + def test_undefined_func(self, engine, parser): + df = DataFrame({"a": np.random.default_rng(2).standard_normal(10)}) + msg = '"mysin" is not a supported function' + + with pytest.raises(ValueError, match=msg): + df.eval("mysin(a)", engine=engine, parser=parser) + + def test_keyword_arg(self, engine, parser): + df = DataFrame({"a": np.random.default_rng(2).standard_normal(10)}) + msg = 'Function "sin" does not support keyword arguments' + + with pytest.raises(TypeError, match=msg): + df.eval("sin(x=a)", engine=engine, parser=parser) + + +_var_s = np.random.default_rng(2).standard_normal(10) + + +class TestScope: + def test_global_scope(self, engine, parser): + e = "_var_s * 2" + tm.assert_numpy_array_equal( + _var_s * 2, pd.eval(e, engine=engine, parser=parser) + ) + + def test_no_new_locals(self, engine, parser): + x = 1 + lcls = locals().copy() + pd.eval("x + 1", local_dict=lcls, engine=engine, parser=parser) + lcls2 = locals().copy() + lcls2.pop("lcls") + assert lcls == lcls2 + + def test_no_new_globals(self, engine, parser): + x = 1 # noqa: F841 + gbls = globals().copy() + pd.eval("x + 1", engine=engine, parser=parser) + gbls2 = globals().copy() + assert gbls == gbls2 + + def test_empty_locals(self, engine, parser): + # GH 47084 + x = 1 # noqa: F841 + msg = "name 'x' is not defined" + with pytest.raises(UndefinedVariableError, match=msg): + pd.eval("x + 1", engine=engine, parser=parser, local_dict={}) + + def test_empty_globals(self, engine, parser): + # GH 47084 + msg = "name '_var_s' is not defined" + e = "_var_s * 2" + with pytest.raises(UndefinedVariableError, match=msg): + pd.eval(e, engine=engine, parser=parser, global_dict={}) + + +@td.skip_if_no("numexpr") +def test_invalid_engine(): + msg = "Invalid engine 'asdf' passed" + with pytest.raises(KeyError, match=msg): + pd.eval("x + y", local_dict={"x": 1, "y": 2}, engine="asdf") + + +@td.skip_if_no("numexpr") +@pytest.mark.parametrize( + ("use_numexpr", "expected"), + ( + (True, "numexpr"), + (False, "python"), + ), +) +def test_numexpr_option_respected(use_numexpr, expected): + # GH 32556 + from pandas.core.computation.eval import _check_engine + + with pd.option_context("compute.use_numexpr", use_numexpr): + result = _check_engine(None) + assert result == expected + + +@td.skip_if_no("numexpr") +def test_numexpr_option_incompatible_op(): + # GH 32556 + with pd.option_context("compute.use_numexpr", False): + df = DataFrame( + {"A": [True, False, True, False, None, None], "B": [1, 2, 3, 4, 5, 6]} + ) + result = df.query("A.isnull()") + expected = DataFrame({"A": [None, None], "B": [5, 6]}, index=[4, 5]) + tm.assert_frame_equal(result, expected) + + +@td.skip_if_no("numexpr") +def test_invalid_parser(): + msg = "Invalid parser 'asdf' passed" + with pytest.raises(KeyError, match=msg): + pd.eval("x + y", local_dict={"x": 1, "y": 2}, parser="asdf") + + +_parsers: dict[str, type[BaseExprVisitor]] = { + "python": PythonExprVisitor, + "pytables": pytables.PyTablesExprVisitor, + "pandas": PandasExprVisitor, +} + + +@pytest.mark.parametrize("engine", ENGINES) +@pytest.mark.parametrize("parser", _parsers) +def test_disallowed_nodes(engine, parser): + VisitorClass = _parsers[parser] + inst = VisitorClass("x + 1", engine, parser) + + for ops in VisitorClass.unsupported_nodes: + msg = "nodes are not implemented" + with pytest.raises(NotImplementedError, match=msg): + getattr(inst, ops)() + + +def test_syntax_error_exprs(engine, parser): + e = "s +" + with pytest.raises(SyntaxError, match="invalid syntax"): + pd.eval(e, engine=engine, parser=parser) + + +def test_name_error_exprs(engine, parser): + e = "s + t" + msg = "name 's' is not defined" + with pytest.raises(NameError, match=msg): + pd.eval(e, engine=engine, parser=parser) + + +@pytest.mark.parametrize("express", ["a + @b", "@a + b", "@a + @b"]) +def test_invalid_local_variable_reference(engine, parser, express): + a, b = 1, 2 # noqa: F841 + + if parser != "pandas": + with pytest.raises(SyntaxError, match="The '@' prefix is only"): + pd.eval(express, engine=engine, parser=parser) + else: + with pytest.raises(SyntaxError, match="The '@' prefix is not"): + pd.eval(express, engine=engine, parser=parser) + + +def test_numexpr_builtin_raises(engine, parser): + sin, dotted_line = 1, 2 + if engine == "numexpr": + msg = "Variables in expression .+" + with pytest.raises(NumExprClobberingError, match=msg): + pd.eval("sin + dotted_line", engine=engine, parser=parser) + else: + res = pd.eval("sin + dotted_line", engine=engine, parser=parser) + assert res == sin + dotted_line + + +def test_bad_resolver_raises(engine, parser): + cannot_resolve = 42, 3.0 + with pytest.raises(TypeError, match="Resolver of type .+"): + pd.eval("1 + 2", resolvers=cannot_resolve, engine=engine, parser=parser) + + +def test_empty_string_raises(engine, parser): + # GH 13139 + with pytest.raises(ValueError, match="expr cannot be an empty string"): + pd.eval("", engine=engine, parser=parser) + + +def test_more_than_one_expression_raises(engine, parser): + with pytest.raises(SyntaxError, match="only a single expression is allowed"): + pd.eval("1 + 1; 2 + 2", engine=engine, parser=parser) + + +@pytest.mark.parametrize("cmp", ("and", "or")) +@pytest.mark.parametrize("lhs", (int, float)) +@pytest.mark.parametrize("rhs", (int, float)) +def test_bool_ops_fails_on_scalars(lhs, cmp, rhs, engine, parser): + gen = { + int: lambda: np.random.default_rng(2).integers(10), + float: np.random.default_rng(2).standard_normal, + } + + mid = gen[lhs]() # noqa: F841 + lhs = gen[lhs]() + rhs = gen[rhs]() + + ex1 = f"lhs {cmp} mid {cmp} rhs" + ex2 = f"lhs {cmp} mid and mid {cmp} rhs" + ex3 = f"(lhs {cmp} mid) & (mid {cmp} rhs)" + for ex in (ex1, ex2, ex3): + msg = "cannot evaluate scalar only bool ops|'BoolOp' nodes are not" + with pytest.raises(NotImplementedError, match=msg): + pd.eval(ex, engine=engine, parser=parser) + + +@pytest.mark.parametrize( + "other", + [ + "'x'", + "...", + ], +) +def test_equals_various(other): + df = DataFrame({"A": ["a", "b", "c"]}, dtype=object) + result = df.eval(f"A == {other}") + expected = Series([False, False, False], name="A") + if USE_NUMEXPR: + # https://github.com/pandas-dev/pandas/issues/10239 + # lose name with numexpr engine. Remove when that's fixed. + expected.name = None + tm.assert_series_equal(result, expected) + + +def test_inf(engine, parser): + s = "inf + 1" + expected = np.inf + result = pd.eval(s, engine=engine, parser=parser) + assert result == expected + + +@pytest.mark.parametrize("column", ["Temp(°C)", "Capacitance(μF)"]) +def test_query_token(engine, column): + # See: https://github.com/pandas-dev/pandas/pull/42826 + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), columns=[column, "b"] + ) + expected = df[df[column] > 5] + query_string = f"`{column}` > 5" + result = df.query(query_string, engine=engine) + tm.assert_frame_equal(result, expected) + + +def test_negate_lt_eq_le(engine, parser): + df = DataFrame([[0, 10], [1, 20]], columns=["cat", "count"]) + expected = df[~(df.cat > 0)] + + result = df.query("~(cat > 0)", engine=engine, parser=parser) + tm.assert_frame_equal(result, expected) + + if parser == "python": + msg = "'Not' nodes are not implemented" + with pytest.raises(NotImplementedError, match=msg): + df.query("not (cat > 0)", engine=engine, parser=parser) + else: + result = df.query("not (cat > 0)", engine=engine, parser=parser) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "column", + DEFAULT_GLOBALS.keys(), +) +def test_eval_no_support_column_name(request, column): + # GH 44603 + if column in ["True", "False", "inf", "Inf"]: + request.applymarker( + pytest.mark.xfail( + raises=KeyError, + reason=f"GH 47859 DataFrame eval not supported with {column}", + ) + ) + + df = DataFrame( + np.random.default_rng(2).integers(0, 100, size=(10, 2)), + columns=[column, "col1"], + ) + expected = df[df[column] > 6] + result = df.query(f"{column}>6") + + tm.assert_frame_equal(result, expected) + + +def test_set_inplace(using_copy_on_write, warn_copy_on_write): + # https://github.com/pandas-dev/pandas/issues/47449 + # Ensure we don't only update the DataFrame inplace, but also the actual + # column values, such that references to this column also get updated + df = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]}) + result_view = df[:] + ser = df["A"] + with tm.assert_cow_warning(warn_copy_on_write): + df.eval("A = B + C", inplace=True) + expected = DataFrame({"A": [11, 13, 15], "B": [4, 5, 6], "C": [7, 8, 9]}) + tm.assert_frame_equal(df, expected) + if not using_copy_on_write: + tm.assert_series_equal(ser, expected["A"]) + tm.assert_series_equal(result_view["A"], expected["A"]) + else: + expected = Series([1, 2, 3], name="A") + tm.assert_series_equal(ser, expected) + tm.assert_series_equal(result_view["A"], expected) + + +class TestValidate: + @pytest.mark.parametrize("value", [1, "True", [1, 2, 3], 5.0]) + def test_validate_bool_args(self, value): + msg = 'For argument "inplace" expected type bool, received type' + with pytest.raises(ValueError, match=msg): + pd.eval("2+2", inplace=value) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/construction/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/construction/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/construction/test_extract_array.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/construction/test_extract_array.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd3eda8c995ce022e9d46b907323e79bcd679f8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/construction/test_extract_array.py @@ -0,0 +1,18 @@ +from pandas import Index +import pandas._testing as tm +from pandas.core.construction import extract_array + + +def test_extract_array_rangeindex(): + ri = Index(range(5)) + + expected = ri._values + res = extract_array(ri, extract_numpy=True, extract_range=True) + tm.assert_numpy_array_equal(res, expected) + res = extract_array(ri, extract_numpy=False, extract_range=True) + tm.assert_numpy_array_equal(res, expected) + + res = extract_array(ri, extract_numpy=True, extract_range=False) + tm.assert_index_equal(res, ri) + res = extract_array(ri, extract_numpy=False, extract_range=False) + tm.assert_index_equal(res, ri) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..551131866e14e9040098161025dfbb5a6ebd430a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/conftest.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/conftest.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..934999a5176e011f964c30a49e74dcd1b469e61b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/conftest.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_categorical.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_categorical.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11c42b793dbcca1cf26323a87191c0490643acdb Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_categorical.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_common.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf3765781586cb50307ffe5402838c5035d84564 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_common.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_datetime.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_datetime.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be6f7e616493f7d7efa65d3e54680612e38a3033 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_datetime.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_extension.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_extension.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d3f80a9f333660a0819a0c25bbda2204e136aa4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_extension.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_interval.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_interval.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffa7ce07fe98694a195f8356114e79bee01cf0b0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_interval.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_masked.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_masked.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e88f37b0c6d9cc81651c2c1f6adfc4c278f8f2dc Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_masked.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_numpy.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_numpy.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..414f455f937cc7bed59dccf4dc6c327e633fe85b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_numpy.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_period.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_period.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e96aaf3299b78a361f9f161b17e29abc6735ae58 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_period.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_sparse.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_sparse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec91d7f47a60b55a49ec77255e1bc4c2478779d8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_sparse.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_string.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_string.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1032d1997ddf354f53478ba004bd64f753cff040 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/__pycache__/test_string.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/array_with_attr/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/array_with_attr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..49da6af024a31726743815ba1e36d66c03daafe5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/array_with_attr/__init__.py @@ -0,0 +1,6 @@ +from pandas.tests.extension.array_with_attr.array import ( + FloatAttrArray, + FloatAttrDtype, +) + +__all__ = ["FloatAttrArray", "FloatAttrDtype"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/array_with_attr/array.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/array_with_attr/array.py new file mode 100644 index 0000000000000000000000000000000000000000..2789d51ec2ce3096c64b41af90b5a416ceef9f5b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/array_with_attr/array.py @@ -0,0 +1,89 @@ +""" +Test extension array that has custom attribute information (not stored on the dtype). + +""" +from __future__ import annotations + +import numbers +from typing import TYPE_CHECKING + +import numpy as np + +from pandas.core.dtypes.base import ExtensionDtype + +import pandas as pd +from pandas.core.arrays import ExtensionArray + +if TYPE_CHECKING: + from pandas._typing import type_t + + +class FloatAttrDtype(ExtensionDtype): + type = float + name = "float_attr" + na_value = np.nan + + @classmethod + def construct_array_type(cls) -> type_t[FloatAttrArray]: + """ + Return the array type associated with this dtype. + + Returns + ------- + type + """ + return FloatAttrArray + + +class FloatAttrArray(ExtensionArray): + dtype = FloatAttrDtype() + __array_priority__ = 1000 + + def __init__(self, values, attr=None) -> None: + if not isinstance(values, np.ndarray): + raise TypeError("Need to pass a numpy array of float64 dtype as values") + if not values.dtype == "float64": + raise TypeError("Need to pass a numpy array of float64 dtype as values") + self.data = values + self.attr = attr + + @classmethod + def _from_sequence(cls, scalars, *, dtype=None, copy=False): + if not copy: + data = np.asarray(scalars, dtype="float64") + else: + data = np.array(scalars, dtype="float64", copy=copy) + return cls(data) + + def __getitem__(self, item): + if isinstance(item, numbers.Integral): + return self.data[item] + else: + # slice, list-like, mask + item = pd.api.indexers.check_array_indexer(self, item) + return type(self)(self.data[item], self.attr) + + def __len__(self) -> int: + return len(self.data) + + def isna(self): + return np.isnan(self.data) + + def take(self, indexer, allow_fill=False, fill_value=None): + from pandas.api.extensions import take + + data = self.data + if allow_fill and fill_value is None: + fill_value = self.dtype.na_value + + result = take(data, indexer, fill_value=fill_value, allow_fill=allow_fill) + return type(self)(result, self.attr) + + def copy(self): + return type(self)(self.data.copy(), self.attr) + + @classmethod + def _concat_same_type(cls, to_concat): + data = np.concatenate([x.data for x in to_concat]) + attr = to_concat[0].attr if len(to_concat) else None + return cls(data, attr) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/array_with_attr/test_array_with_attr.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/array_with_attr/test_array_with_attr.py new file mode 100644 index 0000000000000000000000000000000000000000..3735fe40a0d67784b3603a177b6694e56e26d479 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/array_with_attr/test_array_with_attr.py @@ -0,0 +1,33 @@ +import numpy as np + +import pandas as pd +import pandas._testing as tm +from pandas.tests.extension.array_with_attr import FloatAttrArray + + +def test_concat_with_all_na(): + # https://github.com/pandas-dev/pandas/pull/47762 + # ensure that attribute of the column array is preserved (when it gets + # preserved in reindexing the array) during merge/concat + arr = FloatAttrArray(np.array([np.nan, np.nan], dtype="float64"), attr="test") + + df1 = pd.DataFrame({"col": arr, "key": [0, 1]}) + df2 = pd.DataFrame({"key": [0, 1], "col2": [1, 2]}) + result = pd.merge(df1, df2, on="key") + expected = pd.DataFrame({"col": arr, "key": [0, 1], "col2": [1, 2]}) + tm.assert_frame_equal(result, expected) + assert result["col"].array.attr == "test" + + df1 = pd.DataFrame({"col": arr, "key": [0, 1]}) + df2 = pd.DataFrame({"key": [0, 2], "col2": [1, 2]}) + result = pd.merge(df1, df2, on="key") + expected = pd.DataFrame({"col": arr.take([0]), "key": [0], "col2": [1]}) + tm.assert_frame_equal(result, expected) + assert result["col"].array.attr == "test" + + result = pd.concat([df1.set_index("key"), df2.set_index("key")], axis=1) + expected = pd.DataFrame( + {"col": arr.take([0, 1, -1]), "col2": [1, np.nan, 2], "key": [0, 1, 2]} + ).set_index("key") + tm.assert_frame_equal(result, expected) + assert result["col"].array.attr == "test" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6efaa95aef1b51c33df668db870eaa8741010a59 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/__init__.py @@ -0,0 +1,131 @@ +""" +Base test suite for extension arrays. + +These tests are intended for third-party libraries to subclass to validate +that their extension arrays and dtypes satisfy the interface. Moving or +renaming the tests should not be done lightly. + +Libraries are expected to implement a few pytest fixtures to provide data +for the tests. The fixtures may be located in either + +* The same module as your test class. +* A ``conftest.py`` in the same directory as your test class. + +The full list of fixtures may be found in the ``conftest.py`` next to this +file. + +.. code-block:: python + + import pytest + from pandas.tests.extension.base import BaseDtypeTests + + + @pytest.fixture + def dtype(): + return MyDtype() + + + class TestMyDtype(BaseDtypeTests): + pass + + +Your class ``TestDtype`` will inherit all the tests defined on +``BaseDtypeTests``. pytest's fixture discover will supply your ``dtype`` +wherever the test requires it. You're free to implement additional tests. + +""" +from pandas.tests.extension.base.accumulate import BaseAccumulateTests +from pandas.tests.extension.base.casting import BaseCastingTests +from pandas.tests.extension.base.constructors import BaseConstructorsTests +from pandas.tests.extension.base.dim2 import ( # noqa: F401 + Dim2CompatTests, + NDArrayBacked2DTests, +) +from pandas.tests.extension.base.dtype import BaseDtypeTests +from pandas.tests.extension.base.getitem import BaseGetitemTests +from pandas.tests.extension.base.groupby import BaseGroupbyTests +from pandas.tests.extension.base.index import BaseIndexTests +from pandas.tests.extension.base.interface import BaseInterfaceTests +from pandas.tests.extension.base.io import BaseParsingTests +from pandas.tests.extension.base.methods import BaseMethodsTests +from pandas.tests.extension.base.missing import BaseMissingTests +from pandas.tests.extension.base.ops import ( # noqa: F401 + BaseArithmeticOpsTests, + BaseComparisonOpsTests, + BaseOpsUtil, + BaseUnaryOpsTests, +) +from pandas.tests.extension.base.printing import BasePrintingTests +from pandas.tests.extension.base.reduce import BaseReduceTests +from pandas.tests.extension.base.reshaping import BaseReshapingTests +from pandas.tests.extension.base.setitem import BaseSetitemTests + + +# One test class that you can inherit as an alternative to inheriting all the +# test classes above. +# Note 1) this excludes Dim2CompatTests and NDArrayBacked2DTests. +# Note 2) this uses BaseReduceTests and and _not_ BaseBooleanReduceTests, +# BaseNoReduceTests, or BaseNumericReduceTests +class ExtensionTests( + BaseAccumulateTests, + BaseCastingTests, + BaseConstructorsTests, + BaseDtypeTests, + BaseGetitemTests, + BaseGroupbyTests, + BaseIndexTests, + BaseInterfaceTests, + BaseParsingTests, + BaseMethodsTests, + BaseMissingTests, + BaseArithmeticOpsTests, + BaseComparisonOpsTests, + BaseUnaryOpsTests, + BasePrintingTests, + BaseReduceTests, + BaseReshapingTests, + BaseSetitemTests, + Dim2CompatTests, +): + pass + + +def __getattr__(name: str): + import warnings + + if name == "BaseNoReduceTests": + warnings.warn( + "BaseNoReduceTests is deprecated and will be removed in a " + "future version. Use BaseReduceTests and override " + "`_supports_reduction` instead.", + FutureWarning, + ) + from pandas.tests.extension.base.reduce import BaseNoReduceTests + + return BaseNoReduceTests + + elif name == "BaseNumericReduceTests": + warnings.warn( + "BaseNumericReduceTests is deprecated and will be removed in a " + "future version. Use BaseReduceTests and override " + "`_supports_reduction` instead.", + FutureWarning, + ) + from pandas.tests.extension.base.reduce import BaseNumericReduceTests + + return BaseNumericReduceTests + + elif name == "BaseBooleanReduceTests": + warnings.warn( + "BaseBooleanReduceTests is deprecated and will be removed in a " + "future version. Use BaseReduceTests and override " + "`_supports_reduction` instead.", + FutureWarning, + ) + from pandas.tests.extension.base.reduce import BaseBooleanReduceTests + + return BaseBooleanReduceTests + + raise AttributeError( + f"module 'pandas.tests.extension.base' has no attribute '{name}'" + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/accumulate.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/accumulate.py new file mode 100644 index 0000000000000000000000000000000000000000..9a41a3a582c4a5a140262aed5d0ea812129870f4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/accumulate.py @@ -0,0 +1,39 @@ +import pytest + +import pandas as pd +import pandas._testing as tm + + +class BaseAccumulateTests: + """ + Accumulation specific tests. Generally these only + make sense for numeric/boolean operations. + """ + + def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool: + # Do we expect this accumulation to be supported for this dtype? + # We default to assuming "no"; subclass authors should override here. + return False + + def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool): + try: + alt = ser.astype("float64") + except TypeError: + # e.g. Period can't be cast to float64 + alt = ser.astype(object) + + result = getattr(ser, op_name)(skipna=skipna) + expected = getattr(alt, op_name)(skipna=skipna) + tm.assert_series_equal(result, expected, check_dtype=False) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_accumulate_series(self, data, all_numeric_accumulations, skipna): + op_name = all_numeric_accumulations + ser = pd.Series(data) + + if self._supports_accumulation(ser, op_name): + self.check_accumulate(ser, op_name, skipna) + else: + with pytest.raises((NotImplementedError, TypeError)): + # TODO: require TypeError for things that will _never_ work? + getattr(ser, op_name)(skipna=skipna) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/base.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/base.py new file mode 100644 index 0000000000000000000000000000000000000000..747ebee738c1ee5cf9bbcc858aedd55dea39b38c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/base.py @@ -0,0 +1,2 @@ +class BaseExtensionTests: + pass diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/casting.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/casting.py new file mode 100644 index 0000000000000000000000000000000000000000..2bfe801c48a7794b86fa6c75f5edda9f25269caa --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/casting.py @@ -0,0 +1,87 @@ +import numpy as np +import pytest + +import pandas.util._test_decorators as td + +import pandas as pd +import pandas._testing as tm +from pandas.core.internals.blocks import NumpyBlock + + +class BaseCastingTests: + """Casting to and from ExtensionDtypes""" + + def test_astype_object_series(self, all_data): + ser = pd.Series(all_data, name="A") + result = ser.astype(object) + assert result.dtype == np.dtype(object) + if hasattr(result._mgr, "blocks"): + blk = result._mgr.blocks[0] + assert isinstance(blk, NumpyBlock) + assert blk.is_object + assert isinstance(result._mgr.array, np.ndarray) + assert result._mgr.array.dtype == np.dtype(object) + + def test_astype_object_frame(self, all_data): + df = pd.DataFrame({"A": all_data}) + + result = df.astype(object) + if hasattr(result._mgr, "blocks"): + blk = result._mgr.blocks[0] + assert isinstance(blk, NumpyBlock), type(blk) + assert blk.is_object + assert isinstance(result._mgr.arrays[0], np.ndarray) + assert result._mgr.arrays[0].dtype == np.dtype(object) + + # check that we can compare the dtypes + comp = result.dtypes == df.dtypes + assert not comp.any() + + def test_tolist(self, data): + result = pd.Series(data).tolist() + expected = list(data) + assert result == expected + + def test_astype_str(self, data): + result = pd.Series(data[:5]).astype(str) + expected = pd.Series([str(x) for x in data[:5]], dtype=str) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "nullable_string_dtype", + [ + "string[python]", + pytest.param("string[pyarrow]", marks=td.skip_if_no("pyarrow")), + ], + ) + def test_astype_string(self, data, nullable_string_dtype): + # GH-33465, GH#45326 as of 2.0 we decode bytes instead of calling str(obj) + result = pd.Series(data[:5]).astype(nullable_string_dtype) + expected = pd.Series( + [str(x) if not isinstance(x, bytes) else x.decode() for x in data[:5]], + dtype=nullable_string_dtype, + ) + tm.assert_series_equal(result, expected) + + def test_to_numpy(self, data): + expected = np.asarray(data) + + result = data.to_numpy() + tm.assert_equal(result, expected) + + result = pd.Series(data).to_numpy() + tm.assert_equal(result, expected) + + def test_astype_empty_dataframe(self, dtype): + # https://github.com/pandas-dev/pandas/issues/33113 + df = pd.DataFrame() + result = df.astype(dtype) + tm.assert_frame_equal(result, df) + + @pytest.mark.parametrize("copy", [True, False]) + def test_astype_own_type(self, data, copy): + # ensure that astype returns the original object for equal dtype and copy=False + # https://github.com/pandas-dev/pandas/issues/28488 + result = data.astype(data.dtype, copy=copy) + assert (result is data) is (not copy) + tm.assert_extension_array_equal(result, data) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/constructors.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/constructors.py new file mode 100644 index 0000000000000000000000000000000000000000..c32a6a6a115ac992a9b85b6a35e77e4d446fbd07 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/constructors.py @@ -0,0 +1,142 @@ +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm +from pandas.api.extensions import ExtensionArray +from pandas.core.internals.blocks import EABackedBlock + + +class BaseConstructorsTests: + def test_from_sequence_from_cls(self, data): + result = type(data)._from_sequence(data, dtype=data.dtype) + tm.assert_extension_array_equal(result, data) + + data = data[:0] + result = type(data)._from_sequence(data, dtype=data.dtype) + tm.assert_extension_array_equal(result, data) + + def test_array_from_scalars(self, data): + scalars = [data[0], data[1], data[2]] + result = data._from_sequence(scalars, dtype=data.dtype) + assert isinstance(result, type(data)) + + def test_series_constructor(self, data): + result = pd.Series(data, copy=False) + assert result.dtype == data.dtype + assert len(result) == len(data) + if hasattr(result._mgr, "blocks"): + assert isinstance(result._mgr.blocks[0], EABackedBlock) + assert result._mgr.array is data + + # Series[EA] is unboxed / boxed correctly + result2 = pd.Series(result) + assert result2.dtype == data.dtype + if hasattr(result._mgr, "blocks"): + assert isinstance(result2._mgr.blocks[0], EABackedBlock) + + def test_series_constructor_no_data_with_index(self, dtype, na_value): + result = pd.Series(index=[1, 2, 3], dtype=dtype) + expected = pd.Series([na_value] * 3, index=[1, 2, 3], dtype=dtype) + tm.assert_series_equal(result, expected) + + # GH 33559 - empty index + result = pd.Series(index=[], dtype=dtype) + expected = pd.Series([], index=pd.Index([], dtype="object"), dtype=dtype) + tm.assert_series_equal(result, expected) + + def test_series_constructor_scalar_na_with_index(self, dtype, na_value): + result = pd.Series(na_value, index=[1, 2, 3], dtype=dtype) + expected = pd.Series([na_value] * 3, index=[1, 2, 3], dtype=dtype) + tm.assert_series_equal(result, expected) + + def test_series_constructor_scalar_with_index(self, data, dtype): + scalar = data[0] + result = pd.Series(scalar, index=[1, 2, 3], dtype=dtype) + expected = pd.Series([scalar] * 3, index=[1, 2, 3], dtype=dtype) + tm.assert_series_equal(result, expected) + + result = pd.Series(scalar, index=["foo"], dtype=dtype) + expected = pd.Series([scalar], index=["foo"], dtype=dtype) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("from_series", [True, False]) + def test_dataframe_constructor_from_dict(self, data, from_series): + if from_series: + data = pd.Series(data) + result = pd.DataFrame({"A": data}) + assert result.dtypes["A"] == data.dtype + assert result.shape == (len(data), 1) + if hasattr(result._mgr, "blocks"): + assert isinstance(result._mgr.blocks[0], EABackedBlock) + assert isinstance(result._mgr.arrays[0], ExtensionArray) + + def test_dataframe_from_series(self, data): + result = pd.DataFrame(pd.Series(data)) + assert result.dtypes[0] == data.dtype + assert result.shape == (len(data), 1) + if hasattr(result._mgr, "blocks"): + assert isinstance(result._mgr.blocks[0], EABackedBlock) + assert isinstance(result._mgr.arrays[0], ExtensionArray) + + def test_series_given_mismatched_index_raises(self, data): + msg = r"Length of values \(3\) does not match length of index \(5\)" + with pytest.raises(ValueError, match=msg): + pd.Series(data[:3], index=[0, 1, 2, 3, 4]) + + def test_from_dtype(self, data): + # construct from our dtype & string dtype + dtype = data.dtype + + expected = pd.Series(data) + result = pd.Series(list(data), dtype=dtype) + tm.assert_series_equal(result, expected) + + result = pd.Series(list(data), dtype=str(dtype)) + tm.assert_series_equal(result, expected) + + # gh-30280 + + expected = pd.DataFrame(data).astype(dtype) + result = pd.DataFrame(list(data), dtype=dtype) + tm.assert_frame_equal(result, expected) + + result = pd.DataFrame(list(data), dtype=str(dtype)) + tm.assert_frame_equal(result, expected) + + def test_pandas_array(self, data): + # pd.array(extension_array) should be idempotent... + result = pd.array(data) + tm.assert_extension_array_equal(result, data) + + def test_pandas_array_dtype(self, data): + # ... but specifying dtype will override idempotency + result = pd.array(data, dtype=np.dtype(object)) + expected = pd.arrays.NumpyExtensionArray(np.asarray(data, dtype=object)) + tm.assert_equal(result, expected) + + def test_construct_empty_dataframe(self, dtype): + # GH 33623 + result = pd.DataFrame(columns=["a"], dtype=dtype) + expected = pd.DataFrame( + {"a": pd.array([], dtype=dtype)}, index=pd.RangeIndex(0) + ) + tm.assert_frame_equal(result, expected) + + def test_empty(self, dtype): + cls = dtype.construct_array_type() + result = cls._empty((4,), dtype=dtype) + assert isinstance(result, cls) + assert result.dtype == dtype + assert result.shape == (4,) + + # GH#19600 method on ExtensionDtype + result2 = dtype.empty((4,)) + assert isinstance(result2, cls) + assert result2.dtype == dtype + assert result2.shape == (4,) + + result2 = dtype.empty(4) + assert isinstance(result2, cls) + assert result2.dtype == dtype + assert result2.shape == (4,) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/dim2.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/dim2.py new file mode 100644 index 0000000000000000000000000000000000000000..132cda5a94ed00e74b7f36869bfe0c789dd5ccd7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/dim2.py @@ -0,0 +1,345 @@ +""" +Tests for 2D compatibility. +""" +import numpy as np +import pytest + +from pandas._libs.missing import is_matching_na + +from pandas.core.dtypes.common import ( + is_bool_dtype, + is_integer_dtype, +) + +import pandas as pd +import pandas._testing as tm +from pandas.core.arrays.integer import NUMPY_INT_TO_DTYPE + + +class Dim2CompatTests: + # Note: these are ONLY for ExtensionArray subclasses that support 2D arrays. + # i.e. not for pyarrow-backed EAs. + + @pytest.fixture(autouse=True) + def skip_if_doesnt_support_2d(self, dtype, request): + if not dtype._supports_2d: + node = request.node + # In cases where we are mixed in to ExtensionTests, we only want to + # skip tests that are defined in Dim2CompatTests + test_func = node._obj + if test_func.__qualname__.startswith("Dim2CompatTests"): + # TODO: is there a less hacky way of checking this? + pytest.skip(f"{dtype} does not support 2D.") + + def test_transpose(self, data): + arr2d = data.repeat(2).reshape(-1, 2) + shape = arr2d.shape + assert shape[0] != shape[-1] # otherwise the rest of the test is useless + + assert arr2d.T.shape == shape[::-1] + + def test_frame_from_2d_array(self, data): + arr2d = data.repeat(2).reshape(-1, 2) + + df = pd.DataFrame(arr2d) + expected = pd.DataFrame({0: arr2d[:, 0], 1: arr2d[:, 1]}) + tm.assert_frame_equal(df, expected) + + def test_swapaxes(self, data): + arr2d = data.repeat(2).reshape(-1, 2) + + result = arr2d.swapaxes(0, 1) + expected = arr2d.T + tm.assert_extension_array_equal(result, expected) + + def test_delete_2d(self, data): + arr2d = data.repeat(3).reshape(-1, 3) + + # axis = 0 + result = arr2d.delete(1, axis=0) + expected = data.delete(1).repeat(3).reshape(-1, 3) + tm.assert_extension_array_equal(result, expected) + + # axis = 1 + result = arr2d.delete(1, axis=1) + expected = data.repeat(2).reshape(-1, 2) + tm.assert_extension_array_equal(result, expected) + + def test_take_2d(self, data): + arr2d = data.reshape(-1, 1) + + result = arr2d.take([0, 0, -1], axis=0) + + expected = data.take([0, 0, -1]).reshape(-1, 1) + tm.assert_extension_array_equal(result, expected) + + def test_repr_2d(self, data): + # this could fail in a corner case where an element contained the name + res = repr(data.reshape(1, -1)) + assert res.count(f"<{type(data).__name__}") == 1 + + res = repr(data.reshape(-1, 1)) + assert res.count(f"<{type(data).__name__}") == 1 + + def test_reshape(self, data): + arr2d = data.reshape(-1, 1) + assert arr2d.shape == (data.size, 1) + assert len(arr2d) == len(data) + + arr2d = data.reshape((-1, 1)) + assert arr2d.shape == (data.size, 1) + assert len(arr2d) == len(data) + + with pytest.raises(ValueError): + data.reshape((data.size, 2)) + with pytest.raises(ValueError): + data.reshape(data.size, 2) + + def test_getitem_2d(self, data): + arr2d = data.reshape(1, -1) + + result = arr2d[0] + tm.assert_extension_array_equal(result, data) + + with pytest.raises(IndexError): + arr2d[1] + + with pytest.raises(IndexError): + arr2d[-2] + + result = arr2d[:] + tm.assert_extension_array_equal(result, arr2d) + + result = arr2d[:, :] + tm.assert_extension_array_equal(result, arr2d) + + result = arr2d[:, 0] + expected = data[[0]] + tm.assert_extension_array_equal(result, expected) + + # dimension-expanding getitem on 1D + result = data[:, np.newaxis] + tm.assert_extension_array_equal(result, arr2d.T) + + def test_iter_2d(self, data): + arr2d = data.reshape(1, -1) + + objs = list(iter(arr2d)) + assert len(objs) == arr2d.shape[0] + + for obj in objs: + assert isinstance(obj, type(data)) + assert obj.dtype == data.dtype + assert obj.ndim == 1 + assert len(obj) == arr2d.shape[1] + + def test_tolist_2d(self, data): + arr2d = data.reshape(1, -1) + + result = arr2d.tolist() + expected = [data.tolist()] + + assert isinstance(result, list) + assert all(isinstance(x, list) for x in result) + + assert result == expected + + def test_concat_2d(self, data): + left = type(data)._concat_same_type([data, data]).reshape(-1, 2) + right = left.copy() + + # axis=0 + result = left._concat_same_type([left, right], axis=0) + expected = data._concat_same_type([data] * 4).reshape(-1, 2) + tm.assert_extension_array_equal(result, expected) + + # axis=1 + result = left._concat_same_type([left, right], axis=1) + assert result.shape == (len(data), 4) + tm.assert_extension_array_equal(result[:, :2], left) + tm.assert_extension_array_equal(result[:, 2:], right) + + # axis > 1 -> invalid + msg = "axis 2 is out of bounds for array of dimension 2" + with pytest.raises(ValueError, match=msg): + left._concat_same_type([left, right], axis=2) + + @pytest.mark.parametrize("method", ["backfill", "pad"]) + def test_fillna_2d_method(self, data_missing, method): + # pad_or_backfill is always along axis=0 + arr = data_missing.repeat(2).reshape(2, 2) + assert arr[0].isna().all() + assert not arr[1].isna().any() + + result = arr._pad_or_backfill(method=method, limit=None) + + expected = data_missing._pad_or_backfill(method=method).repeat(2).reshape(2, 2) + tm.assert_extension_array_equal(result, expected) + + # Reverse so that backfill is not a no-op. + arr2 = arr[::-1] + assert not arr2[0].isna().any() + assert arr2[1].isna().all() + + result2 = arr2._pad_or_backfill(method=method, limit=None) + + expected2 = ( + data_missing[::-1]._pad_or_backfill(method=method).repeat(2).reshape(2, 2) + ) + tm.assert_extension_array_equal(result2, expected2) + + @pytest.mark.parametrize("method", ["mean", "median", "var", "std", "sum", "prod"]) + def test_reductions_2d_axis_none(self, data, method): + arr2d = data.reshape(1, -1) + + err_expected = None + err_result = None + try: + expected = getattr(data, method)() + except Exception as err: + # if the 1D reduction is invalid, the 2D reduction should be as well + err_expected = err + try: + result = getattr(arr2d, method)(axis=None) + except Exception as err2: + err_result = err2 + + else: + result = getattr(arr2d, method)(axis=None) + + if err_result is not None or err_expected is not None: + assert type(err_result) == type(err_expected) + return + + assert is_matching_na(result, expected) or result == expected + + @pytest.mark.parametrize("method", ["mean", "median", "var", "std", "sum", "prod"]) + @pytest.mark.parametrize("min_count", [0, 1]) + def test_reductions_2d_axis0(self, data, method, min_count): + if min_count == 1 and method not in ["sum", "prod"]: + pytest.skip(f"min_count not relevant for {method}") + + arr2d = data.reshape(1, -1) + + kwargs = {} + if method in ["std", "var"]: + # pass ddof=0 so we get all-zero std instead of all-NA std + kwargs["ddof"] = 0 + elif method in ["prod", "sum"]: + kwargs["min_count"] = min_count + + try: + result = getattr(arr2d, method)(axis=0, **kwargs) + except Exception as err: + try: + getattr(data, method)() + except Exception as err2: + assert type(err) == type(err2) + return + else: + raise AssertionError("Both reductions should raise or neither") + + def get_reduction_result_dtype(dtype): + # windows and 32bit builds will in some cases have int32/uint32 + # where other builds will have int64/uint64. + if dtype.itemsize == 8: + return dtype + elif dtype.kind in "ib": + return NUMPY_INT_TO_DTYPE[np.dtype(int)] + else: + # i.e. dtype.kind == "u" + return NUMPY_INT_TO_DTYPE[np.dtype("uint")] + + if method in ["sum", "prod"]: + # std and var are not dtype-preserving + expected = data + if data.dtype.kind in "iub": + dtype = get_reduction_result_dtype(data.dtype) + expected = data.astype(dtype) + assert dtype == expected.dtype + + if min_count == 0: + fill_value = 1 if method == "prod" else 0 + expected = expected.fillna(fill_value) + + tm.assert_extension_array_equal(result, expected) + elif method == "median": + # std and var are not dtype-preserving + expected = data + tm.assert_extension_array_equal(result, expected) + elif method in ["mean", "std", "var"]: + if is_integer_dtype(data) or is_bool_dtype(data): + data = data.astype("Float64") + if method == "mean": + tm.assert_extension_array_equal(result, data) + else: + tm.assert_extension_array_equal(result, data - data) + + @pytest.mark.parametrize("method", ["mean", "median", "var", "std", "sum", "prod"]) + def test_reductions_2d_axis1(self, data, method): + arr2d = data.reshape(1, -1) + + try: + result = getattr(arr2d, method)(axis=1) + except Exception as err: + try: + getattr(data, method)() + except Exception as err2: + assert type(err) == type(err2) + return + else: + raise AssertionError("Both reductions should raise or neither") + + # not necessarily type/dtype-preserving, so weaker assertions + assert result.shape == (1,) + expected_scalar = getattr(data, method)() + res = result[0] + assert is_matching_na(res, expected_scalar) or res == expected_scalar + + +class NDArrayBacked2DTests(Dim2CompatTests): + # More specific tests for NDArrayBackedExtensionArray subclasses + + def test_copy_order(self, data): + # We should be matching numpy semantics for the "order" keyword in 'copy' + arr2d = data.repeat(2).reshape(-1, 2) + assert arr2d._ndarray.flags["C_CONTIGUOUS"] + + res = arr2d.copy() + assert res._ndarray.flags["C_CONTIGUOUS"] + + res = arr2d[::2, ::2].copy() + assert res._ndarray.flags["C_CONTIGUOUS"] + + res = arr2d.copy("F") + assert not res._ndarray.flags["C_CONTIGUOUS"] + assert res._ndarray.flags["F_CONTIGUOUS"] + + res = arr2d.copy("K") + assert res._ndarray.flags["C_CONTIGUOUS"] + + res = arr2d.T.copy("K") + assert not res._ndarray.flags["C_CONTIGUOUS"] + assert res._ndarray.flags["F_CONTIGUOUS"] + + # order not accepted by numpy + msg = r"order must be one of 'C', 'F', 'A', or 'K' \(got 'Q'\)" + with pytest.raises(ValueError, match=msg): + arr2d.copy("Q") + + # neither contiguity + arr_nc = arr2d[::2] + assert not arr_nc._ndarray.flags["C_CONTIGUOUS"] + assert not arr_nc._ndarray.flags["F_CONTIGUOUS"] + + assert arr_nc.copy()._ndarray.flags["C_CONTIGUOUS"] + assert not arr_nc.copy()._ndarray.flags["F_CONTIGUOUS"] + + assert arr_nc.copy("C")._ndarray.flags["C_CONTIGUOUS"] + assert not arr_nc.copy("C")._ndarray.flags["F_CONTIGUOUS"] + + assert not arr_nc.copy("F")._ndarray.flags["C_CONTIGUOUS"] + assert arr_nc.copy("F")._ndarray.flags["F_CONTIGUOUS"] + + assert arr_nc.copy("K")._ndarray.flags["C_CONTIGUOUS"] + assert not arr_nc.copy("K")._ndarray.flags["F_CONTIGUOUS"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/dtype.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/dtype.py new file mode 100644 index 0000000000000000000000000000000000000000..c7b768f6e3c88f32a7f9f5c945642e4d69a17c66 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/dtype.py @@ -0,0 +1,123 @@ +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm +from pandas.api.types import ( + infer_dtype, + is_object_dtype, + is_string_dtype, +) + + +class BaseDtypeTests: + """Base class for ExtensionDtype classes""" + + def test_name(self, dtype): + assert isinstance(dtype.name, str) + + def test_kind(self, dtype): + valid = set("biufcmMOSUV") + assert dtype.kind in valid + + def test_is_dtype_from_name(self, dtype): + result = type(dtype).is_dtype(dtype.name) + assert result is True + + def test_is_dtype_unboxes_dtype(self, data, dtype): + assert dtype.is_dtype(data) is True + + def test_is_dtype_from_self(self, dtype): + result = type(dtype).is_dtype(dtype) + assert result is True + + def test_is_dtype_other_input(self, dtype): + assert dtype.is_dtype([1, 2, 3]) is False + + def test_is_not_string_type(self, dtype): + assert not is_string_dtype(dtype) + + def test_is_not_object_type(self, dtype): + assert not is_object_dtype(dtype) + + def test_eq_with_str(self, dtype): + assert dtype == dtype.name + assert dtype != dtype.name + "-suffix" + + def test_eq_with_numpy_object(self, dtype): + assert dtype != np.dtype("object") + + def test_eq_with_self(self, dtype): + assert dtype == dtype + assert dtype != object() + + def test_array_type(self, data, dtype): + assert dtype.construct_array_type() is type(data) + + def test_check_dtype(self, data): + dtype = data.dtype + + # check equivalency for using .dtypes + df = pd.DataFrame( + { + "A": pd.Series(data, dtype=dtype), + "B": data, + "C": pd.Series(["foo"] * len(data), dtype=object), + "D": 1, + } + ) + result = df.dtypes == str(dtype) + assert np.dtype("int64") != "Int64" + + expected = pd.Series([True, True, False, False], index=list("ABCD")) + + tm.assert_series_equal(result, expected) + + expected = pd.Series([True, True, False, False], index=list("ABCD")) + result = df.dtypes.apply(str) == str(dtype) + tm.assert_series_equal(result, expected) + + def test_hashable(self, dtype): + hash(dtype) # no error + + def test_str(self, dtype): + assert str(dtype) == dtype.name + + def test_eq(self, dtype): + assert dtype == dtype.name + assert dtype != "anonther_type" + + def test_construct_from_string_own_name(self, dtype): + result = dtype.construct_from_string(dtype.name) + assert type(result) is type(dtype) + + # check OK as classmethod + result = type(dtype).construct_from_string(dtype.name) + assert type(result) is type(dtype) + + def test_construct_from_string_another_type_raises(self, dtype): + msg = f"Cannot construct a '{type(dtype).__name__}' from 'another_type'" + with pytest.raises(TypeError, match=msg): + type(dtype).construct_from_string("another_type") + + def test_construct_from_string_wrong_type_raises(self, dtype): + with pytest.raises( + TypeError, + match="'construct_from_string' expects a string, got ", + ): + type(dtype).construct_from_string(0) + + def test_get_common_dtype(self, dtype): + # in practice we will not typically call this with a 1-length list + # (we shortcut to just use that dtype as the common dtype), but + # still testing as good practice to have this working (and it is the + # only case we can test in general) + assert dtype._get_common_dtype([dtype]) == dtype + + @pytest.mark.parametrize("skipna", [True, False]) + def test_infer_dtype(self, data, data_missing, skipna): + # only testing that this works without raising an error + res = infer_dtype(data, skipna=skipna) + assert isinstance(res, str) + res = infer_dtype(data_missing, skipna=skipna) + assert isinstance(res, str) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/getitem.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/getitem.py new file mode 100644 index 0000000000000000000000000000000000000000..5f0c1b960a4758e7cd5423188d5f190922b4eee4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/getitem.py @@ -0,0 +1,469 @@ +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm + + +class BaseGetitemTests: + """Tests for ExtensionArray.__getitem__.""" + + def test_iloc_series(self, data): + ser = pd.Series(data) + result = ser.iloc[:4] + expected = pd.Series(data[:4]) + tm.assert_series_equal(result, expected) + + result = ser.iloc[[0, 1, 2, 3]] + tm.assert_series_equal(result, expected) + + def test_iloc_frame(self, data): + df = pd.DataFrame({"A": data, "B": np.arange(len(data), dtype="int64")}) + expected = pd.DataFrame({"A": data[:4]}) + + # slice -> frame + result = df.iloc[:4, [0]] + tm.assert_frame_equal(result, expected) + + # sequence -> frame + result = df.iloc[[0, 1, 2, 3], [0]] + tm.assert_frame_equal(result, expected) + + expected = pd.Series(data[:4], name="A") + + # slice -> series + result = df.iloc[:4, 0] + tm.assert_series_equal(result, expected) + + # sequence -> series + result = df.iloc[:4, 0] + tm.assert_series_equal(result, expected) + + # GH#32959 slice columns with step + result = df.iloc[:, ::2] + tm.assert_frame_equal(result, df[["A"]]) + result = df[["B", "A"]].iloc[:, ::2] + tm.assert_frame_equal(result, df[["B"]]) + + def test_iloc_frame_single_block(self, data): + # GH#32959 null slice along index, slice along columns with single-block + df = pd.DataFrame({"A": data}) + + result = df.iloc[:, :] + tm.assert_frame_equal(result, df) + + result = df.iloc[:, :1] + tm.assert_frame_equal(result, df) + + result = df.iloc[:, :2] + tm.assert_frame_equal(result, df) + + result = df.iloc[:, ::2] + tm.assert_frame_equal(result, df) + + result = df.iloc[:, 1:2] + tm.assert_frame_equal(result, df.iloc[:, :0]) + + result = df.iloc[:, -1:] + tm.assert_frame_equal(result, df) + + def test_loc_series(self, data): + ser = pd.Series(data) + result = ser.loc[:3] + expected = pd.Series(data[:4]) + tm.assert_series_equal(result, expected) + + result = ser.loc[[0, 1, 2, 3]] + tm.assert_series_equal(result, expected) + + def test_loc_frame(self, data): + df = pd.DataFrame({"A": data, "B": np.arange(len(data), dtype="int64")}) + expected = pd.DataFrame({"A": data[:4]}) + + # slice -> frame + result = df.loc[:3, ["A"]] + tm.assert_frame_equal(result, expected) + + # sequence -> frame + result = df.loc[[0, 1, 2, 3], ["A"]] + tm.assert_frame_equal(result, expected) + + expected = pd.Series(data[:4], name="A") + + # slice -> series + result = df.loc[:3, "A"] + tm.assert_series_equal(result, expected) + + # sequence -> series + result = df.loc[:3, "A"] + tm.assert_series_equal(result, expected) + + def test_loc_iloc_frame_single_dtype(self, data): + # GH#27110 bug in ExtensionBlock.iget caused df.iloc[n] to incorrectly + # return a scalar + df = pd.DataFrame({"A": data}) + expected = pd.Series([data[2]], index=["A"], name=2, dtype=data.dtype) + + result = df.loc[2] + tm.assert_series_equal(result, expected) + + expected = pd.Series( + [data[-1]], index=["A"], name=len(data) - 1, dtype=data.dtype + ) + result = df.iloc[-1] + tm.assert_series_equal(result, expected) + + def test_getitem_scalar(self, data): + result = data[0] + assert isinstance(result, data.dtype.type) + + result = pd.Series(data)[0] + assert isinstance(result, data.dtype.type) + + def test_getitem_invalid(self, data): + # TODO: box over scalar, [scalar], (scalar,)? + + msg = ( + r"only integers, slices \(`:`\), ellipsis \(`...`\), numpy.newaxis " + r"\(`None`\) and integer or boolean arrays are valid indices" + ) + with pytest.raises(IndexError, match=msg): + data["foo"] + with pytest.raises(IndexError, match=msg): + data[2.5] + + ub = len(data) + msg = "|".join( + [ + "list index out of range", # json + "index out of bounds", # pyarrow + "Out of bounds access", # Sparse + f"loc must be an integer between -{ub} and {ub}", # Sparse + f"index {ub+1} is out of bounds for axis 0 with size {ub}", + f"index -{ub+1} is out of bounds for axis 0 with size {ub}", + ] + ) + with pytest.raises(IndexError, match=msg): + data[ub + 1] + with pytest.raises(IndexError, match=msg): + data[-ub - 1] + + def test_getitem_scalar_na(self, data_missing, na_cmp, na_value): + result = data_missing[0] + assert na_cmp(result, na_value) + + def test_getitem_empty(self, data): + # Indexing with empty list + result = data[[]] + assert len(result) == 0 + assert isinstance(result, type(data)) + + expected = data[np.array([], dtype="int64")] + tm.assert_extension_array_equal(result, expected) + + def test_getitem_mask(self, data): + # Empty mask, raw array + mask = np.zeros(len(data), dtype=bool) + result = data[mask] + assert len(result) == 0 + assert isinstance(result, type(data)) + + # Empty mask, in series + mask = np.zeros(len(data), dtype=bool) + result = pd.Series(data)[mask] + assert len(result) == 0 + assert result.dtype == data.dtype + + # non-empty mask, raw array + mask[0] = True + result = data[mask] + assert len(result) == 1 + assert isinstance(result, type(data)) + + # non-empty mask, in series + result = pd.Series(data)[mask] + assert len(result) == 1 + assert result.dtype == data.dtype + + def test_getitem_mask_raises(self, data): + mask = np.array([True, False]) + msg = f"Boolean index has wrong length: 2 instead of {len(data)}" + with pytest.raises(IndexError, match=msg): + data[mask] + + mask = pd.array(mask, dtype="boolean") + with pytest.raises(IndexError, match=msg): + data[mask] + + def test_getitem_boolean_array_mask(self, data): + mask = pd.array(np.zeros(data.shape, dtype="bool"), dtype="boolean") + result = data[mask] + assert len(result) == 0 + assert isinstance(result, type(data)) + + result = pd.Series(data)[mask] + assert len(result) == 0 + assert result.dtype == data.dtype + + mask[:5] = True + expected = data.take([0, 1, 2, 3, 4]) + result = data[mask] + tm.assert_extension_array_equal(result, expected) + + expected = pd.Series(expected) + result = pd.Series(data)[mask] + tm.assert_series_equal(result, expected) + + def test_getitem_boolean_na_treated_as_false(self, data): + # https://github.com/pandas-dev/pandas/issues/31503 + mask = pd.array(np.zeros(data.shape, dtype="bool"), dtype="boolean") + mask[:2] = pd.NA + mask[2:4] = True + + result = data[mask] + expected = data[mask.fillna(False)] + + tm.assert_extension_array_equal(result, expected) + + s = pd.Series(data) + + result = s[mask] + expected = s[mask.fillna(False)] + + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "idx", + [[0, 1, 2], pd.array([0, 1, 2], dtype="Int64"), np.array([0, 1, 2])], + ids=["list", "integer-array", "numpy-array"], + ) + def test_getitem_integer_array(self, data, idx): + result = data[idx] + assert len(result) == 3 + assert isinstance(result, type(data)) + expected = data.take([0, 1, 2]) + tm.assert_extension_array_equal(result, expected) + + expected = pd.Series(expected) + result = pd.Series(data)[idx] + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "idx", + [[0, 1, 2, pd.NA], pd.array([0, 1, 2, pd.NA], dtype="Int64")], + ids=["list", "integer-array"], + ) + def test_getitem_integer_with_missing_raises(self, data, idx): + msg = "Cannot index with an integer indexer containing NA values" + with pytest.raises(ValueError, match=msg): + data[idx] + + @pytest.mark.xfail( + reason="Tries label-based and raises KeyError; " + "in some cases raises when calling np.asarray" + ) + @pytest.mark.parametrize( + "idx", + [[0, 1, 2, pd.NA], pd.array([0, 1, 2, pd.NA], dtype="Int64")], + ids=["list", "integer-array"], + ) + def test_getitem_series_integer_with_missing_raises(self, data, idx): + msg = "Cannot index with an integer indexer containing NA values" + # TODO: this raises KeyError about labels not found (it tries label-based) + + ser = pd.Series(data, index=[chr(100 + i) for i in range(len(data))]) + with pytest.raises(ValueError, match=msg): + ser[idx] + + def test_getitem_slice(self, data): + # getitem[slice] should return an array + result = data[slice(0)] # empty + assert isinstance(result, type(data)) + + result = data[slice(1)] # scalar + assert isinstance(result, type(data)) + + def test_getitem_ellipsis_and_slice(self, data): + # GH#40353 this is called from slice_block_rows + result = data[..., :] + tm.assert_extension_array_equal(result, data) + + result = data[:, ...] + tm.assert_extension_array_equal(result, data) + + result = data[..., :3] + tm.assert_extension_array_equal(result, data[:3]) + + result = data[:3, ...] + tm.assert_extension_array_equal(result, data[:3]) + + result = data[..., ::2] + tm.assert_extension_array_equal(result, data[::2]) + + result = data[::2, ...] + tm.assert_extension_array_equal(result, data[::2]) + + def test_get(self, data): + # GH 20882 + s = pd.Series(data, index=[2 * i for i in range(len(data))]) + assert s.get(4) == s.iloc[2] + + result = s.get([4, 6]) + expected = s.iloc[[2, 3]] + tm.assert_series_equal(result, expected) + + result = s.get(slice(2)) + expected = s.iloc[[0, 1]] + tm.assert_series_equal(result, expected) + + assert s.get(-1) is None + assert s.get(s.index.max() + 1) is None + + s = pd.Series(data[:6], index=list("abcdef")) + assert s.get("c") == s.iloc[2] + + result = s.get(slice("b", "d")) + expected = s.iloc[[1, 2, 3]] + tm.assert_series_equal(result, expected) + + result = s.get("Z") + assert result is None + + msg = "Series.__getitem__ treating keys as positions is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + assert s.get(4) == s.iloc[4] + assert s.get(-1) == s.iloc[-1] + assert s.get(len(s)) is None + + # GH 21257 + s = pd.Series(data) + with tm.assert_produces_warning(None): + # GH#45324 make sure we aren't giving a spurious FutureWarning + s2 = s[::2] + assert s2.get(1) is None + + def test_take_sequence(self, data): + result = pd.Series(data)[[0, 1, 3]] + assert result.iloc[0] == data[0] + assert result.iloc[1] == data[1] + assert result.iloc[2] == data[3] + + def test_take(self, data, na_value, na_cmp): + result = data.take([0, -1]) + assert result.dtype == data.dtype + assert result[0] == data[0] + assert result[1] == data[-1] + + result = data.take([0, -1], allow_fill=True, fill_value=na_value) + assert result[0] == data[0] + assert na_cmp(result[1], na_value) + + with pytest.raises(IndexError, match="out of bounds"): + data.take([len(data) + 1]) + + def test_take_empty(self, data, na_value, na_cmp): + empty = data[:0] + + result = empty.take([-1], allow_fill=True) + assert na_cmp(result[0], na_value) + + msg = "cannot do a non-empty take from an empty axes|out of bounds" + + with pytest.raises(IndexError, match=msg): + empty.take([-1]) + + with pytest.raises(IndexError, match="cannot do a non-empty take"): + empty.take([0, 1]) + + def test_take_negative(self, data): + # https://github.com/pandas-dev/pandas/issues/20640 + n = len(data) + result = data.take([0, -n, n - 1, -1]) + expected = data.take([0, 0, n - 1, n - 1]) + tm.assert_extension_array_equal(result, expected) + + def test_take_non_na_fill_value(self, data_missing): + fill_value = data_missing[1] # valid + na = data_missing[0] + + arr = data_missing._from_sequence( + [na, fill_value, na], dtype=data_missing.dtype + ) + result = arr.take([-1, 1], fill_value=fill_value, allow_fill=True) + expected = arr.take([1, 1]) + tm.assert_extension_array_equal(result, expected) + + def test_take_pandas_style_negative_raises(self, data, na_value): + with pytest.raises(ValueError, match=""): + data.take([0, -2], fill_value=na_value, allow_fill=True) + + @pytest.mark.parametrize("allow_fill", [True, False]) + def test_take_out_of_bounds_raises(self, data, allow_fill): + arr = data[:3] + + with pytest.raises(IndexError, match="out of bounds|out-of-bounds"): + arr.take(np.asarray([0, 3]), allow_fill=allow_fill) + + def test_take_series(self, data): + s = pd.Series(data) + result = s.take([0, -1]) + expected = pd.Series( + data._from_sequence([data[0], data[len(data) - 1]], dtype=s.dtype), + index=[0, len(data) - 1], + ) + tm.assert_series_equal(result, expected) + + def test_reindex(self, data, na_value): + s = pd.Series(data) + result = s.reindex([0, 1, 3]) + expected = pd.Series(data.take([0, 1, 3]), index=[0, 1, 3]) + tm.assert_series_equal(result, expected) + + n = len(data) + result = s.reindex([-1, 0, n]) + expected = pd.Series( + data._from_sequence([na_value, data[0], na_value], dtype=s.dtype), + index=[-1, 0, n], + ) + tm.assert_series_equal(result, expected) + + result = s.reindex([n, n + 1]) + expected = pd.Series( + data._from_sequence([na_value, na_value], dtype=s.dtype), index=[n, n + 1] + ) + tm.assert_series_equal(result, expected) + + def test_reindex_non_na_fill_value(self, data_missing): + valid = data_missing[1] + na = data_missing[0] + + arr = data_missing._from_sequence([na, valid], dtype=data_missing.dtype) + ser = pd.Series(arr) + result = ser.reindex([0, 1, 2], fill_value=valid) + expected = pd.Series( + data_missing._from_sequence([na, valid, valid], dtype=data_missing.dtype) + ) + + tm.assert_series_equal(result, expected) + + def test_loc_len1(self, data): + # see GH-27785 take_nd with indexer of len 1 resulting in wrong ndim + df = pd.DataFrame({"A": data}) + res = df.loc[[0], "A"] + assert res.ndim == 1 + assert res._mgr.arrays[0].ndim == 1 + if hasattr(res._mgr, "blocks"): + assert res._mgr._block.ndim == 1 + + def test_item(self, data): + # https://github.com/pandas-dev/pandas/pull/30175 + s = pd.Series(data) + result = s[:1].item() + assert result == data[0] + + msg = "can only convert an array of size 1 to a Python scalar" + with pytest.raises(ValueError, match=msg): + s[:0].item() + + with pytest.raises(ValueError, match=msg): + s.item() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/groupby.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/groupby.py new file mode 100644 index 0000000000000000000000000000000000000000..414683b02dcba24d9f490017c1e5c915ee0551f7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/groupby.py @@ -0,0 +1,174 @@ +import re + +import pytest + +from pandas.core.dtypes.common import ( + is_bool_dtype, + is_numeric_dtype, + is_object_dtype, + is_string_dtype, +) + +import pandas as pd +import pandas._testing as tm + + +@pytest.mark.filterwarnings( + "ignore:The default of observed=False is deprecated:FutureWarning" +) +class BaseGroupbyTests: + """Groupby-specific tests.""" + + def test_grouping_grouper(self, data_for_grouping): + df = pd.DataFrame( + { + "A": pd.Series( + ["B", "B", None, None, "A", "A", "B", "C"], dtype=object + ), + "B": data_for_grouping, + } + ) + gr1 = df.groupby("A")._grouper.groupings[0] + gr2 = df.groupby("B")._grouper.groupings[0] + + tm.assert_numpy_array_equal(gr1.grouping_vector, df.A.values) + tm.assert_extension_array_equal(gr2.grouping_vector, data_for_grouping) + + @pytest.mark.parametrize("as_index", [True, False]) + def test_groupby_extension_agg(self, as_index, data_for_grouping): + df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4], "B": data_for_grouping}) + + is_bool = data_for_grouping.dtype._is_boolean + if is_bool: + # only 2 unique values, and the final entry has c==b + # (see data_for_grouping docstring) + df = df.iloc[:-1] + + result = df.groupby("B", as_index=as_index).A.mean() + _, uniques = pd.factorize(data_for_grouping, sort=True) + + exp_vals = [3.0, 1.0, 4.0] + if is_bool: + exp_vals = exp_vals[:-1] + if as_index: + index = pd.Index(uniques, name="B") + expected = pd.Series(exp_vals, index=index, name="A") + tm.assert_series_equal(result, expected) + else: + expected = pd.DataFrame({"B": uniques, "A": exp_vals}) + tm.assert_frame_equal(result, expected) + + def test_groupby_agg_extension(self, data_for_grouping): + # GH#38980 groupby agg on extension type fails for non-numeric types + df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4], "B": data_for_grouping}) + + expected = df.iloc[[0, 2, 4, 7]] + expected = expected.set_index("A") + + result = df.groupby("A").agg({"B": "first"}) + tm.assert_frame_equal(result, expected) + + result = df.groupby("A").agg("first") + tm.assert_frame_equal(result, expected) + + result = df.groupby("A").first() + tm.assert_frame_equal(result, expected) + + def test_groupby_extension_no_sort(self, data_for_grouping): + df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4], "B": data_for_grouping}) + + is_bool = data_for_grouping.dtype._is_boolean + if is_bool: + # only 2 unique values, and the final entry has c==b + # (see data_for_grouping docstring) + df = df.iloc[:-1] + + result = df.groupby("B", sort=False).A.mean() + _, index = pd.factorize(data_for_grouping, sort=False) + + index = pd.Index(index, name="B") + exp_vals = [1.0, 3.0, 4.0] + if is_bool: + exp_vals = exp_vals[:-1] + expected = pd.Series(exp_vals, index=index, name="A") + tm.assert_series_equal(result, expected) + + def test_groupby_extension_transform(self, data_for_grouping): + is_bool = data_for_grouping.dtype._is_boolean + + valid = data_for_grouping[~data_for_grouping.isna()] + df = pd.DataFrame({"A": [1, 1, 3, 3, 1, 4], "B": valid}) + is_bool = data_for_grouping.dtype._is_boolean + if is_bool: + # only 2 unique values, and the final entry has c==b + # (see data_for_grouping docstring) + df = df.iloc[:-1] + + result = df.groupby("B").A.transform(len) + expected = pd.Series([3, 3, 2, 2, 3, 1], name="A") + if is_bool: + expected = expected[:-1] + + tm.assert_series_equal(result, expected) + + def test_groupby_extension_apply(self, data_for_grouping, groupby_apply_op): + df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4], "B": data_for_grouping}) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + df.groupby("B", group_keys=False, observed=False).apply(groupby_apply_op) + df.groupby("B", group_keys=False, observed=False).A.apply(groupby_apply_op) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + df.groupby("A", group_keys=False, observed=False).apply(groupby_apply_op) + df.groupby("A", group_keys=False, observed=False).B.apply(groupby_apply_op) + + def test_groupby_apply_identity(self, data_for_grouping): + df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4], "B": data_for_grouping}) + result = df.groupby("A").B.apply(lambda x: x.array) + expected = pd.Series( + [ + df.B.iloc[[0, 1, 6]].array, + df.B.iloc[[2, 3]].array, + df.B.iloc[[4, 5]].array, + df.B.iloc[[7]].array, + ], + index=pd.Index([1, 2, 3, 4], name="A"), + name="B", + ) + tm.assert_series_equal(result, expected) + + def test_in_numeric_groupby(self, data_for_grouping): + df = pd.DataFrame( + { + "A": [1, 1, 2, 2, 3, 3, 1, 4], + "B": data_for_grouping, + "C": [1, 1, 1, 1, 1, 1, 1, 1], + } + ) + + dtype = data_for_grouping.dtype + if ( + is_numeric_dtype(dtype) + or is_bool_dtype(dtype) + or dtype.name == "decimal" + or is_string_dtype(dtype) + or is_object_dtype(dtype) + or dtype.kind == "m" # in particular duration[*][pyarrow] + ): + expected = pd.Index(["B", "C"]) + result = df.groupby("A").sum().columns + else: + expected = pd.Index(["C"]) + + msg = "|".join( + [ + # period/datetime + "does not support sum operations", + # all others + re.escape(f"agg function failed [how->sum,dtype->{dtype}"), + ] + ) + with pytest.raises(TypeError, match=msg): + df.groupby("A").sum() + result = df.groupby("A").sum(numeric_only=True).columns + tm.assert_index_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/index.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/index.py new file mode 100644 index 0000000000000000000000000000000000000000..72c4ebfb5d84ae82f74a4c814e72372a651ae6b8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/index.py @@ -0,0 +1,19 @@ +""" +Tests for Indexes backed by arbitrary ExtensionArrays. +""" +import pandas as pd + + +class BaseIndexTests: + """Tests for Index object backed by an ExtensionArray""" + + def test_index_from_array(self, data): + idx = pd.Index(data) + assert data.dtype == idx.dtype + + def test_index_from_listlike_with_dtype(self, data): + idx = pd.Index(data, dtype=data.dtype) + assert idx.dtype == data.dtype + + idx = pd.Index(list(data), dtype=data.dtype) + assert idx.dtype == data.dtype diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/interface.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..6683c87e2b8fc93d71b84918b3638d0879ce3372 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/interface.py @@ -0,0 +1,137 @@ +import numpy as np +import pytest + +from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike +from pandas.core.dtypes.common import is_extension_array_dtype +from pandas.core.dtypes.dtypes import ExtensionDtype + +import pandas as pd +import pandas._testing as tm + + +class BaseInterfaceTests: + """Tests that the basic interface is satisfied.""" + + # ------------------------------------------------------------------------ + # Interface + # ------------------------------------------------------------------------ + + def test_len(self, data): + assert len(data) == 100 + + def test_size(self, data): + assert data.size == 100 + + def test_ndim(self, data): + assert data.ndim == 1 + + def test_can_hold_na_valid(self, data): + # GH-20761 + assert data._can_hold_na is True + + def test_contains(self, data, data_missing): + # GH-37867 + # Tests for membership checks. Membership checks for nan-likes is tricky and + # the settled on rule is: `nan_like in arr` is True if nan_like is + # arr.dtype.na_value and arr.isna().any() is True. Else the check returns False. + + na_value = data.dtype.na_value + # ensure data without missing values + data = data[~data.isna()] + + # first elements are non-missing + assert data[0] in data + assert data_missing[0] in data_missing + + # check the presence of na_value + assert na_value in data_missing + assert na_value not in data + + # the data can never contain other nan-likes than na_value + for na_value_obj in tm.NULL_OBJECTS: + if na_value_obj is na_value or type(na_value_obj) == type(na_value): + # type check for e.g. two instances of Decimal("NAN") + continue + assert na_value_obj not in data + assert na_value_obj not in data_missing + + def test_memory_usage(self, data): + s = pd.Series(data) + result = s.memory_usage(index=False) + assert result == s.nbytes + + def test_array_interface(self, data): + result = np.array(data) + assert result[0] == data[0] + + result = np.array(data, dtype=object) + expected = np.array(list(data), dtype=object) + if expected.ndim > 1: + # nested data, explicitly construct as 1D + expected = construct_1d_object_array_from_listlike(list(data)) + tm.assert_numpy_array_equal(result, expected) + + def test_is_extension_array_dtype(self, data): + assert is_extension_array_dtype(data) + assert is_extension_array_dtype(data.dtype) + assert is_extension_array_dtype(pd.Series(data)) + assert isinstance(data.dtype, ExtensionDtype) + + def test_no_values_attribute(self, data): + # GH-20735: EA's with .values attribute give problems with internal + # code, disallowing this for now until solved + assert not hasattr(data, "values") + assert not hasattr(data, "_values") + + def test_is_numeric_honored(self, data): + result = pd.Series(data) + if hasattr(result._mgr, "blocks"): + assert result._mgr.blocks[0].is_numeric is data.dtype._is_numeric + + def test_isna_extension_array(self, data_missing): + # If your `isna` returns an ExtensionArray, you must also implement + # _reduce. At the *very* least, you must implement any and all + na = data_missing.isna() + if is_extension_array_dtype(na): + assert na._reduce("any") + assert na.any() + + assert not na._reduce("all") + assert not na.all() + + assert na.dtype._is_boolean + + def test_copy(self, data): + # GH#27083 removing deep keyword from EA.copy + assert data[0] != data[1] + result = data.copy() + + if data.dtype._is_immutable: + pytest.skip(f"test_copy assumes mutability and {data.dtype} is immutable") + + data[1] = data[0] + assert result[1] != result[0] + + def test_view(self, data): + # view with no dtype should return a shallow copy, *not* the same + # object + assert data[1] != data[0] + + result = data.view() + assert result is not data + assert type(result) == type(data) + + if data.dtype._is_immutable: + pytest.skip(f"test_view assumes mutability and {data.dtype} is immutable") + + result[1] = result[0] + assert data[1] == data[0] + + # check specifically that the `dtype` kwarg is accepted + data.view(dtype=None) + + def test_tolist(self, data): + result = data.tolist() + expected = list(data) + assert isinstance(result, list) + assert result == expected diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/io.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/io.py new file mode 100644 index 0000000000000000000000000000000000000000..3a6f2eb5ba8b1854ac48a22efa02e89672b2f2ac --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/io.py @@ -0,0 +1,39 @@ +from io import StringIO + +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm +from pandas.core.arrays import ExtensionArray + + +class BaseParsingTests: + @pytest.mark.parametrize("engine", ["c", "python"]) + def test_EA_types(self, engine, data, request): + if isinstance(data.dtype, pd.CategoricalDtype): + # in parsers.pyx _convert_with_dtype there is special-casing for + # Categorical that pre-empts _from_sequence_of_strings + pass + elif isinstance(data.dtype, pd.core.dtypes.dtypes.NumpyEADtype): + # These get unwrapped internally so are treated as numpy dtypes + # in the parsers.pyx code + pass + elif ( + type(data)._from_sequence_of_strings.__func__ + is ExtensionArray._from_sequence_of_strings.__func__ + ): + # i.e. the EA hasn't overridden _from_sequence_of_strings + mark = pytest.mark.xfail( + reason="_from_sequence_of_strings not implemented", + raises=NotImplementedError, + ) + request.node.add_marker(mark) + + df = pd.DataFrame({"with_dtype": pd.Series(data, dtype=str(data.dtype))}) + csv_output = df.to_csv(index=False, na_rep=np.nan) + result = pd.read_csv( + StringIO(csv_output), dtype={"with_dtype": str(data.dtype)}, engine=engine + ) + expected = df + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/methods.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/methods.py new file mode 100644 index 0000000000000000000000000000000000000000..c803a8113b4a46be4dd99b06cb1026317786a241 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/methods.py @@ -0,0 +1,720 @@ +import inspect +import operator + +import numpy as np +import pytest + +from pandas._typing import Dtype + +from pandas.core.dtypes.common import is_bool_dtype +from pandas.core.dtypes.dtypes import NumpyEADtype +from pandas.core.dtypes.missing import na_value_for_dtype + +import pandas as pd +import pandas._testing as tm +from pandas.core.sorting import nargsort + + +class BaseMethodsTests: + """Various Series and DataFrame methods.""" + + def test_hash_pandas_object(self, data): + # _hash_pandas_object should return a uint64 ndarray of the same length + # as the data + from pandas.core.util.hashing import _default_hash_key + + res = data._hash_pandas_object( + encoding="utf-8", hash_key=_default_hash_key, categorize=False + ) + assert res.dtype == np.uint64 + assert res.shape == data.shape + + def test_value_counts_default_dropna(self, data): + # make sure we have consistent default dropna kwarg + if not hasattr(data, "value_counts"): + pytest.skip(f"value_counts is not implemented for {type(data)}") + sig = inspect.signature(data.value_counts) + kwarg = sig.parameters["dropna"] + assert kwarg.default is True + + @pytest.mark.parametrize("dropna", [True, False]) + def test_value_counts(self, all_data, dropna): + all_data = all_data[:10] + if dropna: + other = all_data[~all_data.isna()] + else: + other = all_data + + result = pd.Series(all_data).value_counts(dropna=dropna).sort_index() + expected = pd.Series(other).value_counts(dropna=dropna).sort_index() + + tm.assert_series_equal(result, expected) + + def test_value_counts_with_normalize(self, data): + # GH 33172 + data = data[:10].unique() + values = np.array(data[~data.isna()]) + ser = pd.Series(data, dtype=data.dtype) + + result = ser.value_counts(normalize=True).sort_index() + + if not isinstance(data, pd.Categorical): + expected = pd.Series( + [1 / len(values)] * len(values), index=result.index, name="proportion" + ) + else: + expected = pd.Series(0.0, index=result.index, name="proportion") + expected[result > 0] = 1 / len(values) + + if getattr(data.dtype, "storage", "") == "pyarrow" or isinstance( + data.dtype, pd.ArrowDtype + ): + # TODO: avoid special-casing + expected = expected.astype("double[pyarrow]") + elif getattr(data.dtype, "storage", "") == "pyarrow_numpy": + # TODO: avoid special-casing + expected = expected.astype("float64") + elif na_value_for_dtype(data.dtype) is pd.NA: + # TODO(GH#44692): avoid special-casing + expected = expected.astype("Float64") + + tm.assert_series_equal(result, expected) + + def test_count(self, data_missing): + df = pd.DataFrame({"A": data_missing}) + result = df.count(axis="columns") + expected = pd.Series([0, 1]) + tm.assert_series_equal(result, expected) + + def test_series_count(self, data_missing): + # GH#26835 + ser = pd.Series(data_missing) + result = ser.count() + expected = 1 + assert result == expected + + def test_apply_simple_series(self, data): + result = pd.Series(data).apply(id) + assert isinstance(result, pd.Series) + + @pytest.mark.parametrize("na_action", [None, "ignore"]) + def test_map(self, data_missing, na_action): + result = data_missing.map(lambda x: x, na_action=na_action) + expected = data_missing.to_numpy() + tm.assert_numpy_array_equal(result, expected) + + def test_argsort(self, data_for_sorting): + result = pd.Series(data_for_sorting).argsort() + # argsort result gets passed to take, so should be np.intp + expected = pd.Series(np.array([2, 0, 1], dtype=np.intp)) + tm.assert_series_equal(result, expected) + + def test_argsort_missing_array(self, data_missing_for_sorting): + result = data_missing_for_sorting.argsort() + # argsort result gets passed to take, so should be np.intp + expected = np.array([2, 0, 1], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + def test_argsort_missing(self, data_missing_for_sorting): + msg = "The behavior of Series.argsort in the presence of NA values" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = pd.Series(data_missing_for_sorting).argsort() + expected = pd.Series(np.array([1, -1, 0], dtype=np.intp)) + tm.assert_series_equal(result, expected) + + def test_argmin_argmax(self, data_for_sorting, data_missing_for_sorting, na_value): + # GH 24382 + is_bool = data_for_sorting.dtype._is_boolean + + exp_argmax = 1 + exp_argmax_repeated = 3 + if is_bool: + # See data_for_sorting docstring + exp_argmax = 0 + exp_argmax_repeated = 1 + + # data_for_sorting -> [B, C, A] with A < B < C + assert data_for_sorting.argmax() == exp_argmax + assert data_for_sorting.argmin() == 2 + + # with repeated values -> first occurrence + data = data_for_sorting.take([2, 0, 0, 1, 1, 2]) + assert data.argmax() == exp_argmax_repeated + assert data.argmin() == 0 + + # with missing values + # data_missing_for_sorting -> [B, NA, A] with A < B and NA missing. + assert data_missing_for_sorting.argmax() == 0 + assert data_missing_for_sorting.argmin() == 2 + + @pytest.mark.parametrize("method", ["argmax", "argmin"]) + def test_argmin_argmax_empty_array(self, method, data): + # GH 24382 + err_msg = "attempt to get" + with pytest.raises(ValueError, match=err_msg): + getattr(data[:0], method)() + + @pytest.mark.parametrize("method", ["argmax", "argmin"]) + def test_argmin_argmax_all_na(self, method, data, na_value): + # all missing with skipna=True is the same as empty + err_msg = "attempt to get" + data_na = type(data)._from_sequence([na_value, na_value], dtype=data.dtype) + with pytest.raises(ValueError, match=err_msg): + getattr(data_na, method)() + + @pytest.mark.parametrize( + "op_name, skipna, expected", + [ + ("idxmax", True, 0), + ("idxmin", True, 2), + ("argmax", True, 0), + ("argmin", True, 2), + ("idxmax", False, np.nan), + ("idxmin", False, np.nan), + ("argmax", False, -1), + ("argmin", False, -1), + ], + ) + def test_argreduce_series( + self, data_missing_for_sorting, op_name, skipna, expected + ): + # data_missing_for_sorting -> [B, NA, A] with A < B and NA missing. + warn = None + msg = "The behavior of Series.argmax/argmin" + if op_name.startswith("arg") and expected == -1: + warn = FutureWarning + if op_name.startswith("idx") and np.isnan(expected): + warn = FutureWarning + msg = f"The behavior of Series.{op_name}" + ser = pd.Series(data_missing_for_sorting) + with tm.assert_produces_warning(warn, match=msg): + result = getattr(ser, op_name)(skipna=skipna) + tm.assert_almost_equal(result, expected) + + def test_argmax_argmin_no_skipna_notimplemented(self, data_missing_for_sorting): + # GH#38733 + data = data_missing_for_sorting + + with pytest.raises(NotImplementedError, match=""): + data.argmin(skipna=False) + + with pytest.raises(NotImplementedError, match=""): + data.argmax(skipna=False) + + @pytest.mark.parametrize( + "na_position, expected", + [ + ("last", np.array([2, 0, 1], dtype=np.dtype("intp"))), + ("first", np.array([1, 2, 0], dtype=np.dtype("intp"))), + ], + ) + def test_nargsort(self, data_missing_for_sorting, na_position, expected): + # GH 25439 + result = nargsort(data_missing_for_sorting, na_position=na_position) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("ascending", [True, False]) + def test_sort_values(self, data_for_sorting, ascending, sort_by_key): + ser = pd.Series(data_for_sorting) + result = ser.sort_values(ascending=ascending, key=sort_by_key) + expected = ser.iloc[[2, 0, 1]] + if not ascending: + # GH 35922. Expect stable sort + if ser.nunique() == 2: + expected = ser.iloc[[0, 1, 2]] + else: + expected = ser.iloc[[1, 0, 2]] + + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("ascending", [True, False]) + def test_sort_values_missing( + self, data_missing_for_sorting, ascending, sort_by_key + ): + ser = pd.Series(data_missing_for_sorting) + result = ser.sort_values(ascending=ascending, key=sort_by_key) + if ascending: + expected = ser.iloc[[2, 0, 1]] + else: + expected = ser.iloc[[0, 2, 1]] + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("ascending", [True, False]) + def test_sort_values_frame(self, data_for_sorting, ascending): + df = pd.DataFrame({"A": [1, 2, 1], "B": data_for_sorting}) + result = df.sort_values(["A", "B"]) + expected = pd.DataFrame( + {"A": [1, 1, 2], "B": data_for_sorting.take([2, 0, 1])}, index=[2, 0, 1] + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("keep", ["first", "last", False]) + def test_duplicated(self, data, keep): + arr = data.take([0, 1, 0, 1]) + result = arr.duplicated(keep=keep) + if keep == "first": + expected = np.array([False, False, True, True]) + elif keep == "last": + expected = np.array([True, True, False, False]) + else: + expected = np.array([True, True, True, True]) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("box", [pd.Series, lambda x: x]) + @pytest.mark.parametrize("method", [lambda x: x.unique(), pd.unique]) + def test_unique(self, data, box, method): + duplicated = box(data._from_sequence([data[0], data[0]], dtype=data.dtype)) + + result = method(duplicated) + + assert len(result) == 1 + assert isinstance(result, type(data)) + assert result[0] == duplicated[0] + + def test_factorize(self, data_for_grouping): + codes, uniques = pd.factorize(data_for_grouping, use_na_sentinel=True) + + is_bool = data_for_grouping.dtype._is_boolean + if is_bool: + # only 2 unique values + expected_codes = np.array([0, 0, -1, -1, 1, 1, 0, 0], dtype=np.intp) + expected_uniques = data_for_grouping.take([0, 4]) + else: + expected_codes = np.array([0, 0, -1, -1, 1, 1, 0, 2], dtype=np.intp) + expected_uniques = data_for_grouping.take([0, 4, 7]) + + tm.assert_numpy_array_equal(codes, expected_codes) + tm.assert_extension_array_equal(uniques, expected_uniques) + + def test_factorize_equivalence(self, data_for_grouping): + codes_1, uniques_1 = pd.factorize(data_for_grouping, use_na_sentinel=True) + codes_2, uniques_2 = data_for_grouping.factorize(use_na_sentinel=True) + + tm.assert_numpy_array_equal(codes_1, codes_2) + tm.assert_extension_array_equal(uniques_1, uniques_2) + assert len(uniques_1) == len(pd.unique(uniques_1)) + assert uniques_1.dtype == data_for_grouping.dtype + + def test_factorize_empty(self, data): + codes, uniques = pd.factorize(data[:0]) + expected_codes = np.array([], dtype=np.intp) + expected_uniques = type(data)._from_sequence([], dtype=data[:0].dtype) + + tm.assert_numpy_array_equal(codes, expected_codes) + tm.assert_extension_array_equal(uniques, expected_uniques) + + def test_fillna_copy_frame(self, data_missing): + arr = data_missing.take([1, 1]) + df = pd.DataFrame({"A": arr}) + df_orig = df.copy() + + filled_val = df.iloc[0, 0] + result = df.fillna(filled_val) + + result.iloc[0, 0] = filled_val + + tm.assert_frame_equal(df, df_orig) + + def test_fillna_copy_series(self, data_missing): + arr = data_missing.take([1, 1]) + ser = pd.Series(arr, copy=False) + ser_orig = ser.copy() + + filled_val = ser[0] + result = ser.fillna(filled_val) + result.iloc[0] = filled_val + + tm.assert_series_equal(ser, ser_orig) + + def test_fillna_length_mismatch(self, data_missing): + msg = "Length of 'value' does not match." + with pytest.raises(ValueError, match=msg): + data_missing.fillna(data_missing.take([1])) + + # Subclasses can override if we expect e.g Sparse[bool], boolean, pyarrow[bool] + _combine_le_expected_dtype: Dtype = NumpyEADtype("bool") + + def test_combine_le(self, data_repeated): + # GH 20825 + # Test that combine works when doing a <= (le) comparison + orig_data1, orig_data2 = data_repeated(2) + s1 = pd.Series(orig_data1) + s2 = pd.Series(orig_data2) + result = s1.combine(s2, lambda x1, x2: x1 <= x2) + expected = pd.Series( + pd.array( + [a <= b for (a, b) in zip(list(orig_data1), list(orig_data2))], + dtype=self._combine_le_expected_dtype, + ) + ) + tm.assert_series_equal(result, expected) + + val = s1.iloc[0] + result = s1.combine(val, lambda x1, x2: x1 <= x2) + expected = pd.Series( + pd.array( + [a <= val for a in list(orig_data1)], + dtype=self._combine_le_expected_dtype, + ) + ) + tm.assert_series_equal(result, expected) + + def test_combine_add(self, data_repeated): + # GH 20825 + orig_data1, orig_data2 = data_repeated(2) + s1 = pd.Series(orig_data1) + s2 = pd.Series(orig_data2) + + # Check if the operation is supported pointwise for our scalars. If not, + # we will expect Series.combine to raise as well. + try: + with np.errstate(over="ignore"): + expected = pd.Series( + orig_data1._from_sequence( + [a + b for (a, b) in zip(list(orig_data1), list(orig_data2))] + ) + ) + except TypeError: + # If the operation is not supported pointwise for our scalars, + # then Series.combine should also raise + with pytest.raises(TypeError): + s1.combine(s2, lambda x1, x2: x1 + x2) + return + + result = s1.combine(s2, lambda x1, x2: x1 + x2) + tm.assert_series_equal(result, expected) + + val = s1.iloc[0] + result = s1.combine(val, lambda x1, x2: x1 + x2) + expected = pd.Series( + orig_data1._from_sequence([a + val for a in list(orig_data1)]) + ) + tm.assert_series_equal(result, expected) + + def test_combine_first(self, data): + # https://github.com/pandas-dev/pandas/issues/24147 + a = pd.Series(data[:3]) + b = pd.Series(data[2:5], index=[2, 3, 4]) + result = a.combine_first(b) + expected = pd.Series(data[:5]) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("frame", [True, False]) + @pytest.mark.parametrize( + "periods, indices", + [(-2, [2, 3, 4, -1, -1]), (0, [0, 1, 2, 3, 4]), (2, [-1, -1, 0, 1, 2])], + ) + def test_container_shift(self, data, frame, periods, indices): + # https://github.com/pandas-dev/pandas/issues/22386 + subset = data[:5] + data = pd.Series(subset, name="A") + expected = pd.Series(subset.take(indices, allow_fill=True), name="A") + + if frame: + result = data.to_frame(name="A").assign(B=1).shift(periods) + expected = pd.concat( + [expected, pd.Series([1] * 5, name="B").shift(periods)], axis=1 + ) + compare = tm.assert_frame_equal + else: + result = data.shift(periods) + compare = tm.assert_series_equal + + compare(result, expected) + + def test_shift_0_periods(self, data): + # GH#33856 shifting with periods=0 should return a copy, not same obj + result = data.shift(0) + assert data[0] != data[1] # otherwise below is invalid + data[0] = data[1] + assert result[0] != result[1] # i.e. not the same object/view + + @pytest.mark.parametrize("periods", [1, -2]) + def test_diff(self, data, periods): + data = data[:5] + if is_bool_dtype(data.dtype): + op = operator.xor + else: + op = operator.sub + try: + # does this array implement ops? + op(data, data) + except Exception: + pytest.skip(f"{type(data)} does not support diff") + s = pd.Series(data) + result = s.diff(periods) + expected = pd.Series(op(data, data.shift(periods))) + tm.assert_series_equal(result, expected) + + df = pd.DataFrame({"A": data, "B": [1.0] * 5}) + result = df.diff(periods) + if periods == 1: + b = [np.nan, 0, 0, 0, 0] + else: + b = [0, 0, 0, np.nan, np.nan] + expected = pd.DataFrame({"A": expected, "B": b}) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "periods, indices", + [[-4, [-1, -1]], [-1, [1, -1]], [0, [0, 1]], [1, [-1, 0]], [4, [-1, -1]]], + ) + def test_shift_non_empty_array(self, data, periods, indices): + # https://github.com/pandas-dev/pandas/issues/23911 + subset = data[:2] + result = subset.shift(periods) + expected = subset.take(indices, allow_fill=True) + tm.assert_extension_array_equal(result, expected) + + @pytest.mark.parametrize("periods", [-4, -1, 0, 1, 4]) + def test_shift_empty_array(self, data, periods): + # https://github.com/pandas-dev/pandas/issues/23911 + empty = data[:0] + result = empty.shift(periods) + expected = empty + tm.assert_extension_array_equal(result, expected) + + def test_shift_zero_copies(self, data): + # GH#31502 + result = data.shift(0) + assert result is not data + + result = data[:0].shift(2) + assert result is not data + + def test_shift_fill_value(self, data): + arr = data[:4] + fill_value = data[0] + result = arr.shift(1, fill_value=fill_value) + expected = data.take([0, 0, 1, 2]) + tm.assert_extension_array_equal(result, expected) + + result = arr.shift(-2, fill_value=fill_value) + expected = data.take([2, 3, 0, 0]) + tm.assert_extension_array_equal(result, expected) + + def test_not_hashable(self, data): + # We are in general mutable, so not hashable + with pytest.raises(TypeError, match="unhashable type"): + hash(data) + + def test_hash_pandas_object_works(self, data, as_frame): + # https://github.com/pandas-dev/pandas/issues/23066 + data = pd.Series(data) + if as_frame: + data = data.to_frame() + a = pd.util.hash_pandas_object(data) + b = pd.util.hash_pandas_object(data) + tm.assert_equal(a, b) + + def test_searchsorted(self, data_for_sorting, as_series): + if data_for_sorting.dtype._is_boolean: + return self._test_searchsorted_bool_dtypes(data_for_sorting, as_series) + + b, c, a = data_for_sorting + arr = data_for_sorting.take([2, 0, 1]) # to get [a, b, c] + + if as_series: + arr = pd.Series(arr) + assert arr.searchsorted(a) == 0 + assert arr.searchsorted(a, side="right") == 1 + + assert arr.searchsorted(b) == 1 + assert arr.searchsorted(b, side="right") == 2 + + assert arr.searchsorted(c) == 2 + assert arr.searchsorted(c, side="right") == 3 + + result = arr.searchsorted(arr.take([0, 2])) + expected = np.array([0, 2], dtype=np.intp) + + tm.assert_numpy_array_equal(result, expected) + + # sorter + sorter = np.array([1, 2, 0]) + assert data_for_sorting.searchsorted(a, sorter=sorter) == 0 + + def _test_searchsorted_bool_dtypes(self, data_for_sorting, as_series): + # We call this from test_searchsorted in cases where we have a + # boolean-like dtype. The non-bool test assumes we have more than 2 + # unique values. + dtype = data_for_sorting.dtype + data_for_sorting = pd.array([True, False], dtype=dtype) + b, a = data_for_sorting + arr = type(data_for_sorting)._from_sequence([a, b]) + + if as_series: + arr = pd.Series(arr) + assert arr.searchsorted(a) == 0 + assert arr.searchsorted(a, side="right") == 1 + + assert arr.searchsorted(b) == 1 + assert arr.searchsorted(b, side="right") == 2 + + result = arr.searchsorted(arr.take([0, 1])) + expected = np.array([0, 1], dtype=np.intp) + + tm.assert_numpy_array_equal(result, expected) + + # sorter + sorter = np.array([1, 0]) + assert data_for_sorting.searchsorted(a, sorter=sorter) == 0 + + def test_where_series(self, data, na_value, as_frame): + assert data[0] != data[1] + cls = type(data) + a, b = data[:2] + + orig = pd.Series(cls._from_sequence([a, a, b, b], dtype=data.dtype)) + ser = orig.copy() + cond = np.array([True, True, False, False]) + + if as_frame: + ser = ser.to_frame(name="a") + cond = cond.reshape(-1, 1) + + result = ser.where(cond) + expected = pd.Series( + cls._from_sequence([a, a, na_value, na_value], dtype=data.dtype) + ) + + if as_frame: + expected = expected.to_frame(name="a") + tm.assert_equal(result, expected) + + ser.mask(~cond, inplace=True) + tm.assert_equal(ser, expected) + + # array other + ser = orig.copy() + if as_frame: + ser = ser.to_frame(name="a") + cond = np.array([True, False, True, True]) + other = cls._from_sequence([a, b, a, b], dtype=data.dtype) + if as_frame: + other = pd.DataFrame({"a": other}) + cond = pd.DataFrame({"a": cond}) + result = ser.where(cond, other) + expected = pd.Series(cls._from_sequence([a, b, b, b], dtype=data.dtype)) + if as_frame: + expected = expected.to_frame(name="a") + tm.assert_equal(result, expected) + + ser.mask(~cond, other, inplace=True) + tm.assert_equal(ser, expected) + + @pytest.mark.parametrize("repeats", [0, 1, 2, [1, 2, 3]]) + def test_repeat(self, data, repeats, as_series, use_numpy): + arr = type(data)._from_sequence(data[:3], dtype=data.dtype) + if as_series: + arr = pd.Series(arr) + + result = np.repeat(arr, repeats) if use_numpy else arr.repeat(repeats) + + repeats = [repeats] * 3 if isinstance(repeats, int) else repeats + expected = [x for x, n in zip(arr, repeats) for _ in range(n)] + expected = type(data)._from_sequence(expected, dtype=data.dtype) + if as_series: + expected = pd.Series(expected, index=arr.index.repeat(repeats)) + + tm.assert_equal(result, expected) + + @pytest.mark.parametrize( + "repeats, kwargs, error, msg", + [ + (2, {"axis": 1}, ValueError, "axis"), + (-1, {}, ValueError, "negative"), + ([1, 2], {}, ValueError, "shape"), + (2, {"foo": "bar"}, TypeError, "'foo'"), + ], + ) + def test_repeat_raises(self, data, repeats, kwargs, error, msg, use_numpy): + with pytest.raises(error, match=msg): + if use_numpy: + np.repeat(data, repeats, **kwargs) + else: + data.repeat(repeats, **kwargs) + + def test_delete(self, data): + result = data.delete(0) + expected = data[1:] + tm.assert_extension_array_equal(result, expected) + + result = data.delete([1, 3]) + expected = data._concat_same_type([data[[0]], data[[2]], data[4:]]) + tm.assert_extension_array_equal(result, expected) + + def test_insert(self, data): + # insert at the beginning + result = data[1:].insert(0, data[0]) + tm.assert_extension_array_equal(result, data) + + result = data[1:].insert(-len(data[1:]), data[0]) + tm.assert_extension_array_equal(result, data) + + # insert at the middle + result = data[:-1].insert(4, data[-1]) + + taker = np.arange(len(data)) + taker[5:] = taker[4:-1] + taker[4] = len(data) - 1 + expected = data.take(taker) + tm.assert_extension_array_equal(result, expected) + + def test_insert_invalid(self, data, invalid_scalar): + item = invalid_scalar + + with pytest.raises((TypeError, ValueError)): + data.insert(0, item) + + with pytest.raises((TypeError, ValueError)): + data.insert(4, item) + + with pytest.raises((TypeError, ValueError)): + data.insert(len(data) - 1, item) + + def test_insert_invalid_loc(self, data): + ub = len(data) + + with pytest.raises(IndexError): + data.insert(ub + 1, data[0]) + + with pytest.raises(IndexError): + data.insert(-ub - 1, data[0]) + + with pytest.raises(TypeError): + # we expect TypeError here instead of IndexError to match np.insert + data.insert(1.5, data[0]) + + @pytest.mark.parametrize("box", [pd.array, pd.Series, pd.DataFrame]) + def test_equals(self, data, na_value, as_series, box): + data2 = type(data)._from_sequence([data[0]] * len(data), dtype=data.dtype) + data_na = type(data)._from_sequence([na_value] * len(data), dtype=data.dtype) + + data = tm.box_expected(data, box, transpose=False) + data2 = tm.box_expected(data2, box, transpose=False) + data_na = tm.box_expected(data_na, box, transpose=False) + + # we are asserting with `is True/False` explicitly, to test that the + # result is an actual Python bool, and not something "truthy" + + assert data.equals(data) is True + assert data.equals(data.copy()) is True + + # unequal other data + assert data.equals(data2) is False + assert data.equals(data_na) is False + + # different length + assert data[:2].equals(data[:3]) is False + + # empty are equal + assert data[:0].equals(data[:0]) is True + + # other types + assert data.equals(None) is False + assert data[[0]].equals(data[0]) is False + + def test_equals_same_data_different_object(self, data): + # https://github.com/pandas-dev/pandas/issues/34660 + assert pd.Series(data).equals(pd.Series(data)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/missing.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/missing.py new file mode 100644 index 0000000000000000000000000000000000000000..fb15b2dec869c7d311ddf0b52d13de4e66078dbc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/missing.py @@ -0,0 +1,190 @@ +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm + + +class BaseMissingTests: + def test_isna(self, data_missing): + expected = np.array([True, False]) + + result = pd.isna(data_missing) + tm.assert_numpy_array_equal(result, expected) + + result = pd.Series(data_missing).isna() + expected = pd.Series(expected) + tm.assert_series_equal(result, expected) + + # GH 21189 + result = pd.Series(data_missing).drop([0, 1]).isna() + expected = pd.Series([], dtype=bool) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("na_func", ["isna", "notna"]) + def test_isna_returns_copy(self, data_missing, na_func): + result = pd.Series(data_missing) + expected = result.copy() + mask = getattr(result, na_func)() + if isinstance(mask.dtype, pd.SparseDtype): + # TODO: GH 57739 + mask = np.array(mask) + mask.flags.writeable = True + + mask[:] = True + tm.assert_series_equal(result, expected) + + def test_dropna_array(self, data_missing): + result = data_missing.dropna() + expected = data_missing[[1]] + tm.assert_extension_array_equal(result, expected) + + def test_dropna_series(self, data_missing): + ser = pd.Series(data_missing) + result = ser.dropna() + expected = ser.iloc[[1]] + tm.assert_series_equal(result, expected) + + def test_dropna_frame(self, data_missing): + df = pd.DataFrame({"A": data_missing}, columns=pd.Index(["A"], dtype=object)) + + # defaults + result = df.dropna() + expected = df.iloc[[1]] + tm.assert_frame_equal(result, expected) + + # axis = 1 + result = df.dropna(axis="columns") + expected = pd.DataFrame(index=pd.RangeIndex(2), columns=pd.Index([])) + tm.assert_frame_equal(result, expected) + + # multiple + df = pd.DataFrame({"A": data_missing, "B": [1, np.nan]}) + result = df.dropna() + expected = df.iloc[:0] + tm.assert_frame_equal(result, expected) + + def test_fillna_scalar(self, data_missing): + valid = data_missing[1] + result = data_missing.fillna(valid) + expected = data_missing.fillna(valid) + tm.assert_extension_array_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:Series.fillna with 'method' is deprecated:FutureWarning" + ) + def test_fillna_limit_pad(self, data_missing): + arr = data_missing.take([1, 0, 0, 0, 1]) + result = pd.Series(arr).ffill(limit=2) + expected = pd.Series(data_missing.take([1, 1, 1, 0, 1])) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "limit_area, input_ilocs, expected_ilocs", + [ + ("outside", [1, 0, 0, 0, 1], [1, 0, 0, 0, 1]), + ("outside", [1, 0, 1, 0, 1], [1, 0, 1, 0, 1]), + ("outside", [0, 1, 1, 1, 0], [0, 1, 1, 1, 1]), + ("outside", [0, 1, 0, 1, 0], [0, 1, 0, 1, 1]), + ("inside", [1, 0, 0, 0, 1], [1, 1, 1, 1, 1]), + ("inside", [1, 0, 1, 0, 1], [1, 1, 1, 1, 1]), + ("inside", [0, 1, 1, 1, 0], [0, 1, 1, 1, 0]), + ("inside", [0, 1, 0, 1, 0], [0, 1, 1, 1, 0]), + ], + ) + def test_ffill_limit_area( + self, data_missing, limit_area, input_ilocs, expected_ilocs + ): + # GH#56616 + arr = data_missing.take(input_ilocs) + result = pd.Series(arr).ffill(limit_area=limit_area) + expected = pd.Series(data_missing.take(expected_ilocs)) + tm.assert_series_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:Series.fillna with 'method' is deprecated:FutureWarning" + ) + def test_fillna_limit_backfill(self, data_missing): + arr = data_missing.take([1, 0, 0, 0, 1]) + result = pd.Series(arr).fillna(method="backfill", limit=2) + expected = pd.Series(data_missing.take([1, 0, 1, 1, 1])) + tm.assert_series_equal(result, expected) + + def test_fillna_no_op_returns_copy(self, data): + data = data[~data.isna()] + + valid = data[0] + result = data.fillna(valid) + assert result is not data + tm.assert_extension_array_equal(result, data) + + result = data._pad_or_backfill(method="backfill") + assert result is not data + tm.assert_extension_array_equal(result, data) + + def test_fillna_series(self, data_missing): + fill_value = data_missing[1] + ser = pd.Series(data_missing) + + result = ser.fillna(fill_value) + expected = pd.Series( + data_missing._from_sequence( + [fill_value, fill_value], dtype=data_missing.dtype + ) + ) + tm.assert_series_equal(result, expected) + + # Fill with a series + result = ser.fillna(expected) + tm.assert_series_equal(result, expected) + + # Fill with a series not affecting the missing values + result = ser.fillna(ser) + tm.assert_series_equal(result, ser) + + def test_fillna_series_method(self, data_missing, fillna_method): + fill_value = data_missing[1] + + if fillna_method == "ffill": + data_missing = data_missing[::-1] + + result = getattr(pd.Series(data_missing), fillna_method)() + expected = pd.Series( + data_missing._from_sequence( + [fill_value, fill_value], dtype=data_missing.dtype + ) + ) + + tm.assert_series_equal(result, expected) + + def test_fillna_frame(self, data_missing): + fill_value = data_missing[1] + + result = pd.DataFrame({"A": data_missing, "B": [1, 2]}).fillna(fill_value) + + expected = pd.DataFrame( + { + "A": data_missing._from_sequence( + [fill_value, fill_value], dtype=data_missing.dtype + ), + "B": [1, 2], + } + ) + + tm.assert_frame_equal(result, expected) + + def test_fillna_fill_other(self, data): + result = pd.DataFrame({"A": data, "B": [np.nan] * len(data)}).fillna({"B": 0.0}) + + expected = pd.DataFrame({"A": data, "B": [0.0] * len(result)}) + + tm.assert_frame_equal(result, expected) + + def test_use_inf_as_na_no_effect(self, data_missing): + ser = pd.Series(data_missing) + expected = ser.isna() + msg = "use_inf_as_na option is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + with pd.option_context("mode.use_inf_as_na", True): + result = ser.isna() + tm.assert_series_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/ops.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..5cd66d8a874c70af8ce26dcbd45fbf653086f861 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/ops.py @@ -0,0 +1,299 @@ +from __future__ import annotations + +from typing import final + +import numpy as np +import pytest + +from pandas._config import using_pyarrow_string_dtype + +from pandas.core.dtypes.common import is_string_dtype + +import pandas as pd +import pandas._testing as tm +from pandas.core import ops + + +class BaseOpsUtil: + series_scalar_exc: type[Exception] | None = TypeError + frame_scalar_exc: type[Exception] | None = TypeError + series_array_exc: type[Exception] | None = TypeError + divmod_exc: type[Exception] | None = TypeError + + def _get_expected_exception( + self, op_name: str, obj, other + ) -> type[Exception] | None: + # Find the Exception, if any we expect to raise calling + # obj.__op_name__(other) + + # The self.obj_bar_exc pattern isn't great in part because it can depend + # on op_name or dtypes, but we use it here for backward-compatibility. + if op_name in ["__divmod__", "__rdivmod__"]: + result = self.divmod_exc + elif isinstance(obj, pd.Series) and isinstance(other, pd.Series): + result = self.series_array_exc + elif isinstance(obj, pd.Series): + result = self.series_scalar_exc + else: + result = self.frame_scalar_exc + + if using_pyarrow_string_dtype() and result is not None: + import pyarrow as pa + + result = ( # type: ignore[assignment] + result, + pa.lib.ArrowNotImplementedError, + NotImplementedError, + ) + return result + + def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result): + # In _check_op we check that the result of a pointwise operation + # (found via _combine) matches the result of the vectorized + # operation obj.__op_name__(other). + # In some cases pandas dtype inference on the scalar result may not + # give a matching dtype even if both operations are behaving "correctly". + # In these cases, do extra required casting here. + return pointwise_result + + def get_op_from_name(self, op_name: str): + return tm.get_op_from_name(op_name) + + # Subclasses are not expected to need to override check_opname, _check_op, + # _check_divmod_op, or _combine. + # Ideally any relevant overriding can be done in _cast_pointwise_result, + # get_op_from_name, and the specification of `exc`. If you find a use + # case that still requires overriding _check_op or _combine, please let + # us know at github.com/pandas-dev/pandas/issues + @final + def check_opname(self, ser: pd.Series, op_name: str, other): + exc = self._get_expected_exception(op_name, ser, other) + op = self.get_op_from_name(op_name) + + self._check_op(ser, op, other, op_name, exc) + + # see comment on check_opname + @final + def _combine(self, obj, other, op): + if isinstance(obj, pd.DataFrame): + if len(obj.columns) != 1: + raise NotImplementedError + expected = obj.iloc[:, 0].combine(other, op).to_frame() + else: + expected = obj.combine(other, op) + return expected + + # see comment on check_opname + @final + def _check_op( + self, ser: pd.Series, op, other, op_name: str, exc=NotImplementedError + ): + # Check that the Series/DataFrame arithmetic/comparison method matches + # the pointwise result from _combine. + + if exc is None: + result = op(ser, other) + expected = self._combine(ser, other, op) + expected = self._cast_pointwise_result(op_name, ser, other, expected) + assert isinstance(result, type(ser)) + tm.assert_equal(result, expected) + else: + with pytest.raises(exc): + op(ser, other) + + # see comment on check_opname + @final + def _check_divmod_op(self, ser: pd.Series, op, other): + # check that divmod behavior matches behavior of floordiv+mod + if op is divmod: + exc = self._get_expected_exception("__divmod__", ser, other) + else: + exc = self._get_expected_exception("__rdivmod__", ser, other) + if exc is None: + result_div, result_mod = op(ser, other) + if op is divmod: + expected_div, expected_mod = ser // other, ser % other + else: + expected_div, expected_mod = other // ser, other % ser + tm.assert_series_equal(result_div, expected_div) + tm.assert_series_equal(result_mod, expected_mod) + else: + with pytest.raises(exc): + divmod(ser, other) + + +class BaseArithmeticOpsTests(BaseOpsUtil): + """ + Various Series and DataFrame arithmetic ops methods. + + Subclasses supporting various ops should set the class variables + to indicate that they support ops of that kind + + * series_scalar_exc = TypeError + * frame_scalar_exc = TypeError + * series_array_exc = TypeError + * divmod_exc = TypeError + """ + + series_scalar_exc: type[Exception] | None = TypeError + frame_scalar_exc: type[Exception] | None = TypeError + series_array_exc: type[Exception] | None = TypeError + divmod_exc: type[Exception] | None = TypeError + + def test_arith_series_with_scalar(self, data, all_arithmetic_operators): + # series & scalar + if all_arithmetic_operators == "__rmod__" and is_string_dtype(data.dtype): + pytest.skip("Skip testing Python string formatting") + + op_name = all_arithmetic_operators + ser = pd.Series(data) + self.check_opname(ser, op_name, ser.iloc[0]) + + def test_arith_frame_with_scalar(self, data, all_arithmetic_operators): + # frame & scalar + if all_arithmetic_operators == "__rmod__" and is_string_dtype(data.dtype): + pytest.skip("Skip testing Python string formatting") + + op_name = all_arithmetic_operators + df = pd.DataFrame({"A": data}) + self.check_opname(df, op_name, data[0]) + + def test_arith_series_with_array(self, data, all_arithmetic_operators): + # ndarray & other series + op_name = all_arithmetic_operators + ser = pd.Series(data) + self.check_opname(ser, op_name, pd.Series([ser.iloc[0]] * len(ser))) + + def test_divmod(self, data): + ser = pd.Series(data) + self._check_divmod_op(ser, divmod, 1) + self._check_divmod_op(1, ops.rdivmod, ser) + + def test_divmod_series_array(self, data, data_for_twos): + ser = pd.Series(data) + self._check_divmod_op(ser, divmod, data) + + other = data_for_twos + self._check_divmod_op(other, ops.rdivmod, ser) + + other = pd.Series(other) + self._check_divmod_op(other, ops.rdivmod, ser) + + def test_add_series_with_extension_array(self, data): + # Check adding an ExtensionArray to a Series of the same dtype matches + # the behavior of adding the arrays directly and then wrapping in a + # Series. + + ser = pd.Series(data) + + exc = self._get_expected_exception("__add__", ser, data) + if exc is not None: + with pytest.raises(exc): + ser + data + return + + result = ser + data + expected = pd.Series(data + data) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("box", [pd.Series, pd.DataFrame, pd.Index]) + @pytest.mark.parametrize( + "op_name", + [ + x + for x in tm.arithmetic_dunder_methods + tm.comparison_dunder_methods + if not x.startswith("__r") + ], + ) + def test_direct_arith_with_ndframe_returns_not_implemented( + self, data, box, op_name + ): + # EAs should return NotImplemented for ops with Series/DataFrame/Index + # Pandas takes care of unboxing the series and calling the EA's op. + other = box(data) + + if hasattr(data, op_name): + result = getattr(data, op_name)(other) + assert result is NotImplemented + + +class BaseComparisonOpsTests(BaseOpsUtil): + """Various Series and DataFrame comparison ops methods.""" + + def _compare_other(self, ser: pd.Series, data, op, other): + if op.__name__ in ["eq", "ne"]: + # comparison should match point-wise comparisons + result = op(ser, other) + expected = ser.combine(other, op) + expected = self._cast_pointwise_result(op.__name__, ser, other, expected) + tm.assert_series_equal(result, expected) + + else: + exc = None + try: + result = op(ser, other) + except Exception as err: + exc = err + + if exc is None: + # Didn't error, then should match pointwise behavior + expected = ser.combine(other, op) + expected = self._cast_pointwise_result( + op.__name__, ser, other, expected + ) + tm.assert_series_equal(result, expected) + else: + with pytest.raises(type(exc)): + ser.combine(other, op) + + def test_compare_scalar(self, data, comparison_op): + ser = pd.Series(data) + self._compare_other(ser, data, comparison_op, 0) + + def test_compare_array(self, data, comparison_op): + ser = pd.Series(data) + other = pd.Series([data[0]] * len(data), dtype=data.dtype) + self._compare_other(ser, data, comparison_op, other) + + +class BaseUnaryOpsTests(BaseOpsUtil): + def test_invert(self, data): + ser = pd.Series(data, name="name") + try: + # 10 is an arbitrary choice here, just avoid iterating over + # the whole array to trim test runtime + [~x for x in data[:10]] + except TypeError: + # scalars don't support invert -> we don't expect the vectorized + # operation to succeed + with pytest.raises(TypeError): + ~ser + with pytest.raises(TypeError): + ~data + else: + # Note we do not reuse the pointwise result to construct expected + # because python semantics for negating bools are weird see GH#54569 + result = ~ser + expected = pd.Series(~data, name="name") + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("ufunc", [np.positive, np.negative, np.abs]) + def test_unary_ufunc_dunder_equivalence(self, data, ufunc): + # the dunder __pos__ works if and only if np.positive works, + # same for __neg__/np.negative and __abs__/np.abs + attr = {np.positive: "__pos__", np.negative: "__neg__", np.abs: "__abs__"}[ + ufunc + ] + + exc = None + try: + result = getattr(data, attr)() + except Exception as err: + exc = err + + # if __pos__ raised, then so should the ufunc + with pytest.raises((type(exc), TypeError)): + ufunc(data) + else: + alt = ufunc(data) + tm.assert_extension_array_equal(result, alt) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/printing.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/printing.py new file mode 100644 index 0000000000000000000000000000000000000000..b20236ec107b04a09238c472a1d7172256334d3b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/printing.py @@ -0,0 +1,41 @@ +import io + +import pytest + +import pandas as pd + + +class BasePrintingTests: + """Tests checking the formatting of your EA when printed.""" + + @pytest.mark.parametrize("size", ["big", "small"]) + def test_array_repr(self, data, size): + if size == "small": + data = data[:5] + else: + data = type(data)._concat_same_type([data] * 5) + + result = repr(data) + assert type(data).__name__ in result + assert f"Length: {len(data)}" in result + assert str(data.dtype) in result + if size == "big": + assert "..." in result + + def test_array_repr_unicode(self, data): + result = str(data) + assert isinstance(result, str) + + def test_series_repr(self, data): + ser = pd.Series(data) + assert data.dtype.name in repr(ser) + + def test_dataframe_repr(self, data): + df = pd.DataFrame({"A": data}) + repr(df) + + def test_dtype_name_in_info(self, data): + buf = io.StringIO() + pd.DataFrame({"A": data}).info(buf=buf) + result = buf.getvalue() + assert data.dtype.name in result diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/reduce.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/reduce.py new file mode 100644 index 0000000000000000000000000000000000000000..6ea1b3a6fbe9da36e430c26b7ce7bcb706464108 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/reduce.py @@ -0,0 +1,153 @@ +from typing import final + +import pytest + +import pandas as pd +import pandas._testing as tm +from pandas.api.types import is_numeric_dtype + + +class BaseReduceTests: + """ + Reduction specific tests. Generally these only + make sense for numeric/boolean operations. + """ + + def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool: + # Specify if we expect this reduction to succeed. + return False + + def check_reduce(self, ser: pd.Series, op_name: str, skipna: bool): + # We perform the same operation on the np.float64 data and check + # that the results match. Override if you need to cast to something + # other than float64. + res_op = getattr(ser, op_name) + + try: + alt = ser.astype("float64") + except (TypeError, ValueError): + # e.g. Interval can't cast (TypeError), StringArray can't cast + # (ValueError), so let's cast to object and do + # the reduction pointwise + alt = ser.astype(object) + + exp_op = getattr(alt, op_name) + if op_name == "count": + result = res_op() + expected = exp_op() + else: + result = res_op(skipna=skipna) + expected = exp_op(skipna=skipna) + tm.assert_almost_equal(result, expected) + + def _get_expected_reduction_dtype(self, arr, op_name: str, skipna: bool): + # Find the expected dtype when the given reduction is done on a DataFrame + # column with this array. The default assumes float64-like behavior, + # i.e. retains the dtype. + return arr.dtype + + # We anticipate that authors should not need to override check_reduce_frame, + # but should be able to do any necessary overriding in + # _get_expected_reduction_dtype. If you have a use case where this + # does not hold, please let us know at github.com/pandas-dev/pandas/issues. + @final + def check_reduce_frame(self, ser: pd.Series, op_name: str, skipna: bool): + # Check that the 2D reduction done in a DataFrame reduction "looks like" + # a wrapped version of the 1D reduction done by Series. + arr = ser.array + df = pd.DataFrame({"a": arr}) + + kwargs = {"ddof": 1} if op_name in ["var", "std"] else {} + + cmp_dtype = self._get_expected_reduction_dtype(arr, op_name, skipna) + + # The DataFrame method just calls arr._reduce with keepdims=True, + # so this first check is perfunctory. + result1 = arr._reduce(op_name, skipna=skipna, keepdims=True, **kwargs) + result2 = getattr(df, op_name)(skipna=skipna, **kwargs).array + tm.assert_extension_array_equal(result1, result2) + + # Check that the 2D reduction looks like a wrapped version of the + # 1D reduction + if not skipna and ser.isna().any(): + expected = pd.array([pd.NA], dtype=cmp_dtype) + else: + exp_value = getattr(ser.dropna(), op_name)() + expected = pd.array([exp_value], dtype=cmp_dtype) + + tm.assert_extension_array_equal(result1, expected) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_reduce_series_boolean(self, data, all_boolean_reductions, skipna): + op_name = all_boolean_reductions + ser = pd.Series(data) + + if not self._supports_reduction(ser, op_name): + # TODO: the message being checked here isn't actually checking anything + msg = ( + "[Cc]annot perform|Categorical is not ordered for operation|" + "does not support reduction|" + ) + + with pytest.raises(TypeError, match=msg): + getattr(ser, op_name)(skipna=skipna) + + else: + self.check_reduce(ser, op_name, skipna) + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + @pytest.mark.parametrize("skipna", [True, False]) + def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna): + op_name = all_numeric_reductions + ser = pd.Series(data) + + if not self._supports_reduction(ser, op_name): + # TODO: the message being checked here isn't actually checking anything + msg = ( + "[Cc]annot perform|Categorical is not ordered for operation|" + "does not support reduction|" + ) + + with pytest.raises(TypeError, match=msg): + getattr(ser, op_name)(skipna=skipna) + + else: + # min/max with empty produce numpy warnings + self.check_reduce(ser, op_name, skipna) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_reduce_frame(self, data, all_numeric_reductions, skipna): + op_name = all_numeric_reductions + ser = pd.Series(data) + if not is_numeric_dtype(ser.dtype): + pytest.skip(f"{ser.dtype} is not numeric dtype") + + if op_name in ["count", "kurt", "sem"]: + pytest.skip(f"{op_name} not an array method") + + if not self._supports_reduction(ser, op_name): + pytest.skip(f"Reduction {op_name} not supported for this dtype") + + self.check_reduce_frame(ser, op_name, skipna) + + +# TODO(3.0): remove BaseNoReduceTests, BaseNumericReduceTests, +# BaseBooleanReduceTests +class BaseNoReduceTests(BaseReduceTests): + """we don't define any reductions""" + + +class BaseNumericReduceTests(BaseReduceTests): + # For backward compatibility only, this only runs the numeric reductions + def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool: + if op_name in ["any", "all"]: + pytest.skip("These are tested in BaseBooleanReduceTests") + return True + + +class BaseBooleanReduceTests(BaseReduceTests): + # For backward compatibility only, this only runs the numeric reductions + def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool: + if op_name not in ["any", "all"]: + pytest.skip("These are tested in BaseNumericReduceTests") + return True diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/reshaping.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/reshaping.py new file mode 100644 index 0000000000000000000000000000000000000000..4550e3b055cfeaea60f9d0b44c97e099e8e4d47c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/reshaping.py @@ -0,0 +1,379 @@ +import itertools + +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm +from pandas.api.extensions import ExtensionArray +from pandas.core.internals.blocks import EABackedBlock + + +class BaseReshapingTests: + """Tests for reshaping and concatenation.""" + + @pytest.mark.parametrize("in_frame", [True, False]) + def test_concat(self, data, in_frame): + wrapped = pd.Series(data) + if in_frame: + wrapped = pd.DataFrame(wrapped) + result = pd.concat([wrapped, wrapped], ignore_index=True) + + assert len(result) == len(data) * 2 + + if in_frame: + dtype = result.dtypes[0] + else: + dtype = result.dtype + + assert dtype == data.dtype + if hasattr(result._mgr, "blocks"): + assert isinstance(result._mgr.blocks[0], EABackedBlock) + assert isinstance(result._mgr.arrays[0], ExtensionArray) + + @pytest.mark.parametrize("in_frame", [True, False]) + def test_concat_all_na_block(self, data_missing, in_frame): + valid_block = pd.Series(data_missing.take([1, 1]), index=[0, 1]) + na_block = pd.Series(data_missing.take([0, 0]), index=[2, 3]) + if in_frame: + valid_block = pd.DataFrame({"a": valid_block}) + na_block = pd.DataFrame({"a": na_block}) + result = pd.concat([valid_block, na_block]) + if in_frame: + expected = pd.DataFrame({"a": data_missing.take([1, 1, 0, 0])}) + tm.assert_frame_equal(result, expected) + else: + expected = pd.Series(data_missing.take([1, 1, 0, 0])) + tm.assert_series_equal(result, expected) + + def test_concat_mixed_dtypes(self, data): + # https://github.com/pandas-dev/pandas/issues/20762 + df1 = pd.DataFrame({"A": data[:3]}) + df2 = pd.DataFrame({"A": [1, 2, 3]}) + df3 = pd.DataFrame({"A": ["a", "b", "c"]}).astype("category") + dfs = [df1, df2, df3] + + # dataframes + result = pd.concat(dfs) + expected = pd.concat([x.astype(object) for x in dfs]) + tm.assert_frame_equal(result, expected) + + # series + result = pd.concat([x["A"] for x in dfs]) + expected = pd.concat([x["A"].astype(object) for x in dfs]) + tm.assert_series_equal(result, expected) + + # simple test for just EA and one other + result = pd.concat([df1, df2.astype(object)]) + expected = pd.concat([df1.astype("object"), df2.astype("object")]) + tm.assert_frame_equal(result, expected) + + result = pd.concat([df1["A"], df2["A"].astype(object)]) + expected = pd.concat([df1["A"].astype("object"), df2["A"].astype("object")]) + tm.assert_series_equal(result, expected) + + def test_concat_columns(self, data, na_value): + df1 = pd.DataFrame({"A": data[:3]}) + df2 = pd.DataFrame({"B": [1, 2, 3]}) + + expected = pd.DataFrame({"A": data[:3], "B": [1, 2, 3]}) + result = pd.concat([df1, df2], axis=1) + tm.assert_frame_equal(result, expected) + result = pd.concat([df1["A"], df2["B"]], axis=1) + tm.assert_frame_equal(result, expected) + + # non-aligned + df2 = pd.DataFrame({"B": [1, 2, 3]}, index=[1, 2, 3]) + expected = pd.DataFrame( + { + "A": data._from_sequence(list(data[:3]) + [na_value], dtype=data.dtype), + "B": [np.nan, 1, 2, 3], + } + ) + + result = pd.concat([df1, df2], axis=1) + tm.assert_frame_equal(result, expected) + result = pd.concat([df1["A"], df2["B"]], axis=1) + tm.assert_frame_equal(result, expected) + + def test_concat_extension_arrays_copy_false(self, data, na_value): + # GH 20756 + df1 = pd.DataFrame({"A": data[:3]}) + df2 = pd.DataFrame({"B": data[3:7]}) + expected = pd.DataFrame( + { + "A": data._from_sequence(list(data[:3]) + [na_value], dtype=data.dtype), + "B": data[3:7], + } + ) + result = pd.concat([df1, df2], axis=1, copy=False) + tm.assert_frame_equal(result, expected) + + def test_concat_with_reindex(self, data): + # GH-33027 + a = pd.DataFrame({"a": data[:5]}) + b = pd.DataFrame({"b": data[:5]}) + result = pd.concat([a, b], ignore_index=True) + expected = pd.DataFrame( + { + "a": data.take(list(range(5)) + ([-1] * 5), allow_fill=True), + "b": data.take(([-1] * 5) + list(range(5)), allow_fill=True), + } + ) + tm.assert_frame_equal(result, expected) + + def test_align(self, data, na_value): + a = data[:3] + b = data[2:5] + r1, r2 = pd.Series(a).align(pd.Series(b, index=[1, 2, 3])) + + # Assumes that the ctor can take a list of scalars of the type + e1 = pd.Series(data._from_sequence(list(a) + [na_value], dtype=data.dtype)) + e2 = pd.Series(data._from_sequence([na_value] + list(b), dtype=data.dtype)) + tm.assert_series_equal(r1, e1) + tm.assert_series_equal(r2, e2) + + def test_align_frame(self, data, na_value): + a = data[:3] + b = data[2:5] + r1, r2 = pd.DataFrame({"A": a}).align(pd.DataFrame({"A": b}, index=[1, 2, 3])) + + # Assumes that the ctor can take a list of scalars of the type + e1 = pd.DataFrame( + {"A": data._from_sequence(list(a) + [na_value], dtype=data.dtype)} + ) + e2 = pd.DataFrame( + {"A": data._from_sequence([na_value] + list(b), dtype=data.dtype)} + ) + tm.assert_frame_equal(r1, e1) + tm.assert_frame_equal(r2, e2) + + def test_align_series_frame(self, data, na_value): + # https://github.com/pandas-dev/pandas/issues/20576 + ser = pd.Series(data, name="a") + df = pd.DataFrame({"col": np.arange(len(ser) + 1)}) + r1, r2 = ser.align(df) + + e1 = pd.Series( + data._from_sequence(list(data) + [na_value], dtype=data.dtype), + name=ser.name, + ) + + tm.assert_series_equal(r1, e1) + tm.assert_frame_equal(r2, df) + + def test_set_frame_expand_regular_with_extension(self, data): + df = pd.DataFrame({"A": [1] * len(data)}) + df["B"] = data + expected = pd.DataFrame({"A": [1] * len(data), "B": data}) + tm.assert_frame_equal(df, expected) + + def test_set_frame_expand_extension_with_regular(self, data): + df = pd.DataFrame({"A": data}) + df["B"] = [1] * len(data) + expected = pd.DataFrame({"A": data, "B": [1] * len(data)}) + tm.assert_frame_equal(df, expected) + + def test_set_frame_overwrite_object(self, data): + # https://github.com/pandas-dev/pandas/issues/20555 + df = pd.DataFrame({"A": [1] * len(data)}, dtype=object) + df["A"] = data + assert df.dtypes["A"] == data.dtype + + def test_merge(self, data, na_value): + # GH-20743 + df1 = pd.DataFrame({"ext": data[:3], "int1": [1, 2, 3], "key": [0, 1, 2]}) + df2 = pd.DataFrame({"int2": [1, 2, 3, 4], "key": [0, 0, 1, 3]}) + + res = pd.merge(df1, df2) + exp = pd.DataFrame( + { + "int1": [1, 1, 2], + "int2": [1, 2, 3], + "key": [0, 0, 1], + "ext": data._from_sequence( + [data[0], data[0], data[1]], dtype=data.dtype + ), + } + ) + tm.assert_frame_equal(res, exp[["ext", "int1", "key", "int2"]]) + + res = pd.merge(df1, df2, how="outer") + exp = pd.DataFrame( + { + "int1": [1, 1, 2, 3, np.nan], + "int2": [1, 2, 3, np.nan, 4], + "key": [0, 0, 1, 2, 3], + "ext": data._from_sequence( + [data[0], data[0], data[1], data[2], na_value], dtype=data.dtype + ), + } + ) + tm.assert_frame_equal(res, exp[["ext", "int1", "key", "int2"]]) + + def test_merge_on_extension_array(self, data): + # GH 23020 + a, b = data[:2] + key = type(data)._from_sequence([a, b], dtype=data.dtype) + + df = pd.DataFrame({"key": key, "val": [1, 2]}) + result = pd.merge(df, df, on="key") + expected = pd.DataFrame({"key": key, "val_x": [1, 2], "val_y": [1, 2]}) + tm.assert_frame_equal(result, expected) + + # order + result = pd.merge(df.iloc[[1, 0]], df, on="key") + expected = expected.iloc[[1, 0]].reset_index(drop=True) + tm.assert_frame_equal(result, expected) + + def test_merge_on_extension_array_duplicates(self, data): + # GH 23020 + a, b = data[:2] + key = type(data)._from_sequence([a, b, a], dtype=data.dtype) + df1 = pd.DataFrame({"key": key, "val": [1, 2, 3]}) + df2 = pd.DataFrame({"key": key, "val": [1, 2, 3]}) + + result = pd.merge(df1, df2, on="key") + expected = pd.DataFrame( + { + "key": key.take([0, 0, 1, 2, 2]), + "val_x": [1, 1, 2, 3, 3], + "val_y": [1, 3, 2, 1, 3], + } + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + @pytest.mark.parametrize( + "columns", + [ + ["A", "B"], + pd.MultiIndex.from_tuples( + [("A", "a"), ("A", "b")], names=["outer", "inner"] + ), + ], + ) + @pytest.mark.parametrize("future_stack", [True, False]) + def test_stack(self, data, columns, future_stack): + df = pd.DataFrame({"A": data[:5], "B": data[:5]}) + df.columns = columns + result = df.stack(future_stack=future_stack) + expected = df.astype(object).stack(future_stack=future_stack) + # we need a second astype(object), in case the constructor inferred + # object -> specialized, as is done for period. + expected = expected.astype(object) + + if isinstance(expected, pd.Series): + assert result.dtype == df.iloc[:, 0].dtype + else: + assert all(result.dtypes == df.iloc[:, 0].dtype) + + result = result.astype(object) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize( + "index", + [ + # Two levels, uniform. + pd.MultiIndex.from_product(([["A", "B"], ["a", "b"]]), names=["a", "b"]), + # non-uniform + pd.MultiIndex.from_tuples([("A", "a"), ("A", "b"), ("B", "b")]), + # three levels, non-uniform + pd.MultiIndex.from_product([("A", "B"), ("a", "b", "c"), (0, 1, 2)]), + pd.MultiIndex.from_tuples( + [ + ("A", "a", 1), + ("A", "b", 0), + ("A", "a", 0), + ("B", "a", 0), + ("B", "c", 1), + ] + ), + ], + ) + @pytest.mark.parametrize("obj", ["series", "frame"]) + def test_unstack(self, data, index, obj): + data = data[: len(index)] + if obj == "series": + ser = pd.Series(data, index=index) + else: + ser = pd.DataFrame({"A": data, "B": data}, index=index) + + n = index.nlevels + levels = list(range(n)) + # [0, 1, 2] + # [(0,), (1,), (2,), (0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)] + combinations = itertools.chain.from_iterable( + itertools.permutations(levels, i) for i in range(1, n) + ) + + for level in combinations: + result = ser.unstack(level=level) + assert all( + isinstance(result[col].array, type(data)) for col in result.columns + ) + + if obj == "series": + # We should get the same result with to_frame+unstack+droplevel + df = ser.to_frame() + + alt = df.unstack(level=level).droplevel(0, axis=1) + tm.assert_frame_equal(result, alt) + + obj_ser = ser.astype(object) + + expected = obj_ser.unstack(level=level, fill_value=data.dtype.na_value) + if obj == "series": + assert (expected.dtypes == object).all() + + result = result.astype(object) + tm.assert_frame_equal(result, expected) + + def test_ravel(self, data): + # as long as EA is 1D-only, ravel is a no-op + result = data.ravel() + assert type(result) == type(data) + + if data.dtype._is_immutable: + pytest.skip(f"test_ravel assumes mutability and {data.dtype} is immutable") + + # Check that we have a view, not a copy + result[0] = result[1] + assert data[0] == data[1] + + def test_transpose(self, data): + result = data.transpose() + assert type(result) == type(data) + + # check we get a new object + assert result is not data + + # If we ever _did_ support 2D, shape should be reversed + assert result.shape == data.shape[::-1] + + if data.dtype._is_immutable: + pytest.skip( + f"test_transpose assumes mutability and {data.dtype} is immutable" + ) + + # Check that we have a view, not a copy + result[0] = result[1] + assert data[0] == data[1] + + def test_transpose_frame(self, data): + df = pd.DataFrame({"A": data[:4], "B": data[:4]}, index=["a", "b", "c", "d"]) + result = df.T + expected = pd.DataFrame( + { + "a": type(data)._from_sequence([data[0]] * 2, dtype=data.dtype), + "b": type(data)._from_sequence([data[1]] * 2, dtype=data.dtype), + "c": type(data)._from_sequence([data[2]] * 2, dtype=data.dtype), + "d": type(data)._from_sequence([data[3]] * 2, dtype=data.dtype), + }, + index=["A", "B"], + ) + tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(np.transpose(np.transpose(df)), df) + tm.assert_frame_equal(np.transpose(np.transpose(df[["A"]])), df[["A"]]) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/setitem.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/setitem.py new file mode 100644 index 0000000000000000000000000000000000000000..ca19845041e231f141d480ad57f668e4d6fcd5fc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/base/setitem.py @@ -0,0 +1,451 @@ +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm + + +class BaseSetitemTests: + @pytest.fixture( + params=[ + lambda x: x.index, + lambda x: list(x.index), + lambda x: slice(None), + lambda x: slice(0, len(x)), + lambda x: range(len(x)), + lambda x: list(range(len(x))), + lambda x: np.ones(len(x), dtype=bool), + ], + ids=[ + "index", + "list[index]", + "null_slice", + "full_slice", + "range", + "list(range)", + "mask", + ], + ) + def full_indexer(self, request): + """ + Fixture for an indexer to pass to obj.loc to get/set the full length of the + object. + + In some cases, assumes that obj.index is the default RangeIndex. + """ + return request.param + + @pytest.fixture(autouse=True) + def skip_if_immutable(self, dtype, request): + if dtype._is_immutable: + node = request.node + if node.name.split("[")[0] == "test_is_immutable": + # This fixture is auto-used, but we want to not-skip + # test_is_immutable. + return + + # When BaseSetitemTests is mixed into ExtensionTests, we only + # want this fixture to operate on the tests defined in this + # class/file. + defined_in = node.function.__qualname__.split(".")[0] + if defined_in == "BaseSetitemTests": + pytest.skip("__setitem__ test not applicable with immutable dtype") + + def test_is_immutable(self, data): + if data.dtype._is_immutable: + with pytest.raises(TypeError): + data[0] = data[0] + else: + data[0] = data[1] + assert data[0] == data[1] + + def test_setitem_scalar_series(self, data, box_in_series): + if box_in_series: + data = pd.Series(data) + data[0] = data[1] + assert data[0] == data[1] + + def test_setitem_sequence(self, data, box_in_series): + if box_in_series: + data = pd.Series(data) + original = data.copy() + + data[[0, 1]] = [data[1], data[0]] + assert data[0] == original[1] + assert data[1] == original[0] + + def test_setitem_sequence_mismatched_length_raises(self, data, as_array): + ser = pd.Series(data) + original = ser.copy() + value = [data[0]] + if as_array: + value = data._from_sequence(value, dtype=data.dtype) + + xpr = "cannot set using a {} indexer with a different length" + with pytest.raises(ValueError, match=xpr.format("list-like")): + ser[[0, 1]] = value + # Ensure no modifications made before the exception + tm.assert_series_equal(ser, original) + + with pytest.raises(ValueError, match=xpr.format("slice")): + ser[slice(3)] = value + tm.assert_series_equal(ser, original) + + def test_setitem_empty_indexer(self, data, box_in_series): + if box_in_series: + data = pd.Series(data) + original = data.copy() + data[np.array([], dtype=int)] = [] + tm.assert_equal(data, original) + + def test_setitem_sequence_broadcasts(self, data, box_in_series): + if box_in_series: + data = pd.Series(data) + data[[0, 1]] = data[2] + assert data[0] == data[2] + assert data[1] == data[2] + + @pytest.mark.parametrize("setter", ["loc", "iloc"]) + def test_setitem_scalar(self, data, setter): + arr = pd.Series(data) + setter = getattr(arr, setter) + setter[0] = data[1] + assert arr[0] == data[1] + + def test_setitem_loc_scalar_mixed(self, data): + df = pd.DataFrame({"A": np.arange(len(data)), "B": data}) + df.loc[0, "B"] = data[1] + assert df.loc[0, "B"] == data[1] + + def test_setitem_loc_scalar_single(self, data): + df = pd.DataFrame({"B": data}) + df.loc[10, "B"] = data[1] + assert df.loc[10, "B"] == data[1] + + def test_setitem_loc_scalar_multiple_homogoneous(self, data): + df = pd.DataFrame({"A": data, "B": data}) + df.loc[10, "B"] = data[1] + assert df.loc[10, "B"] == data[1] + + def test_setitem_iloc_scalar_mixed(self, data): + df = pd.DataFrame({"A": np.arange(len(data)), "B": data}) + df.iloc[0, 1] = data[1] + assert df.loc[0, "B"] == data[1] + + def test_setitem_iloc_scalar_single(self, data): + df = pd.DataFrame({"B": data}) + df.iloc[10, 0] = data[1] + assert df.loc[10, "B"] == data[1] + + def test_setitem_iloc_scalar_multiple_homogoneous(self, data): + df = pd.DataFrame({"A": data, "B": data}) + df.iloc[10, 1] = data[1] + assert df.loc[10, "B"] == data[1] + + @pytest.mark.parametrize( + "mask", + [ + np.array([True, True, True, False, False]), + pd.array([True, True, True, False, False], dtype="boolean"), + pd.array([True, True, True, pd.NA, pd.NA], dtype="boolean"), + ], + ids=["numpy-array", "boolean-array", "boolean-array-na"], + ) + def test_setitem_mask(self, data, mask, box_in_series): + arr = data[:5].copy() + expected = arr.take([0, 0, 0, 3, 4]) + if box_in_series: + arr = pd.Series(arr) + expected = pd.Series(expected) + arr[mask] = data[0] + tm.assert_equal(expected, arr) + + def test_setitem_mask_raises(self, data, box_in_series): + # wrong length + mask = np.array([True, False]) + + if box_in_series: + data = pd.Series(data) + + with pytest.raises(IndexError, match="wrong length"): + data[mask] = data[0] + + mask = pd.array(mask, dtype="boolean") + with pytest.raises(IndexError, match="wrong length"): + data[mask] = data[0] + + def test_setitem_mask_boolean_array_with_na(self, data, box_in_series): + mask = pd.array(np.zeros(data.shape, dtype="bool"), dtype="boolean") + mask[:3] = True + mask[3:5] = pd.NA + + if box_in_series: + data = pd.Series(data) + + data[mask] = data[0] + + assert (data[:3] == data[0]).all() + + @pytest.mark.parametrize( + "idx", + [[0, 1, 2], pd.array([0, 1, 2], dtype="Int64"), np.array([0, 1, 2])], + ids=["list", "integer-array", "numpy-array"], + ) + def test_setitem_integer_array(self, data, idx, box_in_series): + arr = data[:5].copy() + expected = data.take([0, 0, 0, 3, 4]) + + if box_in_series: + arr = pd.Series(arr) + expected = pd.Series(expected) + + arr[idx] = arr[0] + tm.assert_equal(arr, expected) + + @pytest.mark.parametrize( + "idx, box_in_series", + [ + ([0, 1, 2, pd.NA], False), + pytest.param( + [0, 1, 2, pd.NA], True, marks=pytest.mark.xfail(reason="GH-31948") + ), + (pd.array([0, 1, 2, pd.NA], dtype="Int64"), False), + (pd.array([0, 1, 2, pd.NA], dtype="Int64"), False), + ], + ids=["list-False", "list-True", "integer-array-False", "integer-array-True"], + ) + def test_setitem_integer_with_missing_raises(self, data, idx, box_in_series): + arr = data.copy() + + # TODO(xfail) this raises KeyError about labels not found (it tries label-based) + # for list of labels with Series + if box_in_series: + arr = pd.Series(data, index=[chr(100 + i) for i in range(len(data))]) + + msg = "Cannot index with an integer indexer containing NA values" + with pytest.raises(ValueError, match=msg): + arr[idx] = arr[0] + + @pytest.mark.parametrize("as_callable", [True, False]) + @pytest.mark.parametrize("setter", ["loc", None]) + def test_setitem_mask_aligned(self, data, as_callable, setter): + ser = pd.Series(data) + mask = np.zeros(len(data), dtype=bool) + mask[:2] = True + + if as_callable: + mask2 = lambda x: mask + else: + mask2 = mask + + if setter: + # loc + target = getattr(ser, setter) + else: + # Series.__setitem__ + target = ser + + target[mask2] = data[5:7] + + ser[mask2] = data[5:7] + assert ser[0] == data[5] + assert ser[1] == data[6] + + @pytest.mark.parametrize("setter", ["loc", None]) + def test_setitem_mask_broadcast(self, data, setter): + ser = pd.Series(data) + mask = np.zeros(len(data), dtype=bool) + mask[:2] = True + + if setter: # loc + target = getattr(ser, setter) + else: # __setitem__ + target = ser + + target[mask] = data[10] + assert ser[0] == data[10] + assert ser[1] == data[10] + + def test_setitem_expand_columns(self, data): + df = pd.DataFrame({"A": data}) + result = df.copy() + result["B"] = 1 + expected = pd.DataFrame({"A": data, "B": [1] * len(data)}) + tm.assert_frame_equal(result, expected) + + result = df.copy() + result.loc[:, "B"] = 1 + tm.assert_frame_equal(result, expected) + + # overwrite with new type + result["B"] = data + expected = pd.DataFrame({"A": data, "B": data}) + tm.assert_frame_equal(result, expected) + + def test_setitem_expand_with_extension(self, data): + df = pd.DataFrame({"A": [1] * len(data)}) + result = df.copy() + result["B"] = data + expected = pd.DataFrame({"A": [1] * len(data), "B": data}) + tm.assert_frame_equal(result, expected) + + result = df.copy() + result.loc[:, "B"] = data + tm.assert_frame_equal(result, expected) + + def test_setitem_frame_invalid_length(self, data): + df = pd.DataFrame({"A": [1] * len(data)}) + xpr = ( + rf"Length of values \({len(data[:5])}\) " + rf"does not match length of index \({len(df)}\)" + ) + with pytest.raises(ValueError, match=xpr): + df["B"] = data[:5] + + def test_setitem_tuple_index(self, data): + ser = pd.Series(data[:2], index=[(0, 0), (0, 1)]) + expected = pd.Series(data.take([1, 1]), index=ser.index) + ser[(0, 0)] = data[1] + tm.assert_series_equal(ser, expected) + + def test_setitem_slice(self, data, box_in_series): + arr = data[:5].copy() + expected = data.take([0, 0, 0, 3, 4]) + if box_in_series: + arr = pd.Series(arr) + expected = pd.Series(expected) + + arr[:3] = data[0] + tm.assert_equal(arr, expected) + + def test_setitem_loc_iloc_slice(self, data): + arr = data[:5].copy() + s = pd.Series(arr, index=["a", "b", "c", "d", "e"]) + expected = pd.Series(data.take([0, 0, 0, 3, 4]), index=s.index) + + result = s.copy() + result.iloc[:3] = data[0] + tm.assert_equal(result, expected) + + result = s.copy() + result.loc[:"c"] = data[0] + tm.assert_equal(result, expected) + + def test_setitem_slice_mismatch_length_raises(self, data): + arr = data[:5] + with pytest.raises(ValueError): + arr[:1] = arr[:2] + + def test_setitem_slice_array(self, data): + arr = data[:5].copy() + arr[:5] = data[-5:] + tm.assert_extension_array_equal(arr, data[-5:]) + + def test_setitem_scalar_key_sequence_raise(self, data): + arr = data[:5].copy() + with pytest.raises(ValueError): + arr[0] = arr[[0, 1]] + + def test_setitem_preserves_views(self, data): + # GH#28150 setitem shouldn't swap the underlying data + view1 = data.view() + view2 = data[:] + + data[0] = data[1] + assert view1[0] == data[1] + assert view2[0] == data[1] + + def test_setitem_with_expansion_dataframe_column(self, data, full_indexer): + # https://github.com/pandas-dev/pandas/issues/32395 + df = expected = pd.DataFrame({0: pd.Series(data)}) + result = pd.DataFrame(index=df.index) + + key = full_indexer(df) + result.loc[key, 0] = df[0] + + tm.assert_frame_equal(result, expected) + + def test_setitem_with_expansion_row(self, data, na_value): + df = pd.DataFrame({"data": data[:1]}) + + df.loc[1, "data"] = data[1] + expected = pd.DataFrame({"data": data[:2]}) + tm.assert_frame_equal(df, expected) + + # https://github.com/pandas-dev/pandas/issues/47284 + df.loc[2, "data"] = na_value + expected = pd.DataFrame( + {"data": pd.Series([data[0], data[1], na_value], dtype=data.dtype)} + ) + tm.assert_frame_equal(df, expected) + + def test_setitem_series(self, data, full_indexer): + # https://github.com/pandas-dev/pandas/issues/32395 + ser = pd.Series(data, name="data") + result = pd.Series(index=ser.index, dtype=object, name="data") + + # because result has object dtype, the attempt to do setting inplace + # is successful, and object dtype is retained + key = full_indexer(ser) + result.loc[key] = ser + + expected = pd.Series( + data.astype(object), index=ser.index, name="data", dtype=object + ) + tm.assert_series_equal(result, expected) + + def test_setitem_frame_2d_values(self, data): + # GH#44514 + df = pd.DataFrame({"A": data}) + + # Avoiding using_array_manager fixture + # https://github.com/pandas-dev/pandas/pull/44514#discussion_r754002410 + using_array_manager = isinstance(df._mgr, pd.core.internals.ArrayManager) + using_copy_on_write = pd.options.mode.copy_on_write + + blk_data = df._mgr.arrays[0] + + orig = df.copy() + + df.iloc[:] = df.copy() + tm.assert_frame_equal(df, orig) + + df.iloc[:-1] = df.iloc[:-1].copy() + tm.assert_frame_equal(df, orig) + + df.iloc[:] = df.values + tm.assert_frame_equal(df, orig) + if not using_array_manager and not using_copy_on_write: + # GH#33457 Check that this setting occurred in-place + # FIXME(ArrayManager): this should work there too + assert df._mgr.arrays[0] is blk_data + + df.iloc[:-1] = df.values[:-1] + tm.assert_frame_equal(df, orig) + + def test_delitem_series(self, data): + # GH#40763 + ser = pd.Series(data, name="data") + + taker = np.arange(len(ser)) + taker = np.delete(taker, 1) + + expected = ser[taker] + del ser[1] + tm.assert_series_equal(ser, expected) + + def test_setitem_invalid(self, data, invalid_scalar): + msg = "" # messages vary by subclass, so we do not test it + with pytest.raises((ValueError, TypeError), match=msg): + data[0] = invalid_scalar + + with pytest.raises((ValueError, TypeError), match=msg): + data[:] = invalid_scalar + + def test_setitem_2d_values(self, data): + # GH50085 + original = data.copy() + df = pd.DataFrame({"a": data, "b": data}) + df.loc[[0, 1], :] = df.loc[[1, 0], :].values + assert (df.loc[0, :] == original[1]).all() + assert (df.loc[1, :] == original[0]).all() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/date/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/date/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2a8c7e9f57a5da982530b8db854edd37baf13b6b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/date/__init__.py @@ -0,0 +1,6 @@ +from pandas.tests.extension.date.array import ( + DateArray, + DateDtype, +) + +__all__ = ["DateArray", "DateDtype"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/date/array.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/date/array.py new file mode 100644 index 0000000000000000000000000000000000000000..2306f5974ba186587dedb1159d64374601f55c86 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/date/array.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +import datetime as dt +from typing import ( + TYPE_CHECKING, + Any, + cast, +) + +import numpy as np + +from pandas.core.dtypes.dtypes import register_extension_dtype + +from pandas.api.extensions import ( + ExtensionArray, + ExtensionDtype, +) +from pandas.api.types import pandas_dtype + +if TYPE_CHECKING: + from collections.abc import Sequence + + from pandas._typing import ( + Dtype, + PositionalIndexer, + ) + + +@register_extension_dtype +class DateDtype(ExtensionDtype): + @property + def type(self): + return dt.date + + @property + def name(self): + return "DateDtype" + + @classmethod + def construct_from_string(cls, string: str): + if not isinstance(string, str): + raise TypeError( + f"'construct_from_string' expects a string, got {type(string)}" + ) + + if string == cls.__name__: + return cls() + else: + raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'") + + @classmethod + def construct_array_type(cls): + return DateArray + + @property + def na_value(self): + return dt.date.min + + def __repr__(self) -> str: + return self.name + + +class DateArray(ExtensionArray): + def __init__( + self, + dates: ( + dt.date + | Sequence[dt.date] + | tuple[np.ndarray, np.ndarray, np.ndarray] + | np.ndarray + ), + ) -> None: + if isinstance(dates, dt.date): + self._year = np.array([dates.year]) + self._month = np.array([dates.month]) + self._day = np.array([dates.year]) + return + + ldates = len(dates) + if isinstance(dates, list): + # pre-allocate the arrays since we know the size before hand + self._year = np.zeros(ldates, dtype=np.uint16) # 65535 (0, 9999) + self._month = np.zeros(ldates, dtype=np.uint8) # 255 (1, 31) + self._day = np.zeros(ldates, dtype=np.uint8) # 255 (1, 12) + # populate them + for i, (y, m, d) in enumerate( + (date.year, date.month, date.day) for date in dates + ): + self._year[i] = y + self._month[i] = m + self._day[i] = d + + elif isinstance(dates, tuple): + # only support triples + if ldates != 3: + raise ValueError("only triples are valid") + # check if all elements have the same type + if any(not isinstance(x, np.ndarray) for x in dates): + raise TypeError("invalid type") + ly, lm, ld = (len(cast(np.ndarray, d)) for d in dates) + if not ly == lm == ld: + raise ValueError( + f"tuple members must have the same length: {(ly, lm, ld)}" + ) + self._year = dates[0].astype(np.uint16) + self._month = dates[1].astype(np.uint8) + self._day = dates[2].astype(np.uint8) + + elif isinstance(dates, np.ndarray) and dates.dtype == "U10": + self._year = np.zeros(ldates, dtype=np.uint16) # 65535 (0, 9999) + self._month = np.zeros(ldates, dtype=np.uint8) # 255 (1, 31) + self._day = np.zeros(ldates, dtype=np.uint8) # 255 (1, 12) + + # error: "object_" object is not iterable + obj = np.char.split(dates, sep="-") + for (i,), (y, m, d) in np.ndenumerate(obj): # type: ignore[misc] + self._year[i] = int(y) + self._month[i] = int(m) + self._day[i] = int(d) + + else: + raise TypeError(f"{type(dates)} is not supported") + + @property + def dtype(self) -> ExtensionDtype: + return DateDtype() + + def astype(self, dtype, copy=True): + dtype = pandas_dtype(dtype) + + if isinstance(dtype, DateDtype): + data = self.copy() if copy else self + else: + data = self.to_numpy(dtype=dtype, copy=copy, na_value=dt.date.min) + + return data + + @property + def nbytes(self) -> int: + return self._year.nbytes + self._month.nbytes + self._day.nbytes + + def __len__(self) -> int: + return len(self._year) # all 3 arrays are enforced to have the same length + + def __getitem__(self, item: PositionalIndexer): + if isinstance(item, int): + return dt.date(self._year[item], self._month[item], self._day[item]) + else: + raise NotImplementedError("only ints are supported as indexes") + + def __setitem__(self, key: int | slice | np.ndarray, value: Any) -> None: + if not isinstance(key, int): + raise NotImplementedError("only ints are supported as indexes") + + if not isinstance(value, dt.date): + raise TypeError("you can only set datetime.date types") + + self._year[key] = value.year + self._month[key] = value.month + self._day[key] = value.day + + def __repr__(self) -> str: + return f"DateArray{list(zip(self._year, self._month, self._day))}" + + def copy(self) -> DateArray: + return DateArray((self._year.copy(), self._month.copy(), self._day.copy())) + + def isna(self) -> np.ndarray: + return np.logical_and( + np.logical_and( + self._year == dt.date.min.year, self._month == dt.date.min.month + ), + self._day == dt.date.min.day, + ) + + @classmethod + def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy=False): + if isinstance(scalars, dt.date): + raise TypeError + elif isinstance(scalars, DateArray): + if dtype is not None: + return scalars.astype(dtype, copy=copy) + if copy: + return scalars.copy() + return scalars[:] + elif isinstance(scalars, np.ndarray): + scalars = scalars.astype("U10") # 10 chars for yyyy-mm-dd + return DateArray(scalars) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/decimal/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/decimal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..34727b43a7b0fb325143dfedee4db25c4b56f5db --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/decimal/__init__.py @@ -0,0 +1,8 @@ +from pandas.tests.extension.decimal.array import ( + DecimalArray, + DecimalDtype, + make_data, + to_decimal, +) + +__all__ = ["DecimalArray", "DecimalDtype", "to_decimal", "make_data"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/decimal/array.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/decimal/array.py new file mode 100644 index 0000000000000000000000000000000000000000..521c1ff0b96bc12672b64be0fa191e153692f6da --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/decimal/array.py @@ -0,0 +1,311 @@ +from __future__ import annotations + +import decimal +import numbers +import sys +from typing import TYPE_CHECKING + +import numpy as np + +from pandas.core.dtypes.base import ExtensionDtype +from pandas.core.dtypes.common import ( + is_dtype_equal, + is_float, + is_integer, + pandas_dtype, +) + +import pandas as pd +from pandas.api.extensions import ( + no_default, + register_extension_dtype, +) +from pandas.api.types import ( + is_list_like, + is_scalar, +) +from pandas.core import arraylike +from pandas.core.algorithms import value_counts_internal as value_counts +from pandas.core.arraylike import OpsMixin +from pandas.core.arrays import ( + ExtensionArray, + ExtensionScalarOpsMixin, +) +from pandas.core.indexers import check_array_indexer + +if TYPE_CHECKING: + from pandas._typing import type_t + + +@register_extension_dtype +class DecimalDtype(ExtensionDtype): + type = decimal.Decimal + name = "decimal" + na_value = decimal.Decimal("NaN") + _metadata = ("context",) + + def __init__(self, context=None) -> None: + self.context = context or decimal.getcontext() + + def __repr__(self) -> str: + return f"DecimalDtype(context={self.context})" + + @classmethod + def construct_array_type(cls) -> type_t[DecimalArray]: + """ + Return the array type associated with this dtype. + + Returns + ------- + type + """ + return DecimalArray + + @property + def _is_numeric(self) -> bool: + return True + + +class DecimalArray(OpsMixin, ExtensionScalarOpsMixin, ExtensionArray): + __array_priority__ = 1000 + + def __init__(self, values, dtype=None, copy=False, context=None) -> None: + for i, val in enumerate(values): + if is_float(val) or is_integer(val): + if np.isnan(val): + values[i] = DecimalDtype.na_value + else: + # error: Argument 1 has incompatible type "float | int | + # integer[Any]"; expected "Decimal | float | str | tuple[int, + # Sequence[int], int]" + values[i] = DecimalDtype.type(val) # type: ignore[arg-type] + elif not isinstance(val, decimal.Decimal): + raise TypeError("All values must be of type " + str(decimal.Decimal)) + values = np.asarray(values, dtype=object) + + self._data = values + # Some aliases for common attribute names to ensure pandas supports + # these + self._items = self.data = self._data + # those aliases are currently not working due to assumptions + # in internal code (GH-20735) + # self._values = self.values = self.data + self._dtype = DecimalDtype(context) + + @property + def dtype(self): + return self._dtype + + @classmethod + def _from_sequence(cls, scalars, *, dtype=None, copy=False): + return cls(scalars) + + @classmethod + def _from_sequence_of_strings(cls, strings, dtype=None, copy=False): + return cls._from_sequence( + [decimal.Decimal(x) for x in strings], dtype=dtype, copy=copy + ) + + @classmethod + def _from_factorized(cls, values, original): + return cls(values) + + _HANDLED_TYPES = (decimal.Decimal, numbers.Number, np.ndarray) + + def to_numpy( + self, + dtype=None, + copy: bool = False, + na_value: object = no_default, + decimals=None, + ) -> np.ndarray: + result = np.asarray(self, dtype=dtype) + if decimals is not None: + result = np.asarray([round(x, decimals) for x in result]) + return result + + def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs): + # + if not all( + isinstance(t, self._HANDLED_TYPES + (DecimalArray,)) for t in inputs + ): + return NotImplemented + + result = arraylike.maybe_dispatch_ufunc_to_dunder_op( + self, ufunc, method, *inputs, **kwargs + ) + if result is not NotImplemented: + # e.g. test_array_ufunc_series_scalar_other + return result + + if "out" in kwargs: + return arraylike.dispatch_ufunc_with_out( + self, ufunc, method, *inputs, **kwargs + ) + + inputs = tuple(x._data if isinstance(x, DecimalArray) else x for x in inputs) + result = getattr(ufunc, method)(*inputs, **kwargs) + + if method == "reduce": + result = arraylike.dispatch_reduction_ufunc( + self, ufunc, method, *inputs, **kwargs + ) + if result is not NotImplemented: + return result + + def reconstruct(x): + if isinstance(x, (decimal.Decimal, numbers.Number)): + return x + else: + return type(self)._from_sequence(x, dtype=self.dtype) + + if ufunc.nout > 1: + return tuple(reconstruct(x) for x in result) + else: + return reconstruct(result) + + def __getitem__(self, item): + if isinstance(item, numbers.Integral): + return self._data[item] + else: + # array, slice. + item = pd.api.indexers.check_array_indexer(self, item) + return type(self)(self._data[item]) + + def take(self, indexer, allow_fill=False, fill_value=None): + from pandas.api.extensions import take + + data = self._data + if allow_fill and fill_value is None: + fill_value = self.dtype.na_value + + result = take(data, indexer, fill_value=fill_value, allow_fill=allow_fill) + return self._from_sequence(result, dtype=self.dtype) + + def copy(self): + return type(self)(self._data.copy(), dtype=self.dtype) + + def astype(self, dtype, copy=True): + if is_dtype_equal(dtype, self._dtype): + if not copy: + return self + dtype = pandas_dtype(dtype) + if isinstance(dtype, type(self.dtype)): + return type(self)(self._data, copy=copy, context=dtype.context) + + return super().astype(dtype, copy=copy) + + def __setitem__(self, key, value) -> None: + if is_list_like(value): + if is_scalar(key): + raise ValueError("setting an array element with a sequence.") + value = [decimal.Decimal(v) for v in value] + else: + value = decimal.Decimal(value) + + key = check_array_indexer(self, key) + self._data[key] = value + + def __len__(self) -> int: + return len(self._data) + + def __contains__(self, item) -> bool | np.bool_: + if not isinstance(item, decimal.Decimal): + return False + elif item.is_nan(): + return self.isna().any() + else: + return super().__contains__(item) + + @property + def nbytes(self) -> int: + n = len(self) + if n: + return n * sys.getsizeof(self[0]) + return 0 + + def isna(self): + return np.array([x.is_nan() for x in self._data], dtype=bool) + + @property + def _na_value(self): + return decimal.Decimal("NaN") + + def _formatter(self, boxed=False): + if boxed: + return "Decimal: {}".format + return repr + + @classmethod + def _concat_same_type(cls, to_concat): + return cls(np.concatenate([x._data for x in to_concat])) + + def _reduce( + self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs + ): + if skipna and self.isna().any(): + # If we don't have any NAs, we can ignore skipna + other = self[~self.isna()] + result = other._reduce(name, **kwargs) + elif name == "sum" and len(self) == 0: + # GH#29630 avoid returning int 0 or np.bool_(False) on old numpy + result = decimal.Decimal(0) + else: + try: + op = getattr(self.data, name) + except AttributeError as err: + raise NotImplementedError( + f"decimal does not support the {name} operation" + ) from err + result = op(axis=0) + + if keepdims: + return type(self)([result]) + else: + return result + + def _cmp_method(self, other, op): + # For use with OpsMixin + def convert_values(param): + if isinstance(param, ExtensionArray) or is_list_like(param): + ovalues = param + else: + # Assume it's an object + ovalues = [param] * len(self) + return ovalues + + lvalues = self + rvalues = convert_values(other) + + # If the operator is not defined for the underlying objects, + # a TypeError should be raised + res = [op(a, b) for (a, b) in zip(lvalues, rvalues)] + + return np.asarray(res, dtype=bool) + + def value_counts(self, dropna: bool = True): + return value_counts(self.to_numpy(), dropna=dropna) + + # We override fillna here to simulate a 3rd party EA that has done so. This + # lets us test the deprecation telling authors to implement _pad_or_backfill + # Simulate a 3rd-party EA that has not yet updated to include a "copy" + # keyword in its fillna method. + # error: Signature of "fillna" incompatible with supertype "ExtensionArray" + def fillna( # type: ignore[override] + self, + value=None, + method=None, + limit: int | None = None, + ): + return super().fillna(value=value, method=method, limit=limit, copy=True) + + +def to_decimal(values, context=None): + return DecimalArray([decimal.Decimal(x) for x in values], context=context) + + +def make_data(): + return [decimal.Decimal(val) for val in np.random.default_rng(2).random(100)] + + +DecimalArray._add_arithmetic_ops() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/decimal/test_decimal.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/decimal/test_decimal.py new file mode 100644 index 0000000000000000000000000000000000000000..9907e345ada63e54c5e80eb8b2cb11a9417799a4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/decimal/test_decimal.py @@ -0,0 +1,567 @@ +from __future__ import annotations + +import decimal +import operator + +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm +from pandas.tests.extension import base +from pandas.tests.extension.decimal.array import ( + DecimalArray, + DecimalDtype, + make_data, + to_decimal, +) + + +@pytest.fixture +def dtype(): + return DecimalDtype() + + +@pytest.fixture +def data(): + return DecimalArray(make_data()) + + +@pytest.fixture +def data_for_twos(): + return DecimalArray([decimal.Decimal(2) for _ in range(100)]) + + +@pytest.fixture +def data_missing(): + return DecimalArray([decimal.Decimal("NaN"), decimal.Decimal(1)]) + + +@pytest.fixture +def data_for_sorting(): + return DecimalArray( + [decimal.Decimal("1"), decimal.Decimal("2"), decimal.Decimal("0")] + ) + + +@pytest.fixture +def data_missing_for_sorting(): + return DecimalArray( + [decimal.Decimal("1"), decimal.Decimal("NaN"), decimal.Decimal("0")] + ) + + +@pytest.fixture +def na_cmp(): + return lambda x, y: x.is_nan() and y.is_nan() + + +@pytest.fixture +def data_for_grouping(): + b = decimal.Decimal("1.0") + a = decimal.Decimal("0.0") + c = decimal.Decimal("2.0") + na = decimal.Decimal("NaN") + return DecimalArray([b, b, na, na, a, a, b, c]) + + +class TestDecimalArray(base.ExtensionTests): + def _get_expected_exception( + self, op_name: str, obj, other + ) -> type[Exception] | None: + return None + + def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool: + return True + + def check_reduce(self, ser: pd.Series, op_name: str, skipna: bool): + if op_name == "count": + return super().check_reduce(ser, op_name, skipna) + else: + result = getattr(ser, op_name)(skipna=skipna) + expected = getattr(np.asarray(ser), op_name)() + tm.assert_almost_equal(result, expected) + + def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna, request): + if all_numeric_reductions in ["kurt", "skew", "sem", "median"]: + mark = pytest.mark.xfail(raises=NotImplementedError) + request.applymarker(mark) + super().test_reduce_series_numeric(data, all_numeric_reductions, skipna) + + def test_reduce_frame(self, data, all_numeric_reductions, skipna, request): + op_name = all_numeric_reductions + if op_name in ["skew", "median"]: + mark = pytest.mark.xfail(raises=NotImplementedError) + request.applymarker(mark) + + return super().test_reduce_frame(data, all_numeric_reductions, skipna) + + def test_compare_scalar(self, data, comparison_op): + ser = pd.Series(data) + self._compare_other(ser, data, comparison_op, 0.5) + + def test_compare_array(self, data, comparison_op): + ser = pd.Series(data) + + alter = np.random.default_rng(2).choice([-1, 0, 1], len(data)) + # Randomly double, halve or keep same value + other = pd.Series(data) * [decimal.Decimal(pow(2.0, i)) for i in alter] + self._compare_other(ser, data, comparison_op, other) + + def test_arith_series_with_array(self, data, all_arithmetic_operators): + op_name = all_arithmetic_operators + ser = pd.Series(data) + + context = decimal.getcontext() + divbyzerotrap = context.traps[decimal.DivisionByZero] + invalidoptrap = context.traps[decimal.InvalidOperation] + context.traps[decimal.DivisionByZero] = 0 + context.traps[decimal.InvalidOperation] = 0 + + # Decimal supports ops with int, but not float + other = pd.Series([int(d * 100) for d in data]) + self.check_opname(ser, op_name, other) + + if "mod" not in op_name: + self.check_opname(ser, op_name, ser * 2) + + self.check_opname(ser, op_name, 0) + self.check_opname(ser, op_name, 5) + context.traps[decimal.DivisionByZero] = divbyzerotrap + context.traps[decimal.InvalidOperation] = invalidoptrap + + def test_fillna_frame(self, data_missing): + msg = "ExtensionArray.fillna added a 'copy' keyword" + with tm.assert_produces_warning( + DeprecationWarning, match=msg, check_stacklevel=False + ): + super().test_fillna_frame(data_missing) + + def test_fillna_limit_pad(self, data_missing): + msg = "ExtensionArray.fillna 'method' keyword is deprecated" + with tm.assert_produces_warning( + DeprecationWarning, + match=msg, + check_stacklevel=False, + raise_on_extra_warnings=False, + ): + super().test_fillna_limit_pad(data_missing) + + msg = "The 'method' keyword in DecimalArray.fillna is deprecated" + with tm.assert_produces_warning( + FutureWarning, + match=msg, + check_stacklevel=False, + raise_on_extra_warnings=False, + ): + super().test_fillna_limit_pad(data_missing) + + @pytest.mark.parametrize( + "limit_area, input_ilocs, expected_ilocs", + [ + ("outside", [1, 0, 0, 0, 1], [1, 0, 0, 0, 1]), + ("outside", [1, 0, 1, 0, 1], [1, 0, 1, 0, 1]), + ("outside", [0, 1, 1, 1, 0], [0, 1, 1, 1, 1]), + ("outside", [0, 1, 0, 1, 0], [0, 1, 0, 1, 1]), + ("inside", [1, 0, 0, 0, 1], [1, 1, 1, 1, 1]), + ("inside", [1, 0, 1, 0, 1], [1, 1, 1, 1, 1]), + ("inside", [0, 1, 1, 1, 0], [0, 1, 1, 1, 0]), + ("inside", [0, 1, 0, 1, 0], [0, 1, 1, 1, 0]), + ], + ) + def test_ffill_limit_area( + self, data_missing, limit_area, input_ilocs, expected_ilocs + ): + # GH#56616 + msg = "ExtensionArray.fillna 'method' keyword is deprecated" + with tm.assert_produces_warning( + DeprecationWarning, + match=msg, + check_stacklevel=False, + raise_on_extra_warnings=False, + ): + msg = "DecimalArray does not implement limit_area" + with pytest.raises(NotImplementedError, match=msg): + super().test_ffill_limit_area( + data_missing, limit_area, input_ilocs, expected_ilocs + ) + + def test_fillna_limit_backfill(self, data_missing): + msg = "Series.fillna with 'method' is deprecated" + with tm.assert_produces_warning( + FutureWarning, + match=msg, + check_stacklevel=False, + raise_on_extra_warnings=False, + ): + super().test_fillna_limit_backfill(data_missing) + + msg = "ExtensionArray.fillna 'method' keyword is deprecated" + with tm.assert_produces_warning( + DeprecationWarning, + match=msg, + check_stacklevel=False, + raise_on_extra_warnings=False, + ): + super().test_fillna_limit_backfill(data_missing) + + msg = "The 'method' keyword in DecimalArray.fillna is deprecated" + with tm.assert_produces_warning( + FutureWarning, + match=msg, + check_stacklevel=False, + raise_on_extra_warnings=False, + ): + super().test_fillna_limit_backfill(data_missing) + + def test_fillna_no_op_returns_copy(self, data): + msg = "|".join( + [ + "ExtensionArray.fillna 'method' keyword is deprecated", + "The 'method' keyword in DecimalArray.fillna is deprecated", + ] + ) + with tm.assert_produces_warning( + (FutureWarning, DeprecationWarning), match=msg, check_stacklevel=False + ): + super().test_fillna_no_op_returns_copy(data) + + def test_fillna_series(self, data_missing): + msg = "ExtensionArray.fillna added a 'copy' keyword" + with tm.assert_produces_warning( + DeprecationWarning, match=msg, check_stacklevel=False + ): + super().test_fillna_series(data_missing) + + def test_fillna_series_method(self, data_missing, fillna_method): + msg = "|".join( + [ + "ExtensionArray.fillna 'method' keyword is deprecated", + "The 'method' keyword in DecimalArray.fillna is deprecated", + ] + ) + with tm.assert_produces_warning( + (FutureWarning, DeprecationWarning), match=msg, check_stacklevel=False + ): + super().test_fillna_series_method(data_missing, fillna_method) + + def test_fillna_copy_frame(self, data_missing, using_copy_on_write): + warn = DeprecationWarning if not using_copy_on_write else None + msg = "ExtensionArray.fillna added a 'copy' keyword" + with tm.assert_produces_warning(warn, match=msg, check_stacklevel=False): + super().test_fillna_copy_frame(data_missing) + + def test_fillna_copy_series(self, data_missing, using_copy_on_write): + warn = DeprecationWarning if not using_copy_on_write else None + msg = "ExtensionArray.fillna added a 'copy' keyword" + with tm.assert_produces_warning(warn, match=msg, check_stacklevel=False): + super().test_fillna_copy_series(data_missing) + + @pytest.mark.parametrize("dropna", [True, False]) + def test_value_counts(self, all_data, dropna, request): + all_data = all_data[:10] + if dropna: + other = np.array(all_data[~all_data.isna()]) + else: + other = all_data + + vcs = pd.Series(all_data).value_counts(dropna=dropna) + vcs_ex = pd.Series(other).value_counts(dropna=dropna) + + with decimal.localcontext() as ctx: + # avoid raising when comparing Decimal("NAN") < Decimal(2) + ctx.traps[decimal.InvalidOperation] = False + + result = vcs.sort_index() + expected = vcs_ex.sort_index() + + tm.assert_series_equal(result, expected) + + def test_series_repr(self, data): + # Overriding this base test to explicitly test that + # the custom _formatter is used + ser = pd.Series(data) + assert data.dtype.name in repr(ser) + assert "Decimal: " in repr(ser) + + @pytest.mark.xfail(reason="Inconsistent array-vs-scalar behavior") + @pytest.mark.parametrize("ufunc", [np.positive, np.negative, np.abs]) + def test_unary_ufunc_dunder_equivalence(self, data, ufunc): + super().test_unary_ufunc_dunder_equivalence(data, ufunc) + + +def test_take_na_value_other_decimal(): + arr = DecimalArray([decimal.Decimal("1.0"), decimal.Decimal("2.0")]) + result = arr.take([0, -1], allow_fill=True, fill_value=decimal.Decimal("-1.0")) + expected = DecimalArray([decimal.Decimal("1.0"), decimal.Decimal("-1.0")]) + tm.assert_extension_array_equal(result, expected) + + +def test_series_constructor_coerce_data_to_extension_dtype(): + dtype = DecimalDtype() + ser = pd.Series([0, 1, 2], dtype=dtype) + + arr = DecimalArray( + [decimal.Decimal(0), decimal.Decimal(1), decimal.Decimal(2)], + dtype=dtype, + ) + exp = pd.Series(arr) + tm.assert_series_equal(ser, exp) + + +def test_series_constructor_with_dtype(): + arr = DecimalArray([decimal.Decimal("10.0")]) + result = pd.Series(arr, dtype=DecimalDtype()) + expected = pd.Series(arr) + tm.assert_series_equal(result, expected) + + result = pd.Series(arr, dtype="int64") + expected = pd.Series([10]) + tm.assert_series_equal(result, expected) + + +def test_dataframe_constructor_with_dtype(): + arr = DecimalArray([decimal.Decimal("10.0")]) + + result = pd.DataFrame({"A": arr}, dtype=DecimalDtype()) + expected = pd.DataFrame({"A": arr}) + tm.assert_frame_equal(result, expected) + + arr = DecimalArray([decimal.Decimal("10.0")]) + result = pd.DataFrame({"A": arr}, dtype="int64") + expected = pd.DataFrame({"A": [10]}) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("frame", [True, False]) +def test_astype_dispatches(frame): + # This is a dtype-specific test that ensures Series[decimal].astype + # gets all the way through to ExtensionArray.astype + # Designing a reliable smoke test that works for arbitrary data types + # is difficult. + data = pd.Series(DecimalArray([decimal.Decimal(2)]), name="a") + ctx = decimal.Context() + ctx.prec = 5 + + if frame: + data = data.to_frame() + + result = data.astype(DecimalDtype(ctx)) + + if frame: + result = result["a"] + + assert result.dtype.context.prec == ctx.prec + + +class DecimalArrayWithoutFromSequence(DecimalArray): + """Helper class for testing error handling in _from_sequence.""" + + @classmethod + def _from_sequence(cls, scalars, *, dtype=None, copy=False): + raise KeyError("For the test") + + +class DecimalArrayWithoutCoercion(DecimalArrayWithoutFromSequence): + @classmethod + def _create_arithmetic_method(cls, op): + return cls._create_method(op, coerce_to_dtype=False) + + +DecimalArrayWithoutCoercion._add_arithmetic_ops() + + +def test_combine_from_sequence_raises(monkeypatch): + # https://github.com/pandas-dev/pandas/issues/22850 + cls = DecimalArrayWithoutFromSequence + + @classmethod + def construct_array_type(cls): + return DecimalArrayWithoutFromSequence + + monkeypatch.setattr(DecimalDtype, "construct_array_type", construct_array_type) + + arr = cls([decimal.Decimal("1.0"), decimal.Decimal("2.0")]) + ser = pd.Series(arr) + result = ser.combine(ser, operator.add) + + # note: object dtype + expected = pd.Series( + [decimal.Decimal("2.0"), decimal.Decimal("4.0")], dtype="object" + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "class_", [DecimalArrayWithoutFromSequence, DecimalArrayWithoutCoercion] +) +def test_scalar_ops_from_sequence_raises(class_): + # op(EA, EA) should return an EA, or an ndarray if it's not possible + # to return an EA with the return values. + arr = class_([decimal.Decimal("1.0"), decimal.Decimal("2.0")]) + result = arr + arr + expected = np.array( + [decimal.Decimal("2.0"), decimal.Decimal("4.0")], dtype="object" + ) + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize( + "reverse, expected_div, expected_mod", + [(False, [0, 1, 1, 2], [1, 0, 1, 0]), (True, [2, 1, 0, 0], [0, 0, 2, 2])], +) +def test_divmod_array(reverse, expected_div, expected_mod): + # https://github.com/pandas-dev/pandas/issues/22930 + arr = to_decimal([1, 2, 3, 4]) + if reverse: + div, mod = divmod(2, arr) + else: + div, mod = divmod(arr, 2) + expected_div = to_decimal(expected_div) + expected_mod = to_decimal(expected_mod) + + tm.assert_extension_array_equal(div, expected_div) + tm.assert_extension_array_equal(mod, expected_mod) + + +def test_ufunc_fallback(data): + a = data[:5] + s = pd.Series(a, index=range(3, 8)) + result = np.abs(s) + expected = pd.Series(np.abs(a), index=range(3, 8)) + tm.assert_series_equal(result, expected) + + +def test_array_ufunc(): + a = to_decimal([1, 2, 3]) + result = np.exp(a) + expected = to_decimal(np.exp(a._data)) + tm.assert_extension_array_equal(result, expected) + + +def test_array_ufunc_series(): + a = to_decimal([1, 2, 3]) + s = pd.Series(a) + result = np.exp(s) + expected = pd.Series(to_decimal(np.exp(a._data))) + tm.assert_series_equal(result, expected) + + +def test_array_ufunc_series_scalar_other(): + # check _HANDLED_TYPES + a = to_decimal([1, 2, 3]) + s = pd.Series(a) + result = np.add(s, decimal.Decimal(1)) + expected = pd.Series(np.add(a, decimal.Decimal(1))) + tm.assert_series_equal(result, expected) + + +def test_array_ufunc_series_defer(): + a = to_decimal([1, 2, 3]) + s = pd.Series(a) + + expected = pd.Series(to_decimal([2, 4, 6])) + r1 = np.add(s, a) + r2 = np.add(a, s) + + tm.assert_series_equal(r1, expected) + tm.assert_series_equal(r2, expected) + + +def test_groupby_agg(): + # Ensure that the result of agg is inferred to be decimal dtype + # https://github.com/pandas-dev/pandas/issues/29141 + + data = make_data()[:5] + df = pd.DataFrame( + {"id1": [0, 0, 0, 1, 1], "id2": [0, 1, 0, 1, 1], "decimals": DecimalArray(data)} + ) + + # single key, selected column + expected = pd.Series(to_decimal([data[0], data[3]])) + result = df.groupby("id1")["decimals"].agg(lambda x: x.iloc[0]) + tm.assert_series_equal(result, expected, check_names=False) + result = df["decimals"].groupby(df["id1"]).agg(lambda x: x.iloc[0]) + tm.assert_series_equal(result, expected, check_names=False) + + # multiple keys, selected column + expected = pd.Series( + to_decimal([data[0], data[1], data[3]]), + index=pd.MultiIndex.from_tuples([(0, 0), (0, 1), (1, 1)]), + ) + result = df.groupby(["id1", "id2"])["decimals"].agg(lambda x: x.iloc[0]) + tm.assert_series_equal(result, expected, check_names=False) + result = df["decimals"].groupby([df["id1"], df["id2"]]).agg(lambda x: x.iloc[0]) + tm.assert_series_equal(result, expected, check_names=False) + + # multiple columns + expected = pd.DataFrame({"id2": [0, 1], "decimals": to_decimal([data[0], data[3]])}) + result = df.groupby("id1").agg(lambda x: x.iloc[0]) + tm.assert_frame_equal(result, expected, check_names=False) + + +def test_groupby_agg_ea_method(monkeypatch): + # Ensure that the result of agg is inferred to be decimal dtype + # https://github.com/pandas-dev/pandas/issues/29141 + + def DecimalArray__my_sum(self): + return np.sum(np.array(self)) + + monkeypatch.setattr(DecimalArray, "my_sum", DecimalArray__my_sum, raising=False) + + data = make_data()[:5] + df = pd.DataFrame({"id": [0, 0, 0, 1, 1], "decimals": DecimalArray(data)}) + expected = pd.Series(to_decimal([data[0] + data[1] + data[2], data[3] + data[4]])) + + result = df.groupby("id")["decimals"].agg(lambda x: x.values.my_sum()) + tm.assert_series_equal(result, expected, check_names=False) + s = pd.Series(DecimalArray(data)) + grouper = np.array([0, 0, 0, 1, 1], dtype=np.int64) + result = s.groupby(grouper).agg(lambda x: x.values.my_sum()) + tm.assert_series_equal(result, expected, check_names=False) + + +def test_indexing_no_materialize(monkeypatch): + # See https://github.com/pandas-dev/pandas/issues/29708 + # Ensure that indexing operations do not materialize (convert to a numpy + # array) the ExtensionArray unnecessary + + def DecimalArray__array__(self, dtype=None): + raise Exception("tried to convert a DecimalArray to a numpy array") + + monkeypatch.setattr(DecimalArray, "__array__", DecimalArray__array__, raising=False) + + data = make_data() + s = pd.Series(DecimalArray(data)) + df = pd.DataFrame({"a": s, "b": range(len(s))}) + + # ensure the following operations do not raise an error + s[s > 0.5] + df[s > 0.5] + s.at[0] + df.at[0, "a"] + + +def test_to_numpy_keyword(): + # test the extra keyword + values = [decimal.Decimal("1.1111"), decimal.Decimal("2.2222")] + expected = np.array( + [decimal.Decimal("1.11"), decimal.Decimal("2.22")], dtype="object" + ) + a = pd.array(values, dtype="decimal") + result = a.to_numpy(decimals=2) + tm.assert_numpy_array_equal(result, expected) + + result = pd.Series(a).to_numpy(decimals=2) + tm.assert_numpy_array_equal(result, expected) + + +def test_array_copy_on_write(using_copy_on_write): + df = pd.DataFrame({"a": [decimal.Decimal(2), decimal.Decimal(3)]}, dtype="object") + df2 = df.astype(DecimalDtype()) + df.iloc[0, 0] = 0 + if using_copy_on_write: + expected = pd.DataFrame( + {"a": [decimal.Decimal(2), decimal.Decimal(3)]}, dtype=DecimalDtype() + ) + tm.assert_equal(df2.values, expected.values) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/json/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/json/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7ebfd54a5b0d6bf1ff2c4602ed72f5214e32608f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/json/__init__.py @@ -0,0 +1,7 @@ +from pandas.tests.extension.json.array import ( + JSONArray, + JSONDtype, + make_data, +) + +__all__ = ["JSONArray", "JSONDtype", "make_data"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/json/array.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/json/array.py new file mode 100644 index 0000000000000000000000000000000000000000..e43b50322bb925778c0f8cf7734d83aa7b51add7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/json/array.py @@ -0,0 +1,256 @@ +""" +Test extension array for storing nested data in a pandas container. + +The JSONArray stores lists of dictionaries. The storage mechanism is a list, +not an ndarray. + +Note +---- +We currently store lists of UserDicts. Pandas has a few places +internally that specifically check for dicts, and does non-scalar things +in that case. We *want* the dictionaries to be treated as scalars, so we +hack around pandas by using UserDicts. +""" +from __future__ import annotations + +from collections import ( + UserDict, + abc, +) +import itertools +import numbers +import string +import sys +from typing import ( + TYPE_CHECKING, + Any, +) + +import numpy as np + +from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike +from pandas.core.dtypes.common import ( + is_bool_dtype, + is_list_like, + pandas_dtype, +) + +import pandas as pd +from pandas.api.extensions import ( + ExtensionArray, + ExtensionDtype, +) +from pandas.core.indexers import unpack_tuple_and_ellipses + +if TYPE_CHECKING: + from collections.abc import Mapping + + from pandas._typing import type_t + + +class JSONDtype(ExtensionDtype): + type = abc.Mapping + name = "json" + na_value: Mapping[str, Any] = UserDict() + + @classmethod + def construct_array_type(cls) -> type_t[JSONArray]: + """ + Return the array type associated with this dtype. + + Returns + ------- + type + """ + return JSONArray + + +class JSONArray(ExtensionArray): + dtype = JSONDtype() + __array_priority__ = 1000 + + def __init__(self, values, dtype=None, copy=False) -> None: + for val in values: + if not isinstance(val, self.dtype.type): + raise TypeError("All values must be of type " + str(self.dtype.type)) + self.data = values + + # Some aliases for common attribute names to ensure pandas supports + # these + self._items = self._data = self.data + # those aliases are currently not working due to assumptions + # in internal code (GH-20735) + # self._values = self.values = self.data + + @classmethod + def _from_sequence(cls, scalars, *, dtype=None, copy=False): + return cls(scalars) + + @classmethod + def _from_factorized(cls, values, original): + return cls([UserDict(x) for x in values if x != ()]) + + def __getitem__(self, item): + if isinstance(item, tuple): + item = unpack_tuple_and_ellipses(item) + + if isinstance(item, numbers.Integral): + return self.data[item] + elif isinstance(item, slice) and item == slice(None): + # Make sure we get a view + return type(self)(self.data) + elif isinstance(item, slice): + # slice + return type(self)(self.data[item]) + elif not is_list_like(item): + # e.g. "foo" or 2.5 + # exception message copied from numpy + raise IndexError( + r"only integers, slices (`:`), ellipsis (`...`), numpy.newaxis " + r"(`None`) and integer or boolean arrays are valid indices" + ) + else: + item = pd.api.indexers.check_array_indexer(self, item) + if is_bool_dtype(item.dtype): + return type(self)._from_sequence( + [x for x, m in zip(self, item) if m], dtype=self.dtype + ) + # integer + return type(self)([self.data[i] for i in item]) + + def __setitem__(self, key, value) -> None: + if isinstance(key, numbers.Integral): + self.data[key] = value + else: + if not isinstance(value, (type(self), abc.Sequence)): + # broadcast value + value = itertools.cycle([value]) + + if isinstance(key, np.ndarray) and key.dtype == "bool": + # masking + for i, (k, v) in enumerate(zip(key, value)): + if k: + assert isinstance(v, self.dtype.type) + self.data[i] = v + else: + for k, v in zip(key, value): + assert isinstance(v, self.dtype.type) + self.data[k] = v + + def __len__(self) -> int: + return len(self.data) + + def __eq__(self, other): + return NotImplemented + + def __ne__(self, other): + return NotImplemented + + def __array__(self, dtype=None, copy=None): + if dtype is None: + dtype = object + if dtype == object: + # on py38 builds it looks like numpy is inferring to a non-1D array + return construct_1d_object_array_from_listlike(list(self)) + return np.asarray(self.data, dtype=dtype) + + @property + def nbytes(self) -> int: + return sys.getsizeof(self.data) + + def isna(self): + return np.array([x == self.dtype.na_value for x in self.data], dtype=bool) + + def take(self, indexer, allow_fill=False, fill_value=None): + # re-implement here, since NumPy has trouble setting + # sized objects like UserDicts into scalar slots of + # an ndarary. + indexer = np.asarray(indexer) + msg = ( + "Index is out of bounds or cannot do a " + "non-empty take from an empty array." + ) + + if allow_fill: + if fill_value is None: + fill_value = self.dtype.na_value + # bounds check + if (indexer < -1).any(): + raise ValueError + try: + output = [ + self.data[loc] if loc != -1 else fill_value for loc in indexer + ] + except IndexError as err: + raise IndexError(msg) from err + else: + try: + output = [self.data[loc] for loc in indexer] + except IndexError as err: + raise IndexError(msg) from err + + return type(self)._from_sequence(output, dtype=self.dtype) + + def copy(self): + return type(self)(self.data[:]) + + def astype(self, dtype, copy=True): + # NumPy has issues when all the dicts are the same length. + # np.array([UserDict(...), UserDict(...)]) fails, + # but np.array([{...}, {...}]) works, so cast. + from pandas.core.arrays.string_ import StringDtype + + dtype = pandas_dtype(dtype) + # needed to add this check for the Series constructor + if isinstance(dtype, type(self.dtype)) and dtype == self.dtype: + if copy: + return self.copy() + return self + elif isinstance(dtype, StringDtype): + value = self.astype(str) # numpy doesn't like nested dicts + arr_cls = dtype.construct_array_type() + return arr_cls._from_sequence(value, dtype=dtype, copy=False) + elif not copy: + return np.asarray([dict(x) for x in self], dtype=dtype) + else: + return np.array([dict(x) for x in self], dtype=dtype, copy=copy) + + def unique(self): + # Parent method doesn't work since np.array will try to infer + # a 2-dim object. + return type(self)([dict(x) for x in {tuple(d.items()) for d in self.data}]) + + @classmethod + def _concat_same_type(cls, to_concat): + data = list(itertools.chain.from_iterable(x.data for x in to_concat)) + return cls(data) + + def _values_for_factorize(self): + frozen = self._values_for_argsort() + if len(frozen) == 0: + # factorize_array expects 1-d array, this is a len-0 2-d array. + frozen = frozen.ravel() + return frozen, () + + def _values_for_argsort(self): + # Bypass NumPy's shape inference to get a (N,) array of tuples. + frozen = [tuple(x.items()) for x in self] + return construct_1d_object_array_from_listlike(frozen) + + def _pad_or_backfill(self, *, method, limit=None, copy=True): + # GH#56616 - test EA method without limit_area argument + return super()._pad_or_backfill(method=method, limit=limit, copy=copy) + + +def make_data(): + # TODO: Use a regular dict. See _NDFrameIndexer._setitem_with_indexer + rng = np.random.default_rng(2) + return [ + UserDict( + [ + (rng.choice(list(string.ascii_letters)), rng.integers(0, 100)) + for _ in range(rng.integers(0, 10)) + ] + ) + for _ in range(100) + ] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/json/test_json.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/json/test_json.py new file mode 100644 index 0000000000000000000000000000000000000000..a18edac9aef93804bd02698dd0b44d5b31f6b887 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/json/test_json.py @@ -0,0 +1,490 @@ +import collections +import operator +import sys + +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm +from pandas.tests.extension import base +from pandas.tests.extension.json.array import ( + JSONArray, + JSONDtype, + make_data, +) + +# We intentionally don't run base.BaseSetitemTests because pandas' +# internals has trouble setting sequences of values into scalar positions. +unhashable = pytest.mark.xfail(reason="Unhashable") + + +@pytest.fixture +def dtype(): + return JSONDtype() + + +@pytest.fixture +def data(): + """Length-100 PeriodArray for semantics test.""" + data = make_data() + + # Why the while loop? NumPy is unable to construct an ndarray from + # equal-length ndarrays. Many of our operations involve coercing the + # EA to an ndarray of objects. To avoid random test failures, we ensure + # that our data is coercible to an ndarray. Several tests deal with only + # the first two elements, so that's what we'll check. + + while len(data[0]) == len(data[1]): + data = make_data() + + return JSONArray(data) + + +@pytest.fixture +def data_missing(): + """Length 2 array with [NA, Valid]""" + return JSONArray([{}, {"a": 10}]) + + +@pytest.fixture +def data_for_sorting(): + return JSONArray([{"b": 1}, {"c": 4}, {"a": 2, "c": 3}]) + + +@pytest.fixture +def data_missing_for_sorting(): + return JSONArray([{"b": 1}, {}, {"a": 4}]) + + +@pytest.fixture +def na_cmp(): + return operator.eq + + +@pytest.fixture +def data_for_grouping(): + return JSONArray( + [ + {"b": 1}, + {"b": 1}, + {}, + {}, + {"a": 0, "c": 2}, + {"a": 0, "c": 2}, + {"b": 1}, + {"c": 2}, + ] + ) + + +class TestJSONArray(base.ExtensionTests): + @pytest.mark.xfail( + reason="comparison method not implemented for JSONArray (GH-37867)" + ) + def test_contains(self, data): + # GH-37867 + super().test_contains(data) + + @pytest.mark.xfail(reason="not implemented constructor from dtype") + def test_from_dtype(self, data): + # construct from our dtype & string dtype + super().test_from_dtype(data) + + @pytest.mark.xfail(reason="RecursionError, GH-33900") + def test_series_constructor_no_data_with_index(self, dtype, na_value): + # RecursionError: maximum recursion depth exceeded in comparison + rec_limit = sys.getrecursionlimit() + try: + # Limit to avoid stack overflow on Windows CI + sys.setrecursionlimit(100) + super().test_series_constructor_no_data_with_index(dtype, na_value) + finally: + sys.setrecursionlimit(rec_limit) + + @pytest.mark.xfail(reason="RecursionError, GH-33900") + def test_series_constructor_scalar_na_with_index(self, dtype, na_value): + # RecursionError: maximum recursion depth exceeded in comparison + rec_limit = sys.getrecursionlimit() + try: + # Limit to avoid stack overflow on Windows CI + sys.setrecursionlimit(100) + super().test_series_constructor_scalar_na_with_index(dtype, na_value) + finally: + sys.setrecursionlimit(rec_limit) + + @pytest.mark.xfail(reason="collection as scalar, GH-33901") + def test_series_constructor_scalar_with_index(self, data, dtype): + # TypeError: All values must be of type + rec_limit = sys.getrecursionlimit() + try: + # Limit to avoid stack overflow on Windows CI + sys.setrecursionlimit(100) + super().test_series_constructor_scalar_with_index(data, dtype) + finally: + sys.setrecursionlimit(rec_limit) + + @pytest.mark.xfail(reason="Different definitions of NA") + def test_stack(self): + """ + The test does .astype(object).stack(future_stack=True). If we happen to have + any missing values in `data`, then we'll end up with different + rows since we consider `{}` NA, but `.astype(object)` doesn't. + """ + super().test_stack() + + @pytest.mark.xfail(reason="dict for NA") + def test_unstack(self, data, index): + # The base test has NaN for the expected NA value. + # this matches otherwise + return super().test_unstack(data, index) + + @pytest.mark.xfail(reason="Setting a dict as a scalar") + def test_fillna_series(self): + """We treat dictionaries as a mapping in fillna, not a scalar.""" + super().test_fillna_series() + + @pytest.mark.xfail(reason="Setting a dict as a scalar") + def test_fillna_frame(self): + """We treat dictionaries as a mapping in fillna, not a scalar.""" + super().test_fillna_frame() + + @pytest.mark.parametrize( + "limit_area, input_ilocs, expected_ilocs", + [ + ("outside", [1, 0, 0, 0, 1], [1, 0, 0, 0, 1]), + ("outside", [1, 0, 1, 0, 1], [1, 0, 1, 0, 1]), + ("outside", [0, 1, 1, 1, 0], [0, 1, 1, 1, 1]), + ("outside", [0, 1, 0, 1, 0], [0, 1, 0, 1, 1]), + ("inside", [1, 0, 0, 0, 1], [1, 1, 1, 1, 1]), + ("inside", [1, 0, 1, 0, 1], [1, 1, 1, 1, 1]), + ("inside", [0, 1, 1, 1, 0], [0, 1, 1, 1, 0]), + ("inside", [0, 1, 0, 1, 0], [0, 1, 1, 1, 0]), + ], + ) + def test_ffill_limit_area( + self, data_missing, limit_area, input_ilocs, expected_ilocs + ): + # GH#56616 + msg = "JSONArray does not implement limit_area" + with pytest.raises(NotImplementedError, match=msg): + super().test_ffill_limit_area( + data_missing, limit_area, input_ilocs, expected_ilocs + ) + + @unhashable + def test_value_counts(self, all_data, dropna): + super().test_value_counts(all_data, dropna) + + @unhashable + def test_value_counts_with_normalize(self, data): + super().test_value_counts_with_normalize(data) + + @unhashable + def test_sort_values_frame(self): + # TODO (EA.factorize): see if _values_for_factorize allows this. + super().test_sort_values_frame() + + @pytest.mark.parametrize("ascending", [True, False]) + def test_sort_values(self, data_for_sorting, ascending, sort_by_key): + super().test_sort_values(data_for_sorting, ascending, sort_by_key) + + @pytest.mark.parametrize("ascending", [True, False]) + def test_sort_values_missing( + self, data_missing_for_sorting, ascending, sort_by_key + ): + super().test_sort_values_missing( + data_missing_for_sorting, ascending, sort_by_key + ) + + @pytest.mark.xfail(reason="combine for JSONArray not supported") + def test_combine_le(self, data_repeated): + super().test_combine_le(data_repeated) + + @pytest.mark.xfail( + reason="combine for JSONArray not supported - " + "may pass depending on random data", + strict=False, + raises=AssertionError, + ) + def test_combine_first(self, data): + super().test_combine_first(data) + + @pytest.mark.xfail(reason="broadcasting error") + def test_where_series(self, data, na_value): + # Fails with + # *** ValueError: operands could not be broadcast together + # with shapes (4,) (4,) (0,) + super().test_where_series(data, na_value) + + @pytest.mark.xfail(reason="Can't compare dicts.") + def test_searchsorted(self, data_for_sorting): + super().test_searchsorted(data_for_sorting) + + @pytest.mark.xfail(reason="Can't compare dicts.") + def test_equals(self, data, na_value, as_series): + super().test_equals(data, na_value, as_series) + + @pytest.mark.skip("fill-value is interpreted as a dict of values") + def test_fillna_copy_frame(self, data_missing): + super().test_fillna_copy_frame(data_missing) + + def test_equals_same_data_different_object( + self, data, using_copy_on_write, request + ): + if using_copy_on_write: + mark = pytest.mark.xfail(reason="Fails with CoW") + request.applymarker(mark) + super().test_equals_same_data_different_object(data) + + @pytest.mark.xfail(reason="failing on np.array(self, dtype=str)") + def test_astype_str(self): + """This currently fails in NumPy on np.array(self, dtype=str) with + + *** ValueError: setting an array element with a sequence + """ + super().test_astype_str() + + @unhashable + def test_groupby_extension_transform(self): + """ + This currently fails in Series.name.setter, since the + name must be hashable, but the value is a dictionary. + I think this is what we want, i.e. `.name` should be the original + values, and not the values for factorization. + """ + super().test_groupby_extension_transform() + + @unhashable + def test_groupby_extension_apply(self): + """ + This fails in Index._do_unique_check with + + > hash(val) + E TypeError: unhashable type: 'UserDict' with + + I suspect that once we support Index[ExtensionArray], + we'll be able to dispatch unique. + """ + super().test_groupby_extension_apply() + + @unhashable + def test_groupby_extension_agg(self): + """ + This fails when we get to tm.assert_series_equal when left.index + contains dictionaries, which are not hashable. + """ + super().test_groupby_extension_agg() + + @unhashable + def test_groupby_extension_no_sort(self): + """ + This fails when we get to tm.assert_series_equal when left.index + contains dictionaries, which are not hashable. + """ + super().test_groupby_extension_no_sort() + + def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request): + if len(data[0]) != 1: + mark = pytest.mark.xfail(reason="raises in coercing to Series") + request.applymarker(mark) + super().test_arith_frame_with_scalar(data, all_arithmetic_operators) + + def test_compare_array(self, data, comparison_op, request): + if comparison_op.__name__ in ["eq", "ne"]: + mark = pytest.mark.xfail(reason="Comparison methods not implemented") + request.applymarker(mark) + super().test_compare_array(data, comparison_op) + + @pytest.mark.xfail(reason="ValueError: Must have equal len keys and value") + def test_setitem_loc_scalar_mixed(self, data): + super().test_setitem_loc_scalar_mixed(data) + + @pytest.mark.xfail(reason="ValueError: Must have equal len keys and value") + def test_setitem_loc_scalar_multiple_homogoneous(self, data): + super().test_setitem_loc_scalar_multiple_homogoneous(data) + + @pytest.mark.xfail(reason="ValueError: Must have equal len keys and value") + def test_setitem_iloc_scalar_mixed(self, data): + super().test_setitem_iloc_scalar_mixed(data) + + @pytest.mark.xfail(reason="ValueError: Must have equal len keys and value") + def test_setitem_iloc_scalar_multiple_homogoneous(self, data): + super().test_setitem_iloc_scalar_multiple_homogoneous(data) + + @pytest.mark.parametrize( + "mask", + [ + np.array([True, True, True, False, False]), + pd.array([True, True, True, False, False], dtype="boolean"), + pd.array([True, True, True, pd.NA, pd.NA], dtype="boolean"), + ], + ids=["numpy-array", "boolean-array", "boolean-array-na"], + ) + def test_setitem_mask(self, data, mask, box_in_series, request): + if box_in_series: + mark = pytest.mark.xfail( + reason="cannot set using a list-like indexer with a different length" + ) + request.applymarker(mark) + elif not isinstance(mask, np.ndarray): + mark = pytest.mark.xfail(reason="Issues unwanted DeprecationWarning") + request.applymarker(mark) + super().test_setitem_mask(data, mask, box_in_series) + + def test_setitem_mask_raises(self, data, box_in_series, request): + if not box_in_series: + mark = pytest.mark.xfail(reason="Fails to raise") + request.applymarker(mark) + + super().test_setitem_mask_raises(data, box_in_series) + + @pytest.mark.xfail( + reason="cannot set using a list-like indexer with a different length" + ) + def test_setitem_mask_boolean_array_with_na(self, data, box_in_series): + super().test_setitem_mask_boolean_array_with_na(data, box_in_series) + + @pytest.mark.parametrize( + "idx", + [[0, 1, 2], pd.array([0, 1, 2], dtype="Int64"), np.array([0, 1, 2])], + ids=["list", "integer-array", "numpy-array"], + ) + def test_setitem_integer_array(self, data, idx, box_in_series, request): + if box_in_series: + mark = pytest.mark.xfail( + reason="cannot set using a list-like indexer with a different length" + ) + request.applymarker(mark) + super().test_setitem_integer_array(data, idx, box_in_series) + + @pytest.mark.xfail(reason="list indices must be integers or slices, not NAType") + @pytest.mark.parametrize( + "idx, box_in_series", + [ + ([0, 1, 2, pd.NA], False), + pytest.param( + [0, 1, 2, pd.NA], True, marks=pytest.mark.xfail(reason="GH-31948") + ), + (pd.array([0, 1, 2, pd.NA], dtype="Int64"), False), + (pd.array([0, 1, 2, pd.NA], dtype="Int64"), False), + ], + ids=["list-False", "list-True", "integer-array-False", "integer-array-True"], + ) + def test_setitem_integer_with_missing_raises(self, data, idx, box_in_series): + super().test_setitem_integer_with_missing_raises(data, idx, box_in_series) + + @pytest.mark.xfail(reason="Fails to raise") + def test_setitem_scalar_key_sequence_raise(self, data): + super().test_setitem_scalar_key_sequence_raise(data) + + def test_setitem_with_expansion_dataframe_column(self, data, full_indexer, request): + if "full_slice" in request.node.name: + mark = pytest.mark.xfail(reason="slice is not iterable") + request.applymarker(mark) + super().test_setitem_with_expansion_dataframe_column(data, full_indexer) + + @pytest.mark.xfail(reason="slice is not iterable") + def test_setitem_frame_2d_values(self, data): + super().test_setitem_frame_2d_values(data) + + @pytest.mark.xfail( + reason="cannot set using a list-like indexer with a different length" + ) + @pytest.mark.parametrize("setter", ["loc", None]) + def test_setitem_mask_broadcast(self, data, setter): + super().test_setitem_mask_broadcast(data, setter) + + @pytest.mark.xfail( + reason="cannot set using a slice indexer with a different length" + ) + def test_setitem_slice(self, data, box_in_series): + super().test_setitem_slice(data, box_in_series) + + @pytest.mark.xfail(reason="slice object is not iterable") + def test_setitem_loc_iloc_slice(self, data): + super().test_setitem_loc_iloc_slice(data) + + @pytest.mark.xfail(reason="slice object is not iterable") + def test_setitem_slice_mismatch_length_raises(self, data): + super().test_setitem_slice_mismatch_length_raises(data) + + @pytest.mark.xfail(reason="slice object is not iterable") + def test_setitem_slice_array(self, data): + super().test_setitem_slice_array(data) + + @pytest.mark.xfail(reason="Fail to raise") + def test_setitem_invalid(self, data, invalid_scalar): + super().test_setitem_invalid(data, invalid_scalar) + + @pytest.mark.xfail(reason="only integer scalar arrays can be converted") + def test_setitem_2d_values(self, data): + super().test_setitem_2d_values(data) + + @pytest.mark.xfail(reason="data type 'json' not understood") + @pytest.mark.parametrize("engine", ["c", "python"]) + def test_EA_types(self, engine, data, request): + super().test_EA_types(engine, data, request) + + +def custom_assert_series_equal(left, right, *args, **kwargs): + # NumPy doesn't handle an array of equal-length UserDicts. + # The default assert_series_equal eventually does a + # Series.values, which raises. We work around it by + # converting the UserDicts to dicts. + if left.dtype.name == "json": + assert left.dtype == right.dtype + left = pd.Series( + JSONArray(left.values.astype(object)), index=left.index, name=left.name + ) + right = pd.Series( + JSONArray(right.values.astype(object)), + index=right.index, + name=right.name, + ) + tm.assert_series_equal(left, right, *args, **kwargs) + + +def custom_assert_frame_equal(left, right, *args, **kwargs): + obj_type = kwargs.get("obj", "DataFrame") + tm.assert_index_equal( + left.columns, + right.columns, + exact=kwargs.get("check_column_type", "equiv"), + check_names=kwargs.get("check_names", True), + check_exact=kwargs.get("check_exact", False), + check_categorical=kwargs.get("check_categorical", True), + obj=f"{obj_type}.columns", + ) + + jsons = (left.dtypes == "json").index + + for col in jsons: + custom_assert_series_equal(left[col], right[col], *args, **kwargs) + + left = left.drop(columns=jsons) + right = right.drop(columns=jsons) + tm.assert_frame_equal(left, right, *args, **kwargs) + + +def test_custom_asserts(): + # This would always trigger the KeyError from trying to put + # an array of equal-length UserDicts inside an ndarray. + data = JSONArray( + [ + collections.UserDict({"a": 1}), + collections.UserDict({"b": 2}), + collections.UserDict({"c": 3}), + ] + ) + a = pd.Series(data) + custom_assert_series_equal(a, a) + custom_assert_frame_equal(a.to_frame(), a.to_frame()) + + b = pd.Series(data.take([0, 0, 1])) + msg = r"Series are different" + with pytest.raises(AssertionError, match=msg): + custom_assert_series_equal(a, b) + + with pytest.raises(AssertionError, match=msg): + custom_assert_frame_equal(a.to_frame(), b.to_frame()) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/list/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/list/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f3f2f35377882a0fae603edfc8edb46371429fe --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/list/__init__.py @@ -0,0 +1,7 @@ +from pandas.tests.extension.list.array import ( + ListArray, + ListDtype, + make_data, +) + +__all__ = ["ListArray", "ListDtype", "make_data"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/list/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/list/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b76d870097614e199fb68ec21e3f833e1f22fba Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/list/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/list/__pycache__/test_list.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/list/__pycache__/test_list.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c72c06631b6294922a53e928a70db671365898f1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/list/__pycache__/test_list.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/list/array.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/list/array.py new file mode 100644 index 0000000000000000000000000000000000000000..b3bb35c9396f4d1748fff37b7334c68a0b055daf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/list/array.py @@ -0,0 +1,137 @@ +""" +Test extension array for storing nested data in a pandas container. + +The ListArray stores an ndarray of lists. +""" +from __future__ import annotations + +import numbers +import string +from typing import TYPE_CHECKING + +import numpy as np + +from pandas.core.dtypes.base import ExtensionDtype + +import pandas as pd +from pandas.api.types import ( + is_object_dtype, + is_string_dtype, +) +from pandas.core.arrays import ExtensionArray + +if TYPE_CHECKING: + from pandas._typing import type_t + + +class ListDtype(ExtensionDtype): + type = list + name = "list" + na_value = np.nan + + @classmethod + def construct_array_type(cls) -> type_t[ListArray]: + """ + Return the array type associated with this dtype. + + Returns + ------- + type + """ + return ListArray + + +class ListArray(ExtensionArray): + dtype = ListDtype() + __array_priority__ = 1000 + + def __init__(self, values, dtype=None, copy=False) -> None: + if not isinstance(values, np.ndarray): + raise TypeError("Need to pass a numpy array as values") + for val in values: + if not isinstance(val, self.dtype.type) and not pd.isna(val): + raise TypeError("All values must be of type " + str(self.dtype.type)) + self.data = values + + @classmethod + def _from_sequence(cls, scalars, *, dtype=None, copy=False): + data = np.empty(len(scalars), dtype=object) + data[:] = scalars + return cls(data) + + def __getitem__(self, item): + if isinstance(item, numbers.Integral): + return self.data[item] + else: + # slice, list-like, mask + return type(self)(self.data[item]) + + def __len__(self) -> int: + return len(self.data) + + def isna(self): + return np.array( + [not isinstance(x, list) and np.isnan(x) for x in self.data], dtype=bool + ) + + def take(self, indexer, allow_fill=False, fill_value=None): + # re-implement here, since NumPy has trouble setting + # sized objects like UserDicts into scalar slots of + # an ndarary. + indexer = np.asarray(indexer) + msg = ( + "Index is out of bounds or cannot do a " + "non-empty take from an empty array." + ) + + if allow_fill: + if fill_value is None: + fill_value = self.dtype.na_value + # bounds check + if (indexer < -1).any(): + raise ValueError + try: + output = [ + self.data[loc] if loc != -1 else fill_value for loc in indexer + ] + except IndexError as err: + raise IndexError(msg) from err + else: + try: + output = [self.data[loc] for loc in indexer] + except IndexError as err: + raise IndexError(msg) from err + + return self._from_sequence(output) + + def copy(self): + return type(self)(self.data[:]) + + def astype(self, dtype, copy=True): + if isinstance(dtype, type(self.dtype)) and dtype == self.dtype: + if copy: + return self.copy() + return self + elif is_string_dtype(dtype) and not is_object_dtype(dtype): + # numpy has problems with astype(str) for nested elements + return np.array([str(x) for x in self.data], dtype=dtype) + elif not copy: + return np.asarray(self.data, dtype=dtype) + else: + return np.array(self.data, dtype=dtype, copy=copy) + + @classmethod + def _concat_same_type(cls, to_concat): + data = np.concatenate([x.data for x in to_concat]) + return cls(data) + + +def make_data(): + # TODO: Use a regular dict. See _NDFrameIndexer._setitem_with_indexer + rng = np.random.default_rng(2) + data = np.empty(100, dtype=object) + data[:] = [ + [rng.choice(list(string.ascii_letters)) for _ in range(rng.integers(0, 10))] + for _ in range(100) + ] + return data diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/list/test_list.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/list/test_list.py new file mode 100644 index 0000000000000000000000000000000000000000..ac396cd3c60d435d34f95d5027d80d116d4560d5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/extension/list/test_list.py @@ -0,0 +1,33 @@ +import pytest + +import pandas as pd +from pandas.tests.extension.list.array import ( + ListArray, + ListDtype, + make_data, +) + + +@pytest.fixture +def dtype(): + return ListDtype() + + +@pytest.fixture +def data(): + """Length-100 ListArray for semantics test.""" + data = make_data() + + while len(data[0]) == len(data[1]): + data = make_data() + + return ListArray(data) + + +def test_to_csv(data): + # https://github.com/pandas-dev/pandas/issues/28840 + # array with list-likes fail when doing astype(str) on the numpy array + # which was done in get_values_for_csv + df = pd.DataFrame({"a": data}) + res = df.to_csv() + assert str(data[0]) in res diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/common.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/common.py new file mode 100644 index 0000000000000000000000000000000000000000..fc41d7907a240f0dd9dc19e0ae1296bee86be421 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/common.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pandas import ( + DataFrame, + concat, +) + +if TYPE_CHECKING: + from pandas._typing import AxisInt + + +def _check_mixed_float(df, dtype=None): + # float16 are most likely to be upcasted to float32 + dtypes = {"A": "float32", "B": "float32", "C": "float16", "D": "float64"} + if isinstance(dtype, str): + dtypes = {k: dtype for k, v in dtypes.items()} + elif isinstance(dtype, dict): + dtypes.update(dtype) + if dtypes.get("A"): + assert df.dtypes["A"] == dtypes["A"] + if dtypes.get("B"): + assert df.dtypes["B"] == dtypes["B"] + if dtypes.get("C"): + assert df.dtypes["C"] == dtypes["C"] + if dtypes.get("D"): + assert df.dtypes["D"] == dtypes["D"] + + +def _check_mixed_int(df, dtype=None): + dtypes = {"A": "int32", "B": "uint64", "C": "uint8", "D": "int64"} + if isinstance(dtype, str): + dtypes = {k: dtype for k, v in dtypes.items()} + elif isinstance(dtype, dict): + dtypes.update(dtype) + if dtypes.get("A"): + assert df.dtypes["A"] == dtypes["A"] + if dtypes.get("B"): + assert df.dtypes["B"] == dtypes["B"] + if dtypes.get("C"): + assert df.dtypes["C"] == dtypes["C"] + if dtypes.get("D"): + assert df.dtypes["D"] == dtypes["D"] + + +def zip_frames(frames: list[DataFrame], axis: AxisInt = 1) -> DataFrame: + """ + take a list of frames, zip them together under the + assumption that these all have the first frames' index/columns. + + Returns + ------- + new_frame : DataFrame + """ + if axis == 1: + columns = frames[0].columns + zipped = [f.loc[:, c] for c in columns for f in frames] + return concat(zipped, axis=1) + else: + index = frames[0].index + zipped = [f.loc[i, :] for i in index for f in frames] + return DataFrame(zipped) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/conftest.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..e07024b2e2a097a8442db2039323f45aa18598de --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/conftest.py @@ -0,0 +1,100 @@ +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Index, + NaT, + date_range, +) + + +@pytest.fixture +def datetime_frame() -> DataFrame: + """ + Fixture for DataFrame of floats with DatetimeIndex + + Columns are ['A', 'B', 'C', 'D'] + """ + return DataFrame( + np.random.default_rng(2).standard_normal((100, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=100, freq="B"), + ) + + +@pytest.fixture +def float_string_frame(): + """ + Fixture for DataFrame of floats and strings with index of unique strings + + Columns are ['A', 'B', 'C', 'D', 'foo']. + """ + df = DataFrame( + np.random.default_rng(2).standard_normal((30, 4)), + index=Index([f"foo_{i}" for i in range(30)], dtype=object), + columns=Index(list("ABCD"), dtype=object), + ) + df["foo"] = "bar" + return df + + +@pytest.fixture +def mixed_float_frame(): + """ + Fixture for DataFrame of different float types with index of unique strings + + Columns are ['A', 'B', 'C', 'D']. + """ + df = DataFrame( + { + col: np.random.default_rng(2).random(30, dtype=dtype) + for col, dtype in zip( + list("ABCD"), ["float32", "float32", "float32", "float64"] + ) + }, + index=Index([f"foo_{i}" for i in range(30)], dtype=object), + ) + # not supported by numpy random + df["C"] = df["C"].astype("float16") + return df + + +@pytest.fixture +def mixed_int_frame(): + """ + Fixture for DataFrame of different int types with index of unique strings + + Columns are ['A', 'B', 'C', 'D']. + """ + return DataFrame( + { + col: np.ones(30, dtype=dtype) + for col, dtype in zip(list("ABCD"), ["int32", "uint64", "uint8", "int64"]) + }, + index=Index([f"foo_{i}" for i in range(30)], dtype=object), + ) + + +@pytest.fixture +def timezone_frame(): + """ + Fixture for DataFrame of date_range Series with different time zones + + Columns are ['A', 'B', 'C']; some entries are missing + + A B C + 0 2013-01-01 2013-01-01 00:00:00-05:00 2013-01-01 00:00:00+01:00 + 1 2013-01-02 NaT NaT + 2 2013-01-03 2013-01-03 00:00:00-05:00 2013-01-03 00:00:00+01:00 + """ + df = DataFrame( + { + "A": date_range("20130101", periods=3), + "B": date_range("20130101", periods=3, tz="US/Eastern"), + "C": date_range("20130101", periods=3, tz="CET"), + } + ) + df.iloc[1, 1] = NaT + df.iloc[1, 2] = NaT + return df diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_alter_axes.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_alter_axes.py new file mode 100644 index 0000000000000000000000000000000000000000..c68171ab254c7c8582a206a8e9b44b3845c47efc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_alter_axes.py @@ -0,0 +1,30 @@ +from datetime import datetime + +import pytz + +from pandas import DataFrame +import pandas._testing as tm + + +class TestDataFrameAlterAxes: + # Tests for setting index/columns attributes directly (i.e. __setattr__) + + def test_set_axis_setattr_index(self): + # GH 6785 + # set the index manually + + df = DataFrame([{"ts": datetime(2014, 4, 1, tzinfo=pytz.utc), "foo": 1}]) + expected = df.set_index("ts") + df.index = df["ts"] + df.pop("ts") + tm.assert_frame_equal(df, expected) + + # Renaming + + def test_assign_columns(self, float_frame): + float_frame["hi"] = "there" + + df = float_frame.copy() + df.columns = ["foo", "bar", "baz", "quux", "foo2"] + tm.assert_series_equal(float_frame["C"], df["baz"], check_names=False) + tm.assert_series_equal(float_frame["hi"], df["foo2"], check_names=False) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_api.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_api.py new file mode 100644 index 0000000000000000000000000000000000000000..c7b444045a0f23ea9d7b9ad94a1244b0b320fee6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_api.py @@ -0,0 +1,392 @@ +from copy import deepcopy +import inspect +import pydoc + +import numpy as np +import pytest + +from pandas._config import using_pyarrow_string_dtype +from pandas._config.config import option_context + +import pandas as pd +from pandas import ( + DataFrame, + Series, + date_range, + timedelta_range, +) +import pandas._testing as tm + + +class TestDataFrameMisc: + def test_getitem_pop_assign_name(self, float_frame): + s = float_frame["A"] + assert s.name == "A" + + s = float_frame.pop("A") + assert s.name == "A" + + s = float_frame.loc[:, "B"] + assert s.name == "B" + + s2 = s.loc[:] + assert s2.name == "B" + + def test_get_axis(self, float_frame): + f = float_frame + assert f._get_axis_number(0) == 0 + assert f._get_axis_number(1) == 1 + assert f._get_axis_number("index") == 0 + assert f._get_axis_number("rows") == 0 + assert f._get_axis_number("columns") == 1 + + assert f._get_axis_name(0) == "index" + assert f._get_axis_name(1) == "columns" + assert f._get_axis_name("index") == "index" + assert f._get_axis_name("rows") == "index" + assert f._get_axis_name("columns") == "columns" + + assert f._get_axis(0) is f.index + assert f._get_axis(1) is f.columns + + with pytest.raises(ValueError, match="No axis named"): + f._get_axis_number(2) + + with pytest.raises(ValueError, match="No axis.*foo"): + f._get_axis_name("foo") + + with pytest.raises(ValueError, match="No axis.*None"): + f._get_axis_name(None) + + with pytest.raises(ValueError, match="No axis named"): + f._get_axis_number(None) + + def test_column_contains_raises(self, float_frame): + with pytest.raises(TypeError, match="unhashable type: 'Index'"): + float_frame.columns in float_frame + + def test_tab_completion(self): + # DataFrame whose columns are identifiers shall have them in __dir__. + df = DataFrame([list("abcd"), list("efgh")], columns=list("ABCD")) + for key in list("ABCD"): + assert key in dir(df) + assert isinstance(df.__getitem__("A"), Series) + + # DataFrame whose first-level columns are identifiers shall have + # them in __dir__. + df = DataFrame( + [list("abcd"), list("efgh")], + columns=pd.MultiIndex.from_tuples(list(zip("ABCD", "EFGH"))), + ) + for key in list("ABCD"): + assert key in dir(df) + for key in list("EFGH"): + assert key not in dir(df) + assert isinstance(df.__getitem__("A"), DataFrame) + + def test_display_max_dir_items(self): + # display.max_dir_items increaes the number of columns that are in __dir__. + columns = ["a" + str(i) for i in range(420)] + values = [range(420), range(420)] + df = DataFrame(values, columns=columns) + + # The default value for display.max_dir_items is 100 + assert "a99" in dir(df) + assert "a100" not in dir(df) + + with option_context("display.max_dir_items", 300): + df = DataFrame(values, columns=columns) + assert "a299" in dir(df) + assert "a300" not in dir(df) + + with option_context("display.max_dir_items", None): + df = DataFrame(values, columns=columns) + assert "a419" in dir(df) + + def test_not_hashable(self): + empty_frame = DataFrame() + + df = DataFrame([1]) + msg = "unhashable type: 'DataFrame'" + with pytest.raises(TypeError, match=msg): + hash(df) + with pytest.raises(TypeError, match=msg): + hash(empty_frame) + + @pytest.mark.xfail(using_pyarrow_string_dtype(), reason="surrogates not allowed") + def test_column_name_contains_unicode_surrogate(self): + # GH 25509 + colname = "\ud83d" + df = DataFrame({colname: []}) + # this should not crash + assert colname not in dir(df) + assert df.columns[0] == colname + + def test_new_empty_index(self): + df1 = DataFrame(np.random.default_rng(2).standard_normal((0, 3))) + df2 = DataFrame(np.random.default_rng(2).standard_normal((0, 3))) + df1.index.name = "foo" + assert df2.index.name is None + + def test_get_agg_axis(self, float_frame): + cols = float_frame._get_agg_axis(0) + assert cols is float_frame.columns + + idx = float_frame._get_agg_axis(1) + assert idx is float_frame.index + + msg = r"Axis must be 0 or 1 \(got 2\)" + with pytest.raises(ValueError, match=msg): + float_frame._get_agg_axis(2) + + def test_empty(self, float_frame, float_string_frame): + empty_frame = DataFrame() + assert empty_frame.empty + + assert not float_frame.empty + assert not float_string_frame.empty + + # corner case + df = DataFrame({"A": [1.0, 2.0, 3.0], "B": ["a", "b", "c"]}, index=np.arange(3)) + del df["A"] + assert not df.empty + + def test_len(self, float_frame): + assert len(float_frame) == len(float_frame.index) + + # single block corner case + arr = float_frame[["A", "B"]].values + expected = float_frame.reindex(columns=["A", "B"]).values + tm.assert_almost_equal(arr, expected) + + def test_axis_aliases(self, float_frame): + f = float_frame + + # reg name + expected = f.sum(axis=0) + result = f.sum(axis="index") + tm.assert_series_equal(result, expected) + + expected = f.sum(axis=1) + result = f.sum(axis="columns") + tm.assert_series_equal(result, expected) + + def test_class_axis(self): + # GH 18147 + # no exception and no empty docstring + assert pydoc.getdoc(DataFrame.index) + assert pydoc.getdoc(DataFrame.columns) + + def test_series_put_names(self, float_string_frame): + series = float_string_frame._series + for k, v in series.items(): + assert v.name == k + + def test_empty_nonzero(self): + df = DataFrame([1, 2, 3]) + assert not df.empty + df = DataFrame(index=[1], columns=[1]) + assert not df.empty + df = DataFrame(index=["a", "b"], columns=["c", "d"]).dropna() + assert df.empty + assert df.T.empty + + @pytest.mark.parametrize( + "df", + [ + DataFrame(), + DataFrame(index=[1]), + DataFrame(columns=[1]), + DataFrame({1: []}), + ], + ) + def test_empty_like(self, df): + assert df.empty + assert df.T.empty + + def test_with_datetimelikes(self): + df = DataFrame( + { + "A": date_range("20130101", periods=10), + "B": timedelta_range("1 day", periods=10), + } + ) + t = df.T + + result = t.dtypes.value_counts() + expected = Series({np.dtype("object"): 10}, name="count") + tm.assert_series_equal(result, expected) + + def test_deepcopy(self, float_frame): + cp = deepcopy(float_frame) + cp.loc[0, "A"] = 10 + assert not float_frame.equals(cp) + + def test_inplace_return_self(self): + # GH 1893 + + data = DataFrame( + {"a": ["foo", "bar", "baz", "qux"], "b": [0, 0, 1, 1], "c": [1, 2, 3, 4]} + ) + + def _check_f(base, f): + result = f(base) + assert result is None + + # -----DataFrame----- + + # set_index + f = lambda x: x.set_index("a", inplace=True) + _check_f(data.copy(), f) + + # reset_index + f = lambda x: x.reset_index(inplace=True) + _check_f(data.set_index("a"), f) + + # drop_duplicates + f = lambda x: x.drop_duplicates(inplace=True) + _check_f(data.copy(), f) + + # sort + f = lambda x: x.sort_values("b", inplace=True) + _check_f(data.copy(), f) + + # sort_index + f = lambda x: x.sort_index(inplace=True) + _check_f(data.copy(), f) + + # fillna + f = lambda x: x.fillna(0, inplace=True) + _check_f(data.copy(), f) + + # replace + f = lambda x: x.replace(1, 0, inplace=True) + _check_f(data.copy(), f) + + # rename + f = lambda x: x.rename({1: "foo"}, inplace=True) + _check_f(data.copy(), f) + + # -----Series----- + d = data.copy()["c"] + + # reset_index + f = lambda x: x.reset_index(inplace=True, drop=True) + _check_f(data.set_index("a")["c"], f) + + # fillna + f = lambda x: x.fillna(0, inplace=True) + _check_f(d.copy(), f) + + # replace + f = lambda x: x.replace(1, 0, inplace=True) + _check_f(d.copy(), f) + + # rename + f = lambda x: x.rename({1: "foo"}, inplace=True) + _check_f(d.copy(), f) + + def test_tab_complete_warning(self, ip, frame_or_series): + # GH 16409 + pytest.importorskip("IPython", minversion="6.0.0") + from IPython.core.completer import provisionalcompleter + + if frame_or_series is DataFrame: + code = "from pandas import DataFrame; obj = DataFrame()" + else: + code = "from pandas import Series; obj = Series(dtype=object)" + + ip.run_cell(code) + # GH 31324 newer jedi version raises Deprecation warning; + # appears resolved 2021-02-02 + with tm.assert_produces_warning(None, raise_on_extra_warnings=False): + with provisionalcompleter("ignore"): + list(ip.Completer.completions("obj.", 1)) + + def test_attrs(self): + df = DataFrame({"A": [2, 3]}) + assert df.attrs == {} + df.attrs["version"] = 1 + + result = df.rename(columns=str) + assert result.attrs == {"version": 1} + + def test_attrs_deepcopy(self): + df = DataFrame({"A": [2, 3]}) + assert df.attrs == {} + df.attrs["tags"] = {"spam", "ham"} + + result = df.rename(columns=str) + assert result.attrs == df.attrs + assert result.attrs["tags"] is not df.attrs["tags"] + + @pytest.mark.parametrize("allows_duplicate_labels", [True, False, None]) + def test_set_flags( + self, + allows_duplicate_labels, + frame_or_series, + using_copy_on_write, + warn_copy_on_write, + ): + obj = DataFrame({"A": [1, 2]}) + key = (0, 0) + if frame_or_series is Series: + obj = obj["A"] + key = 0 + + result = obj.set_flags(allows_duplicate_labels=allows_duplicate_labels) + + if allows_duplicate_labels is None: + # We don't update when it's not provided + assert result.flags.allows_duplicate_labels is True + else: + assert result.flags.allows_duplicate_labels is allows_duplicate_labels + + # We made a copy + assert obj is not result + + # We didn't mutate obj + assert obj.flags.allows_duplicate_labels is True + + # But we didn't copy data + if frame_or_series is Series: + assert np.may_share_memory(obj.values, result.values) + else: + assert np.may_share_memory(obj["A"].values, result["A"].values) + + with tm.assert_cow_warning(warn_copy_on_write): + result.iloc[key] = 0 + if using_copy_on_write: + assert obj.iloc[key] == 1 + else: + assert obj.iloc[key] == 0 + # set back to 1 for test below + with tm.assert_cow_warning(warn_copy_on_write): + result.iloc[key] = 1 + + # Now we do copy. + result = obj.set_flags( + copy=True, allows_duplicate_labels=allows_duplicate_labels + ) + result.iloc[key] = 10 + assert obj.iloc[key] == 1 + + def test_constructor_expanddim(self): + # GH#33628 accessing _constructor_expanddim should not raise NotImplementedError + # GH38782 pandas has no container higher than DataFrame (two-dim), so + # DataFrame._constructor_expand_dim, doesn't make sense, so is removed. + df = DataFrame() + + msg = "'DataFrame' object has no attribute '_constructor_expanddim'" + with pytest.raises(AttributeError, match=msg): + df._constructor_expanddim(np.arange(27).reshape(3, 3, 3)) + + def test_inspect_getmembers(self): + # GH38740 + pytest.importorskip("jinja2") + df = DataFrame() + msg = "DataFrame._data is deprecated" + with tm.assert_produces_warning( + DeprecationWarning, match=msg, check_stacklevel=False + ): + inspect.getmembers(df) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_arithmetic.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_arithmetic.py new file mode 100644 index 0000000000000000000000000000000000000000..0593de7556406ab143088a0f450e620653db83a6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_arithmetic.py @@ -0,0 +1,2136 @@ +from collections import deque +from datetime import ( + datetime, + timezone, +) +from enum import Enum +import functools +import operator +import re + +import numpy as np +import pytest + +from pandas._config import using_pyarrow_string_dtype + +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, +) +import pandas._testing as tm +from pandas.core.computation import expressions as expr +from pandas.tests.frame.common import ( + _check_mixed_float, + _check_mixed_int, +) + + +@pytest.fixture +def simple_frame(): + """ + Fixture for simple 3x3 DataFrame + + Columns are ['one', 'two', 'three'], index is ['a', 'b', 'c']. + + one two three + a 1.0 2.0 3.0 + b 4.0 5.0 6.0 + c 7.0 8.0 9.0 + """ + arr = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) + + return DataFrame(arr, columns=["one", "two", "three"], index=["a", "b", "c"]) + + +@pytest.fixture(autouse=True, params=[0, 100], ids=["numexpr", "python"]) +def switch_numexpr_min_elements(request, monkeypatch): + with monkeypatch.context() as m: + m.setattr(expr, "_MIN_ELEMENTS", request.param) + yield request.param + + +class DummyElement: + def __init__(self, value, dtype) -> None: + self.value = value + self.dtype = np.dtype(dtype) + + def __array__(self, dtype=None, copy=None): + return np.array(self.value, dtype=self.dtype) + + def __str__(self) -> str: + return f"DummyElement({self.value}, {self.dtype})" + + def __repr__(self) -> str: + return str(self) + + def astype(self, dtype, copy=False): + self.dtype = dtype + return self + + def view(self, dtype): + return type(self)(self.value.view(dtype), dtype) + + def any(self, axis=None): + return bool(self.value) + + +# ------------------------------------------------------------------- +# Comparisons + + +class TestFrameComparisons: + # Specifically _not_ flex-comparisons + + def test_comparison_with_categorical_dtype(self): + # GH#12564 + + df = DataFrame({"A": ["foo", "bar", "baz"]}) + exp = DataFrame({"A": [True, False, False]}) + + res = df == "foo" + tm.assert_frame_equal(res, exp) + + # casting to categorical shouldn't affect the result + df["A"] = df["A"].astype("category") + + res = df == "foo" + tm.assert_frame_equal(res, exp) + + def test_frame_in_list(self): + # GH#12689 this should raise at the DataFrame level, not blocks + df = DataFrame( + np.random.default_rng(2).standard_normal((6, 4)), columns=list("ABCD") + ) + msg = "The truth value of a DataFrame is ambiguous" + with pytest.raises(ValueError, match=msg): + df in [None] + + @pytest.mark.parametrize( + "arg, arg2", + [ + [ + { + "a": np.random.default_rng(2).integers(10, size=10), + "b": pd.date_range("20010101", periods=10), + }, + { + "a": np.random.default_rng(2).integers(10, size=10), + "b": np.random.default_rng(2).integers(10, size=10), + }, + ], + [ + { + "a": np.random.default_rng(2).integers(10, size=10), + "b": np.random.default_rng(2).integers(10, size=10), + }, + { + "a": np.random.default_rng(2).integers(10, size=10), + "b": pd.date_range("20010101", periods=10), + }, + ], + [ + { + "a": pd.date_range("20010101", periods=10), + "b": pd.date_range("20010101", periods=10), + }, + { + "a": np.random.default_rng(2).integers(10, size=10), + "b": np.random.default_rng(2).integers(10, size=10), + }, + ], + [ + { + "a": np.random.default_rng(2).integers(10, size=10), + "b": pd.date_range("20010101", periods=10), + }, + { + "a": pd.date_range("20010101", periods=10), + "b": pd.date_range("20010101", periods=10), + }, + ], + ], + ) + def test_comparison_invalid(self, arg, arg2): + # GH4968 + # invalid date/int comparisons + x = DataFrame(arg) + y = DataFrame(arg2) + # we expect the result to match Series comparisons for + # == and !=, inequalities should raise + result = x == y + expected = DataFrame( + {col: x[col] == y[col] for col in x.columns}, + index=x.index, + columns=x.columns, + ) + tm.assert_frame_equal(result, expected) + + result = x != y + expected = DataFrame( + {col: x[col] != y[col] for col in x.columns}, + index=x.index, + columns=x.columns, + ) + tm.assert_frame_equal(result, expected) + + msgs = [ + r"Invalid comparison between dtype=datetime64\[ns\] and ndarray", + "invalid type promotion", + ( + # npdev 1.20.0 + r"The DTypes and " + r" do not have a common DType." + ), + ] + msg = "|".join(msgs) + with pytest.raises(TypeError, match=msg): + x >= y + with pytest.raises(TypeError, match=msg): + x > y + with pytest.raises(TypeError, match=msg): + x < y + with pytest.raises(TypeError, match=msg): + x <= y + + @pytest.mark.parametrize( + "left, right", + [ + ("gt", "lt"), + ("lt", "gt"), + ("ge", "le"), + ("le", "ge"), + ("eq", "eq"), + ("ne", "ne"), + ], + ) + def test_timestamp_compare(self, left, right): + # make sure we can compare Timestamps on the right AND left hand side + # GH#4982 + df = DataFrame( + { + "dates1": pd.date_range("20010101", periods=10), + "dates2": pd.date_range("20010102", periods=10), + "intcol": np.random.default_rng(2).integers(1000000000, size=10), + "floatcol": np.random.default_rng(2).standard_normal(10), + "stringcol": [chr(100 + i) for i in range(10)], + } + ) + df.loc[np.random.default_rng(2).random(len(df)) > 0.5, "dates2"] = pd.NaT + left_f = getattr(operator, left) + right_f = getattr(operator, right) + + # no nats + if left in ["eq", "ne"]: + expected = left_f(df, pd.Timestamp("20010109")) + result = right_f(pd.Timestamp("20010109"), df) + tm.assert_frame_equal(result, expected) + else: + msg = ( + "'(<|>)=?' not supported between " + "instances of 'numpy.ndarray' and 'Timestamp'" + ) + with pytest.raises(TypeError, match=msg): + left_f(df, pd.Timestamp("20010109")) + with pytest.raises(TypeError, match=msg): + right_f(pd.Timestamp("20010109"), df) + # nats + if left in ["eq", "ne"]: + expected = left_f(df, pd.Timestamp("nat")) + result = right_f(pd.Timestamp("nat"), df) + tm.assert_frame_equal(result, expected) + else: + msg = ( + "'(<|>)=?' not supported between " + "instances of 'numpy.ndarray' and 'NaTType'" + ) + with pytest.raises(TypeError, match=msg): + left_f(df, pd.Timestamp("nat")) + with pytest.raises(TypeError, match=msg): + right_f(pd.Timestamp("nat"), df) + + @pytest.mark.xfail( + using_pyarrow_string_dtype(), reason="can't compare string and int" + ) + def test_mixed_comparison(self): + # GH#13128, GH#22163 != datetime64 vs non-dt64 should be False, + # not raise TypeError + # (this appears to be fixed before GH#22163, not sure when) + df = DataFrame([["1989-08-01", 1], ["1989-08-01", 2]]) + other = DataFrame([["a", "b"], ["c", "d"]]) + + result = df == other + assert not result.any().any() + + result = df != other + assert result.all().all() + + def test_df_boolean_comparison_error(self): + # GH#4576, GH#22880 + # comparing DataFrame against list/tuple with len(obj) matching + # len(df.columns) is supported as of GH#22800 + df = DataFrame(np.arange(6).reshape((3, 2))) + + expected = DataFrame([[False, False], [True, False], [False, False]]) + + result = df == (2, 2) + tm.assert_frame_equal(result, expected) + + result = df == [2, 2] + tm.assert_frame_equal(result, expected) + + def test_df_float_none_comparison(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((8, 3)), + index=range(8), + columns=["A", "B", "C"], + ) + + result = df.__eq__(None) + assert not result.any().any() + + def test_df_string_comparison(self): + df = DataFrame([{"a": 1, "b": "foo"}, {"a": 2, "b": "bar"}]) + mask_a = df.a > 1 + tm.assert_frame_equal(df[mask_a], df.loc[1:1, :]) + tm.assert_frame_equal(df[-mask_a], df.loc[0:0, :]) + + mask_b = df.b == "foo" + tm.assert_frame_equal(df[mask_b], df.loc[0:0, :]) + tm.assert_frame_equal(df[-mask_b], df.loc[1:1, :]) + + +class TestFrameFlexComparisons: + # TODO: test_bool_flex_frame needs a better name + @pytest.mark.parametrize("op", ["eq", "ne", "gt", "lt", "ge", "le"]) + def test_bool_flex_frame(self, op): + data = np.random.default_rng(2).standard_normal((5, 3)) + other_data = np.random.default_rng(2).standard_normal((5, 3)) + df = DataFrame(data) + other = DataFrame(other_data) + ndim_5 = np.ones(df.shape + (1, 3)) + + # DataFrame + assert df.eq(df).values.all() + assert not df.ne(df).values.any() + f = getattr(df, op) + o = getattr(operator, op) + # No NAs + tm.assert_frame_equal(f(other), o(df, other)) + # Unaligned + part_o = other.loc[3:, 1:].copy() + rs = f(part_o) + xp = o(df, part_o.reindex(index=df.index, columns=df.columns)) + tm.assert_frame_equal(rs, xp) + # ndarray + tm.assert_frame_equal(f(other.values), o(df, other.values)) + # scalar + tm.assert_frame_equal(f(0), o(df, 0)) + # NAs + msg = "Unable to coerce to Series/DataFrame" + tm.assert_frame_equal(f(np.nan), o(df, np.nan)) + with pytest.raises(ValueError, match=msg): + f(ndim_5) + + @pytest.mark.parametrize("box", [np.array, Series]) + def test_bool_flex_series(self, box): + # Series + # list/tuple + data = np.random.default_rng(2).standard_normal((5, 3)) + df = DataFrame(data) + idx_ser = box(np.random.default_rng(2).standard_normal(5)) + col_ser = box(np.random.default_rng(2).standard_normal(3)) + + idx_eq = df.eq(idx_ser, axis=0) + col_eq = df.eq(col_ser) + idx_ne = df.ne(idx_ser, axis=0) + col_ne = df.ne(col_ser) + tm.assert_frame_equal(col_eq, df == Series(col_ser)) + tm.assert_frame_equal(col_eq, -col_ne) + tm.assert_frame_equal(idx_eq, -idx_ne) + tm.assert_frame_equal(idx_eq, df.T.eq(idx_ser).T) + tm.assert_frame_equal(col_eq, df.eq(list(col_ser))) + tm.assert_frame_equal(idx_eq, df.eq(Series(idx_ser), axis=0)) + tm.assert_frame_equal(idx_eq, df.eq(list(idx_ser), axis=0)) + + idx_gt = df.gt(idx_ser, axis=0) + col_gt = df.gt(col_ser) + idx_le = df.le(idx_ser, axis=0) + col_le = df.le(col_ser) + + tm.assert_frame_equal(col_gt, df > Series(col_ser)) + tm.assert_frame_equal(col_gt, -col_le) + tm.assert_frame_equal(idx_gt, -idx_le) + tm.assert_frame_equal(idx_gt, df.T.gt(idx_ser).T) + + idx_ge = df.ge(idx_ser, axis=0) + col_ge = df.ge(col_ser) + idx_lt = df.lt(idx_ser, axis=0) + col_lt = df.lt(col_ser) + tm.assert_frame_equal(col_ge, df >= Series(col_ser)) + tm.assert_frame_equal(col_ge, -col_lt) + tm.assert_frame_equal(idx_ge, -idx_lt) + tm.assert_frame_equal(idx_ge, df.T.ge(idx_ser).T) + + idx_ser = Series(np.random.default_rng(2).standard_normal(5)) + col_ser = Series(np.random.default_rng(2).standard_normal(3)) + + def test_bool_flex_frame_na(self): + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + # NA + df.loc[0, 0] = np.nan + rs = df.eq(df) + assert not rs.loc[0, 0] + rs = df.ne(df) + assert rs.loc[0, 0] + rs = df.gt(df) + assert not rs.loc[0, 0] + rs = df.lt(df) + assert not rs.loc[0, 0] + rs = df.ge(df) + assert not rs.loc[0, 0] + rs = df.le(df) + assert not rs.loc[0, 0] + + def test_bool_flex_frame_complex_dtype(self): + # complex + arr = np.array([np.nan, 1, 6, np.nan]) + arr2 = np.array([2j, np.nan, 7, None]) + df = DataFrame({"a": arr}) + df2 = DataFrame({"a": arr2}) + + msg = "|".join( + [ + "'>' not supported between instances of '.*' and 'complex'", + r"unorderable types: .*complex\(\)", # PY35 + ] + ) + with pytest.raises(TypeError, match=msg): + # inequalities are not well-defined for complex numbers + df.gt(df2) + with pytest.raises(TypeError, match=msg): + # regression test that we get the same behavior for Series + df["a"].gt(df2["a"]) + with pytest.raises(TypeError, match=msg): + # Check that we match numpy behavior here + df.values > df2.values + + rs = df.ne(df2) + assert rs.values.all() + + arr3 = np.array([2j, np.nan, None]) + df3 = DataFrame({"a": arr3}) + + with pytest.raises(TypeError, match=msg): + # inequalities are not well-defined for complex numbers + df3.gt(2j) + with pytest.raises(TypeError, match=msg): + # regression test that we get the same behavior for Series + df3["a"].gt(2j) + with pytest.raises(TypeError, match=msg): + # Check that we match numpy behavior here + df3.values > 2j + + def test_bool_flex_frame_object_dtype(self): + # corner, dtype=object + df1 = DataFrame({"col": ["foo", np.nan, "bar"]}, dtype=object) + df2 = DataFrame({"col": ["foo", datetime.now(), "bar"]}, dtype=object) + result = df1.ne(df2) + exp = DataFrame({"col": [False, True, False]}) + tm.assert_frame_equal(result, exp) + + def test_flex_comparison_nat(self): + # GH 15697, GH 22163 df.eq(pd.NaT) should behave like df == pd.NaT, + # and _definitely_ not be NaN + df = DataFrame([pd.NaT]) + + result = df == pd.NaT + # result.iloc[0, 0] is a np.bool_ object + assert result.iloc[0, 0].item() is False + + result = df.eq(pd.NaT) + assert result.iloc[0, 0].item() is False + + result = df != pd.NaT + assert result.iloc[0, 0].item() is True + + result = df.ne(pd.NaT) + assert result.iloc[0, 0].item() is True + + @pytest.mark.parametrize("opname", ["eq", "ne", "gt", "lt", "ge", "le"]) + def test_df_flex_cmp_constant_return_types(self, opname): + # GH 15077, non-empty DataFrame + df = DataFrame({"x": [1, 2, 3], "y": [1.0, 2.0, 3.0]}) + const = 2 + + result = getattr(df, opname)(const).dtypes.value_counts() + tm.assert_series_equal( + result, Series([2], index=[np.dtype(bool)], name="count") + ) + + @pytest.mark.parametrize("opname", ["eq", "ne", "gt", "lt", "ge", "le"]) + def test_df_flex_cmp_constant_return_types_empty(self, opname): + # GH 15077 empty DataFrame + df = DataFrame({"x": [1, 2, 3], "y": [1.0, 2.0, 3.0]}) + const = 2 + + empty = df.iloc[:0] + result = getattr(empty, opname)(const).dtypes.value_counts() + tm.assert_series_equal( + result, Series([2], index=[np.dtype(bool)], name="count") + ) + + def test_df_flex_cmp_ea_dtype_with_ndarray_series(self): + ii = pd.IntervalIndex.from_breaks([1, 2, 3]) + df = DataFrame({"A": ii, "B": ii}) + + ser = Series([0, 0]) + res = df.eq(ser, axis=0) + + expected = DataFrame({"A": [False, False], "B": [False, False]}) + tm.assert_frame_equal(res, expected) + + ser2 = Series([1, 2], index=["A", "B"]) + res2 = df.eq(ser2, axis=1) + tm.assert_frame_equal(res2, expected) + + +# ------------------------------------------------------------------- +# Arithmetic + + +class TestFrameFlexArithmetic: + def test_floordiv_axis0(self): + # make sure we df.floordiv(ser, axis=0) matches column-wise result + arr = np.arange(3) + ser = Series(arr) + df = DataFrame({"A": ser, "B": ser}) + + result = df.floordiv(ser, axis=0) + + expected = DataFrame({col: df[col] // ser for col in df.columns}) + + tm.assert_frame_equal(result, expected) + + result2 = df.floordiv(ser.values, axis=0) + tm.assert_frame_equal(result2, expected) + + def test_df_add_td64_columnwise(self): + # GH 22534 Check that column-wise addition broadcasts correctly + dti = pd.date_range("2016-01-01", periods=10) + tdi = pd.timedelta_range("1", periods=10) + tser = Series(tdi) + df = DataFrame({0: dti, 1: tdi}) + + result = df.add(tser, axis=0) + expected = DataFrame({0: dti + tdi, 1: tdi + tdi}) + tm.assert_frame_equal(result, expected) + + def test_df_add_flex_filled_mixed_dtypes(self): + # GH 19611 + dti = pd.date_range("2016-01-01", periods=3) + ser = Series(["1 Day", "NaT", "2 Days"], dtype="timedelta64[ns]") + df = DataFrame({"A": dti, "B": ser}) + other = DataFrame({"A": ser, "B": ser}) + fill = pd.Timedelta(days=1).to_timedelta64() + result = df.add(other, fill_value=fill) + + expected = DataFrame( + { + "A": Series( + ["2016-01-02", "2016-01-03", "2016-01-05"], dtype="datetime64[ns]" + ), + "B": ser * 2, + } + ) + tm.assert_frame_equal(result, expected) + + def test_arith_flex_frame( + self, all_arithmetic_operators, float_frame, mixed_float_frame + ): + # one instance of parametrized fixture + op = all_arithmetic_operators + + def f(x, y): + # r-versions not in operator-stdlib; get op without "r" and invert + if op.startswith("__r"): + return getattr(operator, op.replace("__r", "__"))(y, x) + return getattr(operator, op)(x, y) + + result = getattr(float_frame, op)(2 * float_frame) + expected = f(float_frame, 2 * float_frame) + tm.assert_frame_equal(result, expected) + + # vs mix float + result = getattr(mixed_float_frame, op)(2 * mixed_float_frame) + expected = f(mixed_float_frame, 2 * mixed_float_frame) + tm.assert_frame_equal(result, expected) + _check_mixed_float(result, dtype={"C": None}) + + @pytest.mark.parametrize("op", ["__add__", "__sub__", "__mul__"]) + def test_arith_flex_frame_mixed( + self, + op, + int_frame, + mixed_int_frame, + mixed_float_frame, + switch_numexpr_min_elements, + ): + f = getattr(operator, op) + + # vs mix int + result = getattr(mixed_int_frame, op)(2 + mixed_int_frame) + expected = f(mixed_int_frame, 2 + mixed_int_frame) + + # no overflow in the uint + dtype = None + if op in ["__sub__"]: + dtype = {"B": "uint64", "C": None} + elif op in ["__add__", "__mul__"]: + dtype = {"C": None} + if expr.USE_NUMEXPR and switch_numexpr_min_elements == 0: + # when using numexpr, the casting rules are slightly different: + # in the `2 + mixed_int_frame` operation, int32 column becomes + # and int64 column (not preserving dtype in operation with Python + # scalar), and then the int32/int64 combo results in int64 result + dtype["A"] = (2 + mixed_int_frame)["A"].dtype + tm.assert_frame_equal(result, expected) + _check_mixed_int(result, dtype=dtype) + + # vs mix float + result = getattr(mixed_float_frame, op)(2 * mixed_float_frame) + expected = f(mixed_float_frame, 2 * mixed_float_frame) + tm.assert_frame_equal(result, expected) + _check_mixed_float(result, dtype={"C": None}) + + # vs plain int + result = getattr(int_frame, op)(2 * int_frame) + expected = f(int_frame, 2 * int_frame) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("dim", range(3, 6)) + def test_arith_flex_frame_raise(self, all_arithmetic_operators, float_frame, dim): + # one instance of parametrized fixture + op = all_arithmetic_operators + + # Check that arrays with dim >= 3 raise + arr = np.ones((1,) * dim) + msg = "Unable to coerce to Series/DataFrame" + with pytest.raises(ValueError, match=msg): + getattr(float_frame, op)(arr) + + def test_arith_flex_frame_corner(self, float_frame): + const_add = float_frame.add(1) + tm.assert_frame_equal(const_add, float_frame + 1) + + # corner cases + result = float_frame.add(float_frame[:0]) + expected = float_frame.sort_index() * np.nan + tm.assert_frame_equal(result, expected) + + result = float_frame[:0].add(float_frame) + expected = float_frame.sort_index() * np.nan + tm.assert_frame_equal(result, expected) + + with pytest.raises(NotImplementedError, match="fill_value"): + float_frame.add(float_frame.iloc[0], fill_value=3) + + with pytest.raises(NotImplementedError, match="fill_value"): + float_frame.add(float_frame.iloc[0], axis="index", fill_value=3) + + @pytest.mark.parametrize("op", ["add", "sub", "mul", "mod"]) + def test_arith_flex_series_ops(self, simple_frame, op): + # after arithmetic refactor, add truediv here + df = simple_frame + + row = df.xs("a") + col = df["two"] + f = getattr(df, op) + op = getattr(operator, op) + tm.assert_frame_equal(f(row), op(df, row)) + tm.assert_frame_equal(f(col, axis=0), op(df.T, col).T) + + def test_arith_flex_series(self, simple_frame): + df = simple_frame + + row = df.xs("a") + col = df["two"] + # special case for some reason + tm.assert_frame_equal(df.add(row, axis=None), df + row) + + # cases which will be refactored after big arithmetic refactor + tm.assert_frame_equal(df.div(row), df / row) + tm.assert_frame_equal(df.div(col, axis=0), (df.T / col).T) + + @pytest.mark.parametrize("dtype", ["int64", "float64"]) + def test_arith_flex_series_broadcasting(self, dtype): + # broadcasting issue in GH 7325 + df = DataFrame(np.arange(3 * 2).reshape((3, 2)), dtype=dtype) + expected = DataFrame([[np.nan, np.inf], [1.0, 1.5], [1.0, 1.25]]) + result = df.div(df[0], axis="index") + tm.assert_frame_equal(result, expected) + + def test_arith_flex_zero_len_raises(self): + # GH 19522 passing fill_value to frame flex arith methods should + # raise even in the zero-length special cases + ser_len0 = Series([], dtype=object) + df_len0 = DataFrame(columns=["A", "B"]) + df = DataFrame([[1, 2], [3, 4]], columns=["A", "B"]) + + with pytest.raises(NotImplementedError, match="fill_value"): + df.add(ser_len0, fill_value="E") + + with pytest.raises(NotImplementedError, match="fill_value"): + df_len0.sub(df["A"], axis=None, fill_value=3) + + def test_flex_add_scalar_fill_value(self): + # GH#12723 + dat = np.array([0, 1, np.nan, 3, 4, 5], dtype="float") + df = DataFrame({"foo": dat}, index=range(6)) + + exp = df.fillna(0).add(2) + res = df.add(2, fill_value=0) + tm.assert_frame_equal(res, exp) + + def test_sub_alignment_with_duplicate_index(self): + # GH#5185 dup aligning operations should work + df1 = DataFrame([1, 2, 3, 4, 5], index=[1, 2, 1, 2, 3]) + df2 = DataFrame([1, 2, 3], index=[1, 2, 3]) + expected = DataFrame([0, 2, 0, 2, 2], index=[1, 1, 2, 2, 3]) + result = df1.sub(df2) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("op", ["__add__", "__mul__", "__sub__", "__truediv__"]) + def test_arithmetic_with_duplicate_columns(self, op): + # operations + df = DataFrame({"A": np.arange(10), "B": np.random.default_rng(2).random(10)}) + expected = getattr(df, op)(df) + expected.columns = ["A", "A"] + df.columns = ["A", "A"] + result = getattr(df, op)(df) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("level", [0, None]) + def test_broadcast_multiindex(self, level): + # GH34388 + df1 = DataFrame({"A": [0, 1, 2], "B": [1, 2, 3]}) + df1.columns = df1.columns.set_names("L1") + + df2 = DataFrame({("A", "C"): [0, 0, 0], ("A", "D"): [0, 0, 0]}) + df2.columns = df2.columns.set_names(["L1", "L2"]) + + result = df1.add(df2, level=level) + expected = DataFrame({("A", "C"): [0, 1, 2], ("A", "D"): [0, 1, 2]}) + expected.columns = expected.columns.set_names(["L1", "L2"]) + + tm.assert_frame_equal(result, expected) + + def test_frame_multiindex_operations(self): + # GH 43321 + df = DataFrame( + {2010: [1, 2, 3], 2020: [3, 4, 5]}, + index=MultiIndex.from_product( + [["a"], ["b"], [0, 1, 2]], names=["scen", "mod", "id"] + ), + ) + + series = Series( + [0.4], + index=MultiIndex.from_product([["b"], ["a"]], names=["mod", "scen"]), + ) + + expected = DataFrame( + {2010: [1.4, 2.4, 3.4], 2020: [3.4, 4.4, 5.4]}, + index=MultiIndex.from_product( + [["a"], ["b"], [0, 1, 2]], names=["scen", "mod", "id"] + ), + ) + result = df.add(series, axis=0) + + tm.assert_frame_equal(result, expected) + + def test_frame_multiindex_operations_series_index_to_frame_index(self): + # GH 43321 + df = DataFrame( + {2010: [1], 2020: [3]}, + index=MultiIndex.from_product([["a"], ["b"]], names=["scen", "mod"]), + ) + + series = Series( + [10.0, 20.0, 30.0], + index=MultiIndex.from_product( + [["a"], ["b"], [0, 1, 2]], names=["scen", "mod", "id"] + ), + ) + + expected = DataFrame( + {2010: [11.0, 21, 31.0], 2020: [13.0, 23.0, 33.0]}, + index=MultiIndex.from_product( + [["a"], ["b"], [0, 1, 2]], names=["scen", "mod", "id"] + ), + ) + result = df.add(series, axis=0) + + tm.assert_frame_equal(result, expected) + + def test_frame_multiindex_operations_no_align(self): + df = DataFrame( + {2010: [1, 2, 3], 2020: [3, 4, 5]}, + index=MultiIndex.from_product( + [["a"], ["b"], [0, 1, 2]], names=["scen", "mod", "id"] + ), + ) + + series = Series( + [0.4], + index=MultiIndex.from_product([["c"], ["a"]], names=["mod", "scen"]), + ) + + expected = DataFrame( + {2010: np.nan, 2020: np.nan}, + index=MultiIndex.from_tuples( + [ + ("a", "b", 0), + ("a", "b", 1), + ("a", "b", 2), + ("a", "c", np.nan), + ], + names=["scen", "mod", "id"], + ), + ) + result = df.add(series, axis=0) + + tm.assert_frame_equal(result, expected) + + def test_frame_multiindex_operations_part_align(self): + df = DataFrame( + {2010: [1, 2, 3], 2020: [3, 4, 5]}, + index=MultiIndex.from_tuples( + [ + ("a", "b", 0), + ("a", "b", 1), + ("a", "c", 2), + ], + names=["scen", "mod", "id"], + ), + ) + + series = Series( + [0.4], + index=MultiIndex.from_product([["b"], ["a"]], names=["mod", "scen"]), + ) + + expected = DataFrame( + {2010: [1.4, 2.4, np.nan], 2020: [3.4, 4.4, np.nan]}, + index=MultiIndex.from_tuples( + [ + ("a", "b", 0), + ("a", "b", 1), + ("a", "c", 2), + ], + names=["scen", "mod", "id"], + ), + ) + result = df.add(series, axis=0) + + tm.assert_frame_equal(result, expected) + + +class TestFrameArithmetic: + def test_td64_op_nat_casting(self): + # Make sure we don't accidentally treat timedelta64(NaT) as datetime64 + # when calling dispatch_to_series in DataFrame arithmetic + ser = Series(["NaT", "NaT"], dtype="timedelta64[ns]") + df = DataFrame([[1, 2], [3, 4]]) + + result = df * ser + expected = DataFrame({0: ser, 1: ser}) + tm.assert_frame_equal(result, expected) + + def test_df_add_2d_array_rowlike_broadcasts(self): + # GH#23000 + arr = np.arange(6).reshape(3, 2) + df = DataFrame(arr, columns=[True, False], index=["A", "B", "C"]) + + rowlike = arr[[1], :] # shape --> (1, ncols) + assert rowlike.shape == (1, df.shape[1]) + + expected = DataFrame( + [[2, 4], [4, 6], [6, 8]], + columns=df.columns, + index=df.index, + # specify dtype explicitly to avoid failing + # on 32bit builds + dtype=arr.dtype, + ) + result = df + rowlike + tm.assert_frame_equal(result, expected) + result = rowlike + df + tm.assert_frame_equal(result, expected) + + def test_df_add_2d_array_collike_broadcasts(self): + # GH#23000 + arr = np.arange(6).reshape(3, 2) + df = DataFrame(arr, columns=[True, False], index=["A", "B", "C"]) + + collike = arr[:, [1]] # shape --> (nrows, 1) + assert collike.shape == (df.shape[0], 1) + + expected = DataFrame( + [[1, 2], [5, 6], [9, 10]], + columns=df.columns, + index=df.index, + # specify dtype explicitly to avoid failing + # on 32bit builds + dtype=arr.dtype, + ) + result = df + collike + tm.assert_frame_equal(result, expected) + result = collike + df + tm.assert_frame_equal(result, expected) + + def test_df_arith_2d_array_rowlike_broadcasts( + self, request, all_arithmetic_operators, using_array_manager + ): + # GH#23000 + opname = all_arithmetic_operators + + if using_array_manager and opname in ("__rmod__", "__rfloordiv__"): + # TODO(ArrayManager) decide on dtypes + td.mark_array_manager_not_yet_implemented(request) + + arr = np.arange(6).reshape(3, 2) + df = DataFrame(arr, columns=[True, False], index=["A", "B", "C"]) + + rowlike = arr[[1], :] # shape --> (1, ncols) + assert rowlike.shape == (1, df.shape[1]) + + exvals = [ + getattr(df.loc["A"], opname)(rowlike.squeeze()), + getattr(df.loc["B"], opname)(rowlike.squeeze()), + getattr(df.loc["C"], opname)(rowlike.squeeze()), + ] + + expected = DataFrame(exvals, columns=df.columns, index=df.index) + + result = getattr(df, opname)(rowlike) + tm.assert_frame_equal(result, expected) + + def test_df_arith_2d_array_collike_broadcasts( + self, request, all_arithmetic_operators, using_array_manager + ): + # GH#23000 + opname = all_arithmetic_operators + + if using_array_manager and opname in ("__rmod__", "__rfloordiv__"): + # TODO(ArrayManager) decide on dtypes + td.mark_array_manager_not_yet_implemented(request) + + arr = np.arange(6).reshape(3, 2) + df = DataFrame(arr, columns=[True, False], index=["A", "B", "C"]) + + collike = arr[:, [1]] # shape --> (nrows, 1) + assert collike.shape == (df.shape[0], 1) + + exvals = { + True: getattr(df[True], opname)(collike.squeeze()), + False: getattr(df[False], opname)(collike.squeeze()), + } + + dtype = None + if opname in ["__rmod__", "__rfloordiv__"]: + # Series ops may return mixed int/float dtypes in cases where + # DataFrame op will return all-float. So we upcast `expected` + dtype = np.common_type(*(x.values for x in exvals.values())) + + expected = DataFrame(exvals, columns=df.columns, index=df.index, dtype=dtype) + + result = getattr(df, opname)(collike) + tm.assert_frame_equal(result, expected) + + def test_df_bool_mul_int(self): + # GH 22047, GH 22163 multiplication by 1 should result in int dtype, + # not object dtype + df = DataFrame([[False, True], [False, False]]) + result = df * 1 + + # On appveyor this comes back as np.int32 instead of np.int64, + # so we check dtype.kind instead of just dtype + kinds = result.dtypes.apply(lambda x: x.kind) + assert (kinds == "i").all() + + result = 1 * df + kinds = result.dtypes.apply(lambda x: x.kind) + assert (kinds == "i").all() + + def test_arith_mixed(self): + left = DataFrame({"A": ["a", "b", "c"], "B": [1, 2, 3]}) + + result = left + left + expected = DataFrame({"A": ["aa", "bb", "cc"], "B": [2, 4, 6]}) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("col", ["A", "B"]) + def test_arith_getitem_commute(self, all_arithmetic_functions, col): + df = DataFrame({"A": [1.1, 3.3], "B": [2.5, -3.9]}) + result = all_arithmetic_functions(df, 1)[col] + expected = all_arithmetic_functions(df[col], 1) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "values", [[1, 2], (1, 2), np.array([1, 2]), range(1, 3), deque([1, 2])] + ) + def test_arith_alignment_non_pandas_object(self, values): + # GH#17901 + df = DataFrame({"A": [1, 1], "B": [1, 1]}) + expected = DataFrame({"A": [2, 2], "B": [3, 3]}) + result = df + values + tm.assert_frame_equal(result, expected) + + def test_arith_non_pandas_object(self): + df = DataFrame( + np.arange(1, 10, dtype="f8").reshape(3, 3), + columns=["one", "two", "three"], + index=["a", "b", "c"], + ) + + val1 = df.xs("a").values + added = DataFrame(df.values + val1, index=df.index, columns=df.columns) + tm.assert_frame_equal(df + val1, added) + + added = DataFrame((df.values.T + val1).T, index=df.index, columns=df.columns) + tm.assert_frame_equal(df.add(val1, axis=0), added) + + val2 = list(df["two"]) + + added = DataFrame(df.values + val2, index=df.index, columns=df.columns) + tm.assert_frame_equal(df + val2, added) + + added = DataFrame((df.values.T + val2).T, index=df.index, columns=df.columns) + tm.assert_frame_equal(df.add(val2, axis="index"), added) + + val3 = np.random.default_rng(2).random(df.shape) + added = DataFrame(df.values + val3, index=df.index, columns=df.columns) + tm.assert_frame_equal(df.add(val3), added) + + def test_operations_with_interval_categories_index(self, all_arithmetic_operators): + # GH#27415 + op = all_arithmetic_operators + ind = pd.CategoricalIndex(pd.interval_range(start=0.0, end=2.0)) + data = [1, 2] + df = DataFrame([data], columns=ind) + num = 10 + result = getattr(df, op)(num) + expected = DataFrame([[getattr(n, op)(num) for n in data]], columns=ind) + tm.assert_frame_equal(result, expected) + + def test_frame_with_frame_reindex(self): + # GH#31623 + df = DataFrame( + { + "foo": [pd.Timestamp("2019"), pd.Timestamp("2020")], + "bar": [pd.Timestamp("2018"), pd.Timestamp("2021")], + }, + columns=["foo", "bar"], + dtype="M8[ns]", + ) + df2 = df[["foo"]] + + result = df - df2 + + expected = DataFrame( + {"foo": [pd.Timedelta(0), pd.Timedelta(0)], "bar": [np.nan, np.nan]}, + columns=["bar", "foo"], + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "value, dtype", + [ + (1, "i8"), + (1.0, "f8"), + (2**63, "f8"), + (1j, "complex128"), + (2**63, "complex128"), + (True, "bool"), + (np.timedelta64(20, "ns"), " b + tm.assert_frame_equal(result, expected) + + result = df.values > b + tm.assert_numpy_array_equal(result, expected.values) + + msg1d = "Unable to coerce to Series, length must be 2: given 3" + msg2d = "Unable to coerce to DataFrame, shape must be" + msg2db = "operands could not be broadcast together with shapes" + with pytest.raises(ValueError, match=msg1d): + # wrong shape + df > lst + + with pytest.raises(ValueError, match=msg1d): + # wrong shape + df > tup + + # broadcasts like ndarray (GH#23000) + result = df > b_r + tm.assert_frame_equal(result, expected) + + result = df.values > b_r + tm.assert_numpy_array_equal(result, expected.values) + + with pytest.raises(ValueError, match=msg2d): + df > b_c + + with pytest.raises(ValueError, match=msg2db): + df.values > b_c + + # == + expected = DataFrame([[False, False], [True, False], [False, False]]) + result = df == b + tm.assert_frame_equal(result, expected) + + with pytest.raises(ValueError, match=msg1d): + df == lst + + with pytest.raises(ValueError, match=msg1d): + df == tup + + # broadcasts like ndarray (GH#23000) + result = df == b_r + tm.assert_frame_equal(result, expected) + + result = df.values == b_r + tm.assert_numpy_array_equal(result, expected.values) + + with pytest.raises(ValueError, match=msg2d): + df == b_c + + assert df.values.shape != b_c.shape + + # with alignment + df = DataFrame( + np.arange(6).reshape((3, 2)), columns=list("AB"), index=list("abc") + ) + expected.index = df.index + expected.columns = df.columns + + with pytest.raises(ValueError, match=msg1d): + df == lst + + with pytest.raises(ValueError, match=msg1d): + df == tup + + def test_inplace_ops_alignment(self): + # inplace ops / ops alignment + # GH 8511 + + columns = list("abcdefg") + X_orig = DataFrame( + np.arange(10 * len(columns)).reshape(-1, len(columns)), + columns=columns, + index=range(10), + ) + Z = 100 * X_orig.iloc[:, 1:-1].copy() + block1 = list("bedcf") + subs = list("bcdef") + + # add + X = X_orig.copy() + result1 = (X[block1] + Z).reindex(columns=subs) + + X[block1] += Z + result2 = X.reindex(columns=subs) + + X = X_orig.copy() + result3 = (X[block1] + Z[block1]).reindex(columns=subs) + + X[block1] += Z[block1] + result4 = X.reindex(columns=subs) + + tm.assert_frame_equal(result1, result2) + tm.assert_frame_equal(result1, result3) + tm.assert_frame_equal(result1, result4) + + # sub + X = X_orig.copy() + result1 = (X[block1] - Z).reindex(columns=subs) + + X[block1] -= Z + result2 = X.reindex(columns=subs) + + X = X_orig.copy() + result3 = (X[block1] - Z[block1]).reindex(columns=subs) + + X[block1] -= Z[block1] + result4 = X.reindex(columns=subs) + + tm.assert_frame_equal(result1, result2) + tm.assert_frame_equal(result1, result3) + tm.assert_frame_equal(result1, result4) + + def test_inplace_ops_identity(self): + # GH 5104 + # make sure that we are actually changing the object + s_orig = Series([1, 2, 3]) + df_orig = DataFrame( + np.random.default_rng(2).integers(0, 5, size=10).reshape(-1, 5) + ) + + # no dtype change + s = s_orig.copy() + s2 = s + s += 1 + tm.assert_series_equal(s, s2) + tm.assert_series_equal(s_orig + 1, s) + assert s is s2 + assert s._mgr is s2._mgr + + df = df_orig.copy() + df2 = df + df += 1 + tm.assert_frame_equal(df, df2) + tm.assert_frame_equal(df_orig + 1, df) + assert df is df2 + assert df._mgr is df2._mgr + + # dtype change + s = s_orig.copy() + s2 = s + s += 1.5 + tm.assert_series_equal(s, s2) + tm.assert_series_equal(s_orig + 1.5, s) + + df = df_orig.copy() + df2 = df + df += 1.5 + tm.assert_frame_equal(df, df2) + tm.assert_frame_equal(df_orig + 1.5, df) + assert df is df2 + assert df._mgr is df2._mgr + + # mixed dtype + arr = np.random.default_rng(2).integers(0, 10, size=5) + df_orig = DataFrame({"A": arr.copy(), "B": "foo"}) + df = df_orig.copy() + df2 = df + df["A"] += 1 + expected = DataFrame({"A": arr.copy() + 1, "B": "foo"}) + tm.assert_frame_equal(df, expected) + tm.assert_frame_equal(df2, expected) + assert df._mgr is df2._mgr + + df = df_orig.copy() + df2 = df + df["A"] += 1.5 + expected = DataFrame({"A": arr.copy() + 1.5, "B": "foo"}) + tm.assert_frame_equal(df, expected) + tm.assert_frame_equal(df2, expected) + assert df._mgr is df2._mgr + + @pytest.mark.parametrize( + "op", + [ + "add", + "and", + pytest.param( + "div", + marks=pytest.mark.xfail( + raises=AttributeError, reason="__idiv__ not implemented" + ), + ), + "floordiv", + "mod", + "mul", + "or", + "pow", + "sub", + "truediv", + "xor", + ], + ) + def test_inplace_ops_identity2(self, op): + df = DataFrame({"a": [1.0, 2.0, 3.0], "b": [1, 2, 3]}) + + operand = 2 + if op in ("and", "or", "xor"): + # cannot use floats for boolean ops + df["a"] = [True, False, True] + + df_copy = df.copy() + iop = f"__i{op}__" + op = f"__{op}__" + + # no id change and value is correct + getattr(df, iop)(operand) + expected = getattr(df_copy, op)(operand) + tm.assert_frame_equal(df, expected) + expected = id(df) + assert id(df) == expected + + @pytest.mark.parametrize( + "val", + [ + [1, 2, 3], + (1, 2, 3), + np.array([1, 2, 3], dtype=np.int64), + range(1, 4), + ], + ) + def test_alignment_non_pandas(self, val): + index = ["A", "B", "C"] + columns = ["X", "Y", "Z"] + df = DataFrame( + np.random.default_rng(2).standard_normal((3, 3)), + index=index, + columns=columns, + ) + + align = DataFrame._align_for_op + + expected = DataFrame({"X": val, "Y": val, "Z": val}, index=df.index) + tm.assert_frame_equal(align(df, val, axis=0)[1], expected) + + expected = DataFrame( + {"X": [1, 1, 1], "Y": [2, 2, 2], "Z": [3, 3, 3]}, index=df.index + ) + tm.assert_frame_equal(align(df, val, axis=1)[1], expected) + + @pytest.mark.parametrize("val", [[1, 2], (1, 2), np.array([1, 2]), range(1, 3)]) + def test_alignment_non_pandas_length_mismatch(self, val): + index = ["A", "B", "C"] + columns = ["X", "Y", "Z"] + df = DataFrame( + np.random.default_rng(2).standard_normal((3, 3)), + index=index, + columns=columns, + ) + + align = DataFrame._align_for_op + # length mismatch + msg = "Unable to coerce to Series, length must be 3: given 2" + with pytest.raises(ValueError, match=msg): + align(df, val, axis=0) + + with pytest.raises(ValueError, match=msg): + align(df, val, axis=1) + + def test_alignment_non_pandas_index_columns(self): + index = ["A", "B", "C"] + columns = ["X", "Y", "Z"] + df = DataFrame( + np.random.default_rng(2).standard_normal((3, 3)), + index=index, + columns=columns, + ) + + align = DataFrame._align_for_op + val = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + tm.assert_frame_equal( + align(df, val, axis=0)[1], + DataFrame(val, index=df.index, columns=df.columns), + ) + tm.assert_frame_equal( + align(df, val, axis=1)[1], + DataFrame(val, index=df.index, columns=df.columns), + ) + + # shape mismatch + msg = "Unable to coerce to DataFrame, shape must be" + val = np.array([[1, 2, 3], [4, 5, 6]]) + with pytest.raises(ValueError, match=msg): + align(df, val, axis=0) + + with pytest.raises(ValueError, match=msg): + align(df, val, axis=1) + + val = np.zeros((3, 3, 3)) + msg = re.escape( + "Unable to coerce to Series/DataFrame, dimension must be <= 2: (3, 3, 3)" + ) + with pytest.raises(ValueError, match=msg): + align(df, val, axis=0) + with pytest.raises(ValueError, match=msg): + align(df, val, axis=1) + + def test_no_warning(self, all_arithmetic_operators): + df = DataFrame({"A": [0.0, 0.0], "B": [0.0, None]}) + b = df["B"] + with tm.assert_produces_warning(None): + getattr(df, all_arithmetic_operators)(b) + + def test_dunder_methods_binary(self, all_arithmetic_operators): + # GH#??? frame.__foo__ should only accept one argument + df = DataFrame({"A": [0.0, 0.0], "B": [0.0, None]}) + b = df["B"] + with pytest.raises(TypeError, match="takes 2 positional arguments"): + getattr(df, all_arithmetic_operators)(b, 0) + + def test_align_int_fill_bug(self): + # GH#910 + X = np.arange(10 * 10, dtype="float64").reshape(10, 10) + Y = np.ones((10, 1), dtype=int) + + df1 = DataFrame(X) + df1["0.X"] = Y.squeeze() + + df2 = df1.astype(float) + + result = df1 - df1.mean() + expected = df2 - df2.mean() + tm.assert_frame_equal(result, expected) + + +def test_pow_with_realignment(): + # GH#32685 pow has special semantics for operating with null values + left = DataFrame({"A": [0, 1, 2]}) + right = DataFrame(index=[0, 1, 2]) + + result = left**right + expected = DataFrame({"A": [np.nan, 1.0, np.nan]}) + tm.assert_frame_equal(result, expected) + + +def test_dataframe_series_extension_dtypes(): + # https://github.com/pandas-dev/pandas/issues/34311 + df = DataFrame( + np.random.default_rng(2).integers(0, 100, (10, 3)), columns=["a", "b", "c"] + ) + ser = Series([1, 2, 3], index=["a", "b", "c"]) + + expected = df.to_numpy("int64") + ser.to_numpy("int64").reshape(-1, 3) + expected = DataFrame(expected, columns=df.columns, dtype="Int64") + + df_ea = df.astype("Int64") + result = df_ea + ser + tm.assert_frame_equal(result, expected) + result = df_ea + ser.astype("Int64") + tm.assert_frame_equal(result, expected) + + +def test_dataframe_blockwise_slicelike(): + # GH#34367 + arr = np.random.default_rng(2).integers(0, 1000, (100, 10)) + df1 = DataFrame(arr) + # Explicit cast to float to avoid implicit cast when setting nan + df2 = df1.copy().astype({1: "float", 3: "float", 7: "float"}) + df2.iloc[0, [1, 3, 7]] = np.nan + + # Explicit cast to float to avoid implicit cast when setting nan + df3 = df1.copy().astype({5: "float"}) + df3.iloc[0, [5]] = np.nan + + # Explicit cast to float to avoid implicit cast when setting nan + df4 = df1.copy().astype({2: "float", 3: "float", 4: "float"}) + df4.iloc[0, np.arange(2, 5)] = np.nan + # Explicit cast to float to avoid implicit cast when setting nan + df5 = df1.copy().astype({4: "float", 5: "float", 6: "float"}) + df5.iloc[0, np.arange(4, 7)] = np.nan + + for left, right in [(df1, df2), (df2, df3), (df4, df5)]: + res = left + right + + expected = DataFrame({i: left[i] + right[i] for i in left.columns}) + tm.assert_frame_equal(res, expected) + + +@pytest.mark.parametrize( + "df, col_dtype", + [ + (DataFrame([[1.0, 2.0], [4.0, 5.0]], columns=list("ab")), "float64"), + ( + DataFrame([[1.0, "b"], [4.0, "b"]], columns=list("ab")).astype( + {"b": object} + ), + "object", + ), + ], +) +def test_dataframe_operation_with_non_numeric_types(df, col_dtype): + # GH #22663 + expected = DataFrame([[0.0, np.nan], [3.0, np.nan]], columns=list("ab")) + expected = expected.astype({"b": col_dtype}) + result = df + Series([-1.0], index=list("a")) + tm.assert_frame_equal(result, expected) + + +def test_arith_reindex_with_duplicates(): + # https://github.com/pandas-dev/pandas/issues/35194 + df1 = DataFrame(data=[[0]], columns=["second"]) + df2 = DataFrame(data=[[0, 0, 0]], columns=["first", "second", "second"]) + result = df1 + df2 + expected = DataFrame([[np.nan, 0, 0]], columns=["first", "second", "second"]) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("to_add", [[Series([1, 1])], [Series([1, 1]), Series([1, 1])]]) +def test_arith_list_of_arraylike_raise(to_add): + # GH 36702. Raise when trying to add list of array-like to DataFrame + df = DataFrame({"x": [1, 2], "y": [1, 2]}) + + msg = f"Unable to coerce list of {type(to_add[0])} to Series/DataFrame" + with pytest.raises(ValueError, match=msg): + df + to_add + with pytest.raises(ValueError, match=msg): + to_add + df + + +def test_inplace_arithmetic_series_update(using_copy_on_write, warn_copy_on_write): + # https://github.com/pandas-dev/pandas/issues/36373 + df = DataFrame({"A": [1, 2, 3]}) + df_orig = df.copy() + series = df["A"] + vals = series._values + + with tm.assert_cow_warning(warn_copy_on_write): + series += 1 + if using_copy_on_write: + assert series._values is not vals + tm.assert_frame_equal(df, df_orig) + else: + assert series._values is vals + + expected = DataFrame({"A": [2, 3, 4]}) + tm.assert_frame_equal(df, expected) + + +def test_arithmetic_multiindex_align(): + """ + Regression test for: https://github.com/pandas-dev/pandas/issues/33765 + """ + df1 = DataFrame( + [[1]], + index=["a"], + columns=MultiIndex.from_product([[0], [1]], names=["a", "b"]), + ) + df2 = DataFrame([[1]], index=["a"], columns=Index([0], name="a")) + expected = DataFrame( + [[0]], + index=["a"], + columns=MultiIndex.from_product([[0], [1]], names=["a", "b"]), + ) + result = df1 - df2 + tm.assert_frame_equal(result, expected) + + +def test_bool_frame_mult_float(): + # GH 18549 + df = DataFrame(True, list("ab"), list("cd")) + result = df * 1.0 + expected = DataFrame(np.ones((2, 2)), list("ab"), list("cd")) + tm.assert_frame_equal(result, expected) + + +def test_frame_sub_nullable_int(any_int_ea_dtype): + # GH 32822 + series1 = Series([1, 2, None], dtype=any_int_ea_dtype) + series2 = Series([1, 2, 3], dtype=any_int_ea_dtype) + expected = DataFrame([0, 0, None], dtype=any_int_ea_dtype) + result = series1.to_frame() - series2.to_frame() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.filterwarnings( + "ignore:Passing a BlockManager|Passing a SingleBlockManager:DeprecationWarning" +) +def test_frame_op_subclass_nonclass_constructor(): + # GH#43201 subclass._constructor is a function, not the subclass itself + + class SubclassedSeries(Series): + @property + def _constructor(self): + return SubclassedSeries + + @property + def _constructor_expanddim(self): + return SubclassedDataFrame + + class SubclassedDataFrame(DataFrame): + _metadata = ["my_extra_data"] + + def __init__(self, my_extra_data, *args, **kwargs) -> None: + self.my_extra_data = my_extra_data + super().__init__(*args, **kwargs) + + @property + def _constructor(self): + return functools.partial(type(self), self.my_extra_data) + + @property + def _constructor_sliced(self): + return SubclassedSeries + + sdf = SubclassedDataFrame("some_data", {"A": [1, 2, 3], "B": [4, 5, 6]}) + result = sdf * 2 + expected = SubclassedDataFrame("some_data", {"A": [2, 4, 6], "B": [8, 10, 12]}) + tm.assert_frame_equal(result, expected) + + result = sdf + sdf + tm.assert_frame_equal(result, expected) + + +def test_enum_column_equality(): + Cols = Enum("Cols", "col1 col2") + + q1 = DataFrame({Cols.col1: [1, 2, 3]}) + q2 = DataFrame({Cols.col1: [1, 2, 3]}) + + result = q1[Cols.col1] == q2[Cols.col1] + expected = Series([True, True, True], name=Cols.col1) + + tm.assert_series_equal(result, expected) + + +def test_mixed_col_index_dtype(): + # GH 47382 + df1 = DataFrame(columns=list("abc"), data=1.0, index=[0]) + df2 = DataFrame(columns=list("abc"), data=0.0, index=[0]) + df1.columns = df2.columns.astype("string") + result = df1 + df2 + expected = DataFrame(columns=list("abc"), data=1.0, index=[0]) + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_arrow_interface.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_arrow_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..098d1829b973cedab334cb5992b6a74f7d2c7766 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_arrow_interface.py @@ -0,0 +1,45 @@ +import ctypes + +import pytest + +import pandas.util._test_decorators as td + +import pandas as pd + +pa = pytest.importorskip("pyarrow") + + +@td.skip_if_no("pyarrow", min_version="14.0") +def test_dataframe_arrow_interface(): + df = pd.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}) + + capsule = df.__arrow_c_stream__() + assert ( + ctypes.pythonapi.PyCapsule_IsValid( + ctypes.py_object(capsule), b"arrow_array_stream" + ) + == 1 + ) + + table = pa.table(df) + expected = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]}) + assert table.equals(expected) + + schema = pa.schema([("a", pa.int8()), ("b", pa.string())]) + table = pa.table(df, schema=schema) + expected = expected.cast(schema) + assert table.equals(expected) + + +@td.skip_if_no("pyarrow", min_version="15.0") +def test_dataframe_to_arrow(): + df = pd.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}) + + table = pa.RecordBatchReader.from_stream(df).read_all() + expected = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]}) + assert table.equals(expected) + + schema = pa.schema([("a", pa.int8()), ("b", pa.string())]) + table = pa.RecordBatchReader.from_stream(df, schema=schema).read_all() + expected = expected.cast(schema) + assert table.equals(expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_block_internals.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_block_internals.py new file mode 100644 index 0000000000000000000000000000000000000000..712494ef15f972044adbf28ca459f942b0c237d6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_block_internals.py @@ -0,0 +1,457 @@ +from datetime import ( + datetime, + timedelta, +) +import itertools + +import numpy as np +import pytest + +from pandas.errors import PerformanceWarning +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + Categorical, + DataFrame, + Series, + Timestamp, + date_range, + option_context, +) +import pandas._testing as tm +from pandas.core.internals.blocks import NumpyBlock + +# Segregated collection of methods that require the BlockManager internal data +# structure + + +# TODO(ArrayManager) check which of those tests need to be rewritten to test the +# equivalent for ArrayManager +pytestmark = td.skip_array_manager_invalid_test + + +class TestDataFrameBlockInternals: + def test_setitem_invalidates_datetime_index_freq(self): + # GH#24096 altering a datetime64tz column inplace invalidates the + # `freq` attribute on the underlying DatetimeIndex + + dti = date_range("20130101", periods=3, tz="US/Eastern") + ts = dti[1] + + df = DataFrame({"B": dti}) + assert df["B"]._values.freq is None + + df.iloc[1, 0] = pd.NaT + assert df["B"]._values.freq is None + + # check that the DatetimeIndex was not altered in place + assert dti.freq == "D" + assert dti[1] == ts + + def test_cast_internals(self, float_frame): + msg = "Passing a BlockManager to DataFrame" + with tm.assert_produces_warning( + DeprecationWarning, match=msg, check_stacklevel=False + ): + casted = DataFrame(float_frame._mgr, dtype=int) + expected = DataFrame(float_frame._series, dtype=int) + tm.assert_frame_equal(casted, expected) + + with tm.assert_produces_warning( + DeprecationWarning, match=msg, check_stacklevel=False + ): + casted = DataFrame(float_frame._mgr, dtype=np.int32) + expected = DataFrame(float_frame._series, dtype=np.int32) + tm.assert_frame_equal(casted, expected) + + def test_consolidate(self, float_frame): + float_frame["E"] = 7.0 + consolidated = float_frame._consolidate() + assert len(consolidated._mgr.blocks) == 1 + + # Ensure copy, do I want this? + recons = consolidated._consolidate() + assert recons is not consolidated + tm.assert_frame_equal(recons, consolidated) + + float_frame["F"] = 8.0 + assert len(float_frame._mgr.blocks) == 3 + + return_value = float_frame._consolidate_inplace() + assert return_value is None + assert len(float_frame._mgr.blocks) == 1 + + def test_consolidate_inplace(self, float_frame): + # triggers in-place consolidation + for letter in range(ord("A"), ord("Z")): + float_frame[chr(letter)] = chr(letter) + + def test_modify_values(self, float_frame, using_copy_on_write): + if using_copy_on_write: + with pytest.raises(ValueError, match="read-only"): + float_frame.values[5] = 5 + assert (float_frame.values[5] != 5).all() + return + + float_frame.values[5] = 5 + assert (float_frame.values[5] == 5).all() + + # unconsolidated + float_frame["E"] = 7.0 + col = float_frame["E"] + float_frame.values[6] = 6 + # as of 2.0 .values does not consolidate, so subsequent calls to .values + # does not share data + assert not (float_frame.values[6] == 6).all() + + assert (col == 7).all() + + def test_boolean_set_uncons(self, float_frame): + float_frame["E"] = 7.0 + + expected = float_frame.values.copy() + expected[expected > 1] = 2 + + float_frame[float_frame > 1] = 2 + tm.assert_almost_equal(expected, float_frame.values) + + def test_constructor_with_convert(self): + # this is actually mostly a test of lib.maybe_convert_objects + # #2845 + df = DataFrame({"A": [2**63 - 1]}) + result = df["A"] + expected = Series(np.asarray([2**63 - 1], np.int64), name="A") + tm.assert_series_equal(result, expected) + + df = DataFrame({"A": [2**63]}) + result = df["A"] + expected = Series(np.asarray([2**63], np.uint64), name="A") + tm.assert_series_equal(result, expected) + + df = DataFrame({"A": [datetime(2005, 1, 1), True]}) + result = df["A"] + expected = Series( + np.asarray([datetime(2005, 1, 1), True], np.object_), name="A" + ) + tm.assert_series_equal(result, expected) + + df = DataFrame({"A": [None, 1]}) + result = df["A"] + expected = Series(np.asarray([np.nan, 1], np.float64), name="A") + tm.assert_series_equal(result, expected) + + df = DataFrame({"A": [1.0, 2]}) + result = df["A"] + expected = Series(np.asarray([1.0, 2], np.float64), name="A") + tm.assert_series_equal(result, expected) + + df = DataFrame({"A": [1.0 + 2.0j, 3]}) + result = df["A"] + expected = Series(np.asarray([1.0 + 2.0j, 3], np.complex128), name="A") + tm.assert_series_equal(result, expected) + + df = DataFrame({"A": [1.0 + 2.0j, 3.0]}) + result = df["A"] + expected = Series(np.asarray([1.0 + 2.0j, 3.0], np.complex128), name="A") + tm.assert_series_equal(result, expected) + + df = DataFrame({"A": [1.0 + 2.0j, True]}) + result = df["A"] + expected = Series(np.asarray([1.0 + 2.0j, True], np.object_), name="A") + tm.assert_series_equal(result, expected) + + df = DataFrame({"A": [1.0, None]}) + result = df["A"] + expected = Series(np.asarray([1.0, np.nan], np.float64), name="A") + tm.assert_series_equal(result, expected) + + df = DataFrame({"A": [1.0 + 2.0j, None]}) + result = df["A"] + expected = Series(np.asarray([1.0 + 2.0j, np.nan], np.complex128), name="A") + tm.assert_series_equal(result, expected) + + df = DataFrame({"A": [2.0, 1, True, None]}) + result = df["A"] + expected = Series(np.asarray([2.0, 1, True, None], np.object_), name="A") + tm.assert_series_equal(result, expected) + + df = DataFrame({"A": [2.0, 1, datetime(2006, 1, 1), None]}) + result = df["A"] + expected = Series( + np.asarray([2.0, 1, datetime(2006, 1, 1), None], np.object_), name="A" + ) + tm.assert_series_equal(result, expected) + + def test_construction_with_mixed(self, float_string_frame, using_infer_string): + # test construction edge cases with mixed types + + # f7u12, this does not work without extensive workaround + data = [ + [datetime(2001, 1, 5), np.nan, datetime(2001, 1, 2)], + [datetime(2000, 1, 2), datetime(2000, 1, 3), datetime(2000, 1, 1)], + ] + df = DataFrame(data) + + # check dtypes + result = df.dtypes + expected = Series({"datetime64[us]": 3}) + + # mixed-type frames + float_string_frame["datetime"] = datetime.now() + float_string_frame["timedelta"] = timedelta(days=1, seconds=1) + assert float_string_frame["datetime"].dtype == "M8[us]" + assert float_string_frame["timedelta"].dtype == "m8[us]" + result = float_string_frame.dtypes + expected = Series( + [np.dtype("float64")] * 4 + + [ + np.dtype("object") if not using_infer_string else "string", + np.dtype("datetime64[us]"), + np.dtype("timedelta64[us]"), + ], + index=list("ABCD") + ["foo", "datetime", "timedelta"], + ) + tm.assert_series_equal(result, expected) + + def test_construction_with_conversions(self): + # convert from a numpy array of non-ns timedelta64; as of 2.0 this does + # *not* convert + arr = np.array([1, 2, 3], dtype="timedelta64[s]") + df = DataFrame(index=range(3)) + df["A"] = arr + expected = DataFrame( + {"A": pd.timedelta_range("00:00:01", periods=3, freq="s")}, index=range(3) + ) + tm.assert_numpy_array_equal(df["A"].to_numpy(), arr) + + expected = DataFrame( + { + "dt1": Timestamp("20130101"), + "dt2": date_range("20130101", periods=3).astype("M8[s]"), + # 'dt3' : date_range('20130101 00:00:01',periods=3,freq='s'), + # FIXME: don't leave commented-out + }, + index=range(3), + ) + assert expected.dtypes["dt1"] == "M8[s]" + assert expected.dtypes["dt2"] == "M8[s]" + + df = DataFrame(index=range(3)) + df["dt1"] = np.datetime64("2013-01-01") + df["dt2"] = np.array( + ["2013-01-01", "2013-01-02", "2013-01-03"], dtype="datetime64[D]" + ) + + # df['dt3'] = np.array(['2013-01-01 00:00:01','2013-01-01 + # 00:00:02','2013-01-01 00:00:03'],dtype='datetime64[s]') + # FIXME: don't leave commented-out + + tm.assert_frame_equal(df, expected) + + def test_constructor_compound_dtypes(self): + # GH 5191 + # compound dtypes should raise not-implementederror + + def f(dtype): + data = list(itertools.repeat((datetime(2001, 1, 1), "aa", 20), 9)) + return DataFrame(data=data, columns=["A", "B", "C"], dtype=dtype) + + msg = "compound dtypes are not implemented in the DataFrame constructor" + with pytest.raises(NotImplementedError, match=msg): + f([("A", "datetime64[h]"), ("B", "str"), ("C", "int32")]) + + # pre-2.0 these used to work (though results may be unexpected) + with pytest.raises(TypeError, match="argument must be"): + f("int64") + with pytest.raises(TypeError, match="argument must be"): + f("float64") + + # 10822 + msg = "^Unknown datetime string format, unable to parse: aa, at position 0$" + with pytest.raises(ValueError, match=msg): + f("M8[ns]") + + def test_pickle(self, float_string_frame, timezone_frame): + empty_frame = DataFrame() + + unpickled = tm.round_trip_pickle(float_string_frame) + tm.assert_frame_equal(float_string_frame, unpickled) + + # buglet + float_string_frame._mgr.ndim + + # empty + unpickled = tm.round_trip_pickle(empty_frame) + repr(unpickled) + + # tz frame + unpickled = tm.round_trip_pickle(timezone_frame) + tm.assert_frame_equal(timezone_frame, unpickled) + + def test_consolidate_datetime64(self): + # numpy vstack bug + + df = DataFrame( + { + "starting": pd.to_datetime( + [ + "2012-06-21 00:00", + "2012-06-23 07:00", + "2012-06-23 16:30", + "2012-06-25 08:00", + "2012-06-26 12:00", + ] + ), + "ending": pd.to_datetime( + [ + "2012-06-23 07:00", + "2012-06-23 16:30", + "2012-06-25 08:00", + "2012-06-26 12:00", + "2012-06-27 08:00", + ] + ), + "measure": [77, 65, 77, 0, 77], + } + ) + + ser_starting = df.starting + ser_starting.index = ser_starting.values + ser_starting = ser_starting.tz_localize("US/Eastern") + ser_starting = ser_starting.tz_convert("UTC") + ser_starting.index.name = "starting" + + ser_ending = df.ending + ser_ending.index = ser_ending.values + ser_ending = ser_ending.tz_localize("US/Eastern") + ser_ending = ser_ending.tz_convert("UTC") + ser_ending.index.name = "ending" + + df.starting = ser_starting.index + df.ending = ser_ending.index + + tm.assert_index_equal(pd.DatetimeIndex(df.starting), ser_starting.index) + tm.assert_index_equal(pd.DatetimeIndex(df.ending), ser_ending.index) + + def test_is_mixed_type(self, float_frame, float_string_frame): + assert not float_frame._is_mixed_type + assert float_string_frame._is_mixed_type + + def test_stale_cached_series_bug_473(self, using_copy_on_write, warn_copy_on_write): + # this is chained, but ok + with option_context("chained_assignment", None): + Y = DataFrame( + np.random.default_rng(2).random((4, 4)), + index=("a", "b", "c", "d"), + columns=("e", "f", "g", "h"), + ) + repr(Y) + Y["e"] = Y["e"].astype("object") + with tm.raises_chained_assignment_error(): + Y["g"]["c"] = np.nan + repr(Y) + Y.sum() + Y["g"].sum() + if using_copy_on_write: + assert not pd.isna(Y["g"]["c"]) + else: + assert pd.isna(Y["g"]["c"]) + + @pytest.mark.filterwarnings("ignore:Setting a value on a view:FutureWarning") + def test_strange_column_corruption_issue(self, using_copy_on_write): + # TODO(wesm): Unclear how exactly this is related to internal matters + df = DataFrame(index=[0, 1]) + df[0] = np.nan + wasCol = {} + + with tm.assert_produces_warning( + PerformanceWarning, raise_on_extra_warnings=False + ): + for i, dt in enumerate(df.index): + for col in range(100, 200): + if col not in wasCol: + wasCol[col] = 1 + df[col] = np.nan + if using_copy_on_write: + df.loc[dt, col] = i + else: + df[col][dt] = i + + myid = 100 + + first = len(df.loc[pd.isna(df[myid]), [myid]]) + second = len(df.loc[pd.isna(df[myid]), [myid]]) + assert first == second == 0 + + def test_constructor_no_pandas_array(self): + # Ensure that NumpyExtensionArray isn't allowed inside Series + # See https://github.com/pandas-dev/pandas/issues/23995 for more. + arr = Series([1, 2, 3]).array + result = DataFrame({"A": arr}) + expected = DataFrame({"A": [1, 2, 3]}) + tm.assert_frame_equal(result, expected) + assert isinstance(result._mgr.blocks[0], NumpyBlock) + assert result._mgr.blocks[0].is_numeric + + def test_add_column_with_pandas_array(self): + # GH 26390 + df = DataFrame({"a": [1, 2, 3, 4], "b": ["a", "b", "c", "d"]}) + df["c"] = pd.arrays.NumpyExtensionArray(np.array([1, 2, None, 3], dtype=object)) + df2 = DataFrame( + { + "a": [1, 2, 3, 4], + "b": ["a", "b", "c", "d"], + "c": pd.arrays.NumpyExtensionArray( + np.array([1, 2, None, 3], dtype=object) + ), + } + ) + assert type(df["c"]._mgr.blocks[0]) == NumpyBlock + assert df["c"]._mgr.blocks[0].is_object + assert type(df2["c"]._mgr.blocks[0]) == NumpyBlock + assert df2["c"]._mgr.blocks[0].is_object + tm.assert_frame_equal(df, df2) + + +def test_update_inplace_sets_valid_block_values(using_copy_on_write): + # https://github.com/pandas-dev/pandas/issues/33457 + df = DataFrame({"a": Series([1, 2, None], dtype="category")}) + + # inplace update of a single column + if using_copy_on_write: + with tm.raises_chained_assignment_error(): + df["a"].fillna(1, inplace=True) + else: + with tm.assert_produces_warning(FutureWarning, match="inplace method"): + df["a"].fillna(1, inplace=True) + + # check we haven't put a Series into any block.values + assert isinstance(df._mgr.blocks[0].values, Categorical) + + if not using_copy_on_write: + # smoketest for OP bug from GH#35731 + assert df.isnull().sum().sum() == 0 + + +def test_nonconsolidated_item_cache_take(): + # https://github.com/pandas-dev/pandas/issues/35521 + + # create non-consolidated dataframe with object dtype columns + df = DataFrame() + df["col1"] = Series(["a"], dtype=object) + df["col2"] = Series([0], dtype=object) + + # access column (item cache) + df["col1"] == "A" + # take operation + # (regression was that this consolidated but didn't reset item cache, + # resulting in an invalid cache and the .at operation not working properly) + df[df["col2"] == 0] + + # now setting value should update actual dataframe + df.at[0, "col1"] = "A" + + expected = DataFrame({"col1": ["A"], "col2": [0]}, dtype=object) + tm.assert_frame_equal(df, expected) + assert df.at[0, "col1"] == "A" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_constructors.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_constructors.py new file mode 100644 index 0000000000000000000000000000000000000000..cae2f6e81d384149a5d52b1fed7f3c58f2413365 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_constructors.py @@ -0,0 +1,3348 @@ +import array +from collections import ( + OrderedDict, + abc, + defaultdict, + namedtuple, +) +from collections.abc import Iterator +from dataclasses import make_dataclass +from datetime import ( + date, + datetime, + timedelta, +) +import functools +import re + +import numpy as np +from numpy import ma +from numpy.ma import mrecords +import pytest +import pytz + +from pandas._config import using_pyarrow_string_dtype + +from pandas._libs import lib +from pandas.compat.numpy import np_version_gt2 +from pandas.errors import IntCastingNaNError +import pandas.util._test_decorators as td + +from pandas.core.dtypes.common import is_integer_dtype +from pandas.core.dtypes.dtypes import ( + DatetimeTZDtype, + IntervalDtype, + NumpyEADtype, + PeriodDtype, +) + +import pandas as pd +from pandas import ( + Categorical, + CategoricalIndex, + DataFrame, + DatetimeIndex, + Index, + Interval, + MultiIndex, + Period, + RangeIndex, + Series, + Timedelta, + Timestamp, + cut, + date_range, + isna, +) +import pandas._testing as tm +from pandas.arrays import ( + DatetimeArray, + IntervalArray, + PeriodArray, + SparseArray, + TimedeltaArray, +) + +MIXED_FLOAT_DTYPES = ["float16", "float32", "float64"] +MIXED_INT_DTYPES = [ + "uint8", + "uint16", + "uint32", + "uint64", + "int8", + "int16", + "int32", + "int64", +] + + +class TestDataFrameConstructors: + def test_constructor_from_ndarray_with_str_dtype(self): + # If we don't ravel/reshape around ensure_str_array, we end up + # with an array of strings each of which is e.g. "[0 1 2]" + arr = np.arange(12).reshape(4, 3) + df = DataFrame(arr, dtype=str) + expected = DataFrame(arr.astype(str), dtype=object) + tm.assert_frame_equal(df, expected) + + def test_constructor_from_2d_datetimearray(self, using_array_manager): + dti = date_range("2016-01-01", periods=6, tz="US/Pacific") + dta = dti._data.reshape(3, 2) + + df = DataFrame(dta) + expected = DataFrame({0: dta[:, 0], 1: dta[:, 1]}) + tm.assert_frame_equal(df, expected) + if not using_array_manager: + # GH#44724 big performance hit if we de-consolidate + assert len(df._mgr.blocks) == 1 + + def test_constructor_dict_with_tzaware_scalar(self): + # GH#42505 + dt = Timestamp("2019-11-03 01:00:00-0700").tz_convert("America/Los_Angeles") + dt = dt.as_unit("ns") + + df = DataFrame({"dt": dt}, index=[0]) + expected = DataFrame({"dt": [dt]}) + tm.assert_frame_equal(df, expected) + + # Non-homogeneous + df = DataFrame({"dt": dt, "value": [1]}) + expected = DataFrame({"dt": [dt], "value": [1]}) + tm.assert_frame_equal(df, expected) + + def test_construct_ndarray_with_nas_and_int_dtype(self): + # GH#26919 match Series by not casting np.nan to meaningless int + arr = np.array([[1, np.nan], [2, 3]]) + msg = r"Cannot convert non-finite values \(NA or inf\) to integer" + with pytest.raises(IntCastingNaNError, match=msg): + DataFrame(arr, dtype="i8") + + # check this matches Series behavior + with pytest.raises(IntCastingNaNError, match=msg): + Series(arr[0], dtype="i8", name=0) + + def test_construct_from_list_of_datetimes(self): + df = DataFrame([datetime.now(), datetime.now()]) + assert df[0].dtype == np.dtype("M8[ns]") + + def test_constructor_from_tzaware_datetimeindex(self): + # don't cast a DatetimeIndex WITH a tz, leave as object + # GH#6032 + naive = DatetimeIndex(["2013-1-1 13:00", "2013-1-2 14:00"], name="B") + idx = naive.tz_localize("US/Pacific") + + expected = Series(np.array(idx.tolist(), dtype="object"), name="B") + assert expected.dtype == idx.dtype + + # convert index to series + result = Series(idx) + tm.assert_series_equal(result, expected) + + def test_columns_with_leading_underscore_work_with_to_dict(self): + col_underscore = "_b" + df = DataFrame({"a": [1, 2], col_underscore: [3, 4]}) + d = df.to_dict(orient="records") + + ref_d = [{"a": 1, col_underscore: 3}, {"a": 2, col_underscore: 4}] + + assert ref_d == d + + def test_columns_with_leading_number_and_underscore_work_with_to_dict(self): + col_with_num = "1_b" + df = DataFrame({"a": [1, 2], col_with_num: [3, 4]}) + d = df.to_dict(orient="records") + + ref_d = [{"a": 1, col_with_num: 3}, {"a": 2, col_with_num: 4}] + + assert ref_d == d + + def test_array_of_dt64_nat_with_td64dtype_raises(self, frame_or_series): + # GH#39462 + nat = np.datetime64("NaT", "ns") + arr = np.array([nat], dtype=object) + if frame_or_series is DataFrame: + arr = arr.reshape(1, 1) + + msg = "Invalid type for timedelta scalar: " + with pytest.raises(TypeError, match=msg): + frame_or_series(arr, dtype="m8[ns]") + + @pytest.mark.parametrize("kind", ["m", "M"]) + def test_datetimelike_values_with_object_dtype(self, kind, frame_or_series): + # with dtype=object, we should cast dt64 values to Timestamps, not pydatetimes + if kind == "M": + dtype = "M8[ns]" + scalar_type = Timestamp + else: + dtype = "m8[ns]" + scalar_type = Timedelta + + arr = np.arange(6, dtype="i8").view(dtype).reshape(3, 2) + if frame_or_series is Series: + arr = arr[:, 0] + + obj = frame_or_series(arr, dtype=object) + assert obj._mgr.arrays[0].dtype == object + assert isinstance(obj._mgr.arrays[0].ravel()[0], scalar_type) + + # go through a different path in internals.construction + obj = frame_or_series(frame_or_series(arr), dtype=object) + assert obj._mgr.arrays[0].dtype == object + assert isinstance(obj._mgr.arrays[0].ravel()[0], scalar_type) + + obj = frame_or_series(frame_or_series(arr), dtype=NumpyEADtype(object)) + assert obj._mgr.arrays[0].dtype == object + assert isinstance(obj._mgr.arrays[0].ravel()[0], scalar_type) + + if frame_or_series is DataFrame: + # other paths through internals.construction + sers = [Series(x) for x in arr] + obj = frame_or_series(sers, dtype=object) + assert obj._mgr.arrays[0].dtype == object + assert isinstance(obj._mgr.arrays[0].ravel()[0], scalar_type) + + def test_series_with_name_not_matching_column(self): + # GH#9232 + x = Series(range(5), name=1) + y = Series(range(5), name=0) + + result = DataFrame(x, columns=[0]) + expected = DataFrame([], columns=[0]) + tm.assert_frame_equal(result, expected) + + result = DataFrame(y, columns=[1]) + expected = DataFrame([], columns=[1]) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "constructor", + [ + lambda: DataFrame(), + lambda: DataFrame(None), + lambda: DataFrame(()), + lambda: DataFrame([]), + lambda: DataFrame(_ for _ in []), + lambda: DataFrame(range(0)), + lambda: DataFrame(data=None), + lambda: DataFrame(data=()), + lambda: DataFrame(data=[]), + lambda: DataFrame(data=(_ for _ in [])), + lambda: DataFrame(data=range(0)), + ], + ) + def test_empty_constructor(self, constructor): + expected = DataFrame() + result = constructor() + assert len(result.index) == 0 + assert len(result.columns) == 0 + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "constructor", + [ + lambda: DataFrame({}), + lambda: DataFrame(data={}), + ], + ) + def test_empty_constructor_object_index(self, constructor): + expected = DataFrame(index=RangeIndex(0), columns=RangeIndex(0)) + result = constructor() + assert len(result.index) == 0 + assert len(result.columns) == 0 + tm.assert_frame_equal(result, expected, check_index_type=True) + + @pytest.mark.parametrize( + "emptylike,expected_index,expected_columns", + [ + ([[]], RangeIndex(1), RangeIndex(0)), + ([[], []], RangeIndex(2), RangeIndex(0)), + ([(_ for _ in [])], RangeIndex(1), RangeIndex(0)), + ], + ) + def test_emptylike_constructor(self, emptylike, expected_index, expected_columns): + expected = DataFrame(index=expected_index, columns=expected_columns) + result = DataFrame(emptylike) + tm.assert_frame_equal(result, expected) + + def test_constructor_mixed(self, float_string_frame, using_infer_string): + dtype = "string" if using_infer_string else np.object_ + assert float_string_frame["foo"].dtype == dtype + + def test_constructor_cast_failure(self): + # as of 2.0, we raise if we can't respect "dtype", previously we + # silently ignored + msg = "could not convert string to float" + with pytest.raises(ValueError, match=msg): + DataFrame({"a": ["a", "b", "c"]}, dtype=np.float64) + + # GH 3010, constructing with odd arrays + df = DataFrame(np.ones((4, 2))) + + # this is ok + df["foo"] = np.ones((4, 2)).tolist() + + # this is not ok + msg = "Expected a 1D array, got an array with shape \\(4, 2\\)" + with pytest.raises(ValueError, match=msg): + df["test"] = np.ones((4, 2)) + + # this is ok + df["foo2"] = np.ones((4, 2)).tolist() + + def test_constructor_dtype_copy(self): + orig_df = DataFrame({"col1": [1.0], "col2": [2.0], "col3": [3.0]}) + + new_df = DataFrame(orig_df, dtype=float, copy=True) + + new_df["col1"] = 200.0 + assert orig_df["col1"][0] == 1.0 + + def test_constructor_dtype_nocast_view_dataframe( + self, using_copy_on_write, warn_copy_on_write + ): + df = DataFrame([[1, 2]]) + should_be_view = DataFrame(df, dtype=df[0].dtype) + if using_copy_on_write: + should_be_view.iloc[0, 0] = 99 + assert df.values[0, 0] == 1 + else: + with tm.assert_cow_warning(warn_copy_on_write): + should_be_view.iloc[0, 0] = 99 + assert df.values[0, 0] == 99 + + def test_constructor_dtype_nocast_view_2d_array( + self, using_array_manager, using_copy_on_write, warn_copy_on_write + ): + df = DataFrame([[1, 2], [3, 4]], dtype="int64") + if not using_array_manager and not using_copy_on_write: + should_be_view = DataFrame(df.values, dtype=df[0].dtype) + # TODO(CoW-warn) this should warn + # with tm.assert_cow_warning(warn_copy_on_write): + should_be_view.iloc[0, 0] = 97 + assert df.values[0, 0] == 97 + else: + # INFO(ArrayManager) DataFrame(ndarray) doesn't necessarily preserve + # a view on the array to ensure contiguous 1D arrays + df2 = DataFrame(df.values, dtype=df[0].dtype) + assert df2._mgr.arrays[0].flags.c_contiguous + + @td.skip_array_manager_invalid_test + @pytest.mark.xfail(using_pyarrow_string_dtype(), reason="conversion copies") + def test_1d_object_array_does_not_copy(self): + # https://github.com/pandas-dev/pandas/issues/39272 + arr = np.array(["a", "b"], dtype="object") + df = DataFrame(arr, copy=False) + assert np.shares_memory(df.values, arr) + + @td.skip_array_manager_invalid_test + @pytest.mark.xfail(using_pyarrow_string_dtype(), reason="conversion copies") + def test_2d_object_array_does_not_copy(self): + # https://github.com/pandas-dev/pandas/issues/39272 + arr = np.array([["a", "b"], ["c", "d"]], dtype="object") + df = DataFrame(arr, copy=False) + assert np.shares_memory(df.values, arr) + + def test_constructor_dtype_list_data(self): + df = DataFrame([[1, "2"], [None, "a"]], dtype=object) + assert df.loc[1, 0] is None + assert df.loc[0, 1] == "2" + + def test_constructor_list_of_2d_raises(self): + # https://github.com/pandas-dev/pandas/issues/32289 + a = DataFrame() + b = np.empty((0, 0)) + with pytest.raises(ValueError, match=r"shape=\(1, 0, 0\)"): + DataFrame([a]) + + with pytest.raises(ValueError, match=r"shape=\(1, 0, 0\)"): + DataFrame([b]) + + a = DataFrame({"A": [1, 2]}) + with pytest.raises(ValueError, match=r"shape=\(2, 2, 1\)"): + DataFrame([a, a]) + + @pytest.mark.parametrize( + "typ, ad", + [ + # mixed floating and integer coexist in the same frame + ["float", {}], + # add lots of types + ["float", {"A": 1, "B": "foo", "C": "bar"}], + # GH 622 + ["int", {}], + ], + ) + def test_constructor_mixed_dtypes(self, typ, ad): + if typ == "int": + dtypes = MIXED_INT_DTYPES + arrays = [ + np.array(np.random.default_rng(2).random(10), dtype=d) for d in dtypes + ] + elif typ == "float": + dtypes = MIXED_FLOAT_DTYPES + arrays = [ + np.array(np.random.default_rng(2).integers(10, size=10), dtype=d) + for d in dtypes + ] + + for d, a in zip(dtypes, arrays): + assert a.dtype == d + ad.update(dict(zip(dtypes, arrays))) + df = DataFrame(ad) + + dtypes = MIXED_FLOAT_DTYPES + MIXED_INT_DTYPES + for d in dtypes: + if d in df: + assert df.dtypes[d] == d + + def test_constructor_complex_dtypes(self): + # GH10952 + a = np.random.default_rng(2).random(10).astype(np.complex64) + b = np.random.default_rng(2).random(10).astype(np.complex128) + + df = DataFrame({"a": a, "b": b}) + assert a.dtype == df.a.dtype + assert b.dtype == df.b.dtype + + def test_constructor_dtype_str_na_values(self, string_dtype): + # https://github.com/pandas-dev/pandas/issues/21083 + df = DataFrame({"A": ["x", None]}, dtype=string_dtype) + result = df.isna() + expected = DataFrame({"A": [False, True]}) + tm.assert_frame_equal(result, expected) + assert df.iloc[1, 0] is None + + df = DataFrame({"A": ["x", np.nan]}, dtype=string_dtype) + assert np.isnan(df.iloc[1, 0]) + + def test_constructor_rec(self, float_frame): + rec = float_frame.to_records(index=False) + rec.dtype.names = list(rec.dtype.names)[::-1] + + index = float_frame.index + + df = DataFrame(rec) + tm.assert_index_equal(df.columns, Index(rec.dtype.names)) + + df2 = DataFrame(rec, index=index) + tm.assert_index_equal(df2.columns, Index(rec.dtype.names)) + tm.assert_index_equal(df2.index, index) + + # case with columns != the ones we would infer from the data + rng = np.arange(len(rec))[::-1] + df3 = DataFrame(rec, index=rng, columns=["C", "B"]) + expected = DataFrame(rec, index=rng).reindex(columns=["C", "B"]) + tm.assert_frame_equal(df3, expected) + + def test_constructor_bool(self): + df = DataFrame({0: np.ones(10, dtype=bool), 1: np.zeros(10, dtype=bool)}) + assert df.values.dtype == np.bool_ + + def test_constructor_overflow_int64(self): + # see gh-14881 + values = np.array([2**64 - i for i in range(1, 10)], dtype=np.uint64) + + result = DataFrame({"a": values}) + assert result["a"].dtype == np.uint64 + + # see gh-2355 + data_scores = [ + (6311132704823138710, 273), + (2685045978526272070, 23), + (8921811264899370420, 45), + (17019687244989530680, 270), + (9930107427299601010, 273), + ] + dtype = [("uid", "u8"), ("score", "u8")] + data = np.zeros((len(data_scores),), dtype=dtype) + data[:] = data_scores + df_crawls = DataFrame(data) + assert df_crawls["uid"].dtype == np.uint64 + + @pytest.mark.parametrize( + "values", + [ + np.array([2**64], dtype=object), + np.array([2**65]), + [2**64 + 1], + np.array([-(2**63) - 4], dtype=object), + np.array([-(2**64) - 1]), + [-(2**65) - 2], + ], + ) + def test_constructor_int_overflow(self, values): + # see gh-18584 + value = values[0] + result = DataFrame(values) + + assert result[0].dtype == object + assert result[0][0] == value + + @pytest.mark.parametrize( + "values", + [ + np.array([1], dtype=np.uint16), + np.array([1], dtype=np.uint32), + np.array([1], dtype=np.uint64), + [np.uint16(1)], + [np.uint32(1)], + [np.uint64(1)], + ], + ) + def test_constructor_numpy_uints(self, values): + # GH#47294 + value = values[0] + result = DataFrame(values) + + assert result[0].dtype == value.dtype + assert result[0][0] == value + + def test_constructor_ordereddict(self): + nitems = 100 + nums = list(range(nitems)) + np.random.default_rng(2).shuffle(nums) + expected = [f"A{i:d}" for i in nums] + df = DataFrame(OrderedDict(zip(expected, [[0]] * nitems))) + assert expected == list(df.columns) + + def test_constructor_dict(self): + datetime_series = Series( + np.arange(30, dtype=np.float64), index=date_range("2020-01-01", periods=30) + ) + # test expects index shifted by 5 + datetime_series_short = datetime_series[5:] + + frame = DataFrame({"col1": datetime_series, "col2": datetime_series_short}) + + # col2 is padded with NaN + assert len(datetime_series) == 30 + assert len(datetime_series_short) == 25 + + tm.assert_series_equal(frame["col1"], datetime_series.rename("col1")) + + exp = Series( + np.concatenate([[np.nan] * 5, datetime_series_short.values]), + index=datetime_series.index, + name="col2", + ) + tm.assert_series_equal(exp, frame["col2"]) + + frame = DataFrame( + {"col1": datetime_series, "col2": datetime_series_short}, + columns=["col2", "col3", "col4"], + ) + + assert len(frame) == len(datetime_series_short) + assert "col1" not in frame + assert isna(frame["col3"]).all() + + # Corner cases + assert len(DataFrame()) == 0 + + # mix dict and array, wrong size - no spec for which error should raise + # first + msg = "Mixing dicts with non-Series may lead to ambiguous ordering." + with pytest.raises(ValueError, match=msg): + DataFrame({"A": {"a": "a", "b": "b"}, "B": ["a", "b", "c"]}) + + def test_constructor_dict_length1(self): + # Length-one dict micro-optimization + frame = DataFrame({"A": {"1": 1, "2": 2}}) + tm.assert_index_equal(frame.index, Index(["1", "2"])) + + def test_constructor_dict_with_index(self): + # empty dict plus index + idx = Index([0, 1, 2]) + frame = DataFrame({}, index=idx) + assert frame.index is idx + + def test_constructor_dict_with_index_and_columns(self): + # empty dict with index and columns + idx = Index([0, 1, 2]) + frame = DataFrame({}, index=idx, columns=idx) + assert frame.index is idx + assert frame.columns is idx + assert len(frame._series) == 3 + + def test_constructor_dict_of_empty_lists(self): + # with dict of empty list and Series + frame = DataFrame({"A": [], "B": []}, columns=["A", "B"]) + tm.assert_index_equal(frame.index, RangeIndex(0), exact=True) + + def test_constructor_dict_with_none(self): + # GH 14381 + # Dict with None value + frame_none = DataFrame({"a": None}, index=[0]) + frame_none_list = DataFrame({"a": [None]}, index=[0]) + assert frame_none._get_value(0, "a") is None + assert frame_none_list._get_value(0, "a") is None + tm.assert_frame_equal(frame_none, frame_none_list) + + def test_constructor_dict_errors(self): + # GH10856 + # dict with scalar values should raise error, even if columns passed + msg = "If using all scalar values, you must pass an index" + with pytest.raises(ValueError, match=msg): + DataFrame({"a": 0.7}) + + with pytest.raises(ValueError, match=msg): + DataFrame({"a": 0.7}, columns=["a"]) + + @pytest.mark.parametrize("scalar", [2, np.nan, None, "D"]) + def test_constructor_invalid_items_unused(self, scalar): + # No error if invalid (scalar) value is in fact not used: + result = DataFrame({"a": scalar}, columns=["b"]) + expected = DataFrame(columns=["b"]) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("value", [2, np.nan, None, float("nan")]) + def test_constructor_dict_nan_key(self, value): + # GH 18455 + cols = [1, value, 3] + idx = ["a", value] + values = [[0, 3], [1, 4], [2, 5]] + data = {cols[c]: Series(values[c], index=idx) for c in range(3)} + result = DataFrame(data).sort_values(1).sort_values("a", axis=1) + expected = DataFrame( + np.arange(6, dtype="int64").reshape(2, 3), index=idx, columns=cols + ) + tm.assert_frame_equal(result, expected) + + result = DataFrame(data, index=idx).sort_values("a", axis=1) + tm.assert_frame_equal(result, expected) + + result = DataFrame(data, index=idx, columns=cols) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("value", [np.nan, None, float("nan")]) + def test_constructor_dict_nan_tuple_key(self, value): + # GH 18455 + cols = Index([(11, 21), (value, 22), (13, value)]) + idx = Index([("a", value), (value, 2)]) + values = [[0, 3], [1, 4], [2, 5]] + data = {cols[c]: Series(values[c], index=idx) for c in range(3)} + result = DataFrame(data).sort_values((11, 21)).sort_values(("a", value), axis=1) + expected = DataFrame( + np.arange(6, dtype="int64").reshape(2, 3), index=idx, columns=cols + ) + tm.assert_frame_equal(result, expected) + + result = DataFrame(data, index=idx).sort_values(("a", value), axis=1) + tm.assert_frame_equal(result, expected) + + result = DataFrame(data, index=idx, columns=cols) + tm.assert_frame_equal(result, expected) + + def test_constructor_dict_order_insertion(self): + datetime_series = Series( + np.arange(10, dtype=np.float64), index=date_range("2020-01-01", periods=10) + ) + datetime_series_short = datetime_series[:5] + + # GH19018 + # initialization ordering: by insertion order if python>= 3.6 + d = {"b": datetime_series_short, "a": datetime_series} + frame = DataFrame(data=d) + expected = DataFrame(data=d, columns=list("ba")) + tm.assert_frame_equal(frame, expected) + + def test_constructor_dict_nan_key_and_columns(self): + # GH 16894 + result = DataFrame({np.nan: [1, 2], 2: [2, 3]}, columns=[np.nan, 2]) + expected = DataFrame([[1, 2], [2, 3]], columns=[np.nan, 2]) + tm.assert_frame_equal(result, expected) + + def test_constructor_multi_index(self): + # GH 4078 + # construction error with mi and all-nan frame + tuples = [(2, 3), (3, 3), (3, 3)] + mi = MultiIndex.from_tuples(tuples) + df = DataFrame(index=mi, columns=mi) + assert isna(df).values.ravel().all() + + tuples = [(3, 3), (2, 3), (3, 3)] + mi = MultiIndex.from_tuples(tuples) + df = DataFrame(index=mi, columns=mi) + assert isna(df).values.ravel().all() + + def test_constructor_2d_index(self): + # GH 25416 + # handling of 2d index in construction + df = DataFrame([[1]], columns=[[1]], index=[1, 2]) + expected = DataFrame( + [1, 1], + index=Index([1, 2], dtype="int64"), + columns=MultiIndex(levels=[[1]], codes=[[0]]), + ) + tm.assert_frame_equal(df, expected) + + df = DataFrame([[1]], columns=[[1]], index=[[1, 2]]) + expected = DataFrame( + [1, 1], + index=MultiIndex(levels=[[1, 2]], codes=[[0, 1]]), + columns=MultiIndex(levels=[[1]], codes=[[0]]), + ) + tm.assert_frame_equal(df, expected) + + def test_constructor_error_msgs(self): + msg = "Empty data passed with indices specified." + # passing an empty array with columns specified. + with pytest.raises(ValueError, match=msg): + DataFrame(np.empty(0), index=[1]) + + msg = "Mixing dicts with non-Series may lead to ambiguous ordering." + # mix dict and array, wrong size + with pytest.raises(ValueError, match=msg): + DataFrame({"A": {"a": "a", "b": "b"}, "B": ["a", "b", "c"]}) + + # wrong size ndarray, GH 3105 + msg = r"Shape of passed values is \(4, 3\), indices imply \(3, 3\)" + with pytest.raises(ValueError, match=msg): + DataFrame( + np.arange(12).reshape((4, 3)), + columns=["foo", "bar", "baz"], + index=date_range("2000-01-01", periods=3), + ) + + arr = np.array([[4, 5, 6]]) + msg = r"Shape of passed values is \(1, 3\), indices imply \(1, 4\)" + with pytest.raises(ValueError, match=msg): + DataFrame(index=[0], columns=range(4), data=arr) + + arr = np.array([4, 5, 6]) + msg = r"Shape of passed values is \(3, 1\), indices imply \(1, 4\)" + with pytest.raises(ValueError, match=msg): + DataFrame(index=[0], columns=range(4), data=arr) + + # higher dim raise exception + with pytest.raises(ValueError, match="Must pass 2-d input"): + DataFrame(np.zeros((3, 3, 3)), columns=["A", "B", "C"], index=[1]) + + # wrong size axis labels + msg = r"Shape of passed values is \(2, 3\), indices imply \(1, 3\)" + with pytest.raises(ValueError, match=msg): + DataFrame( + np.random.default_rng(2).random((2, 3)), + columns=["A", "B", "C"], + index=[1], + ) + + msg = r"Shape of passed values is \(2, 3\), indices imply \(2, 2\)" + with pytest.raises(ValueError, match=msg): + DataFrame( + np.random.default_rng(2).random((2, 3)), + columns=["A", "B"], + index=[1, 2], + ) + + # gh-26429 + msg = "2 columns passed, passed data had 10 columns" + with pytest.raises(ValueError, match=msg): + DataFrame((range(10), range(10, 20)), columns=("ones", "twos")) + + msg = "If using all scalar values, you must pass an index" + with pytest.raises(ValueError, match=msg): + DataFrame({"a": False, "b": True}) + + def test_constructor_subclass_dict(self, dict_subclass): + # Test for passing dict subclass to constructor + data = { + "col1": dict_subclass((x, 10.0 * x) for x in range(10)), + "col2": dict_subclass((x, 20.0 * x) for x in range(10)), + } + df = DataFrame(data) + refdf = DataFrame({col: dict(val.items()) for col, val in data.items()}) + tm.assert_frame_equal(refdf, df) + + data = dict_subclass(data.items()) + df = DataFrame(data) + tm.assert_frame_equal(refdf, df) + + def test_constructor_defaultdict(self, float_frame): + # try with defaultdict + data = {} + float_frame.loc[: float_frame.index[10], "B"] = np.nan + + for k, v in float_frame.items(): + dct = defaultdict(dict) + dct.update(v.to_dict()) + data[k] = dct + frame = DataFrame(data) + expected = frame.reindex(index=float_frame.index) + tm.assert_frame_equal(float_frame, expected) + + def test_constructor_dict_block(self): + expected = np.array([[4.0, 3.0, 2.0, 1.0]]) + df = DataFrame( + {"d": [4.0], "c": [3.0], "b": [2.0], "a": [1.0]}, + columns=["d", "c", "b", "a"], + ) + tm.assert_numpy_array_equal(df.values, expected) + + def test_constructor_dict_cast(self, using_infer_string): + # cast float tests + test_data = {"A": {"1": 1, "2": 2}, "B": {"1": "1", "2": "2", "3": "3"}} + frame = DataFrame(test_data, dtype=float) + assert len(frame) == 3 + assert frame["B"].dtype == np.float64 + assert frame["A"].dtype == np.float64 + + frame = DataFrame(test_data) + assert len(frame) == 3 + assert frame["B"].dtype == np.object_ if not using_infer_string else "string" + assert frame["A"].dtype == np.float64 + + def test_constructor_dict_cast2(self): + # can't cast to float + test_data = { + "A": dict(zip(range(20), [f"word_{i}" for i in range(20)])), + "B": dict(zip(range(15), np.random.default_rng(2).standard_normal(15))), + } + with pytest.raises(ValueError, match="could not convert string"): + DataFrame(test_data, dtype=float) + + def test_constructor_dict_dont_upcast(self): + d = {"Col1": {"Row1": "A String", "Row2": np.nan}} + df = DataFrame(d) + assert isinstance(df["Col1"]["Row2"], float) + + def test_constructor_dict_dont_upcast2(self): + dm = DataFrame([[1, 2], ["a", "b"]], index=[1, 2], columns=[1, 2]) + assert isinstance(dm[1][1], int) + + def test_constructor_dict_of_tuples(self): + # GH #1491 + data = {"a": (1, 2, 3), "b": (4, 5, 6)} + + result = DataFrame(data) + expected = DataFrame({k: list(v) for k, v in data.items()}) + tm.assert_frame_equal(result, expected, check_dtype=False) + + def test_constructor_dict_of_ranges(self): + # GH 26356 + data = {"a": range(3), "b": range(3, 6)} + + result = DataFrame(data) + expected = DataFrame({"a": [0, 1, 2], "b": [3, 4, 5]}) + tm.assert_frame_equal(result, expected) + + def test_constructor_dict_of_iterators(self): + # GH 26349 + data = {"a": iter(range(3)), "b": reversed(range(3))} + + result = DataFrame(data) + expected = DataFrame({"a": [0, 1, 2], "b": [2, 1, 0]}) + tm.assert_frame_equal(result, expected) + + def test_constructor_dict_of_generators(self): + # GH 26349 + data = {"a": (i for i in (range(3))), "b": (i for i in reversed(range(3)))} + result = DataFrame(data) + expected = DataFrame({"a": [0, 1, 2], "b": [2, 1, 0]}) + tm.assert_frame_equal(result, expected) + + def test_constructor_dict_multiindex(self): + d = { + ("a", "a"): {("i", "i"): 0, ("i", "j"): 1, ("j", "i"): 2}, + ("b", "a"): {("i", "i"): 6, ("i", "j"): 5, ("j", "i"): 4}, + ("b", "c"): {("i", "i"): 7, ("i", "j"): 8, ("j", "i"): 9}, + } + _d = sorted(d.items()) + df = DataFrame(d) + expected = DataFrame( + [x[1] for x in _d], index=MultiIndex.from_tuples([x[0] for x in _d]) + ).T + expected.index = MultiIndex.from_tuples(expected.index) + tm.assert_frame_equal( + df, + expected, + ) + + d["z"] = {"y": 123.0, ("i", "i"): 111, ("i", "j"): 111, ("j", "i"): 111} + _d.insert(0, ("z", d["z"])) + expected = DataFrame( + [x[1] for x in _d], index=Index([x[0] for x in _d], tupleize_cols=False) + ).T + expected.index = Index(expected.index, tupleize_cols=False) + df = DataFrame(d) + df = df.reindex(columns=expected.columns, index=expected.index) + tm.assert_frame_equal(df, expected) + + def test_constructor_dict_datetime64_index(self): + # GH 10160 + dates_as_str = ["1984-02-19", "1988-11-06", "1989-12-03", "1990-03-15"] + + def create_data(constructor): + return {i: {constructor(s): 2 * i} for i, s in enumerate(dates_as_str)} + + data_datetime64 = create_data(np.datetime64) + data_datetime = create_data(lambda x: datetime.strptime(x, "%Y-%m-%d")) + data_Timestamp = create_data(Timestamp) + + expected = DataFrame( + [ + {0: 0, 1: None, 2: None, 3: None}, + {0: None, 1: 2, 2: None, 3: None}, + {0: None, 1: None, 2: 4, 3: None}, + {0: None, 1: None, 2: None, 3: 6}, + ], + index=[Timestamp(dt) for dt in dates_as_str], + ) + + result_datetime64 = DataFrame(data_datetime64) + result_datetime = DataFrame(data_datetime) + result_Timestamp = DataFrame(data_Timestamp) + tm.assert_frame_equal(result_datetime64, expected) + tm.assert_frame_equal(result_datetime, expected) + tm.assert_frame_equal(result_Timestamp, expected) + + @pytest.mark.parametrize( + "klass,name", + [ + (lambda x: np.timedelta64(x, "D"), "timedelta64"), + (lambda x: timedelta(days=x), "pytimedelta"), + (lambda x: Timedelta(x, "D"), "Timedelta[ns]"), + (lambda x: Timedelta(x, "D").as_unit("s"), "Timedelta[s]"), + ], + ) + def test_constructor_dict_timedelta64_index(self, klass, name): + # GH 10160 + td_as_int = [1, 2, 3, 4] + + data = {i: {klass(s): 2 * i} for i, s in enumerate(td_as_int)} + + expected = DataFrame( + [ + {0: 0, 1: None, 2: None, 3: None}, + {0: None, 1: 2, 2: None, 3: None}, + {0: None, 1: None, 2: 4, 3: None}, + {0: None, 1: None, 2: None, 3: 6}, + ], + index=[Timedelta(td, "D") for td in td_as_int], + ) + + result = DataFrame(data) + + tm.assert_frame_equal(result, expected) + + def test_constructor_period_dict(self): + # PeriodIndex + a = pd.PeriodIndex(["2012-01", "NaT", "2012-04"], freq="M") + b = pd.PeriodIndex(["2012-02-01", "2012-03-01", "NaT"], freq="D") + df = DataFrame({"a": a, "b": b}) + assert df["a"].dtype == a.dtype + assert df["b"].dtype == b.dtype + + # list of periods + df = DataFrame({"a": a.astype(object).tolist(), "b": b.astype(object).tolist()}) + assert df["a"].dtype == a.dtype + assert df["b"].dtype == b.dtype + + def test_constructor_dict_extension_scalar(self, ea_scalar_and_dtype): + ea_scalar, ea_dtype = ea_scalar_and_dtype + df = DataFrame({"a": ea_scalar}, index=[0]) + assert df["a"].dtype == ea_dtype + + expected = DataFrame(index=[0], columns=["a"], data=ea_scalar) + + tm.assert_frame_equal(df, expected) + + @pytest.mark.parametrize( + "data,dtype", + [ + (Period("2020-01"), PeriodDtype("M")), + (Interval(left=0, right=5), IntervalDtype("int64", "right")), + ( + Timestamp("2011-01-01", tz="US/Eastern"), + DatetimeTZDtype(unit="s", tz="US/Eastern"), + ), + ], + ) + def test_constructor_extension_scalar_data(self, data, dtype): + # GH 34832 + df = DataFrame(index=[0, 1], columns=["a", "b"], data=data) + + assert df["a"].dtype == dtype + assert df["b"].dtype == dtype + + arr = pd.array([data] * 2, dtype=dtype) + expected = DataFrame({"a": arr, "b": arr}) + + tm.assert_frame_equal(df, expected) + + def test_nested_dict_frame_constructor(self): + rng = pd.period_range("1/1/2000", periods=5) + df = DataFrame(np.random.default_rng(2).standard_normal((10, 5)), columns=rng) + + data = {} + for col in df.columns: + for row in df.index: + data.setdefault(col, {})[row] = df._get_value(row, col) + + result = DataFrame(data, columns=rng) + tm.assert_frame_equal(result, df) + + data = {} + for col in df.columns: + for row in df.index: + data.setdefault(row, {})[col] = df._get_value(row, col) + + result = DataFrame(data, index=rng).T + tm.assert_frame_equal(result, df) + + def _check_basic_constructor(self, empty): + # mat: 2d matrix with shape (3, 2) to input. empty - makes sized + # objects + mat = empty((2, 3), dtype=float) + # 2-D input + frame = DataFrame(mat, columns=["A", "B", "C"], index=[1, 2]) + + assert len(frame.index) == 2 + assert len(frame.columns) == 3 + + # 1-D input + frame = DataFrame(empty((3,)), columns=["A"], index=[1, 2, 3]) + assert len(frame.index) == 3 + assert len(frame.columns) == 1 + + if empty is not np.ones: + msg = r"Cannot convert non-finite values \(NA or inf\) to integer" + with pytest.raises(IntCastingNaNError, match=msg): + DataFrame(mat, columns=["A", "B", "C"], index=[1, 2], dtype=np.int64) + return + else: + frame = DataFrame( + mat, columns=["A", "B", "C"], index=[1, 2], dtype=np.int64 + ) + assert frame.values.dtype == np.int64 + + # wrong size axis labels + msg = r"Shape of passed values is \(2, 3\), indices imply \(1, 3\)" + with pytest.raises(ValueError, match=msg): + DataFrame(mat, columns=["A", "B", "C"], index=[1]) + msg = r"Shape of passed values is \(2, 3\), indices imply \(2, 2\)" + with pytest.raises(ValueError, match=msg): + DataFrame(mat, columns=["A", "B"], index=[1, 2]) + + # higher dim raise exception + with pytest.raises(ValueError, match="Must pass 2-d input"): + DataFrame(empty((3, 3, 3)), columns=["A", "B", "C"], index=[1]) + + # automatic labeling + frame = DataFrame(mat) + tm.assert_index_equal(frame.index, Index(range(2)), exact=True) + tm.assert_index_equal(frame.columns, Index(range(3)), exact=True) + + frame = DataFrame(mat, index=[1, 2]) + tm.assert_index_equal(frame.columns, Index(range(3)), exact=True) + + frame = DataFrame(mat, columns=["A", "B", "C"]) + tm.assert_index_equal(frame.index, Index(range(2)), exact=True) + + # 0-length axis + frame = DataFrame(empty((0, 3))) + assert len(frame.index) == 0 + + frame = DataFrame(empty((3, 0))) + assert len(frame.columns) == 0 + + def test_constructor_ndarray(self): + self._check_basic_constructor(np.ones) + + frame = DataFrame(["foo", "bar"], index=[0, 1], columns=["A"]) + assert len(frame) == 2 + + def test_constructor_maskedarray(self): + self._check_basic_constructor(ma.masked_all) + + # Check non-masked values + mat = ma.masked_all((2, 3), dtype=float) + mat[0, 0] = 1.0 + mat[1, 2] = 2.0 + frame = DataFrame(mat, columns=["A", "B", "C"], index=[1, 2]) + assert 1.0 == frame["A"][1] + assert 2.0 == frame["C"][2] + + # what is this even checking?? + mat = ma.masked_all((2, 3), dtype=float) + frame = DataFrame(mat, columns=["A", "B", "C"], index=[1, 2]) + assert np.all(~np.asarray(frame == frame)) + + @pytest.mark.filterwarnings( + "ignore:elementwise comparison failed:DeprecationWarning" + ) + def test_constructor_maskedarray_nonfloat(self): + # masked int promoted to float + mat = ma.masked_all((2, 3), dtype=int) + # 2-D input + frame = DataFrame(mat, columns=["A", "B", "C"], index=[1, 2]) + + assert len(frame.index) == 2 + assert len(frame.columns) == 3 + assert np.all(~np.asarray(frame == frame)) + + # cast type + frame = DataFrame(mat, columns=["A", "B", "C"], index=[1, 2], dtype=np.float64) + assert frame.values.dtype == np.float64 + + # Check non-masked values + mat2 = ma.copy(mat) + mat2[0, 0] = 1 + mat2[1, 2] = 2 + frame = DataFrame(mat2, columns=["A", "B", "C"], index=[1, 2]) + assert 1 == frame["A"][1] + assert 2 == frame["C"][2] + + # masked np.datetime64 stays (use NaT as null) + mat = ma.masked_all((2, 3), dtype="M8[ns]") + # 2-D input + frame = DataFrame(mat, columns=["A", "B", "C"], index=[1, 2]) + + assert len(frame.index) == 2 + assert len(frame.columns) == 3 + assert isna(frame).values.all() + + # cast type + msg = r"datetime64\[ns\] values and dtype=int64 is not supported" + with pytest.raises(TypeError, match=msg): + DataFrame(mat, columns=["A", "B", "C"], index=[1, 2], dtype=np.int64) + + # Check non-masked values + mat2 = ma.copy(mat) + mat2[0, 0] = 1 + mat2[1, 2] = 2 + frame = DataFrame(mat2, columns=["A", "B", "C"], index=[1, 2]) + assert 1 == frame["A"].astype("i8")[1] + assert 2 == frame["C"].astype("i8")[2] + + # masked bool promoted to object + mat = ma.masked_all((2, 3), dtype=bool) + # 2-D input + frame = DataFrame(mat, columns=["A", "B", "C"], index=[1, 2]) + + assert len(frame.index) == 2 + assert len(frame.columns) == 3 + assert np.all(~np.asarray(frame == frame)) + + # cast type + frame = DataFrame(mat, columns=["A", "B", "C"], index=[1, 2], dtype=object) + assert frame.values.dtype == object + + # Check non-masked values + mat2 = ma.copy(mat) + mat2[0, 0] = True + mat2[1, 2] = False + frame = DataFrame(mat2, columns=["A", "B", "C"], index=[1, 2]) + assert frame["A"][1] is True + assert frame["C"][2] is False + + def test_constructor_maskedarray_hardened(self): + # Check numpy masked arrays with hard masks -- from GH24574 + mat_hard = ma.masked_all((2, 2), dtype=float).harden_mask() + result = DataFrame(mat_hard, columns=["A", "B"], index=[1, 2]) + expected = DataFrame( + {"A": [np.nan, np.nan], "B": [np.nan, np.nan]}, + columns=["A", "B"], + index=[1, 2], + dtype=float, + ) + tm.assert_frame_equal(result, expected) + # Check case where mask is hard but no data are masked + mat_hard = ma.ones((2, 2), dtype=float).harden_mask() + result = DataFrame(mat_hard, columns=["A", "B"], index=[1, 2]) + expected = DataFrame( + {"A": [1.0, 1.0], "B": [1.0, 1.0]}, + columns=["A", "B"], + index=[1, 2], + dtype=float, + ) + tm.assert_frame_equal(result, expected) + + def test_constructor_maskedrecarray_dtype(self): + # Ensure constructor honors dtype + data = np.ma.array( + np.ma.zeros(5, dtype=[("date", " None: + self._lst = lst + + def __getitem__(self, n): + return self._lst.__getitem__(n) + + def __len__(self) -> int: + return self._lst.__len__() + + lst_containers = [DummyContainer([1, "a"]), DummyContainer([2, "b"])] + columns = ["num", "str"] + result = DataFrame(lst_containers, columns=columns) + expected = DataFrame([[1, "a"], [2, "b"]], columns=columns) + tm.assert_frame_equal(result, expected, check_dtype=False) + + def test_constructor_stdlib_array(self): + # GH 4297 + # support Array + result = DataFrame({"A": array.array("i", range(10))}) + expected = DataFrame({"A": list(range(10))}) + tm.assert_frame_equal(result, expected, check_dtype=False) + + expected = DataFrame([list(range(10)), list(range(10))]) + result = DataFrame([array.array("i", range(10)), array.array("i", range(10))]) + tm.assert_frame_equal(result, expected, check_dtype=False) + + def test_constructor_range(self): + # GH26342 + result = DataFrame(range(10)) + expected = DataFrame(list(range(10))) + tm.assert_frame_equal(result, expected) + + def test_constructor_list_of_ranges(self): + result = DataFrame([range(10), range(10)]) + expected = DataFrame([list(range(10)), list(range(10))]) + tm.assert_frame_equal(result, expected) + + def test_constructor_iterable(self): + # GH 21987 + class Iter: + def __iter__(self) -> Iterator: + for i in range(10): + yield [1, 2, 3] + + expected = DataFrame([[1, 2, 3]] * 10) + result = DataFrame(Iter()) + tm.assert_frame_equal(result, expected) + + def test_constructor_iterator(self): + result = DataFrame(iter(range(10))) + expected = DataFrame(list(range(10))) + tm.assert_frame_equal(result, expected) + + def test_constructor_list_of_iterators(self): + result = DataFrame([iter(range(10)), iter(range(10))]) + expected = DataFrame([list(range(10)), list(range(10))]) + tm.assert_frame_equal(result, expected) + + def test_constructor_generator(self): + # related #2305 + + gen1 = (i for i in range(10)) + gen2 = (i for i in range(10)) + + expected = DataFrame([list(range(10)), list(range(10))]) + result = DataFrame([gen1, gen2]) + tm.assert_frame_equal(result, expected) + + gen = ([i, "a"] for i in range(10)) + result = DataFrame(gen) + expected = DataFrame({0: range(10), 1: "a"}) + tm.assert_frame_equal(result, expected, check_dtype=False) + + def test_constructor_list_of_dicts(self): + result = DataFrame([{}]) + expected = DataFrame(index=RangeIndex(1), columns=[]) + tm.assert_frame_equal(result, expected) + + def test_constructor_ordered_dict_nested_preserve_order(self): + # see gh-18166 + nested1 = OrderedDict([("b", 1), ("a", 2)]) + nested2 = OrderedDict([("b", 2), ("a", 5)]) + data = OrderedDict([("col2", nested1), ("col1", nested2)]) + result = DataFrame(data) + data = {"col2": [1, 2], "col1": [2, 5]} + expected = DataFrame(data=data, index=["b", "a"]) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("dict_type", [dict, OrderedDict]) + def test_constructor_ordered_dict_preserve_order(self, dict_type): + # see gh-13304 + expected = DataFrame([[2, 1]], columns=["b", "a"]) + + data = dict_type() + data["b"] = [2] + data["a"] = [1] + + result = DataFrame(data) + tm.assert_frame_equal(result, expected) + + data = dict_type() + data["b"] = 2 + data["a"] = 1 + + result = DataFrame([data]) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("dict_type", [dict, OrderedDict]) + def test_constructor_ordered_dict_conflicting_orders(self, dict_type): + # the first dict element sets the ordering for the DataFrame, + # even if there are conflicting orders from subsequent ones + row_one = dict_type() + row_one["b"] = 2 + row_one["a"] = 1 + + row_two = dict_type() + row_two["a"] = 1 + row_two["b"] = 2 + + row_three = {"b": 2, "a": 1} + + expected = DataFrame([[2, 1], [2, 1]], columns=["b", "a"]) + result = DataFrame([row_one, row_two]) + tm.assert_frame_equal(result, expected) + + expected = DataFrame([[2, 1], [2, 1], [2, 1]], columns=["b", "a"]) + result = DataFrame([row_one, row_two, row_three]) + tm.assert_frame_equal(result, expected) + + def test_constructor_list_of_series_aligned_index(self): + series = [Series(i, index=["b", "a", "c"], name=str(i)) for i in range(3)] + result = DataFrame(series) + expected = DataFrame( + {"b": [0, 1, 2], "a": [0, 1, 2], "c": [0, 1, 2]}, + columns=["b", "a", "c"], + index=["0", "1", "2"], + ) + tm.assert_frame_equal(result, expected) + + def test_constructor_list_of_derived_dicts(self): + class CustomDict(dict): + pass + + d = {"a": 1.5, "b": 3} + + data_custom = [CustomDict(d)] + data = [d] + + result_custom = DataFrame(data_custom) + result = DataFrame(data) + tm.assert_frame_equal(result, result_custom) + + def test_constructor_ragged(self): + data = { + "A": np.random.default_rng(2).standard_normal(10), + "B": np.random.default_rng(2).standard_normal(8), + } + with pytest.raises(ValueError, match="All arrays must be of the same length"): + DataFrame(data) + + def test_constructor_scalar(self): + idx = Index(range(3)) + df = DataFrame({"a": 0}, index=idx) + expected = DataFrame({"a": [0, 0, 0]}, index=idx) + tm.assert_frame_equal(df, expected, check_dtype=False) + + def test_constructor_Series_copy_bug(self, float_frame): + df = DataFrame(float_frame["A"], index=float_frame.index, columns=["A"]) + df.copy() + + def test_constructor_mixed_dict_and_Series(self): + data = {} + data["A"] = {"foo": 1, "bar": 2, "baz": 3} + data["B"] = Series([4, 3, 2, 1], index=["bar", "qux", "baz", "foo"]) + + result = DataFrame(data) + assert result.index.is_monotonic_increasing + + # ordering ambiguous, raise exception + with pytest.raises(ValueError, match="ambiguous ordering"): + DataFrame({"A": ["a", "b"], "B": {"a": "a", "b": "b"}}) + + # this is OK though + result = DataFrame({"A": ["a", "b"], "B": Series(["a", "b"], index=["a", "b"])}) + expected = DataFrame({"A": ["a", "b"], "B": ["a", "b"]}, index=["a", "b"]) + tm.assert_frame_equal(result, expected) + + def test_constructor_mixed_type_rows(self): + # Issue 25075 + data = [[1, 2], (3, 4)] + result = DataFrame(data) + expected = DataFrame([[1, 2], [3, 4]]) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "tuples,lists", + [ + ((), []), + ((()), []), + (((), ()), [(), ()]), + (((), ()), [[], []]), + (([], []), [[], []]), + (([1], [2]), [[1], [2]]), # GH 32776 + (([1, 2, 3], [4, 5, 6]), [[1, 2, 3], [4, 5, 6]]), + ], + ) + def test_constructor_tuple(self, tuples, lists): + # GH 25691 + result = DataFrame(tuples) + expected = DataFrame(lists) + tm.assert_frame_equal(result, expected) + + def test_constructor_list_of_tuples(self): + result = DataFrame({"A": [(1, 2), (3, 4)]}) + expected = DataFrame({"A": Series([(1, 2), (3, 4)])}) + tm.assert_frame_equal(result, expected) + + def test_constructor_list_of_namedtuples(self): + # GH11181 + named_tuple = namedtuple("Pandas", list("ab")) + tuples = [named_tuple(1, 3), named_tuple(2, 4)] + expected = DataFrame({"a": [1, 2], "b": [3, 4]}) + result = DataFrame(tuples) + tm.assert_frame_equal(result, expected) + + # with columns + expected = DataFrame({"y": [1, 2], "z": [3, 4]}) + result = DataFrame(tuples, columns=["y", "z"]) + tm.assert_frame_equal(result, expected) + + def test_constructor_list_of_dataclasses(self): + # GH21910 + Point = make_dataclass("Point", [("x", int), ("y", int)]) + + data = [Point(0, 3), Point(1, 3)] + expected = DataFrame({"x": [0, 1], "y": [3, 3]}) + result = DataFrame(data) + tm.assert_frame_equal(result, expected) + + def test_constructor_list_of_dataclasses_with_varying_types(self): + # GH21910 + # varying types + Point = make_dataclass("Point", [("x", int), ("y", int)]) + HLine = make_dataclass("HLine", [("x0", int), ("x1", int), ("y", int)]) + + data = [Point(0, 3), HLine(1, 3, 3)] + + expected = DataFrame( + {"x": [0, np.nan], "y": [3, 3], "x0": [np.nan, 1], "x1": [np.nan, 3]} + ) + result = DataFrame(data) + tm.assert_frame_equal(result, expected) + + def test_constructor_list_of_dataclasses_error_thrown(self): + # GH21910 + Point = make_dataclass("Point", [("x", int), ("y", int)]) + + # expect TypeError + msg = "asdict() should be called on dataclass instances" + with pytest.raises(TypeError, match=re.escape(msg)): + DataFrame([Point(0, 0), {"x": 1, "y": 0}]) + + def test_constructor_list_of_dict_order(self): + # GH10056 + data = [ + {"First": 1, "Second": 4, "Third": 7, "Fourth": 10}, + {"Second": 5, "First": 2, "Fourth": 11, "Third": 8}, + {"Second": 6, "First": 3, "Fourth": 12, "Third": 9, "YYY": 14, "XXX": 13}, + ] + expected = DataFrame( + { + "First": [1, 2, 3], + "Second": [4, 5, 6], + "Third": [7, 8, 9], + "Fourth": [10, 11, 12], + "YYY": [None, None, 14], + "XXX": [None, None, 13], + } + ) + result = DataFrame(data) + tm.assert_frame_equal(result, expected) + + def test_constructor_Series_named(self): + a = Series([1, 2, 3], index=["a", "b", "c"], name="x") + df = DataFrame(a) + assert df.columns[0] == "x" + tm.assert_index_equal(df.index, a.index) + + # ndarray like + arr = np.random.default_rng(2).standard_normal(10) + s = Series(arr, name="x") + df = DataFrame(s) + expected = DataFrame({"x": s}) + tm.assert_frame_equal(df, expected) + + s = Series(arr, index=range(3, 13)) + df = DataFrame(s) + expected = DataFrame({0: s}) + tm.assert_frame_equal(df, expected) + + msg = r"Shape of passed values is \(10, 1\), indices imply \(10, 2\)" + with pytest.raises(ValueError, match=msg): + DataFrame(s, columns=[1, 2]) + + # #2234 + a = Series([], name="x", dtype=object) + df = DataFrame(a) + assert df.columns[0] == "x" + + # series with name and w/o + s1 = Series(arr, name="x") + df = DataFrame([s1, arr]).T + expected = DataFrame({"x": s1, "Unnamed 0": arr}, columns=["x", "Unnamed 0"]) + tm.assert_frame_equal(df, expected) + + # this is a bit non-intuitive here; the series collapse down to arrays + df = DataFrame([arr, s1]).T + expected = DataFrame({1: s1, 0: arr}, columns=[0, 1]) + tm.assert_frame_equal(df, expected) + + def test_constructor_Series_named_and_columns(self): + # GH 9232 validation + + s0 = Series(range(5), name=0) + s1 = Series(range(5), name=1) + + # matching name and column gives standard frame + tm.assert_frame_equal(DataFrame(s0, columns=[0]), s0.to_frame()) + tm.assert_frame_equal(DataFrame(s1, columns=[1]), s1.to_frame()) + + # non-matching produces empty frame + assert DataFrame(s0, columns=[1]).empty + assert DataFrame(s1, columns=[0]).empty + + def test_constructor_Series_differently_indexed(self): + # name + s1 = Series([1, 2, 3], index=["a", "b", "c"], name="x") + + # no name + s2 = Series([1, 2, 3], index=["a", "b", "c"]) + + other_index = Index(["a", "b"]) + + df1 = DataFrame(s1, index=other_index) + exp1 = DataFrame(s1.reindex(other_index)) + assert df1.columns[0] == "x" + tm.assert_frame_equal(df1, exp1) + + df2 = DataFrame(s2, index=other_index) + exp2 = DataFrame(s2.reindex(other_index)) + assert df2.columns[0] == 0 + tm.assert_index_equal(df2.index, other_index) + tm.assert_frame_equal(df2, exp2) + + @pytest.mark.parametrize( + "name_in1,name_in2,name_in3,name_out", + [ + ("idx", "idx", "idx", "idx"), + ("idx", "idx", None, None), + ("idx", None, None, None), + ("idx1", "idx2", None, None), + ("idx1", "idx1", "idx2", None), + ("idx1", "idx2", "idx3", None), + (None, None, None, None), + ], + ) + def test_constructor_index_names(self, name_in1, name_in2, name_in3, name_out): + # GH13475 + indices = [ + Index(["a", "b", "c"], name=name_in1), + Index(["b", "c", "d"], name=name_in2), + Index(["c", "d", "e"], name=name_in3), + ] + series = { + c: Series([0, 1, 2], index=i) for i, c in zip(indices, ["x", "y", "z"]) + } + result = DataFrame(series) + + exp_ind = Index(["a", "b", "c", "d", "e"], name=name_out) + expected = DataFrame( + { + "x": [0, 1, 2, np.nan, np.nan], + "y": [np.nan, 0, 1, 2, np.nan], + "z": [np.nan, np.nan, 0, 1, 2], + }, + index=exp_ind, + ) + + tm.assert_frame_equal(result, expected) + + def test_constructor_manager_resize(self, float_frame): + index = list(float_frame.index[:5]) + columns = list(float_frame.columns[:3]) + + msg = "Passing a BlockManager to DataFrame" + with tm.assert_produces_warning( + DeprecationWarning, match=msg, check_stacklevel=False + ): + result = DataFrame(float_frame._mgr, index=index, columns=columns) + tm.assert_index_equal(result.index, Index(index)) + tm.assert_index_equal(result.columns, Index(columns)) + + def test_constructor_mix_series_nonseries(self, float_frame): + df = DataFrame( + {"A": float_frame["A"], "B": list(float_frame["B"])}, columns=["A", "B"] + ) + tm.assert_frame_equal(df, float_frame.loc[:, ["A", "B"]]) + + msg = "does not match index length" + with pytest.raises(ValueError, match=msg): + DataFrame({"A": float_frame["A"], "B": list(float_frame["B"])[:-2]}) + + def test_constructor_miscast_na_int_dtype(self): + msg = r"Cannot convert non-finite values \(NA or inf\) to integer" + + with pytest.raises(IntCastingNaNError, match=msg): + DataFrame([[np.nan, 1], [1, 0]], dtype=np.int64) + + def test_constructor_column_duplicates(self): + # it works! #2079 + df = DataFrame([[8, 5]], columns=["a", "a"]) + edf = DataFrame([[8, 5]]) + edf.columns = ["a", "a"] + + tm.assert_frame_equal(df, edf) + + idf = DataFrame.from_records([(8, 5)], columns=["a", "a"]) + + tm.assert_frame_equal(idf, edf) + + def test_constructor_empty_with_string_dtype(self): + # GH 9428 + expected = DataFrame(index=[0, 1], columns=[0, 1], dtype=object) + + df = DataFrame(index=[0, 1], columns=[0, 1], dtype=str) + tm.assert_frame_equal(df, expected) + df = DataFrame(index=[0, 1], columns=[0, 1], dtype=np.str_) + tm.assert_frame_equal(df, expected) + df = DataFrame(index=[0, 1], columns=[0, 1], dtype="U5") + tm.assert_frame_equal(df, expected) + + def test_constructor_empty_with_string_extension(self, nullable_string_dtype): + # GH 34915 + expected = DataFrame(columns=["c1"], dtype=nullable_string_dtype) + df = DataFrame(columns=["c1"], dtype=nullable_string_dtype) + tm.assert_frame_equal(df, expected) + + def test_constructor_single_value(self): + # expecting single value upcasting here + df = DataFrame(0.0, index=[1, 2, 3], columns=["a", "b", "c"]) + tm.assert_frame_equal( + df, DataFrame(np.zeros(df.shape).astype("float64"), df.index, df.columns) + ) + + df = DataFrame(0, index=[1, 2, 3], columns=["a", "b", "c"]) + tm.assert_frame_equal( + df, DataFrame(np.zeros(df.shape).astype("int64"), df.index, df.columns) + ) + + df = DataFrame("a", index=[1, 2], columns=["a", "c"]) + tm.assert_frame_equal( + df, + DataFrame( + np.array([["a", "a"], ["a", "a"]], dtype=object), + index=[1, 2], + columns=["a", "c"], + ), + ) + + msg = "DataFrame constructor not properly called!" + with pytest.raises(ValueError, match=msg): + DataFrame("a", [1, 2]) + with pytest.raises(ValueError, match=msg): + DataFrame("a", columns=["a", "c"]) + + msg = "incompatible data and dtype" + with pytest.raises(TypeError, match=msg): + DataFrame("a", [1, 2], ["a", "c"], float) + + def test_constructor_with_datetimes(self, using_infer_string): + intname = np.dtype(int).name + floatname = np.dtype(np.float64).name + objectname = np.dtype(np.object_).name + + # single item + df = DataFrame( + { + "A": 1, + "B": "foo", + "C": "bar", + "D": Timestamp("20010101"), + "E": datetime(2001, 1, 2, 0, 0), + }, + index=np.arange(10), + ) + result = df.dtypes + expected = Series( + [np.dtype("int64")] + + [np.dtype(objectname) if not using_infer_string else "string"] * 2 + + [np.dtype("M8[s]"), np.dtype("M8[us]")], + index=list("ABCDE"), + ) + tm.assert_series_equal(result, expected) + + # check with ndarray construction ndim==0 (e.g. we are passing a ndim 0 + # ndarray with a dtype specified) + df = DataFrame( + { + "a": 1.0, + "b": 2, + "c": "foo", + floatname: np.array(1.0, dtype=floatname), + intname: np.array(1, dtype=intname), + }, + index=np.arange(10), + ) + result = df.dtypes + expected = Series( + [np.dtype("float64")] + + [np.dtype("int64")] + + [np.dtype("object") if not using_infer_string else "string"] + + [np.dtype("float64")] + + [np.dtype(intname)], + index=["a", "b", "c", floatname, intname], + ) + tm.assert_series_equal(result, expected) + + # check with ndarray construction ndim>0 + df = DataFrame( + { + "a": 1.0, + "b": 2, + "c": "foo", + floatname: np.array([1.0] * 10, dtype=floatname), + intname: np.array([1] * 10, dtype=intname), + }, + index=np.arange(10), + ) + result = df.dtypes + expected = Series( + [np.dtype("float64")] + + [np.dtype("int64")] + + [np.dtype("object") if not using_infer_string else "string"] + + [np.dtype("float64")] + + [np.dtype(intname)], + index=["a", "b", "c", floatname, intname], + ) + tm.assert_series_equal(result, expected) + + def test_constructor_with_datetimes1(self): + # GH 2809 + ind = date_range(start="2000-01-01", freq="D", periods=10) + datetimes = [ts.to_pydatetime() for ts in ind] + datetime_s = Series(datetimes) + assert datetime_s.dtype == "M8[ns]" + + def test_constructor_with_datetimes2(self): + # GH 2810 + ind = date_range(start="2000-01-01", freq="D", periods=10) + datetimes = [ts.to_pydatetime() for ts in ind] + dates = [ts.date() for ts in ind] + df = DataFrame(datetimes, columns=["datetimes"]) + df["dates"] = dates + result = df.dtypes + expected = Series( + [np.dtype("datetime64[ns]"), np.dtype("object")], + index=["datetimes", "dates"], + ) + tm.assert_series_equal(result, expected) + + def test_constructor_with_datetimes3(self): + # GH 7594 + # don't coerce tz-aware + tz = pytz.timezone("US/Eastern") + dt = tz.localize(datetime(2012, 1, 1)) + + df = DataFrame({"End Date": dt}, index=[0]) + assert df.iat[0, 0] == dt + tm.assert_series_equal( + df.dtypes, Series({"End Date": "datetime64[us, US/Eastern]"}, dtype=object) + ) + + df = DataFrame([{"End Date": dt}]) + assert df.iat[0, 0] == dt + tm.assert_series_equal( + df.dtypes, Series({"End Date": "datetime64[ns, US/Eastern]"}, dtype=object) + ) + + def test_constructor_with_datetimes4(self): + # tz-aware (UTC and other tz's) + # GH 8411 + dr = date_range("20130101", periods=3) + df = DataFrame({"value": dr}) + assert df.iat[0, 0].tz is None + dr = date_range("20130101", periods=3, tz="UTC") + df = DataFrame({"value": dr}) + assert str(df.iat[0, 0].tz) == "UTC" + dr = date_range("20130101", periods=3, tz="US/Eastern") + df = DataFrame({"value": dr}) + assert str(df.iat[0, 0].tz) == "US/Eastern" + + def test_constructor_with_datetimes5(self): + # GH 7822 + # preserver an index with a tz on dict construction + i = date_range("1/1/2011", periods=5, freq="10s", tz="US/Eastern") + + expected = DataFrame({"a": i.to_series().reset_index(drop=True)}) + df = DataFrame() + df["a"] = i + tm.assert_frame_equal(df, expected) + + df = DataFrame({"a": i}) + tm.assert_frame_equal(df, expected) + + def test_constructor_with_datetimes6(self): + # multiples + i = date_range("1/1/2011", periods=5, freq="10s", tz="US/Eastern") + i_no_tz = date_range("1/1/2011", periods=5, freq="10s") + df = DataFrame({"a": i, "b": i_no_tz}) + expected = DataFrame({"a": i.to_series().reset_index(drop=True), "b": i_no_tz}) + tm.assert_frame_equal(df, expected) + + @pytest.mark.parametrize( + "arr", + [ + np.array([None, None, None, None, datetime.now(), None]), + np.array([None, None, datetime.now(), None]), + [[np.datetime64("NaT")], [None]], + [[np.datetime64("NaT")], [pd.NaT]], + [[None], [np.datetime64("NaT")]], + [[None], [pd.NaT]], + [[pd.NaT], [np.datetime64("NaT")]], + [[pd.NaT], [None]], + ], + ) + def test_constructor_datetimes_with_nulls(self, arr): + # gh-15869, GH#11220 + result = DataFrame(arr).dtypes + expected = Series([np.dtype("datetime64[ns]")]) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("order", ["K", "A", "C", "F"]) + @pytest.mark.parametrize( + "unit", + ["M", "D", "h", "m", "s", "ms", "us", "ns"], + ) + def test_constructor_datetimes_non_ns(self, order, unit): + dtype = f"datetime64[{unit}]" + na = np.array( + [ + ["2015-01-01", "2015-01-02", "2015-01-03"], + ["2017-01-01", "2017-01-02", "2017-02-03"], + ], + dtype=dtype, + order=order, + ) + df = DataFrame(na) + expected = DataFrame(na.astype("M8[ns]")) + if unit in ["M", "D", "h", "m"]: + with pytest.raises(TypeError, match="Cannot cast"): + expected.astype(dtype) + + # instead the constructor casts to the closest supported reso, i.e. "s" + expected = expected.astype("datetime64[s]") + else: + expected = expected.astype(dtype=dtype) + + tm.assert_frame_equal(df, expected) + + @pytest.mark.parametrize("order", ["K", "A", "C", "F"]) + @pytest.mark.parametrize( + "unit", + [ + "D", + "h", + "m", + "s", + "ms", + "us", + "ns", + ], + ) + def test_constructor_timedelta_non_ns(self, order, unit): + dtype = f"timedelta64[{unit}]" + na = np.array( + [ + [np.timedelta64(1, "D"), np.timedelta64(2, "D")], + [np.timedelta64(4, "D"), np.timedelta64(5, "D")], + ], + dtype=dtype, + order=order, + ) + df = DataFrame(na) + if unit in ["D", "h", "m"]: + # we get the nearest supported unit, i.e. "s" + exp_unit = "s" + else: + exp_unit = unit + exp_dtype = np.dtype(f"m8[{exp_unit}]") + expected = DataFrame( + [ + [Timedelta(1, "D"), Timedelta(2, "D")], + [Timedelta(4, "D"), Timedelta(5, "D")], + ], + dtype=exp_dtype, + ) + # TODO(2.0): ideally we should get the same 'expected' without passing + # dtype=exp_dtype. + tm.assert_frame_equal(df, expected) + + def test_constructor_for_list_with_dtypes(self, using_infer_string): + # test list of lists/ndarrays + df = DataFrame([np.arange(5) for x in range(5)]) + result = df.dtypes + expected = Series([np.dtype("int")] * 5) + tm.assert_series_equal(result, expected) + + df = DataFrame([np.array(np.arange(5), dtype="int32") for x in range(5)]) + result = df.dtypes + expected = Series([np.dtype("int32")] * 5) + tm.assert_series_equal(result, expected) + + # overflow issue? (we always expected int64 upcasting here) + df = DataFrame({"a": [2**31, 2**31 + 1]}) + assert df.dtypes.iloc[0] == np.dtype("int64") + + # GH #2751 (construction with no index specified), make sure we cast to + # platform values + df = DataFrame([1, 2]) + assert df.dtypes.iloc[0] == np.dtype("int64") + + df = DataFrame([1.0, 2.0]) + assert df.dtypes.iloc[0] == np.dtype("float64") + + df = DataFrame({"a": [1, 2]}) + assert df.dtypes.iloc[0] == np.dtype("int64") + + df = DataFrame({"a": [1.0, 2.0]}) + assert df.dtypes.iloc[0] == np.dtype("float64") + + df = DataFrame({"a": 1}, index=range(3)) + assert df.dtypes.iloc[0] == np.dtype("int64") + + df = DataFrame({"a": 1.0}, index=range(3)) + assert df.dtypes.iloc[0] == np.dtype("float64") + + # with object list + df = DataFrame( + { + "a": [1, 2, 4, 7], + "b": [1.2, 2.3, 5.1, 6.3], + "c": list("abcd"), + "d": [datetime(2000, 1, 1) for i in range(4)], + "e": [1.0, 2, 4.0, 7], + } + ) + result = df.dtypes + expected = Series( + [ + np.dtype("int64"), + np.dtype("float64"), + np.dtype("object") if not using_infer_string else "string", + np.dtype("datetime64[ns]"), + np.dtype("float64"), + ], + index=list("abcde"), + ) + tm.assert_series_equal(result, expected) + + def test_constructor_frame_copy(self, float_frame): + cop = DataFrame(float_frame, copy=True) + cop["A"] = 5 + assert (cop["A"] == 5).all() + assert not (float_frame["A"] == 5).all() + + def test_constructor_frame_shallow_copy(self, float_frame): + # constructing a DataFrame from DataFrame with copy=False should still + # give a "shallow" copy (share data, not attributes) + # https://github.com/pandas-dev/pandas/issues/49523 + orig = float_frame.copy() + cop = DataFrame(float_frame) + assert cop._mgr is not float_frame._mgr + # Overwriting index of copy doesn't change original + cop.index = np.arange(len(cop)) + tm.assert_frame_equal(float_frame, orig) + + def test_constructor_ndarray_copy( + self, float_frame, using_array_manager, using_copy_on_write + ): + if not using_array_manager: + arr = float_frame.values.copy() + df = DataFrame(arr) + + arr[5] = 5 + if using_copy_on_write: + assert not (df.values[5] == 5).all() + else: + assert (df.values[5] == 5).all() + + df = DataFrame(arr, copy=True) + arr[6] = 6 + assert not (df.values[6] == 6).all() + else: + arr = float_frame.values.copy() + # default: copy to ensure contiguous arrays + df = DataFrame(arr) + assert df._mgr.arrays[0].flags.c_contiguous + arr[0, 0] = 100 + assert df.iloc[0, 0] != 100 + + # manually specify copy=False + df = DataFrame(arr, copy=False) + assert not df._mgr.arrays[0].flags.c_contiguous + arr[0, 0] = 1000 + assert df.iloc[0, 0] == 1000 + + def test_constructor_series_copy(self, float_frame): + series = float_frame._series + + df = DataFrame({"A": series["A"]}, copy=True) + # TODO can be replaced with `df.loc[:, "A"] = 5` after deprecation about + # inplace mutation is enforced + df.loc[df.index[0] : df.index[-1], "A"] = 5 + + assert not (series["A"] == 5).all() + + @pytest.mark.parametrize( + "df", + [ + DataFrame([[1, 2, 3], [4, 5, 6]], index=[1, np.nan]), + DataFrame([[1, 2, 3], [4, 5, 6]], columns=[1.1, 2.2, np.nan]), + DataFrame([[0, 1, 2, 3], [4, 5, 6, 7]], columns=[np.nan, 1.1, 2.2, np.nan]), + DataFrame( + [[0.0, 1, 2, 3.0], [4, 5, 6, 7]], columns=[np.nan, 1.1, 2.2, np.nan] + ), + DataFrame([[0.0, 1, 2, 3.0], [4, 5, 6, 7]], columns=[np.nan, 1, 2, 2]), + ], + ) + def test_constructor_with_nas(self, df): + # GH 5016 + # na's in indices + # GH 21428 (non-unique columns) + + for i in range(len(df.columns)): + df.iloc[:, i] + + indexer = np.arange(len(df.columns))[isna(df.columns)] + + # No NaN found -> error + if len(indexer) == 0: + with pytest.raises(KeyError, match="^nan$"): + df.loc[:, np.nan] + # single nan should result in Series + elif len(indexer) == 1: + tm.assert_series_equal(df.iloc[:, indexer[0]], df.loc[:, np.nan]) + # multiple nans should result in DataFrame + else: + tm.assert_frame_equal(df.iloc[:, indexer], df.loc[:, np.nan]) + + def test_constructor_lists_to_object_dtype(self): + # from #1074 + d = DataFrame({"a": [np.nan, False]}) + assert d["a"].dtype == np.object_ + assert not d["a"][1] + + def test_constructor_ndarray_categorical_dtype(self): + cat = Categorical(["A", "B", "C"]) + arr = np.array(cat).reshape(-1, 1) + arr = np.broadcast_to(arr, (3, 4)) + + result = DataFrame(arr, dtype=cat.dtype) + + expected = DataFrame({0: cat, 1: cat, 2: cat, 3: cat}) + tm.assert_frame_equal(result, expected) + + def test_constructor_categorical(self): + # GH8626 + + # dict creation + df = DataFrame({"A": list("abc")}, dtype="category") + expected = Series(list("abc"), dtype="category", name="A") + tm.assert_series_equal(df["A"], expected) + + # to_frame + s = Series(list("abc"), dtype="category") + result = s.to_frame() + expected = Series(list("abc"), dtype="category", name=0) + tm.assert_series_equal(result[0], expected) + result = s.to_frame(name="foo") + expected = Series(list("abc"), dtype="category", name="foo") + tm.assert_series_equal(result["foo"], expected) + + # list-like creation + df = DataFrame(list("abc"), dtype="category") + expected = Series(list("abc"), dtype="category", name=0) + tm.assert_series_equal(df[0], expected) + + def test_construct_from_1item_list_of_categorical(self): + # pre-2.0 this behaved as DataFrame({0: cat}), in 2.0 we remove + # Categorical special case + # ndim != 1 + cat = Categorical(list("abc")) + df = DataFrame([cat]) + expected = DataFrame([cat.astype(object)]) + tm.assert_frame_equal(df, expected) + + def test_construct_from_list_of_categoricals(self): + # pre-2.0 this behaved as DataFrame({0: cat}), in 2.0 we remove + # Categorical special case + + df = DataFrame([Categorical(list("abc")), Categorical(list("abd"))]) + expected = DataFrame([["a", "b", "c"], ["a", "b", "d"]]) + tm.assert_frame_equal(df, expected) + + def test_from_nested_listlike_mixed_types(self): + # pre-2.0 this behaved as DataFrame({0: cat}), in 2.0 we remove + # Categorical special case + # mixed + df = DataFrame([Categorical(list("abc")), list("def")]) + expected = DataFrame([["a", "b", "c"], ["d", "e", "f"]]) + tm.assert_frame_equal(df, expected) + + def test_construct_from_listlikes_mismatched_lengths(self): + df = DataFrame([Categorical(list("abc")), Categorical(list("abdefg"))]) + expected = DataFrame([list("abc"), list("abdefg")]) + tm.assert_frame_equal(df, expected) + + def test_constructor_categorical_series(self): + items = [1, 2, 3, 1] + exp = Series(items).astype("category") + res = Series(items, dtype="category") + tm.assert_series_equal(res, exp) + + items = ["a", "b", "c", "a"] + exp = Series(items).astype("category") + res = Series(items, dtype="category") + tm.assert_series_equal(res, exp) + + # insert into frame with different index + # GH 8076 + index = date_range("20000101", periods=3) + expected = Series( + Categorical(values=[np.nan, np.nan, np.nan], categories=["a", "b", "c"]) + ) + expected.index = index + + expected = DataFrame({"x": expected}) + df = DataFrame({"x": Series(["a", "b", "c"], dtype="category")}, index=index) + tm.assert_frame_equal(df, expected) + + @pytest.mark.parametrize( + "dtype", + tm.ALL_NUMERIC_DTYPES + + tm.DATETIME64_DTYPES + + tm.TIMEDELTA64_DTYPES + + tm.BOOL_DTYPES, + ) + def test_check_dtype_empty_numeric_column(self, dtype): + # GH24386: Ensure dtypes are set correctly for an empty DataFrame. + # Empty DataFrame is generated via dictionary data with non-overlapping columns. + data = DataFrame({"a": [1, 2]}, columns=["b"], dtype=dtype) + + assert data.b.dtype == dtype + + @pytest.mark.parametrize( + "dtype", tm.STRING_DTYPES + tm.BYTES_DTYPES + tm.OBJECT_DTYPES + ) + def test_check_dtype_empty_string_column(self, request, dtype, using_array_manager): + # GH24386: Ensure dtypes are set correctly for an empty DataFrame. + # Empty DataFrame is generated via dictionary data with non-overlapping columns. + data = DataFrame({"a": [1, 2]}, columns=["b"], dtype=dtype) + + if using_array_manager and dtype in tm.BYTES_DTYPES: + # TODO(ArrayManager) astype to bytes dtypes does not yet give object dtype + td.mark_array_manager_not_yet_implemented(request) + + assert data.b.dtype.name == "object" + + def test_to_frame_with_falsey_names(self): + # GH 16114 + result = Series(name=0, dtype=object).to_frame().dtypes + expected = Series({0: object}) + tm.assert_series_equal(result, expected) + + result = DataFrame(Series(name=0, dtype=object)).dtypes + tm.assert_series_equal(result, expected) + + @pytest.mark.arm_slow + @pytest.mark.parametrize("dtype", [None, "uint8", "category"]) + def test_constructor_range_dtype(self, dtype): + expected = DataFrame({"A": [0, 1, 2, 3, 4]}, dtype=dtype or "int64") + + # GH 26342 + result = DataFrame(range(5), columns=["A"], dtype=dtype) + tm.assert_frame_equal(result, expected) + + # GH 16804 + result = DataFrame({"A": range(5)}, dtype=dtype) + tm.assert_frame_equal(result, expected) + + def test_frame_from_list_subclass(self): + # GH21226 + class List(list): + pass + + expected = DataFrame([[1, 2, 3], [4, 5, 6]]) + result = DataFrame(List([List([1, 2, 3]), List([4, 5, 6])])) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "extension_arr", + [ + Categorical(list("aabbc")), + SparseArray([1, np.nan, np.nan, np.nan]), + IntervalArray([Interval(0, 1), Interval(1, 5)]), + PeriodArray(pd.period_range(start="1/1/2017", end="1/1/2018", freq="M")), + ], + ) + def test_constructor_with_extension_array(self, extension_arr): + # GH11363 + expected = DataFrame(Series(extension_arr)) + result = DataFrame(extension_arr) + tm.assert_frame_equal(result, expected) + + def test_datetime_date_tuple_columns_from_dict(self): + # GH 10863 + v = date.today() + tup = v, v + result = DataFrame({tup: Series(range(3), index=range(3))}, columns=[tup]) + expected = DataFrame([0, 1, 2], columns=Index(Series([tup]))) + tm.assert_frame_equal(result, expected) + + def test_construct_with_two_categoricalindex_series(self): + # GH 14600 + s1 = Series([39, 6, 4], index=CategoricalIndex(["female", "male", "unknown"])) + s2 = Series( + [2, 152, 2, 242, 150], + index=CategoricalIndex(["f", "female", "m", "male", "unknown"]), + ) + result = DataFrame([s1, s2]) + expected = DataFrame( + np.array([[39, 6, 4, np.nan, np.nan], [152.0, 242.0, 150.0, 2.0, 2.0]]), + columns=["female", "male", "unknown", "f", "m"], + ) + tm.assert_frame_equal(result, expected) + + def test_constructor_series_nonexact_categoricalindex(self): + # GH 42424 + ser = Series(range(100)) + ser1 = cut(ser, 10).value_counts().head(5) + ser2 = cut(ser, 10).value_counts().tail(5) + result = DataFrame({"1": ser1, "2": ser2}) + index = CategoricalIndex( + [ + Interval(-0.099, 9.9, closed="right"), + Interval(9.9, 19.8, closed="right"), + Interval(19.8, 29.7, closed="right"), + Interval(29.7, 39.6, closed="right"), + Interval(39.6, 49.5, closed="right"), + Interval(49.5, 59.4, closed="right"), + Interval(59.4, 69.3, closed="right"), + Interval(69.3, 79.2, closed="right"), + Interval(79.2, 89.1, closed="right"), + Interval(89.1, 99, closed="right"), + ], + ordered=True, + ) + expected = DataFrame( + {"1": [10] * 5 + [np.nan] * 5, "2": [np.nan] * 5 + [10] * 5}, index=index + ) + tm.assert_frame_equal(expected, result) + + def test_from_M8_structured(self): + dates = [(datetime(2012, 9, 9, 0, 0), datetime(2012, 9, 8, 15, 10))] + arr = np.array(dates, dtype=[("Date", "M8[us]"), ("Forecasting", "M8[us]")]) + df = DataFrame(arr) + + assert df["Date"][0] == dates[0][0] + assert df["Forecasting"][0] == dates[0][1] + + s = Series(arr["Date"]) + assert isinstance(s[0], Timestamp) + assert s[0] == dates[0][0] + + def test_from_datetime_subclass(self): + # GH21142 Verify whether Datetime subclasses are also of dtype datetime + class DatetimeSubclass(datetime): + pass + + data = DataFrame({"datetime": [DatetimeSubclass(2020, 1, 1, 1, 1)]}) + assert data.datetime.dtype == "datetime64[ns]" + + def test_with_mismatched_index_length_raises(self): + # GH#33437 + dti = date_range("2016-01-01", periods=3, tz="US/Pacific") + msg = "Shape of passed values|Passed arrays should have the same length" + with pytest.raises(ValueError, match=msg): + DataFrame(dti, index=range(4)) + + def test_frame_ctor_datetime64_column(self): + rng = date_range("1/1/2000 00:00:00", "1/1/2000 1:59:50", freq="10s") + dates = np.asarray(rng) + + df = DataFrame( + {"A": np.random.default_rng(2).standard_normal(len(rng)), "B": dates} + ) + assert np.issubdtype(df["B"].dtype, np.dtype("M8[ns]")) + + def test_dataframe_constructor_infer_multiindex(self): + index_lists = [["a", "a", "b", "b"], ["x", "y", "x", "y"]] + + multi = DataFrame( + np.random.default_rng(2).standard_normal((4, 4)), + index=[np.array(x) for x in index_lists], + ) + assert isinstance(multi.index, MultiIndex) + assert not isinstance(multi.columns, MultiIndex) + + multi = DataFrame( + np.random.default_rng(2).standard_normal((4, 4)), columns=index_lists + ) + assert isinstance(multi.columns, MultiIndex) + + @pytest.mark.parametrize( + "input_vals", + [ + ([1, 2]), + (["1", "2"]), + (list(date_range("1/1/2011", periods=2, freq="h"))), + (list(date_range("1/1/2011", periods=2, freq="h", tz="US/Eastern"))), + ([Interval(left=0, right=5)]), + ], + ) + def test_constructor_list_str(self, input_vals, string_dtype): + # GH#16605 + # Ensure that data elements are converted to strings when + # dtype is str, 'str', or 'U' + + result = DataFrame({"A": input_vals}, dtype=string_dtype) + expected = DataFrame({"A": input_vals}).astype({"A": string_dtype}) + tm.assert_frame_equal(result, expected) + + def test_constructor_list_str_na(self, string_dtype): + result = DataFrame({"A": [1.0, 2.0, None]}, dtype=string_dtype) + expected = DataFrame({"A": ["1.0", "2.0", None]}, dtype=object) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("copy", [False, True]) + def test_dict_nocopy( + self, + request, + copy, + any_numeric_ea_dtype, + any_numpy_dtype, + using_array_manager, + using_copy_on_write, + ): + if ( + using_array_manager + and not copy + and any_numpy_dtype not in tm.STRING_DTYPES + tm.BYTES_DTYPES + ): + # TODO(ArrayManager) properly honor copy keyword for dict input + td.mark_array_manager_not_yet_implemented(request) + + a = np.array([1, 2], dtype=any_numpy_dtype) + b = np.array([3, 4], dtype=any_numpy_dtype) + if b.dtype.kind in ["S", "U"]: + # These get cast, making the checks below more cumbersome + pytest.skip(f"{b.dtype} get cast, making the checks below more cumbersome") + + c = pd.array([1, 2], dtype=any_numeric_ea_dtype) + c_orig = c.copy() + df = DataFrame({"a": a, "b": b, "c": c}, copy=copy) + + def get_base(obj): + if isinstance(obj, np.ndarray): + return obj.base + elif isinstance(obj.dtype, np.dtype): + # i.e. DatetimeArray, TimedeltaArray + return obj._ndarray.base + else: + raise TypeError + + def check_views(c_only: bool = False): + # written to work for either BlockManager or ArrayManager + + # Check that the underlying data behind df["c"] is still `c` + # after setting with iloc. Since we don't know which entry in + # df._mgr.arrays corresponds to df["c"], we just check that exactly + # one of these arrays is `c`. GH#38939 + assert sum(x is c for x in df._mgr.arrays) == 1 + if c_only: + # If we ever stop consolidating in setitem_with_indexer, + # this will become unnecessary. + return + + assert ( + sum( + get_base(x) is a + for x in df._mgr.arrays + if isinstance(x.dtype, np.dtype) + ) + == 1 + ) + assert ( + sum( + get_base(x) is b + for x in df._mgr.arrays + if isinstance(x.dtype, np.dtype) + ) + == 1 + ) + + if not copy: + # constructor preserves views + check_views() + + # TODO: most of the rest of this test belongs in indexing tests + if lib.is_np_dtype(df.dtypes.iloc[0], "fciuO"): + warn = None + else: + warn = FutureWarning + with tm.assert_produces_warning(warn, match="incompatible dtype"): + df.iloc[0, 0] = 0 + df.iloc[0, 1] = 0 + if not copy: + check_views(True) + + # FIXME(GH#35417): until GH#35417, iloc.setitem into EA values does not preserve + # view, so we have to check in the other direction + df.iloc[:, 2] = pd.array([45, 46], dtype=c.dtype) + assert df.dtypes.iloc[2] == c.dtype + if not copy and not using_copy_on_write: + check_views(True) + + if copy: + if a.dtype.kind == "M": + assert a[0] == a.dtype.type(1, "ns") + assert b[0] == b.dtype.type(3, "ns") + else: + assert a[0] == a.dtype.type(1) + assert b[0] == b.dtype.type(3) + # FIXME(GH#35417): enable after GH#35417 + assert c[0] == c_orig[0] # i.e. df.iloc[0, 2]=45 did *not* update c + elif not using_copy_on_write: + # TODO: we can call check_views if we stop consolidating + # in setitem_with_indexer + assert c[0] == 45 # i.e. df.iloc[0, 2]=45 *did* update c + # TODO: we can check b[0] == 0 if we stop consolidating in + # setitem_with_indexer (except for datetimelike?) + + def test_construct_from_dict_ea_series(self): + # GH#53744 - default of copy=True should also apply for Series with + # extension dtype + ser = Series([1, 2, 3], dtype="Int64") + df = DataFrame({"a": ser}) + assert not np.shares_memory(ser.values._data, df["a"].values._data) + + def test_from_series_with_name_with_columns(self): + # GH 7893 + result = DataFrame(Series(1, name="foo"), columns=["bar"]) + expected = DataFrame(columns=["bar"]) + tm.assert_frame_equal(result, expected) + + def test_nested_list_columns(self): + # GH 14467 + result = DataFrame( + [[1, 2, 3], [4, 5, 6]], columns=[["A", "A", "A"], ["a", "b", "c"]] + ) + expected = DataFrame( + [[1, 2, 3], [4, 5, 6]], + columns=MultiIndex.from_tuples([("A", "a"), ("A", "b"), ("A", "c")]), + ) + tm.assert_frame_equal(result, expected) + + def test_from_2d_object_array_of_periods_or_intervals(self): + # Period analogue to GH#26825 + pi = pd.period_range("2016-04-05", periods=3) + data = pi._data.astype(object).reshape(1, -1) + df = DataFrame(data) + assert df.shape == (1, 3) + assert (df.dtypes == pi.dtype).all() + assert (df == pi).all().all() + + ii = pd.IntervalIndex.from_breaks([3, 4, 5, 6]) + data2 = ii._data.astype(object).reshape(1, -1) + df2 = DataFrame(data2) + assert df2.shape == (1, 3) + assert (df2.dtypes == ii.dtype).all() + assert (df2 == ii).all().all() + + # mixed + data3 = np.r_[data, data2, data, data2].T + df3 = DataFrame(data3) + expected = DataFrame({0: pi, 1: ii, 2: pi, 3: ii}) + tm.assert_frame_equal(df3, expected) + + @pytest.mark.parametrize( + "col_a, col_b", + [ + ([[1], [2]], np.array([[1], [2]])), + (np.array([[1], [2]]), [[1], [2]]), + (np.array([[1], [2]]), np.array([[1], [2]])), + ], + ) + def test_error_from_2darray(self, col_a, col_b): + msg = "Per-column arrays must each be 1-dimensional" + with pytest.raises(ValueError, match=msg): + DataFrame({"a": col_a, "b": col_b}) + + def test_from_dict_with_missing_copy_false(self): + # GH#45369 filled columns should not be views of one another + df = DataFrame(index=[1, 2, 3], columns=["a", "b", "c"], copy=False) + assert not np.shares_memory(df["a"]._values, df["b"]._values) + + df.iloc[0, 0] = 0 + expected = DataFrame( + { + "a": [0, np.nan, np.nan], + "b": [np.nan, np.nan, np.nan], + "c": [np.nan, np.nan, np.nan], + }, + index=[1, 2, 3], + dtype=object, + ) + tm.assert_frame_equal(df, expected) + + def test_construction_empty_array_multi_column_raises(self): + # GH#46822 + msg = r"Shape of passed values is \(0, 1\), indices imply \(0, 2\)" + with pytest.raises(ValueError, match=msg): + DataFrame(data=np.array([]), columns=["a", "b"]) + + def test_construct_with_strings_and_none(self): + # GH#32218 + df = DataFrame(["1", "2", None], columns=["a"], dtype="str") + expected = DataFrame({"a": ["1", "2", None]}, dtype="str") + tm.assert_frame_equal(df, expected) + + def test_frame_string_inference(self): + # GH#54430 + pytest.importorskip("pyarrow") + dtype = "string[pyarrow_numpy]" + expected = DataFrame( + {"a": ["a", "b"]}, dtype=dtype, columns=Index(["a"], dtype=dtype) + ) + with pd.option_context("future.infer_string", True): + df = DataFrame({"a": ["a", "b"]}) + tm.assert_frame_equal(df, expected) + + expected = DataFrame( + {"a": ["a", "b"]}, + dtype=dtype, + columns=Index(["a"], dtype=dtype), + index=Index(["x", "y"], dtype=dtype), + ) + with pd.option_context("future.infer_string", True): + df = DataFrame({"a": ["a", "b"]}, index=["x", "y"]) + tm.assert_frame_equal(df, expected) + + expected = DataFrame( + {"a": ["a", 1]}, dtype="object", columns=Index(["a"], dtype=dtype) + ) + with pd.option_context("future.infer_string", True): + df = DataFrame({"a": ["a", 1]}) + tm.assert_frame_equal(df, expected) + + expected = DataFrame( + {"a": ["a", "b"]}, dtype="object", columns=Index(["a"], dtype=dtype) + ) + with pd.option_context("future.infer_string", True): + df = DataFrame({"a": ["a", "b"]}, dtype="object") + tm.assert_frame_equal(df, expected) + + def test_frame_string_inference_array_string_dtype(self): + # GH#54496 + pytest.importorskip("pyarrow") + dtype = "string[pyarrow_numpy]" + expected = DataFrame( + {"a": ["a", "b"]}, dtype=dtype, columns=Index(["a"], dtype=dtype) + ) + with pd.option_context("future.infer_string", True): + df = DataFrame({"a": np.array(["a", "b"])}) + tm.assert_frame_equal(df, expected) + + expected = DataFrame({0: ["a", "b"], 1: ["c", "d"]}, dtype=dtype) + with pd.option_context("future.infer_string", True): + df = DataFrame(np.array([["a", "c"], ["b", "d"]])) + tm.assert_frame_equal(df, expected) + + expected = DataFrame( + {"a": ["a", "b"], "b": ["c", "d"]}, + dtype=dtype, + columns=Index(["a", "b"], dtype=dtype), + ) + with pd.option_context("future.infer_string", True): + df = DataFrame(np.array([["a", "c"], ["b", "d"]]), columns=["a", "b"]) + tm.assert_frame_equal(df, expected) + + def test_frame_string_inference_block_dim(self): + # GH#55363 + pytest.importorskip("pyarrow") + with pd.option_context("future.infer_string", True): + df = DataFrame(np.array([["hello", "goodbye"], ["hello", "Hello"]])) + assert df._mgr.blocks[0].ndim == 2 + + def test_inference_on_pandas_objects(self): + # GH#56012 + idx = Index([Timestamp("2019-12-31")], dtype=object) + with tm.assert_produces_warning(FutureWarning, match="Dtype inference"): + result = DataFrame(idx, columns=["a"]) + assert result.dtypes.iloc[0] != np.object_ + result = DataFrame({"a": idx}) + assert result.dtypes.iloc[0] == np.object_ + + ser = Series([Timestamp("2019-12-31")], dtype=object) + + with tm.assert_produces_warning(FutureWarning, match="Dtype inference"): + result = DataFrame(ser, columns=["a"]) + assert result.dtypes.iloc[0] != np.object_ + result = DataFrame({"a": ser}) + assert result.dtypes.iloc[0] == np.object_ + + +class TestDataFrameConstructorIndexInference: + def test_frame_from_dict_of_series_overlapping_monthly_period_indexes(self): + rng1 = pd.period_range("1/1/1999", "1/1/2012", freq="M") + s1 = Series(np.random.default_rng(2).standard_normal(len(rng1)), rng1) + + rng2 = pd.period_range("1/1/1980", "12/1/2001", freq="M") + s2 = Series(np.random.default_rng(2).standard_normal(len(rng2)), rng2) + df = DataFrame({"s1": s1, "s2": s2}) + + exp = pd.period_range("1/1/1980", "1/1/2012", freq="M") + tm.assert_index_equal(df.index, exp) + + def test_frame_from_dict_with_mixed_tzaware_indexes(self): + # GH#44091 + dti = date_range("2016-01-01", periods=3) + + ser1 = Series(range(3), index=dti) + ser2 = Series(range(3), index=dti.tz_localize("UTC")) + ser3 = Series(range(3), index=dti.tz_localize("US/Central")) + ser4 = Series(range(3)) + + # no tz-naive, but we do have mixed tzs and a non-DTI + df1 = DataFrame({"A": ser2, "B": ser3, "C": ser4}) + exp_index = Index( + list(ser2.index) + list(ser3.index) + list(ser4.index), dtype=object + ) + tm.assert_index_equal(df1.index, exp_index) + + df2 = DataFrame({"A": ser2, "C": ser4, "B": ser3}) + exp_index3 = Index( + list(ser2.index) + list(ser4.index) + list(ser3.index), dtype=object + ) + tm.assert_index_equal(df2.index, exp_index3) + + df3 = DataFrame({"B": ser3, "A": ser2, "C": ser4}) + exp_index3 = Index( + list(ser3.index) + list(ser2.index) + list(ser4.index), dtype=object + ) + tm.assert_index_equal(df3.index, exp_index3) + + df4 = DataFrame({"C": ser4, "B": ser3, "A": ser2}) + exp_index4 = Index( + list(ser4.index) + list(ser3.index) + list(ser2.index), dtype=object + ) + tm.assert_index_equal(df4.index, exp_index4) + + # TODO: not clear if these raising is desired (no extant tests), + # but this is de facto behavior 2021-12-22 + msg = "Cannot join tz-naive with tz-aware DatetimeIndex" + with pytest.raises(TypeError, match=msg): + DataFrame({"A": ser2, "B": ser3, "C": ser4, "D": ser1}) + with pytest.raises(TypeError, match=msg): + DataFrame({"A": ser2, "B": ser3, "D": ser1}) + with pytest.raises(TypeError, match=msg): + DataFrame({"D": ser1, "A": ser2, "B": ser3}) + + @pytest.mark.parametrize( + "key_val, col_vals, col_type", + [ + ["3", ["3", "4"], "utf8"], + [3, [3, 4], "int8"], + ], + ) + def test_dict_data_arrow_column_expansion(self, key_val, col_vals, col_type): + # GH 53617 + pa = pytest.importorskip("pyarrow") + cols = pd.arrays.ArrowExtensionArray( + pa.array(col_vals, type=pa.dictionary(pa.int8(), getattr(pa, col_type)())) + ) + result = DataFrame({key_val: [1, 2]}, columns=cols) + expected = DataFrame([[1, np.nan], [2, np.nan]], columns=cols) + expected.isetitem(1, expected.iloc[:, 1].astype(object)) + tm.assert_frame_equal(result, expected) + + +class TestDataFrameConstructorWithDtypeCoercion: + def test_floating_values_integer_dtype(self): + # GH#40110 make DataFrame behavior with arraylike floating data and + # inty dtype match Series behavior + + arr = np.random.default_rng(2).standard_normal((10, 5)) + + # GH#49599 in 2.0 we raise instead of either + # a) silently ignoring dtype and returningfloat (the old Series behavior) or + # b) rounding (the old DataFrame behavior) + msg = "Trying to coerce float values to integers" + with pytest.raises(ValueError, match=msg): + DataFrame(arr, dtype="i8") + + df = DataFrame(arr.round(), dtype="i8") + assert (df.dtypes == "i8").all() + + # with NaNs, we go through a different path with a different warning + arr[0, 0] = np.nan + msg = r"Cannot convert non-finite values \(NA or inf\) to integer" + with pytest.raises(IntCastingNaNError, match=msg): + DataFrame(arr, dtype="i8") + with pytest.raises(IntCastingNaNError, match=msg): + Series(arr[0], dtype="i8") + # The future (raising) behavior matches what we would get via astype: + msg = r"Cannot convert non-finite values \(NA or inf\) to integer" + with pytest.raises(IntCastingNaNError, match=msg): + DataFrame(arr).astype("i8") + with pytest.raises(IntCastingNaNError, match=msg): + Series(arr[0]).astype("i8") + + +class TestDataFrameConstructorWithDatetimeTZ: + @pytest.mark.parametrize("tz", ["US/Eastern", "dateutil/US/Eastern"]) + def test_construction_preserves_tzaware_dtypes(self, tz): + # after GH#7822 + # these retain the timezones on dict construction + dr = date_range("2011/1/1", "2012/1/1", freq="W-FRI") + dr_tz = dr.tz_localize(tz) + df = DataFrame({"A": "foo", "B": dr_tz}, index=dr) + tz_expected = DatetimeTZDtype("ns", dr_tz.tzinfo) + assert df["B"].dtype == tz_expected + + # GH#2810 (with timezones) + datetimes_naive = [ts.to_pydatetime() for ts in dr] + datetimes_with_tz = [ts.to_pydatetime() for ts in dr_tz] + df = DataFrame({"dr": dr}) + df["dr_tz"] = dr_tz + df["datetimes_naive"] = datetimes_naive + df["datetimes_with_tz"] = datetimes_with_tz + result = df.dtypes + expected = Series( + [ + np.dtype("datetime64[ns]"), + DatetimeTZDtype(tz=tz), + np.dtype("datetime64[ns]"), + DatetimeTZDtype(tz=tz), + ], + index=["dr", "dr_tz", "datetimes_naive", "datetimes_with_tz"], + ) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("pydt", [True, False]) + def test_constructor_data_aware_dtype_naive(self, tz_aware_fixture, pydt): + # GH#25843, GH#41555, GH#33401 + tz = tz_aware_fixture + ts = Timestamp("2019", tz=tz) + if pydt: + ts = ts.to_pydatetime() + + msg = ( + "Cannot convert timezone-aware data to timezone-naive dtype. " + r"Use pd.Series\(values\).dt.tz_localize\(None\) instead." + ) + with pytest.raises(ValueError, match=msg): + DataFrame({0: [ts]}, dtype="datetime64[ns]") + + msg2 = "Cannot unbox tzaware Timestamp to tznaive dtype" + with pytest.raises(TypeError, match=msg2): + DataFrame({0: ts}, index=[0], dtype="datetime64[ns]") + + with pytest.raises(ValueError, match=msg): + DataFrame([ts], dtype="datetime64[ns]") + + with pytest.raises(ValueError, match=msg): + DataFrame(np.array([ts], dtype=object), dtype="datetime64[ns]") + + with pytest.raises(TypeError, match=msg2): + DataFrame(ts, index=[0], columns=[0], dtype="datetime64[ns]") + + with pytest.raises(ValueError, match=msg): + DataFrame([Series([ts])], dtype="datetime64[ns]") + + with pytest.raises(ValueError, match=msg): + DataFrame([[ts]], columns=[0], dtype="datetime64[ns]") + + def test_from_dict(self): + # 8260 + # support datetime64 with tz + + idx = Index(date_range("20130101", periods=3, tz="US/Eastern"), name="foo") + dr = date_range("20130110", periods=3) + + # construction + df = DataFrame({"A": idx, "B": dr}) + assert df["A"].dtype, "M8[ns, US/Eastern" + assert df["A"].name == "A" + tm.assert_series_equal(df["A"], Series(idx, name="A")) + tm.assert_series_equal(df["B"], Series(dr, name="B")) + + def test_from_index(self): + # from index + idx2 = date_range("20130101", periods=3, tz="US/Eastern", name="foo") + df2 = DataFrame(idx2) + tm.assert_series_equal(df2["foo"], Series(idx2, name="foo")) + df2 = DataFrame(Series(idx2)) + tm.assert_series_equal(df2["foo"], Series(idx2, name="foo")) + + idx2 = date_range("20130101", periods=3, tz="US/Eastern") + df2 = DataFrame(idx2) + tm.assert_series_equal(df2[0], Series(idx2, name=0)) + df2 = DataFrame(Series(idx2)) + tm.assert_series_equal(df2[0], Series(idx2, name=0)) + + def test_frame_dict_constructor_datetime64_1680(self): + dr = date_range("1/1/2012", periods=10) + s = Series(dr, index=dr) + + # it works! + DataFrame({"a": "foo", "b": s}, index=dr) + DataFrame({"a": "foo", "b": s.values}, index=dr) + + def test_frame_datetime64_mixed_index_ctor_1681(self): + dr = date_range("2011/1/1", "2012/1/1", freq="W-FRI") + ts = Series(dr) + + # it works! + d = DataFrame({"A": "foo", "B": ts}, index=dr) + assert d["B"].isna().all() + + def test_frame_timeseries_column(self): + # GH19157 + dr = date_range( + start="20130101T10:00:00", periods=3, freq="min", tz="US/Eastern" + ) + result = DataFrame(dr, columns=["timestamps"]) + expected = DataFrame( + { + "timestamps": [ + Timestamp("20130101T10:00:00", tz="US/Eastern"), + Timestamp("20130101T10:01:00", tz="US/Eastern"), + Timestamp("20130101T10:02:00", tz="US/Eastern"), + ] + } + ) + tm.assert_frame_equal(result, expected) + + def test_nested_dict_construction(self): + # GH22227 + columns = ["Nevada", "Ohio"] + pop = { + "Nevada": {2001: 2.4, 2002: 2.9}, + "Ohio": {2000: 1.5, 2001: 1.7, 2002: 3.6}, + } + result = DataFrame(pop, index=[2001, 2002, 2003], columns=columns) + expected = DataFrame( + [(2.4, 1.7), (2.9, 3.6), (np.nan, np.nan)], + columns=columns, + index=Index([2001, 2002, 2003]), + ) + tm.assert_frame_equal(result, expected) + + def test_from_tzaware_object_array(self): + # GH#26825 2D object array of tzaware timestamps should not raise + dti = date_range("2016-04-05 04:30", periods=3, tz="UTC") + data = dti._data.astype(object).reshape(1, -1) + df = DataFrame(data) + assert df.shape == (1, 3) + assert (df.dtypes == dti.dtype).all() + assert (df == dti).all().all() + + def test_from_tzaware_mixed_object_array(self): + # GH#26825 + arr = np.array( + [ + [ + Timestamp("2013-01-01 00:00:00"), + Timestamp("2013-01-02 00:00:00"), + Timestamp("2013-01-03 00:00:00"), + ], + [ + Timestamp("2013-01-01 00:00:00-0500", tz="US/Eastern"), + pd.NaT, + Timestamp("2013-01-03 00:00:00-0500", tz="US/Eastern"), + ], + [ + Timestamp("2013-01-01 00:00:00+0100", tz="CET"), + pd.NaT, + Timestamp("2013-01-03 00:00:00+0100", tz="CET"), + ], + ], + dtype=object, + ).T + res = DataFrame(arr, columns=["A", "B", "C"]) + + expected_dtypes = [ + "datetime64[ns]", + "datetime64[ns, US/Eastern]", + "datetime64[ns, CET]", + ] + assert (res.dtypes == expected_dtypes).all() + + def test_from_2d_ndarray_with_dtype(self): + # GH#12513 + array_dim2 = np.arange(10).reshape((5, 2)) + df = DataFrame(array_dim2, dtype="datetime64[ns, UTC]") + + expected = DataFrame(array_dim2).astype("datetime64[ns, UTC]") + tm.assert_frame_equal(df, expected) + + @pytest.mark.parametrize("typ", [set, frozenset]) + def test_construction_from_set_raises(self, typ): + # https://github.com/pandas-dev/pandas/issues/32582 + values = typ({1, 2, 3}) + msg = f"'{typ.__name__}' type is unordered" + with pytest.raises(TypeError, match=msg): + DataFrame({"a": values}) + + with pytest.raises(TypeError, match=msg): + Series(values) + + def test_construction_from_ndarray_datetimelike(self): + # ensure the underlying arrays are properly wrapped as EA when + # constructed from 2D ndarray + arr = np.arange(0, 12, dtype="datetime64[ns]").reshape(4, 3) + df = DataFrame(arr) + assert all(isinstance(arr, DatetimeArray) for arr in df._mgr.arrays) + + def test_construction_from_ndarray_with_eadtype_mismatched_columns(self): + arr = np.random.default_rng(2).standard_normal((10, 2)) + dtype = pd.array([2.0]).dtype + msg = r"len\(arrays\) must match len\(columns\)" + with pytest.raises(ValueError, match=msg): + DataFrame(arr, columns=["foo"], dtype=dtype) + + arr2 = pd.array([2.0, 3.0, 4.0]) + with pytest.raises(ValueError, match=msg): + DataFrame(arr2, columns=["foo", "bar"]) + + def test_columns_indexes_raise_on_sets(self): + # GH 47215 + data = [[1, 2, 3], [4, 5, 6]] + with pytest.raises(ValueError, match="index cannot be a set"): + DataFrame(data, index={"a", "b"}) + with pytest.raises(ValueError, match="columns cannot be a set"): + DataFrame(data, columns={"a", "b", "c"}) + + # TODO: make this not cast to object in pandas 3.0 + @pytest.mark.skipif( + not np_version_gt2, reason="StringDType only available in numpy 2 and above" + ) + @pytest.mark.parametrize( + "data", + [ + {"a": ["a", "b", "c"], "b": [1.0, 2.0, 3.0], "c": ["d", "e", "f"]}, + ], + ) + def test_np_string_array_object_cast(self, data): + from numpy.dtypes import StringDType + + data["a"] = np.array(data["a"], dtype=StringDType()) + res = DataFrame(data) + assert res["a"].dtype == np.object_ + assert (res["a"] == data["a"]).all() + + +def get1(obj): # TODO: make a helper in tm? + if isinstance(obj, Series): + return obj.iloc[0] + else: + return obj.iloc[0, 0] + + +class TestFromScalar: + @pytest.fixture(params=[list, dict, None]) + def box(self, request): + return request.param + + @pytest.fixture + def constructor(self, frame_or_series, box): + extra = {"index": range(2)} + if frame_or_series is DataFrame: + extra["columns"] = ["A"] + + if box is None: + return functools.partial(frame_or_series, **extra) + + elif box is dict: + if frame_or_series is Series: + return lambda x, **kwargs: frame_or_series( + {0: x, 1: x}, **extra, **kwargs + ) + else: + return lambda x, **kwargs: frame_or_series({"A": x}, **extra, **kwargs) + elif frame_or_series is Series: + return lambda x, **kwargs: frame_or_series([x, x], **extra, **kwargs) + else: + return lambda x, **kwargs: frame_or_series({"A": [x, x]}, **extra, **kwargs) + + @pytest.mark.parametrize("dtype", ["M8[ns]", "m8[ns]"]) + def test_from_nat_scalar(self, dtype, constructor): + obj = constructor(pd.NaT, dtype=dtype) + assert np.all(obj.dtypes == dtype) + assert np.all(obj.isna()) + + def test_from_timedelta_scalar_preserves_nanos(self, constructor): + td = Timedelta(1) + + obj = constructor(td, dtype="m8[ns]") + assert get1(obj) == td + + def test_from_timestamp_scalar_preserves_nanos(self, constructor, fixed_now_ts): + ts = fixed_now_ts + Timedelta(1) + + obj = constructor(ts, dtype="M8[ns]") + assert get1(obj) == ts + + def test_from_timedelta64_scalar_object(self, constructor): + td = Timedelta(1) + td64 = td.to_timedelta64() + + obj = constructor(td64, dtype=object) + assert isinstance(get1(obj), np.timedelta64) + + @pytest.mark.parametrize("cls", [np.datetime64, np.timedelta64]) + def test_from_scalar_datetimelike_mismatched(self, constructor, cls): + scalar = cls("NaT", "ns") + dtype = {np.datetime64: "m8[ns]", np.timedelta64: "M8[ns]"}[cls] + + if cls is np.datetime64: + msg1 = "Invalid type for timedelta scalar: " + else: + msg1 = " is not convertible to datetime" + msg = "|".join(["Cannot cast", msg1]) + + with pytest.raises(TypeError, match=msg): + constructor(scalar, dtype=dtype) + + scalar = cls(4, "ns") + with pytest.raises(TypeError, match=msg): + constructor(scalar, dtype=dtype) + + @pytest.mark.parametrize("cls", [datetime, np.datetime64]) + def test_from_out_of_bounds_ns_datetime( + self, constructor, cls, request, box, frame_or_series + ): + # scalar that won't fit in nanosecond dt64, but will fit in microsecond + if box is list or (frame_or_series is Series and box is dict): + mark = pytest.mark.xfail( + reason="Timestamp constructor has been updated to cast dt64 to " + "non-nano, but DatetimeArray._from_sequence has not", + strict=True, + ) + request.applymarker(mark) + + scalar = datetime(9999, 1, 1) + exp_dtype = "M8[us]" # pydatetime objects default to this reso + + if cls is np.datetime64: + scalar = np.datetime64(scalar, "D") + exp_dtype = "M8[s]" # closest reso to input + result = constructor(scalar) + + item = get1(result) + dtype = tm.get_dtype(result) + + assert type(item) is Timestamp + assert item.asm8.dtype == exp_dtype + assert dtype == exp_dtype + + @pytest.mark.skip_ubsan + def test_out_of_s_bounds_datetime64(self, constructor): + scalar = np.datetime64(np.iinfo(np.int64).max, "D") + result = constructor(scalar) + item = get1(result) + assert type(item) is np.datetime64 + dtype = tm.get_dtype(result) + assert dtype == object + + @pytest.mark.parametrize("cls", [timedelta, np.timedelta64]) + def test_from_out_of_bounds_ns_timedelta( + self, constructor, cls, request, box, frame_or_series + ): + # scalar that won't fit in nanosecond td64, but will fit in microsecond + if box is list or (frame_or_series is Series and box is dict): + mark = pytest.mark.xfail( + reason="TimedeltaArray constructor has been updated to cast td64 " + "to non-nano, but TimedeltaArray._from_sequence has not", + strict=True, + ) + request.applymarker(mark) + + scalar = datetime(9999, 1, 1) - datetime(1970, 1, 1) + exp_dtype = "m8[us]" # smallest reso that fits + if cls is np.timedelta64: + scalar = np.timedelta64(scalar, "D") + exp_dtype = "m8[s]" # closest reso to input + result = constructor(scalar) + + item = get1(result) + dtype = tm.get_dtype(result) + + assert type(item) is Timedelta + assert item.asm8.dtype == exp_dtype + assert dtype == exp_dtype + + @pytest.mark.skip_ubsan + @pytest.mark.parametrize("cls", [np.datetime64, np.timedelta64]) + def test_out_of_s_bounds_timedelta64(self, constructor, cls): + scalar = cls(np.iinfo(np.int64).max, "D") + result = constructor(scalar) + item = get1(result) + assert type(item) is cls + dtype = tm.get_dtype(result) + assert dtype == object + + def test_tzaware_data_tznaive_dtype(self, constructor, box, frame_or_series): + tz = "US/Eastern" + ts = Timestamp("2019", tz=tz) + + if box is None or (frame_or_series is DataFrame and box is dict): + msg = "Cannot unbox tzaware Timestamp to tznaive dtype" + err = TypeError + else: + msg = ( + "Cannot convert timezone-aware data to timezone-naive dtype. " + r"Use pd.Series\(values\).dt.tz_localize\(None\) instead." + ) + err = ValueError + + with pytest.raises(err, match=msg): + constructor(ts, dtype="M8[ns]") + + +# TODO: better location for this test? +class TestAllowNonNano: + # Until 2.0, we do not preserve non-nano dt64/td64 when passed as ndarray, + # but do preserve it when passed as DTA/TDA + + @pytest.fixture(params=[True, False]) + def as_td(self, request): + return request.param + + @pytest.fixture + def arr(self, as_td): + values = np.arange(5).astype(np.int64).view("M8[s]") + if as_td: + values = values - values[0] + return TimedeltaArray._simple_new(values, dtype=values.dtype) + else: + return DatetimeArray._simple_new(values, dtype=values.dtype) + + def test_index_allow_non_nano(self, arr): + idx = Index(arr) + assert idx.dtype == arr.dtype + + def test_dti_tdi_allow_non_nano(self, arr, as_td): + if as_td: + idx = pd.TimedeltaIndex(arr) + else: + idx = DatetimeIndex(arr) + assert idx.dtype == arr.dtype + + def test_series_allow_non_nano(self, arr): + ser = Series(arr) + assert ser.dtype == arr.dtype + + def test_frame_allow_non_nano(self, arr): + df = DataFrame(arr) + assert df.dtypes[0] == arr.dtype + + def test_frame_from_dict_allow_non_nano(self, arr): + df = DataFrame({0: arr}) + assert df.dtypes[0] == arr.dtype diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_cumulative.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_cumulative.py new file mode 100644 index 0000000000000000000000000000000000000000..5bd9c426123159fcfcf6bf5289fd08a60dfd91b2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_cumulative.py @@ -0,0 +1,81 @@ +""" +Tests for DataFrame cumulative operations + +See also +-------- +tests.series.test_cumulative +""" + +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Series, +) +import pandas._testing as tm + + +class TestDataFrameCumulativeOps: + # --------------------------------------------------------------------- + # Cumulative Operations - cumsum, cummax, ... + + def test_cumulative_ops_smoke(self): + # it works + df = DataFrame({"A": np.arange(20)}, index=np.arange(20)) + df.cummax() + df.cummin() + df.cumsum() + + dm = DataFrame(np.arange(20).reshape(4, 5), index=range(4), columns=range(5)) + # TODO(wesm): do something with this? + dm.cumsum() + + def test_cumprod_smoke(self, datetime_frame): + datetime_frame.iloc[5:10, 0] = np.nan + datetime_frame.iloc[10:15, 1] = np.nan + datetime_frame.iloc[15:, 2] = np.nan + + # ints + df = datetime_frame.fillna(0).astype(int) + df.cumprod(0) + df.cumprod(1) + + # ints32 + df = datetime_frame.fillna(0).astype(np.int32) + df.cumprod(0) + df.cumprod(1) + + @pytest.mark.parametrize("method", ["cumsum", "cumprod", "cummin", "cummax"]) + def test_cumulative_ops_match_series_apply(self, datetime_frame, method): + datetime_frame.iloc[5:10, 0] = np.nan + datetime_frame.iloc[10:15, 1] = np.nan + datetime_frame.iloc[15:, 2] = np.nan + + # axis = 0 + result = getattr(datetime_frame, method)() + expected = datetime_frame.apply(getattr(Series, method)) + tm.assert_frame_equal(result, expected) + + # axis = 1 + result = getattr(datetime_frame, method)(axis=1) + expected = datetime_frame.apply(getattr(Series, method), axis=1) + tm.assert_frame_equal(result, expected) + + # fix issue TODO: GH ref? + assert np.shape(result) == np.shape(datetime_frame) + + def test_cumsum_preserve_dtypes(self): + # GH#19296 dont incorrectly upcast to object + df = DataFrame({"A": [1, 2, 3], "B": [1, 2, 3.0], "C": [True, False, False]}) + + result = df.cumsum() + + expected = DataFrame( + { + "A": Series([1, 3, 6], dtype=np.int64), + "B": Series([1, 3, 6], dtype=np.float64), + "C": df["C"].cumsum(), + } + ) + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_iteration.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_iteration.py new file mode 100644 index 0000000000000000000000000000000000000000..a1c23ff05f3e19aca490444216ec295453483e80 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_iteration.py @@ -0,0 +1,160 @@ +import datetime + +import numpy as np +import pytest + +from pandas.compat import ( + IS64, + is_platform_windows, +) + +from pandas import ( + Categorical, + DataFrame, + Series, + date_range, +) +import pandas._testing as tm + + +class TestIteration: + def test_keys(self, float_frame): + assert float_frame.keys() is float_frame.columns + + def test_iteritems(self): + df = DataFrame([[1, 2, 3], [4, 5, 6]], columns=["a", "a", "b"]) + for k, v in df.items(): + assert isinstance(v, DataFrame._constructor_sliced) + + def test_items(self): + # GH#17213, GH#13918 + cols = ["a", "b", "c"] + df = DataFrame([[1, 2, 3], [4, 5, 6]], columns=cols) + for c, (k, v) in zip(cols, df.items()): + assert c == k + assert isinstance(v, Series) + assert (df[k] == v).all() + + def test_items_names(self, float_string_frame): + for k, v in float_string_frame.items(): + assert v.name == k + + def test_iter(self, float_frame): + assert list(float_frame) == list(float_frame.columns) + + def test_iterrows(self, float_frame, float_string_frame): + for k, v in float_frame.iterrows(): + exp = float_frame.loc[k] + tm.assert_series_equal(v, exp) + + for k, v in float_string_frame.iterrows(): + exp = float_string_frame.loc[k] + tm.assert_series_equal(v, exp) + + def test_iterrows_iso8601(self): + # GH#19671 + s = DataFrame( + { + "non_iso8601": ["M1701", "M1802", "M1903", "M2004"], + "iso8601": date_range("2000-01-01", periods=4, freq="ME"), + } + ) + for k, v in s.iterrows(): + exp = s.loc[k] + tm.assert_series_equal(v, exp) + + def test_iterrows_corner(self): + # GH#12222 + df = DataFrame( + { + "a": [datetime.datetime(2015, 1, 1)], + "b": [None], + "c": [None], + "d": [""], + "e": [[]], + "f": [set()], + "g": [{}], + } + ) + expected = Series( + [datetime.datetime(2015, 1, 1), None, None, "", [], set(), {}], + index=list("abcdefg"), + name=0, + dtype="object", + ) + _, result = next(df.iterrows()) + tm.assert_series_equal(result, expected) + + def test_itertuples(self, float_frame): + for i, tup in enumerate(float_frame.itertuples()): + ser = DataFrame._constructor_sliced(tup[1:]) + ser.name = tup[0] + expected = float_frame.iloc[i, :].reset_index(drop=True) + tm.assert_series_equal(ser, expected) + + def test_itertuples_index_false(self): + df = DataFrame( + {"floats": np.random.default_rng(2).standard_normal(5), "ints": range(5)}, + columns=["floats", "ints"], + ) + + for tup in df.itertuples(index=False): + assert isinstance(tup[1], int) + + def test_itertuples_duplicate_cols(self): + df = DataFrame(data={"a": [1, 2, 3], "b": [4, 5, 6]}) + dfaa = df[["a", "a"]] + + assert list(dfaa.itertuples()) == [(0, 1, 1), (1, 2, 2), (2, 3, 3)] + + # repr with int on 32-bit/windows + if not (is_platform_windows() or not IS64): + assert ( + repr(list(df.itertuples(name=None))) + == "[(0, 1, 4), (1, 2, 5), (2, 3, 6)]" + ) + + def test_itertuples_tuple_name(self): + df = DataFrame(data={"a": [1, 2, 3], "b": [4, 5, 6]}) + tup = next(df.itertuples(name="TestName")) + assert tup._fields == ("Index", "a", "b") + assert (tup.Index, tup.a, tup.b) == tup + assert type(tup).__name__ == "TestName" + + def test_itertuples_disallowed_col_labels(self): + df = DataFrame(data={"def": [1, 2, 3], "return": [4, 5, 6]}) + tup2 = next(df.itertuples(name="TestName")) + assert tup2 == (0, 1, 4) + assert tup2._fields == ("Index", "_1", "_2") + + @pytest.mark.parametrize("limit", [254, 255, 1024]) + @pytest.mark.parametrize("index", [True, False]) + def test_itertuples_py2_3_field_limit_namedtuple(self, limit, index): + # GH#28282 + df = DataFrame([{f"foo_{i}": f"bar_{i}" for i in range(limit)}]) + result = next(df.itertuples(index=index)) + assert isinstance(result, tuple) + assert hasattr(result, "_fields") + + def test_sequence_like_with_categorical(self): + # GH#7839 + # make sure can iterate + df = DataFrame( + {"id": [1, 2, 3, 4, 5, 6], "raw_grade": ["a", "b", "b", "a", "a", "e"]} + ) + df["grade"] = Categorical(df["raw_grade"]) + + # basic sequencing testing + result = list(df.grade.values) + expected = np.array(df.grade.values).tolist() + tm.assert_almost_equal(result, expected) + + # iteration + for t in df.itertuples(index=False): + str(t) + + for row, s in df.iterrows(): + str(s) + + for c, col in df.items(): + str(col) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_logical_ops.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_logical_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..16ca3a202f1e0242104d9516b4dd2318599cefef --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_logical_ops.py @@ -0,0 +1,218 @@ +import operator +import re + +import numpy as np +import pytest + +from pandas import ( + CategoricalIndex, + DataFrame, + Interval, + Series, + isnull, +) +import pandas._testing as tm + + +class TestDataFrameLogicalOperators: + # &, |, ^ + + @pytest.mark.parametrize( + "left, right, op, expected", + [ + ( + [True, False, np.nan], + [True, False, True], + operator.and_, + [True, False, False], + ), + ( + [True, False, True], + [True, False, np.nan], + operator.and_, + [True, False, False], + ), + ( + [True, False, np.nan], + [True, False, True], + operator.or_, + [True, False, False], + ), + ( + [True, False, True], + [True, False, np.nan], + operator.or_, + [True, False, True], + ), + ], + ) + def test_logical_operators_nans(self, left, right, op, expected, frame_or_series): + # GH#13896 + result = op(frame_or_series(left), frame_or_series(right)) + expected = frame_or_series(expected) + + tm.assert_equal(result, expected) + + def test_logical_ops_empty_frame(self): + # GH#5808 + # empty frames, non-mixed dtype + df = DataFrame(index=[1]) + + result = df & df + tm.assert_frame_equal(result, df) + + result = df | df + tm.assert_frame_equal(result, df) + + df2 = DataFrame(index=[1, 2]) + result = df & df2 + tm.assert_frame_equal(result, df2) + + dfa = DataFrame(index=[1], columns=["A"]) + + result = dfa & dfa + expected = DataFrame(False, index=[1], columns=["A"]) + tm.assert_frame_equal(result, expected) + + def test_logical_ops_bool_frame(self): + # GH#5808 + df1a_bool = DataFrame(True, index=[1], columns=["A"]) + + result = df1a_bool & df1a_bool + tm.assert_frame_equal(result, df1a_bool) + + result = df1a_bool | df1a_bool + tm.assert_frame_equal(result, df1a_bool) + + def test_logical_ops_int_frame(self): + # GH#5808 + df1a_int = DataFrame(1, index=[1], columns=["A"]) + df1a_bool = DataFrame(True, index=[1], columns=["A"]) + + result = df1a_int | df1a_bool + tm.assert_frame_equal(result, df1a_bool) + + # Check that this matches Series behavior + res_ser = df1a_int["A"] | df1a_bool["A"] + tm.assert_series_equal(res_ser, df1a_bool["A"]) + + def test_logical_ops_invalid(self, using_infer_string): + # GH#5808 + + df1 = DataFrame(1.0, index=[1], columns=["A"]) + df2 = DataFrame(True, index=[1], columns=["A"]) + msg = re.escape("unsupported operand type(s) for |: 'float' and 'bool'") + with pytest.raises(TypeError, match=msg): + df1 | df2 + + df1 = DataFrame("foo", index=[1], columns=["A"]) + df2 = DataFrame(True, index=[1], columns=["A"]) + msg = re.escape("unsupported operand type(s) for |: 'str' and 'bool'") + if using_infer_string: + import pyarrow as pa + + with pytest.raises(pa.lib.ArrowNotImplementedError, match="|has no kernel"): + df1 | df2 + else: + with pytest.raises(TypeError, match=msg): + df1 | df2 + + def test_logical_operators(self): + def _check_bin_op(op): + result = op(df1, df2) + expected = DataFrame( + op(df1.values, df2.values), index=df1.index, columns=df1.columns + ) + assert result.values.dtype == np.bool_ + tm.assert_frame_equal(result, expected) + + def _check_unary_op(op): + result = op(df1) + expected = DataFrame(op(df1.values), index=df1.index, columns=df1.columns) + assert result.values.dtype == np.bool_ + tm.assert_frame_equal(result, expected) + + df1 = { + "a": {"a": True, "b": False, "c": False, "d": True, "e": True}, + "b": {"a": False, "b": True, "c": False, "d": False, "e": False}, + "c": {"a": False, "b": False, "c": True, "d": False, "e": False}, + "d": {"a": True, "b": False, "c": False, "d": True, "e": True}, + "e": {"a": True, "b": False, "c": False, "d": True, "e": True}, + } + + df2 = { + "a": {"a": True, "b": False, "c": True, "d": False, "e": False}, + "b": {"a": False, "b": True, "c": False, "d": False, "e": False}, + "c": {"a": True, "b": False, "c": True, "d": False, "e": False}, + "d": {"a": False, "b": False, "c": False, "d": True, "e": False}, + "e": {"a": False, "b": False, "c": False, "d": False, "e": True}, + } + + df1 = DataFrame(df1) + df2 = DataFrame(df2) + + _check_bin_op(operator.and_) + _check_bin_op(operator.or_) + _check_bin_op(operator.xor) + + _check_unary_op(operator.inv) # TODO: belongs elsewhere + + @pytest.mark.filterwarnings("ignore:Downcasting object dtype arrays:FutureWarning") + def test_logical_with_nas(self): + d = DataFrame({"a": [np.nan, False], "b": [True, True]}) + + # GH4947 + # bool comparisons should return bool + result = d["a"] | d["b"] + expected = Series([False, True]) + tm.assert_series_equal(result, expected) + + # GH4604, automatic casting here + result = d["a"].fillna(False) | d["b"] + expected = Series([True, True]) + tm.assert_series_equal(result, expected) + + msg = "The 'downcast' keyword in fillna is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = d["a"].fillna(False, downcast=False) | d["b"] + expected = Series([True, True]) + tm.assert_series_equal(result, expected) + + def test_logical_ops_categorical_columns(self): + # GH#38367 + intervals = [Interval(1, 2), Interval(3, 4)] + data = DataFrame( + [[1, np.nan], [2, np.nan]], + columns=CategoricalIndex( + intervals, categories=intervals + [Interval(5, 6)] + ), + ) + mask = DataFrame( + [[False, False], [False, False]], columns=data.columns, dtype=bool + ) + result = mask | isnull(data) + expected = DataFrame( + [[False, True], [False, True]], + columns=CategoricalIndex( + intervals, categories=intervals + [Interval(5, 6)] + ), + ) + tm.assert_frame_equal(result, expected) + + def test_int_dtype_different_index_not_bool(self): + # GH 52500 + df1 = DataFrame([1, 2, 3], index=[10, 11, 23], columns=["a"]) + df2 = DataFrame([10, 20, 30], index=[11, 10, 23], columns=["a"]) + result = np.bitwise_xor(df1, df2) + expected = DataFrame([21, 8, 29], index=[10, 11, 23], columns=["a"]) + tm.assert_frame_equal(result, expected) + + result = df1 ^ df2 + tm.assert_frame_equal(result, expected) + + def test_different_dtypes_different_index_raises(self): + # GH 52538 + df1 = DataFrame([1, 2], index=["a", "b"]) + df2 = DataFrame([3, 4], index=["b", "c"]) + with pytest.raises(TypeError, match="unsupported operand type"): + df1 & df2 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_nonunique_indexes.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_nonunique_indexes.py new file mode 100644 index 0000000000000000000000000000000000000000..34f172e900ab7e16d66afe32aed3d5b87be301c3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_nonunique_indexes.py @@ -0,0 +1,337 @@ +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + Series, + date_range, +) +import pandas._testing as tm + + +class TestDataFrameNonuniqueIndexes: + def test_setattr_columns_vs_construct_with_columns(self): + # assignment + # GH 3687 + arr = np.random.default_rng(2).standard_normal((3, 2)) + idx = list(range(2)) + df = DataFrame(arr, columns=["A", "A"]) + df.columns = idx + expected = DataFrame(arr, columns=idx) + tm.assert_frame_equal(df, expected) + + def test_setattr_columns_vs_construct_with_columns_datetimeindx(self): + idx = date_range("20130101", periods=4, freq="QE-NOV") + df = DataFrame( + [[1, 1, 1, 5], [1, 1, 2, 5], [2, 1, 3, 5]], columns=["a", "a", "a", "a"] + ) + df.columns = idx + expected = DataFrame([[1, 1, 1, 5], [1, 1, 2, 5], [2, 1, 3, 5]], columns=idx) + tm.assert_frame_equal(df, expected) + + def test_insert_with_duplicate_columns(self): + # insert + df = DataFrame( + [[1, 1, 1, 5], [1, 1, 2, 5], [2, 1, 3, 5]], + columns=["foo", "bar", "foo", "hello"], + ) + df["string"] = "bah" + expected = DataFrame( + [[1, 1, 1, 5, "bah"], [1, 1, 2, 5, "bah"], [2, 1, 3, 5, "bah"]], + columns=["foo", "bar", "foo", "hello", "string"], + ) + tm.assert_frame_equal(df, expected) + with pytest.raises(ValueError, match="Length of value"): + df.insert(0, "AnotherColumn", range(len(df.index) - 1)) + + # insert same dtype + df["foo2"] = 3 + expected = DataFrame( + [[1, 1, 1, 5, "bah", 3], [1, 1, 2, 5, "bah", 3], [2, 1, 3, 5, "bah", 3]], + columns=["foo", "bar", "foo", "hello", "string", "foo2"], + ) + tm.assert_frame_equal(df, expected) + + # set (non-dup) + df["foo2"] = 4 + expected = DataFrame( + [[1, 1, 1, 5, "bah", 4], [1, 1, 2, 5, "bah", 4], [2, 1, 3, 5, "bah", 4]], + columns=["foo", "bar", "foo", "hello", "string", "foo2"], + ) + tm.assert_frame_equal(df, expected) + df["foo2"] = 3 + + # delete (non dup) + del df["bar"] + expected = DataFrame( + [[1, 1, 5, "bah", 3], [1, 2, 5, "bah", 3], [2, 3, 5, "bah", 3]], + columns=["foo", "foo", "hello", "string", "foo2"], + ) + tm.assert_frame_equal(df, expected) + + # try to delete again (its not consolidated) + del df["hello"] + expected = DataFrame( + [[1, 1, "bah", 3], [1, 2, "bah", 3], [2, 3, "bah", 3]], + columns=["foo", "foo", "string", "foo2"], + ) + tm.assert_frame_equal(df, expected) + + # consolidate + df = df._consolidate() + expected = DataFrame( + [[1, 1, "bah", 3], [1, 2, "bah", 3], [2, 3, "bah", 3]], + columns=["foo", "foo", "string", "foo2"], + ) + tm.assert_frame_equal(df, expected) + + # insert + df.insert(2, "new_col", 5.0) + expected = DataFrame( + [[1, 1, 5.0, "bah", 3], [1, 2, 5.0, "bah", 3], [2, 3, 5.0, "bah", 3]], + columns=["foo", "foo", "new_col", "string", "foo2"], + ) + tm.assert_frame_equal(df, expected) + + # insert a dup + with pytest.raises(ValueError, match="cannot insert"): + df.insert(2, "new_col", 4.0) + + df.insert(2, "new_col", 4.0, allow_duplicates=True) + expected = DataFrame( + [ + [1, 1, 4.0, 5.0, "bah", 3], + [1, 2, 4.0, 5.0, "bah", 3], + [2, 3, 4.0, 5.0, "bah", 3], + ], + columns=["foo", "foo", "new_col", "new_col", "string", "foo2"], + ) + tm.assert_frame_equal(df, expected) + + # delete (dup) + del df["foo"] + expected = DataFrame( + [[4.0, 5.0, "bah", 3], [4.0, 5.0, "bah", 3], [4.0, 5.0, "bah", 3]], + columns=["new_col", "new_col", "string", "foo2"], + ) + tm.assert_frame_equal(df, expected) + + def test_dup_across_dtypes(self): + # dup across dtypes + df = DataFrame( + [[1, 1, 1.0, 5], [1, 1, 2.0, 5], [2, 1, 3.0, 5]], + columns=["foo", "bar", "foo", "hello"], + ) + + df["foo2"] = 7.0 + expected = DataFrame( + [[1, 1, 1.0, 5, 7.0], [1, 1, 2.0, 5, 7.0], [2, 1, 3.0, 5, 7.0]], + columns=["foo", "bar", "foo", "hello", "foo2"], + ) + tm.assert_frame_equal(df, expected) + + result = df["foo"] + expected = DataFrame([[1, 1.0], [1, 2.0], [2, 3.0]], columns=["foo", "foo"]) + tm.assert_frame_equal(result, expected) + + # multiple replacements + df["foo"] = "string" + expected = DataFrame( + [ + ["string", 1, "string", 5, 7.0], + ["string", 1, "string", 5, 7.0], + ["string", 1, "string", 5, 7.0], + ], + columns=["foo", "bar", "foo", "hello", "foo2"], + ) + tm.assert_frame_equal(df, expected) + + del df["foo"] + expected = DataFrame( + [[1, 5, 7.0], [1, 5, 7.0], [1, 5, 7.0]], columns=["bar", "hello", "foo2"] + ) + tm.assert_frame_equal(df, expected) + + def test_column_dups_indexes(self): + # check column dups with index equal and not equal to df's index + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 3)), + index=["a", "b", "c", "d", "e"], + columns=["A", "B", "A"], + ) + for index in [df.index, pd.Index(list("edcba"))]: + this_df = df.copy() + expected_ser = Series(index.values, index=this_df.index) + expected_df = DataFrame( + {"A": expected_ser, "B": this_df["B"]}, + columns=["A", "B", "A"], + ) + this_df["A"] = index + tm.assert_frame_equal(this_df, expected_df) + + def test_changing_dtypes_with_duplicate_columns(self): + # multiple assignments that change dtypes + # the location indexer is a slice + # GH 6120 + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), columns=["that", "that"] + ) + expected = DataFrame(1.0, index=range(5), columns=["that", "that"]) + + df["that"] = 1.0 + tm.assert_frame_equal(df, expected) + + df = DataFrame( + np.random.default_rng(2).random((5, 2)), columns=["that", "that"] + ) + expected = DataFrame(1, index=range(5), columns=["that", "that"]) + + df["that"] = 1 + tm.assert_frame_equal(df, expected) + + def test_dup_columns_comparisons(self): + # equality + df1 = DataFrame([[1, 2], [2, np.nan], [3, 4], [4, 4]], columns=["A", "B"]) + df2 = DataFrame([[0, 1], [2, 4], [2, np.nan], [4, 5]], columns=["A", "A"]) + + # not-comparing like-labelled + msg = ( + r"Can only compare identically-labeled \(both index and columns\) " + "DataFrame objects" + ) + with pytest.raises(ValueError, match=msg): + df1 == df2 + + df1r = df1.reindex_like(df2) + result = df1r == df2 + expected = DataFrame( + [[False, True], [True, False], [False, False], [True, False]], + columns=["A", "A"], + ) + tm.assert_frame_equal(result, expected) + + def test_mixed_column_selection(self): + # mixed column selection + # GH 5639 + dfbool = DataFrame( + { + "one": Series([True, True, False], index=["a", "b", "c"]), + "two": Series([False, False, True, False], index=["a", "b", "c", "d"]), + "three": Series([False, True, True, True], index=["a", "b", "c", "d"]), + } + ) + expected = pd.concat([dfbool["one"], dfbool["three"], dfbool["one"]], axis=1) + result = dfbool[["one", "three", "one"]] + tm.assert_frame_equal(result, expected) + + def test_multi_axis_dups(self): + # multi-axis dups + # GH 6121 + df = DataFrame( + np.arange(25.0).reshape(5, 5), + index=["a", "b", "c", "d", "e"], + columns=["A", "B", "C", "D", "E"], + ) + z = df[["A", "C", "A"]].copy() + expected = z.loc[["a", "c", "a"]] + + df = DataFrame( + np.arange(25.0).reshape(5, 5), + index=["a", "b", "c", "d", "e"], + columns=["A", "B", "C", "D", "E"], + ) + z = df[["A", "C", "A"]] + result = z.loc[["a", "c", "a"]] + tm.assert_frame_equal(result, expected) + + def test_columns_with_dups(self): + # GH 3468 related + + # basic + df = DataFrame([[1, 2]], columns=["a", "a"]) + df.columns = ["a", "a.1"] + expected = DataFrame([[1, 2]], columns=["a", "a.1"]) + tm.assert_frame_equal(df, expected) + + df = DataFrame([[1, 2, 3]], columns=["b", "a", "a"]) + df.columns = ["b", "a", "a.1"] + expected = DataFrame([[1, 2, 3]], columns=["b", "a", "a.1"]) + tm.assert_frame_equal(df, expected) + + def test_columns_with_dup_index(self): + # with a dup index + df = DataFrame([[1, 2]], columns=["a", "a"]) + df.columns = ["b", "b"] + expected = DataFrame([[1, 2]], columns=["b", "b"]) + tm.assert_frame_equal(df, expected) + + def test_multi_dtype(self): + # multi-dtype + df = DataFrame( + [[1, 2, 1.0, 2.0, 3.0, "foo", "bar"]], + columns=["a", "a", "b", "b", "d", "c", "c"], + ) + df.columns = list("ABCDEFG") + expected = DataFrame( + [[1, 2, 1.0, 2.0, 3.0, "foo", "bar"]], columns=list("ABCDEFG") + ) + tm.assert_frame_equal(df, expected) + + def test_multi_dtype2(self): + df = DataFrame([[1, 2, "foo", "bar"]], columns=["a", "a", "a", "a"]) + df.columns = ["a", "a.1", "a.2", "a.3"] + expected = DataFrame([[1, 2, "foo", "bar"]], columns=["a", "a.1", "a.2", "a.3"]) + tm.assert_frame_equal(df, expected) + + def test_dups_across_blocks(self, using_array_manager): + # dups across blocks + df_float = DataFrame( + np.random.default_rng(2).standard_normal((10, 3)), dtype="float64" + ) + df_int = DataFrame( + np.random.default_rng(2).standard_normal((10, 3)).astype("int64") + ) + df_bool = DataFrame(True, index=df_float.index, columns=df_float.columns) + df_object = DataFrame("foo", index=df_float.index, columns=df_float.columns) + df_dt = DataFrame( + pd.Timestamp("20010101"), index=df_float.index, columns=df_float.columns + ) + df = pd.concat([df_float, df_int, df_bool, df_object, df_dt], axis=1) + + if not using_array_manager: + assert len(df._mgr.blknos) == len(df.columns) + assert len(df._mgr.blklocs) == len(df.columns) + + # testing iloc + for i in range(len(df.columns)): + df.iloc[:, i] + + def test_dup_columns_across_dtype(self): + # dup columns across dtype GH 2079/2194 + vals = [[1, -1, 2.0], [2, -2, 3.0]] + rs = DataFrame(vals, columns=["A", "A", "B"]) + xp = DataFrame(vals) + xp.columns = ["A", "A", "B"] + tm.assert_frame_equal(rs, xp) + + def test_set_value_by_index(self): + # See gh-12344 + warn = None + msg = "will attempt to set the values inplace" + + df = DataFrame(np.arange(9).reshape(3, 3).T) + df.columns = list("AAA") + expected = df.iloc[:, 2].copy() + + with tm.assert_produces_warning(warn, match=msg): + df.iloc[:, 0] = 3 + tm.assert_series_equal(df.iloc[:, 2], expected) + + df = DataFrame(np.arange(9).reshape(3, 3).T) + df.columns = [2, float(2), str(2)] + expected = df.iloc[:, 1].copy() + + with tm.assert_produces_warning(warn, match=msg): + df.iloc[:, 0] = 3 + tm.assert_series_equal(df.iloc[:, 1], expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_npfuncs.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_npfuncs.py new file mode 100644 index 0000000000000000000000000000000000000000..afb53bf2de93aa591ca9d7b99af185bc0c4083ee --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_npfuncs.py @@ -0,0 +1,89 @@ +""" +Tests for np.foo applied to DataFrame, not necessarily ufuncs. +""" +import numpy as np + +from pandas import ( + Categorical, + DataFrame, +) +import pandas._testing as tm + + +class TestAsArray: + def test_asarray_homogeneous(self): + df = DataFrame({"A": Categorical([1, 2]), "B": Categorical([1, 2])}) + result = np.asarray(df) + # may change from object in the future + expected = np.array([[1, 1], [2, 2]], dtype="object") + tm.assert_numpy_array_equal(result, expected) + + def test_np_sqrt(self, float_frame): + with np.errstate(all="ignore"): + result = np.sqrt(float_frame) + assert isinstance(result, type(float_frame)) + assert result.index.is_(float_frame.index) + assert result.columns.is_(float_frame.columns) + + tm.assert_frame_equal(result, float_frame.apply(np.sqrt)) + + def test_sum_deprecated_axis_behavior(self): + # GH#52042 deprecated behavior of df.sum(axis=None), which gets + # called when we do np.sum(df) + + arr = np.random.default_rng(2).standard_normal((4, 3)) + df = DataFrame(arr) + + msg = "The behavior of DataFrame.sum with axis=None is deprecated" + with tm.assert_produces_warning( + FutureWarning, match=msg, check_stacklevel=False + ): + res = np.sum(df) + + with tm.assert_produces_warning(FutureWarning, match=msg): + expected = df.sum(axis=None) + tm.assert_series_equal(res, expected) + + def test_np_ravel(self): + # GH26247 + arr = np.array( + [ + [0.11197053, 0.44361564, -0.92589452], + [0.05883648, -0.00948922, -0.26469934], + ] + ) + + result = np.ravel([DataFrame(batch.reshape(1, 3)) for batch in arr]) + expected = np.array( + [ + 0.11197053, + 0.44361564, + -0.92589452, + 0.05883648, + -0.00948922, + -0.26469934, + ] + ) + tm.assert_numpy_array_equal(result, expected) + + result = np.ravel(DataFrame(arr[0].reshape(1, 3), columns=["x1", "x2", "x3"])) + expected = np.array([0.11197053, 0.44361564, -0.92589452]) + tm.assert_numpy_array_equal(result, expected) + + result = np.ravel( + [ + DataFrame(batch.reshape(1, 3), columns=["x1", "x2", "x3"]) + for batch in arr + ] + ) + expected = np.array( + [ + 0.11197053, + 0.44361564, + -0.92589452, + 0.05883648, + -0.00948922, + -0.26469934, + ] + ) + tm.assert_numpy_array_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_query_eval.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_query_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..2c807c72582c52d748d043263a736b5c9dbdfffa --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_query_eval.py @@ -0,0 +1,1425 @@ +import operator + +import numpy as np +import pytest + +from pandas.errors import ( + NumExprClobberingError, + UndefinedVariableError, +) +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, + date_range, +) +import pandas._testing as tm +from pandas.core.computation.check import NUMEXPR_INSTALLED + + +@pytest.fixture(params=["python", "pandas"], ids=lambda x: x) +def parser(request): + return request.param + + +@pytest.fixture( + params=["python", pytest.param("numexpr", marks=td.skip_if_no("numexpr"))], + ids=lambda x: x, +) +def engine(request): + return request.param + + +def skip_if_no_pandas_parser(parser): + if parser != "pandas": + pytest.skip(f"cannot evaluate with parser={parser}") + + +class TestCompat: + @pytest.fixture + def df(self): + return DataFrame({"A": [1, 2, 3]}) + + @pytest.fixture + def expected1(self, df): + return df[df.A > 0] + + @pytest.fixture + def expected2(self, df): + return df.A + 1 + + def test_query_default(self, df, expected1, expected2): + # GH 12749 + # this should always work, whether NUMEXPR_INSTALLED or not + result = df.query("A>0") + tm.assert_frame_equal(result, expected1) + result = df.eval("A+1") + tm.assert_series_equal(result, expected2, check_names=False) + + def test_query_None(self, df, expected1, expected2): + result = df.query("A>0", engine=None) + tm.assert_frame_equal(result, expected1) + result = df.eval("A+1", engine=None) + tm.assert_series_equal(result, expected2, check_names=False) + + def test_query_python(self, df, expected1, expected2): + result = df.query("A>0", engine="python") + tm.assert_frame_equal(result, expected1) + result = df.eval("A+1", engine="python") + tm.assert_series_equal(result, expected2, check_names=False) + + def test_query_numexpr(self, df, expected1, expected2): + if NUMEXPR_INSTALLED: + result = df.query("A>0", engine="numexpr") + tm.assert_frame_equal(result, expected1) + result = df.eval("A+1", engine="numexpr") + tm.assert_series_equal(result, expected2, check_names=False) + else: + msg = ( + r"'numexpr' is not installed or an unsupported version. " + r"Cannot use engine='numexpr' for query/eval if 'numexpr' is " + r"not installed" + ) + with pytest.raises(ImportError, match=msg): + df.query("A>0", engine="numexpr") + with pytest.raises(ImportError, match=msg): + df.eval("A+1", engine="numexpr") + + +class TestDataFrameEval: + # smaller hits python, larger hits numexpr + @pytest.mark.parametrize("n", [4, 4000]) + @pytest.mark.parametrize( + "op_str,op,rop", + [ + ("+", "__add__", "__radd__"), + ("-", "__sub__", "__rsub__"), + ("*", "__mul__", "__rmul__"), + ("/", "__truediv__", "__rtruediv__"), + ], + ) + def test_ops(self, op_str, op, rop, n): + # tst ops and reversed ops in evaluation + # GH7198 + + df = DataFrame(1, index=range(n), columns=list("abcd")) + df.iloc[0] = 2 + m = df.mean() + + base = DataFrame( # noqa: F841 + np.tile(m.values, n).reshape(n, -1), columns=list("abcd") + ) + + expected = eval(f"base {op_str} df") + + # ops as strings + result = eval(f"m {op_str} df") + tm.assert_frame_equal(result, expected) + + # these are commutative + if op in ["+", "*"]: + result = getattr(df, op)(m) + tm.assert_frame_equal(result, expected) + + # these are not + elif op in ["-", "/"]: + result = getattr(df, rop)(m) + tm.assert_frame_equal(result, expected) + + def test_dataframe_sub_numexpr_path(self): + # GH7192: Note we need a large number of rows to ensure this + # goes through the numexpr path + df = DataFrame({"A": np.random.default_rng(2).standard_normal(25000)}) + df.iloc[0:5] = np.nan + expected = 1 - np.isnan(df.iloc[0:25]) + result = (1 - np.isnan(df)).iloc[0:25] + tm.assert_frame_equal(result, expected) + + def test_query_non_str(self): + # GH 11485 + df = DataFrame({"A": [1, 2, 3], "B": ["a", "b", "b"]}) + + msg = "expr must be a string to be evaluated" + with pytest.raises(ValueError, match=msg): + df.query(lambda x: x.B == "b") + + with pytest.raises(ValueError, match=msg): + df.query(111) + + def test_query_empty_string(self): + # GH 13139 + df = DataFrame({"A": [1, 2, 3]}) + + msg = "expr cannot be an empty string" + with pytest.raises(ValueError, match=msg): + df.query("") + + def test_eval_resolvers_as_list(self): + # GH 14095 + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 2)), columns=list("ab") + ) + dict1 = {"a": 1} + dict2 = {"b": 2} + assert df.eval("a + b", resolvers=[dict1, dict2]) == dict1["a"] + dict2["b"] + assert pd.eval("a + b", resolvers=[dict1, dict2]) == dict1["a"] + dict2["b"] + + def test_eval_resolvers_combined(self): + # GH 34966 + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 2)), columns=list("ab") + ) + dict1 = {"c": 2} + + # Both input and default index/column resolvers should be usable + result = df.eval("a + b * c", resolvers=[dict1]) + + expected = df["a"] + df["b"] * dict1["c"] + tm.assert_series_equal(result, expected) + + def test_eval_object_dtype_binop(self): + # GH#24883 + df = DataFrame({"a1": ["Y", "N"]}) + res = df.eval("c = ((a1 == 'Y') & True)") + expected = DataFrame({"a1": ["Y", "N"], "c": [True, False]}) + tm.assert_frame_equal(res, expected) + + +class TestDataFrameQueryWithMultiIndex: + def test_query_with_named_multiindex(self, parser, engine): + skip_if_no_pandas_parser(parser) + a = np.random.default_rng(2).choice(["red", "green"], size=10) + b = np.random.default_rng(2).choice(["eggs", "ham"], size=10) + index = MultiIndex.from_arrays([a, b], names=["color", "food"]) + df = DataFrame(np.random.default_rng(2).standard_normal((10, 2)), index=index) + ind = Series( + df.index.get_level_values("color").values, index=index, name="color" + ) + + # equality + res1 = df.query('color == "red"', parser=parser, engine=engine) + res2 = df.query('"red" == color', parser=parser, engine=engine) + exp = df[ind == "red"] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + # inequality + res1 = df.query('color != "red"', parser=parser, engine=engine) + res2 = df.query('"red" != color', parser=parser, engine=engine) + exp = df[ind != "red"] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + # list equality (really just set membership) + res1 = df.query('color == ["red"]', parser=parser, engine=engine) + res2 = df.query('["red"] == color', parser=parser, engine=engine) + exp = df[ind.isin(["red"])] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + res1 = df.query('color != ["red"]', parser=parser, engine=engine) + res2 = df.query('["red"] != color', parser=parser, engine=engine) + exp = df[~ind.isin(["red"])] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + # in/not in ops + res1 = df.query('["red"] in color', parser=parser, engine=engine) + res2 = df.query('"red" in color', parser=parser, engine=engine) + exp = df[ind.isin(["red"])] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + res1 = df.query('["red"] not in color', parser=parser, engine=engine) + res2 = df.query('"red" not in color', parser=parser, engine=engine) + exp = df[~ind.isin(["red"])] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + def test_query_with_unnamed_multiindex(self, parser, engine): + skip_if_no_pandas_parser(parser) + a = np.random.default_rng(2).choice(["red", "green"], size=10) + b = np.random.default_rng(2).choice(["eggs", "ham"], size=10) + index = MultiIndex.from_arrays([a, b]) + df = DataFrame(np.random.default_rng(2).standard_normal((10, 2)), index=index) + ind = Series(df.index.get_level_values(0).values, index=index) + + res1 = df.query('ilevel_0 == "red"', parser=parser, engine=engine) + res2 = df.query('"red" == ilevel_0', parser=parser, engine=engine) + exp = df[ind == "red"] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + # inequality + res1 = df.query('ilevel_0 != "red"', parser=parser, engine=engine) + res2 = df.query('"red" != ilevel_0', parser=parser, engine=engine) + exp = df[ind != "red"] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + # list equality (really just set membership) + res1 = df.query('ilevel_0 == ["red"]', parser=parser, engine=engine) + res2 = df.query('["red"] == ilevel_0', parser=parser, engine=engine) + exp = df[ind.isin(["red"])] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + res1 = df.query('ilevel_0 != ["red"]', parser=parser, engine=engine) + res2 = df.query('["red"] != ilevel_0', parser=parser, engine=engine) + exp = df[~ind.isin(["red"])] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + # in/not in ops + res1 = df.query('["red"] in ilevel_0', parser=parser, engine=engine) + res2 = df.query('"red" in ilevel_0', parser=parser, engine=engine) + exp = df[ind.isin(["red"])] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + res1 = df.query('["red"] not in ilevel_0', parser=parser, engine=engine) + res2 = df.query('"red" not in ilevel_0', parser=parser, engine=engine) + exp = df[~ind.isin(["red"])] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + # ## LEVEL 1 + ind = Series(df.index.get_level_values(1).values, index=index) + res1 = df.query('ilevel_1 == "eggs"', parser=parser, engine=engine) + res2 = df.query('"eggs" == ilevel_1', parser=parser, engine=engine) + exp = df[ind == "eggs"] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + # inequality + res1 = df.query('ilevel_1 != "eggs"', parser=parser, engine=engine) + res2 = df.query('"eggs" != ilevel_1', parser=parser, engine=engine) + exp = df[ind != "eggs"] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + # list equality (really just set membership) + res1 = df.query('ilevel_1 == ["eggs"]', parser=parser, engine=engine) + res2 = df.query('["eggs"] == ilevel_1', parser=parser, engine=engine) + exp = df[ind.isin(["eggs"])] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + res1 = df.query('ilevel_1 != ["eggs"]', parser=parser, engine=engine) + res2 = df.query('["eggs"] != ilevel_1', parser=parser, engine=engine) + exp = df[~ind.isin(["eggs"])] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + # in/not in ops + res1 = df.query('["eggs"] in ilevel_1', parser=parser, engine=engine) + res2 = df.query('"eggs" in ilevel_1', parser=parser, engine=engine) + exp = df[ind.isin(["eggs"])] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + res1 = df.query('["eggs"] not in ilevel_1', parser=parser, engine=engine) + res2 = df.query('"eggs" not in ilevel_1', parser=parser, engine=engine) + exp = df[~ind.isin(["eggs"])] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + def test_query_with_partially_named_multiindex(self, parser, engine): + skip_if_no_pandas_parser(parser) + a = np.random.default_rng(2).choice(["red", "green"], size=10) + b = np.arange(10) + index = MultiIndex.from_arrays([a, b]) + index.names = [None, "rating"] + df = DataFrame(np.random.default_rng(2).standard_normal((10, 2)), index=index) + res = df.query("rating == 1", parser=parser, engine=engine) + ind = Series( + df.index.get_level_values("rating").values, index=index, name="rating" + ) + exp = df[ind == 1] + tm.assert_frame_equal(res, exp) + + res = df.query("rating != 1", parser=parser, engine=engine) + ind = Series( + df.index.get_level_values("rating").values, index=index, name="rating" + ) + exp = df[ind != 1] + tm.assert_frame_equal(res, exp) + + res = df.query('ilevel_0 == "red"', parser=parser, engine=engine) + ind = Series(df.index.get_level_values(0).values, index=index) + exp = df[ind == "red"] + tm.assert_frame_equal(res, exp) + + res = df.query('ilevel_0 != "red"', parser=parser, engine=engine) + ind = Series(df.index.get_level_values(0).values, index=index) + exp = df[ind != "red"] + tm.assert_frame_equal(res, exp) + + def test_query_multiindex_get_index_resolvers(self): + df = DataFrame( + np.ones((10, 3)), + index=MultiIndex.from_arrays( + [range(10) for _ in range(2)], names=["spam", "eggs"] + ), + ) + resolvers = df._get_index_resolvers() + + def to_series(mi, level): + level_values = mi.get_level_values(level) + s = level_values.to_series() + s.index = mi + return s + + col_series = df.columns.to_series() + expected = { + "index": df.index, + "columns": col_series, + "spam": to_series(df.index, "spam"), + "eggs": to_series(df.index, "eggs"), + "clevel_0": col_series, + } + for k, v in resolvers.items(): + if isinstance(v, Index): + assert v.is_(expected[k]) + elif isinstance(v, Series): + tm.assert_series_equal(v, expected[k]) + else: + raise AssertionError("object must be a Series or Index") + + +@td.skip_if_no("numexpr") +class TestDataFrameQueryNumExprPandas: + @pytest.fixture + def engine(self): + return "numexpr" + + @pytest.fixture + def parser(self): + return "pandas" + + def test_date_query_with_attribute_access(self, engine, parser): + skip_if_no_pandas_parser(parser) + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + df["dates1"] = date_range("1/1/2012", periods=5) + df["dates2"] = date_range("1/1/2013", periods=5) + df["dates3"] = date_range("1/1/2014", periods=5) + res = df.query( + "@df.dates1 < 20130101 < @df.dates3", engine=engine, parser=parser + ) + expec = df[(df.dates1 < "20130101") & ("20130101" < df.dates3)] + tm.assert_frame_equal(res, expec) + + def test_date_query_no_attribute_access(self, engine, parser): + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + df["dates1"] = date_range("1/1/2012", periods=5) + df["dates2"] = date_range("1/1/2013", periods=5) + df["dates3"] = date_range("1/1/2014", periods=5) + res = df.query("dates1 < 20130101 < dates3", engine=engine, parser=parser) + expec = df[(df.dates1 < "20130101") & ("20130101" < df.dates3)] + tm.assert_frame_equal(res, expec) + + def test_date_query_with_NaT(self, engine, parser): + n = 10 + df = DataFrame(np.random.default_rng(2).standard_normal((n, 3))) + df["dates1"] = date_range("1/1/2012", periods=n) + df["dates2"] = date_range("1/1/2013", periods=n) + df["dates3"] = date_range("1/1/2014", periods=n) + df.loc[np.random.default_rng(2).random(n) > 0.5, "dates1"] = pd.NaT + df.loc[np.random.default_rng(2).random(n) > 0.5, "dates3"] = pd.NaT + res = df.query("dates1 < 20130101 < dates3", engine=engine, parser=parser) + expec = df[(df.dates1 < "20130101") & ("20130101" < df.dates3)] + tm.assert_frame_equal(res, expec) + + def test_date_index_query(self, engine, parser): + n = 10 + df = DataFrame(np.random.default_rng(2).standard_normal((n, 3))) + df["dates1"] = date_range("1/1/2012", periods=n) + df["dates3"] = date_range("1/1/2014", periods=n) + return_value = df.set_index("dates1", inplace=True, drop=True) + assert return_value is None + res = df.query("index < 20130101 < dates3", engine=engine, parser=parser) + expec = df[(df.index < "20130101") & ("20130101" < df.dates3)] + tm.assert_frame_equal(res, expec) + + def test_date_index_query_with_NaT(self, engine, parser): + n = 10 + # Cast to object to avoid implicit cast when setting entry to pd.NaT below + df = DataFrame(np.random.default_rng(2).standard_normal((n, 3))).astype( + {0: object} + ) + df["dates1"] = date_range("1/1/2012", periods=n) + df["dates3"] = date_range("1/1/2014", periods=n) + df.iloc[0, 0] = pd.NaT + return_value = df.set_index("dates1", inplace=True, drop=True) + assert return_value is None + res = df.query("index < 20130101 < dates3", engine=engine, parser=parser) + expec = df[(df.index < "20130101") & ("20130101" < df.dates3)] + tm.assert_frame_equal(res, expec) + + def test_date_index_query_with_NaT_duplicates(self, engine, parser): + n = 10 + d = {} + d["dates1"] = date_range("1/1/2012", periods=n) + d["dates3"] = date_range("1/1/2014", periods=n) + df = DataFrame(d) + df.loc[np.random.default_rng(2).random(n) > 0.5, "dates1"] = pd.NaT + return_value = df.set_index("dates1", inplace=True, drop=True) + assert return_value is None + res = df.query("dates1 < 20130101 < dates3", engine=engine, parser=parser) + expec = df[(df.index.to_series() < "20130101") & ("20130101" < df.dates3)] + tm.assert_frame_equal(res, expec) + + def test_date_query_with_non_date(self, engine, parser): + n = 10 + df = DataFrame( + {"dates": date_range("1/1/2012", periods=n), "nondate": np.arange(n)} + ) + + result = df.query("dates == nondate", parser=parser, engine=engine) + assert len(result) == 0 + + result = df.query("dates != nondate", parser=parser, engine=engine) + tm.assert_frame_equal(result, df) + + msg = r"Invalid comparison between dtype=datetime64\[ns\] and ndarray" + for op in ["<", ">", "<=", ">="]: + with pytest.raises(TypeError, match=msg): + df.query(f"dates {op} nondate", parser=parser, engine=engine) + + def test_query_syntax_error(self, engine, parser): + df = DataFrame({"i": range(10), "+": range(3, 13), "r": range(4, 14)}) + msg = "invalid syntax" + with pytest.raises(SyntaxError, match=msg): + df.query("i - +", engine=engine, parser=parser) + + def test_query_scope(self, engine, parser): + skip_if_no_pandas_parser(parser) + + df = DataFrame( + np.random.default_rng(2).standard_normal((20, 2)), columns=list("ab") + ) + + a, b = 1, 2 # noqa: F841 + res = df.query("a > b", engine=engine, parser=parser) + expected = df[df.a > df.b] + tm.assert_frame_equal(res, expected) + + res = df.query("@a > b", engine=engine, parser=parser) + expected = df[a > df.b] + tm.assert_frame_equal(res, expected) + + # no local variable c + with pytest.raises( + UndefinedVariableError, match="local variable 'c' is not defined" + ): + df.query("@a > b > @c", engine=engine, parser=parser) + + # no column named 'c' + with pytest.raises(UndefinedVariableError, match="name 'c' is not defined"): + df.query("@a > b > c", engine=engine, parser=parser) + + def test_query_doesnt_pickup_local(self, engine, parser): + n = m = 10 + df = DataFrame( + np.random.default_rng(2).integers(m, size=(n, 3)), columns=list("abc") + ) + + # we don't pick up the local 'sin' + with pytest.raises(UndefinedVariableError, match="name 'sin' is not defined"): + df.query("sin > 5", engine=engine, parser=parser) + + def test_query_builtin(self, engine, parser): + n = m = 10 + df = DataFrame( + np.random.default_rng(2).integers(m, size=(n, 3)), columns=list("abc") + ) + + df.index.name = "sin" + msg = "Variables in expression.+" + with pytest.raises(NumExprClobberingError, match=msg): + df.query("sin > 5", engine=engine, parser=parser) + + def test_query(self, engine, parser): + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 3)), columns=["a", "b", "c"] + ) + + tm.assert_frame_equal( + df.query("a < b", engine=engine, parser=parser), df[df.a < df.b] + ) + tm.assert_frame_equal( + df.query("a + b > b * c", engine=engine, parser=parser), + df[df.a + df.b > df.b * df.c], + ) + + def test_query_index_with_name(self, engine, parser): + df = DataFrame( + np.random.default_rng(2).integers(10, size=(10, 3)), + index=Index(range(10), name="blob"), + columns=["a", "b", "c"], + ) + res = df.query("(blob < 5) & (a < b)", engine=engine, parser=parser) + expec = df[(df.index < 5) & (df.a < df.b)] + tm.assert_frame_equal(res, expec) + + res = df.query("blob < b", engine=engine, parser=parser) + expec = df[df.index < df.b] + + tm.assert_frame_equal(res, expec) + + def test_query_index_without_name(self, engine, parser): + df = DataFrame( + np.random.default_rng(2).integers(10, size=(10, 3)), + index=range(10), + columns=["a", "b", "c"], + ) + + # "index" should refer to the index + res = df.query("index < b", engine=engine, parser=parser) + expec = df[df.index < df.b] + tm.assert_frame_equal(res, expec) + + # test against a scalar + res = df.query("index < 5", engine=engine, parser=parser) + expec = df[df.index < 5] + tm.assert_frame_equal(res, expec) + + def test_nested_scope(self, engine, parser): + skip_if_no_pandas_parser(parser) + + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + df2 = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + expected = df[(df > 0) & (df2 > 0)] + + result = df.query("(@df > 0) & (@df2 > 0)", engine=engine, parser=parser) + tm.assert_frame_equal(result, expected) + + result = pd.eval("df[df > 0 and df2 > 0]", engine=engine, parser=parser) + tm.assert_frame_equal(result, expected) + + result = pd.eval( + "df[df > 0 and df2 > 0 and df[df > 0] > 0]", engine=engine, parser=parser + ) + expected = df[(df > 0) & (df2 > 0) & (df[df > 0] > 0)] + tm.assert_frame_equal(result, expected) + + result = pd.eval("df[(df>0) & (df2>0)]", engine=engine, parser=parser) + expected = df.query("(@df>0) & (@df2>0)", engine=engine, parser=parser) + tm.assert_frame_equal(result, expected) + + def test_nested_raises_on_local_self_reference(self, engine, parser): + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + + # can't reference ourself b/c we're a local so @ is necessary + with pytest.raises(UndefinedVariableError, match="name 'df' is not defined"): + df.query("df > 0", engine=engine, parser=parser) + + def test_local_syntax(self, engine, parser): + skip_if_no_pandas_parser(parser) + + df = DataFrame( + np.random.default_rng(2).standard_normal((100, 10)), + columns=list("abcdefghij"), + ) + b = 1 + expect = df[df.a < b] + result = df.query("a < @b", engine=engine, parser=parser) + tm.assert_frame_equal(result, expect) + + expect = df[df.a < df.b] + result = df.query("a < b", engine=engine, parser=parser) + tm.assert_frame_equal(result, expect) + + def test_chained_cmp_and_in(self, engine, parser): + skip_if_no_pandas_parser(parser) + cols = list("abc") + df = DataFrame( + np.random.default_rng(2).standard_normal((100, len(cols))), columns=cols + ) + res = df.query( + "a < b < c and a not in b not in c", engine=engine, parser=parser + ) + ind = (df.a < df.b) & (df.b < df.c) & ~df.b.isin(df.a) & ~df.c.isin(df.b) + expec = df[ind] + tm.assert_frame_equal(res, expec) + + def test_local_variable_with_in(self, engine, parser): + skip_if_no_pandas_parser(parser) + a = Series(np.random.default_rng(2).integers(3, size=15), name="a") + b = Series(np.random.default_rng(2).integers(10, size=15), name="b") + df = DataFrame({"a": a, "b": b}) + + expected = df.loc[(df.b - 1).isin(a)] + result = df.query("b - 1 in a", engine=engine, parser=parser) + tm.assert_frame_equal(expected, result) + + b = Series(np.random.default_rng(2).integers(10, size=15), name="b") + expected = df.loc[(b - 1).isin(a)] + result = df.query("@b - 1 in a", engine=engine, parser=parser) + tm.assert_frame_equal(expected, result) + + def test_at_inside_string(self, engine, parser): + skip_if_no_pandas_parser(parser) + c = 1 # noqa: F841 + df = DataFrame({"a": ["a", "a", "b", "b", "@c", "@c"]}) + result = df.query('a == "@c"', engine=engine, parser=parser) + expected = df[df.a == "@c"] + tm.assert_frame_equal(result, expected) + + def test_query_undefined_local(self): + engine, parser = self.engine, self.parser + skip_if_no_pandas_parser(parser) + + df = DataFrame(np.random.default_rng(2).random((10, 2)), columns=list("ab")) + with pytest.raises( + UndefinedVariableError, match="local variable 'c' is not defined" + ): + df.query("a == @c", engine=engine, parser=parser) + + def test_index_resolvers_come_after_columns_with_the_same_name( + self, engine, parser + ): + n = 1 # noqa: F841 + a = np.r_[20:101:20] + + df = DataFrame( + {"index": a, "b": np.random.default_rng(2).standard_normal(a.size)} + ) + df.index.name = "index" + result = df.query("index > 5", engine=engine, parser=parser) + expected = df[df["index"] > 5] + tm.assert_frame_equal(result, expected) + + df = DataFrame( + {"index": a, "b": np.random.default_rng(2).standard_normal(a.size)} + ) + result = df.query("ilevel_0 > 5", engine=engine, parser=parser) + expected = df.loc[df.index[df.index > 5]] + tm.assert_frame_equal(result, expected) + + df = DataFrame({"a": a, "b": np.random.default_rng(2).standard_normal(a.size)}) + df.index.name = "a" + result = df.query("a > 5", engine=engine, parser=parser) + expected = df[df.a > 5] + tm.assert_frame_equal(result, expected) + + result = df.query("index > 5", engine=engine, parser=parser) + expected = df.loc[df.index[df.index > 5]] + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("op, f", [["==", operator.eq], ["!=", operator.ne]]) + def test_inf(self, op, f, engine, parser): + n = 10 + df = DataFrame( + { + "a": np.random.default_rng(2).random(n), + "b": np.random.default_rng(2).random(n), + } + ) + df.loc[::2, 0] = np.inf + q = f"a {op} inf" + expected = df[f(df.a, np.inf)] + result = df.query(q, engine=engine, parser=parser) + tm.assert_frame_equal(result, expected) + + def test_check_tz_aware_index_query(self, tz_aware_fixture): + # https://github.com/pandas-dev/pandas/issues/29463 + tz = tz_aware_fixture + df_index = date_range( + start="2019-01-01", freq="1d", periods=10, tz=tz, name="time" + ) + expected = DataFrame(index=df_index) + df = DataFrame(index=df_index) + result = df.query('"2018-01-03 00:00:00+00" < time') + tm.assert_frame_equal(result, expected) + + expected = DataFrame(df_index) + result = df.reset_index().query('"2018-01-03 00:00:00+00" < time') + tm.assert_frame_equal(result, expected) + + def test_method_calls_in_query(self, engine, parser): + # https://github.com/pandas-dev/pandas/issues/22435 + n = 10 + df = DataFrame( + { + "a": 2 * np.random.default_rng(2).random(n), + "b": np.random.default_rng(2).random(n), + } + ) + expected = df[df["a"].astype("int") == 0] + result = df.query("a.astype('int') == 0", engine=engine, parser=parser) + tm.assert_frame_equal(result, expected) + + df = DataFrame( + { + "a": np.where( + np.random.default_rng(2).random(n) < 0.5, + np.nan, + np.random.default_rng(2).standard_normal(n), + ), + "b": np.random.default_rng(2).standard_normal(n), + } + ) + expected = df[df["a"].notnull()] + result = df.query("a.notnull()", engine=engine, parser=parser) + tm.assert_frame_equal(result, expected) + + +@td.skip_if_no("numexpr") +class TestDataFrameQueryNumExprPython(TestDataFrameQueryNumExprPandas): + @pytest.fixture + def engine(self): + return "numexpr" + + @pytest.fixture + def parser(self): + return "python" + + def test_date_query_no_attribute_access(self, engine, parser): + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + df["dates1"] = date_range("1/1/2012", periods=5) + df["dates2"] = date_range("1/1/2013", periods=5) + df["dates3"] = date_range("1/1/2014", periods=5) + res = df.query( + "(dates1 < 20130101) & (20130101 < dates3)", engine=engine, parser=parser + ) + expec = df[(df.dates1 < "20130101") & ("20130101" < df.dates3)] + tm.assert_frame_equal(res, expec) + + def test_date_query_with_NaT(self, engine, parser): + n = 10 + df = DataFrame(np.random.default_rng(2).standard_normal((n, 3))) + df["dates1"] = date_range("1/1/2012", periods=n) + df["dates2"] = date_range("1/1/2013", periods=n) + df["dates3"] = date_range("1/1/2014", periods=n) + df.loc[np.random.default_rng(2).random(n) > 0.5, "dates1"] = pd.NaT + df.loc[np.random.default_rng(2).random(n) > 0.5, "dates3"] = pd.NaT + res = df.query( + "(dates1 < 20130101) & (20130101 < dates3)", engine=engine, parser=parser + ) + expec = df[(df.dates1 < "20130101") & ("20130101" < df.dates3)] + tm.assert_frame_equal(res, expec) + + def test_date_index_query(self, engine, parser): + n = 10 + df = DataFrame(np.random.default_rng(2).standard_normal((n, 3))) + df["dates1"] = date_range("1/1/2012", periods=n) + df["dates3"] = date_range("1/1/2014", periods=n) + return_value = df.set_index("dates1", inplace=True, drop=True) + assert return_value is None + res = df.query( + "(index < 20130101) & (20130101 < dates3)", engine=engine, parser=parser + ) + expec = df[(df.index < "20130101") & ("20130101" < df.dates3)] + tm.assert_frame_equal(res, expec) + + def test_date_index_query_with_NaT(self, engine, parser): + n = 10 + # Cast to object to avoid implicit cast when setting entry to pd.NaT below + df = DataFrame(np.random.default_rng(2).standard_normal((n, 3))).astype( + {0: object} + ) + df["dates1"] = date_range("1/1/2012", periods=n) + df["dates3"] = date_range("1/1/2014", periods=n) + df.iloc[0, 0] = pd.NaT + return_value = df.set_index("dates1", inplace=True, drop=True) + assert return_value is None + res = df.query( + "(index < 20130101) & (20130101 < dates3)", engine=engine, parser=parser + ) + expec = df[(df.index < "20130101") & ("20130101" < df.dates3)] + tm.assert_frame_equal(res, expec) + + def test_date_index_query_with_NaT_duplicates(self, engine, parser): + n = 10 + df = DataFrame(np.random.default_rng(2).standard_normal((n, 3))) + df["dates1"] = date_range("1/1/2012", periods=n) + df["dates3"] = date_range("1/1/2014", periods=n) + df.loc[np.random.default_rng(2).random(n) > 0.5, "dates1"] = pd.NaT + return_value = df.set_index("dates1", inplace=True, drop=True) + assert return_value is None + msg = r"'BoolOp' nodes are not implemented" + with pytest.raises(NotImplementedError, match=msg): + df.query("index < 20130101 < dates3", engine=engine, parser=parser) + + def test_nested_scope(self, engine, parser): + # smoke test + x = 1 # noqa: F841 + result = pd.eval("x + 1", engine=engine, parser=parser) + assert result == 2 + + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + df2 = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + + # don't have the pandas parser + msg = r"The '@' prefix is only supported by the pandas parser" + with pytest.raises(SyntaxError, match=msg): + df.query("(@df>0) & (@df2>0)", engine=engine, parser=parser) + + with pytest.raises(UndefinedVariableError, match="name 'df' is not defined"): + df.query("(df>0) & (df2>0)", engine=engine, parser=parser) + + expected = df[(df > 0) & (df2 > 0)] + result = pd.eval("df[(df > 0) & (df2 > 0)]", engine=engine, parser=parser) + tm.assert_frame_equal(expected, result) + + expected = df[(df > 0) & (df2 > 0) & (df[df > 0] > 0)] + result = pd.eval( + "df[(df > 0) & (df2 > 0) & (df[df > 0] > 0)]", engine=engine, parser=parser + ) + tm.assert_frame_equal(expected, result) + + def test_query_numexpr_with_min_and_max_columns(self): + df = DataFrame({"min": [1, 2, 3], "max": [4, 5, 6]}) + regex_to_match = ( + r"Variables in expression \"\(min\) == \(1\)\" " + r"overlap with builtins: \('min'\)" + ) + with pytest.raises(NumExprClobberingError, match=regex_to_match): + df.query("min == 1") + + regex_to_match = ( + r"Variables in expression \"\(max\) == \(1\)\" " + r"overlap with builtins: \('max'\)" + ) + with pytest.raises(NumExprClobberingError, match=regex_to_match): + df.query("max == 1") + + +class TestDataFrameQueryPythonPandas(TestDataFrameQueryNumExprPandas): + @pytest.fixture + def engine(self): + return "python" + + @pytest.fixture + def parser(self): + return "pandas" + + def test_query_builtin(self, engine, parser): + n = m = 10 + df = DataFrame( + np.random.default_rng(2).integers(m, size=(n, 3)), columns=list("abc") + ) + + df.index.name = "sin" + expected = df[df.index > 5] + result = df.query("sin > 5", engine=engine, parser=parser) + tm.assert_frame_equal(expected, result) + + +class TestDataFrameQueryPythonPython(TestDataFrameQueryNumExprPython): + @pytest.fixture + def engine(self): + return "python" + + @pytest.fixture + def parser(self): + return "python" + + def test_query_builtin(self, engine, parser): + n = m = 10 + df = DataFrame( + np.random.default_rng(2).integers(m, size=(n, 3)), columns=list("abc") + ) + + df.index.name = "sin" + expected = df[df.index > 5] + result = df.query("sin > 5", engine=engine, parser=parser) + tm.assert_frame_equal(expected, result) + + +class TestDataFrameQueryStrings: + def test_str_query_method(self, parser, engine): + df = DataFrame(np.random.default_rng(2).standard_normal((10, 1)), columns=["b"]) + df["strings"] = Series(list("aabbccddee")) + expect = df[df.strings == "a"] + + if parser != "pandas": + col = "strings" + lst = '"a"' + + lhs = [col] * 2 + [lst] * 2 + rhs = lhs[::-1] + + eq, ne = "==", "!=" + ops = 2 * ([eq] + [ne]) + msg = r"'(Not)?In' nodes are not implemented" + + for lhs, op, rhs in zip(lhs, ops, rhs): + ex = f"{lhs} {op} {rhs}" + with pytest.raises(NotImplementedError, match=msg): + df.query( + ex, + engine=engine, + parser=parser, + local_dict={"strings": df.strings}, + ) + else: + res = df.query('"a" == strings', engine=engine, parser=parser) + tm.assert_frame_equal(res, expect) + + res = df.query('strings == "a"', engine=engine, parser=parser) + tm.assert_frame_equal(res, expect) + tm.assert_frame_equal(res, df[df.strings.isin(["a"])]) + + expect = df[df.strings != "a"] + res = df.query('strings != "a"', engine=engine, parser=parser) + tm.assert_frame_equal(res, expect) + + res = df.query('"a" != strings', engine=engine, parser=parser) + tm.assert_frame_equal(res, expect) + tm.assert_frame_equal(res, df[~df.strings.isin(["a"])]) + + def test_str_list_query_method(self, parser, engine): + df = DataFrame(np.random.default_rng(2).standard_normal((10, 1)), columns=["b"]) + df["strings"] = Series(list("aabbccddee")) + expect = df[df.strings.isin(["a", "b"])] + + if parser != "pandas": + col = "strings" + lst = '["a", "b"]' + + lhs = [col] * 2 + [lst] * 2 + rhs = lhs[::-1] + + eq, ne = "==", "!=" + ops = 2 * ([eq] + [ne]) + msg = r"'(Not)?In' nodes are not implemented" + + for lhs, op, rhs in zip(lhs, ops, rhs): + ex = f"{lhs} {op} {rhs}" + with pytest.raises(NotImplementedError, match=msg): + df.query(ex, engine=engine, parser=parser) + else: + res = df.query('strings == ["a", "b"]', engine=engine, parser=parser) + tm.assert_frame_equal(res, expect) + + res = df.query('["a", "b"] == strings', engine=engine, parser=parser) + tm.assert_frame_equal(res, expect) + + expect = df[~df.strings.isin(["a", "b"])] + + res = df.query('strings != ["a", "b"]', engine=engine, parser=parser) + tm.assert_frame_equal(res, expect) + + res = df.query('["a", "b"] != strings', engine=engine, parser=parser) + tm.assert_frame_equal(res, expect) + + def test_query_with_string_columns(self, parser, engine): + df = DataFrame( + { + "a": list("aaaabbbbcccc"), + "b": list("aabbccddeeff"), + "c": np.random.default_rng(2).integers(5, size=12), + "d": np.random.default_rng(2).integers(9, size=12), + } + ) + if parser == "pandas": + res = df.query("a in b", parser=parser, engine=engine) + expec = df[df.a.isin(df.b)] + tm.assert_frame_equal(res, expec) + + res = df.query("a in b and c < d", parser=parser, engine=engine) + expec = df[df.a.isin(df.b) & (df.c < df.d)] + tm.assert_frame_equal(res, expec) + else: + msg = r"'(Not)?In' nodes are not implemented" + with pytest.raises(NotImplementedError, match=msg): + df.query("a in b", parser=parser, engine=engine) + + msg = r"'BoolOp' nodes are not implemented" + with pytest.raises(NotImplementedError, match=msg): + df.query("a in b and c < d", parser=parser, engine=engine) + + def test_object_array_eq_ne(self, parser, engine, using_infer_string): + df = DataFrame( + { + "a": list("aaaabbbbcccc"), + "b": list("aabbccddeeff"), + "c": np.random.default_rng(2).integers(5, size=12), + "d": np.random.default_rng(2).integers(9, size=12), + } + ) + warning = RuntimeWarning if using_infer_string and engine == "numexpr" else None + with tm.assert_produces_warning(warning): + res = df.query("a == b", parser=parser, engine=engine) + exp = df[df.a == df.b] + tm.assert_frame_equal(res, exp) + + with tm.assert_produces_warning(warning): + res = df.query("a != b", parser=parser, engine=engine) + exp = df[df.a != df.b] + tm.assert_frame_equal(res, exp) + + def test_query_with_nested_strings(self, parser, engine): + skip_if_no_pandas_parser(parser) + events = [ + f"page {n} {act}" for n in range(1, 4) for act in ["load", "exit"] + ] * 2 + stamps1 = date_range("2014-01-01 0:00:01", freq="30s", periods=6) + stamps2 = date_range("2014-02-01 1:00:01", freq="30s", periods=6) + df = DataFrame( + { + "id": np.arange(1, 7).repeat(2), + "event": events, + "timestamp": stamps1.append(stamps2), + } + ) + + expected = df[df.event == '"page 1 load"'] + res = df.query("""'"page 1 load"' in event""", parser=parser, engine=engine) + tm.assert_frame_equal(expected, res) + + def test_query_with_nested_special_character(self, parser, engine): + skip_if_no_pandas_parser(parser) + df = DataFrame({"a": ["a", "b", "test & test"], "b": [1, 2, 3]}) + res = df.query('a == "test & test"', parser=parser, engine=engine) + expec = df[df.a == "test & test"] + tm.assert_frame_equal(res, expec) + + @pytest.mark.parametrize( + "op, func", + [ + ["<", operator.lt], + [">", operator.gt], + ["<=", operator.le], + [">=", operator.ge], + ], + ) + def test_query_lex_compare_strings( + self, parser, engine, op, func, using_infer_string + ): + a = Series(np.random.default_rng(2).choice(list("abcde"), 20)) + b = Series(np.arange(a.size)) + df = DataFrame({"X": a, "Y": b}) + + warning = RuntimeWarning if using_infer_string and engine == "numexpr" else None + with tm.assert_produces_warning(warning): + res = df.query(f'X {op} "d"', engine=engine, parser=parser) + expected = df[func(df.X, "d")] + tm.assert_frame_equal(res, expected) + + def test_query_single_element_booleans(self, parser, engine): + columns = "bid", "bidsize", "ask", "asksize" + data = np.random.default_rng(2).integers(2, size=(1, len(columns))).astype(bool) + df = DataFrame(data, columns=columns) + res = df.query("bid & ask", engine=engine, parser=parser) + expected = df[df.bid & df.ask] + tm.assert_frame_equal(res, expected) + + def test_query_string_scalar_variable(self, parser, engine): + skip_if_no_pandas_parser(parser) + df = DataFrame( + { + "Symbol": ["BUD US", "BUD US", "IBM US", "IBM US"], + "Price": [109.70, 109.72, 183.30, 183.35], + } + ) + e = df[df.Symbol == "BUD US"] + symb = "BUD US" # noqa: F841 + r = df.query("Symbol == @symb", parser=parser, engine=engine) + tm.assert_frame_equal(e, r) + + @pytest.mark.parametrize( + "in_list", + [ + [None, "asdf", "ghjk"], + ["asdf", None, "ghjk"], + ["asdf", "ghjk", None], + [None, None, "asdf"], + ["asdf", None, None], + [None, None, None], + ], + ) + def test_query_string_null_elements(self, in_list): + # GITHUB ISSUE #31516 + parser = "pandas" + engine = "python" + expected = {i: value for i, value in enumerate(in_list) if value == "asdf"} + + df_expected = DataFrame({"a": expected}, dtype="string") + df_expected.index = df_expected.index.astype("int64") + df = DataFrame({"a": in_list}, dtype="string") + res1 = df.query("a == 'asdf'", parser=parser, engine=engine) + res2 = df[df["a"] == "asdf"] + res3 = df.query("a <= 'asdf'", parser=parser, engine=engine) + tm.assert_frame_equal(res1, df_expected) + tm.assert_frame_equal(res1, res2) + tm.assert_frame_equal(res1, res3) + tm.assert_frame_equal(res2, res3) + + +class TestDataFrameEvalWithFrame: + @pytest.fixture + def frame(self): + return DataFrame( + np.random.default_rng(2).standard_normal((10, 3)), columns=list("abc") + ) + + def test_simple_expr(self, frame, parser, engine): + res = frame.eval("a + b", engine=engine, parser=parser) + expect = frame.a + frame.b + tm.assert_series_equal(res, expect) + + def test_bool_arith_expr(self, frame, parser, engine): + res = frame.eval("a[a < 1] + b", engine=engine, parser=parser) + expect = frame.a[frame.a < 1] + frame.b + tm.assert_series_equal(res, expect) + + @pytest.mark.parametrize("op", ["+", "-", "*", "/"]) + def test_invalid_type_for_operator_raises(self, parser, engine, op): + df = DataFrame({"a": [1, 2], "b": ["c", "d"]}) + msg = r"unsupported operand type\(s\) for .+: '.+' and '.+'|Cannot" + + with pytest.raises(TypeError, match=msg): + df.eval(f"a {op} b", engine=engine, parser=parser) + + +class TestDataFrameQueryBacktickQuoting: + @pytest.fixture + def df(self): + """ + Yields a dataframe with strings that may or may not need escaping + by backticks. The last two columns cannot be escaped by backticks + and should raise a ValueError. + """ + yield DataFrame( + { + "A": [1, 2, 3], + "B B": [3, 2, 1], + "C C": [4, 5, 6], + "C C": [7, 4, 3], + "C_C": [8, 9, 10], + "D_D D": [11, 1, 101], + "E.E": [6, 3, 5], + "F-F": [8, 1, 10], + "1e1": [2, 4, 8], + "def": [10, 11, 2], + "A (x)": [4, 1, 3], + "B(x)": [1, 1, 5], + "B (x)": [2, 7, 4], + " &^ :!€$?(} > <++*'' ": [2, 5, 6], + "": [10, 11, 1], + " A": [4, 7, 9], + " ": [1, 2, 1], + "it's": [6, 3, 1], + "that's": [9, 1, 8], + "☺": [8, 7, 6], + "foo#bar": [2, 4, 5], + 1: [5, 7, 9], + } + ) + + def test_single_backtick_variable_query(self, df): + res = df.query("1 < `B B`") + expect = df[1 < df["B B"]] + tm.assert_frame_equal(res, expect) + + def test_two_backtick_variables_query(self, df): + res = df.query("1 < `B B` and 4 < `C C`") + expect = df[(1 < df["B B"]) & (4 < df["C C"])] + tm.assert_frame_equal(res, expect) + + def test_single_backtick_variable_expr(self, df): + res = df.eval("A + `B B`") + expect = df["A"] + df["B B"] + tm.assert_series_equal(res, expect) + + def test_two_backtick_variables_expr(self, df): + res = df.eval("`B B` + `C C`") + expect = df["B B"] + df["C C"] + tm.assert_series_equal(res, expect) + + def test_already_underscore_variable(self, df): + res = df.eval("`C_C` + A") + expect = df["C_C"] + df["A"] + tm.assert_series_equal(res, expect) + + def test_same_name_but_underscores(self, df): + res = df.eval("C_C + `C C`") + expect = df["C_C"] + df["C C"] + tm.assert_series_equal(res, expect) + + def test_mixed_underscores_and_spaces(self, df): + res = df.eval("A + `D_D D`") + expect = df["A"] + df["D_D D"] + tm.assert_series_equal(res, expect) + + def test_backtick_quote_name_with_no_spaces(self, df): + res = df.eval("A + `C_C`") + expect = df["A"] + df["C_C"] + tm.assert_series_equal(res, expect) + + def test_special_characters(self, df): + res = df.eval("`E.E` + `F-F` - A") + expect = df["E.E"] + df["F-F"] - df["A"] + tm.assert_series_equal(res, expect) + + def test_start_with_digit(self, df): + res = df.eval("A + `1e1`") + expect = df["A"] + df["1e1"] + tm.assert_series_equal(res, expect) + + def test_keyword(self, df): + res = df.eval("A + `def`") + expect = df["A"] + df["def"] + tm.assert_series_equal(res, expect) + + def test_unneeded_quoting(self, df): + res = df.query("`A` > 2") + expect = df[df["A"] > 2] + tm.assert_frame_equal(res, expect) + + def test_parenthesis(self, df): + res = df.query("`A (x)` > 2") + expect = df[df["A (x)"] > 2] + tm.assert_frame_equal(res, expect) + + def test_empty_string(self, df): + res = df.query("`` > 5") + expect = df[df[""] > 5] + tm.assert_frame_equal(res, expect) + + def test_multiple_spaces(self, df): + res = df.query("`C C` > 5") + expect = df[df["C C"] > 5] + tm.assert_frame_equal(res, expect) + + def test_start_with_spaces(self, df): + res = df.eval("` A` + ` `") + expect = df[" A"] + df[" "] + tm.assert_series_equal(res, expect) + + def test_lots_of_operators_string(self, df): + res = df.query("` &^ :!€$?(} > <++*'' ` > 4") + expect = df[df[" &^ :!€$?(} > <++*'' "] > 4] + tm.assert_frame_equal(res, expect) + + def test_missing_attribute(self, df): + message = "module 'pandas' has no attribute 'thing'" + with pytest.raises(AttributeError, match=message): + df.eval("@pd.thing") + + def test_failing_quote(self, df): + msg = r"(Could not convert ).*( to a valid Python identifier.)" + with pytest.raises(SyntaxError, match=msg): + df.query("`it's` > `that's`") + + def test_failing_character_outside_range(self, df): + msg = r"(Could not convert ).*( to a valid Python identifier.)" + with pytest.raises(SyntaxError, match=msg): + df.query("`☺` > 4") + + def test_failing_hashtag(self, df): + msg = "Failed to parse backticks" + with pytest.raises(SyntaxError, match=msg): + df.query("`foo#bar` > 4") + + def test_call_non_named_expression(self, df): + """ + Only attributes and variables ('named functions') can be called. + .__call__() is not an allowed attribute because that would allow + calling anything. + https://github.com/pandas-dev/pandas/pull/32460 + """ + + def func(*_): + return 1 + + funcs = [func] # noqa: F841 + + df.eval("@func()") + + with pytest.raises(TypeError, match="Only named functions are supported"): + df.eval("@funcs[0]()") + + with pytest.raises(TypeError, match="Only named functions are supported"): + df.eval("@funcs[0].__call__()") + + def test_ea_dtypes(self, any_numeric_ea_and_arrow_dtype): + # GH#29618 + df = DataFrame( + [[1, 2], [3, 4]], columns=["a", "b"], dtype=any_numeric_ea_and_arrow_dtype + ) + warning = RuntimeWarning if NUMEXPR_INSTALLED else None + with tm.assert_produces_warning(warning): + result = df.eval("c = b - a") + expected = DataFrame( + [[1, 2, 1], [3, 4, 1]], + columns=["a", "b", "c"], + dtype=any_numeric_ea_and_arrow_dtype, + ) + tm.assert_frame_equal(result, expected) + + def test_ea_dtypes_and_scalar(self): + # GH#29618 + df = DataFrame([[1, 2], [3, 4]], columns=["a", "b"], dtype="Float64") + warning = RuntimeWarning if NUMEXPR_INSTALLED else None + with tm.assert_produces_warning(warning): + result = df.eval("c = b - 1") + expected = DataFrame( + [[1, 2, 1], [3, 4, 3]], columns=["a", "b", "c"], dtype="Float64" + ) + tm.assert_frame_equal(result, expected) + + def test_ea_dtypes_and_scalar_operation(self, any_numeric_ea_and_arrow_dtype): + # GH#29618 + df = DataFrame( + [[1, 2], [3, 4]], columns=["a", "b"], dtype=any_numeric_ea_and_arrow_dtype + ) + result = df.eval("c = 2 - 1") + expected = DataFrame( + { + "a": Series([1, 3], dtype=any_numeric_ea_and_arrow_dtype), + "b": Series([2, 4], dtype=any_numeric_ea_and_arrow_dtype), + "c": Series([1, 1], dtype=result["c"].dtype), + } + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("dtype", ["int64", "Int64", "int64[pyarrow]"]) + def test_query_ea_dtypes(self, dtype): + if dtype == "int64[pyarrow]": + pytest.importorskip("pyarrow") + # GH#50261 + df = DataFrame({"a": Series([1, 2], dtype=dtype)}) + ref = {2} # noqa: F841 + warning = RuntimeWarning if dtype == "Int64" and NUMEXPR_INSTALLED else None + with tm.assert_produces_warning(warning): + result = df.query("a in @ref") + expected = DataFrame({"a": Series([2], dtype=dtype, index=[1])}) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("engine", ["python", "numexpr"]) + @pytest.mark.parametrize("dtype", ["int64", "Int64", "int64[pyarrow]"]) + def test_query_ea_equality_comparison(self, dtype, engine): + # GH#50261 + warning = RuntimeWarning if engine == "numexpr" else None + if engine == "numexpr" and not NUMEXPR_INSTALLED: + pytest.skip("numexpr not installed") + if dtype == "int64[pyarrow]": + pytest.importorskip("pyarrow") + df = DataFrame( + {"A": Series([1, 1, 2], dtype="Int64"), "B": Series([1, 2, 2], dtype=dtype)} + ) + with tm.assert_produces_warning(warning): + result = df.query("A == B", engine=engine) + expected = DataFrame( + { + "A": Series([1, 2], dtype="Int64", index=[0, 2]), + "B": Series([1, 2], dtype=dtype, index=[0, 2]), + } + ) + tm.assert_frame_equal(result, expected) + + def test_all_nat_in_object(self): + # GH#57068 + now = pd.Timestamp.now("UTC") # noqa: F841 + df = DataFrame({"a": pd.to_datetime([None, None], utc=True)}, dtype=object) + result = df.query("a > @now") + expected = DataFrame({"a": []}, dtype=object) + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_reductions.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_reductions.py new file mode 100644 index 0000000000000000000000000000000000000000..66145c32c18d772f4316b8b32ddb9965a22b2abc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_reductions.py @@ -0,0 +1,2157 @@ +from datetime import timedelta +from decimal import Decimal +import re + +from dateutil.tz import tzlocal +import numpy as np +import pytest + +from pandas._config import using_pyarrow_string_dtype + +from pandas.compat import ( + IS64, + is_platform_windows, +) +from pandas.compat.numpy import np_version_gt2 +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + Categorical, + CategoricalDtype, + DataFrame, + DatetimeIndex, + Index, + PeriodIndex, + RangeIndex, + Series, + Timestamp, + date_range, + isna, + notna, + to_datetime, + to_timedelta, +) +import pandas._testing as tm +from pandas.core import ( + algorithms, + nanops, +) + +is_windows_np2_or_is32 = (is_platform_windows() and not np_version_gt2) or not IS64 +is_windows_or_is32 = is_platform_windows() or not IS64 + + +def make_skipna_wrapper(alternative, skipna_alternative=None): + """ + Create a function for calling on an array. + + Parameters + ---------- + alternative : function + The function to be called on the array with no NaNs. + Only used when 'skipna_alternative' is None. + skipna_alternative : function + The function to be called on the original array + + Returns + ------- + function + """ + if skipna_alternative: + + def skipna_wrapper(x): + return skipna_alternative(x.values) + + else: + + def skipna_wrapper(x): + nona = x.dropna() + if len(nona) == 0: + return np.nan + return alternative(nona) + + return skipna_wrapper + + +def assert_stat_op_calc( + opname, + alternative, + frame, + has_skipna=True, + check_dtype=True, + check_dates=False, + rtol=1e-5, + atol=1e-8, + skipna_alternative=None, +): + """ + Check that operator opname works as advertised on frame + + Parameters + ---------- + opname : str + Name of the operator to test on frame + alternative : function + Function that opname is tested against; i.e. "frame.opname()" should + equal "alternative(frame)". + frame : DataFrame + The object that the tests are executed on + has_skipna : bool, default True + Whether the method "opname" has the kwarg "skip_na" + check_dtype : bool, default True + Whether the dtypes of the result of "frame.opname()" and + "alternative(frame)" should be checked. + check_dates : bool, default false + Whether opname should be tested on a Datetime Series + rtol : float, default 1e-5 + Relative tolerance. + atol : float, default 1e-8 + Absolute tolerance. + skipna_alternative : function, default None + NaN-safe version of alternative + """ + f = getattr(frame, opname) + + if check_dates: + df = DataFrame({"b": date_range("1/1/2001", periods=2)}) + with tm.assert_produces_warning(None): + result = getattr(df, opname)() + assert isinstance(result, Series) + + df["a"] = range(len(df)) + with tm.assert_produces_warning(None): + result = getattr(df, opname)() + assert isinstance(result, Series) + assert len(result) + + if has_skipna: + + def wrapper(x): + return alternative(x.values) + + skipna_wrapper = make_skipna_wrapper(alternative, skipna_alternative) + result0 = f(axis=0, skipna=False) + result1 = f(axis=1, skipna=False) + tm.assert_series_equal( + result0, frame.apply(wrapper), check_dtype=check_dtype, rtol=rtol, atol=atol + ) + tm.assert_series_equal( + result1, + frame.apply(wrapper, axis=1), + rtol=rtol, + atol=atol, + ) + else: + skipna_wrapper = alternative + + result0 = f(axis=0) + result1 = f(axis=1) + tm.assert_series_equal( + result0, + frame.apply(skipna_wrapper), + check_dtype=check_dtype, + rtol=rtol, + atol=atol, + ) + + if opname in ["sum", "prod"]: + expected = frame.apply(skipna_wrapper, axis=1) + tm.assert_series_equal( + result1, expected, check_dtype=False, rtol=rtol, atol=atol + ) + + # check dtypes + if check_dtype: + lcd_dtype = frame.values.dtype + assert lcd_dtype == result0.dtype + assert lcd_dtype == result1.dtype + + # bad axis + with pytest.raises(ValueError, match="No axis named 2"): + f(axis=2) + + # all NA case + if has_skipna: + all_na = frame * np.nan + r0 = getattr(all_na, opname)(axis=0) + r1 = getattr(all_na, opname)(axis=1) + if opname in ["sum", "prod"]: + unit = 1 if opname == "prod" else 0 # result for empty sum/prod + expected = Series(unit, index=r0.index, dtype=r0.dtype) + tm.assert_series_equal(r0, expected) + expected = Series(unit, index=r1.index, dtype=r1.dtype) + tm.assert_series_equal(r1, expected) + + +@pytest.fixture +def bool_frame_with_na(): + """ + Fixture for DataFrame of booleans with index of unique strings + + Columns are ['A', 'B', 'C', 'D']; some entries are missing + """ + df = DataFrame( + np.concatenate( + [np.ones((15, 4), dtype=bool), np.zeros((15, 4), dtype=bool)], axis=0 + ), + index=Index([f"foo_{i}" for i in range(30)], dtype=object), + columns=Index(list("ABCD"), dtype=object), + dtype=object, + ) + # set some NAs + df.iloc[5:10] = np.nan + df.iloc[15:20, -2:] = np.nan + return df + + +@pytest.fixture +def float_frame_with_na(): + """ + Fixture for DataFrame of floats with index of unique strings + + Columns are ['A', 'B', 'C', 'D']; some entries are missing + """ + df = DataFrame( + np.random.default_rng(2).standard_normal((30, 4)), + index=Index([f"foo_{i}" for i in range(30)], dtype=object), + columns=Index(list("ABCD"), dtype=object), + ) + # set some NAs + df.iloc[5:10] = np.nan + df.iloc[15:20, -2:] = np.nan + return df + + +class TestDataFrameAnalytics: + # --------------------------------------------------------------------- + # Reductions + @pytest.mark.parametrize("axis", [0, 1]) + @pytest.mark.parametrize( + "opname", + [ + "count", + "sum", + "mean", + "product", + "median", + "min", + "max", + "nunique", + "var", + "std", + "sem", + pytest.param("skew", marks=td.skip_if_no("scipy")), + pytest.param("kurt", marks=td.skip_if_no("scipy")), + ], + ) + def test_stat_op_api_float_string_frame( + self, float_string_frame, axis, opname, using_infer_string + ): + if ( + (opname in ("sum", "min", "max") and axis == 0) + or opname + in ( + "count", + "nunique", + ) + ) and not (using_infer_string and opname == "sum"): + getattr(float_string_frame, opname)(axis=axis) + else: + if opname in ["var", "std", "sem", "skew", "kurt"]: + msg = "could not convert string to float: 'bar'" + elif opname == "product": + if axis == 1: + msg = "can't multiply sequence by non-int of type 'float'" + else: + msg = "can't multiply sequence by non-int of type 'str'" + elif opname == "sum": + msg = r"unsupported operand type\(s\) for \+: 'float' and 'str'" + elif opname == "mean": + if axis == 0: + # different message on different builds + msg = "|".join( + [ + r"Could not convert \['.*'\] to numeric", + "Could not convert string '(bar){30}' to numeric", + ] + ) + else: + msg = r"unsupported operand type\(s\) for \+: 'float' and 'str'" + elif opname in ["min", "max"]: + msg = "'[><]=' not supported between instances of 'float' and 'str'" + elif opname == "median": + msg = re.compile( + r"Cannot convert \[.*\] to numeric|does not support", flags=re.S + ) + if not isinstance(msg, re.Pattern): + msg = msg + "|does not support" + with pytest.raises(TypeError, match=msg): + getattr(float_string_frame, opname)(axis=axis) + if opname != "nunique": + getattr(float_string_frame, opname)(axis=axis, numeric_only=True) + + @pytest.mark.parametrize("axis", [0, 1]) + @pytest.mark.parametrize( + "opname", + [ + "count", + "sum", + "mean", + "product", + "median", + "min", + "max", + "var", + "std", + "sem", + pytest.param("skew", marks=td.skip_if_no("scipy")), + pytest.param("kurt", marks=td.skip_if_no("scipy")), + ], + ) + def test_stat_op_api_float_frame(self, float_frame, axis, opname): + getattr(float_frame, opname)(axis=axis, numeric_only=False) + + def test_stat_op_calc(self, float_frame_with_na, mixed_float_frame): + def count(s): + return notna(s).sum() + + def nunique(s): + return len(algorithms.unique1d(s.dropna())) + + def var(x): + return np.var(x, ddof=1) + + def std(x): + return np.std(x, ddof=1) + + def sem(x): + return np.std(x, ddof=1) / np.sqrt(len(x)) + + assert_stat_op_calc( + "nunique", + nunique, + float_frame_with_na, + has_skipna=False, + check_dtype=False, + check_dates=True, + ) + + # GH#32571: rol needed for flaky CI builds + # mixed types (with upcasting happening) + assert_stat_op_calc( + "sum", + np.sum, + mixed_float_frame.astype("float32"), + check_dtype=False, + rtol=1e-3, + ) + + assert_stat_op_calc( + "sum", np.sum, float_frame_with_na, skipna_alternative=np.nansum + ) + assert_stat_op_calc("mean", np.mean, float_frame_with_na, check_dates=True) + assert_stat_op_calc( + "product", np.prod, float_frame_with_na, skipna_alternative=np.nanprod + ) + + assert_stat_op_calc("var", var, float_frame_with_na) + assert_stat_op_calc("std", std, float_frame_with_na) + assert_stat_op_calc("sem", sem, float_frame_with_na) + + assert_stat_op_calc( + "count", + count, + float_frame_with_na, + has_skipna=False, + check_dtype=False, + check_dates=True, + ) + + def test_stat_op_calc_skew_kurtosis(self, float_frame_with_na): + sp_stats = pytest.importorskip("scipy.stats") + + def skewness(x): + if len(x) < 3: + return np.nan + return sp_stats.skew(x, bias=False) + + def kurt(x): + if len(x) < 4: + return np.nan + return sp_stats.kurtosis(x, bias=False) + + assert_stat_op_calc("skew", skewness, float_frame_with_na) + assert_stat_op_calc("kurt", kurt, float_frame_with_na) + + def test_median(self, float_frame_with_na, int_frame): + def wrapper(x): + if isna(x).any(): + return np.nan + return np.median(x) + + assert_stat_op_calc("median", wrapper, float_frame_with_na, check_dates=True) + assert_stat_op_calc( + "median", wrapper, int_frame, check_dtype=False, check_dates=True + ) + + @pytest.mark.parametrize( + "method", ["sum", "mean", "prod", "var", "std", "skew", "min", "max"] + ) + @pytest.mark.parametrize( + "df", + [ + DataFrame( + { + "a": [ + -0.00049987540199591344, + -0.0016467257772919831, + 0.00067695870775883013, + ], + "b": [-0, -0, 0.0], + "c": [ + 0.00031111847529610595, + 0.0014902627951905339, + -0.00094099200035979691, + ], + }, + index=["foo", "bar", "baz"], + dtype="O", + ), + DataFrame({0: [np.nan, 2], 1: [np.nan, 3], 2: [np.nan, 4]}, dtype=object), + ], + ) + @pytest.mark.filterwarnings("ignore:Mismatched null-like values:FutureWarning") + def test_stat_operators_attempt_obj_array(self, method, df, axis): + # GH#676 + assert df.values.dtype == np.object_ + result = getattr(df, method)(axis=axis) + expected = getattr(df.astype("f8"), method)(axis=axis).astype(object) + if axis in [1, "columns"] and method in ["min", "max"]: + expected[expected.isna()] = None + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("op", ["mean", "std", "var", "skew", "kurt", "sem"]) + def test_mixed_ops(self, op): + # GH#16116 + df = DataFrame( + { + "int": [1, 2, 3, 4], + "float": [1.0, 2.0, 3.0, 4.0], + "str": ["a", "b", "c", "d"], + } + ) + msg = "|".join( + [ + "Could not convert", + "could not convert", + "can't multiply sequence by non-int", + "does not support", + ] + ) + with pytest.raises(TypeError, match=msg): + getattr(df, op)() + + with pd.option_context("use_bottleneck", False): + msg = "|".join( + [ + "Could not convert", + "could not convert", + "can't multiply sequence by non-int", + "does not support", + ] + ) + with pytest.raises(TypeError, match=msg): + getattr(df, op)() + + @pytest.mark.xfail( + using_pyarrow_string_dtype(), reason="sum doesn't work for arrow strings" + ) + def test_reduce_mixed_frame(self): + # GH 6806 + df = DataFrame( + { + "bool_data": [True, True, False, False, False], + "int_data": [10, 20, 30, 40, 50], + "string_data": ["a", "b", "c", "d", "e"], + } + ) + df.reindex(columns=["bool_data", "int_data", "string_data"]) + test = df.sum(axis=0) + tm.assert_numpy_array_equal( + test.values, np.array([2, 150, "abcde"], dtype=object) + ) + alt = df.T.sum(axis=1) + tm.assert_series_equal(test, alt) + + def test_nunique(self): + df = DataFrame({"A": [1, 1, 1], "B": [1, 2, 3], "C": [1, np.nan, 3]}) + tm.assert_series_equal(df.nunique(), Series({"A": 1, "B": 3, "C": 2})) + tm.assert_series_equal( + df.nunique(dropna=False), Series({"A": 1, "B": 3, "C": 3}) + ) + tm.assert_series_equal(df.nunique(axis=1), Series({0: 1, 1: 2, 2: 2})) + tm.assert_series_equal( + df.nunique(axis=1, dropna=False), Series({0: 1, 1: 3, 2: 2}) + ) + + @pytest.mark.parametrize("tz", [None, "UTC"]) + def test_mean_mixed_datetime_numeric(self, tz): + # https://github.com/pandas-dev/pandas/issues/24752 + df = DataFrame({"A": [1, 1], "B": [Timestamp("2000", tz=tz)] * 2}) + result = df.mean() + expected = Series([1.0, Timestamp("2000", tz=tz)], index=["A", "B"]) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("tz", [None, "UTC"]) + def test_mean_includes_datetimes(self, tz): + # https://github.com/pandas-dev/pandas/issues/24752 + # Behavior in 0.24.0rc1 was buggy. + # As of 2.0 with numeric_only=None we do *not* drop datetime columns + df = DataFrame({"A": [Timestamp("2000", tz=tz)] * 2}) + result = df.mean() + + expected = Series([Timestamp("2000", tz=tz)], index=["A"]) + tm.assert_series_equal(result, expected) + + def test_mean_mixed_string_decimal(self): + # GH 11670 + # possible bug when calculating mean of DataFrame? + + d = [ + {"A": 2, "B": None, "C": Decimal("628.00")}, + {"A": 1, "B": None, "C": Decimal("383.00")}, + {"A": 3, "B": None, "C": Decimal("651.00")}, + {"A": 2, "B": None, "C": Decimal("575.00")}, + {"A": 4, "B": None, "C": Decimal("1114.00")}, + {"A": 1, "B": "TEST", "C": Decimal("241.00")}, + {"A": 2, "B": None, "C": Decimal("572.00")}, + {"A": 4, "B": None, "C": Decimal("609.00")}, + {"A": 3, "B": None, "C": Decimal("820.00")}, + {"A": 5, "B": None, "C": Decimal("1223.00")}, + ] + + df = DataFrame(d) + + with pytest.raises( + TypeError, match="unsupported operand type|does not support" + ): + df.mean() + result = df[["A", "C"]].mean() + expected = Series([2.7, 681.6], index=["A", "C"], dtype=object) + tm.assert_series_equal(result, expected) + + def test_var_std(self, datetime_frame): + result = datetime_frame.std(ddof=4) + expected = datetime_frame.apply(lambda x: x.std(ddof=4)) + tm.assert_almost_equal(result, expected) + + result = datetime_frame.var(ddof=4) + expected = datetime_frame.apply(lambda x: x.var(ddof=4)) + tm.assert_almost_equal(result, expected) + + arr = np.repeat(np.random.default_rng(2).random((1, 1000)), 1000, 0) + result = nanops.nanvar(arr, axis=0) + assert not (result < 0).any() + + with pd.option_context("use_bottleneck", False): + result = nanops.nanvar(arr, axis=0) + assert not (result < 0).any() + + @pytest.mark.parametrize("meth", ["sem", "var", "std"]) + def test_numeric_only_flag(self, meth): + # GH 9201 + df1 = DataFrame( + np.random.default_rng(2).standard_normal((5, 3)), + columns=["foo", "bar", "baz"], + ) + # Cast to object to avoid implicit cast when setting entry to "100" below + df1 = df1.astype({"foo": object}) + # set one entry to a number in str format + df1.loc[0, "foo"] = "100" + + df2 = DataFrame( + np.random.default_rng(2).standard_normal((5, 3)), + columns=["foo", "bar", "baz"], + ) + # Cast to object to avoid implicit cast when setting entry to "a" below + df2 = df2.astype({"foo": object}) + # set one entry to a non-number str + df2.loc[0, "foo"] = "a" + + result = getattr(df1, meth)(axis=1, numeric_only=True) + expected = getattr(df1[["bar", "baz"]], meth)(axis=1) + tm.assert_series_equal(expected, result) + + result = getattr(df2, meth)(axis=1, numeric_only=True) + expected = getattr(df2[["bar", "baz"]], meth)(axis=1) + tm.assert_series_equal(expected, result) + + # df1 has all numbers, df2 has a letter inside + msg = r"unsupported operand type\(s\) for -: 'float' and 'str'" + with pytest.raises(TypeError, match=msg): + getattr(df1, meth)(axis=1, numeric_only=False) + msg = "could not convert string to float: 'a'" + with pytest.raises(TypeError, match=msg): + getattr(df2, meth)(axis=1, numeric_only=False) + + def test_sem(self, datetime_frame): + result = datetime_frame.sem(ddof=4) + expected = datetime_frame.apply(lambda x: x.std(ddof=4) / np.sqrt(len(x))) + tm.assert_almost_equal(result, expected) + + arr = np.repeat(np.random.default_rng(2).random((1, 1000)), 1000, 0) + result = nanops.nansem(arr, axis=0) + assert not (result < 0).any() + + with pd.option_context("use_bottleneck", False): + result = nanops.nansem(arr, axis=0) + assert not (result < 0).any() + + @pytest.mark.parametrize( + "dropna, expected", + [ + ( + True, + { + "A": [12], + "B": [10.0], + "C": [1.0], + "D": ["a"], + "E": Categorical(["a"], categories=["a"]), + "F": DatetimeIndex(["2000-01-02"], dtype="M8[ns]"), + "G": to_timedelta(["1 days"]), + }, + ), + ( + False, + { + "A": [12], + "B": [10.0], + "C": [np.nan], + "D": np.array([np.nan], dtype=object), + "E": Categorical([np.nan], categories=["a"]), + "F": DatetimeIndex([pd.NaT], dtype="M8[ns]"), + "G": to_timedelta([pd.NaT]), + }, + ), + ( + True, + { + "H": [8, 9, np.nan, np.nan], + "I": [8, 9, np.nan, np.nan], + "J": [1, np.nan, np.nan, np.nan], + "K": Categorical(["a", np.nan, np.nan, np.nan], categories=["a"]), + "L": DatetimeIndex( + ["2000-01-02", "NaT", "NaT", "NaT"], dtype="M8[ns]" + ), + "M": to_timedelta(["1 days", "nan", "nan", "nan"]), + "N": [0, 1, 2, 3], + }, + ), + ( + False, + { + "H": [8, 9, np.nan, np.nan], + "I": [8, 9, np.nan, np.nan], + "J": [1, np.nan, np.nan, np.nan], + "K": Categorical([np.nan, "a", np.nan, np.nan], categories=["a"]), + "L": DatetimeIndex( + ["NaT", "2000-01-02", "NaT", "NaT"], dtype="M8[ns]" + ), + "M": to_timedelta(["nan", "1 days", "nan", "nan"]), + "N": [0, 1, 2, 3], + }, + ), + ], + ) + def test_mode_dropna(self, dropna, expected): + df = DataFrame( + { + "A": [12, 12, 19, 11], + "B": [10, 10, np.nan, 3], + "C": [1, np.nan, np.nan, np.nan], + "D": Series([np.nan, np.nan, "a", np.nan], dtype=object), + "E": Categorical([np.nan, np.nan, "a", np.nan]), + "F": DatetimeIndex(["NaT", "2000-01-02", "NaT", "NaT"], dtype="M8[ns]"), + "G": to_timedelta(["1 days", "nan", "nan", "nan"]), + "H": [8, 8, 9, 9], + "I": [9, 9, 8, 8], + "J": [1, 1, np.nan, np.nan], + "K": Categorical(["a", np.nan, "a", np.nan]), + "L": DatetimeIndex( + ["2000-01-02", "2000-01-02", "NaT", "NaT"], dtype="M8[ns]" + ), + "M": to_timedelta(["1 days", "nan", "1 days", "nan"]), + "N": np.arange(4, dtype="int64"), + } + ) + + result = df[sorted(expected.keys())].mode(dropna=dropna) + expected = DataFrame(expected) + tm.assert_frame_equal(result, expected) + + def test_mode_sortwarning(self, using_infer_string): + # Check for the warning that is raised when the mode + # results cannot be sorted + + df = DataFrame({"A": [np.nan, np.nan, "a", "a"]}) + expected = DataFrame({"A": ["a", np.nan]}) + + warning = None if using_infer_string else UserWarning + with tm.assert_produces_warning(warning): + result = df.mode(dropna=False) + result = result.sort_values(by="A").reset_index(drop=True) + + tm.assert_frame_equal(result, expected) + + def test_mode_empty_df(self): + df = DataFrame([], columns=["a", "b"]) + result = df.mode() + expected = DataFrame([], columns=["a", "b"], index=Index([], dtype=np.int64)) + tm.assert_frame_equal(result, expected) + + def test_operators_timedelta64(self): + df = DataFrame( + { + "A": date_range("2012-1-1", periods=3, freq="D"), + "B": date_range("2012-1-2", periods=3, freq="D"), + "C": Timestamp("20120101") - timedelta(minutes=5, seconds=5), + } + ) + + diffs = DataFrame({"A": df["A"] - df["C"], "B": df["A"] - df["B"]}) + + # min + result = diffs.min() + assert result.iloc[0] == diffs.loc[0, "A"] + assert result.iloc[1] == diffs.loc[0, "B"] + + result = diffs.min(axis=1) + assert (result == diffs.loc[0, "B"]).all() + + # max + result = diffs.max() + assert result.iloc[0] == diffs.loc[2, "A"] + assert result.iloc[1] == diffs.loc[2, "B"] + + result = diffs.max(axis=1) + assert (result == diffs["A"]).all() + + # abs + result = diffs.abs() + result2 = abs(diffs) + expected = DataFrame({"A": df["A"] - df["C"], "B": df["B"] - df["A"]}) + tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(result2, expected) + + # mixed frame + mixed = diffs.copy() + mixed["C"] = "foo" + mixed["D"] = 1 + mixed["E"] = 1.0 + mixed["F"] = Timestamp("20130101") + + # results in an object array + result = mixed.min() + expected = Series( + [ + pd.Timedelta(timedelta(seconds=5 * 60 + 5)), + pd.Timedelta(timedelta(days=-1)), + "foo", + 1, + 1.0, + Timestamp("20130101"), + ], + index=mixed.columns, + ) + tm.assert_series_equal(result, expected) + + # excludes non-numeric + result = mixed.min(axis=1, numeric_only=True) + expected = Series([1, 1, 1.0], index=[0, 1, 2]) + tm.assert_series_equal(result, expected) + + # works when only those columns are selected + result = mixed[["A", "B"]].min(1) + expected = Series([timedelta(days=-1)] * 3) + tm.assert_series_equal(result, expected) + + result = mixed[["A", "B"]].min() + expected = Series( + [timedelta(seconds=5 * 60 + 5), timedelta(days=-1)], index=["A", "B"] + ) + tm.assert_series_equal(result, expected) + + # GH 3106 + df = DataFrame( + { + "time": date_range("20130102", periods=5), + "time2": date_range("20130105", periods=5), + } + ) + df["off1"] = df["time2"] - df["time"] + assert df["off1"].dtype == "timedelta64[ns]" + + df["off2"] = df["time"] - df["time2"] + df._consolidate_inplace() + assert df["off1"].dtype == "timedelta64[ns]" + assert df["off2"].dtype == "timedelta64[ns]" + + def test_std_timedelta64_skipna_false(self): + # GH#37392 + tdi = pd.timedelta_range("1 Day", periods=10) + df = DataFrame({"A": tdi, "B": tdi}, copy=True) + df.iloc[-2, -1] = pd.NaT + + result = df.std(skipna=False) + expected = Series( + [df["A"].std(), pd.NaT], index=["A", "B"], dtype="timedelta64[ns]" + ) + tm.assert_series_equal(result, expected) + + result = df.std(axis=1, skipna=False) + expected = Series([pd.Timedelta(0)] * 8 + [pd.NaT, pd.Timedelta(0)]) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "values", [["2022-01-01", "2022-01-02", pd.NaT, "2022-01-03"], 4 * [pd.NaT]] + ) + def test_std_datetime64_with_nat( + self, values, skipna, using_array_manager, request, unit + ): + # GH#51335 + if using_array_manager and ( + not skipna or all(value is pd.NaT for value in values) + ): + mark = pytest.mark.xfail( + reason="GH#51446: Incorrect type inference on NaT in reduction result" + ) + request.applymarker(mark) + dti = to_datetime(values).as_unit(unit) + df = DataFrame({"a": dti}) + result = df.std(skipna=skipna) + if not skipna or all(value is pd.NaT for value in values): + expected = Series({"a": pd.NaT}, dtype=f"timedelta64[{unit}]") + else: + # 86400000000000ns == 1 day + expected = Series({"a": 86400000000000}, dtype=f"timedelta64[{unit}]") + tm.assert_series_equal(result, expected) + + def test_sum_corner(self): + empty_frame = DataFrame() + + axis0 = empty_frame.sum(0) + axis1 = empty_frame.sum(1) + assert isinstance(axis0, Series) + assert isinstance(axis1, Series) + assert len(axis0) == 0 + assert len(axis1) == 0 + + @pytest.mark.parametrize( + "index", + [ + RangeIndex(0), + DatetimeIndex([]), + Index([], dtype=np.int64), + Index([], dtype=np.float64), + DatetimeIndex([], freq="ME"), + PeriodIndex([], freq="D"), + ], + ) + def test_axis_1_empty(self, all_reductions, index): + df = DataFrame(columns=["a"], index=index) + result = getattr(df, all_reductions)(axis=1) + if all_reductions in ("any", "all"): + expected_dtype = "bool" + elif all_reductions == "count": + expected_dtype = "int64" + else: + expected_dtype = "object" + expected = Series([], index=index, dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("method, unit", [("sum", 0), ("prod", 1)]) + @pytest.mark.parametrize("numeric_only", [None, True, False]) + def test_sum_prod_nanops(self, method, unit, numeric_only): + idx = ["a", "b", "c"] + df = DataFrame({"a": [unit, unit], "b": [unit, np.nan], "c": [np.nan, np.nan]}) + # The default + result = getattr(df, method)(numeric_only=numeric_only) + expected = Series([unit, unit, unit], index=idx, dtype="float64") + tm.assert_series_equal(result, expected) + + # min_count=1 + result = getattr(df, method)(numeric_only=numeric_only, min_count=1) + expected = Series([unit, unit, np.nan], index=idx) + tm.assert_series_equal(result, expected) + + # min_count=0 + result = getattr(df, method)(numeric_only=numeric_only, min_count=0) + expected = Series([unit, unit, unit], index=idx, dtype="float64") + tm.assert_series_equal(result, expected) + + result = getattr(df.iloc[1:], method)(numeric_only=numeric_only, min_count=1) + expected = Series([unit, np.nan, np.nan], index=idx) + tm.assert_series_equal(result, expected) + + # min_count > 1 + df = DataFrame({"A": [unit] * 10, "B": [unit] * 5 + [np.nan] * 5}) + result = getattr(df, method)(numeric_only=numeric_only, min_count=5) + expected = Series(result, index=["A", "B"]) + tm.assert_series_equal(result, expected) + + result = getattr(df, method)(numeric_only=numeric_only, min_count=6) + expected = Series(result, index=["A", "B"]) + tm.assert_series_equal(result, expected) + + def test_sum_nanops_timedelta(self): + # prod isn't defined on timedeltas + idx = ["a", "b", "c"] + df = DataFrame({"a": [0, 0], "b": [0, np.nan], "c": [np.nan, np.nan]}) + + df2 = df.apply(to_timedelta) + + # 0 by default + result = df2.sum() + expected = Series([0, 0, 0], dtype="m8[ns]", index=idx) + tm.assert_series_equal(result, expected) + + # min_count=0 + result = df2.sum(min_count=0) + tm.assert_series_equal(result, expected) + + # min_count=1 + result = df2.sum(min_count=1) + expected = Series([0, 0, np.nan], dtype="m8[ns]", index=idx) + tm.assert_series_equal(result, expected) + + def test_sum_nanops_min_count(self): + # https://github.com/pandas-dev/pandas/issues/39738 + df = DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}) + result = df.sum(min_count=10) + expected = Series([np.nan, np.nan], index=["x", "y"]) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("float_type", ["float16", "float32", "float64"]) + @pytest.mark.parametrize( + "kwargs, expected_result", + [ + ({"axis": 1, "min_count": 2}, [3.2, 5.3, np.nan]), + ({"axis": 1, "min_count": 3}, [np.nan, np.nan, np.nan]), + ({"axis": 1, "skipna": False}, [3.2, 5.3, np.nan]), + ], + ) + def test_sum_nanops_dtype_min_count(self, float_type, kwargs, expected_result): + # GH#46947 + df = DataFrame({"a": [1.0, 2.3, 4.4], "b": [2.2, 3, np.nan]}, dtype=float_type) + result = df.sum(**kwargs) + expected = Series(expected_result).astype(float_type) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("float_type", ["float16", "float32", "float64"]) + @pytest.mark.parametrize( + "kwargs, expected_result", + [ + ({"axis": 1, "min_count": 2}, [2.0, 4.0, np.nan]), + ({"axis": 1, "min_count": 3}, [np.nan, np.nan, np.nan]), + ({"axis": 1, "skipna": False}, [2.0, 4.0, np.nan]), + ], + ) + def test_prod_nanops_dtype_min_count(self, float_type, kwargs, expected_result): + # GH#46947 + df = DataFrame( + {"a": [1.0, 2.0, 4.4], "b": [2.0, 2.0, np.nan]}, dtype=float_type + ) + result = df.prod(**kwargs) + expected = Series(expected_result).astype(float_type) + tm.assert_series_equal(result, expected) + + def test_sum_object(self, float_frame): + values = float_frame.values.astype(int) + frame = DataFrame(values, index=float_frame.index, columns=float_frame.columns) + deltas = frame * timedelta(1) + deltas.sum() + + def test_sum_bool(self, float_frame): + # ensure this works, bug report + bools = np.isnan(float_frame) + bools.sum(1) + bools.sum(0) + + def test_sum_mixed_datetime(self): + # GH#30886 + df = DataFrame({"A": date_range("2000", periods=4), "B": [1, 2, 3, 4]}).reindex( + [2, 3, 4] + ) + with pytest.raises(TypeError, match="does not support reduction 'sum'"): + df.sum() + + def test_mean_corner(self, float_frame, float_string_frame): + # unit test when have object data + msg = "Could not convert|does not support" + with pytest.raises(TypeError, match=msg): + float_string_frame.mean(axis=0) + + # xs sum mixed type, just want to know it works... + with pytest.raises(TypeError, match="unsupported operand type"): + float_string_frame.mean(axis=1) + + # take mean of boolean column + float_frame["bool"] = float_frame["A"] > 0 + means = float_frame.mean(0) + assert means["bool"] == float_frame["bool"].values.mean() + + def test_mean_datetimelike(self): + # GH#24757 check that datetimelike are excluded by default, handled + # correctly with numeric_only=True + # As of 2.0, datetimelike are *not* excluded with numeric_only=None + + df = DataFrame( + { + "A": np.arange(3), + "B": date_range("2016-01-01", periods=3), + "C": pd.timedelta_range("1D", periods=3), + "D": pd.period_range("2016", periods=3, freq="Y"), + } + ) + result = df.mean(numeric_only=True) + expected = Series({"A": 1.0}) + tm.assert_series_equal(result, expected) + + with pytest.raises(TypeError, match="mean is not implemented for PeriodArray"): + df.mean() + + def test_mean_datetimelike_numeric_only_false(self): + df = DataFrame( + { + "A": np.arange(3), + "B": date_range("2016-01-01", periods=3), + "C": pd.timedelta_range("1D", periods=3), + } + ) + + # datetime(tz) and timedelta work + result = df.mean(numeric_only=False) + expected = Series({"A": 1, "B": df.loc[1, "B"], "C": df.loc[1, "C"]}) + tm.assert_series_equal(result, expected) + + # mean of period is not allowed + df["D"] = pd.period_range("2016", periods=3, freq="Y") + + with pytest.raises(TypeError, match="mean is not implemented for Period"): + df.mean(numeric_only=False) + + def test_mean_extensionarray_numeric_only_true(self): + # https://github.com/pandas-dev/pandas/issues/33256 + arr = np.random.default_rng(2).integers(1000, size=(10, 5)) + df = DataFrame(arr, dtype="Int64") + result = df.mean(numeric_only=True) + expected = DataFrame(arr).mean().astype("Float64") + tm.assert_series_equal(result, expected) + + def test_stats_mixed_type(self, float_string_frame): + with pytest.raises(TypeError, match="could not convert"): + float_string_frame.std(1) + with pytest.raises(TypeError, match="could not convert"): + float_string_frame.var(1) + with pytest.raises(TypeError, match="unsupported operand type"): + float_string_frame.mean(1) + with pytest.raises(TypeError, match="could not convert"): + float_string_frame.skew(1) + + def test_sum_bools(self): + df = DataFrame(index=range(1), columns=range(10)) + bools = isna(df) + assert bools.sum(axis=1)[0] == 10 + + # ---------------------------------------------------------------------- + # Index of max / min + + @pytest.mark.parametrize("skipna", [True, False]) + @pytest.mark.parametrize("axis", [0, 1]) + def test_idxmin(self, float_frame, int_frame, skipna, axis): + frame = float_frame + frame.iloc[5:10] = np.nan + frame.iloc[15:20, -2:] = np.nan + for df in [frame, int_frame]: + warn = None + if skipna is False or axis == 1: + warn = None if df is int_frame else FutureWarning + msg = "The behavior of DataFrame.idxmin with all-NA values" + with tm.assert_produces_warning(warn, match=msg): + result = df.idxmin(axis=axis, skipna=skipna) + + msg2 = "The behavior of Series.idxmin" + with tm.assert_produces_warning(warn, match=msg2): + expected = df.apply(Series.idxmin, axis=axis, skipna=skipna) + expected = expected.astype(df.index.dtype) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("axis", [0, 1]) + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + def test_idxmin_empty(self, index, skipna, axis): + # GH53265 + if axis == 0: + frame = DataFrame(index=index) + else: + frame = DataFrame(columns=index) + + result = frame.idxmin(axis=axis, skipna=skipna) + expected = Series(dtype=index.dtype) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("numeric_only", [True, False]) + def test_idxmin_numeric_only(self, numeric_only): + df = DataFrame({"a": [2, 3, 1], "b": [2, 1, 1], "c": list("xyx")}) + result = df.idxmin(numeric_only=numeric_only) + if numeric_only: + expected = Series([2, 1], index=["a", "b"]) + else: + expected = Series([2, 1, 0], index=["a", "b", "c"]) + tm.assert_series_equal(result, expected) + + def test_idxmin_axis_2(self, float_frame): + frame = float_frame + msg = "No axis named 2 for object type DataFrame" + with pytest.raises(ValueError, match=msg): + frame.idxmin(axis=2) + + @pytest.mark.parametrize("skipna", [True, False]) + @pytest.mark.parametrize("axis", [0, 1]) + def test_idxmax(self, float_frame, int_frame, skipna, axis): + frame = float_frame + frame.iloc[5:10] = np.nan + frame.iloc[15:20, -2:] = np.nan + for df in [frame, int_frame]: + warn = None + if skipna is False or axis == 1: + warn = None if df is int_frame else FutureWarning + msg = "The behavior of DataFrame.idxmax with all-NA values" + with tm.assert_produces_warning(warn, match=msg): + result = df.idxmax(axis=axis, skipna=skipna) + + msg2 = "The behavior of Series.idxmax" + with tm.assert_produces_warning(warn, match=msg2): + expected = df.apply(Series.idxmax, axis=axis, skipna=skipna) + expected = expected.astype(df.index.dtype) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("axis", [0, 1]) + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + def test_idxmax_empty(self, index, skipna, axis): + # GH53265 + if axis == 0: + frame = DataFrame(index=index) + else: + frame = DataFrame(columns=index) + + result = frame.idxmax(axis=axis, skipna=skipna) + expected = Series(dtype=index.dtype) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("numeric_only", [True, False]) + def test_idxmax_numeric_only(self, numeric_only): + df = DataFrame({"a": [2, 3, 1], "b": [2, 1, 1], "c": list("xyx")}) + result = df.idxmax(numeric_only=numeric_only) + if numeric_only: + expected = Series([1, 0], index=["a", "b"]) + else: + expected = Series([1, 0, 1], index=["a", "b", "c"]) + tm.assert_series_equal(result, expected) + + def test_idxmax_arrow_types(self): + # GH#55368 + pytest.importorskip("pyarrow") + + df = DataFrame({"a": [2, 3, 1], "b": [2, 1, 1]}, dtype="int64[pyarrow]") + result = df.idxmax() + expected = Series([1, 0], index=["a", "b"]) + tm.assert_series_equal(result, expected) + + result = df.idxmin() + expected = Series([2, 1], index=["a", "b"]) + tm.assert_series_equal(result, expected) + + df = DataFrame({"a": ["b", "c", "a"]}, dtype="string[pyarrow]") + result = df.idxmax(numeric_only=False) + expected = Series([1], index=["a"]) + tm.assert_series_equal(result, expected) + + result = df.idxmin(numeric_only=False) + expected = Series([2], index=["a"]) + tm.assert_series_equal(result, expected) + + def test_idxmax_axis_2(self, float_frame): + frame = float_frame + msg = "No axis named 2 for object type DataFrame" + with pytest.raises(ValueError, match=msg): + frame.idxmax(axis=2) + + def test_idxmax_mixed_dtype(self): + # don't cast to object, which would raise in nanops + dti = date_range("2016-01-01", periods=3) + + # Copying dti is needed for ArrayManager otherwise when we set + # df.loc[0, 3] = pd.NaT below it edits dti + df = DataFrame({1: [0, 2, 1], 2: range(3)[::-1], 3: dti.copy(deep=True)}) + + result = df.idxmax() + expected = Series([1, 0, 2], index=[1, 2, 3]) + tm.assert_series_equal(result, expected) + + result = df.idxmin() + expected = Series([0, 2, 0], index=[1, 2, 3]) + tm.assert_series_equal(result, expected) + + # with NaTs + df.loc[0, 3] = pd.NaT + result = df.idxmax() + expected = Series([1, 0, 2], index=[1, 2, 3]) + tm.assert_series_equal(result, expected) + + result = df.idxmin() + expected = Series([0, 2, 1], index=[1, 2, 3]) + tm.assert_series_equal(result, expected) + + # with multi-column dt64 block + df[4] = dti[::-1] + df._consolidate_inplace() + + result = df.idxmax() + expected = Series([1, 0, 2, 0], index=[1, 2, 3, 4]) + tm.assert_series_equal(result, expected) + + result = df.idxmin() + expected = Series([0, 2, 1, 2], index=[1, 2, 3, 4]) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "op, expected_value", + [("idxmax", [0, 4]), ("idxmin", [0, 5])], + ) + def test_idxmax_idxmin_convert_dtypes(self, op, expected_value): + # GH 40346 + df = DataFrame( + { + "ID": [100, 100, 100, 200, 200, 200], + "value": [0, 0, 0, 1, 2, 0], + }, + dtype="Int64", + ) + df = df.groupby("ID") + + result = getattr(df, op)() + expected = DataFrame( + {"value": expected_value}, + index=Index([100, 200], name="ID", dtype="Int64"), + ) + tm.assert_frame_equal(result, expected) + + def test_idxmax_dt64_multicolumn_axis1(self): + dti = date_range("2016-01-01", periods=3) + df = DataFrame({3: dti, 4: dti[::-1]}, copy=True) + df.iloc[0, 0] = pd.NaT + + df._consolidate_inplace() + + result = df.idxmax(axis=1) + expected = Series([4, 3, 3]) + tm.assert_series_equal(result, expected) + + result = df.idxmin(axis=1) + expected = Series([4, 3, 4]) + tm.assert_series_equal(result, expected) + + # ---------------------------------------------------------------------- + # Logical reductions + + @pytest.mark.parametrize("opname", ["any", "all"]) + @pytest.mark.parametrize("axis", [0, 1]) + @pytest.mark.parametrize("bool_only", [False, True]) + def test_any_all_mixed_float(self, opname, axis, bool_only, float_string_frame): + # make sure op works on mixed-type frame + mixed = float_string_frame + mixed["_bool_"] = np.random.default_rng(2).standard_normal(len(mixed)) > 0.5 + + getattr(mixed, opname)(axis=axis, bool_only=bool_only) + + @pytest.mark.parametrize("opname", ["any", "all"]) + @pytest.mark.parametrize("axis", [0, 1]) + def test_any_all_bool_with_na(self, opname, axis, bool_frame_with_na): + getattr(bool_frame_with_na, opname)(axis=axis, bool_only=False) + + @pytest.mark.filterwarnings("ignore:Downcasting object dtype arrays:FutureWarning") + @pytest.mark.parametrize("opname", ["any", "all"]) + def test_any_all_bool_frame(self, opname, bool_frame_with_na): + # GH#12863: numpy gives back non-boolean data for object type + # so fill NaNs to compare with pandas behavior + frame = bool_frame_with_na.fillna(True) + alternative = getattr(np, opname) + f = getattr(frame, opname) + + def skipna_wrapper(x): + nona = x.dropna().values + return alternative(nona) + + def wrapper(x): + return alternative(x.values) + + result0 = f(axis=0, skipna=False) + result1 = f(axis=1, skipna=False) + + tm.assert_series_equal(result0, frame.apply(wrapper)) + tm.assert_series_equal(result1, frame.apply(wrapper, axis=1)) + + result0 = f(axis=0) + result1 = f(axis=1) + + tm.assert_series_equal(result0, frame.apply(skipna_wrapper)) + tm.assert_series_equal( + result1, frame.apply(skipna_wrapper, axis=1), check_dtype=False + ) + + # bad axis + with pytest.raises(ValueError, match="No axis named 2"): + f(axis=2) + + # all NA case + all_na = frame * np.nan + r0 = getattr(all_na, opname)(axis=0) + r1 = getattr(all_na, opname)(axis=1) + if opname == "any": + assert not r0.any() + assert not r1.any() + else: + assert r0.all() + assert r1.all() + + def test_any_all_extra(self): + df = DataFrame( + { + "A": [True, False, False], + "B": [True, True, False], + "C": [True, True, True], + }, + index=["a", "b", "c"], + ) + result = df[["A", "B"]].any(axis=1) + expected = Series([True, True, False], index=["a", "b", "c"]) + tm.assert_series_equal(result, expected) + + result = df[["A", "B"]].any(axis=1, bool_only=True) + tm.assert_series_equal(result, expected) + + result = df.all(1) + expected = Series([True, False, False], index=["a", "b", "c"]) + tm.assert_series_equal(result, expected) + + result = df.all(1, bool_only=True) + tm.assert_series_equal(result, expected) + + # Axis is None + result = df.all(axis=None).item() + assert result is False + + result = df.any(axis=None).item() + assert result is True + + result = df[["C"]].all(axis=None).item() + assert result is True + + @pytest.mark.parametrize("axis", [0, 1]) + @pytest.mark.parametrize("bool_agg_func", ["any", "all"]) + @pytest.mark.parametrize("skipna", [True, False]) + def test_any_all_object_dtype( + self, axis, bool_agg_func, skipna, using_infer_string + ): + # GH#35450 + df = DataFrame( + data=[ + [1, np.nan, np.nan, True], + [np.nan, 2, np.nan, True], + [np.nan, np.nan, np.nan, True], + [np.nan, np.nan, "5", np.nan], + ] + ) + if using_infer_string: + # na in object is True while in string pyarrow numpy it's false + val = not axis == 0 and not skipna and bool_agg_func == "all" + else: + val = True + result = getattr(df, bool_agg_func)(axis=axis, skipna=skipna) + expected = Series([True, True, val, True]) + tm.assert_series_equal(result, expected) + + # GH#50947 deprecates this but it is not emitting a warning in some builds. + @pytest.mark.filterwarnings( + "ignore:'any' with datetime64 dtypes is deprecated.*:FutureWarning" + ) + def test_any_datetime(self): + # GH 23070 + float_data = [1, np.nan, 3, np.nan] + datetime_data = [ + Timestamp("1960-02-15"), + Timestamp("1960-02-16"), + pd.NaT, + pd.NaT, + ] + df = DataFrame({"A": float_data, "B": datetime_data}) + + result = df.any(axis=1) + + expected = Series([True, True, True, False]) + tm.assert_series_equal(result, expected) + + def test_any_all_bool_only(self): + # GH 25101 + df = DataFrame( + {"col1": [1, 2, 3], "col2": [4, 5, 6], "col3": [None, None, None]}, + columns=Index(["col1", "col2", "col3"], dtype=object), + ) + + result = df.all(bool_only=True) + expected = Series(dtype=np.bool_, index=[]) + tm.assert_series_equal(result, expected) + + df = DataFrame( + { + "col1": [1, 2, 3], + "col2": [4, 5, 6], + "col3": [None, None, None], + "col4": [False, False, True], + } + ) + + result = df.all(bool_only=True) + expected = Series({"col4": False}) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "func, data, expected", + [ + (np.any, {}, False), + (np.all, {}, True), + (np.any, {"A": []}, False), + (np.all, {"A": []}, True), + (np.any, {"A": [False, False]}, False), + (np.all, {"A": [False, False]}, False), + (np.any, {"A": [True, False]}, True), + (np.all, {"A": [True, False]}, False), + (np.any, {"A": [True, True]}, True), + (np.all, {"A": [True, True]}, True), + (np.any, {"A": [False], "B": [False]}, False), + (np.all, {"A": [False], "B": [False]}, False), + (np.any, {"A": [False, False], "B": [False, True]}, True), + (np.all, {"A": [False, False], "B": [False, True]}, False), + # other types + (np.all, {"A": Series([0.0, 1.0], dtype="float")}, False), + (np.any, {"A": Series([0.0, 1.0], dtype="float")}, True), + (np.all, {"A": Series([0, 1], dtype=int)}, False), + (np.any, {"A": Series([0, 1], dtype=int)}, True), + pytest.param(np.all, {"A": Series([0, 1], dtype="M8[ns]")}, False), + pytest.param(np.all, {"A": Series([0, 1], dtype="M8[ns, UTC]")}, False), + pytest.param(np.any, {"A": Series([0, 1], dtype="M8[ns]")}, True), + pytest.param(np.any, {"A": Series([0, 1], dtype="M8[ns, UTC]")}, True), + pytest.param(np.all, {"A": Series([1, 2], dtype="M8[ns]")}, True), + pytest.param(np.all, {"A": Series([1, 2], dtype="M8[ns, UTC]")}, True), + pytest.param(np.any, {"A": Series([1, 2], dtype="M8[ns]")}, True), + pytest.param(np.any, {"A": Series([1, 2], dtype="M8[ns, UTC]")}, True), + pytest.param(np.all, {"A": Series([0, 1], dtype="m8[ns]")}, False), + pytest.param(np.any, {"A": Series([0, 1], dtype="m8[ns]")}, True), + pytest.param(np.all, {"A": Series([1, 2], dtype="m8[ns]")}, True), + pytest.param(np.any, {"A": Series([1, 2], dtype="m8[ns]")}, True), + # np.all on Categorical raises, so the reduction drops the + # column, so all is being done on an empty Series, so is True + (np.all, {"A": Series([0, 1], dtype="category")}, True), + (np.any, {"A": Series([0, 1], dtype="category")}, False), + (np.all, {"A": Series([1, 2], dtype="category")}, True), + (np.any, {"A": Series([1, 2], dtype="category")}, False), + # Mix GH#21484 + pytest.param( + np.all, + { + "A": Series([10, 20], dtype="M8[ns]"), + "B": Series([10, 20], dtype="m8[ns]"), + }, + True, + ), + ], + ) + def test_any_all_np_func(self, func, data, expected): + # GH 19976 + data = DataFrame(data) + + if any(isinstance(x, CategoricalDtype) for x in data.dtypes): + with pytest.raises( + TypeError, match="dtype category does not support reduction" + ): + func(data) + + # method version + with pytest.raises( + TypeError, match="dtype category does not support reduction" + ): + getattr(DataFrame(data), func.__name__)(axis=None) + else: + msg = "'(any|all)' with datetime64 dtypes is deprecated" + if data.dtypes.apply(lambda x: x.kind == "M").any(): + warn = FutureWarning + else: + warn = None + + with tm.assert_produces_warning(warn, match=msg, check_stacklevel=False): + # GH#34479 + result = func(data) + assert isinstance(result, np.bool_) + assert result.item() is expected + + # method version + with tm.assert_produces_warning(warn, match=msg): + # GH#34479 + result = getattr(DataFrame(data), func.__name__)(axis=None) + assert isinstance(result, np.bool_) + assert result.item() is expected + + def test_any_all_object(self): + # GH 19976 + result = np.all(DataFrame(columns=["a", "b"])).item() + assert result is True + + result = np.any(DataFrame(columns=["a", "b"])).item() + assert result is False + + def test_any_all_object_bool_only(self): + df = DataFrame({"A": ["foo", 2], "B": [True, False]}).astype(object) + df._consolidate_inplace() + df["C"] = Series([True, True]) + + # Categorical of bools is _not_ considered booly + df["D"] = df["C"].astype("category") + + # The underlying bug is in DataFrame._get_bool_data, so we check + # that while we're here + res = df._get_bool_data() + expected = df[["C"]] + tm.assert_frame_equal(res, expected) + + res = df.all(bool_only=True, axis=0) + expected = Series([True], index=["C"]) + tm.assert_series_equal(res, expected) + + # operating on a subset of columns should not produce a _larger_ Series + res = df[["B", "C"]].all(bool_only=True, axis=0) + tm.assert_series_equal(res, expected) + + assert df.all(bool_only=True, axis=None) + + res = df.any(bool_only=True, axis=0) + expected = Series([True], index=["C"]) + tm.assert_series_equal(res, expected) + + # operating on a subset of columns should not produce a _larger_ Series + res = df[["C"]].any(bool_only=True, axis=0) + tm.assert_series_equal(res, expected) + + assert df.any(bool_only=True, axis=None) + + # --------------------------------------------------------------------- + # Unsorted + + def test_series_broadcasting(self): + # smoke test for numpy warnings + # GH 16378, GH 16306 + df = DataFrame([1.0, 1.0, 1.0]) + df_nan = DataFrame({"A": [np.nan, 2.0, np.nan]}) + s = Series([1, 1, 1]) + s_nan = Series([np.nan, np.nan, 1]) + + with tm.assert_produces_warning(None): + df_nan.clip(lower=s, axis=0) + for op in ["lt", "le", "gt", "ge", "eq", "ne"]: + getattr(df, op)(s_nan, axis=0) + + +class TestDataFrameReductions: + def test_min_max_dt64_with_NaT(self): + # Both NaT and Timestamp are in DataFrame. + df = DataFrame({"foo": [pd.NaT, pd.NaT, Timestamp("2012-05-01")]}) + + res = df.min() + exp = Series([Timestamp("2012-05-01")], index=["foo"]) + tm.assert_series_equal(res, exp) + + res = df.max() + exp = Series([Timestamp("2012-05-01")], index=["foo"]) + tm.assert_series_equal(res, exp) + + # GH12941, only NaTs are in DataFrame. + df = DataFrame({"foo": [pd.NaT, pd.NaT]}) + + res = df.min() + exp = Series([pd.NaT], index=["foo"]) + tm.assert_series_equal(res, exp) + + res = df.max() + exp = Series([pd.NaT], index=["foo"]) + tm.assert_series_equal(res, exp) + + def test_min_max_dt64_with_NaT_skipna_false(self, request, tz_naive_fixture): + # GH#36907 + tz = tz_naive_fixture + if isinstance(tz, tzlocal) and is_platform_windows(): + pytest.skip( + "GH#37659 OSError raised within tzlocal bc Windows " + "chokes in times before 1970-01-01" + ) + + df = DataFrame( + { + "a": [ + Timestamp("2020-01-01 08:00:00", tz=tz), + Timestamp("1920-02-01 09:00:00", tz=tz), + ], + "b": [Timestamp("2020-02-01 08:00:00", tz=tz), pd.NaT], + } + ) + res = df.min(axis=1, skipna=False) + expected = Series([df.loc[0, "a"], pd.NaT]) + assert expected.dtype == df["a"].dtype + + tm.assert_series_equal(res, expected) + + res = df.max(axis=1, skipna=False) + expected = Series([df.loc[0, "b"], pd.NaT]) + assert expected.dtype == df["a"].dtype + + tm.assert_series_equal(res, expected) + + def test_min_max_dt64_api_consistency_with_NaT(self): + # Calling the following sum functions returned an error for dataframes but + # returned NaT for series. These tests check that the API is consistent in + # min/max calls on empty Series/DataFrames. See GH:33704 for more + # information + df = DataFrame({"x": to_datetime([])}) + expected_dt_series = Series(to_datetime([])) + # check axis 0 + assert (df.min(axis=0).x is pd.NaT) == (expected_dt_series.min() is pd.NaT) + assert (df.max(axis=0).x is pd.NaT) == (expected_dt_series.max() is pd.NaT) + + # check axis 1 + tm.assert_series_equal(df.min(axis=1), expected_dt_series) + tm.assert_series_equal(df.max(axis=1), expected_dt_series) + + def test_min_max_dt64_api_consistency_empty_df(self): + # check DataFrame/Series api consistency when calling min/max on an empty + # DataFrame/Series. + df = DataFrame({"x": []}) + expected_float_series = Series([], dtype=float) + # check axis 0 + assert np.isnan(df.min(axis=0).x) == np.isnan(expected_float_series.min()) + assert np.isnan(df.max(axis=0).x) == np.isnan(expected_float_series.max()) + # check axis 1 + tm.assert_series_equal(df.min(axis=1), expected_float_series) + tm.assert_series_equal(df.min(axis=1), expected_float_series) + + @pytest.mark.parametrize( + "initial", + ["2018-10-08 13:36:45+00:00", "2018-10-08 13:36:45+03:00"], # Non-UTC timezone + ) + @pytest.mark.parametrize("method", ["min", "max"]) + def test_preserve_timezone(self, initial: str, method): + # GH 28552 + initial_dt = to_datetime(initial) + expected = Series([initial_dt]) + df = DataFrame([expected]) + result = getattr(df, method)(axis=1) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("method", ["min", "max"]) + def test_minmax_tzaware_skipna_axis_1(self, method, skipna): + # GH#51242 + val = to_datetime("1900-01-01", utc=True) + df = DataFrame( + {"a": Series([pd.NaT, pd.NaT, val]), "b": Series([pd.NaT, val, val])} + ) + op = getattr(df, method) + result = op(axis=1, skipna=skipna) + if skipna: + expected = Series([pd.NaT, val, val]) + else: + expected = Series([pd.NaT, pd.NaT, val]) + tm.assert_series_equal(result, expected) + + def test_frame_any_with_timedelta(self): + # GH#17667 + df = DataFrame( + { + "a": Series([0, 0]), + "t": Series([to_timedelta(0, "s"), to_timedelta(1, "ms")]), + } + ) + + result = df.any(axis=0) + expected = Series(data=[False, True], index=["a", "t"]) + tm.assert_series_equal(result, expected) + + result = df.any(axis=1) + expected = Series(data=[False, True]) + tm.assert_series_equal(result, expected) + + def test_reductions_skipna_none_raises( + self, request, frame_or_series, all_reductions + ): + if all_reductions == "count": + request.applymarker( + pytest.mark.xfail(reason="Count does not accept skipna") + ) + obj = frame_or_series([1, 2, 3]) + msg = 'For argument "skipna" expected type bool, received type NoneType.' + with pytest.raises(ValueError, match=msg): + getattr(obj, all_reductions)(skipna=None) + + @td.skip_array_manager_invalid_test + def test_reduction_timestamp_smallest_unit(self): + # GH#52524 + df = DataFrame( + { + "a": Series([Timestamp("2019-12-31")], dtype="datetime64[s]"), + "b": Series( + [Timestamp("2019-12-31 00:00:00.123")], dtype="datetime64[ms]" + ), + } + ) + result = df.max() + expected = Series( + [Timestamp("2019-12-31"), Timestamp("2019-12-31 00:00:00.123")], + dtype="datetime64[ms]", + index=["a", "b"], + ) + tm.assert_series_equal(result, expected) + + @td.skip_array_manager_not_yet_implemented + def test_reduction_timedelta_smallest_unit(self): + # GH#52524 + df = DataFrame( + { + "a": Series([pd.Timedelta("1 days")], dtype="timedelta64[s]"), + "b": Series([pd.Timedelta("1 days")], dtype="timedelta64[ms]"), + } + ) + result = df.max() + expected = Series( + [pd.Timedelta("1 days"), pd.Timedelta("1 days")], + dtype="timedelta64[ms]", + index=["a", "b"], + ) + tm.assert_series_equal(result, expected) + + +class TestNuisanceColumns: + @pytest.mark.parametrize("method", ["any", "all"]) + def test_any_all_categorical_dtype_nuisance_column(self, method): + # GH#36076 DataFrame should match Series behavior + ser = Series([0, 1], dtype="category", name="A") + df = ser.to_frame() + + # Double-check the Series behavior is to raise + with pytest.raises(TypeError, match="does not support reduction"): + getattr(ser, method)() + + with pytest.raises(TypeError, match="does not support reduction"): + getattr(np, method)(ser) + + with pytest.raises(TypeError, match="does not support reduction"): + getattr(df, method)(bool_only=False) + + with pytest.raises(TypeError, match="does not support reduction"): + getattr(df, method)(bool_only=None) + + with pytest.raises(TypeError, match="does not support reduction"): + getattr(np, method)(df, axis=0) + + def test_median_categorical_dtype_nuisance_column(self): + # GH#21020 DataFrame.median should match Series.median + df = DataFrame({"A": Categorical([1, 2, 2, 2, 3])}) + ser = df["A"] + + # Double-check the Series behavior is to raise + with pytest.raises(TypeError, match="does not support reduction"): + ser.median() + + with pytest.raises(TypeError, match="does not support reduction"): + df.median(numeric_only=False) + + with pytest.raises(TypeError, match="does not support reduction"): + df.median() + + # same thing, but with an additional non-categorical column + df["B"] = df["A"].astype(int) + + with pytest.raises(TypeError, match="does not support reduction"): + df.median(numeric_only=False) + + with pytest.raises(TypeError, match="does not support reduction"): + df.median() + + # TODO: np.median(df, axis=0) gives np.array([2.0, 2.0]) instead + # of expected.values + + @pytest.mark.parametrize("method", ["min", "max"]) + def test_min_max_categorical_dtype_non_ordered_nuisance_column(self, method): + # GH#28949 DataFrame.min should behave like Series.min + cat = Categorical(["a", "b", "c", "b"], ordered=False) + ser = Series(cat) + df = ser.to_frame("A") + + # Double-check the Series behavior + with pytest.raises(TypeError, match="is not ordered for operation"): + getattr(ser, method)() + + with pytest.raises(TypeError, match="is not ordered for operation"): + getattr(np, method)(ser) + + with pytest.raises(TypeError, match="is not ordered for operation"): + getattr(df, method)(numeric_only=False) + + with pytest.raises(TypeError, match="is not ordered for operation"): + getattr(df, method)() + + with pytest.raises(TypeError, match="is not ordered for operation"): + getattr(np, method)(df, axis=0) + + # same thing, but with an additional non-categorical column + df["B"] = df["A"].astype(object) + with pytest.raises(TypeError, match="is not ordered for operation"): + getattr(df, method)() + + with pytest.raises(TypeError, match="is not ordered for operation"): + getattr(np, method)(df, axis=0) + + +class TestEmptyDataFrameReductions: + @pytest.mark.parametrize( + "opname, dtype, exp_value, exp_dtype", + [ + ("sum", np.int8, 0, np.int64), + ("prod", np.int8, 1, np.int_), + ("sum", np.int64, 0, np.int64), + ("prod", np.int64, 1, np.int64), + ("sum", np.uint8, 0, np.uint64), + ("prod", np.uint8, 1, np.uint), + ("sum", np.uint64, 0, np.uint64), + ("prod", np.uint64, 1, np.uint64), + ("sum", np.float32, 0, np.float32), + ("prod", np.float32, 1, np.float32), + ("sum", np.float64, 0, np.float64), + ], + ) + def test_df_empty_min_count_0(self, opname, dtype, exp_value, exp_dtype): + df = DataFrame({0: [], 1: []}, dtype=dtype) + result = getattr(df, opname)(min_count=0) + + expected = Series([exp_value, exp_value], dtype=exp_dtype) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "opname, dtype, exp_dtype", + [ + ("sum", np.int8, np.float64), + ("prod", np.int8, np.float64), + ("sum", np.int64, np.float64), + ("prod", np.int64, np.float64), + ("sum", np.uint8, np.float64), + ("prod", np.uint8, np.float64), + ("sum", np.uint64, np.float64), + ("prod", np.uint64, np.float64), + ("sum", np.float32, np.float32), + ("prod", np.float32, np.float32), + ("sum", np.float64, np.float64), + ], + ) + def test_df_empty_min_count_1(self, opname, dtype, exp_dtype): + df = DataFrame({0: [], 1: []}, dtype=dtype) + result = getattr(df, opname)(min_count=1) + + expected = Series([np.nan, np.nan], dtype=exp_dtype) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "opname, dtype, exp_value, exp_dtype", + [ + ("sum", "Int8", 0, ("Int32" if is_windows_np2_or_is32 else "Int64")), + ("prod", "Int8", 1, ("Int32" if is_windows_np2_or_is32 else "Int64")), + ("prod", "Int8", 1, ("Int32" if is_windows_np2_or_is32 else "Int64")), + ("sum", "Int64", 0, "Int64"), + ("prod", "Int64", 1, "Int64"), + ("sum", "UInt8", 0, ("UInt32" if is_windows_np2_or_is32 else "UInt64")), + ("prod", "UInt8", 1, ("UInt32" if is_windows_np2_or_is32 else "UInt64")), + ("sum", "UInt64", 0, "UInt64"), + ("prod", "UInt64", 1, "UInt64"), + ("sum", "Float32", 0, "Float32"), + ("prod", "Float32", 1, "Float32"), + ("sum", "Float64", 0, "Float64"), + ], + ) + def test_df_empty_nullable_min_count_0(self, opname, dtype, exp_value, exp_dtype): + df = DataFrame({0: [], 1: []}, dtype=dtype) + result = getattr(df, opname)(min_count=0) + + expected = Series([exp_value, exp_value], dtype=exp_dtype) + tm.assert_series_equal(result, expected) + + # TODO: why does min_count=1 impact the resulting Windows dtype + # differently than min_count=0? + @pytest.mark.parametrize( + "opname, dtype, exp_dtype", + [ + ("sum", "Int8", ("Int32" if is_windows_or_is32 else "Int64")), + ("prod", "Int8", ("Int32" if is_windows_or_is32 else "Int64")), + ("sum", "Int64", "Int64"), + ("prod", "Int64", "Int64"), + ("sum", "UInt8", ("UInt32" if is_windows_or_is32 else "UInt64")), + ("prod", "UInt8", ("UInt32" if is_windows_or_is32 else "UInt64")), + ("sum", "UInt64", "UInt64"), + ("prod", "UInt64", "UInt64"), + ("sum", "Float32", "Float32"), + ("prod", "Float32", "Float32"), + ("sum", "Float64", "Float64"), + ], + ) + def test_df_empty_nullable_min_count_1(self, opname, dtype, exp_dtype): + df = DataFrame({0: [], 1: []}, dtype=dtype) + result = getattr(df, opname)(min_count=1) + + expected = Series([pd.NA, pd.NA], dtype=exp_dtype) + tm.assert_series_equal(result, expected) + + +def test_sum_timedelta64_skipna_false(using_array_manager, request): + # GH#17235 + if using_array_manager: + mark = pytest.mark.xfail( + reason="Incorrect type inference on NaT in reduction result" + ) + request.applymarker(mark) + + arr = np.arange(8).astype(np.int64).view("m8[s]").reshape(4, 2) + arr[-1, -1] = "Nat" + + df = DataFrame(arr) + assert (df.dtypes == arr.dtype).all() + + result = df.sum(skipna=False) + expected = Series([pd.Timedelta(seconds=12), pd.NaT], dtype="m8[s]") + tm.assert_series_equal(result, expected) + + result = df.sum(axis=0, skipna=False) + tm.assert_series_equal(result, expected) + + result = df.sum(axis=1, skipna=False) + expected = Series( + [ + pd.Timedelta(seconds=1), + pd.Timedelta(seconds=5), + pd.Timedelta(seconds=9), + pd.NaT, + ], + dtype="m8[s]", + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.xfail( + using_pyarrow_string_dtype(), reason="sum doesn't work with arrow strings" +) +def test_mixed_frame_with_integer_sum(): + # https://github.com/pandas-dev/pandas/issues/34520 + df = DataFrame([["a", 1]], columns=list("ab")) + df = df.astype({"b": "Int64"}) + result = df.sum() + expected = Series(["a", 1], index=["a", "b"]) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("numeric_only", [True, False, None]) +@pytest.mark.parametrize("method", ["min", "max"]) +def test_minmax_extensionarray(method, numeric_only): + # https://github.com/pandas-dev/pandas/issues/32651 + int64_info = np.iinfo("int64") + ser = Series([int64_info.max, None, int64_info.min], dtype=pd.Int64Dtype()) + df = DataFrame({"Int64": ser}) + result = getattr(df, method)(numeric_only=numeric_only) + expected = Series( + [getattr(int64_info, method)], + dtype="Int64", + index=Index(["Int64"]), + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("ts_value", [Timestamp("2000-01-01"), pd.NaT]) +def test_frame_mixed_numeric_object_with_timestamp(ts_value): + # GH 13912 + df = DataFrame({"a": [1], "b": [1.1], "c": ["foo"], "d": [ts_value]}) + with pytest.raises(TypeError, match="does not support reduction"): + df.sum() + + +def test_prod_sum_min_count_mixed_object(): + # https://github.com/pandas-dev/pandas/issues/41074 + df = DataFrame([1, "a", True]) + + result = df.prod(axis=0, min_count=1, numeric_only=False) + expected = Series(["a"], dtype=object) + tm.assert_series_equal(result, expected) + + msg = re.escape("unsupported operand type(s) for +: 'int' and 'str'") + with pytest.raises(TypeError, match=msg): + df.sum(axis=0, min_count=1, numeric_only=False) + + +@pytest.mark.parametrize("method", ["min", "max", "mean", "median", "skew", "kurt"]) +@pytest.mark.parametrize("numeric_only", [True, False]) +@pytest.mark.parametrize("dtype", ["float64", "Float64"]) +def test_reduction_axis_none_returns_scalar(method, numeric_only, dtype): + # GH#21597 As of 2.0, axis=None reduces over all axes. + + df = DataFrame(np.random.default_rng(2).standard_normal((4, 4)), dtype=dtype) + + result = getattr(df, method)(axis=None, numeric_only=numeric_only) + np_arr = df.to_numpy(dtype=np.float64) + if method in {"skew", "kurt"}: + comp_mod = pytest.importorskip("scipy.stats") + if method == "kurt": + method = "kurtosis" + expected = getattr(comp_mod, method)(np_arr, bias=False, axis=None) + tm.assert_almost_equal(result, expected) + else: + expected = getattr(np, method)(np_arr, axis=None) + assert result == expected + + +@pytest.mark.parametrize( + "kernel", + [ + "corr", + "corrwith", + "cov", + "idxmax", + "idxmin", + "kurt", + "max", + "mean", + "median", + "min", + "prod", + "quantile", + "sem", + "skew", + "std", + "sum", + "var", + ], +) +def test_fails_on_non_numeric(kernel): + # GH#46852 + df = DataFrame({"a": [1, 2, 3], "b": object}) + args = (df,) if kernel == "corrwith" else () + msg = "|".join( + [ + "not allowed for this dtype", + "argument must be a string or a number", + "not supported between instances of", + "unsupported operand type", + "argument must be a string or a real number", + ] + ) + if kernel == "median": + # slightly different message on different builds + msg1 = ( + r"Cannot convert \[\[ " + r"\]\] to numeric" + ) + msg2 = ( + r"Cannot convert \[ " + r"\] to numeric" + ) + msg = "|".join([msg1, msg2]) + with pytest.raises(TypeError, match=msg): + getattr(df, kernel)(*args) + + +@pytest.mark.parametrize( + "method", + [ + "all", + "any", + "count", + "idxmax", + "idxmin", + "kurt", + "kurtosis", + "max", + "mean", + "median", + "min", + "nunique", + "prod", + "product", + "sem", + "skew", + "std", + "sum", + "var", + ], +) +@pytest.mark.parametrize("min_count", [0, 2]) +def test_numeric_ea_axis_1(method, skipna, min_count, any_numeric_ea_dtype): + # GH 54341 + df = DataFrame( + { + "a": Series([0, 1, 2, 3], dtype=any_numeric_ea_dtype), + "b": Series([0, 1, pd.NA, 3], dtype=any_numeric_ea_dtype), + }, + ) + expected_df = DataFrame( + { + "a": [0.0, 1.0, 2.0, 3.0], + "b": [0.0, 1.0, np.nan, 3.0], + }, + ) + if method in ("count", "nunique"): + expected_dtype = "int64" + elif method in ("all", "any"): + expected_dtype = "boolean" + elif method in ( + "kurt", + "kurtosis", + "mean", + "median", + "sem", + "skew", + "std", + "var", + ) and not any_numeric_ea_dtype.startswith("Float"): + expected_dtype = "Float64" + else: + expected_dtype = any_numeric_ea_dtype + + kwargs = {} + if method not in ("count", "nunique", "quantile"): + kwargs["skipna"] = skipna + if method in ("prod", "product", "sum"): + kwargs["min_count"] = min_count + + warn = None + msg = None + if not skipna and method in ("idxmax", "idxmin"): + warn = FutureWarning + msg = f"The behavior of DataFrame.{method} with all-NA values" + with tm.assert_produces_warning(warn, match=msg): + result = getattr(df, method)(axis=1, **kwargs) + with tm.assert_produces_warning(warn, match=msg): + expected = getattr(expected_df, method)(axis=1, **kwargs) + if method not in ("idxmax", "idxmin"): + expected = expected.astype(expected_dtype) + tm.assert_series_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_repr.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_repr.py new file mode 100644 index 0000000000000000000000000000000000000000..776007fb9691d3f6faa185ed037bdd2ab1fd47fe --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_repr.py @@ -0,0 +1,521 @@ +from datetime import ( + datetime, + timedelta, +) +from io import StringIO + +import numpy as np +import pytest + +from pandas._config import using_pyarrow_string_dtype + +from pandas import ( + NA, + Categorical, + CategoricalIndex, + DataFrame, + IntervalIndex, + MultiIndex, + NaT, + PeriodIndex, + Series, + Timestamp, + date_range, + option_context, + period_range, +) +import pandas._testing as tm + + +class TestDataFrameRepr: + def test_repr_should_return_str(self): + # https://docs.python.org/3/reference/datamodel.html#object.__repr__ + # "...The return value must be a string object." + + # (str on py2.x, str (unicode) on py3) + + data = [8, 5, 3, 5] + index1 = ["\u03c3", "\u03c4", "\u03c5", "\u03c6"] + cols = ["\u03c8"] + df = DataFrame(data, columns=cols, index=index1) + assert type(df.__repr__()) is str # noqa: E721 + + ser = df[cols[0]] + assert type(ser.__repr__()) is str # noqa: E721 + + def test_repr_bytes_61_lines(self): + # GH#12857 + lets = list("ACDEFGHIJKLMNOP") + words = np.random.default_rng(2).choice(lets, (1000, 50)) + df = DataFrame(words).astype("U1") + assert (df.dtypes == object).all() + + # smoke tests; at one point this raised with 61 but not 60 + repr(df) + repr(df.iloc[:60, :]) + repr(df.iloc[:61, :]) + + def test_repr_unicode_level_names(self, frame_or_series): + index = MultiIndex.from_tuples([(0, 0), (1, 1)], names=["\u0394", "i1"]) + + obj = DataFrame(np.random.default_rng(2).standard_normal((2, 4)), index=index) + obj = tm.get_obj(obj, frame_or_series) + repr(obj) + + def test_assign_index_sequences(self): + # GH#2200 + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}).set_index( + ["a", "b"] + ) + index = list(df.index) + index[0] = ("faz", "boo") + df.index = index + repr(df) + + # this travels an improper code path + index[0] = ["faz", "boo"] + df.index = index + repr(df) + + def test_repr_with_mi_nat(self): + df = DataFrame({"X": [1, 2]}, index=[[NaT, Timestamp("20130101")], ["a", "b"]]) + result = repr(df) + expected = " X\nNaT a 1\n2013-01-01 b 2" + assert result == expected + + def test_repr_with_different_nulls(self): + # GH45263 + df = DataFrame([1, 2, 3, 4], [True, None, np.nan, NaT]) + result = repr(df) + expected = """ 0 +True 1 +None 2 +NaN 3 +NaT 4""" + assert result == expected + + def test_repr_with_different_nulls_cols(self): + # GH45263 + d = {np.nan: [1, 2], None: [3, 4], NaT: [6, 7], True: [8, 9]} + df = DataFrame(data=d) + result = repr(df) + expected = """ NaN None NaT True +0 1 3 6 8 +1 2 4 7 9""" + assert result == expected + + def test_multiindex_na_repr(self): + # only an issue with long columns + df3 = DataFrame( + { + "A" * 30: {("A", "A0006000", "nuit"): "A0006000"}, + "B" * 30: {("A", "A0006000", "nuit"): np.nan}, + "C" * 30: {("A", "A0006000", "nuit"): np.nan}, + "D" * 30: {("A", "A0006000", "nuit"): np.nan}, + "E" * 30: {("A", "A0006000", "nuit"): "A"}, + "F" * 30: {("A", "A0006000", "nuit"): np.nan}, + } + ) + + idf = df3.set_index(["A" * 30, "C" * 30]) + repr(idf) + + def test_repr_name_coincide(self): + index = MultiIndex.from_tuples( + [("a", 0, "foo"), ("b", 1, "bar")], names=["a", "b", "c"] + ) + + df = DataFrame({"value": [0, 1]}, index=index) + + lines = repr(df).split("\n") + assert lines[2].startswith("a 0 foo") + + def test_repr_to_string( + self, + multiindex_year_month_day_dataframe_random_data, + multiindex_dataframe_random_data, + ): + ymd = multiindex_year_month_day_dataframe_random_data + frame = multiindex_dataframe_random_data + + repr(frame) + repr(ymd) + repr(frame.T) + repr(ymd.T) + + buf = StringIO() + frame.to_string(buf=buf) + ymd.to_string(buf=buf) + frame.T.to_string(buf=buf) + ymd.T.to_string(buf=buf) + + def test_repr_empty(self): + # empty + repr(DataFrame()) + + # empty with index + frame = DataFrame(index=np.arange(1000)) + repr(frame) + + def test_repr_mixed(self, float_string_frame): + # mixed + repr(float_string_frame) + + @pytest.mark.slow + def test_repr_mixed_big(self): + # big mixed + biggie = DataFrame( + { + "A": np.random.default_rng(2).standard_normal(200), + "B": [str(i) for i in range(200)], + }, + index=range(200), + ) + biggie.loc[:20, "A"] = np.nan + biggie.loc[:20, "B"] = np.nan + + repr(biggie) + + @pytest.mark.xfail(using_pyarrow_string_dtype(), reason="/r in") + def test_repr(self): + # columns but no index + no_index = DataFrame(columns=[0, 1, 3]) + repr(no_index) + + df = DataFrame(["a\n\r\tb"], columns=["a\n\r\td"], index=["a\n\r\tf"]) + assert "\t" not in repr(df) + assert "\r" not in repr(df) + assert "a\n" not in repr(df) + + def test_repr_dimensions(self): + df = DataFrame([[1, 2], [3, 4]]) + with option_context("display.show_dimensions", True): + assert "2 rows x 2 columns" in repr(df) + + with option_context("display.show_dimensions", False): + assert "2 rows x 2 columns" not in repr(df) + + with option_context("display.show_dimensions", "truncate"): + assert "2 rows x 2 columns" not in repr(df) + + @pytest.mark.slow + def test_repr_big(self): + # big one + biggie = DataFrame(np.zeros((200, 4)), columns=range(4), index=range(200)) + repr(biggie) + + def test_repr_unsortable(self): + # columns are not sortable + + unsortable = DataFrame( + { + "foo": [1] * 50, + datetime.today(): [1] * 50, + "bar": ["bar"] * 50, + datetime.today() + timedelta(1): ["bar"] * 50, + }, + index=np.arange(50), + ) + repr(unsortable) + + def test_repr_float_frame_options(self, float_frame): + repr(float_frame) + + with option_context("display.precision", 3): + repr(float_frame) + + with option_context("display.max_rows", 10, "display.max_columns", 2): + repr(float_frame) + + with option_context("display.max_rows", 1000, "display.max_columns", 1000): + repr(float_frame) + + def test_repr_unicode(self): + uval = "\u03c3\u03c3\u03c3\u03c3" + + df = DataFrame({"A": [uval, uval]}) + + result = repr(df) + ex_top = " A" + assert result.split("\n")[0].rstrip() == ex_top + + df = DataFrame({"A": [uval, uval]}) + result = repr(df) + assert result.split("\n")[0].rstrip() == ex_top + + def test_unicode_string_with_unicode(self): + df = DataFrame({"A": ["\u05d0"]}) + str(df) + + def test_repr_unicode_columns(self): + df = DataFrame({"\u05d0": [1, 2, 3], "\u05d1": [4, 5, 6], "c": [7, 8, 9]}) + repr(df.columns) # should not raise UnicodeDecodeError + + def test_str_to_bytes_raises(self): + # GH 26447 + df = DataFrame({"A": ["abc"]}) + msg = "^'str' object cannot be interpreted as an integer$" + with pytest.raises(TypeError, match=msg): + bytes(df) + + def test_very_wide_repr(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 20)), + columns=np.array(["a" * 10] * 20, dtype=object), + ) + repr(df) + + def test_repr_column_name_unicode_truncation_bug(self): + # #1906 + df = DataFrame( + { + "Id": [7117434], + "StringCol": ( + "Is it possible to modify drop plot code" + "so that the output graph is displayed " + "in iphone simulator, Is it possible to " + "modify drop plot code so that the " + "output graph is \xe2\x80\xa8displayed " + "in iphone simulator.Now we are adding " + "the CSV file externally. I want to Call " + "the File through the code.." + ), + } + ) + + with option_context("display.max_columns", 20): + assert "StringCol" in repr(df) + + def test_latex_repr(self): + pytest.importorskip("jinja2") + expected = r"""\begin{tabular}{llll} +\toprule + & 0 & 1 & 2 \\ +\midrule +0 & $\alpha$ & b & c \\ +1 & 1 & 2 & 3 \\ +\bottomrule +\end{tabular} +""" + with option_context( + "styler.format.escape", None, "styler.render.repr", "latex" + ): + df = DataFrame([[r"$\alpha$", "b", "c"], [1, 2, 3]]) + result = df._repr_latex_() + assert result == expected + + # GH 12182 + assert df._repr_latex_() is None + + def test_repr_with_datetimeindex(self): + df = DataFrame({"A": [1, 2, 3]}, index=date_range("2000", periods=3)) + result = repr(df) + expected = " A\n2000-01-01 1\n2000-01-02 2\n2000-01-03 3" + assert result == expected + + def test_repr_with_intervalindex(self): + # https://github.com/pandas-dev/pandas/pull/24134/files + df = DataFrame( + {"A": [1, 2, 3, 4]}, index=IntervalIndex.from_breaks([0, 1, 2, 3, 4]) + ) + result = repr(df) + expected = " A\n(0, 1] 1\n(1, 2] 2\n(2, 3] 3\n(3, 4] 4" + assert result == expected + + def test_repr_with_categorical_index(self): + df = DataFrame({"A": [1, 2, 3]}, index=CategoricalIndex(["a", "b", "c"])) + result = repr(df) + expected = " A\na 1\nb 2\nc 3" + assert result == expected + + def test_repr_categorical_dates_periods(self): + # normal DataFrame + dt = date_range("2011-01-01 09:00", freq="h", periods=5, tz="US/Eastern") + p = period_range("2011-01", freq="M", periods=5) + df = DataFrame({"dt": dt, "p": p}) + exp = """ dt p +0 2011-01-01 09:00:00-05:00 2011-01 +1 2011-01-01 10:00:00-05:00 2011-02 +2 2011-01-01 11:00:00-05:00 2011-03 +3 2011-01-01 12:00:00-05:00 2011-04 +4 2011-01-01 13:00:00-05:00 2011-05""" + + assert repr(df) == exp + + df2 = DataFrame({"dt": Categorical(dt), "p": Categorical(p)}) + assert repr(df2) == exp + + @pytest.mark.parametrize("arg", [np.datetime64, np.timedelta64]) + @pytest.mark.parametrize( + "box, expected", + [[Series, "0 NaT\ndtype: object"], [DataFrame, " 0\n0 NaT"]], + ) + def test_repr_np_nat_with_object(self, arg, box, expected): + # GH 25445 + result = repr(box([arg("NaT")], dtype=object)) + assert result == expected + + def test_frame_datetime64_pre1900_repr(self): + df = DataFrame({"year": date_range("1/1/1700", periods=50, freq="YE-DEC")}) + # it works! + repr(df) + + def test_frame_to_string_with_periodindex(self): + index = PeriodIndex(["2011-1", "2011-2", "2011-3"], freq="M") + frame = DataFrame(np.random.default_rng(2).standard_normal((3, 4)), index=index) + + # it works! + frame.to_string() + + def test_to_string_ea_na_in_multiindex(self): + # GH#47986 + df = DataFrame( + {"a": [1, 2]}, + index=MultiIndex.from_arrays([Series([NA, 1], dtype="Int64")]), + ) + + result = df.to_string() + expected = """ a + 1 +1 2""" + assert result == expected + + def test_datetime64tz_slice_non_truncate(self): + # GH 30263 + df = DataFrame({"x": date_range("2019", periods=10, tz="UTC")}) + expected = repr(df) + df = df.iloc[:, :5] + result = repr(df) + assert result == expected + + def test_to_records_no_typeerror_in_repr(self): + # GH 48526 + df = DataFrame([["a", "b"], ["c", "d"], ["e", "f"]], columns=["left", "right"]) + df["record"] = df[["left", "right"]].to_records() + expected = """ left right record +0 a b [0, a, b] +1 c d [1, c, d] +2 e f [2, e, f]""" + result = repr(df) + assert result == expected + + def test_to_records_with_na_record_value(self): + # GH 48526 + df = DataFrame( + [["a", np.nan], ["c", "d"], ["e", "f"]], columns=["left", "right"] + ) + df["record"] = df[["left", "right"]].to_records() + expected = """ left right record +0 a NaN [0, a, nan] +1 c d [1, c, d] +2 e f [2, e, f]""" + result = repr(df) + assert result == expected + + def test_to_records_with_na_record(self): + # GH 48526 + df = DataFrame( + [["a", "b"], [np.nan, np.nan], ["e", "f"]], columns=[np.nan, "right"] + ) + df["record"] = df[[np.nan, "right"]].to_records() + expected = """ NaN right record +0 a b [0, a, b] +1 NaN NaN [1, nan, nan] +2 e f [2, e, f]""" + result = repr(df) + assert result == expected + + def test_to_records_with_inf_as_na_record(self): + # GH 48526 + expected = """ NaN inf record +0 inf b [0, inf, b] +1 NaN NaN [1, nan, nan] +2 e f [2, e, f]""" + msg = "use_inf_as_na option is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + with option_context("use_inf_as_na", True): + df = DataFrame( + [[np.inf, "b"], [np.nan, np.nan], ["e", "f"]], + columns=[np.nan, np.inf], + ) + df["record"] = df[[np.nan, np.inf]].to_records() + result = repr(df) + assert result == expected + + def test_to_records_with_inf_record(self): + # GH 48526 + expected = """ NaN inf record +0 inf b [0, inf, b] +1 NaN NaN [1, nan, nan] +2 e f [2, e, f]""" + msg = "use_inf_as_na option is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + with option_context("use_inf_as_na", False): + df = DataFrame( + [[np.inf, "b"], [np.nan, np.nan], ["e", "f"]], + columns=[np.nan, np.inf], + ) + df["record"] = df[[np.nan, np.inf]].to_records() + result = repr(df) + assert result == expected + + def test_masked_ea_with_formatter(self): + # GH#39336 + df = DataFrame( + { + "a": Series([0.123456789, 1.123456789], dtype="Float64"), + "b": Series([1, 2], dtype="Int64"), + } + ) + result = df.to_string(formatters=["{:.2f}".format, "{:.2f}".format]) + expected = """ a b +0 0.12 1.00 +1 1.12 2.00""" + assert result == expected + + def test_repr_ea_columns(self, any_string_dtype): + # GH#54797 + pytest.importorskip("pyarrow") + df = DataFrame({"long_column_name": [1, 2, 3], "col2": [4, 5, 6]}) + df.columns = df.columns.astype(any_string_dtype) + expected = """ long_column_name col2 +0 1 4 +1 2 5 +2 3 6""" + assert repr(df) == expected + + +@pytest.mark.parametrize( + "data,output", + [ + ([2, complex("nan"), 1], [" 2.0+0.0j", " NaN+0.0j", " 1.0+0.0j"]), + ([2, complex("nan"), -1], [" 2.0+0.0j", " NaN+0.0j", "-1.0+0.0j"]), + ([-2, complex("nan"), -1], ["-2.0+0.0j", " NaN+0.0j", "-1.0+0.0j"]), + ([-1.23j, complex("nan"), -1], ["-0.00-1.23j", " NaN+0.00j", "-1.00+0.00j"]), + ([1.23j, complex("nan"), 1.23], [" 0.00+1.23j", " NaN+0.00j", " 1.23+0.00j"]), + ( + [-1.23j, complex(np.nan, np.nan), 1], + ["-0.00-1.23j", " NaN+ NaNj", " 1.00+0.00j"], + ), + ( + [-1.23j, complex(1.2, np.nan), 1], + ["-0.00-1.23j", " 1.20+ NaNj", " 1.00+0.00j"], + ), + ( + [-1.23j, complex(np.nan, -1.2), 1], + ["-0.00-1.23j", " NaN-1.20j", " 1.00+0.00j"], + ), + ], +) +@pytest.mark.parametrize("as_frame", [True, False]) +def test_repr_with_complex_nans(data, output, as_frame): + # GH#53762, GH#53841 + obj = Series(np.array(data)) + if as_frame: + obj = obj.to_frame(name="val") + reprs = [f"{i} {val}" for i, val in enumerate(output)] + expected = f"{'val': >{len(reprs[0])}}\n" + "\n".join(reprs) + else: + reprs = [f"{i} {val}" for i, val in enumerate(output)] + expected = "\n".join(reprs) + "\ndtype: complex128" + assert str(obj) == expected, f"\n{str(obj)}\n\n{expected}" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_stack_unstack.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_stack_unstack.py new file mode 100644 index 0000000000000000000000000000000000000000..d8b92091260a3b0fbde9ddd99e46d31d1fdbf320 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_stack_unstack.py @@ -0,0 +1,2678 @@ +from datetime import datetime +import itertools +import re + +import numpy as np +import pytest + +from pandas._libs import lib +from pandas.errors import PerformanceWarning + +import pandas as pd +from pandas import ( + DataFrame, + Index, + MultiIndex, + Period, + Series, + Timedelta, + date_range, +) +import pandas._testing as tm +from pandas.core.reshape import reshape as reshape_lib + + +@pytest.fixture(params=[True, False]) +def future_stack(request): + return request.param + + +class TestDataFrameReshape: + def test_stack_unstack(self, float_frame, future_stack): + df = float_frame.copy() + df[:] = np.arange(np.prod(df.shape)).reshape(df.shape) + + stacked = df.stack(future_stack=future_stack) + stacked_df = DataFrame({"foo": stacked, "bar": stacked}) + + unstacked = stacked.unstack() + unstacked_df = stacked_df.unstack() + + tm.assert_frame_equal(unstacked, df) + tm.assert_frame_equal(unstacked_df["bar"], df) + + unstacked_cols = stacked.unstack(0) + unstacked_cols_df = stacked_df.unstack(0) + tm.assert_frame_equal(unstacked_cols.T, df) + tm.assert_frame_equal(unstacked_cols_df["bar"].T, df) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_mixed_level(self, future_stack): + # GH 18310 + levels = [range(3), [3, "a", "b"], [1, 2]] + + # flat columns: + df = DataFrame(1, index=levels[0], columns=levels[1]) + result = df.stack(future_stack=future_stack) + expected = Series(1, index=MultiIndex.from_product(levels[:2])) + tm.assert_series_equal(result, expected) + + # MultiIndex columns: + df = DataFrame(1, index=levels[0], columns=MultiIndex.from_product(levels[1:])) + result = df.stack(1, future_stack=future_stack) + expected = DataFrame( + 1, index=MultiIndex.from_product([levels[0], levels[2]]), columns=levels[1] + ) + tm.assert_frame_equal(result, expected) + + # as above, but used labels in level are actually of homogeneous type + result = df[["a", "b"]].stack(1, future_stack=future_stack) + expected = expected[["a", "b"]] + tm.assert_frame_equal(result, expected) + + def test_unstack_not_consolidated(self, using_array_manager): + # Gh#34708 + df = DataFrame({"x": [1, 2, np.nan], "y": [3.0, 4, np.nan]}) + df2 = df[["x"]] + df2["y"] = df["y"] + if not using_array_manager: + assert len(df2._mgr.blocks) == 2 + + res = df2.unstack() + expected = df.unstack() + tm.assert_series_equal(res, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_unstack_fill(self, future_stack): + # GH #9746: fill_value keyword argument for Series + # and DataFrame unstack + + # From a series + data = Series([1, 2, 4, 5], dtype=np.int16) + data.index = MultiIndex.from_tuples( + [("x", "a"), ("x", "b"), ("y", "b"), ("z", "a")] + ) + + result = data.unstack(fill_value=-1) + expected = DataFrame( + {"a": [1, -1, 5], "b": [2, 4, -1]}, index=["x", "y", "z"], dtype=np.int16 + ) + tm.assert_frame_equal(result, expected) + + # From a series with incorrect data type for fill_value + result = data.unstack(fill_value=0.5) + expected = DataFrame( + {"a": [1, 0.5, 5], "b": [2, 4, 0.5]}, index=["x", "y", "z"], dtype=float + ) + tm.assert_frame_equal(result, expected) + + # GH #13971: fill_value when unstacking multiple levels: + df = DataFrame( + {"x": ["a", "a", "b"], "y": ["j", "k", "j"], "z": [0, 1, 2], "w": [0, 1, 2]} + ).set_index(["x", "y", "z"]) + unstacked = df.unstack(["x", "y"], fill_value=0) + key = ("w", "b", "j") + expected = unstacked[key] + result = Series([0, 0, 2], index=unstacked.index, name=key) + tm.assert_series_equal(result, expected) + + stacked = unstacked.stack(["x", "y"], future_stack=future_stack) + stacked.index = stacked.index.reorder_levels(df.index.names) + # Workaround for GH #17886 (unnecessarily casts to float): + stacked = stacked.astype(np.int64) + result = stacked.loc[df.index] + tm.assert_frame_equal(result, df) + + # From a series + s = df["w"] + result = s.unstack(["x", "y"], fill_value=0) + expected = unstacked["w"] + tm.assert_frame_equal(result, expected) + + def test_unstack_fill_frame(self): + # From a dataframe + rows = [[1, 2], [3, 4], [5, 6], [7, 8]] + df = DataFrame(rows, columns=list("AB"), dtype=np.int32) + df.index = MultiIndex.from_tuples( + [("x", "a"), ("x", "b"), ("y", "b"), ("z", "a")] + ) + + result = df.unstack(fill_value=-1) + + rows = [[1, 3, 2, 4], [-1, 5, -1, 6], [7, -1, 8, -1]] + expected = DataFrame(rows, index=list("xyz"), dtype=np.int32) + expected.columns = MultiIndex.from_tuples( + [("A", "a"), ("A", "b"), ("B", "a"), ("B", "b")] + ) + tm.assert_frame_equal(result, expected) + + # From a mixed type dataframe + df["A"] = df["A"].astype(np.int16) + df["B"] = df["B"].astype(np.float64) + + result = df.unstack(fill_value=-1) + expected["A"] = expected["A"].astype(np.int16) + expected["B"] = expected["B"].astype(np.float64) + tm.assert_frame_equal(result, expected) + + # From a dataframe with incorrect data type for fill_value + result = df.unstack(fill_value=0.5) + + rows = [[1, 3, 2, 4], [0.5, 5, 0.5, 6], [7, 0.5, 8, 0.5]] + expected = DataFrame(rows, index=list("xyz"), dtype=float) + expected.columns = MultiIndex.from_tuples( + [("A", "a"), ("A", "b"), ("B", "a"), ("B", "b")] + ) + tm.assert_frame_equal(result, expected) + + def test_unstack_fill_frame_datetime(self): + # Test unstacking with date times + dv = date_range("2012-01-01", periods=4).values + data = Series(dv) + data.index = MultiIndex.from_tuples( + [("x", "a"), ("x", "b"), ("y", "b"), ("z", "a")] + ) + + result = data.unstack() + expected = DataFrame( + {"a": [dv[0], pd.NaT, dv[3]], "b": [dv[1], dv[2], pd.NaT]}, + index=["x", "y", "z"], + ) + tm.assert_frame_equal(result, expected) + + result = data.unstack(fill_value=dv[0]) + expected = DataFrame( + {"a": [dv[0], dv[0], dv[3]], "b": [dv[1], dv[2], dv[0]]}, + index=["x", "y", "z"], + ) + tm.assert_frame_equal(result, expected) + + def test_unstack_fill_frame_timedelta(self): + # Test unstacking with time deltas + td = [Timedelta(days=i) for i in range(4)] + data = Series(td) + data.index = MultiIndex.from_tuples( + [("x", "a"), ("x", "b"), ("y", "b"), ("z", "a")] + ) + + result = data.unstack() + expected = DataFrame( + {"a": [td[0], pd.NaT, td[3]], "b": [td[1], td[2], pd.NaT]}, + index=["x", "y", "z"], + ) + tm.assert_frame_equal(result, expected) + + result = data.unstack(fill_value=td[1]) + expected = DataFrame( + {"a": [td[0], td[1], td[3]], "b": [td[1], td[2], td[1]]}, + index=["x", "y", "z"], + ) + tm.assert_frame_equal(result, expected) + + def test_unstack_fill_frame_period(self): + # Test unstacking with period + periods = [ + Period("2012-01"), + Period("2012-02"), + Period("2012-03"), + Period("2012-04"), + ] + data = Series(periods) + data.index = MultiIndex.from_tuples( + [("x", "a"), ("x", "b"), ("y", "b"), ("z", "a")] + ) + + result = data.unstack() + expected = DataFrame( + {"a": [periods[0], None, periods[3]], "b": [periods[1], periods[2], None]}, + index=["x", "y", "z"], + ) + tm.assert_frame_equal(result, expected) + + result = data.unstack(fill_value=periods[1]) + expected = DataFrame( + { + "a": [periods[0], periods[1], periods[3]], + "b": [periods[1], periods[2], periods[1]], + }, + index=["x", "y", "z"], + ) + tm.assert_frame_equal(result, expected) + + def test_unstack_fill_frame_categorical(self): + # Test unstacking with categorical + data = Series(["a", "b", "c", "a"], dtype="category") + data.index = MultiIndex.from_tuples( + [("x", "a"), ("x", "b"), ("y", "b"), ("z", "a")] + ) + + # By default missing values will be NaN + result = data.unstack() + expected = DataFrame( + { + "a": pd.Categorical(list("axa"), categories=list("abc")), + "b": pd.Categorical(list("bcx"), categories=list("abc")), + }, + index=list("xyz"), + ) + tm.assert_frame_equal(result, expected) + + # Fill with non-category results in a ValueError + msg = r"Cannot setitem on a Categorical with a new category \(d\)" + with pytest.raises(TypeError, match=msg): + data.unstack(fill_value="d") + + # Fill with category value replaces missing values as expected + result = data.unstack(fill_value="c") + expected = DataFrame( + { + "a": pd.Categorical(list("aca"), categories=list("abc")), + "b": pd.Categorical(list("bcc"), categories=list("abc")), + }, + index=list("xyz"), + ) + tm.assert_frame_equal(result, expected) + + def test_unstack_tuplename_in_multiindex(self): + # GH 19966 + idx = MultiIndex.from_product( + [["a", "b", "c"], [1, 2, 3]], names=[("A", "a"), ("B", "b")] + ) + df = DataFrame({"d": [1] * 9, "e": [2] * 9}, index=idx) + result = df.unstack(("A", "a")) + + expected = DataFrame( + [[1, 1, 1, 2, 2, 2], [1, 1, 1, 2, 2, 2], [1, 1, 1, 2, 2, 2]], + columns=MultiIndex.from_tuples( + [ + ("d", "a"), + ("d", "b"), + ("d", "c"), + ("e", "a"), + ("e", "b"), + ("e", "c"), + ], + names=[None, ("A", "a")], + ), + index=Index([1, 2, 3], name=("B", "b")), + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "unstack_idx, expected_values, expected_index, expected_columns", + [ + ( + ("A", "a"), + [[1, 1, 2, 2], [1, 1, 2, 2], [1, 1, 2, 2], [1, 1, 2, 2]], + MultiIndex.from_tuples( + [(1, 3), (1, 4), (2, 3), (2, 4)], names=["B", "C"] + ), + MultiIndex.from_tuples( + [("d", "a"), ("d", "b"), ("e", "a"), ("e", "b")], + names=[None, ("A", "a")], + ), + ), + ( + (("A", "a"), "B"), + [[1, 1, 1, 1, 2, 2, 2, 2], [1, 1, 1, 1, 2, 2, 2, 2]], + Index([3, 4], name="C"), + MultiIndex.from_tuples( + [ + ("d", "a", 1), + ("d", "a", 2), + ("d", "b", 1), + ("d", "b", 2), + ("e", "a", 1), + ("e", "a", 2), + ("e", "b", 1), + ("e", "b", 2), + ], + names=[None, ("A", "a"), "B"], + ), + ), + ], + ) + def test_unstack_mixed_type_name_in_multiindex( + self, unstack_idx, expected_values, expected_index, expected_columns + ): + # GH 19966 + idx = MultiIndex.from_product( + [["a", "b"], [1, 2], [3, 4]], names=[("A", "a"), "B", "C"] + ) + df = DataFrame({"d": [1] * 8, "e": [2] * 8}, index=idx) + result = df.unstack(unstack_idx) + + expected = DataFrame( + expected_values, columns=expected_columns, index=expected_index + ) + tm.assert_frame_equal(result, expected) + + def test_unstack_preserve_dtypes(self): + # Checks fix for #11847 + df = DataFrame( + { + "state": ["IL", "MI", "NC"], + "index": ["a", "b", "c"], + "some_categories": Series(["a", "b", "c"]).astype("category"), + "A": np.random.default_rng(2).random(3), + "B": 1, + "C": "foo", + "D": pd.Timestamp("20010102"), + "E": Series([1.0, 50.0, 100.0]).astype("float32"), + "F": Series([3.0, 4.0, 5.0]).astype("float64"), + "G": False, + "H": Series([1, 200, 923442]).astype("int8"), + } + ) + + def unstack_and_compare(df, column_name): + unstacked1 = df.unstack([column_name]) + unstacked2 = df.unstack(column_name) + tm.assert_frame_equal(unstacked1, unstacked2) + + df1 = df.set_index(["state", "index"]) + unstack_and_compare(df1, "index") + + df1 = df.set_index(["state", "some_categories"]) + unstack_and_compare(df1, "some_categories") + + df1 = df.set_index(["F", "C"]) + unstack_and_compare(df1, "F") + + df1 = df.set_index(["G", "B", "state"]) + unstack_and_compare(df1, "B") + + df1 = df.set_index(["E", "A"]) + unstack_and_compare(df1, "E") + + df1 = df.set_index(["state", "index"]) + s = df1["A"] + unstack_and_compare(s, "index") + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_ints(self, future_stack): + columns = MultiIndex.from_tuples(list(itertools.product(range(3), repeat=3))) + df = DataFrame( + np.random.default_rng(2).standard_normal((30, 27)), columns=columns + ) + + tm.assert_frame_equal( + df.stack(level=[1, 2], future_stack=future_stack), + df.stack(level=1, future_stack=future_stack).stack( + level=1, future_stack=future_stack + ), + ) + tm.assert_frame_equal( + df.stack(level=[-2, -1], future_stack=future_stack), + df.stack(level=1, future_stack=future_stack).stack( + level=1, future_stack=future_stack + ), + ) + + df_named = df.copy() + return_value = df_named.columns.set_names(range(3), inplace=True) + assert return_value is None + + tm.assert_frame_equal( + df_named.stack(level=[1, 2], future_stack=future_stack), + df_named.stack(level=1, future_stack=future_stack).stack( + level=1, future_stack=future_stack + ), + ) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_mixed_levels(self, future_stack): + columns = MultiIndex.from_tuples( + [ + ("A", "cat", "long"), + ("B", "cat", "long"), + ("A", "dog", "short"), + ("B", "dog", "short"), + ], + names=["exp", "animal", "hair_length"], + ) + df = DataFrame( + np.random.default_rng(2).standard_normal((4, 4)), columns=columns + ) + + animal_hair_stacked = df.stack( + level=["animal", "hair_length"], future_stack=future_stack + ) + exp_hair_stacked = df.stack( + level=["exp", "hair_length"], future_stack=future_stack + ) + + # GH #8584: Need to check that stacking works when a number + # is passed that is both a level name and in the range of + # the level numbers + df2 = df.copy() + df2.columns.names = ["exp", "animal", 1] + tm.assert_frame_equal( + df2.stack(level=["animal", 1], future_stack=future_stack), + animal_hair_stacked, + check_names=False, + ) + tm.assert_frame_equal( + df2.stack(level=["exp", 1], future_stack=future_stack), + exp_hair_stacked, + check_names=False, + ) + + # When mixed types are passed and the ints are not level + # names, raise + msg = ( + "level should contain all level names or all level numbers, not " + "a mixture of the two" + ) + with pytest.raises(ValueError, match=msg): + df2.stack(level=["animal", 0], future_stack=future_stack) + + # GH #8584: Having 0 in the level names could raise a + # strange error about lexsort depth + df3 = df.copy() + df3.columns.names = ["exp", "animal", 0] + tm.assert_frame_equal( + df3.stack(level=["animal", 0], future_stack=future_stack), + animal_hair_stacked, + check_names=False, + ) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_int_level_names(self, future_stack): + columns = MultiIndex.from_tuples( + [ + ("A", "cat", "long"), + ("B", "cat", "long"), + ("A", "dog", "short"), + ("B", "dog", "short"), + ], + names=["exp", "animal", "hair_length"], + ) + df = DataFrame( + np.random.default_rng(2).standard_normal((4, 4)), columns=columns + ) + + exp_animal_stacked = df.stack( + level=["exp", "animal"], future_stack=future_stack + ) + animal_hair_stacked = df.stack( + level=["animal", "hair_length"], future_stack=future_stack + ) + exp_hair_stacked = df.stack( + level=["exp", "hair_length"], future_stack=future_stack + ) + + df2 = df.copy() + df2.columns.names = [0, 1, 2] + tm.assert_frame_equal( + df2.stack(level=[1, 2], future_stack=future_stack), + animal_hair_stacked, + check_names=False, + ) + tm.assert_frame_equal( + df2.stack(level=[0, 1], future_stack=future_stack), + exp_animal_stacked, + check_names=False, + ) + tm.assert_frame_equal( + df2.stack(level=[0, 2], future_stack=future_stack), + exp_hair_stacked, + check_names=False, + ) + + # Out-of-order int column names + df3 = df.copy() + df3.columns.names = [2, 0, 1] + tm.assert_frame_equal( + df3.stack(level=[0, 1], future_stack=future_stack), + animal_hair_stacked, + check_names=False, + ) + tm.assert_frame_equal( + df3.stack(level=[2, 0], future_stack=future_stack), + exp_animal_stacked, + check_names=False, + ) + tm.assert_frame_equal( + df3.stack(level=[2, 1], future_stack=future_stack), + exp_hair_stacked, + check_names=False, + ) + + def test_unstack_bool(self): + df = DataFrame( + [False, False], + index=MultiIndex.from_arrays([["a", "b"], ["c", "l"]]), + columns=["col"], + ) + rs = df.unstack() + xp = DataFrame( + np.array([[False, np.nan], [np.nan, False]], dtype=object), + index=["a", "b"], + columns=MultiIndex.from_arrays([["col", "col"], ["c", "l"]]), + ) + tm.assert_frame_equal(rs, xp) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_unstack_level_binding(self, future_stack): + # GH9856 + mi = MultiIndex( + levels=[["foo", "bar"], ["one", "two"], ["a", "b"]], + codes=[[0, 0, 1, 1], [0, 1, 0, 1], [1, 0, 1, 0]], + names=["first", "second", "third"], + ) + s = Series(0, index=mi) + result = s.unstack([1, 2]).stack(0, future_stack=future_stack) + + expected_mi = MultiIndex( + levels=[["foo", "bar"], ["one", "two"]], + codes=[[0, 0, 1, 1], [0, 1, 0, 1]], + names=["first", "second"], + ) + + expected = DataFrame( + np.array( + [[0, np.nan], [np.nan, 0], [0, np.nan], [np.nan, 0]], dtype=np.float64 + ), + index=expected_mi, + columns=Index(["b", "a"], name="third"), + ) + + tm.assert_frame_equal(result, expected) + + def test_unstack_to_series(self, float_frame): + # check reversibility + data = float_frame.unstack() + + assert isinstance(data, Series) + undo = data.unstack().T + tm.assert_frame_equal(undo, float_frame) + + # check NA handling + data = DataFrame({"x": [1, 2, np.nan], "y": [3.0, 4, np.nan]}) + data.index = Index(["a", "b", "c"]) + result = data.unstack() + + midx = MultiIndex( + levels=[["x", "y"], ["a", "b", "c"]], + codes=[[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]], + ) + expected = Series([1, 2, np.nan, 3, 4, np.nan], index=midx) + + tm.assert_series_equal(result, expected) + + # check composability of unstack + old_data = data.copy() + for _ in range(4): + data = data.unstack() + tm.assert_frame_equal(old_data, data) + + def test_unstack_dtypes(self, using_infer_string): + # GH 2929 + rows = [[1, 1, 3, 4], [1, 2, 3, 4], [2, 1, 3, 4], [2, 2, 3, 4]] + + df = DataFrame(rows, columns=list("ABCD")) + result = df.dtypes + expected = Series([np.dtype("int64")] * 4, index=list("ABCD")) + tm.assert_series_equal(result, expected) + + # single dtype + df2 = df.set_index(["A", "B"]) + df3 = df2.unstack("B") + result = df3.dtypes + expected = Series( + [np.dtype("int64")] * 4, + index=MultiIndex.from_arrays( + [["C", "C", "D", "D"], [1, 2, 1, 2]], names=(None, "B") + ), + ) + tm.assert_series_equal(result, expected) + + # mixed + df2 = df.set_index(["A", "B"]) + df2["C"] = 3.0 + df3 = df2.unstack("B") + result = df3.dtypes + expected = Series( + [np.dtype("float64")] * 2 + [np.dtype("int64")] * 2, + index=MultiIndex.from_arrays( + [["C", "C", "D", "D"], [1, 2, 1, 2]], names=(None, "B") + ), + ) + tm.assert_series_equal(result, expected) + df2["D"] = "foo" + df3 = df2.unstack("B") + result = df3.dtypes + dtype = "string" if using_infer_string else np.dtype("object") + expected = Series( + [np.dtype("float64")] * 2 + [dtype] * 2, + index=MultiIndex.from_arrays( + [["C", "C", "D", "D"], [1, 2, 1, 2]], names=(None, "B") + ), + ) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "c, d", + ( + (np.zeros(5), np.zeros(5)), + (np.arange(5, dtype="f8"), np.arange(5, 10, dtype="f8")), + ), + ) + def test_unstack_dtypes_mixed_date(self, c, d): + # GH7405 + df = DataFrame( + { + "A": ["a"] * 5, + "C": c, + "D": d, + "B": date_range("2012-01-01", periods=5), + } + ) + + right = df.iloc[:3].copy(deep=True) + + df = df.set_index(["A", "B"]) + df["D"] = df["D"].astype("int64") + + left = df.iloc[:3].unstack(0) + right = right.set_index(["A", "B"]).unstack(0) + right[("D", "a")] = right[("D", "a")].astype("int64") + + assert left.shape == (3, 2) + tm.assert_frame_equal(left, right) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_unstack_non_unique_index_names(self, future_stack): + idx = MultiIndex.from_tuples([("a", "b"), ("c", "d")], names=["c1", "c1"]) + df = DataFrame([1, 2], index=idx) + msg = "The name c1 occurs multiple times, use a level number" + with pytest.raises(ValueError, match=msg): + df.unstack("c1") + + with pytest.raises(ValueError, match=msg): + df.T.stack("c1", future_stack=future_stack) + + def test_unstack_unused_levels(self): + # GH 17845: unused codes in index make unstack() cast int to float + idx = MultiIndex.from_product([["a"], ["A", "B", "C", "D"]])[:-1] + df = DataFrame([[1, 0]] * 3, index=idx) + + result = df.unstack() + exp_col = MultiIndex.from_product([[0, 1], ["A", "B", "C"]]) + expected = DataFrame([[1, 1, 1, 0, 0, 0]], index=["a"], columns=exp_col) + tm.assert_frame_equal(result, expected) + assert (result.columns.levels[1] == idx.levels[1]).all() + + # Unused items on both levels + levels = [[0, 1, 7], [0, 1, 2, 3]] + codes = [[0, 0, 1, 1], [0, 2, 0, 2]] + idx = MultiIndex(levels, codes) + block = np.arange(4).reshape(2, 2) + df = DataFrame(np.concatenate([block, block + 4]), index=idx) + result = df.unstack() + expected = DataFrame( + np.concatenate([block * 2, block * 2 + 1], axis=1), columns=idx + ) + tm.assert_frame_equal(result, expected) + assert (result.columns.levels[1] == idx.levels[1]).all() + + @pytest.mark.parametrize( + "level, idces, col_level, idx_level", + ( + (0, [13, 16, 6, 9, 2, 5, 8, 11], [np.nan, "a", 2], [np.nan, 5, 1]), + (1, [8, 11, 1, 4, 12, 15, 13, 16], [np.nan, 5, 1], [np.nan, "a", 2]), + ), + ) + def test_unstack_unused_levels_mixed_with_nan( + self, level, idces, col_level, idx_level + ): + # With mixed dtype and NaN + levels = [["a", 2, "c"], [1, 3, 5, 7]] + codes = [[0, -1, 1, 1], [0, 2, -1, 2]] + idx = MultiIndex(levels, codes) + data = np.arange(8) + df = DataFrame(data.reshape(4, 2), index=idx) + + result = df.unstack(level=level) + exp_data = np.zeros(18) * np.nan + exp_data[idces] = data + cols = MultiIndex.from_product([[0, 1], col_level]) + expected = DataFrame(exp_data.reshape(3, 6), index=idx_level, columns=cols) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("cols", [["A", "C"], slice(None)]) + def test_unstack_unused_level(self, cols): + # GH 18562 : unused codes on the unstacked level + df = DataFrame([[2010, "a", "I"], [2011, "b", "II"]], columns=["A", "B", "C"]) + + ind = df.set_index(["A", "B", "C"], drop=False) + selection = ind.loc[(slice(None), slice(None), "I"), cols] + result = selection.unstack() + + expected = ind.iloc[[0]][cols] + expected.columns = MultiIndex.from_product( + [expected.columns, ["I"]], names=[None, "C"] + ) + expected.index = expected.index.droplevel("C") + tm.assert_frame_equal(result, expected) + + def test_unstack_long_index(self): + # PH 32624: Error when using a lot of indices to unstack. + # The error occurred only, if a lot of indices are used. + df = DataFrame( + [[1]], + columns=MultiIndex.from_tuples([[0]], names=["c1"]), + index=MultiIndex.from_tuples( + [[0, 0, 1, 0, 0, 0, 1]], + names=["i1", "i2", "i3", "i4", "i5", "i6", "i7"], + ), + ) + result = df.unstack(["i2", "i3", "i4", "i5", "i6", "i7"]) + expected = DataFrame( + [[1]], + columns=MultiIndex.from_tuples( + [[0, 0, 1, 0, 0, 0, 1]], + names=["c1", "i2", "i3", "i4", "i5", "i6", "i7"], + ), + index=Index([0], name="i1"), + ) + tm.assert_frame_equal(result, expected) + + def test_unstack_multi_level_cols(self): + # PH 24729: Unstack a df with multi level columns + df = DataFrame( + [[0.0, 0.0], [0.0, 0.0]], + columns=MultiIndex.from_tuples( + [["B", "C"], ["B", "D"]], names=["c1", "c2"] + ), + index=MultiIndex.from_tuples( + [[10, 20, 30], [10, 20, 40]], names=["i1", "i2", "i3"] + ), + ) + assert df.unstack(["i2", "i1"]).columns.names[-2:] == ["i2", "i1"] + + def test_unstack_multi_level_rows_and_cols(self): + # PH 28306: Unstack df with multi level cols and rows + df = DataFrame( + [[1, 2], [3, 4], [-1, -2], [-3, -4]], + columns=MultiIndex.from_tuples([["a", "b", "c"], ["d", "e", "f"]]), + index=MultiIndex.from_tuples( + [ + ["m1", "P3", 222], + ["m1", "A5", 111], + ["m2", "P3", 222], + ["m2", "A5", 111], + ], + names=["i1", "i2", "i3"], + ), + ) + result = df.unstack(["i3", "i2"]) + expected = df.unstack(["i3"]).unstack(["i2"]) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("idx", [("jim", "joe"), ("joe", "jim")]) + @pytest.mark.parametrize("lev", list(range(2))) + def test_unstack_nan_index1(self, idx, lev): + # GH7466 + def cast(val): + val_str = "" if val != val else val + return f"{val_str:1}" + + df = DataFrame( + { + "jim": ["a", "b", np.nan, "d"], + "joe": ["w", "x", "y", "z"], + "jolie": ["a.w", "b.x", " .y", "d.z"], + } + ) + + left = df.set_index(["jim", "joe"]).unstack()["jolie"] + right = df.set_index(["joe", "jim"]).unstack()["jolie"].T + tm.assert_frame_equal(left, right) + + mi = df.set_index(list(idx)) + udf = mi.unstack(level=lev) + assert udf.notna().values.sum() == len(df) + mk_list = lambda a: list(a) if isinstance(a, tuple) else [a] + rows, cols = udf["jolie"].notna().values.nonzero() + for i, j in zip(rows, cols): + left = sorted(udf["jolie"].iloc[i, j].split(".")) + right = mk_list(udf["jolie"].index[i]) + mk_list(udf["jolie"].columns[j]) + right = sorted(map(cast, right)) + assert left == right + + @pytest.mark.parametrize("idx", itertools.permutations(["1st", "2nd", "3rd"])) + @pytest.mark.parametrize("lev", list(range(3))) + @pytest.mark.parametrize("col", ["4th", "5th"]) + def test_unstack_nan_index_repeats(self, idx, lev, col): + def cast(val): + val_str = "" if val != val else val + return f"{val_str:1}" + + df = DataFrame( + { + "1st": ["d"] * 3 + + [np.nan] * 5 + + ["a"] * 2 + + ["c"] * 3 + + ["e"] * 2 + + ["b"] * 5, + "2nd": ["y"] * 2 + + ["w"] * 3 + + [np.nan] * 3 + + ["z"] * 4 + + [np.nan] * 3 + + ["x"] * 3 + + [np.nan] * 2, + "3rd": [ + 67, + 39, + 53, + 72, + 57, + 80, + 31, + 18, + 11, + 30, + 59, + 50, + 62, + 59, + 76, + 52, + 14, + 53, + 60, + 51, + ], + } + ) + + df["4th"], df["5th"] = ( + df.apply(lambda r: ".".join(map(cast, r)), axis=1), + df.apply(lambda r: ".".join(map(cast, r.iloc[::-1])), axis=1), + ) + + mi = df.set_index(list(idx)) + udf = mi.unstack(level=lev) + assert udf.notna().values.sum() == 2 * len(df) + mk_list = lambda a: list(a) if isinstance(a, tuple) else [a] + rows, cols = udf[col].notna().values.nonzero() + for i, j in zip(rows, cols): + left = sorted(udf[col].iloc[i, j].split(".")) + right = mk_list(udf[col].index[i]) + mk_list(udf[col].columns[j]) + right = sorted(map(cast, right)) + assert left == right + + def test_unstack_nan_index2(self): + # GH7403 + df = DataFrame({"A": list("aaaabbbb"), "B": range(8), "C": range(8)}) + # Explicit cast to avoid implicit cast when setting to np.nan + df = df.astype({"B": "float"}) + df.iloc[3, 1] = np.nan + left = df.set_index(["A", "B"]).unstack(0) + + vals = [ + [3, 0, 1, 2, np.nan, np.nan, np.nan, np.nan], + [np.nan, np.nan, np.nan, np.nan, 4, 5, 6, 7], + ] + vals = list(map(list, zip(*vals))) + idx = Index([np.nan, 0, 1, 2, 4, 5, 6, 7], name="B") + cols = MultiIndex( + levels=[["C"], ["a", "b"]], codes=[[0, 0], [0, 1]], names=[None, "A"] + ) + + right = DataFrame(vals, columns=cols, index=idx) + tm.assert_frame_equal(left, right) + + df = DataFrame({"A": list("aaaabbbb"), "B": list(range(4)) * 2, "C": range(8)}) + # Explicit cast to avoid implicit cast when setting to np.nan + df = df.astype({"B": "float"}) + df.iloc[2, 1] = np.nan + left = df.set_index(["A", "B"]).unstack(0) + + vals = [[2, np.nan], [0, 4], [1, 5], [np.nan, 6], [3, 7]] + cols = MultiIndex( + levels=[["C"], ["a", "b"]], codes=[[0, 0], [0, 1]], names=[None, "A"] + ) + idx = Index([np.nan, 0, 1, 2, 3], name="B") + right = DataFrame(vals, columns=cols, index=idx) + tm.assert_frame_equal(left, right) + + df = DataFrame({"A": list("aaaabbbb"), "B": list(range(4)) * 2, "C": range(8)}) + # Explicit cast to avoid implicit cast when setting to np.nan + df = df.astype({"B": "float"}) + df.iloc[3, 1] = np.nan + left = df.set_index(["A", "B"]).unstack(0) + + vals = [[3, np.nan], [0, 4], [1, 5], [2, 6], [np.nan, 7]] + cols = MultiIndex( + levels=[["C"], ["a", "b"]], codes=[[0, 0], [0, 1]], names=[None, "A"] + ) + idx = Index([np.nan, 0, 1, 2, 3], name="B") + right = DataFrame(vals, columns=cols, index=idx) + tm.assert_frame_equal(left, right) + + def test_unstack_nan_index3(self, using_array_manager): + # GH7401 + df = DataFrame( + { + "A": list("aaaaabbbbb"), + "B": (date_range("2012-01-01", periods=5).tolist() * 2), + "C": np.arange(10), + } + ) + + df.iloc[3, 1] = np.nan + left = df.set_index(["A", "B"]).unstack() + + vals = np.array([[3, 0, 1, 2, np.nan, 4], [np.nan, 5, 6, 7, 8, 9]]) + idx = Index(["a", "b"], name="A") + cols = MultiIndex( + levels=[["C"], date_range("2012-01-01", periods=5)], + codes=[[0, 0, 0, 0, 0, 0], [-1, 0, 1, 2, 3, 4]], + names=[None, "B"], + ) + + right = DataFrame(vals, columns=cols, index=idx) + if using_array_manager: + # INFO(ArrayManager) with ArrayManager preserve dtype where possible + cols = right.columns[[1, 2, 3, 5]] + right[cols] = right[cols].astype(df["C"].dtype) + tm.assert_frame_equal(left, right) + + def test_unstack_nan_index4(self): + # GH4862 + vals = [ + ["Hg", np.nan, np.nan, 680585148], + ["U", 0.0, np.nan, 680585148], + ["Pb", 7.07e-06, np.nan, 680585148], + ["Sn", 2.3614e-05, 0.0133, 680607017], + ["Ag", 0.0, 0.0133, 680607017], + ["Hg", -0.00015, 0.0133, 680607017], + ] + df = DataFrame( + vals, + columns=["agent", "change", "dosage", "s_id"], + index=[17263, 17264, 17265, 17266, 17267, 17268], + ) + + left = df.copy().set_index(["s_id", "dosage", "agent"]).unstack() + + vals = [ + [np.nan, np.nan, 7.07e-06, np.nan, 0.0], + [0.0, -0.00015, np.nan, 2.3614e-05, np.nan], + ] + + idx = MultiIndex( + levels=[[680585148, 680607017], [0.0133]], + codes=[[0, 1], [-1, 0]], + names=["s_id", "dosage"], + ) + + cols = MultiIndex( + levels=[["change"], ["Ag", "Hg", "Pb", "Sn", "U"]], + codes=[[0, 0, 0, 0, 0], [0, 1, 2, 3, 4]], + names=[None, "agent"], + ) + + right = DataFrame(vals, columns=cols, index=idx) + tm.assert_frame_equal(left, right) + + left = df.loc[17264:].copy().set_index(["s_id", "dosage", "agent"]) + tm.assert_frame_equal(left.unstack(), right) + + def test_unstack_nan_index5(self): + # GH9497 - multiple unstack with nulls + df = DataFrame( + { + "1st": [1, 2, 1, 2, 1, 2], + "2nd": date_range("2014-02-01", periods=6, freq="D"), + "jim": 100 + np.arange(6), + "joe": (np.random.default_rng(2).standard_normal(6) * 10).round(2), + } + ) + + df["3rd"] = df["2nd"] - pd.Timestamp("2014-02-02") + df.loc[1, "2nd"] = df.loc[3, "2nd"] = np.nan + df.loc[1, "3rd"] = df.loc[4, "3rd"] = np.nan + + left = df.set_index(["1st", "2nd", "3rd"]).unstack(["2nd", "3rd"]) + assert left.notna().values.sum() == 2 * len(df) + + for col in ["jim", "joe"]: + for _, r in df.iterrows(): + key = r["1st"], (col, r["2nd"], r["3rd"]) + assert r[col] == left.loc[key] + + def test_stack_datetime_column_multiIndex(self, future_stack): + # GH 8039 + t = datetime(2014, 1, 1) + df = DataFrame([1, 2, 3, 4], columns=MultiIndex.from_tuples([(t, "A", "B")])) + warn = None if future_stack else FutureWarning + msg = "The previous implementation of stack is deprecated" + with tm.assert_produces_warning(warn, match=msg): + result = df.stack(future_stack=future_stack) + + eidx = MultiIndex.from_product([(0, 1, 2, 3), ("B",)]) + ecols = MultiIndex.from_tuples([(t, "A")]) + expected = DataFrame([1, 2, 3, 4], index=eidx, columns=ecols) + tm.assert_frame_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + @pytest.mark.parametrize( + "multiindex_columns", + [ + [0, 1, 2, 3, 4], + [0, 1, 2, 3], + [0, 1, 2, 4], + [0, 1, 2], + [1, 2, 3], + [2, 3, 4], + [0, 1], + [0, 2], + [0, 3], + [0], + [2], + [4], + [4, 3, 2, 1, 0], + [3, 2, 1, 0], + [4, 2, 1, 0], + [2, 1, 0], + [3, 2, 1], + [4, 3, 2], + [1, 0], + [2, 0], + [3, 0], + ], + ) + @pytest.mark.parametrize("level", (-1, 0, 1, [0, 1], [1, 0])) + def test_stack_partial_multiIndex(self, multiindex_columns, level, future_stack): + # GH 8844 + dropna = False if not future_stack else lib.no_default + full_multiindex = MultiIndex.from_tuples( + [("B", "x"), ("B", "z"), ("A", "y"), ("C", "x"), ("C", "u")], + names=["Upper", "Lower"], + ) + multiindex = full_multiindex[multiindex_columns] + df = DataFrame( + np.arange(3 * len(multiindex)).reshape(3, len(multiindex)), + columns=multiindex, + ) + result = df.stack(level=level, dropna=dropna, future_stack=future_stack) + + if isinstance(level, int) and not future_stack: + # Stacking a single level should not make any all-NaN rows, + # so df.stack(level=level, dropna=False) should be the same + # as df.stack(level=level, dropna=True). + expected = df.stack(level=level, dropna=True, future_stack=future_stack) + if isinstance(expected, Series): + tm.assert_series_equal(result, expected) + else: + tm.assert_frame_equal(result, expected) + + df.columns = MultiIndex.from_tuples( + df.columns.to_numpy(), names=df.columns.names + ) + expected = df.stack(level=level, dropna=dropna, future_stack=future_stack) + if isinstance(expected, Series): + tm.assert_series_equal(result, expected) + else: + tm.assert_frame_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_full_multiIndex(self, future_stack): + # GH 8844 + full_multiindex = MultiIndex.from_tuples( + [("B", "x"), ("B", "z"), ("A", "y"), ("C", "x"), ("C", "u")], + names=["Upper", "Lower"], + ) + df = DataFrame(np.arange(6).reshape(2, 3), columns=full_multiindex[[0, 1, 3]]) + dropna = False if not future_stack else lib.no_default + result = df.stack(dropna=dropna, future_stack=future_stack) + expected = DataFrame( + [[0, 2], [1, np.nan], [3, 5], [4, np.nan]], + index=MultiIndex( + levels=[[0, 1], ["u", "x", "y", "z"]], + codes=[[0, 0, 1, 1], [1, 3, 1, 3]], + names=[None, "Lower"], + ), + columns=Index(["B", "C"], name="Upper"), + ) + expected["B"] = expected["B"].astype(df.dtypes.iloc[0]) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("ordered", [False, True]) + def test_stack_preserve_categorical_dtype(self, ordered, future_stack): + # GH13854 + cidx = pd.CategoricalIndex(list("yxz"), categories=list("xyz"), ordered=ordered) + df = DataFrame([[10, 11, 12]], columns=cidx) + result = df.stack(future_stack=future_stack) + + # `MultiIndex.from_product` preserves categorical dtype - + # it's tested elsewhere. + midx = MultiIndex.from_product([df.index, cidx]) + expected = Series([10, 11, 12], index=midx) + + tm.assert_series_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + @pytest.mark.parametrize("ordered", [False, True]) + @pytest.mark.parametrize( + "labels,data", + [ + (list("xyz"), [10, 11, 12, 13, 14, 15]), + (list("zyx"), [14, 15, 12, 13, 10, 11]), + ], + ) + def test_stack_multi_preserve_categorical_dtype( + self, ordered, labels, data, future_stack + ): + # GH-36991 + cidx = pd.CategoricalIndex(labels, categories=sorted(labels), ordered=ordered) + cidx2 = pd.CategoricalIndex(["u", "v"], ordered=ordered) + midx = MultiIndex.from_product([cidx, cidx2]) + df = DataFrame([sorted(data)], columns=midx) + result = df.stack([0, 1], future_stack=future_stack) + + labels = labels if future_stack else sorted(labels) + s_cidx = pd.CategoricalIndex(labels, ordered=ordered) + expected_data = sorted(data) if future_stack else data + expected = Series( + expected_data, index=MultiIndex.from_product([[0], s_cidx, cidx2]) + ) + + tm.assert_series_equal(result, expected) + + def test_stack_preserve_categorical_dtype_values(self, future_stack): + # GH-23077 + cat = pd.Categorical(["a", "a", "b", "c"]) + df = DataFrame({"A": cat, "B": cat}) + result = df.stack(future_stack=future_stack) + index = MultiIndex.from_product([[0, 1, 2, 3], ["A", "B"]]) + expected = Series( + pd.Categorical(["a", "a", "a", "a", "b", "b", "c", "c"]), index=index + ) + tm.assert_series_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + @pytest.mark.filterwarnings("ignore:Downcasting object dtype arrays:FutureWarning") + @pytest.mark.parametrize( + "index, columns", + [ + ([0, 0, 1, 1], MultiIndex.from_product([[1, 2], ["a", "b"]])), + ([0, 0, 2, 3], MultiIndex.from_product([[1, 2], ["a", "b"]])), + ([0, 1, 2, 3], MultiIndex.from_product([[1, 2], ["a", "b"]])), + ], + ) + def test_stack_multi_columns_non_unique_index(self, index, columns, future_stack): + # GH-28301 + + df = DataFrame(index=index, columns=columns).fillna(1) + stacked = df.stack(future_stack=future_stack) + new_index = MultiIndex.from_tuples(stacked.index.to_numpy()) + expected = DataFrame( + stacked.to_numpy(), index=new_index, columns=stacked.columns + ) + tm.assert_frame_equal(stacked, expected) + stacked_codes = np.asarray(stacked.index.codes) + expected_codes = np.asarray(new_index.codes) + tm.assert_numpy_array_equal(stacked_codes, expected_codes) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + @pytest.mark.parametrize( + "vals1, vals2, dtype1, dtype2, expected_dtype", + [ + ([1, 2], [3.0, 4.0], "Int64", "Float64", "Float64"), + ([1, 2], ["foo", "bar"], "Int64", "string", "object"), + ], + ) + def test_stack_multi_columns_mixed_extension_types( + self, vals1, vals2, dtype1, dtype2, expected_dtype, future_stack + ): + # GH45740 + df = DataFrame( + { + ("A", 1): Series(vals1, dtype=dtype1), + ("A", 2): Series(vals2, dtype=dtype2), + } + ) + result = df.stack(future_stack=future_stack) + expected = ( + df.astype(object).stack(future_stack=future_stack).astype(expected_dtype) + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("level", [0, 1]) + def test_unstack_mixed_extension_types(self, level): + index = MultiIndex.from_tuples([("A", 0), ("A", 1), ("B", 1)], names=["a", "b"]) + df = DataFrame( + { + "A": pd.array([0, 1, None], dtype="Int64"), + "B": pd.Categorical(["a", "a", "b"]), + }, + index=index, + ) + + result = df.unstack(level=level) + expected = df.astype(object).unstack(level=level) + if level == 0: + expected[("A", "B")] = expected[("A", "B")].fillna(pd.NA) + else: + expected[("A", 0)] = expected[("A", 0)].fillna(pd.NA) + + expected_dtypes = Series( + [df.A.dtype] * 2 + [df.B.dtype] * 2, index=result.columns + ) + tm.assert_series_equal(result.dtypes, expected_dtypes) + tm.assert_frame_equal(result.astype(object), expected) + + @pytest.mark.parametrize("level", [0, "baz"]) + def test_unstack_swaplevel_sortlevel(self, level): + # GH 20994 + mi = MultiIndex.from_product([[0], ["d", "c"]], names=["bar", "baz"]) + df = DataFrame([[0, 2], [1, 3]], index=mi, columns=["B", "A"]) + df.columns.name = "foo" + + expected = DataFrame( + [[3, 1, 2, 0]], + columns=MultiIndex.from_tuples( + [("c", "A"), ("c", "B"), ("d", "A"), ("d", "B")], names=["baz", "foo"] + ), + ) + expected.index.name = "bar" + + result = df.unstack().swaplevel(axis=1).sort_index(axis=1, level=level) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype", ["float64", "Float64"]) +def test_unstack_sort_false(frame_or_series, dtype): + # GH 15105 + index = MultiIndex.from_tuples( + [("two", "z", "b"), ("two", "y", "a"), ("one", "z", "b"), ("one", "y", "a")] + ) + obj = frame_or_series(np.arange(1.0, 5.0), index=index, dtype=dtype) + result = obj.unstack(level=-1, sort=False) + + if frame_or_series is DataFrame: + expected_columns = MultiIndex.from_tuples([(0, "b"), (0, "a")]) + else: + expected_columns = ["b", "a"] + expected = DataFrame( + [[1.0, np.nan], [np.nan, 2.0], [3.0, np.nan], [np.nan, 4.0]], + columns=expected_columns, + index=MultiIndex.from_tuples( + [("two", "z"), ("two", "y"), ("one", "z"), ("one", "y")] + ), + dtype=dtype, + ) + tm.assert_frame_equal(result, expected) + + result = obj.unstack(level=[1, 2], sort=False) + + if frame_or_series is DataFrame: + expected_columns = MultiIndex.from_tuples([(0, "z", "b"), (0, "y", "a")]) + else: + expected_columns = MultiIndex.from_tuples([("z", "b"), ("y", "a")]) + expected = DataFrame( + [[1.0, 2.0], [3.0, 4.0]], + index=["two", "one"], + columns=expected_columns, + dtype=dtype, + ) + tm.assert_frame_equal(result, expected) + + +def test_unstack_fill_frame_object(): + # GH12815 Test unstacking with object. + data = Series(["a", "b", "c", "a"], dtype="object") + data.index = MultiIndex.from_tuples( + [("x", "a"), ("x", "b"), ("y", "b"), ("z", "a")] + ) + + # By default missing values will be NaN + result = data.unstack() + expected = DataFrame( + {"a": ["a", np.nan, "a"], "b": ["b", "c", np.nan]}, + index=list("xyz"), + dtype=object, + ) + tm.assert_frame_equal(result, expected) + + # Fill with any value replaces missing values as expected + result = data.unstack(fill_value="d") + expected = DataFrame( + {"a": ["a", "d", "a"], "b": ["b", "c", "d"]}, index=list("xyz"), dtype=object + ) + tm.assert_frame_equal(result, expected) + + +def test_unstack_timezone_aware_values(): + # GH 18338 + df = DataFrame( + { + "timestamp": [pd.Timestamp("2017-08-27 01:00:00.709949+0000", tz="UTC")], + "a": ["a"], + "b": ["b"], + "c": ["c"], + }, + columns=["timestamp", "a", "b", "c"], + ) + result = df.set_index(["a", "b"]).unstack() + expected = DataFrame( + [[pd.Timestamp("2017-08-27 01:00:00.709949+0000", tz="UTC"), "c"]], + index=Index(["a"], name="a"), + columns=MultiIndex( + levels=[["timestamp", "c"], ["b"]], + codes=[[0, 1], [0, 0]], + names=[None, "b"], + ), + ) + tm.assert_frame_equal(result, expected) + + +def test_stack_timezone_aware_values(future_stack): + # GH 19420 + ts = date_range(freq="D", start="20180101", end="20180103", tz="America/New_York") + df = DataFrame({"A": ts}, index=["a", "b", "c"]) + result = df.stack(future_stack=future_stack) + expected = Series( + ts, + index=MultiIndex(levels=[["a", "b", "c"], ["A"]], codes=[[0, 1, 2], [0, 0, 0]]), + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore:The previous implementation of stack is deprecated") +@pytest.mark.parametrize("dropna", [True, False, lib.no_default]) +def test_stack_empty_frame(dropna, future_stack): + # GH 36113 + levels = [np.array([], dtype=np.int64), np.array([], dtype=np.int64)] + expected = Series(dtype=np.float64, index=MultiIndex(levels=levels, codes=[[], []])) + if future_stack and dropna is not lib.no_default: + with pytest.raises(ValueError, match="dropna must be unspecified"): + DataFrame(dtype=np.float64).stack(dropna=dropna, future_stack=future_stack) + else: + result = DataFrame(dtype=np.float64).stack( + dropna=dropna, future_stack=future_stack + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore:The previous implementation of stack is deprecated") +@pytest.mark.parametrize("dropna", [True, False, lib.no_default]) +@pytest.mark.parametrize("fill_value", [None, 0]) +def test_stack_unstack_empty_frame(dropna, fill_value, future_stack): + # GH 36113 + if future_stack and dropna is not lib.no_default: + with pytest.raises(ValueError, match="dropna must be unspecified"): + DataFrame(dtype=np.int64).stack( + dropna=dropna, future_stack=future_stack + ).unstack(fill_value=fill_value) + else: + result = ( + DataFrame(dtype=np.int64) + .stack(dropna=dropna, future_stack=future_stack) + .unstack(fill_value=fill_value) + ) + expected = DataFrame(dtype=np.int64) + tm.assert_frame_equal(result, expected) + + +def test_unstack_single_index_series(): + # GH 36113 + msg = r"index must be a MultiIndex to unstack.*" + with pytest.raises(ValueError, match=msg): + Series(dtype=np.int64).unstack() + + +def test_unstacking_multi_index_df(): + # see gh-30740 + df = DataFrame( + { + "name": ["Alice", "Bob"], + "score": [9.5, 8], + "employed": [False, True], + "kids": [0, 0], + "gender": ["female", "male"], + } + ) + df = df.set_index(["name", "employed", "kids", "gender"]) + df = df.unstack(["gender"], fill_value=0) + expected = df.unstack("employed", fill_value=0).unstack("kids", fill_value=0) + result = df.unstack(["employed", "kids"], fill_value=0) + expected = DataFrame( + [[9.5, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 8.0]], + index=Index(["Alice", "Bob"], name="name"), + columns=MultiIndex.from_tuples( + [ + ("score", "female", False, 0), + ("score", "female", True, 0), + ("score", "male", False, 0), + ("score", "male", True, 0), + ], + names=[None, "gender", "employed", "kids"], + ), + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore:The previous implementation of stack is deprecated") +def test_stack_positional_level_duplicate_column_names(future_stack): + # https://github.com/pandas-dev/pandas/issues/36353 + columns = MultiIndex.from_product([("x", "y"), ("y", "z")], names=["a", "a"]) + df = DataFrame([[1, 1, 1, 1]], columns=columns) + result = df.stack(0, future_stack=future_stack) + + new_columns = Index(["y", "z"], name="a") + new_index = MultiIndex.from_tuples([(0, "x"), (0, "y")], names=[None, "a"]) + expected = DataFrame([[1, 1], [1, 1]], index=new_index, columns=new_columns) + + tm.assert_frame_equal(result, expected) + + +def test_unstack_non_slice_like_blocks(using_array_manager): + # Case where the mgr_locs of a DataFrame's underlying blocks are not slice-like + + mi = MultiIndex.from_product([range(5), ["A", "B", "C"]]) + df = DataFrame( + { + 0: np.random.default_rng(2).standard_normal(15), + 1: np.random.default_rng(2).standard_normal(15).astype(np.int64), + 2: np.random.default_rng(2).standard_normal(15), + 3: np.random.default_rng(2).standard_normal(15), + }, + index=mi, + ) + if not using_array_manager: + assert any(not x.mgr_locs.is_slice_like for x in df._mgr.blocks) + + res = df.unstack() + + expected = pd.concat([df[n].unstack() for n in range(4)], keys=range(4), axis=1) + tm.assert_frame_equal(res, expected) + + +@pytest.mark.filterwarnings("ignore:The previous implementation of stack is deprecated") +def test_stack_sort_false(future_stack): + # GH 15105 + data = [[1, 2, 3.0, 4.0], [2, 3, 4.0, 5.0], [3, 4, np.nan, np.nan]] + df = DataFrame( + data, + columns=MultiIndex( + levels=[["B", "A"], ["x", "y"]], codes=[[0, 0, 1, 1], [0, 1, 0, 1]] + ), + ) + kwargs = {} if future_stack else {"sort": False} + result = df.stack(level=0, future_stack=future_stack, **kwargs) + if future_stack: + expected = DataFrame( + { + "x": [1.0, 3.0, 2.0, 4.0, 3.0, np.nan], + "y": [2.0, 4.0, 3.0, 5.0, 4.0, np.nan], + }, + index=MultiIndex.from_arrays( + [[0, 0, 1, 1, 2, 2], ["B", "A", "B", "A", "B", "A"]] + ), + ) + else: + expected = DataFrame( + {"x": [1.0, 3.0, 2.0, 4.0, 3.0], "y": [2.0, 4.0, 3.0, 5.0, 4.0]}, + index=MultiIndex.from_arrays([[0, 0, 1, 1, 2], ["B", "A", "B", "A", "B"]]), + ) + tm.assert_frame_equal(result, expected) + + # Codes sorted in this call + df = DataFrame( + data, + columns=MultiIndex.from_arrays([["B", "B", "A", "A"], ["x", "y", "x", "y"]]), + ) + kwargs = {} if future_stack else {"sort": False} + result = df.stack(level=0, future_stack=future_stack, **kwargs) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore:The previous implementation of stack is deprecated") +def test_stack_sort_false_multi_level(future_stack): + # GH 15105 + idx = MultiIndex.from_tuples([("weight", "kg"), ("height", "m")]) + df = DataFrame([[1.0, 2.0], [3.0, 4.0]], index=["cat", "dog"], columns=idx) + kwargs = {} if future_stack else {"sort": False} + result = df.stack([0, 1], future_stack=future_stack, **kwargs) + expected_index = MultiIndex.from_tuples( + [ + ("cat", "weight", "kg"), + ("cat", "height", "m"), + ("dog", "weight", "kg"), + ("dog", "height", "m"), + ] + ) + expected = Series([1.0, 2.0, 3.0, 4.0], index=expected_index) + tm.assert_series_equal(result, expected) + + +class TestStackUnstackMultiLevel: + def test_unstack(self, multiindex_year_month_day_dataframe_random_data): + # just check that it works for now + ymd = multiindex_year_month_day_dataframe_random_data + + unstacked = ymd.unstack() + unstacked.unstack() + + # test that ints work + ymd.astype(int).unstack() + + # test that int32 work + ymd.astype(np.int32).unstack() + + @pytest.mark.parametrize( + "result_rows,result_columns,index_product,expected_row", + [ + ( + [[1, 1, None, None, 30.0, None], [2, 2, None, None, 30.0, None]], + ["ix1", "ix2", "col1", "col2", "col3", "col4"], + 2, + [None, None, 30.0, None], + ), + ( + [[1, 1, None, None, 30.0], [2, 2, None, None, 30.0]], + ["ix1", "ix2", "col1", "col2", "col3"], + 2, + [None, None, 30.0], + ), + ( + [[1, 1, None, None, 30.0], [2, None, None, None, 30.0]], + ["ix1", "ix2", "col1", "col2", "col3"], + None, + [None, None, 30.0], + ), + ], + ) + def test_unstack_partial( + self, result_rows, result_columns, index_product, expected_row + ): + # check for regressions on this issue: + # https://github.com/pandas-dev/pandas/issues/19351 + # make sure DataFrame.unstack() works when its run on a subset of the DataFrame + # and the Index levels contain values that are not present in the subset + result = DataFrame(result_rows, columns=result_columns).set_index( + ["ix1", "ix2"] + ) + result = result.iloc[1:2].unstack("ix2") + expected = DataFrame( + [expected_row], + columns=MultiIndex.from_product( + [result_columns[2:], [index_product]], names=[None, "ix2"] + ), + index=Index([2], name="ix1"), + ) + tm.assert_frame_equal(result, expected) + + def test_unstack_multiple_no_empty_columns(self): + index = MultiIndex.from_tuples( + [(0, "foo", 0), (0, "bar", 0), (1, "baz", 1), (1, "qux", 1)] + ) + + s = Series(np.random.default_rng(2).standard_normal(4), index=index) + + unstacked = s.unstack([1, 2]) + expected = unstacked.dropna(axis=1, how="all") + tm.assert_frame_equal(unstacked, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack(self, multiindex_year_month_day_dataframe_random_data, future_stack): + ymd = multiindex_year_month_day_dataframe_random_data + + # regular roundtrip + unstacked = ymd.unstack() + restacked = unstacked.stack(future_stack=future_stack) + if future_stack: + # NA values in unstacked persist to restacked in version 3 + restacked = restacked.dropna(how="all") + tm.assert_frame_equal(restacked, ymd) + + unlexsorted = ymd.sort_index(level=2) + + unstacked = unlexsorted.unstack(2) + restacked = unstacked.stack(future_stack=future_stack) + if future_stack: + # NA values in unstacked persist to restacked in version 3 + restacked = restacked.dropna(how="all") + tm.assert_frame_equal(restacked.sort_index(level=0), ymd) + + unlexsorted = unlexsorted[::-1] + unstacked = unlexsorted.unstack(1) + restacked = unstacked.stack(future_stack=future_stack).swaplevel(1, 2) + if future_stack: + # NA values in unstacked persist to restacked in version 3 + restacked = restacked.dropna(how="all") + tm.assert_frame_equal(restacked.sort_index(level=0), ymd) + + unlexsorted = unlexsorted.swaplevel(0, 1) + unstacked = unlexsorted.unstack(0).swaplevel(0, 1, axis=1) + restacked = unstacked.stack(0, future_stack=future_stack).swaplevel(1, 2) + if future_stack: + # NA values in unstacked persist to restacked in version 3 + restacked = restacked.dropna(how="all") + tm.assert_frame_equal(restacked.sort_index(level=0), ymd) + + # columns unsorted + unstacked = ymd.unstack() + restacked = unstacked.stack(future_stack=future_stack) + if future_stack: + # NA values in unstacked persist to restacked in version 3 + restacked = restacked.dropna(how="all") + tm.assert_frame_equal(restacked, ymd) + + # more than 2 levels in the columns + unstacked = ymd.unstack(1).unstack(1) + + result = unstacked.stack(1, future_stack=future_stack) + expected = ymd.unstack() + tm.assert_frame_equal(result, expected) + + result = unstacked.stack(2, future_stack=future_stack) + expected = ymd.unstack(1) + tm.assert_frame_equal(result, expected) + + result = unstacked.stack(0, future_stack=future_stack) + expected = ymd.stack(future_stack=future_stack).unstack(1).unstack(1) + tm.assert_frame_equal(result, expected) + + # not all levels present in each echelon + unstacked = ymd.unstack(2).loc[:, ::3] + stacked = unstacked.stack(future_stack=future_stack).stack( + future_stack=future_stack + ) + ymd_stacked = ymd.stack(future_stack=future_stack) + if future_stack: + # NA values in unstacked persist to restacked in version 3 + stacked = stacked.dropna(how="all") + ymd_stacked = ymd_stacked.dropna(how="all") + tm.assert_series_equal(stacked, ymd_stacked.reindex(stacked.index)) + + # stack with negative number + result = ymd.unstack(0).stack(-2, future_stack=future_stack) + expected = ymd.unstack(0).stack(0, future_stack=future_stack) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize( + "idx, columns, exp_idx", + [ + [ + list("abab"), + ["1st", "2nd", "1st"], + MultiIndex( + levels=[["a", "b"], ["1st", "2nd"]], + codes=[np.tile(np.arange(2).repeat(3), 2), np.tile([0, 1, 0], 4)], + ), + ], + [ + MultiIndex.from_tuples((("a", 2), ("b", 1), ("a", 1), ("b", 2))), + ["1st", "2nd", "1st"], + MultiIndex( + levels=[["a", "b"], [1, 2], ["1st", "2nd"]], + codes=[ + np.tile(np.arange(2).repeat(3), 2), + np.repeat([1, 0, 1], [3, 6, 3]), + np.tile([0, 1, 0], 4), + ], + ), + ], + ], + ) + def test_stack_duplicate_index(self, idx, columns, exp_idx, future_stack): + # GH10417 + df = DataFrame( + np.arange(12).reshape(4, 3), + index=idx, + columns=columns, + ) + if future_stack: + msg = "Columns with duplicate values are not supported in stack" + with pytest.raises(ValueError, match=msg): + df.stack(future_stack=future_stack) + else: + result = df.stack(future_stack=future_stack) + expected = Series(np.arange(12), index=exp_idx) + tm.assert_series_equal(result, expected) + assert result.index.is_unique is False + li, ri = result.index, expected.index + tm.assert_index_equal(li, ri) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_unstack_odd_failure(self, future_stack): + mi = MultiIndex.from_arrays( + [ + ["Fri"] * 4 + ["Sat"] * 2 + ["Sun"] * 2 + ["Thu"] * 3, + ["Dinner"] * 2 + ["Lunch"] * 2 + ["Dinner"] * 5 + ["Lunch"] * 2, + ["No", "Yes"] * 4 + ["No", "No", "Yes"], + ], + names=["day", "time", "smoker"], + ) + df = DataFrame( + { + "sum": np.arange(11, dtype="float64"), + "len": np.arange(11, dtype="float64"), + }, + index=mi, + ) + # it works, #2100 + result = df.unstack(2) + + recons = result.stack(future_stack=future_stack) + if future_stack: + # NA values in unstacked persist to restacked in version 3 + recons = recons.dropna(how="all") + tm.assert_frame_equal(recons, df) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_mixed_dtype(self, multiindex_dataframe_random_data, future_stack): + frame = multiindex_dataframe_random_data + + df = frame.T + df["foo", "four"] = "foo" + df = df.sort_index(level=1, axis=1) + + stacked = df.stack(future_stack=future_stack) + result = df["foo"].stack(future_stack=future_stack).sort_index() + tm.assert_series_equal(stacked["foo"], result, check_names=False) + assert result.name is None + assert stacked["bar"].dtype == np.float64 + + def test_unstack_bug(self, future_stack): + df = DataFrame( + { + "state": ["naive", "naive", "naive", "active", "active", "active"], + "exp": ["a", "b", "b", "b", "a", "a"], + "barcode": [1, 2, 3, 4, 1, 3], + "v": ["hi", "hi", "bye", "bye", "bye", "peace"], + "extra": np.arange(6.0), + } + ) + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby(["state", "exp", "barcode", "v"]).apply(len) + + unstacked = result.unstack() + restacked = unstacked.stack(future_stack=future_stack) + tm.assert_series_equal(restacked, result.reindex(restacked.index).astype(float)) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_unstack_preserve_names( + self, multiindex_dataframe_random_data, future_stack + ): + frame = multiindex_dataframe_random_data + + unstacked = frame.unstack() + assert unstacked.index.name == "first" + assert unstacked.columns.names == ["exp", "second"] + + restacked = unstacked.stack(future_stack=future_stack) + assert restacked.index.names == frame.index.names + + @pytest.mark.parametrize("method", ["stack", "unstack"]) + def test_stack_unstack_wrong_level_name( + self, method, multiindex_dataframe_random_data, future_stack + ): + # GH 18303 - wrong level name should raise + frame = multiindex_dataframe_random_data + + # A DataFrame with flat axes: + df = frame.loc["foo"] + + kwargs = {"future_stack": future_stack} if method == "stack" else {} + with pytest.raises(KeyError, match="does not match index name"): + getattr(df, method)("mistake", **kwargs) + + if method == "unstack": + # Same on a Series: + s = df.iloc[:, 0] + with pytest.raises(KeyError, match="does not match index name"): + getattr(s, method)("mistake", **kwargs) + + def test_unstack_level_name(self, multiindex_dataframe_random_data): + frame = multiindex_dataframe_random_data + + result = frame.unstack("second") + expected = frame.unstack(level=1) + tm.assert_frame_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_level_name(self, multiindex_dataframe_random_data, future_stack): + frame = multiindex_dataframe_random_data + + unstacked = frame.unstack("second") + result = unstacked.stack("exp", future_stack=future_stack) + expected = frame.unstack().stack(0, future_stack=future_stack) + tm.assert_frame_equal(result, expected) + + result = frame.stack("exp", future_stack=future_stack) + expected = frame.stack(future_stack=future_stack) + tm.assert_series_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_unstack_multiple( + self, multiindex_year_month_day_dataframe_random_data, future_stack + ): + ymd = multiindex_year_month_day_dataframe_random_data + + unstacked = ymd.unstack(["year", "month"]) + expected = ymd.unstack("year").unstack("month") + tm.assert_frame_equal(unstacked, expected) + assert unstacked.columns.names == expected.columns.names + + # series + s = ymd["A"] + s_unstacked = s.unstack(["year", "month"]) + tm.assert_frame_equal(s_unstacked, expected["A"]) + + restacked = unstacked.stack(["year", "month"], future_stack=future_stack) + if future_stack: + # NA values in unstacked persist to restacked in version 3 + restacked = restacked.dropna(how="all") + restacked = restacked.swaplevel(0, 1).swaplevel(1, 2) + restacked = restacked.sort_index(level=0) + + tm.assert_frame_equal(restacked, ymd) + assert restacked.index.names == ymd.index.names + + # GH #451 + unstacked = ymd.unstack([1, 2]) + expected = ymd.unstack(1).unstack(1).dropna(axis=1, how="all") + tm.assert_frame_equal(unstacked, expected) + + unstacked = ymd.unstack([2, 1]) + expected = ymd.unstack(2).unstack(1).dropna(axis=1, how="all") + tm.assert_frame_equal(unstacked, expected.loc[:, unstacked.columns]) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_names_and_numbers( + self, multiindex_year_month_day_dataframe_random_data, future_stack + ): + ymd = multiindex_year_month_day_dataframe_random_data + + unstacked = ymd.unstack(["year", "month"]) + + # Can't use mixture of names and numbers to stack + with pytest.raises(ValueError, match="level should contain"): + unstacked.stack([0, "month"], future_stack=future_stack) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_multiple_out_of_bounds( + self, multiindex_year_month_day_dataframe_random_data, future_stack + ): + # nlevels == 3 + ymd = multiindex_year_month_day_dataframe_random_data + + unstacked = ymd.unstack(["year", "month"]) + + with pytest.raises(IndexError, match="Too many levels"): + unstacked.stack([2, 3], future_stack=future_stack) + with pytest.raises(IndexError, match="not a valid level number"): + unstacked.stack([-4, -3], future_stack=future_stack) + + def test_unstack_period_series(self): + # GH4342 + idx1 = pd.PeriodIndex( + ["2013-01", "2013-01", "2013-02", "2013-02", "2013-03", "2013-03"], + freq="M", + name="period", + ) + idx2 = Index(["A", "B"] * 3, name="str") + value = [1, 2, 3, 4, 5, 6] + + idx = MultiIndex.from_arrays([idx1, idx2]) + s = Series(value, index=idx) + + result1 = s.unstack() + result2 = s.unstack(level=1) + result3 = s.unstack(level=0) + + e_idx = pd.PeriodIndex( + ["2013-01", "2013-02", "2013-03"], freq="M", name="period" + ) + expected = DataFrame( + {"A": [1, 3, 5], "B": [2, 4, 6]}, index=e_idx, columns=["A", "B"] + ) + expected.columns.name = "str" + + tm.assert_frame_equal(result1, expected) + tm.assert_frame_equal(result2, expected) + tm.assert_frame_equal(result3, expected.T) + + idx1 = pd.PeriodIndex( + ["2013-01", "2013-01", "2013-02", "2013-02", "2013-03", "2013-03"], + freq="M", + name="period1", + ) + + idx2 = pd.PeriodIndex( + ["2013-12", "2013-11", "2013-10", "2013-09", "2013-08", "2013-07"], + freq="M", + name="period2", + ) + idx = MultiIndex.from_arrays([idx1, idx2]) + s = Series(value, index=idx) + + result1 = s.unstack() + result2 = s.unstack(level=1) + result3 = s.unstack(level=0) + + e_idx = pd.PeriodIndex( + ["2013-01", "2013-02", "2013-03"], freq="M", name="period1" + ) + e_cols = pd.PeriodIndex( + ["2013-07", "2013-08", "2013-09", "2013-10", "2013-11", "2013-12"], + freq="M", + name="period2", + ) + expected = DataFrame( + [ + [np.nan, np.nan, np.nan, np.nan, 2, 1], + [np.nan, np.nan, 4, 3, np.nan, np.nan], + [6, 5, np.nan, np.nan, np.nan, np.nan], + ], + index=e_idx, + columns=e_cols, + ) + + tm.assert_frame_equal(result1, expected) + tm.assert_frame_equal(result2, expected) + tm.assert_frame_equal(result3, expected.T) + + def test_unstack_period_frame(self): + # GH4342 + idx1 = pd.PeriodIndex( + ["2014-01", "2014-02", "2014-02", "2014-02", "2014-01", "2014-01"], + freq="M", + name="period1", + ) + idx2 = pd.PeriodIndex( + ["2013-12", "2013-12", "2014-02", "2013-10", "2013-10", "2014-02"], + freq="M", + name="period2", + ) + value = {"A": [1, 2, 3, 4, 5, 6], "B": [6, 5, 4, 3, 2, 1]} + idx = MultiIndex.from_arrays([idx1, idx2]) + df = DataFrame(value, index=idx) + + result1 = df.unstack() + result2 = df.unstack(level=1) + result3 = df.unstack(level=0) + + e_1 = pd.PeriodIndex(["2014-01", "2014-02"], freq="M", name="period1") + e_2 = pd.PeriodIndex( + ["2013-10", "2013-12", "2014-02", "2013-10", "2013-12", "2014-02"], + freq="M", + name="period2", + ) + e_cols = MultiIndex.from_arrays(["A A A B B B".split(), e_2]) + expected = DataFrame( + [[5, 1, 6, 2, 6, 1], [4, 2, 3, 3, 5, 4]], index=e_1, columns=e_cols + ) + + tm.assert_frame_equal(result1, expected) + tm.assert_frame_equal(result2, expected) + + e_1 = pd.PeriodIndex( + ["2014-01", "2014-02", "2014-01", "2014-02"], freq="M", name="period1" + ) + e_2 = pd.PeriodIndex( + ["2013-10", "2013-12", "2014-02"], freq="M", name="period2" + ) + e_cols = MultiIndex.from_arrays(["A A B B".split(), e_1]) + expected = DataFrame( + [[5, 4, 2, 3], [1, 2, 6, 5], [6, 3, 1, 4]], index=e_2, columns=e_cols + ) + + tm.assert_frame_equal(result3, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_multiple_bug(self, future_stack): + # bug when some uniques are not present in the data GH#3170 + id_col = ([1] * 3) + ([2] * 3) + name = (["a"] * 3) + (["b"] * 3) + date = pd.to_datetime(["2013-01-03", "2013-01-04", "2013-01-05"] * 2) + var1 = np.random.default_rng(2).integers(0, 100, 6) + df = DataFrame({"ID": id_col, "NAME": name, "DATE": date, "VAR1": var1}) + + multi = df.set_index(["DATE", "ID"]) + multi.columns.name = "Params" + unst = multi.unstack("ID") + msg = re.escape("agg function failed [how->mean,dtype->") + with pytest.raises(TypeError, match=msg): + unst.resample("W-THU").mean() + down = unst.resample("W-THU").mean(numeric_only=True) + rs = down.stack("ID", future_stack=future_stack) + xp = ( + unst.loc[:, ["VAR1"]] + .resample("W-THU") + .mean() + .stack("ID", future_stack=future_stack) + ) + xp.columns.name = "Params" + tm.assert_frame_equal(rs, xp) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_dropna(self, future_stack): + # GH#3997 + df = DataFrame({"A": ["a1", "a2"], "B": ["b1", "b2"], "C": [1, 1]}) + df = df.set_index(["A", "B"]) + + dropna = False if not future_stack else lib.no_default + stacked = df.unstack().stack(dropna=dropna, future_stack=future_stack) + assert len(stacked) > len(stacked.dropna()) + + if future_stack: + with pytest.raises(ValueError, match="dropna must be unspecified"): + df.unstack().stack(dropna=True, future_stack=future_stack) + else: + stacked = df.unstack().stack(dropna=True, future_stack=future_stack) + tm.assert_frame_equal(stacked, stacked.dropna()) + + def test_unstack_multiple_hierarchical(self, future_stack): + df = DataFrame( + index=[ + [0, 0, 0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 0, 0, 1, 1], + [0, 1, 0, 1, 0, 1, 0, 1], + ], + columns=[[0, 0, 1, 1], [0, 1, 0, 1]], + ) + + df.index.names = ["a", "b", "c"] + df.columns.names = ["d", "e"] + + # it works! + df.unstack(["b", "c"]) + + def test_unstack_sparse_keyspace(self): + # memory problems with naive impl GH#2278 + # Generate Long File & Test Pivot + NUM_ROWS = 1000 + + df = DataFrame( + { + "A": np.random.default_rng(2).integers(100, size=NUM_ROWS), + "B": np.random.default_rng(3).integers(300, size=NUM_ROWS), + "C": np.random.default_rng(4).integers(-7, 7, size=NUM_ROWS), + "D": np.random.default_rng(5).integers(-19, 19, size=NUM_ROWS), + "E": np.random.default_rng(6).integers(3000, size=NUM_ROWS), + "F": np.random.default_rng(7).standard_normal(NUM_ROWS), + } + ) + + idf = df.set_index(["A", "B", "C", "D", "E"]) + + # it works! is sufficient + idf.unstack("E") + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_unstack_unobserved_keys(self, future_stack): + # related to GH#2278 refactoring + levels = [[0, 1], [0, 1, 2, 3]] + codes = [[0, 0, 1, 1], [0, 2, 0, 2]] + + index = MultiIndex(levels, codes) + + df = DataFrame(np.random.default_rng(2).standard_normal((4, 2)), index=index) + + result = df.unstack() + assert len(result.columns) == 4 + + recons = result.stack(future_stack=future_stack) + tm.assert_frame_equal(recons, df) + + @pytest.mark.slow + def test_unstack_number_of_levels_larger_than_int32(self, monkeypatch): + # GH#20601 + # GH 26314: Change ValueError to PerformanceWarning + + class MockUnstacker(reshape_lib._Unstacker): + def __init__(self, *args, **kwargs) -> None: + # __init__ will raise the warning + super().__init__(*args, **kwargs) + raise Exception("Don't compute final result.") + + with monkeypatch.context() as m: + m.setattr(reshape_lib, "_Unstacker", MockUnstacker) + df = DataFrame( + np.zeros((2**16, 2)), + index=[np.arange(2**16), np.arange(2**16)], + ) + msg = "The following operation may generate" + with tm.assert_produces_warning(PerformanceWarning, match=msg): + with pytest.raises(Exception, match="Don't compute final result."): + df.unstack() + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + @pytest.mark.parametrize( + "levels", + itertools.chain.from_iterable( + itertools.product(itertools.permutations([0, 1, 2], width), repeat=2) + for width in [2, 3] + ), + ) + @pytest.mark.parametrize("stack_lev", range(2)) + @pytest.mark.parametrize("sort", [True, False]) + def test_stack_order_with_unsorted_levels( + self, levels, stack_lev, sort, future_stack + ): + # GH#16323 + # deep check for 1-row case + columns = MultiIndex(levels=levels, codes=[[0, 0, 1, 1], [0, 1, 0, 1]]) + df = DataFrame(columns=columns, data=[range(4)]) + kwargs = {} if future_stack else {"sort": sort} + df_stacked = df.stack(stack_lev, future_stack=future_stack, **kwargs) + for row in df.index: + for col in df.columns: + expected = df.loc[row, col] + result_row = row, col[stack_lev] + result_col = col[1 - stack_lev] + result = df_stacked.loc[result_row, result_col] + assert result == expected + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_order_with_unsorted_levels_multi_row(self, future_stack): + # GH#16323 + + # check multi-row case + mi = MultiIndex( + levels=[["A", "C", "B"], ["B", "A", "C"]], + codes=[np.repeat(range(3), 3), np.tile(range(3), 3)], + ) + df = DataFrame( + columns=mi, index=range(5), data=np.arange(5 * len(mi)).reshape(5, -1) + ) + assert all( + df.loc[row, col] + == df.stack(0, future_stack=future_stack).loc[(row, col[0]), col[1]] + for row in df.index + for col in df.columns + ) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_order_with_unsorted_levels_multi_row_2(self, future_stack): + # GH#53636 + levels = ((0, 1), (1, 0)) + stack_lev = 1 + columns = MultiIndex(levels=levels, codes=[[0, 0, 1, 1], [0, 1, 0, 1]]) + df = DataFrame(columns=columns, data=[range(4)], index=[1, 0, 2, 3]) + kwargs = {} if future_stack else {"sort": True} + result = df.stack(stack_lev, future_stack=future_stack, **kwargs) + expected_index = MultiIndex( + levels=[[0, 1, 2, 3], [0, 1]], + codes=[[1, 1, 0, 0, 2, 2, 3, 3], [1, 0, 1, 0, 1, 0, 1, 0]], + ) + expected = DataFrame( + { + 0: [0, 1, 0, 1, 0, 1, 0, 1], + 1: [2, 3, 2, 3, 2, 3, 2, 3], + }, + index=expected_index, + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_unstack_unordered_multiindex(self, future_stack): + # GH# 18265 + values = np.arange(5) + data = np.vstack( + [ + [f"b{x}" for x in values], # b0, b1, .. + [f"a{x}" for x in values], # a0, a1, .. + ] + ) + df = DataFrame(data.T, columns=["b", "a"]) + df.columns.name = "first" + second_level_dict = {"x": df} + multi_level_df = pd.concat(second_level_dict, axis=1) + multi_level_df.columns.names = ["second", "first"] + df = multi_level_df.reindex(sorted(multi_level_df.columns), axis=1) + result = df.stack(["first", "second"], future_stack=future_stack).unstack( + ["first", "second"] + ) + expected = DataFrame( + [["a0", "b0"], ["a1", "b1"], ["a2", "b2"], ["a3", "b3"], ["a4", "b4"]], + index=[0, 1, 2, 3, 4], + columns=MultiIndex.from_tuples( + [("a", "x"), ("b", "x")], names=["first", "second"] + ), + ) + tm.assert_frame_equal(result, expected) + + def test_unstack_preserve_types( + self, multiindex_year_month_day_dataframe_random_data, using_infer_string + ): + # GH#403 + ymd = multiindex_year_month_day_dataframe_random_data + ymd["E"] = "foo" + ymd["F"] = 2 + + unstacked = ymd.unstack("month") + assert unstacked["A", 1].dtype == np.float64 + assert ( + unstacked["E", 1].dtype == np.object_ + if not using_infer_string + else "string" + ) + assert unstacked["F", 1].dtype == np.float64 + + def test_unstack_group_index_overflow(self, future_stack): + codes = np.tile(np.arange(500), 2) + level = np.arange(500) + + index = MultiIndex( + levels=[level] * 8 + [[0, 1]], + codes=[codes] * 8 + [np.arange(2).repeat(500)], + ) + + s = Series(np.arange(1000), index=index) + result = s.unstack() + assert result.shape == (500, 2) + + # test roundtrip + stacked = result.stack(future_stack=future_stack) + tm.assert_series_equal(s, stacked.reindex(s.index)) + + # put it at beginning + index = MultiIndex( + levels=[[0, 1]] + [level] * 8, + codes=[np.arange(2).repeat(500)] + [codes] * 8, + ) + + s = Series(np.arange(1000), index=index) + result = s.unstack(0) + assert result.shape == (500, 2) + + # put it in middle + index = MultiIndex( + levels=[level] * 4 + [[0, 1]] + [level] * 4, + codes=([codes] * 4 + [np.arange(2).repeat(500)] + [codes] * 4), + ) + + s = Series(np.arange(1000), index=index) + result = s.unstack(4) + assert result.shape == (500, 2) + + def test_unstack_with_missing_int_cast_to_float(self, using_array_manager): + # https://github.com/pandas-dev/pandas/issues/37115 + df = DataFrame( + { + "a": ["A", "A", "B"], + "b": ["ca", "cb", "cb"], + "v": [10] * 3, + } + ).set_index(["a", "b"]) + + # add another int column to get 2 blocks + df["is_"] = 1 + if not using_array_manager: + assert len(df._mgr.blocks) == 2 + + result = df.unstack("b") + result[("is_", "ca")] = result[("is_", "ca")].fillna(0) + + expected = DataFrame( + [[10.0, 10.0, 1.0, 1.0], [np.nan, 10.0, 0.0, 1.0]], + index=Index(["A", "B"], name="a"), + columns=MultiIndex.from_tuples( + [("v", "ca"), ("v", "cb"), ("is_", "ca"), ("is_", "cb")], + names=[None, "b"], + ), + ) + if using_array_manager: + # INFO(ArrayManager) with ArrayManager preserve dtype where possible + expected[("v", "cb")] = expected[("v", "cb")].astype("int64") + expected[("is_", "cb")] = expected[("is_", "cb")].astype("int64") + tm.assert_frame_equal(result, expected) + + def test_unstack_with_level_has_nan(self): + # GH 37510 + df1 = DataFrame( + { + "L1": [1, 2, 3, 4], + "L2": [3, 4, 1, 2], + "L3": [1, 1, 1, 1], + "x": [1, 2, 3, 4], + } + ) + df1 = df1.set_index(["L1", "L2", "L3"]) + new_levels = ["n1", "n2", "n3", None] + df1.index = df1.index.set_levels(levels=new_levels, level="L1") + df1.index = df1.index.set_levels(levels=new_levels, level="L2") + + result = df1.unstack("L3")[("x", 1)].sort_index().index + expected = MultiIndex( + levels=[["n1", "n2", "n3", None], ["n1", "n2", "n3", None]], + codes=[[0, 1, 2, 3], [2, 3, 0, 1]], + names=["L1", "L2"], + ) + + tm.assert_index_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_nan_in_multiindex_columns(self, future_stack): + # GH#39481 + df = DataFrame( + np.zeros([1, 5]), + columns=MultiIndex.from_tuples( + [ + (0, None, None), + (0, 2, 0), + (0, 2, 1), + (0, 3, 0), + (0, 3, 1), + ], + ), + ) + result = df.stack(2, future_stack=future_stack) + if future_stack: + index = MultiIndex(levels=[[0], [0.0, 1.0]], codes=[[0, 0, 0], [-1, 0, 1]]) + columns = MultiIndex(levels=[[0], [2, 3]], codes=[[0, 0, 0], [-1, 0, 1]]) + else: + index = Index([(0, None), (0, 0), (0, 1)]) + columns = Index([(0, None), (0, 2), (0, 3)]) + expected = DataFrame( + [[0.0, np.nan, np.nan], [np.nan, 0.0, 0.0], [np.nan, 0.0, 0.0]], + index=index, + columns=columns, + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_multi_level_stack_categorical(self, future_stack): + # GH 15239 + midx = MultiIndex.from_arrays( + [ + ["A"] * 2 + ["B"] * 2, + pd.Categorical(list("abab")), + pd.Categorical(list("ccdd")), + ] + ) + df = DataFrame(np.arange(8).reshape(2, 4), columns=midx) + result = df.stack([1, 2], future_stack=future_stack) + if future_stack: + expected = DataFrame( + [ + [0, np.nan], + [1, np.nan], + [np.nan, 2], + [np.nan, 3], + [4, np.nan], + [5, np.nan], + [np.nan, 6], + [np.nan, 7], + ], + columns=["A", "B"], + index=MultiIndex.from_arrays( + [ + [0] * 4 + [1] * 4, + pd.Categorical(list("abababab")), + pd.Categorical(list("ccddccdd")), + ] + ), + ) + else: + expected = DataFrame( + [ + [0, np.nan], + [np.nan, 2], + [1, np.nan], + [np.nan, 3], + [4, np.nan], + [np.nan, 6], + [5, np.nan], + [np.nan, 7], + ], + columns=["A", "B"], + index=MultiIndex.from_arrays( + [ + [0] * 4 + [1] * 4, + pd.Categorical(list("aabbaabb")), + pd.Categorical(list("cdcdcdcd")), + ] + ), + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_nan_level(self, future_stack): + # GH 9406 + df_nan = DataFrame( + np.arange(4).reshape(2, 2), + columns=MultiIndex.from_tuples( + [("A", np.nan), ("B", "b")], names=["Upper", "Lower"] + ), + index=Index([0, 1], name="Num"), + dtype=np.float64, + ) + result = df_nan.stack(future_stack=future_stack) + if future_stack: + index = MultiIndex( + levels=[[0, 1], [np.nan, "b"]], + codes=[[0, 0, 1, 1], [0, 1, 0, 1]], + names=["Num", "Lower"], + ) + else: + index = MultiIndex.from_tuples( + [(0, np.nan), (0, "b"), (1, np.nan), (1, "b")], names=["Num", "Lower"] + ) + expected = DataFrame( + [[0.0, np.nan], [np.nan, 1], [2.0, np.nan], [np.nan, 3.0]], + columns=Index(["A", "B"], name="Upper"), + index=index, + ) + tm.assert_frame_equal(result, expected) + + def test_unstack_categorical_columns(self): + # GH 14018 + idx = MultiIndex.from_product([["A"], [0, 1]]) + df = DataFrame({"cat": pd.Categorical(["a", "b"])}, index=idx) + result = df.unstack() + expected = DataFrame( + { + 0: pd.Categorical(["a"], categories=["a", "b"]), + 1: pd.Categorical(["b"], categories=["a", "b"]), + }, + index=["A"], + ) + expected.columns = MultiIndex.from_tuples([("cat", 0), ("cat", 1)]) + tm.assert_frame_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_unsorted(self, future_stack): + # GH 16925 + PAE = ["ITA", "FRA"] + VAR = ["A1", "A2"] + TYP = ["CRT", "DBT", "NET"] + MI = MultiIndex.from_product([PAE, VAR, TYP], names=["PAE", "VAR", "TYP"]) + + V = list(range(len(MI))) + DF = DataFrame(data=V, index=MI, columns=["VALUE"]) + + DF = DF.unstack(["VAR", "TYP"]) + DF.columns = DF.columns.droplevel(0) + DF.loc[:, ("A0", "NET")] = 9999 + + result = DF.stack(["VAR", "TYP"], future_stack=future_stack).sort_index() + expected = ( + DF.sort_index(axis=1) + .stack(["VAR", "TYP"], future_stack=future_stack) + .sort_index() + ) + tm.assert_series_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_nullable_dtype(self, future_stack): + # GH#43561 + columns = MultiIndex.from_product( + [["54511", "54515"], ["r", "t_mean"]], names=["station", "element"] + ) + index = Index([1, 2, 3], name="time") + + arr = np.array([[50, 226, 10, 215], [10, 215, 9, 220], [305, 232, 111, 220]]) + df = DataFrame(arr, columns=columns, index=index, dtype=pd.Int64Dtype()) + + result = df.stack("station", future_stack=future_stack) + + expected = ( + df.astype(np.int64) + .stack("station", future_stack=future_stack) + .astype(pd.Int64Dtype()) + ) + tm.assert_frame_equal(result, expected) + + # non-homogeneous case + df[df.columns[0]] = df[df.columns[0]].astype(pd.Float64Dtype()) + result = df.stack("station", future_stack=future_stack) + + expected = DataFrame( + { + "r": pd.array( + [50.0, 10.0, 10.0, 9.0, 305.0, 111.0], dtype=pd.Float64Dtype() + ), + "t_mean": pd.array( + [226, 215, 215, 220, 232, 220], dtype=pd.Int64Dtype() + ), + }, + index=MultiIndex.from_product([index, columns.levels[0]]), + ) + expected.columns.name = "element" + tm.assert_frame_equal(result, expected) + + def test_unstack_mixed_level_names(self): + # GH#48763 + arrays = [["a", "a"], [1, 2], ["red", "blue"]] + idx = MultiIndex.from_arrays(arrays, names=("x", 0, "y")) + df = DataFrame({"m": [1, 2]}, index=idx) + result = df.unstack("x") + expected = DataFrame( + [[1], [2]], + columns=MultiIndex.from_tuples([("m", "a")], names=[None, "x"]), + index=MultiIndex.from_tuples([(1, "red"), (2, "blue")], names=[0, "y"]), + ) + tm.assert_frame_equal(result, expected) + + +def test_stack_tuple_columns(future_stack): + # GH#54948 - test stack when the input has a non-MultiIndex with tuples + df = DataFrame( + [[1, 2, 3], [4, 5, 6], [7, 8, 9]], columns=[("a", 1), ("a", 2), ("b", 1)] + ) + result = df.stack(future_stack=future_stack) + expected = Series( + [1, 2, 3, 4, 5, 6, 7, 8, 9], + index=MultiIndex( + levels=[[0, 1, 2], [("a", 1), ("a", 2), ("b", 1)]], + codes=[[0, 0, 0, 1, 1, 1, 2, 2, 2], [0, 1, 2, 0, 1, 2, 0, 1, 2]], + ), + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "dtype, na_value", + [ + ("float64", np.nan), + ("Float64", np.nan), + ("Float64", pd.NA), + ("Int64", pd.NA), + ], +) +@pytest.mark.parametrize("test_multiindex", [True, False]) +def test_stack_preserves_na(dtype, na_value, test_multiindex): + # GH#56573 + if test_multiindex: + index = MultiIndex.from_arrays(2 * [Index([na_value], dtype=dtype)]) + else: + index = Index([na_value], dtype=dtype) + df = DataFrame({"a": [1]}, index=index) + result = df.stack(future_stack=True) + + if test_multiindex: + expected_index = MultiIndex.from_arrays( + [ + Index([na_value], dtype=dtype), + Index([na_value], dtype=dtype), + Index(["a"]), + ] + ) + else: + expected_index = MultiIndex.from_arrays( + [ + Index([na_value], dtype=dtype), + Index(["a"]), + ] + ) + expected = Series(1, index=expected_index) + tm.assert_series_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_subclass.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_subclass.py new file mode 100644 index 0000000000000000000000000000000000000000..855b58229cbdb5819e83e3abe39a938bbb8658eb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_subclass.py @@ -0,0 +1,825 @@ +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, +) +import pandas._testing as tm + +pytestmark = pytest.mark.filterwarnings( + "ignore:Passing a BlockManager|Passing a SingleBlockManager:DeprecationWarning" +) + + +@pytest.fixture() +def gpd_style_subclass_df(): + class SubclassedDataFrame(DataFrame): + @property + def _constructor(self): + return SubclassedDataFrame + + return SubclassedDataFrame({"a": [1, 2, 3]}) + + +class TestDataFrameSubclassing: + def test_no_warning_on_mgr(self): + # GH#57032 + df = tm.SubclassedDataFrame( + {"X": [1, 2, 3], "Y": [1, 2, 3]}, index=["a", "b", "c"] + ) + with tm.assert_produces_warning(None): + # df.isna() goes through _constructor_from_mgr, which we want to + # *not* pass a Manager do __init__ + df.isna() + df["X"].isna() + + def test_frame_subclassing_and_slicing(self): + # Subclass frame and ensure it returns the right class on slicing it + # In reference to PR 9632 + + class CustomSeries(Series): + @property + def _constructor(self): + return CustomSeries + + def custom_series_function(self): + return "OK" + + class CustomDataFrame(DataFrame): + """ + Subclasses pandas DF, fills DF with simulation results, adds some + custom plotting functions. + """ + + def __init__(self, *args, **kw) -> None: + super().__init__(*args, **kw) + + @property + def _constructor(self): + return CustomDataFrame + + _constructor_sliced = CustomSeries + + def custom_frame_function(self): + return "OK" + + data = {"col1": range(10), "col2": range(10)} + cdf = CustomDataFrame(data) + + # Did we get back our own DF class? + assert isinstance(cdf, CustomDataFrame) + + # Do we get back our own Series class after selecting a column? + cdf_series = cdf.col1 + assert isinstance(cdf_series, CustomSeries) + assert cdf_series.custom_series_function() == "OK" + + # Do we get back our own DF class after slicing row-wise? + cdf_rows = cdf[1:5] + assert isinstance(cdf_rows, CustomDataFrame) + assert cdf_rows.custom_frame_function() == "OK" + + # Make sure sliced part of multi-index frame is custom class + mcol = MultiIndex.from_tuples([("A", "A"), ("A", "B")]) + cdf_multi = CustomDataFrame([[0, 1], [2, 3]], columns=mcol) + assert isinstance(cdf_multi["A"], CustomDataFrame) + + mcol = MultiIndex.from_tuples([("A", ""), ("B", "")]) + cdf_multi2 = CustomDataFrame([[0, 1], [2, 3]], columns=mcol) + assert isinstance(cdf_multi2["A"], CustomSeries) + + def test_dataframe_metadata(self): + df = tm.SubclassedDataFrame( + {"X": [1, 2, 3], "Y": [1, 2, 3]}, index=["a", "b", "c"] + ) + df.testattr = "XXX" + + assert df.testattr == "XXX" + assert df[["X"]].testattr == "XXX" + assert df.loc[["a", "b"], :].testattr == "XXX" + assert df.iloc[[0, 1], :].testattr == "XXX" + + # see gh-9776 + assert df.iloc[0:1, :].testattr == "XXX" + + # see gh-10553 + unpickled = tm.round_trip_pickle(df) + tm.assert_frame_equal(df, unpickled) + assert df._metadata == unpickled._metadata + assert df.testattr == unpickled.testattr + + def test_indexing_sliced(self): + # GH 11559 + df = tm.SubclassedDataFrame( + {"X": [1, 2, 3], "Y": [4, 5, 6], "Z": [7, 8, 9]}, index=["a", "b", "c"] + ) + res = df.loc[:, "X"] + exp = tm.SubclassedSeries([1, 2, 3], index=list("abc"), name="X") + tm.assert_series_equal(res, exp) + assert isinstance(res, tm.SubclassedSeries) + + res = df.iloc[:, 1] + exp = tm.SubclassedSeries([4, 5, 6], index=list("abc"), name="Y") + tm.assert_series_equal(res, exp) + assert isinstance(res, tm.SubclassedSeries) + + res = df.loc[:, "Z"] + exp = tm.SubclassedSeries([7, 8, 9], index=list("abc"), name="Z") + tm.assert_series_equal(res, exp) + assert isinstance(res, tm.SubclassedSeries) + + res = df.loc["a", :] + exp = tm.SubclassedSeries([1, 4, 7], index=list("XYZ"), name="a") + tm.assert_series_equal(res, exp) + assert isinstance(res, tm.SubclassedSeries) + + res = df.iloc[1, :] + exp = tm.SubclassedSeries([2, 5, 8], index=list("XYZ"), name="b") + tm.assert_series_equal(res, exp) + assert isinstance(res, tm.SubclassedSeries) + + res = df.loc["c", :] + exp = tm.SubclassedSeries([3, 6, 9], index=list("XYZ"), name="c") + tm.assert_series_equal(res, exp) + assert isinstance(res, tm.SubclassedSeries) + + def test_subclass_attr_err_propagation(self): + # GH 11808 + class A(DataFrame): + @property + def nonexistence(self): + return self.i_dont_exist + + with pytest.raises(AttributeError, match=".*i_dont_exist.*"): + A().nonexistence + + def test_subclass_align(self): + # GH 12983 + df1 = tm.SubclassedDataFrame( + {"a": [1, 3, 5], "b": [1, 3, 5]}, index=list("ACE") + ) + df2 = tm.SubclassedDataFrame( + {"c": [1, 2, 4], "d": [1, 2, 4]}, index=list("ABD") + ) + + res1, res2 = df1.align(df2, axis=0) + exp1 = tm.SubclassedDataFrame( + {"a": [1, np.nan, 3, np.nan, 5], "b": [1, np.nan, 3, np.nan, 5]}, + index=list("ABCDE"), + ) + exp2 = tm.SubclassedDataFrame( + {"c": [1, 2, np.nan, 4, np.nan], "d": [1, 2, np.nan, 4, np.nan]}, + index=list("ABCDE"), + ) + assert isinstance(res1, tm.SubclassedDataFrame) + tm.assert_frame_equal(res1, exp1) + assert isinstance(res2, tm.SubclassedDataFrame) + tm.assert_frame_equal(res2, exp2) + + res1, res2 = df1.a.align(df2.c) + assert isinstance(res1, tm.SubclassedSeries) + tm.assert_series_equal(res1, exp1.a) + assert isinstance(res2, tm.SubclassedSeries) + tm.assert_series_equal(res2, exp2.c) + + def test_subclass_align_combinations(self): + # GH 12983 + df = tm.SubclassedDataFrame({"a": [1, 3, 5], "b": [1, 3, 5]}, index=list("ACE")) + s = tm.SubclassedSeries([1, 2, 4], index=list("ABD"), name="x") + + # frame + series + res1, res2 = df.align(s, axis=0) + exp1 = tm.SubclassedDataFrame( + {"a": [1, np.nan, 3, np.nan, 5], "b": [1, np.nan, 3, np.nan, 5]}, + index=list("ABCDE"), + ) + # name is lost when + exp2 = tm.SubclassedSeries( + [1, 2, np.nan, 4, np.nan], index=list("ABCDE"), name="x" + ) + + assert isinstance(res1, tm.SubclassedDataFrame) + tm.assert_frame_equal(res1, exp1) + assert isinstance(res2, tm.SubclassedSeries) + tm.assert_series_equal(res2, exp2) + + # series + frame + res1, res2 = s.align(df) + assert isinstance(res1, tm.SubclassedSeries) + tm.assert_series_equal(res1, exp2) + assert isinstance(res2, tm.SubclassedDataFrame) + tm.assert_frame_equal(res2, exp1) + + def test_subclass_iterrows(self): + # GH 13977 + df = tm.SubclassedDataFrame({"a": [1]}) + for i, row in df.iterrows(): + assert isinstance(row, tm.SubclassedSeries) + tm.assert_series_equal(row, df.loc[i]) + + def test_subclass_stack(self): + # GH 15564 + df = tm.SubclassedDataFrame( + [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + index=["a", "b", "c"], + columns=["X", "Y", "Z"], + ) + + res = df.stack(future_stack=True) + exp = tm.SubclassedSeries( + [1, 2, 3, 4, 5, 6, 7, 8, 9], index=[list("aaabbbccc"), list("XYZXYZXYZ")] + ) + + tm.assert_series_equal(res, exp) + + def test_subclass_stack_multi(self): + # GH 15564 + df = tm.SubclassedDataFrame( + [[10, 11, 12, 13], [20, 21, 22, 23], [30, 31, 32, 33], [40, 41, 42, 43]], + index=MultiIndex.from_tuples( + list(zip(list("AABB"), list("cdcd"))), names=["aaa", "ccc"] + ), + columns=MultiIndex.from_tuples( + list(zip(list("WWXX"), list("yzyz"))), names=["www", "yyy"] + ), + ) + + exp = tm.SubclassedDataFrame( + [ + [10, 12], + [11, 13], + [20, 22], + [21, 23], + [30, 32], + [31, 33], + [40, 42], + [41, 43], + ], + index=MultiIndex.from_tuples( + list(zip(list("AAAABBBB"), list("ccddccdd"), list("yzyzyzyz"))), + names=["aaa", "ccc", "yyy"], + ), + columns=Index(["W", "X"], name="www"), + ) + + res = df.stack(future_stack=True) + tm.assert_frame_equal(res, exp) + + res = df.stack("yyy", future_stack=True) + tm.assert_frame_equal(res, exp) + + exp = tm.SubclassedDataFrame( + [ + [10, 11], + [12, 13], + [20, 21], + [22, 23], + [30, 31], + [32, 33], + [40, 41], + [42, 43], + ], + index=MultiIndex.from_tuples( + list(zip(list("AAAABBBB"), list("ccddccdd"), list("WXWXWXWX"))), + names=["aaa", "ccc", "www"], + ), + columns=Index(["y", "z"], name="yyy"), + ) + + res = df.stack("www", future_stack=True) + tm.assert_frame_equal(res, exp) + + def test_subclass_stack_multi_mixed(self): + # GH 15564 + df = tm.SubclassedDataFrame( + [ + [10, 11, 12.0, 13.0], + [20, 21, 22.0, 23.0], + [30, 31, 32.0, 33.0], + [40, 41, 42.0, 43.0], + ], + index=MultiIndex.from_tuples( + list(zip(list("AABB"), list("cdcd"))), names=["aaa", "ccc"] + ), + columns=MultiIndex.from_tuples( + list(zip(list("WWXX"), list("yzyz"))), names=["www", "yyy"] + ), + ) + + exp = tm.SubclassedDataFrame( + [ + [10, 12.0], + [11, 13.0], + [20, 22.0], + [21, 23.0], + [30, 32.0], + [31, 33.0], + [40, 42.0], + [41, 43.0], + ], + index=MultiIndex.from_tuples( + list(zip(list("AAAABBBB"), list("ccddccdd"), list("yzyzyzyz"))), + names=["aaa", "ccc", "yyy"], + ), + columns=Index(["W", "X"], name="www"), + ) + + res = df.stack(future_stack=True) + tm.assert_frame_equal(res, exp) + + res = df.stack("yyy", future_stack=True) + tm.assert_frame_equal(res, exp) + + exp = tm.SubclassedDataFrame( + [ + [10.0, 11.0], + [12.0, 13.0], + [20.0, 21.0], + [22.0, 23.0], + [30.0, 31.0], + [32.0, 33.0], + [40.0, 41.0], + [42.0, 43.0], + ], + index=MultiIndex.from_tuples( + list(zip(list("AAAABBBB"), list("ccddccdd"), list("WXWXWXWX"))), + names=["aaa", "ccc", "www"], + ), + columns=Index(["y", "z"], name="yyy"), + ) + + res = df.stack("www", future_stack=True) + tm.assert_frame_equal(res, exp) + + def test_subclass_unstack(self): + # GH 15564 + df = tm.SubclassedDataFrame( + [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + index=["a", "b", "c"], + columns=["X", "Y", "Z"], + ) + + res = df.unstack() + exp = tm.SubclassedSeries( + [1, 4, 7, 2, 5, 8, 3, 6, 9], index=[list("XXXYYYZZZ"), list("abcabcabc")] + ) + + tm.assert_series_equal(res, exp) + + def test_subclass_unstack_multi(self): + # GH 15564 + df = tm.SubclassedDataFrame( + [[10, 11, 12, 13], [20, 21, 22, 23], [30, 31, 32, 33], [40, 41, 42, 43]], + index=MultiIndex.from_tuples( + list(zip(list("AABB"), list("cdcd"))), names=["aaa", "ccc"] + ), + columns=MultiIndex.from_tuples( + list(zip(list("WWXX"), list("yzyz"))), names=["www", "yyy"] + ), + ) + + exp = tm.SubclassedDataFrame( + [[10, 20, 11, 21, 12, 22, 13, 23], [30, 40, 31, 41, 32, 42, 33, 43]], + index=Index(["A", "B"], name="aaa"), + columns=MultiIndex.from_tuples( + list(zip(list("WWWWXXXX"), list("yyzzyyzz"), list("cdcdcdcd"))), + names=["www", "yyy", "ccc"], + ), + ) + + res = df.unstack() + tm.assert_frame_equal(res, exp) + + res = df.unstack("ccc") + tm.assert_frame_equal(res, exp) + + exp = tm.SubclassedDataFrame( + [[10, 30, 11, 31, 12, 32, 13, 33], [20, 40, 21, 41, 22, 42, 23, 43]], + index=Index(["c", "d"], name="ccc"), + columns=MultiIndex.from_tuples( + list(zip(list("WWWWXXXX"), list("yyzzyyzz"), list("ABABABAB"))), + names=["www", "yyy", "aaa"], + ), + ) + + res = df.unstack("aaa") + tm.assert_frame_equal(res, exp) + + def test_subclass_unstack_multi_mixed(self): + # GH 15564 + df = tm.SubclassedDataFrame( + [ + [10, 11, 12.0, 13.0], + [20, 21, 22.0, 23.0], + [30, 31, 32.0, 33.0], + [40, 41, 42.0, 43.0], + ], + index=MultiIndex.from_tuples( + list(zip(list("AABB"), list("cdcd"))), names=["aaa", "ccc"] + ), + columns=MultiIndex.from_tuples( + list(zip(list("WWXX"), list("yzyz"))), names=["www", "yyy"] + ), + ) + + exp = tm.SubclassedDataFrame( + [ + [10, 20, 11, 21, 12.0, 22.0, 13.0, 23.0], + [30, 40, 31, 41, 32.0, 42.0, 33.0, 43.0], + ], + index=Index(["A", "B"], name="aaa"), + columns=MultiIndex.from_tuples( + list(zip(list("WWWWXXXX"), list("yyzzyyzz"), list("cdcdcdcd"))), + names=["www", "yyy", "ccc"], + ), + ) + + res = df.unstack() + tm.assert_frame_equal(res, exp) + + res = df.unstack("ccc") + tm.assert_frame_equal(res, exp) + + exp = tm.SubclassedDataFrame( + [ + [10, 30, 11, 31, 12.0, 32.0, 13.0, 33.0], + [20, 40, 21, 41, 22.0, 42.0, 23.0, 43.0], + ], + index=Index(["c", "d"], name="ccc"), + columns=MultiIndex.from_tuples( + list(zip(list("WWWWXXXX"), list("yyzzyyzz"), list("ABABABAB"))), + names=["www", "yyy", "aaa"], + ), + ) + + res = df.unstack("aaa") + tm.assert_frame_equal(res, exp) + + def test_subclass_pivot(self): + # GH 15564 + df = tm.SubclassedDataFrame( + { + "index": ["A", "B", "C", "C", "B", "A"], + "columns": ["One", "One", "One", "Two", "Two", "Two"], + "values": [1.0, 2.0, 3.0, 3.0, 2.0, 1.0], + } + ) + + pivoted = df.pivot(index="index", columns="columns", values="values") + + expected = tm.SubclassedDataFrame( + { + "One": {"A": 1.0, "B": 2.0, "C": 3.0}, + "Two": {"A": 1.0, "B": 2.0, "C": 3.0}, + } + ) + + expected.index.name, expected.columns.name = "index", "columns" + + tm.assert_frame_equal(pivoted, expected) + + def test_subclassed_melt(self): + # GH 15564 + cheese = tm.SubclassedDataFrame( + { + "first": ["John", "Mary"], + "last": ["Doe", "Bo"], + "height": [5.5, 6.0], + "weight": [130, 150], + } + ) + + melted = pd.melt(cheese, id_vars=["first", "last"]) + + expected = tm.SubclassedDataFrame( + [ + ["John", "Doe", "height", 5.5], + ["Mary", "Bo", "height", 6.0], + ["John", "Doe", "weight", 130], + ["Mary", "Bo", "weight", 150], + ], + columns=["first", "last", "variable", "value"], + ) + + tm.assert_frame_equal(melted, expected) + + def test_subclassed_wide_to_long(self): + # GH 9762 + + x = np.random.default_rng(2).standard_normal(3) + df = tm.SubclassedDataFrame( + { + "A1970": {0: "a", 1: "b", 2: "c"}, + "A1980": {0: "d", 1: "e", 2: "f"}, + "B1970": {0: 2.5, 1: 1.2, 2: 0.7}, + "B1980": {0: 3.2, 1: 1.3, 2: 0.1}, + "X": dict(zip(range(3), x)), + } + ) + + df["id"] = df.index + exp_data = { + "X": x.tolist() + x.tolist(), + "A": ["a", "b", "c", "d", "e", "f"], + "B": [2.5, 1.2, 0.7, 3.2, 1.3, 0.1], + "year": [1970, 1970, 1970, 1980, 1980, 1980], + "id": [0, 1, 2, 0, 1, 2], + } + expected = tm.SubclassedDataFrame(exp_data) + expected = expected.set_index(["id", "year"])[["X", "A", "B"]] + long_frame = pd.wide_to_long(df, ["A", "B"], i="id", j="year") + + tm.assert_frame_equal(long_frame, expected) + + def test_subclassed_apply(self): + # GH 19822 + + def check_row_subclass(row): + assert isinstance(row, tm.SubclassedSeries) + + def stretch(row): + if row["variable"] == "height": + row["value"] += 0.5 + return row + + df = tm.SubclassedDataFrame( + [ + ["John", "Doe", "height", 5.5], + ["Mary", "Bo", "height", 6.0], + ["John", "Doe", "weight", 130], + ["Mary", "Bo", "weight", 150], + ], + columns=["first", "last", "variable", "value"], + ) + + df.apply(lambda x: check_row_subclass(x)) + df.apply(lambda x: check_row_subclass(x), axis=1) + + expected = tm.SubclassedDataFrame( + [ + ["John", "Doe", "height", 6.0], + ["Mary", "Bo", "height", 6.5], + ["John", "Doe", "weight", 130], + ["Mary", "Bo", "weight", 150], + ], + columns=["first", "last", "variable", "value"], + ) + + result = df.apply(lambda x: stretch(x), axis=1) + assert isinstance(result, tm.SubclassedDataFrame) + tm.assert_frame_equal(result, expected) + + expected = tm.SubclassedDataFrame([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) + + result = df.apply(lambda x: tm.SubclassedSeries([1, 2, 3]), axis=1) + assert isinstance(result, tm.SubclassedDataFrame) + tm.assert_frame_equal(result, expected) + + result = df.apply(lambda x: [1, 2, 3], axis=1, result_type="expand") + assert isinstance(result, tm.SubclassedDataFrame) + tm.assert_frame_equal(result, expected) + + expected = tm.SubclassedSeries([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) + + result = df.apply(lambda x: [1, 2, 3], axis=1) + assert not isinstance(result, tm.SubclassedDataFrame) + tm.assert_series_equal(result, expected) + + def test_subclassed_reductions(self, all_reductions): + # GH 25596 + + df = tm.SubclassedDataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]}) + result = getattr(df, all_reductions)() + assert isinstance(result, tm.SubclassedSeries) + + def test_subclassed_count(self): + df = tm.SubclassedDataFrame( + { + "Person": ["John", "Myla", "Lewis", "John", "Myla"], + "Age": [24.0, np.nan, 21.0, 33, 26], + "Single": [False, True, True, True, False], + } + ) + result = df.count() + assert isinstance(result, tm.SubclassedSeries) + + df = tm.SubclassedDataFrame({"A": [1, 0, 3], "B": [0, 5, 6], "C": [7, 8, 0]}) + result = df.count() + assert isinstance(result, tm.SubclassedSeries) + + df = tm.SubclassedDataFrame( + [[10, 11, 12, 13], [20, 21, 22, 23], [30, 31, 32, 33], [40, 41, 42, 43]], + index=MultiIndex.from_tuples( + list(zip(list("AABB"), list("cdcd"))), names=["aaa", "ccc"] + ), + columns=MultiIndex.from_tuples( + list(zip(list("WWXX"), list("yzyz"))), names=["www", "yyy"] + ), + ) + result = df.count() + assert isinstance(result, tm.SubclassedSeries) + + df = tm.SubclassedDataFrame() + result = df.count() + assert isinstance(result, tm.SubclassedSeries) + + def test_isin(self): + df = tm.SubclassedDataFrame( + {"num_legs": [2, 4], "num_wings": [2, 0]}, index=["falcon", "dog"] + ) + result = df.isin([0, 2]) + assert isinstance(result, tm.SubclassedDataFrame) + + def test_duplicated(self): + df = tm.SubclassedDataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]}) + result = df.duplicated() + assert isinstance(result, tm.SubclassedSeries) + + df = tm.SubclassedDataFrame() + result = df.duplicated() + assert isinstance(result, tm.SubclassedSeries) + + @pytest.mark.parametrize("idx_method", ["idxmax", "idxmin"]) + def test_idx(self, idx_method): + df = tm.SubclassedDataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]}) + result = getattr(df, idx_method)() + assert isinstance(result, tm.SubclassedSeries) + + def test_dot(self): + df = tm.SubclassedDataFrame([[0, 1, -2, -1], [1, 1, 1, 1]]) + s = tm.SubclassedSeries([1, 1, 2, 1]) + result = df.dot(s) + assert isinstance(result, tm.SubclassedSeries) + + df = tm.SubclassedDataFrame([[0, 1, -2, -1], [1, 1, 1, 1]]) + s = tm.SubclassedDataFrame([1, 1, 2, 1]) + result = df.dot(s) + assert isinstance(result, tm.SubclassedDataFrame) + + def test_memory_usage(self): + df = tm.SubclassedDataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]}) + result = df.memory_usage() + assert isinstance(result, tm.SubclassedSeries) + + result = df.memory_usage(index=False) + assert isinstance(result, tm.SubclassedSeries) + + def test_corrwith(self): + pytest.importorskip("scipy") + index = ["a", "b", "c", "d", "e"] + columns = ["one", "two", "three", "four"] + df1 = tm.SubclassedDataFrame( + np.random.default_rng(2).standard_normal((5, 4)), + index=index, + columns=columns, + ) + df2 = tm.SubclassedDataFrame( + np.random.default_rng(2).standard_normal((4, 4)), + index=index[:4], + columns=columns, + ) + correls = df1.corrwith(df2, axis=1, drop=True, method="kendall") + + assert isinstance(correls, (tm.SubclassedSeries)) + + def test_asof(self): + N = 3 + rng = pd.date_range("1/1/1990", periods=N, freq="53s") + df = tm.SubclassedDataFrame( + { + "A": [np.nan, np.nan, np.nan], + "B": [np.nan, np.nan, np.nan], + "C": [np.nan, np.nan, np.nan], + }, + index=rng, + ) + + result = df.asof(rng[-2:]) + assert isinstance(result, tm.SubclassedDataFrame) + + result = df.asof(rng[-2]) + assert isinstance(result, tm.SubclassedSeries) + + result = df.asof("1989-12-31") + assert isinstance(result, tm.SubclassedSeries) + + def test_idxmin_preserves_subclass(self): + # GH 28330 + + df = tm.SubclassedDataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]}) + result = df.idxmin() + assert isinstance(result, tm.SubclassedSeries) + + def test_idxmax_preserves_subclass(self): + # GH 28330 + + df = tm.SubclassedDataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]}) + result = df.idxmax() + assert isinstance(result, tm.SubclassedSeries) + + def test_convert_dtypes_preserves_subclass(self, gpd_style_subclass_df): + # GH 43668 + df = tm.SubclassedDataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]}) + result = df.convert_dtypes() + assert isinstance(result, tm.SubclassedDataFrame) + + result = gpd_style_subclass_df.convert_dtypes() + assert isinstance(result, type(gpd_style_subclass_df)) + + def test_astype_preserves_subclass(self): + # GH#40810 + df = tm.SubclassedDataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]}) + + result = df.astype({"A": np.int64, "B": np.int32, "C": np.float64}) + assert isinstance(result, tm.SubclassedDataFrame) + + def test_equals_subclass(self): + # https://github.com/pandas-dev/pandas/pull/34402 + # allow subclass in both directions + df1 = DataFrame({"a": [1, 2, 3]}) + df2 = tm.SubclassedDataFrame({"a": [1, 2, 3]}) + assert df1.equals(df2) + assert df2.equals(df1) + + def test_replace_list_method(self): + # https://github.com/pandas-dev/pandas/pull/46018 + df = tm.SubclassedDataFrame({"A": [0, 1, 2]}) + msg = "The 'method' keyword in SubclassedDataFrame.replace is deprecated" + with tm.assert_produces_warning( + FutureWarning, match=msg, raise_on_extra_warnings=False + ): + result = df.replace([1, 2], method="ffill") + expected = tm.SubclassedDataFrame({"A": [0, 0, 0]}) + assert isinstance(result, tm.SubclassedDataFrame) + tm.assert_frame_equal(result, expected) + + +class MySubclassWithMetadata(DataFrame): + _metadata = ["my_metadata"] + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + my_metadata = kwargs.pop("my_metadata", None) + if args and isinstance(args[0], MySubclassWithMetadata): + my_metadata = args[0].my_metadata # type: ignore[has-type] + self.my_metadata = my_metadata + + @property + def _constructor(self): + return MySubclassWithMetadata + + +def test_constructor_with_metadata(): + # https://github.com/pandas-dev/pandas/pull/54922 + # https://github.com/pandas-dev/pandas/issues/55120 + df = MySubclassWithMetadata( + np.random.default_rng(2).random((5, 3)), columns=["A", "B", "C"] + ) + subset = df[["A", "B"]] + assert isinstance(subset, MySubclassWithMetadata) + + +class SimpleDataFrameSubClass(DataFrame): + """A subclass of DataFrame that does not define a constructor.""" + + +class SimpleSeriesSubClass(Series): + """A subclass of Series that does not define a constructor.""" + + +class TestSubclassWithoutConstructor: + def test_copy_df(self): + expected = DataFrame({"a": [1, 2, 3]}) + result = SimpleDataFrameSubClass(expected).copy() + + assert ( + type(result) is DataFrame + ) # assert_frame_equal only checks isinstance(lhs, type(rhs)) + tm.assert_frame_equal(result, expected) + + def test_copy_series(self): + expected = Series([1, 2, 3]) + result = SimpleSeriesSubClass(expected).copy() + + tm.assert_series_equal(result, expected) + + def test_series_to_frame(self): + orig = Series([1, 2, 3]) + expected = orig.to_frame() + result = SimpleSeriesSubClass(orig).to_frame() + + assert ( + type(result) is DataFrame + ) # assert_frame_equal only checks isinstance(lhs, type(rhs)) + tm.assert_frame_equal(result, expected) + + def test_groupby(self): + df = SimpleDataFrameSubClass(DataFrame({"a": [1, 2, 3]})) + + for _, v in df.groupby("a"): + assert type(v) is DataFrame diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_ufunc.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_ufunc.py new file mode 100644 index 0000000000000000000000000000000000000000..88c62da2b0a735b103f7a6b03634aa185fc46d2c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_ufunc.py @@ -0,0 +1,311 @@ +from functools import partial +import re + +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm +from pandas.api.types import is_extension_array_dtype + +dtypes = [ + "int64", + "Int64", + {"A": "int64", "B": "Int64"}, +] + + +@pytest.mark.parametrize("dtype", dtypes) +def test_unary_unary(dtype): + # unary input, unary output + values = np.array([[-1, -1], [1, 1]], dtype="int64") + df = pd.DataFrame(values, columns=["A", "B"], index=["a", "b"]).astype(dtype=dtype) + result = np.positive(df) + expected = pd.DataFrame( + np.positive(values), index=df.index, columns=df.columns + ).astype(dtype) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype", dtypes) +def test_unary_binary(request, dtype): + # unary input, binary output + if is_extension_array_dtype(dtype) or isinstance(dtype, dict): + request.applymarker( + pytest.mark.xfail( + reason="Extension / mixed with multiple outputs not implemented." + ) + ) + + values = np.array([[-1, -1], [1, 1]], dtype="int64") + df = pd.DataFrame(values, columns=["A", "B"], index=["a", "b"]).astype(dtype=dtype) + result_pandas = np.modf(df) + assert isinstance(result_pandas, tuple) + assert len(result_pandas) == 2 + expected_numpy = np.modf(values) + + for result, b in zip(result_pandas, expected_numpy): + expected = pd.DataFrame(b, index=df.index, columns=df.columns) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype", dtypes) +def test_binary_input_dispatch_binop(dtype): + # binop ufuncs are dispatched to our dunder methods. + values = np.array([[-1, -1], [1, 1]], dtype="int64") + df = pd.DataFrame(values, columns=["A", "B"], index=["a", "b"]).astype(dtype=dtype) + result = np.add(df, df) + expected = pd.DataFrame( + np.add(values, values), index=df.index, columns=df.columns + ).astype(dtype) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "func,arg,expected", + [ + (np.add, 1, [2, 3, 4, 5]), + ( + partial(np.add, where=[[False, True], [True, False]]), + np.array([[1, 1], [1, 1]]), + [0, 3, 4, 0], + ), + (np.power, np.array([[1, 1], [2, 2]]), [1, 2, 9, 16]), + (np.subtract, 2, [-1, 0, 1, 2]), + ( + partial(np.negative, where=np.array([[False, True], [True, False]])), + None, + [0, -2, -3, 0], + ), + ], +) +def test_ufunc_passes_args(func, arg, expected): + # GH#40662 + arr = np.array([[1, 2], [3, 4]]) + df = pd.DataFrame(arr) + result_inplace = np.zeros_like(arr) + # 1-argument ufunc + if arg is None: + result = func(df, out=result_inplace) + else: + result = func(df, arg, out=result_inplace) + + expected = np.array(expected).reshape(2, 2) + tm.assert_numpy_array_equal(result_inplace, expected) + + expected = pd.DataFrame(expected) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype_a", dtypes) +@pytest.mark.parametrize("dtype_b", dtypes) +def test_binary_input_aligns_columns(request, dtype_a, dtype_b): + if ( + is_extension_array_dtype(dtype_a) + or isinstance(dtype_a, dict) + or is_extension_array_dtype(dtype_b) + or isinstance(dtype_b, dict) + ): + request.applymarker( + pytest.mark.xfail( + reason="Extension / mixed with multiple inputs not implemented." + ) + ) + + df1 = pd.DataFrame({"A": [1, 2], "B": [3, 4]}).astype(dtype_a) + + if isinstance(dtype_a, dict) and isinstance(dtype_b, dict): + dtype_b = dtype_b.copy() + dtype_b["C"] = dtype_b.pop("B") + df2 = pd.DataFrame({"A": [1, 2], "C": [3, 4]}).astype(dtype_b) + # As of 2.0, align first before applying the ufunc + result = np.heaviside(df1, df2) + expected = np.heaviside( + np.array([[1, 3, np.nan], [2, 4, np.nan]]), + np.array([[1, np.nan, 3], [2, np.nan, 4]]), + ) + expected = pd.DataFrame(expected, index=[0, 1], columns=["A", "B", "C"]) + tm.assert_frame_equal(result, expected) + + result = np.heaviside(df1, df2.values) + expected = pd.DataFrame([[1.0, 1.0], [1.0, 1.0]], columns=["A", "B"]) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype", dtypes) +def test_binary_input_aligns_index(request, dtype): + if is_extension_array_dtype(dtype) or isinstance(dtype, dict): + request.applymarker( + pytest.mark.xfail( + reason="Extension / mixed with multiple inputs not implemented." + ) + ) + df1 = pd.DataFrame({"A": [1, 2], "B": [3, 4]}, index=["a", "b"]).astype(dtype) + df2 = pd.DataFrame({"A": [1, 2], "B": [3, 4]}, index=["a", "c"]).astype(dtype) + result = np.heaviside(df1, df2) + expected = np.heaviside( + np.array([[1, 3], [3, 4], [np.nan, np.nan]]), + np.array([[1, 3], [np.nan, np.nan], [3, 4]]), + ) + # TODO(FloatArray): this will be Float64Dtype. + expected = pd.DataFrame(expected, index=["a", "b", "c"], columns=["A", "B"]) + tm.assert_frame_equal(result, expected) + + result = np.heaviside(df1, df2.values) + expected = pd.DataFrame( + [[1.0, 1.0], [1.0, 1.0]], columns=["A", "B"], index=["a", "b"] + ) + tm.assert_frame_equal(result, expected) + + +def test_binary_frame_series_raises(): + # We don't currently implement + df = pd.DataFrame({"A": [1, 2]}) + with pytest.raises(NotImplementedError, match="logaddexp"): + np.logaddexp(df, df["A"]) + + with pytest.raises(NotImplementedError, match="logaddexp"): + np.logaddexp(df["A"], df) + + +def test_unary_accumulate_axis(): + # https://github.com/pandas-dev/pandas/issues/39259 + df = pd.DataFrame({"a": [1, 3, 2, 4]}) + result = np.maximum.accumulate(df) + expected = pd.DataFrame({"a": [1, 3, 3, 4]}) + tm.assert_frame_equal(result, expected) + + df = pd.DataFrame({"a": [1, 3, 2, 4], "b": [0.1, 4.0, 3.0, 2.0]}) + result = np.maximum.accumulate(df) + # in theory could preserve int dtype for default axis=0 + expected = pd.DataFrame({"a": [1.0, 3.0, 3.0, 4.0], "b": [0.1, 4.0, 4.0, 4.0]}) + tm.assert_frame_equal(result, expected) + + result = np.maximum.accumulate(df, axis=0) + tm.assert_frame_equal(result, expected) + + result = np.maximum.accumulate(df, axis=1) + expected = pd.DataFrame({"a": [1.0, 3.0, 2.0, 4.0], "b": [1.0, 4.0, 3.0, 4.0]}) + tm.assert_frame_equal(result, expected) + + +def test_frame_outer_disallowed(): + df = pd.DataFrame({"A": [1, 2]}) + with pytest.raises(NotImplementedError, match=""): + # deprecation enforced in 2.0 + np.subtract.outer(df, df) + + +def test_alignment_deprecation_enforced(): + # Enforced in 2.0 + # https://github.com/pandas-dev/pandas/issues/39184 + df1 = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df2 = pd.DataFrame({"b": [1, 2, 3], "c": [4, 5, 6]}) + s1 = pd.Series([1, 2], index=["a", "b"]) + s2 = pd.Series([1, 2], index=["b", "c"]) + + # binary dataframe / dataframe + expected = pd.DataFrame({"a": [2, 4, 6], "b": [8, 10, 12]}) + + with tm.assert_produces_warning(None): + # aligned -> no warning! + result = np.add(df1, df1) + tm.assert_frame_equal(result, expected) + + result = np.add(df1, df2.values) + tm.assert_frame_equal(result, expected) + + result = np.add(df1, df2) + expected = pd.DataFrame({"a": [np.nan] * 3, "b": [5, 7, 9], "c": [np.nan] * 3}) + tm.assert_frame_equal(result, expected) + + result = np.add(df1.values, df2) + expected = pd.DataFrame({"b": [2, 4, 6], "c": [8, 10, 12]}) + tm.assert_frame_equal(result, expected) + + # binary dataframe / series + expected = pd.DataFrame({"a": [2, 3, 4], "b": [6, 7, 8]}) + + with tm.assert_produces_warning(None): + # aligned -> no warning! + result = np.add(df1, s1) + tm.assert_frame_equal(result, expected) + + result = np.add(df1, s2.values) + tm.assert_frame_equal(result, expected) + + expected = pd.DataFrame( + {"a": [np.nan] * 3, "b": [5.0, 6.0, 7.0], "c": [np.nan] * 3} + ) + result = np.add(df1, s2) + tm.assert_frame_equal(result, expected) + + msg = "Cannot apply ufunc to mixed DataFrame and Series inputs." + with pytest.raises(NotImplementedError, match=msg): + np.add(s2, df1) + + +def test_alignment_deprecation_many_inputs_enforced(): + # Enforced in 2.0 + # https://github.com/pandas-dev/pandas/issues/39184 + # test that the deprecation also works with > 2 inputs -> using a numba + # written ufunc for this because numpy itself doesn't have such ufuncs + numba = pytest.importorskip("numba") + + @numba.vectorize([numba.float64(numba.float64, numba.float64, numba.float64)]) + def my_ufunc(x, y, z): + return x + y + z + + df1 = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df2 = pd.DataFrame({"b": [1, 2, 3], "c": [4, 5, 6]}) + df3 = pd.DataFrame({"a": [1, 2, 3], "c": [4, 5, 6]}) + + result = my_ufunc(df1, df2, df3) + expected = pd.DataFrame(np.full((3, 3), np.nan), columns=["a", "b", "c"]) + tm.assert_frame_equal(result, expected) + + # all aligned -> no warning + with tm.assert_produces_warning(None): + result = my_ufunc(df1, df1, df1) + expected = pd.DataFrame([[3.0, 12.0], [6.0, 15.0], [9.0, 18.0]], columns=["a", "b"]) + tm.assert_frame_equal(result, expected) + + # mixed frame / arrays + msg = ( + r"operands could not be broadcast together with shapes \(3,3\) \(3,3\) \(3,2\)" + ) + with pytest.raises(ValueError, match=msg): + my_ufunc(df1, df2, df3.values) + + # single frame -> no warning + with tm.assert_produces_warning(None): + result = my_ufunc(df1, df2.values, df3.values) + tm.assert_frame_equal(result, expected) + + # takes indices of first frame + msg = ( + r"operands could not be broadcast together with shapes \(3,2\) \(3,3\) \(3,3\)" + ) + with pytest.raises(ValueError, match=msg): + my_ufunc(df1.values, df2, df3) + + +def test_array_ufuncs_for_many_arguments(): + # GH39853 + def add3(x, y, z): + return x + y + z + + ufunc = np.frompyfunc(add3, 3, 1) + df = pd.DataFrame([[1, 2], [3, 4]]) + + result = ufunc(df, df, 1) + expected = pd.DataFrame([[3, 5], [7, 9]], dtype=object) + tm.assert_frame_equal(result, expected) + + ser = pd.Series([1, 2]) + msg = ( + "Cannot apply ufunc " + "to mixed DataFrame and Series inputs." + ) + with pytest.raises(NotImplementedError, match=re.escape(msg)): + ufunc(df, df, ser) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_unary.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_unary.py new file mode 100644 index 0000000000000000000000000000000000000000..850c92013694fa6724d0b3450a1f378600eaadab --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_unary.py @@ -0,0 +1,204 @@ +from decimal import Decimal + +import numpy as np +import pytest + +from pandas.compat.numpy import np_version_gte1p25 + +import pandas as pd +import pandas._testing as tm + + +class TestDataFrameUnaryOperators: + # __pos__, __neg__, __invert__ + + @pytest.mark.parametrize( + "df,expected", + [ + (pd.DataFrame({"a": [-1, 1]}), pd.DataFrame({"a": [1, -1]})), + (pd.DataFrame({"a": [False, True]}), pd.DataFrame({"a": [True, False]})), + ( + pd.DataFrame({"a": pd.Series(pd.to_timedelta([-1, 1]))}), + pd.DataFrame({"a": pd.Series(pd.to_timedelta([1, -1]))}), + ), + ], + ) + def test_neg_numeric(self, df, expected): + tm.assert_frame_equal(-df, expected) + tm.assert_series_equal(-df["a"], expected["a"]) + + @pytest.mark.parametrize( + "df, expected", + [ + (np.array([1, 2], dtype=object), np.array([-1, -2], dtype=object)), + ([Decimal("1.0"), Decimal("2.0")], [Decimal("-1.0"), Decimal("-2.0")]), + ], + ) + def test_neg_object(self, df, expected): + # GH#21380 + df = pd.DataFrame({"a": df}) + expected = pd.DataFrame({"a": expected}) + tm.assert_frame_equal(-df, expected) + tm.assert_series_equal(-df["a"], expected["a"]) + + @pytest.mark.parametrize( + "df", + [ + pd.DataFrame({"a": ["a", "b"]}), + pd.DataFrame({"a": pd.to_datetime(["2017-01-22", "1970-01-01"])}), + ], + ) + def test_neg_raises(self, df, using_infer_string): + msg = ( + "bad operand type for unary -: 'str'|" + r"bad operand type for unary -: 'DatetimeArray'" + ) + if using_infer_string and df.dtypes.iloc[0] == "string": + import pyarrow as pa + + msg = "has no kernel" + with pytest.raises(pa.lib.ArrowNotImplementedError, match=msg): + (-df) + with pytest.raises(pa.lib.ArrowNotImplementedError, match=msg): + (-df["a"]) + + else: + with pytest.raises(TypeError, match=msg): + (-df) + with pytest.raises(TypeError, match=msg): + (-df["a"]) + + def test_invert(self, float_frame): + df = float_frame + + tm.assert_frame_equal(-(df < 0), ~(df < 0)) + + def test_invert_mixed(self): + shape = (10, 5) + df = pd.concat( + [ + pd.DataFrame(np.zeros(shape, dtype="bool")), + pd.DataFrame(np.zeros(shape, dtype=int)), + ], + axis=1, + ignore_index=True, + ) + result = ~df + expected = pd.concat( + [ + pd.DataFrame(np.ones(shape, dtype="bool")), + pd.DataFrame(-np.ones(shape, dtype=int)), + ], + axis=1, + ignore_index=True, + ) + tm.assert_frame_equal(result, expected) + + def test_invert_empty_not_input(self): + # GH#51032 + df = pd.DataFrame() + result = ~df + tm.assert_frame_equal(df, result) + assert df is not result + + @pytest.mark.parametrize( + "df", + [ + pd.DataFrame({"a": [-1, 1]}), + pd.DataFrame({"a": [False, True]}), + pd.DataFrame({"a": pd.Series(pd.to_timedelta([-1, 1]))}), + ], + ) + def test_pos_numeric(self, df): + # GH#16073 + tm.assert_frame_equal(+df, df) + tm.assert_series_equal(+df["a"], df["a"]) + + @pytest.mark.parametrize( + "df", + [ + pd.DataFrame({"a": np.array([-1, 2], dtype=object)}), + pd.DataFrame({"a": [Decimal("-1.0"), Decimal("2.0")]}), + ], + ) + def test_pos_object(self, df): + # GH#21380 + tm.assert_frame_equal(+df, df) + tm.assert_series_equal(+df["a"], df["a"]) + + @pytest.mark.parametrize( + "df", + [ + pytest.param( + pd.DataFrame({"a": ["a", "b"]}), + # filterwarnings removable once min numpy version is 1.25 + marks=[ + pytest.mark.filterwarnings("ignore:Applying:DeprecationWarning") + ], + ), + ], + ) + def test_pos_object_raises(self, df): + # GH#21380 + if np_version_gte1p25: + with pytest.raises( + TypeError, match=r"^bad operand type for unary \+: \'str\'$" + ): + tm.assert_frame_equal(+df, df) + else: + tm.assert_series_equal(+df["a"], df["a"]) + + @pytest.mark.parametrize( + "df", [pd.DataFrame({"a": pd.to_datetime(["2017-01-22", "1970-01-01"])})] + ) + def test_pos_raises(self, df): + msg = r"bad operand type for unary \+: 'DatetimeArray'" + with pytest.raises(TypeError, match=msg): + (+df) + with pytest.raises(TypeError, match=msg): + (+df["a"]) + + def test_unary_nullable(self): + df = pd.DataFrame( + { + "a": pd.array([1, -2, 3, pd.NA], dtype="Int64"), + "b": pd.array([4.0, -5.0, 6.0, pd.NA], dtype="Float32"), + "c": pd.array([True, False, False, pd.NA], dtype="boolean"), + # include numpy bool to make sure bool-vs-boolean behavior + # is consistent in non-NA locations + "d": np.array([True, False, False, True]), + } + ) + + result = +df + res_ufunc = np.positive(df) + expected = df + # TODO: assert that we have copies? + tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(res_ufunc, expected) + + result = -df + res_ufunc = np.negative(df) + expected = pd.DataFrame( + { + "a": pd.array([-1, 2, -3, pd.NA], dtype="Int64"), + "b": pd.array([-4.0, 5.0, -6.0, pd.NA], dtype="Float32"), + "c": pd.array([False, True, True, pd.NA], dtype="boolean"), + "d": np.array([False, True, True, False]), + } + ) + tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(res_ufunc, expected) + + result = abs(df) + res_ufunc = np.abs(df) + expected = pd.DataFrame( + { + "a": pd.array([1, 2, 3, pd.NA], dtype="Int64"), + "b": pd.array([4.0, 5.0, 6.0, pd.NA], dtype="Float32"), + "c": pd.array([True, False, False, pd.NA], dtype="boolean"), + "d": np.array([True, False, False, True]), + } + ) + tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(res_ufunc, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_validate.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_validate.py new file mode 100644 index 0000000000000000000000000000000000000000..e99e0a686384883d570feef949597d08da7e8ff9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/frame/test_validate.py @@ -0,0 +1,41 @@ +import pytest + +from pandas.core.frame import DataFrame + + +@pytest.fixture +def dataframe(): + return DataFrame({"a": [1, 2], "b": [3, 4]}) + + +class TestDataFrameValidate: + """Tests for error handling related to data types of method arguments.""" + + @pytest.mark.parametrize( + "func", + [ + "query", + "eval", + "set_index", + "reset_index", + "dropna", + "drop_duplicates", + "sort_values", + ], + ) + @pytest.mark.parametrize("inplace", [1, "True", [1, 2, 3], 5.0]) + def test_validate_bool_args(self, dataframe, func, inplace): + msg = 'For argument "inplace" expected type bool' + kwargs = {"inplace": inplace} + + if func == "query": + kwargs["expr"] = "a > b" + elif func == "eval": + kwargs["expr"] = "a + b" + elif func == "set_index": + kwargs["keys"] = ["a"] + elif func == "sort_values": + kwargs["by"] = ["a"] + + with pytest.raises(ValueError, match=msg): + getattr(dataframe, func)(**kwargs) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..446d9da4377712b073d76dac7672dcf1de00cf04 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/__init__.py @@ -0,0 +1,25 @@ +def get_groupby_method_args(name, obj): + """ + Get required arguments for a groupby method. + + When parametrizing a test over groupby methods (e.g. "sum", "mean", "fillna"), + it is often the case that arguments are required for certain methods. + + Parameters + ---------- + name: str + Name of the method. + obj: Series or DataFrame + pandas object that is being grouped. + + Returns + ------- + A tuple of required arguments for the method. + """ + if name in ("nth", "fillna", "take"): + return (0,) + if name == "quantile": + return (0.5,) + if name == "corrwith": + return (obj,) + return () diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/__pycache__/conftest.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/__pycache__/conftest.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db3cf68208f56cb89305ad9991465aa2b2d7ff89 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/__pycache__/conftest.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/__pycache__/test_grouping.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/__pycache__/test_grouping.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73cd775b9a8c64172cb283ec303a0f36567fcd97 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/__pycache__/test_grouping.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/__pycache__/test_index_as_string.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/__pycache__/test_index_as_string.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09cf6ffc875976436c0a3233487591f7284dba3c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/__pycache__/test_index_as_string.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/__pycache__/test_numba.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/__pycache__/test_numba.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40345d4360bc0151b598ea1373e950ab1a5b4576 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/__pycache__/test_numba.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/__pycache__/test_reductions.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/__pycache__/test_reductions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a325179c5ef702fc1a79ff4939423fe53c265e65 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/__pycache__/test_reductions.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/aggregate/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/aggregate/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/aggregate/test_aggregate.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/aggregate/test_aggregate.py new file mode 100644 index 0000000000000000000000000000000000000000..6223a153df3588840021210208b7250e48825552 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/aggregate/test_aggregate.py @@ -0,0 +1,1672 @@ +""" +test .agg behavior / note that .apply is tested generally in test_groupby.py +""" +import datetime +import functools +from functools import partial +import re + +import numpy as np +import pytest + +from pandas.errors import SpecificationError + +from pandas.core.dtypes.common import is_integer_dtype + +import pandas as pd +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, + concat, + to_datetime, +) +import pandas._testing as tm +from pandas.core.groupby.grouper import Grouping + + +def test_groupby_agg_no_extra_calls(): + # GH#31760 + df = DataFrame({"key": ["a", "b", "c", "c"], "value": [1, 2, 3, 4]}) + gb = df.groupby("key")["value"] + + def dummy_func(x): + assert len(x) != 0 + return x.sum() + + gb.agg(dummy_func) + + +def test_agg_regression1(tsframe): + grouped = tsframe.groupby([lambda x: x.year, lambda x: x.month]) + result = grouped.agg("mean") + expected = grouped.mean() + tm.assert_frame_equal(result, expected) + + +def test_agg_must_agg(df): + grouped = df.groupby("A")["C"] + + msg = "Must produce aggregated value" + with pytest.raises(Exception, match=msg): + grouped.agg(lambda x: x.describe()) + with pytest.raises(Exception, match=msg): + grouped.agg(lambda x: x.index[:2]) + + +def test_agg_ser_multi_key(df): + f = lambda x: x.sum() + results = df.C.groupby([df.A, df.B]).aggregate(f) + expected = df.groupby(["A", "B"]).sum()["C"] + tm.assert_series_equal(results, expected) + + +def test_groupby_aggregation_mixed_dtype(): + # GH 6212 + expected = DataFrame( + { + "v1": [5, 5, 7, np.nan, 3, 3, 4, 1], + "v2": [55, 55, 77, np.nan, 33, 33, 44, 11], + }, + index=MultiIndex.from_tuples( + [ + (1, 95), + (1, 99), + (2, 95), + (2, 99), + ("big", "damp"), + ("blue", "dry"), + ("red", "red"), + ("red", "wet"), + ], + names=["by1", "by2"], + ), + ) + + df = DataFrame( + { + "v1": [1, 3, 5, 7, 8, 3, 5, np.nan, 4, 5, 7, 9], + "v2": [11, 33, 55, 77, 88, 33, 55, np.nan, 44, 55, 77, 99], + "by1": ["red", "blue", 1, 2, np.nan, "big", 1, 2, "red", 1, np.nan, 12], + "by2": [ + "wet", + "dry", + 99, + 95, + np.nan, + "damp", + 95, + 99, + "red", + 99, + np.nan, + np.nan, + ], + } + ) + + g = df.groupby(["by1", "by2"]) + result = g[["v1", "v2"]].mean() + tm.assert_frame_equal(result, expected) + + +def test_groupby_aggregation_multi_level_column(): + # GH 29772 + lst = [ + [True, True, True, False], + [True, False, np.nan, False], + [True, True, np.nan, False], + [True, True, np.nan, False], + ] + df = DataFrame( + data=lst, + columns=MultiIndex.from_tuples([("A", 0), ("A", 1), ("B", 0), ("B", 1)]), + ) + + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + gb = df.groupby(level=1, axis=1) + result = gb.sum(numeric_only=False) + expected = DataFrame({0: [2.0, True, True, True], 1: [1, 0, 1, 1]}) + + tm.assert_frame_equal(result, expected) + + +def test_agg_apply_corner(ts, tsframe): + # nothing to group, all NA + grouped = ts.groupby(ts * np.nan, group_keys=False) + assert ts.dtype == np.float64 + + # groupby float64 values results in a float64 Index + exp = Series([], dtype=np.float64, index=Index([], dtype=np.float64)) + tm.assert_series_equal(grouped.sum(), exp) + tm.assert_series_equal(grouped.agg("sum"), exp) + tm.assert_series_equal(grouped.apply("sum"), exp, check_index_type=False) + + # DataFrame + grouped = tsframe.groupby(tsframe["A"] * np.nan, group_keys=False) + exp_df = DataFrame( + columns=tsframe.columns, + dtype=float, + index=Index([], name="A", dtype=np.float64), + ) + tm.assert_frame_equal(grouped.sum(), exp_df) + tm.assert_frame_equal(grouped.agg("sum"), exp_df) + + msg = "The behavior of DataFrame.sum with axis=None is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg, check_stacklevel=False): + res = grouped.apply(np.sum) + tm.assert_frame_equal(res, exp_df) + + +def test_agg_grouping_is_list_tuple(ts): + df = DataFrame( + np.random.default_rng(2).standard_normal((30, 4)), + columns=Index(list("ABCD"), dtype=object), + index=pd.date_range("2000-01-01", periods=30, freq="B"), + ) + + grouped = df.groupby(lambda x: x.year) + grouper = grouped._grouper.groupings[0].grouping_vector + grouped._grouper.groupings[0] = Grouping(ts.index, list(grouper)) + + result = grouped.agg("mean") + expected = grouped.mean() + tm.assert_frame_equal(result, expected) + + grouped._grouper.groupings[0] = Grouping(ts.index, tuple(grouper)) + + result = grouped.agg("mean") + expected = grouped.mean() + tm.assert_frame_equal(result, expected) + + +def test_agg_python_multiindex(multiindex_dataframe_random_data): + grouped = multiindex_dataframe_random_data.groupby(["A", "B"]) + + result = grouped.agg("mean") + expected = grouped.mean() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "groupbyfunc", [lambda x: x.weekday(), [lambda x: x.month, lambda x: x.weekday()]] +) +def test_aggregate_str_func(tsframe, groupbyfunc): + grouped = tsframe.groupby(groupbyfunc) + + # single series + result = grouped["A"].agg("std") + expected = grouped["A"].std() + tm.assert_series_equal(result, expected) + + # group frame by function name + result = grouped.aggregate("var") + expected = grouped.var() + tm.assert_frame_equal(result, expected) + + # group frame by function dict + result = grouped.agg({"A": "var", "B": "std", "C": "mean", "D": "sem"}) + expected = DataFrame( + { + "A": grouped["A"].var(), + "B": grouped["B"].std(), + "C": grouped["C"].mean(), + "D": grouped["D"].sem(), + } + ) + tm.assert_frame_equal(result, expected) + + +def test_std_masked_dtype(any_numeric_ea_dtype): + # GH#35516 + df = DataFrame( + { + "a": [2, 1, 1, 1, 2, 2, 1], + "b": Series([pd.NA, 1, 2, 1, 1, 1, 2], dtype="Float64"), + } + ) + result = df.groupby("a").std() + expected = DataFrame( + {"b": [0.57735, 0]}, index=Index([1, 2], name="a"), dtype="Float64" + ) + tm.assert_frame_equal(result, expected) + + +def test_agg_str_with_kwarg_axis_1_raises(df, reduction_func): + gb = df.groupby(level=0) + warn_msg = f"DataFrameGroupBy.{reduction_func} with axis=1 is deprecated" + if reduction_func in ("idxmax", "idxmin"): + error = TypeError + msg = "'[<>]' not supported between instances of 'float' and 'str'" + warn = FutureWarning + else: + error = ValueError + msg = f"Operation {reduction_func} does not support axis=1" + warn = None + with pytest.raises(error, match=msg): + with tm.assert_produces_warning(warn, match=warn_msg): + gb.agg(reduction_func, axis=1) + + +@pytest.mark.parametrize( + "func, expected, dtype, result_dtype_dict", + [ + ("sum", [5, 7, 9], "int64", {}), + ("std", [4.5**0.5] * 3, int, {"i": float, "j": float, "k": float}), + ("var", [4.5] * 3, int, {"i": float, "j": float, "k": float}), + ("sum", [5, 7, 9], "Int64", {"j": "int64"}), + ("std", [4.5**0.5] * 3, "Int64", {"i": float, "j": float, "k": float}), + ("var", [4.5] * 3, "Int64", {"i": "float64", "j": "float64", "k": "float64"}), + ], +) +def test_multiindex_groupby_mixed_cols_axis1(func, expected, dtype, result_dtype_dict): + # GH#43209 + df = DataFrame( + [[1, 2, 3, 4, 5, 6]] * 3, + columns=MultiIndex.from_product([["a", "b"], ["i", "j", "k"]]), + ).astype({("a", "j"): dtype, ("b", "j"): dtype}) + + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + gb = df.groupby(level=1, axis=1) + result = gb.agg(func) + expected = DataFrame([expected] * 3, columns=["i", "j", "k"]).astype( + result_dtype_dict + ) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "func, expected_data, result_dtype_dict", + [ + ("sum", [[2, 4], [10, 12], [18, 20]], {10: "int64", 20: "int64"}), + # std should ideally return Int64 / Float64 #43330 + ("std", [[2**0.5] * 2] * 3, "float64"), + ("var", [[2] * 2] * 3, {10: "float64", 20: "float64"}), + ], +) +def test_groupby_mixed_cols_axis1(func, expected_data, result_dtype_dict): + # GH#43209 + df = DataFrame( + np.arange(12).reshape(3, 4), + index=Index([0, 1, 0], name="y"), + columns=Index([10, 20, 10, 20], name="x"), + dtype="int64", + ).astype({10: "Int64"}) + + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + gb = df.groupby("x", axis=1) + result = gb.agg(func) + expected = DataFrame( + data=expected_data, + index=Index([0, 1, 0], name="y"), + columns=Index([10, 20], name="x"), + ).astype(result_dtype_dict) + tm.assert_frame_equal(result, expected) + + +def test_aggregate_item_by_item(df): + grouped = df.groupby("A") + + aggfun_0 = lambda ser: ser.size + result = grouped.agg(aggfun_0) + foosum = (df.A == "foo").sum() + barsum = (df.A == "bar").sum() + K = len(result.columns) + + # GH5782 + exp = Series(np.array([foosum] * K), index=list("BCD"), name="foo") + tm.assert_series_equal(result.xs("foo"), exp) + + exp = Series(np.array([barsum] * K), index=list("BCD"), name="bar") + tm.assert_almost_equal(result.xs("bar"), exp) + + def aggfun_1(ser): + return ser.size + + result = DataFrame().groupby(df.A).agg(aggfun_1) + assert isinstance(result, DataFrame) + assert len(result) == 0 + + +def test_wrap_agg_out(three_group): + grouped = three_group.groupby(["A", "B"]) + + def func(ser): + if ser.dtype == object: + raise TypeError("Test error message") + return ser.sum() + + with pytest.raises(TypeError, match="Test error message"): + grouped.aggregate(func) + result = grouped[["D", "E", "F"]].aggregate(func) + exp_grouped = three_group.loc[:, ["A", "B", "D", "E", "F"]] + expected = exp_grouped.groupby(["A", "B"]).aggregate(func) + tm.assert_frame_equal(result, expected) + + +def test_agg_multiple_functions_maintain_order(df): + # GH #610 + funcs = [("mean", np.mean), ("max", np.max), ("min", np.min)] + msg = "is currently using SeriesGroupBy.mean" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = df.groupby("A")["C"].agg(funcs) + exp_cols = Index(["mean", "max", "min"]) + + tm.assert_index_equal(result.columns, exp_cols) + + +def test_series_index_name(df): + grouped = df.loc[:, ["C"]].groupby(df["A"]) + result = grouped.agg(lambda x: x.mean()) + assert result.index.name == "A" + + +def test_agg_multiple_functions_same_name(): + # GH 30880 + df = DataFrame( + np.random.default_rng(2).standard_normal((1000, 3)), + index=pd.date_range("1/1/2012", freq="s", periods=1000), + columns=["A", "B", "C"], + ) + result = df.resample("3min").agg( + {"A": [partial(np.quantile, q=0.9999), partial(np.quantile, q=0.1111)]} + ) + expected_index = pd.date_range("1/1/2012", freq="3min", periods=6) + expected_columns = MultiIndex.from_tuples([("A", "quantile"), ("A", "quantile")]) + expected_values = np.array( + [df.resample("3min").A.quantile(q=q).values for q in [0.9999, 0.1111]] + ).T + expected = DataFrame( + expected_values, columns=expected_columns, index=expected_index + ) + tm.assert_frame_equal(result, expected) + + +def test_agg_multiple_functions_same_name_with_ohlc_present(): + # GH 30880 + # ohlc expands dimensions, so different test to the above is required. + df = DataFrame( + np.random.default_rng(2).standard_normal((1000, 3)), + index=pd.date_range("1/1/2012", freq="s", periods=1000, name="dti"), + columns=Index(["A", "B", "C"], name="alpha"), + ) + result = df.resample("3min").agg( + {"A": ["ohlc", partial(np.quantile, q=0.9999), partial(np.quantile, q=0.1111)]} + ) + expected_index = pd.date_range("1/1/2012", freq="3min", periods=6, name="dti") + expected_columns = MultiIndex.from_tuples( + [ + ("A", "ohlc", "open"), + ("A", "ohlc", "high"), + ("A", "ohlc", "low"), + ("A", "ohlc", "close"), + ("A", "quantile", "A"), + ("A", "quantile", "A"), + ], + names=["alpha", None, None], + ) + non_ohlc_expected_values = np.array( + [df.resample("3min").A.quantile(q=q).values for q in [0.9999, 0.1111]] + ).T + expected_values = np.hstack( + [df.resample("3min").A.ohlc(), non_ohlc_expected_values] + ) + expected = DataFrame( + expected_values, columns=expected_columns, index=expected_index + ) + tm.assert_frame_equal(result, expected) + + +def test_multiple_functions_tuples_and_non_tuples(df): + # #1359 + # Columns B and C would cause partial failure + df = df.drop(columns=["B", "C"]) + + funcs = [("foo", "mean"), "std"] + ex_funcs = [("foo", "mean"), ("std", "std")] + + result = df.groupby("A")["D"].agg(funcs) + expected = df.groupby("A")["D"].agg(ex_funcs) + tm.assert_frame_equal(result, expected) + + result = df.groupby("A").agg(funcs) + expected = df.groupby("A").agg(ex_funcs) + tm.assert_frame_equal(result, expected) + + +def test_more_flexible_frame_multi_function(df): + grouped = df.groupby("A") + + exmean = grouped.agg({"C": "mean", "D": "mean"}) + exstd = grouped.agg({"C": "std", "D": "std"}) + + expected = concat([exmean, exstd], keys=["mean", "std"], axis=1) + expected = expected.swaplevel(0, 1, axis=1).sort_index(level=0, axis=1) + + d = {"C": ["mean", "std"], "D": ["mean", "std"]} + result = grouped.aggregate(d) + + tm.assert_frame_equal(result, expected) + + # be careful + result = grouped.aggregate({"C": "mean", "D": ["mean", "std"]}) + expected = grouped.aggregate({"C": "mean", "D": ["mean", "std"]}) + tm.assert_frame_equal(result, expected) + + def numpymean(x): + return np.mean(x) + + def numpystd(x): + return np.std(x, ddof=1) + + # this uses column selection & renaming + msg = r"nested renamer is not supported" + with pytest.raises(SpecificationError, match=msg): + d = {"C": "mean", "D": {"foo": "mean", "bar": "std"}} + grouped.aggregate(d) + + # But without renaming, these functions are OK + d = {"C": ["mean"], "D": [numpymean, numpystd]} + grouped.aggregate(d) + + +def test_multi_function_flexible_mix(df): + # GH #1268 + grouped = df.groupby("A") + + # Expected + d = {"C": {"foo": "mean", "bar": "std"}, "D": {"sum": "sum"}} + # this uses column selection & renaming + msg = r"nested renamer is not supported" + with pytest.raises(SpecificationError, match=msg): + grouped.aggregate(d) + + # Test 1 + d = {"C": {"foo": "mean", "bar": "std"}, "D": "sum"} + # this uses column selection & renaming + with pytest.raises(SpecificationError, match=msg): + grouped.aggregate(d) + + # Test 2 + d = {"C": {"foo": "mean", "bar": "std"}, "D": "sum"} + # this uses column selection & renaming + with pytest.raises(SpecificationError, match=msg): + grouped.aggregate(d) + + +def test_groupby_agg_coercing_bools(): + # issue 14873 + dat = DataFrame({"a": [1, 1, 2, 2], "b": [0, 1, 2, 3], "c": [None, None, 1, 1]}) + gp = dat.groupby("a") + + index = Index([1, 2], name="a") + + result = gp["b"].aggregate(lambda x: (x != 0).all()) + expected = Series([False, True], index=index, name="b") + tm.assert_series_equal(result, expected) + + result = gp["c"].aggregate(lambda x: x.isnull().all()) + expected = Series([True, False], index=index, name="c") + tm.assert_series_equal(result, expected) + + +def test_groupby_agg_dict_with_getitem(): + # issue 25471 + dat = DataFrame({"A": ["A", "A", "B", "B", "B"], "B": [1, 2, 1, 1, 2]}) + result = dat.groupby("A")[["B"]].agg({"B": "sum"}) + + expected = DataFrame({"B": [3, 4]}, index=["A", "B"]).rename_axis("A", axis=0) + + tm.assert_frame_equal(result, expected) + + +def test_groupby_agg_dict_dup_columns(): + # GH#55006 + df = DataFrame( + [[1, 2, 3, 4], [1, 3, 4, 5], [2, 4, 5, 6]], + columns=["a", "b", "c", "c"], + ) + gb = df.groupby("a") + result = gb.agg({"b": "sum"}) + expected = DataFrame({"b": [5, 4]}, index=Index([1, 2], name="a")) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "op", + [ + lambda x: x.sum(), + lambda x: x.cumsum(), + lambda x: x.transform("sum"), + lambda x: x.transform("cumsum"), + lambda x: x.agg("sum"), + lambda x: x.agg("cumsum"), + ], +) +def test_bool_agg_dtype(op): + # GH 7001 + # Bool sum aggregations result in int + df = DataFrame({"a": [1, 1], "b": [False, True]}) + s = df.set_index("a")["b"] + + result = op(df.groupby("a"))["b"].dtype + assert is_integer_dtype(result) + + result = op(s.groupby("a")).dtype + assert is_integer_dtype(result) + + +@pytest.mark.parametrize( + "keys, agg_index", + [ + (["a"], Index([1], name="a")), + (["a", "b"], MultiIndex([[1], [2]], [[0], [0]], names=["a", "b"])), + ], +) +@pytest.mark.parametrize( + "input_dtype", ["bool", "int32", "int64", "float32", "float64"] +) +@pytest.mark.parametrize( + "result_dtype", ["bool", "int32", "int64", "float32", "float64"] +) +@pytest.mark.parametrize("method", ["apply", "aggregate", "transform"]) +def test_callable_result_dtype_frame( + keys, agg_index, input_dtype, result_dtype, method +): + # GH 21240 + df = DataFrame({"a": [1], "b": [2], "c": [True]}) + df["c"] = df["c"].astype(input_dtype) + op = getattr(df.groupby(keys)[["c"]], method) + result = op(lambda x: x.astype(result_dtype).iloc[0]) + expected_index = pd.RangeIndex(0, 1) if method == "transform" else agg_index + expected = DataFrame({"c": [df["c"].iloc[0]]}, index=expected_index).astype( + result_dtype + ) + if method == "apply": + expected.columns.names = [0] + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "keys, agg_index", + [ + (["a"], Index([1], name="a")), + (["a", "b"], MultiIndex([[1], [2]], [[0], [0]], names=["a", "b"])), + ], +) +@pytest.mark.parametrize("input", [True, 1, 1.0]) +@pytest.mark.parametrize("dtype", [bool, int, float]) +@pytest.mark.parametrize("method", ["apply", "aggregate", "transform"]) +def test_callable_result_dtype_series(keys, agg_index, input, dtype, method): + # GH 21240 + df = DataFrame({"a": [1], "b": [2], "c": [input]}) + op = getattr(df.groupby(keys)["c"], method) + result = op(lambda x: x.astype(dtype).iloc[0]) + expected_index = pd.RangeIndex(0, 1) if method == "transform" else agg_index + expected = Series([df["c"].iloc[0]], index=expected_index, name="c").astype(dtype) + tm.assert_series_equal(result, expected) + + +def test_order_aggregate_multiple_funcs(): + # GH 25692 + df = DataFrame({"A": [1, 1, 2, 2], "B": [1, 2, 3, 4]}) + + res = df.groupby("A").agg(["sum", "max", "mean", "ohlc", "min"]) + result = res.columns.levels[1] + + expected = Index(["sum", "max", "mean", "ohlc", "min"]) + + tm.assert_index_equal(result, expected) + + +def test_ohlc_ea_dtypes(any_numeric_ea_dtype): + # GH#37493 + df = DataFrame( + {"a": [1, 1, 2, 3, 4, 4], "b": [22, 11, pd.NA, 10, 20, pd.NA]}, + dtype=any_numeric_ea_dtype, + ) + gb = df.groupby("a") + result = gb.ohlc() + expected = DataFrame( + [[22, 22, 11, 11], [pd.NA] * 4, [10] * 4, [20] * 4], + columns=MultiIndex.from_product([["b"], ["open", "high", "low", "close"]]), + index=Index([1, 2, 3, 4], dtype=any_numeric_ea_dtype, name="a"), + dtype=any_numeric_ea_dtype, + ) + tm.assert_frame_equal(result, expected) + + gb2 = df.groupby("a", as_index=False) + result2 = gb2.ohlc() + expected2 = expected.reset_index() + tm.assert_frame_equal(result2, expected2) + + +@pytest.mark.parametrize("dtype", [np.int64, np.uint64]) +@pytest.mark.parametrize("how", ["first", "last", "min", "max", "mean", "median"]) +def test_uint64_type_handling(dtype, how): + # GH 26310 + df = DataFrame({"x": 6903052872240755750, "y": [1, 2]}) + expected = df.groupby("y").agg({"x": how}) + df.x = df.x.astype(dtype) + result = df.groupby("y").agg({"x": how}) + if how not in ("mean", "median"): + # mean and median always result in floats + result.x = result.x.astype(np.int64) + tm.assert_frame_equal(result, expected, check_exact=True) + + +def test_func_duplicates_raises(): + # GH28426 + msg = "Function names" + df = DataFrame({"A": [0, 0, 1, 1], "B": [1, 2, 3, 4]}) + with pytest.raises(SpecificationError, match=msg): + df.groupby("A").agg(["min", "min"]) + + +@pytest.mark.parametrize( + "index", + [ + pd.CategoricalIndex(list("abc")), + pd.interval_range(0, 3), + pd.period_range("2020", periods=3, freq="D"), + MultiIndex.from_tuples([("a", 0), ("a", 1), ("b", 0)]), + ], +) +def test_agg_index_has_complex_internals(index): + # GH 31223 + df = DataFrame({"group": [1, 1, 2], "value": [0, 1, 0]}, index=index) + result = df.groupby("group").agg({"value": Series.nunique}) + expected = DataFrame({"group": [1, 2], "value": [2, 1]}).set_index("group") + tm.assert_frame_equal(result, expected) + + +def test_agg_split_block(): + # https://github.com/pandas-dev/pandas/issues/31522 + df = DataFrame( + { + "key1": ["a", "a", "b", "b", "a"], + "key2": ["one", "two", "one", "two", "one"], + "key3": ["three", "three", "three", "six", "six"], + } + ) + result = df.groupby("key1").min() + expected = DataFrame( + {"key2": ["one", "one"], "key3": ["six", "six"]}, + index=Index(["a", "b"], name="key1"), + ) + tm.assert_frame_equal(result, expected) + + +def test_agg_split_object_part_datetime(): + # https://github.com/pandas-dev/pandas/pull/31616 + df = DataFrame( + { + "A": pd.date_range("2000", periods=4), + "B": ["a", "b", "c", "d"], + "C": [1, 2, 3, 4], + "D": ["b", "c", "d", "e"], + "E": pd.date_range("2000", periods=4), + "F": [1, 2, 3, 4], + } + ).astype(object) + result = df.groupby([0, 0, 0, 0]).min() + expected = DataFrame( + { + "A": [pd.Timestamp("2000")], + "B": ["a"], + "C": [1], + "D": ["b"], + "E": [pd.Timestamp("2000")], + "F": [1], + }, + index=np.array([0]), + dtype=object, + ) + tm.assert_frame_equal(result, expected) + + +class TestNamedAggregationSeries: + def test_series_named_agg(self): + df = Series([1, 2, 3, 4]) + gr = df.groupby([0, 0, 1, 1]) + result = gr.agg(a="sum", b="min") + expected = DataFrame( + {"a": [3, 7], "b": [1, 3]}, columns=["a", "b"], index=np.array([0, 1]) + ) + tm.assert_frame_equal(result, expected) + + result = gr.agg(b="min", a="sum") + expected = expected[["b", "a"]] + tm.assert_frame_equal(result, expected) + + def test_no_args_raises(self): + gr = Series([1, 2]).groupby([0, 1]) + with pytest.raises(TypeError, match="Must provide"): + gr.agg() + + # but we do allow this + result = gr.agg([]) + expected = DataFrame(columns=[]) + tm.assert_frame_equal(result, expected) + + def test_series_named_agg_duplicates_no_raises(self): + # GH28426 + gr = Series([1, 2, 3]).groupby([0, 0, 1]) + grouped = gr.agg(a="sum", b="sum") + expected = DataFrame({"a": [3, 3], "b": [3, 3]}, index=np.array([0, 1])) + tm.assert_frame_equal(expected, grouped) + + def test_mangled(self): + gr = Series([1, 2, 3]).groupby([0, 0, 1]) + result = gr.agg(a=lambda x: 0, b=lambda x: 1) + expected = DataFrame({"a": [0, 0], "b": [1, 1]}, index=np.array([0, 1])) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "inp", + [ + pd.NamedAgg(column="anything", aggfunc="min"), + ("anything", "min"), + ["anything", "min"], + ], + ) + def test_named_agg_nametuple(self, inp): + # GH34422 + s = Series([1, 1, 2, 2, 3, 3, 4, 5]) + msg = f"func is expected but received {type(inp).__name__}" + with pytest.raises(TypeError, match=msg): + s.groupby(s.values).agg(a=inp) + + +class TestNamedAggregationDataFrame: + def test_agg_relabel(self): + df = DataFrame( + {"group": ["a", "a", "b", "b"], "A": [0, 1, 2, 3], "B": [5, 6, 7, 8]} + ) + result = df.groupby("group").agg(a_max=("A", "max"), b_max=("B", "max")) + expected = DataFrame( + {"a_max": [1, 3], "b_max": [6, 8]}, + index=Index(["a", "b"], name="group"), + columns=["a_max", "b_max"], + ) + tm.assert_frame_equal(result, expected) + + # order invariance + p98 = functools.partial(np.percentile, q=98) + result = df.groupby("group").agg( + b_min=("B", "min"), + a_min=("A", "min"), + a_mean=("A", "mean"), + a_max=("A", "max"), + b_max=("B", "max"), + a_98=("A", p98), + ) + expected = DataFrame( + { + "b_min": [5, 7], + "a_min": [0, 2], + "a_mean": [0.5, 2.5], + "a_max": [1, 3], + "b_max": [6, 8], + "a_98": [0.98, 2.98], + }, + index=Index(["a", "b"], name="group"), + columns=["b_min", "a_min", "a_mean", "a_max", "b_max", "a_98"], + ) + tm.assert_frame_equal(result, expected) + + def test_agg_relabel_non_identifier(self): + df = DataFrame( + {"group": ["a", "a", "b", "b"], "A": [0, 1, 2, 3], "B": [5, 6, 7, 8]} + ) + + result = df.groupby("group").agg(**{"my col": ("A", "max")}) + expected = DataFrame({"my col": [1, 3]}, index=Index(["a", "b"], name="group")) + tm.assert_frame_equal(result, expected) + + def test_duplicate_no_raises(self): + # GH 28426, if use same input function on same column, + # no error should raise + df = DataFrame({"A": [0, 0, 1, 1], "B": [1, 2, 3, 4]}) + + grouped = df.groupby("A").agg(a=("B", "min"), b=("B", "min")) + expected = DataFrame({"a": [1, 3], "b": [1, 3]}, index=Index([0, 1], name="A")) + tm.assert_frame_equal(grouped, expected) + + quant50 = functools.partial(np.percentile, q=50) + quant70 = functools.partial(np.percentile, q=70) + quant50.__name__ = "quant50" + quant70.__name__ = "quant70" + + test = DataFrame({"col1": ["a", "a", "b", "b", "b"], "col2": [1, 2, 3, 4, 5]}) + + grouped = test.groupby("col1").agg( + quantile_50=("col2", quant50), quantile_70=("col2", quant70) + ) + expected = DataFrame( + {"quantile_50": [1.5, 4.0], "quantile_70": [1.7, 4.4]}, + index=Index(["a", "b"], name="col1"), + ) + tm.assert_frame_equal(grouped, expected) + + def test_agg_relabel_with_level(self): + df = DataFrame( + {"A": [0, 0, 1, 1], "B": [1, 2, 3, 4]}, + index=MultiIndex.from_product([["A", "B"], ["a", "b"]]), + ) + result = df.groupby(level=0).agg( + aa=("A", "max"), bb=("A", "min"), cc=("B", "mean") + ) + expected = DataFrame( + {"aa": [0, 1], "bb": [0, 1], "cc": [1.5, 3.5]}, index=["A", "B"] + ) + tm.assert_frame_equal(result, expected) + + def test_agg_relabel_other_raises(self): + df = DataFrame({"A": [0, 0, 1], "B": [1, 2, 3]}) + grouped = df.groupby("A") + match = "Must provide" + with pytest.raises(TypeError, match=match): + grouped.agg(foo=1) + + with pytest.raises(TypeError, match=match): + grouped.agg() + + with pytest.raises(TypeError, match=match): + grouped.agg(a=("B", "max"), b=(1, 2, 3)) + + def test_missing_raises(self): + df = DataFrame({"A": [0, 1], "B": [1, 2]}) + match = re.escape("Column(s) ['C'] do not exist") + with pytest.raises(KeyError, match=match): + df.groupby("A").agg(c=("C", "sum")) + + def test_agg_namedtuple(self): + df = DataFrame({"A": [0, 1], "B": [1, 2]}) + result = df.groupby("A").agg( + b=pd.NamedAgg("B", "sum"), c=pd.NamedAgg(column="B", aggfunc="count") + ) + expected = df.groupby("A").agg(b=("B", "sum"), c=("B", "count")) + tm.assert_frame_equal(result, expected) + + def test_mangled(self): + df = DataFrame({"A": [0, 1], "B": [1, 2], "C": [3, 4]}) + result = df.groupby("A").agg(b=("B", lambda x: 0), c=("C", lambda x: 1)) + expected = DataFrame({"b": [0, 0], "c": [1, 1]}, index=Index([0, 1], name="A")) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "agg_col1, agg_col2, agg_col3, agg_result1, agg_result2, agg_result3", + [ + ( + (("y", "A"), "max"), + (("y", "A"), np.mean), + (("y", "B"), "mean"), + [1, 3], + [0.5, 2.5], + [5.5, 7.5], + ), + ( + (("y", "A"), lambda x: max(x)), + (("y", "A"), lambda x: 1), + (("y", "B"), np.mean), + [1, 3], + [1, 1], + [5.5, 7.5], + ), + ( + pd.NamedAgg(("y", "A"), "max"), + pd.NamedAgg(("y", "B"), np.mean), + pd.NamedAgg(("y", "A"), lambda x: 1), + [1, 3], + [5.5, 7.5], + [1, 1], + ), + ], +) +def test_agg_relabel_multiindex_column( + agg_col1, agg_col2, agg_col3, agg_result1, agg_result2, agg_result3 +): + # GH 29422, add tests for multiindex column cases + df = DataFrame( + {"group": ["a", "a", "b", "b"], "A": [0, 1, 2, 3], "B": [5, 6, 7, 8]} + ) + df.columns = MultiIndex.from_tuples([("x", "group"), ("y", "A"), ("y", "B")]) + idx = Index(["a", "b"], name=("x", "group")) + + result = df.groupby(("x", "group")).agg(a_max=(("y", "A"), "max")) + expected = DataFrame({"a_max": [1, 3]}, index=idx) + tm.assert_frame_equal(result, expected) + + msg = "is currently using SeriesGroupBy.mean" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = df.groupby(("x", "group")).agg( + col_1=agg_col1, col_2=agg_col2, col_3=agg_col3 + ) + expected = DataFrame( + {"col_1": agg_result1, "col_2": agg_result2, "col_3": agg_result3}, index=idx + ) + tm.assert_frame_equal(result, expected) + + +def test_agg_relabel_multiindex_raises_not_exist(): + # GH 29422, add test for raises scenario when aggregate column does not exist + df = DataFrame( + {"group": ["a", "a", "b", "b"], "A": [0, 1, 2, 3], "B": [5, 6, 7, 8]} + ) + df.columns = MultiIndex.from_tuples([("x", "group"), ("y", "A"), ("y", "B")]) + + with pytest.raises(KeyError, match="do not exist"): + df.groupby(("x", "group")).agg(a=(("Y", "a"), "max")) + + +def test_agg_relabel_multiindex_duplicates(): + # GH29422, add test for raises scenario when getting duplicates + # GH28426, after this change, duplicates should also work if the relabelling is + # different + df = DataFrame( + {"group": ["a", "a", "b", "b"], "A": [0, 1, 2, 3], "B": [5, 6, 7, 8]} + ) + df.columns = MultiIndex.from_tuples([("x", "group"), ("y", "A"), ("y", "B")]) + + result = df.groupby(("x", "group")).agg( + a=(("y", "A"), "min"), b=(("y", "A"), "min") + ) + idx = Index(["a", "b"], name=("x", "group")) + expected = DataFrame({"a": [0, 2], "b": [0, 2]}, index=idx) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("kwargs", [{"c": ["min"]}, {"b": [], "c": ["min"]}]) +def test_groupby_aggregate_empty_key(kwargs): + # GH: 32580 + df = DataFrame({"a": [1, 1, 2], "b": [1, 2, 3], "c": [1, 2, 4]}) + result = df.groupby("a").agg(kwargs) + expected = DataFrame( + [1, 4], + index=Index([1, 2], dtype="int64", name="a"), + columns=MultiIndex.from_tuples([["c", "min"]]), + ) + tm.assert_frame_equal(result, expected) + + +def test_groupby_aggregate_empty_key_empty_return(): + # GH: 32580 Check if everything works, when return is empty + df = DataFrame({"a": [1, 1, 2], "b": [1, 2, 3], "c": [1, 2, 4]}) + result = df.groupby("a").agg({"b": []}) + expected = DataFrame(columns=MultiIndex(levels=[["b"], []], codes=[[], []])) + tm.assert_frame_equal(result, expected) + + +def test_groupby_aggregate_empty_with_multiindex_frame(): + # GH 39178 + df = DataFrame(columns=["a", "b", "c"]) + result = df.groupby(["a", "b"], group_keys=False).agg(d=("c", list)) + expected = DataFrame( + columns=["d"], index=MultiIndex([[], []], [[], []], names=["a", "b"]) + ) + tm.assert_frame_equal(result, expected) + + +def test_grouby_agg_loses_results_with_as_index_false_relabel(): + # GH 32240: When the aggregate function relabels column names and + # as_index=False is specified, the results are dropped. + + df = DataFrame( + {"key": ["x", "y", "z", "x", "y", "z"], "val": [1.0, 0.8, 2.0, 3.0, 3.6, 0.75]} + ) + + grouped = df.groupby("key", as_index=False) + result = grouped.agg(min_val=pd.NamedAgg(column="val", aggfunc="min")) + expected = DataFrame({"key": ["x", "y", "z"], "min_val": [1.0, 0.8, 0.75]}) + tm.assert_frame_equal(result, expected) + + +def test_grouby_agg_loses_results_with_as_index_false_relabel_multiindex(): + # GH 32240: When the aggregate function relabels column names and + # as_index=False is specified, the results are dropped. Check if + # multiindex is returned in the right order + + df = DataFrame( + { + "key": ["x", "y", "x", "y", "x", "x"], + "key1": ["a", "b", "c", "b", "a", "c"], + "val": [1.0, 0.8, 2.0, 3.0, 3.6, 0.75], + } + ) + + grouped = df.groupby(["key", "key1"], as_index=False) + result = grouped.agg(min_val=pd.NamedAgg(column="val", aggfunc="min")) + expected = DataFrame( + {"key": ["x", "x", "y"], "key1": ["a", "c", "b"], "min_val": [1.0, 0.75, 0.8]} + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "func", [lambda s: s.mean(), lambda s: np.mean(s), lambda s: np.nanmean(s)] +) +def test_multiindex_custom_func(func): + # GH 31777 + data = [[1, 4, 2], [5, 7, 1]] + df = DataFrame( + data, + columns=MultiIndex.from_arrays( + [[1, 1, 2], [3, 4, 3]], names=["Sisko", "Janeway"] + ), + ) + result = df.groupby(np.array([0, 1])).agg(func) + expected_dict = { + (1, 3): {0: 1.0, 1: 5.0}, + (1, 4): {0: 4.0, 1: 7.0}, + (2, 3): {0: 2.0, 1: 1.0}, + } + expected = DataFrame(expected_dict, index=np.array([0, 1]), columns=df.columns) + tm.assert_frame_equal(result, expected) + + +def myfunc(s): + return np.percentile(s, q=0.90) + + +@pytest.mark.parametrize("func", [lambda s: np.percentile(s, q=0.90), myfunc]) +def test_lambda_named_agg(func): + # see gh-28467 + animals = DataFrame( + { + "kind": ["cat", "dog", "cat", "dog"], + "height": [9.1, 6.0, 9.5, 34.0], + "weight": [7.9, 7.5, 9.9, 198.0], + } + ) + + result = animals.groupby("kind").agg( + mean_height=("height", "mean"), perc90=("height", func) + ) + expected = DataFrame( + [[9.3, 9.1036], [20.0, 6.252]], + columns=["mean_height", "perc90"], + index=Index(["cat", "dog"], name="kind"), + ) + + tm.assert_frame_equal(result, expected) + + +def test_aggregate_mixed_types(): + # GH 16916 + df = DataFrame( + data=np.array([0] * 9).reshape(3, 3), columns=list("XYZ"), index=list("abc") + ) + df["grouping"] = ["group 1", "group 1", 2] + result = df.groupby("grouping").aggregate(lambda x: x.tolist()) + expected_data = [[[0], [0], [0]], [[0, 0], [0, 0], [0, 0]]] + expected = DataFrame( + expected_data, + index=Index([2, "group 1"], dtype="object", name="grouping"), + columns=Index(["X", "Y", "Z"], dtype="object"), + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.xfail(reason="Not implemented;see GH 31256") +def test_aggregate_udf_na_extension_type(): + # https://github.com/pandas-dev/pandas/pull/31359 + # This is currently failing to cast back to Int64Dtype. + # The presence of the NA causes two problems + # 1. NA is not an instance of Int64Dtype.type (numpy.int64) + # 2. The presence of an NA forces object type, so the non-NA values is + # a Python int rather than a NumPy int64. Python ints aren't + # instances of numpy.int64. + def aggfunc(x): + if all(x > 2): + return 1 + else: + return pd.NA + + df = DataFrame({"A": pd.array([1, 2, 3])}) + result = df.groupby([1, 1, 2]).agg(aggfunc) + expected = DataFrame({"A": pd.array([1, pd.NA], dtype="Int64")}, index=[1, 2]) + tm.assert_frame_equal(result, expected) + + +class TestLambdaMangling: + def test_basic(self): + df = DataFrame({"A": [0, 0, 1, 1], "B": [1, 2, 3, 4]}) + result = df.groupby("A").agg({"B": [lambda x: 0, lambda x: 1]}) + + expected = DataFrame( + {("B", ""): [0, 0], ("B", ""): [1, 1]}, + index=Index([0, 1], name="A"), + ) + tm.assert_frame_equal(result, expected) + + def test_mangle_series_groupby(self): + gr = Series([1, 2, 3, 4]).groupby([0, 0, 1, 1]) + result = gr.agg([lambda x: 0, lambda x: 1]) + exp_data = {"": [0, 0], "": [1, 1]} + expected = DataFrame(exp_data, index=np.array([0, 1])) + tm.assert_frame_equal(result, expected) + + @pytest.mark.xfail(reason="GH-26611. kwargs for multi-agg.") + def test_with_kwargs(self): + f1 = lambda x, y, b=1: x.sum() + y + b + f2 = lambda x, y, b=2: x.sum() + y * b + result = Series([1, 2]).groupby([0, 0]).agg([f1, f2], 0) + expected = DataFrame({"": [4], "": [6]}) + tm.assert_frame_equal(result, expected) + + result = Series([1, 2]).groupby([0, 0]).agg([f1, f2], 0, b=10) + expected = DataFrame({"": [13], "": [30]}) + tm.assert_frame_equal(result, expected) + + def test_agg_with_one_lambda(self): + # GH 25719, write tests for DataFrameGroupby.agg with only one lambda + df = DataFrame( + { + "kind": ["cat", "dog", "cat", "dog"], + "height": [9.1, 6.0, 9.5, 34.0], + "weight": [7.9, 7.5, 9.9, 198.0], + } + ) + + columns = ["height_sqr_min", "height_max", "weight_max"] + expected = DataFrame( + { + "height_sqr_min": [82.81, 36.00], + "height_max": [9.5, 34.0], + "weight_max": [9.9, 198.0], + }, + index=Index(["cat", "dog"], name="kind"), + columns=columns, + ) + + # check pd.NameAgg case + result1 = df.groupby(by="kind").agg( + height_sqr_min=pd.NamedAgg( + column="height", aggfunc=lambda x: np.min(x**2) + ), + height_max=pd.NamedAgg(column="height", aggfunc="max"), + weight_max=pd.NamedAgg(column="weight", aggfunc="max"), + ) + tm.assert_frame_equal(result1, expected) + + # check agg(key=(col, aggfunc)) case + result2 = df.groupby(by="kind").agg( + height_sqr_min=("height", lambda x: np.min(x**2)), + height_max=("height", "max"), + weight_max=("weight", "max"), + ) + tm.assert_frame_equal(result2, expected) + + def test_agg_multiple_lambda(self): + # GH25719, test for DataFrameGroupby.agg with multiple lambdas + # with mixed aggfunc + df = DataFrame( + { + "kind": ["cat", "dog", "cat", "dog"], + "height": [9.1, 6.0, 9.5, 34.0], + "weight": [7.9, 7.5, 9.9, 198.0], + } + ) + columns = [ + "height_sqr_min", + "height_max", + "weight_max", + "height_max_2", + "weight_min", + ] + expected = DataFrame( + { + "height_sqr_min": [82.81, 36.00], + "height_max": [9.5, 34.0], + "weight_max": [9.9, 198.0], + "height_max_2": [9.5, 34.0], + "weight_min": [7.9, 7.5], + }, + index=Index(["cat", "dog"], name="kind"), + columns=columns, + ) + + # check agg(key=(col, aggfunc)) case + result1 = df.groupby(by="kind").agg( + height_sqr_min=("height", lambda x: np.min(x**2)), + height_max=("height", "max"), + weight_max=("weight", "max"), + height_max_2=("height", lambda x: np.max(x)), + weight_min=("weight", lambda x: np.min(x)), + ) + tm.assert_frame_equal(result1, expected) + + # check pd.NamedAgg case + result2 = df.groupby(by="kind").agg( + height_sqr_min=pd.NamedAgg( + column="height", aggfunc=lambda x: np.min(x**2) + ), + height_max=pd.NamedAgg(column="height", aggfunc="max"), + weight_max=pd.NamedAgg(column="weight", aggfunc="max"), + height_max_2=pd.NamedAgg(column="height", aggfunc=lambda x: np.max(x)), + weight_min=pd.NamedAgg(column="weight", aggfunc=lambda x: np.min(x)), + ) + tm.assert_frame_equal(result2, expected) + + +def test_groupby_get_by_index(): + # GH 33439 + df = DataFrame({"A": ["S", "W", "W"], "B": [1.0, 1.0, 2.0]}) + res = df.groupby("A").agg({"B": lambda x: x.get(x.index[-1])}) + expected = DataFrame({"A": ["S", "W"], "B": [1.0, 2.0]}).set_index("A") + tm.assert_frame_equal(res, expected) + + +@pytest.mark.parametrize( + "grp_col_dict, exp_data", + [ + ({"nr": "min", "cat_ord": "min"}, {"nr": [1, 5], "cat_ord": ["a", "c"]}), + ({"cat_ord": "min"}, {"cat_ord": ["a", "c"]}), + ({"nr": "min"}, {"nr": [1, 5]}), + ], +) +def test_groupby_single_agg_cat_cols(grp_col_dict, exp_data): + # test single aggregations on ordered categorical cols GHGH27800 + + # create the result dataframe + input_df = DataFrame( + { + "nr": [1, 2, 3, 4, 5, 6, 7, 8], + "cat_ord": list("aabbccdd"), + "cat": list("aaaabbbb"), + } + ) + + input_df = input_df.astype({"cat": "category", "cat_ord": "category"}) + input_df["cat_ord"] = input_df["cat_ord"].cat.as_ordered() + result_df = input_df.groupby("cat", observed=False).agg(grp_col_dict) + + # create expected dataframe + cat_index = pd.CategoricalIndex( + ["a", "b"], categories=["a", "b"], ordered=False, name="cat", dtype="category" + ) + + expected_df = DataFrame(data=exp_data, index=cat_index) + + if "cat_ord" in expected_df: + # ordered categorical columns should be preserved + dtype = input_df["cat_ord"].dtype + expected_df["cat_ord"] = expected_df["cat_ord"].astype(dtype) + + tm.assert_frame_equal(result_df, expected_df) + + +@pytest.mark.parametrize( + "grp_col_dict, exp_data", + [ + ({"nr": ["min", "max"], "cat_ord": "min"}, [(1, 4, "a"), (5, 8, "c")]), + ({"nr": "min", "cat_ord": ["min", "max"]}, [(1, "a", "b"), (5, "c", "d")]), + ({"cat_ord": ["min", "max"]}, [("a", "b"), ("c", "d")]), + ], +) +def test_groupby_combined_aggs_cat_cols(grp_col_dict, exp_data): + # test combined aggregations on ordered categorical cols GH27800 + + # create the result dataframe + input_df = DataFrame( + { + "nr": [1, 2, 3, 4, 5, 6, 7, 8], + "cat_ord": list("aabbccdd"), + "cat": list("aaaabbbb"), + } + ) + + input_df = input_df.astype({"cat": "category", "cat_ord": "category"}) + input_df["cat_ord"] = input_df["cat_ord"].cat.as_ordered() + result_df = input_df.groupby("cat", observed=False).agg(grp_col_dict) + + # create expected dataframe + cat_index = pd.CategoricalIndex( + ["a", "b"], categories=["a", "b"], ordered=False, name="cat", dtype="category" + ) + + # unpack the grp_col_dict to create the multi-index tuple + # this tuple will be used to create the expected dataframe index + multi_index_list = [] + for k, v in grp_col_dict.items(): + if isinstance(v, list): + multi_index_list.extend([k, value] for value in v) + else: + multi_index_list.append([k, v]) + multi_index = MultiIndex.from_tuples(tuple(multi_index_list)) + + expected_df = DataFrame(data=exp_data, columns=multi_index, index=cat_index) + for col in expected_df.columns: + if isinstance(col, tuple) and "cat_ord" in col: + # ordered categorical should be preserved + expected_df[col] = expected_df[col].astype(input_df["cat_ord"].dtype) + + tm.assert_frame_equal(result_df, expected_df) + + +def test_nonagg_agg(): + # GH 35490 - Single/Multiple agg of non-agg function give same results + # TODO: agg should raise for functions that don't aggregate + df = DataFrame({"a": [1, 1, 2, 2], "b": [1, 2, 2, 1]}) + g = df.groupby("a") + + result = g.agg(["cumsum"]) + result.columns = result.columns.droplevel(-1) + expected = g.agg("cumsum") + + tm.assert_frame_equal(result, expected) + + +def test_aggregate_datetime_objects(): + # https://github.com/pandas-dev/pandas/issues/36003 + # ensure we don't raise an error but keep object dtype for out-of-bounds + # datetimes + df = DataFrame( + { + "A": ["X", "Y"], + "B": [ + datetime.datetime(2005, 1, 1, 10, 30, 23, 540000), + datetime.datetime(3005, 1, 1, 10, 30, 23, 540000), + ], + } + ) + result = df.groupby("A").B.max() + expected = df.set_index("A")["B"] + tm.assert_series_equal(result, expected) + + +def test_groupby_index_object_dtype(): + # GH 40014 + df = DataFrame({"c0": ["x", "x", "x"], "c1": ["x", "x", "y"], "p": [0, 1, 2]}) + df.index = df.index.astype("O") + grouped = df.groupby(["c0", "c1"]) + res = grouped.p.agg(lambda x: all(x > 0)) + # Check that providing a user-defined function in agg() + # produces the correct index shape when using an object-typed index. + expected_index = MultiIndex.from_tuples( + [("x", "x"), ("x", "y")], names=("c0", "c1") + ) + expected = Series([False, True], index=expected_index, name="p") + tm.assert_series_equal(res, expected) + + +def test_timeseries_groupby_agg(): + # GH#43290 + + def func(ser): + if ser.isna().all(): + return None + return np.sum(ser) + + df = DataFrame([1.0], index=[pd.Timestamp("2018-01-16 00:00:00+00:00")]) + res = df.groupby(lambda x: 1).agg(func) + + expected = DataFrame([[1.0]], index=[1]) + tm.assert_frame_equal(res, expected) + + +def test_groupby_agg_precision(any_real_numeric_dtype): + if any_real_numeric_dtype in tm.ALL_INT_NUMPY_DTYPES: + max_value = np.iinfo(any_real_numeric_dtype).max + if any_real_numeric_dtype in tm.FLOAT_NUMPY_DTYPES: + max_value = np.finfo(any_real_numeric_dtype).max + if any_real_numeric_dtype in tm.FLOAT_EA_DTYPES: + max_value = np.finfo(any_real_numeric_dtype.lower()).max + if any_real_numeric_dtype in tm.ALL_INT_EA_DTYPES: + max_value = np.iinfo(any_real_numeric_dtype.lower()).max + + df = DataFrame( + { + "key1": ["a"], + "key2": ["b"], + "key3": pd.array([max_value], dtype=any_real_numeric_dtype), + } + ) + arrays = [["a"], ["b"]] + index = MultiIndex.from_arrays(arrays, names=("key1", "key2")) + + expected = DataFrame( + {"key3": pd.array([max_value], dtype=any_real_numeric_dtype)}, index=index + ) + result = df.groupby(["key1", "key2"]).agg(lambda x: x) + tm.assert_frame_equal(result, expected) + + +def test_groupby_aggregate_directory(reduction_func): + # GH#32793 + if reduction_func in ["corrwith", "nth"]: + return None + + obj = DataFrame([[0, 1], [0, np.nan]]) + + result_reduced_series = obj.groupby(0).agg(reduction_func) + result_reduced_frame = obj.groupby(0).agg({1: reduction_func}) + + if reduction_func in ["size", "ngroup"]: + # names are different: None / 1 + tm.assert_series_equal( + result_reduced_series, result_reduced_frame[1], check_names=False + ) + else: + tm.assert_frame_equal(result_reduced_series, result_reduced_frame) + tm.assert_series_equal( + result_reduced_series.dtypes, result_reduced_frame.dtypes + ) + + +def test_group_mean_timedelta_nat(): + # GH43132 + data = Series(["1 day", "3 days", "NaT"], dtype="timedelta64[ns]") + expected = Series(["2 days"], dtype="timedelta64[ns]", index=np.array([0])) + + result = data.groupby([0, 0, 0]).mean() + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "input_data, expected_output", + [ + ( # no timezone + ["2021-01-01T00:00", "NaT", "2021-01-01T02:00"], + ["2021-01-01T01:00"], + ), + ( # timezone + ["2021-01-01T00:00-0100", "NaT", "2021-01-01T02:00-0100"], + ["2021-01-01T01:00-0100"], + ), + ], +) +def test_group_mean_datetime64_nat(input_data, expected_output): + # GH43132 + data = to_datetime(Series(input_data)) + expected = to_datetime(Series(expected_output, index=np.array([0]))) + + result = data.groupby([0, 0, 0]).mean() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "func, output", [("mean", [8 + 18j, 10 + 22j]), ("sum", [40 + 90j, 50 + 110j])] +) +def test_groupby_complex(func, output): + # GH#43701 + data = Series(np.arange(20).reshape(10, 2).dot([1, 2j])) + result = data.groupby(data.index % 2).agg(func) + expected = Series(output) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("func", ["min", "max", "var"]) +def test_groupby_complex_raises(func): + # GH#43701 + data = Series(np.arange(20).reshape(10, 2).dot([1, 2j])) + msg = "No matching signature found" + with pytest.raises(TypeError, match=msg): + data.groupby(data.index % 2).agg(func) + + +@pytest.mark.parametrize( + "func", [["min"], ["mean", "max"], {"b": "sum"}, {"b": "prod", "c": "median"}] +) +def test_multi_axis_1_raises(func): + # GH#46995 + df = DataFrame({"a": [1, 1, 2], "b": [3, 4, 5], "c": [6, 7, 8]}) + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + gb = df.groupby("a", axis=1) + with pytest.raises(NotImplementedError, match="axis other than 0 is not supported"): + gb.agg(func) + + +@pytest.mark.parametrize( + "test, constant", + [ + ([[20, "A"], [20, "B"], [10, "C"]], {0: [10, 20], 1: ["C", ["A", "B"]]}), + ([[20, "A"], [20, "B"], [30, "C"]], {0: [20, 30], 1: [["A", "B"], "C"]}), + ([["a", 1], ["a", 1], ["b", 2], ["b", 3]], {0: ["a", "b"], 1: [1, [2, 3]]}), + pytest.param( + [["a", 1], ["a", 2], ["b", 3], ["b", 3]], + {0: ["a", "b"], 1: [[1, 2], 3]}, + marks=pytest.mark.xfail, + ), + ], +) +def test_agg_of_mode_list(test, constant): + # GH#25581 + df1 = DataFrame(test) + result = df1.groupby(0).agg(Series.mode) + # Mode usually only returns 1 value, but can return a list in the case of a tie. + + expected = DataFrame(constant) + expected = expected.set_index(0) + + tm.assert_frame_equal(result, expected) + + +def test_dataframe_groupy_agg_list_like_func_with_args(): + # GH#50624 + df = DataFrame({"x": [1, 2, 3], "y": ["a", "b", "c"]}) + gb = df.groupby("y") + + def foo1(x, a=1, c=0): + return x.sum() + a + c + + def foo2(x, b=2, c=0): + return x.sum() + b + c + + msg = r"foo1\(\) got an unexpected keyword argument 'b'" + with pytest.raises(TypeError, match=msg): + gb.agg([foo1, foo2], 3, b=3, c=4) + + result = gb.agg([foo1, foo2], 3, c=4) + expected = DataFrame( + [[8, 8], [9, 9], [10, 10]], + index=Index(["a", "b", "c"], name="y"), + columns=MultiIndex.from_tuples([("x", "foo1"), ("x", "foo2")]), + ) + tm.assert_frame_equal(result, expected) + + +def test_series_groupy_agg_list_like_func_with_args(): + # GH#50624 + s = Series([1, 2, 3]) + sgb = s.groupby(s) + + def foo1(x, a=1, c=0): + return x.sum() + a + c + + def foo2(x, b=2, c=0): + return x.sum() + b + c + + msg = r"foo1\(\) got an unexpected keyword argument 'b'" + with pytest.raises(TypeError, match=msg): + sgb.agg([foo1, foo2], 3, b=3, c=4) + + result = sgb.agg([foo1, foo2], 3, c=4) + expected = DataFrame( + [[8, 8], [9, 9], [10, 10]], index=Index([1, 2, 3]), columns=["foo1", "foo2"] + ) + tm.assert_frame_equal(result, expected) + + +def test_agg_groupings_selection(): + # GH#51186 - a selected grouping should be in the output of agg + df = DataFrame({"a": [1, 1, 2], "b": [3, 3, 4], "c": [5, 6, 7]}) + gb = df.groupby(["a", "b"]) + selected_gb = gb[["b", "c"]] + result = selected_gb.agg(lambda x: x.sum()) + index = MultiIndex( + levels=[[1, 2], [3, 4]], codes=[[0, 1], [0, 1]], names=["a", "b"] + ) + expected = DataFrame({"b": [6, 4], "c": [11, 7]}, index=index) + tm.assert_frame_equal(result, expected) + + +def test_agg_multiple_with_as_index_false_subset_to_a_single_column(): + # GH#50724 + df = DataFrame({"a": [1, 1, 2], "b": [3, 4, 5]}) + gb = df.groupby("a", as_index=False)["b"] + result = gb.agg(["sum", "mean"]) + expected = DataFrame({"a": [1, 2], "sum": [7, 5], "mean": [3.5, 5.0]}) + tm.assert_frame_equal(result, expected) + + +def test_agg_with_as_index_false_with_list(): + # GH#52849 + df = DataFrame({"a1": [0, 0, 1], "a2": [2, 3, 3], "b": [4, 5, 6]}) + gb = df.groupby(by=["a1", "a2"], as_index=False) + result = gb.agg(["sum"]) + + expected = DataFrame( + data=[[0, 2, 4], [0, 3, 5], [1, 3, 6]], + columns=MultiIndex.from_tuples([("a1", ""), ("a2", ""), ("b", "sum")]), + ) + tm.assert_frame_equal(result, expected) + + +def test_groupby_agg_extension_timedelta_cumsum_with_named_aggregation(): + # GH#41720 + expected = DataFrame( + { + "td": { + 0: pd.Timedelta("0 days 01:00:00"), + 1: pd.Timedelta("0 days 01:15:00"), + 2: pd.Timedelta("0 days 01:15:00"), + } + } + ) + df = DataFrame( + { + "td": Series( + ["0 days 01:00:00", "0 days 00:15:00", "0 days 01:15:00"], + dtype="timedelta64[ns]", + ), + "grps": ["a", "a", "b"], + } + ) + gb = df.groupby("grps") + result = gb.agg(td=("td", "cumsum")) + tm.assert_frame_equal(result, expected) + + +def test_groupby_aggregation_empty_group(): + # https://github.com/pandas-dev/pandas/issues/18869 + def func(x): + if len(x) == 0: + raise ValueError("length must not be 0") + return len(x) + + df = DataFrame( + {"A": pd.Categorical(["a", "a"], categories=["a", "b", "c"]), "B": [1, 1]} + ) + msg = "length must not be 0" + with pytest.raises(ValueError, match=msg): + df.groupby("A", observed=False).agg(func) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/aggregate/test_cython.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/aggregate/test_cython.py new file mode 100644 index 0000000000000000000000000000000000000000..5c99882cef6d2393278df5879ea4af75aa14f60c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/aggregate/test_cython.py @@ -0,0 +1,435 @@ +""" +test cython .agg behavior +""" + +import numpy as np +import pytest + +from pandas.core.dtypes.common import ( + is_float_dtype, + is_integer_dtype, +) + +import pandas as pd +from pandas import ( + DataFrame, + Index, + NaT, + Series, + Timedelta, + Timestamp, + bdate_range, +) +import pandas._testing as tm +import pandas.core.common as com + + +@pytest.mark.parametrize( + "op_name", + [ + "count", + "sum", + "std", + "var", + "sem", + "mean", + pytest.param( + "median", + # ignore mean of empty slice + # and all-NaN + marks=[pytest.mark.filterwarnings("ignore::RuntimeWarning")], + ), + "prod", + "min", + "max", + ], +) +def test_cythonized_aggers(op_name): + data = { + "A": [0, 0, 0, 0, 1, 1, 1, 1, 1, 1.0, np.nan, np.nan], + "B": ["A", "B"] * 6, + "C": np.random.default_rng(2).standard_normal(12), + } + df = DataFrame(data) + df.loc[2:10:2, "C"] = np.nan + + op = lambda x: getattr(x, op_name)() + + # single column + grouped = df.drop(["B"], axis=1).groupby("A") + exp = {cat: op(group["C"]) for cat, group in grouped} + exp = DataFrame({"C": exp}) + exp.index.name = "A" + result = op(grouped) + tm.assert_frame_equal(result, exp) + + # multiple columns + grouped = df.groupby(["A", "B"]) + expd = {} + for (cat1, cat2), group in grouped: + expd.setdefault(cat1, {})[cat2] = op(group["C"]) + exp = DataFrame(expd).T.stack(future_stack=True) + exp.index.names = ["A", "B"] + exp.name = "C" + + result = op(grouped)["C"] + if op_name in ["sum", "prod"]: + tm.assert_series_equal(result, exp) + + +def test_cython_agg_boolean(): + frame = DataFrame( + { + "a": np.random.default_rng(2).integers(0, 5, 50), + "b": np.random.default_rng(2).integers(0, 2, 50).astype("bool"), + } + ) + result = frame.groupby("a")["b"].mean() + msg = "using SeriesGroupBy.mean" + with tm.assert_produces_warning(FutureWarning, match=msg): + # GH#53425 + expected = frame.groupby("a")["b"].agg(np.mean) + + tm.assert_series_equal(result, expected) + + +def test_cython_agg_nothing_to_agg(): + frame = DataFrame( + {"a": np.random.default_rng(2).integers(0, 5, 50), "b": ["foo", "bar"] * 25} + ) + + msg = "Cannot use numeric_only=True with SeriesGroupBy.mean and non-numeric dtypes" + with pytest.raises(TypeError, match=msg): + frame.groupby("a")["b"].mean(numeric_only=True) + + frame = DataFrame( + {"a": np.random.default_rng(2).integers(0, 5, 50), "b": ["foo", "bar"] * 25} + ) + + result = frame[["b"]].groupby(frame["a"]).mean(numeric_only=True) + expected = DataFrame( + [], index=frame["a"].sort_values().drop_duplicates(), columns=[] + ) + tm.assert_frame_equal(result, expected) + + +def test_cython_agg_nothing_to_agg_with_dates(): + frame = DataFrame( + { + "a": np.random.default_rng(2).integers(0, 5, 50), + "b": ["foo", "bar"] * 25, + "dates": pd.date_range("now", periods=50, freq="min"), + } + ) + msg = "Cannot use numeric_only=True with SeriesGroupBy.mean and non-numeric dtypes" + with pytest.raises(TypeError, match=msg): + frame.groupby("b").dates.mean(numeric_only=True) + + +def test_cython_agg_frame_columns(): + # #2113 + df = DataFrame({"x": [1, 2, 3], "y": [3, 4, 5]}) + + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + df.groupby(level=0, axis="columns").mean() + with tm.assert_produces_warning(FutureWarning, match=msg): + df.groupby(level=0, axis="columns").mean() + with tm.assert_produces_warning(FutureWarning, match=msg): + df.groupby(level=0, axis="columns").mean() + with tm.assert_produces_warning(FutureWarning, match=msg): + df.groupby(level=0, axis="columns").mean() + + +def test_cython_agg_return_dict(): + # GH 16741 + df = DataFrame( + { + "A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"], + "B": ["one", "one", "two", "three", "two", "two", "one", "three"], + "C": np.random.default_rng(2).standard_normal(8), + "D": np.random.default_rng(2).standard_normal(8), + } + ) + + ts = df.groupby("A")["B"].agg(lambda x: x.value_counts().to_dict()) + expected = Series( + [{"two": 1, "one": 1, "three": 1}, {"two": 2, "one": 2, "three": 1}], + index=Index(["bar", "foo"], name="A"), + name="B", + ) + tm.assert_series_equal(ts, expected) + + +def test_cython_fail_agg(): + dr = bdate_range("1/1/2000", periods=50) + ts = Series(["A", "B", "C", "D", "E"] * 10, index=dr) + + grouped = ts.groupby(lambda x: x.month) + summed = grouped.sum() + msg = "using SeriesGroupBy.sum" + with tm.assert_produces_warning(FutureWarning, match=msg): + # GH#53425 + expected = grouped.agg(np.sum) + tm.assert_series_equal(summed, expected) + + +@pytest.mark.parametrize( + "op, targop", + [ + ("mean", np.mean), + ("median", np.median), + ("var", np.var), + ("sum", np.sum), + ("prod", np.prod), + ("min", np.min), + ("max", np.max), + ("first", lambda x: x.iloc[0]), + ("last", lambda x: x.iloc[-1]), + ], +) +def test__cython_agg_general(op, targop): + df = DataFrame(np.random.default_rng(2).standard_normal(1000)) + labels = np.random.default_rng(2).integers(0, 50, size=1000).astype(float) + + result = df.groupby(labels)._cython_agg_general(op, alt=None, numeric_only=True) + warn = FutureWarning if targop in com._cython_table else None + msg = f"using DataFrameGroupBy.{op}" + with tm.assert_produces_warning(warn, match=msg): + # GH#53425 + expected = df.groupby(labels).agg(targop) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "op, targop", + [ + ("mean", np.mean), + ("median", lambda x: np.median(x) if len(x) > 0 else np.nan), + ("var", lambda x: np.var(x, ddof=1)), + ("min", np.min), + ("max", np.max), + ], +) +def test_cython_agg_empty_buckets(op, targop, observed): + df = DataFrame([11, 12, 13]) + grps = range(0, 55, 5) + + # calling _cython_agg_general directly, instead of via the user API + # which sets different values for min_count, so do that here. + g = df.groupby(pd.cut(df[0], grps), observed=observed) + result = g._cython_agg_general(op, alt=None, numeric_only=True) + + g = df.groupby(pd.cut(df[0], grps), observed=observed) + expected = g.agg(lambda x: targop(x)) + tm.assert_frame_equal(result, expected) + + +def test_cython_agg_empty_buckets_nanops(observed): + # GH-18869 can't call nanops on empty groups, so hardcode expected + # for these + df = DataFrame([11, 12, 13], columns=["a"]) + grps = np.arange(0, 25, 5, dtype=int) + # add / sum + result = df.groupby(pd.cut(df["a"], grps), observed=observed)._cython_agg_general( + "sum", alt=None, numeric_only=True + ) + intervals = pd.interval_range(0, 20, freq=5) + expected = DataFrame( + {"a": [0, 0, 36, 0]}, + index=pd.CategoricalIndex(intervals, name="a", ordered=True), + ) + if observed: + expected = expected[expected.a != 0] + + tm.assert_frame_equal(result, expected) + + # prod + result = df.groupby(pd.cut(df["a"], grps), observed=observed)._cython_agg_general( + "prod", alt=None, numeric_only=True + ) + expected = DataFrame( + {"a": [1, 1, 1716, 1]}, + index=pd.CategoricalIndex(intervals, name="a", ordered=True), + ) + if observed: + expected = expected[expected.a != 1] + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("op", ["first", "last", "max", "min"]) +@pytest.mark.parametrize( + "data", [Timestamp("2016-10-14 21:00:44.557"), Timedelta("17088 days 21:00:44.557")] +) +def test_cython_with_timestamp_and_nat(op, data): + # https://github.com/pandas-dev/pandas/issues/19526 + df = DataFrame({"a": [0, 1], "b": [data, NaT]}) + index = Index([0, 1], name="a") + + # We will group by a and test the cython aggregations + expected = DataFrame({"b": [data, NaT]}, index=index) + + result = df.groupby("a").aggregate(op) + tm.assert_frame_equal(expected, result) + + +@pytest.mark.parametrize( + "agg", + [ + "min", + "max", + "count", + "sum", + "prod", + "var", + "mean", + "median", + "ohlc", + "cumprod", + "cumsum", + "shift", + "any", + "all", + "quantile", + "first", + "last", + "rank", + "cummin", + "cummax", + ], +) +def test_read_only_buffer_source_agg(agg): + # https://github.com/pandas-dev/pandas/issues/36014 + df = DataFrame( + { + "sepal_length": [5.1, 4.9, 4.7, 4.6, 5.0], + "species": ["setosa", "setosa", "setosa", "setosa", "setosa"], + } + ) + df._mgr.arrays[0].flags.writeable = False + + result = df.groupby(["species"]).agg({"sepal_length": agg}) + expected = df.copy().groupby(["species"]).agg({"sepal_length": agg}) + + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "op_name", + [ + "count", + "sum", + "std", + "var", + "sem", + "mean", + "median", + "prod", + "min", + "max", + ], +) +def test_cython_agg_nullable_int(op_name): + # ensure that the cython-based aggregations don't fail for nullable dtype + # (eg https://github.com/pandas-dev/pandas/issues/37415) + df = DataFrame( + { + "A": ["A", "B"] * 5, + "B": pd.array([1, 2, 3, 4, 5, 6, 7, 8, 9, pd.NA], dtype="Int64"), + } + ) + result = getattr(df.groupby("A")["B"], op_name)() + df2 = df.assign(B=df["B"].astype("float64")) + expected = getattr(df2.groupby("A")["B"], op_name)() + if op_name in ("mean", "median"): + convert_integer = False + else: + convert_integer = True + expected = expected.convert_dtypes(convert_integer=convert_integer) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("dtype", ["Int64", "Float64", "boolean"]) +def test_count_masked_returns_masked_dtype(dtype): + df = DataFrame( + { + "A": [1, 1], + "B": pd.array([1, pd.NA], dtype=dtype), + "C": pd.array([1, 1], dtype=dtype), + } + ) + result = df.groupby("A").count() + expected = DataFrame( + [[1, 2]], index=Index([1], name="A"), columns=["B", "C"], dtype="Int64" + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("with_na", [True, False]) +@pytest.mark.parametrize( + "op_name, action", + [ + # ("count", "always_int"), + ("sum", "large_int"), + # ("std", "always_float"), + ("var", "always_float"), + # ("sem", "always_float"), + ("mean", "always_float"), + ("median", "always_float"), + ("prod", "large_int"), + ("min", "preserve"), + ("max", "preserve"), + ("first", "preserve"), + ("last", "preserve"), + ], +) +@pytest.mark.parametrize( + "data", + [ + pd.array([1, 2, 3, 4], dtype="Int64"), + pd.array([1, 2, 3, 4], dtype="Int8"), + pd.array([0.1, 0.2, 0.3, 0.4], dtype="Float32"), + pd.array([0.1, 0.2, 0.3, 0.4], dtype="Float64"), + pd.array([True, True, False, False], dtype="boolean"), + ], +) +def test_cython_agg_EA_known_dtypes(data, op_name, action, with_na): + if with_na: + data[3] = pd.NA + + df = DataFrame({"key": ["a", "a", "b", "b"], "col": data}) + grouped = df.groupby("key") + + if action == "always_int": + # always Int64 + expected_dtype = pd.Int64Dtype() + elif action == "large_int": + # for any int/bool use Int64, for float preserve dtype + if is_float_dtype(data.dtype): + expected_dtype = data.dtype + elif is_integer_dtype(data.dtype): + # match the numpy dtype we'd get with the non-nullable analogue + expected_dtype = data.dtype + else: + expected_dtype = pd.Int64Dtype() + elif action == "always_float": + # for any int/bool use Float64, for float preserve dtype + if is_float_dtype(data.dtype): + expected_dtype = data.dtype + else: + expected_dtype = pd.Float64Dtype() + elif action == "preserve": + expected_dtype = data.dtype + + result = getattr(grouped, op_name)() + assert result["col"].dtype == expected_dtype + + result = grouped.aggregate(op_name) + assert result["col"].dtype == expected_dtype + + result = getattr(grouped["col"], op_name)() + assert result.dtype == expected_dtype + + result = grouped["col"].aggregate(op_name) + assert result.dtype == expected_dtype diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/aggregate/test_numba.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/aggregate/test_numba.py new file mode 100644 index 0000000000000000000000000000000000000000..ee694129f71183294dc780783d3b9ccdeae73bf4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/aggregate/test_numba.py @@ -0,0 +1,392 @@ +import numpy as np +import pytest + +from pandas.errors import NumbaUtilError + +from pandas import ( + DataFrame, + Index, + NamedAgg, + Series, + option_context, +) +import pandas._testing as tm + +pytestmark = pytest.mark.single_cpu + + +def test_correct_function_signature(): + pytest.importorskip("numba") + + def incorrect_function(x): + return sum(x) * 2.7 + + data = DataFrame( + {"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]}, + columns=["key", "data"], + ) + with pytest.raises(NumbaUtilError, match="The first 2"): + data.groupby("key").agg(incorrect_function, engine="numba") + + with pytest.raises(NumbaUtilError, match="The first 2"): + data.groupby("key")["data"].agg(incorrect_function, engine="numba") + + +def test_check_nopython_kwargs(): + pytest.importorskip("numba") + + def incorrect_function(values, index): + return sum(values) * 2.7 + + data = DataFrame( + {"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]}, + columns=["key", "data"], + ) + with pytest.raises(NumbaUtilError, match="numba does not support"): + data.groupby("key").agg(incorrect_function, engine="numba", a=1) + + with pytest.raises(NumbaUtilError, match="numba does not support"): + data.groupby("key")["data"].agg(incorrect_function, engine="numba", a=1) + + +@pytest.mark.filterwarnings("ignore") +# Filter warnings when parallel=True and the function can't be parallelized by Numba +@pytest.mark.parametrize("jit", [True, False]) +@pytest.mark.parametrize("pandas_obj", ["Series", "DataFrame"]) +@pytest.mark.parametrize("as_index", [True, False]) +def test_numba_vs_cython(jit, pandas_obj, nogil, parallel, nopython, as_index): + pytest.importorskip("numba") + + def func_numba(values, index): + return np.mean(values) * 2.7 + + if jit: + # Test accepted jitted functions + import numba + + func_numba = numba.jit(func_numba) + + data = DataFrame( + {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1] + ) + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + grouped = data.groupby(0, as_index=as_index) + if pandas_obj == "Series": + grouped = grouped[1] + + result = grouped.agg(func_numba, engine="numba", engine_kwargs=engine_kwargs) + expected = grouped.agg(lambda x: np.mean(x) * 2.7, engine="cython") + + tm.assert_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore") +# Filter warnings when parallel=True and the function can't be parallelized by Numba +@pytest.mark.parametrize("jit", [True, False]) +@pytest.mark.parametrize("pandas_obj", ["Series", "DataFrame"]) +def test_cache(jit, pandas_obj, nogil, parallel, nopython): + # Test that the functions are cached correctly if we switch functions + pytest.importorskip("numba") + + def func_1(values, index): + return np.mean(values) - 3.4 + + def func_2(values, index): + return np.mean(values) * 2.7 + + if jit: + import numba + + func_1 = numba.jit(func_1) + func_2 = numba.jit(func_2) + + data = DataFrame( + {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1] + ) + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + grouped = data.groupby(0) + if pandas_obj == "Series": + grouped = grouped[1] + + result = grouped.agg(func_1, engine="numba", engine_kwargs=engine_kwargs) + expected = grouped.agg(lambda x: np.mean(x) - 3.4, engine="cython") + tm.assert_equal(result, expected) + + # Add func_2 to the cache + result = grouped.agg(func_2, engine="numba", engine_kwargs=engine_kwargs) + expected = grouped.agg(lambda x: np.mean(x) * 2.7, engine="cython") + tm.assert_equal(result, expected) + + # Retest func_1 which should use the cache + result = grouped.agg(func_1, engine="numba", engine_kwargs=engine_kwargs) + expected = grouped.agg(lambda x: np.mean(x) - 3.4, engine="cython") + tm.assert_equal(result, expected) + + +def test_use_global_config(): + pytest.importorskip("numba") + + def func_1(values, index): + return np.mean(values) - 3.4 + + data = DataFrame( + {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1] + ) + grouped = data.groupby(0) + expected = grouped.agg(func_1, engine="numba") + with option_context("compute.use_numba", True): + result = grouped.agg(func_1, engine=None) + tm.assert_frame_equal(expected, result) + + +@pytest.mark.parametrize( + "agg_kwargs", + [ + {"func": ["min", "max"]}, + {"func": "min"}, + {"func": {1: ["min", "max"], 2: "sum"}}, + {"bmin": NamedAgg(column=1, aggfunc="min")}, + ], +) +def test_multifunc_numba_vs_cython_frame(agg_kwargs): + pytest.importorskip("numba") + data = DataFrame( + { + 0: ["a", "a", "b", "b", "a"], + 1: [1.0, 2.0, 3.0, 4.0, 5.0], + 2: [1, 2, 3, 4, 5], + }, + columns=[0, 1, 2], + ) + grouped = data.groupby(0) + result = grouped.agg(**agg_kwargs, engine="numba") + expected = grouped.agg(**agg_kwargs, engine="cython") + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "agg_kwargs,expected_func", + [ + ({"func": lambda values, index: values.sum()}, "sum"), + # FIXME + pytest.param( + { + "func": [ + lambda values, index: values.sum(), + lambda values, index: values.min(), + ] + }, + ["sum", "min"], + marks=pytest.mark.xfail( + reason="This doesn't work yet! Fails in nopython pipeline!" + ), + ), + ], +) +def test_multifunc_numba_udf_frame(agg_kwargs, expected_func): + pytest.importorskip("numba") + data = DataFrame( + { + 0: ["a", "a", "b", "b", "a"], + 1: [1.0, 2.0, 3.0, 4.0, 5.0], + 2: [1, 2, 3, 4, 5], + }, + columns=[0, 1, 2], + ) + grouped = data.groupby(0) + result = grouped.agg(**agg_kwargs, engine="numba") + expected = grouped.agg(expected_func, engine="cython") + # check_dtype can be removed if GH 44952 is addressed + # Currently, UDFs still always return float64 while reductions can preserve dtype + tm.assert_frame_equal(result, expected, check_dtype=False) + + +@pytest.mark.parametrize( + "agg_kwargs", + [{"func": ["min", "max"]}, {"func": "min"}, {"min_val": "min", "max_val": "max"}], +) +def test_multifunc_numba_vs_cython_series(agg_kwargs): + pytest.importorskip("numba") + labels = ["a", "a", "b", "b", "a"] + data = Series([1.0, 2.0, 3.0, 4.0, 5.0]) + grouped = data.groupby(labels) + agg_kwargs["engine"] = "numba" + result = grouped.agg(**agg_kwargs) + agg_kwargs["engine"] = "cython" + expected = grouped.agg(**agg_kwargs) + if isinstance(expected, DataFrame): + tm.assert_frame_equal(result, expected) + else: + tm.assert_series_equal(result, expected) + + +@pytest.mark.single_cpu +@pytest.mark.parametrize( + "data,agg_kwargs", + [ + (Series([1.0, 2.0, 3.0, 4.0, 5.0]), {"func": ["min", "max"]}), + (Series([1.0, 2.0, 3.0, 4.0, 5.0]), {"func": "min"}), + ( + DataFrame( + {1: [1.0, 2.0, 3.0, 4.0, 5.0], 2: [1, 2, 3, 4, 5]}, columns=[1, 2] + ), + {"func": ["min", "max"]}, + ), + ( + DataFrame( + {1: [1.0, 2.0, 3.0, 4.0, 5.0], 2: [1, 2, 3, 4, 5]}, columns=[1, 2] + ), + {"func": "min"}, + ), + ( + DataFrame( + {1: [1.0, 2.0, 3.0, 4.0, 5.0], 2: [1, 2, 3, 4, 5]}, columns=[1, 2] + ), + {"func": {1: ["min", "max"], 2: "sum"}}, + ), + ( + DataFrame( + {1: [1.0, 2.0, 3.0, 4.0, 5.0], 2: [1, 2, 3, 4, 5]}, columns=[1, 2] + ), + {"min_col": NamedAgg(column=1, aggfunc="min")}, + ), + ], +) +def test_multifunc_numba_kwarg_propagation(data, agg_kwargs): + pytest.importorskip("numba") + labels = ["a", "a", "b", "b", "a"] + grouped = data.groupby(labels) + result = grouped.agg(**agg_kwargs, engine="numba", engine_kwargs={"parallel": True}) + expected = grouped.agg(**agg_kwargs, engine="numba") + if isinstance(expected, DataFrame): + tm.assert_frame_equal(result, expected) + else: + tm.assert_series_equal(result, expected) + + +def test_args_not_cached(): + # GH 41647 + pytest.importorskip("numba") + + def sum_last(values, index, n): + return values[-n:].sum() + + df = DataFrame({"id": [0, 0, 1, 1], "x": [1, 1, 1, 1]}) + grouped_x = df.groupby("id")["x"] + result = grouped_x.agg(sum_last, 1, engine="numba") + expected = Series([1.0] * 2, name="x", index=Index([0, 1], name="id")) + tm.assert_series_equal(result, expected) + + result = grouped_x.agg(sum_last, 2, engine="numba") + expected = Series([2.0] * 2, name="x", index=Index([0, 1], name="id")) + tm.assert_series_equal(result, expected) + + +def test_index_data_correctly_passed(): + # GH 43133 + pytest.importorskip("numba") + + def f(values, index): + return np.mean(index) + + df = DataFrame({"group": ["A", "A", "B"], "v": [4, 5, 6]}, index=[-1, -2, -3]) + result = df.groupby("group").aggregate(f, engine="numba") + expected = DataFrame( + [-1.5, -3.0], columns=["v"], index=Index(["A", "B"], name="group") + ) + tm.assert_frame_equal(result, expected) + + +def test_engine_kwargs_not_cached(): + # If the user passes a different set of engine_kwargs don't return the same + # jitted function + pytest.importorskip("numba") + nogil = True + parallel = False + nopython = True + + def func_kwargs(values, index): + return nogil + parallel + nopython + + engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel} + df = DataFrame({"value": [0, 0, 0]}) + result = df.groupby(level=0).aggregate( + func_kwargs, engine="numba", engine_kwargs=engine_kwargs + ) + expected = DataFrame({"value": [2.0, 2.0, 2.0]}) + tm.assert_frame_equal(result, expected) + + nogil = False + engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel} + result = df.groupby(level=0).aggregate( + func_kwargs, engine="numba", engine_kwargs=engine_kwargs + ) + expected = DataFrame({"value": [1.0, 1.0, 1.0]}) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore") +def test_multiindex_one_key(nogil, parallel, nopython): + pytest.importorskip("numba") + + def numba_func(values, index): + return 1 + + df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"]) + engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel} + result = df.groupby("A").agg( + numba_func, engine="numba", engine_kwargs=engine_kwargs + ) + expected = DataFrame([1.0], index=Index([1], name="A"), columns=["C"]) + tm.assert_frame_equal(result, expected) + + +def test_multiindex_multi_key_not_supported(nogil, parallel, nopython): + pytest.importorskip("numba") + + def numba_func(values, index): + return 1 + + df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"]) + engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel} + with pytest.raises(NotImplementedError, match="more than 1 grouping labels"): + df.groupby(["A", "B"]).agg( + numba_func, engine="numba", engine_kwargs=engine_kwargs + ) + + +def test_multilabel_numba_vs_cython(numba_supported_reductions): + pytest.importorskip("numba") + reduction, kwargs = numba_supported_reductions + df = DataFrame( + { + "A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"], + "B": ["one", "one", "two", "three", "two", "two", "one", "three"], + "C": np.random.default_rng(2).standard_normal(8), + "D": np.random.default_rng(2).standard_normal(8), + } + ) + gb = df.groupby(["A", "B"]) + res_agg = gb.agg(reduction, engine="numba", **kwargs) + expected_agg = gb.agg(reduction, engine="cython", **kwargs) + tm.assert_frame_equal(res_agg, expected_agg) + # Test that calling the aggregation directly also works + direct_res = getattr(gb, reduction)(engine="numba", **kwargs) + direct_expected = getattr(gb, reduction)(engine="cython", **kwargs) + tm.assert_frame_equal(direct_res, direct_expected) + + +def test_multilabel_udf_numba_vs_cython(): + pytest.importorskip("numba") + df = DataFrame( + { + "A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"], + "B": ["one", "one", "two", "three", "two", "two", "one", "three"], + "C": np.random.default_rng(2).standard_normal(8), + "D": np.random.default_rng(2).standard_normal(8), + } + ) + gb = df.groupby(["A", "B"]) + result = gb.agg(lambda values, index: values.min(), engine="numba") + expected = gb.agg(lambda x: x.min(), engine="cython") + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/aggregate/test_other.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/aggregate/test_other.py new file mode 100644 index 0000000000000000000000000000000000000000..00136e572288e9858412fd9d84e3ee48dc52a09c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/aggregate/test_other.py @@ -0,0 +1,675 @@ +""" +test all other .agg behavior +""" + +import datetime as dt +from functools import partial + +import numpy as np +import pytest + +from pandas.errors import SpecificationError + +import pandas as pd +from pandas import ( + DataFrame, + Index, + MultiIndex, + PeriodIndex, + Series, + date_range, + period_range, +) +import pandas._testing as tm + +from pandas.io.formats.printing import pprint_thing + + +def test_agg_partial_failure_raises(): + # GH#43741 + + df = DataFrame( + { + "data1": np.random.default_rng(2).standard_normal(5), + "data2": np.random.default_rng(2).standard_normal(5), + "key1": ["a", "a", "b", "b", "a"], + "key2": ["one", "two", "one", "two", "one"], + } + ) + grouped = df.groupby("key1") + + def peak_to_peak(arr): + return arr.max() - arr.min() + + with pytest.raises(TypeError, match="unsupported operand type"): + grouped.agg([peak_to_peak]) + + with pytest.raises(TypeError, match="unsupported operand type"): + grouped.agg(peak_to_peak) + + +def test_agg_datetimes_mixed(): + data = [[1, "2012-01-01", 1.0], [2, "2012-01-02", 2.0], [3, None, 3.0]] + + df1 = DataFrame( + { + "key": [x[0] for x in data], + "date": [x[1] for x in data], + "value": [x[2] for x in data], + } + ) + + data = [ + [ + row[0], + (dt.datetime.strptime(row[1], "%Y-%m-%d").date() if row[1] else None), + row[2], + ] + for row in data + ] + + df2 = DataFrame( + { + "key": [x[0] for x in data], + "date": [x[1] for x in data], + "value": [x[2] for x in data], + } + ) + + df1["weights"] = df1["value"] / df1["value"].sum() + gb1 = df1.groupby("date").aggregate("sum") + + df2["weights"] = df1["value"] / df1["value"].sum() + gb2 = df2.groupby("date").aggregate("sum") + + assert len(gb1) == len(gb2) + + +def test_agg_period_index(): + prng = period_range("2012-1-1", freq="M", periods=3) + df = DataFrame(np.random.default_rng(2).standard_normal((3, 2)), index=prng) + rs = df.groupby(level=0).sum() + assert isinstance(rs.index, PeriodIndex) + + # GH 3579 + index = period_range(start="1999-01", periods=5, freq="M") + s1 = Series(np.random.default_rng(2).random(len(index)), index=index) + s2 = Series(np.random.default_rng(2).random(len(index)), index=index) + df = DataFrame.from_dict({"s1": s1, "s2": s2}) + grouped = df.groupby(df.index.month) + list(grouped) + + +def test_agg_dict_parameter_cast_result_dtypes(): + # GH 12821 + + df = DataFrame( + { + "class": ["A", "A", "B", "B", "C", "C", "D", "D"], + "time": date_range("1/1/2011", periods=8, freq="h"), + } + ) + df.loc[[0, 1, 2, 5], "time"] = None + + # test for `first` function + exp = df.loc[[0, 3, 4, 6]].set_index("class") + grouped = df.groupby("class") + tm.assert_frame_equal(grouped.first(), exp) + tm.assert_frame_equal(grouped.agg("first"), exp) + tm.assert_frame_equal(grouped.agg({"time": "first"}), exp) + tm.assert_series_equal(grouped.time.first(), exp["time"]) + tm.assert_series_equal(grouped.time.agg("first"), exp["time"]) + + # test for `last` function + exp = df.loc[[0, 3, 4, 7]].set_index("class") + grouped = df.groupby("class") + tm.assert_frame_equal(grouped.last(), exp) + tm.assert_frame_equal(grouped.agg("last"), exp) + tm.assert_frame_equal(grouped.agg({"time": "last"}), exp) + tm.assert_series_equal(grouped.time.last(), exp["time"]) + tm.assert_series_equal(grouped.time.agg("last"), exp["time"]) + + # count + exp = Series([2, 2, 2, 2], index=Index(list("ABCD"), name="class"), name="time") + tm.assert_series_equal(grouped.time.agg(len), exp) + tm.assert_series_equal(grouped.time.size(), exp) + + exp = Series([0, 1, 1, 2], index=Index(list("ABCD"), name="class"), name="time") + tm.assert_series_equal(grouped.time.count(), exp) + + +def test_agg_cast_results_dtypes(): + # similar to GH12821 + # xref #11444 + u = [dt.datetime(2015, x + 1, 1) for x in range(12)] + v = list("aaabbbbbbccd") + df = DataFrame({"X": v, "Y": u}) + + result = df.groupby("X")["Y"].agg(len) + expected = df.groupby("X")["Y"].count() + tm.assert_series_equal(result, expected) + + +def test_aggregate_float64_no_int64(): + # see gh-11199 + df = DataFrame({"a": [1, 2, 3, 4, 5], "b": [1, 2, 2, 4, 5], "c": [1, 2, 3, 4, 5]}) + + expected = DataFrame({"a": [1, 2.5, 4, 5]}, index=[1, 2, 4, 5]) + expected.index.name = "b" + + result = df.groupby("b")[["a"]].mean() + tm.assert_frame_equal(result, expected) + + expected = DataFrame({"a": [1, 2.5, 4, 5], "c": [1, 2.5, 4, 5]}, index=[1, 2, 4, 5]) + expected.index.name = "b" + + result = df.groupby("b")[["a", "c"]].mean() + tm.assert_frame_equal(result, expected) + + +def test_aggregate_api_consistency(): + # GH 9052 + # make sure that the aggregates via dict + # are consistent + df = DataFrame( + { + "A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"], + "B": ["one", "one", "two", "two", "two", "two", "one", "two"], + "C": np.random.default_rng(2).standard_normal(8) + 1.0, + "D": np.arange(8), + } + ) + + grouped = df.groupby(["A", "B"]) + c_mean = grouped["C"].mean() + c_sum = grouped["C"].sum() + d_mean = grouped["D"].mean() + d_sum = grouped["D"].sum() + + result = grouped["D"].agg(["sum", "mean"]) + expected = pd.concat([d_sum, d_mean], axis=1) + expected.columns = ["sum", "mean"] + tm.assert_frame_equal(result, expected, check_like=True) + + result = grouped.agg(["sum", "mean"]) + expected = pd.concat([c_sum, c_mean, d_sum, d_mean], axis=1) + expected.columns = MultiIndex.from_product([["C", "D"], ["sum", "mean"]]) + tm.assert_frame_equal(result, expected, check_like=True) + + result = grouped[["D", "C"]].agg(["sum", "mean"]) + expected = pd.concat([d_sum, d_mean, c_sum, c_mean], axis=1) + expected.columns = MultiIndex.from_product([["D", "C"], ["sum", "mean"]]) + tm.assert_frame_equal(result, expected, check_like=True) + + result = grouped.agg({"C": "mean", "D": "sum"}) + expected = pd.concat([d_sum, c_mean], axis=1) + tm.assert_frame_equal(result, expected, check_like=True) + + result = grouped.agg({"C": ["mean", "sum"], "D": ["mean", "sum"]}) + expected = pd.concat([c_mean, c_sum, d_mean, d_sum], axis=1) + expected.columns = MultiIndex.from_product([["C", "D"], ["mean", "sum"]]) + + msg = r"Column\(s\) \['r', 'r2'\] do not exist" + with pytest.raises(KeyError, match=msg): + grouped[["D", "C"]].agg({"r": "sum", "r2": "mean"}) + + +def test_agg_dict_renaming_deprecation(): + # 15931 + df = DataFrame({"A": [1, 1, 1, 2, 2], "B": range(5), "C": range(5)}) + + msg = r"nested renamer is not supported" + with pytest.raises(SpecificationError, match=msg): + df.groupby("A").agg( + {"B": {"foo": ["sum", "max"]}, "C": {"bar": ["count", "min"]}} + ) + + msg = r"Column\(s\) \['ma'\] do not exist" + with pytest.raises(KeyError, match=msg): + df.groupby("A")[["B", "C"]].agg({"ma": "max"}) + + msg = r"nested renamer is not supported" + with pytest.raises(SpecificationError, match=msg): + df.groupby("A").B.agg({"foo": "count"}) + + +def test_agg_compat(): + # GH 12334 + df = DataFrame( + { + "A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"], + "B": ["one", "one", "two", "two", "two", "two", "one", "two"], + "C": np.random.default_rng(2).standard_normal(8) + 1.0, + "D": np.arange(8), + } + ) + + g = df.groupby(["A", "B"]) + + msg = r"nested renamer is not supported" + with pytest.raises(SpecificationError, match=msg): + g["D"].agg({"C": ["sum", "std"]}) + + with pytest.raises(SpecificationError, match=msg): + g["D"].agg({"C": "sum", "D": "std"}) + + +def test_agg_nested_dicts(): + # API change for disallowing these types of nested dicts + df = DataFrame( + { + "A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"], + "B": ["one", "one", "two", "two", "two", "two", "one", "two"], + "C": np.random.default_rng(2).standard_normal(8) + 1.0, + "D": np.arange(8), + } + ) + + g = df.groupby(["A", "B"]) + + msg = r"nested renamer is not supported" + with pytest.raises(SpecificationError, match=msg): + g.aggregate({"r1": {"C": ["mean", "sum"]}, "r2": {"D": ["mean", "sum"]}}) + + with pytest.raises(SpecificationError, match=msg): + g.agg({"C": {"ra": ["mean", "std"]}, "D": {"rb": ["mean", "std"]}}) + + # same name as the original column + # GH9052 + with pytest.raises(SpecificationError, match=msg): + g["D"].agg({"result1": np.sum, "result2": np.mean}) + + with pytest.raises(SpecificationError, match=msg): + g["D"].agg({"D": np.sum, "result2": np.mean}) + + +def test_agg_item_by_item_raise_typeerror(): + df = DataFrame(np.random.default_rng(2).integers(10, size=(20, 10))) + + def raiseException(df): + pprint_thing("----------------------------------------") + pprint_thing(df.to_string()) + raise TypeError("test") + + with pytest.raises(TypeError, match="test"): + df.groupby(0).agg(raiseException) + + +def test_series_agg_multikey(): + ts = Series( + np.arange(10, dtype=np.float64), index=date_range("2020-01-01", periods=10) + ) + grouped = ts.groupby([lambda x: x.year, lambda x: x.month]) + + result = grouped.agg("sum") + expected = grouped.sum() + tm.assert_series_equal(result, expected) + + +def test_series_agg_multi_pure_python(): + data = DataFrame( + { + "A": [ + "foo", + "foo", + "foo", + "foo", + "bar", + "bar", + "bar", + "bar", + "foo", + "foo", + "foo", + ], + "B": [ + "one", + "one", + "one", + "two", + "one", + "one", + "one", + "two", + "two", + "two", + "one", + ], + "C": [ + "dull", + "dull", + "shiny", + "dull", + "dull", + "shiny", + "shiny", + "dull", + "shiny", + "shiny", + "shiny", + ], + "D": np.random.default_rng(2).standard_normal(11), + "E": np.random.default_rng(2).standard_normal(11), + "F": np.random.default_rng(2).standard_normal(11), + } + ) + + def bad(x): + assert len(x.values.base) > 0 + return "foo" + + result = data.groupby(["A", "B"]).agg(bad) + expected = data.groupby(["A", "B"]).agg(lambda x: "foo") + tm.assert_frame_equal(result, expected) + + +def test_agg_consistency(): + # agg with ([]) and () not consistent + # GH 6715 + def P1(a): + return np.percentile(a.dropna(), q=1) + + df = DataFrame( + { + "col1": [1, 2, 3, 4], + "col2": [10, 25, 26, 31], + "date": [ + dt.date(2013, 2, 10), + dt.date(2013, 2, 10), + dt.date(2013, 2, 11), + dt.date(2013, 2, 11), + ], + } + ) + + g = df.groupby("date") + + expected = g.agg([P1]) + expected.columns = expected.columns.levels[0] + + result = g.agg(P1) + tm.assert_frame_equal(result, expected) + + +def test_agg_callables(): + # GH 7929 + df = DataFrame({"foo": [1, 2], "bar": [3, 4]}).astype(np.int64) + + class fn_class: + def __call__(self, x): + return sum(x) + + equiv_callables = [ + sum, + np.sum, + lambda x: sum(x), + lambda x: x.sum(), + partial(sum), + fn_class(), + ] + + expected = df.groupby("foo").agg("sum") + for ecall in equiv_callables: + warn = FutureWarning if ecall is sum or ecall is np.sum else None + msg = "using DataFrameGroupBy.sum" + with tm.assert_produces_warning(warn, match=msg): + result = df.groupby("foo").agg(ecall) + tm.assert_frame_equal(result, expected) + + +def test_agg_over_numpy_arrays(): + # GH 3788 + df = DataFrame( + [ + [1, np.array([10, 20, 30])], + [1, np.array([40, 50, 60])], + [2, np.array([20, 30, 40])], + ], + columns=["category", "arraydata"], + ) + gb = df.groupby("category") + + expected_data = [[np.array([50, 70, 90])], [np.array([20, 30, 40])]] + expected_index = Index([1, 2], name="category") + expected_column = ["arraydata"] + expected = DataFrame(expected_data, index=expected_index, columns=expected_column) + + alt = gb.sum(numeric_only=False) + tm.assert_frame_equal(alt, expected) + + result = gb.agg("sum", numeric_only=False) + tm.assert_frame_equal(result, expected) + + # FIXME: the original version of this test called `gb.agg(sum)` + # and that raises TypeError if `numeric_only=False` is passed + + +@pytest.mark.parametrize("as_period", [True, False]) +def test_agg_tzaware_non_datetime_result(as_period): + # discussed in GH#29589, fixed in GH#29641, operating on tzaware values + # with function that is not dtype-preserving + dti = date_range("2012-01-01", periods=4, tz="UTC") + if as_period: + dti = dti.tz_localize(None).to_period("D") + + df = DataFrame({"a": [0, 0, 1, 1], "b": dti}) + gb = df.groupby("a") + + # Case that _does_ preserve the dtype + result = gb["b"].agg(lambda x: x.iloc[0]) + expected = Series(dti[::2], name="b") + expected.index.name = "a" + tm.assert_series_equal(result, expected) + + # Cases that do _not_ preserve the dtype + result = gb["b"].agg(lambda x: x.iloc[0].year) + expected = Series([2012, 2012], name="b") + expected.index.name = "a" + tm.assert_series_equal(result, expected) + + result = gb["b"].agg(lambda x: x.iloc[-1] - x.iloc[0]) + expected = Series([pd.Timedelta(days=1), pd.Timedelta(days=1)], name="b") + expected.index.name = "a" + if as_period: + expected = Series([pd.offsets.Day(1), pd.offsets.Day(1)], name="b") + expected.index.name = "a" + tm.assert_series_equal(result, expected) + + +def test_agg_timezone_round_trip(): + # GH 15426 + ts = pd.Timestamp("2016-01-01 12:00:00", tz="US/Pacific") + df = DataFrame({"a": 1, "b": [ts + dt.timedelta(minutes=nn) for nn in range(10)]}) + + result1 = df.groupby("a")["b"].agg("min").iloc[0] + result2 = df.groupby("a")["b"].agg(lambda x: np.min(x)).iloc[0] + result3 = df.groupby("a")["b"].min().iloc[0] + + assert result1 == ts + assert result2 == ts + assert result3 == ts + + dates = [ + pd.Timestamp(f"2016-01-0{i:d} 12:00:00", tz="US/Pacific") for i in range(1, 5) + ] + df = DataFrame({"A": ["a", "b"] * 2, "B": dates}) + grouped = df.groupby("A") + + ts = df["B"].iloc[0] + assert ts == grouped.nth(0)["B"].iloc[0] + assert ts == grouped.head(1)["B"].iloc[0] + assert ts == grouped.first()["B"].iloc[0] + + # GH#27110 applying iloc should return a DataFrame + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + assert ts == grouped.apply(lambda x: x.iloc[0]).iloc[0, 1] + + ts = df["B"].iloc[2] + assert ts == grouped.last()["B"].iloc[0] + + # GH#27110 applying iloc should return a DataFrame + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + assert ts == grouped.apply(lambda x: x.iloc[-1]).iloc[0, 1] + + +def test_sum_uint64_overflow(): + # see gh-14758 + # Convert to uint64 and don't overflow + df = DataFrame([[1, 2], [3, 4], [5, 6]], dtype=object) + df = df + 9223372036854775807 + + index = Index( + [9223372036854775808, 9223372036854775810, 9223372036854775812], dtype=np.uint64 + ) + expected = DataFrame( + {1: [9223372036854775809, 9223372036854775811, 9223372036854775813]}, + index=index, + dtype=object, + ) + + expected.index.name = 0 + result = df.groupby(0).sum(numeric_only=False) + tm.assert_frame_equal(result, expected) + + # out column is non-numeric, so with numeric_only=True it is dropped + result2 = df.groupby(0).sum(numeric_only=True) + expected2 = expected[[]] + tm.assert_frame_equal(result2, expected2) + + +@pytest.mark.parametrize( + "structure, expected", + [ + (tuple, DataFrame({"C": {(1, 1): (1, 1, 1), (3, 4): (3, 4, 4)}})), + (list, DataFrame({"C": {(1, 1): [1, 1, 1], (3, 4): [3, 4, 4]}})), + ( + lambda x: tuple(x), + DataFrame({"C": {(1, 1): (1, 1, 1), (3, 4): (3, 4, 4)}}), + ), + ( + lambda x: list(x), + DataFrame({"C": {(1, 1): [1, 1, 1], (3, 4): [3, 4, 4]}}), + ), + ], +) +def test_agg_structs_dataframe(structure, expected): + df = DataFrame( + {"A": [1, 1, 1, 3, 3, 3], "B": [1, 1, 1, 4, 4, 4], "C": [1, 1, 1, 3, 4, 4]} + ) + + result = df.groupby(["A", "B"]).aggregate(structure) + expected.index.names = ["A", "B"] + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "structure, expected", + [ + (tuple, Series([(1, 1, 1), (3, 4, 4)], index=[1, 3], name="C")), + (list, Series([[1, 1, 1], [3, 4, 4]], index=[1, 3], name="C")), + (lambda x: tuple(x), Series([(1, 1, 1), (3, 4, 4)], index=[1, 3], name="C")), + (lambda x: list(x), Series([[1, 1, 1], [3, 4, 4]], index=[1, 3], name="C")), + ], +) +def test_agg_structs_series(structure, expected): + # Issue #18079 + df = DataFrame( + {"A": [1, 1, 1, 3, 3, 3], "B": [1, 1, 1, 4, 4, 4], "C": [1, 1, 1, 3, 4, 4]} + ) + + result = df.groupby("A")["C"].aggregate(structure) + expected.index.name = "A" + tm.assert_series_equal(result, expected) + + +def test_agg_category_nansum(observed): + categories = ["a", "b", "c"] + df = DataFrame( + {"A": pd.Categorical(["a", "a", "b"], categories=categories), "B": [1, 2, 3]} + ) + msg = "using SeriesGroupBy.sum" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = df.groupby("A", observed=observed).B.agg(np.nansum) + expected = Series( + [3, 3, 0], + index=pd.CategoricalIndex(["a", "b", "c"], categories=categories, name="A"), + name="B", + ) + if observed: + expected = expected[expected != 0] + tm.assert_series_equal(result, expected) + + +def test_agg_list_like_func(): + # GH 18473 + df = DataFrame({"A": [str(x) for x in range(3)], "B": [str(x) for x in range(3)]}) + grouped = df.groupby("A", as_index=False, sort=False) + result = grouped.agg({"B": lambda x: list(x)}) + expected = DataFrame( + {"A": [str(x) for x in range(3)], "B": [[str(x)] for x in range(3)]} + ) + tm.assert_frame_equal(result, expected) + + +def test_agg_lambda_with_timezone(): + # GH 23683 + df = DataFrame( + { + "tag": [1, 1], + "date": [ + pd.Timestamp("2018-01-01", tz="UTC"), + pd.Timestamp("2018-01-02", tz="UTC"), + ], + } + ) + result = df.groupby("tag").agg({"date": lambda e: e.head(1)}) + expected = DataFrame( + [pd.Timestamp("2018-01-01", tz="UTC")], + index=Index([1], name="tag"), + columns=["date"], + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "err_cls", + [ + NotImplementedError, + RuntimeError, + KeyError, + IndexError, + OSError, + ValueError, + ArithmeticError, + AttributeError, + ], +) +def test_groupby_agg_err_catching(err_cls): + # make sure we suppress anything other than TypeError or AssertionError + # in _python_agg_general + + # Use a non-standard EA to make sure we don't go down ndarray paths + from pandas.tests.extension.decimal.array import ( + DecimalArray, + make_data, + to_decimal, + ) + + data = make_data()[:5] + df = DataFrame( + {"id1": [0, 0, 0, 1, 1], "id2": [0, 1, 0, 1, 1], "decimals": DecimalArray(data)} + ) + + expected = Series(to_decimal([data[0], data[3]])) + + def weird_func(x): + # weird function that raise something other than TypeError or IndexError + # in _python_agg_general + if len(x) == 0: + raise err_cls + return x.iloc[0] + + result = df["decimals"].groupby(df["id1"]).agg(weird_func) + tm.assert_series_equal(result, expected, check_names=False) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/conftest.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..dce3f072ed903ace4cb014f63d60ffde84c9bf4c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/conftest.py @@ -0,0 +1,208 @@ +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Index, + Series, + date_range, +) +from pandas.core.groupby.base import ( + reduction_kernels, + transformation_kernels, +) + + +@pytest.fixture(params=[True, False]) +def sort(request): + return request.param + + +@pytest.fixture(params=[True, False]) +def as_index(request): + return request.param + + +@pytest.fixture(params=[True, False]) +def dropna(request): + return request.param + + +@pytest.fixture(params=[True, False]) +def observed(request): + return request.param + + +@pytest.fixture +def df(): + return DataFrame( + { + "A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"], + "B": ["one", "one", "two", "three", "two", "two", "one", "three"], + "C": np.random.default_rng(2).standard_normal(8), + "D": np.random.default_rng(2).standard_normal(8), + } + ) + + +@pytest.fixture +def ts(): + return Series( + np.random.default_rng(2).standard_normal(30), + index=date_range("2000-01-01", periods=30, freq="B"), + ) + + +@pytest.fixture +def tsframe(): + return DataFrame( + np.random.default_rng(2).standard_normal((30, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=30, freq="B"), + ) + + +@pytest.fixture +def three_group(): + return DataFrame( + { + "A": [ + "foo", + "foo", + "foo", + "foo", + "bar", + "bar", + "bar", + "bar", + "foo", + "foo", + "foo", + ], + "B": [ + "one", + "one", + "one", + "two", + "one", + "one", + "one", + "two", + "two", + "two", + "one", + ], + "C": [ + "dull", + "dull", + "shiny", + "dull", + "dull", + "shiny", + "shiny", + "dull", + "shiny", + "shiny", + "shiny", + ], + "D": np.random.default_rng(2).standard_normal(11), + "E": np.random.default_rng(2).standard_normal(11), + "F": np.random.default_rng(2).standard_normal(11), + } + ) + + +@pytest.fixture() +def slice_test_df(): + data = [ + [0, "a", "a0_at_0"], + [1, "b", "b0_at_1"], + [2, "a", "a1_at_2"], + [3, "b", "b1_at_3"], + [4, "c", "c0_at_4"], + [5, "a", "a2_at_5"], + [6, "a", "a3_at_6"], + [7, "a", "a4_at_7"], + ] + df = DataFrame(data, columns=["Index", "Group", "Value"]) + return df.set_index("Index") + + +@pytest.fixture() +def slice_test_grouped(slice_test_df): + return slice_test_df.groupby("Group", as_index=False) + + +@pytest.fixture(params=sorted(reduction_kernels)) +def reduction_func(request): + """ + yields the string names of all groupby reduction functions, one at a time. + """ + return request.param + + +@pytest.fixture(params=sorted(transformation_kernels)) +def transformation_func(request): + """yields the string names of all groupby transformation functions.""" + return request.param + + +@pytest.fixture(params=sorted(reduction_kernels) + sorted(transformation_kernels)) +def groupby_func(request): + """yields both aggregation and transformation functions.""" + return request.param + + +@pytest.fixture(params=[True, False]) +def parallel(request): + """parallel keyword argument for numba.jit""" + return request.param + + +# Can parameterize nogil & nopython over True | False, but limiting per +# https://github.com/pandas-dev/pandas/pull/41971#issuecomment-860607472 + + +@pytest.fixture(params=[False]) +def nogil(request): + """nogil keyword argument for numba.jit""" + return request.param + + +@pytest.fixture(params=[True]) +def nopython(request): + """nopython keyword argument for numba.jit""" + return request.param + + +@pytest.fixture( + params=[ + ("mean", {}), + ("var", {"ddof": 1}), + ("var", {"ddof": 0}), + ("std", {"ddof": 1}), + ("std", {"ddof": 0}), + ("sum", {}), + ("min", {}), + ("max", {}), + ("sum", {"min_count": 2}), + ("min", {"min_count": 2}), + ("max", {"min_count": 2}), + ], + ids=[ + "mean", + "var_1", + "var_0", + "std_1", + "std_0", + "sum", + "min", + "max", + "sum-min_count", + "min-min_count", + "max-min_count", + ], +) +def numba_supported_reductions(request): + """reductions supported with engine='numba'""" + return request.param diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_corrwith.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_corrwith.py new file mode 100644 index 0000000000000000000000000000000000000000..53e8bdc4534dc66dc1b68e603b2af431d0c0b209 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_corrwith.py @@ -0,0 +1,24 @@ +import numpy as np + +from pandas import ( + DataFrame, + Index, + Series, +) +import pandas._testing as tm + + +def test_corrwith_with_1_axis(): + # GH 47723 + df = DataFrame({"a": [1, 1, 2], "b": [3, 7, 4]}) + gb = df.groupby("a") + + msg = "DataFrameGroupBy.corrwith with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = gb.corrwith(df, axis=1) + index = Index( + data=[(1, 0), (1, 1), (1, 2), (2, 2), (2, 0), (2, 1)], + name=("a", None), + ) + expected = Series([np.nan] * 6, index=index) + tm.assert_series_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_describe.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_describe.py new file mode 100644 index 0000000000000000000000000000000000000000..a2440e09dfc02436140e94cd689b39a1a9d35189 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_describe.py @@ -0,0 +1,297 @@ +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, + Timestamp, + date_range, +) +import pandas._testing as tm + + +def test_apply_describe_bug(multiindex_dataframe_random_data): + grouped = multiindex_dataframe_random_data.groupby(level="first") + grouped.describe() # it works! + + +def test_series_describe_multikey(): + ts = Series( + np.arange(10, dtype=np.float64), index=date_range("2020-01-01", periods=10) + ) + grouped = ts.groupby([lambda x: x.year, lambda x: x.month]) + result = grouped.describe() + tm.assert_series_equal(result["mean"], grouped.mean(), check_names=False) + tm.assert_series_equal(result["std"], grouped.std(), check_names=False) + tm.assert_series_equal(result["min"], grouped.min(), check_names=False) + + +def test_series_describe_single(): + ts = Series( + np.arange(10, dtype=np.float64), index=date_range("2020-01-01", periods=10) + ) + grouped = ts.groupby(lambda x: x.month) + result = grouped.apply(lambda x: x.describe()) + expected = grouped.describe().stack(future_stack=True) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("keys", ["key1", ["key1", "key2"]]) +def test_series_describe_as_index(as_index, keys): + # GH#49256 + df = DataFrame( + { + "key1": ["one", "two", "two", "three", "two"], + "key2": ["one", "two", "two", "three", "two"], + "foo2": [1, 2, 4, 4, 6], + } + ) + gb = df.groupby(keys, as_index=as_index)["foo2"] + result = gb.describe() + expected = DataFrame( + { + "key1": ["one", "three", "two"], + "count": [1.0, 1.0, 3.0], + "mean": [1.0, 4.0, 4.0], + "std": [np.nan, np.nan, 2.0], + "min": [1.0, 4.0, 2.0], + "25%": [1.0, 4.0, 3.0], + "50%": [1.0, 4.0, 4.0], + "75%": [1.0, 4.0, 5.0], + "max": [1.0, 4.0, 6.0], + } + ) + if len(keys) == 2: + expected.insert(1, "key2", expected["key1"]) + if as_index: + expected = expected.set_index(keys) + tm.assert_frame_equal(result, expected) + + +def test_frame_describe_multikey(tsframe): + grouped = tsframe.groupby([lambda x: x.year, lambda x: x.month]) + result = grouped.describe() + desc_groups = [] + for col in tsframe: + group = grouped[col].describe() + # GH 17464 - Remove duplicate MultiIndex levels + group_col = MultiIndex( + levels=[[col], group.columns], + codes=[[0] * len(group.columns), range(len(group.columns))], + ) + group = DataFrame(group.values, columns=group_col, index=group.index) + desc_groups.append(group) + expected = pd.concat(desc_groups, axis=1) + tm.assert_frame_equal(result, expected) + + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + groupedT = tsframe.groupby({"A": 0, "B": 0, "C": 1, "D": 1}, axis=1) + result = groupedT.describe() + expected = tsframe.describe().T + # reverting the change from https://github.com/pandas-dev/pandas/pull/35441/ + expected.index = MultiIndex( + levels=[[0, 1], expected.index], + codes=[[0, 0, 1, 1], range(len(expected.index))], + ) + tm.assert_frame_equal(result, expected) + + +def test_frame_describe_tupleindex(): + # GH 14848 - regression from 0.19.0 to 0.19.1 + df1 = DataFrame( + { + "x": [1, 2, 3, 4, 5] * 3, + "y": [10, 20, 30, 40, 50] * 3, + "z": [100, 200, 300, 400, 500] * 3, + } + ) + df1["k"] = [(0, 0, 1), (0, 1, 0), (1, 0, 0)] * 5 + df2 = df1.rename(columns={"k": "key"}) + msg = "Names should be list-like for a MultiIndex" + with pytest.raises(ValueError, match=msg): + df1.groupby("k").describe() + with pytest.raises(ValueError, match=msg): + df2.groupby("key").describe() + + +def test_frame_describe_unstacked_format(): + # GH 4792 + prices = { + Timestamp("2011-01-06 10:59:05", tz=None): 24990, + Timestamp("2011-01-06 12:43:33", tz=None): 25499, + Timestamp("2011-01-06 12:54:09", tz=None): 25499, + } + volumes = { + Timestamp("2011-01-06 10:59:05", tz=None): 1500000000, + Timestamp("2011-01-06 12:43:33", tz=None): 5000000000, + Timestamp("2011-01-06 12:54:09", tz=None): 100000000, + } + df = DataFrame({"PRICE": prices, "VOLUME": volumes}) + result = df.groupby("PRICE").VOLUME.describe() + data = [ + df[df.PRICE == 24990].VOLUME.describe().values.tolist(), + df[df.PRICE == 25499].VOLUME.describe().values.tolist(), + ] + expected = DataFrame( + data, + index=Index([24990, 25499], name="PRICE"), + columns=["count", "mean", "std", "min", "25%", "50%", "75%", "max"], + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.filterwarnings( + "ignore:" + "indexing past lexsort depth may impact performance:" + "pandas.errors.PerformanceWarning" +) +@pytest.mark.parametrize("as_index", [True, False]) +@pytest.mark.parametrize("keys", [["a1"], ["a1", "a2"]]) +def test_describe_with_duplicate_output_column_names(as_index, keys): + # GH 35314 + df = DataFrame( + { + "a1": [99, 99, 99, 88, 88, 88], + "a2": [99, 99, 99, 88, 88, 88], + "b": [1, 2, 3, 4, 5, 6], + "c": [10, 20, 30, 40, 50, 60], + }, + columns=["a1", "a2", "b", "b"], + copy=False, + ) + if keys == ["a1"]: + df = df.drop(columns="a2") + + expected = ( + DataFrame.from_records( + [ + ("b", "count", 3.0, 3.0), + ("b", "mean", 5.0, 2.0), + ("b", "std", 1.0, 1.0), + ("b", "min", 4.0, 1.0), + ("b", "25%", 4.5, 1.5), + ("b", "50%", 5.0, 2.0), + ("b", "75%", 5.5, 2.5), + ("b", "max", 6.0, 3.0), + ("b", "count", 3.0, 3.0), + ("b", "mean", 5.0, 2.0), + ("b", "std", 1.0, 1.0), + ("b", "min", 4.0, 1.0), + ("b", "25%", 4.5, 1.5), + ("b", "50%", 5.0, 2.0), + ("b", "75%", 5.5, 2.5), + ("b", "max", 6.0, 3.0), + ], + ) + .set_index([0, 1]) + .T + ) + expected.columns.names = [None, None] + if len(keys) == 2: + expected.index = MultiIndex( + levels=[[88, 99], [88, 99]], codes=[[0, 1], [0, 1]], names=["a1", "a2"] + ) + else: + expected.index = Index([88, 99], name="a1") + + if not as_index: + expected = expected.reset_index() + + result = df.groupby(keys, as_index=as_index).describe() + + tm.assert_frame_equal(result, expected) + + +def test_describe_duplicate_columns(): + # GH#50806 + df = DataFrame([[0, 1, 2, 3]]) + df.columns = [0, 1, 2, 0] + gb = df.groupby(df[1]) + result = gb.describe(percentiles=[]) + + columns = ["count", "mean", "std", "min", "50%", "max"] + frames = [ + DataFrame([[1.0, val, np.nan, val, val, val]], index=[1], columns=columns) + for val in (0.0, 2.0, 3.0) + ] + expected = pd.concat(frames, axis=1) + expected.columns = MultiIndex( + levels=[[0, 2], columns], + codes=[6 * [0] + 6 * [1] + 6 * [0], 3 * list(range(6))], + ) + expected.index.names = [1] + tm.assert_frame_equal(result, expected) + + +class TestGroupByNonCythonPaths: + # GH#5610 non-cython calls should not include the grouper + # Tests for code not expected to go through cython paths. + + @pytest.fixture + def df(self): + df = DataFrame( + [[1, 2, "foo"], [1, np.nan, "bar"], [3, np.nan, "baz"]], + columns=["A", "B", "C"], + ) + return df + + @pytest.fixture + def gb(self, df): + gb = df.groupby("A") + return gb + + @pytest.fixture + def gni(self, df): + gni = df.groupby("A", as_index=False) + return gni + + def test_describe(self, df, gb, gni): + # describe + expected_index = Index([1, 3], name="A") + expected_col = MultiIndex( + levels=[["B"], ["count", "mean", "std", "min", "25%", "50%", "75%", "max"]], + codes=[[0] * 8, list(range(8))], + ) + expected = DataFrame( + [ + [1.0, 2.0, np.nan, 2.0, 2.0, 2.0, 2.0, 2.0], + [0.0, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan], + ], + index=expected_index, + columns=expected_col, + ) + result = gb.describe() + tm.assert_frame_equal(result, expected) + + expected = expected.reset_index() + result = gni.describe() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype", [int, float, object]) +@pytest.mark.parametrize( + "kwargs", + [ + {"percentiles": [0.10, 0.20, 0.30], "include": "all", "exclude": None}, + {"percentiles": [0.10, 0.20, 0.30], "include": None, "exclude": ["int"]}, + {"percentiles": [0.10, 0.20, 0.30], "include": ["int"], "exclude": None}, + ], +) +def test_groupby_empty_dataset(dtype, kwargs): + # GH#41575 + df = DataFrame([[1, 2, 3]], columns=["A", "B", "C"], dtype=dtype) + df["B"] = df["B"].astype(int) + df["C"] = df["C"].astype(float) + + result = df.iloc[:0].groupby("A").describe(**kwargs) + expected = df.groupby("A").describe(**kwargs).reset_index(drop=True).iloc[:0] + tm.assert_frame_equal(result, expected) + + result = df.iloc[:0].groupby("A").B.describe(**kwargs) + expected = df.groupby("A").B.describe(**kwargs).reset_index(drop=True).iloc[:0] + expected.index = Index([]) + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_groupby_shift_diff.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_groupby_shift_diff.py new file mode 100644 index 0000000000000000000000000000000000000000..94e672d4892feb513f75d9a3d3376e261e2c0f36 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_groupby_shift_diff.py @@ -0,0 +1,255 @@ +import numpy as np +import pytest + +from pandas import ( + DataFrame, + NaT, + Series, + Timedelta, + Timestamp, + date_range, +) +import pandas._testing as tm + + +def test_group_shift_with_null_key(): + # This test is designed to replicate the segfault in issue #13813. + n_rows = 1200 + + # Generate a moderately large dataframe with occasional missing + # values in column `B`, and then group by [`A`, `B`]. This should + # force `-1` in `labels` array of `g._grouper.group_info` exactly + # at those places, where the group-by key is partially missing. + df = DataFrame( + [(i % 12, i % 3 if i % 3 else np.nan, i) for i in range(n_rows)], + dtype=float, + columns=["A", "B", "Z"], + index=None, + ) + g = df.groupby(["A", "B"]) + + expected = DataFrame( + [(i + 12 if i % 3 and i < n_rows - 12 else np.nan) for i in range(n_rows)], + dtype=float, + columns=["Z"], + index=None, + ) + result = g.shift(-1) + + tm.assert_frame_equal(result, expected) + + +def test_group_shift_with_fill_value(): + # GH #24128 + n_rows = 24 + df = DataFrame( + [(i % 12, i % 3, i) for i in range(n_rows)], + dtype=float, + columns=["A", "B", "Z"], + index=None, + ) + g = df.groupby(["A", "B"]) + + expected = DataFrame( + [(i + 12 if i < n_rows - 12 else 0) for i in range(n_rows)], + dtype=float, + columns=["Z"], + index=None, + ) + result = g.shift(-1, fill_value=0) + + tm.assert_frame_equal(result, expected) + + +def test_group_shift_lose_timezone(): + # GH 30134 + now_dt = Timestamp.utcnow().as_unit("ns") + df = DataFrame({"a": [1, 1], "date": now_dt}) + result = df.groupby("a").shift(0).iloc[0] + expected = Series({"date": now_dt}, name=result.name) + tm.assert_series_equal(result, expected) + + +def test_group_diff_real_series(any_real_numpy_dtype): + df = DataFrame( + {"a": [1, 2, 3, 3, 2], "b": [1, 2, 3, 4, 5]}, + dtype=any_real_numpy_dtype, + ) + result = df.groupby("a")["b"].diff() + exp_dtype = "float" + if any_real_numpy_dtype in ["int8", "int16", "float32"]: + exp_dtype = "float32" + expected = Series([np.nan, np.nan, np.nan, 1.0, 3.0], dtype=exp_dtype, name="b") + tm.assert_series_equal(result, expected) + + +def test_group_diff_real_frame(any_real_numpy_dtype): + df = DataFrame( + { + "a": [1, 2, 3, 3, 2], + "b": [1, 2, 3, 4, 5], + "c": [1, 2, 3, 4, 6], + }, + dtype=any_real_numpy_dtype, + ) + result = df.groupby("a").diff() + exp_dtype = "float" + if any_real_numpy_dtype in ["int8", "int16", "float32"]: + exp_dtype = "float32" + expected = DataFrame( + { + "b": [np.nan, np.nan, np.nan, 1.0, 3.0], + "c": [np.nan, np.nan, np.nan, 1.0, 4.0], + }, + dtype=exp_dtype, + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "data", + [ + [ + Timestamp("2013-01-01"), + Timestamp("2013-01-02"), + Timestamp("2013-01-03"), + ], + [Timedelta("5 days"), Timedelta("6 days"), Timedelta("7 days")], + ], +) +def test_group_diff_datetimelike(data, unit): + df = DataFrame({"a": [1, 2, 2], "b": data}) + df["b"] = df["b"].dt.as_unit(unit) + result = df.groupby("a")["b"].diff() + expected = Series([NaT, NaT, Timedelta("1 days")], name="b").dt.as_unit(unit) + tm.assert_series_equal(result, expected) + + +def test_group_diff_bool(): + df = DataFrame({"a": [1, 2, 3, 3, 2], "b": [True, True, False, False, True]}) + result = df.groupby("a")["b"].diff() + expected = Series([np.nan, np.nan, np.nan, False, False], name="b") + tm.assert_series_equal(result, expected) + + +def test_group_diff_object_raises(object_dtype): + df = DataFrame( + {"a": ["foo", "bar", "bar"], "b": ["baz", "foo", "foo"]}, dtype=object_dtype + ) + with pytest.raises(TypeError, match=r"unsupported operand type\(s\) for -"): + df.groupby("a")["b"].diff() + + +def test_empty_shift_with_fill(): + # GH 41264, single-index check + df = DataFrame(columns=["a", "b", "c"]) + shifted = df.groupby(["a"]).shift(1) + shifted_with_fill = df.groupby(["a"]).shift(1, fill_value=0) + tm.assert_frame_equal(shifted, shifted_with_fill) + tm.assert_index_equal(shifted.index, shifted_with_fill.index) + + +def test_multindex_empty_shift_with_fill(): + # GH 41264, multi-index check + df = DataFrame(columns=["a", "b", "c"]) + shifted = df.groupby(["a", "b"]).shift(1) + shifted_with_fill = df.groupby(["a", "b"]).shift(1, fill_value=0) + tm.assert_frame_equal(shifted, shifted_with_fill) + tm.assert_index_equal(shifted.index, shifted_with_fill.index) + + +def test_shift_periods_freq(): + # GH 54093 + data = {"a": [1, 2, 3, 4, 5, 6], "b": [0, 0, 0, 1, 1, 1]} + df = DataFrame(data, index=date_range(start="20100101", periods=6)) + result = df.groupby(df.index).shift(periods=-2, freq="D") + expected = DataFrame(data, index=date_range(start="2009-12-30", periods=6)) + tm.assert_frame_equal(result, expected) + + +def test_shift_deprecate_freq_and_fill_value(): + # GH 53832 + data = {"a": [1, 2, 3, 4, 5, 6], "b": [0, 0, 0, 1, 1, 1]} + df = DataFrame(data, index=date_range(start="20100101", periods=6)) + msg = ( + "Passing a 'freq' together with a 'fill_value' silently ignores the fill_value" + ) + with tm.assert_produces_warning(FutureWarning, match=msg): + df.groupby(df.index).shift(periods=-2, freq="D", fill_value="1") + + +def test_shift_disallow_suffix_if_periods_is_int(): + # GH#44424 + data = {"a": [1, 2, 3, 4, 5, 6], "b": [0, 0, 0, 1, 1, 1]} + df = DataFrame(data) + msg = "Cannot specify `suffix` if `periods` is an int." + with pytest.raises(ValueError, match=msg): + df.groupby("b").shift(1, suffix="fails") + + +def test_group_shift_with_multiple_periods(): + # GH#44424 + df = DataFrame({"a": [1, 2, 3, 3, 2], "b": [True, True, False, False, True]}) + + shifted_df = df.groupby("b")[["a"]].shift([0, 1]) + expected_df = DataFrame( + {"a_0": [1, 2, 3, 3, 2], "a_1": [np.nan, 1.0, np.nan, 3.0, 2.0]} + ) + tm.assert_frame_equal(shifted_df, expected_df) + + # series + shifted_series = df.groupby("b")["a"].shift([0, 1]) + tm.assert_frame_equal(shifted_series, expected_df) + + +def test_group_shift_with_multiple_periods_and_freq(): + # GH#44424 + df = DataFrame( + {"a": [1, 2, 3, 4, 5], "b": [True, True, False, False, True]}, + index=date_range("1/1/2000", periods=5, freq="h"), + ) + shifted_df = df.groupby("b")[["a"]].shift( + [0, 1], + freq="h", + ) + expected_df = DataFrame( + { + "a_0": [1.0, 2.0, 3.0, 4.0, 5.0, np.nan], + "a_1": [ + np.nan, + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + ], + }, + index=date_range("1/1/2000", periods=6, freq="h"), + ) + tm.assert_frame_equal(shifted_df, expected_df) + + +def test_group_shift_with_multiple_periods_and_fill_value(): + # GH#44424 + df = DataFrame( + {"a": [1, 2, 3, 4, 5], "b": [True, True, False, False, True]}, + ) + shifted_df = df.groupby("b")[["a"]].shift([0, 1], fill_value=-1) + expected_df = DataFrame( + {"a_0": [1, 2, 3, 4, 5], "a_1": [-1, 1, -1, 3, 2]}, + ) + tm.assert_frame_equal(shifted_df, expected_df) + + +def test_group_shift_with_multiple_periods_and_both_fill_and_freq_deprecated(): + # GH#44424 + df = DataFrame( + {"a": [1, 2, 3, 4, 5], "b": [True, True, False, False, True]}, + index=date_range("1/1/2000", periods=5, freq="h"), + ) + msg = ( + "Passing a 'freq' together with a 'fill_value' silently ignores the " + "fill_value" + ) + with tm.assert_produces_warning(FutureWarning, match=msg): + df.groupby("b")[["a"]].shift([1, 2], fill_value=1, freq="h") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_is_monotonic.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_is_monotonic.py new file mode 100644 index 0000000000000000000000000000000000000000..3428fc90f6e51a0bde0aba9c8ea08ebf414e5556 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_is_monotonic.py @@ -0,0 +1,78 @@ +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Index, + Series, +) +import pandas._testing as tm + + +@pytest.mark.parametrize( + "in_vals, out_vals", + [ + # Basics: strictly increasing (T), strictly decreasing (F), + # abs val increasing (F), non-strictly increasing (T) + ([1, 2, 5, 3, 2, 0, 4, 5, -6, 1, 1], [True, False, False, True]), + # Test with inf vals + ( + [1, 2.1, np.inf, 3, 2, np.inf, -np.inf, 5, 11, 1, -np.inf], + [True, False, True, False], + ), + # Test with nan vals; should always be False + ( + [1, 2, np.nan, 3, 2, np.nan, np.nan, 5, -np.inf, 1, np.nan], + [False, False, False, False], + ), + ], +) +def test_is_monotonic_increasing(in_vals, out_vals): + # GH 17015 + source_dict = { + "A": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11"], + "B": ["a", "a", "a", "b", "b", "b", "c", "c", "c", "d", "d"], + "C": in_vals, + } + df = DataFrame(source_dict) + result = df.groupby("B").C.is_monotonic_increasing + index = Index(list("abcd"), name="B") + expected = Series(index=index, data=out_vals, name="C") + tm.assert_series_equal(result, expected) + + # Also check result equal to manually taking x.is_monotonic_increasing. + expected = df.groupby(["B"]).C.apply(lambda x: x.is_monotonic_increasing) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "in_vals, out_vals", + [ + # Basics: strictly decreasing (T), strictly increasing (F), + # abs val decreasing (F), non-strictly increasing (T) + ([10, 9, 7, 3, 4, 5, -3, 2, 0, 1, 1], [True, False, False, True]), + # Test with inf vals + ( + [np.inf, 1, -np.inf, np.inf, 2, -3, -np.inf, 5, -3, -np.inf, -np.inf], + [True, True, False, True], + ), + # Test with nan vals; should always be False + ( + [1, 2, np.nan, 3, 2, np.nan, np.nan, 5, -np.inf, 1, np.nan], + [False, False, False, False], + ), + ], +) +def test_is_monotonic_decreasing(in_vals, out_vals): + # GH 17015 + source_dict = { + "A": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11"], + "B": ["a", "a", "a", "b", "b", "b", "c", "c", "c", "d", "d"], + "C": in_vals, + } + + df = DataFrame(source_dict) + result = df.groupby("B").C.is_monotonic_decreasing + index = Index(list("abcd"), name="B") + expected = Series(index=index, data=out_vals, name="C") + tm.assert_series_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_nlargest_nsmallest.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_nlargest_nsmallest.py new file mode 100644 index 0000000000000000000000000000000000000000..bf983f04a3f3f17566299bafe756e95e2727f6ad --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_nlargest_nsmallest.py @@ -0,0 +1,115 @@ +import numpy as np +import pytest + +from pandas import ( + MultiIndex, + Series, + date_range, +) +import pandas._testing as tm + + +def test_nlargest(): + a = Series([1, 3, 5, 7, 2, 9, 0, 4, 6, 10]) + b = Series(list("a" * 5 + "b" * 5)) + gb = a.groupby(b) + r = gb.nlargest(3) + e = Series( + [7, 5, 3, 10, 9, 6], + index=MultiIndex.from_arrays([list("aaabbb"), [3, 2, 1, 9, 5, 8]]), + ) + tm.assert_series_equal(r, e) + + a = Series([1, 1, 3, 2, 0, 3, 3, 2, 1, 0]) + gb = a.groupby(b) + e = Series( + [3, 2, 1, 3, 3, 2], + index=MultiIndex.from_arrays([list("aaabbb"), [2, 3, 1, 6, 5, 7]]), + ) + tm.assert_series_equal(gb.nlargest(3, keep="last"), e) + + +def test_nlargest_mi_grouper(): + # see gh-21411 + npr = np.random.default_rng(2) + + dts = date_range("20180101", periods=10) + iterables = [dts, ["one", "two"]] + + idx = MultiIndex.from_product(iterables, names=["first", "second"]) + s = Series(npr.standard_normal(20), index=idx) + + result = s.groupby("first").nlargest(1) + + exp_idx = MultiIndex.from_tuples( + [ + (dts[0], dts[0], "one"), + (dts[1], dts[1], "one"), + (dts[2], dts[2], "one"), + (dts[3], dts[3], "two"), + (dts[4], dts[4], "one"), + (dts[5], dts[5], "one"), + (dts[6], dts[6], "one"), + (dts[7], dts[7], "one"), + (dts[8], dts[8], "one"), + (dts[9], dts[9], "one"), + ], + names=["first", "first", "second"], + ) + + exp_values = [ + 0.18905338179353307, + -0.41306354339189344, + 1.799707382720902, + 0.7738065867276614, + 0.28121066979764925, + 0.9775674511260357, + -0.3288239040579627, + 0.45495807124085547, + 0.5452887139646817, + 0.12682784711186987, + ] + + expected = Series(exp_values, index=exp_idx) + tm.assert_series_equal(result, expected, check_exact=False, rtol=1e-3) + + +def test_nsmallest(): + a = Series([1, 3, 5, 7, 2, 9, 0, 4, 6, 10]) + b = Series(list("a" * 5 + "b" * 5)) + gb = a.groupby(b) + r = gb.nsmallest(3) + e = Series( + [1, 2, 3, 0, 4, 6], + index=MultiIndex.from_arrays([list("aaabbb"), [0, 4, 1, 6, 7, 8]]), + ) + tm.assert_series_equal(r, e) + + a = Series([1, 1, 3, 2, 0, 3, 3, 2, 1, 0]) + gb = a.groupby(b) + e = Series( + [0, 1, 1, 0, 1, 2], + index=MultiIndex.from_arrays([list("aaabbb"), [4, 1, 0, 9, 8, 7]]), + ) + tm.assert_series_equal(gb.nsmallest(3, keep="last"), e) + + +@pytest.mark.parametrize( + "data, groups", + [([0, 1, 2, 3], [0, 0, 1, 1]), ([0], [0])], +) +@pytest.mark.parametrize("dtype", [None, *tm.ALL_INT_NUMPY_DTYPES]) +@pytest.mark.parametrize("method", ["nlargest", "nsmallest"]) +def test_nlargest_and_smallest_noop(data, groups, dtype, method): + # GH 15272, GH 16345, GH 29129 + # Test nlargest/smallest when it results in a noop, + # i.e. input is sorted and group size <= n + if dtype is not None: + data = np.array(data, dtype=dtype) + if method == "nlargest": + data = list(reversed(data)) + ser = Series(data, name="a") + result = getattr(ser.groupby(groups), method)(n=2) + expidx = np.array(groups, dtype=int) if isinstance(groups, list) else groups + expected = Series(data, index=MultiIndex.from_arrays([expidx, ser.index]), name="a") + tm.assert_series_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_nth.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_nth.py new file mode 100644 index 0000000000000000000000000000000000000000..a8ed9e9d5202173b25b8dc47598e49672e0c8a31 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_nth.py @@ -0,0 +1,921 @@ +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, + Timestamp, + isna, +) +import pandas._testing as tm + + +def test_first_last_nth(df): + # tests for first / last / nth + grouped = df.groupby("A") + first = grouped.first() + expected = df.loc[[1, 0], ["B", "C", "D"]] + expected.index = Index(["bar", "foo"], name="A") + expected = expected.sort_index() + tm.assert_frame_equal(first, expected) + + nth = grouped.nth(0) + expected = df.loc[[0, 1]] + tm.assert_frame_equal(nth, expected) + + last = grouped.last() + expected = df.loc[[5, 7], ["B", "C", "D"]] + expected.index = Index(["bar", "foo"], name="A") + tm.assert_frame_equal(last, expected) + + nth = grouped.nth(-1) + expected = df.iloc[[5, 7]] + tm.assert_frame_equal(nth, expected) + + nth = grouped.nth(1) + expected = df.iloc[[2, 3]] + tm.assert_frame_equal(nth, expected) + + # it works! + grouped["B"].first() + grouped["B"].last() + grouped["B"].nth(0) + + df = df.copy() + df.loc[df["A"] == "foo", "B"] = np.nan + grouped = df.groupby("A") + assert isna(grouped["B"].first()["foo"]) + assert isna(grouped["B"].last()["foo"]) + assert isna(grouped["B"].nth(0).iloc[0]) + + # v0.14.0 whatsnew + df = DataFrame([[1, np.nan], [1, 4], [5, 6]], columns=["A", "B"]) + g = df.groupby("A") + result = g.first() + expected = df.iloc[[1, 2]].set_index("A") + tm.assert_frame_equal(result, expected) + + expected = df.iloc[[1, 2]] + result = g.nth(0, dropna="any") + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("method", ["first", "last"]) +def test_first_last_with_na_object(method, nulls_fixture): + # https://github.com/pandas-dev/pandas/issues/32123 + groups = DataFrame({"a": [1, 1, 2, 2], "b": [1, 2, 3, nulls_fixture]}).groupby("a") + result = getattr(groups, method)() + + if method == "first": + values = [1, 3] + else: + values = [2, 3] + + values = np.array(values, dtype=result["b"].dtype) + idx = Index([1, 2], name="a") + expected = DataFrame({"b": values}, index=idx) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("index", [0, -1]) +def test_nth_with_na_object(index, nulls_fixture): + # https://github.com/pandas-dev/pandas/issues/32123 + df = DataFrame({"a": [1, 1, 2, 2], "b": [1, 2, 3, nulls_fixture]}) + groups = df.groupby("a") + result = groups.nth(index) + expected = df.iloc[[0, 2]] if index == 0 else df.iloc[[1, 3]] + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("method", ["first", "last"]) +def test_first_last_with_None(method): + # https://github.com/pandas-dev/pandas/issues/32800 + # None should be preserved as object dtype + df = DataFrame.from_dict({"id": ["a"], "value": [None]}) + groups = df.groupby("id", as_index=False) + result = getattr(groups, method)() + + tm.assert_frame_equal(result, df) + + +@pytest.mark.parametrize("method", ["first", "last"]) +@pytest.mark.parametrize( + "df, expected", + [ + ( + DataFrame({"id": "a", "value": [None, "foo", np.nan]}), + DataFrame({"value": ["foo"]}, index=Index(["a"], name="id")), + ), + ( + DataFrame({"id": "a", "value": [np.nan]}, dtype=object), + DataFrame({"value": [None]}, index=Index(["a"], name="id")), + ), + ], +) +def test_first_last_with_None_expanded(method, df, expected): + # GH 32800, 38286 + result = getattr(df.groupby("id"), method)() + tm.assert_frame_equal(result, expected) + + +def test_first_last_nth_dtypes(): + df = DataFrame( + { + "A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"], + "B": ["one", "one", "two", "three", "two", "two", "one", "three"], + "C": np.random.default_rng(2).standard_normal(8), + "D": np.array(np.random.default_rng(2).standard_normal(8), dtype="float32"), + } + ) + df["E"] = True + df["F"] = 1 + + # tests for first / last / nth + grouped = df.groupby("A") + first = grouped.first() + expected = df.loc[[1, 0], ["B", "C", "D", "E", "F"]] + expected.index = Index(["bar", "foo"], name="A") + expected = expected.sort_index() + tm.assert_frame_equal(first, expected) + + last = grouped.last() + expected = df.loc[[5, 7], ["B", "C", "D", "E", "F"]] + expected.index = Index(["bar", "foo"], name="A") + expected = expected.sort_index() + tm.assert_frame_equal(last, expected) + + nth = grouped.nth(1) + expected = df.iloc[[2, 3]] + tm.assert_frame_equal(nth, expected) + + +def test_first_last_nth_dtypes2(): + # GH 2763, first/last shifting dtypes + idx = list(range(10)) + idx.append(9) + ser = Series(data=range(11), index=idx, name="IntCol") + assert ser.dtype == "int64" + f = ser.groupby(level=0).first() + assert f.dtype == "int64" + + +def test_first_last_nth_nan_dtype(): + # GH 33591 + df = DataFrame({"data": ["A"], "nans": Series([None], dtype=object)}) + grouped = df.groupby("data") + + expected = df.set_index("data").nans + tm.assert_series_equal(grouped.nans.first(), expected) + tm.assert_series_equal(grouped.nans.last(), expected) + + expected = df.nans + tm.assert_series_equal(grouped.nans.nth(-1), expected) + tm.assert_series_equal(grouped.nans.nth(0), expected) + + +def test_first_strings_timestamps(): + # GH 11244 + test = DataFrame( + { + Timestamp("2012-01-01 00:00:00"): ["a", "b"], + Timestamp("2012-01-02 00:00:00"): ["c", "d"], + "name": ["e", "e"], + "aaaa": ["f", "g"], + } + ) + result = test.groupby("name").first() + expected = DataFrame( + [["a", "c", "f"]], + columns=Index([Timestamp("2012-01-01"), Timestamp("2012-01-02"), "aaaa"]), + index=Index(["e"], name="name"), + ) + tm.assert_frame_equal(result, expected) + + +def test_nth(): + df = DataFrame([[1, np.nan], [1, 4], [5, 6]], columns=["A", "B"]) + gb = df.groupby("A") + + tm.assert_frame_equal(gb.nth(0), df.iloc[[0, 2]]) + tm.assert_frame_equal(gb.nth(1), df.iloc[[1]]) + tm.assert_frame_equal(gb.nth(2), df.loc[[]]) + tm.assert_frame_equal(gb.nth(-1), df.iloc[[1, 2]]) + tm.assert_frame_equal(gb.nth(-2), df.iloc[[0]]) + tm.assert_frame_equal(gb.nth(-3), df.loc[[]]) + tm.assert_series_equal(gb.B.nth(0), df.B.iloc[[0, 2]]) + tm.assert_series_equal(gb.B.nth(1), df.B.iloc[[1]]) + tm.assert_frame_equal(gb[["B"]].nth(0), df[["B"]].iloc[[0, 2]]) + + tm.assert_frame_equal(gb.nth(0, dropna="any"), df.iloc[[1, 2]]) + tm.assert_frame_equal(gb.nth(-1, dropna="any"), df.iloc[[1, 2]]) + + tm.assert_frame_equal(gb.nth(7, dropna="any"), df.iloc[:0]) + tm.assert_frame_equal(gb.nth(2, dropna="any"), df.iloc[:0]) + + +def test_nth2(): + # out of bounds, regression from 0.13.1 + # GH 6621 + df = DataFrame( + { + "color": {0: "green", 1: "green", 2: "red", 3: "red", 4: "red"}, + "food": {0: "ham", 1: "eggs", 2: "eggs", 3: "ham", 4: "pork"}, + "two": { + 0: 1.5456590000000001, + 1: -0.070345000000000005, + 2: -2.4004539999999999, + 3: 0.46206000000000003, + 4: 0.52350799999999997, + }, + "one": { + 0: 0.56573799999999996, + 1: -0.9742360000000001, + 2: 1.033801, + 3: -0.78543499999999999, + 4: 0.70422799999999997, + }, + } + ).set_index(["color", "food"]) + + result = df.groupby(level=0, as_index=False).nth(2) + expected = df.iloc[[-1]] + tm.assert_frame_equal(result, expected) + + result = df.groupby(level=0, as_index=False).nth(3) + expected = df.loc[[]] + tm.assert_frame_equal(result, expected) + + +def test_nth3(): + # GH 7559 + # from the vbench + df = DataFrame(np.random.default_rng(2).integers(1, 10, (100, 2)), dtype="int64") + ser = df[1] + gb = df[0] + expected = ser.groupby(gb).first() + expected2 = ser.groupby(gb).apply(lambda x: x.iloc[0]) + tm.assert_series_equal(expected2, expected, check_names=False) + assert expected.name == 1 + assert expected2.name == 1 + + # validate first + v = ser[gb == 1].iloc[0] + assert expected.iloc[0] == v + assert expected2.iloc[0] == v + + with pytest.raises(ValueError, match="For a DataFrame"): + ser.groupby(gb, sort=False).nth(0, dropna=True) + + +def test_nth4(): + # doc example + df = DataFrame([[1, np.nan], [1, 4], [5, 6]], columns=["A", "B"]) + gb = df.groupby("A") + result = gb.B.nth(0, dropna="all") + expected = df.B.iloc[[1, 2]] + tm.assert_series_equal(result, expected) + + +def test_nth5(): + # test multiple nth values + df = DataFrame([[1, np.nan], [1, 3], [1, 4], [5, 6], [5, 7]], columns=["A", "B"]) + gb = df.groupby("A") + + tm.assert_frame_equal(gb.nth(0), df.iloc[[0, 3]]) + tm.assert_frame_equal(gb.nth([0]), df.iloc[[0, 3]]) + tm.assert_frame_equal(gb.nth([0, 1]), df.iloc[[0, 1, 3, 4]]) + tm.assert_frame_equal(gb.nth([0, -1]), df.iloc[[0, 2, 3, 4]]) + tm.assert_frame_equal(gb.nth([0, 1, 2]), df.iloc[[0, 1, 2, 3, 4]]) + tm.assert_frame_equal(gb.nth([0, 1, -1]), df.iloc[[0, 1, 2, 3, 4]]) + tm.assert_frame_equal(gb.nth([2]), df.iloc[[2]]) + tm.assert_frame_equal(gb.nth([3, 4]), df.loc[[]]) + + +def test_nth_bdays(unit): + business_dates = pd.date_range( + start="4/1/2014", end="6/30/2014", freq="B", unit=unit + ) + df = DataFrame(1, index=business_dates, columns=["a", "b"]) + # get the first, fourth and last two business days for each month + key = [df.index.year, df.index.month] + result = df.groupby(key, as_index=False).nth([0, 3, -2, -1]) + expected_dates = pd.to_datetime( + [ + "2014/4/1", + "2014/4/4", + "2014/4/29", + "2014/4/30", + "2014/5/1", + "2014/5/6", + "2014/5/29", + "2014/5/30", + "2014/6/2", + "2014/6/5", + "2014/6/27", + "2014/6/30", + ] + ).as_unit(unit) + expected = DataFrame(1, columns=["a", "b"], index=expected_dates) + tm.assert_frame_equal(result, expected) + + +def test_nth_multi_grouper(three_group): + # PR 9090, related to issue 8979 + # test nth on multiple groupers + grouped = three_group.groupby(["A", "B"]) + result = grouped.nth(0) + expected = three_group.iloc[[0, 3, 4, 7]] + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "data, expected_first, expected_last", + [ + ( + { + "id": ["A"], + "time": Timestamp("2012-02-01 14:00:00", tz="US/Central"), + "foo": [1], + }, + { + "id": ["A"], + "time": Timestamp("2012-02-01 14:00:00", tz="US/Central"), + "foo": [1], + }, + { + "id": ["A"], + "time": Timestamp("2012-02-01 14:00:00", tz="US/Central"), + "foo": [1], + }, + ), + ( + { + "id": ["A", "B", "A"], + "time": [ + Timestamp("2012-01-01 13:00:00", tz="America/New_York"), + Timestamp("2012-02-01 14:00:00", tz="US/Central"), + Timestamp("2012-03-01 12:00:00", tz="Europe/London"), + ], + "foo": [1, 2, 3], + }, + { + "id": ["A", "B"], + "time": [ + Timestamp("2012-01-01 13:00:00", tz="America/New_York"), + Timestamp("2012-02-01 14:00:00", tz="US/Central"), + ], + "foo": [1, 2], + }, + { + "id": ["A", "B"], + "time": [ + Timestamp("2012-03-01 12:00:00", tz="Europe/London"), + Timestamp("2012-02-01 14:00:00", tz="US/Central"), + ], + "foo": [3, 2], + }, + ), + ], +) +def test_first_last_tz(data, expected_first, expected_last): + # GH15884 + # Test that the timezone is retained when calling first + # or last on groupby with as_index=False + + df = DataFrame(data) + + result = df.groupby("id", as_index=False).first() + expected = DataFrame(expected_first) + cols = ["id", "time", "foo"] + tm.assert_frame_equal(result[cols], expected[cols]) + + result = df.groupby("id", as_index=False)["time"].first() + tm.assert_frame_equal(result, expected[["id", "time"]]) + + result = df.groupby("id", as_index=False).last() + expected = DataFrame(expected_last) + cols = ["id", "time", "foo"] + tm.assert_frame_equal(result[cols], expected[cols]) + + result = df.groupby("id", as_index=False)["time"].last() + tm.assert_frame_equal(result, expected[["id", "time"]]) + + +@pytest.mark.parametrize( + "method, ts, alpha", + [ + ["first", Timestamp("2013-01-01", tz="US/Eastern"), "a"], + ["last", Timestamp("2013-01-02", tz="US/Eastern"), "b"], + ], +) +def test_first_last_tz_multi_column(method, ts, alpha, unit): + # GH 21603 + category_string = Series(list("abc")).astype("category") + dti = pd.date_range("20130101", periods=3, tz="US/Eastern", unit=unit) + df = DataFrame( + { + "group": [1, 1, 2], + "category_string": category_string, + "datetimetz": dti, + } + ) + result = getattr(df.groupby("group"), method)() + expected = DataFrame( + { + "category_string": pd.Categorical( + [alpha, "c"], dtype=category_string.dtype + ), + "datetimetz": [ts, Timestamp("2013-01-03", tz="US/Eastern")], + }, + index=Index([1, 2], name="group"), + ) + expected["datetimetz"] = expected["datetimetz"].dt.as_unit(unit) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "values", + [ + pd.array([True, False], dtype="boolean"), + pd.array([1, 2], dtype="Int64"), + pd.to_datetime(["2020-01-01", "2020-02-01"]), + pd.to_timedelta([1, 2], unit="D"), + ], +) +@pytest.mark.parametrize("function", ["first", "last", "min", "max"]) +def test_first_last_extension_array_keeps_dtype(values, function): + # https://github.com/pandas-dev/pandas/issues/33071 + # https://github.com/pandas-dev/pandas/issues/32194 + df = DataFrame({"a": [1, 2], "b": values}) + grouped = df.groupby("a") + idx = Index([1, 2], name="a") + expected_series = Series(values, name="b", index=idx) + expected_frame = DataFrame({"b": values}, index=idx) + + result_series = getattr(grouped["b"], function)() + tm.assert_series_equal(result_series, expected_series) + + result_frame = grouped.agg({"b": function}) + tm.assert_frame_equal(result_frame, expected_frame) + + +def test_nth_multi_index_as_expected(): + # PR 9090, related to issue 8979 + # test nth on MultiIndex + three_group = DataFrame( + { + "A": [ + "foo", + "foo", + "foo", + "foo", + "bar", + "bar", + "bar", + "bar", + "foo", + "foo", + "foo", + ], + "B": [ + "one", + "one", + "one", + "two", + "one", + "one", + "one", + "two", + "two", + "two", + "one", + ], + "C": [ + "dull", + "dull", + "shiny", + "dull", + "dull", + "shiny", + "shiny", + "dull", + "shiny", + "shiny", + "shiny", + ], + } + ) + grouped = three_group.groupby(["A", "B"]) + result = grouped.nth(0) + expected = three_group.iloc[[0, 3, 4, 7]] + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "op, n, expected_rows", + [ + ("head", -1, [0]), + ("head", 0, []), + ("head", 1, [0, 2]), + ("head", 7, [0, 1, 2]), + ("tail", -1, [1]), + ("tail", 0, []), + ("tail", 1, [1, 2]), + ("tail", 7, [0, 1, 2]), + ], +) +@pytest.mark.parametrize("columns", [None, [], ["A"], ["B"], ["A", "B"]]) +@pytest.mark.parametrize("as_index", [True, False]) +def test_groupby_head_tail(op, n, expected_rows, columns, as_index): + df = DataFrame([[1, 2], [1, 4], [5, 6]], columns=["A", "B"]) + g = df.groupby("A", as_index=as_index) + expected = df.iloc[expected_rows] + if columns is not None: + g = g[columns] + expected = expected[columns] + result = getattr(g, op)(n) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "op, n, expected_cols", + [ + ("head", -1, [0]), + ("head", 0, []), + ("head", 1, [0, 2]), + ("head", 7, [0, 1, 2]), + ("tail", -1, [1]), + ("tail", 0, []), + ("tail", 1, [1, 2]), + ("tail", 7, [0, 1, 2]), + ], +) +def test_groupby_head_tail_axis_1(op, n, expected_cols): + # GH 9772 + df = DataFrame( + [[1, 2, 3], [1, 4, 5], [2, 6, 7], [3, 8, 9]], columns=["A", "B", "C"] + ) + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + g = df.groupby([0, 0, 1], axis=1) + expected = df.iloc[:, expected_cols] + result = getattr(g, op)(n) + tm.assert_frame_equal(result, expected) + + +def test_group_selection_cache(): + # GH 12839 nth, head, and tail should return same result consistently + df = DataFrame([[1, 2], [1, 4], [5, 6]], columns=["A", "B"]) + expected = df.iloc[[0, 2]] + + g = df.groupby("A") + result1 = g.head(n=2) + result2 = g.nth(0) + tm.assert_frame_equal(result1, df) + tm.assert_frame_equal(result2, expected) + + g = df.groupby("A") + result1 = g.tail(n=2) + result2 = g.nth(0) + tm.assert_frame_equal(result1, df) + tm.assert_frame_equal(result2, expected) + + g = df.groupby("A") + result1 = g.nth(0) + result2 = g.head(n=2) + tm.assert_frame_equal(result1, expected) + tm.assert_frame_equal(result2, df) + + g = df.groupby("A") + result1 = g.nth(0) + result2 = g.tail(n=2) + tm.assert_frame_equal(result1, expected) + tm.assert_frame_equal(result2, df) + + +def test_nth_empty(): + # GH 16064 + df = DataFrame(index=[0], columns=["a", "b", "c"]) + result = df.groupby("a").nth(10) + expected = df.iloc[:0] + tm.assert_frame_equal(result, expected) + + result = df.groupby(["a", "b"]).nth(10) + expected = df.iloc[:0] + tm.assert_frame_equal(result, expected) + + +def test_nth_column_order(): + # GH 20760 + # Check that nth preserves column order + df = DataFrame( + [[1, "b", 100], [1, "a", 50], [1, "a", np.nan], [2, "c", 200], [2, "d", 150]], + columns=["A", "C", "B"], + ) + result = df.groupby("A").nth(0) + expected = df.iloc[[0, 3]] + tm.assert_frame_equal(result, expected) + + result = df.groupby("A").nth(-1, dropna="any") + expected = df.iloc[[1, 4]] + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dropna", [None, "any", "all"]) +def test_nth_nan_in_grouper(dropna): + # GH 26011 + df = DataFrame( + { + "a": [np.nan, "a", np.nan, "b", np.nan], + "b": [0, 2, 4, 6, 8], + "c": [1, 3, 5, 7, 9], + } + ) + result = df.groupby("a").nth(0, dropna=dropna) + expected = df.iloc[[1, 3]] + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dropna", [None, "any", "all"]) +def test_nth_nan_in_grouper_series(dropna): + # GH 26454 + df = DataFrame( + { + "a": [np.nan, "a", np.nan, "b", np.nan], + "b": [0, 2, 4, 6, 8], + } + ) + result = df.groupby("a")["b"].nth(0, dropna=dropna) + expected = df["b"].iloc[[1, 3]] + + tm.assert_series_equal(result, expected) + + +def test_first_categorical_and_datetime_data_nat(): + # GH 20520 + df = DataFrame( + { + "group": ["first", "first", "second", "third", "third"], + "time": 5 * [np.datetime64("NaT")], + "categories": Series(["a", "b", "c", "a", "b"], dtype="category"), + } + ) + result = df.groupby("group").first() + expected = DataFrame( + { + "time": 3 * [np.datetime64("NaT")], + "categories": Series(["a", "c", "a"]).astype( + pd.CategoricalDtype(["a", "b", "c"]) + ), + } + ) + expected.index = Index(["first", "second", "third"], name="group") + tm.assert_frame_equal(result, expected) + + +def test_first_multi_key_groupby_categorical(): + # GH 22512 + df = DataFrame( + { + "A": [1, 1, 1, 2, 2], + "B": [100, 100, 200, 100, 100], + "C": ["apple", "orange", "mango", "mango", "orange"], + "D": ["jupiter", "mercury", "mars", "venus", "venus"], + } + ) + df = df.astype({"D": "category"}) + result = df.groupby(by=["A", "B"]).first() + expected = DataFrame( + { + "C": ["apple", "mango", "mango"], + "D": Series(["jupiter", "mars", "venus"]).astype( + pd.CategoricalDtype(["jupiter", "mars", "mercury", "venus"]) + ), + } + ) + expected.index = MultiIndex.from_tuples( + [(1, 100), (1, 200), (2, 100)], names=["A", "B"] + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("method", ["first", "last", "nth"]) +def test_groupby_last_first_nth_with_none(method, nulls_fixture): + # GH29645 + expected = Series(["y"]) + data = Series( + [nulls_fixture, nulls_fixture, nulls_fixture, "y", nulls_fixture], + index=[0, 0, 0, 0, 0], + ).groupby(level=0) + + if method == "nth": + result = getattr(data, method)(3) + else: + result = getattr(data, method)() + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "arg, expected_rows", + [ + [slice(None, 3, 2), [0, 1, 4, 5]], + [slice(None, -2), [0, 2, 5]], + [[slice(None, 2), slice(-2, None)], [0, 1, 2, 3, 4, 6, 7]], + [[0, 1, slice(-2, None)], [0, 1, 2, 3, 4, 6, 7]], + ], +) +def test_slice(slice_test_df, slice_test_grouped, arg, expected_rows): + # Test slices GH #42947 + + result = slice_test_grouped.nth[arg] + equivalent = slice_test_grouped.nth(arg) + expected = slice_test_df.iloc[expected_rows] + + tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(equivalent, expected) + + +def test_nth_indexed(slice_test_df, slice_test_grouped): + # Test index notation GH #44688 + + result = slice_test_grouped.nth[0, 1, -2:] + equivalent = slice_test_grouped.nth([0, 1, slice(-2, None)]) + expected = slice_test_df.iloc[[0, 1, 2, 3, 4, 6, 7]] + + tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(equivalent, expected) + + +def test_invalid_argument(slice_test_grouped): + # Test for error on invalid argument + + with pytest.raises(TypeError, match="Invalid index"): + slice_test_grouped.nth(3.14) + + +def test_negative_step(slice_test_grouped): + # Test for error on negative slice step + + with pytest.raises(ValueError, match="Invalid step"): + slice_test_grouped.nth(slice(None, None, -1)) + + +def test_np_ints(slice_test_df, slice_test_grouped): + # Test np ints work + + result = slice_test_grouped.nth(np.array([0, 1])) + expected = slice_test_df.iloc[[0, 1, 2, 3, 4]] + tm.assert_frame_equal(result, expected) + + +def test_groupby_nth_with_column_axis(): + # GH43926 + df = DataFrame( + [ + [4, 5, 6], + [8, 8, 7], + ], + index=["z", "y"], + columns=["C", "B", "A"], + ) + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + gb = df.groupby(df.iloc[1], axis=1) + result = gb.nth(0) + expected = df.iloc[:, [0, 2]] + tm.assert_frame_equal(result, expected) + + +def test_groupby_nth_interval(): + # GH#24205 + idx_result = MultiIndex( + [ + pd.CategoricalIndex([pd.Interval(0, 1), pd.Interval(1, 2)]), + pd.CategoricalIndex([pd.Interval(0, 10), pd.Interval(10, 20)]), + ], + [[0, 0, 0, 1, 1], [0, 1, 1, 0, -1]], + ) + df_result = DataFrame({"col": range(len(idx_result))}, index=idx_result) + result = df_result.groupby(level=[0, 1], observed=False).nth(0) + val_expected = [0, 1, 3] + idx_expected = MultiIndex( + [ + pd.CategoricalIndex([pd.Interval(0, 1), pd.Interval(1, 2)]), + pd.CategoricalIndex([pd.Interval(0, 10), pd.Interval(10, 20)]), + ], + [[0, 0, 1], [0, 1, 0]], + ) + expected = DataFrame(val_expected, index=idx_expected, columns=["col"]) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "start, stop, expected_values, expected_columns", + [ + (None, None, [0, 1, 2, 3, 4], list("ABCDE")), + (None, 1, [0, 3], list("AD")), + (None, 9, [0, 1, 2, 3, 4], list("ABCDE")), + (None, -1, [0, 1, 3], list("ABD")), + (1, None, [1, 2, 4], list("BCE")), + (1, -1, [1], list("B")), + (-1, None, [2, 4], list("CE")), + (-1, 2, [4], list("E")), + ], +) +@pytest.mark.parametrize("method", ["call", "index"]) +def test_nth_slices_with_column_axis( + start, stop, expected_values, expected_columns, method +): + df = DataFrame([range(5)], columns=[list("ABCDE")]) + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + gb = df.groupby([5, 5, 5, 6, 6], axis=1) + result = { + "call": lambda start, stop: gb.nth(slice(start, stop)), + "index": lambda start, stop: gb.nth[start:stop], + }[method](start, stop) + expected = DataFrame([expected_values], columns=[expected_columns]) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.filterwarnings( + "ignore:invalid value encountered in remainder:RuntimeWarning" +) +def test_head_tail_dropna_true(): + # GH#45089 + df = DataFrame( + [["a", "z"], ["b", np.nan], ["c", np.nan], ["c", np.nan]], columns=["X", "Y"] + ) + expected = DataFrame([["a", "z"]], columns=["X", "Y"]) + + result = df.groupby(["X", "Y"]).head(n=1) + tm.assert_frame_equal(result, expected) + + result = df.groupby(["X", "Y"]).tail(n=1) + tm.assert_frame_equal(result, expected) + + result = df.groupby(["X", "Y"]).nth(n=0) + tm.assert_frame_equal(result, expected) + + +def test_head_tail_dropna_false(): + # GH#45089 + df = DataFrame([["a", "z"], ["b", np.nan], ["c", np.nan]], columns=["X", "Y"]) + expected = DataFrame([["a", "z"], ["b", np.nan], ["c", np.nan]], columns=["X", "Y"]) + + result = df.groupby(["X", "Y"], dropna=False).head(n=1) + tm.assert_frame_equal(result, expected) + + result = df.groupby(["X", "Y"], dropna=False).tail(n=1) + tm.assert_frame_equal(result, expected) + + result = df.groupby(["X", "Y"], dropna=False).nth(n=0) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("selection", ("b", ["b"], ["b", "c"])) +@pytest.mark.parametrize("dropna", ["any", "all", None]) +def test_nth_after_selection(selection, dropna): + # GH#11038, GH#53518 + df = DataFrame( + { + "a": [1, 1, 2], + "b": [np.nan, 3, 4], + "c": [5, 6, 7], + } + ) + gb = df.groupby("a")[selection] + result = gb.nth(0, dropna=dropna) + if dropna == "any" or (dropna == "all" and selection != ["b", "c"]): + locs = [1, 2] + else: + locs = [0, 2] + expected = df.loc[locs, selection] + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "data", + [ + ( + Timestamp("2011-01-15 12:50:28.502376"), + Timestamp("2011-01-20 12:50:28.593448"), + ), + (24650000000000001, 24650000000000002), + ], +) +def test_groupby_nth_int_like_precision(data): + # GH#6620, GH#9311 + df = DataFrame({"a": [1, 1], "b": data}) + + grouped = df.groupby("a") + result = grouped.nth(0) + expected = DataFrame({"a": 1, "b": [data[0]]}) + + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_quantile.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_quantile.py new file mode 100644 index 0000000000000000000000000000000000000000..361a8c27fbf9d6744a11d56cf228f06a53d4adaf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_quantile.py @@ -0,0 +1,496 @@ +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + Index, +) +import pandas._testing as tm + + +@pytest.mark.parametrize( + "interpolation", ["linear", "lower", "higher", "nearest", "midpoint"] +) +@pytest.mark.parametrize( + "a_vals,b_vals", + [ + # Ints + ([1, 2, 3, 4, 5], [5, 4, 3, 2, 1]), + ([1, 2, 3, 4], [4, 3, 2, 1]), + ([1, 2, 3, 4, 5], [4, 3, 2, 1]), + # Floats + ([1.0, 2.0, 3.0, 4.0, 5.0], [5.0, 4.0, 3.0, 2.0, 1.0]), + # Missing data + ([1.0, np.nan, 3.0, np.nan, 5.0], [5.0, np.nan, 3.0, np.nan, 1.0]), + ([np.nan, 4.0, np.nan, 2.0, np.nan], [np.nan, 4.0, np.nan, 2.0, np.nan]), + # Timestamps + ( + pd.date_range("1/1/18", freq="D", periods=5), + pd.date_range("1/1/18", freq="D", periods=5)[::-1], + ), + ( + pd.date_range("1/1/18", freq="D", periods=5).as_unit("s"), + pd.date_range("1/1/18", freq="D", periods=5)[::-1].as_unit("s"), + ), + # All NA + ([np.nan] * 5, [np.nan] * 5), + ], +) +@pytest.mark.parametrize("q", [0, 0.25, 0.5, 0.75, 1]) +def test_quantile(interpolation, a_vals, b_vals, q, request): + if ( + interpolation == "nearest" + and q == 0.5 + and isinstance(b_vals, list) + and b_vals == [4, 3, 2, 1] + ): + request.applymarker( + pytest.mark.xfail( + reason="Unclear numpy expectation for nearest " + "result with equidistant data" + ) + ) + all_vals = pd.concat([pd.Series(a_vals), pd.Series(b_vals)]) + + a_expected = pd.Series(a_vals).quantile(q, interpolation=interpolation) + b_expected = pd.Series(b_vals).quantile(q, interpolation=interpolation) + + df = DataFrame({"key": ["a"] * len(a_vals) + ["b"] * len(b_vals), "val": all_vals}) + + expected = DataFrame( + [a_expected, b_expected], columns=["val"], index=Index(["a", "b"], name="key") + ) + if all_vals.dtype.kind == "M" and expected.dtypes.values[0].kind == "M": + # TODO(non-nano): this should be unnecessary once array_to_datetime + # correctly infers non-nano from Timestamp.unit + expected = expected.astype(all_vals.dtype) + result = df.groupby("key").quantile(q, interpolation=interpolation) + + tm.assert_frame_equal(result, expected) + + +def test_quantile_array(): + # https://github.com/pandas-dev/pandas/issues/27526 + df = DataFrame({"A": [0, 1, 2, 3, 4]}) + key = np.array([0, 0, 1, 1, 1], dtype=np.int64) + result = df.groupby(key).quantile([0.25]) + + index = pd.MultiIndex.from_product([[0, 1], [0.25]]) + expected = DataFrame({"A": [0.25, 2.50]}, index=index) + tm.assert_frame_equal(result, expected) + + df = DataFrame({"A": [0, 1, 2, 3], "B": [4, 5, 6, 7]}) + index = pd.MultiIndex.from_product([[0, 1], [0.25, 0.75]]) + + key = np.array([0, 0, 1, 1], dtype=np.int64) + result = df.groupby(key).quantile([0.25, 0.75]) + expected = DataFrame( + {"A": [0.25, 0.75, 2.25, 2.75], "B": [4.25, 4.75, 6.25, 6.75]}, index=index + ) + tm.assert_frame_equal(result, expected) + + +def test_quantile_array2(): + # https://github.com/pandas-dev/pandas/pull/28085#issuecomment-524066959 + arr = np.random.default_rng(2).integers(0, 5, size=(10, 3), dtype=np.int64) + df = DataFrame(arr, columns=list("ABC")) + result = df.groupby("A").quantile([0.3, 0.7]) + expected = DataFrame( + { + "B": [2.0, 2.0, 2.3, 2.7, 0.3, 0.7, 3.2, 4.0, 0.3, 0.7], + "C": [1.0, 1.0, 1.9, 3.0999999999999996, 0.3, 0.7, 2.6, 3.0, 1.2, 2.8], + }, + index=pd.MultiIndex.from_product( + [[0, 1, 2, 3, 4], [0.3, 0.7]], names=["A", None] + ), + ) + tm.assert_frame_equal(result, expected) + + +def test_quantile_array_no_sort(): + df = DataFrame({"A": [0, 1, 2], "B": [3, 4, 5]}) + key = np.array([1, 0, 1], dtype=np.int64) + result = df.groupby(key, sort=False).quantile([0.25, 0.5, 0.75]) + expected = DataFrame( + {"A": [0.5, 1.0, 1.5, 1.0, 1.0, 1.0], "B": [3.5, 4.0, 4.5, 4.0, 4.0, 4.0]}, + index=pd.MultiIndex.from_product([[1, 0], [0.25, 0.5, 0.75]]), + ) + tm.assert_frame_equal(result, expected) + + result = df.groupby(key, sort=False).quantile([0.75, 0.25]) + expected = DataFrame( + {"A": [1.5, 0.5, 1.0, 1.0], "B": [4.5, 3.5, 4.0, 4.0]}, + index=pd.MultiIndex.from_product([[1, 0], [0.75, 0.25]]), + ) + tm.assert_frame_equal(result, expected) + + +def test_quantile_array_multiple_levels(): + df = DataFrame( + {"A": [0, 1, 2], "B": [3, 4, 5], "c": ["a", "a", "a"], "d": ["a", "a", "b"]} + ) + result = df.groupby(["c", "d"]).quantile([0.25, 0.75]) + index = pd.MultiIndex.from_tuples( + [("a", "a", 0.25), ("a", "a", 0.75), ("a", "b", 0.25), ("a", "b", 0.75)], + names=["c", "d", None], + ) + expected = DataFrame( + {"A": [0.25, 0.75, 2.0, 2.0], "B": [3.25, 3.75, 5.0, 5.0]}, index=index + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("frame_size", [(2, 3), (100, 10)]) +@pytest.mark.parametrize("groupby", [[0], [0, 1]]) +@pytest.mark.parametrize("q", [[0.5, 0.6]]) +def test_groupby_quantile_with_arraylike_q_and_int_columns(frame_size, groupby, q): + # GH30289 + nrow, ncol = frame_size + df = DataFrame(np.array([ncol * [_ % 4] for _ in range(nrow)]), columns=range(ncol)) + + idx_levels = [np.arange(min(nrow, 4))] * len(groupby) + [q] + idx_codes = [[x for x in range(min(nrow, 4)) for _ in q]] * len(groupby) + [ + list(range(len(q))) * min(nrow, 4) + ] + expected_index = pd.MultiIndex( + levels=idx_levels, codes=idx_codes, names=groupby + [None] + ) + expected_values = [ + [float(x)] * (ncol - len(groupby)) for x in range(min(nrow, 4)) for _ in q + ] + expected_columns = [x for x in range(ncol) if x not in groupby] + expected = DataFrame( + expected_values, index=expected_index, columns=expected_columns + ) + result = df.groupby(groupby).quantile(q) + + tm.assert_frame_equal(result, expected) + + +def test_quantile_raises(): + df = DataFrame([["foo", "a"], ["foo", "b"], ["foo", "c"]], columns=["key", "val"]) + + with pytest.raises(TypeError, match="cannot be performed against 'object' dtypes"): + df.groupby("key").quantile() + + +def test_quantile_out_of_bounds_q_raises(): + # https://github.com/pandas-dev/pandas/issues/27470 + df = DataFrame({"a": [0, 0, 0, 1, 1, 1], "b": range(6)}) + g = df.groupby([0, 0, 0, 1, 1, 1]) + with pytest.raises(ValueError, match="Got '50.0' instead"): + g.quantile(50) + + with pytest.raises(ValueError, match="Got '-1.0' instead"): + g.quantile(-1) + + +def test_quantile_missing_group_values_no_segfaults(): + # GH 28662 + data = np.array([1.0, np.nan, 1.0]) + df = DataFrame({"key": data, "val": range(3)}) + + # Random segfaults; would have been guaranteed in loop + grp = df.groupby("key") + for _ in range(100): + grp.quantile() + + +@pytest.mark.parametrize( + "key, val, expected_key, expected_val", + [ + ([1.0, np.nan, 3.0, np.nan], range(4), [1.0, 3.0], [0.0, 2.0]), + ([1.0, np.nan, 2.0, 2.0], range(4), [1.0, 2.0], [0.0, 2.5]), + (["a", "b", "b", np.nan], range(4), ["a", "b"], [0, 1.5]), + ([0], [42], [0], [42.0]), + ([], [], np.array([], dtype="float64"), np.array([], dtype="float64")), + ], +) +def test_quantile_missing_group_values_correct_results( + key, val, expected_key, expected_val +): + # GH 28662, GH 33200, GH 33569 + df = DataFrame({"key": key, "val": val}) + + expected = DataFrame( + expected_val, index=Index(expected_key, name="key"), columns=["val"] + ) + + grp = df.groupby("key") + + result = grp.quantile(0.5) + tm.assert_frame_equal(result, expected) + + result = grp.quantile() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "values", + [ + pd.array([1, 0, None] * 2, dtype="Int64"), + pd.array([True, False, None] * 2, dtype="boolean"), + ], +) +@pytest.mark.parametrize("q", [0.5, [0.0, 0.5, 1.0]]) +def test_groupby_quantile_nullable_array(values, q): + # https://github.com/pandas-dev/pandas/issues/33136 + df = DataFrame({"a": ["x"] * 3 + ["y"] * 3, "b": values}) + result = df.groupby("a")["b"].quantile(q) + + if isinstance(q, list): + idx = pd.MultiIndex.from_product((["x", "y"], q), names=["a", None]) + true_quantiles = [0.0, 0.5, 1.0] + else: + idx = Index(["x", "y"], name="a") + true_quantiles = [0.5] + + expected = pd.Series(true_quantiles * 2, index=idx, name="b", dtype="Float64") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("q", [0.5, [0.0, 0.5, 1.0]]) +@pytest.mark.parametrize("numeric_only", [True, False]) +def test_groupby_quantile_raises_on_invalid_dtype(q, numeric_only): + df = DataFrame({"a": [1], "b": [2.0], "c": ["x"]}) + if numeric_only: + result = df.groupby("a").quantile(q, numeric_only=numeric_only) + expected = df.groupby("a")[["b"]].quantile(q) + tm.assert_frame_equal(result, expected) + else: + with pytest.raises( + TypeError, match="'quantile' cannot be performed against 'object' dtypes!" + ): + df.groupby("a").quantile(q, numeric_only=numeric_only) + + +def test_groupby_quantile_NA_float(any_float_dtype): + # GH#42849 + df = DataFrame({"x": [1, 1], "y": [0.2, np.nan]}, dtype=any_float_dtype) + result = df.groupby("x")["y"].quantile(0.5) + exp_index = Index([1.0], dtype=any_float_dtype, name="x") + + if any_float_dtype in ["Float32", "Float64"]: + expected_dtype = any_float_dtype + else: + expected_dtype = None + + expected = pd.Series([0.2], dtype=expected_dtype, index=exp_index, name="y") + tm.assert_series_equal(result, expected) + + result = df.groupby("x")["y"].quantile([0.5, 0.75]) + expected = pd.Series( + [0.2] * 2, + index=pd.MultiIndex.from_product((exp_index, [0.5, 0.75]), names=["x", None]), + name="y", + dtype=expected_dtype, + ) + tm.assert_series_equal(result, expected) + + +def test_groupby_quantile_NA_int(any_int_ea_dtype): + # GH#42849 + df = DataFrame({"x": [1, 1], "y": [2, 5]}, dtype=any_int_ea_dtype) + result = df.groupby("x")["y"].quantile(0.5) + expected = pd.Series( + [3.5], + dtype="Float64", + index=Index([1], name="x", dtype=any_int_ea_dtype), + name="y", + ) + tm.assert_series_equal(expected, result) + + result = df.groupby("x").quantile(0.5) + expected = DataFrame( + {"y": 3.5}, dtype="Float64", index=Index([1], name="x", dtype=any_int_ea_dtype) + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "interpolation, val1, val2", [("lower", 2, 2), ("higher", 2, 3), ("nearest", 2, 2)] +) +def test_groupby_quantile_all_na_group_masked( + interpolation, val1, val2, any_numeric_ea_dtype +): + # GH#37493 + df = DataFrame( + {"a": [1, 1, 1, 2], "b": [1, 2, 3, pd.NA]}, dtype=any_numeric_ea_dtype + ) + result = df.groupby("a").quantile(q=[0.5, 0.7], interpolation=interpolation) + expected = DataFrame( + {"b": [val1, val2, pd.NA, pd.NA]}, + dtype=any_numeric_ea_dtype, + index=pd.MultiIndex.from_arrays( + [pd.Series([1, 1, 2, 2], dtype=any_numeric_ea_dtype), [0.5, 0.7, 0.5, 0.7]], + names=["a", None], + ), + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("interpolation", ["midpoint", "linear"]) +def test_groupby_quantile_all_na_group_masked_interp( + interpolation, any_numeric_ea_dtype +): + # GH#37493 + df = DataFrame( + {"a": [1, 1, 1, 2], "b": [1, 2, 3, pd.NA]}, dtype=any_numeric_ea_dtype + ) + result = df.groupby("a").quantile(q=[0.5, 0.75], interpolation=interpolation) + + if any_numeric_ea_dtype == "Float32": + expected_dtype = any_numeric_ea_dtype + else: + expected_dtype = "Float64" + + expected = DataFrame( + {"b": [2.0, 2.5, pd.NA, pd.NA]}, + dtype=expected_dtype, + index=pd.MultiIndex.from_arrays( + [ + pd.Series([1, 1, 2, 2], dtype=any_numeric_ea_dtype), + [0.5, 0.75, 0.5, 0.75], + ], + names=["a", None], + ), + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype", ["Float64", "Float32"]) +def test_groupby_quantile_allNA_column(dtype): + # GH#42849 + df = DataFrame({"x": [1, 1], "y": [pd.NA] * 2}, dtype=dtype) + result = df.groupby("x")["y"].quantile(0.5) + expected = pd.Series( + [np.nan], dtype=dtype, index=Index([1.0], dtype=dtype), name="y" + ) + expected.index.name = "x" + tm.assert_series_equal(expected, result) + + +def test_groupby_timedelta_quantile(): + # GH: 29485 + df = DataFrame( + {"value": pd.to_timedelta(np.arange(4), unit="s"), "group": [1, 1, 2, 2]} + ) + result = df.groupby("group").quantile(0.99) + expected = DataFrame( + { + "value": [ + pd.Timedelta("0 days 00:00:00.990000"), + pd.Timedelta("0 days 00:00:02.990000"), + ] + }, + index=Index([1, 2], name="group"), + ) + tm.assert_frame_equal(result, expected) + + +def test_columns_groupby_quantile(): + # GH 33795 + df = DataFrame( + np.arange(12).reshape(3, -1), + index=list("XYZ"), + columns=pd.Series(list("ABAB"), name="col"), + ) + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + gb = df.groupby("col", axis=1) + result = gb.quantile(q=[0.8, 0.2]) + expected = DataFrame( + [ + [1.6, 0.4, 2.6, 1.4], + [5.6, 4.4, 6.6, 5.4], + [9.6, 8.4, 10.6, 9.4], + ], + index=list("XYZ"), + columns=pd.MultiIndex.from_tuples( + [("A", 0.8), ("A", 0.2), ("B", 0.8), ("B", 0.2)], names=["col", None] + ), + ) + + tm.assert_frame_equal(result, expected) + + +def test_timestamp_groupby_quantile(unit): + # GH 33168 + dti = pd.date_range( + start="2020-04-19 00:00:00", freq="1min", periods=100, tz="UTC", unit=unit + ).floor("1h") + df = DataFrame( + { + "timestamp": dti, + "category": list(range(1, 101)), + "value": list(range(101, 201)), + } + ) + + result = df.groupby("timestamp").quantile([0.2, 0.8]) + + mi = pd.MultiIndex.from_product([dti[::99], [0.2, 0.8]], names=("timestamp", None)) + expected = DataFrame( + [ + {"category": 12.8, "value": 112.8}, + {"category": 48.2, "value": 148.2}, + {"category": 68.8, "value": 168.8}, + {"category": 92.2, "value": 192.2}, + ], + index=mi, + ) + + tm.assert_frame_equal(result, expected) + + +def test_groupby_quantile_dt64tz_period(): + # GH#51373 + dti = pd.date_range("2016-01-01", periods=1000) + df = pd.Series(dti).to_frame().copy() + df[1] = dti.tz_localize("US/Pacific") + df[2] = dti.to_period("D") + df[3] = dti - dti[0] + df.iloc[-1] = pd.NaT + + by = np.tile(np.arange(5), 200) + gb = df.groupby(by) + + result = gb.quantile(0.5) + + # Check that we match the group-by-group result + exp = {i: df.iloc[i::5].quantile(0.5) for i in range(5)} + expected = DataFrame(exp).T.infer_objects() + expected.index = expected.index.astype(int) + + tm.assert_frame_equal(result, expected) + + +def test_groupby_quantile_nonmulti_levels_order(): + # Non-regression test for GH #53009 + ind = pd.MultiIndex.from_tuples( + [ + (0, "a", "B"), + (0, "a", "A"), + (0, "b", "B"), + (0, "b", "A"), + (1, "a", "B"), + (1, "a", "A"), + (1, "b", "B"), + (1, "b", "A"), + ], + names=["sample", "cat0", "cat1"], + ) + ser = pd.Series(range(8), index=ind) + result = ser.groupby(level="cat1", sort=False).quantile([0.2, 0.8]) + + qind = pd.MultiIndex.from_tuples( + [("B", 0.2), ("B", 0.8), ("A", 0.2), ("A", 0.8)], names=["cat1", None] + ) + expected = pd.Series([1.2, 4.8, 2.2, 5.8], index=qind) + + tm.assert_series_equal(result, expected) + + # We need to check that index levels are not sorted + expected_levels = pd.core.indexes.frozen.FrozenList([["B", "A"], [0.2, 0.8]]) + tm.assert_equal(result.index.levels, expected_levels) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_rank.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_rank.py new file mode 100644 index 0000000000000000000000000000000000000000..a3b7da3fa836c955d8d0e4e17754d7834e5c05f1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_rank.py @@ -0,0 +1,721 @@ +from datetime import datetime + +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + NaT, + Series, + concat, +) +import pandas._testing as tm + + +def test_rank_unordered_categorical_typeerror(): + # GH#51034 should be TypeError, not NotImplementedError + cat = pd.Categorical([], ordered=False) + ser = Series(cat) + df = ser.to_frame() + + msg = "Cannot perform rank with non-ordered Categorical" + + gb = ser.groupby(cat, observed=False) + with pytest.raises(TypeError, match=msg): + gb.rank() + + gb2 = df.groupby(cat, observed=False) + with pytest.raises(TypeError, match=msg): + gb2.rank() + + +def test_rank_apply(): + lev1 = np.array(["a" * 10] * 100, dtype=object) + lev2 = np.array(["b" * 10] * 130, dtype=object) + lab1 = np.random.default_rng(2).integers(0, 100, size=500, dtype=int) + lab2 = np.random.default_rng(2).integers(0, 130, size=500, dtype=int) + + df = DataFrame( + { + "value": np.random.default_rng(2).standard_normal(500), + "key1": lev1.take(lab1), + "key2": lev2.take(lab2), + } + ) + + result = df.groupby(["key1", "key2"]).value.rank() + + expected = [piece.value.rank() for key, piece in df.groupby(["key1", "key2"])] + expected = concat(expected, axis=0) + expected = expected.reindex(result.index) + tm.assert_series_equal(result, expected) + + result = df.groupby(["key1", "key2"]).value.rank(pct=True) + + expected = [ + piece.value.rank(pct=True) for key, piece in df.groupby(["key1", "key2"]) + ] + expected = concat(expected, axis=0) + expected = expected.reindex(result.index) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("grps", [["qux"], ["qux", "quux"]]) +@pytest.mark.parametrize( + "vals", + [ + np.array([2, 2, 8, 2, 6], dtype=dtype) + for dtype in ["i8", "i4", "i2", "i1", "u8", "u4", "u2", "u1", "f8", "f4", "f2"] + ] + + [ + [ + pd.Timestamp("2018-01-02"), + pd.Timestamp("2018-01-02"), + pd.Timestamp("2018-01-08"), + pd.Timestamp("2018-01-02"), + pd.Timestamp("2018-01-06"), + ], + [ + pd.Timestamp("2018-01-02", tz="US/Pacific"), + pd.Timestamp("2018-01-02", tz="US/Pacific"), + pd.Timestamp("2018-01-08", tz="US/Pacific"), + pd.Timestamp("2018-01-02", tz="US/Pacific"), + pd.Timestamp("2018-01-06", tz="US/Pacific"), + ], + [ + pd.Timestamp("2018-01-02") - pd.Timestamp(0), + pd.Timestamp("2018-01-02") - pd.Timestamp(0), + pd.Timestamp("2018-01-08") - pd.Timestamp(0), + pd.Timestamp("2018-01-02") - pd.Timestamp(0), + pd.Timestamp("2018-01-06") - pd.Timestamp(0), + ], + [ + pd.Timestamp("2018-01-02").to_period("D"), + pd.Timestamp("2018-01-02").to_period("D"), + pd.Timestamp("2018-01-08").to_period("D"), + pd.Timestamp("2018-01-02").to_period("D"), + pd.Timestamp("2018-01-06").to_period("D"), + ], + ], + ids=lambda x: type(x[0]), +) +@pytest.mark.parametrize( + "ties_method,ascending,pct,exp", + [ + ("average", True, False, [2.0, 2.0, 5.0, 2.0, 4.0]), + ("average", True, True, [0.4, 0.4, 1.0, 0.4, 0.8]), + ("average", False, False, [4.0, 4.0, 1.0, 4.0, 2.0]), + ("average", False, True, [0.8, 0.8, 0.2, 0.8, 0.4]), + ("min", True, False, [1.0, 1.0, 5.0, 1.0, 4.0]), + ("min", True, True, [0.2, 0.2, 1.0, 0.2, 0.8]), + ("min", False, False, [3.0, 3.0, 1.0, 3.0, 2.0]), + ("min", False, True, [0.6, 0.6, 0.2, 0.6, 0.4]), + ("max", True, False, [3.0, 3.0, 5.0, 3.0, 4.0]), + ("max", True, True, [0.6, 0.6, 1.0, 0.6, 0.8]), + ("max", False, False, [5.0, 5.0, 1.0, 5.0, 2.0]), + ("max", False, True, [1.0, 1.0, 0.2, 1.0, 0.4]), + ("first", True, False, [1.0, 2.0, 5.0, 3.0, 4.0]), + ("first", True, True, [0.2, 0.4, 1.0, 0.6, 0.8]), + ("first", False, False, [3.0, 4.0, 1.0, 5.0, 2.0]), + ("first", False, True, [0.6, 0.8, 0.2, 1.0, 0.4]), + ("dense", True, False, [1.0, 1.0, 3.0, 1.0, 2.0]), + ("dense", True, True, [1.0 / 3.0, 1.0 / 3.0, 3.0 / 3.0, 1.0 / 3.0, 2.0 / 3.0]), + ("dense", False, False, [3.0, 3.0, 1.0, 3.0, 2.0]), + ("dense", False, True, [3.0 / 3.0, 3.0 / 3.0, 1.0 / 3.0, 3.0 / 3.0, 2.0 / 3.0]), + ], +) +def test_rank_args(grps, vals, ties_method, ascending, pct, exp): + key = np.repeat(grps, len(vals)) + + orig_vals = vals + vals = list(vals) * len(grps) + if isinstance(orig_vals, np.ndarray): + vals = np.array(vals, dtype=orig_vals.dtype) + + df = DataFrame({"key": key, "val": vals}) + result = df.groupby("key").rank(method=ties_method, ascending=ascending, pct=pct) + + exp_df = DataFrame(exp * len(grps), columns=["val"]) + tm.assert_frame_equal(result, exp_df) + + +@pytest.mark.parametrize("grps", [["qux"], ["qux", "quux"]]) +@pytest.mark.parametrize( + "vals", [[-np.inf, -np.inf, np.nan, 1.0, np.nan, np.inf, np.inf]] +) +@pytest.mark.parametrize( + "ties_method,ascending,na_option,exp", + [ + ("average", True, "keep", [1.5, 1.5, np.nan, 3, np.nan, 4.5, 4.5]), + ("average", True, "top", [3.5, 3.5, 1.5, 5.0, 1.5, 6.5, 6.5]), + ("average", True, "bottom", [1.5, 1.5, 6.5, 3.0, 6.5, 4.5, 4.5]), + ("average", False, "keep", [4.5, 4.5, np.nan, 3, np.nan, 1.5, 1.5]), + ("average", False, "top", [6.5, 6.5, 1.5, 5.0, 1.5, 3.5, 3.5]), + ("average", False, "bottom", [4.5, 4.5, 6.5, 3.0, 6.5, 1.5, 1.5]), + ("min", True, "keep", [1.0, 1.0, np.nan, 3.0, np.nan, 4.0, 4.0]), + ("min", True, "top", [3.0, 3.0, 1.0, 5.0, 1.0, 6.0, 6.0]), + ("min", True, "bottom", [1.0, 1.0, 6.0, 3.0, 6.0, 4.0, 4.0]), + ("min", False, "keep", [4.0, 4.0, np.nan, 3.0, np.nan, 1.0, 1.0]), + ("min", False, "top", [6.0, 6.0, 1.0, 5.0, 1.0, 3.0, 3.0]), + ("min", False, "bottom", [4.0, 4.0, 6.0, 3.0, 6.0, 1.0, 1.0]), + ("max", True, "keep", [2.0, 2.0, np.nan, 3.0, np.nan, 5.0, 5.0]), + ("max", True, "top", [4.0, 4.0, 2.0, 5.0, 2.0, 7.0, 7.0]), + ("max", True, "bottom", [2.0, 2.0, 7.0, 3.0, 7.0, 5.0, 5.0]), + ("max", False, "keep", [5.0, 5.0, np.nan, 3.0, np.nan, 2.0, 2.0]), + ("max", False, "top", [7.0, 7.0, 2.0, 5.0, 2.0, 4.0, 4.0]), + ("max", False, "bottom", [5.0, 5.0, 7.0, 3.0, 7.0, 2.0, 2.0]), + ("first", True, "keep", [1.0, 2.0, np.nan, 3.0, np.nan, 4.0, 5.0]), + ("first", True, "top", [3.0, 4.0, 1.0, 5.0, 2.0, 6.0, 7.0]), + ("first", True, "bottom", [1.0, 2.0, 6.0, 3.0, 7.0, 4.0, 5.0]), + ("first", False, "keep", [4.0, 5.0, np.nan, 3.0, np.nan, 1.0, 2.0]), + ("first", False, "top", [6.0, 7.0, 1.0, 5.0, 2.0, 3.0, 4.0]), + ("first", False, "bottom", [4.0, 5.0, 6.0, 3.0, 7.0, 1.0, 2.0]), + ("dense", True, "keep", [1.0, 1.0, np.nan, 2.0, np.nan, 3.0, 3.0]), + ("dense", True, "top", [2.0, 2.0, 1.0, 3.0, 1.0, 4.0, 4.0]), + ("dense", True, "bottom", [1.0, 1.0, 4.0, 2.0, 4.0, 3.0, 3.0]), + ("dense", False, "keep", [3.0, 3.0, np.nan, 2.0, np.nan, 1.0, 1.0]), + ("dense", False, "top", [4.0, 4.0, 1.0, 3.0, 1.0, 2.0, 2.0]), + ("dense", False, "bottom", [3.0, 3.0, 4.0, 2.0, 4.0, 1.0, 1.0]), + ], +) +def test_infs_n_nans(grps, vals, ties_method, ascending, na_option, exp): + # GH 20561 + key = np.repeat(grps, len(vals)) + vals = vals * len(grps) + df = DataFrame({"key": key, "val": vals}) + result = df.groupby("key").rank( + method=ties_method, ascending=ascending, na_option=na_option + ) + exp_df = DataFrame(exp * len(grps), columns=["val"]) + tm.assert_frame_equal(result, exp_df) + + +@pytest.mark.parametrize("grps", [["qux"], ["qux", "quux"]]) +@pytest.mark.parametrize( + "vals", + [ + np.array([2, 2, np.nan, 8, 2, 6, np.nan, np.nan], dtype=dtype) + for dtype in ["f8", "f4", "f2"] + ] + + [ + [ + pd.Timestamp("2018-01-02"), + pd.Timestamp("2018-01-02"), + np.nan, + pd.Timestamp("2018-01-08"), + pd.Timestamp("2018-01-02"), + pd.Timestamp("2018-01-06"), + np.nan, + np.nan, + ], + [ + pd.Timestamp("2018-01-02", tz="US/Pacific"), + pd.Timestamp("2018-01-02", tz="US/Pacific"), + np.nan, + pd.Timestamp("2018-01-08", tz="US/Pacific"), + pd.Timestamp("2018-01-02", tz="US/Pacific"), + pd.Timestamp("2018-01-06", tz="US/Pacific"), + np.nan, + np.nan, + ], + [ + pd.Timestamp("2018-01-02") - pd.Timestamp(0), + pd.Timestamp("2018-01-02") - pd.Timestamp(0), + np.nan, + pd.Timestamp("2018-01-08") - pd.Timestamp(0), + pd.Timestamp("2018-01-02") - pd.Timestamp(0), + pd.Timestamp("2018-01-06") - pd.Timestamp(0), + np.nan, + np.nan, + ], + [ + pd.Timestamp("2018-01-02").to_period("D"), + pd.Timestamp("2018-01-02").to_period("D"), + np.nan, + pd.Timestamp("2018-01-08").to_period("D"), + pd.Timestamp("2018-01-02").to_period("D"), + pd.Timestamp("2018-01-06").to_period("D"), + np.nan, + np.nan, + ], + ], + ids=lambda x: type(x[0]), +) +@pytest.mark.parametrize( + "ties_method,ascending,na_option,pct,exp", + [ + ( + "average", + True, + "keep", + False, + [2.0, 2.0, np.nan, 5.0, 2.0, 4.0, np.nan, np.nan], + ), + ( + "average", + True, + "keep", + True, + [0.4, 0.4, np.nan, 1.0, 0.4, 0.8, np.nan, np.nan], + ), + ( + "average", + False, + "keep", + False, + [4.0, 4.0, np.nan, 1.0, 4.0, 2.0, np.nan, np.nan], + ), + ( + "average", + False, + "keep", + True, + [0.8, 0.8, np.nan, 0.2, 0.8, 0.4, np.nan, np.nan], + ), + ("min", True, "keep", False, [1.0, 1.0, np.nan, 5.0, 1.0, 4.0, np.nan, np.nan]), + ("min", True, "keep", True, [0.2, 0.2, np.nan, 1.0, 0.2, 0.8, np.nan, np.nan]), + ( + "min", + False, + "keep", + False, + [3.0, 3.0, np.nan, 1.0, 3.0, 2.0, np.nan, np.nan], + ), + ("min", False, "keep", True, [0.6, 0.6, np.nan, 0.2, 0.6, 0.4, np.nan, np.nan]), + ("max", True, "keep", False, [3.0, 3.0, np.nan, 5.0, 3.0, 4.0, np.nan, np.nan]), + ("max", True, "keep", True, [0.6, 0.6, np.nan, 1.0, 0.6, 0.8, np.nan, np.nan]), + ( + "max", + False, + "keep", + False, + [5.0, 5.0, np.nan, 1.0, 5.0, 2.0, np.nan, np.nan], + ), + ("max", False, "keep", True, [1.0, 1.0, np.nan, 0.2, 1.0, 0.4, np.nan, np.nan]), + ( + "first", + True, + "keep", + False, + [1.0, 2.0, np.nan, 5.0, 3.0, 4.0, np.nan, np.nan], + ), + ( + "first", + True, + "keep", + True, + [0.2, 0.4, np.nan, 1.0, 0.6, 0.8, np.nan, np.nan], + ), + ( + "first", + False, + "keep", + False, + [3.0, 4.0, np.nan, 1.0, 5.0, 2.0, np.nan, np.nan], + ), + ( + "first", + False, + "keep", + True, + [0.6, 0.8, np.nan, 0.2, 1.0, 0.4, np.nan, np.nan], + ), + ( + "dense", + True, + "keep", + False, + [1.0, 1.0, np.nan, 3.0, 1.0, 2.0, np.nan, np.nan], + ), + ( + "dense", + True, + "keep", + True, + [ + 1.0 / 3.0, + 1.0 / 3.0, + np.nan, + 3.0 / 3.0, + 1.0 / 3.0, + 2.0 / 3.0, + np.nan, + np.nan, + ], + ), + ( + "dense", + False, + "keep", + False, + [3.0, 3.0, np.nan, 1.0, 3.0, 2.0, np.nan, np.nan], + ), + ( + "dense", + False, + "keep", + True, + [ + 3.0 / 3.0, + 3.0 / 3.0, + np.nan, + 1.0 / 3.0, + 3.0 / 3.0, + 2.0 / 3.0, + np.nan, + np.nan, + ], + ), + ("average", True, "bottom", False, [2.0, 2.0, 7.0, 5.0, 2.0, 4.0, 7.0, 7.0]), + ( + "average", + True, + "bottom", + True, + [0.25, 0.25, 0.875, 0.625, 0.25, 0.5, 0.875, 0.875], + ), + ("average", False, "bottom", False, [4.0, 4.0, 7.0, 1.0, 4.0, 2.0, 7.0, 7.0]), + ( + "average", + False, + "bottom", + True, + [0.5, 0.5, 0.875, 0.125, 0.5, 0.25, 0.875, 0.875], + ), + ("min", True, "bottom", False, [1.0, 1.0, 6.0, 5.0, 1.0, 4.0, 6.0, 6.0]), + ( + "min", + True, + "bottom", + True, + [0.125, 0.125, 0.75, 0.625, 0.125, 0.5, 0.75, 0.75], + ), + ("min", False, "bottom", False, [3.0, 3.0, 6.0, 1.0, 3.0, 2.0, 6.0, 6.0]), + ( + "min", + False, + "bottom", + True, + [0.375, 0.375, 0.75, 0.125, 0.375, 0.25, 0.75, 0.75], + ), + ("max", True, "bottom", False, [3.0, 3.0, 8.0, 5.0, 3.0, 4.0, 8.0, 8.0]), + ("max", True, "bottom", True, [0.375, 0.375, 1.0, 0.625, 0.375, 0.5, 1.0, 1.0]), + ("max", False, "bottom", False, [5.0, 5.0, 8.0, 1.0, 5.0, 2.0, 8.0, 8.0]), + ( + "max", + False, + "bottom", + True, + [0.625, 0.625, 1.0, 0.125, 0.625, 0.25, 1.0, 1.0], + ), + ("first", True, "bottom", False, [1.0, 2.0, 6.0, 5.0, 3.0, 4.0, 7.0, 8.0]), + ( + "first", + True, + "bottom", + True, + [0.125, 0.25, 0.75, 0.625, 0.375, 0.5, 0.875, 1.0], + ), + ("first", False, "bottom", False, [3.0, 4.0, 6.0, 1.0, 5.0, 2.0, 7.0, 8.0]), + ( + "first", + False, + "bottom", + True, + [0.375, 0.5, 0.75, 0.125, 0.625, 0.25, 0.875, 1.0], + ), + ("dense", True, "bottom", False, [1.0, 1.0, 4.0, 3.0, 1.0, 2.0, 4.0, 4.0]), + ("dense", True, "bottom", True, [0.25, 0.25, 1.0, 0.75, 0.25, 0.5, 1.0, 1.0]), + ("dense", False, "bottom", False, [3.0, 3.0, 4.0, 1.0, 3.0, 2.0, 4.0, 4.0]), + ("dense", False, "bottom", True, [0.75, 0.75, 1.0, 0.25, 0.75, 0.5, 1.0, 1.0]), + ], +) +def test_rank_args_missing(grps, vals, ties_method, ascending, na_option, pct, exp): + key = np.repeat(grps, len(vals)) + + orig_vals = vals + vals = list(vals) * len(grps) + if isinstance(orig_vals, np.ndarray): + vals = np.array(vals, dtype=orig_vals.dtype) + + df = DataFrame({"key": key, "val": vals}) + result = df.groupby("key").rank( + method=ties_method, ascending=ascending, na_option=na_option, pct=pct + ) + + exp_df = DataFrame(exp * len(grps), columns=["val"]) + tm.assert_frame_equal(result, exp_df) + + +@pytest.mark.parametrize( + "pct,exp", [(False, [3.0, 3.0, 3.0, 3.0, 3.0]), (True, [0.6, 0.6, 0.6, 0.6, 0.6])] +) +def test_rank_resets_each_group(pct, exp): + df = DataFrame( + {"key": ["a", "a", "a", "a", "a", "b", "b", "b", "b", "b"], "val": [1] * 10} + ) + result = df.groupby("key").rank(pct=pct) + exp_df = DataFrame(exp * 2, columns=["val"]) + tm.assert_frame_equal(result, exp_df) + + +@pytest.mark.parametrize( + "dtype", ["int64", "int32", "uint64", "uint32", "float64", "float32"] +) +@pytest.mark.parametrize("upper", [True, False]) +def test_rank_avg_even_vals(dtype, upper): + if upper: + # use IntegerDtype/FloatingDtype + dtype = dtype[0].upper() + dtype[1:] + dtype = dtype.replace("Ui", "UI") + df = DataFrame({"key": ["a"] * 4, "val": [1] * 4}) + df["val"] = df["val"].astype(dtype) + assert df["val"].dtype == dtype + + result = df.groupby("key").rank() + exp_df = DataFrame([2.5, 2.5, 2.5, 2.5], columns=["val"]) + if upper: + exp_df = exp_df.astype("Float64") + tm.assert_frame_equal(result, exp_df) + + +@pytest.mark.parametrize("ties_method", ["average", "min", "max", "first", "dense"]) +@pytest.mark.parametrize("ascending", [True, False]) +@pytest.mark.parametrize("na_option", ["keep", "top", "bottom"]) +@pytest.mark.parametrize("pct", [True, False]) +@pytest.mark.parametrize( + "vals", [["bar", "bar", "foo", "bar", "baz"], ["bar", np.nan, "foo", np.nan, "baz"]] +) +def test_rank_object_dtype(ties_method, ascending, na_option, pct, vals): + df = DataFrame({"key": ["foo"] * 5, "val": vals}) + mask = df["val"].isna() + + gb = df.groupby("key") + res = gb.rank(method=ties_method, ascending=ascending, na_option=na_option, pct=pct) + + # construct our expected by using numeric values with the same ordering + if mask.any(): + df2 = DataFrame({"key": ["foo"] * 5, "val": [0, np.nan, 2, np.nan, 1]}) + else: + df2 = DataFrame({"key": ["foo"] * 5, "val": [0, 0, 2, 0, 1]}) + + gb2 = df2.groupby("key") + alt = gb2.rank( + method=ties_method, ascending=ascending, na_option=na_option, pct=pct + ) + + tm.assert_frame_equal(res, alt) + + +@pytest.mark.parametrize("na_option", [True, "bad", 1]) +@pytest.mark.parametrize("ties_method", ["average", "min", "max", "first", "dense"]) +@pytest.mark.parametrize("ascending", [True, False]) +@pytest.mark.parametrize("pct", [True, False]) +@pytest.mark.parametrize( + "vals", + [ + ["bar", "bar", "foo", "bar", "baz"], + ["bar", np.nan, "foo", np.nan, "baz"], + [1, np.nan, 2, np.nan, 3], + ], +) +def test_rank_naoption_raises(ties_method, ascending, na_option, pct, vals): + df = DataFrame({"key": ["foo"] * 5, "val": vals}) + msg = "na_option must be one of 'keep', 'top', or 'bottom'" + + with pytest.raises(ValueError, match=msg): + df.groupby("key").rank( + method=ties_method, ascending=ascending, na_option=na_option, pct=pct + ) + + +def test_rank_empty_group(): + # see gh-22519 + column = "A" + df = DataFrame({"A": [0, 1, 0], "B": [1.0, np.nan, 2.0]}) + + result = df.groupby(column).B.rank(pct=True) + expected = Series([0.5, np.nan, 1.0], name="B") + tm.assert_series_equal(result, expected) + + result = df.groupby(column).rank(pct=True) + expected = DataFrame({"B": [0.5, np.nan, 1.0]}) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "input_key,input_value,output_value", + [ + ([1, 2], [1, 1], [1.0, 1.0]), + ([1, 1, 2, 2], [1, 2, 1, 2], [0.5, 1.0, 0.5, 1.0]), + ([1, 1, 2, 2], [1, 2, 1, np.nan], [0.5, 1.0, 1.0, np.nan]), + ([1, 1, 2], [1, 2, np.nan], [0.5, 1.0, np.nan]), + ], +) +def test_rank_zero_div(input_key, input_value, output_value): + # GH 23666 + df = DataFrame({"A": input_key, "B": input_value}) + + result = df.groupby("A").rank(method="dense", pct=True) + expected = DataFrame({"B": output_value}) + tm.assert_frame_equal(result, expected) + + +def test_rank_min_int(): + # GH-32859 + df = DataFrame( + { + "grp": [1, 1, 2], + "int_col": [ + np.iinfo(np.int64).min, + np.iinfo(np.int64).max, + np.iinfo(np.int64).min, + ], + "datetimelike": [NaT, datetime(2001, 1, 1), NaT], + } + ) + + result = df.groupby("grp").rank() + expected = DataFrame( + {"int_col": [1.0, 2.0, 1.0], "datetimelike": [np.nan, 1.0, np.nan]} + ) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("use_nan", [True, False]) +def test_rank_pct_equal_values_on_group_transition(use_nan): + # GH#40518 + fill_value = np.nan if use_nan else 3 + df = DataFrame( + [ + [-1, 1], + [-1, 2], + [1, fill_value], + [-1, fill_value], + ], + columns=["group", "val"], + ) + result = df.groupby(["group"])["val"].rank( + method="dense", + pct=True, + ) + if use_nan: + expected = Series([0.5, 1, np.nan, np.nan], name="val") + else: + expected = Series([1 / 3, 2 / 3, 1, 1], name="val") + + tm.assert_series_equal(result, expected) + + +def test_rank_multiindex(): + # GH27721 + df = concat( + { + "a": DataFrame({"col1": [3, 4], "col2": [1, 2]}), + "b": DataFrame({"col3": [5, 6], "col4": [7, 8]}), + }, + axis=1, + ) + + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + gb = df.groupby(level=0, axis=1) + msg = "DataFrameGroupBy.rank with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = gb.rank(axis=1) + + expected = concat( + [ + df["a"].rank(axis=1), + df["b"].rank(axis=1), + ], + axis=1, + keys=["a", "b"], + ) + tm.assert_frame_equal(result, expected) + + +def test_groupby_axis0_rank_axis1(): + # GH#41320 + df = DataFrame( + {0: [1, 3, 5, 7], 1: [2, 4, 6, 8], 2: [1.5, 3.5, 5.5, 7.5]}, + index=["a", "a", "b", "b"], + ) + msg = "The 'axis' keyword in DataFrame.groupby is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + gb = df.groupby(level=0, axis=0) + + msg = "DataFrameGroupBy.rank with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + res = gb.rank(axis=1) + + # This should match what we get when "manually" operating group-by-group + expected = concat([df.loc["a"].rank(axis=1), df.loc["b"].rank(axis=1)], axis=0) + tm.assert_frame_equal(res, expected) + + # check that we haven't accidentally written a case that coincidentally + # matches rank(axis=0) + msg = "The 'axis' keyword in DataFrameGroupBy.rank" + with tm.assert_produces_warning(FutureWarning, match=msg): + alt = gb.rank(axis=0) + assert not alt.equals(expected) + + +def test_groupby_axis0_cummax_axis1(): + # case where groupby axis is 0 and axis keyword in transform is 1 + + # df has mixed dtype -> multiple blocks + df = DataFrame( + {0: [1, 3, 5, 7], 1: [2, 4, 6, 8], 2: [1.5, 3.5, 5.5, 7.5]}, + index=["a", "a", "b", "b"], + ) + msg = "The 'axis' keyword in DataFrame.groupby is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + gb = df.groupby(level=0, axis=0) + + msg = "DataFrameGroupBy.cummax with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + cmax = gb.cummax(axis=1) + expected = df[[0, 1]].astype(np.float64) + expected[2] = expected[1] + tm.assert_frame_equal(cmax, expected) + + +def test_non_unique_index(): + # GH 16577 + df = DataFrame( + {"A": [1.0, 2.0, 3.0, np.nan], "value": 1.0}, + index=[pd.Timestamp("20170101", tz="US/Eastern")] * 4, + ) + result = df.groupby([df.index, "A"]).value.rank(ascending=True, pct=True) + expected = Series( + [1.0, 1.0, 1.0, np.nan], + index=[pd.Timestamp("20170101", tz="US/Eastern")] * 4, + name="value", + ) + tm.assert_series_equal(result, expected) + + +def test_rank_categorical(): + cat = pd.Categorical(["a", "a", "b", np.nan, "c", "b"], ordered=True) + cat2 = pd.Categorical([1, 2, 3, np.nan, 4, 5], ordered=True) + + df = DataFrame({"col1": [0, 1, 0, 1, 0, 1], "col2": cat, "col3": cat2}) + + gb = df.groupby("col1") + + res = gb.rank() + + expected = df.astype(object).groupby("col1").rank() + tm.assert_frame_equal(res, expected) + + +@pytest.mark.parametrize("na_option", ["top", "bottom"]) +def test_groupby_op_with_nullables(na_option): + # GH 54206 + df = DataFrame({"x": [None]}, dtype="Float64") + result = df.groupby("x", dropna=False)["x"].rank(method="min", na_option=na_option) + expected = Series([1.0], dtype="Float64", name=result.name) + tm.assert_series_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_sample.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd474741740d4abdea1ebabf2b36c3b68d690ad --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_sample.py @@ -0,0 +1,154 @@ +import pytest + +from pandas import ( + DataFrame, + Index, + Series, +) +import pandas._testing as tm + + +@pytest.mark.parametrize("n, frac", [(2, None), (None, 0.2)]) +def test_groupby_sample_balanced_groups_shape(n, frac): + values = [1] * 10 + [2] * 10 + df = DataFrame({"a": values, "b": values}) + + result = df.groupby("a").sample(n=n, frac=frac) + values = [1] * 2 + [2] * 2 + expected = DataFrame({"a": values, "b": values}, index=result.index) + tm.assert_frame_equal(result, expected) + + result = df.groupby("a")["b"].sample(n=n, frac=frac) + expected = Series(values, name="b", index=result.index) + tm.assert_series_equal(result, expected) + + +def test_groupby_sample_unbalanced_groups_shape(): + values = [1] * 10 + [2] * 20 + df = DataFrame({"a": values, "b": values}) + + result = df.groupby("a").sample(n=5) + values = [1] * 5 + [2] * 5 + expected = DataFrame({"a": values, "b": values}, index=result.index) + tm.assert_frame_equal(result, expected) + + result = df.groupby("a")["b"].sample(n=5) + expected = Series(values, name="b", index=result.index) + tm.assert_series_equal(result, expected) + + +def test_groupby_sample_index_value_spans_groups(): + values = [1] * 3 + [2] * 3 + df = DataFrame({"a": values, "b": values}, index=[1, 2, 2, 2, 2, 2]) + + result = df.groupby("a").sample(n=2) + values = [1] * 2 + [2] * 2 + expected = DataFrame({"a": values, "b": values}, index=result.index) + tm.assert_frame_equal(result, expected) + + result = df.groupby("a")["b"].sample(n=2) + expected = Series(values, name="b", index=result.index) + tm.assert_series_equal(result, expected) + + +def test_groupby_sample_n_and_frac_raises(): + df = DataFrame({"a": [1, 2], "b": [1, 2]}) + msg = "Please enter a value for `frac` OR `n`, not both" + + with pytest.raises(ValueError, match=msg): + df.groupby("a").sample(n=1, frac=1.0) + + with pytest.raises(ValueError, match=msg): + df.groupby("a")["b"].sample(n=1, frac=1.0) + + +def test_groupby_sample_frac_gt_one_without_replacement_raises(): + df = DataFrame({"a": [1, 2], "b": [1, 2]}) + msg = "Replace has to be set to `True` when upsampling the population `frac` > 1." + + with pytest.raises(ValueError, match=msg): + df.groupby("a").sample(frac=1.5, replace=False) + + with pytest.raises(ValueError, match=msg): + df.groupby("a")["b"].sample(frac=1.5, replace=False) + + +@pytest.mark.parametrize("n", [-1, 1.5]) +def test_groupby_sample_invalid_n_raises(n): + df = DataFrame({"a": [1, 2], "b": [1, 2]}) + + if n < 0: + msg = "A negative number of rows requested. Please provide `n` >= 0." + else: + msg = "Only integers accepted as `n` values" + + with pytest.raises(ValueError, match=msg): + df.groupby("a").sample(n=n) + + with pytest.raises(ValueError, match=msg): + df.groupby("a")["b"].sample(n=n) + + +def test_groupby_sample_oversample(): + values = [1] * 10 + [2] * 10 + df = DataFrame({"a": values, "b": values}) + + result = df.groupby("a").sample(frac=2.0, replace=True) + values = [1] * 20 + [2] * 20 + expected = DataFrame({"a": values, "b": values}, index=result.index) + tm.assert_frame_equal(result, expected) + + result = df.groupby("a")["b"].sample(frac=2.0, replace=True) + expected = Series(values, name="b", index=result.index) + tm.assert_series_equal(result, expected) + + +def test_groupby_sample_without_n_or_frac(): + values = [1] * 10 + [2] * 10 + df = DataFrame({"a": values, "b": values}) + + result = df.groupby("a").sample(n=None, frac=None) + expected = DataFrame({"a": [1, 2], "b": [1, 2]}, index=result.index) + tm.assert_frame_equal(result, expected) + + result = df.groupby("a")["b"].sample(n=None, frac=None) + expected = Series([1, 2], name="b", index=result.index) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "index, expected_index", + [(["w", "x", "y", "z"], ["w", "w", "y", "y"]), ([3, 4, 5, 6], [3, 3, 5, 5])], +) +def test_groupby_sample_with_weights(index, expected_index): + # GH 39927 - tests for integer index needed + values = [1] * 2 + [2] * 2 + df = DataFrame({"a": values, "b": values}, index=Index(index)) + + result = df.groupby("a").sample(n=2, replace=True, weights=[1, 0, 1, 0]) + expected = DataFrame({"a": values, "b": values}, index=Index(expected_index)) + tm.assert_frame_equal(result, expected) + + result = df.groupby("a")["b"].sample(n=2, replace=True, weights=[1, 0, 1, 0]) + expected = Series(values, name="b", index=Index(expected_index)) + tm.assert_series_equal(result, expected) + + +def test_groupby_sample_with_selections(): + # GH 39928 + values = [1] * 10 + [2] * 10 + df = DataFrame({"a": values, "b": values, "c": values}) + + result = df.groupby("a")[["b", "c"]].sample(n=None, frac=None) + expected = DataFrame({"b": [1, 2], "c": [1, 2]}, index=result.index) + tm.assert_frame_equal(result, expected) + + +def test_groupby_sample_with_empty_inputs(): + # GH48459 + df = DataFrame({"a": [], "b": []}) + groupby_df = df.groupby("a") + + result = groupby_df.sample() + expected = df + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_size.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_size.py new file mode 100644 index 0000000000000000000000000000000000000000..93a4e743d0d71db1d2a1fcca4163e6db83eb4ffb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_size.py @@ -0,0 +1,130 @@ +import numpy as np +import pytest + +import pandas.util._test_decorators as td + +from pandas.core.dtypes.common import is_integer_dtype + +from pandas import ( + DataFrame, + Index, + PeriodIndex, + Series, +) +import pandas._testing as tm + + +@pytest.mark.parametrize("by", ["A", "B", ["A", "B"]]) +def test_size(df, by): + grouped = df.groupby(by=by) + result = grouped.size() + for key, group in grouped: + assert result[key] == len(group) + + +@pytest.mark.parametrize( + "by", + [ + [0, 0, 0, 0], + [0, 1, 1, 1], + [1, 0, 1, 1], + [0, None, None, None], + pytest.param([None, None, None, None], marks=pytest.mark.xfail), + ], +) +def test_size_axis_1(df, axis_1, by, sort, dropna): + # GH#45715 + counts = {key: sum(value == key for value in by) for key in dict.fromkeys(by)} + if dropna: + counts = {key: value for key, value in counts.items() if key is not None} + expected = Series(counts, dtype="int64") + if sort: + expected = expected.sort_index() + if is_integer_dtype(expected.index.dtype) and not any(x is None for x in by): + expected.index = expected.index.astype(int) + + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + grouped = df.groupby(by=by, axis=axis_1, sort=sort, dropna=dropna) + result = grouped.size() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("by", ["A", "B", ["A", "B"]]) +@pytest.mark.parametrize("sort", [True, False]) +def test_size_sort(sort, by): + df = DataFrame(np.random.default_rng(2).choice(20, (1000, 3)), columns=list("ABC")) + left = df.groupby(by=by, sort=sort).size() + right = df.groupby(by=by, sort=sort)["C"].apply(lambda a: a.shape[0]) + tm.assert_series_equal(left, right, check_names=False) + + +def test_size_series_dataframe(): + # https://github.com/pandas-dev/pandas/issues/11699 + df = DataFrame(columns=["A", "B"]) + out = Series(dtype="int64", index=Index([], name="A")) + tm.assert_series_equal(df.groupby("A").size(), out) + + +def test_size_groupby_all_null(): + # https://github.com/pandas-dev/pandas/issues/23050 + # Assert no 'Value Error : Length of passed values is 2, index implies 0' + df = DataFrame({"A": [None, None]}) # all-null groups + result = df.groupby("A").size() + expected = Series(dtype="int64", index=Index([], name="A")) + tm.assert_series_equal(result, expected) + + +def test_size_period_index(): + # https://github.com/pandas-dev/pandas/issues/34010 + ser = Series([1], index=PeriodIndex(["2000"], name="A", freq="D")) + grp = ser.groupby(level="A") + result = grp.size() + tm.assert_series_equal(result, ser) + + +@pytest.mark.parametrize("as_index", [True, False]) +def test_size_on_categorical(as_index): + df = DataFrame([[1, 1], [2, 2]], columns=["A", "B"]) + df["A"] = df["A"].astype("category") + result = df.groupby(["A", "B"], as_index=as_index, observed=False).size() + + expected = DataFrame( + [[1, 1, 1], [1, 2, 0], [2, 1, 0], [2, 2, 1]], columns=["A", "B", "size"] + ) + expected["A"] = expected["A"].astype("category") + if as_index: + expected = expected.set_index(["A", "B"])["size"].rename(None) + + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize("dtype", ["Int64", "Float64", "boolean"]) +def test_size_series_masked_type_returns_Int64(dtype): + # GH 54132 + ser = Series([1, 1, 1], index=["a", "a", "b"], dtype=dtype) + result = ser.groupby(level=0).size() + expected = Series([2, 1], dtype="Int64", index=["a", "b"]) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "dtype", + [ + object, + pytest.param("string[pyarrow_numpy]", marks=td.skip_if_no("pyarrow")), + pytest.param("string[pyarrow]", marks=td.skip_if_no("pyarrow")), + ], +) +def test_size_strings(dtype): + # GH#55627 + df = DataFrame({"a": ["a", "a", "b"], "b": "a"}, dtype=dtype) + result = df.groupby("a")["b"].size() + exp_dtype = "Int64" if dtype == "string[pyarrow]" else "int64" + expected = Series( + [2, 1], + index=Index(["a", "b"], name="a", dtype=dtype), + name="b", + dtype=exp_dtype, + ) + tm.assert_series_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_skew.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_skew.py new file mode 100644 index 0000000000000000000000000000000000000000..563da89b6ab24a898f042f0e21377ccc2709b072 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_skew.py @@ -0,0 +1,27 @@ +import numpy as np + +import pandas as pd +import pandas._testing as tm + + +def test_groupby_skew_equivalence(): + # Test that that groupby skew method (which uses libgroupby.group_skew) + # matches the results of operating group-by-group (which uses nanops.nanskew) + nrows = 1000 + ngroups = 3 + ncols = 2 + nan_frac = 0.05 + + arr = np.random.default_rng(2).standard_normal((nrows, ncols)) + arr[np.random.default_rng(2).random(nrows) < nan_frac] = np.nan + + df = pd.DataFrame(arr) + grps = np.random.default_rng(2).integers(0, ngroups, size=nrows) + gb = df.groupby(grps) + + result = gb.skew() + + grpwise = [grp.skew().to_frame(i).T for i, grp in gb] + expected = pd.concat(grpwise, axis=0) + expected.index = expected.index.astype(result.index.dtype) # 32bit builds + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_value_counts.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_value_counts.py new file mode 100644 index 0000000000000000000000000000000000000000..8e25177368d8b7bbc412c930dc6ef2a278aa29db --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/methods/test_value_counts.py @@ -0,0 +1,1241 @@ +""" +these are systematically testing all of the args to value_counts +with different size combinations. This is to ensure stability of the sorting +and proper parameter handling +""" + + +import numpy as np +import pytest + +import pandas.util._test_decorators as td + +from pandas import ( + Categorical, + CategoricalIndex, + DataFrame, + Grouper, + Index, + MultiIndex, + Series, + date_range, + to_datetime, +) +import pandas._testing as tm +from pandas.util.version import Version + + +def tests_value_counts_index_names_category_column(): + # GH44324 Missing name of index category column + df = DataFrame( + { + "gender": ["female"], + "country": ["US"], + } + ) + df["gender"] = df["gender"].astype("category") + result = df.groupby("country")["gender"].value_counts() + + # Construct expected, very specific multiindex + df_mi_expected = DataFrame([["US", "female"]], columns=["country", "gender"]) + df_mi_expected["gender"] = df_mi_expected["gender"].astype("category") + mi_expected = MultiIndex.from_frame(df_mi_expected) + expected = Series([1], index=mi_expected, name="count") + + tm.assert_series_equal(result, expected) + + +def seed_df(seed_nans, n, m): + days = date_range("2015-08-24", periods=10) + + frame = DataFrame( + { + "1st": np.random.default_rng(2).choice(list("abcd"), n), + "2nd": np.random.default_rng(2).choice(days, n), + "3rd": np.random.default_rng(2).integers(1, m + 1, n), + } + ) + + if seed_nans: + # Explicitly cast to float to avoid implicit cast when setting nan + frame["3rd"] = frame["3rd"].astype("float") + frame.loc[1::11, "1st"] = np.nan + frame.loc[3::17, "2nd"] = np.nan + frame.loc[7::19, "3rd"] = np.nan + frame.loc[8::19, "3rd"] = np.nan + frame.loc[9::19, "3rd"] = np.nan + + return frame + + +@pytest.mark.slow +@pytest.mark.parametrize("seed_nans", [True, False]) +@pytest.mark.parametrize("num_rows", [10, 50]) +@pytest.mark.parametrize("max_int", [5, 20]) +@pytest.mark.parametrize("keys", ["1st", "2nd", ["1st", "2nd"]], ids=repr) +@pytest.mark.parametrize("bins", [None, [0, 5]], ids=repr) +@pytest.mark.parametrize("isort", [True, False]) +@pytest.mark.parametrize("normalize, name", [(True, "proportion"), (False, "count")]) +@pytest.mark.parametrize("sort", [True, False]) +@pytest.mark.parametrize("ascending", [True, False]) +@pytest.mark.parametrize("dropna", [True, False]) +def test_series_groupby_value_counts( + seed_nans, + num_rows, + max_int, + keys, + bins, + isort, + normalize, + name, + sort, + ascending, + dropna, +): + df = seed_df(seed_nans, num_rows, max_int) + + def rebuild_index(df): + arr = list(map(df.index.get_level_values, range(df.index.nlevels))) + df.index = MultiIndex.from_arrays(arr, names=df.index.names) + return df + + kwargs = { + "normalize": normalize, + "sort": sort, + "ascending": ascending, + "dropna": dropna, + "bins": bins, + } + + gr = df.groupby(keys, sort=isort) + left = gr["3rd"].value_counts(**kwargs) + + gr = df.groupby(keys, sort=isort) + right = gr["3rd"].apply(Series.value_counts, **kwargs) + right.index.names = right.index.names[:-1] + ["3rd"] + # https://github.com/pandas-dev/pandas/issues/49909 + right = right.rename(name) + + # have to sort on index because of unstable sort on values + left, right = map(rebuild_index, (left, right)) # xref GH9212 + tm.assert_series_equal(left.sort_index(), right.sort_index()) + + +@pytest.mark.parametrize("utc", [True, False]) +def test_series_groupby_value_counts_with_grouper(utc): + # GH28479 + df = DataFrame( + { + "Timestamp": [ + 1565083561, + 1565083561 + 86400, + 1565083561 + 86500, + 1565083561 + 86400 * 2, + 1565083561 + 86400 * 3, + 1565083561 + 86500 * 3, + 1565083561 + 86400 * 4, + ], + "Food": ["apple", "apple", "banana", "banana", "orange", "orange", "pear"], + } + ).drop([3]) + + df["Datetime"] = to_datetime(df["Timestamp"], utc=utc, unit="s") + dfg = df.groupby(Grouper(freq="1D", key="Datetime")) + + # have to sort on index because of unstable sort on values xref GH9212 + result = dfg["Food"].value_counts().sort_index() + expected = dfg["Food"].apply(Series.value_counts).sort_index() + expected.index.names = result.index.names + # https://github.com/pandas-dev/pandas/issues/49909 + expected = expected.rename("count") + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("columns", [["A", "B"], ["A", "B", "C"]]) +def test_series_groupby_value_counts_empty(columns): + # GH39172 + df = DataFrame(columns=columns) + dfg = df.groupby(columns[:-1]) + + result = dfg[columns[-1]].value_counts() + expected = Series([], dtype=result.dtype, name="count") + expected.index = MultiIndex.from_arrays([[]] * len(columns), names=columns) + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("columns", [["A", "B"], ["A", "B", "C"]]) +def test_series_groupby_value_counts_one_row(columns): + # GH42618 + df = DataFrame(data=[range(len(columns))], columns=columns) + dfg = df.groupby(columns[:-1]) + + result = dfg[columns[-1]].value_counts() + expected = df.value_counts() + + tm.assert_series_equal(result, expected) + + +def test_series_groupby_value_counts_on_categorical(): + # GH38672 + + s = Series(Categorical(["a"], categories=["a", "b"])) + result = s.groupby([0]).value_counts() + + expected = Series( + data=[1, 0], + index=MultiIndex.from_arrays( + [ + np.array([0, 0]), + CategoricalIndex( + ["a", "b"], categories=["a", "b"], ordered=False, dtype="category" + ), + ] + ), + name="count", + ) + + # Expected: + # 0 a 1 + # b 0 + # dtype: int64 + + tm.assert_series_equal(result, expected) + + +def test_series_groupby_value_counts_no_sort(): + # GH#50482 + df = DataFrame( + { + "gender": ["male", "male", "female", "male", "female", "male"], + "education": ["low", "medium", "high", "low", "high", "low"], + "country": ["US", "FR", "US", "FR", "FR", "FR"], + } + ) + gb = df.groupby(["country", "gender"], sort=False)["education"] + result = gb.value_counts(sort=False) + index = MultiIndex( + levels=[["US", "FR"], ["male", "female"], ["low", "medium", "high"]], + codes=[[0, 1, 0, 1, 1], [0, 0, 1, 0, 1], [0, 1, 2, 0, 2]], + names=["country", "gender", "education"], + ) + expected = Series([1, 1, 1, 2, 1], index=index, name="count") + tm.assert_series_equal(result, expected) + + +@pytest.fixture +def education_df(): + return DataFrame( + { + "gender": ["male", "male", "female", "male", "female", "male"], + "education": ["low", "medium", "high", "low", "high", "low"], + "country": ["US", "FR", "US", "FR", "FR", "FR"], + } + ) + + +def test_axis(education_df): + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + gp = education_df.groupby("country", axis=1) + with pytest.raises(NotImplementedError, match="axis"): + gp.value_counts() + + +def test_bad_subset(education_df): + gp = education_df.groupby("country") + with pytest.raises(ValueError, match="subset"): + gp.value_counts(subset=["country"]) + + +def test_basic(education_df, request): + # gh43564 + if Version(np.__version__) >= Version("1.25"): + request.applymarker( + pytest.mark.xfail( + reason=( + "pandas default unstable sorting of duplicates" + "issue with numpy>=1.25 with AVX instructions" + ), + strict=False, + ) + ) + result = education_df.groupby("country")[["gender", "education"]].value_counts( + normalize=True + ) + expected = Series( + data=[0.5, 0.25, 0.25, 0.5, 0.5], + index=MultiIndex.from_tuples( + [ + ("FR", "male", "low"), + ("FR", "female", "high"), + ("FR", "male", "medium"), + ("US", "female", "high"), + ("US", "male", "low"), + ], + names=["country", "gender", "education"], + ), + name="proportion", + ) + tm.assert_series_equal(result, expected) + + +def _frame_value_counts(df, keys, normalize, sort, ascending): + return df[keys].value_counts(normalize=normalize, sort=sort, ascending=ascending) + + +@pytest.mark.parametrize("groupby", ["column", "array", "function"]) +@pytest.mark.parametrize("normalize, name", [(True, "proportion"), (False, "count")]) +@pytest.mark.parametrize( + "sort, ascending", + [ + (False, None), + (True, True), + (True, False), + ], +) +@pytest.mark.parametrize("as_index", [True, False]) +@pytest.mark.parametrize("frame", [True, False]) +def test_against_frame_and_seriesgroupby( + education_df, groupby, normalize, name, sort, ascending, as_index, frame, request +): + # test all parameters: + # - Use column, array or function as by= parameter + # - Whether or not to normalize + # - Whether or not to sort and how + # - Whether or not to use the groupby as an index + # - 3-way compare against: + # - apply with :meth:`~DataFrame.value_counts` + # - `~SeriesGroupBy.value_counts` + if Version(np.__version__) >= Version("1.25") and frame and sort and normalize: + request.applymarker( + pytest.mark.xfail( + reason=( + "pandas default unstable sorting of duplicates" + "issue with numpy>=1.25 with AVX instructions" + ), + strict=False, + ) + ) + by = { + "column": "country", + "array": education_df["country"].values, + "function": lambda x: education_df["country"][x] == "US", + }[groupby] + + gp = education_df.groupby(by=by, as_index=as_index) + result = gp[["gender", "education"]].value_counts( + normalize=normalize, sort=sort, ascending=ascending + ) + if frame: + # compare against apply with DataFrame value_counts + warn = DeprecationWarning if groupby == "column" else None + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(warn, match=msg): + expected = gp.apply( + _frame_value_counts, ["gender", "education"], normalize, sort, ascending + ) + + if as_index: + tm.assert_series_equal(result, expected) + else: + name = "proportion" if normalize else "count" + expected = expected.reset_index().rename({0: name}, axis=1) + if groupby == "column": + expected = expected.rename({"level_0": "country"}, axis=1) + expected["country"] = np.where(expected["country"], "US", "FR") + elif groupby == "function": + expected["level_0"] = expected["level_0"] == 1 + else: + expected["level_0"] = np.where(expected["level_0"], "US", "FR") + tm.assert_frame_equal(result, expected) + else: + # compare against SeriesGroupBy value_counts + education_df["both"] = education_df["gender"] + "-" + education_df["education"] + expected = gp["both"].value_counts( + normalize=normalize, sort=sort, ascending=ascending + ) + expected.name = name + if as_index: + index_frame = expected.index.to_frame(index=False) + index_frame["gender"] = index_frame["both"].str.split("-").str.get(0) + index_frame["education"] = index_frame["both"].str.split("-").str.get(1) + del index_frame["both"] + index_frame = index_frame.rename({0: None}, axis=1) + expected.index = MultiIndex.from_frame(index_frame) + tm.assert_series_equal(result, expected) + else: + expected.insert(1, "gender", expected["both"].str.split("-").str.get(0)) + expected.insert(2, "education", expected["both"].str.split("-").str.get(1)) + del expected["both"] + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "dtype", + [ + object, + pytest.param("string[pyarrow_numpy]", marks=td.skip_if_no("pyarrow")), + pytest.param("string[pyarrow]", marks=td.skip_if_no("pyarrow")), + ], +) +@pytest.mark.parametrize("normalize", [True, False]) +@pytest.mark.parametrize( + "sort, ascending, expected_rows, expected_count, expected_group_size", + [ + (False, None, [0, 1, 2, 3, 4], [1, 1, 1, 2, 1], [1, 3, 1, 3, 1]), + (True, False, [3, 0, 1, 2, 4], [2, 1, 1, 1, 1], [3, 1, 3, 1, 1]), + (True, True, [0, 1, 2, 4, 3], [1, 1, 1, 1, 2], [1, 3, 1, 1, 3]), + ], +) +def test_compound( + education_df, + normalize, + sort, + ascending, + expected_rows, + expected_count, + expected_group_size, + dtype, +): + education_df = education_df.astype(dtype) + education_df.columns = education_df.columns.astype(dtype) + # Multiple groupby keys and as_index=False + gp = education_df.groupby(["country", "gender"], as_index=False, sort=False) + result = gp["education"].value_counts( + normalize=normalize, sort=sort, ascending=ascending + ) + expected = DataFrame() + for column in ["country", "gender", "education"]: + expected[column] = [education_df[column][row] for row in expected_rows] + expected = expected.astype(dtype) + expected.columns = expected.columns.astype(dtype) + if normalize: + expected["proportion"] = expected_count + expected["proportion"] /= expected_group_size + if dtype == "string[pyarrow]": + expected["proportion"] = expected["proportion"].convert_dtypes() + else: + expected["count"] = expected_count + if dtype == "string[pyarrow]": + expected["count"] = expected["count"].convert_dtypes() + tm.assert_frame_equal(result, expected) + + +@pytest.fixture +def animals_df(): + return DataFrame( + {"key": [1, 1, 1, 1], "num_legs": [2, 4, 4, 6], "num_wings": [2, 0, 0, 0]}, + index=["falcon", "dog", "cat", "ant"], + ) + + +@pytest.mark.parametrize( + "sort, ascending, normalize, name, expected_data, expected_index", + [ + (False, None, False, "count", [1, 2, 1], [(1, 1, 1), (2, 4, 6), (2, 0, 0)]), + (True, True, False, "count", [1, 1, 2], [(1, 1, 1), (2, 6, 4), (2, 0, 0)]), + (True, False, False, "count", [2, 1, 1], [(1, 1, 1), (4, 2, 6), (0, 2, 0)]), + ( + True, + False, + True, + "proportion", + [0.5, 0.25, 0.25], + [(1, 1, 1), (4, 2, 6), (0, 2, 0)], + ), + ], +) +def test_data_frame_value_counts( + animals_df, sort, ascending, normalize, name, expected_data, expected_index +): + # 3-way compare with :meth:`~DataFrame.value_counts` + # Tests from frame/methods/test_value_counts.py + result_frame = animals_df.value_counts( + sort=sort, ascending=ascending, normalize=normalize + ) + expected = Series( + data=expected_data, + index=MultiIndex.from_arrays( + expected_index, names=["key", "num_legs", "num_wings"] + ), + name=name, + ) + tm.assert_series_equal(result_frame, expected) + + result_frame_groupby = animals_df.groupby("key").value_counts( + sort=sort, ascending=ascending, normalize=normalize + ) + + tm.assert_series_equal(result_frame_groupby, expected) + + +@pytest.fixture +def nulls_df(): + n = np.nan + return DataFrame( + { + "A": [1, 1, n, 4, n, 6, 6, 6, 6], + "B": [1, 1, 3, n, n, 6, 6, 6, 6], + "C": [1, 2, 3, 4, 5, 6, n, 8, n], + "D": [1, 2, 3, 4, 5, 6, 7, n, n], + } + ) + + +@pytest.mark.parametrize( + "group_dropna, count_dropna, expected_rows, expected_values", + [ + ( + False, + False, + [0, 1, 3, 5, 7, 6, 8, 2, 4], + [0.5, 0.5, 1.0, 0.25, 0.25, 0.25, 0.25, 1.0, 1.0], + ), + (False, True, [0, 1, 3, 5, 2, 4], [0.5, 0.5, 1.0, 1.0, 1.0, 1.0]), + (True, False, [0, 1, 5, 7, 6, 8], [0.5, 0.5, 0.25, 0.25, 0.25, 0.25]), + (True, True, [0, 1, 5], [0.5, 0.5, 1.0]), + ], +) +def test_dropna_combinations( + nulls_df, group_dropna, count_dropna, expected_rows, expected_values, request +): + if Version(np.__version__) >= Version("1.25") and not group_dropna: + request.applymarker( + pytest.mark.xfail( + reason=( + "pandas default unstable sorting of duplicates" + "issue with numpy>=1.25 with AVX instructions" + ), + strict=False, + ) + ) + gp = nulls_df.groupby(["A", "B"], dropna=group_dropna) + result = gp.value_counts(normalize=True, sort=True, dropna=count_dropna) + columns = DataFrame() + for column in nulls_df.columns: + columns[column] = [nulls_df[column][row] for row in expected_rows] + index = MultiIndex.from_frame(columns) + expected = Series(data=expected_values, index=index, name="proportion") + tm.assert_series_equal(result, expected) + + +@pytest.fixture +def names_with_nulls_df(nulls_fixture): + return DataFrame( + { + "key": [1, 1, 1, 1], + "first_name": ["John", "Anne", "John", "Beth"], + "middle_name": ["Smith", nulls_fixture, nulls_fixture, "Louise"], + }, + ) + + +@pytest.mark.parametrize( + "dropna, expected_data, expected_index", + [ + ( + True, + [1, 1], + MultiIndex.from_arrays( + [(1, 1), ("Beth", "John"), ("Louise", "Smith")], + names=["key", "first_name", "middle_name"], + ), + ), + ( + False, + [1, 1, 1, 1], + MultiIndex( + levels=[ + Index([1]), + Index(["Anne", "Beth", "John"]), + Index(["Louise", "Smith", np.nan]), + ], + codes=[[0, 0, 0, 0], [0, 1, 2, 2], [2, 0, 1, 2]], + names=["key", "first_name", "middle_name"], + ), + ), + ], +) +@pytest.mark.parametrize("normalize, name", [(False, "count"), (True, "proportion")]) +def test_data_frame_value_counts_dropna( + names_with_nulls_df, dropna, normalize, name, expected_data, expected_index +): + # GH 41334 + # 3-way compare with :meth:`~DataFrame.value_counts` + # Tests with nulls from frame/methods/test_value_counts.py + result_frame = names_with_nulls_df.value_counts(dropna=dropna, normalize=normalize) + expected = Series( + data=expected_data, + index=expected_index, + name=name, + ) + if normalize: + expected /= float(len(expected_data)) + + tm.assert_series_equal(result_frame, expected) + + result_frame_groupby = names_with_nulls_df.groupby("key").value_counts( + dropna=dropna, normalize=normalize + ) + + tm.assert_series_equal(result_frame_groupby, expected) + + +@pytest.mark.parametrize("as_index", [False, True]) +@pytest.mark.parametrize("observed", [False, True]) +@pytest.mark.parametrize( + "normalize, name, expected_data", + [ + ( + False, + "count", + np.array([2, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0], dtype=np.int64), + ), + ( + True, + "proportion", + np.array([0.5, 0.25, 0.25, 0.0, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0, 0.0, 0.0]), + ), + ], +) +def test_categorical_single_grouper_with_only_observed_categories( + education_df, as_index, observed, normalize, name, expected_data, request +): + # Test single categorical grouper with only observed grouping categories + # when non-groupers are also categorical + if Version(np.__version__) >= Version("1.25"): + request.applymarker( + pytest.mark.xfail( + reason=( + "pandas default unstable sorting of duplicates" + "issue with numpy>=1.25 with AVX instructions" + ), + strict=False, + ) + ) + + gp = education_df.astype("category").groupby( + "country", as_index=as_index, observed=observed + ) + result = gp.value_counts(normalize=normalize) + + expected_index = MultiIndex.from_tuples( + [ + ("FR", "male", "low"), + ("FR", "female", "high"), + ("FR", "male", "medium"), + ("FR", "female", "low"), + ("FR", "female", "medium"), + ("FR", "male", "high"), + ("US", "female", "high"), + ("US", "male", "low"), + ("US", "female", "low"), + ("US", "female", "medium"), + ("US", "male", "high"), + ("US", "male", "medium"), + ], + names=["country", "gender", "education"], + ) + + expected_series = Series( + data=expected_data, + index=expected_index, + name=name, + ) + for i in range(3): + expected_series.index = expected_series.index.set_levels( + CategoricalIndex(expected_series.index.levels[i]), level=i + ) + + if as_index: + tm.assert_series_equal(result, expected_series) + else: + expected = expected_series.reset_index( + name="proportion" if normalize else "count" + ) + tm.assert_frame_equal(result, expected) + + +def assert_categorical_single_grouper( + education_df, as_index, observed, expected_index, normalize, name, expected_data +): + # Test single categorical grouper when non-groupers are also categorical + education_df = education_df.copy().astype("category") + + # Add non-observed grouping categories + education_df["country"] = education_df["country"].cat.add_categories(["ASIA"]) + + gp = education_df.groupby("country", as_index=as_index, observed=observed) + result = gp.value_counts(normalize=normalize) + + expected_series = Series( + data=expected_data, + index=MultiIndex.from_tuples( + expected_index, + names=["country", "gender", "education"], + ), + name=name, + ) + for i in range(3): + index_level = CategoricalIndex(expected_series.index.levels[i]) + if i == 0: + index_level = index_level.set_categories( + education_df["country"].cat.categories + ) + expected_series.index = expected_series.index.set_levels(index_level, level=i) + + if as_index: + tm.assert_series_equal(result, expected_series) + else: + expected = expected_series.reset_index(name=name) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("as_index", [True, False]) +@pytest.mark.parametrize( + "normalize, name, expected_data", + [ + ( + False, + "count", + np.array([2, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0], dtype=np.int64), + ), + ( + True, + "proportion", + np.array([0.5, 0.25, 0.25, 0.0, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0, 0.0, 0.0]), + ), + ], +) +def test_categorical_single_grouper_observed_true( + education_df, as_index, normalize, name, expected_data, request +): + # GH#46357 + + if Version(np.__version__) >= Version("1.25"): + request.applymarker( + pytest.mark.xfail( + reason=( + "pandas default unstable sorting of duplicates" + "issue with numpy>=1.25 with AVX instructions" + ), + strict=False, + ) + ) + + expected_index = [ + ("FR", "male", "low"), + ("FR", "female", "high"), + ("FR", "male", "medium"), + ("FR", "female", "low"), + ("FR", "female", "medium"), + ("FR", "male", "high"), + ("US", "female", "high"), + ("US", "male", "low"), + ("US", "female", "low"), + ("US", "female", "medium"), + ("US", "male", "high"), + ("US", "male", "medium"), + ] + + assert_categorical_single_grouper( + education_df=education_df, + as_index=as_index, + observed=True, + expected_index=expected_index, + normalize=normalize, + name=name, + expected_data=expected_data, + ) + + +@pytest.mark.parametrize("as_index", [True, False]) +@pytest.mark.parametrize( + "normalize, name, expected_data", + [ + ( + False, + "count", + np.array( + [2, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.int64 + ), + ), + ( + True, + "proportion", + np.array( + [ + 0.5, + 0.25, + 0.25, + 0.0, + 0.0, + 0.0, + 0.5, + 0.5, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ] + ), + ), + ], +) +def test_categorical_single_grouper_observed_false( + education_df, as_index, normalize, name, expected_data, request +): + # GH#46357 + + if Version(np.__version__) >= Version("1.25"): + request.applymarker( + pytest.mark.xfail( + reason=( + "pandas default unstable sorting of duplicates" + "issue with numpy>=1.25 with AVX instructions" + ), + strict=False, + ) + ) + + expected_index = [ + ("FR", "male", "low"), + ("FR", "female", "high"), + ("FR", "male", "medium"), + ("FR", "female", "low"), + ("FR", "female", "medium"), + ("FR", "male", "high"), + ("US", "female", "high"), + ("US", "male", "low"), + ("US", "female", "low"), + ("US", "female", "medium"), + ("US", "male", "high"), + ("US", "male", "medium"), + ("ASIA", "female", "high"), + ("ASIA", "female", "low"), + ("ASIA", "female", "medium"), + ("ASIA", "male", "high"), + ("ASIA", "male", "low"), + ("ASIA", "male", "medium"), + ] + + assert_categorical_single_grouper( + education_df=education_df, + as_index=as_index, + observed=False, + expected_index=expected_index, + normalize=normalize, + name=name, + expected_data=expected_data, + ) + + +@pytest.mark.parametrize("as_index", [True, False]) +@pytest.mark.parametrize( + "observed, expected_index", + [ + ( + False, + [ + ("FR", "high", "female"), + ("FR", "high", "male"), + ("FR", "low", "male"), + ("FR", "low", "female"), + ("FR", "medium", "male"), + ("FR", "medium", "female"), + ("US", "high", "female"), + ("US", "high", "male"), + ("US", "low", "male"), + ("US", "low", "female"), + ("US", "medium", "female"), + ("US", "medium", "male"), + ], + ), + ( + True, + [ + ("FR", "high", "female"), + ("FR", "low", "male"), + ("FR", "medium", "male"), + ("US", "high", "female"), + ("US", "low", "male"), + ], + ), + ], +) +@pytest.mark.parametrize( + "normalize, name, expected_data", + [ + ( + False, + "count", + np.array([1, 0, 2, 0, 1, 0, 1, 0, 1, 0, 0, 0], dtype=np.int64), + ), + ( + True, + "proportion", + # NaN values corresponds to non-observed groups + np.array([1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0]), + ), + ], +) +def test_categorical_multiple_groupers( + education_df, as_index, observed, expected_index, normalize, name, expected_data +): + # GH#46357 + + # Test multiple categorical groupers when non-groupers are non-categorical + education_df = education_df.copy() + education_df["country"] = education_df["country"].astype("category") + education_df["education"] = education_df["education"].astype("category") + + gp = education_df.groupby( + ["country", "education"], as_index=as_index, observed=observed + ) + result = gp.value_counts(normalize=normalize) + + expected_series = Series( + data=expected_data[expected_data > 0.0] if observed else expected_data, + index=MultiIndex.from_tuples( + expected_index, + names=["country", "education", "gender"], + ), + name=name, + ) + for i in range(2): + expected_series.index = expected_series.index.set_levels( + CategoricalIndex(expected_series.index.levels[i]), level=i + ) + + if as_index: + tm.assert_series_equal(result, expected_series) + else: + expected = expected_series.reset_index( + name="proportion" if normalize else "count" + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("as_index", [False, True]) +@pytest.mark.parametrize("observed", [False, True]) +@pytest.mark.parametrize( + "normalize, name, expected_data", + [ + ( + False, + "count", + np.array([2, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0], dtype=np.int64), + ), + ( + True, + "proportion", + # NaN values corresponds to non-observed groups + np.array([0.5, 0.25, 0.25, 0.0, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0, 0.0, 0.0]), + ), + ], +) +def test_categorical_non_groupers( + education_df, as_index, observed, normalize, name, expected_data, request +): + # GH#46357 Test non-observed categories are included in the result, + # regardless of `observed` + + if Version(np.__version__) >= Version("1.25"): + request.applymarker( + pytest.mark.xfail( + reason=( + "pandas default unstable sorting of duplicates" + "issue with numpy>=1.25 with AVX instructions" + ), + strict=False, + ) + ) + + education_df = education_df.copy() + education_df["gender"] = education_df["gender"].astype("category") + education_df["education"] = education_df["education"].astype("category") + + gp = education_df.groupby("country", as_index=as_index, observed=observed) + result = gp.value_counts(normalize=normalize) + + expected_index = [ + ("FR", "male", "low"), + ("FR", "female", "high"), + ("FR", "male", "medium"), + ("FR", "female", "low"), + ("FR", "female", "medium"), + ("FR", "male", "high"), + ("US", "female", "high"), + ("US", "male", "low"), + ("US", "female", "low"), + ("US", "female", "medium"), + ("US", "male", "high"), + ("US", "male", "medium"), + ] + expected_series = Series( + data=expected_data, + index=MultiIndex.from_tuples( + expected_index, + names=["country", "gender", "education"], + ), + name=name, + ) + for i in range(1, 3): + expected_series.index = expected_series.index.set_levels( + CategoricalIndex(expected_series.index.levels[i]), level=i + ) + + if as_index: + tm.assert_series_equal(result, expected_series) + else: + expected = expected_series.reset_index( + name="proportion" if normalize else "count" + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "normalize, expected_label, expected_values", + [ + (False, "count", [1, 1, 1]), + (True, "proportion", [0.5, 0.5, 1.0]), + ], +) +def test_mixed_groupings(normalize, expected_label, expected_values): + # Test multiple groupings + df = DataFrame({"A": [1, 2, 1], "B": [1, 2, 3]}) + gp = df.groupby([[4, 5, 4], "A", lambda i: 7 if i == 1 else 8], as_index=False) + result = gp.value_counts(sort=True, normalize=normalize) + expected = DataFrame( + { + "level_0": np.array([4, 4, 5], dtype=int), + "A": [1, 1, 2], + "level_2": [8, 8, 7], + "B": [1, 3, 2], + expected_label: expected_values, + } + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "test, columns, expected_names", + [ + ("repeat", list("abbde"), ["a", None, "d", "b", "b", "e"]), + ("level", list("abcd") + ["level_1"], ["a", None, "d", "b", "c", "level_1"]), + ], +) +@pytest.mark.parametrize("as_index", [False, True]) +def test_column_label_duplicates(test, columns, expected_names, as_index): + # GH 44992 + # Test for duplicate input column labels and generated duplicate labels + df = DataFrame([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]], columns=columns) + expected_data = [(1, 0, 7, 3, 5, 9), (2, 1, 8, 4, 6, 10)] + keys = ["a", np.array([0, 1], dtype=np.int64), "d"] + result = df.groupby(keys, as_index=as_index).value_counts() + if as_index: + expected = Series( + data=(1, 1), + index=MultiIndex.from_tuples( + expected_data, + names=expected_names, + ), + name="count", + ) + tm.assert_series_equal(result, expected) + else: + expected_data = [list(row) + [1] for row in expected_data] + expected_columns = list(expected_names) + expected_columns[1] = "level_1" + expected_columns.append("count") + expected = DataFrame(expected_data, columns=expected_columns) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "normalize, expected_label", + [ + (False, "count"), + (True, "proportion"), + ], +) +def test_result_label_duplicates(normalize, expected_label): + # Test for result column label duplicating an input column label + gb = DataFrame([[1, 2, 3]], columns=["a", "b", expected_label]).groupby( + "a", as_index=False + ) + msg = f"Column label '{expected_label}' is duplicate of result column" + with pytest.raises(ValueError, match=msg): + gb.value_counts(normalize=normalize) + + +def test_ambiguous_grouping(): + # Test that groupby is not confused by groupings length equal to row count + df = DataFrame({"a": [1, 1]}) + gb = df.groupby(np.array([1, 1], dtype=np.int64)) + result = gb.value_counts() + expected = Series( + [2], index=MultiIndex.from_tuples([[1, 1]], names=[None, "a"]), name="count" + ) + tm.assert_series_equal(result, expected) + + +def test_subset_overlaps_gb_key_raises(): + # GH 46383 + df = DataFrame({"c1": ["a", "b", "c"], "c2": ["x", "y", "y"]}, index=[0, 1, 1]) + msg = "Keys {'c1'} in subset cannot be in the groupby column keys." + with pytest.raises(ValueError, match=msg): + df.groupby("c1").value_counts(subset=["c1"]) + + +def test_subset_doesnt_exist_in_frame(): + # GH 46383 + df = DataFrame({"c1": ["a", "b", "c"], "c2": ["x", "y", "y"]}, index=[0, 1, 1]) + msg = "Keys {'c3'} in subset do not exist in the DataFrame." + with pytest.raises(ValueError, match=msg): + df.groupby("c1").value_counts(subset=["c3"]) + + +def test_subset(): + # GH 46383 + df = DataFrame({"c1": ["a", "b", "c"], "c2": ["x", "y", "y"]}, index=[0, 1, 1]) + result = df.groupby(level=0).value_counts(subset=["c2"]) + expected = Series( + [1, 2], + index=MultiIndex.from_arrays([[0, 1], ["x", "y"]], names=[None, "c2"]), + name="count", + ) + tm.assert_series_equal(result, expected) + + +def test_subset_duplicate_columns(): + # GH 46383 + df = DataFrame( + [["a", "x", "x"], ["b", "y", "y"], ["b", "y", "y"]], + index=[0, 1, 1], + columns=["c1", "c2", "c2"], + ) + result = df.groupby(level=0).value_counts(subset=["c2"]) + expected = Series( + [1, 2], + index=MultiIndex.from_arrays( + [[0, 1], ["x", "y"], ["x", "y"]], names=[None, "c2", "c2"] + ), + name="count", + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("utc", [True, False]) +def test_value_counts_time_grouper(utc, unit): + # GH#50486 + df = DataFrame( + { + "Timestamp": [ + 1565083561, + 1565083561 + 86400, + 1565083561 + 86500, + 1565083561 + 86400 * 2, + 1565083561 + 86400 * 3, + 1565083561 + 86500 * 3, + 1565083561 + 86400 * 4, + ], + "Food": ["apple", "apple", "banana", "banana", "orange", "orange", "pear"], + } + ).drop([3]) + + df["Datetime"] = to_datetime(df["Timestamp"], utc=utc, unit="s").dt.as_unit(unit) + gb = df.groupby(Grouper(freq="1D", key="Datetime")) + result = gb.value_counts() + dates = to_datetime( + ["2019-08-06", "2019-08-07", "2019-08-09", "2019-08-10"], utc=utc + ).as_unit(unit) + timestamps = df["Timestamp"].unique() + index = MultiIndex( + levels=[dates, timestamps, ["apple", "banana", "orange", "pear"]], + codes=[[0, 1, 1, 2, 2, 3], range(6), [0, 0, 1, 2, 2, 3]], + names=["Datetime", "Timestamp", "Food"], + ) + expected = Series(1, index=index, name="count") + tm.assert_series_equal(result, expected) + + +def test_value_counts_integer_columns(): + # GH#55627 + df = DataFrame({1: ["a", "a", "a"], 2: ["a", "a", "d"], 3: ["a", "b", "c"]}) + gp = df.groupby([1, 2], as_index=False, sort=False) + result = gp[3].value_counts() + expected = DataFrame( + {1: ["a", "a", "a"], 2: ["a", "a", "d"], 3: ["a", "b", "c"], "count": 1} + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("vc_sort", [True, False]) +@pytest.mark.parametrize("normalize", [True, False]) +def test_value_counts_sort(sort, vc_sort, normalize): + # GH#55951 + df = DataFrame({"a": [2, 1, 1, 1], 0: [3, 4, 3, 3]}) + gb = df.groupby("a", sort=sort) + result = gb.value_counts(sort=vc_sort, normalize=normalize) + + if normalize: + values = [2 / 3, 1 / 3, 1.0] + else: + values = [2, 1, 1] + index = MultiIndex( + levels=[[1, 2], [3, 4]], codes=[[0, 0, 1], [0, 1, 0]], names=["a", 0] + ) + expected = Series(values, index=index, name="proportion" if normalize else "count") + if sort and vc_sort: + taker = [0, 1, 2] + elif sort and not vc_sort: + taker = [0, 1, 2] + elif not sort and vc_sort: + taker = [0, 2, 1] + else: + taker = [2, 1, 0] + expected = expected.take(taker) + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("vc_sort", [True, False]) +@pytest.mark.parametrize("normalize", [True, False]) +def test_value_counts_sort_categorical(sort, vc_sort, normalize): + # GH#55951 + df = DataFrame({"a": [2, 1, 1, 1], 0: [3, 4, 3, 3]}, dtype="category") + gb = df.groupby("a", sort=sort, observed=True) + result = gb.value_counts(sort=vc_sort, normalize=normalize) + + if normalize: + values = [2 / 3, 1 / 3, 1.0, 0.0] + else: + values = [2, 1, 1, 0] + name = "proportion" if normalize else "count" + expected = DataFrame( + { + "a": Categorical([1, 1, 2, 2]), + 0: Categorical([3, 4, 3, 4]), + name: values, + } + ).set_index(["a", 0])[name] + if sort and vc_sort: + taker = [0, 1, 2, 3] + elif sort and not vc_sort: + taker = [0, 1, 2, 3] + elif not sort and vc_sort: + taker = [0, 2, 1, 3] + else: + taker = [2, 3, 0, 1] + expected = expected.take(taker) + + tm.assert_series_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_all_methods.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_all_methods.py new file mode 100644 index 0000000000000000000000000000000000000000..ad35bec70f668f1df9808d1aebec2b1405424bc1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_all_methods.py @@ -0,0 +1,83 @@ +""" +Tests that apply to all groupby operation methods. + +The only tests that should appear here are those that use the `groupby_func` fixture. +Even if it does use that fixture, prefer a more specific test file if it available +such as: + + - test_categorical + - test_groupby_dropna + - test_groupby_subclass + - test_raises +""" + +import pytest + +import pandas as pd +from pandas import DataFrame +import pandas._testing as tm +from pandas.tests.groupby import get_groupby_method_args + + +def test_multiindex_group_all_columns_when_empty(groupby_func): + # GH 32464 + df = DataFrame({"a": [], "b": [], "c": []}).set_index(["a", "b", "c"]) + gb = df.groupby(["a", "b", "c"], group_keys=False) + method = getattr(gb, groupby_func) + args = get_groupby_method_args(groupby_func, df) + + warn = FutureWarning if groupby_func == "fillna" else None + warn_msg = "DataFrameGroupBy.fillna is deprecated" + with tm.assert_produces_warning(warn, match=warn_msg): + result = method(*args).index + expected = df.index + tm.assert_index_equal(result, expected) + + +def test_duplicate_columns(request, groupby_func, as_index): + # GH#50806 + if groupby_func == "corrwith": + msg = "GH#50845 - corrwith fails when there are duplicate columns" + request.applymarker(pytest.mark.xfail(reason=msg)) + df = DataFrame([[1, 3, 6], [1, 4, 7], [2, 5, 8]], columns=list("abb")) + args = get_groupby_method_args(groupby_func, df) + gb = df.groupby("a", as_index=as_index) + warn = FutureWarning if groupby_func == "fillna" else None + warn_msg = "DataFrameGroupBy.fillna is deprecated" + with tm.assert_produces_warning(warn, match=warn_msg): + result = getattr(gb, groupby_func)(*args) + + expected_df = df.set_axis(["a", "b", "c"], axis=1) + expected_args = get_groupby_method_args(groupby_func, expected_df) + expected_gb = expected_df.groupby("a", as_index=as_index) + warn = FutureWarning if groupby_func == "fillna" else None + warn_msg = "DataFrameGroupBy.fillna is deprecated" + with tm.assert_produces_warning(warn, match=warn_msg): + expected = getattr(expected_gb, groupby_func)(*expected_args) + if groupby_func not in ("size", "ngroup", "cumcount"): + expected = expected.rename(columns={"c": "b"}) + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "idx", + [ + pd.Index(["a", "a"], name="foo"), + pd.MultiIndex.from_tuples((("a", "a"), ("a", "a")), names=["foo", "bar"]), + ], +) +def test_dup_labels_output_shape(groupby_func, idx): + if groupby_func in {"size", "ngroup", "cumcount"}: + pytest.skip(f"Not applicable for {groupby_func}") + + df = DataFrame([[1, 1]], columns=idx) + grp_by = df.groupby([0]) + + args = get_groupby_method_args(groupby_func, df) + warn = FutureWarning if groupby_func == "fillna" else None + warn_msg = "DataFrameGroupBy.fillna is deprecated" + with tm.assert_produces_warning(warn, match=warn_msg): + result = getattr(grp_by, groupby_func)(*args) + + assert result.shape == (1, 2) + tm.assert_index_equal(result.columns, idx) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_api.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_api.py new file mode 100644 index 0000000000000000000000000000000000000000..5c5982954de2f889d3f23d30273cb1a10089315f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_api.py @@ -0,0 +1,265 @@ +""" +Tests of the groupby API, including internal consistency and with other pandas objects. + +Tests in this file should only check the existence, names, and arguments of groupby +methods. It should not test the results of any groupby operation. +""" + +import inspect + +import pytest + +from pandas import ( + DataFrame, + Series, +) +from pandas.core.groupby.base import ( + groupby_other_methods, + reduction_kernels, + transformation_kernels, +) +from pandas.core.groupby.generic import ( + DataFrameGroupBy, + SeriesGroupBy, +) + + +def test_tab_completion(multiindex_dataframe_random_data): + grp = multiindex_dataframe_random_data.groupby(level="second") + results = {v for v in dir(grp) if not v.startswith("_")} + expected = { + "A", + "B", + "C", + "agg", + "aggregate", + "apply", + "boxplot", + "filter", + "first", + "get_group", + "groups", + "hist", + "indices", + "last", + "max", + "mean", + "median", + "min", + "ngroups", + "nth", + "ohlc", + "plot", + "prod", + "size", + "std", + "sum", + "transform", + "var", + "sem", + "count", + "nunique", + "head", + "describe", + "cummax", + "quantile", + "rank", + "cumprod", + "tail", + "resample", + "cummin", + "fillna", + "cumsum", + "cumcount", + "ngroup", + "all", + "shift", + "skew", + "take", + "pct_change", + "any", + "corr", + "corrwith", + "cov", + "dtypes", + "ndim", + "diff", + "idxmax", + "idxmin", + "ffill", + "bfill", + "rolling", + "expanding", + "pipe", + "sample", + "ewm", + "value_counts", + } + assert results == expected + + +def test_all_methods_categorized(multiindex_dataframe_random_data): + grp = multiindex_dataframe_random_data.groupby( + multiindex_dataframe_random_data.iloc[:, 0] + ) + names = {_ for _ in dir(grp) if not _.startswith("_")} - set( + multiindex_dataframe_random_data.columns + ) + new_names = set(names) + new_names -= reduction_kernels + new_names -= transformation_kernels + new_names -= groupby_other_methods + + assert not reduction_kernels & transformation_kernels + assert not reduction_kernels & groupby_other_methods + assert not transformation_kernels & groupby_other_methods + + # new public method? + if new_names: + msg = f""" +There are uncategorized methods defined on the Grouper class: +{new_names}. + +Was a new method recently added? + +Every public method On Grouper must appear in exactly one the +following three lists defined in pandas.core.groupby.base: +- `reduction_kernels` +- `transformation_kernels` +- `groupby_other_methods` +see the comments in pandas/core/groupby/base.py for guidance on +how to fix this test. + """ + raise AssertionError(msg) + + # removed a public method? + all_categorized = reduction_kernels | transformation_kernels | groupby_other_methods + if names != all_categorized: + msg = f""" +Some methods which are supposed to be on the Grouper class +are missing: +{all_categorized - names}. + +They're still defined in one of the lists that live in pandas/core/groupby/base.py. +If you removed a method, you should update them +""" + raise AssertionError(msg) + + +def test_frame_consistency(groupby_func): + # GH#48028 + if groupby_func in ("first", "last"): + msg = "first and last are entirely different between frame and groupby" + pytest.skip(reason=msg) + + if groupby_func in ("cumcount", "ngroup"): + assert not hasattr(DataFrame, groupby_func) + return + + frame_method = getattr(DataFrame, groupby_func) + gb_method = getattr(DataFrameGroupBy, groupby_func) + result = set(inspect.signature(gb_method).parameters) + if groupby_func == "size": + # "size" is a method on GroupBy but property on DataFrame: + expected = {"self"} + else: + expected = set(inspect.signature(frame_method).parameters) + + # Exclude certain arguments from result and expected depending on the operation + # Some of these may be purposeful inconsistencies between the APIs + exclude_expected, exclude_result = set(), set() + if groupby_func in ("any", "all"): + exclude_expected = {"kwargs", "bool_only", "axis"} + elif groupby_func in ("count",): + exclude_expected = {"numeric_only", "axis"} + elif groupby_func in ("nunique",): + exclude_expected = {"axis"} + elif groupby_func in ("max", "min"): + exclude_expected = {"axis", "kwargs", "skipna"} + exclude_result = {"min_count", "engine", "engine_kwargs"} + elif groupby_func in ("mean", "std", "sum", "var"): + exclude_expected = {"axis", "kwargs", "skipna"} + exclude_result = {"engine", "engine_kwargs"} + elif groupby_func in ("median", "prod", "sem"): + exclude_expected = {"axis", "kwargs", "skipna"} + elif groupby_func in ("backfill", "bfill", "ffill", "pad"): + exclude_expected = {"downcast", "inplace", "axis", "limit_area"} + elif groupby_func in ("cummax", "cummin"): + exclude_expected = {"skipna", "args"} + exclude_result = {"numeric_only"} + elif groupby_func in ("cumprod", "cumsum"): + exclude_expected = {"skipna"} + elif groupby_func in ("pct_change",): + exclude_expected = {"kwargs"} + exclude_result = {"axis"} + elif groupby_func in ("rank",): + exclude_expected = {"numeric_only"} + elif groupby_func in ("quantile",): + exclude_expected = {"method", "axis"} + + # Ensure excluded arguments are actually in the signatures + assert result & exclude_result == exclude_result + assert expected & exclude_expected == exclude_expected + + result -= exclude_result + expected -= exclude_expected + assert result == expected + + +def test_series_consistency(request, groupby_func): + # GH#48028 + if groupby_func in ("first", "last"): + pytest.skip("first and last are entirely different between Series and groupby") + + if groupby_func in ("cumcount", "corrwith", "ngroup"): + assert not hasattr(Series, groupby_func) + return + + series_method = getattr(Series, groupby_func) + gb_method = getattr(SeriesGroupBy, groupby_func) + result = set(inspect.signature(gb_method).parameters) + if groupby_func == "size": + # "size" is a method on GroupBy but property on Series + expected = {"self"} + else: + expected = set(inspect.signature(series_method).parameters) + + # Exclude certain arguments from result and expected depending on the operation + # Some of these may be purposeful inconsistencies between the APIs + exclude_expected, exclude_result = set(), set() + if groupby_func in ("any", "all"): + exclude_expected = {"kwargs", "bool_only", "axis"} + elif groupby_func in ("diff",): + exclude_result = {"axis"} + elif groupby_func in ("max", "min"): + exclude_expected = {"axis", "kwargs", "skipna"} + exclude_result = {"min_count", "engine", "engine_kwargs"} + elif groupby_func in ("mean", "std", "sum", "var"): + exclude_expected = {"axis", "kwargs", "skipna"} + exclude_result = {"engine", "engine_kwargs"} + elif groupby_func in ("median", "prod", "sem"): + exclude_expected = {"axis", "kwargs", "skipna"} + elif groupby_func in ("backfill", "bfill", "ffill", "pad"): + exclude_expected = {"downcast", "inplace", "axis", "limit_area"} + elif groupby_func in ("cummax", "cummin"): + exclude_expected = {"skipna", "args"} + exclude_result = {"numeric_only"} + elif groupby_func in ("cumprod", "cumsum"): + exclude_expected = {"skipna"} + elif groupby_func in ("pct_change",): + exclude_expected = {"kwargs"} + exclude_result = {"axis"} + elif groupby_func in ("rank",): + exclude_expected = {"numeric_only"} + elif groupby_func in ("idxmin", "idxmax"): + exclude_expected = {"args", "kwargs"} + elif groupby_func in ("quantile",): + exclude_result = {"numeric_only"} + + # Ensure excluded arguments are actually in the signatures + assert result & exclude_result == exclude_result + assert expected & exclude_expected == exclude_expected + + result -= exclude_result + expected -= exclude_expected + assert result == expected diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_apply.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_apply.py new file mode 100644 index 0000000000000000000000000000000000000000..0ddacfab8c1026324b1e0721aa80ef6b4535098b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_apply.py @@ -0,0 +1,1606 @@ +from datetime import ( + date, + datetime, +) + +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, + bdate_range, +) +import pandas._testing as tm +from pandas.tests.groupby import get_groupby_method_args + + +def test_apply_func_that_appends_group_to_list_without_copy(): + # GH: 17718 + + df = DataFrame(1, index=list(range(10)) * 10, columns=[0]).reset_index() + groups = [] + + def store(group): + groups.append(group) + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + df.groupby("index").apply(store) + expected_value = DataFrame( + {"index": [0] * 10, 0: [1] * 10}, index=pd.RangeIndex(0, 100, 10) + ) + + tm.assert_frame_equal(groups[0], expected_value) + + +def test_apply_index_date(using_infer_string): + # GH 5788 + ts = [ + "2011-05-16 00:00", + "2011-05-16 01:00", + "2011-05-16 02:00", + "2011-05-16 03:00", + "2011-05-17 02:00", + "2011-05-17 03:00", + "2011-05-17 04:00", + "2011-05-17 05:00", + "2011-05-18 02:00", + "2011-05-18 03:00", + "2011-05-18 04:00", + "2011-05-18 05:00", + ] + df = DataFrame( + { + "value": [ + 1.40893, + 1.40760, + 1.40750, + 1.40649, + 1.40893, + 1.40760, + 1.40750, + 1.40649, + 1.40893, + 1.40760, + 1.40750, + 1.40649, + ], + }, + index=Index(pd.to_datetime(ts), name="date_time"), + ) + expected = df.groupby(df.index.date).idxmax() + result = df.groupby(df.index.date).apply(lambda x: x.idxmax()) + tm.assert_frame_equal(result, expected) + + +def test_apply_index_date_object(using_infer_string): + # GH 5789 + # don't auto coerce dates + ts = [ + "2011-05-16 00:00", + "2011-05-16 01:00", + "2011-05-16 02:00", + "2011-05-16 03:00", + "2011-05-17 02:00", + "2011-05-17 03:00", + "2011-05-17 04:00", + "2011-05-17 05:00", + "2011-05-18 02:00", + "2011-05-18 03:00", + "2011-05-18 04:00", + "2011-05-18 05:00", + ] + df = DataFrame([row.split() for row in ts], columns=["date", "time"]) + df["value"] = [ + 1.40893, + 1.40760, + 1.40750, + 1.40649, + 1.40893, + 1.40760, + 1.40750, + 1.40649, + 1.40893, + 1.40760, + 1.40750, + 1.40649, + ] + dtype = "string[pyarrow_numpy]" if using_infer_string else object + exp_idx = Index( + ["2011-05-16", "2011-05-17", "2011-05-18"], dtype=dtype, name="date" + ) + expected = Series(["00:00", "02:00", "02:00"], index=exp_idx) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("date", group_keys=False).apply( + lambda x: x["time"][x["value"].idxmax()] + ) + tm.assert_series_equal(result, expected) + + +def test_apply_trivial(using_infer_string): + # GH 20066 + # trivial apply: ignore input and return a constant dataframe. + df = DataFrame( + {"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]}, + columns=["key", "data"], + ) + dtype = "string" if using_infer_string else "object" + expected = pd.concat([df.iloc[1:], df.iloc[1:]], axis=1, keys=["float64", dtype]) + + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + gb = df.groupby([str(x) for x in df.dtypes], axis=1) + result = gb.apply(lambda x: df.iloc[1:]) + + tm.assert_frame_equal(result, expected) + + +def test_apply_trivial_fail(using_infer_string): + # GH 20066 + df = DataFrame( + {"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]}, + columns=["key", "data"], + ) + dtype = "string" if using_infer_string else "object" + expected = pd.concat([df, df], axis=1, keys=["float64", dtype]) + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + gb = df.groupby([str(x) for x in df.dtypes], axis=1, group_keys=True) + result = gb.apply(lambda x: df) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "df, group_names", + [ + (DataFrame({"a": [1, 1, 1, 2, 3], "b": ["a", "a", "a", "b", "c"]}), [1, 2, 3]), + (DataFrame({"a": [0, 0, 1, 1], "b": [0, 1, 0, 1]}), [0, 1]), + (DataFrame({"a": [1]}), [1]), + (DataFrame({"a": [1, 1, 1, 2, 2, 1, 1, 2], "b": range(8)}), [1, 2]), + (DataFrame({"a": [1, 2, 3, 1, 2, 3], "two": [4, 5, 6, 7, 8, 9]}), [1, 2, 3]), + ( + DataFrame( + { + "a": list("aaabbbcccc"), + "B": [3, 4, 3, 6, 5, 2, 1, 9, 5, 4], + "C": [4, 0, 2, 2, 2, 7, 8, 6, 2, 8], + } + ), + ["a", "b", "c"], + ), + (DataFrame([[1, 2, 3], [2, 2, 3]], columns=["a", "b", "c"]), [1, 2]), + ], + ids=[ + "GH2936", + "GH7739 & GH10519", + "GH10519", + "GH2656", + "GH12155", + "GH20084", + "GH21417", + ], +) +def test_group_apply_once_per_group(df, group_names): + # GH2936, GH7739, GH10519, GH2656, GH12155, GH20084, GH21417 + + # This test should ensure that a function is only evaluated + # once per group. Previously the function has been evaluated twice + # on the first group to check if the Cython index slider is safe to use + # This test ensures that the side effect (append to list) is only triggered + # once per group + + names = [] + # cannot parameterize over the functions since they need external + # `names` to detect side effects + + def f_copy(group): + # this takes the fast apply path + names.append(group.name) + return group.copy() + + def f_nocopy(group): + # this takes the slow apply path + names.append(group.name) + return group + + def f_scalar(group): + # GH7739, GH2656 + names.append(group.name) + return 0 + + def f_none(group): + # GH10519, GH12155, GH21417 + names.append(group.name) + + def f_constant_df(group): + # GH2936, GH20084 + names.append(group.name) + return DataFrame({"a": [1], "b": [1]}) + + for func in [f_copy, f_nocopy, f_scalar, f_none, f_constant_df]: + del names[:] + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + df.groupby("a", group_keys=False).apply(func) + assert names == group_names + + +def test_group_apply_once_per_group2(capsys): + # GH: 31111 + # groupby-apply need to execute len(set(group_by_columns)) times + + expected = 2 # Number of times `apply` should call a function for the current test + + df = DataFrame( + { + "group_by_column": [0, 0, 0, 0, 1, 1, 1, 1], + "test_column": ["0", "2", "4", "6", "8", "10", "12", "14"], + }, + index=["0", "2", "4", "6", "8", "10", "12", "14"], + ) + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + df.groupby("group_by_column", group_keys=False).apply( + lambda df: print("function_called") + ) + + result = capsys.readouterr().out.count("function_called") + # If `groupby` behaves unexpectedly, this test will break + assert result == expected + + +def test_apply_fast_slow_identical(): + # GH 31613 + + df = DataFrame({"A": [0, 0, 1], "b": range(3)}) + + # For simple index structures we check for fast/slow apply using + # an identity check on in/output + def slow(group): + return group + + def fast(group): + return group.copy() + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + fast_df = df.groupby("A", group_keys=False).apply(fast) + with tm.assert_produces_warning(DeprecationWarning, match=msg): + slow_df = df.groupby("A", group_keys=False).apply(slow) + + tm.assert_frame_equal(fast_df, slow_df) + + +@pytest.mark.parametrize( + "func", + [ + lambda x: x, + lambda x: x[:], + lambda x: x.copy(deep=False), + lambda x: x.copy(deep=True), + ], +) +def test_groupby_apply_identity_maybecopy_index_identical(func): + # GH 14927 + # Whether the function returns a copy of the input data or not should not + # have an impact on the index structure of the result since this is not + # transparent to the user + + df = DataFrame({"g": [1, 2, 2, 2], "a": [1, 2, 3, 4], "b": [5, 6, 7, 8]}) + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("g", group_keys=False).apply(func) + tm.assert_frame_equal(result, df) + + +def test_apply_with_mixed_dtype(): + # GH3480, apply with mixed dtype on axis=1 breaks in 0.11 + df = DataFrame( + { + "foo1": np.random.default_rng(2).standard_normal(6), + "foo2": ["one", "two", "two", "three", "one", "two"], + } + ) + result = df.apply(lambda x: x, axis=1).dtypes + expected = df.dtypes + tm.assert_series_equal(result, expected) + + # GH 3610 incorrect dtype conversion with as_index=False + df = DataFrame({"c1": [1, 2, 6, 6, 8]}) + df["c2"] = df.c1 / 2.0 + result1 = df.groupby("c2").mean().reset_index().c2 + result2 = df.groupby("c2", as_index=False).mean().c2 + tm.assert_series_equal(result1, result2) + + +def test_groupby_as_index_apply(): + # GH #4648 and #3417 + df = DataFrame( + { + "item_id": ["b", "b", "a", "c", "a", "b"], + "user_id": [1, 2, 1, 1, 3, 1], + "time": range(6), + } + ) + + g_as = df.groupby("user_id", as_index=True) + g_not_as = df.groupby("user_id", as_index=False) + + res_as = g_as.head(2).index + res_not_as = g_not_as.head(2).index + exp = Index([0, 1, 2, 4]) + tm.assert_index_equal(res_as, exp) + tm.assert_index_equal(res_not_as, exp) + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + res_as_apply = g_as.apply(lambda x: x.head(2)).index + with tm.assert_produces_warning(DeprecationWarning, match=msg): + res_not_as_apply = g_not_as.apply(lambda x: x.head(2)).index + + # apply doesn't maintain the original ordering + # changed in GH5610 as the as_index=False returns a MI here + exp_not_as_apply = MultiIndex.from_tuples([(0, 0), (0, 2), (1, 1), (2, 4)]) + tp = [(1, 0), (1, 2), (2, 1), (3, 4)] + exp_as_apply = MultiIndex.from_tuples(tp, names=["user_id", None]) + + tm.assert_index_equal(res_as_apply, exp_as_apply) + tm.assert_index_equal(res_not_as_apply, exp_not_as_apply) + + ind = Index(list("abcde")) + df = DataFrame([[1, 2], [2, 3], [1, 4], [1, 5], [2, 6]], index=ind) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + res = df.groupby(0, as_index=False, group_keys=False).apply(lambda x: x).index + tm.assert_index_equal(res, ind) + + +def test_apply_concat_preserve_names(three_group): + grouped = three_group.groupby(["A", "B"]) + + def desc(group): + result = group.describe() + result.index.name = "stat" + return result + + def desc2(group): + result = group.describe() + result.index.name = "stat" + result = result[: len(group)] + # weirdo + return result + + def desc3(group): + result = group.describe() + + # names are different + result.index.name = f"stat_{len(group):d}" + + result = result[: len(group)] + # weirdo + return result + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = grouped.apply(desc) + assert result.index.names == ("A", "B", "stat") + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result2 = grouped.apply(desc2) + assert result2.index.names == ("A", "B", "stat") + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result3 = grouped.apply(desc3) + assert result3.index.names == ("A", "B", None) + + +def test_apply_series_to_frame(): + def f(piece): + with np.errstate(invalid="ignore"): + logged = np.log(piece) + return DataFrame( + {"value": piece, "demeaned": piece - piece.mean(), "logged": logged} + ) + + dr = bdate_range("1/1/2000", periods=100) + ts = Series(np.random.default_rng(2).standard_normal(100), index=dr) + + grouped = ts.groupby(lambda x: x.month, group_keys=False) + result = grouped.apply(f) + + assert isinstance(result, DataFrame) + assert not hasattr(result, "name") # GH49907 + tm.assert_index_equal(result.index, ts.index) + + +def test_apply_series_yield_constant(df): + result = df.groupby(["A", "B"])["C"].apply(len) + assert result.index.names[:2] == ("A", "B") + + +def test_apply_frame_yield_constant(df): + # GH13568 + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby(["A", "B"]).apply(len) + assert isinstance(result, Series) + assert result.name is None + + result = df.groupby(["A", "B"])[["C", "D"]].apply(len) + assert isinstance(result, Series) + assert result.name is None + + +def test_apply_frame_to_series(df): + grouped = df.groupby(["A", "B"]) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = grouped.apply(len) + expected = grouped.count()["C"] + tm.assert_index_equal(result.index, expected.index) + tm.assert_numpy_array_equal(result.values, expected.values) + + +def test_apply_frame_not_as_index_column_name(df): + # GH 35964 - path within _wrap_applied_output not hit by a test + grouped = df.groupby(["A", "B"], as_index=False) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = grouped.apply(len) + expected = grouped.count().rename(columns={"C": np.nan}).drop(columns="D") + # TODO(GH#34306): Use assert_frame_equal when column name is not np.nan + tm.assert_index_equal(result.index, expected.index) + tm.assert_numpy_array_equal(result.values, expected.values) + + +def test_apply_frame_concat_series(): + def trans(group): + return group.groupby("B")["C"].sum().sort_values().iloc[:2] + + def trans2(group): + grouped = group.groupby(df.reindex(group.index)["B"]) + return grouped.sum().sort_values().iloc[:2] + + df = DataFrame( + { + "A": np.random.default_rng(2).integers(0, 5, 1000), + "B": np.random.default_rng(2).integers(0, 5, 1000), + "C": np.random.default_rng(2).standard_normal(1000), + } + ) + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("A").apply(trans) + exp = df.groupby("A")["C"].apply(trans2) + tm.assert_series_equal(result, exp, check_names=False) + assert result.name == "C" + + +def test_apply_transform(ts): + grouped = ts.groupby(lambda x: x.month, group_keys=False) + result = grouped.apply(lambda x: x * 2) + expected = grouped.transform(lambda x: x * 2) + tm.assert_series_equal(result, expected) + + +def test_apply_multikey_corner(tsframe): + grouped = tsframe.groupby([lambda x: x.year, lambda x: x.month]) + + def f(group): + return group.sort_values("A")[-5:] + + result = grouped.apply(f) + for key, group in grouped: + tm.assert_frame_equal(result.loc[key], f(group)) + + +@pytest.mark.parametrize("group_keys", [True, False]) +def test_apply_chunk_view(group_keys): + # Low level tinkering could be unsafe, make sure not + df = DataFrame({"key": [1, 1, 1, 2, 2, 2, 3, 3, 3], "value": range(9)}) + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("key", group_keys=group_keys).apply(lambda x: x.iloc[:2]) + expected = df.take([0, 1, 3, 4, 6, 7]) + if group_keys: + expected.index = MultiIndex.from_arrays( + [[1, 1, 2, 2, 3, 3], expected.index], names=["key", None] + ) + + tm.assert_frame_equal(result, expected) + + +def test_apply_no_name_column_conflict(): + df = DataFrame( + { + "name": [1, 1, 1, 1, 1, 1, 2, 2, 2, 2], + "name2": [0, 0, 0, 1, 1, 1, 0, 0, 1, 1], + "value": range(9, -1, -1), + } + ) + + # it works! #2605 + grouped = df.groupby(["name", "name2"]) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + grouped.apply(lambda x: x.sort_values("value", inplace=True)) + + +def test_apply_typecast_fail(): + df = DataFrame( + { + "d": [1.0, 1.0, 1.0, 2.0, 2.0, 2.0], + "c": np.tile(["a", "b", "c"], 2), + "v": np.arange(1.0, 7.0), + } + ) + + def f(group): + v = group["v"] + group["v2"] = (v - v.min()) / (v.max() - v.min()) + return group + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("d", group_keys=False).apply(f) + + expected = df.copy() + expected["v2"] = np.tile([0.0, 0.5, 1], 2) + + tm.assert_frame_equal(result, expected) + + +def test_apply_multiindex_fail(): + index = MultiIndex.from_arrays([[0, 0, 0, 1, 1, 1], [1, 2, 3, 1, 2, 3]]) + df = DataFrame( + { + "d": [1.0, 1.0, 1.0, 2.0, 2.0, 2.0], + "c": np.tile(["a", "b", "c"], 2), + "v": np.arange(1.0, 7.0), + }, + index=index, + ) + + def f(group): + v = group["v"] + group["v2"] = (v - v.min()) / (v.max() - v.min()) + return group + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("d", group_keys=False).apply(f) + + expected = df.copy() + expected["v2"] = np.tile([0.0, 0.5, 1], 2) + + tm.assert_frame_equal(result, expected) + + +def test_apply_corner(tsframe): + result = tsframe.groupby(lambda x: x.year, group_keys=False).apply(lambda x: x * 2) + expected = tsframe * 2 + tm.assert_frame_equal(result, expected) + + +def test_apply_without_copy(): + # GH 5545 + # returning a non-copy in an applied function fails + + data = DataFrame( + { + "id_field": [100, 100, 200, 300], + "category": ["a", "b", "c", "c"], + "value": [1, 2, 3, 4], + } + ) + + def filt1(x): + if x.shape[0] == 1: + return x.copy() + else: + return x[x.category == "c"] + + def filt2(x): + if x.shape[0] == 1: + return x + else: + return x[x.category == "c"] + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + expected = data.groupby("id_field").apply(filt1) + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = data.groupby("id_field").apply(filt2) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("test_series", [True, False]) +def test_apply_with_duplicated_non_sorted_axis(test_series): + # GH 30667 + df = DataFrame( + [["x", "p"], ["x", "p"], ["x", "o"]], columns=["X", "Y"], index=[1, 2, 2] + ) + if test_series: + ser = df.set_index("Y")["X"] + result = ser.groupby(level=0, group_keys=False).apply(lambda x: x) + + # not expecting the order to remain the same for duplicated axis + result = result.sort_index() + expected = ser.sort_index() + tm.assert_series_equal(result, expected) + else: + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("Y", group_keys=False).apply(lambda x: x) + + # not expecting the order to remain the same for duplicated axis + result = result.sort_values("Y") + expected = df.sort_values("Y") + tm.assert_frame_equal(result, expected) + + +def test_apply_reindex_values(): + # GH: 26209 + # reindexing from a single column of a groupby object with duplicate indices caused + # a ValueError (cannot reindex from duplicate axis) in 0.24.2, the problem was + # solved in #30679 + values = [1, 2, 3, 4] + indices = [1, 1, 2, 2] + df = DataFrame({"group": ["Group1", "Group2"] * 2, "value": values}, index=indices) + expected = Series(values, index=indices, name="value") + + def reindex_helper(x): + return x.reindex(np.arange(x.index.min(), x.index.max() + 1)) + + # the following group by raised a ValueError + result = df.groupby("group", group_keys=False).value.apply(reindex_helper) + tm.assert_series_equal(expected, result) + + +def test_apply_corner_cases(): + # #535, can't use sliding iterator + + N = 1000 + labels = np.random.default_rng(2).integers(0, 100, size=N) + df = DataFrame( + { + "key": labels, + "value1": np.random.default_rng(2).standard_normal(N), + "value2": ["foo", "bar", "baz", "qux"] * (N // 4), + } + ) + + grouped = df.groupby("key", group_keys=False) + + def f(g): + g["value3"] = g["value1"] * 2 + return g + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = grouped.apply(f) + assert "value3" in result + + +def test_apply_numeric_coercion_when_datetime(): + # In the past, group-by/apply operations have been over-eager + # in converting dtypes to numeric, in the presence of datetime + # columns. Various GH issues were filed, the reproductions + # for which are here. + + # GH 15670 + df = DataFrame( + {"Number": [1, 2], "Date": ["2017-03-02"] * 2, "Str": ["foo", "inf"]} + ) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + expected = df.groupby(["Number"]).apply(lambda x: x.iloc[0]) + df.Date = pd.to_datetime(df.Date) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby(["Number"]).apply(lambda x: x.iloc[0]) + tm.assert_series_equal(result["Str"], expected["Str"]) + + # GH 15421 + df = DataFrame( + {"A": [10, 20, 30], "B": ["foo", "3", "4"], "T": [pd.Timestamp("12:31:22")] * 3} + ) + + def get_B(g): + return g.iloc[0][["B"]] + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("A").apply(get_B)["B"] + expected = df.B + expected.index = df.A + tm.assert_series_equal(result, expected) + + # GH 14423 + def predictions(tool): + out = Series(index=["p1", "p2", "useTime"], dtype=object) + if "step1" in list(tool.State): + out["p1"] = str(tool[tool.State == "step1"].Machine.values[0]) + if "step2" in list(tool.State): + out["p2"] = str(tool[tool.State == "step2"].Machine.values[0]) + out["useTime"] = str(tool[tool.State == "step2"].oTime.values[0]) + return out + + df1 = DataFrame( + { + "Key": ["B", "B", "A", "A"], + "State": ["step1", "step2", "step1", "step2"], + "oTime": ["", "2016-09-19 05:24:33", "", "2016-09-19 23:59:04"], + "Machine": ["23", "36L", "36R", "36R"], + } + ) + df2 = df1.copy() + df2.oTime = pd.to_datetime(df2.oTime) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + expected = df1.groupby("Key").apply(predictions).p1 + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df2.groupby("Key").apply(predictions).p1 + tm.assert_series_equal(expected, result) + + +def test_apply_aggregating_timedelta_and_datetime(): + # Regression test for GH 15562 + # The following groupby caused ValueErrors and IndexErrors pre 0.20.0 + + df = DataFrame( + { + "clientid": ["A", "B", "C"], + "datetime": [np.datetime64("2017-02-01 00:00:00")] * 3, + } + ) + df["time_delta_zero"] = df.datetime - df.datetime + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("clientid").apply( + lambda ddf: Series( + {"clientid_age": ddf.time_delta_zero.min(), "date": ddf.datetime.min()} + ) + ) + expected = DataFrame( + { + "clientid": ["A", "B", "C"], + "clientid_age": [np.timedelta64(0, "D")] * 3, + "date": [np.datetime64("2017-02-01 00:00:00")] * 3, + } + ).set_index("clientid") + + tm.assert_frame_equal(result, expected) + + +def test_apply_groupby_datetimeindex(): + # GH 26182 + # groupby apply failed on dataframe with DatetimeIndex + + data = [["A", 10], ["B", 20], ["B", 30], ["C", 40], ["C", 50]] + df = DataFrame( + data, columns=["Name", "Value"], index=pd.date_range("2020-09-01", "2020-09-05") + ) + + result = df.groupby("Name").sum() + + expected = DataFrame({"Name": ["A", "B", "C"], "Value": [10, 50, 90]}) + expected.set_index("Name", inplace=True) + + tm.assert_frame_equal(result, expected) + + +def test_time_field_bug(): + # Test a fix for the following error related to GH issue 11324 When + # non-key fields in a group-by dataframe contained time-based fields + # that were not returned by the apply function, an exception would be + # raised. + + df = DataFrame({"a": 1, "b": [datetime.now() for nn in range(10)]}) + + def func_with_no_date(batch): + return Series({"c": 2}) + + def func_with_date(batch): + return Series({"b": datetime(2015, 1, 1), "c": 2}) + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + dfg_no_conversion = df.groupby(by=["a"]).apply(func_with_no_date) + dfg_no_conversion_expected = DataFrame({"c": 2}, index=[1]) + dfg_no_conversion_expected.index.name = "a" + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + dfg_conversion = df.groupby(by=["a"]).apply(func_with_date) + dfg_conversion_expected = DataFrame( + {"b": pd.Timestamp(2015, 1, 1).as_unit("ns"), "c": 2}, index=[1] + ) + dfg_conversion_expected.index.name = "a" + + tm.assert_frame_equal(dfg_no_conversion, dfg_no_conversion_expected) + tm.assert_frame_equal(dfg_conversion, dfg_conversion_expected) + + +def test_gb_apply_list_of_unequal_len_arrays(): + # GH1738 + df = DataFrame( + { + "group1": ["a", "a", "a", "b", "b", "b", "a", "a", "a", "b", "b", "b"], + "group2": ["c", "c", "d", "d", "d", "e", "c", "c", "d", "d", "d", "e"], + "weight": [1.1, 2, 3, 4, 5, 6, 2, 4, 6, 8, 1, 2], + "value": [7.1, 8, 9, 10, 11, 12, 8, 7, 6, 5, 4, 3], + } + ) + df = df.set_index(["group1", "group2"]) + df_grouped = df.groupby(level=["group1", "group2"], sort=True) + + def noddy(value, weight): + out = np.array(value * weight).repeat(3) + return out + + # the kernel function returns arrays of unequal length + # pandas sniffs the first one, sees it's an array and not + # a list, and assumed the rest are of equal length + # and so tries a vstack + + # don't die + df_grouped.apply(lambda x: noddy(x.value, x.weight)) + + +def test_groupby_apply_all_none(): + # Tests to make sure no errors if apply function returns all None + # values. Issue 9684. + test_df = DataFrame({"groups": [0, 0, 1, 1], "random_vars": [8, 7, 4, 5]}) + + def test_func(x): + pass + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = test_df.groupby("groups").apply(test_func) + expected = DataFrame() + tm.assert_frame_equal(result, expected) + + +def test_groupby_apply_none_first(): + # GH 12824. Tests if apply returns None first. + test_df1 = DataFrame({"groups": [1, 1, 1, 2], "vars": [0, 1, 2, 3]}) + test_df2 = DataFrame({"groups": [1, 2, 2, 2], "vars": [0, 1, 2, 3]}) + + def test_func(x): + if x.shape[0] < 2: + return None + return x.iloc[[0, -1]] + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result1 = test_df1.groupby("groups").apply(test_func) + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result2 = test_df2.groupby("groups").apply(test_func) + index1 = MultiIndex.from_arrays([[1, 1], [0, 2]], names=["groups", None]) + index2 = MultiIndex.from_arrays([[2, 2], [1, 3]], names=["groups", None]) + expected1 = DataFrame({"groups": [1, 1], "vars": [0, 2]}, index=index1) + expected2 = DataFrame({"groups": [2, 2], "vars": [1, 3]}, index=index2) + tm.assert_frame_equal(result1, expected1) + tm.assert_frame_equal(result2, expected2) + + +def test_groupby_apply_return_empty_chunk(): + # GH 22221: apply filter which returns some empty groups + df = DataFrame({"value": [0, 1], "group": ["filled", "empty"]}) + groups = df.groupby("group") + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = groups.apply(lambda group: group[group.value != 1]["value"]) + expected = Series( + [0], + name="value", + index=MultiIndex.from_product( + [["empty", "filled"], [0]], names=["group", None] + ).drop("empty"), + ) + tm.assert_series_equal(result, expected) + + +def test_apply_with_mixed_types(): + # gh-20949 + df = DataFrame({"A": "a a b".split(), "B": [1, 2, 3], "C": [4, 6, 5]}) + g = df.groupby("A", group_keys=False) + + result = g.transform(lambda x: x / x.sum()) + expected = DataFrame({"B": [1 / 3.0, 2 / 3.0, 1], "C": [0.4, 0.6, 1.0]}) + tm.assert_frame_equal(result, expected) + + result = g.apply(lambda x: x / x.sum()) + tm.assert_frame_equal(result, expected) + + +def test_func_returns_object(): + # GH 28652 + df = DataFrame({"a": [1, 2]}, index=Index([1, 2])) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("a").apply(lambda g: g.index) + expected = Series([Index([1]), Index([2])], index=Index([1, 2], name="a")) + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "group_column_dtlike", + [datetime.today(), datetime.today().date(), datetime.today().time()], +) +def test_apply_datetime_issue(group_column_dtlike, using_infer_string): + # GH-28247 + # groupby-apply throws an error if one of the columns in the DataFrame + # is a datetime object and the column labels are different from + # standard int values in range(len(num_columns)) + + df = DataFrame({"a": ["foo"], "b": [group_column_dtlike]}) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("a").apply(lambda x: Series(["spam"], index=[42])) + + dtype = "string" if using_infer_string else "object" + expected = DataFrame(["spam"], Index(["foo"], dtype=dtype, name="a"), columns=[42]) + tm.assert_frame_equal(result, expected) + + +def test_apply_series_return_dataframe_groups(): + # GH 10078 + tdf = DataFrame( + { + "day": { + 0: pd.Timestamp("2015-02-24 00:00:00"), + 1: pd.Timestamp("2015-02-24 00:00:00"), + 2: pd.Timestamp("2015-02-24 00:00:00"), + 3: pd.Timestamp("2015-02-24 00:00:00"), + 4: pd.Timestamp("2015-02-24 00:00:00"), + }, + "userAgent": { + 0: "some UA string", + 1: "some UA string", + 2: "some UA string", + 3: "another UA string", + 4: "some UA string", + }, + "userId": { + 0: "17661101", + 1: "17661101", + 2: "17661101", + 3: "17661101", + 4: "17661101", + }, + } + ) + + def most_common_values(df): + return Series({c: s.value_counts().index[0] for c, s in df.items()}) + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = tdf.groupby("day").apply(most_common_values)["userId"] + expected = Series( + ["17661101"], index=pd.DatetimeIndex(["2015-02-24"], name="day"), name="userId" + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("category", [False, True]) +def test_apply_multi_level_name(category): + # https://github.com/pandas-dev/pandas/issues/31068 + b = [1, 2] * 5 + if category: + b = pd.Categorical(b, categories=[1, 2, 3]) + expected_index = pd.CategoricalIndex([1, 2, 3], categories=[1, 2, 3], name="B") + expected_values = [20, 25, 0] + else: + expected_index = Index([1, 2], name="B") + expected_values = [20, 25] + expected = DataFrame( + {"C": expected_values, "D": expected_values}, index=expected_index + ) + + df = DataFrame( + {"A": np.arange(10), "B": b, "C": list(range(10)), "D": list(range(10))} + ).set_index(["A", "B"]) + result = df.groupby("B", observed=False).apply(lambda x: x.sum()) + tm.assert_frame_equal(result, expected) + assert df.index.names == ["A", "B"] + + +def test_groupby_apply_datetime_result_dtypes(using_infer_string): + # GH 14849 + data = DataFrame.from_records( + [ + (pd.Timestamp(2016, 1, 1), "red", "dark", 1, "8"), + (pd.Timestamp(2015, 1, 1), "green", "stormy", 2, "9"), + (pd.Timestamp(2014, 1, 1), "blue", "bright", 3, "10"), + (pd.Timestamp(2013, 1, 1), "blue", "calm", 4, "potato"), + ], + columns=["observation", "color", "mood", "intensity", "score"], + ) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = data.groupby("color").apply(lambda g: g.iloc[0]).dtypes + dtype = "string" if using_infer_string else object + expected = Series( + [np.dtype("datetime64[ns]"), dtype, dtype, np.int64, dtype], + index=["observation", "color", "mood", "intensity", "score"], + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "index", + [ + pd.CategoricalIndex(list("abc")), + pd.interval_range(0, 3), + pd.period_range("2020", periods=3, freq="D"), + MultiIndex.from_tuples([("a", 0), ("a", 1), ("b", 0)]), + ], +) +def test_apply_index_has_complex_internals(index): + # GH 31248 + df = DataFrame({"group": [1, 1, 2], "value": [0, 1, 0]}, index=index) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("group", group_keys=False).apply(lambda x: x) + tm.assert_frame_equal(result, df) + + +@pytest.mark.parametrize( + "function, expected_values", + [ + (lambda x: x.index.to_list(), [[0, 1], [2, 3]]), + (lambda x: set(x.index.to_list()), [{0, 1}, {2, 3}]), + (lambda x: tuple(x.index.to_list()), [(0, 1), (2, 3)]), + ( + lambda x: dict(enumerate(x.index.to_list())), + [{0: 0, 1: 1}, {0: 2, 1: 3}], + ), + ( + lambda x: [{n: i} for (n, i) in enumerate(x.index.to_list())], + [[{0: 0}, {1: 1}], [{0: 2}, {1: 3}]], + ), + ], +) +def test_apply_function_returns_non_pandas_non_scalar(function, expected_values): + # GH 31441 + df = DataFrame(["A", "A", "B", "B"], columns=["groups"]) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("groups").apply(function) + expected = Series(expected_values, index=Index(["A", "B"], name="groups")) + tm.assert_series_equal(result, expected) + + +def test_apply_function_returns_numpy_array(): + # GH 31605 + def fct(group): + return group["B"].values.flatten() + + df = DataFrame({"A": ["a", "a", "b", "none"], "B": [1, 2, 3, np.nan]}) + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("A").apply(fct) + expected = Series( + [[1.0, 2.0], [3.0], [np.nan]], index=Index(["a", "b", "none"], name="A") + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("function", [lambda gr: gr.index, lambda gr: gr.index + 1 - 1]) +def test_apply_function_index_return(function): + # GH: 22541 + df = DataFrame([1, 2, 2, 2, 1, 2, 3, 1, 3, 1], columns=["id"]) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("id").apply(function) + expected = Series( + [Index([0, 4, 7, 9]), Index([1, 2, 3, 5]), Index([6, 8])], + index=Index([1, 2, 3], name="id"), + ) + tm.assert_series_equal(result, expected) + + +def test_apply_function_with_indexing_return_column(): + # GH#7002, GH#41480, GH#49256 + df = DataFrame( + { + "foo1": ["one", "two", "two", "three", "one", "two"], + "foo2": [1, 2, 4, 4, 5, 6], + } + ) + result = df.groupby("foo1", as_index=False).apply(lambda x: x.mean()) + expected = DataFrame( + { + "foo1": ["one", "three", "two"], + "foo2": [3.0, 4.0, 4.0], + } + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "udf", + [(lambda x: x.copy()), (lambda x: x.copy().rename(lambda y: y + 1))], +) +@pytest.mark.parametrize("group_keys", [True, False]) +def test_apply_result_type(group_keys, udf): + # https://github.com/pandas-dev/pandas/issues/34809 + # We'd like to control whether the group keys end up in the index + # regardless of whether the UDF happens to be a transform. + df = DataFrame({"A": ["a", "b"], "B": [1, 2]}) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + df_result = df.groupby("A", group_keys=group_keys).apply(udf) + series_result = df.B.groupby(df.A, group_keys=group_keys).apply(udf) + + if group_keys: + assert df_result.index.nlevels == 2 + assert series_result.index.nlevels == 2 + else: + assert df_result.index.nlevels == 1 + assert series_result.index.nlevels == 1 + + +def test_result_order_group_keys_false(): + # GH 34998 + # apply result order should not depend on whether index is the same or just equal + df = DataFrame({"A": [2, 1, 2], "B": [1, 2, 3]}) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("A", group_keys=False).apply(lambda x: x) + with tm.assert_produces_warning(DeprecationWarning, match=msg): + expected = df.groupby("A", group_keys=False).apply(lambda x: x.copy()) + tm.assert_frame_equal(result, expected) + + +def test_apply_with_timezones_aware(): + # GH: 27212 + dates = ["2001-01-01"] * 2 + ["2001-01-02"] * 2 + ["2001-01-03"] * 2 + index_no_tz = pd.DatetimeIndex(dates) + index_tz = pd.DatetimeIndex(dates, tz="UTC") + df1 = DataFrame({"x": list(range(2)) * 3, "y": range(6), "t": index_no_tz}) + df2 = DataFrame({"x": list(range(2)) * 3, "y": range(6), "t": index_tz}) + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result1 = df1.groupby("x", group_keys=False).apply( + lambda df: df[["x", "y"]].copy() + ) + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result2 = df2.groupby("x", group_keys=False).apply( + lambda df: df[["x", "y"]].copy() + ) + + tm.assert_frame_equal(result1, result2) + + +def test_apply_is_unchanged_when_other_methods_are_called_first(reduction_func): + # GH #34656 + # GH #34271 + df = DataFrame( + { + "a": [99, 99, 99, 88, 88, 88], + "b": [1, 2, 3, 4, 5, 6], + "c": [10, 20, 30, 40, 50, 60], + } + ) + + expected = DataFrame( + {"b": [15, 6], "c": [150, 60]}, + index=Index([88, 99], name="a"), + ) + + # Check output when no other methods are called before .apply() + grp = df.groupby(by="a") + msg = "The behavior of DataFrame.sum with axis=None is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg, check_stacklevel=False): + result = grp.apply(sum, include_groups=False) + tm.assert_frame_equal(result, expected) + + # Check output when another method is called before .apply() + grp = df.groupby(by="a") + args = get_groupby_method_args(reduction_func, df) + _ = getattr(grp, reduction_func)(*args) + with tm.assert_produces_warning(FutureWarning, match=msg, check_stacklevel=False): + result = grp.apply(sum, include_groups=False) + tm.assert_frame_equal(result, expected) + + +def test_apply_with_date_in_multiindex_does_not_convert_to_timestamp(): + # GH 29617 + + df = DataFrame( + { + "A": ["a", "a", "a", "b"], + "B": [ + date(2020, 1, 10), + date(2020, 1, 10), + date(2020, 2, 10), + date(2020, 2, 10), + ], + "C": [1, 2, 3, 4], + }, + index=Index([100, 101, 102, 103], name="idx"), + ) + + grp = df.groupby(["A", "B"]) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = grp.apply(lambda x: x.head(1)) + + expected = df.iloc[[0, 2, 3]] + expected = expected.reset_index() + expected.index = MultiIndex.from_frame(expected[["A", "B", "idx"]]) + expected = expected.drop(columns="idx") + + tm.assert_frame_equal(result, expected) + for val in result.index.levels[1]: + assert type(val) is date + + +def test_apply_by_cols_equals_apply_by_rows_transposed(): + # GH 16646 + # Operating on the columns, or transposing and operating on the rows + # should give the same result. There was previously a bug where the + # by_rows operation would work fine, but by_cols would throw a ValueError + + df = DataFrame( + np.random.default_rng(2).random([6, 4]), + columns=MultiIndex.from_product([["A", "B"], [1, 2]]), + ) + + msg = "The 'axis' keyword in DataFrame.groupby is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + gb = df.T.groupby(axis=0, level=0) + by_rows = gb.apply(lambda x: x.droplevel(axis=0, level=0)) + + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + gb2 = df.groupby(axis=1, level=0) + by_cols = gb2.apply(lambda x: x.droplevel(axis=1, level=0)) + + tm.assert_frame_equal(by_cols, by_rows.T) + tm.assert_frame_equal(by_cols, df) + + +@pytest.mark.parametrize("dropna", [True, False]) +def test_apply_dropna_with_indexed_same(dropna): + # GH 38227 + # GH#43205 + df = DataFrame( + { + "col": [1, 2, 3, 4, 5], + "group": ["a", np.nan, np.nan, "b", "b"], + }, + index=list("xxyxz"), + ) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("group", dropna=dropna, group_keys=False).apply(lambda x: x) + expected = df.dropna() if dropna else df.iloc[[0, 3, 1, 2, 4]] + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "as_index, expected", + [ + [ + False, + DataFrame( + [[1, 1, 1], [2, 2, 1]], columns=Index(["a", "b", None], dtype=object) + ), + ], + [ + True, + Series( + [1, 1], index=MultiIndex.from_tuples([(1, 1), (2, 2)], names=["a", "b"]) + ), + ], + ], +) +def test_apply_as_index_constant_lambda(as_index, expected): + # GH 13217 + df = DataFrame({"a": [1, 1, 2, 2], "b": [1, 1, 2, 2], "c": [1, 1, 1, 1]}) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby(["a", "b"], as_index=as_index).apply(lambda x: 1) + tm.assert_equal(result, expected) + + +def test_sort_index_groups(): + # GH 20420 + df = DataFrame( + {"A": [1, 2, 3, 4, 5], "B": [6, 7, 8, 9, 0], "C": [1, 1, 1, 2, 2]}, + index=range(5), + ) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("C").apply(lambda x: x.A.sort_index()) + expected = Series( + range(1, 6), + index=MultiIndex.from_tuples( + [(1, 0), (1, 1), (1, 2), (2, 3), (2, 4)], names=["C", None] + ), + name="A", + ) + tm.assert_series_equal(result, expected) + + +def test_positional_slice_groups_datetimelike(): + # GH 21651 + expected = DataFrame( + { + "date": pd.date_range("2010-01-01", freq="12h", periods=5), + "vals": range(5), + "let": list("abcde"), + } + ) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = expected.groupby( + [expected.let, expected.date.dt.date], group_keys=False + ).apply(lambda x: x.iloc[0:]) + tm.assert_frame_equal(result, expected) + + +def test_groupby_apply_shape_cache_safety(): + # GH#42702 this fails if we cache_readonly Block.shape + df = DataFrame({"A": ["a", "a", "b"], "B": [1, 2, 3], "C": [4, 6, 5]}) + gb = df.groupby("A") + result = gb[["B", "C"]].apply(lambda x: x.astype(float).max() - x.min()) + + expected = DataFrame( + {"B": [1.0, 0.0], "C": [2.0, 0.0]}, index=Index(["a", "b"], name="A") + ) + tm.assert_frame_equal(result, expected) + + +def test_groupby_apply_to_series_name(): + # GH52444 + df = DataFrame.from_dict( + { + "a": ["a", "b", "a", "b"], + "b1": ["aa", "ac", "ac", "ad"], + "b2": ["aa", "aa", "aa", "ac"], + } + ) + grp = df.groupby("a")[["b1", "b2"]] + result = grp.apply(lambda x: x.unstack().value_counts()) + + expected_idx = MultiIndex.from_arrays( + arrays=[["a", "a", "b", "b", "b"], ["aa", "ac", "ac", "ad", "aa"]], + names=["a", None], + ) + expected = Series([3, 1, 2, 1, 1], index=expected_idx, name="count") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("dropna", [True, False]) +def test_apply_na(dropna): + # GH#28984 + df = DataFrame( + {"grp": [1, 1, 2, 2], "y": [1, 0, 2, 5], "z": [1, 2, np.nan, np.nan]} + ) + dfgrp = df.groupby("grp", dropna=dropna) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = dfgrp.apply(lambda grp_df: grp_df.nlargest(1, "z")) + with tm.assert_produces_warning(DeprecationWarning, match=msg): + expected = dfgrp.apply(lambda x: x.sort_values("z", ascending=False).head(1)) + tm.assert_frame_equal(result, expected) + + +def test_apply_empty_string_nan_coerce_bug(): + # GH#24903 + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = ( + DataFrame( + { + "a": [1, 1, 2, 2], + "b": ["", "", "", ""], + "c": pd.to_datetime([1, 2, 3, 4], unit="s"), + } + ) + .groupby(["a", "b"]) + .apply(lambda df: df.iloc[-1]) + ) + expected = DataFrame( + [[1, "", pd.to_datetime(2, unit="s")], [2, "", pd.to_datetime(4, unit="s")]], + columns=["a", "b", "c"], + index=MultiIndex.from_tuples([(1, ""), (2, "")], names=["a", "b"]), + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("index_values", [[1, 2, 3], [1.0, 2.0, 3.0]]) +def test_apply_index_key_error_bug(index_values): + # GH 44310 + result = DataFrame( + { + "a": ["aa", "a2", "a3"], + "b": [1, 2, 3], + }, + index=Index(index_values), + ) + expected = DataFrame( + { + "b_mean": [2.0, 3.0, 1.0], + }, + index=Index(["a2", "a3", "aa"], name="a"), + ) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = result.groupby("a").apply( + lambda df: Series([df["b"].mean()], index=["b_mean"]) + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "arg,idx", + [ + [ + [ + 1, + 2, + 3, + ], + [ + 0.1, + 0.3, + 0.2, + ], + ], + [ + [ + 1, + 2, + 3, + ], + [ + 0.1, + 0.2, + 0.3, + ], + ], + [ + [ + 1, + 4, + 3, + ], + [ + 0.1, + 0.4, + 0.2, + ], + ], + ], +) +def test_apply_nonmonotonic_float_index(arg, idx): + # GH 34455 + expected = DataFrame({"col": arg}, index=idx) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = expected.groupby("col", group_keys=False).apply(lambda x: x) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("args, kwargs", [([True], {}), ([], {"numeric_only": True})]) +def test_apply_str_with_args(df, args, kwargs): + # GH#46479 + gb = df.groupby("A") + result = gb.apply("sum", *args, **kwargs) + expected = gb.sum(numeric_only=True) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("name", ["some_name", None]) +def test_result_name_when_one_group(name): + # GH 46369 + ser = Series([1, 2], name=name) + result = ser.groupby(["a", "a"], group_keys=False).apply(lambda x: x) + expected = Series([1, 2], name=name) + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "method, op", + [ + ("apply", lambda gb: gb.values[-1]), + ("apply", lambda gb: gb["b"].iloc[0]), + ("agg", "skew"), + ("agg", "prod"), + ("agg", "sum"), + ], +) +def test_empty_df(method, op): + # GH 47985 + empty_df = DataFrame({"a": [], "b": []}) + gb = empty_df.groupby("a", group_keys=True) + group = getattr(gb, "b") + + result = getattr(group, method)(op) + expected = Series( + [], name="b", dtype="float64", index=Index([], dtype="float64", name="a") + ) + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("include_groups", [True, False]) +def test_include_groups(include_groups): + # GH#7155 + df = DataFrame({"a": [1, 1, 2], "b": [3, 4, 5]}) + gb = df.groupby("a") + warn = DeprecationWarning if include_groups else None + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(warn, match=msg): + result = gb.apply(lambda x: x.sum(), include_groups=include_groups) + expected = DataFrame({"a": [2, 2], "b": [7, 5]}, index=Index([1, 2], name="a")) + if not include_groups: + expected = expected[["b"]] + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("f", [max, min, sum]) +@pytest.mark.parametrize("keys", ["jim", ["jim", "joe"]]) # Single key # Multi-key +def test_builtins_apply(keys, f): + # see gh-8155 + rs = np.random.default_rng(2) + df = DataFrame(rs.integers(1, 7, (10, 2)), columns=["jim", "joe"]) + df["jolie"] = rs.standard_normal(10) + + gb = df.groupby(keys) + + fname = f.__name__ + + warn = None if f is not sum else FutureWarning + msg = "The behavior of DataFrame.sum with axis=None is deprecated" + with tm.assert_produces_warning( + warn, match=msg, check_stacklevel=False, raise_on_extra_warnings=False + ): + # Also warns on deprecation GH#53425 + result = gb.apply(f) + ngroups = len(df.drop_duplicates(subset=keys)) + + assert_msg = f"invalid frame shape: {result.shape} (expected ({ngroups}, 3))" + assert result.shape == (ngroups, 3), assert_msg + + npfunc = lambda x: getattr(np, fname)(x, axis=0) # numpy's equivalent function + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + expected = gb.apply(npfunc) + tm.assert_frame_equal(result, expected) + + with tm.assert_produces_warning(DeprecationWarning, match=msg): + expected2 = gb.apply(lambda x: npfunc(x)) + tm.assert_frame_equal(result, expected2) + + if f != sum: + expected = gb.agg(fname).reset_index() + expected.set_index(keys, inplace=True, drop=False) + tm.assert_frame_equal(result, expected, check_dtype=False) + + tm.assert_series_equal(getattr(result, fname)(axis=0), getattr(df, fname)(axis=0)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_apply_mutate.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_apply_mutate.py new file mode 100644 index 0000000000000000000000000000000000000000..cfd1a4bca9d914d736a42d7665bc03fa6412b1a9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_apply_mutate.py @@ -0,0 +1,163 @@ +import numpy as np + +import pandas as pd +import pandas._testing as tm + + +def test_group_by_copy(): + # GH#44803 + df = pd.DataFrame( + { + "name": ["Alice", "Bob", "Carl"], + "age": [20, 21, 20], + } + ).set_index("name") + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + grp_by_same_value = df.groupby(["age"], group_keys=False).apply( + lambda group: group + ) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + grp_by_copy = df.groupby(["age"], group_keys=False).apply( + lambda group: group.copy() + ) + tm.assert_frame_equal(grp_by_same_value, grp_by_copy) + + +def test_mutate_groups(): + # GH3380 + + df = pd.DataFrame( + { + "cat1": ["a"] * 8 + ["b"] * 6, + "cat2": ["c"] * 2 + + ["d"] * 2 + + ["e"] * 2 + + ["f"] * 2 + + ["c"] * 2 + + ["d"] * 2 + + ["e"] * 2, + "cat3": [f"g{x}" for x in range(1, 15)], + "val": np.random.default_rng(2).integers(100, size=14), + } + ) + + def f_copy(x): + x = x.copy() + x["rank"] = x.val.rank(method="min") + return x.groupby("cat2")["rank"].min() + + def f_no_copy(x): + x["rank"] = x.val.rank(method="min") + return x.groupby("cat2")["rank"].min() + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + grpby_copy = df.groupby("cat1").apply(f_copy) + with tm.assert_produces_warning(DeprecationWarning, match=msg): + grpby_no_copy = df.groupby("cat1").apply(f_no_copy) + tm.assert_series_equal(grpby_copy, grpby_no_copy) + + +def test_no_mutate_but_looks_like(): + # GH 8467 + # first show's mutation indicator + # second does not, but should yield the same results + df = pd.DataFrame({"key": [1, 1, 1, 2, 2, 2, 3, 3, 3], "value": range(9)}) + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result1 = df.groupby("key", group_keys=True).apply(lambda x: x[:].key) + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result2 = df.groupby("key", group_keys=True).apply(lambda x: x.key) + tm.assert_series_equal(result1, result2) + + +def test_apply_function_with_indexing(warn_copy_on_write): + # GH: 33058 + df = pd.DataFrame( + {"col1": ["A", "A", "A", "B", "B", "B"], "col2": [1, 2, 3, 4, 5, 6]} + ) + + def fn(x): + x.loc[x.index[-1], "col2"] = 0 + return x.col2 + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning( + DeprecationWarning, match=msg, raise_on_extra_warnings=not warn_copy_on_write + ): + result = df.groupby(["col1"], as_index=False).apply(fn) + expected = pd.Series( + [1, 2, 0, 4, 5, 0], + index=pd.MultiIndex.from_tuples( + [(0, 0), (0, 1), (0, 2), (1, 3), (1, 4), (1, 5)] + ), + name="col2", + ) + tm.assert_series_equal(result, expected) + + +def test_apply_mutate_columns_multiindex(): + # GH 12652 + df = pd.DataFrame( + { + ("C", "julian"): [1, 2, 3], + ("B", "geoffrey"): [1, 2, 3], + ("A", "julian"): [1, 2, 3], + ("B", "julian"): [1, 2, 3], + ("A", "geoffrey"): [1, 2, 3], + ("C", "geoffrey"): [1, 2, 3], + }, + columns=pd.MultiIndex.from_tuples( + [ + ("A", "julian"), + ("A", "geoffrey"), + ("B", "julian"), + ("B", "geoffrey"), + ("C", "julian"), + ("C", "geoffrey"), + ] + ), + ) + + def add_column(grouped): + name = grouped.columns[0][1] + grouped["sum", name] = grouped.sum(axis=1) + return grouped + + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + gb = df.groupby(level=1, axis=1) + result = gb.apply(add_column) + expected = pd.DataFrame( + [ + [1, 1, 1, 3, 1, 1, 1, 3], + [2, 2, 2, 6, 2, 2, 2, 6], + [ + 3, + 3, + 3, + 9, + 3, + 3, + 3, + 9, + ], + ], + columns=pd.MultiIndex.from_tuples( + [ + ("geoffrey", "A", "geoffrey"), + ("geoffrey", "B", "geoffrey"), + ("geoffrey", "C", "geoffrey"), + ("geoffrey", "sum", "geoffrey"), + ("julian", "A", "julian"), + ("julian", "B", "julian"), + ("julian", "C", "julian"), + ("julian", "sum", "julian"), + ] + ), + ) + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_bin_groupby.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_bin_groupby.py new file mode 100644 index 0000000000000000000000000000000000000000..49b2e621b7adc97947ec9d6c376a9d0f10e672fb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_bin_groupby.py @@ -0,0 +1,65 @@ +import numpy as np +import pytest + +from pandas._libs import lib +import pandas.util._test_decorators as td + +import pandas as pd +import pandas._testing as tm + + +def assert_block_lengths(x): + assert len(x) == len(x._mgr.blocks[0].mgr_locs) + return 0 + + +def cumsum_max(x): + x.cumsum().max() + return 0 + + +@pytest.mark.parametrize( + "func", + [ + cumsum_max, + pytest.param(assert_block_lengths, marks=td.skip_array_manager_invalid_test), + ], +) +def test_mgr_locs_updated(func): + # https://github.com/pandas-dev/pandas/issues/31802 + # Some operations may require creating new blocks, which requires + # valid mgr_locs + df = pd.DataFrame({"A": ["a", "a", "a"], "B": ["a", "b", "b"], "C": [1, 1, 1]}) + result = df.groupby(["A", "B"]).agg(func) + expected = pd.DataFrame( + {"C": [0, 0]}, + index=pd.MultiIndex.from_product([["a"], ["a", "b"]], names=["A", "B"]), + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "binner,closed,expected", + [ + ( + np.array([0, 3, 6, 9], dtype=np.int64), + "left", + np.array([2, 5, 6], dtype=np.int64), + ), + ( + np.array([0, 3, 6, 9], dtype=np.int64), + "right", + np.array([3, 6, 6], dtype=np.int64), + ), + (np.array([0, 3, 6], dtype=np.int64), "left", np.array([2, 5], dtype=np.int64)), + ( + np.array([0, 3, 6], dtype=np.int64), + "right", + np.array([3, 6], dtype=np.int64), + ), + ], +) +def test_generate_bins(binner, closed, expected): + values = np.array([1, 2, 3, 4, 5, 6], dtype=np.int64) + result = lib.generate_bins_dt64(values, binner, closed=closed) + tm.assert_numpy_array_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_categorical.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_categorical.py new file mode 100644 index 0000000000000000000000000000000000000000..f60ff65536f20458220a763b946198842d9bf07e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_categorical.py @@ -0,0 +1,2169 @@ +from datetime import datetime + +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + Categorical, + CategoricalIndex, + DataFrame, + Index, + MultiIndex, + Series, + qcut, +) +import pandas._testing as tm +from pandas.api.typing import SeriesGroupBy +from pandas.tests.groupby import get_groupby_method_args + + +def cartesian_product_for_groupers(result, args, names, fill_value=np.nan): + """Reindex to a cartesian production for the groupers, + preserving the nature (Categorical) of each grouper + """ + + def f(a): + if isinstance(a, (CategoricalIndex, Categorical)): + categories = a.categories + a = Categorical.from_codes( + np.arange(len(categories)), categories=categories, ordered=a.ordered + ) + return a + + index = MultiIndex.from_product(map(f, args), names=names) + return result.reindex(index, fill_value=fill_value).sort_index() + + +_results_for_groupbys_with_missing_categories = { + # This maps the builtin groupby functions to their expected outputs for + # missing categories when they are called on a categorical grouper with + # observed=False. Some functions are expected to return NaN, some zero. + # These expected values can be used across several tests (i.e. they are + # the same for SeriesGroupBy and DataFrameGroupBy) but they should only be + # hardcoded in one place. + "all": np.nan, + "any": np.nan, + "count": 0, + "corrwith": np.nan, + "first": np.nan, + "idxmax": np.nan, + "idxmin": np.nan, + "last": np.nan, + "max": np.nan, + "mean": np.nan, + "median": np.nan, + "min": np.nan, + "nth": np.nan, + "nunique": 0, + "prod": np.nan, + "quantile": np.nan, + "sem": np.nan, + "size": 0, + "skew": np.nan, + "std": np.nan, + "sum": 0, + "var": np.nan, +} + + +def test_apply_use_categorical_name(df): + cats = qcut(df.C, 4) + + def get_stats(group): + return { + "min": group.min(), + "max": group.max(), + "count": group.count(), + "mean": group.mean(), + } + + result = df.groupby(cats, observed=False).D.apply(get_stats) + assert result.index.names[0] == "C" + + +def test_basic(using_infer_string): # TODO: split this test + cats = Categorical( + ["a", "a", "a", "b", "b", "b", "c", "c", "c"], + categories=["a", "b", "c", "d"], + ordered=True, + ) + data = DataFrame({"a": [1, 1, 1, 2, 2, 2, 3, 4, 5], "b": cats}) + + exp_index = CategoricalIndex(list("abcd"), name="b", ordered=True) + expected = DataFrame({"a": [1, 2, 4, np.nan]}, index=exp_index) + result = data.groupby("b", observed=False).mean() + tm.assert_frame_equal(result, expected) + + cat1 = Categorical(["a", "a", "b", "b"], categories=["a", "b", "z"], ordered=True) + cat2 = Categorical(["c", "d", "c", "d"], categories=["c", "d", "y"], ordered=True) + df = DataFrame({"A": cat1, "B": cat2, "values": [1, 2, 3, 4]}) + + # single grouper + gb = df.groupby("A", observed=False) + exp_idx = CategoricalIndex(["a", "b", "z"], name="A", ordered=True) + expected = DataFrame({"values": Series([3, 7, 0], index=exp_idx)}) + result = gb.sum(numeric_only=True) + tm.assert_frame_equal(result, expected) + + # GH 8623 + x = DataFrame( + [[1, "John P. Doe"], [2, "Jane Dove"], [1, "John P. Doe"]], + columns=["person_id", "person_name"], + ) + x["person_name"] = Categorical(x.person_name) + + g = x.groupby(["person_id"], observed=False) + result = g.transform(lambda x: x) + tm.assert_frame_equal(result, x[["person_name"]]) + + result = x.drop_duplicates("person_name") + expected = x.iloc[[0, 1]] + tm.assert_frame_equal(result, expected) + + def f(x): + return x.drop_duplicates("person_name").iloc[0] + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = g.apply(f) + expected = x.iloc[[0, 1]].copy() + expected.index = Index([1, 2], name="person_id") + dtype = "string[pyarrow_numpy]" if using_infer_string else object + expected["person_name"] = expected["person_name"].astype(dtype) + tm.assert_frame_equal(result, expected) + + # GH 9921 + # Monotonic + df = DataFrame({"a": [5, 15, 25]}) + c = pd.cut(df.a, bins=[0, 10, 20, 30, 40]) + + msg = "using SeriesGroupBy.sum" + with tm.assert_produces_warning(FutureWarning, match=msg): + # GH#53425 + result = df.a.groupby(c, observed=False).transform(sum) + tm.assert_series_equal(result, df["a"]) + + tm.assert_series_equal( + df.a.groupby(c, observed=False).transform(lambda xs: np.sum(xs)), df["a"] + ) + msg = "using DataFrameGroupBy.sum" + with tm.assert_produces_warning(FutureWarning, match=msg): + # GH#53425 + result = df.groupby(c, observed=False).transform(sum) + expected = df[["a"]] + tm.assert_frame_equal(result, expected) + + gbc = df.groupby(c, observed=False) + result = gbc.transform(lambda xs: np.max(xs, axis=0)) + tm.assert_frame_equal(result, df[["a"]]) + + result2 = gbc.transform(lambda xs: np.max(xs, axis=0)) + msg = "using DataFrameGroupBy.max" + with tm.assert_produces_warning(FutureWarning, match=msg): + # GH#53425 + result3 = gbc.transform(max) + result4 = gbc.transform(np.maximum.reduce) + result5 = gbc.transform(lambda xs: np.maximum.reduce(xs)) + tm.assert_frame_equal(result2, df[["a"]], check_dtype=False) + tm.assert_frame_equal(result3, df[["a"]], check_dtype=False) + tm.assert_frame_equal(result4, df[["a"]]) + tm.assert_frame_equal(result5, df[["a"]]) + + # Filter + tm.assert_series_equal(df.a.groupby(c, observed=False).filter(np.all), df["a"]) + tm.assert_frame_equal(df.groupby(c, observed=False).filter(np.all), df) + + # Non-monotonic + df = DataFrame({"a": [5, 15, 25, -5]}) + c = pd.cut(df.a, bins=[-10, 0, 10, 20, 30, 40]) + + msg = "using SeriesGroupBy.sum" + with tm.assert_produces_warning(FutureWarning, match=msg): + # GH#53425 + result = df.a.groupby(c, observed=False).transform(sum) + tm.assert_series_equal(result, df["a"]) + + tm.assert_series_equal( + df.a.groupby(c, observed=False).transform(lambda xs: np.sum(xs)), df["a"] + ) + msg = "using DataFrameGroupBy.sum" + with tm.assert_produces_warning(FutureWarning, match=msg): + # GH#53425 + result = df.groupby(c, observed=False).transform(sum) + expected = df[["a"]] + tm.assert_frame_equal(result, expected) + + tm.assert_frame_equal( + df.groupby(c, observed=False).transform(lambda xs: np.sum(xs)), df[["a"]] + ) + + # GH 9603 + df = DataFrame({"a": [1, 0, 0, 0]}) + c = pd.cut(df.a, [0, 1, 2, 3, 4], labels=Categorical(list("abcd"))) + result = df.groupby(c, observed=False).apply(len) + + exp_index = CategoricalIndex(c.values.categories, ordered=c.values.ordered) + expected = Series([1, 0, 0, 0], index=exp_index) + expected.index.name = "a" + tm.assert_series_equal(result, expected) + + # more basic + levels = ["foo", "bar", "baz", "qux"] + codes = np.random.default_rng(2).integers(0, 4, size=100) + + cats = Categorical.from_codes(codes, levels, ordered=True) + + data = DataFrame(np.random.default_rng(2).standard_normal((100, 4))) + + result = data.groupby(cats, observed=False).mean() + + expected = data.groupby(np.asarray(cats), observed=False).mean() + exp_idx = CategoricalIndex(levels, categories=cats.categories, ordered=True) + expected = expected.reindex(exp_idx) + + tm.assert_frame_equal(result, expected) + + grouped = data.groupby(cats, observed=False) + desc_result = grouped.describe() + + idx = cats.codes.argsort() + ord_labels = np.asarray(cats).take(idx) + ord_data = data.take(idx) + + exp_cats = Categorical( + ord_labels, ordered=True, categories=["foo", "bar", "baz", "qux"] + ) + expected = ord_data.groupby(exp_cats, sort=False, observed=False).describe() + tm.assert_frame_equal(desc_result, expected) + + # GH 10460 + expc = Categorical.from_codes(np.arange(4).repeat(8), levels, ordered=True) + exp = CategoricalIndex(expc) + tm.assert_index_equal( + (desc_result.stack(future_stack=True).index.get_level_values(0)), exp + ) + exp = Index(["count", "mean", "std", "min", "25%", "50%", "75%", "max"] * 4) + tm.assert_index_equal( + (desc_result.stack(future_stack=True).index.get_level_values(1)), exp + ) + + +def test_level_get_group(observed): + # GH15155 + df = DataFrame( + data=np.arange(2, 22, 2), + index=MultiIndex( + levels=[CategoricalIndex(["a", "b"]), range(10)], + codes=[[0] * 5 + [1] * 5, range(10)], + names=["Index1", "Index2"], + ), + ) + g = df.groupby(level=["Index1"], observed=observed) + + # expected should equal test.loc[["a"]] + # GH15166 + expected = DataFrame( + data=np.arange(2, 12, 2), + index=MultiIndex( + levels=[CategoricalIndex(["a", "b"]), range(5)], + codes=[[0] * 5, range(5)], + names=["Index1", "Index2"], + ), + ) + msg = "you will need to pass a length-1 tuple" + with tm.assert_produces_warning(FutureWarning, match=msg): + # GH#25971 - warn when not passing a length-1 tuple + result = g.get_group("a") + + tm.assert_frame_equal(result, expected) + + +def test_sorting_with_different_categoricals(): + # GH 24271 + df = DataFrame( + { + "group": ["A"] * 6 + ["B"] * 6, + "dose": ["high", "med", "low"] * 4, + "outcomes": np.arange(12.0), + } + ) + + df.dose = Categorical(df.dose, categories=["low", "med", "high"], ordered=True) + + result = df.groupby("group")["dose"].value_counts() + result = result.sort_index(level=0, sort_remaining=True) + index = ["low", "med", "high", "low", "med", "high"] + index = Categorical(index, categories=["low", "med", "high"], ordered=True) + index = [["A", "A", "A", "B", "B", "B"], CategoricalIndex(index)] + index = MultiIndex.from_arrays(index, names=["group", "dose"]) + expected = Series([2] * 6, index=index, name="count") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("ordered", [True, False]) +def test_apply(ordered): + # GH 10138 + + dense = Categorical(list("abc"), ordered=ordered) + + # 'b' is in the categories but not in the list + missing = Categorical(list("aaa"), categories=["a", "b"], ordered=ordered) + values = np.arange(len(dense)) + df = DataFrame({"missing": missing, "dense": dense, "values": values}) + grouped = df.groupby(["missing", "dense"], observed=True) + + # missing category 'b' should still exist in the output index + idx = MultiIndex.from_arrays([missing, dense], names=["missing", "dense"]) + expected = DataFrame([0, 1, 2.0], index=idx, columns=["values"]) + + result = grouped.apply(lambda x: np.mean(x, axis=0)) + tm.assert_frame_equal(result, expected) + + result = grouped.mean() + tm.assert_frame_equal(result, expected) + + msg = "using DataFrameGroupBy.mean" + with tm.assert_produces_warning(FutureWarning, match=msg): + # GH#53425 + result = grouped.agg(np.mean) + tm.assert_frame_equal(result, expected) + + # but for transform we should still get back the original index + idx = MultiIndex.from_arrays([missing, dense], names=["missing", "dense"]) + expected = Series(1, index=idx) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = grouped.apply(lambda x: 1) + tm.assert_series_equal(result, expected) + + +def test_observed(observed): + # multiple groupers, don't re-expand the output space + # of the grouper + # gh-14942 (implement) + # gh-10132 (back-compat) + # gh-8138 (back-compat) + # gh-8869 + + cat1 = Categorical(["a", "a", "b", "b"], categories=["a", "b", "z"], ordered=True) + cat2 = Categorical(["c", "d", "c", "d"], categories=["c", "d", "y"], ordered=True) + df = DataFrame({"A": cat1, "B": cat2, "values": [1, 2, 3, 4]}) + df["C"] = ["foo", "bar"] * 2 + + # multiple groupers with a non-cat + gb = df.groupby(["A", "B", "C"], observed=observed) + exp_index = MultiIndex.from_arrays( + [cat1, cat2, ["foo", "bar"] * 2], names=["A", "B", "C"] + ) + expected = DataFrame({"values": Series([1, 2, 3, 4], index=exp_index)}).sort_index() + result = gb.sum() + if not observed: + expected = cartesian_product_for_groupers( + expected, [cat1, cat2, ["foo", "bar"]], list("ABC"), fill_value=0 + ) + + tm.assert_frame_equal(result, expected) + + gb = df.groupby(["A", "B"], observed=observed) + exp_index = MultiIndex.from_arrays([cat1, cat2], names=["A", "B"]) + expected = DataFrame( + {"values": [1, 2, 3, 4], "C": ["foo", "bar", "foo", "bar"]}, index=exp_index + ) + result = gb.sum() + if not observed: + expected = cartesian_product_for_groupers( + expected, [cat1, cat2], list("AB"), fill_value=0 + ) + + tm.assert_frame_equal(result, expected) + + # https://github.com/pandas-dev/pandas/issues/8138 + d = { + "cat": Categorical( + ["a", "b", "a", "b"], categories=["a", "b", "c"], ordered=True + ), + "ints": [1, 1, 2, 2], + "val": [10, 20, 30, 40], + } + df = DataFrame(d) + + # Grouping on a single column + groups_single_key = df.groupby("cat", observed=observed) + result = groups_single_key.mean() + + exp_index = CategoricalIndex( + list("ab"), name="cat", categories=list("abc"), ordered=True + ) + expected = DataFrame({"ints": [1.5, 1.5], "val": [20.0, 30]}, index=exp_index) + if not observed: + index = CategoricalIndex( + list("abc"), name="cat", categories=list("abc"), ordered=True + ) + expected = expected.reindex(index) + + tm.assert_frame_equal(result, expected) + + # Grouping on two columns + groups_double_key = df.groupby(["cat", "ints"], observed=observed) + result = groups_double_key.agg("mean") + expected = DataFrame( + { + "val": [10.0, 30.0, 20.0, 40.0], + "cat": Categorical( + ["a", "a", "b", "b"], categories=["a", "b", "c"], ordered=True + ), + "ints": [1, 2, 1, 2], + } + ).set_index(["cat", "ints"]) + if not observed: + expected = cartesian_product_for_groupers( + expected, [df.cat.values, [1, 2]], ["cat", "ints"] + ) + + tm.assert_frame_equal(result, expected) + + # GH 10132 + for key in [("a", 1), ("b", 2), ("b", 1), ("a", 2)]: + c, i = key + result = groups_double_key.get_group(key) + expected = df[(df.cat == c) & (df.ints == i)] + tm.assert_frame_equal(result, expected) + + # gh-8869 + # with as_index + d = { + "foo": [10, 8, 4, 8, 4, 1, 1], + "bar": [10, 20, 30, 40, 50, 60, 70], + "baz": ["d", "c", "e", "a", "a", "d", "c"], + } + df = DataFrame(d) + cat = pd.cut(df["foo"], np.linspace(0, 10, 3)) + df["range"] = cat + groups = df.groupby(["range", "baz"], as_index=False, observed=observed) + result = groups.agg("mean") + + groups2 = df.groupby(["range", "baz"], as_index=True, observed=observed) + expected = groups2.agg("mean").reset_index() + tm.assert_frame_equal(result, expected) + + +def test_observed_codes_remap(observed): + d = {"C1": [3, 3, 4, 5], "C2": [1, 2, 3, 4], "C3": [10, 100, 200, 34]} + df = DataFrame(d) + values = pd.cut(df["C1"], [1, 2, 3, 6]) + values.name = "cat" + groups_double_key = df.groupby([values, "C2"], observed=observed) + + idx = MultiIndex.from_arrays([values, [1, 2, 3, 4]], names=["cat", "C2"]) + expected = DataFrame( + {"C1": [3.0, 3.0, 4.0, 5.0], "C3": [10.0, 100.0, 200.0, 34.0]}, index=idx + ) + if not observed: + expected = cartesian_product_for_groupers( + expected, [values.values, [1, 2, 3, 4]], ["cat", "C2"] + ) + + result = groups_double_key.agg("mean") + tm.assert_frame_equal(result, expected) + + +def test_observed_perf(): + # we create a cartesian product, so this is + # non-performant if we don't use observed values + # gh-14942 + df = DataFrame( + { + "cat": np.random.default_rng(2).integers(0, 255, size=30000), + "int_id": np.random.default_rng(2).integers(0, 255, size=30000), + "other_id": np.random.default_rng(2).integers(0, 10000, size=30000), + "foo": 0, + } + ) + df["cat"] = df.cat.astype(str).astype("category") + + grouped = df.groupby(["cat", "int_id", "other_id"], observed=True) + result = grouped.count() + assert result.index.levels[0].nunique() == df.cat.nunique() + assert result.index.levels[1].nunique() == df.int_id.nunique() + assert result.index.levels[2].nunique() == df.other_id.nunique() + + +def test_observed_groups(observed): + # gh-20583 + # test that we have the appropriate groups + + cat = Categorical(["a", "c", "a"], categories=["a", "b", "c"]) + df = DataFrame({"cat": cat, "vals": [1, 2, 3]}) + g = df.groupby("cat", observed=observed) + + result = g.groups + if observed: + expected = {"a": Index([0, 2], dtype="int64"), "c": Index([1], dtype="int64")} + else: + expected = { + "a": Index([0, 2], dtype="int64"), + "b": Index([], dtype="int64"), + "c": Index([1], dtype="int64"), + } + + tm.assert_dict_equal(result, expected) + + +@pytest.mark.parametrize( + "keys, expected_values, expected_index_levels", + [ + ("a", [15, 9, 0], CategoricalIndex([1, 2, 3], name="a")), + ( + ["a", "b"], + [7, 8, 0, 0, 0, 9, 0, 0, 0], + [CategoricalIndex([1, 2, 3], name="a"), Index([4, 5, 6])], + ), + ( + ["a", "a2"], + [15, 0, 0, 0, 9, 0, 0, 0, 0], + [ + CategoricalIndex([1, 2, 3], name="a"), + CategoricalIndex([1, 2, 3], name="a"), + ], + ), + ], +) +@pytest.mark.parametrize("test_series", [True, False]) +def test_unobserved_in_index(keys, expected_values, expected_index_levels, test_series): + # GH#49354 - ensure unobserved cats occur when grouping by index levels + df = DataFrame( + { + "a": Categorical([1, 1, 2], categories=[1, 2, 3]), + "a2": Categorical([1, 1, 2], categories=[1, 2, 3]), + "b": [4, 5, 6], + "c": [7, 8, 9], + } + ).set_index(["a", "a2"]) + if "b" not in keys: + # Only keep b when it is used for grouping for consistent columns in the result + df = df.drop(columns="b") + + gb = df.groupby(keys, observed=False) + if test_series: + gb = gb["c"] + result = gb.sum() + + if len(keys) == 1: + index = expected_index_levels + else: + codes = [[0, 0, 0, 1, 1, 1, 2, 2, 2], 3 * [0, 1, 2]] + index = MultiIndex( + expected_index_levels, + codes=codes, + names=keys, + ) + expected = DataFrame({"c": expected_values}, index=index) + if test_series: + expected = expected["c"] + tm.assert_equal(result, expected) + + +def test_observed_groups_with_nan(observed): + # GH 24740 + df = DataFrame( + { + "cat": Categorical(["a", np.nan, "a"], categories=["a", "b", "d"]), + "vals": [1, 2, 3], + } + ) + g = df.groupby("cat", observed=observed) + result = g.groups + if observed: + expected = {"a": Index([0, 2], dtype="int64")} + else: + expected = { + "a": Index([0, 2], dtype="int64"), + "b": Index([], dtype="int64"), + "d": Index([], dtype="int64"), + } + tm.assert_dict_equal(result, expected) + + +def test_observed_nth(): + # GH 26385 + cat = Categorical(["a", np.nan, np.nan], categories=["a", "b", "c"]) + ser = Series([1, 2, 3]) + df = DataFrame({"cat": cat, "ser": ser}) + + result = df.groupby("cat", observed=False)["ser"].nth(0) + expected = df["ser"].iloc[[0]] + tm.assert_series_equal(result, expected) + + +def test_dataframe_categorical_with_nan(observed): + # GH 21151 + s1 = Categorical([np.nan, "a", np.nan, "a"], categories=["a", "b", "c"]) + s2 = Series([1, 2, 3, 4]) + df = DataFrame({"s1": s1, "s2": s2}) + result = df.groupby("s1", observed=observed).first().reset_index() + if observed: + expected = DataFrame( + {"s1": Categorical(["a"], categories=["a", "b", "c"]), "s2": [2]} + ) + else: + expected = DataFrame( + { + "s1": Categorical(["a", "b", "c"], categories=["a", "b", "c"]), + "s2": [2, np.nan, np.nan], + } + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("ordered", [True, False]) +@pytest.mark.parametrize("observed", [True, False]) +@pytest.mark.parametrize("sort", [True, False]) +def test_dataframe_categorical_ordered_observed_sort(ordered, observed, sort): + # GH 25871: Fix groupby sorting on ordered Categoricals + # GH 25167: Groupby with observed=True doesn't sort + + # Build a dataframe with cat having one unobserved category ('missing'), + # and a Series with identical values + label = Categorical( + ["d", "a", "b", "a", "d", "b"], + categories=["a", "b", "missing", "d"], + ordered=ordered, + ) + val = Series(["d", "a", "b", "a", "d", "b"]) + df = DataFrame({"label": label, "val": val}) + + # aggregate on the Categorical + result = df.groupby("label", observed=observed, sort=sort)["val"].aggregate("first") + + # If ordering works, we expect index labels equal to aggregation results, + # except for 'observed=False': label 'missing' has aggregation None + label = Series(result.index.array, dtype="object") + aggr = Series(result.array) + if not observed: + aggr[aggr.isna()] = "missing" + if not all(label == aggr): + msg = ( + "Labels and aggregation results not consistently sorted\n" + f"for (ordered={ordered}, observed={observed}, sort={sort})\n" + f"Result:\n{result}" + ) + assert False, msg + + +def test_datetime(): + # GH9049: ensure backward compatibility + levels = pd.date_range("2014-01-01", periods=4) + codes = np.random.default_rng(2).integers(0, 4, size=100) + + cats = Categorical.from_codes(codes, levels, ordered=True) + + data = DataFrame(np.random.default_rng(2).standard_normal((100, 4))) + result = data.groupby(cats, observed=False).mean() + + expected = data.groupby(np.asarray(cats), observed=False).mean() + expected = expected.reindex(levels) + expected.index = CategoricalIndex( + expected.index, categories=expected.index, ordered=True + ) + + tm.assert_frame_equal(result, expected) + + grouped = data.groupby(cats, observed=False) + desc_result = grouped.describe() + + idx = cats.codes.argsort() + ord_labels = cats.take(idx) + ord_data = data.take(idx) + expected = ord_data.groupby(ord_labels, observed=False).describe() + tm.assert_frame_equal(desc_result, expected) + tm.assert_index_equal(desc_result.index, expected.index) + tm.assert_index_equal( + desc_result.index.get_level_values(0), expected.index.get_level_values(0) + ) + + # GH 10460 + expc = Categorical.from_codes(np.arange(4).repeat(8), levels, ordered=True) + exp = CategoricalIndex(expc) + tm.assert_index_equal( + (desc_result.stack(future_stack=True).index.get_level_values(0)), exp + ) + exp = Index(["count", "mean", "std", "min", "25%", "50%", "75%", "max"] * 4) + tm.assert_index_equal( + (desc_result.stack(future_stack=True).index.get_level_values(1)), exp + ) + + +def test_categorical_index(): + s = np.random.default_rng(2) + levels = ["foo", "bar", "baz", "qux"] + codes = s.integers(0, 4, size=20) + cats = Categorical.from_codes(codes, levels, ordered=True) + df = DataFrame(np.repeat(np.arange(20), 4).reshape(-1, 4), columns=list("abcd")) + df["cats"] = cats + + # with a cat index + result = df.set_index("cats").groupby(level=0, observed=False).sum() + expected = df[list("abcd")].groupby(cats.codes, observed=False).sum() + expected.index = CategoricalIndex( + Categorical.from_codes([0, 1, 2, 3], levels, ordered=True), name="cats" + ) + tm.assert_frame_equal(result, expected) + + # with a cat column, should produce a cat index + result = df.groupby("cats", observed=False).sum() + expected = df[list("abcd")].groupby(cats.codes, observed=False).sum() + expected.index = CategoricalIndex( + Categorical.from_codes([0, 1, 2, 3], levels, ordered=True), name="cats" + ) + tm.assert_frame_equal(result, expected) + + +def test_describe_categorical_columns(): + # GH 11558 + cats = CategoricalIndex( + ["qux", "foo", "baz", "bar"], + categories=["foo", "bar", "baz", "qux"], + ordered=True, + ) + df = DataFrame(np.random.default_rng(2).standard_normal((20, 4)), columns=cats) + result = df.groupby([1, 2, 3, 4] * 5).describe() + + tm.assert_index_equal(result.stack(future_stack=True).columns, cats) + tm.assert_categorical_equal( + result.stack(future_stack=True).columns.values, cats.values + ) + + +def test_unstack_categorical(): + # GH11558 (example is taken from the original issue) + df = DataFrame( + {"a": range(10), "medium": ["A", "B"] * 5, "artist": list("XYXXY") * 2} + ) + df["medium"] = df["medium"].astype("category") + + gcat = df.groupby(["artist", "medium"], observed=False)["a"].count().unstack() + result = gcat.describe() + + exp_columns = CategoricalIndex(["A", "B"], ordered=False, name="medium") + tm.assert_index_equal(result.columns, exp_columns) + tm.assert_categorical_equal(result.columns.values, exp_columns.values) + + result = gcat["A"] + gcat["B"] + expected = Series([6, 4], index=Index(["X", "Y"], name="artist")) + tm.assert_series_equal(result, expected) + + +def test_bins_unequal_len(): + # GH3011 + series = Series([np.nan, np.nan, 1, 1, 2, 2, 3, 3, 4, 4]) + bins = pd.cut(series.dropna().values, 4) + + # len(bins) != len(series) here + with pytest.raises(ValueError, match="Grouper and axis must be same length"): + series.groupby(bins).mean() + + +@pytest.mark.parametrize( + ["series", "data"], + [ + # Group a series with length and index equal to those of the grouper. + (Series(range(4)), {"A": [0, 3], "B": [1, 2]}), + # Group a series with length equal to that of the grouper and index unequal to + # that of the grouper. + (Series(range(4)).rename(lambda idx: idx + 1), {"A": [2], "B": [0, 1]}), + # GH44179: Group a series with length unequal to that of the grouper. + (Series(range(7)), {"A": [0, 3], "B": [1, 2]}), + ], +) +def test_categorical_series(series, data): + # Group the given series by a series with categorical data type such that group A + # takes indices 0 and 3 and group B indices 1 and 2, obtaining the values mapped in + # the given data. + groupby = series.groupby(Series(list("ABBA"), dtype="category"), observed=False) + result = groupby.aggregate(list) + expected = Series(data, index=CategoricalIndex(data.keys())) + tm.assert_series_equal(result, expected) + + +def test_as_index(): + # GH13204 + df = DataFrame( + { + "cat": Categorical([1, 2, 2], [1, 2, 3]), + "A": [10, 11, 11], + "B": [101, 102, 103], + } + ) + result = df.groupby(["cat", "A"], as_index=False, observed=True).sum() + expected = DataFrame( + { + "cat": Categorical([1, 2], categories=df.cat.cat.categories), + "A": [10, 11], + "B": [101, 205], + }, + columns=["cat", "A", "B"], + ) + tm.assert_frame_equal(result, expected) + + # function grouper + f = lambda r: df.loc[r, "A"] + msg = "A grouping .* was excluded from the result" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = df.groupby(["cat", f], as_index=False, observed=True).sum() + expected = DataFrame( + { + "cat": Categorical([1, 2], categories=df.cat.cat.categories), + "A": [10, 22], + "B": [101, 205], + }, + columns=["cat", "A", "B"], + ) + tm.assert_frame_equal(result, expected) + + # another not in-axis grouper (conflicting names in index) + s = Series(["a", "b", "b"], name="cat") + msg = "A grouping .* was excluded from the result" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = df.groupby(["cat", s], as_index=False, observed=True).sum() + tm.assert_frame_equal(result, expected) + + # is original index dropped? + group_columns = ["cat", "A"] + expected = DataFrame( + { + "cat": Categorical([1, 2], categories=df.cat.cat.categories), + "A": [10, 11], + "B": [101, 205], + }, + columns=["cat", "A", "B"], + ) + + for name in [None, "X", "B"]: + df.index = Index(list("abc"), name=name) + result = df.groupby(group_columns, as_index=False, observed=True).sum() + + tm.assert_frame_equal(result, expected) + + +def test_preserve_categories(): + # GH-13179 + categories = list("abc") + + # ordered=True + df = DataFrame({"A": Categorical(list("ba"), categories=categories, ordered=True)}) + sort_index = CategoricalIndex(categories, categories, ordered=True, name="A") + nosort_index = CategoricalIndex(list("bac"), categories, ordered=True, name="A") + tm.assert_index_equal( + df.groupby("A", sort=True, observed=False).first().index, sort_index + ) + # GH#42482 - don't sort result when sort=False, even when ordered=True + tm.assert_index_equal( + df.groupby("A", sort=False, observed=False).first().index, nosort_index + ) + + # ordered=False + df = DataFrame({"A": Categorical(list("ba"), categories=categories, ordered=False)}) + sort_index = CategoricalIndex(categories, categories, ordered=False, name="A") + # GH#48749 - don't change order of categories + # GH#42482 - don't sort result when sort=False, even when ordered=True + nosort_index = CategoricalIndex(list("bac"), list("abc"), ordered=False, name="A") + tm.assert_index_equal( + df.groupby("A", sort=True, observed=False).first().index, sort_index + ) + tm.assert_index_equal( + df.groupby("A", sort=False, observed=False).first().index, nosort_index + ) + + +def test_preserve_categorical_dtype(): + # GH13743, GH13854 + df = DataFrame( + { + "A": [1, 2, 1, 1, 2], + "B": [10, 16, 22, 28, 34], + "C1": Categorical(list("abaab"), categories=list("bac"), ordered=False), + "C2": Categorical(list("abaab"), categories=list("bac"), ordered=True), + } + ) + # single grouper + exp_full = DataFrame( + { + "A": [2.0, 1.0, np.nan], + "B": [25.0, 20.0, np.nan], + "C1": Categorical(list("bac"), categories=list("bac"), ordered=False), + "C2": Categorical(list("bac"), categories=list("bac"), ordered=True), + } + ) + for col in ["C1", "C2"]: + result1 = df.groupby(by=col, as_index=False, observed=False).mean( + numeric_only=True + ) + result2 = ( + df.groupby(by=col, as_index=True, observed=False) + .mean(numeric_only=True) + .reset_index() + ) + expected = exp_full.reindex(columns=result1.columns) + tm.assert_frame_equal(result1, expected) + tm.assert_frame_equal(result2, expected) + + +@pytest.mark.parametrize( + "func, values", + [ + ("first", ["second", "first"]), + ("last", ["fourth", "third"]), + ("min", ["fourth", "first"]), + ("max", ["second", "third"]), + ], +) +def test_preserve_on_ordered_ops(func, values): + # gh-18502 + # preserve the categoricals on ops + c = Categorical(["first", "second", "third", "fourth"], ordered=True) + df = DataFrame({"payload": [-1, -2, -1, -2], "col": c}) + g = df.groupby("payload") + result = getattr(g, func)() + expected = DataFrame( + {"payload": [-2, -1], "col": Series(values, dtype=c.dtype)} + ).set_index("payload") + tm.assert_frame_equal(result, expected) + + # we should also preserve categorical for SeriesGroupBy + sgb = df.groupby("payload")["col"] + result = getattr(sgb, func)() + expected = expected["col"] + tm.assert_series_equal(result, expected) + + +def test_categorical_no_compress(): + data = Series(np.random.default_rng(2).standard_normal(9)) + + codes = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2]) + cats = Categorical.from_codes(codes, [0, 1, 2], ordered=True) + + result = data.groupby(cats, observed=False).mean() + exp = data.groupby(codes, observed=False).mean() + + exp.index = CategoricalIndex( + exp.index, categories=cats.categories, ordered=cats.ordered + ) + tm.assert_series_equal(result, exp) + + codes = np.array([0, 0, 0, 1, 1, 1, 3, 3, 3]) + cats = Categorical.from_codes(codes, [0, 1, 2, 3], ordered=True) + + result = data.groupby(cats, observed=False).mean() + exp = data.groupby(codes, observed=False).mean().reindex(cats.categories) + exp.index = CategoricalIndex( + exp.index, categories=cats.categories, ordered=cats.ordered + ) + tm.assert_series_equal(result, exp) + + cats = Categorical( + ["a", "a", "a", "b", "b", "b", "c", "c", "c"], + categories=["a", "b", "c", "d"], + ordered=True, + ) + data = DataFrame({"a": [1, 1, 1, 2, 2, 2, 3, 4, 5], "b": cats}) + + result = data.groupby("b", observed=False).mean() + result = result["a"].values + exp = np.array([1, 2, 4, np.nan]) + tm.assert_numpy_array_equal(result, exp) + + +def test_groupby_empty_with_category(): + # GH-9614 + # test fix for when group by on None resulted in + # coercion of dtype categorical -> float + df = DataFrame({"A": [None] * 3, "B": Categorical(["train", "train", "test"])}) + result = df.groupby("A").first()["B"] + expected = Series( + Categorical([], categories=["test", "train"]), + index=Series([], dtype="object", name="A"), + name="B", + ) + tm.assert_series_equal(result, expected) + + +def test_sort(): + # https://stackoverflow.com/questions/23814368/sorting-pandas- + # categorical-labels-after-groupby + # This should result in a properly sorted Series so that the plot + # has a sorted x axis + # self.cat.groupby(['value_group'])['value_group'].count().plot(kind='bar') + + df = DataFrame({"value": np.random.default_rng(2).integers(0, 10000, 100)}) + labels = [f"{i} - {i+499}" for i in range(0, 10000, 500)] + cat_labels = Categorical(labels, labels) + + df = df.sort_values(by=["value"], ascending=True) + df["value_group"] = pd.cut( + df.value, range(0, 10500, 500), right=False, labels=cat_labels + ) + + res = df.groupby(["value_group"], observed=False)["value_group"].count() + exp = res[sorted(res.index, key=lambda x: float(x.split()[0]))] + exp.index = CategoricalIndex(exp.index, name=exp.index.name) + tm.assert_series_equal(res, exp) + + +@pytest.mark.parametrize("ordered", [True, False]) +def test_sort2(sort, ordered): + # dataframe groupby sort was being ignored # GH 8868 + # GH#48749 - don't change order of categories + # GH#42482 - don't sort result when sort=False, even when ordered=True + df = DataFrame( + [ + ["(7.5, 10]", 10, 10], + ["(7.5, 10]", 8, 20], + ["(2.5, 5]", 5, 30], + ["(5, 7.5]", 6, 40], + ["(2.5, 5]", 4, 50], + ["(0, 2.5]", 1, 60], + ["(5, 7.5]", 7, 70], + ], + columns=["range", "foo", "bar"], + ) + df["range"] = Categorical(df["range"], ordered=ordered) + result = df.groupby("range", sort=sort, observed=False).first() + + if sort: + data_values = [[1, 60], [5, 30], [6, 40], [10, 10]] + index_values = ["(0, 2.5]", "(2.5, 5]", "(5, 7.5]", "(7.5, 10]"] + else: + data_values = [[10, 10], [5, 30], [6, 40], [1, 60]] + index_values = ["(7.5, 10]", "(2.5, 5]", "(5, 7.5]", "(0, 2.5]"] + expected = DataFrame( + data_values, + columns=["foo", "bar"], + index=CategoricalIndex(index_values, name="range", ordered=ordered), + ) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("ordered", [True, False]) +def test_sort_datetimelike(sort, ordered): + # GH10505 + # GH#42482 - don't sort result when sort=False, even when ordered=True + + # use same data as test_groupby_sort_categorical, which category is + # corresponding to datetime.month + df = DataFrame( + { + "dt": [ + datetime(2011, 7, 1), + datetime(2011, 7, 1), + datetime(2011, 2, 1), + datetime(2011, 5, 1), + datetime(2011, 2, 1), + datetime(2011, 1, 1), + datetime(2011, 5, 1), + ], + "foo": [10, 8, 5, 6, 4, 1, 7], + "bar": [10, 20, 30, 40, 50, 60, 70], + }, + columns=["dt", "foo", "bar"], + ) + + # ordered=True + df["dt"] = Categorical(df["dt"], ordered=ordered) + if sort: + data_values = [[1, 60], [5, 30], [6, 40], [10, 10]] + index_values = [ + datetime(2011, 1, 1), + datetime(2011, 2, 1), + datetime(2011, 5, 1), + datetime(2011, 7, 1), + ] + else: + data_values = [[10, 10], [5, 30], [6, 40], [1, 60]] + index_values = [ + datetime(2011, 7, 1), + datetime(2011, 2, 1), + datetime(2011, 5, 1), + datetime(2011, 1, 1), + ] + expected = DataFrame( + data_values, + columns=["foo", "bar"], + index=CategoricalIndex(index_values, name="dt", ordered=ordered), + ) + result = df.groupby("dt", sort=sort, observed=False).first() + tm.assert_frame_equal(result, expected) + + +def test_empty_sum(): + # https://github.com/pandas-dev/pandas/issues/18678 + df = DataFrame( + {"A": Categorical(["a", "a", "b"], categories=["a", "b", "c"]), "B": [1, 2, 1]} + ) + expected_idx = CategoricalIndex(["a", "b", "c"], name="A") + + # 0 by default + result = df.groupby("A", observed=False).B.sum() + expected = Series([3, 1, 0], expected_idx, name="B") + tm.assert_series_equal(result, expected) + + # min_count=0 + result = df.groupby("A", observed=False).B.sum(min_count=0) + expected = Series([3, 1, 0], expected_idx, name="B") + tm.assert_series_equal(result, expected) + + # min_count=1 + result = df.groupby("A", observed=False).B.sum(min_count=1) + expected = Series([3, 1, np.nan], expected_idx, name="B") + tm.assert_series_equal(result, expected) + + # min_count>1 + result = df.groupby("A", observed=False).B.sum(min_count=2) + expected = Series([3, np.nan, np.nan], expected_idx, name="B") + tm.assert_series_equal(result, expected) + + +def test_empty_prod(): + # https://github.com/pandas-dev/pandas/issues/18678 + df = DataFrame( + {"A": Categorical(["a", "a", "b"], categories=["a", "b", "c"]), "B": [1, 2, 1]} + ) + + expected_idx = CategoricalIndex(["a", "b", "c"], name="A") + + # 1 by default + result = df.groupby("A", observed=False).B.prod() + expected = Series([2, 1, 1], expected_idx, name="B") + tm.assert_series_equal(result, expected) + + # min_count=0 + result = df.groupby("A", observed=False).B.prod(min_count=0) + expected = Series([2, 1, 1], expected_idx, name="B") + tm.assert_series_equal(result, expected) + + # min_count=1 + result = df.groupby("A", observed=False).B.prod(min_count=1) + expected = Series([2, 1, np.nan], expected_idx, name="B") + tm.assert_series_equal(result, expected) + + +def test_groupby_multiindex_categorical_datetime(): + # https://github.com/pandas-dev/pandas/issues/21390 + + df = DataFrame( + { + "key1": Categorical(list("abcbabcba")), + "key2": Categorical( + list(pd.date_range("2018-06-01 00", freq="1min", periods=3)) * 3 + ), + "values": np.arange(9), + } + ) + result = df.groupby(["key1", "key2"], observed=False).mean() + + idx = MultiIndex.from_product( + [ + Categorical(["a", "b", "c"]), + Categorical(pd.date_range("2018-06-01 00", freq="1min", periods=3)), + ], + names=["key1", "key2"], + ) + expected = DataFrame({"values": [0, 4, 8, 3, 4, 5, 6, np.nan, 2]}, index=idx) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "as_index, expected", + [ + ( + True, + Series( + index=MultiIndex.from_arrays( + [Series([1, 1, 2], dtype="category"), [1, 2, 2]], names=["a", "b"] + ), + data=[1, 2, 3], + name="x", + ), + ), + ( + False, + DataFrame( + { + "a": Series([1, 1, 2], dtype="category"), + "b": [1, 2, 2], + "x": [1, 2, 3], + } + ), + ), + ], +) +def test_groupby_agg_observed_true_single_column(as_index, expected): + # GH-23970 + df = DataFrame( + {"a": Series([1, 1, 2], dtype="category"), "b": [1, 2, 2], "x": [1, 2, 3]} + ) + + result = df.groupby(["a", "b"], as_index=as_index, observed=True)["x"].sum() + + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize("fill_value", [None, np.nan, pd.NaT]) +def test_shift(fill_value): + ct = Categorical( + ["a", "b", "c", "d"], categories=["a", "b", "c", "d"], ordered=False + ) + expected = Categorical( + [None, "a", "b", "c"], categories=["a", "b", "c", "d"], ordered=False + ) + res = ct.shift(1, fill_value=fill_value) + tm.assert_equal(res, expected) + + +@pytest.fixture +def df_cat(df): + """ + DataFrame with multiple categorical columns and a column of integers. + Shortened so as not to contain all possible combinations of categories. + Useful for testing `observed` kwarg functionality on GroupBy objects. + + Parameters + ---------- + df: DataFrame + Non-categorical, longer DataFrame from another fixture, used to derive + this one + + Returns + ------- + df_cat: DataFrame + """ + df_cat = df.copy()[:4] # leave out some groups + df_cat["A"] = df_cat["A"].astype("category") + df_cat["B"] = df_cat["B"].astype("category") + df_cat["C"] = Series([1, 2, 3, 4]) + df_cat = df_cat.drop(["D"], axis=1) + return df_cat + + +@pytest.mark.parametrize("operation", ["agg", "apply"]) +def test_seriesgroupby_observed_true(df_cat, operation): + # GH#24880 + # GH#49223 - order of results was wrong when grouping by index levels + lev_a = Index(["bar", "bar", "foo", "foo"], dtype=df_cat["A"].dtype, name="A") + lev_b = Index(["one", "three", "one", "two"], dtype=df_cat["B"].dtype, name="B") + index = MultiIndex.from_arrays([lev_a, lev_b]) + expected = Series(data=[2, 4, 1, 3], index=index, name="C").sort_index() + + grouped = df_cat.groupby(["A", "B"], observed=True)["C"] + msg = "using np.sum" if operation == "apply" else "using SeriesGroupBy.sum" + with tm.assert_produces_warning(FutureWarning, match=msg): + # GH#53425 + result = getattr(grouped, operation)(sum) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("operation", ["agg", "apply"]) +@pytest.mark.parametrize("observed", [False, None]) +def test_seriesgroupby_observed_false_or_none(df_cat, observed, operation): + # GH 24880 + # GH#49223 - order of results was wrong when grouping by index levels + index, _ = MultiIndex.from_product( + [ + CategoricalIndex(["bar", "foo"], ordered=False), + CategoricalIndex(["one", "three", "two"], ordered=False), + ], + names=["A", "B"], + ).sortlevel() + + expected = Series(data=[2, 4, np.nan, 1, np.nan, 3], index=index, name="C") + if operation == "agg": + msg = "The 'downcast' keyword in fillna is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + expected = expected.fillna(0, downcast="infer") + grouped = df_cat.groupby(["A", "B"], observed=observed)["C"] + msg = "using SeriesGroupBy.sum" if operation == "agg" else "using np.sum" + with tm.assert_produces_warning(FutureWarning, match=msg): + # GH#53425 + result = getattr(grouped, operation)(sum) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "observed, index, data", + [ + ( + True, + MultiIndex.from_arrays( + [ + Index(["bar"] * 4 + ["foo"] * 4, dtype="category", name="A"), + Index( + ["one", "one", "three", "three", "one", "one", "two", "two"], + dtype="category", + name="B", + ), + Index(["min", "max"] * 4), + ] + ), + [2, 2, 4, 4, 1, 1, 3, 3], + ), + ( + False, + MultiIndex.from_product( + [ + CategoricalIndex(["bar", "foo"], ordered=False), + CategoricalIndex(["one", "three", "two"], ordered=False), + Index(["min", "max"]), + ], + names=["A", "B", None], + ), + [2, 2, 4, 4, np.nan, np.nan, 1, 1, np.nan, np.nan, 3, 3], + ), + ( + None, + MultiIndex.from_product( + [ + CategoricalIndex(["bar", "foo"], ordered=False), + CategoricalIndex(["one", "three", "two"], ordered=False), + Index(["min", "max"]), + ], + names=["A", "B", None], + ), + [2, 2, 4, 4, np.nan, np.nan, 1, 1, np.nan, np.nan, 3, 3], + ), + ], +) +def test_seriesgroupby_observed_apply_dict(df_cat, observed, index, data): + # GH 24880 + expected = Series(data=data, index=index, name="C") + result = df_cat.groupby(["A", "B"], observed=observed)["C"].apply( + lambda x: {"min": x.min(), "max": x.max()} + ) + tm.assert_series_equal(result, expected) + + +def test_groupby_categorical_series_dataframe_consistent(df_cat): + # GH 20416 + expected = df_cat.groupby(["A", "B"], observed=False)["C"].mean() + result = df_cat.groupby(["A", "B"], observed=False).mean()["C"] + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("code", [([1, 0, 0]), ([0, 0, 0])]) +def test_groupby_categorical_axis_1(code): + # GH 13420 + df = DataFrame({"a": [1, 2, 3, 4], "b": [-1, -2, -3, -4], "c": [5, 6, 7, 8]}) + cat = Categorical.from_codes(code, categories=list("abc")) + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + gb = df.groupby(cat, axis=1, observed=False) + result = gb.mean() + msg = "The 'axis' keyword in DataFrame.groupby is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + gb2 = df.T.groupby(cat, axis=0, observed=False) + expected = gb2.mean().T + tm.assert_frame_equal(result, expected) + + +def test_groupby_cat_preserves_structure(observed, ordered): + # GH 28787 + df = DataFrame( + {"Name": Categorical(["Bob", "Greg"], ordered=ordered), "Item": [1, 2]}, + columns=["Name", "Item"], + ) + expected = df.copy() + + result = ( + df.groupby("Name", observed=observed) + .agg(DataFrame.sum, skipna=True) + .reset_index() + ) + + tm.assert_frame_equal(result, expected) + + +def test_get_nonexistent_category(): + # Accessing a Category that is not in the dataframe + df = DataFrame({"var": ["a", "a", "b", "b"], "val": range(4)}) + with pytest.raises(KeyError, match="'vau'"): + df.groupby("var").apply( + lambda rows: DataFrame( + {"var": [rows.iloc[-1]["var"]], "val": [rows.iloc[-1]["vau"]]} + ) + ) + + +def test_series_groupby_on_2_categoricals_unobserved(reduction_func, observed): + # GH 17605 + if reduction_func == "ngroup": + pytest.skip("ngroup is not truly a reduction") + + df = DataFrame( + { + "cat_1": Categorical(list("AABB"), categories=list("ABCD")), + "cat_2": Categorical(list("AB") * 2, categories=list("ABCD")), + "value": [0.1] * 4, + } + ) + args = get_groupby_method_args(reduction_func, df) + + expected_length = 4 if observed else 16 + + series_groupby = df.groupby(["cat_1", "cat_2"], observed=observed)["value"] + + if reduction_func == "corrwith": + # TODO: implemented SeriesGroupBy.corrwith. See GH 32293 + assert not hasattr(series_groupby, reduction_func) + return + + agg = getattr(series_groupby, reduction_func) + + if not observed and reduction_func in ["idxmin", "idxmax"]: + # idxmin and idxmax are designed to fail on empty inputs + with pytest.raises( + ValueError, match="empty group due to unobserved categories" + ): + agg(*args) + return + + result = agg(*args) + + assert len(result) == expected_length + + +def test_series_groupby_on_2_categoricals_unobserved_zeroes_or_nans( + reduction_func, request +): + # GH 17605 + # Tests whether the unobserved categories in the result contain 0 or NaN + + if reduction_func == "ngroup": + pytest.skip("ngroup is not truly a reduction") + + if reduction_func == "corrwith": # GH 32293 + mark = pytest.mark.xfail( + reason="TODO: implemented SeriesGroupBy.corrwith. See GH 32293" + ) + request.applymarker(mark) + + df = DataFrame( + { + "cat_1": Categorical(list("AABB"), categories=list("ABC")), + "cat_2": Categorical(list("AB") * 2, categories=list("ABC")), + "value": [0.1] * 4, + } + ) + unobserved = [tuple("AC"), tuple("BC"), tuple("CA"), tuple("CB"), tuple("CC")] + args = get_groupby_method_args(reduction_func, df) + + series_groupby = df.groupby(["cat_1", "cat_2"], observed=False)["value"] + agg = getattr(series_groupby, reduction_func) + + if reduction_func in ["idxmin", "idxmax"]: + # idxmin and idxmax are designed to fail on empty inputs + with pytest.raises( + ValueError, match="empty group due to unobserved categories" + ): + agg(*args) + return + + result = agg(*args) + + zero_or_nan = _results_for_groupbys_with_missing_categories[reduction_func] + + for idx in unobserved: + val = result.loc[idx] + assert (pd.isna(zero_or_nan) and pd.isna(val)) or (val == zero_or_nan) + + # If we expect unobserved values to be zero, we also expect the dtype to be int. + # Except for .sum(). If the observed categories sum to dtype=float (i.e. their + # sums have decimals), then the zeros for the missing categories should also be + # floats. + if zero_or_nan == 0 and reduction_func != "sum": + assert np.issubdtype(result.dtype, np.integer) + + +def test_dataframe_groupby_on_2_categoricals_when_observed_is_true(reduction_func): + # GH 23865 + # GH 27075 + # Ensure that df.groupby, when 'by' is two Categorical variables, + # does not return the categories that are not in df when observed=True + if reduction_func == "ngroup": + pytest.skip("ngroup does not return the Categories on the index") + + df = DataFrame( + { + "cat_1": Categorical(list("AABB"), categories=list("ABC")), + "cat_2": Categorical(list("1111"), categories=list("12")), + "value": [0.1, 0.1, 0.1, 0.1], + } + ) + unobserved_cats = [("A", "2"), ("B", "2"), ("C", "1"), ("C", "2")] + + df_grp = df.groupby(["cat_1", "cat_2"], observed=True) + + args = get_groupby_method_args(reduction_func, df) + res = getattr(df_grp, reduction_func)(*args) + + for cat in unobserved_cats: + assert cat not in res.index + + +@pytest.mark.parametrize("observed", [False, None]) +def test_dataframe_groupby_on_2_categoricals_when_observed_is_false( + reduction_func, observed +): + # GH 23865 + # GH 27075 + # Ensure that df.groupby, when 'by' is two Categorical variables, + # returns the categories that are not in df when observed=False/None + + if reduction_func == "ngroup": + pytest.skip("ngroup does not return the Categories on the index") + + df = DataFrame( + { + "cat_1": Categorical(list("AABB"), categories=list("ABC")), + "cat_2": Categorical(list("1111"), categories=list("12")), + "value": [0.1, 0.1, 0.1, 0.1], + } + ) + unobserved_cats = [("A", "2"), ("B", "2"), ("C", "1"), ("C", "2")] + + df_grp = df.groupby(["cat_1", "cat_2"], observed=observed) + + args = get_groupby_method_args(reduction_func, df) + + if not observed and reduction_func in ["idxmin", "idxmax"]: + # idxmin and idxmax are designed to fail on empty inputs + with pytest.raises( + ValueError, match="empty group due to unobserved categories" + ): + getattr(df_grp, reduction_func)(*args) + return + + res = getattr(df_grp, reduction_func)(*args) + + expected = _results_for_groupbys_with_missing_categories[reduction_func] + + if expected is np.nan: + assert res.loc[unobserved_cats].isnull().all().all() + else: + assert (res.loc[unobserved_cats] == expected).all().all() + + +def test_series_groupby_categorical_aggregation_getitem(): + # GH 8870 + d = {"foo": [10, 8, 4, 1], "bar": [10, 20, 30, 40], "baz": ["d", "c", "d", "c"]} + df = DataFrame(d) + cat = pd.cut(df["foo"], np.linspace(0, 20, 5)) + df["range"] = cat + groups = df.groupby(["range", "baz"], as_index=True, sort=True, observed=False) + result = groups["foo"].agg("mean") + expected = groups.agg("mean")["foo"] + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "func, expected_values", + [(Series.nunique, [1, 1, 2]), (Series.count, [1, 2, 2])], +) +def test_groupby_agg_categorical_columns(func, expected_values): + # 31256 + df = DataFrame( + { + "id": [0, 1, 2, 3, 4], + "groups": [0, 1, 1, 2, 2], + "value": Categorical([0, 0, 0, 0, 1]), + } + ).set_index("id") + result = df.groupby("groups").agg(func) + + expected = DataFrame( + {"value": expected_values}, index=Index([0, 1, 2], name="groups") + ) + tm.assert_frame_equal(result, expected) + + +def test_groupby_agg_non_numeric(): + df = DataFrame({"A": Categorical(["a", "a", "b"], categories=["a", "b", "c"])}) + expected = DataFrame({"A": [2, 1]}, index=np.array([1, 2])) + + result = df.groupby([1, 2, 1]).agg(Series.nunique) + tm.assert_frame_equal(result, expected) + + result = df.groupby([1, 2, 1]).nunique() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("func", ["first", "last"]) +def test_groupby_first_returned_categorical_instead_of_dataframe(func): + # GH 28641: groupby drops index, when grouping over categorical column with + # first/last. Renamed Categorical instead of DataFrame previously. + df = DataFrame({"A": [1997], "B": Series(["b"], dtype="category").cat.as_ordered()}) + df_grouped = df.groupby("A")["B"] + result = getattr(df_grouped, func)() + + # ordered categorical dtype should be preserved + expected = Series( + ["b"], index=Index([1997], name="A"), name="B", dtype=df["B"].dtype + ) + tm.assert_series_equal(result, expected) + + +def test_read_only_category_no_sort(): + # GH33410 + cats = np.array([1, 2]) + cats.flags.writeable = False + df = DataFrame( + {"a": [1, 3, 5, 7], "b": Categorical([1, 1, 2, 2], categories=Index(cats))} + ) + expected = DataFrame(data={"a": [2.0, 6.0]}, index=CategoricalIndex(cats, name="b")) + result = df.groupby("b", sort=False, observed=False).mean() + tm.assert_frame_equal(result, expected) + + +def test_sorted_missing_category_values(): + # GH 28597 + df = DataFrame( + { + "foo": [ + "small", + "large", + "large", + "large", + "medium", + "large", + "large", + "medium", + ], + "bar": ["C", "A", "A", "C", "A", "C", "A", "C"], + } + ) + df["foo"] = ( + df["foo"] + .astype("category") + .cat.set_categories(["tiny", "small", "medium", "large"], ordered=True) + ) + + expected = DataFrame( + { + "tiny": {"A": 0, "C": 0}, + "small": {"A": 0, "C": 1}, + "medium": {"A": 1, "C": 1}, + "large": {"A": 3, "C": 2}, + } + ) + expected = expected.rename_axis("bar", axis="index") + expected.columns = CategoricalIndex( + ["tiny", "small", "medium", "large"], + categories=["tiny", "small", "medium", "large"], + ordered=True, + name="foo", + dtype="category", + ) + + result = df.groupby(["bar", "foo"], observed=False).size().unstack() + + tm.assert_frame_equal(result, expected) + + +def test_agg_cython_category_not_implemented_fallback(): + # https://github.com/pandas-dev/pandas/issues/31450 + df = DataFrame({"col_num": [1, 1, 2, 3]}) + df["col_cat"] = df["col_num"].astype("category") + + result = df.groupby("col_num").col_cat.first() + + # ordered categorical dtype should definitely be preserved; + # this is unordered, so is less-clear case (if anything, it should raise) + expected = Series( + [1, 2, 3], + index=Index([1, 2, 3], name="col_num"), + name="col_cat", + dtype=df["col_cat"].dtype, + ) + tm.assert_series_equal(result, expected) + + result = df.groupby("col_num").agg({"col_cat": "first"}) + expected = expected.to_frame() + tm.assert_frame_equal(result, expected) + + +def test_aggregate_categorical_with_isnan(): + # GH 29837 + df = DataFrame( + { + "A": [1, 1, 1, 1], + "B": [1, 2, 1, 2], + "numerical_col": [0.1, 0.2, np.nan, 0.3], + "object_col": ["foo", "bar", "foo", "fee"], + "categorical_col": ["foo", "bar", "foo", "fee"], + } + ) + + df = df.astype({"categorical_col": "category"}) + + result = df.groupby(["A", "B"]).agg(lambda df: df.isna().sum()) + index = MultiIndex.from_arrays([[1, 1], [1, 2]], names=("A", "B")) + expected = DataFrame( + data={ + "numerical_col": [1, 0], + "object_col": [0, 0], + "categorical_col": [0, 0], + }, + index=index, + ) + tm.assert_frame_equal(result, expected) + + +def test_categorical_transform(): + # GH 29037 + df = DataFrame( + { + "package_id": [1, 1, 1, 2, 2, 3], + "status": [ + "Waiting", + "OnTheWay", + "Delivered", + "Waiting", + "OnTheWay", + "Waiting", + ], + } + ) + + delivery_status_type = pd.CategoricalDtype( + categories=["Waiting", "OnTheWay", "Delivered"], ordered=True + ) + df["status"] = df["status"].astype(delivery_status_type) + msg = "using SeriesGroupBy.max" + with tm.assert_produces_warning(FutureWarning, match=msg): + # GH#53425 + df["last_status"] = df.groupby("package_id")["status"].transform(max) + result = df.copy() + + expected = DataFrame( + { + "package_id": [1, 1, 1, 2, 2, 3], + "status": [ + "Waiting", + "OnTheWay", + "Delivered", + "Waiting", + "OnTheWay", + "Waiting", + ], + "last_status": [ + "Delivered", + "Delivered", + "Delivered", + "OnTheWay", + "OnTheWay", + "Waiting", + ], + } + ) + + expected["status"] = expected["status"].astype(delivery_status_type) + + # .transform(max) should preserve ordered categoricals + expected["last_status"] = expected["last_status"].astype(delivery_status_type) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("func", ["first", "last"]) +def test_series_groupby_first_on_categorical_col_grouped_on_2_categoricals( + func: str, observed: bool +): + # GH 34951 + cat = Categorical([0, 0, 1, 1]) + val = [0, 1, 1, 0] + df = DataFrame({"a": cat, "b": cat, "c": val}) + + cat2 = Categorical([0, 1]) + idx = MultiIndex.from_product([cat2, cat2], names=["a", "b"]) + expected_dict = { + "first": Series([0, np.nan, np.nan, 1], idx, name="c"), + "last": Series([1, np.nan, np.nan, 0], idx, name="c"), + } + + expected = expected_dict[func] + if observed: + expected = expected.dropna().astype(np.int64) + + srs_grp = df.groupby(["a", "b"], observed=observed)["c"] + result = getattr(srs_grp, func)() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("func", ["first", "last"]) +def test_df_groupby_first_on_categorical_col_grouped_on_2_categoricals( + func: str, observed: bool +): + # GH 34951 + cat = Categorical([0, 0, 1, 1]) + val = [0, 1, 1, 0] + df = DataFrame({"a": cat, "b": cat, "c": val}) + + cat2 = Categorical([0, 1]) + idx = MultiIndex.from_product([cat2, cat2], names=["a", "b"]) + expected_dict = { + "first": Series([0, np.nan, np.nan, 1], idx, name="c"), + "last": Series([1, np.nan, np.nan, 0], idx, name="c"), + } + + expected = expected_dict[func].to_frame() + if observed: + expected = expected.dropna().astype(np.int64) + + df_grp = df.groupby(["a", "b"], observed=observed) + result = getattr(df_grp, func)() + tm.assert_frame_equal(result, expected) + + +def test_groupby_categorical_indices_unused_categories(): + # GH#38642 + df = DataFrame( + { + "key": Categorical(["b", "b", "a"], categories=["a", "b", "c"]), + "col": range(3), + } + ) + grouped = df.groupby("key", sort=False, observed=False) + result = grouped.indices + expected = { + "b": np.array([0, 1], dtype="intp"), + "a": np.array([2], dtype="intp"), + "c": np.array([], dtype="intp"), + } + assert result.keys() == expected.keys() + for key in result.keys(): + tm.assert_numpy_array_equal(result[key], expected[key]) + + +@pytest.mark.parametrize("func", ["first", "last"]) +def test_groupby_last_first_preserve_categoricaldtype(func): + # GH#33090 + df = DataFrame({"a": [1, 2, 3]}) + df["b"] = df["a"].astype("category") + result = getattr(df.groupby("a")["b"], func)() + expected = Series( + Categorical([1, 2, 3]), name="b", index=Index([1, 2, 3], name="a") + ) + tm.assert_series_equal(expected, result) + + +def test_groupby_categorical_observed_nunique(): + # GH#45128 + df = DataFrame({"a": [1, 2], "b": [1, 2], "c": [10, 11]}) + df = df.astype(dtype={"a": "category", "b": "category"}) + result = df.groupby(["a", "b"], observed=True).nunique()["c"] + expected = Series( + [1, 1], + index=MultiIndex.from_arrays( + [CategoricalIndex([1, 2], name="a"), CategoricalIndex([1, 2], name="b")] + ), + name="c", + ) + tm.assert_series_equal(result, expected) + + +def test_groupby_categorical_aggregate_functions(): + # GH#37275 + dtype = pd.CategoricalDtype(categories=["small", "big"], ordered=True) + df = DataFrame( + [[1, "small"], [1, "big"], [2, "small"]], columns=["grp", "description"] + ).astype({"description": dtype}) + + result = df.groupby("grp")["description"].max() + expected = Series( + ["big", "small"], + index=Index([1, 2], name="grp"), + name="description", + dtype=pd.CategoricalDtype(categories=["small", "big"], ordered=True), + ) + + tm.assert_series_equal(result, expected) + + +def test_groupby_categorical_dropna(observed, dropna): + # GH#48645 - dropna should have no impact on the result when there are no NA values + cat = Categorical([1, 2], categories=[1, 2, 3]) + df = DataFrame({"x": Categorical([1, 2], categories=[1, 2, 3]), "y": [3, 4]}) + gb = df.groupby("x", observed=observed, dropna=dropna) + result = gb.sum() + + if observed: + expected = DataFrame({"y": [3, 4]}, index=cat) + else: + index = CategoricalIndex([1, 2, 3], [1, 2, 3]) + expected = DataFrame({"y": [3, 4, 0]}, index=index) + expected.index.name = "x" + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("index_kind", ["range", "single", "multi"]) +@pytest.mark.parametrize("ordered", [True, False]) +def test_category_order_reducer( + request, as_index, sort, observed, reduction_func, index_kind, ordered +): + # GH#48749 + if reduction_func == "corrwith" and not as_index: + msg = "GH#49950 - corrwith with as_index=False may not have grouping column" + request.applymarker(pytest.mark.xfail(reason=msg)) + elif index_kind != "range" and not as_index: + pytest.skip(reason="Result doesn't have categories, nothing to test") + df = DataFrame( + { + "a": Categorical([2, 1, 2, 3], categories=[1, 4, 3, 2], ordered=ordered), + "b": range(4), + } + ) + if index_kind == "range": + keys = ["a"] + elif index_kind == "single": + keys = ["a"] + df = df.set_index(keys) + elif index_kind == "multi": + keys = ["a", "a2"] + df["a2"] = df["a"] + df = df.set_index(keys) + args = get_groupby_method_args(reduction_func, df) + gb = df.groupby(keys, as_index=as_index, sort=sort, observed=observed) + + if not observed and reduction_func in ["idxmin", "idxmax"]: + # idxmin and idxmax are designed to fail on empty inputs + with pytest.raises( + ValueError, match="empty group due to unobserved categories" + ): + getattr(gb, reduction_func)(*args) + return + + op_result = getattr(gb, reduction_func)(*args) + if as_index: + result = op_result.index.get_level_values("a").categories + else: + result = op_result["a"].cat.categories + expected = Index([1, 4, 3, 2]) + tm.assert_index_equal(result, expected) + + if index_kind == "multi": + result = op_result.index.get_level_values("a2").categories + tm.assert_index_equal(result, expected) + + +@pytest.mark.parametrize("index_kind", ["single", "multi"]) +@pytest.mark.parametrize("ordered", [True, False]) +def test_category_order_transformer( + as_index, sort, observed, transformation_func, index_kind, ordered +): + # GH#48749 + df = DataFrame( + { + "a": Categorical([2, 1, 2, 3], categories=[1, 4, 3, 2], ordered=ordered), + "b": range(4), + } + ) + if index_kind == "single": + keys = ["a"] + df = df.set_index(keys) + elif index_kind == "multi": + keys = ["a", "a2"] + df["a2"] = df["a"] + df = df.set_index(keys) + args = get_groupby_method_args(transformation_func, df) + gb = df.groupby(keys, as_index=as_index, sort=sort, observed=observed) + warn = FutureWarning if transformation_func == "fillna" else None + msg = "DataFrameGroupBy.fillna is deprecated" + with tm.assert_produces_warning(warn, match=msg): + op_result = getattr(gb, transformation_func)(*args) + result = op_result.index.get_level_values("a").categories + expected = Index([1, 4, 3, 2]) + tm.assert_index_equal(result, expected) + + if index_kind == "multi": + result = op_result.index.get_level_values("a2").categories + tm.assert_index_equal(result, expected) + + +@pytest.mark.parametrize("index_kind", ["range", "single", "multi"]) +@pytest.mark.parametrize("method", ["head", "tail"]) +@pytest.mark.parametrize("ordered", [True, False]) +def test_category_order_head_tail( + as_index, sort, observed, method, index_kind, ordered +): + # GH#48749 + df = DataFrame( + { + "a": Categorical([2, 1, 2, 3], categories=[1, 4, 3, 2], ordered=ordered), + "b": range(4), + } + ) + if index_kind == "range": + keys = ["a"] + elif index_kind == "single": + keys = ["a"] + df = df.set_index(keys) + elif index_kind == "multi": + keys = ["a", "a2"] + df["a2"] = df["a"] + df = df.set_index(keys) + gb = df.groupby(keys, as_index=as_index, sort=sort, observed=observed) + op_result = getattr(gb, method)() + if index_kind == "range": + result = op_result["a"].cat.categories + else: + result = op_result.index.get_level_values("a").categories + expected = Index([1, 4, 3, 2]) + tm.assert_index_equal(result, expected) + + if index_kind == "multi": + result = op_result.index.get_level_values("a2").categories + tm.assert_index_equal(result, expected) + + +@pytest.mark.parametrize("index_kind", ["range", "single", "multi"]) +@pytest.mark.parametrize("method", ["apply", "agg", "transform"]) +@pytest.mark.parametrize("ordered", [True, False]) +def test_category_order_apply(as_index, sort, observed, method, index_kind, ordered): + # GH#48749 + if (method == "transform" and index_kind == "range") or ( + not as_index and index_kind != "range" + ): + pytest.skip("No categories in result, nothing to test") + df = DataFrame( + { + "a": Categorical([2, 1, 2, 3], categories=[1, 4, 3, 2], ordered=ordered), + "b": range(4), + } + ) + if index_kind == "range": + keys = ["a"] + elif index_kind == "single": + keys = ["a"] + df = df.set_index(keys) + elif index_kind == "multi": + keys = ["a", "a2"] + df["a2"] = df["a"] + df = df.set_index(keys) + gb = df.groupby(keys, as_index=as_index, sort=sort, observed=observed) + warn = DeprecationWarning if method == "apply" and index_kind == "range" else None + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(warn, match=msg): + op_result = getattr(gb, method)(lambda x: x.sum(numeric_only=True)) + if (method == "transform" or not as_index) and index_kind == "range": + result = op_result["a"].cat.categories + else: + result = op_result.index.get_level_values("a").categories + expected = Index([1, 4, 3, 2]) + tm.assert_index_equal(result, expected) + + if index_kind == "multi": + result = op_result.index.get_level_values("a2").categories + tm.assert_index_equal(result, expected) + + +@pytest.mark.parametrize("index_kind", ["range", "single", "multi"]) +def test_many_categories(as_index, sort, index_kind, ordered): + # GH#48749 - Test when the grouper has many categories + if index_kind != "range" and not as_index: + pytest.skip(reason="Result doesn't have categories, nothing to test") + categories = np.arange(9999, -1, -1) + grouper = Categorical([2, 1, 2, 3], categories=categories, ordered=ordered) + df = DataFrame({"a": grouper, "b": range(4)}) + if index_kind == "range": + keys = ["a"] + elif index_kind == "single": + keys = ["a"] + df = df.set_index(keys) + elif index_kind == "multi": + keys = ["a", "a2"] + df["a2"] = df["a"] + df = df.set_index(keys) + gb = df.groupby(keys, as_index=as_index, sort=sort, observed=True) + result = gb.sum() + + # Test is setup so that data and index are the same values + data = [3, 2, 1] if sort else [2, 1, 3] + + index = CategoricalIndex( + data, categories=grouper.categories, ordered=ordered, name="a" + ) + if as_index: + expected = DataFrame({"b": data}) + if index_kind == "multi": + expected.index = MultiIndex.from_frame(DataFrame({"a": index, "a2": index})) + else: + expected.index = index + elif index_kind == "multi": + expected = DataFrame({"a": Series(index), "a2": Series(index), "b": data}) + else: + expected = DataFrame({"a": Series(index), "b": data}) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("cat_columns", ["a", "b", ["a", "b"]]) +@pytest.mark.parametrize("keys", ["a", "b", ["a", "b"]]) +def test_groupby_default_depr(cat_columns, keys): + # GH#43999 + df = DataFrame({"a": [1, 1, 2, 3], "b": [4, 5, 6, 7]}) + df[cat_columns] = df[cat_columns].astype("category") + msg = "The default of observed=False is deprecated" + klass = FutureWarning if set(cat_columns) & set(keys) else None + with tm.assert_produces_warning(klass, match=msg): + df.groupby(keys) + + +@pytest.mark.parametrize("test_series", [True, False]) +@pytest.mark.parametrize("keys", [["a1"], ["a1", "a2"]]) +def test_agg_list(request, as_index, observed, reduction_func, test_series, keys): + # GH#52760 + if test_series and reduction_func == "corrwith": + assert not hasattr(SeriesGroupBy, "corrwith") + pytest.skip("corrwith not implemented for SeriesGroupBy") + elif reduction_func == "corrwith": + msg = "GH#32293: attempts to call SeriesGroupBy.corrwith" + request.applymarker(pytest.mark.xfail(reason=msg)) + elif ( + reduction_func == "nunique" + and not test_series + and len(keys) != 1 + and not observed + and not as_index + ): + msg = "GH#52848 - raises a ValueError" + request.applymarker(pytest.mark.xfail(reason=msg)) + + df = DataFrame({"a1": [0, 0, 1], "a2": [2, 3, 3], "b": [4, 5, 6]}) + df = df.astype({"a1": "category", "a2": "category"}) + if "a2" not in keys: + df = df.drop(columns="a2") + gb = df.groupby(by=keys, as_index=as_index, observed=observed) + if test_series: + gb = gb["b"] + args = get_groupby_method_args(reduction_func, df) + + if not observed and reduction_func in ["idxmin", "idxmax"] and keys == ["a1", "a2"]: + with pytest.raises( + ValueError, match="empty group due to unobserved categories" + ): + gb.agg([reduction_func], *args) + return + + result = gb.agg([reduction_func], *args) + expected = getattr(gb, reduction_func)(*args) + + if as_index and (test_series or reduction_func == "size"): + expected = expected.to_frame(reduction_func) + if not test_series: + expected.columns = MultiIndex.from_tuples( + [(ind, "") for ind in expected.columns[:-1]] + [("b", reduction_func)] + ) + elif not as_index: + expected.columns = keys + [reduction_func] + + tm.assert_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_counting.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_counting.py new file mode 100644 index 0000000000000000000000000000000000000000..2622895f9f8d21a9e568b0954681f4a7169659c1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_counting.py @@ -0,0 +1,394 @@ +from itertools import product +from string import ascii_lowercase + +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Index, + MultiIndex, + Period, + Series, + Timedelta, + Timestamp, + date_range, +) +import pandas._testing as tm + + +class TestCounting: + def test_cumcount(self): + df = DataFrame([["a"], ["a"], ["a"], ["b"], ["a"]], columns=["A"]) + g = df.groupby("A") + sg = g.A + + expected = Series([0, 1, 2, 0, 3]) + + tm.assert_series_equal(expected, g.cumcount()) + tm.assert_series_equal(expected, sg.cumcount()) + + def test_cumcount_empty(self): + ge = DataFrame().groupby(level=0) + se = Series(dtype=object).groupby(level=0) + + # edge case, as this is usually considered float + e = Series(dtype="int64") + + tm.assert_series_equal(e, ge.cumcount()) + tm.assert_series_equal(e, se.cumcount()) + + def test_cumcount_dupe_index(self): + df = DataFrame( + [["a"], ["a"], ["a"], ["b"], ["a"]], columns=["A"], index=[0] * 5 + ) + g = df.groupby("A") + sg = g.A + + expected = Series([0, 1, 2, 0, 3], index=[0] * 5) + + tm.assert_series_equal(expected, g.cumcount()) + tm.assert_series_equal(expected, sg.cumcount()) + + def test_cumcount_mi(self): + mi = MultiIndex.from_tuples([[0, 1], [1, 2], [2, 2], [2, 2], [1, 0]]) + df = DataFrame([["a"], ["a"], ["a"], ["b"], ["a"]], columns=["A"], index=mi) + g = df.groupby("A") + sg = g.A + + expected = Series([0, 1, 2, 0, 3], index=mi) + + tm.assert_series_equal(expected, g.cumcount()) + tm.assert_series_equal(expected, sg.cumcount()) + + def test_cumcount_groupby_not_col(self): + df = DataFrame( + [["a"], ["a"], ["a"], ["b"], ["a"]], columns=["A"], index=[0] * 5 + ) + g = df.groupby([0, 0, 0, 1, 0]) + sg = g.A + + expected = Series([0, 1, 2, 0, 3], index=[0] * 5) + + tm.assert_series_equal(expected, g.cumcount()) + tm.assert_series_equal(expected, sg.cumcount()) + + def test_ngroup(self): + df = DataFrame({"A": list("aaaba")}) + g = df.groupby("A") + sg = g.A + + expected = Series([0, 0, 0, 1, 0]) + + tm.assert_series_equal(expected, g.ngroup()) + tm.assert_series_equal(expected, sg.ngroup()) + + def test_ngroup_distinct(self): + df = DataFrame({"A": list("abcde")}) + g = df.groupby("A") + sg = g.A + + expected = Series(range(5), dtype="int64") + + tm.assert_series_equal(expected, g.ngroup()) + tm.assert_series_equal(expected, sg.ngroup()) + + def test_ngroup_one_group(self): + df = DataFrame({"A": [0] * 5}) + g = df.groupby("A") + sg = g.A + + expected = Series([0] * 5) + + tm.assert_series_equal(expected, g.ngroup()) + tm.assert_series_equal(expected, sg.ngroup()) + + def test_ngroup_empty(self): + ge = DataFrame().groupby(level=0) + se = Series(dtype=object).groupby(level=0) + + # edge case, as this is usually considered float + e = Series(dtype="int64") + + tm.assert_series_equal(e, ge.ngroup()) + tm.assert_series_equal(e, se.ngroup()) + + def test_ngroup_series_matches_frame(self): + df = DataFrame({"A": list("aaaba")}) + s = Series(list("aaaba")) + + tm.assert_series_equal(df.groupby(s).ngroup(), s.groupby(s).ngroup()) + + def test_ngroup_dupe_index(self): + df = DataFrame({"A": list("aaaba")}, index=[0] * 5) + g = df.groupby("A") + sg = g.A + + expected = Series([0, 0, 0, 1, 0], index=[0] * 5) + + tm.assert_series_equal(expected, g.ngroup()) + tm.assert_series_equal(expected, sg.ngroup()) + + def test_ngroup_mi(self): + mi = MultiIndex.from_tuples([[0, 1], [1, 2], [2, 2], [2, 2], [1, 0]]) + df = DataFrame({"A": list("aaaba")}, index=mi) + g = df.groupby("A") + sg = g.A + expected = Series([0, 0, 0, 1, 0], index=mi) + + tm.assert_series_equal(expected, g.ngroup()) + tm.assert_series_equal(expected, sg.ngroup()) + + def test_ngroup_groupby_not_col(self): + df = DataFrame({"A": list("aaaba")}, index=[0] * 5) + g = df.groupby([0, 0, 0, 1, 0]) + sg = g.A + + expected = Series([0, 0, 0, 1, 0], index=[0] * 5) + + tm.assert_series_equal(expected, g.ngroup()) + tm.assert_series_equal(expected, sg.ngroup()) + + def test_ngroup_descending(self): + df = DataFrame(["a", "a", "b", "a", "b"], columns=["A"]) + g = df.groupby(["A"]) + + ascending = Series([0, 0, 1, 0, 1]) + descending = Series([1, 1, 0, 1, 0]) + + tm.assert_series_equal(descending, (g.ngroups - 1) - ascending) + tm.assert_series_equal(ascending, g.ngroup(ascending=True)) + tm.assert_series_equal(descending, g.ngroup(ascending=False)) + + def test_ngroup_matches_cumcount(self): + # verify one manually-worked out case works + df = DataFrame( + [["a", "x"], ["a", "y"], ["b", "x"], ["a", "x"], ["b", "y"]], + columns=["A", "X"], + ) + g = df.groupby(["A", "X"]) + g_ngroup = g.ngroup() + g_cumcount = g.cumcount() + expected_ngroup = Series([0, 1, 2, 0, 3]) + expected_cumcount = Series([0, 0, 0, 1, 0]) + + tm.assert_series_equal(g_ngroup, expected_ngroup) + tm.assert_series_equal(g_cumcount, expected_cumcount) + + def test_ngroup_cumcount_pair(self): + # brute force comparison for all small series + for p in product(range(3), repeat=4): + df = DataFrame({"a": p}) + g = df.groupby(["a"]) + + order = sorted(set(p)) + ngroupd = [order.index(val) for val in p] + cumcounted = [p[:i].count(val) for i, val in enumerate(p)] + + tm.assert_series_equal(g.ngroup(), Series(ngroupd)) + tm.assert_series_equal(g.cumcount(), Series(cumcounted)) + + def test_ngroup_respects_groupby_order(self, sort): + df = DataFrame({"a": np.random.default_rng(2).choice(list("abcdef"), 100)}) + g = df.groupby("a", sort=sort) + df["group_id"] = -1 + df["group_index"] = -1 + + for i, (_, group) in enumerate(g): + df.loc[group.index, "group_id"] = i + for j, ind in enumerate(group.index): + df.loc[ind, "group_index"] = j + + tm.assert_series_equal(Series(df["group_id"].values), g.ngroup()) + tm.assert_series_equal(Series(df["group_index"].values), g.cumcount()) + + @pytest.mark.parametrize( + "datetimelike", + [ + [Timestamp(f"2016-05-{i:02d} 20:09:25+00:00") for i in range(1, 4)], + [Timestamp(f"2016-05-{i:02d} 20:09:25") for i in range(1, 4)], + [Timestamp(f"2016-05-{i:02d} 20:09:25", tz="UTC") for i in range(1, 4)], + [Timedelta(x, unit="h") for x in range(1, 4)], + [Period(freq="2W", year=2017, month=x) for x in range(1, 4)], + ], + ) + def test_count_with_datetimelike(self, datetimelike): + # test for #13393, where DataframeGroupBy.count() fails + # when counting a datetimelike column. + + df = DataFrame({"x": ["a", "a", "b"], "y": datetimelike}) + res = df.groupby("x").count() + expected = DataFrame({"y": [2, 1]}, index=["a", "b"]) + expected.index.name = "x" + tm.assert_frame_equal(expected, res) + + def test_count_with_only_nans_in_first_group(self): + # GH21956 + df = DataFrame({"A": [np.nan, np.nan], "B": ["a", "b"], "C": [1, 2]}) + result = df.groupby(["A", "B"]).C.count() + mi = MultiIndex(levels=[[], ["a", "b"]], codes=[[], []], names=["A", "B"]) + expected = Series([], index=mi, dtype=np.int64, name="C") + tm.assert_series_equal(result, expected, check_index_type=False) + + def test_count_groupby_column_with_nan_in_groupby_column(self): + # https://github.com/pandas-dev/pandas/issues/32841 + df = DataFrame({"A": [1, 1, 1, 1, 1], "B": [5, 4, np.nan, 3, 0]}) + res = df.groupby(["B"]).count() + expected = DataFrame( + index=Index([0.0, 3.0, 4.0, 5.0], name="B"), data={"A": [1, 1, 1, 1]} + ) + tm.assert_frame_equal(expected, res) + + def test_groupby_count_dateparseerror(self): + dr = date_range(start="1/1/2012", freq="5min", periods=10) + + # BAD Example, datetimes first + ser = Series(np.arange(10), index=[dr, np.arange(10)]) + grouped = ser.groupby(lambda x: x[1] % 2 == 0) + result = grouped.count() + + ser = Series(np.arange(10), index=[np.arange(10), dr]) + grouped = ser.groupby(lambda x: x[0] % 2 == 0) + expected = grouped.count() + + tm.assert_series_equal(result, expected) + + +def test_groupby_timedelta_cython_count(): + df = DataFrame( + {"g": list("ab" * 2), "delta": np.arange(4).astype("timedelta64[ns]")} + ) + expected = Series([2, 2], index=Index(["a", "b"], name="g"), name="delta") + result = df.groupby("g").delta.count() + tm.assert_series_equal(expected, result) + + +def test_count(): + n = 1 << 15 + dr = date_range("2015-08-30", periods=n // 10, freq="min") + + df = DataFrame( + { + "1st": np.random.default_rng(2).choice(list(ascii_lowercase), n), + "2nd": np.random.default_rng(2).integers(0, 5, n), + "3rd": np.random.default_rng(2).standard_normal(n).round(3), + "4th": np.random.default_rng(2).integers(-10, 10, n), + "5th": np.random.default_rng(2).choice(dr, n), + "6th": np.random.default_rng(2).standard_normal(n).round(3), + "7th": np.random.default_rng(2).standard_normal(n).round(3), + "8th": np.random.default_rng(2).choice(dr, n) + - np.random.default_rng(2).choice(dr, 1), + "9th": np.random.default_rng(2).choice(list(ascii_lowercase), n), + } + ) + + for col in df.columns.drop(["1st", "2nd", "4th"]): + df.loc[np.random.default_rng(2).choice(n, n // 10), col] = np.nan + + df["9th"] = df["9th"].astype("category") + + for key in ["1st", "2nd", ["1st", "2nd"]]: + left = df.groupby(key).count() + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + right = df.groupby(key).apply(DataFrame.count).drop(key, axis=1) + tm.assert_frame_equal(left, right) + + +def test_count_non_nulls(): + # GH#5610 + # count counts non-nulls + df = DataFrame( + [[1, 2, "foo"], [1, np.nan, "bar"], [3, np.nan, np.nan]], + columns=["A", "B", "C"], + ) + + count_as = df.groupby("A").count() + count_not_as = df.groupby("A", as_index=False).count() + + expected = DataFrame([[1, 2], [0, 0]], columns=["B", "C"], index=[1, 3]) + expected.index.name = "A" + tm.assert_frame_equal(count_not_as, expected.reset_index()) + tm.assert_frame_equal(count_as, expected) + + count_B = df.groupby("A")["B"].count() + tm.assert_series_equal(count_B, expected["B"]) + + +def test_count_object(): + df = DataFrame({"a": ["a"] * 3 + ["b"] * 3, "c": [2] * 3 + [3] * 3}) + result = df.groupby("c").a.count() + expected = Series([3, 3], index=Index([2, 3], name="c"), name="a") + tm.assert_series_equal(result, expected) + + df = DataFrame({"a": ["a", np.nan, np.nan] + ["b"] * 3, "c": [2] * 3 + [3] * 3}) + result = df.groupby("c").a.count() + expected = Series([1, 3], index=Index([2, 3], name="c"), name="a") + tm.assert_series_equal(result, expected) + + +def test_count_cross_type(): + # GH8169 + # Set float64 dtype to avoid upcast when setting nan below + vals = np.hstack( + ( + np.random.default_rng(2).integers(0, 5, (100, 2)), + np.random.default_rng(2).integers(0, 2, (100, 2)), + ) + ).astype("float64") + + df = DataFrame(vals, columns=["a", "b", "c", "d"]) + df[df == 2] = np.nan + expected = df.groupby(["c", "d"]).count() + + for t in ["float32", "object"]: + df["a"] = df["a"].astype(t) + df["b"] = df["b"].astype(t) + result = df.groupby(["c", "d"]).count() + tm.assert_frame_equal(result, expected) + + +def test_lower_int_prec_count(): + df = DataFrame( + { + "a": np.array([0, 1, 2, 100], np.int8), + "b": np.array([1, 2, 3, 6], np.uint32), + "c": np.array([4, 5, 6, 8], np.int16), + "grp": list("ab" * 2), + } + ) + result = df.groupby("grp").count() + expected = DataFrame( + {"a": [2, 2], "b": [2, 2], "c": [2, 2]}, index=Index(list("ab"), name="grp") + ) + tm.assert_frame_equal(result, expected) + + +def test_count_uses_size_on_exception(): + class RaisingObjectException(Exception): + pass + + class RaisingObject: + def __init__(self, msg="I will raise inside Cython") -> None: + super().__init__() + self.msg = msg + + def __eq__(self, other): + # gets called in Cython to check that raising calls the method + raise RaisingObjectException(self.msg) + + df = DataFrame({"a": [RaisingObject() for _ in range(4)], "grp": list("ab" * 2)}) + result = df.groupby("grp").count() + expected = DataFrame({"a": [2, 2]}, index=Index(list("ab"), name="grp")) + tm.assert_frame_equal(result, expected) + + +def test_count_arrow_string_array(any_string_dtype): + # GH#54751 + pytest.importorskip("pyarrow") + df = DataFrame( + {"a": [1, 2, 3], "b": Series(["a", "b", "a"], dtype=any_string_dtype)} + ) + result = df.groupby("a").count() + expected = DataFrame({"b": 1}, index=Index([1, 2, 3], name="a")) + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_cumulative.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_cumulative.py new file mode 100644 index 0000000000000000000000000000000000000000..1bdbef6d50c4c23db86060493dcd4f6df4bc4728 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_cumulative.py @@ -0,0 +1,319 @@ +import numpy as np +import pytest + +from pandas.errors import UnsupportedFunctionCall +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + DataFrame, + Series, +) +import pandas._testing as tm + + +@pytest.fixture( + params=[np.int32, np.int64, np.float32, np.float64, "Int64", "Float64"], + ids=["np.int32", "np.int64", "np.float32", "np.float64", "Int64", "Float64"], +) +def dtypes_for_minmax(request): + """ + Fixture of dtypes with min and max values used for testing + cummin and cummax + """ + dtype = request.param + + np_type = dtype + if dtype == "Int64": + np_type = np.int64 + elif dtype == "Float64": + np_type = np.float64 + + min_val = ( + np.iinfo(np_type).min + if np.dtype(np_type).kind == "i" + else np.finfo(np_type).min + ) + max_val = ( + np.iinfo(np_type).max + if np.dtype(np_type).kind == "i" + else np.finfo(np_type).max + ) + + return (dtype, min_val, max_val) + + +def test_groupby_cumprod(): + # GH 4095 + df = DataFrame({"key": ["b"] * 10, "value": 2}) + + actual = df.groupby("key")["value"].cumprod() + expected = df.groupby("key", group_keys=False)["value"].apply(lambda x: x.cumprod()) + expected.name = "value" + tm.assert_series_equal(actual, expected) + + df = DataFrame({"key": ["b"] * 100, "value": 2}) + df["value"] = df["value"].astype(float) + actual = df.groupby("key")["value"].cumprod() + expected = df.groupby("key", group_keys=False)["value"].apply(lambda x: x.cumprod()) + expected.name = "value" + tm.assert_series_equal(actual, expected) + + +@pytest.mark.skip_ubsan +def test_groupby_cumprod_overflow(): + # GH#37493 if we overflow we return garbage consistent with numpy + df = DataFrame({"key": ["b"] * 4, "value": 100_000}) + actual = df.groupby("key")["value"].cumprod() + expected = Series( + [100_000, 10_000_000_000, 1_000_000_000_000_000, 7766279631452241920], + name="value", + ) + tm.assert_series_equal(actual, expected) + + numpy_result = df.groupby("key", group_keys=False)["value"].apply( + lambda x: x.cumprod() + ) + numpy_result.name = "value" + tm.assert_series_equal(actual, numpy_result) + + +def test_groupby_cumprod_nan_influences_other_columns(): + # GH#48064 + df = DataFrame( + { + "a": 1, + "b": [1, np.nan, 2], + "c": [1, 2, 3.0], + } + ) + result = df.groupby("a").cumprod(numeric_only=True, skipna=False) + expected = DataFrame({"b": [1, np.nan, np.nan], "c": [1, 2, 6.0]}) + tm.assert_frame_equal(result, expected) + + +def test_cummin(dtypes_for_minmax): + dtype = dtypes_for_minmax[0] + min_val = dtypes_for_minmax[1] + + # GH 15048 + base_df = DataFrame({"A": [1, 1, 1, 1, 2, 2, 2, 2], "B": [3, 4, 3, 2, 2, 3, 2, 1]}) + expected_mins = [3, 3, 3, 2, 2, 2, 2, 1] + + df = base_df.astype(dtype) + + expected = DataFrame({"B": expected_mins}).astype(dtype) + result = df.groupby("A").cummin() + tm.assert_frame_equal(result, expected) + result = df.groupby("A", group_keys=False).B.apply(lambda x: x.cummin()).to_frame() + tm.assert_frame_equal(result, expected) + + # Test w/ min value for dtype + df.loc[[2, 6], "B"] = min_val + df.loc[[1, 5], "B"] = min_val + 1 + expected.loc[[2, 3, 6, 7], "B"] = min_val + expected.loc[[1, 5], "B"] = min_val + 1 # should not be rounded to min_val + result = df.groupby("A").cummin() + tm.assert_frame_equal(result, expected, check_exact=True) + expected = ( + df.groupby("A", group_keys=False).B.apply(lambda x: x.cummin()).to_frame() + ) + tm.assert_frame_equal(result, expected, check_exact=True) + + # Test nan in some values + # Explicit cast to float to avoid implicit cast when setting nan + base_df = base_df.astype({"B": "float"}) + base_df.loc[[0, 2, 4, 6], "B"] = np.nan + expected = DataFrame({"B": [np.nan, 4, np.nan, 2, np.nan, 3, np.nan, 1]}) + result = base_df.groupby("A").cummin() + tm.assert_frame_equal(result, expected) + expected = ( + base_df.groupby("A", group_keys=False).B.apply(lambda x: x.cummin()).to_frame() + ) + tm.assert_frame_equal(result, expected) + + # GH 15561 + df = DataFrame({"a": [1], "b": pd.to_datetime(["2001"])}) + expected = Series(pd.to_datetime("2001"), index=[0], name="b") + + result = df.groupby("a")["b"].cummin() + tm.assert_series_equal(expected, result) + + # GH 15635 + df = DataFrame({"a": [1, 2, 1], "b": [1, 2, 2]}) + result = df.groupby("a").b.cummin() + expected = Series([1, 2, 1], name="b") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("method", ["cummin", "cummax"]) +@pytest.mark.parametrize("dtype", ["UInt64", "Int64", "Float64", "float", "boolean"]) +def test_cummin_max_all_nan_column(method, dtype): + base_df = DataFrame({"A": [1, 1, 1, 1, 2, 2, 2, 2], "B": [np.nan] * 8}) + base_df["B"] = base_df["B"].astype(dtype) + grouped = base_df.groupby("A") + + expected = DataFrame({"B": [np.nan] * 8}, dtype=dtype) + result = getattr(grouped, method)() + tm.assert_frame_equal(expected, result) + + result = getattr(grouped["B"], method)().to_frame() + tm.assert_frame_equal(expected, result) + + +def test_cummax(dtypes_for_minmax): + dtype = dtypes_for_minmax[0] + max_val = dtypes_for_minmax[2] + + # GH 15048 + base_df = DataFrame({"A": [1, 1, 1, 1, 2, 2, 2, 2], "B": [3, 4, 3, 2, 2, 3, 2, 1]}) + expected_maxs = [3, 4, 4, 4, 2, 3, 3, 3] + + df = base_df.astype(dtype) + + expected = DataFrame({"B": expected_maxs}).astype(dtype) + result = df.groupby("A").cummax() + tm.assert_frame_equal(result, expected) + result = df.groupby("A", group_keys=False).B.apply(lambda x: x.cummax()).to_frame() + tm.assert_frame_equal(result, expected) + + # Test w/ max value for dtype + df.loc[[2, 6], "B"] = max_val + expected.loc[[2, 3, 6, 7], "B"] = max_val + result = df.groupby("A").cummax() + tm.assert_frame_equal(result, expected) + expected = ( + df.groupby("A", group_keys=False).B.apply(lambda x: x.cummax()).to_frame() + ) + tm.assert_frame_equal(result, expected) + + # Test nan in some values + # Explicit cast to float to avoid implicit cast when setting nan + base_df = base_df.astype({"B": "float"}) + base_df.loc[[0, 2, 4, 6], "B"] = np.nan + expected = DataFrame({"B": [np.nan, 4, np.nan, 4, np.nan, 3, np.nan, 3]}) + result = base_df.groupby("A").cummax() + tm.assert_frame_equal(result, expected) + expected = ( + base_df.groupby("A", group_keys=False).B.apply(lambda x: x.cummax()).to_frame() + ) + tm.assert_frame_equal(result, expected) + + # GH 15561 + df = DataFrame({"a": [1], "b": pd.to_datetime(["2001"])}) + expected = Series(pd.to_datetime("2001"), index=[0], name="b") + + result = df.groupby("a")["b"].cummax() + tm.assert_series_equal(expected, result) + + # GH 15635 + df = DataFrame({"a": [1, 2, 1], "b": [2, 1, 1]}) + result = df.groupby("a").b.cummax() + expected = Series([2, 1, 2], name="b") + tm.assert_series_equal(result, expected) + + +def test_cummax_i8_at_implementation_bound(): + # the minimum value used to be treated as NPY_NAT+1 instead of NPY_NAT + # for int64 dtype GH#46382 + ser = Series([pd.NaT._value + n for n in range(5)]) + df = DataFrame({"A": 1, "B": ser, "C": ser._values.view("M8[ns]")}) + gb = df.groupby("A") + + res = gb.cummax() + exp = df[["B", "C"]] + tm.assert_frame_equal(res, exp) + + +@pytest.mark.parametrize("method", ["cummin", "cummax"]) +@pytest.mark.parametrize("dtype", ["float", "Int64", "Float64"]) +@pytest.mark.parametrize( + "groups,expected_data", + [ + ([1, 1, 1], [1, None, None]), + ([1, 2, 3], [1, None, 2]), + ([1, 3, 3], [1, None, None]), + ], +) +def test_cummin_max_skipna(method, dtype, groups, expected_data): + # GH-34047 + df = DataFrame({"a": Series([1, None, 2], dtype=dtype)}) + orig = df.copy() + gb = df.groupby(groups)["a"] + + result = getattr(gb, method)(skipna=False) + expected = Series(expected_data, dtype=dtype, name="a") + + # check we didn't accidentally alter df + tm.assert_frame_equal(df, orig) + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("method", ["cummin", "cummax"]) +def test_cummin_max_skipna_multiple_cols(method): + # Ensure missing value in "a" doesn't cause "b" to be nan-filled + df = DataFrame({"a": [np.nan, 2.0, 2.0], "b": [2.0, 2.0, 2.0]}) + gb = df.groupby([1, 1, 1])[["a", "b"]] + + result = getattr(gb, method)(skipna=False) + expected = DataFrame({"a": [np.nan, np.nan, np.nan], "b": [2.0, 2.0, 2.0]}) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("func", ["cumprod", "cumsum"]) +def test_numpy_compat(func): + # see gh-12811 + df = DataFrame({"A": [1, 2, 1], "B": [1, 2, 3]}) + g = df.groupby("A") + + msg = "numpy operations are not valid with groupby" + + with pytest.raises(UnsupportedFunctionCall, match=msg): + getattr(g, func)(1, 2, 3) + with pytest.raises(UnsupportedFunctionCall, match=msg): + getattr(g, func)(foo=1) + + +@td.skip_if_32bit +@pytest.mark.parametrize("method", ["cummin", "cummax"]) +@pytest.mark.parametrize( + "dtype,val", [("UInt64", np.iinfo("uint64").max), ("Int64", 2**53 + 1)] +) +def test_nullable_int_not_cast_as_float(method, dtype, val): + data = [val, pd.NA] + df = DataFrame({"grp": [1, 1], "b": data}, dtype=dtype) + grouped = df.groupby("grp") + + result = grouped.transform(method) + expected = DataFrame({"b": data}, dtype=dtype) + + tm.assert_frame_equal(result, expected) + + +def test_cython_api2(): + # this takes the fast apply path + + # cumsum (GH5614) + df = DataFrame([[1, 2, np.nan], [1, np.nan, 9], [3, 4, 9]], columns=["A", "B", "C"]) + expected = DataFrame([[2, np.nan], [np.nan, 9], [4, 9]], columns=["B", "C"]) + result = df.groupby("A").cumsum() + tm.assert_frame_equal(result, expected) + + # GH 5755 - cumsum is a transformer and should ignore as_index + result = df.groupby("A", as_index=False).cumsum() + tm.assert_frame_equal(result, expected) + + # GH 13994 + msg = "DataFrameGroupBy.cumsum with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = df.groupby("A").cumsum(axis=1) + expected = df.cumsum(axis=1) + tm.assert_frame_equal(result, expected) + + msg = "DataFrameGroupBy.cumprod with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = df.groupby("A").cumprod(axis=1) + expected = df.cumprod(axis=1) + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_filters.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_filters.py new file mode 100644 index 0000000000000000000000000000000000000000..309c4b7b57e84f68e13ed974790c87c16244aae7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_filters.py @@ -0,0 +1,636 @@ +from string import ascii_lowercase + +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + Series, + Timestamp, +) +import pandas._testing as tm + + +def test_filter_series(): + s = Series([1, 3, 20, 5, 22, 24, 7]) + expected_odd = Series([1, 3, 5, 7], index=[0, 1, 3, 6]) + expected_even = Series([20, 22, 24], index=[2, 4, 5]) + grouper = s.apply(lambda x: x % 2) + grouped = s.groupby(grouper) + tm.assert_series_equal(grouped.filter(lambda x: x.mean() < 10), expected_odd) + tm.assert_series_equal(grouped.filter(lambda x: x.mean() > 10), expected_even) + # Test dropna=False. + tm.assert_series_equal( + grouped.filter(lambda x: x.mean() < 10, dropna=False), + expected_odd.reindex(s.index), + ) + tm.assert_series_equal( + grouped.filter(lambda x: x.mean() > 10, dropna=False), + expected_even.reindex(s.index), + ) + + +def test_filter_single_column_df(): + df = DataFrame([1, 3, 20, 5, 22, 24, 7]) + expected_odd = DataFrame([1, 3, 5, 7], index=[0, 1, 3, 6]) + expected_even = DataFrame([20, 22, 24], index=[2, 4, 5]) + grouper = df[0].apply(lambda x: x % 2) + grouped = df.groupby(grouper) + tm.assert_frame_equal(grouped.filter(lambda x: x.mean() < 10), expected_odd) + tm.assert_frame_equal(grouped.filter(lambda x: x.mean() > 10), expected_even) + # Test dropna=False. + tm.assert_frame_equal( + grouped.filter(lambda x: x.mean() < 10, dropna=False), + expected_odd.reindex(df.index), + ) + tm.assert_frame_equal( + grouped.filter(lambda x: x.mean() > 10, dropna=False), + expected_even.reindex(df.index), + ) + + +def test_filter_multi_column_df(): + df = DataFrame({"A": [1, 12, 12, 1], "B": [1, 1, 1, 1]}) + grouper = df["A"].apply(lambda x: x % 2) + grouped = df.groupby(grouper) + expected = DataFrame({"A": [12, 12], "B": [1, 1]}, index=[1, 2]) + tm.assert_frame_equal( + grouped.filter(lambda x: x["A"].sum() - x["B"].sum() > 10), expected + ) + + +def test_filter_mixed_df(): + df = DataFrame({"A": [1, 12, 12, 1], "B": "a b c d".split()}) + grouper = df["A"].apply(lambda x: x % 2) + grouped = df.groupby(grouper) + expected = DataFrame({"A": [12, 12], "B": ["b", "c"]}, index=[1, 2]) + tm.assert_frame_equal(grouped.filter(lambda x: x["A"].sum() > 10), expected) + + +def test_filter_out_all_groups(): + s = Series([1, 3, 20, 5, 22, 24, 7]) + grouper = s.apply(lambda x: x % 2) + grouped = s.groupby(grouper) + tm.assert_series_equal(grouped.filter(lambda x: x.mean() > 1000), s[[]]) + df = DataFrame({"A": [1, 12, 12, 1], "B": "a b c d".split()}) + grouper = df["A"].apply(lambda x: x % 2) + grouped = df.groupby(grouper) + tm.assert_frame_equal(grouped.filter(lambda x: x["A"].sum() > 1000), df.loc[[]]) + + +def test_filter_out_no_groups(): + s = Series([1, 3, 20, 5, 22, 24, 7]) + grouper = s.apply(lambda x: x % 2) + grouped = s.groupby(grouper) + filtered = grouped.filter(lambda x: x.mean() > 0) + tm.assert_series_equal(filtered, s) + df = DataFrame({"A": [1, 12, 12, 1], "B": "a b c d".split()}) + grouper = df["A"].apply(lambda x: x % 2) + grouped = df.groupby(grouper) + filtered = grouped.filter(lambda x: x["A"].mean() > 0) + tm.assert_frame_equal(filtered, df) + + +def test_filter_out_all_groups_in_df(): + # GH12768 + df = DataFrame({"a": [1, 1, 2], "b": [1, 2, 0]}) + res = df.groupby("a") + res = res.filter(lambda x: x["b"].sum() > 5, dropna=False) + expected = DataFrame({"a": [np.nan] * 3, "b": [np.nan] * 3}) + tm.assert_frame_equal(expected, res) + + df = DataFrame({"a": [1, 1, 2], "b": [1, 2, 0]}) + res = df.groupby("a") + res = res.filter(lambda x: x["b"].sum() > 5, dropna=True) + expected = DataFrame({"a": [], "b": []}, dtype="int64") + tm.assert_frame_equal(expected, res) + + +def test_filter_condition_raises(): + def raise_if_sum_is_zero(x): + if x.sum() == 0: + raise ValueError + return x.sum() > 0 + + s = Series([-1, 0, 1, 2]) + grouper = s.apply(lambda x: x % 2) + grouped = s.groupby(grouper) + msg = "the filter must return a boolean result" + with pytest.raises(TypeError, match=msg): + grouped.filter(raise_if_sum_is_zero) + + +def test_filter_with_axis_in_groupby(): + # issue 11041 + index = pd.MultiIndex.from_product([range(10), [0, 1]]) + data = DataFrame(np.arange(100).reshape(-1, 20), columns=index, dtype="int64") + + msg = "DataFrame.groupby with axis=1" + with tm.assert_produces_warning(FutureWarning, match=msg): + gb = data.groupby(level=0, axis=1) + result = gb.filter(lambda x: x.iloc[0, 0] > 10) + expected = data.iloc[:, 12:20] + tm.assert_frame_equal(result, expected) + + +def test_filter_bad_shapes(): + df = DataFrame({"A": np.arange(8), "B": list("aabbbbcc"), "C": np.arange(8)}) + s = df["B"] + g_df = df.groupby("B") + g_s = s.groupby(s) + + f = lambda x: x + msg = "filter function returned a DataFrame, but expected a scalar bool" + with pytest.raises(TypeError, match=msg): + g_df.filter(f) + msg = "the filter must return a boolean result" + with pytest.raises(TypeError, match=msg): + g_s.filter(f) + + f = lambda x: x == 1 + msg = "filter function returned a DataFrame, but expected a scalar bool" + with pytest.raises(TypeError, match=msg): + g_df.filter(f) + msg = "the filter must return a boolean result" + with pytest.raises(TypeError, match=msg): + g_s.filter(f) + + f = lambda x: np.outer(x, x) + msg = "can't multiply sequence by non-int of type 'str'" + with pytest.raises(TypeError, match=msg): + g_df.filter(f) + msg = "the filter must return a boolean result" + with pytest.raises(TypeError, match=msg): + g_s.filter(f) + + +def test_filter_nan_is_false(): + df = DataFrame({"A": np.arange(8), "B": list("aabbbbcc"), "C": np.arange(8)}) + s = df["B"] + g_df = df.groupby(df["B"]) + g_s = s.groupby(s) + + f = lambda x: np.nan + tm.assert_frame_equal(g_df.filter(f), df.loc[[]]) + tm.assert_series_equal(g_s.filter(f), s[[]]) + + +def test_filter_pdna_is_false(): + # in particular, dont raise in filter trying to call bool(pd.NA) + df = DataFrame({"A": np.arange(8), "B": list("aabbbbcc"), "C": np.arange(8)}) + ser = df["B"] + g_df = df.groupby(df["B"]) + g_s = ser.groupby(ser) + + func = lambda x: pd.NA + res = g_df.filter(func) + tm.assert_frame_equal(res, df.loc[[]]) + res = g_s.filter(func) + tm.assert_series_equal(res, ser[[]]) + + +def test_filter_against_workaround_ints(): + # Series of ints + s = Series(np.random.default_rng(2).integers(0, 100, 100)) + grouper = s.apply(lambda x: np.round(x, -1)) + grouped = s.groupby(grouper) + f = lambda x: x.mean() > 10 + + old_way = s[grouped.transform(f).astype("bool")] + new_way = grouped.filter(f) + tm.assert_series_equal(new_way.sort_values(), old_way.sort_values()) + + +def test_filter_against_workaround_floats(): + # Series of floats + s = 100 * Series(np.random.default_rng(2).random(100)) + grouper = s.apply(lambda x: np.round(x, -1)) + grouped = s.groupby(grouper) + f = lambda x: x.mean() > 10 + old_way = s[grouped.transform(f).astype("bool")] + new_way = grouped.filter(f) + tm.assert_series_equal(new_way.sort_values(), old_way.sort_values()) + + +def test_filter_against_workaround_dataframe(): + # Set up DataFrame of ints, floats, strings. + letters = np.array(list(ascii_lowercase)) + N = 100 + random_letters = letters.take( + np.random.default_rng(2).integers(0, 26, N, dtype=int) + ) + df = DataFrame( + { + "ints": Series(np.random.default_rng(2).integers(0, 100, N)), + "floats": N / 10 * Series(np.random.default_rng(2).random(N)), + "letters": Series(random_letters), + } + ) + + # Group by ints; filter on floats. + grouped = df.groupby("ints") + old_way = df[grouped.floats.transform(lambda x: x.mean() > N / 20).astype("bool")] + new_way = grouped.filter(lambda x: x["floats"].mean() > N / 20) + tm.assert_frame_equal(new_way, old_way) + + # Group by floats (rounded); filter on strings. + grouper = df.floats.apply(lambda x: np.round(x, -1)) + grouped = df.groupby(grouper) + old_way = df[grouped.letters.transform(lambda x: len(x) < N / 10).astype("bool")] + new_way = grouped.filter(lambda x: len(x.letters) < N / 10) + tm.assert_frame_equal(new_way, old_way) + + # Group by strings; filter on ints. + grouped = df.groupby("letters") + old_way = df[grouped.ints.transform(lambda x: x.mean() > N / 20).astype("bool")] + new_way = grouped.filter(lambda x: x["ints"].mean() > N / 20) + tm.assert_frame_equal(new_way, old_way) + + +def test_filter_using_len(): + # BUG GH4447 + df = DataFrame({"A": np.arange(8), "B": list("aabbbbcc"), "C": np.arange(8)}) + grouped = df.groupby("B") + actual = grouped.filter(lambda x: len(x) > 2) + expected = DataFrame( + {"A": np.arange(2, 6), "B": list("bbbb"), "C": np.arange(2, 6)}, + index=np.arange(2, 6, dtype=np.int64), + ) + tm.assert_frame_equal(actual, expected) + + actual = grouped.filter(lambda x: len(x) > 4) + expected = df.loc[[]] + tm.assert_frame_equal(actual, expected) + + # Series have always worked properly, but we'll test anyway. + s = df["B"] + grouped = s.groupby(s) + actual = grouped.filter(lambda x: len(x) > 2) + expected = Series(4 * ["b"], index=np.arange(2, 6, dtype=np.int64), name="B") + tm.assert_series_equal(actual, expected) + + actual = grouped.filter(lambda x: len(x) > 4) + expected = s[[]] + tm.assert_series_equal(actual, expected) + + +def test_filter_maintains_ordering(): + # Simple case: index is sequential. #4621 + df = DataFrame( + {"pid": [1, 1, 1, 2, 2, 3, 3, 3], "tag": [23, 45, 62, 24, 45, 34, 25, 62]} + ) + s = df["pid"] + grouped = df.groupby("tag") + actual = grouped.filter(lambda x: len(x) > 1) + expected = df.iloc[[1, 2, 4, 7]] + tm.assert_frame_equal(actual, expected) + + grouped = s.groupby(df["tag"]) + actual = grouped.filter(lambda x: len(x) > 1) + expected = s.iloc[[1, 2, 4, 7]] + tm.assert_series_equal(actual, expected) + + # Now index is sequentially decreasing. + df.index = np.arange(len(df) - 1, -1, -1) + s = df["pid"] + grouped = df.groupby("tag") + actual = grouped.filter(lambda x: len(x) > 1) + expected = df.iloc[[1, 2, 4, 7]] + tm.assert_frame_equal(actual, expected) + + grouped = s.groupby(df["tag"]) + actual = grouped.filter(lambda x: len(x) > 1) + expected = s.iloc[[1, 2, 4, 7]] + tm.assert_series_equal(actual, expected) + + # Index is shuffled. + SHUFFLED = [4, 6, 7, 2, 1, 0, 5, 3] + df.index = df.index[SHUFFLED] + s = df["pid"] + grouped = df.groupby("tag") + actual = grouped.filter(lambda x: len(x) > 1) + expected = df.iloc[[1, 2, 4, 7]] + tm.assert_frame_equal(actual, expected) + + grouped = s.groupby(df["tag"]) + actual = grouped.filter(lambda x: len(x) > 1) + expected = s.iloc[[1, 2, 4, 7]] + tm.assert_series_equal(actual, expected) + + +def test_filter_multiple_timestamp(): + # GH 10114 + df = DataFrame( + { + "A": np.arange(5, dtype="int64"), + "B": ["foo", "bar", "foo", "bar", "bar"], + "C": Timestamp("20130101"), + } + ) + + grouped = df.groupby(["B", "C"]) + + result = grouped["A"].filter(lambda x: True) + tm.assert_series_equal(df["A"], result) + + result = grouped["A"].transform(len) + expected = Series([2, 3, 2, 3, 3], name="A") + tm.assert_series_equal(result, expected) + + result = grouped.filter(lambda x: True) + tm.assert_frame_equal(df, result) + + result = grouped.transform("sum") + expected = DataFrame({"A": [2, 8, 2, 8, 8]}) + tm.assert_frame_equal(result, expected) + + result = grouped.transform(len) + expected = DataFrame({"A": [2, 3, 2, 3, 3]}) + tm.assert_frame_equal(result, expected) + + +def test_filter_and_transform_with_non_unique_int_index(): + # GH4620 + index = [1, 1, 1, 2, 1, 1, 0, 1] + df = DataFrame( + {"pid": [1, 1, 1, 2, 2, 3, 3, 3], "tag": [23, 45, 62, 24, 45, 34, 25, 62]}, + index=index, + ) + grouped_df = df.groupby("tag") + ser = df["pid"] + grouped_ser = ser.groupby(df["tag"]) + expected_indexes = [1, 2, 4, 7] + + # Filter DataFrame + actual = grouped_df.filter(lambda x: len(x) > 1) + expected = df.iloc[expected_indexes] + tm.assert_frame_equal(actual, expected) + + actual = grouped_df.filter(lambda x: len(x) > 1, dropna=False) + # Cast to avoid upcast when setting nan below + expected = df.copy().astype("float64") + expected.iloc[[0, 3, 5, 6]] = np.nan + tm.assert_frame_equal(actual, expected) + + # Filter Series + actual = grouped_ser.filter(lambda x: len(x) > 1) + expected = ser.take(expected_indexes) + tm.assert_series_equal(actual, expected) + + actual = grouped_ser.filter(lambda x: len(x) > 1, dropna=False) + expected = Series([np.nan, 1, 1, np.nan, 2, np.nan, np.nan, 3], index, name="pid") + # ^ made manually because this can get confusing! + tm.assert_series_equal(actual, expected) + + # Transform Series + actual = grouped_ser.transform(len) + expected = Series([1, 2, 2, 1, 2, 1, 1, 2], index, name="pid") + tm.assert_series_equal(actual, expected) + + # Transform (a column from) DataFrameGroupBy + actual = grouped_df.pid.transform(len) + tm.assert_series_equal(actual, expected) + + +def test_filter_and_transform_with_multiple_non_unique_int_index(): + # GH4620 + index = [1, 1, 1, 2, 0, 0, 0, 1] + df = DataFrame( + {"pid": [1, 1, 1, 2, 2, 3, 3, 3], "tag": [23, 45, 62, 24, 45, 34, 25, 62]}, + index=index, + ) + grouped_df = df.groupby("tag") + ser = df["pid"] + grouped_ser = ser.groupby(df["tag"]) + expected_indexes = [1, 2, 4, 7] + + # Filter DataFrame + actual = grouped_df.filter(lambda x: len(x) > 1) + expected = df.iloc[expected_indexes] + tm.assert_frame_equal(actual, expected) + + actual = grouped_df.filter(lambda x: len(x) > 1, dropna=False) + # Cast to avoid upcast when setting nan below + expected = df.copy().astype("float64") + expected.iloc[[0, 3, 5, 6]] = np.nan + tm.assert_frame_equal(actual, expected) + + # Filter Series + actual = grouped_ser.filter(lambda x: len(x) > 1) + expected = ser.take(expected_indexes) + tm.assert_series_equal(actual, expected) + + actual = grouped_ser.filter(lambda x: len(x) > 1, dropna=False) + expected = Series([np.nan, 1, 1, np.nan, 2, np.nan, np.nan, 3], index, name="pid") + # ^ made manually because this can get confusing! + tm.assert_series_equal(actual, expected) + + # Transform Series + actual = grouped_ser.transform(len) + expected = Series([1, 2, 2, 1, 2, 1, 1, 2], index, name="pid") + tm.assert_series_equal(actual, expected) + + # Transform (a column from) DataFrameGroupBy + actual = grouped_df.pid.transform(len) + tm.assert_series_equal(actual, expected) + + +def test_filter_and_transform_with_non_unique_float_index(): + # GH4620 + index = np.array([1, 1, 1, 2, 1, 1, 0, 1], dtype=float) + df = DataFrame( + {"pid": [1, 1, 1, 2, 2, 3, 3, 3], "tag": [23, 45, 62, 24, 45, 34, 25, 62]}, + index=index, + ) + grouped_df = df.groupby("tag") + ser = df["pid"] + grouped_ser = ser.groupby(df["tag"]) + expected_indexes = [1, 2, 4, 7] + + # Filter DataFrame + actual = grouped_df.filter(lambda x: len(x) > 1) + expected = df.iloc[expected_indexes] + tm.assert_frame_equal(actual, expected) + + actual = grouped_df.filter(lambda x: len(x) > 1, dropna=False) + # Cast to avoid upcast when setting nan below + expected = df.copy().astype("float64") + expected.iloc[[0, 3, 5, 6]] = np.nan + tm.assert_frame_equal(actual, expected) + + # Filter Series + actual = grouped_ser.filter(lambda x: len(x) > 1) + expected = ser.take(expected_indexes) + tm.assert_series_equal(actual, expected) + + actual = grouped_ser.filter(lambda x: len(x) > 1, dropna=False) + expected = Series([np.nan, 1, 1, np.nan, 2, np.nan, np.nan, 3], index, name="pid") + # ^ made manually because this can get confusing! + tm.assert_series_equal(actual, expected) + + # Transform Series + actual = grouped_ser.transform(len) + expected = Series([1, 2, 2, 1, 2, 1, 1, 2], index, name="pid") + tm.assert_series_equal(actual, expected) + + # Transform (a column from) DataFrameGroupBy + actual = grouped_df.pid.transform(len) + tm.assert_series_equal(actual, expected) + + +def test_filter_and_transform_with_non_unique_timestamp_index(): + # GH4620 + t0 = Timestamp("2013-09-30 00:05:00") + t1 = Timestamp("2013-10-30 00:05:00") + t2 = Timestamp("2013-11-30 00:05:00") + index = [t1, t1, t1, t2, t1, t1, t0, t1] + df = DataFrame( + {"pid": [1, 1, 1, 2, 2, 3, 3, 3], "tag": [23, 45, 62, 24, 45, 34, 25, 62]}, + index=index, + ) + grouped_df = df.groupby("tag") + ser = df["pid"] + grouped_ser = ser.groupby(df["tag"]) + expected_indexes = [1, 2, 4, 7] + + # Filter DataFrame + actual = grouped_df.filter(lambda x: len(x) > 1) + expected = df.iloc[expected_indexes] + tm.assert_frame_equal(actual, expected) + + actual = grouped_df.filter(lambda x: len(x) > 1, dropna=False) + # Cast to avoid upcast when setting nan below + expected = df.copy().astype("float64") + expected.iloc[[0, 3, 5, 6]] = np.nan + tm.assert_frame_equal(actual, expected) + + # Filter Series + actual = grouped_ser.filter(lambda x: len(x) > 1) + expected = ser.take(expected_indexes) + tm.assert_series_equal(actual, expected) + + actual = grouped_ser.filter(lambda x: len(x) > 1, dropna=False) + expected = Series([np.nan, 1, 1, np.nan, 2, np.nan, np.nan, 3], index, name="pid") + # ^ made manually because this can get confusing! + tm.assert_series_equal(actual, expected) + + # Transform Series + actual = grouped_ser.transform(len) + expected = Series([1, 2, 2, 1, 2, 1, 1, 2], index, name="pid") + tm.assert_series_equal(actual, expected) + + # Transform (a column from) DataFrameGroupBy + actual = grouped_df.pid.transform(len) + tm.assert_series_equal(actual, expected) + + +def test_filter_and_transform_with_non_unique_string_index(): + # GH4620 + index = list("bbbcbbab") + df = DataFrame( + {"pid": [1, 1, 1, 2, 2, 3, 3, 3], "tag": [23, 45, 62, 24, 45, 34, 25, 62]}, + index=index, + ) + grouped_df = df.groupby("tag") + ser = df["pid"] + grouped_ser = ser.groupby(df["tag"]) + expected_indexes = [1, 2, 4, 7] + + # Filter DataFrame + actual = grouped_df.filter(lambda x: len(x) > 1) + expected = df.iloc[expected_indexes] + tm.assert_frame_equal(actual, expected) + + actual = grouped_df.filter(lambda x: len(x) > 1, dropna=False) + # Cast to avoid upcast when setting nan below + expected = df.copy().astype("float64") + expected.iloc[[0, 3, 5, 6]] = np.nan + tm.assert_frame_equal(actual, expected) + + # Filter Series + actual = grouped_ser.filter(lambda x: len(x) > 1) + expected = ser.take(expected_indexes) + tm.assert_series_equal(actual, expected) + + actual = grouped_ser.filter(lambda x: len(x) > 1, dropna=False) + expected = Series([np.nan, 1, 1, np.nan, 2, np.nan, np.nan, 3], index, name="pid") + # ^ made manually because this can get confusing! + tm.assert_series_equal(actual, expected) + + # Transform Series + actual = grouped_ser.transform(len) + expected = Series([1, 2, 2, 1, 2, 1, 1, 2], index, name="pid") + tm.assert_series_equal(actual, expected) + + # Transform (a column from) DataFrameGroupBy + actual = grouped_df.pid.transform(len) + tm.assert_series_equal(actual, expected) + + +def test_filter_has_access_to_grouped_cols(): + df = DataFrame([[1, 2], [1, 3], [5, 6]], columns=["A", "B"]) + g = df.groupby("A") + # previously didn't have access to col A #???? + filt = g.filter(lambda x: x["A"].sum() == 2) + tm.assert_frame_equal(filt, df.iloc[[0, 1]]) + + +def test_filter_enforces_scalarness(): + df = DataFrame( + [ + ["best", "a", "x"], + ["worst", "b", "y"], + ["best", "c", "x"], + ["best", "d", "y"], + ["worst", "d", "y"], + ["worst", "d", "y"], + ["best", "d", "z"], + ], + columns=["a", "b", "c"], + ) + with pytest.raises(TypeError, match="filter function returned a.*"): + df.groupby("c").filter(lambda g: g["a"] == "best") + + +def test_filter_non_bool_raises(): + df = DataFrame( + [ + ["best", "a", 1], + ["worst", "b", 1], + ["best", "c", 1], + ["best", "d", 1], + ["worst", "d", 1], + ["worst", "d", 1], + ["best", "d", 1], + ], + columns=["a", "b", "c"], + ) + with pytest.raises(TypeError, match="filter function returned a.*"): + df.groupby("a").filter(lambda g: g.c.mean()) + + +def test_filter_dropna_with_empty_groups(): + # GH 10780 + data = Series(np.random.default_rng(2).random(9), index=np.repeat([1, 2, 3], 3)) + grouped = data.groupby(level=0) + result_false = grouped.filter(lambda x: x.mean() > 1, dropna=False) + expected_false = Series([np.nan] * 9, index=np.repeat([1, 2, 3], 3)) + tm.assert_series_equal(result_false, expected_false) + + result_true = grouped.filter(lambda x: x.mean() > 1, dropna=True) + expected_true = Series(index=pd.Index([], dtype=int), dtype=np.float64) + tm.assert_series_equal(result_true, expected_true) + + +def test_filter_consistent_result_before_after_agg_func(): + # GH 17091 + df = DataFrame({"data": range(6), "key": list("ABCABC")}) + grouper = df.groupby("key") + result = grouper.filter(lambda x: True) + expected = DataFrame({"data": range(6), "key": list("ABCABC")}) + tm.assert_frame_equal(result, expected) + + grouper.sum() + result = grouper.filter(lambda x: True) + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_groupby.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_groupby.py new file mode 100644 index 0000000000000000000000000000000000000000..44d6340e55507275284a066f10950e97f795e699 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_groupby.py @@ -0,0 +1,3346 @@ +from datetime import datetime +import decimal +from decimal import Decimal +import re + +import numpy as np +import pytest + +from pandas.errors import ( + PerformanceWarning, + SpecificationError, +) +import pandas.util._test_decorators as td + +from pandas.core.dtypes.common import is_string_dtype + +import pandas as pd +from pandas import ( + Categorical, + DataFrame, + Grouper, + Index, + Interval, + MultiIndex, + RangeIndex, + Series, + Timedelta, + Timestamp, + date_range, + to_datetime, +) +import pandas._testing as tm +from pandas.core.arrays import BooleanArray +import pandas.core.common as com + +pytestmark = pytest.mark.filterwarnings("ignore:Mean of empty slice:RuntimeWarning") + + +def test_repr(): + # GH18203 + result = repr(Grouper(key="A", level="B")) + expected = "Grouper(key='A', level='B', axis=0, sort=False, dropna=True)" + assert result == expected + + +def test_groupby_std_datetimelike(warn_copy_on_write): + # GH#48481 + tdi = pd.timedelta_range("1 Day", periods=10000) + ser = Series(tdi) + ser[::5] *= 2 # get different std for different groups + + df = ser.to_frame("A").copy() + + df["B"] = ser + Timestamp(0) + df["C"] = ser + Timestamp(0, tz="UTC") + df.iloc[-1] = pd.NaT # last group includes NaTs + + gb = df.groupby(list(range(5)) * 2000) + + result = gb.std() + + # Note: this does not _exactly_ match what we would get if we did + # [gb.get_group(i).std() for i in gb.groups] + # but it _does_ match the floating point error we get doing the + # same operation on int64 data xref GH#51332 + td1 = Timedelta("2887 days 11:21:02.326710176") + td4 = Timedelta("2886 days 00:42:34.664668096") + exp_ser = Series([td1 * 2, td1, td1, td1, td4], index=np.arange(5)) + expected = DataFrame({"A": exp_ser, "B": exp_ser, "C": exp_ser}) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype", ["int64", "int32", "float64", "float32"]) +def test_basic_aggregations(dtype): + data = Series(np.arange(9) // 3, index=np.arange(9), dtype=dtype) + + index = np.arange(9) + np.random.default_rng(2).shuffle(index) + data = data.reindex(index) + + grouped = data.groupby(lambda x: x // 3, group_keys=False) + + for k, v in grouped: + assert len(v) == 3 + + msg = "using SeriesGroupBy.mean" + with tm.assert_produces_warning(FutureWarning, match=msg): + agged = grouped.aggregate(np.mean) + assert agged[1] == 1 + + msg = "using SeriesGroupBy.mean" + with tm.assert_produces_warning(FutureWarning, match=msg): + expected = grouped.agg(np.mean) + tm.assert_series_equal(agged, expected) # shorthand + tm.assert_series_equal(agged, grouped.mean()) + result = grouped.sum() + msg = "using SeriesGroupBy.sum" + with tm.assert_produces_warning(FutureWarning, match=msg): + expected = grouped.agg(np.sum) + tm.assert_series_equal(result, expected) + + expected = grouped.apply(lambda x: x * x.sum()) + transformed = grouped.transform(lambda x: x * x.sum()) + assert transformed[7] == 12 + tm.assert_series_equal(transformed, expected) + + value_grouped = data.groupby(data) + msg = "using SeriesGroupBy.mean" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = value_grouped.aggregate(np.mean) + tm.assert_series_equal(result, agged, check_index_type=False) + + # complex agg + msg = "using SeriesGroupBy.[mean|std]" + with tm.assert_produces_warning(FutureWarning, match=msg): + agged = grouped.aggregate([np.mean, np.std]) + + msg = r"nested renamer is not supported" + with pytest.raises(SpecificationError, match=msg): + grouped.aggregate({"one": np.mean, "two": np.std}) + + group_constants = {0: 10, 1: 20, 2: 30} + msg = ( + "Pinning the groupby key to each group in SeriesGroupBy.agg is deprecated, " + "and cases that relied on it will raise in a future version" + ) + with tm.assert_produces_warning(FutureWarning, match=msg): + # GH#41090 + agged = grouped.agg(lambda x: group_constants[x.name] + x.mean()) + assert agged[1] == 21 + + # corner cases + msg = "Must produce aggregated value" + # exception raised is type Exception + with pytest.raises(Exception, match=msg): + grouped.aggregate(lambda x: x * 2) + + +def test_groupby_nonobject_dtype(multiindex_dataframe_random_data): + key = multiindex_dataframe_random_data.index.codes[0] + grouped = multiindex_dataframe_random_data.groupby(key) + result = grouped.sum() + + expected = multiindex_dataframe_random_data.groupby(key.astype("O")).sum() + assert result.index.dtype == np.int8 + assert expected.index.dtype == np.int64 + tm.assert_frame_equal(result, expected, check_index_type=False) + + +def test_groupby_nonobject_dtype_mixed(): + # GH 3911, mixed frame non-conversion + df = DataFrame( + { + "A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"], + "B": ["one", "one", "two", "three", "two", "two", "one", "three"], + "C": np.random.default_rng(2).standard_normal(8), + "D": np.array(np.random.default_rng(2).standard_normal(8), dtype="float32"), + } + ) + df["value"] = range(len(df)) + + def max_value(group): + return group.loc[group["value"].idxmax()] + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + applied = df.groupby("A").apply(max_value) + result = applied.dtypes + expected = df.dtypes + tm.assert_series_equal(result, expected) + + +def test_inconsistent_return_type(): + # GH5592 + # inconsistent return type + df = DataFrame( + { + "A": ["Tiger", "Tiger", "Tiger", "Lamb", "Lamb", "Pony", "Pony"], + "B": Series(np.arange(7), dtype="int64"), + "C": date_range("20130101", periods=7), + } + ) + + def f_0(grp): + return grp.iloc[0] + + expected = df.groupby("A").first()[["B"]] + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("A").apply(f_0)[["B"]] + tm.assert_frame_equal(result, expected) + + def f_1(grp): + if grp.name == "Tiger": + return None + return grp.iloc[0] + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("A").apply(f_1)[["B"]] + e = expected.copy() + e.loc["Tiger"] = np.nan + tm.assert_frame_equal(result, e) + + def f_2(grp): + if grp.name == "Pony": + return None + return grp.iloc[0] + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("A").apply(f_2)[["B"]] + e = expected.copy() + e.loc["Pony"] = np.nan + tm.assert_frame_equal(result, e) + + # 5592 revisited, with datetimes + def f_3(grp): + if grp.name == "Pony": + return None + return grp.iloc[0] + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("A").apply(f_3)[["C"]] + e = df.groupby("A").first()[["C"]] + e.loc["Pony"] = pd.NaT + tm.assert_frame_equal(result, e) + + # scalar outputs + def f_4(grp): + if grp.name == "Pony": + return None + return grp.iloc[0].loc["C"] + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("A").apply(f_4) + e = df.groupby("A").first()["C"].copy() + e.loc["Pony"] = np.nan + e.name = None + tm.assert_series_equal(result, e) + + +def test_pass_args_kwargs(ts, tsframe): + def f(x, q=None, axis=0): + return np.percentile(x, q, axis=axis) + + g = lambda x: np.percentile(x, 80, axis=0) + + # Series + ts_grouped = ts.groupby(lambda x: x.month) + agg_result = ts_grouped.agg(np.percentile, 80, axis=0) + apply_result = ts_grouped.apply(np.percentile, 80, axis=0) + trans_result = ts_grouped.transform(np.percentile, 80, axis=0) + + agg_expected = ts_grouped.quantile(0.8) + trans_expected = ts_grouped.transform(g) + + tm.assert_series_equal(apply_result, agg_expected) + tm.assert_series_equal(agg_result, agg_expected) + tm.assert_series_equal(trans_result, trans_expected) + + agg_result = ts_grouped.agg(f, q=80) + apply_result = ts_grouped.apply(f, q=80) + trans_result = ts_grouped.transform(f, q=80) + tm.assert_series_equal(agg_result, agg_expected) + tm.assert_series_equal(apply_result, agg_expected) + tm.assert_series_equal(trans_result, trans_expected) + + # DataFrame + for as_index in [True, False]: + df_grouped = tsframe.groupby(lambda x: x.month, as_index=as_index) + warn = None if as_index else FutureWarning + msg = "A grouping .* was excluded from the result" + with tm.assert_produces_warning(warn, match=msg): + agg_result = df_grouped.agg(np.percentile, 80, axis=0) + with tm.assert_produces_warning(warn, match=msg): + apply_result = df_grouped.apply(DataFrame.quantile, 0.8) + with tm.assert_produces_warning(warn, match=msg): + expected = df_grouped.quantile(0.8) + tm.assert_frame_equal(apply_result, expected, check_names=False) + tm.assert_frame_equal(agg_result, expected) + + apply_result = df_grouped.apply(DataFrame.quantile, [0.4, 0.8]) + with tm.assert_produces_warning(warn, match=msg): + expected_seq = df_grouped.quantile([0.4, 0.8]) + tm.assert_frame_equal(apply_result, expected_seq, check_names=False) + + with tm.assert_produces_warning(warn, match=msg): + agg_result = df_grouped.agg(f, q=80) + with tm.assert_produces_warning(warn, match=msg): + apply_result = df_grouped.apply(DataFrame.quantile, q=0.8) + tm.assert_frame_equal(agg_result, expected) + tm.assert_frame_equal(apply_result, expected, check_names=False) + + +@pytest.mark.parametrize("as_index", [True, False]) +def test_pass_args_kwargs_duplicate_columns(tsframe, as_index): + # go through _aggregate_frame with self.axis == 0 and duplicate columns + tsframe.columns = ["A", "B", "A", "C"] + gb = tsframe.groupby(lambda x: x.month, as_index=as_index) + + warn = None if as_index else FutureWarning + msg = "A grouping .* was excluded from the result" + with tm.assert_produces_warning(warn, match=msg): + res = gb.agg(np.percentile, 80, axis=0) + + ex_data = { + 1: tsframe[tsframe.index.month == 1].quantile(0.8), + 2: tsframe[tsframe.index.month == 2].quantile(0.8), + } + expected = DataFrame(ex_data).T + if not as_index: + # TODO: try to get this more consistent? + expected.index = Index(range(2)) + + tm.assert_frame_equal(res, expected) + + +def test_len(): + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + grouped = df.groupby([lambda x: x.year, lambda x: x.month, lambda x: x.day]) + assert len(grouped) == len(df) + + grouped = df.groupby([lambda x: x.year, lambda x: x.month]) + expected = len({(x.year, x.month) for x in df.index}) + assert len(grouped) == expected + + +def test_len_nan_group(): + # issue 11016 + df = DataFrame({"a": [np.nan] * 3, "b": [1, 2, 3]}) + assert len(df.groupby("a")) == 0 + assert len(df.groupby("b")) == 3 + assert len(df.groupby(["a", "b"])) == 3 + + +def test_basic_regression(): + # regression + result = Series([1.0 * x for x in list(range(1, 10)) * 10]) + + data = np.random.default_rng(2).random(1100) * 10.0 + groupings = Series(data) + + grouped = result.groupby(groupings) + grouped.mean() + + +@pytest.mark.parametrize( + "dtype", ["float64", "float32", "int64", "int32", "int16", "int8"] +) +def test_with_na_groups(dtype): + index = Index(np.arange(10)) + values = Series(np.ones(10), index, dtype=dtype) + labels = Series( + [np.nan, "foo", "bar", "bar", np.nan, np.nan, "bar", "bar", np.nan, "foo"], + index=index, + ) + + # this SHOULD be an int + grouped = values.groupby(labels) + agged = grouped.agg(len) + expected = Series([4, 2], index=["bar", "foo"]) + + tm.assert_series_equal(agged, expected, check_dtype=False) + + # assert issubclass(agged.dtype.type, np.integer) + + # explicitly return a float from my function + def f(x): + return float(len(x)) + + agged = grouped.agg(f) + expected = Series([4.0, 2.0], index=["bar", "foo"]) + + tm.assert_series_equal(agged, expected) + + +def test_indices_concatenation_order(): + # GH 2808 + + def f1(x): + y = x[(x.b % 2) == 1] ** 2 + if y.empty: + multiindex = MultiIndex(levels=[[]] * 2, codes=[[]] * 2, names=["b", "c"]) + res = DataFrame(columns=["a"], index=multiindex) + return res + else: + y = y.set_index(["b", "c"]) + return y + + def f2(x): + y = x[(x.b % 2) == 1] ** 2 + if y.empty: + return DataFrame() + else: + y = y.set_index(["b", "c"]) + return y + + def f3(x): + y = x[(x.b % 2) == 1] ** 2 + if y.empty: + multiindex = MultiIndex( + levels=[[]] * 2, codes=[[]] * 2, names=["foo", "bar"] + ) + res = DataFrame(columns=["a", "b"], index=multiindex) + return res + else: + return y + + df = DataFrame({"a": [1, 2, 2, 2], "b": range(4), "c": range(5, 9)}) + + df2 = DataFrame({"a": [3, 2, 2, 2], "b": range(4), "c": range(5, 9)}) + + depr_msg = "The behavior of array concatenation with empty entries is deprecated" + + # correct result + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result1 = df.groupby("a").apply(f1) + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result2 = df2.groupby("a").apply(f1) + tm.assert_frame_equal(result1, result2) + + # should fail (not the same number of levels) + msg = "Cannot concat indices that do not have the same number of levels" + with pytest.raises(AssertionError, match=msg): + df.groupby("a").apply(f2) + with pytest.raises(AssertionError, match=msg): + df2.groupby("a").apply(f2) + + # should fail (incorrect shape) + with pytest.raises(AssertionError, match=msg): + df.groupby("a").apply(f3) + with pytest.raises(AssertionError, match=msg): + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + df2.groupby("a").apply(f3) + + +def test_attr_wrapper(ts): + grouped = ts.groupby(lambda x: x.weekday()) + + result = grouped.std() + expected = grouped.agg(lambda x: np.std(x, ddof=1)) + tm.assert_series_equal(result, expected) + + # this is pretty cool + result = grouped.describe() + expected = {name: gp.describe() for name, gp in grouped} + expected = DataFrame(expected).T + tm.assert_frame_equal(result, expected) + + # get attribute + result = grouped.dtype + expected = grouped.agg(lambda x: x.dtype) + tm.assert_series_equal(result, expected) + + # make sure raises error + msg = "'SeriesGroupBy' object has no attribute 'foo'" + with pytest.raises(AttributeError, match=msg): + getattr(grouped, "foo") + + +def test_frame_groupby(tsframe): + grouped = tsframe.groupby(lambda x: x.weekday()) + + # aggregate + aggregated = grouped.aggregate("mean") + assert len(aggregated) == 5 + assert len(aggregated.columns) == 4 + + # by string + tscopy = tsframe.copy() + tscopy["weekday"] = [x.weekday() for x in tscopy.index] + stragged = tscopy.groupby("weekday").aggregate("mean") + tm.assert_frame_equal(stragged, aggregated, check_names=False) + + # transform + grouped = tsframe.head(30).groupby(lambda x: x.weekday()) + transformed = grouped.transform(lambda x: x - x.mean()) + assert len(transformed) == 30 + assert len(transformed.columns) == 4 + + # transform propagate + transformed = grouped.transform(lambda x: x.mean()) + for name, group in grouped: + mean = group.mean() + for idx in group.index: + tm.assert_series_equal(transformed.xs(idx), mean, check_names=False) + + # iterate + for weekday, group in grouped: + assert group.index[0].weekday() == weekday + + # groups / group_indices + groups = grouped.groups + indices = grouped.indices + + for k, v in groups.items(): + samething = tsframe.index.take(indices[k]) + assert (samething == v).all() + + +def test_frame_groupby_columns(tsframe): + mapping = {"A": 0, "B": 0, "C": 1, "D": 1} + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + grouped = tsframe.groupby(mapping, axis=1) + + # aggregate + aggregated = grouped.aggregate("mean") + assert len(aggregated) == len(tsframe) + assert len(aggregated.columns) == 2 + + # transform + tf = lambda x: x - x.mean() + msg = "The 'axis' keyword in DataFrame.groupby is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + groupedT = tsframe.T.groupby(mapping, axis=0) + tm.assert_frame_equal(groupedT.transform(tf).T, grouped.transform(tf)) + + # iterate + for k, v in grouped: + assert len(v.columns) == 2 + + +def test_frame_set_name_single(df): + grouped = df.groupby("A") + + result = grouped.mean(numeric_only=True) + assert result.index.name == "A" + + result = df.groupby("A", as_index=False).mean(numeric_only=True) + assert result.index.name != "A" + + result = grouped[["C", "D"]].agg("mean") + assert result.index.name == "A" + + result = grouped.agg({"C": "mean", "D": "std"}) + assert result.index.name == "A" + + result = grouped["C"].mean() + assert result.index.name == "A" + result = grouped["C"].agg("mean") + assert result.index.name == "A" + result = grouped["C"].agg(["mean", "std"]) + assert result.index.name == "A" + + msg = r"nested renamer is not supported" + with pytest.raises(SpecificationError, match=msg): + grouped["C"].agg({"foo": "mean", "bar": "std"}) + + +def test_multi_func(df): + col1 = df["A"] + col2 = df["B"] + + grouped = df.groupby([col1.get, col2.get]) + agged = grouped.mean(numeric_only=True) + expected = df.groupby(["A", "B"]).mean() + + # TODO groupby get drops names + tm.assert_frame_equal( + agged.loc[:, ["C", "D"]], expected.loc[:, ["C", "D"]], check_names=False + ) + + # some "groups" with no data + df = DataFrame( + { + "v1": np.random.default_rng(2).standard_normal(6), + "v2": np.random.default_rng(2).standard_normal(6), + "k1": np.array(["b", "b", "b", "a", "a", "a"]), + "k2": np.array(["1", "1", "1", "2", "2", "2"]), + }, + index=["one", "two", "three", "four", "five", "six"], + ) + # only verify that it works for now + grouped = df.groupby(["k1", "k2"]) + grouped.agg("sum") + + +def test_multi_key_multiple_functions(df): + grouped = df.groupby(["A", "B"])["C"] + + agged = grouped.agg(["mean", "std"]) + expected = DataFrame({"mean": grouped.agg("mean"), "std": grouped.agg("std")}) + tm.assert_frame_equal(agged, expected) + + +def test_frame_multi_key_function_list(): + data = DataFrame( + { + "A": [ + "foo", + "foo", + "foo", + "foo", + "bar", + "bar", + "bar", + "bar", + "foo", + "foo", + "foo", + ], + "B": [ + "one", + "one", + "one", + "two", + "one", + "one", + "one", + "two", + "two", + "two", + "one", + ], + "D": np.random.default_rng(2).standard_normal(11), + "E": np.random.default_rng(2).standard_normal(11), + "F": np.random.default_rng(2).standard_normal(11), + } + ) + + grouped = data.groupby(["A", "B"]) + funcs = ["mean", "std"] + agged = grouped.agg(funcs) + expected = pd.concat( + [grouped["D"].agg(funcs), grouped["E"].agg(funcs), grouped["F"].agg(funcs)], + keys=["D", "E", "F"], + axis=1, + ) + assert isinstance(agged.index, MultiIndex) + assert isinstance(expected.index, MultiIndex) + tm.assert_frame_equal(agged, expected) + + +def test_frame_multi_key_function_list_partial_failure(): + data = DataFrame( + { + "A": [ + "foo", + "foo", + "foo", + "foo", + "bar", + "bar", + "bar", + "bar", + "foo", + "foo", + "foo", + ], + "B": [ + "one", + "one", + "one", + "two", + "one", + "one", + "one", + "two", + "two", + "two", + "one", + ], + "C": [ + "dull", + "dull", + "shiny", + "dull", + "dull", + "shiny", + "shiny", + "dull", + "shiny", + "shiny", + "shiny", + ], + "D": np.random.default_rng(2).standard_normal(11), + "E": np.random.default_rng(2).standard_normal(11), + "F": np.random.default_rng(2).standard_normal(11), + } + ) + + grouped = data.groupby(["A", "B"]) + funcs = ["mean", "std"] + msg = re.escape("agg function failed [how->mean,dtype->") + with pytest.raises(TypeError, match=msg): + grouped.agg(funcs) + + +@pytest.mark.parametrize("op", [lambda x: x.sum(), lambda x: x.mean()]) +def test_groupby_multiple_columns(df, op): + data = df + grouped = data.groupby(["A", "B"]) + + result1 = op(grouped) + + keys = [] + values = [] + for n1, gp1 in data.groupby("A"): + for n2, gp2 in gp1.groupby("B"): + keys.append((n1, n2)) + values.append(op(gp2.loc[:, ["C", "D"]])) + + mi = MultiIndex.from_tuples(keys, names=["A", "B"]) + expected = pd.concat(values, axis=1).T + expected.index = mi + + # a little bit crude + for col in ["C", "D"]: + result_col = op(grouped[col]) + pivoted = result1[col] + exp = expected[col] + tm.assert_series_equal(result_col, exp) + tm.assert_series_equal(pivoted, exp) + + # test single series works the same + result = data["C"].groupby([data["A"], data["B"]]).mean() + expected = data.groupby(["A", "B"]).mean()["C"] + + tm.assert_series_equal(result, expected) + + +def test_as_index_select_column(): + # GH 5764 + df = DataFrame([[1, 2], [1, 4], [5, 6]], columns=["A", "B"]) + result = df.groupby("A", as_index=False)["B"].get_group(1) + expected = Series([2, 4], name="B") + tm.assert_series_equal(result, expected) + + result = df.groupby("A", as_index=False, group_keys=True)["B"].apply( + lambda x: x.cumsum() + ) + expected = Series( + [2, 6, 6], name="B", index=MultiIndex.from_tuples([(0, 0), (0, 1), (1, 2)]) + ) + tm.assert_series_equal(result, expected) + + +def test_obj_arg_get_group_deprecated(): + depr_msg = "obj is deprecated" + + df = DataFrame({"a": [1, 1, 2], "b": [3, 4, 5]}) + expected = df.iloc[df.groupby("b").indices.get(4)] + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + result = df.groupby("b").get_group(4, obj=df) + tm.assert_frame_equal(result, expected) + + +def test_groupby_as_index_select_column_sum_empty_df(): + # GH 35246 + df = DataFrame(columns=Index(["A", "B", "C"], name="alpha")) + left = df.groupby(by="A", as_index=False)["B"].sum(numeric_only=False) + + expected = DataFrame(columns=df.columns[:2], index=range(0)) + # GH#50744 - Columns after selection shouldn't retain names + expected.columns.names = [None] + tm.assert_frame_equal(left, expected) + + +def test_groupby_as_index_agg(df): + grouped = df.groupby("A", as_index=False) + + # single-key + + result = grouped[["C", "D"]].agg("mean") + expected = grouped.mean(numeric_only=True) + tm.assert_frame_equal(result, expected) + + result2 = grouped.agg({"C": "mean", "D": "sum"}) + expected2 = grouped.mean(numeric_only=True) + expected2["D"] = grouped.sum()["D"] + tm.assert_frame_equal(result2, expected2) + + grouped = df.groupby("A", as_index=True) + + msg = r"nested renamer is not supported" + with pytest.raises(SpecificationError, match=msg): + grouped["C"].agg({"Q": "sum"}) + + # multi-key + + grouped = df.groupby(["A", "B"], as_index=False) + + result = grouped.agg("mean") + expected = grouped.mean() + tm.assert_frame_equal(result, expected) + + result2 = grouped.agg({"C": "mean", "D": "sum"}) + expected2 = grouped.mean() + expected2["D"] = grouped.sum()["D"] + tm.assert_frame_equal(result2, expected2) + + expected3 = grouped["C"].sum() + expected3 = DataFrame(expected3).rename(columns={"C": "Q"}) + msg = "Passing a dictionary to SeriesGroupBy.agg is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result3 = grouped["C"].agg({"Q": "sum"}) + tm.assert_frame_equal(result3, expected3) + + # GH7115 & GH8112 & GH8582 + df = DataFrame( + np.random.default_rng(2).integers(0, 100, (50, 3)), + columns=["jim", "joe", "jolie"], + ) + ts = Series(np.random.default_rng(2).integers(5, 10, 50), name="jim") + + gr = df.groupby(ts) + gr.nth(0) # invokes set_selection_from_grouper internally + + msg = "The behavior of DataFrame.sum with axis=None is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg, check_stacklevel=False): + res = gr.apply(sum) + with tm.assert_produces_warning(FutureWarning, match=msg, check_stacklevel=False): + alt = df.groupby(ts).apply(sum) + tm.assert_frame_equal(res, alt) + + for attr in ["mean", "max", "count", "idxmax", "cumsum", "all"]: + gr = df.groupby(ts, as_index=False) + left = getattr(gr, attr)() + + gr = df.groupby(ts.values, as_index=True) + right = getattr(gr, attr)().reset_index(drop=True) + + tm.assert_frame_equal(left, right) + + +def test_ops_not_as_index(reduction_func): + # GH 10355, 21090 + # Using as_index=False should not modify grouped column + + if reduction_func in ("corrwith", "nth", "ngroup"): + pytest.skip(f"GH 5755: Test not applicable for {reduction_func}") + + df = DataFrame( + np.random.default_rng(2).integers(0, 5, size=(100, 2)), columns=["a", "b"] + ) + expected = getattr(df.groupby("a"), reduction_func)() + if reduction_func == "size": + expected = expected.rename("size") + expected = expected.reset_index() + + if reduction_func != "size": + # 32 bit compat -> groupby preserves dtype whereas reset_index casts to int64 + expected["a"] = expected["a"].astype(df["a"].dtype) + + g = df.groupby("a", as_index=False) + + result = getattr(g, reduction_func)() + tm.assert_frame_equal(result, expected) + + result = g.agg(reduction_func) + tm.assert_frame_equal(result, expected) + + result = getattr(g["b"], reduction_func)() + tm.assert_frame_equal(result, expected) + + result = g["b"].agg(reduction_func) + tm.assert_frame_equal(result, expected) + + +def test_as_index_series_return_frame(df): + grouped = df.groupby("A", as_index=False) + grouped2 = df.groupby(["A", "B"], as_index=False) + + result = grouped["C"].agg("sum") + expected = grouped.agg("sum").loc[:, ["A", "C"]] + assert isinstance(result, DataFrame) + tm.assert_frame_equal(result, expected) + + result2 = grouped2["C"].agg("sum") + expected2 = grouped2.agg("sum").loc[:, ["A", "B", "C"]] + assert isinstance(result2, DataFrame) + tm.assert_frame_equal(result2, expected2) + + result = grouped["C"].sum() + expected = grouped.sum().loc[:, ["A", "C"]] + assert isinstance(result, DataFrame) + tm.assert_frame_equal(result, expected) + + result2 = grouped2["C"].sum() + expected2 = grouped2.sum().loc[:, ["A", "B", "C"]] + assert isinstance(result2, DataFrame) + tm.assert_frame_equal(result2, expected2) + + +def test_as_index_series_column_slice_raises(df): + # GH15072 + grouped = df.groupby("A", as_index=False) + msg = r"Column\(s\) C already selected" + + with pytest.raises(IndexError, match=msg): + grouped["C"].__getitem__("D") + + +def test_groupby_as_index_cython(df): + data = df + + # single-key + grouped = data.groupby("A", as_index=False) + result = grouped.mean(numeric_only=True) + expected = data.groupby(["A"]).mean(numeric_only=True) + expected.insert(0, "A", expected.index) + expected.index = RangeIndex(len(expected)) + tm.assert_frame_equal(result, expected) + + # multi-key + grouped = data.groupby(["A", "B"], as_index=False) + result = grouped.mean() + expected = data.groupby(["A", "B"]).mean() + + arrays = list(zip(*expected.index.values)) + expected.insert(0, "A", arrays[0]) + expected.insert(1, "B", arrays[1]) + expected.index = RangeIndex(len(expected)) + tm.assert_frame_equal(result, expected) + + +def test_groupby_as_index_series_scalar(df): + grouped = df.groupby(["A", "B"], as_index=False) + + # GH #421 + + result = grouped["C"].agg(len) + expected = grouped.agg(len).loc[:, ["A", "B", "C"]] + tm.assert_frame_equal(result, expected) + + +def test_groupby_as_index_corner(df, ts): + msg = "as_index=False only valid with DataFrame" + with pytest.raises(TypeError, match=msg): + ts.groupby(lambda x: x.weekday(), as_index=False) + + msg = "as_index=False only valid for axis=0" + depr_msg = "DataFrame.groupby with axis=1 is deprecated" + with pytest.raises(ValueError, match=msg): + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + df.groupby(lambda x: x.lower(), as_index=False, axis=1) + + +def test_groupby_multiple_key(): + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + grouped = df.groupby([lambda x: x.year, lambda x: x.month, lambda x: x.day]) + agged = grouped.sum() + tm.assert_almost_equal(df.values, agged.values) + + depr_msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + grouped = df.T.groupby( + [lambda x: x.year, lambda x: x.month, lambda x: x.day], axis=1 + ) + + agged = grouped.agg(lambda x: x.sum()) + tm.assert_index_equal(agged.index, df.columns) + tm.assert_almost_equal(df.T.values, agged.values) + + agged = grouped.agg(lambda x: x.sum()) + tm.assert_almost_equal(df.T.values, agged.values) + + +def test_groupby_multi_corner(df): + # test that having an all-NA column doesn't mess you up + df = df.copy() + df["bad"] = np.nan + agged = df.groupby(["A", "B"]).mean() + + expected = df.groupby(["A", "B"]).mean() + expected["bad"] = np.nan + + tm.assert_frame_equal(agged, expected) + + +def test_raises_on_nuisance(df): + grouped = df.groupby("A") + msg = re.escape("agg function failed [how->mean,dtype->") + with pytest.raises(TypeError, match=msg): + grouped.agg("mean") + with pytest.raises(TypeError, match=msg): + grouped.mean() + + df = df.loc[:, ["A", "C", "D"]] + df["E"] = datetime.now() + grouped = df.groupby("A") + msg = "datetime64 type does not support sum operations" + with pytest.raises(TypeError, match=msg): + grouped.agg("sum") + with pytest.raises(TypeError, match=msg): + grouped.sum() + + # won't work with axis = 1 + depr_msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + grouped = df.groupby({"A": 0, "C": 0, "D": 1, "E": 1}, axis=1) + msg = "does not support reduction 'sum'" + with pytest.raises(TypeError, match=msg): + grouped.agg(lambda x: x.sum(0, numeric_only=False)) + + +@pytest.mark.parametrize( + "agg_function", + ["max", "min"], +) +def test_keep_nuisance_agg(df, agg_function): + # GH 38815 + grouped = df.groupby("A") + result = getattr(grouped, agg_function)() + expected = result.copy() + expected.loc["bar", "B"] = getattr(df.loc[df["A"] == "bar", "B"], agg_function)() + expected.loc["foo", "B"] = getattr(df.loc[df["A"] == "foo", "B"], agg_function)() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "agg_function", + ["sum", "mean", "prod", "std", "var", "sem", "median"], +) +@pytest.mark.parametrize("numeric_only", [True, False]) +def test_omit_nuisance_agg(df, agg_function, numeric_only): + # GH 38774, GH 38815 + grouped = df.groupby("A") + + no_drop_nuisance = ("var", "std", "sem", "mean", "prod", "median") + if agg_function in no_drop_nuisance and not numeric_only: + # Added numeric_only as part of GH#46560; these do not drop nuisance + # columns when numeric_only is False + if agg_function in ("std", "sem"): + klass = ValueError + msg = "could not convert string to float: 'one'" + else: + klass = TypeError + msg = re.escape(f"agg function failed [how->{agg_function},dtype->") + with pytest.raises(klass, match=msg): + getattr(grouped, agg_function)(numeric_only=numeric_only) + else: + result = getattr(grouped, agg_function)(numeric_only=numeric_only) + if not numeric_only and agg_function == "sum": + # sum is successful on column B + columns = ["A", "B", "C", "D"] + else: + columns = ["A", "C", "D"] + expected = getattr(df.loc[:, columns].groupby("A"), agg_function)( + numeric_only=numeric_only + ) + tm.assert_frame_equal(result, expected) + + +def test_raise_on_nuisance_python_single(df): + # GH 38815 + grouped = df.groupby("A") + with pytest.raises(ValueError, match="could not convert"): + grouped.skew() + + +def test_raise_on_nuisance_python_multiple(three_group): + grouped = three_group.groupby(["A", "B"]) + msg = re.escape("agg function failed [how->mean,dtype->") + with pytest.raises(TypeError, match=msg): + grouped.agg("mean") + with pytest.raises(TypeError, match=msg): + grouped.mean() + + +def test_empty_groups_corner(multiindex_dataframe_random_data): + # handle empty groups + df = DataFrame( + { + "k1": np.array(["b", "b", "b", "a", "a", "a"]), + "k2": np.array(["1", "1", "1", "2", "2", "2"]), + "k3": ["foo", "bar"] * 3, + "v1": np.random.default_rng(2).standard_normal(6), + "v2": np.random.default_rng(2).standard_normal(6), + } + ) + + grouped = df.groupby(["k1", "k2"]) + result = grouped[["v1", "v2"]].agg("mean") + expected = grouped.mean(numeric_only=True) + tm.assert_frame_equal(result, expected) + + grouped = multiindex_dataframe_random_data[3:5].groupby(level=0) + agged = grouped.apply(lambda x: x.mean()) + agged_A = grouped["A"].apply("mean") + tm.assert_series_equal(agged["A"], agged_A) + assert agged.index.name == "first" + + +def test_nonsense_func(): + df = DataFrame([0]) + msg = r"unsupported operand type\(s\) for \+: 'int' and 'str'" + with pytest.raises(TypeError, match=msg): + df.groupby(lambda x: x + "foo") + + +def test_wrap_aggregated_output_multindex(multiindex_dataframe_random_data): + df = multiindex_dataframe_random_data.T + df["baz", "two"] = "peekaboo" + + keys = [np.array([0, 0, 1]), np.array([0, 0, 1])] + msg = re.escape("agg function failed [how->mean,dtype->") + with pytest.raises(TypeError, match=msg): + df.groupby(keys).agg("mean") + agged = df.drop(columns=("baz", "two")).groupby(keys).agg("mean") + assert isinstance(agged.columns, MultiIndex) + + def aggfun(ser): + if ser.name == ("foo", "one"): + raise TypeError("Test error message") + return ser.sum() + + with pytest.raises(TypeError, match="Test error message"): + df.groupby(keys).aggregate(aggfun) + + +def test_groupby_level_apply(multiindex_dataframe_random_data): + result = multiindex_dataframe_random_data.groupby(level=0).count() + assert result.index.name == "first" + result = multiindex_dataframe_random_data.groupby(level=1).count() + assert result.index.name == "second" + + result = multiindex_dataframe_random_data["A"].groupby(level=0).count() + assert result.index.name == "first" + + +def test_groupby_level_mapper(multiindex_dataframe_random_data): + deleveled = multiindex_dataframe_random_data.reset_index() + + mapper0 = {"foo": 0, "bar": 0, "baz": 1, "qux": 1} + mapper1 = {"one": 0, "two": 0, "three": 1} + + result0 = multiindex_dataframe_random_data.groupby(mapper0, level=0).sum() + result1 = multiindex_dataframe_random_data.groupby(mapper1, level=1).sum() + + mapped_level0 = np.array( + [mapper0.get(x) for x in deleveled["first"]], dtype=np.int64 + ) + mapped_level1 = np.array( + [mapper1.get(x) for x in deleveled["second"]], dtype=np.int64 + ) + expected0 = multiindex_dataframe_random_data.groupby(mapped_level0).sum() + expected1 = multiindex_dataframe_random_data.groupby(mapped_level1).sum() + expected0.index.name, expected1.index.name = "first", "second" + + tm.assert_frame_equal(result0, expected0) + tm.assert_frame_equal(result1, expected1) + + +def test_groupby_level_nonmulti(): + # GH 1313, GH 13901 + s = Series([1, 2, 3, 10, 4, 5, 20, 6], Index([1, 2, 3, 1, 4, 5, 2, 6], name="foo")) + expected = Series([11, 22, 3, 4, 5, 6], Index(range(1, 7), name="foo")) + + result = s.groupby(level=0).sum() + tm.assert_series_equal(result, expected) + result = s.groupby(level=[0]).sum() + tm.assert_series_equal(result, expected) + result = s.groupby(level=-1).sum() + tm.assert_series_equal(result, expected) + result = s.groupby(level=[-1]).sum() + tm.assert_series_equal(result, expected) + + msg = "level > 0 or level < -1 only valid with MultiIndex" + with pytest.raises(ValueError, match=msg): + s.groupby(level=1) + with pytest.raises(ValueError, match=msg): + s.groupby(level=-2) + msg = "No group keys passed!" + with pytest.raises(ValueError, match=msg): + s.groupby(level=[]) + msg = "multiple levels only valid with MultiIndex" + with pytest.raises(ValueError, match=msg): + s.groupby(level=[0, 0]) + with pytest.raises(ValueError, match=msg): + s.groupby(level=[0, 1]) + msg = "level > 0 or level < -1 only valid with MultiIndex" + with pytest.raises(ValueError, match=msg): + s.groupby(level=[1]) + + +def test_groupby_complex(): + # GH 12902 + a = Series(data=np.arange(4) * (1 + 2j), index=[0, 0, 1, 1]) + expected = Series((1 + 2j, 5 + 10j)) + + result = a.groupby(level=0).sum() + tm.assert_series_equal(result, expected) + + +def test_groupby_complex_mean(): + # GH 26475 + df = DataFrame( + [ + {"a": 2, "b": 1 + 2j}, + {"a": 1, "b": 1 + 1j}, + {"a": 1, "b": 1 + 2j}, + ] + ) + result = df.groupby("b").mean() + expected = DataFrame( + [[1.0], [1.5]], + index=Index([(1 + 1j), (1 + 2j)], name="b"), + columns=Index(["a"]), + ) + tm.assert_frame_equal(result, expected) + + +def test_groupby_complex_numbers(using_infer_string): + # GH 17927 + df = DataFrame( + [ + {"a": 1, "b": 1 + 1j}, + {"a": 1, "b": 1 + 2j}, + {"a": 4, "b": 1}, + ] + ) + dtype = "string[pyarrow_numpy]" if using_infer_string else object + expected = DataFrame( + np.array([1, 1, 1], dtype=np.int64), + index=Index([(1 + 1j), (1 + 2j), (1 + 0j)], name="b"), + columns=Index(["a"], dtype=dtype), + ) + result = df.groupby("b", sort=False).count() + tm.assert_frame_equal(result, expected) + + # Sorted by the magnitude of the complex numbers + expected.index = Index([(1 + 0j), (1 + 1j), (1 + 2j)], name="b") + result = df.groupby("b", sort=True).count() + tm.assert_frame_equal(result, expected) + + +def test_groupby_series_indexed_differently(): + s1 = Series( + [5.0, -9.0, 4.0, 100.0, -5.0, 55.0, 6.7], + index=Index(["a", "b", "c", "d", "e", "f", "g"]), + ) + s2 = Series( + [1.0, 1.0, 4.0, 5.0, 5.0, 7.0], index=Index(["a", "b", "d", "f", "g", "h"]) + ) + + grouped = s1.groupby(s2) + agged = grouped.mean() + exp = s1.groupby(s2.reindex(s1.index).get).mean() + tm.assert_series_equal(agged, exp) + + +def test_groupby_with_hier_columns(): + tuples = list( + zip( + *[ + ["bar", "bar", "baz", "baz", "foo", "foo", "qux", "qux"], + ["one", "two", "one", "two", "one", "two", "one", "two"], + ] + ) + ) + index = MultiIndex.from_tuples(tuples) + columns = MultiIndex.from_tuples( + [("A", "cat"), ("B", "dog"), ("B", "cat"), ("A", "dog")] + ) + df = DataFrame( + np.random.default_rng(2).standard_normal((8, 4)), index=index, columns=columns + ) + + result = df.groupby(level=0).mean() + tm.assert_index_equal(result.columns, columns) + + depr_msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + gb = df.groupby(level=0, axis=1) + result = gb.mean() + tm.assert_index_equal(result.index, df.index) + + result = df.groupby(level=0).agg("mean") + tm.assert_index_equal(result.columns, columns) + + result = df.groupby(level=0).apply(lambda x: x.mean()) + tm.assert_index_equal(result.columns, columns) + + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + gb = df.groupby(level=0, axis=1) + result = gb.agg(lambda x: x.mean(1)) + tm.assert_index_equal(result.columns, Index(["A", "B"])) + tm.assert_index_equal(result.index, df.index) + + # add a nuisance column + sorted_columns, _ = columns.sortlevel(0) + df["A", "foo"] = "bar" + result = df.groupby(level=0).mean(numeric_only=True) + tm.assert_index_equal(result.columns, df.columns[:-1]) + + +def test_grouping_ndarray(df): + grouped = df.groupby(df["A"].values) + result = grouped.sum() + expected = df.groupby(df["A"].rename(None)).sum() + tm.assert_frame_equal(result, expected) + + +def test_groupby_wrong_multi_labels(): + index = Index([0, 1, 2, 3, 4], name="index") + data = DataFrame( + { + "foo": ["foo1", "foo1", "foo2", "foo1", "foo3"], + "bar": ["bar1", "bar2", "bar2", "bar1", "bar1"], + "baz": ["baz1", "baz1", "baz1", "baz2", "baz2"], + "spam": ["spam2", "spam3", "spam2", "spam1", "spam1"], + "data": [20, 30, 40, 50, 60], + }, + index=index, + ) + + grouped = data.groupby(["foo", "bar", "baz", "spam"]) + + result = grouped.agg("mean") + expected = grouped.mean() + tm.assert_frame_equal(result, expected) + + +def test_groupby_series_with_name(df): + result = df.groupby(df["A"]).mean(numeric_only=True) + result2 = df.groupby(df["A"], as_index=False).mean(numeric_only=True) + assert result.index.name == "A" + assert "A" in result2 + + result = df.groupby([df["A"], df["B"]]).mean() + result2 = df.groupby([df["A"], df["B"]], as_index=False).mean() + assert result.index.names == ("A", "B") + assert "A" in result2 + assert "B" in result2 + + +def test_seriesgroupby_name_attr(df): + # GH 6265 + result = df.groupby("A")["C"] + assert result.count().name == "C" + assert result.mean().name == "C" + + testFunc = lambda x: np.sum(x) * 2 + assert result.agg(testFunc).name == "C" + + +def test_consistency_name(): + # GH 12363 + + df = DataFrame( + { + "A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"], + "B": ["one", "one", "two", "two", "two", "two", "one", "two"], + "C": np.random.default_rng(2).standard_normal(8) + 1.0, + "D": np.arange(8), + } + ) + + expected = df.groupby(["A"]).B.count() + result = df.B.groupby(df.A).count() + tm.assert_series_equal(result, expected) + + +def test_groupby_name_propagation(df): + # GH 6124 + def summarize(df, name=None): + return Series({"count": 1, "mean": 2, "omissions": 3}, name=name) + + def summarize_random_name(df): + # Provide a different name for each Series. In this case, groupby + # should not attempt to propagate the Series name since they are + # inconsistent. + return Series({"count": 1, "mean": 2, "omissions": 3}, name=df.iloc[0]["A"]) + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + metrics = df.groupby("A").apply(summarize) + assert metrics.columns.name is None + with tm.assert_produces_warning(DeprecationWarning, match=msg): + metrics = df.groupby("A").apply(summarize, "metrics") + assert metrics.columns.name == "metrics" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + metrics = df.groupby("A").apply(summarize_random_name) + assert metrics.columns.name is None + + +def test_groupby_nonstring_columns(): + df = DataFrame([np.arange(10) for x in range(10)]) + grouped = df.groupby(0) + result = grouped.mean() + expected = df.groupby(df[0]).mean() + tm.assert_frame_equal(result, expected) + + +def test_groupby_mixed_type_columns(): + # GH 13432, unorderable types in py3 + df = DataFrame([[0, 1, 2]], columns=["A", "B", 0]) + expected = DataFrame([[1, 2]], columns=["B", 0], index=Index([0], name="A")) + + result = df.groupby("A").first() + tm.assert_frame_equal(result, expected) + + result = df.groupby("A").sum() + tm.assert_frame_equal(result, expected) + + +def test_cython_grouper_series_bug_noncontig(): + arr = np.empty((100, 100)) + arr.fill(np.nan) + obj = Series(arr[:, 0]) + inds = np.tile(range(10), 10) + + result = obj.groupby(inds).agg(Series.median) + assert result.isna().all() + + +def test_series_grouper_noncontig_index(): + index = Index(["a" * 10] * 100) + + values = Series(np.random.default_rng(2).standard_normal(50), index=index[::2]) + labels = np.random.default_rng(2).integers(0, 5, 50) + + # it works! + grouped = values.groupby(labels) + + # accessing the index elements causes segfault + f = lambda x: len(set(map(id, x.index))) + grouped.agg(f) + + +def test_convert_objects_leave_decimal_alone(): + s = Series(range(5)) + labels = np.array(["a", "b", "c", "d", "e"], dtype="O") + + def convert_fast(x): + return Decimal(str(x.mean())) + + def convert_force_pure(x): + # base will be length 0 + assert len(x.values.base) > 0 + return Decimal(str(x.mean())) + + grouped = s.groupby(labels) + + result = grouped.agg(convert_fast) + assert result.dtype == np.object_ + assert isinstance(result.iloc[0], Decimal) + + result = grouped.agg(convert_force_pure) + assert result.dtype == np.object_ + assert isinstance(result.iloc[0], Decimal) + + +def test_groupby_dtype_inference_empty(): + # GH 6733 + df = DataFrame({"x": [], "range": np.arange(0, dtype="int64")}) + assert df["x"].dtype == np.float64 + + result = df.groupby("x").first() + exp_index = Index([], name="x", dtype=np.float64) + expected = DataFrame({"range": Series([], index=exp_index, dtype="int64")}) + tm.assert_frame_equal(result, expected, by_blocks=True) + + +def test_groupby_unit64_float_conversion(): + # GH: 30859 groupby converts unit64 to floats sometimes + df = DataFrame({"first": [1], "second": [1], "value": [16148277970000000000]}) + result = df.groupby(["first", "second"])["value"].max() + expected = Series( + [16148277970000000000], + MultiIndex.from_product([[1], [1]], names=["first", "second"]), + name="value", + ) + tm.assert_series_equal(result, expected) + + +def test_groupby_list_infer_array_like(df): + result = df.groupby(list(df["A"])).mean(numeric_only=True) + expected = df.groupby(df["A"]).mean(numeric_only=True) + tm.assert_frame_equal(result, expected, check_names=False) + + with pytest.raises(KeyError, match=r"^'foo'$"): + df.groupby(list(df["A"][:-1])) + + # pathological case of ambiguity + df = DataFrame( + { + "foo": [0, 1], + "bar": [3, 4], + "val": np.random.default_rng(2).standard_normal(2), + } + ) + + result = df.groupby(["foo", "bar"]).mean() + expected = df.groupby([df["foo"], df["bar"]]).mean()[["val"]] + + +def test_groupby_keys_same_size_as_index(): + # GH 11185 + freq = "s" + index = date_range( + start=Timestamp("2015-09-29T11:34:44-0700"), periods=2, freq=freq + ) + df = DataFrame([["A", 10], ["B", 15]], columns=["metric", "values"], index=index) + result = df.groupby([Grouper(level=0, freq=freq), "metric"]).mean() + expected = df.set_index([df.index, "metric"]).astype(float) + + tm.assert_frame_equal(result, expected) + + +def test_groupby_one_row(): + # GH 11741 + msg = r"^'Z'$" + df1 = DataFrame( + np.random.default_rng(2).standard_normal((1, 4)), columns=list("ABCD") + ) + with pytest.raises(KeyError, match=msg): + df1.groupby("Z") + df2 = DataFrame( + np.random.default_rng(2).standard_normal((2, 4)), columns=list("ABCD") + ) + with pytest.raises(KeyError, match=msg): + df2.groupby("Z") + + +def test_groupby_nat_exclude(): + # GH 6992 + df = DataFrame( + { + "values": np.random.default_rng(2).standard_normal(8), + "dt": [ + np.nan, + Timestamp("2013-01-01"), + np.nan, + Timestamp("2013-02-01"), + np.nan, + Timestamp("2013-02-01"), + np.nan, + Timestamp("2013-01-01"), + ], + "str": [np.nan, "a", np.nan, "a", np.nan, "a", np.nan, "b"], + } + ) + grouped = df.groupby("dt") + + expected = [Index([1, 7]), Index([3, 5])] + keys = sorted(grouped.groups.keys()) + assert len(keys) == 2 + for k, e in zip(keys, expected): + # grouped.groups keys are np.datetime64 with system tz + # not to be affected by tz, only compare values + tm.assert_index_equal(grouped.groups[k], e) + + # confirm obj is not filtered + tm.assert_frame_equal(grouped._grouper.groupings[0].obj, df) + assert grouped.ngroups == 2 + + expected = { + Timestamp("2013-01-01 00:00:00"): np.array([1, 7], dtype=np.intp), + Timestamp("2013-02-01 00:00:00"): np.array([3, 5], dtype=np.intp), + } + + for k in grouped.indices: + tm.assert_numpy_array_equal(grouped.indices[k], expected[k]) + + tm.assert_frame_equal(grouped.get_group(Timestamp("2013-01-01")), df.iloc[[1, 7]]) + tm.assert_frame_equal(grouped.get_group(Timestamp("2013-02-01")), df.iloc[[3, 5]]) + + with pytest.raises(KeyError, match=r"^NaT$"): + grouped.get_group(pd.NaT) + + nan_df = DataFrame( + {"nan": [np.nan, np.nan, np.nan], "nat": [pd.NaT, pd.NaT, pd.NaT]} + ) + assert nan_df["nan"].dtype == "float64" + assert nan_df["nat"].dtype == "datetime64[ns]" + + for key in ["nan", "nat"]: + grouped = nan_df.groupby(key) + assert grouped.groups == {} + assert grouped.ngroups == 0 + assert grouped.indices == {} + with pytest.raises(KeyError, match=r"^nan$"): + grouped.get_group(np.nan) + with pytest.raises(KeyError, match=r"^NaT$"): + grouped.get_group(pd.NaT) + + +def test_groupby_two_group_keys_all_nan(): + # GH #36842: Grouping over two group keys shouldn't raise an error + df = DataFrame({"a": [np.nan, np.nan], "b": [np.nan, np.nan], "c": [1, 2]}) + result = df.groupby(["a", "b"]).indices + assert result == {} + + +def test_groupby_2d_malformed(): + d = DataFrame(index=range(2)) + d["group"] = ["g1", "g2"] + d["zeros"] = [0, 0] + d["ones"] = [1, 1] + d["label"] = ["l1", "l2"] + tmp = d.groupby(["group"]).mean(numeric_only=True) + res_values = np.array([[0.0, 1.0], [0.0, 1.0]]) + tm.assert_index_equal(tmp.columns, Index(["zeros", "ones"])) + tm.assert_numpy_array_equal(tmp.values, res_values) + + +def test_int32_overflow(): + B = np.concatenate((np.arange(10000), np.arange(10000), np.arange(5000))) + A = np.arange(25000) + df = DataFrame( + { + "A": A, + "B": B, + "C": A, + "D": B, + "E": np.random.default_rng(2).standard_normal(25000), + } + ) + + left = df.groupby(["A", "B", "C", "D"]).sum() + right = df.groupby(["D", "C", "B", "A"]).sum() + assert len(left) == len(right) + + +def test_groupby_sort_multi(): + df = DataFrame( + { + "a": ["foo", "bar", "baz"], + "b": [3, 2, 1], + "c": [0, 1, 2], + "d": np.random.default_rng(2).standard_normal(3), + } + ) + + tups = [tuple(row) for row in df[["a", "b", "c"]].values] + tups = com.asarray_tuplesafe(tups) + result = df.groupby(["a", "b", "c"], sort=True).sum() + tm.assert_numpy_array_equal(result.index.values, tups[[1, 2, 0]]) + + tups = [tuple(row) for row in df[["c", "a", "b"]].values] + tups = com.asarray_tuplesafe(tups) + result = df.groupby(["c", "a", "b"], sort=True).sum() + tm.assert_numpy_array_equal(result.index.values, tups) + + tups = [tuple(x) for x in df[["b", "c", "a"]].values] + tups = com.asarray_tuplesafe(tups) + result = df.groupby(["b", "c", "a"], sort=True).sum() + tm.assert_numpy_array_equal(result.index.values, tups[[2, 1, 0]]) + + df = DataFrame( + { + "a": [0, 1, 2, 0, 1, 2], + "b": [0, 0, 0, 1, 1, 1], + "d": np.random.default_rng(2).standard_normal(6), + } + ) + grouped = df.groupby(["a", "b"])["d"] + result = grouped.sum() + + def _check_groupby(df, result, keys, field, f=lambda x: x.sum()): + tups = [tuple(row) for row in df[keys].values] + tups = com.asarray_tuplesafe(tups) + expected = f(df.groupby(tups)[field]) + for k, v in expected.items(): + assert result[k] == v + + _check_groupby(df, result, ["a", "b"], "d") + + +def test_dont_clobber_name_column(): + df = DataFrame( + {"key": ["a", "a", "a", "b", "b", "b"], "name": ["foo", "bar", "baz"] * 2} + ) + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("key", group_keys=False).apply(lambda x: x) + tm.assert_frame_equal(result, df) + + +def test_skip_group_keys(): + tsf = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + + grouped = tsf.groupby(lambda x: x.month, group_keys=False) + result = grouped.apply(lambda x: x.sort_values(by="A")[:3]) + + pieces = [group.sort_values(by="A")[:3] for key, group in grouped] + + expected = pd.concat(pieces) + tm.assert_frame_equal(result, expected) + + grouped = tsf["A"].groupby(lambda x: x.month, group_keys=False) + result = grouped.apply(lambda x: x.sort_values()[:3]) + + pieces = [group.sort_values()[:3] for key, group in grouped] + + expected = pd.concat(pieces) + tm.assert_series_equal(result, expected) + + +def test_no_nonsense_name(float_frame): + # GH #995 + s = float_frame["C"].copy() + s.name = None + + result = s.groupby(float_frame["A"]).agg("sum") + assert result.name is None + + +def test_multifunc_sum_bug(): + # GH #1065 + x = DataFrame(np.arange(9).reshape(3, 3)) + x["test"] = 0 + x["fl"] = [1.3, 1.5, 1.6] + + grouped = x.groupby("test") + result = grouped.agg({"fl": "sum", 2: "size"}) + assert result["fl"].dtype == np.float64 + + +def test_handle_dict_return_value(df): + def f(group): + return {"max": group.max(), "min": group.min()} + + def g(group): + return Series({"max": group.max(), "min": group.min()}) + + result = df.groupby("A")["C"].apply(f) + expected = df.groupby("A")["C"].apply(g) + + assert isinstance(result, Series) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("grouper", ["A", ["A", "B"]]) +def test_set_group_name(df, grouper, using_infer_string): + def f(group): + assert group.name is not None + return group + + def freduce(group): + assert group.name is not None + if using_infer_string and grouper == "A" and is_string_dtype(group.dtype): + with pytest.raises(TypeError, match="does not support"): + group.sum() + else: + return group.sum() + + def freducex(x): + return freduce(x) + + grouped = df.groupby(grouper, group_keys=False) + + # make sure all these work + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + grouped.apply(f) + grouped.aggregate(freduce) + grouped.aggregate({"C": freduce, "D": freduce}) + grouped.transform(f) + + grouped["C"].apply(f) + grouped["C"].aggregate(freduce) + grouped["C"].aggregate([freduce, freducex]) + grouped["C"].transform(f) + + +def test_group_name_available_in_inference_pass(): + # gh-15062 + df = DataFrame({"a": [0, 0, 1, 1, 2, 2], "b": np.arange(6)}) + + names = [] + + def f(group): + names.append(group.name) + return group.copy() + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + df.groupby("a", sort=False, group_keys=False).apply(f) + + expected_names = [0, 1, 2] + assert names == expected_names + + +def test_no_dummy_key_names(df): + # see gh-1291 + result = df.groupby(df["A"].values).sum() + assert result.index.name is None + + result = df.groupby([df["A"].values, df["B"].values]).sum() + assert result.index.names == (None, None) + + +def test_groupby_sort_multiindex_series(): + # series multiindex groupby sort argument was not being passed through + # _compress_group_index + # GH 9444 + index = MultiIndex( + levels=[[1, 2], [1, 2]], + codes=[[0, 0, 0, 0, 1, 1], [1, 1, 0, 0, 0, 0]], + names=["a", "b"], + ) + mseries = Series([0, 1, 2, 3, 4, 5], index=index) + index = MultiIndex( + levels=[[1, 2], [1, 2]], codes=[[0, 0, 1], [1, 0, 0]], names=["a", "b"] + ) + mseries_result = Series([0, 2, 4], index=index) + + result = mseries.groupby(level=["a", "b"], sort=False).first() + tm.assert_series_equal(result, mseries_result) + result = mseries.groupby(level=["a", "b"], sort=True).first() + tm.assert_series_equal(result, mseries_result.sort_index()) + + +def test_groupby_reindex_inside_function(): + periods = 1000 + ind = date_range(start="2012/1/1", freq="5min", periods=periods) + df = DataFrame({"high": np.arange(periods), "low": np.arange(periods)}, index=ind) + + def agg_before(func, fix=False): + """ + Run an aggregate func on the subset of data. + """ + + def _func(data): + d = data.loc[data.index.map(lambda x: x.hour < 11)].dropna() + if fix: + data[data.index[0]] + if len(d) == 0: + return None + return func(d) + + return _func + + grouped = df.groupby(lambda x: datetime(x.year, x.month, x.day)) + closure_bad = grouped.agg({"high": agg_before(np.max)}) + closure_good = grouped.agg({"high": agg_before(np.max, True)}) + + tm.assert_frame_equal(closure_bad, closure_good) + + +def test_groupby_multiindex_missing_pair(): + # GH9049 + df = DataFrame( + { + "group1": ["a", "a", "a", "b"], + "group2": ["c", "c", "d", "c"], + "value": [1, 1, 1, 5], + } + ) + df = df.set_index(["group1", "group2"]) + df_grouped = df.groupby(level=["group1", "group2"], sort=True) + + res = df_grouped.agg("sum") + idx = MultiIndex.from_tuples( + [("a", "c"), ("a", "d"), ("b", "c")], names=["group1", "group2"] + ) + exp = DataFrame([[2], [1], [5]], index=idx, columns=["value"]) + + tm.assert_frame_equal(res, exp) + + +def test_groupby_multiindex_not_lexsorted(): + # GH 11640 + + # define the lexsorted version + lexsorted_mi = MultiIndex.from_tuples( + [("a", ""), ("b1", "c1"), ("b2", "c2")], names=["b", "c"] + ) + lexsorted_df = DataFrame([[1, 3, 4]], columns=lexsorted_mi) + assert lexsorted_df.columns._is_lexsorted() + + # define the non-lexsorted version + not_lexsorted_df = DataFrame( + columns=["a", "b", "c", "d"], data=[[1, "b1", "c1", 3], [1, "b2", "c2", 4]] + ) + not_lexsorted_df = not_lexsorted_df.pivot_table( + index="a", columns=["b", "c"], values="d" + ) + not_lexsorted_df = not_lexsorted_df.reset_index() + assert not not_lexsorted_df.columns._is_lexsorted() + + expected = lexsorted_df.groupby("a").mean() + with tm.assert_produces_warning(PerformanceWarning): + result = not_lexsorted_df.groupby("a").mean() + tm.assert_frame_equal(expected, result) + + # a transforming function should work regardless of sort + # GH 14776 + df = DataFrame( + {"x": ["a", "a", "b", "a"], "y": [1, 1, 2, 2], "z": [1, 2, 3, 4]} + ).set_index(["x", "y"]) + assert not df.index._is_lexsorted() + + for level in [0, 1, [0, 1]]: + for sort in [False, True]: + result = df.groupby(level=level, sort=sort, group_keys=False).apply( + DataFrame.drop_duplicates + ) + expected = df + tm.assert_frame_equal(expected, result) + + result = ( + df.sort_index() + .groupby(level=level, sort=sort, group_keys=False) + .apply(DataFrame.drop_duplicates) + ) + expected = df.sort_index() + tm.assert_frame_equal(expected, result) + + +def test_index_label_overlaps_location(): + # checking we don't have any label/location confusion in the + # wake of GH5375 + df = DataFrame(list("ABCDE"), index=[2, 0, 2, 1, 1]) + g = df.groupby(list("ababb")) + actual = g.filter(lambda x: len(x) > 2) + expected = df.iloc[[1, 3, 4]] + tm.assert_frame_equal(actual, expected) + + ser = df[0] + g = ser.groupby(list("ababb")) + actual = g.filter(lambda x: len(x) > 2) + expected = ser.take([1, 3, 4]) + tm.assert_series_equal(actual, expected) + + # and again, with a generic Index of floats + df.index = df.index.astype(float) + g = df.groupby(list("ababb")) + actual = g.filter(lambda x: len(x) > 2) + expected = df.iloc[[1, 3, 4]] + tm.assert_frame_equal(actual, expected) + + ser = df[0] + g = ser.groupby(list("ababb")) + actual = g.filter(lambda x: len(x) > 2) + expected = ser.take([1, 3, 4]) + tm.assert_series_equal(actual, expected) + + +def test_transform_doesnt_clobber_ints(): + # GH 7972 + n = 6 + x = np.arange(n) + df = DataFrame({"a": x // 2, "b": 2.0 * x, "c": 3.0 * x}) + df2 = DataFrame({"a": x // 2 * 1.0, "b": 2.0 * x, "c": 3.0 * x}) + + gb = df.groupby("a") + result = gb.transform("mean") + + gb2 = df2.groupby("a") + expected = gb2.transform("mean") + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "sort_column", + ["ints", "floats", "strings", ["ints", "floats"], ["ints", "strings"]], +) +@pytest.mark.parametrize( + "group_column", ["int_groups", "string_groups", ["int_groups", "string_groups"]] +) +def test_groupby_preserves_sort(sort_column, group_column): + # Test to ensure that groupby always preserves sort order of original + # object. Issue #8588 and #9651 + + df = DataFrame( + { + "int_groups": [3, 1, 0, 1, 0, 3, 3, 3], + "string_groups": ["z", "a", "z", "a", "a", "g", "g", "g"], + "ints": [8, 7, 4, 5, 2, 9, 1, 1], + "floats": [2.3, 5.3, 6.2, -2.4, 2.2, 1.1, 1.1, 5], + "strings": ["z", "d", "a", "e", "word", "word2", "42", "47"], + } + ) + + # Try sorting on different types and with different group types + + df = df.sort_values(by=sort_column) + g = df.groupby(group_column) + + def test_sort(x): + tm.assert_frame_equal(x, x.sort_values(by=sort_column)) + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + g.apply(test_sort) + + +def test_pivot_table_values_key_error(): + # This test is designed to replicate the error in issue #14938 + df = DataFrame( + { + "eventDate": date_range(datetime.today(), periods=20, freq="ME").tolist(), + "thename": range(20), + } + ) + + df["year"] = df.set_index("eventDate").index.year + df["month"] = df.set_index("eventDate").index.month + + with pytest.raises(KeyError, match="'badname'"): + df.reset_index().pivot_table( + index="year", columns="month", values="badname", aggfunc="count" + ) + + +@pytest.mark.parametrize("columns", ["C", ["C"]]) +@pytest.mark.parametrize("keys", [["A"], ["A", "B"]]) +@pytest.mark.parametrize( + "values", + [ + [True], + [0], + [0.0], + ["a"], + Categorical([0]), + [to_datetime(0)], + date_range(0, 1, 1, tz="US/Eastern"), + pd.period_range("2016-01-01", periods=3, freq="D"), + pd.array([0], dtype="Int64"), + pd.array([0], dtype="Float64"), + pd.array([False], dtype="boolean"), + ], + ids=[ + "bool", + "int", + "float", + "str", + "cat", + "dt64", + "dt64tz", + "period", + "Int64", + "Float64", + "boolean", + ], +) +@pytest.mark.parametrize("method", ["attr", "agg", "apply"]) +@pytest.mark.parametrize( + "op", ["idxmax", "idxmin", "min", "max", "sum", "prod", "skew"] +) +def test_empty_groupby( + columns, keys, values, method, op, using_array_manager, dropna, using_infer_string +): + # GH8093 & GH26411 + override_dtype = None + + if isinstance(values, BooleanArray) and op in ["sum", "prod"]: + # We expect to get Int64 back for these + override_dtype = "Int64" + + if isinstance(values[0], bool) and op in ("prod", "sum"): + # sum/product of bools is an integer + override_dtype = "int64" + + df = DataFrame({"A": values, "B": values, "C": values}, columns=list("ABC")) + + if hasattr(values, "dtype"): + # check that we did the construction right + assert (df.dtypes == values.dtype).all() + + df = df.iloc[:0] + + gb = df.groupby(keys, group_keys=False, dropna=dropna, observed=False)[columns] + + def get_result(**kwargs): + if method == "attr": + return getattr(gb, op)(**kwargs) + else: + return getattr(gb, method)(op, **kwargs) + + def get_categorical_invalid_expected(): + # Categorical is special without 'observed=True', we get an NaN entry + # corresponding to the unobserved group. If we passed observed=True + # to groupby, expected would just be 'df.set_index(keys)[columns]' + # as below + lev = Categorical([0], dtype=values.dtype) + if len(keys) != 1: + idx = MultiIndex.from_product([lev, lev], names=keys) + else: + # all columns are dropped, but we end up with one row + # Categorical is special without 'observed=True' + idx = Index(lev, name=keys[0]) + + if using_infer_string: + columns = Index([], dtype="string[pyarrow_numpy]") + else: + columns = [] + expected = DataFrame([], columns=columns, index=idx) + return expected + + is_per = isinstance(df.dtypes.iloc[0], pd.PeriodDtype) + is_dt64 = df.dtypes.iloc[0].kind == "M" + is_cat = isinstance(values, Categorical) + + if ( + isinstance(values, Categorical) + and not values.ordered + and op in ["min", "max", "idxmin", "idxmax"] + ): + if op in ["min", "max"]: + msg = f"Cannot perform {op} with non-ordered Categorical" + klass = TypeError + else: + msg = f"Can't get {op} of an empty group due to unobserved categories" + klass = ValueError + with pytest.raises(klass, match=msg): + get_result() + + if op in ["min", "max", "idxmin", "idxmax"] and isinstance(columns, list): + # i.e. DataframeGroupBy, not SeriesGroupBy + result = get_result(numeric_only=True) + expected = get_categorical_invalid_expected() + tm.assert_equal(result, expected) + return + + if op in ["prod", "sum", "skew"]: + # ops that require more than just ordered-ness + if is_dt64 or is_cat or is_per: + # GH#41291 + # datetime64 -> prod and sum are invalid + if is_dt64: + msg = "datetime64 type does not support" + elif is_per: + msg = "Period type does not support" + else: + msg = "category type does not support" + if op == "skew": + msg = "|".join([msg, "does not support reduction 'skew'"]) + with pytest.raises(TypeError, match=msg): + get_result() + + if not isinstance(columns, list): + # i.e. SeriesGroupBy + return + elif op == "skew": + # TODO: test the numeric_only=True case + return + else: + # i.e. op in ["prod", "sum"]: + # i.e. DataFrameGroupBy + # ops that require more than just ordered-ness + # GH#41291 + result = get_result(numeric_only=True) + + # with numeric_only=True, these are dropped, and we get + # an empty DataFrame back + expected = df.set_index(keys)[[]] + if is_cat: + expected = get_categorical_invalid_expected() + tm.assert_equal(result, expected) + return + + result = get_result() + expected = df.set_index(keys)[columns] + if op in ["idxmax", "idxmin"]: + expected = expected.astype(df.index.dtype) + if override_dtype is not None: + expected = expected.astype(override_dtype) + if len(keys) == 1: + expected.index.name = keys[0] + tm.assert_equal(result, expected) + + +def test_empty_groupby_apply_nonunique_columns(): + # GH#44417 + df = DataFrame(np.random.default_rng(2).standard_normal((0, 4))) + df[3] = df[3].astype(np.int64) + df.columns = [0, 1, 2, 0] + gb = df.groupby(df[1], group_keys=False) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + res = gb.apply(lambda x: x) + assert (res.dtypes == df.dtypes).all() + + +def test_tuple_as_grouping(): + # https://github.com/pandas-dev/pandas/issues/18314 + df = DataFrame( + { + ("a", "b"): [1, 1, 1, 1], + "a": [2, 2, 2, 2], + "b": [2, 2, 2, 2], + "c": [1, 1, 1, 1], + } + ) + + with pytest.raises(KeyError, match=r"('a', 'b')"): + df[["a", "b", "c"]].groupby(("a", "b")) + + result = df.groupby(("a", "b"))["c"].sum() + expected = Series([4], name="c", index=Index([1], name=("a", "b"))) + tm.assert_series_equal(result, expected) + + +def test_tuple_correct_keyerror(): + # https://github.com/pandas-dev/pandas/issues/18798 + df = DataFrame(1, index=range(3), columns=MultiIndex.from_product([[1, 2], [3, 4]])) + with pytest.raises(KeyError, match=r"^\(7, 8\)$"): + df.groupby((7, 8)).mean() + + +def test_groupby_agg_ohlc_non_first(): + # GH 21716 + df = DataFrame( + [[1], [1]], + columns=Index(["foo"], name="mycols"), + index=date_range("2018-01-01", periods=2, freq="D", name="dti"), + ) + + expected = DataFrame( + [[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]], + columns=MultiIndex.from_tuples( + ( + ("foo", "sum", "foo"), + ("foo", "ohlc", "open"), + ("foo", "ohlc", "high"), + ("foo", "ohlc", "low"), + ("foo", "ohlc", "close"), + ), + names=["mycols", None, None], + ), + index=date_range("2018-01-01", periods=2, freq="D", name="dti"), + ) + + result = df.groupby(Grouper(freq="D")).agg(["sum", "ohlc"]) + + tm.assert_frame_equal(result, expected) + + +def test_groupby_multiindex_nat(): + # GH 9236 + values = [ + (pd.NaT, "a"), + (datetime(2012, 1, 2), "a"), + (datetime(2012, 1, 2), "b"), + (datetime(2012, 1, 3), "a"), + ] + mi = MultiIndex.from_tuples(values, names=["date", None]) + ser = Series([3, 2, 2.5, 4], index=mi) + + result = ser.groupby(level=1).mean() + expected = Series([3.0, 2.5], index=["a", "b"]) + tm.assert_series_equal(result, expected) + + +def test_groupby_empty_list_raises(): + # GH 5289 + values = zip(range(10), range(10)) + df = DataFrame(values, columns=["apple", "b"]) + msg = "Grouper and axis must be same length" + with pytest.raises(ValueError, match=msg): + df.groupby([[]]) + + +def test_groupby_multiindex_series_keys_len_equal_group_axis(): + # GH 25704 + index_array = [["x", "x"], ["a", "b"], ["k", "k"]] + index_names = ["first", "second", "third"] + ri = MultiIndex.from_arrays(index_array, names=index_names) + s = Series(data=[1, 2], index=ri) + result = s.groupby(["first", "third"]).sum() + + index_array = [["x"], ["k"]] + index_names = ["first", "third"] + ei = MultiIndex.from_arrays(index_array, names=index_names) + expected = Series([3], index=ei) + + tm.assert_series_equal(result, expected) + + +def test_groupby_groups_in_BaseGrouper(): + # GH 26326 + # Test if DataFrame grouped with a pandas.Grouper has correct groups + mi = MultiIndex.from_product([["A", "B"], ["C", "D"]], names=["alpha", "beta"]) + df = DataFrame({"foo": [1, 2, 1, 2], "bar": [1, 2, 3, 4]}, index=mi) + result = df.groupby([Grouper(level="alpha"), "beta"]) + expected = df.groupby(["alpha", "beta"]) + assert result.groups == expected.groups + + result = df.groupby(["beta", Grouper(level="alpha")]) + expected = df.groupby(["beta", "alpha"]) + assert result.groups == expected.groups + + +@pytest.mark.parametrize("group_name", ["x", ["x"]]) +def test_groupby_axis_1(group_name): + # GH 27614 + df = DataFrame( + np.arange(12).reshape(3, 4), index=[0, 1, 0], columns=[10, 20, 10, 20] + ) + df.index.name = "y" + df.columns.name = "x" + + depr_msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + gb = df.groupby(group_name, axis=1) + + results = gb.sum() + expected = df.T.groupby(group_name).sum().T + tm.assert_frame_equal(results, expected) + + # test on MI column + iterables = [["bar", "baz", "foo"], ["one", "two"]] + mi = MultiIndex.from_product(iterables=iterables, names=["x", "x1"]) + df = DataFrame(np.arange(18).reshape(3, 6), index=[0, 1, 0], columns=mi) + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + gb = df.groupby(group_name, axis=1) + results = gb.sum() + expected = df.T.groupby(group_name).sum().T + tm.assert_frame_equal(results, expected) + + +@pytest.mark.parametrize( + "op, expected", + [ + ( + "shift", + { + "time": [ + None, + None, + Timestamp("2019-01-01 12:00:00"), + Timestamp("2019-01-01 12:30:00"), + None, + None, + ] + }, + ), + ( + "bfill", + { + "time": [ + Timestamp("2019-01-01 12:00:00"), + Timestamp("2019-01-01 12:30:00"), + Timestamp("2019-01-01 14:00:00"), + Timestamp("2019-01-01 14:30:00"), + Timestamp("2019-01-01 14:00:00"), + Timestamp("2019-01-01 14:30:00"), + ] + }, + ), + ( + "ffill", + { + "time": [ + Timestamp("2019-01-01 12:00:00"), + Timestamp("2019-01-01 12:30:00"), + Timestamp("2019-01-01 12:00:00"), + Timestamp("2019-01-01 12:30:00"), + Timestamp("2019-01-01 14:00:00"), + Timestamp("2019-01-01 14:30:00"), + ] + }, + ), + ], +) +def test_shift_bfill_ffill_tz(tz_naive_fixture, op, expected): + # GH19995, GH27992: Check that timezone does not drop in shift, bfill, and ffill + tz = tz_naive_fixture + data = { + "id": ["A", "B", "A", "B", "A", "B"], + "time": [ + Timestamp("2019-01-01 12:00:00"), + Timestamp("2019-01-01 12:30:00"), + None, + None, + Timestamp("2019-01-01 14:00:00"), + Timestamp("2019-01-01 14:30:00"), + ], + } + df = DataFrame(data).assign(time=lambda x: x.time.dt.tz_localize(tz)) + + grouped = df.groupby("id") + result = getattr(grouped, op)() + expected = DataFrame(expected).assign(time=lambda x: x.time.dt.tz_localize(tz)) + tm.assert_frame_equal(result, expected) + + +def test_groupby_only_none_group(): + # see GH21624 + # this was crashing with "ValueError: Length of passed values is 1, index implies 0" + df = DataFrame({"g": [None], "x": 1}) + actual = df.groupby("g")["x"].transform("sum") + expected = Series([np.nan], name="x") + + tm.assert_series_equal(actual, expected) + + +def test_groupby_duplicate_index(): + # GH#29189 the groupby call here used to raise + ser = Series([2, 5, 6, 8], index=[2.0, 4.0, 4.0, 5.0]) + gb = ser.groupby(level=0) + + result = gb.mean() + expected = Series([2, 5.5, 8], index=[2.0, 4.0, 5.0]) + tm.assert_series_equal(result, expected) + + +def test_group_on_empty_multiindex(transformation_func, request): + # GH 47787 + # With one row, those are transforms so the schema should be the same + df = DataFrame( + data=[[1, Timestamp("today"), 3, 4]], + columns=["col_1", "col_2", "col_3", "col_4"], + ) + df["col_3"] = df["col_3"].astype(int) + df["col_4"] = df["col_4"].astype(int) + df = df.set_index(["col_1", "col_2"]) + if transformation_func == "fillna": + args = ("ffill",) + else: + args = () + warn = FutureWarning if transformation_func == "fillna" else None + warn_msg = "DataFrameGroupBy.fillna is deprecated" + with tm.assert_produces_warning(warn, match=warn_msg): + result = df.iloc[:0].groupby(["col_1"]).transform(transformation_func, *args) + with tm.assert_produces_warning(warn, match=warn_msg): + expected = df.groupby(["col_1"]).transform(transformation_func, *args).iloc[:0] + if transformation_func in ("diff", "shift"): + expected = expected.astype(int) + tm.assert_equal(result, expected) + + warn_msg = "SeriesGroupBy.fillna is deprecated" + with tm.assert_produces_warning(warn, match=warn_msg): + result = ( + df["col_3"] + .iloc[:0] + .groupby(["col_1"]) + .transform(transformation_func, *args) + ) + warn_msg = "SeriesGroupBy.fillna is deprecated" + with tm.assert_produces_warning(warn, match=warn_msg): + expected = ( + df["col_3"] + .groupby(["col_1"]) + .transform(transformation_func, *args) + .iloc[:0] + ) + if transformation_func in ("diff", "shift"): + expected = expected.astype(int) + tm.assert_equal(result, expected) + + +def test_groupby_crash_on_nunique(axis): + # Fix following 30253 + dti = date_range("2016-01-01", periods=2, name="foo") + df = DataFrame({("A", "B"): [1, 2], ("A", "C"): [1, 3], ("D", "B"): [0, 0]}) + df.columns.names = ("bar", "baz") + df.index = dti + + axis_number = df._get_axis_number(axis) + if not axis_number: + df = df.T + msg = "The 'axis' keyword in DataFrame.groupby is deprecated" + else: + msg = "DataFrame.groupby with axis=1 is deprecated" + + with tm.assert_produces_warning(FutureWarning, match=msg): + gb = df.groupby(axis=axis_number, level=0) + result = gb.nunique() + + expected = DataFrame({"A": [1, 2], "D": [1, 1]}, index=dti) + expected.columns.name = "bar" + if not axis_number: + expected = expected.T + + tm.assert_frame_equal(result, expected) + + if axis_number == 0: + # same thing, but empty columns + with tm.assert_produces_warning(FutureWarning, match=msg): + gb2 = df[[]].groupby(axis=axis_number, level=0) + exp = expected[[]] + else: + # same thing, but empty rows + with tm.assert_produces_warning(FutureWarning, match=msg): + gb2 = df.loc[[]].groupby(axis=axis_number, level=0) + # default for empty when we can't infer a dtype is float64 + exp = expected.loc[[]].astype(np.float64) + + res = gb2.nunique() + tm.assert_frame_equal(res, exp) + + +def test_groupby_list_level(): + # GH 9790 + expected = DataFrame(np.arange(0, 9).reshape(3, 3), dtype=float) + result = expected.groupby(level=[0]).mean() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "max_seq_items, expected", + [ + (5, "{0: [0], 1: [1], 2: [2], 3: [3], 4: [4]}"), + (4, "{0: [0], 1: [1], 2: [2], 3: [3], ...}"), + (1, "{0: [0], ...}"), + ], +) +def test_groups_repr_truncates(max_seq_items, expected): + # GH 1135 + df = DataFrame(np.random.default_rng(2).standard_normal((5, 1))) + df["a"] = df.index + + with pd.option_context("display.max_seq_items", max_seq_items): + result = df.groupby("a").groups.__repr__() + assert result == expected + + result = df.groupby(np.array(df.a)).groups.__repr__() + assert result == expected + + +def test_group_on_two_row_multiindex_returns_one_tuple_key(): + # GH 18451 + df = DataFrame([{"a": 1, "b": 2, "c": 99}, {"a": 1, "b": 2, "c": 88}]) + df = df.set_index(["a", "b"]) + + grp = df.groupby(["a", "b"]) + result = grp.indices + expected = {(1, 2): np.array([0, 1], dtype=np.int64)} + + assert len(result) == 1 + key = (1, 2) + assert (result[key] == expected[key]).all() + + +@pytest.mark.parametrize( + "klass, attr, value", + [ + (DataFrame, "level", "a"), + (DataFrame, "as_index", False), + (DataFrame, "sort", False), + (DataFrame, "group_keys", False), + (DataFrame, "observed", True), + (DataFrame, "dropna", False), + (Series, "level", "a"), + (Series, "as_index", False), + (Series, "sort", False), + (Series, "group_keys", False), + (Series, "observed", True), + (Series, "dropna", False), + ], +) +def test_subsetting_columns_keeps_attrs(klass, attr, value): + # GH 9959 - When subsetting columns, don't drop attributes + df = DataFrame({"a": [1], "b": [2], "c": [3]}) + if attr != "axis": + df = df.set_index("a") + + expected = df.groupby("a", **{attr: value}) + result = expected[["b"]] if klass is DataFrame else expected["b"] + assert getattr(result, attr) == getattr(expected, attr) + + +def test_subsetting_columns_axis_1(): + # GH 37725 + df = DataFrame({"A": [1], "B": [2], "C": [3]}) + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + g = df.groupby([0, 0, 1], axis=1) + match = "Cannot subset columns when using axis=1" + with pytest.raises(ValueError, match=match): + g[["A", "B"]].sum() + + +@pytest.mark.parametrize("func", ["sum", "any", "shift"]) +def test_groupby_column_index_name_lost(func): + # GH: 29764 groupby loses index sometimes + expected = Index(["a"], name="idx") + df = DataFrame([[1]], columns=expected) + df_grouped = df.groupby([1]) + result = getattr(df_grouped, func)().columns + tm.assert_index_equal(result, expected) + + +@pytest.mark.parametrize( + "infer_string", + [ + False, + pytest.param(True, marks=td.skip_if_no("pyarrow")), + ], +) +def test_groupby_duplicate_columns(infer_string): + # GH: 31735 + if infer_string: + pytest.importorskip("pyarrow") + df = DataFrame( + {"A": ["f", "e", "g", "h"], "B": ["a", "b", "c", "d"], "C": [1, 2, 3, 4]} + ).astype(object) + df.columns = ["A", "B", "B"] + with pd.option_context("future.infer_string", infer_string): + result = df.groupby([0, 0, 0, 0]).min() + expected = DataFrame( + [["e", "a", 1]], index=np.array([0]), columns=["A", "B", "B"], dtype=object + ) + tm.assert_frame_equal(result, expected) + + +def test_groupby_series_with_tuple_name(): + # GH 37755 + ser = Series([1, 2, 3, 4], index=[1, 1, 2, 2], name=("a", "a")) + ser.index.name = ("b", "b") + result = ser.groupby(level=0).last() + expected = Series([2, 4], index=[1, 2], name=("a", "a")) + expected.index.name = ("b", "b") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "func, values", [("sum", [97.0, 98.0]), ("mean", [24.25, 24.5])] +) +def test_groupby_numerical_stability_sum_mean(func, values): + # GH#38778 + data = [1e16, 1e16, 97, 98, -5e15, -5e15, -5e15, -5e15] + df = DataFrame({"group": [1, 2] * 4, "a": data, "b": data}) + result = getattr(df.groupby("group"), func)() + expected = DataFrame({"a": values, "b": values}, index=Index([1, 2], name="group")) + tm.assert_frame_equal(result, expected) + + +def test_groupby_numerical_stability_cumsum(): + # GH#38934 + data = [1e16, 1e16, 97, 98, -5e15, -5e15, -5e15, -5e15] + df = DataFrame({"group": [1, 2] * 4, "a": data, "b": data}) + result = df.groupby("group").cumsum() + exp_data = ( + [1e16] * 2 + [1e16 + 96, 1e16 + 98] + [5e15 + 97, 5e15 + 98] + [97.0, 98.0] + ) + expected = DataFrame({"a": exp_data, "b": exp_data}) + tm.assert_frame_equal(result, expected, check_exact=True) + + +def test_groupby_cumsum_skipna_false(): + # GH#46216 don't propagate np.nan above the diagonal + arr = np.random.default_rng(2).standard_normal((5, 5)) + df = DataFrame(arr) + for i in range(5): + df.iloc[i, i] = np.nan + + df["A"] = 1 + gb = df.groupby("A") + + res = gb.cumsum(skipna=False) + + expected = df[[0, 1, 2, 3, 4]].cumsum(skipna=False) + tm.assert_frame_equal(res, expected) + + +def test_groupby_cumsum_timedelta64(): + # GH#46216 don't ignore is_datetimelike in libgroupby.group_cumsum + dti = date_range("2016-01-01", periods=5) + ser = Series(dti) - dti[0] + ser[2] = pd.NaT + + df = DataFrame({"A": 1, "B": ser}) + gb = df.groupby("A") + + res = gb.cumsum(numeric_only=False, skipna=True) + exp = DataFrame({"B": [ser[0], ser[1], pd.NaT, ser[4], ser[4] * 2]}) + tm.assert_frame_equal(res, exp) + + res = gb.cumsum(numeric_only=False, skipna=False) + exp = DataFrame({"B": [ser[0], ser[1], pd.NaT, pd.NaT, pd.NaT]}) + tm.assert_frame_equal(res, exp) + + +def test_groupby_mean_duplicate_index(rand_series_with_duplicate_datetimeindex): + dups = rand_series_with_duplicate_datetimeindex + result = dups.groupby(level=0).mean() + expected = dups.groupby(dups.index).mean() + tm.assert_series_equal(result, expected) + + +def test_groupby_all_nan_groups_drop(): + # GH 15036 + s = Series([1, 2, 3], [np.nan, np.nan, np.nan]) + result = s.groupby(s.index).sum() + expected = Series([], index=Index([], dtype=np.float64), dtype=np.int64) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("numeric_only", [True, False]) +def test_groupby_empty_multi_column(as_index, numeric_only): + # GH 15106 & GH 41998 + df = DataFrame(data=[], columns=["A", "B", "C"]) + gb = df.groupby(["A", "B"], as_index=as_index) + result = gb.sum(numeric_only=numeric_only) + if as_index: + index = MultiIndex([[], []], [[], []], names=["A", "B"]) + columns = ["C"] if not numeric_only else [] + else: + index = RangeIndex(0) + columns = ["A", "B", "C"] if not numeric_only else ["A", "B"] + expected = DataFrame([], columns=columns, index=index) + tm.assert_frame_equal(result, expected) + + +def test_groupby_aggregation_non_numeric_dtype(): + # GH #43108 + df = DataFrame( + [["M", [1]], ["M", [1]], ["W", [10]], ["W", [20]]], columns=["MW", "v"] + ) + + expected = DataFrame( + { + "v": [[1, 1], [10, 20]], + }, + index=Index(["M", "W"], dtype="object", name="MW"), + ) + + gb = df.groupby(by=["MW"]) + result = gb.sum() + tm.assert_frame_equal(result, expected) + + +def test_groupby_aggregation_multi_non_numeric_dtype(): + # GH #42395 + df = DataFrame( + { + "x": [1, 0, 1, 1, 0], + "y": [Timedelta(i, "days") for i in range(1, 6)], + "z": [Timedelta(i * 10, "days") for i in range(1, 6)], + } + ) + + expected = DataFrame( + { + "y": [Timedelta(i, "days") for i in range(7, 9)], + "z": [Timedelta(i * 10, "days") for i in range(7, 9)], + }, + index=Index([0, 1], dtype="int64", name="x"), + ) + + gb = df.groupby(by=["x"]) + result = gb.sum() + tm.assert_frame_equal(result, expected) + + +def test_groupby_aggregation_numeric_with_non_numeric_dtype(): + # GH #43108 + df = DataFrame( + { + "x": [1, 0, 1, 1, 0], + "y": [Timedelta(i, "days") for i in range(1, 6)], + "z": list(range(1, 6)), + } + ) + + expected = DataFrame( + {"y": [Timedelta(7, "days"), Timedelta(8, "days")], "z": [7, 8]}, + index=Index([0, 1], dtype="int64", name="x"), + ) + + gb = df.groupby(by=["x"]) + result = gb.sum() + tm.assert_frame_equal(result, expected) + + +def test_groupby_filtered_df_std(): + # GH 16174 + dicts = [ + {"filter_col": False, "groupby_col": True, "bool_col": True, "float_col": 10.5}, + {"filter_col": True, "groupby_col": True, "bool_col": True, "float_col": 20.5}, + {"filter_col": True, "groupby_col": True, "bool_col": True, "float_col": 30.5}, + ] + df = DataFrame(dicts) + + df_filter = df[df["filter_col"] == True] # noqa: E712 + dfgb = df_filter.groupby("groupby_col") + result = dfgb.std() + expected = DataFrame( + [[0.0, 0.0, 7.071068]], + columns=["filter_col", "bool_col", "float_col"], + index=Index([True], name="groupby_col"), + ) + tm.assert_frame_equal(result, expected) + + +def test_datetime_categorical_multikey_groupby_indices(): + # GH 26859 + df = DataFrame( + { + "a": Series(list("abc")), + "b": Series( + to_datetime(["2018-01-01", "2018-02-01", "2018-03-01"]), + dtype="category", + ), + "c": Categorical.from_codes([-1, 0, 1], categories=[0, 1]), + } + ) + result = df.groupby(["a", "b"], observed=False).indices + expected = { + ("a", Timestamp("2018-01-01 00:00:00")): np.array([0]), + ("b", Timestamp("2018-02-01 00:00:00")): np.array([1]), + ("c", Timestamp("2018-03-01 00:00:00")): np.array([2]), + } + assert result == expected + + +def test_rolling_wrong_param_min_period(): + # GH34037 + name_l = ["Alice"] * 5 + ["Bob"] * 5 + val_l = [np.nan, np.nan, 1, 2, 3] + [np.nan, 1, 2, 3, 4] + test_df = DataFrame([name_l, val_l]).T + test_df.columns = ["name", "val"] + + result_error_msg = ( + r"^[a-zA-Z._]*\(\) got an unexpected keyword argument 'min_period'" + ) + with pytest.raises(TypeError, match=result_error_msg): + test_df.groupby("name")["val"].rolling(window=2, min_period=1).sum() + + +@pytest.mark.parametrize( + "dtype", + [ + object, + pytest.param("string[pyarrow_numpy]", marks=td.skip_if_no("pyarrow")), + ], +) +def test_by_column_values_with_same_starting_value(dtype): + # GH29635 + df = DataFrame( + { + "Name": ["Thomas", "Thomas", "Thomas John"], + "Credit": [1200, 1300, 900], + "Mood": Series(["sad", "happy", "happy"], dtype=dtype), + } + ) + aggregate_details = {"Mood": Series.mode, "Credit": "sum"} + + result = df.groupby(["Name"]).agg(aggregate_details) + expected_result = DataFrame( + { + "Mood": [["happy", "sad"], "happy"], + "Credit": [2500, 900], + "Name": ["Thomas", "Thomas John"], + } + ).set_index("Name") + + tm.assert_frame_equal(result, expected_result) + + +def test_groupby_none_in_first_mi_level(): + # GH#47348 + arr = [[None, 1, 0, 1], [2, 3, 2, 3]] + ser = Series(1, index=MultiIndex.from_arrays(arr, names=["a", "b"])) + result = ser.groupby(level=[0, 1]).sum() + expected = Series( + [1, 2], MultiIndex.from_tuples([(0.0, 2), (1.0, 3)], names=["a", "b"]) + ) + tm.assert_series_equal(result, expected) + + +def test_groupby_none_column_name(): + # GH#47348 + df = DataFrame({None: [1, 1, 2, 2], "b": [1, 1, 2, 3], "c": [4, 5, 6, 7]}) + result = df.groupby(by=[None]).sum() + expected = DataFrame({"b": [2, 5], "c": [9, 13]}, index=Index([1, 2], name=None)) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("selection", [None, "a", ["a"]]) +def test_single_element_list_grouping(selection): + # GH#42795, GH#53500 + df = DataFrame({"a": [1, 2], "b": [np.nan, 5], "c": [np.nan, 2]}, index=["x", "y"]) + grouped = df.groupby(["a"]) if selection is None else df.groupby(["a"])[selection] + result = [key for key, _ in grouped] + + expected = [(1,), (2,)] + assert result == expected + + +def test_groupby_string_dtype(): + # GH 40148 + df = DataFrame({"str_col": ["a", "b", "c", "a"], "num_col": [1, 2, 3, 2]}) + df["str_col"] = df["str_col"].astype("string") + expected = DataFrame( + { + "str_col": [ + "a", + "b", + "c", + ], + "num_col": [1.5, 2.0, 3.0], + } + ) + expected["str_col"] = expected["str_col"].astype("string") + grouped = df.groupby("str_col", as_index=False) + result = grouped.mean() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "level_arg, multiindex", [([0], False), ((0,), False), ([0], True), ((0,), True)] +) +def test_single_element_listlike_level_grouping_deprecation(level_arg, multiindex): + # GH 51583 + df = DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}, index=["x", "y"]) + if multiindex: + df = df.set_index(["a", "b"]) + depr_msg = ( + "Creating a Groupby object with a length-1 list-like " + "level parameter will yield indexes as tuples in a future version. " + "To keep indexes as scalars, create Groupby objects with " + "a scalar level parameter instead." + ) + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + [key for key, _ in df.groupby(level=level_arg)] + + +@pytest.mark.parametrize("func", ["sum", "cumsum", "cumprod", "prod"]) +def test_groupby_avoid_casting_to_float(func): + # GH#37493 + val = 922337203685477580 + df = DataFrame({"a": 1, "b": [val]}) + result = getattr(df.groupby("a"), func)() - val + expected = DataFrame({"b": [0]}, index=Index([1], name="a")) + if func in ["cumsum", "cumprod"]: + expected = expected.reset_index(drop=True) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("func, val", [("sum", 3), ("prod", 2)]) +def test_groupby_sum_support_mask(any_numeric_ea_dtype, func, val): + # GH#37493 + df = DataFrame({"a": 1, "b": [1, 2, pd.NA]}, dtype=any_numeric_ea_dtype) + result = getattr(df.groupby("a"), func)() + expected = DataFrame( + {"b": [val]}, + index=Index([1], name="a", dtype=any_numeric_ea_dtype), + dtype=any_numeric_ea_dtype, + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("val, dtype", [(111, "int"), (222, "uint")]) +def test_groupby_overflow(val, dtype): + # GH#37493 + df = DataFrame({"a": 1, "b": [val, val]}, dtype=f"{dtype}8") + result = df.groupby("a").sum() + expected = DataFrame( + {"b": [val * 2]}, + index=Index([1], name="a", dtype=f"{dtype}8"), + dtype=f"{dtype}64", + ) + tm.assert_frame_equal(result, expected) + + result = df.groupby("a").cumsum() + expected = DataFrame({"b": [val, val * 2]}, dtype=f"{dtype}64") + tm.assert_frame_equal(result, expected) + + result = df.groupby("a").prod() + expected = DataFrame( + {"b": [val * val]}, + index=Index([1], name="a", dtype=f"{dtype}8"), + dtype=f"{dtype}64", + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("skipna, val", [(True, 3), (False, pd.NA)]) +def test_groupby_cumsum_mask(any_numeric_ea_dtype, skipna, val): + # GH#37493 + df = DataFrame({"a": 1, "b": [1, pd.NA, 2]}, dtype=any_numeric_ea_dtype) + result = df.groupby("a").cumsum(skipna=skipna) + expected = DataFrame( + {"b": [1, pd.NA, val]}, + dtype=any_numeric_ea_dtype, + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "val_in, index, val_out", + [ + ( + [1.0, 2.0, 3.0, 4.0, 5.0], + ["foo", "foo", "bar", "baz", "blah"], + [3.0, 4.0, 5.0, 3.0], + ), + ( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + ["foo", "foo", "bar", "baz", "blah", "blah"], + [3.0, 4.0, 11.0, 3.0], + ), + ], +) +def test_groupby_index_name_in_index_content(val_in, index, val_out): + # GH 48567 + series = Series(data=val_in, name="values", index=Index(index, name="blah")) + result = series.groupby("blah").sum() + expected = Series( + data=val_out, + name="values", + index=Index(["bar", "baz", "blah", "foo"], name="blah"), + ) + tm.assert_series_equal(result, expected) + + result = series.to_frame().groupby("blah").sum() + expected = expected.to_frame() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("n", [1, 10, 32, 100, 1000]) +def test_sum_of_booleans(n): + # GH 50347 + df = DataFrame({"groupby_col": 1, "bool": [True] * n}) + df["bool"] = df["bool"].eq(True) + result = df.groupby("groupby_col").sum() + expected = DataFrame({"bool": [n]}, index=Index([1], name="groupby_col")) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.filterwarnings( + "ignore:invalid value encountered in remainder:RuntimeWarning" +) +@pytest.mark.parametrize("method", ["head", "tail", "nth", "first", "last"]) +def test_groupby_method_drop_na(method): + # GH 21755 + df = DataFrame({"A": ["a", np.nan, "b", np.nan, "c"], "B": range(5)}) + + if method == "nth": + result = getattr(df.groupby("A"), method)(n=0) + else: + result = getattr(df.groupby("A"), method)() + + if method in ["first", "last"]: + expected = DataFrame({"B": [0, 2, 4]}).set_index( + Series(["a", "b", "c"], name="A") + ) + else: + expected = DataFrame({"A": ["a", "b", "c"], "B": [0, 2, 4]}, index=[0, 2, 4]) + tm.assert_frame_equal(result, expected) + + +def test_groupby_reduce_period(): + # GH#51040 + pi = pd.period_range("2016-01-01", periods=100, freq="D") + grps = list(range(10)) * 10 + ser = pi.to_series() + gb = ser.groupby(grps) + + with pytest.raises(TypeError, match="Period type does not support sum operations"): + gb.sum() + with pytest.raises( + TypeError, match="Period type does not support cumsum operations" + ): + gb.cumsum() + with pytest.raises(TypeError, match="Period type does not support prod operations"): + gb.prod() + with pytest.raises( + TypeError, match="Period type does not support cumprod operations" + ): + gb.cumprod() + + res = gb.max() + expected = ser[-10:] + expected.index = Index(range(10), dtype=int) + tm.assert_series_equal(res, expected) + + res = gb.min() + expected = ser[:10] + expected.index = Index(range(10), dtype=int) + tm.assert_series_equal(res, expected) + + +def test_obj_with_exclusions_duplicate_columns(): + # GH#50806 + df = DataFrame([[0, 1, 2, 3]]) + df.columns = [0, 1, 2, 0] + gb = df.groupby(df[1]) + result = gb._obj_with_exclusions + expected = df.take([0, 2, 3], axis=1) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("numeric_only", [True, False]) +def test_groupby_numeric_only_std_no_result(numeric_only): + # GH 51080 + dicts_non_numeric = [{"a": "foo", "b": "bar"}, {"a": "car", "b": "dar"}] + df = DataFrame(dicts_non_numeric) + dfgb = df.groupby("a", as_index=False, sort=False) + + if numeric_only: + result = dfgb.std(numeric_only=True) + expected_df = DataFrame(["foo", "car"], columns=["a"]) + tm.assert_frame_equal(result, expected_df) + else: + with pytest.raises( + ValueError, match="could not convert string to float: 'bar'" + ): + dfgb.std(numeric_only=numeric_only) + + +def test_grouping_with_categorical_interval_columns(): + # GH#34164 + df = DataFrame({"x": [0.1, 0.2, 0.3, -0.4, 0.5], "w": ["a", "b", "a", "c", "a"]}) + qq = pd.qcut(df["x"], q=np.linspace(0, 1, 5)) + result = df.groupby([qq, "w"], observed=False)["x"].agg("mean") + categorical_index_level_1 = Categorical( + [ + Interval(-0.401, 0.1, closed="right"), + Interval(0.1, 0.2, closed="right"), + Interval(0.2, 0.3, closed="right"), + Interval(0.3, 0.5, closed="right"), + ], + ordered=True, + ) + index_level_2 = ["a", "b", "c"] + mi = MultiIndex.from_product( + [categorical_index_level_1, index_level_2], names=["x", "w"] + ) + expected = Series( + np.array( + [ + 0.1, + np.nan, + -0.4, + np.nan, + 0.2, + np.nan, + 0.3, + np.nan, + np.nan, + 0.5, + np.nan, + np.nan, + ] + ), + index=mi, + name="x", + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("bug_var", [1, "a"]) +def test_groupby_sum_on_nan_should_return_nan(bug_var): + # GH 24196 + df = DataFrame({"A": [bug_var, bug_var, bug_var, np.nan]}) + dfgb = df.groupby(lambda x: x) + result = dfgb.sum(min_count=1) + + expected_df = DataFrame([bug_var, bug_var, bug_var, None], columns=["A"]) + tm.assert_frame_equal(result, expected_df) + + +@pytest.mark.parametrize( + "method", + [ + "count", + "corr", + "cummax", + "cummin", + "cumprod", + "describe", + "rank", + "quantile", + "diff", + "shift", + "all", + "any", + "idxmin", + "idxmax", + "ffill", + "bfill", + "pct_change", + ], +) +def test_groupby_selection_with_methods(df, method): + # some methods which require DatetimeIndex + rng = date_range("2014", periods=len(df)) + df.index = rng + + g = df.groupby(["A"])[["C"]] + g_exp = df[["C"]].groupby(df["A"]) + # TODO check groupby with > 1 col ? + + res = getattr(g, method)() + exp = getattr(g_exp, method)() + + # should always be frames! + tm.assert_frame_equal(res, exp) + + +def test_groupby_selection_other_methods(df): + # some methods which require DatetimeIndex + rng = date_range("2014", periods=len(df)) + df.columns.name = "foo" + df.index = rng + + g = df.groupby(["A"])[["C"]] + g_exp = df[["C"]].groupby(df["A"]) + + # methods which aren't just .foo() + warn_msg = "DataFrameGroupBy.fillna is deprecated" + with tm.assert_produces_warning(FutureWarning, match=warn_msg): + tm.assert_frame_equal(g.fillna(0), g_exp.fillna(0)) + msg = "DataFrameGroupBy.dtypes is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + tm.assert_frame_equal(g.dtypes, g_exp.dtypes) + tm.assert_frame_equal(g.apply(lambda x: x.sum()), g_exp.apply(lambda x: x.sum())) + + tm.assert_frame_equal(g.resample("D").mean(), g_exp.resample("D").mean()) + tm.assert_frame_equal(g.resample("D").ohlc(), g_exp.resample("D").ohlc()) + + tm.assert_frame_equal( + g.filter(lambda x: len(x) == 3), g_exp.filter(lambda x: len(x) == 3) + ) + + +def test_groupby_with_Time_Grouper(unit): + idx2 = to_datetime( + [ + "2016-08-31 22:08:12.000", + "2016-08-31 22:09:12.200", + "2016-08-31 22:20:12.400", + ] + ).as_unit(unit) + + test_data = DataFrame( + {"quant": [1.0, 1.0, 3.0], "quant2": [1.0, 1.0, 3.0], "time2": idx2} + ) + + time2 = date_range("2016-08-31 22:08:00", periods=13, freq="1min", unit=unit) + expected_output = DataFrame( + { + "time2": time2, + "quant": [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], + "quant2": [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], + } + ) + + gb = test_data.groupby(Grouper(key="time2", freq="1min")) + result = gb.count().reset_index() + + tm.assert_frame_equal(result, expected_output) + + +def test_groupby_series_with_datetimeindex_month_name(): + # GH 48509 + s = Series([0, 1, 0], index=date_range("2022-01-01", periods=3), name="jan") + result = s.groupby(s).count() + expected = Series([2, 1], name="jan") + expected.index.name = "jan" + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("test_series", [True, False]) +@pytest.mark.parametrize( + "kwarg, value, name, warn", + [ + ("by", "a", 1, None), + ("by", ["a"], 1, FutureWarning), + ("by", ["a"], (1,), None), + ("level", 0, 1, None), + ("level", [0], 1, FutureWarning), + ("level", [0], (1,), None), + ], +) +def test_depr_get_group_len_1_list_likes(test_series, kwarg, value, name, warn): + # GH#25971 + obj = DataFrame({"b": [3, 4, 5]}, index=Index([1, 1, 2], name="a")) + if test_series: + obj = obj["b"] + gb = obj.groupby(**{kwarg: value}) + msg = "you will need to pass a length-1 tuple" + with tm.assert_produces_warning(warn, match=msg): + result = gb.get_group(name) + if test_series: + expected = Series([3, 4], index=Index([1, 1], name="a"), name="b") + else: + expected = DataFrame({"b": [3, 4]}, index=Index([1, 1], name="a")) + tm.assert_equal(result, expected) + + +def test_groupby_ngroup_with_nan(): + # GH#50100 + df = DataFrame({"a": Categorical([np.nan]), "b": [1]}) + result = df.groupby(["a", "b"], dropna=False, observed=False).ngroup() + expected = Series([0]) + tm.assert_series_equal(result, expected) + + +def test_get_group_axis_1(): + # GH#54858 + df = DataFrame( + { + "col1": [0, 3, 2, 3], + "col2": [4, 1, 6, 7], + "col3": [3, 8, 2, 10], + "col4": [1, 13, 6, 15], + "col5": [-4, 5, 6, -7], + } + ) + with tm.assert_produces_warning(FutureWarning, match="deprecated"): + grouped = df.groupby(axis=1, by=[1, 2, 3, 2, 1]) + result = grouped.get_group(1) + expected = DataFrame( + { + "col1": [0, 3, 2, 3], + "col5": [-4, 5, 6, -7], + } + ) + tm.assert_frame_equal(result, expected) + + +def test_groupby_ffill_with_duplicated_index(): + # GH#43412 + df = DataFrame({"a": [1, 2, 3, 4, np.nan, np.nan]}, index=[0, 1, 2, 0, 1, 2]) + + result = df.groupby(level=0).ffill() + expected = DataFrame({"a": [1, 2, 3, 4, 2, 3]}, index=[0, 1, 2, 0, 1, 2]) + tm.assert_frame_equal(result, expected, check_dtype=False) + + +@pytest.mark.parametrize("test_series", [True, False]) +def test_decimal_na_sort(test_series): + # GH#54847 + # We catch both TypeError and decimal.InvalidOperation exceptions in safe_sort. + # If this next assert raises, we can just catch TypeError + assert not isinstance(decimal.InvalidOperation, TypeError) + df = DataFrame( + { + "key": [Decimal(1), Decimal(1), None, None], + "value": [Decimal(2), Decimal(3), Decimal(4), Decimal(5)], + } + ) + gb = df.groupby("key", dropna=False) + if test_series: + gb = gb["value"] + result = gb._grouper.result_index + expected = Index([Decimal(1), None], name="key") + tm.assert_index_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_groupby_dropna.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_groupby_dropna.py new file mode 100644 index 0000000000000000000000000000000000000000..9155f2cccf1178e2b107621e7e3b78a5c87e9105 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_groupby_dropna.py @@ -0,0 +1,696 @@ +import numpy as np +import pytest + +from pandas.compat.pyarrow import pa_version_under10p1 + +from pandas.core.dtypes.missing import na_value_for_dtype + +import pandas as pd +import pandas._testing as tm +from pandas.tests.groupby import get_groupby_method_args + + +@pytest.mark.parametrize( + "dropna, tuples, outputs", + [ + ( + True, + [["A", "B"], ["B", "A"]], + {"c": [13.0, 123.23], "d": [13.0, 123.0], "e": [13.0, 1.0]}, + ), + ( + False, + [["A", "B"], ["A", np.nan], ["B", "A"]], + { + "c": [13.0, 12.3, 123.23], + "d": [13.0, 233.0, 123.0], + "e": [13.0, 12.0, 1.0], + }, + ), + ], +) +def test_groupby_dropna_multi_index_dataframe_nan_in_one_group( + dropna, tuples, outputs, nulls_fixture +): + # GH 3729 this is to test that NA is in one group + df_list = [ + ["A", "B", 12, 12, 12], + ["A", nulls_fixture, 12.3, 233.0, 12], + ["B", "A", 123.23, 123, 1], + ["A", "B", 1, 1, 1.0], + ] + df = pd.DataFrame(df_list, columns=["a", "b", "c", "d", "e"]) + grouped = df.groupby(["a", "b"], dropna=dropna).sum() + + mi = pd.MultiIndex.from_tuples(tuples, names=list("ab")) + + # Since right now, by default MI will drop NA from levels when we create MI + # via `from_*`, so we need to add NA for level manually afterwards. + if not dropna: + mi = mi.set_levels(["A", "B", np.nan], level="b") + expected = pd.DataFrame(outputs, index=mi) + + tm.assert_frame_equal(grouped, expected) + + +@pytest.mark.parametrize( + "dropna, tuples, outputs", + [ + ( + True, + [["A", "B"], ["B", "A"]], + {"c": [12.0, 123.23], "d": [12.0, 123.0], "e": [12.0, 1.0]}, + ), + ( + False, + [["A", "B"], ["A", np.nan], ["B", "A"], [np.nan, "B"]], + { + "c": [12.0, 13.3, 123.23, 1.0], + "d": [12.0, 234.0, 123.0, 1.0], + "e": [12.0, 13.0, 1.0, 1.0], + }, + ), + ], +) +def test_groupby_dropna_multi_index_dataframe_nan_in_two_groups( + dropna, tuples, outputs, nulls_fixture, nulls_fixture2 +): + # GH 3729 this is to test that NA in different groups with different representations + df_list = [ + ["A", "B", 12, 12, 12], + ["A", nulls_fixture, 12.3, 233.0, 12], + ["B", "A", 123.23, 123, 1], + [nulls_fixture2, "B", 1, 1, 1.0], + ["A", nulls_fixture2, 1, 1, 1.0], + ] + df = pd.DataFrame(df_list, columns=["a", "b", "c", "d", "e"]) + grouped = df.groupby(["a", "b"], dropna=dropna).sum() + + mi = pd.MultiIndex.from_tuples(tuples, names=list("ab")) + + # Since right now, by default MI will drop NA from levels when we create MI + # via `from_*`, so we need to add NA for level manually afterwards. + if not dropna: + mi = mi.set_levels([["A", "B", np.nan], ["A", "B", np.nan]]) + expected = pd.DataFrame(outputs, index=mi) + + tm.assert_frame_equal(grouped, expected) + + +@pytest.mark.parametrize( + "dropna, idx, outputs", + [ + (True, ["A", "B"], {"b": [123.23, 13.0], "c": [123.0, 13.0], "d": [1.0, 13.0]}), + ( + False, + ["A", "B", np.nan], + { + "b": [123.23, 13.0, 12.3], + "c": [123.0, 13.0, 233.0], + "d": [1.0, 13.0, 12.0], + }, + ), + ], +) +def test_groupby_dropna_normal_index_dataframe(dropna, idx, outputs): + # GH 3729 + df_list = [ + ["B", 12, 12, 12], + [None, 12.3, 233.0, 12], + ["A", 123.23, 123, 1], + ["B", 1, 1, 1.0], + ] + df = pd.DataFrame(df_list, columns=["a", "b", "c", "d"]) + grouped = df.groupby("a", dropna=dropna).sum() + + expected = pd.DataFrame(outputs, index=pd.Index(idx, dtype="object", name="a")) + + tm.assert_frame_equal(grouped, expected) + + +@pytest.mark.parametrize( + "dropna, idx, expected", + [ + (True, ["a", "a", "b", np.nan], pd.Series([3, 3], index=["a", "b"])), + ( + False, + ["a", "a", "b", np.nan], + pd.Series([3, 3, 3], index=["a", "b", np.nan]), + ), + ], +) +def test_groupby_dropna_series_level(dropna, idx, expected): + ser = pd.Series([1, 2, 3, 3], index=idx) + + result = ser.groupby(level=0, dropna=dropna).sum() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "dropna, expected", + [ + (True, pd.Series([210.0, 350.0], index=["a", "b"], name="Max Speed")), + ( + False, + pd.Series([210.0, 350.0, 20.0], index=["a", "b", np.nan], name="Max Speed"), + ), + ], +) +def test_groupby_dropna_series_by(dropna, expected): + ser = pd.Series( + [390.0, 350.0, 30.0, 20.0], + index=["Falcon", "Falcon", "Parrot", "Parrot"], + name="Max Speed", + ) + + result = ser.groupby(["a", "b", "a", np.nan], dropna=dropna).mean() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("dropna", (False, True)) +def test_grouper_dropna_propagation(dropna): + # GH 36604 + df = pd.DataFrame({"A": [0, 0, 1, None], "B": [1, 2, 3, None]}) + gb = df.groupby("A", dropna=dropna) + assert gb._grouper.dropna == dropna + + +@pytest.mark.parametrize( + "index", + [ + pd.RangeIndex(0, 4), + list("abcd"), + pd.MultiIndex.from_product([(1, 2), ("R", "B")], names=["num", "col"]), + ], +) +def test_groupby_dataframe_slice_then_transform(dropna, index): + # GH35014 & GH35612 + expected_data = {"B": [2, 2, 1, np.nan if dropna else 1]} + + df = pd.DataFrame({"A": [0, 0, 1, None], "B": [1, 2, 3, None]}, index=index) + gb = df.groupby("A", dropna=dropna) + + result = gb.transform(len) + expected = pd.DataFrame(expected_data, index=index) + tm.assert_frame_equal(result, expected) + + result = gb[["B"]].transform(len) + expected = pd.DataFrame(expected_data, index=index) + tm.assert_frame_equal(result, expected) + + result = gb["B"].transform(len) + expected = pd.Series(expected_data["B"], index=index, name="B") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "dropna, tuples, outputs", + [ + ( + True, + [["A", "B"], ["B", "A"]], + {"c": [13.0, 123.23], "d": [12.0, 123.0], "e": [1.0, 1.0]}, + ), + ( + False, + [["A", "B"], ["A", np.nan], ["B", "A"]], + { + "c": [13.0, 12.3, 123.23], + "d": [12.0, 233.0, 123.0], + "e": [1.0, 12.0, 1.0], + }, + ), + ], +) +def test_groupby_dropna_multi_index_dataframe_agg(dropna, tuples, outputs): + # GH 3729 + df_list = [ + ["A", "B", 12, 12, 12], + ["A", None, 12.3, 233.0, 12], + ["B", "A", 123.23, 123, 1], + ["A", "B", 1, 1, 1.0], + ] + df = pd.DataFrame(df_list, columns=["a", "b", "c", "d", "e"]) + agg_dict = {"c": "sum", "d": "max", "e": "min"} + grouped = df.groupby(["a", "b"], dropna=dropna).agg(agg_dict) + + mi = pd.MultiIndex.from_tuples(tuples, names=list("ab")) + + # Since right now, by default MI will drop NA from levels when we create MI + # via `from_*`, so we need to add NA for level manually afterwards. + if not dropna: + mi = mi.set_levels(["A", "B", np.nan], level="b") + expected = pd.DataFrame(outputs, index=mi) + + tm.assert_frame_equal(grouped, expected) + + +@pytest.mark.arm_slow +@pytest.mark.parametrize( + "datetime1, datetime2", + [ + (pd.Timestamp("2020-01-01"), pd.Timestamp("2020-02-01")), + (pd.Timedelta("-2 days"), pd.Timedelta("-1 days")), + (pd.Period("2020-01-01"), pd.Period("2020-02-01")), + ], +) +@pytest.mark.parametrize("dropna, values", [(True, [12, 3]), (False, [12, 3, 6])]) +def test_groupby_dropna_datetime_like_data( + dropna, values, datetime1, datetime2, unique_nulls_fixture, unique_nulls_fixture2 +): + # 3729 + df = pd.DataFrame( + { + "values": [1, 2, 3, 4, 5, 6], + "dt": [ + datetime1, + unique_nulls_fixture, + datetime2, + unique_nulls_fixture2, + datetime1, + datetime1, + ], + } + ) + + if dropna: + indexes = [datetime1, datetime2] + else: + indexes = [datetime1, datetime2, np.nan] + + grouped = df.groupby("dt", dropna=dropna).agg({"values": "sum"}) + expected = pd.DataFrame({"values": values}, index=pd.Index(indexes, name="dt")) + + tm.assert_frame_equal(grouped, expected) + + +@pytest.mark.parametrize( + "dropna, data, selected_data, levels", + [ + pytest.param( + False, + {"groups": ["a", "a", "b", np.nan], "values": [10, 10, 20, 30]}, + {"values": [0, 1, 0, 0]}, + ["a", "b", np.nan], + id="dropna_false_has_nan", + ), + pytest.param( + True, + {"groups": ["a", "a", "b", np.nan], "values": [10, 10, 20, 30]}, + {"values": [0, 1, 0]}, + None, + id="dropna_true_has_nan", + ), + pytest.param( + # no nan in "groups"; dropna=True|False should be same. + False, + {"groups": ["a", "a", "b", "c"], "values": [10, 10, 20, 30]}, + {"values": [0, 1, 0, 0]}, + None, + id="dropna_false_no_nan", + ), + pytest.param( + # no nan in "groups"; dropna=True|False should be same. + True, + {"groups": ["a", "a", "b", "c"], "values": [10, 10, 20, 30]}, + {"values": [0, 1, 0, 0]}, + None, + id="dropna_true_no_nan", + ), + ], +) +def test_groupby_apply_with_dropna_for_multi_index(dropna, data, selected_data, levels): + # GH 35889 + + df = pd.DataFrame(data) + gb = df.groupby("groups", dropna=dropna) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = gb.apply(lambda grp: pd.DataFrame({"values": range(len(grp))})) + + mi_tuples = tuple(zip(data["groups"], selected_data["values"])) + mi = pd.MultiIndex.from_tuples(mi_tuples, names=["groups", None]) + # Since right now, by default MI will drop NA from levels when we create MI + # via `from_*`, so we need to add NA for level manually afterwards. + if not dropna and levels: + mi = mi.set_levels(levels, level="groups") + + expected = pd.DataFrame(selected_data, index=mi) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("input_index", [None, ["a"], ["a", "b"]]) +@pytest.mark.parametrize("keys", [["a"], ["a", "b"]]) +@pytest.mark.parametrize("series", [True, False]) +def test_groupby_dropna_with_multiindex_input(input_index, keys, series): + # GH#46783 + obj = pd.DataFrame( + { + "a": [1, np.nan], + "b": [1, 1], + "c": [2, 3], + } + ) + + expected = obj.set_index(keys) + if series: + expected = expected["c"] + elif input_index == ["a", "b"] and keys == ["a"]: + # Column b should not be aggregated + expected = expected[["c"]] + + if input_index is not None: + obj = obj.set_index(input_index) + gb = obj.groupby(keys, dropna=False) + if series: + gb = gb["c"] + result = gb.sum() + + tm.assert_equal(result, expected) + + +def test_groupby_nan_included(): + # GH 35646 + data = {"group": ["g1", np.nan, "g1", "g2", np.nan], "B": [0, 1, 2, 3, 4]} + df = pd.DataFrame(data) + grouped = df.groupby("group", dropna=False) + result = grouped.indices + dtype = np.intp + expected = { + "g1": np.array([0, 2], dtype=dtype), + "g2": np.array([3], dtype=dtype), + np.nan: np.array([1, 4], dtype=dtype), + } + for result_values, expected_values in zip(result.values(), expected.values()): + tm.assert_numpy_array_equal(result_values, expected_values) + assert np.isnan(list(result.keys())[2]) + assert list(result.keys())[0:2] == ["g1", "g2"] + + +def test_groupby_drop_nan_with_multi_index(): + # GH 39895 + df = pd.DataFrame([[np.nan, 0, 1]], columns=["a", "b", "c"]) + df = df.set_index(["a", "b"]) + result = df.groupby(["a", "b"], dropna=False).first() + expected = df + tm.assert_frame_equal(result, expected) + + +# sequence_index enumerates all strings made up of x, y, z of length 4 +@pytest.mark.parametrize("sequence_index", range(3**4)) +@pytest.mark.parametrize( + "dtype", + [ + None, + "UInt8", + "Int8", + "UInt16", + "Int16", + "UInt32", + "Int32", + "UInt64", + "Int64", + "Float32", + "Int64", + "Float64", + "category", + "string", + pytest.param( + "string[pyarrow]", + marks=pytest.mark.skipif( + pa_version_under10p1, reason="pyarrow is not installed" + ), + ), + "datetime64[ns]", + "period[d]", + "Sparse[float]", + ], +) +@pytest.mark.parametrize("test_series", [True, False]) +def test_no_sort_keep_na(sequence_index, dtype, test_series, as_index): + # GH#46584, GH#48794 + + # Convert sequence_index into a string sequence, e.g. 5 becomes "xxyz" + # This sequence is used for the grouper. + sequence = "".join( + [{0: "x", 1: "y", 2: "z"}[sequence_index // (3**k) % 3] for k in range(4)] + ) + + # Unique values to use for grouper, depends on dtype + if dtype in ("string", "string[pyarrow]"): + uniques = {"x": "x", "y": "y", "z": pd.NA} + elif dtype in ("datetime64[ns]", "period[d]"): + uniques = {"x": "2016-01-01", "y": "2017-01-01", "z": pd.NA} + else: + uniques = {"x": 1, "y": 2, "z": np.nan} + + df = pd.DataFrame( + { + "key": pd.Series([uniques[label] for label in sequence], dtype=dtype), + "a": [0, 1, 2, 3], + } + ) + gb = df.groupby("key", dropna=False, sort=False, as_index=as_index, observed=False) + if test_series: + gb = gb["a"] + result = gb.sum() + + # Manually compute the groupby sum, use the labels "x", "y", and "z" to avoid + # issues with hashing np.nan + summed = {} + for idx, label in enumerate(sequence): + summed[label] = summed.get(label, 0) + idx + if dtype == "category": + index = pd.CategoricalIndex( + [uniques[e] for e in summed], + df["key"].cat.categories, + name="key", + ) + elif isinstance(dtype, str) and dtype.startswith("Sparse"): + index = pd.Index( + pd.array([uniques[label] for label in summed], dtype=dtype), name="key" + ) + else: + index = pd.Index([uniques[label] for label in summed], dtype=dtype, name="key") + expected = pd.Series(summed.values(), index=index, name="a", dtype=None) + if not test_series: + expected = expected.to_frame() + if not as_index: + expected = expected.reset_index() + if dtype is not None and dtype.startswith("Sparse"): + expected["key"] = expected["key"].astype(dtype) + + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize("test_series", [True, False]) +@pytest.mark.parametrize("dtype", [object, None]) +def test_null_is_null_for_dtype( + sort, dtype, nulls_fixture, nulls_fixture2, test_series +): + # GH#48506 - groups should always result in using the null for the dtype + df = pd.DataFrame({"a": [1, 2]}) + groups = pd.Series([nulls_fixture, nulls_fixture2], dtype=dtype) + obj = df["a"] if test_series else df + gb = obj.groupby(groups, dropna=False, sort=sort) + result = gb.sum() + index = pd.Index([na_value_for_dtype(groups.dtype)]) + expected = pd.DataFrame({"a": [3]}, index=index) + if test_series: + tm.assert_series_equal(result, expected["a"]) + else: + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("index_kind", ["range", "single", "multi"]) +def test_categorical_reducers(reduction_func, observed, sort, as_index, index_kind): + # Ensure there is at least one null value by appending to the end + values = np.append(np.random.default_rng(2).choice([1, 2, None], size=19), None) + df = pd.DataFrame( + {"x": pd.Categorical(values, categories=[1, 2, 3]), "y": range(20)} + ) + + # Strategy: Compare to dropna=True by filling null values with a new code + df_filled = df.copy() + df_filled["x"] = pd.Categorical(values, categories=[1, 2, 3, 4]).fillna(4) + + if index_kind == "range": + keys = ["x"] + elif index_kind == "single": + keys = ["x"] + df = df.set_index("x") + df_filled = df_filled.set_index("x") + else: + keys = ["x", "x2"] + df["x2"] = df["x"] + df = df.set_index(["x", "x2"]) + df_filled["x2"] = df_filled["x"] + df_filled = df_filled.set_index(["x", "x2"]) + args = get_groupby_method_args(reduction_func, df) + args_filled = get_groupby_method_args(reduction_func, df_filled) + if reduction_func == "corrwith" and index_kind == "range": + # Don't include the grouping columns so we can call reset_index + args = (args[0].drop(columns=keys),) + args_filled = (args_filled[0].drop(columns=keys),) + + gb_keepna = df.groupby( + keys, dropna=False, observed=observed, sort=sort, as_index=as_index + ) + + if not observed and reduction_func in ["idxmin", "idxmax"]: + with pytest.raises( + ValueError, match="empty group due to unobserved categories" + ): + getattr(gb_keepna, reduction_func)(*args) + return + + gb_filled = df_filled.groupby(keys, observed=observed, sort=sort, as_index=True) + expected = getattr(gb_filled, reduction_func)(*args_filled).reset_index() + expected["x"] = expected["x"].cat.remove_categories([4]) + if index_kind == "multi": + expected["x2"] = expected["x2"].cat.remove_categories([4]) + if as_index: + if index_kind == "multi": + expected = expected.set_index(["x", "x2"]) + else: + expected = expected.set_index("x") + elif index_kind != "range" and reduction_func != "size": + # size, unlike other methods, has the desired behavior in GH#49519 + expected = expected.drop(columns="x") + if index_kind == "multi": + expected = expected.drop(columns="x2") + if reduction_func in ("idxmax", "idxmin") and index_kind != "range": + # expected was computed with a RangeIndex; need to translate to index values + values = expected["y"].values.tolist() + if index_kind == "single": + values = [np.nan if e == 4 else e for e in values] + expected["y"] = pd.Categorical(values, categories=[1, 2, 3]) + else: + values = [(np.nan, np.nan) if e == (4, 4) else e for e in values] + expected["y"] = values + if reduction_func == "size": + # size, unlike other methods, has the desired behavior in GH#49519 + expected = expected.rename(columns={0: "size"}) + if as_index: + expected = expected["size"].rename(None) + + if as_index or index_kind == "range" or reduction_func == "size": + warn = None + else: + warn = FutureWarning + msg = "A grouping .* was excluded from the result" + with tm.assert_produces_warning(warn, match=msg): + result = getattr(gb_keepna, reduction_func)(*args) + + # size will return a Series, others are DataFrame + tm.assert_equal(result, expected) + + +def test_categorical_transformers( + request, transformation_func, observed, sort, as_index +): + # GH#36327 + if transformation_func == "fillna": + msg = "GH#49651 fillna may incorrectly reorders results when dropna=False" + request.applymarker(pytest.mark.xfail(reason=msg, strict=False)) + + values = np.append(np.random.default_rng(2).choice([1, 2, None], size=19), None) + df = pd.DataFrame( + {"x": pd.Categorical(values, categories=[1, 2, 3]), "y": range(20)} + ) + args = get_groupby_method_args(transformation_func, df) + + # Compute result for null group + null_group_values = df[df["x"].isnull()]["y"] + if transformation_func == "cumcount": + null_group_data = list(range(len(null_group_values))) + elif transformation_func == "ngroup": + if sort: + if observed: + na_group = df["x"].nunique(dropna=False) - 1 + else: + # TODO: Should this be 3? + na_group = df["x"].nunique(dropna=False) - 1 + else: + na_group = df.iloc[: null_group_values.index[0]]["x"].nunique() + null_group_data = len(null_group_values) * [na_group] + else: + null_group_data = getattr(null_group_values, transformation_func)(*args) + null_group_result = pd.DataFrame({"y": null_group_data}) + + gb_keepna = df.groupby( + "x", dropna=False, observed=observed, sort=sort, as_index=as_index + ) + gb_dropna = df.groupby("x", dropna=True, observed=observed, sort=sort) + + msg = "The default fill_method='ffill' in DataFrameGroupBy.pct_change is deprecated" + if transformation_func == "pct_change": + with tm.assert_produces_warning(FutureWarning, match=msg): + result = getattr(gb_keepna, "pct_change")(*args) + else: + result = getattr(gb_keepna, transformation_func)(*args) + expected = getattr(gb_dropna, transformation_func)(*args) + + for iloc, value in zip( + df[df["x"].isnull()].index.tolist(), null_group_result.values.ravel() + ): + if expected.ndim == 1: + expected.iloc[iloc] = value + else: + expected.iloc[iloc, 0] = value + if transformation_func == "ngroup": + expected[df["x"].notnull() & expected.ge(na_group)] += 1 + if transformation_func not in ("rank", "diff", "pct_change", "shift"): + expected = expected.astype("int64") + + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize("method", ["head", "tail"]) +def test_categorical_head_tail(method, observed, sort, as_index): + # GH#36327 + values = np.random.default_rng(2).choice([1, 2, None], 30) + df = pd.DataFrame( + {"x": pd.Categorical(values, categories=[1, 2, 3]), "y": range(len(values))} + ) + gb = df.groupby("x", dropna=False, observed=observed, sort=sort, as_index=as_index) + result = getattr(gb, method)() + + if method == "tail": + values = values[::-1] + # Take the top 5 values from each group + mask = ( + ((values == 1) & ((values == 1).cumsum() <= 5)) + | ((values == 2) & ((values == 2).cumsum() <= 5)) + # flake8 doesn't like the vectorized check for None, thinks we should use `is` + | ((values == None) & ((values == None).cumsum() <= 5)) # noqa: E711 + ) + if method == "tail": + mask = mask[::-1] + expected = df[mask] + + tm.assert_frame_equal(result, expected) + + +def test_categorical_agg(): + # GH#36327 + values = np.random.default_rng(2).choice([1, 2, None], 30) + df = pd.DataFrame( + {"x": pd.Categorical(values, categories=[1, 2, 3]), "y": range(len(values))} + ) + gb = df.groupby("x", dropna=False, observed=False) + result = gb.agg(lambda x: x.sum()) + expected = gb.sum() + tm.assert_frame_equal(result, expected) + + +def test_categorical_transform(): + # GH#36327 + values = np.random.default_rng(2).choice([1, 2, None], 30) + df = pd.DataFrame( + {"x": pd.Categorical(values, categories=[1, 2, 3]), "y": range(len(values))} + ) + gb = df.groupby("x", dropna=False, observed=False) + result = gb.transform(lambda x: x.sum()) + expected = gb.transform("sum") + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_groupby_subclass.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_groupby_subclass.py new file mode 100644 index 0000000000000000000000000000000000000000..0832b67b38098fea8dd5ce4727f81051c3591ca3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_groupby_subclass.py @@ -0,0 +1,135 @@ +from datetime import datetime + +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Index, + Series, +) +import pandas._testing as tm +from pandas.tests.groupby import get_groupby_method_args + +pytestmark = pytest.mark.filterwarnings( + "ignore:Passing a BlockManager|Passing a SingleBlockManager:DeprecationWarning" +) + + +@pytest.mark.parametrize( + "obj", + [ + tm.SubclassedDataFrame({"A": np.arange(0, 10)}), + tm.SubclassedSeries(np.arange(0, 10), name="A"), + ], +) +def test_groupby_preserves_subclass(obj, groupby_func): + # GH28330 -- preserve subclass through groupby operations + + if isinstance(obj, Series) and groupby_func in {"corrwith"}: + pytest.skip(f"Not applicable for Series and {groupby_func}") + + grouped = obj.groupby(np.arange(0, 10)) + + # Groups should preserve subclass type + assert isinstance(grouped.get_group(0), type(obj)) + + args = get_groupby_method_args(groupby_func, obj) + + warn = FutureWarning if groupby_func == "fillna" else None + msg = f"{type(grouped).__name__}.fillna is deprecated" + with tm.assert_produces_warning(warn, match=msg, raise_on_extra_warnings=False): + result1 = getattr(grouped, groupby_func)(*args) + with tm.assert_produces_warning(warn, match=msg, raise_on_extra_warnings=False): + result2 = grouped.agg(groupby_func, *args) + + # Reduction or transformation kernels should preserve type + slices = {"ngroup", "cumcount", "size"} + if isinstance(obj, DataFrame) and groupby_func in slices: + assert isinstance(result1, tm.SubclassedSeries) + else: + assert isinstance(result1, type(obj)) + + # Confirm .agg() groupby operations return same results + if isinstance(result1, DataFrame): + tm.assert_frame_equal(result1, result2) + else: + tm.assert_series_equal(result1, result2) + + +def test_groupby_preserves_metadata(): + # GH-37343 + custom_df = tm.SubclassedDataFrame({"a": [1, 2, 3], "b": [1, 1, 2], "c": [7, 8, 9]}) + assert "testattr" in custom_df._metadata + custom_df.testattr = "hello" + for _, group_df in custom_df.groupby("c"): + assert group_df.testattr == "hello" + + # GH-45314 + def func(group): + assert isinstance(group, tm.SubclassedDataFrame) + assert hasattr(group, "testattr") + assert group.testattr == "hello" + return group.testattr + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning( + DeprecationWarning, + match=msg, + raise_on_extra_warnings=False, + check_stacklevel=False, + ): + result = custom_df.groupby("c").apply(func) + expected = tm.SubclassedSeries(["hello"] * 3, index=Index([7, 8, 9], name="c")) + tm.assert_series_equal(result, expected) + + result = custom_df.groupby("c").apply(func, include_groups=False) + tm.assert_series_equal(result, expected) + + # https://github.com/pandas-dev/pandas/pull/56761 + result = custom_df.groupby("c")[["a", "b"]].apply(func) + tm.assert_series_equal(result, expected) + + def func2(group): + assert isinstance(group, tm.SubclassedSeries) + assert hasattr(group, "testattr") + return group.testattr + + custom_series = tm.SubclassedSeries([1, 2, 3]) + custom_series.testattr = "hello" + result = custom_series.groupby(custom_df["c"]).apply(func2) + tm.assert_series_equal(result, expected) + result = custom_series.groupby(custom_df["c"]).agg(func2) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("obj", [DataFrame, tm.SubclassedDataFrame]) +def test_groupby_resample_preserves_subclass(obj): + # GH28330 -- preserve subclass through groupby.resample() + + df = obj( + { + "Buyer": "Carl Carl Carl Carl Joe Carl".split(), + "Quantity": [18, 3, 5, 1, 9, 3], + "Date": [ + datetime(2013, 9, 1, 13, 0), + datetime(2013, 9, 1, 13, 5), + datetime(2013, 10, 1, 20, 0), + datetime(2013, 10, 3, 10, 0), + datetime(2013, 12, 2, 12, 0), + datetime(2013, 9, 2, 14, 0), + ], + } + ) + df = df.set_index("Date") + + # Confirm groupby.resample() preserves dataframe type + msg = "DataFrameGroupBy.resample operated on the grouping columns" + with tm.assert_produces_warning( + DeprecationWarning, + match=msg, + raise_on_extra_warnings=False, + check_stacklevel=False, + ): + result = df.groupby("Buyer").resample("5D").sum() + assert isinstance(result, obj) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_grouping.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_grouping.py new file mode 100644 index 0000000000000000000000000000000000000000..d763b670593757c8f1a8b35a32f277566f648652 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_grouping.py @@ -0,0 +1,1236 @@ +""" +test where we are determining what we are grouping, or getting groups +""" +from datetime import ( + date, + timedelta, +) + +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + CategoricalIndex, + DataFrame, + Grouper, + Index, + MultiIndex, + Series, + Timestamp, + date_range, + period_range, +) +import pandas._testing as tm +from pandas.core.groupby.grouper import Grouping + +# selection +# -------------------------------- + + +class TestSelection: + def test_select_bad_cols(self): + df = DataFrame([[1, 2]], columns=["A", "B"]) + g = df.groupby("A") + with pytest.raises(KeyError, match="\"Columns not found: 'C'\""): + g[["C"]] + + with pytest.raises(KeyError, match="^[^A]+$"): + # A should not be referenced as a bad column... + # will have to rethink regex if you change message! + g[["A", "C"]] + + def test_groupby_duplicated_column_errormsg(self): + # GH7511 + df = DataFrame( + columns=["A", "B", "A", "C"], data=[range(4), range(2, 6), range(0, 8, 2)] + ) + + msg = "Grouper for 'A' not 1-dimensional" + with pytest.raises(ValueError, match=msg): + df.groupby("A") + with pytest.raises(ValueError, match=msg): + df.groupby(["A", "B"]) + + grouped = df.groupby("B") + c = grouped.count() + assert c.columns.nlevels == 1 + assert c.columns.size == 3 + + def test_column_select_via_attr(self, df): + result = df.groupby("A").C.sum() + expected = df.groupby("A")["C"].sum() + tm.assert_series_equal(result, expected) + + df["mean"] = 1.5 + result = df.groupby("A").mean(numeric_only=True) + expected = df.groupby("A")[["C", "D", "mean"]].agg("mean") + tm.assert_frame_equal(result, expected) + + def test_getitem_list_of_columns(self): + df = DataFrame( + { + "A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"], + "B": ["one", "one", "two", "three", "two", "two", "one", "three"], + "C": np.random.default_rng(2).standard_normal(8), + "D": np.random.default_rng(2).standard_normal(8), + "E": np.random.default_rng(2).standard_normal(8), + } + ) + + result = df.groupby("A")[["C", "D"]].mean() + result2 = df.groupby("A")[df.columns[2:4]].mean() + + expected = df.loc[:, ["A", "C", "D"]].groupby("A").mean() + + tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(result2, expected) + + def test_getitem_numeric_column_names(self): + # GH #13731 + df = DataFrame( + { + 0: list("abcd") * 2, + 2: np.random.default_rng(2).standard_normal(8), + 4: np.random.default_rng(2).standard_normal(8), + 6: np.random.default_rng(2).standard_normal(8), + } + ) + result = df.groupby(0)[df.columns[1:3]].mean() + result2 = df.groupby(0)[[2, 4]].mean() + + expected = df.loc[:, [0, 2, 4]].groupby(0).mean() + + tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(result2, expected) + + # per GH 23566 enforced deprecation raises a ValueError + with pytest.raises(ValueError, match="Cannot subset columns with a tuple"): + df.groupby(0)[2, 4].mean() + + def test_getitem_single_tuple_of_columns_raises(self, df): + # per GH 23566 enforced deprecation raises a ValueError + with pytest.raises(ValueError, match="Cannot subset columns with a tuple"): + df.groupby("A")["C", "D"].mean() + + def test_getitem_single_column(self): + df = DataFrame( + { + "A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"], + "B": ["one", "one", "two", "three", "two", "two", "one", "three"], + "C": np.random.default_rng(2).standard_normal(8), + "D": np.random.default_rng(2).standard_normal(8), + "E": np.random.default_rng(2).standard_normal(8), + } + ) + + result = df.groupby("A")["C"].mean() + + as_frame = df.loc[:, ["A", "C"]].groupby("A").mean() + as_series = as_frame.iloc[:, 0] + expected = as_series + + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "func", [lambda x: x.sum(), lambda x: x.agg(lambda y: y.sum())] + ) + def test_getitem_from_grouper(self, func): + # GH 50383 + df = DataFrame({"a": [1, 1, 2], "b": 3, "c": 4, "d": 5}) + gb = df.groupby(["a", "b"])[["a", "c"]] + + idx = MultiIndex.from_tuples([(1, 3), (2, 3)], names=["a", "b"]) + expected = DataFrame({"a": [2, 2], "c": [8, 4]}, index=idx) + result = func(gb) + + tm.assert_frame_equal(result, expected) + + def test_indices_grouped_by_tuple_with_lambda(self): + # GH 36158 + df = DataFrame( + { + "Tuples": ( + (x, y) + for x in [0, 1] + for y in np.random.default_rng(2).integers(3, 5, 5) + ) + } + ) + + gb = df.groupby("Tuples") + gb_lambda = df.groupby(lambda x: df.iloc[x, 0]) + + expected = gb.indices + result = gb_lambda.indices + + tm.assert_dict_equal(result, expected) + + +# grouping +# -------------------------------- + + +class TestGrouping: + @pytest.mark.parametrize( + "index", + [ + Index(list("abcde")), + Index(np.arange(5)), + Index(np.arange(5, dtype=float)), + date_range("2020-01-01", periods=5), + period_range("2020-01-01", periods=5), + ], + ) + def test_grouper_index_types(self, index): + # related GH5375 + # groupby misbehaving when using a Floatlike index + df = DataFrame(np.arange(10).reshape(5, 2), columns=list("AB"), index=index) + + df.groupby(list("abcde"), group_keys=False).apply(lambda x: x) + + df.index = df.index[::-1] + df.groupby(list("abcde"), group_keys=False).apply(lambda x: x) + + def test_grouper_multilevel_freq(self): + # GH 7885 + # with level and freq specified in a Grouper + d0 = date.today() - timedelta(days=14) + dates = date_range(d0, date.today()) + date_index = MultiIndex.from_product([dates, dates], names=["foo", "bar"]) + df = DataFrame(np.random.default_rng(2).integers(0, 100, 225), index=date_index) + + # Check string level + expected = ( + df.reset_index() + .groupby([Grouper(key="foo", freq="W"), Grouper(key="bar", freq="W")]) + .sum() + ) + # reset index changes columns dtype to object + expected.columns = Index([0], dtype="int64") + + result = df.groupby( + [Grouper(level="foo", freq="W"), Grouper(level="bar", freq="W")] + ).sum() + tm.assert_frame_equal(result, expected) + + # Check integer level + result = df.groupby( + [Grouper(level=0, freq="W"), Grouper(level=1, freq="W")] + ).sum() + tm.assert_frame_equal(result, expected) + + def test_grouper_creation_bug(self): + # GH 8795 + df = DataFrame({"A": [0, 0, 1, 1, 2, 2], "B": [1, 2, 3, 4, 5, 6]}) + g = df.groupby("A") + expected = g.sum() + + g = df.groupby(Grouper(key="A")) + result = g.sum() + tm.assert_frame_equal(result, expected) + + msg = "Grouper axis keyword is deprecated and will be removed" + with tm.assert_produces_warning(FutureWarning, match=msg): + gpr = Grouper(key="A", axis=0) + g = df.groupby(gpr) + result = g.sum() + tm.assert_frame_equal(result, expected) + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = g.apply(lambda x: x.sum()) + expected["A"] = [0, 2, 4] + expected = expected.loc[:, ["A", "B"]] + tm.assert_frame_equal(result, expected) + + def test_grouper_creation_bug2(self): + # GH14334 + # Grouper(key=...) may be passed in a list + df = DataFrame( + {"A": [0, 0, 0, 1, 1, 1], "B": [1, 1, 2, 2, 3, 3], "C": [1, 2, 3, 4, 5, 6]} + ) + # Group by single column + expected = df.groupby("A").sum() + g = df.groupby([Grouper(key="A")]) + result = g.sum() + tm.assert_frame_equal(result, expected) + + # Group by two columns + # using a combination of strings and Grouper objects + expected = df.groupby(["A", "B"]).sum() + + # Group with two Grouper objects + g = df.groupby([Grouper(key="A"), Grouper(key="B")]) + result = g.sum() + tm.assert_frame_equal(result, expected) + + # Group with a string and a Grouper object + g = df.groupby(["A", Grouper(key="B")]) + result = g.sum() + tm.assert_frame_equal(result, expected) + + # Group with a Grouper object and a string + g = df.groupby([Grouper(key="A"), "B"]) + result = g.sum() + tm.assert_frame_equal(result, expected) + + def test_grouper_creation_bug3(self, unit): + # GH8866 + dti = date_range("20130101", periods=2, unit=unit) + mi = MultiIndex.from_product( + [list("ab"), range(2), dti], + names=["one", "two", "three"], + ) + ser = Series( + np.arange(8, dtype="int64"), + index=mi, + ) + result = ser.groupby(Grouper(level="three", freq="ME")).sum() + exp_dti = pd.DatetimeIndex( + [Timestamp("2013-01-31")], freq="ME", name="three" + ).as_unit(unit) + expected = Series( + [28], + index=exp_dti, + ) + tm.assert_series_equal(result, expected) + + # just specifying a level breaks + result = ser.groupby(Grouper(level="one")).sum() + expected = ser.groupby(level="one").sum() + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("func", [False, True]) + def test_grouper_returning_tuples(self, func): + # GH 22257 , both with dict and with callable + df = DataFrame({"X": ["A", "B", "A", "B"], "Y": [1, 4, 3, 2]}) + mapping = dict(zip(range(4), [("C", 5), ("D", 6)] * 2)) + + if func: + gb = df.groupby(by=lambda idx: mapping[idx], sort=False) + else: + gb = df.groupby(by=mapping, sort=False) + + name, expected = next(iter(gb)) + assert name == ("C", 5) + result = gb.get_group(name) + + tm.assert_frame_equal(result, expected) + + def test_grouper_column_and_index(self): + # GH 14327 + + # Grouping a multi-index frame by a column and an index level should + # be equivalent to resetting the index and grouping by two columns + idx = MultiIndex.from_tuples( + [("a", 1), ("a", 2), ("a", 3), ("b", 1), ("b", 2), ("b", 3)] + ) + idx.names = ["outer", "inner"] + df_multi = DataFrame( + {"A": np.arange(6), "B": ["one", "one", "two", "two", "one", "one"]}, + index=idx, + ) + result = df_multi.groupby(["B", Grouper(level="inner")]).mean(numeric_only=True) + expected = ( + df_multi.reset_index().groupby(["B", "inner"]).mean(numeric_only=True) + ) + tm.assert_frame_equal(result, expected) + + # Test the reverse grouping order + result = df_multi.groupby([Grouper(level="inner"), "B"]).mean(numeric_only=True) + expected = ( + df_multi.reset_index().groupby(["inner", "B"]).mean(numeric_only=True) + ) + tm.assert_frame_equal(result, expected) + + # Grouping a single-index frame by a column and the index should + # be equivalent to resetting the index and grouping by two columns + df_single = df_multi.reset_index("outer") + result = df_single.groupby(["B", Grouper(level="inner")]).mean( + numeric_only=True + ) + expected = ( + df_single.reset_index().groupby(["B", "inner"]).mean(numeric_only=True) + ) + tm.assert_frame_equal(result, expected) + + # Test the reverse grouping order + result = df_single.groupby([Grouper(level="inner"), "B"]).mean( + numeric_only=True + ) + expected = ( + df_single.reset_index().groupby(["inner", "B"]).mean(numeric_only=True) + ) + tm.assert_frame_equal(result, expected) + + def test_groupby_levels_and_columns(self): + # GH9344, GH9049 + idx_names = ["x", "y"] + idx = MultiIndex.from_tuples([(1, 1), (1, 2), (3, 4), (5, 6)], names=idx_names) + df = DataFrame(np.arange(12).reshape(-1, 3), index=idx) + + by_levels = df.groupby(level=idx_names).mean() + # reset_index changes columns dtype to object + by_columns = df.reset_index().groupby(idx_names).mean() + + # without casting, by_columns.columns is object-dtype + by_columns.columns = by_columns.columns.astype(np.int64) + tm.assert_frame_equal(by_levels, by_columns) + + def test_groupby_categorical_index_and_columns(self, observed): + # GH18432, adapted for GH25871 + columns = ["A", "B", "A", "B"] + categories = ["B", "A"] + data = np.array( + [[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 1, 2]], int + ) + cat_columns = CategoricalIndex(columns, categories=categories, ordered=True) + df = DataFrame(data=data, columns=cat_columns) + depr_msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + result = df.groupby(axis=1, level=0, observed=observed).sum() + expected_data = np.array([[4, 2], [4, 2], [4, 2], [4, 2], [4, 2]], int) + expected_columns = CategoricalIndex( + categories, categories=categories, ordered=True + ) + expected = DataFrame(data=expected_data, columns=expected_columns) + tm.assert_frame_equal(result, expected) + + # test transposed version + df = DataFrame(data.T, index=cat_columns) + msg = "The 'axis' keyword in DataFrame.groupby is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = df.groupby(axis=0, level=0, observed=observed).sum() + expected = DataFrame(data=expected_data.T, index=expected_columns) + tm.assert_frame_equal(result, expected) + + def test_grouper_getting_correct_binner(self): + # GH 10063 + # using a non-time-based grouper and a time-based grouper + # and specifying levels + df = DataFrame( + {"A": 1}, + index=MultiIndex.from_product( + [list("ab"), date_range("20130101", periods=80)], names=["one", "two"] + ), + ) + result = df.groupby( + [Grouper(level="one"), Grouper(level="two", freq="ME")] + ).sum() + expected = DataFrame( + {"A": [31, 28, 21, 31, 28, 21]}, + index=MultiIndex.from_product( + [list("ab"), date_range("20130101", freq="ME", periods=3)], + names=["one", "two"], + ), + ) + tm.assert_frame_equal(result, expected) + + def test_grouper_iter(self, df): + gb = df.groupby("A") + msg = "DataFrameGroupBy.grouper is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + grouper = gb.grouper + result = sorted(grouper) + expected = ["bar", "foo"] + assert result == expected + + def test_empty_groups(self, df): + # see gh-1048 + with pytest.raises(ValueError, match="No group keys passed!"): + df.groupby([]) + + def test_groupby_grouper(self, df): + grouped = df.groupby("A") + msg = "DataFrameGroupBy.grouper is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + grouper = grouped.grouper + result = df.groupby(grouper).mean(numeric_only=True) + expected = grouped.mean(numeric_only=True) + tm.assert_frame_equal(result, expected) + + def test_groupby_dict_mapping(self): + # GH #679 + s = Series({"T1": 5}) + result = s.groupby({"T1": "T2"}).agg("sum") + expected = s.groupby(["T2"]).agg("sum") + tm.assert_series_equal(result, expected) + + s = Series([1.0, 2.0, 3.0, 4.0], index=list("abcd")) + mapping = {"a": 0, "b": 0, "c": 1, "d": 1} + + result = s.groupby(mapping).mean() + result2 = s.groupby(mapping).agg("mean") + exp_key = np.array([0, 0, 1, 1], dtype=np.int64) + expected = s.groupby(exp_key).mean() + expected2 = s.groupby(exp_key).mean() + tm.assert_series_equal(result, expected) + tm.assert_series_equal(result, result2) + tm.assert_series_equal(result, expected2) + + @pytest.mark.parametrize( + "index", + [ + [0, 1, 2, 3], + ["a", "b", "c", "d"], + [Timestamp(2021, 7, 28 + i) for i in range(4)], + ], + ) + def test_groupby_series_named_with_tuple(self, frame_or_series, index): + # GH 42731 + obj = frame_or_series([1, 2, 3, 4], index=index) + groups = Series([1, 0, 1, 0], index=index, name=("a", "a")) + result = obj.groupby(groups).last() + expected = frame_or_series([4, 3]) + expected.index.name = ("a", "a") + tm.assert_equal(result, expected) + + def test_groupby_grouper_f_sanity_checked(self): + dates = date_range("01-Jan-2013", periods=12, freq="MS") + ts = Series(np.random.default_rng(2).standard_normal(12), index=dates) + + # GH51979 + # simple check that the passed function doesn't operates on the whole index + msg = "'Timestamp' object is not subscriptable" + with pytest.raises(TypeError, match=msg): + ts.groupby(lambda key: key[0:6]) + + result = ts.groupby(lambda x: x).sum() + expected = ts.groupby(ts.index).sum() + expected.index.freq = None + tm.assert_series_equal(result, expected) + + def test_groupby_with_datetime_key(self): + # GH 51158 + df = DataFrame( + { + "id": ["a", "b"] * 3, + "b": date_range("2000-01-01", "2000-01-03", freq="9h"), + } + ) + grouper = Grouper(key="b", freq="D") + gb = df.groupby([grouper, "id"]) + + # test number of groups + expected = { + (Timestamp("2000-01-01"), "a"): [0, 2], + (Timestamp("2000-01-01"), "b"): [1], + (Timestamp("2000-01-02"), "a"): [4], + (Timestamp("2000-01-02"), "b"): [3, 5], + } + tm.assert_dict_equal(gb.groups, expected) + + # test number of group keys + assert len(gb.groups.keys()) == 4 + + def test_grouping_error_on_multidim_input(self, df): + msg = "Grouper for '' not 1-dimensional" + with pytest.raises(ValueError, match=msg): + Grouping(df.index, df[["A", "A"]]) + + def test_multiindex_passthru(self): + # GH 7997 + # regression from 0.14.1 + df = DataFrame([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + df.columns = MultiIndex.from_tuples([(0, 1), (1, 1), (2, 1)]) + + depr_msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + gb = df.groupby(axis=1, level=[0, 1]) + result = gb.first() + tm.assert_frame_equal(result, df) + + def test_multiindex_negative_level(self, multiindex_dataframe_random_data): + # GH 13901 + result = multiindex_dataframe_random_data.groupby(level=-1).sum() + expected = multiindex_dataframe_random_data.groupby(level="second").sum() + tm.assert_frame_equal(result, expected) + + result = multiindex_dataframe_random_data.groupby(level=-2).sum() + expected = multiindex_dataframe_random_data.groupby(level="first").sum() + tm.assert_frame_equal(result, expected) + + result = multiindex_dataframe_random_data.groupby(level=[-2, -1]).sum() + expected = multiindex_dataframe_random_data.sort_index() + tm.assert_frame_equal(result, expected) + + result = multiindex_dataframe_random_data.groupby(level=[-1, "first"]).sum() + expected = multiindex_dataframe_random_data.groupby( + level=["second", "first"] + ).sum() + tm.assert_frame_equal(result, expected) + + def test_multifunc_select_col_integer_cols(self, df): + df.columns = np.arange(len(df.columns)) + + # it works! + msg = "Passing a dictionary to SeriesGroupBy.agg is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + df.groupby(1, as_index=False)[2].agg({"Q": np.mean}) + + def test_multiindex_columns_empty_level(self): + lst = [["count", "values"], ["to filter", ""]] + midx = MultiIndex.from_tuples(lst) + + df = DataFrame([[1, "A"]], columns=midx) + + grouped = df.groupby("to filter").groups + assert grouped["A"] == [0] + + grouped = df.groupby([("to filter", "")]).groups + assert grouped["A"] == [0] + + df = DataFrame([[1, "A"], [2, "B"]], columns=midx) + + expected = df.groupby("to filter").groups + result = df.groupby([("to filter", "")]).groups + assert result == expected + + df = DataFrame([[1, "A"], [2, "A"]], columns=midx) + + expected = df.groupby("to filter").groups + result = df.groupby([("to filter", "")]).groups + tm.assert_dict_equal(result, expected) + + def test_groupby_multiindex_tuple(self): + # GH 17979 + df = DataFrame( + [[1, 2, 3, 4], [3, 4, 5, 6], [1, 4, 2, 3]], + columns=MultiIndex.from_arrays([["a", "b", "b", "c"], [1, 1, 2, 2]]), + ) + expected = df.groupby([("b", 1)]).groups + result = df.groupby(("b", 1)).groups + tm.assert_dict_equal(expected, result) + + df2 = DataFrame( + df.values, + columns=MultiIndex.from_arrays( + [["a", "b", "b", "c"], ["d", "d", "e", "e"]] + ), + ) + expected = df2.groupby([("b", "d")]).groups + result = df.groupby(("b", 1)).groups + tm.assert_dict_equal(expected, result) + + df3 = DataFrame(df.values, columns=[("a", "d"), ("b", "d"), ("b", "e"), "c"]) + expected = df3.groupby([("b", "d")]).groups + result = df.groupby(("b", 1)).groups + tm.assert_dict_equal(expected, result) + + def test_groupby_multiindex_partial_indexing_equivalence(self): + # GH 17977 + df = DataFrame( + [[1, 2, 3, 4], [3, 4, 5, 6], [1, 4, 2, 3]], + columns=MultiIndex.from_arrays([["a", "b", "b", "c"], [1, 1, 2, 2]]), + ) + + expected_mean = df.groupby([("a", 1)])[[("b", 1), ("b", 2)]].mean() + result_mean = df.groupby([("a", 1)])["b"].mean() + tm.assert_frame_equal(expected_mean, result_mean) + + expected_sum = df.groupby([("a", 1)])[[("b", 1), ("b", 2)]].sum() + result_sum = df.groupby([("a", 1)])["b"].sum() + tm.assert_frame_equal(expected_sum, result_sum) + + expected_count = df.groupby([("a", 1)])[[("b", 1), ("b", 2)]].count() + result_count = df.groupby([("a", 1)])["b"].count() + tm.assert_frame_equal(expected_count, result_count) + + expected_min = df.groupby([("a", 1)])[[("b", 1), ("b", 2)]].min() + result_min = df.groupby([("a", 1)])["b"].min() + tm.assert_frame_equal(expected_min, result_min) + + expected_max = df.groupby([("a", 1)])[[("b", 1), ("b", 2)]].max() + result_max = df.groupby([("a", 1)])["b"].max() + tm.assert_frame_equal(expected_max, result_max) + + expected_groups = df.groupby([("a", 1)])[[("b", 1), ("b", 2)]].groups + result_groups = df.groupby([("a", 1)])["b"].groups + tm.assert_dict_equal(expected_groups, result_groups) + + @pytest.mark.parametrize("sort", [True, False]) + def test_groupby_level(self, sort, multiindex_dataframe_random_data, df): + # GH 17537 + frame = multiindex_dataframe_random_data + deleveled = frame.reset_index() + + result0 = frame.groupby(level=0, sort=sort).sum() + result1 = frame.groupby(level=1, sort=sort).sum() + + expected0 = frame.groupby(deleveled["first"].values, sort=sort).sum() + expected1 = frame.groupby(deleveled["second"].values, sort=sort).sum() + + expected0.index.name = "first" + expected1.index.name = "second" + + assert result0.index.name == "first" + assert result1.index.name == "second" + + tm.assert_frame_equal(result0, expected0) + tm.assert_frame_equal(result1, expected1) + assert result0.index.name == frame.index.names[0] + assert result1.index.name == frame.index.names[1] + + # groupby level name + result0 = frame.groupby(level="first", sort=sort).sum() + result1 = frame.groupby(level="second", sort=sort).sum() + tm.assert_frame_equal(result0, expected0) + tm.assert_frame_equal(result1, expected1) + + # axis=1 + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result0 = frame.T.groupby(level=0, axis=1, sort=sort).sum() + result1 = frame.T.groupby(level=1, axis=1, sort=sort).sum() + tm.assert_frame_equal(result0, expected0.T) + tm.assert_frame_equal(result1, expected1.T) + + # raise exception for non-MultiIndex + msg = "level > 0 or level < -1 only valid with MultiIndex" + with pytest.raises(ValueError, match=msg): + df.groupby(level=1) + + def test_groupby_level_index_names(self, axis): + # GH4014 this used to raise ValueError since 'exp'>1 (in py2) + df = DataFrame({"exp": ["A"] * 3 + ["B"] * 3, "var1": range(6)}).set_index( + "exp" + ) + if axis in (1, "columns"): + df = df.T + depr_msg = "DataFrame.groupby with axis=1 is deprecated" + else: + depr_msg = "The 'axis' keyword in DataFrame.groupby is deprecated" + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + df.groupby(level="exp", axis=axis) + msg = f"level name foo is not the name of the {df._get_axis_name(axis)}" + with pytest.raises(ValueError, match=msg): + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + df.groupby(level="foo", axis=axis) + + @pytest.mark.parametrize("sort", [True, False]) + def test_groupby_level_with_nas(self, sort): + # GH 17537 + index = MultiIndex( + levels=[[1, 0], [0, 1, 2, 3]], + codes=[[1, 1, 1, 1, 0, 0, 0, 0], [0, 1, 2, 3, 0, 1, 2, 3]], + ) + + # factorizing doesn't confuse things + s = Series(np.arange(8.0), index=index) + result = s.groupby(level=0, sort=sort).sum() + expected = Series([6.0, 22.0], index=[0, 1]) + tm.assert_series_equal(result, expected) + + index = MultiIndex( + levels=[[1, 0], [0, 1, 2, 3]], + codes=[[1, 1, 1, 1, -1, 0, 0, 0], [0, 1, 2, 3, 0, 1, 2, 3]], + ) + + # factorizing doesn't confuse things + s = Series(np.arange(8.0), index=index) + result = s.groupby(level=0, sort=sort).sum() + expected = Series([6.0, 18.0], index=[0.0, 1.0]) + tm.assert_series_equal(result, expected) + + def test_groupby_args(self, multiindex_dataframe_random_data): + # PR8618 and issue 8015 + frame = multiindex_dataframe_random_data + + msg = "You have to supply one of 'by' and 'level'" + with pytest.raises(TypeError, match=msg): + frame.groupby() + + msg = "You have to supply one of 'by' and 'level'" + with pytest.raises(TypeError, match=msg): + frame.groupby(by=None, level=None) + + @pytest.mark.parametrize( + "sort,labels", + [ + [True, [2, 2, 2, 0, 0, 1, 1, 3, 3, 3]], + [False, [0, 0, 0, 1, 1, 2, 2, 3, 3, 3]], + ], + ) + def test_level_preserve_order(self, sort, labels, multiindex_dataframe_random_data): + # GH 17537 + grouped = multiindex_dataframe_random_data.groupby(level=0, sort=sort) + exp_labels = np.array(labels, np.intp) + tm.assert_almost_equal(grouped._grouper.codes[0], exp_labels) + + def test_grouping_labels(self, multiindex_dataframe_random_data): + grouped = multiindex_dataframe_random_data.groupby( + multiindex_dataframe_random_data.index.get_level_values(0) + ) + exp_labels = np.array([2, 2, 2, 0, 0, 1, 1, 3, 3, 3], dtype=np.intp) + tm.assert_almost_equal(grouped._grouper.codes[0], exp_labels) + + def test_list_grouper_with_nat(self): + # GH 14715 + df = DataFrame({"date": date_range("1/1/2011", periods=365, freq="D")}) + df.iloc[-1] = pd.NaT + grouper = Grouper(key="date", freq="YS") + + # Grouper in a list grouping + result = df.groupby([grouper]) + expected = {Timestamp("2011-01-01"): Index(list(range(364)))} + tm.assert_dict_equal(result.groups, expected) + + # Test case without a list + result = df.groupby(grouper) + expected = {Timestamp("2011-01-01"): 365} + tm.assert_dict_equal(result.groups, expected) + + @pytest.mark.parametrize( + "func,expected", + [ + ( + "transform", + Series(name=2, dtype=np.float64), + ), + ( + "agg", + Series( + name=2, dtype=np.float64, index=Index([], dtype=np.float64, name=1) + ), + ), + ( + "apply", + Series( + name=2, dtype=np.float64, index=Index([], dtype=np.float64, name=1) + ), + ), + ], + ) + def test_evaluate_with_empty_groups(self, func, expected): + # 26208 + # test transform'ing empty groups + # (not testing other agg fns, because they return + # different index objects. + df = DataFrame({1: [], 2: []}) + g = df.groupby(1, group_keys=False) + result = getattr(g[2], func)(lambda x: x) + tm.assert_series_equal(result, expected) + + def test_groupby_empty(self): + # https://github.com/pandas-dev/pandas/issues/27190 + s = Series([], name="name", dtype="float64") + gr = s.groupby([]) + + result = gr.mean() + expected = s.set_axis(Index([], dtype=np.intp)) + tm.assert_series_equal(result, expected) + + # check group properties + assert len(gr._grouper.groupings) == 1 + tm.assert_numpy_array_equal( + gr._grouper.group_info[0], np.array([], dtype=np.dtype(np.intp)) + ) + + tm.assert_numpy_array_equal( + gr._grouper.group_info[1], np.array([], dtype=np.dtype(np.intp)) + ) + + assert gr._grouper.group_info[2] == 0 + + # check name + gb = s.groupby(s) + msg = "SeriesGroupBy.grouper is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + grouper = gb.grouper + result = grouper.names + expected = ["name"] + assert result == expected + + def test_groupby_level_index_value_all_na(self): + # issue 20519 + df = DataFrame( + [["x", np.nan, 10], [None, np.nan, 20]], columns=["A", "B", "C"] + ).set_index(["A", "B"]) + result = df.groupby(level=["A", "B"]).sum() + expected = DataFrame( + data=[], + index=MultiIndex( + levels=[Index(["x"], dtype="object"), Index([], dtype="float64")], + codes=[[], []], + names=["A", "B"], + ), + columns=["C"], + dtype="int64", + ) + tm.assert_frame_equal(result, expected) + + def test_groupby_multiindex_level_empty(self): + # https://github.com/pandas-dev/pandas/issues/31670 + df = DataFrame( + [[123, "a", 1.0], [123, "b", 2.0]], columns=["id", "category", "value"] + ) + df = df.set_index(["id", "category"]) + empty = df[df.value < 0] + result = empty.groupby("id").sum() + expected = DataFrame( + dtype="float64", + columns=["value"], + index=Index([], dtype=np.int64, name="id"), + ) + tm.assert_frame_equal(result, expected) + + +# get_group +# -------------------------------- + + +class TestGetGroup: + def test_get_group(self): + # GH 5267 + # be datelike friendly + df = DataFrame( + { + "DATE": pd.to_datetime( + [ + "10-Oct-2013", + "10-Oct-2013", + "10-Oct-2013", + "11-Oct-2013", + "11-Oct-2013", + "11-Oct-2013", + ] + ), + "label": ["foo", "foo", "bar", "foo", "foo", "bar"], + "VAL": [1, 2, 3, 4, 5, 6], + } + ) + + g = df.groupby("DATE") + key = next(iter(g.groups)) + result1 = g.get_group(key) + result2 = g.get_group(Timestamp(key).to_pydatetime()) + result3 = g.get_group(str(Timestamp(key))) + tm.assert_frame_equal(result1, result2) + tm.assert_frame_equal(result1, result3) + + g = df.groupby(["DATE", "label"]) + + key = next(iter(g.groups)) + result1 = g.get_group(key) + result2 = g.get_group((Timestamp(key[0]).to_pydatetime(), key[1])) + result3 = g.get_group((str(Timestamp(key[0])), key[1])) + tm.assert_frame_equal(result1, result2) + tm.assert_frame_equal(result1, result3) + + # must pass a same-length tuple with multiple keys + msg = "must supply a tuple to get_group with multiple grouping keys" + with pytest.raises(ValueError, match=msg): + g.get_group("foo") + with pytest.raises(ValueError, match=msg): + g.get_group("foo") + msg = "must supply a same-length tuple to get_group with multiple grouping keys" + with pytest.raises(ValueError, match=msg): + g.get_group(("foo", "bar", "baz")) + + def test_get_group_empty_bins(self, observed): + d = DataFrame([3, 1, 7, 6]) + bins = [0, 5, 10, 15] + g = d.groupby(pd.cut(d[0], bins), observed=observed) + + # TODO: should prob allow a str of Interval work as well + # IOW '(0, 5]' + result = g.get_group(pd.Interval(0, 5)) + expected = DataFrame([3, 1], index=[0, 1]) + tm.assert_frame_equal(result, expected) + + msg = r"Interval\(10, 15, closed='right'\)" + with pytest.raises(KeyError, match=msg): + g.get_group(pd.Interval(10, 15)) + + def test_get_group_grouped_by_tuple(self): + # GH 8121 + df = DataFrame([[(1,), (1, 2), (1,), (1, 2)]], index=["ids"]).T + gr = df.groupby("ids") + expected = DataFrame({"ids": [(1,), (1,)]}, index=[0, 2]) + result = gr.get_group((1,)) + tm.assert_frame_equal(result, expected) + + dt = pd.to_datetime(["2010-01-01", "2010-01-02", "2010-01-01", "2010-01-02"]) + df = DataFrame({"ids": [(x,) for x in dt]}) + gr = df.groupby("ids") + result = gr.get_group(("2010-01-01",)) + expected = DataFrame({"ids": [(dt[0],), (dt[0],)]}, index=[0, 2]) + tm.assert_frame_equal(result, expected) + + def test_get_group_grouped_by_tuple_with_lambda(self): + # GH 36158 + df = DataFrame( + { + "Tuples": ( + (x, y) + for x in [0, 1] + for y in np.random.default_rng(2).integers(3, 5, 5) + ) + } + ) + + gb = df.groupby("Tuples") + gb_lambda = df.groupby(lambda x: df.iloc[x, 0]) + + expected = gb.get_group(next(iter(gb.groups.keys()))) + result = gb_lambda.get_group(next(iter(gb_lambda.groups.keys()))) + + tm.assert_frame_equal(result, expected) + + def test_groupby_with_empty(self): + index = pd.DatetimeIndex(()) + data = () + series = Series(data, index, dtype=object) + grouper = Grouper(freq="D") + grouped = series.groupby(grouper) + assert next(iter(grouped), None) is None + + def test_groupby_with_single_column(self): + df = DataFrame({"a": list("abssbab")}) + tm.assert_frame_equal(df.groupby("a").get_group("a"), df.iloc[[0, 5]]) + # GH 13530 + exp = DataFrame(index=Index(["a", "b", "s"], name="a"), columns=[]) + tm.assert_frame_equal(df.groupby("a").count(), exp) + tm.assert_frame_equal(df.groupby("a").sum(), exp) + + exp = df.iloc[[3, 4, 5]] + tm.assert_frame_equal(df.groupby("a").nth(1), exp) + + def test_gb_key_len_equal_axis_len(self): + # GH16843 + # test ensures that index and column keys are recognized correctly + # when number of keys equals axis length of groupby + df = DataFrame( + [["foo", "bar", "B", 1], ["foo", "bar", "B", 2], ["foo", "baz", "C", 3]], + columns=["first", "second", "third", "one"], + ) + df = df.set_index(["first", "second"]) + df = df.groupby(["first", "second", "third"]).size() + assert df.loc[("foo", "bar", "B")] == 2 + assert df.loc[("foo", "baz", "C")] == 1 + + +# groups & iteration +# -------------------------------- + + +class TestIteration: + def test_groups(self, df): + grouped = df.groupby(["A"]) + groups = grouped.groups + assert groups is grouped.groups # caching works + + for k, v in grouped.groups.items(): + assert (df.loc[v]["A"] == k).all() + + grouped = df.groupby(["A", "B"]) + groups = grouped.groups + assert groups is grouped.groups # caching works + + for k, v in grouped.groups.items(): + assert (df.loc[v]["A"] == k[0]).all() + assert (df.loc[v]["B"] == k[1]).all() + + def test_grouping_is_iterable(self, tsframe): + # this code path isn't used anywhere else + # not sure it's useful + grouped = tsframe.groupby([lambda x: x.weekday(), lambda x: x.year]) + + # test it works + for g in grouped._grouper.groupings[0]: + pass + + def test_multi_iter(self): + s = Series(np.arange(6)) + k1 = np.array(["a", "a", "a", "b", "b", "b"]) + k2 = np.array(["1", "2", "1", "2", "1", "2"]) + + grouped = s.groupby([k1, k2]) + + iterated = list(grouped) + expected = [ + ("a", "1", s[[0, 2]]), + ("a", "2", s[[1]]), + ("b", "1", s[[4]]), + ("b", "2", s[[3, 5]]), + ] + for i, ((one, two), three) in enumerate(iterated): + e1, e2, e3 = expected[i] + assert e1 == one + assert e2 == two + tm.assert_series_equal(three, e3) + + def test_multi_iter_frame(self, three_group): + k1 = np.array(["b", "b", "b", "a", "a", "a"]) + k2 = np.array(["1", "2", "1", "2", "1", "2"]) + df = DataFrame( + { + "v1": np.random.default_rng(2).standard_normal(6), + "v2": np.random.default_rng(2).standard_normal(6), + "k1": k1, + "k2": k2, + }, + index=["one", "two", "three", "four", "five", "six"], + ) + + grouped = df.groupby(["k1", "k2"]) + + # things get sorted! + iterated = list(grouped) + idx = df.index + expected = [ + ("a", "1", df.loc[idx[[4]]]), + ("a", "2", df.loc[idx[[3, 5]]]), + ("b", "1", df.loc[idx[[0, 2]]]), + ("b", "2", df.loc[idx[[1]]]), + ] + for i, ((one, two), three) in enumerate(iterated): + e1, e2, e3 = expected[i] + assert e1 == one + assert e2 == two + tm.assert_frame_equal(three, e3) + + # don't iterate through groups with no data + df["k1"] = np.array(["b", "b", "b", "a", "a", "a"]) + df["k2"] = np.array(["1", "1", "1", "2", "2", "2"]) + grouped = df.groupby(["k1", "k2"]) + # calling `dict` on a DataFrameGroupBy leads to a TypeError, + # we need to use a dictionary comprehension here + # pylint: disable-next=unnecessary-comprehension + groups = {key: gp for key, gp in grouped} # noqa: C416 + assert len(groups) == 2 + + # axis = 1 + three_levels = three_group.groupby(["A", "B", "C"]).mean() + depr_msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + grouped = three_levels.T.groupby(axis=1, level=(1, 2)) + for key, group in grouped: + pass + + def test_dictify(self, df): + dict(iter(df.groupby("A"))) + dict(iter(df.groupby(["A", "B"]))) + dict(iter(df["C"].groupby(df["A"]))) + dict(iter(df["C"].groupby([df["A"], df["B"]]))) + dict(iter(df.groupby("A")["C"])) + dict(iter(df.groupby(["A", "B"])["C"])) + + def test_groupby_with_small_elem(self): + # GH 8542 + # length=2 + df = DataFrame( + {"event": ["start", "start"], "change": [1234, 5678]}, + index=pd.DatetimeIndex(["2014-09-10", "2013-10-10"]), + ) + grouped = df.groupby([Grouper(freq="ME"), "event"]) + assert len(grouped.groups) == 2 + assert grouped.ngroups == 2 + assert (Timestamp("2014-09-30"), "start") in grouped.groups + assert (Timestamp("2013-10-31"), "start") in grouped.groups + + res = grouped.get_group((Timestamp("2014-09-30"), "start")) + tm.assert_frame_equal(res, df.iloc[[0], :]) + res = grouped.get_group((Timestamp("2013-10-31"), "start")) + tm.assert_frame_equal(res, df.iloc[[1], :]) + + df = DataFrame( + {"event": ["start", "start", "start"], "change": [1234, 5678, 9123]}, + index=pd.DatetimeIndex(["2014-09-10", "2013-10-10", "2014-09-15"]), + ) + grouped = df.groupby([Grouper(freq="ME"), "event"]) + assert len(grouped.groups) == 2 + assert grouped.ngroups == 2 + assert (Timestamp("2014-09-30"), "start") in grouped.groups + assert (Timestamp("2013-10-31"), "start") in grouped.groups + + res = grouped.get_group((Timestamp("2014-09-30"), "start")) + tm.assert_frame_equal(res, df.iloc[[0, 2], :]) + res = grouped.get_group((Timestamp("2013-10-31"), "start")) + tm.assert_frame_equal(res, df.iloc[[1], :]) + + # length=3 + df = DataFrame( + {"event": ["start", "start", "start"], "change": [1234, 5678, 9123]}, + index=pd.DatetimeIndex(["2014-09-10", "2013-10-10", "2014-08-05"]), + ) + grouped = df.groupby([Grouper(freq="ME"), "event"]) + assert len(grouped.groups) == 3 + assert grouped.ngroups == 3 + assert (Timestamp("2014-09-30"), "start") in grouped.groups + assert (Timestamp("2013-10-31"), "start") in grouped.groups + assert (Timestamp("2014-08-31"), "start") in grouped.groups + + res = grouped.get_group((Timestamp("2014-09-30"), "start")) + tm.assert_frame_equal(res, df.iloc[[0], :]) + res = grouped.get_group((Timestamp("2013-10-31"), "start")) + tm.assert_frame_equal(res, df.iloc[[1], :]) + res = grouped.get_group((Timestamp("2014-08-31"), "start")) + tm.assert_frame_equal(res, df.iloc[[2], :]) + + def test_grouping_string_repr(self): + # GH 13394 + mi = MultiIndex.from_arrays([list("AAB"), list("aba")]) + df = DataFrame([[1, 2, 3]], columns=mi) + gr = df.groupby(df[("A", "a")]) + + result = gr._grouper.groupings[0].__repr__() + expected = "Grouping(('A', 'a'))" + assert result == expected + + +def test_grouping_by_key_is_in_axis(): + # GH#50413 - Groupers specified by key are in-axis + df = DataFrame({"a": [1, 1, 2], "b": [1, 1, 2], "c": [3, 4, 5]}).set_index("a") + gb = df.groupby([Grouper(level="a"), Grouper(key="b")], as_index=False) + assert not gb._grouper.groupings[0].in_axis + assert gb._grouper.groupings[1].in_axis + + # Currently only in-axis groupings are including in the result when as_index=False; + # This is likely to change in the future. + msg = "A grouping .* was excluded from the result" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = gb.sum() + expected = DataFrame({"b": [1, 2], "c": [7, 5]}) + tm.assert_frame_equal(result, expected) + + +def test_grouper_groups(): + # GH#51182 check Grouper.groups does not raise AttributeError + df = DataFrame({"a": [1, 2, 3], "b": 1}) + grper = Grouper(key="a") + gb = df.groupby(grper) + + msg = "Use GroupBy.groups instead" + with tm.assert_produces_warning(FutureWarning, match=msg): + res = grper.groups + assert res is gb.groups + + msg = "Use GroupBy.grouper instead" + with tm.assert_produces_warning(FutureWarning, match=msg): + res = grper.grouper + assert res is gb._grouper + + msg = "Grouper.obj is deprecated and will be removed" + with tm.assert_produces_warning(FutureWarning, match=msg): + res = grper.obj + assert res is gb.obj + + msg = "Use Resampler.ax instead" + with tm.assert_produces_warning(FutureWarning, match=msg): + grper.ax + + msg = "Grouper.indexer is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + grper.indexer + + +@pytest.mark.parametrize("attr", ["group_index", "result_index", "group_arraylike"]) +def test_depr_grouping_attrs(attr): + # GH#56148 + df = DataFrame({"a": [1, 1, 2], "b": [3, 4, 5]}) + gb = df.groupby("a") + msg = f"{attr} is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + getattr(gb._grouper.groupings[0], attr) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_index_as_string.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_index_as_string.py new file mode 100644 index 0000000000000000000000000000000000000000..4aaf3de9a23b2416603947db312bb49eea343ba8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_index_as_string.py @@ -0,0 +1,85 @@ +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm + + +@pytest.fixture(params=[["inner"], ["inner", "outer"]]) +def frame(request): + levels = request.param + df = pd.DataFrame( + { + "outer": ["a", "a", "a", "b", "b", "b"], + "inner": [1, 2, 3, 1, 2, 3], + "A": np.arange(6), + "B": ["one", "one", "two", "two", "one", "one"], + } + ) + if levels: + df = df.set_index(levels) + + return df + + +@pytest.fixture() +def series(): + df = pd.DataFrame( + { + "outer": ["a", "a", "a", "b", "b", "b"], + "inner": [1, 2, 3, 1, 2, 3], + "A": np.arange(6), + "B": ["one", "one", "two", "two", "one", "one"], + } + ) + s = df.set_index(["outer", "inner", "B"])["A"] + + return s + + +@pytest.mark.parametrize( + "key_strs,groupers", + [ + ("inner", pd.Grouper(level="inner")), # Index name + (["inner"], [pd.Grouper(level="inner")]), # List of index name + (["B", "inner"], ["B", pd.Grouper(level="inner")]), # Column and index + (["inner", "B"], [pd.Grouper(level="inner"), "B"]), # Index and column + ], +) +def test_grouper_index_level_as_string(frame, key_strs, groupers): + if "B" not in key_strs or "outer" in frame.columns: + result = frame.groupby(key_strs).mean(numeric_only=True) + expected = frame.groupby(groupers).mean(numeric_only=True) + else: + result = frame.groupby(key_strs).mean() + expected = frame.groupby(groupers).mean() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "levels", + [ + "inner", + "outer", + "B", + ["inner"], + ["outer"], + ["B"], + ["inner", "outer"], + ["outer", "inner"], + ["inner", "outer", "B"], + ["B", "outer", "inner"], + ], +) +def test_grouper_index_level_as_string_series(series, levels): + # Compute expected result + if isinstance(levels, list): + groupers = [pd.Grouper(level=lv) for lv in levels] + else: + groupers = pd.Grouper(level=levels) + + expected = series.groupby(groupers).mean() + + # Compute and check result + result = series.groupby(levels).mean() + tm.assert_series_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_indexing.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_indexing.py new file mode 100644 index 0000000000000000000000000000000000000000..664c52babac1381f77f2e2ee7266a9d41031f15e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_indexing.py @@ -0,0 +1,333 @@ +# Test GroupBy._positional_selector positional grouped indexing GH#42864 + +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm + + +@pytest.mark.parametrize( + "arg, expected_rows", + [ + [0, [0, 1, 4]], + [2, [5]], + [5, []], + [-1, [3, 4, 7]], + [-2, [1, 6]], + [-6, []], + ], +) +def test_int(slice_test_df, slice_test_grouped, arg, expected_rows): + # Test single integer + result = slice_test_grouped._positional_selector[arg] + expected = slice_test_df.iloc[expected_rows] + + tm.assert_frame_equal(result, expected) + + +def test_slice(slice_test_df, slice_test_grouped): + # Test single slice + result = slice_test_grouped._positional_selector[0:3:2] + expected = slice_test_df.iloc[[0, 1, 4, 5]] + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "arg, expected_rows", + [ + [[0, 2], [0, 1, 4, 5]], + [[0, 2, -1], [0, 1, 3, 4, 5, 7]], + [range(0, 3, 2), [0, 1, 4, 5]], + [{0, 2}, [0, 1, 4, 5]], + ], + ids=[ + "list", + "negative", + "range", + "set", + ], +) +def test_list(slice_test_df, slice_test_grouped, arg, expected_rows): + # Test lists of integers and integer valued iterables + result = slice_test_grouped._positional_selector[arg] + expected = slice_test_df.iloc[expected_rows] + + tm.assert_frame_equal(result, expected) + + +def test_ints(slice_test_df, slice_test_grouped): + # Test tuple of ints + result = slice_test_grouped._positional_selector[0, 2, -1] + expected = slice_test_df.iloc[[0, 1, 3, 4, 5, 7]] + + tm.assert_frame_equal(result, expected) + + +def test_slices(slice_test_df, slice_test_grouped): + # Test tuple of slices + result = slice_test_grouped._positional_selector[:2, -2:] + expected = slice_test_df.iloc[[0, 1, 2, 3, 4, 6, 7]] + + tm.assert_frame_equal(result, expected) + + +def test_mix(slice_test_df, slice_test_grouped): + # Test mixed tuple of ints and slices + result = slice_test_grouped._positional_selector[0, 1, -2:] + expected = slice_test_df.iloc[[0, 1, 2, 3, 4, 6, 7]] + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "arg, expected_rows", + [ + [0, [0, 1, 4]], + [[0, 2, -1], [0, 1, 3, 4, 5, 7]], + [(slice(None, 2), slice(-2, None)), [0, 1, 2, 3, 4, 6, 7]], + ], +) +def test_as_index(slice_test_df, arg, expected_rows): + # Test the default as_index behaviour + result = slice_test_df.groupby("Group", sort=False)._positional_selector[arg] + expected = slice_test_df.iloc[expected_rows] + + tm.assert_frame_equal(result, expected) + + +def test_doc_examples(): + # Test the examples in the documentation + df = pd.DataFrame( + [["a", 1], ["a", 2], ["a", 3], ["b", 4], ["b", 5]], columns=["A", "B"] + ) + + grouped = df.groupby("A", as_index=False) + + result = grouped._positional_selector[1:2] + expected = pd.DataFrame([["a", 2], ["b", 5]], columns=["A", "B"], index=[1, 4]) + + tm.assert_frame_equal(result, expected) + + result = grouped._positional_selector[1, -1] + expected = pd.DataFrame( + [["a", 2], ["a", 3], ["b", 5]], columns=["A", "B"], index=[1, 2, 4] + ) + + tm.assert_frame_equal(result, expected) + + +@pytest.fixture() +def multiindex_data(): + rng = np.random.default_rng(2) + ndates = 100 + nitems = 20 + dates = pd.date_range("20130101", periods=ndates, freq="D") + items = [f"item {i}" for i in range(nitems)] + + data = {} + for date in dates: + nitems_for_date = nitems - rng.integers(0, 12) + levels = [ + (item, rng.integers(0, 10000) / 100, rng.integers(0, 10000) / 100) + for item in items[:nitems_for_date] + ] + levels.sort(key=lambda x: x[1]) + data[date] = levels + + return data + + +def _make_df_from_data(data): + rows = {} + for date in data: + for level in data[date]: + rows[(date, level[0])] = {"A": level[1], "B": level[2]} + + df = pd.DataFrame.from_dict(rows, orient="index") + df.index.names = ("Date", "Item") + return df + + +def test_multiindex(multiindex_data): + # Test the multiindex mentioned as the use-case in the documentation + df = _make_df_from_data(multiindex_data) + result = df.groupby("Date", as_index=False).nth(slice(3, -3)) + + sliced = {date: multiindex_data[date][3:-3] for date in multiindex_data} + expected = _make_df_from_data(sliced) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("arg", [1, 5, 30, 1000, -1, -5, -30, -1000]) +@pytest.mark.parametrize("method", ["head", "tail"]) +@pytest.mark.parametrize("simulated", [True, False]) +def test_against_head_and_tail(arg, method, simulated): + # Test gives the same results as grouped head and tail + n_groups = 100 + n_rows_per_group = 30 + + data = { + "group": [ + f"group {g}" for j in range(n_rows_per_group) for g in range(n_groups) + ], + "value": [ + f"group {g} row {j}" + for j in range(n_rows_per_group) + for g in range(n_groups) + ], + } + df = pd.DataFrame(data) + grouped = df.groupby("group", as_index=False) + size = arg if arg >= 0 else n_rows_per_group + arg + + if method == "head": + result = grouped._positional_selector[:arg] + + if simulated: + indices = [ + j * n_groups + i + for j in range(size) + for i in range(n_groups) + if j * n_groups + i < n_groups * n_rows_per_group + ] + expected = df.iloc[indices] + + else: + expected = grouped.head(arg) + + else: + result = grouped._positional_selector[-arg:] + + if simulated: + indices = [ + (n_rows_per_group + j - size) * n_groups + i + for j in range(size) + for i in range(n_groups) + if (n_rows_per_group + j - size) * n_groups + i >= 0 + ] + expected = df.iloc[indices] + + else: + expected = grouped.tail(arg) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("start", [None, 0, 1, 10, -1, -10]) +@pytest.mark.parametrize("stop", [None, 0, 1, 10, -1, -10]) +@pytest.mark.parametrize("step", [None, 1, 5]) +def test_against_df_iloc(start, stop, step): + # Test that a single group gives the same results as DataFrame.iloc + n_rows = 30 + + data = { + "group": ["group 0"] * n_rows, + "value": list(range(n_rows)), + } + df = pd.DataFrame(data) + grouped = df.groupby("group", as_index=False) + + result = grouped._positional_selector[start:stop:step] + expected = df.iloc[start:stop:step] + + tm.assert_frame_equal(result, expected) + + +def test_series(): + # Test grouped Series + ser = pd.Series([1, 2, 3, 4, 5], index=["a", "a", "a", "b", "b"]) + grouped = ser.groupby(level=0) + result = grouped._positional_selector[1:2] + expected = pd.Series([2, 5], index=["a", "b"]) + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("step", [1, 2, 3, 4, 5]) +def test_step(step): + # Test slice with various step values + data = [["x", f"x{i}"] for i in range(5)] + data += [["y", f"y{i}"] for i in range(4)] + data += [["z", f"z{i}"] for i in range(3)] + df = pd.DataFrame(data, columns=["A", "B"]) + + grouped = df.groupby("A", as_index=False) + + result = grouped._positional_selector[::step] + + data = [["x", f"x{i}"] for i in range(0, 5, step)] + data += [["y", f"y{i}"] for i in range(0, 4, step)] + data += [["z", f"z{i}"] for i in range(0, 3, step)] + + index = [0 + i for i in range(0, 5, step)] + index += [5 + i for i in range(0, 4, step)] + index += [9 + i for i in range(0, 3, step)] + + expected = pd.DataFrame(data, columns=["A", "B"], index=index) + + tm.assert_frame_equal(result, expected) + + +@pytest.fixture() +def column_group_df(): + return pd.DataFrame( + [[0, 1, 2, 3, 4, 5, 6], [0, 0, 1, 0, 1, 0, 2]], + columns=["A", "B", "C", "D", "E", "F", "G"], + ) + + +def test_column_axis(column_group_df): + msg = "DataFrame.groupby with axis=1" + with tm.assert_produces_warning(FutureWarning, match=msg): + g = column_group_df.groupby(column_group_df.iloc[1], axis=1) + result = g._positional_selector[1:-1] + expected = column_group_df.iloc[:, [1, 3]] + + tm.assert_frame_equal(result, expected) + + +def test_columns_on_iter(): + # GitHub issue #44821 + df = pd.DataFrame({k: range(10) for k in "ABC"}) + + # Group-by and select columns + cols = ["A", "B"] + for _, dg in df.groupby(df.A < 4)[cols]: + tm.assert_index_equal(dg.columns, pd.Index(cols)) + assert "C" not in dg.columns + + +@pytest.mark.parametrize("func", [list, pd.Index, pd.Series, np.array]) +def test_groupby_duplicated_columns(func): + # GH#44924 + df = pd.DataFrame( + { + "A": [1, 2], + "B": [3, 3], + "C": ["G", "G"], + } + ) + result = df.groupby("C")[func(["A", "B", "A"])].mean() + expected = pd.DataFrame( + [[1.5, 3.0, 1.5]], columns=["A", "B", "A"], index=pd.Index(["G"], name="C") + ) + tm.assert_frame_equal(result, expected) + + +def test_groupby_get_nonexisting_groups(): + # GH#32492 + df = pd.DataFrame( + data={ + "A": ["a1", "a2", None], + "B": ["b1", "b2", "b1"], + "val": [1, 2, 3], + } + ) + grps = df.groupby(by=["A", "B"]) + + msg = "('a2', 'b1')" + with pytest.raises(KeyError, match=msg): + grps.get_group(("a2", "b1")) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_libgroupby.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_libgroupby.py new file mode 100644 index 0000000000000000000000000000000000000000..35b8fa93b8e033b8dd9287bc7de8e1ca18ade439 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_libgroupby.py @@ -0,0 +1,331 @@ +import numpy as np +import pytest + +from pandas._libs import groupby as libgroupby +from pandas._libs.groupby import ( + group_cumprod, + group_cumsum, + group_mean, + group_sum, + group_var, +) + +from pandas.core.dtypes.common import ensure_platform_int + +from pandas import isna +import pandas._testing as tm + + +class GroupVarTestMixin: + def test_group_var_generic_1d(self): + prng = np.random.default_rng(2) + + out = (np.nan * np.ones((5, 1))).astype(self.dtype) + counts = np.zeros(5, dtype="int64") + values = 10 * prng.random((15, 1)).astype(self.dtype) + labels = np.tile(np.arange(5), (3,)).astype("intp") + + expected_out = ( + np.squeeze(values).reshape((5, 3), order="F").std(axis=1, ddof=1) ** 2 + )[:, np.newaxis] + expected_counts = counts + 3 + + self.algo(out, counts, values, labels) + assert np.allclose(out, expected_out, self.rtol) + tm.assert_numpy_array_equal(counts, expected_counts) + + def test_group_var_generic_1d_flat_labels(self): + prng = np.random.default_rng(2) + + out = (np.nan * np.ones((1, 1))).astype(self.dtype) + counts = np.zeros(1, dtype="int64") + values = 10 * prng.random((5, 1)).astype(self.dtype) + labels = np.zeros(5, dtype="intp") + + expected_out = np.array([[values.std(ddof=1) ** 2]]) + expected_counts = counts + 5 + + self.algo(out, counts, values, labels) + + assert np.allclose(out, expected_out, self.rtol) + tm.assert_numpy_array_equal(counts, expected_counts) + + def test_group_var_generic_2d_all_finite(self): + prng = np.random.default_rng(2) + + out = (np.nan * np.ones((5, 2))).astype(self.dtype) + counts = np.zeros(5, dtype="int64") + values = 10 * prng.random((10, 2)).astype(self.dtype) + labels = np.tile(np.arange(5), (2,)).astype("intp") + + expected_out = np.std(values.reshape(2, 5, 2), ddof=1, axis=0) ** 2 + expected_counts = counts + 2 + + self.algo(out, counts, values, labels) + assert np.allclose(out, expected_out, self.rtol) + tm.assert_numpy_array_equal(counts, expected_counts) + + def test_group_var_generic_2d_some_nan(self): + prng = np.random.default_rng(2) + + out = (np.nan * np.ones((5, 2))).astype(self.dtype) + counts = np.zeros(5, dtype="int64") + values = 10 * prng.random((10, 2)).astype(self.dtype) + values[:, 1] = np.nan + labels = np.tile(np.arange(5), (2,)).astype("intp") + + expected_out = np.vstack( + [ + values[:, 0].reshape(5, 2, order="F").std(ddof=1, axis=1) ** 2, + np.nan * np.ones(5), + ] + ).T.astype(self.dtype) + expected_counts = counts + 2 + + self.algo(out, counts, values, labels) + tm.assert_almost_equal(out, expected_out, rtol=0.5e-06) + tm.assert_numpy_array_equal(counts, expected_counts) + + def test_group_var_constant(self): + # Regression test from GH 10448. + + out = np.array([[np.nan]], dtype=self.dtype) + counts = np.array([0], dtype="int64") + values = 0.832845131556193 * np.ones((3, 1), dtype=self.dtype) + labels = np.zeros(3, dtype="intp") + + self.algo(out, counts, values, labels) + + assert counts[0] == 3 + assert out[0, 0] >= 0 + tm.assert_almost_equal(out[0, 0], 0.0) + + +class TestGroupVarFloat64(GroupVarTestMixin): + __test__ = True + + algo = staticmethod(group_var) + dtype = np.float64 + rtol = 1e-5 + + def test_group_var_large_inputs(self): + prng = np.random.default_rng(2) + + out = np.array([[np.nan]], dtype=self.dtype) + counts = np.array([0], dtype="int64") + values = (prng.random(10**6) + 10**12).astype(self.dtype) + values.shape = (10**6, 1) + labels = np.zeros(10**6, dtype="intp") + + self.algo(out, counts, values, labels) + + assert counts[0] == 10**6 + tm.assert_almost_equal(out[0, 0], 1.0 / 12, rtol=0.5e-3) + + +class TestGroupVarFloat32(GroupVarTestMixin): + __test__ = True + + algo = staticmethod(group_var) + dtype = np.float32 + rtol = 1e-2 + + +@pytest.mark.parametrize("dtype", ["float32", "float64"]) +def test_group_ohlc(dtype): + obj = np.array(np.random.default_rng(2).standard_normal(20), dtype=dtype) + + bins = np.array([6, 12, 20]) + out = np.zeros((3, 4), dtype) + counts = np.zeros(len(out), dtype=np.int64) + labels = ensure_platform_int(np.repeat(np.arange(3), np.diff(np.r_[0, bins]))) + + func = libgroupby.group_ohlc + func(out, counts, obj[:, None], labels) + + def _ohlc(group): + if isna(group).all(): + return np.repeat(np.nan, 4) + return [group[0], group.max(), group.min(), group[-1]] + + expected = np.array([_ohlc(obj[:6]), _ohlc(obj[6:12]), _ohlc(obj[12:])]) + + tm.assert_almost_equal(out, expected) + tm.assert_numpy_array_equal(counts, np.array([6, 6, 8], dtype=np.int64)) + + obj[:6] = np.nan + func(out, counts, obj[:, None], labels) + expected[0] = np.nan + tm.assert_almost_equal(out, expected) + + +def _check_cython_group_transform_cumulative(pd_op, np_op, dtype): + """ + Check a group transform that executes a cumulative function. + + Parameters + ---------- + pd_op : callable + The pandas cumulative function. + np_op : callable + The analogous one in NumPy. + dtype : type + The specified dtype of the data. + """ + is_datetimelike = False + + data = np.array([[1], [2], [3], [4]], dtype=dtype) + answer = np.zeros_like(data) + + labels = np.array([0, 0, 0, 0], dtype=np.intp) + ngroups = 1 + pd_op(answer, data, labels, ngroups, is_datetimelike) + + tm.assert_numpy_array_equal(np_op(data), answer[:, 0], check_dtype=False) + + +@pytest.mark.parametrize("np_dtype", ["int64", "uint64", "float32", "float64"]) +def test_cython_group_transform_cumsum(np_dtype): + # see gh-4095 + dtype = np.dtype(np_dtype).type + pd_op, np_op = group_cumsum, np.cumsum + _check_cython_group_transform_cumulative(pd_op, np_op, dtype) + + +def test_cython_group_transform_cumprod(): + # see gh-4095 + dtype = np.float64 + pd_op, np_op = group_cumprod, np.cumprod + _check_cython_group_transform_cumulative(pd_op, np_op, dtype) + + +def test_cython_group_transform_algos(): + # see gh-4095 + is_datetimelike = False + + # with nans + labels = np.array([0, 0, 0, 0, 0], dtype=np.intp) + ngroups = 1 + + data = np.array([[1], [2], [3], [np.nan], [4]], dtype="float64") + actual = np.zeros_like(data) + actual.fill(np.nan) + group_cumprod(actual, data, labels, ngroups, is_datetimelike) + expected = np.array([1, 2, 6, np.nan, 24], dtype="float64") + tm.assert_numpy_array_equal(actual[:, 0], expected) + + actual = np.zeros_like(data) + actual.fill(np.nan) + group_cumsum(actual, data, labels, ngroups, is_datetimelike) + expected = np.array([1, 3, 6, np.nan, 10], dtype="float64") + tm.assert_numpy_array_equal(actual[:, 0], expected) + + # timedelta + is_datetimelike = True + data = np.array([np.timedelta64(1, "ns")] * 5, dtype="m8[ns]")[:, None] + actual = np.zeros_like(data, dtype="int64") + group_cumsum(actual, data.view("int64"), labels, ngroups, is_datetimelike) + expected = np.array( + [ + np.timedelta64(1, "ns"), + np.timedelta64(2, "ns"), + np.timedelta64(3, "ns"), + np.timedelta64(4, "ns"), + np.timedelta64(5, "ns"), + ] + ) + tm.assert_numpy_array_equal(actual[:, 0].view("m8[ns]"), expected) + + +def test_cython_group_mean_datetimelike(): + actual = np.zeros(shape=(1, 1), dtype="float64") + counts = np.array([0], dtype="int64") + data = ( + np.array( + [np.timedelta64(2, "ns"), np.timedelta64(4, "ns"), np.timedelta64("NaT")], + dtype="m8[ns]", + )[:, None] + .view("int64") + .astype("float64") + ) + labels = np.zeros(len(data), dtype=np.intp) + + group_mean(actual, counts, data, labels, is_datetimelike=True) + + tm.assert_numpy_array_equal(actual[:, 0], np.array([3], dtype="float64")) + + +def test_cython_group_mean_wrong_min_count(): + actual = np.zeros(shape=(1, 1), dtype="float64") + counts = np.zeros(1, dtype="int64") + data = np.zeros(1, dtype="float64")[:, None] + labels = np.zeros(1, dtype=np.intp) + + with pytest.raises(AssertionError, match="min_count"): + group_mean(actual, counts, data, labels, is_datetimelike=True, min_count=0) + + +def test_cython_group_mean_not_datetimelike_but_has_NaT_values(): + actual = np.zeros(shape=(1, 1), dtype="float64") + counts = np.array([0], dtype="int64") + data = ( + np.array( + [np.timedelta64("NaT"), np.timedelta64("NaT")], + dtype="m8[ns]", + )[:, None] + .view("int64") + .astype("float64") + ) + labels = np.zeros(len(data), dtype=np.intp) + + group_mean(actual, counts, data, labels, is_datetimelike=False) + + tm.assert_numpy_array_equal( + actual[:, 0], np.array(np.divide(np.add(data[0], data[1]), 2), dtype="float64") + ) + + +def test_cython_group_mean_Inf_at_begining_and_end(): + # GH 50367 + actual = np.array([[np.nan, np.nan], [np.nan, np.nan]], dtype="float64") + counts = np.array([0, 0], dtype="int64") + data = np.array( + [[np.inf, 1.0], [1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0], [5, np.inf]], + dtype="float64", + ) + labels = np.array([0, 1, 0, 1, 0, 1], dtype=np.intp) + + group_mean(actual, counts, data, labels, is_datetimelike=False) + + expected = np.array([[np.inf, 3], [3, np.inf]], dtype="float64") + + tm.assert_numpy_array_equal( + actual, + expected, + ) + + +@pytest.mark.parametrize( + "values, out", + [ + ([[np.inf], [np.inf], [np.inf]], [[np.inf], [np.inf]]), + ([[np.inf], [np.inf], [-np.inf]], [[np.inf], [np.nan]]), + ([[np.inf], [-np.inf], [np.inf]], [[np.inf], [np.nan]]), + ([[np.inf], [-np.inf], [-np.inf]], [[np.inf], [-np.inf]]), + ], +) +def test_cython_group_sum_Inf_at_begining_and_end(values, out): + # GH #53606 + actual = np.array([[np.nan], [np.nan]], dtype="float64") + counts = np.array([0, 0], dtype="int64") + data = np.array(values, dtype="float64") + labels = np.array([0, 1, 1], dtype=np.intp) + + group_sum(actual, counts, data, labels, None, is_datetimelike=False) + + expected = np.array(out, dtype="float64") + + tm.assert_numpy_array_equal( + actual, + expected, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_missing.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_missing.py new file mode 100644 index 0000000000000000000000000000000000000000..3180a92be1236688e044758bf2334a0985e7aee1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_missing.py @@ -0,0 +1,163 @@ +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + Index, + date_range, +) +import pandas._testing as tm + + +@pytest.mark.parametrize("func", ["ffill", "bfill"]) +def test_groupby_column_index_name_lost_fill_funcs(func): + # GH: 29764 groupby loses index sometimes + df = DataFrame( + [[1, 1.0, -1.0], [1, np.nan, np.nan], [1, 2.0, -2.0]], + columns=Index(["type", "a", "b"], name="idx"), + ) + df_grouped = df.groupby(["type"])[["a", "b"]] + result = getattr(df_grouped, func)().columns + expected = Index(["a", "b"], name="idx") + tm.assert_index_equal(result, expected) + + +@pytest.mark.parametrize("func", ["ffill", "bfill"]) +def test_groupby_fill_duplicate_column_names(func): + # GH: 25610 ValueError with duplicate column names + df1 = DataFrame({"field1": [1, 3, 4], "field2": [1, 3, 4]}) + df2 = DataFrame({"field1": [1, np.nan, 4]}) + df_grouped = pd.concat([df1, df2], axis=1).groupby(by=["field2"]) + expected = DataFrame( + [[1, 1.0], [3, np.nan], [4, 4.0]], columns=["field1", "field1"] + ) + result = getattr(df_grouped, func)() + tm.assert_frame_equal(result, expected) + + +def test_ffill_missing_arguments(): + # GH 14955 + df = DataFrame({"a": [1, 2], "b": [1, 1]}) + msg = "DataFrameGroupBy.fillna is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + with pytest.raises(ValueError, match="Must specify a fill"): + df.groupby("b").fillna() + + +@pytest.mark.parametrize( + "method, expected", [("ffill", [None, "a", "a"]), ("bfill", ["a", "a", None])] +) +def test_fillna_with_string_dtype(method, expected): + # GH 40250 + df = DataFrame({"a": pd.array([None, "a", None], dtype="string"), "b": [0, 0, 0]}) + grp = df.groupby("b") + msg = "DataFrameGroupBy.fillna is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = grp.fillna(method=method) + expected = DataFrame({"a": pd.array(expected, dtype="string")}) + tm.assert_frame_equal(result, expected) + + +def test_fill_consistency(): + # GH9221 + # pass thru keyword arguments to the generated wrapper + # are set if the passed kw is None (only) + df = DataFrame( + index=pd.MultiIndex.from_product( + [["value1", "value2"], date_range("2014-01-01", "2014-01-06")] + ), + columns=Index(["1", "2"], name="id"), + ) + df["1"] = [ + np.nan, + 1, + np.nan, + np.nan, + 11, + np.nan, + np.nan, + 2, + np.nan, + np.nan, + 22, + np.nan, + ] + df["2"] = [ + np.nan, + 3, + np.nan, + np.nan, + 33, + np.nan, + np.nan, + 4, + np.nan, + np.nan, + 44, + np.nan, + ] + + msg = "The 'axis' keyword in DataFrame.groupby is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + expected = df.groupby(level=0, axis=0).fillna(method="ffill") + + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = df.T.groupby(level=0, axis=1).fillna(method="ffill").T + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("method", ["ffill", "bfill"]) +@pytest.mark.parametrize("dropna", [True, False]) +@pytest.mark.parametrize("has_nan_group", [True, False]) +def test_ffill_handles_nan_groups(dropna, method, has_nan_group): + # GH 34725 + + df_without_nan_rows = DataFrame([(1, 0.1), (2, 0.2)]) + + ridx = [-1, 0, -1, -1, 1, -1] + df = df_without_nan_rows.reindex(ridx).reset_index(drop=True) + + group_b = np.nan if has_nan_group else "b" + df["group_col"] = pd.Series(["a"] * 3 + [group_b] * 3) + + grouped = df.groupby(by="group_col", dropna=dropna) + result = getattr(grouped, method)(limit=None) + + expected_rows = { + ("ffill", True, True): [-1, 0, 0, -1, -1, -1], + ("ffill", True, False): [-1, 0, 0, -1, 1, 1], + ("ffill", False, True): [-1, 0, 0, -1, 1, 1], + ("ffill", False, False): [-1, 0, 0, -1, 1, 1], + ("bfill", True, True): [0, 0, -1, -1, -1, -1], + ("bfill", True, False): [0, 0, -1, 1, 1, -1], + ("bfill", False, True): [0, 0, -1, 1, 1, -1], + ("bfill", False, False): [0, 0, -1, 1, 1, -1], + } + + ridx = expected_rows.get((method, dropna, has_nan_group)) + expected = df_without_nan_rows.reindex(ridx).reset_index(drop=True) + # columns are a 'take' on df.columns, which are object dtype + expected.columns = expected.columns.astype(object) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("min_count, value", [(2, np.nan), (-1, 1.0)]) +@pytest.mark.parametrize("func", ["first", "last", "max", "min"]) +def test_min_count(func, min_count, value): + # GH#37821 + df = DataFrame({"a": [1] * 3, "b": [1, np.nan, np.nan], "c": [np.nan] * 3}) + result = getattr(df.groupby("a"), func)(min_count=min_count) + expected = DataFrame({"b": [value], "c": [np.nan]}, index=Index([1], name="a")) + tm.assert_frame_equal(result, expected) + + +def test_indices_with_missing(): + # GH 9304 + df = DataFrame({"a": [1, 1, np.nan], "b": [2, 3, 4], "c": [5, 6, 7]}) + g = df.groupby(["a", "b"]) + result = g.indices + expected = {(1.0, 2): np.array([0]), (1.0, 3): np.array([1])} + assert result == expected diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_numba.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_numba.py new file mode 100644 index 0000000000000000000000000000000000000000..ee7d3424724932befa772e47162e032e28f2cd1d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_numba.py @@ -0,0 +1,80 @@ +import pytest + +from pandas import ( + DataFrame, + Series, + option_context, +) +import pandas._testing as tm + +pytestmark = pytest.mark.single_cpu + +pytest.importorskip("numba") + + +@pytest.mark.filterwarnings("ignore") +# Filter warnings when parallel=True and the function can't be parallelized by Numba +class TestEngine: + def test_cython_vs_numba_frame( + self, sort, nogil, parallel, nopython, numba_supported_reductions + ): + func, kwargs = numba_supported_reductions + df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)}) + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + gb = df.groupby("a", sort=sort) + result = getattr(gb, func)( + engine="numba", engine_kwargs=engine_kwargs, **kwargs + ) + expected = getattr(gb, func)(**kwargs) + tm.assert_frame_equal(result, expected) + + def test_cython_vs_numba_getitem( + self, sort, nogil, parallel, nopython, numba_supported_reductions + ): + func, kwargs = numba_supported_reductions + df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)}) + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + gb = df.groupby("a", sort=sort)["c"] + result = getattr(gb, func)( + engine="numba", engine_kwargs=engine_kwargs, **kwargs + ) + expected = getattr(gb, func)(**kwargs) + tm.assert_series_equal(result, expected) + + def test_cython_vs_numba_series( + self, sort, nogil, parallel, nopython, numba_supported_reductions + ): + func, kwargs = numba_supported_reductions + ser = Series(range(3), index=[1, 2, 1], name="foo") + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + gb = ser.groupby(level=0, sort=sort) + result = getattr(gb, func)( + engine="numba", engine_kwargs=engine_kwargs, **kwargs + ) + expected = getattr(gb, func)(**kwargs) + tm.assert_series_equal(result, expected) + + def test_as_index_false_unsupported(self, numba_supported_reductions): + func, kwargs = numba_supported_reductions + df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)}) + gb = df.groupby("a", as_index=False) + with pytest.raises(NotImplementedError, match="as_index=False"): + getattr(gb, func)(engine="numba", **kwargs) + + def test_axis_1_unsupported(self, numba_supported_reductions): + func, kwargs = numba_supported_reductions + df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)}) + gb = df.groupby("a", axis=1) + with pytest.raises(NotImplementedError, match="axis=1"): + getattr(gb, func)(engine="numba", **kwargs) + + def test_no_engine_doesnt_raise(self): + # GH55520 + df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)}) + gb = df.groupby("a") + # Make sure behavior of functions w/out engine argument don't raise + # when the global use_numba option is set + with option_context("compute.use_numba", True): + res = gb.agg({"b": "first"}) + expected = gb.agg({"b": "first"}) + tm.assert_frame_equal(res, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_numeric_only.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_numeric_only.py new file mode 100644 index 0000000000000000000000000000000000000000..ff4685b1e412d9b43503bc081686dc486fb5c62d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_numeric_only.py @@ -0,0 +1,521 @@ +import re + +import numpy as np +import pytest + +from pandas._libs import lib + +import pandas as pd +from pandas import ( + DataFrame, + Index, + Series, + Timestamp, + date_range, +) +import pandas._testing as tm +from pandas.tests.groupby import get_groupby_method_args + + +class TestNumericOnly: + # make sure that we are passing thru kwargs to our agg functions + + @pytest.fixture + def df(self): + # GH3668 + # GH5724 + df = DataFrame( + { + "group": [1, 1, 2], + "int": [1, 2, 3], + "float": [4.0, 5.0, 6.0], + "string": list("abc"), + "category_string": Series(list("abc")).astype("category"), + "category_int": [7, 8, 9], + "datetime": date_range("20130101", periods=3), + "datetimetz": date_range("20130101", periods=3, tz="US/Eastern"), + "timedelta": pd.timedelta_range("1 s", periods=3, freq="s"), + }, + columns=[ + "group", + "int", + "float", + "string", + "category_string", + "category_int", + "datetime", + "datetimetz", + "timedelta", + ], + ) + return df + + @pytest.mark.parametrize("method", ["mean", "median"]) + def test_averages(self, df, method): + # mean / median + expected_columns_numeric = Index(["int", "float", "category_int"]) + + gb = df.groupby("group") + expected = DataFrame( + { + "category_int": [7.5, 9], + "float": [4.5, 6.0], + "timedelta": [pd.Timedelta("1.5s"), pd.Timedelta("3s")], + "int": [1.5, 3], + "datetime": [ + Timestamp("2013-01-01 12:00:00"), + Timestamp("2013-01-03 00:00:00"), + ], + "datetimetz": [ + Timestamp("2013-01-01 12:00:00", tz="US/Eastern"), + Timestamp("2013-01-03 00:00:00", tz="US/Eastern"), + ], + }, + index=Index([1, 2], name="group"), + columns=[ + "int", + "float", + "category_int", + ], + ) + + result = getattr(gb, method)(numeric_only=True) + tm.assert_frame_equal(result.reindex_like(expected), expected) + + expected_columns = expected.columns + + self._check(df, method, expected_columns, expected_columns_numeric) + + @pytest.mark.parametrize("method", ["min", "max"]) + def test_extrema(self, df, method): + # TODO: min, max *should* handle + # categorical (ordered) dtype + + expected_columns = Index( + [ + "int", + "float", + "string", + "category_int", + "datetime", + "datetimetz", + "timedelta", + ] + ) + expected_columns_numeric = expected_columns + + self._check(df, method, expected_columns, expected_columns_numeric) + + @pytest.mark.parametrize("method", ["first", "last"]) + def test_first_last(self, df, method): + expected_columns = Index( + [ + "int", + "float", + "string", + "category_string", + "category_int", + "datetime", + "datetimetz", + "timedelta", + ] + ) + expected_columns_numeric = expected_columns + + self._check(df, method, expected_columns, expected_columns_numeric) + + @pytest.mark.parametrize("method", ["sum", "cumsum"]) + def test_sum_cumsum(self, df, method): + expected_columns_numeric = Index(["int", "float", "category_int"]) + expected_columns = Index( + ["int", "float", "string", "category_int", "timedelta"] + ) + if method == "cumsum": + # cumsum loses string + expected_columns = Index(["int", "float", "category_int", "timedelta"]) + + self._check(df, method, expected_columns, expected_columns_numeric) + + @pytest.mark.parametrize("method", ["prod", "cumprod"]) + def test_prod_cumprod(self, df, method): + expected_columns = Index(["int", "float", "category_int"]) + expected_columns_numeric = expected_columns + + self._check(df, method, expected_columns, expected_columns_numeric) + + @pytest.mark.parametrize("method", ["cummin", "cummax"]) + def test_cummin_cummax(self, df, method): + # like min, max, but don't include strings + expected_columns = Index( + ["int", "float", "category_int", "datetime", "datetimetz", "timedelta"] + ) + + # GH#15561: numeric_only=False set by default like min/max + expected_columns_numeric = expected_columns + + self._check(df, method, expected_columns, expected_columns_numeric) + + def _check(self, df, method, expected_columns, expected_columns_numeric): + gb = df.groupby("group") + + # object dtypes for transformations are not implemented in Cython and + # have no Python fallback + exception = NotImplementedError if method.startswith("cum") else TypeError + + if method in ("min", "max", "cummin", "cummax", "cumsum", "cumprod"): + # The methods default to numeric_only=False and raise TypeError + msg = "|".join( + [ + "Categorical is not ordered", + f"Cannot perform {method} with non-ordered Categorical", + re.escape(f"agg function failed [how->{method},dtype->object]"), + # cumsum/cummin/cummax/cumprod + "function is not implemented for this dtype", + ] + ) + with pytest.raises(exception, match=msg): + getattr(gb, method)() + elif method in ("sum", "mean", "median", "prod"): + msg = "|".join( + [ + "category type does not support sum operations", + re.escape(f"agg function failed [how->{method},dtype->object]"), + re.escape(f"agg function failed [how->{method},dtype->string]"), + ] + ) + with pytest.raises(exception, match=msg): + getattr(gb, method)() + else: + result = getattr(gb, method)() + tm.assert_index_equal(result.columns, expected_columns_numeric) + + if method not in ("first", "last"): + msg = "|".join( + [ + "Categorical is not ordered", + "category type does not support", + "function is not implemented for this dtype", + f"Cannot perform {method} with non-ordered Categorical", + re.escape(f"agg function failed [how->{method},dtype->object]"), + re.escape(f"agg function failed [how->{method},dtype->string]"), + ] + ) + with pytest.raises(exception, match=msg): + getattr(gb, method)(numeric_only=False) + else: + result = getattr(gb, method)(numeric_only=False) + tm.assert_index_equal(result.columns, expected_columns) + + +@pytest.mark.parametrize("numeric_only", [True, False, None]) +def test_axis1_numeric_only(request, groupby_func, numeric_only, using_infer_string): + if groupby_func in ("idxmax", "idxmin"): + pytest.skip("idxmax and idx_min tested in test_idxmin_idxmax_axis1") + if groupby_func in ("corrwith", "skew"): + msg = "GH#47723 groupby.corrwith and skew do not correctly implement axis=1" + request.applymarker(pytest.mark.xfail(reason=msg)) + + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), columns=["A", "B", "C", "D"] + ) + df["E"] = "x" + groups = [1, 2, 3, 1, 2, 3, 1, 2, 3, 4] + gb = df.groupby(groups) + method = getattr(gb, groupby_func) + args = get_groupby_method_args(groupby_func, df) + kwargs = {"axis": 1} + if numeric_only is not None: + # when numeric_only is None we don't pass any argument + kwargs["numeric_only"] = numeric_only + + # Functions without numeric_only and axis args + no_args = ("cumprod", "cumsum", "diff", "fillna", "pct_change", "rank", "shift") + # Functions with axis args + has_axis = ( + "cumprod", + "cumsum", + "diff", + "pct_change", + "rank", + "shift", + "cummax", + "cummin", + "idxmin", + "idxmax", + "fillna", + ) + warn_msg = f"DataFrameGroupBy.{groupby_func} with axis=1 is deprecated" + if numeric_only is not None and groupby_func in no_args: + msg = "got an unexpected keyword argument 'numeric_only'" + if groupby_func in ["cumprod", "cumsum"]: + with pytest.raises(TypeError, match=msg): + with tm.assert_produces_warning(FutureWarning, match=warn_msg): + method(*args, **kwargs) + else: + with pytest.raises(TypeError, match=msg): + method(*args, **kwargs) + elif groupby_func not in has_axis: + msg = "got an unexpected keyword argument 'axis'" + with pytest.raises(TypeError, match=msg): + method(*args, **kwargs) + # fillna and shift are successful even on object dtypes + elif (numeric_only is None or not numeric_only) and groupby_func not in ( + "fillna", + "shift", + ): + msgs = ( + # cummax, cummin, rank + "not supported between instances of", + # cumprod + "can't multiply sequence by non-int of type 'float'", + # cumsum, diff, pct_change + "unsupported operand type", + "has no kernel", + ) + if using_infer_string: + import pyarrow as pa + + errs = (TypeError, pa.lib.ArrowNotImplementedError) + else: + errs = TypeError + with pytest.raises(errs, match=f"({'|'.join(msgs)})"): + with tm.assert_produces_warning(FutureWarning, match=warn_msg): + method(*args, **kwargs) + else: + with tm.assert_produces_warning(FutureWarning, match=warn_msg): + result = method(*args, **kwargs) + + df_expected = df.drop(columns="E").T if numeric_only else df.T + expected = getattr(df_expected, groupby_func)(*args).T + if groupby_func == "shift" and not numeric_only: + # shift with axis=1 leaves the leftmost column as numeric + # but transposing for expected gives us object dtype + expected = expected.astype(float) + + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "kernel, has_arg", + [ + ("all", False), + ("any", False), + ("bfill", False), + ("corr", True), + ("corrwith", True), + ("cov", True), + ("cummax", True), + ("cummin", True), + ("cumprod", True), + ("cumsum", True), + ("diff", False), + ("ffill", False), + ("fillna", False), + ("first", True), + ("idxmax", True), + ("idxmin", True), + ("last", True), + ("max", True), + ("mean", True), + ("median", True), + ("min", True), + ("nth", False), + ("nunique", False), + ("pct_change", False), + ("prod", True), + ("quantile", True), + ("sem", True), + ("skew", True), + ("std", True), + ("sum", True), + ("var", True), + ], +) +@pytest.mark.parametrize("numeric_only", [True, False, lib.no_default]) +@pytest.mark.parametrize("keys", [["a1"], ["a1", "a2"]]) +def test_numeric_only(kernel, has_arg, numeric_only, keys): + # GH#46072 + # drops_nuisance: Whether the op drops nuisance columns even when numeric_only=False + # has_arg: Whether the op has a numeric_only arg + df = DataFrame({"a1": [1, 1], "a2": [2, 2], "a3": [5, 6], "b": 2 * [object]}) + + args = get_groupby_method_args(kernel, df) + kwargs = {} if numeric_only is lib.no_default else {"numeric_only": numeric_only} + + gb = df.groupby(keys) + method = getattr(gb, kernel) + if has_arg and numeric_only is True: + # Cases where b does not appear in the result + result = method(*args, **kwargs) + assert "b" not in result.columns + elif ( + # kernels that work on any dtype and have numeric_only arg + kernel in ("first", "last") + or ( + # kernels that work on any dtype and don't have numeric_only arg + kernel in ("any", "all", "bfill", "ffill", "fillna", "nth", "nunique") + and numeric_only is lib.no_default + ) + ): + warn = FutureWarning if kernel == "fillna" else None + msg = "DataFrameGroupBy.fillna is deprecated" + with tm.assert_produces_warning(warn, match=msg): + result = method(*args, **kwargs) + assert "b" in result.columns + elif has_arg: + assert numeric_only is not True + # kernels that are successful on any dtype were above; this will fail + + # object dtypes for transformations are not implemented in Cython and + # have no Python fallback + exception = NotImplementedError if kernel.startswith("cum") else TypeError + + msg = "|".join( + [ + "not allowed for this dtype", + "cannot be performed against 'object' dtypes", + # On PY39 message is "a number"; on PY310 and after is "a real number" + "must be a string or a.* number", + "unsupported operand type", + "function is not implemented for this dtype", + re.escape(f"agg function failed [how->{kernel},dtype->object]"), + ] + ) + if kernel == "idxmin": + msg = "'<' not supported between instances of 'type' and 'type'" + elif kernel == "idxmax": + msg = "'>' not supported between instances of 'type' and 'type'" + with pytest.raises(exception, match=msg): + method(*args, **kwargs) + elif not has_arg and numeric_only is not lib.no_default: + with pytest.raises( + TypeError, match="got an unexpected keyword argument 'numeric_only'" + ): + method(*args, **kwargs) + else: + assert kernel in ("diff", "pct_change") + assert numeric_only is lib.no_default + # Doesn't have numeric_only argument and fails on nuisance columns + with pytest.raises(TypeError, match=r"unsupported operand type"): + method(*args, **kwargs) + + +@pytest.mark.filterwarnings("ignore:Downcasting object dtype arrays:FutureWarning") +@pytest.mark.parametrize("dtype", [bool, int, float, object]) +def test_deprecate_numeric_only_series(dtype, groupby_func, request): + # GH#46560 + grouper = [0, 0, 1] + + ser = Series([1, 0, 0], dtype=dtype) + gb = ser.groupby(grouper) + + if groupby_func == "corrwith": + # corrwith is not implemented on SeriesGroupBy + assert not hasattr(gb, groupby_func) + return + + method = getattr(gb, groupby_func) + + expected_ser = Series([1, 0, 0]) + expected_gb = expected_ser.groupby(grouper) + expected_method = getattr(expected_gb, groupby_func) + + args = get_groupby_method_args(groupby_func, ser) + + fails_on_numeric_object = ( + "corr", + "cov", + "cummax", + "cummin", + "cumprod", + "cumsum", + "quantile", + ) + # ops that give an object result on object input + obj_result = ( + "first", + "last", + "nth", + "bfill", + "ffill", + "shift", + "sum", + "diff", + "pct_change", + "var", + "mean", + "median", + "min", + "max", + "prod", + "skew", + ) + + # Test default behavior; kernels that fail may be enabled in the future but kernels + # that succeed should not be allowed to fail (without deprecation, at least) + if groupby_func in fails_on_numeric_object and dtype is object: + if groupby_func == "quantile": + msg = "cannot be performed against 'object' dtypes" + else: + msg = "is not supported for object dtype" + warn = FutureWarning if groupby_func == "fillna" else None + warn_msg = "DataFrameGroupBy.fillna is deprecated" + with tm.assert_produces_warning(warn, match=warn_msg): + with pytest.raises(TypeError, match=msg): + method(*args) + elif dtype is object: + warn = FutureWarning if groupby_func == "fillna" else None + warn_msg = "SeriesGroupBy.fillna is deprecated" + with tm.assert_produces_warning(warn, match=warn_msg): + result = method(*args) + with tm.assert_produces_warning(warn, match=warn_msg): + expected = expected_method(*args) + if groupby_func in obj_result: + expected = expected.astype(object) + tm.assert_series_equal(result, expected) + + has_numeric_only = ( + "first", + "last", + "max", + "mean", + "median", + "min", + "prod", + "quantile", + "sem", + "skew", + "std", + "sum", + "var", + "cummax", + "cummin", + "cumprod", + "cumsum", + ) + if groupby_func not in has_numeric_only: + msg = "got an unexpected keyword argument 'numeric_only'" + with pytest.raises(TypeError, match=msg): + method(*args, numeric_only=True) + elif dtype is object: + msg = "|".join( + [ + "SeriesGroupBy.sem called with numeric_only=True and dtype object", + "Series.skew does not allow numeric_only=True with non-numeric", + "cum(sum|prod|min|max) is not supported for object dtype", + r"Cannot use numeric_only=True with SeriesGroupBy\..* and non-numeric", + ] + ) + with pytest.raises(TypeError, match=msg): + method(*args, numeric_only=True) + elif dtype == bool and groupby_func == "quantile": + msg = "Allowing bool dtype in SeriesGroupBy.quantile" + with tm.assert_produces_warning(FutureWarning, match=msg): + # GH#51424 + result = method(*args, numeric_only=True) + expected = method(*args, numeric_only=False) + tm.assert_series_equal(result, expected) + else: + result = method(*args, numeric_only=True) + expected = method(*args, numeric_only=False) + tm.assert_series_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_pipe.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_pipe.py new file mode 100644 index 0000000000000000000000000000000000000000..7d5c1625b8ab466677280de30562eb13c53376d7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_pipe.py @@ -0,0 +1,80 @@ +import numpy as np + +import pandas as pd +from pandas import ( + DataFrame, + Index, +) +import pandas._testing as tm + + +def test_pipe(): + # Test the pipe method of DataFrameGroupBy. + # Issue #17871 + + random_state = np.random.default_rng(2) + + df = DataFrame( + { + "A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"], + "B": random_state.standard_normal(8), + "C": random_state.standard_normal(8), + } + ) + + def f(dfgb): + return dfgb.B.max() - dfgb.C.min().min() + + def square(srs): + return srs**2 + + # Note that the transformations are + # GroupBy -> Series + # Series -> Series + # This then chains the GroupBy.pipe and the + # NDFrame.pipe methods + result = df.groupby("A").pipe(f).pipe(square) + + index = Index(["bar", "foo"], dtype="object", name="A") + expected = pd.Series([3.749306591013693, 6.717707873081384], name="B", index=index) + + tm.assert_series_equal(expected, result) + + +def test_pipe_args(): + # Test passing args to the pipe method of DataFrameGroupBy. + # Issue #17871 + + df = DataFrame( + { + "group": ["A", "A", "B", "B", "C"], + "x": [1.0, 2.0, 3.0, 2.0, 5.0], + "y": [10.0, 100.0, 1000.0, -100.0, -1000.0], + } + ) + + def f(dfgb, arg1): + filtered = dfgb.filter(lambda grp: grp.y.mean() > arg1, dropna=False) + return filtered.groupby("group") + + def g(dfgb, arg2): + return dfgb.sum() / dfgb.sum().sum() + arg2 + + def h(df, arg3): + return df.x + df.y - arg3 + + result = df.groupby("group").pipe(f, 0).pipe(g, 10).pipe(h, 100) + + # Assert the results here + index = Index(["A", "B"], name="group") + expected = pd.Series([-79.5160891089, -78.4839108911], index=index) + + tm.assert_series_equal(result, expected) + + # test SeriesGroupby.pipe + ser = pd.Series([1, 1, 2, 2, 3, 3]) + result = ser.groupby(ser).pipe(lambda grp: grp.sum() * grp.count()) + + expected = pd.Series([4, 8, 12], index=Index([1, 2, 3], dtype=np.int64)) + + tm.assert_series_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_raises.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_raises.py new file mode 100644 index 0000000000000000000000000000000000000000..0b451ce73db898f02bce2d1432ca64c0011f3e71 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_raises.py @@ -0,0 +1,716 @@ +# Only tests that raise an error and have no better location should go here. +# Tests for specific groupby methods should go in their respective +# test file. + +import datetime +import re + +import numpy as np +import pytest + +from pandas import ( + Categorical, + DataFrame, + Grouper, + Series, +) +import pandas._testing as tm +from pandas.tests.groupby import get_groupby_method_args + + +@pytest.fixture( + params=[ + "a", + ["a"], + ["a", "b"], + Grouper(key="a"), + lambda x: x % 2, + [0, 0, 0, 1, 2, 2, 2, 3, 3], + np.array([0, 0, 0, 1, 2, 2, 2, 3, 3]), + dict(zip(range(9), [0, 0, 0, 1, 2, 2, 2, 3, 3])), + Series([1, 1, 1, 1, 1, 2, 2, 2, 2]), + [Series([1, 1, 1, 1, 1, 2, 2, 2, 2]), Series([3, 3, 4, 4, 4, 4, 4, 3, 3])], + ] +) +def by(request): + return request.param + + +@pytest.fixture(params=[True, False]) +def groupby_series(request): + return request.param + + +@pytest.fixture +def df_with_string_col(): + df = DataFrame( + { + "a": [1, 1, 1, 1, 1, 2, 2, 2, 2], + "b": [3, 3, 4, 4, 4, 4, 4, 3, 3], + "c": range(9), + "d": list("xyzwtyuio"), + } + ) + return df + + +@pytest.fixture +def df_with_datetime_col(): + df = DataFrame( + { + "a": [1, 1, 1, 1, 1, 2, 2, 2, 2], + "b": [3, 3, 4, 4, 4, 4, 4, 3, 3], + "c": range(9), + "d": datetime.datetime(2005, 1, 1, 10, 30, 23, 540000), + } + ) + return df + + +@pytest.fixture +def df_with_timedelta_col(): + df = DataFrame( + { + "a": [1, 1, 1, 1, 1, 2, 2, 2, 2], + "b": [3, 3, 4, 4, 4, 4, 4, 3, 3], + "c": range(9), + "d": datetime.timedelta(days=1), + } + ) + return df + + +@pytest.fixture +def df_with_cat_col(): + df = DataFrame( + { + "a": [1, 1, 1, 1, 1, 2, 2, 2, 2], + "b": [3, 3, 4, 4, 4, 4, 4, 3, 3], + "c": range(9), + "d": Categorical( + ["a", "a", "a", "a", "b", "b", "b", "b", "c"], + categories=["a", "b", "c", "d"], + ordered=True, + ), + } + ) + return df + + +def _call_and_check(klass, msg, how, gb, groupby_func, args, warn_msg=""): + warn_klass = None if warn_msg == "" else FutureWarning + with tm.assert_produces_warning(warn_klass, match=warn_msg): + if klass is None: + if how == "method": + getattr(gb, groupby_func)(*args) + elif how == "agg": + gb.agg(groupby_func, *args) + else: + gb.transform(groupby_func, *args) + else: + with pytest.raises(klass, match=msg): + if how == "method": + getattr(gb, groupby_func)(*args) + elif how == "agg": + gb.agg(groupby_func, *args) + else: + gb.transform(groupby_func, *args) + + +@pytest.mark.parametrize("how", ["method", "agg", "transform"]) +def test_groupby_raises_string( + how, by, groupby_series, groupby_func, df_with_string_col +): + df = df_with_string_col + args = get_groupby_method_args(groupby_func, df) + gb = df.groupby(by=by) + + if groupby_series: + gb = gb["d"] + + if groupby_func == "corrwith": + assert not hasattr(gb, "corrwith") + return + + klass, msg = { + "all": (None, ""), + "any": (None, ""), + "bfill": (None, ""), + "corrwith": (TypeError, "Could not convert"), + "count": (None, ""), + "cumcount": (None, ""), + "cummax": ( + (NotImplementedError, TypeError), + "(function|cummax) is not (implemented|supported) for (this|object) dtype", + ), + "cummin": ( + (NotImplementedError, TypeError), + "(function|cummin) is not (implemented|supported) for (this|object) dtype", + ), + "cumprod": ( + (NotImplementedError, TypeError), + "(function|cumprod) is not (implemented|supported) for (this|object) dtype", + ), + "cumsum": ( + (NotImplementedError, TypeError), + "(function|cumsum) is not (implemented|supported) for (this|object) dtype", + ), + "diff": (TypeError, "unsupported operand type"), + "ffill": (None, ""), + "fillna": (None, ""), + "first": (None, ""), + "idxmax": (None, ""), + "idxmin": (None, ""), + "last": (None, ""), + "max": (None, ""), + "mean": ( + TypeError, + re.escape("agg function failed [how->mean,dtype->object]"), + ), + "median": ( + TypeError, + re.escape("agg function failed [how->median,dtype->object]"), + ), + "min": (None, ""), + "ngroup": (None, ""), + "nunique": (None, ""), + "pct_change": (TypeError, "unsupported operand type"), + "prod": ( + TypeError, + re.escape("agg function failed [how->prod,dtype->object]"), + ), + "quantile": (TypeError, "cannot be performed against 'object' dtypes!"), + "rank": (None, ""), + "sem": (ValueError, "could not convert string to float"), + "shift": (None, ""), + "size": (None, ""), + "skew": (ValueError, "could not convert string to float"), + "std": (ValueError, "could not convert string to float"), + "sum": (None, ""), + "var": ( + TypeError, + re.escape("agg function failed [how->var,dtype->"), + ), + }[groupby_func] + + if groupby_func == "fillna": + kind = "Series" if groupby_series else "DataFrame" + warn_msg = f"{kind}GroupBy.fillna is deprecated" + else: + warn_msg = "" + _call_and_check(klass, msg, how, gb, groupby_func, args, warn_msg) + + +@pytest.mark.parametrize("how", ["agg", "transform"]) +def test_groupby_raises_string_udf(how, by, groupby_series, df_with_string_col): + df = df_with_string_col + gb = df.groupby(by=by) + + if groupby_series: + gb = gb["d"] + + def func(x): + raise TypeError("Test error message") + + with pytest.raises(TypeError, match="Test error message"): + getattr(gb, how)(func) + + +@pytest.mark.parametrize("how", ["agg", "transform"]) +@pytest.mark.parametrize("groupby_func_np", [np.sum, np.mean]) +def test_groupby_raises_string_np( + how, by, groupby_series, groupby_func_np, df_with_string_col +): + # GH#50749 + df = df_with_string_col + gb = df.groupby(by=by) + + if groupby_series: + gb = gb["d"] + + klass, msg = { + np.sum: (None, ""), + np.mean: ( + TypeError, + re.escape("agg function failed [how->mean,dtype->object]"), + ), + }[groupby_func_np] + + if groupby_series: + warn_msg = "using SeriesGroupBy.[sum|mean]" + else: + warn_msg = "using DataFrameGroupBy.[sum|mean]" + _call_and_check(klass, msg, how, gb, groupby_func_np, (), warn_msg=warn_msg) + + +@pytest.mark.parametrize("how", ["method", "agg", "transform"]) +def test_groupby_raises_datetime( + how, by, groupby_series, groupby_func, df_with_datetime_col +): + df = df_with_datetime_col + args = get_groupby_method_args(groupby_func, df) + gb = df.groupby(by=by) + + if groupby_series: + gb = gb["d"] + + if groupby_func == "corrwith": + assert not hasattr(gb, "corrwith") + return + + klass, msg = { + "all": (None, ""), + "any": (None, ""), + "bfill": (None, ""), + "corrwith": (TypeError, "cannot perform __mul__ with this index type"), + "count": (None, ""), + "cumcount": (None, ""), + "cummax": (None, ""), + "cummin": (None, ""), + "cumprod": (TypeError, "datetime64 type does not support cumprod operations"), + "cumsum": (TypeError, "datetime64 type does not support cumsum operations"), + "diff": (None, ""), + "ffill": (None, ""), + "fillna": (None, ""), + "first": (None, ""), + "idxmax": (None, ""), + "idxmin": (None, ""), + "last": (None, ""), + "max": (None, ""), + "mean": (None, ""), + "median": (None, ""), + "min": (None, ""), + "ngroup": (None, ""), + "nunique": (None, ""), + "pct_change": (TypeError, "cannot perform __truediv__ with this index type"), + "prod": (TypeError, "datetime64 type does not support prod"), + "quantile": (None, ""), + "rank": (None, ""), + "sem": (None, ""), + "shift": (None, ""), + "size": (None, ""), + "skew": ( + TypeError, + "|".join( + [ + r"dtype datetime64\[ns\] does not support reduction", + "datetime64 type does not support skew operations", + ] + ), + ), + "std": (None, ""), + "sum": (TypeError, "datetime64 type does not support sum operations"), + "var": (TypeError, "datetime64 type does not support var operations"), + }[groupby_func] + + if groupby_func in ["any", "all"]: + warn_msg = f"'{groupby_func}' with datetime64 dtypes is deprecated" + elif groupby_func == "fillna": + kind = "Series" if groupby_series else "DataFrame" + warn_msg = f"{kind}GroupBy.fillna is deprecated" + else: + warn_msg = "" + _call_and_check(klass, msg, how, gb, groupby_func, args, warn_msg=warn_msg) + + +@pytest.mark.parametrize("how", ["agg", "transform"]) +def test_groupby_raises_datetime_udf(how, by, groupby_series, df_with_datetime_col): + df = df_with_datetime_col + gb = df.groupby(by=by) + + if groupby_series: + gb = gb["d"] + + def func(x): + raise TypeError("Test error message") + + with pytest.raises(TypeError, match="Test error message"): + getattr(gb, how)(func) + + +@pytest.mark.parametrize("how", ["agg", "transform"]) +@pytest.mark.parametrize("groupby_func_np", [np.sum, np.mean]) +def test_groupby_raises_datetime_np( + how, by, groupby_series, groupby_func_np, df_with_datetime_col +): + # GH#50749 + df = df_with_datetime_col + gb = df.groupby(by=by) + + if groupby_series: + gb = gb["d"] + + klass, msg = { + np.sum: (TypeError, "datetime64 type does not support sum operations"), + np.mean: (None, ""), + }[groupby_func_np] + + if groupby_series: + warn_msg = "using SeriesGroupBy.[sum|mean]" + else: + warn_msg = "using DataFrameGroupBy.[sum|mean]" + _call_and_check(klass, msg, how, gb, groupby_func_np, (), warn_msg=warn_msg) + + +@pytest.mark.parametrize("func", ["prod", "cumprod", "skew", "var"]) +def test_groupby_raises_timedelta(func, df_with_timedelta_col): + df = df_with_timedelta_col + gb = df.groupby(by="a") + + _call_and_check( + TypeError, + "timedelta64 type does not support .* operations", + "method", + gb, + func, + [], + ) + + +@pytest.mark.parametrize("how", ["method", "agg", "transform"]) +def test_groupby_raises_category( + how, by, groupby_series, groupby_func, using_copy_on_write, df_with_cat_col +): + # GH#50749 + df = df_with_cat_col + args = get_groupby_method_args(groupby_func, df) + gb = df.groupby(by=by) + + if groupby_series: + gb = gb["d"] + + if groupby_func == "corrwith": + assert not hasattr(gb, "corrwith") + return + + klass, msg = { + "all": (None, ""), + "any": (None, ""), + "bfill": (None, ""), + "corrwith": ( + TypeError, + r"unsupported operand type\(s\) for \*: 'Categorical' and 'int'", + ), + "count": (None, ""), + "cumcount": (None, ""), + "cummax": ( + (NotImplementedError, TypeError), + "(category type does not support cummax operations|" + "category dtype not supported|" + "cummax is not supported for category dtype)", + ), + "cummin": ( + (NotImplementedError, TypeError), + "(category type does not support cummin operations|" + "category dtype not supported|" + "cummin is not supported for category dtype)", + ), + "cumprod": ( + (NotImplementedError, TypeError), + "(category type does not support cumprod operations|" + "category dtype not supported|" + "cumprod is not supported for category dtype)", + ), + "cumsum": ( + (NotImplementedError, TypeError), + "(category type does not support cumsum operations|" + "category dtype not supported|" + "cumsum is not supported for category dtype)", + ), + "diff": ( + TypeError, + r"unsupported operand type\(s\) for -: 'Categorical' and 'Categorical'", + ), + "ffill": (None, ""), + "fillna": ( + TypeError, + r"Cannot setitem on a Categorical with a new category \(0\), " + "set the categories first", + ) + if not using_copy_on_write + else (None, ""), # no-op with CoW + "first": (None, ""), + "idxmax": (None, ""), + "idxmin": (None, ""), + "last": (None, ""), + "max": (None, ""), + "mean": ( + TypeError, + "|".join( + [ + "'Categorical' .* does not support reduction 'mean'", + "category dtype does not support aggregation 'mean'", + ] + ), + ), + "median": ( + TypeError, + "|".join( + [ + "'Categorical' .* does not support reduction 'median'", + "category dtype does not support aggregation 'median'", + ] + ), + ), + "min": (None, ""), + "ngroup": (None, ""), + "nunique": (None, ""), + "pct_change": ( + TypeError, + r"unsupported operand type\(s\) for /: 'Categorical' and 'Categorical'", + ), + "prod": (TypeError, "category type does not support prod operations"), + "quantile": (TypeError, "No matching signature found"), + "rank": (None, ""), + "sem": ( + TypeError, + "|".join( + [ + "'Categorical' .* does not support reduction 'sem'", + "category dtype does not support aggregation 'sem'", + ] + ), + ), + "shift": (None, ""), + "size": (None, ""), + "skew": ( + TypeError, + "|".join( + [ + "dtype category does not support reduction 'skew'", + "category type does not support skew operations", + ] + ), + ), + "std": ( + TypeError, + "|".join( + [ + "'Categorical' .* does not support reduction 'std'", + "category dtype does not support aggregation 'std'", + ] + ), + ), + "sum": (TypeError, "category type does not support sum operations"), + "var": ( + TypeError, + "|".join( + [ + "'Categorical' .* does not support reduction 'var'", + "category dtype does not support aggregation 'var'", + ] + ), + ), + }[groupby_func] + + if groupby_func == "fillna": + kind = "Series" if groupby_series else "DataFrame" + warn_msg = f"{kind}GroupBy.fillna is deprecated" + else: + warn_msg = "" + _call_and_check(klass, msg, how, gb, groupby_func, args, warn_msg) + + +@pytest.mark.parametrize("how", ["agg", "transform"]) +def test_groupby_raises_category_udf(how, by, groupby_series, df_with_cat_col): + # GH#50749 + df = df_with_cat_col + gb = df.groupby(by=by) + + if groupby_series: + gb = gb["d"] + + def func(x): + raise TypeError("Test error message") + + with pytest.raises(TypeError, match="Test error message"): + getattr(gb, how)(func) + + +@pytest.mark.parametrize("how", ["agg", "transform"]) +@pytest.mark.parametrize("groupby_func_np", [np.sum, np.mean]) +def test_groupby_raises_category_np( + how, by, groupby_series, groupby_func_np, df_with_cat_col +): + # GH#50749 + df = df_with_cat_col + gb = df.groupby(by=by) + + if groupby_series: + gb = gb["d"] + + klass, msg = { + np.sum: (TypeError, "category type does not support sum operations"), + np.mean: ( + TypeError, + "category dtype does not support aggregation 'mean'", + ), + }[groupby_func_np] + + if groupby_series: + warn_msg = "using SeriesGroupBy.[sum|mean]" + else: + warn_msg = "using DataFrameGroupBy.[sum|mean]" + _call_and_check(klass, msg, how, gb, groupby_func_np, (), warn_msg=warn_msg) + + +@pytest.mark.parametrize("how", ["method", "agg", "transform"]) +def test_groupby_raises_category_on_category( + how, + by, + groupby_series, + groupby_func, + observed, + using_copy_on_write, + df_with_cat_col, +): + # GH#50749 + df = df_with_cat_col + df["a"] = Categorical( + ["a", "a", "a", "a", "b", "b", "b", "b", "c"], + categories=["a", "b", "c", "d"], + ordered=True, + ) + args = get_groupby_method_args(groupby_func, df) + gb = df.groupby(by=by, observed=observed) + + if groupby_series: + gb = gb["d"] + + if groupby_func == "corrwith": + assert not hasattr(gb, "corrwith") + return + + empty_groups = not observed and any(group.empty for group in gb.groups.values()) + if ( + not observed + and how != "transform" + and isinstance(by, list) + and isinstance(by[0], str) + and by == ["a", "b"] + ): + assert not empty_groups + # TODO: empty_groups should be true due to unobserved categorical combinations + empty_groups = True + if how == "transform": + # empty groups will be ignored + empty_groups = False + + klass, msg = { + "all": (None, ""), + "any": (None, ""), + "bfill": (None, ""), + "corrwith": ( + TypeError, + r"unsupported operand type\(s\) for \*: 'Categorical' and 'int'", + ), + "count": (None, ""), + "cumcount": (None, ""), + "cummax": ( + (NotImplementedError, TypeError), + "(cummax is not supported for category dtype|" + "category dtype not supported|" + "category type does not support cummax operations)", + ), + "cummin": ( + (NotImplementedError, TypeError), + "(cummin is not supported for category dtype|" + "category dtype not supported|" + "category type does not support cummin operations)", + ), + "cumprod": ( + (NotImplementedError, TypeError), + "(cumprod is not supported for category dtype|" + "category dtype not supported|" + "category type does not support cumprod operations)", + ), + "cumsum": ( + (NotImplementedError, TypeError), + "(cumsum is not supported for category dtype|" + "category dtype not supported|" + "category type does not support cumsum operations)", + ), + "diff": (TypeError, "unsupported operand type"), + "ffill": (None, ""), + "fillna": ( + TypeError, + r"Cannot setitem on a Categorical with a new category \(0\), " + "set the categories first", + ) + if not using_copy_on_write + else (None, ""), # no-op with CoW + "first": (None, ""), + "idxmax": (ValueError, "empty group due to unobserved categories") + if empty_groups + else (None, ""), + "idxmin": (ValueError, "empty group due to unobserved categories") + if empty_groups + else (None, ""), + "last": (None, ""), + "max": (None, ""), + "mean": (TypeError, "category dtype does not support aggregation 'mean'"), + "median": (TypeError, "category dtype does not support aggregation 'median'"), + "min": (None, ""), + "ngroup": (None, ""), + "nunique": (None, ""), + "pct_change": (TypeError, "unsupported operand type"), + "prod": (TypeError, "category type does not support prod operations"), + "quantile": (TypeError, ""), + "rank": (None, ""), + "sem": ( + TypeError, + "|".join( + [ + "'Categorical' .* does not support reduction 'sem'", + "category dtype does not support aggregation 'sem'", + ] + ), + ), + "shift": (None, ""), + "size": (None, ""), + "skew": ( + TypeError, + "|".join( + [ + "category type does not support skew operations", + "dtype category does not support reduction 'skew'", + ] + ), + ), + "std": ( + TypeError, + "|".join( + [ + "'Categorical' .* does not support reduction 'std'", + "category dtype does not support aggregation 'std'", + ] + ), + ), + "sum": (TypeError, "category type does not support sum operations"), + "var": ( + TypeError, + "|".join( + [ + "'Categorical' .* does not support reduction 'var'", + "category dtype does not support aggregation 'var'", + ] + ), + ), + }[groupby_func] + + if groupby_func == "fillna": + kind = "Series" if groupby_series else "DataFrame" + warn_msg = f"{kind}GroupBy.fillna is deprecated" + else: + warn_msg = "" + _call_and_check(klass, msg, how, gb, groupby_func, args, warn_msg) + + +def test_subsetting_columns_axis_1_raises(): + # GH 35443 + df = DataFrame({"a": [1], "b": [2], "c": [3]}) + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + gb = df.groupby("a", axis=1) + with pytest.raises(ValueError, match="Cannot subset columns when using axis=1"): + gb["b"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_reductions.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_reductions.py new file mode 100644 index 0000000000000000000000000000000000000000..25b0f80639cff61bafe9ee13e5acef950f470e64 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_reductions.py @@ -0,0 +1,1176 @@ +import builtins +import datetime as dt +from string import ascii_lowercase + +import numpy as np +import pytest + +from pandas._libs.tslibs import iNaT + +from pandas.core.dtypes.common import pandas_dtype +from pandas.core.dtypes.missing import na_value_for_dtype + +import pandas as pd +from pandas import ( + DataFrame, + MultiIndex, + Series, + Timestamp, + date_range, + isna, +) +import pandas._testing as tm +from pandas.util import _test_decorators as td + + +@pytest.mark.parametrize("agg_func", ["any", "all"]) +@pytest.mark.parametrize( + "vals", + [ + ["foo", "bar", "baz"], + ["foo", "", ""], + ["", "", ""], + [1, 2, 3], + [1, 0, 0], + [0, 0, 0], + [1.0, 2.0, 3.0], + [1.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [True, True, True], + [True, False, False], + [False, False, False], + [np.nan, np.nan, np.nan], + ], +) +def test_groupby_bool_aggs(skipna, agg_func, vals): + df = DataFrame({"key": ["a"] * 3 + ["b"] * 3, "val": vals * 2}) + + # Figure out expectation using Python builtin + exp = getattr(builtins, agg_func)(vals) + + # edge case for missing data with skipna and 'any' + if skipna and all(isna(vals)) and agg_func == "any": + exp = False + + expected = DataFrame( + [exp] * 2, columns=["val"], index=pd.Index(["a", "b"], name="key") + ) + result = getattr(df.groupby("key"), agg_func)(skipna=skipna) + tm.assert_frame_equal(result, expected) + + +def test_any(): + df = DataFrame( + [[1, 2, "foo"], [1, np.nan, "bar"], [3, np.nan, "baz"]], + columns=["A", "B", "C"], + ) + expected = DataFrame( + [[True, True], [False, True]], columns=["B", "C"], index=[1, 3] + ) + expected.index.name = "A" + result = df.groupby("A").any() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("bool_agg_func", ["any", "all"]) +def test_bool_aggs_dup_column_labels(bool_agg_func): + # GH#21668 + df = DataFrame([[True, True]], columns=["a", "a"]) + grp_by = df.groupby([0]) + result = getattr(grp_by, bool_agg_func)() + + expected = df.set_axis(np.array([0])) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("bool_agg_func", ["any", "all"]) +@pytest.mark.parametrize( + "data", + [ + [False, False, False], + [True, True, True], + [pd.NA, pd.NA, pd.NA], + [False, pd.NA, False], + [True, pd.NA, True], + [True, pd.NA, False], + ], +) +def test_masked_kleene_logic(bool_agg_func, skipna, data): + # GH#37506 + ser = Series(data, dtype="boolean") + + # The result should match aggregating on the whole series. Correctness + # there is verified in test_reductions.py::test_any_all_boolean_kleene_logic + expected_data = getattr(ser, bool_agg_func)(skipna=skipna) + expected = Series(expected_data, index=np.array([0]), dtype="boolean") + + result = ser.groupby([0, 0, 0]).agg(bool_agg_func, skipna=skipna) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "dtype1,dtype2,exp_col1,exp_col2", + [ + ( + "float", + "Float64", + np.array([True], dtype=bool), + pd.array([pd.NA], dtype="boolean"), + ), + ( + "Int64", + "float", + pd.array([pd.NA], dtype="boolean"), + np.array([True], dtype=bool), + ), + ( + "Int64", + "Int64", + pd.array([pd.NA], dtype="boolean"), + pd.array([pd.NA], dtype="boolean"), + ), + ( + "Float64", + "boolean", + pd.array([pd.NA], dtype="boolean"), + pd.array([pd.NA], dtype="boolean"), + ), + ], +) +def test_masked_mixed_types(dtype1, dtype2, exp_col1, exp_col2): + # GH#37506 + data = [1.0, np.nan] + df = DataFrame( + {"col1": pd.array(data, dtype=dtype1), "col2": pd.array(data, dtype=dtype2)} + ) + result = df.groupby([1, 1]).agg("all", skipna=False) + + expected = DataFrame({"col1": exp_col1, "col2": exp_col2}, index=np.array([1])) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("bool_agg_func", ["any", "all"]) +@pytest.mark.parametrize("dtype", ["Int64", "Float64", "boolean"]) +def test_masked_bool_aggs_skipna(bool_agg_func, dtype, skipna, frame_or_series): + # GH#40585 + obj = frame_or_series([pd.NA, 1], dtype=dtype) + expected_res = True + if not skipna and bool_agg_func == "all": + expected_res = pd.NA + expected = frame_or_series([expected_res], index=np.array([1]), dtype="boolean") + + result = obj.groupby([1, 1]).agg(bool_agg_func, skipna=skipna) + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "bool_agg_func,data,expected_res", + [ + ("any", [pd.NA, np.nan], False), + ("any", [pd.NA, 1, np.nan], True), + ("all", [pd.NA, pd.NaT], True), + ("all", [pd.NA, False, pd.NaT], False), + ], +) +def test_object_type_missing_vals(bool_agg_func, data, expected_res, frame_or_series): + # GH#37501 + obj = frame_or_series(data, dtype=object) + result = obj.groupby([1] * len(data)).agg(bool_agg_func) + expected = frame_or_series([expected_res], index=np.array([1]), dtype="bool") + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize("bool_agg_func", ["any", "all"]) +def test_object_NA_raises_with_skipna_false(bool_agg_func): + # GH#37501 + ser = Series([pd.NA], dtype=object) + with pytest.raises(TypeError, match="boolean value of NA is ambiguous"): + ser.groupby([1]).agg(bool_agg_func, skipna=False) + + +@pytest.mark.parametrize("bool_agg_func", ["any", "all"]) +def test_empty(frame_or_series, bool_agg_func): + # GH 45231 + kwargs = {"columns": ["a"]} if frame_or_series is DataFrame else {"name": "a"} + obj = frame_or_series(**kwargs, dtype=object) + result = getattr(obj.groupby(obj.index), bool_agg_func)() + expected = frame_or_series(**kwargs, dtype=bool) + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize("how", ["idxmin", "idxmax"]) +def test_idxmin_idxmax_extremes(how, any_real_numpy_dtype): + # GH#57040 + if any_real_numpy_dtype is int or any_real_numpy_dtype is float: + # No need to test + return + info = np.iinfo if "int" in any_real_numpy_dtype else np.finfo + min_value = info(any_real_numpy_dtype).min + max_value = info(any_real_numpy_dtype).max + df = DataFrame( + {"a": [2, 1, 1, 2], "b": [min_value, max_value, max_value, min_value]}, + dtype=any_real_numpy_dtype, + ) + gb = df.groupby("a") + result = getattr(gb, how)() + expected = DataFrame( + {"b": [1, 0]}, index=pd.Index([1, 2], name="a", dtype=any_real_numpy_dtype) + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("how", ["idxmin", "idxmax"]) +def test_idxmin_idxmax_extremes_skipna(skipna, how, float_numpy_dtype): + # GH#57040 + min_value = np.finfo(float_numpy_dtype).min + max_value = np.finfo(float_numpy_dtype).max + df = DataFrame( + { + "a": Series(np.repeat(range(1, 6), repeats=2), dtype="intp"), + "b": Series( + [ + np.nan, + min_value, + np.nan, + max_value, + min_value, + np.nan, + max_value, + np.nan, + np.nan, + np.nan, + ], + dtype=float_numpy_dtype, + ), + }, + ) + gb = df.groupby("a") + + warn = None if skipna else FutureWarning + msg = f"The behavior of DataFrameGroupBy.{how} with all-NA values" + with tm.assert_produces_warning(warn, match=msg): + result = getattr(gb, how)(skipna=skipna) + if skipna: + values = [1, 3, 4, 6, np.nan] + else: + values = np.nan + expected = DataFrame( + {"b": values}, index=pd.Index(range(1, 6), name="a", dtype="intp") + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "func, values", + [ + ("idxmin", {"c_int": [0, 2], "c_float": [1, 3], "c_date": [1, 2]}), + ("idxmax", {"c_int": [1, 3], "c_float": [0, 2], "c_date": [0, 3]}), + ], +) +@pytest.mark.parametrize("numeric_only", [True, False]) +def test_idxmin_idxmax_returns_int_types(func, values, numeric_only): + # GH 25444 + df = DataFrame( + { + "name": ["A", "A", "B", "B"], + "c_int": [1, 2, 3, 4], + "c_float": [4.02, 3.03, 2.04, 1.05], + "c_date": ["2019", "2018", "2016", "2017"], + } + ) + df["c_date"] = pd.to_datetime(df["c_date"]) + df["c_date_tz"] = df["c_date"].dt.tz_localize("US/Pacific") + df["c_timedelta"] = df["c_date"] - df["c_date"].iloc[0] + df["c_period"] = df["c_date"].dt.to_period("W") + df["c_Integer"] = df["c_int"].astype("Int64") + df["c_Floating"] = df["c_float"].astype("Float64") + + result = getattr(df.groupby("name"), func)(numeric_only=numeric_only) + + expected = DataFrame(values, index=pd.Index(["A", "B"], name="name")) + if numeric_only: + expected = expected.drop(columns=["c_date"]) + else: + expected["c_date_tz"] = expected["c_date"] + expected["c_timedelta"] = expected["c_date"] + expected["c_period"] = expected["c_date"] + expected["c_Integer"] = expected["c_int"] + expected["c_Floating"] = expected["c_float"] + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "data", + [ + ( + Timestamp("2011-01-15 12:50:28.502376"), + Timestamp("2011-01-20 12:50:28.593448"), + ), + (24650000000000001, 24650000000000002), + ], +) +@pytest.mark.parametrize("method", ["count", "min", "max", "first", "last"]) +def test_groupby_non_arithmetic_agg_int_like_precision(method, data): + # GH#6620, GH#9311 + df = DataFrame({"a": [1, 1], "b": data}) + + grouped = df.groupby("a") + result = getattr(grouped, method)() + if method == "count": + expected_value = 2 + elif method == "first": + expected_value = data[0] + elif method == "last": + expected_value = data[1] + else: + expected_value = getattr(df["b"], method)() + expected = DataFrame({"b": [expected_value]}, index=pd.Index([1], name="a")) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("how", ["first", "last"]) +def test_first_last_skipna(any_real_nullable_dtype, sort, skipna, how): + # GH#57019 + na_value = na_value_for_dtype(pandas_dtype(any_real_nullable_dtype)) + df = DataFrame( + { + "a": [2, 1, 1, 2, 3, 3], + "b": [na_value, 3.0, na_value, 4.0, np.nan, np.nan], + "c": [na_value, 3.0, na_value, 4.0, np.nan, np.nan], + }, + dtype=any_real_nullable_dtype, + ) + gb = df.groupby("a", sort=sort) + method = getattr(gb, how) + result = method(skipna=skipna) + + ilocs = { + ("first", True): [3, 1, 4], + ("first", False): [0, 1, 4], + ("last", True): [3, 1, 5], + ("last", False): [3, 2, 5], + }[how, skipna] + expected = df.iloc[ilocs].set_index("a") + if sort: + expected = expected.sort_index() + tm.assert_frame_equal(result, expected) + + +def test_idxmin_idxmax_axis1(): + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), columns=["A", "B", "C", "D"] + ) + df["A"] = [1, 2, 3, 1, 2, 3, 1, 2, 3, 4] + + gb = df.groupby("A") + + warn_msg = "DataFrameGroupBy.idxmax with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=warn_msg): + res = gb.idxmax(axis=1) + + alt = df.iloc[:, 1:].idxmax(axis=1) + indexer = res.index.get_level_values(1) + + tm.assert_series_equal(alt[indexer], res.droplevel("A")) + + df["E"] = date_range("2016-01-01", periods=10) + gb2 = df.groupby("A") + + msg = "'>' not supported between instances of 'Timestamp' and 'float'" + with pytest.raises(TypeError, match=msg): + with tm.assert_produces_warning(FutureWarning, match=warn_msg): + gb2.idxmax(axis=1) + + +def test_groupby_mean_no_overflow(): + # Regression test for (#22487) + df = DataFrame( + { + "user": ["A", "A", "A", "A", "A"], + "connections": [4970, 4749, 4719, 4704, 18446744073699999744], + } + ) + assert df.groupby("user")["connections"].mean()["A"] == 3689348814740003840 + + +def test_mean_on_timedelta(): + # GH 17382 + df = DataFrame({"time": pd.to_timedelta(range(10)), "cat": ["A", "B"] * 5}) + result = df.groupby("cat")["time"].mean() + expected = Series( + pd.to_timedelta([4, 5]), name="time", index=pd.Index(["A", "B"], name="cat") + ) + tm.assert_series_equal(result, expected) + + +def test_cython_median(): + arr = np.random.default_rng(2).standard_normal(1000) + arr[::2] = np.nan + df = DataFrame(arr) + + labels = np.random.default_rng(2).integers(0, 50, size=1000).astype(float) + labels[::17] = np.nan + + result = df.groupby(labels).median() + msg = "using DataFrameGroupBy.median" + with tm.assert_produces_warning(FutureWarning, match=msg): + exp = df.groupby(labels).agg(np.nanmedian) + tm.assert_frame_equal(result, exp) + + df = DataFrame(np.random.default_rng(2).standard_normal((1000, 5))) + msg = "using DataFrameGroupBy.median" + with tm.assert_produces_warning(FutureWarning, match=msg): + rs = df.groupby(labels).agg(np.median) + xp = df.groupby(labels).median() + tm.assert_frame_equal(rs, xp) + + +def test_median_empty_bins(observed): + df = DataFrame(np.random.default_rng(2).integers(0, 44, 500)) + + grps = range(0, 55, 5) + bins = pd.cut(df[0], grps) + + result = df.groupby(bins, observed=observed).median() + expected = df.groupby(bins, observed=observed).agg(lambda x: x.median()) + tm.assert_frame_equal(result, expected) + + +def test_max_min_non_numeric(): + # #2700 + aa = DataFrame({"nn": [11, 11, 22, 22], "ii": [1, 2, 3, 4], "ss": 4 * ["mama"]}) + + result = aa.groupby("nn").max() + assert "ss" in result + + result = aa.groupby("nn").max(numeric_only=False) + assert "ss" in result + + result = aa.groupby("nn").min() + assert "ss" in result + + result = aa.groupby("nn").min(numeric_only=False) + assert "ss" in result + + +def test_max_min_object_multiple_columns(using_array_manager): + # GH#41111 case where the aggregation is valid for some columns but not + # others; we split object blocks column-wise, consistent with + # DataFrame._reduce + + df = DataFrame( + { + "A": [1, 1, 2, 2, 3], + "B": [1, "foo", 2, "bar", False], + "C": ["a", "b", "c", "d", "e"], + } + ) + df._consolidate_inplace() # should already be consolidate, but double-check + if not using_array_manager: + assert len(df._mgr.blocks) == 2 + + gb = df.groupby("A") + + result = gb[["C"]].max() + # "max" is valid for column "C" but not for "B" + ei = pd.Index([1, 2, 3], name="A") + expected = DataFrame({"C": ["b", "d", "e"]}, index=ei) + tm.assert_frame_equal(result, expected) + + result = gb[["C"]].min() + # "min" is valid for column "C" but not for "B" + ei = pd.Index([1, 2, 3], name="A") + expected = DataFrame({"C": ["a", "c", "e"]}, index=ei) + tm.assert_frame_equal(result, expected) + + +def test_min_date_with_nans(): + # GH26321 + dates = pd.to_datetime( + Series(["2019-05-09", "2019-05-09", "2019-05-09"]), format="%Y-%m-%d" + ).dt.date + df = DataFrame({"a": [np.nan, "1", np.nan], "b": [0, 1, 1], "c": dates}) + + result = df.groupby("b", as_index=False)["c"].min()["c"] + expected = pd.to_datetime( + Series(["2019-05-09", "2019-05-09"], name="c"), format="%Y-%m-%d" + ).dt.date + tm.assert_series_equal(result, expected) + + result = df.groupby("b")["c"].min() + expected.index.name = "b" + tm.assert_series_equal(result, expected) + + +def test_max_inat(): + # GH#40767 dont interpret iNaT as NaN + ser = Series([1, iNaT]) + key = np.array([1, 1], dtype=np.int64) + gb = ser.groupby(key) + + result = gb.max(min_count=2) + expected = Series({1: 1}, dtype=np.int64) + tm.assert_series_equal(result, expected, check_exact=True) + + result = gb.min(min_count=2) + expected = Series({1: iNaT}, dtype=np.int64) + tm.assert_series_equal(result, expected, check_exact=True) + + # not enough entries -> gets masked to NaN + result = gb.min(min_count=3) + expected = Series({1: np.nan}) + tm.assert_series_equal(result, expected, check_exact=True) + + +def test_max_inat_not_all_na(): + # GH#40767 dont interpret iNaT as NaN + + # make sure we dont round iNaT+1 to iNaT + ser = Series([1, iNaT, 2, iNaT + 1]) + gb = ser.groupby([1, 2, 3, 3]) + result = gb.min(min_count=2) + + # Note: in converting to float64, the iNaT + 1 maps to iNaT, i.e. is lossy + expected = Series({1: np.nan, 2: np.nan, 3: iNaT + 1}) + expected.index = expected.index.astype(int) + tm.assert_series_equal(result, expected, check_exact=True) + + +@pytest.mark.parametrize("func", ["min", "max"]) +def test_groupby_aggregate_period_column(func): + # GH 31471 + groups = [1, 2] + periods = pd.period_range("2020", periods=2, freq="Y") + df = DataFrame({"a": groups, "b": periods}) + + result = getattr(df.groupby("a")["b"], func)() + idx = pd.Index([1, 2], name="a") + expected = Series(periods, index=idx, name="b") + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("func", ["min", "max"]) +def test_groupby_aggregate_period_frame(func): + # GH 31471 + groups = [1, 2] + periods = pd.period_range("2020", periods=2, freq="Y") + df = DataFrame({"a": groups, "b": periods}) + + result = getattr(df.groupby("a"), func)() + idx = pd.Index([1, 2], name="a") + expected = DataFrame({"b": periods}, index=idx) + + tm.assert_frame_equal(result, expected) + + +def test_aggregate_numeric_object_dtype(): + # https://github.com/pandas-dev/pandas/issues/39329 + # simplified case: multiple object columns where one is all-NaN + # -> gets split as the all-NaN is inferred as float + df = DataFrame( + {"key": ["A", "A", "B", "B"], "col1": list("abcd"), "col2": [np.nan] * 4}, + ).astype(object) + result = df.groupby("key").min() + expected = ( + DataFrame( + {"key": ["A", "B"], "col1": ["a", "c"], "col2": [np.nan, np.nan]}, + ) + .set_index("key") + .astype(object) + ) + tm.assert_frame_equal(result, expected) + + # same but with numbers + df = DataFrame( + {"key": ["A", "A", "B", "B"], "col1": list("abcd"), "col2": range(4)}, + ).astype(object) + result = df.groupby("key").min() + expected = ( + DataFrame({"key": ["A", "B"], "col1": ["a", "c"], "col2": [0, 2]}) + .set_index("key") + .astype(object) + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("func", ["min", "max"]) +def test_aggregate_categorical_lost_index(func: str): + # GH: 28641 groupby drops index, when grouping over categorical column with min/max + ds = Series(["b"], dtype="category").cat.as_ordered() + df = DataFrame({"A": [1997], "B": ds}) + result = df.groupby("A").agg({"B": func}) + expected = DataFrame({"B": ["b"]}, index=pd.Index([1997], name="A")) + + # ordered categorical dtype should be preserved + expected["B"] = expected["B"].astype(ds.dtype) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype", ["Int64", "Int32", "Float64", "Float32", "boolean"]) +def test_groupby_min_max_nullable(dtype): + if dtype == "Int64": + # GH#41743 avoid precision loss + ts = 1618556707013635762 + elif dtype == "boolean": + ts = 0 + else: + ts = 4.0 + + df = DataFrame({"id": [2, 2], "ts": [ts, ts + 1]}) + df["ts"] = df["ts"].astype(dtype) + + gb = df.groupby("id") + + result = gb.min() + expected = df.iloc[:1].set_index("id") + tm.assert_frame_equal(result, expected) + + res_max = gb.max() + expected_max = df.iloc[1:].set_index("id") + tm.assert_frame_equal(res_max, expected_max) + + result2 = gb.min(min_count=3) + expected2 = DataFrame({"ts": [pd.NA]}, index=expected.index, dtype=dtype) + tm.assert_frame_equal(result2, expected2) + + res_max2 = gb.max(min_count=3) + tm.assert_frame_equal(res_max2, expected2) + + # Case with NA values + df2 = DataFrame({"id": [2, 2, 2], "ts": [ts, pd.NA, ts + 1]}) + df2["ts"] = df2["ts"].astype(dtype) + gb2 = df2.groupby("id") + + result3 = gb2.min() + tm.assert_frame_equal(result3, expected) + + res_max3 = gb2.max() + tm.assert_frame_equal(res_max3, expected_max) + + result4 = gb2.min(min_count=100) + tm.assert_frame_equal(result4, expected2) + + res_max4 = gb2.max(min_count=100) + tm.assert_frame_equal(res_max4, expected2) + + +def test_min_max_nullable_uint64_empty_group(): + # don't raise NotImplementedError from libgroupby + cat = pd.Categorical([0] * 10, categories=[0, 1]) + df = DataFrame({"A": cat, "B": pd.array(np.arange(10, dtype=np.uint64))}) + gb = df.groupby("A", observed=False) + + res = gb.min() + + idx = pd.CategoricalIndex([0, 1], dtype=cat.dtype, name="A") + expected = DataFrame({"B": pd.array([0, pd.NA], dtype="UInt64")}, index=idx) + tm.assert_frame_equal(res, expected) + + res = gb.max() + expected.iloc[0, 0] = 9 + tm.assert_frame_equal(res, expected) + + +@pytest.mark.parametrize("func", ["first", "last", "min", "max"]) +def test_groupby_min_max_categorical(func): + # GH: 52151 + df = DataFrame( + { + "col1": pd.Categorical(["A"], categories=list("AB"), ordered=True), + "col2": pd.Categorical([1], categories=[1, 2], ordered=True), + "value": 0.1, + } + ) + result = getattr(df.groupby("col1", observed=False), func)() + + idx = pd.CategoricalIndex(data=["A", "B"], name="col1", ordered=True) + expected = DataFrame( + { + "col2": pd.Categorical([1, None], categories=[1, 2], ordered=True), + "value": [0.1, None], + }, + index=idx, + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("func", ["min", "max"]) +def test_min_empty_string_dtype(func): + # GH#55619 + pytest.importorskip("pyarrow") + dtype = "string[pyarrow_numpy]" + df = DataFrame({"a": ["a"], "b": "a", "c": "a"}, dtype=dtype).iloc[:0] + result = getattr(df.groupby("a"), func)() + expected = DataFrame( + columns=["b", "c"], dtype=dtype, index=pd.Index([], dtype=dtype, name="a") + ) + tm.assert_frame_equal(result, expected) + + +def test_max_nan_bug(): + df = DataFrame( + { + "Unnamed: 0": ["-04-23", "-05-06", "-05-07"], + "Date": [ + "2013-04-23 00:00:00", + "2013-05-06 00:00:00", + "2013-05-07 00:00:00", + ], + "app": Series([np.nan, np.nan, "OE"]), + "File": ["log080001.log", "log.log", "xlsx"], + } + ) + gb = df.groupby("Date") + r = gb[["File"]].max() + e = gb["File"].max().to_frame() + tm.assert_frame_equal(r, e) + assert not r["File"].isna().any() + + +@pytest.mark.slow +@pytest.mark.parametrize("sort", [False, True]) +@pytest.mark.parametrize("dropna", [False, True]) +@pytest.mark.parametrize("as_index", [True, False]) +@pytest.mark.parametrize("with_nan", [True, False]) +@pytest.mark.parametrize("keys", [["joe"], ["joe", "jim"]]) +def test_series_groupby_nunique(sort, dropna, as_index, with_nan, keys): + n = 100 + m = 10 + days = date_range("2015-08-23", periods=10) + df = DataFrame( + { + "jim": np.random.default_rng(2).choice(list(ascii_lowercase), n), + "joe": np.random.default_rng(2).choice(days, n), + "julie": np.random.default_rng(2).integers(0, m, n), + } + ) + if with_nan: + df = df.astype({"julie": float}) # Explicit cast to avoid implicit cast below + df.loc[1::17, "jim"] = None + df.loc[3::37, "joe"] = None + df.loc[7::19, "julie"] = None + df.loc[8::19, "julie"] = None + df.loc[9::19, "julie"] = None + original_df = df.copy() + gr = df.groupby(keys, as_index=as_index, sort=sort) + left = gr["julie"].nunique(dropna=dropna) + + gr = df.groupby(keys, as_index=as_index, sort=sort) + right = gr["julie"].apply(Series.nunique, dropna=dropna) + if not as_index: + right = right.reset_index(drop=True) + + if as_index: + tm.assert_series_equal(left, right, check_names=False) + else: + tm.assert_frame_equal(left, right, check_names=False) + tm.assert_frame_equal(df, original_df) + + +def test_nunique(): + df = DataFrame({"A": list("abbacc"), "B": list("abxacc"), "C": list("abbacx")}) + + expected = DataFrame({"A": list("abc"), "B": [1, 2, 1], "C": [1, 1, 2]}) + result = df.groupby("A", as_index=False).nunique() + tm.assert_frame_equal(result, expected) + + # as_index + expected.index = list("abc") + expected.index.name = "A" + expected = expected.drop(columns="A") + result = df.groupby("A").nunique() + tm.assert_frame_equal(result, expected) + + # with na + result = df.replace({"x": None}).groupby("A").nunique(dropna=False) + tm.assert_frame_equal(result, expected) + + # dropna + expected = DataFrame({"B": [1] * 3, "C": [1] * 3}, index=list("abc")) + expected.index.name = "A" + result = df.replace({"x": None}).groupby("A").nunique() + tm.assert_frame_equal(result, expected) + + +def test_nunique_with_object(): + # GH 11077 + data = DataFrame( + [ + [100, 1, "Alice"], + [200, 2, "Bob"], + [300, 3, "Charlie"], + [-400, 4, "Dan"], + [500, 5, "Edith"], + ], + columns=["amount", "id", "name"], + ) + + result = data.groupby(["id", "amount"])["name"].nunique() + index = MultiIndex.from_arrays([data.id, data.amount]) + expected = Series([1] * 5, name="name", index=index) + tm.assert_series_equal(result, expected) + + +def test_nunique_with_empty_series(): + # GH 12553 + data = Series(name="name", dtype=object) + result = data.groupby(level=0).nunique() + expected = Series(name="name", dtype="int64") + tm.assert_series_equal(result, expected) + + +def test_nunique_with_timegrouper(): + # GH 13453 + test = DataFrame( + { + "time": [ + Timestamp("2016-06-28 09:35:35"), + Timestamp("2016-06-28 16:09:30"), + Timestamp("2016-06-28 16:46:28"), + ], + "data": ["1", "2", "3"], + } + ).set_index("time") + result = test.groupby(pd.Grouper(freq="h"))["data"].nunique() + expected = test.groupby(pd.Grouper(freq="h"))["data"].apply(Series.nunique) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "key, data, dropna, expected", + [ + ( + ["x", "x", "x"], + [Timestamp("2019-01-01"), pd.NaT, Timestamp("2019-01-01")], + True, + Series([1], index=pd.Index(["x"], name="key"), name="data"), + ), + ( + ["x", "x", "x"], + [dt.date(2019, 1, 1), pd.NaT, dt.date(2019, 1, 1)], + True, + Series([1], index=pd.Index(["x"], name="key"), name="data"), + ), + ( + ["x", "x", "x", "y", "y"], + [ + dt.date(2019, 1, 1), + pd.NaT, + dt.date(2019, 1, 1), + pd.NaT, + dt.date(2019, 1, 1), + ], + False, + Series([2, 2], index=pd.Index(["x", "y"], name="key"), name="data"), + ), + ( + ["x", "x", "x", "x", "y"], + [ + dt.date(2019, 1, 1), + pd.NaT, + dt.date(2019, 1, 1), + pd.NaT, + dt.date(2019, 1, 1), + ], + False, + Series([2, 1], index=pd.Index(["x", "y"], name="key"), name="data"), + ), + ], +) +def test_nunique_with_NaT(key, data, dropna, expected): + # GH 27951 + df = DataFrame({"key": key, "data": data}) + result = df.groupby(["key"])["data"].nunique(dropna=dropna) + tm.assert_series_equal(result, expected) + + +def test_nunique_preserves_column_level_names(): + # GH 23222 + test = DataFrame([1, 2, 2], columns=pd.Index(["A"], name="level_0")) + result = test.groupby([0, 0, 0]).nunique() + expected = DataFrame([2], index=np.array([0]), columns=test.columns) + tm.assert_frame_equal(result, expected) + + +def test_nunique_transform_with_datetime(): + # GH 35109 - transform with nunique on datetimes results in integers + df = DataFrame(date_range("2008-12-31", "2009-01-02"), columns=["date"]) + result = df.groupby([0, 0, 1])["date"].transform("nunique") + expected = Series([2, 2, 1], name="date") + tm.assert_series_equal(result, expected) + + +def test_empty_categorical(observed): + # GH#21334 + cat = Series([1]).astype("category") + ser = cat[:0] + gb = ser.groupby(ser, observed=observed) + result = gb.nunique() + if observed: + expected = Series([], index=cat[:0], dtype="int64") + else: + expected = Series([0], index=cat, dtype="int64") + tm.assert_series_equal(result, expected) + + +def test_intercept_builtin_sum(): + s = Series([1.0, 2.0, np.nan, 3.0]) + grouped = s.groupby([0, 1, 2, 2]) + + msg = "using SeriesGroupBy.sum" + with tm.assert_produces_warning(FutureWarning, match=msg): + # GH#53425 + result = grouped.agg(builtins.sum) + msg = "using np.sum" + with tm.assert_produces_warning(FutureWarning, match=msg): + # GH#53425 + result2 = grouped.apply(builtins.sum) + expected = grouped.sum() + tm.assert_series_equal(result, expected) + tm.assert_series_equal(result2, expected) + + +@pytest.mark.parametrize("min_count", [0, 10]) +def test_groupby_sum_mincount_boolean(min_count): + b = True + a = False + na = np.nan + dfg = pd.array([b, b, na, na, a, a, b], dtype="boolean") + + df = DataFrame({"A": [1, 1, 2, 2, 3, 3, 1], "B": dfg}) + result = df.groupby("A").sum(min_count=min_count) + if min_count == 0: + expected = DataFrame( + {"B": pd.array([3, 0, 0], dtype="Int64")}, + index=pd.Index([1, 2, 3], name="A"), + ) + tm.assert_frame_equal(result, expected) + else: + expected = DataFrame( + {"B": pd.array([pd.NA] * 3, dtype="Int64")}, + index=pd.Index([1, 2, 3], name="A"), + ) + tm.assert_frame_equal(result, expected) + + +def test_groupby_sum_below_mincount_nullable_integer(): + # https://github.com/pandas-dev/pandas/issues/32861 + df = DataFrame({"a": [0, 1, 2], "b": [0, 1, 2], "c": [0, 1, 2]}, dtype="Int64") + grouped = df.groupby("a") + idx = pd.Index([0, 1, 2], name="a", dtype="Int64") + + result = grouped["b"].sum(min_count=2) + expected = Series([pd.NA] * 3, dtype="Int64", index=idx, name="b") + tm.assert_series_equal(result, expected) + + result = grouped.sum(min_count=2) + expected = DataFrame({"b": [pd.NA] * 3, "c": [pd.NA] * 3}, dtype="Int64", index=idx) + tm.assert_frame_equal(result, expected) + + +def test_groupby_sum_timedelta_with_nat(): + # GH#42659 + df = DataFrame( + { + "a": [1, 1, 2, 2], + "b": [pd.Timedelta("1d"), pd.Timedelta("2d"), pd.Timedelta("3d"), pd.NaT], + } + ) + td3 = pd.Timedelta(days=3) + + gb = df.groupby("a") + + res = gb.sum() + expected = DataFrame({"b": [td3, td3]}, index=pd.Index([1, 2], name="a")) + tm.assert_frame_equal(res, expected) + + res = gb["b"].sum() + tm.assert_series_equal(res, expected["b"]) + + res = gb["b"].sum(min_count=2) + expected = Series([td3, pd.NaT], dtype="m8[ns]", name="b", index=expected.index) + tm.assert_series_equal(res, expected) + + +@pytest.mark.parametrize( + "dtype", ["int8", "int16", "int32", "int64", "float32", "float64", "uint64"] +) +@pytest.mark.parametrize( + "method,data", + [ + ("first", {"df": [{"a": 1, "b": 1}, {"a": 2, "b": 3}]}), + ("last", {"df": [{"a": 1, "b": 2}, {"a": 2, "b": 4}]}), + ("min", {"df": [{"a": 1, "b": 1}, {"a": 2, "b": 3}]}), + ("max", {"df": [{"a": 1, "b": 2}, {"a": 2, "b": 4}]}), + ("count", {"df": [{"a": 1, "b": 2}, {"a": 2, "b": 2}], "out_type": "int64"}), + ], +) +def test_groupby_non_arithmetic_agg_types(dtype, method, data): + # GH9311, GH6620 + df = DataFrame( + [{"a": 1, "b": 1}, {"a": 1, "b": 2}, {"a": 2, "b": 3}, {"a": 2, "b": 4}] + ) + + df["b"] = df.b.astype(dtype) + + if "args" not in data: + data["args"] = [] + + if "out_type" in data: + out_type = data["out_type"] + else: + out_type = dtype + + exp = data["df"] + df_out = DataFrame(exp) + + df_out["b"] = df_out.b.astype(out_type) + df_out.set_index("a", inplace=True) + + grpd = df.groupby("a") + t = getattr(grpd, method)(*data["args"]) + tm.assert_frame_equal(t, df_out) + + +def scipy_sem(*args, **kwargs): + from scipy.stats import sem + + return sem(*args, ddof=1, **kwargs) + + +@pytest.mark.parametrize( + "op,targop", + [ + ("mean", np.mean), + ("median", np.median), + ("std", np.std), + ("var", np.var), + ("sum", np.sum), + ("prod", np.prod), + ("min", np.min), + ("max", np.max), + ("first", lambda x: x.iloc[0]), + ("last", lambda x: x.iloc[-1]), + ("count", np.size), + pytest.param("sem", scipy_sem, marks=td.skip_if_no("scipy")), + ], +) +def test_ops_general(op, targop): + df = DataFrame(np.random.default_rng(2).standard_normal(1000)) + labels = np.random.default_rng(2).integers(0, 50, size=1000).astype(float) + + result = getattr(df.groupby(labels), op)() + warn = None if op in ("first", "last", "count", "sem") else FutureWarning + msg = f"using DataFrameGroupBy.{op}" + with tm.assert_produces_warning(warn, match=msg): + expected = df.groupby(labels).agg(targop) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "values", + [ + { + "a": [1, 1, 1, 2, 2, 2, 3, 3, 3], + "b": [1, pd.NA, 2, 1, pd.NA, 2, 1, pd.NA, 2], + }, + {"a": [1, 1, 2, 2, 3, 3], "b": [1, 2, 1, 2, 1, 2]}, + ], +) +@pytest.mark.parametrize("function", ["mean", "median", "var"]) +def test_apply_to_nullable_integer_returns_float(values, function): + # https://github.com/pandas-dev/pandas/issues/32219 + output = 0.5 if function == "var" else 1.5 + arr = np.array([output] * 3, dtype=float) + idx = pd.Index([1, 2, 3], name="a", dtype="Int64") + expected = DataFrame({"b": arr}, index=idx).astype("Float64") + + groups = DataFrame(values, dtype="Int64").groupby("a") + + result = getattr(groups, function)() + tm.assert_frame_equal(result, expected) + + result = groups.agg(function) + tm.assert_frame_equal(result, expected) + + result = groups.agg([function]) + expected.columns = MultiIndex.from_tuples([("b", function)]) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "op", + [ + "sum", + "prod", + "min", + "max", + "median", + "mean", + "skew", + "std", + "var", + "sem", + ], +) +@pytest.mark.parametrize("axis", [0, 1]) +@pytest.mark.parametrize("skipna", [True, False]) +@pytest.mark.parametrize("sort", [True, False]) +def test_regression_allowlist_methods(op, axis, skipna, sort): + # GH6944 + # GH 17537 + # explicitly test the allowlist methods + raw_frame = DataFrame([0]) + if axis == 0: + frame = raw_frame + msg = "The 'axis' keyword in DataFrame.groupby is deprecated and will be" + else: + frame = raw_frame.T + msg = "DataFrame.groupby with axis=1 is deprecated" + + with tm.assert_produces_warning(FutureWarning, match=msg): + grouped = frame.groupby(level=0, axis=axis, sort=sort) + + if op == "skew": + # skew has skipna + result = getattr(grouped, op)(skipna=skipna) + expected = frame.groupby(level=0).apply( + lambda h: getattr(h, op)(axis=axis, skipna=skipna) + ) + if sort: + expected = expected.sort_index(axis=axis) + tm.assert_frame_equal(result, expected) + else: + result = getattr(grouped, op)() + expected = frame.groupby(level=0).apply(lambda h: getattr(h, op)(axis=axis)) + if sort: + expected = expected.sort_index(axis=axis) + tm.assert_frame_equal(result, expected) + + +def test_groupby_prod_with_int64_dtype(): + # GH#46573 + data = [ + [1, 11], + [1, 41], + [1, 17], + [1, 37], + [1, 7], + [1, 29], + [1, 31], + [1, 2], + [1, 3], + [1, 43], + [1, 5], + [1, 47], + [1, 19], + [1, 88], + ] + df = DataFrame(data, columns=["A", "B"], dtype="int64") + result = df.groupby(["A"]).prod().reset_index() + expected = DataFrame({"A": [1], "B": [180970905912331920]}, dtype="int64") + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_timegrouper.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_timegrouper.py new file mode 100644 index 0000000000000000000000000000000000000000..8ef7c2b8ce859d399abf4972aa040e347a7e91e1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/test_timegrouper.py @@ -0,0 +1,963 @@ +""" +test with the TimeGrouper / grouping with datetimes +""" +from datetime import ( + datetime, + timedelta, +) + +import numpy as np +import pytest +import pytz + +import pandas as pd +from pandas import ( + DataFrame, + DatetimeIndex, + Index, + MultiIndex, + Series, + Timestamp, + date_range, + offsets, +) +import pandas._testing as tm +from pandas.core.groupby.grouper import Grouper +from pandas.core.groupby.ops import BinGrouper + + +@pytest.fixture +def frame_for_truncated_bingrouper(): + """ + DataFrame used by groupby_with_truncated_bingrouper, made into + a separate fixture for easier reuse in + test_groupby_apply_timegrouper_with_nat_apply_squeeze + """ + df = DataFrame( + { + "Quantity": [18, 3, 5, 1, 9, 3], + "Date": [ + Timestamp(2013, 9, 1, 13, 0), + Timestamp(2013, 9, 1, 13, 5), + Timestamp(2013, 10, 1, 20, 0), + Timestamp(2013, 10, 3, 10, 0), + pd.NaT, + Timestamp(2013, 9, 2, 14, 0), + ], + } + ) + return df + + +@pytest.fixture +def groupby_with_truncated_bingrouper(frame_for_truncated_bingrouper): + """ + GroupBy object such that gb._grouper is a BinGrouper and + len(gb._grouper.result_index) < len(gb._grouper.group_keys_seq) + + Aggregations on this groupby should have + + dti = date_range("2013-09-01", "2013-10-01", freq="5D", name="Date") + + As either the index or an index level. + """ + df = frame_for_truncated_bingrouper + + tdg = Grouper(key="Date", freq="5D") + gb = df.groupby(tdg) + + # check we're testing the case we're interested in + assert len(gb._grouper.result_index) != len(gb._grouper.group_keys_seq) + + return gb + + +class TestGroupBy: + def test_groupby_with_timegrouper(self): + # GH 4161 + # TimeGrouper requires a sorted index + # also verifies that the resultant index has the correct name + df_original = DataFrame( + { + "Buyer": "Carl Carl Carl Carl Joe Carl".split(), + "Quantity": [18, 3, 5, 1, 9, 3], + "Date": [ + datetime(2013, 9, 1, 13, 0), + datetime(2013, 9, 1, 13, 5), + datetime(2013, 10, 1, 20, 0), + datetime(2013, 10, 3, 10, 0), + datetime(2013, 12, 2, 12, 0), + datetime(2013, 9, 2, 14, 0), + ], + } + ) + + # GH 6908 change target column's order + df_reordered = df_original.sort_values(by="Quantity") + + for df in [df_original, df_reordered]: + df = df.set_index(["Date"]) + + exp_dti = date_range( + "20130901", + "20131205", + freq="5D", + name="Date", + inclusive="left", + unit=df.index.unit, + ) + expected = DataFrame( + {"Buyer": 0, "Quantity": 0}, + index=exp_dti, + ) + # Cast to object to avoid implicit cast when setting entry to "CarlCarlCarl" + expected = expected.astype({"Buyer": object}) + expected.iloc[0, 0] = "CarlCarlCarl" + expected.iloc[6, 0] = "CarlCarl" + expected.iloc[18, 0] = "Joe" + expected.iloc[[0, 6, 18], 1] = np.array([24, 6, 9], dtype="int64") + + result1 = df.resample("5D").sum() + tm.assert_frame_equal(result1, expected) + + df_sorted = df.sort_index() + result2 = df_sorted.groupby(Grouper(freq="5D")).sum() + tm.assert_frame_equal(result2, expected) + + result3 = df.groupby(Grouper(freq="5D")).sum() + tm.assert_frame_equal(result3, expected) + + @pytest.mark.parametrize("should_sort", [True, False]) + def test_groupby_with_timegrouper_methods(self, should_sort): + # GH 3881 + # make sure API of timegrouper conforms + + df = DataFrame( + { + "Branch": "A A A A A B".split(), + "Buyer": "Carl Mark Carl Joe Joe Carl".split(), + "Quantity": [1, 3, 5, 8, 9, 3], + "Date": [ + datetime(2013, 1, 1, 13, 0), + datetime(2013, 1, 1, 13, 5), + datetime(2013, 10, 1, 20, 0), + datetime(2013, 10, 2, 10, 0), + datetime(2013, 12, 2, 12, 0), + datetime(2013, 12, 2, 14, 0), + ], + } + ) + + if should_sort: + df = df.sort_values(by="Quantity", ascending=False) + + df = df.set_index("Date", drop=False) + g = df.groupby(Grouper(freq="6ME")) + assert g.group_keys + + assert isinstance(g._grouper, BinGrouper) + groups = g.groups + assert isinstance(groups, dict) + assert len(groups) == 3 + + def test_timegrouper_with_reg_groups(self): + # GH 3794 + # allow combination of timegrouper/reg groups + + df_original = DataFrame( + { + "Branch": "A A A A A A A B".split(), + "Buyer": "Carl Mark Carl Carl Joe Joe Joe Carl".split(), + "Quantity": [1, 3, 5, 1, 8, 1, 9, 3], + "Date": [ + datetime(2013, 1, 1, 13, 0), + datetime(2013, 1, 1, 13, 5), + datetime(2013, 10, 1, 20, 0), + datetime(2013, 10, 2, 10, 0), + datetime(2013, 10, 1, 20, 0), + datetime(2013, 10, 2, 10, 0), + datetime(2013, 12, 2, 12, 0), + datetime(2013, 12, 2, 14, 0), + ], + } + ).set_index("Date") + + df_sorted = df_original.sort_values(by="Quantity", ascending=False) + + for df in [df_original, df_sorted]: + expected = DataFrame( + { + "Buyer": "Carl Joe Mark".split(), + "Quantity": [10, 18, 3], + "Date": [ + datetime(2013, 12, 31, 0, 0), + datetime(2013, 12, 31, 0, 0), + datetime(2013, 12, 31, 0, 0), + ], + } + ).set_index(["Date", "Buyer"]) + + msg = "The default value of numeric_only" + result = df.groupby([Grouper(freq="YE"), "Buyer"]).sum(numeric_only=True) + tm.assert_frame_equal(result, expected) + + expected = DataFrame( + { + "Buyer": "Carl Mark Carl Joe".split(), + "Quantity": [1, 3, 9, 18], + "Date": [ + datetime(2013, 1, 1, 0, 0), + datetime(2013, 1, 1, 0, 0), + datetime(2013, 7, 1, 0, 0), + datetime(2013, 7, 1, 0, 0), + ], + } + ).set_index(["Date", "Buyer"]) + result = df.groupby([Grouper(freq="6MS"), "Buyer"]).sum(numeric_only=True) + tm.assert_frame_equal(result, expected) + + df_original = DataFrame( + { + "Branch": "A A A A A A A B".split(), + "Buyer": "Carl Mark Carl Carl Joe Joe Joe Carl".split(), + "Quantity": [1, 3, 5, 1, 8, 1, 9, 3], + "Date": [ + datetime(2013, 10, 1, 13, 0), + datetime(2013, 10, 1, 13, 5), + datetime(2013, 10, 1, 20, 0), + datetime(2013, 10, 2, 10, 0), + datetime(2013, 10, 1, 20, 0), + datetime(2013, 10, 2, 10, 0), + datetime(2013, 10, 2, 12, 0), + datetime(2013, 10, 2, 14, 0), + ], + } + ).set_index("Date") + + df_sorted = df_original.sort_values(by="Quantity", ascending=False) + for df in [df_original, df_sorted]: + expected = DataFrame( + { + "Buyer": "Carl Joe Mark Carl Joe".split(), + "Quantity": [6, 8, 3, 4, 10], + "Date": [ + datetime(2013, 10, 1, 0, 0), + datetime(2013, 10, 1, 0, 0), + datetime(2013, 10, 1, 0, 0), + datetime(2013, 10, 2, 0, 0), + datetime(2013, 10, 2, 0, 0), + ], + } + ).set_index(["Date", "Buyer"]) + + result = df.groupby([Grouper(freq="1D"), "Buyer"]).sum(numeric_only=True) + tm.assert_frame_equal(result, expected) + + result = df.groupby([Grouper(freq="1ME"), "Buyer"]).sum(numeric_only=True) + expected = DataFrame( + { + "Buyer": "Carl Joe Mark".split(), + "Quantity": [10, 18, 3], + "Date": [ + datetime(2013, 10, 31, 0, 0), + datetime(2013, 10, 31, 0, 0), + datetime(2013, 10, 31, 0, 0), + ], + } + ).set_index(["Date", "Buyer"]) + tm.assert_frame_equal(result, expected) + + # passing the name + df = df.reset_index() + result = df.groupby([Grouper(freq="1ME", key="Date"), "Buyer"]).sum( + numeric_only=True + ) + tm.assert_frame_equal(result, expected) + + with pytest.raises(KeyError, match="'The grouper name foo is not found'"): + df.groupby([Grouper(freq="1ME", key="foo"), "Buyer"]).sum() + + # passing the level + df = df.set_index("Date") + result = df.groupby([Grouper(freq="1ME", level="Date"), "Buyer"]).sum( + numeric_only=True + ) + tm.assert_frame_equal(result, expected) + result = df.groupby([Grouper(freq="1ME", level=0), "Buyer"]).sum( + numeric_only=True + ) + tm.assert_frame_equal(result, expected) + + with pytest.raises(ValueError, match="The level foo is not valid"): + df.groupby([Grouper(freq="1ME", level="foo"), "Buyer"]).sum() + + # multi names + df = df.copy() + df["Date"] = df.index + offsets.MonthEnd(2) + result = df.groupby([Grouper(freq="1ME", key="Date"), "Buyer"]).sum( + numeric_only=True + ) + expected = DataFrame( + { + "Buyer": "Carl Joe Mark".split(), + "Quantity": [10, 18, 3], + "Date": [ + datetime(2013, 11, 30, 0, 0), + datetime(2013, 11, 30, 0, 0), + datetime(2013, 11, 30, 0, 0), + ], + } + ).set_index(["Date", "Buyer"]) + tm.assert_frame_equal(result, expected) + + # error as we have both a level and a name! + msg = "The Grouper cannot specify both a key and a level!" + with pytest.raises(ValueError, match=msg): + df.groupby( + [Grouper(freq="1ME", key="Date", level="Date"), "Buyer"] + ).sum() + + # single groupers + expected = DataFrame( + [[31]], + columns=["Quantity"], + index=DatetimeIndex( + [datetime(2013, 10, 31, 0, 0)], freq=offsets.MonthEnd(), name="Date" + ), + ) + result = df.groupby(Grouper(freq="1ME")).sum(numeric_only=True) + tm.assert_frame_equal(result, expected) + + result = df.groupby([Grouper(freq="1ME")]).sum(numeric_only=True) + tm.assert_frame_equal(result, expected) + + expected.index = expected.index.shift(1) + assert expected.index.freq == offsets.MonthEnd() + result = df.groupby(Grouper(freq="1ME", key="Date")).sum(numeric_only=True) + tm.assert_frame_equal(result, expected) + + result = df.groupby([Grouper(freq="1ME", key="Date")]).sum( + numeric_only=True + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("freq", ["D", "ME", "YE", "QE-APR"]) + def test_timegrouper_with_reg_groups_freq(self, freq): + # GH 6764 multiple grouping with/without sort + df = DataFrame( + { + "date": pd.to_datetime( + [ + "20121002", + "20121007", + "20130130", + "20130202", + "20130305", + "20121002", + "20121207", + "20130130", + "20130202", + "20130305", + "20130202", + "20130305", + ] + ), + "user_id": [1, 1, 1, 1, 1, 3, 3, 3, 5, 5, 5, 5], + "whole_cost": [ + 1790, + 364, + 280, + 259, + 201, + 623, + 90, + 312, + 359, + 301, + 359, + 801, + ], + "cost1": [12, 15, 10, 24, 39, 1, 0, 90, 45, 34, 1, 12], + } + ).set_index("date") + + expected = ( + df.groupby("user_id")["whole_cost"] + .resample(freq) + .sum(min_count=1) # XXX + .dropna() + .reorder_levels(["date", "user_id"]) + .sort_index() + .astype("int64") + ) + expected.name = "whole_cost" + + result1 = ( + df.sort_index().groupby([Grouper(freq=freq), "user_id"])["whole_cost"].sum() + ) + tm.assert_series_equal(result1, expected) + + result2 = df.groupby([Grouper(freq=freq), "user_id"])["whole_cost"].sum() + tm.assert_series_equal(result2, expected) + + def test_timegrouper_get_group(self): + # GH 6914 + + df_original = DataFrame( + { + "Buyer": "Carl Joe Joe Carl Joe Carl".split(), + "Quantity": [18, 3, 5, 1, 9, 3], + "Date": [ + datetime(2013, 9, 1, 13, 0), + datetime(2013, 9, 1, 13, 5), + datetime(2013, 10, 1, 20, 0), + datetime(2013, 10, 3, 10, 0), + datetime(2013, 12, 2, 12, 0), + datetime(2013, 9, 2, 14, 0), + ], + } + ) + df_reordered = df_original.sort_values(by="Quantity") + + # single grouping + expected_list = [ + df_original.iloc[[0, 1, 5]], + df_original.iloc[[2, 3]], + df_original.iloc[[4]], + ] + dt_list = ["2013-09-30", "2013-10-31", "2013-12-31"] + + for df in [df_original, df_reordered]: + grouped = df.groupby(Grouper(freq="ME", key="Date")) + for t, expected in zip(dt_list, expected_list): + dt = Timestamp(t) + result = grouped.get_group(dt) + tm.assert_frame_equal(result, expected) + + # multiple grouping + expected_list = [ + df_original.iloc[[1]], + df_original.iloc[[3]], + df_original.iloc[[4]], + ] + g_list = [("Joe", "2013-09-30"), ("Carl", "2013-10-31"), ("Joe", "2013-12-31")] + + for df in [df_original, df_reordered]: + grouped = df.groupby(["Buyer", Grouper(freq="ME", key="Date")]) + for (b, t), expected in zip(g_list, expected_list): + dt = Timestamp(t) + result = grouped.get_group((b, dt)) + tm.assert_frame_equal(result, expected) + + # with index + df_original = df_original.set_index("Date") + df_reordered = df_original.sort_values(by="Quantity") + + expected_list = [ + df_original.iloc[[0, 1, 5]], + df_original.iloc[[2, 3]], + df_original.iloc[[4]], + ] + + for df in [df_original, df_reordered]: + grouped = df.groupby(Grouper(freq="ME")) + for t, expected in zip(dt_list, expected_list): + dt = Timestamp(t) + result = grouped.get_group(dt) + tm.assert_frame_equal(result, expected) + + def test_timegrouper_apply_return_type_series(self): + # Using `apply` with the `TimeGrouper` should give the + # same return type as an `apply` with a `Grouper`. + # Issue #11742 + df = DataFrame({"date": ["10/10/2000", "11/10/2000"], "value": [10, 13]}) + df_dt = df.copy() + df_dt["date"] = pd.to_datetime(df_dt["date"]) + + def sumfunc_series(x): + return Series([x["value"].sum()], ("sum",)) + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + expected = df.groupby(Grouper(key="date")).apply(sumfunc_series) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df_dt.groupby(Grouper(freq="ME", key="date")).apply(sumfunc_series) + tm.assert_frame_equal( + result.reset_index(drop=True), expected.reset_index(drop=True) + ) + + def test_timegrouper_apply_return_type_value(self): + # Using `apply` with the `TimeGrouper` should give the + # same return type as an `apply` with a `Grouper`. + # Issue #11742 + df = DataFrame({"date": ["10/10/2000", "11/10/2000"], "value": [10, 13]}) + df_dt = df.copy() + df_dt["date"] = pd.to_datetime(df_dt["date"]) + + def sumfunc_value(x): + return x.value.sum() + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + expected = df.groupby(Grouper(key="date")).apply(sumfunc_value) + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df_dt.groupby(Grouper(freq="ME", key="date")).apply(sumfunc_value) + tm.assert_series_equal( + result.reset_index(drop=True), expected.reset_index(drop=True) + ) + + def test_groupby_groups_datetimeindex(self): + # GH#1430 + periods = 1000 + ind = date_range(start="2012/1/1", freq="5min", periods=periods) + df = DataFrame( + {"high": np.arange(periods), "low": np.arange(periods)}, index=ind + ) + grouped = df.groupby(lambda x: datetime(x.year, x.month, x.day)) + + # it works! + groups = grouped.groups + assert isinstance(next(iter(groups.keys())), datetime) + + def test_groupby_groups_datetimeindex2(self): + # GH#11442 + index = date_range("2015/01/01", periods=5, name="date") + df = DataFrame({"A": [5, 6, 7, 8, 9], "B": [1, 2, 3, 4, 5]}, index=index) + result = df.groupby(level="date").groups + dates = ["2015-01-05", "2015-01-04", "2015-01-03", "2015-01-02", "2015-01-01"] + expected = { + Timestamp(date): DatetimeIndex([date], name="date") for date in dates + } + tm.assert_dict_equal(result, expected) + + grouped = df.groupby(level="date") + for date in dates: + result = grouped.get_group(date) + data = [[df.loc[date, "A"], df.loc[date, "B"]]] + expected_index = DatetimeIndex( + [date], name="date", freq="D", dtype=index.dtype + ) + expected = DataFrame(data, columns=list("AB"), index=expected_index) + tm.assert_frame_equal(result, expected) + + def test_groupby_groups_datetimeindex_tz(self): + # GH 3950 + dates = [ + "2011-07-19 07:00:00", + "2011-07-19 08:00:00", + "2011-07-19 09:00:00", + "2011-07-19 07:00:00", + "2011-07-19 08:00:00", + "2011-07-19 09:00:00", + ] + df = DataFrame( + { + "label": ["a", "a", "a", "b", "b", "b"], + "datetime": dates, + "value1": np.arange(6, dtype="int64"), + "value2": [1, 2] * 3, + } + ) + df["datetime"] = df["datetime"].apply(lambda d: Timestamp(d, tz="US/Pacific")) + + exp_idx1 = DatetimeIndex( + [ + "2011-07-19 07:00:00", + "2011-07-19 07:00:00", + "2011-07-19 08:00:00", + "2011-07-19 08:00:00", + "2011-07-19 09:00:00", + "2011-07-19 09:00:00", + ], + tz="US/Pacific", + name="datetime", + ) + exp_idx2 = Index(["a", "b"] * 3, name="label") + exp_idx = MultiIndex.from_arrays([exp_idx1, exp_idx2]) + expected = DataFrame( + {"value1": [0, 3, 1, 4, 2, 5], "value2": [1, 2, 2, 1, 1, 2]}, + index=exp_idx, + columns=["value1", "value2"], + ) + + result = df.groupby(["datetime", "label"]).sum() + tm.assert_frame_equal(result, expected) + + # by level + didx = DatetimeIndex(dates, tz="Asia/Tokyo") + df = DataFrame( + {"value1": np.arange(6, dtype="int64"), "value2": [1, 2, 3, 1, 2, 3]}, + index=didx, + ) + + exp_idx = DatetimeIndex( + ["2011-07-19 07:00:00", "2011-07-19 08:00:00", "2011-07-19 09:00:00"], + tz="Asia/Tokyo", + ) + expected = DataFrame( + {"value1": [3, 5, 7], "value2": [2, 4, 6]}, + index=exp_idx, + columns=["value1", "value2"], + ) + + result = df.groupby(level=0).sum() + tm.assert_frame_equal(result, expected) + + def test_frame_datetime64_handling_groupby(self): + # it works! + df = DataFrame( + [(3, np.datetime64("2012-07-03")), (3, np.datetime64("2012-07-04"))], + columns=["a", "date"], + ) + result = df.groupby("a").first() + assert result["date"][3] == Timestamp("2012-07-03") + + def test_groupby_multi_timezone(self): + # combining multiple / different timezones yields UTC + df = DataFrame( + { + "value": range(5), + "date": [ + "2000-01-28 16:47:00", + "2000-01-29 16:48:00", + "2000-01-30 16:49:00", + "2000-01-31 16:50:00", + "2000-01-01 16:50:00", + ], + "tz": [ + "America/Chicago", + "America/Chicago", + "America/Los_Angeles", + "America/Chicago", + "America/New_York", + ], + } + ) + + result = df.groupby("tz", group_keys=False).date.apply( + lambda x: pd.to_datetime(x).dt.tz_localize(x.name) + ) + + expected = Series( + [ + Timestamp("2000-01-28 16:47:00-0600", tz="America/Chicago"), + Timestamp("2000-01-29 16:48:00-0600", tz="America/Chicago"), + Timestamp("2000-01-30 16:49:00-0800", tz="America/Los_Angeles"), + Timestamp("2000-01-31 16:50:00-0600", tz="America/Chicago"), + Timestamp("2000-01-01 16:50:00-0500", tz="America/New_York"), + ], + name="date", + dtype=object, + ) + tm.assert_series_equal(result, expected) + + tz = "America/Chicago" + res_values = df.groupby("tz").date.get_group(tz) + result = pd.to_datetime(res_values).dt.tz_localize(tz) + exp_values = Series( + ["2000-01-28 16:47:00", "2000-01-29 16:48:00", "2000-01-31 16:50:00"], + index=[0, 1, 3], + name="date", + ) + expected = pd.to_datetime(exp_values).dt.tz_localize(tz) + tm.assert_series_equal(result, expected) + + def test_groupby_groups_periods(self): + dates = [ + "2011-07-19 07:00:00", + "2011-07-19 08:00:00", + "2011-07-19 09:00:00", + "2011-07-19 07:00:00", + "2011-07-19 08:00:00", + "2011-07-19 09:00:00", + ] + df = DataFrame( + { + "label": ["a", "a", "a", "b", "b", "b"], + "period": [pd.Period(d, freq="h") for d in dates], + "value1": np.arange(6, dtype="int64"), + "value2": [1, 2] * 3, + } + ) + + exp_idx1 = pd.PeriodIndex( + [ + "2011-07-19 07:00:00", + "2011-07-19 07:00:00", + "2011-07-19 08:00:00", + "2011-07-19 08:00:00", + "2011-07-19 09:00:00", + "2011-07-19 09:00:00", + ], + freq="h", + name="period", + ) + exp_idx2 = Index(["a", "b"] * 3, name="label") + exp_idx = MultiIndex.from_arrays([exp_idx1, exp_idx2]) + expected = DataFrame( + {"value1": [0, 3, 1, 4, 2, 5], "value2": [1, 2, 2, 1, 1, 2]}, + index=exp_idx, + columns=["value1", "value2"], + ) + + result = df.groupby(["period", "label"]).sum() + tm.assert_frame_equal(result, expected) + + # by level + didx = pd.PeriodIndex(dates, freq="h") + df = DataFrame( + {"value1": np.arange(6, dtype="int64"), "value2": [1, 2, 3, 1, 2, 3]}, + index=didx, + ) + + exp_idx = pd.PeriodIndex( + ["2011-07-19 07:00:00", "2011-07-19 08:00:00", "2011-07-19 09:00:00"], + freq="h", + ) + expected = DataFrame( + {"value1": [3, 5, 7], "value2": [2, 4, 6]}, + index=exp_idx, + columns=["value1", "value2"], + ) + + result = df.groupby(level=0).sum() + tm.assert_frame_equal(result, expected) + + def test_groupby_first_datetime64(self): + df = DataFrame([(1, 1351036800000000000), (2, 1351036800000000000)]) + df[1] = df[1].astype("M8[ns]") + + assert issubclass(df[1].dtype.type, np.datetime64) + + result = df.groupby(level=0).first() + got_dt = result[1].dtype + assert issubclass(got_dt.type, np.datetime64) + + result = df[1].groupby(level=0).first() + got_dt = result.dtype + assert issubclass(got_dt.type, np.datetime64) + + def test_groupby_max_datetime64(self): + # GH 5869 + # datetimelike dtype conversion from int + df = DataFrame({"A": Timestamp("20130101"), "B": np.arange(5)}) + # TODO: can we retain second reso in .apply here? + expected = df.groupby("A")["A"].apply(lambda x: x.max()).astype("M8[s]") + result = df.groupby("A")["A"].max() + tm.assert_series_equal(result, expected) + + def test_groupby_datetime64_32_bit(self): + # GH 6410 / numpy 4328 + # 32-bit under 1.9-dev indexing issue + + df = DataFrame({"A": range(2), "B": [Timestamp("2000-01-1")] * 2}) + result = df.groupby("A")["B"].transform("min") + expected = Series([Timestamp("2000-01-1")] * 2, name="B") + tm.assert_series_equal(result, expected) + + def test_groupby_with_timezone_selection(self): + # GH 11616 + # Test that column selection returns output in correct timezone. + + df = DataFrame( + { + "factor": np.random.default_rng(2).integers(0, 3, size=60), + "time": date_range("01/01/2000 00:00", periods=60, freq="s", tz="UTC"), + } + ) + df1 = df.groupby("factor").max()["time"] + df2 = df.groupby("factor")["time"].max() + tm.assert_series_equal(df1, df2) + + def test_timezone_info(self): + # see gh-11682: Timezone info lost when broadcasting + # scalar datetime to DataFrame + + df = DataFrame({"a": [1], "b": [datetime.now(pytz.utc)]}) + assert df["b"][0].tzinfo == pytz.utc + df = DataFrame({"a": [1, 2, 3]}) + df["b"] = datetime.now(pytz.utc) + assert df["b"][0].tzinfo == pytz.utc + + def test_datetime_count(self): + df = DataFrame( + {"a": [1, 2, 3] * 2, "dates": date_range("now", periods=6, freq="min")} + ) + result = df.groupby("a").dates.count() + expected = Series([2, 2, 2], index=Index([1, 2, 3], name="a"), name="dates") + tm.assert_series_equal(result, expected) + + def test_first_last_max_min_on_time_data(self): + # GH 10295 + # Verify that NaT is not in the result of max, min, first and last on + # Dataframe with datetime or timedelta values. + df_test = DataFrame( + { + "dt": [ + np.nan, + "2015-07-24 10:10", + "2015-07-25 11:11", + "2015-07-23 12:12", + np.nan, + ], + "td": [ + np.nan, + timedelta(days=1), + timedelta(days=2), + timedelta(days=3), + np.nan, + ], + } + ) + df_test.dt = pd.to_datetime(df_test.dt) + df_test["group"] = "A" + df_ref = df_test[df_test.dt.notna()] + + grouped_test = df_test.groupby("group") + grouped_ref = df_ref.groupby("group") + + tm.assert_frame_equal(grouped_ref.max(), grouped_test.max()) + tm.assert_frame_equal(grouped_ref.min(), grouped_test.min()) + tm.assert_frame_equal(grouped_ref.first(), grouped_test.first()) + tm.assert_frame_equal(grouped_ref.last(), grouped_test.last()) + + def test_nunique_with_timegrouper_and_nat(self): + # GH 17575 + test = DataFrame( + { + "time": [ + Timestamp("2016-06-28 09:35:35"), + pd.NaT, + Timestamp("2016-06-28 16:46:28"), + ], + "data": ["1", "2", "3"], + } + ) + + grouper = Grouper(key="time", freq="h") + result = test.groupby(grouper)["data"].nunique() + expected = test[test.time.notnull()].groupby(grouper)["data"].nunique() + expected.index = expected.index._with_freq(None) + tm.assert_series_equal(result, expected) + + def test_scalar_call_versus_list_call(self): + # Issue: 17530 + data_frame = { + "location": ["shanghai", "beijing", "shanghai"], + "time": Series( + ["2017-08-09 13:32:23", "2017-08-11 23:23:15", "2017-08-11 22:23:15"], + dtype="datetime64[ns]", + ), + "value": [1, 2, 3], + } + data_frame = DataFrame(data_frame).set_index("time") + grouper = Grouper(freq="D") + + grouped = data_frame.groupby(grouper) + result = grouped.count() + grouped = data_frame.groupby([grouper]) + expected = grouped.count() + + tm.assert_frame_equal(result, expected) + + def test_grouper_period_index(self): + # GH 32108 + periods = 2 + index = pd.period_range( + start="2018-01", periods=periods, freq="M", name="Month" + ) + period_series = Series(range(periods), index=index) + result = period_series.groupby(period_series.index.month).sum() + + expected = Series( + range(periods), index=Index(range(1, periods + 1), name=index.name) + ) + tm.assert_series_equal(result, expected) + + def test_groupby_apply_timegrouper_with_nat_dict_returns( + self, groupby_with_truncated_bingrouper + ): + # GH#43500 case where gb._grouper.result_index and gb._grouper.group_keys_seq + # have different lengths that goes through the `isinstance(values[0], dict)` + # path + gb = groupby_with_truncated_bingrouper + + res = gb["Quantity"].apply(lambda x: {"foo": len(x)}) + + df = gb.obj + unit = df["Date"]._values.unit + dti = date_range("2013-09-01", "2013-10-01", freq="5D", name="Date", unit=unit) + mi = MultiIndex.from_arrays([dti, ["foo"] * len(dti)]) + expected = Series([3, 0, 0, 0, 0, 0, 2], index=mi, name="Quantity") + tm.assert_series_equal(res, expected) + + def test_groupby_apply_timegrouper_with_nat_scalar_returns( + self, groupby_with_truncated_bingrouper + ): + # GH#43500 Previously raised ValueError bc used index with incorrect + # length in wrap_applied_result + gb = groupby_with_truncated_bingrouper + + res = gb["Quantity"].apply(lambda x: x.iloc[0] if len(x) else np.nan) + + df = gb.obj + unit = df["Date"]._values.unit + dti = date_range("2013-09-01", "2013-10-01", freq="5D", name="Date", unit=unit) + expected = Series( + [18, np.nan, np.nan, np.nan, np.nan, np.nan, 5], + index=dti._with_freq(None), + name="Quantity", + ) + + tm.assert_series_equal(res, expected) + + def test_groupby_apply_timegrouper_with_nat_apply_squeeze( + self, frame_for_truncated_bingrouper + ): + df = frame_for_truncated_bingrouper + + # We need to create a GroupBy object with only one non-NaT group, + # so use a huge freq so that all non-NaT dates will be grouped together + tdg = Grouper(key="Date", freq="100YE") + gb = df.groupby(tdg) + + # check that we will go through the singular_series path + # in _wrap_applied_output_series + assert gb.ngroups == 1 + assert gb._selected_obj._get_axis(gb.axis).nlevels == 1 + + # function that returns a Series + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + res = gb.apply(lambda x: x["Quantity"] * 2) + + dti = Index([Timestamp("2013-12-31")], dtype=df["Date"].dtype, name="Date") + expected = DataFrame( + [[36, 6, 6, 10, 2]], + index=dti, + columns=Index([0, 1, 5, 2, 3], name="Quantity"), + ) + tm.assert_frame_equal(res, expected) + + @pytest.mark.single_cpu + def test_groupby_agg_numba_timegrouper_with_nat( + self, groupby_with_truncated_bingrouper + ): + pytest.importorskip("numba") + + # See discussion in GH#43487 + gb = groupby_with_truncated_bingrouper + + result = gb["Quantity"].aggregate( + lambda values, index: np.nanmean(values), engine="numba" + ) + + expected = gb["Quantity"].aggregate("mean") + tm.assert_series_equal(result, expected) + + result_df = gb[["Quantity"]].aggregate( + lambda values, index: np.nanmean(values), engine="numba" + ) + expected_df = gb[["Quantity"]].aggregate("mean") + tm.assert_frame_equal(result_df, expected_df) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/transform/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/transform/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/transform/test_numba.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/transform/test_numba.py new file mode 100644 index 0000000000000000000000000000000000000000..61fcc930f116a7e9a5fefde0885f92b9b489d343 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/transform/test_numba.py @@ -0,0 +1,284 @@ +import numpy as np +import pytest + +from pandas.errors import NumbaUtilError + +from pandas import ( + DataFrame, + Series, + option_context, +) +import pandas._testing as tm + +pytestmark = pytest.mark.single_cpu + + +def test_correct_function_signature(): + pytest.importorskip("numba") + + def incorrect_function(x): + return x + 1 + + data = DataFrame( + {"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]}, + columns=["key", "data"], + ) + with pytest.raises(NumbaUtilError, match="The first 2"): + data.groupby("key").transform(incorrect_function, engine="numba") + + with pytest.raises(NumbaUtilError, match="The first 2"): + data.groupby("key")["data"].transform(incorrect_function, engine="numba") + + +def test_check_nopython_kwargs(): + pytest.importorskip("numba") + + def incorrect_function(values, index): + return values + 1 + + data = DataFrame( + {"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]}, + columns=["key", "data"], + ) + with pytest.raises(NumbaUtilError, match="numba does not support"): + data.groupby("key").transform(incorrect_function, engine="numba", a=1) + + with pytest.raises(NumbaUtilError, match="numba does not support"): + data.groupby("key")["data"].transform(incorrect_function, engine="numba", a=1) + + +@pytest.mark.filterwarnings("ignore") +# Filter warnings when parallel=True and the function can't be parallelized by Numba +@pytest.mark.parametrize("jit", [True, False]) +@pytest.mark.parametrize("pandas_obj", ["Series", "DataFrame"]) +@pytest.mark.parametrize("as_index", [True, False]) +def test_numba_vs_cython(jit, pandas_obj, nogil, parallel, nopython, as_index): + pytest.importorskip("numba") + + def func(values, index): + return values + 1 + + if jit: + # Test accepted jitted functions + import numba + + func = numba.jit(func) + + data = DataFrame( + {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1] + ) + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + grouped = data.groupby(0, as_index=as_index) + if pandas_obj == "Series": + grouped = grouped[1] + + result = grouped.transform(func, engine="numba", engine_kwargs=engine_kwargs) + expected = grouped.transform(lambda x: x + 1, engine="cython") + + tm.assert_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore") +# Filter warnings when parallel=True and the function can't be parallelized by Numba +@pytest.mark.parametrize("jit", [True, False]) +@pytest.mark.parametrize("pandas_obj", ["Series", "DataFrame"]) +def test_cache(jit, pandas_obj, nogil, parallel, nopython): + # Test that the functions are cached correctly if we switch functions + pytest.importorskip("numba") + + def func_1(values, index): + return values + 1 + + def func_2(values, index): + return values * 5 + + if jit: + import numba + + func_1 = numba.jit(func_1) + func_2 = numba.jit(func_2) + + data = DataFrame( + {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1] + ) + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + grouped = data.groupby(0) + if pandas_obj == "Series": + grouped = grouped[1] + + result = grouped.transform(func_1, engine="numba", engine_kwargs=engine_kwargs) + expected = grouped.transform(lambda x: x + 1, engine="cython") + tm.assert_equal(result, expected) + + result = grouped.transform(func_2, engine="numba", engine_kwargs=engine_kwargs) + expected = grouped.transform(lambda x: x * 5, engine="cython") + tm.assert_equal(result, expected) + + # Retest func_1 which should use the cache + result = grouped.transform(func_1, engine="numba", engine_kwargs=engine_kwargs) + expected = grouped.transform(lambda x: x + 1, engine="cython") + tm.assert_equal(result, expected) + + +def test_use_global_config(): + pytest.importorskip("numba") + + def func_1(values, index): + return values + 1 + + data = DataFrame( + {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1] + ) + grouped = data.groupby(0) + expected = grouped.transform(func_1, engine="numba") + with option_context("compute.use_numba", True): + result = grouped.transform(func_1, engine=None) + tm.assert_frame_equal(expected, result) + + +# TODO: Test more than just reductions (e.g. actually test transformations once we have +@pytest.mark.parametrize( + "agg_func", [["min", "max"], "min", {"B": ["min", "max"], "C": "sum"}] +) +def test_string_cython_vs_numba(agg_func, numba_supported_reductions): + pytest.importorskip("numba") + agg_func, kwargs = numba_supported_reductions + data = DataFrame( + {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1] + ) + grouped = data.groupby(0) + + result = grouped.transform(agg_func, engine="numba", **kwargs) + expected = grouped.transform(agg_func, engine="cython", **kwargs) + tm.assert_frame_equal(result, expected) + + result = grouped[1].transform(agg_func, engine="numba", **kwargs) + expected = grouped[1].transform(agg_func, engine="cython", **kwargs) + tm.assert_series_equal(result, expected) + + +def test_args_not_cached(): + # GH 41647 + pytest.importorskip("numba") + + def sum_last(values, index, n): + return values[-n:].sum() + + df = DataFrame({"id": [0, 0, 1, 1], "x": [1, 1, 1, 1]}) + grouped_x = df.groupby("id")["x"] + result = grouped_x.transform(sum_last, 1, engine="numba") + expected = Series([1.0] * 4, name="x") + tm.assert_series_equal(result, expected) + + result = grouped_x.transform(sum_last, 2, engine="numba") + expected = Series([2.0] * 4, name="x") + tm.assert_series_equal(result, expected) + + +def test_index_data_correctly_passed(): + # GH 43133 + pytest.importorskip("numba") + + def f(values, index): + return index - 1 + + df = DataFrame({"group": ["A", "A", "B"], "v": [4, 5, 6]}, index=[-1, -2, -3]) + result = df.groupby("group").transform(f, engine="numba") + expected = DataFrame([-4.0, -3.0, -2.0], columns=["v"], index=[-1, -2, -3]) + tm.assert_frame_equal(result, expected) + + +def test_engine_kwargs_not_cached(): + # If the user passes a different set of engine_kwargs don't return the same + # jitted function + pytest.importorskip("numba") + nogil = True + parallel = False + nopython = True + + def func_kwargs(values, index): + return nogil + parallel + nopython + + engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel} + df = DataFrame({"value": [0, 0, 0]}) + result = df.groupby(level=0).transform( + func_kwargs, engine="numba", engine_kwargs=engine_kwargs + ) + expected = DataFrame({"value": [2.0, 2.0, 2.0]}) + tm.assert_frame_equal(result, expected) + + nogil = False + engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel} + result = df.groupby(level=0).transform( + func_kwargs, engine="numba", engine_kwargs=engine_kwargs + ) + expected = DataFrame({"value": [1.0, 1.0, 1.0]}) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore") +def test_multiindex_one_key(nogil, parallel, nopython): + pytest.importorskip("numba") + + def numba_func(values, index): + return 1 + + df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"]) + engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel} + result = df.groupby("A").transform( + numba_func, engine="numba", engine_kwargs=engine_kwargs + ) + expected = DataFrame([{"A": 1, "B": 2, "C": 1.0}]).set_index(["A", "B"]) + tm.assert_frame_equal(result, expected) + + +def test_multiindex_multi_key_not_supported(nogil, parallel, nopython): + pytest.importorskip("numba") + + def numba_func(values, index): + return 1 + + df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"]) + engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel} + with pytest.raises(NotImplementedError, match="more than 1 grouping labels"): + df.groupby(["A", "B"]).transform( + numba_func, engine="numba", engine_kwargs=engine_kwargs + ) + + +def test_multilabel_numba_vs_cython(numba_supported_reductions): + pytest.importorskip("numba") + reduction, kwargs = numba_supported_reductions + df = DataFrame( + { + "A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"], + "B": ["one", "one", "two", "three", "two", "two", "one", "three"], + "C": np.random.default_rng(2).standard_normal(8), + "D": np.random.default_rng(2).standard_normal(8), + } + ) + gb = df.groupby(["A", "B"]) + res_agg = gb.transform(reduction, engine="numba", **kwargs) + expected_agg = gb.transform(reduction, engine="cython", **kwargs) + tm.assert_frame_equal(res_agg, expected_agg) + + +def test_multilabel_udf_numba_vs_cython(): + pytest.importorskip("numba") + df = DataFrame( + { + "A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"], + "B": ["one", "one", "two", "three", "two", "two", "one", "three"], + "C": np.random.default_rng(2).standard_normal(8), + "D": np.random.default_rng(2).standard_normal(8), + } + ) + gb = df.groupby(["A", "B"]) + result = gb.transform( + lambda values, index: (values - values.min()) / (values.max() - values.min()), + engine="numba", + ) + expected = gb.transform( + lambda x: (x - x.min()) / (x.max() - x.min()), engine="cython" + ) + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/transform/test_transform.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/transform/test_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..fd9bd5cc55538ff867faffb52ea9546e0aec53a9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/groupby/transform/test_transform.py @@ -0,0 +1,1702 @@ +""" test with the .transform """ +import numpy as np +import pytest + +from pandas._libs import lib + +from pandas.core.dtypes.common import ensure_platform_int + +import pandas as pd +from pandas import ( + Categorical, + DataFrame, + Index, + MultiIndex, + Series, + Timestamp, + concat, + date_range, +) +import pandas._testing as tm +from pandas.tests.groupby import get_groupby_method_args + + +def assert_fp_equal(a, b): + assert (np.abs(a - b) < 1e-12).all() + + +def test_transform(): + data = Series(np.arange(9) // 3, index=np.arange(9)) + + index = np.arange(9) + np.random.default_rng(2).shuffle(index) + data = data.reindex(index) + + grouped = data.groupby(lambda x: x // 3) + + transformed = grouped.transform(lambda x: x * x.sum()) + assert transformed[7] == 12 + + # GH 8046 + # make sure that we preserve the input order + + df = DataFrame( + np.arange(6, dtype="int64").reshape(3, 2), columns=["a", "b"], index=[0, 2, 1] + ) + key = [0, 0, 1] + expected = ( + df.sort_index() + .groupby(key) + .transform(lambda x: x - x.mean()) + .groupby(key) + .mean() + ) + result = df.groupby(key).transform(lambda x: x - x.mean()).groupby(key).mean() + tm.assert_frame_equal(result, expected) + + def demean(arr): + return arr - arr.mean(axis=0) + + people = DataFrame( + np.random.default_rng(2).standard_normal((5, 5)), + columns=["a", "b", "c", "d", "e"], + index=["Joe", "Steve", "Wes", "Jim", "Travis"], + ) + key = ["one", "two", "one", "two", "one"] + result = people.groupby(key).transform(demean).groupby(key).mean() + expected = people.groupby(key, group_keys=False).apply(demean).groupby(key).mean() + tm.assert_frame_equal(result, expected) + + # GH 8430 + df = DataFrame( + np.random.default_rng(2).standard_normal((50, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=50, freq="B"), + ) + g = df.groupby(pd.Grouper(freq="ME")) + g.transform(lambda x: x - 1) + + # GH 9700 + df = DataFrame({"a": range(5, 10), "b": range(5)}) + msg = "using DataFrameGroupBy.max" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = df.groupby("a").transform(max) + expected = DataFrame({"b": range(5)}) + tm.assert_frame_equal(result, expected) + + +def test_transform_fast(): + df = DataFrame( + { + "id": np.arange(100000) / 3, + "val": np.random.default_rng(2).standard_normal(100000), + } + ) + + grp = df.groupby("id")["val"] + + values = np.repeat(grp.mean().values, ensure_platform_int(grp.count().values)) + expected = Series(values, index=df.index, name="val") + + msg = "using SeriesGroupBy.mean" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = grp.transform(np.mean) + tm.assert_series_equal(result, expected) + + result = grp.transform("mean") + tm.assert_series_equal(result, expected) + + +def test_transform_fast2(): + # GH 12737 + df = DataFrame( + { + "grouping": [0, 1, 1, 3], + "f": [1.1, 2.1, 3.1, 4.5], + "d": date_range("2014-1-1", "2014-1-4"), + "i": [1, 2, 3, 4], + }, + columns=["grouping", "f", "i", "d"], + ) + result = df.groupby("grouping").transform("first") + + dates = Index( + [ + Timestamp("2014-1-1"), + Timestamp("2014-1-2"), + Timestamp("2014-1-2"), + Timestamp("2014-1-4"), + ], + dtype="M8[ns]", + ) + expected = DataFrame( + {"f": [1.1, 2.1, 2.1, 4.5], "d": dates, "i": [1, 2, 2, 4]}, + columns=["f", "i", "d"], + ) + tm.assert_frame_equal(result, expected) + + # selection + result = df.groupby("grouping")[["f", "i"]].transform("first") + expected = expected[["f", "i"]] + tm.assert_frame_equal(result, expected) + + +def test_transform_fast3(): + # dup columns + df = DataFrame([[1, 2, 3], [4, 5, 6]], columns=["g", "a", "a"]) + result = df.groupby("g").transform("first") + expected = df.drop("g", axis=1) + tm.assert_frame_equal(result, expected) + + +def test_transform_broadcast(tsframe, ts): + grouped = ts.groupby(lambda x: x.month) + msg = "using SeriesGroupBy.mean" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = grouped.transform(np.mean) + + tm.assert_index_equal(result.index, ts.index) + for _, gp in grouped: + assert_fp_equal(result.reindex(gp.index), gp.mean()) + + grouped = tsframe.groupby(lambda x: x.month) + msg = "using DataFrameGroupBy.mean" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = grouped.transform(np.mean) + tm.assert_index_equal(result.index, tsframe.index) + for _, gp in grouped: + agged = gp.mean(axis=0) + res = result.reindex(gp.index) + for col in tsframe: + assert_fp_equal(res[col], agged[col]) + + # group columns + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + grouped = tsframe.groupby({"A": 0, "B": 0, "C": 1, "D": 1}, axis=1) + msg = "using DataFrameGroupBy.mean" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = grouped.transform(np.mean) + tm.assert_index_equal(result.index, tsframe.index) + tm.assert_index_equal(result.columns, tsframe.columns) + for _, gp in grouped: + agged = gp.mean(1) + res = result.reindex(columns=gp.columns) + for idx in gp.index: + assert_fp_equal(res.xs(idx), agged[idx]) + + +def test_transform_axis_1(request, transformation_func): + # GH 36308 + + df = DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}, index=["x", "y"]) + args = get_groupby_method_args(transformation_func, df) + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + gb = df.groupby([0, 0, 1], axis=1) + warn = FutureWarning if transformation_func == "fillna" else None + msg = "DataFrameGroupBy.fillna is deprecated" + with tm.assert_produces_warning(warn, match=msg): + result = gb.transform(transformation_func, *args) + msg = "DataFrameGroupBy.fillna is deprecated" + with tm.assert_produces_warning(warn, match=msg): + expected = df.T.groupby([0, 0, 1]).transform(transformation_func, *args).T + + if transformation_func in ["diff", "shift"]: + # Result contains nans, so transpose coerces to float + expected["b"] = expected["b"].astype("int64") + + # cumcount returns Series; the rest are DataFrame + tm.assert_equal(result, expected) + + +def test_transform_axis_1_reducer(request, reduction_func): + # GH#45715 + if reduction_func in ( + "corrwith", + "ngroup", + "nth", + ): + marker = pytest.mark.xfail(reason="transform incorrectly fails - GH#45986") + request.applymarker(marker) + + df = DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}, index=["x", "y"]) + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + gb = df.groupby([0, 0, 1], axis=1) + + result = gb.transform(reduction_func) + expected = df.T.groupby([0, 0, 1]).transform(reduction_func).T + tm.assert_equal(result, expected) + + +def test_transform_axis_ts(tsframe): + # make sure that we are setting the axes + # correctly when on axis=0 or 1 + # in the presence of a non-monotonic indexer + # GH12713 + + base = tsframe.iloc[0:5] + r = len(base.index) + c = len(base.columns) + tso = DataFrame( + np.random.default_rng(2).standard_normal((r, c)), + index=base.index, + columns=base.columns, + dtype="float64", + ) + # monotonic + ts = tso + grouped = ts.groupby(lambda x: x.weekday(), group_keys=False) + result = ts - grouped.transform("mean") + expected = grouped.apply(lambda x: x - x.mean(axis=0)) + tm.assert_frame_equal(result, expected) + + ts = ts.T + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + grouped = ts.groupby(lambda x: x.weekday(), axis=1, group_keys=False) + result = ts - grouped.transform("mean") + expected = grouped.apply(lambda x: (x.T - x.mean(1)).T) + tm.assert_frame_equal(result, expected) + + # non-monotonic + ts = tso.iloc[[1, 0] + list(range(2, len(base)))] + grouped = ts.groupby(lambda x: x.weekday(), group_keys=False) + result = ts - grouped.transform("mean") + expected = grouped.apply(lambda x: x - x.mean(axis=0)) + tm.assert_frame_equal(result, expected) + + ts = ts.T + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + grouped = ts.groupby(lambda x: x.weekday(), axis=1, group_keys=False) + result = ts - grouped.transform("mean") + expected = grouped.apply(lambda x: (x.T - x.mean(1)).T) + tm.assert_frame_equal(result, expected) + + +def test_transform_dtype(): + # GH 9807 + # Check transform dtype output is preserved + df = DataFrame([[1, 3], [2, 3]]) + result = df.groupby(1).transform("mean") + expected = DataFrame([[1.5], [1.5]]) + tm.assert_frame_equal(result, expected) + + +def test_transform_bug(): + # GH 5712 + # transforming on a datetime column + df = DataFrame({"A": Timestamp("20130101"), "B": np.arange(5)}) + result = df.groupby("A")["B"].transform(lambda x: x.rank(ascending=False)) + expected = Series(np.arange(5, 0, step=-1), name="B", dtype="float64") + tm.assert_series_equal(result, expected) + + +def test_transform_numeric_to_boolean(): + # GH 16875 + # inconsistency in transforming boolean values + expected = Series([True, True], name="A") + + df = DataFrame({"A": [1.1, 2.2], "B": [1, 2]}) + result = df.groupby("B").A.transform(lambda x: True) + tm.assert_series_equal(result, expected) + + df = DataFrame({"A": [1, 2], "B": [1, 2]}) + result = df.groupby("B").A.transform(lambda x: True) + tm.assert_series_equal(result, expected) + + +def test_transform_datetime_to_timedelta(): + # GH 15429 + # transforming a datetime to timedelta + df = DataFrame({"A": Timestamp("20130101"), "B": np.arange(5)}) + expected = Series( + Timestamp("20130101") - Timestamp("20130101"), index=range(5), name="A" + ) + + # this does date math without changing result type in transform + base_time = df["A"][0] + result = ( + df.groupby("A")["A"].transform(lambda x: x.max() - x.min() + base_time) + - base_time + ) + tm.assert_series_equal(result, expected) + + # this does date math and causes the transform to return timedelta + result = df.groupby("A")["A"].transform(lambda x: x.max() - x.min()) + tm.assert_series_equal(result, expected) + + +def test_transform_datetime_to_numeric(): + # GH 10972 + # convert dt to float + df = DataFrame({"a": 1, "b": date_range("2015-01-01", periods=2, freq="D")}) + result = df.groupby("a").b.transform( + lambda x: x.dt.dayofweek - x.dt.dayofweek.mean() + ) + + expected = Series([-0.5, 0.5], name="b") + tm.assert_series_equal(result, expected) + + # convert dt to int + df = DataFrame({"a": 1, "b": date_range("2015-01-01", periods=2, freq="D")}) + result = df.groupby("a").b.transform( + lambda x: x.dt.dayofweek - x.dt.dayofweek.min() + ) + + expected = Series([0, 1], dtype=np.int32, name="b") + tm.assert_series_equal(result, expected) + + +def test_transform_casting(): + # 13046 + times = [ + "13:43:27", + "14:26:19", + "14:29:01", + "18:39:34", + "18:40:18", + "18:44:30", + "18:46:00", + "18:52:15", + "18:59:59", + "19:17:48", + "19:21:38", + ] + df = DataFrame( + { + "A": [f"B-{i}" for i in range(11)], + "ID3": np.take( + ["a", "b", "c", "d", "e"], [0, 1, 2, 1, 3, 1, 1, 1, 4, 1, 1] + ), + "DATETIME": pd.to_datetime([f"2014-10-08 {time}" for time in times]), + }, + index=pd.RangeIndex(11, name="idx"), + ) + + result = df.groupby("ID3")["DATETIME"].transform(lambda x: x.diff()) + assert lib.is_np_dtype(result.dtype, "m") + + result = df[["ID3", "DATETIME"]].groupby("ID3").transform(lambda x: x.diff()) + assert lib.is_np_dtype(result.DATETIME.dtype, "m") + + +def test_transform_multiple(ts): + grouped = ts.groupby([lambda x: x.year, lambda x: x.month]) + + grouped.transform(lambda x: x * 2) + + msg = "using SeriesGroupBy.mean" + with tm.assert_produces_warning(FutureWarning, match=msg): + grouped.transform(np.mean) + + +def test_dispatch_transform(tsframe): + df = tsframe[::5].reindex(tsframe.index) + + grouped = df.groupby(lambda x: x.month) + + msg = "DataFrameGroupBy.fillna is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + filled = grouped.fillna(method="pad") + msg = "Series.fillna with 'method' is deprecated" + fillit = lambda x: x.fillna(method="pad") + with tm.assert_produces_warning(FutureWarning, match=msg): + expected = df.groupby(lambda x: x.month).transform(fillit) + tm.assert_frame_equal(filled, expected) + + +def test_transform_fillna_null(): + df = DataFrame( + { + "price": [10, 10, 20, 20, 30, 30], + "color": [10, 10, 20, 20, 30, 30], + "cost": (100, 200, 300, 400, 500, 600), + } + ) + msg = "DataFrameGroupBy.fillna is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + with pytest.raises(ValueError, match="Must specify a fill 'value' or 'method'"): + df.groupby(["price"]).transform("fillna") + with tm.assert_produces_warning(FutureWarning, match=msg): + with pytest.raises(ValueError, match="Must specify a fill 'value' or 'method'"): + df.groupby(["price"]).fillna() + + +def test_transform_transformation_func(transformation_func): + # GH 30918 + df = DataFrame( + { + "A": ["foo", "foo", "foo", "foo", "bar", "bar", "baz"], + "B": [1, 2, np.nan, 3, 3, np.nan, 4], + }, + index=date_range("2020-01-01", "2020-01-07"), + ) + if transformation_func == "cumcount": + test_op = lambda x: x.transform("cumcount") + mock_op = lambda x: Series(range(len(x)), x.index) + elif transformation_func == "fillna": + test_op = lambda x: x.transform("fillna", value=0) + mock_op = lambda x: x.fillna(value=0) + elif transformation_func == "ngroup": + test_op = lambda x: x.transform("ngroup") + counter = -1 + + def mock_op(x): + nonlocal counter + counter += 1 + return Series(counter, index=x.index) + + else: + test_op = lambda x: x.transform(transformation_func) + mock_op = lambda x: getattr(x, transformation_func)() + + if transformation_func == "pct_change": + msg = "The default fill_method='pad' in DataFrame.pct_change is deprecated" + groupby_msg = ( + "The default fill_method='ffill' in DataFrameGroupBy.pct_change " + "is deprecated" + ) + warn = FutureWarning + groupby_warn = FutureWarning + elif transformation_func == "fillna": + msg = "" + groupby_msg = "DataFrameGroupBy.fillna is deprecated" + warn = None + groupby_warn = FutureWarning + else: + msg = groupby_msg = "" + warn = groupby_warn = None + + with tm.assert_produces_warning(groupby_warn, match=groupby_msg): + result = test_op(df.groupby("A")) + + # pass the group in same order as iterating `for ... in df.groupby(...)` + # but reorder to match df's index since this is a transform + groups = [df[["B"]].iloc[4:6], df[["B"]].iloc[6:], df[["B"]].iloc[:4]] + with tm.assert_produces_warning(warn, match=msg): + expected = concat([mock_op(g) for g in groups]).sort_index() + # sort_index does not preserve the freq + expected = expected.set_axis(df.index) + + if transformation_func in ("cumcount", "ngroup"): + tm.assert_series_equal(result, expected) + else: + tm.assert_frame_equal(result, expected) + + +def test_transform_select_columns(df): + f = lambda x: x.mean() + result = df.groupby("A")[["C", "D"]].transform(f) + + selection = df[["C", "D"]] + expected = selection.groupby(df["A"]).transform(f) + + tm.assert_frame_equal(result, expected) + + +def test_transform_nuisance_raises(df): + # case that goes through _transform_item_by_item + + df.columns = ["A", "B", "B", "D"] + + # this also tests orderings in transform between + # series/frame to make sure it's consistent + grouped = df.groupby("A") + + gbc = grouped["B"] + with pytest.raises(TypeError, match="Could not convert"): + gbc.transform(lambda x: np.mean(x)) + + with pytest.raises(TypeError, match="Could not convert"): + df.groupby("A").transform(lambda x: np.mean(x)) + + +def test_transform_function_aliases(df): + result = df.groupby("A").transform("mean", numeric_only=True) + msg = "using DataFrameGroupBy.mean" + with tm.assert_produces_warning(FutureWarning, match=msg): + expected = df.groupby("A")[["C", "D"]].transform(np.mean) + tm.assert_frame_equal(result, expected) + + result = df.groupby("A")["C"].transform("mean") + msg = "using SeriesGroupBy.mean" + with tm.assert_produces_warning(FutureWarning, match=msg): + expected = df.groupby("A")["C"].transform(np.mean) + tm.assert_series_equal(result, expected) + + +def test_series_fast_transform_date(): + # GH 13191 + df = DataFrame( + {"grouping": [np.nan, 1, 1, 3], "d": date_range("2014-1-1", "2014-1-4")} + ) + result = df.groupby("grouping")["d"].transform("first") + dates = [ + pd.NaT, + Timestamp("2014-1-2"), + Timestamp("2014-1-2"), + Timestamp("2014-1-4"), + ] + expected = Series(dates, name="d", dtype="M8[ns]") + tm.assert_series_equal(result, expected) + + +def test_transform_length(): + # GH 9697 + df = DataFrame({"col1": [1, 1, 2, 2], "col2": [1, 2, 3, np.nan]}) + expected = Series([3.0] * 4) + + def nsum(x): + return np.nansum(x) + + msg = "using DataFrameGroupBy.sum" + with tm.assert_produces_warning(FutureWarning, match=msg): + results = [ + df.groupby("col1").transform(sum)["col2"], + df.groupby("col1")["col2"].transform(sum), + df.groupby("col1").transform(nsum)["col2"], + df.groupby("col1")["col2"].transform(nsum), + ] + for result in results: + tm.assert_series_equal(result, expected, check_names=False) + + +def test_transform_coercion(): + # 14457 + # when we are transforming be sure to not coerce + # via assignment + df = DataFrame({"A": ["a", "a", "b", "b"], "B": [0, 1, 3, 4]}) + g = df.groupby("A") + + msg = "using DataFrameGroupBy.mean" + with tm.assert_produces_warning(FutureWarning, match=msg): + expected = g.transform(np.mean) + + result = g.transform(lambda x: np.mean(x, axis=0)) + tm.assert_frame_equal(result, expected) + + +def test_groupby_transform_with_int(): + # GH 3740, make sure that we might upcast on item-by-item transform + + # floats + df = DataFrame( + { + "A": [1, 1, 1, 2, 2, 2], + "B": Series(1, dtype="float64"), + "C": Series([1, 2, 3, 1, 2, 3], dtype="float64"), + "D": "foo", + } + ) + with np.errstate(all="ignore"): + result = df.groupby("A")[["B", "C"]].transform( + lambda x: (x - x.mean()) / x.std() + ) + expected = DataFrame( + {"B": np.nan, "C": Series([-1, 0, 1, -1, 0, 1], dtype="float64")} + ) + tm.assert_frame_equal(result, expected) + + # int case + df = DataFrame( + { + "A": [1, 1, 1, 2, 2, 2], + "B": 1, + "C": [1, 2, 3, 1, 2, 3], + "D": "foo", + } + ) + with np.errstate(all="ignore"): + with pytest.raises(TypeError, match="Could not convert"): + df.groupby("A").transform(lambda x: (x - x.mean()) / x.std()) + result = df.groupby("A")[["B", "C"]].transform( + lambda x: (x - x.mean()) / x.std() + ) + expected = DataFrame({"B": np.nan, "C": [-1.0, 0.0, 1.0, -1.0, 0.0, 1.0]}) + tm.assert_frame_equal(result, expected) + + # int that needs float conversion + s = Series([2, 3, 4, 10, 5, -1]) + df = DataFrame({"A": [1, 1, 1, 2, 2, 2], "B": 1, "C": s, "D": "foo"}) + with np.errstate(all="ignore"): + with pytest.raises(TypeError, match="Could not convert"): + df.groupby("A").transform(lambda x: (x - x.mean()) / x.std()) + result = df.groupby("A")[["B", "C"]].transform( + lambda x: (x - x.mean()) / x.std() + ) + + s1 = s.iloc[0:3] + s1 = (s1 - s1.mean()) / s1.std() + s2 = s.iloc[3:6] + s2 = (s2 - s2.mean()) / s2.std() + expected = DataFrame({"B": np.nan, "C": concat([s1, s2])}) + tm.assert_frame_equal(result, expected) + + # int doesn't get downcasted + result = df.groupby("A")[["B", "C"]].transform(lambda x: x * 2 / 2) + expected = DataFrame({"B": 1.0, "C": [2.0, 3.0, 4.0, 10.0, 5.0, -1.0]}) + tm.assert_frame_equal(result, expected) + + +def test_groupby_transform_with_nan_group(): + # GH 9941 + df = DataFrame({"a": range(10), "b": [1, 1, 2, 3, np.nan, 4, 4, 5, 5, 5]}) + msg = "using SeriesGroupBy.max" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = df.groupby(df.b)["a"].transform(max) + expected = Series([1.0, 1.0, 2.0, 3.0, np.nan, 6.0, 6.0, 9.0, 9.0, 9.0], name="a") + tm.assert_series_equal(result, expected) + + +def test_transform_mixed_type(): + index = MultiIndex.from_arrays([[0, 0, 0, 1, 1, 1], [1, 2, 3, 1, 2, 3]]) + df = DataFrame( + { + "d": [1.0, 1.0, 1.0, 2.0, 2.0, 2.0], + "c": np.tile(["a", "b", "c"], 2), + "v": np.arange(1.0, 7.0), + }, + index=index, + ) + + def f(group): + group["g"] = group["d"] * 2 + return group[:1] + + grouped = df.groupby("c") + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = grouped.apply(f) + + assert result["d"].dtype == np.float64 + + # this is by definition a mutating operation! + with pd.option_context("mode.chained_assignment", None): + for key, group in grouped: + res = f(group) + tm.assert_frame_equal(res, result.loc[key]) + + +@pytest.mark.parametrize( + "op, args, targop", + [ + ("cumprod", (), lambda x: x.cumprod()), + ("cumsum", (), lambda x: x.cumsum()), + ("shift", (-1,), lambda x: x.shift(-1)), + ("shift", (1,), lambda x: x.shift()), + ], +) +def test_cython_transform_series(op, args, targop): + # GH 4095 + s = Series(np.random.default_rng(2).standard_normal(1000)) + s_missing = s.copy() + s_missing.iloc[2:10] = np.nan + labels = np.random.default_rng(2).integers(0, 50, size=1000).astype(float) + + # series + for data in [s, s_missing]: + # print(data.head()) + expected = data.groupby(labels).transform(targop) + + tm.assert_series_equal(expected, data.groupby(labels).transform(op, *args)) + tm.assert_series_equal(expected, getattr(data.groupby(labels), op)(*args)) + + +@pytest.mark.parametrize("op", ["cumprod", "cumsum"]) +@pytest.mark.parametrize("skipna", [False, True]) +@pytest.mark.parametrize( + "input, exp", + [ + # When everything is NaN + ({"key": ["b"] * 10, "value": np.nan}, Series([np.nan] * 10, name="value")), + # When there is a single NaN + ( + {"key": ["b"] * 10 + ["a"] * 2, "value": [3] * 3 + [np.nan] + [3] * 8}, + { + ("cumprod", False): [3.0, 9.0, 27.0] + [np.nan] * 7 + [3.0, 9.0], + ("cumprod", True): [ + 3.0, + 9.0, + 27.0, + np.nan, + 81.0, + 243.0, + 729.0, + 2187.0, + 6561.0, + 19683.0, + 3.0, + 9.0, + ], + ("cumsum", False): [3.0, 6.0, 9.0] + [np.nan] * 7 + [3.0, 6.0], + ("cumsum", True): [ + 3.0, + 6.0, + 9.0, + np.nan, + 12.0, + 15.0, + 18.0, + 21.0, + 24.0, + 27.0, + 3.0, + 6.0, + ], + }, + ), + ], +) +def test_groupby_cum_skipna(op, skipna, input, exp): + df = DataFrame(input) + result = df.groupby("key")["value"].transform(op, skipna=skipna) + if isinstance(exp, dict): + expected = exp[(op, skipna)] + else: + expected = exp + expected = Series(expected, name="value") + tm.assert_series_equal(expected, result) + + +@pytest.fixture +def frame(): + floating = Series(np.random.default_rng(2).standard_normal(10)) + floating_missing = floating.copy() + floating_missing.iloc[2:7] = np.nan + strings = list("abcde") * 2 + strings_missing = strings[:] + strings_missing[5] = np.nan + + df = DataFrame( + { + "float": floating, + "float_missing": floating_missing, + "int": [1, 1, 1, 1, 2] * 2, + "datetime": date_range("1990-1-1", periods=10), + "timedelta": pd.timedelta_range(1, freq="s", periods=10), + "string": strings, + "string_missing": strings_missing, + "cat": Categorical(strings), + }, + ) + return df + + +@pytest.fixture +def frame_mi(frame): + frame.index = MultiIndex.from_product([range(5), range(2)]) + return frame + + +@pytest.mark.slow +@pytest.mark.parametrize( + "op, args, targop", + [ + ("cumprod", (), lambda x: x.cumprod()), + ("cumsum", (), lambda x: x.cumsum()), + ("shift", (-1,), lambda x: x.shift(-1)), + ("shift", (1,), lambda x: x.shift()), + ], +) +@pytest.mark.parametrize("df_fix", ["frame", "frame_mi"]) +@pytest.mark.parametrize( + "gb_target", + [ + {"by": np.random.default_rng(2).integers(0, 50, size=10).astype(float)}, + {"level": 0}, + {"by": "string"}, + pytest.param({"by": "string_missing"}, marks=pytest.mark.xfail), + {"by": ["int", "string"]}, + ], +) +def test_cython_transform_frame(request, op, args, targop, df_fix, gb_target): + df = request.getfixturevalue(df_fix) + gb = df.groupby(group_keys=False, **gb_target) + + if op != "shift" and "int" not in gb_target: + # numeric apply fastpath promotes dtype so have + # to apply separately and concat + i = gb[["int"]].apply(targop) + f = gb[["float", "float_missing"]].apply(targop) + expected = concat([f, i], axis=1) + else: + if op != "shift" or not isinstance(gb_target.get("by"), (str, list)): + warn = None + else: + warn = DeprecationWarning + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(warn, match=msg): + expected = gb.apply(targop) + + expected = expected.sort_index(axis=1) + if op == "shift": + depr_msg = "The 'downcast' keyword in fillna is deprecated" + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + expected["string_missing"] = expected["string_missing"].fillna( + np.nan, downcast=False + ) + expected["string"] = expected["string"].fillna(np.nan, downcast=False) + + result = gb[expected.columns].transform(op, *args).sort_index(axis=1) + tm.assert_frame_equal(result, expected) + result = getattr(gb[expected.columns], op)(*args).sort_index(axis=1) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.slow +@pytest.mark.parametrize( + "op, args, targop", + [ + ("cumprod", (), lambda x: x.cumprod()), + ("cumsum", (), lambda x: x.cumsum()), + ("shift", (-1,), lambda x: x.shift(-1)), + ("shift", (1,), lambda x: x.shift()), + ], +) +@pytest.mark.parametrize("df_fix", ["frame", "frame_mi"]) +@pytest.mark.parametrize( + "gb_target", + [ + {"by": np.random.default_rng(2).integers(0, 50, size=10).astype(float)}, + {"level": 0}, + {"by": "string"}, + # TODO: create xfail condition given other params + # {"by": 'string_missing'}, + {"by": ["int", "string"]}, + ], +) +@pytest.mark.parametrize( + "column", + [ + "float", + "float_missing", + "int", + "datetime", + "timedelta", + "string", + "string_missing", + ], +) +def test_cython_transform_frame_column( + request, op, args, targop, df_fix, gb_target, column +): + df = request.getfixturevalue(df_fix) + gb = df.groupby(group_keys=False, **gb_target) + c = column + if ( + c not in ["float", "int", "float_missing"] + and op != "shift" + and not (c == "timedelta" and op == "cumsum") + ): + msg = "|".join( + [ + "does not support .* operations", + ".* is not supported for object dtype", + "is not implemented for this dtype", + ] + ) + with pytest.raises(TypeError, match=msg): + gb[c].transform(op) + with pytest.raises(TypeError, match=msg): + getattr(gb[c], op)() + else: + expected = gb[c].apply(targop) + expected.name = c + if c in ["string_missing", "string"]: + depr_msg = "The 'downcast' keyword in fillna is deprecated" + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + expected = expected.fillna(np.nan, downcast=False) + + res = gb[c].transform(op, *args) + tm.assert_series_equal(expected, res) + res2 = getattr(gb[c], op)(*args) + tm.assert_series_equal(expected, res2) + + +def test_transform_with_non_scalar_group(): + # GH 10165 + cols = MultiIndex.from_tuples( + [ + ("syn", "A"), + ("foo", "A"), + ("non", "A"), + ("syn", "C"), + ("foo", "C"), + ("non", "C"), + ("syn", "T"), + ("foo", "T"), + ("non", "T"), + ("syn", "G"), + ("foo", "G"), + ("non", "G"), + ] + ) + df = DataFrame( + np.random.default_rng(2).integers(1, 10, (4, 12)), + columns=cols, + index=["A", "C", "G", "T"], + ) + + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + gb = df.groupby(axis=1, level=1) + msg = "transform must return a scalar value for each group.*" + with pytest.raises(ValueError, match=msg): + gb.transform(lambda z: z.div(z.sum(axis=1), axis=0)) + + +@pytest.mark.parametrize( + "cols,expected", + [ + ("a", Series([1, 1, 1], name="a")), + ( + ["a", "c"], + DataFrame({"a": [1, 1, 1], "c": [1, 1, 1]}), + ), + ], +) +@pytest.mark.parametrize("agg_func", ["count", "rank", "size"]) +def test_transform_numeric_ret(cols, expected, agg_func): + # GH#19200 and GH#27469 + df = DataFrame( + {"a": date_range("2018-01-01", periods=3), "b": range(3), "c": range(7, 10)} + ) + result = df.groupby("b")[cols].transform(agg_func) + + if agg_func == "rank": + expected = expected.astype("float") + elif agg_func == "size" and cols == ["a", "c"]: + # transform("size") returns a Series + expected = expected["a"].rename(None) + tm.assert_equal(result, expected) + + +def test_transform_ffill(): + # GH 24211 + data = [["a", 0.0], ["a", float("nan")], ["b", 1.0], ["b", float("nan")]] + df = DataFrame(data, columns=["key", "values"]) + result = df.groupby("key").transform("ffill") + expected = DataFrame({"values": [0.0, 0.0, 1.0, 1.0]}) + tm.assert_frame_equal(result, expected) + result = df.groupby("key")["values"].transform("ffill") + expected = Series([0.0, 0.0, 1.0, 1.0], name="values") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("mix_groupings", [True, False]) +@pytest.mark.parametrize("as_series", [True, False]) +@pytest.mark.parametrize("val1,val2", [("foo", "bar"), (1, 2), (1.0, 2.0)]) +@pytest.mark.parametrize( + "fill_method,limit,exp_vals", + [ + ( + "ffill", + None, + [np.nan, np.nan, "val1", "val1", "val1", "val2", "val2", "val2"], + ), + ("ffill", 1, [np.nan, np.nan, "val1", "val1", np.nan, "val2", "val2", np.nan]), + ( + "bfill", + None, + ["val1", "val1", "val1", "val2", "val2", "val2", np.nan, np.nan], + ), + ("bfill", 1, [np.nan, "val1", "val1", np.nan, "val2", "val2", np.nan, np.nan]), + ], +) +def test_group_fill_methods( + mix_groupings, as_series, val1, val2, fill_method, limit, exp_vals +): + vals = [np.nan, np.nan, val1, np.nan, np.nan, val2, np.nan, np.nan] + _exp_vals = list(exp_vals) + # Overwrite placeholder values + for index, exp_val in enumerate(_exp_vals): + if exp_val == "val1": + _exp_vals[index] = val1 + elif exp_val == "val2": + _exp_vals[index] = val2 + + # Need to modify values and expectations depending on the + # Series / DataFrame that we ultimately want to generate + if mix_groupings: # ['a', 'b', 'a, 'b', ...] + keys = ["a", "b"] * len(vals) + + def interweave(list_obj): + temp = [] + for x in list_obj: + temp.extend([x, x]) + + return temp + + _exp_vals = interweave(_exp_vals) + vals = interweave(vals) + else: # ['a', 'a', 'a', ... 'b', 'b', 'b'] + keys = ["a"] * len(vals) + ["b"] * len(vals) + _exp_vals = _exp_vals * 2 + vals = vals * 2 + + df = DataFrame({"key": keys, "val": vals}) + if as_series: + result = getattr(df.groupby("key")["val"], fill_method)(limit=limit) + exp = Series(_exp_vals, name="val") + tm.assert_series_equal(result, exp) + else: + result = getattr(df.groupby("key"), fill_method)(limit=limit) + exp = DataFrame({"val": _exp_vals}) + tm.assert_frame_equal(result, exp) + + +@pytest.mark.parametrize("fill_method", ["ffill", "bfill"]) +def test_pad_stable_sorting(fill_method): + # GH 21207 + x = [0] * 20 + y = [np.nan] * 10 + [1] * 10 + + if fill_method == "bfill": + y = y[::-1] + + df = DataFrame({"x": x, "y": y}) + expected = df.drop("x", axis=1) + + result = getattr(df.groupby("x"), fill_method)() + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "freq", + [ + None, + pytest.param( + "D", + marks=pytest.mark.xfail( + reason="GH#23918 before method uses freq in vectorized approach" + ), + ), + ], +) +@pytest.mark.parametrize("periods", [1, -1]) +@pytest.mark.parametrize("fill_method", ["ffill", "bfill", None]) +@pytest.mark.parametrize("limit", [None, 1]) +def test_pct_change(frame_or_series, freq, periods, fill_method, limit): + # GH 21200, 21621, 30463 + vals = [3, np.nan, np.nan, np.nan, 1, 2, 4, 10, np.nan, 4] + keys = ["a", "b"] + key_v = np.repeat(keys, len(vals)) + df = DataFrame({"key": key_v, "vals": vals * 2}) + + df_g = df + if fill_method is not None: + df_g = getattr(df.groupby("key"), fill_method)(limit=limit) + grp = df_g.groupby(df.key) + + expected = grp["vals"].obj / grp["vals"].shift(periods) - 1 + + gb = df.groupby("key") + + if frame_or_series is Series: + gb = gb["vals"] + else: + expected = expected.to_frame("vals") + + msg = ( + "The 'fill_method' keyword being not None and the 'limit' keyword in " + f"{type(gb).__name__}.pct_change are deprecated" + ) + with tm.assert_produces_warning(FutureWarning, match=msg): + result = gb.pct_change( + periods=periods, fill_method=fill_method, limit=limit, freq=freq + ) + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "func, expected_status", + [ + ("ffill", ["shrt", "shrt", "lng", np.nan, "shrt", "ntrl", "ntrl"]), + ("bfill", ["shrt", "lng", "lng", "shrt", "shrt", "ntrl", np.nan]), + ], +) +def test_ffill_bfill_non_unique_multilevel(func, expected_status): + # GH 19437 + date = pd.to_datetime( + [ + "2018-01-01", + "2018-01-01", + "2018-01-01", + "2018-01-01", + "2018-01-02", + "2018-01-01", + "2018-01-02", + ] + ) + symbol = ["MSFT", "MSFT", "MSFT", "AAPL", "AAPL", "TSLA", "TSLA"] + status = ["shrt", np.nan, "lng", np.nan, "shrt", "ntrl", np.nan] + + df = DataFrame({"date": date, "symbol": symbol, "status": status}) + df = df.set_index(["date", "symbol"]) + result = getattr(df.groupby("symbol")["status"], func)() + + index = MultiIndex.from_tuples( + tuples=list(zip(*[date, symbol])), names=["date", "symbol"] + ) + expected = Series(expected_status, index=index, name="status") + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("func", [np.any, np.all]) +def test_any_all_np_func(func): + # GH 20653 + df = DataFrame( + [["foo", True], [np.nan, True], ["foo", True]], columns=["key", "val"] + ) + + exp = Series([True, np.nan, True], name="val") + + msg = "using SeriesGroupBy.[any|all]" + with tm.assert_produces_warning(FutureWarning, match=msg): + res = df.groupby("key")["val"].transform(func) + tm.assert_series_equal(res, exp) + + +def test_groupby_transform_rename(): + # https://github.com/pandas-dev/pandas/issues/23461 + def demean_rename(x): + result = x - x.mean() + + if isinstance(x, Series): + return result + + result = result.rename(columns={c: f"{c}_demeaned" for c in result.columns}) + + return result + + df = DataFrame({"group": list("ababa"), "value": [1, 1, 1, 2, 2]}) + expected = DataFrame({"value": [-1.0 / 3, -0.5, -1.0 / 3, 0.5, 2.0 / 3]}) + + result = df.groupby("group").transform(demean_rename) + tm.assert_frame_equal(result, expected) + result_single = df.groupby("group").value.transform(demean_rename) + tm.assert_series_equal(result_single, expected["value"]) + + +@pytest.mark.parametrize("func", [min, max, np.min, np.max, "first", "last"]) +def test_groupby_transform_timezone_column(func): + # GH 24198 + ts = pd.to_datetime("now", utc=True).tz_convert("Asia/Singapore") + result = DataFrame({"end_time": [ts], "id": [1]}) + warn = FutureWarning if not isinstance(func, str) else None + msg = "using SeriesGroupBy.[min|max]" + with tm.assert_produces_warning(warn, match=msg): + result["max_end_time"] = result.groupby("id").end_time.transform(func) + expected = DataFrame([[ts, 1, ts]], columns=["end_time", "id", "max_end_time"]) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "func, values", + [ + ("idxmin", ["1/1/2011"] * 2 + ["1/3/2011"] * 7 + ["1/10/2011"]), + ("idxmax", ["1/2/2011"] * 2 + ["1/9/2011"] * 7 + ["1/10/2011"]), + ], +) +def test_groupby_transform_with_datetimes(func, values): + # GH 15306 + dates = date_range("1/1/2011", periods=10, freq="D") + + stocks = DataFrame({"price": np.arange(10.0)}, index=dates) + stocks["week_id"] = dates.isocalendar().week + + result = stocks.groupby(stocks["week_id"])["price"].transform(func) + + expected = Series( + data=pd.to_datetime(values).as_unit("ns"), index=dates, name="price" + ) + + tm.assert_series_equal(result, expected) + + +def test_groupby_transform_dtype(): + # GH 22243 + df = DataFrame({"a": [1], "val": [1.35]}) + + result = df["val"].transform(lambda x: x.map(lambda y: f"+{y}")) + expected1 = Series(["+1.35"], name="val", dtype="object") + tm.assert_series_equal(result, expected1) + + result = df.groupby("a")["val"].transform(lambda x: x.map(lambda y: f"+{y}")) + tm.assert_series_equal(result, expected1) + + result = df.groupby("a")["val"].transform(lambda x: x.map(lambda y: f"+({y})")) + expected2 = Series(["+(1.35)"], name="val", dtype="object") + tm.assert_series_equal(result, expected2) + + df["val"] = df["val"].astype(object) + result = df.groupby("a")["val"].transform(lambda x: x.map(lambda y: f"+{y}")) + tm.assert_series_equal(result, expected1) + + +@pytest.mark.parametrize("func", ["cumsum", "cumprod", "cummin", "cummax"]) +def test_transform_absent_categories(func): + # GH 16771 + # cython transforms with more groups than rows + x_vals = [1] + x_cats = range(2) + y = [1] + df = DataFrame({"x": Categorical(x_vals, x_cats), "y": y}) + result = getattr(df.y.groupby(df.x, observed=False), func)() + expected = df.y + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("func", ["ffill", "bfill", "shift"]) +@pytest.mark.parametrize("key, val", [("level", 0), ("by", Series([0]))]) +def test_ffill_not_in_axis(func, key, val): + # GH 21521 + df = DataFrame([[np.nan]]) + result = getattr(df.groupby(**{key: val}), func)() + expected = df + + tm.assert_frame_equal(result, expected) + + +def test_transform_invalid_name_raises(): + # GH#27486 + df = DataFrame({"a": [0, 1, 1, 2]}) + g = df.groupby(["a", "b", "b", "c"]) + with pytest.raises(ValueError, match="not a valid function name"): + g.transform("some_arbitrary_name") + + # method exists on the object, but is not a valid transformation/agg + assert hasattr(g, "aggregate") # make sure the method exists + with pytest.raises(ValueError, match="not a valid function name"): + g.transform("aggregate") + + # Test SeriesGroupBy + g = df["a"].groupby(["a", "b", "b", "c"]) + with pytest.raises(ValueError, match="not a valid function name"): + g.transform("some_arbitrary_name") + + +def test_transform_agg_by_name(request, reduction_func, frame_or_series): + func = reduction_func + + obj = DataFrame( + {"a": [0, 0, 0, 1, 1, 1], "b": range(6)}, + index=["A", "B", "C", "D", "E", "F"], + ) + if frame_or_series is Series: + obj = obj["a"] + + g = obj.groupby(np.repeat([0, 1], 3)) + + if func == "corrwith" and isinstance(obj, Series): # GH#32293 + # TODO: implement SeriesGroupBy.corrwith + assert not hasattr(g, func) + return + + args = get_groupby_method_args(reduction_func, obj) + result = g.transform(func, *args) + + # this is the *definition* of a transformation + tm.assert_index_equal(result.index, obj.index) + + if func not in ("ngroup", "size") and obj.ndim == 2: + # size/ngroup return a Series, unlike other transforms + tm.assert_index_equal(result.columns, obj.columns) + + # verify that values were broadcasted across each group + assert len(set(DataFrame(result).iloc[-3:, -1])) == 1 + + +def test_transform_lambda_with_datetimetz(): + # GH 27496 + df = DataFrame( + { + "time": [ + Timestamp("2010-07-15 03:14:45"), + Timestamp("2010-11-19 18:47:06"), + ], + "timezone": ["Etc/GMT+4", "US/Eastern"], + } + ) + result = df.groupby(["timezone"])["time"].transform( + lambda x: x.dt.tz_localize(x.name) + ) + expected = Series( + [ + Timestamp("2010-07-15 03:14:45", tz="Etc/GMT+4"), + Timestamp("2010-11-19 18:47:06", tz="US/Eastern"), + ], + name="time", + ) + tm.assert_series_equal(result, expected) + + +def test_transform_fastpath_raises(): + # GH#29631 case where fastpath defined in groupby.generic _choose_path + # raises, but slow_path does not + + df = DataFrame({"A": [1, 1, 2, 2], "B": [1, -1, 1, 2]}) + gb = df.groupby("A") + + def func(grp): + # we want a function such that func(frame) fails but func.apply(frame) + # works + if grp.ndim == 2: + # Ensure that fast_path fails + raise NotImplementedError("Don't cross the streams") + return grp * 2 + + # Check that the fastpath raises, see _transform_general + obj = gb._obj_with_exclusions + gen = gb._grouper.get_iterator(obj, axis=gb.axis) + fast_path, slow_path = gb._define_paths(func) + _, group = next(gen) + + with pytest.raises(NotImplementedError, match="Don't cross the streams"): + fast_path(group) + + result = gb.transform(func) + + expected = DataFrame([2, -2, 2, 4], columns=["B"]) + tm.assert_frame_equal(result, expected) + + +def test_transform_lambda_indexing(): + # GH 7883 + df = DataFrame( + { + "A": ["foo", "bar", "foo", "bar", "foo", "flux", "foo", "flux"], + "B": ["one", "one", "two", "three", "two", "six", "five", "three"], + "C": range(8), + "D": range(8), + "E": range(8), + } + ) + df = df.set_index(["A", "B"]) + df = df.sort_index() + result = df.groupby(level="A").transform(lambda x: x.iloc[-1]) + expected = DataFrame( + { + "C": [3, 3, 7, 7, 4, 4, 4, 4], + "D": [3, 3, 7, 7, 4, 4, 4, 4], + "E": [3, 3, 7, 7, 4, 4, 4, 4], + }, + index=MultiIndex.from_tuples( + [ + ("bar", "one"), + ("bar", "three"), + ("flux", "six"), + ("flux", "three"), + ("foo", "five"), + ("foo", "one"), + ("foo", "two"), + ("foo", "two"), + ], + names=["A", "B"], + ), + ) + tm.assert_frame_equal(result, expected) + + +def test_categorical_and_not_categorical_key(observed): + # Checks that groupby-transform, when grouping by both a categorical + # and a non-categorical key, doesn't try to expand the output to include + # non-observed categories but instead matches the input shape. + # GH 32494 + df_with_categorical = DataFrame( + { + "A": Categorical(["a", "b", "a"], categories=["a", "b", "c"]), + "B": [1, 2, 3], + "C": ["a", "b", "a"], + } + ) + df_without_categorical = DataFrame( + {"A": ["a", "b", "a"], "B": [1, 2, 3], "C": ["a", "b", "a"]} + ) + + # DataFrame case + result = df_with_categorical.groupby(["A", "C"], observed=observed).transform("sum") + expected = df_without_categorical.groupby(["A", "C"]).transform("sum") + tm.assert_frame_equal(result, expected) + expected_explicit = DataFrame({"B": [4, 2, 4]}) + tm.assert_frame_equal(result, expected_explicit) + + # Series case + result = df_with_categorical.groupby(["A", "C"], observed=observed)["B"].transform( + "sum" + ) + expected = df_without_categorical.groupby(["A", "C"])["B"].transform("sum") + tm.assert_series_equal(result, expected) + expected_explicit = Series([4, 2, 4], name="B") + tm.assert_series_equal(result, expected_explicit) + + +def test_string_rank_grouping(): + # GH 19354 + df = DataFrame({"A": [1, 1, 2], "B": [1, 2, 3]}) + result = df.groupby("A").transform("rank") + expected = DataFrame({"B": [1.0, 2.0, 1.0]}) + tm.assert_frame_equal(result, expected) + + +def test_transform_cumcount(): + # GH 27472 + df = DataFrame({"a": [0, 0, 0, 1, 1, 1], "b": range(6)}) + grp = df.groupby(np.repeat([0, 1], 3)) + + result = grp.cumcount() + expected = Series([0, 1, 2, 0, 1, 2]) + tm.assert_series_equal(result, expected) + + result = grp.transform("cumcount") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("keys", [["A1"], ["A1", "A2"]]) +def test_null_group_lambda_self(sort, dropna, keys): + # GH 17093 + size = 50 + nulls1 = np.random.default_rng(2).choice([False, True], size) + nulls2 = np.random.default_rng(2).choice([False, True], size) + # Whether a group contains a null value or not + nulls_grouper = nulls1 if len(keys) == 1 else nulls1 | nulls2 + + a1 = np.random.default_rng(2).integers(0, 5, size=size).astype(float) + a1[nulls1] = np.nan + a2 = np.random.default_rng(2).integers(0, 5, size=size).astype(float) + a2[nulls2] = np.nan + values = np.random.default_rng(2).integers(0, 5, size=a1.shape) + df = DataFrame({"A1": a1, "A2": a2, "B": values}) + + expected_values = values + if dropna and nulls_grouper.any(): + expected_values = expected_values.astype(float) + expected_values[nulls_grouper] = np.nan + expected = DataFrame(expected_values, columns=["B"]) + + gb = df.groupby(keys, dropna=dropna, sort=sort) + result = gb[["B"]].transform(lambda x: x) + tm.assert_frame_equal(result, expected) + + +def test_null_group_str_reducer(request, dropna, reduction_func): + # GH 17093 + if reduction_func == "corrwith": + msg = "incorrectly raises" + request.applymarker(pytest.mark.xfail(reason=msg)) + + index = [1, 2, 3, 4] # test transform preserves non-standard index + df = DataFrame({"A": [1, 1, np.nan, np.nan], "B": [1, 2, 2, 3]}, index=index) + gb = df.groupby("A", dropna=dropna) + + args = get_groupby_method_args(reduction_func, df) + + # Manually handle reducers that don't fit the generic pattern + # Set expected with dropna=False, then replace if necessary + if reduction_func == "first": + expected = DataFrame({"B": [1, 1, 2, 2]}, index=index) + elif reduction_func == "last": + expected = DataFrame({"B": [2, 2, 3, 3]}, index=index) + elif reduction_func == "nth": + expected = DataFrame({"B": [1, 1, 2, 2]}, index=index) + elif reduction_func == "size": + expected = Series([2, 2, 2, 2], index=index) + elif reduction_func == "corrwith": + expected = DataFrame({"B": [1.0, 1.0, 1.0, 1.0]}, index=index) + else: + expected_gb = df.groupby("A", dropna=False) + buffer = [] + for idx, group in expected_gb: + res = getattr(group["B"], reduction_func)() + buffer.append(Series(res, index=group.index)) + expected = concat(buffer).to_frame("B") + if dropna: + dtype = object if reduction_func in ("any", "all") else float + expected = expected.astype(dtype) + if expected.ndim == 2: + expected.iloc[[2, 3], 0] = np.nan + else: + expected.iloc[[2, 3]] = np.nan + + result = gb.transform(reduction_func, *args) + tm.assert_equal(result, expected) + + +def test_null_group_str_transformer(request, dropna, transformation_func): + # GH 17093 + df = DataFrame({"A": [1, 1, np.nan], "B": [1, 2, 2]}, index=[1, 2, 3]) + args = get_groupby_method_args(transformation_func, df) + gb = df.groupby("A", dropna=dropna) + + buffer = [] + for k, (idx, group) in enumerate(gb): + if transformation_func == "cumcount": + # DataFrame has no cumcount method + res = DataFrame({"B": range(len(group))}, index=group.index) + elif transformation_func == "ngroup": + res = DataFrame(len(group) * [k], index=group.index, columns=["B"]) + else: + res = getattr(group[["B"]], transformation_func)(*args) + buffer.append(res) + if dropna: + dtype = object if transformation_func in ("any", "all") else None + buffer.append(DataFrame([[np.nan]], index=[3], dtype=dtype, columns=["B"])) + expected = concat(buffer) + + if transformation_func in ("cumcount", "ngroup"): + # ngroup/cumcount always returns a Series as it counts the groups, not values + expected = expected["B"].rename(None) + + if transformation_func == "pct_change" and not dropna: + warn = FutureWarning + msg = ( + "The default fill_method='ffill' in DataFrameGroupBy.pct_change " + "is deprecated" + ) + elif transformation_func == "fillna": + warn = FutureWarning + msg = "DataFrameGroupBy.fillna is deprecated" + else: + warn = None + msg = "" + with tm.assert_produces_warning(warn, match=msg): + result = gb.transform(transformation_func, *args) + + tm.assert_equal(result, expected) + + +def test_null_group_str_reducer_series(request, dropna, reduction_func): + # GH 17093 + index = [1, 2, 3, 4] # test transform preserves non-standard index + ser = Series([1, 2, 2, 3], index=index) + gb = ser.groupby([1, 1, np.nan, np.nan], dropna=dropna) + + if reduction_func == "corrwith": + # corrwith not implemented for SeriesGroupBy + assert not hasattr(gb, reduction_func) + return + + args = get_groupby_method_args(reduction_func, ser) + + # Manually handle reducers that don't fit the generic pattern + # Set expected with dropna=False, then replace if necessary + if reduction_func == "first": + expected = Series([1, 1, 2, 2], index=index) + elif reduction_func == "last": + expected = Series([2, 2, 3, 3], index=index) + elif reduction_func == "nth": + expected = Series([1, 1, 2, 2], index=index) + elif reduction_func == "size": + expected = Series([2, 2, 2, 2], index=index) + elif reduction_func == "corrwith": + expected = Series([1, 1, 2, 2], index=index) + else: + expected_gb = ser.groupby([1, 1, np.nan, np.nan], dropna=False) + buffer = [] + for idx, group in expected_gb: + res = getattr(group, reduction_func)() + buffer.append(Series(res, index=group.index)) + expected = concat(buffer) + if dropna: + dtype = object if reduction_func in ("any", "all") else float + expected = expected.astype(dtype) + expected.iloc[[2, 3]] = np.nan + + result = gb.transform(reduction_func, *args) + tm.assert_series_equal(result, expected) + + +def test_null_group_str_transformer_series(dropna, transformation_func): + # GH 17093 + ser = Series([1, 2, 2], index=[1, 2, 3]) + args = get_groupby_method_args(transformation_func, ser) + gb = ser.groupby([1, 1, np.nan], dropna=dropna) + + buffer = [] + for k, (idx, group) in enumerate(gb): + if transformation_func == "cumcount": + # Series has no cumcount method + res = Series(range(len(group)), index=group.index) + elif transformation_func == "ngroup": + res = Series(k, index=group.index) + else: + res = getattr(group, transformation_func)(*args) + buffer.append(res) + if dropna: + dtype = object if transformation_func in ("any", "all") else None + buffer.append(Series([np.nan], index=[3], dtype=dtype)) + expected = concat(buffer) + + warn = FutureWarning if transformation_func == "fillna" else None + msg = "SeriesGroupBy.fillna is deprecated" + with tm.assert_produces_warning(warn, match=msg): + result = gb.transform(transformation_func, *args) + + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "func, expected_values", + [ + (Series.sort_values, [5, 4, 3, 2, 1]), + (lambda x: x.head(1), [5.0, np.nan, 3, 2, np.nan]), + ], +) +@pytest.mark.parametrize("keys", [["a1"], ["a1", "a2"]]) +@pytest.mark.parametrize("keys_in_index", [True, False]) +def test_transform_aligns(func, frame_or_series, expected_values, keys, keys_in_index): + # GH#45648 - transform should align with the input's index + df = DataFrame({"a1": [1, 1, 3, 2, 2], "b": [5, 4, 3, 2, 1]}) + if "a2" in keys: + df["a2"] = df["a1"] + if keys_in_index: + df = df.set_index(keys, append=True) + + gb = df.groupby(keys) + if frame_or_series is Series: + gb = gb["b"] + + result = gb.transform(func) + expected = DataFrame({"b": expected_values}, index=df.index) + if frame_or_series is Series: + expected = expected["b"] + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize("keys", ["A", ["A", "B"]]) +def test_as_index_no_change(keys, df, groupby_func): + # GH#49834 - as_index should have no impact on DataFrameGroupBy.transform + if keys == "A": + # Column B is string dtype; will fail on some ops + df = df.drop(columns="B") + args = get_groupby_method_args(groupby_func, df) + gb_as_index_true = df.groupby(keys, as_index=True) + gb_as_index_false = df.groupby(keys, as_index=False) + warn = FutureWarning if groupby_func == "fillna" else None + msg = "DataFrameGroupBy.fillna is deprecated" + with tm.assert_produces_warning(warn, match=msg): + result = gb_as_index_true.transform(groupby_func, *args) + with tm.assert_produces_warning(warn, match=msg): + expected = gb_as_index_false.transform(groupby_func, *args) + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize("how", ["idxmax", "idxmin"]) +@pytest.mark.parametrize("numeric_only", [True, False]) +def test_idxmin_idxmax_transform_args(how, skipna, numeric_only): + # GH#55268 - ensure *args are passed through when calling transform + df = DataFrame({"a": [1, 1, 1, 2], "b": [3.0, 4.0, np.nan, 6.0], "c": list("abcd")}) + gb = df.groupby("a") + msg = f"'axis' keyword in DataFrameGroupBy.{how} is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = gb.transform(how, 0, skipna, numeric_only) + warn = None if skipna else FutureWarning + msg = f"The behavior of DataFrameGroupBy.{how} with .* any-NA and skipna=False" + with tm.assert_produces_warning(warn, match=msg): + expected = gb.transform(how, skipna=skipna, numeric_only=numeric_only) + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20eb7dd049236e98d93cb8d48b7538bd4ae66c5d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/conftest.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/conftest.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f983342c9c394a08e91ca7a10267d2d2813ed6f7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/conftest.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_any_index.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_any_index.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eadade751a3d9d2b3112906ae190be06ed6a0f65 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_any_index.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_base.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_base.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c407c3f0b8189f332417112514cfb580f78d11f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_base.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_common.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27e0e4f12ef36588c7abbb64507e4ef4fe03f376 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_common.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_datetimelike.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_datetimelike.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9badd85425f3026144d05e9499133f88c82ad3d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_datetimelike.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_engines.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_engines.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be186af15e54d6380df048e35ff132530a6736e9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_engines.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_frozen.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_frozen.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b039038a96dea85d6f60e5fa02ffb247e42c8cc Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_frozen.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_index_new.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_index_new.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..edde1a4b4c55bad2eebead56397f90cc9a902d60 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_index_new.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_indexing.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_indexing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44c20d852eae84c301b62e787c2f70adf5afbb19 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_indexing.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_numpy_compat.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_numpy_compat.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f13e500f54b3c837a87ac053235411f1a98c0044 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_numpy_compat.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_old_base.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_old_base.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..584f2bc70c00bd63a02ba7e640688629e77e768c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_old_base.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_setops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_setops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c893e0ef35621260ca7c03282cb8103e2a9ee14a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_setops.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_subclass.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_subclass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ee879b4a74b061f37376d9abbe7f02e35242acb Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/__pycache__/test_subclass.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/base_class/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/base_class/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/base_class/test_constructors.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/base_class/test_constructors.py new file mode 100644 index 0000000000000000000000000000000000000000..338509dd239e63f7fde17fe377a697429f1405e2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/base_class/test_constructors.py @@ -0,0 +1,80 @@ +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + Index, + MultiIndex, + Series, +) +import pandas._testing as tm + + +class TestIndexConstructor: + # Tests for the Index constructor, specifically for cases that do + # not return a subclass + + @pytest.mark.parametrize("value", [1, np.int64(1)]) + def test_constructor_corner(self, value): + # corner case + msg = ( + r"Index\(\.\.\.\) must be called with a collection of some " + f"kind, {value} was passed" + ) + with pytest.raises(TypeError, match=msg): + Index(value) + + @pytest.mark.parametrize("index_vals", [[("A", 1), "B"], ["B", ("A", 1)]]) + def test_construction_list_mixed_tuples(self, index_vals): + # see gh-10697: if we are constructing from a mixed list of tuples, + # make sure that we are independent of the sorting order. + index = Index(index_vals) + assert isinstance(index, Index) + assert not isinstance(index, MultiIndex) + + def test_constructor_cast(self): + msg = "could not convert string to float" + with pytest.raises(ValueError, match=msg): + Index(["a", "b", "c"], dtype=float) + + @pytest.mark.parametrize("tuple_list", [[()], [(), ()]]) + def test_construct_empty_tuples(self, tuple_list): + # GH #45608 + result = Index(tuple_list) + expected = MultiIndex.from_tuples(tuple_list) + + tm.assert_index_equal(result, expected) + + def test_index_string_inference(self): + # GH#54430 + pytest.importorskip("pyarrow") + dtype = "string[pyarrow_numpy]" + expected = Index(["a", "b"], dtype=dtype) + with pd.option_context("future.infer_string", True): + ser = Index(["a", "b"]) + tm.assert_index_equal(ser, expected) + + expected = Index(["a", 1], dtype="object") + with pd.option_context("future.infer_string", True): + ser = Index(["a", 1]) + tm.assert_index_equal(ser, expected) + + def test_inference_on_pandas_objects(self): + # GH#56012 + idx = Index([pd.Timestamp("2019-12-31")], dtype=object) + with tm.assert_produces_warning(FutureWarning, match="Dtype inference"): + result = Index(idx) + assert result.dtype != np.object_ + + ser = Series([pd.Timestamp("2019-12-31")], dtype=object) + + with tm.assert_produces_warning(FutureWarning, match="Dtype inference"): + result = Index(ser) + assert result.dtype != np.object_ + + def test_constructor_not_read_only(self): + # GH#57130 + ser = Series([1, 2], dtype=object) + with pd.option_context("mode.copy_on_write", True): + idx = Index(ser) + assert idx._values.flags.writeable diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/base_class/test_formats.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/base_class/test_formats.py new file mode 100644 index 0000000000000000000000000000000000000000..f30b578cfcf566ac577d15a1bc15880bea5d7af0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/base_class/test_formats.py @@ -0,0 +1,163 @@ +import numpy as np +import pytest + +from pandas._config import using_pyarrow_string_dtype +import pandas._config.config as cf + +from pandas import Index +import pandas._testing as tm + + +class TestIndexRendering: + def test_repr_is_valid_construction_code(self): + # for the case of Index, where the repr is traditional rather than + # stylized + idx = Index(["a", "b"]) + res = eval(repr(idx)) + tm.assert_index_equal(res, idx) + + @pytest.mark.xfail(using_pyarrow_string_dtype(), reason="repr different") + @pytest.mark.parametrize( + "index,expected", + [ + # ASCII + # short + ( + Index(["a", "bb", "ccc"]), + """Index(['a', 'bb', 'ccc'], dtype='object')""", + ), + # multiple lines + ( + Index(["a", "bb", "ccc"] * 10), + "Index(['a', 'bb', 'ccc', 'a', 'bb', 'ccc', 'a', " + "'bb', 'ccc', 'a', 'bb', 'ccc',\n" + " 'a', 'bb', 'ccc', 'a', 'bb', 'ccc', 'a', " + "'bb', 'ccc', 'a', 'bb', 'ccc',\n" + " 'a', 'bb', 'ccc', 'a', 'bb', 'ccc'],\n" + " dtype='object')", + ), + # truncated + ( + Index(["a", "bb", "ccc"] * 100), + "Index(['a', 'bb', 'ccc', 'a', 'bb', 'ccc', 'a', 'bb', 'ccc', 'a',\n" + " ...\n" + " 'ccc', 'a', 'bb', 'ccc', 'a', 'bb', 'ccc', 'a', 'bb', 'ccc'],\n" + " dtype='object', length=300)", + ), + # Non-ASCII + # short + ( + Index(["あ", "いい", "ううう"]), + """Index(['あ', 'いい', 'ううう'], dtype='object')""", + ), + # multiple lines + ( + Index(["あ", "いい", "ううう"] * 10), + ( + "Index(['あ', 'いい', 'ううう', 'あ', 'いい', 'ううう', " + "'あ', 'いい', 'ううう', 'あ', 'いい', 'ううう',\n" + " 'あ', 'いい', 'ううう', 'あ', 'いい', 'ううう', " + "'あ', 'いい', 'ううう', 'あ', 'いい', 'ううう',\n" + " 'あ', 'いい', 'ううう', 'あ', 'いい', " + "'ううう'],\n" + " dtype='object')" + ), + ), + # truncated + ( + Index(["あ", "いい", "ううう"] * 100), + ( + "Index(['あ', 'いい', 'ううう', 'あ', 'いい', 'ううう', " + "'あ', 'いい', 'ううう', 'あ',\n" + " ...\n" + " 'ううう', 'あ', 'いい', 'ううう', 'あ', 'いい', " + "'ううう', 'あ', 'いい', 'ううう'],\n" + " dtype='object', length=300)" + ), + ), + ], + ) + def test_string_index_repr(self, index, expected): + result = repr(index) + assert result == expected + + @pytest.mark.xfail(using_pyarrow_string_dtype(), reason="repr different") + @pytest.mark.parametrize( + "index,expected", + [ + # short + ( + Index(["あ", "いい", "ううう"]), + ("Index(['あ', 'いい', 'ううう'], dtype='object')"), + ), + # multiple lines + ( + Index(["あ", "いい", "ううう"] * 10), + ( + "Index(['あ', 'いい', 'ううう', 'あ', 'いい', " + "'ううう', 'あ', 'いい', 'ううう',\n" + " 'あ', 'いい', 'ううう', 'あ', 'いい', " + "'ううう', 'あ', 'いい', 'ううう',\n" + " 'あ', 'いい', 'ううう', 'あ', 'いい', " + "'ううう', 'あ', 'いい', 'ううう',\n" + " 'あ', 'いい', 'ううう'],\n" + " dtype='object')" + "" + ), + ), + # truncated + ( + Index(["あ", "いい", "ううう"] * 100), + ( + "Index(['あ', 'いい', 'ううう', 'あ', 'いい', " + "'ううう', 'あ', 'いい', 'ううう',\n" + " 'あ',\n" + " ...\n" + " 'ううう', 'あ', 'いい', 'ううう', 'あ', " + "'いい', 'ううう', 'あ', 'いい',\n" + " 'ううう'],\n" + " dtype='object', length=300)" + ), + ), + ], + ) + def test_string_index_repr_with_unicode_option(self, index, expected): + # Enable Unicode option ----------------------------------------- + with cf.option_context("display.unicode.east_asian_width", True): + result = repr(index) + assert result == expected + + def test_repr_summary(self): + with cf.option_context("display.max_seq_items", 10): + result = repr(Index(np.arange(1000))) + assert len(result) < 200 + assert "..." in result + + def test_summary_bug(self): + # GH#3869 + ind = Index(["{other}%s", "~:{range}:0"], name="A") + result = ind._summary() + # shouldn't be formatted accidentally. + assert "~:{range}:0" in result + assert "{other}%s" in result + + def test_index_repr_bool_nan(self): + # GH32146 + arr = Index([True, False, np.nan], dtype=object) + msg = "Index.format is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + exp1 = arr.format() + out1 = ["True", "False", "NaN"] + assert out1 == exp1 + + exp2 = repr(arr) + out2 = "Index([True, False, nan], dtype='object')" + assert out2 == exp2 + + def test_format_different_scalar_lengths(self): + # GH#35439 + idx = Index(["aaaaaaaaa", "b"]) + expected = ["aaaaaaaaa", "b"] + msg = r"Index\.format is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + assert idx.format() == expected diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/base_class/test_indexing.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/base_class/test_indexing.py new file mode 100644 index 0000000000000000000000000000000000000000..2988fa7d1baa1e0bc0f6cc4b6dc32e5d12f332cf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/base_class/test_indexing.py @@ -0,0 +1,104 @@ +import numpy as np +import pytest + +from pandas._libs import index as libindex + +import pandas as pd +from pandas import ( + Index, + NaT, +) +import pandas._testing as tm + + +class TestGetSliceBounds: + @pytest.mark.parametrize("side, expected", [("left", 4), ("right", 5)]) + def test_get_slice_bounds_within(self, side, expected): + index = Index(list("abcdef")) + result = index.get_slice_bound("e", side=side) + assert result == expected + + @pytest.mark.parametrize("side", ["left", "right"]) + @pytest.mark.parametrize( + "data, bound, expected", [(list("abcdef"), "x", 6), (list("bcdefg"), "a", 0)] + ) + def test_get_slice_bounds_outside(self, side, expected, data, bound): + index = Index(data) + result = index.get_slice_bound(bound, side=side) + assert result == expected + + def test_get_slice_bounds_invalid_side(self): + with pytest.raises(ValueError, match="Invalid value for side kwarg"): + Index([]).get_slice_bound("a", side="middle") + + +class TestGetIndexerNonUnique: + def test_get_indexer_non_unique_dtype_mismatch(self): + # GH#25459 + indexes, missing = Index(["A", "B"]).get_indexer_non_unique(Index([0])) + tm.assert_numpy_array_equal(np.array([-1], dtype=np.intp), indexes) + tm.assert_numpy_array_equal(np.array([0], dtype=np.intp), missing) + + @pytest.mark.parametrize( + "idx_values,idx_non_unique", + [ + ([np.nan, 100, 200, 100], [np.nan, 100]), + ([np.nan, 100.0, 200.0, 100.0], [np.nan, 100.0]), + ], + ) + def test_get_indexer_non_unique_int_index(self, idx_values, idx_non_unique): + indexes, missing = Index(idx_values).get_indexer_non_unique(Index([np.nan])) + tm.assert_numpy_array_equal(np.array([0], dtype=np.intp), indexes) + tm.assert_numpy_array_equal(np.array([], dtype=np.intp), missing) + + indexes, missing = Index(idx_values).get_indexer_non_unique( + Index(idx_non_unique) + ) + tm.assert_numpy_array_equal(np.array([0, 1, 3], dtype=np.intp), indexes) + tm.assert_numpy_array_equal(np.array([], dtype=np.intp), missing) + + +class TestGetLoc: + @pytest.mark.slow # to_flat_index takes a while + def test_get_loc_tuple_monotonic_above_size_cutoff(self, monkeypatch): + # Go through the libindex path for which using + # _bin_search vs ndarray.searchsorted makes a difference + + with monkeypatch.context(): + monkeypatch.setattr(libindex, "_SIZE_CUTOFF", 100) + lev = list("ABCD") + dti = pd.date_range("2016-01-01", periods=10) + + mi = pd.MultiIndex.from_product([lev, range(5), dti]) + oidx = mi.to_flat_index() + + loc = len(oidx) // 2 + tup = oidx[loc] + + res = oidx.get_loc(tup) + assert res == loc + + def test_get_loc_nan_object_dtype_nonmonotonic_nonunique(self): + # case that goes through _maybe_get_bool_indexer + idx = Index(["foo", np.nan, None, "foo", 1.0, None], dtype=object) + + # we dont raise KeyError on nan + res = idx.get_loc(np.nan) + assert res == 1 + + # we only match on None, not on np.nan + res = idx.get_loc(None) + expected = np.array([False, False, True, False, False, True]) + tm.assert_numpy_array_equal(res, expected) + + # we don't match at all on mismatched NA + with pytest.raises(KeyError, match="NaT"): + idx.get_loc(NaT) + + +def test_getitem_boolean_ea_indexer(): + # GH#45806 + ser = pd.Series([True, False, pd.NA], dtype="boolean") + result = ser.index[ser] + expected = Index([0]) + tm.assert_index_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/base_class/test_pickle.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/base_class/test_pickle.py new file mode 100644 index 0000000000000000000000000000000000000000..c670921decb78808fa54a35c45e3d2d15ab57a67 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/base_class/test_pickle.py @@ -0,0 +1,11 @@ +from pandas import Index +import pandas._testing as tm + + +def test_pickle_preserves_object_dtype(): + # GH#43188, GH#43155 don't infer numeric dtype + index = Index([1, 2, 3], dtype=object) + + result = tm.round_trip_pickle(index) + assert result.dtype == object + tm.assert_index_equal(index, result) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/base_class/test_reshape.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/base_class/test_reshape.py new file mode 100644 index 0000000000000000000000000000000000000000..814a6a516904b6c31d3ef38fa91f23f69fa6ed7e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/base_class/test_reshape.py @@ -0,0 +1,95 @@ +""" +Tests for ndarray-like method on the base Index class +""" +import numpy as np +import pytest + +from pandas import Index +import pandas._testing as tm + + +class TestReshape: + def test_repeat(self): + repeats = 2 + index = Index([1, 2, 3]) + expected = Index([1, 1, 2, 2, 3, 3]) + + result = index.repeat(repeats) + tm.assert_index_equal(result, expected) + + def test_insert(self): + # GH 7256 + # validate neg/pos inserts + result = Index(["b", "c", "d"]) + + # test 0th element + tm.assert_index_equal(Index(["a", "b", "c", "d"]), result.insert(0, "a")) + + # test Nth element that follows Python list behavior + tm.assert_index_equal(Index(["b", "c", "e", "d"]), result.insert(-1, "e")) + + # test loc +/- neq (0, -1) + tm.assert_index_equal(result.insert(1, "z"), result.insert(-2, "z")) + + # test empty + null_index = Index([]) + tm.assert_index_equal(Index(["a"], dtype=object), null_index.insert(0, "a")) + + def test_insert_missing(self, nulls_fixture, using_infer_string): + # GH#22295 + # test there is no mangling of NA values + expected = Index(["a", nulls_fixture, "b", "c"], dtype=object) + result = Index(list("abc"), dtype=object).insert( + 1, Index([nulls_fixture], dtype=object) + ) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "val", [(1, 2), np.datetime64("2019-12-31"), np.timedelta64(1, "D")] + ) + @pytest.mark.parametrize("loc", [-1, 2]) + def test_insert_datetime_into_object(self, loc, val): + # GH#44509 + idx = Index(["1", "2", "3"]) + result = idx.insert(loc, val) + expected = Index(["1", "2", val, "3"]) + tm.assert_index_equal(result, expected) + assert type(expected[2]) is type(val) + + def test_insert_none_into_string_numpy(self): + # GH#55365 + pytest.importorskip("pyarrow") + index = Index(["a", "b", "c"], dtype="string[pyarrow_numpy]") + result = index.insert(-1, None) + expected = Index(["a", "b", None, "c"], dtype="string[pyarrow_numpy]") + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "pos,expected", + [ + (0, Index(["b", "c", "d"], name="index")), + (-1, Index(["a", "b", "c"], name="index")), + ], + ) + def test_delete(self, pos, expected): + index = Index(["a", "b", "c", "d"], name="index") + result = index.delete(pos) + tm.assert_index_equal(result, expected) + assert result.name == expected.name + + def test_delete_raises(self): + index = Index(["a", "b", "c", "d"], name="index") + msg = "index 5 is out of bounds for axis 0 with size 4" + with pytest.raises(IndexError, match=msg): + index.delete(5) + + def test_append_multiple(self): + index = Index(["a", "b", "c", "d", "e", "f"]) + + foos = [index[:2], index[2:4], index[4:]] + result = foos[0].append(foos[1:]) + tm.assert_index_equal(result, index) + + # empty + result = index.append([]) + tm.assert_index_equal(result, index) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/base_class/test_setops.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/base_class/test_setops.py new file mode 100644 index 0000000000000000000000000000000000000000..3ef3f3ad4d3a20bd2e6303d781590396cbc00ae0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/base_class/test_setops.py @@ -0,0 +1,266 @@ +from datetime import datetime + +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + Index, + Series, +) +import pandas._testing as tm +from pandas.core.algorithms import safe_sort + + +def equal_contents(arr1, arr2) -> bool: + """ + Checks if the set of unique elements of arr1 and arr2 are equivalent. + """ + return frozenset(arr1) == frozenset(arr2) + + +class TestIndexSetOps: + @pytest.mark.parametrize( + "method", ["union", "intersection", "difference", "symmetric_difference"] + ) + def test_setops_sort_validation(self, method): + idx1 = Index(["a", "b"]) + idx2 = Index(["b", "c"]) + + with pytest.raises(ValueError, match="The 'sort' keyword only takes"): + getattr(idx1, method)(idx2, sort=2) + + # sort=True is supported as of GH#?? + getattr(idx1, method)(idx2, sort=True) + + def test_setops_preserve_object_dtype(self): + idx = Index([1, 2, 3], dtype=object) + result = idx.intersection(idx[1:]) + expected = idx[1:] + tm.assert_index_equal(result, expected) + + # if other is not monotonic increasing, intersection goes through + # a different route + result = idx.intersection(idx[1:][::-1]) + tm.assert_index_equal(result, expected) + + result = idx._union(idx[1:], sort=None) + expected = idx + tm.assert_numpy_array_equal(result, expected.values) + + result = idx.union(idx[1:], sort=None) + tm.assert_index_equal(result, expected) + + # if other is not monotonic increasing, _union goes through + # a different route + result = idx._union(idx[1:][::-1], sort=None) + tm.assert_numpy_array_equal(result, expected.values) + + result = idx.union(idx[1:][::-1], sort=None) + tm.assert_index_equal(result, expected) + + def test_union_base(self): + index = Index([0, "a", 1, "b", 2, "c"]) + first = index[3:] + second = index[:5] + + result = first.union(second) + + expected = Index([0, 1, 2, "a", "b", "c"]) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("klass", [np.array, Series, list]) + def test_union_different_type_base(self, klass): + # GH 10149 + index = Index([0, "a", 1, "b", 2, "c"]) + first = index[3:] + second = index[:5] + + result = first.union(klass(second.values)) + + assert equal_contents(result, index) + + def test_union_sort_other_incomparable(self): + # https://github.com/pandas-dev/pandas/issues/24959 + idx = Index([1, pd.Timestamp("2000")]) + # default (sort=None) + with tm.assert_produces_warning(RuntimeWarning): + result = idx.union(idx[:1]) + + tm.assert_index_equal(result, idx) + + # sort=None + with tm.assert_produces_warning(RuntimeWarning): + result = idx.union(idx[:1], sort=None) + tm.assert_index_equal(result, idx) + + # sort=False + result = idx.union(idx[:1], sort=False) + tm.assert_index_equal(result, idx) + + def test_union_sort_other_incomparable_true(self): + idx = Index([1, pd.Timestamp("2000")]) + with pytest.raises(TypeError, match=".*"): + idx.union(idx[:1], sort=True) + + def test_intersection_equal_sort_true(self): + idx = Index(["c", "a", "b"]) + sorted_ = Index(["a", "b", "c"]) + tm.assert_index_equal(idx.intersection(idx, sort=True), sorted_) + + def test_intersection_base(self, sort): + # (same results for py2 and py3 but sortedness not tested elsewhere) + index = Index([0, "a", 1, "b", 2, "c"]) + first = index[:5] + second = index[:3] + + expected = Index([0, 1, "a"]) if sort is None else Index([0, "a", 1]) + result = first.intersection(second, sort=sort) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("klass", [np.array, Series, list]) + def test_intersection_different_type_base(self, klass, sort): + # GH 10149 + index = Index([0, "a", 1, "b", 2, "c"]) + first = index[:5] + second = index[:3] + + result = first.intersection(klass(second.values), sort=sort) + assert equal_contents(result, second) + + def test_intersection_nosort(self): + result = Index(["c", "b", "a"]).intersection(["b", "a"]) + expected = Index(["b", "a"]) + tm.assert_index_equal(result, expected) + + def test_intersection_equal_sort(self): + idx = Index(["c", "a", "b"]) + tm.assert_index_equal(idx.intersection(idx, sort=False), idx) + tm.assert_index_equal(idx.intersection(idx, sort=None), idx) + + def test_intersection_str_dates(self, sort): + dt_dates = [datetime(2012, 2, 9), datetime(2012, 2, 22)] + + i1 = Index(dt_dates, dtype=object) + i2 = Index(["aa"], dtype=object) + result = i2.intersection(i1, sort=sort) + + assert len(result) == 0 + + @pytest.mark.parametrize( + "index2,expected_arr", + [(Index(["B", "D"]), ["B"]), (Index(["B", "D", "A"]), ["A", "B"])], + ) + def test_intersection_non_monotonic_non_unique(self, index2, expected_arr, sort): + # non-monotonic non-unique + index1 = Index(["A", "B", "A", "C"]) + expected = Index(expected_arr) + result = index1.intersection(index2, sort=sort) + if sort is None: + expected = expected.sort_values() + tm.assert_index_equal(result, expected) + + def test_difference_base(self, sort): + # (same results for py2 and py3 but sortedness not tested elsewhere) + index = Index([0, "a", 1, "b", 2, "c"]) + first = index[:4] + second = index[3:] + + result = first.difference(second, sort) + expected = Index([0, "a", 1]) + if sort is None: + expected = Index(safe_sort(expected)) + tm.assert_index_equal(result, expected) + + def test_symmetric_difference(self): + # (same results for py2 and py3 but sortedness not tested elsewhere) + index = Index([0, "a", 1, "b", 2, "c"]) + first = index[:4] + second = index[3:] + + result = first.symmetric_difference(second) + expected = Index([0, 1, 2, "a", "c"]) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "method,expected,sort", + [ + ( + "intersection", + np.array( + [(1, "A"), (2, "A"), (1, "B"), (2, "B")], + dtype=[("num", int), ("let", "S1")], + ), + False, + ), + ( + "intersection", + np.array( + [(1, "A"), (1, "B"), (2, "A"), (2, "B")], + dtype=[("num", int), ("let", "S1")], + ), + None, + ), + ( + "union", + np.array( + [(1, "A"), (1, "B"), (1, "C"), (2, "A"), (2, "B"), (2, "C")], + dtype=[("num", int), ("let", "S1")], + ), + None, + ), + ], + ) + def test_tuple_union_bug(self, method, expected, sort): + index1 = Index( + np.array( + [(1, "A"), (2, "A"), (1, "B"), (2, "B")], + dtype=[("num", int), ("let", "S1")], + ) + ) + index2 = Index( + np.array( + [(1, "A"), (2, "A"), (1, "B"), (2, "B"), (1, "C"), (2, "C")], + dtype=[("num", int), ("let", "S1")], + ) + ) + + result = getattr(index1, method)(index2, sort=sort) + assert result.ndim == 1 + + expected = Index(expected) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("first_list", [["b", "a"], []]) + @pytest.mark.parametrize("second_list", [["a", "b"], []]) + @pytest.mark.parametrize( + "first_name, second_name, expected_name", + [("A", "B", None), (None, "B", None), ("A", None, None)], + ) + def test_union_name_preservation( + self, first_list, second_list, first_name, second_name, expected_name, sort + ): + first = Index(first_list, name=first_name) + second = Index(second_list, name=second_name) + union = first.union(second, sort=sort) + + vals = set(first_list).union(second_list) + + if sort is None and len(first_list) > 0 and len(second_list) > 0: + expected = Index(sorted(vals), name=expected_name) + tm.assert_index_equal(union, expected) + else: + expected = Index(vals, name=expected_name) + tm.assert_index_equal(union.sort_values(), expected.sort_values()) + + @pytest.mark.parametrize( + "diff_type, expected", + [["difference", [1, "B"]], ["symmetric_difference", [1, 2, "B", "C"]]], + ) + def test_difference_object_type(self, diff_type, expected): + # GH 13432 + idx1 = Index([0, 1, "A", "B"]) + idx2 = Index([0, 2, "A", "C"]) + result = getattr(idx1, diff_type)(idx2) + expected = Index(expected) + tm.assert_index_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/base_class/test_where.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/base_class/test_where.py new file mode 100644 index 0000000000000000000000000000000000000000..0c8969735e14e2741bc029b499024af3ec378a92 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/base_class/test_where.py @@ -0,0 +1,13 @@ +import numpy as np + +from pandas import Index +import pandas._testing as tm + + +class TestWhere: + def test_where_intlike_str_doesnt_cast_ints(self): + idx = Index(range(3)) + mask = np.array([True, False, True]) + res = idx.where(mask, "2") + expected = Index([0, "2", 2]) + tm.assert_index_equal(res, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/conftest.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..bfb7acdcf481273e50c18540c141017deb52e094 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/conftest.py @@ -0,0 +1,41 @@ +import numpy as np +import pytest + +from pandas import ( + Series, + array, +) + + +@pytest.fixture(params=[None, False]) +def sort(request): + """ + Valid values for the 'sort' parameter used in the Index + setops methods (intersection, union, etc.) + + Caution: + Don't confuse this one with the "sort" fixture used + for DataFrame.append or concat. That one has + parameters [True, False]. + + We can't combine them as sort=True is not permitted + in the Index setops methods. + """ + return request.param + + +@pytest.fixture(params=["D", "3D", "-3D", "h", "2h", "-2h", "min", "2min", "s", "-3s"]) +def freq_sample(request): + """ + Valid values for 'freq' parameter used to create date_range and + timedelta_range.. + """ + return request.param + + +@pytest.fixture(params=[list, tuple, np.array, array, Series]) +def listlike_box(request): + """ + Types that may be passed as the indexer to searchsorted. + """ + return request.param diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc9f1df2fc89a0a8d4fc8bbce0c3a3737cef00ab Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_arithmetic.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_arithmetic.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ab659045e366ee54149ec24d838d7662824d89d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_arithmetic.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_constructors.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_constructors.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4497944e04e15b34996834ec7432a35ae2e5617 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_constructors.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_date_range.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_date_range.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a16aa5d95e24683ee648b6b10f8f97d13a27a315 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_date_range.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_datetime.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_datetime.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e13b1e1dd269ac9f16392a5c0fcf782da1d51b38 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_datetime.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_formats.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_formats.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2103114fd069853014536fea51d8292aa31759d6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_formats.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_freq_attr.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_freq_attr.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19bdb3373770715ebf53efa26c273ad2a456ebc5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_freq_attr.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_indexing.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_indexing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1ae9b536ddb9827019cdef27205fa26b329bc6c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_indexing.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_iter.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_iter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dddcb6931c84e33b22c423c4a6d6672fb47e9f3 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_iter.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_join.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_join.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8a59d16efb22fd26aa700a2bc6f59bd830f05a4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_join.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_npfuncs.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_npfuncs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbb638007a570c022a04634e309fc05182cfb919 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_npfuncs.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_ops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06e35daf29edfa756efd2ac4ba52f5918cb710f3 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_ops.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_partial_slicing.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_partial_slicing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbdad8cbdc9a3e90fd3554565368141a607052e4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_partial_slicing.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_pickle.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_pickle.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f741f574209af1276826270458f031dc162ac8f5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_pickle.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_reindex.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_reindex.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eca5506d28f7493dca2bf62da58458824c04372d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_reindex.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_scalar_compat.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_scalar_compat.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2def1e82331665dd6688d892105babb96ce3c99 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_scalar_compat.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_setops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_setops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45462800faeeb630fe619f9b987e3d9e34ce4ea1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_setops.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_timezones.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_timezones.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c61918ef6c3b92a9591e9e8ad85b83dbf193e47c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/__pycache__/test_timezones.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/__pycache__/test_asof.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/__pycache__/test_asof.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7230e1a5ec630f962ea36aeaac1115db2d57060 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/__pycache__/test_asof.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/__pycache__/test_astype.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/__pycache__/test_astype.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63974401bdecaa95dd0601c07cc4718b2426ae07 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/__pycache__/test_astype.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/__pycache__/test_fillna.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/__pycache__/test_fillna.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e55c8acde64599c017034eeb3fe1751ea7c80a04 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/__pycache__/test_fillna.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/__pycache__/test_resolution.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/__pycache__/test_resolution.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00286f2c2c91ea37185d6cab7547bd48fd5ecffd Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/__pycache__/test_resolution.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/__pycache__/test_round.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/__pycache__/test_round.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ffdf0b250f33df66c372d6251f579e725bf20c2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/__pycache__/test_round.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/__pycache__/test_snap.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/__pycache__/test_snap.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..143b6130a81a406c25249a5c21c3896ea23e2601 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/__pycache__/test_snap.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/__pycache__/test_to_julian_date.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/__pycache__/test_to_julian_date.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46b7a63b15e90f48f13d6ecc980d1005eddee0a0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/__pycache__/test_to_julian_date.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/__pycache__/test_tz_convert.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/__pycache__/test_tz_convert.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a0efb777c7ea11ab45bcc402dea89651db6e867 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/__pycache__/test_tz_convert.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_asof.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_asof.py new file mode 100644 index 0000000000000000000000000000000000000000..dc92f533087bc3226727fac1810269520e1c4d1f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_asof.py @@ -0,0 +1,30 @@ +from datetime import timedelta + +from pandas import ( + Index, + Timestamp, + date_range, + isna, +) + + +class TestAsOf: + def test_asof_partial(self): + index = date_range("2010-01-01", periods=2, freq="ME") + expected = Timestamp("2010-02-28") + result = index.asof("2010-02") + assert result == expected + assert not isinstance(result, Index) + + def test_asof(self): + index = date_range("2020-01-01", periods=10) + + dt = index[0] + assert index.asof(dt) == dt + assert isna(index.asof(dt - timedelta(1))) + + dt = index[-1] + assert index.asof(dt + timedelta(1)) == dt + + dt = index[0].to_pydatetime() + assert isinstance(index.asof(dt), Timestamp) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_astype.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_astype.py new file mode 100644 index 0000000000000000000000000000000000000000..c0bc6601769b1dce195f0af7303e5015bcddf589 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_astype.py @@ -0,0 +1,335 @@ +from datetime import datetime + +import dateutil +import numpy as np +import pytest +import pytz + +import pandas as pd +from pandas import ( + DatetimeIndex, + Index, + NaT, + PeriodIndex, + Timestamp, + date_range, +) +import pandas._testing as tm + + +class TestDatetimeIndex: + @pytest.mark.parametrize("tzstr", ["US/Eastern", "dateutil/US/Eastern"]) + def test_dti_astype_asobject_around_dst_transition(self, tzstr): + # GH#1345 + + # dates around a dst transition + rng = date_range("2/13/2010", "5/6/2010", tz=tzstr) + + objs = rng.astype(object) + for i, x in enumerate(objs): + exval = rng[i] + assert x == exval + assert x.tzinfo == exval.tzinfo + + objs = rng.astype(object) + for i, x in enumerate(objs): + exval = rng[i] + assert x == exval + assert x.tzinfo == exval.tzinfo + + def test_astype(self): + # GH 13149, GH 13209 + idx = DatetimeIndex( + ["2016-05-16", "NaT", NaT, np.nan], dtype="M8[ns]", name="idx" + ) + + result = idx.astype(object) + expected = Index( + [Timestamp("2016-05-16")] + [NaT] * 3, dtype=object, name="idx" + ) + tm.assert_index_equal(result, expected) + + result = idx.astype(np.int64) + expected = Index( + [1463356800000000000] + [-9223372036854775808] * 3, + dtype=np.int64, + name="idx", + ) + tm.assert_index_equal(result, expected) + + def test_astype2(self): + rng = date_range("1/1/2000", periods=10, name="idx") + result = rng.astype("i8") + tm.assert_index_equal(result, Index(rng.asi8, name="idx")) + tm.assert_numpy_array_equal(result.values, rng.asi8) + + def test_astype_uint(self): + arr = date_range("2000", periods=2, name="idx") + + with pytest.raises(TypeError, match=r"Do obj.astype\('int64'\)"): + arr.astype("uint64") + with pytest.raises(TypeError, match=r"Do obj.astype\('int64'\)"): + arr.astype("uint32") + + def test_astype_with_tz(self): + # with tz + rng = date_range("1/1/2000", periods=10, tz="US/Eastern") + msg = "Cannot use .astype to convert from timezone-aware" + with pytest.raises(TypeError, match=msg): + # deprecated + rng.astype("datetime64[ns]") + with pytest.raises(TypeError, match=msg): + # check DatetimeArray while we're here deprecated + rng._data.astype("datetime64[ns]") + + def test_astype_tzaware_to_tzaware(self): + # GH 18951: tz-aware to tz-aware + idx = date_range("20170101", periods=4, tz="US/Pacific") + result = idx.astype("datetime64[ns, US/Eastern]") + expected = date_range("20170101 03:00:00", periods=4, tz="US/Eastern") + tm.assert_index_equal(result, expected) + assert result.freq == expected.freq + + def test_astype_tznaive_to_tzaware(self): + # GH 18951: tz-naive to tz-aware + idx = date_range("20170101", periods=4) + idx = idx._with_freq(None) # tz_localize does not preserve freq + msg = "Cannot use .astype to convert from timezone-naive" + with pytest.raises(TypeError, match=msg): + # dt64->dt64tz deprecated + idx.astype("datetime64[ns, US/Eastern]") + with pytest.raises(TypeError, match=msg): + # dt64->dt64tz deprecated + idx._data.astype("datetime64[ns, US/Eastern]") + + def test_astype_str_nat(self): + # GH 13149, GH 13209 + # verify that we are returning NaT as a string (and not unicode) + + idx = DatetimeIndex(["2016-05-16", "NaT", NaT, np.nan]) + result = idx.astype(str) + expected = Index(["2016-05-16", "NaT", "NaT", "NaT"], dtype=object) + tm.assert_index_equal(result, expected) + + def test_astype_str(self): + # test astype string - #10442 + dti = date_range("2012-01-01", periods=4, name="test_name") + result = dti.astype(str) + expected = Index( + ["2012-01-01", "2012-01-02", "2012-01-03", "2012-01-04"], + name="test_name", + dtype=object, + ) + tm.assert_index_equal(result, expected) + + def test_astype_str_tz_and_name(self): + # test astype string with tz and name + dti = date_range("2012-01-01", periods=3, name="test_name", tz="US/Eastern") + result = dti.astype(str) + expected = Index( + [ + "2012-01-01 00:00:00-05:00", + "2012-01-02 00:00:00-05:00", + "2012-01-03 00:00:00-05:00", + ], + name="test_name", + dtype=object, + ) + tm.assert_index_equal(result, expected) + + def test_astype_str_freq_and_name(self): + # test astype string with freqH and name + dti = date_range("1/1/2011", periods=3, freq="h", name="test_name") + result = dti.astype(str) + expected = Index( + ["2011-01-01 00:00:00", "2011-01-01 01:00:00", "2011-01-01 02:00:00"], + name="test_name", + dtype=object, + ) + tm.assert_index_equal(result, expected) + + def test_astype_str_freq_and_tz(self): + # test astype string with freqH and timezone + dti = date_range( + "3/6/2012 00:00", periods=2, freq="h", tz="Europe/London", name="test_name" + ) + result = dti.astype(str) + expected = Index( + ["2012-03-06 00:00:00+00:00", "2012-03-06 01:00:00+00:00"], + dtype=object, + name="test_name", + ) + tm.assert_index_equal(result, expected) + + def test_astype_datetime64(self): + # GH 13149, GH 13209 + idx = DatetimeIndex( + ["2016-05-16", "NaT", NaT, np.nan], dtype="M8[ns]", name="idx" + ) + + result = idx.astype("datetime64[ns]") + tm.assert_index_equal(result, idx) + assert result is not idx + + result = idx.astype("datetime64[ns]", copy=False) + tm.assert_index_equal(result, idx) + assert result is idx + + idx_tz = DatetimeIndex(["2016-05-16", "NaT", NaT, np.nan], tz="EST", name="idx") + msg = "Cannot use .astype to convert from timezone-aware" + with pytest.raises(TypeError, match=msg): + # dt64tz->dt64 deprecated + result = idx_tz.astype("datetime64[ns]") + + def test_astype_object(self): + rng = date_range("1/1/2000", periods=20) + + casted = rng.astype("O") + exp_values = list(rng) + + tm.assert_index_equal(casted, Index(exp_values, dtype=np.object_)) + assert casted.tolist() == exp_values + + @pytest.mark.parametrize("tz", [None, "Asia/Tokyo"]) + def test_astype_object_tz(self, tz): + idx = date_range(start="2013-01-01", periods=4, freq="ME", name="idx", tz=tz) + expected_list = [ + Timestamp("2013-01-31", tz=tz), + Timestamp("2013-02-28", tz=tz), + Timestamp("2013-03-31", tz=tz), + Timestamp("2013-04-30", tz=tz), + ] + expected = Index(expected_list, dtype=object, name="idx") + result = idx.astype(object) + tm.assert_index_equal(result, expected) + assert idx.tolist() == expected_list + + def test_astype_object_with_nat(self): + idx = DatetimeIndex( + [datetime(2013, 1, 1), datetime(2013, 1, 2), NaT, datetime(2013, 1, 4)], + name="idx", + ) + expected_list = [ + Timestamp("2013-01-01"), + Timestamp("2013-01-02"), + NaT, + Timestamp("2013-01-04"), + ] + expected = Index(expected_list, dtype=object, name="idx") + result = idx.astype(object) + tm.assert_index_equal(result, expected) + assert idx.tolist() == expected_list + + @pytest.mark.parametrize( + "dtype", + [float, "timedelta64", "timedelta64[ns]", "datetime64", "datetime64[D]"], + ) + def test_astype_raises(self, dtype): + # GH 13149, GH 13209 + idx = DatetimeIndex(["2016-05-16", "NaT", NaT, np.nan]) + msg = "Cannot cast DatetimeIndex to dtype" + if dtype == "datetime64": + msg = "Casting to unit-less dtype 'datetime64' is not supported" + with pytest.raises(TypeError, match=msg): + idx.astype(dtype) + + def test_index_convert_to_datetime_array(self): + def _check_rng(rng): + converted = rng.to_pydatetime() + assert isinstance(converted, np.ndarray) + for x, stamp in zip(converted, rng): + assert isinstance(x, datetime) + assert x == stamp.to_pydatetime() + assert x.tzinfo == stamp.tzinfo + + rng = date_range("20090415", "20090519") + rng_eastern = date_range("20090415", "20090519", tz="US/Eastern") + rng_utc = date_range("20090415", "20090519", tz="utc") + + _check_rng(rng) + _check_rng(rng_eastern) + _check_rng(rng_utc) + + def test_index_convert_to_datetime_array_explicit_pytz(self): + def _check_rng(rng): + converted = rng.to_pydatetime() + assert isinstance(converted, np.ndarray) + for x, stamp in zip(converted, rng): + assert isinstance(x, datetime) + assert x == stamp.to_pydatetime() + assert x.tzinfo == stamp.tzinfo + + rng = date_range("20090415", "20090519") + rng_eastern = date_range("20090415", "20090519", tz=pytz.timezone("US/Eastern")) + rng_utc = date_range("20090415", "20090519", tz=pytz.utc) + + _check_rng(rng) + _check_rng(rng_eastern) + _check_rng(rng_utc) + + def test_index_convert_to_datetime_array_dateutil(self): + def _check_rng(rng): + converted = rng.to_pydatetime() + assert isinstance(converted, np.ndarray) + for x, stamp in zip(converted, rng): + assert isinstance(x, datetime) + assert x == stamp.to_pydatetime() + assert x.tzinfo == stamp.tzinfo + + rng = date_range("20090415", "20090519") + rng_eastern = date_range("20090415", "20090519", tz="dateutil/US/Eastern") + rng_utc = date_range("20090415", "20090519", tz=dateutil.tz.tzutc()) + + _check_rng(rng) + _check_rng(rng_eastern) + _check_rng(rng_utc) + + @pytest.mark.parametrize( + "tz, dtype", + [["US/Pacific", "datetime64[ns, US/Pacific]"], [None, "datetime64[ns]"]], + ) + def test_integer_index_astype_datetime(self, tz, dtype): + # GH 20997, 20964, 24559 + val = [Timestamp("2018-01-01", tz=tz).as_unit("ns")._value] + result = Index(val, name="idx").astype(dtype) + expected = DatetimeIndex(["2018-01-01"], tz=tz, name="idx").as_unit("ns") + tm.assert_index_equal(result, expected) + + def test_dti_astype_period(self): + idx = DatetimeIndex([NaT, "2011-01-01", "2011-02-01"], name="idx") + + res = idx.astype("period[M]") + exp = PeriodIndex(["NaT", "2011-01", "2011-02"], freq="M", name="idx") + tm.assert_index_equal(res, exp) + + res = idx.astype("period[3M]") + exp = PeriodIndex(["NaT", "2011-01", "2011-02"], freq="3M", name="idx") + tm.assert_index_equal(res, exp) + + +class TestAstype: + @pytest.mark.parametrize("tz", [None, "US/Central"]) + def test_astype_category(self, tz): + obj = date_range("2000", periods=2, tz=tz, name="idx") + result = obj.astype("category") + dti = DatetimeIndex(["2000-01-01", "2000-01-02"], tz=tz).as_unit("ns") + expected = pd.CategoricalIndex( + dti, + name="idx", + ) + tm.assert_index_equal(result, expected) + + result = obj._data.astype("category") + expected = expected.values + tm.assert_categorical_equal(result, expected) + + @pytest.mark.parametrize("tz", [None, "US/Central"]) + def test_astype_array_fallback(self, tz): + obj = date_range("2000", periods=2, tz=tz, name="idx") + result = obj.astype(bool) + expected = Index(np.array([True, True]), name="idx") + tm.assert_index_equal(result, expected) + + result = obj._data.astype(bool) + expected = np.array([True, True]) + tm.assert_numpy_array_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_delete.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_delete.py new file mode 100644 index 0000000000000000000000000000000000000000..2341499977f2247dc42c30470795378515f49dc8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_delete.py @@ -0,0 +1,141 @@ +import pytest + +from pandas import ( + DatetimeIndex, + Series, + date_range, +) +import pandas._testing as tm + + +class TestDelete: + def test_delete(self, unit): + idx = date_range( + start="2000-01-01", periods=5, freq="ME", name="idx", unit=unit + ) + + # preserve freq + expected_0 = date_range( + start="2000-02-01", periods=4, freq="ME", name="idx", unit=unit + ) + expected_4 = date_range( + start="2000-01-01", periods=4, freq="ME", name="idx", unit=unit + ) + + # reset freq to None + expected_1 = DatetimeIndex( + ["2000-01-31", "2000-03-31", "2000-04-30", "2000-05-31"], + freq=None, + name="idx", + ).as_unit(unit) + + cases = { + 0: expected_0, + -5: expected_0, + -1: expected_4, + 4: expected_4, + 1: expected_1, + } + for n, expected in cases.items(): + result = idx.delete(n) + tm.assert_index_equal(result, expected) + assert result.name == expected.name + assert result.freq == expected.freq + + with pytest.raises((IndexError, ValueError), match="out of bounds"): + # either depending on numpy version + idx.delete(5) + + @pytest.mark.parametrize("tz", [None, "Asia/Tokyo", "US/Pacific"]) + def test_delete2(self, tz): + idx = date_range( + start="2000-01-01 09:00", periods=10, freq="h", name="idx", tz=tz + ) + + expected = date_range( + start="2000-01-01 10:00", periods=9, freq="h", name="idx", tz=tz + ) + result = idx.delete(0) + tm.assert_index_equal(result, expected) + assert result.name == expected.name + assert result.freqstr == "h" + assert result.tz == expected.tz + + expected = date_range( + start="2000-01-01 09:00", periods=9, freq="h", name="idx", tz=tz + ) + result = idx.delete(-1) + tm.assert_index_equal(result, expected) + assert result.name == expected.name + assert result.freqstr == "h" + assert result.tz == expected.tz + + def test_delete_slice(self, unit): + idx = date_range( + start="2000-01-01", periods=10, freq="D", name="idx", unit=unit + ) + + # preserve freq + expected_0_2 = date_range( + start="2000-01-04", periods=7, freq="D", name="idx", unit=unit + ) + expected_7_9 = date_range( + start="2000-01-01", periods=7, freq="D", name="idx", unit=unit + ) + + # reset freq to None + expected_3_5 = DatetimeIndex( + [ + "2000-01-01", + "2000-01-02", + "2000-01-03", + "2000-01-07", + "2000-01-08", + "2000-01-09", + "2000-01-10", + ], + freq=None, + name="idx", + ).as_unit(unit) + + cases = { + (0, 1, 2): expected_0_2, + (7, 8, 9): expected_7_9, + (3, 4, 5): expected_3_5, + } + for n, expected in cases.items(): + result = idx.delete(n) + tm.assert_index_equal(result, expected) + assert result.name == expected.name + assert result.freq == expected.freq + + result = idx.delete(slice(n[0], n[-1] + 1)) + tm.assert_index_equal(result, expected) + assert result.name == expected.name + assert result.freq == expected.freq + + # TODO: belongs in Series.drop tests? + @pytest.mark.parametrize("tz", [None, "Asia/Tokyo", "US/Pacific"]) + def test_delete_slice2(self, tz, unit): + dti = date_range( + "2000-01-01 09:00", periods=10, freq="h", name="idx", tz=tz, unit=unit + ) + ts = Series( + 1, + index=dti, + ) + # preserve freq + result = ts.drop(ts.index[:5]).index + expected = dti[5:] + tm.assert_index_equal(result, expected) + assert result.name == expected.name + assert result.freq == expected.freq + assert result.tz == expected.tz + + # reset freq to None + result = ts.drop(ts.index[[1, 3, 5, 7, 9]]).index + expected = dti[::2]._with_freq(None) + tm.assert_index_equal(result, expected) + assert result.name == expected.name + assert result.freq == expected.freq + assert result.tz == expected.tz diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_factorize.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_factorize.py new file mode 100644 index 0000000000000000000000000000000000000000..41ecf9ee6b82317137b1a6accee14ad8c1b5a35a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_factorize.py @@ -0,0 +1,125 @@ +import numpy as np +import pytest + +from pandas import ( + DatetimeIndex, + Index, + date_range, + factorize, +) +import pandas._testing as tm + + +class TestDatetimeIndexFactorize: + def test_factorize(self): + idx1 = DatetimeIndex( + ["2014-01", "2014-01", "2014-02", "2014-02", "2014-03", "2014-03"] + ) + + exp_arr = np.array([0, 0, 1, 1, 2, 2], dtype=np.intp) + exp_idx = DatetimeIndex(["2014-01", "2014-02", "2014-03"]) + + arr, idx = idx1.factorize() + tm.assert_numpy_array_equal(arr, exp_arr) + tm.assert_index_equal(idx, exp_idx) + assert idx.freq == exp_idx.freq + + arr, idx = idx1.factorize(sort=True) + tm.assert_numpy_array_equal(arr, exp_arr) + tm.assert_index_equal(idx, exp_idx) + assert idx.freq == exp_idx.freq + + # tz must be preserved + idx1 = idx1.tz_localize("Asia/Tokyo") + exp_idx = exp_idx.tz_localize("Asia/Tokyo") + + arr, idx = idx1.factorize() + tm.assert_numpy_array_equal(arr, exp_arr) + tm.assert_index_equal(idx, exp_idx) + assert idx.freq == exp_idx.freq + + idx2 = DatetimeIndex( + ["2014-03", "2014-03", "2014-02", "2014-01", "2014-03", "2014-01"] + ) + + exp_arr = np.array([2, 2, 1, 0, 2, 0], dtype=np.intp) + exp_idx = DatetimeIndex(["2014-01", "2014-02", "2014-03"]) + arr, idx = idx2.factorize(sort=True) + tm.assert_numpy_array_equal(arr, exp_arr) + tm.assert_index_equal(idx, exp_idx) + assert idx.freq == exp_idx.freq + + exp_arr = np.array([0, 0, 1, 2, 0, 2], dtype=np.intp) + exp_idx = DatetimeIndex(["2014-03", "2014-02", "2014-01"]) + arr, idx = idx2.factorize() + tm.assert_numpy_array_equal(arr, exp_arr) + tm.assert_index_equal(idx, exp_idx) + assert idx.freq == exp_idx.freq + + def test_factorize_preserves_freq(self): + # GH#38120 freq should be preserved + idx3 = date_range("2000-01", periods=4, freq="ME", tz="Asia/Tokyo") + exp_arr = np.array([0, 1, 2, 3], dtype=np.intp) + + arr, idx = idx3.factorize() + tm.assert_numpy_array_equal(arr, exp_arr) + tm.assert_index_equal(idx, idx3) + assert idx.freq == idx3.freq + + arr, idx = factorize(idx3) + tm.assert_numpy_array_equal(arr, exp_arr) + tm.assert_index_equal(idx, idx3) + assert idx.freq == idx3.freq + + def test_factorize_tz(self, tz_naive_fixture, index_or_series): + tz = tz_naive_fixture + # GH#13750 + base = date_range("2016-11-05", freq="h", periods=100, tz=tz) + idx = base.repeat(5) + + exp_arr = np.arange(100, dtype=np.intp).repeat(5) + + obj = index_or_series(idx) + + arr, res = obj.factorize() + tm.assert_numpy_array_equal(arr, exp_arr) + expected = base._with_freq(None) + tm.assert_index_equal(res, expected) + assert res.freq == expected.freq + + def test_factorize_dst(self, index_or_series): + # GH#13750 + idx = date_range("2016-11-06", freq="h", periods=12, tz="US/Eastern") + obj = index_or_series(idx) + + arr, res = obj.factorize() + tm.assert_numpy_array_equal(arr, np.arange(12, dtype=np.intp)) + tm.assert_index_equal(res, idx) + if index_or_series is Index: + assert res.freq == idx.freq + + idx = date_range("2016-06-13", freq="h", periods=12, tz="US/Eastern") + obj = index_or_series(idx) + + arr, res = obj.factorize() + tm.assert_numpy_array_equal(arr, np.arange(12, dtype=np.intp)) + tm.assert_index_equal(res, idx) + if index_or_series is Index: + assert res.freq == idx.freq + + @pytest.mark.parametrize("sort", [True, False]) + def test_factorize_no_freq_non_nano(self, tz_naive_fixture, sort): + # GH#51978 case that does not go through the fastpath based on + # non-None freq + tz = tz_naive_fixture + idx = date_range("2016-11-06", freq="h", periods=5, tz=tz)[[0, 4, 1, 3, 2]] + exp_codes, exp_uniques = idx.factorize(sort=sort) + + res_codes, res_uniques = idx.as_unit("s").factorize(sort=sort) + + tm.assert_numpy_array_equal(res_codes, exp_codes) + tm.assert_index_equal(res_uniques, exp_uniques.as_unit("s")) + + res_codes, res_uniques = idx.as_unit("s").to_series().factorize(sort=sort) + tm.assert_numpy_array_equal(res_codes, exp_codes) + tm.assert_index_equal(res_uniques, exp_uniques.as_unit("s")) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_fillna.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_fillna.py new file mode 100644 index 0000000000000000000000000000000000000000..5fbe60bb0c50f0b6ec36eb02b125e9e9bf0f81dd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_fillna.py @@ -0,0 +1,62 @@ +import pytest + +import pandas as pd +import pandas._testing as tm + + +class TestDatetimeIndexFillNA: + @pytest.mark.parametrize("tz", ["US/Eastern", "Asia/Tokyo"]) + def test_fillna_datetime64(self, tz): + # GH 11343 + idx = pd.DatetimeIndex(["2011-01-01 09:00", pd.NaT, "2011-01-01 11:00"]) + + exp = pd.DatetimeIndex( + ["2011-01-01 09:00", "2011-01-01 10:00", "2011-01-01 11:00"] + ) + tm.assert_index_equal(idx.fillna(pd.Timestamp("2011-01-01 10:00")), exp) + + # tz mismatch + exp = pd.Index( + [ + pd.Timestamp("2011-01-01 09:00"), + pd.Timestamp("2011-01-01 10:00", tz=tz), + pd.Timestamp("2011-01-01 11:00"), + ], + dtype=object, + ) + tm.assert_index_equal(idx.fillna(pd.Timestamp("2011-01-01 10:00", tz=tz)), exp) + + # object + exp = pd.Index( + [pd.Timestamp("2011-01-01 09:00"), "x", pd.Timestamp("2011-01-01 11:00")], + dtype=object, + ) + tm.assert_index_equal(idx.fillna("x"), exp) + + idx = pd.DatetimeIndex(["2011-01-01 09:00", pd.NaT, "2011-01-01 11:00"], tz=tz) + + exp = pd.DatetimeIndex( + ["2011-01-01 09:00", "2011-01-01 10:00", "2011-01-01 11:00"], tz=tz + ) + tm.assert_index_equal(idx.fillna(pd.Timestamp("2011-01-01 10:00", tz=tz)), exp) + + exp = pd.Index( + [ + pd.Timestamp("2011-01-01 09:00", tz=tz), + pd.Timestamp("2011-01-01 10:00"), + pd.Timestamp("2011-01-01 11:00", tz=tz), + ], + dtype=object, + ) + tm.assert_index_equal(idx.fillna(pd.Timestamp("2011-01-01 10:00")), exp) + + # object + exp = pd.Index( + [ + pd.Timestamp("2011-01-01 09:00", tz=tz), + "x", + pd.Timestamp("2011-01-01 11:00", tz=tz), + ], + dtype=object, + ) + tm.assert_index_equal(idx.fillna("x"), exp) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_insert.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_insert.py new file mode 100644 index 0000000000000000000000000000000000000000..ebfe490e0e067807f7a38d3f8f285aee76718fcf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_insert.py @@ -0,0 +1,265 @@ +from datetime import datetime + +import numpy as np +import pytest +import pytz + +from pandas import ( + NA, + DatetimeIndex, + Index, + NaT, + Timestamp, + date_range, +) +import pandas._testing as tm + + +class TestInsert: + @pytest.mark.parametrize("null", [None, np.nan, np.datetime64("NaT"), NaT, NA]) + @pytest.mark.parametrize("tz", [None, "UTC", "US/Eastern"]) + def test_insert_nat(self, tz, null): + # GH#16537, GH#18295 (test missing) + + idx = DatetimeIndex(["2017-01-01"], tz=tz) + expected = DatetimeIndex(["NaT", "2017-01-01"], tz=tz) + if tz is not None and isinstance(null, np.datetime64): + expected = Index([null, idx[0]], dtype=object) + + res = idx.insert(0, null) + tm.assert_index_equal(res, expected) + + @pytest.mark.parametrize("tz", [None, "UTC", "US/Eastern"]) + def test_insert_invalid_na(self, tz): + idx = DatetimeIndex(["2017-01-01"], tz=tz) + + item = np.timedelta64("NaT") + result = idx.insert(0, item) + expected = Index([item] + list(idx), dtype=object) + tm.assert_index_equal(result, expected) + + def test_insert_empty_preserves_freq(self, tz_naive_fixture): + # GH#33573 + tz = tz_naive_fixture + dti = DatetimeIndex([], tz=tz, freq="D") + item = Timestamp("2017-04-05").tz_localize(tz) + + result = dti.insert(0, item) + assert result.freq == dti.freq + + # But not when we insert an item that doesn't conform to freq + dti = DatetimeIndex([], tz=tz, freq="W-THU") + result = dti.insert(0, item) + assert result.freq is None + + def test_insert(self, unit): + idx = DatetimeIndex( + ["2000-01-04", "2000-01-01", "2000-01-02"], name="idx" + ).as_unit(unit) + + result = idx.insert(2, datetime(2000, 1, 5)) + exp = DatetimeIndex( + ["2000-01-04", "2000-01-01", "2000-01-05", "2000-01-02"], name="idx" + ).as_unit(unit) + tm.assert_index_equal(result, exp) + + # insertion of non-datetime should coerce to object index + result = idx.insert(1, "inserted") + expected = Index( + [ + datetime(2000, 1, 4), + "inserted", + datetime(2000, 1, 1), + datetime(2000, 1, 2), + ], + name="idx", + ) + assert not isinstance(result, DatetimeIndex) + tm.assert_index_equal(result, expected) + assert result.name == expected.name + + def test_insert2(self, unit): + idx = date_range("1/1/2000", periods=3, freq="ME", name="idx", unit=unit) + + # preserve freq + expected_0 = DatetimeIndex( + ["1999-12-31", "2000-01-31", "2000-02-29", "2000-03-31"], + name="idx", + freq="ME", + ).as_unit(unit) + expected_3 = DatetimeIndex( + ["2000-01-31", "2000-02-29", "2000-03-31", "2000-04-30"], + name="idx", + freq="ME", + ).as_unit(unit) + + # reset freq to None + expected_1_nofreq = DatetimeIndex( + ["2000-01-31", "2000-01-31", "2000-02-29", "2000-03-31"], + name="idx", + freq=None, + ).as_unit(unit) + expected_3_nofreq = DatetimeIndex( + ["2000-01-31", "2000-02-29", "2000-03-31", "2000-01-02"], + name="idx", + freq=None, + ).as_unit(unit) + + cases = [ + (0, datetime(1999, 12, 31), expected_0), + (-3, datetime(1999, 12, 31), expected_0), + (3, datetime(2000, 4, 30), expected_3), + (1, datetime(2000, 1, 31), expected_1_nofreq), + (3, datetime(2000, 1, 2), expected_3_nofreq), + ] + + for n, d, expected in cases: + result = idx.insert(n, d) + tm.assert_index_equal(result, expected) + assert result.name == expected.name + assert result.freq == expected.freq + + def test_insert3(self, unit): + idx = date_range("1/1/2000", periods=3, freq="ME", name="idx", unit=unit) + + # reset freq to None + result = idx.insert(3, datetime(2000, 1, 2)) + expected = DatetimeIndex( + ["2000-01-31", "2000-02-29", "2000-03-31", "2000-01-02"], + name="idx", + freq=None, + ).as_unit(unit) + tm.assert_index_equal(result, expected) + assert result.name == expected.name + assert result.freq is None + + def test_insert4(self, unit): + for tz in ["US/Pacific", "Asia/Singapore"]: + idx = date_range( + "1/1/2000 09:00", periods=6, freq="h", tz=tz, name="idx", unit=unit + ) + # preserve freq + expected = date_range( + "1/1/2000 09:00", periods=7, freq="h", tz=tz, name="idx", unit=unit + ) + for d in [ + Timestamp("2000-01-01 15:00", tz=tz), + pytz.timezone(tz).localize(datetime(2000, 1, 1, 15)), + ]: + result = idx.insert(6, d) + tm.assert_index_equal(result, expected) + assert result.name == expected.name + assert result.freq == expected.freq + assert result.tz == expected.tz + + expected = DatetimeIndex( + [ + "2000-01-01 09:00", + "2000-01-01 10:00", + "2000-01-01 11:00", + "2000-01-01 12:00", + "2000-01-01 13:00", + "2000-01-01 14:00", + "2000-01-01 10:00", + ], + name="idx", + tz=tz, + freq=None, + ).as_unit(unit) + # reset freq to None + for d in [ + Timestamp("2000-01-01 10:00", tz=tz), + pytz.timezone(tz).localize(datetime(2000, 1, 1, 10)), + ]: + result = idx.insert(6, d) + tm.assert_index_equal(result, expected) + assert result.name == expected.name + assert result.tz == expected.tz + assert result.freq is None + + # TODO: also changes DataFrame.__setitem__ with expansion + def test_insert_mismatched_tzawareness(self): + # see GH#7299 + idx = date_range("1/1/2000", periods=3, freq="D", tz="Asia/Tokyo", name="idx") + + # mismatched tz-awareness + item = Timestamp("2000-01-04") + result = idx.insert(3, item) + expected = Index( + list(idx[:3]) + [item] + list(idx[3:]), dtype=object, name="idx" + ) + tm.assert_index_equal(result, expected) + + # mismatched tz-awareness + item = datetime(2000, 1, 4) + result = idx.insert(3, item) + expected = Index( + list(idx[:3]) + [item] + list(idx[3:]), dtype=object, name="idx" + ) + tm.assert_index_equal(result, expected) + + # TODO: also changes DataFrame.__setitem__ with expansion + def test_insert_mismatched_tz(self): + # see GH#7299 + # pre-2.0 with mismatched tzs we would cast to object + idx = date_range("1/1/2000", periods=3, freq="D", tz="Asia/Tokyo", name="idx") + + # mismatched tz -> cast to object (could reasonably cast to same tz or UTC) + item = Timestamp("2000-01-04", tz="US/Eastern") + result = idx.insert(3, item) + expected = Index( + list(idx[:3]) + [item.tz_convert(idx.tz)] + list(idx[3:]), + name="idx", + ) + assert expected.dtype == idx.dtype + tm.assert_index_equal(result, expected) + + item = datetime(2000, 1, 4, tzinfo=pytz.timezone("US/Eastern")) + result = idx.insert(3, item) + expected = Index( + list(idx[:3]) + [item.astimezone(idx.tzinfo)] + list(idx[3:]), + name="idx", + ) + assert expected.dtype == idx.dtype + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "item", [0, np.int64(0), np.float64(0), np.array(0), np.timedelta64(456)] + ) + def test_insert_mismatched_types_raises(self, tz_aware_fixture, item): + # GH#33703 dont cast these to dt64 + tz = tz_aware_fixture + dti = date_range("2019-11-04", periods=9, freq="-1D", name=9, tz=tz) + + result = dti.insert(1, item) + + if isinstance(item, np.ndarray): + assert item.item() == 0 + expected = Index([dti[0], 0] + list(dti[1:]), dtype=object, name=9) + else: + expected = Index([dti[0], item] + list(dti[1:]), dtype=object, name=9) + + tm.assert_index_equal(result, expected) + + def test_insert_castable_str(self, tz_aware_fixture): + # GH#33703 + tz = tz_aware_fixture + dti = date_range("2019-11-04", periods=3, freq="-1D", name=9, tz=tz) + + value = "2019-11-05" + result = dti.insert(0, value) + + ts = Timestamp(value).tz_localize(tz) + expected = DatetimeIndex([ts] + list(dti), dtype=dti.dtype, name=9) + tm.assert_index_equal(result, expected) + + def test_insert_non_castable_str(self, tz_aware_fixture): + # GH#33703 + tz = tz_aware_fixture + dti = date_range("2019-11-04", periods=3, freq="-1D", name=9, tz=tz) + + value = "foo" + result = dti.insert(0, value) + + expected = Index(["foo"] + list(dti), dtype=object, name=9) + tm.assert_index_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_isocalendar.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_isocalendar.py new file mode 100644 index 0000000000000000000000000000000000000000..97f1003e0f43f7564434cbc8b3051e870143209c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_isocalendar.py @@ -0,0 +1,28 @@ +from pandas import ( + DataFrame, + DatetimeIndex, + date_range, +) +import pandas._testing as tm + + +def test_isocalendar_returns_correct_values_close_to_new_year_with_tz(): + # GH#6538: Check that DatetimeIndex and its TimeStamp elements + # return the same weekofyear accessor close to new year w/ tz + dates = ["2013/12/29", "2013/12/30", "2013/12/31"] + dates = DatetimeIndex(dates, tz="Europe/Brussels") + result = dates.isocalendar() + expected_data_frame = DataFrame( + [[2013, 52, 7], [2014, 1, 1], [2014, 1, 2]], + columns=["year", "week", "day"], + index=dates, + dtype="UInt32", + ) + tm.assert_frame_equal(result, expected_data_frame) + + +def test_dti_timestamp_isocalendar_fields(): + idx = date_range("2020-01-01", periods=10) + expected = tuple(idx.isocalendar().iloc[-1].to_list()) + result = idx[-1].isocalendar() + assert result == expected diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_map.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_map.py new file mode 100644 index 0000000000000000000000000000000000000000..f35f07bd32068f15fa8c4eb8d1ad8c2a6d43fc72 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_map.py @@ -0,0 +1,47 @@ +import pytest + +from pandas import ( + DatetimeIndex, + Index, + MultiIndex, + Period, + date_range, +) +import pandas._testing as tm + + +class TestMap: + def test_map(self): + rng = date_range("1/1/2000", periods=10) + + f = lambda x: x.strftime("%Y%m%d") + result = rng.map(f) + exp = Index([f(x) for x in rng]) + tm.assert_index_equal(result, exp) + + def test_map_fallthrough(self, capsys): + # GH#22067, check we don't get warnings about silently ignored errors + dti = date_range("2017-01-01", "2018-01-01", freq="B") + + dti.map(lambda x: Period(year=x.year, month=x.month, freq="M")) + + captured = capsys.readouterr() + assert captured.err == "" + + def test_map_bug_1677(self): + index = DatetimeIndex(["2012-04-25 09:30:00.393000"]) + f = index.asof + + result = index.map(f) + expected = Index([f(index[0])]) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("name", [None, "name"]) + def test_index_map(self, name): + # see GH#20990 + count = 6 + index = date_range("2018-01-01", periods=count, freq="ME", name=name).map( + lambda x: (x.year, x.month) + ) + exp_index = MultiIndex.from_product(((2018,), range(1, 7)), names=[name, name]) + tm.assert_index_equal(index, exp_index) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_normalize.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_normalize.py new file mode 100644 index 0000000000000000000000000000000000000000..74711f67e64465c5592e562fcc94202666d0ad67 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_normalize.py @@ -0,0 +1,95 @@ +from dateutil.tz import tzlocal +import numpy as np +import pytest + +import pandas.util._test_decorators as td + +from pandas import ( + DatetimeIndex, + NaT, + Timestamp, + date_range, +) +import pandas._testing as tm + + +class TestNormalize: + def test_normalize(self): + rng = date_range("1/1/2000 9:30", periods=10, freq="D") + + result = rng.normalize() + expected = date_range("1/1/2000", periods=10, freq="D") + tm.assert_index_equal(result, expected) + + arr_ns = np.array([1380585623454345752, 1380585612343234312]).astype( + "datetime64[ns]" + ) + rng_ns = DatetimeIndex(arr_ns) + rng_ns_normalized = rng_ns.normalize() + + arr_ns = np.array([1380585600000000000, 1380585600000000000]).astype( + "datetime64[ns]" + ) + expected = DatetimeIndex(arr_ns) + tm.assert_index_equal(rng_ns_normalized, expected) + + assert result.is_normalized + assert not rng.is_normalized + + def test_normalize_nat(self): + dti = DatetimeIndex([NaT, Timestamp("2018-01-01 01:00:00")]) + result = dti.normalize() + expected = DatetimeIndex([NaT, Timestamp("2018-01-01")]) + tm.assert_index_equal(result, expected) + + def test_normalize_tz(self): + rng = date_range("1/1/2000 9:30", periods=10, freq="D", tz="US/Eastern") + + result = rng.normalize() # does not preserve freq + expected = date_range("1/1/2000", periods=10, freq="D", tz="US/Eastern") + tm.assert_index_equal(result, expected._with_freq(None)) + + assert result.is_normalized + assert not rng.is_normalized + + rng = date_range("1/1/2000 9:30", periods=10, freq="D", tz="UTC") + + result = rng.normalize() + expected = date_range("1/1/2000", periods=10, freq="D", tz="UTC") + tm.assert_index_equal(result, expected) + + assert result.is_normalized + assert not rng.is_normalized + + rng = date_range("1/1/2000 9:30", periods=10, freq="D", tz=tzlocal()) + result = rng.normalize() # does not preserve freq + expected = date_range("1/1/2000", periods=10, freq="D", tz=tzlocal()) + tm.assert_index_equal(result, expected._with_freq(None)) + + assert result.is_normalized + assert not rng.is_normalized + + @td.skip_if_windows + @pytest.mark.parametrize( + "timezone", + [ + "US/Pacific", + "US/Eastern", + "UTC", + "Asia/Kolkata", + "Asia/Shanghai", + "Australia/Canberra", + ], + ) + def test_normalize_tz_local(self, timezone): + # GH#13459 + with tm.set_timezone(timezone): + rng = date_range("1/1/2000 9:30", periods=10, freq="D", tz=tzlocal()) + + result = rng.normalize() + expected = date_range("1/1/2000", periods=10, freq="D", tz=tzlocal()) + expected = expected._with_freq(None) + tm.assert_index_equal(result, expected) + + assert result.is_normalized + assert not rng.is_normalized diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_repeat.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_repeat.py new file mode 100644 index 0000000000000000000000000000000000000000..92501755f8c5b3e943864c76a62cd712edc6dd51 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_repeat.py @@ -0,0 +1,83 @@ +import numpy as np +import pytest + +from pandas import ( + DatetimeIndex, + Timestamp, + date_range, +) +import pandas._testing as tm + + +class TestRepeat: + def test_repeat_range(self, tz_naive_fixture): + rng = date_range("1/1/2000", "1/1/2001") + + result = rng.repeat(5) + assert result.freq is None + assert len(result) == 5 * len(rng) + + def test_repeat_range2(self, tz_naive_fixture, unit): + tz = tz_naive_fixture + index = date_range("2001-01-01", periods=2, freq="D", tz=tz, unit=unit) + exp = DatetimeIndex( + ["2001-01-01", "2001-01-01", "2001-01-02", "2001-01-02"], tz=tz + ).as_unit(unit) + for res in [index.repeat(2), np.repeat(index, 2)]: + tm.assert_index_equal(res, exp) + assert res.freq is None + + def test_repeat_range3(self, tz_naive_fixture, unit): + tz = tz_naive_fixture + index = date_range("2001-01-01", periods=2, freq="2D", tz=tz, unit=unit) + exp = DatetimeIndex( + ["2001-01-01", "2001-01-01", "2001-01-03", "2001-01-03"], tz=tz + ).as_unit(unit) + for res in [index.repeat(2), np.repeat(index, 2)]: + tm.assert_index_equal(res, exp) + assert res.freq is None + + def test_repeat_range4(self, tz_naive_fixture, unit): + tz = tz_naive_fixture + index = DatetimeIndex(["2001-01-01", "NaT", "2003-01-01"], tz=tz).as_unit(unit) + exp = DatetimeIndex( + [ + "2001-01-01", + "2001-01-01", + "2001-01-01", + "NaT", + "NaT", + "NaT", + "2003-01-01", + "2003-01-01", + "2003-01-01", + ], + tz=tz, + ).as_unit(unit) + for res in [index.repeat(3), np.repeat(index, 3)]: + tm.assert_index_equal(res, exp) + assert res.freq is None + + def test_repeat(self, tz_naive_fixture, unit): + tz = tz_naive_fixture + reps = 2 + msg = "the 'axis' parameter is not supported" + + rng = date_range(start="2016-01-01", periods=2, freq="30Min", tz=tz, unit=unit) + + expected_rng = DatetimeIndex( + [ + Timestamp("2016-01-01 00:00:00", tz=tz), + Timestamp("2016-01-01 00:00:00", tz=tz), + Timestamp("2016-01-01 00:30:00", tz=tz), + Timestamp("2016-01-01 00:30:00", tz=tz), + ] + ).as_unit(unit) + + res = rng.repeat(reps) + tm.assert_index_equal(res, expected_rng) + assert res.freq is None + + tm.assert_index_equal(np.repeat(rng, reps), expected_rng) + with pytest.raises(ValueError, match=msg): + np.repeat(rng, reps, axis=1) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_resolution.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_resolution.py new file mode 100644 index 0000000000000000000000000000000000000000..8399fafbbaff20463901a8008555492bc8b5c5f5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_resolution.py @@ -0,0 +1,31 @@ +from dateutil.tz import tzlocal +import pytest + +from pandas.compat import IS64 + +from pandas import date_range + + +@pytest.mark.parametrize( + "freq,expected", + [ + ("YE", "day"), + ("QE", "day"), + ("ME", "day"), + ("D", "day"), + ("h", "hour"), + ("min", "minute"), + ("s", "second"), + ("ms", "millisecond"), + ("us", "microsecond"), + ], +) +def test_dti_resolution(request, tz_naive_fixture, freq, expected): + tz = tz_naive_fixture + if freq == "YE" and not IS64 and isinstance(tz, tzlocal): + request.applymarker( + pytest.mark.xfail(reason="OverflowError inside tzlocal past 2038") + ) + + idx = date_range(start="2013-04-01", periods=30, freq=freq, tz=tz) + assert idx.resolution == expected diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_round.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_round.py new file mode 100644 index 0000000000000000000000000000000000000000..cde4a3a65804df514dfa71ce3e724aaee7d413c0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_round.py @@ -0,0 +1,221 @@ +import pytest + +from pandas._libs.tslibs import to_offset +from pandas._libs.tslibs.offsets import INVALID_FREQ_ERR_MSG + +from pandas import ( + DatetimeIndex, + Timestamp, + date_range, +) +import pandas._testing as tm + + +class TestDatetimeIndexRound: + def test_round_daily(self): + dti = date_range("20130101 09:10:11", periods=5) + result = dti.round("D") + expected = date_range("20130101", periods=5) + tm.assert_index_equal(result, expected) + + dti = dti.tz_localize("UTC").tz_convert("US/Eastern") + result = dti.round("D") + expected = date_range("20130101", periods=5).tz_localize("US/Eastern") + tm.assert_index_equal(result, expected) + + result = dti.round("s") + tm.assert_index_equal(result, dti) + + @pytest.mark.parametrize( + "freq, error_msg", + [ + ("YE", " is a non-fixed frequency"), + ("ME", " is a non-fixed frequency"), + ("foobar", "Invalid frequency: foobar"), + ], + ) + def test_round_invalid(self, freq, error_msg): + dti = date_range("20130101 09:10:11", periods=5) + dti = dti.tz_localize("UTC").tz_convert("US/Eastern") + with pytest.raises(ValueError, match=error_msg): + dti.round(freq) + + def test_round(self, tz_naive_fixture, unit): + tz = tz_naive_fixture + rng = date_range(start="2016-01-01", periods=5, freq="30Min", tz=tz, unit=unit) + elt = rng[1] + + expected_rng = DatetimeIndex( + [ + Timestamp("2016-01-01 00:00:00", tz=tz), + Timestamp("2016-01-01 00:00:00", tz=tz), + Timestamp("2016-01-01 01:00:00", tz=tz), + Timestamp("2016-01-01 02:00:00", tz=tz), + Timestamp("2016-01-01 02:00:00", tz=tz), + ] + ).as_unit(unit) + expected_elt = expected_rng[1] + + result = rng.round(freq="h") + tm.assert_index_equal(result, expected_rng) + assert elt.round(freq="h") == expected_elt + + msg = INVALID_FREQ_ERR_MSG + with pytest.raises(ValueError, match=msg): + rng.round(freq="foo") + with pytest.raises(ValueError, match=msg): + elt.round(freq="foo") + + msg = " is a non-fixed frequency" + with pytest.raises(ValueError, match=msg): + rng.round(freq="ME") + with pytest.raises(ValueError, match=msg): + elt.round(freq="ME") + + def test_round2(self, tz_naive_fixture): + tz = tz_naive_fixture + # GH#14440 & GH#15578 + index = DatetimeIndex(["2016-10-17 12:00:00.0015"], tz=tz).as_unit("ns") + result = index.round("ms") + expected = DatetimeIndex(["2016-10-17 12:00:00.002000"], tz=tz).as_unit("ns") + tm.assert_index_equal(result, expected) + + for freq in ["us", "ns"]: + tm.assert_index_equal(index, index.round(freq)) + + def test_round3(self, tz_naive_fixture): + tz = tz_naive_fixture + index = DatetimeIndex(["2016-10-17 12:00:00.00149"], tz=tz).as_unit("ns") + result = index.round("ms") + expected = DatetimeIndex(["2016-10-17 12:00:00.001000"], tz=tz).as_unit("ns") + tm.assert_index_equal(result, expected) + + def test_round4(self, tz_naive_fixture): + index = DatetimeIndex(["2016-10-17 12:00:00.001501031"], dtype="M8[ns]") + result = index.round("10ns") + expected = DatetimeIndex(["2016-10-17 12:00:00.001501030"], dtype="M8[ns]") + tm.assert_index_equal(result, expected) + + ts = "2016-10-17 12:00:00.001501031" + dti = DatetimeIndex([ts], dtype="M8[ns]") + with tm.assert_produces_warning(False): + dti.round("1010ns") + + def test_no_rounding_occurs(self, tz_naive_fixture): + # GH 21262 + tz = tz_naive_fixture + rng = date_range(start="2016-01-01", periods=5, freq="2Min", tz=tz) + + expected_rng = DatetimeIndex( + [ + Timestamp("2016-01-01 00:00:00", tz=tz), + Timestamp("2016-01-01 00:02:00", tz=tz), + Timestamp("2016-01-01 00:04:00", tz=tz), + Timestamp("2016-01-01 00:06:00", tz=tz), + Timestamp("2016-01-01 00:08:00", tz=tz), + ] + ).as_unit("ns") + + result = rng.round(freq="2min") + tm.assert_index_equal(result, expected_rng) + + @pytest.mark.parametrize( + "test_input, rounder, freq, expected", + [ + (["2117-01-01 00:00:45"], "floor", "15s", ["2117-01-01 00:00:45"]), + (["2117-01-01 00:00:45"], "ceil", "15s", ["2117-01-01 00:00:45"]), + ( + ["2117-01-01 00:00:45.000000012"], + "floor", + "10ns", + ["2117-01-01 00:00:45.000000010"], + ), + ( + ["1823-01-01 00:00:01.000000012"], + "ceil", + "10ns", + ["1823-01-01 00:00:01.000000020"], + ), + (["1823-01-01 00:00:01"], "floor", "1s", ["1823-01-01 00:00:01"]), + (["1823-01-01 00:00:01"], "ceil", "1s", ["1823-01-01 00:00:01"]), + (["2018-01-01 00:15:00"], "ceil", "15min", ["2018-01-01 00:15:00"]), + (["2018-01-01 00:15:00"], "floor", "15min", ["2018-01-01 00:15:00"]), + (["1823-01-01 03:00:00"], "ceil", "3h", ["1823-01-01 03:00:00"]), + (["1823-01-01 03:00:00"], "floor", "3h", ["1823-01-01 03:00:00"]), + ( + ("NaT", "1823-01-01 00:00:01"), + "floor", + "1s", + ("NaT", "1823-01-01 00:00:01"), + ), + ( + ("NaT", "1823-01-01 00:00:01"), + "ceil", + "1s", + ("NaT", "1823-01-01 00:00:01"), + ), + ], + ) + def test_ceil_floor_edge(self, test_input, rounder, freq, expected): + dt = DatetimeIndex(list(test_input)) + func = getattr(dt, rounder) + result = func(freq) + expected = DatetimeIndex(list(expected)) + assert expected.equals(result) + + @pytest.mark.parametrize( + "start, index_freq, periods", + [("2018-01-01", "12h", 25), ("2018-01-01 0:0:0.124999", "1ns", 1000)], + ) + @pytest.mark.parametrize( + "round_freq", + [ + "2ns", + "3ns", + "4ns", + "5ns", + "6ns", + "7ns", + "250ns", + "500ns", + "750ns", + "1us", + "19us", + "250us", + "500us", + "750us", + "1s", + "2s", + "3s", + "12h", + "1D", + ], + ) + def test_round_int64(self, start, index_freq, periods, round_freq): + dt = date_range(start=start, freq=index_freq, periods=periods) + unit = to_offset(round_freq).nanos + + # test floor + result = dt.floor(round_freq) + diff = dt.asi8 - result.asi8 + mod = result.asi8 % unit + assert (mod == 0).all(), f"floor not a {round_freq} multiple" + assert (0 <= diff).all() and (diff < unit).all(), "floor error" + + # test ceil + result = dt.ceil(round_freq) + diff = result.asi8 - dt.asi8 + mod = result.asi8 % unit + assert (mod == 0).all(), f"ceil not a {round_freq} multiple" + assert (0 <= diff).all() and (diff < unit).all(), "ceil error" + + # test round + result = dt.round(round_freq) + diff = abs(result.asi8 - dt.asi8) + mod = result.asi8 % unit + assert (mod == 0).all(), f"round not a {round_freq} multiple" + assert (diff <= unit // 2).all(), "round error" + if unit % 2 == 0: + assert ( + result.asi8[diff == unit // 2] % 2 == 0 + ).all(), "round half to even error" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_shift.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_shift.py new file mode 100644 index 0000000000000000000000000000000000000000..d8bdcc2a176851d92d8bf79bddb2669419e07b76 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_shift.py @@ -0,0 +1,169 @@ +from datetime import datetime + +import pytest +import pytz + +from pandas.errors import NullFrequencyError + +import pandas as pd +from pandas import ( + DatetimeIndex, + Series, + date_range, +) +import pandas._testing as tm + +START, END = datetime(2009, 1, 1), datetime(2010, 1, 1) + + +class TestDatetimeIndexShift: + # ------------------------------------------------------------- + # DatetimeIndex.shift is used in integer addition + + def test_dti_shift_tzaware(self, tz_naive_fixture, unit): + # GH#9903 + tz = tz_naive_fixture + idx = DatetimeIndex([], name="xxx", tz=tz).as_unit(unit) + tm.assert_index_equal(idx.shift(0, freq="h"), idx) + tm.assert_index_equal(idx.shift(3, freq="h"), idx) + + idx = DatetimeIndex( + ["2011-01-01 10:00", "2011-01-01 11:00", "2011-01-01 12:00"], + name="xxx", + tz=tz, + freq="h", + ).as_unit(unit) + tm.assert_index_equal(idx.shift(0, freq="h"), idx) + exp = DatetimeIndex( + ["2011-01-01 13:00", "2011-01-01 14:00", "2011-01-01 15:00"], + name="xxx", + tz=tz, + freq="h", + ).as_unit(unit) + tm.assert_index_equal(idx.shift(3, freq="h"), exp) + exp = DatetimeIndex( + ["2011-01-01 07:00", "2011-01-01 08:00", "2011-01-01 09:00"], + name="xxx", + tz=tz, + freq="h", + ).as_unit(unit) + tm.assert_index_equal(idx.shift(-3, freq="h"), exp) + + def test_dti_shift_freqs(self, unit): + # test shift for DatetimeIndex and non DatetimeIndex + # GH#8083 + drange = date_range("20130101", periods=5, unit=unit) + result = drange.shift(1) + expected = DatetimeIndex( + ["2013-01-02", "2013-01-03", "2013-01-04", "2013-01-05", "2013-01-06"], + dtype=f"M8[{unit}]", + freq="D", + ) + tm.assert_index_equal(result, expected) + + result = drange.shift(-1) + expected = DatetimeIndex( + ["2012-12-31", "2013-01-01", "2013-01-02", "2013-01-03", "2013-01-04"], + dtype=f"M8[{unit}]", + freq="D", + ) + tm.assert_index_equal(result, expected) + + result = drange.shift(3, freq="2D") + expected = DatetimeIndex( + ["2013-01-07", "2013-01-08", "2013-01-09", "2013-01-10", "2013-01-11"], + dtype=f"M8[{unit}]", + freq="D", + ) + tm.assert_index_equal(result, expected) + + def test_dti_shift_int(self, unit): + rng = date_range("1/1/2000", periods=20, unit=unit) + + result = rng + 5 * rng.freq + expected = rng.shift(5) + tm.assert_index_equal(result, expected) + + result = rng - 5 * rng.freq + expected = rng.shift(-5) + tm.assert_index_equal(result, expected) + + def test_dti_shift_no_freq(self, unit): + # GH#19147 + dti = DatetimeIndex(["2011-01-01 10:00", "2011-01-01"], freq=None).as_unit(unit) + with pytest.raises(NullFrequencyError, match="Cannot shift with no freq"): + dti.shift(2) + + @pytest.mark.parametrize("tzstr", ["US/Eastern", "dateutil/US/Eastern"]) + def test_dti_shift_localized(self, tzstr, unit): + dr = date_range("2011/1/1", "2012/1/1", freq="W-FRI", unit=unit) + dr_tz = dr.tz_localize(tzstr) + + result = dr_tz.shift(1, "10min") + assert result.tz == dr_tz.tz + + def test_dti_shift_across_dst(self, unit): + # GH 8616 + idx = date_range( + "2013-11-03", tz="America/Chicago", periods=7, freq="h", unit=unit + ) + ser = Series(index=idx[:-1], dtype=object) + result = ser.shift(freq="h") + expected = Series(index=idx[1:], dtype=object) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "shift, result_time", + [ + [0, "2014-11-14 00:00:00"], + [-1, "2014-11-13 23:00:00"], + [1, "2014-11-14 01:00:00"], + ], + ) + def test_dti_shift_near_midnight(self, shift, result_time, unit): + # GH 8616 + dt = datetime(2014, 11, 14, 0) + dt_est = pytz.timezone("EST").localize(dt) + idx = DatetimeIndex([dt_est]).as_unit(unit) + ser = Series(data=[1], index=idx) + result = ser.shift(shift, freq="h") + exp_index = DatetimeIndex([result_time], tz="EST").as_unit(unit) + expected = Series(1, index=exp_index) + tm.assert_series_equal(result, expected) + + def test_shift_periods(self, unit): + # GH#22458 : argument 'n' was deprecated in favor of 'periods' + idx = date_range(start=START, end=END, periods=3, unit=unit) + tm.assert_index_equal(idx.shift(periods=0), idx) + tm.assert_index_equal(idx.shift(0), idx) + + @pytest.mark.parametrize("freq", ["B", "C"]) + def test_shift_bday(self, freq, unit): + rng = date_range(START, END, freq=freq, unit=unit) + shifted = rng.shift(5) + assert shifted[0] == rng[5] + assert shifted.freq == rng.freq + + shifted = rng.shift(-5) + assert shifted[5] == rng[0] + assert shifted.freq == rng.freq + + shifted = rng.shift(0) + assert shifted[0] == rng[0] + assert shifted.freq == rng.freq + + def test_shift_bmonth(self, unit): + rng = date_range(START, END, freq=pd.offsets.BMonthEnd(), unit=unit) + shifted = rng.shift(1, freq=pd.offsets.BDay()) + assert shifted[0] == rng[0] + pd.offsets.BDay() + + rng = date_range(START, END, freq=pd.offsets.BMonthEnd(), unit=unit) + with tm.assert_produces_warning(pd.errors.PerformanceWarning): + shifted = rng.shift(1, freq=pd.offsets.CDay()) + assert shifted[0] == rng[0] + pd.offsets.CDay() + + def test_shift_empty(self, unit): + # GH#14811 + dti = date_range(start="2016-10-21", end="2016-10-21", freq="BME", unit=unit) + result = dti.shift(1) + tm.assert_index_equal(result, dti) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_snap.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_snap.py new file mode 100644 index 0000000000000000000000000000000000000000..7064e9e7993f8cd14420bb3101c084923c13c4e7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_snap.py @@ -0,0 +1,47 @@ +import pytest + +from pandas import ( + DatetimeIndex, + date_range, +) +import pandas._testing as tm + + +@pytest.mark.parametrize("tz", [None, "Asia/Shanghai", "Europe/Berlin"]) +@pytest.mark.parametrize("name", [None, "my_dti"]) +@pytest.mark.parametrize("unit", ["ns", "us", "ms", "s"]) +def test_dti_snap(name, tz, unit): + dti = DatetimeIndex( + [ + "1/1/2002", + "1/2/2002", + "1/3/2002", + "1/4/2002", + "1/5/2002", + "1/6/2002", + "1/7/2002", + ], + name=name, + tz=tz, + freq="D", + ) + dti = dti.as_unit(unit) + + result = dti.snap(freq="W-MON") + expected = date_range("12/31/2001", "1/7/2002", name=name, tz=tz, freq="w-mon") + expected = expected.repeat([3, 4]) + expected = expected.as_unit(unit) + tm.assert_index_equal(result, expected) + assert result.tz == expected.tz + assert result.freq is None + assert expected.freq is None + + result = dti.snap(freq="B") + + expected = date_range("1/1/2002", "1/7/2002", name=name, tz=tz, freq="b") + expected = expected.repeat([1, 1, 1, 2, 2]) + expected = expected.as_unit(unit) + tm.assert_index_equal(result, expected) + assert result.tz == expected.tz + assert result.freq is None + assert expected.freq is None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_to_frame.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_to_frame.py new file mode 100644 index 0000000000000000000000000000000000000000..c829109d4e06c14dca160f1de8903432f844f4ef --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_to_frame.py @@ -0,0 +1,28 @@ +from pandas import ( + DataFrame, + Index, + date_range, +) +import pandas._testing as tm + + +class TestToFrame: + def test_to_frame_datetime_tz(self): + # GH#25809 + idx = date_range(start="2019-01-01", end="2019-01-30", freq="D", tz="UTC") + result = idx.to_frame() + expected = DataFrame(idx, index=idx) + tm.assert_frame_equal(result, expected) + + def test_to_frame_respects_none_name(self): + # GH#44212 if we explicitly pass name=None, then that should be respected, + # not changed to 0 + # GH-45448 this is first deprecated to only change in the future + idx = date_range(start="2019-01-01", end="2019-01-30", freq="D", tz="UTC") + result = idx.to_frame(name=None) + exp_idx = Index([None], dtype=object) + tm.assert_index_equal(exp_idx, result.columns) + + result = idx.rename("foo").to_frame(name=None) + exp_idx = Index([None], dtype=object) + tm.assert_index_equal(exp_idx, result.columns) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_to_julian_date.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_to_julian_date.py new file mode 100644 index 0000000000000000000000000000000000000000..fc1f0595c21c527816acedf6ef97839ce7d71713 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_to_julian_date.py @@ -0,0 +1,45 @@ +import numpy as np + +from pandas import ( + Index, + Timestamp, + date_range, +) +import pandas._testing as tm + + +class TestDateTimeIndexToJulianDate: + def test_1700(self): + dr = date_range(start=Timestamp("1710-10-01"), periods=5, freq="D") + r1 = Index([x.to_julian_date() for x in dr]) + r2 = dr.to_julian_date() + assert isinstance(r2, Index) and r2.dtype == np.float64 + tm.assert_index_equal(r1, r2) + + def test_2000(self): + dr = date_range(start=Timestamp("2000-02-27"), periods=5, freq="D") + r1 = Index([x.to_julian_date() for x in dr]) + r2 = dr.to_julian_date() + assert isinstance(r2, Index) and r2.dtype == np.float64 + tm.assert_index_equal(r1, r2) + + def test_hour(self): + dr = date_range(start=Timestamp("2000-02-27"), periods=5, freq="h") + r1 = Index([x.to_julian_date() for x in dr]) + r2 = dr.to_julian_date() + assert isinstance(r2, Index) and r2.dtype == np.float64 + tm.assert_index_equal(r1, r2) + + def test_minute(self): + dr = date_range(start=Timestamp("2000-02-27"), periods=5, freq="min") + r1 = Index([x.to_julian_date() for x in dr]) + r2 = dr.to_julian_date() + assert isinstance(r2, Index) and r2.dtype == np.float64 + tm.assert_index_equal(r1, r2) + + def test_second(self): + dr = date_range(start=Timestamp("2000-02-27"), periods=5, freq="s") + r1 = Index([x.to_julian_date() for x in dr]) + r2 = dr.to_julian_date() + assert isinstance(r2, Index) and r2.dtype == np.float64 + tm.assert_index_equal(r1, r2) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_to_period.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_to_period.py new file mode 100644 index 0000000000000000000000000000000000000000..de8d32f64cde26b2fa0a0720cbdacc56f6c2e983 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_to_period.py @@ -0,0 +1,225 @@ +import dateutil.tz +from dateutil.tz import tzlocal +import pytest +import pytz + +from pandas._libs.tslibs.ccalendar import MONTHS +from pandas._libs.tslibs.offsets import MonthEnd +from pandas._libs.tslibs.period import INVALID_FREQ_ERR_MSG + +from pandas import ( + DatetimeIndex, + Period, + PeriodIndex, + Timestamp, + date_range, + period_range, +) +import pandas._testing as tm + + +class TestToPeriod: + def test_dti_to_period(self): + dti = date_range(start="1/1/2005", end="12/1/2005", freq="ME") + pi1 = dti.to_period() + pi2 = dti.to_period(freq="D") + pi3 = dti.to_period(freq="3D") + + assert pi1[0] == Period("Jan 2005", freq="M") + assert pi2[0] == Period("1/31/2005", freq="D") + assert pi3[0] == Period("1/31/2005", freq="3D") + + assert pi1[-1] == Period("Nov 2005", freq="M") + assert pi2[-1] == Period("11/30/2005", freq="D") + assert pi3[-1], Period("11/30/2005", freq="3D") + + tm.assert_index_equal(pi1, period_range("1/1/2005", "11/1/2005", freq="M")) + tm.assert_index_equal( + pi2, period_range("1/1/2005", "11/1/2005", freq="M").asfreq("D") + ) + tm.assert_index_equal( + pi3, period_range("1/1/2005", "11/1/2005", freq="M").asfreq("3D") + ) + + @pytest.mark.parametrize("month", MONTHS) + def test_to_period_quarterly(self, month): + # make sure we can make the round trip + freq = f"Q-{month}" + rng = period_range("1989Q3", "1991Q3", freq=freq) + stamps = rng.to_timestamp() + result = stamps.to_period(freq) + tm.assert_index_equal(rng, result) + + @pytest.mark.parametrize("off", ["BQE", "QS", "BQS"]) + def test_to_period_quarterlyish(self, off): + rng = date_range("01-Jan-2012", periods=8, freq=off) + prng = rng.to_period() + assert prng.freq == "QE-DEC" + + @pytest.mark.parametrize("off", ["BYE", "YS", "BYS"]) + def test_to_period_annualish(self, off): + rng = date_range("01-Jan-2012", periods=8, freq=off) + prng = rng.to_period() + assert prng.freq == "YE-DEC" + + def test_to_period_monthish(self): + offsets = ["MS", "BME"] + for off in offsets: + rng = date_range("01-Jan-2012", periods=8, freq=off) + prng = rng.to_period() + assert prng.freqstr == "M" + + rng = date_range("01-Jan-2012", periods=8, freq="ME") + prng = rng.to_period() + assert prng.freqstr == "M" + + with pytest.raises(ValueError, match=INVALID_FREQ_ERR_MSG): + date_range("01-Jan-2012", periods=8, freq="EOM") + + @pytest.mark.parametrize( + "freq_offset, freq_period", + [ + ("2ME", "2M"), + (MonthEnd(2), MonthEnd(2)), + ], + ) + def test_dti_to_period_2monthish(self, freq_offset, freq_period): + dti = date_range("2020-01-01", periods=3, freq=freq_offset) + pi = dti.to_period() + + tm.assert_index_equal(pi, period_range("2020-01", "2020-05", freq=freq_period)) + + @pytest.mark.parametrize( + "freq, freq_depr", + [ + ("2ME", "2M"), + ("2QE", "2Q"), + ("2QE-SEP", "2Q-SEP"), + ("1YE", "1Y"), + ("2YE-MAR", "2Y-MAR"), + ("1YE", "1A"), + ("2YE-MAR", "2A-MAR"), + ], + ) + def test_to_period_frequency_M_Q_Y_A_deprecated(self, freq, freq_depr): + # GH#9586 + msg = f"'{freq_depr[1:]}' is deprecated and will be removed " + f"in a future version, please use '{freq[1:]}' instead." + + rng = date_range("01-Jan-2012", periods=8, freq=freq) + prng = rng.to_period() + with tm.assert_produces_warning(FutureWarning, match=msg): + assert prng.freq == freq_depr + + def test_to_period_infer(self): + # https://github.com/pandas-dev/pandas/issues/33358 + rng = date_range( + start="2019-12-22 06:40:00+00:00", + end="2019-12-22 08:45:00+00:00", + freq="5min", + ) + + with tm.assert_produces_warning(UserWarning): + pi1 = rng.to_period("5min") + + with tm.assert_produces_warning(UserWarning): + pi2 = rng.to_period() + + tm.assert_index_equal(pi1, pi2) + + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + def test_period_dt64_round_trip(self): + dti = date_range("1/1/2000", "1/7/2002", freq="B") + pi = dti.to_period() + tm.assert_index_equal(pi.to_timestamp(), dti) + + dti = date_range("1/1/2000", "1/7/2002", freq="B") + pi = dti.to_period(freq="h") + tm.assert_index_equal(pi.to_timestamp(), dti) + + def test_to_period_millisecond(self): + index = DatetimeIndex( + [ + Timestamp("2007-01-01 10:11:12.123456Z"), + Timestamp("2007-01-01 10:11:13.789123Z"), + ] + ) + + with tm.assert_produces_warning(UserWarning): + # warning that timezone info will be lost + period = index.to_period(freq="ms") + assert 2 == len(period) + assert period[0] == Period("2007-01-01 10:11:12.123Z", "ms") + assert period[1] == Period("2007-01-01 10:11:13.789Z", "ms") + + def test_to_period_microsecond(self): + index = DatetimeIndex( + [ + Timestamp("2007-01-01 10:11:12.123456Z"), + Timestamp("2007-01-01 10:11:13.789123Z"), + ] + ) + + with tm.assert_produces_warning(UserWarning): + # warning that timezone info will be lost + period = index.to_period(freq="us") + assert 2 == len(period) + assert period[0] == Period("2007-01-01 10:11:12.123456Z", "us") + assert period[1] == Period("2007-01-01 10:11:13.789123Z", "us") + + @pytest.mark.parametrize( + "tz", + ["US/Eastern", pytz.utc, tzlocal(), "dateutil/US/Eastern", dateutil.tz.tzutc()], + ) + def test_to_period_tz(self, tz): + ts = date_range("1/1/2000", "2/1/2000", tz=tz) + + with tm.assert_produces_warning(UserWarning): + # GH#21333 warning that timezone info will be lost + # filter warning about freq deprecation + + result = ts.to_period()[0] + expected = ts[0].to_period(ts.freq) + + assert result == expected + + expected = date_range("1/1/2000", "2/1/2000").to_period() + + with tm.assert_produces_warning(UserWarning): + # GH#21333 warning that timezone info will be lost + result = ts.to_period(ts.freq) + + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("tz", ["Etc/GMT-1", "Etc/GMT+1"]) + def test_to_period_tz_utc_offset_consistency(self, tz): + # GH#22905 + ts = date_range("1/1/2000", "2/1/2000", tz="Etc/GMT-1") + with tm.assert_produces_warning(UserWarning): + result = ts.to_period()[0] + expected = ts[0].to_period(ts.freq) + assert result == expected + + def test_to_period_nofreq(self): + idx = DatetimeIndex(["2000-01-01", "2000-01-02", "2000-01-04"]) + msg = "You must pass a freq argument as current index has none." + with pytest.raises(ValueError, match=msg): + idx.to_period() + + idx = DatetimeIndex(["2000-01-01", "2000-01-02", "2000-01-03"], freq="infer") + assert idx.freqstr == "D" + expected = PeriodIndex(["2000-01-01", "2000-01-02", "2000-01-03"], freq="D") + tm.assert_index_equal(idx.to_period(), expected) + + # GH#7606 + idx = DatetimeIndex(["2000-01-01", "2000-01-02", "2000-01-03"]) + assert idx.freqstr is None + tm.assert_index_equal(idx.to_period(), expected) + + @pytest.mark.parametrize("freq", ["2BMS", "1SME-15"]) + def test_to_period_offsets_not_supported(self, freq): + # GH#56243 + msg = f"{freq[1:]} is not supported as period frequency" + ts = date_range("1/1/2012", periods=4, freq=freq) + with pytest.raises(ValueError, match=msg): + ts.to_period() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_to_pydatetime.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_to_pydatetime.py new file mode 100644 index 0000000000000000000000000000000000000000..fe97ff0cca8ebe6d04ce093077d6ee44d73a7e0b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_to_pydatetime.py @@ -0,0 +1,51 @@ +from datetime import ( + datetime, + timezone, +) + +import dateutil.parser +import dateutil.tz +from dateutil.tz import tzlocal +import numpy as np + +from pandas import ( + DatetimeIndex, + date_range, + to_datetime, +) +import pandas._testing as tm +from pandas.tests.indexes.datetimes.test_timezones import FixedOffset + +fixed_off = FixedOffset(-420, "-07:00") + + +class TestToPyDatetime: + def test_dti_to_pydatetime(self): + dt = dateutil.parser.parse("2012-06-13T01:39:00Z") + dt = dt.replace(tzinfo=tzlocal()) + + arr = np.array([dt], dtype=object) + + result = to_datetime(arr, utc=True) + assert result.tz is timezone.utc + + rng = date_range("2012-11-03 03:00", "2012-11-05 03:00", tz=tzlocal()) + arr = rng.to_pydatetime() + result = to_datetime(arr, utc=True) + assert result.tz is timezone.utc + + def test_dti_to_pydatetime_fizedtz(self): + dates = np.array( + [ + datetime(2000, 1, 1, tzinfo=fixed_off), + datetime(2000, 1, 2, tzinfo=fixed_off), + datetime(2000, 1, 3, tzinfo=fixed_off), + ] + ) + dti = DatetimeIndex(dates) + + result = dti.to_pydatetime() + tm.assert_numpy_array_equal(dates, result) + + result = dti._mpl_repr() + tm.assert_numpy_array_equal(dates, result) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_to_series.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_to_series.py new file mode 100644 index 0000000000000000000000000000000000000000..0c397c8ab2cd310a2d4fdf59992ea4d123370ee0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_to_series.py @@ -0,0 +1,18 @@ +import numpy as np + +from pandas import ( + DatetimeIndex, + Series, +) +import pandas._testing as tm + + +class TestToSeries: + def test_to_series(self): + naive = DatetimeIndex(["2013-1-1 13:00", "2013-1-2 14:00"], name="B") + idx = naive.tz_localize("US/Pacific") + + expected = Series(np.array(idx.tolist(), dtype="object"), name="B") + result = idx.to_series(index=[0, 1]) + assert expected.dtype == idx.dtype + tm.assert_series_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_tz_convert.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_tz_convert.py new file mode 100644 index 0000000000000000000000000000000000000000..b2cf488ac8313c527bd4eb489abc4a11ff820988 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_tz_convert.py @@ -0,0 +1,283 @@ +from datetime import datetime + +import dateutil.tz +from dateutil.tz import gettz +import numpy as np +import pytest +import pytz + +from pandas._libs.tslibs import timezones + +from pandas import ( + DatetimeIndex, + Index, + NaT, + Timestamp, + date_range, + offsets, +) +import pandas._testing as tm + + +class TestTZConvert: + def test_tz_convert_nat(self): + # GH#5546 + dates = [NaT] + idx = DatetimeIndex(dates) + idx = idx.tz_localize("US/Pacific") + tm.assert_index_equal(idx, DatetimeIndex(dates, tz="US/Pacific")) + idx = idx.tz_convert("US/Eastern") + tm.assert_index_equal(idx, DatetimeIndex(dates, tz="US/Eastern")) + idx = idx.tz_convert("UTC") + tm.assert_index_equal(idx, DatetimeIndex(dates, tz="UTC")) + + dates = ["2010-12-01 00:00", "2010-12-02 00:00", NaT] + idx = DatetimeIndex(dates) + idx = idx.tz_localize("US/Pacific") + tm.assert_index_equal(idx, DatetimeIndex(dates, tz="US/Pacific")) + idx = idx.tz_convert("US/Eastern") + expected = ["2010-12-01 03:00", "2010-12-02 03:00", NaT] + tm.assert_index_equal(idx, DatetimeIndex(expected, tz="US/Eastern")) + + idx = idx + offsets.Hour(5) + expected = ["2010-12-01 08:00", "2010-12-02 08:00", NaT] + tm.assert_index_equal(idx, DatetimeIndex(expected, tz="US/Eastern")) + idx = idx.tz_convert("US/Pacific") + expected = ["2010-12-01 05:00", "2010-12-02 05:00", NaT] + tm.assert_index_equal(idx, DatetimeIndex(expected, tz="US/Pacific")) + + idx = idx + np.timedelta64(3, "h") + expected = ["2010-12-01 08:00", "2010-12-02 08:00", NaT] + tm.assert_index_equal(idx, DatetimeIndex(expected, tz="US/Pacific")) + + idx = idx.tz_convert("US/Eastern") + expected = ["2010-12-01 11:00", "2010-12-02 11:00", NaT] + tm.assert_index_equal(idx, DatetimeIndex(expected, tz="US/Eastern")) + + @pytest.mark.parametrize("prefix", ["", "dateutil/"]) + def test_dti_tz_convert_compat_timestamp(self, prefix): + strdates = ["1/1/2012", "3/1/2012", "4/1/2012"] + idx = DatetimeIndex(strdates, tz=prefix + "US/Eastern") + + conv = idx[0].tz_convert(prefix + "US/Pacific") + expected = idx.tz_convert(prefix + "US/Pacific")[0] + + assert conv == expected + + def test_dti_tz_convert_hour_overflow_dst(self): + # Regression test for GH#13306 + + # sorted case US/Eastern -> UTC + ts = ["2008-05-12 09:50:00", "2008-12-12 09:50:35", "2009-05-12 09:50:32"] + tt = DatetimeIndex(ts).tz_localize("US/Eastern") + ut = tt.tz_convert("UTC") + expected = Index([13, 14, 13], dtype=np.int32) + tm.assert_index_equal(ut.hour, expected) + + # sorted case UTC -> US/Eastern + ts = ["2008-05-12 13:50:00", "2008-12-12 14:50:35", "2009-05-12 13:50:32"] + tt = DatetimeIndex(ts).tz_localize("UTC") + ut = tt.tz_convert("US/Eastern") + expected = Index([9, 9, 9], dtype=np.int32) + tm.assert_index_equal(ut.hour, expected) + + # unsorted case US/Eastern -> UTC + ts = ["2008-05-12 09:50:00", "2008-12-12 09:50:35", "2008-05-12 09:50:32"] + tt = DatetimeIndex(ts).tz_localize("US/Eastern") + ut = tt.tz_convert("UTC") + expected = Index([13, 14, 13], dtype=np.int32) + tm.assert_index_equal(ut.hour, expected) + + # unsorted case UTC -> US/Eastern + ts = ["2008-05-12 13:50:00", "2008-12-12 14:50:35", "2008-05-12 13:50:32"] + tt = DatetimeIndex(ts).tz_localize("UTC") + ut = tt.tz_convert("US/Eastern") + expected = Index([9, 9, 9], dtype=np.int32) + tm.assert_index_equal(ut.hour, expected) + + @pytest.mark.parametrize("tz", ["US/Eastern", "dateutil/US/Eastern"]) + def test_dti_tz_convert_hour_overflow_dst_timestamps(self, tz): + # Regression test for GH#13306 + + # sorted case US/Eastern -> UTC + ts = [ + Timestamp("2008-05-12 09:50:00", tz=tz), + Timestamp("2008-12-12 09:50:35", tz=tz), + Timestamp("2009-05-12 09:50:32", tz=tz), + ] + tt = DatetimeIndex(ts) + ut = tt.tz_convert("UTC") + expected = Index([13, 14, 13], dtype=np.int32) + tm.assert_index_equal(ut.hour, expected) + + # sorted case UTC -> US/Eastern + ts = [ + Timestamp("2008-05-12 13:50:00", tz="UTC"), + Timestamp("2008-12-12 14:50:35", tz="UTC"), + Timestamp("2009-05-12 13:50:32", tz="UTC"), + ] + tt = DatetimeIndex(ts) + ut = tt.tz_convert("US/Eastern") + expected = Index([9, 9, 9], dtype=np.int32) + tm.assert_index_equal(ut.hour, expected) + + # unsorted case US/Eastern -> UTC + ts = [ + Timestamp("2008-05-12 09:50:00", tz=tz), + Timestamp("2008-12-12 09:50:35", tz=tz), + Timestamp("2008-05-12 09:50:32", tz=tz), + ] + tt = DatetimeIndex(ts) + ut = tt.tz_convert("UTC") + expected = Index([13, 14, 13], dtype=np.int32) + tm.assert_index_equal(ut.hour, expected) + + # unsorted case UTC -> US/Eastern + ts = [ + Timestamp("2008-05-12 13:50:00", tz="UTC"), + Timestamp("2008-12-12 14:50:35", tz="UTC"), + Timestamp("2008-05-12 13:50:32", tz="UTC"), + ] + tt = DatetimeIndex(ts) + ut = tt.tz_convert("US/Eastern") + expected = Index([9, 9, 9], dtype=np.int32) + tm.assert_index_equal(ut.hour, expected) + + @pytest.mark.parametrize("freq, n", [("h", 1), ("min", 60), ("s", 3600)]) + def test_dti_tz_convert_trans_pos_plus_1__bug(self, freq, n): + # Regression test for tslib.tz_convert(vals, tz1, tz2). + # See GH#4496 for details. + idx = date_range(datetime(2011, 3, 26, 23), datetime(2011, 3, 27, 1), freq=freq) + idx = idx.tz_localize("UTC") + idx = idx.tz_convert("Europe/Moscow") + + expected = np.repeat(np.array([3, 4, 5]), np.array([n, n, 1])) + tm.assert_index_equal(idx.hour, Index(expected, dtype=np.int32)) + + def test_dti_tz_convert_dst(self): + for freq, n in [("h", 1), ("min", 60), ("s", 3600)]: + # Start DST + idx = date_range( + "2014-03-08 23:00", "2014-03-09 09:00", freq=freq, tz="UTC" + ) + idx = idx.tz_convert("US/Eastern") + expected = np.repeat( + np.array([18, 19, 20, 21, 22, 23, 0, 1, 3, 4, 5]), + np.array([n, n, n, n, n, n, n, n, n, n, 1]), + ) + tm.assert_index_equal(idx.hour, Index(expected, dtype=np.int32)) + + idx = date_range( + "2014-03-08 18:00", "2014-03-09 05:00", freq=freq, tz="US/Eastern" + ) + idx = idx.tz_convert("UTC") + expected = np.repeat( + np.array([23, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), + np.array([n, n, n, n, n, n, n, n, n, n, 1]), + ) + tm.assert_index_equal(idx.hour, Index(expected, dtype=np.int32)) + + # End DST + idx = date_range( + "2014-11-01 23:00", "2014-11-02 09:00", freq=freq, tz="UTC" + ) + idx = idx.tz_convert("US/Eastern") + expected = np.repeat( + np.array([19, 20, 21, 22, 23, 0, 1, 1, 2, 3, 4]), + np.array([n, n, n, n, n, n, n, n, n, n, 1]), + ) + tm.assert_index_equal(idx.hour, Index(expected, dtype=np.int32)) + + idx = date_range( + "2014-11-01 18:00", "2014-11-02 05:00", freq=freq, tz="US/Eastern" + ) + idx = idx.tz_convert("UTC") + expected = np.repeat( + np.array([22, 23, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), + np.array([n, n, n, n, n, n, n, n, n, n, n, n, 1]), + ) + tm.assert_index_equal(idx.hour, Index(expected, dtype=np.int32)) + + # daily + # Start DST + idx = date_range("2014-03-08 00:00", "2014-03-09 00:00", freq="D", tz="UTC") + idx = idx.tz_convert("US/Eastern") + tm.assert_index_equal(idx.hour, Index([19, 19], dtype=np.int32)) + + idx = date_range( + "2014-03-08 00:00", "2014-03-09 00:00", freq="D", tz="US/Eastern" + ) + idx = idx.tz_convert("UTC") + tm.assert_index_equal(idx.hour, Index([5, 5], dtype=np.int32)) + + # End DST + idx = date_range("2014-11-01 00:00", "2014-11-02 00:00", freq="D", tz="UTC") + idx = idx.tz_convert("US/Eastern") + tm.assert_index_equal(idx.hour, Index([20, 20], dtype=np.int32)) + + idx = date_range( + "2014-11-01 00:00", "2014-11-02 000:00", freq="D", tz="US/Eastern" + ) + idx = idx.tz_convert("UTC") + tm.assert_index_equal(idx.hour, Index([4, 4], dtype=np.int32)) + + def test_tz_convert_roundtrip(self, tz_aware_fixture): + tz = tz_aware_fixture + idx1 = date_range(start="2014-01-01", end="2014-12-31", freq="ME", tz="UTC") + exp1 = date_range(start="2014-01-01", end="2014-12-31", freq="ME") + + idx2 = date_range(start="2014-01-01", end="2014-12-31", freq="D", tz="UTC") + exp2 = date_range(start="2014-01-01", end="2014-12-31", freq="D") + + idx3 = date_range(start="2014-01-01", end="2014-03-01", freq="h", tz="UTC") + exp3 = date_range(start="2014-01-01", end="2014-03-01", freq="h") + + idx4 = date_range(start="2014-08-01", end="2014-10-31", freq="min", tz="UTC") + exp4 = date_range(start="2014-08-01", end="2014-10-31", freq="min") + + for idx, expected in [(idx1, exp1), (idx2, exp2), (idx3, exp3), (idx4, exp4)]: + converted = idx.tz_convert(tz) + reset = converted.tz_convert(None) + tm.assert_index_equal(reset, expected) + assert reset.tzinfo is None + expected = converted.tz_convert("UTC").tz_localize(None) + expected = expected._with_freq("infer") + tm.assert_index_equal(reset, expected) + + def test_dti_tz_convert_tzlocal(self): + # GH#13583 + # tz_convert doesn't affect to internal + dti = date_range(start="2001-01-01", end="2001-03-01", tz="UTC") + dti2 = dti.tz_convert(dateutil.tz.tzlocal()) + tm.assert_numpy_array_equal(dti2.asi8, dti.asi8) + + dti = date_range(start="2001-01-01", end="2001-03-01", tz=dateutil.tz.tzlocal()) + dti2 = dti.tz_convert(None) + tm.assert_numpy_array_equal(dti2.asi8, dti.asi8) + + @pytest.mark.parametrize( + "tz", + [ + "US/Eastern", + "dateutil/US/Eastern", + pytz.timezone("US/Eastern"), + gettz("US/Eastern"), + ], + ) + def test_dti_tz_convert_utc_to_local_no_modify(self, tz): + rng = date_range("3/11/2012", "3/12/2012", freq="h", tz="utc") + rng_eastern = rng.tz_convert(tz) + + # Values are unmodified + tm.assert_numpy_array_equal(rng.asi8, rng_eastern.asi8) + + assert timezones.tz_compare(rng_eastern.tz, timezones.maybe_get_tz(tz)) + + @pytest.mark.parametrize("tzstr", ["US/Eastern", "dateutil/US/Eastern"]) + def test_tz_convert_unsorted(self, tzstr): + dr = date_range("2012-03-09", freq="h", periods=100, tz="utc") + dr = dr.tz_convert(tzstr) + + result = dr[::-1].hour + exp = dr.hour[::-1] + tm.assert_almost_equal(result, exp) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_tz_localize.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_tz_localize.py new file mode 100644 index 0000000000000000000000000000000000000000..ad7769c6b96714b30fe4f3a1e1468de05ec1e6f2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_tz_localize.py @@ -0,0 +1,402 @@ +from datetime import ( + datetime, + timedelta, +) + +import dateutil.tz +from dateutil.tz import gettz +import numpy as np +import pytest +import pytz + +from pandas import ( + DatetimeIndex, + Timestamp, + bdate_range, + date_range, + offsets, + to_datetime, +) +import pandas._testing as tm + +try: + from zoneinfo import ZoneInfo +except ImportError: + # Cannot assign to a type [misc] + ZoneInfo = None # type: ignore[misc, assignment] + + +easts = [pytz.timezone("US/Eastern"), gettz("US/Eastern")] +if ZoneInfo is not None: + try: + tz = ZoneInfo("US/Eastern") + except KeyError: + # no tzdata + pass + else: + easts.append(tz) + + +class TestTZLocalize: + def test_tz_localize_invalidates_freq(self): + # we only preserve freq in unambiguous cases + + # if localized to US/Eastern, this crosses a DST transition + dti = date_range("2014-03-08 23:00", "2014-03-09 09:00", freq="h") + assert dti.freq == "h" + + result = dti.tz_localize(None) # no-op + assert result.freq == "h" + + result = dti.tz_localize("UTC") # unambiguous freq preservation + assert result.freq == "h" + + result = dti.tz_localize("US/Eastern", nonexistent="shift_forward") + assert result.freq is None + assert result.inferred_freq is None # i.e. we are not _too_ strict here + + # Case where we _can_ keep freq because we're length==1 + dti2 = dti[:1] + result = dti2.tz_localize("US/Eastern") + assert result.freq == "h" + + def test_tz_localize_utc_copies(self, utc_fixture): + # GH#46460 + times = ["2015-03-08 01:00", "2015-03-08 02:00", "2015-03-08 03:00"] + index = DatetimeIndex(times) + + res = index.tz_localize(utc_fixture) + assert not tm.shares_memory(res, index) + + res2 = index._data.tz_localize(utc_fixture) + assert not tm.shares_memory(index._data, res2) + + def test_dti_tz_localize_nonexistent_raise_coerce(self): + # GH#13057 + times = ["2015-03-08 01:00", "2015-03-08 02:00", "2015-03-08 03:00"] + index = DatetimeIndex(times) + tz = "US/Eastern" + with pytest.raises(pytz.NonExistentTimeError, match="|".join(times)): + index.tz_localize(tz=tz) + + with pytest.raises(pytz.NonExistentTimeError, match="|".join(times)): + index.tz_localize(tz=tz, nonexistent="raise") + + result = index.tz_localize(tz=tz, nonexistent="NaT") + test_times = ["2015-03-08 01:00-05:00", "NaT", "2015-03-08 03:00-04:00"] + dti = to_datetime(test_times, utc=True) + expected = dti.tz_convert("US/Eastern") + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("tz", easts) + def test_dti_tz_localize_ambiguous_infer(self, tz): + # November 6, 2011, fall back, repeat 2 AM hour + # With no repeated hours, we cannot infer the transition + dr = date_range(datetime(2011, 11, 6, 0), periods=5, freq=offsets.Hour()) + with pytest.raises(pytz.AmbiguousTimeError, match="Cannot infer dst time"): + dr.tz_localize(tz) + + @pytest.mark.parametrize("tz", easts) + def test_dti_tz_localize_ambiguous_infer2(self, tz, unit): + # With repeated hours, we can infer the transition + dr = date_range( + datetime(2011, 11, 6, 0), periods=5, freq=offsets.Hour(), tz=tz, unit=unit + ) + times = [ + "11/06/2011 00:00", + "11/06/2011 01:00", + "11/06/2011 01:00", + "11/06/2011 02:00", + "11/06/2011 03:00", + ] + di = DatetimeIndex(times).as_unit(unit) + result = di.tz_localize(tz, ambiguous="infer") + expected = dr._with_freq(None) + tm.assert_index_equal(result, expected) + result2 = DatetimeIndex(times, tz=tz, ambiguous="infer").as_unit(unit) + tm.assert_index_equal(result2, expected) + + @pytest.mark.parametrize("tz", easts) + def test_dti_tz_localize_ambiguous_infer3(self, tz): + # When there is no dst transition, nothing special happens + dr = date_range(datetime(2011, 6, 1, 0), periods=10, freq=offsets.Hour()) + localized = dr.tz_localize(tz) + localized_infer = dr.tz_localize(tz, ambiguous="infer") + tm.assert_index_equal(localized, localized_infer) + + @pytest.mark.parametrize("tz", easts) + def test_dti_tz_localize_ambiguous_times(self, tz): + # March 13, 2011, spring forward, skip from 2 AM to 3 AM + dr = date_range(datetime(2011, 3, 13, 1, 30), periods=3, freq=offsets.Hour()) + with pytest.raises(pytz.NonExistentTimeError, match="2011-03-13 02:30:00"): + dr.tz_localize(tz) + + # after dst transition, it works + dr = date_range( + datetime(2011, 3, 13, 3, 30), periods=3, freq=offsets.Hour(), tz=tz + ) + + # November 6, 2011, fall back, repeat 2 AM hour + dr = date_range(datetime(2011, 11, 6, 1, 30), periods=3, freq=offsets.Hour()) + with pytest.raises(pytz.AmbiguousTimeError, match="Cannot infer dst time"): + dr.tz_localize(tz) + + # UTC is OK + dr = date_range( + datetime(2011, 3, 13), periods=48, freq=offsets.Minute(30), tz=pytz.utc + ) + + @pytest.mark.parametrize("tzstr", ["US/Eastern", "dateutil/US/Eastern"]) + def test_dti_tz_localize_pass_dates_to_utc(self, tzstr): + strdates = ["1/1/2012", "3/1/2012", "4/1/2012"] + + idx = DatetimeIndex(strdates) + conv = idx.tz_localize(tzstr) + + fromdates = DatetimeIndex(strdates, tz=tzstr) + + assert conv.tz == fromdates.tz + tm.assert_numpy_array_equal(conv.values, fromdates.values) + + @pytest.mark.parametrize("prefix", ["", "dateutil/"]) + def test_dti_tz_localize(self, prefix): + tzstr = prefix + "US/Eastern" + dti = date_range(start="1/1/2005", end="1/1/2005 0:00:30.256", freq="ms") + dti2 = dti.tz_localize(tzstr) + + dti_utc = date_range( + start="1/1/2005 05:00", end="1/1/2005 5:00:30.256", freq="ms", tz="utc" + ) + + tm.assert_numpy_array_equal(dti2.values, dti_utc.values) + + dti3 = dti2.tz_convert(prefix + "US/Pacific") + tm.assert_numpy_array_equal(dti3.values, dti_utc.values) + + dti = date_range(start="11/6/2011 1:59", end="11/6/2011 2:00", freq="ms") + with pytest.raises(pytz.AmbiguousTimeError, match="Cannot infer dst time"): + dti.tz_localize(tzstr) + + dti = date_range(start="3/13/2011 1:59", end="3/13/2011 2:00", freq="ms") + with pytest.raises(pytz.NonExistentTimeError, match="2011-03-13 02:00:00"): + dti.tz_localize(tzstr) + + @pytest.mark.parametrize( + "tz", + [ + "US/Eastern", + "dateutil/US/Eastern", + pytz.timezone("US/Eastern"), + gettz("US/Eastern"), + ], + ) + def test_dti_tz_localize_utc_conversion(self, tz): + # Localizing to time zone should: + # 1) check for DST ambiguities + # 2) convert to UTC + + rng = date_range("3/10/2012", "3/11/2012", freq="30min") + + converted = rng.tz_localize(tz) + expected_naive = rng + offsets.Hour(5) + tm.assert_numpy_array_equal(converted.asi8, expected_naive.asi8) + + # DST ambiguity, this should fail + rng = date_range("3/11/2012", "3/12/2012", freq="30min") + # Is this really how it should fail?? + with pytest.raises(pytz.NonExistentTimeError, match="2012-03-11 02:00:00"): + rng.tz_localize(tz) + + def test_dti_tz_localize_roundtrip(self, tz_aware_fixture): + # note: this tz tests that a tz-naive index can be localized + # and de-localized successfully, when there are no DST transitions + # in the range. + idx = date_range(start="2014-06-01", end="2014-08-30", freq="15min") + tz = tz_aware_fixture + localized = idx.tz_localize(tz) + # can't localize a tz-aware object + with pytest.raises( + TypeError, match="Already tz-aware, use tz_convert to convert" + ): + localized.tz_localize(tz) + reset = localized.tz_localize(None) + assert reset.tzinfo is None + expected = idx._with_freq(None) + tm.assert_index_equal(reset, expected) + + def test_dti_tz_localize_naive(self): + rng = date_range("1/1/2011", periods=100, freq="h") + + conv = rng.tz_localize("US/Pacific") + exp = date_range("1/1/2011", periods=100, freq="h", tz="US/Pacific") + + tm.assert_index_equal(conv, exp._with_freq(None)) + + def test_dti_tz_localize_tzlocal(self): + # GH#13583 + offset = dateutil.tz.tzlocal().utcoffset(datetime(2011, 1, 1)) + offset = int(offset.total_seconds() * 1000000000) + + dti = date_range(start="2001-01-01", end="2001-03-01") + dti2 = dti.tz_localize(dateutil.tz.tzlocal()) + tm.assert_numpy_array_equal(dti2.asi8 + offset, dti.asi8) + + dti = date_range(start="2001-01-01", end="2001-03-01", tz=dateutil.tz.tzlocal()) + dti2 = dti.tz_localize(None) + tm.assert_numpy_array_equal(dti2.asi8 - offset, dti.asi8) + + @pytest.mark.parametrize("tz", easts) + def test_dti_tz_localize_ambiguous_nat(self, tz): + times = [ + "11/06/2011 00:00", + "11/06/2011 01:00", + "11/06/2011 01:00", + "11/06/2011 02:00", + "11/06/2011 03:00", + ] + di = DatetimeIndex(times) + localized = di.tz_localize(tz, ambiguous="NaT") + + times = [ + "11/06/2011 00:00", + np.nan, + np.nan, + "11/06/2011 02:00", + "11/06/2011 03:00", + ] + di_test = DatetimeIndex(times, tz="US/Eastern") + + # left dtype is datetime64[ns, US/Eastern] + # right is datetime64[ns, tzfile('/usr/share/zoneinfo/US/Eastern')] + tm.assert_numpy_array_equal(di_test.values, localized.values) + + @pytest.mark.parametrize("tz", easts) + def test_dti_tz_localize_ambiguous_flags(self, tz, unit): + # November 6, 2011, fall back, repeat 2 AM hour + + # Pass in flags to determine right dst transition + dr = date_range( + datetime(2011, 11, 6, 0), periods=5, freq=offsets.Hour(), tz=tz, unit=unit + ) + times = [ + "11/06/2011 00:00", + "11/06/2011 01:00", + "11/06/2011 01:00", + "11/06/2011 02:00", + "11/06/2011 03:00", + ] + + # Test tz_localize + di = DatetimeIndex(times).as_unit(unit) + is_dst = [1, 1, 0, 0, 0] + localized = di.tz_localize(tz, ambiguous=is_dst) + expected = dr._with_freq(None) + tm.assert_index_equal(expected, localized) + + result = DatetimeIndex(times, tz=tz, ambiguous=is_dst).as_unit(unit) + tm.assert_index_equal(result, expected) + + localized = di.tz_localize(tz, ambiguous=np.array(is_dst)) + tm.assert_index_equal(dr, localized) + + localized = di.tz_localize(tz, ambiguous=np.array(is_dst).astype("bool")) + tm.assert_index_equal(dr, localized) + + # Test constructor + localized = DatetimeIndex(times, tz=tz, ambiguous=is_dst).as_unit(unit) + tm.assert_index_equal(dr, localized) + + # Test duplicate times where inferring the dst fails + times += times + di = DatetimeIndex(times).as_unit(unit) + + # When the sizes are incompatible, make sure error is raised + msg = "Length of ambiguous bool-array must be the same size as vals" + with pytest.raises(Exception, match=msg): + di.tz_localize(tz, ambiguous=is_dst) + + # When sizes are compatible and there are repeats ('infer' won't work) + is_dst = np.hstack((is_dst, is_dst)) + localized = di.tz_localize(tz, ambiguous=is_dst) + dr = dr.append(dr) + tm.assert_index_equal(dr, localized) + + @pytest.mark.parametrize("tz", easts) + def test_dti_tz_localize_ambiguous_flags2(self, tz, unit): + # When there is no dst transition, nothing special happens + dr = date_range(datetime(2011, 6, 1, 0), periods=10, freq=offsets.Hour()) + is_dst = np.array([1] * 10) + localized = dr.tz_localize(tz) + localized_is_dst = dr.tz_localize(tz, ambiguous=is_dst) + tm.assert_index_equal(localized, localized_is_dst) + + def test_dti_tz_localize_bdate_range(self): + dr = bdate_range("1/1/2009", "1/1/2010") + dr_utc = bdate_range("1/1/2009", "1/1/2010", tz=pytz.utc) + localized = dr.tz_localize(pytz.utc) + tm.assert_index_equal(dr_utc, localized) + + @pytest.mark.parametrize( + "start_ts, tz, end_ts, shift", + [ + ["2015-03-29 02:20:00", "Europe/Warsaw", "2015-03-29 03:00:00", "forward"], + [ + "2015-03-29 02:20:00", + "Europe/Warsaw", + "2015-03-29 01:59:59.999999999", + "backward", + ], + [ + "2015-03-29 02:20:00", + "Europe/Warsaw", + "2015-03-29 03:20:00", + timedelta(hours=1), + ], + [ + "2015-03-29 02:20:00", + "Europe/Warsaw", + "2015-03-29 01:20:00", + timedelta(hours=-1), + ], + ["2018-03-11 02:33:00", "US/Pacific", "2018-03-11 03:00:00", "forward"], + [ + "2018-03-11 02:33:00", + "US/Pacific", + "2018-03-11 01:59:59.999999999", + "backward", + ], + [ + "2018-03-11 02:33:00", + "US/Pacific", + "2018-03-11 03:33:00", + timedelta(hours=1), + ], + [ + "2018-03-11 02:33:00", + "US/Pacific", + "2018-03-11 01:33:00", + timedelta(hours=-1), + ], + ], + ) + @pytest.mark.parametrize("tz_type", ["", "dateutil/"]) + def test_dti_tz_localize_nonexistent_shift( + self, start_ts, tz, end_ts, shift, tz_type, unit + ): + # GH#8917 + tz = tz_type + tz + if isinstance(shift, str): + shift = "shift_" + shift + dti = DatetimeIndex([Timestamp(start_ts)]).as_unit(unit) + result = dti.tz_localize(tz, nonexistent=shift) + expected = DatetimeIndex([Timestamp(end_ts)]).tz_localize(tz).as_unit(unit) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("offset", [-1, 1]) + def test_dti_tz_localize_nonexistent_shift_invalid(self, offset, warsaw): + # GH#8917 + tz = warsaw + dti = DatetimeIndex([Timestamp("2015-03-29 02:20:00")]) + msg = "The provided timedelta will relocalize on a nonexistent time" + with pytest.raises(ValueError, match=msg): + dti.tz_localize(tz, nonexistent=timedelta(seconds=offset)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_unique.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_unique.py new file mode 100644 index 0000000000000000000000000000000000000000..3c419b23c749a16e66458b334b3aec34521c2241 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/datetimes/methods/test_unique.py @@ -0,0 +1,77 @@ +from datetime import ( + datetime, + timedelta, +) + +from pandas import ( + DatetimeIndex, + NaT, + Timestamp, +) +import pandas._testing as tm + + +def test_unique(tz_naive_fixture): + idx = DatetimeIndex(["2017"] * 2, tz=tz_naive_fixture) + expected = idx[:1] + + result = idx.unique() + tm.assert_index_equal(result, expected) + # GH#21737 + # Ensure the underlying data is consistent + assert result[0] == expected[0] + + +def test_index_unique(rand_series_with_duplicate_datetimeindex): + dups = rand_series_with_duplicate_datetimeindex + index = dups.index + + uniques = index.unique() + expected = DatetimeIndex( + [ + datetime(2000, 1, 2), + datetime(2000, 1, 3), + datetime(2000, 1, 4), + datetime(2000, 1, 5), + ], + dtype=index.dtype, + ) + assert uniques.dtype == index.dtype # sanity + tm.assert_index_equal(uniques, expected) + assert index.nunique() == 4 + + # GH#2563 + assert isinstance(uniques, DatetimeIndex) + + dups_local = index.tz_localize("US/Eastern") + dups_local.name = "foo" + result = dups_local.unique() + expected = DatetimeIndex(expected, name="foo") + expected = expected.tz_localize("US/Eastern") + assert result.tz is not None + assert result.name == "foo" + tm.assert_index_equal(result, expected) + + +def test_index_unique2(): + # NaT, note this is excluded + arr = [1370745748 + t for t in range(20)] + [NaT._value] + idx = DatetimeIndex(arr * 3) + tm.assert_index_equal(idx.unique(), DatetimeIndex(arr)) + assert idx.nunique() == 20 + assert idx.nunique(dropna=False) == 21 + + +def test_index_unique3(): + arr = [ + Timestamp("2013-06-09 02:42:28") + timedelta(seconds=t) for t in range(20) + ] + [NaT] + idx = DatetimeIndex(arr * 3) + tm.assert_index_equal(idx.unique(), DatetimeIndex(arr)) + assert idx.nunique() == 20 + assert idx.nunique(dropna=False) == 21 + + +def test_is_unique_monotonic(rand_series_with_duplicate_datetimeindex): + index = rand_series_with_duplicate_datetimeindex.index + assert not index.is_unique diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/object/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/object/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cededd8cf6917a43c9b2ba1a2965d1a6a8834f9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/object/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/object/__pycache__/test_astype.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/object/__pycache__/test_astype.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a71cf2c7134554dc2f7b6e79c1a6ca6487c27dd0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/object/__pycache__/test_astype.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/object/__pycache__/test_indexing.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/object/__pycache__/test_indexing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..332bfe5edaee4ebdbdebf334d7dd3597422a9137 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/object/__pycache__/test_indexing.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_constructors.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_constructors.py new file mode 100644 index 0000000000000000000000000000000000000000..892eb7b4a00d1ffbd9477194466bf9f2a2c522ff --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_constructors.py @@ -0,0 +1,691 @@ +import numpy as np +import pytest + +from pandas._libs.tslibs.period import IncompatibleFrequency + +from pandas.core.dtypes.dtypes import PeriodDtype + +from pandas import ( + Index, + NaT, + Period, + PeriodIndex, + Series, + date_range, + offsets, + period_range, +) +import pandas._testing as tm +from pandas.core.arrays import PeriodArray + + +class TestPeriodIndexDisallowedFreqs: + @pytest.mark.parametrize( + "freq,freq_depr", + [ + ("2M", "2ME"), + ("2Q-MAR", "2QE-MAR"), + ("2Y-FEB", "2YE-FEB"), + ("2M", "2me"), + ("2Q-MAR", "2qe-MAR"), + ("2Y-FEB", "2yE-feb"), + ], + ) + def test_period_index_offsets_frequency_error_message(self, freq, freq_depr): + # GH#52064 + msg = f"for Period, please use '{freq[1:]}' instead of '{freq_depr[1:]}'" + + with pytest.raises(ValueError, match=msg): + PeriodIndex(["2020-01-01", "2020-01-02"], freq=freq_depr) + + with pytest.raises(ValueError, match=msg): + period_range(start="2020-01-01", end="2020-01-02", freq=freq_depr) + + @pytest.mark.parametrize("freq_depr", ["2SME", "2sme", "2CBME", "2BYE", "2Bye"]) + def test_period_index_frequency_invalid_freq(self, freq_depr): + # GH#9586 + msg = f"Invalid frequency: {freq_depr[1:]}" + + with pytest.raises(ValueError, match=msg): + period_range("2020-01", "2020-05", freq=freq_depr) + with pytest.raises(ValueError, match=msg): + PeriodIndex(["2020-01", "2020-05"], freq=freq_depr) + + @pytest.mark.parametrize("freq", ["2BQE-SEP", "2BYE-MAR", "2BME"]) + def test_period_index_from_datetime_index_invalid_freq(self, freq): + # GH#56899 + msg = f"Invalid frequency: {freq[1:]}" + + rng = date_range("01-Jan-2012", periods=8, freq=freq) + with pytest.raises(ValueError, match=msg): + rng.to_period() + + +class TestPeriodIndex: + def test_from_ordinals(self): + Period(ordinal=-1000, freq="Y") + Period(ordinal=0, freq="Y") + + msg = "The 'ordinal' keyword in PeriodIndex is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + idx1 = PeriodIndex(ordinal=[-1, 0, 1], freq="Y") + with tm.assert_produces_warning(FutureWarning, match=msg): + idx2 = PeriodIndex(ordinal=np.array([-1, 0, 1]), freq="Y") + tm.assert_index_equal(idx1, idx2) + + alt1 = PeriodIndex.from_ordinals([-1, 0, 1], freq="Y") + tm.assert_index_equal(alt1, idx1) + + alt2 = PeriodIndex.from_ordinals(np.array([-1, 0, 1]), freq="Y") + tm.assert_index_equal(alt2, idx2) + + def test_keyword_mismatch(self): + # GH#55961 we should get exactly one of data/ordinals/**fields + per = Period("2016-01-01", "D") + depr_msg1 = "The 'ordinal' keyword in PeriodIndex is deprecated" + depr_msg2 = "Constructing PeriodIndex from fields is deprecated" + + err_msg1 = "Cannot pass both data and ordinal" + with pytest.raises(ValueError, match=err_msg1): + with tm.assert_produces_warning(FutureWarning, match=depr_msg1): + PeriodIndex(data=[per], ordinal=[per.ordinal], freq=per.freq) + + err_msg2 = "Cannot pass both data and fields" + with pytest.raises(ValueError, match=err_msg2): + with tm.assert_produces_warning(FutureWarning, match=depr_msg2): + PeriodIndex(data=[per], year=[per.year], freq=per.freq) + + err_msg3 = "Cannot pass both ordinal and fields" + with pytest.raises(ValueError, match=err_msg3): + with tm.assert_produces_warning(FutureWarning, match=depr_msg2): + PeriodIndex(ordinal=[per.ordinal], year=[per.year], freq=per.freq) + + def test_construction_base_constructor(self): + # GH 13664 + arr = [Period("2011-01", freq="M"), NaT, Period("2011-03", freq="M")] + tm.assert_index_equal(Index(arr), PeriodIndex(arr)) + tm.assert_index_equal(Index(np.array(arr)), PeriodIndex(np.array(arr))) + + arr = [np.nan, NaT, Period("2011-03", freq="M")] + tm.assert_index_equal(Index(arr), PeriodIndex(arr)) + tm.assert_index_equal(Index(np.array(arr)), PeriodIndex(np.array(arr))) + + arr = [Period("2011-01", freq="M"), NaT, Period("2011-03", freq="D")] + tm.assert_index_equal(Index(arr), Index(arr, dtype=object)) + + tm.assert_index_equal(Index(np.array(arr)), Index(np.array(arr), dtype=object)) + + def test_base_constructor_with_period_dtype(self): + dtype = PeriodDtype("D") + values = ["2011-01-01", "2012-03-04", "2014-05-01"] + result = Index(values, dtype=dtype) + + expected = PeriodIndex(values, dtype=dtype) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "values_constructor", [list, np.array, PeriodIndex, PeriodArray._from_sequence] + ) + def test_index_object_dtype(self, values_constructor): + # Index(periods, dtype=object) is an Index (not an PeriodIndex) + periods = [ + Period("2011-01", freq="M"), + NaT, + Period("2011-03", freq="M"), + ] + values = values_constructor(periods) + result = Index(values, dtype=object) + + assert type(result) is Index + tm.assert_numpy_array_equal(result.values, np.array(values)) + + def test_constructor_use_start_freq(self): + # GH #1118 + msg1 = "Period with BDay freq is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg1): + p = Period("4/2/2012", freq="B") + msg2 = r"PeriodDtype\[B\] is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg2): + expected = period_range(start="4/2/2012", periods=10, freq="B") + + with tm.assert_produces_warning(FutureWarning, match=msg2): + index = period_range(start=p, periods=10) + tm.assert_index_equal(index, expected) + + def test_constructor_field_arrays(self): + # GH #1264 + + years = np.arange(1990, 2010).repeat(4)[2:-2] + quarters = np.tile(np.arange(1, 5), 20)[2:-2] + + depr_msg = "Constructing PeriodIndex from fields is deprecated" + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + index = PeriodIndex(year=years, quarter=quarters, freq="Q-DEC") + expected = period_range("1990Q3", "2009Q2", freq="Q-DEC") + tm.assert_index_equal(index, expected) + + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + index2 = PeriodIndex(year=years, quarter=quarters, freq="2Q-DEC") + tm.assert_numpy_array_equal(index.asi8, index2.asi8) + + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + index = PeriodIndex(year=years, quarter=quarters) + tm.assert_index_equal(index, expected) + + years = [2007, 2007, 2007] + months = [1, 2] + + msg = "Mismatched Period array lengths" + with pytest.raises(ValueError, match=msg): + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + PeriodIndex(year=years, month=months, freq="M") + with pytest.raises(ValueError, match=msg): + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + PeriodIndex(year=years, month=months, freq="2M") + + years = [2007, 2007, 2007] + months = [1, 2, 3] + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + idx = PeriodIndex(year=years, month=months, freq="M") + exp = period_range("2007-01", periods=3, freq="M") + tm.assert_index_equal(idx, exp) + + def test_constructor_nano(self): + idx = period_range( + start=Period(ordinal=1, freq="ns"), + end=Period(ordinal=4, freq="ns"), + freq="ns", + ) + exp = PeriodIndex( + [ + Period(ordinal=1, freq="ns"), + Period(ordinal=2, freq="ns"), + Period(ordinal=3, freq="ns"), + Period(ordinal=4, freq="ns"), + ], + freq="ns", + ) + tm.assert_index_equal(idx, exp) + + def test_constructor_arrays_negative_year(self): + years = np.arange(1960, 2000, dtype=np.int64).repeat(4) + quarters = np.tile(np.array([1, 2, 3, 4], dtype=np.int64), 40) + + msg = "Constructing PeriodIndex from fields is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + pindex = PeriodIndex(year=years, quarter=quarters) + + tm.assert_index_equal(pindex.year, Index(years)) + tm.assert_index_equal(pindex.quarter, Index(quarters)) + + alt = PeriodIndex.from_fields(year=years, quarter=quarters) + tm.assert_index_equal(alt, pindex) + + def test_constructor_invalid_quarters(self): + depr_msg = "Constructing PeriodIndex from fields is deprecated" + msg = "Quarter must be 1 <= q <= 4" + with pytest.raises(ValueError, match=msg): + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + PeriodIndex( + year=range(2000, 2004), quarter=list(range(4)), freq="Q-DEC" + ) + + def test_period_range_fractional_period(self): + msg = "Non-integer 'periods' in pd.date_range, pd.timedelta_range" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = period_range("2007-01", periods=10.5, freq="M") + exp = period_range("2007-01", periods=10, freq="M") + tm.assert_index_equal(result, exp) + + def test_constructor_with_without_freq(self): + # GH53687 + start = Period("2002-01-01 00:00", freq="30min") + exp = period_range(start=start, periods=5, freq=start.freq) + result = period_range(start=start, periods=5) + tm.assert_index_equal(exp, result) + + def test_constructor_fromarraylike(self): + idx = period_range("2007-01", periods=20, freq="M") + + # values is an array of Period, thus can retrieve freq + tm.assert_index_equal(PeriodIndex(idx.values), idx) + tm.assert_index_equal(PeriodIndex(list(idx.values)), idx) + + msg = "freq not specified and cannot be inferred" + with pytest.raises(ValueError, match=msg): + PeriodIndex(idx.asi8) + with pytest.raises(ValueError, match=msg): + PeriodIndex(list(idx.asi8)) + + msg = "'Period' object is not iterable" + with pytest.raises(TypeError, match=msg): + PeriodIndex(data=Period("2007", freq="Y")) + + result = PeriodIndex(iter(idx)) + tm.assert_index_equal(result, idx) + + result = PeriodIndex(idx) + tm.assert_index_equal(result, idx) + + result = PeriodIndex(idx, freq="M") + tm.assert_index_equal(result, idx) + + result = PeriodIndex(idx, freq=offsets.MonthEnd()) + tm.assert_index_equal(result, idx) + assert result.freq == "ME" + + result = PeriodIndex(idx, freq="2M") + tm.assert_index_equal(result, idx.asfreq("2M")) + assert result.freq == "2ME" + + result = PeriodIndex(idx, freq=offsets.MonthEnd(2)) + tm.assert_index_equal(result, idx.asfreq("2M")) + assert result.freq == "2ME" + + result = PeriodIndex(idx, freq="D") + exp = idx.asfreq("D", "e") + tm.assert_index_equal(result, exp) + + def test_constructor_datetime64arr(self): + vals = np.arange(100000, 100000 + 10000, 100, dtype=np.int64) + vals = vals.view(np.dtype("M8[us]")) + + pi = PeriodIndex(vals, freq="D") + + expected = PeriodIndex(vals.astype("M8[ns]"), freq="D") + tm.assert_index_equal(pi, expected) + + @pytest.mark.parametrize("box", [None, "series", "index"]) + def test_constructor_datetime64arr_ok(self, box): + # https://github.com/pandas-dev/pandas/issues/23438 + data = date_range("2017", periods=4, freq="ME") + if box is None: + data = data._values + elif box == "series": + data = Series(data) + + result = PeriodIndex(data, freq="D") + expected = PeriodIndex( + ["2017-01-31", "2017-02-28", "2017-03-31", "2017-04-30"], freq="D" + ) + tm.assert_index_equal(result, expected) + + def test_constructor_dtype(self): + # passing a dtype with a tz should localize + idx = PeriodIndex(["2013-01", "2013-03"], dtype="period[M]") + exp = PeriodIndex(["2013-01", "2013-03"], freq="M") + tm.assert_index_equal(idx, exp) + assert idx.dtype == "period[M]" + + idx = PeriodIndex(["2013-01-05", "2013-03-05"], dtype="period[3D]") + exp = PeriodIndex(["2013-01-05", "2013-03-05"], freq="3D") + tm.assert_index_equal(idx, exp) + assert idx.dtype == "period[3D]" + + # if we already have a freq and its not the same, then asfreq + # (not changed) + idx = PeriodIndex(["2013-01-01", "2013-01-02"], freq="D") + + res = PeriodIndex(idx, dtype="period[M]") + exp = PeriodIndex(["2013-01", "2013-01"], freq="M") + tm.assert_index_equal(res, exp) + assert res.dtype == "period[M]" + + res = PeriodIndex(idx, freq="M") + tm.assert_index_equal(res, exp) + assert res.dtype == "period[M]" + + msg = "specified freq and dtype are different" + with pytest.raises(IncompatibleFrequency, match=msg): + PeriodIndex(["2011-01"], freq="M", dtype="period[D]") + + def test_constructor_empty(self): + idx = PeriodIndex([], freq="M") + assert isinstance(idx, PeriodIndex) + assert len(idx) == 0 + assert idx.freq == "ME" + + with pytest.raises(ValueError, match="freq not specified"): + PeriodIndex([]) + + def test_constructor_pi_nat(self): + idx = PeriodIndex( + [Period("2011-01", freq="M"), NaT, Period("2011-01", freq="M")] + ) + exp = PeriodIndex(["2011-01", "NaT", "2011-01"], freq="M") + tm.assert_index_equal(idx, exp) + + idx = PeriodIndex( + np.array([Period("2011-01", freq="M"), NaT, Period("2011-01", freq="M")]) + ) + tm.assert_index_equal(idx, exp) + + idx = PeriodIndex( + [NaT, NaT, Period("2011-01", freq="M"), Period("2011-01", freq="M")] + ) + exp = PeriodIndex(["NaT", "NaT", "2011-01", "2011-01"], freq="M") + tm.assert_index_equal(idx, exp) + + idx = PeriodIndex( + np.array( + [NaT, NaT, Period("2011-01", freq="M"), Period("2011-01", freq="M")] + ) + ) + tm.assert_index_equal(idx, exp) + + idx = PeriodIndex([NaT, NaT, "2011-01", "2011-01"], freq="M") + tm.assert_index_equal(idx, exp) + + with pytest.raises(ValueError, match="freq not specified"): + PeriodIndex([NaT, NaT]) + + with pytest.raises(ValueError, match="freq not specified"): + PeriodIndex(np.array([NaT, NaT])) + + with pytest.raises(ValueError, match="freq not specified"): + PeriodIndex(["NaT", "NaT"]) + + with pytest.raises(ValueError, match="freq not specified"): + PeriodIndex(np.array(["NaT", "NaT"])) + + def test_constructor_incompat_freq(self): + msg = "Input has different freq=D from PeriodIndex\\(freq=M\\)" + + with pytest.raises(IncompatibleFrequency, match=msg): + PeriodIndex([Period("2011-01", freq="M"), NaT, Period("2011-01", freq="D")]) + + with pytest.raises(IncompatibleFrequency, match=msg): + PeriodIndex( + np.array( + [Period("2011-01", freq="M"), NaT, Period("2011-01", freq="D")] + ) + ) + + # first element is NaT + with pytest.raises(IncompatibleFrequency, match=msg): + PeriodIndex([NaT, Period("2011-01", freq="M"), Period("2011-01", freq="D")]) + + with pytest.raises(IncompatibleFrequency, match=msg): + PeriodIndex( + np.array( + [NaT, Period("2011-01", freq="M"), Period("2011-01", freq="D")] + ) + ) + + def test_constructor_mixed(self): + idx = PeriodIndex(["2011-01", NaT, Period("2011-01", freq="M")]) + exp = PeriodIndex(["2011-01", "NaT", "2011-01"], freq="M") + tm.assert_index_equal(idx, exp) + + idx = PeriodIndex(["NaT", NaT, Period("2011-01", freq="M")]) + exp = PeriodIndex(["NaT", "NaT", "2011-01"], freq="M") + tm.assert_index_equal(idx, exp) + + idx = PeriodIndex([Period("2011-01-01", freq="D"), NaT, "2012-01-01"]) + exp = PeriodIndex(["2011-01-01", "NaT", "2012-01-01"], freq="D") + tm.assert_index_equal(idx, exp) + + @pytest.mark.parametrize("floats", [[1.1, 2.1], np.array([1.1, 2.1])]) + def test_constructor_floats(self, floats): + msg = "PeriodIndex does not allow floating point in construction" + with pytest.raises(TypeError, match=msg): + PeriodIndex(floats) + + def test_constructor_year_and_quarter(self): + year = Series([2001, 2002, 2003]) + quarter = year - 2000 + msg = "Constructing PeriodIndex from fields is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + idx = PeriodIndex(year=year, quarter=quarter) + strs = [f"{t[0]:d}Q{t[1]:d}" for t in zip(quarter, year)] + lops = list(map(Period, strs)) + p = PeriodIndex(lops) + tm.assert_index_equal(p, idx) + + def test_constructor_freq_mult(self): + # GH #7811 + pidx = period_range(start="2014-01", freq="2M", periods=4) + expected = PeriodIndex(["2014-01", "2014-03", "2014-05", "2014-07"], freq="2M") + tm.assert_index_equal(pidx, expected) + + pidx = period_range(start="2014-01-02", end="2014-01-15", freq="3D") + expected = PeriodIndex( + ["2014-01-02", "2014-01-05", "2014-01-08", "2014-01-11", "2014-01-14"], + freq="3D", + ) + tm.assert_index_equal(pidx, expected) + + pidx = period_range(end="2014-01-01 17:00", freq="4h", periods=3) + expected = PeriodIndex( + ["2014-01-01 09:00", "2014-01-01 13:00", "2014-01-01 17:00"], freq="4h" + ) + tm.assert_index_equal(pidx, expected) + + msg = "Frequency must be positive, because it represents span: -1M" + with pytest.raises(ValueError, match=msg): + PeriodIndex(["2011-01"], freq="-1M") + + msg = "Frequency must be positive, because it represents span: 0M" + with pytest.raises(ValueError, match=msg): + PeriodIndex(["2011-01"], freq="0M") + + msg = "Frequency must be positive, because it represents span: 0M" + with pytest.raises(ValueError, match=msg): + period_range("2011-01", periods=3, freq="0M") + + @pytest.mark.parametrize( + "freq_offset, freq_period", + [ + ("YE", "Y"), + ("ME", "M"), + ("D", "D"), + ("min", "min"), + ("s", "s"), + ], + ) + @pytest.mark.parametrize("mult", [1, 2, 3, 4, 5]) + def test_constructor_freq_mult_dti_compat(self, mult, freq_offset, freq_period): + freqstr_offset = str(mult) + freq_offset + freqstr_period = str(mult) + freq_period + pidx = period_range(start="2014-04-01", freq=freqstr_period, periods=10) + expected = date_range( + start="2014-04-01", freq=freqstr_offset, periods=10 + ).to_period(freqstr_period) + tm.assert_index_equal(pidx, expected) + + @pytest.mark.parametrize("mult", [1, 2, 3, 4, 5]) + def test_constructor_freq_mult_dti_compat_month(self, mult): + pidx = period_range(start="2014-04-01", freq=f"{mult}M", periods=10) + expected = date_range( + start="2014-04-01", freq=f"{mult}ME", periods=10 + ).to_period(f"{mult}M") + tm.assert_index_equal(pidx, expected) + + def test_constructor_freq_combined(self): + for freq in ["1D1h", "1h1D"]: + pidx = PeriodIndex(["2016-01-01", "2016-01-02"], freq=freq) + expected = PeriodIndex(["2016-01-01 00:00", "2016-01-02 00:00"], freq="25h") + for freq in ["1D1h", "1h1D"]: + pidx = period_range(start="2016-01-01", periods=2, freq=freq) + expected = PeriodIndex(["2016-01-01 00:00", "2016-01-02 01:00"], freq="25h") + tm.assert_index_equal(pidx, expected) + + def test_period_range_length(self): + pi = period_range(freq="Y", start="1/1/2001", end="12/1/2009") + assert len(pi) == 9 + + pi = period_range(freq="Q", start="1/1/2001", end="12/1/2009") + assert len(pi) == 4 * 9 + + pi = period_range(freq="M", start="1/1/2001", end="12/1/2009") + assert len(pi) == 12 * 9 + + pi = period_range(freq="D", start="1/1/2001", end="12/31/2009") + assert len(pi) == 365 * 9 + 2 + + msg = "Period with BDay freq is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + pi = period_range(freq="B", start="1/1/2001", end="12/31/2009") + assert len(pi) == 261 * 9 + + pi = period_range(freq="h", start="1/1/2001", end="12/31/2001 23:00") + assert len(pi) == 365 * 24 + + pi = period_range(freq="Min", start="1/1/2001", end="1/1/2001 23:59") + assert len(pi) == 24 * 60 + + pi = period_range(freq="s", start="1/1/2001", end="1/1/2001 23:59:59") + assert len(pi) == 24 * 60 * 60 + + with tm.assert_produces_warning(FutureWarning, match=msg): + start = Period("02-Apr-2005", "B") + i1 = period_range(start=start, periods=20) + assert len(i1) == 20 + assert i1.freq == start.freq + assert i1[0] == start + + end_intv = Period("2006-12-31", "W") + i1 = period_range(end=end_intv, periods=10) + assert len(i1) == 10 + assert i1.freq == end_intv.freq + assert i1[-1] == end_intv + + msg = "'w' is deprecated and will be removed in a future version." + with tm.assert_produces_warning(FutureWarning, match=msg): + end_intv = Period("2006-12-31", "1w") + i2 = period_range(end=end_intv, periods=10) + assert len(i1) == len(i2) + assert (i1 == i2).all() + assert i1.freq == i2.freq + + def test_infer_freq_from_first_element(self): + msg = "Period with BDay freq is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + start = Period("02-Apr-2005", "B") + end_intv = Period("2005-05-01", "B") + period_range(start=start, end=end_intv) + + # infer freq from first element + i2 = PeriodIndex([end_intv, Period("2005-05-05", "B")]) + assert len(i2) == 2 + assert i2[0] == end_intv + + with tm.assert_produces_warning(FutureWarning, match=msg): + i2 = PeriodIndex(np.array([end_intv, Period("2005-05-05", "B")])) + assert len(i2) == 2 + assert i2[0] == end_intv + + def test_mixed_freq_raises(self): + # Mixed freq should fail + msg = "Period with BDay freq is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + end_intv = Period("2005-05-01", "B") + + msg = "'w' is deprecated and will be removed in a future version." + with tm.assert_produces_warning(FutureWarning, match=msg): + vals = [end_intv, Period("2006-12-31", "w")] + msg = r"Input has different freq=W-SUN from PeriodIndex\(freq=B\)" + depr_msg = r"PeriodDtype\[B\] is deprecated" + with pytest.raises(IncompatibleFrequency, match=msg): + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + PeriodIndex(vals) + vals = np.array(vals) + with pytest.raises(IncompatibleFrequency, match=msg): + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + PeriodIndex(vals) + + @pytest.mark.parametrize( + "freq", ["M", "Q", "Y", "D", "B", "min", "s", "ms", "us", "ns", "h"] + ) + @pytest.mark.filterwarnings( + r"ignore:Period with BDay freq is deprecated:FutureWarning" + ) + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + def test_recreate_from_data(self, freq): + org = period_range(start="2001/04/01", freq=freq, periods=1) + idx = PeriodIndex(org.values, freq=freq) + tm.assert_index_equal(idx, org) + + def test_map_with_string_constructor(self): + raw = [2005, 2007, 2009] + index = PeriodIndex(raw, freq="Y") + + expected = Index([str(num) for num in raw]) + res = index.map(str) + + # should return an Index + assert isinstance(res, Index) + + # preserve element types + assert all(isinstance(resi, str) for resi in res) + + # lastly, values should compare equal + tm.assert_index_equal(res, expected) + + +class TestSimpleNew: + def test_constructor_simple_new(self): + idx = period_range("2007-01", name="p", periods=2, freq="M") + + with pytest.raises(AssertionError, match=""): + idx._simple_new(idx, name="p") + + result = idx._simple_new(idx._data, name="p") + tm.assert_index_equal(result, idx) + + msg = "Should be numpy array of type i8" + with pytest.raises(AssertionError, match=msg): + # Need ndarray, not int64 Index + type(idx._data)._simple_new(Index(idx.asi8), dtype=idx.dtype) + + arr = type(idx._data)._simple_new(idx.asi8, dtype=idx.dtype) + result = idx._simple_new(arr, name="p") + tm.assert_index_equal(result, idx) + + def test_constructor_simple_new_empty(self): + # GH13079 + idx = PeriodIndex([], freq="M", name="p") + with pytest.raises(AssertionError, match=""): + idx._simple_new(idx, name="p") + + result = idx._simple_new(idx._data, name="p") + tm.assert_index_equal(result, idx) + + @pytest.mark.parametrize("floats", [[1.1, 2.1], np.array([1.1, 2.1])]) + def test_period_index_simple_new_disallows_floats(self, floats): + with pytest.raises(AssertionError, match="= -1" + ) + with pytest.raises(ValueError, match=msg): + idx.take(np.array([1, 0, -2]), fill_value=True) + with pytest.raises(ValueError, match=msg): + idx.take(np.array([1, 0, -5]), fill_value=True) + + msg = "index -5 is out of bounds for( axis 0 with)? size 3" + with pytest.raises(IndexError, match=msg): + idx.take(np.array([1, -5])) + + +class TestGetValue: + @pytest.mark.parametrize("freq", ["h", "D"]) + def test_get_value_datetime_hourly(self, freq): + # get_loc and get_value should treat datetime objects symmetrically + # TODO: this test used to test get_value, which is removed in 2.0. + # should this test be moved somewhere, or is what's left redundant? + dti = date_range("2016-01-01", periods=3, freq="MS") + pi = dti.to_period(freq) + ser = Series(range(7, 10), index=pi) + + ts = dti[0] + + assert pi.get_loc(ts) == 0 + assert ser[ts] == 7 + assert ser.loc[ts] == 7 + + ts2 = ts + Timedelta(hours=3) + if freq == "h": + with pytest.raises(KeyError, match="2016-01-01 03:00"): + pi.get_loc(ts2) + with pytest.raises(KeyError, match="2016-01-01 03:00"): + ser[ts2] + with pytest.raises(KeyError, match="2016-01-01 03:00"): + ser.loc[ts2] + else: + assert pi.get_loc(ts2) == 0 + assert ser[ts2] == 7 + assert ser.loc[ts2] == 7 + + +class TestContains: + def test_contains(self): + # GH 17717 + p0 = Period("2017-09-01") + p1 = Period("2017-09-02") + p2 = Period("2017-09-03") + p3 = Period("2017-09-04") + + ps0 = [p0, p1, p2] + idx0 = PeriodIndex(ps0) + + for p in ps0: + assert p in idx0 + assert str(p) in idx0 + + # GH#31172 + # Higher-resolution period-like are _not_ considered as contained + key = "2017-09-01 00:00:01" + assert key not in idx0 + with pytest.raises(KeyError, match=key): + idx0.get_loc(key) + + assert "2017-09" in idx0 + + assert p3 not in idx0 + + def test_contains_freq_mismatch(self): + rng = period_range("2007-01", freq="M", periods=10) + + assert Period("2007-01", freq="M") in rng + assert Period("2007-01", freq="D") not in rng + assert Period("2007-01", freq="2M") not in rng + + def test_contains_nat(self): + # see gh-13582 + idx = period_range("2007-01", freq="M", periods=10) + assert NaT not in idx + assert None not in idx + assert float("nan") not in idx + assert np.nan not in idx + + idx = PeriodIndex(["2011-01", "NaT", "2011-02"], freq="M") + assert NaT in idx + assert None in idx + assert float("nan") in idx + assert np.nan in idx + + +class TestAsOfLocs: + def test_asof_locs_mismatched_type(self): + dti = date_range("2016-01-01", periods=3) + pi = dti.to_period("D") + pi2 = dti.to_period("h") + + mask = np.array([0, 1, 0], dtype=bool) + + msg = "must be DatetimeIndex or PeriodIndex" + with pytest.raises(TypeError, match=msg): + pi.asof_locs(pd.Index(pi.asi8, dtype=np.int64), mask) + + with pytest.raises(TypeError, match=msg): + pi.asof_locs(pd.Index(pi.asi8, dtype=np.float64), mask) + + with pytest.raises(TypeError, match=msg): + # TimedeltaIndex + pi.asof_locs(dti - dti, mask) + + msg = "Input has different freq=h" + with pytest.raises(libperiod.IncompatibleFrequency, match=msg): + pi.asof_locs(pi2, mask) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_join.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_join.py new file mode 100644 index 0000000000000000000000000000000000000000..3e659c1a632669c2b89d7ea0411de5c4c35108ad --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_join.py @@ -0,0 +1,58 @@ +import numpy as np +import pytest + +from pandas._libs.tslibs import IncompatibleFrequency + +from pandas import ( + DataFrame, + Index, + PeriodIndex, + date_range, + period_range, +) +import pandas._testing as tm + + +class TestJoin: + def test_join_outer_indexer(self): + pi = period_range("1/1/2000", "1/20/2000", freq="D") + + result = pi._outer_indexer(pi) + tm.assert_extension_array_equal(result[0], pi._values) + tm.assert_numpy_array_equal(result[1], np.arange(len(pi), dtype=np.intp)) + tm.assert_numpy_array_equal(result[2], np.arange(len(pi), dtype=np.intp)) + + def test_joins(self, join_type): + index = period_range("1/1/2000", "1/20/2000", freq="D") + + joined = index.join(index[:-5], how=join_type) + + assert isinstance(joined, PeriodIndex) + assert joined.freq == index.freq + + def test_join_self(self, join_type): + index = period_range("1/1/2000", "1/20/2000", freq="D") + + res = index.join(index, how=join_type) + assert index is res + + def test_join_does_not_recur(self): + df = DataFrame( + np.ones((3, 2)), + index=date_range("2020-01-01", periods=3), + columns=period_range("2020-01-01", periods=2), + ) + ser = df.iloc[:2, 0] + + res = ser.index.join(df.columns, how="outer") + expected = Index( + [ser.index[0], ser.index[1], df.columns[0], df.columns[1]], object + ) + tm.assert_index_equal(res, expected) + + def test_join_mismatched_freq_raises(self): + index = period_range("1/1/2000", "1/20/2000", freq="D") + index3 = period_range("1/1/2000", "1/20/2000", freq="2D") + msg = r".*Input has different freq=2D from Period\(freq=D\)" + with pytest.raises(IncompatibleFrequency, match=msg): + index.join(index3) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_monotonic.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_monotonic.py new file mode 100644 index 0000000000000000000000000000000000000000..15cb8f71cdcf3221800e6dca43390ae79114a9df --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_monotonic.py @@ -0,0 +1,42 @@ +from pandas import ( + Period, + PeriodIndex, +) + + +def test_is_monotonic_increasing(): + # GH#17717 + p0 = Period("2017-09-01") + p1 = Period("2017-09-02") + p2 = Period("2017-09-03") + + idx_inc0 = PeriodIndex([p0, p1, p2]) + idx_inc1 = PeriodIndex([p0, p1, p1]) + idx_dec0 = PeriodIndex([p2, p1, p0]) + idx_dec1 = PeriodIndex([p2, p1, p1]) + idx = PeriodIndex([p1, p2, p0]) + + assert idx_inc0.is_monotonic_increasing is True + assert idx_inc1.is_monotonic_increasing is True + assert idx_dec0.is_monotonic_increasing is False + assert idx_dec1.is_monotonic_increasing is False + assert idx.is_monotonic_increasing is False + + +def test_is_monotonic_decreasing(): + # GH#17717 + p0 = Period("2017-09-01") + p1 = Period("2017-09-02") + p2 = Period("2017-09-03") + + idx_inc0 = PeriodIndex([p0, p1, p2]) + idx_inc1 = PeriodIndex([p0, p1, p1]) + idx_dec0 = PeriodIndex([p2, p1, p0]) + idx_dec1 = PeriodIndex([p2, p1, p1]) + idx = PeriodIndex([p1, p2, p0]) + + assert idx_inc0.is_monotonic_decreasing is False + assert idx_inc1.is_monotonic_decreasing is False + assert idx_dec0.is_monotonic_decreasing is True + assert idx_dec1.is_monotonic_decreasing is True + assert idx.is_monotonic_decreasing is False diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_partial_slicing.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_partial_slicing.py new file mode 100644 index 0000000000000000000000000000000000000000..4fab12f195dc03d43e952d5ee424955330933c0a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_partial_slicing.py @@ -0,0 +1,198 @@ +import numpy as np +import pytest + +from pandas import ( + DataFrame, + PeriodIndex, + Series, + date_range, + period_range, +) +import pandas._testing as tm + + +class TestPeriodIndex: + def test_getitem_periodindex_duplicates_string_slice( + self, using_copy_on_write, warn_copy_on_write + ): + # monotonic + idx = PeriodIndex([2000, 2007, 2007, 2009, 2009], freq="Y-JUN") + ts = Series(np.random.default_rng(2).standard_normal(len(idx)), index=idx) + original = ts.copy() + + result = ts["2007"] + expected = ts[1:3] + tm.assert_series_equal(result, expected) + with tm.assert_cow_warning(warn_copy_on_write): + result[:] = 1 + if using_copy_on_write: + tm.assert_series_equal(ts, original) + else: + assert (ts[1:3] == 1).all() + + # not monotonic + idx = PeriodIndex([2000, 2007, 2007, 2009, 2007], freq="Y-JUN") + ts = Series(np.random.default_rng(2).standard_normal(len(idx)), index=idx) + + result = ts["2007"] + expected = ts[idx == "2007"] + tm.assert_series_equal(result, expected) + + def test_getitem_periodindex_quarter_string(self): + pi = PeriodIndex(["2Q05", "3Q05", "4Q05", "1Q06", "2Q06"], freq="Q") + ser = Series(np.random.default_rng(2).random(len(pi)), index=pi).cumsum() + # Todo: fix these accessors! + assert ser["05Q4"] == ser.iloc[2] + + def test_pindex_slice_index(self): + pi = period_range(start="1/1/10", end="12/31/12", freq="M") + s = Series(np.random.default_rng(2).random(len(pi)), index=pi) + res = s["2010"] + exp = s[0:12] + tm.assert_series_equal(res, exp) + res = s["2011"] + exp = s[12:24] + tm.assert_series_equal(res, exp) + + @pytest.mark.parametrize("make_range", [date_range, period_range]) + def test_range_slice_day(self, make_range): + # GH#6716 + idx = make_range(start="2013/01/01", freq="D", periods=400) + + msg = "slice indices must be integers or None or have an __index__ method" + # slices against index should raise IndexError + values = [ + "2014", + "2013/02", + "2013/01/02", + "2013/02/01 9H", + "2013/02/01 09:00", + ] + for v in values: + with pytest.raises(TypeError, match=msg): + idx[v:] + + s = Series(np.random.default_rng(2).random(len(idx)), index=idx) + + tm.assert_series_equal(s["2013/01/02":], s[1:]) + tm.assert_series_equal(s["2013/01/02":"2013/01/05"], s[1:5]) + tm.assert_series_equal(s["2013/02":], s[31:]) + tm.assert_series_equal(s["2014":], s[365:]) + + invalid = ["2013/02/01 9H", "2013/02/01 09:00"] + for v in invalid: + with pytest.raises(TypeError, match=msg): + idx[v:] + + @pytest.mark.parametrize("make_range", [date_range, period_range]) + def test_range_slice_seconds(self, make_range): + # GH#6716 + idx = make_range(start="2013/01/01 09:00:00", freq="s", periods=4000) + msg = "slice indices must be integers or None or have an __index__ method" + + # slices against index should raise IndexError + values = [ + "2014", + "2013/02", + "2013/01/02", + "2013/02/01 9H", + "2013/02/01 09:00", + ] + for v in values: + with pytest.raises(TypeError, match=msg): + idx[v:] + + s = Series(np.random.default_rng(2).random(len(idx)), index=idx) + + tm.assert_series_equal(s["2013/01/01 09:05":"2013/01/01 09:10"], s[300:660]) + tm.assert_series_equal(s["2013/01/01 10:00":"2013/01/01 10:05"], s[3600:3960]) + tm.assert_series_equal(s["2013/01/01 10H":], s[3600:]) + tm.assert_series_equal(s[:"2013/01/01 09:30"], s[:1860]) + for d in ["2013/01/01", "2013/01", "2013"]: + tm.assert_series_equal(s[d:], s) + + @pytest.mark.parametrize("make_range", [date_range, period_range]) + def test_range_slice_outofbounds(self, make_range): + # GH#5407 + idx = make_range(start="2013/10/01", freq="D", periods=10) + + df = DataFrame({"units": [100 + i for i in range(10)]}, index=idx) + empty = DataFrame(index=idx[:0], columns=["units"]) + empty["units"] = empty["units"].astype("int64") + + tm.assert_frame_equal(df["2013/09/01":"2013/09/30"], empty) + tm.assert_frame_equal(df["2013/09/30":"2013/10/02"], df.iloc[:2]) + tm.assert_frame_equal(df["2013/10/01":"2013/10/02"], df.iloc[:2]) + tm.assert_frame_equal(df["2013/10/02":"2013/09/30"], empty) + tm.assert_frame_equal(df["2013/10/15":"2013/10/17"], empty) + tm.assert_frame_equal(df["2013-06":"2013-09"], empty) + tm.assert_frame_equal(df["2013-11":"2013-12"], empty) + + @pytest.mark.parametrize("make_range", [date_range, period_range]) + def test_maybe_cast_slice_bound(self, make_range, frame_or_series): + idx = make_range(start="2013/10/01", freq="D", periods=10) + + obj = DataFrame({"units": [100 + i for i in range(10)]}, index=idx) + obj = tm.get_obj(obj, frame_or_series) + + msg = ( + f"cannot do slice indexing on {type(idx).__name__} with " + r"these indexers \[foo\] of type str" + ) + + # Check the lower-level calls are raising where expected. + with pytest.raises(TypeError, match=msg): + idx._maybe_cast_slice_bound("foo", "left") + with pytest.raises(TypeError, match=msg): + idx.get_slice_bound("foo", "left") + + with pytest.raises(TypeError, match=msg): + obj["2013/09/30":"foo"] + with pytest.raises(TypeError, match=msg): + obj["foo":"2013/09/30"] + with pytest.raises(TypeError, match=msg): + obj.loc["2013/09/30":"foo"] + with pytest.raises(TypeError, match=msg): + obj.loc["foo":"2013/09/30"] + + def test_partial_slice_doesnt_require_monotonicity(self): + # See also: DatetimeIndex test ofm the same name + dti = date_range("2014-01-01", periods=30, freq="30D") + pi = dti.to_period("D") + + ser_montonic = Series(np.arange(30), index=pi) + + shuffler = list(range(0, 30, 2)) + list(range(1, 31, 2)) + ser = ser_montonic.iloc[shuffler] + nidx = ser.index + + # Manually identified locations of year==2014 + indexer_2014 = np.array( + [0, 1, 2, 3, 4, 5, 6, 15, 16, 17, 18, 19, 20], dtype=np.intp + ) + assert (nidx[indexer_2014].year == 2014).all() + assert not (nidx[~indexer_2014].year == 2014).any() + + result = nidx.get_loc("2014") + tm.assert_numpy_array_equal(result, indexer_2014) + + expected = ser.iloc[indexer_2014] + result = ser.loc["2014"] + tm.assert_series_equal(result, expected) + + result = ser["2014"] + tm.assert_series_equal(result, expected) + + # Manually identified locations where ser.index is within Mat 2015 + indexer_may2015 = np.array([23], dtype=np.intp) + assert nidx[23].year == 2015 and nidx[23].month == 5 + + result = nidx.get_loc("May 2015") + tm.assert_numpy_array_equal(result, indexer_may2015) + + expected = ser.iloc[indexer_may2015] + result = ser.loc["May 2015"] + tm.assert_series_equal(result, expected) + + result = ser["May 2015"] + tm.assert_series_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_period.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_period.py new file mode 100644 index 0000000000000000000000000000000000000000..77b8e76894647f25ea94f8bf1dce460d0b2a165f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_period.py @@ -0,0 +1,231 @@ +import numpy as np +import pytest + +from pandas import ( + Index, + NaT, + Period, + PeriodIndex, + Series, + date_range, + offsets, + period_range, +) +import pandas._testing as tm + + +class TestPeriodIndex: + def test_view_asi8(self): + idx = PeriodIndex([], freq="M") + + exp = np.array([], dtype=np.int64) + tm.assert_numpy_array_equal(idx.view("i8"), exp) + tm.assert_numpy_array_equal(idx.asi8, exp) + + idx = PeriodIndex(["2011-01", NaT], freq="M") + + exp = np.array([492, -9223372036854775808], dtype=np.int64) + tm.assert_numpy_array_equal(idx.view("i8"), exp) + tm.assert_numpy_array_equal(idx.asi8, exp) + + exp = np.array([14975, -9223372036854775808], dtype=np.int64) + idx = PeriodIndex(["2011-01-01", NaT], freq="D") + tm.assert_numpy_array_equal(idx.view("i8"), exp) + tm.assert_numpy_array_equal(idx.asi8, exp) + + def test_values(self): + idx = PeriodIndex([], freq="M") + + exp = np.array([], dtype=object) + tm.assert_numpy_array_equal(idx.values, exp) + tm.assert_numpy_array_equal(idx.to_numpy(), exp) + + exp = np.array([], dtype=np.int64) + tm.assert_numpy_array_equal(idx.asi8, exp) + + idx = PeriodIndex(["2011-01", NaT], freq="M") + + exp = np.array([Period("2011-01", freq="M"), NaT], dtype=object) + tm.assert_numpy_array_equal(idx.values, exp) + tm.assert_numpy_array_equal(idx.to_numpy(), exp) + exp = np.array([492, -9223372036854775808], dtype=np.int64) + tm.assert_numpy_array_equal(idx.asi8, exp) + + idx = PeriodIndex(["2011-01-01", NaT], freq="D") + + exp = np.array([Period("2011-01-01", freq="D"), NaT], dtype=object) + tm.assert_numpy_array_equal(idx.values, exp) + tm.assert_numpy_array_equal(idx.to_numpy(), exp) + exp = np.array([14975, -9223372036854775808], dtype=np.int64) + tm.assert_numpy_array_equal(idx.asi8, exp) + + @pytest.mark.parametrize( + "field", + [ + "year", + "month", + "day", + "hour", + "minute", + "second", + "weekofyear", + "week", + "dayofweek", + "day_of_week", + "dayofyear", + "day_of_year", + "quarter", + "qyear", + "days_in_month", + ], + ) + @pytest.mark.parametrize( + "periodindex", + [ + period_range(freq="Y", start="1/1/2001", end="12/1/2005"), + period_range(freq="Q", start="1/1/2001", end="12/1/2002"), + period_range(freq="M", start="1/1/2001", end="1/1/2002"), + period_range(freq="D", start="12/1/2001", end="6/1/2001"), + period_range(freq="h", start="12/31/2001", end="1/1/2002 23:00"), + period_range(freq="Min", start="12/31/2001", end="1/1/2002 00:20"), + period_range( + freq="s", start="12/31/2001 00:00:00", end="12/31/2001 00:05:00" + ), + period_range(end=Period("2006-12-31", "W"), periods=10), + ], + ) + def test_fields(self, periodindex, field): + periods = list(periodindex) + ser = Series(periodindex) + + field_idx = getattr(periodindex, field) + assert len(periodindex) == len(field_idx) + for x, val in zip(periods, field_idx): + assert getattr(x, field) == val + + if len(ser) == 0: + return + + field_s = getattr(ser.dt, field) + assert len(periodindex) == len(field_s) + for x, val in zip(periods, field_s): + assert getattr(x, field) == val + + def test_is_(self): + create_index = lambda: period_range(freq="Y", start="1/1/2001", end="12/1/2009") + index = create_index() + assert index.is_(index) + assert not index.is_(create_index()) + assert index.is_(index.view()) + assert index.is_(index.view().view().view().view().view()) + assert index.view().is_(index) + ind2 = index.view() + index.name = "Apple" + assert ind2.is_(index) + assert not index.is_(index[:]) + assert not index.is_(index.asfreq("M")) + assert not index.is_(index.asfreq("Y")) + + assert not index.is_(index - 2) + assert not index.is_(index - 0) + + def test_index_unique(self): + idx = PeriodIndex([2000, 2007, 2007, 2009, 2009], freq="Y-JUN") + expected = PeriodIndex([2000, 2007, 2009], freq="Y-JUN") + tm.assert_index_equal(idx.unique(), expected) + assert idx.nunique() == 3 + + def test_pindex_fieldaccessor_nat(self): + idx = PeriodIndex( + ["2011-01", "2011-02", "NaT", "2012-03", "2012-04"], freq="D", name="name" + ) + + exp = Index([2011, 2011, -1, 2012, 2012], dtype=np.int64, name="name") + tm.assert_index_equal(idx.year, exp) + exp = Index([1, 2, -1, 3, 4], dtype=np.int64, name="name") + tm.assert_index_equal(idx.month, exp) + + def test_pindex_multiples(self): + expected = PeriodIndex( + ["2011-01", "2011-03", "2011-05", "2011-07", "2011-09", "2011-11"], + freq="2M", + ) + + pi = period_range(start="1/1/11", end="12/31/11", freq="2M") + tm.assert_index_equal(pi, expected) + assert pi.freq == offsets.MonthEnd(2) + assert pi.freqstr == "2M" + + pi = period_range(start="1/1/11", periods=6, freq="2M") + tm.assert_index_equal(pi, expected) + assert pi.freq == offsets.MonthEnd(2) + assert pi.freqstr == "2M" + + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + @pytest.mark.filterwarnings("ignore:Period with BDay freq:FutureWarning") + def test_iteration(self): + index = period_range(start="1/1/10", periods=4, freq="B") + + result = list(index) + assert isinstance(result[0], Period) + assert result[0].freq == index.freq + + def test_with_multi_index(self): + # #1705 + index = date_range("1/1/2012", periods=4, freq="12h") + index_as_arrays = [index.to_period(freq="D"), index.hour] + + s = Series([0, 1, 2, 3], index_as_arrays) + + assert isinstance(s.index.levels[0], PeriodIndex) + + assert isinstance(s.index.values[0][0], Period) + + def test_map(self): + # test_map_dictlike generally tests + + index = PeriodIndex([2005, 2007, 2009], freq="Y") + result = index.map(lambda x: x.ordinal) + exp = Index([x.ordinal for x in index]) + tm.assert_index_equal(result, exp) + + +def test_maybe_convert_timedelta(): + pi = PeriodIndex(["2000", "2001"], freq="D") + offset = offsets.Day(2) + assert pi._maybe_convert_timedelta(offset) == 2 + assert pi._maybe_convert_timedelta(2) == 2 + + offset = offsets.BusinessDay() + msg = r"Input has different freq=B from PeriodIndex\(freq=D\)" + with pytest.raises(ValueError, match=msg): + pi._maybe_convert_timedelta(offset) + + +@pytest.mark.parametrize("array", [True, False]) +def test_dunder_array(array): + obj = PeriodIndex(["2000-01-01", "2001-01-01"], freq="D") + if array: + obj = obj._data + + expected = np.array([obj[0], obj[1]], dtype=object) + result = np.array(obj) + tm.assert_numpy_array_equal(result, expected) + + result = np.asarray(obj) + tm.assert_numpy_array_equal(result, expected) + + expected = obj.asi8 + for dtype in ["i8", "int64", np.int64]: + result = np.array(obj, dtype=dtype) + tm.assert_numpy_array_equal(result, expected) + + result = np.asarray(obj, dtype=dtype) + tm.assert_numpy_array_equal(result, expected) + + for dtype in ["float64", "int32", "uint64"]: + msg = "argument must be" + with pytest.raises(TypeError, match=msg): + np.array(obj, dtype=dtype) + with pytest.raises(TypeError, match=msg): + np.array(obj, dtype=getattr(np, dtype)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_period_range.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_period_range.py new file mode 100644 index 0000000000000000000000000000000000000000..6f8e6d07da8bf3c730ef1f82224388ba4b99ccb1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_period_range.py @@ -0,0 +1,241 @@ +import numpy as np +import pytest + +from pandas import ( + NaT, + Period, + PeriodIndex, + date_range, + period_range, +) +import pandas._testing as tm + + +class TestPeriodRangeKeywords: + def test_required_arguments(self): + msg = ( + "Of the three parameters: start, end, and periods, exactly two " + "must be specified" + ) + with pytest.raises(ValueError, match=msg): + period_range("2011-1-1", "2012-1-1", "B") + + def test_required_arguments2(self): + start = Period("02-Apr-2005", "D") + msg = ( + "Of the three parameters: start, end, and periods, exactly two " + "must be specified" + ) + with pytest.raises(ValueError, match=msg): + period_range(start=start) + + def test_required_arguments3(self): + # not enough params + msg = ( + "Of the three parameters: start, end, and periods, " + "exactly two must be specified" + ) + with pytest.raises(ValueError, match=msg): + period_range(start="2017Q1") + + with pytest.raises(ValueError, match=msg): + period_range(end="2017Q1") + + with pytest.raises(ValueError, match=msg): + period_range(periods=5) + + with pytest.raises(ValueError, match=msg): + period_range() + + def test_required_arguments_too_many(self): + msg = ( + "Of the three parameters: start, end, and periods, " + "exactly two must be specified" + ) + with pytest.raises(ValueError, match=msg): + period_range(start="2017Q1", end="2018Q1", periods=8, freq="Q") + + def test_start_end_non_nat(self): + # start/end NaT + msg = "start and end must not be NaT" + with pytest.raises(ValueError, match=msg): + period_range(start=NaT, end="2018Q1") + with pytest.raises(ValueError, match=msg): + period_range(start=NaT, end="2018Q1", freq="Q") + + with pytest.raises(ValueError, match=msg): + period_range(start="2017Q1", end=NaT) + with pytest.raises(ValueError, match=msg): + period_range(start="2017Q1", end=NaT, freq="Q") + + def test_periods_requires_integer(self): + # invalid periods param + msg = "periods must be a number, got foo" + with pytest.raises(TypeError, match=msg): + period_range(start="2017Q1", periods="foo") + + +class TestPeriodRange: + @pytest.mark.parametrize( + "freq_offset, freq_period", + [ + ("D", "D"), + ("W", "W"), + ("QE", "Q"), + ("YE", "Y"), + ], + ) + def test_construction_from_string(self, freq_offset, freq_period): + # non-empty + expected = date_range( + start="2017-01-01", periods=5, freq=freq_offset, name="foo" + ).to_period() + start, end = str(expected[0]), str(expected[-1]) + + result = period_range(start=start, end=end, freq=freq_period, name="foo") + tm.assert_index_equal(result, expected) + + result = period_range(start=start, periods=5, freq=freq_period, name="foo") + tm.assert_index_equal(result, expected) + + result = period_range(end=end, periods=5, freq=freq_period, name="foo") + tm.assert_index_equal(result, expected) + + # empty + expected = PeriodIndex([], freq=freq_period, name="foo") + + result = period_range(start=start, periods=0, freq=freq_period, name="foo") + tm.assert_index_equal(result, expected) + + result = period_range(end=end, periods=0, freq=freq_period, name="foo") + tm.assert_index_equal(result, expected) + + result = period_range(start=end, end=start, freq=freq_period, name="foo") + tm.assert_index_equal(result, expected) + + def test_construction_from_string_monthly(self): + # non-empty + expected = date_range( + start="2017-01-01", periods=5, freq="ME", name="foo" + ).to_period() + start, end = str(expected[0]), str(expected[-1]) + + result = period_range(start=start, end=end, freq="M", name="foo") + tm.assert_index_equal(result, expected) + + result = period_range(start=start, periods=5, freq="M", name="foo") + tm.assert_index_equal(result, expected) + + result = period_range(end=end, periods=5, freq="M", name="foo") + tm.assert_index_equal(result, expected) + + # empty + expected = PeriodIndex([], freq="M", name="foo") + + result = period_range(start=start, periods=0, freq="M", name="foo") + tm.assert_index_equal(result, expected) + + result = period_range(end=end, periods=0, freq="M", name="foo") + tm.assert_index_equal(result, expected) + + result = period_range(start=end, end=start, freq="M", name="foo") + tm.assert_index_equal(result, expected) + + def test_construction_from_period(self): + # upsampling + start, end = Period("2017Q1", freq="Q"), Period("2018Q1", freq="Q") + expected = date_range( + start="2017-03-31", end="2018-03-31", freq="ME", name="foo" + ).to_period() + result = period_range(start=start, end=end, freq="M", name="foo") + tm.assert_index_equal(result, expected) + + # downsampling + start = Period("2017-1", freq="M") + end = Period("2019-12", freq="M") + expected = date_range( + start="2017-01-31", end="2019-12-31", freq="QE", name="foo" + ).to_period() + result = period_range(start=start, end=end, freq="Q", name="foo") + tm.assert_index_equal(result, expected) + + # test for issue # 21793 + start = Period("2017Q1", freq="Q") + end = Period("2018Q1", freq="Q") + idx = period_range(start=start, end=end, freq="Q", name="foo") + result = idx == idx.values + expected = np.array([True, True, True, True, True]) + tm.assert_numpy_array_equal(result, expected) + + # empty + expected = PeriodIndex([], freq="W", name="foo") + + result = period_range(start=start, periods=0, freq="W", name="foo") + tm.assert_index_equal(result, expected) + + result = period_range(end=end, periods=0, freq="W", name="foo") + tm.assert_index_equal(result, expected) + + result = period_range(start=end, end=start, freq="W", name="foo") + tm.assert_index_equal(result, expected) + + def test_mismatched_start_end_freq_raises(self): + depr_msg = "Period with BDay freq is deprecated" + msg = "'w' is deprecated and will be removed in a future version." + with tm.assert_produces_warning(FutureWarning, match=msg): + end_w = Period("2006-12-31", "1w") + + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + start_b = Period("02-Apr-2005", "B") + end_b = Period("2005-05-01", "B") + + msg = "start and end must have same freq" + with pytest.raises(ValueError, match=msg): + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + period_range(start=start_b, end=end_w) + + # without mismatch we are OK + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + period_range(start=start_b, end=end_b) + + +class TestPeriodRangeDisallowedFreqs: + def test_constructor_U(self): + # U was used as undefined period + with pytest.raises(ValueError, match="Invalid frequency: X"): + period_range("2007-1-1", periods=500, freq="X") + + @pytest.mark.parametrize( + "freq,freq_depr", + [ + ("2Y", "2A"), + ("2Y", "2a"), + ("2Y-AUG", "2A-AUG"), + ("2Y-AUG", "2A-aug"), + ], + ) + def test_a_deprecated_from_time_series(self, freq, freq_depr): + # GH#52536 + msg = f"'{freq_depr[1:]}' is deprecated and will be removed in a " + f"future version. Please use '{freq[1:]}' instead." + + with tm.assert_produces_warning(FutureWarning, match=msg): + period_range(freq=freq_depr, start="1/1/2001", end="12/1/2009") + + @pytest.mark.parametrize("freq_depr", ["2H", "2MIN", "2S", "2US", "2NS"]) + def test_uppercase_freq_deprecated_from_time_series(self, freq_depr): + # GH#52536, GH#54939 + msg = f"'{freq_depr[1:]}' is deprecated and will be removed in a " + f"future version. Please use '{freq_depr.lower()[1:]}' instead." + + with tm.assert_produces_warning(FutureWarning, match=msg): + period_range("2020-01-01 00:00:00 00:00", periods=2, freq=freq_depr) + + @pytest.mark.parametrize("freq_depr", ["2m", "2q-sep", "2y", "2w"]) + def test_lowercase_freq_deprecated_from_time_series(self, freq_depr): + # GH#52536, GH#54939 + msg = f"'{freq_depr[1:]}' is deprecated and will be removed in a " + f"future version. Please use '{freq_depr.upper()[1:]}' instead." + + with tm.assert_produces_warning(FutureWarning, match=msg): + period_range(freq=freq_depr, start="1/1/2001", end="12/1/2009") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_pickle.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_pickle.py new file mode 100644 index 0000000000000000000000000000000000000000..7d359fdabb6f1229e713e45452c6816d9f5743e9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_pickle.py @@ -0,0 +1,26 @@ +import numpy as np +import pytest + +from pandas import ( + NaT, + PeriodIndex, + period_range, +) +import pandas._testing as tm + +from pandas.tseries import offsets + + +class TestPickle: + @pytest.mark.parametrize("freq", ["D", "M", "Y"]) + def test_pickle_round_trip(self, freq): + idx = PeriodIndex(["2016-05-16", "NaT", NaT, np.nan], freq=freq) + result = tm.round_trip_pickle(idx) + tm.assert_index_equal(result, idx) + + def test_pickle_freq(self): + # GH#2891 + prng = period_range("1/1/2011", "1/1/2012", freq="M") + new_prng = tm.round_trip_pickle(prng) + assert new_prng.freq == offsets.MonthEnd() + assert new_prng.freqstr == "M" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_resolution.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_resolution.py new file mode 100644 index 0000000000000000000000000000000000000000..680bdaa2e2a44c9603c6465274e4f4cea35e8701 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_resolution.py @@ -0,0 +1,23 @@ +import pytest + +import pandas as pd + + +class TestResolution: + @pytest.mark.parametrize( + "freq,expected", + [ + ("Y", "year"), + ("Q", "quarter"), + ("M", "month"), + ("D", "day"), + ("h", "hour"), + ("min", "minute"), + ("s", "second"), + ("ms", "millisecond"), + ("us", "microsecond"), + ], + ) + def test_resolution(self, freq, expected): + idx = pd.period_range(start="2013-04-01", periods=30, freq=freq) + assert idx.resolution == expected diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_scalar_compat.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_scalar_compat.py new file mode 100644 index 0000000000000000000000000000000000000000..d8afd29ff31c558a7e99861852b08d86deaa9fac --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_scalar_compat.py @@ -0,0 +1,38 @@ +"""Tests for PeriodIndex behaving like a vectorized Period scalar""" + +import pytest + +from pandas import ( + Timedelta, + date_range, + period_range, +) +import pandas._testing as tm + + +class TestPeriodIndexOps: + def test_start_time(self): + # GH#17157 + index = period_range(freq="M", start="2016-01-01", end="2016-05-31") + expected_index = date_range("2016-01-01", end="2016-05-31", freq="MS") + tm.assert_index_equal(index.start_time, expected_index) + + def test_end_time(self): + # GH#17157 + index = period_range(freq="M", start="2016-01-01", end="2016-05-31") + expected_index = date_range("2016-01-01", end="2016-05-31", freq="ME") + expected_index += Timedelta(1, "D") - Timedelta(1, "ns") + tm.assert_index_equal(index.end_time, expected_index) + + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + @pytest.mark.filterwarnings( + "ignore:Period with BDay freq is deprecated:FutureWarning" + ) + def test_end_time_business_friday(self): + # GH#34449 + pi = period_range("1990-01-05", freq="B", periods=1) + result = pi.end_time + + dti = date_range("1990-01-05", freq="D", periods=1)._with_freq(None) + expected = dti + Timedelta(days=1, nanoseconds=-1) + tm.assert_index_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_searchsorted.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_searchsorted.py new file mode 100644 index 0000000000000000000000000000000000000000..9b02a2f35fd0193bbc8133373299a0ac2cea38ea --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_searchsorted.py @@ -0,0 +1,80 @@ +import numpy as np +import pytest + +from pandas._libs.tslibs import IncompatibleFrequency + +from pandas import ( + NaT, + Period, + PeriodIndex, +) +import pandas._testing as tm + + +class TestSearchsorted: + @pytest.mark.parametrize("freq", ["D", "2D"]) + def test_searchsorted(self, freq): + pidx = PeriodIndex( + ["2014-01-01", "2014-01-02", "2014-01-03", "2014-01-04", "2014-01-05"], + freq=freq, + ) + + p1 = Period("2014-01-01", freq=freq) + assert pidx.searchsorted(p1) == 0 + + p2 = Period("2014-01-04", freq=freq) + assert pidx.searchsorted(p2) == 3 + + assert pidx.searchsorted(NaT) == 5 + + msg = "Input has different freq=h from PeriodArray" + with pytest.raises(IncompatibleFrequency, match=msg): + pidx.searchsorted(Period("2014-01-01", freq="h")) + + msg = "Input has different freq=5D from PeriodArray" + with pytest.raises(IncompatibleFrequency, match=msg): + pidx.searchsorted(Period("2014-01-01", freq="5D")) + + def test_searchsorted_different_argument_classes(self, listlike_box): + pidx = PeriodIndex( + ["2014-01-01", "2014-01-02", "2014-01-03", "2014-01-04", "2014-01-05"], + freq="D", + ) + result = pidx.searchsorted(listlike_box(pidx)) + expected = np.arange(len(pidx), dtype=result.dtype) + tm.assert_numpy_array_equal(result, expected) + + result = pidx._data.searchsorted(listlike_box(pidx)) + tm.assert_numpy_array_equal(result, expected) + + def test_searchsorted_invalid(self): + pidx = PeriodIndex( + ["2014-01-01", "2014-01-02", "2014-01-03", "2014-01-04", "2014-01-05"], + freq="D", + ) + + other = np.array([0, 1], dtype=np.int64) + + msg = "|".join( + [ + "searchsorted requires compatible dtype or scalar", + "value should be a 'Period', 'NaT', or array of those. Got", + ] + ) + with pytest.raises(TypeError, match=msg): + pidx.searchsorted(other) + + with pytest.raises(TypeError, match=msg): + pidx.searchsorted(other.astype("timedelta64[ns]")) + + with pytest.raises(TypeError, match=msg): + pidx.searchsorted(np.timedelta64(4)) + + with pytest.raises(TypeError, match=msg): + pidx.searchsorted(np.timedelta64("NaT", "ms")) + + with pytest.raises(TypeError, match=msg): + pidx.searchsorted(np.datetime64(4, "ns")) + + with pytest.raises(TypeError, match=msg): + pidx.searchsorted(np.datetime64("NaT", "ns")) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_setops.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_setops.py new file mode 100644 index 0000000000000000000000000000000000000000..2fa7e8cd0d2df5982cc0c798fbfba4e0230df367 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_setops.py @@ -0,0 +1,363 @@ +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + PeriodIndex, + date_range, + period_range, +) +import pandas._testing as tm + + +def _permute(obj): + return obj.take(np.random.default_rng(2).permutation(len(obj))) + + +class TestPeriodIndex: + def test_union(self, sort): + # union + other1 = period_range("1/1/2000", freq="D", periods=5) + rng1 = period_range("1/6/2000", freq="D", periods=5) + expected1 = PeriodIndex( + [ + "2000-01-06", + "2000-01-07", + "2000-01-08", + "2000-01-09", + "2000-01-10", + "2000-01-01", + "2000-01-02", + "2000-01-03", + "2000-01-04", + "2000-01-05", + ], + freq="D", + ) + + rng2 = period_range("1/1/2000", freq="D", periods=5) + other2 = period_range("1/4/2000", freq="D", periods=5) + expected2 = period_range("1/1/2000", freq="D", periods=8) + + rng3 = period_range("1/1/2000", freq="D", periods=5) + other3 = PeriodIndex([], freq="D") + expected3 = period_range("1/1/2000", freq="D", periods=5) + + rng4 = period_range("2000-01-01 09:00", freq="h", periods=5) + other4 = period_range("2000-01-02 09:00", freq="h", periods=5) + expected4 = PeriodIndex( + [ + "2000-01-01 09:00", + "2000-01-01 10:00", + "2000-01-01 11:00", + "2000-01-01 12:00", + "2000-01-01 13:00", + "2000-01-02 09:00", + "2000-01-02 10:00", + "2000-01-02 11:00", + "2000-01-02 12:00", + "2000-01-02 13:00", + ], + freq="h", + ) + + rng5 = PeriodIndex( + ["2000-01-01 09:01", "2000-01-01 09:03", "2000-01-01 09:05"], freq="min" + ) + other5 = PeriodIndex( + ["2000-01-01 09:01", "2000-01-01 09:05", "2000-01-01 09:08"], freq="min" + ) + expected5 = PeriodIndex( + [ + "2000-01-01 09:01", + "2000-01-01 09:03", + "2000-01-01 09:05", + "2000-01-01 09:08", + ], + freq="min", + ) + + rng6 = period_range("2000-01-01", freq="M", periods=7) + other6 = period_range("2000-04-01", freq="M", periods=7) + expected6 = period_range("2000-01-01", freq="M", periods=10) + + rng7 = period_range("2003-01-01", freq="Y", periods=5) + other7 = period_range("1998-01-01", freq="Y", periods=8) + expected7 = PeriodIndex( + [ + "2003", + "2004", + "2005", + "2006", + "2007", + "1998", + "1999", + "2000", + "2001", + "2002", + ], + freq="Y", + ) + + rng8 = PeriodIndex( + ["1/3/2000", "1/2/2000", "1/1/2000", "1/5/2000", "1/4/2000"], freq="D" + ) + other8 = period_range("1/6/2000", freq="D", periods=5) + expected8 = PeriodIndex( + [ + "1/3/2000", + "1/2/2000", + "1/1/2000", + "1/5/2000", + "1/4/2000", + "1/6/2000", + "1/7/2000", + "1/8/2000", + "1/9/2000", + "1/10/2000", + ], + freq="D", + ) + + for rng, other, expected in [ + (rng1, other1, expected1), + (rng2, other2, expected2), + (rng3, other3, expected3), + (rng4, other4, expected4), + (rng5, other5, expected5), + (rng6, other6, expected6), + (rng7, other7, expected7), + (rng8, other8, expected8), + ]: + result_union = rng.union(other, sort=sort) + if sort is None: + expected = expected.sort_values() + tm.assert_index_equal(result_union, expected) + + def test_union_misc(self, sort): + index = period_range("1/1/2000", "1/20/2000", freq="D") + + result = index[:-5].union(index[10:], sort=sort) + tm.assert_index_equal(result, index) + + # not in order + result = _permute(index[:-5]).union(_permute(index[10:]), sort=sort) + if sort is False: + tm.assert_index_equal(result.sort_values(), index) + else: + tm.assert_index_equal(result, index) + + # cast if different frequencies + index = period_range("1/1/2000", "1/20/2000", freq="D") + index2 = period_range("1/1/2000", "1/20/2000", freq="W-WED") + result = index.union(index2, sort=sort) + expected = index.astype(object).union(index2.astype(object), sort=sort) + tm.assert_index_equal(result, expected) + + def test_intersection(self, sort): + index = period_range("1/1/2000", "1/20/2000", freq="D") + + result = index[:-5].intersection(index[10:], sort=sort) + tm.assert_index_equal(result, index[10:-5]) + + # not in order + left = _permute(index[:-5]) + right = _permute(index[10:]) + result = left.intersection(right, sort=sort) + if sort is False: + tm.assert_index_equal(result.sort_values(), index[10:-5]) + else: + tm.assert_index_equal(result, index[10:-5]) + + # cast if different frequencies + index = period_range("1/1/2000", "1/20/2000", freq="D") + index2 = period_range("1/1/2000", "1/20/2000", freq="W-WED") + + result = index.intersection(index2, sort=sort) + expected = pd.Index([], dtype=object) + tm.assert_index_equal(result, expected) + + index3 = period_range("1/1/2000", "1/20/2000", freq="2D") + result = index.intersection(index3, sort=sort) + tm.assert_index_equal(result, expected) + + def test_intersection_cases(self, sort): + base = period_range("6/1/2000", "6/30/2000", freq="D", name="idx") + + # if target has the same name, it is preserved + rng2 = period_range("5/15/2000", "6/20/2000", freq="D", name="idx") + expected2 = period_range("6/1/2000", "6/20/2000", freq="D", name="idx") + + # if target name is different, it will be reset + rng3 = period_range("5/15/2000", "6/20/2000", freq="D", name="other") + expected3 = period_range("6/1/2000", "6/20/2000", freq="D", name=None) + + rng4 = period_range("7/1/2000", "7/31/2000", freq="D", name="idx") + expected4 = PeriodIndex([], name="idx", freq="D") + + for rng, expected in [ + (rng2, expected2), + (rng3, expected3), + (rng4, expected4), + ]: + result = base.intersection(rng, sort=sort) + tm.assert_index_equal(result, expected) + assert result.name == expected.name + assert result.freq == expected.freq + + # non-monotonic + base = PeriodIndex( + ["2011-01-05", "2011-01-04", "2011-01-02", "2011-01-03"], + freq="D", + name="idx", + ) + + rng2 = PeriodIndex( + ["2011-01-04", "2011-01-02", "2011-02-02", "2011-02-03"], + freq="D", + name="idx", + ) + expected2 = PeriodIndex(["2011-01-04", "2011-01-02"], freq="D", name="idx") + + rng3 = PeriodIndex( + ["2011-01-04", "2011-01-02", "2011-02-02", "2011-02-03"], + freq="D", + name="other", + ) + expected3 = PeriodIndex(["2011-01-04", "2011-01-02"], freq="D", name=None) + + rng4 = period_range("7/1/2000", "7/31/2000", freq="D", name="idx") + expected4 = PeriodIndex([], freq="D", name="idx") + + for rng, expected in [ + (rng2, expected2), + (rng3, expected3), + (rng4, expected4), + ]: + result = base.intersection(rng, sort=sort) + if sort is None: + expected = expected.sort_values() + tm.assert_index_equal(result, expected) + assert result.name == expected.name + assert result.freq == "D" + + # empty same freq + rng = date_range("6/1/2000", "6/15/2000", freq="min") + result = rng[0:0].intersection(rng) + assert len(result) == 0 + + result = rng.intersection(rng[0:0]) + assert len(result) == 0 + + def test_difference(self, sort): + # diff + period_rng = ["1/3/2000", "1/2/2000", "1/1/2000", "1/5/2000", "1/4/2000"] + rng1 = PeriodIndex(period_rng, freq="D") + other1 = period_range("1/6/2000", freq="D", periods=5) + expected1 = rng1 + + rng2 = PeriodIndex(period_rng, freq="D") + other2 = period_range("1/4/2000", freq="D", periods=5) + expected2 = PeriodIndex(["1/3/2000", "1/2/2000", "1/1/2000"], freq="D") + + rng3 = PeriodIndex(period_rng, freq="D") + other3 = PeriodIndex([], freq="D") + expected3 = rng3 + + period_rng = [ + "2000-01-01 10:00", + "2000-01-01 09:00", + "2000-01-01 12:00", + "2000-01-01 11:00", + "2000-01-01 13:00", + ] + rng4 = PeriodIndex(period_rng, freq="h") + other4 = period_range("2000-01-02 09:00", freq="h", periods=5) + expected4 = rng4 + + rng5 = PeriodIndex( + ["2000-01-01 09:03", "2000-01-01 09:01", "2000-01-01 09:05"], freq="min" + ) + other5 = PeriodIndex(["2000-01-01 09:01", "2000-01-01 09:05"], freq="min") + expected5 = PeriodIndex(["2000-01-01 09:03"], freq="min") + + period_rng = [ + "2000-02-01", + "2000-01-01", + "2000-06-01", + "2000-07-01", + "2000-05-01", + "2000-03-01", + "2000-04-01", + ] + rng6 = PeriodIndex(period_rng, freq="M") + other6 = period_range("2000-04-01", freq="M", periods=7) + expected6 = PeriodIndex(["2000-02-01", "2000-01-01", "2000-03-01"], freq="M") + + period_rng = ["2003", "2007", "2006", "2005", "2004"] + rng7 = PeriodIndex(period_rng, freq="Y") + other7 = period_range("1998-01-01", freq="Y", periods=8) + expected7 = PeriodIndex(["2007", "2006"], freq="Y") + + for rng, other, expected in [ + (rng1, other1, expected1), + (rng2, other2, expected2), + (rng3, other3, expected3), + (rng4, other4, expected4), + (rng5, other5, expected5), + (rng6, other6, expected6), + (rng7, other7, expected7), + ]: + result_difference = rng.difference(other, sort=sort) + if sort is None and len(other): + # We dont sort (yet?) when empty GH#24959 + expected = expected.sort_values() + tm.assert_index_equal(result_difference, expected) + + def test_difference_freq(self, sort): + # GH14323: difference of Period MUST preserve frequency + # but the ability to union results must be preserved + + index = period_range("20160920", "20160925", freq="D") + + other = period_range("20160921", "20160924", freq="D") + expected = PeriodIndex(["20160920", "20160925"], freq="D") + idx_diff = index.difference(other, sort) + tm.assert_index_equal(idx_diff, expected) + tm.assert_attr_equal("freq", idx_diff, expected) + + other = period_range("20160922", "20160925", freq="D") + idx_diff = index.difference(other, sort) + expected = PeriodIndex(["20160920", "20160921"], freq="D") + tm.assert_index_equal(idx_diff, expected) + tm.assert_attr_equal("freq", idx_diff, expected) + + def test_intersection_equal_duplicates(self): + # GH#38302 + idx = period_range("2011-01-01", periods=2) + idx_dup = idx.append(idx) + result = idx_dup.intersection(idx_dup) + tm.assert_index_equal(result, idx) + + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + def test_union_duplicates(self): + # GH#36289 + idx = period_range("2011-01-01", periods=2) + idx_dup = idx.append(idx) + + idx2 = period_range("2011-01-02", periods=2) + idx2_dup = idx2.append(idx2) + result = idx_dup.union(idx2_dup) + + expected = PeriodIndex( + [ + "2011-01-01", + "2011-01-01", + "2011-01-02", + "2011-01-02", + "2011-01-03", + "2011-01-03", + ], + freq="D", + ) + tm.assert_index_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_tools.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..f507e64d88b06b5862de3e98c693ab9f85306116 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/period/test_tools.py @@ -0,0 +1,52 @@ +import numpy as np +import pytest + +from pandas import ( + Period, + PeriodIndex, + period_range, +) +import pandas._testing as tm + + +class TestPeriodRepresentation: + """ + Wish to match NumPy units + """ + + @pytest.mark.parametrize( + "freq, base_date", + [ + ("W-THU", "1970-01-01"), + ("D", "1970-01-01"), + ("B", "1970-01-01"), + ("h", "1970-01-01"), + ("min", "1970-01-01"), + ("s", "1970-01-01"), + ("ms", "1970-01-01"), + ("us", "1970-01-01"), + ("ns", "1970-01-01"), + ("M", "1970-01"), + ("Y", 1970), + ], + ) + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + @pytest.mark.filterwarnings( + "ignore:Period with BDay freq is deprecated:FutureWarning" + ) + def test_freq(self, freq, base_date): + rng = period_range(start=base_date, periods=10, freq=freq) + exp = np.arange(10, dtype=np.int64) + + tm.assert_numpy_array_equal(rng.asi8, exp) + + +class TestPeriodIndexConversion: + def test_tolist(self): + index = period_range(freq="Y", start="1/1/2001", end="12/1/2009") + rs = index.tolist() + for x in rs: + assert isinstance(x, Period) + + recon = PeriodIndex(rs) + tm.assert_index_equal(index, recon) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/ranges/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/ranges/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/ranges/test_constructors.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/ranges/test_constructors.py new file mode 100644 index 0000000000000000000000000000000000000000..5e6f16075ae636a3aa14e7443097f426bd6f998a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/ranges/test_constructors.py @@ -0,0 +1,164 @@ +from datetime import datetime + +import numpy as np +import pytest + +from pandas import ( + Index, + RangeIndex, + Series, +) +import pandas._testing as tm + + +class TestRangeIndexConstructors: + @pytest.mark.parametrize("name", [None, "foo"]) + @pytest.mark.parametrize( + "args, kwargs, start, stop, step", + [ + ((5,), {}, 0, 5, 1), + ((1, 5), {}, 1, 5, 1), + ((1, 5, 2), {}, 1, 5, 2), + ((0,), {}, 0, 0, 1), + ((0, 0), {}, 0, 0, 1), + ((), {"start": 0}, 0, 0, 1), + ((), {"stop": 0}, 0, 0, 1), + ], + ) + def test_constructor(self, args, kwargs, start, stop, step, name): + result = RangeIndex(*args, name=name, **kwargs) + expected = Index(np.arange(start, stop, step, dtype=np.int64), name=name) + assert isinstance(result, RangeIndex) + assert result.name is name + assert result._range == range(start, stop, step) + tm.assert_index_equal(result, expected, exact="equiv") + + def test_constructor_invalid_args(self): + msg = "RangeIndex\\(\\.\\.\\.\\) must be called with integers" + with pytest.raises(TypeError, match=msg): + RangeIndex() + + with pytest.raises(TypeError, match=msg): + RangeIndex(name="Foo") + + # we don't allow on a bare Index + msg = ( + r"Index\(\.\.\.\) must be called with a collection of some " + r"kind, 0 was passed" + ) + with pytest.raises(TypeError, match=msg): + Index(0) + + @pytest.mark.parametrize( + "args", + [ + Index(["a", "b"]), + Series(["a", "b"]), + np.array(["a", "b"]), + [], + np.arange(0, 10), + np.array([1]), + [1], + ], + ) + def test_constructor_additional_invalid_args(self, args): + msg = f"Value needs to be a scalar value, was type {type(args).__name__}" + with pytest.raises(TypeError, match=msg): + RangeIndex(args) + + @pytest.mark.parametrize("args", ["foo", datetime(2000, 1, 1, 0, 0)]) + def test_constructor_invalid_args_wrong_type(self, args): + msg = f"Wrong type {type(args)} for value {args}" + with pytest.raises(TypeError, match=msg): + RangeIndex(args) + + def test_constructor_same(self): + # pass thru w and w/o copy + index = RangeIndex(1, 5, 2) + result = RangeIndex(index, copy=False) + assert result.identical(index) + + result = RangeIndex(index, copy=True) + tm.assert_index_equal(result, index, exact=True) + + result = RangeIndex(index) + tm.assert_index_equal(result, index, exact=True) + + with pytest.raises( + ValueError, + match="Incorrect `dtype` passed: expected signed integer, received float64", + ): + RangeIndex(index, dtype="float64") + + def test_constructor_range_object(self): + result = RangeIndex(range(1, 5, 2)) + expected = RangeIndex(1, 5, 2) + tm.assert_index_equal(result, expected, exact=True) + + def test_constructor_range(self): + result = RangeIndex.from_range(range(1, 5, 2)) + expected = RangeIndex(1, 5, 2) + tm.assert_index_equal(result, expected, exact=True) + + result = RangeIndex.from_range(range(5, 6)) + expected = RangeIndex(5, 6, 1) + tm.assert_index_equal(result, expected, exact=True) + + # an invalid range + result = RangeIndex.from_range(range(5, 1)) + expected = RangeIndex(0, 0, 1) + tm.assert_index_equal(result, expected, exact=True) + + result = RangeIndex.from_range(range(5)) + expected = RangeIndex(0, 5, 1) + tm.assert_index_equal(result, expected, exact=True) + + result = Index(range(1, 5, 2)) + expected = RangeIndex(1, 5, 2) + tm.assert_index_equal(result, expected, exact=True) + + msg = ( + r"(RangeIndex.)?from_range\(\) got an unexpected keyword argument( 'copy')?" + ) + with pytest.raises(TypeError, match=msg): + RangeIndex.from_range(range(10), copy=True) + + def test_constructor_name(self): + # GH#12288 + orig = RangeIndex(10) + orig.name = "original" + + copy = RangeIndex(orig) + copy.name = "copy" + + assert orig.name == "original" + assert copy.name == "copy" + + new = Index(copy) + assert new.name == "copy" + + new.name = "new" + assert orig.name == "original" + assert copy.name == "copy" + assert new.name == "new" + + def test_constructor_corner(self): + arr = np.array([1, 2, 3, 4], dtype=object) + index = RangeIndex(1, 5) + assert index.values.dtype == np.int64 + expected = Index(arr).astype("int64") + + tm.assert_index_equal(index, expected, exact="equiv") + + # non-int raise Exception + with pytest.raises(TypeError, match=r"Wrong type \"): + RangeIndex("1", "10", "1") + with pytest.raises(TypeError, match=r"Wrong type \"): + RangeIndex(1.1, 10.2, 1.3) + + # invalid passed type + with pytest.raises( + ValueError, + match="Incorrect `dtype` passed: expected signed integer, received float64", + ): + RangeIndex(1, 5, dtype="float64") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/ranges/test_indexing.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/ranges/test_indexing.py new file mode 100644 index 0000000000000000000000000000000000000000..6202074a11d7883c6f6aa984c23d7964e9042eb0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/ranges/test_indexing.py @@ -0,0 +1,137 @@ +import numpy as np +import pytest + +from pandas import ( + Index, + RangeIndex, +) +import pandas._testing as tm + + +class TestGetIndexer: + def test_get_indexer(self): + index = RangeIndex(start=0, stop=20, step=2) + target = RangeIndex(10) + indexer = index.get_indexer(target) + expected = np.array([0, -1, 1, -1, 2, -1, 3, -1, 4, -1], dtype=np.intp) + tm.assert_numpy_array_equal(indexer, expected) + + def test_get_indexer_pad(self): + index = RangeIndex(start=0, stop=20, step=2) + target = RangeIndex(10) + indexer = index.get_indexer(target, method="pad") + expected = np.array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4], dtype=np.intp) + tm.assert_numpy_array_equal(indexer, expected) + + def test_get_indexer_backfill(self): + index = RangeIndex(start=0, stop=20, step=2) + target = RangeIndex(10) + indexer = index.get_indexer(target, method="backfill") + expected = np.array([0, 1, 1, 2, 2, 3, 3, 4, 4, 5], dtype=np.intp) + tm.assert_numpy_array_equal(indexer, expected) + + def test_get_indexer_limit(self): + # GH#28631 + idx = RangeIndex(4) + target = RangeIndex(6) + result = idx.get_indexer(target, method="pad", limit=1) + expected = np.array([0, 1, 2, 3, 3, -1], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("stop", [0, -1, -2]) + def test_get_indexer_decreasing(self, stop): + # GH#28678 + index = RangeIndex(7, stop, -3) + result = index.get_indexer(range(9)) + expected = np.array([-1, 2, -1, -1, 1, -1, -1, 0, -1], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + +class TestTake: + def test_take_preserve_name(self): + index = RangeIndex(1, 5, name="foo") + taken = index.take([3, 0, 1]) + assert index.name == taken.name + + def test_take_fill_value(self): + # GH#12631 + idx = RangeIndex(1, 4, name="xxx") + result = idx.take(np.array([1, 0, -1])) + expected = Index([2, 1, 3], dtype=np.int64, name="xxx") + tm.assert_index_equal(result, expected) + + # fill_value + msg = "Unable to fill values because RangeIndex cannot contain NA" + with pytest.raises(ValueError, match=msg): + idx.take(np.array([1, 0, -1]), fill_value=True) + + # allow_fill=False + result = idx.take(np.array([1, 0, -1]), allow_fill=False, fill_value=True) + expected = Index([2, 1, 3], dtype=np.int64, name="xxx") + tm.assert_index_equal(result, expected) + + msg = "Unable to fill values because RangeIndex cannot contain NA" + with pytest.raises(ValueError, match=msg): + idx.take(np.array([1, 0, -2]), fill_value=True) + with pytest.raises(ValueError, match=msg): + idx.take(np.array([1, 0, -5]), fill_value=True) + + def test_take_raises_index_error(self): + idx = RangeIndex(1, 4, name="xxx") + + msg = "index -5 is out of bounds for (axis 0 with )?size 3" + with pytest.raises(IndexError, match=msg): + idx.take(np.array([1, -5])) + + msg = "index -4 is out of bounds for (axis 0 with )?size 3" + with pytest.raises(IndexError, match=msg): + idx.take(np.array([1, -4])) + + # no errors + result = idx.take(np.array([1, -3])) + expected = Index([2, 1], dtype=np.int64, name="xxx") + tm.assert_index_equal(result, expected) + + def test_take_accepts_empty_array(self): + idx = RangeIndex(1, 4, name="foo") + result = idx.take(np.array([])) + expected = Index([], dtype=np.int64, name="foo") + tm.assert_index_equal(result, expected) + + # empty index + idx = RangeIndex(0, name="foo") + result = idx.take(np.array([])) + expected = Index([], dtype=np.int64, name="foo") + tm.assert_index_equal(result, expected) + + def test_take_accepts_non_int64_array(self): + idx = RangeIndex(1, 4, name="foo") + result = idx.take(np.array([2, 1], dtype=np.uint32)) + expected = Index([3, 2], dtype=np.int64, name="foo") + tm.assert_index_equal(result, expected) + + def test_take_when_index_has_step(self): + idx = RangeIndex(1, 11, 3, name="foo") # [1, 4, 7, 10] + result = idx.take(np.array([1, 0, -1, -4])) + expected = Index([4, 1, 10, 1], dtype=np.int64, name="foo") + tm.assert_index_equal(result, expected) + + def test_take_when_index_has_negative_step(self): + idx = RangeIndex(11, -4, -2, name="foo") # [11, 9, 7, 5, 3, 1, -1, -3] + result = idx.take(np.array([1, 0, -1, -8])) + expected = Index([9, 11, -3, 11], dtype=np.int64, name="foo") + tm.assert_index_equal(result, expected) + + +class TestWhere: + def test_where_putmask_range_cast(self): + # GH#43240 + idx = RangeIndex(0, 5, name="test") + + mask = np.array([True, True, False, False, False]) + result = idx.putmask(mask, 10) + expected = Index([10, 10, 2, 3, 4], dtype=np.int64, name="test") + tm.assert_index_equal(result, expected) + + result = idx.where(~mask, 10) + tm.assert_index_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/ranges/test_join.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/ranges/test_join.py new file mode 100644 index 0000000000000000000000000000000000000000..682b5c8def9ff0e00b533610c1d45a093e7d7a8d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/ranges/test_join.py @@ -0,0 +1,177 @@ +import numpy as np + +from pandas import ( + Index, + RangeIndex, +) +import pandas._testing as tm + + +class TestJoin: + def test_join_outer(self): + # join with Index[int64] + index = RangeIndex(start=0, stop=20, step=2) + other = Index(np.arange(25, 14, -1, dtype=np.int64)) + + res, lidx, ridx = index.join(other, how="outer", return_indexers=True) + noidx_res = index.join(other, how="outer") + tm.assert_index_equal(res, noidx_res) + + eres = Index( + [0, 2, 4, 6, 8, 10, 12, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25] + ) + elidx = np.array( + [0, 1, 2, 3, 4, 5, 6, 7, -1, 8, -1, 9, -1, -1, -1, -1, -1, -1, -1], + dtype=np.intp, + ) + eridx = np.array( + [-1, -1, -1, -1, -1, -1, -1, -1, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], + dtype=np.intp, + ) + + assert isinstance(res, Index) and res.dtype == np.dtype(np.int64) + assert not isinstance(res, RangeIndex) + tm.assert_index_equal(res, eres, exact=True) + tm.assert_numpy_array_equal(lidx, elidx) + tm.assert_numpy_array_equal(ridx, eridx) + + # join with RangeIndex + other = RangeIndex(25, 14, -1) + + res, lidx, ridx = index.join(other, how="outer", return_indexers=True) + noidx_res = index.join(other, how="outer") + tm.assert_index_equal(res, noidx_res) + + assert isinstance(res, Index) and res.dtype == np.int64 + assert not isinstance(res, RangeIndex) + tm.assert_index_equal(res, eres) + tm.assert_numpy_array_equal(lidx, elidx) + tm.assert_numpy_array_equal(ridx, eridx) + + def test_join_inner(self): + # Join with non-RangeIndex + index = RangeIndex(start=0, stop=20, step=2) + other = Index(np.arange(25, 14, -1, dtype=np.int64)) + + res, lidx, ridx = index.join(other, how="inner", return_indexers=True) + + # no guarantee of sortedness, so sort for comparison purposes + ind = res.argsort() + res = res.take(ind) + lidx = lidx.take(ind) + ridx = ridx.take(ind) + + eres = Index([16, 18]) + elidx = np.array([8, 9], dtype=np.intp) + eridx = np.array([9, 7], dtype=np.intp) + + assert isinstance(res, Index) and res.dtype == np.int64 + tm.assert_index_equal(res, eres) + tm.assert_numpy_array_equal(lidx, elidx) + tm.assert_numpy_array_equal(ridx, eridx) + + # Join two RangeIndex + other = RangeIndex(25, 14, -1) + + res, lidx, ridx = index.join(other, how="inner", return_indexers=True) + + assert isinstance(res, RangeIndex) + tm.assert_index_equal(res, eres, exact="equiv") + tm.assert_numpy_array_equal(lidx, elidx) + tm.assert_numpy_array_equal(ridx, eridx) + + def test_join_left(self): + # Join with Index[int64] + index = RangeIndex(start=0, stop=20, step=2) + other = Index(np.arange(25, 14, -1, dtype=np.int64)) + + res, lidx, ridx = index.join(other, how="left", return_indexers=True) + eres = index + eridx = np.array([-1, -1, -1, -1, -1, -1, -1, -1, 9, 7], dtype=np.intp) + + assert isinstance(res, RangeIndex) + tm.assert_index_equal(res, eres) + assert lidx is None + tm.assert_numpy_array_equal(ridx, eridx) + + # Join withRangeIndex + other = Index(np.arange(25, 14, -1, dtype=np.int64)) + + res, lidx, ridx = index.join(other, how="left", return_indexers=True) + + assert isinstance(res, RangeIndex) + tm.assert_index_equal(res, eres) + assert lidx is None + tm.assert_numpy_array_equal(ridx, eridx) + + def test_join_right(self): + # Join with Index[int64] + index = RangeIndex(start=0, stop=20, step=2) + other = Index(np.arange(25, 14, -1, dtype=np.int64)) + + res, lidx, ridx = index.join(other, how="right", return_indexers=True) + eres = other + elidx = np.array([-1, -1, -1, -1, -1, -1, -1, 9, -1, 8, -1], dtype=np.intp) + + assert isinstance(other, Index) and other.dtype == np.int64 + tm.assert_index_equal(res, eres) + tm.assert_numpy_array_equal(lidx, elidx) + assert ridx is None + + # Join withRangeIndex + other = RangeIndex(25, 14, -1) + + res, lidx, ridx = index.join(other, how="right", return_indexers=True) + eres = other + + assert isinstance(other, RangeIndex) + tm.assert_index_equal(res, eres) + tm.assert_numpy_array_equal(lidx, elidx) + assert ridx is None + + def test_join_non_int_index(self): + index = RangeIndex(start=0, stop=20, step=2) + other = Index([3, 6, 7, 8, 10], dtype=object) + + outer = index.join(other, how="outer") + outer2 = other.join(index, how="outer") + expected = Index([0, 2, 3, 4, 6, 7, 8, 10, 12, 14, 16, 18]) + tm.assert_index_equal(outer, outer2) + tm.assert_index_equal(outer, expected) + + inner = index.join(other, how="inner") + inner2 = other.join(index, how="inner") + expected = Index([6, 8, 10]) + tm.assert_index_equal(inner, inner2) + tm.assert_index_equal(inner, expected) + + left = index.join(other, how="left") + tm.assert_index_equal(left, index.astype(object)) + + left2 = other.join(index, how="left") + tm.assert_index_equal(left2, other) + + right = index.join(other, how="right") + tm.assert_index_equal(right, other) + + right2 = other.join(index, how="right") + tm.assert_index_equal(right2, index.astype(object)) + + def test_join_non_unique(self): + index = RangeIndex(start=0, stop=20, step=2) + other = Index([4, 4, 3, 3]) + + res, lidx, ridx = index.join(other, return_indexers=True) + + eres = Index([0, 2, 4, 4, 6, 8, 10, 12, 14, 16, 18]) + elidx = np.array([0, 1, 2, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.intp) + eridx = np.array([-1, -1, 0, 1, -1, -1, -1, -1, -1, -1, -1], dtype=np.intp) + + tm.assert_index_equal(res, eres) + tm.assert_numpy_array_equal(lidx, elidx) + tm.assert_numpy_array_equal(ridx, eridx) + + def test_join_self(self, join_type): + index = RangeIndex(start=0, stop=20, step=2) + joined = index.join(index, how=join_type) + assert index is joined diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/ranges/test_range.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/ranges/test_range.py new file mode 100644 index 0000000000000000000000000000000000000000..06e19eeca67663318709772ff23f76675545e19b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/ranges/test_range.py @@ -0,0 +1,622 @@ +import numpy as np +import pytest + +from pandas.core.dtypes.common import ensure_platform_int + +import pandas as pd +from pandas import ( + Index, + RangeIndex, +) +import pandas._testing as tm + + +class TestRangeIndex: + @pytest.fixture + def simple_index(self): + return RangeIndex(start=0, stop=20, step=2) + + def test_constructor_unwraps_index(self): + result = RangeIndex(1, 3) + expected = np.array([1, 2], dtype=np.int64) + tm.assert_numpy_array_equal(result._data, expected) + + def test_can_hold_identifiers(self, simple_index): + idx = simple_index + key = idx[0] + assert idx._can_hold_identifiers_and_holds_name(key) is False + + def test_too_many_names(self, simple_index): + index = simple_index + with pytest.raises(ValueError, match="^Length"): + index.names = ["roger", "harold"] + + @pytest.mark.parametrize( + "index, start, stop, step", + [ + (RangeIndex(5), 0, 5, 1), + (RangeIndex(0, 5), 0, 5, 1), + (RangeIndex(5, step=2), 0, 5, 2), + (RangeIndex(1, 5, 2), 1, 5, 2), + ], + ) + def test_start_stop_step_attrs(self, index, start, stop, step): + # GH 25710 + assert index.start == start + assert index.stop == stop + assert index.step == step + + def test_copy(self): + i = RangeIndex(5, name="Foo") + i_copy = i.copy() + assert i_copy is not i + assert i_copy.identical(i) + assert i_copy._range == range(0, 5, 1) + assert i_copy.name == "Foo" + + def test_repr(self): + i = RangeIndex(5, name="Foo") + result = repr(i) + expected = "RangeIndex(start=0, stop=5, step=1, name='Foo')" + assert result == expected + + result = eval(result) + tm.assert_index_equal(result, i, exact=True) + + i = RangeIndex(5, 0, -1) + result = repr(i) + expected = "RangeIndex(start=5, stop=0, step=-1)" + assert result == expected + + result = eval(result) + tm.assert_index_equal(result, i, exact=True) + + def test_insert(self): + idx = RangeIndex(5, name="Foo") + result = idx[1:4] + + # test 0th element + tm.assert_index_equal(idx[0:4], result.insert(0, idx[0]), exact="equiv") + + # GH 18295 (test missing) + expected = Index([0, np.nan, 1, 2, 3, 4], dtype=np.float64) + for na in [np.nan, None, pd.NA]: + result = RangeIndex(5).insert(1, na) + tm.assert_index_equal(result, expected) + + result = RangeIndex(5).insert(1, pd.NaT) + expected = Index([0, pd.NaT, 1, 2, 3, 4], dtype=object) + tm.assert_index_equal(result, expected) + + def test_insert_edges_preserves_rangeindex(self): + idx = Index(range(4, 9, 2)) + + result = idx.insert(0, 2) + expected = Index(range(2, 9, 2)) + tm.assert_index_equal(result, expected, exact=True) + + result = idx.insert(3, 10) + expected = Index(range(4, 11, 2)) + tm.assert_index_equal(result, expected, exact=True) + + def test_insert_middle_preserves_rangeindex(self): + # insert in the middle + idx = Index(range(0, 3, 2)) + result = idx.insert(1, 1) + expected = Index(range(3)) + tm.assert_index_equal(result, expected, exact=True) + + idx = idx * 2 + result = idx.insert(1, 2) + expected = expected * 2 + tm.assert_index_equal(result, expected, exact=True) + + def test_delete(self): + idx = RangeIndex(5, name="Foo") + expected = idx[1:] + result = idx.delete(0) + tm.assert_index_equal(result, expected, exact=True) + assert result.name == expected.name + + expected = idx[:-1] + result = idx.delete(-1) + tm.assert_index_equal(result, expected, exact=True) + assert result.name == expected.name + + msg = "index 5 is out of bounds for axis 0 with size 5" + with pytest.raises((IndexError, ValueError), match=msg): + # either depending on numpy version + result = idx.delete(len(idx)) + + def test_delete_preserves_rangeindex(self): + idx = Index(range(2), name="foo") + + result = idx.delete([1]) + expected = Index(range(1), name="foo") + tm.assert_index_equal(result, expected, exact=True) + + result = idx.delete(1) + tm.assert_index_equal(result, expected, exact=True) + + def test_delete_preserves_rangeindex_middle(self): + idx = Index(range(3), name="foo") + result = idx.delete(1) + expected = idx[::2] + tm.assert_index_equal(result, expected, exact=True) + + result = idx.delete(-2) + tm.assert_index_equal(result, expected, exact=True) + + def test_delete_preserves_rangeindex_list_at_end(self): + idx = RangeIndex(0, 6, 1) + + loc = [2, 3, 4, 5] + result = idx.delete(loc) + expected = idx[:2] + tm.assert_index_equal(result, expected, exact=True) + + result = idx.delete(loc[::-1]) + tm.assert_index_equal(result, expected, exact=True) + + def test_delete_preserves_rangeindex_list_middle(self): + idx = RangeIndex(0, 6, 1) + + loc = [1, 2, 3, 4] + result = idx.delete(loc) + expected = RangeIndex(0, 6, 5) + tm.assert_index_equal(result, expected, exact=True) + + result = idx.delete(loc[::-1]) + tm.assert_index_equal(result, expected, exact=True) + + def test_delete_all_preserves_rangeindex(self): + idx = RangeIndex(0, 6, 1) + + loc = [0, 1, 2, 3, 4, 5] + result = idx.delete(loc) + expected = idx[:0] + tm.assert_index_equal(result, expected, exact=True) + + result = idx.delete(loc[::-1]) + tm.assert_index_equal(result, expected, exact=True) + + def test_delete_not_preserving_rangeindex(self): + idx = RangeIndex(0, 6, 1) + + loc = [0, 3, 5] + result = idx.delete(loc) + expected = Index([1, 2, 4]) + tm.assert_index_equal(result, expected, exact=True) + + result = idx.delete(loc[::-1]) + tm.assert_index_equal(result, expected, exact=True) + + def test_view(self): + i = RangeIndex(0, name="Foo") + i_view = i.view() + assert i_view.name == "Foo" + + i_view = i.view("i8") + tm.assert_numpy_array_equal(i.values, i_view) + + msg = "Passing a type in RangeIndex.view is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + i_view = i.view(RangeIndex) + tm.assert_index_equal(i, i_view) + + def test_dtype(self, simple_index): + index = simple_index + assert index.dtype == np.int64 + + def test_cache(self): + # GH 26565, GH26617, GH35432, GH53387 + # This test checks whether _cache has been set. + # Calling RangeIndex._cache["_data"] creates an int64 array of the same length + # as the RangeIndex and stores it in _cache. + idx = RangeIndex(0, 100, 10) + + assert idx._cache == {} + + repr(idx) + assert idx._cache == {} + + str(idx) + assert idx._cache == {} + + idx.get_loc(20) + assert idx._cache == {} + + 90 in idx # True + assert idx._cache == {} + + 91 in idx # False + assert idx._cache == {} + + idx.all() + assert idx._cache == {} + + idx.any() + assert idx._cache == {} + + for _ in idx: + pass + assert idx._cache == {} + + msg = "RangeIndex.format is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + idx.format() + assert idx._cache == {} + + df = pd.DataFrame({"a": range(10)}, index=idx) + + # df.__repr__ should not populate index cache + str(df) + assert idx._cache == {} + + df.loc[50] + assert idx._cache == {} + + with pytest.raises(KeyError, match="51"): + df.loc[51] + assert idx._cache == {} + + df.loc[10:50] + assert idx._cache == {} + + df.iloc[5:10] + assert idx._cache == {} + + # after calling take, _cache may contain other keys, but not "_data" + idx.take([3, 0, 1]) + assert "_data" not in idx._cache + + df.loc[[50]] + assert "_data" not in idx._cache + + df.iloc[[5, 6, 7, 8, 9]] + assert "_data" not in idx._cache + + # idx._cache should contain a _data entry after call to idx._data + idx._data + assert isinstance(idx._data, np.ndarray) + assert idx._data is idx._data # check cached value is reused + assert "_data" in idx._cache + expected = np.arange(0, 100, 10, dtype="int64") + tm.assert_numpy_array_equal(idx._cache["_data"], expected) + + def test_is_monotonic(self): + index = RangeIndex(0, 20, 2) + assert index.is_monotonic_increasing is True + assert index.is_monotonic_increasing is True + assert index.is_monotonic_decreasing is False + assert index._is_strictly_monotonic_increasing is True + assert index._is_strictly_monotonic_decreasing is False + + index = RangeIndex(4, 0, -1) + assert index.is_monotonic_increasing is False + assert index._is_strictly_monotonic_increasing is False + assert index.is_monotonic_decreasing is True + assert index._is_strictly_monotonic_decreasing is True + + index = RangeIndex(1, 2) + assert index.is_monotonic_increasing is True + assert index.is_monotonic_increasing is True + assert index.is_monotonic_decreasing is True + assert index._is_strictly_monotonic_increasing is True + assert index._is_strictly_monotonic_decreasing is True + + index = RangeIndex(2, 1) + assert index.is_monotonic_increasing is True + assert index.is_monotonic_increasing is True + assert index.is_monotonic_decreasing is True + assert index._is_strictly_monotonic_increasing is True + assert index._is_strictly_monotonic_decreasing is True + + index = RangeIndex(1, 1) + assert index.is_monotonic_increasing is True + assert index.is_monotonic_increasing is True + assert index.is_monotonic_decreasing is True + assert index._is_strictly_monotonic_increasing is True + assert index._is_strictly_monotonic_decreasing is True + + @pytest.mark.parametrize( + "left,right", + [ + (RangeIndex(0, 9, 2), RangeIndex(0, 10, 2)), + (RangeIndex(0), RangeIndex(1, -1, 3)), + (RangeIndex(1, 2, 3), RangeIndex(1, 3, 4)), + (RangeIndex(0, -9, -2), RangeIndex(0, -10, -2)), + ], + ) + def test_equals_range(self, left, right): + assert left.equals(right) + assert right.equals(left) + + def test_logical_compat(self, simple_index): + idx = simple_index + assert idx.all() == idx.values.all() + assert idx.any() == idx.values.any() + + def test_identical(self, simple_index): + index = simple_index + i = Index(index.copy()) + assert i.identical(index) + + # we don't allow object dtype for RangeIndex + if isinstance(index, RangeIndex): + return + + same_values_different_type = Index(i, dtype=object) + assert not i.identical(same_values_different_type) + + i = index.copy(dtype=object) + i = i.rename("foo") + same_values = Index(i, dtype=object) + assert same_values.identical(index.copy(dtype=object)) + + assert not i.identical(index) + assert Index(same_values, name="foo", dtype=object).identical(i) + + assert not index.copy(dtype=object).identical(index.copy(dtype="int64")) + + def test_nbytes(self): + # memory savings vs int index + idx = RangeIndex(0, 1000) + assert idx.nbytes < Index(idx._values).nbytes / 10 + + # constant memory usage + i2 = RangeIndex(0, 10) + assert idx.nbytes == i2.nbytes + + @pytest.mark.parametrize( + "start,stop,step", + [ + # can't + ("foo", "bar", "baz"), + # shouldn't + ("0", "1", "2"), + ], + ) + def test_cant_or_shouldnt_cast(self, start, stop, step): + msg = f"Wrong type {type(start)} for value {start}" + with pytest.raises(TypeError, match=msg): + RangeIndex(start, stop, step) + + def test_view_index(self, simple_index): + index = simple_index + msg = "Passing a type in RangeIndex.view is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + index.view(Index) + + def test_prevent_casting(self, simple_index): + index = simple_index + result = index.astype("O") + assert result.dtype == np.object_ + + def test_repr_roundtrip(self, simple_index): + index = simple_index + tm.assert_index_equal(eval(repr(index)), index) + + def test_slice_keep_name(self): + idx = RangeIndex(1, 2, name="asdf") + assert idx.name == idx[1:].name + + @pytest.mark.parametrize( + "index", + [ + RangeIndex(start=0, stop=20, step=2, name="foo"), + RangeIndex(start=18, stop=-1, step=-2, name="bar"), + ], + ids=["index_inc", "index_dec"], + ) + def test_has_duplicates(self, index): + assert index.is_unique + assert not index.has_duplicates + + def test_extended_gcd(self, simple_index): + index = simple_index + result = index._extended_gcd(6, 10) + assert result[0] == result[1] * 6 + result[2] * 10 + assert 2 == result[0] + + result = index._extended_gcd(10, 6) + assert 2 == result[1] * 10 + result[2] * 6 + assert 2 == result[0] + + def test_min_fitting_element(self): + result = RangeIndex(0, 20, 2)._min_fitting_element(1) + assert 2 == result + + result = RangeIndex(1, 6)._min_fitting_element(1) + assert 1 == result + + result = RangeIndex(18, -2, -2)._min_fitting_element(1) + assert 2 == result + + result = RangeIndex(5, 0, -1)._min_fitting_element(1) + assert 1 == result + + big_num = 500000000000000000000000 + + result = RangeIndex(5, big_num * 2, 1)._min_fitting_element(big_num) + assert big_num == result + + def test_slice_specialised(self, simple_index): + index = simple_index + index.name = "foo" + + # scalar indexing + res = index[1] + expected = 2 + assert res == expected + + res = index[-1] + expected = 18 + assert res == expected + + # slicing + # slice value completion + index_slice = index[:] + expected = index + tm.assert_index_equal(index_slice, expected) + + # positive slice values + index_slice = index[7:10:2] + expected = Index([14, 18], name="foo") + tm.assert_index_equal(index_slice, expected, exact="equiv") + + # negative slice values + index_slice = index[-1:-5:-2] + expected = Index([18, 14], name="foo") + tm.assert_index_equal(index_slice, expected, exact="equiv") + + # stop overshoot + index_slice = index[2:100:4] + expected = Index([4, 12], name="foo") + tm.assert_index_equal(index_slice, expected, exact="equiv") + + # reverse + index_slice = index[::-1] + expected = Index(index.values[::-1], name="foo") + tm.assert_index_equal(index_slice, expected, exact="equiv") + + index_slice = index[-8::-1] + expected = Index([4, 2, 0], name="foo") + tm.assert_index_equal(index_slice, expected, exact="equiv") + + index_slice = index[-40::-1] + expected = Index(np.array([], dtype=np.int64), name="foo") + tm.assert_index_equal(index_slice, expected, exact="equiv") + + index_slice = index[40::-1] + expected = Index(index.values[40::-1], name="foo") + tm.assert_index_equal(index_slice, expected, exact="equiv") + + index_slice = index[10::-1] + expected = Index(index.values[::-1], name="foo") + tm.assert_index_equal(index_slice, expected, exact="equiv") + + @pytest.mark.parametrize("step", set(range(-5, 6)) - {0}) + def test_len_specialised(self, step): + # make sure that our len is the same as np.arange calc + start, stop = (0, 5) if step > 0 else (5, 0) + + arr = np.arange(start, stop, step) + index = RangeIndex(start, stop, step) + assert len(index) == len(arr) + + index = RangeIndex(stop, start, step) + assert len(index) == 0 + + @pytest.mark.parametrize( + "indices, expected", + [ + ([RangeIndex(1, 12, 5)], RangeIndex(1, 12, 5)), + ([RangeIndex(0, 6, 4)], RangeIndex(0, 6, 4)), + ([RangeIndex(1, 3), RangeIndex(3, 7)], RangeIndex(1, 7)), + ([RangeIndex(1, 5, 2), RangeIndex(5, 6)], RangeIndex(1, 6, 2)), + ([RangeIndex(1, 3, 2), RangeIndex(4, 7, 3)], RangeIndex(1, 7, 3)), + ([RangeIndex(-4, 3, 2), RangeIndex(4, 7, 2)], RangeIndex(-4, 7, 2)), + ([RangeIndex(-4, -8), RangeIndex(-8, -12)], RangeIndex(0, 0)), + ([RangeIndex(-4, -8), RangeIndex(3, -4)], RangeIndex(0, 0)), + ([RangeIndex(-4, -8), RangeIndex(3, 5)], RangeIndex(3, 5)), + ([RangeIndex(-4, -2), RangeIndex(3, 5)], Index([-4, -3, 3, 4])), + ([RangeIndex(-2), RangeIndex(3, 5)], RangeIndex(3, 5)), + ([RangeIndex(2), RangeIndex(2)], Index([0, 1, 0, 1])), + ([RangeIndex(2), RangeIndex(2, 5), RangeIndex(5, 8, 4)], RangeIndex(0, 6)), + ( + [RangeIndex(2), RangeIndex(3, 5), RangeIndex(5, 8, 4)], + Index([0, 1, 3, 4, 5]), + ), + ( + [RangeIndex(-2, 2), RangeIndex(2, 5), RangeIndex(5, 8, 4)], + RangeIndex(-2, 6), + ), + ([RangeIndex(3), Index([-1, 3, 15])], Index([0, 1, 2, -1, 3, 15])), + ([RangeIndex(3), Index([-1, 3.1, 15.0])], Index([0, 1, 2, -1, 3.1, 15.0])), + ([RangeIndex(3), Index(["a", None, 14])], Index([0, 1, 2, "a", None, 14])), + ([RangeIndex(3, 1), Index(["a", None, 14])], Index(["a", None, 14])), + ], + ) + def test_append(self, indices, expected): + # GH16212 + result = indices[0].append(indices[1:]) + tm.assert_index_equal(result, expected, exact=True) + + if len(indices) == 2: + # Append single item rather than list + result2 = indices[0].append(indices[1]) + tm.assert_index_equal(result2, expected, exact=True) + + def test_engineless_lookup(self): + # GH 16685 + # Standard lookup on RangeIndex should not require the engine to be + # created + idx = RangeIndex(2, 10, 3) + + assert idx.get_loc(5) == 1 + tm.assert_numpy_array_equal( + idx.get_indexer([2, 8]), ensure_platform_int(np.array([0, 2])) + ) + with pytest.raises(KeyError, match="3"): + idx.get_loc(3) + + assert "_engine" not in idx._cache + + # Different types of scalars can be excluded immediately, no need to + # use the _engine + with pytest.raises(KeyError, match="'a'"): + idx.get_loc("a") + + assert "_engine" not in idx._cache + + def test_format_empty(self): + # GH35712 + empty_idx = RangeIndex(0) + msg = r"RangeIndex\.format is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + assert empty_idx.format() == [] + with tm.assert_produces_warning(FutureWarning, match=msg): + assert empty_idx.format(name=True) == [""] + + @pytest.mark.parametrize( + "ri", + [ + RangeIndex(0, -1, -1), + RangeIndex(0, 1, 1), + RangeIndex(1, 3, 2), + RangeIndex(0, -1, -2), + RangeIndex(-3, -5, -2), + ], + ) + def test_append_len_one(self, ri): + # GH39401 + result = ri.append([]) + tm.assert_index_equal(result, ri, exact=True) + + @pytest.mark.parametrize("base", [RangeIndex(0, 2), Index([0, 1])]) + def test_isin_range(self, base): + # GH#41151 + values = RangeIndex(0, 1) + result = base.isin(values) + expected = np.array([True, False]) + tm.assert_numpy_array_equal(result, expected) + + def test_sort_values_key(self): + # GH#43666, GH#52764 + sort_order = {8: 2, 6: 0, 4: 8, 2: 10, 0: 12} + values = RangeIndex(0, 10, 2) + result = values.sort_values(key=lambda x: x.map(sort_order)) + expected = Index([6, 8, 4, 2, 0], dtype="int64") + tm.assert_index_equal(result, expected, check_exact=True) + + # check this matches the Series.sort_values behavior + ser = values.to_series() + result2 = ser.sort_values(key=lambda x: x.map(sort_order)) + tm.assert_series_equal(result2, expected.to_series(), check_exact=True) + + def test_range_index_rsub_by_const(self): + # GH#53255 + result = 3 - RangeIndex(0, 4, 1) + expected = RangeIndex(3, -1, -1) + tm.assert_index_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/ranges/test_setops.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/ranges/test_setops.py new file mode 100644 index 0000000000000000000000000000000000000000..d417b8b743dc589bdf9d6acf5bde396a129ece23 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/ranges/test_setops.py @@ -0,0 +1,493 @@ +from datetime import ( + datetime, + timedelta, +) + +from hypothesis import ( + assume, + given, + strategies as st, +) +import numpy as np +import pytest + +from pandas import ( + Index, + RangeIndex, +) +import pandas._testing as tm + + +class TestRangeIndexSetOps: + @pytest.mark.parametrize("dtype", [None, "int64", "uint64"]) + def test_intersection_mismatched_dtype(self, dtype): + # check that we cast to float, not object + index = RangeIndex(start=0, stop=20, step=2, name="foo") + index = Index(index, dtype=dtype) + + flt = index.astype(np.float64) + + # bc index.equals(flt), we go through fastpath and get RangeIndex back + result = index.intersection(flt) + tm.assert_index_equal(result, index, exact=True) + + result = flt.intersection(index) + tm.assert_index_equal(result, flt, exact=True) + + # neither empty, not-equals + result = index.intersection(flt[1:]) + tm.assert_index_equal(result, flt[1:], exact=True) + + result = flt[1:].intersection(index) + tm.assert_index_equal(result, flt[1:], exact=True) + + # empty other + result = index.intersection(flt[:0]) + tm.assert_index_equal(result, flt[:0], exact=True) + + result = flt[:0].intersection(index) + tm.assert_index_equal(result, flt[:0], exact=True) + + def test_intersection_empty(self, sort, names): + # name retention on empty intersections + index = RangeIndex(start=0, stop=20, step=2, name=names[0]) + + # empty other + result = index.intersection(index[:0].rename(names[1]), sort=sort) + tm.assert_index_equal(result, index[:0].rename(names[2]), exact=True) + + # empty self + result = index[:0].intersection(index.rename(names[1]), sort=sort) + tm.assert_index_equal(result, index[:0].rename(names[2]), exact=True) + + def test_intersection(self, sort): + # intersect with Index with dtype int64 + index = RangeIndex(start=0, stop=20, step=2) + other = Index(np.arange(1, 6)) + result = index.intersection(other, sort=sort) + expected = Index(np.sort(np.intersect1d(index.values, other.values))) + tm.assert_index_equal(result, expected) + + result = other.intersection(index, sort=sort) + expected = Index( + np.sort(np.asarray(np.intersect1d(index.values, other.values))) + ) + tm.assert_index_equal(result, expected) + + # intersect with increasing RangeIndex + other = RangeIndex(1, 6) + result = index.intersection(other, sort=sort) + expected = Index(np.sort(np.intersect1d(index.values, other.values))) + tm.assert_index_equal(result, expected, exact="equiv") + + # intersect with decreasing RangeIndex + other = RangeIndex(5, 0, -1) + result = index.intersection(other, sort=sort) + expected = Index(np.sort(np.intersect1d(index.values, other.values))) + tm.assert_index_equal(result, expected, exact="equiv") + + # reversed (GH 17296) + result = other.intersection(index, sort=sort) + tm.assert_index_equal(result, expected, exact="equiv") + + # GH 17296: intersect two decreasing RangeIndexes + first = RangeIndex(10, -2, -2) + other = RangeIndex(5, -4, -1) + expected = first.astype(int).intersection(other.astype(int), sort=sort) + result = first.intersection(other, sort=sort).astype(int) + tm.assert_index_equal(result, expected) + + # reversed + result = other.intersection(first, sort=sort).astype(int) + tm.assert_index_equal(result, expected) + + index = RangeIndex(5, name="foo") + + # intersect of non-overlapping indices + other = RangeIndex(5, 10, 1, name="foo") + result = index.intersection(other, sort=sort) + expected = RangeIndex(0, 0, 1, name="foo") + tm.assert_index_equal(result, expected) + + other = RangeIndex(-1, -5, -1) + result = index.intersection(other, sort=sort) + expected = RangeIndex(0, 0, 1) + tm.assert_index_equal(result, expected) + + # intersection of empty indices + other = RangeIndex(0, 0, 1) + result = index.intersection(other, sort=sort) + expected = RangeIndex(0, 0, 1) + tm.assert_index_equal(result, expected) + + result = other.intersection(index, sort=sort) + tm.assert_index_equal(result, expected) + + def test_intersection_non_overlapping_gcd(self, sort, names): + # intersection of non-overlapping values based on start value and gcd + index = RangeIndex(1, 10, 2, name=names[0]) + other = RangeIndex(0, 10, 4, name=names[1]) + result = index.intersection(other, sort=sort) + expected = RangeIndex(0, 0, 1, name=names[2]) + tm.assert_index_equal(result, expected) + + def test_union_noncomparable(self, sort): + # corner case, Index with non-int64 dtype + index = RangeIndex(start=0, stop=20, step=2) + other = Index([datetime.now() + timedelta(i) for i in range(4)], dtype=object) + result = index.union(other, sort=sort) + expected = Index(np.concatenate((index, other))) + tm.assert_index_equal(result, expected) + + result = other.union(index, sort=sort) + expected = Index(np.concatenate((other, index))) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "idx1, idx2, expected_sorted, expected_notsorted", + [ + ( + RangeIndex(0, 10, 1), + RangeIndex(0, 10, 1), + RangeIndex(0, 10, 1), + RangeIndex(0, 10, 1), + ), + ( + RangeIndex(0, 10, 1), + RangeIndex(5, 20, 1), + RangeIndex(0, 20, 1), + RangeIndex(0, 20, 1), + ), + ( + RangeIndex(0, 10, 1), + RangeIndex(10, 20, 1), + RangeIndex(0, 20, 1), + RangeIndex(0, 20, 1), + ), + ( + RangeIndex(0, -10, -1), + RangeIndex(0, -10, -1), + RangeIndex(0, -10, -1), + RangeIndex(0, -10, -1), + ), + ( + RangeIndex(0, -10, -1), + RangeIndex(-10, -20, -1), + RangeIndex(-19, 1, 1), + RangeIndex(0, -20, -1), + ), + ( + RangeIndex(0, 10, 2), + RangeIndex(1, 10, 2), + RangeIndex(0, 10, 1), + Index(list(range(0, 10, 2)) + list(range(1, 10, 2))), + ), + ( + RangeIndex(0, 11, 2), + RangeIndex(1, 12, 2), + RangeIndex(0, 12, 1), + Index(list(range(0, 11, 2)) + list(range(1, 12, 2))), + ), + ( + RangeIndex(0, 21, 4), + RangeIndex(-2, 24, 4), + RangeIndex(-2, 24, 2), + Index(list(range(0, 21, 4)) + list(range(-2, 24, 4))), + ), + ( + RangeIndex(0, -20, -2), + RangeIndex(-1, -21, -2), + RangeIndex(-19, 1, 1), + Index(list(range(0, -20, -2)) + list(range(-1, -21, -2))), + ), + ( + RangeIndex(0, 100, 5), + RangeIndex(0, 100, 20), + RangeIndex(0, 100, 5), + RangeIndex(0, 100, 5), + ), + ( + RangeIndex(0, -100, -5), + RangeIndex(5, -100, -20), + RangeIndex(-95, 10, 5), + Index(list(range(0, -100, -5)) + [5]), + ), + ( + RangeIndex(0, -11, -1), + RangeIndex(1, -12, -4), + RangeIndex(-11, 2, 1), + Index(list(range(0, -11, -1)) + [1, -11]), + ), + (RangeIndex(0), RangeIndex(0), RangeIndex(0), RangeIndex(0)), + ( + RangeIndex(0, -10, -2), + RangeIndex(0), + RangeIndex(0, -10, -2), + RangeIndex(0, -10, -2), + ), + ( + RangeIndex(0, 100, 2), + RangeIndex(100, 150, 200), + RangeIndex(0, 102, 2), + RangeIndex(0, 102, 2), + ), + ( + RangeIndex(0, -100, -2), + RangeIndex(-100, 50, 102), + RangeIndex(-100, 4, 2), + Index(list(range(0, -100, -2)) + [-100, 2]), + ), + ( + RangeIndex(0, -100, -1), + RangeIndex(0, -50, -3), + RangeIndex(-99, 1, 1), + RangeIndex(0, -100, -1), + ), + ( + RangeIndex(0, 1, 1), + RangeIndex(5, 6, 10), + RangeIndex(0, 6, 5), + RangeIndex(0, 10, 5), + ), + ( + RangeIndex(0, 10, 5), + RangeIndex(-5, -6, -20), + RangeIndex(-5, 10, 5), + Index([0, 5, -5]), + ), + ( + RangeIndex(0, 3, 1), + RangeIndex(4, 5, 1), + Index([0, 1, 2, 4]), + Index([0, 1, 2, 4]), + ), + ( + RangeIndex(0, 10, 1), + Index([], dtype=np.int64), + RangeIndex(0, 10, 1), + RangeIndex(0, 10, 1), + ), + ( + RangeIndex(0), + Index([1, 5, 6]), + Index([1, 5, 6]), + Index([1, 5, 6]), + ), + # GH 43885 + ( + RangeIndex(0, 10), + RangeIndex(0, 5), + RangeIndex(0, 10), + RangeIndex(0, 10), + ), + ], + ids=lambda x: repr(x) if isinstance(x, RangeIndex) else x, + ) + def test_union_sorted(self, idx1, idx2, expected_sorted, expected_notsorted): + res1 = idx1.union(idx2, sort=None) + tm.assert_index_equal(res1, expected_sorted, exact=True) + + res1 = idx1.union(idx2, sort=False) + tm.assert_index_equal(res1, expected_notsorted, exact=True) + + res2 = idx2.union(idx1, sort=None) + res3 = Index(idx1._values, name=idx1.name).union(idx2, sort=None) + tm.assert_index_equal(res2, expected_sorted, exact=True) + tm.assert_index_equal(res3, expected_sorted, exact="equiv") + + def test_union_same_step_misaligned(self): + # GH#44019 + left = RangeIndex(range(0, 20, 4)) + right = RangeIndex(range(1, 21, 4)) + + result = left.union(right) + expected = Index([0, 1, 4, 5, 8, 9, 12, 13, 16, 17]) + tm.assert_index_equal(result, expected, exact=True) + + def test_difference(self): + # GH#12034 Cases where we operate against another RangeIndex and may + # get back another RangeIndex + obj = RangeIndex.from_range(range(1, 10), name="foo") + + result = obj.difference(obj) + expected = RangeIndex.from_range(range(0), name="foo") + tm.assert_index_equal(result, expected, exact=True) + + result = obj.difference(expected.rename("bar")) + tm.assert_index_equal(result, obj.rename(None), exact=True) + + result = obj.difference(obj[:3]) + tm.assert_index_equal(result, obj[3:], exact=True) + + result = obj.difference(obj[-3:]) + tm.assert_index_equal(result, obj[:-3], exact=True) + + # Flipping the step of 'other' doesn't affect the result, but + # flipping the stepof 'self' does when sort=None + result = obj[::-1].difference(obj[-3:]) + tm.assert_index_equal(result, obj[:-3], exact=True) + + result = obj[::-1].difference(obj[-3:], sort=False) + tm.assert_index_equal(result, obj[:-3][::-1], exact=True) + + result = obj[::-1].difference(obj[-3:][::-1]) + tm.assert_index_equal(result, obj[:-3], exact=True) + + result = obj[::-1].difference(obj[-3:][::-1], sort=False) + tm.assert_index_equal(result, obj[:-3][::-1], exact=True) + + result = obj.difference(obj[2:6]) + expected = Index([1, 2, 7, 8, 9], name="foo") + tm.assert_index_equal(result, expected, exact=True) + + def test_difference_sort(self): + # GH#44085 ensure we respect the sort keyword + + idx = Index(range(4))[::-1] + other = Index(range(3, 4)) + + result = idx.difference(other) + expected = Index(range(3)) + tm.assert_index_equal(result, expected, exact=True) + + result = idx.difference(other, sort=False) + expected = expected[::-1] + tm.assert_index_equal(result, expected, exact=True) + + # case where the intersection is empty + other = range(10, 12) + result = idx.difference(other, sort=None) + expected = idx[::-1] + tm.assert_index_equal(result, expected, exact=True) + + def test_difference_mismatched_step(self): + obj = RangeIndex.from_range(range(1, 10), name="foo") + + result = obj.difference(obj[::2]) + expected = obj[1::2] + tm.assert_index_equal(result, expected, exact=True) + + result = obj[::-1].difference(obj[::2], sort=False) + tm.assert_index_equal(result, expected[::-1], exact=True) + + result = obj.difference(obj[1::2]) + expected = obj[::2] + tm.assert_index_equal(result, expected, exact=True) + + result = obj[::-1].difference(obj[1::2], sort=False) + tm.assert_index_equal(result, expected[::-1], exact=True) + + def test_difference_interior_overlap_endpoints_preserved(self): + left = RangeIndex(range(4)) + right = RangeIndex(range(1, 3)) + + result = left.difference(right) + expected = RangeIndex(0, 4, 3) + assert expected.tolist() == [0, 3] + tm.assert_index_equal(result, expected, exact=True) + + def test_difference_endpoints_overlap_interior_preserved(self): + left = RangeIndex(-8, 20, 7) + right = RangeIndex(13, -9, -3) + + result = left.difference(right) + expected = RangeIndex(-1, 13, 7) + assert expected.tolist() == [-1, 6] + tm.assert_index_equal(result, expected, exact=True) + + def test_difference_interior_non_preserving(self): + # case with intersection of length 1 but RangeIndex is not preserved + idx = Index(range(10)) + + other = idx[3:4] + result = idx.difference(other) + expected = Index([0, 1, 2, 4, 5, 6, 7, 8, 9]) + tm.assert_index_equal(result, expected, exact=True) + + # case with other.step / self.step > 2 + other = idx[::3] + result = idx.difference(other) + expected = Index([1, 2, 4, 5, 7, 8]) + tm.assert_index_equal(result, expected, exact=True) + + # cases with only reaching one end of left + obj = Index(range(20)) + other = obj[:10:2] + result = obj.difference(other) + expected = Index([1, 3, 5, 7, 9] + list(range(10, 20))) + tm.assert_index_equal(result, expected, exact=True) + + other = obj[1:11:2] + result = obj.difference(other) + expected = Index([0, 2, 4, 6, 8, 10] + list(range(11, 20))) + tm.assert_index_equal(result, expected, exact=True) + + def test_symmetric_difference(self): + # GH#12034 Cases where we operate against another RangeIndex and may + # get back another RangeIndex + left = RangeIndex.from_range(range(1, 10), name="foo") + + result = left.symmetric_difference(left) + expected = RangeIndex.from_range(range(0), name="foo") + tm.assert_index_equal(result, expected) + + result = left.symmetric_difference(expected.rename("bar")) + tm.assert_index_equal(result, left.rename(None)) + + result = left[:-2].symmetric_difference(left[2:]) + expected = Index([1, 2, 8, 9], name="foo") + tm.assert_index_equal(result, expected, exact=True) + + right = RangeIndex.from_range(range(10, 15)) + + result = left.symmetric_difference(right) + expected = RangeIndex.from_range(range(1, 15)) + tm.assert_index_equal(result, expected) + + result = left.symmetric_difference(right[1:]) + expected = Index([1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14]) + tm.assert_index_equal(result, expected, exact=True) + + +def assert_range_or_not_is_rangelike(index): + """ + Check that we either have a RangeIndex or that this index *cannot* + be represented as a RangeIndex. + """ + if not isinstance(index, RangeIndex) and len(index) > 0: + diff = index[:-1] - index[1:] + assert not (diff == diff[0]).all() + + +@given( + st.integers(-20, 20), + st.integers(-20, 20), + st.integers(-20, 20), + st.integers(-20, 20), + st.integers(-20, 20), + st.integers(-20, 20), +) +def test_range_difference(start1, stop1, step1, start2, stop2, step2): + # test that + # a) we match Index[int64].difference and + # b) we return RangeIndex whenever it is possible to do so. + assume(step1 != 0) + assume(step2 != 0) + + left = RangeIndex(start1, stop1, step1) + right = RangeIndex(start2, stop2, step2) + + result = left.difference(right, sort=None) + assert_range_or_not_is_rangelike(result) + + left_int64 = Index(left.to_numpy()) + right_int64 = Index(right.to_numpy()) + + alt = left_int64.difference(right_int64, sort=None) + tm.assert_index_equal(result, alt, exact="equiv") + + result = left.difference(right, sort=False) + assert_range_or_not_is_rangelike(result) + + alt = left_int64.difference(right_int64, sort=False) + tm.assert_index_equal(result, alt, exact="equiv") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_any_index.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_any_index.py new file mode 100644 index 0000000000000000000000000000000000000000..10204cfb78e8928dd69e0ea33ce40b02840959ed --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_any_index.py @@ -0,0 +1,172 @@ +""" +Tests that can be parametrized over _any_ Index object. +""" +import re + +import numpy as np +import pytest + +from pandas.errors import InvalidIndexError + +import pandas._testing as tm + + +def test_boolean_context_compat(index): + # GH#7897 + with pytest.raises(ValueError, match="The truth value of a"): + if index: + pass + + with pytest.raises(ValueError, match="The truth value of a"): + bool(index) + + +def test_sort(index): + msg = "cannot sort an Index object in-place, use sort_values instead" + with pytest.raises(TypeError, match=msg): + index.sort() + + +def test_hash_error(index): + with pytest.raises(TypeError, match=f"unhashable type: '{type(index).__name__}'"): + hash(index) + + +def test_mutability(index): + if not len(index): + pytest.skip("Test doesn't make sense for empty index") + msg = "Index does not support mutable operations" + with pytest.raises(TypeError, match=msg): + index[0] = index[0] + + +@pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") +def test_map_identity_mapping(index, request): + # GH#12766 + + result = index.map(lambda x: x) + if index.dtype == object and result.dtype == bool: + assert (index == result).all() + # TODO: could work that into the 'exact="equiv"'? + return # FIXME: doesn't belong in this file anymore! + tm.assert_index_equal(result, index, exact="equiv") + + +def test_wrong_number_names(index): + names = index.nlevels * ["apple", "banana", "carrot"] + with pytest.raises(ValueError, match="^Length"): + index.names = names + + +def test_view_preserves_name(index): + assert index.view().name == index.name + + +def test_ravel(index): + # GH#19956 ravel returning ndarray is deprecated, in 2.0 returns a view on self + res = index.ravel() + tm.assert_index_equal(res, index) + + +class TestConversion: + def test_to_series(self, index): + # assert that we are creating a copy of the index + + ser = index.to_series() + assert ser.values is not index.values + assert ser.index is not index + assert ser.name == index.name + + def test_to_series_with_arguments(self, index): + # GH#18699 + + # index kwarg + ser = index.to_series(index=index) + + assert ser.values is not index.values + assert ser.index is index + assert ser.name == index.name + + # name kwarg + ser = index.to_series(name="__test") + + assert ser.values is not index.values + assert ser.index is not index + assert ser.name != index.name + + def test_tolist_matches_list(self, index): + assert index.tolist() == list(index) + + +class TestRoundTrips: + def test_pickle_roundtrip(self, index): + result = tm.round_trip_pickle(index) + tm.assert_index_equal(result, index, exact=True) + if result.nlevels > 1: + # GH#8367 round-trip with timezone + assert index.equal_levels(result) + + def test_pickle_preserves_name(self, index): + original_name, index.name = index.name, "foo" + unpickled = tm.round_trip_pickle(index) + assert index.equals(unpickled) + index.name = original_name + + +class TestIndexing: + def test_get_loc_listlike_raises_invalid_index_error(self, index): + # and never TypeError + key = np.array([0, 1], dtype=np.intp) + + with pytest.raises(InvalidIndexError, match=r"\[0 1\]"): + index.get_loc(key) + + with pytest.raises(InvalidIndexError, match=r"\[False True\]"): + index.get_loc(key.astype(bool)) + + def test_getitem_ellipsis(self, index): + # GH#21282 + result = index[...] + assert result.equals(index) + assert result is not index + + def test_slice_keeps_name(self, index): + assert index.name == index[1:].name + + @pytest.mark.parametrize("item", [101, "no_int", 2.5]) + def test_getitem_error(self, index, item): + msg = "|".join( + [ + r"index 101 is out of bounds for axis 0 with size [\d]+", + re.escape( + "only integers, slices (`:`), ellipsis (`...`), " + "numpy.newaxis (`None`) and integer or boolean arrays " + "are valid indices" + ), + "index out of bounds", # string[pyarrow] + ] + ) + with pytest.raises(IndexError, match=msg): + index[item] + + +class TestRendering: + def test_str(self, index): + # test the string repr + index.name = "foo" + assert "'foo'" in str(index) + assert type(index).__name__ in str(index) + + +class TestReductions: + def test_argmax_axis_invalid(self, index): + # GH#23081 + msg = r"`axis` must be fewer than the number of dimensions \(1\)" + with pytest.raises(ValueError, match=msg): + index.argmax(axis=1) + with pytest.raises(ValueError, match=msg): + index.argmin(axis=2) + with pytest.raises(ValueError, match=msg): + index.min(axis=-2) + with pytest.raises(ValueError, match=msg): + index.max(axis=-3) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_base.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..7eeb626d91dc8e19ec7768dc80bb9b08472c5b41 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_base.py @@ -0,0 +1,1737 @@ +from collections import defaultdict +from datetime import datetime +from functools import partial +import math +import operator +import re + +import numpy as np +import pytest + +from pandas.compat import IS64 +from pandas.errors import InvalidIndexError +import pandas.util._test_decorators as td + +from pandas.core.dtypes.common import ( + is_any_real_numeric_dtype, + is_numeric_dtype, + is_object_dtype, +) + +import pandas as pd +from pandas import ( + CategoricalIndex, + DataFrame, + DatetimeIndex, + IntervalIndex, + PeriodIndex, + RangeIndex, + Series, + TimedeltaIndex, + date_range, + period_range, + timedelta_range, +) +import pandas._testing as tm +from pandas.core.indexes.api import ( + Index, + MultiIndex, + _get_combined_index, + ensure_index, + ensure_index_from_sequences, +) + + +class TestIndex: + @pytest.fixture + def simple_index(self) -> Index: + return Index(list("abcde")) + + def test_can_hold_identifiers(self, simple_index): + index = simple_index + key = index[0] + assert index._can_hold_identifiers_and_holds_name(key) is True + + @pytest.mark.parametrize("index", ["datetime"], indirect=True) + def test_new_axis(self, index): + # TODO: a bunch of scattered tests check this deprecation is enforced. + # de-duplicate/centralize them. + with pytest.raises(ValueError, match="Multi-dimensional indexing"): + # GH#30588 multi-dimensional indexing deprecated + index[None, :] + + def test_constructor_regular(self, index): + tm.assert_contains_all(index, index) + + @pytest.mark.parametrize("index", ["string"], indirect=True) + def test_constructor_casting(self, index): + # casting + arr = np.array(index) + new_index = Index(arr) + tm.assert_contains_all(arr, new_index) + tm.assert_index_equal(index, new_index) + + def test_constructor_copy(self, using_infer_string): + index = Index(list("abc"), name="name") + arr = np.array(index) + new_index = Index(arr, copy=True, name="name") + assert isinstance(new_index, Index) + assert new_index.name == "name" + if using_infer_string: + tm.assert_extension_array_equal( + new_index.values, pd.array(arr, dtype="string[pyarrow_numpy]") + ) + else: + tm.assert_numpy_array_equal(arr, new_index.values) + arr[0] = "SOMEBIGLONGSTRING" + assert new_index[0] != "SOMEBIGLONGSTRING" + + @pytest.mark.parametrize("cast_as_obj", [True, False]) + @pytest.mark.parametrize( + "index", + [ + date_range( + "2015-01-01 10:00", + freq="D", + periods=3, + tz="US/Eastern", + name="Green Eggs & Ham", + ), # DTI with tz + date_range("2015-01-01 10:00", freq="D", periods=3), # DTI no tz + timedelta_range("1 days", freq="D", periods=3), # td + period_range("2015-01-01", freq="D", periods=3), # period + ], + ) + def test_constructor_from_index_dtlike(self, cast_as_obj, index): + if cast_as_obj: + with tm.assert_produces_warning(FutureWarning, match="Dtype inference"): + result = Index(index.astype(object)) + else: + result = Index(index) + + tm.assert_index_equal(result, index) + + if isinstance(index, DatetimeIndex): + assert result.tz == index.tz + if cast_as_obj: + # GH#23524 check that Index(dti, dtype=object) does not + # incorrectly raise ValueError, and that nanoseconds are not + # dropped + index += pd.Timedelta(nanoseconds=50) + result = Index(index, dtype=object) + assert result.dtype == np.object_ + assert list(result) == list(index) + + @pytest.mark.parametrize( + "index,has_tz", + [ + ( + date_range("2015-01-01 10:00", freq="D", periods=3, tz="US/Eastern"), + True, + ), # datetimetz + (timedelta_range("1 days", freq="D", periods=3), False), # td + (period_range("2015-01-01", freq="D", periods=3), False), # period + ], + ) + def test_constructor_from_series_dtlike(self, index, has_tz): + result = Index(Series(index)) + tm.assert_index_equal(result, index) + + if has_tz: + assert result.tz == index.tz + + def test_constructor_from_series_freq(self): + # GH 6273 + # create from a series, passing a freq + dts = ["1-1-1990", "2-1-1990", "3-1-1990", "4-1-1990", "5-1-1990"] + expected = DatetimeIndex(dts, freq="MS") + + s = Series(pd.to_datetime(dts)) + result = DatetimeIndex(s, freq="MS") + + tm.assert_index_equal(result, expected) + + def test_constructor_from_frame_series_freq(self, using_infer_string): + # GH 6273 + # create from a series, passing a freq + dts = ["1-1-1990", "2-1-1990", "3-1-1990", "4-1-1990", "5-1-1990"] + expected = DatetimeIndex(dts, freq="MS") + + df = DataFrame(np.random.default_rng(2).random((5, 3))) + df["date"] = dts + result = DatetimeIndex(df["date"], freq="MS") + dtype = object if not using_infer_string else "string" + assert df["date"].dtype == dtype + expected.name = "date" + tm.assert_index_equal(result, expected) + + expected = Series(dts, name="date") + tm.assert_series_equal(df["date"], expected) + + # GH 6274 + # infer freq of same + if not using_infer_string: + # Doesn't work with arrow strings + freq = pd.infer_freq(df["date"]) + assert freq == "MS" + + def test_constructor_int_dtype_nan(self): + # see gh-15187 + data = [np.nan] + expected = Index(data, dtype=np.float64) + result = Index(data, dtype="float") + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "klass,dtype,na_val", + [ + (Index, np.float64, np.nan), + (DatetimeIndex, "datetime64[ns]", pd.NaT), + ], + ) + def test_index_ctor_infer_nan_nat(self, klass, dtype, na_val): + # GH 13467 + na_list = [na_val, na_val] + expected = klass(na_list) + assert expected.dtype == dtype + + result = Index(na_list) + tm.assert_index_equal(result, expected) + + result = Index(np.array(na_list)) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "vals,dtype", + [ + ([1, 2, 3, 4, 5], "int"), + ([1.1, np.nan, 2.2, 3.0], "float"), + (["A", "B", "C", np.nan], "obj"), + ], + ) + def test_constructor_simple_new(self, vals, dtype): + index = Index(vals, name=dtype) + result = index._simple_new(index.values, dtype) + tm.assert_index_equal(result, index) + + @pytest.mark.parametrize("attr", ["values", "asi8"]) + @pytest.mark.parametrize("klass", [Index, DatetimeIndex]) + def test_constructor_dtypes_datetime(self, tz_naive_fixture, attr, klass): + # Test constructing with a datetimetz dtype + # .values produces numpy datetimes, so these are considered naive + # .asi8 produces integers, so these are considered epoch timestamps + # ^the above will be true in a later version. Right now we `.view` + # the i8 values as NS_DTYPE, effectively treating them as wall times. + index = date_range("2011-01-01", periods=5) + arg = getattr(index, attr) + index = index.tz_localize(tz_naive_fixture) + dtype = index.dtype + + # As of 2.0 astype raises on dt64.astype(dt64tz) + err = tz_naive_fixture is not None + msg = "Cannot use .astype to convert from timezone-naive dtype to" + + if attr == "asi8": + result = DatetimeIndex(arg).tz_localize(tz_naive_fixture) + tm.assert_index_equal(result, index) + elif klass is Index: + with pytest.raises(TypeError, match="unexpected keyword"): + klass(arg, tz=tz_naive_fixture) + else: + result = klass(arg, tz=tz_naive_fixture) + tm.assert_index_equal(result, index) + + if attr == "asi8": + if err: + with pytest.raises(TypeError, match=msg): + DatetimeIndex(arg).astype(dtype) + else: + result = DatetimeIndex(arg).astype(dtype) + tm.assert_index_equal(result, index) + else: + result = klass(arg, dtype=dtype) + tm.assert_index_equal(result, index) + + if attr == "asi8": + result = DatetimeIndex(list(arg)).tz_localize(tz_naive_fixture) + tm.assert_index_equal(result, index) + elif klass is Index: + with pytest.raises(TypeError, match="unexpected keyword"): + klass(arg, tz=tz_naive_fixture) + else: + result = klass(list(arg), tz=tz_naive_fixture) + tm.assert_index_equal(result, index) + + if attr == "asi8": + if err: + with pytest.raises(TypeError, match=msg): + DatetimeIndex(list(arg)).astype(dtype) + else: + result = DatetimeIndex(list(arg)).astype(dtype) + tm.assert_index_equal(result, index) + else: + result = klass(list(arg), dtype=dtype) + tm.assert_index_equal(result, index) + + @pytest.mark.parametrize("attr", ["values", "asi8"]) + @pytest.mark.parametrize("klass", [Index, TimedeltaIndex]) + def test_constructor_dtypes_timedelta(self, attr, klass): + index = timedelta_range("1 days", periods=5) + index = index._with_freq(None) # won't be preserved by constructors + dtype = index.dtype + + values = getattr(index, attr) + + result = klass(values, dtype=dtype) + tm.assert_index_equal(result, index) + + result = klass(list(values), dtype=dtype) + tm.assert_index_equal(result, index) + + @pytest.mark.parametrize("value", [[], iter([]), (_ for _ in [])]) + @pytest.mark.parametrize( + "klass", + [ + Index, + CategoricalIndex, + DatetimeIndex, + TimedeltaIndex, + ], + ) + def test_constructor_empty(self, value, klass): + empty = klass(value) + assert isinstance(empty, klass) + assert not len(empty) + + @pytest.mark.parametrize( + "empty,klass", + [ + (PeriodIndex([], freq="D"), PeriodIndex), + (PeriodIndex(iter([]), freq="D"), PeriodIndex), + (PeriodIndex((_ for _ in []), freq="D"), PeriodIndex), + (RangeIndex(step=1), RangeIndex), + (MultiIndex(levels=[[1, 2], ["blue", "red"]], codes=[[], []]), MultiIndex), + ], + ) + def test_constructor_empty_special(self, empty, klass): + assert isinstance(empty, klass) + assert not len(empty) + + @pytest.mark.parametrize( + "index", + [ + "datetime", + "float64", + "float32", + "int64", + "int32", + "period", + "range", + "repeats", + "timedelta", + "tuples", + "uint64", + "uint32", + ], + indirect=True, + ) + def test_view_with_args(self, index): + index.view("i8") + + @pytest.mark.parametrize( + "index", + [ + "string", + pytest.param("categorical", marks=pytest.mark.xfail(reason="gh-25464")), + "bool-object", + "bool-dtype", + "empty", + ], + indirect=True, + ) + def test_view_with_args_object_array_raises(self, index): + if index.dtype == bool: + msg = "When changing to a larger dtype" + with pytest.raises(ValueError, match=msg): + index.view("i8") + elif index.dtype == "string": + with pytest.raises(NotImplementedError, match="i8"): + index.view("i8") + else: + msg = ( + "Cannot change data-type for array of references|" + "Cannot change data-type for object array|" + ) + with pytest.raises(TypeError, match=msg): + index.view("i8") + + @pytest.mark.parametrize( + "index", + ["int64", "int32", "range"], + indirect=True, + ) + def test_astype(self, index): + casted = index.astype("i8") + + # it works! + casted.get_loc(5) + + # pass on name + index.name = "foobar" + casted = index.astype("i8") + assert casted.name == "foobar" + + def test_equals_object(self): + # same + assert Index(["a", "b", "c"]).equals(Index(["a", "b", "c"])) + + @pytest.mark.parametrize( + "comp", [Index(["a", "b"]), Index(["a", "b", "d"]), ["a", "b", "c"]] + ) + def test_not_equals_object(self, comp): + assert not Index(["a", "b", "c"]).equals(comp) + + def test_identical(self): + # index + i1 = Index(["a", "b", "c"]) + i2 = Index(["a", "b", "c"]) + + assert i1.identical(i2) + + i1 = i1.rename("foo") + assert i1.equals(i2) + assert not i1.identical(i2) + + i2 = i2.rename("foo") + assert i1.identical(i2) + + i3 = Index([("a", "a"), ("a", "b"), ("b", "a")]) + i4 = Index([("a", "a"), ("a", "b"), ("b", "a")], tupleize_cols=False) + assert not i3.identical(i4) + + def test_is_(self): + ind = Index(range(10)) + assert ind.is_(ind) + assert ind.is_(ind.view().view().view().view()) + assert not ind.is_(Index(range(10))) + assert not ind.is_(ind.copy()) + assert not ind.is_(ind.copy(deep=False)) + assert not ind.is_(ind[:]) + assert not ind.is_(np.array(range(10))) + + # quasi-implementation dependent + assert ind.is_(ind.view()) + ind2 = ind.view() + ind2.name = "bob" + assert ind.is_(ind2) + assert ind2.is_(ind) + # doesn't matter if Indices are *actually* views of underlying data, + assert not ind.is_(Index(ind.values)) + arr = np.array(range(1, 11)) + ind1 = Index(arr, copy=False) + ind2 = Index(arr, copy=False) + assert not ind1.is_(ind2) + + def test_asof_numeric_vs_bool_raises(self): + left = Index([1, 2, 3]) + right = Index([True, False], dtype=object) + + msg = "Cannot compare dtypes int64 and bool" + with pytest.raises(TypeError, match=msg): + left.asof(right[0]) + # TODO: should right.asof(left[0]) also raise? + + with pytest.raises(InvalidIndexError, match=re.escape(str(right))): + left.asof(right) + + with pytest.raises(InvalidIndexError, match=re.escape(str(left))): + right.asof(left) + + @pytest.mark.parametrize("index", ["string"], indirect=True) + def test_booleanindex(self, index): + bool_index = np.ones(len(index), dtype=bool) + bool_index[5:30:2] = False + + sub_index = index[bool_index] + + for i, val in enumerate(sub_index): + assert sub_index.get_loc(val) == i + + sub_index = index[list(bool_index)] + for i, val in enumerate(sub_index): + assert sub_index.get_loc(val) == i + + def test_fancy(self, simple_index): + index = simple_index + sl = index[[1, 2, 3]] + for i in sl: + assert i == sl[sl.get_loc(i)] + + @pytest.mark.parametrize( + "index", + ["string", "int64", "int32", "uint64", "uint32", "float64", "float32"], + indirect=True, + ) + @pytest.mark.parametrize("dtype", [int, np.bool_]) + def test_empty_fancy(self, index, dtype, request, using_infer_string): + if dtype is np.bool_ and using_infer_string and index.dtype == "string": + request.applymarker(pytest.mark.xfail(reason="numpy behavior is buggy")) + empty_arr = np.array([], dtype=dtype) + empty_index = type(index)([], dtype=index.dtype) + + assert index[[]].identical(empty_index) + if dtype == np.bool_: + with tm.assert_produces_warning(FutureWarning, match="is deprecated"): + assert index[empty_arr].identical(empty_index) + else: + assert index[empty_arr].identical(empty_index) + + @pytest.mark.parametrize( + "index", + ["string", "int64", "int32", "uint64", "uint32", "float64", "float32"], + indirect=True, + ) + def test_empty_fancy_raises(self, index): + # DatetimeIndex is excluded, because it overrides getitem and should + # be tested separately. + empty_farr = np.array([], dtype=np.float64) + empty_index = type(index)([], dtype=index.dtype) + + assert index[[]].identical(empty_index) + # np.ndarray only accepts ndarray of int & bool dtypes, so should Index + msg = r"arrays used as indices must be of integer" + with pytest.raises(IndexError, match=msg): + index[empty_farr] + + def test_union_dt_as_obj(self, simple_index): + # TODO: Replace with fixturesult + index = simple_index + date_index = date_range("2019-01-01", periods=10) + first_cat = index.union(date_index) + second_cat = index.union(index) + + appended = Index(np.append(index, date_index.astype("O"))) + + tm.assert_index_equal(first_cat, appended) + tm.assert_index_equal(second_cat, index) + tm.assert_contains_all(index, first_cat) + tm.assert_contains_all(index, second_cat) + tm.assert_contains_all(date_index, first_cat) + + def test_map_with_tuples(self): + # GH 12766 + + # Test that returning a single tuple from an Index + # returns an Index. + index = Index(np.arange(3), dtype=np.int64) + result = index.map(lambda x: (x,)) + expected = Index([(i,) for i in index]) + tm.assert_index_equal(result, expected) + + # Test that returning a tuple from a map of a single index + # returns a MultiIndex object. + result = index.map(lambda x: (x, x == 1)) + expected = MultiIndex.from_tuples([(i, i == 1) for i in index]) + tm.assert_index_equal(result, expected) + + def test_map_with_tuples_mi(self): + # Test that returning a single object from a MultiIndex + # returns an Index. + first_level = ["foo", "bar", "baz"] + multi_index = MultiIndex.from_tuples(zip(first_level, [1, 2, 3])) + reduced_index = multi_index.map(lambda x: x[0]) + tm.assert_index_equal(reduced_index, Index(first_level)) + + @pytest.mark.parametrize( + "index", + [ + date_range("2020-01-01", freq="D", periods=10), + period_range("2020-01-01", freq="D", periods=10), + timedelta_range("1 day", periods=10), + ], + ) + def test_map_tseries_indices_return_index(self, index): + expected = Index([1] * 10) + result = index.map(lambda x: 1) + tm.assert_index_equal(expected, result) + + def test_map_tseries_indices_accsr_return_index(self): + date_index = DatetimeIndex( + date_range("2020-01-01", periods=24, freq="h"), name="hourly" + ) + result = date_index.map(lambda x: x.hour) + expected = Index(np.arange(24, dtype="int64"), name="hourly") + tm.assert_index_equal(result, expected, exact=True) + + @pytest.mark.parametrize( + "mapper", + [ + lambda values, index: {i: e for e, i in zip(values, index)}, + lambda values, index: Series(values, index), + ], + ) + def test_map_dictlike_simple(self, mapper): + # GH 12756 + expected = Index(["foo", "bar", "baz"]) + index = Index(np.arange(3), dtype=np.int64) + result = index.map(mapper(expected.values, index)) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "mapper", + [ + lambda values, index: {i: e for e, i in zip(values, index)}, + lambda values, index: Series(values, index), + ], + ) + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + def test_map_dictlike(self, index, mapper, request): + # GH 12756 + if isinstance(index, CategoricalIndex): + pytest.skip("Tested in test_categorical") + elif not index.is_unique: + pytest.skip("Cannot map duplicated index") + + rng = np.arange(len(index), 0, -1, dtype=np.int64) + + if index.empty: + # to match proper result coercion for uints + expected = Index([]) + elif is_numeric_dtype(index.dtype): + expected = index._constructor(rng, dtype=index.dtype) + elif type(index) is Index and index.dtype != object: + # i.e. EA-backed, for now just Nullable + expected = Index(rng, dtype=index.dtype) + else: + expected = Index(rng) + + result = index.map(mapper(expected, index)) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "mapper", + [Series(["foo", 2.0, "baz"], index=[0, 2, -1]), {0: "foo", 2: 2.0, -1: "baz"}], + ) + def test_map_with_non_function_missing_values(self, mapper): + # GH 12756 + expected = Index([2.0, np.nan, "foo"]) + result = Index([2, 1, 0]).map(mapper) + + tm.assert_index_equal(expected, result) + + def test_map_na_exclusion(self): + index = Index([1.5, np.nan, 3, np.nan, 5]) + + result = index.map(lambda x: x * 2, na_action="ignore") + expected = index * 2 + tm.assert_index_equal(result, expected) + + def test_map_defaultdict(self): + index = Index([1, 2, 3]) + default_dict = defaultdict(lambda: "blank") + default_dict[1] = "stuff" + result = index.map(default_dict) + expected = Index(["stuff", "blank", "blank"]) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("name,expected", [("foo", "foo"), ("bar", None)]) + def test_append_empty_preserve_name(self, name, expected): + left = Index([], name="foo") + right = Index([1, 2, 3], name=name) + + msg = "The behavior of array concatenation with empty entries is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = left.append(right) + assert result.name == expected + + @pytest.mark.parametrize( + "index, expected", + [ + ("string", False), + ("bool-object", False), + ("bool-dtype", False), + ("categorical", False), + ("int64", True), + ("int32", True), + ("uint64", True), + ("uint32", True), + ("datetime", False), + ("float64", True), + ("float32", True), + ], + indirect=["index"], + ) + def test_is_numeric(self, index, expected): + assert is_any_real_numeric_dtype(index) is expected + + @pytest.mark.parametrize( + "index, expected", + [ + ("string", True), + ("bool-object", True), + ("bool-dtype", False), + ("categorical", False), + ("int64", False), + ("int32", False), + ("uint64", False), + ("uint32", False), + ("datetime", False), + ("float64", False), + ("float32", False), + ], + indirect=["index"], + ) + def test_is_object(self, index, expected, using_infer_string): + if using_infer_string and index.dtype == "string" and expected: + expected = False + assert is_object_dtype(index) is expected + + def test_summary(self, index): + index._summary() + + def test_format_bug(self): + # GH 14626 + # windows has different precision on datetime.datetime.now (it doesn't + # include us since the default for Timestamp shows these but Index + # formatting does not we are skipping) + now = datetime.now() + msg = r"Index\.format is deprecated" + + if not str(now).endswith("000"): + index = Index([now]) + with tm.assert_produces_warning(FutureWarning, match=msg): + formatted = index.format() + expected = [str(index[0])] + assert formatted == expected + + with tm.assert_produces_warning(FutureWarning, match=msg): + Index([]).format() + + @pytest.mark.parametrize("vals", [[1, 2.0 + 3.0j, 4.0], ["a", "b", "c"]]) + def test_format_missing(self, vals, nulls_fixture): + # 2845 + vals = list(vals) # Copy for each iteration + vals.append(nulls_fixture) + index = Index(vals, dtype=object) + # TODO: case with complex dtype? + + msg = r"Index\.format is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + formatted = index.format() + null_repr = "NaN" if isinstance(nulls_fixture, float) else str(nulls_fixture) + expected = [str(index[0]), str(index[1]), str(index[2]), null_repr] + + assert formatted == expected + assert index[3] is nulls_fixture + + @pytest.mark.parametrize("op", ["any", "all"]) + def test_logical_compat(self, op, simple_index): + index = simple_index + left = getattr(index, op)() + assert left == getattr(index.values, op)() + right = getattr(index.to_series(), op)() + # left might not match right exactly in e.g. string cases where the + # because we use np.any/all instead of .any/all + assert bool(left) == bool(right) + + @pytest.mark.parametrize( + "index", ["string", "int64", "int32", "float64", "float32"], indirect=True + ) + def test_drop_by_str_label(self, index): + n = len(index) + drop = index[list(range(5, 10))] + dropped = index.drop(drop) + + expected = index[list(range(5)) + list(range(10, n))] + tm.assert_index_equal(dropped, expected) + + dropped = index.drop(index[0]) + expected = index[1:] + tm.assert_index_equal(dropped, expected) + + @pytest.mark.parametrize( + "index", ["string", "int64", "int32", "float64", "float32"], indirect=True + ) + @pytest.mark.parametrize("keys", [["foo", "bar"], ["1", "bar"]]) + def test_drop_by_str_label_raises_missing_keys(self, index, keys): + with pytest.raises(KeyError, match=""): + index.drop(keys) + + @pytest.mark.parametrize( + "index", ["string", "int64", "int32", "float64", "float32"], indirect=True + ) + def test_drop_by_str_label_errors_ignore(self, index): + n = len(index) + drop = index[list(range(5, 10))] + mixed = drop.tolist() + ["foo"] + dropped = index.drop(mixed, errors="ignore") + + expected = index[list(range(5)) + list(range(10, n))] + tm.assert_index_equal(dropped, expected) + + dropped = index.drop(["foo", "bar"], errors="ignore") + expected = index[list(range(n))] + tm.assert_index_equal(dropped, expected) + + def test_drop_by_numeric_label_loc(self): + # TODO: Parametrize numeric and str tests after self.strIndex fixture + index = Index([1, 2, 3]) + dropped = index.drop(1) + expected = Index([2, 3]) + + tm.assert_index_equal(dropped, expected) + + def test_drop_by_numeric_label_raises_missing_keys(self): + index = Index([1, 2, 3]) + with pytest.raises(KeyError, match=""): + index.drop([3, 4]) + + @pytest.mark.parametrize( + "key,expected", [(4, Index([1, 2, 3])), ([3, 4, 5], Index([1, 2]))] + ) + def test_drop_by_numeric_label_errors_ignore(self, key, expected): + index = Index([1, 2, 3]) + dropped = index.drop(key, errors="ignore") + + tm.assert_index_equal(dropped, expected) + + @pytest.mark.parametrize( + "values", + [["a", "b", ("c", "d")], ["a", ("c", "d"), "b"], [("c", "d"), "a", "b"]], + ) + @pytest.mark.parametrize("to_drop", [[("c", "d"), "a"], ["a", ("c", "d")]]) + def test_drop_tuple(self, values, to_drop): + # GH 18304 + index = Index(values) + expected = Index(["b"], dtype=object) + + result = index.drop(to_drop) + tm.assert_index_equal(result, expected) + + removed = index.drop(to_drop[0]) + for drop_me in to_drop[1], [to_drop[1]]: + result = removed.drop(drop_me) + tm.assert_index_equal(result, expected) + + removed = index.drop(to_drop[1]) + msg = rf"\"\[{re.escape(to_drop[1].__repr__())}\] not found in axis\"" + for drop_me in to_drop[1], [to_drop[1]]: + with pytest.raises(KeyError, match=msg): + removed.drop(drop_me) + + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + def test_drop_with_duplicates_in_index(self, index): + # GH38051 + if len(index) == 0 or isinstance(index, MultiIndex): + pytest.skip("Test doesn't make sense for empty MultiIndex") + if isinstance(index, IntervalIndex) and not IS64: + pytest.skip("Cannot test IntervalIndex with int64 dtype on 32 bit platform") + index = index.unique().repeat(2) + expected = index[2:] + result = index.drop(index[0]) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "attr", + [ + "is_monotonic_increasing", + "is_monotonic_decreasing", + "_is_strictly_monotonic_increasing", + "_is_strictly_monotonic_decreasing", + ], + ) + def test_is_monotonic_incomparable(self, attr): + index = Index([5, datetime.now(), 7]) + assert not getattr(index, attr) + + @pytest.mark.parametrize("values", [["foo", "bar", "quux"], {"foo", "bar", "quux"}]) + @pytest.mark.parametrize( + "index,expected", + [ + (Index(["qux", "baz", "foo", "bar"]), np.array([False, False, True, True])), + (Index([]), np.array([], dtype=bool)), # empty + ], + ) + def test_isin(self, values, index, expected): + result = index.isin(values) + tm.assert_numpy_array_equal(result, expected) + + def test_isin_nan_common_object( + self, nulls_fixture, nulls_fixture2, using_infer_string + ): + # Test cartesian product of null fixtures and ensure that we don't + # mangle the various types (save a corner case with PyPy) + idx = Index(["a", nulls_fixture]) + + # all nans are the same + if ( + isinstance(nulls_fixture, float) + and isinstance(nulls_fixture2, float) + and math.isnan(nulls_fixture) + and math.isnan(nulls_fixture2) + ): + tm.assert_numpy_array_equal( + idx.isin([nulls_fixture2]), + np.array([False, True]), + ) + + elif nulls_fixture is nulls_fixture2: # should preserve NA type + tm.assert_numpy_array_equal( + idx.isin([nulls_fixture2]), + np.array([False, True]), + ) + + elif using_infer_string and idx.dtype == "string": + tm.assert_numpy_array_equal( + idx.isin([nulls_fixture2]), + np.array([False, True]), + ) + + else: + tm.assert_numpy_array_equal( + idx.isin([nulls_fixture2]), + np.array([False, False]), + ) + + def test_isin_nan_common_float64(self, nulls_fixture, float_numpy_dtype): + dtype = float_numpy_dtype + + if nulls_fixture is pd.NaT or nulls_fixture is pd.NA: + # Check 1) that we cannot construct a float64 Index with this value + # and 2) that with an NaN we do not have .isin(nulls_fixture) + msg = ( + r"float\(\) argument must be a string or a (real )?number, " + f"not {repr(type(nulls_fixture).__name__)}" + ) + with pytest.raises(TypeError, match=msg): + Index([1.0, nulls_fixture], dtype=dtype) + + idx = Index([1.0, np.nan], dtype=dtype) + assert not idx.isin([nulls_fixture]).any() + return + + idx = Index([1.0, nulls_fixture], dtype=dtype) + res = idx.isin([np.nan]) + tm.assert_numpy_array_equal(res, np.array([False, True])) + + # we cannot compare NaT with NaN + res = idx.isin([pd.NaT]) + tm.assert_numpy_array_equal(res, np.array([False, False])) + + @pytest.mark.parametrize("level", [0, -1]) + @pytest.mark.parametrize( + "index", + [ + Index(["qux", "baz", "foo", "bar"]), + Index([1.0, 2.0, 3.0, 4.0], dtype=np.float64), + ], + ) + def test_isin_level_kwarg(self, level, index): + values = index.tolist()[-2:] + ["nonexisting"] + + expected = np.array([False, False, True, True]) + tm.assert_numpy_array_equal(expected, index.isin(values, level=level)) + + index.name = "foobar" + tm.assert_numpy_array_equal(expected, index.isin(values, level="foobar")) + + def test_isin_level_kwarg_bad_level_raises(self, index): + for level in [10, index.nlevels, -(index.nlevels + 1)]: + with pytest.raises(IndexError, match="Too many levels"): + index.isin([], level=level) + + @pytest.mark.parametrize("label", [1.0, "foobar", "xyzzy", np.nan]) + def test_isin_level_kwarg_bad_label_raises(self, label, index): + if isinstance(index, MultiIndex): + index = index.rename(["foo", "bar"] + index.names[2:]) + msg = f"'Level {label} not found'" + else: + index = index.rename("foo") + msg = rf"Requested level \({label}\) does not match index name \(foo\)" + with pytest.raises(KeyError, match=msg): + index.isin([], level=label) + + @pytest.mark.parametrize("empty", [[], Series(dtype=object), np.array([])]) + def test_isin_empty(self, empty): + # see gh-16991 + index = Index(["a", "b"]) + expected = np.array([False, False]) + + result = index.isin(empty) + tm.assert_numpy_array_equal(expected, result) + + @td.skip_if_no("pyarrow") + def test_isin_arrow_string_null(self): + # GH#55821 + index = Index(["a", "b"], dtype="string[pyarrow_numpy]") + result = index.isin([None]) + expected = np.array([False, False]) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize( + "values", + [ + [1, 2, 3, 4], + [1.0, 2.0, 3.0, 4.0], + [True, True, True, True], + ["foo", "bar", "baz", "qux"], + date_range("2018-01-01", freq="D", periods=4), + ], + ) + def test_boolean_cmp(self, values): + index = Index(values) + result = index == values + expected = np.array([True, True, True, True], dtype=bool) + + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("index", ["string"], indirect=True) + @pytest.mark.parametrize("name,level", [(None, 0), ("a", "a")]) + def test_get_level_values(self, index, name, level): + expected = index.copy() + if name: + expected.name = name + + result = expected.get_level_values(level) + tm.assert_index_equal(result, expected) + + def test_slice_keep_name(self): + index = Index(["a", "b"], name="asdf") + assert index.name == index[1:].name + + @pytest.mark.parametrize( + "index", + [ + "string", + "datetime", + "int64", + "int32", + "uint64", + "uint32", + "float64", + "float32", + ], + indirect=True, + ) + def test_join_self(self, index, join_type): + result = index.join(index, how=join_type) + expected = index + if join_type == "outer": + expected = expected.sort_values() + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("method", ["strip", "rstrip", "lstrip"]) + def test_str_attribute(self, method): + # GH9068 + index = Index([" jack", "jill ", " jesse ", "frank"]) + expected = Index([getattr(str, method)(x) for x in index.values]) + + result = getattr(index.str, method)() + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "index", + [ + Index(range(5)), + date_range("2020-01-01", periods=10), + MultiIndex.from_tuples([("foo", "1"), ("bar", "3")]), + period_range(start="2000", end="2010", freq="Y"), + ], + ) + def test_str_attribute_raises(self, index): + with pytest.raises(AttributeError, match="only use .str accessor"): + index.str.repeat(2) + + @pytest.mark.parametrize( + "expand,expected", + [ + (None, Index([["a", "b", "c"], ["d", "e"], ["f"]])), + (False, Index([["a", "b", "c"], ["d", "e"], ["f"]])), + ( + True, + MultiIndex.from_tuples( + [("a", "b", "c"), ("d", "e", np.nan), ("f", np.nan, np.nan)] + ), + ), + ], + ) + def test_str_split(self, expand, expected): + index = Index(["a b c", "d e", "f"]) + if expand is not None: + result = index.str.split(expand=expand) + else: + result = index.str.split() + + tm.assert_index_equal(result, expected) + + def test_str_bool_return(self): + # test boolean case, should return np.array instead of boolean Index + index = Index(["a1", "a2", "b1", "b2"]) + result = index.str.startswith("a") + expected = np.array([True, True, False, False]) + + tm.assert_numpy_array_equal(result, expected) + assert isinstance(result, np.ndarray) + + def test_str_bool_series_indexing(self): + index = Index(["a1", "a2", "b1", "b2"]) + s = Series(range(4), index=index) + + result = s[s.index.str.startswith("a")] + expected = Series(range(2), index=["a1", "a2"]) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "index,expected", [(Index(list("abcd")), True), (Index(range(4)), False)] + ) + def test_tab_completion(self, index, expected): + # GH 9910 + result = "str" in dir(index) + assert result == expected + + def test_indexing_doesnt_change_class(self): + index = Index([1, 2, 3, "a", "b", "c"]) + + assert index[1:3].identical(Index([2, 3], dtype=np.object_)) + assert index[[0, 1]].identical(Index([1, 2], dtype=np.object_)) + + def test_outer_join_sort(self): + left_index = Index(np.random.default_rng(2).permutation(15)) + right_index = date_range("2020-01-01", periods=10) + + with tm.assert_produces_warning(RuntimeWarning): + result = left_index.join(right_index, how="outer") + + with tm.assert_produces_warning(RuntimeWarning): + expected = left_index.astype(object).union(right_index.astype(object)) + + tm.assert_index_equal(result, expected) + + def test_take_fill_value(self): + # GH 12631 + index = Index(list("ABC"), name="xxx") + result = index.take(np.array([1, 0, -1])) + expected = Index(list("BAC"), name="xxx") + tm.assert_index_equal(result, expected) + + # fill_value + result = index.take(np.array([1, 0, -1]), fill_value=True) + expected = Index(["B", "A", np.nan], name="xxx") + tm.assert_index_equal(result, expected) + + # allow_fill=False + result = index.take(np.array([1, 0, -1]), allow_fill=False, fill_value=True) + expected = Index(["B", "A", "C"], name="xxx") + tm.assert_index_equal(result, expected) + + def test_take_fill_value_none_raises(self): + index = Index(list("ABC"), name="xxx") + msg = ( + "When allow_fill=True and fill_value is not None, " + "all indices must be >= -1" + ) + + with pytest.raises(ValueError, match=msg): + index.take(np.array([1, 0, -2]), fill_value=True) + with pytest.raises(ValueError, match=msg): + index.take(np.array([1, 0, -5]), fill_value=True) + + def test_take_bad_bounds_raises(self): + index = Index(list("ABC"), name="xxx") + with pytest.raises(IndexError, match="out of bounds"): + index.take(np.array([1, -5])) + + @pytest.mark.parametrize("name", [None, "foobar"]) + @pytest.mark.parametrize( + "labels", + [ + [], + np.array([]), + ["A", "B", "C"], + ["C", "B", "A"], + np.array(["A", "B", "C"]), + np.array(["C", "B", "A"]), + # Must preserve name even if dtype changes + date_range("20130101", periods=3).values, + date_range("20130101", periods=3).tolist(), + ], + ) + def test_reindex_preserves_name_if_target_is_list_or_ndarray(self, name, labels): + # GH6552 + index = Index([0, 1, 2]) + index.name = name + assert index.reindex(labels)[0].name == name + + @pytest.mark.parametrize("labels", [[], np.array([]), np.array([], dtype=np.int64)]) + def test_reindex_preserves_type_if_target_is_empty_list_or_array(self, labels): + # GH7774 + index = Index(list("abc")) + assert index.reindex(labels)[0].dtype.type == index.dtype.type + + @pytest.mark.parametrize( + "labels,dtype", + [ + (DatetimeIndex([]), np.datetime64), + ], + ) + def test_reindex_doesnt_preserve_type_if_target_is_empty_index(self, labels, dtype): + # GH7774 + index = Index(list("abc")) + assert index.reindex(labels)[0].dtype.type == dtype + + def test_reindex_doesnt_preserve_type_if_target_is_empty_index_numeric( + self, any_real_numpy_dtype + ): + # GH7774 + dtype = any_real_numpy_dtype + index = Index(list("abc")) + labels = Index([], dtype=dtype) + assert index.reindex(labels)[0].dtype == dtype + + def test_reindex_no_type_preserve_target_empty_mi(self): + index = Index(list("abc")) + result = index.reindex( + MultiIndex([Index([], np.int64), Index([], np.float64)], [[], []]) + )[0] + assert result.levels[0].dtype.type == np.int64 + assert result.levels[1].dtype.type == np.float64 + + def test_reindex_ignoring_level(self): + # GH#35132 + idx = Index([1, 2, 3], name="x") + idx2 = Index([1, 2, 3, 4], name="x") + expected = Index([1, 2, 3, 4], name="x") + result, _ = idx.reindex(idx2, level="x") + tm.assert_index_equal(result, expected) + + def test_groupby(self): + index = Index(range(5)) + result = index.groupby(np.array([1, 1, 2, 2, 2])) + expected = {1: Index([0, 1]), 2: Index([2, 3, 4])} + + tm.assert_dict_equal(result, expected) + + @pytest.mark.parametrize( + "mi,expected", + [ + (MultiIndex.from_tuples([(1, 2), (4, 5)]), np.array([True, True])), + (MultiIndex.from_tuples([(1, 2), (4, 6)]), np.array([True, False])), + ], + ) + def test_equals_op_multiindex(self, mi, expected): + # GH9785 + # test comparisons of multiindex + df = DataFrame( + [3, 6], + columns=["c"], + index=MultiIndex.from_arrays([[1, 4], [2, 5]], names=["a", "b"]), + ) + + result = df.index == mi + tm.assert_numpy_array_equal(result, expected) + + def test_equals_op_multiindex_identify(self): + df = DataFrame( + [3, 6], + columns=["c"], + index=MultiIndex.from_arrays([[1, 4], [2, 5]], names=["a", "b"]), + ) + + result = df.index == df.index + expected = np.array([True, True]) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize( + "index", + [ + MultiIndex.from_tuples([(1, 2), (4, 5), (8, 9)]), + Index(["foo", "bar", "baz"]), + ], + ) + def test_equals_op_mismatched_multiindex_raises(self, index): + df = DataFrame( + [3, 6], + columns=["c"], + index=MultiIndex.from_arrays([[1, 4], [2, 5]], names=["a", "b"]), + ) + + with pytest.raises(ValueError, match="Lengths must match"): + df.index == index + + def test_equals_op_index_vs_mi_same_length(self, using_infer_string): + mi = MultiIndex.from_tuples([(1, 2), (4, 5), (8, 9)]) + index = Index(["foo", "bar", "baz"]) + + result = mi == index + expected = np.array([False, False, False]) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize( + "dt_conv, arg", + [ + (pd.to_datetime, ["2000-01-01", "2000-01-02"]), + (pd.to_timedelta, ["01:02:03", "01:02:04"]), + ], + ) + def test_dt_conversion_preserves_name(self, dt_conv, arg): + # GH 10875 + index = Index(arg, name="label") + assert index.name == dt_conv(index).name + + def test_cached_properties_not_settable(self): + index = Index([1, 2, 3]) + with pytest.raises(AttributeError, match="Can't set attribute"): + index.is_unique = False + + def test_tab_complete_warning(self, ip): + # https://github.com/pandas-dev/pandas/issues/16409 + pytest.importorskip("IPython", minversion="6.0.0") + from IPython.core.completer import provisionalcompleter + + code = "import pandas as pd; idx = pd.Index([1, 2])" + ip.run_cell(code) + + # GH 31324 newer jedi version raises Deprecation warning; + # appears resolved 2021-02-02 + with tm.assert_produces_warning(None, raise_on_extra_warnings=False): + with provisionalcompleter("ignore"): + list(ip.Completer.completions("idx.", 4)) + + def test_contains_method_removed(self, index): + # GH#30103 method removed for all types except IntervalIndex + if isinstance(index, IntervalIndex): + index.contains(1) + else: + msg = f"'{type(index).__name__}' object has no attribute 'contains'" + with pytest.raises(AttributeError, match=msg): + index.contains(1) + + def test_sortlevel(self): + index = Index([5, 4, 3, 2, 1]) + with pytest.raises(Exception, match="ascending must be a single bool value or"): + index.sortlevel(ascending="True") + + with pytest.raises( + Exception, match="ascending must be a list of bool values of length 1" + ): + index.sortlevel(ascending=[True, True]) + + with pytest.raises(Exception, match="ascending must be a bool value"): + index.sortlevel(ascending=["True"]) + + expected = Index([1, 2, 3, 4, 5]) + result = index.sortlevel(ascending=[True]) + tm.assert_index_equal(result[0], expected) + + expected = Index([1, 2, 3, 4, 5]) + result = index.sortlevel(ascending=True) + tm.assert_index_equal(result[0], expected) + + expected = Index([5, 4, 3, 2, 1]) + result = index.sortlevel(ascending=False) + tm.assert_index_equal(result[0], expected) + + def test_sortlevel_na_position(self): + # GH#51612 + idx = Index([1, np.nan]) + result = idx.sortlevel(na_position="first")[0] + expected = Index([np.nan, 1]) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "periods, expected_results", + [ + (1, [np.nan, 10, 10, 10, 10]), + (2, [np.nan, np.nan, 20, 20, 20]), + (3, [np.nan, np.nan, np.nan, 30, 30]), + ], + ) + def test_index_diff(self, periods, expected_results): + # GH#19708 + idx = Index([10, 20, 30, 40, 50]) + result = idx.diff(periods) + expected = Index(expected_results) + + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "decimals, expected_results", + [ + (0, [1.0, 2.0, 3.0]), + (1, [1.2, 2.3, 3.5]), + (2, [1.23, 2.35, 3.46]), + ], + ) + def test_index_round(self, decimals, expected_results): + # GH#19708 + idx = Index([1.234, 2.345, 3.456]) + result = idx.round(decimals) + expected = Index(expected_results) + + tm.assert_index_equal(result, expected) + + +class TestMixedIntIndex: + # Mostly the tests from common.py for which the results differ + # in py2 and py3 because ints and strings are uncomparable in py3 + # (GH 13514) + @pytest.fixture + def simple_index(self) -> Index: + return Index([0, "a", 1, "b", 2, "c"]) + + def test_argsort(self, simple_index): + index = simple_index + with pytest.raises(TypeError, match="'>|<' not supported"): + index.argsort() + + def test_numpy_argsort(self, simple_index): + index = simple_index + with pytest.raises(TypeError, match="'>|<' not supported"): + np.argsort(index) + + def test_copy_name(self, simple_index): + # Check that "name" argument passed at initialization is honoured + # GH12309 + index = simple_index + + first = type(index)(index, copy=True, name="mario") + second = type(first)(first, copy=False) + + # Even though "copy=False", we want a new object. + assert first is not second + tm.assert_index_equal(first, second) + + assert first.name == "mario" + assert second.name == "mario" + + s1 = Series(2, index=first) + s2 = Series(3, index=second[:-1]) + + s3 = s1 * s2 + + assert s3.index.name == "mario" + + def test_copy_name2(self): + # Check that adding a "name" parameter to the copy is honored + # GH14302 + index = Index([1, 2], name="MyName") + index1 = index.copy() + + tm.assert_index_equal(index, index1) + + index2 = index.copy(name="NewName") + tm.assert_index_equal(index, index2, check_names=False) + assert index.name == "MyName" + assert index2.name == "NewName" + + def test_unique_na(self): + idx = Index([2, np.nan, 2, 1], name="my_index") + expected = Index([2, np.nan, 1], name="my_index") + result = idx.unique() + tm.assert_index_equal(result, expected) + + def test_logical_compat(self, simple_index): + index = simple_index + assert index.all() == index.values.all() + assert index.any() == index.values.any() + + @pytest.mark.parametrize("how", ["any", "all"]) + @pytest.mark.parametrize("dtype", [None, object, "category"]) + @pytest.mark.parametrize( + "vals,expected", + [ + ([1, 2, 3], [1, 2, 3]), + ([1.0, 2.0, 3.0], [1.0, 2.0, 3.0]), + ([1.0, 2.0, np.nan, 3.0], [1.0, 2.0, 3.0]), + (["A", "B", "C"], ["A", "B", "C"]), + (["A", np.nan, "B", "C"], ["A", "B", "C"]), + ], + ) + def test_dropna(self, how, dtype, vals, expected): + # GH 6194 + index = Index(vals, dtype=dtype) + result = index.dropna(how=how) + expected = Index(expected, dtype=dtype) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("how", ["any", "all"]) + @pytest.mark.parametrize( + "index,expected", + [ + ( + DatetimeIndex(["2011-01-01", "2011-01-02", "2011-01-03"]), + DatetimeIndex(["2011-01-01", "2011-01-02", "2011-01-03"]), + ), + ( + DatetimeIndex(["2011-01-01", "2011-01-02", "2011-01-03", pd.NaT]), + DatetimeIndex(["2011-01-01", "2011-01-02", "2011-01-03"]), + ), + ( + TimedeltaIndex(["1 days", "2 days", "3 days"]), + TimedeltaIndex(["1 days", "2 days", "3 days"]), + ), + ( + TimedeltaIndex([pd.NaT, "1 days", "2 days", "3 days", pd.NaT]), + TimedeltaIndex(["1 days", "2 days", "3 days"]), + ), + ( + PeriodIndex(["2012-02", "2012-04", "2012-05"], freq="M"), + PeriodIndex(["2012-02", "2012-04", "2012-05"], freq="M"), + ), + ( + PeriodIndex(["2012-02", "2012-04", "NaT", "2012-05"], freq="M"), + PeriodIndex(["2012-02", "2012-04", "2012-05"], freq="M"), + ), + ], + ) + def test_dropna_dt_like(self, how, index, expected): + result = index.dropna(how=how) + tm.assert_index_equal(result, expected) + + def test_dropna_invalid_how_raises(self): + msg = "invalid how option: xxx" + with pytest.raises(ValueError, match=msg): + Index([1, 2, 3]).dropna(how="xxx") + + @pytest.mark.parametrize( + "index", + [ + Index([np.nan]), + Index([np.nan, 1]), + Index([1, 2, np.nan]), + Index(["a", "b", np.nan]), + pd.to_datetime(["NaT"]), + pd.to_datetime(["NaT", "2000-01-01"]), + pd.to_datetime(["2000-01-01", "NaT", "2000-01-02"]), + pd.to_timedelta(["1 day", "NaT"]), + ], + ) + def test_is_monotonic_na(self, index): + assert index.is_monotonic_increasing is False + assert index.is_monotonic_decreasing is False + assert index._is_strictly_monotonic_increasing is False + assert index._is_strictly_monotonic_decreasing is False + + @pytest.mark.parametrize("dtype", ["f8", "m8[ns]", "M8[us]"]) + @pytest.mark.parametrize("unique_first", [True, False]) + def test_is_monotonic_unique_na(self, dtype, unique_first): + # GH 55755 + index = Index([None, 1, 1], dtype=dtype) + if unique_first: + assert index.is_unique is False + assert index.is_monotonic_increasing is False + assert index.is_monotonic_decreasing is False + else: + assert index.is_monotonic_increasing is False + assert index.is_monotonic_decreasing is False + assert index.is_unique is False + + def test_int_name_format(self, frame_or_series): + index = Index(["a", "b", "c"], name=0) + result = frame_or_series(list(range(3)), index=index) + assert "0" in repr(result) + + def test_str_to_bytes_raises(self): + # GH 26447 + index = Index([str(x) for x in range(10)]) + msg = "^'str' object cannot be interpreted as an integer$" + with pytest.raises(TypeError, match=msg): + bytes(index) + + @pytest.mark.filterwarnings("ignore:elementwise comparison failed:FutureWarning") + def test_index_with_tuple_bool(self): + # GH34123 + # TODO: also this op right now produces FutureWarning from numpy + # https://github.com/numpy/numpy/issues/11521 + idx = Index([("a", "b"), ("b", "c"), ("c", "a")]) + result = idx == ("c", "a") + expected = np.array([False, False, True]) + tm.assert_numpy_array_equal(result, expected) + + +class TestIndexUtils: + @pytest.mark.parametrize( + "data, names, expected", + [ + ([[1, 2, 3]], None, Index([1, 2, 3])), + ([[1, 2, 3]], ["name"], Index([1, 2, 3], name="name")), + ( + [["a", "a"], ["c", "d"]], + None, + MultiIndex([["a"], ["c", "d"]], [[0, 0], [0, 1]]), + ), + ( + [["a", "a"], ["c", "d"]], + ["L1", "L2"], + MultiIndex([["a"], ["c", "d"]], [[0, 0], [0, 1]], names=["L1", "L2"]), + ), + ], + ) + def test_ensure_index_from_sequences(self, data, names, expected): + result = ensure_index_from_sequences(data, names) + tm.assert_index_equal(result, expected) + + def test_ensure_index_mixed_closed_intervals(self): + # GH27172 + intervals = [ + pd.Interval(0, 1, closed="left"), + pd.Interval(1, 2, closed="right"), + pd.Interval(2, 3, closed="neither"), + pd.Interval(3, 4, closed="both"), + ] + result = ensure_index(intervals) + expected = Index(intervals, dtype=object) + tm.assert_index_equal(result, expected) + + def test_ensure_index_uint64(self): + # with both 0 and a large-uint64, np.array will infer to float64 + # https://github.com/numpy/numpy/issues/19146 + # but a more accurate choice would be uint64 + values = [0, np.iinfo(np.uint64).max] + + result = ensure_index(values) + assert list(result) == values + + expected = Index(values, dtype="uint64") + tm.assert_index_equal(result, expected) + + def test_get_combined_index(self): + result = _get_combined_index([]) + expected = Index([]) + tm.assert_index_equal(result, expected) + + +@pytest.mark.parametrize( + "opname", + [ + "eq", + "ne", + "le", + "lt", + "ge", + "gt", + "add", + "radd", + "sub", + "rsub", + "mul", + "rmul", + "truediv", + "rtruediv", + "floordiv", + "rfloordiv", + "pow", + "rpow", + "mod", + "divmod", + ], +) +def test_generated_op_names(opname, index): + opname = f"__{opname}__" + method = getattr(index, opname) + assert method.__name__ == opname + + +@pytest.mark.parametrize( + "klass", + [ + partial(CategoricalIndex, data=[1]), + partial(DatetimeIndex, data=["2020-01-01"]), + partial(PeriodIndex, data=["2020-01-01"]), + partial(TimedeltaIndex, data=["1 day"]), + partial(RangeIndex, data=range(1)), + partial(IntervalIndex, data=[pd.Interval(0, 1)]), + partial(Index, data=["a"], dtype=object), + partial(MultiIndex, levels=[1], codes=[0]), + ], +) +def test_index_subclass_constructor_wrong_kwargs(klass): + # GH #19348 + with pytest.raises(TypeError, match="unexpected keyword argument"): + klass(foo="bar") + + +def test_deprecated_fastpath(): + msg = "[Uu]nexpected keyword argument" + with pytest.raises(TypeError, match=msg): + Index(np.array(["a", "b"], dtype=object), name="test", fastpath=True) + + with pytest.raises(TypeError, match=msg): + Index(np.array([1, 2, 3], dtype="int64"), name="test", fastpath=True) + + with pytest.raises(TypeError, match=msg): + RangeIndex(0, 5, 2, name="test", fastpath=True) + + with pytest.raises(TypeError, match=msg): + CategoricalIndex(["a", "b", "c"], name="test", fastpath=True) + + +def test_shape_of_invalid_index(): + # Pre-2.0, it was possible to create "invalid" index objects backed by + # a multi-dimensional array (see https://github.com/pandas-dev/pandas/issues/27125 + # about this). However, as long as this is not solved in general,this test ensures + # that the returned shape is consistent with this underlying array for + # compat with matplotlib (see https://github.com/pandas-dev/pandas/issues/27775) + idx = Index([0, 1, 2, 3]) + with pytest.raises(ValueError, match="Multi-dimensional indexing"): + # GH#30588 multi-dimensional indexing deprecated + idx[:, None] + + +@pytest.mark.parametrize("dtype", [None, np.int64, np.uint64, np.float64]) +def test_validate_1d_input(dtype): + # GH#27125 check that we do not have >1-dimensional input + msg = "Index data must be 1-dimensional" + + arr = np.arange(8).reshape(2, 2, 2) + with pytest.raises(ValueError, match=msg): + Index(arr, dtype=dtype) + + df = DataFrame(arr.reshape(4, 2)) + with pytest.raises(ValueError, match=msg): + Index(df, dtype=dtype) + + # GH#13601 trying to assign a multi-dimensional array to an index is not allowed + ser = Series(0, range(4)) + with pytest.raises(ValueError, match=msg): + ser.index = np.array([[2, 3]] * 4, dtype=dtype) + + +@pytest.mark.parametrize( + "klass, extra_kwargs", + [ + [Index, {}], + *[[lambda x: Index(x, dtype=dtyp), {}] for dtyp in tm.ALL_REAL_NUMPY_DTYPES], + [DatetimeIndex, {}], + [TimedeltaIndex, {}], + [PeriodIndex, {"freq": "Y"}], + ], +) +def test_construct_from_memoryview(klass, extra_kwargs): + # GH 13120 + result = klass(memoryview(np.arange(2000, 2005)), **extra_kwargs) + expected = klass(list(range(2000, 2005)), **extra_kwargs) + tm.assert_index_equal(result, expected, exact=True) + + +@pytest.mark.parametrize("op", [operator.lt, operator.gt]) +def test_nan_comparison_same_object(op): + # GH#47105 + idx = Index([np.nan]) + expected = np.array([False]) + + result = op(idx, idx) + tm.assert_numpy_array_equal(result, expected) + + result = op(idx, idx.copy()) + tm.assert_numpy_array_equal(result, expected) + + +@td.skip_if_no("pyarrow") +def test_is_monotonic_pyarrow_list_type(): + # GH 57333 + import pyarrow as pa + + idx = Index([[1], [2, 3]], dtype=pd.ArrowDtype(pa.list_(pa.int64()))) + assert not idx.is_monotonic_increasing + assert not idx.is_monotonic_decreasing diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_common.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_common.py new file mode 100644 index 0000000000000000000000000000000000000000..05b2aa584674c58e38c649deb8e821ce2630672b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_common.py @@ -0,0 +1,512 @@ +""" +Collection of tests asserting things that should be true for +any index subclass except for MultiIndex. Makes use of the `index_flat` +fixture defined in pandas/conftest.py. +""" +from copy import ( + copy, + deepcopy, +) +import re + +import numpy as np +import pytest + +from pandas.compat import IS64 +from pandas.compat.numpy import np_version_gte1p25 + +from pandas.core.dtypes.common import ( + is_integer_dtype, + is_numeric_dtype, +) + +import pandas as pd +from pandas import ( + CategoricalIndex, + MultiIndex, + PeriodIndex, + RangeIndex, +) +import pandas._testing as tm + + +class TestCommon: + @pytest.mark.parametrize("name", [None, "new_name"]) + def test_to_frame(self, name, index_flat, using_copy_on_write): + # see GH#15230, GH#22580 + idx = index_flat + + if name: + idx_name = name + else: + idx_name = idx.name or 0 + + df = idx.to_frame(name=idx_name) + + assert df.index is idx + assert len(df.columns) == 1 + assert df.columns[0] == idx_name + if not using_copy_on_write: + assert df[idx_name].values is not idx.values + + df = idx.to_frame(index=False, name=idx_name) + assert df.index is not idx + + def test_droplevel(self, index_flat): + # GH 21115 + # MultiIndex is tested separately in test_multi.py + index = index_flat + + assert index.droplevel([]).equals(index) + + for level in [index.name, [index.name]]: + if isinstance(index.name, tuple) and level is index.name: + # GH 21121 : droplevel with tuple name + continue + msg = ( + "Cannot remove 1 levels from an index with 1 levels: at least one " + "level must be left." + ) + with pytest.raises(ValueError, match=msg): + index.droplevel(level) + + for level in "wrong", ["wrong"]: + with pytest.raises( + KeyError, + match=r"'Requested level \(wrong\) does not match index name \(None\)'", + ): + index.droplevel(level) + + def test_constructor_non_hashable_name(self, index_flat): + # GH 20527 + index = index_flat + + message = "Index.name must be a hashable type" + renamed = [["1"]] + + # With .rename() + with pytest.raises(TypeError, match=message): + index.rename(name=renamed) + + # With .set_names() + with pytest.raises(TypeError, match=message): + index.set_names(names=renamed) + + def test_constructor_unwraps_index(self, index_flat): + a = index_flat + # Passing dtype is necessary for Index([True, False], dtype=object) + # case. + b = type(a)(a, dtype=a.dtype) + tm.assert_equal(a._data, b._data) + + def test_to_flat_index(self, index_flat): + # 22866 + index = index_flat + + result = index.to_flat_index() + tm.assert_index_equal(result, index) + + def test_set_name_methods(self, index_flat): + # MultiIndex tested separately + index = index_flat + new_name = "This is the new name for this index" + + original_name = index.name + new_ind = index.set_names([new_name]) + assert new_ind.name == new_name + assert index.name == original_name + res = index.rename(new_name, inplace=True) + + # should return None + assert res is None + assert index.name == new_name + assert index.names == [new_name] + with pytest.raises(ValueError, match="Level must be None"): + index.set_names("a", level=0) + + # rename in place just leaves tuples and other containers alone + name = ("A", "B") + index.rename(name, inplace=True) + assert index.name == name + assert index.names == [name] + + @pytest.mark.xfail + def test_set_names_single_label_no_level(self, index_flat): + with pytest.raises(TypeError, match="list-like"): + # should still fail even if it would be the right length + index_flat.set_names("a") + + def test_copy_and_deepcopy(self, index_flat): + index = index_flat + + for func in (copy, deepcopy): + idx_copy = func(index) + assert idx_copy is not index + assert idx_copy.equals(index) + + new_copy = index.copy(deep=True, name="banana") + assert new_copy.name == "banana" + + def test_copy_name(self, index_flat): + # GH#12309: Check that the "name" argument + # passed at initialization is honored. + index = index_flat + + first = type(index)(index, copy=True, name="mario") + second = type(first)(first, copy=False) + + # Even though "copy=False", we want a new object. + assert first is not second + tm.assert_index_equal(first, second) + + # Not using tm.assert_index_equal() since names differ. + assert index.equals(first) + + assert first.name == "mario" + assert second.name == "mario" + + # TODO: belongs in series arithmetic tests? + s1 = pd.Series(2, index=first) + s2 = pd.Series(3, index=second[:-1]) + # See GH#13365 + s3 = s1 * s2 + assert s3.index.name == "mario" + + def test_copy_name2(self, index_flat): + # GH#35592 + index = index_flat + + assert index.copy(name="mario").name == "mario" + + with pytest.raises(ValueError, match="Length of new names must be 1, got 2"): + index.copy(name=["mario", "luigi"]) + + msg = f"{type(index).__name__}.name must be a hashable type" + with pytest.raises(TypeError, match=msg): + index.copy(name=[["mario"]]) + + def test_unique_level(self, index_flat): + # don't test a MultiIndex here (as its tested separated) + index = index_flat + + # GH 17896 + expected = index.drop_duplicates() + for level in [0, index.name, None]: + result = index.unique(level=level) + tm.assert_index_equal(result, expected) + + msg = "Too many levels: Index has only 1 level, not 4" + with pytest.raises(IndexError, match=msg): + index.unique(level=3) + + msg = ( + rf"Requested level \(wrong\) does not match index name " + rf"\({re.escape(index.name.__repr__())}\)" + ) + with pytest.raises(KeyError, match=msg): + index.unique(level="wrong") + + def test_unique(self, index_flat): + # MultiIndex tested separately + index = index_flat + if not len(index): + pytest.skip("Skip check for empty Index and MultiIndex") + + idx = index[[0] * 5] + idx_unique = index[[0]] + + # We test against `idx_unique`, so first we make sure it's unique + # and doesn't contain nans. + assert idx_unique.is_unique is True + try: + assert idx_unique.hasnans is False + except NotImplementedError: + pass + + result = idx.unique() + tm.assert_index_equal(result, idx_unique) + + # nans: + if not index._can_hold_na: + pytest.skip("Skip na-check if index cannot hold na") + + vals = index._values[[0] * 5] + vals[0] = np.nan + + vals_unique = vals[:2] + idx_nan = index._shallow_copy(vals) + idx_unique_nan = index._shallow_copy(vals_unique) + assert idx_unique_nan.is_unique is True + + assert idx_nan.dtype == index.dtype + assert idx_unique_nan.dtype == index.dtype + + expected = idx_unique_nan + for pos, i in enumerate([idx_nan, idx_unique_nan]): + result = i.unique() + tm.assert_index_equal(result, expected) + + @pytest.mark.filterwarnings("ignore:Period with BDay freq:FutureWarning") + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + def test_searchsorted_monotonic(self, index_flat, request): + # GH17271 + index = index_flat + # not implemented for tuple searches in MultiIndex + # or Intervals searches in IntervalIndex + if isinstance(index, pd.IntervalIndex): + mark = pytest.mark.xfail( + reason="IntervalIndex.searchsorted does not support Interval arg", + raises=NotImplementedError, + ) + request.applymarker(mark) + + # nothing to test if the index is empty + if index.empty: + pytest.skip("Skip check for empty Index") + value = index[0] + + # determine the expected results (handle dupes for 'right') + expected_left, expected_right = 0, (index == value).argmin() + if expected_right == 0: + # all values are the same, expected_right should be length + expected_right = len(index) + + # test _searchsorted_monotonic in all cases + # test searchsorted only for increasing + if index.is_monotonic_increasing: + ssm_left = index._searchsorted_monotonic(value, side="left") + assert expected_left == ssm_left + + ssm_right = index._searchsorted_monotonic(value, side="right") + assert expected_right == ssm_right + + ss_left = index.searchsorted(value, side="left") + assert expected_left == ss_left + + ss_right = index.searchsorted(value, side="right") + assert expected_right == ss_right + + elif index.is_monotonic_decreasing: + ssm_left = index._searchsorted_monotonic(value, side="left") + assert expected_left == ssm_left + + ssm_right = index._searchsorted_monotonic(value, side="right") + assert expected_right == ssm_right + else: + # non-monotonic should raise. + msg = "index must be monotonic increasing or decreasing" + with pytest.raises(ValueError, match=msg): + index._searchsorted_monotonic(value, side="left") + + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + def test_drop_duplicates(self, index_flat, keep): + # MultiIndex is tested separately + index = index_flat + if isinstance(index, RangeIndex): + pytest.skip( + "RangeIndex is tested in test_drop_duplicates_no_duplicates " + "as it cannot hold duplicates" + ) + if len(index) == 0: + pytest.skip( + "empty index is tested in test_drop_duplicates_no_duplicates " + "as it cannot hold duplicates" + ) + + # make unique index + holder = type(index) + unique_values = list(set(index)) + dtype = index.dtype if is_numeric_dtype(index) else None + unique_idx = holder(unique_values, dtype=dtype) + + # make duplicated index + n = len(unique_idx) + duplicated_selection = np.random.default_rng(2).choice(n, int(n * 1.5)) + idx = holder(unique_idx.values[duplicated_selection]) + + # Series.duplicated is tested separately + expected_duplicated = ( + pd.Series(duplicated_selection).duplicated(keep=keep).values + ) + tm.assert_numpy_array_equal(idx.duplicated(keep=keep), expected_duplicated) + + # Series.drop_duplicates is tested separately + expected_dropped = holder(pd.Series(idx).drop_duplicates(keep=keep)) + tm.assert_index_equal(idx.drop_duplicates(keep=keep), expected_dropped) + + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + def test_drop_duplicates_no_duplicates(self, index_flat): + # MultiIndex is tested separately + index = index_flat + + # make unique index + if isinstance(index, RangeIndex): + # RangeIndex cannot have duplicates + unique_idx = index + else: + holder = type(index) + unique_values = list(set(index)) + dtype = index.dtype if is_numeric_dtype(index) else None + unique_idx = holder(unique_values, dtype=dtype) + + # check on unique index + expected_duplicated = np.array([False] * len(unique_idx), dtype="bool") + tm.assert_numpy_array_equal(unique_idx.duplicated(), expected_duplicated) + result_dropped = unique_idx.drop_duplicates() + tm.assert_index_equal(result_dropped, unique_idx) + # validate shallow copy + assert result_dropped is not unique_idx + + def test_drop_duplicates_inplace(self, index): + msg = r"drop_duplicates\(\) got an unexpected keyword argument" + with pytest.raises(TypeError, match=msg): + index.drop_duplicates(inplace=True) + + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + def test_has_duplicates(self, index_flat): + # MultiIndex tested separately in: + # tests/indexes/multi/test_unique_and_duplicates. + index = index_flat + holder = type(index) + if not len(index) or isinstance(index, RangeIndex): + # MultiIndex tested separately in: + # tests/indexes/multi/test_unique_and_duplicates. + # RangeIndex is unique by definition. + pytest.skip("Skip check for empty Index, MultiIndex, and RangeIndex") + + idx = holder([index[0]] * 5) + assert idx.is_unique is False + assert idx.has_duplicates is True + + @pytest.mark.parametrize( + "dtype", + ["int64", "uint64", "float64", "category", "datetime64[ns]", "timedelta64[ns]"], + ) + def test_astype_preserves_name(self, index, dtype): + # https://github.com/pandas-dev/pandas/issues/32013 + if isinstance(index, MultiIndex): + index.names = ["idx" + str(i) for i in range(index.nlevels)] + else: + index.name = "idx" + + warn = None + if index.dtype.kind == "c" and dtype in ["float64", "int64", "uint64"]: + # imaginary components discarded + if np_version_gte1p25: + warn = np.exceptions.ComplexWarning + else: + warn = np.ComplexWarning + + is_pyarrow_str = str(index.dtype) == "string[pyarrow]" and dtype == "category" + try: + # Some of these conversions cannot succeed so we use a try / except + with tm.assert_produces_warning( + warn, + raise_on_extra_warnings=is_pyarrow_str, + check_stacklevel=False, + ): + result = index.astype(dtype) + except (ValueError, TypeError, NotImplementedError, SystemError): + return + + if isinstance(index, MultiIndex): + assert result.names == index.names + else: + assert result.name == index.name + + def test_hasnans_isnans(self, index_flat): + # GH#11343, added tests for hasnans / isnans + index = index_flat + + # cases in indices doesn't include NaN + idx = index.copy(deep=True) + expected = np.array([False] * len(idx), dtype=bool) + tm.assert_numpy_array_equal(idx._isnan, expected) + assert idx.hasnans is False + + idx = index.copy(deep=True) + values = idx._values + + if len(index) == 0: + return + elif is_integer_dtype(index.dtype): + return + elif index.dtype == bool: + # values[1] = np.nan below casts to True! + return + + values[1] = np.nan + + idx = type(index)(values) + + expected = np.array([False] * len(idx), dtype=bool) + expected[1] = True + tm.assert_numpy_array_equal(idx._isnan, expected) + assert idx.hasnans is True + + +@pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") +@pytest.mark.parametrize("na_position", [None, "middle"]) +def test_sort_values_invalid_na_position(index_with_missing, na_position): + with pytest.raises(ValueError, match=f"invalid na_position: {na_position}"): + index_with_missing.sort_values(na_position=na_position) + + +@pytest.mark.fails_arm_wheels +@pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") +@pytest.mark.parametrize("na_position", ["first", "last"]) +def test_sort_values_with_missing(index_with_missing, na_position, request): + # GH 35584. Test that sort_values works with missing values, + # sort non-missing and place missing according to na_position + + if isinstance(index_with_missing, CategoricalIndex): + request.applymarker( + pytest.mark.xfail( + reason="missing value sorting order not well-defined", strict=False + ) + ) + + missing_count = np.sum(index_with_missing.isna()) + not_na_vals = index_with_missing[index_with_missing.notna()].values + sorted_values = np.sort(not_na_vals) + if na_position == "first": + sorted_values = np.concatenate([[None] * missing_count, sorted_values]) + else: + sorted_values = np.concatenate([sorted_values, [None] * missing_count]) + + # Explicitly pass dtype needed for Index backed by EA e.g. IntegerArray + expected = type(index_with_missing)(sorted_values, dtype=index_with_missing.dtype) + + result = index_with_missing.sort_values(na_position=na_position) + tm.assert_index_equal(result, expected) + + +def test_ndarray_compat_properties(index): + if isinstance(index, PeriodIndex) and not IS64: + pytest.skip("Overflow") + idx = index + assert idx.T.equals(idx) + assert idx.transpose().equals(idx) + + values = idx.values + + assert idx.shape == values.shape + assert idx.ndim == values.ndim + assert idx.size == values.size + + if not isinstance(index, (RangeIndex, MultiIndex)): + # These two are not backed by an ndarray + assert idx.nbytes == values.nbytes + + # test for validity + idx.nbytes + idx.values.nbytes + + +def test_compare_read_only_array(): + # GH#57130 + arr = np.array([], dtype=object) + arr.flags.writeable = False + idx = pd.Index(arr) + result = idx > 69 + assert result.dtype == bool diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_datetimelike.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_datetimelike.py new file mode 100644 index 0000000000000000000000000000000000000000..21a686e8bc05b09729c6fe54e67c96405ee36bca --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_datetimelike.py @@ -0,0 +1,171 @@ +""" generic datetimelike tests """ + +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm + + +class TestDatetimeLike: + @pytest.fixture( + params=[ + pd.period_range("20130101", periods=5, freq="D"), + pd.TimedeltaIndex( + [ + "0 days 01:00:00", + "1 days 01:00:00", + "2 days 01:00:00", + "3 days 01:00:00", + "4 days 01:00:00", + ], + dtype="timedelta64[ns]", + freq="D", + ), + pd.DatetimeIndex( + ["2013-01-01", "2013-01-02", "2013-01-03", "2013-01-04", "2013-01-05"], + dtype="datetime64[ns]", + freq="D", + ), + ] + ) + def simple_index(self, request): + return request.param + + def test_isin(self, simple_index): + index = simple_index[:4] + result = index.isin(index) + assert result.all() + + result = index.isin(list(index)) + assert result.all() + + result = index.isin([index[2], 5]) + expected = np.array([False, False, True, False]) + tm.assert_numpy_array_equal(result, expected) + + def test_argsort_matches_array(self, simple_index): + idx = simple_index + idx = idx.insert(1, pd.NaT) + + result = idx.argsort() + expected = idx._data.argsort() + tm.assert_numpy_array_equal(result, expected) + + def test_can_hold_identifiers(self, simple_index): + idx = simple_index + key = idx[0] + assert idx._can_hold_identifiers_and_holds_name(key) is False + + def test_shift_identity(self, simple_index): + idx = simple_index + tm.assert_index_equal(idx, idx.shift(0)) + + def test_shift_empty(self, simple_index): + # GH#14811 + idx = simple_index[:0] + tm.assert_index_equal(idx, idx.shift(1)) + + def test_str(self, simple_index): + # test the string repr + idx = simple_index.copy() + idx.name = "foo" + assert f"length={len(idx)}" not in str(idx) + assert "'foo'" in str(idx) + assert type(idx).__name__ in str(idx) + + if hasattr(idx, "tz"): + if idx.tz is not None: + assert idx.tz in str(idx) + if isinstance(idx, pd.PeriodIndex): + assert f"dtype='period[{idx.freqstr}]'" in str(idx) + else: + assert f"freq='{idx.freqstr}'" in str(idx) + + def test_view(self, simple_index): + idx = simple_index + + idx_view = idx.view("i8") + result = type(simple_index)(idx) + tm.assert_index_equal(result, idx) + + msg = "Passing a type in .*Index.view is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + idx_view = idx.view(type(simple_index)) + result = type(simple_index)(idx) + tm.assert_index_equal(result, idx_view) + + def test_map_callable(self, simple_index): + index = simple_index + expected = index + index.freq + result = index.map(lambda x: x + index.freq) + tm.assert_index_equal(result, expected) + + # map to NaT + result = index.map(lambda x: pd.NaT if x == index[0] else x) + expected = pd.Index([pd.NaT] + index[1:].tolist()) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "mapper", + [ + lambda values, index: {i: e for e, i in zip(values, index)}, + lambda values, index: pd.Series(values, index, dtype=object), + ], + ) + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + def test_map_dictlike(self, mapper, simple_index): + index = simple_index + expected = index + index.freq + + # don't compare the freqs + if isinstance(expected, (pd.DatetimeIndex, pd.TimedeltaIndex)): + expected = expected._with_freq(None) + + result = index.map(mapper(expected, index)) + tm.assert_index_equal(result, expected) + + expected = pd.Index([pd.NaT] + index[1:].tolist()) + result = index.map(mapper(expected, index)) + tm.assert_index_equal(result, expected) + + # empty map; these map to np.nan because we cannot know + # to re-infer things + expected = pd.Index([np.nan] * len(index)) + result = index.map(mapper([], [])) + tm.assert_index_equal(result, expected) + + def test_getitem_preserves_freq(self, simple_index): + index = simple_index + assert index.freq is not None + + result = index[:] + assert result.freq == index.freq + + def test_where_cast_str(self, simple_index): + index = simple_index + + mask = np.ones(len(index), dtype=bool) + mask[-1] = False + + result = index.where(mask, str(index[0])) + expected = index.where(mask, index[0]) + tm.assert_index_equal(result, expected) + + result = index.where(mask, [str(index[0])]) + tm.assert_index_equal(result, expected) + + expected = index.astype(object).where(mask, "foo") + result = index.where(mask, "foo") + tm.assert_index_equal(result, expected) + + result = index.where(mask, ["foo"]) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("unit", ["ns", "us", "ms", "s"]) + def test_diff(self, unit): + # GH 55080 + dti = pd.to_datetime([10, 20, 30], unit=unit).as_unit(unit) + result = dti.diff(1) + expected = pd.to_timedelta([pd.NaT, 10, 10], unit=unit).as_unit(unit) + tm.assert_index_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_engines.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_engines.py new file mode 100644 index 0000000000000000000000000000000000000000..468c2240c8192098a6ff75a5a2d0210c8108a176 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_engines.py @@ -0,0 +1,192 @@ +import re + +import numpy as np +import pytest + +from pandas._libs import index as libindex + +import pandas as pd + + +@pytest.fixture( + params=[ + (libindex.Int64Engine, np.int64), + (libindex.Int32Engine, np.int32), + (libindex.Int16Engine, np.int16), + (libindex.Int8Engine, np.int8), + (libindex.UInt64Engine, np.uint64), + (libindex.UInt32Engine, np.uint32), + (libindex.UInt16Engine, np.uint16), + (libindex.UInt8Engine, np.uint8), + (libindex.Float64Engine, np.float64), + (libindex.Float32Engine, np.float32), + ], + ids=lambda x: x[0].__name__, +) +def numeric_indexing_engine_type_and_dtype(request): + return request.param + + +class TestDatetimeEngine: + @pytest.mark.parametrize( + "scalar", + [ + pd.Timedelta(pd.Timestamp("2016-01-01").asm8.view("m8[ns]")), + pd.Timestamp("2016-01-01")._value, + pd.Timestamp("2016-01-01").to_pydatetime(), + pd.Timestamp("2016-01-01").to_datetime64(), + ], + ) + def test_not_contains_requires_timestamp(self, scalar): + dti1 = pd.date_range("2016-01-01", periods=3) + dti2 = dti1.insert(1, pd.NaT) # non-monotonic + dti3 = dti1.insert(3, dti1[0]) # non-unique + dti4 = pd.date_range("2016-01-01", freq="ns", periods=2_000_000) + dti5 = dti4.insert(0, dti4[0]) # over size threshold, not unique + + msg = "|".join([re.escape(str(scalar)), re.escape(repr(scalar))]) + for dti in [dti1, dti2, dti3, dti4, dti5]: + with pytest.raises(TypeError, match=msg): + scalar in dti._engine + + with pytest.raises(KeyError, match=msg): + dti._engine.get_loc(scalar) + + +class TestTimedeltaEngine: + @pytest.mark.parametrize( + "scalar", + [ + pd.Timestamp(pd.Timedelta(days=42).asm8.view("datetime64[ns]")), + pd.Timedelta(days=42)._value, + pd.Timedelta(days=42).to_pytimedelta(), + pd.Timedelta(days=42).to_timedelta64(), + ], + ) + def test_not_contains_requires_timedelta(self, scalar): + tdi1 = pd.timedelta_range("42 days", freq="9h", periods=1234) + tdi2 = tdi1.insert(1, pd.NaT) # non-monotonic + tdi3 = tdi1.insert(3, tdi1[0]) # non-unique + tdi4 = pd.timedelta_range("42 days", freq="ns", periods=2_000_000) + tdi5 = tdi4.insert(0, tdi4[0]) # over size threshold, not unique + + msg = "|".join([re.escape(str(scalar)), re.escape(repr(scalar))]) + for tdi in [tdi1, tdi2, tdi3, tdi4, tdi5]: + with pytest.raises(TypeError, match=msg): + scalar in tdi._engine + + with pytest.raises(KeyError, match=msg): + tdi._engine.get_loc(scalar) + + +class TestNumericEngine: + def test_is_monotonic(self, numeric_indexing_engine_type_and_dtype): + engine_type, dtype = numeric_indexing_engine_type_and_dtype + num = 1000 + arr = np.array([1] * num + [2] * num + [3] * num, dtype=dtype) + + # monotonic increasing + engine = engine_type(arr) + assert engine.is_monotonic_increasing is True + assert engine.is_monotonic_decreasing is False + + # monotonic decreasing + engine = engine_type(arr[::-1]) + assert engine.is_monotonic_increasing is False + assert engine.is_monotonic_decreasing is True + + # neither monotonic increasing or decreasing + arr = np.array([1] * num + [2] * num + [1] * num, dtype=dtype) + engine = engine_type(arr[::-1]) + assert engine.is_monotonic_increasing is False + assert engine.is_monotonic_decreasing is False + + def test_is_unique(self, numeric_indexing_engine_type_and_dtype): + engine_type, dtype = numeric_indexing_engine_type_and_dtype + + # unique + arr = np.array([1, 3, 2], dtype=dtype) + engine = engine_type(arr) + assert engine.is_unique is True + + # not unique + arr = np.array([1, 2, 1], dtype=dtype) + engine = engine_type(arr) + assert engine.is_unique is False + + def test_get_loc(self, numeric_indexing_engine_type_and_dtype): + engine_type, dtype = numeric_indexing_engine_type_and_dtype + + # unique + arr = np.array([1, 2, 3], dtype=dtype) + engine = engine_type(arr) + assert engine.get_loc(2) == 1 + + # monotonic + num = 1000 + arr = np.array([1] * num + [2] * num + [3] * num, dtype=dtype) + engine = engine_type(arr) + assert engine.get_loc(2) == slice(1000, 2000) + + # not monotonic + arr = np.array([1, 2, 3] * num, dtype=dtype) + engine = engine_type(arr) + expected = np.array([False, True, False] * num, dtype=bool) + result = engine.get_loc(2) + assert (result == expected).all() + + +class TestObjectEngine: + engine_type = libindex.ObjectEngine + dtype = np.object_ + values = list("abc") + + def test_is_monotonic(self): + num = 1000 + arr = np.array(["a"] * num + ["a"] * num + ["c"] * num, dtype=self.dtype) + + # monotonic increasing + engine = self.engine_type(arr) + assert engine.is_monotonic_increasing is True + assert engine.is_monotonic_decreasing is False + + # monotonic decreasing + engine = self.engine_type(arr[::-1]) + assert engine.is_monotonic_increasing is False + assert engine.is_monotonic_decreasing is True + + # neither monotonic increasing or decreasing + arr = np.array(["a"] * num + ["b"] * num + ["a"] * num, dtype=self.dtype) + engine = self.engine_type(arr[::-1]) + assert engine.is_monotonic_increasing is False + assert engine.is_monotonic_decreasing is False + + def test_is_unique(self): + # unique + arr = np.array(self.values, dtype=self.dtype) + engine = self.engine_type(arr) + assert engine.is_unique is True + + # not unique + arr = np.array(["a", "b", "a"], dtype=self.dtype) + engine = self.engine_type(arr) + assert engine.is_unique is False + + def test_get_loc(self): + # unique + arr = np.array(self.values, dtype=self.dtype) + engine = self.engine_type(arr) + assert engine.get_loc("b") == 1 + + # monotonic + num = 1000 + arr = np.array(["a"] * num + ["b"] * num + ["c"] * num, dtype=self.dtype) + engine = self.engine_type(arr) + assert engine.get_loc("b") == slice(1000, 2000) + + # not monotonic + arr = np.array(self.values * num, dtype=self.dtype) + engine = self.engine_type(arr) + expected = np.array([False, True, False] * num, dtype=bool) + result = engine.get_loc("b") + assert (result == expected).all() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_frozen.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_frozen.py new file mode 100644 index 0000000000000000000000000000000000000000..ace66b5b06a51291d2cf229fdc446d070054836a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_frozen.py @@ -0,0 +1,113 @@ +import re + +import pytest + +from pandas.core.indexes.frozen import FrozenList + + +@pytest.fixture +def lst(): + return [1, 2, 3, 4, 5] + + +@pytest.fixture +def container(lst): + return FrozenList(lst) + + +@pytest.fixture +def unicode_container(): + return FrozenList(["\u05d0", "\u05d1", "c"]) + + +class TestFrozenList: + def check_mutable_error(self, *args, **kwargs): + # Pass whatever function you normally would to pytest.raises + # (after the Exception kind). + mutable_regex = re.compile("does not support mutable operations") + msg = "'(_s)?re.(SRE_)?Pattern' object is not callable" + with pytest.raises(TypeError, match=msg): + mutable_regex(*args, **kwargs) + + def test_no_mutable_funcs(self, container): + def setitem(): + container[0] = 5 + + self.check_mutable_error(setitem) + + def setslice(): + container[1:2] = 3 + + self.check_mutable_error(setslice) + + def delitem(): + del container[0] + + self.check_mutable_error(delitem) + + def delslice(): + del container[0:3] + + self.check_mutable_error(delslice) + + mutable_methods = ("extend", "pop", "remove", "insert") + + for meth in mutable_methods: + self.check_mutable_error(getattr(container, meth)) + + def test_slicing_maintains_type(self, container, lst): + result = container[1:2] + expected = lst[1:2] + self.check_result(result, expected) + + def check_result(self, result, expected): + assert isinstance(result, FrozenList) + assert result == expected + + def test_string_methods_dont_fail(self, container): + repr(container) + str(container) + bytes(container) + + def test_tricky_container(self, unicode_container): + repr(unicode_container) + str(unicode_container) + + def test_add(self, container, lst): + result = container + (1, 2, 3) + expected = FrozenList(lst + [1, 2, 3]) + self.check_result(result, expected) + + result = (1, 2, 3) + container + expected = FrozenList([1, 2, 3] + lst) + self.check_result(result, expected) + + def test_iadd(self, container, lst): + q = r = container + + q += [5] + self.check_result(q, lst + [5]) + + # Other shouldn't be mutated. + self.check_result(r, lst) + + def test_union(self, container, lst): + result = container.union((1, 2, 3)) + expected = FrozenList(lst + [1, 2, 3]) + self.check_result(result, expected) + + def test_difference(self, container): + result = container.difference([2]) + expected = FrozenList([1, 3, 4, 5]) + self.check_result(result, expected) + + def test_difference_dupe(self): + result = FrozenList([1, 2, 3, 2]).difference([2]) + expected = FrozenList([1, 3]) + self.check_result(result, expected) + + def test_tricky_container_to_bytes_raises(self, unicode_container): + # GH 26447 + msg = "^'str' object cannot be interpreted as an integer$" + with pytest.raises(TypeError, match=msg): + bytes(unicode_container) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_index_new.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_index_new.py new file mode 100644 index 0000000000000000000000000000000000000000..6042e5b9cc6793018ccf26f37aec236dfa353393 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_index_new.py @@ -0,0 +1,432 @@ +""" +Tests for the Index constructor conducting inference. +""" +from datetime import ( + datetime, + timedelta, + timezone, +) +from decimal import Decimal + +import numpy as np +import pytest + +from pandas._libs.tslibs.timezones import maybe_get_tz + +from pandas import ( + NA, + Categorical, + CategoricalIndex, + DatetimeIndex, + Index, + IntervalIndex, + MultiIndex, + NaT, + PeriodIndex, + Series, + TimedeltaIndex, + Timestamp, + array, + date_range, + period_range, + timedelta_range, +) +import pandas._testing as tm + + +class TestIndexConstructorInference: + def test_object_all_bools(self): + # GH#49594 match Series behavior on ndarray[object] of all bools + arr = np.array([True, False], dtype=object) + res = Index(arr) + assert res.dtype == object + + # since the point is matching Series behavior, let's double check + assert Series(arr).dtype == object + + def test_object_all_complex(self): + # GH#49594 match Series behavior on ndarray[object] of all complex + arr = np.array([complex(1), complex(2)], dtype=object) + res = Index(arr) + assert res.dtype == object + + # since the point is matching Series behavior, let's double check + assert Series(arr).dtype == object + + @pytest.mark.parametrize("val", [NaT, None, np.nan, float("nan")]) + def test_infer_nat(self, val): + # GH#49340 all NaT/None/nan and at least 1 NaT -> datetime64[ns], + # matching Series behavior + values = [NaT, val] + + idx = Index(values) + assert idx.dtype == "datetime64[ns]" and idx.isna().all() + + idx = Index(values[::-1]) + assert idx.dtype == "datetime64[ns]" and idx.isna().all() + + idx = Index(np.array(values, dtype=object)) + assert idx.dtype == "datetime64[ns]" and idx.isna().all() + + idx = Index(np.array(values, dtype=object)[::-1]) + assert idx.dtype == "datetime64[ns]" and idx.isna().all() + + @pytest.mark.parametrize("na_value", [None, np.nan]) + @pytest.mark.parametrize("vtype", [list, tuple, iter]) + def test_construction_list_tuples_nan(self, na_value, vtype): + # GH#18505 : valid tuples containing NaN + values = [(1, "two"), (3.0, na_value)] + result = Index(vtype(values)) + expected = MultiIndex.from_tuples(values) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "dtype", + [int, "int64", "int32", "int16", "int8", "uint64", "uint32", "uint16", "uint8"], + ) + def test_constructor_int_dtype_float(self, dtype): + # GH#18400 + expected = Index([0, 1, 2, 3], dtype=dtype) + result = Index([0.0, 1.0, 2.0, 3.0], dtype=dtype) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("cast_index", [True, False]) + @pytest.mark.parametrize( + "vals", [[True, False, True], np.array([True, False, True], dtype=bool)] + ) + def test_constructor_dtypes_to_object(self, cast_index, vals): + if cast_index: + index = Index(vals, dtype=bool) + else: + index = Index(vals) + + assert type(index) is Index + assert index.dtype == bool + + def test_constructor_categorical_to_object(self): + # GH#32167 Categorical data and dtype=object should return object-dtype + ci = CategoricalIndex(range(5)) + result = Index(ci, dtype=object) + assert not isinstance(result, CategoricalIndex) + + def test_constructor_infer_periodindex(self): + xp = period_range("2012-1-1", freq="M", periods=3) + rs = Index(xp) + tm.assert_index_equal(rs, xp) + assert isinstance(rs, PeriodIndex) + + def test_from_list_of_periods(self): + rng = period_range("1/1/2000", periods=20, freq="D") + periods = list(rng) + + result = Index(periods) + assert isinstance(result, PeriodIndex) + + @pytest.mark.parametrize("pos", [0, 1]) + @pytest.mark.parametrize( + "klass,dtype,ctor", + [ + (DatetimeIndex, "datetime64[ns]", np.datetime64("nat")), + (TimedeltaIndex, "timedelta64[ns]", np.timedelta64("nat")), + ], + ) + def test_constructor_infer_nat_dt_like( + self, pos, klass, dtype, ctor, nulls_fixture, request + ): + if isinstance(nulls_fixture, Decimal): + # We dont cast these to datetime64/timedelta64 + pytest.skip( + f"We don't cast {type(nulls_fixture).__name__} to " + "datetime64/timedelta64" + ) + + expected = klass([NaT, NaT]) + assert expected.dtype == dtype + data = [ctor] + data.insert(pos, nulls_fixture) + + warn = None + if nulls_fixture is NA: + expected = Index([NA, NaT]) + mark = pytest.mark.xfail(reason="Broken with np.NaT ctor; see GH 31884") + request.applymarker(mark) + # GH#35942 numpy will emit a DeprecationWarning within the + # assert_index_equal calls. Since we can't do anything + # about it until GH#31884 is fixed, we suppress that warning. + warn = DeprecationWarning + + result = Index(data) + + with tm.assert_produces_warning(warn): + tm.assert_index_equal(result, expected) + + result = Index(np.array(data, dtype=object)) + + with tm.assert_produces_warning(warn): + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("swap_objs", [True, False]) + def test_constructor_mixed_nat_objs_infers_object(self, swap_objs): + # mixed np.datetime64/timedelta64 nat results in object + data = [np.datetime64("nat"), np.timedelta64("nat")] + if swap_objs: + data = data[::-1] + + expected = Index(data, dtype=object) + tm.assert_index_equal(Index(data), expected) + tm.assert_index_equal(Index(np.array(data, dtype=object)), expected) + + @pytest.mark.parametrize("swap_objs", [True, False]) + def test_constructor_datetime_and_datetime64(self, swap_objs): + data = [Timestamp(2021, 6, 8, 9, 42), np.datetime64("now")] + if swap_objs: + data = data[::-1] + expected = DatetimeIndex(data) + + tm.assert_index_equal(Index(data), expected) + tm.assert_index_equal(Index(np.array(data, dtype=object)), expected) + + def test_constructor_datetimes_mixed_tzs(self): + # https://github.com/pandas-dev/pandas/pull/55793/files#r1383719998 + tz = maybe_get_tz("US/Central") + dt1 = datetime(2020, 1, 1, tzinfo=tz) + dt2 = datetime(2020, 1, 1, tzinfo=timezone.utc) + result = Index([dt1, dt2]) + expected = Index([dt1, dt2], dtype=object) + tm.assert_index_equal(result, expected) + + +class TestDtypeEnforced: + # check we don't silently ignore the dtype keyword + + def test_constructor_object_dtype_with_ea_data(self, any_numeric_ea_dtype): + # GH#45206 + arr = array([0], dtype=any_numeric_ea_dtype) + + idx = Index(arr, dtype=object) + assert idx.dtype == object + + @pytest.mark.parametrize("dtype", [object, "float64", "uint64", "category"]) + def test_constructor_range_values_mismatched_dtype(self, dtype): + rng = Index(range(5)) + + result = Index(rng, dtype=dtype) + assert result.dtype == dtype + + result = Index(range(5), dtype=dtype) + assert result.dtype == dtype + + @pytest.mark.parametrize("dtype", [object, "float64", "uint64", "category"]) + def test_constructor_categorical_values_mismatched_non_ea_dtype(self, dtype): + cat = Categorical([1, 2, 3]) + + result = Index(cat, dtype=dtype) + assert result.dtype == dtype + + def test_constructor_categorical_values_mismatched_dtype(self): + dti = date_range("2016-01-01", periods=3) + cat = Categorical(dti) + result = Index(cat, dti.dtype) + tm.assert_index_equal(result, dti) + + dti2 = dti.tz_localize("Asia/Tokyo") + cat2 = Categorical(dti2) + result = Index(cat2, dti2.dtype) + tm.assert_index_equal(result, dti2) + + ii = IntervalIndex.from_breaks(range(5)) + cat3 = Categorical(ii) + result = Index(cat3, dtype=ii.dtype) + tm.assert_index_equal(result, ii) + + def test_constructor_ea_values_mismatched_categorical_dtype(self): + dti = date_range("2016-01-01", periods=3) + result = Index(dti, dtype="category") + expected = CategoricalIndex(dti) + tm.assert_index_equal(result, expected) + + dti2 = date_range("2016-01-01", periods=3, tz="US/Pacific") + result = Index(dti2, dtype="category") + expected = CategoricalIndex(dti2) + tm.assert_index_equal(result, expected) + + def test_constructor_period_values_mismatched_dtype(self): + pi = period_range("2016-01-01", periods=3, freq="D") + result = Index(pi, dtype="category") + expected = CategoricalIndex(pi) + tm.assert_index_equal(result, expected) + + def test_constructor_timedelta64_values_mismatched_dtype(self): + # check we don't silently ignore the dtype keyword + tdi = timedelta_range("4 Days", periods=5) + result = Index(tdi, dtype="category") + expected = CategoricalIndex(tdi) + tm.assert_index_equal(result, expected) + + def test_constructor_interval_values_mismatched_dtype(self): + dti = date_range("2016-01-01", periods=3) + ii = IntervalIndex.from_breaks(dti) + result = Index(ii, dtype="category") + expected = CategoricalIndex(ii) + tm.assert_index_equal(result, expected) + + def test_constructor_datetime64_values_mismatched_period_dtype(self): + dti = date_range("2016-01-01", periods=3) + result = Index(dti, dtype="Period[D]") + expected = dti.to_period("D") + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("dtype", ["int64", "uint64"]) + def test_constructor_int_dtype_nan_raises(self, dtype): + # see GH#15187 + data = [np.nan] + msg = "cannot convert" + with pytest.raises(ValueError, match=msg): + Index(data, dtype=dtype) + + @pytest.mark.parametrize( + "vals", + [ + [1, 2, 3], + np.array([1, 2, 3]), + np.array([1, 2, 3], dtype=int), + # below should coerce + [1.0, 2.0, 3.0], + np.array([1.0, 2.0, 3.0], dtype=float), + ], + ) + def test_constructor_dtypes_to_int(self, vals, any_int_numpy_dtype): + dtype = any_int_numpy_dtype + index = Index(vals, dtype=dtype) + assert index.dtype == dtype + + @pytest.mark.parametrize( + "vals", + [ + [1, 2, 3], + [1.0, 2.0, 3.0], + np.array([1.0, 2.0, 3.0]), + np.array([1, 2, 3], dtype=int), + np.array([1.0, 2.0, 3.0], dtype=float), + ], + ) + def test_constructor_dtypes_to_float(self, vals, float_numpy_dtype): + dtype = float_numpy_dtype + index = Index(vals, dtype=dtype) + assert index.dtype == dtype + + @pytest.mark.parametrize( + "vals", + [ + [1, 2, 3], + np.array([1, 2, 3], dtype=int), + np.array(["2011-01-01", "2011-01-02"], dtype="datetime64[ns]"), + [datetime(2011, 1, 1), datetime(2011, 1, 2)], + ], + ) + def test_constructor_dtypes_to_categorical(self, vals): + index = Index(vals, dtype="category") + assert isinstance(index, CategoricalIndex) + + @pytest.mark.parametrize("cast_index", [True, False]) + @pytest.mark.parametrize( + "vals", + [ + Index(np.array([np.datetime64("2011-01-01"), np.datetime64("2011-01-02")])), + Index([datetime(2011, 1, 1), datetime(2011, 1, 2)]), + ], + ) + def test_constructor_dtypes_to_datetime(self, cast_index, vals): + if cast_index: + index = Index(vals, dtype=object) + assert isinstance(index, Index) + assert index.dtype == object + else: + index = Index(vals) + assert isinstance(index, DatetimeIndex) + + @pytest.mark.parametrize("cast_index", [True, False]) + @pytest.mark.parametrize( + "vals", + [ + np.array([np.timedelta64(1, "D"), np.timedelta64(1, "D")]), + [timedelta(1), timedelta(1)], + ], + ) + def test_constructor_dtypes_to_timedelta(self, cast_index, vals): + if cast_index: + index = Index(vals, dtype=object) + assert isinstance(index, Index) + assert index.dtype == object + else: + index = Index(vals) + assert isinstance(index, TimedeltaIndex) + + def test_pass_timedeltaindex_to_index(self): + rng = timedelta_range("1 days", "10 days") + idx = Index(rng, dtype=object) + + expected = Index(rng.to_pytimedelta(), dtype=object) + + tm.assert_numpy_array_equal(idx.values, expected.values) + + def test_pass_datetimeindex_to_index(self): + # GH#1396 + rng = date_range("1/1/2000", "3/1/2000") + idx = Index(rng, dtype=object) + + expected = Index(rng.to_pydatetime(), dtype=object) + + tm.assert_numpy_array_equal(idx.values, expected.values) + + +class TestIndexConstructorUnwrapping: + # Test passing different arraylike values to pd.Index + + @pytest.mark.parametrize("klass", [Index, DatetimeIndex]) + def test_constructor_from_series_dt64(self, klass): + stamps = [Timestamp("20110101"), Timestamp("20120101"), Timestamp("20130101")] + expected = DatetimeIndex(stamps) + ser = Series(stamps) + result = klass(ser) + tm.assert_index_equal(result, expected) + + def test_constructor_no_pandas_array(self): + ser = Series([1, 2, 3]) + result = Index(ser.array) + expected = Index([1, 2, 3]) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "array", + [ + np.arange(5), + np.array(["a", "b", "c"]), + date_range("2000-01-01", periods=3).values, + ], + ) + def test_constructor_ndarray_like(self, array): + # GH#5460#issuecomment-44474502 + # it should be possible to convert any object that satisfies the numpy + # ndarray interface directly into an Index + class ArrayLike: + def __init__(self, array) -> None: + self.array = array + + def __array__(self, dtype=None, copy=None) -> np.ndarray: + return self.array + + expected = Index(array) + result = Index(ArrayLike(array)) + tm.assert_index_equal(result, expected) + + +class TestIndexConstructionErrors: + def test_constructor_overflow_int64(self): + # see GH#15832 + msg = ( + "The elements provided in the data cannot " + "all be casted to the dtype int64" + ) + with pytest.raises(OverflowError, match=msg): + Index([np.iinfo(np.uint64).max - 1], dtype="int64") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_indexing.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_indexing.py new file mode 100644 index 0000000000000000000000000000000000000000..1ea47f636ac9b64346b21496fe25d4fe109cd711 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_indexing.py @@ -0,0 +1,357 @@ +""" +test_indexing tests the following Index methods: + __getitem__ + get_loc + get_value + __contains__ + take + where + get_indexer + get_indexer_for + slice_locs + asof_locs + +The corresponding tests.indexes.[index_type].test_indexing files +contain tests for the corresponding methods specific to those Index subclasses. +""" +import numpy as np +import pytest + +from pandas.errors import InvalidIndexError + +from pandas.core.dtypes.common import ( + is_float_dtype, + is_scalar, +) + +from pandas import ( + NA, + DatetimeIndex, + Index, + IntervalIndex, + MultiIndex, + NaT, + PeriodIndex, + TimedeltaIndex, +) +import pandas._testing as tm + + +class TestTake: + def test_take_invalid_kwargs(self, index): + indices = [1, 2] + + msg = r"take\(\) got an unexpected keyword argument 'foo'" + with pytest.raises(TypeError, match=msg): + index.take(indices, foo=2) + + msg = "the 'out' parameter is not supported" + with pytest.raises(ValueError, match=msg): + index.take(indices, out=indices) + + msg = "the 'mode' parameter is not supported" + with pytest.raises(ValueError, match=msg): + index.take(indices, mode="clip") + + def test_take(self, index): + indexer = [4, 3, 0, 2] + if len(index) < 5: + pytest.skip("Test doesn't make sense since not enough elements") + + result = index.take(indexer) + expected = index[indexer] + assert result.equals(expected) + + if not isinstance(index, (DatetimeIndex, PeriodIndex, TimedeltaIndex)): + # GH 10791 + msg = r"'(.*Index)' object has no attribute 'freq'" + with pytest.raises(AttributeError, match=msg): + index.freq + + def test_take_indexer_type(self): + # GH#42875 + integer_index = Index([0, 1, 2, 3]) + scalar_index = 1 + msg = "Expected indices to be array-like" + with pytest.raises(TypeError, match=msg): + integer_index.take(scalar_index) + + def test_take_minus1_without_fill(self, index): + # -1 does not get treated as NA unless allow_fill=True is passed + if len(index) == 0: + # Test is not applicable + pytest.skip("Test doesn't make sense for empty index") + + result = index.take([0, 0, -1]) + + expected = index.take([0, 0, len(index) - 1]) + tm.assert_index_equal(result, expected) + + +class TestContains: + @pytest.mark.parametrize( + "index,val", + [ + (Index([0, 1, 2]), 2), + (Index([0, 1, "2"]), "2"), + (Index([0, 1, 2, np.inf, 4]), 4), + (Index([0, 1, 2, np.nan, 4]), 4), + (Index([0, 1, 2, np.inf]), np.inf), + (Index([0, 1, 2, np.nan]), np.nan), + ], + ) + def test_index_contains(self, index, val): + assert val in index + + @pytest.mark.parametrize( + "index,val", + [ + (Index([0, 1, 2]), "2"), + (Index([0, 1, "2"]), 2), + (Index([0, 1, 2, np.inf]), 4), + (Index([0, 1, 2, np.nan]), 4), + (Index([0, 1, 2, np.inf]), np.nan), + (Index([0, 1, 2, np.nan]), np.inf), + # Checking if np.inf in int64 Index should not cause an OverflowError + # Related to GH 16957 + (Index([0, 1, 2], dtype=np.int64), np.inf), + (Index([0, 1, 2], dtype=np.int64), np.nan), + (Index([0, 1, 2], dtype=np.uint64), np.inf), + (Index([0, 1, 2], dtype=np.uint64), np.nan), + ], + ) + def test_index_not_contains(self, index, val): + assert val not in index + + @pytest.mark.parametrize( + "index,val", [(Index([0, 1, "2"]), 0), (Index([0, 1, "2"]), "2")] + ) + def test_mixed_index_contains(self, index, val): + # GH#19860 + assert val in index + + @pytest.mark.parametrize( + "index,val", [(Index([0, 1, "2"]), "1"), (Index([0, 1, "2"]), 2)] + ) + def test_mixed_index_not_contains(self, index, val): + # GH#19860 + assert val not in index + + def test_contains_with_float_index(self, any_real_numpy_dtype): + # GH#22085 + dtype = any_real_numpy_dtype + data = [0, 1, 2, 3] if not is_float_dtype(dtype) else [0.1, 1.1, 2.2, 3.3] + index = Index(data, dtype=dtype) + + if not is_float_dtype(index.dtype): + assert 1.1 not in index + assert 1.0 in index + assert 1 in index + else: + assert 1.1 in index + assert 1.0 not in index + assert 1 not in index + + def test_contains_requires_hashable_raises(self, index): + if isinstance(index, MultiIndex): + return # TODO: do we want this to raise? + + msg = "unhashable type: 'list'" + with pytest.raises(TypeError, match=msg): + [] in index + + msg = "|".join( + [ + r"unhashable type: 'dict'", + r"must be real number, not dict", + r"an integer is required", + r"\{\}", + r"pandas\._libs\.interval\.IntervalTree' is not iterable", + ] + ) + with pytest.raises(TypeError, match=msg): + {} in index._engine + + +class TestGetLoc: + def test_get_loc_non_hashable(self, index): + with pytest.raises(InvalidIndexError, match="[0, 1]"): + index.get_loc([0, 1]) + + def test_get_loc_non_scalar_hashable(self, index): + # GH52877 + from enum import Enum + + class E(Enum): + X1 = "x1" + + assert not is_scalar(E.X1) + + exc = KeyError + msg = "" + if isinstance( + index, + ( + DatetimeIndex, + TimedeltaIndex, + PeriodIndex, + IntervalIndex, + ), + ): + # TODO: make these more consistent? + exc = InvalidIndexError + msg = "E.X1" + with pytest.raises(exc, match=msg): + index.get_loc(E.X1) + + def test_get_loc_generator(self, index): + exc = KeyError + if isinstance( + index, + ( + DatetimeIndex, + TimedeltaIndex, + PeriodIndex, + IntervalIndex, + MultiIndex, + ), + ): + # TODO: make these more consistent? + exc = InvalidIndexError + with pytest.raises(exc, match="generator object"): + # MultiIndex specifically checks for generator; others for scalar + index.get_loc(x for x in range(5)) + + def test_get_loc_masked_duplicated_na(self): + # GH#48411 + idx = Index([1, 2, NA, NA], dtype="Int64") + result = idx.get_loc(NA) + expected = np.array([False, False, True, True]) + tm.assert_numpy_array_equal(result, expected) + + +class TestGetIndexer: + def test_get_indexer_base(self, index): + if index._index_as_unique: + expected = np.arange(index.size, dtype=np.intp) + actual = index.get_indexer(index) + tm.assert_numpy_array_equal(expected, actual) + else: + msg = "Reindexing only valid with uniquely valued Index objects" + with pytest.raises(InvalidIndexError, match=msg): + index.get_indexer(index) + + with pytest.raises(ValueError, match="Invalid fill method"): + index.get_indexer(index, method="invalid") + + def test_get_indexer_consistency(self, index): + # See GH#16819 + + if index._index_as_unique: + indexer = index.get_indexer(index[0:2]) + assert isinstance(indexer, np.ndarray) + assert indexer.dtype == np.intp + else: + msg = "Reindexing only valid with uniquely valued Index objects" + with pytest.raises(InvalidIndexError, match=msg): + index.get_indexer(index[0:2]) + + indexer, _ = index.get_indexer_non_unique(index[0:2]) + assert isinstance(indexer, np.ndarray) + assert indexer.dtype == np.intp + + def test_get_indexer_masked_duplicated_na(self): + # GH#48411 + idx = Index([1, 2, NA, NA], dtype="Int64") + result = idx.get_indexer_for(Index([1, NA], dtype="Int64")) + expected = np.array([0, 2, 3], dtype=result.dtype) + tm.assert_numpy_array_equal(result, expected) + + +class TestConvertSliceIndexer: + def test_convert_almost_null_slice(self, index): + # slice with None at both ends, but not step + + key = slice(None, None, "foo") + + if isinstance(index, IntervalIndex): + msg = "label-based slicing with step!=1 is not supported for IntervalIndex" + with pytest.raises(ValueError, match=msg): + index._convert_slice_indexer(key, "loc") + else: + msg = "'>=' not supported between instances of 'str' and 'int'" + with pytest.raises(TypeError, match=msg): + index._convert_slice_indexer(key, "loc") + + +class TestPutmask: + def test_putmask_with_wrong_mask(self, index): + # GH#18368 + if not len(index): + pytest.skip("Test doesn't make sense for empty index") + + fill = index[0] + + msg = "putmask: mask and data must be the same size" + with pytest.raises(ValueError, match=msg): + index.putmask(np.ones(len(index) + 1, np.bool_), fill) + + with pytest.raises(ValueError, match=msg): + index.putmask(np.ones(len(index) - 1, np.bool_), fill) + + with pytest.raises(ValueError, match=msg): + index.putmask("foo", fill) + + +@pytest.mark.parametrize( + "idx", [Index([1, 2, 3]), Index([0.1, 0.2, 0.3]), Index(["a", "b", "c"])] +) +def test_getitem_deprecated_float(idx): + # https://github.com/pandas-dev/pandas/issues/34191 + + msg = "Indexing with a float is no longer supported" + with pytest.raises(IndexError, match=msg): + idx[1.0] + + +@pytest.mark.parametrize( + "idx,target,expected", + [ + ([np.nan, "var1", np.nan], [np.nan], np.array([0, 2], dtype=np.intp)), + ( + [np.nan, "var1", np.nan], + [np.nan, "var1"], + np.array([0, 2, 1], dtype=np.intp), + ), + ( + np.array([np.nan, "var1", np.nan], dtype=object), + [np.nan], + np.array([0, 2], dtype=np.intp), + ), + ( + DatetimeIndex(["2020-08-05", NaT, NaT]), + [NaT], + np.array([1, 2], dtype=np.intp), + ), + (["a", "b", "a", np.nan], [np.nan], np.array([3], dtype=np.intp)), + ( + np.array(["b", np.nan, float("NaN"), "b"], dtype=object), + Index([np.nan], dtype=object), + np.array([1, 2], dtype=np.intp), + ), + ], +) +def test_get_indexer_non_unique_multiple_nans(idx, target, expected): + # GH 35392 + axis = Index(idx) + actual = axis.get_indexer_for(target) + tm.assert_numpy_array_equal(actual, expected) + + +def test_get_indexer_non_unique_nans_in_object_dtype_target(nulls_fixture): + idx = Index([1.0, 2.0]) + target = Index([1, nulls_fixture], dtype="object") + + result_idx, result_missing = idx.get_indexer_non_unique(target) + tm.assert_numpy_array_equal(result_idx, np.array([0, -1], dtype=np.intp)) + tm.assert_numpy_array_equal(result_missing, np.array([1], dtype=np.intp)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_numpy_compat.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_numpy_compat.py new file mode 100644 index 0000000000000000000000000000000000000000..ace78d77350cbdc4ca3aa837720767a965443051 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_numpy_compat.py @@ -0,0 +1,189 @@ +import numpy as np +import pytest + +from pandas import ( + CategoricalIndex, + DatetimeIndex, + Index, + PeriodIndex, + TimedeltaIndex, + isna, +) +import pandas._testing as tm +from pandas.api.types import ( + is_complex_dtype, + is_numeric_dtype, +) +from pandas.core.arrays import BooleanArray +from pandas.core.indexes.datetimelike import DatetimeIndexOpsMixin + + +def test_numpy_ufuncs_out(index): + result = index == index + + out = np.empty(index.shape, dtype=bool) + np.equal(index, index, out=out) + tm.assert_numpy_array_equal(out, result) + + if not index._is_multi: + # same thing on the ExtensionArray + out = np.empty(index.shape, dtype=bool) + np.equal(index.array, index.array, out=out) + tm.assert_numpy_array_equal(out, result) + + +@pytest.mark.parametrize( + "func", + [ + np.exp, + np.exp2, + np.expm1, + np.log, + np.log2, + np.log10, + np.log1p, + np.sqrt, + np.sin, + np.cos, + np.tan, + np.arcsin, + np.arccos, + np.arctan, + np.sinh, + np.cosh, + np.tanh, + np.arcsinh, + np.arccosh, + np.arctanh, + np.deg2rad, + np.rad2deg, + ], + ids=lambda x: x.__name__, +) +def test_numpy_ufuncs_basic(index, func): + # test ufuncs of numpy, see: + # https://numpy.org/doc/stable/reference/ufuncs.html + + if isinstance(index, DatetimeIndexOpsMixin): + with tm.external_error_raised((TypeError, AttributeError)): + with np.errstate(all="ignore"): + func(index) + elif is_numeric_dtype(index) and not ( + is_complex_dtype(index) and func in [np.deg2rad, np.rad2deg] + ): + # coerces to float (e.g. np.sin) + with np.errstate(all="ignore"): + result = func(index) + arr_result = func(index.values) + if arr_result.dtype == np.float16: + arr_result = arr_result.astype(np.float32) + exp = Index(arr_result, name=index.name) + + tm.assert_index_equal(result, exp) + if isinstance(index.dtype, np.dtype) and is_numeric_dtype(index): + if is_complex_dtype(index): + assert result.dtype == index.dtype + elif index.dtype in ["bool", "int8", "uint8"]: + assert result.dtype in ["float16", "float32"] + elif index.dtype in ["int16", "uint16", "float32"]: + assert result.dtype == "float32" + else: + assert result.dtype == "float64" + else: + # e.g. np.exp with Int64 -> Float64 + assert type(result) is Index + # raise AttributeError or TypeError + elif len(index) == 0: + pass + else: + with tm.external_error_raised((TypeError, AttributeError)): + with np.errstate(all="ignore"): + func(index) + + +@pytest.mark.parametrize( + "func", [np.isfinite, np.isinf, np.isnan, np.signbit], ids=lambda x: x.__name__ +) +def test_numpy_ufuncs_other(index, func): + # test ufuncs of numpy, see: + # https://numpy.org/doc/stable/reference/ufuncs.html + if isinstance(index, (DatetimeIndex, TimedeltaIndex)): + if func in (np.isfinite, np.isinf, np.isnan): + # numpy 1.18 changed isinf and isnan to not raise on dt64/td64 + result = func(index) + assert isinstance(result, np.ndarray) + + out = np.empty(index.shape, dtype=bool) + func(index, out=out) + tm.assert_numpy_array_equal(out, result) + else: + with tm.external_error_raised(TypeError): + func(index) + + elif isinstance(index, PeriodIndex): + with tm.external_error_raised(TypeError): + func(index) + + elif is_numeric_dtype(index) and not ( + is_complex_dtype(index) and func is np.signbit + ): + # Results in bool array + result = func(index) + if not isinstance(index.dtype, np.dtype): + # e.g. Int64 we expect to get BooleanArray back + assert isinstance(result, BooleanArray) + else: + assert isinstance(result, np.ndarray) + + out = np.empty(index.shape, dtype=bool) + func(index, out=out) + + if not isinstance(index.dtype, np.dtype): + tm.assert_numpy_array_equal(out, result._data) + else: + tm.assert_numpy_array_equal(out, result) + + elif len(index) == 0: + pass + else: + with tm.external_error_raised(TypeError): + func(index) + + +@pytest.mark.parametrize("func", [np.maximum, np.minimum]) +def test_numpy_ufuncs_reductions(index, func, request): + # TODO: overlap with tests.series.test_ufunc.test_reductions + if len(index) == 0: + pytest.skip("Test doesn't make sense for empty index.") + + if isinstance(index, CategoricalIndex) and index.dtype.ordered is False: + with pytest.raises(TypeError, match="is not ordered for"): + func.reduce(index) + return + else: + result = func.reduce(index) + + if func is np.maximum: + expected = index.max(skipna=False) + else: + expected = index.min(skipna=False) + # TODO: do we have cases both with and without NAs? + + assert type(result) is type(expected) + if isna(result): + assert isna(expected) + else: + assert result == expected + + +@pytest.mark.parametrize("func", [np.bitwise_and, np.bitwise_or, np.bitwise_xor]) +def test_numpy_ufuncs_bitwise(func): + # https://github.com/pandas-dev/pandas/issues/46769 + idx1 = Index([1, 2, 3, 4], dtype="int64") + idx2 = Index([3, 4, 5, 6], dtype="int64") + + with tm.assert_produces_warning(None): + result = func(idx1, idx2) + + expected = Index(func(idx1.values, idx2.values)) + tm.assert_index_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_old_base.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_old_base.py new file mode 100644 index 0000000000000000000000000000000000000000..1787379b0faee68956932a451a1dfbbdc711204c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_old_base.py @@ -0,0 +1,1061 @@ +from __future__ import annotations + +from datetime import datetime +import weakref + +import numpy as np +import pytest + +from pandas._config import using_pyarrow_string_dtype + +from pandas._libs.tslibs import Timestamp + +from pandas.core.dtypes.common import ( + is_integer_dtype, + is_numeric_dtype, +) +from pandas.core.dtypes.dtypes import CategoricalDtype + +import pandas as pd +from pandas import ( + CategoricalIndex, + DatetimeIndex, + DatetimeTZDtype, + Index, + IntervalIndex, + MultiIndex, + PeriodIndex, + RangeIndex, + Series, + TimedeltaIndex, + isna, + period_range, +) +import pandas._testing as tm +import pandas.core.algorithms as algos +from pandas.core.arrays import BaseMaskedArray + + +class TestBase: + @pytest.fixture( + params=[ + RangeIndex(start=0, stop=20, step=2), + Index(np.arange(5, dtype=np.float64)), + Index(np.arange(5, dtype=np.float32)), + Index(np.arange(5, dtype=np.uint64)), + Index(range(0, 20, 2), dtype=np.int64), + Index(range(0, 20, 2), dtype=np.int32), + Index(range(0, 20, 2), dtype=np.int16), + Index(range(0, 20, 2), dtype=np.int8), + Index(list("abcde")), + Index([0, "a", 1, "b", 2, "c"]), + period_range("20130101", periods=5, freq="D"), + TimedeltaIndex( + [ + "0 days 01:00:00", + "1 days 01:00:00", + "2 days 01:00:00", + "3 days 01:00:00", + "4 days 01:00:00", + ], + dtype="timedelta64[ns]", + freq="D", + ), + DatetimeIndex( + ["2013-01-01", "2013-01-02", "2013-01-03", "2013-01-04", "2013-01-05"], + dtype="datetime64[ns]", + freq="D", + ), + IntervalIndex.from_breaks(range(11), closed="right"), + ] + ) + def simple_index(self, request): + return request.param + + def test_pickle_compat_construction(self, simple_index): + # need an object to create with + if isinstance(simple_index, RangeIndex): + pytest.skip("RangeIndex() is a valid constructor") + msg = "|".join( + [ + r"Index\(\.\.\.\) must be called with a collection of some " + r"kind, None was passed", + r"DatetimeIndex\(\) must be called with a collection of some " + r"kind, None was passed", + r"TimedeltaIndex\(\) must be called with a collection of some " + r"kind, None was passed", + r"__new__\(\) missing 1 required positional argument: 'data'", + r"__new__\(\) takes at least 2 arguments \(1 given\)", + ] + ) + with pytest.raises(TypeError, match=msg): + type(simple_index)() + + def test_shift(self, simple_index): + # GH8083 test the base class for shift + if isinstance(simple_index, (DatetimeIndex, TimedeltaIndex, PeriodIndex)): + pytest.skip("Tested in test_ops/test_arithmetic") + idx = simple_index + msg = ( + f"This method is only implemented for DatetimeIndex, PeriodIndex and " + f"TimedeltaIndex; Got type {type(idx).__name__}" + ) + with pytest.raises(NotImplementedError, match=msg): + idx.shift(1) + with pytest.raises(NotImplementedError, match=msg): + idx.shift(1, 2) + + def test_constructor_name_unhashable(self, simple_index): + # GH#29069 check that name is hashable + # See also same-named test in tests.series.test_constructors + idx = simple_index + with pytest.raises(TypeError, match="Index.name must be a hashable type"): + type(idx)(idx, name=[]) + + def test_create_index_existing_name(self, simple_index): + # GH11193, when an existing index is passed, and a new name is not + # specified, the new index should inherit the previous object name + expected = simple_index.copy() + if not isinstance(expected, MultiIndex): + expected.name = "foo" + result = Index(expected) + tm.assert_index_equal(result, expected) + + result = Index(expected, name="bar") + expected.name = "bar" + tm.assert_index_equal(result, expected) + else: + expected.names = ["foo", "bar"] + result = Index(expected) + tm.assert_index_equal( + result, + Index( + Index( + [ + ("foo", "one"), + ("foo", "two"), + ("bar", "one"), + ("baz", "two"), + ("qux", "one"), + ("qux", "two"), + ], + dtype="object", + ), + names=["foo", "bar"], + ), + ) + + result = Index(expected, names=["A", "B"]) + tm.assert_index_equal( + result, + Index( + Index( + [ + ("foo", "one"), + ("foo", "two"), + ("bar", "one"), + ("baz", "two"), + ("qux", "one"), + ("qux", "two"), + ], + dtype="object", + ), + names=["A", "B"], + ), + ) + + def test_numeric_compat(self, simple_index): + idx = simple_index + # Check that this doesn't cover MultiIndex case, if/when it does, + # we can remove multi.test_compat.test_numeric_compat + assert not isinstance(idx, MultiIndex) + if type(idx) is Index: + pytest.skip("Not applicable for Index") + if is_numeric_dtype(simple_index.dtype) or isinstance( + simple_index, TimedeltaIndex + ): + pytest.skip("Tested elsewhere.") + + typ = type(idx._data).__name__ + cls = type(idx).__name__ + lmsg = "|".join( + [ + rf"unsupported operand type\(s\) for \*: '{typ}' and 'int'", + "cannot perform (__mul__|__truediv__|__floordiv__) with " + f"this index type: ({cls}|{typ})", + ] + ) + with pytest.raises(TypeError, match=lmsg): + idx * 1 + rmsg = "|".join( + [ + rf"unsupported operand type\(s\) for \*: 'int' and '{typ}'", + "cannot perform (__rmul__|__rtruediv__|__rfloordiv__) with " + f"this index type: ({cls}|{typ})", + ] + ) + with pytest.raises(TypeError, match=rmsg): + 1 * idx + + div_err = lmsg.replace("*", "/") + with pytest.raises(TypeError, match=div_err): + idx / 1 + div_err = rmsg.replace("*", "/") + with pytest.raises(TypeError, match=div_err): + 1 / idx + + floordiv_err = lmsg.replace("*", "//") + with pytest.raises(TypeError, match=floordiv_err): + idx // 1 + floordiv_err = rmsg.replace("*", "//") + with pytest.raises(TypeError, match=floordiv_err): + 1 // idx + + def test_logical_compat(self, simple_index): + if simple_index.dtype in (object, "string"): + pytest.skip("Tested elsewhere.") + idx = simple_index + if idx.dtype.kind in "iufcbm": + assert idx.all() == idx._values.all() + assert idx.all() == idx.to_series().all() + assert idx.any() == idx._values.any() + assert idx.any() == idx.to_series().any() + else: + msg = "cannot perform (any|all)" + if isinstance(idx, IntervalIndex): + msg = ( + r"'IntervalArray' with dtype interval\[.*\] does " + "not support reduction '(any|all)'" + ) + with pytest.raises(TypeError, match=msg): + idx.all() + with pytest.raises(TypeError, match=msg): + idx.any() + + def test_repr_roundtrip(self, simple_index): + if isinstance(simple_index, IntervalIndex): + pytest.skip(f"Not a valid repr for {type(simple_index).__name__}") + idx = simple_index + tm.assert_index_equal(eval(repr(idx)), idx) + + def test_repr_max_seq_item_setting(self, simple_index): + # GH10182 + if isinstance(simple_index, IntervalIndex): + pytest.skip(f"Not a valid repr for {type(simple_index).__name__}") + idx = simple_index + idx = idx.repeat(50) + with pd.option_context("display.max_seq_items", None): + repr(idx) + assert "..." not in str(idx) + + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + def test_ensure_copied_data(self, index): + # Check the "copy" argument of each Index.__new__ is honoured + # GH12309 + init_kwargs = {} + if isinstance(index, PeriodIndex): + # Needs "freq" specification: + init_kwargs["freq"] = index.freq + elif isinstance(index, (RangeIndex, MultiIndex, CategoricalIndex)): + pytest.skip( + "RangeIndex cannot be initialized from data, " + "MultiIndex and CategoricalIndex are tested separately" + ) + elif index.dtype == object and index.inferred_type == "boolean": + init_kwargs["dtype"] = index.dtype + + index_type = type(index) + result = index_type(index.values, copy=True, **init_kwargs) + if isinstance(index.dtype, DatetimeTZDtype): + result = result.tz_localize("UTC").tz_convert(index.tz) + if isinstance(index, (DatetimeIndex, TimedeltaIndex)): + index = index._with_freq(None) + + tm.assert_index_equal(index, result) + + if isinstance(index, PeriodIndex): + # .values an object array of Period, thus copied + depr_msg = "The 'ordinal' keyword in PeriodIndex is deprecated" + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + result = index_type(ordinal=index.asi8, copy=False, **init_kwargs) + tm.assert_numpy_array_equal(index.asi8, result.asi8, check_same="same") + elif isinstance(index, IntervalIndex): + # checked in test_interval.py + pass + elif type(index) is Index and not isinstance(index.dtype, np.dtype): + result = index_type(index.values, copy=False, **init_kwargs) + tm.assert_index_equal(result, index) + + if isinstance(index._values, BaseMaskedArray): + assert np.shares_memory(index._values._data, result._values._data) + tm.assert_numpy_array_equal( + index._values._data, result._values._data, check_same="same" + ) + assert np.shares_memory(index._values._mask, result._values._mask) + tm.assert_numpy_array_equal( + index._values._mask, result._values._mask, check_same="same" + ) + elif index.dtype == "string[python]": + assert np.shares_memory(index._values._ndarray, result._values._ndarray) + tm.assert_numpy_array_equal( + index._values._ndarray, result._values._ndarray, check_same="same" + ) + elif index.dtype in ("string[pyarrow]", "string[pyarrow_numpy]"): + assert tm.shares_memory(result._values, index._values) + else: + raise NotImplementedError(index.dtype) + else: + result = index_type(index.values, copy=False, **init_kwargs) + tm.assert_numpy_array_equal(index.values, result.values, check_same="same") + + def test_memory_usage(self, index): + index._engine.clear_mapping() + result = index.memory_usage() + if index.empty: + # we report 0 for no-length + assert result == 0 + return + + # non-zero length + index.get_loc(index[0]) + result2 = index.memory_usage() + result3 = index.memory_usage(deep=True) + + # RangeIndex, IntervalIndex + # don't have engines + # Index[EA] has engine but it does not have a Hashtable .mapping + if not isinstance(index, (RangeIndex, IntervalIndex)) and not ( + type(index) is Index and not isinstance(index.dtype, np.dtype) + ): + assert result2 > result + + if index.inferred_type == "object": + assert result3 > result2 + + def test_argsort(self, index): + if isinstance(index, CategoricalIndex): + pytest.skip(f"{type(self).__name__} separately tested") + + result = index.argsort() + expected = np.array(index).argsort() + tm.assert_numpy_array_equal(result, expected, check_dtype=False) + + def test_numpy_argsort(self, index): + result = np.argsort(index) + expected = index.argsort() + tm.assert_numpy_array_equal(result, expected) + + result = np.argsort(index, kind="mergesort") + expected = index.argsort(kind="mergesort") + tm.assert_numpy_array_equal(result, expected) + + # these are the only two types that perform + # pandas compatibility input validation - the + # rest already perform separate (or no) such + # validation via their 'values' attribute as + # defined in pandas.core.indexes/base.py - they + # cannot be changed at the moment due to + # backwards compatibility concerns + if isinstance(index, (CategoricalIndex, RangeIndex)): + msg = "the 'axis' parameter is not supported" + with pytest.raises(ValueError, match=msg): + np.argsort(index, axis=1) + + msg = "the 'order' parameter is not supported" + with pytest.raises(ValueError, match=msg): + np.argsort(index, order=("a", "b")) + + def test_repeat(self, simple_index): + rep = 2 + idx = simple_index.copy() + new_index_cls = idx._constructor + expected = new_index_cls(idx.values.repeat(rep), name=idx.name) + tm.assert_index_equal(idx.repeat(rep), expected) + + idx = simple_index + rep = np.arange(len(idx)) + expected = new_index_cls(idx.values.repeat(rep), name=idx.name) + tm.assert_index_equal(idx.repeat(rep), expected) + + def test_numpy_repeat(self, simple_index): + rep = 2 + idx = simple_index + expected = idx.repeat(rep) + tm.assert_index_equal(np.repeat(idx, rep), expected) + + msg = "the 'axis' parameter is not supported" + with pytest.raises(ValueError, match=msg): + np.repeat(idx, rep, axis=0) + + def test_where(self, listlike_box, simple_index): + if isinstance(simple_index, (IntervalIndex, PeriodIndex)) or is_numeric_dtype( + simple_index.dtype + ): + pytest.skip("Tested elsewhere.") + klass = listlike_box + + idx = simple_index + if isinstance(idx, (DatetimeIndex, TimedeltaIndex)): + # where does not preserve freq + idx = idx._with_freq(None) + + cond = [True] * len(idx) + result = idx.where(klass(cond)) + expected = idx + tm.assert_index_equal(result, expected) + + cond = [False] + [True] * len(idx[1:]) + expected = Index([idx._na_value] + idx[1:].tolist(), dtype=idx.dtype) + result = idx.where(klass(cond)) + tm.assert_index_equal(result, expected) + + def test_insert_base(self, index): + trimmed = index[1:4] + + if not len(index): + pytest.skip("Not applicable for empty index") + + # test 0th element + warn = None + if index.dtype == object and index.inferred_type == "boolean": + # GH#51363 + warn = FutureWarning + msg = "The behavior of Index.insert with object-dtype is deprecated" + with tm.assert_produces_warning(warn, match=msg): + result = trimmed.insert(0, index[0]) + assert index[0:4].equals(result) + + @pytest.mark.skipif( + using_pyarrow_string_dtype(), + reason="completely different behavior, tested elsewher", + ) + def test_insert_out_of_bounds(self, index): + # TypeError/IndexError matches what np.insert raises in these cases + + if len(index) > 0: + err = TypeError + else: + err = IndexError + if len(index) == 0: + # 0 vs 0.5 in error message varies with numpy version + msg = "index (0|0.5) is out of bounds for axis 0 with size 0" + else: + msg = "slice indices must be integers or None or have an __index__ method" + with pytest.raises(err, match=msg): + index.insert(0.5, "foo") + + msg = "|".join( + [ + r"index -?\d+ is out of bounds for axis 0 with size \d+", + "loc must be an integer between", + ] + ) + with pytest.raises(IndexError, match=msg): + index.insert(len(index) + 1, 1) + + with pytest.raises(IndexError, match=msg): + index.insert(-len(index) - 1, 1) + + def test_delete_base(self, index): + if not len(index): + pytest.skip("Not applicable for empty index") + + if isinstance(index, RangeIndex): + # tested in class + pytest.skip(f"{type(self).__name__} tested elsewhere") + + expected = index[1:] + result = index.delete(0) + assert result.equals(expected) + assert result.name == expected.name + + expected = index[:-1] + result = index.delete(-1) + assert result.equals(expected) + assert result.name == expected.name + + length = len(index) + msg = f"index {length} is out of bounds for axis 0 with size {length}" + with pytest.raises(IndexError, match=msg): + index.delete(length) + + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + def test_equals(self, index): + if isinstance(index, IntervalIndex): + pytest.skip(f"{type(index).__name__} tested elsewhere") + + is_ea_idx = type(index) is Index and not isinstance(index.dtype, np.dtype) + + assert index.equals(index) + assert index.equals(index.copy()) + if not is_ea_idx: + # doesn't hold for e.g. IntegerDtype + assert index.equals(index.astype(object)) + + assert not index.equals(list(index)) + assert not index.equals(np.array(index)) + + # Cannot pass in non-int64 dtype to RangeIndex + if not isinstance(index, RangeIndex) and not is_ea_idx: + same_values = Index(index, dtype=object) + assert index.equals(same_values) + assert same_values.equals(index) + + if index.nlevels == 1: + # do not test MultiIndex + assert not index.equals(Series(index)) + + def test_equals_op(self, simple_index): + # GH9947, GH10637 + index_a = simple_index + + n = len(index_a) + index_b = index_a[0:-1] + index_c = index_a[0:-1].append(index_a[-2:-1]) + index_d = index_a[0:1] + + msg = "Lengths must match|could not be broadcast" + with pytest.raises(ValueError, match=msg): + index_a == index_b + expected1 = np.array([True] * n) + expected2 = np.array([True] * (n - 1) + [False]) + tm.assert_numpy_array_equal(index_a == index_a, expected1) + tm.assert_numpy_array_equal(index_a == index_c, expected2) + + # test comparisons with numpy arrays + array_a = np.array(index_a) + array_b = np.array(index_a[0:-1]) + array_c = np.array(index_a[0:-1].append(index_a[-2:-1])) + array_d = np.array(index_a[0:1]) + with pytest.raises(ValueError, match=msg): + index_a == array_b + tm.assert_numpy_array_equal(index_a == array_a, expected1) + tm.assert_numpy_array_equal(index_a == array_c, expected2) + + # test comparisons with Series + series_a = Series(array_a) + series_b = Series(array_b) + series_c = Series(array_c) + series_d = Series(array_d) + with pytest.raises(ValueError, match=msg): + index_a == series_b + + tm.assert_numpy_array_equal(index_a == series_a, expected1) + tm.assert_numpy_array_equal(index_a == series_c, expected2) + + # cases where length is 1 for one of them + with pytest.raises(ValueError, match="Lengths must match"): + index_a == index_d + with pytest.raises(ValueError, match="Lengths must match"): + index_a == series_d + with pytest.raises(ValueError, match="Lengths must match"): + index_a == array_d + msg = "Can only compare identically-labeled Series objects" + with pytest.raises(ValueError, match=msg): + series_a == series_d + with pytest.raises(ValueError, match="Lengths must match"): + series_a == array_d + + # comparing with a scalar should broadcast; note that we are excluding + # MultiIndex because in this case each item in the index is a tuple of + # length 2, and therefore is considered an array of length 2 in the + # comparison instead of a scalar + if not isinstance(index_a, MultiIndex): + expected3 = np.array([False] * (len(index_a) - 2) + [True, False]) + # assuming the 2nd to last item is unique in the data + item = index_a[-2] + tm.assert_numpy_array_equal(index_a == item, expected3) + tm.assert_series_equal(series_a == item, Series(expected3)) + + def test_format(self, simple_index): + # GH35439 + if is_numeric_dtype(simple_index.dtype) or isinstance( + simple_index, DatetimeIndex + ): + pytest.skip("Tested elsewhere.") + idx = simple_index + expected = [str(x) for x in idx] + msg = r"Index\.format is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + assert idx.format() == expected + + def test_format_empty(self, simple_index): + # GH35712 + if isinstance(simple_index, (PeriodIndex, RangeIndex)): + pytest.skip("Tested elsewhere") + empty_idx = type(simple_index)([]) + msg = r"Index\.format is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + assert empty_idx.format() == [] + with tm.assert_produces_warning(FutureWarning, match=msg): + assert empty_idx.format(name=True) == [""] + + def test_fillna(self, index): + # GH 11343 + if len(index) == 0: + pytest.skip("Not relevant for empty index") + elif index.dtype == bool: + pytest.skip(f"{index.dtype} cannot hold NAs") + elif isinstance(index, Index) and is_integer_dtype(index.dtype): + pytest.skip(f"Not relevant for Index with {index.dtype}") + elif isinstance(index, MultiIndex): + idx = index.copy(deep=True) + msg = "isna is not defined for MultiIndex" + with pytest.raises(NotImplementedError, match=msg): + idx.fillna(idx[0]) + else: + idx = index.copy(deep=True) + result = idx.fillna(idx[0]) + tm.assert_index_equal(result, idx) + assert result is not idx + + msg = "'value' must be a scalar, passed: " + with pytest.raises(TypeError, match=msg): + idx.fillna([idx[0]]) + + idx = index.copy(deep=True) + values = idx._values + + values[1] = np.nan + + idx = type(index)(values) + + msg = "does not support 'downcast'" + msg2 = r"The 'downcast' keyword in .*Index\.fillna is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg2): + with pytest.raises(NotImplementedError, match=msg): + # For now at least, we only raise if there are NAs present + idx.fillna(idx[0], downcast="infer") + + expected = np.array([False] * len(idx), dtype=bool) + expected[1] = True + tm.assert_numpy_array_equal(idx._isnan, expected) + assert idx.hasnans is True + + def test_nulls(self, index): + # this is really a smoke test for the methods + # as these are adequately tested for function elsewhere + if len(index) == 0: + tm.assert_numpy_array_equal(index.isna(), np.array([], dtype=bool)) + elif isinstance(index, MultiIndex): + idx = index.copy() + msg = "isna is not defined for MultiIndex" + with pytest.raises(NotImplementedError, match=msg): + idx.isna() + elif not index.hasnans: + tm.assert_numpy_array_equal(index.isna(), np.zeros(len(index), dtype=bool)) + tm.assert_numpy_array_equal(index.notna(), np.ones(len(index), dtype=bool)) + else: + result = isna(index) + tm.assert_numpy_array_equal(index.isna(), result) + tm.assert_numpy_array_equal(index.notna(), ~result) + + def test_empty(self, simple_index): + # GH 15270 + idx = simple_index + assert not idx.empty + assert idx[:0].empty + + def test_join_self_unique(self, join_type, simple_index): + idx = simple_index + if idx.is_unique: + joined = idx.join(idx, how=join_type) + expected = simple_index + if join_type == "outer": + expected = algos.safe_sort(expected) + tm.assert_index_equal(joined, expected) + + def test_map(self, simple_index): + # callable + if isinstance(simple_index, (TimedeltaIndex, PeriodIndex)): + pytest.skip("Tested elsewhere.") + idx = simple_index + + result = idx.map(lambda x: x) + # RangeIndex are equivalent to the similar Index with int64 dtype + tm.assert_index_equal(result, idx, exact="equiv") + + @pytest.mark.parametrize( + "mapper", + [ + lambda values, index: {i: e for e, i in zip(values, index)}, + lambda values, index: Series(values, index), + ], + ) + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + def test_map_dictlike(self, mapper, simple_index, request): + idx = simple_index + if isinstance(idx, (DatetimeIndex, TimedeltaIndex, PeriodIndex)): + pytest.skip("Tested elsewhere.") + + identity = mapper(idx.values, idx) + + result = idx.map(identity) + # RangeIndex are equivalent to the similar Index with int64 dtype + tm.assert_index_equal(result, idx, exact="equiv") + + # empty mappable + dtype = None + if idx.dtype.kind == "f": + dtype = idx.dtype + + expected = Index([np.nan] * len(idx), dtype=dtype) + result = idx.map(mapper(expected, idx)) + tm.assert_index_equal(result, expected) + + def test_map_str(self, simple_index): + # GH 31202 + if isinstance(simple_index, CategoricalIndex): + pytest.skip("See test_map.py") + idx = simple_index + result = idx.map(str) + expected = Index([str(x) for x in idx]) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("copy", [True, False]) + @pytest.mark.parametrize("name", [None, "foo"]) + @pytest.mark.parametrize("ordered", [True, False]) + def test_astype_category(self, copy, name, ordered, simple_index): + # GH 18630 + idx = simple_index + if name: + idx = idx.rename(name) + + # standard categories + dtype = CategoricalDtype(ordered=ordered) + result = idx.astype(dtype, copy=copy) + expected = CategoricalIndex(idx, name=name, ordered=ordered) + tm.assert_index_equal(result, expected, exact=True) + + # non-standard categories + dtype = CategoricalDtype(idx.unique().tolist()[:-1], ordered) + result = idx.astype(dtype, copy=copy) + expected = CategoricalIndex(idx, name=name, dtype=dtype) + tm.assert_index_equal(result, expected, exact=True) + + if ordered is False: + # dtype='category' defaults to ordered=False, so only test once + result = idx.astype("category", copy=copy) + expected = CategoricalIndex(idx, name=name) + tm.assert_index_equal(result, expected, exact=True) + + def test_is_unique(self, simple_index): + # initialize a unique index + index = simple_index.drop_duplicates() + assert index.is_unique is True + + # empty index should be unique + index_empty = index[:0] + assert index_empty.is_unique is True + + # test basic dupes + index_dup = index.insert(0, index[0]) + assert index_dup.is_unique is False + + # single NA should be unique + index_na = index.insert(0, np.nan) + assert index_na.is_unique is True + + # multiple NA should not be unique + index_na_dup = index_na.insert(0, np.nan) + assert index_na_dup.is_unique is False + + @pytest.mark.arm_slow + def test_engine_reference_cycle(self, simple_index): + # GH27585 + index = simple_index.copy() + ref = weakref.ref(index) + index._engine + del index + assert ref() is None + + def test_getitem_2d_deprecated(self, simple_index): + # GH#30588, GH#31479 + if isinstance(simple_index, IntervalIndex): + pytest.skip("Tested elsewhere") + idx = simple_index + msg = "Multi-dimensional indexing|too many|only" + with pytest.raises((ValueError, IndexError), match=msg): + idx[:, None] + + if not isinstance(idx, RangeIndex): + # GH#44051 RangeIndex already raised pre-2.0 with a different message + with pytest.raises((ValueError, IndexError), match=msg): + idx[True] + with pytest.raises((ValueError, IndexError), match=msg): + idx[False] + else: + msg = "only integers, slices" + with pytest.raises(IndexError, match=msg): + idx[True] + with pytest.raises(IndexError, match=msg): + idx[False] + + def test_copy_shares_cache(self, simple_index): + # GH32898, GH36840 + idx = simple_index + idx.get_loc(idx[0]) # populates the _cache. + copy = idx.copy() + + assert copy._cache is idx._cache + + def test_shallow_copy_shares_cache(self, simple_index): + # GH32669, GH36840 + idx = simple_index + idx.get_loc(idx[0]) # populates the _cache. + shallow_copy = idx._view() + + assert shallow_copy._cache is idx._cache + + shallow_copy = idx._shallow_copy(idx._data) + assert shallow_copy._cache is not idx._cache + assert shallow_copy._cache == {} + + def test_index_groupby(self, simple_index): + idx = simple_index[:5] + to_groupby = np.array([1, 2, np.nan, 2, 1]) + tm.assert_dict_equal( + idx.groupby(to_groupby), {1.0: idx[[0, 4]], 2.0: idx[[1, 3]]} + ) + + to_groupby = DatetimeIndex( + [ + datetime(2011, 11, 1), + datetime(2011, 12, 1), + pd.NaT, + datetime(2011, 12, 1), + datetime(2011, 11, 1), + ], + tz="UTC", + ).values + + ex_keys = [Timestamp("2011-11-01"), Timestamp("2011-12-01")] + expected = {ex_keys[0]: idx[[0, 4]], ex_keys[1]: idx[[1, 3]]} + tm.assert_dict_equal(idx.groupby(to_groupby), expected) + + def test_append_preserves_dtype(self, simple_index): + # In particular Index with dtype float32 + index = simple_index + N = len(index) + + result = index.append(index) + assert result.dtype == index.dtype + tm.assert_index_equal(result[:N], index, check_exact=True) + tm.assert_index_equal(result[N:], index, check_exact=True) + + alt = index.take(list(range(N)) * 2) + tm.assert_index_equal(result, alt, check_exact=True) + + def test_inv(self, simple_index, using_infer_string): + idx = simple_index + + if idx.dtype.kind in ["i", "u"]: + res = ~idx + expected = Index(~idx.values, name=idx.name) + tm.assert_index_equal(res, expected) + + # check that we are matching Series behavior + res2 = ~Series(idx) + tm.assert_series_equal(res2, Series(expected)) + else: + if idx.dtype.kind == "f": + err = TypeError + msg = "ufunc 'invert' not supported for the input types" + elif using_infer_string and idx.dtype == "string": + import pyarrow as pa + + err = pa.lib.ArrowNotImplementedError + msg = "has no kernel" + else: + err = TypeError + msg = "bad operand" + with pytest.raises(err, match=msg): + ~idx + + # check that we get the same behavior with Series + with pytest.raises(err, match=msg): + ~Series(idx) + + def test_is_boolean_is_deprecated(self, simple_index): + # GH50042 + idx = simple_index + with tm.assert_produces_warning(FutureWarning): + idx.is_boolean() + + def test_is_floating_is_deprecated(self, simple_index): + # GH50042 + idx = simple_index + with tm.assert_produces_warning(FutureWarning): + idx.is_floating() + + def test_is_integer_is_deprecated(self, simple_index): + # GH50042 + idx = simple_index + with tm.assert_produces_warning(FutureWarning): + idx.is_integer() + + def test_holds_integer_deprecated(self, simple_index): + # GH50243 + idx = simple_index + msg = f"{type(idx).__name__}.holds_integer is deprecated. " + with tm.assert_produces_warning(FutureWarning, match=msg): + idx.holds_integer() + + def test_is_numeric_is_deprecated(self, simple_index): + # GH50042 + idx = simple_index + with tm.assert_produces_warning( + FutureWarning, + match=f"{type(idx).__name__}.is_numeric is deprecated. ", + ): + idx.is_numeric() + + def test_is_categorical_is_deprecated(self, simple_index): + # GH50042 + idx = simple_index + with tm.assert_produces_warning( + FutureWarning, + match=r"Use pandas\.api\.types\.is_categorical_dtype instead", + ): + idx.is_categorical() + + def test_is_interval_is_deprecated(self, simple_index): + # GH50042 + idx = simple_index + with tm.assert_produces_warning(FutureWarning): + idx.is_interval() + + def test_is_object_is_deprecated(self, simple_index): + # GH50042 + idx = simple_index + with tm.assert_produces_warning(FutureWarning): + idx.is_object() + + +class TestNumericBase: + @pytest.fixture( + params=[ + RangeIndex(start=0, stop=20, step=2), + Index(np.arange(5, dtype=np.float64)), + Index(np.arange(5, dtype=np.float32)), + Index(np.arange(5, dtype=np.uint64)), + Index(range(0, 20, 2), dtype=np.int64), + Index(range(0, 20, 2), dtype=np.int32), + Index(range(0, 20, 2), dtype=np.int16), + Index(range(0, 20, 2), dtype=np.int8), + ] + ) + def simple_index(self, request): + return request.param + + def test_constructor_unwraps_index(self, simple_index): + if isinstance(simple_index, RangeIndex): + pytest.skip("Tested elsewhere.") + index_cls = type(simple_index) + dtype = simple_index.dtype + + idx = Index([1, 2], dtype=dtype) + result = index_cls(idx) + expected = np.array([1, 2], dtype=idx.dtype) + tm.assert_numpy_array_equal(result._data, expected) + + def test_can_hold_identifiers(self, simple_index): + idx = simple_index + key = idx[0] + assert idx._can_hold_identifiers_and_holds_name(key) is False + + def test_view(self, simple_index): + if isinstance(simple_index, RangeIndex): + pytest.skip("Tested elsewhere.") + index_cls = type(simple_index) + dtype = simple_index.dtype + + idx = index_cls([], dtype=dtype, name="Foo") + idx_view = idx.view() + assert idx_view.name == "Foo" + + idx_view = idx.view(dtype) + tm.assert_index_equal(idx, index_cls(idx_view, name="Foo"), exact=True) + + msg = "Passing a type in .*Index.view is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + idx_view = idx.view(index_cls) + tm.assert_index_equal(idx, index_cls(idx_view, name="Foo"), exact=True) + + def test_format(self, simple_index): + # GH35439 + if isinstance(simple_index, DatetimeIndex): + pytest.skip("Tested elsewhere") + idx = simple_index + max_width = max(len(str(x)) for x in idx) + expected = [str(x).ljust(max_width) for x in idx] + msg = r"Index\.format is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + assert idx.format() == expected + + def test_insert_non_na(self, simple_index): + # GH#43921 inserting an element that we know we can hold should + # not change dtype or type (except for RangeIndex) + index = simple_index + + result = index.insert(0, index[0]) + + expected = Index([index[0]] + list(index), dtype=index.dtype) + tm.assert_index_equal(result, expected, exact=True) + + def test_insert_na(self, nulls_fixture, simple_index): + # GH 18295 (test missing) + index = simple_index + na_val = nulls_fixture + + if na_val is pd.NaT: + expected = Index([index[0], pd.NaT] + list(index[1:]), dtype=object) + else: + expected = Index([index[0], np.nan] + list(index[1:])) + # GH#43921 we preserve float dtype + if index.dtype.kind == "f": + expected = Index(expected, dtype=index.dtype) + + result = index.insert(1, na_val) + tm.assert_index_equal(result, expected, exact=True) + + def test_arithmetic_explicit_conversions(self, simple_index): + # GH 8608 + # add/sub are overridden explicitly for Float/Int Index + index_cls = type(simple_index) + if index_cls is RangeIndex: + idx = RangeIndex(5) + else: + idx = index_cls(np.arange(5, dtype="int64")) + + # float conversions + arr = np.arange(5, dtype="int64") * 3.2 + expected = Index(arr, dtype=np.float64) + fidx = idx * 3.2 + tm.assert_index_equal(fidx, expected) + fidx = 3.2 * idx + tm.assert_index_equal(fidx, expected) + + # interops with numpy arrays + expected = Index(arr, dtype=np.float64) + a = np.zeros(5, dtype="float64") + result = fidx - a + tm.assert_index_equal(result, expected) + + expected = Index(-arr, dtype=np.float64) + a = np.zeros(5, dtype="float64") + result = a - fidx + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("complex_dtype", [np.complex64, np.complex128]) + def test_astype_to_complex(self, complex_dtype, simple_index): + result = simple_index.astype(complex_dtype) + + assert type(result) is Index and result.dtype == complex_dtype + + def test_cast_string(self, simple_index): + if isinstance(simple_index, RangeIndex): + pytest.skip("casting of strings not relevant for RangeIndex") + result = type(simple_index)(["0", "1", "2"], dtype=simple_index.dtype) + expected = type(simple_index)([0, 1, 2], dtype=simple_index.dtype) + tm.assert_index_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_setops.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_setops.py new file mode 100644 index 0000000000000000000000000000000000000000..4a6982cf98670b43cfb87f5f3ea5a4edd5a7fd36 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_setops.py @@ -0,0 +1,959 @@ +""" +The tests in this package are to ensure the proper resultant dtypes of +set operations. +""" +from datetime import datetime +import operator + +import numpy as np +import pytest + +from pandas._libs import lib + +from pandas.core.dtypes.cast import find_common_type + +from pandas import ( + CategoricalDtype, + CategoricalIndex, + DatetimeTZDtype, + Index, + MultiIndex, + PeriodDtype, + RangeIndex, + Series, + Timestamp, +) +import pandas._testing as tm +from pandas.api.types import ( + is_signed_integer_dtype, + pandas_dtype, +) + + +def equal_contents(arr1, arr2) -> bool: + """ + Checks if the set of unique elements of arr1 and arr2 are equivalent. + """ + return frozenset(arr1) == frozenset(arr2) + + +@pytest.fixture( + params=tm.ALL_REAL_NUMPY_DTYPES + + [ + "object", + "category", + "datetime64[ns]", + "timedelta64[ns]", + ] +) +def any_dtype_for_small_pos_integer_indexes(request): + """ + Dtypes that can be given to an Index with small positive integers. + + This means that for any dtype `x` in the params list, `Index([1, 2, 3], dtype=x)` is + valid and gives the correct Index (sub-)class. + """ + return request.param + + +def test_union_same_types(index): + # Union with a non-unique, non-monotonic index raises error + # Only needed for bool index factory + idx1 = index.sort_values() + idx2 = index.sort_values() + assert idx1.union(idx2).dtype == idx1.dtype + + +def test_union_different_types(index_flat, index_flat2, request): + # This test only considers combinations of indices + # GH 23525 + idx1 = index_flat + idx2 = index_flat2 + + if ( + not idx1.is_unique + and not idx2.is_unique + and idx1.dtype.kind == "i" + and idx2.dtype.kind == "b" + ) or ( + not idx2.is_unique + and not idx1.is_unique + and idx2.dtype.kind == "i" + and idx1.dtype.kind == "b" + ): + # Each condition had idx[1|2].is_monotonic_decreasing + # but failed when e.g. + # idx1 = Index( + # [True, True, True, True, True, True, True, True, False, False], dtype='bool' + # ) + # idx2 = Index([0, 0, 1, 1, 2, 2], dtype='int64') + mark = pytest.mark.xfail( + reason="GH#44000 True==1", raises=ValueError, strict=False + ) + request.applymarker(mark) + + common_dtype = find_common_type([idx1.dtype, idx2.dtype]) + + warn = None + msg = "'<' not supported between" + if not len(idx1) or not len(idx2): + pass + elif (idx1.dtype.kind == "c" and (not lib.is_np_dtype(idx2.dtype, "iufc"))) or ( + idx2.dtype.kind == "c" and (not lib.is_np_dtype(idx1.dtype, "iufc")) + ): + # complex objects non-sortable + warn = RuntimeWarning + elif ( + isinstance(idx1.dtype, PeriodDtype) and isinstance(idx2.dtype, CategoricalDtype) + ) or ( + isinstance(idx2.dtype, PeriodDtype) and isinstance(idx1.dtype, CategoricalDtype) + ): + warn = FutureWarning + msg = r"PeriodDtype\[B\] is deprecated" + mark = pytest.mark.xfail( + reason="Warning not produced on all builds", + raises=AssertionError, + strict=False, + ) + request.applymarker(mark) + + any_uint64 = np.uint64 in (idx1.dtype, idx2.dtype) + idx1_signed = is_signed_integer_dtype(idx1.dtype) + idx2_signed = is_signed_integer_dtype(idx2.dtype) + + # Union with a non-unique, non-monotonic index raises error + # This applies to the boolean index + idx1 = idx1.sort_values() + idx2 = idx2.sort_values() + + with tm.assert_produces_warning(warn, match=msg): + res1 = idx1.union(idx2) + res2 = idx2.union(idx1) + + if any_uint64 and (idx1_signed or idx2_signed): + assert res1.dtype == np.dtype("O") + assert res2.dtype == np.dtype("O") + else: + assert res1.dtype == common_dtype + assert res2.dtype == common_dtype + + +@pytest.mark.parametrize( + "idx1,idx2", + [ + (Index(np.arange(5), dtype=np.int64), RangeIndex(5)), + (Index(np.arange(5), dtype=np.float64), Index(np.arange(5), dtype=np.int64)), + (Index(np.arange(5), dtype=np.float64), RangeIndex(5)), + (Index(np.arange(5), dtype=np.float64), Index(np.arange(5), dtype=np.uint64)), + ], +) +def test_compatible_inconsistent_pairs(idx1, idx2): + # GH 23525 + res1 = idx1.union(idx2) + res2 = idx2.union(idx1) + + assert res1.dtype in (idx1.dtype, idx2.dtype) + assert res2.dtype in (idx1.dtype, idx2.dtype) + + +@pytest.mark.parametrize( + "left, right, expected", + [ + ("int64", "int64", "int64"), + ("int64", "uint64", "object"), + ("int64", "float64", "float64"), + ("uint64", "float64", "float64"), + ("uint64", "uint64", "uint64"), + ("float64", "float64", "float64"), + ("datetime64[ns]", "int64", "object"), + ("datetime64[ns]", "uint64", "object"), + ("datetime64[ns]", "float64", "object"), + ("datetime64[ns, CET]", "int64", "object"), + ("datetime64[ns, CET]", "uint64", "object"), + ("datetime64[ns, CET]", "float64", "object"), + ("Period[D]", "int64", "object"), + ("Period[D]", "uint64", "object"), + ("Period[D]", "float64", "object"), + ], +) +@pytest.mark.parametrize("names", [("foo", "foo", "foo"), ("foo", "bar", None)]) +def test_union_dtypes(left, right, expected, names): + left = pandas_dtype(left) + right = pandas_dtype(right) + a = Index([], dtype=left, name=names[0]) + b = Index([], dtype=right, name=names[1]) + result = a.union(b) + assert result.dtype == expected + assert result.name == names[2] + + # Testing name retention + # TODO: pin down desired dtype; do we want it to be commutative? + result = a.intersection(b) + assert result.name == names[2] + + +@pytest.mark.parametrize("values", [[1, 2, 2, 3], [3, 3]]) +def test_intersection_duplicates(values): + # GH#31326 + a = Index(values) + b = Index([3, 3]) + result = a.intersection(b) + expected = Index([3]) + tm.assert_index_equal(result, expected) + + +class TestSetOps: + # Set operation tests shared by all indexes in the `index` fixture + @pytest.mark.parametrize("case", [0.5, "xxx"]) + @pytest.mark.parametrize( + "method", ["intersection", "union", "difference", "symmetric_difference"] + ) + def test_set_ops_error_cases(self, case, method, index): + # non-iterable input + msg = "Input must be Index or array-like" + with pytest.raises(TypeError, match=msg): + getattr(index, method)(case) + + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + def test_intersection_base(self, index): + if isinstance(index, CategoricalIndex): + pytest.skip(f"Not relevant for {type(index).__name__}") + + first = index[:5].unique() + second = index[:3].unique() + intersect = first.intersection(second) + tm.assert_index_equal(intersect, second) + + if isinstance(index.dtype, DatetimeTZDtype): + # The second.values below will drop tz, so the rest of this test + # is not applicable. + return + + # GH#10149 + cases = [second.to_numpy(), second.to_series(), second.to_list()] + for case in cases: + result = first.intersection(case) + assert equal_contents(result, second) + + if isinstance(index, MultiIndex): + msg = "other must be a MultiIndex or a list of tuples" + with pytest.raises(TypeError, match=msg): + first.intersection([1, 2, 3]) + + @pytest.mark.filterwarnings( + "ignore:Falling back on a non-pyarrow:pandas.errors.PerformanceWarning" + ) + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + def test_union_base(self, index): + index = index.unique() + first = index[3:] + second = index[:5] + everything = index + + union = first.union(second) + tm.assert_index_equal(union.sort_values(), everything.sort_values()) + + if isinstance(index.dtype, DatetimeTZDtype): + # The second.values below will drop tz, so the rest of this test + # is not applicable. + return + + # GH#10149 + cases = [second.to_numpy(), second.to_series(), second.to_list()] + for case in cases: + result = first.union(case) + assert equal_contents(result, everything) + + if isinstance(index, MultiIndex): + msg = "other must be a MultiIndex or a list of tuples" + with pytest.raises(TypeError, match=msg): + first.union([1, 2, 3]) + + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + @pytest.mark.filterwarnings( + "ignore:Falling back on a non-pyarrow:pandas.errors.PerformanceWarning" + ) + def test_difference_base(self, sort, index): + first = index[2:] + second = index[:4] + if index.inferred_type == "boolean": + # i think (TODO: be sure) there assumptions baked in about + # the index fixture that don't hold here? + answer = set(first).difference(set(second)) + elif isinstance(index, CategoricalIndex): + answer = [] + else: + answer = index[4:] + result = first.difference(second, sort) + assert equal_contents(result, answer) + + # GH#10149 + cases = [second.to_numpy(), second.to_series(), second.to_list()] + for case in cases: + result = first.difference(case, sort) + assert equal_contents(result, answer) + + if isinstance(index, MultiIndex): + msg = "other must be a MultiIndex or a list of tuples" + with pytest.raises(TypeError, match=msg): + first.difference([1, 2, 3], sort) + + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + @pytest.mark.filterwarnings( + "ignore:Falling back on a non-pyarrow:pandas.errors.PerformanceWarning" + ) + def test_symmetric_difference(self, index): + if isinstance(index, CategoricalIndex): + pytest.skip(f"Not relevant for {type(index).__name__}") + if len(index) < 2: + pytest.skip("Too few values for test") + if index[0] in index[1:] or index[-1] in index[:-1]: + # index fixture has e.g. an index of bools that does not satisfy this, + # another with [0, 0, 1, 1, 2, 2] + pytest.skip("Index values no not satisfy test condition.") + + first = index[1:] + second = index[:-1] + answer = index[[0, -1]] + result = first.symmetric_difference(second) + tm.assert_index_equal(result.sort_values(), answer.sort_values()) + + # GH#10149 + cases = [second.to_numpy(), second.to_series(), second.to_list()] + for case in cases: + result = first.symmetric_difference(case) + assert equal_contents(result, answer) + + if isinstance(index, MultiIndex): + msg = "other must be a MultiIndex or a list of tuples" + with pytest.raises(TypeError, match=msg): + first.symmetric_difference([1, 2, 3]) + + @pytest.mark.parametrize( + "fname, sname, expected_name", + [ + ("A", "A", "A"), + ("A", "B", None), + ("A", None, None), + (None, "B", None), + (None, None, None), + ], + ) + def test_corner_union(self, index_flat, fname, sname, expected_name): + # GH#9943, GH#9862 + # Test unions with various name combinations + # Do not test MultiIndex or repeats + if not index_flat.is_unique: + index = index_flat.unique() + else: + index = index_flat + + # Test copy.union(copy) + first = index.copy().set_names(fname) + second = index.copy().set_names(sname) + union = first.union(second) + expected = index.copy().set_names(expected_name) + tm.assert_index_equal(union, expected) + + # Test copy.union(empty) + first = index.copy().set_names(fname) + second = index.drop(index).set_names(sname) + union = first.union(second) + expected = index.copy().set_names(expected_name) + tm.assert_index_equal(union, expected) + + # Test empty.union(copy) + first = index.drop(index).set_names(fname) + second = index.copy().set_names(sname) + union = first.union(second) + expected = index.copy().set_names(expected_name) + tm.assert_index_equal(union, expected) + + # Test empty.union(empty) + first = index.drop(index).set_names(fname) + second = index.drop(index).set_names(sname) + union = first.union(second) + expected = index.drop(index).set_names(expected_name) + tm.assert_index_equal(union, expected) + + @pytest.mark.parametrize( + "fname, sname, expected_name", + [ + ("A", "A", "A"), + ("A", "B", None), + ("A", None, None), + (None, "B", None), + (None, None, None), + ], + ) + def test_union_unequal(self, index_flat, fname, sname, expected_name): + if not index_flat.is_unique: + index = index_flat.unique() + else: + index = index_flat + + # test copy.union(subset) - need sort for unicode and string + first = index.copy().set_names(fname) + second = index[1:].set_names(sname) + union = first.union(second).sort_values() + expected = index.set_names(expected_name).sort_values() + tm.assert_index_equal(union, expected) + + @pytest.mark.parametrize( + "fname, sname, expected_name", + [ + ("A", "A", "A"), + ("A", "B", None), + ("A", None, None), + (None, "B", None), + (None, None, None), + ], + ) + def test_corner_intersect(self, index_flat, fname, sname, expected_name): + # GH#35847 + # Test intersections with various name combinations + if not index_flat.is_unique: + index = index_flat.unique() + else: + index = index_flat + + # Test copy.intersection(copy) + first = index.copy().set_names(fname) + second = index.copy().set_names(sname) + intersect = first.intersection(second) + expected = index.copy().set_names(expected_name) + tm.assert_index_equal(intersect, expected) + + # Test copy.intersection(empty) + first = index.copy().set_names(fname) + second = index.drop(index).set_names(sname) + intersect = first.intersection(second) + expected = index.drop(index).set_names(expected_name) + tm.assert_index_equal(intersect, expected) + + # Test empty.intersection(copy) + first = index.drop(index).set_names(fname) + second = index.copy().set_names(sname) + intersect = first.intersection(second) + expected = index.drop(index).set_names(expected_name) + tm.assert_index_equal(intersect, expected) + + # Test empty.intersection(empty) + first = index.drop(index).set_names(fname) + second = index.drop(index).set_names(sname) + intersect = first.intersection(second) + expected = index.drop(index).set_names(expected_name) + tm.assert_index_equal(intersect, expected) + + @pytest.mark.parametrize( + "fname, sname, expected_name", + [ + ("A", "A", "A"), + ("A", "B", None), + ("A", None, None), + (None, "B", None), + (None, None, None), + ], + ) + def test_intersect_unequal(self, index_flat, fname, sname, expected_name): + if not index_flat.is_unique: + index = index_flat.unique() + else: + index = index_flat + + # test copy.intersection(subset) - need sort for unicode and string + first = index.copy().set_names(fname) + second = index[1:].set_names(sname) + intersect = first.intersection(second).sort_values() + expected = index[1:].set_names(expected_name).sort_values() + tm.assert_index_equal(intersect, expected) + + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + def test_intersection_name_retention_with_nameless(self, index): + if isinstance(index, MultiIndex): + index = index.rename(list(range(index.nlevels))) + else: + index = index.rename("foo") + + other = np.asarray(index) + + result = index.intersection(other) + assert result.name == index.name + + # empty other, same dtype + result = index.intersection(other[:0]) + assert result.name == index.name + + # empty `self` + result = index[:0].intersection(other) + assert result.name == index.name + + def test_difference_preserves_type_empty(self, index, sort): + # GH#20040 + # If taking difference of a set and itself, it + # needs to preserve the type of the index + if not index.is_unique: + pytest.skip("Not relevant since index is not unique") + result = index.difference(index, sort=sort) + expected = index[:0] + tm.assert_index_equal(result, expected, exact=True) + + def test_difference_name_retention_equals(self, index, names): + if isinstance(index, MultiIndex): + names = [[x] * index.nlevels for x in names] + index = index.rename(names[0]) + other = index.rename(names[1]) + + assert index.equals(other) + + result = index.difference(other) + expected = index[:0].rename(names[2]) + tm.assert_index_equal(result, expected) + + def test_intersection_difference_match_empty(self, index, sort): + # GH#20040 + # Test that the intersection of an index with an + # empty index produces the same index as the difference + # of an index with itself. Test for all types + if not index.is_unique: + pytest.skip("Not relevant because index is not unique") + inter = index.intersection(index[:0]) + diff = index.difference(index, sort=sort) + tm.assert_index_equal(inter, diff, exact=True) + + +@pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") +@pytest.mark.filterwarnings( + "ignore:Falling back on a non-pyarrow:pandas.errors.PerformanceWarning" +) +@pytest.mark.parametrize( + "method", ["intersection", "union", "difference", "symmetric_difference"] +) +def test_setop_with_categorical(index_flat, sort, method): + # MultiIndex tested separately in tests.indexes.multi.test_setops + index = index_flat + + other = index.astype("category") + exact = "equiv" if isinstance(index, RangeIndex) else True + + result = getattr(index, method)(other, sort=sort) + expected = getattr(index, method)(index, sort=sort) + tm.assert_index_equal(result, expected, exact=exact) + + result = getattr(index, method)(other[:5], sort=sort) + expected = getattr(index, method)(index[:5], sort=sort) + tm.assert_index_equal(result, expected, exact=exact) + + +def test_intersection_duplicates_all_indexes(index): + # GH#38743 + if index.empty: + # No duplicates in empty indexes + pytest.skip("Not relevant for empty Index") + + idx = index + idx_non_unique = idx[[0, 0, 1, 2]] + + assert idx.intersection(idx_non_unique).equals(idx_non_unique.intersection(idx)) + assert idx.intersection(idx_non_unique).is_unique + + +def test_union_duplicate_index_subsets_of_each_other( + any_dtype_for_small_pos_integer_indexes, +): + # GH#31326 + dtype = any_dtype_for_small_pos_integer_indexes + a = Index([1, 2, 2, 3], dtype=dtype) + b = Index([3, 3, 4], dtype=dtype) + + expected = Index([1, 2, 2, 3, 3, 4], dtype=dtype) + if isinstance(a, CategoricalIndex): + expected = Index([1, 2, 2, 3, 3, 4]) + result = a.union(b) + tm.assert_index_equal(result, expected) + result = a.union(b, sort=False) + tm.assert_index_equal(result, expected) + + +def test_union_with_duplicate_index_and_non_monotonic( + any_dtype_for_small_pos_integer_indexes, +): + # GH#36289 + dtype = any_dtype_for_small_pos_integer_indexes + a = Index([1, 0, 0], dtype=dtype) + b = Index([0, 1], dtype=dtype) + expected = Index([0, 0, 1], dtype=dtype) + + result = a.union(b) + tm.assert_index_equal(result, expected) + + result = b.union(a) + tm.assert_index_equal(result, expected) + + +def test_union_duplicate_index_different_dtypes(): + # GH#36289 + a = Index([1, 2, 2, 3]) + b = Index(["1", "0", "0"]) + expected = Index([1, 2, 2, 3, "1", "0", "0"]) + result = a.union(b, sort=False) + tm.assert_index_equal(result, expected) + + +def test_union_same_value_duplicated_in_both(): + # GH#36289 + a = Index([0, 0, 1]) + b = Index([0, 0, 1, 2]) + result = a.union(b) + expected = Index([0, 0, 1, 2]) + tm.assert_index_equal(result, expected) + + +@pytest.mark.parametrize("dup", [1, np.nan]) +def test_union_nan_in_both(dup): + # GH#36289 + a = Index([np.nan, 1, 2, 2]) + b = Index([np.nan, dup, 1, 2]) + result = a.union(b, sort=False) + expected = Index([np.nan, dup, 1.0, 2.0, 2.0]) + tm.assert_index_equal(result, expected) + + +def test_union_rangeindex_sort_true(): + # GH 53490 + idx1 = RangeIndex(1, 100, 6) + idx2 = RangeIndex(1, 50, 3) + result = idx1.union(idx2, sort=True) + expected = Index( + [ + 1, + 4, + 7, + 10, + 13, + 16, + 19, + 22, + 25, + 28, + 31, + 34, + 37, + 40, + 43, + 46, + 49, + 55, + 61, + 67, + 73, + 79, + 85, + 91, + 97, + ] + ) + tm.assert_index_equal(result, expected) + + +def test_union_with_duplicate_index_not_subset_and_non_monotonic( + any_dtype_for_small_pos_integer_indexes, +): + # GH#36289 + dtype = any_dtype_for_small_pos_integer_indexes + a = Index([1, 0, 2], dtype=dtype) + b = Index([0, 0, 1], dtype=dtype) + expected = Index([0, 0, 1, 2], dtype=dtype) + if isinstance(a, CategoricalIndex): + expected = Index([0, 0, 1, 2]) + + result = a.union(b) + tm.assert_index_equal(result, expected) + + result = b.union(a) + tm.assert_index_equal(result, expected) + + +def test_union_int_categorical_with_nan(): + ci = CategoricalIndex([1, 2, np.nan]) + assert ci.categories.dtype.kind == "i" + + idx = Index([1, 2]) + + result = idx.union(ci) + expected = Index([1, 2, np.nan], dtype=np.float64) + tm.assert_index_equal(result, expected) + + result = ci.union(idx) + tm.assert_index_equal(result, expected) + + +class TestSetOpsUnsorted: + # These may eventually belong in a dtype-specific test_setops, or + # parametrized over a more general fixture + def test_intersect_str_dates(self): + dt_dates = [datetime(2012, 2, 9), datetime(2012, 2, 22)] + + index1 = Index(dt_dates, dtype=object) + index2 = Index(["aa"], dtype=object) + result = index2.intersection(index1) + + expected = Index([], dtype=object) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("index", ["string"], indirect=True) + def test_intersection(self, index, sort): + first = index[:20] + second = index[:10] + intersect = first.intersection(second, sort=sort) + if sort in (None, False): + tm.assert_index_equal(intersect.sort_values(), second.sort_values()) + else: + tm.assert_index_equal(intersect, second) + + # Corner cases + inter = first.intersection(first, sort=sort) + assert inter is first + + @pytest.mark.parametrize( + "index2,keeps_name", + [ + (Index([3, 4, 5, 6, 7], name="index"), True), # preserve same name + (Index([3, 4, 5, 6, 7], name="other"), False), # drop diff names + (Index([3, 4, 5, 6, 7]), False), + ], + ) + def test_intersection_name_preservation(self, index2, keeps_name, sort): + index1 = Index([1, 2, 3, 4, 5], name="index") + expected = Index([3, 4, 5]) + result = index1.intersection(index2, sort) + + if keeps_name: + expected.name = "index" + + assert result.name == expected.name + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("index", ["string"], indirect=True) + @pytest.mark.parametrize( + "first_name,second_name,expected_name", + [("A", "A", "A"), ("A", "B", None), (None, "B", None)], + ) + def test_intersection_name_preservation2( + self, index, first_name, second_name, expected_name, sort + ): + first = index[5:20] + second = index[:10] + first.name = first_name + second.name = second_name + intersect = first.intersection(second, sort=sort) + assert intersect.name == expected_name + + def test_chained_union(self, sort): + # Chained unions handles names correctly + i1 = Index([1, 2], name="i1") + i2 = Index([5, 6], name="i2") + i3 = Index([3, 4], name="i3") + union = i1.union(i2.union(i3, sort=sort), sort=sort) + expected = i1.union(i2, sort=sort).union(i3, sort=sort) + tm.assert_index_equal(union, expected) + + j1 = Index([1, 2], name="j1") + j2 = Index([], name="j2") + j3 = Index([], name="j3") + union = j1.union(j2.union(j3, sort=sort), sort=sort) + expected = j1.union(j2, sort=sort).union(j3, sort=sort) + tm.assert_index_equal(union, expected) + + @pytest.mark.parametrize("index", ["string"], indirect=True) + def test_union(self, index, sort): + first = index[5:20] + second = index[:10] + everything = index[:20] + + union = first.union(second, sort=sort) + if sort in (None, False): + tm.assert_index_equal(union.sort_values(), everything.sort_values()) + else: + tm.assert_index_equal(union, everything) + + @pytest.mark.parametrize("klass", [np.array, Series, list]) + @pytest.mark.parametrize("index", ["string"], indirect=True) + def test_union_from_iterables(self, index, klass, sort): + # GH#10149 + first = index[5:20] + second = index[:10] + everything = index[:20] + + case = klass(second.values) + result = first.union(case, sort=sort) + if sort in (None, False): + tm.assert_index_equal(result.sort_values(), everything.sort_values()) + else: + tm.assert_index_equal(result, everything) + + @pytest.mark.parametrize("index", ["string"], indirect=True) + def test_union_identity(self, index, sort): + first = index[5:20] + + union = first.union(first, sort=sort) + # i.e. identity is not preserved when sort is True + assert (union is first) is (not sort) + + # This should no longer be the same object, since [] is not consistent, + # both objects will be recast to dtype('O') + union = first.union(Index([], dtype=first.dtype), sort=sort) + assert (union is first) is (not sort) + + union = Index([], dtype=first.dtype).union(first, sort=sort) + assert (union is first) is (not sort) + + @pytest.mark.parametrize("index", ["string"], indirect=True) + @pytest.mark.parametrize("second_name,expected", [(None, None), ("name", "name")]) + def test_difference_name_preservation(self, index, second_name, expected, sort): + first = index[5:20] + second = index[:10] + answer = index[10:20] + + first.name = "name" + second.name = second_name + result = first.difference(second, sort=sort) + + if sort is True: + tm.assert_index_equal(result, answer) + else: + answer.name = second_name + tm.assert_index_equal(result.sort_values(), answer.sort_values()) + + if expected is None: + assert result.name is None + else: + assert result.name == expected + + def test_difference_empty_arg(self, index, sort): + first = index.copy() + first = first[5:20] + first.name = "name" + result = first.difference([], sort) + expected = index[5:20].unique() + expected.name = "name" + tm.assert_index_equal(result, expected) + + def test_difference_should_not_compare(self): + # GH 55113 + left = Index([1, 1]) + right = Index([True]) + result = left.difference(right) + expected = Index([1]) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("index", ["string"], indirect=True) + def test_difference_identity(self, index, sort): + first = index[5:20] + first.name = "name" + result = first.difference(first, sort) + + assert len(result) == 0 + assert result.name == first.name + + @pytest.mark.parametrize("index", ["string"], indirect=True) + def test_difference_sort(self, index, sort): + first = index[5:20] + second = index[:10] + + result = first.difference(second, sort) + expected = index[10:20] + + if sort is None: + expected = expected.sort_values() + + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("opname", ["difference", "symmetric_difference"]) + def test_difference_incomparable(self, opname): + a = Index([3, Timestamp("2000"), 1]) + b = Index([2, Timestamp("1999"), 1]) + op = operator.methodcaller(opname, b) + + with tm.assert_produces_warning(RuntimeWarning): + # sort=None, the default + result = op(a) + expected = Index([3, Timestamp("2000"), 2, Timestamp("1999")]) + if opname == "difference": + expected = expected[:2] + tm.assert_index_equal(result, expected) + + # sort=False + op = operator.methodcaller(opname, b, sort=False) + result = op(a) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("opname", ["difference", "symmetric_difference"]) + def test_difference_incomparable_true(self, opname): + a = Index([3, Timestamp("2000"), 1]) + b = Index([2, Timestamp("1999"), 1]) + op = operator.methodcaller(opname, b, sort=True) + + msg = "'<' not supported between instances of 'Timestamp' and 'int'" + with pytest.raises(TypeError, match=msg): + op(a) + + def test_symmetric_difference_mi(self, sort): + index1 = MultiIndex.from_tuples(zip(["foo", "bar", "baz"], [1, 2, 3])) + index2 = MultiIndex.from_tuples([("foo", 1), ("bar", 3)]) + result = index1.symmetric_difference(index2, sort=sort) + expected = MultiIndex.from_tuples([("bar", 2), ("baz", 3), ("bar", 3)]) + if sort is None: + expected = expected.sort_values() + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "index2,expected", + [ + (Index([0, 1, np.nan]), Index([2.0, 3.0, 0.0])), + (Index([0, 1]), Index([np.nan, 2.0, 3.0, 0.0])), + ], + ) + def test_symmetric_difference_missing(self, index2, expected, sort): + # GH#13514 change: {nan} - {nan} == {} + # (GH#6444, sorting of nans, is no longer an issue) + index1 = Index([1, np.nan, 2, 3]) + + result = index1.symmetric_difference(index2, sort=sort) + if sort is None: + expected = expected.sort_values() + tm.assert_index_equal(result, expected) + + def test_symmetric_difference_non_index(self, sort): + index1 = Index([1, 2, 3, 4], name="index1") + index2 = np.array([2, 3, 4, 5]) + expected = Index([1, 5], name="index1") + result = index1.symmetric_difference(index2, sort=sort) + if sort in (None, True): + tm.assert_index_equal(result, expected) + else: + tm.assert_index_equal(result.sort_values(), expected) + assert result.name == "index1" + + result = index1.symmetric_difference(index2, result_name="new_name", sort=sort) + expected.name = "new_name" + if sort in (None, True): + tm.assert_index_equal(result, expected) + else: + tm.assert_index_equal(result.sort_values(), expected) + assert result.name == "new_name" + + def test_union_ea_dtypes(self, any_numeric_ea_and_arrow_dtype): + # GH#51365 + idx = Index([1, 2, 3], dtype=any_numeric_ea_and_arrow_dtype) + idx2 = Index([3, 4, 5], dtype=any_numeric_ea_and_arrow_dtype) + result = idx.union(idx2) + expected = Index([1, 2, 3, 4, 5], dtype=any_numeric_ea_and_arrow_dtype) + tm.assert_index_equal(result, expected) + + def test_union_string_array(self, any_string_dtype): + idx1 = Index(["a"], dtype=any_string_dtype) + idx2 = Index(["b"], dtype=any_string_dtype) + result = idx1.union(idx2) + expected = Index(["a", "b"], dtype=any_string_dtype) + tm.assert_index_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_subclass.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_subclass.py new file mode 100644 index 0000000000000000000000000000000000000000..c3287e1ddcddcedc14857f2299798d3957830921 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/test_subclass.py @@ -0,0 +1,40 @@ +""" +Tests involving custom Index subclasses +""" +import numpy as np + +from pandas import ( + DataFrame, + Index, +) +import pandas._testing as tm + + +class CustomIndex(Index): + def __new__(cls, data, name=None): + # assert that this index class cannot hold strings + if any(isinstance(val, str) for val in data): + raise TypeError("CustomIndex cannot hold strings") + + if name is None and hasattr(data, "name"): + name = data.name + data = np.array(data, dtype="O") + + return cls._simple_new(data, name) + + +def test_insert_fallback_to_base_index(): + # https://github.com/pandas-dev/pandas/issues/47071 + + idx = CustomIndex([1, 2, 3]) + result = idx.insert(0, "string") + expected = Index(["string", 1, 2, 3], dtype=object) + tm.assert_index_equal(result, expected) + + df = DataFrame( + np.random.default_rng(2).standard_normal((2, 3)), + columns=idx, + index=Index([1, 2], name="string"), + ) + result = df.reset_index() + tm.assert_index_equal(result.columns, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/timedeltas/test_arithmetic.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/timedeltas/test_arithmetic.py new file mode 100644 index 0000000000000000000000000000000000000000..a431e10dc18ab15da0fd07f798d54b6dead26073 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/timedeltas/test_arithmetic.py @@ -0,0 +1,51 @@ +# Arithmetic tests for TimedeltaIndex are generally about the result's `freq` attribute. +# Other cases can be shared in tests.arithmetic.test_timedelta64 +import numpy as np + +from pandas import ( + NaT, + Timedelta, + timedelta_range, +) +import pandas._testing as tm + + +class TestTimedeltaIndexArithmetic: + def test_arithmetic_zero_freq(self): + # GH#51575 don't get a .freq with freq.n = 0 + tdi = timedelta_range(0, periods=100, freq="ns") + result = tdi / 2 + assert result.freq is None + expected = tdi[:50].repeat(2) + tm.assert_index_equal(result, expected) + + result2 = tdi // 2 + assert result2.freq is None + expected2 = expected + tm.assert_index_equal(result2, expected2) + + result3 = tdi * 0 + assert result3.freq is None + expected3 = tdi[:1].repeat(100) + tm.assert_index_equal(result3, expected3) + + def test_tdi_division(self, index_or_series): + # doc example + + scalar = Timedelta(days=31) + td = index_or_series( + [scalar, scalar, scalar + Timedelta(minutes=5, seconds=3), NaT], + dtype="m8[ns]", + ) + + result = td / np.timedelta64(1, "D") + expected = index_or_series( + [31, 31, (31 * 86400 + 5 * 60 + 3) / 86400.0, np.nan] + ) + tm.assert_equal(result, expected) + + result = td / np.timedelta64(1, "s") + expected = index_or_series( + [31 * 86400, 31 * 86400, 31 * 86400 + 5 * 60 + 3, np.nan] + ) + tm.assert_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/timedeltas/test_formats.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/timedeltas/test_formats.py new file mode 100644 index 0000000000000000000000000000000000000000..607336060cbbc2093e224e31614e26a2c03bd72f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/timedeltas/test_formats.py @@ -0,0 +1,106 @@ +import pytest + +import pandas as pd +from pandas import ( + Series, + TimedeltaIndex, +) + + +class TestTimedeltaIndexRendering: + def test_repr_round_days_non_nano(self): + # GH#55405 + # we should get "1 days", not "1 days 00:00:00" with non-nano + tdi = TimedeltaIndex(["1 days"], freq="D").as_unit("s") + result = repr(tdi) + expected = "TimedeltaIndex(['1 days'], dtype='timedelta64[s]', freq='D')" + assert result == expected + + result2 = repr(Series(tdi)) + expected2 = "0 1 days\ndtype: timedelta64[s]" + assert result2 == expected2 + + @pytest.mark.parametrize("method", ["__repr__", "__str__"]) + def test_representation(self, method): + idx1 = TimedeltaIndex([], freq="D") + idx2 = TimedeltaIndex(["1 days"], freq="D") + idx3 = TimedeltaIndex(["1 days", "2 days"], freq="D") + idx4 = TimedeltaIndex(["1 days", "2 days", "3 days"], freq="D") + idx5 = TimedeltaIndex(["1 days 00:00:01", "2 days", "3 days"]) + + exp1 = "TimedeltaIndex([], dtype='timedelta64[ns]', freq='D')" + + exp2 = "TimedeltaIndex(['1 days'], dtype='timedelta64[ns]', freq='D')" + + exp3 = "TimedeltaIndex(['1 days', '2 days'], dtype='timedelta64[ns]', freq='D')" + + exp4 = ( + "TimedeltaIndex(['1 days', '2 days', '3 days'], " + "dtype='timedelta64[ns]', freq='D')" + ) + + exp5 = ( + "TimedeltaIndex(['1 days 00:00:01', '2 days 00:00:00', " + "'3 days 00:00:00'], dtype='timedelta64[ns]', freq=None)" + ) + + with pd.option_context("display.width", 300): + for idx, expected in zip( + [idx1, idx2, idx3, idx4, idx5], [exp1, exp2, exp3, exp4, exp5] + ): + result = getattr(idx, method)() + assert result == expected + + # TODO: this is a Series.__repr__ test + def test_representation_to_series(self): + idx1 = TimedeltaIndex([], freq="D") + idx2 = TimedeltaIndex(["1 days"], freq="D") + idx3 = TimedeltaIndex(["1 days", "2 days"], freq="D") + idx4 = TimedeltaIndex(["1 days", "2 days", "3 days"], freq="D") + idx5 = TimedeltaIndex(["1 days 00:00:01", "2 days", "3 days"]) + + exp1 = """Series([], dtype: timedelta64[ns])""" + + exp2 = "0 1 days\ndtype: timedelta64[ns]" + + exp3 = "0 1 days\n1 2 days\ndtype: timedelta64[ns]" + + exp4 = "0 1 days\n1 2 days\n2 3 days\ndtype: timedelta64[ns]" + + exp5 = ( + "0 1 days 00:00:01\n" + "1 2 days 00:00:00\n" + "2 3 days 00:00:00\n" + "dtype: timedelta64[ns]" + ) + + with pd.option_context("display.width", 300): + for idx, expected in zip( + [idx1, idx2, idx3, idx4, idx5], [exp1, exp2, exp3, exp4, exp5] + ): + result = repr(Series(idx)) + assert result == expected + + def test_summary(self): + # GH#9116 + idx1 = TimedeltaIndex([], freq="D") + idx2 = TimedeltaIndex(["1 days"], freq="D") + idx3 = TimedeltaIndex(["1 days", "2 days"], freq="D") + idx4 = TimedeltaIndex(["1 days", "2 days", "3 days"], freq="D") + idx5 = TimedeltaIndex(["1 days 00:00:01", "2 days", "3 days"]) + + exp1 = "TimedeltaIndex: 0 entries\nFreq: D" + + exp2 = "TimedeltaIndex: 1 entries, 1 days to 1 days\nFreq: D" + + exp3 = "TimedeltaIndex: 2 entries, 1 days to 2 days\nFreq: D" + + exp4 = "TimedeltaIndex: 3 entries, 1 days to 3 days\nFreq: D" + + exp5 = "TimedeltaIndex: 3 entries, 1 days 00:00:01 to 3 days 00:00:00" + + for idx, expected in zip( + [idx1, idx2, idx3, idx4, idx5], [exp1, exp2, exp3, exp4, exp5] + ): + result = idx._summary() + assert result == expected diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/timedeltas/test_ops.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/timedeltas/test_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..f6013baf86edcd566a17cd3127467a7443ac475a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/timedeltas/test_ops.py @@ -0,0 +1,14 @@ +from pandas import ( + TimedeltaIndex, + timedelta_range, +) +import pandas._testing as tm + + +class TestTimedeltaIndexOps: + def test_infer_freq(self, freq_sample): + # GH#11018 + idx = timedelta_range("1", freq=freq_sample, periods=10) + result = TimedeltaIndex(idx.asi8, freq="infer") + tm.assert_index_equal(idx, result) + assert result.freq == freq_sample diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/timedeltas/test_timedelta.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/timedeltas/test_timedelta.py new file mode 100644 index 0000000000000000000000000000000000000000..3120066741ffa292dc1533056438ebf481cb1849 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexes/timedeltas/test_timedelta.py @@ -0,0 +1,61 @@ +import numpy as np +import pytest + +from pandas import ( + Index, + Series, + Timedelta, + timedelta_range, +) +import pandas._testing as tm + + +class TestTimedeltaIndex: + def test_misc_coverage(self): + rng = timedelta_range("1 day", periods=5) + result = rng.groupby(rng.days) + assert isinstance(next(iter(result.values()))[0], Timedelta) + + def test_map(self): + # test_map_dictlike generally tests + + rng = timedelta_range("1 day", periods=10) + + f = lambda x: x.days + result = rng.map(f) + exp = Index([f(x) for x in rng], dtype=np.int64) + tm.assert_index_equal(result, exp) + + def test_fields(self): + rng = timedelta_range("1 days, 10:11:12.100123456", periods=2, freq="s") + tm.assert_index_equal(rng.days, Index([1, 1], dtype=np.int64)) + tm.assert_index_equal( + rng.seconds, + Index([10 * 3600 + 11 * 60 + 12, 10 * 3600 + 11 * 60 + 13], dtype=np.int32), + ) + tm.assert_index_equal( + rng.microseconds, + Index([100 * 1000 + 123, 100 * 1000 + 123], dtype=np.int32), + ) + tm.assert_index_equal(rng.nanoseconds, Index([456, 456], dtype=np.int32)) + + msg = "'TimedeltaIndex' object has no attribute '{}'" + with pytest.raises(AttributeError, match=msg.format("hours")): + rng.hours + with pytest.raises(AttributeError, match=msg.format("minutes")): + rng.minutes + with pytest.raises(AttributeError, match=msg.format("milliseconds")): + rng.milliseconds + + # with nat + s = Series(rng) + s[1] = np.nan + + tm.assert_series_equal(s.dt.days, Series([1, np.nan], index=[0, 1])) + tm.assert_series_equal( + s.dt.seconds, Series([10 * 3600 + 11 * 60 + 12, np.nan], index=[0, 1]) + ) + + # preserve name (GH15589) + rng.name = "name" + assert rng.days.name == "name" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcee66edafb36b10bed4ba86c0bd96ffc2062f38 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/common.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f69869f032baa561154861b53838e40b55005ad0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/common.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/conftest.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/conftest.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84a6a568433360590e6363d592d6b43e0ea6a76f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/conftest.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_at.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_at.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cd4deb2ea2d8de9ddbeb4cddca0e8dd92fa1a48 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_at.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_categorical.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_categorical.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c91d09fa04afc6fa545e364351780450c50b0675 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_categorical.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_chaining_and_caching.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_chaining_and_caching.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9cc5bc30be12a2e2494d26839e744bae3c07e48 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_chaining_and_caching.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_check_indexer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_check_indexer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad1317a725ba48d12c4079fc0feb6d0ef2de4c47 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_check_indexer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_coercion.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_coercion.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86569104863e4cad6f4051ccc95dd31168ec75ea Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_coercion.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_datetime.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_datetime.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21c3919808656aefef73b1ec81e1353cb29f091e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_datetime.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_floats.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_floats.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38ddae3b488ea9ff4366f5e8aa78dcaa209c092b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_floats.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_iat.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_iat.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b4799da3ac1ca4fd0c19d70d843d764955e8dbf Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_iat.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_iloc.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_iloc.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..832eacb49173a534a0a02b10f9cad063afa14855 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_iloc.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_indexers.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_indexers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23455ffec38f9102842637934831c02efadf91b8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_indexers.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_indexing.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_indexing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36dea117baf2ad14730088aaff1ffbc888b0eb0d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_indexing.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_na_indexing.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_na_indexing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81c19456acbd00f5552cd95b7b21703af922e5c0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_na_indexing.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_partial.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_partial.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b217fe1d31c4b7815aaaafc9807b5e84d4f1352e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_partial.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_scalar.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_scalar.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a763f759ae5b336cf17cf7ffa2ee419805f54da Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/__pycache__/test_scalar.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/common.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/common.py new file mode 100644 index 0000000000000000000000000000000000000000..2af76f69a4300ac744a5e6f1f7dab185e19767ca --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/common.py @@ -0,0 +1,40 @@ +""" common utilities """ +from __future__ import annotations + +from typing import ( + Any, + Literal, +) + + +def _mklbl(prefix: str, n: int): + return [f"{prefix}{i}" for i in range(n)] + + +def check_indexing_smoketest_or_raises( + obj, + method: Literal["iloc", "loc"], + key: Any, + axes: Literal[0, 1] | None = None, + fails=None, +) -> None: + if axes is None: + axes_list = [0, 1] + else: + assert axes in [0, 1] + axes_list = [axes] + + for ax in axes_list: + if ax < obj.ndim: + # create a tuple accessor + new_axes = [slice(None)] * obj.ndim + new_axes[ax] = key + axified = tuple(new_axes) + try: + getattr(obj, method).__getitem__(axified) + except (IndexError, TypeError, KeyError) as detail: + # if we are in fails, the ok, otherwise raise it + if fails is not None: + if isinstance(detail, fails): + return + raise diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/conftest.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..4184c6a0047ccf0dccb8a72f028b27879130aea5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/conftest.py @@ -0,0 +1,127 @@ +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, + date_range, +) + + +@pytest.fixture +def series_ints(): + return Series(np.random.default_rng(2).random(4), index=np.arange(0, 8, 2)) + + +@pytest.fixture +def frame_ints(): + return DataFrame( + np.random.default_rng(2).standard_normal((4, 4)), + index=np.arange(0, 8, 2), + columns=np.arange(0, 12, 3), + ) + + +@pytest.fixture +def series_uints(): + return Series( + np.random.default_rng(2).random(4), + index=Index(np.arange(0, 8, 2, dtype=np.uint64)), + ) + + +@pytest.fixture +def frame_uints(): + return DataFrame( + np.random.default_rng(2).standard_normal((4, 4)), + index=Index(range(0, 8, 2), dtype=np.uint64), + columns=Index(range(0, 12, 3), dtype=np.uint64), + ) + + +@pytest.fixture +def series_labels(): + return Series(np.random.default_rng(2).standard_normal(4), index=list("abcd")) + + +@pytest.fixture +def frame_labels(): + return DataFrame( + np.random.default_rng(2).standard_normal((4, 4)), + index=list("abcd"), + columns=list("ABCD"), + ) + + +@pytest.fixture +def series_ts(): + return Series( + np.random.default_rng(2).standard_normal(4), + index=date_range("20130101", periods=4), + ) + + +@pytest.fixture +def frame_ts(): + return DataFrame( + np.random.default_rng(2).standard_normal((4, 4)), + index=date_range("20130101", periods=4), + ) + + +@pytest.fixture +def series_floats(): + return Series( + np.random.default_rng(2).random(4), + index=Index(range(0, 8, 2), dtype=np.float64), + ) + + +@pytest.fixture +def frame_floats(): + return DataFrame( + np.random.default_rng(2).standard_normal((4, 4)), + index=Index(range(0, 8, 2), dtype=np.float64), + columns=Index(range(0, 12, 3), dtype=np.float64), + ) + + +@pytest.fixture +def series_mixed(): + return Series(np.random.default_rng(2).standard_normal(4), index=[2, 4, "null", 8]) + + +@pytest.fixture +def frame_mixed(): + return DataFrame( + np.random.default_rng(2).standard_normal((4, 4)), index=[2, 4, "null", 8] + ) + + +@pytest.fixture +def frame_empty(): + return DataFrame() + + +@pytest.fixture +def series_empty(): + return Series(dtype=object) + + +@pytest.fixture +def frame_multi(): + return DataFrame( + np.random.default_rng(2).standard_normal((4, 4)), + index=MultiIndex.from_product([[1, 2], [3, 4]]), + columns=MultiIndex.from_product([[5, 6], [7, 8]]), + ) + + +@pytest.fixture +def series_multi(): + return Series( + np.random.default_rng(2).random(4), + index=MultiIndex.from_product([[1, 2], [3, 4]]), + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/interval/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/interval/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dab0519bcf8ad7c5de310e73c15dce6ac94566bf Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/interval/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/interval/__pycache__/test_interval.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/interval/__pycache__/test_interval.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e526aca7364660ce453b5f32b9a65f3d104fac1b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/interval/__pycache__/test_interval.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/interval/__pycache__/test_interval_new.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/interval/__pycache__/test_interval_new.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06df9f5c3348badbdc219333f859eae2c9da44f2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/interval/__pycache__/test_interval_new.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b932a049143a787947836810b6fdd71a4700f9f6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_chaining_and_caching.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_chaining_and_caching.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ef656fe3971a561e0c81d9e0042bea45e2f4424 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_chaining_and_caching.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_datetime.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_datetime.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..479f6cbe0058827d9c2886db32f362a9616aac0e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_datetime.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_getitem.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_getitem.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee1d4d233f99d684dc22972c542e79646709ad73 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_getitem.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_iloc.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_iloc.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77ccb7a30376314fc000aaffda6b8f626db845db Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_iloc.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_indexing_slow.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_indexing_slow.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f818729719db2a726d309d0d918814cc8f72cd6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_indexing_slow.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_loc.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_loc.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5b9808b11f266933b1a42e1a7389984d0f5bf01 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_loc.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_multiindex.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_multiindex.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6106cdef543b935d551b296eeaa9b04f9a911e6c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_multiindex.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_partial.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_partial.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec6778a59a9f6a6533244c796e052796062559ec Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_partial.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_setitem.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_setitem.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07354de697790ef06a2f16a9888b527bed24bf0f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_setitem.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_slice.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_slice.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab5c0ccd5143bd75a85849f317d7b36430459f87 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_slice.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_sorted.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_sorted.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9776afa102eeeab6747c17ad129464d599a1403 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/__pycache__/test_sorted.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/test_iloc.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/test_iloc.py new file mode 100644 index 0000000000000000000000000000000000000000..8939ecc78000be08812afb702358e7eee1ae9499 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/multiindex/test_iloc.py @@ -0,0 +1,171 @@ +import numpy as np +import pytest + +from pandas import ( + DataFrame, + MultiIndex, + Series, +) +import pandas._testing as tm + + +@pytest.fixture +def simple_multiindex_dataframe(): + """ + Factory function to create simple 3 x 3 dataframe with + both columns and row MultiIndex using supplied data or + random data by default. + """ + + data = np.random.default_rng(2).standard_normal((3, 3)) + return DataFrame( + data, columns=[[2, 2, 4], [6, 8, 10]], index=[[4, 4, 8], [8, 10, 12]] + ) + + +@pytest.mark.parametrize( + "indexer, expected", + [ + ( + lambda df: df.iloc[0], + lambda arr: Series(arr[0], index=[[2, 2, 4], [6, 8, 10]], name=(4, 8)), + ), + ( + lambda df: df.iloc[2], + lambda arr: Series(arr[2], index=[[2, 2, 4], [6, 8, 10]], name=(8, 12)), + ), + ( + lambda df: df.iloc[:, 2], + lambda arr: Series(arr[:, 2], index=[[4, 4, 8], [8, 10, 12]], name=(4, 10)), + ), + ], +) +def test_iloc_returns_series(indexer, expected, simple_multiindex_dataframe): + df = simple_multiindex_dataframe + arr = df.values + result = indexer(df) + expected = expected(arr) + tm.assert_series_equal(result, expected) + + +def test_iloc_returns_dataframe(simple_multiindex_dataframe): + df = simple_multiindex_dataframe + result = df.iloc[[0, 1]] + expected = df.xs(4, drop_level=False) + tm.assert_frame_equal(result, expected) + + +def test_iloc_returns_scalar(simple_multiindex_dataframe): + df = simple_multiindex_dataframe + arr = df.values + result = df.iloc[2, 2] + expected = arr[2, 2] + assert result == expected + + +def test_iloc_getitem_multiple_items(): + # GH 5528 + tup = zip(*[["a", "a", "b", "b"], ["x", "y", "x", "y"]]) + index = MultiIndex.from_tuples(tup) + df = DataFrame(np.random.default_rng(2).standard_normal((4, 4)), index=index) + result = df.iloc[[2, 3]] + expected = df.xs("b", drop_level=False) + tm.assert_frame_equal(result, expected) + + +def test_iloc_getitem_labels(): + # this is basically regular indexing + arr = np.random.default_rng(2).standard_normal((4, 3)) + df = DataFrame( + arr, + columns=[["i", "i", "j"], ["A", "A", "B"]], + index=[["i", "i", "j", "k"], ["X", "X", "Y", "Y"]], + ) + result = df.iloc[2, 2] + expected = arr[2, 2] + assert result == expected + + +def test_frame_getitem_slice(multiindex_dataframe_random_data): + df = multiindex_dataframe_random_data + result = df.iloc[:4] + expected = df[:4] + tm.assert_frame_equal(result, expected) + + +def test_frame_setitem_slice(multiindex_dataframe_random_data): + df = multiindex_dataframe_random_data + df.iloc[:4] = 0 + + assert (df.values[:4] == 0).all() + assert (df.values[4:] != 0).all() + + +def test_indexing_ambiguity_bug_1678(): + # GH 1678 + columns = MultiIndex.from_tuples( + [("Ohio", "Green"), ("Ohio", "Red"), ("Colorado", "Green")] + ) + index = MultiIndex.from_tuples([("a", 1), ("a", 2), ("b", 1), ("b", 2)]) + + df = DataFrame(np.arange(12).reshape((4, 3)), index=index, columns=columns) + + result = df.iloc[:, 1] + expected = df.loc[:, ("Ohio", "Red")] + tm.assert_series_equal(result, expected) + + +def test_iloc_integer_locations(): + # GH 13797 + data = [ + ["str00", "str01"], + ["str10", "str11"], + ["str20", "srt21"], + ["str30", "str31"], + ["str40", "str41"], + ] + + index = MultiIndex.from_tuples( + [("CC", "A"), ("CC", "B"), ("CC", "B"), ("BB", "a"), ("BB", "b")] + ) + + expected = DataFrame(data) + df = DataFrame(data, index=index) + + result = DataFrame([[df.iloc[r, c] for c in range(2)] for r in range(5)]) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "data, indexes, values, expected_k", + [ + # test without indexer value in first level of MultiIndex + ([[2, 22, 5], [2, 33, 6]], [0, -1, 1], [2, 3, 1], [7, 10]), + # test like code sample 1 in the issue + ([[1, 22, 555], [1, 33, 666]], [0, -1, 1], [200, 300, 100], [755, 1066]), + # test like code sample 2 in the issue + ([[1, 3, 7], [2, 4, 8]], [0, -1, 1], [10, 10, 1000], [17, 1018]), + # test like code sample 3 in the issue + ([[1, 11, 4], [2, 22, 5], [3, 33, 6]], [0, -1, 1], [4, 7, 10], [8, 15, 13]), + ], +) +def test_iloc_setitem_int_multiindex_series(data, indexes, values, expected_k): + # GH17148 + df = DataFrame(data=data, columns=["i", "j", "k"]) + df = df.set_index(["i", "j"]) + + series = df.k.copy() + for i, v in zip(indexes, values): + series.iloc[i] += v + + df["k"] = expected_k + expected = df.k + tm.assert_series_equal(series, expected) + + +def test_getitem_iloc(multiindex_dataframe_random_data): + df = multiindex_dataframe_random_data + result = df.iloc[2] + expected = df.xs(df.index[2]) + tm.assert_series_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_categorical.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_categorical.py new file mode 100644 index 0000000000000000000000000000000000000000..1b58f8e8b983113e4a627e75cf6db7917c33866a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_categorical.py @@ -0,0 +1,573 @@ +import re + +import numpy as np +import pytest + +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + Categorical, + CategoricalDtype, + CategoricalIndex, + DataFrame, + Index, + Interval, + Series, + Timedelta, + Timestamp, + option_context, +) +import pandas._testing as tm + + +@pytest.fixture +def df(): + return DataFrame( + { + "A": np.arange(6, dtype="int64"), + }, + index=CategoricalIndex( + list("aabbca"), dtype=CategoricalDtype(list("cab")), name="B" + ), + ) + + +@pytest.fixture +def df2(): + return DataFrame( + { + "A": np.arange(6, dtype="int64"), + }, + index=CategoricalIndex( + list("aabbca"), dtype=CategoricalDtype(list("cabe")), name="B" + ), + ) + + +class TestCategoricalIndex: + def test_loc_scalar(self, df): + dtype = CategoricalDtype(list("cab")) + result = df.loc["a"] + bidx = Series(list("aaa"), name="B").astype(dtype) + assert bidx.dtype == dtype + + expected = DataFrame({"A": [0, 1, 5]}, index=Index(bidx)) + tm.assert_frame_equal(result, expected) + + df = df.copy() + df.loc["a"] = 20 + bidx2 = Series(list("aabbca"), name="B").astype(dtype) + assert bidx2.dtype == dtype + expected = DataFrame( + { + "A": [20, 20, 2, 3, 4, 20], + }, + index=Index(bidx2), + ) + tm.assert_frame_equal(df, expected) + + # value not in the categories + with pytest.raises(KeyError, match=r"^'d'$"): + df.loc["d"] + + df2 = df.copy() + expected = df2.copy() + expected.index = expected.index.astype(object) + expected.loc["d"] = 10 + df2.loc["d"] = 10 + tm.assert_frame_equal(df2, expected) + + def test_loc_setitem_with_expansion_non_category(self, df): + # Setting-with-expansion with a new key "d" that is not among caegories + df.loc["a"] = 20 + + # Setting a new row on an existing column + df3 = df.copy() + df3.loc["d", "A"] = 10 + bidx3 = Index(list("aabbcad"), name="B") + expected3 = DataFrame( + { + "A": [20, 20, 2, 3, 4, 20, 10.0], + }, + index=Index(bidx3), + ) + tm.assert_frame_equal(df3, expected3) + + # Setting a new row _and_ new column + df4 = df.copy() + df4.loc["d", "C"] = 10 + expected3 = DataFrame( + { + "A": [20, 20, 2, 3, 4, 20, np.nan], + "C": [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, 10], + }, + index=Index(bidx3), + ) + tm.assert_frame_equal(df4, expected3) + + def test_loc_getitem_scalar_non_category(self, df): + with pytest.raises(KeyError, match="^1$"): + df.loc[1] + + def test_slicing(self): + cat = Series(Categorical([1, 2, 3, 4])) + reverse = cat[::-1] + exp = np.array([4, 3, 2, 1], dtype=np.int64) + tm.assert_numpy_array_equal(reverse.__array__(), exp) + + df = DataFrame({"value": (np.arange(100) + 1).astype("int64")}) + df["D"] = pd.cut(df.value, bins=[0, 25, 50, 75, 100]) + + expected = Series([11, Interval(0, 25)], index=["value", "D"], name=10) + result = df.iloc[10] + tm.assert_series_equal(result, expected) + + expected = DataFrame( + {"value": np.arange(11, 21).astype("int64")}, + index=np.arange(10, 20).astype("int64"), + ) + expected["D"] = pd.cut(expected.value, bins=[0, 25, 50, 75, 100]) + result = df.iloc[10:20] + tm.assert_frame_equal(result, expected) + + expected = Series([9, Interval(0, 25)], index=["value", "D"], name=8) + result = df.loc[8] + tm.assert_series_equal(result, expected) + + def test_slicing_and_getting_ops(self): + # systematically test the slicing operations: + # for all slicing ops: + # - returning a dataframe + # - returning a column + # - returning a row + # - returning a single value + + cats = Categorical( + ["a", "c", "b", "c", "c", "c", "c"], categories=["a", "b", "c"] + ) + idx = Index(["h", "i", "j", "k", "l", "m", "n"]) + values = [1, 2, 3, 4, 5, 6, 7] + df = DataFrame({"cats": cats, "values": values}, index=idx) + + # the expected values + cats2 = Categorical(["b", "c"], categories=["a", "b", "c"]) + idx2 = Index(["j", "k"]) + values2 = [3, 4] + + # 2:4,: | "j":"k",: + exp_df = DataFrame({"cats": cats2, "values": values2}, index=idx2) + + # :,"cats" | :,0 + exp_col = Series(cats, index=idx, name="cats") + + # "j",: | 2,: + exp_row = Series(["b", 3], index=["cats", "values"], dtype="object", name="j") + + # "j","cats | 2,0 + exp_val = "b" + + # iloc + # frame + res_df = df.iloc[2:4, :] + tm.assert_frame_equal(res_df, exp_df) + assert isinstance(res_df["cats"].dtype, CategoricalDtype) + + # row + res_row = df.iloc[2, :] + tm.assert_series_equal(res_row, exp_row) + assert isinstance(res_row["cats"], str) + + # col + res_col = df.iloc[:, 0] + tm.assert_series_equal(res_col, exp_col) + assert isinstance(res_col.dtype, CategoricalDtype) + + # single value + res_val = df.iloc[2, 0] + assert res_val == exp_val + + # loc + # frame + res_df = df.loc["j":"k", :] + tm.assert_frame_equal(res_df, exp_df) + assert isinstance(res_df["cats"].dtype, CategoricalDtype) + + # row + res_row = df.loc["j", :] + tm.assert_series_equal(res_row, exp_row) + assert isinstance(res_row["cats"], str) + + # col + res_col = df.loc[:, "cats"] + tm.assert_series_equal(res_col, exp_col) + assert isinstance(res_col.dtype, CategoricalDtype) + + # single value + res_val = df.loc["j", "cats"] + assert res_val == exp_val + + # single value + res_val = df.loc["j", df.columns[0]] + assert res_val == exp_val + + # iat + res_val = df.iat[2, 0] + assert res_val == exp_val + + # at + res_val = df.at["j", "cats"] + assert res_val == exp_val + + # fancy indexing + exp_fancy = df.iloc[[2]] + + res_fancy = df[df["cats"] == "b"] + tm.assert_frame_equal(res_fancy, exp_fancy) + res_fancy = df[df["values"] == 3] + tm.assert_frame_equal(res_fancy, exp_fancy) + + # get_value + res_val = df.at["j", "cats"] + assert res_val == exp_val + + # i : int, slice, or sequence of integers + res_row = df.iloc[2] + tm.assert_series_equal(res_row, exp_row) + assert isinstance(res_row["cats"], str) + + res_df = df.iloc[slice(2, 4)] + tm.assert_frame_equal(res_df, exp_df) + assert isinstance(res_df["cats"].dtype, CategoricalDtype) + + res_df = df.iloc[[2, 3]] + tm.assert_frame_equal(res_df, exp_df) + assert isinstance(res_df["cats"].dtype, CategoricalDtype) + + res_col = df.iloc[:, 0] + tm.assert_series_equal(res_col, exp_col) + assert isinstance(res_col.dtype, CategoricalDtype) + + res_df = df.iloc[:, slice(0, 2)] + tm.assert_frame_equal(res_df, df) + assert isinstance(res_df["cats"].dtype, CategoricalDtype) + + res_df = df.iloc[:, [0, 1]] + tm.assert_frame_equal(res_df, df) + assert isinstance(res_df["cats"].dtype, CategoricalDtype) + + def test_slicing_doc_examples(self): + # GH 7918 + cats = Categorical( + ["a", "b", "b", "b", "c", "c", "c"], categories=["a", "b", "c"] + ) + idx = Index(["h", "i", "j", "k", "l", "m", "n"]) + values = [1, 2, 2, 2, 3, 4, 5] + df = DataFrame({"cats": cats, "values": values}, index=idx) + + result = df.iloc[2:4, :] + expected = DataFrame( + { + "cats": Categorical(["b", "b"], categories=["a", "b", "c"]), + "values": [2, 2], + }, + index=["j", "k"], + ) + tm.assert_frame_equal(result, expected) + + result = df.iloc[2:4, :].dtypes + expected = Series(["category", "int64"], ["cats", "values"], dtype=object) + tm.assert_series_equal(result, expected) + + result = df.loc["h":"j", "cats"] + expected = Series( + Categorical(["a", "b", "b"], categories=["a", "b", "c"]), + index=["h", "i", "j"], + name="cats", + ) + tm.assert_series_equal(result, expected) + + result = df.loc["h":"j", df.columns[0:1]] + expected = DataFrame( + {"cats": Categorical(["a", "b", "b"], categories=["a", "b", "c"])}, + index=["h", "i", "j"], + ) + tm.assert_frame_equal(result, expected) + + def test_loc_getitem_listlike_labels(self, df): + # list of labels + result = df.loc[["c", "a"]] + expected = df.iloc[[4, 0, 1, 5]] + tm.assert_frame_equal(result, expected, check_index_type=True) + + def test_loc_getitem_listlike_unused_category(self, df2): + # GH#37901 a label that is in index.categories but not in index + # listlike containing an element in the categories but not in the values + with pytest.raises(KeyError, match=re.escape("['e'] not in index")): + df2.loc[["a", "b", "e"]] + + def test_loc_getitem_label_unused_category(self, df2): + # element in the categories but not in the values + with pytest.raises(KeyError, match=r"^'e'$"): + df2.loc["e"] + + def test_loc_getitem_non_category(self, df2): + # not all labels in the categories + with pytest.raises(KeyError, match=re.escape("['d'] not in index")): + df2.loc[["a", "d"]] + + def test_loc_setitem_expansion_label_unused_category(self, df2): + # assigning with a label that is in the categories but not in the index + df = df2.copy() + df.loc["e"] = 20 + result = df.loc[["a", "b", "e"]] + exp_index = CategoricalIndex(list("aaabbe"), categories=list("cabe"), name="B") + expected = DataFrame({"A": [0, 1, 5, 2, 3, 20]}, index=exp_index) + tm.assert_frame_equal(result, expected) + + def test_loc_listlike_dtypes(self): + # GH 11586 + + # unique categories and codes + index = CategoricalIndex(["a", "b", "c"]) + df = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}, index=index) + + # unique slice + res = df.loc[["a", "b"]] + exp_index = CategoricalIndex(["a", "b"], categories=index.categories) + exp = DataFrame({"A": [1, 2], "B": [4, 5]}, index=exp_index) + tm.assert_frame_equal(res, exp, check_index_type=True) + + # duplicated slice + res = df.loc[["a", "a", "b"]] + + exp_index = CategoricalIndex(["a", "a", "b"], categories=index.categories) + exp = DataFrame({"A": [1, 1, 2], "B": [4, 4, 5]}, index=exp_index) + tm.assert_frame_equal(res, exp, check_index_type=True) + + with pytest.raises(KeyError, match=re.escape("['x'] not in index")): + df.loc[["a", "x"]] + + def test_loc_listlike_dtypes_duplicated_categories_and_codes(self): + # duplicated categories and codes + index = CategoricalIndex(["a", "b", "a"]) + df = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}, index=index) + + # unique slice + res = df.loc[["a", "b"]] + exp = DataFrame( + {"A": [1, 3, 2], "B": [4, 6, 5]}, index=CategoricalIndex(["a", "a", "b"]) + ) + tm.assert_frame_equal(res, exp, check_index_type=True) + + # duplicated slice + res = df.loc[["a", "a", "b"]] + exp = DataFrame( + {"A": [1, 3, 1, 3, 2], "B": [4, 6, 4, 6, 5]}, + index=CategoricalIndex(["a", "a", "a", "a", "b"]), + ) + tm.assert_frame_equal(res, exp, check_index_type=True) + + with pytest.raises(KeyError, match=re.escape("['x'] not in index")): + df.loc[["a", "x"]] + + def test_loc_listlike_dtypes_unused_category(self): + # contains unused category + index = CategoricalIndex(["a", "b", "a", "c"], categories=list("abcde")) + df = DataFrame({"A": [1, 2, 3, 4], "B": [5, 6, 7, 8]}, index=index) + + res = df.loc[["a", "b"]] + exp = DataFrame( + {"A": [1, 3, 2], "B": [5, 7, 6]}, + index=CategoricalIndex(["a", "a", "b"], categories=list("abcde")), + ) + tm.assert_frame_equal(res, exp, check_index_type=True) + + # duplicated slice + res = df.loc[["a", "a", "b"]] + exp = DataFrame( + {"A": [1, 3, 1, 3, 2], "B": [5, 7, 5, 7, 6]}, + index=CategoricalIndex(["a", "a", "a", "a", "b"], categories=list("abcde")), + ) + tm.assert_frame_equal(res, exp, check_index_type=True) + + with pytest.raises(KeyError, match=re.escape("['x'] not in index")): + df.loc[["a", "x"]] + + def test_loc_getitem_listlike_unused_category_raises_keyerror(self): + # key that is an *unused* category raises + index = CategoricalIndex(["a", "b", "a", "c"], categories=list("abcde")) + df = DataFrame({"A": [1, 2, 3, 4], "B": [5, 6, 7, 8]}, index=index) + + with pytest.raises(KeyError, match="e"): + # For comparison, check the scalar behavior + df.loc["e"] + + with pytest.raises(KeyError, match=re.escape("['e'] not in index")): + df.loc[["a", "e"]] + + def test_ix_categorical_index(self): + # GH 12531 + df = DataFrame( + np.random.default_rng(2).standard_normal((3, 3)), + index=list("ABC"), + columns=list("XYZ"), + ) + cdf = df.copy() + cdf.index = CategoricalIndex(df.index) + cdf.columns = CategoricalIndex(df.columns) + + expect = Series(df.loc["A", :], index=cdf.columns, name="A") + tm.assert_series_equal(cdf.loc["A", :], expect) + + expect = Series(df.loc[:, "X"], index=cdf.index, name="X") + tm.assert_series_equal(cdf.loc[:, "X"], expect) + + exp_index = CategoricalIndex(list("AB"), categories=["A", "B", "C"]) + expect = DataFrame(df.loc[["A", "B"], :], columns=cdf.columns, index=exp_index) + tm.assert_frame_equal(cdf.loc[["A", "B"], :], expect) + + exp_columns = CategoricalIndex(list("XY"), categories=["X", "Y", "Z"]) + expect = DataFrame(df.loc[:, ["X", "Y"]], index=cdf.index, columns=exp_columns) + tm.assert_frame_equal(cdf.loc[:, ["X", "Y"]], expect) + + @pytest.mark.parametrize( + "infer_string", [False, pytest.param(True, marks=td.skip_if_no("pyarrow"))] + ) + def test_ix_categorical_index_non_unique(self, infer_string): + # non-unique + with option_context("future.infer_string", infer_string): + df = DataFrame( + np.random.default_rng(2).standard_normal((3, 3)), + index=list("ABA"), + columns=list("XYX"), + ) + cdf = df.copy() + cdf.index = CategoricalIndex(df.index) + cdf.columns = CategoricalIndex(df.columns) + + exp_index = CategoricalIndex(list("AA"), categories=["A", "B"]) + expect = DataFrame(df.loc["A", :], columns=cdf.columns, index=exp_index) + tm.assert_frame_equal(cdf.loc["A", :], expect) + + exp_columns = CategoricalIndex(list("XX"), categories=["X", "Y"]) + expect = DataFrame(df.loc[:, "X"], index=cdf.index, columns=exp_columns) + tm.assert_frame_equal(cdf.loc[:, "X"], expect) + + expect = DataFrame( + df.loc[["A", "B"], :], + columns=cdf.columns, + index=CategoricalIndex(list("AAB")), + ) + tm.assert_frame_equal(cdf.loc[["A", "B"], :], expect) + + expect = DataFrame( + df.loc[:, ["X", "Y"]], + index=cdf.index, + columns=CategoricalIndex(list("XXY")), + ) + tm.assert_frame_equal(cdf.loc[:, ["X", "Y"]], expect) + + def test_loc_slice(self, df): + # GH9748 + msg = ( + "cannot do slice indexing on CategoricalIndex with these " + r"indexers \[1\] of type int" + ) + with pytest.raises(TypeError, match=msg): + df.loc[1:5] + + result = df.loc["b":"c"] + expected = df.iloc[[2, 3, 4]] + tm.assert_frame_equal(result, expected) + + def test_loc_and_at_with_categorical_index(self): + # GH 20629 + df = DataFrame( + [[1, 2], [3, 4], [5, 6]], index=CategoricalIndex(["A", "B", "C"]) + ) + + s = df[0] + assert s.loc["A"] == 1 + assert s.at["A"] == 1 + + assert df.loc["B", 1] == 4 + assert df.at["B", 1] == 4 + + @pytest.mark.parametrize( + "idx_values", + [ + # python types + [1, 2, 3], + [-1, -2, -3], + [1.5, 2.5, 3.5], + [-1.5, -2.5, -3.5], + # numpy int/uint + *(np.array([1, 2, 3], dtype=dtype) for dtype in tm.ALL_INT_NUMPY_DTYPES), + # numpy floats + *(np.array([1.5, 2.5, 3.5], dtype=dtyp) for dtyp in tm.FLOAT_NUMPY_DTYPES), + # numpy object + np.array([1, "b", 3.5], dtype=object), + # pandas scalars + [Interval(1, 4), Interval(4, 6), Interval(6, 9)], + [Timestamp(2019, 1, 1), Timestamp(2019, 2, 1), Timestamp(2019, 3, 1)], + [Timedelta(1, "d"), Timedelta(2, "d"), Timedelta(3, "D")], + # pandas Integer arrays + *(pd.array([1, 2, 3], dtype=dtype) for dtype in tm.ALL_INT_EA_DTYPES), + # other pandas arrays + pd.IntervalIndex.from_breaks([1, 4, 6, 9]).array, + pd.date_range("2019-01-01", periods=3).array, + pd.timedelta_range(start="1d", periods=3).array, + ], + ) + def test_loc_getitem_with_non_string_categories(self, idx_values, ordered): + # GH-17569 + cat_idx = CategoricalIndex(idx_values, ordered=ordered) + df = DataFrame({"A": ["foo", "bar", "baz"]}, index=cat_idx) + sl = slice(idx_values[0], idx_values[1]) + + # scalar selection + result = df.loc[idx_values[0]] + expected = Series(["foo"], index=["A"], name=idx_values[0]) + tm.assert_series_equal(result, expected) + + # list selection + result = df.loc[idx_values[:2]] + expected = DataFrame(["foo", "bar"], index=cat_idx[:2], columns=["A"]) + tm.assert_frame_equal(result, expected) + + # slice selection + result = df.loc[sl] + expected = DataFrame(["foo", "bar"], index=cat_idx[:2], columns=["A"]) + tm.assert_frame_equal(result, expected) + + # scalar assignment + result = df.copy() + result.loc[idx_values[0]] = "qux" + expected = DataFrame({"A": ["qux", "bar", "baz"]}, index=cat_idx) + tm.assert_frame_equal(result, expected) + + # list assignment + result = df.copy() + result.loc[idx_values[:2], "A"] = ["qux", "qux2"] + expected = DataFrame({"A": ["qux", "qux2", "baz"]}, index=cat_idx) + tm.assert_frame_equal(result, expected) + + # slice assignment + result = df.copy() + result.loc[sl, "A"] = ["qux", "qux2"] + expected = DataFrame({"A": ["qux", "qux2", "baz"]}, index=cat_idx) + tm.assert_frame_equal(result, expected) + + def test_getitem_categorical_with_nan(self): + # GH#41933 + ci = CategoricalIndex(["A", "B", np.nan]) + + ser = Series(range(3), index=ci) + + assert ser[np.nan] == 2 + assert ser.loc[np.nan] == 2 + + df = DataFrame(ser) + assert df.loc[np.nan, 0] == 2 + assert df.loc[np.nan][0] == 2 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_chaining_and_caching.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_chaining_and_caching.py new file mode 100644 index 0000000000000000000000000000000000000000..b97df376ac47fd8b62f631e38133ae0c0251fd63 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_chaining_and_caching.py @@ -0,0 +1,647 @@ +from string import ascii_letters + +import numpy as np +import pytest + +from pandas.errors import ( + SettingWithCopyError, + SettingWithCopyWarning, +) +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + DataFrame, + Index, + Series, + Timestamp, + date_range, + option_context, +) +import pandas._testing as tm + +msg = "A value is trying to be set on a copy of a slice from a DataFrame" + + +def random_text(nobs=100): + # Construct a DataFrame where each row is a random slice from 'letters' + idxs = np.random.default_rng(2).integers(len(ascii_letters), size=(nobs, 2)) + idxs.sort(axis=1) + strings = [ascii_letters[x[0] : x[1]] for x in idxs] + + return DataFrame(strings, columns=["letters"]) + + +class TestCaching: + def test_slice_consolidate_invalidate_item_cache(self, using_copy_on_write): + # this is chained assignment, but will 'work' + with option_context("chained_assignment", None): + # #3970 + df = DataFrame({"aa": np.arange(5), "bb": [2.2] * 5}) + + # Creates a second float block + df["cc"] = 0.0 + + # caches a reference to the 'bb' series + df["bb"] + + # Assignment to wrong series + with tm.raises_chained_assignment_error(): + df["bb"].iloc[0] = 0.17 + df._clear_item_cache() + if not using_copy_on_write: + tm.assert_almost_equal(df["bb"][0], 0.17) + else: + # with ArrayManager, parent is not mutated with chained assignment + tm.assert_almost_equal(df["bb"][0], 2.2) + + @pytest.mark.parametrize("do_ref", [True, False]) + def test_setitem_cache_updating(self, do_ref): + # GH 5424 + cont = ["one", "two", "three", "four", "five", "six", "seven"] + + df = DataFrame({"a": cont, "b": cont[3:] + cont[:3], "c": np.arange(7)}) + + # ref the cache + if do_ref: + df.loc[0, "c"] + + # set it + df.loc[7, "c"] = 1 + + assert df.loc[0, "c"] == 0.0 + assert df.loc[7, "c"] == 1.0 + + def test_setitem_cache_updating_slices( + self, using_copy_on_write, warn_copy_on_write + ): + # GH 7084 + # not updating cache on series setting with slices + expected = DataFrame( + {"A": [600, 600, 600]}, index=date_range("5/7/2014", "5/9/2014") + ) + out = DataFrame({"A": [0, 0, 0]}, index=date_range("5/7/2014", "5/9/2014")) + df = DataFrame({"C": ["A", "A", "A"], "D": [100, 200, 300]}) + + # loop through df to update out + six = Timestamp("5/7/2014") + eix = Timestamp("5/9/2014") + for ix, row in df.iterrows(): + out.loc[six:eix, row["C"]] = out.loc[six:eix, row["C"]] + row["D"] + + tm.assert_frame_equal(out, expected) + tm.assert_series_equal(out["A"], expected["A"]) + + # try via a chain indexing + # this actually works + out = DataFrame({"A": [0, 0, 0]}, index=date_range("5/7/2014", "5/9/2014")) + out_original = out.copy() + for ix, row in df.iterrows(): + v = out[row["C"]][six:eix] + row["D"] + with tm.raises_chained_assignment_error( + (ix == 0) or warn_copy_on_write or using_copy_on_write + ): + out[row["C"]][six:eix] = v + + if not using_copy_on_write: + tm.assert_frame_equal(out, expected) + tm.assert_series_equal(out["A"], expected["A"]) + else: + tm.assert_frame_equal(out, out_original) + tm.assert_series_equal(out["A"], out_original["A"]) + + out = DataFrame({"A": [0, 0, 0]}, index=date_range("5/7/2014", "5/9/2014")) + for ix, row in df.iterrows(): + out.loc[six:eix, row["C"]] += row["D"] + + tm.assert_frame_equal(out, expected) + tm.assert_series_equal(out["A"], expected["A"]) + + def test_altering_series_clears_parent_cache( + self, using_copy_on_write, warn_copy_on_write + ): + # GH #33675 + df = DataFrame([[1, 2], [3, 4]], index=["a", "b"], columns=["A", "B"]) + ser = df["A"] + + if using_copy_on_write or warn_copy_on_write: + assert "A" not in df._item_cache + else: + assert "A" in df._item_cache + + # Adding a new entry to ser swaps in a new array, so "A" needs to + # be removed from df._item_cache + ser["c"] = 5 + assert len(ser) == 3 + assert "A" not in df._item_cache + assert df["A"] is not ser + assert len(df["A"]) == 2 + + +class TestChaining: + def test_setitem_chained_setfault(self, using_copy_on_write): + # GH6026 + data = ["right", "left", "left", "left", "right", "left", "timeout"] + mdata = ["right", "left", "left", "left", "right", "left", "none"] + + df = DataFrame({"response": np.array(data)}) + mask = df.response == "timeout" + with tm.raises_chained_assignment_error(): + df.response[mask] = "none" + if using_copy_on_write: + tm.assert_frame_equal(df, DataFrame({"response": data})) + else: + tm.assert_frame_equal(df, DataFrame({"response": mdata})) + + recarray = np.rec.fromarrays([data], names=["response"]) + df = DataFrame(recarray) + mask = df.response == "timeout" + with tm.raises_chained_assignment_error(): + df.response[mask] = "none" + if using_copy_on_write: + tm.assert_frame_equal(df, DataFrame({"response": data})) + else: + tm.assert_frame_equal(df, DataFrame({"response": mdata})) + + df = DataFrame({"response": data, "response1": data}) + df_original = df.copy() + mask = df.response == "timeout" + with tm.raises_chained_assignment_error(): + df.response[mask] = "none" + if using_copy_on_write: + tm.assert_frame_equal(df, df_original) + else: + tm.assert_frame_equal(df, DataFrame({"response": mdata, "response1": data})) + + # GH 6056 + expected = DataFrame({"A": [np.nan, "bar", "bah", "foo", "bar"]}) + df = DataFrame({"A": np.array(["foo", "bar", "bah", "foo", "bar"])}) + with tm.raises_chained_assignment_error(): + df["A"].iloc[0] = np.nan + if using_copy_on_write: + expected = DataFrame({"A": ["foo", "bar", "bah", "foo", "bar"]}) + else: + expected = DataFrame({"A": [np.nan, "bar", "bah", "foo", "bar"]}) + result = df.head() + tm.assert_frame_equal(result, expected) + + df = DataFrame({"A": np.array(["foo", "bar", "bah", "foo", "bar"])}) + with tm.raises_chained_assignment_error(): + df.A.iloc[0] = np.nan + result = df.head() + tm.assert_frame_equal(result, expected) + + @pytest.mark.arm_slow + def test_detect_chained_assignment(self, using_copy_on_write): + with option_context("chained_assignment", "raise"): + # work with the chain + expected = DataFrame([[-5, 1], [-6, 3]], columns=list("AB")) + df = DataFrame( + np.arange(4).reshape(2, 2), columns=list("AB"), dtype="int64" + ) + df_original = df.copy() + assert df._is_copy is None + + with tm.raises_chained_assignment_error(): + df["A"][0] = -5 + with tm.raises_chained_assignment_error(): + df["A"][1] = -6 + if using_copy_on_write: + tm.assert_frame_equal(df, df_original) + else: + tm.assert_frame_equal(df, expected) + + @pytest.mark.arm_slow + def test_detect_chained_assignment_raises( + self, using_array_manager, using_copy_on_write, warn_copy_on_write + ): + # test with the chaining + df = DataFrame( + { + "A": Series(range(2), dtype="int64"), + "B": np.array(np.arange(2, 4), dtype=np.float64), + } + ) + df_original = df.copy() + assert df._is_copy is None + + if using_copy_on_write: + with tm.raises_chained_assignment_error(): + df["A"][0] = -5 + with tm.raises_chained_assignment_error(): + df["A"][1] = -6 + tm.assert_frame_equal(df, df_original) + elif warn_copy_on_write: + with tm.raises_chained_assignment_error(): + df["A"][0] = -5 + with tm.raises_chained_assignment_error(): + df["A"][1] = np.nan + elif not using_array_manager: + with pytest.raises(SettingWithCopyError, match=msg): + with tm.raises_chained_assignment_error(): + df["A"][0] = -5 + + with pytest.raises(SettingWithCopyError, match=msg): + with tm.raises_chained_assignment_error(): + df["A"][1] = np.nan + + assert df["A"]._is_copy is None + else: + # INFO(ArrayManager) for ArrayManager it doesn't matter that it's + # a mixed dataframe + df["A"][0] = -5 + df["A"][1] = -6 + expected = DataFrame([[-5, 2], [-6, 3]], columns=list("AB")) + expected["B"] = expected["B"].astype("float64") + tm.assert_frame_equal(df, expected) + + @pytest.mark.arm_slow + def test_detect_chained_assignment_fails( + self, using_copy_on_write, warn_copy_on_write + ): + # Using a copy (the chain), fails + df = DataFrame( + { + "A": Series(range(2), dtype="int64"), + "B": np.array(np.arange(2, 4), dtype=np.float64), + } + ) + + if using_copy_on_write or warn_copy_on_write: + with tm.raises_chained_assignment_error(): + df.loc[0]["A"] = -5 + else: + with pytest.raises(SettingWithCopyError, match=msg): + df.loc[0]["A"] = -5 + + @pytest.mark.arm_slow + def test_detect_chained_assignment_doc_example( + self, using_copy_on_write, warn_copy_on_write + ): + # Doc example + df = DataFrame( + { + "a": ["one", "one", "two", "three", "two", "one", "six"], + "c": Series(range(7), dtype="int64"), + } + ) + assert df._is_copy is None + + indexer = df.a.str.startswith("o") + if using_copy_on_write or warn_copy_on_write: + with tm.raises_chained_assignment_error(): + df[indexer]["c"] = 42 + else: + with pytest.raises(SettingWithCopyError, match=msg): + df[indexer]["c"] = 42 + + @pytest.mark.arm_slow + def test_detect_chained_assignment_object_dtype( + self, using_array_manager, using_copy_on_write, warn_copy_on_write + ): + expected = DataFrame({"A": [111, "bbb", "ccc"], "B": [1, 2, 3]}) + df = DataFrame( + {"A": Series(["aaa", "bbb", "ccc"], dtype=object), "B": [1, 2, 3]} + ) + df_original = df.copy() + + if not using_copy_on_write and not warn_copy_on_write: + with pytest.raises(SettingWithCopyError, match=msg): + df.loc[0]["A"] = 111 + + if using_copy_on_write: + with tm.raises_chained_assignment_error(): + df["A"][0] = 111 + tm.assert_frame_equal(df, df_original) + elif warn_copy_on_write: + with tm.raises_chained_assignment_error(): + df["A"][0] = 111 + tm.assert_frame_equal(df, expected) + elif not using_array_manager: + with pytest.raises(SettingWithCopyError, match=msg): + with tm.raises_chained_assignment_error(): + df["A"][0] = 111 + + df.loc[0, "A"] = 111 + tm.assert_frame_equal(df, expected) + else: + # INFO(ArrayManager) for ArrayManager it doesn't matter that it's + # a mixed dataframe + df["A"][0] = 111 + tm.assert_frame_equal(df, expected) + + @pytest.mark.arm_slow + def test_detect_chained_assignment_is_copy_pickle(self): + # gh-5475: Make sure that is_copy is picked up reconstruction + df = DataFrame({"A": [1, 2]}) + assert df._is_copy is None + + with tm.ensure_clean("__tmp__pickle") as path: + df.to_pickle(path) + df2 = pd.read_pickle(path) + df2["B"] = df2["A"] + df2["B"] = df2["A"] + + @pytest.mark.arm_slow + def test_detect_chained_assignment_setting_entire_column(self): + # gh-5597: a spurious raise as we are setting the entire column here + + df = random_text(100000) + + # Always a copy + x = df.iloc[[0, 1, 2]] + assert x._is_copy is not None + + x = df.iloc[[0, 1, 2, 4]] + assert x._is_copy is not None + + # Explicitly copy + indexer = df.letters.apply(lambda x: len(x) > 10) + df = df.loc[indexer].copy() + + assert df._is_copy is None + df["letters"] = df["letters"].apply(str.lower) + + @pytest.mark.arm_slow + def test_detect_chained_assignment_implicit_take(self): + # Implicitly take + df = random_text(100000) + indexer = df.letters.apply(lambda x: len(x) > 10) + df = df.loc[indexer] + + assert df._is_copy is not None + df["letters"] = df["letters"].apply(str.lower) + + @pytest.mark.arm_slow + def test_detect_chained_assignment_implicit_take2( + self, using_copy_on_write, warn_copy_on_write + ): + if using_copy_on_write or warn_copy_on_write: + pytest.skip("_is_copy is not always set for CoW") + # Implicitly take 2 + df = random_text(100000) + indexer = df.letters.apply(lambda x: len(x) > 10) + + df = df.loc[indexer] + assert df._is_copy is not None + df.loc[:, "letters"] = df["letters"].apply(str.lower) + + # with the enforcement of #45333 in 2.0, the .loc[:, letters] setting + # is inplace, so df._is_copy remains non-None. + assert df._is_copy is not None + + df["letters"] = df["letters"].apply(str.lower) + assert df._is_copy is None + + @pytest.mark.arm_slow + def test_detect_chained_assignment_str(self): + df = random_text(100000) + indexer = df.letters.apply(lambda x: len(x) > 10) + df.loc[indexer, "letters"] = df.loc[indexer, "letters"].apply(str.lower) + + @pytest.mark.arm_slow + def test_detect_chained_assignment_is_copy(self): + # an identical take, so no copy + df = DataFrame({"a": [1]}).dropna() + assert df._is_copy is None + df["a"] += 1 + + @pytest.mark.arm_slow + def test_detect_chained_assignment_sorting(self): + df = DataFrame(np.random.default_rng(2).standard_normal((10, 4))) + ser = df.iloc[:, 0].sort_values() + + tm.assert_series_equal(ser, df.iloc[:, 0].sort_values()) + tm.assert_series_equal(ser, df[0].sort_values()) + + @pytest.mark.arm_slow + def test_detect_chained_assignment_false_positives(self): + # see gh-6025: false positives + df = DataFrame({"column1": ["a", "a", "a"], "column2": [4, 8, 9]}) + str(df) + + df["column1"] = df["column1"] + "b" + str(df) + + df = df[df["column2"] != 8] + str(df) + + df["column1"] = df["column1"] + "c" + str(df) + + @pytest.mark.arm_slow + def test_detect_chained_assignment_undefined_column( + self, using_copy_on_write, warn_copy_on_write + ): + # from SO: + # https://stackoverflow.com/questions/24054495/potential-bug-setting-value-for-undefined-column-using-iloc + df = DataFrame(np.arange(0, 9), columns=["count"]) + df["group"] = "b" + df_original = df.copy() + + if using_copy_on_write: + with tm.raises_chained_assignment_error(): + df.iloc[0:5]["group"] = "a" + tm.assert_frame_equal(df, df_original) + elif warn_copy_on_write: + with tm.raises_chained_assignment_error(): + df.iloc[0:5]["group"] = "a" + else: + with pytest.raises(SettingWithCopyError, match=msg): + with tm.raises_chained_assignment_error(): + df.iloc[0:5]["group"] = "a" + + @pytest.mark.arm_slow + def test_detect_chained_assignment_changing_dtype( + self, using_array_manager, using_copy_on_write, warn_copy_on_write + ): + # Mixed type setting but same dtype & changing dtype + df = DataFrame( + { + "A": date_range("20130101", periods=5), + "B": np.random.default_rng(2).standard_normal(5), + "C": np.arange(5, dtype="int64"), + "D": ["a", "b", "c", "d", "e"], + } + ) + df_original = df.copy() + + if using_copy_on_write or warn_copy_on_write: + with tm.raises_chained_assignment_error(): + df.loc[2]["D"] = "foo" + with tm.raises_chained_assignment_error(): + df.loc[2]["C"] = "foo" + tm.assert_frame_equal(df, df_original) + with tm.raises_chained_assignment_error(extra_warnings=(FutureWarning,)): + df["C"][2] = "foo" + if using_copy_on_write: + tm.assert_frame_equal(df, df_original) + else: + assert df.loc[2, "C"] == "foo" + else: + with pytest.raises(SettingWithCopyError, match=msg): + df.loc[2]["D"] = "foo" + + with pytest.raises(SettingWithCopyError, match=msg): + df.loc[2]["C"] = "foo" + + if not using_array_manager: + with pytest.raises(SettingWithCopyError, match=msg): + with tm.raises_chained_assignment_error(): + df["C"][2] = "foo" + else: + # INFO(ArrayManager) for ArrayManager it doesn't matter if it's + # changing the dtype or not + df["C"][2] = "foo" + assert df.loc[2, "C"] == "foo" + + def test_setting_with_copy_bug(self, using_copy_on_write, warn_copy_on_write): + # operating on a copy + df = DataFrame( + {"a": list(range(4)), "b": list("ab.."), "c": ["a", "b", np.nan, "d"]} + ) + df_original = df.copy() + mask = pd.isna(df.c) + + if using_copy_on_write: + with tm.raises_chained_assignment_error(): + df[["c"]][mask] = df[["b"]][mask] + tm.assert_frame_equal(df, df_original) + elif warn_copy_on_write: + with tm.raises_chained_assignment_error(): + df[["c"]][mask] = df[["b"]][mask] + else: + with pytest.raises(SettingWithCopyError, match=msg): + df[["c"]][mask] = df[["b"]][mask] + + def test_setting_with_copy_bug_no_warning(self): + # invalid warning as we are returning a new object + # GH 8730 + df1 = DataFrame({"x": Series(["a", "b", "c"]), "y": Series(["d", "e", "f"])}) + df2 = df1[["x"]] + + # this should not raise + df2["y"] = ["g", "h", "i"] + + def test_detect_chained_assignment_warnings_errors( + self, using_copy_on_write, warn_copy_on_write + ): + df = DataFrame({"A": ["aaa", "bbb", "ccc"], "B": [1, 2, 3]}) + if using_copy_on_write or warn_copy_on_write: + with tm.raises_chained_assignment_error(): + df.loc[0]["A"] = 111 + return + + with option_context("chained_assignment", "warn"): + with tm.assert_produces_warning(SettingWithCopyWarning): + df.loc[0]["A"] = 111 + + with option_context("chained_assignment", "raise"): + with pytest.raises(SettingWithCopyError, match=msg): + df.loc[0]["A"] = 111 + + @pytest.mark.parametrize("rhs", [3, DataFrame({0: [1, 2, 3, 4]})]) + def test_detect_chained_assignment_warning_stacklevel( + self, rhs, using_copy_on_write, warn_copy_on_write + ): + # GH#42570 + df = DataFrame(np.arange(25).reshape(5, 5)) + df_original = df.copy() + chained = df.loc[:3] + with option_context("chained_assignment", "warn"): + if not using_copy_on_write and not warn_copy_on_write: + with tm.assert_produces_warning(SettingWithCopyWarning) as t: + chained[2] = rhs + assert t[0].filename == __file__ + else: + # INFO(CoW) no warning, and original dataframe not changed + chained[2] = rhs + tm.assert_frame_equal(df, df_original) + + # TODO(ArrayManager) fast_xs with array-like scalars is not yet working + @td.skip_array_manager_not_yet_implemented + def test_chained_getitem_with_lists(self): + # GH6394 + # Regression in chained getitem indexing with embedded list-like from + # 0.12 + + df = DataFrame({"A": 5 * [np.zeros(3)], "B": 5 * [np.ones(3)]}) + expected = df["A"].iloc[2] + result = df.loc[2, "A"] + tm.assert_numpy_array_equal(result, expected) + result2 = df.iloc[2]["A"] + tm.assert_numpy_array_equal(result2, expected) + result3 = df["A"].loc[2] + tm.assert_numpy_array_equal(result3, expected) + result4 = df["A"].iloc[2] + tm.assert_numpy_array_equal(result4, expected) + + def test_cache_updating(self): + # GH 4939, make sure to update the cache on setitem + + df = DataFrame( + np.zeros((10, 4)), + columns=Index(list("ABCD"), dtype=object), + ) + df["A"] # cache series + df.loc["Hello Friend"] = df.iloc[0] + assert "Hello Friend" in df["A"].index + assert "Hello Friend" in df["B"].index + + def test_cache_updating2(self, using_copy_on_write): + # 10264 + df = DataFrame( + np.zeros((5, 5), dtype="int64"), + columns=["a", "b", "c", "d", "e"], + index=range(5), + ) + df["f"] = 0 + df_orig = df.copy() + if using_copy_on_write: + with pytest.raises(ValueError, match="read-only"): + df.f.values[3] = 1 + tm.assert_frame_equal(df, df_orig) + return + + df.f.values[3] = 1 + + df.f.values[3] = 2 + expected = DataFrame( + np.zeros((5, 6), dtype="int64"), + columns=["a", "b", "c", "d", "e", "f"], + index=range(5), + ) + expected.at[3, "f"] = 2 + tm.assert_frame_equal(df, expected) + expected = Series([0, 0, 0, 2, 0], name="f") + tm.assert_series_equal(df.f, expected) + + def test_iloc_setitem_chained_assignment(self, using_copy_on_write): + # GH#3970 + with option_context("chained_assignment", None): + df = DataFrame({"aa": range(5), "bb": [2.2] * 5}) + df["cc"] = 0.0 + + ck = [True] * len(df) + + with tm.raises_chained_assignment_error(): + df["bb"].iloc[0] = 0.13 + + # GH#3970 this lookup used to break the chained setting to 0.15 + df.iloc[ck] + + with tm.raises_chained_assignment_error(): + df["bb"].iloc[0] = 0.15 + + if not using_copy_on_write: + assert df["bb"].iloc[0] == 0.15 + else: + assert df["bb"].iloc[0] == 2.2 + + def test_getitem_loc_assignment_slice_state(self): + # GH 13569 + df = DataFrame({"a": [10, 20, 30]}) + with tm.raises_chained_assignment_error(): + df["a"].loc[4] = 40 + tm.assert_frame_equal(df, DataFrame({"a": [10, 20, 30]})) + tm.assert_series_equal(df["a"], Series([10, 20, 30], name="a")) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_check_indexer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_check_indexer.py new file mode 100644 index 0000000000000000000000000000000000000000..975a31b873792c6afe59a23e5fef43b56ce7e46e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_check_indexer.py @@ -0,0 +1,105 @@ +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm +from pandas.api.indexers import check_array_indexer + + +@pytest.mark.parametrize( + "indexer, expected", + [ + # integer + ([1, 2], np.array([1, 2], dtype=np.intp)), + (np.array([1, 2], dtype="int64"), np.array([1, 2], dtype=np.intp)), + (pd.array([1, 2], dtype="Int32"), np.array([1, 2], dtype=np.intp)), + (pd.Index([1, 2]), np.array([1, 2], dtype=np.intp)), + # boolean + ([True, False, True], np.array([True, False, True], dtype=np.bool_)), + (np.array([True, False, True]), np.array([True, False, True], dtype=np.bool_)), + ( + pd.array([True, False, True], dtype="boolean"), + np.array([True, False, True], dtype=np.bool_), + ), + # other + ([], np.array([], dtype=np.intp)), + ], +) +def test_valid_input(indexer, expected): + arr = np.array([1, 2, 3]) + result = check_array_indexer(arr, indexer) + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize( + "indexer", [[True, False, None], pd.array([True, False, None], dtype="boolean")] +) +def test_boolean_na_returns_indexer(indexer): + # https://github.com/pandas-dev/pandas/issues/31503 + arr = np.array([1, 2, 3]) + + result = check_array_indexer(arr, indexer) + expected = np.array([True, False, False], dtype=bool) + + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize( + "indexer", + [ + [True, False], + pd.array([True, False], dtype="boolean"), + np.array([True, False], dtype=np.bool_), + ], +) +def test_bool_raise_length(indexer): + arr = np.array([1, 2, 3]) + + msg = "Boolean index has wrong length" + with pytest.raises(IndexError, match=msg): + check_array_indexer(arr, indexer) + + +@pytest.mark.parametrize( + "indexer", [[0, 1, None], pd.array([0, 1, pd.NA], dtype="Int64")] +) +def test_int_raise_missing_values(indexer): + arr = np.array([1, 2, 3]) + + msg = "Cannot index with an integer indexer containing NA values" + with pytest.raises(ValueError, match=msg): + check_array_indexer(arr, indexer) + + +@pytest.mark.parametrize( + "indexer", + [ + [0.0, 1.0], + np.array([1.0, 2.0], dtype="float64"), + np.array([True, False], dtype=object), + pd.Index([True, False], dtype=object), + ], +) +def test_raise_invalid_array_dtypes(indexer): + arr = np.array([1, 2, 3]) + + msg = "arrays used as indices must be of integer or boolean type" + with pytest.raises(IndexError, match=msg): + check_array_indexer(arr, indexer) + + +def test_raise_nullable_string_dtype(nullable_string_dtype): + indexer = pd.array(["a", "b"], dtype=nullable_string_dtype) + arr = np.array([1, 2, 3]) + + msg = "arrays used as indices must be of integer or boolean type" + with pytest.raises(IndexError, match=msg): + check_array_indexer(arr, indexer) + + +@pytest.mark.parametrize("indexer", [None, Ellipsis, slice(0, 3), (None,)]) +def test_pass_through_non_array_likes(indexer): + arr = np.array([1, 2, 3]) + + result = check_array_indexer(arr, indexer) + assert result == indexer diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_iat.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_iat.py new file mode 100644 index 0000000000000000000000000000000000000000..5b8c4f2d4b9b97228eb768797b224cedffb239a8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_iat.py @@ -0,0 +1,53 @@ +import numpy as np + +from pandas import ( + DataFrame, + Series, + period_range, +) +import pandas._testing as tm + + +def test_iat(float_frame): + for i, row in enumerate(float_frame.index): + for j, col in enumerate(float_frame.columns): + result = float_frame.iat[i, j] + expected = float_frame.at[row, col] + assert result == expected + + +def test_iat_duplicate_columns(): + # https://github.com/pandas-dev/pandas/issues/11754 + df = DataFrame([[1, 2]], columns=["x", "x"]) + assert df.iat[0, 0] == 1 + + +def test_iat_getitem_series_with_period_index(): + # GH#4390, iat incorrectly indexing + index = period_range("1/1/2001", periods=10) + ser = Series(np.random.default_rng(2).standard_normal(10), index=index) + expected = ser[index[0]] + result = ser.iat[0] + assert expected == result + + +def test_iat_setitem_item_cache_cleared( + indexer_ial, using_copy_on_write, warn_copy_on_write +): + # GH#45684 + data = {"x": np.arange(8, dtype=np.int64), "y": np.int64(0)} + df = DataFrame(data).copy() + ser = df["y"] + + # previously this iat setting would split the block and fail to clear + # the item_cache. + with tm.assert_cow_warning(warn_copy_on_write): + indexer_ial(df)[7, 0] = 9999 + + with tm.assert_cow_warning(warn_copy_on_write): + indexer_ial(df)[7, 1] = 1234 + + assert df.iat[7, 1] == 1234 + if not using_copy_on_write: + assert ser.iloc[-1] == 1234 + assert df.iloc[-1, -1] == 1234 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_indexers.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_indexers.py new file mode 100644 index 0000000000000000000000000000000000000000..ddc5c039160d5ada6c6dccb62514590a4ce9f620 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_indexers.py @@ -0,0 +1,61 @@ +# Tests aimed at pandas.core.indexers +import numpy as np +import pytest + +from pandas.core.indexers import ( + is_scalar_indexer, + length_of_indexer, + validate_indices, +) + + +def test_length_of_indexer(): + arr = np.zeros(4, dtype=bool) + arr[0] = 1 + result = length_of_indexer(arr) + assert result == 1 + + +def test_is_scalar_indexer(): + indexer = (0, 1) + assert is_scalar_indexer(indexer, 2) + assert not is_scalar_indexer(indexer[0], 2) + + indexer = (np.array([2]), 1) + assert not is_scalar_indexer(indexer, 2) + + indexer = (np.array([2]), np.array([3])) + assert not is_scalar_indexer(indexer, 2) + + indexer = (np.array([2]), np.array([3, 4])) + assert not is_scalar_indexer(indexer, 2) + + assert not is_scalar_indexer(slice(None), 1) + + indexer = 0 + assert is_scalar_indexer(indexer, 1) + + indexer = (0,) + assert is_scalar_indexer(indexer, 1) + + +class TestValidateIndices: + def test_validate_indices_ok(self): + indices = np.asarray([0, 1]) + validate_indices(indices, 2) + validate_indices(indices[:0], 0) + validate_indices(np.array([-1, -1]), 0) + + def test_validate_indices_low(self): + indices = np.asarray([0, -2]) + with pytest.raises(ValueError, match="'indices' contains"): + validate_indices(indices, 2) + + def test_validate_indices_high(self): + indices = np.asarray([0, 1, 2]) + with pytest.raises(IndexError, match="indices are out"): + validate_indices(indices, 2) + + def test_validate_indices_empty(self): + with pytest.raises(IndexError, match="indices are out"): + validate_indices(np.array([0, 1]), 0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_indexing.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_indexing.py new file mode 100644 index 0000000000000000000000000000000000000000..57f45f867254dbb4c5075b1180f655d3679a3a83 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_indexing.py @@ -0,0 +1,1157 @@ +""" test fancy indexing & misc """ + +import array +from datetime import datetime +import re +import weakref + +import numpy as np +import pytest + +from pandas._config import using_pyarrow_string_dtype + +from pandas.errors import IndexingError + +from pandas.core.dtypes.common import ( + is_float_dtype, + is_integer_dtype, + is_object_dtype, +) + +import pandas as pd +from pandas import ( + DataFrame, + Index, + NaT, + Series, + date_range, + offsets, + timedelta_range, +) +import pandas._testing as tm +from pandas.tests.indexing.common import _mklbl +from pandas.tests.indexing.test_floats import gen_obj + +# ------------------------------------------------------------------------ +# Indexing test cases + + +class TestFancy: + """pure get/set item & fancy indexing""" + + def test_setitem_ndarray_1d(self): + # GH5508 + + # len of indexer vs length of the 1d ndarray + df = DataFrame(index=Index(np.arange(1, 11), dtype=np.int64)) + df["foo"] = np.zeros(10, dtype=np.float64) + df["bar"] = np.zeros(10, dtype=complex) + + # invalid + msg = "Must have equal len keys and value when setting with an iterable" + with pytest.raises(ValueError, match=msg): + df.loc[df.index[2:5], "bar"] = np.array([2.33j, 1.23 + 0.1j, 2.2, 1.0]) + + # valid + df.loc[df.index[2:6], "bar"] = np.array([2.33j, 1.23 + 0.1j, 2.2, 1.0]) + + result = df.loc[df.index[2:6], "bar"] + expected = Series( + [2.33j, 1.23 + 0.1j, 2.2, 1.0], index=[3, 4, 5, 6], name="bar" + ) + tm.assert_series_equal(result, expected) + + def test_setitem_ndarray_1d_2(self): + # GH5508 + + # dtype getting changed? + df = DataFrame(index=Index(np.arange(1, 11))) + df["foo"] = np.zeros(10, dtype=np.float64) + df["bar"] = np.zeros(10, dtype=complex) + + msg = "Must have equal len keys and value when setting with an iterable" + with pytest.raises(ValueError, match=msg): + df[2:5] = np.arange(1, 4) * 1j + + @pytest.mark.filterwarnings( + "ignore:Series.__getitem__ treating keys as positions is deprecated:" + "FutureWarning" + ) + def test_getitem_ndarray_3d( + self, index, frame_or_series, indexer_sli, using_array_manager + ): + # GH 25567 + obj = gen_obj(frame_or_series, index) + idxr = indexer_sli(obj) + nd3 = np.random.default_rng(2).integers(5, size=(2, 2, 2)) + + msgs = [] + if frame_or_series is Series and indexer_sli in [tm.setitem, tm.iloc]: + msgs.append(r"Wrong number of dimensions. values.ndim > ndim \[3 > 1\]") + if using_array_manager: + msgs.append("Passed array should be 1-dimensional") + if frame_or_series is Series or indexer_sli is tm.iloc: + msgs.append(r"Buffer has wrong number of dimensions \(expected 1, got 3\)") + if using_array_manager: + msgs.append("indexer should be 1-dimensional") + if indexer_sli is tm.loc or ( + frame_or_series is Series and indexer_sli is tm.setitem + ): + msgs.append("Cannot index with multidimensional key") + if frame_or_series is DataFrame and indexer_sli is tm.setitem: + msgs.append("Index data must be 1-dimensional") + if isinstance(index, pd.IntervalIndex) and indexer_sli is tm.iloc: + msgs.append("Index data must be 1-dimensional") + if isinstance(index, (pd.TimedeltaIndex, pd.DatetimeIndex, pd.PeriodIndex)): + msgs.append("Data must be 1-dimensional") + if len(index) == 0 or isinstance(index, pd.MultiIndex): + msgs.append("positional indexers are out-of-bounds") + if type(index) is Index and not isinstance(index._values, np.ndarray): + # e.g. Int64 + msgs.append("values must be a 1D array") + + # string[pyarrow] + msgs.append("only handle 1-dimensional arrays") + + msg = "|".join(msgs) + + potential_errors = (IndexError, ValueError, NotImplementedError) + with pytest.raises(potential_errors, match=msg): + idxr[nd3] + + @pytest.mark.filterwarnings( + "ignore:Series.__setitem__ treating keys as positions is deprecated:" + "FutureWarning" + ) + def test_setitem_ndarray_3d(self, index, frame_or_series, indexer_sli): + # GH 25567 + obj = gen_obj(frame_or_series, index) + idxr = indexer_sli(obj) + nd3 = np.random.default_rng(2).integers(5, size=(2, 2, 2)) + + if indexer_sli is tm.iloc: + err = ValueError + msg = f"Cannot set values with ndim > {obj.ndim}" + else: + err = ValueError + msg = "|".join( + [ + r"Buffer has wrong number of dimensions \(expected 1, got 3\)", + "Cannot set values with ndim > 1", + "Index data must be 1-dimensional", + "Data must be 1-dimensional", + "Array conditional must be same shape as self", + ] + ) + + with pytest.raises(err, match=msg): + idxr[nd3] = 0 + + def test_getitem_ndarray_0d(self): + # GH#24924 + key = np.array(0) + + # dataframe __getitem__ + df = DataFrame([[1, 2], [3, 4]]) + result = df[key] + expected = Series([1, 3], name=0) + tm.assert_series_equal(result, expected) + + # series __getitem__ + ser = Series([1, 2]) + result = ser[key] + assert result == 1 + + def test_inf_upcast(self): + # GH 16957 + # We should be able to use np.inf as a key + # np.inf should cause an index to convert to float + + # Test with np.inf in rows + df = DataFrame(columns=[0]) + df.loc[1] = 1 + df.loc[2] = 2 + df.loc[np.inf] = 3 + + # make sure we can look up the value + assert df.loc[np.inf, 0] == 3 + + result = df.index + expected = Index([1, 2, np.inf], dtype=np.float64) + tm.assert_index_equal(result, expected) + + def test_setitem_dtype_upcast(self): + # GH3216 + df = DataFrame([{"a": 1}, {"a": 3, "b": 2}]) + df["c"] = np.nan + assert df["c"].dtype == np.float64 + + with tm.assert_produces_warning( + FutureWarning, match="item of incompatible dtype" + ): + df.loc[0, "c"] = "foo" + expected = DataFrame( + {"a": [1, 3], "b": [np.nan, 2], "c": Series(["foo", np.nan], dtype=object)} + ) + tm.assert_frame_equal(df, expected) + + @pytest.mark.parametrize("val", [3.14, "wxyz"]) + def test_setitem_dtype_upcast2(self, val): + # GH10280 + df = DataFrame( + np.arange(6, dtype="int64").reshape(2, 3), + index=list("ab"), + columns=["foo", "bar", "baz"], + ) + + left = df.copy() + with tm.assert_produces_warning( + FutureWarning, match="item of incompatible dtype" + ): + left.loc["a", "bar"] = val + right = DataFrame( + [[0, val, 2], [3, 4, 5]], + index=list("ab"), + columns=["foo", "bar", "baz"], + ) + + tm.assert_frame_equal(left, right) + assert is_integer_dtype(left["foo"]) + assert is_integer_dtype(left["baz"]) + + def test_setitem_dtype_upcast3(self): + left = DataFrame( + np.arange(6, dtype="int64").reshape(2, 3) / 10.0, + index=list("ab"), + columns=["foo", "bar", "baz"], + ) + with tm.assert_produces_warning( + FutureWarning, match="item of incompatible dtype" + ): + left.loc["a", "bar"] = "wxyz" + + right = DataFrame( + [[0, "wxyz", 0.2], [0.3, 0.4, 0.5]], + index=list("ab"), + columns=["foo", "bar", "baz"], + ) + + tm.assert_frame_equal(left, right) + assert is_float_dtype(left["foo"]) + assert is_float_dtype(left["baz"]) + + def test_dups_fancy_indexing(self): + # GH 3455 + + df = DataFrame(np.eye(3), columns=["a", "a", "b"]) + result = df[["b", "a"]].columns + expected = Index(["b", "a", "a"]) + tm.assert_index_equal(result, expected) + + def test_dups_fancy_indexing_across_dtypes(self): + # across dtypes + df = DataFrame([[1, 2, 1.0, 2.0, 3.0, "foo", "bar"]], columns=list("aaaaaaa")) + result = DataFrame([[1, 2, 1.0, 2.0, 3.0, "foo", "bar"]]) + result.columns = list("aaaaaaa") # GH#3468 + + # GH#3509 smoke tests for indexing with duplicate columns + df.iloc[:, 4] + result.iloc[:, 4] + + tm.assert_frame_equal(df, result) + + def test_dups_fancy_indexing_not_in_order(self): + # GH 3561, dups not in selected order + df = DataFrame( + {"test": [5, 7, 9, 11], "test1": [4.0, 5, 6, 7], "other": list("abcd")}, + index=["A", "A", "B", "C"], + ) + rows = ["C", "B"] + expected = DataFrame( + {"test": [11, 9], "test1": [7.0, 6], "other": ["d", "c"]}, index=rows + ) + result = df.loc[rows] + tm.assert_frame_equal(result, expected) + + result = df.loc[Index(rows)] + tm.assert_frame_equal(result, expected) + + rows = ["C", "B", "E"] + with pytest.raises(KeyError, match="not in index"): + df.loc[rows] + + # see GH5553, make sure we use the right indexer + rows = ["F", "G", "H", "C", "B", "E"] + with pytest.raises(KeyError, match="not in index"): + df.loc[rows] + + def test_dups_fancy_indexing_only_missing_label(self, using_infer_string): + # List containing only missing label + dfnu = DataFrame( + np.random.default_rng(2).standard_normal((5, 3)), index=list("AABCD") + ) + if using_infer_string: + with pytest.raises( + KeyError, + match=re.escape( + "\"None of [Index(['E'], dtype='string')] are in the [index]\"" + ), + ): + dfnu.loc[["E"]] + else: + with pytest.raises( + KeyError, + match=re.escape( + "\"None of [Index(['E'], dtype='object')] are in the [index]\"" + ), + ): + dfnu.loc[["E"]] + + @pytest.mark.parametrize("vals", [[0, 1, 2], list("abc")]) + def test_dups_fancy_indexing_missing_label(self, vals): + # GH 4619; duplicate indexer with missing label + df = DataFrame({"A": vals}) + with pytest.raises(KeyError, match="not in index"): + df.loc[[0, 8, 0]] + + def test_dups_fancy_indexing_non_unique(self): + # non unique with non unique selector + df = DataFrame({"test": [5, 7, 9, 11]}, index=["A", "A", "B", "C"]) + with pytest.raises(KeyError, match="not in index"): + df.loc[["A", "A", "E"]] + + def test_dups_fancy_indexing2(self): + # GH 5835 + # dups on index and missing values + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 5)), + columns=["A", "B", "B", "B", "A"], + ) + + with pytest.raises(KeyError, match="not in index"): + df.loc[:, ["A", "B", "C"]] + + def test_dups_fancy_indexing3(self): + # GH 6504, multi-axis indexing + df = DataFrame( + np.random.default_rng(2).standard_normal((9, 2)), + index=[1, 1, 1, 2, 2, 2, 3, 3, 3], + columns=["a", "b"], + ) + + expected = df.iloc[0:6] + result = df.loc[[1, 2]] + tm.assert_frame_equal(result, expected) + + expected = df + result = df.loc[:, ["a", "b"]] + tm.assert_frame_equal(result, expected) + + expected = df.iloc[0:6, :] + result = df.loc[[1, 2], ["a", "b"]] + tm.assert_frame_equal(result, expected) + + def test_duplicate_int_indexing(self, indexer_sl): + # GH 17347 + ser = Series(range(3), index=[1, 1, 3]) + expected = Series(range(2), index=[1, 1]) + result = indexer_sl(ser)[[1]] + tm.assert_series_equal(result, expected) + + def test_indexing_mixed_frame_bug(self): + # GH3492 + df = DataFrame( + {"a": {1: "aaa", 2: "bbb", 3: "ccc"}, "b": {1: 111, 2: 222, 3: 333}} + ) + + # this works, new column is created correctly + df["test"] = df["a"].apply(lambda x: "_" if x == "aaa" else x) + + # this does not work, ie column test is not changed + idx = df["test"] == "_" + temp = df.loc[idx, "a"].apply(lambda x: "-----" if x == "aaa" else x) + df.loc[idx, "test"] = temp + assert df.iloc[0, 2] == "-----" + + def test_multitype_list_index_access(self): + # GH 10610 + df = DataFrame( + np.random.default_rng(2).random((10, 5)), columns=["a"] + [20, 21, 22, 23] + ) + + with pytest.raises(KeyError, match=re.escape("'[26, -8] not in index'")): + df[[22, 26, -8]] + assert df[21].shape[0] == df.shape[0] + + def test_set_index_nan(self): + # GH 3586 + df = DataFrame( + { + "PRuid": { + 17: "nonQC", + 18: "nonQC", + 19: "nonQC", + 20: "10", + 21: "11", + 22: "12", + 23: "13", + 24: "24", + 25: "35", + 26: "46", + 27: "47", + 28: "48", + 29: "59", + 30: "10", + }, + "QC": { + 17: 0.0, + 18: 0.0, + 19: 0.0, + 20: np.nan, + 21: np.nan, + 22: np.nan, + 23: np.nan, + 24: 1.0, + 25: np.nan, + 26: np.nan, + 27: np.nan, + 28: np.nan, + 29: np.nan, + 30: np.nan, + }, + "data": { + 17: 7.9544899999999998, + 18: 8.0142609999999994, + 19: 7.8591520000000008, + 20: 0.86140349999999999, + 21: 0.87853110000000001, + 22: 0.8427041999999999, + 23: 0.78587700000000005, + 24: 0.73062459999999996, + 25: 0.81668560000000001, + 26: 0.81927080000000008, + 27: 0.80705009999999999, + 28: 0.81440240000000008, + 29: 0.80140849999999997, + 30: 0.81307740000000006, + }, + "year": { + 17: 2006, + 18: 2007, + 19: 2008, + 20: 1985, + 21: 1985, + 22: 1985, + 23: 1985, + 24: 1985, + 25: 1985, + 26: 1985, + 27: 1985, + 28: 1985, + 29: 1985, + 30: 1986, + }, + } + ).reset_index() + + result = ( + df.set_index(["year", "PRuid", "QC"]) + .reset_index() + .reindex(columns=df.columns) + ) + tm.assert_frame_equal(result, df) + + @pytest.mark.xfail( + using_pyarrow_string_dtype(), reason="can't multiply arrow strings" + ) + def test_multi_assign(self): + # GH 3626, an assignment of a sub-df to a df + # set float64 to avoid upcast when setting nan + df = DataFrame( + { + "FC": ["a", "b", "a", "b", "a", "b"], + "PF": [0, 0, 0, 0, 1, 1], + "col1": list(range(6)), + "col2": list(range(6, 12)), + } + ).astype({"col2": "float64"}) + df.iloc[1, 0] = np.nan + df2 = df.copy() + + mask = ~df2.FC.isna() + cols = ["col1", "col2"] + + dft = df2 * 2 + dft.iloc[3, 3] = np.nan + + expected = DataFrame( + { + "FC": ["a", np.nan, "a", "b", "a", "b"], + "PF": [0, 0, 0, 0, 1, 1], + "col1": Series([0, 1, 4, 6, 8, 10]), + "col2": [12, 7, 16, np.nan, 20, 22], + } + ) + + # frame on rhs + df2.loc[mask, cols] = dft.loc[mask, cols] + tm.assert_frame_equal(df2, expected) + + # with an ndarray on rhs + # coerces to float64 because values has float64 dtype + # GH 14001 + expected = DataFrame( + { + "FC": ["a", np.nan, "a", "b", "a", "b"], + "PF": [0, 0, 0, 0, 1, 1], + "col1": [0, 1, 4, 6, 8, 10], + "col2": [12, 7, 16, np.nan, 20, 22], + } + ) + df2 = df.copy() + df2.loc[mask, cols] = dft.loc[mask, cols].values + tm.assert_frame_equal(df2, expected) + + def test_multi_assign_broadcasting_rhs(self): + # broadcasting on the rhs is required + df = DataFrame( + { + "A": [1, 2, 0, 0, 0], + "B": [0, 0, 0, 10, 11], + "C": [0, 0, 0, 10, 11], + "D": [3, 4, 5, 6, 7], + } + ) + + expected = df.copy() + mask = expected["A"] == 0 + for col in ["A", "B"]: + expected.loc[mask, col] = df["D"] + + df.loc[df["A"] == 0, ["A", "B"]] = df["D"].copy() + tm.assert_frame_equal(df, expected) + + def test_setitem_list(self): + # GH 6043 + # iloc with a list + df = DataFrame(index=[0, 1], columns=[0]) + df.iloc[1, 0] = [1, 2, 3] + df.iloc[1, 0] = [1, 2] + + result = DataFrame(index=[0, 1], columns=[0]) + result.iloc[1, 0] = [1, 2] + + tm.assert_frame_equal(result, df) + + def test_string_slice(self): + # GH 14424 + # string indexing against datetimelike with object + # dtype should properly raises KeyError + df = DataFrame([1], Index([pd.Timestamp("2011-01-01")], dtype=object)) + assert df.index._is_all_dates + with pytest.raises(KeyError, match="'2011'"): + df["2011"] + + with pytest.raises(KeyError, match="'2011'"): + df.loc["2011", 0] + + def test_string_slice_empty(self): + # GH 14424 + + df = DataFrame() + assert not df.index._is_all_dates + with pytest.raises(KeyError, match="'2011'"): + df["2011"] + + with pytest.raises(KeyError, match="^0$"): + df.loc["2011", 0] + + def test_astype_assignment(self, using_infer_string): + # GH4312 (iloc) + df_orig = DataFrame( + [["1", "2", "3", ".4", 5, 6.0, "foo"]], columns=list("ABCDEFG") + ) + + df = df_orig.copy() + + # with the enforcement of GH#45333 in 2.0, this setting is attempted inplace, + # so object dtype is retained + df.iloc[:, 0:2] = df.iloc[:, 0:2].astype(np.int64) + expected = DataFrame( + [[1, 2, "3", ".4", 5, 6.0, "foo"]], columns=list("ABCDEFG") + ) + if not using_infer_string: + expected["A"] = expected["A"].astype(object) + expected["B"] = expected["B"].astype(object) + tm.assert_frame_equal(df, expected) + + # GH5702 (loc) + df = df_orig.copy() + df.loc[:, "A"] = df.loc[:, "A"].astype(np.int64) + expected = DataFrame( + [[1, "2", "3", ".4", 5, 6.0, "foo"]], columns=list("ABCDEFG") + ) + if not using_infer_string: + expected["A"] = expected["A"].astype(object) + tm.assert_frame_equal(df, expected) + + df = df_orig.copy() + df.loc[:, ["B", "C"]] = df.loc[:, ["B", "C"]].astype(np.int64) + expected = DataFrame( + [["1", 2, 3, ".4", 5, 6.0, "foo"]], columns=list("ABCDEFG") + ) + if not using_infer_string: + expected["B"] = expected["B"].astype(object) + expected["C"] = expected["C"].astype(object) + tm.assert_frame_equal(df, expected) + + def test_astype_assignment_full_replacements(self): + # full replacements / no nans + df = DataFrame({"A": [1.0, 2.0, 3.0, 4.0]}) + + # With the enforcement of GH#45333 in 2.0, this assignment occurs inplace, + # so float64 is retained + df.iloc[:, 0] = df["A"].astype(np.int64) + expected = DataFrame({"A": [1.0, 2.0, 3.0, 4.0]}) + tm.assert_frame_equal(df, expected) + + df = DataFrame({"A": [1.0, 2.0, 3.0, 4.0]}) + df.loc[:, "A"] = df["A"].astype(np.int64) + tm.assert_frame_equal(df, expected) + + @pytest.mark.parametrize("indexer", [tm.getitem, tm.loc]) + def test_index_type_coercion(self, indexer): + # GH 11836 + # if we have an index type and set it with something that looks + # to numpy like the same, but is actually, not + # (e.g. setting with a float or string '0') + # then we need to coerce to object + + # integer indexes + for s in [Series(range(5)), Series(range(5), index=range(1, 6))]: + assert is_integer_dtype(s.index) + + s2 = s.copy() + indexer(s2)[0.1] = 0 + assert is_float_dtype(s2.index) + assert indexer(s2)[0.1] == 0 + + s2 = s.copy() + indexer(s2)[0.0] = 0 + exp = s.index + if 0 not in s: + exp = Index(s.index.tolist() + [0]) + tm.assert_index_equal(s2.index, exp) + + s2 = s.copy() + indexer(s2)["0"] = 0 + assert is_object_dtype(s2.index) + + for s in [Series(range(5), index=np.arange(5.0))]: + assert is_float_dtype(s.index) + + s2 = s.copy() + indexer(s2)[0.1] = 0 + assert is_float_dtype(s2.index) + assert indexer(s2)[0.1] == 0 + + s2 = s.copy() + indexer(s2)[0.0] = 0 + tm.assert_index_equal(s2.index, s.index) + + s2 = s.copy() + indexer(s2)["0"] = 0 + assert is_object_dtype(s2.index) + + +class TestMisc: + def test_float_index_to_mixed(self): + df = DataFrame( + { + 0.0: np.random.default_rng(2).random(10), + 1.0: np.random.default_rng(2).random(10), + } + ) + df["a"] = 10 + + expected = DataFrame({0.0: df[0.0], 1.0: df[1.0], "a": [10] * 10}) + tm.assert_frame_equal(expected, df) + + def test_float_index_non_scalar_assignment(self): + df = DataFrame({"a": [1, 2, 3], "b": [3, 4, 5]}, index=[1.0, 2.0, 3.0]) + df.loc[df.index[:2]] = 1 + expected = DataFrame({"a": [1, 1, 3], "b": [1, 1, 5]}, index=df.index) + tm.assert_frame_equal(expected, df) + + def test_loc_setitem_fullindex_views(self): + df = DataFrame({"a": [1, 2, 3], "b": [3, 4, 5]}, index=[1.0, 2.0, 3.0]) + df2 = df.copy() + df.loc[df.index] = df.loc[df.index] + tm.assert_frame_equal(df, df2) + + @pytest.mark.xfail(using_pyarrow_string_dtype(), reason="can't set int into string") + def test_rhs_alignment(self): + # GH8258, tests that both rows & columns are aligned to what is + # assigned to. covers both uniform data-type & multi-type cases + def run_tests(df, rhs, right_loc, right_iloc): + # label, index, slice + lbl_one, idx_one, slice_one = list("bcd"), [1, 2, 3], slice(1, 4) + lbl_two, idx_two, slice_two = ["joe", "jolie"], [1, 2], slice(1, 3) + + left = df.copy() + left.loc[lbl_one, lbl_two] = rhs + tm.assert_frame_equal(left, right_loc) + + left = df.copy() + left.iloc[idx_one, idx_two] = rhs + tm.assert_frame_equal(left, right_iloc) + + left = df.copy() + left.iloc[slice_one, slice_two] = rhs + tm.assert_frame_equal(left, right_iloc) + + xs = np.arange(20).reshape(5, 4) + cols = ["jim", "joe", "jolie", "joline"] + df = DataFrame(xs, columns=cols, index=list("abcde"), dtype="int64") + + # right hand side; permute the indices and multiplpy by -2 + rhs = -2 * df.iloc[3:0:-1, 2:0:-1] + + # expected `right` result; just multiply by -2 + right_iloc = df.copy() + right_iloc["joe"] = [1, 14, 10, 6, 17] + right_iloc["jolie"] = [2, 13, 9, 5, 18] + right_iloc.iloc[1:4, 1:3] *= -2 + right_loc = df.copy() + right_loc.iloc[1:4, 1:3] *= -2 + + # run tests with uniform dtypes + run_tests(df, rhs, right_loc, right_iloc) + + # make frames multi-type & re-run tests + for frame in [df, rhs, right_loc, right_iloc]: + frame["joe"] = frame["joe"].astype("float64") + frame["jolie"] = frame["jolie"].map(lambda x: f"@{x}") + right_iloc["joe"] = [1.0, "@-28", "@-20", "@-12", 17.0] + right_iloc["jolie"] = ["@2", -26.0, -18.0, -10.0, "@18"] + with tm.assert_produces_warning(FutureWarning, match="incompatible dtype"): + run_tests(df, rhs, right_loc, right_iloc) + + @pytest.mark.parametrize( + "idx", [_mklbl("A", 20), np.arange(20) + 100, np.linspace(100, 150, 20)] + ) + def test_str_label_slicing_with_negative_step(self, idx): + SLC = pd.IndexSlice + + idx = Index(idx) + ser = Series(np.arange(20), index=idx) + tm.assert_indexing_slices_equivalent(ser, SLC[idx[9] :: -1], SLC[9::-1]) + tm.assert_indexing_slices_equivalent(ser, SLC[: idx[9] : -1], SLC[:8:-1]) + tm.assert_indexing_slices_equivalent( + ser, SLC[idx[13] : idx[9] : -1], SLC[13:8:-1] + ) + tm.assert_indexing_slices_equivalent(ser, SLC[idx[9] : idx[13] : -1], SLC[:0]) + + def test_slice_with_zero_step_raises(self, index, indexer_sl, frame_or_series): + obj = frame_or_series(np.arange(len(index)), index=index) + with pytest.raises(ValueError, match="slice step cannot be zero"): + indexer_sl(obj)[::0] + + def test_loc_setitem_indexing_assignment_dict_already_exists(self): + index = Index([-5, 0, 5], name="z") + df = DataFrame({"x": [1, 2, 6], "y": [2, 2, 8]}, index=index) + expected = df.copy() + rhs = {"x": 9, "y": 99} + df.loc[5] = rhs + expected.loc[5] = [9, 99] + tm.assert_frame_equal(df, expected) + + # GH#38335 same thing, mixed dtypes + df = DataFrame({"x": [1, 2, 6], "y": [2.0, 2.0, 8.0]}, index=index) + df.loc[5] = rhs + expected = DataFrame({"x": [1, 2, 9], "y": [2.0, 2.0, 99.0]}, index=index) + tm.assert_frame_equal(df, expected) + + def test_iloc_getitem_indexing_dtypes_on_empty(self): + # Check that .iloc returns correct dtypes GH9983 + df = DataFrame({"a": [1, 2, 3], "b": ["b", "b2", "b3"]}) + df2 = df.iloc[[], :] + + assert df2.loc[:, "a"].dtype == np.int64 + tm.assert_series_equal(df2.loc[:, "a"], df2.iloc[:, 0]) + + @pytest.mark.parametrize("size", [5, 999999, 1000000]) + def test_loc_range_in_series_indexing(self, size): + # range can cause an indexing error + # GH 11652 + s = Series(index=range(size), dtype=np.float64) + s.loc[range(1)] = 42 + tm.assert_series_equal(s.loc[range(1)], Series(42.0, index=[0])) + + s.loc[range(2)] = 43 + tm.assert_series_equal(s.loc[range(2)], Series(43.0, index=[0, 1])) + + def test_partial_boolean_frame_indexing(self): + # GH 17170 + df = DataFrame( + np.arange(9.0).reshape(3, 3), index=list("abc"), columns=list("ABC") + ) + index_df = DataFrame(1, index=list("ab"), columns=list("AB")) + result = df[index_df.notnull()] + expected = DataFrame( + np.array([[0.0, 1.0, np.nan], [3.0, 4.0, np.nan], [np.nan] * 3]), + index=list("abc"), + columns=list("ABC"), + ) + tm.assert_frame_equal(result, expected) + + def test_no_reference_cycle(self): + df = DataFrame({"a": [0, 1], "b": [2, 3]}) + for name in ("loc", "iloc", "at", "iat"): + getattr(df, name) + wr = weakref.ref(df) + del df + assert wr() is None + + def test_label_indexing_on_nan(self, nulls_fixture): + # GH 32431 + df = Series([1, "{1,2}", 1, nulls_fixture]) + vc = df.value_counts(dropna=False) + result1 = vc.loc[nulls_fixture] + result2 = vc[nulls_fixture] + + expected = 1 + assert result1 == expected + assert result2 == expected + + +class TestDataframeNoneCoercion: + EXPECTED_SINGLE_ROW_RESULTS = [ + # For numeric series, we should coerce to NaN. + ([1, 2, 3], [np.nan, 2, 3], FutureWarning), + ([1.0, 2.0, 3.0], [np.nan, 2.0, 3.0], None), + # For datetime series, we should coerce to NaT. + ( + [datetime(2000, 1, 1), datetime(2000, 1, 2), datetime(2000, 1, 3)], + [NaT, datetime(2000, 1, 2), datetime(2000, 1, 3)], + None, + ), + # For objects, we should preserve the None value. + (["foo", "bar", "baz"], [None, "bar", "baz"], None), + ] + + @pytest.mark.parametrize("expected", EXPECTED_SINGLE_ROW_RESULTS) + def test_coercion_with_loc(self, expected): + start_data, expected_result, warn = expected + + start_dataframe = DataFrame({"foo": start_data}) + start_dataframe.loc[0, ["foo"]] = None + + expected_dataframe = DataFrame({"foo": expected_result}) + tm.assert_frame_equal(start_dataframe, expected_dataframe) + + @pytest.mark.parametrize("expected", EXPECTED_SINGLE_ROW_RESULTS) + def test_coercion_with_setitem_and_dataframe(self, expected): + start_data, expected_result, warn = expected + + start_dataframe = DataFrame({"foo": start_data}) + start_dataframe[start_dataframe["foo"] == start_dataframe["foo"][0]] = None + + expected_dataframe = DataFrame({"foo": expected_result}) + tm.assert_frame_equal(start_dataframe, expected_dataframe) + + @pytest.mark.parametrize("expected", EXPECTED_SINGLE_ROW_RESULTS) + def test_none_coercion_loc_and_dataframe(self, expected): + start_data, expected_result, warn = expected + + start_dataframe = DataFrame({"foo": start_data}) + start_dataframe.loc[start_dataframe["foo"] == start_dataframe["foo"][0]] = None + + expected_dataframe = DataFrame({"foo": expected_result}) + tm.assert_frame_equal(start_dataframe, expected_dataframe) + + def test_none_coercion_mixed_dtypes(self): + start_dataframe = DataFrame( + { + "a": [1, 2, 3], + "b": [1.0, 2.0, 3.0], + "c": [datetime(2000, 1, 1), datetime(2000, 1, 2), datetime(2000, 1, 3)], + "d": ["a", "b", "c"], + } + ) + start_dataframe.iloc[0] = None + + exp = DataFrame( + { + "a": [np.nan, 2, 3], + "b": [np.nan, 2.0, 3.0], + "c": [NaT, datetime(2000, 1, 2), datetime(2000, 1, 3)], + "d": [None, "b", "c"], + } + ) + tm.assert_frame_equal(start_dataframe, exp) + + +class TestDatetimelikeCoercion: + def test_setitem_dt64_string_scalar(self, tz_naive_fixture, indexer_sli): + # dispatching _can_hold_element to underlying DatetimeArray + tz = tz_naive_fixture + + dti = date_range("2016-01-01", periods=3, tz=tz) + ser = Series(dti.copy(deep=True)) + + values = ser._values + + newval = "2018-01-01" + values._validate_setitem_value(newval) + + indexer_sli(ser)[0] = newval + + if tz is None: + # TODO(EA2D): we can make this no-copy in tz-naive case too + assert ser.dtype == dti.dtype + assert ser._values._ndarray is values._ndarray + else: + assert ser._values is values + + @pytest.mark.parametrize("box", [list, np.array, pd.array, pd.Categorical, Index]) + @pytest.mark.parametrize( + "key", [[0, 1], slice(0, 2), np.array([True, True, False])] + ) + def test_setitem_dt64_string_values(self, tz_naive_fixture, indexer_sli, key, box): + # dispatching _can_hold_element to underling DatetimeArray + tz = tz_naive_fixture + + if isinstance(key, slice) and indexer_sli is tm.loc: + key = slice(0, 1) + + dti = date_range("2016-01-01", periods=3, tz=tz) + ser = Series(dti.copy(deep=True)) + + values = ser._values + + newvals = box(["2019-01-01", "2010-01-02"]) + values._validate_setitem_value(newvals) + + indexer_sli(ser)[key] = newvals + + if tz is None: + # TODO(EA2D): we can make this no-copy in tz-naive case too + assert ser.dtype == dti.dtype + assert ser._values._ndarray is values._ndarray + else: + assert ser._values is values + + @pytest.mark.parametrize("scalar", ["3 Days", offsets.Hour(4)]) + def test_setitem_td64_scalar(self, indexer_sli, scalar): + # dispatching _can_hold_element to underling TimedeltaArray + tdi = timedelta_range("1 Day", periods=3) + ser = Series(tdi.copy(deep=True)) + + values = ser._values + values._validate_setitem_value(scalar) + + indexer_sli(ser)[0] = scalar + assert ser._values._ndarray is values._ndarray + + @pytest.mark.parametrize("box", [list, np.array, pd.array, pd.Categorical, Index]) + @pytest.mark.parametrize( + "key", [[0, 1], slice(0, 2), np.array([True, True, False])] + ) + def test_setitem_td64_string_values(self, indexer_sli, key, box): + # dispatching _can_hold_element to underling TimedeltaArray + if isinstance(key, slice) and indexer_sli is tm.loc: + key = slice(0, 1) + + tdi = timedelta_range("1 Day", periods=3) + ser = Series(tdi.copy(deep=True)) + + values = ser._values + + newvals = box(["10 Days", "44 hours"]) + values._validate_setitem_value(newvals) + + indexer_sli(ser)[key] = newvals + assert ser._values._ndarray is values._ndarray + + +def test_extension_array_cross_section(): + # A cross-section of a homogeneous EA should be an EA + df = DataFrame( + { + "A": pd.array([1, 2], dtype="Int64"), + "B": pd.array([3, 4], dtype="Int64"), + }, + index=["a", "b"], + ) + expected = Series(pd.array([1, 3], dtype="Int64"), index=["A", "B"], name="a") + result = df.loc["a"] + tm.assert_series_equal(result, expected) + + result = df.iloc[0] + tm.assert_series_equal(result, expected) + + +def test_extension_array_cross_section_converts(): + # all numeric columns -> numeric series + df = DataFrame( + { + "A": pd.array([1, 2], dtype="Int64"), + "B": np.array([1, 2], dtype="int64"), + }, + index=["a", "b"], + ) + result = df.loc["a"] + expected = Series([1, 1], dtype="Int64", index=["A", "B"], name="a") + tm.assert_series_equal(result, expected) + + result = df.iloc[0] + tm.assert_series_equal(result, expected) + + # mixed columns -> object series + df = DataFrame( + {"A": pd.array([1, 2], dtype="Int64"), "B": np.array(["a", "b"])}, + index=["a", "b"], + ) + result = df.loc["a"] + expected = Series([1, "a"], dtype=object, index=["A", "B"], name="a") + tm.assert_series_equal(result, expected) + + result = df.iloc[0] + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "ser, keys", + [(Series([10]), (0, 0)), (Series([1, 2, 3], index=list("abc")), (0, 1))], +) +def test_ser_tup_indexer_exceeds_dimensions(ser, keys, indexer_li): + # GH#13831 + exp_err, exp_msg = IndexingError, "Too many indexers" + with pytest.raises(exp_err, match=exp_msg): + indexer_li(ser)[keys] + + if indexer_li == tm.iloc: + # For iloc.__setitem__ we let numpy handle the error reporting. + exp_err, exp_msg = IndexError, "too many indices for array" + + with pytest.raises(exp_err, match=exp_msg): + indexer_li(ser)[keys] = 0 + + +def test_ser_list_indexer_exceeds_dimensions(indexer_li): + # GH#13831 + # Make sure an exception is raised when a tuple exceeds the dimension of the series, + # but not list when a list is used. + ser = Series([10]) + res = indexer_li(ser)[[0, 0]] + exp = Series([10, 10], index=Index([0, 0])) + tm.assert_series_equal(res, exp) + + +@pytest.mark.parametrize( + "value", [(0, 1), [0, 1], np.array([0, 1]), array.array("b", [0, 1])] +) +def test_scalar_setitem_with_nested_value(value): + # For numeric data, we try to unpack and thus raise for mismatching length + df = DataFrame({"A": [1, 2, 3]}) + msg = "|".join( + [ + "Must have equal len keys and value", + "setting an array element with a sequence", + ] + ) + with pytest.raises(ValueError, match=msg): + df.loc[0, "B"] = value + + # TODO For object dtype this happens as well, but should we rather preserve + # the nested data and set as such? + df = DataFrame({"A": [1, 2, 3], "B": np.array([1, "a", "b"], dtype=object)}) + with pytest.raises(ValueError, match="Must have equal len keys and value"): + df.loc[0, "B"] = value + # if isinstance(value, np.ndarray): + # assert (df.loc[0, "B"] == value).all() + # else: + # assert df.loc[0, "B"] == value + + +@pytest.mark.parametrize( + "value", [(0, 1), [0, 1], np.array([0, 1]), array.array("b", [0, 1])] +) +def test_scalar_setitem_series_with_nested_value(value, indexer_sli): + # For numeric data, we try to unpack and thus raise for mismatching length + ser = Series([1, 2, 3]) + with pytest.raises(ValueError, match="setting an array element with a sequence"): + indexer_sli(ser)[0] = value + + # but for object dtype we preserve the nested data and set as such + ser = Series([1, "a", "b"], dtype=object) + indexer_sli(ser)[0] = value + if isinstance(value, np.ndarray): + assert (ser.loc[0] == value).all() + else: + assert ser.loc[0] == value + + +@pytest.mark.parametrize( + "value", [(0.0,), [0.0], np.array([0.0]), array.array("d", [0.0])] +) +def test_scalar_setitem_with_nested_value_length1(value): + # https://github.com/pandas-dev/pandas/issues/46268 + + # For numeric data, assigning length-1 array to scalar position gets unpacked + df = DataFrame({"A": [1, 2, 3]}) + df.loc[0, "B"] = value + expected = DataFrame({"A": [1, 2, 3], "B": [0.0, np.nan, np.nan]}) + tm.assert_frame_equal(df, expected) + + # but for object dtype we preserve the nested data + df = DataFrame({"A": [1, 2, 3], "B": np.array([1, "a", "b"], dtype=object)}) + df.loc[0, "B"] = value + if isinstance(value, np.ndarray): + assert (df.loc[0, "B"] == value).all() + else: + assert df.loc[0, "B"] == value + + +@pytest.mark.parametrize( + "value", [(0.0,), [0.0], np.array([0.0]), array.array("d", [0.0])] +) +def test_scalar_setitem_series_with_nested_value_length1(value, indexer_sli): + # For numeric data, assigning length-1 array to scalar position gets unpacked + # TODO this only happens in case of ndarray, should we make this consistent + # for all list-likes? (as happens for DataFrame.(i)loc, see test above) + ser = Series([1.0, 2.0, 3.0]) + if isinstance(value, np.ndarray): + indexer_sli(ser)[0] = value + expected = Series([0.0, 2.0, 3.0]) + tm.assert_series_equal(ser, expected) + else: + with pytest.raises( + ValueError, match="setting an array element with a sequence" + ): + indexer_sli(ser)[0] = value + + # but for object dtype we preserve the nested data + ser = Series([1, "a", "b"], dtype=object) + indexer_sli(ser)[0] = value + if isinstance(value, np.ndarray): + assert (ser.loc[0] == value).all() + else: + assert ser.loc[0] == value + + +def test_object_dtype_series_set_series_element(): + # GH 48933 + s1 = Series(dtype="O", index=["a", "b"]) + + s1["a"] = Series() + s1.loc["b"] = Series() + + tm.assert_series_equal(s1.loc["a"], Series()) + tm.assert_series_equal(s1.loc["b"], Series()) + + s2 = Series(dtype="O", index=["a", "b"]) + + s2.iloc[1] = Series() + tm.assert_series_equal(s2.iloc[1], Series()) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_loc.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_loc.py new file mode 100644 index 0000000000000000000000000000000000000000..0cd1390d41461cedf3c11c6e8e633007d05e3ecb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_loc.py @@ -0,0 +1,3366 @@ +""" test label based indexing with loc """ +from collections import namedtuple +from datetime import ( + date, + datetime, + time, + timedelta, +) +import re + +from dateutil.tz import gettz +import numpy as np +import pytest + +from pandas._config import using_pyarrow_string_dtype + +from pandas._libs import index as libindex +from pandas.compat.numpy import np_version_gt2 +from pandas.errors import IndexingError +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + Categorical, + CategoricalDtype, + CategoricalIndex, + DataFrame, + DatetimeIndex, + Index, + IndexSlice, + MultiIndex, + Period, + PeriodIndex, + Series, + SparseDtype, + Timedelta, + Timestamp, + date_range, + timedelta_range, + to_datetime, + to_timedelta, +) +import pandas._testing as tm +from pandas.api.types import is_scalar +from pandas.core.indexing import _one_ellipsis_message +from pandas.tests.indexing.common import check_indexing_smoketest_or_raises + + +@pytest.mark.parametrize( + "series, new_series, expected_ser", + [ + [[np.nan, np.nan, "b"], ["a", np.nan, np.nan], [False, True, True]], + [[np.nan, "b"], ["a", np.nan], [False, True]], + ], +) +def test_not_change_nan_loc(series, new_series, expected_ser): + # GH 28403 + df = DataFrame({"A": series}) + df.loc[:, "A"] = new_series + expected = DataFrame({"A": expected_ser}) + tm.assert_frame_equal(df.isna(), expected) + tm.assert_frame_equal(df.notna(), ~expected) + + +class TestLoc: + def test_none_values_on_string_columns(self): + # Issue #32218 + df = DataFrame(["1", "2", None], columns=["a"], dtype="str") + + assert df.loc[2, "a"] is None + + @pytest.mark.parametrize("kind", ["series", "frame"]) + def test_loc_getitem_int(self, kind, request): + # int label + obj = request.getfixturevalue(f"{kind}_labels") + check_indexing_smoketest_or_raises(obj, "loc", 2, fails=KeyError) + + @pytest.mark.parametrize("kind", ["series", "frame"]) + def test_loc_getitem_label(self, kind, request): + # label + obj = request.getfixturevalue(f"{kind}_empty") + check_indexing_smoketest_or_raises(obj, "loc", "c", fails=KeyError) + + @pytest.mark.parametrize( + "key, typs, axes", + [ + ["f", ["ints", "uints", "labels", "mixed", "ts"], None], + ["f", ["floats"], None], + [20, ["ints", "uints", "mixed"], None], + [20, ["labels"], None], + [20, ["ts"], 0], + [20, ["floats"], 0], + ], + ) + @pytest.mark.parametrize("kind", ["series", "frame"]) + def test_loc_getitem_label_out_of_range(self, key, typs, axes, kind, request): + for typ in typs: + obj = request.getfixturevalue(f"{kind}_{typ}") + # out of range label + check_indexing_smoketest_or_raises( + obj, "loc", key, axes=axes, fails=KeyError + ) + + @pytest.mark.parametrize( + "key, typs", + [ + [[0, 1, 2], ["ints", "uints", "floats"]], + [[1, 3.0, "A"], ["ints", "uints", "floats"]], + ], + ) + @pytest.mark.parametrize("kind", ["series", "frame"]) + def test_loc_getitem_label_list(self, key, typs, kind, request): + for typ in typs: + obj = request.getfixturevalue(f"{kind}_{typ}") + # list of labels + check_indexing_smoketest_or_raises(obj, "loc", key, fails=KeyError) + + @pytest.mark.parametrize( + "key, typs, axes", + [ + [[0, 1, 2], ["empty"], None], + [[0, 2, 10], ["ints", "uints", "floats"], 0], + [[3, 6, 7], ["ints", "uints", "floats"], 1], + # GH 17758 - MultiIndex and missing keys + [[(1, 3), (1, 4), (2, 5)], ["multi"], 0], + ], + ) + @pytest.mark.parametrize("kind", ["series", "frame"]) + def test_loc_getitem_label_list_with_missing(self, key, typs, axes, kind, request): + for typ in typs: + obj = request.getfixturevalue(f"{kind}_{typ}") + check_indexing_smoketest_or_raises( + obj, "loc", key, axes=axes, fails=KeyError + ) + + @pytest.mark.parametrize("typs", ["ints", "uints"]) + @pytest.mark.parametrize("kind", ["series", "frame"]) + def test_loc_getitem_label_list_fails(self, typs, kind, request): + # fails + obj = request.getfixturevalue(f"{kind}_{typs}") + check_indexing_smoketest_or_raises( + obj, "loc", [20, 30, 40], axes=1, fails=KeyError + ) + + def test_loc_getitem_label_array_like(self): + # TODO: test something? + # array like + pass + + @pytest.mark.parametrize("kind", ["series", "frame"]) + def test_loc_getitem_bool(self, kind, request): + obj = request.getfixturevalue(f"{kind}_empty") + # boolean indexers + b = [True, False, True, False] + + check_indexing_smoketest_or_raises(obj, "loc", b, fails=IndexError) + + @pytest.mark.parametrize( + "slc, typs, axes, fails", + [ + [ + slice(1, 3), + ["labels", "mixed", "empty", "ts", "floats"], + None, + TypeError, + ], + [slice("20130102", "20130104"), ["ts"], 1, TypeError], + [slice(2, 8), ["mixed"], 0, TypeError], + [slice(2, 8), ["mixed"], 1, KeyError], + [slice(2, 4, 2), ["mixed"], 0, TypeError], + ], + ) + @pytest.mark.parametrize("kind", ["series", "frame"]) + def test_loc_getitem_label_slice(self, slc, typs, axes, fails, kind, request): + # label slices (with ints) + + # real label slices + + # GH 14316 + for typ in typs: + obj = request.getfixturevalue(f"{kind}_{typ}") + check_indexing_smoketest_or_raises( + obj, + "loc", + slc, + axes=axes, + fails=fails, + ) + + def test_setitem_from_duplicate_axis(self): + # GH#34034 + df = DataFrame( + [[20, "a"], [200, "a"], [200, "a"]], + columns=["col1", "col2"], + index=[10, 1, 1], + ) + df.loc[1, "col1"] = np.arange(2) + expected = DataFrame( + [[20, "a"], [0, "a"], [1, "a"]], columns=["col1", "col2"], index=[10, 1, 1] + ) + tm.assert_frame_equal(df, expected) + + def test_column_types_consistent(self): + # GH 26779 + df = DataFrame( + data={ + "channel": [1, 2, 3], + "A": ["String 1", np.nan, "String 2"], + "B": [ + Timestamp("2019-06-11 11:00:00"), + pd.NaT, + Timestamp("2019-06-11 12:00:00"), + ], + } + ) + df2 = DataFrame( + data={"A": ["String 3"], "B": [Timestamp("2019-06-11 12:00:00")]} + ) + # Change Columns A and B to df2.values wherever Column A is NaN + df.loc[df["A"].isna(), ["A", "B"]] = df2.values + expected = DataFrame( + data={ + "channel": [1, 2, 3], + "A": ["String 1", "String 3", "String 2"], + "B": [ + Timestamp("2019-06-11 11:00:00"), + Timestamp("2019-06-11 12:00:00"), + Timestamp("2019-06-11 12:00:00"), + ], + } + ) + tm.assert_frame_equal(df, expected) + + @pytest.mark.parametrize( + "obj, key, exp", + [ + ( + DataFrame([[1]], columns=Index([False])), + IndexSlice[:, False], + Series([1], name=False), + ), + (Series([1], index=Index([False])), False, [1]), + (DataFrame([[1]], index=Index([False])), False, Series([1], name=False)), + ], + ) + def test_loc_getitem_single_boolean_arg(self, obj, key, exp): + # GH 44322 + res = obj.loc[key] + if isinstance(exp, (DataFrame, Series)): + tm.assert_equal(res, exp) + else: + assert res == exp + + +class TestLocBaseIndependent: + # Tests for loc that do not depend on subclassing Base + def test_loc_npstr(self): + # GH#45580 + df = DataFrame(index=date_range("2021", "2022")) + result = df.loc[np.array(["2021/6/1"])[0] :] + expected = df.iloc[151:] + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "msg, key", + [ + (r"Period\('2019', 'Y-DEC'\), 'foo', 'bar'", (Period(2019), "foo", "bar")), + (r"Period\('2019', 'Y-DEC'\), 'y1', 'bar'", (Period(2019), "y1", "bar")), + (r"Period\('2019', 'Y-DEC'\), 'foo', 'z1'", (Period(2019), "foo", "z1")), + ( + r"Period\('2018', 'Y-DEC'\), Period\('2016', 'Y-DEC'\), 'bar'", + (Period(2018), Period(2016), "bar"), + ), + (r"Period\('2018', 'Y-DEC'\), 'foo', 'y1'", (Period(2018), "foo", "y1")), + ( + r"Period\('2017', 'Y-DEC'\), 'foo', Period\('2015', 'Y-DEC'\)", + (Period(2017), "foo", Period(2015)), + ), + (r"Period\('2017', 'Y-DEC'\), 'z1', 'bar'", (Period(2017), "z1", "bar")), + ], + ) + def test_contains_raise_error_if_period_index_is_in_multi_index(self, msg, key): + # GH#20684 + """ + parse_datetime_string_with_reso return parameter if type not matched. + PeriodIndex.get_loc takes returned value from parse_datetime_string_with_reso + as a tuple. + If first argument is Period and a tuple has 3 items, + process go on not raise exception + """ + df = DataFrame( + { + "A": [Period(2019), "x1", "x2"], + "B": [Period(2018), Period(2016), "y1"], + "C": [Period(2017), "z1", Period(2015)], + "V1": [1, 2, 3], + "V2": [10, 20, 30], + } + ).set_index(["A", "B", "C"]) + with pytest.raises(KeyError, match=msg): + df.loc[key] + + def test_loc_getitem_missing_unicode_key(self): + df = DataFrame({"a": [1]}) + with pytest.raises(KeyError, match="\u05d0"): + df.loc[:, "\u05d0"] # should not raise UnicodeEncodeError + + def test_loc_getitem_dups(self): + # GH 5678 + # repeated getitems on a dup index returning a ndarray + df = DataFrame( + np.random.default_rng(2).random((20, 5)), + index=["ABCDE"[x % 5] for x in range(20)], + ) + expected = df.loc["A", 0] + result = df.loc[:, 0].loc["A"] + tm.assert_series_equal(result, expected) + + def test_loc_getitem_dups2(self): + # GH4726 + # dup indexing with iloc/loc + df = DataFrame( + [[1, 2, "foo", "bar", Timestamp("20130101")]], + columns=["a", "a", "a", "a", "a"], + index=[1], + ) + expected = Series( + [1, 2, "foo", "bar", Timestamp("20130101")], + index=["a", "a", "a", "a", "a"], + name=1, + ) + + result = df.iloc[0] + tm.assert_series_equal(result, expected) + + result = df.loc[1] + tm.assert_series_equal(result, expected) + + def test_loc_setitem_dups(self): + # GH 6541 + df_orig = DataFrame( + { + "me": list("rttti"), + "foo": list("aaade"), + "bar": np.arange(5, dtype="float64") * 1.34 + 2, + "bar2": np.arange(5, dtype="float64") * -0.34 + 2, + } + ).set_index("me") + + indexer = ( + "r", + ["bar", "bar2"], + ) + df = df_orig.copy() + df.loc[indexer] *= 2.0 + tm.assert_series_equal(df.loc[indexer], 2.0 * df_orig.loc[indexer]) + + indexer = ( + "r", + "bar", + ) + df = df_orig.copy() + df.loc[indexer] *= 2.0 + assert df.loc[indexer] == 2.0 * df_orig.loc[indexer] + + indexer = ( + "t", + ["bar", "bar2"], + ) + df = df_orig.copy() + df.loc[indexer] *= 2.0 + tm.assert_frame_equal(df.loc[indexer], 2.0 * df_orig.loc[indexer]) + + def test_loc_setitem_slice(self): + # GH10503 + + # assigning the same type should not change the type + df1 = DataFrame({"a": [0, 1, 1], "b": Series([100, 200, 300], dtype="uint32")}) + ix = df1["a"] == 1 + newb1 = df1.loc[ix, "b"] + 1 + df1.loc[ix, "b"] = newb1 + expected = DataFrame( + {"a": [0, 1, 1], "b": Series([100, 201, 301], dtype="uint32")} + ) + tm.assert_frame_equal(df1, expected) + + # assigning a new type should get the inferred type + df2 = DataFrame({"a": [0, 1, 1], "b": [100, 200, 300]}, dtype="uint64") + ix = df1["a"] == 1 + newb2 = df2.loc[ix, "b"] + with tm.assert_produces_warning( + FutureWarning, match="item of incompatible dtype" + ): + df1.loc[ix, "b"] = newb2 + expected = DataFrame({"a": [0, 1, 1], "b": [100, 200, 300]}, dtype="uint64") + tm.assert_frame_equal(df2, expected) + + def test_loc_setitem_dtype(self): + # GH31340 + df = DataFrame({"id": ["A"], "a": [1.2], "b": [0.0], "c": [-2.5]}) + cols = ["a", "b", "c"] + df.loc[:, cols] = df.loc[:, cols].astype("float32") + + # pre-2.0 this setting would swap in new arrays, in 2.0 it is correctly + # in-place, consistent with non-split-path + expected = DataFrame( + { + "id": ["A"], + "a": np.array([1.2], dtype="float64"), + "b": np.array([0.0], dtype="float64"), + "c": np.array([-2.5], dtype="float64"), + } + ) # id is inferred as object + + tm.assert_frame_equal(df, expected) + + def test_getitem_label_list_with_missing(self): + s = Series(range(3), index=["a", "b", "c"]) + + # consistency + with pytest.raises(KeyError, match="not in index"): + s[["a", "d"]] + + s = Series(range(3)) + with pytest.raises(KeyError, match="not in index"): + s[[0, 3]] + + @pytest.mark.parametrize("index", [[True, False], [True, False, True, False]]) + def test_loc_getitem_bool_diff_len(self, index): + # GH26658 + s = Series([1, 2, 3]) + msg = f"Boolean index has wrong length: {len(index)} instead of {len(s)}" + with pytest.raises(IndexError, match=msg): + s.loc[index] + + def test_loc_getitem_int_slice(self): + # TODO: test something here? + pass + + def test_loc_to_fail(self): + # GH3449 + df = DataFrame( + np.random.default_rng(2).random((3, 3)), + index=["a", "b", "c"], + columns=["e", "f", "g"], + ) + + msg = ( + rf"\"None of \[Index\(\[1, 2\], dtype='{np.dtype(int)}'\)\] are " + r"in the \[index\]\"" + ) + with pytest.raises(KeyError, match=msg): + df.loc[[1, 2], [1, 2]] + + def test_loc_to_fail2(self): + # GH 7496 + # loc should not fallback + + s = Series(dtype=object) + s.loc[1] = 1 + s.loc["a"] = 2 + + with pytest.raises(KeyError, match=r"^-1$"): + s.loc[-1] + + msg = ( + rf"\"None of \[Index\(\[-1, -2\], dtype='{np.dtype(int)}'\)\] are " + r"in the \[index\]\"" + ) + with pytest.raises(KeyError, match=msg): + s.loc[[-1, -2]] + + msg = r"\"None of \[Index\(\['4'\], dtype='object'\)\] are in the \[index\]\"" + with pytest.raises(KeyError, match=msg): + s.loc[Index(["4"], dtype=object)] + + s.loc[-1] = 3 + with pytest.raises(KeyError, match="not in index"): + s.loc[[-1, -2]] + + s["a"] = 2 + msg = ( + rf"\"None of \[Index\(\[-2\], dtype='{np.dtype(int)}'\)\] are " + r"in the \[index\]\"" + ) + with pytest.raises(KeyError, match=msg): + s.loc[[-2]] + + del s["a"] + + with pytest.raises(KeyError, match=msg): + s.loc[[-2]] = 0 + + def test_loc_to_fail3(self): + # inconsistency between .loc[values] and .loc[values,:] + # GH 7999 + df = DataFrame([["a"], ["b"]], index=[1, 2], columns=["value"]) + + msg = ( + rf"\"None of \[Index\(\[3\], dtype='{np.dtype(int)}'\)\] are " + r"in the \[index\]\"" + ) + with pytest.raises(KeyError, match=msg): + df.loc[[3], :] + + with pytest.raises(KeyError, match=msg): + df.loc[[3]] + + def test_loc_getitem_list_with_fail(self): + # 15747 + # should KeyError if *any* missing labels + + s = Series([1, 2, 3]) + + s.loc[[2]] + + msg = f"\"None of [Index([3], dtype='{np.dtype(int)}')] are in the [index]" + with pytest.raises(KeyError, match=re.escape(msg)): + s.loc[[3]] + + # a non-match and a match + with pytest.raises(KeyError, match="not in index"): + s.loc[[2, 3]] + + def test_loc_index(self): + # gh-17131 + # a boolean index should index like a boolean numpy array + + df = DataFrame( + np.random.default_rng(2).random(size=(5, 10)), + index=["alpha_0", "alpha_1", "alpha_2", "beta_0", "beta_1"], + ) + + mask = df.index.map(lambda x: "alpha" in x) + expected = df.loc[np.array(mask)] + + result = df.loc[mask] + tm.assert_frame_equal(result, expected) + + result = df.loc[mask.values] + tm.assert_frame_equal(result, expected) + + result = df.loc[pd.array(mask, dtype="boolean")] + tm.assert_frame_equal(result, expected) + + def test_loc_general(self): + df = DataFrame( + np.random.default_rng(2).random((4, 4)), + columns=["A", "B", "C", "D"], + index=["A", "B", "C", "D"], + ) + + # want this to work + result = df.loc[:, "A":"B"].iloc[0:2, :] + assert (result.columns == ["A", "B"]).all() + assert (result.index == ["A", "B"]).all() + + # mixed type + result = DataFrame({"a": [Timestamp("20130101")], "b": [1]}).iloc[0] + expected = Series([Timestamp("20130101"), 1], index=["a", "b"], name=0) + tm.assert_series_equal(result, expected) + assert result.dtype == object + + @pytest.fixture + def frame_for_consistency(self): + return DataFrame( + { + "date": date_range("2000-01-01", "2000-01-5"), + "val": Series(range(5), dtype=np.int64), + } + ) + + @pytest.mark.parametrize( + "val", + [0, np.array(0, dtype=np.int64), np.array([0, 0, 0, 0, 0], dtype=np.int64)], + ) + def test_loc_setitem_consistency(self, frame_for_consistency, val): + # GH 6149 + # coerce similarly for setitem and loc when rows have a null-slice + expected = DataFrame( + { + "date": Series(0, index=range(5), dtype=np.int64), + "val": Series(range(5), dtype=np.int64), + } + ) + df = frame_for_consistency.copy() + with tm.assert_produces_warning(FutureWarning, match="incompatible dtype"): + df.loc[:, "date"] = val + tm.assert_frame_equal(df, expected) + + def test_loc_setitem_consistency_dt64_to_str(self, frame_for_consistency): + # GH 6149 + # coerce similarly for setitem and loc when rows have a null-slice + + expected = DataFrame( + { + "date": Series("foo", index=range(5)), + "val": Series(range(5), dtype=np.int64), + } + ) + df = frame_for_consistency.copy() + with tm.assert_produces_warning(FutureWarning, match="incompatible dtype"): + df.loc[:, "date"] = "foo" + tm.assert_frame_equal(df, expected) + + def test_loc_setitem_consistency_dt64_to_float(self, frame_for_consistency): + # GH 6149 + # coerce similarly for setitem and loc when rows have a null-slice + expected = DataFrame( + { + "date": Series(1.0, index=range(5)), + "val": Series(range(5), dtype=np.int64), + } + ) + df = frame_for_consistency.copy() + with tm.assert_produces_warning(FutureWarning, match="incompatible dtype"): + df.loc[:, "date"] = 1.0 + tm.assert_frame_equal(df, expected) + + def test_loc_setitem_consistency_single_row(self): + # GH 15494 + # setting on frame with single row + df = DataFrame({"date": Series([Timestamp("20180101")])}) + with tm.assert_produces_warning(FutureWarning, match="incompatible dtype"): + df.loc[:, "date"] = "string" + expected = DataFrame({"date": Series(["string"])}) + tm.assert_frame_equal(df, expected) + + def test_loc_setitem_consistency_empty(self): + # empty (essentially noops) + # before the enforcement of #45333 in 2.0, the loc.setitem here would + # change the dtype of df.x to int64 + expected = DataFrame(columns=["x", "y"]) + df = DataFrame(columns=["x", "y"]) + with tm.assert_produces_warning(None): + df.loc[:, "x"] = 1 + tm.assert_frame_equal(df, expected) + + # setting with setitem swaps in a new array, so changes the dtype + df = DataFrame(columns=["x", "y"]) + df["x"] = 1 + expected["x"] = expected["x"].astype(np.int64) + tm.assert_frame_equal(df, expected) + + def test_loc_setitem_consistency_slice_column_len(self): + # .loc[:,column] setting with slice == len of the column + # GH10408 + levels = [ + ["Region_1"] * 4, + ["Site_1", "Site_1", "Site_2", "Site_2"], + [3987227376, 3980680971, 3977723249, 3977723089], + ] + mi = MultiIndex.from_arrays(levels, names=["Region", "Site", "RespondentID"]) + + clevels = [ + ["Respondent", "Respondent", "Respondent", "OtherCat", "OtherCat"], + ["Something", "StartDate", "EndDate", "Yes/No", "SomethingElse"], + ] + cols = MultiIndex.from_arrays(clevels, names=["Level_0", "Level_1"]) + + values = [ + ["A", "5/25/2015 10:59", "5/25/2015 11:22", "Yes", np.nan], + ["A", "5/21/2015 9:40", "5/21/2015 9:52", "Yes", "Yes"], + ["A", "5/20/2015 8:27", "5/20/2015 8:41", "Yes", np.nan], + ["A", "5/20/2015 8:33", "5/20/2015 9:09", "Yes", "No"], + ] + df = DataFrame(values, index=mi, columns=cols) + + df.loc[:, ("Respondent", "StartDate")] = to_datetime( + df.loc[:, ("Respondent", "StartDate")] + ) + df.loc[:, ("Respondent", "EndDate")] = to_datetime( + df.loc[:, ("Respondent", "EndDate")] + ) + df = df.infer_objects(copy=False) + + # Adding a new key + df.loc[:, ("Respondent", "Duration")] = ( + df.loc[:, ("Respondent", "EndDate")] + - df.loc[:, ("Respondent", "StartDate")] + ) + + # timedelta64[m] -> float, so this cannot be done inplace, so + # no warning + with tm.assert_produces_warning(FutureWarning, match="incompatible dtype"): + df.loc[:, ("Respondent", "Duration")] = df.loc[ + :, ("Respondent", "Duration") + ] / Timedelta(60_000_000_000) + + expected = Series( + [23.0, 12.0, 14.0, 36.0], index=df.index, name=("Respondent", "Duration") + ) + tm.assert_series_equal(df[("Respondent", "Duration")], expected) + + @pytest.mark.parametrize("unit", ["Y", "M", "D", "h", "m", "s", "ms", "us"]) + def test_loc_assign_non_ns_datetime(self, unit): + # GH 27395, non-ns dtype assignment via .loc should work + # and return the same result when using simple assignment + df = DataFrame( + { + "timestamp": [ + np.datetime64("2017-02-11 12:41:29"), + np.datetime64("1991-11-07 04:22:37"), + ] + } + ) + + df.loc[:, unit] = df.loc[:, "timestamp"].values.astype(f"datetime64[{unit}]") + df["expected"] = df.loc[:, "timestamp"].values.astype(f"datetime64[{unit}]") + expected = Series(df.loc[:, "expected"], name=unit) + tm.assert_series_equal(df.loc[:, unit], expected) + + def test_loc_modify_datetime(self): + # see gh-28837 + df = DataFrame.from_dict( + {"date": [1485264372711, 1485265925110, 1540215845888, 1540282121025]} + ) + + df["date_dt"] = to_datetime(df["date"], unit="ms", cache=True) + + df.loc[:, "date_dt_cp"] = df.loc[:, "date_dt"] + df.loc[[2, 3], "date_dt_cp"] = df.loc[[2, 3], "date_dt"] + + expected = DataFrame( + [ + [1485264372711, "2017-01-24 13:26:12.711", "2017-01-24 13:26:12.711"], + [1485265925110, "2017-01-24 13:52:05.110", "2017-01-24 13:52:05.110"], + [1540215845888, "2018-10-22 13:44:05.888", "2018-10-22 13:44:05.888"], + [1540282121025, "2018-10-23 08:08:41.025", "2018-10-23 08:08:41.025"], + ], + columns=["date", "date_dt", "date_dt_cp"], + ) + + columns = ["date_dt", "date_dt_cp"] + expected[columns] = expected[columns].apply(to_datetime) + + tm.assert_frame_equal(df, expected) + + def test_loc_setitem_frame_with_reindex(self): + # GH#6254 setting issue + df = DataFrame(index=[3, 5, 4], columns=["A"], dtype=float) + df.loc[[4, 3, 5], "A"] = np.array([1, 2, 3], dtype="int64") + + # setting integer values into a float dataframe with loc is inplace, + # so we retain float dtype + ser = Series([2, 3, 1], index=[3, 5, 4], dtype=float) + expected = DataFrame({"A": ser}) + tm.assert_frame_equal(df, expected) + + def test_loc_setitem_frame_with_reindex_mixed(self): + # GH#40480 + df = DataFrame(index=[3, 5, 4], columns=["A", "B"], dtype=float) + df["B"] = "string" + df.loc[[4, 3, 5], "A"] = np.array([1, 2, 3], dtype="int64") + ser = Series([2, 3, 1], index=[3, 5, 4], dtype="int64") + # pre-2.0 this setting swapped in a new array, now it is inplace + # consistent with non-split-path + expected = DataFrame({"A": ser.astype(float)}) + expected["B"] = "string" + tm.assert_frame_equal(df, expected) + + def test_loc_setitem_frame_with_inverted_slice(self): + # GH#40480 + df = DataFrame(index=[1, 2, 3], columns=["A", "B"], dtype=float) + df["B"] = "string" + df.loc[slice(3, 0, -1), "A"] = np.array([1, 2, 3], dtype="int64") + # pre-2.0 this setting swapped in a new array, now it is inplace + # consistent with non-split-path + expected = DataFrame({"A": [3.0, 2.0, 1.0], "B": "string"}, index=[1, 2, 3]) + tm.assert_frame_equal(df, expected) + + def test_loc_setitem_empty_frame(self): + # GH#6252 setting with an empty frame + keys1 = ["@" + str(i) for i in range(5)] + val1 = np.arange(5, dtype="int64") + + keys2 = ["@" + str(i) for i in range(4)] + val2 = np.arange(4, dtype="int64") + + index = list(set(keys1).union(keys2)) + df = DataFrame(index=index) + df["A"] = np.nan + df.loc[keys1, "A"] = val1 + + df["B"] = np.nan + df.loc[keys2, "B"] = val2 + + # Because df["A"] was initialized as float64, setting values into it + # is inplace, so that dtype is retained + sera = Series(val1, index=keys1, dtype=np.float64) + serb = Series(val2, index=keys2) + expected = DataFrame( + {"A": sera, "B": serb}, columns=Index(["A", "B"], dtype=object) + ).reindex(index=index) + tm.assert_frame_equal(df, expected) + + def test_loc_setitem_frame(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((4, 4)), + index=list("abcd"), + columns=list("ABCD"), + ) + + result = df.iloc[0, 0] + + df.loc["a", "A"] = 1 + result = df.loc["a", "A"] + assert result == 1 + + result = df.iloc[0, 0] + assert result == 1 + + df.loc[:, "B":"D"] = 0 + expected = df.loc[:, "B":"D"] + result = df.iloc[:, 1:] + tm.assert_frame_equal(result, expected) + + def test_loc_setitem_frame_nan_int_coercion_invalid(self): + # GH 8669 + # invalid coercion of nan -> int + df = DataFrame({"A": [1, 2, 3], "B": np.nan}) + df.loc[df.B > df.A, "B"] = df.A + expected = DataFrame({"A": [1, 2, 3], "B": np.nan}) + tm.assert_frame_equal(df, expected) + + def test_loc_setitem_frame_mixed_labels(self): + # GH 6546 + # setting with mixed labels + df = DataFrame({1: [1, 2], 2: [3, 4], "a": ["a", "b"]}) + + result = df.loc[0, [1, 2]] + expected = Series( + [1, 3], index=Index([1, 2], dtype=object), dtype=object, name=0 + ) + tm.assert_series_equal(result, expected) + + expected = DataFrame({1: [5, 2], 2: [6, 4], "a": ["a", "b"]}) + df.loc[0, [1, 2]] = [5, 6] + tm.assert_frame_equal(df, expected) + + @pytest.mark.filterwarnings("ignore:Setting a value on a view:FutureWarning") + def test_loc_setitem_frame_multiples(self, warn_copy_on_write): + # multiple setting + df = DataFrame( + {"A": ["foo", "bar", "baz"], "B": Series(range(3), dtype=np.int64)} + ) + rhs = df.loc[1:2] + rhs.index = df.index[0:2] + df.loc[0:1] = rhs + expected = DataFrame( + {"A": ["bar", "baz", "baz"], "B": Series([1, 2, 2], dtype=np.int64)} + ) + tm.assert_frame_equal(df, expected) + + # multiple setting with frame on rhs (with M8) + df = DataFrame( + { + "date": date_range("2000-01-01", "2000-01-5"), + "val": Series(range(5), dtype=np.int64), + } + ) + expected = DataFrame( + { + "date": [ + Timestamp("20000101"), + Timestamp("20000102"), + Timestamp("20000101"), + Timestamp("20000102"), + Timestamp("20000103"), + ], + "val": Series([0, 1, 0, 1, 2], dtype=np.int64), + } + ) + rhs = df.loc[0:2] + rhs.index = df.index[2:5] + df.loc[2:4] = rhs + tm.assert_frame_equal(df, expected) + + @pytest.mark.parametrize( + "indexer", [["A"], slice(None, "A", None), np.array(["A"])] + ) + @pytest.mark.parametrize("value", [["Z"], np.array(["Z"])]) + def test_loc_setitem_with_scalar_index(self, indexer, value): + # GH #19474 + # assigning like "df.loc[0, ['A']] = ['Z']" should be evaluated + # elementwisely, not using "setter('A', ['Z'])". + + # Set object dtype to avoid upcast when setting 'Z' + df = DataFrame([[1, 2], [3, 4]], columns=["A", "B"]).astype({"A": object}) + df.loc[0, indexer] = value + result = df.loc[0, "A"] + + assert is_scalar(result) and result == "Z" + + @pytest.mark.parametrize( + "index,box,expected", + [ + ( + ([0, 2], ["A", "B", "C", "D"]), + 7, + DataFrame( + [[7, 7, 7, 7], [3, 4, np.nan, np.nan], [7, 7, 7, 7]], + columns=["A", "B", "C", "D"], + ), + ), + ( + (1, ["C", "D"]), + [7, 8], + DataFrame( + [[1, 2, np.nan, np.nan], [3, 4, 7, 8], [5, 6, np.nan, np.nan]], + columns=["A", "B", "C", "D"], + ), + ), + ( + (1, ["A", "B", "C"]), + np.array([7, 8, 9], dtype=np.int64), + DataFrame( + [[1, 2, np.nan], [7, 8, 9], [5, 6, np.nan]], columns=["A", "B", "C"] + ), + ), + ( + (slice(1, 3, None), ["B", "C", "D"]), + [[7, 8, 9], [10, 11, 12]], + DataFrame( + [[1, 2, np.nan, np.nan], [3, 7, 8, 9], [5, 10, 11, 12]], + columns=["A", "B", "C", "D"], + ), + ), + ( + (slice(1, 3, None), ["C", "A", "D"]), + np.array([[7, 8, 9], [10, 11, 12]], dtype=np.int64), + DataFrame( + [[1, 2, np.nan, np.nan], [8, 4, 7, 9], [11, 6, 10, 12]], + columns=["A", "B", "C", "D"], + ), + ), + ( + (slice(None, None, None), ["A", "C"]), + DataFrame([[7, 8], [9, 10], [11, 12]], columns=["A", "C"]), + DataFrame( + [[7, 2, 8], [9, 4, 10], [11, 6, 12]], columns=["A", "B", "C"] + ), + ), + ], + ) + def test_loc_setitem_missing_columns(self, index, box, expected): + # GH 29334 + df = DataFrame([[1, 2], [3, 4], [5, 6]], columns=["A", "B"]) + + df.loc[index] = box + tm.assert_frame_equal(df, expected) + + def test_loc_coercion(self): + # GH#12411 + df = DataFrame({"date": [Timestamp("20130101").tz_localize("UTC"), pd.NaT]}) + expected = df.dtypes + + result = df.iloc[[0]] + tm.assert_series_equal(result.dtypes, expected) + + result = df.iloc[[1]] + tm.assert_series_equal(result.dtypes, expected) + + def test_loc_coercion2(self): + # GH#12045 + df = DataFrame({"date": [datetime(2012, 1, 1), datetime(1012, 1, 2)]}) + expected = df.dtypes + + result = df.iloc[[0]] + tm.assert_series_equal(result.dtypes, expected) + + result = df.iloc[[1]] + tm.assert_series_equal(result.dtypes, expected) + + def test_loc_coercion3(self): + # GH#11594 + df = DataFrame({"text": ["some words"] + [None] * 9}) + expected = df.dtypes + + result = df.iloc[0:2] + tm.assert_series_equal(result.dtypes, expected) + + result = df.iloc[3:] + tm.assert_series_equal(result.dtypes, expected) + + def test_setitem_new_key_tz(self, indexer_sl): + # GH#12862 should not raise on assigning the second value + vals = [ + to_datetime(42).tz_localize("UTC"), + to_datetime(666).tz_localize("UTC"), + ] + expected = Series(vals, index=Index(["foo", "bar"], dtype=object)) + + ser = Series(dtype=object) + indexer_sl(ser)["foo"] = vals[0] + indexer_sl(ser)["bar"] = vals[1] + + tm.assert_series_equal(ser, expected) + + def test_loc_non_unique(self): + # GH3659 + # non-unique indexer with loc slice + # https://groups.google.com/forum/?fromgroups#!topic/pydata/zTm2No0crYs + + # these are going to raise because the we are non monotonic + df = DataFrame( + {"A": [1, 2, 3, 4, 5, 6], "B": [3, 4, 5, 6, 7, 8]}, index=[0, 1, 0, 1, 2, 3] + ) + msg = "'Cannot get left slice bound for non-unique label: 1'" + with pytest.raises(KeyError, match=msg): + df.loc[1:] + msg = "'Cannot get left slice bound for non-unique label: 0'" + with pytest.raises(KeyError, match=msg): + df.loc[0:] + msg = "'Cannot get left slice bound for non-unique label: 1'" + with pytest.raises(KeyError, match=msg): + df.loc[1:2] + + # monotonic are ok + df = DataFrame( + {"A": [1, 2, 3, 4, 5, 6], "B": [3, 4, 5, 6, 7, 8]}, index=[0, 1, 0, 1, 2, 3] + ).sort_index(axis=0) + result = df.loc[1:] + expected = DataFrame({"A": [2, 4, 5, 6], "B": [4, 6, 7, 8]}, index=[1, 1, 2, 3]) + tm.assert_frame_equal(result, expected) + + result = df.loc[0:] + tm.assert_frame_equal(result, df) + + result = df.loc[1:2] + expected = DataFrame({"A": [2, 4, 5], "B": [4, 6, 7]}, index=[1, 1, 2]) + tm.assert_frame_equal(result, expected) + + @pytest.mark.arm_slow + @pytest.mark.parametrize("length, l2", [[900, 100], [900000, 100000]]) + def test_loc_non_unique_memory_error(self, length, l2): + # GH 4280 + # non_unique index with a large selection triggers a memory error + + columns = list("ABCDEFG") + + df = pd.concat( + [ + DataFrame( + np.random.default_rng(2).standard_normal((length, len(columns))), + index=np.arange(length), + columns=columns, + ), + DataFrame(np.ones((l2, len(columns))), index=[0] * l2, columns=columns), + ] + ) + + assert df.index.is_unique is False + + mask = np.arange(l2) + result = df.loc[mask] + expected = pd.concat( + [ + df.take([0]), + DataFrame( + np.ones((len(mask), len(columns))), + index=[0] * len(mask), + columns=columns, + ), + df.take(mask[1:]), + ] + ) + tm.assert_frame_equal(result, expected) + + def test_loc_name(self): + # GH 3880 + df = DataFrame([[1, 1], [1, 1]]) + df.index.name = "index_name" + result = df.iloc[[0, 1]].index.name + assert result == "index_name" + + result = df.loc[[0, 1]].index.name + assert result == "index_name" + + def test_loc_empty_list_indexer_is_ok(self): + df = DataFrame( + np.ones((5, 2)), + index=Index([f"i-{i}" for i in range(5)], name="a"), + columns=Index([f"i-{i}" for i in range(2)], name="a"), + ) + # vertical empty + tm.assert_frame_equal( + df.loc[:, []], df.iloc[:, :0], check_index_type=True, check_column_type=True + ) + # horizontal empty + tm.assert_frame_equal( + df.loc[[], :], df.iloc[:0, :], check_index_type=True, check_column_type=True + ) + # horizontal empty + tm.assert_frame_equal( + df.loc[[]], df.iloc[:0, :], check_index_type=True, check_column_type=True + ) + + def test_identity_slice_returns_new_object( + self, using_copy_on_write, warn_copy_on_write + ): + # GH13873 + + original_df = DataFrame({"a": [1, 2, 3]}) + sliced_df = original_df.loc[:] + assert sliced_df is not original_df + assert original_df[:] is not original_df + assert original_df.loc[:, :] is not original_df + + # should be a shallow copy + assert np.shares_memory(original_df["a"]._values, sliced_df["a"]._values) + + # Setting using .loc[:, "a"] sets inplace so alters both sliced and orig + # depending on CoW + with tm.assert_cow_warning(warn_copy_on_write): + original_df.loc[:, "a"] = [4, 4, 4] + if using_copy_on_write: + assert (sliced_df["a"] == [1, 2, 3]).all() + else: + assert (sliced_df["a"] == 4).all() + + # These should not return copies + df = DataFrame(np.random.default_rng(2).standard_normal((10, 4))) + if using_copy_on_write or warn_copy_on_write: + assert df[0] is not df.loc[:, 0] + else: + assert df[0] is df.loc[:, 0] + + # Same tests for Series + original_series = Series([1, 2, 3, 4, 5, 6]) + sliced_series = original_series.loc[:] + assert sliced_series is not original_series + assert original_series[:] is not original_series + + with tm.assert_cow_warning(warn_copy_on_write): + original_series[:3] = [7, 8, 9] + if using_copy_on_write: + assert all(sliced_series[:3] == [1, 2, 3]) + else: + assert all(sliced_series[:3] == [7, 8, 9]) + + def test_loc_copy_vs_view(self, request, using_copy_on_write): + # GH 15631 + + if not using_copy_on_write: + mark = pytest.mark.xfail(reason="accidental fix reverted - GH37497") + request.applymarker(mark) + x = DataFrame(zip(range(3), range(3)), columns=["a", "b"]) + + y = x.copy() + q = y.loc[:, "a"] + q += 2 + + tm.assert_frame_equal(x, y) + + z = x.copy() + q = z.loc[x.index, "a"] + q += 2 + + tm.assert_frame_equal(x, z) + + def test_loc_uint64(self): + # GH20722 + # Test whether loc accept uint64 max value as index. + umax = np.iinfo("uint64").max + ser = Series([1, 2], index=[umax - 1, umax]) + + result = ser.loc[umax - 1] + expected = ser.iloc[0] + assert result == expected + + result = ser.loc[[umax - 1]] + expected = ser.iloc[[0]] + tm.assert_series_equal(result, expected) + + result = ser.loc[[umax - 1, umax]] + tm.assert_series_equal(result, ser) + + def test_loc_uint64_disallow_negative(self): + # GH#41775 + umax = np.iinfo("uint64").max + ser = Series([1, 2], index=[umax - 1, umax]) + + with pytest.raises(KeyError, match="-1"): + # don't wrap around + ser.loc[-1] + + with pytest.raises(KeyError, match="-1"): + # don't wrap around + ser.loc[[-1]] + + def test_loc_setitem_empty_append_expands_rows(self): + # GH6173, various appends to an empty dataframe + + data = [1, 2, 3] + expected = DataFrame( + {"x": data, "y": np.array([np.nan] * len(data), dtype=object)} + ) + + # appends to fit length of data + df = DataFrame(columns=["x", "y"]) + df.loc[:, "x"] = data + tm.assert_frame_equal(df, expected) + + def test_loc_setitem_empty_append_expands_rows_mixed_dtype(self): + # GH#37932 same as test_loc_setitem_empty_append_expands_rows + # but with mixed dtype so we go through take_split_path + data = [1, 2, 3] + expected = DataFrame( + {"x": data, "y": np.array([np.nan] * len(data), dtype=object)} + ) + + df = DataFrame(columns=["x", "y"]) + df["x"] = df["x"].astype(np.int64) + df.loc[:, "x"] = data + tm.assert_frame_equal(df, expected) + + def test_loc_setitem_empty_append_single_value(self): + # only appends one value + expected = DataFrame({"x": [1.0], "y": [np.nan]}) + df = DataFrame(columns=["x", "y"], dtype=float) + df.loc[0, "x"] = expected.loc[0, "x"] + tm.assert_frame_equal(df, expected) + + def test_loc_setitem_empty_append_raises(self): + # GH6173, various appends to an empty dataframe + + data = [1, 2] + df = DataFrame(columns=["x", "y"]) + df.index = df.index.astype(np.int64) + msg = ( + rf"None of \[Index\(\[0, 1\], dtype='{np.dtype(int)}'\)\] " + r"are in the \[index\]" + ) + with pytest.raises(KeyError, match=msg): + df.loc[[0, 1], "x"] = data + + msg = "setting an array element with a sequence." + with pytest.raises(ValueError, match=msg): + df.loc[0:2, "x"] = data + + def test_indexing_zerodim_np_array(self): + # GH24924 + df = DataFrame([[1, 2], [3, 4]]) + result = df.loc[np.array(0)] + s = Series([1, 2], name=0) + tm.assert_series_equal(result, s) + + def test_series_indexing_zerodim_np_array(self): + # GH24924 + s = Series([1, 2]) + result = s.loc[np.array(0)] + assert result == 1 + + def test_loc_reverse_assignment(self): + # GH26939 + data = [1, 2, 3, 4, 5, 6] + [None] * 4 + expected = Series(data, index=range(2010, 2020)) + + result = Series(index=range(2010, 2020), dtype=np.float64) + result.loc[2015:2010:-1] = [6, 5, 4, 3, 2, 1] + + tm.assert_series_equal(result, expected) + + @pytest.mark.xfail(using_pyarrow_string_dtype(), reason="can't set int into string") + def test_loc_setitem_str_to_small_float_conversion_type(self): + # GH#20388 + + col_data = [str(np.random.default_rng(2).random() * 1e-12) for _ in range(5)] + result = DataFrame(col_data, columns=["A"]) + expected = DataFrame(col_data, columns=["A"], dtype=object) + tm.assert_frame_equal(result, expected) + + # assigning with loc/iloc attempts to set the values inplace, which + # in this case is successful + result.loc[result.index, "A"] = [float(x) for x in col_data] + expected = DataFrame(col_data, columns=["A"], dtype=float).astype(object) + tm.assert_frame_equal(result, expected) + + # assigning the entire column using __setitem__ swaps in the new array + # GH#??? + result["A"] = [float(x) for x in col_data] + expected = DataFrame(col_data, columns=["A"], dtype=float) + tm.assert_frame_equal(result, expected) + + def test_loc_getitem_time_object(self, frame_or_series): + rng = date_range("1/1/2000", "1/5/2000", freq="5min") + mask = (rng.hour == 9) & (rng.minute == 30) + + obj = DataFrame( + np.random.default_rng(2).standard_normal((len(rng), 3)), index=rng + ) + obj = tm.get_obj(obj, frame_or_series) + + result = obj.loc[time(9, 30)] + exp = obj.loc[mask] + tm.assert_equal(result, exp) + + chunk = obj.loc["1/4/2000":] + result = chunk.loc[time(9, 30)] + expected = result[-1:] + + # Without resetting the freqs, these are 5 min and 1440 min, respectively + result.index = result.index._with_freq(None) + expected.index = expected.index._with_freq(None) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize("spmatrix_t", ["coo_matrix", "csc_matrix", "csr_matrix"]) + @pytest.mark.parametrize("dtype", [np.int64, np.float64, complex]) + def test_loc_getitem_range_from_spmatrix(self, spmatrix_t, dtype): + sp_sparse = pytest.importorskip("scipy.sparse") + + spmatrix_t = getattr(sp_sparse, spmatrix_t) + + # The bug is triggered by a sparse matrix with purely sparse columns. So the + # recipe below generates a rectangular matrix of dimension (5, 7) where all the + # diagonal cells are ones, meaning the last two columns are purely sparse. + rows, cols = 5, 7 + spmatrix = spmatrix_t(np.eye(rows, cols, dtype=dtype), dtype=dtype) + df = DataFrame.sparse.from_spmatrix(spmatrix) + + # regression test for GH#34526 + itr_idx = range(2, rows) + result = df.loc[itr_idx].values + expected = spmatrix.toarray()[itr_idx] + tm.assert_numpy_array_equal(result, expected) + + # regression test for GH#34540 + result = df.loc[itr_idx].dtypes.values + expected = np.full(cols, SparseDtype(dtype, fill_value=0)) + tm.assert_numpy_array_equal(result, expected) + + def test_loc_getitem_listlike_all_retains_sparse(self): + df = DataFrame({"A": pd.array([0, 0], dtype=SparseDtype("int64"))}) + result = df.loc[[0, 1]] + tm.assert_frame_equal(result, df) + + def test_loc_getitem_sparse_frame(self): + # GH34687 + sp_sparse = pytest.importorskip("scipy.sparse") + + df = DataFrame.sparse.from_spmatrix(sp_sparse.eye(5)) + result = df.loc[range(2)] + expected = DataFrame( + [[1.0, 0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 0.0]], + dtype=SparseDtype("float64", 0.0), + ) + tm.assert_frame_equal(result, expected) + + result = df.loc[range(2)].loc[range(1)] + expected = DataFrame( + [[1.0, 0.0, 0.0, 0.0, 0.0]], dtype=SparseDtype("float64", 0.0) + ) + tm.assert_frame_equal(result, expected) + + def test_loc_getitem_sparse_series(self): + # GH34687 + s = Series([1.0, 0.0, 0.0, 0.0, 0.0], dtype=SparseDtype("float64", 0.0)) + + result = s.loc[range(2)] + expected = Series([1.0, 0.0], dtype=SparseDtype("float64", 0.0)) + tm.assert_series_equal(result, expected) + + result = s.loc[range(3)].loc[range(2)] + expected = Series([1.0, 0.0], dtype=SparseDtype("float64", 0.0)) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("indexer", ["loc", "iloc"]) + def test_getitem_single_row_sparse_df(self, indexer): + # GH#46406 + df = DataFrame([[1.0, 0.0, 1.5], [0.0, 2.0, 0.0]], dtype=SparseDtype(float)) + result = getattr(df, indexer)[0] + expected = Series([1.0, 0.0, 1.5], dtype=SparseDtype(float), name=0) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("key_type", [iter, np.array, Series, Index]) + def test_loc_getitem_iterable(self, float_frame, key_type): + idx = key_type(["A", "B", "C"]) + result = float_frame.loc[:, idx] + expected = float_frame.loc[:, ["A", "B", "C"]] + tm.assert_frame_equal(result, expected) + + def test_loc_getitem_timedelta_0seconds(self): + # GH#10583 + df = DataFrame(np.random.default_rng(2).normal(size=(10, 4))) + df.index = timedelta_range(start="0s", periods=10, freq="s") + expected = df.loc[Timedelta("0s") :, :] + result = df.loc["0s":, :] + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "val,expected", [(2**63 - 1, Series([1])), (2**63, Series([2]))] + ) + def test_loc_getitem_uint64_scalar(self, val, expected): + # see GH#19399 + df = DataFrame([1, 2], index=[2**63 - 1, 2**63]) + result = df.loc[val] + + expected.name = val + tm.assert_series_equal(result, expected) + + def test_loc_setitem_int_label_with_float_index(self, float_numpy_dtype): + # note labels are floats + dtype = float_numpy_dtype + ser = Series(["a", "b", "c"], index=Index([0, 0.5, 1], dtype=dtype)) + expected = ser.copy() + + ser.loc[1] = "zoo" + expected.iloc[2] = "zoo" + + tm.assert_series_equal(ser, expected) + + @pytest.mark.parametrize( + "indexer, expected", + [ + # The test name is a misnomer in the 0 case as df.index[indexer] + # is a scalar. + (0, [20, 1, 2, 3, 4, 5, 6, 7, 8, 9]), + (slice(4, 8), [0, 1, 2, 3, 20, 20, 20, 20, 8, 9]), + ([3, 5], [0, 1, 2, 20, 4, 20, 6, 7, 8, 9]), + ], + ) + def test_loc_setitem_listlike_with_timedelta64index(self, indexer, expected): + # GH#16637 + tdi = to_timedelta(range(10), unit="s") + df = DataFrame({"x": range(10)}, dtype="int64", index=tdi) + + df.loc[df.index[indexer], "x"] = 20 + + expected = DataFrame( + expected, + index=tdi, + columns=["x"], + dtype="int64", + ) + + tm.assert_frame_equal(expected, df) + + def test_loc_setitem_categorical_values_partial_column_slice(self): + # Assigning a Category to parts of a int/... column uses the values of + # the Categorical + df = DataFrame({"a": [1, 1, 1, 1, 1], "b": list("aaaaa")}) + exp = DataFrame({"a": [1, "b", "b", 1, 1], "b": list("aabba")}) + with tm.assert_produces_warning( + FutureWarning, match="item of incompatible dtype" + ): + df.loc[1:2, "a"] = Categorical(["b", "b"], categories=["a", "b"]) + df.loc[2:3, "b"] = Categorical(["b", "b"], categories=["a", "b"]) + tm.assert_frame_equal(df, exp) + + def test_loc_setitem_single_row_categorical(self, using_infer_string): + # GH#25495 + df = DataFrame({"Alpha": ["a"], "Numeric": [0]}) + categories = Categorical(df["Alpha"], categories=["a", "b", "c"]) + + # pre-2.0 this swapped in a new array, in 2.0 it operates inplace, + # consistent with non-split-path + df.loc[:, "Alpha"] = categories + + result = df["Alpha"] + expected = Series(categories, index=df.index, name="Alpha").astype( + object if not using_infer_string else "string[pyarrow_numpy]" + ) + tm.assert_series_equal(result, expected) + + # double-check that the non-loc setting retains categoricalness + df["Alpha"] = categories + tm.assert_series_equal(df["Alpha"], Series(categories, name="Alpha")) + + def test_loc_setitem_datetime_coercion(self): + # GH#1048 + df = DataFrame({"c": [Timestamp("2010-10-01")] * 3}) + df.loc[0:1, "c"] = np.datetime64("2008-08-08") + assert Timestamp("2008-08-08") == df.loc[0, "c"] + assert Timestamp("2008-08-08") == df.loc[1, "c"] + with tm.assert_produces_warning(FutureWarning, match="incompatible dtype"): + df.loc[2, "c"] = date(2005, 5, 5) + assert Timestamp("2005-05-05").date() == df.loc[2, "c"] + + @pytest.mark.parametrize("idxer", ["var", ["var"]]) + def test_loc_setitem_datetimeindex_tz(self, idxer, tz_naive_fixture): + # GH#11365 + tz = tz_naive_fixture + idx = date_range(start="2015-07-12", periods=3, freq="h", tz=tz) + expected = DataFrame(1.2, index=idx, columns=["var"]) + # if result started off with object dtype, then the .loc.__setitem__ + # below would retain object dtype + result = DataFrame(index=idx, columns=["var"], dtype=np.float64) + with tm.assert_produces_warning( + FutureWarning if idxer == "var" else None, match="incompatible dtype" + ): + # See https://github.com/pandas-dev/pandas/issues/56223 + result.loc[:, idxer] = expected + tm.assert_frame_equal(result, expected) + + def test_loc_setitem_time_key(self, using_array_manager): + index = date_range("2012-01-01", "2012-01-05", freq="30min") + df = DataFrame( + np.random.default_rng(2).standard_normal((len(index), 5)), index=index + ) + akey = time(12, 0, 0) + bkey = slice(time(13, 0, 0), time(14, 0, 0)) + ainds = [24, 72, 120, 168] + binds = [26, 27, 28, 74, 75, 76, 122, 123, 124, 170, 171, 172] + + result = df.copy() + result.loc[akey] = 0 + result = result.loc[akey] + expected = df.loc[akey].copy() + expected.loc[:] = 0 + if using_array_manager: + # TODO(ArrayManager) we are still overwriting columns + expected = expected.astype(float) + tm.assert_frame_equal(result, expected) + + result = df.copy() + result.loc[akey] = 0 + result.loc[akey] = df.iloc[ainds] + tm.assert_frame_equal(result, df) + + result = df.copy() + result.loc[bkey] = 0 + result = result.loc[bkey] + expected = df.loc[bkey].copy() + expected.loc[:] = 0 + if using_array_manager: + # TODO(ArrayManager) we are still overwriting columns + expected = expected.astype(float) + tm.assert_frame_equal(result, expected) + + result = df.copy() + result.loc[bkey] = 0 + result.loc[bkey] = df.iloc[binds] + tm.assert_frame_equal(result, df) + + @pytest.mark.parametrize("key", ["A", ["A"], ("A", slice(None))]) + def test_loc_setitem_unsorted_multiindex_columns(self, key): + # GH#38601 + mi = MultiIndex.from_tuples([("A", 4), ("B", "3"), ("A", "2")]) + df = DataFrame([[1, 2, 3], [4, 5, 6]], columns=mi) + obj = df.copy() + obj.loc[:, key] = np.zeros((2, 2), dtype="int64") + expected = DataFrame([[0, 2, 0], [0, 5, 0]], columns=mi) + tm.assert_frame_equal(obj, expected) + + df = df.sort_index(axis=1) + df.loc[:, key] = np.zeros((2, 2), dtype="int64") + expected = expected.sort_index(axis=1) + tm.assert_frame_equal(df, expected) + + def test_loc_setitem_uint_drop(self, any_int_numpy_dtype): + # see GH#18311 + # assigning series.loc[0] = 4 changed series.dtype to int + series = Series([1, 2, 3], dtype=any_int_numpy_dtype) + series.loc[0] = 4 + expected = Series([4, 2, 3], dtype=any_int_numpy_dtype) + tm.assert_series_equal(series, expected) + + def test_loc_setitem_td64_non_nano(self): + # GH#14155 + ser = Series(10 * [np.timedelta64(10, "m")]) + ser.loc[[1, 2, 3]] = np.timedelta64(20, "m") + expected = Series(10 * [np.timedelta64(10, "m")]) + expected.loc[[1, 2, 3]] = Timedelta(np.timedelta64(20, "m")) + tm.assert_series_equal(ser, expected) + + def test_loc_setitem_2d_to_1d_raises(self): + data = np.random.default_rng(2).standard_normal((2, 2)) + # float64 dtype to avoid upcast when trying to set float data + ser = Series(range(2), dtype="float64") + + msg = "setting an array element with a sequence." + with pytest.raises(ValueError, match=msg): + ser.loc[range(2)] = data + + with pytest.raises(ValueError, match=msg): + ser.loc[:] = data + + def test_loc_getitem_interval_index(self): + # GH#19977 + index = pd.interval_range(start=0, periods=3) + df = DataFrame( + [[1, 2, 3], [4, 5, 6], [7, 8, 9]], index=index, columns=["A", "B", "C"] + ) + + expected = 1 + result = df.loc[0.5, "A"] + tm.assert_almost_equal(result, expected) + + def test_loc_getitem_interval_index2(self): + # GH#19977 + index = pd.interval_range(start=0, periods=3, closed="both") + df = DataFrame( + [[1, 2, 3], [4, 5, 6], [7, 8, 9]], index=index, columns=["A", "B", "C"] + ) + + index_exp = pd.interval_range(start=0, periods=2, freq=1, closed="both") + expected = Series([1, 4], index=index_exp, name="A") + result = df.loc[1, "A"] + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("tpl", [(1,), (1, 2)]) + def test_loc_getitem_index_single_double_tuples(self, tpl): + # GH#20991 + idx = Index( + [(1,), (1, 2)], + name="A", + tupleize_cols=False, + ) + df = DataFrame(index=idx) + + result = df.loc[[tpl]] + idx = Index([tpl], name="A", tupleize_cols=False) + expected = DataFrame(index=idx) + tm.assert_frame_equal(result, expected) + + def test_loc_getitem_index_namedtuple(self): + IndexType = namedtuple("IndexType", ["a", "b"]) + idx1 = IndexType("foo", "bar") + idx2 = IndexType("baz", "bof") + index = Index([idx1, idx2], name="composite_index", tupleize_cols=False) + df = DataFrame([(1, 2), (3, 4)], index=index, columns=["A", "B"]) + + result = df.loc[IndexType("foo", "bar")]["A"] + assert result == 1 + + def test_loc_setitem_single_column_mixed(self, using_infer_string): + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 3)), + index=["a", "b", "c", "d", "e"], + columns=["foo", "bar", "baz"], + ) + df["str"] = "qux" + df.loc[df.index[::2], "str"] = np.nan + expected = Series( + [np.nan, "qux", np.nan, "qux", np.nan], + dtype=object if not using_infer_string else "string[pyarrow_numpy]", + ).values + tm.assert_almost_equal(df["str"].values, expected) + + def test_loc_setitem_cast2(self): + # GH#7704 + # dtype conversion on setting + df = DataFrame(np.random.default_rng(2).random((30, 3)), columns=tuple("ABC")) + df["event"] = np.nan + with tm.assert_produces_warning( + FutureWarning, match="item of incompatible dtype" + ): + df.loc[10, "event"] = "foo" + result = df.dtypes + expected = Series( + [np.dtype("float64")] * 3 + [np.dtype("object")], + index=["A", "B", "C", "event"], + ) + tm.assert_series_equal(result, expected) + + def test_loc_setitem_cast3(self): + # Test that data type is preserved . GH#5782 + df = DataFrame({"one": np.arange(6, dtype=np.int8)}) + df.loc[1, "one"] = 6 + assert df.dtypes.one == np.dtype(np.int8) + df.one = np.int8(7) + assert df.dtypes.one == np.dtype(np.int8) + + def test_loc_setitem_range_key(self, frame_or_series): + # GH#45479 don't treat range key as positional + obj = frame_or_series(range(5), index=[3, 4, 1, 0, 2]) + + values = [9, 10, 11] + if obj.ndim == 2: + values = [[9], [10], [11]] + + obj.loc[range(3)] = values + + expected = frame_or_series([0, 1, 10, 9, 11], index=obj.index) + tm.assert_equal(obj, expected) + + def test_loc_setitem_numpy_frame_categorical_value(self): + # GH#52927 + df = DataFrame({"a": [1, 1, 1, 1, 1], "b": ["a", "a", "a", "a", "a"]}) + df.loc[1:2, "a"] = Categorical([2, 2], categories=[1, 2]) + + expected = DataFrame({"a": [1, 2, 2, 1, 1], "b": ["a", "a", "a", "a", "a"]}) + tm.assert_frame_equal(df, expected) + + +class TestLocWithEllipsis: + @pytest.fixture(params=[tm.loc, tm.iloc]) + def indexer(self, request): + # Test iloc while we're here + return request.param + + @pytest.fixture + def obj(self, series_with_simple_index, frame_or_series): + obj = series_with_simple_index + if frame_or_series is not Series: + obj = obj.to_frame() + return obj + + def test_loc_iloc_getitem_ellipsis(self, obj, indexer): + result = indexer(obj)[...] + tm.assert_equal(result, obj) + + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + def test_loc_iloc_getitem_leading_ellipses(self, series_with_simple_index, indexer): + obj = series_with_simple_index + key = 0 if (indexer is tm.iloc or len(obj) == 0) else obj.index[0] + + if indexer is tm.loc and obj.index.inferred_type == "boolean": + # passing [False] will get interpreted as a boolean mask + # TODO: should it? unambiguous when lengths dont match? + return + if indexer is tm.loc and isinstance(obj.index, MultiIndex): + msg = "MultiIndex does not support indexing with Ellipsis" + with pytest.raises(NotImplementedError, match=msg): + result = indexer(obj)[..., [key]] + + elif len(obj) != 0: + result = indexer(obj)[..., [key]] + expected = indexer(obj)[[key]] + tm.assert_series_equal(result, expected) + + key2 = 0 if indexer is tm.iloc else obj.name + df = obj.to_frame() + result = indexer(df)[..., [key2]] + expected = indexer(df)[:, [key2]] + tm.assert_frame_equal(result, expected) + + def test_loc_iloc_getitem_ellipses_only_one_ellipsis(self, obj, indexer): + # GH37750 + key = 0 if (indexer is tm.iloc or len(obj) == 0) else obj.index[0] + + with pytest.raises(IndexingError, match=_one_ellipsis_message): + indexer(obj)[..., ...] + + with pytest.raises(IndexingError, match=_one_ellipsis_message): + indexer(obj)[..., [key], ...] + + with pytest.raises(IndexingError, match=_one_ellipsis_message): + indexer(obj)[..., ..., key] + + # one_ellipsis_message takes precedence over "Too many indexers" + # only when the first key is Ellipsis + with pytest.raises(IndexingError, match="Too many indexers"): + indexer(obj)[key, ..., ...] + + +class TestLocWithMultiIndex: + @pytest.mark.parametrize( + "keys, expected", + [ + (["b", "a"], [["b", "b", "a", "a"], [1, 2, 1, 2]]), + (["a", "b"], [["a", "a", "b", "b"], [1, 2, 1, 2]]), + ((["a", "b"], [1, 2]), [["a", "a", "b", "b"], [1, 2, 1, 2]]), + ((["a", "b"], [2, 1]), [["a", "a", "b", "b"], [2, 1, 2, 1]]), + ((["b", "a"], [2, 1]), [["b", "b", "a", "a"], [2, 1, 2, 1]]), + ((["b", "a"], [1, 2]), [["b", "b", "a", "a"], [1, 2, 1, 2]]), + ((["c", "a"], [2, 1]), [["c", "a", "a"], [1, 2, 1]]), + ], + ) + @pytest.mark.parametrize("dim", ["index", "columns"]) + def test_loc_getitem_multilevel_index_order(self, dim, keys, expected): + # GH#22797 + # Try to respect order of keys given for MultiIndex.loc + kwargs = {dim: [["c", "a", "a", "b", "b"], [1, 1, 2, 1, 2]]} + df = DataFrame(np.arange(25).reshape(5, 5), **kwargs) + exp_index = MultiIndex.from_arrays(expected) + if dim == "index": + res = df.loc[keys, :] + tm.assert_index_equal(res.index, exp_index) + elif dim == "columns": + res = df.loc[:, keys] + tm.assert_index_equal(res.columns, exp_index) + + def test_loc_preserve_names(self, multiindex_year_month_day_dataframe_random_data): + ymd = multiindex_year_month_day_dataframe_random_data + + result = ymd.loc[2000] + result2 = ymd["A"].loc[2000] + assert result.index.names == ymd.index.names[1:] + assert result2.index.names == ymd.index.names[1:] + + result = ymd.loc[2000, 2] + result2 = ymd["A"].loc[2000, 2] + assert result.index.name == ymd.index.names[2] + assert result2.index.name == ymd.index.names[2] + + def test_loc_getitem_multiindex_nonunique_len_zero(self): + # GH#13691 + mi = MultiIndex.from_product([[0], [1, 1]]) + ser = Series(0, index=mi) + + res = ser.loc[[]] + + expected = ser[:0] + tm.assert_series_equal(res, expected) + + res2 = ser.loc[ser.iloc[0:0]] + tm.assert_series_equal(res2, expected) + + def test_loc_getitem_access_none_value_in_multiindex(self): + # GH#34318: test that you can access a None value using .loc + # through a Multiindex + + ser = Series([None], MultiIndex.from_arrays([["Level1"], ["Level2"]])) + result = ser.loc[("Level1", "Level2")] + assert result is None + + midx = MultiIndex.from_product([["Level1"], ["Level2_a", "Level2_b"]]) + ser = Series([None] * len(midx), dtype=object, index=midx) + result = ser.loc[("Level1", "Level2_a")] + assert result is None + + ser = Series([1] * len(midx), dtype=object, index=midx) + result = ser.loc[("Level1", "Level2_a")] + assert result == 1 + + def test_loc_setitem_multiindex_slice(self): + # GH 34870 + + index = MultiIndex.from_tuples( + zip( + ["bar", "bar", "baz", "baz", "foo", "foo", "qux", "qux"], + ["one", "two", "one", "two", "one", "two", "one", "two"], + ), + names=["first", "second"], + ) + + result = Series([1, 1, 1, 1, 1, 1, 1, 1], index=index) + result.loc[("baz", "one"):("foo", "two")] = 100 + + expected = Series([1, 1, 100, 100, 100, 100, 1, 1], index=index) + + tm.assert_series_equal(result, expected) + + def test_loc_getitem_slice_datetime_objs_with_datetimeindex(self): + times = date_range("2000-01-01", freq="10min", periods=100000) + ser = Series(range(100000), times) + result = ser.loc[datetime(1900, 1, 1) : datetime(2100, 1, 1)] + tm.assert_series_equal(result, ser) + + def test_loc_getitem_datetime_string_with_datetimeindex(self): + # GH 16710 + df = DataFrame( + {"a": range(10), "b": range(10)}, + index=date_range("2010-01-01", "2010-01-10"), + ) + result = df.loc[["2010-01-01", "2010-01-05"], ["a", "b"]] + expected = DataFrame( + {"a": [0, 4], "b": [0, 4]}, + index=DatetimeIndex(["2010-01-01", "2010-01-05"]), + ) + tm.assert_frame_equal(result, expected) + + def test_loc_getitem_sorted_index_level_with_duplicates(self): + # GH#4516 sorting a MultiIndex with duplicates and multiple dtypes + mi = MultiIndex.from_tuples( + [ + ("foo", "bar"), + ("foo", "bar"), + ("bah", "bam"), + ("bah", "bam"), + ("foo", "bar"), + ("bah", "bam"), + ], + names=["A", "B"], + ) + df = DataFrame( + [ + [1.0, 1], + [2.0, 2], + [3.0, 3], + [4.0, 4], + [5.0, 5], + [6.0, 6], + ], + index=mi, + columns=["C", "D"], + ) + df = df.sort_index(level=0) + + expected = DataFrame( + [[1.0, 1], [2.0, 2], [5.0, 5]], columns=["C", "D"], index=mi.take([0, 1, 4]) + ) + + result = df.loc[("foo", "bar")] + tm.assert_frame_equal(result, expected) + + def test_additional_element_to_categorical_series_loc(self): + # GH#47677 + result = Series(["a", "b", "c"], dtype="category") + result.loc[3] = 0 + expected = Series(["a", "b", "c", 0], dtype="object") + tm.assert_series_equal(result, expected) + + def test_additional_categorical_element_loc(self): + # GH#47677 + result = Series(["a", "b", "c"], dtype="category") + result.loc[3] = "a" + expected = Series(["a", "b", "c", "a"], dtype="category") + tm.assert_series_equal(result, expected) + + def test_loc_set_nan_in_categorical_series(self, any_numeric_ea_dtype): + # GH#47677 + srs = Series( + [1, 2, 3], + dtype=CategoricalDtype(Index([1, 2, 3], dtype=any_numeric_ea_dtype)), + ) + # enlarge + srs.loc[3] = np.nan + expected = Series( + [1, 2, 3, np.nan], + dtype=CategoricalDtype(Index([1, 2, 3], dtype=any_numeric_ea_dtype)), + ) + tm.assert_series_equal(srs, expected) + # set into + srs.loc[1] = np.nan + expected = Series( + [1, np.nan, 3, np.nan], + dtype=CategoricalDtype(Index([1, 2, 3], dtype=any_numeric_ea_dtype)), + ) + tm.assert_series_equal(srs, expected) + + @pytest.mark.parametrize("na", (np.nan, pd.NA, None, pd.NaT)) + def test_loc_consistency_series_enlarge_set_into(self, na): + # GH#47677 + srs_enlarge = Series(["a", "b", "c"], dtype="category") + srs_enlarge.loc[3] = na + + srs_setinto = Series(["a", "b", "c", "a"], dtype="category") + srs_setinto.loc[3] = na + + tm.assert_series_equal(srs_enlarge, srs_setinto) + expected = Series(["a", "b", "c", na], dtype="category") + tm.assert_series_equal(srs_enlarge, expected) + + def test_loc_getitem_preserves_index_level_category_dtype(self): + # GH#15166 + df = DataFrame( + data=np.arange(2, 22, 2), + index=MultiIndex( + levels=[CategoricalIndex(["a", "b"]), range(10)], + codes=[[0] * 5 + [1] * 5, range(10)], + names=["Index1", "Index2"], + ), + ) + + expected = CategoricalIndex( + ["a", "b"], + categories=["a", "b"], + ordered=False, + name="Index1", + dtype="category", + ) + + result = df.index.levels[0] + tm.assert_index_equal(result, expected) + + result = df.loc[["a"]].index.levels[0] + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("lt_value", [30, 10]) + def test_loc_multiindex_levels_contain_values_not_in_index_anymore(self, lt_value): + # GH#41170 + df = DataFrame({"a": [12, 23, 34, 45]}, index=[list("aabb"), [0, 1, 2, 3]]) + with pytest.raises(KeyError, match=r"\['b'\] not in index"): + df.loc[df["a"] < lt_value, :].loc[["b"], :] + + def test_loc_multiindex_null_slice_na_level(self): + # GH#42055 + lev1 = np.array([np.nan, np.nan]) + lev2 = ["bar", "baz"] + mi = MultiIndex.from_arrays([lev1, lev2]) + ser = Series([0, 1], index=mi) + result = ser.loc[:, "bar"] + + # TODO: should we have name="bar"? + expected = Series([0], index=[np.nan]) + tm.assert_series_equal(result, expected) + + def test_loc_drops_level(self): + # Based on test_series_varied_multiindex_alignment, where + # this used to fail to drop the first level + mi = MultiIndex.from_product( + [list("ab"), list("xy"), [1, 2]], names=["ab", "xy", "num"] + ) + ser = Series(range(8), index=mi) + + loc_result = ser.loc["a", :, :] + expected = ser.index.droplevel(0)[:4] + tm.assert_index_equal(loc_result.index, expected) + + +class TestLocSetitemWithExpansion: + def test_loc_setitem_with_expansion_large_dataframe(self, monkeypatch): + # GH#10692 + size_cutoff = 50 + with monkeypatch.context(): + monkeypatch.setattr(libindex, "_SIZE_CUTOFF", size_cutoff) + result = DataFrame({"x": range(size_cutoff)}, dtype="int64") + result.loc[size_cutoff] = size_cutoff + expected = DataFrame({"x": range(size_cutoff + 1)}, dtype="int64") + tm.assert_frame_equal(result, expected) + + def test_loc_setitem_empty_series(self): + # GH#5226 + + # partially set with an empty object series + ser = Series(dtype=object) + ser.loc[1] = 1 + tm.assert_series_equal(ser, Series([1], index=[1])) + ser.loc[3] = 3 + tm.assert_series_equal(ser, Series([1, 3], index=[1, 3])) + + def test_loc_setitem_empty_series_float(self): + # GH#5226 + + # partially set with an empty object series + ser = Series(dtype=object) + ser.loc[1] = 1.0 + tm.assert_series_equal(ser, Series([1.0], index=[1])) + ser.loc[3] = 3.0 + tm.assert_series_equal(ser, Series([1.0, 3.0], index=[1, 3])) + + def test_loc_setitem_empty_series_str_idx(self): + # GH#5226 + + # partially set with an empty object series + ser = Series(dtype=object) + ser.loc["foo"] = 1 + tm.assert_series_equal(ser, Series([1], index=Index(["foo"], dtype=object))) + ser.loc["bar"] = 3 + tm.assert_series_equal( + ser, Series([1, 3], index=Index(["foo", "bar"], dtype=object)) + ) + ser.loc[3] = 4 + tm.assert_series_equal( + ser, Series([1, 3, 4], index=Index(["foo", "bar", 3], dtype=object)) + ) + + def test_loc_setitem_incremental_with_dst(self): + # GH#20724 + base = datetime(2015, 11, 1, tzinfo=gettz("US/Pacific")) + idxs = [base + timedelta(seconds=i * 900) for i in range(16)] + result = Series([0], index=[idxs[0]]) + for ts in idxs: + result.loc[ts] = 1 + expected = Series(1, index=idxs) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "conv", + [ + lambda x: x, + lambda x: x.to_datetime64(), + lambda x: x.to_pydatetime(), + lambda x: np.datetime64(x), + ], + ids=["self", "to_datetime64", "to_pydatetime", "np.datetime64"], + ) + def test_loc_setitem_datetime_keys_cast(self, conv): + # GH#9516 + dt1 = Timestamp("20130101 09:00:00") + dt2 = Timestamp("20130101 10:00:00") + df = DataFrame() + df.loc[conv(dt1), "one"] = 100 + df.loc[conv(dt2), "one"] = 200 + + expected = DataFrame( + {"one": [100.0, 200.0]}, + index=[dt1, dt2], + columns=Index(["one"], dtype=object), + ) + tm.assert_frame_equal(df, expected) + + def test_loc_setitem_categorical_column_retains_dtype(self, ordered): + # GH16360 + result = DataFrame({"A": [1]}) + result.loc[:, "B"] = Categorical(["b"], ordered=ordered) + expected = DataFrame({"A": [1], "B": Categorical(["b"], ordered=ordered)}) + tm.assert_frame_equal(result, expected) + + def test_loc_setitem_with_expansion_and_existing_dst(self): + # GH#18308 + start = Timestamp("2017-10-29 00:00:00+0200", tz="Europe/Madrid") + end = Timestamp("2017-10-29 03:00:00+0100", tz="Europe/Madrid") + ts = Timestamp("2016-10-10 03:00:00", tz="Europe/Madrid") + idx = date_range(start, end, inclusive="left", freq="h") + assert ts not in idx # i.e. result.loc setitem is with-expansion + + result = DataFrame(index=idx, columns=["value"]) + result.loc[ts, "value"] = 12 + expected = DataFrame( + [np.nan] * len(idx) + [12], + index=idx.append(DatetimeIndex([ts])), + columns=["value"], + dtype=object, + ) + tm.assert_frame_equal(result, expected) + + def test_setitem_with_expansion(self): + # indexing - setting an element + df = DataFrame( + data=to_datetime(["2015-03-30 20:12:32", "2015-03-12 00:11:11"]), + columns=["time"], + ) + df["new_col"] = ["new", "old"] + df.time = df.set_index("time").index.tz_localize("UTC") + v = df[df.new_col == "new"].set_index("time").index.tz_convert("US/Pacific") + + # pre-2.0 trying to set a single element on a part of a different + # timezone converted to object; in 2.0 it retains dtype + df2 = df.copy() + df2.loc[df2.new_col == "new", "time"] = v + + expected = Series([v[0].tz_convert("UTC"), df.loc[1, "time"]], name="time") + tm.assert_series_equal(df2.time, expected) + + v = df.loc[df.new_col == "new", "time"] + Timedelta("1s") + df.loc[df.new_col == "new", "time"] = v + tm.assert_series_equal(df.loc[df.new_col == "new", "time"], v) + + def test_loc_setitem_with_expansion_inf_upcast_empty(self): + # Test with np.inf in columns + df = DataFrame() + df.loc[0, 0] = 1 + df.loc[1, 1] = 2 + df.loc[0, np.inf] = 3 + + result = df.columns + expected = Index([0, 1, np.inf], dtype=np.float64) + tm.assert_index_equal(result, expected) + + @pytest.mark.filterwarnings("ignore:indexing past lexsort depth") + def test_loc_setitem_with_expansion_nonunique_index(self, index): + # GH#40096 + if not len(index): + pytest.skip("Not relevant for empty Index") + + index = index.repeat(2) # ensure non-unique + N = len(index) + arr = np.arange(N).astype(np.int64) + + orig = DataFrame(arr, index=index, columns=[0]) + + # key that will requiring object-dtype casting in the index + key = "kapow" + assert key not in index # otherwise test is invalid + # TODO: using a tuple key breaks here in many cases + + exp_index = index.insert(len(index), key) + if isinstance(index, MultiIndex): + assert exp_index[-1][0] == key + else: + assert exp_index[-1] == key + exp_data = np.arange(N + 1).astype(np.float64) + expected = DataFrame(exp_data, index=exp_index, columns=[0]) + + # Add new row, but no new columns + df = orig.copy() + df.loc[key, 0] = N + tm.assert_frame_equal(df, expected) + + # add new row on a Series + ser = orig.copy()[0] + ser.loc[key] = N + # the series machinery lets us preserve int dtype instead of float + expected = expected[0].astype(np.int64) + tm.assert_series_equal(ser, expected) + + # add new row and new column + df = orig.copy() + df.loc[key, 1] = N + expected = DataFrame( + {0: list(arr) + [np.nan], 1: [np.nan] * N + [float(N)]}, + index=exp_index, + ) + tm.assert_frame_equal(df, expected) + + @pytest.mark.parametrize( + "dtype", ["Int32", "Int64", "UInt32", "UInt64", "Float32", "Float64"] + ) + def test_loc_setitem_with_expansion_preserves_nullable_int(self, dtype): + # GH#42099 + ser = Series([0, 1, 2, 3], dtype=dtype) + df = DataFrame({"data": ser}) + + result = DataFrame(index=df.index) + result.loc[df.index, "data"] = ser + + tm.assert_frame_equal(result, df, check_column_type=False) + + result = DataFrame(index=df.index) + result.loc[df.index, "data"] = ser._values + tm.assert_frame_equal(result, df, check_column_type=False) + + def test_loc_setitem_ea_not_full_column(self): + # GH#39163 + df = DataFrame({"A": range(5)}) + + val = date_range("2016-01-01", periods=3, tz="US/Pacific") + + df.loc[[0, 1, 2], "B"] = val + + bex = val.append(DatetimeIndex([pd.NaT, pd.NaT], dtype=val.dtype)) + expected = DataFrame({"A": range(5), "B": bex}) + assert expected.dtypes["B"] == val.dtype + tm.assert_frame_equal(df, expected) + + +class TestLocCallable: + def test_frame_loc_getitem_callable(self): + # GH#11485 + df = DataFrame({"A": [1, 2, 3, 4], "B": list("aabb"), "C": [1, 2, 3, 4]}) + # iloc cannot use boolean Series (see GH3635) + + # return bool indexer + res = df.loc[lambda x: x.A > 2] + tm.assert_frame_equal(res, df.loc[df.A > 2]) + + res = df.loc[lambda x: x.B == "b", :] + tm.assert_frame_equal(res, df.loc[df.B == "b", :]) + + res = df.loc[lambda x: x.A > 2, lambda x: x.columns == "B"] + tm.assert_frame_equal(res, df.loc[df.A > 2, [False, True, False]]) + + res = df.loc[lambda x: x.A > 2, lambda x: "B"] + tm.assert_series_equal(res, df.loc[df.A > 2, "B"]) + + res = df.loc[lambda x: x.A > 2, lambda x: ["A", "B"]] + tm.assert_frame_equal(res, df.loc[df.A > 2, ["A", "B"]]) + + res = df.loc[lambda x: x.A == 2, lambda x: ["A", "B"]] + tm.assert_frame_equal(res, df.loc[df.A == 2, ["A", "B"]]) + + # scalar + res = df.loc[lambda x: 1, lambda x: "A"] + assert res == df.loc[1, "A"] + + def test_frame_loc_getitem_callable_mixture(self): + # GH#11485 + df = DataFrame({"A": [1, 2, 3, 4], "B": list("aabb"), "C": [1, 2, 3, 4]}) + + res = df.loc[lambda x: x.A > 2, ["A", "B"]] + tm.assert_frame_equal(res, df.loc[df.A > 2, ["A", "B"]]) + + res = df.loc[[2, 3], lambda x: ["A", "B"]] + tm.assert_frame_equal(res, df.loc[[2, 3], ["A", "B"]]) + + res = df.loc[3, lambda x: ["A", "B"]] + tm.assert_series_equal(res, df.loc[3, ["A", "B"]]) + + def test_frame_loc_getitem_callable_labels(self): + # GH#11485 + df = DataFrame({"X": [1, 2, 3, 4], "Y": list("aabb")}, index=list("ABCD")) + + # return label + res = df.loc[lambda x: ["A", "C"]] + tm.assert_frame_equal(res, df.loc[["A", "C"]]) + + res = df.loc[lambda x: ["A", "C"], :] + tm.assert_frame_equal(res, df.loc[["A", "C"], :]) + + res = df.loc[lambda x: ["A", "C"], lambda x: "X"] + tm.assert_series_equal(res, df.loc[["A", "C"], "X"]) + + res = df.loc[lambda x: ["A", "C"], lambda x: ["X"]] + tm.assert_frame_equal(res, df.loc[["A", "C"], ["X"]]) + + # mixture + res = df.loc[["A", "C"], lambda x: "X"] + tm.assert_series_equal(res, df.loc[["A", "C"], "X"]) + + res = df.loc[["A", "C"], lambda x: ["X"]] + tm.assert_frame_equal(res, df.loc[["A", "C"], ["X"]]) + + res = df.loc[lambda x: ["A", "C"], "X"] + tm.assert_series_equal(res, df.loc[["A", "C"], "X"]) + + res = df.loc[lambda x: ["A", "C"], ["X"]] + tm.assert_frame_equal(res, df.loc[["A", "C"], ["X"]]) + + def test_frame_loc_setitem_callable(self): + # GH#11485 + df = DataFrame( + {"X": [1, 2, 3, 4], "Y": Series(list("aabb"), dtype=object)}, + index=list("ABCD"), + ) + + # return label + res = df.copy() + res.loc[lambda x: ["A", "C"]] = -20 + exp = df.copy() + exp.loc[["A", "C"]] = -20 + tm.assert_frame_equal(res, exp) + + res = df.copy() + res.loc[lambda x: ["A", "C"], :] = 20 + exp = df.copy() + exp.loc[["A", "C"], :] = 20 + tm.assert_frame_equal(res, exp) + + res = df.copy() + res.loc[lambda x: ["A", "C"], lambda x: "X"] = -1 + exp = df.copy() + exp.loc[["A", "C"], "X"] = -1 + tm.assert_frame_equal(res, exp) + + res = df.copy() + res.loc[lambda x: ["A", "C"], lambda x: ["X"]] = [5, 10] + exp = df.copy() + exp.loc[["A", "C"], ["X"]] = [5, 10] + tm.assert_frame_equal(res, exp) + + # mixture + res = df.copy() + res.loc[["A", "C"], lambda x: "X"] = np.array([-1, -2]) + exp = df.copy() + exp.loc[["A", "C"], "X"] = np.array([-1, -2]) + tm.assert_frame_equal(res, exp) + + res = df.copy() + res.loc[["A", "C"], lambda x: ["X"]] = 10 + exp = df.copy() + exp.loc[["A", "C"], ["X"]] = 10 + tm.assert_frame_equal(res, exp) + + res = df.copy() + res.loc[lambda x: ["A", "C"], "X"] = -2 + exp = df.copy() + exp.loc[["A", "C"], "X"] = -2 + tm.assert_frame_equal(res, exp) + + res = df.copy() + res.loc[lambda x: ["A", "C"], ["X"]] = -4 + exp = df.copy() + exp.loc[["A", "C"], ["X"]] = -4 + tm.assert_frame_equal(res, exp) + + +class TestPartialStringSlicing: + def test_loc_getitem_partial_string_slicing_datetimeindex(self): + # GH#35509 + df = DataFrame( + {"col1": ["a", "b", "c"], "col2": [1, 2, 3]}, + index=to_datetime(["2020-08-01", "2020-07-02", "2020-08-05"]), + ) + expected = DataFrame( + {"col1": ["a", "c"], "col2": [1, 3]}, + index=to_datetime(["2020-08-01", "2020-08-05"]), + ) + result = df.loc["2020-08"] + tm.assert_frame_equal(result, expected) + + def test_loc_getitem_partial_string_slicing_with_periodindex(self): + pi = pd.period_range(start="2017-01-01", end="2018-01-01", freq="M") + ser = pi.to_series() + result = ser.loc[:"2017-12"] + expected = ser.iloc[:-1] + + tm.assert_series_equal(result, expected) + + def test_loc_getitem_partial_string_slicing_with_timedeltaindex(self): + ix = timedelta_range(start="1 day", end="2 days", freq="1h") + ser = ix.to_series() + result = ser.loc[:"1 days"] + expected = ser.iloc[:-1] + + tm.assert_series_equal(result, expected) + + def test_loc_getitem_str_timedeltaindex(self): + # GH#16896 + df = DataFrame({"x": range(3)}, index=to_timedelta(range(3), unit="days")) + expected = df.iloc[0] + sliced = df.loc["0 days"] + tm.assert_series_equal(sliced, expected) + + @pytest.mark.parametrize("indexer_end", [None, "2020-01-02 23:59:59.999999999"]) + def test_loc_getitem_partial_slice_non_monotonicity( + self, tz_aware_fixture, indexer_end, frame_or_series + ): + # GH#33146 + obj = frame_or_series( + [1] * 5, + index=DatetimeIndex( + [ + Timestamp("2019-12-30"), + Timestamp("2020-01-01"), + Timestamp("2019-12-25"), + Timestamp("2020-01-02 23:59:59.999999999"), + Timestamp("2019-12-19"), + ], + tz=tz_aware_fixture, + ), + ) + expected = frame_or_series( + [1] * 2, + index=DatetimeIndex( + [ + Timestamp("2020-01-01"), + Timestamp("2020-01-02 23:59:59.999999999"), + ], + tz=tz_aware_fixture, + ), + ) + indexer = slice("2020-01-01", indexer_end) + + result = obj[indexer] + tm.assert_equal(result, expected) + + result = obj.loc[indexer] + tm.assert_equal(result, expected) + + +class TestLabelSlicing: + def test_loc_getitem_slicing_datetimes_frame(self): + # GH#7523 + + # unique + df_unique = DataFrame( + np.arange(4.0, dtype="float64"), + index=[datetime(2001, 1, i, 10, 00) for i in [1, 2, 3, 4]], + ) + + # duplicates + df_dups = DataFrame( + np.arange(5.0, dtype="float64"), + index=[datetime(2001, 1, i, 10, 00) for i in [1, 2, 2, 3, 4]], + ) + + for df in [df_unique, df_dups]: + result = df.loc[datetime(2001, 1, 1, 10) :] + tm.assert_frame_equal(result, df) + result = df.loc[: datetime(2001, 1, 4, 10)] + tm.assert_frame_equal(result, df) + result = df.loc[datetime(2001, 1, 1, 10) : datetime(2001, 1, 4, 10)] + tm.assert_frame_equal(result, df) + + result = df.loc[datetime(2001, 1, 1, 11) :] + expected = df.iloc[1:] + tm.assert_frame_equal(result, expected) + result = df.loc["20010101 11":] + tm.assert_frame_equal(result, expected) + + def test_loc_getitem_label_slice_across_dst(self): + # GH#21846 + idx = date_range( + "2017-10-29 01:30:00", tz="Europe/Berlin", periods=5, freq="30 min" + ) + series2 = Series([0, 1, 2, 3, 4], index=idx) + + t_1 = Timestamp("2017-10-29 02:30:00+02:00", tz="Europe/Berlin") + t_2 = Timestamp("2017-10-29 02:00:00+01:00", tz="Europe/Berlin") + result = series2.loc[t_1:t_2] + expected = Series([2, 3], index=idx[2:4]) + tm.assert_series_equal(result, expected) + + result = series2[t_1] + expected = 2 + assert result == expected + + @pytest.mark.parametrize( + "index", + [ + pd.period_range(start="2017-01-01", end="2018-01-01", freq="M"), + timedelta_range(start="1 day", end="2 days", freq="1h"), + ], + ) + def test_loc_getitem_label_slice_period_timedelta(self, index): + ser = index.to_series() + result = ser.loc[: index[-2]] + expected = ser.iloc[:-1] + + tm.assert_series_equal(result, expected) + + def test_loc_getitem_slice_floats_inexact(self): + index = [52195.504153, 52196.303147, 52198.369883] + df = DataFrame(np.random.default_rng(2).random((3, 2)), index=index) + + s1 = df.loc[52195.1:52196.5] + assert len(s1) == 2 + + s1 = df.loc[52195.1:52196.6] + assert len(s1) == 2 + + s1 = df.loc[52195.1:52198.9] + assert len(s1) == 3 + + def test_loc_getitem_float_slice_floatindex(self, float_numpy_dtype): + dtype = float_numpy_dtype + ser = Series( + np.random.default_rng(2).random(10), index=np.arange(10, 20, dtype=dtype) + ) + + assert len(ser.loc[12.0:]) == 8 + assert len(ser.loc[12.5:]) == 7 + + idx = np.arange(10, 20, dtype=dtype) + idx[2] = 12.2 + ser.index = idx + assert len(ser.loc[12.0:]) == 8 + assert len(ser.loc[12.5:]) == 7 + + @pytest.mark.parametrize( + "start,stop, expected_slice", + [ + [np.timedelta64(0, "ns"), None, slice(0, 11)], + [np.timedelta64(1, "D"), np.timedelta64(6, "D"), slice(1, 7)], + [None, np.timedelta64(4, "D"), slice(0, 5)], + ], + ) + def test_loc_getitem_slice_label_td64obj(self, start, stop, expected_slice): + # GH#20393 + ser = Series(range(11), timedelta_range("0 days", "10 days")) + result = ser.loc[slice(start, stop)] + expected = ser.iloc[expected_slice] + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("start", ["2018", "2020"]) + def test_loc_getitem_slice_unordered_dt_index(self, frame_or_series, start): + obj = frame_or_series( + [1, 2, 3], + index=[Timestamp("2016"), Timestamp("2019"), Timestamp("2017")], + ) + with pytest.raises( + KeyError, match="Value based partial slicing on non-monotonic" + ): + obj.loc[start:"2022"] + + @pytest.mark.parametrize("value", [1, 1.5]) + def test_loc_getitem_slice_labels_int_in_object_index(self, frame_or_series, value): + # GH: 26491 + obj = frame_or_series(range(4), index=[value, "first", 2, "third"]) + result = obj.loc[value:"third"] + expected = frame_or_series(range(4), index=[value, "first", 2, "third"]) + tm.assert_equal(result, expected) + + def test_loc_getitem_slice_columns_mixed_dtype(self): + # GH: 20975 + df = DataFrame({"test": 1, 1: 2, 2: 3}, index=[0]) + expected = DataFrame( + data=[[2, 3]], index=[0], columns=Index([1, 2], dtype=object) + ) + tm.assert_frame_equal(df.loc[:, 1:], expected) + + +class TestLocBooleanLabelsAndSlices: + @pytest.mark.parametrize("bool_value", [True, False]) + def test_loc_bool_incompatible_index_raises( + self, index, frame_or_series, bool_value + ): + # GH20432 + message = f"{bool_value}: boolean label can not be used without a boolean index" + if index.inferred_type != "boolean": + obj = frame_or_series(index=index, dtype="object") + with pytest.raises(KeyError, match=message): + obj.loc[bool_value] + + @pytest.mark.parametrize("bool_value", [True, False]) + def test_loc_bool_should_not_raise(self, frame_or_series, bool_value): + obj = frame_or_series( + index=Index([True, False], dtype="boolean"), dtype="object" + ) + obj.loc[bool_value] + + def test_loc_bool_slice_raises(self, index, frame_or_series): + # GH20432 + message = ( + r"slice\(True, False, None\): boolean values can not be used in a slice" + ) + obj = frame_or_series(index=index, dtype="object") + with pytest.raises(TypeError, match=message): + obj.loc[True:False] + + +class TestLocBooleanMask: + def test_loc_setitem_bool_mask_timedeltaindex(self): + # GH#14946 + df = DataFrame({"x": range(10)}) + df.index = to_timedelta(range(10), unit="s") + conditions = [df["x"] > 3, df["x"] == 3, df["x"] < 3] + expected_data = [ + [0, 1, 2, 3, 10, 10, 10, 10, 10, 10], + [0, 1, 2, 10, 4, 5, 6, 7, 8, 9], + [10, 10, 10, 3, 4, 5, 6, 7, 8, 9], + ] + for cond, data in zip(conditions, expected_data): + result = df.copy() + result.loc[cond, "x"] = 10 + + expected = DataFrame( + data, + index=to_timedelta(range(10), unit="s"), + columns=["x"], + dtype="int64", + ) + tm.assert_frame_equal(expected, result) + + @pytest.mark.parametrize("tz", [None, "UTC"]) + def test_loc_setitem_mask_with_datetimeindex_tz(self, tz): + # GH#16889 + # support .loc with alignment and tz-aware DatetimeIndex + mask = np.array([True, False, True, False]) + + idx = date_range("20010101", periods=4, tz=tz) + df = DataFrame({"a": np.arange(4)}, index=idx).astype("float64") + + result = df.copy() + result.loc[mask, :] = df.loc[mask, :] + tm.assert_frame_equal(result, df) + + result = df.copy() + result.loc[mask] = df.loc[mask] + tm.assert_frame_equal(result, df) + + def test_loc_setitem_mask_and_label_with_datetimeindex(self): + # GH#9478 + # a datetimeindex alignment issue with partial setting + df = DataFrame( + np.arange(6.0).reshape(3, 2), + columns=list("AB"), + index=date_range("1/1/2000", periods=3, freq="1h"), + ) + expected = df.copy() + expected["C"] = [expected.index[0]] + [pd.NaT, pd.NaT] + + mask = df.A < 1 + df.loc[mask, "C"] = df.loc[mask].index + tm.assert_frame_equal(df, expected) + + def test_loc_setitem_mask_td64_series_value(self): + # GH#23462 key list of bools, value is a Series + td1 = Timedelta(0) + td2 = Timedelta(28767471428571405) + df = DataFrame({"col": Series([td1, td2])}) + df_copy = df.copy() + ser = Series([td1]) + + expected = df["col"].iloc[1]._value + df.loc[[True, False]] = ser + result = df["col"].iloc[1]._value + + assert expected == result + tm.assert_frame_equal(df, df_copy) + + @td.skip_array_manager_invalid_test # TODO(ArrayManager) rewrite not using .values + def test_loc_setitem_boolean_and_column(self, float_frame): + expected = float_frame.copy() + mask = float_frame["A"] > 0 + + float_frame.loc[mask, "B"] = 0 + + values = expected.values.copy() + values[mask.values, 1] = 0 + expected = DataFrame(values, index=expected.index, columns=expected.columns) + tm.assert_frame_equal(float_frame, expected) + + def test_loc_setitem_ndframe_values_alignment( + self, using_copy_on_write, warn_copy_on_write + ): + # GH#45501 + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df.loc[[False, False, True], ["a"]] = DataFrame( + {"a": [10, 20, 30]}, index=[2, 1, 0] + ) + + expected = DataFrame({"a": [1, 2, 10], "b": [4, 5, 6]}) + tm.assert_frame_equal(df, expected) + + # same thing with Series RHS + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df.loc[[False, False, True], ["a"]] = Series([10, 11, 12], index=[2, 1, 0]) + tm.assert_frame_equal(df, expected) + + # same thing but setting "a" instead of ["a"] + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df.loc[[False, False, True], "a"] = Series([10, 11, 12], index=[2, 1, 0]) + tm.assert_frame_equal(df, expected) + + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df_orig = df.copy() + ser = df["a"] + with tm.assert_cow_warning(warn_copy_on_write): + ser.loc[[False, False, True]] = Series([10, 11, 12], index=[2, 1, 0]) + if using_copy_on_write: + tm.assert_frame_equal(df, df_orig) + else: + tm.assert_frame_equal(df, expected) + + def test_loc_indexer_empty_broadcast(self): + # GH#51450 + df = DataFrame({"a": [], "b": []}, dtype=object) + expected = df.copy() + df.loc[np.array([], dtype=np.bool_), ["a"]] = df["a"].copy() + tm.assert_frame_equal(df, expected) + + def test_loc_indexer_all_false_broadcast(self): + # GH#51450 + df = DataFrame({"a": ["x"], "b": ["y"]}, dtype=object) + expected = df.copy() + df.loc[np.array([False], dtype=np.bool_), ["a"]] = df["b"].copy() + tm.assert_frame_equal(df, expected) + + def test_loc_indexer_length_one(self): + # GH#51435 + df = DataFrame({"a": ["x"], "b": ["y"]}, dtype=object) + expected = DataFrame({"a": ["y"], "b": ["y"]}, dtype=object) + df.loc[np.array([True], dtype=np.bool_), ["a"]] = df["b"].copy() + tm.assert_frame_equal(df, expected) + + +class TestLocListlike: + @pytest.mark.parametrize("box", [lambda x: x, np.asarray, list]) + def test_loc_getitem_list_of_labels_categoricalindex_with_na(self, box): + # passing a list can include valid categories _or_ NA values + ci = CategoricalIndex(["A", "B", np.nan]) + ser = Series(range(3), index=ci) + + result = ser.loc[box(ci)] + tm.assert_series_equal(result, ser) + + result = ser[box(ci)] + tm.assert_series_equal(result, ser) + + result = ser.to_frame().loc[box(ci)] + tm.assert_frame_equal(result, ser.to_frame()) + + ser2 = ser[:-1] + ci2 = ci[1:] + # but if there are no NAs present, this should raise KeyError + msg = "not in index" + with pytest.raises(KeyError, match=msg): + ser2.loc[box(ci2)] + + with pytest.raises(KeyError, match=msg): + ser2[box(ci2)] + + with pytest.raises(KeyError, match=msg): + ser2.to_frame().loc[box(ci2)] + + def test_loc_getitem_series_label_list_missing_values(self): + # gh-11428 + key = np.array( + ["2001-01-04", "2001-01-02", "2001-01-04", "2001-01-14"], dtype="datetime64" + ) + ser = Series([2, 5, 8, 11], date_range("2001-01-01", freq="D", periods=4)) + with pytest.raises(KeyError, match="not in index"): + ser.loc[key] + + def test_loc_getitem_series_label_list_missing_integer_values(self): + # GH: 25927 + ser = Series( + index=np.array([9730701000001104, 10049011000001109]), + data=np.array([999000011000001104, 999000011000001104]), + ) + with pytest.raises(KeyError, match="not in index"): + ser.loc[np.array([9730701000001104, 10047311000001102])] + + @pytest.mark.parametrize("to_period", [True, False]) + def test_loc_getitem_listlike_of_datetimelike_keys(self, to_period): + # GH#11497 + + idx = date_range("2011-01-01", "2011-01-02", freq="D", name="idx") + if to_period: + idx = idx.to_period("D") + ser = Series([0.1, 0.2], index=idx, name="s") + + keys = [Timestamp("2011-01-01"), Timestamp("2011-01-02")] + if to_period: + keys = [x.to_period("D") for x in keys] + result = ser.loc[keys] + exp = Series([0.1, 0.2], index=idx, name="s") + if not to_period: + exp.index = exp.index._with_freq(None) + tm.assert_series_equal(result, exp, check_index_type=True) + + keys = [ + Timestamp("2011-01-02"), + Timestamp("2011-01-02"), + Timestamp("2011-01-01"), + ] + if to_period: + keys = [x.to_period("D") for x in keys] + exp = Series( + [0.2, 0.2, 0.1], index=Index(keys, name="idx", dtype=idx.dtype), name="s" + ) + result = ser.loc[keys] + tm.assert_series_equal(result, exp, check_index_type=True) + + keys = [ + Timestamp("2011-01-03"), + Timestamp("2011-01-02"), + Timestamp("2011-01-03"), + ] + if to_period: + keys = [x.to_period("D") for x in keys] + + with pytest.raises(KeyError, match="not in index"): + ser.loc[keys] + + def test_loc_named_index(self): + # GH 42790 + df = DataFrame( + [[1, 2], [4, 5], [7, 8]], + index=["cobra", "viper", "sidewinder"], + columns=["max_speed", "shield"], + ) + expected = df.iloc[:2] + expected.index.name = "foo" + result = df.loc[Index(["cobra", "viper"], name="foo")] + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "columns, column_key, expected_columns", + [ + ([2011, 2012, 2013], [2011, 2012], [0, 1]), + ([2011, 2012, "All"], [2011, 2012], [0, 1]), + ([2011, 2012, "All"], [2011, "All"], [0, 2]), + ], +) +def test_loc_getitem_label_list_integer_labels(columns, column_key, expected_columns): + # gh-14836 + df = DataFrame( + np.random.default_rng(2).random((3, 3)), columns=columns, index=list("ABC") + ) + expected = df.iloc[:, expected_columns] + result = df.loc[["A", "B", "C"], column_key] + + tm.assert_frame_equal(result, expected, check_column_type=True) + + +def test_loc_setitem_float_intindex(): + # GH 8720 + rand_data = np.random.default_rng(2).standard_normal((8, 4)) + result = DataFrame(rand_data) + result.loc[:, 0.5] = np.nan + expected_data = np.hstack((rand_data, np.array([np.nan] * 8).reshape(8, 1))) + expected = DataFrame(expected_data, columns=[0.0, 1.0, 2.0, 3.0, 0.5]) + tm.assert_frame_equal(result, expected) + + result = DataFrame(rand_data) + result.loc[:, 0.5] = np.nan + tm.assert_frame_equal(result, expected) + + +def test_loc_axis_1_slice(): + # GH 10586 + cols = [(yr, m) for yr in [2014, 2015] for m in [7, 8, 9, 10]] + df = DataFrame( + np.ones((10, 8)), + index=tuple("ABCDEFGHIJ"), + columns=MultiIndex.from_tuples(cols), + ) + result = df.loc(axis=1)[(2014, 9):(2015, 8)] + expected = DataFrame( + np.ones((10, 4)), + index=tuple("ABCDEFGHIJ"), + columns=MultiIndex.from_tuples([(2014, 9), (2014, 10), (2015, 7), (2015, 8)]), + ) + tm.assert_frame_equal(result, expected) + + +def test_loc_set_dataframe_multiindex(): + # GH 14592 + expected = DataFrame( + "a", index=range(2), columns=MultiIndex.from_product([range(2), range(2)]) + ) + result = expected.copy() + result.loc[0, [(0, 1)]] = result.loc[0, [(0, 1)]] + tm.assert_frame_equal(result, expected) + + +def test_loc_mixed_int_float(): + # GH#19456 + ser = Series(range(2), Index([1, 2.0], dtype=object)) + + result = ser.loc[1] + assert result == 0 + + +def test_loc_with_positional_slice_raises(): + # GH#31840 + ser = Series(range(4), index=["A", "B", "C", "D"]) + + with pytest.raises(TypeError, match="Slicing a positional slice with .loc"): + ser.loc[:3] = 2 + + +def test_loc_slice_disallows_positional(): + # GH#16121, GH#24612, GH#31810 + dti = date_range("2016-01-01", periods=3) + df = DataFrame(np.random.default_rng(2).random((3, 2)), index=dti) + + ser = df[0] + + msg = ( + "cannot do slice indexing on DatetimeIndex with these " + r"indexers \[1\] of type int" + ) + + for obj in [df, ser]: + with pytest.raises(TypeError, match=msg): + obj.loc[1:3] + + with pytest.raises(TypeError, match="Slicing a positional slice with .loc"): + # GH#31840 enforce incorrect behavior + obj.loc[1:3] = 1 + + with pytest.raises(TypeError, match=msg): + df.loc[1:3, 1] + + with pytest.raises(TypeError, match="Slicing a positional slice with .loc"): + # GH#31840 enforce incorrect behavior + df.loc[1:3, 1] = 2 + + +def test_loc_datetimelike_mismatched_dtypes(): + # GH#32650 dont mix and match datetime/timedelta/period dtypes + + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 3)), + columns=["a", "b", "c"], + index=date_range("2012", freq="h", periods=5), + ) + # create dataframe with non-unique DatetimeIndex + df = df.iloc[[0, 2, 2, 3]].copy() + + dti = df.index + tdi = pd.TimedeltaIndex(dti.asi8) # matching i8 values + + msg = r"None of \[TimedeltaIndex.* are in the \[index\]" + with pytest.raises(KeyError, match=msg): + df.loc[tdi] + + with pytest.raises(KeyError, match=msg): + df["a"].loc[tdi] + + +def test_loc_with_period_index_indexer(): + # GH#4125 + idx = pd.period_range("2002-01", "2003-12", freq="M") + df = DataFrame(np.random.default_rng(2).standard_normal((24, 10)), index=idx) + tm.assert_frame_equal(df, df.loc[idx]) + tm.assert_frame_equal(df, df.loc[list(idx)]) + tm.assert_frame_equal(df, df.loc[list(idx)]) + tm.assert_frame_equal(df.iloc[0:5], df.loc[idx[0:5]]) + tm.assert_frame_equal(df, df.loc[list(idx)]) + + +def test_loc_setitem_multiindex_timestamp(): + # GH#13831 + vals = np.random.default_rng(2).standard_normal((8, 6)) + idx = date_range("1/1/2000", periods=8) + cols = ["A", "B", "C", "D", "E", "F"] + exp = DataFrame(vals, index=idx, columns=cols) + exp.loc[exp.index[1], ("A", "B")] = np.nan + vals[1][0:2] = np.nan + res = DataFrame(vals, index=idx, columns=cols) + tm.assert_frame_equal(res, exp) + + +def test_loc_getitem_multiindex_tuple_level(): + # GH#27591 + lev1 = ["a", "b", "c"] + lev2 = [(0, 1), (1, 0)] + lev3 = [0, 1] + cols = MultiIndex.from_product([lev1, lev2, lev3], names=["x", "y", "z"]) + df = DataFrame(6, index=range(5), columns=cols) + + # the lev2[0] here should be treated as a single label, not as a sequence + # of labels + result = df.loc[:, (lev1[0], lev2[0], lev3[0])] + + # TODO: i think this actually should drop levels + expected = df.iloc[:, :1] + tm.assert_frame_equal(result, expected) + + alt = df.xs((lev1[0], lev2[0], lev3[0]), level=[0, 1, 2], axis=1) + tm.assert_frame_equal(alt, expected) + + # same thing on a Series + ser = df.iloc[0] + expected2 = ser.iloc[:1] + + alt2 = ser.xs((lev1[0], lev2[0], lev3[0]), level=[0, 1, 2], axis=0) + tm.assert_series_equal(alt2, expected2) + + result2 = ser.loc[lev1[0], lev2[0], lev3[0]] + assert result2 == 6 + + +def test_loc_getitem_nullable_index_with_duplicates(): + # GH#34497 + df = DataFrame( + data=np.array([[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, np.nan, np.nan]]).T, + columns=["a", "b", "c"], + dtype="Int64", + ) + df2 = df.set_index("c") + assert df2.index.dtype == "Int64" + + res = df2.loc[1] + expected = Series([1, 5], index=df2.columns, dtype="Int64", name=1) + tm.assert_series_equal(res, expected) + + # pd.NA and duplicates in an object-dtype Index + df2.index = df2.index.astype(object) + res = df2.loc[1] + tm.assert_series_equal(res, expected) + + +@pytest.mark.parametrize("value", [300, np.uint16(300), np.int16(300)]) +def test_loc_setitem_uint8_upcast(value): + # GH#26049 + + df = DataFrame([1, 2, 3, 4], columns=["col1"], dtype="uint8") + with tm.assert_produces_warning(FutureWarning, match="item of incompatible dtype"): + df.loc[2, "col1"] = value # value that can't be held in uint8 + + if np_version_gt2 and isinstance(value, np.int16): + # Note, result type of uint8 + int16 is int16 + # in numpy < 2, though, numpy would inspect the + # value and see that it could fit in an uint16, resulting in a uint16 + dtype = "int16" + else: + dtype = "uint16" + + expected = DataFrame([1, 2, 300, 4], columns=["col1"], dtype=dtype) + tm.assert_frame_equal(df, expected) + + +@pytest.mark.parametrize( + "fill_val,exp_dtype", + [ + (Timestamp("2022-01-06"), "datetime64[ns]"), + (Timestamp("2022-01-07", tz="US/Eastern"), "datetime64[ns, US/Eastern]"), + ], +) +def test_loc_setitem_using_datetimelike_str_as_index(fill_val, exp_dtype): + data = ["2022-01-02", "2022-01-03", "2022-01-04", fill_val.date()] + index = DatetimeIndex(data, tz=fill_val.tz, dtype=exp_dtype) + df = DataFrame([10, 11, 12, 14], columns=["a"], index=index) + # adding new row using an unexisting datetime-like str index + df.loc["2022-01-08", "a"] = 13 + + data.append("2022-01-08") + expected_index = DatetimeIndex(data, dtype=exp_dtype) + tm.assert_index_equal(df.index, expected_index, exact=True) + + +def test_loc_set_int_dtype(): + # GH#23326 + df = DataFrame([list("abc")]) + df.loc[:, "col1"] = 5 + + expected = DataFrame({0: ["a"], 1: ["b"], 2: ["c"], "col1": [5]}) + tm.assert_frame_equal(df, expected) + + +@pytest.mark.filterwarnings(r"ignore:Period with BDay freq is deprecated:FutureWarning") +@pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") +def test_loc_periodindex_3_levels(): + # GH#24091 + p_index = PeriodIndex( + ["20181101 1100", "20181101 1200", "20181102 1300", "20181102 1400"], + name="datetime", + freq="B", + ) + mi_series = DataFrame( + [["A", "B", 1.0], ["A", "C", 2.0], ["Z", "Q", 3.0], ["W", "F", 4.0]], + index=p_index, + columns=["ONE", "TWO", "VALUES"], + ) + mi_series = mi_series.set_index(["ONE", "TWO"], append=True)["VALUES"] + assert mi_series.loc[(p_index[0], "A", "B")] == 1.0 + + +def test_loc_setitem_pyarrow_strings(): + # GH#52319 + pytest.importorskip("pyarrow") + df = DataFrame( + { + "strings": Series(["A", "B", "C"], dtype="string[pyarrow]"), + "ids": Series([True, True, False]), + } + ) + new_value = Series(["X", "Y"]) + df.loc[df.ids, "strings"] = new_value + + expected_df = DataFrame( + { + "strings": Series(["X", "Y", "C"], dtype="string[pyarrow]"), + "ids": Series([True, True, False]), + } + ) + + tm.assert_frame_equal(df, expected_df) + + +class TestLocSeries: + @pytest.mark.parametrize("val,expected", [(2**63 - 1, 3), (2**63, 4)]) + def test_loc_uint64(self, val, expected): + # see GH#19399 + ser = Series({2**63 - 1: 3, 2**63: 4}) + assert ser.loc[val] == expected + + def test_loc_getitem(self, string_series, datetime_series): + inds = string_series.index[[3, 4, 7]] + tm.assert_series_equal(string_series.loc[inds], string_series.reindex(inds)) + tm.assert_series_equal(string_series.iloc[5::2], string_series[5::2]) + + # slice with indices + d1, d2 = datetime_series.index[[5, 15]] + result = datetime_series.loc[d1:d2] + expected = datetime_series.truncate(d1, d2) + tm.assert_series_equal(result, expected) + + # boolean + mask = string_series > string_series.median() + tm.assert_series_equal(string_series.loc[mask], string_series[mask]) + + # ask for index value + assert datetime_series.loc[d1] == datetime_series[d1] + assert datetime_series.loc[d2] == datetime_series[d2] + + def test_loc_getitem_not_monotonic(self, datetime_series): + d1, d2 = datetime_series.index[[5, 15]] + + ts2 = datetime_series[::2].iloc[[1, 2, 0]] + + msg = r"Timestamp\('2000-01-10 00:00:00'\)" + with pytest.raises(KeyError, match=msg): + ts2.loc[d1:d2] + with pytest.raises(KeyError, match=msg): + ts2.loc[d1:d2] = 0 + + def test_loc_getitem_setitem_integer_slice_keyerrors(self): + ser = Series( + np.random.default_rng(2).standard_normal(10), index=list(range(0, 20, 2)) + ) + + # this is OK + cp = ser.copy() + cp.iloc[4:10] = 0 + assert (cp.iloc[4:10] == 0).all() + + # so is this + cp = ser.copy() + cp.iloc[3:11] = 0 + assert (cp.iloc[3:11] == 0).values.all() + + result = ser.iloc[2:6] + result2 = ser.loc[3:11] + expected = ser.reindex([4, 6, 8, 10]) + + tm.assert_series_equal(result, expected) + tm.assert_series_equal(result2, expected) + + # non-monotonic, raise KeyError + s2 = ser.iloc[list(range(5)) + list(range(9, 4, -1))] + with pytest.raises(KeyError, match=r"^3$"): + s2.loc[3:11] + with pytest.raises(KeyError, match=r"^3$"): + s2.loc[3:11] = 0 + + def test_loc_getitem_iterator(self, string_series): + idx = iter(string_series.index[:10]) + result = string_series.loc[idx] + tm.assert_series_equal(result, string_series[:10]) + + def test_loc_setitem_boolean(self, string_series): + mask = string_series > string_series.median() + + result = string_series.copy() + result.loc[mask] = 0 + expected = string_series + expected[mask] = 0 + tm.assert_series_equal(result, expected) + + def test_loc_setitem_corner(self, string_series): + inds = list(string_series.index[[5, 8, 12]]) + string_series.loc[inds] = 5 + msg = r"\['foo'\] not in index" + with pytest.raises(KeyError, match=msg): + string_series.loc[inds + ["foo"]] = 5 + + def test_basic_setitem_with_labels(self, datetime_series): + indices = datetime_series.index[[5, 10, 15]] + + cp = datetime_series.copy() + exp = datetime_series.copy() + cp[indices] = 0 + exp.loc[indices] = 0 + tm.assert_series_equal(cp, exp) + + cp = datetime_series.copy() + exp = datetime_series.copy() + cp[indices[0] : indices[2]] = 0 + exp.loc[indices[0] : indices[2]] = 0 + tm.assert_series_equal(cp, exp) + + def test_loc_setitem_listlike_of_ints(self): + # integer indexes, be careful + ser = Series( + np.random.default_rng(2).standard_normal(10), index=list(range(0, 20, 2)) + ) + inds = [0, 4, 6] + arr_inds = np.array([0, 4, 6]) + + cp = ser.copy() + exp = ser.copy() + ser[inds] = 0 + ser.loc[inds] = 0 + tm.assert_series_equal(cp, exp) + + cp = ser.copy() + exp = ser.copy() + ser[arr_inds] = 0 + ser.loc[arr_inds] = 0 + tm.assert_series_equal(cp, exp) + + inds_notfound = [0, 4, 5, 6] + arr_inds_notfound = np.array([0, 4, 5, 6]) + msg = r"\[5\] not in index" + with pytest.raises(KeyError, match=msg): + ser[inds_notfound] = 0 + with pytest.raises(Exception, match=msg): + ser[arr_inds_notfound] = 0 + + def test_loc_setitem_dt64tz_values(self): + # GH#12089 + ser = Series( + date_range("2011-01-01", periods=3, tz="US/Eastern"), + index=["a", "b", "c"], + ) + s2 = ser.copy() + expected = Timestamp("2011-01-03", tz="US/Eastern") + s2.loc["a"] = expected + result = s2.loc["a"] + assert result == expected + + s2 = ser.copy() + s2.iloc[0] = expected + result = s2.iloc[0] + assert result == expected + + s2 = ser.copy() + s2["a"] = expected + result = s2["a"] + assert result == expected + + @pytest.mark.parametrize("array_fn", [np.array, pd.array, list, tuple]) + @pytest.mark.parametrize("size", [0, 4, 5, 6]) + def test_loc_iloc_setitem_with_listlike(self, size, array_fn): + # GH37748 + # testing insertion, in a Series of size N (here 5), of a listlike object + # of size 0, N-1, N, N+1 + + arr = array_fn([0] * size) + expected = Series([arr, 0, 0, 0, 0], index=list("abcde"), dtype=object) + + ser = Series(0, index=list("abcde"), dtype=object) + ser.loc["a"] = arr + tm.assert_series_equal(ser, expected) + + ser = Series(0, index=list("abcde"), dtype=object) + ser.iloc[0] = arr + tm.assert_series_equal(ser, expected) + + @pytest.mark.parametrize("indexer", [IndexSlice["A", :], ("A", slice(None))]) + def test_loc_series_getitem_too_many_dimensions(self, indexer): + # GH#35349 + ser = Series( + index=MultiIndex.from_tuples([("A", "0"), ("A", "1"), ("B", "0")]), + data=[21, 22, 23], + ) + msg = "Too many indexers" + with pytest.raises(IndexingError, match=msg): + ser.loc[indexer, :] + + with pytest.raises(IndexingError, match=msg): + ser.loc[indexer, :] = 1 + + def test_loc_setitem(self, string_series): + inds = string_series.index[[3, 4, 7]] + + result = string_series.copy() + result.loc[inds] = 5 + + expected = string_series.copy() + expected.iloc[[3, 4, 7]] = 5 + tm.assert_series_equal(result, expected) + + result.iloc[5:10] = 10 + expected[5:10] = 10 + tm.assert_series_equal(result, expected) + + # set slice with indices + d1, d2 = string_series.index[[5, 15]] + result.loc[d1:d2] = 6 + expected[5:16] = 6 # because it's inclusive + tm.assert_series_equal(result, expected) + + # set index value + string_series.loc[d1] = 4 + string_series.loc[d2] = 6 + assert string_series[d1] == 4 + assert string_series[d2] == 6 + + @pytest.mark.parametrize("dtype", ["object", "string"]) + def test_loc_assign_dict_to_row(self, dtype): + # GH41044 + df = DataFrame({"A": ["abc", "def"], "B": ["ghi", "jkl"]}, dtype=dtype) + df.loc[0, :] = {"A": "newA", "B": "newB"} + + expected = DataFrame({"A": ["newA", "def"], "B": ["newB", "jkl"]}, dtype=dtype) + + tm.assert_frame_equal(df, expected) + + @td.skip_array_manager_invalid_test + def test_loc_setitem_dict_timedelta_multiple_set(self): + # GH 16309 + result = DataFrame(columns=["time", "value"]) + result.loc[1] = {"time": Timedelta(6, unit="s"), "value": "foo"} + result.loc[1] = {"time": Timedelta(6, unit="s"), "value": "foo"} + expected = DataFrame( + [[Timedelta(6, unit="s"), "foo"]], columns=["time", "value"], index=[1] + ) + tm.assert_frame_equal(result, expected) + + def test_loc_set_multiple_items_in_multiple_new_columns(self): + # GH 25594 + df = DataFrame(index=[1, 2], columns=["a"]) + df.loc[1, ["b", "c"]] = [6, 7] + + expected = DataFrame( + { + "a": Series([np.nan, np.nan], dtype="object"), + "b": [6, np.nan], + "c": [7, np.nan], + }, + index=[1, 2], + ) + + tm.assert_frame_equal(df, expected) + + def test_getitem_loc_str_periodindex(self): + # GH#33964 + msg = "Period with BDay freq is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + index = pd.period_range(start="2000", periods=20, freq="B") + series = Series(range(20), index=index) + assert series.loc["2000-01-14"] == 9 + + def test_loc_nonunique_masked_index(self): + # GH 57027 + ids = list(range(11)) + index = Index(ids * 1000, dtype="Int64") + df = DataFrame({"val": np.arange(len(index), dtype=np.intp)}, index=index) + result = df.loc[ids] + expected = DataFrame( + {"val": index.argsort(kind="stable").astype(np.intp)}, + index=Index(np.array(ids).repeat(1000), dtype="Int64"), + ) + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_na_indexing.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_na_indexing.py new file mode 100644 index 0000000000000000000000000000000000000000..5364cfe85243001040bf40c8b72b4f71808c3d9c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_na_indexing.py @@ -0,0 +1,75 @@ +import pytest + +import pandas as pd +import pandas._testing as tm + + +@pytest.mark.parametrize( + "values, dtype", + [ + ([], "object"), + ([1, 2, 3], "int64"), + ([1.0, 2.0, 3.0], "float64"), + (["a", "b", "c"], "object"), + (["a", "b", "c"], "string"), + ([1, 2, 3], "datetime64[ns]"), + ([1, 2, 3], "datetime64[ns, CET]"), + ([1, 2, 3], "timedelta64[ns]"), + (["2000", "2001", "2002"], "Period[D]"), + ([1, 0, 3], "Sparse"), + ([pd.Interval(0, 1), pd.Interval(1, 2), pd.Interval(3, 4)], "interval"), + ], +) +@pytest.mark.parametrize( + "mask", [[True, False, False], [True, True, True], [False, False, False]] +) +@pytest.mark.parametrize("indexer_class", [list, pd.array, pd.Index, pd.Series]) +@pytest.mark.parametrize("frame", [True, False]) +def test_series_mask_boolean(values, dtype, mask, indexer_class, frame): + # In case len(values) < 3 + index = ["a", "b", "c"][: len(values)] + mask = mask[: len(values)] + + obj = pd.Series(values, dtype=dtype, index=index) + if frame: + if len(values) == 0: + # Otherwise obj is an empty DataFrame with shape (0, 1) + obj = pd.DataFrame(dtype=dtype, index=index) + else: + obj = obj.to_frame() + + if indexer_class is pd.array: + mask = pd.array(mask, dtype="boolean") + elif indexer_class is pd.Series: + mask = pd.Series(mask, index=obj.index, dtype="boolean") + else: + mask = indexer_class(mask) + + expected = obj[mask] + + result = obj[mask] + tm.assert_equal(result, expected) + + if indexer_class is pd.Series: + msg = "iLocation based boolean indexing cannot use an indexable as a mask" + with pytest.raises(ValueError, match=msg): + result = obj.iloc[mask] + tm.assert_equal(result, expected) + else: + result = obj.iloc[mask] + tm.assert_equal(result, expected) + + result = obj.loc[mask] + tm.assert_equal(result, expected) + + +def test_na_treated_as_false(frame_or_series, indexer_sli): + # https://github.com/pandas-dev/pandas/issues/31503 + obj = frame_or_series([1, 2, 3]) + + mask = pd.array([True, False, None], dtype="boolean") + + result = indexer_sli(obj)[mask] + expected = indexer_sli(obj)[mask.fillna(False)] + + tm.assert_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_partial.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_partial.py new file mode 100644 index 0000000000000000000000000000000000000000..ca551024b4c1fe528a14fece4213a8aaad4a98d1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_partial.py @@ -0,0 +1,702 @@ +""" +test setting *parts* of objects both positionally and label based + +TODO: these should be split among the indexer tests +""" + +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + Index, + Period, + Series, + Timestamp, + date_range, + period_range, +) +import pandas._testing as tm + + +class TestEmptyFrameSetitemExpansion: + def test_empty_frame_setitem_index_name_retained(self): + # GH#31368 empty frame has non-None index.name -> retained + df = DataFrame({}, index=pd.RangeIndex(0, name="df_index")) + series = Series(1.23, index=pd.RangeIndex(4, name="series_index")) + + df["series"] = series + expected = DataFrame( + {"series": [1.23] * 4}, + index=pd.RangeIndex(4, name="df_index"), + columns=Index(["series"], dtype=object), + ) + + tm.assert_frame_equal(df, expected) + + def test_empty_frame_setitem_index_name_inherited(self): + # GH#36527 empty frame has None index.name -> not retained + df = DataFrame() + series = Series(1.23, index=pd.RangeIndex(4, name="series_index")) + df["series"] = series + expected = DataFrame( + {"series": [1.23] * 4}, + index=pd.RangeIndex(4, name="series_index"), + columns=Index(["series"], dtype=object), + ) + tm.assert_frame_equal(df, expected) + + def test_loc_setitem_zerolen_series_columns_align(self): + # columns will align + df = DataFrame(columns=["A", "B"]) + df.loc[0] = Series(1, index=range(4)) + expected = DataFrame(columns=["A", "B"], index=[0], dtype=np.float64) + tm.assert_frame_equal(df, expected) + + # columns will align + df = DataFrame(columns=["A", "B"]) + df.loc[0] = Series(1, index=["B"]) + + exp = DataFrame([[np.nan, 1]], columns=["A", "B"], index=[0], dtype="float64") + tm.assert_frame_equal(df, exp) + + def test_loc_setitem_zerolen_list_length_must_match_columns(self): + # list-like must conform + df = DataFrame(columns=["A", "B"]) + + msg = "cannot set a row with mismatched columns" + with pytest.raises(ValueError, match=msg): + df.loc[0] = [1, 2, 3] + + df = DataFrame(columns=["A", "B"]) + df.loc[3] = [6, 7] # length matches len(df.columns) --> OK! + + exp = DataFrame([[6, 7]], index=[3], columns=["A", "B"], dtype=np.int64) + tm.assert_frame_equal(df, exp) + + def test_partial_set_empty_frame(self): + # partially set with an empty object + # frame + df = DataFrame() + + msg = "cannot set a frame with no defined columns" + + with pytest.raises(ValueError, match=msg): + df.loc[1] = 1 + + with pytest.raises(ValueError, match=msg): + df.loc[1] = Series([1], index=["foo"]) + + msg = "cannot set a frame with no defined index and a scalar" + with pytest.raises(ValueError, match=msg): + df.loc[:, 1] = 1 + + def test_partial_set_empty_frame2(self): + # these work as they don't really change + # anything but the index + # GH#5632 + expected = DataFrame( + columns=Index(["foo"], dtype=object), index=Index([], dtype="object") + ) + + df = DataFrame(index=Index([], dtype="object")) + df["foo"] = Series([], dtype="object") + + tm.assert_frame_equal(df, expected) + + df = DataFrame(index=Index([])) + df["foo"] = Series(df.index) + + tm.assert_frame_equal(df, expected) + + df = DataFrame(index=Index([])) + df["foo"] = df.index + + tm.assert_frame_equal(df, expected) + + def test_partial_set_empty_frame3(self): + expected = DataFrame( + columns=Index(["foo"], dtype=object), index=Index([], dtype="int64") + ) + expected["foo"] = expected["foo"].astype("float64") + + df = DataFrame(index=Index([], dtype="int64")) + df["foo"] = [] + + tm.assert_frame_equal(df, expected) + + df = DataFrame(index=Index([], dtype="int64")) + df["foo"] = Series(np.arange(len(df)), dtype="float64") + + tm.assert_frame_equal(df, expected) + + def test_partial_set_empty_frame4(self): + df = DataFrame(index=Index([], dtype="int64")) + df["foo"] = range(len(df)) + + expected = DataFrame( + columns=Index(["foo"], dtype=object), index=Index([], dtype="int64") + ) + # range is int-dtype-like, so we get int64 dtype + expected["foo"] = expected["foo"].astype("int64") + tm.assert_frame_equal(df, expected) + + def test_partial_set_empty_frame5(self): + df = DataFrame() + tm.assert_index_equal(df.columns, pd.RangeIndex(0)) + df2 = DataFrame() + df2[1] = Series([1], index=["foo"]) + df.loc[:, 1] = Series([1], index=["foo"]) + tm.assert_frame_equal(df, DataFrame([[1]], index=["foo"], columns=[1])) + tm.assert_frame_equal(df, df2) + + def test_partial_set_empty_frame_no_index(self): + # no index to start + expected = DataFrame({0: Series(1, index=range(4))}, columns=["A", "B", 0]) + + df = DataFrame(columns=["A", "B"]) + df[0] = Series(1, index=range(4)) + tm.assert_frame_equal(df, expected) + + df = DataFrame(columns=["A", "B"]) + df.loc[:, 0] = Series(1, index=range(4)) + tm.assert_frame_equal(df, expected) + + def test_partial_set_empty_frame_row(self): + # GH#5720, GH#5744 + # don't create rows when empty + expected = DataFrame(columns=["A", "B", "New"], index=Index([], dtype="int64")) + expected["A"] = expected["A"].astype("int64") + expected["B"] = expected["B"].astype("float64") + expected["New"] = expected["New"].astype("float64") + + df = DataFrame({"A": [1, 2, 3], "B": [1.2, 4.2, 5.2]}) + y = df[df.A > 5] + y["New"] = np.nan + tm.assert_frame_equal(y, expected) + + expected = DataFrame(columns=["a", "b", "c c", "d"]) + expected["d"] = expected["d"].astype("int64") + df = DataFrame(columns=["a", "b", "c c"]) + df["d"] = 3 + tm.assert_frame_equal(df, expected) + tm.assert_series_equal(df["c c"], Series(name="c c", dtype=object)) + + # reindex columns is ok + df = DataFrame({"A": [1, 2, 3], "B": [1.2, 4.2, 5.2]}) + y = df[df.A > 5] + result = y.reindex(columns=["A", "B", "C"]) + expected = DataFrame(columns=["A", "B", "C"]) + expected["A"] = expected["A"].astype("int64") + expected["B"] = expected["B"].astype("float64") + expected["C"] = expected["C"].astype("float64") + tm.assert_frame_equal(result, expected) + + def test_partial_set_empty_frame_set_series(self): + # GH#5756 + # setting with empty Series + df = DataFrame(Series(dtype=object)) + expected = DataFrame({0: Series(dtype=object)}) + tm.assert_frame_equal(df, expected) + + df = DataFrame(Series(name="foo", dtype=object)) + expected = DataFrame({"foo": Series(dtype=object)}) + tm.assert_frame_equal(df, expected) + + def test_partial_set_empty_frame_empty_copy_assignment(self): + # GH#5932 + # copy on empty with assignment fails + df = DataFrame(index=[0]) + df = df.copy() + df["a"] = 0 + expected = DataFrame(0, index=[0], columns=Index(["a"], dtype=object)) + tm.assert_frame_equal(df, expected) + + def test_partial_set_empty_frame_empty_consistencies(self, using_infer_string): + # GH#6171 + # consistency on empty frames + df = DataFrame(columns=["x", "y"]) + df["x"] = [1, 2] + expected = DataFrame({"x": [1, 2], "y": [np.nan, np.nan]}) + tm.assert_frame_equal(df, expected, check_dtype=False) + + df = DataFrame(columns=["x", "y"]) + df["x"] = ["1", "2"] + expected = DataFrame( + { + "x": Series( + ["1", "2"], + dtype=object if not using_infer_string else "string[pyarrow_numpy]", + ), + "y": Series([np.nan, np.nan], dtype=object), + } + ) + tm.assert_frame_equal(df, expected) + + df = DataFrame(columns=["x", "y"]) + df.loc[0, "x"] = 1 + expected = DataFrame({"x": [1], "y": [np.nan]}) + tm.assert_frame_equal(df, expected, check_dtype=False) + + +class TestPartialSetting: + def test_partial_setting(self): + # GH2578, allow ix and friends to partially set + + # series + s_orig = Series([1, 2, 3]) + + s = s_orig.copy() + s[5] = 5 + expected = Series([1, 2, 3, 5], index=[0, 1, 2, 5]) + tm.assert_series_equal(s, expected) + + s = s_orig.copy() + s.loc[5] = 5 + expected = Series([1, 2, 3, 5], index=[0, 1, 2, 5]) + tm.assert_series_equal(s, expected) + + s = s_orig.copy() + s[5] = 5.0 + expected = Series([1, 2, 3, 5.0], index=[0, 1, 2, 5]) + tm.assert_series_equal(s, expected) + + s = s_orig.copy() + s.loc[5] = 5.0 + expected = Series([1, 2, 3, 5.0], index=[0, 1, 2, 5]) + tm.assert_series_equal(s, expected) + + # iloc/iat raise + s = s_orig.copy() + + msg = "iloc cannot enlarge its target object" + with pytest.raises(IndexError, match=msg): + s.iloc[3] = 5.0 + + msg = "index 3 is out of bounds for axis 0 with size 3" + with pytest.raises(IndexError, match=msg): + s.iat[3] = 5.0 + + @pytest.mark.filterwarnings("ignore:Setting a value on a view:FutureWarning") + def test_partial_setting_frame(self, using_array_manager): + df_orig = DataFrame( + np.arange(6).reshape(3, 2), columns=["A", "B"], dtype="int64" + ) + + # iloc/iat raise + df = df_orig.copy() + + msg = "iloc cannot enlarge its target object" + with pytest.raises(IndexError, match=msg): + df.iloc[4, 2] = 5.0 + + msg = "index 2 is out of bounds for axis 0 with size 2" + if using_array_manager: + msg = "list index out of range" + with pytest.raises(IndexError, match=msg): + df.iat[4, 2] = 5.0 + + # row setting where it exists + expected = DataFrame({"A": [0, 4, 4], "B": [1, 5, 5]}) + df = df_orig.copy() + df.iloc[1] = df.iloc[2] + tm.assert_frame_equal(df, expected) + + expected = DataFrame({"A": [0, 4, 4], "B": [1, 5, 5]}) + df = df_orig.copy() + df.loc[1] = df.loc[2] + tm.assert_frame_equal(df, expected) + + # like 2578, partial setting with dtype preservation + expected = DataFrame({"A": [0, 2, 4, 4], "B": [1, 3, 5, 5]}) + df = df_orig.copy() + df.loc[3] = df.loc[2] + tm.assert_frame_equal(df, expected) + + # single dtype frame, overwrite + expected = DataFrame({"A": [0, 2, 4], "B": [0, 2, 4]}) + df = df_orig.copy() + df.loc[:, "B"] = df.loc[:, "A"] + tm.assert_frame_equal(df, expected) + + # mixed dtype frame, overwrite + expected = DataFrame({"A": [0, 2, 4], "B": Series([0.0, 2.0, 4.0])}) + df = df_orig.copy() + df["B"] = df["B"].astype(np.float64) + # as of 2.0, df.loc[:, "B"] = ... attempts (and here succeeds) at + # setting inplace + df.loc[:, "B"] = df.loc[:, "A"] + tm.assert_frame_equal(df, expected) + + # single dtype frame, partial setting + expected = df_orig.copy() + expected["C"] = df["A"] + df = df_orig.copy() + df.loc[:, "C"] = df.loc[:, "A"] + tm.assert_frame_equal(df, expected) + + # mixed frame, partial setting + expected = df_orig.copy() + expected["C"] = df["A"] + df = df_orig.copy() + df.loc[:, "C"] = df.loc[:, "A"] + tm.assert_frame_equal(df, expected) + + def test_partial_setting2(self): + # GH 8473 + dates = date_range("1/1/2000", periods=8) + df_orig = DataFrame( + np.random.default_rng(2).standard_normal((8, 4)), + index=dates, + columns=["A", "B", "C", "D"], + ) + + expected = pd.concat( + [df_orig, DataFrame({"A": 7}, index=dates[-1:] + dates.freq)], sort=True + ) + df = df_orig.copy() + df.loc[dates[-1] + dates.freq, "A"] = 7 + tm.assert_frame_equal(df, expected) + df = df_orig.copy() + df.at[dates[-1] + dates.freq, "A"] = 7 + tm.assert_frame_equal(df, expected) + + exp_other = DataFrame({0: 7}, index=dates[-1:] + dates.freq) + expected = pd.concat([df_orig, exp_other], axis=1) + + df = df_orig.copy() + df.loc[dates[-1] + dates.freq, 0] = 7 + tm.assert_frame_equal(df, expected) + df = df_orig.copy() + df.at[dates[-1] + dates.freq, 0] = 7 + tm.assert_frame_equal(df, expected) + + def test_partial_setting_mixed_dtype(self): + # in a mixed dtype environment, try to preserve dtypes + # by appending + df = DataFrame([[True, 1], [False, 2]], columns=["female", "fitness"]) + + s = df.loc[1].copy() + s.name = 2 + expected = pd.concat([df, DataFrame(s).T.infer_objects()]) + + df.loc[2] = df.loc[1] + tm.assert_frame_equal(df, expected) + + def test_series_partial_set(self): + # partial set with new index + # Regression from GH4825 + ser = Series([0.1, 0.2], index=[1, 2]) + + # loc equiv to .reindex + expected = Series([np.nan, 0.2, np.nan], index=[3, 2, 3]) + with pytest.raises(KeyError, match=r"not in index"): + ser.loc[[3, 2, 3]] + + result = ser.reindex([3, 2, 3]) + tm.assert_series_equal(result, expected, check_index_type=True) + + expected = Series([np.nan, 0.2, np.nan, np.nan], index=[3, 2, 3, "x"]) + with pytest.raises(KeyError, match="not in index"): + ser.loc[[3, 2, 3, "x"]] + + result = ser.reindex([3, 2, 3, "x"]) + tm.assert_series_equal(result, expected, check_index_type=True) + + expected = Series([0.2, 0.2, 0.1], index=[2, 2, 1]) + result = ser.loc[[2, 2, 1]] + tm.assert_series_equal(result, expected, check_index_type=True) + + expected = Series([0.2, 0.2, np.nan, 0.1], index=[2, 2, "x", 1]) + with pytest.raises(KeyError, match="not in index"): + ser.loc[[2, 2, "x", 1]] + + result = ser.reindex([2, 2, "x", 1]) + tm.assert_series_equal(result, expected, check_index_type=True) + + # raises as nothing is in the index + msg = ( + rf"\"None of \[Index\(\[3, 3, 3\], dtype='{np.dtype(int)}'\)\] " + r"are in the \[index\]\"" + ) + with pytest.raises(KeyError, match=msg): + ser.loc[[3, 3, 3]] + + expected = Series([0.2, 0.2, np.nan], index=[2, 2, 3]) + with pytest.raises(KeyError, match="not in index"): + ser.loc[[2, 2, 3]] + + result = ser.reindex([2, 2, 3]) + tm.assert_series_equal(result, expected, check_index_type=True) + + s = Series([0.1, 0.2, 0.3], index=[1, 2, 3]) + expected = Series([0.3, np.nan, np.nan], index=[3, 4, 4]) + with pytest.raises(KeyError, match="not in index"): + s.loc[[3, 4, 4]] + + result = s.reindex([3, 4, 4]) + tm.assert_series_equal(result, expected, check_index_type=True) + + s = Series([0.1, 0.2, 0.3, 0.4], index=[1, 2, 3, 4]) + expected = Series([np.nan, 0.3, 0.3], index=[5, 3, 3]) + with pytest.raises(KeyError, match="not in index"): + s.loc[[5, 3, 3]] + + result = s.reindex([5, 3, 3]) + tm.assert_series_equal(result, expected, check_index_type=True) + + s = Series([0.1, 0.2, 0.3, 0.4], index=[1, 2, 3, 4]) + expected = Series([np.nan, 0.4, 0.4], index=[5, 4, 4]) + with pytest.raises(KeyError, match="not in index"): + s.loc[[5, 4, 4]] + + result = s.reindex([5, 4, 4]) + tm.assert_series_equal(result, expected, check_index_type=True) + + s = Series([0.1, 0.2, 0.3, 0.4], index=[4, 5, 6, 7]) + expected = Series([0.4, np.nan, np.nan], index=[7, 2, 2]) + with pytest.raises(KeyError, match="not in index"): + s.loc[[7, 2, 2]] + + result = s.reindex([7, 2, 2]) + tm.assert_series_equal(result, expected, check_index_type=True) + + s = Series([0.1, 0.2, 0.3, 0.4], index=[1, 2, 3, 4]) + expected = Series([0.4, np.nan, np.nan], index=[4, 5, 5]) + with pytest.raises(KeyError, match="not in index"): + s.loc[[4, 5, 5]] + + result = s.reindex([4, 5, 5]) + tm.assert_series_equal(result, expected, check_index_type=True) + + # iloc + expected = Series([0.2, 0.2, 0.1, 0.1], index=[2, 2, 1, 1]) + result = ser.iloc[[1, 1, 0, 0]] + tm.assert_series_equal(result, expected, check_index_type=True) + + def test_series_partial_set_with_name(self): + # GH 11497 + + idx = Index([1, 2], dtype="int64", name="idx") + ser = Series([0.1, 0.2], index=idx, name="s") + + # loc + with pytest.raises(KeyError, match=r"\[3\] not in index"): + ser.loc[[3, 2, 3]] + + with pytest.raises(KeyError, match=r"not in index"): + ser.loc[[3, 2, 3, "x"]] + + exp_idx = Index([2, 2, 1], dtype="int64", name="idx") + expected = Series([0.2, 0.2, 0.1], index=exp_idx, name="s") + result = ser.loc[[2, 2, 1]] + tm.assert_series_equal(result, expected, check_index_type=True) + + with pytest.raises(KeyError, match=r"\['x'\] not in index"): + ser.loc[[2, 2, "x", 1]] + + # raises as nothing is in the index + msg = ( + rf"\"None of \[Index\(\[3, 3, 3\], dtype='{np.dtype(int)}', " + r"name='idx'\)\] are in the \[index\]\"" + ) + with pytest.raises(KeyError, match=msg): + ser.loc[[3, 3, 3]] + + with pytest.raises(KeyError, match="not in index"): + ser.loc[[2, 2, 3]] + + idx = Index([1, 2, 3], dtype="int64", name="idx") + with pytest.raises(KeyError, match="not in index"): + Series([0.1, 0.2, 0.3], index=idx, name="s").loc[[3, 4, 4]] + + idx = Index([1, 2, 3, 4], dtype="int64", name="idx") + with pytest.raises(KeyError, match="not in index"): + Series([0.1, 0.2, 0.3, 0.4], index=idx, name="s").loc[[5, 3, 3]] + + idx = Index([1, 2, 3, 4], dtype="int64", name="idx") + with pytest.raises(KeyError, match="not in index"): + Series([0.1, 0.2, 0.3, 0.4], index=idx, name="s").loc[[5, 4, 4]] + + idx = Index([4, 5, 6, 7], dtype="int64", name="idx") + with pytest.raises(KeyError, match="not in index"): + Series([0.1, 0.2, 0.3, 0.4], index=idx, name="s").loc[[7, 2, 2]] + + idx = Index([1, 2, 3, 4], dtype="int64", name="idx") + with pytest.raises(KeyError, match="not in index"): + Series([0.1, 0.2, 0.3, 0.4], index=idx, name="s").loc[[4, 5, 5]] + + # iloc + exp_idx = Index([2, 2, 1, 1], dtype="int64", name="idx") + expected = Series([0.2, 0.2, 0.1, 0.1], index=exp_idx, name="s") + result = ser.iloc[[1, 1, 0, 0]] + tm.assert_series_equal(result, expected, check_index_type=True) + + @pytest.mark.parametrize("key", [100, 100.0]) + def test_setitem_with_expansion_numeric_into_datetimeindex(self, key): + # GH#4940 inserting non-strings + orig = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + df = orig.copy() + + df.loc[key, :] = df.iloc[0] + ex_index = Index(list(orig.index) + [key], dtype=object, name=orig.index.name) + ex_data = np.concatenate([orig.values, df.iloc[[0]].values], axis=0) + expected = DataFrame(ex_data, index=ex_index, columns=orig.columns) + + tm.assert_frame_equal(df, expected) + + def test_partial_set_invalid(self): + # GH 4940 + # allow only setting of 'valid' values + + orig = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + + # allow object conversion here + df = orig.copy() + df.loc["a", :] = df.iloc[0] + ser = Series(df.iloc[0], name="a") + exp = pd.concat([orig, DataFrame(ser).T.infer_objects()]) + tm.assert_frame_equal(df, exp) + tm.assert_index_equal(df.index, Index(orig.index.tolist() + ["a"])) + assert df.index.dtype == "object" + + @pytest.mark.parametrize( + "idx,labels,expected_idx", + [ + ( + period_range(start="2000", periods=20, freq="D"), + ["2000-01-04", "2000-01-08", "2000-01-12"], + [ + Period("2000-01-04", freq="D"), + Period("2000-01-08", freq="D"), + Period("2000-01-12", freq="D"), + ], + ), + ( + date_range(start="2000", periods=20, freq="D"), + ["2000-01-04", "2000-01-08", "2000-01-12"], + [ + Timestamp("2000-01-04"), + Timestamp("2000-01-08"), + Timestamp("2000-01-12"), + ], + ), + ( + pd.timedelta_range(start="1 day", periods=20), + ["4D", "8D", "12D"], + [pd.Timedelta("4 day"), pd.Timedelta("8 day"), pd.Timedelta("12 day")], + ), + ], + ) + def test_loc_with_list_of_strings_representing_datetimes( + self, idx, labels, expected_idx, frame_or_series + ): + # GH 11278 + obj = frame_or_series(range(20), index=idx) + + expected_value = [3, 7, 11] + expected = frame_or_series(expected_value, expected_idx) + + tm.assert_equal(expected, obj.loc[labels]) + if frame_or_series is Series: + tm.assert_series_equal(expected, obj[labels]) + + @pytest.mark.parametrize( + "idx,labels", + [ + ( + period_range(start="2000", periods=20, freq="D"), + ["2000-01-04", "2000-01-30"], + ), + ( + date_range(start="2000", periods=20, freq="D"), + ["2000-01-04", "2000-01-30"], + ), + (pd.timedelta_range(start="1 day", periods=20), ["3 day", "30 day"]), + ], + ) + def test_loc_with_list_of_strings_representing_datetimes_missing_value( + self, idx, labels + ): + # GH 11278 + ser = Series(range(20), index=idx) + df = DataFrame(range(20), index=idx) + msg = r"not in index" + + with pytest.raises(KeyError, match=msg): + ser.loc[labels] + with pytest.raises(KeyError, match=msg): + ser[labels] + with pytest.raises(KeyError, match=msg): + df.loc[labels] + + @pytest.mark.parametrize( + "idx,labels,msg", + [ + ( + period_range(start="2000", periods=20, freq="D"), + Index(["4D", "8D"], dtype=object), + ( + r"None of \[Index\(\['4D', '8D'\], dtype='object'\)\] " + r"are in the \[index\]" + ), + ), + ( + date_range(start="2000", periods=20, freq="D"), + Index(["4D", "8D"], dtype=object), + ( + r"None of \[Index\(\['4D', '8D'\], dtype='object'\)\] " + r"are in the \[index\]" + ), + ), + ( + pd.timedelta_range(start="1 day", periods=20), + Index(["2000-01-04", "2000-01-08"], dtype=object), + ( + r"None of \[Index\(\['2000-01-04', '2000-01-08'\], " + r"dtype='object'\)\] are in the \[index\]" + ), + ), + ], + ) + def test_loc_with_list_of_strings_representing_datetimes_not_matched_type( + self, idx, labels, msg + ): + # GH 11278 + ser = Series(range(20), index=idx) + df = DataFrame(range(20), index=idx) + + with pytest.raises(KeyError, match=msg): + ser.loc[labels] + with pytest.raises(KeyError, match=msg): + ser[labels] + with pytest.raises(KeyError, match=msg): + df.loc[labels] + + +class TestStringSlicing: + def test_slice_irregular_datetime_index_with_nan(self): + # GH36953 + index = pd.to_datetime(["2012-01-01", "2012-01-02", "2012-01-03", None]) + df = DataFrame(range(len(index)), index=index) + expected = DataFrame(range(len(index[:3])), index=index[:3]) + with pytest.raises(KeyError, match="non-existing keys is not allowed"): + # Upper bound is not in index (which is unordered) + # GH53983 + # GH37819 + df["2012-01-01":"2012-01-04"] + # Need this precision for right bound since the right slice + # bound is "rounded" up to the largest timepoint smaller than + # the next "resolution"-step of the provided point. + # e.g. 2012-01-03 is rounded up to 2012-01-04 - 1ns + result = df["2012-01-01":"2012-01-03 00:00:00.000000000"] + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_scalar.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_scalar.py new file mode 100644 index 0000000000000000000000000000000000000000..29e3dc0aebe9551ae94566904372dde3563fbef9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/indexing/test_scalar.py @@ -0,0 +1,303 @@ +""" test scalar indexing, including at and iat """ +from datetime import ( + datetime, + timedelta, +) +import itertools + +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Series, + Timedelta, + Timestamp, + date_range, +) +import pandas._testing as tm + + +def generate_indices(f, values=False): + """ + generate the indices + if values is True , use the axis values + is False, use the range + """ + axes = f.axes + if values: + axes = (list(range(len(ax))) for ax in axes) + + return itertools.product(*axes) + + +class TestScalar: + @pytest.mark.parametrize("kind", ["series", "frame"]) + @pytest.mark.parametrize("col", ["ints", "uints"]) + def test_iat_set_ints(self, kind, col, request): + f = request.getfixturevalue(f"{kind}_{col}") + indices = generate_indices(f, True) + for i in indices: + f.iat[i] = 1 + expected = f.values[i] + tm.assert_almost_equal(expected, 1) + + @pytest.mark.parametrize("kind", ["series", "frame"]) + @pytest.mark.parametrize("col", ["labels", "ts", "floats"]) + def test_iat_set_other(self, kind, col, request): + f = request.getfixturevalue(f"{kind}_{col}") + msg = "iAt based indexing can only have integer indexers" + with pytest.raises(ValueError, match=msg): + idx = next(generate_indices(f, False)) + f.iat[idx] = 1 + + @pytest.mark.parametrize("kind", ["series", "frame"]) + @pytest.mark.parametrize("col", ["ints", "uints", "labels", "ts", "floats"]) + def test_at_set_ints_other(self, kind, col, request): + f = request.getfixturevalue(f"{kind}_{col}") + indices = generate_indices(f, False) + for i in indices: + f.at[i] = 1 + expected = f.loc[i] + tm.assert_almost_equal(expected, 1) + + +class TestAtAndiAT: + # at and iat tests that don't need Base class + + def test_float_index_at_iat(self): + ser = Series([1, 2, 3], index=[0.1, 0.2, 0.3]) + for el, item in ser.items(): + assert ser.at[el] == item + for i in range(len(ser)): + assert ser.iat[i] == i + 1 + + def test_at_iat_coercion(self): + # as timestamp is not a tuple! + dates = date_range("1/1/2000", periods=8) + df = DataFrame( + np.random.default_rng(2).standard_normal((8, 4)), + index=dates, + columns=["A", "B", "C", "D"], + ) + s = df["A"] + + result = s.at[dates[5]] + xp = s.values[5] + assert result == xp + + @pytest.mark.parametrize( + "ser, expected", + [ + [ + Series(["2014-01-01", "2014-02-02"], dtype="datetime64[ns]"), + Timestamp("2014-02-02"), + ], + [ + Series(["1 days", "2 days"], dtype="timedelta64[ns]"), + Timedelta("2 days"), + ], + ], + ) + def test_iloc_iat_coercion_datelike(self, indexer_ial, ser, expected): + # GH 7729 + # make sure we are boxing the returns + result = indexer_ial(ser)[1] + assert result == expected + + def test_imethods_with_dups(self): + # GH6493 + # iat/iloc with dups + + s = Series(range(5), index=[1, 1, 2, 2, 3], dtype="int64") + result = s.iloc[2] + assert result == 2 + result = s.iat[2] + assert result == 2 + + msg = "index 10 is out of bounds for axis 0 with size 5" + with pytest.raises(IndexError, match=msg): + s.iat[10] + msg = "index -10 is out of bounds for axis 0 with size 5" + with pytest.raises(IndexError, match=msg): + s.iat[-10] + + result = s.iloc[[2, 3]] + expected = Series([2, 3], [2, 2], dtype="int64") + tm.assert_series_equal(result, expected) + + df = s.to_frame() + result = df.iloc[2] + expected = Series(2, index=[0], name=2) + tm.assert_series_equal(result, expected) + + result = df.iat[2, 0] + assert result == 2 + + def test_frame_at_with_duplicate_axes(self): + # GH#33041 + arr = np.random.default_rng(2).standard_normal(6).reshape(3, 2) + df = DataFrame(arr, columns=["A", "A"]) + + result = df.at[0, "A"] + expected = df.iloc[0].copy() + + tm.assert_series_equal(result, expected) + + result = df.T.at["A", 0] + tm.assert_series_equal(result, expected) + + # setter + df.at[1, "A"] = 2 + expected = Series([2.0, 2.0], index=["A", "A"], name=1) + tm.assert_series_equal(df.iloc[1], expected) + + def test_at_getitem_dt64tz_values(self): + # gh-15822 + df = DataFrame( + { + "name": ["John", "Anderson"], + "date": [ + Timestamp(2017, 3, 13, 13, 32, 56), + Timestamp(2017, 2, 16, 12, 10, 3), + ], + } + ) + df["date"] = df["date"].dt.tz_localize("Asia/Shanghai") + + expected = Timestamp("2017-03-13 13:32:56+0800", tz="Asia/Shanghai") + + result = df.loc[0, "date"] + assert result == expected + + result = df.at[0, "date"] + assert result == expected + + def test_mixed_index_at_iat_loc_iloc_series(self): + # GH 19860 + s = Series([1, 2, 3, 4, 5], index=["a", "b", "c", 1, 2]) + for el, item in s.items(): + assert s.at[el] == s.loc[el] == item + for i in range(len(s)): + assert s.iat[i] == s.iloc[i] == i + 1 + + with pytest.raises(KeyError, match="^4$"): + s.at[4] + with pytest.raises(KeyError, match="^4$"): + s.loc[4] + + def test_mixed_index_at_iat_loc_iloc_dataframe(self): + # GH 19860 + df = DataFrame( + [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]], columns=["a", "b", "c", 1, 2] + ) + for rowIdx, row in df.iterrows(): + for el, item in row.items(): + assert df.at[rowIdx, el] == df.loc[rowIdx, el] == item + + for row in range(2): + for i in range(5): + assert df.iat[row, i] == df.iloc[row, i] == row * 5 + i + + with pytest.raises(KeyError, match="^3$"): + df.at[0, 3] + with pytest.raises(KeyError, match="^3$"): + df.loc[0, 3] + + def test_iat_setter_incompatible_assignment(self): + # GH 23236 + result = DataFrame({"a": [0.0, 1.0], "b": [4, 5]}) + result.iat[0, 0] = None + expected = DataFrame({"a": [None, 1], "b": [4, 5]}) + tm.assert_frame_equal(result, expected) + + +def test_iat_dont_wrap_object_datetimelike(): + # GH#32809 .iat calls go through DataFrame._get_value, should not + # call maybe_box_datetimelike + dti = date_range("2016-01-01", periods=3) + tdi = dti - dti + ser = Series(dti.to_pydatetime(), dtype=object) + ser2 = Series(tdi.to_pytimedelta(), dtype=object) + df = DataFrame({"A": ser, "B": ser2}) + assert (df.dtypes == object).all() + + for result in [df.at[0, "A"], df.iat[0, 0], df.loc[0, "A"], df.iloc[0, 0]]: + assert result is ser[0] + assert isinstance(result, datetime) + assert not isinstance(result, Timestamp) + + for result in [df.at[1, "B"], df.iat[1, 1], df.loc[1, "B"], df.iloc[1, 1]]: + assert result is ser2[1] + assert isinstance(result, timedelta) + assert not isinstance(result, Timedelta) + + +def test_at_with_tuple_index_get(): + # GH 26989 + # DataFrame.at getter works with Index of tuples + df = DataFrame({"a": [1, 2]}, index=[(1, 2), (3, 4)]) + assert df.index.nlevels == 1 + assert df.at[(1, 2), "a"] == 1 + + # Series.at getter works with Index of tuples + series = df["a"] + assert series.index.nlevels == 1 + assert series.at[(1, 2)] == 1 + + +@pytest.mark.filterwarnings("ignore:Setting a value on a view:FutureWarning") +def test_at_with_tuple_index_set(): + # GH 26989 + # DataFrame.at setter works with Index of tuples + df = DataFrame({"a": [1, 2]}, index=[(1, 2), (3, 4)]) + assert df.index.nlevels == 1 + df.at[(1, 2), "a"] = 2 + assert df.at[(1, 2), "a"] == 2 + + # Series.at setter works with Index of tuples + series = df["a"] + assert series.index.nlevels == 1 + series.at[1, 2] = 3 + assert series.at[1, 2] == 3 + + +class TestMultiIndexScalar: + def test_multiindex_at_get(self): + # GH 26989 + # DataFrame.at and DataFrame.loc getter works with MultiIndex + df = DataFrame({"a": [1, 2]}, index=[[1, 2], [3, 4]]) + assert df.index.nlevels == 2 + assert df.at[(1, 3), "a"] == 1 + assert df.loc[(1, 3), "a"] == 1 + + # Series.at and Series.loc getter works with MultiIndex + series = df["a"] + assert series.index.nlevels == 2 + assert series.at[1, 3] == 1 + assert series.loc[1, 3] == 1 + + @pytest.mark.filterwarnings("ignore:Setting a value on a view:FutureWarning") + def test_multiindex_at_set(self): + # GH 26989 + # DataFrame.at and DataFrame.loc setter works with MultiIndex + df = DataFrame({"a": [1, 2]}, index=[[1, 2], [3, 4]]) + assert df.index.nlevels == 2 + df.at[(1, 3), "a"] = 3 + assert df.at[(1, 3), "a"] == 3 + df.loc[(1, 3), "a"] = 4 + assert df.loc[(1, 3), "a"] == 4 + + # Series.at and Series.loc setter works with MultiIndex + series = df["a"] + assert series.index.nlevels == 2 + series.at[1, 3] = 5 + assert series.at[1, 3] == 5 + series.loc[1, 3] = 6 + assert series.loc[1, 3] == 6 + + def test_multiindex_at_get_one_level(self): + # GH#38053 + s2 = Series((0, 1), index=[[False, True]]) + result = s2.at[False] + assert result == 0 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/interchange/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/interchange/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/interchange/test_impl.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/interchange/test_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..25418b8bb2b37d3241ffe0d066f8877db80dded5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/interchange/test_impl.py @@ -0,0 +1,604 @@ +from datetime import ( + datetime, + timezone, +) + +import numpy as np +import pytest + +from pandas._libs.tslibs import iNaT +from pandas.compat import ( + is_ci_environment, + is_platform_windows, +) +from pandas.compat.numpy import np_version_lt1p23 + +import pandas as pd +import pandas._testing as tm +from pandas.core.interchange.column import PandasColumn +from pandas.core.interchange.dataframe_protocol import ( + ColumnNullType, + DtypeKind, +) +from pandas.core.interchange.from_dataframe import from_dataframe +from pandas.core.interchange.utils import ArrowCTypes + + +@pytest.fixture +def data_categorical(): + return { + "ordered": pd.Categorical(list("testdata") * 30, ordered=True), + "unordered": pd.Categorical(list("testdata") * 30, ordered=False), + } + + +@pytest.fixture +def string_data(): + return { + "separator data": [ + "abC|DeF,Hik", + "234,3245.67", + "gSaf,qWer|Gre", + "asd3,4sad|", + np.nan, + ] + } + + +@pytest.mark.parametrize("data", [("ordered", True), ("unordered", False)]) +def test_categorical_dtype(data, data_categorical): + df = pd.DataFrame({"A": (data_categorical[data[0]])}) + + col = df.__dataframe__().get_column_by_name("A") + assert col.dtype[0] == DtypeKind.CATEGORICAL + assert col.null_count == 0 + assert col.describe_null == (ColumnNullType.USE_SENTINEL, -1) + assert col.num_chunks() == 1 + desc_cat = col.describe_categorical + assert desc_cat["is_ordered"] == data[1] + assert desc_cat["is_dictionary"] is True + assert isinstance(desc_cat["categories"], PandasColumn) + tm.assert_series_equal( + desc_cat["categories"]._col, pd.Series(["a", "d", "e", "s", "t"]) + ) + + tm.assert_frame_equal(df, from_dataframe(df.__dataframe__())) + + +def test_categorical_pyarrow(): + # GH 49889 + pa = pytest.importorskip("pyarrow", "11.0.0") + + arr = ["Mon", "Tue", "Mon", "Wed", "Mon", "Thu", "Fri", "Sat", "Sun"] + table = pa.table({"weekday": pa.array(arr).dictionary_encode()}) + exchange_df = table.__dataframe__() + result = from_dataframe(exchange_df) + weekday = pd.Categorical( + arr, categories=["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] + ) + expected = pd.DataFrame({"weekday": weekday}) + tm.assert_frame_equal(result, expected) + + +def test_empty_categorical_pyarrow(): + # https://github.com/pandas-dev/pandas/issues/53077 + pa = pytest.importorskip("pyarrow", "11.0.0") + + arr = [None] + table = pa.table({"arr": pa.array(arr, "float64").dictionary_encode()}) + exchange_df = table.__dataframe__() + result = pd.api.interchange.from_dataframe(exchange_df) + expected = pd.DataFrame({"arr": pd.Categorical([np.nan])}) + tm.assert_frame_equal(result, expected) + + +def test_large_string_pyarrow(): + # GH 52795 + pa = pytest.importorskip("pyarrow", "11.0.0") + + arr = ["Mon", "Tue"] + table = pa.table({"weekday": pa.array(arr, "large_string")}) + exchange_df = table.__dataframe__() + result = from_dataframe(exchange_df) + expected = pd.DataFrame({"weekday": ["Mon", "Tue"]}) + tm.assert_frame_equal(result, expected) + + # check round-trip + assert pa.Table.equals(pa.interchange.from_dataframe(result), table) + + +@pytest.mark.parametrize( + ("offset", "length", "expected_values"), + [ + (0, None, [3.3, float("nan"), 2.1]), + (1, None, [float("nan"), 2.1]), + (2, None, [2.1]), + (0, 2, [3.3, float("nan")]), + (0, 1, [3.3]), + (1, 1, [float("nan")]), + ], +) +def test_bitmasks_pyarrow(offset, length, expected_values): + # GH 52795 + pa = pytest.importorskip("pyarrow", "11.0.0") + + arr = [3.3, None, 2.1] + table = pa.table({"arr": arr}).slice(offset, length) + exchange_df = table.__dataframe__() + result = from_dataframe(exchange_df) + expected = pd.DataFrame({"arr": expected_values}) + tm.assert_frame_equal(result, expected) + + # check round-trip + assert pa.Table.equals(pa.interchange.from_dataframe(result), table) + + +@pytest.mark.parametrize( + "data", + [ + lambda: np.random.default_rng(2).integers(-100, 100), + lambda: np.random.default_rng(2).integers(1, 100), + lambda: np.random.default_rng(2).random(), + lambda: np.random.default_rng(2).choice([True, False]), + lambda: datetime( + year=np.random.default_rng(2).integers(1900, 2100), + month=np.random.default_rng(2).integers(1, 12), + day=np.random.default_rng(2).integers(1, 20), + ), + ], +) +def test_dataframe(data): + NCOLS, NROWS = 10, 20 + data = { + f"col{int((i - NCOLS / 2) % NCOLS + 1)}": [data() for _ in range(NROWS)] + for i in range(NCOLS) + } + df = pd.DataFrame(data) + + df2 = df.__dataframe__() + + assert df2.num_columns() == NCOLS + assert df2.num_rows() == NROWS + + assert list(df2.column_names()) == list(data.keys()) + + indices = (0, 2) + names = tuple(list(data.keys())[idx] for idx in indices) + + result = from_dataframe(df2.select_columns(indices)) + expected = from_dataframe(df2.select_columns_by_name(names)) + tm.assert_frame_equal(result, expected) + + assert isinstance(result.attrs["_INTERCHANGE_PROTOCOL_BUFFERS"], list) + assert isinstance(expected.attrs["_INTERCHANGE_PROTOCOL_BUFFERS"], list) + + +def test_missing_from_masked(): + df = pd.DataFrame( + { + "x": np.array([1.0, 2.0, 3.0, 4.0, 0.0]), + "y": np.array([1.5, 2.5, 3.5, 4.5, 0]), + "z": np.array([1.0, 0.0, 1.0, 1.0, 1.0]), + } + ) + + rng = np.random.default_rng(2) + dict_null = {col: rng.integers(low=0, high=len(df)) for col in df.columns} + for col, num_nulls in dict_null.items(): + null_idx = df.index[ + rng.choice(np.arange(len(df)), size=num_nulls, replace=False) + ] + df.loc[null_idx, col] = None + + df2 = df.__dataframe__() + + assert df2.get_column_by_name("x").null_count == dict_null["x"] + assert df2.get_column_by_name("y").null_count == dict_null["y"] + assert df2.get_column_by_name("z").null_count == dict_null["z"] + + +@pytest.mark.parametrize( + "data", + [ + {"x": [1.5, 2.5, 3.5], "y": [9.2, 10.5, 11.8]}, + {"x": [1, 2, 0], "y": [9.2, 10.5, 11.8]}, + { + "x": np.array([True, True, False]), + "y": np.array([1, 2, 0]), + "z": np.array([9.2, 10.5, 11.8]), + }, + ], +) +def test_mixed_data(data): + df = pd.DataFrame(data) + df2 = df.__dataframe__() + + for col_name in df.columns: + assert df2.get_column_by_name(col_name).null_count == 0 + + +def test_mixed_missing(): + df = pd.DataFrame( + { + "x": np.array([True, None, False, None, True]), + "y": np.array([None, 2, None, 1, 2]), + "z": np.array([9.2, 10.5, None, 11.8, None]), + } + ) + + df2 = df.__dataframe__() + + for col_name in df.columns: + assert df2.get_column_by_name(col_name).null_count == 2 + + +def test_string(string_data): + test_str_data = string_data["separator data"] + [""] + df = pd.DataFrame({"A": test_str_data}) + col = df.__dataframe__().get_column_by_name("A") + + assert col.size() == 6 + assert col.null_count == 1 + assert col.dtype[0] == DtypeKind.STRING + assert col.describe_null == (ColumnNullType.USE_BYTEMASK, 0) + + df_sliced = df[1:] + col = df_sliced.__dataframe__().get_column_by_name("A") + assert col.size() == 5 + assert col.null_count == 1 + assert col.dtype[0] == DtypeKind.STRING + assert col.describe_null == (ColumnNullType.USE_BYTEMASK, 0) + + +def test_nonstring_object(): + df = pd.DataFrame({"A": ["a", 10, 1.0, ()]}) + col = df.__dataframe__().get_column_by_name("A") + with pytest.raises(NotImplementedError, match="not supported yet"): + col.dtype + + +def test_datetime(): + df = pd.DataFrame({"A": [pd.Timestamp("2022-01-01"), pd.NaT]}) + col = df.__dataframe__().get_column_by_name("A") + + assert col.size() == 2 + assert col.null_count == 1 + assert col.dtype[0] == DtypeKind.DATETIME + assert col.describe_null == (ColumnNullType.USE_SENTINEL, iNaT) + + tm.assert_frame_equal(df, from_dataframe(df.__dataframe__())) + + +@pytest.mark.skipif(np_version_lt1p23, reason="Numpy > 1.23 required") +def test_categorical_to_numpy_dlpack(): + # https://github.com/pandas-dev/pandas/issues/48393 + df = pd.DataFrame({"A": pd.Categorical(["a", "b", "a"])}) + col = df.__dataframe__().get_column_by_name("A") + result = np.from_dlpack(col.get_buffers()["data"][0]) + expected = np.array([0, 1, 0], dtype="int8") + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize("data", [{}, {"a": []}]) +def test_empty_pyarrow(data): + # GH 53155 + pytest.importorskip("pyarrow", "11.0.0") + from pyarrow.interchange import from_dataframe as pa_from_dataframe + + expected = pd.DataFrame(data) + arrow_df = pa_from_dataframe(expected) + result = from_dataframe(arrow_df) + tm.assert_frame_equal(result, expected) + + +def test_multi_chunk_pyarrow() -> None: + pa = pytest.importorskip("pyarrow", "11.0.0") + n_legs = pa.chunked_array([[2, 2, 4], [4, 5, 100]]) + names = ["n_legs"] + table = pa.table([n_legs], names=names) + with pytest.raises( + RuntimeError, + match="To join chunks a copy is required which is " + "forbidden by allow_copy=False", + ): + pd.api.interchange.from_dataframe(table, allow_copy=False) + + +def test_multi_chunk_column() -> None: + pytest.importorskip("pyarrow", "11.0.0") + ser = pd.Series([1, 2, None], dtype="Int64[pyarrow]") + df = pd.concat([ser, ser], ignore_index=True).to_frame("a") + df_orig = df.copy() + with pytest.raises( + RuntimeError, match="Found multi-chunk pyarrow array, but `allow_copy` is False" + ): + pd.api.interchange.from_dataframe(df.__dataframe__(allow_copy=False)) + result = pd.api.interchange.from_dataframe(df.__dataframe__(allow_copy=True)) + # Interchange protocol defaults to creating numpy-backed columns, so currently this + # is 'float64'. + expected = pd.DataFrame({"a": [1.0, 2.0, None, 1.0, 2.0, None]}, dtype="float64") + tm.assert_frame_equal(result, expected) + + # Check that the rechunking we did didn't modify the original DataFrame. + tm.assert_frame_equal(df, df_orig) + assert len(df["a"].array._pa_array.chunks) == 2 + assert len(df_orig["a"].array._pa_array.chunks) == 2 + + +def test_timestamp_ns_pyarrow(): + # GH 56712 + pytest.importorskip("pyarrow", "11.0.0") + timestamp_args = { + "year": 2000, + "month": 1, + "day": 1, + "hour": 1, + "minute": 1, + "second": 1, + } + df = pd.Series( + [datetime(**timestamp_args)], + dtype="timestamp[ns][pyarrow]", + name="col0", + ).to_frame() + + dfi = df.__dataframe__() + result = pd.api.interchange.from_dataframe(dfi)["col0"].item() + + expected = pd.Timestamp(**timestamp_args) + assert result == expected + + +@pytest.mark.parametrize("tz", ["UTC", "US/Pacific"]) +@pytest.mark.parametrize("unit", ["s", "ms", "us", "ns"]) +def test_datetimetzdtype(tz, unit): + # GH 54239 + tz_data = ( + pd.date_range("2018-01-01", periods=5, freq="D").tz_localize(tz).as_unit(unit) + ) + df = pd.DataFrame({"ts_tz": tz_data}) + tm.assert_frame_equal(df, from_dataframe(df.__dataframe__())) + + +def test_interchange_from_non_pandas_tz_aware(request): + # GH 54239, 54287 + pa = pytest.importorskip("pyarrow", "11.0.0") + import pyarrow.compute as pc + + if is_platform_windows() and is_ci_environment(): + mark = pytest.mark.xfail( + raises=pa.ArrowInvalid, + reason=( + "TODO: Set ARROW_TIMEZONE_DATABASE environment variable " + "on CI to path to the tzdata for pyarrow." + ), + ) + request.applymarker(mark) + + arr = pa.array([datetime(2020, 1, 1), None, datetime(2020, 1, 2)]) + arr = pc.assume_timezone(arr, "Asia/Kathmandu") + table = pa.table({"arr": arr}) + exchange_df = table.__dataframe__() + result = from_dataframe(exchange_df) + + expected = pd.DataFrame( + ["2020-01-01 00:00:00+05:45", "NaT", "2020-01-02 00:00:00+05:45"], + columns=["arr"], + dtype="datetime64[us, Asia/Kathmandu]", + ) + tm.assert_frame_equal(expected, result) + + +def test_interchange_from_corrected_buffer_dtypes(monkeypatch) -> None: + # https://github.com/pandas-dev/pandas/issues/54781 + df = pd.DataFrame({"a": ["foo", "bar"]}).__dataframe__() + interchange = df.__dataframe__() + column = interchange.get_column_by_name("a") + buffers = column.get_buffers() + buffers_data = buffers["data"] + buffer_dtype = buffers_data[1] + buffer_dtype = ( + DtypeKind.UINT, + 8, + ArrowCTypes.UINT8, + buffer_dtype[3], + ) + buffers["data"] = (buffers_data[0], buffer_dtype) + column.get_buffers = lambda: buffers + interchange.get_column_by_name = lambda _: column + monkeypatch.setattr(df, "__dataframe__", lambda allow_copy: interchange) + pd.api.interchange.from_dataframe(df) + + +def test_empty_string_column(): + # https://github.com/pandas-dev/pandas/issues/56703 + df = pd.DataFrame({"a": []}, dtype=str) + df2 = df.__dataframe__() + result = pd.api.interchange.from_dataframe(df2) + tm.assert_frame_equal(df, result) + + +def test_large_string(): + # GH#56702 + pytest.importorskip("pyarrow") + df = pd.DataFrame({"a": ["x"]}, dtype="large_string[pyarrow]") + result = pd.api.interchange.from_dataframe(df.__dataframe__()) + expected = pd.DataFrame({"a": ["x"]}, dtype="object") + tm.assert_frame_equal(result, expected) + + +def test_non_str_names(): + # https://github.com/pandas-dev/pandas/issues/56701 + df = pd.Series([1, 2, 3], name=0).to_frame() + names = df.__dataframe__().column_names() + assert names == ["0"] + + +def test_non_str_names_w_duplicates(): + # https://github.com/pandas-dev/pandas/issues/56701 + df = pd.DataFrame({"0": [1, 2, 3], 0: [4, 5, 6]}) + dfi = df.__dataframe__() + with pytest.raises( + TypeError, + match=( + "Expected a Series, got a DataFrame. This likely happened because you " + "called __dataframe__ on a DataFrame which, after converting column " + r"names to string, resulted in duplicated names: Index\(\['0', '0'\], " + r"dtype='object'\). Please rename these columns before using the " + "interchange protocol." + ), + ): + pd.api.interchange.from_dataframe(dfi, allow_copy=False) + + +@pytest.mark.parametrize( + ("data", "dtype", "expected_dtype"), + [ + ([1, 2, None], "Int64", "int64"), + ([1, 2, None], "Int64[pyarrow]", "int64"), + ([1, 2, None], "Int8", "int8"), + ([1, 2, None], "Int8[pyarrow]", "int8"), + ( + [1, 2, None], + "UInt64", + "uint64", + ), + ( + [1, 2, None], + "UInt64[pyarrow]", + "uint64", + ), + ([1.0, 2.25, None], "Float32", "float32"), + ([1.0, 2.25, None], "Float32[pyarrow]", "float32"), + ([True, False, None], "boolean", "bool"), + ([True, False, None], "boolean[pyarrow]", "bool"), + (["much ado", "about", None], "string[pyarrow_numpy]", "large_string"), + (["much ado", "about", None], "string[pyarrow]", "large_string"), + ( + [datetime(2020, 1, 1), datetime(2020, 1, 2), None], + "timestamp[ns][pyarrow]", + "timestamp[ns]", + ), + ( + [datetime(2020, 1, 1), datetime(2020, 1, 2), None], + "timestamp[us][pyarrow]", + "timestamp[us]", + ), + ( + [ + datetime(2020, 1, 1, tzinfo=timezone.utc), + datetime(2020, 1, 2, tzinfo=timezone.utc), + None, + ], + "timestamp[us, Asia/Kathmandu][pyarrow]", + "timestamp[us, tz=Asia/Kathmandu]", + ), + ], +) +def test_pandas_nullable_with_missing_values( + data: list, dtype: str, expected_dtype: str +) -> None: + # https://github.com/pandas-dev/pandas/issues/57643 + # https://github.com/pandas-dev/pandas/issues/57664 + pa = pytest.importorskip("pyarrow", "11.0.0") + import pyarrow.interchange as pai + + if expected_dtype == "timestamp[us, tz=Asia/Kathmandu]": + expected_dtype = pa.timestamp("us", "Asia/Kathmandu") + + df = pd.DataFrame({"a": data}, dtype=dtype) + result = pai.from_dataframe(df.__dataframe__())["a"] + assert result.type == expected_dtype + assert result[0].as_py() == data[0] + assert result[1].as_py() == data[1] + assert result[2].as_py() is None + + +@pytest.mark.parametrize( + ("data", "dtype", "expected_dtype"), + [ + ([1, 2, 3], "Int64", "int64"), + ([1, 2, 3], "Int64[pyarrow]", "int64"), + ([1, 2, 3], "Int8", "int8"), + ([1, 2, 3], "Int8[pyarrow]", "int8"), + ( + [1, 2, 3], + "UInt64", + "uint64", + ), + ( + [1, 2, 3], + "UInt64[pyarrow]", + "uint64", + ), + ([1.0, 2.25, 5.0], "Float32", "float32"), + ([1.0, 2.25, 5.0], "Float32[pyarrow]", "float32"), + ([True, False, False], "boolean", "bool"), + ([True, False, False], "boolean[pyarrow]", "bool"), + (["much ado", "about", "nothing"], "string[pyarrow_numpy]", "large_string"), + (["much ado", "about", "nothing"], "string[pyarrow]", "large_string"), + ( + [datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 3)], + "timestamp[ns][pyarrow]", + "timestamp[ns]", + ), + ( + [datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 3)], + "timestamp[us][pyarrow]", + "timestamp[us]", + ), + ( + [ + datetime(2020, 1, 1, tzinfo=timezone.utc), + datetime(2020, 1, 2, tzinfo=timezone.utc), + datetime(2020, 1, 3, tzinfo=timezone.utc), + ], + "timestamp[us, Asia/Kathmandu][pyarrow]", + "timestamp[us, tz=Asia/Kathmandu]", + ), + ], +) +def test_pandas_nullable_without_missing_values( + data: list, dtype: str, expected_dtype: str +) -> None: + # https://github.com/pandas-dev/pandas/issues/57643 + pa = pytest.importorskip("pyarrow", "11.0.0") + import pyarrow.interchange as pai + + if expected_dtype == "timestamp[us, tz=Asia/Kathmandu]": + expected_dtype = pa.timestamp("us", "Asia/Kathmandu") + + df = pd.DataFrame({"a": data}, dtype=dtype) + result = pai.from_dataframe(df.__dataframe__())["a"] + assert result.type == expected_dtype + assert result[0].as_py() == data[0] + assert result[1].as_py() == data[1] + assert result[2].as_py() == data[2] + + +def test_string_validity_buffer() -> None: + # https://github.com/pandas-dev/pandas/issues/57761 + pytest.importorskip("pyarrow", "11.0.0") + df = pd.DataFrame({"a": ["x"]}, dtype="large_string[pyarrow]") + result = df.__dataframe__().get_column_by_name("a").get_buffers()["validity"] + assert result is None + + +def test_string_validity_buffer_no_missing() -> None: + # https://github.com/pandas-dev/pandas/issues/57762 + pytest.importorskip("pyarrow", "11.0.0") + df = pd.DataFrame({"a": ["x", None]}, dtype="large_string[pyarrow]") + validity = df.__dataframe__().get_column_by_name("a").get_buffers()["validity"] + assert validity is not None + result = validity[1] + expected = (DtypeKind.BOOL, 1, ArrowCTypes.BOOL, "=") + assert result == expected + + +def test_empty_dataframe(): + # https://github.com/pandas-dev/pandas/issues/56700 + df = pd.DataFrame({"a": []}, dtype="int8") + dfi = df.__dataframe__() + result = pd.api.interchange.from_dataframe(dfi, allow_copy=False) + expected = pd.DataFrame({"a": []}, dtype="int8") + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/interchange/test_spec_conformance.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/interchange/test_spec_conformance.py new file mode 100644 index 0000000000000000000000000000000000000000..7c02379c118539032cb79d682d4baa2c7ae1fb81 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/interchange/test_spec_conformance.py @@ -0,0 +1,175 @@ +""" +A verbatim copy (vendored) of the spec tests. +Taken from https://github.com/data-apis/dataframe-api +""" +import ctypes +import math + +import pytest + +import pandas as pd + + +@pytest.fixture +def df_from_dict(): + def maker(dct, is_categorical=False): + df = pd.DataFrame(dct) + return df.astype("category") if is_categorical else df + + return maker + + +@pytest.mark.parametrize( + "test_data", + [ + {"a": ["foo", "bar"], "b": ["baz", "qux"]}, + {"a": [1.5, 2.5, 3.5], "b": [9.2, 10.5, 11.8]}, + {"A": [1, 2, 3, 4], "B": [1, 2, 3, 4]}, + ], + ids=["str_data", "float_data", "int_data"], +) +def test_only_one_dtype(test_data, df_from_dict): + columns = list(test_data.keys()) + df = df_from_dict(test_data) + dfX = df.__dataframe__() + + column_size = len(test_data[columns[0]]) + for column in columns: + null_count = dfX.get_column_by_name(column).null_count + assert null_count == 0 + assert isinstance(null_count, int) + assert dfX.get_column_by_name(column).size() == column_size + assert dfX.get_column_by_name(column).offset == 0 + + +def test_mixed_dtypes(df_from_dict): + df = df_from_dict( + { + "a": [1, 2, 3], # dtype kind INT = 0 + "b": [3, 4, 5], # dtype kind INT = 0 + "c": [1.5, 2.5, 3.5], # dtype kind FLOAT = 2 + "d": [9, 10, 11], # dtype kind INT = 0 + "e": [True, False, True], # dtype kind BOOLEAN = 20 + "f": ["a", "", "c"], # dtype kind STRING = 21 + } + ) + dfX = df.__dataframe__() + # for meanings of dtype[0] see the spec; we cannot import the spec here as this + # file is expected to be vendored *anywhere*; + # values for dtype[0] are explained above + columns = {"a": 0, "b": 0, "c": 2, "d": 0, "e": 20, "f": 21} + + for column, kind in columns.items(): + colX = dfX.get_column_by_name(column) + assert colX.null_count == 0 + assert isinstance(colX.null_count, int) + assert colX.size() == 3 + assert colX.offset == 0 + + assert colX.dtype[0] == kind + + assert dfX.get_column_by_name("c").dtype[1] == 64 + + +def test_na_float(df_from_dict): + df = df_from_dict({"a": [1.0, math.nan, 2.0]}) + dfX = df.__dataframe__() + colX = dfX.get_column_by_name("a") + assert colX.null_count == 1 + assert isinstance(colX.null_count, int) + + +def test_noncategorical(df_from_dict): + df = df_from_dict({"a": [1, 2, 3]}) + dfX = df.__dataframe__() + colX = dfX.get_column_by_name("a") + with pytest.raises(TypeError, match=".*categorical.*"): + colX.describe_categorical + + +def test_categorical(df_from_dict): + df = df_from_dict( + {"weekday": ["Mon", "Tue", "Mon", "Wed", "Mon", "Thu", "Fri", "Sat", "Sun"]}, + is_categorical=True, + ) + + colX = df.__dataframe__().get_column_by_name("weekday") + categorical = colX.describe_categorical + assert isinstance(categorical["is_ordered"], bool) + assert isinstance(categorical["is_dictionary"], bool) + + +def test_dataframe(df_from_dict): + df = df_from_dict( + {"x": [True, True, False], "y": [1, 2, 0], "z": [9.2, 10.5, 11.8]} + ) + dfX = df.__dataframe__() + + assert dfX.num_columns() == 3 + assert dfX.num_rows() == 3 + assert dfX.num_chunks() == 1 + assert list(dfX.column_names()) == ["x", "y", "z"] + assert list(dfX.select_columns((0, 2)).column_names()) == list( + dfX.select_columns_by_name(("x", "z")).column_names() + ) + + +@pytest.mark.parametrize(["size", "n_chunks"], [(10, 3), (12, 3), (12, 5)]) +def test_df_get_chunks(size, n_chunks, df_from_dict): + df = df_from_dict({"x": list(range(size))}) + dfX = df.__dataframe__() + chunks = list(dfX.get_chunks(n_chunks)) + assert len(chunks) == n_chunks + assert sum(chunk.num_rows() for chunk in chunks) == size + + +@pytest.mark.parametrize(["size", "n_chunks"], [(10, 3), (12, 3), (12, 5)]) +def test_column_get_chunks(size, n_chunks, df_from_dict): + df = df_from_dict({"x": list(range(size))}) + dfX = df.__dataframe__() + chunks = list(dfX.get_column(0).get_chunks(n_chunks)) + assert len(chunks) == n_chunks + assert sum(chunk.size() for chunk in chunks) == size + + +def test_get_columns(df_from_dict): + df = df_from_dict({"a": [0, 1], "b": [2.5, 3.5]}) + dfX = df.__dataframe__() + for colX in dfX.get_columns(): + assert colX.size() == 2 + assert colX.num_chunks() == 1 + # for meanings of dtype[0] see the spec; we cannot import the spec here as this + # file is expected to be vendored *anywhere* + assert dfX.get_column(0).dtype[0] == 0 # INT + assert dfX.get_column(1).dtype[0] == 2 # FLOAT + + +def test_buffer(df_from_dict): + arr = [0, 1, -1] + df = df_from_dict({"a": arr}) + dfX = df.__dataframe__() + colX = dfX.get_column(0) + bufX = colX.get_buffers() + + dataBuf, dataDtype = bufX["data"] + + assert dataBuf.bufsize > 0 + assert dataBuf.ptr != 0 + device, _ = dataBuf.__dlpack_device__() + + # for meanings of dtype[0] see the spec; we cannot import the spec here as this + # file is expected to be vendored *anywhere* + assert dataDtype[0] == 0 # INT + + if device == 1: # CPU-only as we're going to directly read memory here + bitwidth = dataDtype[1] + ctype = { + 8: ctypes.c_int8, + 16: ctypes.c_int16, + 32: ctypes.c_int32, + 64: ctypes.c_int64, + }[bitwidth] + + for idx, truth in enumerate(arr): + val = ctype.from_address(dataBuf.ptr + idx * (bitwidth // 8)).value + assert val == truth, f"Buffer at index {idx} mismatch" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/interchange/test_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/interchange/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a47bc2752ff32f5eb7630a3960e7611242cb73e3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/interchange/test_utils.py @@ -0,0 +1,89 @@ +import numpy as np +import pytest + +import pandas as pd +from pandas.core.interchange.utils import dtype_to_arrow_c_fmt + +# TODO: use ArrowSchema to get reference C-string. +# At the time, there is no way to access ArrowSchema holding a type format string +# from python. The only way to access it is to export the structure to a C-pointer, +# see DataType._export_to_c() method defined in +# https://github.com/apache/arrow/blob/master/python/pyarrow/types.pxi + + +@pytest.mark.parametrize( + "pandas_dtype, c_string", + [ + (np.dtype("bool"), "b"), + (np.dtype("int8"), "c"), + (np.dtype("uint8"), "C"), + (np.dtype("int16"), "s"), + (np.dtype("uint16"), "S"), + (np.dtype("int32"), "i"), + (np.dtype("uint32"), "I"), + (np.dtype("int64"), "l"), + (np.dtype("uint64"), "L"), + (np.dtype("float16"), "e"), + (np.dtype("float32"), "f"), + (np.dtype("float64"), "g"), + (pd.Series(["a"]).dtype, "u"), + ( + pd.Series([0]).astype("datetime64[ns]").dtype, + "tsn:", + ), + (pd.CategoricalDtype(["a"]), "l"), + (np.dtype("O"), "u"), + ], +) +def test_dtype_to_arrow_c_fmt(pandas_dtype, c_string): # PR01 + """Test ``dtype_to_arrow_c_fmt`` utility function.""" + assert dtype_to_arrow_c_fmt(pandas_dtype) == c_string + + +@pytest.mark.parametrize( + "pa_dtype, args_kwargs, c_string", + [ + ["null", {}, "n"], + ["bool_", {}, "b"], + ["uint8", {}, "C"], + ["uint16", {}, "S"], + ["uint32", {}, "I"], + ["uint64", {}, "L"], + ["int8", {}, "c"], + ["int16", {}, "S"], + ["int32", {}, "i"], + ["int64", {}, "l"], + ["float16", {}, "e"], + ["float32", {}, "f"], + ["float64", {}, "g"], + ["string", {}, "u"], + ["binary", {}, "z"], + ["time32", ("s",), "tts"], + ["time32", ("ms",), "ttm"], + ["time64", ("us",), "ttu"], + ["time64", ("ns",), "ttn"], + ["date32", {}, "tdD"], + ["date64", {}, "tdm"], + ["timestamp", {"unit": "s"}, "tss:"], + ["timestamp", {"unit": "ms"}, "tsm:"], + ["timestamp", {"unit": "us"}, "tsu:"], + ["timestamp", {"unit": "ns"}, "tsn:"], + ["timestamp", {"unit": "ns", "tz": "UTC"}, "tsn:UTC"], + ["duration", ("s",), "tDs"], + ["duration", ("ms",), "tDm"], + ["duration", ("us",), "tDu"], + ["duration", ("ns",), "tDn"], + ["decimal128", {"precision": 4, "scale": 2}, "d:4,2"], + ], +) +def test_dtype_to_arrow_c_fmt_arrowdtype(pa_dtype, args_kwargs, c_string): + # GH 52323 + pa = pytest.importorskip("pyarrow") + if not args_kwargs: + pa_type = getattr(pa, pa_dtype)() + elif isinstance(args_kwargs, tuple): + pa_type = getattr(pa, pa_dtype)(*args_kwargs) + else: + pa_type = getattr(pa, pa_dtype)(**args_kwargs) + arrow_type = pd.ArrowDtype(pa_type) + assert dtype_to_arrow_c_fmt(arrow_type) == c_string diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/internals/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/internals/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/internals/test_api.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/internals/test_api.py new file mode 100644 index 0000000000000000000000000000000000000000..1251a6ae97a1cb9304de036dba252de54e7fb10b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/internals/test_api.py @@ -0,0 +1,86 @@ +""" +Tests for the pseudo-public API implemented in internals/api.py and exposed +in core.internals +""" + +import pytest + +import pandas as pd +import pandas._testing as tm +from pandas.core import internals +from pandas.core.internals import api + + +def test_internals_api(): + assert internals.make_block is api.make_block + + +def test_namespace(): + # SUBJECT TO CHANGE + + modules = [ + "blocks", + "concat", + "managers", + "construction", + "array_manager", + "base", + "api", + "ops", + ] + expected = [ + "make_block", + "DataManager", + "ArrayManager", + "BlockManager", + "SingleDataManager", + "SingleBlockManager", + "SingleArrayManager", + "concatenate_managers", + ] + + result = [x for x in dir(internals) if not x.startswith("__")] + assert set(result) == set(expected + modules) + + +@pytest.mark.parametrize( + "name", + [ + "NumericBlock", + "ObjectBlock", + "Block", + "ExtensionBlock", + "DatetimeTZBlock", + ], +) +def test_deprecations(name): + # GH#55139 + msg = f"{name} is deprecated.* Use public APIs instead" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + getattr(internals, name) + + if name not in ["NumericBlock", "ObjectBlock"]: + # NumericBlock and ObjectBlock are not in the internals.api namespace + with tm.assert_produces_warning(DeprecationWarning, match=msg): + getattr(api, name) + + +def test_make_block_2d_with_dti(): + # GH#41168 + dti = pd.date_range("2012", periods=3, tz="UTC") + blk = api.make_block(dti, placement=[0]) + + assert blk.shape == (1, 3) + assert blk.values.shape == (1, 3) + + +def test_create_block_manager_from_blocks_deprecated(): + # GH#33892 + # If they must, downstream packages should get this from internals.api, + # not internals. + msg = ( + "create_block_manager_from_blocks is deprecated and will be " + "removed in a future version. Use public APIs instead" + ) + with tm.assert_produces_warning(DeprecationWarning, match=msg): + internals.create_block_manager_from_blocks diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/internals/test_internals.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/internals/test_internals.py new file mode 100644 index 0000000000000000000000000000000000000000..ce88bae6e02f2892d7c8e4ee8f6315904bbbd65a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/internals/test_internals.py @@ -0,0 +1,1422 @@ +from datetime import ( + date, + datetime, +) +import itertools +import re + +import numpy as np +import pytest + +from pandas._libs.internals import BlockPlacement +from pandas.compat import IS64 +import pandas.util._test_decorators as td + +from pandas.core.dtypes.common import is_scalar + +import pandas as pd +from pandas import ( + Categorical, + DataFrame, + DatetimeIndex, + Index, + IntervalIndex, + Series, + Timedelta, + Timestamp, + period_range, +) +import pandas._testing as tm +import pandas.core.algorithms as algos +from pandas.core.arrays import ( + DatetimeArray, + SparseArray, + TimedeltaArray, +) +from pandas.core.internals import ( + BlockManager, + SingleBlockManager, + make_block, +) +from pandas.core.internals.blocks import ( + ensure_block_shape, + maybe_coerce_values, + new_block, +) + +# this file contains BlockManager specific tests +# TODO(ArrayManager) factor out interleave_dtype tests +pytestmark = td.skip_array_manager_invalid_test + + +@pytest.fixture(params=[new_block, make_block]) +def block_maker(request): + """ + Fixture to test both the internal new_block and pseudo-public make_block. + """ + return request.param + + +@pytest.fixture +def mgr(): + return create_mgr( + "a: f8; b: object; c: f8; d: object; e: f8;" + "f: bool; g: i8; h: complex; i: datetime-1; j: datetime-2;" + "k: M8[ns, US/Eastern]; l: M8[ns, CET];" + ) + + +def assert_block_equal(left, right): + tm.assert_numpy_array_equal(left.values, right.values) + assert left.dtype == right.dtype + assert isinstance(left.mgr_locs, BlockPlacement) + assert isinstance(right.mgr_locs, BlockPlacement) + tm.assert_numpy_array_equal(left.mgr_locs.as_array, right.mgr_locs.as_array) + + +def get_numeric_mat(shape): + arr = np.arange(shape[0]) + return np.lib.stride_tricks.as_strided( + x=arr, shape=shape, strides=(arr.itemsize,) + (0,) * (len(shape) - 1) + ).copy() + + +N = 10 + + +def create_block(typestr, placement, item_shape=None, num_offset=0, maker=new_block): + """ + Supported typestr: + + * float, f8, f4, f2 + * int, i8, i4, i2, i1 + * uint, u8, u4, u2, u1 + * complex, c16, c8 + * bool + * object, string, O + * datetime, dt, M8[ns], M8[ns, tz] + * timedelta, td, m8[ns] + * sparse (SparseArray with fill_value=0.0) + * sparse_na (SparseArray with fill_value=np.nan) + * category, category2 + + """ + placement = BlockPlacement(placement) + num_items = len(placement) + + if item_shape is None: + item_shape = (N,) + + shape = (num_items,) + item_shape + + mat = get_numeric_mat(shape) + + if typestr in ( + "float", + "f8", + "f4", + "f2", + "int", + "i8", + "i4", + "i2", + "i1", + "uint", + "u8", + "u4", + "u2", + "u1", + ): + values = mat.astype(typestr) + num_offset + elif typestr in ("complex", "c16", "c8"): + values = 1.0j * (mat.astype(typestr) + num_offset) + elif typestr in ("object", "string", "O"): + values = np.reshape([f"A{i:d}" for i in mat.ravel() + num_offset], shape) + elif typestr in ("b", "bool"): + values = np.ones(shape, dtype=np.bool_) + elif typestr in ("datetime", "dt", "M8[ns]"): + values = (mat * 1e9).astype("M8[ns]") + elif typestr.startswith("M8[ns"): + # datetime with tz + m = re.search(r"M8\[ns,\s*(\w+\/?\w*)\]", typestr) + assert m is not None, f"incompatible typestr -> {typestr}" + tz = m.groups()[0] + assert num_items == 1, "must have only 1 num items for a tz-aware" + values = DatetimeIndex(np.arange(N) * 10**9, tz=tz)._data + values = ensure_block_shape(values, ndim=len(shape)) + elif typestr in ("timedelta", "td", "m8[ns]"): + values = (mat * 1).astype("m8[ns]") + elif typestr in ("category",): + values = Categorical([1, 1, 2, 2, 3, 3, 3, 3, 4, 4]) + elif typestr in ("category2",): + values = Categorical(["a", "a", "a", "a", "b", "b", "c", "c", "c", "d"]) + elif typestr in ("sparse", "sparse_na"): + if shape[-1] != 10: + # We also are implicitly assuming this in the category cases above + raise NotImplementedError + + assert all(s == 1 for s in shape[:-1]) + if typestr.endswith("_na"): + fill_value = np.nan + else: + fill_value = 0.0 + values = SparseArray( + [fill_value, fill_value, 1, 2, 3, fill_value, 4, 5, fill_value, 6], + fill_value=fill_value, + ) + arr = values.sp_values.view() + arr += num_offset - 1 + else: + raise ValueError(f'Unsupported typestr: "{typestr}"') + + values = maybe_coerce_values(values) + return maker(values, placement=placement, ndim=len(shape)) + + +def create_single_mgr(typestr, num_rows=None): + if num_rows is None: + num_rows = N + + return SingleBlockManager( + create_block(typestr, placement=slice(0, num_rows), item_shape=()), + Index(np.arange(num_rows)), + ) + + +def create_mgr(descr, item_shape=None): + """ + Construct BlockManager from string description. + + String description syntax looks similar to np.matrix initializer. It looks + like this:: + + a,b,c: f8; d,e,f: i8 + + Rules are rather simple: + + * see list of supported datatypes in `create_block` method + * components are semicolon-separated + * each component is `NAME,NAME,NAME: DTYPE_ID` + * whitespace around colons & semicolons are removed + * components with same DTYPE_ID are combined into single block + * to force multiple blocks with same dtype, use '-SUFFIX':: + + 'a:f8-1; b:f8-2; c:f8-foobar' + + """ + if item_shape is None: + item_shape = (N,) + + offset = 0 + mgr_items = [] + block_placements = {} + for d in descr.split(";"): + d = d.strip() + if not len(d): + continue + names, blockstr = d.partition(":")[::2] + blockstr = blockstr.strip() + names = names.strip().split(",") + + mgr_items.extend(names) + placement = list(np.arange(len(names)) + offset) + try: + block_placements[blockstr].extend(placement) + except KeyError: + block_placements[blockstr] = placement + offset += len(names) + + mgr_items = Index(mgr_items) + + blocks = [] + num_offset = 0 + for blockstr, placement in block_placements.items(): + typestr = blockstr.split("-")[0] + blocks.append( + create_block( + typestr, placement, item_shape=item_shape, num_offset=num_offset + ) + ) + num_offset += len(placement) + + sblocks = sorted(blocks, key=lambda b: b.mgr_locs[0]) + return BlockManager( + tuple(sblocks), + [mgr_items] + [Index(np.arange(n)) for n in item_shape], + ) + + +@pytest.fixture +def fblock(): + return create_block("float", [0, 2, 4]) + + +class TestBlock: + def test_constructor(self): + int32block = create_block("i4", [0]) + assert int32block.dtype == np.int32 + + @pytest.mark.parametrize( + "typ, data", + [ + ["float", [0, 2, 4]], + ["complex", [7]], + ["object", [1, 3]], + ["bool", [5]], + ], + ) + def test_pickle(self, typ, data): + blk = create_block(typ, data) + assert_block_equal(tm.round_trip_pickle(blk), blk) + + def test_mgr_locs(self, fblock): + assert isinstance(fblock.mgr_locs, BlockPlacement) + tm.assert_numpy_array_equal( + fblock.mgr_locs.as_array, np.array([0, 2, 4], dtype=np.intp) + ) + + def test_attrs(self, fblock): + assert fblock.shape == fblock.values.shape + assert fblock.dtype == fblock.values.dtype + assert len(fblock) == len(fblock.values) + + def test_copy(self, fblock): + cop = fblock.copy() + assert cop is not fblock + assert_block_equal(fblock, cop) + + def test_delete(self, fblock): + newb = fblock.copy() + locs = newb.mgr_locs + nb = newb.delete(0)[0] + assert newb.mgr_locs is locs + + assert nb is not newb + + tm.assert_numpy_array_equal( + nb.mgr_locs.as_array, np.array([2, 4], dtype=np.intp) + ) + assert not (newb.values[0] == 1).all() + assert (nb.values[0] == 1).all() + + newb = fblock.copy() + locs = newb.mgr_locs + nb = newb.delete(1) + assert len(nb) == 2 + assert newb.mgr_locs is locs + + tm.assert_numpy_array_equal( + nb[0].mgr_locs.as_array, np.array([0], dtype=np.intp) + ) + tm.assert_numpy_array_equal( + nb[1].mgr_locs.as_array, np.array([4], dtype=np.intp) + ) + assert not (newb.values[1] == 2).all() + assert (nb[1].values[0] == 2).all() + + newb = fblock.copy() + nb = newb.delete(2) + assert len(nb) == 1 + tm.assert_numpy_array_equal( + nb[0].mgr_locs.as_array, np.array([0, 2], dtype=np.intp) + ) + assert (nb[0].values[1] == 1).all() + + newb = fblock.copy() + + with pytest.raises(IndexError, match=None): + newb.delete(3) + + def test_delete_datetimelike(self): + # dont use np.delete on values, as that will coerce from DTA/TDA to ndarray + arr = np.arange(20, dtype="i8").reshape(5, 4).view("m8[ns]") + df = DataFrame(arr) + blk = df._mgr.blocks[0] + assert isinstance(blk.values, TimedeltaArray) + + nb = blk.delete(1) + assert len(nb) == 2 + assert isinstance(nb[0].values, TimedeltaArray) + assert isinstance(nb[1].values, TimedeltaArray) + + df = DataFrame(arr.view("M8[ns]")) + blk = df._mgr.blocks[0] + assert isinstance(blk.values, DatetimeArray) + + nb = blk.delete([1, 3]) + assert len(nb) == 2 + assert isinstance(nb[0].values, DatetimeArray) + assert isinstance(nb[1].values, DatetimeArray) + + def test_split(self): + # GH#37799 + values = np.random.default_rng(2).standard_normal((3, 4)) + blk = new_block(values, placement=BlockPlacement([3, 1, 6]), ndim=2) + result = blk._split() + + # check that we get views, not copies + values[:] = -9999 + assert (blk.values == -9999).all() + + assert len(result) == 3 + expected = [ + new_block(values[[0]], placement=BlockPlacement([3]), ndim=2), + new_block(values[[1]], placement=BlockPlacement([1]), ndim=2), + new_block(values[[2]], placement=BlockPlacement([6]), ndim=2), + ] + for res, exp in zip(result, expected): + assert_block_equal(res, exp) + + +class TestBlockManager: + def test_attrs(self): + mgr = create_mgr("a,b,c: f8-1; d,e,f: f8-2") + assert mgr.nblocks == 2 + assert len(mgr) == 6 + + def test_duplicate_ref_loc_failure(self): + tmp_mgr = create_mgr("a:bool; a: f8") + + axes, blocks = tmp_mgr.axes, tmp_mgr.blocks + + blocks[0].mgr_locs = BlockPlacement(np.array([0])) + blocks[1].mgr_locs = BlockPlacement(np.array([0])) + + # test trying to create block manager with overlapping ref locs + + msg = "Gaps in blk ref_locs" + + with pytest.raises(AssertionError, match=msg): + mgr = BlockManager(blocks, axes) + mgr._rebuild_blknos_and_blklocs() + + blocks[0].mgr_locs = BlockPlacement(np.array([0])) + blocks[1].mgr_locs = BlockPlacement(np.array([1])) + mgr = BlockManager(blocks, axes) + mgr.iget(1) + + def test_pickle(self, mgr): + mgr2 = tm.round_trip_pickle(mgr) + tm.assert_frame_equal( + DataFrame._from_mgr(mgr, axes=mgr.axes), + DataFrame._from_mgr(mgr2, axes=mgr2.axes), + ) + + # GH2431 + assert hasattr(mgr2, "_is_consolidated") + assert hasattr(mgr2, "_known_consolidated") + + # reset to False on load + assert not mgr2._is_consolidated + assert not mgr2._known_consolidated + + @pytest.mark.parametrize("mgr_string", ["a,a,a:f8", "a: f8; a: i8"]) + def test_non_unique_pickle(self, mgr_string): + mgr = create_mgr(mgr_string) + mgr2 = tm.round_trip_pickle(mgr) + tm.assert_frame_equal( + DataFrame._from_mgr(mgr, axes=mgr.axes), + DataFrame._from_mgr(mgr2, axes=mgr2.axes), + ) + + def test_categorical_block_pickle(self): + mgr = create_mgr("a: category") + mgr2 = tm.round_trip_pickle(mgr) + tm.assert_frame_equal( + DataFrame._from_mgr(mgr, axes=mgr.axes), + DataFrame._from_mgr(mgr2, axes=mgr2.axes), + ) + + smgr = create_single_mgr("category") + smgr2 = tm.round_trip_pickle(smgr) + tm.assert_series_equal( + Series()._constructor_from_mgr(smgr, axes=smgr.axes), + Series()._constructor_from_mgr(smgr2, axes=smgr2.axes), + ) + + def test_iget(self): + cols = Index(list("abc")) + values = np.random.default_rng(2).random((3, 3)) + block = new_block( + values=values.copy(), + placement=BlockPlacement(np.arange(3, dtype=np.intp)), + ndim=values.ndim, + ) + mgr = BlockManager(blocks=(block,), axes=[cols, Index(np.arange(3))]) + + tm.assert_almost_equal(mgr.iget(0).internal_values(), values[0]) + tm.assert_almost_equal(mgr.iget(1).internal_values(), values[1]) + tm.assert_almost_equal(mgr.iget(2).internal_values(), values[2]) + + def test_set(self): + mgr = create_mgr("a,b,c: int", item_shape=(3,)) + + mgr.insert(len(mgr.items), "d", np.array(["foo"] * 3)) + mgr.iset(1, np.array(["bar"] * 3)) + tm.assert_numpy_array_equal(mgr.iget(0).internal_values(), np.array([0] * 3)) + tm.assert_numpy_array_equal( + mgr.iget(1).internal_values(), np.array(["bar"] * 3, dtype=np.object_) + ) + tm.assert_numpy_array_equal(mgr.iget(2).internal_values(), np.array([2] * 3)) + tm.assert_numpy_array_equal( + mgr.iget(3).internal_values(), np.array(["foo"] * 3, dtype=np.object_) + ) + + def test_set_change_dtype(self, mgr): + mgr.insert(len(mgr.items), "baz", np.zeros(N, dtype=bool)) + + mgr.iset(mgr.items.get_loc("baz"), np.repeat("foo", N)) + idx = mgr.items.get_loc("baz") + assert mgr.iget(idx).dtype == np.object_ + + mgr2 = mgr.consolidate() + mgr2.iset(mgr2.items.get_loc("baz"), np.repeat("foo", N)) + idx = mgr2.items.get_loc("baz") + assert mgr2.iget(idx).dtype == np.object_ + + mgr2.insert( + len(mgr2.items), + "quux", + np.random.default_rng(2).standard_normal(N).astype(int), + ) + idx = mgr2.items.get_loc("quux") + assert mgr2.iget(idx).dtype == np.dtype(int) + + mgr2.iset( + mgr2.items.get_loc("quux"), np.random.default_rng(2).standard_normal(N) + ) + assert mgr2.iget(idx).dtype == np.float64 + + def test_copy(self, mgr): + cp = mgr.copy(deep=False) + for blk, cp_blk in zip(mgr.blocks, cp.blocks): + # view assertion + tm.assert_equal(cp_blk.values, blk.values) + if isinstance(blk.values, np.ndarray): + assert cp_blk.values.base is blk.values.base + else: + # DatetimeTZBlock has DatetimeIndex values + assert cp_blk.values._ndarray.base is blk.values._ndarray.base + + # copy(deep=True) consolidates, so the block-wise assertions will + # fail is mgr is not consolidated + mgr._consolidate_inplace() + cp = mgr.copy(deep=True) + for blk, cp_blk in zip(mgr.blocks, cp.blocks): + bvals = blk.values + cpvals = cp_blk.values + + tm.assert_equal(cpvals, bvals) + + if isinstance(cpvals, np.ndarray): + lbase = cpvals.base + rbase = bvals.base + else: + lbase = cpvals._ndarray.base + rbase = bvals._ndarray.base + + # copy assertion we either have a None for a base or in case of + # some blocks it is an array (e.g. datetimetz), but was copied + if isinstance(cpvals, DatetimeArray): + assert (lbase is None and rbase is None) or (lbase is not rbase) + elif not isinstance(cpvals, np.ndarray): + assert lbase is not rbase + else: + assert lbase is None and rbase is None + + def test_sparse(self): + mgr = create_mgr("a: sparse-1; b: sparse-2") + assert mgr.as_array().dtype == np.float64 + + def test_sparse_mixed(self): + mgr = create_mgr("a: sparse-1; b: sparse-2; c: f8") + assert len(mgr.blocks) == 3 + assert isinstance(mgr, BlockManager) + + @pytest.mark.parametrize( + "mgr_string, dtype", + [("c: f4; d: f2", np.float32), ("c: f4; d: f2; e: f8", np.float64)], + ) + def test_as_array_float(self, mgr_string, dtype): + mgr = create_mgr(mgr_string) + assert mgr.as_array().dtype == dtype + + @pytest.mark.parametrize( + "mgr_string, dtype", + [ + ("a: bool-1; b: bool-2", np.bool_), + ("a: i8-1; b: i8-2; c: i4; d: i2; e: u1", np.int64), + ("c: i4; d: i2; e: u1", np.int32), + ], + ) + def test_as_array_int_bool(self, mgr_string, dtype): + mgr = create_mgr(mgr_string) + assert mgr.as_array().dtype == dtype + + def test_as_array_datetime(self): + mgr = create_mgr("h: datetime-1; g: datetime-2") + assert mgr.as_array().dtype == "M8[ns]" + + def test_as_array_datetime_tz(self): + mgr = create_mgr("h: M8[ns, US/Eastern]; g: M8[ns, CET]") + assert mgr.iget(0).dtype == "datetime64[ns, US/Eastern]" + assert mgr.iget(1).dtype == "datetime64[ns, CET]" + assert mgr.as_array().dtype == "object" + + @pytest.mark.parametrize("t", ["float16", "float32", "float64", "int32", "int64"]) + def test_astype(self, t): + # coerce all + mgr = create_mgr("c: f4; d: f2; e: f8") + + t = np.dtype(t) + tmgr = mgr.astype(t) + assert tmgr.iget(0).dtype.type == t + assert tmgr.iget(1).dtype.type == t + assert tmgr.iget(2).dtype.type == t + + # mixed + mgr = create_mgr("a,b: object; c: bool; d: datetime; e: f4; f: f2; g: f8") + + t = np.dtype(t) + tmgr = mgr.astype(t, errors="ignore") + assert tmgr.iget(2).dtype.type == t + assert tmgr.iget(4).dtype.type == t + assert tmgr.iget(5).dtype.type == t + assert tmgr.iget(6).dtype.type == t + + assert tmgr.iget(0).dtype.type == np.object_ + assert tmgr.iget(1).dtype.type == np.object_ + if t != np.int64: + assert tmgr.iget(3).dtype.type == np.datetime64 + else: + assert tmgr.iget(3).dtype.type == t + + def test_convert(self, using_infer_string): + def _compare(old_mgr, new_mgr): + """compare the blocks, numeric compare ==, object don't""" + old_blocks = set(old_mgr.blocks) + new_blocks = set(new_mgr.blocks) + assert len(old_blocks) == len(new_blocks) + + # compare non-numeric + for b in old_blocks: + found = False + for nb in new_blocks: + if (b.values == nb.values).all(): + found = True + break + assert found + + for b in new_blocks: + found = False + for ob in old_blocks: + if (b.values == ob.values).all(): + found = True + break + assert found + + # noops + mgr = create_mgr("f: i8; g: f8") + new_mgr = mgr.convert(copy=True) + _compare(mgr, new_mgr) + + # convert + mgr = create_mgr("a,b,foo: object; f: i8; g: f8") + mgr.iset(0, np.array(["1"] * N, dtype=np.object_)) + mgr.iset(1, np.array(["2."] * N, dtype=np.object_)) + mgr.iset(2, np.array(["foo."] * N, dtype=np.object_)) + new_mgr = mgr.convert(copy=True) + dtype = "string[pyarrow_numpy]" if using_infer_string else np.object_ + assert new_mgr.iget(0).dtype == dtype + assert new_mgr.iget(1).dtype == dtype + assert new_mgr.iget(2).dtype == dtype + assert new_mgr.iget(3).dtype == np.int64 + assert new_mgr.iget(4).dtype == np.float64 + + mgr = create_mgr( + "a,b,foo: object; f: i4; bool: bool; dt: datetime; i: i8; g: f8; h: f2" + ) + mgr.iset(0, np.array(["1"] * N, dtype=np.object_)) + mgr.iset(1, np.array(["2."] * N, dtype=np.object_)) + mgr.iset(2, np.array(["foo."] * N, dtype=np.object_)) + new_mgr = mgr.convert(copy=True) + assert new_mgr.iget(0).dtype == dtype + assert new_mgr.iget(1).dtype == dtype + assert new_mgr.iget(2).dtype == dtype + assert new_mgr.iget(3).dtype == np.int32 + assert new_mgr.iget(4).dtype == np.bool_ + assert new_mgr.iget(5).dtype.type, np.datetime64 + assert new_mgr.iget(6).dtype == np.int64 + assert new_mgr.iget(7).dtype == np.float64 + assert new_mgr.iget(8).dtype == np.float16 + + def test_interleave(self): + # self + for dtype in ["f8", "i8", "object", "bool", "complex", "M8[ns]", "m8[ns]"]: + mgr = create_mgr(f"a: {dtype}") + assert mgr.as_array().dtype == dtype + mgr = create_mgr(f"a: {dtype}; b: {dtype}") + assert mgr.as_array().dtype == dtype + + @pytest.mark.parametrize( + "mgr_string, dtype", + [ + ("a: category", "i8"), + ("a: category; b: category", "i8"), + ("a: category; b: category2", "object"), + ("a: category2", "object"), + ("a: category2; b: category2", "object"), + ("a: f8", "f8"), + ("a: f8; b: i8", "f8"), + ("a: f4; b: i8", "f8"), + ("a: f4; b: i8; d: object", "object"), + ("a: bool; b: i8", "object"), + ("a: complex", "complex"), + ("a: f8; b: category", "object"), + ("a: M8[ns]; b: category", "object"), + ("a: M8[ns]; b: bool", "object"), + ("a: M8[ns]; b: i8", "object"), + ("a: m8[ns]; b: bool", "object"), + ("a: m8[ns]; b: i8", "object"), + ("a: M8[ns]; b: m8[ns]", "object"), + ], + ) + def test_interleave_dtype(self, mgr_string, dtype): + # will be converted according the actual dtype of the underlying + mgr = create_mgr("a: category") + assert mgr.as_array().dtype == "i8" + mgr = create_mgr("a: category; b: category2") + assert mgr.as_array().dtype == "object" + mgr = create_mgr("a: category2") + assert mgr.as_array().dtype == "object" + + # combinations + mgr = create_mgr("a: f8") + assert mgr.as_array().dtype == "f8" + mgr = create_mgr("a: f8; b: i8") + assert mgr.as_array().dtype == "f8" + mgr = create_mgr("a: f4; b: i8") + assert mgr.as_array().dtype == "f8" + mgr = create_mgr("a: f4; b: i8; d: object") + assert mgr.as_array().dtype == "object" + mgr = create_mgr("a: bool; b: i8") + assert mgr.as_array().dtype == "object" + mgr = create_mgr("a: complex") + assert mgr.as_array().dtype == "complex" + mgr = create_mgr("a: f8; b: category") + assert mgr.as_array().dtype == "f8" + mgr = create_mgr("a: M8[ns]; b: category") + assert mgr.as_array().dtype == "object" + mgr = create_mgr("a: M8[ns]; b: bool") + assert mgr.as_array().dtype == "object" + mgr = create_mgr("a: M8[ns]; b: i8") + assert mgr.as_array().dtype == "object" + mgr = create_mgr("a: m8[ns]; b: bool") + assert mgr.as_array().dtype == "object" + mgr = create_mgr("a: m8[ns]; b: i8") + assert mgr.as_array().dtype == "object" + mgr = create_mgr("a: M8[ns]; b: m8[ns]") + assert mgr.as_array().dtype == "object" + + def test_consolidate_ordering_issues(self, mgr): + mgr.iset(mgr.items.get_loc("f"), np.random.default_rng(2).standard_normal(N)) + mgr.iset(mgr.items.get_loc("d"), np.random.default_rng(2).standard_normal(N)) + mgr.iset(mgr.items.get_loc("b"), np.random.default_rng(2).standard_normal(N)) + mgr.iset(mgr.items.get_loc("g"), np.random.default_rng(2).standard_normal(N)) + mgr.iset(mgr.items.get_loc("h"), np.random.default_rng(2).standard_normal(N)) + + # we have datetime/tz blocks in mgr + cons = mgr.consolidate() + assert cons.nblocks == 4 + cons = mgr.consolidate().get_numeric_data() + assert cons.nblocks == 1 + assert isinstance(cons.blocks[0].mgr_locs, BlockPlacement) + tm.assert_numpy_array_equal( + cons.blocks[0].mgr_locs.as_array, np.arange(len(cons.items), dtype=np.intp) + ) + + def test_reindex_items(self): + # mgr is not consolidated, f8 & f8-2 blocks + mgr = create_mgr("a: f8; b: i8; c: f8; d: i8; e: f8; f: bool; g: f8-2") + + reindexed = mgr.reindex_axis(["g", "c", "a", "d"], axis=0) + # reindex_axis does not consolidate_inplace, as that risks failing to + # invalidate _item_cache + assert not reindexed.is_consolidated() + + tm.assert_index_equal(reindexed.items, Index(["g", "c", "a", "d"])) + tm.assert_almost_equal( + mgr.iget(6).internal_values(), reindexed.iget(0).internal_values() + ) + tm.assert_almost_equal( + mgr.iget(2).internal_values(), reindexed.iget(1).internal_values() + ) + tm.assert_almost_equal( + mgr.iget(0).internal_values(), reindexed.iget(2).internal_values() + ) + tm.assert_almost_equal( + mgr.iget(3).internal_values(), reindexed.iget(3).internal_values() + ) + + def test_get_numeric_data(self, using_copy_on_write): + mgr = create_mgr( + "int: int; float: float; complex: complex;" + "str: object; bool: bool; obj: object; dt: datetime", + item_shape=(3,), + ) + mgr.iset(5, np.array([1, 2, 3], dtype=np.object_)) + + numeric = mgr.get_numeric_data() + tm.assert_index_equal(numeric.items, Index(["int", "float", "complex", "bool"])) + tm.assert_almost_equal( + mgr.iget(mgr.items.get_loc("float")).internal_values(), + numeric.iget(numeric.items.get_loc("float")).internal_values(), + ) + + # Check sharing + numeric.iset( + numeric.items.get_loc("float"), + np.array([100.0, 200.0, 300.0]), + inplace=True, + ) + if using_copy_on_write: + tm.assert_almost_equal( + mgr.iget(mgr.items.get_loc("float")).internal_values(), + np.array([1.0, 1.0, 1.0]), + ) + else: + tm.assert_almost_equal( + mgr.iget(mgr.items.get_loc("float")).internal_values(), + np.array([100.0, 200.0, 300.0]), + ) + + def test_get_bool_data(self, using_copy_on_write): + mgr = create_mgr( + "int: int; float: float; complex: complex;" + "str: object; bool: bool; obj: object; dt: datetime", + item_shape=(3,), + ) + mgr.iset(6, np.array([True, False, True], dtype=np.object_)) + + bools = mgr.get_bool_data() + tm.assert_index_equal(bools.items, Index(["bool"])) + tm.assert_almost_equal( + mgr.iget(mgr.items.get_loc("bool")).internal_values(), + bools.iget(bools.items.get_loc("bool")).internal_values(), + ) + + bools.iset(0, np.array([True, False, True]), inplace=True) + if using_copy_on_write: + tm.assert_numpy_array_equal( + mgr.iget(mgr.items.get_loc("bool")).internal_values(), + np.array([True, True, True]), + ) + else: + tm.assert_numpy_array_equal( + mgr.iget(mgr.items.get_loc("bool")).internal_values(), + np.array([True, False, True]), + ) + + def test_unicode_repr_doesnt_raise(self): + repr(create_mgr("b,\u05d0: object")) + + @pytest.mark.parametrize( + "mgr_string", ["a,b,c: i8-1; d,e,f: i8-2", "a,a,a: i8-1; b,b,b: i8-2"] + ) + def test_equals(self, mgr_string): + # unique items + bm1 = create_mgr(mgr_string) + bm2 = BlockManager(bm1.blocks[::-1], bm1.axes) + assert bm1.equals(bm2) + + @pytest.mark.parametrize( + "mgr_string", + [ + "a:i8;b:f8", # basic case + "a:i8;b:f8;c:c8;d:b", # many types + "a:i8;e:dt;f:td;g:string", # more types + "a:i8;b:category;c:category2", # categories + "c:sparse;d:sparse_na;b:f8", # sparse + ], + ) + def test_equals_block_order_different_dtypes(self, mgr_string): + # GH 9330 + bm = create_mgr(mgr_string) + block_perms = itertools.permutations(bm.blocks) + for bm_perm in block_perms: + bm_this = BlockManager(tuple(bm_perm), bm.axes) + assert bm.equals(bm_this) + assert bm_this.equals(bm) + + def test_single_mgr_ctor(self): + mgr = create_single_mgr("f8", num_rows=5) + assert mgr.external_values().tolist() == [0.0, 1.0, 2.0, 3.0, 4.0] + + @pytest.mark.parametrize("value", [1, "True", [1, 2, 3], 5.0]) + def test_validate_bool_args(self, value): + bm1 = create_mgr("a,b,c: i8-1; d,e,f: i8-2") + + msg = ( + 'For argument "inplace" expected type bool, ' + f"received type {type(value).__name__}." + ) + with pytest.raises(ValueError, match=msg): + bm1.replace_list([1], [2], inplace=value) + + def test_iset_split_block(self): + bm = create_mgr("a,b,c: i8; d: f8") + bm._iset_split_block(0, np.array([0])) + tm.assert_numpy_array_equal( + bm.blklocs, np.array([0, 0, 1, 0], dtype="int64" if IS64 else "int32") + ) + # First indexer currently does not have a block associated with it in case + tm.assert_numpy_array_equal( + bm.blknos, np.array([0, 0, 0, 1], dtype="int64" if IS64 else "int32") + ) + assert len(bm.blocks) == 2 + + def test_iset_split_block_values(self): + bm = create_mgr("a,b,c: i8; d: f8") + bm._iset_split_block(0, np.array([0]), np.array([list(range(10))])) + tm.assert_numpy_array_equal( + bm.blklocs, np.array([0, 0, 1, 0], dtype="int64" if IS64 else "int32") + ) + # First indexer currently does not have a block associated with it in case + tm.assert_numpy_array_equal( + bm.blknos, np.array([0, 2, 2, 1], dtype="int64" if IS64 else "int32") + ) + assert len(bm.blocks) == 3 + + +def _as_array(mgr): + if mgr.ndim == 1: + return mgr.external_values() + return mgr.as_array().T + + +class TestIndexing: + # Nosetests-style data-driven tests. + # + # This test applies different indexing routines to block managers and + # compares the outcome to the result of same operations on np.ndarray. + # + # NOTE: sparse (SparseBlock with fill_value != np.nan) fail a lot of tests + # and are disabled. + + MANAGERS = [ + create_single_mgr("f8", N), + create_single_mgr("i8", N), + # 2-dim + create_mgr("a,b,c,d,e,f: f8", item_shape=(N,)), + create_mgr("a,b,c,d,e,f: i8", item_shape=(N,)), + create_mgr("a,b: f8; c,d: i8; e,f: string", item_shape=(N,)), + create_mgr("a,b: f8; c,d: i8; e,f: f8", item_shape=(N,)), + ] + + @pytest.mark.parametrize("mgr", MANAGERS) + def test_get_slice(self, mgr): + def assert_slice_ok(mgr, axis, slobj): + mat = _as_array(mgr) + + # we maybe using an ndarray to test slicing and + # might not be the full length of the axis + if isinstance(slobj, np.ndarray): + ax = mgr.axes[axis] + if len(ax) and len(slobj) and len(slobj) != len(ax): + slobj = np.concatenate( + [slobj, np.zeros(len(ax) - len(slobj), dtype=bool)] + ) + + if isinstance(slobj, slice): + sliced = mgr.get_slice(slobj, axis=axis) + elif ( + mgr.ndim == 1 + and axis == 0 + and isinstance(slobj, np.ndarray) + and slobj.dtype == bool + ): + sliced = mgr.get_rows_with_mask(slobj) + else: + # BlockManager doesn't support non-slice, SingleBlockManager + # doesn't support axis > 0 + raise TypeError(slobj) + + mat_slobj = (slice(None),) * axis + (slobj,) + tm.assert_numpy_array_equal( + mat[mat_slobj], _as_array(sliced), check_dtype=False + ) + tm.assert_index_equal(mgr.axes[axis][slobj], sliced.axes[axis]) + + assert mgr.ndim <= 2, mgr.ndim + for ax in range(mgr.ndim): + # slice + assert_slice_ok(mgr, ax, slice(None)) + assert_slice_ok(mgr, ax, slice(3)) + assert_slice_ok(mgr, ax, slice(100)) + assert_slice_ok(mgr, ax, slice(1, 4)) + assert_slice_ok(mgr, ax, slice(3, 0, -2)) + + if mgr.ndim < 2: + # 2D only support slice objects + + # boolean mask + assert_slice_ok(mgr, ax, np.ones(mgr.shape[ax], dtype=np.bool_)) + assert_slice_ok(mgr, ax, np.zeros(mgr.shape[ax], dtype=np.bool_)) + + if mgr.shape[ax] >= 3: + assert_slice_ok(mgr, ax, np.arange(mgr.shape[ax]) % 3 == 0) + assert_slice_ok( + mgr, ax, np.array([True, True, False], dtype=np.bool_) + ) + + @pytest.mark.parametrize("mgr", MANAGERS) + def test_take(self, mgr): + def assert_take_ok(mgr, axis, indexer): + mat = _as_array(mgr) + taken = mgr.take(indexer, axis) + tm.assert_numpy_array_equal( + np.take(mat, indexer, axis), _as_array(taken), check_dtype=False + ) + tm.assert_index_equal(mgr.axes[axis].take(indexer), taken.axes[axis]) + + for ax in range(mgr.ndim): + # take/fancy indexer + assert_take_ok(mgr, ax, indexer=np.array([], dtype=np.intp)) + assert_take_ok(mgr, ax, indexer=np.array([0, 0, 0], dtype=np.intp)) + assert_take_ok( + mgr, ax, indexer=np.array(list(range(mgr.shape[ax])), dtype=np.intp) + ) + + if mgr.shape[ax] >= 3: + assert_take_ok(mgr, ax, indexer=np.array([0, 1, 2], dtype=np.intp)) + assert_take_ok(mgr, ax, indexer=np.array([-1, -2, -3], dtype=np.intp)) + + @pytest.mark.parametrize("mgr", MANAGERS) + @pytest.mark.parametrize("fill_value", [None, np.nan, 100.0]) + def test_reindex_axis(self, fill_value, mgr): + def assert_reindex_axis_is_ok(mgr, axis, new_labels, fill_value): + mat = _as_array(mgr) + indexer = mgr.axes[axis].get_indexer_for(new_labels) + + reindexed = mgr.reindex_axis(new_labels, axis, fill_value=fill_value) + tm.assert_numpy_array_equal( + algos.take_nd(mat, indexer, axis, fill_value=fill_value), + _as_array(reindexed), + check_dtype=False, + ) + tm.assert_index_equal(reindexed.axes[axis], new_labels) + + for ax in range(mgr.ndim): + assert_reindex_axis_is_ok(mgr, ax, Index([]), fill_value) + assert_reindex_axis_is_ok(mgr, ax, mgr.axes[ax], fill_value) + assert_reindex_axis_is_ok(mgr, ax, mgr.axes[ax][[0, 0, 0]], fill_value) + assert_reindex_axis_is_ok(mgr, ax, Index(["foo", "bar", "baz"]), fill_value) + assert_reindex_axis_is_ok( + mgr, ax, Index(["foo", mgr.axes[ax][0], "baz"]), fill_value + ) + + if mgr.shape[ax] >= 3: + assert_reindex_axis_is_ok(mgr, ax, mgr.axes[ax][:-3], fill_value) + assert_reindex_axis_is_ok(mgr, ax, mgr.axes[ax][-3::-1], fill_value) + assert_reindex_axis_is_ok( + mgr, ax, mgr.axes[ax][[0, 1, 2, 0, 1, 2]], fill_value + ) + + @pytest.mark.parametrize("mgr", MANAGERS) + @pytest.mark.parametrize("fill_value", [None, np.nan, 100.0]) + def test_reindex_indexer(self, fill_value, mgr): + def assert_reindex_indexer_is_ok(mgr, axis, new_labels, indexer, fill_value): + mat = _as_array(mgr) + reindexed_mat = algos.take_nd(mat, indexer, axis, fill_value=fill_value) + reindexed = mgr.reindex_indexer( + new_labels, indexer, axis, fill_value=fill_value + ) + tm.assert_numpy_array_equal( + reindexed_mat, _as_array(reindexed), check_dtype=False + ) + tm.assert_index_equal(reindexed.axes[axis], new_labels) + + for ax in range(mgr.ndim): + assert_reindex_indexer_is_ok( + mgr, ax, Index([]), np.array([], dtype=np.intp), fill_value + ) + assert_reindex_indexer_is_ok( + mgr, ax, mgr.axes[ax], np.arange(mgr.shape[ax]), fill_value + ) + assert_reindex_indexer_is_ok( + mgr, + ax, + Index(["foo"] * mgr.shape[ax]), + np.arange(mgr.shape[ax]), + fill_value, + ) + assert_reindex_indexer_is_ok( + mgr, ax, mgr.axes[ax][::-1], np.arange(mgr.shape[ax]), fill_value + ) + assert_reindex_indexer_is_ok( + mgr, ax, mgr.axes[ax], np.arange(mgr.shape[ax])[::-1], fill_value + ) + assert_reindex_indexer_is_ok( + mgr, ax, Index(["foo", "bar", "baz"]), np.array([0, 0, 0]), fill_value + ) + assert_reindex_indexer_is_ok( + mgr, ax, Index(["foo", "bar", "baz"]), np.array([-1, 0, -1]), fill_value + ) + assert_reindex_indexer_is_ok( + mgr, + ax, + Index(["foo", mgr.axes[ax][0], "baz"]), + np.array([-1, -1, -1]), + fill_value, + ) + + if mgr.shape[ax] >= 3: + assert_reindex_indexer_is_ok( + mgr, + ax, + Index(["foo", "bar", "baz"]), + np.array([0, 1, 2]), + fill_value, + ) + + +class TestBlockPlacement: + @pytest.mark.parametrize( + "slc, expected", + [ + (slice(0, 4), 4), + (slice(0, 4, 2), 2), + (slice(0, 3, 2), 2), + (slice(0, 1, 2), 1), + (slice(1, 0, -1), 1), + ], + ) + def test_slice_len(self, slc, expected): + assert len(BlockPlacement(slc)) == expected + + @pytest.mark.parametrize("slc", [slice(1, 1, 0), slice(1, 2, 0)]) + def test_zero_step_raises(self, slc): + msg = "slice step cannot be zero" + with pytest.raises(ValueError, match=msg): + BlockPlacement(slc) + + def test_slice_canonize_negative_stop(self): + # GH#37524 negative stop is OK with negative step and positive start + slc = slice(3, -1, -2) + + bp = BlockPlacement(slc) + assert bp.indexer == slice(3, None, -2) + + @pytest.mark.parametrize( + "slc", + [ + slice(None, None), + slice(10, None), + slice(None, None, -1), + slice(None, 10, -1), + # These are "unbounded" because negative index will + # change depending on container shape. + slice(-1, None), + slice(None, -1), + slice(-1, -1), + slice(-1, None, -1), + slice(None, -1, -1), + slice(-1, -1, -1), + ], + ) + def test_unbounded_slice_raises(self, slc): + msg = "unbounded slice" + with pytest.raises(ValueError, match=msg): + BlockPlacement(slc) + + @pytest.mark.parametrize( + "slc", + [ + slice(0, 0), + slice(100, 0), + slice(100, 100), + slice(100, 100, -1), + slice(0, 100, -1), + ], + ) + def test_not_slice_like_slices(self, slc): + assert not BlockPlacement(slc).is_slice_like + + @pytest.mark.parametrize( + "arr, slc", + [ + ([0], slice(0, 1, 1)), + ([100], slice(100, 101, 1)), + ([0, 1, 2], slice(0, 3, 1)), + ([0, 5, 10], slice(0, 15, 5)), + ([0, 100], slice(0, 200, 100)), + ([2, 1], slice(2, 0, -1)), + ], + ) + def test_array_to_slice_conversion(self, arr, slc): + assert BlockPlacement(arr).as_slice == slc + + @pytest.mark.parametrize( + "arr", + [ + [], + [-1], + [-1, -2, -3], + [-10], + [-1], + [-1, 0, 1, 2], + [-2, 0, 2, 4], + [1, 0, -1], + [1, 1, 1], + ], + ) + def test_not_slice_like_arrays(self, arr): + assert not BlockPlacement(arr).is_slice_like + + @pytest.mark.parametrize( + "slc, expected", + [(slice(0, 3), [0, 1, 2]), (slice(0, 0), []), (slice(3, 0), [])], + ) + def test_slice_iter(self, slc, expected): + assert list(BlockPlacement(slc)) == expected + + @pytest.mark.parametrize( + "slc, arr", + [ + (slice(0, 3), [0, 1, 2]), + (slice(0, 0), []), + (slice(3, 0), []), + (slice(3, 0, -1), [3, 2, 1]), + ], + ) + def test_slice_to_array_conversion(self, slc, arr): + tm.assert_numpy_array_equal( + BlockPlacement(slc).as_array, np.asarray(arr, dtype=np.intp) + ) + + def test_blockplacement_add(self): + bpl = BlockPlacement(slice(0, 5)) + assert bpl.add(1).as_slice == slice(1, 6, 1) + assert bpl.add(np.arange(5)).as_slice == slice(0, 10, 2) + assert list(bpl.add(np.arange(5, 0, -1))) == [5, 5, 5, 5, 5] + + @pytest.mark.parametrize( + "val, inc, expected", + [ + (slice(0, 0), 0, []), + (slice(1, 4), 0, [1, 2, 3]), + (slice(3, 0, -1), 0, [3, 2, 1]), + ([1, 2, 4], 0, [1, 2, 4]), + (slice(0, 0), 10, []), + (slice(1, 4), 10, [11, 12, 13]), + (slice(3, 0, -1), 10, [13, 12, 11]), + ([1, 2, 4], 10, [11, 12, 14]), + (slice(0, 0), -1, []), + (slice(1, 4), -1, [0, 1, 2]), + ([1, 2, 4], -1, [0, 1, 3]), + ], + ) + def test_blockplacement_add_int(self, val, inc, expected): + assert list(BlockPlacement(val).add(inc)) == expected + + @pytest.mark.parametrize("val", [slice(1, 4), [1, 2, 4]]) + def test_blockplacement_add_int_raises(self, val): + msg = "iadd causes length change" + with pytest.raises(ValueError, match=msg): + BlockPlacement(val).add(-10) + + +class TestCanHoldElement: + @pytest.fixture( + params=[ + lambda x: x, + lambda x: x.to_series(), + lambda x: x._data, + lambda x: list(x), + lambda x: x.astype(object), + lambda x: np.asarray(x), + lambda x: x[0], + lambda x: x[:0], + ] + ) + def element(self, request): + """ + Functions that take an Index and return an element that should have + blk._can_hold_element(element) for a Block with this index's dtype. + """ + return request.param + + def test_datetime_block_can_hold_element(self): + block = create_block("datetime", [0]) + + assert block._can_hold_element([]) + + # We will check that block._can_hold_element iff arr.__setitem__ works + arr = pd.array(block.values.ravel()) + + # coerce None + assert block._can_hold_element(None) + arr[0] = None + assert arr[0] is pd.NaT + + # coerce different types of datetime objects + vals = [np.datetime64("2010-10-10"), datetime(2010, 10, 10)] + for val in vals: + assert block._can_hold_element(val) + arr[0] = val + + val = date(2010, 10, 10) + assert not block._can_hold_element(val) + + msg = ( + "value should be a 'Timestamp', 'NaT', " + "or array of those. Got 'date' instead." + ) + with pytest.raises(TypeError, match=msg): + arr[0] = val + + @pytest.mark.parametrize("dtype", [np.int64, np.uint64, np.float64]) + def test_interval_can_hold_element_emptylist(self, dtype, element): + arr = np.array([1, 3, 4], dtype=dtype) + ii = IntervalIndex.from_breaks(arr) + blk = new_block(ii._data, BlockPlacement([1]), ndim=2) + + assert blk._can_hold_element([]) + # TODO: check this holds for all blocks + + @pytest.mark.parametrize("dtype", [np.int64, np.uint64, np.float64]) + def test_interval_can_hold_element(self, dtype, element): + arr = np.array([1, 3, 4, 9], dtype=dtype) + ii = IntervalIndex.from_breaks(arr) + blk = new_block(ii._data, BlockPlacement([1]), ndim=2) + + elem = element(ii) + self.check_series_setitem(elem, ii, True) + assert blk._can_hold_element(elem) + + # Careful: to get the expected Series-inplace behavior we need + # `elem` to not have the same length as `arr` + ii2 = IntervalIndex.from_breaks(arr[:-1], closed="neither") + elem = element(ii2) + with tm.assert_produces_warning(FutureWarning): + self.check_series_setitem(elem, ii, False) + assert not blk._can_hold_element(elem) + + ii3 = IntervalIndex.from_breaks([Timestamp(1), Timestamp(3), Timestamp(4)]) + elem = element(ii3) + with tm.assert_produces_warning(FutureWarning): + self.check_series_setitem(elem, ii, False) + assert not blk._can_hold_element(elem) + + ii4 = IntervalIndex.from_breaks([Timedelta(1), Timedelta(3), Timedelta(4)]) + elem = element(ii4) + with tm.assert_produces_warning(FutureWarning): + self.check_series_setitem(elem, ii, False) + assert not blk._can_hold_element(elem) + + def test_period_can_hold_element_emptylist(self): + pi = period_range("2016", periods=3, freq="Y") + blk = new_block(pi._data.reshape(1, 3), BlockPlacement([1]), ndim=2) + + assert blk._can_hold_element([]) + + def test_period_can_hold_element(self, element): + pi = period_range("2016", periods=3, freq="Y") + + elem = element(pi) + self.check_series_setitem(elem, pi, True) + + # Careful: to get the expected Series-inplace behavior we need + # `elem` to not have the same length as `arr` + pi2 = pi.asfreq("D")[:-1] + elem = element(pi2) + with tm.assert_produces_warning(FutureWarning): + self.check_series_setitem(elem, pi, False) + + dti = pi.to_timestamp("s")[:-1] + elem = element(dti) + with tm.assert_produces_warning(FutureWarning): + self.check_series_setitem(elem, pi, False) + + def check_can_hold_element(self, obj, elem, inplace: bool): + blk = obj._mgr.blocks[0] + if inplace: + assert blk._can_hold_element(elem) + else: + assert not blk._can_hold_element(elem) + + def check_series_setitem(self, elem, index: Index, inplace: bool): + arr = index._data.copy() + ser = Series(arr, copy=False) + + self.check_can_hold_element(ser, elem, inplace) + + if is_scalar(elem): + ser[0] = elem + else: + ser[: len(elem)] = elem + + if inplace: + assert ser.array is arr # i.e. setting was done inplace + else: + assert ser.dtype == object + + +class TestShouldStore: + def test_should_store_categorical(self): + cat = Categorical(["A", "B", "C"]) + df = DataFrame(cat) + blk = df._mgr.blocks[0] + + # matching dtype + assert blk.should_store(cat) + assert blk.should_store(cat[:-1]) + + # different dtype + assert not blk.should_store(cat.as_ordered()) + + # ndarray instead of Categorical + assert not blk.should_store(np.asarray(cat)) + + +def test_validate_ndim(): + values = np.array([1.0, 2.0]) + placement = BlockPlacement(slice(2)) + msg = r"Wrong number of dimensions. values.ndim != ndim \[1 != 2\]" + + with pytest.raises(ValueError, match=msg): + make_block(values, placement, ndim=2) + + +def test_block_shape(): + idx = Index([0, 1, 2, 3, 4]) + a = Series([1, 2, 3]).reindex(idx) + b = Series(Categorical([1, 2, 3])).reindex(idx) + + assert a._mgr.blocks[0].mgr_locs.indexer == b._mgr.blocks[0].mgr_locs.indexer + + +def test_make_block_no_pandas_array(block_maker): + # https://github.com/pandas-dev/pandas/pull/24866 + arr = pd.arrays.NumpyExtensionArray(np.array([1, 2])) + + # NumpyExtensionArray, no dtype + result = block_maker(arr, BlockPlacement(slice(len(arr))), ndim=arr.ndim) + assert result.dtype.kind in ["i", "u"] + + if block_maker is make_block: + # new_block requires caller to unwrap NumpyExtensionArray + assert result.is_extension is False + + # NumpyExtensionArray, NumpyEADtype + result = block_maker(arr, slice(len(arr)), dtype=arr.dtype, ndim=arr.ndim) + assert result.dtype.kind in ["i", "u"] + assert result.is_extension is False + + # new_block no longer taked dtype keyword + # ndarray, NumpyEADtype + result = block_maker( + arr.to_numpy(), slice(len(arr)), dtype=arr.dtype, ndim=arr.ndim + ) + assert result.dtype.kind in ["i", "u"] + assert result.is_extension is False diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/internals/test_managers.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/internals/test_managers.py new file mode 100644 index 0000000000000000000000000000000000000000..f40362c299717be5f2e8665e4547276c2af05fb0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/internals/test_managers.py @@ -0,0 +1,103 @@ +""" +Testing interaction between the different managers (BlockManager, ArrayManager) +""" +import os +import subprocess +import sys + +import pytest + +from pandas.core.dtypes.missing import array_equivalent + +import pandas as pd +import pandas._testing as tm +from pandas.core.internals import ( + ArrayManager, + BlockManager, + SingleArrayManager, + SingleBlockManager, +) + + +def test_dataframe_creation(): + msg = "data_manager option is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + with pd.option_context("mode.data_manager", "block"): + df_block = pd.DataFrame( + {"a": [1, 2, 3], "b": [0.1, 0.2, 0.3], "c": [4, 5, 6]} + ) + assert isinstance(df_block._mgr, BlockManager) + + with tm.assert_produces_warning(FutureWarning, match=msg): + with pd.option_context("mode.data_manager", "array"): + df_array = pd.DataFrame( + {"a": [1, 2, 3], "b": [0.1, 0.2, 0.3], "c": [4, 5, 6]} + ) + assert isinstance(df_array._mgr, ArrayManager) + + # also ensure both are seen as equal + tm.assert_frame_equal(df_block, df_array) + + # conversion from one manager to the other + result = df_block._as_manager("block") + assert isinstance(result._mgr, BlockManager) + result = df_block._as_manager("array") + assert isinstance(result._mgr, ArrayManager) + tm.assert_frame_equal(result, df_block) + assert all( + array_equivalent(left, right) + for left, right in zip(result._mgr.arrays, df_array._mgr.arrays) + ) + + result = df_array._as_manager("array") + assert isinstance(result._mgr, ArrayManager) + result = df_array._as_manager("block") + assert isinstance(result._mgr, BlockManager) + tm.assert_frame_equal(result, df_array) + assert len(result._mgr.blocks) == 2 + + +def test_series_creation(): + msg = "data_manager option is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + with pd.option_context("mode.data_manager", "block"): + s_block = pd.Series([1, 2, 3], name="A", index=["a", "b", "c"]) + assert isinstance(s_block._mgr, SingleBlockManager) + + with tm.assert_produces_warning(FutureWarning, match=msg): + with pd.option_context("mode.data_manager", "array"): + s_array = pd.Series([1, 2, 3], name="A", index=["a", "b", "c"]) + assert isinstance(s_array._mgr, SingleArrayManager) + + # also ensure both are seen as equal + tm.assert_series_equal(s_block, s_array) + + # conversion from one manager to the other + result = s_block._as_manager("block") + assert isinstance(result._mgr, SingleBlockManager) + result = s_block._as_manager("array") + assert isinstance(result._mgr, SingleArrayManager) + tm.assert_series_equal(result, s_block) + + result = s_array._as_manager("array") + assert isinstance(result._mgr, SingleArrayManager) + result = s_array._as_manager("block") + assert isinstance(result._mgr, SingleBlockManager) + tm.assert_series_equal(result, s_array) + + +@pytest.mark.single_cpu +@pytest.mark.parametrize("manager", ["block", "array"]) +def test_array_manager_depr_env_var(manager): + # GH#55043 + test_env = os.environ.copy() + test_env["PANDAS_DATA_MANAGER"] = manager + response = subprocess.run( + [sys.executable, "-c", "import pandas"], + capture_output=True, + env=test_env, + check=True, + ) + msg = "FutureWarning: The env variable PANDAS_DATA_MANAGER is set" + stderr_msg = response.stderr.decode("utf-8") + assert msg in stderr_msg, stderr_msg diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/conftest.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..ab6cacc4cc860d0d4c0ffe948274252daae2ee27 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/conftest.py @@ -0,0 +1,242 @@ +import shlex +import subprocess +import time +import uuid + +import pytest + +from pandas.compat import ( + is_ci_environment, + is_platform_arm, + is_platform_mac, + is_platform_windows, +) +import pandas.util._test_decorators as td + +import pandas.io.common as icom +from pandas.io.parsers import read_csv + + +@pytest.fixture +def compression_to_extension(): + return {value: key for key, value in icom.extension_to_compression.items()} + + +@pytest.fixture +def tips_file(datapath): + """Path to the tips dataset""" + return datapath("io", "data", "csv", "tips.csv") + + +@pytest.fixture +def jsonl_file(datapath): + """Path to a JSONL dataset""" + return datapath("io", "parser", "data", "items.jsonl") + + +@pytest.fixture +def salaries_table(datapath): + """DataFrame with the salaries dataset""" + return read_csv(datapath("io", "parser", "data", "salaries.csv"), sep="\t") + + +@pytest.fixture +def feather_file(datapath): + return datapath("io", "data", "feather", "feather-0_3_1.feather") + + +@pytest.fixture +def xml_file(datapath): + return datapath("io", "data", "xml", "books.xml") + + +@pytest.fixture +def s3_base(worker_id, monkeypatch): + """ + Fixture for mocking S3 interaction. + + Sets up moto server in separate process locally + Return url for motoserver/moto CI service + """ + pytest.importorskip("s3fs") + pytest.importorskip("boto3") + + # temporary workaround as moto fails for botocore >= 1.11 otherwise, + # see https://github.com/spulec/moto/issues/1924 & 1952 + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "foobar_key") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "foobar_secret") + if is_ci_environment(): + if is_platform_arm() or is_platform_mac() or is_platform_windows(): + # NOT RUN on Windows/macOS/ARM, only Ubuntu + # - subprocess in CI can cause timeouts + # - GitHub Actions do not support + # container services for the above OSs + # - CircleCI will probably hit the Docker rate pull limit + pytest.skip( + "S3 tests do not have a corresponding service in " + "Windows, macOS or ARM platforms" + ) + else: + # set in .github/workflows/unit-tests.yml + yield "http://localhost:5000" + else: + requests = pytest.importorskip("requests") + pytest.importorskip("moto") + pytest.importorskip("flask") # server mode needs flask too + + # Launching moto in server mode, i.e., as a separate process + # with an S3 endpoint on localhost + + worker_id = "5" if worker_id == "master" else worker_id.lstrip("gw") + endpoint_port = f"555{worker_id}" + endpoint_uri = f"http://127.0.0.1:{endpoint_port}/" + + # pipe to null to avoid logging in terminal + with subprocess.Popen( + shlex.split(f"moto_server s3 -p {endpoint_port}"), + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) as proc: + timeout = 5 + while timeout > 0: + try: + # OK to go once server is accepting connections + r = requests.get(endpoint_uri) + if r.ok: + break + except Exception: + pass + timeout -= 0.1 + time.sleep(0.1) + yield endpoint_uri + + proc.terminate() + + +@pytest.fixture +def s3so(s3_base): + return {"client_kwargs": {"endpoint_url": s3_base}} + + +@pytest.fixture +def s3_resource(s3_base): + import boto3 + + s3 = boto3.resource("s3", endpoint_url=s3_base) + return s3 + + +@pytest.fixture +def s3_public_bucket(s3_resource): + bucket = s3_resource.Bucket(f"pandas-test-{uuid.uuid4()}") + bucket.create() + yield bucket + bucket.objects.delete() + bucket.delete() + + +@pytest.fixture +def s3_public_bucket_with_data( + s3_public_bucket, tips_file, jsonl_file, feather_file, xml_file +): + """ + The following datasets + are loaded. + + - tips.csv + - tips.csv.gz + - tips.csv.bz2 + - items.jsonl + """ + test_s3_files = [ + ("tips#1.csv", tips_file), + ("tips.csv", tips_file), + ("tips.csv.gz", tips_file + ".gz"), + ("tips.csv.bz2", tips_file + ".bz2"), + ("items.jsonl", jsonl_file), + ("simple_dataset.feather", feather_file), + ("books.xml", xml_file), + ] + for s3_key, file_name in test_s3_files: + with open(file_name, "rb") as f: + s3_public_bucket.put_object(Key=s3_key, Body=f) + return s3_public_bucket + + +@pytest.fixture +def s3_private_bucket(s3_resource): + bucket = s3_resource.Bucket(f"cant_get_it-{uuid.uuid4()}") + bucket.create(ACL="private") + yield bucket + bucket.objects.delete() + bucket.delete() + + +@pytest.fixture +def s3_private_bucket_with_data( + s3_private_bucket, tips_file, jsonl_file, feather_file, xml_file +): + """ + The following datasets + are loaded. + + - tips.csv + - tips.csv.gz + - tips.csv.bz2 + - items.jsonl + """ + test_s3_files = [ + ("tips#1.csv", tips_file), + ("tips.csv", tips_file), + ("tips.csv.gz", tips_file + ".gz"), + ("tips.csv.bz2", tips_file + ".bz2"), + ("items.jsonl", jsonl_file), + ("simple_dataset.feather", feather_file), + ("books.xml", xml_file), + ] + for s3_key, file_name in test_s3_files: + with open(file_name, "rb") as f: + s3_private_bucket.put_object(Key=s3_key, Body=f) + return s3_private_bucket + + +_compression_formats_params = [ + (".no_compress", None), + ("", None), + (".gz", "gzip"), + (".GZ", "gzip"), + (".bz2", "bz2"), + (".BZ2", "bz2"), + (".zip", "zip"), + (".ZIP", "zip"), + (".xz", "xz"), + (".XZ", "xz"), + pytest.param((".zst", "zstd"), marks=td.skip_if_no("zstandard")), + pytest.param((".ZST", "zstd"), marks=td.skip_if_no("zstandard")), +] + + +@pytest.fixture(params=_compression_formats_params[1:]) +def compression_format(request): + return request.param + + +@pytest.fixture(params=_compression_formats_params) +def compression_ext(request): + return request.param[0] + + +@pytest.fixture( + params=[ + "python", + pytest.param("pyarrow", marks=td.skip_if_no("pyarrow")), + ] +) +def string_storage(request): + """ + Parametrized fixture for pd.options.mode.string_storage. + + * 'python' + * 'pyarrow' + """ + return request.param diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/generate_legacy_storage_files.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/generate_legacy_storage_files.py new file mode 100644 index 0000000000000000000000000000000000000000..9bfd8eb9d51d59ef83c7a4d6fcf8bbeb1ef24025 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/generate_legacy_storage_files.py @@ -0,0 +1,342 @@ +""" +self-contained to write legacy storage pickle files + +To use this script. Create an environment where you want +generate pickles, say its for 0.20.3, with your pandas clone +in ~/pandas + +. activate pandas_0.20.3 +cd ~/pandas/pandas + +$ python -m tests.io.generate_legacy_storage_files \ + tests/io/data/legacy_pickle/0.20.3/ pickle + +This script generates a storage file for the current arch, system, +and python version + pandas version: 0.20.3 + output dir : pandas/pandas/tests/io/data/legacy_pickle/0.20.3/ + storage format: pickle +created pickle file: 0.20.3_x86_64_darwin_3.5.2.pickle + +The idea here is you are using the *current* version of the +generate_legacy_storage_files with an *older* version of pandas to +generate a pickle file. We will then check this file into a current +branch, and test using test_pickle.py. This will load the *older* +pickles and test versus the current data that is generated +(with main). These are then compared. + +If we have cases where we changed the signature (e.g. we renamed +offset -> freq in Timestamp). Then we have to conditionally execute +in the generate_legacy_storage_files.py to make it +run under the older AND the newer version. + +""" + +from datetime import timedelta +import os +import pickle +import platform as pl +import sys + +# Remove script directory from path, otherwise Python will try to +# import the JSON test directory as the json module +sys.path.pop(0) + +import numpy as np + +import pandas +from pandas import ( + Categorical, + DataFrame, + Index, + MultiIndex, + NaT, + Period, + RangeIndex, + Series, + Timestamp, + bdate_range, + date_range, + interval_range, + period_range, + timedelta_range, +) +from pandas.arrays import SparseArray + +from pandas.tseries.offsets import ( + FY5253, + BusinessDay, + BusinessHour, + CustomBusinessDay, + DateOffset, + Day, + Easter, + Hour, + LastWeekOfMonth, + Minute, + MonthBegin, + MonthEnd, + QuarterBegin, + QuarterEnd, + SemiMonthBegin, + SemiMonthEnd, + Week, + WeekOfMonth, + YearBegin, + YearEnd, +) + + +def _create_sp_series(): + nan = np.nan + + # nan-based + arr = np.arange(15, dtype=np.float64) + arr[7:12] = nan + arr[-1:] = nan + + bseries = Series(SparseArray(arr, kind="block")) + bseries.name = "bseries" + return bseries + + +def _create_sp_tsseries(): + nan = np.nan + + # nan-based + arr = np.arange(15, dtype=np.float64) + arr[7:12] = nan + arr[-1:] = nan + + date_index = bdate_range("1/1/2011", periods=len(arr)) + bseries = Series(SparseArray(arr, kind="block"), index=date_index) + bseries.name = "btsseries" + return bseries + + +def _create_sp_frame(): + nan = np.nan + + data = { + "A": [nan, nan, nan, 0, 1, 2, 3, 4, 5, 6], + "B": [0, 1, 2, nan, nan, nan, 3, 4, 5, 6], + "C": np.arange(10).astype(np.int64), + "D": [0, 1, 2, 3, 4, 5, nan, nan, nan, nan], + } + + dates = bdate_range("1/1/2011", periods=10) + return DataFrame(data, index=dates).apply(SparseArray) + + +def create_pickle_data(): + """create the pickle data""" + data = { + "A": [0.0, 1.0, 2.0, 3.0, np.nan], + "B": [0, 1, 0, 1, 0], + "C": ["foo1", "foo2", "foo3", "foo4", "foo5"], + "D": date_range("1/1/2009", periods=5), + "E": [0.0, 1, Timestamp("20100101"), "foo", 2.0], + } + + scalars = {"timestamp": Timestamp("20130101"), "period": Period("2012", "M")} + + index = { + "int": Index(np.arange(10)), + "date": date_range("20130101", periods=10), + "period": period_range("2013-01-01", freq="M", periods=10), + "float": Index(np.arange(10, dtype=np.float64)), + "uint": Index(np.arange(10, dtype=np.uint64)), + "timedelta": timedelta_range("00:00:00", freq="30min", periods=10), + } + + index["range"] = RangeIndex(10) + + index["interval"] = interval_range(0, periods=10) + + mi = { + "reg2": MultiIndex.from_tuples( + tuple( + zip( + *[ + ["bar", "bar", "baz", "baz", "foo", "foo", "qux", "qux"], + ["one", "two", "one", "two", "one", "two", "one", "two"], + ] + ) + ), + names=["first", "second"], + ) + } + + series = { + "float": Series(data["A"]), + "int": Series(data["B"]), + "mixed": Series(data["E"]), + "ts": Series( + np.arange(10).astype(np.int64), index=date_range("20130101", periods=10) + ), + "mi": Series( + np.arange(5).astype(np.float64), + index=MultiIndex.from_tuples( + tuple(zip(*[[1, 1, 2, 2, 2], [3, 4, 3, 4, 5]])), names=["one", "two"] + ), + ), + "dup": Series(np.arange(5).astype(np.float64), index=["A", "B", "C", "D", "A"]), + "cat": Series(Categorical(["foo", "bar", "baz"])), + "dt": Series(date_range("20130101", periods=5)), + "dt_tz": Series(date_range("20130101", periods=5, tz="US/Eastern")), + "period": Series([Period("2000Q1")] * 5), + } + + mixed_dup_df = DataFrame(data) + mixed_dup_df.columns = list("ABCDA") + frame = { + "float": DataFrame({"A": series["float"], "B": series["float"] + 1}), + "int": DataFrame({"A": series["int"], "B": series["int"] + 1}), + "mixed": DataFrame({k: data[k] for k in ["A", "B", "C", "D"]}), + "mi": DataFrame( + {"A": np.arange(5).astype(np.float64), "B": np.arange(5).astype(np.int64)}, + index=MultiIndex.from_tuples( + tuple( + zip( + *[ + ["bar", "bar", "baz", "baz", "baz"], + ["one", "two", "one", "two", "three"], + ] + ) + ), + names=["first", "second"], + ), + ), + "dup": DataFrame( + np.arange(15).reshape(5, 3).astype(np.float64), columns=["A", "B", "A"] + ), + "cat_onecol": DataFrame({"A": Categorical(["foo", "bar"])}), + "cat_and_float": DataFrame( + { + "A": Categorical(["foo", "bar", "baz"]), + "B": np.arange(3).astype(np.int64), + } + ), + "mixed_dup": mixed_dup_df, + "dt_mixed_tzs": DataFrame( + { + "A": Timestamp("20130102", tz="US/Eastern"), + "B": Timestamp("20130603", tz="CET"), + }, + index=range(5), + ), + "dt_mixed2_tzs": DataFrame( + { + "A": Timestamp("20130102", tz="US/Eastern"), + "B": Timestamp("20130603", tz="CET"), + "C": Timestamp("20130603", tz="UTC"), + }, + index=range(5), + ), + } + + cat = { + "int8": Categorical(list("abcdefg")), + "int16": Categorical(np.arange(1000)), + "int32": Categorical(np.arange(10000)), + } + + timestamp = { + "normal": Timestamp("2011-01-01"), + "nat": NaT, + "tz": Timestamp("2011-01-01", tz="US/Eastern"), + } + + off = { + "DateOffset": DateOffset(years=1), + "DateOffset_h_ns": DateOffset(hour=6, nanoseconds=5824), + "BusinessDay": BusinessDay(offset=timedelta(seconds=9)), + "BusinessHour": BusinessHour(normalize=True, n=6, end="15:14"), + "CustomBusinessDay": CustomBusinessDay(weekmask="Mon Fri"), + "SemiMonthBegin": SemiMonthBegin(day_of_month=9), + "SemiMonthEnd": SemiMonthEnd(day_of_month=24), + "MonthBegin": MonthBegin(1), + "MonthEnd": MonthEnd(1), + "QuarterBegin": QuarterBegin(1), + "QuarterEnd": QuarterEnd(1), + "Day": Day(1), + "YearBegin": YearBegin(1), + "YearEnd": YearEnd(1), + "Week": Week(1), + "Week_Tues": Week(2, normalize=False, weekday=1), + "WeekOfMonth": WeekOfMonth(week=3, weekday=4), + "LastWeekOfMonth": LastWeekOfMonth(n=1, weekday=3), + "FY5253": FY5253(n=2, weekday=6, startingMonth=7, variation="last"), + "Easter": Easter(), + "Hour": Hour(1), + "Minute": Minute(1), + } + + return { + "series": series, + "frame": frame, + "index": index, + "scalars": scalars, + "mi": mi, + "sp_series": {"float": _create_sp_series(), "ts": _create_sp_tsseries()}, + "sp_frame": {"float": _create_sp_frame()}, + "cat": cat, + "timestamp": timestamp, + "offsets": off, + } + + +def platform_name(): + return "_".join( + [ + str(pandas.__version__), + str(pl.machine()), + str(pl.system().lower()), + str(pl.python_version()), + ] + ) + + +def write_legacy_pickles(output_dir): + version = pandas.__version__ + + print( + "This script generates a storage file for the current arch, system, " + "and python version" + ) + print(f" pandas version: {version}") + print(f" output dir : {output_dir}") + print(" storage format: pickle") + + pth = f"{platform_name()}.pickle" + + with open(os.path.join(output_dir, pth), "wb") as fh: + pickle.dump(create_pickle_data(), fh, pickle.DEFAULT_PROTOCOL) + + print(f"created pickle file: {pth}") + + +def write_legacy_file(): + # force our cwd to be the first searched + sys.path.insert(0, "") + + if not 3 <= len(sys.argv) <= 4: + sys.exit( + "Specify output directory and storage type: generate_legacy_" + "storage_files.py " + ) + + output_dir = str(sys.argv[1]) + storage_type = str(sys.argv[2]) + + if not os.path.exists(output_dir): + os.mkdir(output_dir) + + if storage_type == "pickle": + write_legacy_pickles(output_dir=output_dir) + else: + sys.exit("storage_type must be one of {'pickle'}") + + +if __name__ == "__main__": + write_legacy_file() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_clipboard.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_clipboard.py new file mode 100644 index 0000000000000000000000000000000000000000..3c0208fcc74ec83f782e1fedf5e89b40fca3ed69 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_clipboard.py @@ -0,0 +1,423 @@ +from textwrap import dedent + +import numpy as np +import pytest + +from pandas.errors import ( + PyperclipException, + PyperclipWindowsException, +) + +import pandas as pd +from pandas import ( + NA, + DataFrame, + Series, + get_option, + read_clipboard, +) +import pandas._testing as tm +from pandas.core.arrays import ( + ArrowStringArray, + StringArray, +) + +from pandas.io.clipboard import ( + CheckedCall, + _stringifyText, + init_qt_clipboard, +) + + +def build_kwargs(sep, excel): + kwargs = {} + if excel != "default": + kwargs["excel"] = excel + if sep != "default": + kwargs["sep"] = sep + return kwargs + + +@pytest.fixture( + params=[ + "delims", + "utf8", + "utf16", + "string", + "long", + "nonascii", + "colwidth", + "mixed", + "float", + "int", + ] +) +def df(request): + data_type = request.param + + if data_type == "delims": + return DataFrame({"a": ['"a,\t"b|c', "d\tef`"], "b": ["hi'j", "k''lm"]}) + elif data_type == "utf8": + return DataFrame({"a": ["µasd", "Ωœ∑`"], "b": ["øπ∆˚¬", "œ∑`®"]}) + elif data_type == "utf16": + return DataFrame( + {"a": ["\U0001f44d\U0001f44d", "\U0001f44d\U0001f44d"], "b": ["abc", "def"]} + ) + elif data_type == "string": + return DataFrame( + np.array([f"i-{i}" for i in range(15)]).reshape(5, 3), columns=list("abc") + ) + elif data_type == "long": + max_rows = get_option("display.max_rows") + return DataFrame( + np.random.default_rng(2).integers(0, 10, size=(max_rows + 1, 3)), + columns=list("abc"), + ) + elif data_type == "nonascii": + return DataFrame({"en": "in English".split(), "es": "en español".split()}) + elif data_type == "colwidth": + _cw = get_option("display.max_colwidth") + 1 + return DataFrame( + np.array(["x" * _cw for _ in range(15)]).reshape(5, 3), columns=list("abc") + ) + elif data_type == "mixed": + return DataFrame( + { + "a": np.arange(1.0, 6.0) + 0.01, + "b": np.arange(1, 6).astype(np.int64), + "c": list("abcde"), + } + ) + elif data_type == "float": + return DataFrame(np.random.default_rng(2).random((5, 3)), columns=list("abc")) + elif data_type == "int": + return DataFrame( + np.random.default_rng(2).integers(0, 10, (5, 3)), columns=list("abc") + ) + else: + raise ValueError + + +@pytest.fixture +def mock_ctypes(monkeypatch): + """ + Mocks WinError to help with testing the clipboard. + """ + + def _mock_win_error(): + return "Window Error" + + # Set raising to False because WinError won't exist on non-windows platforms + with monkeypatch.context() as m: + m.setattr("ctypes.WinError", _mock_win_error, raising=False) + yield + + +@pytest.mark.usefixtures("mock_ctypes") +def test_checked_call_with_bad_call(monkeypatch): + """ + Give CheckCall a function that returns a falsey value and + mock get_errno so it returns false so an exception is raised. + """ + + def _return_false(): + return False + + monkeypatch.setattr("pandas.io.clipboard.get_errno", lambda: True) + msg = f"Error calling {_return_false.__name__} \\(Window Error\\)" + + with pytest.raises(PyperclipWindowsException, match=msg): + CheckedCall(_return_false)() + + +@pytest.mark.usefixtures("mock_ctypes") +def test_checked_call_with_valid_call(monkeypatch): + """ + Give CheckCall a function that returns a truthy value and + mock get_errno so it returns true so an exception is not raised. + The function should return the results from _return_true. + """ + + def _return_true(): + return True + + monkeypatch.setattr("pandas.io.clipboard.get_errno", lambda: False) + + # Give CheckedCall a callable that returns a truthy value s + checked_call = CheckedCall(_return_true) + assert checked_call() is True + + +@pytest.mark.parametrize( + "text", + [ + "String_test", + True, + 1, + 1.0, + 1j, + ], +) +def test_stringify_text(text): + valid_types = (str, int, float, bool) + + if isinstance(text, valid_types): + result = _stringifyText(text) + assert result == str(text) + else: + msg = ( + "only str, int, float, and bool values " + f"can be copied to the clipboard, not {type(text).__name__}" + ) + with pytest.raises(PyperclipException, match=msg): + _stringifyText(text) + + +@pytest.fixture +def set_pyqt_clipboard(monkeypatch): + qt_cut, qt_paste = init_qt_clipboard() + with monkeypatch.context() as m: + m.setattr(pd.io.clipboard, "clipboard_set", qt_cut) + m.setattr(pd.io.clipboard, "clipboard_get", qt_paste) + yield + + +@pytest.fixture +def clipboard(qapp): + clip = qapp.clipboard() + yield clip + clip.clear() + + +@pytest.mark.single_cpu +@pytest.mark.clipboard +@pytest.mark.usefixtures("set_pyqt_clipboard") +@pytest.mark.usefixtures("clipboard") +class TestClipboard: + # Test that default arguments copy as tab delimited + # Test that explicit delimiters are respected + @pytest.mark.parametrize("sep", [None, "\t", ",", "|"]) + @pytest.mark.parametrize("encoding", [None, "UTF-8", "utf-8", "utf8"]) + def test_round_trip_frame_sep(self, df, sep, encoding): + df.to_clipboard(excel=None, sep=sep, encoding=encoding) + result = read_clipboard(sep=sep or "\t", index_col=0, encoding=encoding) + tm.assert_frame_equal(df, result) + + # Test white space separator + def test_round_trip_frame_string(self, df): + df.to_clipboard(excel=False, sep=None) + result = read_clipboard() + assert df.to_string() == result.to_string() + assert df.shape == result.shape + + # Two character separator is not supported in to_clipboard + # Test that multi-character separators are not silently passed + def test_excel_sep_warning(self, df): + with tm.assert_produces_warning( + UserWarning, + match="to_clipboard in excel mode requires a single character separator.", + check_stacklevel=False, + ): + df.to_clipboard(excel=True, sep=r"\t") + + # Separator is ignored when excel=False and should produce a warning + def test_copy_delim_warning(self, df): + with tm.assert_produces_warning(): + df.to_clipboard(excel=False, sep="\t") + + # Tests that the default behavior of to_clipboard is tab + # delimited and excel="True" + @pytest.mark.parametrize("sep", ["\t", None, "default"]) + @pytest.mark.parametrize("excel", [True, None, "default"]) + def test_clipboard_copy_tabs_default(self, sep, excel, df, clipboard): + kwargs = build_kwargs(sep, excel) + df.to_clipboard(**kwargs) + assert clipboard.text() == df.to_csv(sep="\t") + + # Tests reading of white space separated tables + @pytest.mark.parametrize("sep", [None, "default"]) + def test_clipboard_copy_strings(self, sep, df): + kwargs = build_kwargs(sep, False) + df.to_clipboard(**kwargs) + result = read_clipboard(sep=r"\s+") + assert result.to_string() == df.to_string() + assert df.shape == result.shape + + def test_read_clipboard_infer_excel(self, clipboard): + # gh-19010: avoid warnings + clip_kwargs = {"engine": "python"} + + text = dedent( + """ + John James\tCharlie Mingus + 1\t2 + 4\tHarry Carney + """.strip() + ) + clipboard.setText(text) + df = read_clipboard(**clip_kwargs) + + # excel data is parsed correctly + assert df.iloc[1, 1] == "Harry Carney" + + # having diff tab counts doesn't trigger it + text = dedent( + """ + a\t b + 1 2 + 3 4 + """.strip() + ) + clipboard.setText(text) + res = read_clipboard(**clip_kwargs) + + text = dedent( + """ + a b + 1 2 + 3 4 + """.strip() + ) + clipboard.setText(text) + exp = read_clipboard(**clip_kwargs) + + tm.assert_frame_equal(res, exp) + + def test_infer_excel_with_nulls(self, clipboard): + # GH41108 + text = "col1\tcol2\n1\tred\n\tblue\n2\tgreen" + + clipboard.setText(text) + df = read_clipboard() + df_expected = DataFrame( + data={"col1": [1, None, 2], "col2": ["red", "blue", "green"]} + ) + + # excel data is parsed correctly + tm.assert_frame_equal(df, df_expected) + + @pytest.mark.parametrize( + "multiindex", + [ + ( # Can't use `dedent` here as it will remove the leading `\t` + "\n".join( + [ + "\t\t\tcol1\tcol2", + "A\t0\tTrue\t1\tred", + "A\t1\tTrue\t\tblue", + "B\t0\tFalse\t2\tgreen", + ] + ), + [["A", "A", "B"], [0, 1, 0], [True, True, False]], + ), + ( + "\n".join( + ["\t\tcol1\tcol2", "A\t0\t1\tred", "A\t1\t\tblue", "B\t0\t2\tgreen"] + ), + [["A", "A", "B"], [0, 1, 0]], + ), + ], + ) + def test_infer_excel_with_multiindex(self, clipboard, multiindex): + # GH41108 + + clipboard.setText(multiindex[0]) + df = read_clipboard() + df_expected = DataFrame( + data={"col1": [1, None, 2], "col2": ["red", "blue", "green"]}, + index=multiindex[1], + ) + + # excel data is parsed correctly + tm.assert_frame_equal(df, df_expected) + + def test_invalid_encoding(self, df): + msg = "clipboard only supports utf-8 encoding" + # test case for testing invalid encoding + with pytest.raises(ValueError, match=msg): + df.to_clipboard(encoding="ascii") + with pytest.raises(NotImplementedError, match=msg): + read_clipboard(encoding="ascii") + + @pytest.mark.parametrize("data", ["\U0001f44d...", "Ωœ∑`...", "abcd..."]) + def test_raw_roundtrip(self, data): + # PR #25040 wide unicode wasn't copied correctly on PY3 on windows + df = DataFrame({"data": [data]}) + df.to_clipboard() + result = read_clipboard() + tm.assert_frame_equal(df, result) + + @pytest.mark.parametrize("engine", ["c", "python"]) + def test_read_clipboard_dtype_backend( + self, clipboard, string_storage, dtype_backend, engine + ): + # GH#50502 + if string_storage == "pyarrow" or dtype_backend == "pyarrow": + pa = pytest.importorskip("pyarrow") + + if string_storage == "python": + string_array = StringArray(np.array(["x", "y"], dtype=np.object_)) + string_array_na = StringArray(np.array(["x", NA], dtype=np.object_)) + + elif dtype_backend == "pyarrow" and engine != "c": + pa = pytest.importorskip("pyarrow") + from pandas.arrays import ArrowExtensionArray + + string_array = ArrowExtensionArray(pa.array(["x", "y"])) + string_array_na = ArrowExtensionArray(pa.array(["x", None])) + + else: + string_array = ArrowStringArray(pa.array(["x", "y"])) + string_array_na = ArrowStringArray(pa.array(["x", None])) + + text = """a,b,c,d,e,f,g,h,i +x,1,4.0,x,2,4.0,,True,False +y,2,5.0,,,,,False,""" + clipboard.setText(text) + + with pd.option_context("mode.string_storage", string_storage): + result = read_clipboard(sep=",", dtype_backend=dtype_backend, engine=engine) + + expected = DataFrame( + { + "a": string_array, + "b": Series([1, 2], dtype="Int64"), + "c": Series([4.0, 5.0], dtype="Float64"), + "d": string_array_na, + "e": Series([2, NA], dtype="Int64"), + "f": Series([4.0, NA], dtype="Float64"), + "g": Series([NA, NA], dtype="Int64"), + "h": Series([True, False], dtype="boolean"), + "i": Series([False, NA], dtype="boolean"), + } + ) + if dtype_backend == "pyarrow": + from pandas.arrays import ArrowExtensionArray + + expected = DataFrame( + { + col: ArrowExtensionArray(pa.array(expected[col], from_pandas=True)) + for col in expected.columns + } + ) + expected["g"] = ArrowExtensionArray(pa.array([None, None])) + + tm.assert_frame_equal(result, expected) + + def test_invalid_dtype_backend(self): + msg = ( + "dtype_backend numpy is invalid, only 'numpy_nullable' and " + "'pyarrow' are allowed." + ) + with pytest.raises(ValueError, match=msg): + read_clipboard(dtype_backend="numpy") + + def test_to_clipboard_pos_args_deprecation(self): + # GH-54229 + df = DataFrame({"a": [1, 2, 3]}) + msg = ( + r"Starting with pandas version 3.0 all arguments of to_clipboard " + r"will be keyword-only." + ) + with tm.assert_produces_warning(FutureWarning, match=msg): + df.to_clipboard(True, None) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_common.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_common.py new file mode 100644 index 0000000000000000000000000000000000000000..e51f86563081b2749cefb027bd425c1578eda9f6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_common.py @@ -0,0 +1,653 @@ +""" +Tests for the pandas.io.common functionalities +""" +import codecs +import errno +from functools import partial +from io import ( + BytesIO, + StringIO, + UnsupportedOperation, +) +import mmap +import os +from pathlib import Path +import pickle +import tempfile + +import numpy as np +import pytest + +from pandas.compat import is_platform_windows +import pandas.util._test_decorators as td + +import pandas as pd +import pandas._testing as tm + +import pandas.io.common as icom + +pytestmark = pytest.mark.filterwarnings( + "ignore:Passing a BlockManager to DataFrame:DeprecationWarning" +) + + +class CustomFSPath: + """For testing fspath on unknown objects""" + + def __init__(self, path) -> None: + self.path = path + + def __fspath__(self): + return self.path + + +# Functions that consume a string path and return a string or path-like object +path_types = [str, CustomFSPath, Path] + +try: + from py.path import local as LocalPath + + path_types.append(LocalPath) +except ImportError: + pass + +HERE = os.path.abspath(os.path.dirname(__file__)) + + +# https://github.com/cython/cython/issues/1720 +class TestCommonIOCapabilities: + data1 = """index,A,B,C,D +foo,2,3,4,5 +bar,7,8,9,10 +baz,12,13,14,15 +qux,12,13,14,15 +foo2,12,13,14,15 +bar2,12,13,14,15 +""" + + def test_expand_user(self): + filename = "~/sometest" + expanded_name = icom._expand_user(filename) + + assert expanded_name != filename + assert os.path.isabs(expanded_name) + assert os.path.expanduser(filename) == expanded_name + + def test_expand_user_normal_path(self): + filename = "/somefolder/sometest" + expanded_name = icom._expand_user(filename) + + assert expanded_name == filename + assert os.path.expanduser(filename) == expanded_name + + def test_stringify_path_pathlib(self): + rel_path = icom.stringify_path(Path(".")) + assert rel_path == "." + redundant_path = icom.stringify_path(Path("foo//bar")) + assert redundant_path == os.path.join("foo", "bar") + + @td.skip_if_no("py.path") + def test_stringify_path_localpath(self): + path = os.path.join("foo", "bar") + abs_path = os.path.abspath(path) + lpath = LocalPath(path) + assert icom.stringify_path(lpath) == abs_path + + def test_stringify_path_fspath(self): + p = CustomFSPath("foo/bar.csv") + result = icom.stringify_path(p) + assert result == "foo/bar.csv" + + def test_stringify_file_and_path_like(self): + # GH 38125: do not stringify file objects that are also path-like + fsspec = pytest.importorskip("fsspec") + with tm.ensure_clean() as path: + with fsspec.open(f"file://{path}", mode="wb") as fsspec_obj: + assert fsspec_obj == icom.stringify_path(fsspec_obj) + + @pytest.mark.parametrize("path_type", path_types) + def test_infer_compression_from_path(self, compression_format, path_type): + extension, expected = compression_format + path = path_type("foo/bar.csv" + extension) + compression = icom.infer_compression(path, compression="infer") + assert compression == expected + + @pytest.mark.parametrize("path_type", [str, CustomFSPath, Path]) + def test_get_handle_with_path(self, path_type): + # ignore LocalPath: it creates strange paths: /absolute/~/sometest + with tempfile.TemporaryDirectory(dir=Path.home()) as tmp: + filename = path_type("~/" + Path(tmp).name + "/sometest") + with icom.get_handle(filename, "w") as handles: + assert Path(handles.handle.name).is_absolute() + assert os.path.expanduser(filename) == handles.handle.name + + def test_get_handle_with_buffer(self): + with StringIO() as input_buffer: + with icom.get_handle(input_buffer, "r") as handles: + assert handles.handle == input_buffer + assert not input_buffer.closed + assert input_buffer.closed + + # Test that BytesIOWrapper(get_handle) returns correct amount of bytes every time + def test_bytesiowrapper_returns_correct_bytes(self): + # Test latin1, ucs-2, and ucs-4 chars + data = """a,b,c +1,2,3 +©,®,® +Look,a snake,🐍""" + with icom.get_handle(StringIO(data), "rb", is_text=False) as handles: + result = b"" + chunksize = 5 + while True: + chunk = handles.handle.read(chunksize) + # Make sure each chunk is correct amount of bytes + assert len(chunk) <= chunksize + if len(chunk) < chunksize: + # Can be less amount of bytes, but only at EOF + # which happens when read returns empty + assert len(handles.handle.read()) == 0 + result += chunk + break + result += chunk + assert result == data.encode("utf-8") + + # Test that pyarrow can handle a file opened with get_handle + def test_get_handle_pyarrow_compat(self): + pa_csv = pytest.importorskip("pyarrow.csv") + + # Test latin1, ucs-2, and ucs-4 chars + data = """a,b,c +1,2,3 +©,®,® +Look,a snake,🐍""" + expected = pd.DataFrame( + {"a": ["1", "©", "Look"], "b": ["2", "®", "a snake"], "c": ["3", "®", "🐍"]} + ) + s = StringIO(data) + with icom.get_handle(s, "rb", is_text=False) as handles: + df = pa_csv.read_csv(handles.handle).to_pandas() + tm.assert_frame_equal(df, expected) + assert not s.closed + + def test_iterator(self): + with pd.read_csv(StringIO(self.data1), chunksize=1) as reader: + result = pd.concat(reader, ignore_index=True) + expected = pd.read_csv(StringIO(self.data1)) + tm.assert_frame_equal(result, expected) + + # GH12153 + with pd.read_csv(StringIO(self.data1), chunksize=1) as it: + first = next(it) + tm.assert_frame_equal(first, expected.iloc[[0]]) + tm.assert_frame_equal(pd.concat(it), expected.iloc[1:]) + + @pytest.mark.parametrize( + "reader, module, error_class, fn_ext", + [ + (pd.read_csv, "os", FileNotFoundError, "csv"), + (pd.read_fwf, "os", FileNotFoundError, "txt"), + (pd.read_excel, "xlrd", FileNotFoundError, "xlsx"), + (pd.read_feather, "pyarrow", OSError, "feather"), + (pd.read_hdf, "tables", FileNotFoundError, "h5"), + (pd.read_stata, "os", FileNotFoundError, "dta"), + (pd.read_sas, "os", FileNotFoundError, "sas7bdat"), + (pd.read_json, "os", FileNotFoundError, "json"), + (pd.read_pickle, "os", FileNotFoundError, "pickle"), + ], + ) + def test_read_non_existent(self, reader, module, error_class, fn_ext): + pytest.importorskip(module) + + path = os.path.join(HERE, "data", "does_not_exist." + fn_ext) + msg1 = rf"File (b')?.+does_not_exist\.{fn_ext}'? does not exist" + msg2 = rf"\[Errno 2\] No such file or directory: '.+does_not_exist\.{fn_ext}'" + msg3 = "Expected object or value" + msg4 = "path_or_buf needs to be a string file path or file-like" + msg5 = ( + rf"\[Errno 2\] File .+does_not_exist\.{fn_ext} does not exist: " + rf"'.+does_not_exist\.{fn_ext}'" + ) + msg6 = rf"\[Errno 2\] 没有那个文件或目录: '.+does_not_exist\.{fn_ext}'" + msg7 = ( + rf"\[Errno 2\] File o directory non esistente: '.+does_not_exist\.{fn_ext}'" + ) + msg8 = rf"Failed to open local file.+does_not_exist\.{fn_ext}" + + with pytest.raises( + error_class, + match=rf"({msg1}|{msg2}|{msg3}|{msg4}|{msg5}|{msg6}|{msg7}|{msg8})", + ): + reader(path) + + @pytest.mark.parametrize( + "method, module, error_class, fn_ext", + [ + (pd.DataFrame.to_csv, "os", OSError, "csv"), + (pd.DataFrame.to_html, "os", OSError, "html"), + (pd.DataFrame.to_excel, "xlrd", OSError, "xlsx"), + (pd.DataFrame.to_feather, "pyarrow", OSError, "feather"), + (pd.DataFrame.to_parquet, "pyarrow", OSError, "parquet"), + (pd.DataFrame.to_stata, "os", OSError, "dta"), + (pd.DataFrame.to_json, "os", OSError, "json"), + (pd.DataFrame.to_pickle, "os", OSError, "pickle"), + ], + ) + # NOTE: Missing parent directory for pd.DataFrame.to_hdf is handled by PyTables + def test_write_missing_parent_directory(self, method, module, error_class, fn_ext): + pytest.importorskip(module) + + dummy_frame = pd.DataFrame({"a": [1, 2, 3], "b": [2, 3, 4], "c": [3, 4, 5]}) + + path = os.path.join(HERE, "data", "missing_folder", "does_not_exist." + fn_ext) + + with pytest.raises( + error_class, + match=r"Cannot save file into a non-existent directory: .*missing_folder", + ): + method(dummy_frame, path) + + @pytest.mark.parametrize( + "reader, module, error_class, fn_ext", + [ + (pd.read_csv, "os", FileNotFoundError, "csv"), + (pd.read_table, "os", FileNotFoundError, "csv"), + (pd.read_fwf, "os", FileNotFoundError, "txt"), + (pd.read_excel, "xlrd", FileNotFoundError, "xlsx"), + (pd.read_feather, "pyarrow", OSError, "feather"), + (pd.read_hdf, "tables", FileNotFoundError, "h5"), + (pd.read_stata, "os", FileNotFoundError, "dta"), + (pd.read_sas, "os", FileNotFoundError, "sas7bdat"), + (pd.read_json, "os", FileNotFoundError, "json"), + (pd.read_pickle, "os", FileNotFoundError, "pickle"), + ], + ) + def test_read_expands_user_home_dir( + self, reader, module, error_class, fn_ext, monkeypatch + ): + pytest.importorskip(module) + + path = os.path.join("~", "does_not_exist." + fn_ext) + monkeypatch.setattr(icom, "_expand_user", lambda x: os.path.join("foo", x)) + + msg1 = rf"File (b')?.+does_not_exist\.{fn_ext}'? does not exist" + msg2 = rf"\[Errno 2\] No such file or directory: '.+does_not_exist\.{fn_ext}'" + msg3 = "Unexpected character found when decoding 'false'" + msg4 = "path_or_buf needs to be a string file path or file-like" + msg5 = ( + rf"\[Errno 2\] File .+does_not_exist\.{fn_ext} does not exist: " + rf"'.+does_not_exist\.{fn_ext}'" + ) + msg6 = rf"\[Errno 2\] 没有那个文件或目录: '.+does_not_exist\.{fn_ext}'" + msg7 = ( + rf"\[Errno 2\] File o directory non esistente: '.+does_not_exist\.{fn_ext}'" + ) + msg8 = rf"Failed to open local file.+does_not_exist\.{fn_ext}" + + with pytest.raises( + error_class, + match=rf"({msg1}|{msg2}|{msg3}|{msg4}|{msg5}|{msg6}|{msg7}|{msg8})", + ): + reader(path) + + @pytest.mark.parametrize( + "reader, module, path", + [ + (pd.read_csv, "os", ("io", "data", "csv", "iris.csv")), + (pd.read_table, "os", ("io", "data", "csv", "iris.csv")), + ( + pd.read_fwf, + "os", + ("io", "data", "fixed_width", "fixed_width_format.txt"), + ), + (pd.read_excel, "xlrd", ("io", "data", "excel", "test1.xlsx")), + ( + pd.read_feather, + "pyarrow", + ("io", "data", "feather", "feather-0_3_1.feather"), + ), + ( + pd.read_hdf, + "tables", + ("io", "data", "legacy_hdf", "datetimetz_object.h5"), + ), + (pd.read_stata, "os", ("io", "data", "stata", "stata10_115.dta")), + (pd.read_sas, "os", ("io", "sas", "data", "test1.sas7bdat")), + (pd.read_json, "os", ("io", "json", "data", "tsframe_v012.json")), + ( + pd.read_pickle, + "os", + ("io", "data", "pickle", "categorical.0.25.0.pickle"), + ), + ], + ) + def test_read_fspath_all(self, reader, module, path, datapath): + pytest.importorskip(module) + path = datapath(*path) + + mypath = CustomFSPath(path) + result = reader(mypath) + expected = reader(path) + + if path.endswith(".pickle"): + # categorical + tm.assert_categorical_equal(result, expected) + else: + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "writer_name, writer_kwargs, module", + [ + ("to_csv", {}, "os"), + ("to_excel", {"engine": "openpyxl"}, "openpyxl"), + ("to_feather", {}, "pyarrow"), + ("to_html", {}, "os"), + ("to_json", {}, "os"), + ("to_latex", {}, "os"), + ("to_pickle", {}, "os"), + ("to_stata", {"time_stamp": pd.to_datetime("2019-01-01 00:00")}, "os"), + ], + ) + def test_write_fspath_all(self, writer_name, writer_kwargs, module): + if writer_name in ["to_latex"]: # uses Styler implementation + pytest.importorskip("jinja2") + p1 = tm.ensure_clean("string") + p2 = tm.ensure_clean("fspath") + df = pd.DataFrame({"A": [1, 2]}) + + with p1 as string, p2 as fspath: + pytest.importorskip(module) + mypath = CustomFSPath(fspath) + writer = getattr(df, writer_name) + + writer(string, **writer_kwargs) + writer(mypath, **writer_kwargs) + with open(string, "rb") as f_str, open(fspath, "rb") as f_path: + if writer_name == "to_excel": + # binary representation of excel contains time creation + # data that causes flaky CI failures + result = pd.read_excel(f_str, **writer_kwargs) + expected = pd.read_excel(f_path, **writer_kwargs) + tm.assert_frame_equal(result, expected) + else: + result = f_str.read() + expected = f_path.read() + assert result == expected + + def test_write_fspath_hdf5(self): + # Same test as write_fspath_all, except HDF5 files aren't + # necessarily byte-for-byte identical for a given dataframe, so we'll + # have to read and compare equality + pytest.importorskip("tables") + + df = pd.DataFrame({"A": [1, 2]}) + p1 = tm.ensure_clean("string") + p2 = tm.ensure_clean("fspath") + + with p1 as string, p2 as fspath: + mypath = CustomFSPath(fspath) + df.to_hdf(mypath, key="bar") + df.to_hdf(string, key="bar") + + result = pd.read_hdf(fspath, key="bar") + expected = pd.read_hdf(string, key="bar") + + tm.assert_frame_equal(result, expected) + + +@pytest.fixture +def mmap_file(datapath): + return datapath("io", "data", "csv", "test_mmap.csv") + + +class TestMMapWrapper: + def test_constructor_bad_file(self, mmap_file): + non_file = StringIO("I am not a file") + non_file.fileno = lambda: -1 + + # the error raised is different on Windows + if is_platform_windows(): + msg = "The parameter is incorrect" + err = OSError + else: + msg = "[Errno 22]" + err = mmap.error + + with pytest.raises(err, match=msg): + icom._maybe_memory_map(non_file, True) + + with open(mmap_file, encoding="utf-8") as target: + pass + + msg = "I/O operation on closed file" + with pytest.raises(ValueError, match=msg): + icom._maybe_memory_map(target, True) + + def test_next(self, mmap_file): + with open(mmap_file, encoding="utf-8") as target: + lines = target.readlines() + + with icom.get_handle( + target, "r", is_text=True, memory_map=True + ) as wrappers: + wrapper = wrappers.handle + assert isinstance(wrapper.buffer.buffer, mmap.mmap) + + for line in lines: + next_line = next(wrapper) + assert next_line.strip() == line.strip() + + with pytest.raises(StopIteration, match=r"^$"): + next(wrapper) + + def test_unknown_engine(self): + with tm.ensure_clean() as path: + df = pd.DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD"), dtype=object), + index=pd.Index([f"i-{i}" for i in range(30)], dtype=object), + ) + df.to_csv(path) + with pytest.raises(ValueError, match="Unknown engine"): + pd.read_csv(path, engine="pyt") + + def test_binary_mode(self): + """ + 'encoding' shouldn't be passed to 'open' in binary mode. + + GH 35058 + """ + with tm.ensure_clean() as path: + df = pd.DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD"), dtype=object), + index=pd.Index([f"i-{i}" for i in range(30)], dtype=object), + ) + df.to_csv(path, mode="w+b") + tm.assert_frame_equal(df, pd.read_csv(path, index_col=0)) + + @pytest.mark.parametrize("encoding", ["utf-16", "utf-32"]) + @pytest.mark.parametrize("compression_", ["bz2", "xz"]) + def test_warning_missing_utf_bom(self, encoding, compression_): + """ + bz2 and xz do not write the byte order mark (BOM) for utf-16/32. + + https://stackoverflow.com/questions/55171439 + + GH 35681 + """ + df = pd.DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD"), dtype=object), + index=pd.Index([f"i-{i}" for i in range(30)], dtype=object), + ) + with tm.ensure_clean() as path: + with tm.assert_produces_warning(UnicodeWarning): + df.to_csv(path, compression=compression_, encoding=encoding) + + # reading should fail (otherwise we wouldn't need the warning) + msg = ( + r"UTF-\d+ stream does not start with BOM|" + r"'utf-\d+' codec can't decode byte" + ) + with pytest.raises(UnicodeError, match=msg): + pd.read_csv(path, compression=compression_, encoding=encoding) + + +def test_is_fsspec_url(): + assert icom.is_fsspec_url("gcs://pandas/somethingelse.com") + assert icom.is_fsspec_url("gs://pandas/somethingelse.com") + # the following is the only remote URL that is handled without fsspec + assert not icom.is_fsspec_url("http://pandas/somethingelse.com") + assert not icom.is_fsspec_url("random:pandas/somethingelse.com") + assert not icom.is_fsspec_url("/local/path") + assert not icom.is_fsspec_url("relative/local/path") + # fsspec URL in string should not be recognized + assert not icom.is_fsspec_url("this is not fsspec://url") + assert not icom.is_fsspec_url("{'url': 'gs://pandas/somethingelse.com'}") + # accept everything that conforms to RFC 3986 schema + assert icom.is_fsspec_url("RFC-3986+compliant.spec://something") + + +@pytest.mark.parametrize("encoding", [None, "utf-8"]) +@pytest.mark.parametrize("format", ["csv", "json"]) +def test_codecs_encoding(encoding, format): + # GH39247 + expected = pd.DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD"), dtype=object), + index=pd.Index([f"i-{i}" for i in range(30)], dtype=object), + ) + with tm.ensure_clean() as path: + with codecs.open(path, mode="w", encoding=encoding) as handle: + getattr(expected, f"to_{format}")(handle) + with codecs.open(path, mode="r", encoding=encoding) as handle: + if format == "csv": + df = pd.read_csv(handle, index_col=0) + else: + df = pd.read_json(handle) + tm.assert_frame_equal(expected, df) + + +def test_codecs_get_writer_reader(): + # GH39247 + expected = pd.DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD"), dtype=object), + index=pd.Index([f"i-{i}" for i in range(30)], dtype=object), + ) + with tm.ensure_clean() as path: + with open(path, "wb") as handle: + with codecs.getwriter("utf-8")(handle) as encoded: + expected.to_csv(encoded) + with open(path, "rb") as handle: + with codecs.getreader("utf-8")(handle) as encoded: + df = pd.read_csv(encoded, index_col=0) + tm.assert_frame_equal(expected, df) + + +@pytest.mark.parametrize( + "io_class,mode,msg", + [ + (BytesIO, "t", "a bytes-like object is required, not 'str'"), + (StringIO, "b", "string argument expected, got 'bytes'"), + ], +) +def test_explicit_encoding(io_class, mode, msg): + # GH39247; this test makes sure that if a user provides mode="*t" or "*b", + # it is used. In the case of this test it leads to an error as intentionally the + # wrong mode is requested + expected = pd.DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD"), dtype=object), + index=pd.Index([f"i-{i}" for i in range(30)], dtype=object), + ) + with io_class() as buffer: + with pytest.raises(TypeError, match=msg): + expected.to_csv(buffer, mode=f"w{mode}") + + +@pytest.mark.parametrize("encoding_errors", [None, "strict", "replace"]) +@pytest.mark.parametrize("format", ["csv", "json"]) +def test_encoding_errors(encoding_errors, format): + # GH39450 + msg = "'utf-8' codec can't decode byte" + bad_encoding = b"\xe4" + + if format == "csv": + content = b"," + bad_encoding + b"\n" + bad_encoding * 2 + b"," + bad_encoding + reader = partial(pd.read_csv, index_col=0) + else: + content = ( + b'{"' + + bad_encoding * 2 + + b'": {"' + + bad_encoding + + b'":"' + + bad_encoding + + b'"}}' + ) + reader = partial(pd.read_json, orient="index") + with tm.ensure_clean() as path: + file = Path(path) + file.write_bytes(content) + + if encoding_errors != "replace": + with pytest.raises(UnicodeDecodeError, match=msg): + reader(path, encoding_errors=encoding_errors) + else: + df = reader(path, encoding_errors=encoding_errors) + decoded = bad_encoding.decode(errors=encoding_errors) + expected = pd.DataFrame({decoded: [decoded]}, index=[decoded * 2]) + tm.assert_frame_equal(df, expected) + + +def test_bad_encdoing_errors(): + # GH 39777 + with tm.ensure_clean() as path: + with pytest.raises(LookupError, match="unknown error handler name"): + icom.get_handle(path, "w", errors="bad") + + +def test_errno_attribute(): + # GH 13872 + with pytest.raises(FileNotFoundError, match="\\[Errno 2\\]") as err: + pd.read_csv("doesnt_exist") + assert err.errno == errno.ENOENT + + +def test_fail_mmap(): + with pytest.raises(UnsupportedOperation, match="fileno"): + with BytesIO() as buffer: + icom.get_handle(buffer, "rb", memory_map=True) + + +def test_close_on_error(): + # GH 47136 + class TestError: + def close(self): + raise OSError("test") + + with pytest.raises(OSError, match="test"): + with BytesIO() as buffer: + with icom.get_handle(buffer, "rb") as handles: + handles.created_handles.append(TestError()) + + +@pytest.mark.parametrize( + "reader", + [ + pd.read_csv, + pd.read_fwf, + pd.read_excel, + pd.read_feather, + pd.read_hdf, + pd.read_stata, + pd.read_sas, + pd.read_json, + pd.read_pickle, + ], +) +def test_pickle_reader(reader): + # GH 22265 + with BytesIO() as buffer: + pickle.dump(reader, buffer) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_compression.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_compression.py new file mode 100644 index 0000000000000000000000000000000000000000..3a58dda9e8dc47f2072e0175c120036523ed83f7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_compression.py @@ -0,0 +1,378 @@ +import gzip +import io +import os +from pathlib import Path +import subprocess +import sys +import tarfile +import textwrap +import time +import zipfile + +import numpy as np +import pytest + +from pandas.compat import is_platform_windows + +import pandas as pd +import pandas._testing as tm + +import pandas.io.common as icom + + +@pytest.mark.parametrize( + "obj", + [ + pd.DataFrame( + 100 * [[0.123456, 0.234567, 0.567567], [12.32112, 123123.2, 321321.2]], + columns=["X", "Y", "Z"], + ), + pd.Series(100 * [0.123456, 0.234567, 0.567567], name="X"), + ], +) +@pytest.mark.parametrize("method", ["to_pickle", "to_json", "to_csv"]) +def test_compression_size(obj, method, compression_only): + if compression_only == "tar": + compression_only = {"method": "tar", "mode": "w:gz"} + + with tm.ensure_clean() as path: + getattr(obj, method)(path, compression=compression_only) + compressed_size = os.path.getsize(path) + getattr(obj, method)(path, compression=None) + uncompressed_size = os.path.getsize(path) + assert uncompressed_size > compressed_size + + +@pytest.mark.parametrize( + "obj", + [ + pd.DataFrame( + 100 * [[0.123456, 0.234567, 0.567567], [12.32112, 123123.2, 321321.2]], + columns=["X", "Y", "Z"], + ), + pd.Series(100 * [0.123456, 0.234567, 0.567567], name="X"), + ], +) +@pytest.mark.parametrize("method", ["to_csv", "to_json"]) +def test_compression_size_fh(obj, method, compression_only): + with tm.ensure_clean() as path: + with icom.get_handle( + path, + "w:gz" if compression_only == "tar" else "w", + compression=compression_only, + ) as handles: + getattr(obj, method)(handles.handle) + assert not handles.handle.closed + compressed_size = os.path.getsize(path) + with tm.ensure_clean() as path: + with icom.get_handle(path, "w", compression=None) as handles: + getattr(obj, method)(handles.handle) + assert not handles.handle.closed + uncompressed_size = os.path.getsize(path) + assert uncompressed_size > compressed_size + + +@pytest.mark.parametrize( + "write_method, write_kwargs, read_method", + [ + ("to_csv", {"index": False}, pd.read_csv), + ("to_json", {}, pd.read_json), + ("to_pickle", {}, pd.read_pickle), + ], +) +def test_dataframe_compression_defaults_to_infer( + write_method, write_kwargs, read_method, compression_only, compression_to_extension +): + # GH22004 + input = pd.DataFrame([[1.0, 0, -4], [3.4, 5, 2]], columns=["X", "Y", "Z"]) + extension = compression_to_extension[compression_only] + with tm.ensure_clean("compressed" + extension) as path: + getattr(input, write_method)(path, **write_kwargs) + output = read_method(path, compression=compression_only) + tm.assert_frame_equal(output, input) + + +@pytest.mark.parametrize( + "write_method,write_kwargs,read_method,read_kwargs", + [ + ("to_csv", {"index": False, "header": True}, pd.read_csv, {"squeeze": True}), + ("to_json", {}, pd.read_json, {"typ": "series"}), + ("to_pickle", {}, pd.read_pickle, {}), + ], +) +def test_series_compression_defaults_to_infer( + write_method, + write_kwargs, + read_method, + read_kwargs, + compression_only, + compression_to_extension, +): + # GH22004 + input = pd.Series([0, 5, -2, 10], name="X") + extension = compression_to_extension[compression_only] + with tm.ensure_clean("compressed" + extension) as path: + getattr(input, write_method)(path, **write_kwargs) + if "squeeze" in read_kwargs: + kwargs = read_kwargs.copy() + del kwargs["squeeze"] + output = read_method(path, compression=compression_only, **kwargs).squeeze( + "columns" + ) + else: + output = read_method(path, compression=compression_only, **read_kwargs) + tm.assert_series_equal(output, input, check_names=False) + + +def test_compression_warning(compression_only): + # Assert that passing a file object to to_csv while explicitly specifying a + # compression protocol triggers a RuntimeWarning, as per GH21227. + df = pd.DataFrame( + 100 * [[0.123456, 0.234567, 0.567567], [12.32112, 123123.2, 321321.2]], + columns=["X", "Y", "Z"], + ) + with tm.ensure_clean() as path: + with icom.get_handle(path, "w", compression=compression_only) as handles: + with tm.assert_produces_warning(RuntimeWarning): + df.to_csv(handles.handle, compression=compression_only) + + +def test_compression_binary(compression_only): + """ + Binary file handles support compression. + + GH22555 + """ + df = pd.DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD"), dtype=object), + index=pd.Index([f"i-{i}" for i in range(30)], dtype=object), + ) + + # with a file + with tm.ensure_clean() as path: + with open(path, mode="wb") as file: + df.to_csv(file, mode="wb", compression=compression_only) + file.seek(0) # file shouldn't be closed + tm.assert_frame_equal( + df, pd.read_csv(path, index_col=0, compression=compression_only) + ) + + # with BytesIO + file = io.BytesIO() + df.to_csv(file, mode="wb", compression=compression_only) + file.seek(0) # file shouldn't be closed + tm.assert_frame_equal( + df, pd.read_csv(file, index_col=0, compression=compression_only) + ) + + +def test_gzip_reproducibility_file_name(): + """ + Gzip should create reproducible archives with mtime. + + Note: Archives created with different filenames will still be different! + + GH 28103 + """ + df = pd.DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD"), dtype=object), + index=pd.Index([f"i-{i}" for i in range(30)], dtype=object), + ) + compression_options = {"method": "gzip", "mtime": 1} + + # test for filename + with tm.ensure_clean() as path: + path = Path(path) + df.to_csv(path, compression=compression_options) + time.sleep(0.1) + output = path.read_bytes() + df.to_csv(path, compression=compression_options) + assert output == path.read_bytes() + + +def test_gzip_reproducibility_file_object(): + """ + Gzip should create reproducible archives with mtime. + + GH 28103 + """ + df = pd.DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD"), dtype=object), + index=pd.Index([f"i-{i}" for i in range(30)], dtype=object), + ) + compression_options = {"method": "gzip", "mtime": 1} + + # test for file object + buffer = io.BytesIO() + df.to_csv(buffer, compression=compression_options, mode="wb") + output = buffer.getvalue() + time.sleep(0.1) + buffer = io.BytesIO() + df.to_csv(buffer, compression=compression_options, mode="wb") + assert output == buffer.getvalue() + + +@pytest.mark.single_cpu +def test_with_missing_lzma(): + """Tests if import pandas works when lzma is not present.""" + # https://github.com/pandas-dev/pandas/issues/27575 + code = textwrap.dedent( + """\ + import sys + sys.modules['lzma'] = None + import pandas + """ + ) + subprocess.check_output([sys.executable, "-c", code], stderr=subprocess.PIPE) + + +@pytest.mark.single_cpu +def test_with_missing_lzma_runtime(): + """Tests if RuntimeError is hit when calling lzma without + having the module available. + """ + code = textwrap.dedent( + """ + import sys + import pytest + sys.modules['lzma'] = None + import pandas as pd + df = pd.DataFrame() + with pytest.raises(RuntimeError, match='lzma module'): + df.to_csv('foo.csv', compression='xz') + """ + ) + subprocess.check_output([sys.executable, "-c", code], stderr=subprocess.PIPE) + + +@pytest.mark.parametrize( + "obj", + [ + pd.DataFrame( + 100 * [[0.123456, 0.234567, 0.567567], [12.32112, 123123.2, 321321.2]], + columns=["X", "Y", "Z"], + ), + pd.Series(100 * [0.123456, 0.234567, 0.567567], name="X"), + ], +) +@pytest.mark.parametrize("method", ["to_pickle", "to_json", "to_csv"]) +def test_gzip_compression_level(obj, method): + # GH33196 + with tm.ensure_clean() as path: + getattr(obj, method)(path, compression="gzip") + compressed_size_default = os.path.getsize(path) + getattr(obj, method)(path, compression={"method": "gzip", "compresslevel": 1}) + compressed_size_fast = os.path.getsize(path) + assert compressed_size_default < compressed_size_fast + + +@pytest.mark.parametrize( + "obj", + [ + pd.DataFrame( + 100 * [[0.123456, 0.234567, 0.567567], [12.32112, 123123.2, 321321.2]], + columns=["X", "Y", "Z"], + ), + pd.Series(100 * [0.123456, 0.234567, 0.567567], name="X"), + ], +) +@pytest.mark.parametrize("method", ["to_pickle", "to_json", "to_csv"]) +def test_xz_compression_level_read(obj, method): + with tm.ensure_clean() as path: + getattr(obj, method)(path, compression="xz") + compressed_size_default = os.path.getsize(path) + getattr(obj, method)(path, compression={"method": "xz", "preset": 1}) + compressed_size_fast = os.path.getsize(path) + assert compressed_size_default < compressed_size_fast + if method == "to_csv": + pd.read_csv(path, compression="xz") + + +@pytest.mark.parametrize( + "obj", + [ + pd.DataFrame( + 100 * [[0.123456, 0.234567, 0.567567], [12.32112, 123123.2, 321321.2]], + columns=["X", "Y", "Z"], + ), + pd.Series(100 * [0.123456, 0.234567, 0.567567], name="X"), + ], +) +@pytest.mark.parametrize("method", ["to_pickle", "to_json", "to_csv"]) +def test_bzip_compression_level(obj, method): + """GH33196 bzip needs file size > 100k to show a size difference between + compression levels, so here we just check if the call works when + compression is passed as a dict. + """ + with tm.ensure_clean() as path: + getattr(obj, method)(path, compression={"method": "bz2", "compresslevel": 1}) + + +@pytest.mark.parametrize( + "suffix,archive", + [ + (".zip", zipfile.ZipFile), + (".tar", tarfile.TarFile), + ], +) +def test_empty_archive_zip(suffix, archive): + with tm.ensure_clean(filename=suffix) as path: + with archive(path, "w"): + pass + with pytest.raises(ValueError, match="Zero files found"): + pd.read_csv(path) + + +def test_ambiguous_archive_zip(): + with tm.ensure_clean(filename=".zip") as path: + with zipfile.ZipFile(path, "w") as file: + file.writestr("a.csv", "foo,bar") + file.writestr("b.csv", "foo,bar") + with pytest.raises(ValueError, match="Multiple files found in ZIP file"): + pd.read_csv(path) + + +def test_ambiguous_archive_tar(tmp_path): + csvAPath = tmp_path / "a.csv" + with open(csvAPath, "w", encoding="utf-8") as a: + a.write("foo,bar\n") + csvBPath = tmp_path / "b.csv" + with open(csvBPath, "w", encoding="utf-8") as b: + b.write("foo,bar\n") + + tarpath = tmp_path / "archive.tar" + with tarfile.TarFile(tarpath, "w") as tar: + tar.add(csvAPath, "a.csv") + tar.add(csvBPath, "b.csv") + + with pytest.raises(ValueError, match="Multiple files found in TAR archive"): + pd.read_csv(tarpath) + + +def test_tar_gz_to_different_filename(): + with tm.ensure_clean(filename=".foo") as file: + pd.DataFrame( + [["1", "2"]], + columns=["foo", "bar"], + ).to_csv(file, compression={"method": "tar", "mode": "w:gz"}, index=False) + with gzip.open(file) as uncompressed: + with tarfile.TarFile(fileobj=uncompressed) as archive: + members = archive.getmembers() + assert len(members) == 1 + content = archive.extractfile(members[0]).read().decode("utf8") + + if is_platform_windows(): + expected = "foo,bar\r\n1,2\r\n" + else: + expected = "foo,bar\n1,2\n" + + assert content == expected + + +def test_tar_no_error_on_close(): + with io.BytesIO() as buffer: + with icom._BytesTarFile(fileobj=buffer, mode="w"): + pass diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_feather.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_feather.py new file mode 100644 index 0000000000000000000000000000000000000000..22a7d3b83a459a5dc48ee5d56c2f70130d644be4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_feather.py @@ -0,0 +1,252 @@ +""" test feather-format compat """ +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm +from pandas.core.arrays import ( + ArrowStringArray, + StringArray, +) + +from pandas.io.feather_format import read_feather, to_feather # isort:skip + +pytestmark = pytest.mark.filterwarnings( + "ignore:Passing a BlockManager to DataFrame:DeprecationWarning" +) + +pa = pytest.importorskip("pyarrow") + + +@pytest.mark.single_cpu +class TestFeather: + def check_error_on_write(self, df, exc, err_msg): + # check that we are raising the exception + # on writing + + with pytest.raises(exc, match=err_msg): + with tm.ensure_clean() as path: + to_feather(df, path) + + def check_external_error_on_write(self, df): + # check that we are raising the exception + # on writing + + with tm.external_error_raised(Exception): + with tm.ensure_clean() as path: + to_feather(df, path) + + def check_round_trip(self, df, expected=None, write_kwargs={}, **read_kwargs): + if expected is None: + expected = df.copy() + + with tm.ensure_clean() as path: + to_feather(df, path, **write_kwargs) + + result = read_feather(path, **read_kwargs) + + tm.assert_frame_equal(result, expected) + + def test_error(self): + msg = "feather only support IO with DataFrames" + for obj in [ + pd.Series([1, 2, 3]), + 1, + "foo", + pd.Timestamp("20130101"), + np.array([1, 2, 3]), + ]: + self.check_error_on_write(obj, ValueError, msg) + + def test_basic(self): + df = pd.DataFrame( + { + "string": list("abc"), + "int": list(range(1, 4)), + "uint": np.arange(3, 6).astype("u1"), + "float": np.arange(4.0, 7.0, dtype="float64"), + "float_with_null": [1.0, np.nan, 3], + "bool": [True, False, True], + "bool_with_null": [True, np.nan, False], + "cat": pd.Categorical(list("abc")), + "dt": pd.DatetimeIndex( + list(pd.date_range("20130101", periods=3)), freq=None + ), + "dttz": pd.DatetimeIndex( + list(pd.date_range("20130101", periods=3, tz="US/Eastern")), + freq=None, + ), + "dt_with_null": [ + pd.Timestamp("20130101"), + pd.NaT, + pd.Timestamp("20130103"), + ], + "dtns": pd.DatetimeIndex( + list(pd.date_range("20130101", periods=3, freq="ns")), freq=None + ), + } + ) + df["periods"] = pd.period_range("2013", freq="M", periods=3) + df["timedeltas"] = pd.timedelta_range("1 day", periods=3) + df["intervals"] = pd.interval_range(0, 3, 3) + + assert df.dttz.dtype.tz.zone == "US/Eastern" + + expected = df.copy() + expected.loc[1, "bool_with_null"] = None + self.check_round_trip(df, expected=expected) + + def test_duplicate_columns(self): + # https://github.com/wesm/feather/issues/53 + # not currently able to handle duplicate columns + df = pd.DataFrame(np.arange(12).reshape(4, 3), columns=list("aaa")).copy() + self.check_external_error_on_write(df) + + def test_read_columns(self): + # GH 24025 + df = pd.DataFrame( + { + "col1": list("abc"), + "col2": list(range(1, 4)), + "col3": list("xyz"), + "col4": list(range(4, 7)), + } + ) + columns = ["col1", "col3"] + self.check_round_trip(df, expected=df[columns], columns=columns) + + def test_read_columns_different_order(self): + # GH 33878 + df = pd.DataFrame({"A": [1, 2], "B": ["x", "y"], "C": [True, False]}) + expected = df[["B", "A"]] + self.check_round_trip(df, expected, columns=["B", "A"]) + + def test_unsupported_other(self): + # mixed python objects + df = pd.DataFrame({"a": ["a", 1, 2.0]}) + self.check_external_error_on_write(df) + + def test_rw_use_threads(self): + df = pd.DataFrame({"A": np.arange(100000)}) + self.check_round_trip(df, use_threads=True) + self.check_round_trip(df, use_threads=False) + + def test_path_pathlib(self): + df = pd.DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD"), dtype=object), + index=pd.Index([f"i-{i}" for i in range(30)], dtype=object), + ).reset_index() + result = tm.round_trip_pathlib(df.to_feather, read_feather) + tm.assert_frame_equal(df, result) + + def test_path_localpath(self): + df = pd.DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD"), dtype=object), + index=pd.Index([f"i-{i}" for i in range(30)], dtype=object), + ).reset_index() + result = tm.round_trip_localpath(df.to_feather, read_feather) + tm.assert_frame_equal(df, result) + + def test_passthrough_keywords(self): + df = pd.DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD"), dtype=object), + index=pd.Index([f"i-{i}" for i in range(30)], dtype=object), + ).reset_index() + self.check_round_trip(df, write_kwargs={"version": 1}) + + @pytest.mark.network + @pytest.mark.single_cpu + def test_http_path(self, feather_file, httpserver): + # GH 29055 + expected = read_feather(feather_file) + with open(feather_file, "rb") as f: + httpserver.serve_content(content=f.read()) + res = read_feather(httpserver.url) + tm.assert_frame_equal(expected, res) + + def test_read_feather_dtype_backend(self, string_storage, dtype_backend): + # GH#50765 + df = pd.DataFrame( + { + "a": pd.Series([1, np.nan, 3], dtype="Int64"), + "b": pd.Series([1, 2, 3], dtype="Int64"), + "c": pd.Series([1.5, np.nan, 2.5], dtype="Float64"), + "d": pd.Series([1.5, 2.0, 2.5], dtype="Float64"), + "e": [True, False, None], + "f": [True, False, True], + "g": ["a", "b", "c"], + "h": ["a", "b", None], + } + ) + + if string_storage == "python": + string_array = StringArray(np.array(["a", "b", "c"], dtype=np.object_)) + string_array_na = StringArray(np.array(["a", "b", pd.NA], dtype=np.object_)) + + elif dtype_backend == "pyarrow": + from pandas.arrays import ArrowExtensionArray + + string_array = ArrowExtensionArray(pa.array(["a", "b", "c"])) + string_array_na = ArrowExtensionArray(pa.array(["a", "b", None])) + + else: + string_array = ArrowStringArray(pa.array(["a", "b", "c"])) + string_array_na = ArrowStringArray(pa.array(["a", "b", None])) + + with tm.ensure_clean() as path: + to_feather(df, path) + with pd.option_context("mode.string_storage", string_storage): + result = read_feather(path, dtype_backend=dtype_backend) + + expected = pd.DataFrame( + { + "a": pd.Series([1, np.nan, 3], dtype="Int64"), + "b": pd.Series([1, 2, 3], dtype="Int64"), + "c": pd.Series([1.5, np.nan, 2.5], dtype="Float64"), + "d": pd.Series([1.5, 2.0, 2.5], dtype="Float64"), + "e": pd.Series([True, False, pd.NA], dtype="boolean"), + "f": pd.Series([True, False, True], dtype="boolean"), + "g": string_array, + "h": string_array_na, + } + ) + + if dtype_backend == "pyarrow": + from pandas.arrays import ArrowExtensionArray + + expected = pd.DataFrame( + { + col: ArrowExtensionArray(pa.array(expected[col], from_pandas=True)) + for col in expected.columns + } + ) + + tm.assert_frame_equal(result, expected) + + def test_int_columns_and_index(self): + df = pd.DataFrame({"a": [1, 2, 3]}, index=pd.Index([3, 4, 5], name="test")) + self.check_round_trip(df) + + def test_invalid_dtype_backend(self): + msg = ( + "dtype_backend numpy is invalid, only 'numpy_nullable' and " + "'pyarrow' are allowed." + ) + df = pd.DataFrame({"int": list(range(1, 4))}) + with tm.ensure_clean("tmp.feather") as path: + df.to_feather(path) + with pytest.raises(ValueError, match=msg): + read_feather(path, dtype_backend="numpy") + + def test_string_inference(self, tmp_path): + # GH#54431 + path = tmp_path / "test_string_inference.p" + df = pd.DataFrame(data={"a": ["x", "y"]}) + df.to_feather(path) + with pd.option_context("future.infer_string", True): + result = read_feather(path) + expected = pd.DataFrame(data={"a": ["x", "y"]}, dtype="string[pyarrow_numpy]") + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_fsspec.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_fsspec.py new file mode 100644 index 0000000000000000000000000000000000000000..a1dec8a2d05b4fc4c39cbc910544532bf4eb0cca --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_fsspec.py @@ -0,0 +1,345 @@ +import io + +import numpy as np +import pytest + +from pandas import ( + DataFrame, + date_range, + read_csv, + read_excel, + read_feather, + read_json, + read_parquet, + read_pickle, + read_stata, + read_table, +) +import pandas._testing as tm +from pandas.util import _test_decorators as td + +pytestmark = pytest.mark.filterwarnings( + "ignore:Passing a BlockManager to DataFrame:DeprecationWarning" +) + + +@pytest.fixture +def fsspectest(): + pytest.importorskip("fsspec") + from fsspec import register_implementation + from fsspec.implementations.memory import MemoryFileSystem + from fsspec.registry import _registry as registry + + class TestMemoryFS(MemoryFileSystem): + protocol = "testmem" + test = [None] + + def __init__(self, **kwargs) -> None: + self.test[0] = kwargs.pop("test", None) + super().__init__(**kwargs) + + register_implementation("testmem", TestMemoryFS, clobber=True) + yield TestMemoryFS() + registry.pop("testmem", None) + TestMemoryFS.test[0] = None + TestMemoryFS.store.clear() + + +@pytest.fixture +def df1(): + return DataFrame( + { + "int": [1, 3], + "float": [2.0, np.nan], + "str": ["t", "s"], + "dt": date_range("2018-06-18", periods=2), + } + ) + + +@pytest.fixture +def cleared_fs(): + fsspec = pytest.importorskip("fsspec") + + memfs = fsspec.filesystem("memory") + yield memfs + memfs.store.clear() + + +def test_read_csv(cleared_fs, df1): + text = str(df1.to_csv(index=False)).encode() + with cleared_fs.open("test/test.csv", "wb") as w: + w.write(text) + df2 = read_csv("memory://test/test.csv", parse_dates=["dt"]) + + tm.assert_frame_equal(df1, df2) + + +def test_reasonable_error(monkeypatch, cleared_fs): + from fsspec.registry import known_implementations + + with pytest.raises(ValueError, match="nosuchprotocol"): + read_csv("nosuchprotocol://test/test.csv") + err_msg = "test error message" + monkeypatch.setitem( + known_implementations, + "couldexist", + {"class": "unimportable.CouldExist", "err": err_msg}, + ) + with pytest.raises(ImportError, match=err_msg): + read_csv("couldexist://test/test.csv") + + +def test_to_csv(cleared_fs, df1): + df1.to_csv("memory://test/test.csv", index=True) + + df2 = read_csv("memory://test/test.csv", parse_dates=["dt"], index_col=0) + + tm.assert_frame_equal(df1, df2) + + +def test_to_excel(cleared_fs, df1): + pytest.importorskip("openpyxl") + ext = "xlsx" + path = f"memory://test/test.{ext}" + df1.to_excel(path, index=True) + + df2 = read_excel(path, parse_dates=["dt"], index_col=0) + + tm.assert_frame_equal(df1, df2) + + +@pytest.mark.parametrize("binary_mode", [False, True]) +def test_to_csv_fsspec_object(cleared_fs, binary_mode, df1): + fsspec = pytest.importorskip("fsspec") + + path = "memory://test/test.csv" + mode = "wb" if binary_mode else "w" + with fsspec.open(path, mode=mode).open() as fsspec_object: + df1.to_csv(fsspec_object, index=True) + assert not fsspec_object.closed + + mode = mode.replace("w", "r") + with fsspec.open(path, mode=mode) as fsspec_object: + df2 = read_csv( + fsspec_object, + parse_dates=["dt"], + index_col=0, + ) + assert not fsspec_object.closed + + tm.assert_frame_equal(df1, df2) + + +def test_csv_options(fsspectest): + df = DataFrame({"a": [0]}) + df.to_csv( + "testmem://test/test.csv", storage_options={"test": "csv_write"}, index=False + ) + assert fsspectest.test[0] == "csv_write" + read_csv("testmem://test/test.csv", storage_options={"test": "csv_read"}) + assert fsspectest.test[0] == "csv_read" + + +def test_read_table_options(fsspectest): + # GH #39167 + df = DataFrame({"a": [0]}) + df.to_csv( + "testmem://test/test.csv", storage_options={"test": "csv_write"}, index=False + ) + assert fsspectest.test[0] == "csv_write" + read_table("testmem://test/test.csv", storage_options={"test": "csv_read"}) + assert fsspectest.test[0] == "csv_read" + + +def test_excel_options(fsspectest): + pytest.importorskip("openpyxl") + extension = "xlsx" + + df = DataFrame({"a": [0]}) + + path = f"testmem://test/test.{extension}" + + df.to_excel(path, storage_options={"test": "write"}, index=False) + assert fsspectest.test[0] == "write" + read_excel(path, storage_options={"test": "read"}) + assert fsspectest.test[0] == "read" + + +def test_to_parquet_new_file(cleared_fs, df1): + """Regression test for writing to a not-yet-existent GCS Parquet file.""" + pytest.importorskip("fastparquet") + + df1.to_parquet( + "memory://test/test.csv", index=True, engine="fastparquet", compression=None + ) + + +def test_arrowparquet_options(fsspectest): + """Regression test for writing to a not-yet-existent GCS Parquet file.""" + pytest.importorskip("pyarrow") + df = DataFrame({"a": [0]}) + df.to_parquet( + "testmem://test/test.csv", + engine="pyarrow", + compression=None, + storage_options={"test": "parquet_write"}, + ) + assert fsspectest.test[0] == "parquet_write" + read_parquet( + "testmem://test/test.csv", + engine="pyarrow", + storage_options={"test": "parquet_read"}, + ) + assert fsspectest.test[0] == "parquet_read" + + +@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) fastparquet +def test_fastparquet_options(fsspectest): + """Regression test for writing to a not-yet-existent GCS Parquet file.""" + pytest.importorskip("fastparquet") + + df = DataFrame({"a": [0]}) + df.to_parquet( + "testmem://test/test.csv", + engine="fastparquet", + compression=None, + storage_options={"test": "parquet_write"}, + ) + assert fsspectest.test[0] == "parquet_write" + read_parquet( + "testmem://test/test.csv", + engine="fastparquet", + storage_options={"test": "parquet_read"}, + ) + assert fsspectest.test[0] == "parquet_read" + + +@pytest.mark.single_cpu +def test_from_s3_csv(s3_public_bucket_with_data, tips_file, s3so): + pytest.importorskip("s3fs") + tm.assert_equal( + read_csv( + f"s3://{s3_public_bucket_with_data.name}/tips.csv", storage_options=s3so + ), + read_csv(tips_file), + ) + # the following are decompressed by pandas, not fsspec + tm.assert_equal( + read_csv( + f"s3://{s3_public_bucket_with_data.name}/tips.csv.gz", storage_options=s3so + ), + read_csv(tips_file), + ) + tm.assert_equal( + read_csv( + f"s3://{s3_public_bucket_with_data.name}/tips.csv.bz2", storage_options=s3so + ), + read_csv(tips_file), + ) + + +@pytest.mark.single_cpu +@pytest.mark.parametrize("protocol", ["s3", "s3a", "s3n"]) +def test_s3_protocols(s3_public_bucket_with_data, tips_file, protocol, s3so): + pytest.importorskip("s3fs") + tm.assert_equal( + read_csv( + f"{protocol}://{s3_public_bucket_with_data.name}/tips.csv", + storage_options=s3so, + ), + read_csv(tips_file), + ) + + +@pytest.mark.single_cpu +@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) fastparquet +def test_s3_parquet(s3_public_bucket, s3so, df1): + pytest.importorskip("fastparquet") + pytest.importorskip("s3fs") + + fn = f"s3://{s3_public_bucket.name}/test.parquet" + df1.to_parquet( + fn, index=False, engine="fastparquet", compression=None, storage_options=s3so + ) + df2 = read_parquet(fn, engine="fastparquet", storage_options=s3so) + tm.assert_equal(df1, df2) + + +@td.skip_if_installed("fsspec") +def test_not_present_exception(): + msg = "Missing optional dependency 'fsspec'|fsspec library is required" + with pytest.raises(ImportError, match=msg): + read_csv("memory://test/test.csv") + + +def test_feather_options(fsspectest): + pytest.importorskip("pyarrow") + df = DataFrame({"a": [0]}) + df.to_feather("testmem://mockfile", storage_options={"test": "feather_write"}) + assert fsspectest.test[0] == "feather_write" + out = read_feather("testmem://mockfile", storage_options={"test": "feather_read"}) + assert fsspectest.test[0] == "feather_read" + tm.assert_frame_equal(df, out) + + +def test_pickle_options(fsspectest): + df = DataFrame({"a": [0]}) + df.to_pickle("testmem://mockfile", storage_options={"test": "pickle_write"}) + assert fsspectest.test[0] == "pickle_write" + out = read_pickle("testmem://mockfile", storage_options={"test": "pickle_read"}) + assert fsspectest.test[0] == "pickle_read" + tm.assert_frame_equal(df, out) + + +def test_json_options(fsspectest, compression): + df = DataFrame({"a": [0]}) + df.to_json( + "testmem://mockfile", + compression=compression, + storage_options={"test": "json_write"}, + ) + assert fsspectest.test[0] == "json_write" + out = read_json( + "testmem://mockfile", + compression=compression, + storage_options={"test": "json_read"}, + ) + assert fsspectest.test[0] == "json_read" + tm.assert_frame_equal(df, out) + + +def test_stata_options(fsspectest): + df = DataFrame({"a": [0]}) + df.to_stata( + "testmem://mockfile", storage_options={"test": "stata_write"}, write_index=False + ) + assert fsspectest.test[0] == "stata_write" + out = read_stata("testmem://mockfile", storage_options={"test": "stata_read"}) + assert fsspectest.test[0] == "stata_read" + tm.assert_frame_equal(df, out.astype("int64")) + + +def test_markdown_options(fsspectest): + pytest.importorskip("tabulate") + df = DataFrame({"a": [0]}) + df.to_markdown("testmem://mockfile", storage_options={"test": "md_write"}) + assert fsspectest.test[0] == "md_write" + assert fsspectest.cat("testmem://mockfile") + + +def test_non_fsspec_options(): + pytest.importorskip("pyarrow") + with pytest.raises(ValueError, match="storage_options"): + read_csv("localfile", storage_options={"a": True}) + with pytest.raises(ValueError, match="storage_options"): + # separate test for parquet, which has a different code path + read_parquet("localfile", storage_options={"a": True}) + by = io.BytesIO() + + with pytest.raises(ValueError, match="storage_options"): + read_csv(by, storage_options={"a": True}) + + df = DataFrame({"a": [0]}) + with pytest.raises(ValueError, match="storage_options"): + df.to_parquet("nonfsspecpath", storage_options={"a": True}) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_gbq.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_gbq.py new file mode 100644 index 0000000000000000000000000000000000000000..b2b212ceb2c41c9a8fa0828b691b6161db02d62f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_gbq.py @@ -0,0 +1,14 @@ +import pandas as pd +import pandas._testing as tm + + +def test_read_gbq_deprecated(): + with tm.assert_produces_warning(FutureWarning): + with tm.external_error_raised(Exception): + pd.read_gbq("fake") + + +def test_to_gbq_deprecated(): + with tm.assert_produces_warning(FutureWarning): + with tm.external_error_raised(Exception): + pd.DataFrame(range(1)).to_gbq("fake") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_gcs.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_gcs.py new file mode 100644 index 0000000000000000000000000000000000000000..4b337b5b82052f76580da8de1b64b70fd06f4617 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_gcs.py @@ -0,0 +1,228 @@ +from io import BytesIO +import os +import pathlib +import tarfile +import zipfile + +import numpy as np +import pytest + +from pandas.compat.pyarrow import pa_version_under17p0 + +from pandas import ( + DataFrame, + Index, + date_range, + read_csv, + read_excel, + read_json, + read_parquet, +) +import pandas._testing as tm +from pandas.util import _test_decorators as td + +pytestmark = pytest.mark.filterwarnings( + "ignore:Passing a BlockManager to DataFrame:DeprecationWarning" +) + + +@pytest.fixture +def gcs_buffer(): + """Emulate GCS using a binary buffer.""" + pytest.importorskip("gcsfs") + fsspec = pytest.importorskip("fsspec") + + gcs_buffer = BytesIO() + gcs_buffer.close = lambda: True + + class MockGCSFileSystem(fsspec.AbstractFileSystem): + @staticmethod + def open(*args, **kwargs): + gcs_buffer.seek(0) + return gcs_buffer + + def ls(self, path, **kwargs): + # needed for pyarrow + return [{"name": path, "type": "file"}] + + # Overwrites the default implementation from gcsfs to our mock class + fsspec.register_implementation("gs", MockGCSFileSystem, clobber=True) + + return gcs_buffer + + +# Patches pyarrow; other processes should not pick up change +@pytest.mark.single_cpu +@pytest.mark.parametrize("format", ["csv", "json", "parquet", "excel", "markdown"]) +def test_to_read_gcs(gcs_buffer, format, monkeypatch, capsys, request): + """ + Test that many to/read functions support GCS. + + GH 33987 + """ + + df1 = DataFrame( + { + "int": [1, 3], + "float": [2.0, np.nan], + "str": ["t", "s"], + "dt": date_range("2018-06-18", periods=2), + } + ) + + path = f"gs://test/test.{format}" + + if format == "csv": + df1.to_csv(path, index=True) + df2 = read_csv(path, parse_dates=["dt"], index_col=0) + elif format == "excel": + path = "gs://test/test.xlsx" + df1.to_excel(path) + df2 = read_excel(path, parse_dates=["dt"], index_col=0) + elif format == "json": + df1.to_json(path) + df2 = read_json(path, convert_dates=["dt"]) + elif format == "parquet": + pytest.importorskip("pyarrow") + pa_fs = pytest.importorskip("pyarrow.fs") + + class MockFileSystem(pa_fs.FileSystem): + @staticmethod + def from_uri(path): + print("Using pyarrow filesystem") + to_local = pathlib.Path(path.replace("gs://", "")).absolute().as_uri() + return pa_fs.LocalFileSystem(to_local) + + request.applymarker( + pytest.mark.xfail( + not pa_version_under17p0, + raises=TypeError, + reason="pyarrow 17 broke the mocked filesystem", + ) + ) + with monkeypatch.context() as m: + m.setattr(pa_fs, "FileSystem", MockFileSystem) + df1.to_parquet(path) + df2 = read_parquet(path) + captured = capsys.readouterr() + assert captured.out == "Using pyarrow filesystem\nUsing pyarrow filesystem\n" + elif format == "markdown": + pytest.importorskip("tabulate") + df1.to_markdown(path) + df2 = df1 + + tm.assert_frame_equal(df1, df2) + + +def assert_equal_zip_safe(result: bytes, expected: bytes, compression: str): + """ + For zip compression, only compare the CRC-32 checksum of the file contents + to avoid checking the time-dependent last-modified timestamp which + in some CI builds is off-by-one + + See https://en.wikipedia.org/wiki/ZIP_(file_format)#File_headers + """ + if compression == "zip": + # Only compare the CRC checksum of the file contents + with zipfile.ZipFile(BytesIO(result)) as exp, zipfile.ZipFile( + BytesIO(expected) + ) as res: + for res_info, exp_info in zip(res.infolist(), exp.infolist()): + assert res_info.CRC == exp_info.CRC + elif compression == "tar": + with tarfile.open(fileobj=BytesIO(result)) as tar_exp, tarfile.open( + fileobj=BytesIO(expected) + ) as tar_res: + for tar_res_info, tar_exp_info in zip( + tar_res.getmembers(), tar_exp.getmembers() + ): + actual_file = tar_res.extractfile(tar_res_info) + expected_file = tar_exp.extractfile(tar_exp_info) + assert (actual_file is None) == (expected_file is None) + if actual_file is not None and expected_file is not None: + assert actual_file.read() == expected_file.read() + else: + assert result == expected + + +@pytest.mark.parametrize("encoding", ["utf-8", "cp1251"]) +def test_to_csv_compression_encoding_gcs( + gcs_buffer, compression_only, encoding, compression_to_extension +): + """ + Compression and encoding should with GCS. + + GH 35677 (to_csv, compression), GH 26124 (to_csv, encoding), and + GH 32392 (read_csv, encoding) + """ + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=Index(list("ABCD"), dtype=object), + index=Index([f"i-{i}" for i in range(30)], dtype=object), + ) + + # reference of compressed and encoded file + compression = {"method": compression_only} + if compression_only == "gzip": + compression["mtime"] = 1 # be reproducible + buffer = BytesIO() + df.to_csv(buffer, compression=compression, encoding=encoding, mode="wb") + + # write compressed file with explicit compression + path_gcs = "gs://test/test.csv" + df.to_csv(path_gcs, compression=compression, encoding=encoding) + res = gcs_buffer.getvalue() + expected = buffer.getvalue() + assert_equal_zip_safe(res, expected, compression_only) + + read_df = read_csv( + path_gcs, index_col=0, compression=compression_only, encoding=encoding + ) + tm.assert_frame_equal(df, read_df) + + # write compressed file with implicit compression + file_ext = compression_to_extension[compression_only] + compression["method"] = "infer" + path_gcs += f".{file_ext}" + df.to_csv(path_gcs, compression=compression, encoding=encoding) + + res = gcs_buffer.getvalue() + expected = buffer.getvalue() + assert_equal_zip_safe(res, expected, compression_only) + + read_df = read_csv(path_gcs, index_col=0, compression="infer", encoding=encoding) + tm.assert_frame_equal(df, read_df) + + +def test_to_parquet_gcs_new_file(monkeypatch, tmpdir): + """Regression test for writing to a not-yet-existent GCS Parquet file.""" + pytest.importorskip("fastparquet") + pytest.importorskip("gcsfs") + + from fsspec import AbstractFileSystem + + df1 = DataFrame( + { + "int": [1, 3], + "float": [2.0, np.nan], + "str": ["t", "s"], + "dt": date_range("2018-06-18", periods=2), + } + ) + + class MockGCSFileSystem(AbstractFileSystem): + def open(self, path, mode="r", *args): + if "w" not in mode: + raise FileNotFoundError + return open(os.path.join(tmpdir, "test.parquet"), mode, encoding="utf-8") + + monkeypatch.setattr("gcsfs.GCSFileSystem", MockGCSFileSystem) + df1.to_parquet( + "gs://test/test.csv", index=True, engine="fastparquet", compression=None + ) + + +@td.skip_if_installed("gcsfs") +def test_gcs_not_present_exception(): + with tm.external_error_raised(ImportError): + read_csv("gs://test/test.csv") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_http_headers.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_http_headers.py new file mode 100644 index 0000000000000000000000000000000000000000..2ca11ad1f74e6381e389577e821d37d89cc689db --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_http_headers.py @@ -0,0 +1,172 @@ +""" +Tests for the pandas custom headers in http(s) requests +""" +from functools import partial +import gzip +from io import BytesIO + +import pytest + +import pandas.util._test_decorators as td + +import pandas as pd +import pandas._testing as tm + +pytestmark = [ + pytest.mark.single_cpu, + pytest.mark.network, + pytest.mark.filterwarnings( + "ignore:Passing a BlockManager to DataFrame:DeprecationWarning" + ), +] + + +def gzip_bytes(response_bytes): + with BytesIO() as bio: + with gzip.GzipFile(fileobj=bio, mode="w") as zipper: + zipper.write(response_bytes) + return bio.getvalue() + + +def csv_responder(df): + return df.to_csv(index=False).encode("utf-8") + + +def gz_csv_responder(df): + return gzip_bytes(csv_responder(df)) + + +def json_responder(df): + return df.to_json().encode("utf-8") + + +def gz_json_responder(df): + return gzip_bytes(json_responder(df)) + + +def html_responder(df): + return df.to_html(index=False).encode("utf-8") + + +def parquetpyarrow_reponder(df): + return df.to_parquet(index=False, engine="pyarrow") + + +def parquetfastparquet_responder(df): + # the fastparquet engine doesn't like to write to a buffer + # it can do it via the open_with function being set appropriately + # however it automatically calls the close method and wipes the buffer + # so just overwrite that attribute on this instance to not do that + + # protected by an importorskip in the respective test + import fsspec + + df.to_parquet( + "memory://fastparquet_user_agent.parquet", + index=False, + engine="fastparquet", + compression=None, + ) + with fsspec.open("memory://fastparquet_user_agent.parquet", "rb") as f: + return f.read() + + +def pickle_respnder(df): + with BytesIO() as bio: + df.to_pickle(bio) + return bio.getvalue() + + +def stata_responder(df): + with BytesIO() as bio: + df.to_stata(bio, write_index=False) + return bio.getvalue() + + +@pytest.mark.parametrize( + "responder, read_method", + [ + (csv_responder, pd.read_csv), + (json_responder, pd.read_json), + ( + html_responder, + lambda *args, **kwargs: pd.read_html(*args, **kwargs)[0], + ), + pytest.param( + parquetpyarrow_reponder, + partial(pd.read_parquet, engine="pyarrow"), + marks=td.skip_if_no("pyarrow"), + ), + pytest.param( + parquetfastparquet_responder, + partial(pd.read_parquet, engine="fastparquet"), + # TODO(ArrayManager) fastparquet + marks=[ + td.skip_if_no("fastparquet"), + td.skip_if_no("fsspec"), + td.skip_array_manager_not_yet_implemented, + ], + ), + (pickle_respnder, pd.read_pickle), + (stata_responder, pd.read_stata), + (gz_csv_responder, pd.read_csv), + (gz_json_responder, pd.read_json), + ], +) +@pytest.mark.parametrize( + "storage_options", + [ + None, + {"User-Agent": "foo"}, + {"User-Agent": "foo", "Auth": "bar"}, + ], +) +def test_request_headers(responder, read_method, httpserver, storage_options): + expected = pd.DataFrame({"a": ["b"]}) + default_headers = ["Accept-Encoding", "Host", "Connection", "User-Agent"] + if "gz" in responder.__name__: + extra = {"Content-Encoding": "gzip"} + if storage_options is None: + storage_options = extra + else: + storage_options |= extra + else: + extra = None + expected_headers = set(default_headers).union( + storage_options.keys() if storage_options else [] + ) + httpserver.serve_content(content=responder(expected), headers=extra) + result = read_method(httpserver.url, storage_options=storage_options) + tm.assert_frame_equal(result, expected) + + request_headers = dict(httpserver.requests[0].headers) + for header in expected_headers: + exp = request_headers.pop(header) + if storage_options and header in storage_options: + assert exp == storage_options[header] + # No extra headers added + assert not request_headers + + +@pytest.mark.parametrize( + "engine", + [ + "pyarrow", + "fastparquet", + ], +) +def test_to_parquet_to_disk_with_storage_options(engine): + headers = { + "User-Agent": "custom", + "Auth": "other_custom", + } + + pytest.importorskip(engine) + + true_df = pd.DataFrame({"column_name": ["column_value"]}) + msg = ( + "storage_options passed with file object or non-fsspec file path|" + "storage_options passed with buffer, or non-supported URL" + ) + with pytest.raises(ValueError, match=msg): + true_df.to_parquet("/tmp/junk.parquet", storage_options=headers, engine=engine) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_orc.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_orc.py new file mode 100644 index 0000000000000000000000000000000000000000..a4021311fc963a41633ebec2680c7f6d79525044 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_orc.py @@ -0,0 +1,436 @@ +""" test orc compat """ +import datetime +from decimal import Decimal +from io import BytesIO +import os +import pathlib + +import numpy as np +import pytest + +import pandas as pd +from pandas import read_orc +import pandas._testing as tm +from pandas.core.arrays import StringArray + +pytest.importorskip("pyarrow.orc") + +import pyarrow as pa + +pytestmark = pytest.mark.filterwarnings( + "ignore:Passing a BlockManager to DataFrame:DeprecationWarning" +) + + +@pytest.fixture +def dirpath(datapath): + return datapath("io", "data", "orc") + + +@pytest.fixture( + params=[ + np.array([1, 20], dtype="uint64"), + pd.Series(["a", "b", "a"], dtype="category"), + [pd.Interval(left=0, right=2), pd.Interval(left=0, right=5)], + [pd.Period("2022-01-03", freq="D"), pd.Period("2022-01-04", freq="D")], + ] +) +def orc_writer_dtypes_not_supported(request): + # Examples of dataframes with dtypes for which conversion to ORC + # hasn't been implemented yet, that is, Category, unsigned integers, + # interval, period and sparse. + return pd.DataFrame({"unimpl": request.param}) + + +def test_orc_reader_empty(dirpath): + columns = [ + "boolean1", + "byte1", + "short1", + "int1", + "long1", + "float1", + "double1", + "bytes1", + "string1", + ] + dtypes = [ + "bool", + "int8", + "int16", + "int32", + "int64", + "float32", + "float64", + "object", + "object", + ] + expected = pd.DataFrame(index=pd.RangeIndex(0)) + for colname, dtype in zip(columns, dtypes): + expected[colname] = pd.Series(dtype=dtype) + + inputfile = os.path.join(dirpath, "TestOrcFile.emptyFile.orc") + got = read_orc(inputfile, columns=columns) + + tm.assert_equal(expected, got) + + +def test_orc_reader_basic(dirpath): + data = { + "boolean1": np.array([False, True], dtype="bool"), + "byte1": np.array([1, 100], dtype="int8"), + "short1": np.array([1024, 2048], dtype="int16"), + "int1": np.array([65536, 65536], dtype="int32"), + "long1": np.array([9223372036854775807, 9223372036854775807], dtype="int64"), + "float1": np.array([1.0, 2.0], dtype="float32"), + "double1": np.array([-15.0, -5.0], dtype="float64"), + "bytes1": np.array([b"\x00\x01\x02\x03\x04", b""], dtype="object"), + "string1": np.array(["hi", "bye"], dtype="object"), + } + expected = pd.DataFrame.from_dict(data) + + inputfile = os.path.join(dirpath, "TestOrcFile.test1.orc") + got = read_orc(inputfile, columns=data.keys()) + + tm.assert_equal(expected, got) + + +def test_orc_reader_decimal(dirpath): + # Only testing the first 10 rows of data + data = { + "_col0": np.array( + [ + Decimal("-1000.50000"), + Decimal("-999.60000"), + Decimal("-998.70000"), + Decimal("-997.80000"), + Decimal("-996.90000"), + Decimal("-995.10000"), + Decimal("-994.11000"), + Decimal("-993.12000"), + Decimal("-992.13000"), + Decimal("-991.14000"), + ], + dtype="object", + ) + } + expected = pd.DataFrame.from_dict(data) + + inputfile = os.path.join(dirpath, "TestOrcFile.decimal.orc") + got = read_orc(inputfile).iloc[:10] + + tm.assert_equal(expected, got) + + +def test_orc_reader_date_low(dirpath): + data = { + "time": np.array( + [ + "1900-05-05 12:34:56.100000", + "1900-05-05 12:34:56.100100", + "1900-05-05 12:34:56.100200", + "1900-05-05 12:34:56.100300", + "1900-05-05 12:34:56.100400", + "1900-05-05 12:34:56.100500", + "1900-05-05 12:34:56.100600", + "1900-05-05 12:34:56.100700", + "1900-05-05 12:34:56.100800", + "1900-05-05 12:34:56.100900", + ], + dtype="datetime64[ns]", + ), + "date": np.array( + [ + datetime.date(1900, 12, 25), + datetime.date(1900, 12, 25), + datetime.date(1900, 12, 25), + datetime.date(1900, 12, 25), + datetime.date(1900, 12, 25), + datetime.date(1900, 12, 25), + datetime.date(1900, 12, 25), + datetime.date(1900, 12, 25), + datetime.date(1900, 12, 25), + datetime.date(1900, 12, 25), + ], + dtype="object", + ), + } + expected = pd.DataFrame.from_dict(data) + + inputfile = os.path.join(dirpath, "TestOrcFile.testDate1900.orc") + got = read_orc(inputfile).iloc[:10] + + tm.assert_equal(expected, got) + + +def test_orc_reader_date_high(dirpath): + data = { + "time": np.array( + [ + "2038-05-05 12:34:56.100000", + "2038-05-05 12:34:56.100100", + "2038-05-05 12:34:56.100200", + "2038-05-05 12:34:56.100300", + "2038-05-05 12:34:56.100400", + "2038-05-05 12:34:56.100500", + "2038-05-05 12:34:56.100600", + "2038-05-05 12:34:56.100700", + "2038-05-05 12:34:56.100800", + "2038-05-05 12:34:56.100900", + ], + dtype="datetime64[ns]", + ), + "date": np.array( + [ + datetime.date(2038, 12, 25), + datetime.date(2038, 12, 25), + datetime.date(2038, 12, 25), + datetime.date(2038, 12, 25), + datetime.date(2038, 12, 25), + datetime.date(2038, 12, 25), + datetime.date(2038, 12, 25), + datetime.date(2038, 12, 25), + datetime.date(2038, 12, 25), + datetime.date(2038, 12, 25), + ], + dtype="object", + ), + } + expected = pd.DataFrame.from_dict(data) + + inputfile = os.path.join(dirpath, "TestOrcFile.testDate2038.orc") + got = read_orc(inputfile).iloc[:10] + + tm.assert_equal(expected, got) + + +def test_orc_reader_snappy_compressed(dirpath): + data = { + "int1": np.array( + [ + -1160101563, + 1181413113, + 2065821249, + -267157795, + 172111193, + 1752363137, + 1406072123, + 1911809390, + -1308542224, + -467100286, + ], + dtype="int32", + ), + "string1": np.array( + [ + "f50dcb8", + "382fdaaa", + "90758c6", + "9e8caf3f", + "ee97332b", + "d634da1", + "2bea4396", + "d67d89e8", + "ad71007e", + "e8c82066", + ], + dtype="object", + ), + } + expected = pd.DataFrame.from_dict(data) + + inputfile = os.path.join(dirpath, "TestOrcFile.testSnappy.orc") + got = read_orc(inputfile).iloc[:10] + + tm.assert_equal(expected, got) + + +def test_orc_roundtrip_file(dirpath): + # GH44554 + # PyArrow gained ORC write support with the current argument order + pytest.importorskip("pyarrow") + + data = { + "boolean1": np.array([False, True], dtype="bool"), + "byte1": np.array([1, 100], dtype="int8"), + "short1": np.array([1024, 2048], dtype="int16"), + "int1": np.array([65536, 65536], dtype="int32"), + "long1": np.array([9223372036854775807, 9223372036854775807], dtype="int64"), + "float1": np.array([1.0, 2.0], dtype="float32"), + "double1": np.array([-15.0, -5.0], dtype="float64"), + "bytes1": np.array([b"\x00\x01\x02\x03\x04", b""], dtype="object"), + "string1": np.array(["hi", "bye"], dtype="object"), + } + expected = pd.DataFrame.from_dict(data) + + with tm.ensure_clean() as path: + expected.to_orc(path) + got = read_orc(path) + + tm.assert_equal(expected, got) + + +def test_orc_roundtrip_bytesio(): + # GH44554 + # PyArrow gained ORC write support with the current argument order + pytest.importorskip("pyarrow") + + data = { + "boolean1": np.array([False, True], dtype="bool"), + "byte1": np.array([1, 100], dtype="int8"), + "short1": np.array([1024, 2048], dtype="int16"), + "int1": np.array([65536, 65536], dtype="int32"), + "long1": np.array([9223372036854775807, 9223372036854775807], dtype="int64"), + "float1": np.array([1.0, 2.0], dtype="float32"), + "double1": np.array([-15.0, -5.0], dtype="float64"), + "bytes1": np.array([b"\x00\x01\x02\x03\x04", b""], dtype="object"), + "string1": np.array(["hi", "bye"], dtype="object"), + } + expected = pd.DataFrame.from_dict(data) + + bytes = expected.to_orc() + got = read_orc(BytesIO(bytes)) + + tm.assert_equal(expected, got) + + +def test_orc_writer_dtypes_not_supported(orc_writer_dtypes_not_supported): + # GH44554 + # PyArrow gained ORC write support with the current argument order + pytest.importorskip("pyarrow") + + msg = "The dtype of one or more columns is not supported yet." + with pytest.raises(NotImplementedError, match=msg): + orc_writer_dtypes_not_supported.to_orc() + + +def test_orc_dtype_backend_pyarrow(): + pytest.importorskip("pyarrow") + df = pd.DataFrame( + { + "string": list("abc"), + "string_with_nan": ["a", np.nan, "c"], + "string_with_none": ["a", None, "c"], + "bytes": [b"foo", b"bar", None], + "int": list(range(1, 4)), + "float": np.arange(4.0, 7.0, dtype="float64"), + "float_with_nan": [2.0, np.nan, 3.0], + "bool": [True, False, True], + "bool_with_na": [True, False, None], + "datetime": pd.date_range("20130101", periods=3), + "datetime_with_nat": [ + pd.Timestamp("20130101"), + pd.NaT, + pd.Timestamp("20130103"), + ], + } + ) + + bytes_data = df.copy().to_orc() + result = read_orc(BytesIO(bytes_data), dtype_backend="pyarrow") + + expected = pd.DataFrame( + { + col: pd.arrays.ArrowExtensionArray(pa.array(df[col], from_pandas=True)) + for col in df.columns + } + ) + + tm.assert_frame_equal(result, expected) + + +def test_orc_dtype_backend_numpy_nullable(): + # GH#50503 + pytest.importorskip("pyarrow") + df = pd.DataFrame( + { + "string": list("abc"), + "string_with_nan": ["a", np.nan, "c"], + "string_with_none": ["a", None, "c"], + "int": list(range(1, 4)), + "int_with_nan": pd.Series([1, pd.NA, 3], dtype="Int64"), + "na_only": pd.Series([pd.NA, pd.NA, pd.NA], dtype="Int64"), + "float": np.arange(4.0, 7.0, dtype="float64"), + "float_with_nan": [2.0, np.nan, 3.0], + "bool": [True, False, True], + "bool_with_na": [True, False, None], + } + ) + + bytes_data = df.copy().to_orc() + result = read_orc(BytesIO(bytes_data), dtype_backend="numpy_nullable") + + expected = pd.DataFrame( + { + "string": StringArray(np.array(["a", "b", "c"], dtype=np.object_)), + "string_with_nan": StringArray( + np.array(["a", pd.NA, "c"], dtype=np.object_) + ), + "string_with_none": StringArray( + np.array(["a", pd.NA, "c"], dtype=np.object_) + ), + "int": pd.Series([1, 2, 3], dtype="Int64"), + "int_with_nan": pd.Series([1, pd.NA, 3], dtype="Int64"), + "na_only": pd.Series([pd.NA, pd.NA, pd.NA], dtype="Int64"), + "float": pd.Series([4.0, 5.0, 6.0], dtype="Float64"), + "float_with_nan": pd.Series([2.0, pd.NA, 3.0], dtype="Float64"), + "bool": pd.Series([True, False, True], dtype="boolean"), + "bool_with_na": pd.Series([True, False, pd.NA], dtype="boolean"), + } + ) + + tm.assert_frame_equal(result, expected) + + +def test_orc_uri_path(): + expected = pd.DataFrame({"int": list(range(1, 4))}) + with tm.ensure_clean("tmp.orc") as path: + expected.to_orc(path) + uri = pathlib.Path(path).as_uri() + result = read_orc(uri) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "index", + [ + pd.RangeIndex(start=2, stop=5, step=1), + pd.RangeIndex(start=0, stop=3, step=1, name="non-default"), + pd.Index([1, 2, 3]), + ], +) +def test_to_orc_non_default_index(index): + df = pd.DataFrame({"a": [1, 2, 3]}, index=index) + msg = ( + "orc does not support serializing a non-default index|" + "orc does not serialize index meta-data" + ) + with pytest.raises(ValueError, match=msg): + df.to_orc() + + +def test_invalid_dtype_backend(): + msg = ( + "dtype_backend numpy is invalid, only 'numpy_nullable' and " + "'pyarrow' are allowed." + ) + df = pd.DataFrame({"int": list(range(1, 4))}) + with tm.ensure_clean("tmp.orc") as path: + df.to_orc(path) + with pytest.raises(ValueError, match=msg): + read_orc(path, dtype_backend="numpy") + + +def test_string_inference(tmp_path): + # GH#54431 + path = tmp_path / "test_string_inference.p" + df = pd.DataFrame(data={"a": ["x", "y"]}) + df.to_orc(path) + with pd.option_context("future.infer_string", True): + result = read_orc(path) + expected = pd.DataFrame( + data={"a": ["x", "y"]}, + dtype="string[pyarrow_numpy]", + columns=pd.Index(["a"], dtype="string[pyarrow_numpy]"), + ) + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_parquet.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_parquet.py new file mode 100644 index 0000000000000000000000000000000000000000..760a64c8d4c33d6e1f48460cd99cdb00e030a9b1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_parquet.py @@ -0,0 +1,1427 @@ +""" test parquet compat """ +import datetime +from decimal import Decimal +from io import BytesIO +import os +import pathlib + +import numpy as np +import pytest + +from pandas._config import using_copy_on_write +from pandas._config.config import _get_option + +from pandas.compat import is_platform_windows +from pandas.compat.pyarrow import ( + pa_version_under11p0, + pa_version_under13p0, + pa_version_under15p0, +) + +import pandas as pd +import pandas._testing as tm +from pandas.util.version import Version + +from pandas.io.parquet import ( + FastParquetImpl, + PyArrowImpl, + get_engine, + read_parquet, + to_parquet, +) + +try: + import pyarrow + + _HAVE_PYARROW = True +except ImportError: + _HAVE_PYARROW = False + +try: + import fastparquet + + _HAVE_FASTPARQUET = True +except ImportError: + _HAVE_FASTPARQUET = False + + +# TODO(ArrayManager) fastparquet relies on BlockManager internals + +pytestmark = [ + pytest.mark.filterwarnings("ignore:DataFrame._data is deprecated:FutureWarning"), + pytest.mark.filterwarnings( + "ignore:Passing a BlockManager to DataFrame:DeprecationWarning" + ), +] + + +# setup engines & skips +@pytest.fixture( + params=[ + pytest.param( + "fastparquet", + marks=pytest.mark.skipif( + not _HAVE_FASTPARQUET + or _get_option("mode.data_manager", silent=True) == "array", + reason="fastparquet is not installed or ArrayManager is used", + ), + ), + pytest.param( + "pyarrow", + marks=pytest.mark.skipif( + not _HAVE_PYARROW, reason="pyarrow is not installed" + ), + ), + ] +) +def engine(request): + return request.param + + +@pytest.fixture +def pa(): + if not _HAVE_PYARROW: + pytest.skip("pyarrow is not installed") + return "pyarrow" + + +@pytest.fixture +def fp(): + if not _HAVE_FASTPARQUET: + pytest.skip("fastparquet is not installed") + elif _get_option("mode.data_manager", silent=True) == "array": + pytest.skip("ArrayManager is not supported with fastparquet") + return "fastparquet" + + +@pytest.fixture +def df_compat(): + return pd.DataFrame({"A": [1, 2, 3], "B": "foo"}) + + +@pytest.fixture +def df_cross_compat(): + df = pd.DataFrame( + { + "a": list("abc"), + "b": list(range(1, 4)), + # 'c': np.arange(3, 6).astype('u1'), + "d": np.arange(4.0, 7.0, dtype="float64"), + "e": [True, False, True], + "f": pd.date_range("20130101", periods=3), + # 'g': pd.date_range('20130101', periods=3, + # tz='US/Eastern'), + # 'h': pd.date_range('20130101', periods=3, freq='ns') + } + ) + return df + + +@pytest.fixture +def df_full(): + return pd.DataFrame( + { + "string": list("abc"), + "string_with_nan": ["a", np.nan, "c"], + "string_with_none": ["a", None, "c"], + "bytes": [b"foo", b"bar", b"baz"], + "unicode": ["foo", "bar", "baz"], + "int": list(range(1, 4)), + "uint": np.arange(3, 6).astype("u1"), + "float": np.arange(4.0, 7.0, dtype="float64"), + "float_with_nan": [2.0, np.nan, 3.0], + "bool": [True, False, True], + "datetime": pd.date_range("20130101", periods=3), + "datetime_with_nat": [ + pd.Timestamp("20130101"), + pd.NaT, + pd.Timestamp("20130103"), + ], + } + ) + + +@pytest.fixture( + params=[ + datetime.datetime.now(datetime.timezone.utc), + datetime.datetime.now(datetime.timezone.min), + datetime.datetime.now(datetime.timezone.max), + datetime.datetime.strptime("2019-01-04T16:41:24+0200", "%Y-%m-%dT%H:%M:%S%z"), + datetime.datetime.strptime("2019-01-04T16:41:24+0215", "%Y-%m-%dT%H:%M:%S%z"), + datetime.datetime.strptime("2019-01-04T16:41:24-0200", "%Y-%m-%dT%H:%M:%S%z"), + datetime.datetime.strptime("2019-01-04T16:41:24-0215", "%Y-%m-%dT%H:%M:%S%z"), + ] +) +def timezone_aware_date_list(request): + return request.param + + +def check_round_trip( + df, + engine=None, + path=None, + write_kwargs=None, + read_kwargs=None, + expected=None, + check_names=True, + check_like=False, + check_dtype=True, + repeat=2, +): + """Verify parquet serializer and deserializer produce the same results. + + Performs a pandas to disk and disk to pandas round trip, + then compares the 2 resulting DataFrames to verify equality. + + Parameters + ---------- + df: Dataframe + engine: str, optional + 'pyarrow' or 'fastparquet' + path: str, optional + write_kwargs: dict of str:str, optional + read_kwargs: dict of str:str, optional + expected: DataFrame, optional + Expected deserialization result, otherwise will be equal to `df` + check_names: list of str, optional + Closed set of column names to be compared + check_like: bool, optional + If True, ignore the order of index & columns. + repeat: int, optional + How many times to repeat the test + """ + write_kwargs = write_kwargs or {"compression": None} + read_kwargs = read_kwargs or {} + + if expected is None: + expected = df + + if engine: + write_kwargs["engine"] = engine + read_kwargs["engine"] = engine + + def compare(repeat): + for _ in range(repeat): + df.to_parquet(path, **write_kwargs) + actual = read_parquet(path, **read_kwargs) + + if "string_with_nan" in expected: + expected.loc[1, "string_with_nan"] = None + tm.assert_frame_equal( + expected, + actual, + check_names=check_names, + check_like=check_like, + check_dtype=check_dtype, + ) + + if path is None: + with tm.ensure_clean() as path: + compare(repeat) + else: + compare(repeat) + + +def check_partition_names(path, expected): + """Check partitions of a parquet file are as expected. + + Parameters + ---------- + path: str + Path of the dataset. + expected: iterable of str + Expected partition names. + """ + import pyarrow.dataset as ds + + dataset = ds.dataset(path, partitioning="hive") + assert dataset.partitioning.schema.names == expected + + +def test_invalid_engine(df_compat): + msg = "engine must be one of 'pyarrow', 'fastparquet'" + with pytest.raises(ValueError, match=msg): + check_round_trip(df_compat, "foo", "bar") + + +def test_options_py(df_compat, pa): + # use the set option + + with pd.option_context("io.parquet.engine", "pyarrow"): + check_round_trip(df_compat) + + +def test_options_fp(df_compat, fp): + # use the set option + + with pd.option_context("io.parquet.engine", "fastparquet"): + check_round_trip(df_compat) + + +def test_options_auto(df_compat, fp, pa): + # use the set option + + with pd.option_context("io.parquet.engine", "auto"): + check_round_trip(df_compat) + + +def test_options_get_engine(fp, pa): + assert isinstance(get_engine("pyarrow"), PyArrowImpl) + assert isinstance(get_engine("fastparquet"), FastParquetImpl) + + with pd.option_context("io.parquet.engine", "pyarrow"): + assert isinstance(get_engine("auto"), PyArrowImpl) + assert isinstance(get_engine("pyarrow"), PyArrowImpl) + assert isinstance(get_engine("fastparquet"), FastParquetImpl) + + with pd.option_context("io.parquet.engine", "fastparquet"): + assert isinstance(get_engine("auto"), FastParquetImpl) + assert isinstance(get_engine("pyarrow"), PyArrowImpl) + assert isinstance(get_engine("fastparquet"), FastParquetImpl) + + with pd.option_context("io.parquet.engine", "auto"): + assert isinstance(get_engine("auto"), PyArrowImpl) + assert isinstance(get_engine("pyarrow"), PyArrowImpl) + assert isinstance(get_engine("fastparquet"), FastParquetImpl) + + +def test_get_engine_auto_error_message(): + # Expect different error messages from get_engine(engine="auto") + # if engines aren't installed vs. are installed but bad version + from pandas.compat._optional import VERSIONS + + # Do we have engines installed, but a bad version of them? + pa_min_ver = VERSIONS.get("pyarrow") + fp_min_ver = VERSIONS.get("fastparquet") + have_pa_bad_version = ( + False + if not _HAVE_PYARROW + else Version(pyarrow.__version__) < Version(pa_min_ver) + ) + have_fp_bad_version = ( + False + if not _HAVE_FASTPARQUET + else Version(fastparquet.__version__) < Version(fp_min_ver) + ) + # Do we have usable engines installed? + have_usable_pa = _HAVE_PYARROW and not have_pa_bad_version + have_usable_fp = _HAVE_FASTPARQUET and not have_fp_bad_version + + if not have_usable_pa and not have_usable_fp: + # No usable engines found. + if have_pa_bad_version: + match = f"Pandas requires version .{pa_min_ver}. or newer of .pyarrow." + with pytest.raises(ImportError, match=match): + get_engine("auto") + else: + match = "Missing optional dependency .pyarrow." + with pytest.raises(ImportError, match=match): + get_engine("auto") + + if have_fp_bad_version: + match = f"Pandas requires version .{fp_min_ver}. or newer of .fastparquet." + with pytest.raises(ImportError, match=match): + get_engine("auto") + else: + match = "Missing optional dependency .fastparquet." + with pytest.raises(ImportError, match=match): + get_engine("auto") + + +def test_cross_engine_pa_fp(df_cross_compat, pa, fp): + # cross-compat with differing reading/writing engines + + df = df_cross_compat + with tm.ensure_clean() as path: + df.to_parquet(path, engine=pa, compression=None) + + result = read_parquet(path, engine=fp) + tm.assert_frame_equal(result, df) + + result = read_parquet(path, engine=fp, columns=["a", "d"]) + tm.assert_frame_equal(result, df[["a", "d"]]) + + +def test_cross_engine_fp_pa(df_cross_compat, pa, fp): + # cross-compat with differing reading/writing engines + df = df_cross_compat + with tm.ensure_clean() as path: + df.to_parquet(path, engine=fp, compression=None) + + result = read_parquet(path, engine=pa) + tm.assert_frame_equal(result, df) + + result = read_parquet(path, engine=pa, columns=["a", "d"]) + tm.assert_frame_equal(result, df[["a", "d"]]) + + +def test_parquet_pos_args_deprecation(engine): + # GH-54229 + df = pd.DataFrame({"a": [1, 2, 3]}) + msg = ( + r"Starting with pandas version 3.0 all arguments of to_parquet except for the " + r"argument 'path' will be keyword-only." + ) + with tm.ensure_clean() as path: + with tm.assert_produces_warning( + FutureWarning, + match=msg, + check_stacklevel=False, + raise_on_extra_warnings=False, + ): + df.to_parquet(path, engine) + + +class Base: + def check_error_on_write(self, df, engine, exc, err_msg): + # check that we are raising the exception on writing + with tm.ensure_clean() as path: + with pytest.raises(exc, match=err_msg): + to_parquet(df, path, engine, compression=None) + + def check_external_error_on_write(self, df, engine, exc): + # check that an external library is raising the exception on writing + with tm.ensure_clean() as path: + with tm.external_error_raised(exc): + to_parquet(df, path, engine, compression=None) + + @pytest.mark.network + @pytest.mark.single_cpu + def test_parquet_read_from_url(self, httpserver, datapath, df_compat, engine): + if engine != "auto": + pytest.importorskip(engine) + with open(datapath("io", "data", "parquet", "simple.parquet"), mode="rb") as f: + httpserver.serve_content(content=f.read()) + df = read_parquet(httpserver.url) + tm.assert_frame_equal(df, df_compat) + + +class TestBasic(Base): + def test_error(self, engine): + for obj in [ + pd.Series([1, 2, 3]), + 1, + "foo", + pd.Timestamp("20130101"), + np.array([1, 2, 3]), + ]: + msg = "to_parquet only supports IO with DataFrames" + self.check_error_on_write(obj, engine, ValueError, msg) + + def test_columns_dtypes(self, engine): + df = pd.DataFrame({"string": list("abc"), "int": list(range(1, 4))}) + + # unicode + df.columns = ["foo", "bar"] + check_round_trip(df, engine) + + @pytest.mark.parametrize("compression", [None, "gzip", "snappy", "brotli"]) + def test_compression(self, engine, compression): + df = pd.DataFrame({"A": [1, 2, 3]}) + check_round_trip(df, engine, write_kwargs={"compression": compression}) + + def test_read_columns(self, engine): + # GH18154 + df = pd.DataFrame({"string": list("abc"), "int": list(range(1, 4))}) + + expected = pd.DataFrame({"string": list("abc")}) + check_round_trip( + df, engine, expected=expected, read_kwargs={"columns": ["string"]} + ) + + def test_read_filters(self, engine, tmp_path): + df = pd.DataFrame( + { + "int": list(range(4)), + "part": list("aabb"), + } + ) + + expected = pd.DataFrame({"int": [0, 1]}) + check_round_trip( + df, + engine, + path=tmp_path, + expected=expected, + write_kwargs={"partition_cols": ["part"]}, + read_kwargs={"filters": [("part", "==", "a")], "columns": ["int"]}, + repeat=1, + ) + + def test_write_index(self, engine): + check_names = engine != "fastparquet" + + df = pd.DataFrame({"A": [1, 2, 3]}) + check_round_trip(df, engine) + + indexes = [ + [2, 3, 4], + pd.date_range("20130101", periods=3), + list("abc"), + [1, 3, 4], + ] + # non-default index + for index in indexes: + df.index = index + if isinstance(index, pd.DatetimeIndex): + df.index = df.index._with_freq(None) # freq doesn't round-trip + check_round_trip(df, engine, check_names=check_names) + + # index with meta-data + df.index = [0, 1, 2] + df.index.name = "foo" + check_round_trip(df, engine) + + def test_write_multiindex(self, pa): + # Not supported in fastparquet as of 0.1.3 or older pyarrow version + engine = pa + + df = pd.DataFrame({"A": [1, 2, 3]}) + index = pd.MultiIndex.from_tuples([("a", 1), ("a", 2), ("b", 1)]) + df.index = index + check_round_trip(df, engine) + + def test_multiindex_with_columns(self, pa): + engine = pa + dates = pd.date_range("01-Jan-2018", "01-Dec-2018", freq="MS") + df = pd.DataFrame( + np.random.default_rng(2).standard_normal((2 * len(dates), 3)), + columns=list("ABC"), + ) + index1 = pd.MultiIndex.from_product( + [["Level1", "Level2"], dates], names=["level", "date"] + ) + index2 = index1.copy(names=None) + for index in [index1, index2]: + df.index = index + + check_round_trip(df, engine) + check_round_trip( + df, engine, read_kwargs={"columns": ["A", "B"]}, expected=df[["A", "B"]] + ) + + def test_write_ignoring_index(self, engine): + # ENH 20768 + # Ensure index=False omits the index from the written Parquet file. + df = pd.DataFrame({"a": [1, 2, 3], "b": ["q", "r", "s"]}) + + write_kwargs = {"compression": None, "index": False} + + # Because we're dropping the index, we expect the loaded dataframe to + # have the default integer index. + expected = df.reset_index(drop=True) + + check_round_trip(df, engine, write_kwargs=write_kwargs, expected=expected) + + # Ignore custom index + df = pd.DataFrame( + {"a": [1, 2, 3], "b": ["q", "r", "s"]}, index=["zyx", "wvu", "tsr"] + ) + + check_round_trip(df, engine, write_kwargs=write_kwargs, expected=expected) + + # Ignore multi-indexes as well. + arrays = [ + ["bar", "bar", "baz", "baz", "foo", "foo", "qux", "qux"], + ["one", "two", "one", "two", "one", "two", "one", "two"], + ] + df = pd.DataFrame( + {"one": list(range(8)), "two": [-i for i in range(8)]}, index=arrays + ) + + expected = df.reset_index(drop=True) + check_round_trip(df, engine, write_kwargs=write_kwargs, expected=expected) + + def test_write_column_multiindex(self, engine): + # Not able to write column multi-indexes with non-string column names. + mi_columns = pd.MultiIndex.from_tuples([("a", 1), ("a", 2), ("b", 1)]) + df = pd.DataFrame( + np.random.default_rng(2).standard_normal((4, 3)), columns=mi_columns + ) + + if engine == "fastparquet": + self.check_error_on_write( + df, engine, TypeError, "Column name must be a string" + ) + elif engine == "pyarrow": + check_round_trip(df, engine) + + def test_write_column_multiindex_nonstring(self, engine): + # GH #34777 + + # Not able to write column multi-indexes with non-string column names + arrays = [ + ["bar", "bar", "baz", "baz", "foo", "foo", "qux", "qux"], + [1, 2, 1, 2, 1, 2, 1, 2], + ] + df = pd.DataFrame( + np.random.default_rng(2).standard_normal((8, 8)), columns=arrays + ) + df.columns.names = ["Level1", "Level2"] + if engine == "fastparquet": + self.check_error_on_write(df, engine, ValueError, "Column name") + elif engine == "pyarrow": + check_round_trip(df, engine) + + def test_write_column_multiindex_string(self, pa): + # GH #34777 + # Not supported in fastparquet as of 0.1.3 + engine = pa + + # Write column multi-indexes with string column names + arrays = [ + ["bar", "bar", "baz", "baz", "foo", "foo", "qux", "qux"], + ["one", "two", "one", "two", "one", "two", "one", "two"], + ] + df = pd.DataFrame( + np.random.default_rng(2).standard_normal((8, 8)), columns=arrays + ) + df.columns.names = ["ColLevel1", "ColLevel2"] + + check_round_trip(df, engine) + + def test_write_column_index_string(self, pa): + # GH #34777 + # Not supported in fastparquet as of 0.1.3 + engine = pa + + # Write column indexes with string column names + arrays = ["bar", "baz", "foo", "qux"] + df = pd.DataFrame( + np.random.default_rng(2).standard_normal((8, 4)), columns=arrays + ) + df.columns.name = "StringCol" + + check_round_trip(df, engine) + + def test_write_column_index_nonstring(self, engine): + # GH #34777 + + # Write column indexes with string column names + arrays = [1, 2, 3, 4] + df = pd.DataFrame( + np.random.default_rng(2).standard_normal((8, 4)), columns=arrays + ) + df.columns.name = "NonStringCol" + if engine == "fastparquet": + self.check_error_on_write( + df, engine, TypeError, "Column name must be a string" + ) + else: + check_round_trip(df, engine) + + def test_dtype_backend(self, engine, request): + pq = pytest.importorskip("pyarrow.parquet") + + if engine == "fastparquet": + # We are manually disabling fastparquet's + # nullable dtype support pending discussion + mark = pytest.mark.xfail( + reason="Fastparquet nullable dtype support is disabled" + ) + request.applymarker(mark) + + table = pyarrow.table( + { + "a": pyarrow.array([1, 2, 3, None], "int64"), + "b": pyarrow.array([1, 2, 3, None], "uint8"), + "c": pyarrow.array(["a", "b", "c", None]), + "d": pyarrow.array([True, False, True, None]), + # Test that nullable dtypes used even in absence of nulls + "e": pyarrow.array([1, 2, 3, 4], "int64"), + # GH 45694 + "f": pyarrow.array([1.0, 2.0, 3.0, None], "float32"), + "g": pyarrow.array([1.0, 2.0, 3.0, None], "float64"), + } + ) + with tm.ensure_clean() as path: + # write manually with pyarrow to write integers + pq.write_table(table, path) + result1 = read_parquet(path, engine=engine) + result2 = read_parquet(path, engine=engine, dtype_backend="numpy_nullable") + + assert result1["a"].dtype == np.dtype("float64") + expected = pd.DataFrame( + { + "a": pd.array([1, 2, 3, None], dtype="Int64"), + "b": pd.array([1, 2, 3, None], dtype="UInt8"), + "c": pd.array(["a", "b", "c", None], dtype="string"), + "d": pd.array([True, False, True, None], dtype="boolean"), + "e": pd.array([1, 2, 3, 4], dtype="Int64"), + "f": pd.array([1.0, 2.0, 3.0, None], dtype="Float32"), + "g": pd.array([1.0, 2.0, 3.0, None], dtype="Float64"), + } + ) + if engine == "fastparquet": + # Fastparquet doesn't support string columns yet + # Only int and boolean + result2 = result2.drop("c", axis=1) + expected = expected.drop("c", axis=1) + tm.assert_frame_equal(result2, expected) + + @pytest.mark.parametrize( + "dtype", + [ + "Int64", + "UInt8", + "boolean", + "object", + "datetime64[ns, UTC]", + "float", + "period[D]", + "Float64", + "string", + ], + ) + def test_read_empty_array(self, pa, dtype): + # GH #41241 + df = pd.DataFrame( + { + "value": pd.array([], dtype=dtype), + } + ) + # GH 45694 + expected = None + if dtype == "float": + expected = pd.DataFrame( + { + "value": pd.array([], dtype="Float64"), + } + ) + check_round_trip( + df, pa, read_kwargs={"dtype_backend": "numpy_nullable"}, expected=expected + ) + + +class TestParquetPyArrow(Base): + def test_basic(self, pa, df_full): + df = df_full + + # additional supported types for pyarrow + dti = pd.date_range("20130101", periods=3, tz="Europe/Brussels") + dti = dti._with_freq(None) # freq doesn't round-trip + df["datetime_tz"] = dti + df["bool_with_none"] = [True, None, True] + + check_round_trip(df, pa) + + def test_basic_subset_columns(self, pa, df_full): + # GH18628 + + df = df_full + # additional supported types for pyarrow + df["datetime_tz"] = pd.date_range("20130101", periods=3, tz="Europe/Brussels") + + check_round_trip( + df, + pa, + expected=df[["string", "int"]], + read_kwargs={"columns": ["string", "int"]}, + ) + + def test_to_bytes_without_path_or_buf_provided(self, pa, df_full): + # GH 37105 + buf_bytes = df_full.to_parquet(engine=pa) + assert isinstance(buf_bytes, bytes) + + buf_stream = BytesIO(buf_bytes) + res = read_parquet(buf_stream) + + expected = df_full.copy() + expected.loc[1, "string_with_nan"] = None + tm.assert_frame_equal(res, expected) + + def test_duplicate_columns(self, pa): + # not currently able to handle duplicate columns + df = pd.DataFrame(np.arange(12).reshape(4, 3), columns=list("aaa")).copy() + self.check_error_on_write(df, pa, ValueError, "Duplicate column names found") + + def test_timedelta(self, pa): + df = pd.DataFrame({"a": pd.timedelta_range("1 day", periods=3)}) + check_round_trip(df, pa) + + def test_unsupported(self, pa): + # mixed python objects + df = pd.DataFrame({"a": ["a", 1, 2.0]}) + # pyarrow 0.11 raises ArrowTypeError + # older pyarrows raise ArrowInvalid + self.check_external_error_on_write(df, pa, pyarrow.ArrowException) + + def test_unsupported_float16(self, pa): + # #44847, #44914 + # Not able to write float 16 column using pyarrow. + data = np.arange(2, 10, dtype=np.float16) + df = pd.DataFrame(data=data, columns=["fp16"]) + if pa_version_under15p0: + self.check_external_error_on_write(df, pa, pyarrow.ArrowException) + else: + check_round_trip(df, pa) + + @pytest.mark.xfail( + is_platform_windows(), + reason=( + "PyArrow does not cleanup of partial files dumps when unsupported " + "dtypes are passed to_parquet function in windows" + ), + ) + @pytest.mark.skipif(not pa_version_under15p0, reason="float16 works on 15") + @pytest.mark.parametrize("path_type", [str, pathlib.Path]) + def test_unsupported_float16_cleanup(self, pa, path_type): + # #44847, #44914 + # Not able to write float 16 column using pyarrow. + # Tests cleanup by pyarrow in case of an error + data = np.arange(2, 10, dtype=np.float16) + df = pd.DataFrame(data=data, columns=["fp16"]) + + with tm.ensure_clean() as path_str: + path = path_type(path_str) + with tm.external_error_raised(pyarrow.ArrowException): + df.to_parquet(path=path, engine=pa) + assert not os.path.isfile(path) + + def test_categorical(self, pa): + # supported in >= 0.7.0 + df = pd.DataFrame() + df["a"] = pd.Categorical(list("abcdef")) + + # test for null, out-of-order values, and unobserved category + df["b"] = pd.Categorical( + ["bar", "foo", "foo", "bar", None, "bar"], + dtype=pd.CategoricalDtype(["foo", "bar", "baz"]), + ) + + # test for ordered flag + df["c"] = pd.Categorical( + ["a", "b", "c", "a", "c", "b"], categories=["b", "c", "d"], ordered=True + ) + + check_round_trip(df, pa) + + @pytest.mark.single_cpu + def test_s3_roundtrip_explicit_fs(self, df_compat, s3_public_bucket, pa, s3so): + s3fs = pytest.importorskip("s3fs") + s3 = s3fs.S3FileSystem(**s3so) + kw = {"filesystem": s3} + check_round_trip( + df_compat, + pa, + path=f"{s3_public_bucket.name}/pyarrow.parquet", + read_kwargs=kw, + write_kwargs=kw, + ) + + @pytest.mark.single_cpu + def test_s3_roundtrip(self, df_compat, s3_public_bucket, pa, s3so): + # GH #19134 + s3so = {"storage_options": s3so} + check_round_trip( + df_compat, + pa, + path=f"s3://{s3_public_bucket.name}/pyarrow.parquet", + read_kwargs=s3so, + write_kwargs=s3so, + ) + + @pytest.mark.single_cpu + @pytest.mark.parametrize( + "partition_col", + [ + ["A"], + [], + ], + ) + def test_s3_roundtrip_for_dir( + self, df_compat, s3_public_bucket, pa, partition_col, s3so + ): + pytest.importorskip("s3fs") + # GH #26388 + expected_df = df_compat.copy() + + # GH #35791 + if partition_col: + expected_df = expected_df.astype(dict.fromkeys(partition_col, np.int32)) + partition_col_type = "category" + + expected_df[partition_col] = expected_df[partition_col].astype( + partition_col_type + ) + + check_round_trip( + df_compat, + pa, + expected=expected_df, + path=f"s3://{s3_public_bucket.name}/parquet_dir", + read_kwargs={"storage_options": s3so}, + write_kwargs={ + "partition_cols": partition_col, + "compression": None, + "storage_options": s3so, + }, + check_like=True, + repeat=1, + ) + + def test_read_file_like_obj_support(self, df_compat): + pytest.importorskip("pyarrow") + buffer = BytesIO() + df_compat.to_parquet(buffer) + df_from_buf = read_parquet(buffer) + tm.assert_frame_equal(df_compat, df_from_buf) + + def test_expand_user(self, df_compat, monkeypatch): + pytest.importorskip("pyarrow") + monkeypatch.setenv("HOME", "TestingUser") + monkeypatch.setenv("USERPROFILE", "TestingUser") + with pytest.raises(OSError, match=r".*TestingUser.*"): + read_parquet("~/file.parquet") + with pytest.raises(OSError, match=r".*TestingUser.*"): + df_compat.to_parquet("~/file.parquet") + + def test_partition_cols_supported(self, tmp_path, pa, df_full): + # GH #23283 + partition_cols = ["bool", "int"] + df = df_full + df.to_parquet(tmp_path, partition_cols=partition_cols, compression=None) + check_partition_names(tmp_path, partition_cols) + assert read_parquet(tmp_path).shape == df.shape + + def test_partition_cols_string(self, tmp_path, pa, df_full): + # GH #27117 + partition_cols = "bool" + partition_cols_list = [partition_cols] + df = df_full + df.to_parquet(tmp_path, partition_cols=partition_cols, compression=None) + check_partition_names(tmp_path, partition_cols_list) + assert read_parquet(tmp_path).shape == df.shape + + @pytest.mark.parametrize( + "path_type", [str, lambda x: x], ids=["string", "pathlib.Path"] + ) + def test_partition_cols_pathlib(self, tmp_path, pa, df_compat, path_type): + # GH 35902 + + partition_cols = "B" + partition_cols_list = [partition_cols] + df = df_compat + + path = path_type(tmp_path) + df.to_parquet(path, partition_cols=partition_cols_list) + assert read_parquet(path).shape == df.shape + + def test_empty_dataframe(self, pa): + # GH #27339 + df = pd.DataFrame(index=[], columns=[]) + check_round_trip(df, pa) + + def test_write_with_schema(self, pa): + import pyarrow + + df = pd.DataFrame({"x": [0, 1]}) + schema = pyarrow.schema([pyarrow.field("x", type=pyarrow.bool_())]) + out_df = df.astype(bool) + check_round_trip(df, pa, write_kwargs={"schema": schema}, expected=out_df) + + def test_additional_extension_arrays(self, pa): + # test additional ExtensionArrays that are supported through the + # __arrow_array__ protocol + pytest.importorskip("pyarrow") + df = pd.DataFrame( + { + "a": pd.Series([1, 2, 3], dtype="Int64"), + "b": pd.Series([1, 2, 3], dtype="UInt32"), + "c": pd.Series(["a", None, "c"], dtype="string"), + } + ) + check_round_trip(df, pa) + + df = pd.DataFrame({"a": pd.Series([1, 2, 3, None], dtype="Int64")}) + check_round_trip(df, pa) + + def test_pyarrow_backed_string_array(self, pa, string_storage): + # test ArrowStringArray supported through the __arrow_array__ protocol + pytest.importorskip("pyarrow") + df = pd.DataFrame({"a": pd.Series(["a", None, "c"], dtype="string[pyarrow]")}) + with pd.option_context("string_storage", string_storage): + check_round_trip(df, pa, expected=df.astype(f"string[{string_storage}]")) + + def test_additional_extension_types(self, pa): + # test additional ExtensionArrays that are supported through the + # __arrow_array__ protocol + by defining a custom ExtensionType + pytest.importorskip("pyarrow") + df = pd.DataFrame( + { + "c": pd.IntervalIndex.from_tuples([(0, 1), (1, 2), (3, 4)]), + "d": pd.period_range("2012-01-01", periods=3, freq="D"), + # GH-45881 issue with interval with datetime64[ns] subtype + "e": pd.IntervalIndex.from_breaks( + pd.date_range("2012-01-01", periods=4, freq="D") + ), + } + ) + check_round_trip(df, pa) + + def test_timestamp_nanoseconds(self, pa): + # with version 2.6, pyarrow defaults to writing the nanoseconds, so + # this should work without error + # Note in previous pyarrows(<7.0.0), only the pseudo-version 2.0 was available + ver = "2.6" + df = pd.DataFrame({"a": pd.date_range("2017-01-01", freq="1ns", periods=10)}) + check_round_trip(df, pa, write_kwargs={"version": ver}) + + def test_timezone_aware_index(self, request, pa, timezone_aware_date_list): + if timezone_aware_date_list.tzinfo != datetime.timezone.utc: + request.applymarker( + pytest.mark.xfail( + reason="temporary skip this test until it is properly resolved: " + "https://github.com/pandas-dev/pandas/issues/37286" + ) + ) + idx = 5 * [timezone_aware_date_list] + df = pd.DataFrame(index=idx, data={"index_as_col": idx}) + + # see gh-36004 + # compare time(zone) values only, skip their class: + # pyarrow always creates fixed offset timezones using pytz.FixedOffset() + # even if it was datetime.timezone() originally + # + # technically they are the same: + # they both implement datetime.tzinfo + # they both wrap datetime.timedelta() + # this use-case sets the resolution to 1 minute + check_round_trip(df, pa, check_dtype=False) + + def test_filter_row_groups(self, pa): + # https://github.com/pandas-dev/pandas/issues/26551 + pytest.importorskip("pyarrow") + df = pd.DataFrame({"a": list(range(3))}) + with tm.ensure_clean() as path: + df.to_parquet(path, engine=pa) + result = read_parquet(path, pa, filters=[("a", "==", 0)]) + assert len(result) == 1 + + def test_read_parquet_manager(self, pa, using_array_manager): + # ensure that read_parquet honors the pandas.options.mode.data_manager option + df = pd.DataFrame( + np.random.default_rng(2).standard_normal((10, 3)), columns=["A", "B", "C"] + ) + + with tm.ensure_clean() as path: + df.to_parquet(path, engine=pa) + result = read_parquet(path, pa) + if using_array_manager: + assert isinstance(result._mgr, pd.core.internals.ArrayManager) + else: + assert isinstance(result._mgr, pd.core.internals.BlockManager) + + def test_read_dtype_backend_pyarrow_config(self, pa, df_full): + import pyarrow + + df = df_full + + # additional supported types for pyarrow + dti = pd.date_range("20130101", periods=3, tz="Europe/Brussels") + dti = dti._with_freq(None) # freq doesn't round-trip + df["datetime_tz"] = dti + df["bool_with_none"] = [True, None, True] + + pa_table = pyarrow.Table.from_pandas(df) + expected = pa_table.to_pandas(types_mapper=pd.ArrowDtype) + if pa_version_under13p0: + # pyarrow infers datetimes as us instead of ns + expected["datetime"] = expected["datetime"].astype("timestamp[us][pyarrow]") + expected["datetime_with_nat"] = expected["datetime_with_nat"].astype( + "timestamp[us][pyarrow]" + ) + expected["datetime_tz"] = expected["datetime_tz"].astype( + pd.ArrowDtype(pyarrow.timestamp(unit="us", tz="Europe/Brussels")) + ) + + check_round_trip( + df, + engine=pa, + read_kwargs={"dtype_backend": "pyarrow"}, + expected=expected, + ) + + def test_read_dtype_backend_pyarrow_config_index(self, pa): + df = pd.DataFrame( + {"a": [1, 2]}, index=pd.Index([3, 4], name="test"), dtype="int64[pyarrow]" + ) + expected = df.copy() + import pyarrow + + if Version(pyarrow.__version__) > Version("11.0.0"): + expected.index = expected.index.astype("int64[pyarrow]") + check_round_trip( + df, + engine=pa, + read_kwargs={"dtype_backend": "pyarrow"}, + expected=expected, + ) + + def test_columns_dtypes_not_invalid(self, pa): + df = pd.DataFrame({"string": list("abc"), "int": list(range(1, 4))}) + + # numeric + df.columns = [0, 1] + check_round_trip(df, pa) + + # bytes + df.columns = [b"foo", b"bar"] + with pytest.raises(NotImplementedError, match="|S3"): + # Bytes fails on read_parquet + check_round_trip(df, pa) + + # python object + df.columns = [ + datetime.datetime(2011, 1, 1, 0, 0), + datetime.datetime(2011, 1, 1, 1, 1), + ] + check_round_trip(df, pa) + + def test_empty_columns(self, pa): + # GH 52034 + df = pd.DataFrame(index=pd.Index(["a", "b", "c"], name="custom name")) + check_round_trip(df, pa) + + def test_df_attrs_persistence(self, tmp_path, pa): + path = tmp_path / "test_df_metadata.p" + df = pd.DataFrame(data={1: [1]}) + df.attrs = {"test_attribute": 1} + df.to_parquet(path, engine=pa) + new_df = read_parquet(path, engine=pa) + assert new_df.attrs == df.attrs + + def test_string_inference(self, tmp_path, pa): + # GH#54431 + path = tmp_path / "test_string_inference.p" + df = pd.DataFrame(data={"a": ["x", "y"]}, index=["a", "b"]) + df.to_parquet(path, engine="pyarrow") + with pd.option_context("future.infer_string", True): + result = read_parquet(path, engine="pyarrow") + expected = pd.DataFrame( + data={"a": ["x", "y"]}, + dtype="string[pyarrow_numpy]", + index=pd.Index(["a", "b"], dtype="string[pyarrow_numpy]"), + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.skipif(pa_version_under11p0, reason="not supported before 11.0") + def test_roundtrip_decimal(self, tmp_path, pa): + # GH#54768 + import pyarrow as pa + + path = tmp_path / "decimal.p" + df = pd.DataFrame({"a": [Decimal("123.00")]}, dtype="string[pyarrow]") + df.to_parquet(path, schema=pa.schema([("a", pa.decimal128(5))])) + result = read_parquet(path) + expected = pd.DataFrame({"a": ["123"]}, dtype="string[python]") + tm.assert_frame_equal(result, expected) + + def test_infer_string_large_string_type(self, tmp_path, pa): + # GH#54798 + import pyarrow as pa + import pyarrow.parquet as pq + + path = tmp_path / "large_string.p" + + table = pa.table({"a": pa.array([None, "b", "c"], pa.large_string())}) + pq.write_table(table, path) + + with pd.option_context("future.infer_string", True): + result = read_parquet(path) + expected = pd.DataFrame( + data={"a": [None, "b", "c"]}, + dtype="string[pyarrow_numpy]", + columns=pd.Index(["a"], dtype="string[pyarrow_numpy]"), + ) + tm.assert_frame_equal(result, expected) + + # NOTE: this test is not run by default, because it requires a lot of memory (>5GB) + # @pytest.mark.slow + # def test_string_column_above_2GB(self, tmp_path, pa): + # # https://github.com/pandas-dev/pandas/issues/55606 + # # above 2GB of string data + # v1 = b"x" * 100000000 + # v2 = b"x" * 147483646 + # df = pd.DataFrame({"strings": [v1] * 20 + [v2] + ["x"] * 20}, dtype="string") + # df.to_parquet(tmp_path / "test.parquet") + # result = read_parquet(tmp_path / "test.parquet") + # assert result["strings"].dtype == "string" + + +class TestParquetFastParquet(Base): + def test_basic(self, fp, df_full): + df = df_full + + dti = pd.date_range("20130101", periods=3, tz="US/Eastern") + dti = dti._with_freq(None) # freq doesn't round-trip + df["datetime_tz"] = dti + df["timedelta"] = pd.timedelta_range("1 day", periods=3) + check_round_trip(df, fp) + + def test_columns_dtypes_invalid(self, fp): + df = pd.DataFrame({"string": list("abc"), "int": list(range(1, 4))}) + + err = TypeError + msg = "Column name must be a string" + + # numeric + df.columns = [0, 1] + self.check_error_on_write(df, fp, err, msg) + + # bytes + df.columns = [b"foo", b"bar"] + self.check_error_on_write(df, fp, err, msg) + + # python object + df.columns = [ + datetime.datetime(2011, 1, 1, 0, 0), + datetime.datetime(2011, 1, 1, 1, 1), + ] + self.check_error_on_write(df, fp, err, msg) + + def test_duplicate_columns(self, fp): + # not currently able to handle duplicate columns + df = pd.DataFrame(np.arange(12).reshape(4, 3), columns=list("aaa")).copy() + msg = "Cannot create parquet dataset with duplicate column names" + self.check_error_on_write(df, fp, ValueError, msg) + + @pytest.mark.xfail( + Version(np.__version__) >= Version("2.0.0"), + reason="fastparquet uses np.float_ in numpy2", + ) + def test_bool_with_none(self, fp): + df = pd.DataFrame({"a": [True, None, False]}) + expected = pd.DataFrame({"a": [1.0, np.nan, 0.0]}, dtype="float16") + # Fastparquet bug in 0.7.1 makes it so that this dtype becomes + # float64 + check_round_trip(df, fp, expected=expected, check_dtype=False) + + def test_unsupported(self, fp): + # period + df = pd.DataFrame({"a": pd.period_range("2013", freq="M", periods=3)}) + # error from fastparquet -> don't check exact error message + self.check_error_on_write(df, fp, ValueError, None) + + # mixed + df = pd.DataFrame({"a": ["a", 1, 2.0]}) + msg = "Can't infer object conversion type" + self.check_error_on_write(df, fp, ValueError, msg) + + def test_categorical(self, fp): + df = pd.DataFrame({"a": pd.Categorical(list("abc"))}) + check_round_trip(df, fp) + + def test_filter_row_groups(self, fp): + d = {"a": list(range(3))} + df = pd.DataFrame(d) + with tm.ensure_clean() as path: + df.to_parquet(path, engine=fp, compression=None, row_group_offsets=1) + result = read_parquet(path, fp, filters=[("a", "==", 0)]) + assert len(result) == 1 + + @pytest.mark.single_cpu + def test_s3_roundtrip(self, df_compat, s3_public_bucket, fp, s3so): + # GH #19134 + check_round_trip( + df_compat, + fp, + path=f"s3://{s3_public_bucket.name}/fastparquet.parquet", + read_kwargs={"storage_options": s3so}, + write_kwargs={"compression": None, "storage_options": s3so}, + ) + + def test_partition_cols_supported(self, tmp_path, fp, df_full): + # GH #23283 + partition_cols = ["bool", "int"] + df = df_full + df.to_parquet( + tmp_path, + engine="fastparquet", + partition_cols=partition_cols, + compression=None, + ) + assert os.path.exists(tmp_path) + import fastparquet + + actual_partition_cols = fastparquet.ParquetFile(str(tmp_path), False).cats + assert len(actual_partition_cols) == 2 + + def test_partition_cols_string(self, tmp_path, fp, df_full): + # GH #27117 + partition_cols = "bool" + df = df_full + df.to_parquet( + tmp_path, + engine="fastparquet", + partition_cols=partition_cols, + compression=None, + ) + assert os.path.exists(tmp_path) + import fastparquet + + actual_partition_cols = fastparquet.ParquetFile(str(tmp_path), False).cats + assert len(actual_partition_cols) == 1 + + def test_partition_on_supported(self, tmp_path, fp, df_full): + # GH #23283 + partition_cols = ["bool", "int"] + df = df_full + df.to_parquet( + tmp_path, + engine="fastparquet", + compression=None, + partition_on=partition_cols, + ) + assert os.path.exists(tmp_path) + import fastparquet + + actual_partition_cols = fastparquet.ParquetFile(str(tmp_path), False).cats + assert len(actual_partition_cols) == 2 + + def test_error_on_using_partition_cols_and_partition_on( + self, tmp_path, fp, df_full + ): + # GH #23283 + partition_cols = ["bool", "int"] + df = df_full + msg = ( + "Cannot use both partition_on and partition_cols. Use partition_cols for " + "partitioning data" + ) + with pytest.raises(ValueError, match=msg): + df.to_parquet( + tmp_path, + engine="fastparquet", + compression=None, + partition_on=partition_cols, + partition_cols=partition_cols, + ) + + @pytest.mark.skipif(using_copy_on_write(), reason="fastparquet writes into Index") + def test_empty_dataframe(self, fp): + # GH #27339 + df = pd.DataFrame() + expected = df.copy() + check_round_trip(df, fp, expected=expected) + + @pytest.mark.xfail( + _HAVE_FASTPARQUET and Version(fastparquet.__version__) > Version("2022.12"), + reason="fastparquet bug, see https://github.com/dask/fastparquet/issues/929", + ) + def test_timezone_aware_index(self, fp, timezone_aware_date_list): + idx = 5 * [timezone_aware_date_list] + + df = pd.DataFrame(index=idx, data={"index_as_col": idx}) + + expected = df.copy() + expected.index.name = "index" + check_round_trip(df, fp, expected=expected) + + def test_use_nullable_dtypes_not_supported(self, fp): + df = pd.DataFrame({"a": [1, 2]}) + + with tm.ensure_clean() as path: + df.to_parquet(path) + with pytest.raises(ValueError, match="not supported for the fastparquet"): + with tm.assert_produces_warning(FutureWarning): + read_parquet(path, engine="fastparquet", use_nullable_dtypes=True) + with pytest.raises(ValueError, match="not supported for the fastparquet"): + read_parquet(path, engine="fastparquet", dtype_backend="pyarrow") + + def test_close_file_handle_on_read_error(self): + with tm.ensure_clean("test.parquet") as path: + pathlib.Path(path).write_bytes(b"breakit") + with pytest.raises(Exception, match=""): # Not important which exception + read_parquet(path, engine="fastparquet") + # The next line raises an error on Windows if the file is still open + pathlib.Path(path).unlink(missing_ok=False) + + def test_bytes_file_name(self, engine): + # GH#48944 + df = pd.DataFrame(data={"A": [0, 1], "B": [1, 0]}) + with tm.ensure_clean("test.parquet") as path: + with open(path.encode(), "wb") as f: + df.to_parquet(f) + + result = read_parquet(path, engine=engine) + tm.assert_frame_equal(result, df) + + def test_filesystem_notimplemented(self): + pytest.importorskip("fastparquet") + df = pd.DataFrame(data={"A": [0, 1], "B": [1, 0]}) + with tm.ensure_clean() as path: + with pytest.raises( + NotImplementedError, match="filesystem is not implemented" + ): + df.to_parquet(path, engine="fastparquet", filesystem="foo") + + with tm.ensure_clean() as path: + pathlib.Path(path).write_bytes(b"foo") + with pytest.raises( + NotImplementedError, match="filesystem is not implemented" + ): + read_parquet(path, engine="fastparquet", filesystem="foo") + + def test_invalid_filesystem(self): + pytest.importorskip("pyarrow") + df = pd.DataFrame(data={"A": [0, 1], "B": [1, 0]}) + with tm.ensure_clean() as path: + with pytest.raises( + ValueError, match="filesystem must be a pyarrow or fsspec FileSystem" + ): + df.to_parquet(path, engine="pyarrow", filesystem="foo") + + with tm.ensure_clean() as path: + pathlib.Path(path).write_bytes(b"foo") + with pytest.raises( + ValueError, match="filesystem must be a pyarrow or fsspec FileSystem" + ): + read_parquet(path, engine="pyarrow", filesystem="foo") + + def test_unsupported_pa_filesystem_storage_options(self): + pa_fs = pytest.importorskip("pyarrow.fs") + df = pd.DataFrame(data={"A": [0, 1], "B": [1, 0]}) + with tm.ensure_clean() as path: + with pytest.raises( + NotImplementedError, + match="storage_options not supported with a pyarrow FileSystem.", + ): + df.to_parquet( + path, + engine="pyarrow", + filesystem=pa_fs.LocalFileSystem(), + storage_options={"foo": "bar"}, + ) + + with tm.ensure_clean() as path: + pathlib.Path(path).write_bytes(b"foo") + with pytest.raises( + NotImplementedError, + match="storage_options not supported with a pyarrow FileSystem.", + ): + read_parquet( + path, + engine="pyarrow", + filesystem=pa_fs.LocalFileSystem(), + storage_options={"foo": "bar"}, + ) + + def test_invalid_dtype_backend(self, engine): + msg = ( + "dtype_backend numpy is invalid, only 'numpy_nullable' and " + "'pyarrow' are allowed." + ) + df = pd.DataFrame({"int": list(range(1, 4))}) + with tm.ensure_clean("tmp.parquet") as path: + df.to_parquet(path) + with pytest.raises(ValueError, match=msg): + read_parquet(path, dtype_backend="numpy") + + @pytest.mark.skipif(using_copy_on_write(), reason="fastparquet writes into Index") + def test_empty_columns(self, fp): + # GH 52034 + df = pd.DataFrame(index=pd.Index(["a", "b", "c"], name="custom name")) + expected = pd.DataFrame(index=pd.Index(["a", "b", "c"], name="custom name")) + check_round_trip(df, fp, expected=expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_pickle.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_pickle.py new file mode 100644 index 0000000000000000000000000000000000000000..4f3993a038197e52c7f21fb4f4d40425e897600f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_pickle.py @@ -0,0 +1,652 @@ +""" +manage legacy pickle tests + +How to add pickle tests: + +1. Install pandas version intended to output the pickle. + +2. Execute "generate_legacy_storage_files.py" to create the pickle. +$ python generate_legacy_storage_files.py pickle + +3. Move the created pickle to "data/legacy_pickle/" directory. +""" +from __future__ import annotations + +from array import array +import bz2 +import datetime +import functools +from functools import partial +import gzip +import io +import os +from pathlib import Path +import pickle +import shutil +import tarfile +from typing import Any +import uuid +import zipfile + +import numpy as np +import pytest + +from pandas.compat import ( + get_lzma_file, + is_platform_little_endian, +) +from pandas.compat._optional import import_optional_dependency +from pandas.compat.compressors import flatten_buffer +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + DataFrame, + Index, + Series, + period_range, +) +import pandas._testing as tm +from pandas.tests.io.generate_legacy_storage_files import create_pickle_data + +import pandas.io.common as icom +from pandas.tseries.offsets import ( + Day, + MonthEnd, +) + + +# --------------------- +# comparison functions +# --------------------- +def compare_element(result, expected, typ): + if isinstance(expected, Index): + tm.assert_index_equal(expected, result) + return + + if typ.startswith("sp_"): + tm.assert_equal(result, expected) + elif typ == "timestamp": + if expected is pd.NaT: + assert result is pd.NaT + else: + assert result == expected + else: + comparator = getattr(tm, f"assert_{typ}_equal", tm.assert_almost_equal) + comparator(result, expected) + + +# --------------------- +# tests +# --------------------- + + +@pytest.mark.parametrize( + "data", + [ + b"123", + b"123456", + bytearray(b"123"), + memoryview(b"123"), + pickle.PickleBuffer(b"123"), + array("I", [1, 2, 3]), + memoryview(b"123456").cast("B", (3, 2)), + memoryview(b"123456").cast("B", (3, 2))[::2], + np.arange(12).reshape((3, 4), order="C"), + np.arange(12).reshape((3, 4), order="F"), + np.arange(12).reshape((3, 4), order="C")[:, ::2], + ], +) +def test_flatten_buffer(data): + result = flatten_buffer(data) + expected = memoryview(data).tobytes("A") + assert result == expected + if isinstance(data, (bytes, bytearray)): + assert result is data + elif isinstance(result, memoryview): + assert result.ndim == 1 + assert result.format == "B" + assert result.contiguous + assert result.shape == (result.nbytes,) + + +def test_pickles(datapath): + if not is_platform_little_endian(): + pytest.skip("known failure on non-little endian") + + # For loop for compat with --strict-data-files + for legacy_pickle in Path(__file__).parent.glob("data/legacy_pickle/*/*.p*kl*"): + legacy_pickle = datapath(legacy_pickle) + + data = pd.read_pickle(legacy_pickle) + + for typ, dv in data.items(): + for dt, result in dv.items(): + expected = data[typ][dt] + + if typ == "series" and dt == "ts": + # GH 7748 + tm.assert_series_equal(result, expected) + assert result.index.freq == expected.index.freq + assert not result.index.freq.normalize + tm.assert_series_equal(result > 0, expected > 0) + + # GH 9291 + freq = result.index.freq + assert freq + Day(1) == Day(2) + + res = freq + pd.Timedelta(hours=1) + assert isinstance(res, pd.Timedelta) + assert res == pd.Timedelta(days=1, hours=1) + + res = freq + pd.Timedelta(nanoseconds=1) + assert isinstance(res, pd.Timedelta) + assert res == pd.Timedelta(days=1, nanoseconds=1) + elif typ == "index" and dt == "period": + tm.assert_index_equal(result, expected) + assert isinstance(result.freq, MonthEnd) + assert result.freq == MonthEnd() + assert result.freqstr == "M" + tm.assert_index_equal(result.shift(2), expected.shift(2)) + elif typ == "series" and dt in ("dt_tz", "cat"): + tm.assert_series_equal(result, expected) + elif typ == "frame" and dt in ( + "dt_mixed_tzs", + "cat_onecol", + "cat_and_float", + ): + tm.assert_frame_equal(result, expected) + else: + compare_element(result, expected, typ) + + +def python_pickler(obj, path): + with open(path, "wb") as fh: + pickle.dump(obj, fh, protocol=-1) + + +def python_unpickler(path): + with open(path, "rb") as fh: + fh.seek(0) + return pickle.load(fh) + + +def flatten(data: dict) -> list[tuple[str, Any]]: + """Flatten create_pickle_data""" + return [ + (typ, example) + for typ, examples in data.items() + for example in examples.values() + ] + + +@pytest.mark.parametrize( + "pickle_writer", + [ + pytest.param(python_pickler, id="python"), + pytest.param(pd.to_pickle, id="pandas_proto_default"), + pytest.param( + functools.partial(pd.to_pickle, protocol=pickle.HIGHEST_PROTOCOL), + id="pandas_proto_highest", + ), + pytest.param(functools.partial(pd.to_pickle, protocol=4), id="pandas_proto_4"), + pytest.param( + functools.partial(pd.to_pickle, protocol=5), + id="pandas_proto_5", + ), + ], +) +@pytest.mark.parametrize("writer", [pd.to_pickle, python_pickler]) +@pytest.mark.parametrize("typ, expected", flatten(create_pickle_data())) +def test_round_trip_current(typ, expected, pickle_writer, writer): + with tm.ensure_clean() as path: + # test writing with each pickler + pickle_writer(expected, path) + + # test reading with each unpickler + result = pd.read_pickle(path) + compare_element(result, expected, typ) + + result = python_unpickler(path) + compare_element(result, expected, typ) + + # and the same for file objects (GH 35679) + with open(path, mode="wb") as handle: + writer(expected, path) + handle.seek(0) # shouldn't close file handle + with open(path, mode="rb") as handle: + result = pd.read_pickle(handle) + handle.seek(0) # shouldn't close file handle + compare_element(result, expected, typ) + + +def test_pickle_path_pathlib(): + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=Index(list("ABCD"), dtype=object), + index=Index([f"i-{i}" for i in range(30)], dtype=object), + ) + result = tm.round_trip_pathlib(df.to_pickle, pd.read_pickle) + tm.assert_frame_equal(df, result) + + +def test_pickle_path_localpath(): + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=Index(list("ABCD"), dtype=object), + index=Index([f"i-{i}" for i in range(30)], dtype=object), + ) + result = tm.round_trip_localpath(df.to_pickle, pd.read_pickle) + tm.assert_frame_equal(df, result) + + +# --------------------- +# test pickle compression +# --------------------- + + +@pytest.fixture +def get_random_path(): + return f"__{uuid.uuid4()}__.pickle" + + +class TestCompression: + _extension_to_compression = icom.extension_to_compression + + def compress_file(self, src_path, dest_path, compression): + if compression is None: + shutil.copyfile(src_path, dest_path) + return + + if compression == "gzip": + f = gzip.open(dest_path, "w") + elif compression == "bz2": + f = bz2.BZ2File(dest_path, "w") + elif compression == "zip": + with zipfile.ZipFile(dest_path, "w", compression=zipfile.ZIP_DEFLATED) as f: + f.write(src_path, os.path.basename(src_path)) + elif compression == "tar": + with open(src_path, "rb") as fh: + with tarfile.open(dest_path, mode="w") as tar: + tarinfo = tar.gettarinfo(src_path, os.path.basename(src_path)) + tar.addfile(tarinfo, fh) + elif compression == "xz": + f = get_lzma_file()(dest_path, "w") + elif compression == "zstd": + f = import_optional_dependency("zstandard").open(dest_path, "wb") + else: + msg = f"Unrecognized compression type: {compression}" + raise ValueError(msg) + + if compression not in ["zip", "tar"]: + with open(src_path, "rb") as fh: + with f: + f.write(fh.read()) + + def test_write_explicit(self, compression, get_random_path): + base = get_random_path + path1 = base + ".compressed" + path2 = base + ".raw" + + with tm.ensure_clean(path1) as p1, tm.ensure_clean(path2) as p2: + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=Index(list("ABCD"), dtype=object), + index=Index([f"i-{i}" for i in range(30)], dtype=object), + ) + + # write to compressed file + df.to_pickle(p1, compression=compression) + + # decompress + with tm.decompress_file(p1, compression=compression) as f: + with open(p2, "wb") as fh: + fh.write(f.read()) + + # read decompressed file + df2 = pd.read_pickle(p2, compression=None) + + tm.assert_frame_equal(df, df2) + + @pytest.mark.parametrize("compression", ["", "None", "bad", "7z"]) + def test_write_explicit_bad(self, compression, get_random_path): + with pytest.raises(ValueError, match="Unrecognized compression type"): + with tm.ensure_clean(get_random_path) as path: + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=Index(list("ABCD"), dtype=object), + index=Index([f"i-{i}" for i in range(30)], dtype=object), + ) + df.to_pickle(path, compression=compression) + + def test_write_infer(self, compression_ext, get_random_path): + base = get_random_path + path1 = base + compression_ext + path2 = base + ".raw" + compression = self._extension_to_compression.get(compression_ext.lower()) + + with tm.ensure_clean(path1) as p1, tm.ensure_clean(path2) as p2: + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=Index(list("ABCD"), dtype=object), + index=Index([f"i-{i}" for i in range(30)], dtype=object), + ) + + # write to compressed file by inferred compression method + df.to_pickle(p1) + + # decompress + with tm.decompress_file(p1, compression=compression) as f: + with open(p2, "wb") as fh: + fh.write(f.read()) + + # read decompressed file + df2 = pd.read_pickle(p2, compression=None) + + tm.assert_frame_equal(df, df2) + + def test_read_explicit(self, compression, get_random_path): + base = get_random_path + path1 = base + ".raw" + path2 = base + ".compressed" + + with tm.ensure_clean(path1) as p1, tm.ensure_clean(path2) as p2: + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=Index(list("ABCD"), dtype=object), + index=Index([f"i-{i}" for i in range(30)], dtype=object), + ) + + # write to uncompressed file + df.to_pickle(p1, compression=None) + + # compress + self.compress_file(p1, p2, compression=compression) + + # read compressed file + df2 = pd.read_pickle(p2, compression=compression) + tm.assert_frame_equal(df, df2) + + def test_read_infer(self, compression_ext, get_random_path): + base = get_random_path + path1 = base + ".raw" + path2 = base + compression_ext + compression = self._extension_to_compression.get(compression_ext.lower()) + + with tm.ensure_clean(path1) as p1, tm.ensure_clean(path2) as p2: + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=Index(list("ABCD"), dtype=object), + index=Index([f"i-{i}" for i in range(30)], dtype=object), + ) + + # write to uncompressed file + df.to_pickle(p1, compression=None) + + # compress + self.compress_file(p1, p2, compression=compression) + + # read compressed file by inferred compression method + df2 = pd.read_pickle(p2) + tm.assert_frame_equal(df, df2) + + +# --------------------- +# test pickle compression +# --------------------- + + +class TestProtocol: + @pytest.mark.parametrize("protocol", [-1, 0, 1, 2]) + def test_read(self, protocol, get_random_path): + with tm.ensure_clean(get_random_path) as path: + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=Index(list("ABCD"), dtype=object), + index=Index([f"i-{i}" for i in range(30)], dtype=object), + ) + df.to_pickle(path, protocol=protocol) + df2 = pd.read_pickle(path) + tm.assert_frame_equal(df, df2) + + +@pytest.mark.parametrize( + ["pickle_file", "excols"], + [ + ("test_py27.pkl", Index(["a", "b", "c"])), + ( + "test_mi_py27.pkl", + pd.MultiIndex.from_arrays([["a", "b", "c"], ["A", "B", "C"]]), + ), + ], +) +def test_unicode_decode_error(datapath, pickle_file, excols): + # pickle file written with py27, should be readable without raising + # UnicodeDecodeError, see GH#28645 and GH#31988 + path = datapath("io", "data", "pickle", pickle_file) + df = pd.read_pickle(path) + + # just test the columns are correct since the values are random + tm.assert_index_equal(df.columns, excols) + + +# --------------------- +# tests for buffer I/O +# --------------------- + + +def test_pickle_buffer_roundtrip(): + with tm.ensure_clean() as path: + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=Index(list("ABCD"), dtype=object), + index=Index([f"i-{i}" for i in range(30)], dtype=object), + ) + with open(path, "wb") as fh: + df.to_pickle(fh) + with open(path, "rb") as fh: + result = pd.read_pickle(fh) + tm.assert_frame_equal(df, result) + + +# --------------------- +# tests for URL I/O +# --------------------- + + +@pytest.mark.parametrize( + "mockurl", ["http://url.com", "ftp://test.com", "http://gzip.com"] +) +def test_pickle_generalurl_read(monkeypatch, mockurl): + def python_pickler(obj, path): + with open(path, "wb") as fh: + pickle.dump(obj, fh, protocol=-1) + + class MockReadResponse: + def __init__(self, path) -> None: + self.file = open(path, "rb") + if "gzip" in path: + self.headers = {"Content-Encoding": "gzip"} + else: + self.headers = {"Content-Encoding": ""} + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def read(self): + return self.file.read() + + def close(self): + return self.file.close() + + with tm.ensure_clean() as path: + + def mock_urlopen_read(*args, **kwargs): + return MockReadResponse(path) + + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=Index(list("ABCD"), dtype=object), + index=Index([f"i-{i}" for i in range(30)], dtype=object), + ) + python_pickler(df, path) + monkeypatch.setattr("urllib.request.urlopen", mock_urlopen_read) + result = pd.read_pickle(mockurl) + tm.assert_frame_equal(df, result) + + +def test_pickle_fsspec_roundtrip(): + pytest.importorskip("fsspec") + with tm.ensure_clean(): + mockurl = "memory://mockfile" + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=Index(list("ABCD"), dtype=object), + index=Index([f"i-{i}" for i in range(30)], dtype=object), + ) + df.to_pickle(mockurl) + result = pd.read_pickle(mockurl) + tm.assert_frame_equal(df, result) + + +class MyTz(datetime.tzinfo): + def __init__(self) -> None: + pass + + +def test_read_pickle_with_subclass(): + # GH 12163 + expected = Series(dtype=object), MyTz() + result = tm.round_trip_pickle(expected) + + tm.assert_series_equal(result[0], expected[0]) + assert isinstance(result[1], MyTz) + + +def test_pickle_binary_object_compression(compression): + """ + Read/write from binary file-objects w/wo compression. + + GH 26237, GH 29054, and GH 29570 + """ + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=Index(list("ABCD"), dtype=object), + index=Index([f"i-{i}" for i in range(30)], dtype=object), + ) + + # reference for compression + with tm.ensure_clean() as path: + df.to_pickle(path, compression=compression) + reference = Path(path).read_bytes() + + # write + buffer = io.BytesIO() + df.to_pickle(buffer, compression=compression) + buffer.seek(0) + + # gzip and zip safe the filename: cannot compare the compressed content + assert buffer.getvalue() == reference or compression in ("gzip", "zip", "tar") + + # read + read_df = pd.read_pickle(buffer, compression=compression) + buffer.seek(0) + tm.assert_frame_equal(df, read_df) + + +def test_pickle_dataframe_with_multilevel_index( + multiindex_year_month_day_dataframe_random_data, + multiindex_dataframe_random_data, +): + ymd = multiindex_year_month_day_dataframe_random_data + frame = multiindex_dataframe_random_data + + def _test_roundtrip(frame): + unpickled = tm.round_trip_pickle(frame) + tm.assert_frame_equal(frame, unpickled) + + _test_roundtrip(frame) + _test_roundtrip(frame.T) + _test_roundtrip(ymd) + _test_roundtrip(ymd.T) + + +def test_pickle_timeseries_periodindex(): + # GH#2891 + prng = period_range("1/1/2011", "1/1/2012", freq="M") + ts = Series(np.random.default_rng(2).standard_normal(len(prng)), prng) + new_ts = tm.round_trip_pickle(ts) + assert new_ts.index.freqstr == "M" + + +@pytest.mark.parametrize( + "name", [777, 777.0, "name", datetime.datetime(2001, 11, 11), (1, 2)] +) +def test_pickle_preserve_name(name): + unpickled = tm.round_trip_pickle(Series(np.arange(10, dtype=np.float64), name=name)) + assert unpickled.name == name + + +def test_pickle_datetimes(datetime_series): + unp_ts = tm.round_trip_pickle(datetime_series) + tm.assert_series_equal(unp_ts, datetime_series) + + +def test_pickle_strings(string_series): + unp_series = tm.round_trip_pickle(string_series) + tm.assert_series_equal(unp_series, string_series) + + +@td.skip_array_manager_invalid_test +def test_pickle_preserves_block_ndim(): + # GH#37631 + ser = Series(list("abc")).astype("category").iloc[[0]] + res = tm.round_trip_pickle(ser) + + assert res._mgr.blocks[0].ndim == 1 + assert res._mgr.blocks[0].shape == (1,) + + # GH#37631 OP issue was about indexing, underlying problem was pickle + tm.assert_series_equal(res[[True]], ser) + + +@pytest.mark.parametrize("protocol", [pickle.DEFAULT_PROTOCOL, pickle.HIGHEST_PROTOCOL]) +def test_pickle_big_dataframe_compression(protocol, compression): + # GH#39002 + df = DataFrame(range(100000)) + result = tm.round_trip_pathlib( + partial(df.to_pickle, protocol=protocol, compression=compression), + partial(pd.read_pickle, compression=compression), + ) + tm.assert_frame_equal(df, result) + + +def test_pickle_frame_v124_unpickle_130(datapath): + # GH#42345 DataFrame created in 1.2.x, unpickle in 1.3.x + path = datapath( + Path(__file__).parent, + "data", + "legacy_pickle", + "1.2.4", + "empty_frame_v1_2_4-GH#42345.pkl", + ) + with open(path, "rb") as fd: + df = pickle.load(fd) + + expected = DataFrame(index=[], columns=[]) + tm.assert_frame_equal(df, expected) + + +def test_pickle_pos_args_deprecation(): + # GH-54229 + df = DataFrame({"a": [1, 2, 3]}) + msg = ( + r"Starting with pandas version 3.0 all arguments of to_pickle except for the " + r"argument 'path' will be keyword-only." + ) + with tm.assert_produces_warning(FutureWarning, match=msg): + buffer = io.BytesIO() + df.to_pickle(buffer, "infer") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_s3.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_s3.py new file mode 100644 index 0000000000000000000000000000000000000000..79473895b662da6af68fbe29a60eb05f134a54df --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_s3.py @@ -0,0 +1,43 @@ +from io import BytesIO + +import pytest + +from pandas import read_csv + + +def test_streaming_s3_objects(): + # GH17135 + # botocore gained iteration support in 1.10.47, can now be used in read_* + pytest.importorskip("botocore", minversion="1.10.47") + from botocore.response import StreamingBody + + data = [b"foo,bar,baz\n1,2,3\n4,5,6\n", b"just,the,header\n"] + for el in data: + body = StreamingBody(BytesIO(el), content_length=len(el)) + read_csv(body) + + +@pytest.mark.single_cpu +def test_read_without_creds_from_pub_bucket(s3_public_bucket_with_data, s3so): + # GH 34626 + pytest.importorskip("s3fs") + result = read_csv( + f"s3://{s3_public_bucket_with_data.name}/tips.csv", + nrows=3, + storage_options=s3so, + ) + assert len(result) == 3 + + +@pytest.mark.single_cpu +def test_read_with_creds_from_pub_bucket(s3_public_bucket_with_data, s3so): + # Ensure we can read from a public bucket with credentials + # GH 34626 + pytest.importorskip("s3fs") + df = read_csv( + f"s3://{s3_public_bucket_with_data.name}/tips.csv", + nrows=5, + header=None, + storage_options=s3so, + ) + assert len(df) == 5 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_spss.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_spss.py new file mode 100644 index 0000000000000000000000000000000000000000..e118c90d9bc02041719cd1452b5af8e77b12db77 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_spss.py @@ -0,0 +1,164 @@ +import datetime +from pathlib import Path + +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm +from pandas.util.version import Version + +pyreadstat = pytest.importorskip("pyreadstat") + + +# TODO(CoW) - detection of chained assignment in cython +# https://github.com/pandas-dev/pandas/issues/51315 +@pytest.mark.filterwarnings("ignore::pandas.errors.ChainedAssignmentError") +@pytest.mark.filterwarnings("ignore:ChainedAssignmentError:FutureWarning") +@pytest.mark.parametrize("path_klass", [lambda p: p, Path]) +def test_spss_labelled_num(path_klass, datapath): + # test file from the Haven project (https://haven.tidyverse.org/) + # Licence at LICENSES/HAVEN_LICENSE, LICENSES/HAVEN_MIT + fname = path_klass(datapath("io", "data", "spss", "labelled-num.sav")) + + df = pd.read_spss(fname, convert_categoricals=True) + expected = pd.DataFrame({"VAR00002": "This is one"}, index=[0]) + expected["VAR00002"] = pd.Categorical(expected["VAR00002"]) + tm.assert_frame_equal(df, expected) + + df = pd.read_spss(fname, convert_categoricals=False) + expected = pd.DataFrame({"VAR00002": 1.0}, index=[0]) + tm.assert_frame_equal(df, expected) + + +@pytest.mark.filterwarnings("ignore::pandas.errors.ChainedAssignmentError") +@pytest.mark.filterwarnings("ignore:ChainedAssignmentError:FutureWarning") +def test_spss_labelled_num_na(datapath): + # test file from the Haven project (https://haven.tidyverse.org/) + # Licence at LICENSES/HAVEN_LICENSE, LICENSES/HAVEN_MIT + fname = datapath("io", "data", "spss", "labelled-num-na.sav") + + df = pd.read_spss(fname, convert_categoricals=True) + expected = pd.DataFrame({"VAR00002": ["This is one", None]}) + expected["VAR00002"] = pd.Categorical(expected["VAR00002"]) + tm.assert_frame_equal(df, expected) + + df = pd.read_spss(fname, convert_categoricals=False) + expected = pd.DataFrame({"VAR00002": [1.0, np.nan]}) + tm.assert_frame_equal(df, expected) + + +@pytest.mark.filterwarnings("ignore::pandas.errors.ChainedAssignmentError") +@pytest.mark.filterwarnings("ignore:ChainedAssignmentError:FutureWarning") +def test_spss_labelled_str(datapath): + # test file from the Haven project (https://haven.tidyverse.org/) + # Licence at LICENSES/HAVEN_LICENSE, LICENSES/HAVEN_MIT + fname = datapath("io", "data", "spss", "labelled-str.sav") + + df = pd.read_spss(fname, convert_categoricals=True) + expected = pd.DataFrame({"gender": ["Male", "Female"]}) + expected["gender"] = pd.Categorical(expected["gender"]) + tm.assert_frame_equal(df, expected) + + df = pd.read_spss(fname, convert_categoricals=False) + expected = pd.DataFrame({"gender": ["M", "F"]}) + tm.assert_frame_equal(df, expected) + + +@pytest.mark.filterwarnings("ignore::pandas.errors.ChainedAssignmentError") +@pytest.mark.filterwarnings("ignore:ChainedAssignmentError:FutureWarning") +def test_spss_umlauts(datapath): + # test file from the Haven project (https://haven.tidyverse.org/) + # Licence at LICENSES/HAVEN_LICENSE, LICENSES/HAVEN_MIT + fname = datapath("io", "data", "spss", "umlauts.sav") + + df = pd.read_spss(fname, convert_categoricals=True) + expected = pd.DataFrame( + {"var1": ["the ä umlaut", "the ü umlaut", "the ä umlaut", "the ö umlaut"]} + ) + expected["var1"] = pd.Categorical(expected["var1"]) + tm.assert_frame_equal(df, expected) + + df = pd.read_spss(fname, convert_categoricals=False) + expected = pd.DataFrame({"var1": [1.0, 2.0, 1.0, 3.0]}) + tm.assert_frame_equal(df, expected) + + +def test_spss_usecols(datapath): + # usecols must be list-like + fname = datapath("io", "data", "spss", "labelled-num.sav") + + with pytest.raises(TypeError, match="usecols must be list-like."): + pd.read_spss(fname, usecols="VAR00002") + + +def test_spss_umlauts_dtype_backend(datapath, dtype_backend): + # test file from the Haven project (https://haven.tidyverse.org/) + # Licence at LICENSES/HAVEN_LICENSE, LICENSES/HAVEN_MIT + fname = datapath("io", "data", "spss", "umlauts.sav") + + df = pd.read_spss(fname, convert_categoricals=False, dtype_backend=dtype_backend) + expected = pd.DataFrame({"var1": [1.0, 2.0, 1.0, 3.0]}, dtype="Int64") + + if dtype_backend == "pyarrow": + pa = pytest.importorskip("pyarrow") + + from pandas.arrays import ArrowExtensionArray + + expected = pd.DataFrame( + { + col: ArrowExtensionArray(pa.array(expected[col], from_pandas=True)) + for col in expected.columns + } + ) + + tm.assert_frame_equal(df, expected) + + +def test_invalid_dtype_backend(): + msg = ( + "dtype_backend numpy is invalid, only 'numpy_nullable' and " + "'pyarrow' are allowed." + ) + with pytest.raises(ValueError, match=msg): + pd.read_spss("test", dtype_backend="numpy") + + +@pytest.mark.filterwarnings("ignore::pandas.errors.ChainedAssignmentError") +@pytest.mark.filterwarnings("ignore:ChainedAssignmentError:FutureWarning") +def test_spss_metadata(datapath): + # GH 54264 + fname = datapath("io", "data", "spss", "labelled-num.sav") + + df = pd.read_spss(fname) + metadata = { + "column_names": ["VAR00002"], + "column_labels": [None], + "column_names_to_labels": {"VAR00002": None}, + "file_encoding": "UTF-8", + "number_columns": 1, + "number_rows": 1, + "variable_value_labels": {"VAR00002": {1.0: "This is one"}}, + "value_labels": {"labels0": {1.0: "This is one"}}, + "variable_to_label": {"VAR00002": "labels0"}, + "notes": [], + "original_variable_types": {"VAR00002": "F8.0"}, + "readstat_variable_types": {"VAR00002": "double"}, + "table_name": None, + "missing_ranges": {}, + "missing_user_values": {}, + "variable_storage_width": {"VAR00002": 8}, + "variable_display_width": {"VAR00002": 8}, + "variable_alignment": {"VAR00002": "unknown"}, + "variable_measure": {"VAR00002": "unknown"}, + "file_label": None, + "file_format": "sav/zsav", + } + if Version(pyreadstat.__version__) >= Version("1.2.4"): + metadata.update( + { + "creation_time": datetime.datetime(2015, 2, 6, 14, 33, 36), + "modification_time": datetime.datetime(2015, 2, 6, 14, 33, 36), + } + ) + assert df.attrs == metadata diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_sql.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_sql.py new file mode 100644 index 0000000000000000000000000000000000000000..7068247bbfa8bcab937736e8a93b7299c1ba17d6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_sql.py @@ -0,0 +1,4395 @@ +from __future__ import annotations + +import contextlib +from contextlib import closing +import csv +from datetime import ( + date, + datetime, + time, + timedelta, +) +from io import StringIO +from pathlib import Path +import sqlite3 +from typing import TYPE_CHECKING +import uuid + +import numpy as np +import pytest + +from pandas._libs import lib +from pandas.compat import ( + pa_version_under13p0, + pa_version_under14p1, +) +from pandas.compat._optional import import_optional_dependency +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, + Timestamp, + concat, + date_range, + isna, + to_datetime, + to_timedelta, +) +import pandas._testing as tm +from pandas.core.arrays import ( + ArrowStringArray, + StringArray, +) +from pandas.util.version import Version + +from pandas.io import sql +from pandas.io.sql import ( + SQLAlchemyEngine, + SQLDatabase, + SQLiteDatabase, + get_engine, + pandasSQL_builder, + read_sql_query, + read_sql_table, +) + +if TYPE_CHECKING: + import sqlalchemy + + +pytestmark = pytest.mark.filterwarnings( + "ignore:Passing a BlockManager to DataFrame:DeprecationWarning" +) + + +@pytest.fixture +def sql_strings(): + return { + "read_parameters": { + "sqlite": "SELECT * FROM iris WHERE Name=? AND SepalLength=?", + "mysql": "SELECT * FROM iris WHERE `Name`=%s AND `SepalLength`=%s", + "postgresql": 'SELECT * FROM iris WHERE "Name"=%s AND "SepalLength"=%s', + }, + "read_named_parameters": { + "sqlite": """ + SELECT * FROM iris WHERE Name=:name AND SepalLength=:length + """, + "mysql": """ + SELECT * FROM iris WHERE + `Name`=%(name)s AND `SepalLength`=%(length)s + """, + "postgresql": """ + SELECT * FROM iris WHERE + "Name"=%(name)s AND "SepalLength"=%(length)s + """, + }, + "read_no_parameters_with_percent": { + "sqlite": "SELECT * FROM iris WHERE Name LIKE '%'", + "mysql": "SELECT * FROM iris WHERE `Name` LIKE '%'", + "postgresql": "SELECT * FROM iris WHERE \"Name\" LIKE '%'", + }, + } + + +def iris_table_metadata(): + import sqlalchemy + from sqlalchemy import ( + Column, + Double, + Float, + MetaData, + String, + Table, + ) + + dtype = Double if Version(sqlalchemy.__version__) >= Version("2.0.0") else Float + metadata = MetaData() + iris = Table( + "iris", + metadata, + Column("SepalLength", dtype), + Column("SepalWidth", dtype), + Column("PetalLength", dtype), + Column("PetalWidth", dtype), + Column("Name", String(200)), + ) + return iris + + +def create_and_load_iris_sqlite3(conn, iris_file: Path): + stmt = """CREATE TABLE iris ( + "SepalLength" REAL, + "SepalWidth" REAL, + "PetalLength" REAL, + "PetalWidth" REAL, + "Name" TEXT + )""" + + cur = conn.cursor() + cur.execute(stmt) + with iris_file.open(newline=None, encoding="utf-8") as csvfile: + reader = csv.reader(csvfile) + next(reader) + stmt = "INSERT INTO iris VALUES(?, ?, ?, ?, ?)" + # ADBC requires explicit types - no implicit str -> float conversion + records = [] + records = [ + ( + float(row[0]), + float(row[1]), + float(row[2]), + float(row[3]), + row[4], + ) + for row in reader + ] + + cur.executemany(stmt, records) + cur.close() + + conn.commit() + + +def create_and_load_iris_postgresql(conn, iris_file: Path): + stmt = """CREATE TABLE iris ( + "SepalLength" DOUBLE PRECISION, + "SepalWidth" DOUBLE PRECISION, + "PetalLength" DOUBLE PRECISION, + "PetalWidth" DOUBLE PRECISION, + "Name" TEXT + )""" + with conn.cursor() as cur: + cur.execute(stmt) + with iris_file.open(newline=None, encoding="utf-8") as csvfile: + reader = csv.reader(csvfile) + next(reader) + stmt = "INSERT INTO iris VALUES($1, $2, $3, $4, $5)" + # ADBC requires explicit types - no implicit str -> float conversion + records = [ + ( + float(row[0]), + float(row[1]), + float(row[2]), + float(row[3]), + row[4], + ) + for row in reader + ] + + cur.executemany(stmt, records) + + conn.commit() + + +def create_and_load_iris(conn, iris_file: Path): + from sqlalchemy import insert + + iris = iris_table_metadata() + + with iris_file.open(newline=None, encoding="utf-8") as csvfile: + reader = csv.reader(csvfile) + header = next(reader) + params = [dict(zip(header, row)) for row in reader] + stmt = insert(iris).values(params) + with conn.begin() as con: + iris.drop(con, checkfirst=True) + iris.create(bind=con) + con.execute(stmt) + + +def create_and_load_iris_view(conn): + stmt = "CREATE VIEW iris_view AS SELECT * FROM iris" + if isinstance(conn, sqlite3.Connection): + cur = conn.cursor() + cur.execute(stmt) + else: + adbc = import_optional_dependency("adbc_driver_manager.dbapi", errors="ignore") + if adbc and isinstance(conn, adbc.Connection): + with conn.cursor() as cur: + cur.execute(stmt) + conn.commit() + else: + from sqlalchemy import text + + stmt = text(stmt) + with conn.begin() as con: + con.execute(stmt) + + +def types_table_metadata(dialect: str): + from sqlalchemy import ( + TEXT, + Boolean, + Column, + DateTime, + Float, + Integer, + MetaData, + Table, + ) + + date_type = TEXT if dialect == "sqlite" else DateTime + bool_type = Integer if dialect == "sqlite" else Boolean + metadata = MetaData() + types = Table( + "types", + metadata, + Column("TextCol", TEXT), + Column("DateCol", date_type), + Column("IntDateCol", Integer), + Column("IntDateOnlyCol", Integer), + Column("FloatCol", Float), + Column("IntCol", Integer), + Column("BoolCol", bool_type), + Column("IntColWithNull", Integer), + Column("BoolColWithNull", bool_type), + ) + return types + + +def create_and_load_types_sqlite3(conn, types_data: list[dict]): + stmt = """CREATE TABLE types ( + "TextCol" TEXT, + "DateCol" TEXT, + "IntDateCol" INTEGER, + "IntDateOnlyCol" INTEGER, + "FloatCol" REAL, + "IntCol" INTEGER, + "BoolCol" INTEGER, + "IntColWithNull" INTEGER, + "BoolColWithNull" INTEGER + )""" + + ins_stmt = """ + INSERT INTO types + VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?) + """ + + if isinstance(conn, sqlite3.Connection): + cur = conn.cursor() + cur.execute(stmt) + cur.executemany(ins_stmt, types_data) + else: + with conn.cursor() as cur: + cur.execute(stmt) + cur.executemany(ins_stmt, types_data) + + conn.commit() + + +def create_and_load_types_postgresql(conn, types_data: list[dict]): + with conn.cursor() as cur: + stmt = """CREATE TABLE types ( + "TextCol" TEXT, + "DateCol" TIMESTAMP, + "IntDateCol" INTEGER, + "IntDateOnlyCol" INTEGER, + "FloatCol" DOUBLE PRECISION, + "IntCol" INTEGER, + "BoolCol" BOOLEAN, + "IntColWithNull" INTEGER, + "BoolColWithNull" BOOLEAN + )""" + cur.execute(stmt) + + stmt = """ + INSERT INTO types + VALUES($1, $2::timestamp, $3, $4, $5, $6, $7, $8, $9) + """ + + cur.executemany(stmt, types_data) + + conn.commit() + + +def create_and_load_types(conn, types_data: list[dict], dialect: str): + from sqlalchemy import insert + from sqlalchemy.engine import Engine + + types = types_table_metadata(dialect) + + stmt = insert(types).values(types_data) + if isinstance(conn, Engine): + with conn.connect() as conn: + with conn.begin(): + types.drop(conn, checkfirst=True) + types.create(bind=conn) + conn.execute(stmt) + else: + with conn.begin(): + types.drop(conn, checkfirst=True) + types.create(bind=conn) + conn.execute(stmt) + + +def create_and_load_postgres_datetz(conn): + from sqlalchemy import ( + Column, + DateTime, + MetaData, + Table, + insert, + ) + from sqlalchemy.engine import Engine + + metadata = MetaData() + datetz = Table("datetz", metadata, Column("DateColWithTz", DateTime(timezone=True))) + datetz_data = [ + { + "DateColWithTz": "2000-01-01 00:00:00-08:00", + }, + { + "DateColWithTz": "2000-06-01 00:00:00-07:00", + }, + ] + stmt = insert(datetz).values(datetz_data) + if isinstance(conn, Engine): + with conn.connect() as conn: + with conn.begin(): + datetz.drop(conn, checkfirst=True) + datetz.create(bind=conn) + conn.execute(stmt) + else: + with conn.begin(): + datetz.drop(conn, checkfirst=True) + datetz.create(bind=conn) + conn.execute(stmt) + + # "2000-01-01 00:00:00-08:00" should convert to + # "2000-01-01 08:00:00" + # "2000-06-01 00:00:00-07:00" should convert to + # "2000-06-01 07:00:00" + # GH 6415 + expected_data = [ + Timestamp("2000-01-01 08:00:00", tz="UTC"), + Timestamp("2000-06-01 07:00:00", tz="UTC"), + ] + return Series(expected_data, name="DateColWithTz") + + +def check_iris_frame(frame: DataFrame): + pytype = frame.dtypes.iloc[0].type + row = frame.iloc[0] + assert issubclass(pytype, np.floating) + tm.assert_series_equal( + row, Series([5.1, 3.5, 1.4, 0.2, "Iris-setosa"], index=frame.columns, name=0) + ) + assert frame.shape in ((150, 5), (8, 5)) + + +def count_rows(conn, table_name: str): + stmt = f"SELECT count(*) AS count_1 FROM {table_name}" + adbc = import_optional_dependency("adbc_driver_manager.dbapi", errors="ignore") + if isinstance(conn, sqlite3.Connection): + cur = conn.cursor() + return cur.execute(stmt).fetchone()[0] + elif adbc and isinstance(conn, adbc.Connection): + with conn.cursor() as cur: + cur.execute(stmt) + return cur.fetchone()[0] + else: + from sqlalchemy import create_engine + from sqlalchemy.engine import Engine + + if isinstance(conn, str): + try: + engine = create_engine(conn) + with engine.connect() as conn: + return conn.exec_driver_sql(stmt).scalar_one() + finally: + engine.dispose() + elif isinstance(conn, Engine): + with conn.connect() as conn: + return conn.exec_driver_sql(stmt).scalar_one() + else: + return conn.exec_driver_sql(stmt).scalar_one() + + +@pytest.fixture +def iris_path(datapath): + iris_path = datapath("io", "data", "csv", "iris.csv") + return Path(iris_path) + + +@pytest.fixture +def types_data(): + return [ + { + "TextCol": "first", + "DateCol": "2000-01-03 00:00:00", + "IntDateCol": 535852800, + "IntDateOnlyCol": 20101010, + "FloatCol": 10.10, + "IntCol": 1, + "BoolCol": False, + "IntColWithNull": 1, + "BoolColWithNull": False, + }, + { + "TextCol": "first", + "DateCol": "2000-01-04 00:00:00", + "IntDateCol": 1356998400, + "IntDateOnlyCol": 20101212, + "FloatCol": 10.10, + "IntCol": 1, + "BoolCol": False, + "IntColWithNull": None, + "BoolColWithNull": None, + }, + ] + + +@pytest.fixture +def types_data_frame(types_data): + dtypes = { + "TextCol": "str", + "DateCol": "str", + "IntDateCol": "int64", + "IntDateOnlyCol": "int64", + "FloatCol": "float", + "IntCol": "int64", + "BoolCol": "int64", + "IntColWithNull": "float", + "BoolColWithNull": "float", + } + df = DataFrame(types_data) + return df[dtypes.keys()].astype(dtypes) + + +@pytest.fixture +def test_frame1(): + columns = ["index", "A", "B", "C", "D"] + data = [ + ( + "2000-01-03 00:00:00", + 0.980268513777, + 3.68573087906, + -0.364216805298, + -1.15973806169, + ), + ( + "2000-01-04 00:00:00", + 1.04791624281, + -0.0412318367011, + -0.16181208307, + 0.212549316967, + ), + ( + "2000-01-05 00:00:00", + 0.498580885705, + 0.731167677815, + -0.537677223318, + 1.34627041952, + ), + ( + "2000-01-06 00:00:00", + 1.12020151869, + 1.56762092543, + 0.00364077397681, + 0.67525259227, + ), + ] + return DataFrame(data, columns=columns) + + +@pytest.fixture +def test_frame3(): + columns = ["index", "A", "B"] + data = [ + ("2000-01-03 00:00:00", 2**31 - 1, -1.987670), + ("2000-01-04 00:00:00", -29, -0.0412318367011), + ("2000-01-05 00:00:00", 20000, 0.731167677815), + ("2000-01-06 00:00:00", -290867, 1.56762092543), + ] + return DataFrame(data, columns=columns) + + +def get_all_views(conn): + if isinstance(conn, sqlite3.Connection): + c = conn.execute("SELECT name FROM sqlite_master WHERE type='view'") + return [view[0] for view in c.fetchall()] + else: + adbc = import_optional_dependency("adbc_driver_manager.dbapi", errors="ignore") + if adbc and isinstance(conn, adbc.Connection): + results = [] + info = conn.adbc_get_objects().read_all().to_pylist() + for catalog in info: + catalog["catalog_name"] + for schema in catalog["catalog_db_schemas"]: + schema["db_schema_name"] + for table in schema["db_schema_tables"]: + if table["table_type"] == "view": + view_name = table["table_name"] + results.append(view_name) + + return results + else: + from sqlalchemy import inspect + + return inspect(conn).get_view_names() + + +def get_all_tables(conn): + if isinstance(conn, sqlite3.Connection): + c = conn.execute("SELECT name FROM sqlite_master WHERE type='table'") + return [table[0] for table in c.fetchall()] + else: + adbc = import_optional_dependency("adbc_driver_manager.dbapi", errors="ignore") + + if adbc and isinstance(conn, adbc.Connection): + results = [] + info = conn.adbc_get_objects().read_all().to_pylist() + for catalog in info: + for schema in catalog["catalog_db_schemas"]: + for table in schema["db_schema_tables"]: + if table["table_type"] == "table": + table_name = table["table_name"] + results.append(table_name) + + return results + else: + from sqlalchemy import inspect + + return inspect(conn).get_table_names() + + +def drop_table( + table_name: str, + conn: sqlite3.Connection | sqlalchemy.engine.Engine | sqlalchemy.engine.Connection, +): + if isinstance(conn, sqlite3.Connection): + conn.execute(f"DROP TABLE IF EXISTS {sql._get_valid_sqlite_name(table_name)}") + conn.commit() + + else: + adbc = import_optional_dependency("adbc_driver_manager.dbapi", errors="ignore") + if adbc and isinstance(conn, adbc.Connection): + with conn.cursor() as cur: + cur.execute(f'DROP TABLE IF EXISTS "{table_name}"') + else: + with conn.begin() as con: + with sql.SQLDatabase(con) as db: + db.drop_table(table_name) + + +def drop_view( + view_name: str, + conn: sqlite3.Connection | sqlalchemy.engine.Engine | sqlalchemy.engine.Connection, +): + import sqlalchemy + + if isinstance(conn, sqlite3.Connection): + conn.execute(f"DROP VIEW IF EXISTS {sql._get_valid_sqlite_name(view_name)}") + conn.commit() + else: + adbc = import_optional_dependency("adbc_driver_manager.dbapi", errors="ignore") + if adbc and isinstance(conn, adbc.Connection): + with conn.cursor() as cur: + cur.execute(f'DROP VIEW IF EXISTS "{view_name}"') + else: + quoted_view = conn.engine.dialect.identifier_preparer.quote_identifier( + view_name + ) + stmt = sqlalchemy.text(f"DROP VIEW IF EXISTS {quoted_view}") + with conn.begin() as con: + con.execute(stmt) # type: ignore[union-attr] + + +@pytest.fixture +def mysql_pymysql_engine(): + sqlalchemy = pytest.importorskip("sqlalchemy") + pymysql = pytest.importorskip("pymysql") + engine = sqlalchemy.create_engine( + "mysql+pymysql://root@localhost:3306/pandas", + connect_args={"client_flag": pymysql.constants.CLIENT.MULTI_STATEMENTS}, + poolclass=sqlalchemy.pool.NullPool, + ) + yield engine + for view in get_all_views(engine): + drop_view(view, engine) + for tbl in get_all_tables(engine): + drop_table(tbl, engine) + engine.dispose() + + +@pytest.fixture +def mysql_pymysql_engine_iris(mysql_pymysql_engine, iris_path): + create_and_load_iris(mysql_pymysql_engine, iris_path) + create_and_load_iris_view(mysql_pymysql_engine) + yield mysql_pymysql_engine + + +@pytest.fixture +def mysql_pymysql_engine_types(mysql_pymysql_engine, types_data): + create_and_load_types(mysql_pymysql_engine, types_data, "mysql") + yield mysql_pymysql_engine + + +@pytest.fixture +def mysql_pymysql_conn(mysql_pymysql_engine): + with mysql_pymysql_engine.connect() as conn: + yield conn + + +@pytest.fixture +def mysql_pymysql_conn_iris(mysql_pymysql_engine_iris): + with mysql_pymysql_engine_iris.connect() as conn: + yield conn + + +@pytest.fixture +def mysql_pymysql_conn_types(mysql_pymysql_engine_types): + with mysql_pymysql_engine_types.connect() as conn: + yield conn + + +@pytest.fixture +def postgresql_psycopg2_engine(): + sqlalchemy = pytest.importorskip("sqlalchemy") + pytest.importorskip("psycopg2") + engine = sqlalchemy.create_engine( + "postgresql+psycopg2://postgres:postgres@localhost:5432/pandas", + poolclass=sqlalchemy.pool.NullPool, + ) + yield engine + for view in get_all_views(engine): + drop_view(view, engine) + for tbl in get_all_tables(engine): + drop_table(tbl, engine) + engine.dispose() + + +@pytest.fixture +def postgresql_psycopg2_engine_iris(postgresql_psycopg2_engine, iris_path): + create_and_load_iris(postgresql_psycopg2_engine, iris_path) + create_and_load_iris_view(postgresql_psycopg2_engine) + yield postgresql_psycopg2_engine + + +@pytest.fixture +def postgresql_psycopg2_engine_types(postgresql_psycopg2_engine, types_data): + create_and_load_types(postgresql_psycopg2_engine, types_data, "postgres") + yield postgresql_psycopg2_engine + + +@pytest.fixture +def postgresql_psycopg2_conn(postgresql_psycopg2_engine): + with postgresql_psycopg2_engine.connect() as conn: + yield conn + + +@pytest.fixture +def postgresql_adbc_conn(): + pytest.importorskip("adbc_driver_postgresql") + from adbc_driver_postgresql import dbapi + + uri = "postgresql://postgres:postgres@localhost:5432/pandas" + with dbapi.connect(uri) as conn: + yield conn + for view in get_all_views(conn): + drop_view(view, conn) + for tbl in get_all_tables(conn): + drop_table(tbl, conn) + conn.commit() + + +@pytest.fixture +def postgresql_adbc_iris(postgresql_adbc_conn, iris_path): + import adbc_driver_manager as mgr + + conn = postgresql_adbc_conn + + try: + conn.adbc_get_table_schema("iris") + except mgr.ProgrammingError: + conn.rollback() + create_and_load_iris_postgresql(conn, iris_path) + try: + conn.adbc_get_table_schema("iris_view") + except mgr.ProgrammingError: # note arrow-adbc issue 1022 + conn.rollback() + create_and_load_iris_view(conn) + yield conn + + +@pytest.fixture +def postgresql_adbc_types(postgresql_adbc_conn, types_data): + import adbc_driver_manager as mgr + + conn = postgresql_adbc_conn + + try: + conn.adbc_get_table_schema("types") + except mgr.ProgrammingError: + conn.rollback() + new_data = [tuple(entry.values()) for entry in types_data] + + create_and_load_types_postgresql(conn, new_data) + + yield conn + + +@pytest.fixture +def postgresql_psycopg2_conn_iris(postgresql_psycopg2_engine_iris): + with postgresql_psycopg2_engine_iris.connect() as conn: + yield conn + + +@pytest.fixture +def postgresql_psycopg2_conn_types(postgresql_psycopg2_engine_types): + with postgresql_psycopg2_engine_types.connect() as conn: + yield conn + + +@pytest.fixture +def sqlite_str(): + pytest.importorskip("sqlalchemy") + with tm.ensure_clean() as name: + yield f"sqlite:///{name}" + + +@pytest.fixture +def sqlite_engine(sqlite_str): + sqlalchemy = pytest.importorskip("sqlalchemy") + engine = sqlalchemy.create_engine(sqlite_str, poolclass=sqlalchemy.pool.NullPool) + yield engine + for view in get_all_views(engine): + drop_view(view, engine) + for tbl in get_all_tables(engine): + drop_table(tbl, engine) + engine.dispose() + + +@pytest.fixture +def sqlite_conn(sqlite_engine): + with sqlite_engine.connect() as conn: + yield conn + + +@pytest.fixture +def sqlite_str_iris(sqlite_str, iris_path): + sqlalchemy = pytest.importorskip("sqlalchemy") + engine = sqlalchemy.create_engine(sqlite_str) + create_and_load_iris(engine, iris_path) + create_and_load_iris_view(engine) + engine.dispose() + return sqlite_str + + +@pytest.fixture +def sqlite_engine_iris(sqlite_engine, iris_path): + create_and_load_iris(sqlite_engine, iris_path) + create_and_load_iris_view(sqlite_engine) + yield sqlite_engine + + +@pytest.fixture +def sqlite_conn_iris(sqlite_engine_iris): + with sqlite_engine_iris.connect() as conn: + yield conn + + +@pytest.fixture +def sqlite_str_types(sqlite_str, types_data): + sqlalchemy = pytest.importorskip("sqlalchemy") + engine = sqlalchemy.create_engine(sqlite_str) + create_and_load_types(engine, types_data, "sqlite") + engine.dispose() + return sqlite_str + + +@pytest.fixture +def sqlite_engine_types(sqlite_engine, types_data): + create_and_load_types(sqlite_engine, types_data, "sqlite") + yield sqlite_engine + + +@pytest.fixture +def sqlite_conn_types(sqlite_engine_types): + with sqlite_engine_types.connect() as conn: + yield conn + + +@pytest.fixture +def sqlite_adbc_conn(): + pytest.importorskip("adbc_driver_sqlite") + from adbc_driver_sqlite import dbapi + + with tm.ensure_clean() as name: + uri = f"file:{name}" + with dbapi.connect(uri) as conn: + yield conn + for view in get_all_views(conn): + drop_view(view, conn) + for tbl in get_all_tables(conn): + drop_table(tbl, conn) + conn.commit() + + +@pytest.fixture +def sqlite_adbc_iris(sqlite_adbc_conn, iris_path): + import adbc_driver_manager as mgr + + conn = sqlite_adbc_conn + try: + conn.adbc_get_table_schema("iris") + except mgr.ProgrammingError: + conn.rollback() + create_and_load_iris_sqlite3(conn, iris_path) + try: + conn.adbc_get_table_schema("iris_view") + except mgr.ProgrammingError: + conn.rollback() + create_and_load_iris_view(conn) + yield conn + + +@pytest.fixture +def sqlite_adbc_types(sqlite_adbc_conn, types_data): + import adbc_driver_manager as mgr + + conn = sqlite_adbc_conn + try: + conn.adbc_get_table_schema("types") + except mgr.ProgrammingError: + conn.rollback() + new_data = [] + for entry in types_data: + entry["BoolCol"] = int(entry["BoolCol"]) + if entry["BoolColWithNull"] is not None: + entry["BoolColWithNull"] = int(entry["BoolColWithNull"]) + new_data.append(tuple(entry.values())) + + create_and_load_types_sqlite3(conn, new_data) + conn.commit() + + yield conn + + +@pytest.fixture +def sqlite_buildin(): + with contextlib.closing(sqlite3.connect(":memory:")) as closing_conn: + with closing_conn as conn: + yield conn + + +@pytest.fixture +def sqlite_buildin_iris(sqlite_buildin, iris_path): + create_and_load_iris_sqlite3(sqlite_buildin, iris_path) + create_and_load_iris_view(sqlite_buildin) + yield sqlite_buildin + + +@pytest.fixture +def sqlite_buildin_types(sqlite_buildin, types_data): + types_data = [tuple(entry.values()) for entry in types_data] + create_and_load_types_sqlite3(sqlite_buildin, types_data) + yield sqlite_buildin + + +mysql_connectable = [ + pytest.param("mysql_pymysql_engine", marks=pytest.mark.db), + pytest.param("mysql_pymysql_conn", marks=pytest.mark.db), +] + +mysql_connectable_iris = [ + pytest.param("mysql_pymysql_engine_iris", marks=pytest.mark.db), + pytest.param("mysql_pymysql_conn_iris", marks=pytest.mark.db), +] + +mysql_connectable_types = [ + pytest.param("mysql_pymysql_engine_types", marks=pytest.mark.db), + pytest.param("mysql_pymysql_conn_types", marks=pytest.mark.db), +] + +postgresql_connectable = [ + pytest.param("postgresql_psycopg2_engine", marks=pytest.mark.db), + pytest.param("postgresql_psycopg2_conn", marks=pytest.mark.db), +] + +postgresql_connectable_iris = [ + pytest.param("postgresql_psycopg2_engine_iris", marks=pytest.mark.db), + pytest.param("postgresql_psycopg2_conn_iris", marks=pytest.mark.db), +] + +postgresql_connectable_types = [ + pytest.param("postgresql_psycopg2_engine_types", marks=pytest.mark.db), + pytest.param("postgresql_psycopg2_conn_types", marks=pytest.mark.db), +] + +sqlite_connectable = [ + "sqlite_engine", + "sqlite_conn", + "sqlite_str", +] + +sqlite_connectable_iris = [ + "sqlite_engine_iris", + "sqlite_conn_iris", + "sqlite_str_iris", +] + +sqlite_connectable_types = [ + "sqlite_engine_types", + "sqlite_conn_types", + "sqlite_str_types", +] + +sqlalchemy_connectable = mysql_connectable + postgresql_connectable + sqlite_connectable + +sqlalchemy_connectable_iris = ( + mysql_connectable_iris + postgresql_connectable_iris + sqlite_connectable_iris +) + +sqlalchemy_connectable_types = ( + mysql_connectable_types + postgresql_connectable_types + sqlite_connectable_types +) + +adbc_connectable = [ + "sqlite_adbc_conn", + pytest.param("postgresql_adbc_conn", marks=pytest.mark.db), +] + +adbc_connectable_iris = [ + pytest.param("postgresql_adbc_iris", marks=pytest.mark.db), + pytest.param("sqlite_adbc_iris", marks=pytest.mark.db), +] + +adbc_connectable_types = [ + pytest.param("postgresql_adbc_types", marks=pytest.mark.db), + pytest.param("sqlite_adbc_types", marks=pytest.mark.db), +] + + +all_connectable = sqlalchemy_connectable + ["sqlite_buildin"] + adbc_connectable + +all_connectable_iris = ( + sqlalchemy_connectable_iris + ["sqlite_buildin_iris"] + adbc_connectable_iris +) + +all_connectable_types = ( + sqlalchemy_connectable_types + ["sqlite_buildin_types"] + adbc_connectable_types +) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_dataframe_to_sql(conn, test_frame1, request): + # GH 51086 if conn is sqlite_engine + conn = request.getfixturevalue(conn) + test_frame1.to_sql(name="test", con=conn, if_exists="append", index=False) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_dataframe_to_sql_empty(conn, test_frame1, request): + if conn == "postgresql_adbc_conn": + request.node.add_marker( + pytest.mark.xfail( + reason="postgres ADBC driver cannot insert index with null type", + strict=True, + ) + ) + # GH 51086 if conn is sqlite_engine + conn = request.getfixturevalue(conn) + empty_df = test_frame1.iloc[:0] + empty_df.to_sql(name="test", con=conn, if_exists="append", index=False) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_dataframe_to_sql_arrow_dtypes(conn, request): + # GH 52046 + pytest.importorskip("pyarrow") + df = DataFrame( + { + "int": pd.array([1], dtype="int8[pyarrow]"), + "datetime": pd.array( + [datetime(2023, 1, 1)], dtype="timestamp[ns][pyarrow]" + ), + "date": pd.array([date(2023, 1, 1)], dtype="date32[day][pyarrow]"), + "timedelta": pd.array([timedelta(1)], dtype="duration[ns][pyarrow]"), + "string": pd.array(["a"], dtype="string[pyarrow]"), + } + ) + + if "adbc" in conn: + if conn == "sqlite_adbc_conn": + df = df.drop(columns=["timedelta"]) + if pa_version_under14p1: + exp_warning = DeprecationWarning + msg = "is_sparse is deprecated" + else: + exp_warning = None + msg = "" + else: + exp_warning = UserWarning + msg = "the 'timedelta'" + + conn = request.getfixturevalue(conn) + with tm.assert_produces_warning(exp_warning, match=msg, check_stacklevel=False): + df.to_sql(name="test_arrow", con=conn, if_exists="replace", index=False) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_dataframe_to_sql_arrow_dtypes_missing(conn, request, nulls_fixture): + # GH 52046 + pytest.importorskip("pyarrow") + df = DataFrame( + { + "datetime": pd.array( + [datetime(2023, 1, 1), nulls_fixture], dtype="timestamp[ns][pyarrow]" + ), + } + ) + conn = request.getfixturevalue(conn) + df.to_sql(name="test_arrow", con=conn, if_exists="replace", index=False) + + +@pytest.mark.parametrize("conn", all_connectable) +@pytest.mark.parametrize("method", [None, "multi"]) +def test_to_sql(conn, method, test_frame1, request): + if method == "multi" and "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail( + reason="'method' not implemented for ADBC drivers", strict=True + ) + ) + + conn = request.getfixturevalue(conn) + with pandasSQL_builder(conn, need_transaction=True) as pandasSQL: + pandasSQL.to_sql(test_frame1, "test_frame", method=method) + assert pandasSQL.has_table("test_frame") + assert count_rows(conn, "test_frame") == len(test_frame1) + + +@pytest.mark.parametrize("conn", all_connectable) +@pytest.mark.parametrize("mode, num_row_coef", [("replace", 1), ("append", 2)]) +def test_to_sql_exist(conn, mode, num_row_coef, test_frame1, request): + conn = request.getfixturevalue(conn) + with pandasSQL_builder(conn, need_transaction=True) as pandasSQL: + pandasSQL.to_sql(test_frame1, "test_frame", if_exists="fail") + pandasSQL.to_sql(test_frame1, "test_frame", if_exists=mode) + assert pandasSQL.has_table("test_frame") + assert count_rows(conn, "test_frame") == num_row_coef * len(test_frame1) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_to_sql_exist_fail(conn, test_frame1, request): + conn = request.getfixturevalue(conn) + with pandasSQL_builder(conn, need_transaction=True) as pandasSQL: + pandasSQL.to_sql(test_frame1, "test_frame", if_exists="fail") + assert pandasSQL.has_table("test_frame") + + msg = "Table 'test_frame' already exists" + with pytest.raises(ValueError, match=msg): + pandasSQL.to_sql(test_frame1, "test_frame", if_exists="fail") + + +@pytest.mark.parametrize("conn", all_connectable_iris) +def test_read_iris_query(conn, request): + conn = request.getfixturevalue(conn) + iris_frame = read_sql_query("SELECT * FROM iris", conn) + check_iris_frame(iris_frame) + iris_frame = pd.read_sql("SELECT * FROM iris", conn) + check_iris_frame(iris_frame) + iris_frame = pd.read_sql("SELECT * FROM iris where 0=1", conn) + assert iris_frame.shape == (0, 5) + assert "SepalWidth" in iris_frame.columns + + +@pytest.mark.parametrize("conn", all_connectable_iris) +def test_read_iris_query_chunksize(conn, request): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail( + reason="'chunksize' not implemented for ADBC drivers", + strict=True, + ) + ) + conn = request.getfixturevalue(conn) + iris_frame = concat(read_sql_query("SELECT * FROM iris", conn, chunksize=7)) + check_iris_frame(iris_frame) + iris_frame = concat(pd.read_sql("SELECT * FROM iris", conn, chunksize=7)) + check_iris_frame(iris_frame) + iris_frame = concat(pd.read_sql("SELECT * FROM iris where 0=1", conn, chunksize=7)) + assert iris_frame.shape == (0, 5) + assert "SepalWidth" in iris_frame.columns + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris) +def test_read_iris_query_expression_with_parameter(conn, request): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail( + reason="'chunksize' not implemented for ADBC drivers", + strict=True, + ) + ) + conn = request.getfixturevalue(conn) + from sqlalchemy import ( + MetaData, + Table, + create_engine, + select, + ) + + metadata = MetaData() + autoload_con = create_engine(conn) if isinstance(conn, str) else conn + iris = Table("iris", metadata, autoload_with=autoload_con) + iris_frame = read_sql_query( + select(iris), conn, params={"name": "Iris-setosa", "length": 5.1} + ) + check_iris_frame(iris_frame) + if isinstance(conn, str): + autoload_con.dispose() + + +@pytest.mark.parametrize("conn", all_connectable_iris) +def test_read_iris_query_string_with_parameter(conn, request, sql_strings): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail( + reason="'chunksize' not implemented for ADBC drivers", + strict=True, + ) + ) + + for db, query in sql_strings["read_parameters"].items(): + if db in conn: + break + else: + raise KeyError(f"No part of {conn} found in sql_strings['read_parameters']") + conn = request.getfixturevalue(conn) + iris_frame = read_sql_query(query, conn, params=("Iris-setosa", 5.1)) + check_iris_frame(iris_frame) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris) +def test_read_iris_table(conn, request): + # GH 51015 if conn = sqlite_iris_str + conn = request.getfixturevalue(conn) + iris_frame = read_sql_table("iris", conn) + check_iris_frame(iris_frame) + iris_frame = pd.read_sql("iris", conn) + check_iris_frame(iris_frame) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris) +def test_read_iris_table_chunksize(conn, request): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail(reason="chunksize argument NotImplemented with ADBC") + ) + conn = request.getfixturevalue(conn) + iris_frame = concat(read_sql_table("iris", conn, chunksize=7)) + check_iris_frame(iris_frame) + iris_frame = concat(pd.read_sql("iris", conn, chunksize=7)) + check_iris_frame(iris_frame) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_to_sql_callable(conn, test_frame1, request): + conn = request.getfixturevalue(conn) + + check = [] # used to double check function below is really being used + + def sample(pd_table, conn, keys, data_iter): + check.append(1) + data = [dict(zip(keys, row)) for row in data_iter] + conn.execute(pd_table.table.insert(), data) + + with pandasSQL_builder(conn, need_transaction=True) as pandasSQL: + pandasSQL.to_sql(test_frame1, "test_frame", method=sample) + assert pandasSQL.has_table("test_frame") + assert check == [1] + assert count_rows(conn, "test_frame") == len(test_frame1) + + +@pytest.mark.parametrize("conn", all_connectable_types) +def test_default_type_conversion(conn, request): + conn_name = conn + if conn_name == "sqlite_buildin_types": + request.applymarker( + pytest.mark.xfail( + reason="sqlite_buildin connection does not implement read_sql_table" + ) + ) + + conn = request.getfixturevalue(conn) + df = sql.read_sql_table("types", conn) + + assert issubclass(df.FloatCol.dtype.type, np.floating) + assert issubclass(df.IntCol.dtype.type, np.integer) + + # MySQL/sqlite has no real BOOL type + if "postgresql" in conn_name: + assert issubclass(df.BoolCol.dtype.type, np.bool_) + else: + assert issubclass(df.BoolCol.dtype.type, np.integer) + + # Int column with NA values stays as float + assert issubclass(df.IntColWithNull.dtype.type, np.floating) + + # Bool column with NA = int column with NA values => becomes float + if "postgresql" in conn_name: + assert issubclass(df.BoolColWithNull.dtype.type, object) + else: + assert issubclass(df.BoolColWithNull.dtype.type, np.floating) + + +@pytest.mark.parametrize("conn", mysql_connectable) +def test_read_procedure(conn, request): + conn = request.getfixturevalue(conn) + + # GH 7324 + # Although it is more an api test, it is added to the + # mysql tests as sqlite does not have stored procedures + from sqlalchemy import text + from sqlalchemy.engine import Engine + + df = DataFrame({"a": [1, 2, 3], "b": [0.1, 0.2, 0.3]}) + df.to_sql(name="test_frame", con=conn, index=False) + + proc = """DROP PROCEDURE IF EXISTS get_testdb; + + CREATE PROCEDURE get_testdb () + + BEGIN + SELECT * FROM test_frame; + END""" + proc = text(proc) + if isinstance(conn, Engine): + with conn.connect() as engine_conn: + with engine_conn.begin(): + engine_conn.execute(proc) + else: + with conn.begin(): + conn.execute(proc) + + res1 = sql.read_sql_query("CALL get_testdb();", conn) + tm.assert_frame_equal(df, res1) + + # test delegation to read_sql_query + res2 = sql.read_sql("CALL get_testdb();", conn) + tm.assert_frame_equal(df, res2) + + +@pytest.mark.parametrize("conn", postgresql_connectable) +@pytest.mark.parametrize("expected_count", [2, "Success!"]) +def test_copy_from_callable_insertion_method(conn, expected_count, request): + # GH 8953 + # Example in io.rst found under _io.sql.method + # not available in sqlite, mysql + def psql_insert_copy(table, conn, keys, data_iter): + # gets a DBAPI connection that can provide a cursor + dbapi_conn = conn.connection + with dbapi_conn.cursor() as cur: + s_buf = StringIO() + writer = csv.writer(s_buf) + writer.writerows(data_iter) + s_buf.seek(0) + + columns = ", ".join([f'"{k}"' for k in keys]) + if table.schema: + table_name = f"{table.schema}.{table.name}" + else: + table_name = table.name + + sql_query = f"COPY {table_name} ({columns}) FROM STDIN WITH CSV" + cur.copy_expert(sql=sql_query, file=s_buf) + return expected_count + + conn = request.getfixturevalue(conn) + expected = DataFrame({"col1": [1, 2], "col2": [0.1, 0.2], "col3": ["a", "n"]}) + result_count = expected.to_sql( + name="test_frame", con=conn, index=False, method=psql_insert_copy + ) + # GH 46891 + if expected_count is None: + assert result_count is None + else: + assert result_count == expected_count + result = sql.read_sql_table("test_frame", conn) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("conn", postgresql_connectable) +def test_insertion_method_on_conflict_do_nothing(conn, request): + # GH 15988: Example in to_sql docstring + conn = request.getfixturevalue(conn) + + from sqlalchemy.dialects.postgresql import insert + from sqlalchemy.engine import Engine + from sqlalchemy.sql import text + + def insert_on_conflict(table, conn, keys, data_iter): + data = [dict(zip(keys, row)) for row in data_iter] + stmt = ( + insert(table.table) + .values(data) + .on_conflict_do_nothing(index_elements=["a"]) + ) + result = conn.execute(stmt) + return result.rowcount + + create_sql = text( + """ + CREATE TABLE test_insert_conflict ( + a integer PRIMARY KEY, + b numeric, + c text + ); + """ + ) + if isinstance(conn, Engine): + with conn.connect() as con: + with con.begin(): + con.execute(create_sql) + else: + with conn.begin(): + conn.execute(create_sql) + + expected = DataFrame([[1, 2.1, "a"]], columns=list("abc")) + expected.to_sql( + name="test_insert_conflict", con=conn, if_exists="append", index=False + ) + + df_insert = DataFrame([[1, 3.2, "b"]], columns=list("abc")) + inserted = df_insert.to_sql( + name="test_insert_conflict", + con=conn, + index=False, + if_exists="append", + method=insert_on_conflict, + ) + result = sql.read_sql_table("test_insert_conflict", conn) + tm.assert_frame_equal(result, expected) + assert inserted == 0 + + # Cleanup + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_insert_conflict") + + +@pytest.mark.parametrize("conn", all_connectable) +def test_to_sql_on_public_schema(conn, request): + if "sqlite" in conn or "mysql" in conn: + request.applymarker( + pytest.mark.xfail( + reason="test for public schema only specific to postgresql" + ) + ) + + conn = request.getfixturevalue(conn) + + test_data = DataFrame([[1, 2.1, "a"], [2, 3.1, "b"]], columns=list("abc")) + test_data.to_sql( + name="test_public_schema", + con=conn, + if_exists="append", + index=False, + schema="public", + ) + + df_out = sql.read_sql_table("test_public_schema", conn, schema="public") + tm.assert_frame_equal(test_data, df_out) + + +@pytest.mark.parametrize("conn", mysql_connectable) +def test_insertion_method_on_conflict_update(conn, request): + # GH 14553: Example in to_sql docstring + conn = request.getfixturevalue(conn) + + from sqlalchemy.dialects.mysql import insert + from sqlalchemy.engine import Engine + from sqlalchemy.sql import text + + def insert_on_conflict(table, conn, keys, data_iter): + data = [dict(zip(keys, row)) for row in data_iter] + stmt = insert(table.table).values(data) + stmt = stmt.on_duplicate_key_update(b=stmt.inserted.b, c=stmt.inserted.c) + result = conn.execute(stmt) + return result.rowcount + + create_sql = text( + """ + CREATE TABLE test_insert_conflict ( + a INT PRIMARY KEY, + b FLOAT, + c VARCHAR(10) + ); + """ + ) + if isinstance(conn, Engine): + with conn.connect() as con: + with con.begin(): + con.execute(create_sql) + else: + with conn.begin(): + conn.execute(create_sql) + + df = DataFrame([[1, 2.1, "a"]], columns=list("abc")) + df.to_sql(name="test_insert_conflict", con=conn, if_exists="append", index=False) + + expected = DataFrame([[1, 3.2, "b"]], columns=list("abc")) + inserted = expected.to_sql( + name="test_insert_conflict", + con=conn, + index=False, + if_exists="append", + method=insert_on_conflict, + ) + result = sql.read_sql_table("test_insert_conflict", conn) + tm.assert_frame_equal(result, expected) + assert inserted == 2 + + # Cleanup + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_insert_conflict") + + +@pytest.mark.parametrize("conn", postgresql_connectable) +def test_read_view_postgres(conn, request): + # GH 52969 + conn = request.getfixturevalue(conn) + + from sqlalchemy.engine import Engine + from sqlalchemy.sql import text + + table_name = f"group_{uuid.uuid4().hex}" + view_name = f"group_view_{uuid.uuid4().hex}" + + sql_stmt = text( + f""" + CREATE TABLE {table_name} ( + group_id INTEGER, + name TEXT + ); + INSERT INTO {table_name} VALUES + (1, 'name'); + CREATE VIEW {view_name} + AS + SELECT * FROM {table_name}; + """ + ) + if isinstance(conn, Engine): + with conn.connect() as con: + with con.begin(): + con.execute(sql_stmt) + else: + with conn.begin(): + conn.execute(sql_stmt) + result = read_sql_table(view_name, conn) + expected = DataFrame({"group_id": [1], "name": "name"}) + tm.assert_frame_equal(result, expected) + + +def test_read_view_sqlite(sqlite_buildin): + # GH 52969 + create_table = """ +CREATE TABLE groups ( + group_id INTEGER, + name TEXT +); +""" + insert_into = """ +INSERT INTO groups VALUES + (1, 'name'); +""" + create_view = """ +CREATE VIEW group_view +AS +SELECT * FROM groups; +""" + sqlite_buildin.execute(create_table) + sqlite_buildin.execute(insert_into) + sqlite_buildin.execute(create_view) + result = pd.read_sql("SELECT * FROM group_view", sqlite_buildin) + expected = DataFrame({"group_id": [1], "name": "name"}) + tm.assert_frame_equal(result, expected) + + +def test_execute_typeerror(sqlite_engine_iris): + with pytest.raises(TypeError, match="pandas.io.sql.execute requires a connection"): + with tm.assert_produces_warning( + FutureWarning, + match="`pandas.io.sql.execute` is deprecated and " + "will be removed in the future version.", + ): + sql.execute("select * from iris", sqlite_engine_iris) + + +def test_execute_deprecated(sqlite_conn_iris): + # GH50185 + with tm.assert_produces_warning( + FutureWarning, + match="`pandas.io.sql.execute` is deprecated and " + "will be removed in the future version.", + ): + sql.execute("select * from iris", sqlite_conn_iris) + + +def flavor(conn_name): + if "postgresql" in conn_name: + return "postgresql" + elif "sqlite" in conn_name: + return "sqlite" + elif "mysql" in conn_name: + return "mysql" + + raise ValueError(f"unsupported connection: {conn_name}") + + +@pytest.mark.parametrize("conn", all_connectable_iris) +def test_read_sql_iris_parameter(conn, request, sql_strings): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail( + reason="'params' not implemented for ADBC drivers", + strict=True, + ) + ) + conn_name = conn + conn = request.getfixturevalue(conn) + query = sql_strings["read_parameters"][flavor(conn_name)] + params = ("Iris-setosa", 5.1) + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction(): + iris_frame = pandasSQL.read_query(query, params=params) + check_iris_frame(iris_frame) + + +@pytest.mark.parametrize("conn", all_connectable_iris) +def test_read_sql_iris_named_parameter(conn, request, sql_strings): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail( + reason="'params' not implemented for ADBC drivers", + strict=True, + ) + ) + + conn_name = conn + conn = request.getfixturevalue(conn) + query = sql_strings["read_named_parameters"][flavor(conn_name)] + params = {"name": "Iris-setosa", "length": 5.1} + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction(): + iris_frame = pandasSQL.read_query(query, params=params) + check_iris_frame(iris_frame) + + +@pytest.mark.parametrize("conn", all_connectable_iris) +def test_read_sql_iris_no_parameter_with_percent(conn, request, sql_strings): + if "mysql" in conn or ("postgresql" in conn and "adbc" not in conn): + request.applymarker(pytest.mark.xfail(reason="broken test")) + + conn_name = conn + conn = request.getfixturevalue(conn) + + query = sql_strings["read_no_parameters_with_percent"][flavor(conn_name)] + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction(): + iris_frame = pandasSQL.read_query(query, params=None) + check_iris_frame(iris_frame) + + +# ----------------------------------------------------------------------------- +# -- Testing the public API + + +@pytest.mark.parametrize("conn", all_connectable_iris) +def test_api_read_sql_view(conn, request): + conn = request.getfixturevalue(conn) + iris_frame = sql.read_sql_query("SELECT * FROM iris_view", conn) + check_iris_frame(iris_frame) + + +@pytest.mark.parametrize("conn", all_connectable_iris) +def test_api_read_sql_with_chunksize_no_result(conn, request): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail(reason="chunksize argument NotImplemented with ADBC") + ) + conn = request.getfixturevalue(conn) + query = 'SELECT * FROM iris_view WHERE "SepalLength" < 0.0' + with_batch = sql.read_sql_query(query, conn, chunksize=5) + without_batch = sql.read_sql_query(query, conn) + tm.assert_frame_equal(concat(with_batch), without_batch) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_to_sql(conn, request, test_frame1): + conn = request.getfixturevalue(conn) + if sql.has_table("test_frame1", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_frame1") + + sql.to_sql(test_frame1, "test_frame1", conn) + assert sql.has_table("test_frame1", conn) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_to_sql_fail(conn, request, test_frame1): + conn = request.getfixturevalue(conn) + if sql.has_table("test_frame2", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_frame2") + + sql.to_sql(test_frame1, "test_frame2", conn, if_exists="fail") + assert sql.has_table("test_frame2", conn) + + msg = "Table 'test_frame2' already exists" + with pytest.raises(ValueError, match=msg): + sql.to_sql(test_frame1, "test_frame2", conn, if_exists="fail") + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_to_sql_replace(conn, request, test_frame1): + conn = request.getfixturevalue(conn) + if sql.has_table("test_frame3", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_frame3") + + sql.to_sql(test_frame1, "test_frame3", conn, if_exists="fail") + # Add to table again + sql.to_sql(test_frame1, "test_frame3", conn, if_exists="replace") + assert sql.has_table("test_frame3", conn) + + num_entries = len(test_frame1) + num_rows = count_rows(conn, "test_frame3") + + assert num_rows == num_entries + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_to_sql_append(conn, request, test_frame1): + conn = request.getfixturevalue(conn) + if sql.has_table("test_frame4", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_frame4") + + assert sql.to_sql(test_frame1, "test_frame4", conn, if_exists="fail") == 4 + + # Add to table again + assert sql.to_sql(test_frame1, "test_frame4", conn, if_exists="append") == 4 + assert sql.has_table("test_frame4", conn) + + num_entries = 2 * len(test_frame1) + num_rows = count_rows(conn, "test_frame4") + + assert num_rows == num_entries + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_to_sql_type_mapping(conn, request, test_frame3): + conn = request.getfixturevalue(conn) + if sql.has_table("test_frame5", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_frame5") + + sql.to_sql(test_frame3, "test_frame5", conn, index=False) + result = sql.read_sql("SELECT * FROM test_frame5", conn) + + tm.assert_frame_equal(test_frame3, result) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_to_sql_series(conn, request): + conn = request.getfixturevalue(conn) + if sql.has_table("test_series", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_series") + + s = Series(np.arange(5, dtype="int64"), name="series") + sql.to_sql(s, "test_series", conn, index=False) + s2 = sql.read_sql_query("SELECT * FROM test_series", conn) + tm.assert_frame_equal(s.to_frame(), s2) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_roundtrip(conn, request, test_frame1): + conn_name = conn + conn = request.getfixturevalue(conn) + if sql.has_table("test_frame_roundtrip", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_frame_roundtrip") + + sql.to_sql(test_frame1, "test_frame_roundtrip", con=conn) + result = sql.read_sql_query("SELECT * FROM test_frame_roundtrip", con=conn) + + # HACK! + if "adbc" in conn_name: + result = result.rename(columns={"__index_level_0__": "level_0"}) + result.index = test_frame1.index + result.set_index("level_0", inplace=True) + result.index.astype(int) + result.index.name = None + tm.assert_frame_equal(result, test_frame1) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_roundtrip_chunksize(conn, request, test_frame1): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail(reason="chunksize argument NotImplemented with ADBC") + ) + conn = request.getfixturevalue(conn) + if sql.has_table("test_frame_roundtrip", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_frame_roundtrip") + + sql.to_sql( + test_frame1, + "test_frame_roundtrip", + con=conn, + index=False, + chunksize=2, + ) + result = sql.read_sql_query("SELECT * FROM test_frame_roundtrip", con=conn) + tm.assert_frame_equal(result, test_frame1) + + +@pytest.mark.parametrize("conn", all_connectable_iris) +def test_api_execute_sql(conn, request): + # drop_sql = "DROP TABLE IF EXISTS test" # should already be done + conn = request.getfixturevalue(conn) + with sql.pandasSQL_builder(conn) as pandas_sql: + iris_results = pandas_sql.execute("SELECT * FROM iris") + row = iris_results.fetchone() + iris_results.close() + assert list(row) == [5.1, 3.5, 1.4, 0.2, "Iris-setosa"] + + +@pytest.mark.parametrize("conn", all_connectable_types) +def test_api_date_parsing(conn, request): + conn_name = conn + conn = request.getfixturevalue(conn) + # Test date parsing in read_sql + # No Parsing + df = sql.read_sql_query("SELECT * FROM types", conn) + if not ("mysql" in conn_name or "postgres" in conn_name): + assert not issubclass(df.DateCol.dtype.type, np.datetime64) + + df = sql.read_sql_query("SELECT * FROM types", conn, parse_dates=["DateCol"]) + assert issubclass(df.DateCol.dtype.type, np.datetime64) + assert df.DateCol.tolist() == [ + Timestamp(2000, 1, 3, 0, 0, 0), + Timestamp(2000, 1, 4, 0, 0, 0), + ] + + df = sql.read_sql_query( + "SELECT * FROM types", + conn, + parse_dates={"DateCol": "%Y-%m-%d %H:%M:%S"}, + ) + assert issubclass(df.DateCol.dtype.type, np.datetime64) + assert df.DateCol.tolist() == [ + Timestamp(2000, 1, 3, 0, 0, 0), + Timestamp(2000, 1, 4, 0, 0, 0), + ] + + df = sql.read_sql_query("SELECT * FROM types", conn, parse_dates=["IntDateCol"]) + assert issubclass(df.IntDateCol.dtype.type, np.datetime64) + assert df.IntDateCol.tolist() == [ + Timestamp(1986, 12, 25, 0, 0, 0), + Timestamp(2013, 1, 1, 0, 0, 0), + ] + + df = sql.read_sql_query( + "SELECT * FROM types", conn, parse_dates={"IntDateCol": "s"} + ) + assert issubclass(df.IntDateCol.dtype.type, np.datetime64) + assert df.IntDateCol.tolist() == [ + Timestamp(1986, 12, 25, 0, 0, 0), + Timestamp(2013, 1, 1, 0, 0, 0), + ] + + df = sql.read_sql_query( + "SELECT * FROM types", + conn, + parse_dates={"IntDateOnlyCol": "%Y%m%d"}, + ) + assert issubclass(df.IntDateOnlyCol.dtype.type, np.datetime64) + assert df.IntDateOnlyCol.tolist() == [ + Timestamp("2010-10-10"), + Timestamp("2010-12-12"), + ] + + +@pytest.mark.parametrize("conn", all_connectable_types) +@pytest.mark.parametrize("error", ["ignore", "raise", "coerce"]) +@pytest.mark.parametrize( + "read_sql, text, mode", + [ + (sql.read_sql, "SELECT * FROM types", ("sqlalchemy", "fallback")), + (sql.read_sql, "types", ("sqlalchemy")), + ( + sql.read_sql_query, + "SELECT * FROM types", + ("sqlalchemy", "fallback"), + ), + (sql.read_sql_table, "types", ("sqlalchemy")), + ], +) +def test_api_custom_dateparsing_error( + conn, request, read_sql, text, mode, error, types_data_frame +): + conn_name = conn + conn = request.getfixturevalue(conn) + if text == "types" and conn_name == "sqlite_buildin_types": + request.applymarker( + pytest.mark.xfail(reason="failing combination of arguments") + ) + + expected = types_data_frame.astype({"DateCol": "datetime64[ns]"}) + + result = read_sql( + text, + con=conn, + parse_dates={ + "DateCol": {"errors": error}, + }, + ) + if "postgres" in conn_name: + # TODO: clean up types_data_frame fixture + result["BoolCol"] = result["BoolCol"].astype(int) + result["BoolColWithNull"] = result["BoolColWithNull"].astype(float) + + if conn_name == "postgresql_adbc_types": + expected = expected.astype( + { + "IntDateCol": "int32", + "IntDateOnlyCol": "int32", + "IntCol": "int32", + } + ) + + if not pa_version_under13p0: + # TODO: is this astype safe? + expected["DateCol"] = expected["DateCol"].astype("datetime64[us]") + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("conn", all_connectable_types) +def test_api_date_and_index(conn, request): + # Test case where same column appears in parse_date and index_col + conn = request.getfixturevalue(conn) + df = sql.read_sql_query( + "SELECT * FROM types", + conn, + index_col="DateCol", + parse_dates=["DateCol", "IntDateCol"], + ) + + assert issubclass(df.index.dtype.type, np.datetime64) + assert issubclass(df.IntDateCol.dtype.type, np.datetime64) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_timedelta(conn, request): + # see #6921 + conn_name = conn + conn = request.getfixturevalue(conn) + if sql.has_table("test_timedelta", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_timedelta") + + df = to_timedelta(Series(["00:00:01", "00:00:03"], name="foo")).to_frame() + + if conn_name == "sqlite_adbc_conn": + request.node.add_marker( + pytest.mark.xfail( + reason="sqlite ADBC driver doesn't implement timedelta", + ) + ) + + if "adbc" in conn_name: + if pa_version_under14p1: + exp_warning = DeprecationWarning + else: + exp_warning = None + else: + exp_warning = UserWarning + + with tm.assert_produces_warning(exp_warning, check_stacklevel=False): + result_count = df.to_sql(name="test_timedelta", con=conn) + assert result_count == 2 + result = sql.read_sql_query("SELECT * FROM test_timedelta", conn) + + if conn_name == "postgresql_adbc_conn": + # TODO: Postgres stores an INTERVAL, which ADBC reads as a Month-Day-Nano + # Interval; the default pandas type mapper maps this to a DateOffset + # but maybe we should try and restore the timedelta here? + expected = Series( + [ + pd.DateOffset(months=0, days=0, microseconds=1000000, nanoseconds=0), + pd.DateOffset(months=0, days=0, microseconds=3000000, nanoseconds=0), + ], + name="foo", + ) + else: + expected = df["foo"].astype("int64") + tm.assert_series_equal(result["foo"], expected) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_complex_raises(conn, request): + conn_name = conn + conn = request.getfixturevalue(conn) + df = DataFrame({"a": [1 + 1j, 2j]}) + + if "adbc" in conn_name: + msg = "datatypes not supported" + else: + msg = "Complex datatypes not supported" + with pytest.raises(ValueError, match=msg): + assert df.to_sql("test_complex", con=conn) is None + + +@pytest.mark.parametrize("conn", all_connectable) +@pytest.mark.parametrize( + "index_name,index_label,expected", + [ + # no index name, defaults to 'index' + (None, None, "index"), + # specifying index_label + (None, "other_label", "other_label"), + # using the index name + ("index_name", None, "index_name"), + # has index name, but specifying index_label + ("index_name", "other_label", "other_label"), + # index name is integer + (0, None, "0"), + # index name is None but index label is integer + (None, 0, "0"), + ], +) +def test_api_to_sql_index_label(conn, request, index_name, index_label, expected): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail(reason="index_label argument NotImplemented with ADBC") + ) + conn = request.getfixturevalue(conn) + if sql.has_table("test_index_label", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_index_label") + + temp_frame = DataFrame({"col1": range(4)}) + temp_frame.index.name = index_name + query = "SELECT * FROM test_index_label" + sql.to_sql(temp_frame, "test_index_label", conn, index_label=index_label) + frame = sql.read_sql_query(query, conn) + assert frame.columns[0] == expected + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_to_sql_index_label_multiindex(conn, request): + conn_name = conn + if "mysql" in conn_name: + request.applymarker( + pytest.mark.xfail( + reason="MySQL can fail using TEXT without length as key", strict=False + ) + ) + elif "adbc" in conn_name: + request.node.add_marker( + pytest.mark.xfail(reason="index_label argument NotImplemented with ADBC") + ) + + conn = request.getfixturevalue(conn) + if sql.has_table("test_index_label", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_index_label") + + expected_row_count = 4 + temp_frame = DataFrame( + {"col1": range(4)}, + index=MultiIndex.from_product([("A0", "A1"), ("B0", "B1")]), + ) + + # no index name, defaults to 'level_0' and 'level_1' + result = sql.to_sql(temp_frame, "test_index_label", conn) + assert result == expected_row_count + frame = sql.read_sql_query("SELECT * FROM test_index_label", conn) + assert frame.columns[0] == "level_0" + assert frame.columns[1] == "level_1" + + # specifying index_label + result = sql.to_sql( + temp_frame, + "test_index_label", + conn, + if_exists="replace", + index_label=["A", "B"], + ) + assert result == expected_row_count + frame = sql.read_sql_query("SELECT * FROM test_index_label", conn) + assert frame.columns[:2].tolist() == ["A", "B"] + + # using the index name + temp_frame.index.names = ["A", "B"] + result = sql.to_sql(temp_frame, "test_index_label", conn, if_exists="replace") + assert result == expected_row_count + frame = sql.read_sql_query("SELECT * FROM test_index_label", conn) + assert frame.columns[:2].tolist() == ["A", "B"] + + # has index name, but specifying index_label + result = sql.to_sql( + temp_frame, + "test_index_label", + conn, + if_exists="replace", + index_label=["C", "D"], + ) + assert result == expected_row_count + frame = sql.read_sql_query("SELECT * FROM test_index_label", conn) + assert frame.columns[:2].tolist() == ["C", "D"] + + msg = "Length of 'index_label' should match number of levels, which is 2" + with pytest.raises(ValueError, match=msg): + sql.to_sql( + temp_frame, + "test_index_label", + conn, + if_exists="replace", + index_label="C", + ) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_multiindex_roundtrip(conn, request): + conn = request.getfixturevalue(conn) + if sql.has_table("test_multiindex_roundtrip", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_multiindex_roundtrip") + + df = DataFrame.from_records( + [(1, 2.1, "line1"), (2, 1.5, "line2")], + columns=["A", "B", "C"], + index=["A", "B"], + ) + + df.to_sql(name="test_multiindex_roundtrip", con=conn) + result = sql.read_sql_query( + "SELECT * FROM test_multiindex_roundtrip", conn, index_col=["A", "B"] + ) + tm.assert_frame_equal(df, result, check_index_type=True) + + +@pytest.mark.parametrize("conn", all_connectable) +@pytest.mark.parametrize( + "dtype", + [ + None, + int, + float, + {"A": int, "B": float}, + ], +) +def test_api_dtype_argument(conn, request, dtype): + # GH10285 Add dtype argument to read_sql_query + conn_name = conn + conn = request.getfixturevalue(conn) + if sql.has_table("test_dtype_argument", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_dtype_argument") + + df = DataFrame([[1.2, 3.4], [5.6, 7.8]], columns=["A", "B"]) + assert df.to_sql(name="test_dtype_argument", con=conn) == 2 + + expected = df.astype(dtype) + + if "postgres" in conn_name: + query = 'SELECT "A", "B" FROM test_dtype_argument' + else: + query = "SELECT A, B FROM test_dtype_argument" + result = sql.read_sql_query(query, con=conn, dtype=dtype) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_integer_col_names(conn, request): + conn = request.getfixturevalue(conn) + df = DataFrame([[1, 2], [3, 4]], columns=[0, 1]) + sql.to_sql(df, "test_frame_integer_col_names", conn, if_exists="replace") + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_get_schema(conn, request, test_frame1): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail( + reason="'get_schema' not implemented for ADBC drivers", + strict=True, + ) + ) + conn = request.getfixturevalue(conn) + create_sql = sql.get_schema(test_frame1, "test", con=conn) + assert "CREATE" in create_sql + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_get_schema_with_schema(conn, request, test_frame1): + # GH28486 + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail( + reason="'get_schema' not implemented for ADBC drivers", + strict=True, + ) + ) + conn = request.getfixturevalue(conn) + create_sql = sql.get_schema(test_frame1, "test", con=conn, schema="pypi") + assert "CREATE TABLE pypi." in create_sql + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_get_schema_dtypes(conn, request): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail( + reason="'get_schema' not implemented for ADBC drivers", + strict=True, + ) + ) + conn_name = conn + conn = request.getfixturevalue(conn) + float_frame = DataFrame({"a": [1.1, 1.2], "b": [2.1, 2.2]}) + + if conn_name == "sqlite_buildin": + dtype = "INTEGER" + else: + from sqlalchemy import Integer + + dtype = Integer + create_sql = sql.get_schema(float_frame, "test", con=conn, dtype={"b": dtype}) + assert "CREATE" in create_sql + assert "INTEGER" in create_sql + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_get_schema_keys(conn, request, test_frame1): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail( + reason="'get_schema' not implemented for ADBC drivers", + strict=True, + ) + ) + conn_name = conn + conn = request.getfixturevalue(conn) + frame = DataFrame({"Col1": [1.1, 1.2], "Col2": [2.1, 2.2]}) + create_sql = sql.get_schema(frame, "test", con=conn, keys="Col1") + + if "mysql" in conn_name: + constraint_sentence = "CONSTRAINT test_pk PRIMARY KEY (`Col1`)" + else: + constraint_sentence = 'CONSTRAINT test_pk PRIMARY KEY ("Col1")' + assert constraint_sentence in create_sql + + # multiple columns as key (GH10385) + create_sql = sql.get_schema(test_frame1, "test", con=conn, keys=["A", "B"]) + if "mysql" in conn_name: + constraint_sentence = "CONSTRAINT test_pk PRIMARY KEY (`A`, `B`)" + else: + constraint_sentence = 'CONSTRAINT test_pk PRIMARY KEY ("A", "B")' + assert constraint_sentence in create_sql + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_chunksize_read(conn, request): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail(reason="chunksize argument NotImplemented with ADBC") + ) + conn_name = conn + conn = request.getfixturevalue(conn) + if sql.has_table("test_chunksize", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_chunksize") + + df = DataFrame( + np.random.default_rng(2).standard_normal((22, 5)), columns=list("abcde") + ) + df.to_sql(name="test_chunksize", con=conn, index=False) + + # reading the query in one time + res1 = sql.read_sql_query("select * from test_chunksize", conn) + + # reading the query in chunks with read_sql_query + res2 = DataFrame() + i = 0 + sizes = [5, 5, 5, 5, 2] + + for chunk in sql.read_sql_query("select * from test_chunksize", conn, chunksize=5): + res2 = concat([res2, chunk], ignore_index=True) + assert len(chunk) == sizes[i] + i += 1 + + tm.assert_frame_equal(res1, res2) + + # reading the query in chunks with read_sql_query + if conn_name == "sqlite_buildin": + with pytest.raises(NotImplementedError, match=""): + sql.read_sql_table("test_chunksize", conn, chunksize=5) + else: + res3 = DataFrame() + i = 0 + sizes = [5, 5, 5, 5, 2] + + for chunk in sql.read_sql_table("test_chunksize", conn, chunksize=5): + res3 = concat([res3, chunk], ignore_index=True) + assert len(chunk) == sizes[i] + i += 1 + + tm.assert_frame_equal(res1, res3) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_categorical(conn, request): + if conn == "postgresql_adbc_conn": + adbc = import_optional_dependency("adbc_driver_postgresql", errors="ignore") + if adbc is not None and Version(adbc.__version__) < Version("0.9.0"): + request.node.add_marker( + pytest.mark.xfail( + reason="categorical dtype not implemented for ADBC postgres driver", + strict=True, + ) + ) + # GH8624 + # test that categorical gets written correctly as dense column + conn = request.getfixturevalue(conn) + if sql.has_table("test_categorical", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_categorical") + + df = DataFrame( + { + "person_id": [1, 2, 3], + "person_name": ["John P. Doe", "Jane Dove", "John P. Doe"], + } + ) + df2 = df.copy() + df2["person_name"] = df2["person_name"].astype("category") + + df2.to_sql(name="test_categorical", con=conn, index=False) + res = sql.read_sql_query("SELECT * FROM test_categorical", conn) + + tm.assert_frame_equal(res, df) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_unicode_column_name(conn, request): + # GH 11431 + conn = request.getfixturevalue(conn) + if sql.has_table("test_unicode", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_unicode") + + df = DataFrame([[1, 2], [3, 4]], columns=["\xe9", "b"]) + df.to_sql(name="test_unicode", con=conn, index=False) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_escaped_table_name(conn, request): + # GH 13206 + conn_name = conn + conn = request.getfixturevalue(conn) + if sql.has_table("d1187b08-4943-4c8d-a7f6", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("d1187b08-4943-4c8d-a7f6") + + df = DataFrame({"A": [0, 1, 2], "B": [0.2, np.nan, 5.6]}) + df.to_sql(name="d1187b08-4943-4c8d-a7f6", con=conn, index=False) + + if "postgres" in conn_name: + query = 'SELECT * FROM "d1187b08-4943-4c8d-a7f6"' + else: + query = "SELECT * FROM `d1187b08-4943-4c8d-a7f6`" + res = sql.read_sql_query(query, conn) + + tm.assert_frame_equal(res, df) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_read_sql_duplicate_columns(conn, request): + # GH#53117 + if "adbc" in conn: + pa = pytest.importorskip("pyarrow") + if not ( + Version(pa.__version__) >= Version("16.0") + and conn in ["sqlite_adbc_conn", "postgresql_adbc_conn"] + ): + request.node.add_marker( + pytest.mark.xfail( + reason="pyarrow->pandas throws ValueError", strict=True + ) + ) + conn = request.getfixturevalue(conn) + if sql.has_table("test_table", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_table") + + df = DataFrame({"a": [1, 2, 3], "b": [0.1, 0.2, 0.3], "c": 1}) + df.to_sql(name="test_table", con=conn, index=False) + + result = pd.read_sql("SELECT a, b, a +1 as a, c FROM test_table", conn) + expected = DataFrame( + [[1, 0.1, 2, 1], [2, 0.2, 3, 1], [3, 0.3, 4, 1]], + columns=["a", "b", "a", "c"], + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_read_table_columns(conn, request, test_frame1): + # test columns argument in read_table + conn_name = conn + if conn_name == "sqlite_buildin": + request.applymarker(pytest.mark.xfail(reason="Not Implemented")) + + conn = request.getfixturevalue(conn) + sql.to_sql(test_frame1, "test_frame", conn) + + cols = ["A", "B"] + + result = sql.read_sql_table("test_frame", conn, columns=cols) + assert result.columns.tolist() == cols + + +@pytest.mark.parametrize("conn", all_connectable) +def test_read_table_index_col(conn, request, test_frame1): + # test columns argument in read_table + conn_name = conn + if conn_name == "sqlite_buildin": + request.applymarker(pytest.mark.xfail(reason="Not Implemented")) + + conn = request.getfixturevalue(conn) + sql.to_sql(test_frame1, "test_frame", conn) + + result = sql.read_sql_table("test_frame", conn, index_col="index") + assert result.index.names == ["index"] + + result = sql.read_sql_table("test_frame", conn, index_col=["A", "B"]) + assert result.index.names == ["A", "B"] + + result = sql.read_sql_table( + "test_frame", conn, index_col=["A", "B"], columns=["C", "D"] + ) + assert result.index.names == ["A", "B"] + assert result.columns.tolist() == ["C", "D"] + + +@pytest.mark.parametrize("conn", all_connectable_iris) +def test_read_sql_delegate(conn, request): + if conn == "sqlite_buildin_iris": + request.applymarker( + pytest.mark.xfail( + reason="sqlite_buildin connection does not implement read_sql_table" + ) + ) + + conn = request.getfixturevalue(conn) + iris_frame1 = sql.read_sql_query("SELECT * FROM iris", conn) + iris_frame2 = sql.read_sql("SELECT * FROM iris", conn) + tm.assert_frame_equal(iris_frame1, iris_frame2) + + iris_frame1 = sql.read_sql_table("iris", conn) + iris_frame2 = sql.read_sql("iris", conn) + tm.assert_frame_equal(iris_frame1, iris_frame2) + + +def test_not_reflect_all_tables(sqlite_conn): + conn = sqlite_conn + from sqlalchemy import text + from sqlalchemy.engine import Engine + + # create invalid table + query_list = [ + text("CREATE TABLE invalid (x INTEGER, y UNKNOWN);"), + text("CREATE TABLE other_table (x INTEGER, y INTEGER);"), + ] + + for query in query_list: + if isinstance(conn, Engine): + with conn.connect() as conn: + with conn.begin(): + conn.execute(query) + else: + with conn.begin(): + conn.execute(query) + + with tm.assert_produces_warning(None): + sql.read_sql_table("other_table", conn) + sql.read_sql_query("SELECT * FROM other_table", conn) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_warning_case_insensitive_table_name(conn, request, test_frame1): + conn_name = conn + if conn_name == "sqlite_buildin" or "adbc" in conn_name: + request.applymarker(pytest.mark.xfail(reason="Does not raise warning")) + + conn = request.getfixturevalue(conn) + # see gh-7815 + with tm.assert_produces_warning( + UserWarning, + match=( + r"The provided table name 'TABLE1' is not found exactly as such in " + r"the database after writing the table, possibly due to case " + r"sensitivity issues. Consider using lower case table names." + ), + ): + with sql.SQLDatabase(conn) as db: + db.check_case_sensitive("TABLE1", "") + + # Test that the warning is certainly NOT triggered in a normal case. + with tm.assert_produces_warning(None): + test_frame1.to_sql(name="CaseSensitive", con=conn) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_sqlalchemy_type_mapping(conn, request): + conn = request.getfixturevalue(conn) + from sqlalchemy import TIMESTAMP + + # Test Timestamp objects (no datetime64 because of timezone) (GH9085) + df = DataFrame( + {"time": to_datetime(["2014-12-12 01:54", "2014-12-11 02:54"], utc=True)} + ) + with sql.SQLDatabase(conn) as db: + table = sql.SQLTable("test_type", db, frame=df) + # GH 9086: TIMESTAMP is the suggested type for datetimes with timezones + assert isinstance(table.table.c["time"].type, TIMESTAMP) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +@pytest.mark.parametrize( + "integer, expected", + [ + ("int8", "SMALLINT"), + ("Int8", "SMALLINT"), + ("uint8", "SMALLINT"), + ("UInt8", "SMALLINT"), + ("int16", "SMALLINT"), + ("Int16", "SMALLINT"), + ("uint16", "INTEGER"), + ("UInt16", "INTEGER"), + ("int32", "INTEGER"), + ("Int32", "INTEGER"), + ("uint32", "BIGINT"), + ("UInt32", "BIGINT"), + ("int64", "BIGINT"), + ("Int64", "BIGINT"), + (int, "BIGINT" if np.dtype(int).name == "int64" else "INTEGER"), + ], +) +def test_sqlalchemy_integer_mapping(conn, request, integer, expected): + # GH35076 Map pandas integer to optimal SQLAlchemy integer type + conn = request.getfixturevalue(conn) + df = DataFrame([0, 1], columns=["a"], dtype=integer) + with sql.SQLDatabase(conn) as db: + table = sql.SQLTable("test_type", db, frame=df) + + result = str(table.table.c.a.type) + assert result == expected + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +@pytest.mark.parametrize("integer", ["uint64", "UInt64"]) +def test_sqlalchemy_integer_overload_mapping(conn, request, integer): + conn = request.getfixturevalue(conn) + # GH35076 Map pandas integer to optimal SQLAlchemy integer type + df = DataFrame([0, 1], columns=["a"], dtype=integer) + with sql.SQLDatabase(conn) as db: + with pytest.raises( + ValueError, match="Unsigned 64 bit integer datatype is not supported" + ): + sql.SQLTable("test_type", db, frame=df) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_database_uri_string(conn, request, test_frame1): + pytest.importorskip("sqlalchemy") + conn = request.getfixturevalue(conn) + # Test read_sql and .to_sql method with a database URI (GH10654) + # db_uri = 'sqlite:///:memory:' # raises + # sqlalchemy.exc.OperationalError: (sqlite3.OperationalError) near + # "iris": syntax error [SQL: 'iris'] + with tm.ensure_clean() as name: + db_uri = "sqlite:///" + name + table = "iris" + test_frame1.to_sql(name=table, con=db_uri, if_exists="replace", index=False) + test_frame2 = sql.read_sql(table, db_uri) + test_frame3 = sql.read_sql_table(table, db_uri) + query = "SELECT * FROM iris" + test_frame4 = sql.read_sql_query(query, db_uri) + tm.assert_frame_equal(test_frame1, test_frame2) + tm.assert_frame_equal(test_frame1, test_frame3) + tm.assert_frame_equal(test_frame1, test_frame4) + + +@td.skip_if_installed("pg8000") +@pytest.mark.parametrize("conn", all_connectable) +def test_pg8000_sqlalchemy_passthrough_error(conn, request): + pytest.importorskip("sqlalchemy") + conn = request.getfixturevalue(conn) + # using driver that will not be installed on CI to trigger error + # in sqlalchemy.create_engine -> test passing of this error to user + db_uri = "postgresql+pg8000://user:pass@host/dbname" + with pytest.raises(ImportError, match="pg8000"): + sql.read_sql("select * from table", db_uri) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris) +def test_query_by_text_obj(conn, request): + # WIP : GH10846 + conn_name = conn + conn = request.getfixturevalue(conn) + from sqlalchemy import text + + if "postgres" in conn_name: + name_text = text('select * from iris where "Name"=:name') + else: + name_text = text("select * from iris where name=:name") + iris_df = sql.read_sql(name_text, conn, params={"name": "Iris-versicolor"}) + all_names = set(iris_df["Name"]) + assert all_names == {"Iris-versicolor"} + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris) +def test_query_by_select_obj(conn, request): + conn = request.getfixturevalue(conn) + # WIP : GH10846 + from sqlalchemy import ( + bindparam, + select, + ) + + iris = iris_table_metadata() + name_select = select(iris).where(iris.c.Name == bindparam("name")) + iris_df = sql.read_sql(name_select, conn, params={"name": "Iris-setosa"}) + all_names = set(iris_df["Name"]) + assert all_names == {"Iris-setosa"} + + +@pytest.mark.parametrize("conn", all_connectable) +def test_column_with_percentage(conn, request): + # GH 37157 + conn_name = conn + if conn_name == "sqlite_buildin": + request.applymarker(pytest.mark.xfail(reason="Not Implemented")) + + conn = request.getfixturevalue(conn) + df = DataFrame({"A": [0, 1, 2], "%_variation": [3, 4, 5]}) + df.to_sql(name="test_column_percentage", con=conn, index=False) + + res = sql.read_sql_table("test_column_percentage", conn) + + tm.assert_frame_equal(res, df) + + +def test_sql_open_close(test_frame3): + # Test if the IO in the database still work if the connection closed + # between the writing and reading (as in many real situations). + + with tm.ensure_clean() as name: + with closing(sqlite3.connect(name)) as conn: + assert sql.to_sql(test_frame3, "test_frame3_legacy", conn, index=False) == 4 + + with closing(sqlite3.connect(name)) as conn: + result = sql.read_sql_query("SELECT * FROM test_frame3_legacy;", conn) + + tm.assert_frame_equal(test_frame3, result) + + +@td.skip_if_installed("sqlalchemy") +def test_con_string_import_error(): + conn = "mysql://root@localhost/pandas" + msg = "Using URI string without sqlalchemy installed" + with pytest.raises(ImportError, match=msg): + sql.read_sql("SELECT * FROM iris", conn) + + +@td.skip_if_installed("sqlalchemy") +def test_con_unknown_dbapi2_class_does_not_error_without_sql_alchemy_installed(): + class MockSqliteConnection: + def __init__(self, *args, **kwargs) -> None: + self.conn = sqlite3.Connection(*args, **kwargs) + + def __getattr__(self, name): + return getattr(self.conn, name) + + def close(self): + self.conn.close() + + with contextlib.closing(MockSqliteConnection(":memory:")) as conn: + with tm.assert_produces_warning(UserWarning): + sql.read_sql("SELECT 1", conn) + + +def test_sqlite_read_sql_delegate(sqlite_buildin_iris): + conn = sqlite_buildin_iris + iris_frame1 = sql.read_sql_query("SELECT * FROM iris", conn) + iris_frame2 = sql.read_sql("SELECT * FROM iris", conn) + tm.assert_frame_equal(iris_frame1, iris_frame2) + + msg = "Execution failed on sql 'iris': near \"iris\": syntax error" + with pytest.raises(sql.DatabaseError, match=msg): + sql.read_sql("iris", conn) + + +def test_get_schema2(test_frame1): + # without providing a connection object (available for backwards comp) + create_sql = sql.get_schema(test_frame1, "test") + assert "CREATE" in create_sql + + +def test_sqlite_type_mapping(sqlite_buildin): + # Test Timestamp objects (no datetime64 because of timezone) (GH9085) + conn = sqlite_buildin + df = DataFrame( + {"time": to_datetime(["2014-12-12 01:54", "2014-12-11 02:54"], utc=True)} + ) + db = sql.SQLiteDatabase(conn) + table = sql.SQLiteTable("test_type", db, frame=df) + schema = table.sql_schema() + for col in schema.split("\n"): + if col.split()[0].strip('"') == "time": + assert col.split()[1] == "TIMESTAMP" + + +# ----------------------------------------------------------------------------- +# -- Database flavor specific tests + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_create_table(conn, request): + if conn == "sqlite_str": + pytest.skip("sqlite_str has no inspection system") + + conn = request.getfixturevalue(conn) + + from sqlalchemy import inspect + + temp_frame = DataFrame({"one": [1.0, 2.0, 3.0, 4.0], "two": [4.0, 3.0, 2.0, 1.0]}) + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + assert pandasSQL.to_sql(temp_frame, "temp_frame") == 4 + + insp = inspect(conn) + assert insp.has_table("temp_frame") + + # Cleanup + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("temp_frame") + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_drop_table(conn, request): + if conn == "sqlite_str": + pytest.skip("sqlite_str has no inspection system") + + conn = request.getfixturevalue(conn) + + from sqlalchemy import inspect + + temp_frame = DataFrame({"one": [1.0, 2.0, 3.0, 4.0], "two": [4.0, 3.0, 2.0, 1.0]}) + with sql.SQLDatabase(conn) as pandasSQL: + with pandasSQL.run_transaction(): + assert pandasSQL.to_sql(temp_frame, "temp_frame") == 4 + + insp = inspect(conn) + assert insp.has_table("temp_frame") + + with pandasSQL.run_transaction(): + pandasSQL.drop_table("temp_frame") + try: + insp.clear_cache() # needed with SQLAlchemy 2.0, unavailable prior + except AttributeError: + pass + assert not insp.has_table("temp_frame") + + +@pytest.mark.parametrize("conn", all_connectable) +def test_roundtrip(conn, request, test_frame1): + if conn == "sqlite_str": + pytest.skip("sqlite_str has no inspection system") + + conn_name = conn + conn = request.getfixturevalue(conn) + pandasSQL = pandasSQL_builder(conn) + with pandasSQL.run_transaction(): + assert pandasSQL.to_sql(test_frame1, "test_frame_roundtrip") == 4 + result = pandasSQL.read_query("SELECT * FROM test_frame_roundtrip") + + if "adbc" in conn_name: + result = result.rename(columns={"__index_level_0__": "level_0"}) + result.set_index("level_0", inplace=True) + # result.index.astype(int) + + result.index.name = None + + tm.assert_frame_equal(result, test_frame1) + + +@pytest.mark.parametrize("conn", all_connectable_iris) +def test_execute_sql(conn, request): + conn = request.getfixturevalue(conn) + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction(): + iris_results = pandasSQL.execute("SELECT * FROM iris") + row = iris_results.fetchone() + iris_results.close() + assert list(row) == [5.1, 3.5, 1.4, 0.2, "Iris-setosa"] + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris) +def test_sqlalchemy_read_table(conn, request): + conn = request.getfixturevalue(conn) + iris_frame = sql.read_sql_table("iris", con=conn) + check_iris_frame(iris_frame) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris) +def test_sqlalchemy_read_table_columns(conn, request): + conn = request.getfixturevalue(conn) + iris_frame = sql.read_sql_table( + "iris", con=conn, columns=["SepalLength", "SepalLength"] + ) + tm.assert_index_equal(iris_frame.columns, Index(["SepalLength", "SepalLength__1"])) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris) +def test_read_table_absent_raises(conn, request): + conn = request.getfixturevalue(conn) + msg = "Table this_doesnt_exist not found" + with pytest.raises(ValueError, match=msg): + sql.read_sql_table("this_doesnt_exist", con=conn) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable_types) +def test_sqlalchemy_default_type_conversion(conn, request): + conn_name = conn + if conn_name == "sqlite_str": + pytest.skip("types tables not created in sqlite_str fixture") + elif "mysql" in conn_name or "sqlite" in conn_name: + request.applymarker( + pytest.mark.xfail(reason="boolean dtype not inferred properly") + ) + + conn = request.getfixturevalue(conn) + df = sql.read_sql_table("types", conn) + + assert issubclass(df.FloatCol.dtype.type, np.floating) + assert issubclass(df.IntCol.dtype.type, np.integer) + assert issubclass(df.BoolCol.dtype.type, np.bool_) + + # Int column with NA values stays as float + assert issubclass(df.IntColWithNull.dtype.type, np.floating) + # Bool column with NA values becomes object + assert issubclass(df.BoolColWithNull.dtype.type, object) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_bigint(conn, request): + # int64 should be converted to BigInteger, GH7433 + conn = request.getfixturevalue(conn) + df = DataFrame(data={"i64": [2**62]}) + assert df.to_sql(name="test_bigint", con=conn, index=False) == 1 + result = sql.read_sql_table("test_bigint", conn) + + tm.assert_frame_equal(df, result) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable_types) +def test_default_date_load(conn, request): + conn_name = conn + if conn_name == "sqlite_str": + pytest.skip("types tables not created in sqlite_str fixture") + elif "sqlite" in conn_name: + request.applymarker( + pytest.mark.xfail(reason="sqlite does not read date properly") + ) + + conn = request.getfixturevalue(conn) + df = sql.read_sql_table("types", conn) + + assert issubclass(df.DateCol.dtype.type, np.datetime64) + + +@pytest.mark.parametrize("conn", postgresql_connectable) +@pytest.mark.parametrize("parse_dates", [None, ["DateColWithTz"]]) +def test_datetime_with_timezone_query(conn, request, parse_dates): + # edge case that converts postgresql datetime with time zone types + # to datetime64[ns,psycopg2.tz.FixedOffsetTimezone..], which is ok + # but should be more natural, so coerce to datetime64[ns] for now + conn = request.getfixturevalue(conn) + expected = create_and_load_postgres_datetz(conn) + + # GH11216 + df = read_sql_query("select * from datetz", conn, parse_dates=parse_dates) + col = df.DateColWithTz + tm.assert_series_equal(col, expected) + + +@pytest.mark.parametrize("conn", postgresql_connectable) +def test_datetime_with_timezone_query_chunksize(conn, request): + conn = request.getfixturevalue(conn) + expected = create_and_load_postgres_datetz(conn) + + df = concat( + list(read_sql_query("select * from datetz", conn, chunksize=1)), + ignore_index=True, + ) + col = df.DateColWithTz + tm.assert_series_equal(col, expected) + + +@pytest.mark.parametrize("conn", postgresql_connectable) +def test_datetime_with_timezone_table(conn, request): + conn = request.getfixturevalue(conn) + expected = create_and_load_postgres_datetz(conn) + result = sql.read_sql_table("datetz", conn) + tm.assert_frame_equal(result, expected.to_frame()) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_datetime_with_timezone_roundtrip(conn, request): + conn_name = conn + conn = request.getfixturevalue(conn) + # GH 9086 + # Write datetimetz data to a db and read it back + # For dbs that support timestamps with timezones, should get back UTC + # otherwise naive data should be returned + expected = DataFrame( + {"A": date_range("2013-01-01 09:00:00", periods=3, tz="US/Pacific")} + ) + assert expected.to_sql(name="test_datetime_tz", con=conn, index=False) == 3 + + if "postgresql" in conn_name: + # SQLAlchemy "timezones" (i.e. offsets) are coerced to UTC + expected["A"] = expected["A"].dt.tz_convert("UTC") + else: + # Otherwise, timestamps are returned as local, naive + expected["A"] = expected["A"].dt.tz_localize(None) + + result = sql.read_sql_table("test_datetime_tz", conn) + tm.assert_frame_equal(result, expected) + + result = sql.read_sql_query("SELECT * FROM test_datetime_tz", conn) + if "sqlite" in conn_name: + # read_sql_query does not return datetime type like read_sql_table + assert isinstance(result.loc[0, "A"], str) + result["A"] = to_datetime(result["A"]) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_out_of_bounds_datetime(conn, request): + # GH 26761 + conn = request.getfixturevalue(conn) + data = DataFrame({"date": datetime(9999, 1, 1)}, index=[0]) + assert data.to_sql(name="test_datetime_obb", con=conn, index=False) == 1 + result = sql.read_sql_table("test_datetime_obb", conn) + expected = DataFrame([pd.NaT], columns=["date"]) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_naive_datetimeindex_roundtrip(conn, request): + # GH 23510 + # Ensure that a naive DatetimeIndex isn't converted to UTC + conn = request.getfixturevalue(conn) + dates = date_range("2018-01-01", periods=5, freq="6h")._with_freq(None) + expected = DataFrame({"nums": range(5)}, index=dates) + assert expected.to_sql(name="foo_table", con=conn, index_label="info_date") == 5 + result = sql.read_sql_table("foo_table", conn, index_col="info_date") + # result index with gain a name from a set_index operation; expected + tm.assert_frame_equal(result, expected, check_names=False) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable_types) +def test_date_parsing(conn, request): + # No Parsing + conn_name = conn + conn = request.getfixturevalue(conn) + df = sql.read_sql_table("types", conn) + expected_type = object if "sqlite" in conn_name else np.datetime64 + assert issubclass(df.DateCol.dtype.type, expected_type) + + df = sql.read_sql_table("types", conn, parse_dates=["DateCol"]) + assert issubclass(df.DateCol.dtype.type, np.datetime64) + + df = sql.read_sql_table("types", conn, parse_dates={"DateCol": "%Y-%m-%d %H:%M:%S"}) + assert issubclass(df.DateCol.dtype.type, np.datetime64) + + df = sql.read_sql_table( + "types", + conn, + parse_dates={"DateCol": {"format": "%Y-%m-%d %H:%M:%S"}}, + ) + assert issubclass(df.DateCol.dtype.type, np.datetime64) + + df = sql.read_sql_table("types", conn, parse_dates=["IntDateCol"]) + assert issubclass(df.IntDateCol.dtype.type, np.datetime64) + + df = sql.read_sql_table("types", conn, parse_dates={"IntDateCol": "s"}) + assert issubclass(df.IntDateCol.dtype.type, np.datetime64) + + df = sql.read_sql_table("types", conn, parse_dates={"IntDateCol": {"unit": "s"}}) + assert issubclass(df.IntDateCol.dtype.type, np.datetime64) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_datetime(conn, request): + conn_name = conn + conn = request.getfixturevalue(conn) + df = DataFrame( + {"A": date_range("2013-01-01 09:00:00", periods=3), "B": np.arange(3.0)} + ) + assert df.to_sql(name="test_datetime", con=conn) == 3 + + # with read_table -> type information from schema used + result = sql.read_sql_table("test_datetime", conn) + result = result.drop("index", axis=1) + tm.assert_frame_equal(result, df) + + # with read_sql -> no type information -> sqlite has no native + result = sql.read_sql_query("SELECT * FROM test_datetime", conn) + result = result.drop("index", axis=1) + if "sqlite" in conn_name: + assert isinstance(result.loc[0, "A"], str) + result["A"] = to_datetime(result["A"]) + tm.assert_frame_equal(result, df) + else: + tm.assert_frame_equal(result, df) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_datetime_NaT(conn, request): + conn_name = conn + conn = request.getfixturevalue(conn) + df = DataFrame( + {"A": date_range("2013-01-01 09:00:00", periods=3), "B": np.arange(3.0)} + ) + df.loc[1, "A"] = np.nan + assert df.to_sql(name="test_datetime", con=conn, index=False) == 3 + + # with read_table -> type information from schema used + result = sql.read_sql_table("test_datetime", conn) + tm.assert_frame_equal(result, df) + + # with read_sql -> no type information -> sqlite has no native + result = sql.read_sql_query("SELECT * FROM test_datetime", conn) + if "sqlite" in conn_name: + assert isinstance(result.loc[0, "A"], str) + result["A"] = to_datetime(result["A"], errors="coerce") + tm.assert_frame_equal(result, df) + else: + tm.assert_frame_equal(result, df) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_datetime_date(conn, request): + # test support for datetime.date + conn = request.getfixturevalue(conn) + df = DataFrame([date(2014, 1, 1), date(2014, 1, 2)], columns=["a"]) + assert df.to_sql(name="test_date", con=conn, index=False) == 2 + res = read_sql_table("test_date", conn) + result = res["a"] + expected = to_datetime(df["a"]) + # comes back as datetime64 + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_datetime_time(conn, request, sqlite_buildin): + # test support for datetime.time + conn_name = conn + conn = request.getfixturevalue(conn) + df = DataFrame([time(9, 0, 0), time(9, 1, 30)], columns=["a"]) + assert df.to_sql(name="test_time", con=conn, index=False) == 2 + res = read_sql_table("test_time", conn) + tm.assert_frame_equal(res, df) + + # GH8341 + # first, use the fallback to have the sqlite adapter put in place + sqlite_conn = sqlite_buildin + assert sql.to_sql(df, "test_time2", sqlite_conn, index=False) == 2 + res = sql.read_sql_query("SELECT * FROM test_time2", sqlite_conn) + ref = df.map(lambda _: _.strftime("%H:%M:%S.%f")) + tm.assert_frame_equal(ref, res) # check if adapter is in place + # then test if sqlalchemy is unaffected by the sqlite adapter + assert sql.to_sql(df, "test_time3", conn, index=False) == 2 + if "sqlite" in conn_name: + res = sql.read_sql_query("SELECT * FROM test_time3", conn) + ref = df.map(lambda _: _.strftime("%H:%M:%S.%f")) + tm.assert_frame_equal(ref, res) + res = sql.read_sql_table("test_time3", conn) + tm.assert_frame_equal(df, res) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_mixed_dtype_insert(conn, request): + # see GH6509 + conn = request.getfixturevalue(conn) + s1 = Series(2**25 + 1, dtype=np.int32) + s2 = Series(0.0, dtype=np.float32) + df = DataFrame({"s1": s1, "s2": s2}) + + # write and read again + assert df.to_sql(name="test_read_write", con=conn, index=False) == 1 + df2 = sql.read_sql_table("test_read_write", conn) + + tm.assert_frame_equal(df, df2, check_dtype=False, check_exact=True) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_nan_numeric(conn, request): + # NaNs in numeric float column + conn = request.getfixturevalue(conn) + df = DataFrame({"A": [0, 1, 2], "B": [0.2, np.nan, 5.6]}) + assert df.to_sql(name="test_nan", con=conn, index=False) == 3 + + # with read_table + result = sql.read_sql_table("test_nan", conn) + tm.assert_frame_equal(result, df) + + # with read_sql + result = sql.read_sql_query("SELECT * FROM test_nan", conn) + tm.assert_frame_equal(result, df) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_nan_fullcolumn(conn, request): + # full NaN column (numeric float column) + conn = request.getfixturevalue(conn) + df = DataFrame({"A": [0, 1, 2], "B": [np.nan, np.nan, np.nan]}) + assert df.to_sql(name="test_nan", con=conn, index=False) == 3 + + # with read_table + result = sql.read_sql_table("test_nan", conn) + tm.assert_frame_equal(result, df) + + # with read_sql -> not type info from table -> stays None + df["B"] = df["B"].astype("object") + df["B"] = None + result = sql.read_sql_query("SELECT * FROM test_nan", conn) + tm.assert_frame_equal(result, df) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_nan_string(conn, request): + # NaNs in string column + conn = request.getfixturevalue(conn) + df = DataFrame({"A": [0, 1, 2], "B": ["a", "b", np.nan]}) + assert df.to_sql(name="test_nan", con=conn, index=False) == 3 + + # NaNs are coming back as None + df.loc[2, "B"] = None + + # with read_table + result = sql.read_sql_table("test_nan", conn) + tm.assert_frame_equal(result, df) + + # with read_sql + result = sql.read_sql_query("SELECT * FROM test_nan", conn) + tm.assert_frame_equal(result, df) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_to_sql_save_index(conn, request): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail( + reason="ADBC implementation does not create index", strict=True + ) + ) + conn_name = conn + conn = request.getfixturevalue(conn) + df = DataFrame.from_records( + [(1, 2.1, "line1"), (2, 1.5, "line2")], columns=["A", "B", "C"], index=["A"] + ) + + tbl_name = "test_to_sql_saves_index" + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction(): + assert pandasSQL.to_sql(df, tbl_name) == 2 + + if conn_name in {"sqlite_buildin", "sqlite_str"}: + ixs = sql.read_sql_query( + "SELECT * FROM sqlite_master WHERE type = 'index' " + f"AND tbl_name = '{tbl_name}'", + conn, + ) + ix_cols = [] + for ix_name in ixs.name: + ix_info = sql.read_sql_query(f"PRAGMA index_info({ix_name})", conn) + ix_cols.append(ix_info.name.tolist()) + else: + from sqlalchemy import inspect + + insp = inspect(conn) + + ixs = insp.get_indexes(tbl_name) + ix_cols = [i["column_names"] for i in ixs] + + assert ix_cols == [["A"]] + + +@pytest.mark.parametrize("conn", all_connectable) +def test_transactions(conn, request): + conn_name = conn + conn = request.getfixturevalue(conn) + + stmt = "CREATE TABLE test_trans (A INT, B TEXT)" + if conn_name != "sqlite_buildin" and "adbc" not in conn_name: + from sqlalchemy import text + + stmt = text(stmt) + + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction() as trans: + trans.execute(stmt) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_transaction_rollback(conn, request): + conn_name = conn + conn = request.getfixturevalue(conn) + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction() as trans: + stmt = "CREATE TABLE test_trans (A INT, B TEXT)" + if "adbc" in conn_name or isinstance(pandasSQL, SQLiteDatabase): + trans.execute(stmt) + else: + from sqlalchemy import text + + stmt = text(stmt) + trans.execute(stmt) + + class DummyException(Exception): + pass + + # Make sure when transaction is rolled back, no rows get inserted + ins_sql = "INSERT INTO test_trans (A,B) VALUES (1, 'blah')" + if isinstance(pandasSQL, SQLDatabase): + from sqlalchemy import text + + ins_sql = text(ins_sql) + try: + with pandasSQL.run_transaction() as trans: + trans.execute(ins_sql) + raise DummyException("error") + except DummyException: + # ignore raised exception + pass + with pandasSQL.run_transaction(): + res = pandasSQL.read_query("SELECT * FROM test_trans") + assert len(res) == 0 + + # Make sure when transaction is committed, rows do get inserted + with pandasSQL.run_transaction() as trans: + trans.execute(ins_sql) + res2 = pandasSQL.read_query("SELECT * FROM test_trans") + assert len(res2) == 1 + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_get_schema_create_table(conn, request, test_frame3): + # Use a dataframe without a bool column, since MySQL converts bool to + # TINYINT (which read_sql_table returns as an int and causes a dtype + # mismatch) + if conn == "sqlite_str": + request.applymarker( + pytest.mark.xfail(reason="test does not support sqlite_str fixture") + ) + + conn = request.getfixturevalue(conn) + + from sqlalchemy import text + from sqlalchemy.engine import Engine + + tbl = "test_get_schema_create_table" + create_sql = sql.get_schema(test_frame3, tbl, con=conn) + blank_test_df = test_frame3.iloc[:0] + + create_sql = text(create_sql) + if isinstance(conn, Engine): + with conn.connect() as newcon: + with newcon.begin(): + newcon.execute(create_sql) + else: + conn.execute(create_sql) + returned_df = sql.read_sql_table(tbl, conn) + tm.assert_frame_equal(returned_df, blank_test_df, check_index_type=False) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_dtype(conn, request): + if conn == "sqlite_str": + pytest.skip("sqlite_str has no inspection system") + + conn = request.getfixturevalue(conn) + + from sqlalchemy import ( + TEXT, + String, + ) + from sqlalchemy.schema import MetaData + + cols = ["A", "B"] + data = [(0.8, True), (0.9, None)] + df = DataFrame(data, columns=cols) + assert df.to_sql(name="dtype_test", con=conn) == 2 + assert df.to_sql(name="dtype_test2", con=conn, dtype={"B": TEXT}) == 2 + meta = MetaData() + meta.reflect(bind=conn) + sqltype = meta.tables["dtype_test2"].columns["B"].type + assert isinstance(sqltype, TEXT) + msg = "The type of B is not a SQLAlchemy type" + with pytest.raises(ValueError, match=msg): + df.to_sql(name="error", con=conn, dtype={"B": str}) + + # GH9083 + assert df.to_sql(name="dtype_test3", con=conn, dtype={"B": String(10)}) == 2 + meta.reflect(bind=conn) + sqltype = meta.tables["dtype_test3"].columns["B"].type + assert isinstance(sqltype, String) + assert sqltype.length == 10 + + # single dtype + assert df.to_sql(name="single_dtype_test", con=conn, dtype=TEXT) == 2 + meta.reflect(bind=conn) + sqltypea = meta.tables["single_dtype_test"].columns["A"].type + sqltypeb = meta.tables["single_dtype_test"].columns["B"].type + assert isinstance(sqltypea, TEXT) + assert isinstance(sqltypeb, TEXT) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_notna_dtype(conn, request): + if conn == "sqlite_str": + pytest.skip("sqlite_str has no inspection system") + + conn_name = conn + conn = request.getfixturevalue(conn) + + from sqlalchemy import ( + Boolean, + DateTime, + Float, + Integer, + ) + from sqlalchemy.schema import MetaData + + cols = { + "Bool": Series([True, None]), + "Date": Series([datetime(2012, 5, 1), None]), + "Int": Series([1, None], dtype="object"), + "Float": Series([1.1, None]), + } + df = DataFrame(cols) + + tbl = "notna_dtype_test" + assert df.to_sql(name=tbl, con=conn) == 2 + _ = sql.read_sql_table(tbl, conn) + meta = MetaData() + meta.reflect(bind=conn) + my_type = Integer if "mysql" in conn_name else Boolean + col_dict = meta.tables[tbl].columns + assert isinstance(col_dict["Bool"].type, my_type) + assert isinstance(col_dict["Date"].type, DateTime) + assert isinstance(col_dict["Int"].type, Integer) + assert isinstance(col_dict["Float"].type, Float) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_double_precision(conn, request): + if conn == "sqlite_str": + pytest.skip("sqlite_str has no inspection system") + + conn = request.getfixturevalue(conn) + + from sqlalchemy import ( + BigInteger, + Float, + Integer, + ) + from sqlalchemy.schema import MetaData + + V = 1.23456789101112131415 + + df = DataFrame( + { + "f32": Series([V], dtype="float32"), + "f64": Series([V], dtype="float64"), + "f64_as_f32": Series([V], dtype="float64"), + "i32": Series([5], dtype="int32"), + "i64": Series([5], dtype="int64"), + } + ) + + assert ( + df.to_sql( + name="test_dtypes", + con=conn, + index=False, + if_exists="replace", + dtype={"f64_as_f32": Float(precision=23)}, + ) + == 1 + ) + res = sql.read_sql_table("test_dtypes", conn) + + # check precision of float64 + assert np.round(df["f64"].iloc[0], 14) == np.round(res["f64"].iloc[0], 14) + + # check sql types + meta = MetaData() + meta.reflect(bind=conn) + col_dict = meta.tables["test_dtypes"].columns + assert str(col_dict["f32"].type) == str(col_dict["f64_as_f32"].type) + assert isinstance(col_dict["f32"].type, Float) + assert isinstance(col_dict["f64"].type, Float) + assert isinstance(col_dict["i32"].type, Integer) + assert isinstance(col_dict["i64"].type, BigInteger) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_connectable_issue_example(conn, request): + conn = request.getfixturevalue(conn) + + # This tests the example raised in issue + # https://github.com/pandas-dev/pandas/issues/10104 + from sqlalchemy.engine import Engine + + def test_select(connection): + query = "SELECT test_foo_data FROM test_foo_data" + return sql.read_sql_query(query, con=connection) + + def test_append(connection, data): + data.to_sql(name="test_foo_data", con=connection, if_exists="append") + + def test_connectable(conn): + # https://github.com/sqlalchemy/sqlalchemy/commit/ + # 00b5c10846e800304caa86549ab9da373b42fa5d#r48323973 + foo_data = test_select(conn) + test_append(conn, foo_data) + + def main(connectable): + if isinstance(connectable, Engine): + with connectable.connect() as conn: + with conn.begin(): + test_connectable(conn) + else: + test_connectable(connectable) + + assert ( + DataFrame({"test_foo_data": [0, 1, 2]}).to_sql(name="test_foo_data", con=conn) + == 3 + ) + main(conn) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +@pytest.mark.parametrize( + "input", + [{"foo": [np.inf]}, {"foo": [-np.inf]}, {"foo": [-np.inf], "infe0": ["bar"]}], +) +def test_to_sql_with_negative_npinf(conn, request, input): + # GH 34431 + + df = DataFrame(input) + conn_name = conn + conn = request.getfixturevalue(conn) + + if "mysql" in conn_name: + # GH 36465 + # The input {"foo": [-np.inf], "infe0": ["bar"]} does not raise any error + # for pymysql version >= 0.10 + # TODO(GH#36465): remove this version check after GH 36465 is fixed + pymysql = pytest.importorskip("pymysql") + + if Version(pymysql.__version__) < Version("1.0.3") and "infe0" in df.columns: + mark = pytest.mark.xfail(reason="GH 36465") + request.applymarker(mark) + + msg = "inf cannot be used with MySQL" + with pytest.raises(ValueError, match=msg): + df.to_sql(name="foobar", con=conn, index=False) + else: + assert df.to_sql(name="foobar", con=conn, index=False) == 1 + res = sql.read_sql_table("foobar", conn) + tm.assert_equal(df, res) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_temporary_table(conn, request): + if conn == "sqlite_str": + pytest.skip("test does not work with str connection") + + conn = request.getfixturevalue(conn) + + from sqlalchemy import ( + Column, + Integer, + Unicode, + select, + ) + from sqlalchemy.orm import ( + Session, + declarative_base, + ) + + test_data = "Hello, World!" + expected = DataFrame({"spam": [test_data]}) + Base = declarative_base() + + class Temporary(Base): + __tablename__ = "temp_test" + __table_args__ = {"prefixes": ["TEMPORARY"]} + id = Column(Integer, primary_key=True) + spam = Column(Unicode(30), nullable=False) + + with Session(conn) as session: + with session.begin(): + conn = session.connection() + Temporary.__table__.create(conn) + session.add(Temporary(spam=test_data)) + session.flush() + df = sql.read_sql_query(sql=select(Temporary.spam), con=conn) + tm.assert_frame_equal(df, expected) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_invalid_engine(conn, request, test_frame1): + if conn == "sqlite_buildin" or "adbc" in conn: + request.applymarker( + pytest.mark.xfail( + reason="SQLiteDatabase/ADBCDatabase does not raise for bad engine" + ) + ) + + conn = request.getfixturevalue(conn) + msg = "engine must be one of 'auto', 'sqlalchemy'" + with pandasSQL_builder(conn) as pandasSQL: + with pytest.raises(ValueError, match=msg): + pandasSQL.to_sql(test_frame1, "test_frame1", engine="bad_engine") + + +@pytest.mark.parametrize("conn", all_connectable) +def test_to_sql_with_sql_engine(conn, request, test_frame1): + """`to_sql` with the `engine` param""" + # mostly copied from this class's `_to_sql()` method + conn = request.getfixturevalue(conn) + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction(): + assert pandasSQL.to_sql(test_frame1, "test_frame1", engine="auto") == 4 + assert pandasSQL.has_table("test_frame1") + + num_entries = len(test_frame1) + num_rows = count_rows(conn, "test_frame1") + assert num_rows == num_entries + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_options_sqlalchemy(conn, request, test_frame1): + # use the set option + conn = request.getfixturevalue(conn) + with pd.option_context("io.sql.engine", "sqlalchemy"): + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction(): + assert pandasSQL.to_sql(test_frame1, "test_frame1") == 4 + assert pandasSQL.has_table("test_frame1") + + num_entries = len(test_frame1) + num_rows = count_rows(conn, "test_frame1") + assert num_rows == num_entries + + +@pytest.mark.parametrize("conn", all_connectable) +def test_options_auto(conn, request, test_frame1): + # use the set option + conn = request.getfixturevalue(conn) + with pd.option_context("io.sql.engine", "auto"): + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction(): + assert pandasSQL.to_sql(test_frame1, "test_frame1") == 4 + assert pandasSQL.has_table("test_frame1") + + num_entries = len(test_frame1) + num_rows = count_rows(conn, "test_frame1") + assert num_rows == num_entries + + +def test_options_get_engine(): + pytest.importorskip("sqlalchemy") + assert isinstance(get_engine("sqlalchemy"), SQLAlchemyEngine) + + with pd.option_context("io.sql.engine", "sqlalchemy"): + assert isinstance(get_engine("auto"), SQLAlchemyEngine) + assert isinstance(get_engine("sqlalchemy"), SQLAlchemyEngine) + + with pd.option_context("io.sql.engine", "auto"): + assert isinstance(get_engine("auto"), SQLAlchemyEngine) + assert isinstance(get_engine("sqlalchemy"), SQLAlchemyEngine) + + +def test_get_engine_auto_error_message(): + # Expect different error messages from get_engine(engine="auto") + # if engines aren't installed vs. are installed but bad version + pass + # TODO(GH#36893) fill this in when we add more engines + + +@pytest.mark.parametrize("conn", all_connectable) +@pytest.mark.parametrize("func", ["read_sql", "read_sql_query"]) +def test_read_sql_dtype_backend( + conn, + request, + string_storage, + func, + dtype_backend, + dtype_backend_data, + dtype_backend_expected, +): + # GH#50048 + conn_name = conn + conn = request.getfixturevalue(conn) + table = "test" + df = dtype_backend_data + df.to_sql(name=table, con=conn, index=False, if_exists="replace") + + with pd.option_context("mode.string_storage", string_storage): + result = getattr(pd, func)( + f"Select * from {table}", conn, dtype_backend=dtype_backend + ) + expected = dtype_backend_expected(string_storage, dtype_backend, conn_name) + tm.assert_frame_equal(result, expected) + + if "adbc" in conn_name: + # adbc does not support chunksize argument + request.applymarker( + pytest.mark.xfail(reason="adbc does not support chunksize argument") + ) + + with pd.option_context("mode.string_storage", string_storage): + iterator = getattr(pd, func)( + f"Select * from {table}", + con=conn, + dtype_backend=dtype_backend, + chunksize=3, + ) + expected = dtype_backend_expected(string_storage, dtype_backend, conn_name) + for result in iterator: + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("conn", all_connectable) +@pytest.mark.parametrize("func", ["read_sql", "read_sql_table"]) +def test_read_sql_dtype_backend_table( + conn, + request, + string_storage, + func, + dtype_backend, + dtype_backend_data, + dtype_backend_expected, +): + if "sqlite" in conn and "adbc" not in conn: + request.applymarker( + pytest.mark.xfail( + reason=( + "SQLite actually returns proper boolean values via " + "read_sql_table, but before pytest refactor was skipped" + ) + ) + ) + # GH#50048 + conn_name = conn + conn = request.getfixturevalue(conn) + table = "test" + df = dtype_backend_data + df.to_sql(name=table, con=conn, index=False, if_exists="replace") + + with pd.option_context("mode.string_storage", string_storage): + result = getattr(pd, func)(table, conn, dtype_backend=dtype_backend) + expected = dtype_backend_expected(string_storage, dtype_backend, conn_name) + tm.assert_frame_equal(result, expected) + + if "adbc" in conn_name: + # adbc does not support chunksize argument + return + + with pd.option_context("mode.string_storage", string_storage): + iterator = getattr(pd, func)( + table, + conn, + dtype_backend=dtype_backend, + chunksize=3, + ) + expected = dtype_backend_expected(string_storage, dtype_backend, conn_name) + for result in iterator: + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("conn", all_connectable) +@pytest.mark.parametrize("func", ["read_sql", "read_sql_table", "read_sql_query"]) +def test_read_sql_invalid_dtype_backend_table(conn, request, func, dtype_backend_data): + conn = request.getfixturevalue(conn) + table = "test" + df = dtype_backend_data + df.to_sql(name=table, con=conn, index=False, if_exists="replace") + + msg = ( + "dtype_backend numpy is invalid, only 'numpy_nullable' and " + "'pyarrow' are allowed." + ) + with pytest.raises(ValueError, match=msg): + getattr(pd, func)(table, conn, dtype_backend="numpy") + + +@pytest.fixture +def dtype_backend_data() -> DataFrame: + return DataFrame( + { + "a": Series([1, np.nan, 3], dtype="Int64"), + "b": Series([1, 2, 3], dtype="Int64"), + "c": Series([1.5, np.nan, 2.5], dtype="Float64"), + "d": Series([1.5, 2.0, 2.5], dtype="Float64"), + "e": [True, False, None], + "f": [True, False, True], + "g": ["a", "b", "c"], + "h": ["a", "b", None], + } + ) + + +@pytest.fixture +def dtype_backend_expected(): + def func(storage, dtype_backend, conn_name) -> DataFrame: + string_array: StringArray | ArrowStringArray + string_array_na: StringArray | ArrowStringArray + if storage == "python": + string_array = StringArray(np.array(["a", "b", "c"], dtype=np.object_)) + string_array_na = StringArray(np.array(["a", "b", pd.NA], dtype=np.object_)) + + elif dtype_backend == "pyarrow": + pa = pytest.importorskip("pyarrow") + from pandas.arrays import ArrowExtensionArray + + string_array = ArrowExtensionArray(pa.array(["a", "b", "c"])) # type: ignore[assignment] + string_array_na = ArrowExtensionArray(pa.array(["a", "b", None])) # type: ignore[assignment] + + else: + pa = pytest.importorskip("pyarrow") + string_array = ArrowStringArray(pa.array(["a", "b", "c"])) + string_array_na = ArrowStringArray(pa.array(["a", "b", None])) + + df = DataFrame( + { + "a": Series([1, np.nan, 3], dtype="Int64"), + "b": Series([1, 2, 3], dtype="Int64"), + "c": Series([1.5, np.nan, 2.5], dtype="Float64"), + "d": Series([1.5, 2.0, 2.5], dtype="Float64"), + "e": Series([True, False, pd.NA], dtype="boolean"), + "f": Series([True, False, True], dtype="boolean"), + "g": string_array, + "h": string_array_na, + } + ) + if dtype_backend == "pyarrow": + pa = pytest.importorskip("pyarrow") + + from pandas.arrays import ArrowExtensionArray + + df = DataFrame( + { + col: ArrowExtensionArray(pa.array(df[col], from_pandas=True)) + for col in df.columns + } + ) + + if "mysql" in conn_name or "sqlite" in conn_name: + if dtype_backend == "numpy_nullable": + df = df.astype({"e": "Int64", "f": "Int64"}) + else: + df = df.astype({"e": "int64[pyarrow]", "f": "int64[pyarrow]"}) + + return df + + return func + + +@pytest.mark.parametrize("conn", all_connectable) +def test_chunksize_empty_dtypes(conn, request): + # GH#50245 + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail(reason="chunksize argument NotImplemented with ADBC") + ) + conn = request.getfixturevalue(conn) + dtypes = {"a": "int64", "b": "object"} + df = DataFrame(columns=["a", "b"]).astype(dtypes) + expected = df.copy() + df.to_sql(name="test", con=conn, index=False, if_exists="replace") + + for result in read_sql_query( + "SELECT * FROM test", + conn, + dtype=dtypes, + chunksize=1, + ): + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("conn", all_connectable) +@pytest.mark.parametrize("dtype_backend", [lib.no_default, "numpy_nullable"]) +@pytest.mark.parametrize("func", ["read_sql", "read_sql_query"]) +def test_read_sql_dtype(conn, request, func, dtype_backend): + # GH#50797 + conn = request.getfixturevalue(conn) + table = "test" + df = DataFrame({"a": [1, 2, 3], "b": 5}) + df.to_sql(name=table, con=conn, index=False, if_exists="replace") + + result = getattr(pd, func)( + f"Select * from {table}", + conn, + dtype={"a": np.float64}, + dtype_backend=dtype_backend, + ) + expected = DataFrame( + { + "a": Series([1, 2, 3], dtype=np.float64), + "b": Series( + [5, 5, 5], + dtype="int64" if not dtype_backend == "numpy_nullable" else "Int64", + ), + } + ) + tm.assert_frame_equal(result, expected) + + +def test_keyword_deprecation(sqlite_engine): + conn = sqlite_engine + # GH 54397 + msg = ( + "Starting with pandas version 3.0 all arguments of to_sql except for the " + "arguments 'name' and 'con' will be keyword-only." + ) + df = DataFrame([{"A": 1, "B": 2, "C": 3}, {"A": 1, "B": 2, "C": 3}]) + df.to_sql("example", conn) + + with tm.assert_produces_warning(FutureWarning, match=msg): + df.to_sql("example", conn, None, if_exists="replace") + + +def test_bigint_warning(sqlite_engine): + conn = sqlite_engine + # test no warning for BIGINT (to support int64) is raised (GH7433) + df = DataFrame({"a": [1, 2]}, dtype="int64") + assert df.to_sql(name="test_bigintwarning", con=conn, index=False) == 2 + + with tm.assert_produces_warning(None): + sql.read_sql_table("test_bigintwarning", conn) + + +def test_valueerror_exception(sqlite_engine): + conn = sqlite_engine + df = DataFrame({"col1": [1, 2], "col2": [3, 4]}) + with pytest.raises(ValueError, match="Empty table name specified"): + df.to_sql(name="", con=conn, if_exists="replace", index=False) + + +def test_row_object_is_named_tuple(sqlite_engine): + conn = sqlite_engine + # GH 40682 + # Test for the is_named_tuple() function + # Placed here due to its usage of sqlalchemy + + from sqlalchemy import ( + Column, + Integer, + String, + ) + from sqlalchemy.orm import ( + declarative_base, + sessionmaker, + ) + + BaseModel = declarative_base() + + class Test(BaseModel): + __tablename__ = "test_frame" + id = Column(Integer, primary_key=True) + string_column = Column(String(50)) + + with conn.begin(): + BaseModel.metadata.create_all(conn) + Session = sessionmaker(bind=conn) + with Session() as session: + df = DataFrame({"id": [0, 1], "string_column": ["hello", "world"]}) + assert ( + df.to_sql(name="test_frame", con=conn, index=False, if_exists="replace") + == 2 + ) + session.commit() + test_query = session.query(Test.id, Test.string_column) + df = DataFrame(test_query) + + assert list(df.columns) == ["id", "string_column"] + + +def test_read_sql_string_inference(sqlite_engine): + conn = sqlite_engine + # GH#54430 + pytest.importorskip("pyarrow") + table = "test" + df = DataFrame({"a": ["x", "y"]}) + df.to_sql(table, con=conn, index=False, if_exists="replace") + + with pd.option_context("future.infer_string", True): + result = read_sql_table(table, conn) + + dtype = "string[pyarrow_numpy]" + expected = DataFrame( + {"a": ["x", "y"]}, dtype=dtype, columns=Index(["a"], dtype=dtype) + ) + + tm.assert_frame_equal(result, expected) + + +def test_roundtripping_datetimes(sqlite_engine): + conn = sqlite_engine + # GH#54877 + df = DataFrame({"t": [datetime(2020, 12, 31, 12)]}, dtype="datetime64[ns]") + df.to_sql("test", conn, if_exists="replace", index=False) + result = pd.read_sql("select * from test", conn).iloc[0, 0] + assert result == "2020-12-31 12:00:00.000000" + + +@pytest.fixture +def sqlite_builtin_detect_types(): + with contextlib.closing( + sqlite3.connect(":memory:", detect_types=sqlite3.PARSE_DECLTYPES) + ) as closing_conn: + with closing_conn as conn: + yield conn + + +def test_roundtripping_datetimes_detect_types(sqlite_builtin_detect_types): + # https://github.com/pandas-dev/pandas/issues/55554 + conn = sqlite_builtin_detect_types + df = DataFrame({"t": [datetime(2020, 12, 31, 12)]}, dtype="datetime64[ns]") + df.to_sql("test", conn, if_exists="replace", index=False) + result = pd.read_sql("select * from test", conn).iloc[0, 0] + assert result == Timestamp("2020-12-31 12:00:00.000000") + + +@pytest.mark.db +def test_psycopg2_schema_support(postgresql_psycopg2_engine): + conn = postgresql_psycopg2_engine + + # only test this for postgresql (schema's not supported in + # mysql/sqlite) + df = DataFrame({"col1": [1, 2], "col2": [0.1, 0.2], "col3": ["a", "n"]}) + + # create a schema + with conn.connect() as con: + with con.begin(): + con.exec_driver_sql("DROP SCHEMA IF EXISTS other CASCADE;") + con.exec_driver_sql("CREATE SCHEMA other;") + + # write dataframe to different schema's + assert df.to_sql(name="test_schema_public", con=conn, index=False) == 2 + assert ( + df.to_sql( + name="test_schema_public_explicit", + con=conn, + index=False, + schema="public", + ) + == 2 + ) + assert ( + df.to_sql(name="test_schema_other", con=conn, index=False, schema="other") == 2 + ) + + # read dataframes back in + res1 = sql.read_sql_table("test_schema_public", conn) + tm.assert_frame_equal(df, res1) + res2 = sql.read_sql_table("test_schema_public_explicit", conn) + tm.assert_frame_equal(df, res2) + res3 = sql.read_sql_table("test_schema_public_explicit", conn, schema="public") + tm.assert_frame_equal(df, res3) + res4 = sql.read_sql_table("test_schema_other", conn, schema="other") + tm.assert_frame_equal(df, res4) + msg = "Table test_schema_other not found" + with pytest.raises(ValueError, match=msg): + sql.read_sql_table("test_schema_other", conn, schema="public") + + # different if_exists options + + # create a schema + with conn.connect() as con: + with con.begin(): + con.exec_driver_sql("DROP SCHEMA IF EXISTS other CASCADE;") + con.exec_driver_sql("CREATE SCHEMA other;") + + # write dataframe with different if_exists options + assert ( + df.to_sql(name="test_schema_other", con=conn, schema="other", index=False) == 2 + ) + df.to_sql( + name="test_schema_other", + con=conn, + schema="other", + index=False, + if_exists="replace", + ) + assert ( + df.to_sql( + name="test_schema_other", + con=conn, + schema="other", + index=False, + if_exists="append", + ) + == 2 + ) + res = sql.read_sql_table("test_schema_other", conn, schema="other") + tm.assert_frame_equal(concat([df, df], ignore_index=True), res) + + +@pytest.mark.db +def test_self_join_date_columns(postgresql_psycopg2_engine): + # GH 44421 + conn = postgresql_psycopg2_engine + from sqlalchemy.sql import text + + create_table = text( + """ + CREATE TABLE person + ( + id serial constraint person_pkey primary key, + created_dt timestamp with time zone + ); + + INSERT INTO person + VALUES (1, '2021-01-01T00:00:00Z'); + """ + ) + with conn.connect() as con: + with con.begin(): + con.execute(create_table) + + sql_query = ( + 'SELECT * FROM "person" AS p1 INNER JOIN "person" AS p2 ON p1.id = p2.id;' + ) + result = pd.read_sql(sql_query, conn) + expected = DataFrame( + [[1, Timestamp("2021", tz="UTC")] * 2], columns=["id", "created_dt"] * 2 + ) + tm.assert_frame_equal(result, expected) + + # Cleanup + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("person") + + +def test_create_and_drop_table(sqlite_engine): + conn = sqlite_engine + temp_frame = DataFrame({"one": [1.0, 2.0, 3.0, 4.0], "two": [4.0, 3.0, 2.0, 1.0]}) + with sql.SQLDatabase(conn) as pandasSQL: + with pandasSQL.run_transaction(): + assert pandasSQL.to_sql(temp_frame, "drop_test_frame") == 4 + + assert pandasSQL.has_table("drop_test_frame") + + with pandasSQL.run_transaction(): + pandasSQL.drop_table("drop_test_frame") + + assert not pandasSQL.has_table("drop_test_frame") + + +def test_sqlite_datetime_date(sqlite_buildin): + conn = sqlite_buildin + df = DataFrame([date(2014, 1, 1), date(2014, 1, 2)], columns=["a"]) + assert df.to_sql(name="test_date", con=conn, index=False) == 2 + res = read_sql_query("SELECT * FROM test_date", conn) + # comes back as strings + tm.assert_frame_equal(res, df.astype(str)) + + +@pytest.mark.parametrize("tz_aware", [False, True]) +def test_sqlite_datetime_time(tz_aware, sqlite_buildin): + conn = sqlite_buildin + # test support for datetime.time, GH #8341 + if not tz_aware: + tz_times = [time(9, 0, 0), time(9, 1, 30)] + else: + tz_dt = date_range("2013-01-01 09:00:00", periods=2, tz="US/Pacific") + tz_times = Series(tz_dt.to_pydatetime()).map(lambda dt: dt.timetz()) + + df = DataFrame(tz_times, columns=["a"]) + + assert df.to_sql(name="test_time", con=conn, index=False) == 2 + res = read_sql_query("SELECT * FROM test_time", conn) + # comes back as strings + expected = df.map(lambda _: _.strftime("%H:%M:%S.%f")) + tm.assert_frame_equal(res, expected) + + +def get_sqlite_column_type(conn, table, column): + recs = conn.execute(f"PRAGMA table_info({table})") + for cid, name, ctype, not_null, default, pk in recs: + if name == column: + return ctype + raise ValueError(f"Table {table}, column {column} not found") + + +def test_sqlite_test_dtype(sqlite_buildin): + conn = sqlite_buildin + cols = ["A", "B"] + data = [(0.8, True), (0.9, None)] + df = DataFrame(data, columns=cols) + assert df.to_sql(name="dtype_test", con=conn) == 2 + assert df.to_sql(name="dtype_test2", con=conn, dtype={"B": "STRING"}) == 2 + + # sqlite stores Boolean values as INTEGER + assert get_sqlite_column_type(conn, "dtype_test", "B") == "INTEGER" + + assert get_sqlite_column_type(conn, "dtype_test2", "B") == "STRING" + msg = r"B \(\) not a string" + with pytest.raises(ValueError, match=msg): + df.to_sql(name="error", con=conn, dtype={"B": bool}) + + # single dtype + assert df.to_sql(name="single_dtype_test", con=conn, dtype="STRING") == 2 + assert get_sqlite_column_type(conn, "single_dtype_test", "A") == "STRING" + assert get_sqlite_column_type(conn, "single_dtype_test", "B") == "STRING" + + +def test_sqlite_notna_dtype(sqlite_buildin): + conn = sqlite_buildin + cols = { + "Bool": Series([True, None]), + "Date": Series([datetime(2012, 5, 1), None]), + "Int": Series([1, None], dtype="object"), + "Float": Series([1.1, None]), + } + df = DataFrame(cols) + + tbl = "notna_dtype_test" + assert df.to_sql(name=tbl, con=conn) == 2 + + assert get_sqlite_column_type(conn, tbl, "Bool") == "INTEGER" + assert get_sqlite_column_type(conn, tbl, "Date") == "TIMESTAMP" + assert get_sqlite_column_type(conn, tbl, "Int") == "INTEGER" + assert get_sqlite_column_type(conn, tbl, "Float") == "REAL" + + +def test_sqlite_illegal_names(sqlite_buildin): + # For sqlite, these should work fine + conn = sqlite_buildin + df = DataFrame([[1, 2], [3, 4]], columns=["a", "b"]) + + msg = "Empty table or column name specified" + with pytest.raises(ValueError, match=msg): + df.to_sql(name="", con=conn) + + for ndx, weird_name in enumerate( + [ + "test_weird_name]", + "test_weird_name[", + "test_weird_name`", + 'test_weird_name"', + "test_weird_name'", + "_b.test_weird_name_01-30", + '"_b.test_weird_name_01-30"', + "99beginswithnumber", + "12345", + "\xe9", + ] + ): + assert df.to_sql(name=weird_name, con=conn) == 2 + sql.table_exists(weird_name, conn) + + df2 = DataFrame([[1, 2], [3, 4]], columns=["a", weird_name]) + c_tbl = f"test_weird_col_name{ndx:d}" + assert df2.to_sql(name=c_tbl, con=conn) == 2 + sql.table_exists(c_tbl, conn) + + +def format_query(sql, *args): + _formatters = { + datetime: "'{}'".format, + str: "'{}'".format, + np.str_: "'{}'".format, + bytes: "'{}'".format, + float: "{:.8f}".format, + int: "{:d}".format, + type(None): lambda x: "NULL", + np.float64: "{:.10f}".format, + bool: "'{!s}'".format, + } + processed_args = [] + for arg in args: + if isinstance(arg, float) and isna(arg): + arg = None + + formatter = _formatters[type(arg)] + processed_args.append(formatter(arg)) + + return sql % tuple(processed_args) + + +def tquery(query, con=None): + """Replace removed sql.tquery function""" + with sql.pandasSQL_builder(con) as pandas_sql: + res = pandas_sql.execute(query).fetchall() + return None if res is None else list(res) + + +def test_xsqlite_basic(sqlite_buildin): + frame = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + assert sql.to_sql(frame, name="test_table", con=sqlite_buildin, index=False) == 10 + result = sql.read_sql("select * from test_table", sqlite_buildin) + + # HACK! Change this once indexes are handled properly. + result.index = frame.index + + expected = frame + tm.assert_frame_equal(result, frame) + + frame["txt"] = ["a"] * len(frame) + frame2 = frame.copy() + new_idx = Index(np.arange(len(frame2)), dtype=np.int64) + 10 + frame2["Idx"] = new_idx.copy() + assert sql.to_sql(frame2, name="test_table2", con=sqlite_buildin, index=False) == 10 + result = sql.read_sql("select * from test_table2", sqlite_buildin, index_col="Idx") + expected = frame.copy() + expected.index = new_idx + expected.index.name = "Idx" + tm.assert_frame_equal(expected, result) + + +def test_xsqlite_write_row_by_row(sqlite_buildin): + frame = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + frame.iloc[0, 0] = np.nan + create_sql = sql.get_schema(frame, "test") + cur = sqlite_buildin.cursor() + cur.execute(create_sql) + + ins = "INSERT INTO test VALUES (%s, %s, %s, %s)" + for _, row in frame.iterrows(): + fmt_sql = format_query(ins, *row) + tquery(fmt_sql, con=sqlite_buildin) + + sqlite_buildin.commit() + + result = sql.read_sql("select * from test", con=sqlite_buildin) + result.index = frame.index + tm.assert_frame_equal(result, frame, rtol=1e-3) + + +def test_xsqlite_execute(sqlite_buildin): + frame = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + create_sql = sql.get_schema(frame, "test") + cur = sqlite_buildin.cursor() + cur.execute(create_sql) + ins = "INSERT INTO test VALUES (?, ?, ?, ?)" + + row = frame.iloc[0] + with sql.pandasSQL_builder(sqlite_buildin) as pandas_sql: + pandas_sql.execute(ins, tuple(row)) + sqlite_buildin.commit() + + result = sql.read_sql("select * from test", sqlite_buildin) + result.index = frame.index[:1] + tm.assert_frame_equal(result, frame[:1]) + + +def test_xsqlite_schema(sqlite_buildin): + frame = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + create_sql = sql.get_schema(frame, "test") + lines = create_sql.splitlines() + for line in lines: + tokens = line.split(" ") + if len(tokens) == 2 and tokens[0] == "A": + assert tokens[1] == "DATETIME" + + create_sql = sql.get_schema(frame, "test", keys=["A", "B"]) + lines = create_sql.splitlines() + assert 'PRIMARY KEY ("A", "B")' in create_sql + cur = sqlite_buildin.cursor() + cur.execute(create_sql) + + +def test_xsqlite_execute_fail(sqlite_buildin): + create_sql = """ + CREATE TABLE test + ( + a TEXT, + b TEXT, + c REAL, + PRIMARY KEY (a, b) + ); + """ + cur = sqlite_buildin.cursor() + cur.execute(create_sql) + + with sql.pandasSQL_builder(sqlite_buildin) as pandas_sql: + pandas_sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)') + pandas_sql.execute('INSERT INTO test VALUES("foo", "baz", 2.567)') + + with pytest.raises(sql.DatabaseError, match="Execution failed on sql"): + pandas_sql.execute('INSERT INTO test VALUES("foo", "bar", 7)') + + +def test_xsqlite_execute_closed_connection(): + create_sql = """ + CREATE TABLE test + ( + a TEXT, + b TEXT, + c REAL, + PRIMARY KEY (a, b) + ); + """ + with contextlib.closing(sqlite3.connect(":memory:")) as conn: + cur = conn.cursor() + cur.execute(create_sql) + + with sql.pandasSQL_builder(conn) as pandas_sql: + pandas_sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)') + + msg = "Cannot operate on a closed database." + with pytest.raises(sqlite3.ProgrammingError, match=msg): + tquery("select * from test", con=conn) + + +def test_xsqlite_keyword_as_column_names(sqlite_buildin): + df = DataFrame({"From": np.ones(5)}) + assert sql.to_sql(df, con=sqlite_buildin, name="testkeywords", index=False) == 5 + + +def test_xsqlite_onecolumn_of_integer(sqlite_buildin): + # GH 3628 + # a column_of_integers dataframe should transfer well to sql + + mono_df = DataFrame([1, 2], columns=["c0"]) + assert sql.to_sql(mono_df, con=sqlite_buildin, name="mono_df", index=False) == 2 + # computing the sum via sql + con_x = sqlite_buildin + the_sum = sum(my_c0[0] for my_c0 in con_x.execute("select * from mono_df")) + # it should not fail, and gives 3 ( Issue #3628 ) + assert the_sum == 3 + + result = sql.read_sql("select * from mono_df", con_x) + tm.assert_frame_equal(result, mono_df) + + +def test_xsqlite_if_exists(sqlite_buildin): + df_if_exists_1 = DataFrame({"col1": [1, 2], "col2": ["A", "B"]}) + df_if_exists_2 = DataFrame({"col1": [3, 4, 5], "col2": ["C", "D", "E"]}) + table_name = "table_if_exists" + sql_select = f"SELECT * FROM {table_name}" + + msg = "'notvalidvalue' is not valid for if_exists" + with pytest.raises(ValueError, match=msg): + sql.to_sql( + frame=df_if_exists_1, + con=sqlite_buildin, + name=table_name, + if_exists="notvalidvalue", + ) + drop_table(table_name, sqlite_buildin) + + # test if_exists='fail' + sql.to_sql( + frame=df_if_exists_1, con=sqlite_buildin, name=table_name, if_exists="fail" + ) + msg = "Table 'table_if_exists' already exists" + with pytest.raises(ValueError, match=msg): + sql.to_sql( + frame=df_if_exists_1, + con=sqlite_buildin, + name=table_name, + if_exists="fail", + ) + # test if_exists='replace' + sql.to_sql( + frame=df_if_exists_1, + con=sqlite_buildin, + name=table_name, + if_exists="replace", + index=False, + ) + assert tquery(sql_select, con=sqlite_buildin) == [(1, "A"), (2, "B")] + assert ( + sql.to_sql( + frame=df_if_exists_2, + con=sqlite_buildin, + name=table_name, + if_exists="replace", + index=False, + ) + == 3 + ) + assert tquery(sql_select, con=sqlite_buildin) == [(3, "C"), (4, "D"), (5, "E")] + drop_table(table_name, sqlite_buildin) + + # test if_exists='append' + assert ( + sql.to_sql( + frame=df_if_exists_1, + con=sqlite_buildin, + name=table_name, + if_exists="fail", + index=False, + ) + == 2 + ) + assert tquery(sql_select, con=sqlite_buildin) == [(1, "A"), (2, "B")] + assert ( + sql.to_sql( + frame=df_if_exists_2, + con=sqlite_buildin, + name=table_name, + if_exists="append", + index=False, + ) + == 3 + ) + assert tquery(sql_select, con=sqlite_buildin) == [ + (1, "A"), + (2, "B"), + (3, "C"), + (4, "D"), + (5, "E"), + ] + drop_table(table_name, sqlite_buildin) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_stata.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_stata.py new file mode 100644 index 0000000000000000000000000000000000000000..6bd74faa8a3dbbae94f2a8fd79aa23bf677e6220 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/io/test_stata.py @@ -0,0 +1,2381 @@ +import bz2 +import datetime as dt +from datetime import datetime +import gzip +import io +import os +import struct +import tarfile +import zipfile + +import numpy as np +import pytest + +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import CategoricalDtype +import pandas._testing as tm +from pandas.core.frame import ( + DataFrame, + Series, +) + +from pandas.io.parsers import read_csv +from pandas.io.stata import ( + CategoricalConversionWarning, + InvalidColumnName, + PossiblePrecisionLoss, + StataMissingValue, + StataReader, + StataWriter, + StataWriterUTF8, + ValueLabelTypeMismatch, + read_stata, +) + + +@pytest.fixture +def mixed_frame(): + return DataFrame( + { + "a": [1, 2, 3, 4], + "b": [1.0, 3.0, 27.0, 81.0], + "c": ["Atlanta", "Birmingham", "Cincinnati", "Detroit"], + } + ) + + +@pytest.fixture +def parsed_114(datapath): + dta14_114 = datapath("io", "data", "stata", "stata5_114.dta") + parsed_114 = read_stata(dta14_114, convert_dates=True) + parsed_114.index.name = "index" + return parsed_114 + + +class TestStata: + def read_dta(self, file): + # Legacy default reader configuration + return read_stata(file, convert_dates=True) + + def read_csv(self, file): + return read_csv(file, parse_dates=True) + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_read_empty_dta(self, version): + empty_ds = DataFrame(columns=["unit"]) + # GH 7369, make sure can read a 0-obs dta file + with tm.ensure_clean() as path: + empty_ds.to_stata(path, write_index=False, version=version) + empty_ds2 = read_stata(path) + tm.assert_frame_equal(empty_ds, empty_ds2) + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_read_empty_dta_with_dtypes(self, version): + # GH 46240 + # Fixing above bug revealed that types are not correctly preserved when + # writing empty DataFrames + empty_df_typed = DataFrame( + { + "i8": np.array([0], dtype=np.int8), + "i16": np.array([0], dtype=np.int16), + "i32": np.array([0], dtype=np.int32), + "i64": np.array([0], dtype=np.int64), + "u8": np.array([0], dtype=np.uint8), + "u16": np.array([0], dtype=np.uint16), + "u32": np.array([0], dtype=np.uint32), + "u64": np.array([0], dtype=np.uint64), + "f32": np.array([0], dtype=np.float32), + "f64": np.array([0], dtype=np.float64), + } + ) + expected = empty_df_typed.copy() + # No uint# support. Downcast since values in range for int# + expected["u8"] = expected["u8"].astype(np.int8) + expected["u16"] = expected["u16"].astype(np.int16) + expected["u32"] = expected["u32"].astype(np.int32) + # No int64 supported at all. Downcast since values in range for int32 + expected["u64"] = expected["u64"].astype(np.int32) + expected["i64"] = expected["i64"].astype(np.int32) + + # GH 7369, make sure can read a 0-obs dta file + with tm.ensure_clean() as path: + empty_df_typed.to_stata(path, write_index=False, version=version) + empty_reread = read_stata(path) + tm.assert_frame_equal(expected, empty_reread) + tm.assert_series_equal(expected.dtypes, empty_reread.dtypes) + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_read_index_col_none(self, version): + df = DataFrame({"a": range(5), "b": ["b1", "b2", "b3", "b4", "b5"]}) + # GH 7369, make sure can read a 0-obs dta file + with tm.ensure_clean() as path: + df.to_stata(path, write_index=False, version=version) + read_df = read_stata(path) + + assert isinstance(read_df.index, pd.RangeIndex) + expected = df.copy() + expected["a"] = expected["a"].astype(np.int32) + tm.assert_frame_equal(read_df, expected, check_index_type=True) + + @pytest.mark.parametrize("file", ["stata1_114", "stata1_117"]) + def test_read_dta1(self, file, datapath): + file = datapath("io", "data", "stata", f"{file}.dta") + parsed = self.read_dta(file) + + # Pandas uses np.nan as missing value. + # Thus, all columns will be of type float, regardless of their name. + expected = DataFrame( + [(np.nan, np.nan, np.nan, np.nan, np.nan)], + columns=["float_miss", "double_miss", "byte_miss", "int_miss", "long_miss"], + ) + + # this is an oddity as really the nan should be float64, but + # the casting doesn't fail so need to match stata here + expected["float_miss"] = expected["float_miss"].astype(np.float32) + + tm.assert_frame_equal(parsed, expected) + + def test_read_dta2(self, datapath): + expected = DataFrame.from_records( + [ + ( + datetime(2006, 11, 19, 23, 13, 20), + 1479596223000, + datetime(2010, 1, 20), + datetime(2010, 1, 8), + datetime(2010, 1, 1), + datetime(1974, 7, 1), + datetime(2010, 1, 1), + datetime(2010, 1, 1), + ), + ( + datetime(1959, 12, 31, 20, 3, 20), + -1479590, + datetime(1953, 10, 2), + datetime(1948, 6, 10), + datetime(1955, 1, 1), + datetime(1955, 7, 1), + datetime(1955, 1, 1), + datetime(2, 1, 1), + ), + (pd.NaT, pd.NaT, pd.NaT, pd.NaT, pd.NaT, pd.NaT, pd.NaT, pd.NaT), + ], + columns=[ + "datetime_c", + "datetime_big_c", + "date", + "weekly_date", + "monthly_date", + "quarterly_date", + "half_yearly_date", + "yearly_date", + ], + ) + expected["yearly_date"] = expected["yearly_date"].astype("O") + + path1 = datapath("io", "data", "stata", "stata2_114.dta") + path2 = datapath("io", "data", "stata", "stata2_115.dta") + path3 = datapath("io", "data", "stata", "stata2_117.dta") + + with tm.assert_produces_warning(UserWarning): + parsed_114 = self.read_dta(path1) + with tm.assert_produces_warning(UserWarning): + parsed_115 = self.read_dta(path2) + with tm.assert_produces_warning(UserWarning): + parsed_117 = self.read_dta(path3) + # FIXME: don't leave commented-out + # 113 is buggy due to limits of date format support in Stata + # parsed_113 = self.read_dta( + # datapath("io", "data", "stata", "stata2_113.dta") + # ) + + # FIXME: don't leave commented-out + # buggy test because of the NaT comparison on certain platforms + # Format 113 test fails since it does not support tc and tC formats + # tm.assert_frame_equal(parsed_113, expected) + tm.assert_frame_equal(parsed_114, expected, check_datetimelike_compat=True) + tm.assert_frame_equal(parsed_115, expected, check_datetimelike_compat=True) + tm.assert_frame_equal(parsed_117, expected, check_datetimelike_compat=True) + + @pytest.mark.parametrize( + "file", ["stata3_113", "stata3_114", "stata3_115", "stata3_117"] + ) + def test_read_dta3(self, file, datapath): + file = datapath("io", "data", "stata", f"{file}.dta") + parsed = self.read_dta(file) + + # match stata here + expected = self.read_csv(datapath("io", "data", "stata", "stata3.csv")) + expected = expected.astype(np.float32) + expected["year"] = expected["year"].astype(np.int16) + expected["quarter"] = expected["quarter"].astype(np.int8) + + tm.assert_frame_equal(parsed, expected) + + @pytest.mark.parametrize( + "file", ["stata4_113", "stata4_114", "stata4_115", "stata4_117"] + ) + def test_read_dta4(self, file, datapath): + file = datapath("io", "data", "stata", f"{file}.dta") + parsed = self.read_dta(file) + + expected = DataFrame.from_records( + [ + ["one", "ten", "one", "one", "one"], + ["two", "nine", "two", "two", "two"], + ["three", "eight", "three", "three", "three"], + ["four", "seven", 4, "four", "four"], + ["five", "six", 5, np.nan, "five"], + ["six", "five", 6, np.nan, "six"], + ["seven", "four", 7, np.nan, "seven"], + ["eight", "three", 8, np.nan, "eight"], + ["nine", "two", 9, np.nan, "nine"], + ["ten", "one", "ten", np.nan, "ten"], + ], + columns=[ + "fully_labeled", + "fully_labeled2", + "incompletely_labeled", + "labeled_with_missings", + "float_labelled", + ], + ) + + # these are all categoricals + for col in expected: + orig = expected[col].copy() + + categories = np.asarray(expected["fully_labeled"][orig.notna()]) + if col == "incompletely_labeled": + categories = orig + + cat = orig.astype("category")._values + cat = cat.set_categories(categories, ordered=True) + cat.categories.rename(None, inplace=True) + + expected[col] = cat + + # stata doesn't save .category metadata + tm.assert_frame_equal(parsed, expected) + + # File containing strls + def test_read_dta12(self, datapath): + parsed_117 = self.read_dta(datapath("io", "data", "stata", "stata12_117.dta")) + expected = DataFrame.from_records( + [ + [1, "abc", "abcdefghi"], + [3, "cba", "qwertywertyqwerty"], + [93, "", "strl"], + ], + columns=["x", "y", "z"], + ) + + tm.assert_frame_equal(parsed_117, expected, check_dtype=False) + + def test_read_dta18(self, datapath): + parsed_118 = self.read_dta(datapath("io", "data", "stata", "stata14_118.dta")) + parsed_118["Bytes"] = parsed_118["Bytes"].astype("O") + expected = DataFrame.from_records( + [ + ["Cat", "Bogota", "Bogotá", 1, 1.0, "option b Ünicode", 1.0], + ["Dog", "Boston", "Uzunköprü", np.nan, np.nan, np.nan, np.nan], + ["Plane", "Rome", "Tromsø", 0, 0.0, "option a", 0.0], + ["Potato", "Tokyo", "Elâzığ", -4, 4.0, 4, 4], # noqa: RUF001 + ["", "", "", 0, 0.3332999, "option a", 1 / 3.0], + ], + columns=[ + "Things", + "Cities", + "Unicode_Cities_Strl", + "Ints", + "Floats", + "Bytes", + "Longs", + ], + ) + expected["Floats"] = expected["Floats"].astype(np.float32) + for col in parsed_118.columns: + tm.assert_almost_equal(parsed_118[col], expected[col]) + + with StataReader(datapath("io", "data", "stata", "stata14_118.dta")) as rdr: + vl = rdr.variable_labels() + vl_expected = { + "Unicode_Cities_Strl": "Here are some strls with Ünicode chars", + "Longs": "long data", + "Things": "Here are some things", + "Bytes": "byte data", + "Ints": "int data", + "Cities": "Here are some cities", + "Floats": "float data", + } + tm.assert_dict_equal(vl, vl_expected) + + assert rdr.data_label == "This is a Ünicode data label" + + def test_read_write_dta5(self): + original = DataFrame( + [(np.nan, np.nan, np.nan, np.nan, np.nan)], + columns=["float_miss", "double_miss", "byte_miss", "int_miss", "long_miss"], + ) + original.index.name = "index" + + with tm.ensure_clean() as path: + original.to_stata(path, convert_dates=None) + written_and_read_again = self.read_dta(path) + + expected = original.copy() + expected.index = expected.index.astype(np.int32) + tm.assert_frame_equal(written_and_read_again.set_index("index"), expected) + + def test_write_dta6(self, datapath): + original = self.read_csv(datapath("io", "data", "stata", "stata3.csv")) + original.index.name = "index" + original.index = original.index.astype(np.int32) + original["year"] = original["year"].astype(np.int32) + original["quarter"] = original["quarter"].astype(np.int32) + + with tm.ensure_clean() as path: + original.to_stata(path, convert_dates=None) + written_and_read_again = self.read_dta(path) + tm.assert_frame_equal( + written_and_read_again.set_index("index"), + original, + check_index_type=False, + ) + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_read_write_dta10(self, version): + original = DataFrame( + data=[["string", "object", 1, 1.1, np.datetime64("2003-12-25")]], + columns=["string", "object", "integer", "floating", "datetime"], + ) + original["object"] = Series(original["object"], dtype=object) + original.index.name = "index" + original.index = original.index.astype(np.int32) + original["integer"] = original["integer"].astype(np.int32) + + with tm.ensure_clean() as path: + original.to_stata(path, convert_dates={"datetime": "tc"}, version=version) + written_and_read_again = self.read_dta(path) + # original.index is np.int32, read index is np.int64 + tm.assert_frame_equal( + written_and_read_again.set_index("index"), + original, + check_index_type=False, + ) + + def test_stata_doc_examples(self): + with tm.ensure_clean() as path: + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 2)), columns=list("AB") + ) + df.to_stata(path) + + def test_write_preserves_original(self): + # 9795 + + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 4)), columns=list("abcd") + ) + df.loc[2, "a":"c"] = np.nan + df_copy = df.copy() + with tm.ensure_clean() as path: + df.to_stata(path, write_index=False) + tm.assert_frame_equal(df, df_copy) + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_encoding(self, version, datapath): + # GH 4626, proper encoding handling + raw = read_stata(datapath("io", "data", "stata", "stata1_encoding.dta")) + encoded = read_stata(datapath("io", "data", "stata", "stata1_encoding.dta")) + result = encoded.kreis1849[0] + + expected = raw.kreis1849[0] + assert result == expected + assert isinstance(result, str) + + with tm.ensure_clean() as path: + encoded.to_stata(path, write_index=False, version=version) + reread_encoded = read_stata(path) + tm.assert_frame_equal(encoded, reread_encoded) + + def test_read_write_dta11(self): + original = DataFrame( + [(1, 2, 3, 4)], + columns=[ + "good", + "b\u00E4d", + "8number", + "astringwithmorethan32characters______", + ], + ) + formatted = DataFrame( + [(1, 2, 3, 4)], + columns=["good", "b_d", "_8number", "astringwithmorethan32characters_"], + ) + formatted.index.name = "index" + formatted = formatted.astype(np.int32) + + with tm.ensure_clean() as path: + with tm.assert_produces_warning(InvalidColumnName): + original.to_stata(path, convert_dates=None) + + written_and_read_again = self.read_dta(path) + + expected = formatted.copy() + expected.index = expected.index.astype(np.int32) + tm.assert_frame_equal(written_and_read_again.set_index("index"), expected) + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_read_write_dta12(self, version): + original = DataFrame( + [(1, 2, 3, 4, 5, 6)], + columns=[ + "astringwithmorethan32characters_1", + "astringwithmorethan32characters_2", + "+", + "-", + "short", + "delete", + ], + ) + formatted = DataFrame( + [(1, 2, 3, 4, 5, 6)], + columns=[ + "astringwithmorethan32characters_", + "_0astringwithmorethan32character", + "_", + "_1_", + "_short", + "_delete", + ], + ) + formatted.index.name = "index" + formatted = formatted.astype(np.int32) + + with tm.ensure_clean() as path: + with tm.assert_produces_warning(InvalidColumnName): + original.to_stata(path, convert_dates=None, version=version) + # should get a warning for that format. + + written_and_read_again = self.read_dta(path) + + expected = formatted.copy() + expected.index = expected.index.astype(np.int32) + tm.assert_frame_equal(written_and_read_again.set_index("index"), expected) + + def test_read_write_dta13(self): + s1 = Series(2**9, dtype=np.int16) + s2 = Series(2**17, dtype=np.int32) + s3 = Series(2**33, dtype=np.int64) + original = DataFrame({"int16": s1, "int32": s2, "int64": s3}) + original.index.name = "index" + + formatted = original + formatted["int64"] = formatted["int64"].astype(np.float64) + + with tm.ensure_clean() as path: + original.to_stata(path) + written_and_read_again = self.read_dta(path) + + expected = formatted.copy() + expected.index = expected.index.astype(np.int32) + tm.assert_frame_equal(written_and_read_again.set_index("index"), expected) + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + @pytest.mark.parametrize( + "file", ["stata5_113", "stata5_114", "stata5_115", "stata5_117"] + ) + def test_read_write_reread_dta14(self, file, parsed_114, version, datapath): + file = datapath("io", "data", "stata", f"{file}.dta") + parsed = self.read_dta(file) + parsed.index.name = "index" + + tm.assert_frame_equal(parsed_114, parsed) + + with tm.ensure_clean() as path: + parsed_114.to_stata(path, convert_dates={"date_td": "td"}, version=version) + written_and_read_again = self.read_dta(path) + + expected = parsed_114.copy() + expected.index = expected.index.astype(np.int32) + tm.assert_frame_equal(written_and_read_again.set_index("index"), expected) + + @pytest.mark.parametrize( + "file", ["stata6_113", "stata6_114", "stata6_115", "stata6_117"] + ) + def test_read_write_reread_dta15(self, file, datapath): + expected = self.read_csv(datapath("io", "data", "stata", "stata6.csv")) + expected["byte_"] = expected["byte_"].astype(np.int8) + expected["int_"] = expected["int_"].astype(np.int16) + expected["long_"] = expected["long_"].astype(np.int32) + expected["float_"] = expected["float_"].astype(np.float32) + expected["double_"] = expected["double_"].astype(np.float64) + expected["date_td"] = expected["date_td"].apply( + datetime.strptime, args=("%Y-%m-%d",) + ) + + file = datapath("io", "data", "stata", f"{file}.dta") + parsed = self.read_dta(file) + + tm.assert_frame_equal(expected, parsed) + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_timestamp_and_label(self, version): + original = DataFrame([(1,)], columns=["variable"]) + time_stamp = datetime(2000, 2, 29, 14, 21) + data_label = "This is a data file." + with tm.ensure_clean() as path: + original.to_stata( + path, time_stamp=time_stamp, data_label=data_label, version=version + ) + + with StataReader(path) as reader: + assert reader.time_stamp == "29 Feb 2000 14:21" + assert reader.data_label == data_label + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_invalid_timestamp(self, version): + original = DataFrame([(1,)], columns=["variable"]) + time_stamp = "01 Jan 2000, 00:00:00" + with tm.ensure_clean() as path: + msg = "time_stamp should be datetime type" + with pytest.raises(ValueError, match=msg): + original.to_stata(path, time_stamp=time_stamp, version=version) + assert not os.path.isfile(path) + + def test_numeric_column_names(self): + original = DataFrame(np.reshape(np.arange(25.0), (5, 5))) + original.index.name = "index" + with tm.ensure_clean() as path: + # should get a warning for that format. + with tm.assert_produces_warning(InvalidColumnName): + original.to_stata(path) + + written_and_read_again = self.read_dta(path) + + written_and_read_again = written_and_read_again.set_index("index") + columns = list(written_and_read_again.columns) + convert_col_name = lambda x: int(x[1]) + written_and_read_again.columns = map(convert_col_name, columns) + + expected = original.copy() + expected.index = expected.index.astype(np.int32) + tm.assert_frame_equal(expected, written_and_read_again) + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_nan_to_missing_value(self, version): + s1 = Series(np.arange(4.0), dtype=np.float32) + s2 = Series(np.arange(4.0), dtype=np.float64) + s1[::2] = np.nan + s2[1::2] = np.nan + original = DataFrame({"s1": s1, "s2": s2}) + original.index.name = "index" + + with tm.ensure_clean() as path: + original.to_stata(path, version=version) + written_and_read_again = self.read_dta(path) + + written_and_read_again = written_and_read_again.set_index("index") + expected = original.copy() + expected.index = expected.index.astype(np.int32) + tm.assert_frame_equal(written_and_read_again, expected) + + def test_no_index(self): + columns = ["x", "y"] + original = DataFrame(np.reshape(np.arange(10.0), (5, 2)), columns=columns) + original.index.name = "index_not_written" + with tm.ensure_clean() as path: + original.to_stata(path, write_index=False) + written_and_read_again = self.read_dta(path) + with pytest.raises(KeyError, match=original.index.name): + written_and_read_again["index_not_written"] + + def test_string_no_dates(self): + s1 = Series(["a", "A longer string"]) + s2 = Series([1.0, 2.0], dtype=np.float64) + original = DataFrame({"s1": s1, "s2": s2}) + original.index.name = "index" + with tm.ensure_clean() as path: + original.to_stata(path) + written_and_read_again = self.read_dta(path) + + expected = original.copy() + expected.index = expected.index.astype(np.int32) + tm.assert_frame_equal(written_and_read_again.set_index("index"), expected) + + def test_large_value_conversion(self): + s0 = Series([1, 99], dtype=np.int8) + s1 = Series([1, 127], dtype=np.int8) + s2 = Series([1, 2**15 - 1], dtype=np.int16) + s3 = Series([1, 2**63 - 1], dtype=np.int64) + original = DataFrame({"s0": s0, "s1": s1, "s2": s2, "s3": s3}) + original.index.name = "index" + with tm.ensure_clean() as path: + with tm.assert_produces_warning(PossiblePrecisionLoss): + original.to_stata(path) + + written_and_read_again = self.read_dta(path) + + modified = original.copy() + modified["s1"] = Series(modified["s1"], dtype=np.int16) + modified["s2"] = Series(modified["s2"], dtype=np.int32) + modified["s3"] = Series(modified["s3"], dtype=np.float64) + modified.index = original.index.astype(np.int32) + tm.assert_frame_equal(written_and_read_again.set_index("index"), modified) + + def test_dates_invalid_column(self): + original = DataFrame([datetime(2006, 11, 19, 23, 13, 20)]) + original.index.name = "index" + with tm.ensure_clean() as path: + with tm.assert_produces_warning(InvalidColumnName): + original.to_stata(path, convert_dates={0: "tc"}) + + written_and_read_again = self.read_dta(path) + + modified = original.copy() + modified.columns = ["_0"] + modified.index = original.index.astype(np.int32) + tm.assert_frame_equal(written_and_read_again.set_index("index"), modified) + + def test_105(self, datapath): + # Data obtained from: + # http://go.worldbank.org/ZXY29PVJ21 + dpath = datapath("io", "data", "stata", "S4_EDUC1.dta") + df = read_stata(dpath) + df0 = [[1, 1, 3, -2], [2, 1, 2, -2], [4, 1, 1, -2]] + df0 = DataFrame(df0) + df0.columns = ["clustnum", "pri_schl", "psch_num", "psch_dis"] + df0["clustnum"] = df0["clustnum"].astype(np.int16) + df0["pri_schl"] = df0["pri_schl"].astype(np.int8) + df0["psch_num"] = df0["psch_num"].astype(np.int8) + df0["psch_dis"] = df0["psch_dis"].astype(np.float32) + tm.assert_frame_equal(df.head(3), df0) + + def test_value_labels_old_format(self, datapath): + # GH 19417 + # + # Test that value_labels() returns an empty dict if the file format + # predates supporting value labels. + dpath = datapath("io", "data", "stata", "S4_EDUC1.dta") + with StataReader(dpath) as reader: + assert reader.value_labels() == {} + + def test_date_export_formats(self): + columns = ["tc", "td", "tw", "tm", "tq", "th", "ty"] + conversions = {c: c for c in columns} + data = [datetime(2006, 11, 20, 23, 13, 20)] * len(columns) + original = DataFrame([data], columns=columns) + original.index.name = "index" + expected_values = [ + datetime(2006, 11, 20, 23, 13, 20), # Time + datetime(2006, 11, 20), # Day + datetime(2006, 11, 19), # Week + datetime(2006, 11, 1), # Month + datetime(2006, 10, 1), # Quarter year + datetime(2006, 7, 1), # Half year + datetime(2006, 1, 1), + ] # Year + + expected = DataFrame( + [expected_values], + index=pd.Index([0], dtype=np.int32, name="index"), + columns=columns, + ) + + with tm.ensure_clean() as path: + original.to_stata(path, convert_dates=conversions) + written_and_read_again = self.read_dta(path) + + tm.assert_frame_equal(written_and_read_again.set_index("index"), expected) + + def test_write_missing_strings(self): + original = DataFrame([["1"], [None]], columns=["foo"]) + + expected = DataFrame( + [["1"], [""]], + index=pd.Index([0, 1], dtype=np.int32, name="index"), + columns=["foo"], + ) + + with tm.ensure_clean() as path: + original.to_stata(path) + written_and_read_again = self.read_dta(path) + + tm.assert_frame_equal(written_and_read_again.set_index("index"), expected) + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + @pytest.mark.parametrize("byteorder", [">", "<"]) + def test_bool_uint(self, byteorder, version): + s0 = Series([0, 1, True], dtype=np.bool_) + s1 = Series([0, 1, 100], dtype=np.uint8) + s2 = Series([0, 1, 255], dtype=np.uint8) + s3 = Series([0, 1, 2**15 - 100], dtype=np.uint16) + s4 = Series([0, 1, 2**16 - 1], dtype=np.uint16) + s5 = Series([0, 1, 2**31 - 100], dtype=np.uint32) + s6 = Series([0, 1, 2**32 - 1], dtype=np.uint32) + + original = DataFrame( + {"s0": s0, "s1": s1, "s2": s2, "s3": s3, "s4": s4, "s5": s5, "s6": s6} + ) + original.index.name = "index" + expected = original.copy() + expected.index = original.index.astype(np.int32) + expected_types = ( + np.int8, + np.int8, + np.int16, + np.int16, + np.int32, + np.int32, + np.float64, + ) + for c, t in zip(expected.columns, expected_types): + expected[c] = expected[c].astype(t) + + with tm.ensure_clean() as path: + original.to_stata(path, byteorder=byteorder, version=version) + written_and_read_again = self.read_dta(path) + + written_and_read_again = written_and_read_again.set_index("index") + tm.assert_frame_equal(written_and_read_again, expected) + + def test_variable_labels(self, datapath): + with StataReader(datapath("io", "data", "stata", "stata7_115.dta")) as rdr: + sr_115 = rdr.variable_labels() + with StataReader(datapath("io", "data", "stata", "stata7_117.dta")) as rdr: + sr_117 = rdr.variable_labels() + keys = ("var1", "var2", "var3") + labels = ("label1", "label2", "label3") + for k, v in sr_115.items(): + assert k in sr_117 + assert v == sr_117[k] + assert k in keys + assert v in labels + + def test_minimal_size_col(self): + str_lens = (1, 100, 244) + s = {} + for str_len in str_lens: + s["s" + str(str_len)] = Series( + ["a" * str_len, "b" * str_len, "c" * str_len] + ) + original = DataFrame(s) + with tm.ensure_clean() as path: + original.to_stata(path, write_index=False) + + with StataReader(path) as sr: + sr._ensure_open() # The `_*list` variables are initialized here + for variable, fmt, typ in zip(sr._varlist, sr._fmtlist, sr._typlist): + assert int(variable[1:]) == int(fmt[1:-1]) + assert int(variable[1:]) == typ + + def test_excessively_long_string(self): + str_lens = (1, 244, 500) + s = {} + for str_len in str_lens: + s["s" + str(str_len)] = Series( + ["a" * str_len, "b" * str_len, "c" * str_len] + ) + original = DataFrame(s) + msg = ( + r"Fixed width strings in Stata \.dta files are limited to 244 " + r"\(or fewer\)\ncharacters\. Column 's500' does not satisfy " + r"this restriction\. Use the\n'version=117' parameter to write " + r"the newer \(Stata 13 and later\) format\." + ) + with pytest.raises(ValueError, match=msg): + with tm.ensure_clean() as path: + original.to_stata(path) + + def test_missing_value_generator(self): + types = ("b", "h", "l") + df = DataFrame([[0.0]], columns=["float_"]) + with tm.ensure_clean() as path: + df.to_stata(path) + with StataReader(path) as rdr: + valid_range = rdr.VALID_RANGE + expected_values = ["." + chr(97 + i) for i in range(26)] + expected_values.insert(0, ".") + for t in types: + offset = valid_range[t][1] + for i in range(27): + val = StataMissingValue(offset + 1 + i) + assert val.string == expected_values[i] + + # Test extremes for floats + val = StataMissingValue(struct.unpack(" DataFrame: + """ + Emulate the categorical casting behavior we expect from roundtripping. + """ + for col in from_frame: + ser = from_frame[col] + if isinstance(ser.dtype, CategoricalDtype): + cat = ser._values.remove_unused_categories() + if cat.categories.dtype == object: + categories = pd.Index._with_infer(cat.categories._values) + cat = cat.set_categories(categories) + from_frame[col] = cat + return from_frame + + def test_iterator(self, datapath): + fname = datapath("io", "data", "stata", "stata3_117.dta") + + parsed = read_stata(fname) + + with read_stata(fname, iterator=True) as itr: + chunk = itr.read(5) + tm.assert_frame_equal(parsed.iloc[0:5, :], chunk) + + with read_stata(fname, chunksize=5) as itr: + chunk = list(itr) + tm.assert_frame_equal(parsed.iloc[0:5, :], chunk[0]) + + with read_stata(fname, iterator=True) as itr: + chunk = itr.get_chunk(5) + tm.assert_frame_equal(parsed.iloc[0:5, :], chunk) + + with read_stata(fname, chunksize=5) as itr: + chunk = itr.get_chunk() + tm.assert_frame_equal(parsed.iloc[0:5, :], chunk) + + # GH12153 + with read_stata(fname, chunksize=4) as itr: + from_chunks = pd.concat(itr) + tm.assert_frame_equal(parsed, from_chunks) + + @pytest.mark.filterwarnings("ignore::UserWarning") + @pytest.mark.parametrize( + "file", + [ + "stata2_115", + "stata3_115", + "stata4_115", + "stata5_115", + "stata6_115", + "stata7_115", + "stata8_115", + "stata9_115", + "stata10_115", + "stata11_115", + ], + ) + @pytest.mark.parametrize("chunksize", [1, 2]) + @pytest.mark.parametrize("convert_categoricals", [False, True]) + @pytest.mark.parametrize("convert_dates", [False, True]) + def test_read_chunks_115( + self, file, chunksize, convert_categoricals, convert_dates, datapath + ): + fname = datapath("io", "data", "stata", f"{file}.dta") + + # Read the whole file + parsed = read_stata( + fname, + convert_categoricals=convert_categoricals, + convert_dates=convert_dates, + ) + + # Compare to what we get when reading by chunk + with read_stata( + fname, + iterator=True, + convert_dates=convert_dates, + convert_categoricals=convert_categoricals, + ) as itr: + pos = 0 + for j in range(5): + try: + chunk = itr.read(chunksize) + except StopIteration: + break + from_frame = parsed.iloc[pos : pos + chunksize, :].copy() + from_frame = self._convert_categorical(from_frame) + tm.assert_frame_equal( + from_frame, chunk, check_dtype=False, check_datetimelike_compat=True + ) + pos += chunksize + + def test_read_chunks_columns(self, datapath): + fname = datapath("io", "data", "stata", "stata3_117.dta") + columns = ["quarter", "cpi", "m1"] + chunksize = 2 + + parsed = read_stata(fname, columns=columns) + with read_stata(fname, iterator=True) as itr: + pos = 0 + for j in range(5): + chunk = itr.read(chunksize, columns=columns) + if chunk is None: + break + from_frame = parsed.iloc[pos : pos + chunksize, :] + tm.assert_frame_equal(from_frame, chunk, check_dtype=False) + pos += chunksize + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_write_variable_labels(self, version, mixed_frame): + # GH 13631, add support for writing variable labels + mixed_frame.index.name = "index" + variable_labels = {"a": "City Rank", "b": "City Exponent", "c": "City"} + with tm.ensure_clean() as path: + mixed_frame.to_stata(path, variable_labels=variable_labels, version=version) + with StataReader(path) as sr: + read_labels = sr.variable_labels() + expected_labels = { + "index": "", + "a": "City Rank", + "b": "City Exponent", + "c": "City", + } + assert read_labels == expected_labels + + variable_labels["index"] = "The Index" + with tm.ensure_clean() as path: + mixed_frame.to_stata(path, variable_labels=variable_labels, version=version) + with StataReader(path) as sr: + read_labels = sr.variable_labels() + assert read_labels == variable_labels + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_invalid_variable_labels(self, version, mixed_frame): + mixed_frame.index.name = "index" + variable_labels = {"a": "very long" * 10, "b": "City Exponent", "c": "City"} + with tm.ensure_clean() as path: + msg = "Variable labels must be 80 characters or fewer" + with pytest.raises(ValueError, match=msg): + mixed_frame.to_stata( + path, variable_labels=variable_labels, version=version + ) + + @pytest.mark.parametrize("version", [114, 117]) + def test_invalid_variable_label_encoding(self, version, mixed_frame): + mixed_frame.index.name = "index" + variable_labels = {"a": "very long" * 10, "b": "City Exponent", "c": "City"} + variable_labels["a"] = "invalid character Œ" + with tm.ensure_clean() as path: + with pytest.raises( + ValueError, match="Variable labels must contain only characters" + ): + mixed_frame.to_stata( + path, variable_labels=variable_labels, version=version + ) + + def test_write_variable_label_errors(self, mixed_frame): + values = ["\u03A1", "\u0391", "\u039D", "\u0394", "\u0391", "\u03A3"] + + variable_labels_utf8 = { + "a": "City Rank", + "b": "City Exponent", + "c": "".join(values), + } + + msg = ( + "Variable labels must contain only characters that can be " + "encoded in Latin-1" + ) + with pytest.raises(ValueError, match=msg): + with tm.ensure_clean() as path: + mixed_frame.to_stata(path, variable_labels=variable_labels_utf8) + + variable_labels_long = { + "a": "City Rank", + "b": "City Exponent", + "c": "A very, very, very long variable label " + "that is too long for Stata which means " + "that it has more than 80 characters", + } + + msg = "Variable labels must be 80 characters or fewer" + with pytest.raises(ValueError, match=msg): + with tm.ensure_clean() as path: + mixed_frame.to_stata(path, variable_labels=variable_labels_long) + + def test_default_date_conversion(self): + # GH 12259 + dates = [ + dt.datetime(1999, 12, 31, 12, 12, 12, 12000), + dt.datetime(2012, 12, 21, 12, 21, 12, 21000), + dt.datetime(1776, 7, 4, 7, 4, 7, 4000), + ] + original = DataFrame( + { + "nums": [1.0, 2.0, 3.0], + "strs": ["apple", "banana", "cherry"], + "dates": dates, + } + ) + + with tm.ensure_clean() as path: + original.to_stata(path, write_index=False) + reread = read_stata(path, convert_dates=True) + tm.assert_frame_equal(original, reread) + + original.to_stata(path, write_index=False, convert_dates={"dates": "tc"}) + direct = read_stata(path, convert_dates=True) + tm.assert_frame_equal(reread, direct) + + dates_idx = original.columns.tolist().index("dates") + original.to_stata(path, write_index=False, convert_dates={dates_idx: "tc"}) + direct = read_stata(path, convert_dates=True) + tm.assert_frame_equal(reread, direct) + + def test_unsupported_type(self): + original = DataFrame({"a": [1 + 2j, 2 + 4j]}) + + msg = "Data type complex128 not supported" + with pytest.raises(NotImplementedError, match=msg): + with tm.ensure_clean() as path: + original.to_stata(path) + + def test_unsupported_datetype(self): + dates = [ + dt.datetime(1999, 12, 31, 12, 12, 12, 12000), + dt.datetime(2012, 12, 21, 12, 21, 12, 21000), + dt.datetime(1776, 7, 4, 7, 4, 7, 4000), + ] + original = DataFrame( + { + "nums": [1.0, 2.0, 3.0], + "strs": ["apple", "banana", "cherry"], + "dates": dates, + } + ) + + msg = "Format %tC not implemented" + with pytest.raises(NotImplementedError, match=msg): + with tm.ensure_clean() as path: + original.to_stata(path, convert_dates={"dates": "tC"}) + + dates = pd.date_range("1-1-1990", periods=3, tz="Asia/Hong_Kong") + original = DataFrame( + { + "nums": [1.0, 2.0, 3.0], + "strs": ["apple", "banana", "cherry"], + "dates": dates, + } + ) + with pytest.raises(NotImplementedError, match="Data type datetime64"): + with tm.ensure_clean() as path: + original.to_stata(path) + + def test_repeated_column_labels(self, datapath): + # GH 13923, 25772 + msg = """ +Value labels for column ethnicsn are not unique. These cannot be converted to +pandas categoricals. + +Either read the file with `convert_categoricals` set to False or use the +low level interface in `StataReader` to separately read the values and the +value_labels. + +The repeated labels are:\n-+\nwolof +""" + with pytest.raises(ValueError, match=msg): + read_stata( + datapath("io", "data", "stata", "stata15.dta"), + convert_categoricals=True, + ) + + def test_stata_111(self, datapath): + # 111 is an old version but still used by current versions of + # SAS when exporting to Stata format. We do not know of any + # on-line documentation for this version. + df = read_stata(datapath("io", "data", "stata", "stata7_111.dta")) + original = DataFrame( + { + "y": [1, 1, 1, 1, 1, 0, 0, np.nan, 0, 0], + "x": [1, 2, 1, 3, np.nan, 4, 3, 5, 1, 6], + "w": [2, np.nan, 5, 2, 4, 4, 3, 1, 2, 3], + "z": ["a", "b", "c", "d", "e", "", "g", "h", "i", "j"], + } + ) + original = original[["y", "x", "w", "z"]] + tm.assert_frame_equal(original, df) + + def test_out_of_range_double(self): + # GH 14618 + df = DataFrame( + { + "ColumnOk": [0.0, np.finfo(np.double).eps, 4.49423283715579e307], + "ColumnTooBig": [0.0, np.finfo(np.double).eps, np.finfo(np.double).max], + } + ) + msg = ( + r"Column ColumnTooBig has a maximum value \(.+\) outside the range " + r"supported by Stata \(.+\)" + ) + with pytest.raises(ValueError, match=msg): + with tm.ensure_clean() as path: + df.to_stata(path) + + def test_out_of_range_float(self): + original = DataFrame( + { + "ColumnOk": [ + 0.0, + np.finfo(np.float32).eps, + np.finfo(np.float32).max / 10.0, + ], + "ColumnTooBig": [ + 0.0, + np.finfo(np.float32).eps, + np.finfo(np.float32).max, + ], + } + ) + original.index.name = "index" + for col in original: + original[col] = original[col].astype(np.float32) + + with tm.ensure_clean() as path: + original.to_stata(path) + reread = read_stata(path) + + original["ColumnTooBig"] = original["ColumnTooBig"].astype(np.float64) + expected = original.copy() + expected.index = expected.index.astype(np.int32) + tm.assert_frame_equal(reread.set_index("index"), expected) + + @pytest.mark.parametrize("infval", [np.inf, -np.inf]) + def test_inf(self, infval): + # GH 45350 + df = DataFrame({"WithoutInf": [0.0, 1.0], "WithInf": [2.0, infval]}) + msg = ( + "Column WithInf contains infinity or -infinity" + "which is outside the range supported by Stata." + ) + with pytest.raises(ValueError, match=msg): + with tm.ensure_clean() as path: + df.to_stata(path) + + def test_path_pathlib(self): + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD"), dtype=object), + index=pd.Index([f"i-{i}" for i in range(30)], dtype=object), + ) + df.index.name = "index" + reader = lambda x: read_stata(x).set_index("index") + result = tm.round_trip_pathlib(df.to_stata, reader) + tm.assert_frame_equal(df, result) + + def test_pickle_path_localpath(self): + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD"), dtype=object), + index=pd.Index([f"i-{i}" for i in range(30)], dtype=object), + ) + df.index.name = "index" + reader = lambda x: read_stata(x).set_index("index") + result = tm.round_trip_localpath(df.to_stata, reader) + tm.assert_frame_equal(df, result) + + @pytest.mark.parametrize("write_index", [True, False]) + def test_value_labels_iterator(self, write_index): + # GH 16923 + d = {"A": ["B", "E", "C", "A", "E"]} + df = DataFrame(data=d) + df["A"] = df["A"].astype("category") + with tm.ensure_clean() as path: + df.to_stata(path, write_index=write_index) + + with read_stata(path, iterator=True) as dta_iter: + value_labels = dta_iter.value_labels() + assert value_labels == {"A": {0: "A", 1: "B", 2: "C", 3: "E"}} + + def test_set_index(self): + # GH 17328 + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD"), dtype=object), + index=pd.Index([f"i-{i}" for i in range(30)], dtype=object), + ) + df.index.name = "index" + with tm.ensure_clean() as path: + df.to_stata(path) + reread = read_stata(path, index_col="index") + tm.assert_frame_equal(df, reread) + + @pytest.mark.parametrize( + "column", ["ms", "day", "week", "month", "qtr", "half", "yr"] + ) + def test_date_parsing_ignores_format_details(self, column, datapath): + # GH 17797 + # + # Test that display formats are ignored when determining if a numeric + # column is a date value. + # + # All date types are stored as numbers and format associated with the + # column denotes both the type of the date and the display format. + # + # STATA supports 9 date types which each have distinct units. We test 7 + # of the 9 types, ignoring %tC and %tb. %tC is a variant of %tc that + # accounts for leap seconds and %tb relies on STATAs business calendar. + df = read_stata(datapath("io", "data", "stata", "stata13_dates.dta")) + unformatted = df.loc[0, column] + formatted = df.loc[0, column + "_fmt"] + assert unformatted == formatted + + def test_writer_117(self): + original = DataFrame( + data=[ + [ + "string", + "object", + 1, + 1, + 1, + 1.1, + 1.1, + np.datetime64("2003-12-25"), + "a", + "a" * 2045, + "a" * 5000, + "a", + ], + [ + "string-1", + "object-1", + 1, + 1, + 1, + 1.1, + 1.1, + np.datetime64("2003-12-26"), + "b", + "b" * 2045, + "", + "", + ], + ], + columns=[ + "string", + "object", + "int8", + "int16", + "int32", + "float32", + "float64", + "datetime", + "s1", + "s2045", + "srtl", + "forced_strl", + ], + ) + original["object"] = Series(original["object"], dtype=object) + original["int8"] = Series(original["int8"], dtype=np.int8) + original["int16"] = Series(original["int16"], dtype=np.int16) + original["int32"] = original["int32"].astype(np.int32) + original["float32"] = Series(original["float32"], dtype=np.float32) + original.index.name = "index" + original.index = original.index.astype(np.int32) + copy = original.copy() + with tm.ensure_clean() as path: + original.to_stata( + path, + convert_dates={"datetime": "tc"}, + convert_strl=["forced_strl"], + version=117, + ) + written_and_read_again = self.read_dta(path) + # original.index is np.int32, read index is np.int64 + tm.assert_frame_equal( + written_and_read_again.set_index("index"), + original, + check_index_type=False, + ) + tm.assert_frame_equal(original, copy) + + def test_convert_strl_name_swap(self): + original = DataFrame( + [["a" * 3000, "A", "apple"], ["b" * 1000, "B", "banana"]], + columns=["long1" * 10, "long", 1], + ) + original.index.name = "index" + + with tm.assert_produces_warning(InvalidColumnName): + with tm.ensure_clean() as path: + original.to_stata(path, convert_strl=["long", 1], version=117) + reread = self.read_dta(path) + reread = reread.set_index("index") + reread.columns = original.columns + tm.assert_frame_equal(reread, original, check_index_type=False) + + def test_invalid_date_conversion(self): + # GH 12259 + dates = [ + dt.datetime(1999, 12, 31, 12, 12, 12, 12000), + dt.datetime(2012, 12, 21, 12, 21, 12, 21000), + dt.datetime(1776, 7, 4, 7, 4, 7, 4000), + ] + original = DataFrame( + { + "nums": [1.0, 2.0, 3.0], + "strs": ["apple", "banana", "cherry"], + "dates": dates, + } + ) + + with tm.ensure_clean() as path: + msg = "convert_dates key must be a column or an integer" + with pytest.raises(ValueError, match=msg): + original.to_stata(path, convert_dates={"wrong_name": "tc"}) + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_nonfile_writing(self, version): + # GH 21041 + bio = io.BytesIO() + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD"), dtype=object), + index=pd.Index([f"i-{i}" for i in range(30)], dtype=object), + ) + df.index.name = "index" + with tm.ensure_clean() as path: + df.to_stata(bio, version=version) + bio.seek(0) + with open(path, "wb") as dta: + dta.write(bio.read()) + reread = read_stata(path, index_col="index") + tm.assert_frame_equal(df, reread) + + def test_gzip_writing(self): + # writing version 117 requires seek and cannot be used with gzip + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD"), dtype=object), + index=pd.Index([f"i-{i}" for i in range(30)], dtype=object), + ) + df.index.name = "index" + with tm.ensure_clean() as path: + with gzip.GzipFile(path, "wb") as gz: + df.to_stata(gz, version=114) + with gzip.GzipFile(path, "rb") as gz: + reread = read_stata(gz, index_col="index") + tm.assert_frame_equal(df, reread) + + def test_unicode_dta_118(self, datapath): + unicode_df = self.read_dta(datapath("io", "data", "stata", "stata16_118.dta")) + + columns = ["utf8", "latin1", "ascii", "utf8_strl", "ascii_strl"] + values = [ + ["ραηδας", "PÄNDÄS", "p", "ραηδας", "p"], + ["ƤĀńĐąŜ", "Ö", "a", "ƤĀńĐąŜ", "a"], + ["ᴘᴀᴎᴅᴀS", "Ü", "n", "ᴘᴀᴎᴅᴀS", "n"], + [" ", " ", "d", " ", "d"], + [" ", "", "a", " ", "a"], + ["", "", "s", "", "s"], + ["", "", " ", "", " "], + ] + expected = DataFrame(values, columns=columns) + + tm.assert_frame_equal(unicode_df, expected) + + def test_mixed_string_strl(self): + # GH 23633 + output = [{"mixed": "string" * 500, "number": 0}, {"mixed": None, "number": 1}] + output = DataFrame(output) + output.number = output.number.astype("int32") + + with tm.ensure_clean() as path: + output.to_stata(path, write_index=False, version=117) + reread = read_stata(path) + expected = output.fillna("") + tm.assert_frame_equal(reread, expected) + + # Check strl supports all None (null) + output["mixed"] = None + output.to_stata( + path, write_index=False, convert_strl=["mixed"], version=117 + ) + reread = read_stata(path) + expected = output.fillna("") + tm.assert_frame_equal(reread, expected) + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_all_none_exception(self, version): + output = [{"none": "none", "number": 0}, {"none": None, "number": 1}] + output = DataFrame(output) + output["none"] = None + with tm.ensure_clean() as path: + with pytest.raises(ValueError, match="Column `none` cannot be exported"): + output.to_stata(path, version=version) + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_invalid_file_not_written(self, version): + content = "Here is one __�__ Another one __·__ Another one __½__" + df = DataFrame([content], columns=["invalid"]) + with tm.ensure_clean() as path: + msg1 = ( + r"'latin-1' codec can't encode character '\\ufffd' " + r"in position 14: ordinal not in range\(256\)" + ) + msg2 = ( + "'ascii' codec can't decode byte 0xef in position 14: " + r"ordinal not in range\(128\)" + ) + with pytest.raises(UnicodeEncodeError, match=f"{msg1}|{msg2}"): + df.to_stata(path) + + def test_strl_latin1(self): + # GH 23573, correct GSO data to reflect correct size + output = DataFrame( + [["pandas"] * 2, ["þâÑÐŧ"] * 2], columns=["var_str", "var_strl"] + ) + + with tm.ensure_clean() as path: + output.to_stata(path, version=117, convert_strl=["var_strl"]) + with open(path, "rb") as reread: + content = reread.read() + expected = "þâÑÐŧ" + assert expected.encode("latin-1") in content + assert expected.encode("utf-8") in content + gsos = content.split(b"strls")[1][1:-2] + for gso in gsos.split(b"GSO")[1:]: + val = gso.split(b"\x00")[-2] + size = gso[gso.find(b"\x82") + 1] + assert len(val) == size - 1 + + def test_encoding_latin1_118(self, datapath): + # GH 25960 + msg = """ +One or more strings in the dta file could not be decoded using utf-8, and +so the fallback encoding of latin-1 is being used. This can happen when a file +has been incorrectly encoded by Stata or some other software. You should verify +the string values returned are correct.""" + # Move path outside of read_stata, or else assert_produces_warning + # will block pytests skip mechanism from triggering (failing the test) + # if the path is not present + path = datapath("io", "data", "stata", "stata1_encoding_118.dta") + with tm.assert_produces_warning(UnicodeWarning, filter_level="once") as w: + encoded = read_stata(path) + # with filter_level="always", produces 151 warnings which can be slow + assert len(w) == 1 + assert w[0].message.args[0] == msg + + expected = DataFrame([["Düsseldorf"]] * 151, columns=["kreis1849"]) + tm.assert_frame_equal(encoded, expected) + + @pytest.mark.slow + def test_stata_119(self, datapath): + # Gzipped since contains 32,999 variables and uncompressed is 20MiB + # Just validate that the reader reports correct number of variables + # to avoid high peak memory + with gzip.open( + datapath("io", "data", "stata", "stata1_119.dta.gz"), "rb" + ) as gz: + with StataReader(gz) as reader: + reader._ensure_open() + assert reader._nvar == 32999 + + @pytest.mark.parametrize("version", [118, 119, None]) + def test_utf8_writer(self, version): + cat = pd.Categorical(["a", "β", "ĉ"], ordered=True) + data = DataFrame( + [ + [1.0, 1, "ᴬ", "ᴀ relatively long ŝtring"], + [2.0, 2, "ᴮ", ""], + [3.0, 3, "ᴰ", None], + ], + columns=["Å", "β", "ĉ", "strls"], + ) + data["ᴐᴬᵀ"] = cat + variable_labels = { + "Å": "apple", + "β": "ᵈᵉᵊ", + "ĉ": "ᴎტჄႲႳႴႶႺ", + "strls": "Long Strings", + "ᴐᴬᵀ": "", + } + data_label = "ᴅaᵀa-label" + value_labels = {"β": {1: "label", 2: "æøå", 3: "ŋot valid latin-1"}} + data["β"] = data["β"].astype(np.int32) + with tm.ensure_clean() as path: + writer = StataWriterUTF8( + path, + data, + data_label=data_label, + convert_strl=["strls"], + variable_labels=variable_labels, + write_index=False, + version=version, + value_labels=value_labels, + ) + writer.write_file() + reread_encoded = read_stata(path) + # Missing is intentionally converted to empty strl + data["strls"] = data["strls"].fillna("") + # Variable with value labels is reread as categorical + data["β"] = ( + data["β"].replace(value_labels["β"]).astype("category").cat.as_ordered() + ) + tm.assert_frame_equal(data, reread_encoded) + with StataReader(path) as reader: + assert reader.data_label == data_label + assert reader.variable_labels() == variable_labels + + data.to_stata(path, version=version, write_index=False) + reread_to_stata = read_stata(path) + tm.assert_frame_equal(data, reread_to_stata) + + def test_writer_118_exceptions(self): + df = DataFrame(np.zeros((1, 33000), dtype=np.int8)) + with tm.ensure_clean() as path: + with pytest.raises(ValueError, match="version must be either 118 or 119."): + StataWriterUTF8(path, df, version=117) + with tm.ensure_clean() as path: + with pytest.raises(ValueError, match="You must use version 119"): + StataWriterUTF8(path, df, version=118) + + @pytest.mark.parametrize( + "dtype_backend", + ["numpy_nullable", pytest.param("pyarrow", marks=td.skip_if_no("pyarrow"))], + ) + def test_read_write_ea_dtypes(self, dtype_backend): + df = DataFrame( + { + "a": [1, 2, None], + "b": ["a", "b", "c"], + "c": [True, False, None], + "d": [1.5, 2.5, 3.5], + "e": pd.date_range("2020-12-31", periods=3, freq="D"), + }, + index=pd.Index([0, 1, 2], name="index"), + ) + df = df.convert_dtypes(dtype_backend=dtype_backend) + df.to_stata("test_stata.dta", version=118) + + with tm.ensure_clean() as path: + df.to_stata(path) + written_and_read_again = self.read_dta(path) + + expected = DataFrame( + { + "a": [1, 2, np.nan], + "b": ["a", "b", "c"], + "c": [1.0, 0, np.nan], + "d": [1.5, 2.5, 3.5], + "e": pd.date_range("2020-12-31", periods=3, freq="D"), + }, + index=pd.Index([0, 1, 2], name="index", dtype=np.int32), + ) + + tm.assert_frame_equal(written_and_read_again.set_index("index"), expected) + + +@pytest.mark.parametrize("version", [105, 108, 111, 113, 114]) +def test_backward_compat(version, datapath): + data_base = datapath("io", "data", "stata") + ref = os.path.join(data_base, "stata-compat-118.dta") + old = os.path.join(data_base, f"stata-compat-{version}.dta") + expected = read_stata(ref) + old_dta = read_stata(old) + tm.assert_frame_equal(old_dta, expected, check_dtype=False) + + +def test_direct_read(datapath, monkeypatch): + file_path = datapath("io", "data", "stata", "stata-compat-118.dta") + + # Test that opening a file path doesn't buffer the file. + with StataReader(file_path) as reader: + # Must not have been buffered to memory + assert not reader.read().empty + assert not isinstance(reader._path_or_buf, io.BytesIO) + + # Test that we use a given fp exactly, if possible. + with open(file_path, "rb") as fp: + with StataReader(fp) as reader: + assert not reader.read().empty + assert reader._path_or_buf is fp + + # Test that we use a given BytesIO exactly, if possible. + with open(file_path, "rb") as fp: + with io.BytesIO(fp.read()) as bio: + with StataReader(bio) as reader: + assert not reader.read().empty + assert reader._path_or_buf is bio + + +def test_statareader_warns_when_used_without_context(datapath): + file_path = datapath("io", "data", "stata", "stata-compat-118.dta") + with tm.assert_produces_warning( + ResourceWarning, + match="without using a context manager", + ): + sr = StataReader(file_path) + sr.read() + with tm.assert_produces_warning( + FutureWarning, + match="is not part of the public API", + ): + sr.close() + + +@pytest.mark.parametrize("version", [114, 117, 118, 119, None]) +@pytest.mark.parametrize("use_dict", [True, False]) +@pytest.mark.parametrize("infer", [True, False]) +def test_compression(compression, version, use_dict, infer, compression_to_extension): + file_name = "dta_inferred_compression.dta" + if compression: + if use_dict: + file_ext = compression + else: + file_ext = compression_to_extension[compression] + file_name += f".{file_ext}" + compression_arg = compression + if infer: + compression_arg = "infer" + if use_dict: + compression_arg = {"method": compression} + + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 2)), columns=list("AB") + ) + df.index.name = "index" + with tm.ensure_clean(file_name) as path: + df.to_stata(path, version=version, compression=compression_arg) + if compression == "gzip": + with gzip.open(path, "rb") as comp: + fp = io.BytesIO(comp.read()) + elif compression == "zip": + with zipfile.ZipFile(path, "r") as comp: + fp = io.BytesIO(comp.read(comp.filelist[0])) + elif compression == "tar": + with tarfile.open(path) as tar: + fp = io.BytesIO(tar.extractfile(tar.getnames()[0]).read()) + elif compression == "bz2": + with bz2.open(path, "rb") as comp: + fp = io.BytesIO(comp.read()) + elif compression == "zstd": + zstd = pytest.importorskip("zstandard") + with zstd.open(path, "rb") as comp: + fp = io.BytesIO(comp.read()) + elif compression == "xz": + lzma = pytest.importorskip("lzma") + with lzma.open(path, "rb") as comp: + fp = io.BytesIO(comp.read()) + elif compression is None: + fp = path + reread = read_stata(fp, index_col="index") + + expected = df.copy() + expected.index = expected.index.astype(np.int32) + tm.assert_frame_equal(reread, expected) + + +@pytest.mark.parametrize("method", ["zip", "infer"]) +@pytest.mark.parametrize("file_ext", [None, "dta", "zip"]) +def test_compression_dict(method, file_ext): + file_name = f"test.{file_ext}" + archive_name = "test.dta" + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 2)), columns=list("AB") + ) + df.index.name = "index" + with tm.ensure_clean(file_name) as path: + compression = {"method": method, "archive_name": archive_name} + df.to_stata(path, compression=compression) + if method == "zip" or file_ext == "zip": + with zipfile.ZipFile(path, "r") as zp: + assert len(zp.filelist) == 1 + assert zp.filelist[0].filename == archive_name + fp = io.BytesIO(zp.read(zp.filelist[0])) + else: + fp = path + reread = read_stata(fp, index_col="index") + + expected = df.copy() + expected.index = expected.index.astype(np.int32) + tm.assert_frame_equal(reread, expected) + + +@pytest.mark.parametrize("version", [114, 117, 118, 119, None]) +def test_chunked_categorical(version): + df = DataFrame({"cats": Series(["a", "b", "a", "b", "c"], dtype="category")}) + df.index.name = "index" + + expected = df.copy() + expected.index = expected.index.astype(np.int32) + + with tm.ensure_clean() as path: + df.to_stata(path, version=version) + with StataReader(path, chunksize=2, order_categoricals=False) as reader: + for i, block in enumerate(reader): + block = block.set_index("index") + assert "cats" in block + tm.assert_series_equal( + block.cats, expected.cats.iloc[2 * i : 2 * (i + 1)] + ) + + +def test_chunked_categorical_partial(datapath): + dta_file = datapath("io", "data", "stata", "stata-dta-partially-labeled.dta") + values = ["a", "b", "a", "b", 3.0] + with StataReader(dta_file, chunksize=2) as reader: + with tm.assert_produces_warning(CategoricalConversionWarning): + for i, block in enumerate(reader): + assert list(block.cats) == values[2 * i : 2 * (i + 1)] + if i < 2: + idx = pd.Index(["a", "b"]) + else: + idx = pd.Index([3.0], dtype="float64") + tm.assert_index_equal(block.cats.cat.categories, idx) + with tm.assert_produces_warning(CategoricalConversionWarning): + with StataReader(dta_file, chunksize=5) as reader: + large_chunk = reader.__next__() + direct = read_stata(dta_file) + tm.assert_frame_equal(direct, large_chunk) + + +@pytest.mark.parametrize("chunksize", (-1, 0, "apple")) +def test_iterator_errors(datapath, chunksize): + dta_file = datapath("io", "data", "stata", "stata-dta-partially-labeled.dta") + with pytest.raises(ValueError, match="chunksize must be a positive"): + with StataReader(dta_file, chunksize=chunksize): + pass + + +def test_iterator_value_labels(): + # GH 31544 + values = ["c_label", "b_label"] + ["a_label"] * 500 + df = DataFrame({f"col{k}": pd.Categorical(values, ordered=True) for k in range(2)}) + with tm.ensure_clean() as path: + df.to_stata(path, write_index=False) + expected = pd.Index(["a_label", "b_label", "c_label"], dtype="object") + with read_stata(path, chunksize=100) as reader: + for j, chunk in enumerate(reader): + for i in range(2): + tm.assert_index_equal(chunk.dtypes.iloc[i].categories, expected) + tm.assert_frame_equal(chunk, df.iloc[j * 100 : (j + 1) * 100]) + + +def test_precision_loss(): + df = DataFrame( + [[sum(2**i for i in range(60)), sum(2**i for i in range(52))]], + columns=["big", "little"], + ) + with tm.ensure_clean() as path: + with tm.assert_produces_warning( + PossiblePrecisionLoss, match="Column converted from int64 to float64" + ): + df.to_stata(path, write_index=False) + reread = read_stata(path) + expected_dt = Series([np.float64, np.float64], index=["big", "little"]) + tm.assert_series_equal(reread.dtypes, expected_dt) + assert reread.loc[0, "little"] == df.loc[0, "little"] + assert reread.loc[0, "big"] == float(df.loc[0, "big"]) + + +def test_compression_roundtrip(compression): + df = DataFrame( + [[0.123456, 0.234567, 0.567567], [12.32112, 123123.2, 321321.2]], + index=["A", "B"], + columns=["X", "Y", "Z"], + ) + df.index.name = "index" + + with tm.ensure_clean() as path: + df.to_stata(path, compression=compression) + reread = read_stata(path, compression=compression, index_col="index") + tm.assert_frame_equal(df, reread) + + # explicitly ensure file was compressed. + with tm.decompress_file(path, compression) as fh: + contents = io.BytesIO(fh.read()) + reread = read_stata(contents, index_col="index") + tm.assert_frame_equal(df, reread) + + +@pytest.mark.parametrize("to_infer", [True, False]) +@pytest.mark.parametrize("read_infer", [True, False]) +def test_stata_compression( + compression_only, read_infer, to_infer, compression_to_extension +): + compression = compression_only + + ext = compression_to_extension[compression] + filename = f"test.{ext}" + + df = DataFrame( + [[0.123456, 0.234567, 0.567567], [12.32112, 123123.2, 321321.2]], + index=["A", "B"], + columns=["X", "Y", "Z"], + ) + df.index.name = "index" + + to_compression = "infer" if to_infer else compression + read_compression = "infer" if read_infer else compression + + with tm.ensure_clean(filename) as path: + df.to_stata(path, compression=to_compression) + result = read_stata(path, compression=read_compression, index_col="index") + tm.assert_frame_equal(result, df) + + +def test_non_categorical_value_labels(): + data = DataFrame( + { + "fully_labelled": [1, 2, 3, 3, 1], + "partially_labelled": [1.0, 2.0, np.nan, 9.0, np.nan], + "Y": [7, 7, 9, 8, 10], + "Z": pd.Categorical(["j", "k", "l", "k", "j"]), + } + ) + + with tm.ensure_clean() as path: + value_labels = { + "fully_labelled": {1: "one", 2: "two", 3: "three"}, + "partially_labelled": {1.0: "one", 2.0: "two"}, + } + expected = {**value_labels, "Z": {0: "j", 1: "k", 2: "l"}} + + writer = StataWriter(path, data, value_labels=value_labels) + writer.write_file() + + with StataReader(path) as reader: + reader_value_labels = reader.value_labels() + assert reader_value_labels == expected + + msg = "Can't create value labels for notY, it wasn't found in the dataset." + with pytest.raises(KeyError, match=msg): + value_labels = {"notY": {7: "label1", 8: "label2"}} + StataWriter(path, data, value_labels=value_labels) + + msg = ( + "Can't create value labels for Z, value labels " + "can only be applied to numeric columns." + ) + with pytest.raises(ValueError, match=msg): + value_labels = {"Z": {1: "a", 2: "k", 3: "j", 4: "i"}} + StataWriter(path, data, value_labels=value_labels) + + +def test_non_categorical_value_label_name_conversion(): + # Check conversion of invalid variable names + data = DataFrame( + { + "invalid~!": [1, 1, 2, 3, 5, 8], # Only alphanumeric and _ + "6_invalid": [1, 1, 2, 3, 5, 8], # Must start with letter or _ + "invalid_name_longer_than_32_characters": [8, 8, 9, 9, 8, 8], # Too long + "aggregate": [2, 5, 5, 6, 6, 9], # Reserved words + (1, 2): [1, 2, 3, 4, 5, 6], # Hashable non-string + } + ) + + value_labels = { + "invalid~!": {1: "label1", 2: "label2"}, + "6_invalid": {1: "label1", 2: "label2"}, + "invalid_name_longer_than_32_characters": {8: "eight", 9: "nine"}, + "aggregate": {5: "five"}, + (1, 2): {3: "three"}, + } + + expected = { + "invalid__": {1: "label1", 2: "label2"}, + "_6_invalid": {1: "label1", 2: "label2"}, + "invalid_name_longer_than_32_char": {8: "eight", 9: "nine"}, + "_aggregate": {5: "five"}, + "_1__2_": {3: "three"}, + } + + with tm.ensure_clean() as path: + with tm.assert_produces_warning(InvalidColumnName): + data.to_stata(path, value_labels=value_labels) + + with StataReader(path) as reader: + reader_value_labels = reader.value_labels() + assert reader_value_labels == expected + + +def test_non_categorical_value_label_convert_categoricals_error(): + # Mapping more than one value to the same label is valid for Stata + # labels, but can't be read with convert_categoricals=True + value_labels = { + "repeated_labels": {10: "Ten", 20: "More than ten", 40: "More than ten"} + } + + data = DataFrame( + { + "repeated_labels": [10, 10, 20, 20, 40, 40], + } + ) + + with tm.ensure_clean() as path: + data.to_stata(path, value_labels=value_labels) + + with StataReader(path, convert_categoricals=False) as reader: + reader_value_labels = reader.value_labels() + assert reader_value_labels == value_labels + + col = "repeated_labels" + repeats = "-" * 80 + "\n" + "\n".join(["More than ten"]) + + msg = f""" +Value labels for column {col} are not unique. These cannot be converted to +pandas categoricals. + +Either read the file with `convert_categoricals` set to False or use the +low level interface in `StataReader` to separately read the values and the +value_labels. + +The repeated labels are: +{repeats} +""" + with pytest.raises(ValueError, match=msg): + read_stata(path, convert_categoricals=True) + + +@pytest.mark.parametrize("version", [114, 117, 118, 119, None]) +@pytest.mark.parametrize( + "dtype", + [ + pd.BooleanDtype, + pd.Int8Dtype, + pd.Int16Dtype, + pd.Int32Dtype, + pd.Int64Dtype, + pd.UInt8Dtype, + pd.UInt16Dtype, + pd.UInt32Dtype, + pd.UInt64Dtype, + ], +) +def test_nullable_support(dtype, version): + df = DataFrame( + { + "a": Series([1.0, 2.0, 3.0]), + "b": Series([1, pd.NA, pd.NA], dtype=dtype.name), + "c": Series(["a", "b", None]), + } + ) + dtype_name = df.b.dtype.numpy_dtype.name + # Only use supported names: no uint, bool or int64 + dtype_name = dtype_name.replace("u", "") + if dtype_name == "int64": + dtype_name = "int32" + elif dtype_name == "bool": + dtype_name = "int8" + value = StataMissingValue.BASE_MISSING_VALUES[dtype_name] + smv = StataMissingValue(value) + expected_b = Series([1, smv, smv], dtype=object, name="b") + expected_c = Series(["a", "b", ""], name="c") + with tm.ensure_clean() as path: + df.to_stata(path, write_index=False, version=version) + reread = read_stata(path, convert_missing=True) + tm.assert_series_equal(df.a, reread.a) + tm.assert_series_equal(reread.b, expected_b) + tm.assert_series_equal(reread.c, expected_c) + + +def test_empty_frame(): + # GH 46240 + # create an empty DataFrame with int64 and float64 dtypes + df = DataFrame(data={"a": range(3), "b": [1.0, 2.0, 3.0]}).head(0) + with tm.ensure_clean() as path: + df.to_stata(path, write_index=False, version=117) + # Read entire dataframe + df2 = read_stata(path) + assert "b" in df2 + # Dtypes don't match since no support for int32 + dtypes = Series({"a": np.dtype("int32"), "b": np.dtype("float64")}) + tm.assert_series_equal(df2.dtypes, dtypes) + # read one column of empty .dta file + df3 = read_stata(path, columns=["a"]) + assert "b" not in df3 + tm.assert_series_equal(df3.dtypes, dtypes.loc[["a"]]) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/libs/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/libs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/libs/test_hashtable.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/libs/test_hashtable.py new file mode 100644 index 0000000000000000000000000000000000000000..e54764f9ac4a69714f740124ad886eca10f33fc2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/libs/test_hashtable.py @@ -0,0 +1,748 @@ +from collections.abc import Generator +from contextlib import contextmanager +import re +import struct +import tracemalloc + +import numpy as np +import pytest + +from pandas._libs import hashtable as ht + +import pandas as pd +import pandas._testing as tm +from pandas.core.algorithms import isin + + +@contextmanager +def activated_tracemalloc() -> Generator[None, None, None]: + tracemalloc.start() + try: + yield + finally: + tracemalloc.stop() + + +def get_allocated_khash_memory(): + snapshot = tracemalloc.take_snapshot() + snapshot = snapshot.filter_traces( + (tracemalloc.DomainFilter(True, ht.get_hashtable_trace_domain()),) + ) + return sum(x.size for x in snapshot.traces) + + +@pytest.mark.parametrize( + "table_type, dtype", + [ + (ht.PyObjectHashTable, np.object_), + (ht.Complex128HashTable, np.complex128), + (ht.Int64HashTable, np.int64), + (ht.UInt64HashTable, np.uint64), + (ht.Float64HashTable, np.float64), + (ht.Complex64HashTable, np.complex64), + (ht.Int32HashTable, np.int32), + (ht.UInt32HashTable, np.uint32), + (ht.Float32HashTable, np.float32), + (ht.Int16HashTable, np.int16), + (ht.UInt16HashTable, np.uint16), + (ht.Int8HashTable, np.int8), + (ht.UInt8HashTable, np.uint8), + (ht.IntpHashTable, np.intp), + ], +) +class TestHashTable: + def test_get_set_contains_len(self, table_type, dtype): + index = 5 + table = table_type(55) + assert len(table) == 0 + assert index not in table + + table.set_item(index, 42) + assert len(table) == 1 + assert index in table + assert table.get_item(index) == 42 + + table.set_item(index + 1, 41) + assert index in table + assert index + 1 in table + assert len(table) == 2 + assert table.get_item(index) == 42 + assert table.get_item(index + 1) == 41 + + table.set_item(index, 21) + assert index in table + assert index + 1 in table + assert len(table) == 2 + assert table.get_item(index) == 21 + assert table.get_item(index + 1) == 41 + assert index + 2 not in table + + table.set_item(index + 1, 21) + assert index in table + assert index + 1 in table + assert len(table) == 2 + assert table.get_item(index) == 21 + assert table.get_item(index + 1) == 21 + + with pytest.raises(KeyError, match=str(index + 2)): + table.get_item(index + 2) + + def test_get_set_contains_len_mask(self, table_type, dtype): + if table_type == ht.PyObjectHashTable: + pytest.skip("Mask not supported for object") + index = 5 + table = table_type(55, uses_mask=True) + assert len(table) == 0 + assert index not in table + + table.set_item(index, 42) + assert len(table) == 1 + assert index in table + assert table.get_item(index) == 42 + with pytest.raises(KeyError, match="NA"): + table.get_na() + + table.set_item(index + 1, 41) + table.set_na(41) + assert pd.NA in table + assert index in table + assert index + 1 in table + assert len(table) == 3 + assert table.get_item(index) == 42 + assert table.get_item(index + 1) == 41 + assert table.get_na() == 41 + + table.set_na(21) + assert index in table + assert index + 1 in table + assert len(table) == 3 + assert table.get_item(index + 1) == 41 + assert table.get_na() == 21 + assert index + 2 not in table + + with pytest.raises(KeyError, match=str(index + 2)): + table.get_item(index + 2) + + def test_map_keys_to_values(self, table_type, dtype, writable): + # only Int64HashTable has this method + if table_type == ht.Int64HashTable: + N = 77 + table = table_type() + keys = np.arange(N).astype(dtype) + vals = np.arange(N).astype(np.int64) + N + keys.flags.writeable = writable + vals.flags.writeable = writable + table.map_keys_to_values(keys, vals) + for i in range(N): + assert table.get_item(keys[i]) == i + N + + def test_map_locations(self, table_type, dtype, writable): + N = 8 + table = table_type() + keys = (np.arange(N) + N).astype(dtype) + keys.flags.writeable = writable + table.map_locations(keys) + for i in range(N): + assert table.get_item(keys[i]) == i + + def test_map_locations_mask(self, table_type, dtype, writable): + if table_type == ht.PyObjectHashTable: + pytest.skip("Mask not supported for object") + N = 3 + table = table_type(uses_mask=True) + keys = (np.arange(N) + N).astype(dtype) + keys.flags.writeable = writable + table.map_locations(keys, np.array([False, False, True])) + for i in range(N - 1): + assert table.get_item(keys[i]) == i + + with pytest.raises(KeyError, match=re.escape(str(keys[N - 1]))): + table.get_item(keys[N - 1]) + + assert table.get_na() == 2 + + def test_lookup(self, table_type, dtype, writable): + N = 3 + table = table_type() + keys = (np.arange(N) + N).astype(dtype) + keys.flags.writeable = writable + table.map_locations(keys) + result = table.lookup(keys) + expected = np.arange(N) + tm.assert_numpy_array_equal(result.astype(np.int64), expected.astype(np.int64)) + + def test_lookup_wrong(self, table_type, dtype): + if dtype in (np.int8, np.uint8): + N = 100 + else: + N = 512 + table = table_type() + keys = (np.arange(N) + N).astype(dtype) + table.map_locations(keys) + wrong_keys = np.arange(N).astype(dtype) + result = table.lookup(wrong_keys) + assert np.all(result == -1) + + def test_lookup_mask(self, table_type, dtype, writable): + if table_type == ht.PyObjectHashTable: + pytest.skip("Mask not supported for object") + N = 3 + table = table_type(uses_mask=True) + keys = (np.arange(N) + N).astype(dtype) + mask = np.array([False, True, False]) + keys.flags.writeable = writable + table.map_locations(keys, mask) + result = table.lookup(keys, mask) + expected = np.arange(N) + tm.assert_numpy_array_equal(result.astype(np.int64), expected.astype(np.int64)) + + result = table.lookup(np.array([1 + N]).astype(dtype), np.array([False])) + tm.assert_numpy_array_equal( + result.astype(np.int64), np.array([-1], dtype=np.int64) + ) + + def test_unique(self, table_type, dtype, writable): + if dtype in (np.int8, np.uint8): + N = 88 + else: + N = 1000 + table = table_type() + expected = (np.arange(N) + N).astype(dtype) + keys = np.repeat(expected, 5) + keys.flags.writeable = writable + unique = table.unique(keys) + tm.assert_numpy_array_equal(unique, expected) + + def test_tracemalloc_works(self, table_type, dtype): + if dtype in (np.int8, np.uint8): + N = 256 + else: + N = 30000 + keys = np.arange(N).astype(dtype) + with activated_tracemalloc(): + table = table_type() + table.map_locations(keys) + used = get_allocated_khash_memory() + my_size = table.sizeof() + assert used == my_size + del table + assert get_allocated_khash_memory() == 0 + + def test_tracemalloc_for_empty(self, table_type, dtype): + with activated_tracemalloc(): + table = table_type() + used = get_allocated_khash_memory() + my_size = table.sizeof() + assert used == my_size + del table + assert get_allocated_khash_memory() == 0 + + def test_get_state(self, table_type, dtype): + table = table_type(1000) + state = table.get_state() + assert state["size"] == 0 + assert state["n_occupied"] == 0 + assert "n_buckets" in state + assert "upper_bound" in state + + @pytest.mark.parametrize("N", range(1, 110)) + def test_no_reallocation(self, table_type, dtype, N): + keys = np.arange(N).astype(dtype) + preallocated_table = table_type(N) + n_buckets_start = preallocated_table.get_state()["n_buckets"] + preallocated_table.map_locations(keys) + n_buckets_end = preallocated_table.get_state()["n_buckets"] + # original number of buckets was enough: + assert n_buckets_start == n_buckets_end + # check with clean table (not too much preallocated) + clean_table = table_type() + clean_table.map_locations(keys) + assert n_buckets_start == clean_table.get_state()["n_buckets"] + + +class TestHashTableUnsorted: + # TODO: moved from test_algos; may be redundancies with other tests + def test_string_hashtable_set_item_signature(self): + # GH#30419 fix typing in StringHashTable.set_item to prevent segfault + tbl = ht.StringHashTable() + + tbl.set_item("key", 1) + assert tbl.get_item("key") == 1 + + with pytest.raises(TypeError, match="'key' has incorrect type"): + # key arg typed as string, not object + tbl.set_item(4, 6) + with pytest.raises(TypeError, match="'val' has incorrect type"): + tbl.get_item(4) + + def test_lookup_nan(self, writable): + # GH#21688 ensure we can deal with readonly memory views + xs = np.array([2.718, 3.14, np.nan, -7, 5, 2, 3]) + xs.setflags(write=writable) + m = ht.Float64HashTable() + m.map_locations(xs) + tm.assert_numpy_array_equal(m.lookup(xs), np.arange(len(xs), dtype=np.intp)) + + def test_add_signed_zeros(self): + # GH#21866 inconsistent hash-function for float64 + # default hash-function would lead to different hash-buckets + # for 0.0 and -0.0 if there are more than 2^30 hash-buckets + # but this would mean 16GB + N = 4 # 12 * 10**8 would trigger the error, if you have enough memory + m = ht.Float64HashTable(N) + m.set_item(0.0, 0) + m.set_item(-0.0, 0) + assert len(m) == 1 # 0.0 and -0.0 are equivalent + + def test_add_different_nans(self): + # GH#21866 inconsistent hash-function for float64 + # create different nans from bit-patterns: + NAN1 = struct.unpack("d", struct.pack("=Q", 0x7FF8000000000000))[0] + NAN2 = struct.unpack("d", struct.pack("=Q", 0x7FF8000000000001))[0] + assert NAN1 != NAN1 + assert NAN2 != NAN2 + # default hash function would lead to different hash-buckets + # for NAN1 and NAN2 even if there are only 4 buckets: + m = ht.Float64HashTable() + m.set_item(NAN1, 0) + m.set_item(NAN2, 0) + assert len(m) == 1 # NAN1 and NAN2 are equivalent + + def test_lookup_overflow(self, writable): + xs = np.array([1, 2, 2**63], dtype=np.uint64) + # GH 21688 ensure we can deal with readonly memory views + xs.setflags(write=writable) + m = ht.UInt64HashTable() + m.map_locations(xs) + tm.assert_numpy_array_equal(m.lookup(xs), np.arange(len(xs), dtype=np.intp)) + + @pytest.mark.parametrize("nvals", [0, 10]) # resizing to 0 is special case + @pytest.mark.parametrize( + "htable, uniques, dtype, safely_resizes", + [ + (ht.PyObjectHashTable, ht.ObjectVector, "object", False), + (ht.StringHashTable, ht.ObjectVector, "object", True), + (ht.Float64HashTable, ht.Float64Vector, "float64", False), + (ht.Int64HashTable, ht.Int64Vector, "int64", False), + (ht.Int32HashTable, ht.Int32Vector, "int32", False), + (ht.UInt64HashTable, ht.UInt64Vector, "uint64", False), + ], + ) + def test_vector_resize( + self, writable, htable, uniques, dtype, safely_resizes, nvals + ): + # Test for memory errors after internal vector + # reallocations (GH 7157) + # Changed from using np.random.default_rng(2).rand to range + # which could cause flaky CI failures when safely_resizes=False + vals = np.array(range(1000), dtype=dtype) + + # GH 21688 ensures we can deal with read-only memory views + vals.setflags(write=writable) + + # initialise instances; cannot initialise in parametrization, + # as otherwise external views would be held on the array (which is + # one of the things this test is checking) + htable = htable() + uniques = uniques() + + # get_labels may append to uniques + htable.get_labels(vals[:nvals], uniques, 0, -1) + # to_array() sets an external_view_exists flag on uniques. + tmp = uniques.to_array() + oldshape = tmp.shape + + # subsequent get_labels() calls can no longer append to it + # (except for StringHashTables + ObjectVector) + if safely_resizes: + htable.get_labels(vals, uniques, 0, -1) + else: + with pytest.raises(ValueError, match="external reference.*"): + htable.get_labels(vals, uniques, 0, -1) + + uniques.to_array() # should not raise here + assert tmp.shape == oldshape + + @pytest.mark.parametrize( + "hashtable", + [ + ht.PyObjectHashTable, + ht.StringHashTable, + ht.Float64HashTable, + ht.Int64HashTable, + ht.Int32HashTable, + ht.UInt64HashTable, + ], + ) + def test_hashtable_large_sizehint(self, hashtable): + # GH#22729 smoketest for not raising when passing a large size_hint + size_hint = np.iinfo(np.uint32).max + 1 + hashtable(size_hint=size_hint) + + +class TestPyObjectHashTableWithNans: + def test_nan_float(self): + nan1 = float("nan") + nan2 = float("nan") + assert nan1 is not nan2 + table = ht.PyObjectHashTable() + table.set_item(nan1, 42) + assert table.get_item(nan2) == 42 + + def test_nan_complex_both(self): + nan1 = complex(float("nan"), float("nan")) + nan2 = complex(float("nan"), float("nan")) + assert nan1 is not nan2 + table = ht.PyObjectHashTable() + table.set_item(nan1, 42) + assert table.get_item(nan2) == 42 + + def test_nan_complex_real(self): + nan1 = complex(float("nan"), 1) + nan2 = complex(float("nan"), 1) + other = complex(float("nan"), 2) + assert nan1 is not nan2 + table = ht.PyObjectHashTable() + table.set_item(nan1, 42) + assert table.get_item(nan2) == 42 + with pytest.raises(KeyError, match=None) as error: + table.get_item(other) + assert str(error.value) == str(other) + + def test_nan_complex_imag(self): + nan1 = complex(1, float("nan")) + nan2 = complex(1, float("nan")) + other = complex(2, float("nan")) + assert nan1 is not nan2 + table = ht.PyObjectHashTable() + table.set_item(nan1, 42) + assert table.get_item(nan2) == 42 + with pytest.raises(KeyError, match=None) as error: + table.get_item(other) + assert str(error.value) == str(other) + + def test_nan_in_tuple(self): + nan1 = (float("nan"),) + nan2 = (float("nan"),) + assert nan1[0] is not nan2[0] + table = ht.PyObjectHashTable() + table.set_item(nan1, 42) + assert table.get_item(nan2) == 42 + + def test_nan_in_nested_tuple(self): + nan1 = (1, (2, (float("nan"),))) + nan2 = (1, (2, (float("nan"),))) + other = (1, 2) + table = ht.PyObjectHashTable() + table.set_item(nan1, 42) + assert table.get_item(nan2) == 42 + with pytest.raises(KeyError, match=None) as error: + table.get_item(other) + assert str(error.value) == str(other) + + +def test_hash_equal_tuple_with_nans(): + a = (float("nan"), (float("nan"), float("nan"))) + b = (float("nan"), (float("nan"), float("nan"))) + assert ht.object_hash(a) == ht.object_hash(b) + assert ht.objects_are_equal(a, b) + + +def test_get_labels_groupby_for_Int64(writable): + table = ht.Int64HashTable() + vals = np.array([1, 2, -1, 2, 1, -1], dtype=np.int64) + vals.flags.writeable = writable + arr, unique = table.get_labels_groupby(vals) + expected_arr = np.array([0, 1, -1, 1, 0, -1], dtype=np.intp) + expected_unique = np.array([1, 2], dtype=np.int64) + tm.assert_numpy_array_equal(arr, expected_arr) + tm.assert_numpy_array_equal(unique, expected_unique) + + +def test_tracemalloc_works_for_StringHashTable(): + N = 1000 + keys = np.arange(N).astype(np.str_).astype(np.object_) + with activated_tracemalloc(): + table = ht.StringHashTable() + table.map_locations(keys) + used = get_allocated_khash_memory() + my_size = table.sizeof() + assert used == my_size + del table + assert get_allocated_khash_memory() == 0 + + +def test_tracemalloc_for_empty_StringHashTable(): + with activated_tracemalloc(): + table = ht.StringHashTable() + used = get_allocated_khash_memory() + my_size = table.sizeof() + assert used == my_size + del table + assert get_allocated_khash_memory() == 0 + + +@pytest.mark.parametrize("N", range(1, 110)) +def test_no_reallocation_StringHashTable(N): + keys = np.arange(N).astype(np.str_).astype(np.object_) + preallocated_table = ht.StringHashTable(N) + n_buckets_start = preallocated_table.get_state()["n_buckets"] + preallocated_table.map_locations(keys) + n_buckets_end = preallocated_table.get_state()["n_buckets"] + # original number of buckets was enough: + assert n_buckets_start == n_buckets_end + # check with clean table (not too much preallocated) + clean_table = ht.StringHashTable() + clean_table.map_locations(keys) + assert n_buckets_start == clean_table.get_state()["n_buckets"] + + +@pytest.mark.parametrize( + "table_type, dtype", + [ + (ht.Float64HashTable, np.float64), + (ht.Float32HashTable, np.float32), + (ht.Complex128HashTable, np.complex128), + (ht.Complex64HashTable, np.complex64), + ], +) +class TestHashTableWithNans: + def test_get_set_contains_len(self, table_type, dtype): + index = float("nan") + table = table_type() + assert index not in table + + table.set_item(index, 42) + assert len(table) == 1 + assert index in table + assert table.get_item(index) == 42 + + table.set_item(index, 41) + assert len(table) == 1 + assert index in table + assert table.get_item(index) == 41 + + def test_map_locations(self, table_type, dtype): + N = 10 + table = table_type() + keys = np.full(N, np.nan, dtype=dtype) + table.map_locations(keys) + assert len(table) == 1 + assert table.get_item(np.nan) == N - 1 + + def test_unique(self, table_type, dtype): + N = 1020 + table = table_type() + keys = np.full(N, np.nan, dtype=dtype) + unique = table.unique(keys) + assert np.all(np.isnan(unique)) and len(unique) == 1 + + +def test_unique_for_nan_objects_floats(): + table = ht.PyObjectHashTable() + keys = np.array([float("nan") for i in range(50)], dtype=np.object_) + unique = table.unique(keys) + assert len(unique) == 1 + + +def test_unique_for_nan_objects_complex(): + table = ht.PyObjectHashTable() + keys = np.array([complex(float("nan"), 1.0) for i in range(50)], dtype=np.object_) + unique = table.unique(keys) + assert len(unique) == 1 + + +def test_unique_for_nan_objects_tuple(): + table = ht.PyObjectHashTable() + keys = np.array( + [1] + [(1.0, (float("nan"), 1.0)) for i in range(50)], dtype=np.object_ + ) + unique = table.unique(keys) + assert len(unique) == 2 + + +@pytest.mark.parametrize( + "dtype", + [ + np.object_, + np.complex128, + np.int64, + np.uint64, + np.float64, + np.complex64, + np.int32, + np.uint32, + np.float32, + np.int16, + np.uint16, + np.int8, + np.uint8, + np.intp, + ], +) +class TestHelpFunctions: + def test_value_count(self, dtype, writable): + N = 43 + expected = (np.arange(N) + N).astype(dtype) + values = np.repeat(expected, 5) + values.flags.writeable = writable + keys, counts, _ = ht.value_count(values, False) + tm.assert_numpy_array_equal(np.sort(keys), expected) + assert np.all(counts == 5) + + def test_value_count_mask(self, dtype): + if dtype == np.object_: + pytest.skip("mask not implemented for object dtype") + values = np.array([1] * 5, dtype=dtype) + mask = np.zeros((5,), dtype=np.bool_) + mask[1] = True + mask[4] = True + keys, counts, na_counter = ht.value_count(values, False, mask=mask) + assert len(keys) == 2 + assert na_counter == 2 + + def test_value_count_stable(self, dtype, writable): + # GH12679 + values = np.array([2, 1, 5, 22, 3, -1, 8]).astype(dtype) + values.flags.writeable = writable + keys, counts, _ = ht.value_count(values, False) + tm.assert_numpy_array_equal(keys, values) + assert np.all(counts == 1) + + def test_duplicated_first(self, dtype, writable): + N = 100 + values = np.repeat(np.arange(N).astype(dtype), 5) + values.flags.writeable = writable + result = ht.duplicated(values) + expected = np.ones_like(values, dtype=np.bool_) + expected[::5] = False + tm.assert_numpy_array_equal(result, expected) + + def test_ismember_yes(self, dtype, writable): + N = 127 + arr = np.arange(N).astype(dtype) + values = np.arange(N).astype(dtype) + arr.flags.writeable = writable + values.flags.writeable = writable + result = ht.ismember(arr, values) + expected = np.ones_like(values, dtype=np.bool_) + tm.assert_numpy_array_equal(result, expected) + + def test_ismember_no(self, dtype): + N = 17 + arr = np.arange(N).astype(dtype) + values = (np.arange(N) + N).astype(dtype) + result = ht.ismember(arr, values) + expected = np.zeros_like(values, dtype=np.bool_) + tm.assert_numpy_array_equal(result, expected) + + def test_mode(self, dtype, writable): + if dtype in (np.int8, np.uint8): + N = 53 + else: + N = 11111 + values = np.repeat(np.arange(N).astype(dtype), 5) + values[0] = 42 + values.flags.writeable = writable + result = ht.mode(values, False)[0] + assert result == 42 + + def test_mode_stable(self, dtype, writable): + values = np.array([2, 1, 5, 22, 3, -1, 8]).astype(dtype) + values.flags.writeable = writable + keys = ht.mode(values, False)[0] + tm.assert_numpy_array_equal(keys, values) + + +def test_modes_with_nans(): + # GH42688, nans aren't mangled + nulls = [pd.NA, np.nan, pd.NaT, None] + values = np.array([True] + nulls * 2, dtype=np.object_) + modes = ht.mode(values, False)[0] + assert modes.size == len(nulls) + + +def test_unique_label_indices_intp(writable): + keys = np.array([1, 2, 2, 2, 1, 3], dtype=np.intp) + keys.flags.writeable = writable + result = ht.unique_label_indices(keys) + expected = np.array([0, 1, 5], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + +def test_unique_label_indices(): + a = np.random.default_rng(2).integers(1, 1 << 10, 1 << 15).astype(np.intp) + + left = ht.unique_label_indices(a) + right = np.unique(a, return_index=True)[1] + + tm.assert_numpy_array_equal(left, right, check_dtype=False) + + a[np.random.default_rng(2).choice(len(a), 10)] = -1 + left = ht.unique_label_indices(a) + right = np.unique(a, return_index=True)[1][1:] + tm.assert_numpy_array_equal(left, right, check_dtype=False) + + +@pytest.mark.parametrize( + "dtype", + [ + np.float64, + np.float32, + np.complex128, + np.complex64, + ], +) +class TestHelpFunctionsWithNans: + def test_value_count(self, dtype): + values = np.array([np.nan, np.nan, np.nan], dtype=dtype) + keys, counts, _ = ht.value_count(values, True) + assert len(keys) == 0 + keys, counts, _ = ht.value_count(values, False) + assert len(keys) == 1 and np.all(np.isnan(keys)) + assert counts[0] == 3 + + def test_duplicated_first(self, dtype): + values = np.array([np.nan, np.nan, np.nan], dtype=dtype) + result = ht.duplicated(values) + expected = np.array([False, True, True]) + tm.assert_numpy_array_equal(result, expected) + + def test_ismember_yes(self, dtype): + arr = np.array([np.nan, np.nan, np.nan], dtype=dtype) + values = np.array([np.nan, np.nan], dtype=dtype) + result = ht.ismember(arr, values) + expected = np.array([True, True, True], dtype=np.bool_) + tm.assert_numpy_array_equal(result, expected) + + def test_ismember_no(self, dtype): + arr = np.array([np.nan, np.nan, np.nan], dtype=dtype) + values = np.array([1], dtype=dtype) + result = ht.ismember(arr, values) + expected = np.array([False, False, False], dtype=np.bool_) + tm.assert_numpy_array_equal(result, expected) + + def test_mode(self, dtype): + values = np.array([42, np.nan, np.nan, np.nan], dtype=dtype) + assert ht.mode(values, True)[0] == 42 + assert np.isnan(ht.mode(values, False)[0]) + + +def test_ismember_tuple_with_nans(): + # GH-41836 + values = [("a", float("nan")), ("b", 1)] + comps = [("a", float("nan"))] + + msg = "isin with argument that is not not a Series" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = isin(values, comps) + expected = np.array([True, False], dtype=np.bool_) + tm.assert_numpy_array_equal(result, expected) + + +def test_float_complex_int_are_equal_as_objects(): + values = ["a", 5, 5.0, 5.0 + 0j] + comps = list(range(129)) + result = isin(np.array(values, dtype=object), np.asarray(comps)) + expected = np.array([False, True, True, True], dtype=np.bool_) + tm.assert_numpy_array_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/libs/test_join.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/libs/test_join.py new file mode 100644 index 0000000000000000000000000000000000000000..ba2e6e713092916648d375a991e3cb4d9fc7828d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/libs/test_join.py @@ -0,0 +1,390 @@ +import numpy as np +import pytest + +from pandas._libs import join as libjoin +from pandas._libs.join import ( + inner_join, + left_outer_join, +) + +import pandas._testing as tm + + +class TestIndexer: + @pytest.mark.parametrize( + "dtype", ["int32", "int64", "float32", "float64", "object"] + ) + def test_outer_join_indexer(self, dtype): + indexer = libjoin.outer_join_indexer + + left = np.arange(3, dtype=dtype) + right = np.arange(2, 5, dtype=dtype) + empty = np.array([], dtype=dtype) + + result, lindexer, rindexer = indexer(left, right) + assert isinstance(result, np.ndarray) + assert isinstance(lindexer, np.ndarray) + assert isinstance(rindexer, np.ndarray) + tm.assert_numpy_array_equal(result, np.arange(5, dtype=dtype)) + exp = np.array([0, 1, 2, -1, -1], dtype=np.intp) + tm.assert_numpy_array_equal(lindexer, exp) + exp = np.array([-1, -1, 0, 1, 2], dtype=np.intp) + tm.assert_numpy_array_equal(rindexer, exp) + + result, lindexer, rindexer = indexer(empty, right) + tm.assert_numpy_array_equal(result, right) + exp = np.array([-1, -1, -1], dtype=np.intp) + tm.assert_numpy_array_equal(lindexer, exp) + exp = np.array([0, 1, 2], dtype=np.intp) + tm.assert_numpy_array_equal(rindexer, exp) + + result, lindexer, rindexer = indexer(left, empty) + tm.assert_numpy_array_equal(result, left) + exp = np.array([0, 1, 2], dtype=np.intp) + tm.assert_numpy_array_equal(lindexer, exp) + exp = np.array([-1, -1, -1], dtype=np.intp) + tm.assert_numpy_array_equal(rindexer, exp) + + def test_cython_left_outer_join(self): + left = np.array([0, 1, 2, 1, 2, 0, 0, 1, 2, 3, 3], dtype=np.intp) + right = np.array([1, 1, 0, 4, 2, 2, 1], dtype=np.intp) + max_group = 5 + + ls, rs = left_outer_join(left, right, max_group) + + exp_ls = left.argsort(kind="mergesort") + exp_rs = right.argsort(kind="mergesort") + + exp_li = np.array([0, 1, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 7, 7, 8, 8, 9, 10]) + exp_ri = np.array( + [0, 0, 0, 1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 4, 5, 4, 5, -1, -1] + ) + + exp_ls = exp_ls.take(exp_li) + exp_ls[exp_li == -1] = -1 + + exp_rs = exp_rs.take(exp_ri) + exp_rs[exp_ri == -1] = -1 + + tm.assert_numpy_array_equal(ls, exp_ls, check_dtype=False) + tm.assert_numpy_array_equal(rs, exp_rs, check_dtype=False) + + def test_cython_right_outer_join(self): + left = np.array([0, 1, 2, 1, 2, 0, 0, 1, 2, 3, 3], dtype=np.intp) + right = np.array([1, 1, 0, 4, 2, 2, 1], dtype=np.intp) + max_group = 5 + + rs, ls = left_outer_join(right, left, max_group) + + exp_ls = left.argsort(kind="mergesort") + exp_rs = right.argsort(kind="mergesort") + + # 0 1 1 1 + exp_li = np.array( + [ + 0, + 1, + 2, + 3, + 4, + 5, + 3, + 4, + 5, + 3, + 4, + 5, + # 2 2 4 + 6, + 7, + 8, + 6, + 7, + 8, + -1, + ] + ) + exp_ri = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6]) + + exp_ls = exp_ls.take(exp_li) + exp_ls[exp_li == -1] = -1 + + exp_rs = exp_rs.take(exp_ri) + exp_rs[exp_ri == -1] = -1 + + tm.assert_numpy_array_equal(ls, exp_ls) + tm.assert_numpy_array_equal(rs, exp_rs) + + def test_cython_inner_join(self): + left = np.array([0, 1, 2, 1, 2, 0, 0, 1, 2, 3, 3], dtype=np.intp) + right = np.array([1, 1, 0, 4, 2, 2, 1, 4], dtype=np.intp) + max_group = 5 + + ls, rs = inner_join(left, right, max_group) + + exp_ls = left.argsort(kind="mergesort") + exp_rs = right.argsort(kind="mergesort") + + exp_li = np.array([0, 1, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 7, 7, 8, 8]) + exp_ri = np.array([0, 0, 0, 1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 4, 5, 4, 5]) + + exp_ls = exp_ls.take(exp_li) + exp_ls[exp_li == -1] = -1 + + exp_rs = exp_rs.take(exp_ri) + exp_rs[exp_ri == -1] = -1 + + tm.assert_numpy_array_equal(ls, exp_ls) + tm.assert_numpy_array_equal(rs, exp_rs) + + +@pytest.mark.parametrize("readonly", [True, False]) +def test_left_join_indexer_unique(readonly): + a = np.array([1, 2, 3, 4, 5], dtype=np.int64) + b = np.array([2, 2, 3, 4, 4], dtype=np.int64) + if readonly: + # GH#37312, GH#37264 + a.setflags(write=False) + b.setflags(write=False) + + result = libjoin.left_join_indexer_unique(b, a) + expected = np.array([1, 1, 2, 3, 3], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + +def test_left_outer_join_bug(): + left = np.array( + [ + 0, + 1, + 0, + 1, + 1, + 2, + 3, + 1, + 0, + 2, + 1, + 2, + 0, + 1, + 1, + 2, + 3, + 2, + 3, + 2, + 1, + 1, + 3, + 0, + 3, + 2, + 3, + 0, + 0, + 2, + 3, + 2, + 0, + 3, + 1, + 3, + 0, + 1, + 3, + 0, + 0, + 1, + 0, + 3, + 1, + 0, + 1, + 0, + 1, + 1, + 0, + 2, + 2, + 2, + 2, + 2, + 0, + 3, + 1, + 2, + 0, + 0, + 3, + 1, + 3, + 2, + 2, + 0, + 1, + 3, + 0, + 2, + 3, + 2, + 3, + 3, + 2, + 3, + 3, + 1, + 3, + 2, + 0, + 0, + 3, + 1, + 1, + 1, + 0, + 2, + 3, + 3, + 1, + 2, + 0, + 3, + 1, + 2, + 0, + 2, + ], + dtype=np.intp, + ) + + right = np.array([3, 1], dtype=np.intp) + max_groups = 4 + + lidx, ridx = libjoin.left_outer_join(left, right, max_groups, sort=False) + + exp_lidx = np.arange(len(left), dtype=np.intp) + exp_ridx = -np.ones(len(left), dtype=np.intp) + + exp_ridx[left == 1] = 1 + exp_ridx[left == 3] = 0 + + tm.assert_numpy_array_equal(lidx, exp_lidx) + tm.assert_numpy_array_equal(ridx, exp_ridx) + + +def test_inner_join_indexer(): + a = np.array([1, 2, 3, 4, 5], dtype=np.int64) + b = np.array([0, 3, 5, 7, 9], dtype=np.int64) + + index, ares, bres = libjoin.inner_join_indexer(a, b) + + index_exp = np.array([3, 5], dtype=np.int64) + tm.assert_almost_equal(index, index_exp) + + aexp = np.array([2, 4], dtype=np.intp) + bexp = np.array([1, 2], dtype=np.intp) + tm.assert_almost_equal(ares, aexp) + tm.assert_almost_equal(bres, bexp) + + a = np.array([5], dtype=np.int64) + b = np.array([5], dtype=np.int64) + + index, ares, bres = libjoin.inner_join_indexer(a, b) + tm.assert_numpy_array_equal(index, np.array([5], dtype=np.int64)) + tm.assert_numpy_array_equal(ares, np.array([0], dtype=np.intp)) + tm.assert_numpy_array_equal(bres, np.array([0], dtype=np.intp)) + + +def test_outer_join_indexer(): + a = np.array([1, 2, 3, 4, 5], dtype=np.int64) + b = np.array([0, 3, 5, 7, 9], dtype=np.int64) + + index, ares, bres = libjoin.outer_join_indexer(a, b) + + index_exp = np.array([0, 1, 2, 3, 4, 5, 7, 9], dtype=np.int64) + tm.assert_almost_equal(index, index_exp) + + aexp = np.array([-1, 0, 1, 2, 3, 4, -1, -1], dtype=np.intp) + bexp = np.array([0, -1, -1, 1, -1, 2, 3, 4], dtype=np.intp) + tm.assert_almost_equal(ares, aexp) + tm.assert_almost_equal(bres, bexp) + + a = np.array([5], dtype=np.int64) + b = np.array([5], dtype=np.int64) + + index, ares, bres = libjoin.outer_join_indexer(a, b) + tm.assert_numpy_array_equal(index, np.array([5], dtype=np.int64)) + tm.assert_numpy_array_equal(ares, np.array([0], dtype=np.intp)) + tm.assert_numpy_array_equal(bres, np.array([0], dtype=np.intp)) + + +def test_left_join_indexer(): + a = np.array([1, 2, 3, 4, 5], dtype=np.int64) + b = np.array([0, 3, 5, 7, 9], dtype=np.int64) + + index, ares, bres = libjoin.left_join_indexer(a, b) + + tm.assert_almost_equal(index, a) + + aexp = np.array([0, 1, 2, 3, 4], dtype=np.intp) + bexp = np.array([-1, -1, 1, -1, 2], dtype=np.intp) + tm.assert_almost_equal(ares, aexp) + tm.assert_almost_equal(bres, bexp) + + a = np.array([5], dtype=np.int64) + b = np.array([5], dtype=np.int64) + + index, ares, bres = libjoin.left_join_indexer(a, b) + tm.assert_numpy_array_equal(index, np.array([5], dtype=np.int64)) + tm.assert_numpy_array_equal(ares, np.array([0], dtype=np.intp)) + tm.assert_numpy_array_equal(bres, np.array([0], dtype=np.intp)) + + +def test_left_join_indexer2(): + idx = np.array([1, 1, 2, 5], dtype=np.int64) + idx2 = np.array([1, 2, 5, 7, 9], dtype=np.int64) + + res, lidx, ridx = libjoin.left_join_indexer(idx2, idx) + + exp_res = np.array([1, 1, 2, 5, 7, 9], dtype=np.int64) + tm.assert_almost_equal(res, exp_res) + + exp_lidx = np.array([0, 0, 1, 2, 3, 4], dtype=np.intp) + tm.assert_almost_equal(lidx, exp_lidx) + + exp_ridx = np.array([0, 1, 2, 3, -1, -1], dtype=np.intp) + tm.assert_almost_equal(ridx, exp_ridx) + + +def test_outer_join_indexer2(): + idx = np.array([1, 1, 2, 5], dtype=np.int64) + idx2 = np.array([1, 2, 5, 7, 9], dtype=np.int64) + + res, lidx, ridx = libjoin.outer_join_indexer(idx2, idx) + + exp_res = np.array([1, 1, 2, 5, 7, 9], dtype=np.int64) + tm.assert_almost_equal(res, exp_res) + + exp_lidx = np.array([0, 0, 1, 2, 3, 4], dtype=np.intp) + tm.assert_almost_equal(lidx, exp_lidx) + + exp_ridx = np.array([0, 1, 2, 3, -1, -1], dtype=np.intp) + tm.assert_almost_equal(ridx, exp_ridx) + + +def test_inner_join_indexer2(): + idx = np.array([1, 1, 2, 5], dtype=np.int64) + idx2 = np.array([1, 2, 5, 7, 9], dtype=np.int64) + + res, lidx, ridx = libjoin.inner_join_indexer(idx2, idx) + + exp_res = np.array([1, 1, 2, 5], dtype=np.int64) + tm.assert_almost_equal(res, exp_res) + + exp_lidx = np.array([0, 0, 1, 2], dtype=np.intp) + tm.assert_almost_equal(lidx, exp_lidx) + + exp_ridx = np.array([0, 1, 2, 3], dtype=np.intp) + tm.assert_almost_equal(ridx, exp_ridx) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/libs/test_lib.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/libs/test_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..8583d8bcc052c4d76e090227272facca2faafa1f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/libs/test_lib.py @@ -0,0 +1,285 @@ +import numpy as np +import pytest + +from pandas._libs import ( + Timedelta, + lib, + writers as libwriters, +) +from pandas.compat import IS64 + +from pandas import Index +import pandas._testing as tm + + +class TestMisc: + def test_max_len_string_array(self): + arr = a = np.array(["foo", "b", np.nan], dtype="object") + assert libwriters.max_len_string_array(arr) == 3 + + # unicode + arr = a.astype("U").astype(object) + assert libwriters.max_len_string_array(arr) == 3 + + # bytes for python3 + arr = a.astype("S").astype(object) + assert libwriters.max_len_string_array(arr) == 3 + + # raises + msg = "No matching signature found" + with pytest.raises(TypeError, match=msg): + libwriters.max_len_string_array(arr.astype("U")) + + def test_fast_unique_multiple_list_gen_sort(self): + keys = [["p", "a"], ["n", "d"], ["a", "s"]] + + gen = (key for key in keys) + expected = np.array(["a", "d", "n", "p", "s"]) + out = lib.fast_unique_multiple_list_gen(gen, sort=True) + tm.assert_numpy_array_equal(np.array(out), expected) + + gen = (key for key in keys) + expected = np.array(["p", "a", "n", "d", "s"]) + out = lib.fast_unique_multiple_list_gen(gen, sort=False) + tm.assert_numpy_array_equal(np.array(out), expected) + + def test_fast_multiget_timedelta_resos(self): + # This will become relevant for test_constructor_dict_timedelta64_index + # once Timedelta constructor preserves reso when passed a + # np.timedelta64 object + td = Timedelta(days=1) + + mapping1 = {td: 1} + mapping2 = {td.as_unit("s"): 1} + + oindex = Index([td * n for n in range(3)])._values.astype(object) + + expected = lib.fast_multiget(mapping1, oindex) + result = lib.fast_multiget(mapping2, oindex) + tm.assert_numpy_array_equal(result, expected) + + # case that can't be cast to td64ns + td = Timedelta(np.timedelta64(146000, "D")) + assert hash(td) == hash(td.as_unit("ms")) + assert hash(td) == hash(td.as_unit("us")) + mapping1 = {td: 1} + mapping2 = {td.as_unit("ms"): 1} + + oindex = Index([td * n for n in range(3)])._values.astype(object) + + expected = lib.fast_multiget(mapping1, oindex) + result = lib.fast_multiget(mapping2, oindex) + tm.assert_numpy_array_equal(result, expected) + + +class TestIndexing: + def test_maybe_indices_to_slice_left_edge(self): + target = np.arange(100) + + # slice + indices = np.array([], dtype=np.intp) + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + + assert isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(target[indices], target[maybe_slice]) + + @pytest.mark.parametrize("end", [1, 2, 5, 20, 99]) + @pytest.mark.parametrize("step", [1, 2, 4]) + def test_maybe_indices_to_slice_left_edge_not_slice_end_steps(self, end, step): + target = np.arange(100) + indices = np.arange(0, end, step, dtype=np.intp) + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + + assert isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(target[indices], target[maybe_slice]) + + # reverse + indices = indices[::-1] + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + + assert isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(target[indices], target[maybe_slice]) + + @pytest.mark.parametrize( + "case", [[2, 1, 2, 0], [2, 2, 1, 0], [0, 1, 2, 1], [-2, 0, 2], [2, 0, -2]] + ) + def test_maybe_indices_to_slice_left_edge_not_slice(self, case): + # not slice + target = np.arange(100) + indices = np.array(case, dtype=np.intp) + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + + assert not isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(maybe_slice, indices) + tm.assert_numpy_array_equal(target[indices], target[maybe_slice]) + + @pytest.mark.parametrize("start", [0, 2, 5, 20, 97, 98]) + @pytest.mark.parametrize("step", [1, 2, 4]) + def test_maybe_indices_to_slice_right_edge(self, start, step): + target = np.arange(100) + + # slice + indices = np.arange(start, 99, step, dtype=np.intp) + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + + assert isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(target[indices], target[maybe_slice]) + + # reverse + indices = indices[::-1] + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + + assert isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(target[indices], target[maybe_slice]) + + def test_maybe_indices_to_slice_right_edge_not_slice(self): + # not slice + target = np.arange(100) + indices = np.array([97, 98, 99, 100], dtype=np.intp) + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + + assert not isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(maybe_slice, indices) + + msg = "index 100 is out of bounds for axis (0|1) with size 100" + + with pytest.raises(IndexError, match=msg): + target[indices] + with pytest.raises(IndexError, match=msg): + target[maybe_slice] + + indices = np.array([100, 99, 98, 97], dtype=np.intp) + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + + assert not isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(maybe_slice, indices) + + with pytest.raises(IndexError, match=msg): + target[indices] + with pytest.raises(IndexError, match=msg): + target[maybe_slice] + + @pytest.mark.parametrize( + "case", [[99, 97, 99, 96], [99, 99, 98, 97], [98, 98, 97, 96]] + ) + def test_maybe_indices_to_slice_right_edge_cases(self, case): + target = np.arange(100) + indices = np.array(case, dtype=np.intp) + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + + assert not isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(maybe_slice, indices) + tm.assert_numpy_array_equal(target[indices], target[maybe_slice]) + + @pytest.mark.parametrize("step", [1, 2, 4, 5, 8, 9]) + def test_maybe_indices_to_slice_both_edges(self, step): + target = np.arange(10) + + # slice + indices = np.arange(0, 9, step, dtype=np.intp) + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + assert isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(target[indices], target[maybe_slice]) + + # reverse + indices = indices[::-1] + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + assert isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(target[indices], target[maybe_slice]) + + @pytest.mark.parametrize("case", [[4, 2, 0, -2], [2, 2, 1, 0], [0, 1, 2, 1]]) + def test_maybe_indices_to_slice_both_edges_not_slice(self, case): + # not slice + target = np.arange(10) + indices = np.array(case, dtype=np.intp) + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + assert not isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(maybe_slice, indices) + tm.assert_numpy_array_equal(target[indices], target[maybe_slice]) + + @pytest.mark.parametrize("start, end", [(2, 10), (5, 25), (65, 97)]) + @pytest.mark.parametrize("step", [1, 2, 4, 20]) + def test_maybe_indices_to_slice_middle(self, start, end, step): + target = np.arange(100) + + # slice + indices = np.arange(start, end, step, dtype=np.intp) + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + + assert isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(target[indices], target[maybe_slice]) + + # reverse + indices = indices[::-1] + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + + assert isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(target[indices], target[maybe_slice]) + + @pytest.mark.parametrize( + "case", [[14, 12, 10, 12], [12, 12, 11, 10], [10, 11, 12, 11]] + ) + def test_maybe_indices_to_slice_middle_not_slice(self, case): + # not slice + target = np.arange(100) + indices = np.array(case, dtype=np.intp) + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + + assert not isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(maybe_slice, indices) + tm.assert_numpy_array_equal(target[indices], target[maybe_slice]) + + def test_maybe_booleans_to_slice(self): + arr = np.array([0, 0, 1, 1, 1, 0, 1], dtype=np.uint8) + result = lib.maybe_booleans_to_slice(arr) + assert result.dtype == np.bool_ + + result = lib.maybe_booleans_to_slice(arr[:0]) + assert result == slice(0, 0) + + def test_get_reverse_indexer(self): + indexer = np.array([-1, -1, 1, 2, 0, -1, 3, 4], dtype=np.intp) + result = lib.get_reverse_indexer(indexer, 5) + expected = np.array([4, 2, 3, 6, 7], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("dtype", ["int64", "int32"]) + def test_is_range_indexer(self, dtype): + # GH#50592 + left = np.arange(0, 100, dtype=dtype) + assert lib.is_range_indexer(left, 100) + + @pytest.mark.skipif( + not IS64, + reason="2**31 is too big for Py_ssize_t on 32-bit. " + "It doesn't matter though since you cannot create an array that long on 32-bit", + ) + @pytest.mark.parametrize("dtype", ["int64", "int32"]) + def test_is_range_indexer_big_n(self, dtype): + # GH53616 + left = np.arange(0, 100, dtype=dtype) + + assert not lib.is_range_indexer(left, 2**31) + + @pytest.mark.parametrize("dtype", ["int64", "int32"]) + def test_is_range_indexer_not_equal(self, dtype): + # GH#50592 + left = np.array([1, 2], dtype=dtype) + assert not lib.is_range_indexer(left, 2) + + @pytest.mark.parametrize("dtype", ["int64", "int32"]) + def test_is_range_indexer_not_equal_shape(self, dtype): + # GH#50592 + left = np.array([0, 1, 2], dtype=dtype) + assert not lib.is_range_indexer(left, 2) + + +def test_cache_readonly_preserve_docstrings(): + # GH18197 + assert Index.hasnans.__doc__ is not None + + +def test_no_default_pickle(): + # GH#40397 + obj = tm.round_trip_pickle(lib.no_default) + assert obj is lib.no_default diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/libs/test_libalgos.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/libs/test_libalgos.py new file mode 100644 index 0000000000000000000000000000000000000000..42d09c72aab2baa9636093d172d864cbe0e41b12 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/libs/test_libalgos.py @@ -0,0 +1,162 @@ +from datetime import datetime +from itertools import permutations + +import numpy as np + +from pandas._libs import algos as libalgos + +import pandas._testing as tm + + +def test_ensure_platform_int(): + arr = np.arange(100, dtype=np.intp) + + result = libalgos.ensure_platform_int(arr) + assert result is arr + + +def test_is_lexsorted(): + failure = [ + np.array( + ([3] * 32) + ([2] * 32) + ([1] * 32) + ([0] * 32), + dtype="int64", + ), + np.array( + list(range(31))[::-1] * 4, + dtype="int64", + ), + ] + + assert not libalgos.is_lexsorted(failure) + + +def test_groupsort_indexer(): + a = np.random.default_rng(2).integers(0, 1000, 100).astype(np.intp) + b = np.random.default_rng(2).integers(0, 1000, 100).astype(np.intp) + + result = libalgos.groupsort_indexer(a, 1000)[0] + + # need to use a stable sort + # np.argsort returns int, groupsort_indexer + # always returns intp + expected = np.argsort(a, kind="mergesort") + expected = expected.astype(np.intp) + + tm.assert_numpy_array_equal(result, expected) + + # compare with lexsort + # np.lexsort returns int, groupsort_indexer + # always returns intp + key = a * 1000 + b + result = libalgos.groupsort_indexer(key, 1000000)[0] + expected = np.lexsort((b, a)) + expected = expected.astype(np.intp) + + tm.assert_numpy_array_equal(result, expected) + + +class TestPadBackfill: + def test_backfill(self): + old = np.array([1, 5, 10], dtype=np.int64) + new = np.array(list(range(12)), dtype=np.int64) + + filler = libalgos.backfill["int64_t"](old, new) + + expect_filler = np.array([0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, -1], dtype=np.intp) + tm.assert_numpy_array_equal(filler, expect_filler) + + # corner case + old = np.array([1, 4], dtype=np.int64) + new = np.array(list(range(5, 10)), dtype=np.int64) + filler = libalgos.backfill["int64_t"](old, new) + + expect_filler = np.array([-1, -1, -1, -1, -1], dtype=np.intp) + tm.assert_numpy_array_equal(filler, expect_filler) + + def test_pad(self): + old = np.array([1, 5, 10], dtype=np.int64) + new = np.array(list(range(12)), dtype=np.int64) + + filler = libalgos.pad["int64_t"](old, new) + + expect_filler = np.array([-1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2], dtype=np.intp) + tm.assert_numpy_array_equal(filler, expect_filler) + + # corner case + old = np.array([5, 10], dtype=np.int64) + new = np.arange(5, dtype=np.int64) + filler = libalgos.pad["int64_t"](old, new) + expect_filler = np.array([-1, -1, -1, -1, -1], dtype=np.intp) + tm.assert_numpy_array_equal(filler, expect_filler) + + def test_pad_backfill_object_segfault(self): + old = np.array([], dtype="O") + new = np.array([datetime(2010, 12, 31)], dtype="O") + + result = libalgos.pad["object"](old, new) + expected = np.array([-1], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + result = libalgos.pad["object"](new, old) + expected = np.array([], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + result = libalgos.backfill["object"](old, new) + expected = np.array([-1], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + result = libalgos.backfill["object"](new, old) + expected = np.array([], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + +class TestInfinity: + def test_infinity_sort(self): + # GH#13445 + # numpy's argsort can be unhappy if something is less than + # itself. Instead, let's give our infinities a self-consistent + # ordering, but outside the float extended real line. + + Inf = libalgos.Infinity() + NegInf = libalgos.NegInfinity() + + ref_nums = [NegInf, float("-inf"), -1e100, 0, 1e100, float("inf"), Inf] + + assert all(Inf >= x for x in ref_nums) + assert all(Inf > x or x is Inf for x in ref_nums) + assert Inf >= Inf and Inf == Inf + assert not Inf < Inf and not Inf > Inf + assert libalgos.Infinity() == libalgos.Infinity() + assert not libalgos.Infinity() != libalgos.Infinity() + + assert all(NegInf <= x for x in ref_nums) + assert all(NegInf < x or x is NegInf for x in ref_nums) + assert NegInf <= NegInf and NegInf == NegInf + assert not NegInf < NegInf and not NegInf > NegInf + assert libalgos.NegInfinity() == libalgos.NegInfinity() + assert not libalgos.NegInfinity() != libalgos.NegInfinity() + + for perm in permutations(ref_nums): + assert sorted(perm) == ref_nums + + # smoke tests + np.array([libalgos.Infinity()] * 32).argsort() + np.array([libalgos.NegInfinity()] * 32).argsort() + + def test_infinity_against_nan(self): + Inf = libalgos.Infinity() + NegInf = libalgos.NegInfinity() + + assert not Inf > np.nan + assert not Inf >= np.nan + assert not Inf < np.nan + assert not Inf <= np.nan + assert not Inf == np.nan + assert Inf != np.nan + + assert not NegInf > np.nan + assert not NegInf >= np.nan + assert not NegInf < np.nan + assert not NegInf <= np.nan + assert not NegInf == np.nan + assert NegInf != np.nan diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reductions/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reductions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e3851753b67421842a0d3d9fd5f88e7eb72734dd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reductions/__init__.py @@ -0,0 +1,4 @@ +""" +Tests for reductions where we want to test for matching behavior across +Array, Index, Series, and DataFrame methods. +""" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reductions/test_reductions.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reductions/test_reductions.py new file mode 100644 index 0000000000000000000000000000000000000000..30ec0d0affaa3b30facdb8bf55062017a217b5ae --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reductions/test_reductions.py @@ -0,0 +1,1673 @@ +from datetime import ( + datetime, + timedelta, +) +from decimal import Decimal + +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + Categorical, + DataFrame, + DatetimeIndex, + Index, + NaT, + Period, + PeriodIndex, + RangeIndex, + Series, + Timedelta, + TimedeltaIndex, + Timestamp, + date_range, + isna, + period_range, + timedelta_range, + to_timedelta, +) +import pandas._testing as tm +from pandas.core import nanops +from pandas.core.arrays.string_arrow import ArrowStringArrayNumpySemantics + + +def get_objs(): + indexes = [ + Index([True, False] * 5, name="a"), + Index(np.arange(10), dtype=np.int64, name="a"), + Index(np.arange(10), dtype=np.float64, name="a"), + DatetimeIndex(date_range("2020-01-01", periods=10), name="a"), + DatetimeIndex(date_range("2020-01-01", periods=10), name="a").tz_localize( + tz="US/Eastern" + ), + PeriodIndex(period_range("2020-01-01", periods=10, freq="D"), name="a"), + Index([str(i) for i in range(10)], name="a"), + ] + + arr = np.random.default_rng(2).standard_normal(10) + series = [Series(arr, index=idx, name="a") for idx in indexes] + + objs = indexes + series + return objs + + +class TestReductions: + @pytest.mark.filterwarnings( + "ignore:Period with BDay freq is deprecated:FutureWarning" + ) + @pytest.mark.parametrize("opname", ["max", "min"]) + @pytest.mark.parametrize("obj", get_objs()) + def test_ops(self, opname, obj): + result = getattr(obj, opname)() + if not isinstance(obj, PeriodIndex): + if isinstance(obj.values, ArrowStringArrayNumpySemantics): + # max not on the interface + expected = getattr(np.array(obj.values), opname)() + else: + expected = getattr(obj.values, opname)() + else: + expected = Period(ordinal=getattr(obj.asi8, opname)(), freq=obj.freq) + + if getattr(obj, "tz", None) is not None: + # We need to de-localize before comparing to the numpy-produced result + expected = expected.astype("M8[ns]").astype("int64") + assert result._value == expected + else: + assert result == expected + + @pytest.mark.parametrize("opname", ["max", "min"]) + @pytest.mark.parametrize( + "dtype, val", + [ + ("object", 2.0), + ("float64", 2.0), + ("datetime64[ns]", datetime(2011, 11, 1)), + ("Int64", 2), + ("boolean", True), + ], + ) + def test_nanminmax(self, opname, dtype, val, index_or_series): + # GH#7261 + klass = index_or_series + + def check_missing(res): + if dtype == "datetime64[ns]": + return res is NaT + elif dtype in ["Int64", "boolean"]: + return res is pd.NA + else: + return isna(res) + + obj = klass([None], dtype=dtype) + assert check_missing(getattr(obj, opname)()) + assert check_missing(getattr(obj, opname)(skipna=False)) + + obj = klass([], dtype=dtype) + assert check_missing(getattr(obj, opname)()) + assert check_missing(getattr(obj, opname)(skipna=False)) + + if dtype == "object": + # generic test with object only works for empty / all NaN + return + + obj = klass([None, val], dtype=dtype) + assert getattr(obj, opname)() == val + assert check_missing(getattr(obj, opname)(skipna=False)) + + obj = klass([None, val, None], dtype=dtype) + assert getattr(obj, opname)() == val + assert check_missing(getattr(obj, opname)(skipna=False)) + + @pytest.mark.parametrize("opname", ["max", "min"]) + def test_nanargminmax(self, opname, index_or_series): + # GH#7261 + klass = index_or_series + arg_op = "arg" + opname if klass is Index else "idx" + opname + + obj = klass([NaT, datetime(2011, 11, 1)]) + assert getattr(obj, arg_op)() == 1 + + msg = ( + "The behavior of (DatetimeIndex|Series).argmax/argmin with " + "skipna=False and NAs" + ) + if klass is Series: + msg = "The behavior of Series.(idxmax|idxmin) with all-NA" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = getattr(obj, arg_op)(skipna=False) + if klass is Series: + assert np.isnan(result) + else: + assert result == -1 + + obj = klass([NaT, datetime(2011, 11, 1), NaT]) + # check DatetimeIndex non-monotonic path + assert getattr(obj, arg_op)() == 1 + with tm.assert_produces_warning(FutureWarning, match=msg): + result = getattr(obj, arg_op)(skipna=False) + if klass is Series: + assert np.isnan(result) + else: + assert result == -1 + + @pytest.mark.parametrize("opname", ["max", "min"]) + @pytest.mark.parametrize("dtype", ["M8[ns]", "datetime64[ns, UTC]"]) + def test_nanops_empty_object(self, opname, index_or_series, dtype): + klass = index_or_series + arg_op = "arg" + opname if klass is Index else "idx" + opname + + obj = klass([], dtype=dtype) + + assert getattr(obj, opname)() is NaT + assert getattr(obj, opname)(skipna=False) is NaT + + with pytest.raises(ValueError, match="empty sequence"): + getattr(obj, arg_op)() + with pytest.raises(ValueError, match="empty sequence"): + getattr(obj, arg_op)(skipna=False) + + def test_argminmax(self): + obj = Index(np.arange(5, dtype="int64")) + assert obj.argmin() == 0 + assert obj.argmax() == 4 + + obj = Index([np.nan, 1, np.nan, 2]) + assert obj.argmin() == 1 + assert obj.argmax() == 3 + msg = "The behavior of Index.argmax/argmin with skipna=False and NAs" + with tm.assert_produces_warning(FutureWarning, match=msg): + assert obj.argmin(skipna=False) == -1 + with tm.assert_produces_warning(FutureWarning, match=msg): + assert obj.argmax(skipna=False) == -1 + + obj = Index([np.nan]) + with tm.assert_produces_warning(FutureWarning, match=msg): + assert obj.argmin() == -1 + with tm.assert_produces_warning(FutureWarning, match=msg): + assert obj.argmax() == -1 + with tm.assert_produces_warning(FutureWarning, match=msg): + assert obj.argmin(skipna=False) == -1 + with tm.assert_produces_warning(FutureWarning, match=msg): + assert obj.argmax(skipna=False) == -1 + + msg = "The behavior of DatetimeIndex.argmax/argmin with skipna=False and NAs" + obj = Index([NaT, datetime(2011, 11, 1), datetime(2011, 11, 2), NaT]) + assert obj.argmin() == 1 + assert obj.argmax() == 2 + with tm.assert_produces_warning(FutureWarning, match=msg): + assert obj.argmin(skipna=False) == -1 + with tm.assert_produces_warning(FutureWarning, match=msg): + assert obj.argmax(skipna=False) == -1 + + obj = Index([NaT]) + with tm.assert_produces_warning(FutureWarning, match=msg): + assert obj.argmin() == -1 + with tm.assert_produces_warning(FutureWarning, match=msg): + assert obj.argmax() == -1 + with tm.assert_produces_warning(FutureWarning, match=msg): + assert obj.argmin(skipna=False) == -1 + with tm.assert_produces_warning(FutureWarning, match=msg): + assert obj.argmax(skipna=False) == -1 + + @pytest.mark.parametrize("op, expected_col", [["max", "a"], ["min", "b"]]) + def test_same_tz_min_max_axis_1(self, op, expected_col): + # GH 10390 + df = DataFrame( + date_range("2016-01-01 00:00:00", periods=3, tz="UTC"), columns=["a"] + ) + df["b"] = df.a.subtract(Timedelta(seconds=3600)) + result = getattr(df, op)(axis=1) + expected = df[expected_col].rename(None) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("func", ["maximum", "minimum"]) + def test_numpy_reduction_with_tz_aware_dtype(self, tz_aware_fixture, func): + # GH 15552 + tz = tz_aware_fixture + arg = pd.to_datetime(["2019"]).tz_localize(tz) + expected = Series(arg) + result = getattr(np, func)(expected, expected) + tm.assert_series_equal(result, expected) + + def test_nan_int_timedelta_sum(self): + # GH 27185 + df = DataFrame( + { + "A": Series([1, 2, NaT], dtype="timedelta64[ns]"), + "B": Series([1, 2, np.nan], dtype="Int64"), + } + ) + expected = Series({"A": Timedelta(3), "B": 3}) + result = df.sum() + tm.assert_series_equal(result, expected) + + +class TestIndexReductions: + # Note: the name TestIndexReductions indicates these tests + # were moved from a Index-specific test file, _not_ that these tests are + # intended long-term to be Index-specific + + @pytest.mark.parametrize( + "start,stop,step", + [ + (0, 400, 3), + (500, 0, -6), + (-(10**6), 10**6, 4), + (10**6, -(10**6), -4), + (0, 10, 20), + ], + ) + def test_max_min_range(self, start, stop, step): + # GH#17607 + idx = RangeIndex(start, stop, step) + expected = idx._values.max() + result = idx.max() + assert result == expected + + # skipna should be irrelevant since RangeIndex should never have NAs + result2 = idx.max(skipna=False) + assert result2 == expected + + expected = idx._values.min() + result = idx.min() + assert result == expected + + # skipna should be irrelevant since RangeIndex should never have NAs + result2 = idx.min(skipna=False) + assert result2 == expected + + # empty + idx = RangeIndex(start, stop, -step) + assert isna(idx.max()) + assert isna(idx.min()) + + def test_minmax_timedelta64(self): + # monotonic + idx1 = TimedeltaIndex(["1 days", "2 days", "3 days"]) + assert idx1.is_monotonic_increasing + + # non-monotonic + idx2 = TimedeltaIndex(["1 days", np.nan, "3 days", "NaT"]) + assert not idx2.is_monotonic_increasing + + for idx in [idx1, idx2]: + assert idx.min() == Timedelta("1 days") + assert idx.max() == Timedelta("3 days") + assert idx.argmin() == 0 + assert idx.argmax() == 2 + + @pytest.mark.parametrize("op", ["min", "max"]) + def test_minmax_timedelta_empty_or_na(self, op): + # Return NaT + obj = TimedeltaIndex([]) + assert getattr(obj, op)() is NaT + + obj = TimedeltaIndex([NaT]) + assert getattr(obj, op)() is NaT + + obj = TimedeltaIndex([NaT, NaT, NaT]) + assert getattr(obj, op)() is NaT + + def test_numpy_minmax_timedelta64(self): + td = timedelta_range("16815 days", "16820 days", freq="D") + + assert np.min(td) == Timedelta("16815 days") + assert np.max(td) == Timedelta("16820 days") + + errmsg = "the 'out' parameter is not supported" + with pytest.raises(ValueError, match=errmsg): + np.min(td, out=0) + with pytest.raises(ValueError, match=errmsg): + np.max(td, out=0) + + assert np.argmin(td) == 0 + assert np.argmax(td) == 5 + + errmsg = "the 'out' parameter is not supported" + with pytest.raises(ValueError, match=errmsg): + np.argmin(td, out=0) + with pytest.raises(ValueError, match=errmsg): + np.argmax(td, out=0) + + def test_timedelta_ops(self): + # GH#4984 + # make sure ops return Timedelta + s = Series( + [Timestamp("20130101") + timedelta(seconds=i * i) for i in range(10)] + ) + td = s.diff() + + result = td.mean() + expected = to_timedelta(timedelta(seconds=9)) + assert result == expected + + result = td.to_frame().mean() + assert result[0] == expected + + result = td.quantile(0.1) + expected = Timedelta(np.timedelta64(2600, "ms")) + assert result == expected + + result = td.median() + expected = to_timedelta("00:00:09") + assert result == expected + + result = td.to_frame().median() + assert result[0] == expected + + # GH#6462 + # consistency in returned values for sum + result = td.sum() + expected = to_timedelta("00:01:21") + assert result == expected + + result = td.to_frame().sum() + assert result[0] == expected + + # std + result = td.std() + expected = to_timedelta(Series(td.dropna().values).std()) + assert result == expected + + result = td.to_frame().std() + assert result[0] == expected + + # GH#10040 + # make sure NaT is properly handled by median() + s = Series([Timestamp("2015-02-03"), Timestamp("2015-02-07")]) + assert s.diff().median() == timedelta(days=4) + + s = Series( + [Timestamp("2015-02-03"), Timestamp("2015-02-07"), Timestamp("2015-02-15")] + ) + assert s.diff().median() == timedelta(days=6) + + @pytest.mark.parametrize("opname", ["skew", "kurt", "sem", "prod", "var"]) + def test_invalid_td64_reductions(self, opname): + s = Series( + [Timestamp("20130101") + timedelta(seconds=i * i) for i in range(10)] + ) + td = s.diff() + + msg = "|".join( + [ + f"reduction operation '{opname}' not allowed for this dtype", + rf"cannot perform {opname} with type timedelta64\[ns\]", + f"does not support reduction '{opname}'", + ] + ) + + with pytest.raises(TypeError, match=msg): + getattr(td, opname)() + + with pytest.raises(TypeError, match=msg): + getattr(td.to_frame(), opname)(numeric_only=False) + + def test_minmax_tz(self, tz_naive_fixture): + tz = tz_naive_fixture + # monotonic + idx1 = DatetimeIndex(["2011-01-01", "2011-01-02", "2011-01-03"], tz=tz) + assert idx1.is_monotonic_increasing + + # non-monotonic + idx2 = DatetimeIndex( + ["2011-01-01", NaT, "2011-01-03", "2011-01-02", NaT], tz=tz + ) + assert not idx2.is_monotonic_increasing + + for idx in [idx1, idx2]: + assert idx.min() == Timestamp("2011-01-01", tz=tz) + assert idx.max() == Timestamp("2011-01-03", tz=tz) + assert idx.argmin() == 0 + assert idx.argmax() == 2 + + @pytest.mark.parametrize("op", ["min", "max"]) + def test_minmax_nat_datetime64(self, op): + # Return NaT + obj = DatetimeIndex([]) + assert isna(getattr(obj, op)()) + + obj = DatetimeIndex([NaT]) + assert isna(getattr(obj, op)()) + + obj = DatetimeIndex([NaT, NaT, NaT]) + assert isna(getattr(obj, op)()) + + def test_numpy_minmax_integer(self): + # GH#26125 + idx = Index([1, 2, 3]) + + expected = idx.values.max() + result = np.max(idx) + assert result == expected + + expected = idx.values.min() + result = np.min(idx) + assert result == expected + + errmsg = "the 'out' parameter is not supported" + with pytest.raises(ValueError, match=errmsg): + np.min(idx, out=0) + with pytest.raises(ValueError, match=errmsg): + np.max(idx, out=0) + + expected = idx.values.argmax() + result = np.argmax(idx) + assert result == expected + + expected = idx.values.argmin() + result = np.argmin(idx) + assert result == expected + + errmsg = "the 'out' parameter is not supported" + with pytest.raises(ValueError, match=errmsg): + np.argmin(idx, out=0) + with pytest.raises(ValueError, match=errmsg): + np.argmax(idx, out=0) + + def test_numpy_minmax_range(self): + # GH#26125 + idx = RangeIndex(0, 10, 3) + + result = np.max(idx) + assert result == 9 + + result = np.min(idx) + assert result == 0 + + errmsg = "the 'out' parameter is not supported" + with pytest.raises(ValueError, match=errmsg): + np.min(idx, out=0) + with pytest.raises(ValueError, match=errmsg): + np.max(idx, out=0) + + # No need to test again argmax/argmin compat since the implementation + # is the same as basic integer index + + def test_numpy_minmax_datetime64(self): + dr = date_range(start="2016-01-15", end="2016-01-20") + + assert np.min(dr) == Timestamp("2016-01-15 00:00:00") + assert np.max(dr) == Timestamp("2016-01-20 00:00:00") + + errmsg = "the 'out' parameter is not supported" + with pytest.raises(ValueError, match=errmsg): + np.min(dr, out=0) + + with pytest.raises(ValueError, match=errmsg): + np.max(dr, out=0) + + assert np.argmin(dr) == 0 + assert np.argmax(dr) == 5 + + errmsg = "the 'out' parameter is not supported" + with pytest.raises(ValueError, match=errmsg): + np.argmin(dr, out=0) + + with pytest.raises(ValueError, match=errmsg): + np.argmax(dr, out=0) + + def test_minmax_period(self): + # monotonic + idx1 = PeriodIndex([NaT, "2011-01-01", "2011-01-02", "2011-01-03"], freq="D") + assert not idx1.is_monotonic_increasing + assert idx1[1:].is_monotonic_increasing + + # non-monotonic + idx2 = PeriodIndex( + ["2011-01-01", NaT, "2011-01-03", "2011-01-02", NaT], freq="D" + ) + assert not idx2.is_monotonic_increasing + + for idx in [idx1, idx2]: + assert idx.min() == Period("2011-01-01", freq="D") + assert idx.max() == Period("2011-01-03", freq="D") + assert idx1.argmin() == 1 + assert idx2.argmin() == 0 + assert idx1.argmax() == 3 + assert idx2.argmax() == 2 + + @pytest.mark.parametrize("op", ["min", "max"]) + @pytest.mark.parametrize("data", [[], [NaT], [NaT, NaT, NaT]]) + def test_minmax_period_empty_nat(self, op, data): + # Return NaT + obj = PeriodIndex(data, freq="M") + result = getattr(obj, op)() + assert result is NaT + + def test_numpy_minmax_period(self): + pr = period_range(start="2016-01-15", end="2016-01-20") + + assert np.min(pr) == Period("2016-01-15", freq="D") + assert np.max(pr) == Period("2016-01-20", freq="D") + + errmsg = "the 'out' parameter is not supported" + with pytest.raises(ValueError, match=errmsg): + np.min(pr, out=0) + with pytest.raises(ValueError, match=errmsg): + np.max(pr, out=0) + + assert np.argmin(pr) == 0 + assert np.argmax(pr) == 5 + + errmsg = "the 'out' parameter is not supported" + with pytest.raises(ValueError, match=errmsg): + np.argmin(pr, out=0) + with pytest.raises(ValueError, match=errmsg): + np.argmax(pr, out=0) + + def test_min_max_categorical(self): + ci = pd.CategoricalIndex(list("aabbca"), categories=list("cab"), ordered=False) + msg = ( + r"Categorical is not ordered for operation min\n" + r"you can use .as_ordered\(\) to change the Categorical to an ordered one\n" + ) + with pytest.raises(TypeError, match=msg): + ci.min() + msg = ( + r"Categorical is not ordered for operation max\n" + r"you can use .as_ordered\(\) to change the Categorical to an ordered one\n" + ) + with pytest.raises(TypeError, match=msg): + ci.max() + + ci = pd.CategoricalIndex(list("aabbca"), categories=list("cab"), ordered=True) + assert ci.min() == "c" + assert ci.max() == "b" + + +class TestSeriesReductions: + # Note: the name TestSeriesReductions indicates these tests + # were moved from a series-specific test file, _not_ that these tests are + # intended long-term to be series-specific + + def test_sum_inf(self): + s = Series(np.random.default_rng(2).standard_normal(10)) + s2 = s.copy() + + s[5:8] = np.inf + s2[5:8] = np.nan + + assert np.isinf(s.sum()) + + arr = np.random.default_rng(2).standard_normal((100, 100)).astype("f4") + arr[:, 2] = np.inf + + msg = "use_inf_as_na option is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + with pd.option_context("mode.use_inf_as_na", True): + tm.assert_almost_equal(s.sum(), s2.sum()) + + res = nanops.nansum(arr, axis=1) + assert np.isinf(res).all() + + @pytest.mark.parametrize( + "dtype", ["float64", "Float32", "Int64", "boolean", "object"] + ) + @pytest.mark.parametrize("use_bottleneck", [True, False]) + @pytest.mark.parametrize("method, unit", [("sum", 0.0), ("prod", 1.0)]) + def test_empty(self, method, unit, use_bottleneck, dtype): + with pd.option_context("use_bottleneck", use_bottleneck): + # GH#9422 / GH#18921 + # Entirely empty + s = Series([], dtype=dtype) + # NA by default + result = getattr(s, method)() + assert result == unit + + # Explicit + result = getattr(s, method)(min_count=0) + assert result == unit + + result = getattr(s, method)(min_count=1) + assert isna(result) + + # Skipna, default + result = getattr(s, method)(skipna=True) + result == unit + + # Skipna, explicit + result = getattr(s, method)(skipna=True, min_count=0) + assert result == unit + + result = getattr(s, method)(skipna=True, min_count=1) + assert isna(result) + + result = getattr(s, method)(skipna=False, min_count=0) + assert result == unit + + result = getattr(s, method)(skipna=False, min_count=1) + assert isna(result) + + # All-NA + s = Series([np.nan], dtype=dtype) + # NA by default + result = getattr(s, method)() + assert result == unit + + # Explicit + result = getattr(s, method)(min_count=0) + assert result == unit + + result = getattr(s, method)(min_count=1) + assert isna(result) + + # Skipna, default + result = getattr(s, method)(skipna=True) + result == unit + + # skipna, explicit + result = getattr(s, method)(skipna=True, min_count=0) + assert result == unit + + result = getattr(s, method)(skipna=True, min_count=1) + assert isna(result) + + # Mix of valid, empty + s = Series([np.nan, 1], dtype=dtype) + # Default + result = getattr(s, method)() + assert result == 1.0 + + # Explicit + result = getattr(s, method)(min_count=0) + assert result == 1.0 + + result = getattr(s, method)(min_count=1) + assert result == 1.0 + + # Skipna + result = getattr(s, method)(skipna=True) + assert result == 1.0 + + result = getattr(s, method)(skipna=True, min_count=0) + assert result == 1.0 + + # GH#844 (changed in GH#9422) + df = DataFrame(np.empty((10, 0)), dtype=dtype) + assert (getattr(df, method)(1) == unit).all() + + s = Series([1], dtype=dtype) + result = getattr(s, method)(min_count=2) + assert isna(result) + + result = getattr(s, method)(skipna=False, min_count=2) + assert isna(result) + + s = Series([np.nan], dtype=dtype) + result = getattr(s, method)(min_count=2) + assert isna(result) + + s = Series([np.nan, 1], dtype=dtype) + result = getattr(s, method)(min_count=2) + assert isna(result) + + @pytest.mark.parametrize("method", ["mean", "var"]) + @pytest.mark.parametrize("dtype", ["Float64", "Int64", "boolean"]) + def test_ops_consistency_on_empty_nullable(self, method, dtype): + # GH#34814 + # consistency for nullable dtypes on empty or ALL-NA mean + + # empty series + eser = Series([], dtype=dtype) + result = getattr(eser, method)() + assert result is pd.NA + + # ALL-NA series + nser = Series([np.nan], dtype=dtype) + result = getattr(nser, method)() + assert result is pd.NA + + @pytest.mark.parametrize("method", ["mean", "median", "std", "var"]) + def test_ops_consistency_on_empty(self, method): + # GH#7869 + # consistency on empty + + # float + result = getattr(Series(dtype=float), method)() + assert isna(result) + + # timedelta64[ns] + tdser = Series([], dtype="m8[ns]") + if method == "var": + msg = "|".join( + [ + "operation 'var' not allowed", + r"cannot perform var with type timedelta64\[ns\]", + "does not support reduction 'var'", + ] + ) + with pytest.raises(TypeError, match=msg): + getattr(tdser, method)() + else: + result = getattr(tdser, method)() + assert result is NaT + + def test_nansum_buglet(self): + ser = Series([1.0, np.nan], index=[0, 1]) + result = np.nansum(ser) + tm.assert_almost_equal(result, 1) + + @pytest.mark.parametrize("use_bottleneck", [True, False]) + @pytest.mark.parametrize("dtype", ["int32", "int64"]) + def test_sum_overflow_int(self, use_bottleneck, dtype): + with pd.option_context("use_bottleneck", use_bottleneck): + # GH#6915 + # overflowing on the smaller int dtypes + v = np.arange(5000000, dtype=dtype) + s = Series(v) + + result = s.sum(skipna=False) + assert int(result) == v.sum(dtype="int64") + result = s.min(skipna=False) + assert int(result) == 0 + result = s.max(skipna=False) + assert int(result) == v[-1] + + @pytest.mark.parametrize("use_bottleneck", [True, False]) + @pytest.mark.parametrize("dtype", ["float32", "float64"]) + def test_sum_overflow_float(self, use_bottleneck, dtype): + with pd.option_context("use_bottleneck", use_bottleneck): + v = np.arange(5000000, dtype=dtype) + s = Series(v) + + result = s.sum(skipna=False) + assert result == v.sum(dtype=dtype) + result = s.min(skipna=False) + assert np.allclose(float(result), 0.0) + result = s.max(skipna=False) + assert np.allclose(float(result), v[-1]) + + def test_mean_masked_overflow(self): + # GH#48378 + val = 100_000_000_000_000_000 + n_elements = 100 + na = np.array([val] * n_elements) + ser = Series([val] * n_elements, dtype="Int64") + + result_numpy = np.mean(na) + result_masked = ser.mean() + assert result_masked - result_numpy == 0 + assert result_masked == 1e17 + + @pytest.mark.parametrize("ddof, exp", [(1, 2.5), (0, 2.0)]) + def test_var_masked_array(self, ddof, exp): + # GH#48379 + ser = Series([1, 2, 3, 4, 5], dtype="Int64") + ser_numpy_dtype = Series([1, 2, 3, 4, 5], dtype="int64") + result = ser.var(ddof=ddof) + result_numpy_dtype = ser_numpy_dtype.var(ddof=ddof) + assert result == result_numpy_dtype + assert result == exp + + @pytest.mark.parametrize("dtype", ("m8[ns]", "m8[ns]", "M8[ns]", "M8[ns, UTC]")) + @pytest.mark.parametrize("skipna", [True, False]) + def test_empty_timeseries_reductions_return_nat(self, dtype, skipna): + # covers GH#11245 + assert Series([], dtype=dtype).min(skipna=skipna) is NaT + assert Series([], dtype=dtype).max(skipna=skipna) is NaT + + def test_numpy_argmin(self): + # See GH#16830 + data = np.arange(1, 11) + + s = Series(data, index=data) + result = np.argmin(s) + + expected = np.argmin(data) + assert result == expected + + result = s.argmin() + + assert result == expected + + msg = "the 'out' parameter is not supported" + with pytest.raises(ValueError, match=msg): + np.argmin(s, out=data) + + def test_numpy_argmax(self): + # See GH#16830 + data = np.arange(1, 11) + + ser = Series(data, index=data) + result = np.argmax(ser) + expected = np.argmax(data) + assert result == expected + + result = ser.argmax() + + assert result == expected + + msg = "the 'out' parameter is not supported" + with pytest.raises(ValueError, match=msg): + np.argmax(ser, out=data) + + def test_idxmin_dt64index(self, unit): + # GH#43587 should have NaT instead of NaN + dti = DatetimeIndex(["NaT", "2015-02-08", "NaT"]).as_unit(unit) + ser = Series([1.0, 2.0, np.nan], index=dti) + msg = "The behavior of Series.idxmin with all-NA values" + with tm.assert_produces_warning(FutureWarning, match=msg): + res = ser.idxmin(skipna=False) + assert res is NaT + msg = "The behavior of Series.idxmax with all-NA values" + with tm.assert_produces_warning(FutureWarning, match=msg): + res = ser.idxmax(skipna=False) + assert res is NaT + + df = ser.to_frame() + msg = "The behavior of DataFrame.idxmin with all-NA values" + with tm.assert_produces_warning(FutureWarning, match=msg): + res = df.idxmin(skipna=False) + assert res.dtype == f"M8[{unit}]" + assert res.isna().all() + msg = "The behavior of DataFrame.idxmax with all-NA values" + with tm.assert_produces_warning(FutureWarning, match=msg): + res = df.idxmax(skipna=False) + assert res.dtype == f"M8[{unit}]" + assert res.isna().all() + + def test_idxmin(self): + # test idxmin + # _check_stat_op approach can not be used here because of isna check. + string_series = Series(range(20), dtype=np.float64, name="series") + + # add some NaNs + string_series[5:15] = np.nan + + # skipna or no + assert string_series[string_series.idxmin()] == string_series.min() + msg = "The behavior of Series.idxmin" + with tm.assert_produces_warning(FutureWarning, match=msg): + assert isna(string_series.idxmin(skipna=False)) + + # no NaNs + nona = string_series.dropna() + assert nona[nona.idxmin()] == nona.min() + assert nona.index.values.tolist().index(nona.idxmin()) == nona.values.argmin() + + # all NaNs + allna = string_series * np.nan + with tm.assert_produces_warning(FutureWarning, match=msg): + assert isna(allna.idxmin()) + + # datetime64[ns] + s = Series(date_range("20130102", periods=6)) + result = s.idxmin() + assert result == 0 + + s[0] = np.nan + result = s.idxmin() + assert result == 1 + + def test_idxmax(self): + # test idxmax + # _check_stat_op approach can not be used here because of isna check. + string_series = Series(range(20), dtype=np.float64, name="series") + + # add some NaNs + string_series[5:15] = np.nan + + # skipna or no + assert string_series[string_series.idxmax()] == string_series.max() + msg = "The behavior of Series.idxmax with all-NA values" + with tm.assert_produces_warning(FutureWarning, match=msg): + assert isna(string_series.idxmax(skipna=False)) + + # no NaNs + nona = string_series.dropna() + assert nona[nona.idxmax()] == nona.max() + assert nona.index.values.tolist().index(nona.idxmax()) == nona.values.argmax() + + # all NaNs + allna = string_series * np.nan + msg = "The behavior of Series.idxmax with all-NA values" + with tm.assert_produces_warning(FutureWarning, match=msg): + assert isna(allna.idxmax()) + + s = Series(date_range("20130102", periods=6)) + result = s.idxmax() + assert result == 5 + + s[5] = np.nan + result = s.idxmax() + assert result == 4 + + # Index with float64 dtype + # GH#5914 + s = Series([1, 2, 3], [1.1, 2.1, 3.1]) + result = s.idxmax() + assert result == 3.1 + result = s.idxmin() + assert result == 1.1 + + s = Series(s.index, s.index) + result = s.idxmax() + assert result == 3.1 + result = s.idxmin() + assert result == 1.1 + + def test_all_any(self): + ts = Series( + np.arange(10, dtype=np.float64), + index=date_range("2020-01-01", periods=10), + name="ts", + ) + bool_series = ts > 0 + assert not bool_series.all() + assert bool_series.any() + + # Alternative types, with implicit 'object' dtype. + s = Series(["abc", True]) + assert s.any() + + def test_numpy_all_any(self, index_or_series): + # GH#40180 + idx = index_or_series([0, 1, 2]) + assert not np.all(idx) + assert np.any(idx) + idx = Index([1, 2, 3]) + assert np.all(idx) + + def test_all_any_skipna(self): + # Check skipna, with implicit 'object' dtype. + s1 = Series([np.nan, True]) + s2 = Series([np.nan, False]) + assert s1.all(skipna=False) # nan && True => True + assert s1.all(skipna=True) + assert s2.any(skipna=False) + assert not s2.any(skipna=True) + + def test_all_any_bool_only(self): + s = Series([False, False, True, True, False, True], index=[0, 0, 1, 1, 2, 2]) + + # GH#47500 - test bool_only works + assert s.any(bool_only=True) + assert not s.all(bool_only=True) + + @pytest.mark.parametrize("bool_agg_func", ["any", "all"]) + @pytest.mark.parametrize("skipna", [True, False]) + def test_any_all_object_dtype(self, bool_agg_func, skipna): + # GH#12863 + ser = Series(["a", "b", "c", "d", "e"], dtype=object) + result = getattr(ser, bool_agg_func)(skipna=skipna) + expected = True + + assert result == expected + + @pytest.mark.parametrize("bool_agg_func", ["any", "all"]) + @pytest.mark.parametrize( + "data", [[False, None], [None, False], [False, np.nan], [np.nan, False]] + ) + def test_any_all_object_dtype_missing(self, data, bool_agg_func): + # GH#27709 + ser = Series(data) + result = getattr(ser, bool_agg_func)(skipna=False) + + # None is treated is False, but np.nan is treated as True + expected = bool_agg_func == "any" and None not in data + assert result == expected + + @pytest.mark.parametrize("dtype", ["boolean", "Int64", "UInt64", "Float64"]) + @pytest.mark.parametrize("bool_agg_func", ["any", "all"]) + @pytest.mark.parametrize("skipna", [True, False]) + @pytest.mark.parametrize( + # expected_data indexed as [[skipna=False/any, skipna=False/all], + # [skipna=True/any, skipna=True/all]] + "data,expected_data", + [ + ([0, 0, 0], [[False, False], [False, False]]), + ([1, 1, 1], [[True, True], [True, True]]), + ([pd.NA, pd.NA, pd.NA], [[pd.NA, pd.NA], [False, True]]), + ([0, pd.NA, 0], [[pd.NA, False], [False, False]]), + ([1, pd.NA, 1], [[True, pd.NA], [True, True]]), + ([1, pd.NA, 0], [[True, False], [True, False]]), + ], + ) + def test_any_all_nullable_kleene_logic( + self, bool_agg_func, skipna, data, dtype, expected_data + ): + # GH-37506, GH-41967 + ser = Series(data, dtype=dtype) + expected = expected_data[skipna][bool_agg_func == "all"] + + result = getattr(ser, bool_agg_func)(skipna=skipna) + assert (result is pd.NA and expected is pd.NA) or result == expected + + def test_any_axis1_bool_only(self): + # GH#32432 + df = DataFrame({"A": [True, False], "B": [1, 2]}) + result = df.any(axis=1, bool_only=True) + expected = Series([True, False]) + tm.assert_series_equal(result, expected) + + def test_any_all_datetimelike(self): + # GH#38723 these may not be the desired long-term behavior (GH#34479) + # but in the interim should be internally consistent + dta = date_range("1995-01-02", periods=3)._data + ser = Series(dta) + df = DataFrame(ser) + + msg = "'(any|all)' with datetime64 dtypes is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + # GH#34479 + assert dta.all() + assert dta.any() + + assert ser.all() + assert ser.any() + + assert df.any().all() + assert df.all().all() + + dta = dta.tz_localize("UTC") + ser = Series(dta) + df = DataFrame(ser) + + with tm.assert_produces_warning(FutureWarning, match=msg): + # GH#34479 + assert dta.all() + assert dta.any() + + assert ser.all() + assert ser.any() + + assert df.any().all() + assert df.all().all() + + tda = dta - dta[0] + ser = Series(tda) + df = DataFrame(ser) + + assert tda.any() + assert not tda.all() + + assert ser.any() + assert not ser.all() + + assert df.any().all() + assert not df.all().any() + + def test_any_all_pyarrow_string(self): + # GH#54591 + pytest.importorskip("pyarrow") + ser = Series(["", "a"], dtype="string[pyarrow_numpy]") + assert ser.any() + assert not ser.all() + + ser = Series([None, "a"], dtype="string[pyarrow_numpy]") + assert ser.any() + assert ser.all() + assert not ser.all(skipna=False) + + ser = Series([None, ""], dtype="string[pyarrow_numpy]") + assert not ser.any() + assert not ser.all() + + ser = Series(["a", "b"], dtype="string[pyarrow_numpy]") + assert ser.any() + assert ser.all() + + def test_timedelta64_analytics(self): + # index min/max + dti = date_range("2012-1-1", periods=3, freq="D") + td = Series(dti) - Timestamp("20120101") + + result = td.idxmin() + assert result == 0 + + result = td.idxmax() + assert result == 2 + + # GH#2982 + # with NaT + td[0] = np.nan + + result = td.idxmin() + assert result == 1 + + result = td.idxmax() + assert result == 2 + + # abs + s1 = Series(date_range("20120101", periods=3)) + s2 = Series(date_range("20120102", periods=3)) + expected = Series(s2 - s1) + + result = np.abs(s1 - s2) + tm.assert_series_equal(result, expected) + + result = (s1 - s2).abs() + tm.assert_series_equal(result, expected) + + # max/min + result = td.max() + expected = Timedelta("2 days") + assert result == expected + + result = td.min() + expected = Timedelta("1 days") + assert result == expected + + @pytest.mark.parametrize( + "test_input,error_type", + [ + (Series([], dtype="float64"), ValueError), + # For strings, or any Series with dtype 'O' + (Series(["foo", "bar", "baz"]), TypeError), + (Series([(1,), (2,)]), TypeError), + # For mixed data types + (Series(["foo", "foo", "bar", "bar", None, np.nan, "baz"]), TypeError), + ], + ) + def test_assert_idxminmax_empty_raises(self, test_input, error_type): + """ + Cases where ``Series.argmax`` and related should raise an exception + """ + test_input = Series([], dtype="float64") + msg = "attempt to get argmin of an empty sequence" + with pytest.raises(ValueError, match=msg): + test_input.idxmin() + with pytest.raises(ValueError, match=msg): + test_input.idxmin(skipna=False) + msg = "attempt to get argmax of an empty sequence" + with pytest.raises(ValueError, match=msg): + test_input.idxmax() + with pytest.raises(ValueError, match=msg): + test_input.idxmax(skipna=False) + + def test_idxminmax_object_dtype(self, using_infer_string): + # pre-2.1 object-dtype was disallowed for argmin/max + ser = Series(["foo", "bar", "baz"]) + assert ser.idxmax() == 0 + assert ser.idxmax(skipna=False) == 0 + assert ser.idxmin() == 1 + assert ser.idxmin(skipna=False) == 1 + + ser2 = Series([(1,), (2,)]) + assert ser2.idxmax() == 1 + assert ser2.idxmax(skipna=False) == 1 + assert ser2.idxmin() == 0 + assert ser2.idxmin(skipna=False) == 0 + + if not using_infer_string: + # attempting to compare np.nan with string raises + ser3 = Series(["foo", "foo", "bar", "bar", None, np.nan, "baz"]) + msg = "'>' not supported between instances of 'float' and 'str'" + with pytest.raises(TypeError, match=msg): + ser3.idxmax() + with pytest.raises(TypeError, match=msg): + ser3.idxmax(skipna=False) + msg = "'<' not supported between instances of 'float' and 'str'" + with pytest.raises(TypeError, match=msg): + ser3.idxmin() + with pytest.raises(TypeError, match=msg): + ser3.idxmin(skipna=False) + + def test_idxminmax_object_frame(self): + # GH#4279 + df = DataFrame([["zimm", 2.5], ["biff", 1.0], ["bid", 12.0]]) + res = df.idxmax() + exp = Series([0, 2]) + tm.assert_series_equal(res, exp) + + def test_idxminmax_object_tuples(self): + # GH#43697 + ser = Series([(1, 3), (2, 2), (3, 1)]) + assert ser.idxmax() == 2 + assert ser.idxmin() == 0 + assert ser.idxmax(skipna=False) == 2 + assert ser.idxmin(skipna=False) == 0 + + def test_idxminmax_object_decimals(self): + # GH#40685 + df = DataFrame( + { + "idx": [0, 1], + "x": [Decimal("8.68"), Decimal("42.23")], + "y": [Decimal("7.11"), Decimal("79.61")], + } + ) + res = df.idxmax() + exp = Series({"idx": 1, "x": 1, "y": 1}) + tm.assert_series_equal(res, exp) + + res2 = df.idxmin() + exp2 = exp - 1 + tm.assert_series_equal(res2, exp2) + + def test_argminmax_object_ints(self): + # GH#18021 + ser = Series([0, 1], dtype="object") + assert ser.argmax() == 1 + assert ser.argmin() == 0 + assert ser.argmax(skipna=False) == 1 + assert ser.argmin(skipna=False) == 0 + + def test_idxminmax_with_inf(self): + # For numeric data with NA and Inf (GH #13595) + s = Series([0, -np.inf, np.inf, np.nan]) + + assert s.idxmin() == 1 + msg = "The behavior of Series.idxmin with all-NA values" + with tm.assert_produces_warning(FutureWarning, match=msg): + assert np.isnan(s.idxmin(skipna=False)) + + assert s.idxmax() == 2 + msg = "The behavior of Series.idxmax with all-NA values" + with tm.assert_produces_warning(FutureWarning, match=msg): + assert np.isnan(s.idxmax(skipna=False)) + + msg = "use_inf_as_na option is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + # Using old-style behavior that treats floating point nan, -inf, and + # +inf as missing + with pd.option_context("mode.use_inf_as_na", True): + assert s.idxmin() == 0 + assert np.isnan(s.idxmin(skipna=False)) + assert s.idxmax() == 0 + np.isnan(s.idxmax(skipna=False)) + + def test_sum_uint64(self): + # GH 53401 + s = Series([10000000000000000000], dtype="uint64") + result = s.sum() + expected = np.uint64(10000000000000000000) + tm.assert_almost_equal(result, expected) + + +class TestDatetime64SeriesReductions: + # Note: the name TestDatetime64SeriesReductions indicates these tests + # were moved from a series-specific test file, _not_ that these tests are + # intended long-term to be series-specific + + @pytest.mark.parametrize( + "nat_ser", + [ + Series([NaT, NaT]), + Series([NaT, Timedelta("nat")]), + Series([Timedelta("nat"), Timedelta("nat")]), + ], + ) + def test_minmax_nat_series(self, nat_ser): + # GH#23282 + assert nat_ser.min() is NaT + assert nat_ser.max() is NaT + assert nat_ser.min(skipna=False) is NaT + assert nat_ser.max(skipna=False) is NaT + + @pytest.mark.parametrize( + "nat_df", + [ + DataFrame([NaT, NaT]), + DataFrame([NaT, Timedelta("nat")]), + DataFrame([Timedelta("nat"), Timedelta("nat")]), + ], + ) + def test_minmax_nat_dataframe(self, nat_df): + # GH#23282 + assert nat_df.min()[0] is NaT + assert nat_df.max()[0] is NaT + assert nat_df.min(skipna=False)[0] is NaT + assert nat_df.max(skipna=False)[0] is NaT + + def test_min_max(self): + rng = date_range("1/1/2000", "12/31/2000") + rng2 = rng.take(np.random.default_rng(2).permutation(len(rng))) + + the_min = rng2.min() + the_max = rng2.max() + assert isinstance(the_min, Timestamp) + assert isinstance(the_max, Timestamp) + assert the_min == rng[0] + assert the_max == rng[-1] + + assert rng.min() == rng[0] + assert rng.max() == rng[-1] + + def test_min_max_series(self): + rng = date_range("1/1/2000", periods=10, freq="4h") + lvls = ["A", "A", "A", "B", "B", "B", "C", "C", "C", "C"] + df = DataFrame( + { + "TS": rng, + "V": np.random.default_rng(2).standard_normal(len(rng)), + "L": lvls, + } + ) + + result = df.TS.max() + exp = Timestamp(df.TS.iat[-1]) + assert isinstance(result, Timestamp) + assert result == exp + + result = df.TS.min() + exp = Timestamp(df.TS.iat[0]) + assert isinstance(result, Timestamp) + assert result == exp + + +class TestCategoricalSeriesReductions: + # Note: the name TestCategoricalSeriesReductions indicates these tests + # were moved from a series-specific test file, _not_ that these tests are + # intended long-term to be series-specific + + @pytest.mark.parametrize("function", ["min", "max"]) + def test_min_max_unordered_raises(self, function): + # unordered cats have no min/max + cat = Series(Categorical(["a", "b", "c", "d"], ordered=False)) + msg = f"Categorical is not ordered for operation {function}" + with pytest.raises(TypeError, match=msg): + getattr(cat, function)() + + @pytest.mark.parametrize( + "values, categories", + [ + (list("abc"), list("abc")), + (list("abc"), list("cba")), + (list("abc") + [np.nan], list("cba")), + ([1, 2, 3], [3, 2, 1]), + ([1, 2, 3, np.nan], [3, 2, 1]), + ], + ) + @pytest.mark.parametrize("function", ["min", "max"]) + def test_min_max_ordered(self, values, categories, function): + # GH 25303 + cat = Series(Categorical(values, categories=categories, ordered=True)) + result = getattr(cat, function)(skipna=True) + expected = categories[0] if function == "min" else categories[2] + assert result == expected + + @pytest.mark.parametrize("function", ["min", "max"]) + @pytest.mark.parametrize("skipna", [True, False]) + def test_min_max_ordered_with_nan_only(self, function, skipna): + # https://github.com/pandas-dev/pandas/issues/33450 + cat = Series(Categorical([np.nan], categories=[1, 2], ordered=True)) + result = getattr(cat, function)(skipna=skipna) + assert result is np.nan + + @pytest.mark.parametrize("function", ["min", "max"]) + @pytest.mark.parametrize("skipna", [True, False]) + def test_min_max_skipna(self, function, skipna): + cat = Series( + Categorical(["a", "b", np.nan, "a"], categories=["b", "a"], ordered=True) + ) + result = getattr(cat, function)(skipna=skipna) + + if skipna is True: + expected = "b" if function == "min" else "a" + assert result == expected + else: + assert result is np.nan + + +class TestSeriesMode: + # Note: the name TestSeriesMode indicates these tests + # were moved from a series-specific test file, _not_ that these tests are + # intended long-term to be series-specific + + @pytest.mark.parametrize( + "dropna, expected", + [(True, Series([], dtype=np.float64)), (False, Series([], dtype=np.float64))], + ) + def test_mode_empty(self, dropna, expected): + s = Series([], dtype=np.float64) + result = s.mode(dropna) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "dropna, data, expected", + [ + (True, [1, 1, 1, 2], [1]), + (True, [1, 1, 1, 2, 3, 3, 3], [1, 3]), + (False, [1, 1, 1, 2], [1]), + (False, [1, 1, 1, 2, 3, 3, 3], [1, 3]), + ], + ) + @pytest.mark.parametrize( + "dt", list(np.typecodes["AllInteger"] + np.typecodes["Float"]) + ) + def test_mode_numerical(self, dropna, data, expected, dt): + s = Series(data, dtype=dt) + result = s.mode(dropna) + expected = Series(expected, dtype=dt) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("dropna, expected", [(True, [1.0]), (False, [1, np.nan])]) + def test_mode_numerical_nan(self, dropna, expected): + s = Series([1, 1, 2, np.nan, np.nan]) + result = s.mode(dropna) + expected = Series(expected) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "dropna, expected1, expected2, expected3", + [(True, ["b"], ["bar"], ["nan"]), (False, ["b"], [np.nan], ["nan"])], + ) + def test_mode_str_obj(self, dropna, expected1, expected2, expected3): + # Test string and object types. + data = ["a"] * 2 + ["b"] * 3 + + s = Series(data, dtype="c") + result = s.mode(dropna) + expected1 = Series(expected1, dtype="c") + tm.assert_series_equal(result, expected1) + + data = ["foo", "bar", "bar", np.nan, np.nan, np.nan] + + s = Series(data, dtype=object) + result = s.mode(dropna) + expected2 = Series(expected2, dtype=None if expected2 == ["bar"] else object) + tm.assert_series_equal(result, expected2) + + data = ["foo", "bar", "bar", np.nan, np.nan, np.nan] + + s = Series(data, dtype=object).astype(str) + result = s.mode(dropna) + expected3 = Series(expected3) + tm.assert_series_equal(result, expected3) + + @pytest.mark.parametrize( + "dropna, expected1, expected2", + [(True, ["foo"], ["foo"]), (False, ["foo"], [np.nan])], + ) + def test_mode_mixeddtype(self, dropna, expected1, expected2): + s = Series([1, "foo", "foo"]) + result = s.mode(dropna) + expected = Series(expected1) + tm.assert_series_equal(result, expected) + + s = Series([1, "foo", "foo", np.nan, np.nan, np.nan]) + result = s.mode(dropna) + expected = Series(expected2, dtype=None if expected2 == ["foo"] else object) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "dropna, expected1, expected2", + [ + ( + True, + ["1900-05-03", "2011-01-03", "2013-01-02"], + ["2011-01-03", "2013-01-02"], + ), + (False, [np.nan], [np.nan, "2011-01-03", "2013-01-02"]), + ], + ) + def test_mode_datetime(self, dropna, expected1, expected2): + s = Series( + ["2011-01-03", "2013-01-02", "1900-05-03", "nan", "nan"], dtype="M8[ns]" + ) + result = s.mode(dropna) + expected1 = Series(expected1, dtype="M8[ns]") + tm.assert_series_equal(result, expected1) + + s = Series( + [ + "2011-01-03", + "2013-01-02", + "1900-05-03", + "2011-01-03", + "2013-01-02", + "nan", + "nan", + ], + dtype="M8[ns]", + ) + result = s.mode(dropna) + expected2 = Series(expected2, dtype="M8[ns]") + tm.assert_series_equal(result, expected2) + + @pytest.mark.parametrize( + "dropna, expected1, expected2", + [ + (True, ["-1 days", "0 days", "1 days"], ["2 min", "1 day"]), + (False, [np.nan], [np.nan, "2 min", "1 day"]), + ], + ) + def test_mode_timedelta(self, dropna, expected1, expected2): + # gh-5986: Test timedelta types. + + s = Series( + ["1 days", "-1 days", "0 days", "nan", "nan"], dtype="timedelta64[ns]" + ) + result = s.mode(dropna) + expected1 = Series(expected1, dtype="timedelta64[ns]") + tm.assert_series_equal(result, expected1) + + s = Series( + [ + "1 day", + "1 day", + "-1 day", + "-1 day 2 min", + "2 min", + "2 min", + "nan", + "nan", + ], + dtype="timedelta64[ns]", + ) + result = s.mode(dropna) + expected2 = Series(expected2, dtype="timedelta64[ns]") + tm.assert_series_equal(result, expected2) + + @pytest.mark.parametrize( + "dropna, expected1, expected2, expected3", + [ + ( + True, + Categorical([1, 2], categories=[1, 2]), + Categorical(["a"], categories=[1, "a"]), + Categorical([3, 1], categories=[3, 2, 1], ordered=True), + ), + ( + False, + Categorical([np.nan], categories=[1, 2]), + Categorical([np.nan, "a"], categories=[1, "a"]), + Categorical([np.nan, 3, 1], categories=[3, 2, 1], ordered=True), + ), + ], + ) + def test_mode_category(self, dropna, expected1, expected2, expected3): + s = Series(Categorical([1, 2, np.nan, np.nan])) + result = s.mode(dropna) + expected1 = Series(expected1, dtype="category") + tm.assert_series_equal(result, expected1) + + s = Series(Categorical([1, "a", "a", np.nan, np.nan])) + result = s.mode(dropna) + expected2 = Series(expected2, dtype="category") + tm.assert_series_equal(result, expected2) + + s = Series( + Categorical( + [1, 1, 2, 3, 3, np.nan, np.nan], categories=[3, 2, 1], ordered=True + ) + ) + result = s.mode(dropna) + expected3 = Series(expected3, dtype="category") + tm.assert_series_equal(result, expected3) + + @pytest.mark.parametrize( + "dropna, expected1, expected2", + [(True, [2**63], [1, 2**63]), (False, [2**63], [1, 2**63])], + ) + def test_mode_intoverflow(self, dropna, expected1, expected2): + # Test for uint64 overflow. + s = Series([1, 2**63, 2**63], dtype=np.uint64) + result = s.mode(dropna) + expected1 = Series(expected1, dtype=np.uint64) + tm.assert_series_equal(result, expected1) + + s = Series([1, 2**63], dtype=np.uint64) + result = s.mode(dropna) + expected2 = Series(expected2, dtype=np.uint64) + tm.assert_series_equal(result, expected2) + + def test_mode_sortwarning(self): + # Check for the warning that is raised when the mode + # results cannot be sorted + + expected = Series(["foo", np.nan]) + s = Series([1, "foo", "foo", np.nan, np.nan]) + + with tm.assert_produces_warning(UserWarning): + result = s.mode(dropna=False) + result = result.sort_values().reset_index(drop=True) + + tm.assert_series_equal(result, expected) + + def test_mode_boolean_with_na(self): + # GH#42107 + ser = Series([True, False, True, pd.NA], dtype="boolean") + result = ser.mode() + expected = Series({0: True}, dtype="boolean") + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "array,expected,dtype", + [ + ( + [0, 1j, 1, 1, 1 + 1j, 1 + 2j], + Series([1], dtype=np.complex128), + np.complex128, + ), + ( + [0, 1j, 1, 1, 1 + 1j, 1 + 2j], + Series([1], dtype=np.complex64), + np.complex64, + ), + ( + [1 + 1j, 2j, 1 + 1j], + Series([1 + 1j], dtype=np.complex128), + np.complex128, + ), + ], + ) + def test_single_mode_value_complex(self, array, expected, dtype): + result = Series(array, dtype=dtype).mode() + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "array,expected,dtype", + [ + ( + # no modes + [0, 1j, 1, 1 + 1j, 1 + 2j], + Series([0j, 1j, 1 + 0j, 1 + 1j, 1 + 2j], dtype=np.complex128), + np.complex128, + ), + ( + [1 + 1j, 2j, 1 + 1j, 2j, 3], + Series([2j, 1 + 1j], dtype=np.complex64), + np.complex64, + ), + ], + ) + def test_multimode_complex(self, array, expected, dtype): + # GH 17927 + # mode tries to sort multimodal series. + # Complex numbers are sorted by their magnitude + result = Series(array, dtype=dtype).mode() + tm.assert_series_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reductions/test_stat_reductions.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reductions/test_stat_reductions.py new file mode 100644 index 0000000000000000000000000000000000000000..8fbb78737474c8abf34b8720603e32f6a93d83e7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reductions/test_stat_reductions.py @@ -0,0 +1,276 @@ +""" +Tests for statistical reductions of 2nd moment or higher: var, skew, kurt, ... +""" +import inspect + +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + Series, + date_range, +) +import pandas._testing as tm + + +class TestDatetimeLikeStatReductions: + @pytest.mark.parametrize("box", [Series, pd.Index, pd.array]) + def test_dt64_mean(self, tz_naive_fixture, box): + tz = tz_naive_fixture + + dti = date_range("2001-01-01", periods=11, tz=tz) + # shuffle so that we are not just working with monotone-increasing + dti = dti.take([4, 1, 3, 10, 9, 7, 8, 5, 0, 2, 6]) + dtarr = dti._data + + obj = box(dtarr) + assert obj.mean() == pd.Timestamp("2001-01-06", tz=tz) + assert obj.mean(skipna=False) == pd.Timestamp("2001-01-06", tz=tz) + + # dtarr[-2] will be the first date 2001-01-1 + dtarr[-2] = pd.NaT + + obj = box(dtarr) + assert obj.mean() == pd.Timestamp("2001-01-06 07:12:00", tz=tz) + assert obj.mean(skipna=False) is pd.NaT + + @pytest.mark.parametrize("box", [Series, pd.Index, pd.array]) + @pytest.mark.parametrize("freq", ["s", "h", "D", "W", "B"]) + def test_period_mean(self, box, freq): + # GH#24757 + dti = date_range("2001-01-01", periods=11) + # shuffle so that we are not just working with monotone-increasing + dti = dti.take([4, 1, 3, 10, 9, 7, 8, 5, 0, 2, 6]) + + warn = FutureWarning if freq == "B" else None + msg = r"PeriodDtype\[B\] is deprecated" + with tm.assert_produces_warning(warn, match=msg): + parr = dti._data.to_period(freq) + obj = box(parr) + with pytest.raises(TypeError, match="ambiguous"): + obj.mean() + with pytest.raises(TypeError, match="ambiguous"): + obj.mean(skipna=True) + + # parr[-2] will be the first date 2001-01-1 + parr[-2] = pd.NaT + + with pytest.raises(TypeError, match="ambiguous"): + obj.mean() + with pytest.raises(TypeError, match="ambiguous"): + obj.mean(skipna=True) + + @pytest.mark.parametrize("box", [Series, pd.Index, pd.array]) + def test_td64_mean(self, box): + m8values = np.array([0, 3, -2, -7, 1, 2, -1, 3, 5, -2, 4], "m8[D]") + tdi = pd.TimedeltaIndex(m8values).as_unit("ns") + + tdarr = tdi._data + obj = box(tdarr, copy=False) + + result = obj.mean() + expected = np.array(tdarr).mean() + assert result == expected + + tdarr[0] = pd.NaT + assert obj.mean(skipna=False) is pd.NaT + + result2 = obj.mean(skipna=True) + assert result2 == tdi[1:].mean() + + # exact equality fails by 1 nanosecond + assert result2.round("us") == (result * 11.0 / 10).round("us") + + +class TestSeriesStatReductions: + # Note: the name TestSeriesStatReductions indicates these tests + # were moved from a series-specific test file, _not_ that these tests are + # intended long-term to be series-specific + + def _check_stat_op( + self, name, alternate, string_series_, check_objects=False, check_allna=False + ): + with pd.option_context("use_bottleneck", False): + f = getattr(Series, name) + + # add some NaNs + string_series_[5:15] = np.nan + + # mean, idxmax, idxmin, min, and max are valid for dates + if name not in ["max", "min", "mean", "median", "std"]: + ds = Series(date_range("1/1/2001", periods=10)) + msg = f"does not support reduction '{name}'" + with pytest.raises(TypeError, match=msg): + f(ds) + + # skipna or no + assert pd.notna(f(string_series_)) + assert pd.isna(f(string_series_, skipna=False)) + + # check the result is correct + nona = string_series_.dropna() + tm.assert_almost_equal(f(nona), alternate(nona.values)) + tm.assert_almost_equal(f(string_series_), alternate(nona.values)) + + allna = string_series_ * np.nan + + if check_allna: + assert np.isnan(f(allna)) + + # dtype=object with None, it works! + s = Series([1, 2, 3, None, 5]) + f(s) + + # GH#2888 + items = [0] + items.extend(range(2**40, 2**40 + 1000)) + s = Series(items, dtype="int64") + tm.assert_almost_equal(float(f(s)), float(alternate(s.values))) + + # check date range + if check_objects: + s = Series(pd.bdate_range("1/1/2000", periods=10)) + res = f(s) + exp = alternate(s) + assert res == exp + + # check on string data + if name not in ["sum", "min", "max"]: + with pytest.raises(TypeError, match=None): + f(Series(list("abc"))) + + # Invalid axis. + msg = "No axis named 1 for object type Series" + with pytest.raises(ValueError, match=msg): + f(string_series_, axis=1) + + if "numeric_only" in inspect.getfullargspec(f).args: + # only the index is string; dtype is float + f(string_series_, numeric_only=True) + + def test_sum(self): + string_series = Series(range(20), dtype=np.float64, name="series") + self._check_stat_op("sum", np.sum, string_series, check_allna=False) + + def test_mean(self): + string_series = Series(range(20), dtype=np.float64, name="series") + self._check_stat_op("mean", np.mean, string_series) + + def test_median(self): + string_series = Series(range(20), dtype=np.float64, name="series") + self._check_stat_op("median", np.median, string_series) + + # test with integers, test failure + int_ts = Series(np.ones(10, dtype=int), index=range(10)) + tm.assert_almost_equal(np.median(int_ts), int_ts.median()) + + def test_prod(self): + string_series = Series(range(20), dtype=np.float64, name="series") + self._check_stat_op("prod", np.prod, string_series) + + def test_min(self): + string_series = Series(range(20), dtype=np.float64, name="series") + self._check_stat_op("min", np.min, string_series, check_objects=True) + + def test_max(self): + string_series = Series(range(20), dtype=np.float64, name="series") + self._check_stat_op("max", np.max, string_series, check_objects=True) + + def test_var_std(self): + string_series = Series(range(20), dtype=np.float64, name="series") + datetime_series = Series( + np.arange(10, dtype=np.float64), + index=date_range("2020-01-01", periods=10), + name="ts", + ) + + alt = lambda x: np.std(x, ddof=1) + self._check_stat_op("std", alt, string_series) + + alt = lambda x: np.var(x, ddof=1) + self._check_stat_op("var", alt, string_series) + + result = datetime_series.std(ddof=4) + expected = np.std(datetime_series.values, ddof=4) + tm.assert_almost_equal(result, expected) + + result = datetime_series.var(ddof=4) + expected = np.var(datetime_series.values, ddof=4) + tm.assert_almost_equal(result, expected) + + # 1 - element series with ddof=1 + s = datetime_series.iloc[[0]] + result = s.var(ddof=1) + assert pd.isna(result) + + result = s.std(ddof=1) + assert pd.isna(result) + + def test_sem(self): + string_series = Series(range(20), dtype=np.float64, name="series") + datetime_series = Series( + np.arange(10, dtype=np.float64), + index=date_range("2020-01-01", periods=10), + name="ts", + ) + + alt = lambda x: np.std(x, ddof=1) / np.sqrt(len(x)) + self._check_stat_op("sem", alt, string_series) + + result = datetime_series.sem(ddof=4) + expected = np.std(datetime_series.values, ddof=4) / np.sqrt( + len(datetime_series.values) + ) + tm.assert_almost_equal(result, expected) + + # 1 - element series with ddof=1 + s = datetime_series.iloc[[0]] + result = s.sem(ddof=1) + assert pd.isna(result) + + def test_skew(self): + sp_stats = pytest.importorskip("scipy.stats") + + string_series = Series(range(20), dtype=np.float64, name="series") + + alt = lambda x: sp_stats.skew(x, bias=False) + self._check_stat_op("skew", alt, string_series) + + # test corner cases, skew() returns NaN unless there's at least 3 + # values + min_N = 3 + for i in range(1, min_N + 1): + s = Series(np.ones(i)) + df = DataFrame(np.ones((i, i))) + if i < min_N: + assert np.isnan(s.skew()) + assert np.isnan(df.skew()).all() + else: + assert 0 == s.skew() + assert isinstance(s.skew(), np.float64) # GH53482 + assert (df.skew() == 0).all() + + def test_kurt(self): + sp_stats = pytest.importorskip("scipy.stats") + + string_series = Series(range(20), dtype=np.float64, name="series") + + alt = lambda x: sp_stats.kurtosis(x, bias=False) + self._check_stat_op("kurt", alt, string_series) + + def test_kurt_corner(self): + # test corner cases, kurt() returns NaN unless there's at least 4 + # values + min_N = 4 + for i in range(1, min_N + 1): + s = Series(np.ones(i)) + df = DataFrame(np.ones((i, i))) + if i < min_N: + assert np.isnan(s.kurt()) + assert np.isnan(df.kurt()).all() + else: + assert 0 == s.kurt() + assert isinstance(s.kurt(), np.float64) # GH53482 + assert (df.kurt() == 0).all() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/resample/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/resample/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/resample/conftest.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/resample/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..1033d908eb22d38e58b46fd0e4a042dc31b3aca9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/resample/conftest.py @@ -0,0 +1,143 @@ +from datetime import datetime + +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Series, +) + +# The various methods we support +downsample_methods = [ + "min", + "max", + "first", + "last", + "sum", + "mean", + "sem", + "median", + "prod", + "var", + "std", + "ohlc", + "quantile", +] +upsample_methods = ["count", "size"] +series_methods = ["nunique"] +resample_methods = downsample_methods + upsample_methods + series_methods + + +@pytest.fixture(params=downsample_methods) +def downsample_method(request): + """Fixture for parametrization of Grouper downsample methods.""" + return request.param + + +@pytest.fixture(params=resample_methods) +def resample_method(request): + """Fixture for parametrization of Grouper resample methods.""" + return request.param + + +@pytest.fixture +def _index_start(): + """Fixture for parametrization of index, series and frame.""" + return datetime(2005, 1, 1) + + +@pytest.fixture +def _index_end(): + """Fixture for parametrization of index, series and frame.""" + return datetime(2005, 1, 10) + + +@pytest.fixture +def _index_freq(): + """Fixture for parametrization of index, series and frame.""" + return "D" + + +@pytest.fixture +def _index_name(): + """Fixture for parametrization of index, series and frame.""" + return None + + +@pytest.fixture +def index(_index_factory, _index_start, _index_end, _index_freq, _index_name): + """ + Fixture for parametrization of date_range, period_range and + timedelta_range indexes + """ + return _index_factory(_index_start, _index_end, freq=_index_freq, name=_index_name) + + +@pytest.fixture +def _static_values(index): + """ + Fixture for parametrization of values used in parametrization of + Series and DataFrames with date_range, period_range and + timedelta_range indexes + """ + return np.arange(len(index)) + + +@pytest.fixture +def _series_name(): + """ + Fixture for parametrization of Series name for Series used with + date_range, period_range and timedelta_range indexes + """ + return None + + +@pytest.fixture +def series(index, _series_name, _static_values): + """ + Fixture for parametrization of Series with date_range, period_range and + timedelta_range indexes + """ + return Series(_static_values, index=index, name=_series_name) + + +@pytest.fixture +def empty_series_dti(series): + """ + Fixture for parametrization of empty Series with date_range, + period_range and timedelta_range indexes + """ + return series[:0] + + +@pytest.fixture +def frame(index, _series_name, _static_values): + """ + Fixture for parametrization of DataFrame with date_range, period_range + and timedelta_range indexes + """ + # _series_name is intentionally unused + return DataFrame({"value": _static_values}, index=index) + + +@pytest.fixture +def empty_frame_dti(series): + """ + Fixture for parametrization of empty DataFrame with date_range, + period_range and timedelta_range indexes + """ + index = series.index[:0] + return DataFrame(index=index) + + +@pytest.fixture +def series_and_frame(frame_or_series, series, frame): + """ + Fixture for parametrization of Series and DataFrame with date_range, + period_range and timedelta_range indexes + """ + if frame_or_series == Series: + return series + if frame_or_series == DataFrame: + return frame diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/resample/test_base.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/resample/test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..dcf6c6099abab66d7a3fc78606c0eaf2c4298421 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/resample/test_base.py @@ -0,0 +1,460 @@ +from datetime import datetime + +import numpy as np +import pytest + +from pandas.core.dtypes.common import is_extension_array_dtype + +import pandas as pd +from pandas import ( + DataFrame, + DatetimeIndex, + Index, + MultiIndex, + NaT, + PeriodIndex, + Series, + TimedeltaIndex, +) +import pandas._testing as tm +from pandas.core.groupby.groupby import DataError +from pandas.core.groupby.grouper import Grouper +from pandas.core.indexes.datetimes import date_range +from pandas.core.indexes.period import period_range +from pandas.core.indexes.timedeltas import timedelta_range +from pandas.core.resample import _asfreq_compat + +# a fixture value can be overridden by the test parameter value. Note that the +# value of the fixture can be overridden this way even if the test doesn't use +# it directly (doesn't mention it in the function prototype). +# see https://docs.pytest.org/en/latest/fixture.html#override-a-fixture-with-direct-test-parametrization # noqa: E501 +# in this module we override the fixture values defined in conftest.py +# tuples of '_index_factory,_series_name,_index_start,_index_end' +DATE_RANGE = (date_range, "dti", datetime(2005, 1, 1), datetime(2005, 1, 10)) +PERIOD_RANGE = (period_range, "pi", datetime(2005, 1, 1), datetime(2005, 1, 10)) +TIMEDELTA_RANGE = (timedelta_range, "tdi", "1 day", "10 day") + +all_ts = pytest.mark.parametrize( + "_index_factory,_series_name,_index_start,_index_end", + [DATE_RANGE, PERIOD_RANGE, TIMEDELTA_RANGE], +) + + +@pytest.fixture +def create_index(_index_factory): + def _create_index(*args, **kwargs): + """return the _index_factory created using the args, kwargs""" + return _index_factory(*args, **kwargs) + + return _create_index + + +@pytest.mark.parametrize("freq", ["2D", "1h"]) +@pytest.mark.parametrize( + "_index_factory,_series_name,_index_start,_index_end", [DATE_RANGE, TIMEDELTA_RANGE] +) +def test_asfreq(series_and_frame, freq, create_index): + obj = series_and_frame + + result = obj.resample(freq).asfreq() + new_index = create_index(obj.index[0], obj.index[-1], freq=freq) + expected = obj.reindex(new_index) + tm.assert_almost_equal(result, expected) + + +@pytest.mark.parametrize( + "_index_factory,_series_name,_index_start,_index_end", [DATE_RANGE, TIMEDELTA_RANGE] +) +def test_asfreq_fill_value(series, create_index): + # test for fill value during resampling, issue 3715 + + ser = series + + result = ser.resample("1h").asfreq() + new_index = create_index(ser.index[0], ser.index[-1], freq="1h") + expected = ser.reindex(new_index) + tm.assert_series_equal(result, expected) + + # Explicit cast to float to avoid implicit cast when setting None + frame = ser.astype("float").to_frame("value") + frame.iloc[1] = None + result = frame.resample("1h").asfreq(fill_value=4.0) + new_index = create_index(frame.index[0], frame.index[-1], freq="1h") + expected = frame.reindex(new_index, fill_value=4.0) + tm.assert_frame_equal(result, expected) + + +@all_ts +def test_resample_interpolate(frame): + # GH#12925 + df = frame + warn = None + if isinstance(df.index, PeriodIndex): + warn = FutureWarning + msg = "Resampling with a PeriodIndex is deprecated" + with tm.assert_produces_warning(warn, match=msg): + result = df.resample("1min").asfreq().interpolate() + expected = df.resample("1min").interpolate() + tm.assert_frame_equal(result, expected) + + +def test_raises_on_non_datetimelike_index(): + # this is a non datetimelike index + xp = DataFrame() + msg = ( + "Only valid with DatetimeIndex, TimedeltaIndex or PeriodIndex, " + "but got an instance of 'RangeIndex'" + ) + with pytest.raises(TypeError, match=msg): + xp.resample("YE") + + +@all_ts +@pytest.mark.parametrize("freq", ["ME", "D", "h"]) +def test_resample_empty_series(freq, empty_series_dti, resample_method): + # GH12771 & GH12868 + + ser = empty_series_dti + if freq == "ME" and isinstance(ser.index, TimedeltaIndex): + msg = ( + "Resampling on a TimedeltaIndex requires fixed-duration `freq`, " + "e.g. '24h' or '3D', not " + ) + with pytest.raises(ValueError, match=msg): + ser.resample(freq) + return + elif freq == "ME" and isinstance(ser.index, PeriodIndex): + # index is PeriodIndex, so convert to corresponding Period freq + freq = "M" + + warn = None + if isinstance(ser.index, PeriodIndex): + warn = FutureWarning + msg = "Resampling with a PeriodIndex is deprecated" + with tm.assert_produces_warning(warn, match=msg): + rs = ser.resample(freq) + result = getattr(rs, resample_method)() + + if resample_method == "ohlc": + expected = DataFrame( + [], index=ser.index[:0].copy(), columns=["open", "high", "low", "close"] + ) + expected.index = _asfreq_compat(ser.index, freq) + tm.assert_frame_equal(result, expected, check_dtype=False) + else: + expected = ser.copy() + expected.index = _asfreq_compat(ser.index, freq) + tm.assert_series_equal(result, expected, check_dtype=False) + + tm.assert_index_equal(result.index, expected.index) + assert result.index.freq == expected.index.freq + + +@all_ts +@pytest.mark.parametrize( + "freq", + [ + pytest.param("ME", marks=pytest.mark.xfail(reason="Don't know why this fails")), + "D", + "h", + ], +) +def test_resample_nat_index_series(freq, series, resample_method): + # GH39227 + + ser = series.copy() + ser.index = PeriodIndex([NaT] * len(ser), freq=freq) + + msg = "Resampling with a PeriodIndex is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + rs = ser.resample(freq) + result = getattr(rs, resample_method)() + + if resample_method == "ohlc": + expected = DataFrame( + [], index=ser.index[:0].copy(), columns=["open", "high", "low", "close"] + ) + tm.assert_frame_equal(result, expected, check_dtype=False) + else: + expected = ser[:0].copy() + tm.assert_series_equal(result, expected, check_dtype=False) + tm.assert_index_equal(result.index, expected.index) + assert result.index.freq == expected.index.freq + + +@all_ts +@pytest.mark.parametrize("freq", ["ME", "D", "h"]) +@pytest.mark.parametrize("resample_method", ["count", "size"]) +def test_resample_count_empty_series(freq, empty_series_dti, resample_method): + # GH28427 + ser = empty_series_dti + if freq == "ME" and isinstance(ser.index, TimedeltaIndex): + msg = ( + "Resampling on a TimedeltaIndex requires fixed-duration `freq`, " + "e.g. '24h' or '3D', not " + ) + with pytest.raises(ValueError, match=msg): + ser.resample(freq) + return + elif freq == "ME" and isinstance(ser.index, PeriodIndex): + # index is PeriodIndex, so convert to corresponding Period freq + freq = "M" + + warn = None + if isinstance(ser.index, PeriodIndex): + warn = FutureWarning + msg = "Resampling with a PeriodIndex is deprecated" + with tm.assert_produces_warning(warn, match=msg): + rs = ser.resample(freq) + + result = getattr(rs, resample_method)() + + index = _asfreq_compat(ser.index, freq) + + expected = Series([], dtype="int64", index=index, name=ser.name) + + tm.assert_series_equal(result, expected) + + +@all_ts +@pytest.mark.parametrize("freq", ["ME", "D", "h"]) +def test_resample_empty_dataframe(empty_frame_dti, freq, resample_method): + # GH13212 + df = empty_frame_dti + # count retains dimensions too + if freq == "ME" and isinstance(df.index, TimedeltaIndex): + msg = ( + "Resampling on a TimedeltaIndex requires fixed-duration `freq`, " + "e.g. '24h' or '3D', not " + ) + with pytest.raises(ValueError, match=msg): + df.resample(freq, group_keys=False) + return + elif freq == "ME" and isinstance(df.index, PeriodIndex): + # index is PeriodIndex, so convert to corresponding Period freq + freq = "M" + + warn = None + if isinstance(df.index, PeriodIndex): + warn = FutureWarning + msg = "Resampling with a PeriodIndex is deprecated" + with tm.assert_produces_warning(warn, match=msg): + rs = df.resample(freq, group_keys=False) + result = getattr(rs, resample_method)() + if resample_method == "ohlc": + # TODO: no tests with len(df.columns) > 0 + mi = MultiIndex.from_product([df.columns, ["open", "high", "low", "close"]]) + expected = DataFrame( + [], index=df.index[:0].copy(), columns=mi, dtype=np.float64 + ) + expected.index = _asfreq_compat(df.index, freq) + + elif resample_method != "size": + expected = df.copy() + else: + # GH14962 + expected = Series([], dtype=np.int64) + + expected.index = _asfreq_compat(df.index, freq) + + tm.assert_index_equal(result.index, expected.index) + assert result.index.freq == expected.index.freq + tm.assert_almost_equal(result, expected) + + # test size for GH13212 (currently stays as df) + + +@all_ts +@pytest.mark.parametrize("freq", ["ME", "D", "h"]) +def test_resample_count_empty_dataframe(freq, empty_frame_dti): + # GH28427 + + empty_frame_dti["a"] = [] + + if freq == "ME" and isinstance(empty_frame_dti.index, TimedeltaIndex): + msg = ( + "Resampling on a TimedeltaIndex requires fixed-duration `freq`, " + "e.g. '24h' or '3D', not " + ) + with pytest.raises(ValueError, match=msg): + empty_frame_dti.resample(freq) + return + elif freq == "ME" and isinstance(empty_frame_dti.index, PeriodIndex): + # index is PeriodIndex, so convert to corresponding Period freq + freq = "M" + + warn = None + if isinstance(empty_frame_dti.index, PeriodIndex): + warn = FutureWarning + msg = "Resampling with a PeriodIndex is deprecated" + with tm.assert_produces_warning(warn, match=msg): + rs = empty_frame_dti.resample(freq) + result = rs.count() + + index = _asfreq_compat(empty_frame_dti.index, freq) + + expected = DataFrame(dtype="int64", index=index, columns=Index(["a"], dtype=object)) + + tm.assert_frame_equal(result, expected) + + +@all_ts +@pytest.mark.parametrize("freq", ["ME", "D", "h"]) +def test_resample_size_empty_dataframe(freq, empty_frame_dti): + # GH28427 + + empty_frame_dti["a"] = [] + + if freq == "ME" and isinstance(empty_frame_dti.index, TimedeltaIndex): + msg = ( + "Resampling on a TimedeltaIndex requires fixed-duration `freq`, " + "e.g. '24h' or '3D', not " + ) + with pytest.raises(ValueError, match=msg): + empty_frame_dti.resample(freq) + return + elif freq == "ME" and isinstance(empty_frame_dti.index, PeriodIndex): + # index is PeriodIndex, so convert to corresponding Period freq + freq = "M" + + msg = "Resampling with a PeriodIndex" + warn = None + if isinstance(empty_frame_dti.index, PeriodIndex): + warn = FutureWarning + with tm.assert_produces_warning(warn, match=msg): + rs = empty_frame_dti.resample(freq) + result = rs.size() + + index = _asfreq_compat(empty_frame_dti.index, freq) + + expected = Series([], dtype="int64", index=index) + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "index", + [ + PeriodIndex([], freq="M", name="a"), + DatetimeIndex([], name="a"), + TimedeltaIndex([], name="a"), + ], +) +@pytest.mark.parametrize("dtype", [float, int, object, "datetime64[ns]"]) +@pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") +def test_resample_empty_dtypes(index, dtype, resample_method): + # Empty series were sometimes causing a segfault (for the functions + # with Cython bounds-checking disabled) or an IndexError. We just run + # them to ensure they no longer do. (GH #10228) + warn = None + if isinstance(index, PeriodIndex): + # GH#53511 + index = PeriodIndex([], freq="B", name=index.name) + warn = FutureWarning + msg = "Resampling with a PeriodIndex is deprecated" + + empty_series_dti = Series([], index, dtype) + with tm.assert_produces_warning(warn, match=msg): + rs = empty_series_dti.resample("d", group_keys=False) + try: + getattr(rs, resample_method)() + except DataError: + # Ignore these since some combinations are invalid + # (ex: doing mean with dtype of np.object_) + pass + + +@all_ts +@pytest.mark.parametrize("freq", ["ME", "D", "h"]) +def test_apply_to_empty_series(empty_series_dti, freq): + # GH 14313 + ser = empty_series_dti + + if freq == "ME" and isinstance(empty_series_dti.index, TimedeltaIndex): + msg = ( + "Resampling on a TimedeltaIndex requires fixed-duration `freq`, " + "e.g. '24h' or '3D', not " + ) + with pytest.raises(ValueError, match=msg): + empty_series_dti.resample(freq) + return + elif freq == "ME" and isinstance(empty_series_dti.index, PeriodIndex): + # index is PeriodIndex, so convert to corresponding Period freq + freq = "M" + + msg = "Resampling with a PeriodIndex" + warn = None + if isinstance(empty_series_dti.index, PeriodIndex): + warn = FutureWarning + + with tm.assert_produces_warning(warn, match=msg): + rs = ser.resample(freq, group_keys=False) + + result = rs.apply(lambda x: 1) + with tm.assert_produces_warning(warn, match=msg): + expected = ser.resample(freq).apply("sum") + + tm.assert_series_equal(result, expected, check_dtype=False) + + +@all_ts +def test_resampler_is_iterable(series): + # GH 15314 + freq = "h" + tg = Grouper(freq=freq, convention="start") + msg = "Resampling with a PeriodIndex" + warn = None + if isinstance(series.index, PeriodIndex): + warn = FutureWarning + + with tm.assert_produces_warning(warn, match=msg): + grouped = series.groupby(tg) + + with tm.assert_produces_warning(warn, match=msg): + resampled = series.resample(freq) + for (rk, rv), (gk, gv) in zip(resampled, grouped): + assert rk == gk + tm.assert_series_equal(rv, gv) + + +@all_ts +def test_resample_quantile(series): + # GH 15023 + ser = series + q = 0.75 + freq = "h" + + msg = "Resampling with a PeriodIndex" + warn = None + if isinstance(series.index, PeriodIndex): + warn = FutureWarning + with tm.assert_produces_warning(warn, match=msg): + result = ser.resample(freq).quantile(q) + expected = ser.resample(freq).agg(lambda x: x.quantile(q)).rename(ser.name) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("how", ["first", "last"]) +def test_first_last_skipna(any_real_nullable_dtype, skipna, how): + # GH#57019 + if is_extension_array_dtype(any_real_nullable_dtype): + na_value = Series(dtype=any_real_nullable_dtype).dtype.na_value + else: + na_value = np.nan + df = DataFrame( + { + "a": [2, 1, 1, 2], + "b": [na_value, 3.0, na_value, 4.0], + "c": [na_value, 3.0, na_value, 4.0], + }, + index=date_range("2020-01-01", periods=4, freq="D"), + dtype=any_real_nullable_dtype, + ) + rs = df.resample("ME") + method = getattr(rs, how) + result = method(skipna=skipna) + + gb = df.groupby(df.shape[0] * [pd.to_datetime("2020-01-31")]) + expected = getattr(gb, how)(skipna=skipna) + expected.index.freq = "ME" + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/resample/test_datetime_index.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/resample/test_datetime_index.py new file mode 100644 index 0000000000000000000000000000000000000000..ddd81ab1d347d315de0aba6e75908a18beb923ec --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/resample/test_datetime_index.py @@ -0,0 +1,2231 @@ +from datetime import datetime +from functools import partial + +import numpy as np +import pytest +import pytz + +from pandas._libs import lib +from pandas._typing import DatetimeNaTType +from pandas.compat import is_platform_windows +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + DataFrame, + Index, + Series, + Timedelta, + Timestamp, + isna, + notna, +) +import pandas._testing as tm +from pandas.core.groupby.grouper import Grouper +from pandas.core.indexes.datetimes import date_range +from pandas.core.indexes.period import ( + Period, + period_range, +) +from pandas.core.resample import ( + DatetimeIndex, + _get_timestamp_range_edges, +) + +from pandas.tseries import offsets +from pandas.tseries.offsets import Minute + + +@pytest.fixture() +def _index_factory(): + return date_range + + +@pytest.fixture +def _index_freq(): + return "Min" + + +@pytest.fixture +def _static_values(index): + return np.random.default_rng(2).random(len(index)) + + +@pytest.fixture(params=["s", "ms", "us", "ns"]) +def unit(request): + return request.param + + +@pytest.fixture +def simple_date_range_series(): + """ + Series with date range index and random data for test purposes. + """ + + def _simple_date_range_series(start, end, freq="D"): + rng = date_range(start, end, freq=freq) + return Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + return _simple_date_range_series + + +def test_custom_grouper(index, unit): + dti = index.as_unit(unit) + s = Series(np.array([1] * len(dti)), index=dti, dtype="int64") + + b = Grouper(freq=Minute(5)) + g = s.groupby(b) + + # check all cython functions work + g.ohlc() # doesn't use _cython_agg_general + funcs = ["sum", "mean", "prod", "min", "max", "var"] + for f in funcs: + g._cython_agg_general(f, alt=None, numeric_only=True) + + b = Grouper(freq=Minute(5), closed="right", label="right") + g = s.groupby(b) + # check all cython functions work + g.ohlc() # doesn't use _cython_agg_general + funcs = ["sum", "mean", "prod", "min", "max", "var"] + for f in funcs: + g._cython_agg_general(f, alt=None, numeric_only=True) + + assert g.ngroups == 2593 + assert notna(g.mean()).all() + + # construct expected val + arr = [1] + [5] * 2592 + idx = dti[0:-1:5] + idx = idx.append(dti[-1:]) + idx = DatetimeIndex(idx, freq="5min").as_unit(unit) + expect = Series(arr, index=idx) + + # GH2763 - return input dtype if we can + result = g.agg("sum") + tm.assert_series_equal(result, expect) + + +def test_custom_grouper_df(index, unit): + b = Grouper(freq=Minute(5), closed="right", label="right") + dti = index.as_unit(unit) + df = DataFrame( + np.random.default_rng(2).random((len(dti), 10)), index=dti, dtype="float64" + ) + r = df.groupby(b).agg("sum") + + assert len(r.columns) == 10 + assert len(r.index) == 2593 + + +@pytest.mark.parametrize( + "_index_start,_index_end,_index_name", + [("1/1/2000 00:00:00", "1/1/2000 00:13:00", "index")], +) +@pytest.mark.parametrize( + "closed, expected", + [ + ( + "right", + lambda s: Series( + [s.iloc[0], s[1:6].mean(), s[6:11].mean(), s[11:].mean()], + index=date_range("1/1/2000", periods=4, freq="5min", name="index"), + ), + ), + ( + "left", + lambda s: Series( + [s[:5].mean(), s[5:10].mean(), s[10:].mean()], + index=date_range( + "1/1/2000 00:05", periods=3, freq="5min", name="index" + ), + ), + ), + ], +) +def test_resample_basic(series, closed, expected, unit): + s = series + s.index = s.index.as_unit(unit) + expected = expected(s) + expected.index = expected.index.as_unit(unit) + result = s.resample("5min", closed=closed, label="right").mean() + tm.assert_series_equal(result, expected) + + +def test_resample_integerarray(unit): + # GH 25580, resample on IntegerArray + ts = Series( + range(9), + index=date_range("1/1/2000", periods=9, freq="min").as_unit(unit), + dtype="Int64", + ) + result = ts.resample("3min").sum() + expected = Series( + [3, 12, 21], + index=date_range("1/1/2000", periods=3, freq="3min").as_unit(unit), + dtype="Int64", + ) + tm.assert_series_equal(result, expected) + + result = ts.resample("3min").mean() + expected = Series( + [1, 4, 7], + index=date_range("1/1/2000", periods=3, freq="3min").as_unit(unit), + dtype="Float64", + ) + tm.assert_series_equal(result, expected) + + +def test_resample_basic_grouper(series, unit): + s = series + s.index = s.index.as_unit(unit) + result = s.resample("5Min").last() + grouper = Grouper(freq=Minute(5), closed="left", label="left") + expected = s.groupby(grouper).agg(lambda x: x.iloc[-1]) + tm.assert_series_equal(result, expected) + + +@pytest.mark.filterwarnings( + "ignore:The 'convention' keyword in Series.resample:FutureWarning" +) +@pytest.mark.parametrize( + "_index_start,_index_end,_index_name", + [("1/1/2000 00:00:00", "1/1/2000 00:13:00", "index")], +) +@pytest.mark.parametrize( + "keyword,value", + [("label", "righttt"), ("closed", "righttt"), ("convention", "starttt")], +) +def test_resample_string_kwargs(series, keyword, value, unit): + # see gh-19303 + # Check that wrong keyword argument strings raise an error + series.index = series.index.as_unit(unit) + msg = f"Unsupported value {value} for `{keyword}`" + with pytest.raises(ValueError, match=msg): + series.resample("5min", **({keyword: value})) + + +@pytest.mark.parametrize( + "_index_start,_index_end,_index_name", + [("1/1/2000 00:00:00", "1/1/2000 00:13:00", "index")], +) +def test_resample_how(series, downsample_method, unit): + if downsample_method == "ohlc": + pytest.skip("covered by test_resample_how_ohlc") + + s = series + s.index = s.index.as_unit(unit) + grouplist = np.ones_like(s) + grouplist[0] = 0 + grouplist[1:6] = 1 + grouplist[6:11] = 2 + grouplist[11:] = 3 + expected = s.groupby(grouplist).agg(downsample_method) + expected.index = date_range( + "1/1/2000", periods=4, freq="5min", name="index" + ).as_unit(unit) + + result = getattr( + s.resample("5min", closed="right", label="right"), downsample_method + )() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "_index_start,_index_end,_index_name", + [("1/1/2000 00:00:00", "1/1/2000 00:13:00", "index")], +) +def test_resample_how_ohlc(series, unit): + s = series + s.index = s.index.as_unit(unit) + grouplist = np.ones_like(s) + grouplist[0] = 0 + grouplist[1:6] = 1 + grouplist[6:11] = 2 + grouplist[11:] = 3 + + def _ohlc(group): + if isna(group).all(): + return np.repeat(np.nan, 4) + return [group.iloc[0], group.max(), group.min(), group.iloc[-1]] + + expected = DataFrame( + s.groupby(grouplist).agg(_ohlc).values.tolist(), + index=date_range("1/1/2000", periods=4, freq="5min", name="index").as_unit( + unit + ), + columns=["open", "high", "low", "close"], + ) + + result = s.resample("5min", closed="right", label="right").ohlc() + tm.assert_frame_equal(result, expected) + + +def test_resample_how_callables(unit): + # GH#7929 + data = np.arange(5, dtype=np.int64) + ind = date_range(start="2014-01-01", periods=len(data), freq="d").as_unit(unit) + df = DataFrame({"A": data, "B": data}, index=ind) + + def fn(x, a=1): + return str(type(x)) + + class FnClass: + def __call__(self, x): + return str(type(x)) + + df_standard = df.resample("ME").apply(fn) + df_lambda = df.resample("ME").apply(lambda x: str(type(x))) + df_partial = df.resample("ME").apply(partial(fn)) + df_partial2 = df.resample("ME").apply(partial(fn, a=2)) + df_class = df.resample("ME").apply(FnClass()) + + tm.assert_frame_equal(df_standard, df_lambda) + tm.assert_frame_equal(df_standard, df_partial) + tm.assert_frame_equal(df_standard, df_partial2) + tm.assert_frame_equal(df_standard, df_class) + + +def test_resample_rounding(unit): + # GH 8371 + # odd results when rounding is needed + + ts = [ + "2014-11-08 00:00:01", + "2014-11-08 00:00:02", + "2014-11-08 00:00:02", + "2014-11-08 00:00:03", + "2014-11-08 00:00:07", + "2014-11-08 00:00:07", + "2014-11-08 00:00:08", + "2014-11-08 00:00:08", + "2014-11-08 00:00:08", + "2014-11-08 00:00:09", + "2014-11-08 00:00:10", + "2014-11-08 00:00:11", + "2014-11-08 00:00:11", + "2014-11-08 00:00:13", + "2014-11-08 00:00:14", + "2014-11-08 00:00:15", + "2014-11-08 00:00:17", + "2014-11-08 00:00:20", + "2014-11-08 00:00:21", + ] + df = DataFrame({"value": [1] * 19}, index=pd.to_datetime(ts)) + df.index = df.index.as_unit(unit) + + result = df.resample("6s").sum() + expected = DataFrame( + {"value": [4, 9, 4, 2]}, + index=date_range("2014-11-08", freq="6s", periods=4).as_unit(unit), + ) + tm.assert_frame_equal(result, expected) + + result = df.resample("7s").sum() + expected = DataFrame( + {"value": [4, 10, 4, 1]}, + index=date_range("2014-11-08", freq="7s", periods=4).as_unit(unit), + ) + tm.assert_frame_equal(result, expected) + + result = df.resample("11s").sum() + expected = DataFrame( + {"value": [11, 8]}, + index=date_range("2014-11-08", freq="11s", periods=2).as_unit(unit), + ) + tm.assert_frame_equal(result, expected) + + result = df.resample("13s").sum() + expected = DataFrame( + {"value": [13, 6]}, + index=date_range("2014-11-08", freq="13s", periods=2).as_unit(unit), + ) + tm.assert_frame_equal(result, expected) + + result = df.resample("17s").sum() + expected = DataFrame( + {"value": [16, 3]}, + index=date_range("2014-11-08", freq="17s", periods=2).as_unit(unit), + ) + tm.assert_frame_equal(result, expected) + + +def test_resample_basic_from_daily(unit): + # from daily + dti = date_range( + start=datetime(2005, 1, 1), end=datetime(2005, 1, 10), freq="D", name="index" + ).as_unit(unit) + + s = Series(np.random.default_rng(2).random(len(dti)), dti) + + # to weekly + result = s.resample("w-sun").last() + + assert len(result) == 3 + assert (result.index.dayofweek == [6, 6, 6]).all() + assert result.iloc[0] == s["1/2/2005"] + assert result.iloc[1] == s["1/9/2005"] + assert result.iloc[2] == s.iloc[-1] + + result = s.resample("W-MON").last() + assert len(result) == 2 + assert (result.index.dayofweek == [0, 0]).all() + assert result.iloc[0] == s["1/3/2005"] + assert result.iloc[1] == s["1/10/2005"] + + result = s.resample("W-TUE").last() + assert len(result) == 2 + assert (result.index.dayofweek == [1, 1]).all() + assert result.iloc[0] == s["1/4/2005"] + assert result.iloc[1] == s["1/10/2005"] + + result = s.resample("W-WED").last() + assert len(result) == 2 + assert (result.index.dayofweek == [2, 2]).all() + assert result.iloc[0] == s["1/5/2005"] + assert result.iloc[1] == s["1/10/2005"] + + result = s.resample("W-THU").last() + assert len(result) == 2 + assert (result.index.dayofweek == [3, 3]).all() + assert result.iloc[0] == s["1/6/2005"] + assert result.iloc[1] == s["1/10/2005"] + + result = s.resample("W-FRI").last() + assert len(result) == 2 + assert (result.index.dayofweek == [4, 4]).all() + assert result.iloc[0] == s["1/7/2005"] + assert result.iloc[1] == s["1/10/2005"] + + # to biz day + result = s.resample("B").last() + assert len(result) == 7 + assert (result.index.dayofweek == [4, 0, 1, 2, 3, 4, 0]).all() + + assert result.iloc[0] == s["1/2/2005"] + assert result.iloc[1] == s["1/3/2005"] + assert result.iloc[5] == s["1/9/2005"] + assert result.index.name == "index" + + +def test_resample_upsampling_picked_but_not_correct(unit): + # Test for issue #3020 + dates = date_range("01-Jan-2014", "05-Jan-2014", freq="D").as_unit(unit) + series = Series(1, index=dates) + + result = series.resample("D").mean() + assert result.index[0] == dates[0] + + # GH 5955 + # incorrect deciding to upsample when the axis frequency matches the + # resample frequency + + s = Series( + np.arange(1.0, 6), index=[datetime(1975, 1, i, 12, 0) for i in range(1, 6)] + ) + s.index = s.index.as_unit(unit) + expected = Series( + np.arange(1.0, 6), + index=date_range("19750101", periods=5, freq="D").as_unit(unit), + ) + + result = s.resample("D").count() + tm.assert_series_equal(result, Series(1, index=expected.index)) + + result1 = s.resample("D").sum() + result2 = s.resample("D").mean() + tm.assert_series_equal(result1, expected) + tm.assert_series_equal(result2, expected) + + +@pytest.mark.parametrize("f", ["sum", "mean", "prod", "min", "max", "var"]) +def test_resample_frame_basic_cy_funcs(f, unit): + df = DataFrame( + np.random.default_rng(2).standard_normal((50, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=50, freq="B"), + ) + df.index = df.index.as_unit(unit) + + b = Grouper(freq="ME") + g = df.groupby(b) + + # check all cython functions work + g._cython_agg_general(f, alt=None, numeric_only=True) + + +@pytest.mark.parametrize("freq", ["YE", "ME"]) +def test_resample_frame_basic_M_A(freq, unit): + df = DataFrame( + np.random.default_rng(2).standard_normal((50, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=50, freq="B"), + ) + df.index = df.index.as_unit(unit) + result = df.resample(freq).mean() + tm.assert_series_equal(result["A"], df["A"].resample(freq).mean()) + + +@pytest.mark.parametrize("freq", ["W-WED", "ME"]) +def test_resample_frame_basic_kind(freq, unit): + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + df.index = df.index.as_unit(unit) + msg = "The 'kind' keyword in DataFrame.resample is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + df.resample(freq, kind="period").mean() + + +def test_resample_upsample(unit): + # from daily + dti = date_range( + start=datetime(2005, 1, 1), end=datetime(2005, 1, 10), freq="D", name="index" + ).as_unit(unit) + + s = Series(np.random.default_rng(2).random(len(dti)), dti) + + # to minutely, by padding + result = s.resample("Min").ffill() + assert len(result) == 12961 + assert result.iloc[0] == s.iloc[0] + assert result.iloc[-1] == s.iloc[-1] + + assert result.index.name == "index" + + +def test_resample_how_method(unit): + # GH9915 + s = Series( + [11, 22], + index=[ + Timestamp("2015-03-31 21:48:52.672000"), + Timestamp("2015-03-31 21:49:52.739000"), + ], + ) + s.index = s.index.as_unit(unit) + expected = Series( + [11, np.nan, np.nan, np.nan, np.nan, np.nan, 22], + index=DatetimeIndex( + [ + Timestamp("2015-03-31 21:48:50"), + Timestamp("2015-03-31 21:49:00"), + Timestamp("2015-03-31 21:49:10"), + Timestamp("2015-03-31 21:49:20"), + Timestamp("2015-03-31 21:49:30"), + Timestamp("2015-03-31 21:49:40"), + Timestamp("2015-03-31 21:49:50"), + ], + freq="10s", + ), + ) + expected.index = expected.index.as_unit(unit) + tm.assert_series_equal(s.resample("10s").mean(), expected) + + +def test_resample_extra_index_point(unit): + # GH#9756 + index = date_range(start="20150101", end="20150331", freq="BME").as_unit(unit) + expected = DataFrame({"A": Series([21, 41, 63], index=index)}) + + index = date_range(start="20150101", end="20150331", freq="B").as_unit(unit) + df = DataFrame({"A": Series(range(len(index)), index=index)}, dtype="int64") + result = df.resample("BME").last() + tm.assert_frame_equal(result, expected) + + +def test_upsample_with_limit(unit): + rng = date_range("1/1/2000", periods=3, freq="5min").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), rng) + + result = ts.resample("min").ffill(limit=2) + expected = ts.reindex(result.index, method="ffill", limit=2) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("freq", ["1D", "10h", "5Min", "10s"]) +@pytest.mark.parametrize("rule", ["YE", "3ME", "15D", "30h", "15Min", "30s"]) +def test_nearest_upsample_with_limit(tz_aware_fixture, freq, rule, unit): + # GH 33939 + rng = date_range("1/1/2000", periods=3, freq=freq, tz=tz_aware_fixture).as_unit( + unit + ) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), rng) + + result = ts.resample(rule).nearest(limit=2) + expected = ts.reindex(result.index, method="nearest", limit=2) + tm.assert_series_equal(result, expected) + + +def test_resample_ohlc(series, unit): + s = series + s.index = s.index.as_unit(unit) + + grouper = Grouper(freq=Minute(5)) + expect = s.groupby(grouper).agg(lambda x: x.iloc[-1]) + result = s.resample("5Min").ohlc() + + assert len(result) == len(expect) + assert len(result.columns) == 4 + + xs = result.iloc[-2] + assert xs["open"] == s.iloc[-6] + assert xs["high"] == s[-6:-1].max() + assert xs["low"] == s[-6:-1].min() + assert xs["close"] == s.iloc[-2] + + xs = result.iloc[0] + assert xs["open"] == s.iloc[0] + assert xs["high"] == s[:5].max() + assert xs["low"] == s[:5].min() + assert xs["close"] == s.iloc[4] + + +def test_resample_ohlc_result(unit): + # GH 12332 + index = date_range("1-1-2000", "2-15-2000", freq="h").as_unit(unit) + index = index.union(date_range("4-15-2000", "5-15-2000", freq="h").as_unit(unit)) + s = Series(range(len(index)), index=index) + + a = s.loc[:"4-15-2000"].resample("30min").ohlc() + assert isinstance(a, DataFrame) + + b = s.loc[:"4-14-2000"].resample("30min").ohlc() + assert isinstance(b, DataFrame) + + +def test_resample_ohlc_result_odd_period(unit): + # GH12348 + # raising on odd period + rng = date_range("2013-12-30", "2014-01-07").as_unit(unit) + index = rng.drop( + [ + Timestamp("2014-01-01"), + Timestamp("2013-12-31"), + Timestamp("2014-01-04"), + Timestamp("2014-01-05"), + ] + ) + df = DataFrame(data=np.arange(len(index)), index=index) + result = df.resample("B").mean() + expected = df.reindex(index=date_range(rng[0], rng[-1], freq="B").as_unit(unit)) + tm.assert_frame_equal(result, expected) + + +def test_resample_ohlc_dataframe(unit): + df = ( + DataFrame( + { + "PRICE": { + Timestamp("2011-01-06 10:59:05", tz=None): 24990, + Timestamp("2011-01-06 12:43:33", tz=None): 25499, + Timestamp("2011-01-06 12:54:09", tz=None): 25499, + }, + "VOLUME": { + Timestamp("2011-01-06 10:59:05", tz=None): 1500000000, + Timestamp("2011-01-06 12:43:33", tz=None): 5000000000, + Timestamp("2011-01-06 12:54:09", tz=None): 100000000, + }, + } + ) + ).reindex(["VOLUME", "PRICE"], axis=1) + df.index = df.index.as_unit(unit) + df.columns.name = "Cols" + res = df.resample("h").ohlc() + exp = pd.concat( + [df["VOLUME"].resample("h").ohlc(), df["PRICE"].resample("h").ohlc()], + axis=1, + keys=df.columns, + ) + assert exp.columns.names[0] == "Cols" + tm.assert_frame_equal(exp, res) + + df.columns = [["a", "b"], ["c", "d"]] + res = df.resample("h").ohlc() + exp.columns = pd.MultiIndex.from_tuples( + [ + ("a", "c", "open"), + ("a", "c", "high"), + ("a", "c", "low"), + ("a", "c", "close"), + ("b", "d", "open"), + ("b", "d", "high"), + ("b", "d", "low"), + ("b", "d", "close"), + ] + ) + tm.assert_frame_equal(exp, res) + + # dupe columns fail atm + # df.columns = ['PRICE', 'PRICE'] + + +def test_resample_dup_index(): + # GH 4812 + # dup columns with resample raising + df = DataFrame( + np.random.default_rng(2).standard_normal((4, 12)), + index=[2000, 2000, 2000, 2000], + columns=[Period(year=2000, month=i + 1, freq="M") for i in range(12)], + ) + df.iloc[3, :] = np.nan + warning_msg = "DataFrame.resample with axis=1 is deprecated." + with tm.assert_produces_warning(FutureWarning, match=warning_msg): + result = df.resample("QE", axis=1).mean() + + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + expected = df.groupby(lambda x: int((x.month - 1) / 3), axis=1).mean() + expected.columns = [Period(year=2000, quarter=i + 1, freq="Q") for i in range(4)] + tm.assert_frame_equal(result, expected) + + +def test_resample_reresample(unit): + dti = date_range( + start=datetime(2005, 1, 1), end=datetime(2005, 1, 10), freq="D" + ).as_unit(unit) + s = Series(np.random.default_rng(2).random(len(dti)), dti) + bs = s.resample("B", closed="right", label="right").mean() + result = bs.resample("8h").mean() + assert len(result) == 25 + assert isinstance(result.index.freq, offsets.DateOffset) + assert result.index.freq == offsets.Hour(8) + + +@pytest.mark.parametrize( + "freq, expected_kwargs", + [ + ["YE-DEC", {"start": "1990", "end": "2000", "freq": "Y-DEC"}], + ["YE-JUN", {"start": "1990", "end": "2000", "freq": "Y-JUN"}], + ["ME", {"start": "1990-01", "end": "2000-01", "freq": "M"}], + ], +) +def test_resample_timestamp_to_period( + simple_date_range_series, freq, expected_kwargs, unit +): + ts = simple_date_range_series("1/1/1990", "1/1/2000") + ts.index = ts.index.as_unit(unit) + + msg = "The 'kind' keyword in Series.resample is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = ts.resample(freq, kind="period").mean() + expected = ts.resample(freq).mean() + expected.index = period_range(**expected_kwargs) + tm.assert_series_equal(result, expected) + + +def test_ohlc_5min(unit): + def _ohlc(group): + if isna(group).all(): + return np.repeat(np.nan, 4) + return [group.iloc[0], group.max(), group.min(), group.iloc[-1]] + + rng = date_range("1/1/2000 00:00:00", "1/1/2000 5:59:50", freq="10s").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + resampled = ts.resample("5min", closed="right", label="right").ohlc() + + assert (resampled.loc["1/1/2000 00:00"] == ts.iloc[0]).all() + + exp = _ohlc(ts[1:31]) + assert (resampled.loc["1/1/2000 00:05"] == exp).all() + + exp = _ohlc(ts["1/1/2000 5:55:01":]) + assert (resampled.loc["1/1/2000 6:00:00"] == exp).all() + + +def test_downsample_non_unique(unit): + rng = date_range("1/1/2000", "2/29/2000").as_unit(unit) + rng2 = rng.repeat(5).values + ts = Series(np.random.default_rng(2).standard_normal(len(rng2)), index=rng2) + + result = ts.resample("ME").mean() + + expected = ts.groupby(lambda x: x.month).mean() + assert len(result) == 2 + tm.assert_almost_equal(result.iloc[0], expected[1]) + tm.assert_almost_equal(result.iloc[1], expected[2]) + + +def test_asfreq_non_unique(unit): + # GH #1077 + rng = date_range("1/1/2000", "2/29/2000").as_unit(unit) + rng2 = rng.repeat(2).values + ts = Series(np.random.default_rng(2).standard_normal(len(rng2)), index=rng2) + + msg = "cannot reindex on an axis with duplicate labels" + with pytest.raises(ValueError, match=msg): + ts.asfreq("B") + + +def test_resample_axis1(unit): + rng = date_range("1/1/2000", "2/29/2000").as_unit(unit) + df = DataFrame( + np.random.default_rng(2).standard_normal((3, len(rng))), + columns=rng, + index=["a", "b", "c"], + ) + + warning_msg = "DataFrame.resample with axis=1 is deprecated." + with tm.assert_produces_warning(FutureWarning, match=warning_msg): + result = df.resample("ME", axis=1).mean() + expected = df.T.resample("ME").mean().T + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("freq", ["min", "5min", "15min", "30min", "4h", "12h"]) +def test_resample_anchored_ticks(freq, unit): + # If a fixed delta (5 minute, 4 hour) evenly divides a day, we should + # "anchor" the origin at midnight so we get regular intervals rather + # than starting from the first timestamp which might start in the + # middle of a desired interval + + rng = date_range("1/1/2000 04:00:00", periods=86400, freq="s").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + ts[:2] = np.nan # so results are the same + result = ts[2:].resample(freq, closed="left", label="left").mean() + expected = ts.resample(freq, closed="left", label="left").mean() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("end", [1, 2]) +def test_resample_single_group(end, unit): + mysum = lambda x: x.sum() + + rng = date_range("2000-1-1", f"2000-{end}-10", freq="D").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + tm.assert_series_equal(ts.resample("ME").sum(), ts.resample("ME").apply(mysum)) + + +def test_resample_single_group_std(unit): + # GH 3849 + s = Series( + [30.1, 31.6], + index=[Timestamp("20070915 15:30:00"), Timestamp("20070915 15:40:00")], + ) + s.index = s.index.as_unit(unit) + expected = Series( + [0.75], index=DatetimeIndex([Timestamp("20070915")], freq="D").as_unit(unit) + ) + result = s.resample("D").apply(lambda x: np.std(x)) + tm.assert_series_equal(result, expected) + + +def test_resample_offset(unit): + # GH 31809 + + rng = date_range("1/1/2000 00:00:00", "1/1/2000 02:00", freq="s").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + resampled = ts.resample("5min", offset="2min").mean() + exp_rng = date_range("12/31/1999 23:57:00", "1/1/2000 01:57", freq="5min").as_unit( + unit + ) + tm.assert_index_equal(resampled.index, exp_rng) + + +@pytest.mark.parametrize( + "kwargs", + [ + {"origin": "1999-12-31 23:57:00"}, + {"origin": Timestamp("1970-01-01 00:02:00")}, + {"origin": "epoch", "offset": "2m"}, + # origin of '1999-31-12 12:02:00' should be equivalent for this case + {"origin": "1999-12-31 12:02:00"}, + {"offset": "-3m"}, + ], +) +def test_resample_origin(kwargs, unit): + # GH 31809 + rng = date_range("2000-01-01 00:00:00", "2000-01-01 02:00", freq="s").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + exp_rng = date_range( + "1999-12-31 23:57:00", "2000-01-01 01:57", freq="5min" + ).as_unit(unit) + + resampled = ts.resample("5min", **kwargs).mean() + tm.assert_index_equal(resampled.index, exp_rng) + + +@pytest.mark.parametrize( + "origin", ["invalid_value", "epch", "startday", "startt", "2000-30-30", object()] +) +def test_resample_bad_origin(origin, unit): + rng = date_range("2000-01-01 00:00:00", "2000-01-01 02:00", freq="s").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + msg = ( + "'origin' should be equal to 'epoch', 'start', 'start_day', " + "'end', 'end_day' or should be a Timestamp convertible type. Got " + f"'{origin}' instead." + ) + with pytest.raises(ValueError, match=msg): + ts.resample("5min", origin=origin) + + +@pytest.mark.parametrize("offset", ["invalid_value", "12dayys", "2000-30-30", object()]) +def test_resample_bad_offset(offset, unit): + rng = date_range("2000-01-01 00:00:00", "2000-01-01 02:00", freq="s").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + msg = f"'offset' should be a Timedelta convertible type. Got '{offset}' instead." + with pytest.raises(ValueError, match=msg): + ts.resample("5min", offset=offset) + + +def test_resample_origin_prime_freq(unit): + # GH 31809 + start, end = "2000-10-01 23:30:00", "2000-10-02 00:30:00" + rng = date_range(start, end, freq="7min").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + exp_rng = date_range( + "2000-10-01 23:14:00", "2000-10-02 00:22:00", freq="17min" + ).as_unit(unit) + resampled = ts.resample("17min").mean() + tm.assert_index_equal(resampled.index, exp_rng) + resampled = ts.resample("17min", origin="start_day").mean() + tm.assert_index_equal(resampled.index, exp_rng) + + exp_rng = date_range( + "2000-10-01 23:30:00", "2000-10-02 00:21:00", freq="17min" + ).as_unit(unit) + resampled = ts.resample("17min", origin="start").mean() + tm.assert_index_equal(resampled.index, exp_rng) + resampled = ts.resample("17min", offset="23h30min").mean() + tm.assert_index_equal(resampled.index, exp_rng) + resampled = ts.resample("17min", origin="start_day", offset="23h30min").mean() + tm.assert_index_equal(resampled.index, exp_rng) + + exp_rng = date_range( + "2000-10-01 23:18:00", "2000-10-02 00:26:00", freq="17min" + ).as_unit(unit) + resampled = ts.resample("17min", origin="epoch").mean() + tm.assert_index_equal(resampled.index, exp_rng) + + exp_rng = date_range( + "2000-10-01 23:24:00", "2000-10-02 00:15:00", freq="17min" + ).as_unit(unit) + resampled = ts.resample("17min", origin="2000-01-01").mean() + tm.assert_index_equal(resampled.index, exp_rng) + + +def test_resample_origin_with_tz(unit): + # GH 31809 + msg = "The origin must have the same timezone as the index." + + tz = "Europe/Paris" + rng = date_range( + "2000-01-01 00:00:00", "2000-01-01 02:00", freq="s", tz=tz + ).as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + exp_rng = date_range( + "1999-12-31 23:57:00", "2000-01-01 01:57", freq="5min", tz=tz + ).as_unit(unit) + resampled = ts.resample("5min", origin="1999-12-31 23:57:00+00:00").mean() + tm.assert_index_equal(resampled.index, exp_rng) + + # origin of '1999-31-12 12:02:00+03:00' should be equivalent for this case + resampled = ts.resample("5min", origin="1999-12-31 12:02:00+03:00").mean() + tm.assert_index_equal(resampled.index, exp_rng) + + resampled = ts.resample("5min", origin="epoch", offset="2m").mean() + tm.assert_index_equal(resampled.index, exp_rng) + + with pytest.raises(ValueError, match=msg): + ts.resample("5min", origin="12/31/1999 23:57:00").mean() + + # if the series is not tz aware, origin should not be tz aware + rng = date_range("2000-01-01 00:00:00", "2000-01-01 02:00", freq="s").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + with pytest.raises(ValueError, match=msg): + ts.resample("5min", origin="12/31/1999 23:57:00+03:00").mean() + + +def test_resample_origin_epoch_with_tz_day_vs_24h(unit): + # GH 34474 + start, end = "2000-10-01 23:30:00+0500", "2000-12-02 00:30:00+0500" + rng = date_range(start, end, freq="7min").as_unit(unit) + random_values = np.random.default_rng(2).standard_normal(len(rng)) + ts_1 = Series(random_values, index=rng) + + result_1 = ts_1.resample("D", origin="epoch").mean() + result_2 = ts_1.resample("24h", origin="epoch").mean() + tm.assert_series_equal(result_1, result_2) + + # check that we have the same behavior with epoch even if we are not timezone aware + ts_no_tz = ts_1.tz_localize(None) + result_3 = ts_no_tz.resample("D", origin="epoch").mean() + result_4 = ts_no_tz.resample("24h", origin="epoch").mean() + tm.assert_series_equal(result_1, result_3.tz_localize(rng.tz), check_freq=False) + tm.assert_series_equal(result_1, result_4.tz_localize(rng.tz), check_freq=False) + + # check that we have the similar results with two different timezones (+2H and +5H) + start, end = "2000-10-01 23:30:00+0200", "2000-12-02 00:30:00+0200" + rng = date_range(start, end, freq="7min").as_unit(unit) + ts_2 = Series(random_values, index=rng) + result_5 = ts_2.resample("D", origin="epoch").mean() + result_6 = ts_2.resample("24h", origin="epoch").mean() + tm.assert_series_equal(result_1.tz_localize(None), result_5.tz_localize(None)) + tm.assert_series_equal(result_1.tz_localize(None), result_6.tz_localize(None)) + + +def test_resample_origin_with_day_freq_on_dst(unit): + # GH 31809 + tz = "America/Chicago" + + def _create_series(values, timestamps, freq="D"): + return Series( + values, + index=DatetimeIndex( + [Timestamp(t, tz=tz) for t in timestamps], freq=freq, ambiguous=True + ).as_unit(unit), + ) + + # test classical behavior of origin in a DST context + start = Timestamp("2013-11-02", tz=tz) + end = Timestamp("2013-11-03 23:59", tz=tz) + rng = date_range(start, end, freq="1h").as_unit(unit) + ts = Series(np.ones(len(rng)), index=rng) + + expected = _create_series([24.0, 25.0], ["2013-11-02", "2013-11-03"]) + for origin in ["epoch", "start", "start_day", start, None]: + result = ts.resample("D", origin=origin).sum() + tm.assert_series_equal(result, expected) + + # test complex behavior of origin/offset in a DST context + start = Timestamp("2013-11-03", tz=tz) + end = Timestamp("2013-11-03 23:59", tz=tz) + rng = date_range(start, end, freq="1h").as_unit(unit) + ts = Series(np.ones(len(rng)), index=rng) + + expected_ts = ["2013-11-02 22:00-05:00", "2013-11-03 22:00-06:00"] + expected = _create_series([23.0, 2.0], expected_ts) + result = ts.resample("D", origin="start", offset="-2h").sum() + tm.assert_series_equal(result, expected) + + expected_ts = ["2013-11-02 22:00-05:00", "2013-11-03 21:00-06:00"] + expected = _create_series([22.0, 3.0], expected_ts, freq="24h") + result = ts.resample("24h", origin="start", offset="-2h").sum() + tm.assert_series_equal(result, expected) + + expected_ts = ["2013-11-02 02:00-05:00", "2013-11-03 02:00-06:00"] + expected = _create_series([3.0, 22.0], expected_ts) + result = ts.resample("D", origin="start", offset="2h").sum() + tm.assert_series_equal(result, expected) + + expected_ts = ["2013-11-02 23:00-05:00", "2013-11-03 23:00-06:00"] + expected = _create_series([24.0, 1.0], expected_ts) + result = ts.resample("D", origin="start", offset="-1h").sum() + tm.assert_series_equal(result, expected) + + expected_ts = ["2013-11-02 01:00-05:00", "2013-11-03 01:00:00-0500"] + expected = _create_series([1.0, 24.0], expected_ts) + result = ts.resample("D", origin="start", offset="1h").sum() + tm.assert_series_equal(result, expected) + + +def test_resample_daily_anchored(unit): + rng = date_range("1/1/2000 0:00:00", periods=10000, freq="min").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + ts[:2] = np.nan # so results are the same + + result = ts[2:].resample("D", closed="left", label="left").mean() + expected = ts.resample("D", closed="left", label="left").mean() + tm.assert_series_equal(result, expected) + + +def test_resample_to_period_monthly_buglet(unit): + # GH #1259 + + rng = date_range("1/1/2000", "12/31/2000").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + msg = "The 'kind' keyword in Series.resample is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = ts.resample("ME", kind="period").mean() + exp_index = period_range("Jan-2000", "Dec-2000", freq="M") + tm.assert_index_equal(result.index, exp_index) + + +def test_period_with_agg(): + # aggregate a period resampler with a lambda + s2 = Series( + np.random.default_rng(2).integers(0, 5, 50), + index=period_range("2012-01-01", freq="h", periods=50), + dtype="float64", + ) + + expected = s2.to_timestamp().resample("D").mean().to_period() + msg = "Resampling with a PeriodIndex is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + rs = s2.resample("D") + result = rs.agg(lambda x: x.mean()) + tm.assert_series_equal(result, expected) + + +def test_resample_segfault(unit): + # GH 8573 + # segfaulting in older versions + all_wins_and_wagers = [ + (1, datetime(2013, 10, 1, 16, 20), 1, 0), + (2, datetime(2013, 10, 1, 16, 10), 1, 0), + (2, datetime(2013, 10, 1, 18, 15), 1, 0), + (2, datetime(2013, 10, 1, 16, 10, 31), 1, 0), + ] + + df = DataFrame.from_records( + all_wins_and_wagers, columns=("ID", "timestamp", "A", "B") + ).set_index("timestamp") + df.index = df.index.as_unit(unit) + msg = "DataFrameGroupBy.resample operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("ID").resample("5min").sum() + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + expected = df.groupby("ID").apply(lambda x: x.resample("5min").sum()) + tm.assert_frame_equal(result, expected) + + +def test_resample_dtype_preservation(unit): + # GH 12202 + # validation tests for dtype preservation + + df = DataFrame( + { + "date": date_range(start="2016-01-01", periods=4, freq="W").as_unit(unit), + "group": [1, 1, 2, 2], + "val": Series([5, 6, 7, 8], dtype="int32"), + } + ).set_index("date") + + result = df.resample("1D").ffill() + assert result.val.dtype == np.int32 + + msg = "DataFrameGroupBy.resample operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("group").resample("1D").ffill() + assert result.val.dtype == np.int32 + + +def test_resample_dtype_coercion(unit): + pytest.importorskip("scipy.interpolate") + + # GH 16361 + df = {"a": [1, 3, 1, 4]} + df = DataFrame(df, index=date_range("2017-01-01", "2017-01-04").as_unit(unit)) + + expected = df.astype("float64").resample("h").mean()["a"].interpolate("cubic") + + result = df.resample("h")["a"].mean().interpolate("cubic") + tm.assert_series_equal(result, expected) + + result = df.resample("h").mean()["a"].interpolate("cubic") + tm.assert_series_equal(result, expected) + + +def test_weekly_resample_buglet(unit): + # #1327 + rng = date_range("1/1/2000", freq="B", periods=20).as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + resampled = ts.resample("W").mean() + expected = ts.resample("W-SUN").mean() + tm.assert_series_equal(resampled, expected) + + +def test_monthly_resample_error(unit): + # #1451 + dates = date_range("4/16/2012 20:00", periods=5000, freq="h").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(dates)), index=dates) + # it works! + ts.resample("ME") + + +def test_nanosecond_resample_error(): + # GH 12307 - Values falls after last bin when + # Resampling using pd.tseries.offsets.Nano as period + start = 1443707890427 + exp_start = 1443707890400 + indx = date_range(start=pd.to_datetime(start), periods=10, freq="100ns") + ts = Series(range(len(indx)), index=indx) + r = ts.resample(pd.tseries.offsets.Nano(100)) + result = r.agg("mean") + + exp_indx = date_range(start=pd.to_datetime(exp_start), periods=10, freq="100ns") + exp = Series(range(len(exp_indx)), index=exp_indx, dtype=float) + + tm.assert_series_equal(result, exp) + + +def test_resample_anchored_intraday(unit): + # #1471, #1458 + + rng = date_range("1/1/2012", "4/1/2012", freq="100min").as_unit(unit) + df = DataFrame(rng.month, index=rng) + + result = df.resample("ME").mean() + msg = "The 'kind' keyword in DataFrame.resample is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + expected = df.resample("ME", kind="period").mean().to_timestamp(how="end") + expected.index += Timedelta(1, "ns") - Timedelta(1, "D") + expected.index = expected.index.as_unit(unit)._with_freq("infer") + assert expected.index.freq == "ME" + tm.assert_frame_equal(result, expected) + + result = df.resample("ME", closed="left").mean() + msg = "The 'kind' keyword in DataFrame.resample is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + exp = df.shift(1, freq="D").resample("ME", kind="period").mean() + exp = exp.to_timestamp(how="end") + + exp.index = exp.index + Timedelta(1, "ns") - Timedelta(1, "D") + exp.index = exp.index.as_unit(unit)._with_freq("infer") + assert exp.index.freq == "ME" + tm.assert_frame_equal(result, exp) + + +def test_resample_anchored_intraday2(unit): + rng = date_range("1/1/2012", "4/1/2012", freq="100min").as_unit(unit) + df = DataFrame(rng.month, index=rng) + + result = df.resample("QE").mean() + msg = "The 'kind' keyword in DataFrame.resample is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + expected = df.resample("QE", kind="period").mean().to_timestamp(how="end") + expected.index += Timedelta(1, "ns") - Timedelta(1, "D") + expected.index._data.freq = "QE" + expected.index._freq = lib.no_default + expected.index = expected.index.as_unit(unit) + tm.assert_frame_equal(result, expected) + + result = df.resample("QE", closed="left").mean() + msg = "The 'kind' keyword in DataFrame.resample is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + expected = ( + df.shift(1, freq="D").resample("QE", kind="period", closed="left").mean() + ) + expected = expected.to_timestamp(how="end") + expected.index += Timedelta(1, "ns") - Timedelta(1, "D") + expected.index._data.freq = "QE" + expected.index._freq = lib.no_default + expected.index = expected.index.as_unit(unit) + tm.assert_frame_equal(result, expected) + + +def test_resample_anchored_intraday3(simple_date_range_series, unit): + ts = simple_date_range_series("2012-04-29 23:00", "2012-04-30 5:00", freq="h") + ts.index = ts.index.as_unit(unit) + resampled = ts.resample("ME").mean() + assert len(resampled) == 1 + + +@pytest.mark.parametrize("freq", ["MS", "BMS", "QS-MAR", "YS-DEC", "YS-JUN"]) +def test_resample_anchored_monthstart(simple_date_range_series, freq, unit): + ts = simple_date_range_series("1/1/2000", "12/31/2002") + ts.index = ts.index.as_unit(unit) + ts.resample(freq).mean() + + +@pytest.mark.parametrize("label, sec", [[None, 2.0], ["right", "4.2"]]) +def test_resample_anchored_multiday(label, sec): + # When resampling a range spanning multiple days, ensure that the + # start date gets used to determine the offset. Fixes issue where + # a one day period is not a multiple of the frequency. + # + # See: https://github.com/pandas-dev/pandas/issues/8683 + + index1 = date_range("2014-10-14 23:06:23.206", periods=3, freq="400ms") + index2 = date_range("2014-10-15 23:00:00", periods=2, freq="2200ms") + index = index1.union(index2) + + s = Series(np.random.default_rng(2).standard_normal(5), index=index) + + # Ensure left closing works + result = s.resample("2200ms", label=label).mean() + assert result.index[-1] == Timestamp(f"2014-10-15 23:00:{sec}00") + + +def test_corner_cases(unit): + # miscellaneous test coverage + + rng = date_range("1/1/2000", periods=12, freq="min").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + result = ts.resample("5min", closed="right", label="left").mean() + ex_index = date_range("1999-12-31 23:55", periods=4, freq="5min").as_unit(unit) + tm.assert_index_equal(result.index, ex_index) + + +def test_corner_cases_date(simple_date_range_series, unit): + # resample to periods + ts = simple_date_range_series("2000-04-28", "2000-04-30 11:00", freq="h") + ts.index = ts.index.as_unit(unit) + msg = "The 'kind' keyword in Series.resample is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = ts.resample("ME", kind="period").mean() + assert len(result) == 1 + assert result.index[0] == Period("2000-04", freq="M") + + +def test_anchored_lowercase_buglet(unit): + dates = date_range("4/16/2012 20:00", periods=50000, freq="s").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(dates)), index=dates) + # it works! + ts.resample("d").mean() + + +def test_upsample_apply_functions(unit): + # #1596 + rng = date_range("2012-06-12", periods=4, freq="h").as_unit(unit) + + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + result = ts.resample("20min").aggregate(["mean", "sum"]) + assert isinstance(result, DataFrame) + + +def test_resample_not_monotonic(unit): + rng = date_range("2012-06-12", periods=200, freq="h").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + ts = ts.take(np.random.default_rng(2).permutation(len(ts))) + + result = ts.resample("D").sum() + exp = ts.sort_index().resample("D").sum() + tm.assert_series_equal(result, exp) + + +@pytest.mark.parametrize( + "dtype", + [ + "int64", + "int32", + "float64", + pytest.param( + "float32", + marks=pytest.mark.xfail( + reason="Empty groups cause x.mean() to return float64" + ), + ), + ], +) +def test_resample_median_bug_1688(dtype, unit): + # GH#55958 + dti = DatetimeIndex( + [datetime(2012, 1, 1, 0, 0, 0), datetime(2012, 1, 1, 0, 5, 0)] + ).as_unit(unit) + df = DataFrame( + [1, 2], + index=dti, + dtype=dtype, + ) + + result = df.resample("min").apply(lambda x: x.mean()) + exp = df.asfreq("min") + tm.assert_frame_equal(result, exp) + + result = df.resample("min").median() + exp = df.asfreq("min") + tm.assert_frame_equal(result, exp) + + +def test_how_lambda_functions(simple_date_range_series, unit): + ts = simple_date_range_series("1/1/2000", "4/1/2000") + ts.index = ts.index.as_unit(unit) + + result = ts.resample("ME").apply(lambda x: x.mean()) + exp = ts.resample("ME").mean() + tm.assert_series_equal(result, exp) + + foo_exp = ts.resample("ME").mean() + foo_exp.name = "foo" + bar_exp = ts.resample("ME").std() + bar_exp.name = "bar" + + result = ts.resample("ME").apply([lambda x: x.mean(), lambda x: x.std(ddof=1)]) + result.columns = ["foo", "bar"] + tm.assert_series_equal(result["foo"], foo_exp) + tm.assert_series_equal(result["bar"], bar_exp) + + # this is a MI Series, so comparing the names of the results + # doesn't make sense + result = ts.resample("ME").aggregate( + {"foo": lambda x: x.mean(), "bar": lambda x: x.std(ddof=1)} + ) + tm.assert_series_equal(result["foo"], foo_exp, check_names=False) + tm.assert_series_equal(result["bar"], bar_exp, check_names=False) + + +def test_resample_unequal_times(unit): + # #1772 + start = datetime(1999, 3, 1, 5) + # end hour is less than start + end = datetime(2012, 7, 31, 4) + bad_ind = date_range(start, end, freq="30min").as_unit(unit) + df = DataFrame({"close": 1}, index=bad_ind) + + # it works! + df.resample("YS").sum() + + +def test_resample_consistency(unit): + # GH 6418 + # resample with bfill / limit / reindex consistency + + i30 = date_range("2002-02-02", periods=4, freq="30min").as_unit(unit) + s = Series(np.arange(4.0), index=i30) + s.iloc[2] = np.nan + + # Upsample by factor 3 with reindex() and resample() methods: + i10 = date_range(i30[0], i30[-1], freq="10min").as_unit(unit) + + s10 = s.reindex(index=i10, method="bfill") + s10_2 = s.reindex(index=i10, method="bfill", limit=2) + rl = s.reindex_like(s10, method="bfill", limit=2) + r10_2 = s.resample("10Min").bfill(limit=2) + r10 = s.resample("10Min").bfill() + + # s10_2, r10, r10_2, rl should all be equal + tm.assert_series_equal(s10_2, r10) + tm.assert_series_equal(s10_2, r10_2) + tm.assert_series_equal(s10_2, rl) + + +dates1: list[DatetimeNaTType] = [ + datetime(2014, 10, 1), + datetime(2014, 9, 3), + datetime(2014, 11, 5), + datetime(2014, 9, 5), + datetime(2014, 10, 8), + datetime(2014, 7, 15), +] + +dates2: list[DatetimeNaTType] = ( + dates1[:2] + [pd.NaT] + dates1[2:4] + [pd.NaT] + dates1[4:] +) +dates3 = [pd.NaT] + dates1 + [pd.NaT] + + +@pytest.mark.parametrize("dates", [dates1, dates2, dates3]) +def test_resample_timegrouper(dates, unit): + # GH 7227 + dates = DatetimeIndex(dates).as_unit(unit) + df = DataFrame({"A": dates, "B": np.arange(len(dates))}) + result = df.set_index("A").resample("ME").count() + exp_idx = DatetimeIndex( + ["2014-07-31", "2014-08-31", "2014-09-30", "2014-10-31", "2014-11-30"], + freq="ME", + name="A", + ).as_unit(unit) + expected = DataFrame({"B": [1, 0, 2, 2, 1]}, index=exp_idx) + if df["A"].isna().any(): + expected.index = expected.index._with_freq(None) + tm.assert_frame_equal(result, expected) + + result = df.groupby(Grouper(freq="ME", key="A")).count() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dates", [dates1, dates2, dates3]) +def test_resample_timegrouper2(dates, unit): + dates = DatetimeIndex(dates).as_unit(unit) + + df = DataFrame({"A": dates, "B": np.arange(len(dates)), "C": np.arange(len(dates))}) + result = df.set_index("A").resample("ME").count() + + exp_idx = DatetimeIndex( + ["2014-07-31", "2014-08-31", "2014-09-30", "2014-10-31", "2014-11-30"], + freq="ME", + name="A", + ).as_unit(unit) + expected = DataFrame( + {"B": [1, 0, 2, 2, 1], "C": [1, 0, 2, 2, 1]}, + index=exp_idx, + columns=["B", "C"], + ) + if df["A"].isna().any(): + expected.index = expected.index._with_freq(None) + tm.assert_frame_equal(result, expected) + + result = df.groupby(Grouper(freq="ME", key="A")).count() + tm.assert_frame_equal(result, expected) + + +def test_resample_nunique(unit): + # GH 12352 + df = DataFrame( + { + "ID": { + Timestamp("2015-06-05 00:00:00"): "0010100903", + Timestamp("2015-06-08 00:00:00"): "0010150847", + }, + "DATE": { + Timestamp("2015-06-05 00:00:00"): "2015-06-05", + Timestamp("2015-06-08 00:00:00"): "2015-06-08", + }, + } + ) + df.index = df.index.as_unit(unit) + r = df.resample("D") + g = df.groupby(Grouper(freq="D")) + expected = df.groupby(Grouper(freq="D")).ID.apply(lambda x: x.nunique()) + assert expected.name == "ID" + + for t in [r, g]: + result = t.ID.nunique() + tm.assert_series_equal(result, expected) + + result = df.ID.resample("D").nunique() + tm.assert_series_equal(result, expected) + + result = df.ID.groupby(Grouper(freq="D")).nunique() + tm.assert_series_equal(result, expected) + + +def test_resample_nunique_preserves_column_level_names(unit): + # see gh-23222 + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=5, freq="D"), + ).abs() + df.index = df.index.as_unit(unit) + df.columns = pd.MultiIndex.from_arrays( + [df.columns.tolist()] * 2, names=["lev0", "lev1"] + ) + result = df.resample("1h").nunique() + tm.assert_index_equal(df.columns, result.columns) + + +@pytest.mark.parametrize( + "func", + [ + lambda x: x.nunique(), + lambda x: x.agg(Series.nunique), + lambda x: x.agg("nunique"), + ], + ids=["nunique", "series_nunique", "nunique_str"], +) +def test_resample_nunique_with_date_gap(func, unit): + # GH 13453 + # Since all elements are unique, these should all be the same + index = date_range("1-1-2000", "2-15-2000", freq="h").as_unit(unit) + index2 = date_range("4-15-2000", "5-15-2000", freq="h").as_unit(unit) + index3 = index.append(index2) + s = Series(range(len(index3)), index=index3, dtype="int64") + r = s.resample("ME") + result = r.count() + expected = func(r) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("n", [10000, 100000]) +@pytest.mark.parametrize("k", [10, 100, 1000]) +def test_resample_group_info(n, k, unit): + # GH10914 + + # use a fixed seed to always have the same uniques + prng = np.random.default_rng(2) + + dr = date_range(start="2015-08-27", periods=n // 10, freq="min").as_unit(unit) + ts = Series(prng.integers(0, n // k, n).astype("int64"), index=prng.choice(dr, n)) + + left = ts.resample("30min").nunique() + ix = date_range(start=ts.index.min(), end=ts.index.max(), freq="30min").as_unit( + unit + ) + + vals = ts.values + bins = np.searchsorted(ix.values, ts.index, side="right") + + sorter = np.lexsort((vals, bins)) + vals, bins = vals[sorter], bins[sorter] + + mask = np.r_[True, vals[1:] != vals[:-1]] + mask |= np.r_[True, bins[1:] != bins[:-1]] + + arr = np.bincount(bins[mask] - 1, minlength=len(ix)).astype("int64", copy=False) + right = Series(arr, index=ix) + + tm.assert_series_equal(left, right) + + +def test_resample_size(unit): + n = 10000 + dr = date_range("2015-09-19", periods=n, freq="min").as_unit(unit) + ts = Series( + np.random.default_rng(2).standard_normal(n), + index=np.random.default_rng(2).choice(dr, n), + ) + + left = ts.resample("7min").size() + ix = date_range(start=left.index.min(), end=ts.index.max(), freq="7min").as_unit( + unit + ) + + bins = np.searchsorted(ix.values, ts.index.values, side="right") + val = np.bincount(bins, minlength=len(ix) + 1)[1:].astype("int64", copy=False) + + right = Series(val, index=ix) + tm.assert_series_equal(left, right) + + +def test_resample_across_dst(): + # The test resamples a DatetimeIndex with values before and after a + # DST change + # Issue: 14682 + + # The DatetimeIndex we will start with + # (note that DST happens at 03:00+02:00 -> 02:00+01:00) + # 2016-10-30 02:23:00+02:00, 2016-10-30 02:23:00+01:00 + df1 = DataFrame([1477786980, 1477790580], columns=["ts"]) + dti1 = DatetimeIndex( + pd.to_datetime(df1.ts, unit="s") + .dt.tz_localize("UTC") + .dt.tz_convert("Europe/Madrid") + ) + + # The expected DatetimeIndex after resampling. + # 2016-10-30 02:00:00+02:00, 2016-10-30 02:00:00+01:00 + df2 = DataFrame([1477785600, 1477789200], columns=["ts"]) + dti2 = DatetimeIndex( + pd.to_datetime(df2.ts, unit="s") + .dt.tz_localize("UTC") + .dt.tz_convert("Europe/Madrid"), + freq="h", + ) + df = DataFrame([5, 5], index=dti1) + + result = df.resample(rule="h").sum() + expected = DataFrame([5, 5], index=dti2) + + tm.assert_frame_equal(result, expected) + + +def test_groupby_with_dst_time_change(unit): + # GH 24972 + index = ( + DatetimeIndex([1478064900001000000, 1480037118776792000], tz="UTC") + .tz_convert("America/Chicago") + .as_unit(unit) + ) + + df = DataFrame([1, 2], index=index) + result = df.groupby(Grouper(freq="1d")).last() + expected_index_values = date_range( + "2016-11-02", "2016-11-24", freq="d", tz="America/Chicago" + ).as_unit(unit) + + index = DatetimeIndex(expected_index_values) + expected = DataFrame([1.0] + ([np.nan] * 21) + [2.0], index=index) + tm.assert_frame_equal(result, expected) + + +def test_resample_dst_anchor(unit): + # 5172 + dti = DatetimeIndex([datetime(2012, 11, 4, 23)], tz="US/Eastern").as_unit(unit) + df = DataFrame([5], index=dti) + + dti = DatetimeIndex(df.index.normalize(), freq="D").as_unit(unit) + expected = DataFrame([5], index=dti) + tm.assert_frame_equal(df.resample(rule="D").sum(), expected) + df.resample(rule="MS").sum() + tm.assert_frame_equal( + df.resample(rule="MS").sum(), + DataFrame( + [5], + index=DatetimeIndex( + [datetime(2012, 11, 1)], tz="US/Eastern", freq="MS" + ).as_unit(unit), + ), + ) + + +def test_resample_dst_anchor2(unit): + dti = date_range( + "2013-09-30", "2013-11-02", freq="30Min", tz="Europe/Paris" + ).as_unit(unit) + values = range(dti.size) + df = DataFrame({"a": values, "b": values, "c": values}, index=dti, dtype="int64") + how = {"a": "min", "b": "max", "c": "count"} + + rs = df.resample("W-MON") + result = rs.agg(how)[["a", "b", "c"]] + expected = DataFrame( + { + "a": [0, 48, 384, 720, 1056, 1394], + "b": [47, 383, 719, 1055, 1393, 1586], + "c": [48, 336, 336, 336, 338, 193], + }, + index=date_range( + "9/30/2013", "11/4/2013", freq="W-MON", tz="Europe/Paris" + ).as_unit(unit), + ) + tm.assert_frame_equal( + result, + expected, + "W-MON Frequency", + ) + + rs2 = df.resample("2W-MON") + result2 = rs2.agg(how)[["a", "b", "c"]] + expected2 = DataFrame( + { + "a": [0, 48, 720, 1394], + "b": [47, 719, 1393, 1586], + "c": [48, 672, 674, 193], + }, + index=date_range( + "9/30/2013", "11/11/2013", freq="2W-MON", tz="Europe/Paris" + ).as_unit(unit), + ) + tm.assert_frame_equal( + result2, + expected2, + "2W-MON Frequency", + ) + + rs3 = df.resample("MS") + result3 = rs3.agg(how)[["a", "b", "c"]] + expected3 = DataFrame( + {"a": [0, 48, 1538], "b": [47, 1537, 1586], "c": [48, 1490, 49]}, + index=date_range("9/1/2013", "11/1/2013", freq="MS", tz="Europe/Paris").as_unit( + unit + ), + ) + tm.assert_frame_equal( + result3, + expected3, + "MS Frequency", + ) + + rs4 = df.resample("2MS") + result4 = rs4.agg(how)[["a", "b", "c"]] + expected4 = DataFrame( + {"a": [0, 1538], "b": [1537, 1586], "c": [1538, 49]}, + index=date_range( + "9/1/2013", "11/1/2013", freq="2MS", tz="Europe/Paris" + ).as_unit(unit), + ) + tm.assert_frame_equal( + result4, + expected4, + "2MS Frequency", + ) + + df_daily = df["10/26/2013":"10/29/2013"] + rs_d = df_daily.resample("D") + result_d = rs_d.agg({"a": "min", "b": "max", "c": "count"})[["a", "b", "c"]] + expected_d = DataFrame( + { + "a": [1248, 1296, 1346, 1394], + "b": [1295, 1345, 1393, 1441], + "c": [48, 50, 48, 48], + }, + index=date_range( + "10/26/2013", "10/29/2013", freq="D", tz="Europe/Paris" + ).as_unit(unit), + ) + tm.assert_frame_equal( + result_d, + expected_d, + "D Frequency", + ) + + +def test_downsample_across_dst(unit): + # GH 8531 + tz = pytz.timezone("Europe/Berlin") + dt = datetime(2014, 10, 26) + dates = date_range(tz.localize(dt), periods=4, freq="2h").as_unit(unit) + result = Series(5, index=dates).resample("h").mean() + expected = Series( + [5.0, np.nan] * 3 + [5.0], + index=date_range(tz.localize(dt), periods=7, freq="h").as_unit(unit), + ) + tm.assert_series_equal(result, expected) + + +def test_downsample_across_dst_weekly(unit): + # GH 9119, GH 21459 + df = DataFrame( + index=DatetimeIndex( + ["2017-03-25", "2017-03-26", "2017-03-27", "2017-03-28", "2017-03-29"], + tz="Europe/Amsterdam", + ).as_unit(unit), + data=[11, 12, 13, 14, 15], + ) + result = df.resample("1W").sum() + expected = DataFrame( + [23, 42], + index=DatetimeIndex( + ["2017-03-26", "2017-04-02"], tz="Europe/Amsterdam", freq="W" + ).as_unit(unit), + ) + tm.assert_frame_equal(result, expected) + + +def test_downsample_across_dst_weekly_2(unit): + # GH 9119, GH 21459 + idx = date_range("2013-04-01", "2013-05-01", tz="Europe/London", freq="h").as_unit( + unit + ) + s = Series(index=idx, dtype=np.float64) + result = s.resample("W").mean() + expected = Series( + index=date_range("2013-04-07", freq="W", periods=5, tz="Europe/London").as_unit( + unit + ), + dtype=np.float64, + ) + tm.assert_series_equal(result, expected) + + +def test_downsample_dst_at_midnight(unit): + # GH 25758 + start = datetime(2018, 11, 3, 12) + end = datetime(2018, 11, 5, 12) + index = date_range(start, end, freq="1h").as_unit(unit) + index = index.tz_localize("UTC").tz_convert("America/Havana") + data = list(range(len(index))) + dataframe = DataFrame(data, index=index) + result = dataframe.groupby(Grouper(freq="1D")).mean() + + dti = date_range("2018-11-03", periods=3).tz_localize( + "America/Havana", ambiguous=True + ) + dti = DatetimeIndex(dti, freq="D").as_unit(unit) + expected = DataFrame([7.5, 28.0, 44.5], index=dti) + tm.assert_frame_equal(result, expected) + + +def test_resample_with_nat(unit): + # GH 13020 + index = DatetimeIndex( + [ + pd.NaT, + "1970-01-01 00:00:00", + pd.NaT, + "1970-01-01 00:00:01", + "1970-01-01 00:00:02", + ] + ).as_unit(unit) + frame = DataFrame([2, 3, 5, 7, 11], index=index) + + index_1s = DatetimeIndex( + ["1970-01-01 00:00:00", "1970-01-01 00:00:01", "1970-01-01 00:00:02"] + ).as_unit(unit) + frame_1s = DataFrame([3.0, 7.0, 11.0], index=index_1s) + tm.assert_frame_equal(frame.resample("1s").mean(), frame_1s) + + index_2s = DatetimeIndex(["1970-01-01 00:00:00", "1970-01-01 00:00:02"]).as_unit( + unit + ) + frame_2s = DataFrame([5.0, 11.0], index=index_2s) + tm.assert_frame_equal(frame.resample("2s").mean(), frame_2s) + + index_3s = DatetimeIndex(["1970-01-01 00:00:00"]).as_unit(unit) + frame_3s = DataFrame([7.0], index=index_3s) + tm.assert_frame_equal(frame.resample("3s").mean(), frame_3s) + + tm.assert_frame_equal(frame.resample("60s").mean(), frame_3s) + + +def test_resample_datetime_values(unit): + # GH 13119 + # check that datetime dtype is preserved when NaT values are + # introduced by the resampling + + dates = [datetime(2016, 1, 15), datetime(2016, 1, 19)] + df = DataFrame({"timestamp": dates}, index=dates) + df.index = df.index.as_unit(unit) + + exp = Series( + [datetime(2016, 1, 15), pd.NaT, datetime(2016, 1, 19)], + index=date_range("2016-01-15", periods=3, freq="2D").as_unit(unit), + name="timestamp", + ) + + res = df.resample("2D").first()["timestamp"] + tm.assert_series_equal(res, exp) + res = df["timestamp"].resample("2D").first() + tm.assert_series_equal(res, exp) + + +def test_resample_apply_with_additional_args(series, unit): + # GH 14615 + def f(data, add_arg): + return np.mean(data) * add_arg + + series.index = series.index.as_unit(unit) + + multiplier = 10 + result = series.resample("D").apply(f, multiplier) + expected = series.resample("D").mean().multiply(multiplier) + tm.assert_series_equal(result, expected) + + # Testing as kwarg + result = series.resample("D").apply(f, add_arg=multiplier) + expected = series.resample("D").mean().multiply(multiplier) + tm.assert_series_equal(result, expected) + + +def test_resample_apply_with_additional_args2(): + # Testing dataframe + def f(data, add_arg): + return np.mean(data) * add_arg + + multiplier = 10 + + df = DataFrame({"A": 1, "B": 2}, index=date_range("2017", periods=10)) + msg = "DataFrameGroupBy.resample operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("A").resample("D").agg(f, multiplier).astype(float) + msg = "DataFrameGroupBy.resample operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + expected = df.groupby("A").resample("D").mean().multiply(multiplier) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("k", [1, 2, 3]) +@pytest.mark.parametrize( + "n1, freq1, n2, freq2", + [ + (30, "s", 0.5, "Min"), + (60, "s", 1, "Min"), + (3600, "s", 1, "h"), + (60, "Min", 1, "h"), + (21600, "s", 0.25, "D"), + (86400, "s", 1, "D"), + (43200, "s", 0.5, "D"), + (1440, "Min", 1, "D"), + (12, "h", 0.5, "D"), + (24, "h", 1, "D"), + ], +) +def test_resample_equivalent_offsets(n1, freq1, n2, freq2, k, unit): + # GH 24127 + n1_ = n1 * k + n2_ = n2 * k + dti = date_range("1991-09-05", "1991-09-12", freq=freq1).as_unit(unit) + ser = Series(range(len(dti)), index=dti) + + result1 = ser.resample(str(n1_) + freq1).mean() + result2 = ser.resample(str(n2_) + freq2).mean() + tm.assert_series_equal(result1, result2) + + +@pytest.mark.parametrize( + "first,last,freq,exp_first,exp_last", + [ + ("19910905", "19920406", "D", "19910905", "19920407"), + ("19910905 00:00", "19920406 06:00", "D", "19910905", "19920407"), + ("19910905 06:00", "19920406 06:00", "h", "19910905 06:00", "19920406 07:00"), + ("19910906", "19920406", "ME", "19910831", "19920430"), + ("19910831", "19920430", "ME", "19910831", "19920531"), + ("1991-08", "1992-04", "ME", "19910831", "19920531"), + ], +) +def test_get_timestamp_range_edges(first, last, freq, exp_first, exp_last, unit): + first = Period(first) + first = first.to_timestamp(first.freq).as_unit(unit) + last = Period(last) + last = last.to_timestamp(last.freq).as_unit(unit) + + exp_first = Timestamp(exp_first) + exp_last = Timestamp(exp_last) + + freq = pd.tseries.frequencies.to_offset(freq) + result = _get_timestamp_range_edges(first, last, freq, unit="ns") + expected = (exp_first, exp_last) + assert result == expected + + +@pytest.mark.parametrize("duplicates", [True, False]) +def test_resample_apply_product(duplicates, unit): + # GH 5586 + index = date_range(start="2012-01-31", freq="ME", periods=12).as_unit(unit) + + ts = Series(range(12), index=index) + df = DataFrame({"A": ts, "B": ts + 2}) + if duplicates: + df.columns = ["A", "A"] + + msg = "using DatetimeIndexResampler.prod" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = df.resample("QE").apply(np.prod) + expected = DataFrame( + np.array([[0, 24], [60, 210], [336, 720], [990, 1716]], dtype=np.int64), + index=DatetimeIndex( + ["2012-03-31", "2012-06-30", "2012-09-30", "2012-12-31"], freq="QE-DEC" + ).as_unit(unit), + columns=df.columns, + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "first,last,freq_in,freq_out,exp_last", + [ + ( + "2020-03-28", + "2020-03-31", + "D", + "24h", + "2020-03-30 01:00", + ), # includes transition into DST + ( + "2020-03-28", + "2020-10-27", + "D", + "24h", + "2020-10-27 00:00", + ), # includes transition into and out of DST + ( + "2020-10-25", + "2020-10-27", + "D", + "24h", + "2020-10-26 23:00", + ), # includes transition out of DST + ( + "2020-03-28", + "2020-03-31", + "24h", + "D", + "2020-03-30 00:00", + ), # same as above, but from 24H to D + ("2020-03-28", "2020-10-27", "24h", "D", "2020-10-27 00:00"), + ("2020-10-25", "2020-10-27", "24h", "D", "2020-10-26 00:00"), + ], +) +def test_resample_calendar_day_with_dst( + first: str, last: str, freq_in: str, freq_out: str, exp_last: str, unit +): + # GH 35219 + ts = Series( + 1.0, date_range(first, last, freq=freq_in, tz="Europe/Amsterdam").as_unit(unit) + ) + result = ts.resample(freq_out).ffill() + expected = Series( + 1.0, + date_range(first, exp_last, freq=freq_out, tz="Europe/Amsterdam").as_unit(unit), + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("func", ["min", "max", "first", "last"]) +def test_resample_aggregate_functions_min_count(func, unit): + # GH#37768 + index = date_range(start="2020", freq="ME", periods=3).as_unit(unit) + ser = Series([1, np.nan, np.nan], index) + result = getattr(ser.resample("QE"), func)(min_count=2) + expected = Series( + [np.nan], + index=DatetimeIndex(["2020-03-31"], freq="QE-DEC").as_unit(unit), + ) + tm.assert_series_equal(result, expected) + + +def test_resample_unsigned_int(any_unsigned_int_numpy_dtype, unit): + # gh-43329 + df = DataFrame( + index=date_range(start="2000-01-01", end="2000-01-03 23", freq="12h").as_unit( + unit + ), + columns=["x"], + data=[0, 1, 0] * 2, + dtype=any_unsigned_int_numpy_dtype, + ) + df = df.loc[(df.index < "2000-01-02") | (df.index > "2000-01-03"), :] + + result = df.resample("D").max() + + expected = DataFrame( + [1, np.nan, 0], + columns=["x"], + index=date_range(start="2000-01-01", end="2000-01-03 23", freq="D").as_unit( + unit + ), + ) + tm.assert_frame_equal(result, expected) + + +def test_long_rule_non_nano(): + # https://github.com/pandas-dev/pandas/issues/51024 + idx = date_range("0300-01-01", "2000-01-01", unit="s", freq="100YE") + ser = Series([1, 4, 2, 8, 5, 7, 1, 4, 2, 8, 5, 7, 1, 4, 2, 8, 5], index=idx) + result = ser.resample("200YE").mean() + expected_idx = DatetimeIndex( + np.array( + [ + "0300-12-31", + "0500-12-31", + "0700-12-31", + "0900-12-31", + "1100-12-31", + "1300-12-31", + "1500-12-31", + "1700-12-31", + "1900-12-31", + ] + ).astype("datetime64[s]"), + freq="200YE-DEC", + ) + expected = Series([1.0, 3.0, 6.5, 4.0, 3.0, 6.5, 4.0, 3.0, 6.5], index=expected_idx) + tm.assert_series_equal(result, expected) + + +def test_resample_empty_series_with_tz(): + # GH#53664 + df = DataFrame({"ts": [], "values": []}).astype( + {"ts": "datetime64[ns, Atlantic/Faroe]"} + ) + result = df.resample("2MS", on="ts", closed="left", label="left", origin="start")[ + "values" + ].sum() + + expected_idx = DatetimeIndex( + [], freq="2MS", name="ts", dtype="datetime64[ns, Atlantic/Faroe]" + ) + expected = Series([], index=expected_idx, name="values", dtype="float64") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "freq, freq_depr", + [ + ("2ME", "2M"), + ("2QE", "2Q"), + ("2QE-SEP", "2Q-SEP"), + ("1YE", "1Y"), + ("2YE-MAR", "2Y-MAR"), + ("1YE", "1A"), + ("2YE-MAR", "2A-MAR"), + ], +) +def test_resample_M_Q_Y_A_deprecated(freq, freq_depr): + # GH#9586 + depr_msg = f"'{freq_depr[1:]}' is deprecated and will be removed " + f"in a future version, please use '{freq[1:]}' instead." + + s = Series(range(10), index=date_range("20130101", freq="d", periods=10)) + expected = s.resample(freq).mean() + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + result = s.resample(freq_depr).mean() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "freq, freq_depr", + [ + ("2BME", "2BM"), + ("2BQE", "2BQ"), + ("2BQE-MAR", "2BQ-MAR"), + ], +) +def test_resample_BM_BQ_deprecated(freq, freq_depr): + # GH#52064 + depr_msg = f"'{freq_depr[1:]}' is deprecated and will be removed " + f"in a future version, please use '{freq[1:]}' instead." + + s = Series(range(10), index=date_range("20130101", freq="d", periods=10)) + expected = s.resample(freq).mean() + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + result = s.resample(freq_depr).mean() + tm.assert_series_equal(result, expected) + + +def test_resample_ms_closed_right(unit): + # https://github.com/pandas-dev/pandas/issues/55271 + dti = date_range(start="2020-01-31", freq="1min", periods=6000, unit=unit) + df = DataFrame({"ts": dti}, index=dti) + grouped = df.resample("MS", closed="right") + result = grouped.last() + exp_dti = DatetimeIndex( + [datetime(2020, 1, 1), datetime(2020, 2, 1)], freq="MS" + ).as_unit(unit) + expected = DataFrame( + {"ts": [datetime(2020, 2, 1), datetime(2020, 2, 4, 3, 59)]}, + index=exp_dti, + ).astype(f"M8[{unit}]") + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("freq", ["B", "C"]) +def test_resample_c_b_closed_right(freq: str, unit): + # https://github.com/pandas-dev/pandas/issues/55281 + dti = date_range(start="2020-01-31", freq="1min", periods=6000, unit=unit) + df = DataFrame({"ts": dti}, index=dti) + grouped = df.resample(freq, closed="right") + result = grouped.last() + + exp_dti = DatetimeIndex( + [ + datetime(2020, 1, 30), + datetime(2020, 1, 31), + datetime(2020, 2, 3), + datetime(2020, 2, 4), + ], + freq=freq, + ).as_unit(unit) + expected = DataFrame( + { + "ts": [ + datetime(2020, 1, 31), + datetime(2020, 2, 3), + datetime(2020, 2, 4), + datetime(2020, 2, 4, 3, 59), + ] + }, + index=exp_dti, + ).astype(f"M8[{unit}]") + tm.assert_frame_equal(result, expected) + + +def test_resample_b_55282(unit): + # https://github.com/pandas-dev/pandas/issues/55282 + dti = date_range("2023-09-26", periods=6, freq="12h", unit=unit) + ser = Series([1, 2, 3, 4, 5, 6], index=dti) + result = ser.resample("B", closed="right", label="right").mean() + + exp_dti = DatetimeIndex( + [ + datetime(2023, 9, 26), + datetime(2023, 9, 27), + datetime(2023, 9, 28), + datetime(2023, 9, 29), + ], + freq="B", + ).as_unit(unit) + expected = Series( + [1.0, 2.5, 4.5, 6.0], + index=exp_dti, + ) + tm.assert_series_equal(result, expected) + + +@td.skip_if_no("pyarrow") +@pytest.mark.parametrize( + "tz", + [ + None, + pytest.param( + "UTC", + marks=pytest.mark.xfail( + condition=is_platform_windows(), + reason="TODO: Set ARROW_TIMEZONE_DATABASE env var in CI", + ), + ), + ], +) +def test_arrow_timestamp_resample(tz): + # GH 56371 + idx = Series(date_range("2020-01-01", periods=5), dtype="timestamp[ns][pyarrow]") + if tz is not None: + idx = idx.dt.tz_localize(tz) + expected = Series(np.arange(5, dtype=np.float64), index=idx) + result = expected.resample("1D").mean() + tm.assert_series_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/resample/test_period_index.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/resample/test_period_index.py new file mode 100644 index 0000000000000000000000000000000000000000..6b7cce7d15a5b46ed121d6ea4e15c99878f35f97 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/resample/test_period_index.py @@ -0,0 +1,1100 @@ +from datetime import datetime +import warnings + +import dateutil +import numpy as np +import pytest +import pytz + +from pandas._libs.tslibs.ccalendar import ( + DAYS, + MONTHS, +) +from pandas._libs.tslibs.period import IncompatibleFrequency +from pandas.errors import InvalidIndexError + +import pandas as pd +from pandas import ( + DataFrame, + Series, + Timestamp, +) +import pandas._testing as tm +from pandas.core.indexes.datetimes import date_range +from pandas.core.indexes.period import ( + Period, + PeriodIndex, + period_range, +) +from pandas.core.resample import _get_period_range_edges + +from pandas.tseries import offsets + +pytestmark = pytest.mark.filterwarnings( + "ignore:Resampling with a PeriodIndex is deprecated:FutureWarning" +) + + +@pytest.fixture() +def _index_factory(): + return period_range + + +@pytest.fixture +def _series_name(): + return "pi" + + +@pytest.fixture +def simple_period_range_series(): + """ + Series with period range index and random data for test purposes. + """ + + def _simple_period_range_series(start, end, freq="D"): + with warnings.catch_warnings(): + # suppress Period[B] deprecation warning + msg = "|".join(["Period with BDay freq", r"PeriodDtype\[B\] is deprecated"]) + warnings.filterwarnings( + "ignore", + msg, + category=FutureWarning, + ) + rng = period_range(start, end, freq=freq) + return Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + return _simple_period_range_series + + +class TestPeriodIndex: + @pytest.mark.parametrize("freq", ["2D", "1h", "2h"]) + @pytest.mark.parametrize("kind", ["period", None, "timestamp"]) + def test_asfreq(self, series_and_frame, freq, kind): + # GH 12884, 15944 + # make sure .asfreq() returns PeriodIndex (except kind='timestamp') + + obj = series_and_frame + if kind == "timestamp": + expected = obj.to_timestamp().resample(freq).asfreq() + else: + start = obj.index[0].to_timestamp(how="start") + end = (obj.index[-1] + obj.index.freq).to_timestamp(how="start") + new_index = date_range(start=start, end=end, freq=freq, inclusive="left") + expected = obj.to_timestamp().reindex(new_index).to_period(freq) + msg = "The 'kind' keyword in (Series|DataFrame).resample is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = obj.resample(freq, kind=kind).asfreq() + tm.assert_almost_equal(result, expected) + + def test_asfreq_fill_value(self, series): + # test for fill value during resampling, issue 3715 + + s = series + new_index = date_range( + s.index[0].to_timestamp(how="start"), + (s.index[-1]).to_timestamp(how="start"), + freq="1h", + ) + expected = s.to_timestamp().reindex(new_index, fill_value=4.0) + msg = "The 'kind' keyword in Series.resample is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = s.resample("1h", kind="timestamp").asfreq(fill_value=4.0) + tm.assert_series_equal(result, expected) + + frame = s.to_frame("value") + new_index = date_range( + frame.index[0].to_timestamp(how="start"), + (frame.index[-1]).to_timestamp(how="start"), + freq="1h", + ) + expected = frame.to_timestamp().reindex(new_index, fill_value=3.0) + msg = "The 'kind' keyword in DataFrame.resample is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = frame.resample("1h", kind="timestamp").asfreq(fill_value=3.0) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("freq", ["h", "12h", "2D", "W"]) + @pytest.mark.parametrize("kind", [None, "period", "timestamp"]) + @pytest.mark.parametrize("kwargs", [{"on": "date"}, {"level": "d"}]) + def test_selection(self, index, freq, kind, kwargs): + # This is a bug, these should be implemented + # GH 14008 + rng = np.arange(len(index), dtype=np.int64) + df = DataFrame( + {"date": index, "a": rng}, + index=pd.MultiIndex.from_arrays([rng, index], names=["v", "d"]), + ) + msg = ( + "Resampling from level= or on= selection with a PeriodIndex is " + r"not currently supported, use \.set_index\(\.\.\.\) to " + "explicitly set index" + ) + depr_msg = "The 'kind' keyword in DataFrame.resample is deprecated" + with pytest.raises(NotImplementedError, match=msg): + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + df.resample(freq, kind=kind, **kwargs) + + @pytest.mark.parametrize("month", MONTHS) + @pytest.mark.parametrize("meth", ["ffill", "bfill"]) + @pytest.mark.parametrize("conv", ["start", "end"]) + @pytest.mark.parametrize( + ("offset", "period"), [("D", "D"), ("B", "B"), ("ME", "M"), ("QE", "Q")] + ) + def test_annual_upsample_cases( + self, offset, period, conv, meth, month, simple_period_range_series + ): + ts = simple_period_range_series("1/1/1990", "12/31/1991", freq=f"Y-{month}") + warn = FutureWarning if period == "B" else None + msg = r"PeriodDtype\[B\] is deprecated" + if warn is None: + msg = "Resampling with a PeriodIndex is deprecated" + warn = FutureWarning + with tm.assert_produces_warning(warn, match=msg): + result = getattr(ts.resample(period, convention=conv), meth)() + expected = result.to_timestamp(period, how=conv) + expected = expected.asfreq(offset, meth).to_period() + tm.assert_series_equal(result, expected) + + def test_basic_downsample(self, simple_period_range_series): + ts = simple_period_range_series("1/1/1990", "6/30/1995", freq="M") + result = ts.resample("Y-DEC").mean() + + expected = ts.groupby(ts.index.year).mean() + expected.index = period_range("1/1/1990", "6/30/1995", freq="Y-DEC") + tm.assert_series_equal(result, expected) + + # this is ok + tm.assert_series_equal(ts.resample("Y-DEC").mean(), result) + tm.assert_series_equal(ts.resample("Y").mean(), result) + + @pytest.mark.parametrize( + "rule,expected_error_msg", + [ + ("Y-DEC", ""), + ("Q-MAR", ""), + ("M", ""), + ("w-thu", ""), + ], + ) + def test_not_subperiod(self, simple_period_range_series, rule, expected_error_msg): + # These are incompatible period rules for resampling + ts = simple_period_range_series("1/1/1990", "6/30/1995", freq="w-wed") + msg = ( + "Frequency cannot be resampled to " + f"{expected_error_msg}, as they are not sub or super periods" + ) + with pytest.raises(IncompatibleFrequency, match=msg): + ts.resample(rule).mean() + + @pytest.mark.parametrize("freq", ["D", "2D"]) + def test_basic_upsample(self, freq, simple_period_range_series): + ts = simple_period_range_series("1/1/1990", "6/30/1995", freq="M") + result = ts.resample("Y-DEC").mean() + + msg = "The 'convention' keyword in Series.resample is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + resampled = result.resample(freq, convention="end").ffill() + expected = result.to_timestamp(freq, how="end") + expected = expected.asfreq(freq, "ffill").to_period(freq) + tm.assert_series_equal(resampled, expected) + + def test_upsample_with_limit(self): + rng = period_range("1/1/2000", periods=5, freq="Y") + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), rng) + + msg = "The 'convention' keyword in Series.resample is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = ts.resample("M", convention="end").ffill(limit=2) + expected = ts.asfreq("M").reindex(result.index, method="ffill", limit=2) + tm.assert_series_equal(result, expected) + + def test_annual_upsample(self, simple_period_range_series): + ts = simple_period_range_series("1/1/1990", "12/31/1995", freq="Y-DEC") + df = DataFrame({"a": ts}) + rdf = df.resample("D").ffill() + exp = df["a"].resample("D").ffill() + tm.assert_series_equal(rdf["a"], exp) + + def test_annual_upsample2(self): + rng = period_range("2000", "2003", freq="Y-DEC") + ts = Series([1, 2, 3, 4], index=rng) + + result = ts.resample("M").ffill() + ex_index = period_range("2000-01", "2003-12", freq="M") + + expected = ts.asfreq("M", how="start").reindex(ex_index, method="ffill") + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("month", MONTHS) + @pytest.mark.parametrize("convention", ["start", "end"]) + @pytest.mark.parametrize( + ("offset", "period"), [("D", "D"), ("B", "B"), ("ME", "M")] + ) + def test_quarterly_upsample( + self, month, offset, period, convention, simple_period_range_series + ): + freq = f"Q-{month}" + ts = simple_period_range_series("1/1/1990", "12/31/1995", freq=freq) + warn = FutureWarning if period == "B" else None + msg = r"PeriodDtype\[B\] is deprecated" + if warn is None: + msg = "Resampling with a PeriodIndex is deprecated" + warn = FutureWarning + with tm.assert_produces_warning(warn, match=msg): + result = ts.resample(period, convention=convention).ffill() + expected = result.to_timestamp(period, how=convention) + expected = expected.asfreq(offset, "ffill").to_period() + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("target", ["D", "B"]) + @pytest.mark.parametrize("convention", ["start", "end"]) + def test_monthly_upsample(self, target, convention, simple_period_range_series): + ts = simple_period_range_series("1/1/1990", "12/31/1995", freq="M") + + warn = None if target == "D" else FutureWarning + msg = r"PeriodDtype\[B\] is deprecated" + if warn is None: + msg = "Resampling with a PeriodIndex is deprecated" + warn = FutureWarning + with tm.assert_produces_warning(warn, match=msg): + result = ts.resample(target, convention=convention).ffill() + expected = result.to_timestamp(target, how=convention) + expected = expected.asfreq(target, "ffill").to_period() + tm.assert_series_equal(result, expected) + + def test_resample_basic(self): + # GH3609 + s = Series( + range(100), + index=date_range("20130101", freq="s", periods=100, name="idx"), + dtype="float", + ) + s[10:30] = np.nan + index = PeriodIndex( + [Period("2013-01-01 00:00", "min"), Period("2013-01-01 00:01", "min")], + name="idx", + ) + expected = Series([34.5, 79.5], index=index) + msg = "The 'kind' keyword in Series.resample is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = s.to_period().resample("min", kind="period").mean() + tm.assert_series_equal(result, expected) + with tm.assert_produces_warning(FutureWarning, match=msg): + result2 = s.resample("min", kind="period").mean() + tm.assert_series_equal(result2, expected) + + @pytest.mark.parametrize( + "freq,expected_vals", [("M", [31, 29, 31, 9]), ("2M", [31 + 29, 31 + 9])] + ) + def test_resample_count(self, freq, expected_vals): + # GH12774 + series = Series(1, index=period_range(start="2000", periods=100)) + result = series.resample(freq).count() + expected_index = period_range( + start="2000", freq=freq, periods=len(expected_vals) + ) + expected = Series(expected_vals, index=expected_index) + tm.assert_series_equal(result, expected) + + def test_resample_same_freq(self, resample_method): + # GH12770 + series = Series(range(3), index=period_range(start="2000", periods=3, freq="M")) + expected = series + + result = getattr(series.resample("M"), resample_method)() + tm.assert_series_equal(result, expected) + + def test_resample_incompat_freq(self): + msg = ( + "Frequency cannot be resampled to , " + "as they are not sub or super periods" + ) + pi = period_range(start="2000", periods=3, freq="M") + ser = Series(range(3), index=pi) + rs = ser.resample("W") + with pytest.raises(IncompatibleFrequency, match=msg): + # TODO: should this raise at the resample call instead of at the mean call? + rs.mean() + + @pytest.mark.parametrize( + "tz", + [ + pytz.timezone("America/Los_Angeles"), + dateutil.tz.gettz("America/Los_Angeles"), + ], + ) + def test_with_local_timezone(self, tz): + # see gh-5430 + local_timezone = tz + + start = datetime(year=2013, month=11, day=1, hour=0, minute=0, tzinfo=pytz.utc) + # 1 day later + end = datetime(year=2013, month=11, day=2, hour=0, minute=0, tzinfo=pytz.utc) + + index = date_range(start, end, freq="h", name="idx") + + series = Series(1, index=index) + series = series.tz_convert(local_timezone) + msg = "The 'kind' keyword in Series.resample is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = series.resample("D", kind="period").mean() + + # Create the expected series + # Index is moved back a day with the timezone conversion from UTC to + # Pacific + expected_index = ( + period_range(start=start, end=end, freq="D", name="idx") - offsets.Day() + ) + expected = Series(1.0, index=expected_index) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "tz", + [ + pytz.timezone("America/Los_Angeles"), + dateutil.tz.gettz("America/Los_Angeles"), + ], + ) + def test_resample_with_tz(self, tz, unit): + # GH 13238 + dti = date_range("2017-01-01", periods=48, freq="h", tz=tz, unit=unit) + ser = Series(2, index=dti) + result = ser.resample("D").mean() + exp_dti = pd.DatetimeIndex( + ["2017-01-01", "2017-01-02"], tz=tz, freq="D" + ).as_unit(unit) + expected = Series( + 2.0, + index=exp_dti, + ) + tm.assert_series_equal(result, expected) + # Especially assert that the timezone is LMT for pytz + assert result.index.tz == tz + + def test_resample_nonexistent_time_bin_edge(self): + # GH 19375 + index = date_range("2017-03-12", "2017-03-12 1:45:00", freq="15min") + s = Series(np.zeros(len(index)), index=index) + expected = s.tz_localize("US/Pacific") + expected.index = pd.DatetimeIndex(expected.index, freq="900s") + result = expected.resample("900s").mean() + tm.assert_series_equal(result, expected) + + def test_resample_nonexistent_time_bin_edge2(self): + # GH 23742 + index = date_range(start="2017-10-10", end="2017-10-20", freq="1h") + index = index.tz_localize("UTC").tz_convert("America/Sao_Paulo") + df = DataFrame(data=list(range(len(index))), index=index) + result = df.groupby(pd.Grouper(freq="1D")).count() + expected = date_range( + start="2017-10-09", + end="2017-10-20", + freq="D", + tz="America/Sao_Paulo", + nonexistent="shift_forward", + inclusive="left", + ) + tm.assert_index_equal(result.index, expected) + + def test_resample_ambiguous_time_bin_edge(self): + # GH 10117 + idx = date_range( + "2014-10-25 22:00:00", + "2014-10-26 00:30:00", + freq="30min", + tz="Europe/London", + ) + expected = Series(np.zeros(len(idx)), index=idx) + result = expected.resample("30min").mean() + tm.assert_series_equal(result, expected) + + def test_fill_method_and_how_upsample(self): + # GH2073 + s = Series( + np.arange(9, dtype="int64"), + index=date_range("2010-01-01", periods=9, freq="QE"), + ) + last = s.resample("ME").ffill() + both = s.resample("ME").ffill().resample("ME").last().astype("int64") + tm.assert_series_equal(last, both) + + @pytest.mark.parametrize("day", DAYS) + @pytest.mark.parametrize("target", ["D", "B"]) + @pytest.mark.parametrize("convention", ["start", "end"]) + def test_weekly_upsample(self, day, target, convention, simple_period_range_series): + freq = f"W-{day}" + ts = simple_period_range_series("1/1/1990", "12/31/1995", freq=freq) + + warn = None if target == "D" else FutureWarning + msg = r"PeriodDtype\[B\] is deprecated" + if warn is None: + msg = "Resampling with a PeriodIndex is deprecated" + warn = FutureWarning + with tm.assert_produces_warning(warn, match=msg): + result = ts.resample(target, convention=convention).ffill() + expected = result.to_timestamp(target, how=convention) + expected = expected.asfreq(target, "ffill").to_period() + tm.assert_series_equal(result, expected) + + def test_resample_to_timestamps(self, simple_period_range_series): + ts = simple_period_range_series("1/1/1990", "12/31/1995", freq="M") + + msg = "The 'kind' keyword in Series.resample is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = ts.resample("Y-DEC", kind="timestamp").mean() + expected = ts.to_timestamp(how="start").resample("YE-DEC").mean() + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("month", MONTHS) + def test_resample_to_quarterly(self, simple_period_range_series, month): + ts = simple_period_range_series("1990", "1992", freq=f"Y-{month}") + quar_ts = ts.resample(f"Q-{month}").ffill() + + stamps = ts.to_timestamp("D", how="start") + qdates = period_range( + ts.index[0].asfreq("D", "start"), + ts.index[-1].asfreq("D", "end"), + freq=f"Q-{month}", + ) + + expected = stamps.reindex(qdates.to_timestamp("D", "s"), method="ffill") + expected.index = qdates + + tm.assert_series_equal(quar_ts, expected) + + @pytest.mark.parametrize("how", ["start", "end"]) + def test_resample_to_quarterly_start_end(self, simple_period_range_series, how): + # conforms, but different month + ts = simple_period_range_series("1990", "1992", freq="Y-JUN") + msg = "The 'convention' keyword in Series.resample is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = ts.resample("Q-MAR", convention=how).ffill() + expected = ts.asfreq("Q-MAR", how=how) + expected = expected.reindex(result.index, method="ffill") + + # FIXME: don't leave commented-out + # .to_timestamp('D') + # expected = expected.resample('Q-MAR').ffill() + + tm.assert_series_equal(result, expected) + + def test_resample_fill_missing(self): + rng = PeriodIndex([2000, 2005, 2007, 2009], freq="Y") + + s = Series(np.random.default_rng(2).standard_normal(4), index=rng) + + stamps = s.to_timestamp() + filled = s.resample("Y").ffill() + expected = stamps.resample("YE").ffill().to_period("Y") + tm.assert_series_equal(filled, expected) + + def test_cant_fill_missing_dups(self): + rng = PeriodIndex([2000, 2005, 2005, 2007, 2007], freq="Y") + s = Series(np.random.default_rng(2).standard_normal(5), index=rng) + msg = "Reindexing only valid with uniquely valued Index objects" + with pytest.raises(InvalidIndexError, match=msg): + s.resample("Y").ffill() + + @pytest.mark.parametrize("freq", ["5min"]) + @pytest.mark.parametrize("kind", ["period", None, "timestamp"]) + def test_resample_5minute(self, freq, kind): + rng = period_range("1/1/2000", "1/5/2000", freq="min") + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + expected = ts.to_timestamp().resample(freq).mean() + if kind != "timestamp": + expected = expected.to_period(freq) + msg = "The 'kind' keyword in Series.resample is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = ts.resample(freq, kind=kind).mean() + tm.assert_series_equal(result, expected) + + def test_upsample_daily_business_daily(self, simple_period_range_series): + ts = simple_period_range_series("1/1/2000", "2/1/2000", freq="B") + + result = ts.resample("D").asfreq() + expected = ts.asfreq("D").reindex(period_range("1/3/2000", "2/1/2000")) + tm.assert_series_equal(result, expected) + + ts = simple_period_range_series("1/1/2000", "2/1/2000") + msg = "The 'convention' keyword in Series.resample is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = ts.resample("h", convention="s").asfreq() + exp_rng = period_range("1/1/2000", "2/1/2000 23:00", freq="h") + expected = ts.asfreq("h", how="s").reindex(exp_rng) + tm.assert_series_equal(result, expected) + + def test_resample_irregular_sparse(self): + dr = date_range(start="1/1/2012", freq="5min", periods=1000) + s = Series(np.array(100), index=dr) + # subset the data. + subset = s[:"2012-01-04 06:55"] + + result = subset.resample("10min").apply(len) + expected = s.resample("10min").apply(len).loc[result.index] + tm.assert_series_equal(result, expected) + + def test_resample_weekly_all_na(self): + rng = date_range("1/1/2000", periods=10, freq="W-WED") + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + result = ts.resample("W-THU").asfreq() + + assert result.isna().all() + + result = ts.resample("W-THU").asfreq().ffill()[:-1] + expected = ts.asfreq("W-THU").ffill() + tm.assert_series_equal(result, expected) + + def test_resample_tz_localized(self, unit): + dr = date_range(start="2012-4-13", end="2012-5-1", unit=unit) + ts = Series(range(len(dr)), index=dr) + + ts_utc = ts.tz_localize("UTC") + ts_local = ts_utc.tz_convert("America/Los_Angeles") + + result = ts_local.resample("W").mean() + + ts_local_naive = ts_local.copy() + ts_local_naive.index = ts_local_naive.index.tz_localize(None) + + exp = ts_local_naive.resample("W").mean().tz_localize("America/Los_Angeles") + exp.index = pd.DatetimeIndex(exp.index, freq="W") + + tm.assert_series_equal(result, exp) + + # it works + result = ts_local.resample("D").mean() + + def test_resample_tz_localized2(self): + # #2245 + idx = date_range( + "2001-09-20 15:59", "2001-09-20 16:00", freq="min", tz="Australia/Sydney" + ) + s = Series([1, 2], index=idx) + + result = s.resample("D", closed="right", label="right").mean() + ex_index = date_range("2001-09-21", periods=1, freq="D", tz="Australia/Sydney") + expected = Series([1.5], index=ex_index) + + tm.assert_series_equal(result, expected) + + # for good measure + msg = "The 'kind' keyword in Series.resample is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = s.resample("D", kind="period").mean() + ex_index = period_range("2001-09-20", periods=1, freq="D") + expected = Series([1.5], index=ex_index) + tm.assert_series_equal(result, expected) + + def test_resample_tz_localized3(self): + # GH 6397 + # comparing an offset that doesn't propagate tz's + rng = date_range("1/1/2011", periods=20000, freq="h") + rng = rng.tz_localize("EST") + ts = DataFrame(index=rng) + ts["first"] = np.random.default_rng(2).standard_normal(len(rng)) + ts["second"] = np.cumsum(np.random.default_rng(2).standard_normal(len(rng))) + expected = DataFrame( + { + "first": ts.resample("YE").sum()["first"], + "second": ts.resample("YE").mean()["second"], + }, + columns=["first", "second"], + ) + result = ( + ts.resample("YE") + .agg({"first": "sum", "second": "mean"}) + .reindex(columns=["first", "second"]) + ) + tm.assert_frame_equal(result, expected) + + def test_closed_left_corner(self): + # #1465 + s = Series( + np.random.default_rng(2).standard_normal(21), + index=date_range(start="1/1/2012 9:30", freq="1min", periods=21), + ) + s.iloc[0] = np.nan + + result = s.resample("10min", closed="left", label="right").mean() + exp = s[1:].resample("10min", closed="left", label="right").mean() + tm.assert_series_equal(result, exp) + + result = s.resample("10min", closed="left", label="left").mean() + exp = s[1:].resample("10min", closed="left", label="left").mean() + + ex_index = date_range(start="1/1/2012 9:30", freq="10min", periods=3) + + tm.assert_index_equal(result.index, ex_index) + tm.assert_series_equal(result, exp) + + def test_quarterly_resampling(self): + rng = period_range("2000Q1", periods=10, freq="Q-DEC") + ts = Series(np.arange(10), index=rng) + + result = ts.resample("Y").mean() + exp = ts.to_timestamp().resample("YE").mean().to_period() + tm.assert_series_equal(result, exp) + + def test_resample_weekly_bug_1726(self): + # 8/6/12 is a Monday + ind = date_range(start="8/6/2012", end="8/26/2012", freq="D") + n = len(ind) + data = [[x] * 5 for x in range(n)] + df = DataFrame(data, columns=["open", "high", "low", "close", "vol"], index=ind) + + # it works! + df.resample("W-MON", closed="left", label="left").first() + + def test_resample_with_dst_time_change(self): + # GH 15549 + index = ( + pd.DatetimeIndex([1457537600000000000, 1458059600000000000]) + .tz_localize("UTC") + .tz_convert("America/Chicago") + ) + df = DataFrame([1, 2], index=index) + result = df.resample("12h", closed="right", label="right").last().ffill() + + expected_index_values = [ + "2016-03-09 12:00:00-06:00", + "2016-03-10 00:00:00-06:00", + "2016-03-10 12:00:00-06:00", + "2016-03-11 00:00:00-06:00", + "2016-03-11 12:00:00-06:00", + "2016-03-12 00:00:00-06:00", + "2016-03-12 12:00:00-06:00", + "2016-03-13 00:00:00-06:00", + "2016-03-13 13:00:00-05:00", + "2016-03-14 01:00:00-05:00", + "2016-03-14 13:00:00-05:00", + "2016-03-15 01:00:00-05:00", + "2016-03-15 13:00:00-05:00", + ] + index = ( + pd.to_datetime(expected_index_values, utc=True) + .tz_convert("America/Chicago") + .as_unit(index.unit) + ) + index = pd.DatetimeIndex(index, freq="12h") + expected = DataFrame( + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0], + index=index, + ) + tm.assert_frame_equal(result, expected) + + def test_resample_bms_2752(self): + # GH2753 + timeseries = Series( + index=pd.bdate_range("20000101", "20000201"), dtype=np.float64 + ) + res1 = timeseries.resample("BMS").mean() + res2 = timeseries.resample("BMS").mean().resample("B").mean() + assert res1.index[0] == Timestamp("20000103") + assert res1.index[0] == res2.index[0] + + @pytest.mark.xfail(reason="Commented out for more than 3 years. Should this work?") + def test_monthly_convention_span(self): + rng = period_range("2000-01", periods=3, freq="ME") + ts = Series(np.arange(3), index=rng) + + # hacky way to get same thing + exp_index = period_range("2000-01-01", "2000-03-31", freq="D") + expected = ts.asfreq("D", how="end").reindex(exp_index) + expected = expected.fillna(method="bfill") + + result = ts.resample("D").mean() + + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "from_freq, to_freq", [("D", "ME"), ("QE", "YE"), ("ME", "QE"), ("D", "W")] + ) + def test_default_right_closed_label(self, from_freq, to_freq): + idx = date_range(start="8/15/2012", periods=100, freq=from_freq) + df = DataFrame(np.random.default_rng(2).standard_normal((len(idx), 2)), idx) + + resampled = df.resample(to_freq).mean() + tm.assert_frame_equal( + resampled, df.resample(to_freq, closed="right", label="right").mean() + ) + + @pytest.mark.parametrize( + "from_freq, to_freq", + [("D", "MS"), ("QE", "YS"), ("ME", "QS"), ("h", "D"), ("min", "h")], + ) + def test_default_left_closed_label(self, from_freq, to_freq): + idx = date_range(start="8/15/2012", periods=100, freq=from_freq) + df = DataFrame(np.random.default_rng(2).standard_normal((len(idx), 2)), idx) + + resampled = df.resample(to_freq).mean() + tm.assert_frame_equal( + resampled, df.resample(to_freq, closed="left", label="left").mean() + ) + + def test_all_values_single_bin(self): + # GH#2070 + index = period_range(start="2012-01-01", end="2012-12-31", freq="M") + ser = Series(np.random.default_rng(2).standard_normal(len(index)), index=index) + + result = ser.resample("Y").mean() + tm.assert_almost_equal(result.iloc[0], ser.mean()) + + def test_evenly_divisible_with_no_extra_bins(self): + # GH#4076 + # when the frequency is evenly divisible, sometimes extra bins + + df = DataFrame( + np.random.default_rng(2).standard_normal((9, 3)), + index=date_range("2000-1-1", periods=9), + ) + result = df.resample("5D").mean() + expected = pd.concat([df.iloc[0:5].mean(), df.iloc[5:].mean()], axis=1).T + expected.index = pd.DatetimeIndex( + [Timestamp("2000-1-1"), Timestamp("2000-1-6")], dtype="M8[ns]", freq="5D" + ) + tm.assert_frame_equal(result, expected) + + def test_evenly_divisible_with_no_extra_bins2(self): + index = date_range(start="2001-5-4", periods=28) + df = DataFrame( + [ + { + "REST_KEY": 1, + "DLY_TRN_QT": 80, + "DLY_SLS_AMT": 90, + "COOP_DLY_TRN_QT": 30, + "COOP_DLY_SLS_AMT": 20, + } + ] + * 28 + + [ + { + "REST_KEY": 2, + "DLY_TRN_QT": 70, + "DLY_SLS_AMT": 10, + "COOP_DLY_TRN_QT": 50, + "COOP_DLY_SLS_AMT": 20, + } + ] + * 28, + index=index.append(index), + ).sort_index() + + index = date_range("2001-5-4", periods=4, freq="7D") + expected = DataFrame( + [ + { + "REST_KEY": 14, + "DLY_TRN_QT": 14, + "DLY_SLS_AMT": 14, + "COOP_DLY_TRN_QT": 14, + "COOP_DLY_SLS_AMT": 14, + } + ] + * 4, + index=index, + ) + result = df.resample("7D").count() + tm.assert_frame_equal(result, expected) + + expected = DataFrame( + [ + { + "REST_KEY": 21, + "DLY_TRN_QT": 1050, + "DLY_SLS_AMT": 700, + "COOP_DLY_TRN_QT": 560, + "COOP_DLY_SLS_AMT": 280, + } + ] + * 4, + index=index, + ) + result = df.resample("7D").sum() + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("freq, period_mult", [("h", 24), ("12h", 2)]) + @pytest.mark.parametrize("kind", [None, "period"]) + def test_upsampling_ohlc(self, freq, period_mult, kind): + # GH 13083 + pi = period_range(start="2000", freq="D", periods=10) + s = Series(range(len(pi)), index=pi) + expected = s.to_timestamp().resample(freq).ohlc().to_period(freq) + + # timestamp-based resampling doesn't include all sub-periods + # of the last original period, so extend accordingly: + new_index = period_range(start="2000", freq=freq, periods=period_mult * len(pi)) + expected = expected.reindex(new_index) + msg = "The 'kind' keyword in Series.resample is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = s.resample(freq, kind=kind).ohlc() + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "periods, values", + [ + ( + [ + pd.NaT, + "1970-01-01 00:00:00", + pd.NaT, + "1970-01-01 00:00:02", + "1970-01-01 00:00:03", + ], + [2, 3, 5, 7, 11], + ), + ( + [ + pd.NaT, + pd.NaT, + "1970-01-01 00:00:00", + pd.NaT, + pd.NaT, + pd.NaT, + "1970-01-01 00:00:02", + "1970-01-01 00:00:03", + pd.NaT, + pd.NaT, + ], + [1, 2, 3, 5, 6, 8, 7, 11, 12, 13], + ), + ], + ) + @pytest.mark.parametrize( + "freq, expected_values", + [ + ("1s", [3, np.nan, 7, 11]), + ("2s", [3, (7 + 11) / 2]), + ("3s", [(3 + 7) / 2, 11]), + ], + ) + def test_resample_with_nat(self, periods, values, freq, expected_values): + # GH 13224 + index = PeriodIndex(periods, freq="s") + frame = DataFrame(values, index=index) + + expected_index = period_range( + "1970-01-01 00:00:00", periods=len(expected_values), freq=freq + ) + expected = DataFrame(expected_values, index=expected_index) + msg = "Resampling with a PeriodIndex is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + rs = frame.resample(freq) + result = rs.mean() + tm.assert_frame_equal(result, expected) + + def test_resample_with_only_nat(self): + # GH 13224 + pi = PeriodIndex([pd.NaT] * 3, freq="s") + frame = DataFrame([2, 3, 5], index=pi, columns=["a"]) + expected_index = PeriodIndex(data=[], freq=pi.freq) + expected = DataFrame(index=expected_index, columns=["a"], dtype="float64") + result = frame.resample("1s").mean() + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "start,end,start_freq,end_freq,offset", + [ + ("19910905", "19910909 03:00", "h", "24h", "10h"), + ("19910905", "19910909 12:00", "h", "24h", "10h"), + ("19910905", "19910909 23:00", "h", "24h", "10h"), + ("19910905 10:00", "19910909", "h", "24h", "10h"), + ("19910905 10:00", "19910909 10:00", "h", "24h", "10h"), + ("19910905", "19910909 10:00", "h", "24h", "10h"), + ("19910905 12:00", "19910909", "h", "24h", "10h"), + ("19910905 12:00", "19910909 03:00", "h", "24h", "10h"), + ("19910905 12:00", "19910909 12:00", "h", "24h", "10h"), + ("19910905 12:00", "19910909 12:00", "h", "24h", "34h"), + ("19910905 12:00", "19910909 12:00", "h", "17h", "10h"), + ("19910905 12:00", "19910909 12:00", "h", "17h", "3h"), + ("19910905", "19910913 06:00", "2h", "24h", "10h"), + ("19910905", "19910905 01:39", "Min", "5Min", "3Min"), + ("19910905", "19910905 03:18", "2Min", "5Min", "3Min"), + ], + ) + def test_resample_with_offset(self, start, end, start_freq, end_freq, offset): + # GH 23882 & 31809 + pi = period_range(start, end, freq=start_freq) + ser = Series(np.arange(len(pi)), index=pi) + msg = "Resampling with a PeriodIndex is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + rs = ser.resample(end_freq, offset=offset) + result = rs.mean() + result = result.to_timestamp(end_freq) + + expected = ser.to_timestamp().resample(end_freq, offset=offset).mean() + tm.assert_series_equal(result, expected) + + def test_resample_with_offset_month(self): + # GH 23882 & 31809 + pi = period_range("19910905 12:00", "19910909 1:00", freq="h") + ser = Series(np.arange(len(pi)), index=pi) + msg = "Resampling with a PeriodIndex is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + rs = ser.resample("M", offset="3h") + result = rs.mean() + result = result.to_timestamp("M") + expected = ser.to_timestamp().resample("ME", offset="3h").mean() + # TODO: is non-tick the relevant characteristic? (GH 33815) + expected.index = expected.index._with_freq(None) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "first,last,freq,freq_to_offset,exp_first,exp_last", + [ + ("19910905", "19920406", "D", "D", "19910905", "19920406"), + ("19910905 00:00", "19920406 06:00", "D", "D", "19910905", "19920406"), + ( + "19910905 06:00", + "19920406 06:00", + "h", + "h", + "19910905 06:00", + "19920406 06:00", + ), + ("19910906", "19920406", "M", "ME", "1991-09", "1992-04"), + ("19910831", "19920430", "M", "ME", "1991-08", "1992-04"), + ("1991-08", "1992-04", "M", "ME", "1991-08", "1992-04"), + ], + ) + def test_get_period_range_edges( + self, first, last, freq, freq_to_offset, exp_first, exp_last + ): + first = Period(first) + last = Period(last) + + exp_first = Period(exp_first, freq=freq) + exp_last = Period(exp_last, freq=freq) + + freq = pd.tseries.frequencies.to_offset(freq_to_offset) + result = _get_period_range_edges(first, last, freq) + expected = (exp_first, exp_last) + assert result == expected + + def test_sum_min_count(self): + # GH 19974 + index = date_range(start="2018", freq="ME", periods=6) + data = np.ones(6) + data[3:6] = np.nan + s = Series(data, index).to_period() + msg = "Resampling with a PeriodIndex is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + rs = s.resample("Q") + result = rs.sum(min_count=1) + expected = Series( + [3.0, np.nan], index=PeriodIndex(["2018Q1", "2018Q2"], freq="Q-DEC") + ) + tm.assert_series_equal(result, expected) + + def test_resample_t_l_deprecated(self): + # GH#52536 + msg_t = "'T' is deprecated and will be removed in a future version." + msg_l = "'L' is deprecated and will be removed in a future version." + + with tm.assert_produces_warning(FutureWarning, match=msg_l): + rng_l = period_range( + "2020-01-01 00:00:00 00:00", "2020-01-01 00:00:00 00:01", freq="L" + ) + ser = Series(np.arange(len(rng_l)), index=rng_l) + + rng = period_range( + "2020-01-01 00:00:00 00:00", "2020-01-01 00:00:00 00:01", freq="min" + ) + expected = Series([29999.5, 60000.0], index=rng) + with tm.assert_produces_warning(FutureWarning, match=msg_t): + result = ser.resample("T").mean() + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "freq, freq_depr, freq_res, freq_depr_res, data", + [ + ("2Q", "2q", "2Y", "2y", [0.5]), + ("2M", "2m", "2Q", "2q", [1.0, 3.0]), + ], + ) + def test_resample_lowercase_frequency_deprecated( + self, freq, freq_depr, freq_res, freq_depr_res, data + ): + depr_msg = f"'{freq_depr[1:]}' is deprecated and will be removed in a " + f"future version. Please use '{freq[1:]}' instead." + depr_msg_res = f"'{freq_depr_res[1:]}' is deprecated and will be removed in a " + f"future version. Please use '{freq_res[1:]}' instead." + + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + rng_l = period_range("2020-01-01", "2020-08-01", freq=freq_depr) + ser = Series(np.arange(len(rng_l)), index=rng_l) + + rng = period_range("2020-01-01", "2020-08-01", freq=freq_res) + expected = Series(data=data, index=rng) + + with tm.assert_produces_warning(FutureWarning, match=depr_msg_res): + result = ser.resample(freq_depr_res).mean() + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "offset", + [ + offsets.MonthBegin(), + offsets.BYearBegin(2), + offsets.BusinessHour(2), + ], + ) + def test_asfreq_invalid_period_offset(self, offset, series_and_frame): + # GH#55785 + msg = f"Invalid offset: '{offset.base}' for converting time series " + + df = series_and_frame + with pytest.raises(ValueError, match=msg): + df.asfreq(freq=offset) + + +@pytest.mark.parametrize( + "freq,freq_depr", + [ + ("2M", "2ME"), + ("2Q", "2QE"), + ("2Q-FEB", "2QE-FEB"), + ("2Y", "2YE"), + ("2Y-MAR", "2YE-MAR"), + ("2M", "2me"), + ("2Q", "2qe"), + ("2Y-MAR", "2ye-mar"), + ], +) +def test_resample_frequency_ME_QE_YE_error_message(series_and_frame, freq, freq_depr): + # GH#9586 + msg = f"for Period, please use '{freq[1:]}' instead of '{freq_depr[1:]}'" + + obj = series_and_frame + with pytest.raises(ValueError, match=msg): + obj.resample(freq_depr) + + +def test_corner_cases_period(simple_period_range_series): + # miscellaneous test coverage + len0pts = simple_period_range_series("2007-01", "2010-05", freq="M")[:0] + # it works + msg = "Resampling with a PeriodIndex is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = len0pts.resample("Y-DEC").mean() + assert len(result) == 0 + + +@pytest.mark.parametrize( + "freq_depr", + [ + "2BME", + "2CBME", + "2SME", + "2BQE-FEB", + "2BYE-MAR", + ], +) +def test_resample_frequency_invalid_freq(series_and_frame, freq_depr): + # GH#9586 + msg = f"Invalid frequency: {freq_depr[1:]}" + + obj = series_and_frame + with pytest.raises(ValueError, match=msg): + obj.resample(freq_depr) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/resample/test_resample_api.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/resample/test_resample_api.py new file mode 100644 index 0000000000000000000000000000000000000000..12abd1c98784bf2d080f2eb2a26cc00e8fbf4b5f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/resample/test_resample_api.py @@ -0,0 +1,1099 @@ +from datetime import datetime +import re + +import numpy as np +import pytest + +from pandas._libs import lib +from pandas.errors import UnsupportedFunctionCall + +import pandas as pd +from pandas import ( + DataFrame, + NamedAgg, + Series, +) +import pandas._testing as tm +from pandas.core.indexes.datetimes import date_range + + +@pytest.fixture +def dti(): + return date_range(start=datetime(2005, 1, 1), end=datetime(2005, 1, 10), freq="Min") + + +@pytest.fixture +def _test_series(dti): + return Series(np.random.default_rng(2).random(len(dti)), dti) + + +@pytest.fixture +def test_frame(dti, _test_series): + return DataFrame({"A": _test_series, "B": _test_series, "C": np.arange(len(dti))}) + + +def test_str(_test_series): + r = _test_series.resample("h") + assert ( + "DatetimeIndexResampler [freq=, axis=0, closed=left, " + "label=left, convention=start, origin=start_day]" in str(r) + ) + + r = _test_series.resample("h", origin="2000-01-01") + assert ( + "DatetimeIndexResampler [freq=, axis=0, closed=left, " + "label=left, convention=start, origin=2000-01-01 00:00:00]" in str(r) + ) + + +def test_api(_test_series): + r = _test_series.resample("h") + result = r.mean() + assert isinstance(result, Series) + assert len(result) == 217 + + r = _test_series.to_frame().resample("h") + result = r.mean() + assert isinstance(result, DataFrame) + assert len(result) == 217 + + +def test_groupby_resample_api(): + # GH 12448 + # .groupby(...).resample(...) hitting warnings + # when appropriate + df = DataFrame( + { + "date": date_range(start="2016-01-01", periods=4, freq="W"), + "group": [1, 1, 2, 2], + "val": [5, 6, 7, 8], + } + ).set_index("date") + + # replication step + i = ( + date_range("2016-01-03", periods=8).tolist() + + date_range("2016-01-17", periods=8).tolist() + ) + index = pd.MultiIndex.from_arrays([[1] * 8 + [2] * 8, i], names=["group", "date"]) + expected = DataFrame({"val": [5] * 7 + [6] + [7] * 7 + [8]}, index=index) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("group").apply(lambda x: x.resample("1D").ffill())[["val"]] + tm.assert_frame_equal(result, expected) + + +def test_groupby_resample_on_api(): + # GH 15021 + # .groupby(...).resample(on=...) results in an unexpected + # keyword warning. + df = DataFrame( + { + "key": ["A", "B"] * 5, + "dates": date_range("2016-01-01", periods=10), + "values": np.random.default_rng(2).standard_normal(10), + } + ) + + expected = df.set_index("dates").groupby("key").resample("D").mean() + result = df.groupby("key").resample("D", on="dates").mean() + tm.assert_frame_equal(result, expected) + + +def test_resample_group_keys(): + df = DataFrame({"A": 1, "B": 2}, index=date_range("2000", periods=10)) + expected = df.copy() + + # group_keys=False + g = df.resample("5D", group_keys=False) + result = g.apply(lambda x: x) + tm.assert_frame_equal(result, expected) + + # group_keys defaults to False + g = df.resample("5D") + result = g.apply(lambda x: x) + tm.assert_frame_equal(result, expected) + + # group_keys=True + expected.index = pd.MultiIndex.from_arrays( + [ + pd.to_datetime(["2000-01-01", "2000-01-06"]).as_unit("ns").repeat(5), + expected.index, + ] + ) + g = df.resample("5D", group_keys=True) + result = g.apply(lambda x: x) + tm.assert_frame_equal(result, expected) + + +def test_pipe(test_frame, _test_series): + # GH17905 + + # series + r = _test_series.resample("h") + expected = r.max() - r.mean() + result = r.pipe(lambda x: x.max() - x.mean()) + tm.assert_series_equal(result, expected) + + # dataframe + r = test_frame.resample("h") + expected = r.max() - r.mean() + result = r.pipe(lambda x: x.max() - x.mean()) + tm.assert_frame_equal(result, expected) + + +def test_getitem(test_frame): + r = test_frame.resample("h") + tm.assert_index_equal(r._selected_obj.columns, test_frame.columns) + + r = test_frame.resample("h")["B"] + assert r._selected_obj.name == test_frame.columns[1] + + # technically this is allowed + r = test_frame.resample("h")["A", "B"] + tm.assert_index_equal(r._selected_obj.columns, test_frame.columns[[0, 1]]) + + r = test_frame.resample("h")["A", "B"] + tm.assert_index_equal(r._selected_obj.columns, test_frame.columns[[0, 1]]) + + +@pytest.mark.parametrize("key", [["D"], ["A", "D"]]) +def test_select_bad_cols(key, test_frame): + g = test_frame.resample("h") + # 'A' should not be referenced as a bad column... + # will have to rethink regex if you change message! + msg = r"^\"Columns not found: 'D'\"$" + with pytest.raises(KeyError, match=msg): + g[key] + + +def test_attribute_access(test_frame): + r = test_frame.resample("h") + tm.assert_series_equal(r.A.sum(), r["A"].sum()) + + +@pytest.mark.parametrize("attr", ["groups", "ngroups", "indices"]) +def test_api_compat_before_use(attr): + # make sure that we are setting the binner + # on these attributes + rng = date_range("1/1/2012", periods=100, freq="s") + ts = Series(np.arange(len(rng)), index=rng) + rs = ts.resample("30s") + + # before use + getattr(rs, attr) + + # after grouper is initialized is ok + rs.mean() + getattr(rs, attr) + + +def tests_raises_on_nuisance(test_frame): + df = test_frame + df["D"] = "foo" + r = df.resample("h") + result = r[["A", "B"]].mean() + expected = pd.concat([r.A.mean(), r.B.mean()], axis=1) + tm.assert_frame_equal(result, expected) + + expected = r[["A", "B", "C"]].mean() + msg = re.escape("agg function failed [how->mean,dtype->") + with pytest.raises(TypeError, match=msg): + r.mean() + result = r.mean(numeric_only=True) + tm.assert_frame_equal(result, expected) + + +def test_downsample_but_actually_upsampling(): + # this is reindex / asfreq + rng = date_range("1/1/2012", periods=100, freq="s") + ts = Series(np.arange(len(rng), dtype="int64"), index=rng) + result = ts.resample("20s").asfreq() + expected = Series( + [0, 20, 40, 60, 80], + index=date_range("2012-01-01 00:00:00", freq="20s", periods=5), + ) + tm.assert_series_equal(result, expected) + + +def test_combined_up_downsampling_of_irregular(): + # since we are really doing an operation like this + # ts2.resample('2s').mean().ffill() + # preserve these semantics + + rng = date_range("1/1/2012", periods=100, freq="s") + ts = Series(np.arange(len(rng)), index=rng) + ts2 = ts.iloc[[0, 1, 2, 3, 5, 7, 11, 15, 16, 25, 30]] + + result = ts2.resample("2s").mean().ffill() + expected = Series( + [ + 0.5, + 2.5, + 5.0, + 7.0, + 7.0, + 11.0, + 11.0, + 15.0, + 16.0, + 16.0, + 16.0, + 16.0, + 25.0, + 25.0, + 25.0, + 30.0, + ], + index=pd.DatetimeIndex( + [ + "2012-01-01 00:00:00", + "2012-01-01 00:00:02", + "2012-01-01 00:00:04", + "2012-01-01 00:00:06", + "2012-01-01 00:00:08", + "2012-01-01 00:00:10", + "2012-01-01 00:00:12", + "2012-01-01 00:00:14", + "2012-01-01 00:00:16", + "2012-01-01 00:00:18", + "2012-01-01 00:00:20", + "2012-01-01 00:00:22", + "2012-01-01 00:00:24", + "2012-01-01 00:00:26", + "2012-01-01 00:00:28", + "2012-01-01 00:00:30", + ], + dtype="datetime64[ns]", + freq="2s", + ), + ) + tm.assert_series_equal(result, expected) + + +def test_transform_series(_test_series): + r = _test_series.resample("20min") + expected = _test_series.groupby(pd.Grouper(freq="20min")).transform("mean") + result = r.transform("mean") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("on", [None, "date"]) +def test_transform_frame(on): + # GH#47079 + index = date_range(datetime(2005, 1, 1), datetime(2005, 1, 10), freq="D") + index.name = "date" + df = DataFrame( + np.random.default_rng(2).random((10, 2)), columns=list("AB"), index=index + ) + expected = df.groupby(pd.Grouper(freq="20min")).transform("mean") + if on == "date": + # Move date to being a column; result will then have a RangeIndex + expected = expected.reset_index(drop=True) + df = df.reset_index() + + r = df.resample("20min", on=on) + result = r.transform("mean") + tm.assert_frame_equal(result, expected) + + +def test_fillna(): + # need to upsample here + rng = date_range("1/1/2012", periods=10, freq="2s") + ts = Series(np.arange(len(rng), dtype="int64"), index=rng) + r = ts.resample("s") + + expected = r.ffill() + msg = "DatetimeIndexResampler.fillna is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = r.fillna(method="ffill") + tm.assert_series_equal(result, expected) + + expected = r.bfill() + with tm.assert_produces_warning(FutureWarning, match=msg): + result = r.fillna(method="bfill") + tm.assert_series_equal(result, expected) + + msg2 = ( + r"Invalid fill method\. Expecting pad \(ffill\), backfill " + r"\(bfill\) or nearest\. Got 0" + ) + with pytest.raises(ValueError, match=msg2): + with tm.assert_produces_warning(FutureWarning, match=msg): + r.fillna(0) + + +@pytest.mark.parametrize( + "func", + [ + lambda x: x.resample("20min", group_keys=False), + lambda x: x.groupby(pd.Grouper(freq="20min"), group_keys=False), + ], + ids=["resample", "groupby"], +) +def test_apply_without_aggregation(func, _test_series): + # both resample and groupby should work w/o aggregation + t = func(_test_series) + result = t.apply(lambda x: x) + tm.assert_series_equal(result, _test_series) + + +def test_apply_without_aggregation2(_test_series): + grouped = _test_series.to_frame(name="foo").resample("20min", group_keys=False) + result = grouped["foo"].apply(lambda x: x) + tm.assert_series_equal(result, _test_series.rename("foo")) + + +def test_agg_consistency(): + # make sure that we are consistent across + # similar aggregations with and w/o selection list + df = DataFrame( + np.random.default_rng(2).standard_normal((1000, 3)), + index=date_range("1/1/2012", freq="s", periods=1000), + columns=["A", "B", "C"], + ) + + r = df.resample("3min") + + msg = r"Column\(s\) \['r1', 'r2'\] do not exist" + with pytest.raises(KeyError, match=msg): + r.agg({"r1": "mean", "r2": "sum"}) + + +def test_agg_consistency_int_str_column_mix(): + # GH#39025 + df = DataFrame( + np.random.default_rng(2).standard_normal((1000, 2)), + index=date_range("1/1/2012", freq="s", periods=1000), + columns=[1, "a"], + ) + + r = df.resample("3min") + + msg = r"Column\(s\) \[2, 'b'\] do not exist" + with pytest.raises(KeyError, match=msg): + r.agg({2: "mean", "b": "sum"}) + + +# TODO(GH#14008): once GH 14008 is fixed, move these tests into +# `Base` test class + + +@pytest.fixture +def index(): + index = date_range(datetime(2005, 1, 1), datetime(2005, 1, 10), freq="D") + index.name = "date" + return index + + +@pytest.fixture +def df(index): + frame = DataFrame( + np.random.default_rng(2).random((10, 2)), columns=list("AB"), index=index + ) + return frame + + +@pytest.fixture +def df_col(df): + return df.reset_index() + + +@pytest.fixture +def df_mult(df_col, index): + df_mult = df_col.copy() + df_mult.index = pd.MultiIndex.from_arrays( + [range(10), index], names=["index", "date"] + ) + return df_mult + + +@pytest.fixture +def a_mean(df): + return df.resample("2D")["A"].mean() + + +@pytest.fixture +def a_std(df): + return df.resample("2D")["A"].std() + + +@pytest.fixture +def a_sum(df): + return df.resample("2D")["A"].sum() + + +@pytest.fixture +def b_mean(df): + return df.resample("2D")["B"].mean() + + +@pytest.fixture +def b_std(df): + return df.resample("2D")["B"].std() + + +@pytest.fixture +def b_sum(df): + return df.resample("2D")["B"].sum() + + +@pytest.fixture +def df_resample(df): + return df.resample("2D") + + +@pytest.fixture +def df_col_resample(df_col): + return df_col.resample("2D", on="date") + + +@pytest.fixture +def df_mult_resample(df_mult): + return df_mult.resample("2D", level="date") + + +@pytest.fixture +def df_grouper_resample(df): + return df.groupby(pd.Grouper(freq="2D")) + + +@pytest.fixture( + params=["df_resample", "df_col_resample", "df_mult_resample", "df_grouper_resample"] +) +def cases(request): + return request.getfixturevalue(request.param) + + +def test_agg_mixed_column_aggregation(cases, a_mean, a_std, b_mean, b_std, request): + expected = pd.concat([a_mean, a_std, b_mean, b_std], axis=1) + expected.columns = pd.MultiIndex.from_product([["A", "B"], ["mean", "std"]]) + msg = "using SeriesGroupBy.[mean|std]" + # "date" is an index and a column, so get included in the agg + if "df_mult" in request.node.callspec.id: + date_mean = cases["date"].mean() + date_std = cases["date"].std() + expected = pd.concat([date_mean, date_std, expected], axis=1) + expected.columns = pd.MultiIndex.from_product( + [["date", "A", "B"], ["mean", "std"]] + ) + with tm.assert_produces_warning(FutureWarning, match=msg): + result = cases.aggregate([np.mean, np.std]) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "agg", + [ + {"func": {"A": np.mean, "B": np.std}}, + {"A": ("A", np.mean), "B": ("B", np.std)}, + {"A": NamedAgg("A", np.mean), "B": NamedAgg("B", np.std)}, + ], +) +def test_agg_both_mean_std_named_result(cases, a_mean, b_std, agg): + msg = "using SeriesGroupBy.[mean|std]" + expected = pd.concat([a_mean, b_std], axis=1) + with tm.assert_produces_warning(FutureWarning, match=msg): + result = cases.aggregate(**agg) + tm.assert_frame_equal(result, expected, check_like=True) + + +def test_agg_both_mean_std_dict_of_list(cases, a_mean, a_std): + expected = pd.concat([a_mean, a_std], axis=1) + expected.columns = pd.MultiIndex.from_tuples([("A", "mean"), ("A", "std")]) + result = cases.aggregate({"A": ["mean", "std"]}) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "agg", [{"func": ["mean", "sum"]}, {"mean": "mean", "sum": "sum"}] +) +def test_agg_both_mean_sum(cases, a_mean, a_sum, agg): + expected = pd.concat([a_mean, a_sum], axis=1) + expected.columns = ["mean", "sum"] + result = cases["A"].aggregate(**agg) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "agg", + [ + {"A": {"mean": "mean", "sum": "sum"}}, + { + "A": {"mean": "mean", "sum": "sum"}, + "B": {"mean2": "mean", "sum2": "sum"}, + }, + ], +) +def test_agg_dict_of_dict_specificationerror(cases, agg): + msg = "nested renamer is not supported" + with pytest.raises(pd.errors.SpecificationError, match=msg): + cases.aggregate(agg) + + +def test_agg_dict_of_lists(cases, a_mean, a_std, b_mean, b_std): + expected = pd.concat([a_mean, a_std, b_mean, b_std], axis=1) + expected.columns = pd.MultiIndex.from_tuples( + [("A", "mean"), ("A", "std"), ("B", "mean"), ("B", "std")] + ) + result = cases.aggregate({"A": ["mean", "std"], "B": ["mean", "std"]}) + tm.assert_frame_equal(result, expected, check_like=True) + + +@pytest.mark.parametrize( + "agg", + [ + {"func": {"A": np.sum, "B": lambda x: np.std(x, ddof=1)}}, + {"A": ("A", np.sum), "B": ("B", lambda x: np.std(x, ddof=1))}, + {"A": NamedAgg("A", np.sum), "B": NamedAgg("B", lambda x: np.std(x, ddof=1))}, + ], +) +def test_agg_with_lambda(cases, agg): + # passed lambda + msg = "using SeriesGroupBy.sum" + rcustom = cases["B"].apply(lambda x: np.std(x, ddof=1)) + expected = pd.concat([cases["A"].sum(), rcustom], axis=1) + with tm.assert_produces_warning(FutureWarning, match=msg): + result = cases.agg(**agg) + tm.assert_frame_equal(result, expected, check_like=True) + + +@pytest.mark.parametrize( + "agg", + [ + {"func": {"result1": np.sum, "result2": np.mean}}, + {"A": ("result1", np.sum), "B": ("result2", np.mean)}, + {"A": NamedAgg("result1", np.sum), "B": NamedAgg("result2", np.mean)}, + ], +) +def test_agg_no_column(cases, agg): + msg = r"Column\(s\) \['result1', 'result2'\] do not exist" + with pytest.raises(KeyError, match=msg): + cases[["A", "B"]].agg(**agg) + + +@pytest.mark.parametrize( + "cols, agg", + [ + [None, {"A": ["sum", "std"], "B": ["mean", "std"]}], + [ + [ + "A", + "B", + ], + {"A": ["sum", "std"], "B": ["mean", "std"]}, + ], + ], +) +def test_agg_specificationerror_nested(cases, cols, agg, a_sum, a_std, b_mean, b_std): + # agg with different hows + # equivalent of using a selection list / or not + expected = pd.concat([a_sum, a_std, b_mean, b_std], axis=1) + expected.columns = pd.MultiIndex.from_tuples( + [("A", "sum"), ("A", "std"), ("B", "mean"), ("B", "std")] + ) + if cols is not None: + obj = cases[cols] + else: + obj = cases + + result = obj.agg(agg) + tm.assert_frame_equal(result, expected, check_like=True) + + +@pytest.mark.parametrize( + "agg", [{"A": ["sum", "std"]}, {"A": ["sum", "std"], "B": ["mean", "std"]}] +) +def test_agg_specificationerror_series(cases, agg): + msg = "nested renamer is not supported" + + # series like aggs + with pytest.raises(pd.errors.SpecificationError, match=msg): + cases["A"].agg(agg) + + +def test_agg_specificationerror_invalid_names(cases): + # errors + # invalid names in the agg specification + msg = r"Column\(s\) \['B'\] do not exist" + with pytest.raises(KeyError, match=msg): + cases[["A"]].agg({"A": ["sum", "std"], "B": ["mean", "std"]}) + + +@pytest.mark.parametrize( + "func", [["min"], ["mean", "max"], {"A": "sum"}, {"A": "prod", "B": "median"}] +) +def test_multi_agg_axis_1_raises(func): + # GH#46904 + + index = date_range(datetime(2005, 1, 1), datetime(2005, 1, 10), freq="D") + index.name = "date" + df = DataFrame( + np.random.default_rng(2).random((10, 2)), columns=list("AB"), index=index + ).T + warning_msg = "DataFrame.resample with axis=1 is deprecated." + with tm.assert_produces_warning(FutureWarning, match=warning_msg): + res = df.resample("ME", axis=1) + with pytest.raises( + NotImplementedError, match="axis other than 0 is not supported" + ): + res.agg(func) + + +def test_agg_nested_dicts(): + index = date_range(datetime(2005, 1, 1), datetime(2005, 1, 10), freq="D") + index.name = "date" + df = DataFrame( + np.random.default_rng(2).random((10, 2)), columns=list("AB"), index=index + ) + df_col = df.reset_index() + df_mult = df_col.copy() + df_mult.index = pd.MultiIndex.from_arrays( + [range(10), df.index], names=["index", "date"] + ) + r = df.resample("2D") + cases = [ + r, + df_col.resample("2D", on="date"), + df_mult.resample("2D", level="date"), + df.groupby(pd.Grouper(freq="2D")), + ] + + msg = "nested renamer is not supported" + for t in cases: + with pytest.raises(pd.errors.SpecificationError, match=msg): + t.aggregate({"r1": {"A": ["mean", "sum"]}, "r2": {"B": ["mean", "sum"]}}) + + for t in cases: + with pytest.raises(pd.errors.SpecificationError, match=msg): + t[["A", "B"]].agg( + {"A": {"ra": ["mean", "std"]}, "B": {"rb": ["mean", "std"]}} + ) + + with pytest.raises(pd.errors.SpecificationError, match=msg): + t.agg({"A": {"ra": ["mean", "std"]}, "B": {"rb": ["mean", "std"]}}) + + +def test_try_aggregate_non_existing_column(): + # GH 16766 + data = [ + {"dt": datetime(2017, 6, 1, 0), "x": 1.0, "y": 2.0}, + {"dt": datetime(2017, 6, 1, 1), "x": 2.0, "y": 2.0}, + {"dt": datetime(2017, 6, 1, 2), "x": 3.0, "y": 1.5}, + ] + df = DataFrame(data).set_index("dt") + + # Error as we don't have 'z' column + msg = r"Column\(s\) \['z'\] do not exist" + with pytest.raises(KeyError, match=msg): + df.resample("30min").agg({"x": ["mean"], "y": ["median"], "z": ["sum"]}) + + +def test_agg_list_like_func_with_args(): + # 50624 + df = DataFrame( + {"x": [1, 2, 3]}, index=date_range("2020-01-01", periods=3, freq="D") + ) + + def foo1(x, a=1, c=0): + return x + a + c + + def foo2(x, b=2, c=0): + return x + b + c + + msg = r"foo1\(\) got an unexpected keyword argument 'b'" + with pytest.raises(TypeError, match=msg): + df.resample("D").agg([foo1, foo2], 3, b=3, c=4) + + result = df.resample("D").agg([foo1, foo2], 3, c=4) + expected = DataFrame( + [[8, 8], [9, 9], [10, 10]], + index=date_range("2020-01-01", periods=3, freq="D"), + columns=pd.MultiIndex.from_tuples([("x", "foo1"), ("x", "foo2")]), + ) + tm.assert_frame_equal(result, expected) + + +def test_selection_api_validation(): + # GH 13500 + index = date_range(datetime(2005, 1, 1), datetime(2005, 1, 10), freq="D") + + rng = np.arange(len(index), dtype=np.int64) + df = DataFrame( + {"date": index, "a": rng}, + index=pd.MultiIndex.from_arrays([rng, index], names=["v", "d"]), + ) + df_exp = DataFrame({"a": rng}, index=index) + + # non DatetimeIndex + msg = ( + "Only valid with DatetimeIndex, TimedeltaIndex or PeriodIndex, " + "but got an instance of 'Index'" + ) + with pytest.raises(TypeError, match=msg): + df.resample("2D", level="v") + + msg = "The Grouper cannot specify both a key and a level!" + with pytest.raises(ValueError, match=msg): + df.resample("2D", on="date", level="d") + + msg = "unhashable type: 'list'" + with pytest.raises(TypeError, match=msg): + df.resample("2D", on=["a", "date"]) + + msg = r"\"Level \['a', 'date'\] not found\"" + with pytest.raises(KeyError, match=msg): + df.resample("2D", level=["a", "date"]) + + # upsampling not allowed + msg = ( + "Upsampling from level= or on= selection is not supported, use " + r"\.set_index\(\.\.\.\) to explicitly set index to datetime-like" + ) + with pytest.raises(ValueError, match=msg): + df.resample("2D", level="d").asfreq() + with pytest.raises(ValueError, match=msg): + df.resample("2D", on="date").asfreq() + + exp = df_exp.resample("2D").sum() + exp.index.name = "date" + result = df.resample("2D", on="date").sum() + tm.assert_frame_equal(exp, result) + + exp.index.name = "d" + with pytest.raises(TypeError, match="datetime64 type does not support sum"): + df.resample("2D", level="d").sum() + result = df.resample("2D", level="d").sum(numeric_only=True) + tm.assert_frame_equal(exp, result) + + +@pytest.mark.parametrize( + "col_name", ["t2", "t2x", "t2q", "T_2M", "t2p", "t2m", "t2m1", "T2M"] +) +def test_agg_with_datetime_index_list_agg_func(col_name): + # GH 22660 + # The parametrized column names would get converted to dates by our + # date parser. Some would result in OutOfBoundsError (ValueError) while + # others would result in OverflowError when passed into Timestamp. + # We catch these errors and move on to the correct branch. + df = DataFrame( + list(range(200)), + index=date_range( + start="2017-01-01", freq="15min", periods=200, tz="Europe/Berlin" + ), + columns=[col_name], + ) + result = df.resample("1d").aggregate(["mean"]) + expected = DataFrame( + [47.5, 143.5, 195.5], + index=date_range(start="2017-01-01", freq="D", periods=3, tz="Europe/Berlin"), + columns=pd.MultiIndex(levels=[[col_name], ["mean"]], codes=[[0], [0]]), + ) + tm.assert_frame_equal(result, expected) + + +def test_resample_agg_readonly(): + # GH#31710 cython needs to allow readonly data + index = date_range("2020-01-01", "2020-01-02", freq="1h") + arr = np.zeros_like(index) + arr.setflags(write=False) + + ser = Series(arr, index=index) + rs = ser.resample("1D") + + expected = Series([pd.Timestamp(0), pd.Timestamp(0)], index=index[::24]) + + result = rs.agg("last") + tm.assert_series_equal(result, expected) + + result = rs.agg("first") + tm.assert_series_equal(result, expected) + + result = rs.agg("max") + tm.assert_series_equal(result, expected) + + result = rs.agg("min") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "start,end,freq,data,resample_freq,origin,closed,exp_data,exp_end,exp_periods", + [ + ( + "2000-10-01 23:30:00", + "2000-10-02 00:26:00", + "7min", + [0, 3, 6, 9, 12, 15, 18, 21, 24], + "17min", + "end", + None, + [0, 18, 27, 63], + "20001002 00:26:00", + 4, + ), + ( + "20200101 8:26:35", + "20200101 9:31:58", + "77s", + [1] * 51, + "7min", + "end", + "right", + [1, 6, 5, 6, 5, 6, 5, 6, 5, 6], + "2020-01-01 09:30:45", + 10, + ), + ( + "2000-10-01 23:30:00", + "2000-10-02 00:26:00", + "7min", + [0, 3, 6, 9, 12, 15, 18, 21, 24], + "17min", + "end", + "left", + [0, 18, 27, 39, 24], + "20001002 00:43:00", + 5, + ), + ( + "2000-10-01 23:30:00", + "2000-10-02 00:26:00", + "7min", + [0, 3, 6, 9, 12, 15, 18, 21, 24], + "17min", + "end_day", + None, + [3, 15, 45, 45], + "2000-10-02 00:29:00", + 4, + ), + ], +) +def test_end_and_end_day_origin( + start, + end, + freq, + data, + resample_freq, + origin, + closed, + exp_data, + exp_end, + exp_periods, +): + rng = date_range(start, end, freq=freq) + ts = Series(data, index=rng) + + res = ts.resample(resample_freq, origin=origin, closed=closed).sum() + expected = Series( + exp_data, + index=date_range(end=exp_end, freq=resample_freq, periods=exp_periods), + ) + + tm.assert_series_equal(res, expected) + + +@pytest.mark.parametrize( + # expected_data is a string when op raises a ValueError + "method, numeric_only, expected_data", + [ + ("sum", True, {"num": [25]}), + ("sum", False, {"cat": ["cat_1cat_2"], "num": [25]}), + ("sum", lib.no_default, {"cat": ["cat_1cat_2"], "num": [25]}), + ("prod", True, {"num": [100]}), + ("prod", False, "can't multiply sequence"), + ("prod", lib.no_default, "can't multiply sequence"), + ("min", True, {"num": [5]}), + ("min", False, {"cat": ["cat_1"], "num": [5]}), + ("min", lib.no_default, {"cat": ["cat_1"], "num": [5]}), + ("max", True, {"num": [20]}), + ("max", False, {"cat": ["cat_2"], "num": [20]}), + ("max", lib.no_default, {"cat": ["cat_2"], "num": [20]}), + ("first", True, {"num": [5]}), + ("first", False, {"cat": ["cat_1"], "num": [5]}), + ("first", lib.no_default, {"cat": ["cat_1"], "num": [5]}), + ("last", True, {"num": [20]}), + ("last", False, {"cat": ["cat_2"], "num": [20]}), + ("last", lib.no_default, {"cat": ["cat_2"], "num": [20]}), + ("mean", True, {"num": [12.5]}), + ("mean", False, "Could not convert"), + ("mean", lib.no_default, "Could not convert"), + ("median", True, {"num": [12.5]}), + ("median", False, r"Cannot convert \['cat_1' 'cat_2'\] to numeric"), + ("median", lib.no_default, r"Cannot convert \['cat_1' 'cat_2'\] to numeric"), + ("std", True, {"num": [10.606601717798213]}), + ("std", False, "could not convert string to float"), + ("std", lib.no_default, "could not convert string to float"), + ("var", True, {"num": [112.5]}), + ("var", False, "could not convert string to float"), + ("var", lib.no_default, "could not convert string to float"), + ("sem", True, {"num": [7.5]}), + ("sem", False, "could not convert string to float"), + ("sem", lib.no_default, "could not convert string to float"), + ], +) +def test_frame_downsample_method(method, numeric_only, expected_data): + # GH#46442 test if `numeric_only` behave as expected for DataFrameGroupBy + + index = date_range("2018-01-01", periods=2, freq="D") + expected_index = date_range("2018-12-31", periods=1, freq="YE") + df = DataFrame({"cat": ["cat_1", "cat_2"], "num": [5, 20]}, index=index) + resampled = df.resample("YE") + if numeric_only is lib.no_default: + kwargs = {} + else: + kwargs = {"numeric_only": numeric_only} + + func = getattr(resampled, method) + if isinstance(expected_data, str): + if method in ("var", "mean", "median", "prod"): + klass = TypeError + msg = re.escape(f"agg function failed [how->{method},dtype->") + else: + klass = ValueError + msg = expected_data + with pytest.raises(klass, match=msg): + _ = func(**kwargs) + else: + result = func(**kwargs) + expected = DataFrame(expected_data, index=expected_index) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "method, numeric_only, expected_data", + [ + ("sum", True, ()), + ("sum", False, ["cat_1cat_2"]), + ("sum", lib.no_default, ["cat_1cat_2"]), + ("prod", True, ()), + ("prod", False, ()), + ("prod", lib.no_default, ()), + ("min", True, ()), + ("min", False, ["cat_1"]), + ("min", lib.no_default, ["cat_1"]), + ("max", True, ()), + ("max", False, ["cat_2"]), + ("max", lib.no_default, ["cat_2"]), + ("first", True, ()), + ("first", False, ["cat_1"]), + ("first", lib.no_default, ["cat_1"]), + ("last", True, ()), + ("last", False, ["cat_2"]), + ("last", lib.no_default, ["cat_2"]), + ], +) +def test_series_downsample_method(method, numeric_only, expected_data): + # GH#46442 test if `numeric_only` behave as expected for SeriesGroupBy + + index = date_range("2018-01-01", periods=2, freq="D") + expected_index = date_range("2018-12-31", periods=1, freq="YE") + df = Series(["cat_1", "cat_2"], index=index) + resampled = df.resample("YE") + kwargs = {} if numeric_only is lib.no_default else {"numeric_only": numeric_only} + + func = getattr(resampled, method) + if numeric_only and numeric_only is not lib.no_default: + msg = rf"Cannot use numeric_only=True with SeriesGroupBy\.{method}" + with pytest.raises(TypeError, match=msg): + func(**kwargs) + elif method == "prod": + msg = re.escape("agg function failed [how->prod,dtype->") + with pytest.raises(TypeError, match=msg): + func(**kwargs) + else: + result = func(**kwargs) + expected = Series(expected_data, index=expected_index) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "method, raises", + [ + ("sum", True), + ("prod", True), + ("min", True), + ("max", True), + ("first", False), + ("last", False), + ("median", False), + ("mean", True), + ("std", True), + ("var", True), + ("sem", False), + ("ohlc", False), + ("nunique", False), + ], +) +def test_args_kwargs_depr(method, raises): + index = date_range("20180101", periods=3, freq="h") + df = Series([2, 4, 6], index=index) + resampled = df.resample("30min") + args = () + + func = getattr(resampled, method) + + error_msg = "numpy operations are not valid with resample." + error_msg_type = "too many arguments passed in" + warn_msg = f"Passing additional args to DatetimeIndexResampler.{method}" + + if raises: + with tm.assert_produces_warning(FutureWarning, match=warn_msg): + with pytest.raises(UnsupportedFunctionCall, match=error_msg): + func(*args, 1, 2, 3, 4) + else: + with tm.assert_produces_warning(FutureWarning, match=warn_msg): + with pytest.raises(TypeError, match=error_msg_type): + func(*args, 1, 2, 3, 4) + + +def test_df_axis_param_depr(): + index = date_range(datetime(2005, 1, 1), datetime(2005, 1, 10), freq="D") + index.name = "date" + df = DataFrame( + np.random.default_rng(2).random((10, 2)), columns=list("AB"), index=index + ).T + + # Deprecation error when axis=1 is explicitly passed + warning_msg = "DataFrame.resample with axis=1 is deprecated." + with tm.assert_produces_warning(FutureWarning, match=warning_msg): + df.resample("ME", axis=1) + + # Deprecation error when axis=0 is explicitly passed + df = df.T + warning_msg = ( + "The 'axis' keyword in DataFrame.resample is deprecated and " + "will be removed in a future version." + ) + with tm.assert_produces_warning(FutureWarning, match=warning_msg): + df.resample("ME", axis=0) + + +def test_series_axis_param_depr(_test_series): + warning_msg = ( + "The 'axis' keyword in Series.resample is " + "deprecated and will be removed in a future version." + ) + with tm.assert_produces_warning(FutureWarning, match=warning_msg): + _test_series.resample("h", axis=0) + + +def test_resample_empty(): + # GH#52484 + df = DataFrame( + index=pd.to_datetime( + ["2018-01-01 00:00:00", "2018-01-01 12:00:00", "2018-01-02 00:00:00"] + ) + ) + expected = DataFrame( + index=pd.to_datetime( + [ + "2018-01-01 00:00:00", + "2018-01-01 08:00:00", + "2018-01-01 16:00:00", + "2018-01-02 00:00:00", + ] + ) + ) + result = df.resample("8h").mean() + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/resample/test_resampler_grouper.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/resample/test_resampler_grouper.py new file mode 100644 index 0000000000000000000000000000000000000000..550523a432a894c7edbd8202a8217eeeef876cf8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/resample/test_resampler_grouper.py @@ -0,0 +1,715 @@ +from textwrap import dedent + +import numpy as np +import pytest + +from pandas.compat import is_platform_windows + +import pandas as pd +from pandas import ( + DataFrame, + Index, + Series, + TimedeltaIndex, + Timestamp, +) +import pandas._testing as tm +from pandas.core.indexes.datetimes import date_range + + +@pytest.fixture +def test_frame(): + return DataFrame( + {"A": [1] * 20 + [2] * 12 + [3] * 8, "B": np.arange(40)}, + index=date_range("1/1/2000", freq="s", periods=40), + ) + + +def test_tab_complete_ipython6_warning(ip): + from IPython.core.completer import provisionalcompleter + + code = dedent( + """\ + import numpy as np + from pandas import Series, date_range + data = np.arange(10, dtype=np.float64) + index = date_range("2020-01-01", periods=len(data)) + s = Series(data, index=index) + rs = s.resample("D") + """ + ) + ip.run_cell(code) + + # GH 31324 newer jedi version raises Deprecation warning; + # appears resolved 2021-02-02 + with tm.assert_produces_warning(None, raise_on_extra_warnings=False): + with provisionalcompleter("ignore"): + list(ip.Completer.completions("rs.", 1)) + + +def test_deferred_with_groupby(): + # GH 12486 + # support deferred resample ops with groupby + data = [ + ["2010-01-01", "A", 2], + ["2010-01-02", "A", 3], + ["2010-01-05", "A", 8], + ["2010-01-10", "A", 7], + ["2010-01-13", "A", 3], + ["2010-01-01", "B", 5], + ["2010-01-03", "B", 2], + ["2010-01-04", "B", 1], + ["2010-01-11", "B", 7], + ["2010-01-14", "B", 3], + ] + + df = DataFrame(data, columns=["date", "id", "score"]) + df.date = pd.to_datetime(df.date) + + def f_0(x): + return x.set_index("date").resample("D").asfreq() + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + expected = df.groupby("id").apply(f_0) + msg = "DataFrameGroupBy.resample operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.set_index("date").groupby("id").resample("D").asfreq() + tm.assert_frame_equal(result, expected) + + df = DataFrame( + { + "date": date_range(start="2016-01-01", periods=4, freq="W"), + "group": [1, 1, 2, 2], + "val": [5, 6, 7, 8], + } + ).set_index("date") + + def f_1(x): + return x.resample("1D").ffill() + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + expected = df.groupby("group").apply(f_1) + msg = "DataFrameGroupBy.resample operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("group").resample("1D").ffill() + tm.assert_frame_equal(result, expected) + + +def test_getitem(test_frame): + g = test_frame.groupby("A") + + expected = g.B.apply(lambda x: x.resample("2s").mean()) + + result = g.resample("2s").B.mean() + tm.assert_series_equal(result, expected) + + result = g.B.resample("2s").mean() + tm.assert_series_equal(result, expected) + + msg = "DataFrameGroupBy.resample operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = g.resample("2s").mean().B + tm.assert_series_equal(result, expected) + + +def test_getitem_multiple(): + # GH 13174 + # multiple calls after selection causing an issue with aliasing + data = [{"id": 1, "buyer": "A"}, {"id": 2, "buyer": "B"}] + df = DataFrame(data, index=date_range("2016-01-01", periods=2)) + r = df.groupby("id").resample("1D") + result = r["buyer"].count() + + exp_mi = pd.MultiIndex.from_arrays([[1, 2], df.index], names=("id", None)) + expected = Series( + [1, 1], + index=exp_mi, + name="buyer", + ) + tm.assert_series_equal(result, expected) + + result = r["buyer"].count() + tm.assert_series_equal(result, expected) + + +def test_groupby_resample_on_api_with_getitem(): + # GH 17813 + df = DataFrame( + {"id": list("aabbb"), "date": date_range("1-1-2016", periods=5), "data": 1} + ) + exp = df.set_index("date").groupby("id").resample("2D")["data"].sum() + result = df.groupby("id").resample("2D", on="date")["data"].sum() + tm.assert_series_equal(result, exp) + + +def test_groupby_with_origin(): + # GH 31809 + + freq = "1399min" # prime number that is smaller than 24h + start, end = "1/1/2000 00:00:00", "1/31/2000 00:00" + middle = "1/15/2000 00:00:00" + + rng = date_range(start, end, freq="1231min") # prime number + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + ts2 = ts[middle:end] + + # proves that grouper without a fixed origin does not work + # when dealing with unusual frequencies + simple_grouper = pd.Grouper(freq=freq) + count_ts = ts.groupby(simple_grouper).agg("count") + count_ts = count_ts[middle:end] + count_ts2 = ts2.groupby(simple_grouper).agg("count") + with pytest.raises(AssertionError, match="Index are different"): + tm.assert_index_equal(count_ts.index, count_ts2.index) + + # test origin on 1970-01-01 00:00:00 + origin = Timestamp(0) + adjusted_grouper = pd.Grouper(freq=freq, origin=origin) + adjusted_count_ts = ts.groupby(adjusted_grouper).agg("count") + adjusted_count_ts = adjusted_count_ts[middle:end] + adjusted_count_ts2 = ts2.groupby(adjusted_grouper).agg("count") + tm.assert_series_equal(adjusted_count_ts, adjusted_count_ts2) + + # test origin on 2049-10-18 20:00:00 + origin_future = Timestamp(0) + pd.Timedelta("1399min") * 30_000 + adjusted_grouper2 = pd.Grouper(freq=freq, origin=origin_future) + adjusted2_count_ts = ts.groupby(adjusted_grouper2).agg("count") + adjusted2_count_ts = adjusted2_count_ts[middle:end] + adjusted2_count_ts2 = ts2.groupby(adjusted_grouper2).agg("count") + tm.assert_series_equal(adjusted2_count_ts, adjusted2_count_ts2) + + # both grouper use an adjusted timestamp that is a multiple of 1399 min + # they should be equals even if the adjusted_timestamp is in the future + tm.assert_series_equal(adjusted_count_ts, adjusted2_count_ts2) + + +def test_nearest(): + # GH 17496 + # Resample nearest + index = date_range("1/1/2000", periods=3, freq="min") + result = Series(range(3), index=index).resample("20s").nearest() + + expected = Series( + [0, 0, 1, 1, 1, 2, 2], + index=pd.DatetimeIndex( + [ + "2000-01-01 00:00:00", + "2000-01-01 00:00:20", + "2000-01-01 00:00:40", + "2000-01-01 00:01:00", + "2000-01-01 00:01:20", + "2000-01-01 00:01:40", + "2000-01-01 00:02:00", + ], + dtype="datetime64[ns]", + freq="20s", + ), + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "f", + [ + "first", + "last", + "median", + "sem", + "sum", + "mean", + "min", + "max", + "size", + "count", + "nearest", + "bfill", + "ffill", + "asfreq", + "ohlc", + ], +) +def test_methods(f, test_frame): + g = test_frame.groupby("A") + r = g.resample("2s") + + msg = "DataFrameGroupBy.resample operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = getattr(r, f)() + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + expected = g.apply(lambda x: getattr(x.resample("2s"), f)()) + tm.assert_equal(result, expected) + + +def test_methods_nunique(test_frame): + # series only + g = test_frame.groupby("A") + r = g.resample("2s") + result = r.B.nunique() + expected = g.B.apply(lambda x: x.resample("2s").nunique()) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("f", ["std", "var"]) +def test_methods_std_var(f, test_frame): + g = test_frame.groupby("A") + r = g.resample("2s") + msg = "DataFrameGroupBy.resample operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = getattr(r, f)(ddof=1) + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + expected = g.apply(lambda x: getattr(x.resample("2s"), f)(ddof=1)) + tm.assert_frame_equal(result, expected) + + +def test_apply(test_frame): + g = test_frame.groupby("A") + r = g.resample("2s") + + # reduction + msg = "DataFrameGroupBy.resample operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + expected = g.resample("2s").sum() + + def f_0(x): + return x.resample("2s").sum() + + msg = "DataFrameGroupBy.resample operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = r.apply(f_0) + tm.assert_frame_equal(result, expected) + + def f_1(x): + return x.resample("2s").apply(lambda y: y.sum()) + + msg = "DataFrameGroupBy.apply operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = g.apply(f_1) + # y.sum() results in int64 instead of int32 on 32-bit architectures + expected = expected.astype("int64") + tm.assert_frame_equal(result, expected) + + +def test_apply_with_mutated_index(): + # GH 15169 + index = date_range("1-1-2015", "12-31-15", freq="D") + df = DataFrame( + data={"col1": np.random.default_rng(2).random(len(index))}, index=index + ) + + def f(x): + s = Series([1, 2], index=["a", "b"]) + return s + + expected = df.groupby(pd.Grouper(freq="ME")).apply(f) + + result = df.resample("ME").apply(f) + tm.assert_frame_equal(result, expected) + + # A case for series + expected = df["col1"].groupby(pd.Grouper(freq="ME"), group_keys=False).apply(f) + result = df["col1"].resample("ME").apply(f) + tm.assert_series_equal(result, expected) + + +def test_apply_columns_multilevel(): + # GH 16231 + cols = pd.MultiIndex.from_tuples([("A", "a", "", "one"), ("B", "b", "i", "two")]) + ind = date_range(start="2017-01-01", freq="15Min", periods=8) + df = DataFrame(np.array([0] * 16).reshape(8, 2), index=ind, columns=cols) + agg_dict = {col: (np.sum if col[3] == "one" else np.mean) for col in df.columns} + result = df.resample("h").apply(lambda x: agg_dict[x.name](x)) + expected = DataFrame( + 2 * [[0, 0.0]], + index=date_range(start="2017-01-01", freq="1h", periods=2), + columns=pd.MultiIndex.from_tuples( + [("A", "a", "", "one"), ("B", "b", "i", "two")] + ), + ) + tm.assert_frame_equal(result, expected) + + +def test_apply_non_naive_index(): + def weighted_quantile(series, weights, q): + series = series.sort_values() + cumsum = weights.reindex(series.index).fillna(0).cumsum() + cutoff = cumsum.iloc[-1] * q + return series[cumsum >= cutoff].iloc[0] + + times = date_range("2017-6-23 18:00", periods=8, freq="15min", tz="UTC") + data = Series([1.0, 1, 1, 1, 1, 2, 2, 0], index=times) + weights = Series([160.0, 91, 65, 43, 24, 10, 1, 0], index=times) + + result = data.resample("D").apply(weighted_quantile, weights=weights, q=0.5) + ind = date_range( + "2017-06-23 00:00:00+00:00", "2017-06-23 00:00:00+00:00", freq="D", tz="UTC" + ) + expected = Series([1.0], index=ind) + tm.assert_series_equal(result, expected) + + +def test_resample_groupby_with_label(unit): + # GH 13235 + index = date_range("2000-01-01", freq="2D", periods=5, unit=unit) + df = DataFrame(index=index, data={"col0": [0, 0, 1, 1, 2], "col1": [1, 1, 1, 1, 1]}) + msg = "DataFrameGroupBy.resample operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("col0").resample("1W", label="left").sum() + + mi = [ + np.array([0, 0, 1, 2], dtype=np.int64), + np.array( + ["1999-12-26", "2000-01-02", "2000-01-02", "2000-01-02"], + dtype=f"M8[{unit}]", + ), + ] + mindex = pd.MultiIndex.from_arrays(mi, names=["col0", None]) + expected = DataFrame( + data={"col0": [0, 0, 2, 2], "col1": [1, 1, 2, 1]}, index=mindex + ) + + tm.assert_frame_equal(result, expected) + + +def test_consistency_with_window(test_frame): + # consistent return values with window + df = test_frame + expected = Index([1, 2, 3], name="A") + msg = "DataFrameGroupBy.resample operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("A").resample("2s").mean() + assert result.index.nlevels == 2 + tm.assert_index_equal(result.index.levels[0], expected) + + result = df.groupby("A").rolling(20).mean() + assert result.index.nlevels == 2 + tm.assert_index_equal(result.index.levels[0], expected) + + +def test_median_duplicate_columns(): + # GH 14233 + + df = DataFrame( + np.random.default_rng(2).standard_normal((20, 3)), + columns=list("aaa"), + index=date_range("2012-01-01", periods=20, freq="s"), + ) + df2 = df.copy() + df2.columns = ["a", "b", "c"] + expected = df2.resample("5s").median() + result = df.resample("5s").median() + expected.columns = result.columns + tm.assert_frame_equal(result, expected) + + +def test_apply_to_one_column_of_df(): + # GH: 36951 + df = DataFrame( + {"col": range(10), "col1": range(10, 20)}, + index=date_range("2012-01-01", periods=10, freq="20min"), + ) + + # access "col" via getattr -> make sure we handle AttributeError + result = df.resample("h").apply(lambda group: group.col.sum()) + expected = Series( + [3, 12, 21, 9], index=date_range("2012-01-01", periods=4, freq="h") + ) + tm.assert_series_equal(result, expected) + + # access "col" via _getitem__ -> make sure we handle KeyErrpr + result = df.resample("h").apply(lambda group: group["col"].sum()) + tm.assert_series_equal(result, expected) + + +def test_resample_groupby_agg(): + # GH: 33548 + df = DataFrame( + { + "cat": [ + "cat_1", + "cat_1", + "cat_2", + "cat_1", + "cat_2", + "cat_1", + "cat_2", + "cat_1", + ], + "num": [5, 20, 22, 3, 4, 30, 10, 50], + "date": [ + "2019-2-1", + "2018-02-03", + "2020-3-11", + "2019-2-2", + "2019-2-2", + "2018-12-4", + "2020-3-11", + "2020-12-12", + ], + } + ) + df["date"] = pd.to_datetime(df["date"]) + + resampled = df.groupby("cat").resample("YE", on="date") + expected = resampled[["num"]].sum() + result = resampled.agg({"num": "sum"}) + + tm.assert_frame_equal(result, expected) + + +def test_resample_groupby_agg_listlike(): + # GH 42905 + ts = Timestamp("2021-02-28 00:00:00") + df = DataFrame({"class": ["beta"], "value": [69]}, index=Index([ts], name="date")) + resampled = df.groupby("class").resample("ME")["value"] + result = resampled.agg(["sum", "size"]) + expected = DataFrame( + [[69, 1]], + index=pd.MultiIndex.from_tuples([("beta", ts)], names=["class", "date"]), + columns=["sum", "size"], + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("keys", [["a"], ["a", "b"]]) +def test_empty(keys): + # GH 26411 + df = DataFrame([], columns=["a", "b"], index=TimedeltaIndex([])) + msg = "DataFrameGroupBy.resample operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby(keys).resample(rule=pd.to_timedelta("00:00:01")).mean() + expected = ( + DataFrame(columns=["a", "b"]) + .set_index(keys, drop=False) + .set_index(TimedeltaIndex([]), append=True) + ) + if len(keys) == 1: + expected.index.name = keys[0] + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("consolidate", [True, False]) +def test_resample_groupby_agg_object_dtype_all_nan(consolidate): + # https://github.com/pandas-dev/pandas/issues/39329 + + dates = date_range("2020-01-01", periods=15, freq="D") + df1 = DataFrame({"key": "A", "date": dates, "col1": range(15), "col_object": "val"}) + df2 = DataFrame({"key": "B", "date": dates, "col1": range(15)}) + df = pd.concat([df1, df2], ignore_index=True) + if consolidate: + df = df._consolidate() + + msg = "DataFrameGroupBy.resample operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby(["key"]).resample("W", on="date").min() + idx = pd.MultiIndex.from_arrays( + [ + ["A"] * 3 + ["B"] * 3, + pd.to_datetime(["2020-01-05", "2020-01-12", "2020-01-19"] * 2).as_unit( + "ns" + ), + ], + names=["key", "date"], + ) + expected = DataFrame( + { + "key": ["A"] * 3 + ["B"] * 3, + "col1": [0, 5, 12] * 2, + "col_object": ["val"] * 3 + [np.nan] * 3, + }, + index=idx, + ) + tm.assert_frame_equal(result, expected) + + +def test_groupby_resample_with_list_of_keys(): + # GH 47362 + df = DataFrame( + data={ + "date": date_range(start="2016-01-01", periods=8), + "group": [0, 0, 0, 0, 1, 1, 1, 1], + "val": [1, 7, 5, 2, 3, 10, 5, 1], + } + ) + result = df.groupby("group").resample("2D", on="date")[["val"]].mean() + + mi_exp = pd.MultiIndex.from_arrays( + [[0, 0, 1, 1], df["date"]._values[::2]], names=["group", "date"] + ) + expected = DataFrame( + data={ + "val": [4.0, 3.5, 6.5, 3.0], + }, + index=mi_exp, + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("keys", [["a"], ["a", "b"]]) +def test_resample_no_index(keys): + # GH 47705 + df = DataFrame([], columns=["a", "b", "date"]) + df["date"] = pd.to_datetime(df["date"]) + df = df.set_index("date") + msg = "DataFrameGroupBy.resample operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby(keys).resample(rule=pd.to_timedelta("00:00:01")).mean() + expected = DataFrame(columns=["a", "b", "date"]).set_index(keys, drop=False) + expected["date"] = pd.to_datetime(expected["date"]) + expected = expected.set_index("date", append=True, drop=True) + if len(keys) == 1: + expected.index.name = keys[0] + + tm.assert_frame_equal(result, expected) + + +def test_resample_no_columns(): + # GH#52484 + df = DataFrame( + index=Index( + pd.to_datetime( + ["2018-01-01 00:00:00", "2018-01-01 12:00:00", "2018-01-02 00:00:00"] + ), + name="date", + ) + ) + result = df.groupby([0, 0, 1]).resample(rule=pd.to_timedelta("06:00:00")).mean() + index = pd.to_datetime( + [ + "2018-01-01 00:00:00", + "2018-01-01 06:00:00", + "2018-01-01 12:00:00", + "2018-01-02 00:00:00", + ] + ) + expected = DataFrame( + index=pd.MultiIndex( + levels=[np.array([0, 1], dtype=np.intp), index], + codes=[[0, 0, 0, 1], [0, 1, 2, 3]], + names=[None, "date"], + ) + ) + + # GH#52710 - Index comes out as 32-bit on 64-bit Windows + tm.assert_frame_equal(result, expected, check_index_type=not is_platform_windows()) + + +def test_groupby_resample_size_all_index_same(): + # GH 46826 + df = DataFrame( + {"A": [1] * 3 + [2] * 3 + [1] * 3 + [2] * 3, "B": np.arange(12)}, + index=date_range("31/12/2000 18:00", freq="h", periods=12), + ) + msg = "DataFrameGroupBy.resample operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = df.groupby("A").resample("D").size() + + mi_exp = pd.MultiIndex.from_arrays( + [ + [1, 1, 2, 2], + pd.DatetimeIndex(["2000-12-31", "2001-01-01"] * 2, dtype="M8[ns]"), + ], + names=["A", None], + ) + expected = Series( + 3, + index=mi_exp, + ) + tm.assert_series_equal(result, expected) + + +def test_groupby_resample_on_index_with_list_of_keys(): + # GH 50840 + df = DataFrame( + data={ + "group": [0, 0, 0, 0, 1, 1, 1, 1], + "val": [3, 1, 4, 1, 5, 9, 2, 6], + }, + index=date_range(start="2016-01-01", periods=8, name="date"), + ) + result = df.groupby("group").resample("2D")[["val"]].mean() + + mi_exp = pd.MultiIndex.from_arrays( + [[0, 0, 1, 1], df.index[::2]], names=["group", "date"] + ) + expected = DataFrame( + data={ + "val": [2.0, 2.5, 7.0, 4.0], + }, + index=mi_exp, + ) + tm.assert_frame_equal(result, expected) + + +def test_groupby_resample_on_index_with_list_of_keys_multi_columns(): + # GH 50876 + df = DataFrame( + data={ + "group": [0, 0, 0, 0, 1, 1, 1, 1], + "first_val": [3, 1, 4, 1, 5, 9, 2, 6], + "second_val": [2, 7, 1, 8, 2, 8, 1, 8], + "third_val": [1, 4, 1, 4, 2, 1, 3, 5], + }, + index=date_range(start="2016-01-01", periods=8, name="date"), + ) + result = df.groupby("group").resample("2D")[["first_val", "second_val"]].mean() + + mi_exp = pd.MultiIndex.from_arrays( + [[0, 0, 1, 1], df.index[::2]], names=["group", "date"] + ) + expected = DataFrame( + data={ + "first_val": [2.0, 2.5, 7.0, 4.0], + "second_val": [4.5, 4.5, 5.0, 4.5], + }, + index=mi_exp, + ) + tm.assert_frame_equal(result, expected) + + +def test_groupby_resample_on_index_with_list_of_keys_missing_column(): + # GH 50876 + df = DataFrame( + data={ + "group": [0, 0, 0, 0, 1, 1, 1, 1], + "val": [3, 1, 4, 1, 5, 9, 2, 6], + }, + index=Series( + date_range(start="2016-01-01", periods=8), + name="date", + ), + ) + gb = df.groupby("group") + rs = gb.resample("2D") + with pytest.raises(KeyError, match="Columns not found"): + rs[["val_not_in_dataframe"]] + + +@pytest.mark.parametrize("kind", ["datetime", "period"]) +def test_groupby_resample_kind(kind): + # GH 24103 + df = DataFrame( + { + "datetime": pd.to_datetime( + ["20181101 1100", "20181101 1200", "20181102 1300", "20181102 1400"] + ), + "group": ["A", "B", "A", "B"], + "value": [1, 2, 3, 4], + } + ) + df = df.set_index("datetime") + result = df.groupby("group")["value"].resample("D", kind=kind).last() + + dt_level = pd.DatetimeIndex(["2018-11-01", "2018-11-02"]) + if kind == "period": + dt_level = dt_level.to_period(freq="D") + expected_index = pd.MultiIndex.from_product( + [["A", "B"], dt_level], + names=["group", "datetime"], + ) + expected = Series([1, 3, 2, 4], index=expected_index, name="value") + tm.assert_series_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/resample/test_time_grouper.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/resample/test_time_grouper.py new file mode 100644 index 0000000000000000000000000000000000000000..3f9340b800eae4775a1f697ae2e69a31ab42dcf2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/resample/test_time_grouper.py @@ -0,0 +1,390 @@ +from datetime import datetime +from operator import methodcaller + +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + Index, + Series, + Timestamp, +) +import pandas._testing as tm +from pandas.core.groupby.grouper import Grouper +from pandas.core.indexes.datetimes import date_range + + +@pytest.fixture +def test_series(): + return Series( + np.random.default_rng(2).standard_normal(1000), + index=date_range("1/1/2000", periods=1000), + ) + + +def test_apply(test_series): + grouper = Grouper(freq="YE", label="right", closed="right") + + grouped = test_series.groupby(grouper) + + def f(x): + return x.sort_values()[-3:] + + applied = grouped.apply(f) + expected = test_series.groupby(lambda x: x.year).apply(f) + + applied.index = applied.index.droplevel(0) + expected.index = expected.index.droplevel(0) + tm.assert_series_equal(applied, expected) + + +def test_count(test_series): + test_series[::3] = np.nan + + expected = test_series.groupby(lambda x: x.year).count() + + grouper = Grouper(freq="YE", label="right", closed="right") + result = test_series.groupby(grouper).count() + expected.index = result.index + tm.assert_series_equal(result, expected) + + result = test_series.resample("YE").count() + expected.index = result.index + tm.assert_series_equal(result, expected) + + +def test_numpy_reduction(test_series): + result = test_series.resample("YE", closed="right").prod() + + msg = "using SeriesGroupBy.prod" + with tm.assert_produces_warning(FutureWarning, match=msg): + expected = test_series.groupby(lambda x: x.year).agg(np.prod) + expected.index = result.index + + tm.assert_series_equal(result, expected) + + +def test_apply_iteration(): + # #2300 + N = 1000 + ind = date_range(start="2000-01-01", freq="D", periods=N) + df = DataFrame({"open": 1, "close": 2}, index=ind) + tg = Grouper(freq="ME") + + grouper, _ = tg._get_grouper(df) + + # Errors + grouped = df.groupby(grouper, group_keys=False) + + def f(df): + return df["close"] / df["open"] + + # it works! + result = grouped.apply(f) + tm.assert_index_equal(result.index, df.index) + + +@pytest.mark.parametrize( + "index", + [ + Index([1, 2]), + Index(["a", "b"]), + Index([1.1, 2.2]), + pd.MultiIndex.from_arrays([[1, 2], ["a", "b"]]), + ], +) +def test_fails_on_no_datetime_index(index): + name = type(index).__name__ + df = DataFrame({"a": range(len(index))}, index=index) + + msg = ( + "Only valid with DatetimeIndex, TimedeltaIndex " + f"or PeriodIndex, but got an instance of '{name}'" + ) + with pytest.raises(TypeError, match=msg): + df.groupby(Grouper(freq="D")) + + +def test_aaa_group_order(): + # GH 12840 + # check TimeGrouper perform stable sorts + n = 20 + data = np.random.default_rng(2).standard_normal((n, 4)) + df = DataFrame(data, columns=["A", "B", "C", "D"]) + df["key"] = [ + datetime(2013, 1, 1), + datetime(2013, 1, 2), + datetime(2013, 1, 3), + datetime(2013, 1, 4), + datetime(2013, 1, 5), + ] * 4 + grouped = df.groupby(Grouper(key="key", freq="D")) + + tm.assert_frame_equal(grouped.get_group(datetime(2013, 1, 1)), df[::5]) + tm.assert_frame_equal(grouped.get_group(datetime(2013, 1, 2)), df[1::5]) + tm.assert_frame_equal(grouped.get_group(datetime(2013, 1, 3)), df[2::5]) + tm.assert_frame_equal(grouped.get_group(datetime(2013, 1, 4)), df[3::5]) + tm.assert_frame_equal(grouped.get_group(datetime(2013, 1, 5)), df[4::5]) + + +def test_aggregate_normal(resample_method): + """Check TimeGrouper's aggregation is identical as normal groupby.""" + + data = np.random.default_rng(2).standard_normal((20, 4)) + normal_df = DataFrame(data, columns=["A", "B", "C", "D"]) + normal_df["key"] = [1, 2, 3, 4, 5] * 4 + + dt_df = DataFrame(data, columns=["A", "B", "C", "D"]) + dt_df["key"] = Index( + [ + datetime(2013, 1, 1), + datetime(2013, 1, 2), + datetime(2013, 1, 3), + datetime(2013, 1, 4), + datetime(2013, 1, 5), + ] + * 4, + dtype="M8[ns]", + ) + + normal_grouped = normal_df.groupby("key") + dt_grouped = dt_df.groupby(Grouper(key="key", freq="D")) + + expected = getattr(normal_grouped, resample_method)() + dt_result = getattr(dt_grouped, resample_method)() + expected.index = date_range(start="2013-01-01", freq="D", periods=5, name="key") + tm.assert_equal(expected, dt_result) + + +@pytest.mark.xfail(reason="if TimeGrouper is used included, 'nth' doesn't work yet") +def test_aggregate_nth(): + """Check TimeGrouper's aggregation is identical as normal groupby.""" + + data = np.random.default_rng(2).standard_normal((20, 4)) + normal_df = DataFrame(data, columns=["A", "B", "C", "D"]) + normal_df["key"] = [1, 2, 3, 4, 5] * 4 + + dt_df = DataFrame(data, columns=["A", "B", "C", "D"]) + dt_df["key"] = [ + datetime(2013, 1, 1), + datetime(2013, 1, 2), + datetime(2013, 1, 3), + datetime(2013, 1, 4), + datetime(2013, 1, 5), + ] * 4 + + normal_grouped = normal_df.groupby("key") + dt_grouped = dt_df.groupby(Grouper(key="key", freq="D")) + + expected = normal_grouped.nth(3) + expected.index = date_range(start="2013-01-01", freq="D", periods=5, name="key") + dt_result = dt_grouped.nth(3) + tm.assert_frame_equal(expected, dt_result) + + +@pytest.mark.parametrize( + "method, method_args, unit", + [ + ("sum", {}, 0), + ("sum", {"min_count": 0}, 0), + ("sum", {"min_count": 1}, np.nan), + ("prod", {}, 1), + ("prod", {"min_count": 0}, 1), + ("prod", {"min_count": 1}, np.nan), + ], +) +def test_resample_entirely_nat_window(method, method_args, unit): + ser = Series([0] * 2 + [np.nan] * 2, index=date_range("2017", periods=4)) + result = methodcaller(method, **method_args)(ser.resample("2d")) + + exp_dti = pd.DatetimeIndex(["2017-01-01", "2017-01-03"], dtype="M8[ns]", freq="2D") + expected = Series([0.0, unit], index=exp_dti) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "func, fill_value", + [("min", np.nan), ("max", np.nan), ("sum", 0), ("prod", 1), ("count", 0)], +) +def test_aggregate_with_nat(func, fill_value): + # check TimeGrouper's aggregation is identical as normal groupby + # if NaT is included, 'var', 'std', 'mean', 'first','last' + # and 'nth' doesn't work yet + + n = 20 + data = np.random.default_rng(2).standard_normal((n, 4)).astype("int64") + normal_df = DataFrame(data, columns=["A", "B", "C", "D"]) + normal_df["key"] = [1, 2, np.nan, 4, 5] * 4 + + dt_df = DataFrame(data, columns=["A", "B", "C", "D"]) + dt_df["key"] = Index( + [ + datetime(2013, 1, 1), + datetime(2013, 1, 2), + pd.NaT, + datetime(2013, 1, 4), + datetime(2013, 1, 5), + ] + * 4, + dtype="M8[ns]", + ) + + normal_grouped = normal_df.groupby("key") + dt_grouped = dt_df.groupby(Grouper(key="key", freq="D")) + + normal_result = getattr(normal_grouped, func)() + dt_result = getattr(dt_grouped, func)() + + pad = DataFrame([[fill_value] * 4], index=[3], columns=["A", "B", "C", "D"]) + expected = pd.concat([normal_result, pad]) + expected = expected.sort_index() + dti = date_range( + start="2013-01-01", + freq="D", + periods=5, + name="key", + unit=dt_df["key"]._values.unit, + ) + expected.index = dti._with_freq(None) # TODO: is this desired? + tm.assert_frame_equal(expected, dt_result) + assert dt_result.index.name == "key" + + +def test_aggregate_with_nat_size(): + # GH 9925 + n = 20 + data = np.random.default_rng(2).standard_normal((n, 4)).astype("int64") + normal_df = DataFrame(data, columns=["A", "B", "C", "D"]) + normal_df["key"] = [1, 2, np.nan, 4, 5] * 4 + + dt_df = DataFrame(data, columns=["A", "B", "C", "D"]) + dt_df["key"] = Index( + [ + datetime(2013, 1, 1), + datetime(2013, 1, 2), + pd.NaT, + datetime(2013, 1, 4), + datetime(2013, 1, 5), + ] + * 4, + dtype="M8[ns]", + ) + + normal_grouped = normal_df.groupby("key") + dt_grouped = dt_df.groupby(Grouper(key="key", freq="D")) + + normal_result = normal_grouped.size() + dt_result = dt_grouped.size() + + pad = Series([0], index=[3]) + expected = pd.concat([normal_result, pad]) + expected = expected.sort_index() + expected.index = date_range( + start="2013-01-01", + freq="D", + periods=5, + name="key", + unit=dt_df["key"]._values.unit, + )._with_freq(None) + tm.assert_series_equal(expected, dt_result) + assert dt_result.index.name == "key" + + +def test_repr(): + # GH18203 + result = repr(Grouper(key="A", freq="h")) + expected = ( + "TimeGrouper(key='A', freq=, axis=0, sort=True, dropna=True, " + "closed='left', label='left', how='mean', " + "convention='e', origin='start_day')" + ) + assert result == expected + + result = repr(Grouper(key="A", freq="h", origin="2000-01-01")) + expected = ( + "TimeGrouper(key='A', freq=, axis=0, sort=True, dropna=True, " + "closed='left', label='left', how='mean', " + "convention='e', origin=Timestamp('2000-01-01 00:00:00'))" + ) + assert result == expected + + +@pytest.mark.parametrize( + "method, method_args, expected_values", + [ + ("sum", {}, [1, 0, 1]), + ("sum", {"min_count": 0}, [1, 0, 1]), + ("sum", {"min_count": 1}, [1, np.nan, 1]), + ("sum", {"min_count": 2}, [np.nan, np.nan, np.nan]), + ("prod", {}, [1, 1, 1]), + ("prod", {"min_count": 0}, [1, 1, 1]), + ("prod", {"min_count": 1}, [1, np.nan, 1]), + ("prod", {"min_count": 2}, [np.nan, np.nan, np.nan]), + ], +) +def test_upsample_sum(method, method_args, expected_values): + ser = Series(1, index=date_range("2017", periods=2, freq="h")) + resampled = ser.resample("30min") + index = pd.DatetimeIndex( + ["2017-01-01T00:00:00", "2017-01-01T00:30:00", "2017-01-01T01:00:00"], + dtype="M8[ns]", + freq="30min", + ) + result = methodcaller(method, **method_args)(resampled) + expected = Series(expected_values, index=index) + tm.assert_series_equal(result, expected) + + +def test_groupby_resample_interpolate(): + # GH 35325 + d = {"price": [10, 11, 9], "volume": [50, 60, 50]} + + df = DataFrame(d) + + df["week_starting"] = date_range("01/01/2018", periods=3, freq="W") + + msg = "DataFrameGroupBy.resample operated on the grouping columns" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + result = ( + df.set_index("week_starting") + .groupby("volume") + .resample("1D") + .interpolate(method="linear") + ) + + volume = [50] * 15 + [60] + week_starting = list(date_range("2018-01-07", "2018-01-21")) + [ + Timestamp("2018-01-14") + ] + expected_ind = pd.MultiIndex.from_arrays( + [volume, week_starting], + names=["volume", "week_starting"], + ) + + expected = DataFrame( + data={ + "price": [ + 10.0, + 9.928571428571429, + 9.857142857142858, + 9.785714285714286, + 9.714285714285714, + 9.642857142857142, + 9.571428571428571, + 9.5, + 9.428571428571429, + 9.357142857142858, + 9.285714285714286, + 9.214285714285714, + 9.142857142857142, + 9.071428571428571, + 9.0, + 11.0, + ], + "volume": [50.0] * 15 + [60], + }, + index=expected_ind, + ) + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/resample/test_timedelta.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/resample/test_timedelta.py new file mode 100644 index 0000000000000000000000000000000000000000..7c70670d42908af9bd488eba5941732a9fc93da0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/resample/test_timedelta.py @@ -0,0 +1,220 @@ +from datetime import timedelta + +import numpy as np +import pytest + +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + DataFrame, + Series, +) +import pandas._testing as tm +from pandas.core.indexes.timedeltas import timedelta_range + + +def test_asfreq_bug(): + df = DataFrame(data=[1, 3], index=[timedelta(), timedelta(minutes=3)]) + result = df.resample("1min").asfreq() + expected = DataFrame( + data=[1, np.nan, np.nan, 3], + index=timedelta_range("0 day", periods=4, freq="1min"), + ) + tm.assert_frame_equal(result, expected) + + +def test_resample_with_nat(): + # GH 13223 + index = pd.to_timedelta(["0s", pd.NaT, "2s"]) + result = DataFrame({"value": [2, 3, 5]}, index).resample("1s").mean() + expected = DataFrame( + {"value": [2.5, np.nan, 5.0]}, + index=timedelta_range("0 day", periods=3, freq="1s"), + ) + tm.assert_frame_equal(result, expected) + + +def test_resample_as_freq_with_subperiod(): + # GH 13022 + index = timedelta_range("00:00:00", "00:10:00", freq="5min") + df = DataFrame(data={"value": [1, 5, 10]}, index=index) + result = df.resample("2min").asfreq() + expected_data = {"value": [1, np.nan, np.nan, np.nan, np.nan, 10]} + expected = DataFrame( + data=expected_data, index=timedelta_range("00:00:00", "00:10:00", freq="2min") + ) + tm.assert_frame_equal(result, expected) + + +def test_resample_with_timedeltas(): + expected = DataFrame({"A": np.arange(1480)}) + expected = expected.groupby(expected.index // 30).sum() + expected.index = timedelta_range("0 days", freq="30min", periods=50) + + df = DataFrame( + {"A": np.arange(1480)}, index=pd.to_timedelta(np.arange(1480), unit="min") + ) + result = df.resample("30min").sum() + + tm.assert_frame_equal(result, expected) + + s = df["A"] + result = s.resample("30min").sum() + tm.assert_series_equal(result, expected["A"]) + + +def test_resample_single_period_timedelta(): + s = Series(list(range(5)), index=timedelta_range("1 day", freq="s", periods=5)) + result = s.resample("2s").sum() + expected = Series([1, 5, 4], index=timedelta_range("1 day", freq="2s", periods=3)) + tm.assert_series_equal(result, expected) + + +def test_resample_timedelta_idempotency(): + # GH 12072 + index = timedelta_range("0", periods=9, freq="10ms") + series = Series(range(9), index=index) + result = series.resample("10ms").mean() + expected = series.astype(float) + tm.assert_series_equal(result, expected) + + +def test_resample_offset_with_timedeltaindex(): + # GH 10530 & 31809 + rng = timedelta_range(start="0s", periods=25, freq="s") + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + with_base = ts.resample("2s", offset="5s").mean() + without_base = ts.resample("2s").mean() + + exp_without_base = timedelta_range(start="0s", end="25s", freq="2s") + exp_with_base = timedelta_range(start="5s", end="29s", freq="2s") + + tm.assert_index_equal(without_base.index, exp_without_base) + tm.assert_index_equal(with_base.index, exp_with_base) + + +def test_resample_categorical_data_with_timedeltaindex(): + # GH #12169 + df = DataFrame({"Group_obj": "A"}, index=pd.to_timedelta(list(range(20)), unit="s")) + df["Group"] = df["Group_obj"].astype("category") + result = df.resample("10s").agg(lambda x: (x.value_counts().index[0])) + exp_tdi = pd.TimedeltaIndex(np.array([0, 10], dtype="m8[s]"), freq="10s").as_unit( + "ns" + ) + expected = DataFrame( + {"Group_obj": ["A", "A"], "Group": ["A", "A"]}, + index=exp_tdi, + ) + expected = expected.reindex(["Group_obj", "Group"], axis=1) + expected["Group"] = expected["Group_obj"].astype("category") + tm.assert_frame_equal(result, expected) + + +def test_resample_timedelta_values(): + # GH 13119 + # check that timedelta dtype is preserved when NaT values are + # introduced by the resampling + + times = timedelta_range("1 day", "6 day", freq="4D") + df = DataFrame({"time": times}, index=times) + + times2 = timedelta_range("1 day", "6 day", freq="2D") + exp = Series(times2, index=times2, name="time") + exp.iloc[1] = pd.NaT + + res = df.resample("2D").first()["time"] + tm.assert_series_equal(res, exp) + res = df["time"].resample("2D").first() + tm.assert_series_equal(res, exp) + + +@pytest.mark.parametrize( + "start, end, freq, resample_freq", + [ + ("8h", "21h59min50s", "10s", "3h"), # GH 30353 example + ("3h", "22h", "1h", "5h"), + ("527D", "5006D", "3D", "10D"), + ("1D", "10D", "1D", "2D"), # GH 13022 example + # tests that worked before GH 33498: + ("8h", "21h59min50s", "10s", "2h"), + ("0h", "21h59min50s", "10s", "3h"), + ("10D", "85D", "D", "2D"), + ], +) +def test_resample_timedelta_edge_case(start, end, freq, resample_freq): + # GH 33498 + # check that the timedelta bins does not contains an extra bin + idx = timedelta_range(start=start, end=end, freq=freq) + s = Series(np.arange(len(idx)), index=idx) + result = s.resample(resample_freq).min() + expected_index = timedelta_range(freq=resample_freq, start=start, end=end) + tm.assert_index_equal(result.index, expected_index) + assert result.index.freq == expected_index.freq + assert not np.isnan(result.iloc[-1]) + + +@pytest.mark.parametrize("duplicates", [True, False]) +def test_resample_with_timedelta_yields_no_empty_groups(duplicates): + # GH 10603 + df = DataFrame( + np.random.default_rng(2).normal(size=(10000, 4)), + index=timedelta_range(start="0s", periods=10000, freq="3906250ns"), + ) + if duplicates: + # case with non-unique columns + df.columns = ["A", "B", "A", "C"] + + result = df.loc["1s":, :].resample("3s").apply(lambda x: len(x)) + + expected = DataFrame( + [[768] * 4] * 12 + [[528] * 4], + index=timedelta_range(start="1s", periods=13, freq="3s"), + ) + expected.columns = df.columns + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("unit", ["s", "ms", "us", "ns"]) +def test_resample_quantile_timedelta(unit): + # GH: 29485 + dtype = np.dtype(f"m8[{unit}]") + df = DataFrame( + {"value": pd.to_timedelta(np.arange(4), unit="s").astype(dtype)}, + index=pd.date_range("20200101", periods=4, tz="UTC"), + ) + result = df.resample("2D").quantile(0.99) + expected = DataFrame( + { + "value": [ + pd.Timedelta("0 days 00:00:00.990000"), + pd.Timedelta("0 days 00:00:02.990000"), + ] + }, + index=pd.date_range("20200101", periods=2, tz="UTC", freq="2D"), + ).astype(dtype) + tm.assert_frame_equal(result, expected) + + +def test_resample_closed_right(): + # GH#45414 + idx = pd.Index([pd.Timedelta(seconds=120 + i * 30) for i in range(10)]) + ser = Series(range(10), index=idx) + result = ser.resample("min", closed="right", label="right").sum() + expected = Series( + [0, 3, 7, 11, 15, 9], + index=pd.TimedeltaIndex( + [pd.Timedelta(seconds=120 + i * 60) for i in range(6)], freq="min" + ), + ) + tm.assert_series_equal(result, expected) + + +@td.skip_if_no("pyarrow") +def test_arrow_duration_resample(): + # GH 56371 + idx = pd.Index(timedelta_range("1 day", periods=5), dtype="duration[ns][pyarrow]") + expected = Series(np.arange(5, dtype=np.float64), index=idx) + result = expected.resample("1D").mean() + tm.assert_series_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abd60bf68c56f905a84bd7acae5ad1099fc5079d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_crosstab.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_crosstab.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35451164daa35080574194e8aaf1b9cc45d828bf Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_crosstab.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_cut.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_cut.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5221c63481f99dc329968b87138d2348d6a2cf68 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_cut.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_from_dummies.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_from_dummies.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd5efab201c2816205fda9c4398b9e0d9bb1bddc Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_from_dummies.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_get_dummies.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_get_dummies.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1887b935b9a5e1bbb96f07e1b0c9b4fcab0dcc81 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_get_dummies.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_melt.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_melt.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76d6a95050b79795bf44cc4cfc5856652d76185e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_melt.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_pivot_multilevel.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_pivot_multilevel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b6ba4597d37ff5153848ac6b446bc0c84f12312 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_pivot_multilevel.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_qcut.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_qcut.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f93f027bb8973a15bb190e950ec95a1477e2ea3 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_qcut.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_union_categoricals.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_union_categoricals.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e90be5afa61ba05793073ebc91a5afa86e90406 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_union_categoricals.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_util.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_util.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f71bdaa6d34974ff377c536a76528b466f542c5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/__pycache__/test_util.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..374210944e524346bb56c06231e5e7cb9408cbd8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/conftest.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/conftest.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de7ecf5ddc3ccf7d7a04825f5782accd0dc5499f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/conftest.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_append.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_append.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c36d655ce203069c381df0634fb2afee46af316 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_append.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_append_common.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_append_common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5312f474981d3e46f1cbaadd38d082064c65e9f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_append_common.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_categorical.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_categorical.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e9cb3b9d953ab3b2ce306380d491ae94b5a5158 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_categorical.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_concat.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_concat.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3002df312b627da4c32b13931c8a78397c85ad3e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_concat.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_dataframe.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_dataframe.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3da8e7c70b4c1081a0fcb029c0b5f81ed3779257 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_dataframe.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_datetimes.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_datetimes.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f45cf37cb8028718bd600906bd6e656107d26fd7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_datetimes.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_empty.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_empty.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9b0832312c0f8dd0c6c8d67860782a4c67003bb Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_empty.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_index.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_index.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18871f4ab0281b2c556cc79125b3b1e54d679bd4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_index.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_invalid.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_invalid.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c809da307d9c89441f6d9604b6531ab3afa6ef7c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_invalid.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_series.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_series.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b4c194b6adc6780a125ed2db4df976df15e7645 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_series.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_sort.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_sort.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..971744bf57222e3a1da496ee4450124ddf09cd3c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/concat/__pycache__/test_sort.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa789f9f930dd58ff3e7dc6694591b8fbd8928b9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/__pycache__/test_join.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/__pycache__/test_join.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6026d5b53a0722e40fdd390082511d17ada7b0e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/__pycache__/test_join.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/__pycache__/test_merge_asof.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/__pycache__/test_merge_asof.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da2d40e21381b01d5bd163ab4498e7592ee8b0aa Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/__pycache__/test_merge_asof.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/__pycache__/test_merge_cross.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/__pycache__/test_merge_cross.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b026b9c1a3c3f1a4921f9270dd886c6eec1b2328 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/__pycache__/test_merge_cross.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/__pycache__/test_merge_index_as_string.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/__pycache__/test_merge_index_as_string.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e500979ac0eba72f2956d8c8dde802a64e27777 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/__pycache__/test_merge_index_as_string.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/__pycache__/test_merge_ordered.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/__pycache__/test_merge_ordered.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eaa3282a50f1c7b81cfeef2fc77f6a2147e1a0a9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/__pycache__/test_merge_ordered.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/__pycache__/test_multi.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/__pycache__/test_multi.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..055b24ea9bc9c3f1840ef05f1190cbe45c5d2771 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/__pycache__/test_multi.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/test_merge_asof.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/test_merge_asof.py new file mode 100644 index 0000000000000000000000000000000000000000..a2e22ea73fd86ef1795298d99404bfdb04be81f2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/test_merge_asof.py @@ -0,0 +1,3657 @@ +import datetime + +import numpy as np +import pytest +import pytz + +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + Index, + Timedelta, + merge_asof, + option_context, + to_datetime, +) +import pandas._testing as tm +from pandas.core.reshape.merge import MergeError + + +@pytest.fixture(params=["s", "ms", "us", "ns"]) +def unit(request): + """ + Resolution for datetimelike dtypes. + """ + return request.param + + +class TestAsOfMerge: + def prep_data(self, df, dedupe=False): + if dedupe: + df = df.drop_duplicates(["time", "ticker"], keep="last").reset_index( + drop=True + ) + df.time = to_datetime(df.time) + return df + + @pytest.fixture + def trades(self): + df = pd.DataFrame( + [ + ["20160525 13:30:00.023", "MSFT", "51.9500", "75", "NASDAQ"], + ["20160525 13:30:00.038", "MSFT", "51.9500", "155", "NASDAQ"], + ["20160525 13:30:00.048", "GOOG", "720.7700", "100", "NASDAQ"], + ["20160525 13:30:00.048", "GOOG", "720.9200", "100", "NASDAQ"], + ["20160525 13:30:00.048", "GOOG", "720.9300", "200", "NASDAQ"], + ["20160525 13:30:00.048", "GOOG", "720.9300", "300", "NASDAQ"], + ["20160525 13:30:00.048", "GOOG", "720.9300", "600", "NASDAQ"], + ["20160525 13:30:00.048", "GOOG", "720.9300", "44", "NASDAQ"], + ["20160525 13:30:00.074", "AAPL", "98.6700", "478343", "NASDAQ"], + ["20160525 13:30:00.075", "AAPL", "98.6700", "478343", "NASDAQ"], + ["20160525 13:30:00.075", "AAPL", "98.6600", "6", "NASDAQ"], + ["20160525 13:30:00.075", "AAPL", "98.6500", "30", "NASDAQ"], + ["20160525 13:30:00.075", "AAPL", "98.6500", "75", "NASDAQ"], + ["20160525 13:30:00.075", "AAPL", "98.6500", "20", "NASDAQ"], + ["20160525 13:30:00.075", "AAPL", "98.6500", "35", "NASDAQ"], + ["20160525 13:30:00.075", "AAPL", "98.6500", "10", "NASDAQ"], + ["20160525 13:30:00.075", "AAPL", "98.5500", "6", "ARCA"], + ["20160525 13:30:00.075", "AAPL", "98.5500", "6", "ARCA"], + ["20160525 13:30:00.076", "AAPL", "98.5600", "1000", "ARCA"], + ["20160525 13:30:00.076", "AAPL", "98.5600", "200", "ARCA"], + ["20160525 13:30:00.076", "AAPL", "98.5600", "300", "ARCA"], + ["20160525 13:30:00.076", "AAPL", "98.5600", "400", "ARCA"], + ["20160525 13:30:00.076", "AAPL", "98.5600", "600", "ARCA"], + ["20160525 13:30:00.076", "AAPL", "98.5600", "200", "ARCA"], + ["20160525 13:30:00.078", "MSFT", "51.9500", "783", "NASDAQ"], + ["20160525 13:30:00.078", "MSFT", "51.9500", "100", "NASDAQ"], + ["20160525 13:30:00.078", "MSFT", "51.9500", "100", "NASDAQ"], + ], + columns="time,ticker,price,quantity,marketCenter".split(","), + ) + df["price"] = df["price"].astype("float64") + df["quantity"] = df["quantity"].astype("int64") + return self.prep_data(df) + + @pytest.fixture + def quotes(self): + df = pd.DataFrame( + [ + ["20160525 13:30:00.023", "GOOG", "720.50", "720.93"], + ["20160525 13:30:00.023", "MSFT", "51.95", "51.95"], + ["20160525 13:30:00.041", "MSFT", "51.95", "51.95"], + ["20160525 13:30:00.048", "GOOG", "720.50", "720.93"], + ["20160525 13:30:00.048", "GOOG", "720.50", "720.93"], + ["20160525 13:30:00.048", "GOOG", "720.50", "720.93"], + ["20160525 13:30:00.048", "GOOG", "720.50", "720.93"], + ["20160525 13:30:00.072", "GOOG", "720.50", "720.88"], + ["20160525 13:30:00.075", "AAPL", "98.55", "98.56"], + ["20160525 13:30:00.076", "AAPL", "98.55", "98.56"], + ["20160525 13:30:00.076", "AAPL", "98.55", "98.56"], + ["20160525 13:30:00.076", "AAPL", "98.55", "98.56"], + ["20160525 13:30:00.078", "MSFT", "51.95", "51.95"], + ["20160525 13:30:00.078", "MSFT", "51.95", "51.95"], + ["20160525 13:30:00.078", "MSFT", "51.95", "51.95"], + ["20160525 13:30:00.078", "MSFT", "51.92", "51.95"], + ], + columns="time,ticker,bid,ask".split(","), + ) + df["bid"] = df["bid"].astype("float64") + df["ask"] = df["ask"].astype("float64") + return self.prep_data(df, dedupe=True) + + @pytest.fixture + def asof(self): + df = pd.DataFrame( + [ + [ + "20160525 13:30:00.023", + "MSFT", + "51.95", + "75", + "NASDAQ", + "51.95", + "51.95", + ], + [ + "20160525 13:30:00.038", + "MSFT", + "51.95", + "155", + "NASDAQ", + "51.95", + "51.95", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.77", + "100", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.92", + "100", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.93", + "200", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.93", + "300", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.93", + "600", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.93", + "44", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.074", + "AAPL", + "98.67", + "478343", + "NASDAQ", + np.nan, + np.nan, + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.67", + "478343", + "NASDAQ", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.66", + "6", + "NASDAQ", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.65", + "30", + "NASDAQ", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.65", + "75", + "NASDAQ", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.65", + "20", + "NASDAQ", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.65", + "35", + "NASDAQ", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.65", + "10", + "NASDAQ", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.55", + "6", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.55", + "6", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "1000", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "200", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "300", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "400", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "600", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "200", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.078", + "MSFT", + "51.95", + "783", + "NASDAQ", + "51.92", + "51.95", + ], + [ + "20160525 13:30:00.078", + "MSFT", + "51.95", + "100", + "NASDAQ", + "51.92", + "51.95", + ], + [ + "20160525 13:30:00.078", + "MSFT", + "51.95", + "100", + "NASDAQ", + "51.92", + "51.95", + ], + ], + columns="time,ticker,price,quantity,marketCenter,bid,ask".split(","), + ) + df["price"] = df["price"].astype("float64") + df["quantity"] = df["quantity"].astype("int64") + df["bid"] = df["bid"].astype("float64") + df["ask"] = df["ask"].astype("float64") + return self.prep_data(df) + + @pytest.fixture + def tolerance(self): + df = pd.DataFrame( + [ + [ + "20160525 13:30:00.023", + "MSFT", + "51.95", + "75", + "NASDAQ", + "51.95", + "51.95", + ], + [ + "20160525 13:30:00.038", + "MSFT", + "51.95", + "155", + "NASDAQ", + "51.95", + "51.95", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.77", + "100", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.92", + "100", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.93", + "200", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.93", + "300", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.93", + "600", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.93", + "44", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.074", + "AAPL", + "98.67", + "478343", + "NASDAQ", + np.nan, + np.nan, + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.67", + "478343", + "NASDAQ", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.66", + "6", + "NASDAQ", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.65", + "30", + "NASDAQ", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.65", + "75", + "NASDAQ", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.65", + "20", + "NASDAQ", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.65", + "35", + "NASDAQ", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.65", + "10", + "NASDAQ", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.55", + "6", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.55", + "6", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "1000", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "200", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "300", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "400", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "600", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "200", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.078", + "MSFT", + "51.95", + "783", + "NASDAQ", + "51.92", + "51.95", + ], + [ + "20160525 13:30:00.078", + "MSFT", + "51.95", + "100", + "NASDAQ", + "51.92", + "51.95", + ], + [ + "20160525 13:30:00.078", + "MSFT", + "51.95", + "100", + "NASDAQ", + "51.92", + "51.95", + ], + ], + columns="time,ticker,price,quantity,marketCenter,bid,ask".split(","), + ) + df["price"] = df["price"].astype("float64") + df["quantity"] = df["quantity"].astype("int64") + df["bid"] = df["bid"].astype("float64") + df["ask"] = df["ask"].astype("float64") + return self.prep_data(df) + + @pytest.fixture + def allow_exact_matches(self, datapath): + df = pd.DataFrame( + [ + [ + "20160525 13:30:00.023", + "MSFT", + "51.95", + "75", + "NASDAQ", + np.nan, + np.nan, + ], + [ + "20160525 13:30:00.038", + "MSFT", + "51.95", + "155", + "NASDAQ", + "51.95", + "51.95", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.77", + "100", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.92", + "100", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.93", + "200", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.93", + "300", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.93", + "600", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.93", + "44", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.074", + "AAPL", + "98.67", + "478343", + "NASDAQ", + np.nan, + np.nan, + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.67", + "478343", + "NASDAQ", + np.nan, + np.nan, + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.66", + "6", + "NASDAQ", + np.nan, + np.nan, + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.65", + "30", + "NASDAQ", + np.nan, + np.nan, + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.65", + "75", + "NASDAQ", + np.nan, + np.nan, + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.65", + "20", + "NASDAQ", + np.nan, + np.nan, + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.65", + "35", + "NASDAQ", + np.nan, + np.nan, + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.65", + "10", + "NASDAQ", + np.nan, + np.nan, + ], + ["20160525 13:30:00.075", "AAPL", "98.55", "6", "ARCA", np.nan, np.nan], + ["20160525 13:30:00.075", "AAPL", "98.55", "6", "ARCA", np.nan, np.nan], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "1000", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "200", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "300", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "400", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "600", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "200", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.078", + "MSFT", + "51.95", + "783", + "NASDAQ", + "51.95", + "51.95", + ], + [ + "20160525 13:30:00.078", + "MSFT", + "51.95", + "100", + "NASDAQ", + "51.95", + "51.95", + ], + [ + "20160525 13:30:00.078", + "MSFT", + "51.95", + "100", + "NASDAQ", + "51.95", + "51.95", + ], + ], + columns="time,ticker,price,quantity,marketCenter,bid,ask".split(","), + ) + df["price"] = df["price"].astype("float64") + df["quantity"] = df["quantity"].astype("int64") + df["bid"] = df["bid"].astype("float64") + df["ask"] = df["ask"].astype("float64") + return self.prep_data(df) + + @pytest.fixture + def allow_exact_matches_and_tolerance(self): + df = pd.DataFrame( + [ + [ + "20160525 13:30:00.023", + "MSFT", + "51.95", + "75", + "NASDAQ", + np.nan, + np.nan, + ], + [ + "20160525 13:30:00.038", + "MSFT", + "51.95", + "155", + "NASDAQ", + "51.95", + "51.95", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.77", + "100", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.92", + "100", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.93", + "200", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.93", + "300", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.93", + "600", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.93", + "44", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.074", + "AAPL", + "98.67", + "478343", + "NASDAQ", + np.nan, + np.nan, + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.67", + "478343", + "NASDAQ", + np.nan, + np.nan, + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.66", + "6", + "NASDAQ", + np.nan, + np.nan, + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.65", + "30", + "NASDAQ", + np.nan, + np.nan, + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.65", + "75", + "NASDAQ", + np.nan, + np.nan, + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.65", + "20", + "NASDAQ", + np.nan, + np.nan, + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.65", + "35", + "NASDAQ", + np.nan, + np.nan, + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.65", + "10", + "NASDAQ", + np.nan, + np.nan, + ], + ["20160525 13:30:00.075", "AAPL", "98.55", "6", "ARCA", np.nan, np.nan], + ["20160525 13:30:00.075", "AAPL", "98.55", "6", "ARCA", np.nan, np.nan], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "1000", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "200", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "300", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "400", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "600", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "200", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.078", + "MSFT", + "51.95", + "783", + "NASDAQ", + "51.95", + "51.95", + ], + [ + "20160525 13:30:00.078", + "MSFT", + "51.95", + "100", + "NASDAQ", + "51.95", + "51.95", + ], + [ + "20160525 13:30:00.078", + "MSFT", + "51.95", + "100", + "NASDAQ", + "51.95", + "51.95", + ], + ], + columns="time,ticker,price,quantity,marketCenter,bid,ask".split(","), + ) + df["price"] = df["price"].astype("float64") + df["quantity"] = df["quantity"].astype("int64") + df["bid"] = df["bid"].astype("float64") + df["ask"] = df["ask"].astype("float64") + return self.prep_data(df) + + def test_examples1(self): + """doc-string examples""" + left = pd.DataFrame({"a": [1, 5, 10], "left_val": ["a", "b", "c"]}) + right = pd.DataFrame({"a": [1, 2, 3, 6, 7], "right_val": [1, 2, 3, 6, 7]}) + + expected = pd.DataFrame( + {"a": [1, 5, 10], "left_val": ["a", "b", "c"], "right_val": [1, 3, 7]} + ) + + result = merge_asof(left, right, on="a") + tm.assert_frame_equal(result, expected) + + def test_examples2(self, unit): + """doc-string examples""" + if unit == "s": + pytest.skip( + "This test is invalid for unit='s' because that would " + "round the trades['time']]" + ) + trades = pd.DataFrame( + { + "time": to_datetime( + [ + "20160525 13:30:00.023", + "20160525 13:30:00.038", + "20160525 13:30:00.048", + "20160525 13:30:00.048", + "20160525 13:30:00.048", + ] + ).astype(f"M8[{unit}]"), + "ticker": ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"], + "price": [51.95, 51.95, 720.77, 720.92, 98.00], + "quantity": [75, 155, 100, 100, 100], + }, + columns=["time", "ticker", "price", "quantity"], + ) + + quotes = pd.DataFrame( + { + "time": to_datetime( + [ + "20160525 13:30:00.023", + "20160525 13:30:00.023", + "20160525 13:30:00.030", + "20160525 13:30:00.041", + "20160525 13:30:00.048", + "20160525 13:30:00.049", + "20160525 13:30:00.072", + "20160525 13:30:00.075", + ] + ).astype(f"M8[{unit}]"), + "ticker": [ + "GOOG", + "MSFT", + "MSFT", + "MSFT", + "GOOG", + "AAPL", + "GOOG", + "MSFT", + ], + "bid": [720.50, 51.95, 51.97, 51.99, 720.50, 97.99, 720.50, 52.01], + "ask": [720.93, 51.96, 51.98, 52.00, 720.93, 98.01, 720.88, 52.03], + }, + columns=["time", "ticker", "bid", "ask"], + ) + + merge_asof(trades, quotes, on="time", by="ticker") + + merge_asof(trades, quotes, on="time", by="ticker", tolerance=Timedelta("2ms")) + + expected = pd.DataFrame( + { + "time": to_datetime( + [ + "20160525 13:30:00.023", + "20160525 13:30:00.038", + "20160525 13:30:00.048", + "20160525 13:30:00.048", + "20160525 13:30:00.048", + ] + ).astype(f"M8[{unit}]"), + "ticker": ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"], + "price": [51.95, 51.95, 720.77, 720.92, 98.00], + "quantity": [75, 155, 100, 100, 100], + "bid": [np.nan, 51.97, np.nan, np.nan, np.nan], + "ask": [np.nan, 51.98, np.nan, np.nan, np.nan], + }, + columns=["time", "ticker", "price", "quantity", "bid", "ask"], + ) + + result = merge_asof( + trades, + quotes, + on="time", + by="ticker", + tolerance=Timedelta("10ms"), + allow_exact_matches=False, + ) + tm.assert_frame_equal(result, expected) + + def test_examples3(self): + """doc-string examples""" + # GH14887 + + left = pd.DataFrame({"a": [1, 5, 10], "left_val": ["a", "b", "c"]}) + right = pd.DataFrame({"a": [1, 2, 3, 6, 7], "right_val": [1, 2, 3, 6, 7]}) + + expected = pd.DataFrame( + {"a": [1, 5, 10], "left_val": ["a", "b", "c"], "right_val": [1, 6, np.nan]} + ) + + result = merge_asof(left, right, on="a", direction="forward") + tm.assert_frame_equal(result, expected) + + def test_examples4(self): + """doc-string examples""" + # GH14887 + + left = pd.DataFrame({"a": [1, 5, 10], "left_val": ["a", "b", "c"]}) + right = pd.DataFrame({"a": [1, 2, 3, 6, 7], "right_val": [1, 2, 3, 6, 7]}) + + expected = pd.DataFrame( + {"a": [1, 5, 10], "left_val": ["a", "b", "c"], "right_val": [1, 6, 7]} + ) + + result = merge_asof(left, right, on="a", direction="nearest") + tm.assert_frame_equal(result, expected) + + def test_basic(self, trades, asof, quotes): + expected = asof + + result = merge_asof(trades, quotes, on="time", by="ticker") + tm.assert_frame_equal(result, expected) + + def test_basic_categorical(self, trades, asof, quotes): + expected = asof + trades.ticker = trades.ticker.astype("category") + quotes.ticker = quotes.ticker.astype("category") + expected.ticker = expected.ticker.astype("category") + + result = merge_asof(trades, quotes, on="time", by="ticker") + tm.assert_frame_equal(result, expected) + + def test_basic_left_index(self, trades, asof, quotes): + # GH14253 + expected = asof + trades = trades.set_index("time") + + result = merge_asof( + trades, quotes, left_index=True, right_on="time", by="ticker" + ) + # left-only index uses right"s index, oddly + expected.index = result.index + # time column appears after left"s columns + expected = expected[result.columns] + tm.assert_frame_equal(result, expected) + + def test_basic_right_index(self, trades, asof, quotes): + expected = asof + quotes = quotes.set_index("time") + + result = merge_asof( + trades, quotes, left_on="time", right_index=True, by="ticker" + ) + tm.assert_frame_equal(result, expected) + + def test_basic_left_index_right_index(self, trades, asof, quotes): + expected = asof.set_index("time") + trades = trades.set_index("time") + quotes = quotes.set_index("time") + + result = merge_asof( + trades, quotes, left_index=True, right_index=True, by="ticker" + ) + tm.assert_frame_equal(result, expected) + + def test_multi_index_left(self, trades, quotes): + # MultiIndex is prohibited + trades = trades.set_index(["time", "price"]) + quotes = quotes.set_index("time") + with pytest.raises(MergeError, match="left can only have one index"): + merge_asof(trades, quotes, left_index=True, right_index=True) + + def test_multi_index_right(self, trades, quotes): + # MultiIndex is prohibited + trades = trades.set_index("time") + quotes = quotes.set_index(["time", "bid"]) + with pytest.raises(MergeError, match="right can only have one index"): + merge_asof(trades, quotes, left_index=True, right_index=True) + + def test_on_and_index_left_on(self, trades, quotes): + # "on" parameter and index together is prohibited + trades = trades.set_index("time") + quotes = quotes.set_index("time") + msg = 'Can only pass argument "left_on" OR "left_index" not both.' + with pytest.raises(MergeError, match=msg): + merge_asof( + trades, quotes, left_on="price", left_index=True, right_index=True + ) + + def test_on_and_index_right_on(self, trades, quotes): + trades = trades.set_index("time") + quotes = quotes.set_index("time") + msg = 'Can only pass argument "right_on" OR "right_index" not both.' + with pytest.raises(MergeError, match=msg): + merge_asof( + trades, quotes, right_on="bid", left_index=True, right_index=True + ) + + def test_basic_left_by_right_by(self, trades, asof, quotes): + # GH14253 + expected = asof + + result = merge_asof( + trades, quotes, on="time", left_by="ticker", right_by="ticker" + ) + tm.assert_frame_equal(result, expected) + + def test_missing_right_by(self, trades, asof, quotes): + expected = asof + + q = quotes[quotes.ticker != "MSFT"] + result = merge_asof(trades, q, on="time", by="ticker") + expected.loc[expected.ticker == "MSFT", ["bid", "ask"]] = np.nan + tm.assert_frame_equal(result, expected) + + def test_multiby(self): + # GH13936 + trades = pd.DataFrame( + { + "time": to_datetime( + [ + "20160525 13:30:00.023", + "20160525 13:30:00.023", + "20160525 13:30:00.046", + "20160525 13:30:00.048", + "20160525 13:30:00.050", + ] + ), + "ticker": ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"], + "exch": ["ARCA", "NSDQ", "NSDQ", "BATS", "NSDQ"], + "price": [51.95, 51.95, 720.77, 720.92, 98.00], + "quantity": [75, 155, 100, 100, 100], + }, + columns=["time", "ticker", "exch", "price", "quantity"], + ) + + quotes = pd.DataFrame( + { + "time": to_datetime( + [ + "20160525 13:30:00.023", + "20160525 13:30:00.023", + "20160525 13:30:00.030", + "20160525 13:30:00.041", + "20160525 13:30:00.045", + "20160525 13:30:00.049", + ] + ), + "ticker": ["GOOG", "MSFT", "MSFT", "MSFT", "GOOG", "AAPL"], + "exch": ["BATS", "NSDQ", "ARCA", "ARCA", "NSDQ", "ARCA"], + "bid": [720.51, 51.95, 51.97, 51.99, 720.50, 97.99], + "ask": [720.92, 51.96, 51.98, 52.00, 720.93, 98.01], + }, + columns=["time", "ticker", "exch", "bid", "ask"], + ) + + expected = pd.DataFrame( + { + "time": to_datetime( + [ + "20160525 13:30:00.023", + "20160525 13:30:00.023", + "20160525 13:30:00.046", + "20160525 13:30:00.048", + "20160525 13:30:00.050", + ] + ), + "ticker": ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"], + "exch": ["ARCA", "NSDQ", "NSDQ", "BATS", "NSDQ"], + "price": [51.95, 51.95, 720.77, 720.92, 98.00], + "quantity": [75, 155, 100, 100, 100], + "bid": [np.nan, 51.95, 720.50, 720.51, np.nan], + "ask": [np.nan, 51.96, 720.93, 720.92, np.nan], + }, + columns=["time", "ticker", "exch", "price", "quantity", "bid", "ask"], + ) + + result = merge_asof(trades, quotes, on="time", by=["ticker", "exch"]) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("dtype", ["object", "string"]) + def test_multiby_heterogeneous_types(self, dtype): + # GH13936 + trades = pd.DataFrame( + { + "time": to_datetime( + [ + "20160525 13:30:00.023", + "20160525 13:30:00.023", + "20160525 13:30:00.046", + "20160525 13:30:00.048", + "20160525 13:30:00.050", + ] + ), + "ticker": [0, 0, 1, 1, 2], + "exch": ["ARCA", "NSDQ", "NSDQ", "BATS", "NSDQ"], + "price": [51.95, 51.95, 720.77, 720.92, 98.00], + "quantity": [75, 155, 100, 100, 100], + }, + columns=["time", "ticker", "exch", "price", "quantity"], + ) + trades = trades.astype({"ticker": dtype, "exch": dtype}) + + quotes = pd.DataFrame( + { + "time": to_datetime( + [ + "20160525 13:30:00.023", + "20160525 13:30:00.023", + "20160525 13:30:00.030", + "20160525 13:30:00.041", + "20160525 13:30:00.045", + "20160525 13:30:00.049", + ] + ), + "ticker": [1, 0, 0, 0, 1, 2], + "exch": ["BATS", "NSDQ", "ARCA", "ARCA", "NSDQ", "ARCA"], + "bid": [720.51, 51.95, 51.97, 51.99, 720.50, 97.99], + "ask": [720.92, 51.96, 51.98, 52.00, 720.93, 98.01], + }, + columns=["time", "ticker", "exch", "bid", "ask"], + ) + quotes = quotes.astype({"ticker": dtype, "exch": dtype}) + + expected = pd.DataFrame( + { + "time": to_datetime( + [ + "20160525 13:30:00.023", + "20160525 13:30:00.023", + "20160525 13:30:00.046", + "20160525 13:30:00.048", + "20160525 13:30:00.050", + ] + ), + "ticker": [0, 0, 1, 1, 2], + "exch": ["ARCA", "NSDQ", "NSDQ", "BATS", "NSDQ"], + "price": [51.95, 51.95, 720.77, 720.92, 98.00], + "quantity": [75, 155, 100, 100, 100], + "bid": [np.nan, 51.95, 720.50, 720.51, np.nan], + "ask": [np.nan, 51.96, 720.93, 720.92, np.nan], + }, + columns=["time", "ticker", "exch", "price", "quantity", "bid", "ask"], + ) + expected = expected.astype({"ticker": dtype, "exch": dtype}) + + result = merge_asof(trades, quotes, on="time", by=["ticker", "exch"]) + tm.assert_frame_equal(result, expected) + + def test_mismatched_index_dtype(self): + # similar to test_multiby_indexed, but we change the dtype on left.index + left = pd.DataFrame( + [ + [to_datetime("20160602"), 1, "a"], + [to_datetime("20160602"), 2, "a"], + [to_datetime("20160603"), 1, "b"], + [to_datetime("20160603"), 2, "b"], + ], + columns=["time", "k1", "k2"], + ).set_index("time") + # different dtype for the index + left.index = left.index - pd.Timestamp(0) + + right = pd.DataFrame( + [ + [to_datetime("20160502"), 1, "a", 1.0], + [to_datetime("20160502"), 2, "a", 2.0], + [to_datetime("20160503"), 1, "b", 3.0], + [to_datetime("20160503"), 2, "b", 4.0], + ], + columns=["time", "k1", "k2", "value"], + ).set_index("time") + + msg = "incompatible merge keys" + with pytest.raises(MergeError, match=msg): + merge_asof(left, right, left_index=True, right_index=True, by=["k1", "k2"]) + + def test_multiby_indexed(self): + # GH15676 + left = pd.DataFrame( + [ + [to_datetime("20160602"), 1, "a"], + [to_datetime("20160602"), 2, "a"], + [to_datetime("20160603"), 1, "b"], + [to_datetime("20160603"), 2, "b"], + ], + columns=["time", "k1", "k2"], + ).set_index("time") + + right = pd.DataFrame( + [ + [to_datetime("20160502"), 1, "a", 1.0], + [to_datetime("20160502"), 2, "a", 2.0], + [to_datetime("20160503"), 1, "b", 3.0], + [to_datetime("20160503"), 2, "b", 4.0], + ], + columns=["time", "k1", "k2", "value"], + ).set_index("time") + + expected = pd.DataFrame( + [ + [to_datetime("20160602"), 1, "a", 1.0], + [to_datetime("20160602"), 2, "a", 2.0], + [to_datetime("20160603"), 1, "b", 3.0], + [to_datetime("20160603"), 2, "b", 4.0], + ], + columns=["time", "k1", "k2", "value"], + ).set_index("time") + + result = merge_asof( + left, right, left_index=True, right_index=True, by=["k1", "k2"] + ) + + tm.assert_frame_equal(expected, result) + + with pytest.raises( + MergeError, match="left_by and right_by must be the same length" + ): + merge_asof( + left, + right, + left_index=True, + right_index=True, + left_by=["k1", "k2"], + right_by=["k1"], + ) + + def test_basic2(self, datapath): + expected = pd.DataFrame( + [ + [ + "20160525 13:30:00.023", + "MSFT", + "51.95", + "75", + "NASDAQ", + "51.95", + "51.95", + ], + [ + "20160525 13:30:00.038", + "MSFT", + "51.95", + "155", + "NASDAQ", + "51.95", + "51.95", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.77", + "100", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.92", + "100", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.93", + "200", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.93", + "300", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.93", + "600", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.048", + "GOOG", + "720.93", + "44", + "NASDAQ", + "720.5", + "720.93", + ], + [ + "20160525 13:30:00.074", + "AAPL", + "98.67", + "478343", + "NASDAQ", + np.nan, + np.nan, + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.67", + "478343", + "NASDAQ", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.66", + "6", + "NASDAQ", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.65", + "30", + "NASDAQ", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.65", + "75", + "NASDAQ", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.65", + "20", + "NASDAQ", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.65", + "35", + "NASDAQ", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.65", + "10", + "NASDAQ", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.55", + "6", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.075", + "AAPL", + "98.55", + "6", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "1000", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "200", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "300", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "400", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "600", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.076", + "AAPL", + "98.56", + "200", + "ARCA", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.078", + "MSFT", + "51.95", + "783", + "NASDAQ", + "51.92", + "51.95", + ], + [ + "20160525 13:30:00.078", + "MSFT", + "51.95", + "100", + "NASDAQ", + "51.92", + "51.95", + ], + [ + "20160525 13:30:00.078", + "MSFT", + "51.95", + "100", + "NASDAQ", + "51.92", + "51.95", + ], + [ + "20160525 13:30:00.084", + "AAPL", + "98.64", + "40", + "NASDAQ", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.084", + "AAPL", + "98.55", + "149", + "EDGX", + "98.55", + "98.56", + ], + [ + "20160525 13:30:00.086", + "AAPL", + "98.56", + "500", + "ARCA", + "98.55", + "98.63", + ], + [ + "20160525 13:30:00.104", + "AAPL", + "98.63", + "647", + "EDGX", + "98.62", + "98.63", + ], + [ + "20160525 13:30:00.104", + "AAPL", + "98.63", + "300", + "EDGX", + "98.62", + "98.63", + ], + [ + "20160525 13:30:00.104", + "AAPL", + "98.63", + "50", + "NASDAQ", + "98.62", + "98.63", + ], + [ + "20160525 13:30:00.104", + "AAPL", + "98.63", + "50", + "NASDAQ", + "98.62", + "98.63", + ], + [ + "20160525 13:30:00.104", + "AAPL", + "98.63", + "70", + "NASDAQ", + "98.62", + "98.63", + ], + [ + "20160525 13:30:00.104", + "AAPL", + "98.63", + "70", + "NASDAQ", + "98.62", + "98.63", + ], + [ + "20160525 13:30:00.104", + "AAPL", + "98.63", + "1", + "NASDAQ", + "98.62", + "98.63", + ], + [ + "20160525 13:30:00.104", + "AAPL", + "98.63", + "62", + "NASDAQ", + "98.62", + "98.63", + ], + [ + "20160525 13:30:00.104", + "AAPL", + "98.63", + "10", + "NASDAQ", + "98.62", + "98.63", + ], + [ + "20160525 13:30:00.104", + "AAPL", + "98.63", + "100", + "ARCA", + "98.62", + "98.63", + ], + [ + "20160525 13:30:00.105", + "AAPL", + "98.63", + "100", + "ARCA", + "98.62", + "98.63", + ], + [ + "20160525 13:30:00.105", + "AAPL", + "98.63", + "700", + "ARCA", + "98.62", + "98.63", + ], + [ + "20160525 13:30:00.106", + "AAPL", + "98.63", + "61", + "EDGX", + "98.62", + "98.63", + ], + [ + "20160525 13:30:00.107", + "AAPL", + "98.63", + "100", + "ARCA", + "98.62", + "98.63", + ], + [ + "20160525 13:30:00.107", + "AAPL", + "98.63", + "53", + "ARCA", + "98.62", + "98.63", + ], + [ + "20160525 13:30:00.108", + "AAPL", + "98.63", + "100", + "ARCA", + "98.62", + "98.63", + ], + [ + "20160525 13:30:00.108", + "AAPL", + "98.63", + "839", + "ARCA", + "98.62", + "98.63", + ], + [ + "20160525 13:30:00.115", + "AAPL", + "98.63", + "5", + "EDGX", + "98.62", + "98.63", + ], + [ + "20160525 13:30:00.118", + "AAPL", + "98.63", + "295", + "EDGX", + "98.62", + "98.63", + ], + [ + "20160525 13:30:00.118", + "AAPL", + "98.63", + "5", + "EDGX", + "98.62", + "98.63", + ], + [ + "20160525 13:30:00.128", + "AAPL", + "98.63", + "100", + "NASDAQ", + "98.62", + "98.63", + ], + [ + "20160525 13:30:00.128", + "AAPL", + "98.63", + "100", + "NASDAQ", + "98.62", + "98.63", + ], + [ + "20160525 13:30:00.128", + "MSFT", + "51.92", + "100", + "ARCA", + "51.92", + "51.95", + ], + [ + "20160525 13:30:00.129", + "AAPL", + "98.62", + "100", + "NASDAQ", + "98.61", + "98.63", + ], + [ + "20160525 13:30:00.129", + "AAPL", + "98.62", + "10", + "NASDAQ", + "98.61", + "98.63", + ], + [ + "20160525 13:30:00.129", + "AAPL", + "98.62", + "59", + "NASDAQ", + "98.61", + "98.63", + ], + [ + "20160525 13:30:00.129", + "AAPL", + "98.62", + "31", + "NASDAQ", + "98.61", + "98.63", + ], + [ + "20160525 13:30:00.129", + "AAPL", + "98.62", + "69", + "NASDAQ", + "98.61", + "98.63", + ], + [ + "20160525 13:30:00.129", + "AAPL", + "98.62", + "12", + "NASDAQ", + "98.61", + "98.63", + ], + [ + "20160525 13:30:00.129", + "AAPL", + "98.62", + "12", + "EDGX", + "98.61", + "98.63", + ], + [ + "20160525 13:30:00.129", + "AAPL", + "98.62", + "100", + "ARCA", + "98.61", + "98.63", + ], + [ + "20160525 13:30:00.129", + "AAPL", + "98.62", + "100", + "ARCA", + "98.61", + "98.63", + ], + [ + "20160525 13:30:00.130", + "MSFT", + "51.95", + "317", + "ARCA", + "51.93", + "51.95", + ], + [ + "20160525 13:30:00.130", + "MSFT", + "51.95", + "283", + "ARCA", + "51.93", + "51.95", + ], + [ + "20160525 13:30:00.135", + "MSFT", + "51.93", + "100", + "EDGX", + "51.92", + "51.95", + ], + [ + "20160525 13:30:00.135", + "AAPL", + "98.62", + "100", + "ARCA", + "98.61", + "98.62", + ], + [ + "20160525 13:30:00.144", + "AAPL", + "98.62", + "12", + "NASDAQ", + "98.61", + "98.62", + ], + [ + "20160525 13:30:00.144", + "AAPL", + "98.62", + "88", + "NASDAQ", + "98.61", + "98.62", + ], + [ + "20160525 13:30:00.144", + "AAPL", + "98.62", + "162", + "NASDAQ", + "98.61", + "98.62", + ], + [ + "20160525 13:30:00.144", + "AAPL", + "98.61", + "100", + "BATS", + "98.61", + "98.62", + ], + [ + "20160525 13:30:00.144", + "AAPL", + "98.62", + "61", + "ARCA", + "98.61", + "98.62", + ], + [ + "20160525 13:30:00.144", + "AAPL", + "98.62", + "25", + "ARCA", + "98.61", + "98.62", + ], + [ + "20160525 13:30:00.144", + "AAPL", + "98.62", + "14", + "ARCA", + "98.61", + "98.62", + ], + [ + "20160525 13:30:00.145", + "AAPL", + "98.62", + "12", + "ARCA", + "98.6", + "98.63", + ], + [ + "20160525 13:30:00.145", + "AAPL", + "98.62", + "100", + "ARCA", + "98.6", + "98.63", + ], + [ + "20160525 13:30:00.145", + "AAPL", + "98.63", + "100", + "NASDAQ", + "98.6", + "98.63", + ], + [ + "20160525 13:30:00.145", + "AAPL", + "98.63", + "100", + "NASDAQ", + "98.6", + "98.63", + ], + ], + columns="time,ticker,price,quantity,marketCenter,bid,ask".split(","), + ) + expected["price"] = expected["price"].astype("float64") + expected["quantity"] = expected["quantity"].astype("int64") + expected["bid"] = expected["bid"].astype("float64") + expected["ask"] = expected["ask"].astype("float64") + expected = self.prep_data(expected) + + trades = pd.DataFrame( + [ + ["20160525 13:30:00.023", "MSFT", "51.9500", "75", "NASDAQ"], + ["20160525 13:30:00.038", "MSFT", "51.9500", "155", "NASDAQ"], + ["20160525 13:30:00.048", "GOOG", "720.7700", "100", "NASDAQ"], + ["20160525 13:30:00.048", "GOOG", "720.9200", "100", "NASDAQ"], + ["20160525 13:30:00.048", "GOOG", "720.9300", "200", "NASDAQ"], + ["20160525 13:30:00.048", "GOOG", "720.9300", "300", "NASDAQ"], + ["20160525 13:30:00.048", "GOOG", "720.9300", "600", "NASDAQ"], + ["20160525 13:30:00.048", "GOOG", "720.9300", "44", "NASDAQ"], + ["20160525 13:30:00.074", "AAPL", "98.6700", "478343", "NASDAQ"], + ["20160525 13:30:00.075", "AAPL", "98.6700", "478343", "NASDAQ"], + ["20160525 13:30:00.075", "AAPL", "98.6600", "6", "NASDAQ"], + ["20160525 13:30:00.075", "AAPL", "98.6500", "30", "NASDAQ"], + ["20160525 13:30:00.075", "AAPL", "98.6500", "75", "NASDAQ"], + ["20160525 13:30:00.075", "AAPL", "98.6500", "20", "NASDAQ"], + ["20160525 13:30:00.075", "AAPL", "98.6500", "35", "NASDAQ"], + ["20160525 13:30:00.075", "AAPL", "98.6500", "10", "NASDAQ"], + ["20160525 13:30:00.075", "AAPL", "98.5500", "6", "ARCA"], + ["20160525 13:30:00.075", "AAPL", "98.5500", "6", "ARCA"], + ["20160525 13:30:00.076", "AAPL", "98.5600", "1000", "ARCA"], + ["20160525 13:30:00.076", "AAPL", "98.5600", "200", "ARCA"], + ["20160525 13:30:00.076", "AAPL", "98.5600", "300", "ARCA"], + ["20160525 13:30:00.076", "AAPL", "98.5600", "400", "ARCA"], + ["20160525 13:30:00.076", "AAPL", "98.5600", "600", "ARCA"], + ["20160525 13:30:00.076", "AAPL", "98.5600", "200", "ARCA"], + ["20160525 13:30:00.078", "MSFT", "51.9500", "783", "NASDAQ"], + ["20160525 13:30:00.078", "MSFT", "51.9500", "100", "NASDAQ"], + ["20160525 13:30:00.078", "MSFT", "51.9500", "100", "NASDAQ"], + ["20160525 13:30:00.084", "AAPL", "98.6400", "40", "NASDAQ"], + ["20160525 13:30:00.084", "AAPL", "98.5500", "149", "EDGX"], + ["20160525 13:30:00.086", "AAPL", "98.5600", "500", "ARCA"], + ["20160525 13:30:00.104", "AAPL", "98.6300", "647", "EDGX"], + ["20160525 13:30:00.104", "AAPL", "98.6300", "300", "EDGX"], + ["20160525 13:30:00.104", "AAPL", "98.6300", "50", "NASDAQ"], + ["20160525 13:30:00.104", "AAPL", "98.6300", "50", "NASDAQ"], + ["20160525 13:30:00.104", "AAPL", "98.6300", "70", "NASDAQ"], + ["20160525 13:30:00.104", "AAPL", "98.6300", "70", "NASDAQ"], + ["20160525 13:30:00.104", "AAPL", "98.6300", "1", "NASDAQ"], + ["20160525 13:30:00.104", "AAPL", "98.6300", "62", "NASDAQ"], + ["20160525 13:30:00.104", "AAPL", "98.6300", "10", "NASDAQ"], + ["20160525 13:30:00.104", "AAPL", "98.6300", "100", "ARCA"], + ["20160525 13:30:00.105", "AAPL", "98.6300", "100", "ARCA"], + ["20160525 13:30:00.105", "AAPL", "98.6300", "700", "ARCA"], + ["20160525 13:30:00.106", "AAPL", "98.6300", "61", "EDGX"], + ["20160525 13:30:00.107", "AAPL", "98.6300", "100", "ARCA"], + ["20160525 13:30:00.107", "AAPL", "98.6300", "53", "ARCA"], + ["20160525 13:30:00.108", "AAPL", "98.6300", "100", "ARCA"], + ["20160525 13:30:00.108", "AAPL", "98.6300", "839", "ARCA"], + ["20160525 13:30:00.115", "AAPL", "98.6300", "5", "EDGX"], + ["20160525 13:30:00.118", "AAPL", "98.6300", "295", "EDGX"], + ["20160525 13:30:00.118", "AAPL", "98.6300", "5", "EDGX"], + ["20160525 13:30:00.128", "AAPL", "98.6300", "100", "NASDAQ"], + ["20160525 13:30:00.128", "AAPL", "98.6300", "100", "NASDAQ"], + ["20160525 13:30:00.128", "MSFT", "51.9200", "100", "ARCA"], + ["20160525 13:30:00.129", "AAPL", "98.6200", "100", "NASDAQ"], + ["20160525 13:30:00.129", "AAPL", "98.6200", "10", "NASDAQ"], + ["20160525 13:30:00.129", "AAPL", "98.6200", "59", "NASDAQ"], + ["20160525 13:30:00.129", "AAPL", "98.6200", "31", "NASDAQ"], + ["20160525 13:30:00.129", "AAPL", "98.6200", "69", "NASDAQ"], + ["20160525 13:30:00.129", "AAPL", "98.6200", "12", "NASDAQ"], + ["20160525 13:30:00.129", "AAPL", "98.6200", "12", "EDGX"], + ["20160525 13:30:00.129", "AAPL", "98.6200", "100", "ARCA"], + ["20160525 13:30:00.129", "AAPL", "98.6200", "100", "ARCA"], + ["20160525 13:30:00.130", "MSFT", "51.9500", "317", "ARCA"], + ["20160525 13:30:00.130", "MSFT", "51.9500", "283", "ARCA"], + ["20160525 13:30:00.135", "MSFT", "51.9300", "100", "EDGX"], + ["20160525 13:30:00.135", "AAPL", "98.6200", "100", "ARCA"], + ["20160525 13:30:00.144", "AAPL", "98.6200", "12", "NASDAQ"], + ["20160525 13:30:00.144", "AAPL", "98.6200", "88", "NASDAQ"], + ["20160525 13:30:00.144", "AAPL", "98.6200", "162", "NASDAQ"], + ["20160525 13:30:00.144", "AAPL", "98.6100", "100", "BATS"], + ["20160525 13:30:00.144", "AAPL", "98.6200", "61", "ARCA"], + ["20160525 13:30:00.144", "AAPL", "98.6200", "25", "ARCA"], + ["20160525 13:30:00.144", "AAPL", "98.6200", "14", "ARCA"], + ["20160525 13:30:00.145", "AAPL", "98.6200", "12", "ARCA"], + ["20160525 13:30:00.145", "AAPL", "98.6200", "100", "ARCA"], + ["20160525 13:30:00.145", "AAPL", "98.6300", "100", "NASDAQ"], + ["20160525 13:30:00.145", "AAPL", "98.6300", "100", "NASDAQ"], + ], + columns="time,ticker,price,quantity,marketCenter".split(","), + ) + trades["price"] = trades["price"].astype("float64") + trades["quantity"] = trades["quantity"].astype("int64") + trades = self.prep_data(trades) + + quotes = pd.DataFrame( + [ + ["20160525 13:30:00.023", "GOOG", "720.50", "720.93"], + ["20160525 13:30:00.023", "MSFT", "51.95", "51.95"], + ["20160525 13:30:00.041", "MSFT", "51.95", "51.95"], + ["20160525 13:30:00.048", "GOOG", "720.50", "720.93"], + ["20160525 13:30:00.048", "GOOG", "720.50", "720.93"], + ["20160525 13:30:00.048", "GOOG", "720.50", "720.93"], + ["20160525 13:30:00.048", "GOOG", "720.50", "720.93"], + ["20160525 13:30:00.072", "GOOG", "720.50", "720.88"], + ["20160525 13:30:00.075", "AAPL", "98.55", "98.56"], + ["20160525 13:30:00.076", "AAPL", "98.55", "98.56"], + ["20160525 13:30:00.076", "AAPL", "98.55", "98.56"], + ["20160525 13:30:00.076", "AAPL", "98.55", "98.56"], + ["20160525 13:30:00.078", "MSFT", "51.95", "51.95"], + ["20160525 13:30:00.078", "MSFT", "51.95", "51.95"], + ["20160525 13:30:00.078", "MSFT", "51.95", "51.95"], + ["20160525 13:30:00.078", "MSFT", "51.92", "51.95"], + ["20160525 13:30:00.079", "MSFT", "51.92", "51.95"], + ["20160525 13:30:00.080", "AAPL", "98.55", "98.56"], + ["20160525 13:30:00.084", "AAPL", "98.55", "98.56"], + ["20160525 13:30:00.086", "AAPL", "98.55", "98.63"], + ["20160525 13:30:00.088", "AAPL", "98.65", "98.63"], + ["20160525 13:30:00.089", "AAPL", "98.63", "98.63"], + ["20160525 13:30:00.104", "AAPL", "98.63", "98.63"], + ["20160525 13:30:00.104", "AAPL", "98.63", "98.63"], + ["20160525 13:30:00.104", "AAPL", "98.63", "98.63"], + ["20160525 13:30:00.104", "AAPL", "98.63", "98.63"], + ["20160525 13:30:00.104", "AAPL", "98.62", "98.63"], + ["20160525 13:30:00.105", "AAPL", "98.62", "98.63"], + ["20160525 13:30:00.107", "AAPL", "98.62", "98.63"], + ["20160525 13:30:00.115", "AAPL", "98.62", "98.63"], + ["20160525 13:30:00.115", "AAPL", "98.62", "98.63"], + ["20160525 13:30:00.118", "AAPL", "98.62", "98.63"], + ["20160525 13:30:00.128", "AAPL", "98.62", "98.63"], + ["20160525 13:30:00.128", "AAPL", "98.62", "98.63"], + ["20160525 13:30:00.129", "AAPL", "98.62", "98.63"], + ["20160525 13:30:00.129", "AAPL", "98.61", "98.63"], + ["20160525 13:30:00.129", "AAPL", "98.62", "98.63"], + ["20160525 13:30:00.129", "AAPL", "98.62", "98.63"], + ["20160525 13:30:00.129", "AAPL", "98.61", "98.63"], + ["20160525 13:30:00.130", "MSFT", "51.93", "51.95"], + ["20160525 13:30:00.130", "MSFT", "51.93", "51.95"], + ["20160525 13:30:00.130", "AAPL", "98.61", "98.63"], + ["20160525 13:30:00.131", "AAPL", "98.61", "98.62"], + ["20160525 13:30:00.131", "AAPL", "98.61", "98.62"], + ["20160525 13:30:00.135", "MSFT", "51.92", "51.95"], + ["20160525 13:30:00.135", "AAPL", "98.61", "98.62"], + ["20160525 13:30:00.136", "AAPL", "98.61", "98.62"], + ["20160525 13:30:00.136", "AAPL", "98.61", "98.62"], + ["20160525 13:30:00.144", "AAPL", "98.61", "98.62"], + ["20160525 13:30:00.144", "AAPL", "98.61", "98.62"], + ["20160525 13:30:00.145", "AAPL", "98.61", "98.62"], + ["20160525 13:30:00.145", "AAPL", "98.61", "98.63"], + ["20160525 13:30:00.145", "AAPL", "98.61", "98.63"], + ["20160525 13:30:00.145", "AAPL", "98.60", "98.63"], + ["20160525 13:30:00.145", "AAPL", "98.61", "98.63"], + ["20160525 13:30:00.145", "AAPL", "98.60", "98.63"], + ], + columns="time,ticker,bid,ask".split(","), + ) + quotes["bid"] = quotes["bid"].astype("float64") + quotes["ask"] = quotes["ask"].astype("float64") + quotes = self.prep_data(quotes, dedupe=True) + + result = merge_asof(trades, quotes, on="time", by="ticker") + tm.assert_frame_equal(result, expected) + + def test_basic_no_by(self, trades, asof, quotes): + f = ( + lambda x: x[x.ticker == "MSFT"] + .drop("ticker", axis=1) + .reset_index(drop=True) + ) + + # just use a single ticker + expected = f(asof) + trades = f(trades) + quotes = f(quotes) + + result = merge_asof(trades, quotes, on="time") + tm.assert_frame_equal(result, expected) + + def test_valid_join_keys(self, trades, quotes): + msg = r"incompatible merge keys \[1\] .* must be the same type" + + with pytest.raises(MergeError, match=msg): + merge_asof(trades, quotes, left_on="time", right_on="bid", by="ticker") + + with pytest.raises(MergeError, match="can only asof on a key for left"): + merge_asof(trades, quotes, on=["time", "ticker"], by="ticker") + + with pytest.raises(MergeError, match="can only asof on a key for left"): + merge_asof(trades, quotes, by="ticker") + + def test_with_duplicates(self, datapath, trades, quotes, asof): + q = ( + pd.concat([quotes, quotes]) + .sort_values(["time", "ticker"]) + .reset_index(drop=True) + ) + result = merge_asof(trades, q, on="time", by="ticker") + expected = self.prep_data(asof) + tm.assert_frame_equal(result, expected) + + def test_with_duplicates_no_on(self): + df1 = pd.DataFrame({"key": [1, 1, 3], "left_val": [1, 2, 3]}) + df2 = pd.DataFrame({"key": [1, 2, 2], "right_val": [1, 2, 3]}) + result = merge_asof(df1, df2, on="key") + expected = pd.DataFrame( + {"key": [1, 1, 3], "left_val": [1, 2, 3], "right_val": [1, 1, 3]} + ) + tm.assert_frame_equal(result, expected) + + def test_valid_allow_exact_matches(self, trades, quotes): + msg = "allow_exact_matches must be boolean, passed foo" + + with pytest.raises(MergeError, match=msg): + merge_asof( + trades, quotes, on="time", by="ticker", allow_exact_matches="foo" + ) + + def test_valid_tolerance(self, trades, quotes): + # dti + merge_asof(trades, quotes, on="time", by="ticker", tolerance=Timedelta("1s")) + + # integer + merge_asof( + trades.reset_index(), + quotes.reset_index(), + on="index", + by="ticker", + tolerance=1, + ) + + msg = r"incompatible tolerance .*, must be compat with type .*" + + # incompat + with pytest.raises(MergeError, match=msg): + merge_asof(trades, quotes, on="time", by="ticker", tolerance=1) + + # invalid + with pytest.raises(MergeError, match=msg): + merge_asof( + trades.reset_index(), + quotes.reset_index(), + on="index", + by="ticker", + tolerance=1.0, + ) + + msg = "tolerance must be positive" + + # invalid negative + with pytest.raises(MergeError, match=msg): + merge_asof( + trades, quotes, on="time", by="ticker", tolerance=-Timedelta("1s") + ) + + with pytest.raises(MergeError, match=msg): + merge_asof( + trades.reset_index(), + quotes.reset_index(), + on="index", + by="ticker", + tolerance=-1, + ) + + def test_non_sorted(self, trades, quotes): + trades = trades.sort_values("time", ascending=False) + quotes = quotes.sort_values("time", ascending=False) + + # we require that we are already sorted on time & quotes + assert not trades.time.is_monotonic_increasing + assert not quotes.time.is_monotonic_increasing + with pytest.raises(ValueError, match="left keys must be sorted"): + merge_asof(trades, quotes, on="time", by="ticker") + + trades = trades.sort_values("time") + assert trades.time.is_monotonic_increasing + assert not quotes.time.is_monotonic_increasing + with pytest.raises(ValueError, match="right keys must be sorted"): + merge_asof(trades, quotes, on="time", by="ticker") + + quotes = quotes.sort_values("time") + assert trades.time.is_monotonic_increasing + assert quotes.time.is_monotonic_increasing + + # ok, though has dupes + merge_asof(trades, quotes, on="time", by="ticker") + + @pytest.mark.parametrize( + "tolerance_ts", + [Timedelta("1day"), datetime.timedelta(days=1)], + ids=["Timedelta", "datetime.timedelta"], + ) + def test_tolerance(self, tolerance_ts, trades, quotes, tolerance): + result = merge_asof( + trades, quotes, on="time", by="ticker", tolerance=tolerance_ts + ) + expected = tolerance + tm.assert_frame_equal(result, expected) + + def test_tolerance_forward(self): + # GH14887 + + left = pd.DataFrame({"a": [1, 5, 10], "left_val": ["a", "b", "c"]}) + right = pd.DataFrame({"a": [1, 2, 3, 7, 11], "right_val": [1, 2, 3, 7, 11]}) + + expected = pd.DataFrame( + {"a": [1, 5, 10], "left_val": ["a", "b", "c"], "right_val": [1, np.nan, 11]} + ) + + result = merge_asof(left, right, on="a", direction="forward", tolerance=1) + tm.assert_frame_equal(result, expected) + + def test_tolerance_nearest(self): + # GH14887 + + left = pd.DataFrame({"a": [1, 5, 10], "left_val": ["a", "b", "c"]}) + right = pd.DataFrame({"a": [1, 2, 3, 7, 11], "right_val": [1, 2, 3, 7, 11]}) + + expected = pd.DataFrame( + {"a": [1, 5, 10], "left_val": ["a", "b", "c"], "right_val": [1, np.nan, 11]} + ) + + result = merge_asof(left, right, on="a", direction="nearest", tolerance=1) + tm.assert_frame_equal(result, expected) + + def test_tolerance_tz(self, unit): + # GH 14844 + left = pd.DataFrame( + { + "date": pd.date_range( + start=to_datetime("2016-01-02"), + freq="D", + periods=5, + tz=pytz.timezone("UTC"), + unit=unit, + ), + "value1": np.arange(5), + } + ) + right = pd.DataFrame( + { + "date": pd.date_range( + start=to_datetime("2016-01-01"), + freq="D", + periods=5, + tz=pytz.timezone("UTC"), + unit=unit, + ), + "value2": list("ABCDE"), + } + ) + result = merge_asof(left, right, on="date", tolerance=Timedelta("1 day")) + + expected = pd.DataFrame( + { + "date": pd.date_range( + start=to_datetime("2016-01-02"), + freq="D", + periods=5, + tz=pytz.timezone("UTC"), + unit=unit, + ), + "value1": np.arange(5), + "value2": list("BCDEE"), + } + ) + tm.assert_frame_equal(result, expected) + + def test_tolerance_float(self): + # GH22981 + left = pd.DataFrame({"a": [1.1, 3.5, 10.9], "left_val": ["a", "b", "c"]}) + right = pd.DataFrame( + {"a": [1.0, 2.5, 3.3, 7.5, 11.5], "right_val": [1.0, 2.5, 3.3, 7.5, 11.5]} + ) + + expected = pd.DataFrame( + { + "a": [1.1, 3.5, 10.9], + "left_val": ["a", "b", "c"], + "right_val": [1, 3.3, np.nan], + } + ) + + result = merge_asof(left, right, on="a", direction="nearest", tolerance=0.5) + tm.assert_frame_equal(result, expected) + + def test_index_tolerance(self, trades, quotes, tolerance): + # GH 15135 + expected = tolerance.set_index("time") + trades = trades.set_index("time") + quotes = quotes.set_index("time") + + result = merge_asof( + trades, + quotes, + left_index=True, + right_index=True, + by="ticker", + tolerance=Timedelta("1day"), + ) + tm.assert_frame_equal(result, expected) + + def test_allow_exact_matches(self, trades, quotes, allow_exact_matches): + result = merge_asof( + trades, quotes, on="time", by="ticker", allow_exact_matches=False + ) + expected = allow_exact_matches + tm.assert_frame_equal(result, expected) + + def test_allow_exact_matches_forward(self): + # GH14887 + + left = pd.DataFrame({"a": [1, 5, 10], "left_val": ["a", "b", "c"]}) + right = pd.DataFrame({"a": [1, 2, 3, 7, 11], "right_val": [1, 2, 3, 7, 11]}) + + expected = pd.DataFrame( + {"a": [1, 5, 10], "left_val": ["a", "b", "c"], "right_val": [2, 7, 11]} + ) + + result = merge_asof( + left, right, on="a", direction="forward", allow_exact_matches=False + ) + tm.assert_frame_equal(result, expected) + + def test_allow_exact_matches_nearest(self): + # GH14887 + + left = pd.DataFrame({"a": [1, 5, 10], "left_val": ["a", "b", "c"]}) + right = pd.DataFrame({"a": [1, 2, 3, 7, 11], "right_val": [1, 2, 3, 7, 11]}) + + expected = pd.DataFrame( + {"a": [1, 5, 10], "left_val": ["a", "b", "c"], "right_val": [2, 3, 11]} + ) + + result = merge_asof( + left, right, on="a", direction="nearest", allow_exact_matches=False + ) + tm.assert_frame_equal(result, expected) + + def test_allow_exact_matches_and_tolerance( + self, trades, quotes, allow_exact_matches_and_tolerance + ): + result = merge_asof( + trades, + quotes, + on="time", + by="ticker", + tolerance=Timedelta("100ms"), + allow_exact_matches=False, + ) + expected = allow_exact_matches_and_tolerance + tm.assert_frame_equal(result, expected) + + def test_allow_exact_matches_and_tolerance2(self): + # GH 13695 + df1 = pd.DataFrame( + {"time": to_datetime(["2016-07-15 13:30:00.030"]), "username": ["bob"]} + ) + df2 = pd.DataFrame( + { + "time": to_datetime( + ["2016-07-15 13:30:00.000", "2016-07-15 13:30:00.030"] + ), + "version": [1, 2], + } + ) + + result = merge_asof(df1, df2, on="time") + expected = pd.DataFrame( + { + "time": to_datetime(["2016-07-15 13:30:00.030"]), + "username": ["bob"], + "version": [2], + } + ) + tm.assert_frame_equal(result, expected) + + result = merge_asof(df1, df2, on="time", allow_exact_matches=False) + expected = pd.DataFrame( + { + "time": to_datetime(["2016-07-15 13:30:00.030"]), + "username": ["bob"], + "version": [1], + } + ) + tm.assert_frame_equal(result, expected) + + result = merge_asof( + df1, + df2, + on="time", + allow_exact_matches=False, + tolerance=Timedelta("10ms"), + ) + expected = pd.DataFrame( + { + "time": to_datetime(["2016-07-15 13:30:00.030"]), + "username": ["bob"], + "version": [np.nan], + } + ) + tm.assert_frame_equal(result, expected) + + def test_allow_exact_matches_and_tolerance3(self): + # GH 13709 + df1 = pd.DataFrame( + { + "time": to_datetime( + ["2016-07-15 13:30:00.030", "2016-07-15 13:30:00.030"] + ), + "username": ["bob", "charlie"], + } + ) + df2 = pd.DataFrame( + { + "time": to_datetime( + ["2016-07-15 13:30:00.000", "2016-07-15 13:30:00.030"] + ), + "version": [1, 2], + } + ) + + result = merge_asof( + df1, + df2, + on="time", + allow_exact_matches=False, + tolerance=Timedelta("10ms"), + ) + expected = pd.DataFrame( + { + "time": to_datetime( + ["2016-07-15 13:30:00.030", "2016-07-15 13:30:00.030"] + ), + "username": ["bob", "charlie"], + "version": [np.nan, np.nan], + } + ) + tm.assert_frame_equal(result, expected) + + def test_allow_exact_matches_and_tolerance_forward(self): + # GH14887 + + left = pd.DataFrame({"a": [1, 5, 10], "left_val": ["a", "b", "c"]}) + right = pd.DataFrame({"a": [1, 3, 4, 6, 11], "right_val": [1, 3, 4, 6, 11]}) + + expected = pd.DataFrame( + {"a": [1, 5, 10], "left_val": ["a", "b", "c"], "right_val": [np.nan, 6, 11]} + ) + + result = merge_asof( + left, + right, + on="a", + direction="forward", + allow_exact_matches=False, + tolerance=1, + ) + tm.assert_frame_equal(result, expected) + + def test_allow_exact_matches_and_tolerance_nearest(self): + # GH14887 + + left = pd.DataFrame({"a": [1, 5, 10], "left_val": ["a", "b", "c"]}) + right = pd.DataFrame({"a": [1, 3, 4, 6, 11], "right_val": [1, 3, 4, 7, 11]}) + + expected = pd.DataFrame( + {"a": [1, 5, 10], "left_val": ["a", "b", "c"], "right_val": [np.nan, 4, 11]} + ) + + result = merge_asof( + left, + right, + on="a", + direction="nearest", + allow_exact_matches=False, + tolerance=1, + ) + tm.assert_frame_equal(result, expected) + + def test_forward_by(self): + # GH14887 + + left = pd.DataFrame( + { + "a": [1, 5, 10, 12, 15], + "b": ["X", "X", "Y", "Z", "Y"], + "left_val": ["a", "b", "c", "d", "e"], + } + ) + right = pd.DataFrame( + { + "a": [1, 6, 11, 15, 16], + "b": ["X", "Z", "Y", "Z", "Y"], + "right_val": [1, 6, 11, 15, 16], + } + ) + + expected = pd.DataFrame( + { + "a": [1, 5, 10, 12, 15], + "b": ["X", "X", "Y", "Z", "Y"], + "left_val": ["a", "b", "c", "d", "e"], + "right_val": [1, np.nan, 11, 15, 16], + } + ) + + result = merge_asof(left, right, on="a", by="b", direction="forward") + tm.assert_frame_equal(result, expected) + + def test_nearest_by(self): + # GH14887 + + left = pd.DataFrame( + { + "a": [1, 5, 10, 12, 15], + "b": ["X", "X", "Z", "Z", "Y"], + "left_val": ["a", "b", "c", "d", "e"], + } + ) + right = pd.DataFrame( + { + "a": [1, 6, 11, 15, 16], + "b": ["X", "Z", "Z", "Z", "Y"], + "right_val": [1, 6, 11, 15, 16], + } + ) + + expected = pd.DataFrame( + { + "a": [1, 5, 10, 12, 15], + "b": ["X", "X", "Z", "Z", "Y"], + "left_val": ["a", "b", "c", "d", "e"], + "right_val": [1, 1, 11, 11, 16], + } + ) + + result = merge_asof(left, right, on="a", by="b", direction="nearest") + tm.assert_frame_equal(result, expected) + + def test_by_int(self): + # we specialize by type, so test that this is correct + df1 = pd.DataFrame( + { + "time": to_datetime( + [ + "20160525 13:30:00.020", + "20160525 13:30:00.030", + "20160525 13:30:00.040", + "20160525 13:30:00.050", + "20160525 13:30:00.060", + ] + ), + "key": [1, 2, 1, 3, 2], + "value1": [1.1, 1.2, 1.3, 1.4, 1.5], + }, + columns=["time", "key", "value1"], + ) + + df2 = pd.DataFrame( + { + "time": to_datetime( + [ + "20160525 13:30:00.015", + "20160525 13:30:00.020", + "20160525 13:30:00.025", + "20160525 13:30:00.035", + "20160525 13:30:00.040", + "20160525 13:30:00.055", + "20160525 13:30:00.060", + "20160525 13:30:00.065", + ] + ), + "key": [2, 1, 1, 3, 2, 1, 2, 3], + "value2": [2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8], + }, + columns=["time", "key", "value2"], + ) + + result = merge_asof(df1, df2, on="time", by="key") + + expected = pd.DataFrame( + { + "time": to_datetime( + [ + "20160525 13:30:00.020", + "20160525 13:30:00.030", + "20160525 13:30:00.040", + "20160525 13:30:00.050", + "20160525 13:30:00.060", + ] + ), + "key": [1, 2, 1, 3, 2], + "value1": [1.1, 1.2, 1.3, 1.4, 1.5], + "value2": [2.2, 2.1, 2.3, 2.4, 2.7], + }, + columns=["time", "key", "value1", "value2"], + ) + + tm.assert_frame_equal(result, expected) + + def test_on_float(self): + # mimics how to determine the minimum-price variation + df1 = pd.DataFrame( + { + "price": [5.01, 0.0023, 25.13, 340.05, 30.78, 1040.90, 0.0078], + "symbol": list("ABCDEFG"), + }, + columns=["symbol", "price"], + ) + + df2 = pd.DataFrame( + {"price": [0.0, 1.0, 100.0], "mpv": [0.0001, 0.01, 0.05]}, + columns=["price", "mpv"], + ) + + df1 = df1.sort_values("price").reset_index(drop=True) + + result = merge_asof(df1, df2, on="price") + + expected = pd.DataFrame( + { + "symbol": list("BGACEDF"), + "price": [0.0023, 0.0078, 5.01, 25.13, 30.78, 340.05, 1040.90], + "mpv": [0.0001, 0.0001, 0.01, 0.01, 0.01, 0.05, 0.05], + }, + columns=["symbol", "price", "mpv"], + ) + + tm.assert_frame_equal(result, expected) + + def test_on_specialized_type(self, any_real_numpy_dtype): + # see gh-13936 + dtype = np.dtype(any_real_numpy_dtype).type + + df1 = pd.DataFrame( + {"value": [5, 2, 25, 100, 78, 120, 79], "symbol": list("ABCDEFG")}, + columns=["symbol", "value"], + ) + df1.value = dtype(df1.value) + + df2 = pd.DataFrame( + {"value": [0, 80, 120, 125], "result": list("xyzw")}, + columns=["value", "result"], + ) + df2.value = dtype(df2.value) + + df1 = df1.sort_values("value").reset_index(drop=True) + result = merge_asof(df1, df2, on="value") + + expected = pd.DataFrame( + { + "symbol": list("BACEGDF"), + "value": [2, 5, 25, 78, 79, 100, 120], + "result": list("xxxxxyz"), + }, + columns=["symbol", "value", "result"], + ) + expected.value = dtype(expected.value) + + tm.assert_frame_equal(result, expected) + + def test_on_specialized_type_by_int(self, any_real_numpy_dtype): + # see gh-13936 + dtype = np.dtype(any_real_numpy_dtype).type + + df1 = pd.DataFrame( + { + "value": [5, 2, 25, 100, 78, 120, 79], + "key": [1, 2, 3, 2, 3, 1, 2], + "symbol": list("ABCDEFG"), + }, + columns=["symbol", "key", "value"], + ) + df1.value = dtype(df1.value) + + df2 = pd.DataFrame( + {"value": [0, 80, 120, 125], "key": [1, 2, 2, 3], "result": list("xyzw")}, + columns=["value", "key", "result"], + ) + df2.value = dtype(df2.value) + + df1 = df1.sort_values("value").reset_index(drop=True) + result = merge_asof(df1, df2, on="value", by="key") + + expected = pd.DataFrame( + { + "symbol": list("BACEGDF"), + "key": [2, 1, 3, 3, 2, 2, 1], + "value": [2, 5, 25, 78, 79, 100, 120], + "result": [np.nan, "x", np.nan, np.nan, np.nan, "y", "x"], + }, + columns=["symbol", "key", "value", "result"], + ) + expected.value = dtype(expected.value) + + tm.assert_frame_equal(result, expected) + + def test_on_float_by_int(self): + # type specialize both "by" and "on" parameters + df1 = pd.DataFrame( + { + "symbol": list("AAABBBCCC"), + "exch": [1, 2, 3, 1, 2, 3, 1, 2, 3], + "price": [ + 3.26, + 3.2599, + 3.2598, + 12.58, + 12.59, + 12.5, + 378.15, + 378.2, + 378.25, + ], + }, + columns=["symbol", "exch", "price"], + ) + + df2 = pd.DataFrame( + { + "exch": [1, 1, 1, 2, 2, 2, 3, 3, 3], + "price": [0.0, 1.0, 100.0, 0.0, 5.0, 100.0, 0.0, 5.0, 1000.0], + "mpv": [0.0001, 0.01, 0.05, 0.0001, 0.01, 0.1, 0.0001, 0.25, 1.0], + }, + columns=["exch", "price", "mpv"], + ) + + df1 = df1.sort_values("price").reset_index(drop=True) + df2 = df2.sort_values("price").reset_index(drop=True) + + result = merge_asof(df1, df2, on="price", by="exch") + + expected = pd.DataFrame( + { + "symbol": list("AAABBBCCC"), + "exch": [3, 2, 1, 3, 1, 2, 1, 2, 3], + "price": [ + 3.2598, + 3.2599, + 3.26, + 12.5, + 12.58, + 12.59, + 378.15, + 378.2, + 378.25, + ], + "mpv": [0.0001, 0.0001, 0.01, 0.25, 0.01, 0.01, 0.05, 0.1, 0.25], + }, + columns=["symbol", "exch", "price", "mpv"], + ) + + tm.assert_frame_equal(result, expected) + + def test_merge_datatype_error_raises(self, using_infer_string): + if using_infer_string: + msg = "incompatible merge keys" + else: + msg = r"Incompatible merge dtype, .*, both sides must have numeric dtype" + + left = pd.DataFrame({"left_val": [1, 5, 10], "a": ["a", "b", "c"]}) + right = pd.DataFrame({"right_val": [1, 2, 3, 6, 7], "a": [1, 2, 3, 6, 7]}) + + with pytest.raises(MergeError, match=msg): + merge_asof(left, right, on="a") + + def test_merge_datatype_categorical_error_raises(self): + msg = ( + r"incompatible merge keys \[0\] .* both sides category, " + "but not equal ones" + ) + + left = pd.DataFrame( + {"left_val": [1, 5, 10], "a": pd.Categorical(["a", "b", "c"])} + ) + right = pd.DataFrame( + { + "right_val": [1, 2, 3, 6, 7], + "a": pd.Categorical(["a", "X", "c", "X", "b"]), + } + ) + + with pytest.raises(MergeError, match=msg): + merge_asof(left, right, on="a") + + def test_merge_groupby_multiple_column_with_categorical_column(self): + # GH 16454 + df = pd.DataFrame({"x": [0], "y": [0], "z": pd.Categorical([0])}) + result = merge_asof(df, df, on="x", by=["y", "z"]) + expected = pd.DataFrame({"x": [0], "y": [0], "z": pd.Categorical([0])}) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "func", [lambda x: x, lambda x: to_datetime(x)], ids=["numeric", "datetime"] + ) + @pytest.mark.parametrize("side", ["left", "right"]) + def test_merge_on_nans(self, func, side): + # GH 23189 + msg = f"Merge keys contain null values on {side} side" + nulls = func([1.0, 5.0, np.nan]) + non_nulls = func([1.0, 5.0, 10.0]) + df_null = pd.DataFrame({"a": nulls, "left_val": ["a", "b", "c"]}) + df = pd.DataFrame({"a": non_nulls, "right_val": [1, 6, 11]}) + + with pytest.raises(ValueError, match=msg): + if side == "left": + merge_asof(df_null, df, on="a") + else: + merge_asof(df, df_null, on="a") + + def test_by_nullable(self, any_numeric_ea_dtype, using_infer_string): + # Note: this test passes if instead of using pd.array we use + # np.array([np.nan, 1]). Other than that, I (@jbrockmendel) + # have NO IDEA what the expected behavior is. + # TODO(GH#32306): may be relevant to the expected behavior here. + + arr = pd.array([pd.NA, 0, 1], dtype=any_numeric_ea_dtype) + if arr.dtype.kind in ["i", "u"]: + max_val = np.iinfo(arr.dtype.numpy_dtype).max + else: + max_val = np.finfo(arr.dtype.numpy_dtype).max + # set value s.t. (at least for integer dtypes) arr._values_for_argsort + # is not an injection + arr[2] = max_val + + left = pd.DataFrame( + { + "by_col1": arr, + "by_col2": ["HELLO", "To", "You"], + "on_col": [2, 4, 6], + "value": ["a", "c", "e"], + } + ) + right = pd.DataFrame( + { + "by_col1": arr, + "by_col2": ["WORLD", "Wide", "Web"], + "on_col": [1, 2, 6], + "value": ["b", "d", "f"], + } + ) + + result = merge_asof(left, right, by=["by_col1", "by_col2"], on="on_col") + expected = pd.DataFrame( + { + "by_col1": arr, + "by_col2": ["HELLO", "To", "You"], + "on_col": [2, 4, 6], + "value_x": ["a", "c", "e"], + } + ) + expected["value_y"] = np.array([np.nan, np.nan, np.nan], dtype=object) + if using_infer_string: + expected["value_y"] = expected["value_y"].astype("string[pyarrow_numpy]") + tm.assert_frame_equal(result, expected) + + def test_merge_by_col_tz_aware(self): + # GH 21184 + left = pd.DataFrame( + { + "by_col": pd.DatetimeIndex(["2018-01-01"]).tz_localize("UTC"), + "on_col": [2], + "values": ["a"], + } + ) + right = pd.DataFrame( + { + "by_col": pd.DatetimeIndex(["2018-01-01"]).tz_localize("UTC"), + "on_col": [1], + "values": ["b"], + } + ) + result = merge_asof(left, right, by="by_col", on="on_col") + expected = pd.DataFrame( + [[pd.Timestamp("2018-01-01", tz="UTC"), 2, "a", "b"]], + columns=["by_col", "on_col", "values_x", "values_y"], + ) + tm.assert_frame_equal(result, expected) + + def test_by_mixed_tz_aware(self, using_infer_string): + # GH 26649 + left = pd.DataFrame( + { + "by_col1": pd.DatetimeIndex(["2018-01-01"]).tz_localize("UTC"), + "by_col2": ["HELLO"], + "on_col": [2], + "value": ["a"], + } + ) + right = pd.DataFrame( + { + "by_col1": pd.DatetimeIndex(["2018-01-01"]).tz_localize("UTC"), + "by_col2": ["WORLD"], + "on_col": [1], + "value": ["b"], + } + ) + result = merge_asof(left, right, by=["by_col1", "by_col2"], on="on_col") + expected = pd.DataFrame( + [[pd.Timestamp("2018-01-01", tz="UTC"), "HELLO", 2, "a"]], + columns=["by_col1", "by_col2", "on_col", "value_x"], + ) + expected["value_y"] = np.array([np.nan], dtype=object) + if using_infer_string: + expected["value_y"] = expected["value_y"].astype("string[pyarrow_numpy]") + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("dtype", ["float64", "int16", "m8[ns]", "M8[us]"]) + def test_by_dtype(self, dtype): + # GH 55453, GH 22794 + left = pd.DataFrame( + { + "by_col": np.array([1], dtype=dtype), + "on_col": [2], + "value": ["a"], + } + ) + right = pd.DataFrame( + { + "by_col": np.array([1], dtype=dtype), + "on_col": [1], + "value": ["b"], + } + ) + result = merge_asof(left, right, by="by_col", on="on_col") + expected = pd.DataFrame( + { + "by_col": np.array([1], dtype=dtype), + "on_col": [2], + "value_x": ["a"], + "value_y": ["b"], + } + ) + tm.assert_frame_equal(result, expected) + + def test_timedelta_tolerance_nearest(self, unit): + # GH 27642 + if unit == "s": + pytest.skip( + "This test is invalid with unit='s' because that would " + "round left['time']" + ) + + left = pd.DataFrame( + list(zip([0, 5, 10, 15, 20, 25], [0, 1, 2, 3, 4, 5])), + columns=["time", "left"], + ) + + left["time"] = pd.to_timedelta(left["time"], "ms").astype(f"m8[{unit}]") + + right = pd.DataFrame( + list(zip([0, 3, 9, 12, 15, 18], [0, 1, 2, 3, 4, 5])), + columns=["time", "right"], + ) + + right["time"] = pd.to_timedelta(right["time"], "ms").astype(f"m8[{unit}]") + + expected = pd.DataFrame( + list( + zip( + [0, 5, 10, 15, 20, 25], + [0, 1, 2, 3, 4, 5], + [0, np.nan, 2, 4, np.nan, np.nan], + ) + ), + columns=["time", "left", "right"], + ) + + expected["time"] = pd.to_timedelta(expected["time"], "ms").astype(f"m8[{unit}]") + + result = merge_asof( + left, right, on="time", tolerance=Timedelta("1ms"), direction="nearest" + ) + + tm.assert_frame_equal(result, expected) + + def test_int_type_tolerance(self, any_int_dtype): + # GH #28870 + + left = pd.DataFrame({"a": [0, 10, 20], "left_val": [1, 2, 3]}) + right = pd.DataFrame({"a": [5, 15, 25], "right_val": [1, 2, 3]}) + left["a"] = left["a"].astype(any_int_dtype) + right["a"] = right["a"].astype(any_int_dtype) + + expected = pd.DataFrame( + {"a": [0, 10, 20], "left_val": [1, 2, 3], "right_val": [np.nan, 1.0, 2.0]} + ) + expected["a"] = expected["a"].astype(any_int_dtype) + + result = merge_asof(left, right, on="a", tolerance=10) + tm.assert_frame_equal(result, expected) + + def test_merge_index_column_tz(self): + # GH 29864 + index = pd.date_range("2019-10-01", freq="30min", periods=5, tz="UTC") + left = pd.DataFrame([0.9, 0.8, 0.7, 0.6], columns=["xyz"], index=index[1:]) + right = pd.DataFrame({"from_date": index, "abc": [2.46] * 4 + [2.19]}) + result = merge_asof( + left=left, right=right, left_index=True, right_on=["from_date"] + ) + expected = pd.DataFrame( + { + "xyz": [0.9, 0.8, 0.7, 0.6], + "from_date": index[1:], + "abc": [2.46] * 3 + [2.19], + }, + index=pd.date_range( + "2019-10-01 00:30:00", freq="30min", periods=4, tz="UTC" + ), + ) + tm.assert_frame_equal(result, expected) + + result = merge_asof( + left=right, right=left, right_index=True, left_on=["from_date"] + ) + expected = pd.DataFrame( + { + "from_date": index, + "abc": [2.46] * 4 + [2.19], + "xyz": [np.nan, 0.9, 0.8, 0.7, 0.6], + }, + index=Index([0, 1, 2, 3, 4]), + ) + tm.assert_frame_equal(result, expected) + + def test_left_index_right_index_tolerance(self, unit): + # https://github.com/pandas-dev/pandas/issues/35558 + if unit == "s": + pytest.skip( + "This test is invalid with unit='s' because that would round dr1" + ) + + dr1 = pd.date_range( + start="1/1/2020", end="1/20/2020", freq="2D", unit=unit + ) + Timedelta(seconds=0.4).as_unit(unit) + dr2 = pd.date_range(start="1/1/2020", end="2/1/2020", unit=unit) + + df1 = pd.DataFrame({"val1": "foo"}, index=pd.DatetimeIndex(dr1)) + df2 = pd.DataFrame({"val2": "bar"}, index=pd.DatetimeIndex(dr2)) + + expected = pd.DataFrame( + {"val1": "foo", "val2": "bar"}, index=pd.DatetimeIndex(dr1) + ) + result = merge_asof( + df1, + df2, + left_index=True, + right_index=True, + tolerance=Timedelta(seconds=0.5), + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "infer_string", [False, pytest.param(True, marks=td.skip_if_no("pyarrow"))] +) +@pytest.mark.parametrize( + "kwargs", [{"on": "x"}, {"left_index": True, "right_index": True}] +) +@pytest.mark.parametrize( + "data", + [["2019-06-01 00:09:12", "2019-06-01 00:10:29"], [1.0, "2019-06-01 00:10:29"]], +) +def test_merge_asof_non_numerical_dtype(kwargs, data, infer_string): + # GH#29130 + with option_context("future.infer_string", infer_string): + left = pd.DataFrame({"x": data}, index=data) + right = pd.DataFrame({"x": data}, index=data) + with pytest.raises( + MergeError, + match=r"Incompatible merge dtype, .*, both sides must have numeric dtype", + ): + merge_asof(left, right, **kwargs) + + +def test_merge_asof_non_numerical_dtype_object(): + # GH#29130 + left = pd.DataFrame({"a": ["12", "13", "15"], "left_val1": ["a", "b", "c"]}) + right = pd.DataFrame({"a": ["a", "b", "c"], "left_val": ["d", "e", "f"]}) + with pytest.raises( + MergeError, + match=r"Incompatible merge dtype, .*, both sides must have numeric dtype", + ): + merge_asof( + left, + right, + left_on="left_val1", + right_on="a", + left_by="a", + right_by="left_val", + ) + + +@pytest.mark.parametrize( + "kwargs", + [ + {"right_index": True, "left_index": True}, + {"left_on": "left_time", "right_index": True}, + {"left_index": True, "right_on": "right"}, + ], +) +def test_merge_asof_index_behavior(kwargs): + # GH 33463 + index = Index([1, 5, 10], name="test") + left = pd.DataFrame({"left": ["a", "b", "c"], "left_time": [1, 4, 10]}, index=index) + right = pd.DataFrame({"right": [1, 2, 3, 6, 7]}, index=[1, 2, 3, 6, 7]) + result = merge_asof(left, right, **kwargs) + + expected = pd.DataFrame( + {"left": ["a", "b", "c"], "left_time": [1, 4, 10], "right": [1, 3, 7]}, + index=index, + ) + tm.assert_frame_equal(result, expected) + + +def test_merge_asof_numeric_column_in_index(): + # GH#34488 + left = pd.DataFrame({"b": [10, 11, 12]}, index=Index([1, 2, 3], name="a")) + right = pd.DataFrame({"c": [20, 21, 22]}, index=Index([0, 2, 3], name="a")) + + result = merge_asof(left, right, left_on="a", right_on="a") + expected = pd.DataFrame({"a": [1, 2, 3], "b": [10, 11, 12], "c": [20, 21, 22]}) + tm.assert_frame_equal(result, expected) + + +def test_merge_asof_numeric_column_in_multiindex(): + # GH#34488 + left = pd.DataFrame( + {"b": [10, 11, 12]}, + index=pd.MultiIndex.from_arrays([[1, 2, 3], ["a", "b", "c"]], names=["a", "z"]), + ) + right = pd.DataFrame( + {"c": [20, 21, 22]}, + index=pd.MultiIndex.from_arrays([[1, 2, 3], ["x", "y", "z"]], names=["a", "y"]), + ) + + result = merge_asof(left, right, left_on="a", right_on="a") + expected = pd.DataFrame({"a": [1, 2, 3], "b": [10, 11, 12], "c": [20, 21, 22]}) + tm.assert_frame_equal(result, expected) + + +def test_merge_asof_numeri_column_in_index_object_dtype(): + # GH#34488 + left = pd.DataFrame({"b": [10, 11, 12]}, index=Index(["1", "2", "3"], name="a")) + right = pd.DataFrame({"c": [20, 21, 22]}, index=Index(["m", "n", "o"], name="a")) + + with pytest.raises( + MergeError, + match=r"Incompatible merge dtype, .*, both sides must have numeric dtype", + ): + merge_asof(left, right, left_on="a", right_on="a") + + left = left.reset_index().set_index(["a", "b"]) + right = right.reset_index().set_index(["a", "c"]) + + with pytest.raises( + MergeError, + match=r"Incompatible merge dtype, .*, both sides must have numeric dtype", + ): + merge_asof(left, right, left_on="a", right_on="a") + + +def test_merge_asof_array_as_on(unit): + # GH#42844 + dti = pd.DatetimeIndex( + ["2021/01/01 00:37", "2021/01/01 01:40"], dtype=f"M8[{unit}]" + ) + right = pd.DataFrame( + { + "a": [2, 6], + "ts": dti, + } + ) + ts_merge = pd.date_range( + start=pd.Timestamp("2021/01/01 00:00"), periods=3, freq="1h", unit=unit + ) + left = pd.DataFrame({"b": [4, 8, 7]}) + result = merge_asof( + left, + right, + left_on=ts_merge, + right_on="ts", + allow_exact_matches=False, + direction="backward", + ) + expected = pd.DataFrame({"b": [4, 8, 7], "a": [np.nan, 2, 6], "ts": ts_merge}) + tm.assert_frame_equal(result, expected) + + result = merge_asof( + right, + left, + left_on="ts", + right_on=ts_merge, + allow_exact_matches=False, + direction="backward", + ) + expected = pd.DataFrame( + { + "a": [2, 6], + "ts": dti, + "b": [4, 8], + } + ) + tm.assert_frame_equal(result, expected) + + +def test_merge_asof_raise_for_duplicate_columns(): + # GH#50102 + left = pd.DataFrame([[1, 2, "a"]], columns=["a", "a", "left_val"]) + right = pd.DataFrame([[1, 1, 1]], columns=["a", "a", "right_val"]) + + with pytest.raises(ValueError, match="column label 'a'"): + merge_asof(left, right, on="a") + + with pytest.raises(ValueError, match="column label 'a'"): + merge_asof(left, right, left_on="a", right_on="right_val") + + with pytest.raises(ValueError, match="column label 'a'"): + merge_asof(left, right, left_on="left_val", right_on="a") + + +@pytest.mark.parametrize( + "dtype", + [ + "Int64", + pytest.param("int64[pyarrow]", marks=td.skip_if_no("pyarrow")), + pytest.param("timestamp[s][pyarrow]", marks=td.skip_if_no("pyarrow")), + ], +) +def test_merge_asof_extension_dtype(dtype): + # GH 52904 + left = pd.DataFrame( + { + "join_col": [1, 3, 5], + "left_val": [1, 2, 3], + } + ) + right = pd.DataFrame( + { + "join_col": [2, 3, 4], + "right_val": [1, 2, 3], + } + ) + left = left.astype({"join_col": dtype}) + right = right.astype({"join_col": dtype}) + result = merge_asof(left, right, on="join_col") + expected = pd.DataFrame( + { + "join_col": [1, 3, 5], + "left_val": [1, 2, 3], + "right_val": [np.nan, 2.0, 3.0], + } + ) + expected = expected.astype({"join_col": dtype}) + tm.assert_frame_equal(result, expected) + + +@td.skip_if_no("pyarrow") +def test_merge_asof_pyarrow_td_tolerance(): + # GH 56486 + ser = pd.Series( + [datetime.datetime(2023, 1, 1)], dtype="timestamp[us, UTC][pyarrow]" + ) + df = pd.DataFrame( + { + "timestamp": ser, + "value": [1], + } + ) + result = merge_asof(df, df, on="timestamp", tolerance=Timedelta("1s")) + expected = pd.DataFrame( + { + "timestamp": ser, + "value_x": [1], + "value_y": [1], + } + ) + tm.assert_frame_equal(result, expected) + + +def test_merge_asof_read_only_ndarray(): + # GH 53513 + left = pd.Series([2], index=[2], name="left") + right = pd.Series([1], index=[1], name="right") + # set to read-only + left.index.values.flags.writeable = False + right.index.values.flags.writeable = False + result = merge_asof(left, right, left_index=True, right_index=True) + expected = pd.DataFrame({"left": [2], "right": [1]}, index=[2]) + tm.assert_frame_equal(result, expected) + + +def test_merge_asof_multiby_with_categorical(): + # GH 43541 + left = pd.DataFrame( + { + "c1": pd.Categorical(["a", "a", "b", "b"], categories=["a", "b"]), + "c2": ["x"] * 4, + "t": [1] * 4, + "v": range(4), + } + ) + right = pd.DataFrame( + { + "c1": pd.Categorical(["b", "b"], categories=["b", "a"]), + "c2": ["x"] * 2, + "t": [1, 2], + "v": range(2), + } + ) + result = merge_asof( + left, + right, + by=["c1", "c2"], + on="t", + direction="forward", + suffixes=["_left", "_right"], + ) + expected = pd.DataFrame( + { + "c1": pd.Categorical(["a", "a", "b", "b"], categories=["a", "b"]), + "c2": ["x"] * 4, + "t": [1] * 4, + "v_left": range(4), + "v_right": [np.nan, np.nan, 0.0, 0.0], + } + ) + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/test_merge_cross.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/test_merge_cross.py new file mode 100644 index 0000000000000000000000000000000000000000..14f9036e43fce13580916dacfd6409e4639d025e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/test_merge_cross.py @@ -0,0 +1,111 @@ +import pytest + +from pandas import ( + DataFrame, + Series, +) +import pandas._testing as tm +from pandas.core.reshape.merge import ( + MergeError, + merge, +) + + +@pytest.mark.parametrize( + ("input_col", "output_cols"), [("b", ["a", "b"]), ("a", ["a_x", "a_y"])] +) +def test_merge_cross(input_col, output_cols): + # GH#5401 + left = DataFrame({"a": [1, 3]}) + right = DataFrame({input_col: [3, 4]}) + left_copy = left.copy() + right_copy = right.copy() + result = merge(left, right, how="cross") + expected = DataFrame({output_cols[0]: [1, 1, 3, 3], output_cols[1]: [3, 4, 3, 4]}) + tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(left, left_copy) + tm.assert_frame_equal(right, right_copy) + + +@pytest.mark.parametrize( + "kwargs", + [ + {"left_index": True}, + {"right_index": True}, + {"on": "a"}, + {"left_on": "a"}, + {"right_on": "b"}, + ], +) +def test_merge_cross_error_reporting(kwargs): + # GH#5401 + left = DataFrame({"a": [1, 3]}) + right = DataFrame({"b": [3, 4]}) + msg = ( + "Can not pass on, right_on, left_on or set right_index=True or " + "left_index=True" + ) + with pytest.raises(MergeError, match=msg): + merge(left, right, how="cross", **kwargs) + + +def test_merge_cross_mixed_dtypes(): + # GH#5401 + left = DataFrame(["a", "b", "c"], columns=["A"]) + right = DataFrame(range(2), columns=["B"]) + result = merge(left, right, how="cross") + expected = DataFrame({"A": ["a", "a", "b", "b", "c", "c"], "B": [0, 1, 0, 1, 0, 1]}) + tm.assert_frame_equal(result, expected) + + +def test_merge_cross_more_than_one_column(): + # GH#5401 + left = DataFrame({"A": list("ab"), "B": [2, 1]}) + right = DataFrame({"C": range(2), "D": range(4, 6)}) + result = merge(left, right, how="cross") + expected = DataFrame( + { + "A": ["a", "a", "b", "b"], + "B": [2, 2, 1, 1], + "C": [0, 1, 0, 1], + "D": [4, 5, 4, 5], + } + ) + tm.assert_frame_equal(result, expected) + + +def test_merge_cross_null_values(nulls_fixture): + # GH#5401 + left = DataFrame({"a": [1, nulls_fixture]}) + right = DataFrame({"b": ["a", "b"], "c": [1.0, 2.0]}) + result = merge(left, right, how="cross") + expected = DataFrame( + { + "a": [1, 1, nulls_fixture, nulls_fixture], + "b": ["a", "b", "a", "b"], + "c": [1.0, 2.0, 1.0, 2.0], + } + ) + tm.assert_frame_equal(result, expected) + + +def test_join_cross_error_reporting(): + # GH#5401 + left = DataFrame({"a": [1, 3]}) + right = DataFrame({"a": [3, 4]}) + msg = ( + "Can not pass on, right_on, left_on or set right_index=True or " + "left_index=True" + ) + with pytest.raises(MergeError, match=msg): + left.join(right, how="cross", on="a") + + +def test_merge_cross_series(): + # GH#54055 + ls = Series([1, 2, 3, 4], index=[1, 2, 3, 4], name="left") + rs = Series([3, 4, 5, 6], index=[3, 4, 5, 6], name="right") + res = merge(ls, rs, how="cross") + + expected = merge(ls.to_frame(), rs.to_frame(), how="cross") + tm.assert_frame_equal(res, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/test_merge_ordered.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/test_merge_ordered.py new file mode 100644 index 0000000000000000000000000000000000000000..0bd3ca3cf2c1bd8ceae38d8a04cd6d48c4c69f68 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/merge/test_merge_ordered.py @@ -0,0 +1,244 @@ +import re + +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + merge_ordered, +) +import pandas._testing as tm + + +@pytest.fixture +def left(): + return DataFrame({"key": ["a", "c", "e"], "lvalue": [1, 2.0, 3]}) + + +@pytest.fixture +def right(): + return DataFrame({"key": ["b", "c", "d", "f"], "rvalue": [1, 2, 3.0, 4]}) + + +class TestMergeOrdered: + def test_basic(self, left, right): + result = merge_ordered(left, right, on="key") + expected = DataFrame( + { + "key": ["a", "b", "c", "d", "e", "f"], + "lvalue": [1, np.nan, 2, np.nan, 3, np.nan], + "rvalue": [np.nan, 1, 2, 3, np.nan, 4], + } + ) + + tm.assert_frame_equal(result, expected) + + def test_ffill(self, left, right): + result = merge_ordered(left, right, on="key", fill_method="ffill") + expected = DataFrame( + { + "key": ["a", "b", "c", "d", "e", "f"], + "lvalue": [1.0, 1, 2, 2, 3, 3.0], + "rvalue": [np.nan, 1, 2, 3, 3, 4], + } + ) + tm.assert_frame_equal(result, expected) + + def test_multigroup(self, left, right): + left = pd.concat([left, left], ignore_index=True) + + left["group"] = ["a"] * 3 + ["b"] * 3 + + result = merge_ordered( + left, right, on="key", left_by="group", fill_method="ffill" + ) + expected = DataFrame( + { + "key": ["a", "b", "c", "d", "e", "f"] * 2, + "lvalue": [1.0, 1, 2, 2, 3, 3.0] * 2, + "rvalue": [np.nan, 1, 2, 3, 3, 4] * 2, + } + ) + expected["group"] = ["a"] * 6 + ["b"] * 6 + + tm.assert_frame_equal(result, expected.loc[:, result.columns]) + + result2 = merge_ordered( + right, left, on="key", right_by="group", fill_method="ffill" + ) + tm.assert_frame_equal(result, result2.loc[:, result.columns]) + + result = merge_ordered(left, right, on="key", left_by="group") + assert result["group"].notna().all() + + @pytest.mark.filterwarnings( + "ignore:Passing a BlockManager|Passing a SingleBlockManager:DeprecationWarning" + ) + def test_merge_type(self, left, right): + class NotADataFrame(DataFrame): + @property + def _constructor(self): + return NotADataFrame + + nad = NotADataFrame(left) + result = nad.merge(right, on="key") + + assert isinstance(result, NotADataFrame) + + @pytest.mark.parametrize( + "df_seq, pattern", + [ + ((), "[Nn]o objects"), + ([], "[Nn]o objects"), + ({}, "[Nn]o objects"), + ([None], "objects.*None"), + ([None, None], "objects.*None"), + ], + ) + def test_empty_sequence_concat(self, df_seq, pattern): + # GH 9157 + with pytest.raises(ValueError, match=pattern): + pd.concat(df_seq) + + @pytest.mark.parametrize( + "arg", [[DataFrame()], [None, DataFrame()], [DataFrame(), None]] + ) + def test_empty_sequence_concat_ok(self, arg): + pd.concat(arg) + + def test_doc_example(self): + left = DataFrame( + { + "group": list("aaabbb"), + "key": ["a", "c", "e", "a", "c", "e"], + "lvalue": [1, 2, 3] * 2, + } + ) + + right = DataFrame({"key": ["b", "c", "d"], "rvalue": [1, 2, 3]}) + + result = merge_ordered(left, right, fill_method="ffill", left_by="group") + + expected = DataFrame( + { + "group": list("aaaaabbbbb"), + "key": ["a", "b", "c", "d", "e"] * 2, + "lvalue": [1, 1, 2, 2, 3] * 2, + "rvalue": [np.nan, 1, 2, 3, 3] * 2, + } + ) + + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "left, right, on, left_by, right_by, expected", + [ + ( + DataFrame({"G": ["g", "g"], "H": ["h", "h"], "T": [1, 3]}), + DataFrame({"T": [2], "E": [1]}), + ["T"], + ["G", "H"], + None, + DataFrame( + { + "G": ["g"] * 3, + "H": ["h"] * 3, + "T": [1, 2, 3], + "E": [np.nan, 1.0, np.nan], + } + ), + ), + ( + DataFrame({"G": ["g", "g"], "H": ["h", "h"], "T": [1, 3]}), + DataFrame({"T": [2], "E": [1]}), + "T", + ["G", "H"], + None, + DataFrame( + { + "G": ["g"] * 3, + "H": ["h"] * 3, + "T": [1, 2, 3], + "E": [np.nan, 1.0, np.nan], + } + ), + ), + ( + DataFrame({"T": [2], "E": [1]}), + DataFrame({"G": ["g", "g"], "H": ["h", "h"], "T": [1, 3]}), + ["T"], + None, + ["G", "H"], + DataFrame( + { + "T": [1, 2, 3], + "E": [np.nan, 1.0, np.nan], + "G": ["g"] * 3, + "H": ["h"] * 3, + } + ), + ), + ], + ) + def test_list_type_by(self, left, right, on, left_by, right_by, expected): + # GH 35269 + result = merge_ordered( + left=left, + right=right, + on=on, + left_by=left_by, + right_by=right_by, + ) + + tm.assert_frame_equal(result, expected) + + def test_left_by_length_equals_to_right_shape0(self): + # GH 38166 + left = DataFrame([["g", "h", 1], ["g", "h", 3]], columns=list("GHE")) + right = DataFrame([[2, 1]], columns=list("ET")) + result = merge_ordered(left, right, on="E", left_by=["G", "H"]) + expected = DataFrame( + {"G": ["g"] * 3, "H": ["h"] * 3, "E": [1, 2, 3], "T": [np.nan, 1.0, np.nan]} + ) + + tm.assert_frame_equal(result, expected) + + def test_elements_not_in_by_but_in_df(self): + # GH 38167 + left = DataFrame([["g", "h", 1], ["g", "h", 3]], columns=list("GHE")) + right = DataFrame([[2, 1]], columns=list("ET")) + msg = r"\{'h'\} not found in left columns" + with pytest.raises(KeyError, match=msg): + merge_ordered(left, right, on="E", left_by=["G", "h"]) + + @pytest.mark.parametrize("invalid_method", ["linear", "carrot"]) + def test_ffill_validate_fill_method(self, left, right, invalid_method): + # GH 55884 + with pytest.raises( + ValueError, match=re.escape("fill_method must be 'ffill' or None") + ): + merge_ordered(left, right, on="key", fill_method=invalid_method) + + def test_ffill_left_merge(self): + # GH 57010 + df1 = DataFrame( + { + "key": ["a", "c", "e", "a", "c", "e"], + "lvalue": [1, 2, 3, 1, 2, 3], + "group": ["a", "a", "a", "b", "b", "b"], + } + ) + df2 = DataFrame({"key": ["b", "c", "d"], "rvalue": [1, 2, 3]}) + result = merge_ordered( + df1, df2, fill_method="ffill", left_by="group", how="left" + ) + expected = DataFrame( + { + "key": ["a", "c", "e", "a", "c", "e"], + "lvalue": [1, 2, 3, 1, 2, 3], + "group": ["a", "a", "a", "b", "b", "b"], + "rvalue": [np.nan, 2.0, 2.0, np.nan, 2.0, 2.0], + } + ) + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_crosstab.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_crosstab.py new file mode 100644 index 0000000000000000000000000000000000000000..136e76986df9d86afbf5a6ca8e3650d7745dae3a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_crosstab.py @@ -0,0 +1,886 @@ +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + CategoricalDtype, + CategoricalIndex, + DataFrame, + Index, + MultiIndex, + Series, + crosstab, +) +import pandas._testing as tm + + +@pytest.fixture +def df(): + df = DataFrame( + { + "A": [ + "foo", + "foo", + "foo", + "foo", + "bar", + "bar", + "bar", + "bar", + "foo", + "foo", + "foo", + ], + "B": [ + "one", + "one", + "one", + "two", + "one", + "one", + "one", + "two", + "two", + "two", + "one", + ], + "C": [ + "dull", + "dull", + "shiny", + "dull", + "dull", + "shiny", + "shiny", + "dull", + "shiny", + "shiny", + "shiny", + ], + "D": np.random.default_rng(2).standard_normal(11), + "E": np.random.default_rng(2).standard_normal(11), + "F": np.random.default_rng(2).standard_normal(11), + } + ) + + return pd.concat([df, df], ignore_index=True) + + +class TestCrosstab: + def test_crosstab_single(self, df): + result = crosstab(df["A"], df["C"]) + expected = df.groupby(["A", "C"]).size().unstack() + tm.assert_frame_equal(result, expected.fillna(0).astype(np.int64)) + + def test_crosstab_multiple(self, df): + result = crosstab(df["A"], [df["B"], df["C"]]) + expected = df.groupby(["A", "B", "C"]).size() + expected = expected.unstack("B").unstack("C").fillna(0).astype(np.int64) + tm.assert_frame_equal(result, expected) + + result = crosstab([df["B"], df["C"]], df["A"]) + expected = df.groupby(["B", "C", "A"]).size() + expected = expected.unstack("A").fillna(0).astype(np.int64) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("box", [np.array, list, tuple]) + def test_crosstab_ndarray(self, box): + # GH 44076 + a = box(np.random.default_rng(2).integers(0, 5, size=100)) + b = box(np.random.default_rng(2).integers(0, 3, size=100)) + c = box(np.random.default_rng(2).integers(0, 10, size=100)) + + df = DataFrame({"a": a, "b": b, "c": c}) + + result = crosstab(a, [b, c], rownames=["a"], colnames=("b", "c")) + expected = crosstab(df["a"], [df["b"], df["c"]]) + tm.assert_frame_equal(result, expected) + + result = crosstab([b, c], a, colnames=["a"], rownames=("b", "c")) + expected = crosstab([df["b"], df["c"]], df["a"]) + tm.assert_frame_equal(result, expected) + + # assign arbitrary names + result = crosstab(a, c) + expected = crosstab(df["a"], df["c"]) + expected.index.names = ["row_0"] + expected.columns.names = ["col_0"] + tm.assert_frame_equal(result, expected) + + def test_crosstab_non_aligned(self): + # GH 17005 + a = Series([0, 1, 1], index=["a", "b", "c"]) + b = Series([3, 4, 3, 4, 3], index=["a", "b", "c", "d", "f"]) + c = np.array([3, 4, 3], dtype=np.int64) + + expected = DataFrame( + [[1, 0], [1, 1]], + index=Index([0, 1], name="row_0"), + columns=Index([3, 4], name="col_0"), + ) + + result = crosstab(a, b) + tm.assert_frame_equal(result, expected) + + result = crosstab(a, c) + tm.assert_frame_equal(result, expected) + + def test_crosstab_margins(self): + a = np.random.default_rng(2).integers(0, 7, size=100) + b = np.random.default_rng(2).integers(0, 3, size=100) + c = np.random.default_rng(2).integers(0, 5, size=100) + + df = DataFrame({"a": a, "b": b, "c": c}) + + result = crosstab(a, [b, c], rownames=["a"], colnames=("b", "c"), margins=True) + + assert result.index.names == ("a",) + assert result.columns.names == ["b", "c"] + + all_cols = result["All", ""] + exp_cols = df.groupby(["a"]).size().astype("i8") + # to keep index.name + exp_margin = Series([len(df)], index=Index(["All"], name="a")) + exp_cols = pd.concat([exp_cols, exp_margin]) + exp_cols.name = ("All", "") + + tm.assert_series_equal(all_cols, exp_cols) + + all_rows = result.loc["All"] + exp_rows = df.groupby(["b", "c"]).size().astype("i8") + exp_rows = pd.concat([exp_rows, Series([len(df)], index=[("All", "")])]) + exp_rows.name = "All" + + exp_rows = exp_rows.reindex(all_rows.index) + exp_rows = exp_rows.fillna(0).astype(np.int64) + tm.assert_series_equal(all_rows, exp_rows) + + def test_crosstab_margins_set_margin_name(self): + # GH 15972 + a = np.random.default_rng(2).integers(0, 7, size=100) + b = np.random.default_rng(2).integers(0, 3, size=100) + c = np.random.default_rng(2).integers(0, 5, size=100) + + df = DataFrame({"a": a, "b": b, "c": c}) + + result = crosstab( + a, + [b, c], + rownames=["a"], + colnames=("b", "c"), + margins=True, + margins_name="TOTAL", + ) + + assert result.index.names == ("a",) + assert result.columns.names == ["b", "c"] + + all_cols = result["TOTAL", ""] + exp_cols = df.groupby(["a"]).size().astype("i8") + # to keep index.name + exp_margin = Series([len(df)], index=Index(["TOTAL"], name="a")) + exp_cols = pd.concat([exp_cols, exp_margin]) + exp_cols.name = ("TOTAL", "") + + tm.assert_series_equal(all_cols, exp_cols) + + all_rows = result.loc["TOTAL"] + exp_rows = df.groupby(["b", "c"]).size().astype("i8") + exp_rows = pd.concat([exp_rows, Series([len(df)], index=[("TOTAL", "")])]) + exp_rows.name = "TOTAL" + + exp_rows = exp_rows.reindex(all_rows.index) + exp_rows = exp_rows.fillna(0).astype(np.int64) + tm.assert_series_equal(all_rows, exp_rows) + + msg = "margins_name argument must be a string" + for margins_name in [666, None, ["a", "b"]]: + with pytest.raises(ValueError, match=msg): + crosstab( + a, + [b, c], + rownames=["a"], + colnames=("b", "c"), + margins=True, + margins_name=margins_name, + ) + + def test_crosstab_pass_values(self): + a = np.random.default_rng(2).integers(0, 7, size=100) + b = np.random.default_rng(2).integers(0, 3, size=100) + c = np.random.default_rng(2).integers(0, 5, size=100) + values = np.random.default_rng(2).standard_normal(100) + + table = crosstab( + [a, b], c, values, aggfunc="sum", rownames=["foo", "bar"], colnames=["baz"] + ) + + df = DataFrame({"foo": a, "bar": b, "baz": c, "values": values}) + + expected = df.pivot_table( + "values", index=["foo", "bar"], columns="baz", aggfunc="sum" + ) + tm.assert_frame_equal(table, expected) + + def test_crosstab_dropna(self): + # GH 3820 + a = np.array(["foo", "foo", "foo", "bar", "bar", "foo", "foo"], dtype=object) + b = np.array(["one", "one", "two", "one", "two", "two", "two"], dtype=object) + c = np.array( + ["dull", "dull", "dull", "dull", "dull", "shiny", "shiny"], dtype=object + ) + res = crosstab(a, [b, c], rownames=["a"], colnames=["b", "c"], dropna=False) + m = MultiIndex.from_tuples( + [("one", "dull"), ("one", "shiny"), ("two", "dull"), ("two", "shiny")], + names=["b", "c"], + ) + tm.assert_index_equal(res.columns, m) + + def test_crosstab_no_overlap(self): + # GS 10291 + + s1 = Series([1, 2, 3], index=[1, 2, 3]) + s2 = Series([4, 5, 6], index=[4, 5, 6]) + + actual = crosstab(s1, s2) + expected = DataFrame( + index=Index([], dtype="int64", name="row_0"), + columns=Index([], dtype="int64", name="col_0"), + ) + + tm.assert_frame_equal(actual, expected) + + def test_margin_dropna(self): + # GH 12577 + # pivot_table counts null into margin ('All') + # when margins=true and dropna=true + + df = DataFrame({"a": [1, 2, 2, 2, 2, np.nan], "b": [3, 3, 4, 4, 4, 4]}) + actual = crosstab(df.a, df.b, margins=True, dropna=True) + expected = DataFrame([[1, 0, 1], [1, 3, 4], [2, 3, 5]]) + expected.index = Index([1.0, 2.0, "All"], name="a") + expected.columns = Index([3, 4, "All"], name="b") + tm.assert_frame_equal(actual, expected) + + def test_margin_dropna2(self): + df = DataFrame( + {"a": [1, np.nan, np.nan, np.nan, 2, np.nan], "b": [3, np.nan, 4, 4, 4, 4]} + ) + actual = crosstab(df.a, df.b, margins=True, dropna=True) + expected = DataFrame([[1, 0, 1], [0, 1, 1], [1, 1, 2]]) + expected.index = Index([1.0, 2.0, "All"], name="a") + expected.columns = Index([3.0, 4.0, "All"], name="b") + tm.assert_frame_equal(actual, expected) + + def test_margin_dropna3(self): + df = DataFrame( + {"a": [1, np.nan, np.nan, np.nan, np.nan, 2], "b": [3, 3, 4, 4, 4, 4]} + ) + actual = crosstab(df.a, df.b, margins=True, dropna=True) + expected = DataFrame([[1, 0, 1], [0, 1, 1], [1, 1, 2]]) + expected.index = Index([1.0, 2.0, "All"], name="a") + expected.columns = Index([3, 4, "All"], name="b") + tm.assert_frame_equal(actual, expected) + + def test_margin_dropna4(self): + # GH 12642 + # _add_margins raises KeyError: Level None not found + # when margins=True and dropna=False + # GH: 10772: Keep np.nan in result with dropna=False + df = DataFrame({"a": [1, 2, 2, 2, 2, np.nan], "b": [3, 3, 4, 4, 4, 4]}) + actual = crosstab(df.a, df.b, margins=True, dropna=False) + expected = DataFrame([[1, 0, 1.0], [1, 3, 4.0], [0, 1, np.nan], [2, 4, 6.0]]) + expected.index = Index([1.0, 2.0, np.nan, "All"], name="a") + expected.columns = Index([3, 4, "All"], name="b") + tm.assert_frame_equal(actual, expected) + + def test_margin_dropna5(self): + # GH: 10772: Keep np.nan in result with dropna=False + df = DataFrame( + {"a": [1, np.nan, np.nan, np.nan, 2, np.nan], "b": [3, np.nan, 4, 4, 4, 4]} + ) + actual = crosstab(df.a, df.b, margins=True, dropna=False) + expected = DataFrame( + [[1, 0, 0, 1.0], [0, 1, 0, 1.0], [0, 3, 1, np.nan], [1, 4, 0, 6.0]] + ) + expected.index = Index([1.0, 2.0, np.nan, "All"], name="a") + expected.columns = Index([3.0, 4.0, np.nan, "All"], name="b") + tm.assert_frame_equal(actual, expected) + + def test_margin_dropna6(self): + # GH: 10772: Keep np.nan in result with dropna=False + a = np.array(["foo", "foo", "foo", "bar", "bar", "foo", "foo"], dtype=object) + b = np.array(["one", "one", "two", "one", "two", np.nan, "two"], dtype=object) + c = np.array( + ["dull", "dull", "dull", "dull", "dull", "shiny", "shiny"], dtype=object + ) + + actual = crosstab( + a, [b, c], rownames=["a"], colnames=["b", "c"], margins=True, dropna=False + ) + m = MultiIndex.from_arrays( + [ + ["one", "one", "two", "two", np.nan, np.nan, "All"], + ["dull", "shiny", "dull", "shiny", "dull", "shiny", ""], + ], + names=["b", "c"], + ) + expected = DataFrame( + [[1, 0, 1, 0, 0, 0, 2], [2, 0, 1, 1, 0, 1, 5], [3, 0, 2, 1, 0, 0, 7]], + columns=m, + ) + expected.index = Index(["bar", "foo", "All"], name="a") + tm.assert_frame_equal(actual, expected) + + actual = crosstab( + [a, b], c, rownames=["a", "b"], colnames=["c"], margins=True, dropna=False + ) + m = MultiIndex.from_arrays( + [ + ["bar", "bar", "bar", "foo", "foo", "foo", "All"], + ["one", "two", np.nan, "one", "two", np.nan, ""], + ], + names=["a", "b"], + ) + expected = DataFrame( + [ + [1, 0, 1.0], + [1, 0, 1.0], + [0, 0, np.nan], + [2, 0, 2.0], + [1, 1, 2.0], + [0, 1, np.nan], + [5, 2, 7.0], + ], + index=m, + ) + expected.columns = Index(["dull", "shiny", "All"], name="c") + tm.assert_frame_equal(actual, expected) + + actual = crosstab( + [a, b], c, rownames=["a", "b"], colnames=["c"], margins=True, dropna=True + ) + m = MultiIndex.from_arrays( + [["bar", "bar", "foo", "foo", "All"], ["one", "two", "one", "two", ""]], + names=["a", "b"], + ) + expected = DataFrame( + [[1, 0, 1], [1, 0, 1], [2, 0, 2], [1, 1, 2], [5, 1, 6]], index=m + ) + expected.columns = Index(["dull", "shiny", "All"], name="c") + tm.assert_frame_equal(actual, expected) + + def test_crosstab_normalize(self): + # Issue 12578 + df = DataFrame( + {"a": [1, 2, 2, 2, 2], "b": [3, 3, 4, 4, 4], "c": [1, 1, np.nan, 1, 1]} + ) + + rindex = Index([1, 2], name="a") + cindex = Index([3, 4], name="b") + full_normal = DataFrame([[0.2, 0], [0.2, 0.6]], index=rindex, columns=cindex) + row_normal = DataFrame([[1.0, 0], [0.25, 0.75]], index=rindex, columns=cindex) + col_normal = DataFrame([[0.5, 0], [0.5, 1.0]], index=rindex, columns=cindex) + + # Check all normalize args + tm.assert_frame_equal(crosstab(df.a, df.b, normalize="all"), full_normal) + tm.assert_frame_equal(crosstab(df.a, df.b, normalize=True), full_normal) + tm.assert_frame_equal(crosstab(df.a, df.b, normalize="index"), row_normal) + tm.assert_frame_equal(crosstab(df.a, df.b, normalize="columns"), col_normal) + tm.assert_frame_equal( + crosstab(df.a, df.b, normalize=1), + crosstab(df.a, df.b, normalize="columns"), + ) + tm.assert_frame_equal( + crosstab(df.a, df.b, normalize=0), crosstab(df.a, df.b, normalize="index") + ) + + row_normal_margins = DataFrame( + [[1.0, 0], [0.25, 0.75], [0.4, 0.6]], + index=Index([1, 2, "All"], name="a", dtype="object"), + columns=Index([3, 4], name="b", dtype="object"), + ) + col_normal_margins = DataFrame( + [[0.5, 0, 0.2], [0.5, 1.0, 0.8]], + index=Index([1, 2], name="a", dtype="object"), + columns=Index([3, 4, "All"], name="b", dtype="object"), + ) + + all_normal_margins = DataFrame( + [[0.2, 0, 0.2], [0.2, 0.6, 0.8], [0.4, 0.6, 1]], + index=Index([1, 2, "All"], name="a", dtype="object"), + columns=Index([3, 4, "All"], name="b", dtype="object"), + ) + tm.assert_frame_equal( + crosstab(df.a, df.b, normalize="index", margins=True), row_normal_margins + ) + tm.assert_frame_equal( + crosstab(df.a, df.b, normalize="columns", margins=True), col_normal_margins + ) + tm.assert_frame_equal( + crosstab(df.a, df.b, normalize=True, margins=True), all_normal_margins + ) + + def test_crosstab_normalize_arrays(self): + # GH#12578 + df = DataFrame( + {"a": [1, 2, 2, 2, 2], "b": [3, 3, 4, 4, 4], "c": [1, 1, np.nan, 1, 1]} + ) + + # Test arrays + crosstab( + [np.array([1, 1, 2, 2]), np.array([1, 2, 1, 2])], np.array([1, 2, 1, 2]) + ) + + # Test with aggfunc + norm_counts = DataFrame( + [[0.25, 0, 0.25], [0.25, 0.5, 0.75], [0.5, 0.5, 1]], + index=Index([1, 2, "All"], name="a", dtype="object"), + columns=Index([3, 4, "All"], name="b"), + ) + test_case = crosstab( + df.a, df.b, df.c, aggfunc="count", normalize="all", margins=True + ) + tm.assert_frame_equal(test_case, norm_counts) + + df = DataFrame( + {"a": [1, 2, 2, 2, 2], "b": [3, 3, 4, 4, 4], "c": [0, 4, np.nan, 3, 3]} + ) + + norm_sum = DataFrame( + [[0, 0, 0.0], [0.4, 0.6, 1], [0.4, 0.6, 1]], + index=Index([1, 2, "All"], name="a", dtype="object"), + columns=Index([3, 4, "All"], name="b", dtype="object"), + ) + msg = "using DataFrameGroupBy.sum" + with tm.assert_produces_warning(FutureWarning, match=msg): + test_case = crosstab( + df.a, df.b, df.c, aggfunc=np.sum, normalize="all", margins=True + ) + tm.assert_frame_equal(test_case, norm_sum) + + def test_crosstab_with_empties(self, using_array_manager): + # Check handling of empties + df = DataFrame( + { + "a": [1, 2, 2, 2, 2], + "b": [3, 3, 4, 4, 4], + "c": [np.nan, np.nan, np.nan, np.nan, np.nan], + } + ) + + empty = DataFrame( + [[0.0, 0.0], [0.0, 0.0]], + index=Index([1, 2], name="a", dtype="int64"), + columns=Index([3, 4], name="b"), + ) + + for i in [True, "index", "columns"]: + calculated = crosstab(df.a, df.b, values=df.c, aggfunc="count", normalize=i) + tm.assert_frame_equal(empty, calculated) + + nans = DataFrame( + [[0.0, np.nan], [0.0, 0.0]], + index=Index([1, 2], name="a", dtype="int64"), + columns=Index([3, 4], name="b"), + ) + if using_array_manager: + # INFO(ArrayManager) column without NaNs can preserve int dtype + nans[3] = nans[3].astype("int64") + + calculated = crosstab(df.a, df.b, values=df.c, aggfunc="count", normalize=False) + tm.assert_frame_equal(nans, calculated) + + def test_crosstab_errors(self): + # Issue 12578 + + df = DataFrame( + {"a": [1, 2, 2, 2, 2], "b": [3, 3, 4, 4, 4], "c": [1, 1, np.nan, 1, 1]} + ) + + error = "values cannot be used without an aggfunc." + with pytest.raises(ValueError, match=error): + crosstab(df.a, df.b, values=df.c) + + error = "aggfunc cannot be used without values" + with pytest.raises(ValueError, match=error): + crosstab(df.a, df.b, aggfunc=np.mean) + + error = "Not a valid normalize argument" + with pytest.raises(ValueError, match=error): + crosstab(df.a, df.b, normalize="42") + + with pytest.raises(ValueError, match=error): + crosstab(df.a, df.b, normalize=42) + + error = "Not a valid margins argument" + with pytest.raises(ValueError, match=error): + crosstab(df.a, df.b, normalize="all", margins=42) + + def test_crosstab_with_categorial_columns(self): + # GH 8860 + df = DataFrame( + { + "MAKE": ["Honda", "Acura", "Tesla", "Honda", "Honda", "Acura"], + "MODEL": ["Sedan", "Sedan", "Electric", "Pickup", "Sedan", "Sedan"], + } + ) + categories = ["Sedan", "Electric", "Pickup"] + df["MODEL"] = df["MODEL"].astype("category").cat.set_categories(categories) + result = crosstab(df["MAKE"], df["MODEL"]) + + expected_index = Index(["Acura", "Honda", "Tesla"], name="MAKE") + expected_columns = CategoricalIndex( + categories, categories=categories, ordered=False, name="MODEL" + ) + expected_data = [[2, 0, 0], [2, 0, 1], [0, 1, 0]] + expected = DataFrame( + expected_data, index=expected_index, columns=expected_columns + ) + tm.assert_frame_equal(result, expected) + + def test_crosstab_with_numpy_size(self): + # GH 4003 + df = DataFrame( + { + "A": ["one", "one", "two", "three"] * 6, + "B": ["A", "B", "C"] * 8, + "C": ["foo", "foo", "foo", "bar", "bar", "bar"] * 4, + "D": np.random.default_rng(2).standard_normal(24), + "E": np.random.default_rng(2).standard_normal(24), + } + ) + result = crosstab( + index=[df["A"], df["B"]], + columns=[df["C"]], + margins=True, + aggfunc=np.size, + values=df["D"], + ) + expected_index = MultiIndex( + levels=[["All", "one", "three", "two"], ["", "A", "B", "C"]], + codes=[[1, 1, 1, 2, 2, 2, 3, 3, 3, 0], [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]], + names=["A", "B"], + ) + expected_column = Index(["bar", "foo", "All"], name="C") + expected_data = np.array( + [ + [2.0, 2.0, 4.0], + [2.0, 2.0, 4.0], + [2.0, 2.0, 4.0], + [2.0, np.nan, 2.0], + [np.nan, 2.0, 2.0], + [2.0, np.nan, 2.0], + [np.nan, 2.0, 2.0], + [2.0, np.nan, 2.0], + [np.nan, 2.0, 2.0], + [12.0, 12.0, 24.0], + ] + ) + expected = DataFrame( + expected_data, index=expected_index, columns=expected_column + ) + # aggfunc is np.size, resulting in integers + expected["All"] = expected["All"].astype("int64") + tm.assert_frame_equal(result, expected) + + def test_crosstab_duplicate_names(self): + # GH 13279 / 22529 + + s1 = Series(range(3), name="foo") + s2_foo = Series(range(1, 4), name="foo") + s2_bar = Series(range(1, 4), name="bar") + s3 = Series(range(3), name="waldo") + + # check result computed with duplicate labels against + # result computed with unique labels, then relabelled + mapper = {"bar": "foo"} + + # duplicate row, column labels + result = crosstab(s1, s2_foo) + expected = crosstab(s1, s2_bar).rename_axis(columns=mapper, axis=1) + tm.assert_frame_equal(result, expected) + + # duplicate row, unique column labels + result = crosstab([s1, s2_foo], s3) + expected = crosstab([s1, s2_bar], s3).rename_axis(index=mapper, axis=0) + tm.assert_frame_equal(result, expected) + + # unique row, duplicate column labels + result = crosstab(s3, [s1, s2_foo]) + expected = crosstab(s3, [s1, s2_bar]).rename_axis(columns=mapper, axis=1) + + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("names", [["a", ("b", "c")], [("a", "b"), "c"]]) + def test_crosstab_tuple_name(self, names): + s1 = Series(range(3), name=names[0]) + s2 = Series(range(1, 4), name=names[1]) + + mi = MultiIndex.from_arrays([range(3), range(1, 4)], names=names) + expected = Series(1, index=mi).unstack(1, fill_value=0) + + result = crosstab(s1, s2) + tm.assert_frame_equal(result, expected) + + def test_crosstab_both_tuple_names(self): + # GH 18321 + s1 = Series(range(3), name=("a", "b")) + s2 = Series(range(3), name=("c", "d")) + + expected = DataFrame( + np.eye(3, dtype="int64"), + index=Index(range(3), name=("a", "b")), + columns=Index(range(3), name=("c", "d")), + ) + result = crosstab(s1, s2) + tm.assert_frame_equal(result, expected) + + def test_crosstab_unsorted_order(self): + df = DataFrame({"b": [3, 1, 2], "a": [5, 4, 6]}, index=["C", "A", "B"]) + result = crosstab(df.index, [df.b, df.a]) + e_idx = Index(["A", "B", "C"], name="row_0") + e_columns = MultiIndex.from_tuples([(1, 4), (2, 6), (3, 5)], names=["b", "a"]) + expected = DataFrame( + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], index=e_idx, columns=e_columns + ) + tm.assert_frame_equal(result, expected) + + def test_crosstab_normalize_multiple_columns(self): + # GH 15150 + df = DataFrame( + { + "A": ["one", "one", "two", "three"] * 6, + "B": ["A", "B", "C"] * 8, + "C": ["foo", "foo", "foo", "bar", "bar", "bar"] * 4, + "D": [0] * 24, + "E": [0] * 24, + } + ) + + msg = "using DataFrameGroupBy.sum" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = crosstab( + [df.A, df.B], + df.C, + values=df.D, + aggfunc=np.sum, + normalize=True, + margins=True, + ) + expected = DataFrame( + np.array([0] * 29 + [1], dtype=float).reshape(10, 3), + columns=Index(["bar", "foo", "All"], name="C"), + index=MultiIndex.from_tuples( + [ + ("one", "A"), + ("one", "B"), + ("one", "C"), + ("three", "A"), + ("three", "B"), + ("three", "C"), + ("two", "A"), + ("two", "B"), + ("two", "C"), + ("All", ""), + ], + names=["A", "B"], + ), + ) + tm.assert_frame_equal(result, expected) + + def test_margin_normalize(self): + # GH 27500 + df = DataFrame( + { + "A": ["foo", "foo", "foo", "foo", "foo", "bar", "bar", "bar", "bar"], + "B": ["one", "one", "one", "two", "two", "one", "one", "two", "two"], + "C": [ + "small", + "large", + "large", + "small", + "small", + "large", + "small", + "small", + "large", + ], + "D": [1, 2, 2, 3, 3, 4, 5, 6, 7], + "E": [2, 4, 5, 5, 6, 6, 8, 9, 9], + } + ) + # normalize on index + result = crosstab( + [df.A, df.B], df.C, margins=True, margins_name="Sub-Total", normalize=0 + ) + expected = DataFrame( + [[0.5, 0.5], [0.5, 0.5], [0.666667, 0.333333], [0, 1], [0.444444, 0.555556]] + ) + expected.index = MultiIndex( + levels=[["Sub-Total", "bar", "foo"], ["", "one", "two"]], + codes=[[1, 1, 2, 2, 0], [1, 2, 1, 2, 0]], + names=["A", "B"], + ) + expected.columns = Index(["large", "small"], name="C") + tm.assert_frame_equal(result, expected) + + # normalize on columns + result = crosstab( + [df.A, df.B], df.C, margins=True, margins_name="Sub-Total", normalize=1 + ) + expected = DataFrame( + [ + [0.25, 0.2, 0.222222], + [0.25, 0.2, 0.222222], + [0.5, 0.2, 0.333333], + [0, 0.4, 0.222222], + ] + ) + expected.columns = Index(["large", "small", "Sub-Total"], name="C") + expected.index = MultiIndex( + levels=[["bar", "foo"], ["one", "two"]], + codes=[[0, 0, 1, 1], [0, 1, 0, 1]], + names=["A", "B"], + ) + tm.assert_frame_equal(result, expected) + + # normalize on both index and column + result = crosstab( + [df.A, df.B], df.C, margins=True, margins_name="Sub-Total", normalize=True + ) + expected = DataFrame( + [ + [0.111111, 0.111111, 0.222222], + [0.111111, 0.111111, 0.222222], + [0.222222, 0.111111, 0.333333], + [0.000000, 0.222222, 0.222222], + [0.444444, 0.555555, 1], + ] + ) + expected.columns = Index(["large", "small", "Sub-Total"], name="C") + expected.index = MultiIndex( + levels=[["Sub-Total", "bar", "foo"], ["", "one", "two"]], + codes=[[1, 1, 2, 2, 0], [1, 2, 1, 2, 0]], + names=["A", "B"], + ) + tm.assert_frame_equal(result, expected) + + def test_margin_normalize_multiple_columns(self): + # GH 35144 + # use multiple columns with margins and normalization + df = DataFrame( + { + "A": ["foo", "foo", "foo", "foo", "foo", "bar", "bar", "bar", "bar"], + "B": ["one", "one", "one", "two", "two", "one", "one", "two", "two"], + "C": [ + "small", + "large", + "large", + "small", + "small", + "large", + "small", + "small", + "large", + ], + "D": [1, 2, 2, 3, 3, 4, 5, 6, 7], + "E": [2, 4, 5, 5, 6, 6, 8, 9, 9], + } + ) + result = crosstab( + index=df.C, + columns=[df.A, df.B], + margins=True, + margins_name="margin", + normalize=True, + ) + expected = DataFrame( + [ + [0.111111, 0.111111, 0.222222, 0.000000, 0.444444], + [0.111111, 0.111111, 0.111111, 0.222222, 0.555556], + [0.222222, 0.222222, 0.333333, 0.222222, 1.0], + ], + index=["large", "small", "margin"], + ) + expected.columns = MultiIndex( + levels=[["bar", "foo", "margin"], ["", "one", "two"]], + codes=[[0, 0, 1, 1, 2], [1, 2, 1, 2, 0]], + names=["A", "B"], + ) + expected.index.name = "C" + tm.assert_frame_equal(result, expected) + + def test_margin_support_Float(self): + # GH 50313 + # use Float64 formats and function aggfunc with margins + df = DataFrame( + {"A": [1, 2, 2, 1], "B": [3, 3, 4, 5], "C": [-1.0, 10.0, 1.0, 10.0]}, + dtype="Float64", + ) + result = crosstab( + df["A"], + df["B"], + values=df["C"], + aggfunc="sum", + margins=True, + ) + expected = DataFrame( + [ + [-1.0, pd.NA, 10.0, 9.0], + [10.0, 1.0, pd.NA, 11.0], + [9.0, 1.0, 10.0, 20.0], + ], + index=Index([1.0, 2.0, "All"], dtype="object", name="A"), + columns=Index([3.0, 4.0, 5.0, "All"], dtype="object", name="B"), + dtype="Float64", + ) + tm.assert_frame_equal(result, expected) + + def test_margin_with_ordered_categorical_column(self): + # GH 25278 + df = DataFrame( + { + "First": ["B", "B", "C", "A", "B", "C"], + "Second": ["C", "B", "B", "B", "C", "A"], + } + ) + df["First"] = df["First"].astype(CategoricalDtype(ordered=True)) + customized_categories_order = ["C", "A", "B"] + df["First"] = df["First"].cat.reorder_categories(customized_categories_order) + result = crosstab(df["First"], df["Second"], margins=True) + + expected_index = Index(["C", "A", "B", "All"], name="First") + expected_columns = Index(["A", "B", "C", "All"], name="Second") + expected_data = [[1, 1, 0, 2], [0, 1, 0, 1], [0, 1, 2, 3], [1, 3, 2, 6]] + expected = DataFrame( + expected_data, index=expected_index, columns=expected_columns + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("a_dtype", ["category", "int64"]) +@pytest.mark.parametrize("b_dtype", ["category", "int64"]) +def test_categoricals(a_dtype, b_dtype): + # https://github.com/pandas-dev/pandas/issues/37465 + g = np.random.default_rng(2) + a = Series(g.integers(0, 3, size=100)).astype(a_dtype) + b = Series(g.integers(0, 2, size=100)).astype(b_dtype) + result = crosstab(a, b, margins=True, dropna=False) + columns = Index([0, 1, "All"], dtype="object", name="col_0") + index = Index([0, 1, 2, "All"], dtype="object", name="row_0") + values = [[10, 18, 28], [23, 16, 39], [17, 16, 33], [50, 50, 100]] + expected = DataFrame(values, index, columns) + tm.assert_frame_equal(result, expected) + + # Verify when categorical does not have all values present + a.loc[a == 1] = 2 + a_is_cat = isinstance(a.dtype, CategoricalDtype) + assert not a_is_cat or a.value_counts().loc[1] == 0 + result = crosstab(a, b, margins=True, dropna=False) + values = [[10, 18, 28], [0, 0, 0], [40, 32, 72], [50, 50, 100]] + expected = DataFrame(values, index, columns) + if not a_is_cat: + expected = expected.loc[[0, 2, "All"]] + expected["All"] = expected["All"].astype("int64") + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_cut.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_cut.py new file mode 100644 index 0000000000000000000000000000000000000000..0811c69859c0dadf9b3b909437d9a6a16a584c24 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_cut.py @@ -0,0 +1,791 @@ +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + Categorical, + DataFrame, + DatetimeIndex, + Index, + Interval, + IntervalIndex, + Series, + TimedeltaIndex, + Timestamp, + cut, + date_range, + interval_range, + isna, + qcut, + timedelta_range, + to_datetime, +) +import pandas._testing as tm +from pandas.api.types import CategoricalDtype +import pandas.core.reshape.tile as tmod + + +def test_simple(): + data = np.ones(5, dtype="int64") + result = cut(data, 4, labels=False) + + expected = np.array([1, 1, 1, 1, 1]) + tm.assert_numpy_array_equal(result, expected, check_dtype=False) + + +@pytest.mark.parametrize("func", [list, np.array]) +def test_bins(func): + data = func([0.2, 1.4, 2.5, 6.2, 9.7, 2.1]) + result, bins = cut(data, 3, retbins=True) + + intervals = IntervalIndex.from_breaks(bins.round(3)) + intervals = intervals.take([0, 0, 0, 1, 2, 0]) + expected = Categorical(intervals, ordered=True) + + tm.assert_categorical_equal(result, expected) + tm.assert_almost_equal(bins, np.array([0.1905, 3.36666667, 6.53333333, 9.7])) + + +def test_right(): + data = np.array([0.2, 1.4, 2.5, 6.2, 9.7, 2.1, 2.575]) + result, bins = cut(data, 4, right=True, retbins=True) + + intervals = IntervalIndex.from_breaks(bins.round(3)) + expected = Categorical(intervals, ordered=True) + expected = expected.take([0, 0, 0, 2, 3, 0, 0]) + + tm.assert_categorical_equal(result, expected) + tm.assert_almost_equal(bins, np.array([0.1905, 2.575, 4.95, 7.325, 9.7])) + + +def test_no_right(): + data = np.array([0.2, 1.4, 2.5, 6.2, 9.7, 2.1, 2.575]) + result, bins = cut(data, 4, right=False, retbins=True) + + intervals = IntervalIndex.from_breaks(bins.round(3), closed="left") + intervals = intervals.take([0, 0, 0, 2, 3, 0, 1]) + expected = Categorical(intervals, ordered=True) + + tm.assert_categorical_equal(result, expected) + tm.assert_almost_equal(bins, np.array([0.2, 2.575, 4.95, 7.325, 9.7095])) + + +def test_bins_from_interval_index(): + c = cut(range(5), 3) + expected = c + result = cut(range(5), bins=expected.categories) + tm.assert_categorical_equal(result, expected) + + expected = Categorical.from_codes( + np.append(c.codes, -1), categories=c.categories, ordered=True + ) + result = cut(range(6), bins=expected.categories) + tm.assert_categorical_equal(result, expected) + + +def test_bins_from_interval_index_doc_example(): + # Make sure we preserve the bins. + ages = np.array([10, 15, 13, 12, 23, 25, 28, 59, 60]) + c = cut(ages, bins=[0, 18, 35, 70]) + expected = IntervalIndex.from_tuples([(0, 18), (18, 35), (35, 70)]) + tm.assert_index_equal(c.categories, expected) + + result = cut([25, 20, 50], bins=c.categories) + tm.assert_index_equal(result.categories, expected) + tm.assert_numpy_array_equal(result.codes, np.array([1, 1, 2], dtype="int8")) + + +def test_bins_not_overlapping_from_interval_index(): + # see gh-23980 + msg = "Overlapping IntervalIndex is not accepted" + ii = IntervalIndex.from_tuples([(0, 10), (2, 12), (4, 14)]) + + with pytest.raises(ValueError, match=msg): + cut([5, 6], bins=ii) + + +def test_bins_not_monotonic(): + msg = "bins must increase monotonically" + data = [0.2, 1.4, 2.5, 6.2, 9.7, 2.1] + + with pytest.raises(ValueError, match=msg): + cut(data, [0.1, 1.5, 1, 10]) + + +@pytest.mark.parametrize( + "x, bins, expected", + [ + ( + date_range("2017-12-31", periods=3), + [Timestamp.min, Timestamp("2018-01-01"), Timestamp.max], + IntervalIndex.from_tuples( + [ + (Timestamp.min, Timestamp("2018-01-01")), + (Timestamp("2018-01-01"), Timestamp.max), + ] + ), + ), + ( + [-1, 0, 1], + np.array( + [np.iinfo(np.int64).min, 0, np.iinfo(np.int64).max], dtype="int64" + ), + IntervalIndex.from_tuples( + [(np.iinfo(np.int64).min, 0), (0, np.iinfo(np.int64).max)] + ), + ), + ( + [ + np.timedelta64(-1, "ns"), + np.timedelta64(0, "ns"), + np.timedelta64(1, "ns"), + ], + np.array( + [ + np.timedelta64(-np.iinfo(np.int64).max, "ns"), + np.timedelta64(0, "ns"), + np.timedelta64(np.iinfo(np.int64).max, "ns"), + ] + ), + IntervalIndex.from_tuples( + [ + ( + np.timedelta64(-np.iinfo(np.int64).max, "ns"), + np.timedelta64(0, "ns"), + ), + ( + np.timedelta64(0, "ns"), + np.timedelta64(np.iinfo(np.int64).max, "ns"), + ), + ] + ), + ), + ], +) +def test_bins_monotonic_not_overflowing(x, bins, expected): + # GH 26045 + result = cut(x, bins) + tm.assert_index_equal(result.categories, expected) + + +def test_wrong_num_labels(): + msg = "Bin labels must be one fewer than the number of bin edges" + data = [0.2, 1.4, 2.5, 6.2, 9.7, 2.1] + + with pytest.raises(ValueError, match=msg): + cut(data, [0, 1, 10], labels=["foo", "bar", "baz"]) + + +@pytest.mark.parametrize( + "x,bins,msg", + [ + ([], 2, "Cannot cut empty array"), + ([1, 2, 3], 0.5, "`bins` should be a positive integer"), + ], +) +def test_cut_corner(x, bins, msg): + with pytest.raises(ValueError, match=msg): + cut(x, bins) + + +@pytest.mark.parametrize("arg", [2, np.eye(2), DataFrame(np.eye(2))]) +@pytest.mark.parametrize("cut_func", [cut, qcut]) +def test_cut_not_1d_arg(arg, cut_func): + msg = "Input array must be 1 dimensional" + with pytest.raises(ValueError, match=msg): + cut_func(arg, 2) + + +@pytest.mark.parametrize( + "data", + [ + [0, 1, 2, 3, 4, np.inf], + [-np.inf, 0, 1, 2, 3, 4], + [-np.inf, 0, 1, 2, 3, 4, np.inf], + ], +) +def test_int_bins_with_inf(data): + # GH 24314 + msg = "cannot specify integer `bins` when input data contains infinity" + with pytest.raises(ValueError, match=msg): + cut(data, bins=3) + + +def test_cut_out_of_range_more(): + # see gh-1511 + name = "x" + + ser = Series([0, -1, 0, 1, -3], name=name) + ind = cut(ser, [0, 1], labels=False) + + exp = Series([np.nan, np.nan, np.nan, 0, np.nan], name=name) + tm.assert_series_equal(ind, exp) + + +@pytest.mark.parametrize( + "right,breaks,closed", + [ + (True, [-1e-3, 0.25, 0.5, 0.75, 1], "right"), + (False, [0, 0.25, 0.5, 0.75, 1 + 1e-3], "left"), + ], +) +def test_labels(right, breaks, closed): + arr = np.tile(np.arange(0, 1.01, 0.1), 4) + + result, bins = cut(arr, 4, retbins=True, right=right) + ex_levels = IntervalIndex.from_breaks(breaks, closed=closed) + tm.assert_index_equal(result.categories, ex_levels) + + +def test_cut_pass_series_name_to_factor(): + name = "foo" + ser = Series(np.random.default_rng(2).standard_normal(100), name=name) + + factor = cut(ser, 4) + assert factor.name == name + + +def test_label_precision(): + arr = np.arange(0, 0.73, 0.01) + result = cut(arr, 4, precision=2) + + ex_levels = IntervalIndex.from_breaks([-0.00072, 0.18, 0.36, 0.54, 0.72]) + tm.assert_index_equal(result.categories, ex_levels) + + +@pytest.mark.parametrize("labels", [None, False]) +def test_na_handling(labels): + arr = np.arange(0, 0.75, 0.01) + arr[::3] = np.nan + + result = cut(arr, 4, labels=labels) + result = np.asarray(result) + + expected = np.where(isna(arr), np.nan, result) + tm.assert_almost_equal(result, expected) + + +def test_inf_handling(): + data = np.arange(6) + data_ser = Series(data, dtype="int64") + + bins = [-np.inf, 2, 4, np.inf] + result = cut(data, bins) + result_ser = cut(data_ser, bins) + + ex_uniques = IntervalIndex.from_breaks(bins) + tm.assert_index_equal(result.categories, ex_uniques) + + assert result[5] == Interval(4, np.inf) + assert result[0] == Interval(-np.inf, 2) + assert result_ser[5] == Interval(4, np.inf) + assert result_ser[0] == Interval(-np.inf, 2) + + +def test_cut_out_of_bounds(): + arr = np.random.default_rng(2).standard_normal(100) + result = cut(arr, [-1, 0, 1]) + + mask = isna(result) + ex_mask = (arr < -1) | (arr > 1) + tm.assert_numpy_array_equal(mask, ex_mask) + + +@pytest.mark.parametrize( + "get_labels,get_expected", + [ + ( + lambda labels: labels, + lambda labels: Categorical( + ["Medium"] + 4 * ["Small"] + ["Medium", "Large"], + categories=labels, + ordered=True, + ), + ), + ( + lambda labels: Categorical.from_codes([0, 1, 2], labels), + lambda labels: Categorical.from_codes([1] + 4 * [0] + [1, 2], labels), + ), + ], +) +def test_cut_pass_labels(get_labels, get_expected): + bins = [0, 25, 50, 100] + arr = [50, 5, 10, 15, 20, 30, 70] + labels = ["Small", "Medium", "Large"] + + result = cut(arr, bins, labels=get_labels(labels)) + tm.assert_categorical_equal(result, get_expected(labels)) + + +def test_cut_pass_labels_compat(): + # see gh-16459 + arr = [50, 5, 10, 15, 20, 30, 70] + labels = ["Good", "Medium", "Bad"] + + result = cut(arr, 3, labels=labels) + exp = cut(arr, 3, labels=Categorical(labels, categories=labels, ordered=True)) + tm.assert_categorical_equal(result, exp) + + +@pytest.mark.parametrize("x", [np.arange(11.0), np.arange(11.0) / 1e10]) +def test_round_frac_just_works(x): + # It works. + cut(x, 2) + + +@pytest.mark.parametrize( + "val,precision,expected", + [ + (-117.9998, 3, -118), + (117.9998, 3, 118), + (117.9998, 2, 118), + (0.000123456, 2, 0.00012), + ], +) +def test_round_frac(val, precision, expected): + # see gh-1979 + result = tmod._round_frac(val, precision=precision) + assert result == expected + + +def test_cut_return_intervals(): + ser = Series([0, 1, 2, 3, 4, 5, 6, 7, 8]) + result = cut(ser, 3) + + exp_bins = np.linspace(0, 8, num=4).round(3) + exp_bins[0] -= 0.008 + + expected = Series( + IntervalIndex.from_breaks(exp_bins, closed="right").take( + [0, 0, 0, 1, 1, 1, 2, 2, 2] + ) + ).astype(CategoricalDtype(ordered=True)) + tm.assert_series_equal(result, expected) + + +def test_series_ret_bins(): + # see gh-8589 + ser = Series(np.arange(4)) + result, bins = cut(ser, 2, retbins=True) + + expected = Series( + IntervalIndex.from_breaks([-0.003, 1.5, 3], closed="right").repeat(2) + ).astype(CategoricalDtype(ordered=True)) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "kwargs,msg", + [ + ({"duplicates": "drop"}, None), + ({}, "Bin edges must be unique"), + ({"duplicates": "raise"}, "Bin edges must be unique"), + ({"duplicates": "foo"}, "invalid value for 'duplicates' parameter"), + ], +) +def test_cut_duplicates_bin(kwargs, msg): + # see gh-20947 + bins = [0, 2, 4, 6, 10, 10] + values = Series(np.array([1, 3, 5, 7, 9]), index=["a", "b", "c", "d", "e"]) + + if msg is not None: + with pytest.raises(ValueError, match=msg): + cut(values, bins, **kwargs) + else: + result = cut(values, bins, **kwargs) + expected = cut(values, pd.unique(np.asarray(bins))) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("data", [9.0, -9.0, 0.0]) +@pytest.mark.parametrize("length", [1, 2]) +def test_single_bin(data, length): + # see gh-14652, gh-15428 + ser = Series([data] * length) + result = cut(ser, 1, labels=False) + + expected = Series([0] * length, dtype=np.intp) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "array_1_writeable,array_2_writeable", [(True, True), (True, False), (False, False)] +) +def test_cut_read_only(array_1_writeable, array_2_writeable): + # issue 18773 + array_1 = np.arange(0, 100, 10) + array_1.flags.writeable = array_1_writeable + + array_2 = np.arange(0, 100, 10) + array_2.flags.writeable = array_2_writeable + + hundred_elements = np.arange(100) + tm.assert_categorical_equal( + cut(hundred_elements, array_1), cut(hundred_elements, array_2) + ) + + +@pytest.mark.parametrize( + "conv", + [ + lambda v: Timestamp(v), + lambda v: to_datetime(v), + lambda v: np.datetime64(v), + lambda v: Timestamp(v).to_pydatetime(), + ], +) +def test_datetime_bin(conv): + data = [np.datetime64("2012-12-13"), np.datetime64("2012-12-15")] + bin_data = ["2012-12-12", "2012-12-14", "2012-12-16"] + + expected = Series( + IntervalIndex( + [ + Interval(Timestamp(bin_data[0]), Timestamp(bin_data[1])), + Interval(Timestamp(bin_data[1]), Timestamp(bin_data[2])), + ] + ) + ).astype(CategoricalDtype(ordered=True)) + + bins = [conv(v) for v in bin_data] + result = Series(cut(data, bins=bins)) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("box", [Series, Index, np.array, list]) +def test_datetime_cut(unit, box): + # see gh-14714 + # + # Testing time data when it comes in various collection types. + data = to_datetime(["2013-01-01", "2013-01-02", "2013-01-03"]).astype(f"M8[{unit}]") + data = box(data) + result, _ = cut(data, 3, retbins=True) + + if box is list: + # We don't (yet) do inference on these, so get nanos + unit = "ns" + + if unit == "s": + # See https://github.com/pandas-dev/pandas/pull/56101#discussion_r1405325425 + # for why we round to 8 seconds instead of 7 + left = DatetimeIndex( + ["2012-12-31 23:57:08", "2013-01-01 16:00:00", "2013-01-02 08:00:00"], + dtype=f"M8[{unit}]", + ) + else: + left = DatetimeIndex( + [ + "2012-12-31 23:57:07.200000", + "2013-01-01 16:00:00", + "2013-01-02 08:00:00", + ], + dtype=f"M8[{unit}]", + ) + right = DatetimeIndex( + ["2013-01-01 16:00:00", "2013-01-02 08:00:00", "2013-01-03 00:00:00"], + dtype=f"M8[{unit}]", + ) + + exp_intervals = IntervalIndex.from_arrays(left, right) + expected = Series(exp_intervals).astype(CategoricalDtype(ordered=True)) + tm.assert_series_equal(Series(result), expected) + + +@pytest.mark.parametrize("box", [list, np.array, Index, Series]) +def test_datetime_tz_cut_mismatched_tzawareness(box): + # GH#54964 + bins = box( + [ + Timestamp("2013-01-01 04:57:07.200000"), + Timestamp("2013-01-01 21:00:00"), + Timestamp("2013-01-02 13:00:00"), + Timestamp("2013-01-03 05:00:00"), + ] + ) + ser = Series(date_range("20130101", periods=3, tz="US/Eastern")) + + msg = "Cannot use timezone-naive bins with timezone-aware values" + with pytest.raises(ValueError, match=msg): + cut(ser, bins) + + +@pytest.mark.parametrize( + "bins", + [ + 3, + [ + Timestamp("2013-01-01 04:57:07.200000", tz="UTC").tz_convert("US/Eastern"), + Timestamp("2013-01-01 21:00:00", tz="UTC").tz_convert("US/Eastern"), + Timestamp("2013-01-02 13:00:00", tz="UTC").tz_convert("US/Eastern"), + Timestamp("2013-01-03 05:00:00", tz="UTC").tz_convert("US/Eastern"), + ], + ], +) +@pytest.mark.parametrize("box", [list, np.array, Index, Series]) +def test_datetime_tz_cut(bins, box): + # see gh-19872 + tz = "US/Eastern" + ser = Series(date_range("20130101", periods=3, tz=tz)) + + if not isinstance(bins, int): + bins = box(bins) + + result = cut(ser, bins) + expected = Series( + IntervalIndex( + [ + Interval( + Timestamp("2012-12-31 23:57:07.200000", tz=tz), + Timestamp("2013-01-01 16:00:00", tz=tz), + ), + Interval( + Timestamp("2013-01-01 16:00:00", tz=tz), + Timestamp("2013-01-02 08:00:00", tz=tz), + ), + Interval( + Timestamp("2013-01-02 08:00:00", tz=tz), + Timestamp("2013-01-03 00:00:00", tz=tz), + ), + ] + ) + ).astype(CategoricalDtype(ordered=True)) + tm.assert_series_equal(result, expected) + + +def test_datetime_nan_error(): + msg = "bins must be of datetime64 dtype" + + with pytest.raises(ValueError, match=msg): + cut(date_range("20130101", periods=3), bins=[0, 2, 4]) + + +def test_datetime_nan_mask(): + result = cut( + date_range("20130102", periods=5), bins=date_range("20130101", periods=2) + ) + + mask = result.categories.isna() + tm.assert_numpy_array_equal(mask, np.array([False])) + + mask = result.isna() + tm.assert_numpy_array_equal(mask, np.array([False, True, True, True, True])) + + +@pytest.mark.parametrize("tz", [None, "UTC", "US/Pacific"]) +def test_datetime_cut_roundtrip(tz, unit): + # see gh-19891 + ser = Series(date_range("20180101", periods=3, tz=tz, unit=unit)) + result, result_bins = cut(ser, 2, retbins=True) + + expected = cut(ser, result_bins) + tm.assert_series_equal(result, expected) + + if unit == "s": + # TODO: constructing DatetimeIndex with dtype="M8[s]" without truncating + # the first entry here raises in array_to_datetime. Should truncate + # instead of raising? + # See https://github.com/pandas-dev/pandas/pull/56101#discussion_r1405325425 + # for why we round to 8 seconds instead of 7 + expected_bins = DatetimeIndex( + ["2017-12-31 23:57:08", "2018-01-02 00:00:00", "2018-01-03 00:00:00"], + dtype=f"M8[{unit}]", + ) + else: + expected_bins = DatetimeIndex( + [ + "2017-12-31 23:57:07.200000", + "2018-01-02 00:00:00", + "2018-01-03 00:00:00", + ], + dtype=f"M8[{unit}]", + ) + expected_bins = expected_bins.tz_localize(tz) + tm.assert_index_equal(result_bins, expected_bins) + + +def test_timedelta_cut_roundtrip(): + # see gh-19891 + ser = Series(timedelta_range("1day", periods=3)) + result, result_bins = cut(ser, 2, retbins=True) + + expected = cut(ser, result_bins) + tm.assert_series_equal(result, expected) + + expected_bins = TimedeltaIndex( + ["0 days 23:57:07.200000", "2 days 00:00:00", "3 days 00:00:00"] + ) + tm.assert_index_equal(result_bins, expected_bins) + + +@pytest.mark.parametrize("bins", [6, 7]) +@pytest.mark.parametrize( + "box, compare", + [ + (Series, tm.assert_series_equal), + (np.array, tm.assert_categorical_equal), + (list, tm.assert_equal), + ], +) +def test_cut_bool_coercion_to_int(bins, box, compare): + # issue 20303 + data_expected = box([0, 1, 1, 0, 1] * 10) + data_result = box([False, True, True, False, True] * 10) + expected = cut(data_expected, bins, duplicates="drop") + result = cut(data_result, bins, duplicates="drop") + compare(result, expected) + + +@pytest.mark.parametrize("labels", ["foo", 1, True]) +def test_cut_incorrect_labels(labels): + # GH 13318 + values = range(5) + msg = "Bin labels must either be False, None or passed in as a list-like argument" + with pytest.raises(ValueError, match=msg): + cut(values, 4, labels=labels) + + +@pytest.mark.parametrize("bins", [3, [0, 5, 15]]) +@pytest.mark.parametrize("right", [True, False]) +@pytest.mark.parametrize("include_lowest", [True, False]) +def test_cut_nullable_integer(bins, right, include_lowest): + a = np.random.default_rng(2).integers(0, 10, size=50).astype(float) + a[::2] = np.nan + result = cut( + pd.array(a, dtype="Int64"), bins, right=right, include_lowest=include_lowest + ) + expected = cut(a, bins, right=right, include_lowest=include_lowest) + tm.assert_categorical_equal(result, expected) + + +@pytest.mark.parametrize( + "data, bins, labels, expected_codes, expected_labels", + [ + ([15, 17, 19], [14, 16, 18, 20], ["A", "B", "A"], [0, 1, 0], ["A", "B"]), + ([1, 3, 5], [0, 2, 4, 6, 8], [2, 0, 1, 2], [2, 0, 1], [0, 1, 2]), + ], +) +def test_cut_non_unique_labels(data, bins, labels, expected_codes, expected_labels): + # GH 33141 + result = cut(data, bins=bins, labels=labels, ordered=False) + expected = Categorical.from_codes( + expected_codes, categories=expected_labels, ordered=False + ) + tm.assert_categorical_equal(result, expected) + + +@pytest.mark.parametrize( + "data, bins, labels, expected_codes, expected_labels", + [ + ([15, 17, 19], [14, 16, 18, 20], ["C", "B", "A"], [0, 1, 2], ["C", "B", "A"]), + ([1, 3, 5], [0, 2, 4, 6, 8], [3, 0, 1, 2], [0, 1, 2], [3, 0, 1, 2]), + ], +) +def test_cut_unordered_labels(data, bins, labels, expected_codes, expected_labels): + # GH 33141 + result = cut(data, bins=bins, labels=labels, ordered=False) + expected = Categorical.from_codes( + expected_codes, categories=expected_labels, ordered=False + ) + tm.assert_categorical_equal(result, expected) + + +def test_cut_unordered_with_missing_labels_raises_error(): + # GH 33141 + msg = "'labels' must be provided if 'ordered = False'" + with pytest.raises(ValueError, match=msg): + cut([0.5, 3], bins=[0, 1, 2], ordered=False) + + +def test_cut_unordered_with_series_labels(): + # https://github.com/pandas-dev/pandas/issues/36603 + ser = Series([1, 2, 3, 4, 5]) + bins = Series([0, 2, 4, 6]) + labels = Series(["a", "b", "c"]) + result = cut(ser, bins=bins, labels=labels, ordered=False) + expected = Series(["a", "a", "b", "b", "c"], dtype="category") + tm.assert_series_equal(result, expected) + + +def test_cut_no_warnings(): + df = DataFrame({"value": np.random.default_rng(2).integers(0, 100, 20)}) + labels = [f"{i} - {i + 9}" for i in range(0, 100, 10)] + with tm.assert_produces_warning(False): + df["group"] = cut(df.value, range(0, 105, 10), right=False, labels=labels) + + +def test_cut_with_duplicated_index_lowest_included(): + # GH 42185 + expected = Series( + [Interval(-0.001, 2, closed="right")] * 3 + + [Interval(2, 4, closed="right"), Interval(-0.001, 2, closed="right")], + index=[0, 1, 2, 3, 0], + dtype="category", + ).cat.as_ordered() + + ser = Series([0, 1, 2, 3, 0], index=[0, 1, 2, 3, 0]) + result = cut(ser, bins=[0, 2, 4], include_lowest=True) + tm.assert_series_equal(result, expected) + + +def test_cut_with_nonexact_categorical_indices(): + # GH 42424 + + ser = Series(range(100)) + ser1 = cut(ser, 10).value_counts().head(5) + ser2 = cut(ser, 10).value_counts().tail(5) + result = DataFrame({"1": ser1, "2": ser2}) + + index = pd.CategoricalIndex( + [ + Interval(-0.099, 9.9, closed="right"), + Interval(9.9, 19.8, closed="right"), + Interval(19.8, 29.7, closed="right"), + Interval(29.7, 39.6, closed="right"), + Interval(39.6, 49.5, closed="right"), + Interval(49.5, 59.4, closed="right"), + Interval(59.4, 69.3, closed="right"), + Interval(69.3, 79.2, closed="right"), + Interval(79.2, 89.1, closed="right"), + Interval(89.1, 99, closed="right"), + ], + ordered=True, + ) + + expected = DataFrame( + {"1": [10] * 5 + [np.nan] * 5, "2": [np.nan] * 5 + [10] * 5}, index=index + ) + + tm.assert_frame_equal(expected, result) + + +def test_cut_with_timestamp_tuple_labels(): + # GH 40661 + labels = [(Timestamp(10),), (Timestamp(20),), (Timestamp(30),)] + result = cut([2, 4, 6], bins=[1, 3, 5, 7], labels=labels) + + expected = Categorical.from_codes([0, 1, 2], labels, ordered=True) + tm.assert_categorical_equal(result, expected) + + +def test_cut_bins_datetime_intervalindex(): + # https://github.com/pandas-dev/pandas/issues/46218 + bins = interval_range(Timestamp("2022-02-25"), Timestamp("2022-02-27"), freq="1D") + # passing Series instead of list is important to trigger bug + result = cut(Series([Timestamp("2022-02-26")]).astype("M8[ns]"), bins=bins) + expected = Categorical.from_codes([0], bins, ordered=True) + tm.assert_categorical_equal(result.array, expected) + + +def test_cut_with_nullable_int64(): + # GH 30787 + series = Series([0, 1, 2, 3, 4, pd.NA, 6, 7], dtype="Int64") + bins = [0, 2, 4, 6, 8] + intervals = IntervalIndex.from_breaks(bins) + + expected = Series( + Categorical.from_codes([-1, 0, 0, 1, 1, -1, 2, 3], intervals, ordered=True) + ) + + result = cut(series, bins=bins) + + tm.assert_series_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_from_dummies.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_from_dummies.py new file mode 100644 index 0000000000000000000000000000000000000000..f9a03222c8057157072fc12af015b183da0f4de8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_from_dummies.py @@ -0,0 +1,447 @@ +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Series, + from_dummies, + get_dummies, +) +import pandas._testing as tm + + +@pytest.fixture +def dummies_basic(): + return DataFrame( + { + "col1_a": [1, 0, 1], + "col1_b": [0, 1, 0], + "col2_a": [0, 1, 0], + "col2_b": [1, 0, 0], + "col2_c": [0, 0, 1], + }, + ) + + +@pytest.fixture +def dummies_with_unassigned(): + return DataFrame( + { + "col1_a": [1, 0, 0], + "col1_b": [0, 1, 0], + "col2_a": [0, 1, 0], + "col2_b": [0, 0, 0], + "col2_c": [0, 0, 1], + }, + ) + + +def test_error_wrong_data_type(): + dummies = [0, 1, 0] + with pytest.raises( + TypeError, + match=r"Expected 'data' to be a 'DataFrame'; Received 'data' of type: list", + ): + from_dummies(dummies) + + +def test_error_no_prefix_contains_unassigned(): + dummies = DataFrame({"a": [1, 0, 0], "b": [0, 1, 0]}) + with pytest.raises( + ValueError, + match=( + r"Dummy DataFrame contains unassigned value\(s\); " + r"First instance in row: 2" + ), + ): + from_dummies(dummies) + + +def test_error_no_prefix_wrong_default_category_type(): + dummies = DataFrame({"a": [1, 0, 1], "b": [0, 1, 1]}) + with pytest.raises( + TypeError, + match=( + r"Expected 'default_category' to be of type 'None', 'Hashable', or 'dict'; " + r"Received 'default_category' of type: list" + ), + ): + from_dummies(dummies, default_category=["c", "d"]) + + +def test_error_no_prefix_multi_assignment(): + dummies = DataFrame({"a": [1, 0, 1], "b": [0, 1, 1]}) + with pytest.raises( + ValueError, + match=( + r"Dummy DataFrame contains multi-assignment\(s\); " + r"First instance in row: 2" + ), + ): + from_dummies(dummies) + + +def test_error_no_prefix_contains_nan(): + dummies = DataFrame({"a": [1, 0, 0], "b": [0, 1, np.nan]}) + with pytest.raises( + ValueError, match=r"Dummy DataFrame contains NA value in column: 'b'" + ): + from_dummies(dummies) + + +def test_error_contains_non_dummies(): + dummies = DataFrame( + {"a": [1, 6, 3, 1], "b": [0, 1, 0, 2], "c": ["c1", "c2", "c3", "c4"]} + ) + with pytest.raises( + TypeError, + match=r"Passed DataFrame contains non-dummy data", + ): + from_dummies(dummies) + + +def test_error_with_prefix_multiple_seperators(): + dummies = DataFrame( + { + "col1_a": [1, 0, 1], + "col1_b": [0, 1, 0], + "col2-a": [0, 1, 0], + "col2-b": [1, 0, 1], + }, + ) + with pytest.raises( + ValueError, + match=(r"Separator not specified for column: col2-a"), + ): + from_dummies(dummies, sep="_") + + +def test_error_with_prefix_sep_wrong_type(dummies_basic): + with pytest.raises( + TypeError, + match=( + r"Expected 'sep' to be of type 'str' or 'None'; " + r"Received 'sep' of type: list" + ), + ): + from_dummies(dummies_basic, sep=["_"]) + + +def test_error_with_prefix_contains_unassigned(dummies_with_unassigned): + with pytest.raises( + ValueError, + match=( + r"Dummy DataFrame contains unassigned value\(s\); " + r"First instance in row: 2" + ), + ): + from_dummies(dummies_with_unassigned, sep="_") + + +def test_error_with_prefix_default_category_wrong_type(dummies_with_unassigned): + with pytest.raises( + TypeError, + match=( + r"Expected 'default_category' to be of type 'None', 'Hashable', or 'dict'; " + r"Received 'default_category' of type: list" + ), + ): + from_dummies(dummies_with_unassigned, sep="_", default_category=["x", "y"]) + + +def test_error_with_prefix_default_category_dict_not_complete( + dummies_with_unassigned, +): + with pytest.raises( + ValueError, + match=( + r"Length of 'default_category' \(1\) did not match " + r"the length of the columns being encoded \(2\)" + ), + ): + from_dummies(dummies_with_unassigned, sep="_", default_category={"col1": "x"}) + + +def test_error_with_prefix_contains_nan(dummies_basic): + # Set float64 dtype to avoid upcast when setting np.nan + dummies_basic["col2_c"] = dummies_basic["col2_c"].astype("float64") + dummies_basic.loc[2, "col2_c"] = np.nan + with pytest.raises( + ValueError, match=r"Dummy DataFrame contains NA value in column: 'col2_c'" + ): + from_dummies(dummies_basic, sep="_") + + +def test_error_with_prefix_contains_non_dummies(dummies_basic): + # Set object dtype to avoid upcast when setting "str" + dummies_basic["col2_c"] = dummies_basic["col2_c"].astype(object) + dummies_basic.loc[2, "col2_c"] = "str" + with pytest.raises(TypeError, match=r"Passed DataFrame contains non-dummy data"): + from_dummies(dummies_basic, sep="_") + + +def test_error_with_prefix_double_assignment(): + dummies = DataFrame( + { + "col1_a": [1, 0, 1], + "col1_b": [1, 1, 0], + "col2_a": [0, 1, 0], + "col2_b": [1, 0, 0], + "col2_c": [0, 0, 1], + }, + ) + with pytest.raises( + ValueError, + match=( + r"Dummy DataFrame contains multi-assignment\(s\); " + r"First instance in row: 0" + ), + ): + from_dummies(dummies, sep="_") + + +def test_roundtrip_series_to_dataframe(): + categories = Series(["a", "b", "c", "a"]) + dummies = get_dummies(categories) + result = from_dummies(dummies) + expected = DataFrame({"": ["a", "b", "c", "a"]}) + tm.assert_frame_equal(result, expected) + + +def test_roundtrip_single_column_dataframe(): + categories = DataFrame({"": ["a", "b", "c", "a"]}) + dummies = get_dummies(categories) + result = from_dummies(dummies, sep="_") + expected = categories + tm.assert_frame_equal(result, expected) + + +def test_roundtrip_with_prefixes(): + categories = DataFrame({"col1": ["a", "b", "a"], "col2": ["b", "a", "c"]}) + dummies = get_dummies(categories) + result = from_dummies(dummies, sep="_") + expected = categories + tm.assert_frame_equal(result, expected) + + +def test_no_prefix_string_cats_basic(): + dummies = DataFrame({"a": [1, 0, 0, 1], "b": [0, 1, 0, 0], "c": [0, 0, 1, 0]}) + expected = DataFrame({"": ["a", "b", "c", "a"]}) + result = from_dummies(dummies) + tm.assert_frame_equal(result, expected) + + +def test_no_prefix_string_cats_basic_bool_values(): + dummies = DataFrame( + { + "a": [True, False, False, True], + "b": [False, True, False, False], + "c": [False, False, True, False], + } + ) + expected = DataFrame({"": ["a", "b", "c", "a"]}) + result = from_dummies(dummies) + tm.assert_frame_equal(result, expected) + + +def test_no_prefix_string_cats_basic_mixed_bool_values(): + dummies = DataFrame( + {"a": [1, 0, 0, 1], "b": [False, True, False, False], "c": [0, 0, 1, 0]} + ) + expected = DataFrame({"": ["a", "b", "c", "a"]}) + result = from_dummies(dummies) + tm.assert_frame_equal(result, expected) + + +def test_no_prefix_int_cats_basic(): + dummies = DataFrame( + {1: [1, 0, 0, 0], 25: [0, 1, 0, 0], 2: [0, 0, 1, 0], 5: [0, 0, 0, 1]} + ) + expected = DataFrame({"": [1, 25, 2, 5]}) + result = from_dummies(dummies) + tm.assert_frame_equal(result, expected) + + +def test_no_prefix_float_cats_basic(): + dummies = DataFrame( + {1.0: [1, 0, 0, 0], 25.0: [0, 1, 0, 0], 2.5: [0, 0, 1, 0], 5.84: [0, 0, 0, 1]} + ) + expected = DataFrame({"": [1.0, 25.0, 2.5, 5.84]}) + result = from_dummies(dummies) + tm.assert_frame_equal(result, expected) + + +def test_no_prefix_mixed_cats_basic(): + dummies = DataFrame( + { + 1.23: [1, 0, 0, 0, 0], + "c": [0, 1, 0, 0, 0], + 2: [0, 0, 1, 0, 0], + False: [0, 0, 0, 1, 0], + None: [0, 0, 0, 0, 1], + } + ) + expected = DataFrame({"": [1.23, "c", 2, False, None]}, dtype="object") + result = from_dummies(dummies) + tm.assert_frame_equal(result, expected) + + +def test_no_prefix_string_cats_contains_get_dummies_NaN_column(): + dummies = DataFrame({"a": [1, 0, 0], "b": [0, 1, 0], "NaN": [0, 0, 1]}) + expected = DataFrame({"": ["a", "b", "NaN"]}) + result = from_dummies(dummies) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "default_category, expected", + [ + pytest.param( + "c", + DataFrame({"": ["a", "b", "c"]}), + id="default_category is a str", + ), + pytest.param( + 1, + DataFrame({"": ["a", "b", 1]}), + id="default_category is a int", + ), + pytest.param( + 1.25, + DataFrame({"": ["a", "b", 1.25]}), + id="default_category is a float", + ), + pytest.param( + 0, + DataFrame({"": ["a", "b", 0]}), + id="default_category is a 0", + ), + pytest.param( + False, + DataFrame({"": ["a", "b", False]}), + id="default_category is a bool", + ), + pytest.param( + (1, 2), + DataFrame({"": ["a", "b", (1, 2)]}), + id="default_category is a tuple", + ), + ], +) +def test_no_prefix_string_cats_default_category( + default_category, expected, using_infer_string +): + dummies = DataFrame({"a": [1, 0, 0], "b": [0, 1, 0]}) + result = from_dummies(dummies, default_category=default_category) + if using_infer_string: + expected[""] = expected[""].astype("string[pyarrow_numpy]") + tm.assert_frame_equal(result, expected) + + +def test_with_prefix_basic(dummies_basic): + expected = DataFrame({"col1": ["a", "b", "a"], "col2": ["b", "a", "c"]}) + result = from_dummies(dummies_basic, sep="_") + tm.assert_frame_equal(result, expected) + + +def test_with_prefix_contains_get_dummies_NaN_column(): + dummies = DataFrame( + { + "col1_a": [1, 0, 0], + "col1_b": [0, 1, 0], + "col1_NaN": [0, 0, 1], + "col2_a": [0, 1, 0], + "col2_b": [0, 0, 0], + "col2_c": [0, 0, 1], + "col2_NaN": [1, 0, 0], + }, + ) + expected = DataFrame({"col1": ["a", "b", "NaN"], "col2": ["NaN", "a", "c"]}) + result = from_dummies(dummies, sep="_") + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "default_category, expected", + [ + pytest.param( + "x", + DataFrame({"col1": ["a", "b", "x"], "col2": ["x", "a", "c"]}), + id="default_category is a str", + ), + pytest.param( + 0, + DataFrame({"col1": ["a", "b", 0], "col2": [0, "a", "c"]}), + id="default_category is a 0", + ), + pytest.param( + False, + DataFrame({"col1": ["a", "b", False], "col2": [False, "a", "c"]}), + id="default_category is a False", + ), + pytest.param( + {"col2": 1, "col1": 2.5}, + DataFrame({"col1": ["a", "b", 2.5], "col2": [1, "a", "c"]}), + id="default_category is a dict with int and float values", + ), + pytest.param( + {"col2": None, "col1": False}, + DataFrame({"col1": ["a", "b", False], "col2": [None, "a", "c"]}), + id="default_category is a dict with bool and None values", + ), + pytest.param( + {"col2": (1, 2), "col1": [1.25, False]}, + DataFrame({"col1": ["a", "b", [1.25, False]], "col2": [(1, 2), "a", "c"]}), + id="default_category is a dict with list and tuple values", + ), + ], +) +def test_with_prefix_default_category( + dummies_with_unassigned, default_category, expected +): + result = from_dummies( + dummies_with_unassigned, sep="_", default_category=default_category + ) + tm.assert_frame_equal(result, expected) + + +def test_ea_categories(): + # GH 54300 + df = DataFrame({"a": [1, 0, 0, 1], "b": [0, 1, 0, 0], "c": [0, 0, 1, 0]}) + df.columns = df.columns.astype("string[python]") + result = from_dummies(df) + expected = DataFrame({"": Series(list("abca"), dtype="string[python]")}) + tm.assert_frame_equal(result, expected) + + +def test_ea_categories_with_sep(): + # GH 54300 + df = DataFrame( + { + "col1_a": [1, 0, 1], + "col1_b": [0, 1, 0], + "col2_a": [0, 1, 0], + "col2_b": [1, 0, 0], + "col2_c": [0, 0, 1], + } + ) + df.columns = df.columns.astype("string[python]") + result = from_dummies(df, sep="_") + expected = DataFrame( + { + "col1": Series(list("aba"), dtype="string[python]"), + "col2": Series(list("bac"), dtype="string[python]"), + } + ) + expected.columns = expected.columns.astype("string[python]") + tm.assert_frame_equal(result, expected) + + +def test_maintain_original_index(): + # GH 54300 + df = DataFrame( + {"a": [1, 0, 0, 1], "b": [0, 1, 0, 0], "c": [0, 0, 1, 0]}, index=list("abcd") + ) + result = from_dummies(df) + expected = DataFrame({"": list("abca")}, index=list("abcd")) + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_get_dummies.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_get_dummies.py new file mode 100644 index 0000000000000000000000000000000000000000..31260e4dcb7d2ecf5f003ce4aadf890493b59887 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_get_dummies.py @@ -0,0 +1,743 @@ +import re +import unicodedata + +import numpy as np +import pytest + +import pandas.util._test_decorators as td + +from pandas.core.dtypes.common import is_integer_dtype + +import pandas as pd +from pandas import ( + ArrowDtype, + Categorical, + CategoricalDtype, + CategoricalIndex, + DataFrame, + Index, + RangeIndex, + Series, + SparseDtype, + get_dummies, +) +import pandas._testing as tm +from pandas.core.arrays.sparse import SparseArray + +try: + import pyarrow as pa +except ImportError: + pa = None + + +class TestGetDummies: + @pytest.fixture + def df(self): + return DataFrame({"A": ["a", "b", "a"], "B": ["b", "b", "c"], "C": [1, 2, 3]}) + + @pytest.fixture(params=["uint8", "i8", np.float64, bool, None]) + def dtype(self, request): + return np.dtype(request.param) + + @pytest.fixture(params=["dense", "sparse"]) + def sparse(self, request): + # params are strings to simplify reading test results, + # e.g. TestGetDummies::test_basic[uint8-sparse] instead of [uint8-True] + return request.param == "sparse" + + def effective_dtype(self, dtype): + if dtype is None: + return np.uint8 + return dtype + + def test_get_dummies_raises_on_dtype_object(self, df): + msg = "dtype=object is not a valid dtype for get_dummies" + with pytest.raises(ValueError, match=msg): + get_dummies(df, dtype="object") + + def test_get_dummies_basic(self, sparse, dtype): + s_list = list("abc") + s_series = Series(s_list) + s_series_index = Series(s_list, list("ABC")) + + expected = DataFrame( + {"a": [1, 0, 0], "b": [0, 1, 0], "c": [0, 0, 1]}, + dtype=self.effective_dtype(dtype), + ) + if sparse: + if dtype.kind == "b": + expected = expected.apply(SparseArray, fill_value=False) + else: + expected = expected.apply(SparseArray, fill_value=0.0) + result = get_dummies(s_list, sparse=sparse, dtype=dtype) + tm.assert_frame_equal(result, expected) + + result = get_dummies(s_series, sparse=sparse, dtype=dtype) + tm.assert_frame_equal(result, expected) + + expected.index = list("ABC") + result = get_dummies(s_series_index, sparse=sparse, dtype=dtype) + tm.assert_frame_equal(result, expected) + + def test_get_dummies_basic_types(self, sparse, dtype, using_infer_string): + # GH 10531 + s_list = list("abc") + s_series = Series(s_list) + s_df = DataFrame( + {"a": [0, 1, 0, 1, 2], "b": ["A", "A", "B", "C", "C"], "c": [2, 3, 3, 3, 2]} + ) + + expected = DataFrame( + {"a": [1, 0, 0], "b": [0, 1, 0], "c": [0, 0, 1]}, + dtype=self.effective_dtype(dtype), + columns=list("abc"), + ) + if sparse: + if is_integer_dtype(dtype): + fill_value = 0 + elif dtype == bool: + fill_value = False + else: + fill_value = 0.0 + + expected = expected.apply(SparseArray, fill_value=fill_value) + result = get_dummies(s_list, sparse=sparse, dtype=dtype) + tm.assert_frame_equal(result, expected) + + result = get_dummies(s_series, sparse=sparse, dtype=dtype) + tm.assert_frame_equal(result, expected) + + result = get_dummies(s_df, columns=s_df.columns, sparse=sparse, dtype=dtype) + if sparse: + dtype_name = f"Sparse[{self.effective_dtype(dtype).name}, {fill_value}]" + else: + dtype_name = self.effective_dtype(dtype).name + + expected = Series({dtype_name: 8}, name="count") + result = result.dtypes.value_counts() + result.index = [str(i) for i in result.index] + tm.assert_series_equal(result, expected) + + result = get_dummies(s_df, columns=["a"], sparse=sparse, dtype=dtype) + + key = "string" if using_infer_string else "object" + expected_counts = {"int64": 1, key: 1} + expected_counts[dtype_name] = 3 + expected_counts.get(dtype_name, 0) + + expected = Series(expected_counts, name="count").sort_index() + result = result.dtypes.value_counts() + result.index = [str(i) for i in result.index] + result = result.sort_index() + tm.assert_series_equal(result, expected) + + def test_get_dummies_just_na(self, sparse): + just_na_list = [np.nan] + just_na_series = Series(just_na_list) + just_na_series_index = Series(just_na_list, index=["A"]) + + res_list = get_dummies(just_na_list, sparse=sparse) + res_series = get_dummies(just_na_series, sparse=sparse) + res_series_index = get_dummies(just_na_series_index, sparse=sparse) + + assert res_list.empty + assert res_series.empty + assert res_series_index.empty + + assert res_list.index.tolist() == [0] + assert res_series.index.tolist() == [0] + assert res_series_index.index.tolist() == ["A"] + + def test_get_dummies_include_na(self, sparse, dtype): + s = ["a", "b", np.nan] + res = get_dummies(s, sparse=sparse, dtype=dtype) + exp = DataFrame( + {"a": [1, 0, 0], "b": [0, 1, 0]}, dtype=self.effective_dtype(dtype) + ) + if sparse: + if dtype.kind == "b": + exp = exp.apply(SparseArray, fill_value=False) + else: + exp = exp.apply(SparseArray, fill_value=0.0) + tm.assert_frame_equal(res, exp) + + # Sparse dataframes do not allow nan labelled columns, see #GH8822 + res_na = get_dummies(s, dummy_na=True, sparse=sparse, dtype=dtype) + exp_na = DataFrame( + {np.nan: [0, 0, 1], "a": [1, 0, 0], "b": [0, 1, 0]}, + dtype=self.effective_dtype(dtype), + ) + exp_na = exp_na.reindex(["a", "b", np.nan], axis=1) + # hack (NaN handling in assert_index_equal) + exp_na.columns = res_na.columns + if sparse: + if dtype.kind == "b": + exp_na = exp_na.apply(SparseArray, fill_value=False) + else: + exp_na = exp_na.apply(SparseArray, fill_value=0.0) + tm.assert_frame_equal(res_na, exp_na) + + res_just_na = get_dummies([np.nan], dummy_na=True, sparse=sparse, dtype=dtype) + exp_just_na = DataFrame( + Series(1, index=[0]), columns=[np.nan], dtype=self.effective_dtype(dtype) + ) + tm.assert_numpy_array_equal(res_just_na.values, exp_just_na.values) + + def test_get_dummies_unicode(self, sparse): + # See GH 6885 - get_dummies chokes on unicode values + e = "e" + eacute = unicodedata.lookup("LATIN SMALL LETTER E WITH ACUTE") + s = [e, eacute, eacute] + res = get_dummies(s, prefix="letter", sparse=sparse) + exp = DataFrame( + {"letter_e": [True, False, False], f"letter_{eacute}": [False, True, True]} + ) + if sparse: + exp = exp.apply(SparseArray, fill_value=False) + tm.assert_frame_equal(res, exp) + + def test_dataframe_dummies_all_obj(self, df, sparse): + df = df[["A", "B"]] + result = get_dummies(df, sparse=sparse) + expected = DataFrame( + {"A_a": [1, 0, 1], "A_b": [0, 1, 0], "B_b": [1, 1, 0], "B_c": [0, 0, 1]}, + dtype=bool, + ) + if sparse: + expected = DataFrame( + { + "A_a": SparseArray([1, 0, 1], dtype="bool"), + "A_b": SparseArray([0, 1, 0], dtype="bool"), + "B_b": SparseArray([1, 1, 0], dtype="bool"), + "B_c": SparseArray([0, 0, 1], dtype="bool"), + } + ) + + tm.assert_frame_equal(result, expected) + + def test_dataframe_dummies_string_dtype(self, df, using_infer_string): + # GH44965 + df = df[["A", "B"]] + df = df.astype({"A": "object", "B": "string"}) + result = get_dummies(df) + expected = DataFrame( + { + "A_a": [1, 0, 1], + "A_b": [0, 1, 0], + "B_b": [1, 1, 0], + "B_c": [0, 0, 1], + }, + dtype=bool, + ) + if not using_infer_string: + # infer_string returns numpy bools + expected[["B_b", "B_c"]] = expected[["B_b", "B_c"]].astype("boolean") + tm.assert_frame_equal(result, expected) + + def test_dataframe_dummies_mix_default(self, df, sparse, dtype): + result = get_dummies(df, sparse=sparse, dtype=dtype) + if sparse: + arr = SparseArray + if dtype.kind == "b": + typ = SparseDtype(dtype, False) + else: + typ = SparseDtype(dtype, 0) + else: + arr = np.array + typ = dtype + expected = DataFrame( + { + "C": [1, 2, 3], + "A_a": arr([1, 0, 1], dtype=typ), + "A_b": arr([0, 1, 0], dtype=typ), + "B_b": arr([1, 1, 0], dtype=typ), + "B_c": arr([0, 0, 1], dtype=typ), + } + ) + expected = expected[["C", "A_a", "A_b", "B_b", "B_c"]] + tm.assert_frame_equal(result, expected) + + def test_dataframe_dummies_prefix_list(self, df, sparse): + prefixes = ["from_A", "from_B"] + result = get_dummies(df, prefix=prefixes, sparse=sparse) + expected = DataFrame( + { + "C": [1, 2, 3], + "from_A_a": [True, False, True], + "from_A_b": [False, True, False], + "from_B_b": [True, True, False], + "from_B_c": [False, False, True], + }, + ) + expected[["C"]] = df[["C"]] + cols = ["from_A_a", "from_A_b", "from_B_b", "from_B_c"] + expected = expected[["C"] + cols] + + typ = SparseArray if sparse else Series + expected[cols] = expected[cols].apply(lambda x: typ(x)) + tm.assert_frame_equal(result, expected) + + def test_dataframe_dummies_prefix_str(self, df, sparse): + # not that you should do this... + result = get_dummies(df, prefix="bad", sparse=sparse) + bad_columns = ["bad_a", "bad_b", "bad_b", "bad_c"] + expected = DataFrame( + [ + [1, True, False, True, False], + [2, False, True, True, False], + [3, True, False, False, True], + ], + columns=["C"] + bad_columns, + ) + expected = expected.astype({"C": np.int64}) + if sparse: + # work around astyping & assigning with duplicate columns + # https://github.com/pandas-dev/pandas/issues/14427 + expected = pd.concat( + [ + Series([1, 2, 3], name="C"), + Series([True, False, True], name="bad_a", dtype="Sparse[bool]"), + Series([False, True, False], name="bad_b", dtype="Sparse[bool]"), + Series([True, True, False], name="bad_b", dtype="Sparse[bool]"), + Series([False, False, True], name="bad_c", dtype="Sparse[bool]"), + ], + axis=1, + ) + + tm.assert_frame_equal(result, expected) + + def test_dataframe_dummies_subset(self, df, sparse): + result = get_dummies(df, prefix=["from_A"], columns=["A"], sparse=sparse) + expected = DataFrame( + { + "B": ["b", "b", "c"], + "C": [1, 2, 3], + "from_A_a": [1, 0, 1], + "from_A_b": [0, 1, 0], + }, + ) + cols = expected.columns + expected[cols[1:]] = expected[cols[1:]].astype(bool) + expected[["C"]] = df[["C"]] + if sparse: + cols = ["from_A_a", "from_A_b"] + expected[cols] = expected[cols].astype(SparseDtype("bool", False)) + tm.assert_frame_equal(result, expected) + + def test_dataframe_dummies_prefix_sep(self, df, sparse): + result = get_dummies(df, prefix_sep="..", sparse=sparse) + expected = DataFrame( + { + "C": [1, 2, 3], + "A..a": [True, False, True], + "A..b": [False, True, False], + "B..b": [True, True, False], + "B..c": [False, False, True], + }, + ) + expected[["C"]] = df[["C"]] + expected = expected[["C", "A..a", "A..b", "B..b", "B..c"]] + if sparse: + cols = ["A..a", "A..b", "B..b", "B..c"] + expected[cols] = expected[cols].astype(SparseDtype("bool", False)) + + tm.assert_frame_equal(result, expected) + + result = get_dummies(df, prefix_sep=["..", "__"], sparse=sparse) + expected = expected.rename(columns={"B..b": "B__b", "B..c": "B__c"}) + tm.assert_frame_equal(result, expected) + + result = get_dummies(df, prefix_sep={"A": "..", "B": "__"}, sparse=sparse) + tm.assert_frame_equal(result, expected) + + def test_dataframe_dummies_prefix_bad_length(self, df, sparse): + msg = re.escape( + "Length of 'prefix' (1) did not match the length of the columns being " + "encoded (2)" + ) + with pytest.raises(ValueError, match=msg): + get_dummies(df, prefix=["too few"], sparse=sparse) + + def test_dataframe_dummies_prefix_sep_bad_length(self, df, sparse): + msg = re.escape( + "Length of 'prefix_sep' (1) did not match the length of the columns being " + "encoded (2)" + ) + with pytest.raises(ValueError, match=msg): + get_dummies(df, prefix_sep=["bad"], sparse=sparse) + + def test_dataframe_dummies_prefix_dict(self, sparse): + prefixes = {"A": "from_A", "B": "from_B"} + df = DataFrame({"C": [1, 2, 3], "A": ["a", "b", "a"], "B": ["b", "b", "c"]}) + result = get_dummies(df, prefix=prefixes, sparse=sparse) + + expected = DataFrame( + { + "C": [1, 2, 3], + "from_A_a": [1, 0, 1], + "from_A_b": [0, 1, 0], + "from_B_b": [1, 1, 0], + "from_B_c": [0, 0, 1], + } + ) + + columns = ["from_A_a", "from_A_b", "from_B_b", "from_B_c"] + expected[columns] = expected[columns].astype(bool) + if sparse: + expected[columns] = expected[columns].astype(SparseDtype("bool", False)) + + tm.assert_frame_equal(result, expected) + + def test_dataframe_dummies_with_na(self, df, sparse, dtype): + df.loc[3, :] = [np.nan, np.nan, np.nan] + result = get_dummies(df, dummy_na=True, sparse=sparse, dtype=dtype).sort_index( + axis=1 + ) + + if sparse: + arr = SparseArray + if dtype.kind == "b": + typ = SparseDtype(dtype, False) + else: + typ = SparseDtype(dtype, 0) + else: + arr = np.array + typ = dtype + + expected = DataFrame( + { + "C": [1, 2, 3, np.nan], + "A_a": arr([1, 0, 1, 0], dtype=typ), + "A_b": arr([0, 1, 0, 0], dtype=typ), + "A_nan": arr([0, 0, 0, 1], dtype=typ), + "B_b": arr([1, 1, 0, 0], dtype=typ), + "B_c": arr([0, 0, 1, 0], dtype=typ), + "B_nan": arr([0, 0, 0, 1], dtype=typ), + } + ).sort_index(axis=1) + + tm.assert_frame_equal(result, expected) + + result = get_dummies(df, dummy_na=False, sparse=sparse, dtype=dtype) + expected = expected[["C", "A_a", "A_b", "B_b", "B_c"]] + tm.assert_frame_equal(result, expected) + + def test_dataframe_dummies_with_categorical(self, df, sparse, dtype): + df["cat"] = Categorical(["x", "y", "y"]) + result = get_dummies(df, sparse=sparse, dtype=dtype).sort_index(axis=1) + if sparse: + arr = SparseArray + if dtype.kind == "b": + typ = SparseDtype(dtype, False) + else: + typ = SparseDtype(dtype, 0) + else: + arr = np.array + typ = dtype + + expected = DataFrame( + { + "C": [1, 2, 3], + "A_a": arr([1, 0, 1], dtype=typ), + "A_b": arr([0, 1, 0], dtype=typ), + "B_b": arr([1, 1, 0], dtype=typ), + "B_c": arr([0, 0, 1], dtype=typ), + "cat_x": arr([1, 0, 0], dtype=typ), + "cat_y": arr([0, 1, 1], dtype=typ), + } + ).sort_index(axis=1) + + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "get_dummies_kwargs,expected", + [ + ( + {"data": DataFrame({"ä": ["a"]})}, + DataFrame({"ä_a": [True]}), + ), + ( + {"data": DataFrame({"x": ["ä"]})}, + DataFrame({"x_ä": [True]}), + ), + ( + {"data": DataFrame({"x": ["a"]}), "prefix": "ä"}, + DataFrame({"ä_a": [True]}), + ), + ( + {"data": DataFrame({"x": ["a"]}), "prefix_sep": "ä"}, + DataFrame({"xäa": [True]}), + ), + ], + ) + def test_dataframe_dummies_unicode(self, get_dummies_kwargs, expected): + # GH22084 get_dummies incorrectly encodes unicode characters + # in dataframe column names + result = get_dummies(**get_dummies_kwargs) + tm.assert_frame_equal(result, expected) + + def test_get_dummies_basic_drop_first(self, sparse): + # GH12402 Add a new parameter `drop_first` to avoid collinearity + # Basic case + s_list = list("abc") + s_series = Series(s_list) + s_series_index = Series(s_list, list("ABC")) + + expected = DataFrame({"b": [0, 1, 0], "c": [0, 0, 1]}, dtype=bool) + + result = get_dummies(s_list, drop_first=True, sparse=sparse) + if sparse: + expected = expected.apply(SparseArray, fill_value=False) + tm.assert_frame_equal(result, expected) + + result = get_dummies(s_series, drop_first=True, sparse=sparse) + tm.assert_frame_equal(result, expected) + + expected.index = list("ABC") + result = get_dummies(s_series_index, drop_first=True, sparse=sparse) + tm.assert_frame_equal(result, expected) + + def test_get_dummies_basic_drop_first_one_level(self, sparse): + # Test the case that categorical variable only has one level. + s_list = list("aaa") + s_series = Series(s_list) + s_series_index = Series(s_list, list("ABC")) + + expected = DataFrame(index=RangeIndex(3)) + + result = get_dummies(s_list, drop_first=True, sparse=sparse) + tm.assert_frame_equal(result, expected) + + result = get_dummies(s_series, drop_first=True, sparse=sparse) + tm.assert_frame_equal(result, expected) + + expected = DataFrame(index=list("ABC")) + result = get_dummies(s_series_index, drop_first=True, sparse=sparse) + tm.assert_frame_equal(result, expected) + + def test_get_dummies_basic_drop_first_NA(self, sparse): + # Test NA handling together with drop_first + s_NA = ["a", "b", np.nan] + res = get_dummies(s_NA, drop_first=True, sparse=sparse) + exp = DataFrame({"b": [0, 1, 0]}, dtype=bool) + if sparse: + exp = exp.apply(SparseArray, fill_value=False) + + tm.assert_frame_equal(res, exp) + + res_na = get_dummies(s_NA, dummy_na=True, drop_first=True, sparse=sparse) + exp_na = DataFrame({"b": [0, 1, 0], np.nan: [0, 0, 1]}, dtype=bool).reindex( + ["b", np.nan], axis=1 + ) + if sparse: + exp_na = exp_na.apply(SparseArray, fill_value=False) + tm.assert_frame_equal(res_na, exp_na) + + res_just_na = get_dummies( + [np.nan], dummy_na=True, drop_first=True, sparse=sparse + ) + exp_just_na = DataFrame(index=RangeIndex(1)) + tm.assert_frame_equal(res_just_na, exp_just_na) + + def test_dataframe_dummies_drop_first(self, df, sparse): + df = df[["A", "B"]] + result = get_dummies(df, drop_first=True, sparse=sparse) + expected = DataFrame({"A_b": [0, 1, 0], "B_c": [0, 0, 1]}, dtype=bool) + if sparse: + expected = expected.apply(SparseArray, fill_value=False) + tm.assert_frame_equal(result, expected) + + def test_dataframe_dummies_drop_first_with_categorical(self, df, sparse, dtype): + df["cat"] = Categorical(["x", "y", "y"]) + result = get_dummies(df, drop_first=True, sparse=sparse) + expected = DataFrame( + {"C": [1, 2, 3], "A_b": [0, 1, 0], "B_c": [0, 0, 1], "cat_y": [0, 1, 1]} + ) + cols = ["A_b", "B_c", "cat_y"] + expected[cols] = expected[cols].astype(bool) + expected = expected[["C", "A_b", "B_c", "cat_y"]] + if sparse: + for col in cols: + expected[col] = SparseArray(expected[col]) + tm.assert_frame_equal(result, expected) + + def test_dataframe_dummies_drop_first_with_na(self, df, sparse): + df.loc[3, :] = [np.nan, np.nan, np.nan] + result = get_dummies( + df, dummy_na=True, drop_first=True, sparse=sparse + ).sort_index(axis=1) + expected = DataFrame( + { + "C": [1, 2, 3, np.nan], + "A_b": [0, 1, 0, 0], + "A_nan": [0, 0, 0, 1], + "B_c": [0, 0, 1, 0], + "B_nan": [0, 0, 0, 1], + } + ) + cols = ["A_b", "A_nan", "B_c", "B_nan"] + expected[cols] = expected[cols].astype(bool) + expected = expected.sort_index(axis=1) + if sparse: + for col in cols: + expected[col] = SparseArray(expected[col]) + + tm.assert_frame_equal(result, expected) + + result = get_dummies(df, dummy_na=False, drop_first=True, sparse=sparse) + expected = expected[["C", "A_b", "B_c"]] + tm.assert_frame_equal(result, expected) + + def test_get_dummies_int_int(self): + data = Series([1, 2, 1]) + result = get_dummies(data) + expected = DataFrame([[1, 0], [0, 1], [1, 0]], columns=[1, 2], dtype=bool) + tm.assert_frame_equal(result, expected) + + data = Series(Categorical(["a", "b", "a"])) + result = get_dummies(data) + expected = DataFrame( + [[1, 0], [0, 1], [1, 0]], columns=Categorical(["a", "b"]), dtype=bool + ) + tm.assert_frame_equal(result, expected) + + def test_get_dummies_int_df(self, dtype): + data = DataFrame( + { + "A": [1, 2, 1], + "B": Categorical(["a", "b", "a"]), + "C": [1, 2, 1], + "D": [1.0, 2.0, 1.0], + } + ) + columns = ["C", "D", "A_1", "A_2", "B_a", "B_b"] + expected = DataFrame( + [[1, 1.0, 1, 0, 1, 0], [2, 2.0, 0, 1, 0, 1], [1, 1.0, 1, 0, 1, 0]], + columns=columns, + ) + expected[columns[2:]] = expected[columns[2:]].astype(dtype) + result = get_dummies(data, columns=["A", "B"], dtype=dtype) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("ordered", [True, False]) + def test_dataframe_dummies_preserve_categorical_dtype(self, dtype, ordered): + # GH13854 + cat = Categorical(list("xy"), categories=list("xyz"), ordered=ordered) + result = get_dummies(cat, dtype=dtype) + + data = np.array([[1, 0, 0], [0, 1, 0]], dtype=self.effective_dtype(dtype)) + cols = CategoricalIndex( + cat.categories, categories=cat.categories, ordered=ordered + ) + expected = DataFrame(data, columns=cols, dtype=self.effective_dtype(dtype)) + + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("sparse", [True, False]) + def test_get_dummies_dont_sparsify_all_columns(self, sparse): + # GH18914 + df = DataFrame.from_dict({"GDP": [1, 2], "Nation": ["AB", "CD"]}) + df = get_dummies(df, columns=["Nation"], sparse=sparse) + df2 = df.reindex(columns=["GDP"]) + + tm.assert_frame_equal(df[["GDP"]], df2) + + def test_get_dummies_duplicate_columns(self, df): + # GH20839 + df.columns = ["A", "A", "A"] + result = get_dummies(df).sort_index(axis=1) + + expected = DataFrame( + [ + [1, True, False, True, False], + [2, False, True, True, False], + [3, True, False, False, True], + ], + columns=["A", "A_a", "A_b", "A_b", "A_c"], + ).sort_index(axis=1) + + expected = expected.astype({"A": np.int64}) + + tm.assert_frame_equal(result, expected) + + def test_get_dummies_all_sparse(self): + df = DataFrame({"A": [1, 2]}) + result = get_dummies(df, columns=["A"], sparse=True) + dtype = SparseDtype("bool", False) + expected = DataFrame( + { + "A_1": SparseArray([1, 0], dtype=dtype), + "A_2": SparseArray([0, 1], dtype=dtype), + } + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("values", ["baz"]) + def test_get_dummies_with_string_values(self, values): + # issue #28383 + df = DataFrame( + { + "bar": [1, 2, 3, 4, 5, 6], + "foo": ["one", "one", "one", "two", "two", "two"], + "baz": ["A", "B", "C", "A", "B", "C"], + "zoo": ["x", "y", "z", "q", "w", "t"], + } + ) + + msg = "Input must be a list-like for parameter `columns`" + + with pytest.raises(TypeError, match=msg): + get_dummies(df, columns=values) + + def test_get_dummies_ea_dtype_series(self, any_numeric_ea_and_arrow_dtype): + # GH#32430 + ser = Series(list("abca")) + result = get_dummies(ser, dtype=any_numeric_ea_and_arrow_dtype) + expected = DataFrame( + {"a": [1, 0, 0, 1], "b": [0, 1, 0, 0], "c": [0, 0, 1, 0]}, + dtype=any_numeric_ea_and_arrow_dtype, + ) + tm.assert_frame_equal(result, expected) + + def test_get_dummies_ea_dtype_dataframe(self, any_numeric_ea_and_arrow_dtype): + # GH#32430 + df = DataFrame({"x": list("abca")}) + result = get_dummies(df, dtype=any_numeric_ea_and_arrow_dtype) + expected = DataFrame( + {"x_a": [1, 0, 0, 1], "x_b": [0, 1, 0, 0], "x_c": [0, 0, 1, 0]}, + dtype=any_numeric_ea_and_arrow_dtype, + ) + tm.assert_frame_equal(result, expected) + + @td.skip_if_no("pyarrow") + def test_get_dummies_ea_dtype(self): + # GH#56273 + for dtype, exp_dtype in [ + ("string[pyarrow]", "boolean"), + ("string[pyarrow_numpy]", "bool"), + (CategoricalDtype(Index(["a"], dtype="string[pyarrow]")), "boolean"), + (CategoricalDtype(Index(["a"], dtype="string[pyarrow_numpy]")), "bool"), + ]: + df = DataFrame({"name": Series(["a"], dtype=dtype), "x": 1}) + result = get_dummies(df) + expected = DataFrame({"x": 1, "name_a": Series([True], dtype=exp_dtype)}) + tm.assert_frame_equal(result, expected) + + @td.skip_if_no("pyarrow") + def test_get_dummies_arrow_dtype(self): + # GH#56273 + df = DataFrame({"name": Series(["a"], dtype=ArrowDtype(pa.string())), "x": 1}) + result = get_dummies(df) + expected = DataFrame({"x": 1, "name_a": Series([True], dtype="bool[pyarrow]")}) + tm.assert_frame_equal(result, expected) + + df = DataFrame( + { + "name": Series( + ["a"], + dtype=CategoricalDtype(Index(["a"], dtype=ArrowDtype(pa.string()))), + ), + "x": 1, + } + ) + result = get_dummies(df) + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_melt.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_melt.py new file mode 100644 index 0000000000000000000000000000000000000000..272c5b34032937215b3ffff45ea10a9e31048041 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_melt.py @@ -0,0 +1,1252 @@ +import re + +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + Index, + date_range, + lreshape, + melt, + wide_to_long, +) +import pandas._testing as tm + + +@pytest.fixture +def df(): + res = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + res["id1"] = (res["A"] > 0).astype(np.int64) + res["id2"] = (res["B"] > 0).astype(np.int64) + return res + + +@pytest.fixture +def df1(): + res = DataFrame( + [ + [1.067683, -1.110463, 0.20867], + [-1.321405, 0.368915, -1.055342], + [-0.807333, 0.08298, -0.873361], + ] + ) + res.columns = [list("ABC"), list("abc")] + res.columns.names = ["CAP", "low"] + return res + + +@pytest.fixture +def var_name(): + return "var" + + +@pytest.fixture +def value_name(): + return "val" + + +class TestMelt: + def test_top_level_method(self, df): + result = melt(df) + assert result.columns.tolist() == ["variable", "value"] + + def test_method_signatures(self, df, df1, var_name, value_name): + tm.assert_frame_equal(df.melt(), melt(df)) + + tm.assert_frame_equal( + df.melt(id_vars=["id1", "id2"], value_vars=["A", "B"]), + melt(df, id_vars=["id1", "id2"], value_vars=["A", "B"]), + ) + + tm.assert_frame_equal( + df.melt(var_name=var_name, value_name=value_name), + melt(df, var_name=var_name, value_name=value_name), + ) + + tm.assert_frame_equal(df1.melt(col_level=0), melt(df1, col_level=0)) + + def test_default_col_names(self, df): + result = df.melt() + assert result.columns.tolist() == ["variable", "value"] + + result1 = df.melt(id_vars=["id1"]) + assert result1.columns.tolist() == ["id1", "variable", "value"] + + result2 = df.melt(id_vars=["id1", "id2"]) + assert result2.columns.tolist() == ["id1", "id2", "variable", "value"] + + def test_value_vars(self, df): + result3 = df.melt(id_vars=["id1", "id2"], value_vars="A") + assert len(result3) == 10 + + result4 = df.melt(id_vars=["id1", "id2"], value_vars=["A", "B"]) + expected4 = DataFrame( + { + "id1": df["id1"].tolist() * 2, + "id2": df["id2"].tolist() * 2, + "variable": ["A"] * 10 + ["B"] * 10, + "value": (df["A"].tolist() + df["B"].tolist()), + }, + columns=["id1", "id2", "variable", "value"], + ) + tm.assert_frame_equal(result4, expected4) + + @pytest.mark.parametrize("type_", (tuple, list, np.array)) + def test_value_vars_types(self, type_, df): + # GH 15348 + expected = DataFrame( + { + "id1": df["id1"].tolist() * 2, + "id2": df["id2"].tolist() * 2, + "variable": ["A"] * 10 + ["B"] * 10, + "value": (df["A"].tolist() + df["B"].tolist()), + }, + columns=["id1", "id2", "variable", "value"], + ) + result = df.melt(id_vars=["id1", "id2"], value_vars=type_(("A", "B"))) + tm.assert_frame_equal(result, expected) + + def test_vars_work_with_multiindex(self, df1): + expected = DataFrame( + { + ("A", "a"): df1[("A", "a")], + "CAP": ["B"] * len(df1), + "low": ["b"] * len(df1), + "value": df1[("B", "b")], + }, + columns=[("A", "a"), "CAP", "low", "value"], + ) + + result = df1.melt(id_vars=[("A", "a")], value_vars=[("B", "b")]) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "id_vars, value_vars, col_level, expected", + [ + ( + ["A"], + ["B"], + 0, + DataFrame( + { + "A": {0: 1.067683, 1: -1.321405, 2: -0.807333}, + "CAP": {0: "B", 1: "B", 2: "B"}, + "value": {0: -1.110463, 1: 0.368915, 2: 0.08298}, + } + ), + ), + ( + ["a"], + ["b"], + 1, + DataFrame( + { + "a": {0: 1.067683, 1: -1.321405, 2: -0.807333}, + "low": {0: "b", 1: "b", 2: "b"}, + "value": {0: -1.110463, 1: 0.368915, 2: 0.08298}, + } + ), + ), + ], + ) + def test_single_vars_work_with_multiindex( + self, id_vars, value_vars, col_level, expected, df1 + ): + result = df1.melt(id_vars, value_vars, col_level=col_level) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "id_vars, value_vars", + [ + [("A", "a"), [("B", "b")]], + [[("A", "a")], ("B", "b")], + [("A", "a"), ("B", "b")], + ], + ) + def test_tuple_vars_fail_with_multiindex(self, id_vars, value_vars, df1): + # melt should fail with an informative error message if + # the columns have a MultiIndex and a tuple is passed + # for id_vars or value_vars. + msg = r"(id|value)_vars must be a list of tuples when columns are a MultiIndex" + with pytest.raises(ValueError, match=msg): + df1.melt(id_vars=id_vars, value_vars=value_vars) + + def test_custom_var_name(self, df, var_name): + result5 = df.melt(var_name=var_name) + assert result5.columns.tolist() == ["var", "value"] + + result6 = df.melt(id_vars=["id1"], var_name=var_name) + assert result6.columns.tolist() == ["id1", "var", "value"] + + result7 = df.melt(id_vars=["id1", "id2"], var_name=var_name) + assert result7.columns.tolist() == ["id1", "id2", "var", "value"] + + result8 = df.melt(id_vars=["id1", "id2"], value_vars="A", var_name=var_name) + assert result8.columns.tolist() == ["id1", "id2", "var", "value"] + + result9 = df.melt( + id_vars=["id1", "id2"], value_vars=["A", "B"], var_name=var_name + ) + expected9 = DataFrame( + { + "id1": df["id1"].tolist() * 2, + "id2": df["id2"].tolist() * 2, + var_name: ["A"] * 10 + ["B"] * 10, + "value": (df["A"].tolist() + df["B"].tolist()), + }, + columns=["id1", "id2", var_name, "value"], + ) + tm.assert_frame_equal(result9, expected9) + + def test_custom_value_name(self, df, value_name): + result10 = df.melt(value_name=value_name) + assert result10.columns.tolist() == ["variable", "val"] + + result11 = df.melt(id_vars=["id1"], value_name=value_name) + assert result11.columns.tolist() == ["id1", "variable", "val"] + + result12 = df.melt(id_vars=["id1", "id2"], value_name=value_name) + assert result12.columns.tolist() == ["id1", "id2", "variable", "val"] + + result13 = df.melt( + id_vars=["id1", "id2"], value_vars="A", value_name=value_name + ) + assert result13.columns.tolist() == ["id1", "id2", "variable", "val"] + + result14 = df.melt( + id_vars=["id1", "id2"], value_vars=["A", "B"], value_name=value_name + ) + expected14 = DataFrame( + { + "id1": df["id1"].tolist() * 2, + "id2": df["id2"].tolist() * 2, + "variable": ["A"] * 10 + ["B"] * 10, + value_name: (df["A"].tolist() + df["B"].tolist()), + }, + columns=["id1", "id2", "variable", value_name], + ) + tm.assert_frame_equal(result14, expected14) + + def test_custom_var_and_value_name(self, df, value_name, var_name): + result15 = df.melt(var_name=var_name, value_name=value_name) + assert result15.columns.tolist() == ["var", "val"] + + result16 = df.melt(id_vars=["id1"], var_name=var_name, value_name=value_name) + assert result16.columns.tolist() == ["id1", "var", "val"] + + result17 = df.melt( + id_vars=["id1", "id2"], var_name=var_name, value_name=value_name + ) + assert result17.columns.tolist() == ["id1", "id2", "var", "val"] + + result18 = df.melt( + id_vars=["id1", "id2"], + value_vars="A", + var_name=var_name, + value_name=value_name, + ) + assert result18.columns.tolist() == ["id1", "id2", "var", "val"] + + result19 = df.melt( + id_vars=["id1", "id2"], + value_vars=["A", "B"], + var_name=var_name, + value_name=value_name, + ) + expected19 = DataFrame( + { + "id1": df["id1"].tolist() * 2, + "id2": df["id2"].tolist() * 2, + var_name: ["A"] * 10 + ["B"] * 10, + value_name: (df["A"].tolist() + df["B"].tolist()), + }, + columns=["id1", "id2", var_name, value_name], + ) + tm.assert_frame_equal(result19, expected19) + + df20 = df.copy() + df20.columns.name = "foo" + result20 = df20.melt() + assert result20.columns.tolist() == ["foo", "value"] + + @pytest.mark.parametrize("col_level", [0, "CAP"]) + def test_col_level(self, col_level, df1): + res = df1.melt(col_level=col_level) + assert res.columns.tolist() == ["CAP", "value"] + + def test_multiindex(self, df1): + res = df1.melt() + assert res.columns.tolist() == ["CAP", "low", "value"] + + @pytest.mark.parametrize( + "col", + [ + pd.Series(date_range("2010", periods=5, tz="US/Pacific")), + pd.Series(["a", "b", "c", "a", "d"], dtype="category"), + pd.Series([0, 1, 0, 0, 0]), + ], + ) + def test_pandas_dtypes(self, col): + # GH 15785 + df = DataFrame( + {"klass": range(5), "col": col, "attr1": [1, 0, 0, 0, 0], "attr2": col} + ) + expected_value = pd.concat([pd.Series([1, 0, 0, 0, 0]), col], ignore_index=True) + result = melt( + df, id_vars=["klass", "col"], var_name="attribute", value_name="value" + ) + expected = DataFrame( + { + 0: list(range(5)) * 2, + 1: pd.concat([col] * 2, ignore_index=True), + 2: ["attr1"] * 5 + ["attr2"] * 5, + 3: expected_value, + } + ) + expected.columns = ["klass", "col", "attribute", "value"] + tm.assert_frame_equal(result, expected) + + def test_preserve_category(self): + # GH 15853 + data = DataFrame({"A": [1, 2], "B": pd.Categorical(["X", "Y"])}) + result = melt(data, ["B"], ["A"]) + expected = DataFrame( + {"B": pd.Categorical(["X", "Y"]), "variable": ["A", "A"], "value": [1, 2]} + ) + + tm.assert_frame_equal(result, expected) + + def test_melt_missing_columns_raises(self): + # GH-23575 + # This test is to ensure that pandas raises an error if melting is + # attempted with column names absent from the dataframe + + # Generate data + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 4)), columns=list("abcd") + ) + + # Try to melt with missing `value_vars` column name + msg = "The following id_vars or value_vars are not present in the DataFrame:" + with pytest.raises(KeyError, match=msg): + df.melt(["a", "b"], ["C", "d"]) + + # Try to melt with missing `id_vars` column name + with pytest.raises(KeyError, match=msg): + df.melt(["A", "b"], ["c", "d"]) + + # Multiple missing + with pytest.raises( + KeyError, + match=msg, + ): + df.melt(["a", "b", "not_here", "or_there"], ["c", "d"]) + + # Multiindex melt fails if column is missing from multilevel melt + multi = df.copy() + multi.columns = [list("ABCD"), list("abcd")] + with pytest.raises(KeyError, match=msg): + multi.melt([("E", "a")], [("B", "b")]) + # Multiindex fails if column is missing from single level melt + with pytest.raises(KeyError, match=msg): + multi.melt(["A"], ["F"], col_level=0) + + def test_melt_mixed_int_str_id_vars(self): + # GH 29718 + df = DataFrame({0: ["foo"], "a": ["bar"], "b": [1], "d": [2]}) + result = melt(df, id_vars=[0, "a"], value_vars=["b", "d"]) + expected = DataFrame( + {0: ["foo"] * 2, "a": ["bar"] * 2, "variable": list("bd"), "value": [1, 2]} + ) + tm.assert_frame_equal(result, expected) + + def test_melt_mixed_int_str_value_vars(self): + # GH 29718 + df = DataFrame({0: ["foo"], "a": ["bar"]}) + result = melt(df, value_vars=[0, "a"]) + expected = DataFrame({"variable": [0, "a"], "value": ["foo", "bar"]}) + tm.assert_frame_equal(result, expected) + + def test_ignore_index(self): + # GH 17440 + df = DataFrame({"foo": [0], "bar": [1]}, index=["first"]) + result = melt(df, ignore_index=False) + expected = DataFrame( + {"variable": ["foo", "bar"], "value": [0, 1]}, index=["first", "first"] + ) + tm.assert_frame_equal(result, expected) + + def test_ignore_multiindex(self): + # GH 17440 + index = pd.MultiIndex.from_tuples( + [("first", "second"), ("first", "third")], names=["baz", "foobar"] + ) + df = DataFrame({"foo": [0, 1], "bar": [2, 3]}, index=index) + result = melt(df, ignore_index=False) + + expected_index = pd.MultiIndex.from_tuples( + [("first", "second"), ("first", "third")] * 2, names=["baz", "foobar"] + ) + expected = DataFrame( + {"variable": ["foo"] * 2 + ["bar"] * 2, "value": [0, 1, 2, 3]}, + index=expected_index, + ) + + tm.assert_frame_equal(result, expected) + + def test_ignore_index_name_and_type(self): + # GH 17440 + index = Index(["foo", "bar"], dtype="category", name="baz") + df = DataFrame({"x": [0, 1], "y": [2, 3]}, index=index) + result = melt(df, ignore_index=False) + + expected_index = Index(["foo", "bar"] * 2, dtype="category", name="baz") + expected = DataFrame( + {"variable": ["x", "x", "y", "y"], "value": [0, 1, 2, 3]}, + index=expected_index, + ) + + tm.assert_frame_equal(result, expected) + + def test_melt_with_duplicate_columns(self): + # GH#41951 + df = DataFrame([["id", 2, 3]], columns=["a", "b", "b"]) + result = df.melt(id_vars=["a"], value_vars=["b"]) + expected = DataFrame( + [["id", "b", 2], ["id", "b", 3]], columns=["a", "variable", "value"] + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("dtype", ["Int8", "Int64"]) + def test_melt_ea_dtype(self, dtype): + # GH#41570 + df = DataFrame( + { + "a": pd.Series([1, 2], dtype="Int8"), + "b": pd.Series([3, 4], dtype=dtype), + } + ) + result = df.melt() + expected = DataFrame( + { + "variable": ["a", "a", "b", "b"], + "value": pd.Series([1, 2, 3, 4], dtype=dtype), + } + ) + tm.assert_frame_equal(result, expected) + + def test_melt_ea_columns(self): + # GH 54297 + df = DataFrame( + { + "A": {0: "a", 1: "b", 2: "c"}, + "B": {0: 1, 1: 3, 2: 5}, + "C": {0: 2, 1: 4, 2: 6}, + } + ) + df.columns = df.columns.astype("string[python]") + result = df.melt(id_vars=["A"], value_vars=["B"]) + expected = DataFrame( + { + "A": list("abc"), + "variable": pd.Series(["B"] * 3, dtype="string[python]"), + "value": [1, 3, 5], + } + ) + tm.assert_frame_equal(result, expected) + + def test_melt_preserves_datetime(self): + df = DataFrame( + data=[ + { + "type": "A0", + "start_date": pd.Timestamp("2023/03/01", tz="Asia/Tokyo"), + "end_date": pd.Timestamp("2023/03/10", tz="Asia/Tokyo"), + }, + { + "type": "A1", + "start_date": pd.Timestamp("2023/03/01", tz="Asia/Tokyo"), + "end_date": pd.Timestamp("2023/03/11", tz="Asia/Tokyo"), + }, + ], + index=["aaaa", "bbbb"], + ) + result = df.melt( + id_vars=["type"], + value_vars=["start_date", "end_date"], + var_name="start/end", + value_name="date", + ) + expected = DataFrame( + { + "type": {0: "A0", 1: "A1", 2: "A0", 3: "A1"}, + "start/end": { + 0: "start_date", + 1: "start_date", + 2: "end_date", + 3: "end_date", + }, + "date": { + 0: pd.Timestamp("2023-03-01 00:00:00+0900", tz="Asia/Tokyo"), + 1: pd.Timestamp("2023-03-01 00:00:00+0900", tz="Asia/Tokyo"), + 2: pd.Timestamp("2023-03-10 00:00:00+0900", tz="Asia/Tokyo"), + 3: pd.Timestamp("2023-03-11 00:00:00+0900", tz="Asia/Tokyo"), + }, + } + ) + tm.assert_frame_equal(result, expected) + + def test_melt_allows_non_scalar_id_vars(self): + df = DataFrame( + data={"a": [1, 2, 3], "b": [4, 5, 6]}, + index=["11", "22", "33"], + ) + result = df.melt( + id_vars="a", + var_name=0, + value_name=1, + ) + expected = DataFrame({"a": [1, 2, 3], 0: ["b"] * 3, 1: [4, 5, 6]}) + tm.assert_frame_equal(result, expected) + + def test_melt_allows_non_string_var_name(self): + df = DataFrame( + data={"a": [1, 2, 3], "b": [4, 5, 6]}, + index=["11", "22", "33"], + ) + result = df.melt( + id_vars=["a"], + var_name=0, + value_name=1, + ) + expected = DataFrame({"a": [1, 2, 3], 0: ["b"] * 3, 1: [4, 5, 6]}) + tm.assert_frame_equal(result, expected) + + def test_melt_non_scalar_var_name_raises(self): + df = DataFrame( + data={"a": [1, 2, 3], "b": [4, 5, 6]}, + index=["11", "22", "33"], + ) + with pytest.raises(ValueError, match=r".* must be a scalar."): + df.melt(id_vars=["a"], var_name=[1, 2]) + + +class TestLreshape: + def test_pairs(self): + data = { + "birthdt": [ + "08jan2009", + "20dec2008", + "30dec2008", + "21dec2008", + "11jan2009", + ], + "birthwt": [1766, 3301, 1454, 3139, 4133], + "id": [101, 102, 103, 104, 105], + "sex": ["Male", "Female", "Female", "Female", "Female"], + "visitdt1": [ + "11jan2009", + "22dec2008", + "04jan2009", + "29dec2008", + "20jan2009", + ], + "visitdt2": ["21jan2009", np.nan, "22jan2009", "31dec2008", "03feb2009"], + "visitdt3": ["05feb2009", np.nan, np.nan, "02jan2009", "15feb2009"], + "wt1": [1823, 3338, 1549, 3298, 4306], + "wt2": [2011.0, np.nan, 1892.0, 3338.0, 4575.0], + "wt3": [2293.0, np.nan, np.nan, 3377.0, 4805.0], + } + + df = DataFrame(data) + + spec = { + "visitdt": [f"visitdt{i:d}" for i in range(1, 4)], + "wt": [f"wt{i:d}" for i in range(1, 4)], + } + result = lreshape(df, spec) + + exp_data = { + "birthdt": [ + "08jan2009", + "20dec2008", + "30dec2008", + "21dec2008", + "11jan2009", + "08jan2009", + "30dec2008", + "21dec2008", + "11jan2009", + "08jan2009", + "21dec2008", + "11jan2009", + ], + "birthwt": [ + 1766, + 3301, + 1454, + 3139, + 4133, + 1766, + 1454, + 3139, + 4133, + 1766, + 3139, + 4133, + ], + "id": [101, 102, 103, 104, 105, 101, 103, 104, 105, 101, 104, 105], + "sex": [ + "Male", + "Female", + "Female", + "Female", + "Female", + "Male", + "Female", + "Female", + "Female", + "Male", + "Female", + "Female", + ], + "visitdt": [ + "11jan2009", + "22dec2008", + "04jan2009", + "29dec2008", + "20jan2009", + "21jan2009", + "22jan2009", + "31dec2008", + "03feb2009", + "05feb2009", + "02jan2009", + "15feb2009", + ], + "wt": [ + 1823.0, + 3338.0, + 1549.0, + 3298.0, + 4306.0, + 2011.0, + 1892.0, + 3338.0, + 4575.0, + 2293.0, + 3377.0, + 4805.0, + ], + } + exp = DataFrame(exp_data, columns=result.columns) + tm.assert_frame_equal(result, exp) + + result = lreshape(df, spec, dropna=False) + exp_data = { + "birthdt": [ + "08jan2009", + "20dec2008", + "30dec2008", + "21dec2008", + "11jan2009", + "08jan2009", + "20dec2008", + "30dec2008", + "21dec2008", + "11jan2009", + "08jan2009", + "20dec2008", + "30dec2008", + "21dec2008", + "11jan2009", + ], + "birthwt": [ + 1766, + 3301, + 1454, + 3139, + 4133, + 1766, + 3301, + 1454, + 3139, + 4133, + 1766, + 3301, + 1454, + 3139, + 4133, + ], + "id": [ + 101, + 102, + 103, + 104, + 105, + 101, + 102, + 103, + 104, + 105, + 101, + 102, + 103, + 104, + 105, + ], + "sex": [ + "Male", + "Female", + "Female", + "Female", + "Female", + "Male", + "Female", + "Female", + "Female", + "Female", + "Male", + "Female", + "Female", + "Female", + "Female", + ], + "visitdt": [ + "11jan2009", + "22dec2008", + "04jan2009", + "29dec2008", + "20jan2009", + "21jan2009", + np.nan, + "22jan2009", + "31dec2008", + "03feb2009", + "05feb2009", + np.nan, + np.nan, + "02jan2009", + "15feb2009", + ], + "wt": [ + 1823.0, + 3338.0, + 1549.0, + 3298.0, + 4306.0, + 2011.0, + np.nan, + 1892.0, + 3338.0, + 4575.0, + 2293.0, + np.nan, + np.nan, + 3377.0, + 4805.0, + ], + } + exp = DataFrame(exp_data, columns=result.columns) + tm.assert_frame_equal(result, exp) + + spec = { + "visitdt": [f"visitdt{i:d}" for i in range(1, 3)], + "wt": [f"wt{i:d}" for i in range(1, 4)], + } + msg = "All column lists must be same length" + with pytest.raises(ValueError, match=msg): + lreshape(df, spec) + + +class TestWideToLong: + def test_simple(self): + x = np.random.default_rng(2).standard_normal(3) + df = DataFrame( + { + "A1970": {0: "a", 1: "b", 2: "c"}, + "A1980": {0: "d", 1: "e", 2: "f"}, + "B1970": {0: 2.5, 1: 1.2, 2: 0.7}, + "B1980": {0: 3.2, 1: 1.3, 2: 0.1}, + "X": dict(zip(range(3), x)), + } + ) + df["id"] = df.index + exp_data = { + "X": x.tolist() + x.tolist(), + "A": ["a", "b", "c", "d", "e", "f"], + "B": [2.5, 1.2, 0.7, 3.2, 1.3, 0.1], + "year": [1970, 1970, 1970, 1980, 1980, 1980], + "id": [0, 1, 2, 0, 1, 2], + } + expected = DataFrame(exp_data) + expected = expected.set_index(["id", "year"])[["X", "A", "B"]] + result = wide_to_long(df, ["A", "B"], i="id", j="year") + tm.assert_frame_equal(result, expected) + + def test_stubs(self): + # GH9204 wide_to_long call should not modify 'stubs' list + df = DataFrame([[0, 1, 2, 3, 8], [4, 5, 6, 7, 9]]) + df.columns = ["id", "inc1", "inc2", "edu1", "edu2"] + stubs = ["inc", "edu"] + + wide_to_long(df, stubs, i="id", j="age") + + assert stubs == ["inc", "edu"] + + def test_separating_character(self): + # GH14779 + + x = np.random.default_rng(2).standard_normal(3) + df = DataFrame( + { + "A.1970": {0: "a", 1: "b", 2: "c"}, + "A.1980": {0: "d", 1: "e", 2: "f"}, + "B.1970": {0: 2.5, 1: 1.2, 2: 0.7}, + "B.1980": {0: 3.2, 1: 1.3, 2: 0.1}, + "X": dict(zip(range(3), x)), + } + ) + df["id"] = df.index + exp_data = { + "X": x.tolist() + x.tolist(), + "A": ["a", "b", "c", "d", "e", "f"], + "B": [2.5, 1.2, 0.7, 3.2, 1.3, 0.1], + "year": [1970, 1970, 1970, 1980, 1980, 1980], + "id": [0, 1, 2, 0, 1, 2], + } + expected = DataFrame(exp_data) + expected = expected.set_index(["id", "year"])[["X", "A", "B"]] + result = wide_to_long(df, ["A", "B"], i="id", j="year", sep=".") + tm.assert_frame_equal(result, expected) + + def test_escapable_characters(self): + x = np.random.default_rng(2).standard_normal(3) + df = DataFrame( + { + "A(quarterly)1970": {0: "a", 1: "b", 2: "c"}, + "A(quarterly)1980": {0: "d", 1: "e", 2: "f"}, + "B(quarterly)1970": {0: 2.5, 1: 1.2, 2: 0.7}, + "B(quarterly)1980": {0: 3.2, 1: 1.3, 2: 0.1}, + "X": dict(zip(range(3), x)), + } + ) + df["id"] = df.index + exp_data = { + "X": x.tolist() + x.tolist(), + "A(quarterly)": ["a", "b", "c", "d", "e", "f"], + "B(quarterly)": [2.5, 1.2, 0.7, 3.2, 1.3, 0.1], + "year": [1970, 1970, 1970, 1980, 1980, 1980], + "id": [0, 1, 2, 0, 1, 2], + } + expected = DataFrame(exp_data) + expected = expected.set_index(["id", "year"])[ + ["X", "A(quarterly)", "B(quarterly)"] + ] + result = wide_to_long(df, ["A(quarterly)", "B(quarterly)"], i="id", j="year") + tm.assert_frame_equal(result, expected) + + def test_unbalanced(self): + # test that we can have a varying amount of time variables + df = DataFrame( + { + "A2010": [1.0, 2.0], + "A2011": [3.0, 4.0], + "B2010": [5.0, 6.0], + "X": ["X1", "X2"], + } + ) + df["id"] = df.index + exp_data = { + "X": ["X1", "X2", "X1", "X2"], + "A": [1.0, 2.0, 3.0, 4.0], + "B": [5.0, 6.0, np.nan, np.nan], + "id": [0, 1, 0, 1], + "year": [2010, 2010, 2011, 2011], + } + expected = DataFrame(exp_data) + expected = expected.set_index(["id", "year"])[["X", "A", "B"]] + result = wide_to_long(df, ["A", "B"], i="id", j="year") + tm.assert_frame_equal(result, expected) + + def test_character_overlap(self): + # Test we handle overlapping characters in both id_vars and value_vars + df = DataFrame( + { + "A11": ["a11", "a22", "a33"], + "A12": ["a21", "a22", "a23"], + "B11": ["b11", "b12", "b13"], + "B12": ["b21", "b22", "b23"], + "BB11": [1, 2, 3], + "BB12": [4, 5, 6], + "BBBX": [91, 92, 93], + "BBBZ": [91, 92, 93], + } + ) + df["id"] = df.index + expected = DataFrame( + { + "BBBX": [91, 92, 93, 91, 92, 93], + "BBBZ": [91, 92, 93, 91, 92, 93], + "A": ["a11", "a22", "a33", "a21", "a22", "a23"], + "B": ["b11", "b12", "b13", "b21", "b22", "b23"], + "BB": [1, 2, 3, 4, 5, 6], + "id": [0, 1, 2, 0, 1, 2], + "year": [11, 11, 11, 12, 12, 12], + } + ) + expected = expected.set_index(["id", "year"])[["BBBX", "BBBZ", "A", "B", "BB"]] + result = wide_to_long(df, ["A", "B", "BB"], i="id", j="year") + tm.assert_frame_equal(result.sort_index(axis=1), expected.sort_index(axis=1)) + + def test_invalid_separator(self): + # if an invalid separator is supplied a empty data frame is returned + sep = "nope!" + df = DataFrame( + { + "A2010": [1.0, 2.0], + "A2011": [3.0, 4.0], + "B2010": [5.0, 6.0], + "X": ["X1", "X2"], + } + ) + df["id"] = df.index + exp_data = { + "X": "", + "A2010": [], + "A2011": [], + "B2010": [], + "id": [], + "year": [], + "A": [], + "B": [], + } + expected = DataFrame(exp_data).astype({"year": np.int64}) + expected = expected.set_index(["id", "year"])[ + ["X", "A2010", "A2011", "B2010", "A", "B"] + ] + expected.index = expected.index.set_levels([0, 1], level=0) + result = wide_to_long(df, ["A", "B"], i="id", j="year", sep=sep) + tm.assert_frame_equal(result.sort_index(axis=1), expected.sort_index(axis=1)) + + def test_num_string_disambiguation(self): + # Test that we can disambiguate number value_vars from + # string value_vars + df = DataFrame( + { + "A11": ["a11", "a22", "a33"], + "A12": ["a21", "a22", "a23"], + "B11": ["b11", "b12", "b13"], + "B12": ["b21", "b22", "b23"], + "BB11": [1, 2, 3], + "BB12": [4, 5, 6], + "Arating": [91, 92, 93], + "Arating_old": [91, 92, 93], + } + ) + df["id"] = df.index + expected = DataFrame( + { + "Arating": [91, 92, 93, 91, 92, 93], + "Arating_old": [91, 92, 93, 91, 92, 93], + "A": ["a11", "a22", "a33", "a21", "a22", "a23"], + "B": ["b11", "b12", "b13", "b21", "b22", "b23"], + "BB": [1, 2, 3, 4, 5, 6], + "id": [0, 1, 2, 0, 1, 2], + "year": [11, 11, 11, 12, 12, 12], + } + ) + expected = expected.set_index(["id", "year"])[ + ["Arating", "Arating_old", "A", "B", "BB"] + ] + result = wide_to_long(df, ["A", "B", "BB"], i="id", j="year") + tm.assert_frame_equal(result.sort_index(axis=1), expected.sort_index(axis=1)) + + def test_invalid_suffixtype(self): + # If all stubs names end with a string, but a numeric suffix is + # assumed, an empty data frame is returned + df = DataFrame( + { + "Aone": [1.0, 2.0], + "Atwo": [3.0, 4.0], + "Bone": [5.0, 6.0], + "X": ["X1", "X2"], + } + ) + df["id"] = df.index + exp_data = { + "X": "", + "Aone": [], + "Atwo": [], + "Bone": [], + "id": [], + "year": [], + "A": [], + "B": [], + } + expected = DataFrame(exp_data).astype({"year": np.int64}) + + expected = expected.set_index(["id", "year"]) + expected.index = expected.index.set_levels([0, 1], level=0) + result = wide_to_long(df, ["A", "B"], i="id", j="year") + tm.assert_frame_equal(result.sort_index(axis=1), expected.sort_index(axis=1)) + + def test_multiple_id_columns(self): + # Taken from http://www.ats.ucla.edu/stat/stata/modules/reshapel.htm + df = DataFrame( + { + "famid": [1, 1, 1, 2, 2, 2, 3, 3, 3], + "birth": [1, 2, 3, 1, 2, 3, 1, 2, 3], + "ht1": [2.8, 2.9, 2.2, 2, 1.8, 1.9, 2.2, 2.3, 2.1], + "ht2": [3.4, 3.8, 2.9, 3.2, 2.8, 2.4, 3.3, 3.4, 2.9], + } + ) + expected = DataFrame( + { + "ht": [ + 2.8, + 3.4, + 2.9, + 3.8, + 2.2, + 2.9, + 2.0, + 3.2, + 1.8, + 2.8, + 1.9, + 2.4, + 2.2, + 3.3, + 2.3, + 3.4, + 2.1, + 2.9, + ], + "famid": [1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3], + "birth": [1, 1, 2, 2, 3, 3, 1, 1, 2, 2, 3, 3, 1, 1, 2, 2, 3, 3], + "age": [1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2], + } + ) + expected = expected.set_index(["famid", "birth", "age"])[["ht"]] + result = wide_to_long(df, "ht", i=["famid", "birth"], j="age") + tm.assert_frame_equal(result, expected) + + def test_non_unique_idvars(self): + # GH16382 + # Raise an error message if non unique id vars (i) are passed + df = DataFrame( + {"A_A1": [1, 2, 3, 4, 5], "B_B1": [1, 2, 3, 4, 5], "x": [1, 1, 1, 1, 1]} + ) + msg = "the id variables need to uniquely identify each row" + with pytest.raises(ValueError, match=msg): + wide_to_long(df, ["A_A", "B_B"], i="x", j="colname") + + def test_cast_j_int(self): + df = DataFrame( + { + "actor_1": ["CCH Pounder", "Johnny Depp", "Christoph Waltz"], + "actor_2": ["Joel David Moore", "Orlando Bloom", "Rory Kinnear"], + "actor_fb_likes_1": [1000.0, 40000.0, 11000.0], + "actor_fb_likes_2": [936.0, 5000.0, 393.0], + "title": ["Avatar", "Pirates of the Caribbean", "Spectre"], + } + ) + + expected = DataFrame( + { + "actor": [ + "CCH Pounder", + "Johnny Depp", + "Christoph Waltz", + "Joel David Moore", + "Orlando Bloom", + "Rory Kinnear", + ], + "actor_fb_likes": [1000.0, 40000.0, 11000.0, 936.0, 5000.0, 393.0], + "num": [1, 1, 1, 2, 2, 2], + "title": [ + "Avatar", + "Pirates of the Caribbean", + "Spectre", + "Avatar", + "Pirates of the Caribbean", + "Spectre", + ], + } + ).set_index(["title", "num"]) + result = wide_to_long( + df, ["actor", "actor_fb_likes"], i="title", j="num", sep="_" + ) + + tm.assert_frame_equal(result, expected) + + def test_identical_stubnames(self): + df = DataFrame( + { + "A2010": [1.0, 2.0], + "A2011": [3.0, 4.0], + "B2010": [5.0, 6.0], + "A": ["X1", "X2"], + } + ) + msg = "stubname can't be identical to a column name" + with pytest.raises(ValueError, match=msg): + wide_to_long(df, ["A", "B"], i="A", j="colname") + + def test_nonnumeric_suffix(self): + df = DataFrame( + { + "treatment_placebo": [1.0, 2.0], + "treatment_test": [3.0, 4.0], + "result_placebo": [5.0, 6.0], + "A": ["X1", "X2"], + } + ) + expected = DataFrame( + { + "A": ["X1", "X2", "X1", "X2"], + "colname": ["placebo", "placebo", "test", "test"], + "result": [5.0, 6.0, np.nan, np.nan], + "treatment": [1.0, 2.0, 3.0, 4.0], + } + ) + expected = expected.set_index(["A", "colname"]) + result = wide_to_long( + df, ["result", "treatment"], i="A", j="colname", suffix="[a-z]+", sep="_" + ) + tm.assert_frame_equal(result, expected) + + def test_mixed_type_suffix(self): + df = DataFrame( + { + "A": ["X1", "X2"], + "result_1": [0, 9], + "result_foo": [5.0, 6.0], + "treatment_1": [1.0, 2.0], + "treatment_foo": [3.0, 4.0], + } + ) + expected = DataFrame( + { + "A": ["X1", "X2", "X1", "X2"], + "colname": ["1", "1", "foo", "foo"], + "result": [0.0, 9.0, 5.0, 6.0], + "treatment": [1.0, 2.0, 3.0, 4.0], + } + ).set_index(["A", "colname"]) + result = wide_to_long( + df, ["result", "treatment"], i="A", j="colname", suffix=".+", sep="_" + ) + tm.assert_frame_equal(result, expected) + + def test_float_suffix(self): + df = DataFrame( + { + "treatment_1.1": [1.0, 2.0], + "treatment_2.1": [3.0, 4.0], + "result_1.2": [5.0, 6.0], + "result_1": [0, 9], + "A": ["X1", "X2"], + } + ) + expected = DataFrame( + { + "A": ["X1", "X2", "X1", "X2", "X1", "X2", "X1", "X2"], + "colname": [1.2, 1.2, 1.0, 1.0, 1.1, 1.1, 2.1, 2.1], + "result": [5.0, 6.0, 0.0, 9.0, np.nan, np.nan, np.nan, np.nan], + "treatment": [np.nan, np.nan, np.nan, np.nan, 1.0, 2.0, 3.0, 4.0], + } + ) + expected = expected.set_index(["A", "colname"]) + result = wide_to_long( + df, ["result", "treatment"], i="A", j="colname", suffix="[0-9.]+", sep="_" + ) + tm.assert_frame_equal(result, expected) + + def test_col_substring_of_stubname(self): + # GH22468 + # Don't raise ValueError when a column name is a substring + # of a stubname that's been passed as a string + wide_data = { + "node_id": {0: 0, 1: 1, 2: 2, 3: 3, 4: 4}, + "A": {0: 0.80, 1: 0.0, 2: 0.25, 3: 1.0, 4: 0.81}, + "PA0": {0: 0.74, 1: 0.56, 2: 0.56, 3: 0.98, 4: 0.6}, + "PA1": {0: 0.77, 1: 0.64, 2: 0.52, 3: 0.98, 4: 0.67}, + "PA3": {0: 0.34, 1: 0.70, 2: 0.52, 3: 0.98, 4: 0.67}, + } + wide_df = DataFrame.from_dict(wide_data) + expected = wide_to_long(wide_df, stubnames=["PA"], i=["node_id", "A"], j="time") + result = wide_to_long(wide_df, stubnames="PA", i=["node_id", "A"], j="time") + tm.assert_frame_equal(result, expected) + + def test_raise_of_column_name_value(self): + # GH34731, enforced in 2.0 + # raise a ValueError if the resultant value column name matches + # a name in the dataframe already (default name is "value") + df = DataFrame({"col": list("ABC"), "value": range(10, 16, 2)}) + + with pytest.raises( + ValueError, match=re.escape("value_name (value) cannot match") + ): + df.melt(id_vars="value", value_name="value") + + @pytest.mark.parametrize("dtype", ["O", "string"]) + def test_missing_stubname(self, dtype): + # GH46044 + df = DataFrame({"id": ["1", "2"], "a-1": [100, 200], "a-2": [300, 400]}) + df = df.astype({"id": dtype}) + result = wide_to_long( + df, + stubnames=["a", "b"], + i="id", + j="num", + sep="-", + ) + index = Index( + [("1", 1), ("2", 1), ("1", 2), ("2", 2)], + name=("id", "num"), + ) + expected = DataFrame( + {"a": [100, 200, 300, 400], "b": [np.nan] * 4}, + index=index, + ) + new_level = expected.index.levels[0].astype(dtype) + expected.index = expected.index.set_levels(new_level, level=0) + tm.assert_frame_equal(result, expected) + + +def test_wide_to_long_pyarrow_string_columns(): + # GH 57066 + pytest.importorskip("pyarrow") + df = DataFrame( + { + "ID": {0: 1}, + "R_test1": {0: 1}, + "R_test2": {0: 1}, + "R_test3": {0: 2}, + "D": {0: 1}, + } + ) + df.columns = df.columns.astype("string[pyarrow_numpy]") + result = wide_to_long( + df, stubnames="R", i="ID", j="UNPIVOTED", sep="_", suffix=".*" + ) + expected = DataFrame( + [[1, 1], [1, 1], [1, 2]], + columns=Index(["D", "R"], dtype=object), + index=pd.MultiIndex.from_arrays( + [ + [1, 1, 1], + Index(["test1", "test2", "test3"], dtype="string[pyarrow_numpy]"), + ], + names=["ID", "UNPIVOTED"], + ), + ) + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_pivot.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_pivot.py new file mode 100644 index 0000000000000000000000000000000000000000..18a449b4d0c67b55f24aa590229dddb0ed456054 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_pivot.py @@ -0,0 +1,2714 @@ +from datetime import ( + date, + datetime, + timedelta, +) +from itertools import product +import re + +import numpy as np +import pytest + +from pandas._config import using_pyarrow_string_dtype + +from pandas.errors import PerformanceWarning + +import pandas as pd +from pandas import ( + Categorical, + DataFrame, + Grouper, + Index, + MultiIndex, + Series, + concat, + date_range, +) +import pandas._testing as tm +from pandas.api.types import CategoricalDtype +from pandas.core.reshape import reshape as reshape_lib +from pandas.core.reshape.pivot import pivot_table + + +@pytest.fixture(params=[True, False]) +def dropna(request): + return request.param + + +@pytest.fixture(params=[([0] * 4, [1] * 4), (range(3), range(1, 4))]) +def interval_values(request, closed): + left, right = request.param + return Categorical(pd.IntervalIndex.from_arrays(left, right, closed)) + + +class TestPivotTable: + @pytest.fixture + def data(self): + return DataFrame( + { + "A": [ + "foo", + "foo", + "foo", + "foo", + "bar", + "bar", + "bar", + "bar", + "foo", + "foo", + "foo", + ], + "B": [ + "one", + "one", + "one", + "two", + "one", + "one", + "one", + "two", + "two", + "two", + "one", + ], + "C": [ + "dull", + "dull", + "shiny", + "dull", + "dull", + "shiny", + "shiny", + "dull", + "shiny", + "shiny", + "shiny", + ], + "D": np.random.default_rng(2).standard_normal(11), + "E": np.random.default_rng(2).standard_normal(11), + "F": np.random.default_rng(2).standard_normal(11), + } + ) + + def test_pivot_table(self, observed, data): + index = ["A", "B"] + columns = "C" + table = pivot_table( + data, values="D", index=index, columns=columns, observed=observed + ) + + table2 = data.pivot_table( + values="D", index=index, columns=columns, observed=observed + ) + tm.assert_frame_equal(table, table2) + + # this works + pivot_table(data, values="D", index=index, observed=observed) + + if len(index) > 1: + assert table.index.names == tuple(index) + else: + assert table.index.name == index[0] + + if len(columns) > 1: + assert table.columns.names == columns + else: + assert table.columns.name == columns[0] + + expected = data.groupby(index + [columns])["D"].agg("mean").unstack() + tm.assert_frame_equal(table, expected) + + def test_pivot_table_categorical_observed_equal(self, observed): + # issue #24923 + df = DataFrame( + {"col1": list("abcde"), "col2": list("fghij"), "col3": [1, 2, 3, 4, 5]} + ) + + expected = df.pivot_table( + index="col1", values="col3", columns="col2", aggfunc="sum", fill_value=0 + ) + + expected.index = expected.index.astype("category") + expected.columns = expected.columns.astype("category") + + df.col1 = df.col1.astype("category") + df.col2 = df.col2.astype("category") + + result = df.pivot_table( + index="col1", + values="col3", + columns="col2", + aggfunc="sum", + fill_value=0, + observed=observed, + ) + + tm.assert_frame_equal(result, expected) + + def test_pivot_table_nocols(self): + df = DataFrame( + {"rows": ["a", "b", "c"], "cols": ["x", "y", "z"], "values": [1, 2, 3]} + ) + rs = df.pivot_table(columns="cols", aggfunc="sum") + xp = df.pivot_table(index="cols", aggfunc="sum").T + tm.assert_frame_equal(rs, xp) + + rs = df.pivot_table(columns="cols", aggfunc={"values": "mean"}) + xp = df.pivot_table(index="cols", aggfunc={"values": "mean"}).T + tm.assert_frame_equal(rs, xp) + + def test_pivot_table_dropna(self): + df = DataFrame( + { + "amount": {0: 60000, 1: 100000, 2: 50000, 3: 30000}, + "customer": {0: "A", 1: "A", 2: "B", 3: "C"}, + "month": {0: 201307, 1: 201309, 2: 201308, 3: 201310}, + "product": {0: "a", 1: "b", 2: "c", 3: "d"}, + "quantity": {0: 2000000, 1: 500000, 2: 1000000, 3: 1000000}, + } + ) + pv_col = df.pivot_table( + "quantity", "month", ["customer", "product"], dropna=False + ) + pv_ind = df.pivot_table( + "quantity", ["customer", "product"], "month", dropna=False + ) + + m = MultiIndex.from_tuples( + [ + ("A", "a"), + ("A", "b"), + ("A", "c"), + ("A", "d"), + ("B", "a"), + ("B", "b"), + ("B", "c"), + ("B", "d"), + ("C", "a"), + ("C", "b"), + ("C", "c"), + ("C", "d"), + ], + names=["customer", "product"], + ) + tm.assert_index_equal(pv_col.columns, m) + tm.assert_index_equal(pv_ind.index, m) + + def test_pivot_table_categorical(self): + cat1 = Categorical( + ["a", "a", "b", "b"], categories=["a", "b", "z"], ordered=True + ) + cat2 = Categorical( + ["c", "d", "c", "d"], categories=["c", "d", "y"], ordered=True + ) + df = DataFrame({"A": cat1, "B": cat2, "values": [1, 2, 3, 4]}) + msg = "The default value of observed=False is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = pivot_table(df, values="values", index=["A", "B"], dropna=True) + + exp_index = MultiIndex.from_arrays([cat1, cat2], names=["A", "B"]) + expected = DataFrame({"values": [1.0, 2.0, 3.0, 4.0]}, index=exp_index) + tm.assert_frame_equal(result, expected) + + def test_pivot_table_dropna_categoricals(self, dropna): + # GH 15193 + categories = ["a", "b", "c", "d"] + + df = DataFrame( + { + "A": ["a", "a", "a", "b", "b", "b", "c", "c", "c"], + "B": [1, 2, 3, 1, 2, 3, 1, 2, 3], + "C": range(9), + } + ) + + df["A"] = df["A"].astype(CategoricalDtype(categories, ordered=False)) + msg = "The default value of observed=False is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = df.pivot_table(index="B", columns="A", values="C", dropna=dropna) + expected_columns = Series(["a", "b", "c"], name="A") + expected_columns = expected_columns.astype( + CategoricalDtype(categories, ordered=False) + ) + expected_index = Series([1, 2, 3], name="B") + expected = DataFrame( + [[0.0, 3.0, 6.0], [1.0, 4.0, 7.0], [2.0, 5.0, 8.0]], + index=expected_index, + columns=expected_columns, + ) + if not dropna: + # add back the non observed to compare + expected = expected.reindex(columns=Categorical(categories)).astype("float") + + tm.assert_frame_equal(result, expected) + + def test_pivot_with_non_observable_dropna(self, dropna): + # gh-21133 + df = DataFrame( + { + "A": Categorical( + [np.nan, "low", "high", "low", "high"], + categories=["low", "high"], + ordered=True, + ), + "B": [0.0, 1.0, 2.0, 3.0, 4.0], + } + ) + + msg = "The default value of observed=False is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = df.pivot_table(index="A", values="B", dropna=dropna) + if dropna: + values = [2.0, 3.0] + codes = [0, 1] + else: + # GH: 10772 + values = [2.0, 3.0, 0.0] + codes = [0, 1, -1] + expected = DataFrame( + {"B": values}, + index=Index( + Categorical.from_codes( + codes, categories=["low", "high"], ordered=dropna + ), + name="A", + ), + ) + + tm.assert_frame_equal(result, expected) + + def test_pivot_with_non_observable_dropna_multi_cat(self, dropna): + # gh-21378 + df = DataFrame( + { + "A": Categorical( + ["left", "low", "high", "low", "high"], + categories=["low", "high", "left"], + ordered=True, + ), + "B": range(5), + } + ) + + msg = "The default value of observed=False is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = df.pivot_table(index="A", values="B", dropna=dropna) + expected = DataFrame( + {"B": [2.0, 3.0, 0.0]}, + index=Index( + Categorical.from_codes( + [0, 1, 2], categories=["low", "high", "left"], ordered=True + ), + name="A", + ), + ) + if not dropna: + expected["B"] = expected["B"].astype(float) + + tm.assert_frame_equal(result, expected) + + def test_pivot_with_interval_index(self, interval_values, dropna): + # GH 25814 + df = DataFrame({"A": interval_values, "B": 1}) + + msg = "The default value of observed=False is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = df.pivot_table(index="A", values="B", dropna=dropna) + expected = DataFrame( + {"B": 1.0}, index=Index(interval_values.unique(), name="A") + ) + if not dropna: + expected = expected.astype(float) + tm.assert_frame_equal(result, expected) + + def test_pivot_with_interval_index_margins(self): + # GH 25815 + ordered_cat = pd.IntervalIndex.from_arrays([0, 0, 1, 1], [1, 1, 2, 2]) + df = DataFrame( + { + "A": np.arange(4, 0, -1, dtype=np.intp), + "B": ["a", "b", "a", "b"], + "C": Categorical(ordered_cat, ordered=True).sort_values( + ascending=False + ), + } + ) + + msg = "The default value of observed=False is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + pivot_tab = pivot_table( + df, index="C", columns="B", values="A", aggfunc="sum", margins=True + ) + + result = pivot_tab["All"] + expected = Series( + [3, 7, 10], + index=Index([pd.Interval(0, 1), pd.Interval(1, 2), "All"], name="C"), + name="All", + dtype=np.intp, + ) + tm.assert_series_equal(result, expected) + + def test_pass_array(self, data): + result = data.pivot_table("D", index=data.A, columns=data.C) + expected = data.pivot_table("D", index="A", columns="C") + tm.assert_frame_equal(result, expected) + + def test_pass_function(self, data): + result = data.pivot_table("D", index=lambda x: x // 5, columns=data.C) + expected = data.pivot_table("D", index=data.index // 5, columns="C") + tm.assert_frame_equal(result, expected) + + def test_pivot_table_multiple(self, data): + index = ["A", "B"] + columns = "C" + table = pivot_table(data, index=index, columns=columns) + expected = data.groupby(index + [columns]).agg("mean").unstack() + tm.assert_frame_equal(table, expected) + + def test_pivot_dtypes(self): + # can convert dtypes + f = DataFrame( + { + "a": ["cat", "bat", "cat", "bat"], + "v": [1, 2, 3, 4], + "i": ["a", "b", "a", "b"], + } + ) + assert f.dtypes["v"] == "int64" + + z = pivot_table( + f, values="v", index=["a"], columns=["i"], fill_value=0, aggfunc="sum" + ) + result = z.dtypes + expected = Series([np.dtype("int64")] * 2, index=Index(list("ab"), name="i")) + tm.assert_series_equal(result, expected) + + # cannot convert dtypes + f = DataFrame( + { + "a": ["cat", "bat", "cat", "bat"], + "v": [1.5, 2.5, 3.5, 4.5], + "i": ["a", "b", "a", "b"], + } + ) + assert f.dtypes["v"] == "float64" + + z = pivot_table( + f, values="v", index=["a"], columns=["i"], fill_value=0, aggfunc="mean" + ) + result = z.dtypes + expected = Series([np.dtype("float64")] * 2, index=Index(list("ab"), name="i")) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "columns,values", + [ + ("bool1", ["float1", "float2"]), + ("bool1", ["float1", "float2", "bool1"]), + ("bool2", ["float1", "float2", "bool1"]), + ], + ) + def test_pivot_preserve_dtypes(self, columns, values): + # GH 7142 regression test + v = np.arange(5, dtype=np.float64) + df = DataFrame( + {"float1": v, "float2": v + 2.0, "bool1": v <= 2, "bool2": v <= 3} + ) + + df_res = df.reset_index().pivot_table( + index="index", columns=columns, values=values + ) + + result = dict(df_res.dtypes) + expected = {col: np.dtype("float64") for col in df_res} + assert result == expected + + def test_pivot_no_values(self): + # GH 14380 + idx = pd.DatetimeIndex( + ["2011-01-01", "2011-02-01", "2011-01-02", "2011-01-01", "2011-01-02"] + ) + df = DataFrame({"A": [1, 2, 3, 4, 5]}, index=idx) + res = df.pivot_table(index=df.index.month, columns=df.index.day) + + exp_columns = MultiIndex.from_tuples([("A", 1), ("A", 2)]) + exp_columns = exp_columns.set_levels( + exp_columns.levels[1].astype(np.int32), level=1 + ) + exp = DataFrame( + [[2.5, 4.0], [2.0, np.nan]], + index=Index([1, 2], dtype=np.int32), + columns=exp_columns, + ) + tm.assert_frame_equal(res, exp) + + df = DataFrame( + { + "A": [1, 2, 3, 4, 5], + "dt": date_range("2011-01-01", freq="D", periods=5), + }, + index=idx, + ) + res = df.pivot_table(index=df.index.month, columns=Grouper(key="dt", freq="ME")) + exp_columns = MultiIndex.from_arrays( + [["A"], pd.DatetimeIndex(["2011-01-31"], dtype="M8[ns]")], + names=[None, "dt"], + ) + exp = DataFrame( + [3.25, 2.0], index=Index([1, 2], dtype=np.int32), columns=exp_columns + ) + tm.assert_frame_equal(res, exp) + + res = df.pivot_table( + index=Grouper(freq="YE"), columns=Grouper(key="dt", freq="ME") + ) + exp = DataFrame( + [3.0], + index=pd.DatetimeIndex(["2011-12-31"], freq="YE"), + columns=exp_columns, + ) + tm.assert_frame_equal(res, exp) + + def test_pivot_multi_values(self, data): + result = pivot_table( + data, values=["D", "E"], index="A", columns=["B", "C"], fill_value=0 + ) + expected = pivot_table( + data.drop(["F"], axis=1), index="A", columns=["B", "C"], fill_value=0 + ) + tm.assert_frame_equal(result, expected) + + def test_pivot_multi_functions(self, data): + f = lambda func: pivot_table( + data, values=["D", "E"], index=["A", "B"], columns="C", aggfunc=func + ) + result = f(["mean", "std"]) + means = f("mean") + stds = f("std") + expected = concat([means, stds], keys=["mean", "std"], axis=1) + tm.assert_frame_equal(result, expected) + + # margins not supported?? + f = lambda func: pivot_table( + data, + values=["D", "E"], + index=["A", "B"], + columns="C", + aggfunc=func, + margins=True, + ) + result = f(["mean", "std"]) + means = f("mean") + stds = f("std") + expected = concat([means, stds], keys=["mean", "std"], axis=1) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("method", [True, False]) + def test_pivot_index_with_nan(self, method): + # GH 3588 + nan = np.nan + df = DataFrame( + { + "a": ["R1", "R2", nan, "R4"], + "b": ["C1", "C2", "C3", "C4"], + "c": [10, 15, 17, 20], + } + ) + if method: + result = df.pivot(index="a", columns="b", values="c") + else: + result = pd.pivot(df, index="a", columns="b", values="c") + expected = DataFrame( + [ + [nan, nan, 17, nan], + [10, nan, nan, nan], + [nan, 15, nan, nan], + [nan, nan, nan, 20], + ], + index=Index([nan, "R1", "R2", "R4"], name="a"), + columns=Index(["C1", "C2", "C3", "C4"], name="b"), + ) + tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(df.pivot(index="b", columns="a", values="c"), expected.T) + + @pytest.mark.parametrize("method", [True, False]) + def test_pivot_index_with_nan_dates(self, method): + # GH9491 + df = DataFrame( + { + "a": date_range("2014-02-01", periods=6, freq="D"), + "c": 100 + np.arange(6), + } + ) + df["b"] = df["a"] - pd.Timestamp("2014-02-02") + df.loc[1, "a"] = df.loc[3, "a"] = np.nan + df.loc[1, "b"] = df.loc[4, "b"] = np.nan + + if method: + pv = df.pivot(index="a", columns="b", values="c") + else: + pv = pd.pivot(df, index="a", columns="b", values="c") + assert pv.notna().values.sum() == len(df) + + for _, row in df.iterrows(): + assert pv.loc[row["a"], row["b"]] == row["c"] + + if method: + result = df.pivot(index="b", columns="a", values="c") + else: + result = pd.pivot(df, index="b", columns="a", values="c") + tm.assert_frame_equal(result, pv.T) + + @pytest.mark.parametrize("method", [True, False]) + def test_pivot_with_tz(self, method, unit): + # GH 5878 + df = DataFrame( + { + "dt1": pd.DatetimeIndex( + [ + datetime(2013, 1, 1, 9, 0), + datetime(2013, 1, 2, 9, 0), + datetime(2013, 1, 1, 9, 0), + datetime(2013, 1, 2, 9, 0), + ], + dtype=f"M8[{unit}, US/Pacific]", + ), + "dt2": pd.DatetimeIndex( + [ + datetime(2014, 1, 1, 9, 0), + datetime(2014, 1, 1, 9, 0), + datetime(2014, 1, 2, 9, 0), + datetime(2014, 1, 2, 9, 0), + ], + dtype=f"M8[{unit}, Asia/Tokyo]", + ), + "data1": np.arange(4, dtype="int64"), + "data2": np.arange(4, dtype="int64"), + } + ) + + exp_col1 = Index(["data1", "data1", "data2", "data2"]) + exp_col2 = pd.DatetimeIndex( + ["2014/01/01 09:00", "2014/01/02 09:00"] * 2, + name="dt2", + dtype=f"M8[{unit}, Asia/Tokyo]", + ) + exp_col = MultiIndex.from_arrays([exp_col1, exp_col2]) + exp_idx = pd.DatetimeIndex( + ["2013/01/01 09:00", "2013/01/02 09:00"], + name="dt1", + dtype=f"M8[{unit}, US/Pacific]", + ) + expected = DataFrame( + [[0, 2, 0, 2], [1, 3, 1, 3]], + index=exp_idx, + columns=exp_col, + ) + + if method: + pv = df.pivot(index="dt1", columns="dt2") + else: + pv = pd.pivot(df, index="dt1", columns="dt2") + tm.assert_frame_equal(pv, expected) + + expected = DataFrame( + [[0, 2], [1, 3]], + index=exp_idx, + columns=exp_col2[:2], + ) + + if method: + pv = df.pivot(index="dt1", columns="dt2", values="data1") + else: + pv = pd.pivot(df, index="dt1", columns="dt2", values="data1") + tm.assert_frame_equal(pv, expected) + + def test_pivot_tz_in_values(self): + # GH 14948 + df = DataFrame( + [ + { + "uid": "aa", + "ts": pd.Timestamp("2016-08-12 13:00:00-0700", tz="US/Pacific"), + }, + { + "uid": "aa", + "ts": pd.Timestamp("2016-08-12 08:00:00-0700", tz="US/Pacific"), + }, + { + "uid": "aa", + "ts": pd.Timestamp("2016-08-12 14:00:00-0700", tz="US/Pacific"), + }, + { + "uid": "aa", + "ts": pd.Timestamp("2016-08-25 11:00:00-0700", tz="US/Pacific"), + }, + { + "uid": "aa", + "ts": pd.Timestamp("2016-08-25 13:00:00-0700", tz="US/Pacific"), + }, + ] + ) + + df = df.set_index("ts").reset_index() + mins = df.ts.map(lambda x: x.replace(hour=0, minute=0, second=0, microsecond=0)) + + result = pivot_table( + df.set_index("ts").reset_index(), + values="ts", + index=["uid"], + columns=[mins], + aggfunc="min", + ) + expected = DataFrame( + [ + [ + pd.Timestamp("2016-08-12 08:00:00-0700", tz="US/Pacific"), + pd.Timestamp("2016-08-25 11:00:00-0700", tz="US/Pacific"), + ] + ], + index=Index(["aa"], name="uid"), + columns=pd.DatetimeIndex( + [ + pd.Timestamp("2016-08-12 00:00:00", tz="US/Pacific"), + pd.Timestamp("2016-08-25 00:00:00", tz="US/Pacific"), + ], + name="ts", + ), + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("method", [True, False]) + def test_pivot_periods(self, method): + df = DataFrame( + { + "p1": [ + pd.Period("2013-01-01", "D"), + pd.Period("2013-01-02", "D"), + pd.Period("2013-01-01", "D"), + pd.Period("2013-01-02", "D"), + ], + "p2": [ + pd.Period("2013-01", "M"), + pd.Period("2013-01", "M"), + pd.Period("2013-02", "M"), + pd.Period("2013-02", "M"), + ], + "data1": np.arange(4, dtype="int64"), + "data2": np.arange(4, dtype="int64"), + } + ) + + exp_col1 = Index(["data1", "data1", "data2", "data2"]) + exp_col2 = pd.PeriodIndex(["2013-01", "2013-02"] * 2, name="p2", freq="M") + exp_col = MultiIndex.from_arrays([exp_col1, exp_col2]) + expected = DataFrame( + [[0, 2, 0, 2], [1, 3, 1, 3]], + index=pd.PeriodIndex(["2013-01-01", "2013-01-02"], name="p1", freq="D"), + columns=exp_col, + ) + if method: + pv = df.pivot(index="p1", columns="p2") + else: + pv = pd.pivot(df, index="p1", columns="p2") + tm.assert_frame_equal(pv, expected) + + expected = DataFrame( + [[0, 2], [1, 3]], + index=pd.PeriodIndex(["2013-01-01", "2013-01-02"], name="p1", freq="D"), + columns=pd.PeriodIndex(["2013-01", "2013-02"], name="p2", freq="M"), + ) + if method: + pv = df.pivot(index="p1", columns="p2", values="data1") + else: + pv = pd.pivot(df, index="p1", columns="p2", values="data1") + tm.assert_frame_equal(pv, expected) + + def test_pivot_periods_with_margins(self): + # GH 28323 + df = DataFrame( + { + "a": [1, 1, 2, 2], + "b": [ + pd.Period("2019Q1"), + pd.Period("2019Q2"), + pd.Period("2019Q1"), + pd.Period("2019Q2"), + ], + "x": 1.0, + } + ) + + expected = DataFrame( + data=1.0, + index=Index([1, 2, "All"], name="a"), + columns=Index([pd.Period("2019Q1"), pd.Period("2019Q2"), "All"], name="b"), + ) + + result = df.pivot_table(index="a", columns="b", values="x", margins=True) + tm.assert_frame_equal(expected, result) + + @pytest.mark.parametrize( + "values", + [ + ["baz", "zoo"], + np.array(["baz", "zoo"]), + Series(["baz", "zoo"]), + Index(["baz", "zoo"]), + ], + ) + @pytest.mark.parametrize("method", [True, False]) + def test_pivot_with_list_like_values(self, values, method): + # issue #17160 + df = DataFrame( + { + "foo": ["one", "one", "one", "two", "two", "two"], + "bar": ["A", "B", "C", "A", "B", "C"], + "baz": [1, 2, 3, 4, 5, 6], + "zoo": ["x", "y", "z", "q", "w", "t"], + } + ) + + if method: + result = df.pivot(index="foo", columns="bar", values=values) + else: + result = pd.pivot(df, index="foo", columns="bar", values=values) + + data = [[1, 2, 3, "x", "y", "z"], [4, 5, 6, "q", "w", "t"]] + index = Index(data=["one", "two"], name="foo") + columns = MultiIndex( + levels=[["baz", "zoo"], ["A", "B", "C"]], + codes=[[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]], + names=[None, "bar"], + ) + expected = DataFrame(data=data, index=index, columns=columns) + expected["baz"] = expected["baz"].astype(object) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "values", + [ + ["bar", "baz"], + np.array(["bar", "baz"]), + Series(["bar", "baz"]), + Index(["bar", "baz"]), + ], + ) + @pytest.mark.parametrize("method", [True, False]) + def test_pivot_with_list_like_values_nans(self, values, method): + # issue #17160 + df = DataFrame( + { + "foo": ["one", "one", "one", "two", "two", "two"], + "bar": ["A", "B", "C", "A", "B", "C"], + "baz": [1, 2, 3, 4, 5, 6], + "zoo": ["x", "y", "z", "q", "w", "t"], + } + ) + + if method: + result = df.pivot(index="zoo", columns="foo", values=values) + else: + result = pd.pivot(df, index="zoo", columns="foo", values=values) + + data = [ + [np.nan, "A", np.nan, 4], + [np.nan, "C", np.nan, 6], + [np.nan, "B", np.nan, 5], + ["A", np.nan, 1, np.nan], + ["B", np.nan, 2, np.nan], + ["C", np.nan, 3, np.nan], + ] + index = Index(data=["q", "t", "w", "x", "y", "z"], name="zoo") + columns = MultiIndex( + levels=[["bar", "baz"], ["one", "two"]], + codes=[[0, 0, 1, 1], [0, 1, 0, 1]], + names=[None, "foo"], + ) + expected = DataFrame(data=data, index=index, columns=columns) + expected["baz"] = expected["baz"].astype(object) + tm.assert_frame_equal(result, expected) + + def test_pivot_columns_none_raise_error(self): + # GH 30924 + df = DataFrame({"col1": ["a", "b", "c"], "col2": [1, 2, 3], "col3": [1, 2, 3]}) + msg = r"pivot\(\) missing 1 required keyword-only argument: 'columns'" + with pytest.raises(TypeError, match=msg): + df.pivot(index="col1", values="col3") # pylint: disable=missing-kwoa + + @pytest.mark.xfail( + reason="MultiIndexed unstack with tuple names fails with KeyError GH#19966" + ) + @pytest.mark.parametrize("method", [True, False]) + def test_pivot_with_multiindex(self, method): + # issue #17160 + index = Index(data=[0, 1, 2, 3, 4, 5]) + data = [ + ["one", "A", 1, "x"], + ["one", "B", 2, "y"], + ["one", "C", 3, "z"], + ["two", "A", 4, "q"], + ["two", "B", 5, "w"], + ["two", "C", 6, "t"], + ] + columns = MultiIndex( + levels=[["bar", "baz"], ["first", "second"]], + codes=[[0, 0, 1, 1], [0, 1, 0, 1]], + ) + df = DataFrame(data=data, index=index, columns=columns, dtype="object") + if method: + result = df.pivot( + index=("bar", "first"), + columns=("bar", "second"), + values=("baz", "first"), + ) + else: + result = pd.pivot( + df, + index=("bar", "first"), + columns=("bar", "second"), + values=("baz", "first"), + ) + + data = { + "A": Series([1, 4], index=["one", "two"]), + "B": Series([2, 5], index=["one", "two"]), + "C": Series([3, 6], index=["one", "two"]), + } + expected = DataFrame(data) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("method", [True, False]) + def test_pivot_with_tuple_of_values(self, method): + # issue #17160 + df = DataFrame( + { + "foo": ["one", "one", "one", "two", "two", "two"], + "bar": ["A", "B", "C", "A", "B", "C"], + "baz": [1, 2, 3, 4, 5, 6], + "zoo": ["x", "y", "z", "q", "w", "t"], + } + ) + with pytest.raises(KeyError, match=r"^\('bar', 'baz'\)$"): + # tuple is seen as a single column name + if method: + df.pivot(index="zoo", columns="foo", values=("bar", "baz")) + else: + pd.pivot(df, index="zoo", columns="foo", values=("bar", "baz")) + + def _check_output( + self, + result, + values_col, + data, + index=["A", "B"], + columns=["C"], + margins_col="All", + ): + col_margins = result.loc[result.index[:-1], margins_col] + expected_col_margins = data.groupby(index)[values_col].mean() + tm.assert_series_equal(col_margins, expected_col_margins, check_names=False) + assert col_margins.name == margins_col + + result = result.sort_index() + index_margins = result.loc[(margins_col, "")].iloc[:-1] + + expected_ix_margins = data.groupby(columns)[values_col].mean() + tm.assert_series_equal(index_margins, expected_ix_margins, check_names=False) + assert index_margins.name == (margins_col, "") + + grand_total_margins = result.loc[(margins_col, ""), margins_col] + expected_total_margins = data[values_col].mean() + assert grand_total_margins == expected_total_margins + + def test_margins(self, data): + # column specified + result = data.pivot_table( + values="D", index=["A", "B"], columns="C", margins=True, aggfunc="mean" + ) + self._check_output(result, "D", data) + + # Set a different margins_name (not 'All') + result = data.pivot_table( + values="D", + index=["A", "B"], + columns="C", + margins=True, + aggfunc="mean", + margins_name="Totals", + ) + self._check_output(result, "D", data, margins_col="Totals") + + # no column specified + table = data.pivot_table( + index=["A", "B"], columns="C", margins=True, aggfunc="mean" + ) + for value_col in table.columns.levels[0]: + self._check_output(table[value_col], value_col, data) + + def test_no_col(self, data): + # no col + + # to help with a buglet + data.columns = [k * 2 for k in data.columns] + msg = re.escape("agg function failed [how->mean,dtype->") + with pytest.raises(TypeError, match=msg): + data.pivot_table(index=["AA", "BB"], margins=True, aggfunc="mean") + table = data.drop(columns="CC").pivot_table( + index=["AA", "BB"], margins=True, aggfunc="mean" + ) + for value_col in table.columns: + totals = table.loc[("All", ""), value_col] + assert totals == data[value_col].mean() + + with pytest.raises(TypeError, match=msg): + data.pivot_table(index=["AA", "BB"], margins=True, aggfunc="mean") + table = data.drop(columns="CC").pivot_table( + index=["AA", "BB"], margins=True, aggfunc="mean" + ) + for item in ["DD", "EE", "FF"]: + totals = table.loc[("All", ""), item] + assert totals == data[item].mean() + + @pytest.mark.parametrize( + "columns, aggfunc, values, expected_columns", + [ + ( + "A", + "mean", + [[5.5, 5.5, 2.2, 2.2], [8.0, 8.0, 4.4, 4.4]], + Index(["bar", "All", "foo", "All"], name="A"), + ), + ( + ["A", "B"], + "sum", + [ + [9, 13, 22, 5, 6, 11], + [14, 18, 32, 11, 11, 22], + ], + MultiIndex.from_tuples( + [ + ("bar", "one"), + ("bar", "two"), + ("bar", "All"), + ("foo", "one"), + ("foo", "two"), + ("foo", "All"), + ], + names=["A", "B"], + ), + ), + ], + ) + def test_margin_with_only_columns_defined( + self, columns, aggfunc, values, expected_columns + ): + # GH 31016 + df = DataFrame( + { + "A": ["foo", "foo", "foo", "foo", "foo", "bar", "bar", "bar", "bar"], + "B": ["one", "one", "one", "two", "two", "one", "one", "two", "two"], + "C": [ + "small", + "large", + "large", + "small", + "small", + "large", + "small", + "small", + "large", + ], + "D": [1, 2, 2, 3, 3, 4, 5, 6, 7], + "E": [2, 4, 5, 5, 6, 6, 8, 9, 9], + } + ) + if aggfunc != "sum": + msg = re.escape("agg function failed [how->mean,dtype->") + with pytest.raises(TypeError, match=msg): + df.pivot_table(columns=columns, margins=True, aggfunc=aggfunc) + if "B" not in columns: + df = df.drop(columns="B") + result = df.drop(columns="C").pivot_table( + columns=columns, margins=True, aggfunc=aggfunc + ) + expected = DataFrame(values, index=Index(["D", "E"]), columns=expected_columns) + + tm.assert_frame_equal(result, expected) + + def test_margins_dtype(self, data): + # GH 17013 + + df = data.copy() + df[["D", "E", "F"]] = np.arange(len(df) * 3).reshape(len(df), 3).astype("i8") + + mi_val = list(product(["bar", "foo"], ["one", "two"])) + [("All", "")] + mi = MultiIndex.from_tuples(mi_val, names=("A", "B")) + expected = DataFrame( + {"dull": [12, 21, 3, 9, 45], "shiny": [33, 0, 36, 51, 120]}, index=mi + ).rename_axis("C", axis=1) + expected["All"] = expected["dull"] + expected["shiny"] + + result = df.pivot_table( + values="D", + index=["A", "B"], + columns="C", + margins=True, + aggfunc="sum", + fill_value=0, + ) + + tm.assert_frame_equal(expected, result) + + def test_margins_dtype_len(self, data): + mi_val = list(product(["bar", "foo"], ["one", "two"])) + [("All", "")] + mi = MultiIndex.from_tuples(mi_val, names=("A", "B")) + expected = DataFrame( + {"dull": [1, 1, 2, 1, 5], "shiny": [2, 0, 2, 2, 6]}, index=mi + ).rename_axis("C", axis=1) + expected["All"] = expected["dull"] + expected["shiny"] + + result = data.pivot_table( + values="D", + index=["A", "B"], + columns="C", + margins=True, + aggfunc=len, + fill_value=0, + ) + + tm.assert_frame_equal(expected, result) + + @pytest.mark.parametrize("cols", [(1, 2), ("a", "b"), (1, "b"), ("a", 1)]) + def test_pivot_table_multiindex_only(self, cols): + # GH 17038 + df2 = DataFrame({cols[0]: [1, 2, 3], cols[1]: [1, 2, 3], "v": [4, 5, 6]}) + + result = df2.pivot_table(values="v", columns=cols) + expected = DataFrame( + [[4.0, 5.0, 6.0]], + columns=MultiIndex.from_tuples([(1, 1), (2, 2), (3, 3)], names=cols), + index=Index(["v"], dtype=object), + ) + + tm.assert_frame_equal(result, expected) + + def test_pivot_table_retains_tz(self): + dti = date_range("2016-01-01", periods=3, tz="Europe/Amsterdam") + df = DataFrame( + { + "A": np.random.default_rng(2).standard_normal(3), + "B": np.random.default_rng(2).standard_normal(3), + "C": dti, + } + ) + result = df.pivot_table(index=["B", "C"], dropna=False) + + # check tz retention + assert result.index.levels[1].equals(dti) + + def test_pivot_integer_columns(self): + # caused by upstream bug in unstack + + d = date.min + data = list( + product( + ["foo", "bar"], + ["A", "B", "C"], + ["x1", "x2"], + [d + timedelta(i) for i in range(20)], + [1.0], + ) + ) + df = DataFrame(data) + table = df.pivot_table(values=4, index=[0, 1, 3], columns=[2]) + + df2 = df.rename(columns=str) + table2 = df2.pivot_table(values="4", index=["0", "1", "3"], columns=["2"]) + + tm.assert_frame_equal(table, table2, check_names=False) + + def test_pivot_no_level_overlap(self): + # GH #1181 + + data = DataFrame( + { + "a": ["a", "a", "a", "a", "b", "b", "b", "b"] * 2, + "b": [0, 0, 0, 0, 1, 1, 1, 1] * 2, + "c": (["foo"] * 4 + ["bar"] * 4) * 2, + "value": np.random.default_rng(2).standard_normal(16), + } + ) + + table = data.pivot_table("value", index="a", columns=["b", "c"]) + + grouped = data.groupby(["a", "b", "c"])["value"].mean() + expected = grouped.unstack("b").unstack("c").dropna(axis=1, how="all") + tm.assert_frame_equal(table, expected) + + def test_pivot_columns_lexsorted(self): + n = 10000 + + dtype = np.dtype( + [ + ("Index", object), + ("Symbol", object), + ("Year", int), + ("Month", int), + ("Day", int), + ("Quantity", int), + ("Price", float), + ] + ) + + products = np.array( + [ + ("SP500", "ADBE"), + ("SP500", "NVDA"), + ("SP500", "ORCL"), + ("NDQ100", "AAPL"), + ("NDQ100", "MSFT"), + ("NDQ100", "GOOG"), + ("FTSE", "DGE.L"), + ("FTSE", "TSCO.L"), + ("FTSE", "GSK.L"), + ], + dtype=[("Index", object), ("Symbol", object)], + ) + items = np.empty(n, dtype=dtype) + iproduct = np.random.default_rng(2).integers(0, len(products), n) + items["Index"] = products["Index"][iproduct] + items["Symbol"] = products["Symbol"][iproduct] + dr = date_range(date(2000, 1, 1), date(2010, 12, 31)) + dates = dr[np.random.default_rng(2).integers(0, len(dr), n)] + items["Year"] = dates.year + items["Month"] = dates.month + items["Day"] = dates.day + items["Price"] = np.random.default_rng(2).lognormal(4.0, 2.0, n) + + df = DataFrame(items) + + pivoted = df.pivot_table( + "Price", + index=["Month", "Day"], + columns=["Index", "Symbol", "Year"], + aggfunc="mean", + ) + + assert pivoted.columns.is_monotonic_increasing + + def test_pivot_complex_aggfunc(self, data): + f = {"D": ["std"], "E": ["sum"]} + expected = data.groupby(["A", "B"]).agg(f).unstack("B") + result = data.pivot_table(index="A", columns="B", aggfunc=f) + + tm.assert_frame_equal(result, expected) + + def test_margins_no_values_no_cols(self, data): + # Regression test on pivot table: no values or cols passed. + result = data[["A", "B"]].pivot_table( + index=["A", "B"], aggfunc=len, margins=True + ) + result_list = result.tolist() + assert sum(result_list[:-1]) == result_list[-1] + + def test_margins_no_values_two_rows(self, data): + # Regression test on pivot table: no values passed but rows are a + # multi-index + result = data[["A", "B", "C"]].pivot_table( + index=["A", "B"], columns="C", aggfunc=len, margins=True + ) + assert result.All.tolist() == [3.0, 1.0, 4.0, 3.0, 11.0] + + def test_margins_no_values_one_row_one_col(self, data): + # Regression test on pivot table: no values passed but row and col + # defined + result = data[["A", "B"]].pivot_table( + index="A", columns="B", aggfunc=len, margins=True + ) + assert result.All.tolist() == [4.0, 7.0, 11.0] + + def test_margins_no_values_two_row_two_cols(self, data): + # Regression test on pivot table: no values passed but rows and cols + # are multi-indexed + data["D"] = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k"] + result = data[["A", "B", "C", "D"]].pivot_table( + index=["A", "B"], columns=["C", "D"], aggfunc=len, margins=True + ) + assert result.All.tolist() == [3.0, 1.0, 4.0, 3.0, 11.0] + + @pytest.mark.parametrize("margin_name", ["foo", "one", 666, None, ["a", "b"]]) + def test_pivot_table_with_margins_set_margin_name(self, margin_name, data): + # see gh-3335 + msg = ( + f'Conflicting name "{margin_name}" in margins|' + "margins_name argument must be a string" + ) + with pytest.raises(ValueError, match=msg): + # multi-index index + pivot_table( + data, + values="D", + index=["A", "B"], + columns=["C"], + margins=True, + margins_name=margin_name, + ) + with pytest.raises(ValueError, match=msg): + # multi-index column + pivot_table( + data, + values="D", + index=["C"], + columns=["A", "B"], + margins=True, + margins_name=margin_name, + ) + with pytest.raises(ValueError, match=msg): + # non-multi-index index/column + pivot_table( + data, + values="D", + index=["A"], + columns=["B"], + margins=True, + margins_name=margin_name, + ) + + def test_pivot_timegrouper(self, using_array_manager): + df = DataFrame( + { + "Branch": "A A A A A A A B".split(), + "Buyer": "Carl Mark Carl Carl Joe Joe Joe Carl".split(), + "Quantity": [1, 3, 5, 1, 8, 1, 9, 3], + "Date": [ + datetime(2013, 1, 1), + datetime(2013, 1, 1), + datetime(2013, 10, 1), + datetime(2013, 10, 2), + datetime(2013, 10, 1), + datetime(2013, 10, 2), + datetime(2013, 12, 2), + datetime(2013, 12, 2), + ], + } + ).set_index("Date") + + expected = DataFrame( + np.array([10, 18, 3], dtype="int64").reshape(1, 3), + index=pd.DatetimeIndex([datetime(2013, 12, 31)], freq="YE"), + columns="Carl Joe Mark".split(), + ) + expected.index.name = "Date" + expected.columns.name = "Buyer" + + result = pivot_table( + df, + index=Grouper(freq="YE"), + columns="Buyer", + values="Quantity", + aggfunc="sum", + ) + tm.assert_frame_equal(result, expected) + + result = pivot_table( + df, + index="Buyer", + columns=Grouper(freq="YE"), + values="Quantity", + aggfunc="sum", + ) + tm.assert_frame_equal(result, expected.T) + + expected = DataFrame( + np.array([1, np.nan, 3, 9, 18, np.nan]).reshape(2, 3), + index=pd.DatetimeIndex( + [datetime(2013, 1, 1), datetime(2013, 7, 1)], freq="6MS" + ), + columns="Carl Joe Mark".split(), + ) + expected.index.name = "Date" + expected.columns.name = "Buyer" + if using_array_manager: + # INFO(ArrayManager) column without NaNs can preserve int dtype + expected["Carl"] = expected["Carl"].astype("int64") + + result = pivot_table( + df, + index=Grouper(freq="6MS"), + columns="Buyer", + values="Quantity", + aggfunc="sum", + ) + tm.assert_frame_equal(result, expected) + + result = pivot_table( + df, + index="Buyer", + columns=Grouper(freq="6MS"), + values="Quantity", + aggfunc="sum", + ) + tm.assert_frame_equal(result, expected.T) + + # passing the name + df = df.reset_index() + result = pivot_table( + df, + index=Grouper(freq="6MS", key="Date"), + columns="Buyer", + values="Quantity", + aggfunc="sum", + ) + tm.assert_frame_equal(result, expected) + + result = pivot_table( + df, + index="Buyer", + columns=Grouper(freq="6MS", key="Date"), + values="Quantity", + aggfunc="sum", + ) + tm.assert_frame_equal(result, expected.T) + + msg = "'The grouper name foo is not found'" + with pytest.raises(KeyError, match=msg): + pivot_table( + df, + index=Grouper(freq="6MS", key="foo"), + columns="Buyer", + values="Quantity", + aggfunc="sum", + ) + with pytest.raises(KeyError, match=msg): + pivot_table( + df, + index="Buyer", + columns=Grouper(freq="6MS", key="foo"), + values="Quantity", + aggfunc="sum", + ) + + # passing the level + df = df.set_index("Date") + result = pivot_table( + df, + index=Grouper(freq="6MS", level="Date"), + columns="Buyer", + values="Quantity", + aggfunc="sum", + ) + tm.assert_frame_equal(result, expected) + + result = pivot_table( + df, + index="Buyer", + columns=Grouper(freq="6MS", level="Date"), + values="Quantity", + aggfunc="sum", + ) + tm.assert_frame_equal(result, expected.T) + + msg = "The level foo is not valid" + with pytest.raises(ValueError, match=msg): + pivot_table( + df, + index=Grouper(freq="6MS", level="foo"), + columns="Buyer", + values="Quantity", + aggfunc="sum", + ) + with pytest.raises(ValueError, match=msg): + pivot_table( + df, + index="Buyer", + columns=Grouper(freq="6MS", level="foo"), + values="Quantity", + aggfunc="sum", + ) + + def test_pivot_timegrouper_double(self): + # double grouper + df = DataFrame( + { + "Branch": "A A A A A A A B".split(), + "Buyer": "Carl Mark Carl Carl Joe Joe Joe Carl".split(), + "Quantity": [1, 3, 5, 1, 8, 1, 9, 3], + "Date": [ + datetime(2013, 11, 1, 13, 0), + datetime(2013, 9, 1, 13, 5), + datetime(2013, 10, 1, 20, 0), + datetime(2013, 10, 2, 10, 0), + datetime(2013, 11, 1, 20, 0), + datetime(2013, 10, 2, 10, 0), + datetime(2013, 10, 2, 12, 0), + datetime(2013, 12, 5, 14, 0), + ], + "PayDay": [ + datetime(2013, 10, 4, 0, 0), + datetime(2013, 10, 15, 13, 5), + datetime(2013, 9, 5, 20, 0), + datetime(2013, 11, 2, 10, 0), + datetime(2013, 10, 7, 20, 0), + datetime(2013, 9, 5, 10, 0), + datetime(2013, 12, 30, 12, 0), + datetime(2013, 11, 20, 14, 0), + ], + } + ) + + result = pivot_table( + df, + index=Grouper(freq="ME", key="Date"), + columns=Grouper(freq="ME", key="PayDay"), + values="Quantity", + aggfunc="sum", + ) + expected = DataFrame( + np.array( + [ + np.nan, + 3, + np.nan, + np.nan, + 6, + np.nan, + 1, + 9, + np.nan, + 9, + np.nan, + np.nan, + np.nan, + np.nan, + 3, + np.nan, + ] + ).reshape(4, 4), + index=pd.DatetimeIndex( + [ + datetime(2013, 9, 30), + datetime(2013, 10, 31), + datetime(2013, 11, 30), + datetime(2013, 12, 31), + ], + freq="ME", + ), + columns=pd.DatetimeIndex( + [ + datetime(2013, 9, 30), + datetime(2013, 10, 31), + datetime(2013, 11, 30), + datetime(2013, 12, 31), + ], + freq="ME", + ), + ) + expected.index.name = "Date" + expected.columns.name = "PayDay" + + tm.assert_frame_equal(result, expected) + + result = pivot_table( + df, + index=Grouper(freq="ME", key="PayDay"), + columns=Grouper(freq="ME", key="Date"), + values="Quantity", + aggfunc="sum", + ) + tm.assert_frame_equal(result, expected.T) + + tuples = [ + (datetime(2013, 9, 30), datetime(2013, 10, 31)), + (datetime(2013, 10, 31), datetime(2013, 9, 30)), + (datetime(2013, 10, 31), datetime(2013, 11, 30)), + (datetime(2013, 10, 31), datetime(2013, 12, 31)), + (datetime(2013, 11, 30), datetime(2013, 10, 31)), + (datetime(2013, 12, 31), datetime(2013, 11, 30)), + ] + idx = MultiIndex.from_tuples(tuples, names=["Date", "PayDay"]) + expected = DataFrame( + np.array( + [3, np.nan, 6, np.nan, 1, np.nan, 9, np.nan, 9, np.nan, np.nan, 3] + ).reshape(6, 2), + index=idx, + columns=["A", "B"], + ) + expected.columns.name = "Branch" + + result = pivot_table( + df, + index=[Grouper(freq="ME", key="Date"), Grouper(freq="ME", key="PayDay")], + columns=["Branch"], + values="Quantity", + aggfunc="sum", + ) + tm.assert_frame_equal(result, expected) + + result = pivot_table( + df, + index=["Branch"], + columns=[Grouper(freq="ME", key="Date"), Grouper(freq="ME", key="PayDay")], + values="Quantity", + aggfunc="sum", + ) + tm.assert_frame_equal(result, expected.T) + + def test_pivot_datetime_tz(self): + dates1 = pd.DatetimeIndex( + [ + "2011-07-19 07:00:00", + "2011-07-19 08:00:00", + "2011-07-19 09:00:00", + "2011-07-19 07:00:00", + "2011-07-19 08:00:00", + "2011-07-19 09:00:00", + ], + dtype="M8[ns, US/Pacific]", + name="dt1", + ) + dates2 = pd.DatetimeIndex( + [ + "2013-01-01 15:00:00", + "2013-01-01 15:00:00", + "2013-01-01 15:00:00", + "2013-02-01 15:00:00", + "2013-02-01 15:00:00", + "2013-02-01 15:00:00", + ], + dtype="M8[ns, Asia/Tokyo]", + ) + df = DataFrame( + { + "label": ["a", "a", "a", "b", "b", "b"], + "dt1": dates1, + "dt2": dates2, + "value1": np.arange(6, dtype="int64"), + "value2": [1, 2] * 3, + } + ) + + exp_idx = dates1[:3] + exp_col1 = Index(["value1", "value1"]) + exp_col2 = Index(["a", "b"], name="label") + exp_col = MultiIndex.from_arrays([exp_col1, exp_col2]) + expected = DataFrame( + [[0.0, 3.0], [1.0, 4.0], [2.0, 5.0]], index=exp_idx, columns=exp_col + ) + result = pivot_table(df, index=["dt1"], columns=["label"], values=["value1"]) + tm.assert_frame_equal(result, expected) + + exp_col1 = Index(["sum", "sum", "sum", "sum", "mean", "mean", "mean", "mean"]) + exp_col2 = Index(["value1", "value1", "value2", "value2"] * 2) + exp_col3 = pd.DatetimeIndex( + ["2013-01-01 15:00:00", "2013-02-01 15:00:00"] * 4, + dtype="M8[ns, Asia/Tokyo]", + name="dt2", + ) + exp_col = MultiIndex.from_arrays([exp_col1, exp_col2, exp_col3]) + expected1 = DataFrame( + np.array( + [ + [ + 0, + 3, + 1, + 2, + ], + [1, 4, 2, 1], + [2, 5, 1, 2], + ], + dtype="int64", + ), + index=exp_idx, + columns=exp_col[:4], + ) + expected2 = DataFrame( + np.array( + [ + [0.0, 3.0, 1.0, 2.0], + [1.0, 4.0, 2.0, 1.0], + [2.0, 5.0, 1.0, 2.0], + ], + ), + index=exp_idx, + columns=exp_col[4:], + ) + expected = concat([expected1, expected2], axis=1) + + result = pivot_table( + df, + index=["dt1"], + columns=["dt2"], + values=["value1", "value2"], + aggfunc=["sum", "mean"], + ) + tm.assert_frame_equal(result, expected) + + def test_pivot_dtaccessor(self): + # GH 8103 + dates1 = pd.DatetimeIndex( + [ + "2011-07-19 07:00:00", + "2011-07-19 08:00:00", + "2011-07-19 09:00:00", + "2011-07-19 07:00:00", + "2011-07-19 08:00:00", + "2011-07-19 09:00:00", + ] + ) + dates2 = pd.DatetimeIndex( + [ + "2013-01-01 15:00:00", + "2013-01-01 15:00:00", + "2013-01-01 15:00:00", + "2013-02-01 15:00:00", + "2013-02-01 15:00:00", + "2013-02-01 15:00:00", + ] + ) + df = DataFrame( + { + "label": ["a", "a", "a", "b", "b", "b"], + "dt1": dates1, + "dt2": dates2, + "value1": np.arange(6, dtype="int64"), + "value2": [1, 2] * 3, + } + ) + + result = pivot_table( + df, index="label", columns=df["dt1"].dt.hour, values="value1" + ) + + exp_idx = Index(["a", "b"], name="label") + expected = DataFrame( + {7: [0.0, 3.0], 8: [1.0, 4.0], 9: [2.0, 5.0]}, + index=exp_idx, + columns=Index([7, 8, 9], dtype=np.int32, name="dt1"), + ) + tm.assert_frame_equal(result, expected) + + result = pivot_table( + df, index=df["dt2"].dt.month, columns=df["dt1"].dt.hour, values="value1" + ) + + expected = DataFrame( + {7: [0.0, 3.0], 8: [1.0, 4.0], 9: [2.0, 5.0]}, + index=Index([1, 2], dtype=np.int32, name="dt2"), + columns=Index([7, 8, 9], dtype=np.int32, name="dt1"), + ) + tm.assert_frame_equal(result, expected) + + result = pivot_table( + df, + index=df["dt2"].dt.year.values, + columns=[df["dt1"].dt.hour, df["dt2"].dt.month], + values="value1", + ) + + exp_col = MultiIndex.from_arrays( + [ + np.array([7, 7, 8, 8, 9, 9], dtype=np.int32), + np.array([1, 2] * 3, dtype=np.int32), + ], + names=["dt1", "dt2"], + ) + expected = DataFrame( + np.array([[0.0, 3.0, 1.0, 4.0, 2.0, 5.0]]), + index=Index([2013], dtype=np.int32), + columns=exp_col, + ) + tm.assert_frame_equal(result, expected) + + result = pivot_table( + df, + index=np.array(["X", "X", "X", "X", "Y", "Y"]), + columns=[df["dt1"].dt.hour, df["dt2"].dt.month], + values="value1", + ) + expected = DataFrame( + np.array( + [[0, 3, 1, np.nan, 2, np.nan], [np.nan, np.nan, np.nan, 4, np.nan, 5]] + ), + index=["X", "Y"], + columns=exp_col, + ) + tm.assert_frame_equal(result, expected) + + def test_daily(self): + rng = date_range("1/1/2000", "12/31/2004", freq="D") + ts = Series(np.arange(len(rng)), index=rng) + + result = pivot_table( + DataFrame(ts), index=ts.index.year, columns=ts.index.dayofyear + ) + result.columns = result.columns.droplevel(0) + + doy = np.asarray(ts.index.dayofyear) + + expected = {} + for y in ts.index.year.unique().values: + mask = ts.index.year == y + expected[y] = Series(ts.values[mask], index=doy[mask]) + expected = DataFrame(expected, dtype=float).T + tm.assert_frame_equal(result, expected) + + def test_monthly(self): + rng = date_range("1/1/2000", "12/31/2004", freq="ME") + ts = Series(np.arange(len(rng)), index=rng) + + result = pivot_table(DataFrame(ts), index=ts.index.year, columns=ts.index.month) + result.columns = result.columns.droplevel(0) + + month = np.asarray(ts.index.month) + expected = {} + for y in ts.index.year.unique().values: + mask = ts.index.year == y + expected[y] = Series(ts.values[mask], index=month[mask]) + expected = DataFrame(expected, dtype=float).T + tm.assert_frame_equal(result, expected) + + def test_pivot_table_with_iterator_values(self, data): + # GH 12017 + aggs = {"D": "sum", "E": "mean"} + + pivot_values_list = pivot_table( + data, index=["A"], values=list(aggs.keys()), aggfunc=aggs + ) + + pivot_values_keys = pivot_table( + data, index=["A"], values=aggs.keys(), aggfunc=aggs + ) + tm.assert_frame_equal(pivot_values_keys, pivot_values_list) + + agg_values_gen = (value for value in aggs) + pivot_values_gen = pivot_table( + data, index=["A"], values=agg_values_gen, aggfunc=aggs + ) + tm.assert_frame_equal(pivot_values_gen, pivot_values_list) + + def test_pivot_table_margins_name_with_aggfunc_list(self): + # GH 13354 + margins_name = "Weekly" + costs = DataFrame( + { + "item": ["bacon", "cheese", "bacon", "cheese"], + "cost": [2.5, 4.5, 3.2, 3.3], + "day": ["ME", "ME", "T", "T"], + } + ) + table = costs.pivot_table( + index="item", + columns="day", + margins=True, + margins_name=margins_name, + aggfunc=["mean", "max"], + ) + ix = Index(["bacon", "cheese", margins_name], name="item") + tups = [ + ("mean", "cost", "ME"), + ("mean", "cost", "T"), + ("mean", "cost", margins_name), + ("max", "cost", "ME"), + ("max", "cost", "T"), + ("max", "cost", margins_name), + ] + cols = MultiIndex.from_tuples(tups, names=[None, None, "day"]) + expected = DataFrame(table.values, index=ix, columns=cols) + tm.assert_frame_equal(table, expected) + + def test_categorical_margins(self, observed): + # GH 10989 + df = DataFrame( + {"x": np.arange(8), "y": np.arange(8) // 4, "z": np.arange(8) % 2} + ) + + expected = DataFrame([[1.0, 2.0, 1.5], [5, 6, 5.5], [3, 4, 3.5]]) + expected.index = Index([0, 1, "All"], name="y") + expected.columns = Index([0, 1, "All"], name="z") + + table = df.pivot_table("x", "y", "z", dropna=observed, margins=True) + tm.assert_frame_equal(table, expected) + + def test_categorical_margins_category(self, observed): + df = DataFrame( + {"x": np.arange(8), "y": np.arange(8) // 4, "z": np.arange(8) % 2} + ) + + expected = DataFrame([[1.0, 2.0, 1.5], [5, 6, 5.5], [3, 4, 3.5]]) + expected.index = Index([0, 1, "All"], name="y") + expected.columns = Index([0, 1, "All"], name="z") + + df.y = df.y.astype("category") + df.z = df.z.astype("category") + msg = "The default value of observed=False is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + table = df.pivot_table("x", "y", "z", dropna=observed, margins=True) + tm.assert_frame_equal(table, expected) + + def test_margins_casted_to_float(self): + # GH 24893 + df = DataFrame( + { + "A": [2, 4, 6, 8], + "B": [1, 4, 5, 8], + "C": [1, 3, 4, 6], + "D": ["X", "X", "Y", "Y"], + } + ) + + result = pivot_table(df, index="D", margins=True) + expected = DataFrame( + {"A": [3.0, 7.0, 5], "B": [2.5, 6.5, 4.5], "C": [2.0, 5.0, 3.5]}, + index=Index(["X", "Y", "All"], name="D"), + ) + tm.assert_frame_equal(result, expected) + + def test_pivot_with_categorical(self, observed, ordered): + # gh-21370 + idx = [np.nan, "low", "high", "low", np.nan] + col = [np.nan, "A", "B", np.nan, "A"] + df = DataFrame( + { + "In": Categorical(idx, categories=["low", "high"], ordered=ordered), + "Col": Categorical(col, categories=["A", "B"], ordered=ordered), + "Val": range(1, 6), + } + ) + # case with index/columns/value + result = df.pivot_table( + index="In", columns="Col", values="Val", observed=observed + ) + + expected_cols = pd.CategoricalIndex(["A", "B"], ordered=ordered, name="Col") + + expected = DataFrame(data=[[2.0, np.nan], [np.nan, 3.0]], columns=expected_cols) + expected.index = Index( + Categorical(["low", "high"], categories=["low", "high"], ordered=ordered), + name="In", + ) + + tm.assert_frame_equal(result, expected) + + # case with columns/value + result = df.pivot_table(columns="Col", values="Val", observed=observed) + + expected = DataFrame( + data=[[3.5, 3.0]], columns=expected_cols, index=Index(["Val"]) + ) + + tm.assert_frame_equal(result, expected) + + def test_categorical_aggfunc(self, observed): + # GH 9534 + df = DataFrame( + {"C1": ["A", "B", "C", "C"], "C2": ["a", "a", "b", "b"], "V": [1, 2, 3, 4]} + ) + df["C1"] = df["C1"].astype("category") + msg = "The default value of observed=False is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = df.pivot_table( + "V", index="C1", columns="C2", dropna=observed, aggfunc="count" + ) + + expected_index = pd.CategoricalIndex( + ["A", "B", "C"], categories=["A", "B", "C"], ordered=False, name="C1" + ) + expected_columns = Index(["a", "b"], name="C2") + expected_data = np.array([[1, 0], [1, 0], [0, 2]], dtype=np.int64) + expected = DataFrame( + expected_data, index=expected_index, columns=expected_columns + ) + tm.assert_frame_equal(result, expected) + + def test_categorical_pivot_index_ordering(self, observed): + # GH 8731 + df = DataFrame( + { + "Sales": [100, 120, 220], + "Month": ["January", "January", "January"], + "Year": [2013, 2014, 2013], + } + ) + months = [ + "January", + "February", + "March", + "April", + "May", + "June", + "July", + "August", + "September", + "October", + "November", + "December", + ] + df["Month"] = df["Month"].astype("category").cat.set_categories(months) + result = df.pivot_table( + values="Sales", + index="Month", + columns="Year", + observed=observed, + aggfunc="sum", + ) + expected_columns = Index([2013, 2014], name="Year", dtype="int64") + expected_index = pd.CategoricalIndex( + months, categories=months, ordered=False, name="Month" + ) + expected_data = [[320, 120]] + [[0, 0]] * 11 + expected = DataFrame( + expected_data, index=expected_index, columns=expected_columns + ) + if observed: + expected = expected.loc[["January"]] + + tm.assert_frame_equal(result, expected) + + def test_pivot_table_not_series(self): + # GH 4386 + # pivot_table always returns a DataFrame + # when values is not list like and columns is None + # and aggfunc is not instance of list + df = DataFrame({"col1": [3, 4, 5], "col2": ["C", "D", "E"], "col3": [1, 3, 9]}) + + result = df.pivot_table("col1", index=["col3", "col2"], aggfunc="sum") + m = MultiIndex.from_arrays([[1, 3, 9], ["C", "D", "E"]], names=["col3", "col2"]) + expected = DataFrame([3, 4, 5], index=m, columns=["col1"]) + + tm.assert_frame_equal(result, expected) + + result = df.pivot_table("col1", index="col3", columns="col2", aggfunc="sum") + expected = DataFrame( + [[3, np.nan, np.nan], [np.nan, 4, np.nan], [np.nan, np.nan, 5]], + index=Index([1, 3, 9], name="col3"), + columns=Index(["C", "D", "E"], name="col2"), + ) + + tm.assert_frame_equal(result, expected) + + result = df.pivot_table("col1", index="col3", aggfunc=["sum"]) + m = MultiIndex.from_arrays([["sum"], ["col1"]]) + expected = DataFrame([3, 4, 5], index=Index([1, 3, 9], name="col3"), columns=m) + + tm.assert_frame_equal(result, expected) + + def test_pivot_margins_name_unicode(self): + # issue #13292 + greek = "\u0394\u03bf\u03ba\u03b9\u03bc\u03ae" + frame = DataFrame({"foo": [1, 2, 3]}, columns=Index(["foo"], dtype=object)) + table = pivot_table( + frame, index=["foo"], aggfunc=len, margins=True, margins_name=greek + ) + index = Index([1, 2, 3, greek], dtype="object", name="foo") + expected = DataFrame(index=index, columns=[]) + tm.assert_frame_equal(table, expected) + + def test_pivot_string_as_func(self): + # GH #18713 + # for correctness purposes + data = DataFrame( + { + "A": [ + "foo", + "foo", + "foo", + "foo", + "bar", + "bar", + "bar", + "bar", + "foo", + "foo", + "foo", + ], + "B": [ + "one", + "one", + "one", + "two", + "one", + "one", + "one", + "two", + "two", + "two", + "one", + ], + "C": range(11), + } + ) + + result = pivot_table(data, index="A", columns="B", aggfunc="sum") + mi = MultiIndex( + levels=[["C"], ["one", "two"]], codes=[[0, 0], [0, 1]], names=[None, "B"] + ) + expected = DataFrame( + {("C", "one"): {"bar": 15, "foo": 13}, ("C", "two"): {"bar": 7, "foo": 20}}, + columns=mi, + ).rename_axis("A") + tm.assert_frame_equal(result, expected) + + result = pivot_table(data, index="A", columns="B", aggfunc=["sum", "mean"]) + mi = MultiIndex( + levels=[["sum", "mean"], ["C"], ["one", "two"]], + codes=[[0, 0, 1, 1], [0, 0, 0, 0], [0, 1, 0, 1]], + names=[None, None, "B"], + ) + expected = DataFrame( + { + ("mean", "C", "one"): {"bar": 5.0, "foo": 3.25}, + ("mean", "C", "two"): {"bar": 7.0, "foo": 6.666666666666667}, + ("sum", "C", "one"): {"bar": 15, "foo": 13}, + ("sum", "C", "two"): {"bar": 7, "foo": 20}, + }, + columns=mi, + ).rename_axis("A") + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "f, f_numpy", + [ + ("sum", np.sum), + ("mean", np.mean), + ("std", np.std), + (["sum", "mean"], [np.sum, np.mean]), + (["sum", "std"], [np.sum, np.std]), + (["std", "mean"], [np.std, np.mean]), + ], + ) + def test_pivot_string_func_vs_func(self, f, f_numpy, data): + # GH #18713 + # for consistency purposes + data = data.drop(columns="C") + result = pivot_table(data, index="A", columns="B", aggfunc=f) + ops = "|".join(f) if isinstance(f, list) else f + msg = f"using DataFrameGroupBy.[{ops}]" + with tm.assert_produces_warning(FutureWarning, match=msg): + expected = pivot_table(data, index="A", columns="B", aggfunc=f_numpy) + tm.assert_frame_equal(result, expected) + + @pytest.mark.slow + def test_pivot_number_of_levels_larger_than_int32(self, monkeypatch): + # GH 20601 + # GH 26314: Change ValueError to PerformanceWarning + class MockUnstacker(reshape_lib._Unstacker): + def __init__(self, *args, **kwargs) -> None: + # __init__ will raise the warning + super().__init__(*args, **kwargs) + raise Exception("Don't compute final result.") + + with monkeypatch.context() as m: + m.setattr(reshape_lib, "_Unstacker", MockUnstacker) + df = DataFrame( + {"ind1": np.arange(2**16), "ind2": np.arange(2**16), "count": 0} + ) + + msg = "The following operation may generate" + with tm.assert_produces_warning(PerformanceWarning, match=msg): + with pytest.raises(Exception, match="Don't compute final result."): + df.pivot_table( + index="ind1", columns="ind2", values="count", aggfunc="count" + ) + + def test_pivot_table_aggfunc_dropna(self, dropna): + # GH 22159 + df = DataFrame( + { + "fruit": ["apple", "peach", "apple"], + "size": [1, 1, 2], + "taste": [7, 6, 6], + } + ) + + def ret_one(x): + return 1 + + def ret_sum(x): + return sum(x) + + def ret_none(x): + return np.nan + + result = pivot_table( + df, columns="fruit", aggfunc=[ret_sum, ret_none, ret_one], dropna=dropna + ) + + data = [[3, 1, np.nan, np.nan, 1, 1], [13, 6, np.nan, np.nan, 1, 1]] + col = MultiIndex.from_product( + [["ret_sum", "ret_none", "ret_one"], ["apple", "peach"]], + names=[None, "fruit"], + ) + expected = DataFrame(data, index=["size", "taste"], columns=col) + + if dropna: + expected = expected.dropna(axis="columns") + + tm.assert_frame_equal(result, expected) + + def test_pivot_table_aggfunc_scalar_dropna(self, dropna): + # GH 22159 + df = DataFrame( + {"A": ["one", "two", "one"], "x": [3, np.nan, 2], "y": [1, np.nan, np.nan]} + ) + + result = pivot_table(df, columns="A", aggfunc="mean", dropna=dropna) + + data = [[2.5, np.nan], [1, np.nan]] + col = Index(["one", "two"], name="A") + expected = DataFrame(data, index=["x", "y"], columns=col) + + if dropna: + expected = expected.dropna(axis="columns") + + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("margins", [True, False]) + def test_pivot_table_empty_aggfunc(self, margins): + # GH 9186 & GH 13483 & GH 49240 + df = DataFrame( + { + "A": [2, 2, 3, 3, 2], + "id": [5, 6, 7, 8, 9], + "C": ["p", "q", "q", "p", "q"], + "D": [None, None, None, None, None], + } + ) + result = df.pivot_table( + index="A", columns="D", values="id", aggfunc=np.size, margins=margins + ) + exp_cols = Index([], name="D") + expected = DataFrame(index=Index([], dtype="int64", name="A"), columns=exp_cols) + tm.assert_frame_equal(result, expected) + + def test_pivot_table_no_column_raises(self): + # GH 10326 + def agg(arr): + return np.mean(arr) + + df = DataFrame({"X": [0, 0, 1, 1], "Y": [0, 1, 0, 1], "Z": [10, 20, 30, 40]}) + with pytest.raises(KeyError, match="notpresent"): + df.pivot_table("notpresent", "X", "Y", aggfunc=agg) + + def test_pivot_table_multiindex_columns_doctest_case(self): + # The relevant characteristic is that the call + # to maybe_downcast_to_dtype(agged[v], data[v].dtype) in + # __internal_pivot_table has `agged[v]` a DataFrame instead of Series, + # In this case this is because agged.columns is a MultiIndex and 'v' + # is only indexing on its first level. + df = DataFrame( + { + "A": ["foo", "foo", "foo", "foo", "foo", "bar", "bar", "bar", "bar"], + "B": ["one", "one", "one", "two", "two", "one", "one", "two", "two"], + "C": [ + "small", + "large", + "large", + "small", + "small", + "large", + "small", + "small", + "large", + ], + "D": [1, 2, 2, 3, 3, 4, 5, 6, 7], + "E": [2, 4, 5, 5, 6, 6, 8, 9, 9], + } + ) + + table = pivot_table( + df, + values=["D", "E"], + index=["A", "C"], + aggfunc={"D": "mean", "E": ["min", "max", "mean"]}, + ) + cols = MultiIndex.from_tuples( + [("D", "mean"), ("E", "max"), ("E", "mean"), ("E", "min")] + ) + index = MultiIndex.from_tuples( + [("bar", "large"), ("bar", "small"), ("foo", "large"), ("foo", "small")], + names=["A", "C"], + ) + vals = np.array( + [ + [5.5, 9.0, 7.5, 6.0], + [5.5, 9.0, 8.5, 8.0], + [2.0, 5.0, 4.5, 4.0], + [2.33333333, 6.0, 4.33333333, 2.0], + ] + ) + expected = DataFrame(vals, columns=cols, index=index) + expected[("E", "min")] = expected[("E", "min")].astype(np.int64) + expected[("E", "max")] = expected[("E", "max")].astype(np.int64) + tm.assert_frame_equal(table, expected) + + def test_pivot_table_sort_false(self): + # GH#39143 + df = DataFrame( + { + "a": ["d1", "d4", "d3"], + "col": ["a", "b", "c"], + "num": [23, 21, 34], + "year": ["2018", "2018", "2019"], + } + ) + result = df.pivot_table( + index=["a", "col"], columns="year", values="num", aggfunc="sum", sort=False + ) + expected = DataFrame( + [[23, np.nan], [21, np.nan], [np.nan, 34]], + columns=Index(["2018", "2019"], name="year"), + index=MultiIndex.from_arrays( + [["d1", "d4", "d3"], ["a", "b", "c"]], names=["a", "col"] + ), + ) + tm.assert_frame_equal(result, expected) + + def test_pivot_table_nullable_margins(self): + # GH#48681 + df = DataFrame( + {"a": "A", "b": [1, 2], "sales": Series([10, 11], dtype="Int64")} + ) + + result = df.pivot_table(index="b", columns="a", margins=True, aggfunc="sum") + expected = DataFrame( + [[10, 10], [11, 11], [21, 21]], + index=Index([1, 2, "All"], name="b"), + columns=MultiIndex.from_tuples( + [("sales", "A"), ("sales", "All")], names=[None, "a"] + ), + dtype="Int64", + ) + tm.assert_frame_equal(result, expected) + + def test_pivot_table_sort_false_with_multiple_values(self): + df = DataFrame( + { + "firstname": ["John", "Michael"], + "lastname": ["Foo", "Bar"], + "height": [173, 182], + "age": [47, 33], + } + ) + result = df.pivot_table( + index=["lastname", "firstname"], values=["height", "age"], sort=False + ) + expected = DataFrame( + [[173.0, 47.0], [182.0, 33.0]], + columns=["height", "age"], + index=MultiIndex.from_tuples( + [("Foo", "John"), ("Bar", "Michael")], + names=["lastname", "firstname"], + ), + ) + tm.assert_frame_equal(result, expected) + + def test_pivot_table_with_margins_and_numeric_columns(self): + # GH 26568 + df = DataFrame([["a", "x", 1], ["a", "y", 2], ["b", "y", 3], ["b", "z", 4]]) + df.columns = [10, 20, 30] + + result = df.pivot_table( + index=10, columns=20, values=30, aggfunc="sum", fill_value=0, margins=True + ) + + expected = DataFrame([[1, 2, 0, 3], [0, 3, 4, 7], [1, 5, 4, 10]]) + expected.columns = ["x", "y", "z", "All"] + expected.index = ["a", "b", "All"] + expected.columns.name = 20 + expected.index.name = 10 + + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("dropna", [True, False]) + def test_pivot_ea_dtype_dropna(self, dropna): + # GH#47477 + df = DataFrame({"x": "a", "y": "b", "age": Series([20, 40], dtype="Int64")}) + result = df.pivot_table( + index="x", columns="y", values="age", aggfunc="mean", dropna=dropna + ) + expected = DataFrame( + [[30]], + index=Index(["a"], name="x"), + columns=Index(["b"], name="y"), + dtype="Float64", + ) + tm.assert_frame_equal(result, expected) + + def test_pivot_table_datetime_warning(self): + # GH#48683 + df = DataFrame( + { + "a": "A", + "b": [1, 2], + "date": pd.Timestamp("2019-12-31"), + "sales": [10.0, 11], + } + ) + with tm.assert_produces_warning(None): + result = df.pivot_table( + index=["b", "date"], columns="a", margins=True, aggfunc="sum" + ) + expected = DataFrame( + [[10.0, 10.0], [11.0, 11.0], [21.0, 21.0]], + index=MultiIndex.from_arrays( + [ + Index([1, 2, "All"], name="b"), + Index( + [pd.Timestamp("2019-12-31"), pd.Timestamp("2019-12-31"), ""], + dtype=object, + name="date", + ), + ] + ), + columns=MultiIndex.from_tuples( + [("sales", "A"), ("sales", "All")], names=[None, "a"] + ), + ) + tm.assert_frame_equal(result, expected) + + def test_pivot_table_with_mixed_nested_tuples(self, using_array_manager): + # GH 50342 + df = DataFrame( + { + "A": ["foo", "foo", "foo", "foo", "foo", "bar", "bar", "bar", "bar"], + "B": ["one", "one", "one", "two", "two", "one", "one", "two", "two"], + "C": [ + "small", + "large", + "large", + "small", + "small", + "large", + "small", + "small", + "large", + ], + "D": [1, 2, 2, 3, 3, 4, 5, 6, 7], + "E": [2, 4, 5, 5, 6, 6, 8, 9, 9], + ("col5",): [ + "foo", + "foo", + "foo", + "foo", + "foo", + "bar", + "bar", + "bar", + "bar", + ], + ("col6", 6): [ + "one", + "one", + "one", + "two", + "two", + "one", + "one", + "two", + "two", + ], + (7, "seven"): [ + "small", + "large", + "large", + "small", + "small", + "large", + "small", + "small", + "large", + ], + } + ) + result = pivot_table( + df, values="D", index=["A", "B"], columns=[(7, "seven")], aggfunc="sum" + ) + expected = DataFrame( + [[4.0, 5.0], [7.0, 6.0], [4.0, 1.0], [np.nan, 6.0]], + columns=Index(["large", "small"], name=(7, "seven")), + index=MultiIndex.from_arrays( + [["bar", "bar", "foo", "foo"], ["one", "two"] * 2], names=["A", "B"] + ), + ) + if using_array_manager: + # INFO(ArrayManager) column without NaNs can preserve int dtype + expected["small"] = expected["small"].astype("int64") + tm.assert_frame_equal(result, expected) + + def test_pivot_table_aggfunc_nunique_with_different_values(self): + test = DataFrame( + { + "a": range(10), + "b": range(10), + "c": range(10), + "d": range(10), + } + ) + + columnval = MultiIndex.from_arrays( + [ + ["nunique" for i in range(10)], + ["c" for i in range(10)], + range(10), + ], + names=(None, None, "b"), + ) + nparr = np.full((10, 10), np.nan) + np.fill_diagonal(nparr, 1.0) + + expected = DataFrame(nparr, index=Index(range(10), name="a"), columns=columnval) + result = test.pivot_table( + index=[ + "a", + ], + columns=[ + "b", + ], + values=[ + "c", + ], + aggfunc=["nunique"], + ) + + tm.assert_frame_equal(result, expected) + + +class TestPivot: + def test_pivot(self): + data = { + "index": ["A", "B", "C", "C", "B", "A"], + "columns": ["One", "One", "One", "Two", "Two", "Two"], + "values": [1.0, 2.0, 3.0, 3.0, 2.0, 1.0], + } + + frame = DataFrame(data) + pivoted = frame.pivot(index="index", columns="columns", values="values") + + expected = DataFrame( + { + "One": {"A": 1.0, "B": 2.0, "C": 3.0}, + "Two": {"A": 1.0, "B": 2.0, "C": 3.0}, + } + ) + + expected.index.name, expected.columns.name = "index", "columns" + tm.assert_frame_equal(pivoted, expected) + + # name tracking + assert pivoted.index.name == "index" + assert pivoted.columns.name == "columns" + + # don't specify values + pivoted = frame.pivot(index="index", columns="columns") + assert pivoted.index.name == "index" + assert pivoted.columns.names == (None, "columns") + + def test_pivot_duplicates(self): + data = DataFrame( + { + "a": ["bar", "bar", "foo", "foo", "foo"], + "b": ["one", "two", "one", "one", "two"], + "c": [1.0, 2.0, 3.0, 3.0, 4.0], + } + ) + with pytest.raises(ValueError, match="duplicate entries"): + data.pivot(index="a", columns="b", values="c") + + def test_pivot_empty(self): + df = DataFrame(columns=["a", "b", "c"]) + result = df.pivot(index="a", columns="b", values="c") + expected = DataFrame(index=[], columns=[]) + tm.assert_frame_equal(result, expected, check_names=False) + + @pytest.mark.parametrize("dtype", [object, "string"]) + def test_pivot_integer_bug(self, dtype): + df = DataFrame(data=[("A", "1", "A1"), ("B", "2", "B2")], dtype=dtype) + + result = df.pivot(index=1, columns=0, values=2) + tm.assert_index_equal(result.columns, Index(["A", "B"], name=0, dtype=dtype)) + + def test_pivot_index_none(self): + # GH#3962 + data = { + "index": ["A", "B", "C", "C", "B", "A"], + "columns": ["One", "One", "One", "Two", "Two", "Two"], + "values": [1.0, 2.0, 3.0, 3.0, 2.0, 1.0], + } + + frame = DataFrame(data).set_index("index") + result = frame.pivot(columns="columns", values="values") + expected = DataFrame( + { + "One": {"A": 1.0, "B": 2.0, "C": 3.0}, + "Two": {"A": 1.0, "B": 2.0, "C": 3.0}, + } + ) + + expected.index.name, expected.columns.name = "index", "columns" + tm.assert_frame_equal(result, expected) + + # omit values + result = frame.pivot(columns="columns") + + expected.columns = MultiIndex.from_tuples( + [("values", "One"), ("values", "Two")], names=[None, "columns"] + ) + expected.index.name = "index" + tm.assert_frame_equal(result, expected, check_names=False) + assert result.index.name == "index" + assert result.columns.names == (None, "columns") + expected.columns = expected.columns.droplevel(0) + result = frame.pivot(columns="columns", values="values") + + expected.columns.name = "columns" + tm.assert_frame_equal(result, expected) + + def test_pivot_index_list_values_none_immutable_args(self): + # GH37635 + df = DataFrame( + { + "lev1": [1, 1, 1, 2, 2, 2], + "lev2": [1, 1, 2, 1, 1, 2], + "lev3": [1, 2, 1, 2, 1, 2], + "lev4": [1, 2, 3, 4, 5, 6], + "values": [0, 1, 2, 3, 4, 5], + } + ) + index = ["lev1", "lev2"] + columns = ["lev3"] + result = df.pivot(index=index, columns=columns) + + expected = DataFrame( + np.array( + [ + [1.0, 2.0, 0.0, 1.0], + [3.0, np.nan, 2.0, np.nan], + [5.0, 4.0, 4.0, 3.0], + [np.nan, 6.0, np.nan, 5.0], + ] + ), + index=MultiIndex.from_arrays( + [(1, 1, 2, 2), (1, 2, 1, 2)], names=["lev1", "lev2"] + ), + columns=MultiIndex.from_arrays( + [("lev4", "lev4", "values", "values"), (1, 2, 1, 2)], + names=[None, "lev3"], + ), + ) + + tm.assert_frame_equal(result, expected) + + assert index == ["lev1", "lev2"] + assert columns == ["lev3"] + + def test_pivot_columns_not_given(self): + # GH#48293 + df = DataFrame({"a": [1], "b": 1}) + with pytest.raises(TypeError, match="missing 1 required keyword-only argument"): + df.pivot() # pylint: disable=missing-kwoa + + @pytest.mark.xfail(using_pyarrow_string_dtype(), reason="None is cast to NaN") + def test_pivot_columns_is_none(self): + # GH#48293 + df = DataFrame({None: [1], "b": 2, "c": 3}) + result = df.pivot(columns=None) + expected = DataFrame({("b", 1): [2], ("c", 1): 3}) + tm.assert_frame_equal(result, expected) + + result = df.pivot(columns=None, index="b") + expected = DataFrame({("c", 1): 3}, index=Index([2], name="b")) + tm.assert_frame_equal(result, expected) + + result = df.pivot(columns=None, index="b", values="c") + expected = DataFrame({1: 3}, index=Index([2], name="b")) + tm.assert_frame_equal(result, expected) + + @pytest.mark.xfail(using_pyarrow_string_dtype(), reason="None is cast to NaN") + def test_pivot_index_is_none(self): + # GH#48293 + df = DataFrame({None: [1], "b": 2, "c": 3}) + + result = df.pivot(columns="b", index=None) + expected = DataFrame({("c", 2): 3}, index=[1]) + expected.columns.names = [None, "b"] + tm.assert_frame_equal(result, expected) + + result = df.pivot(columns="b", index=None, values="c") + expected = DataFrame(3, index=[1], columns=Index([2], name="b")) + tm.assert_frame_equal(result, expected) + + @pytest.mark.xfail(using_pyarrow_string_dtype(), reason="None is cast to NaN") + def test_pivot_values_is_none(self): + # GH#48293 + df = DataFrame({None: [1], "b": 2, "c": 3}) + + result = df.pivot(columns="b", index="c", values=None) + expected = DataFrame( + 1, index=Index([3], name="c"), columns=Index([2], name="b") + ) + tm.assert_frame_equal(result, expected) + + result = df.pivot(columns="b", values=None) + expected = DataFrame(1, index=[0], columns=Index([2], name="b")) + tm.assert_frame_equal(result, expected) + + def test_pivot_not_changing_index_name(self): + # GH#52692 + df = DataFrame({"one": ["a"], "two": 0, "three": 1}) + expected = df.copy(deep=True) + df.pivot(index="one", columns="two", values="three") + tm.assert_frame_equal(df, expected) + + def test_pivot_table_empty_dataframe_correct_index(self): + # GH 21932 + df = DataFrame([], columns=["a", "b", "value"]) + pivot = df.pivot_table(index="a", columns="b", values="value", aggfunc="count") + + expected = Index([], dtype="object", name="b") + tm.assert_index_equal(pivot.columns, expected) + + def test_pivot_table_handles_explicit_datetime_types(self): + # GH#43574 + df = DataFrame( + [ + {"a": "x", "date_str": "2023-01-01", "amount": 1}, + {"a": "y", "date_str": "2023-01-02", "amount": 2}, + {"a": "z", "date_str": "2023-01-03", "amount": 3}, + ] + ) + df["date"] = pd.to_datetime(df["date_str"]) + + with tm.assert_produces_warning(False): + pivot = df.pivot_table( + index=["a", "date"], values=["amount"], aggfunc="sum", margins=True + ) + + expected = MultiIndex.from_tuples( + [ + ("x", datetime.strptime("2023-01-01 00:00:00", "%Y-%m-%d %H:%M:%S")), + ("y", datetime.strptime("2023-01-02 00:00:00", "%Y-%m-%d %H:%M:%S")), + ("z", datetime.strptime("2023-01-03 00:00:00", "%Y-%m-%d %H:%M:%S")), + ("All", ""), + ], + names=["a", "date"], + ) + tm.assert_index_equal(pivot.index, expected) + + def test_pivot_table_with_margins_and_numeric_column_names(self): + # GH#26568 + df = DataFrame([["a", "x", 1], ["a", "y", 2], ["b", "y", 3], ["b", "z", 4]]) + + result = df.pivot_table( + index=0, columns=1, values=2, aggfunc="sum", fill_value=0, margins=True + ) + + expected = DataFrame( + [[1, 2, 0, 3], [0, 3, 4, 7], [1, 5, 4, 10]], + columns=Index(["x", "y", "z", "All"], name=1), + index=Index(["a", "b", "All"], name=0), + ) + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_pivot_multilevel.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_pivot_multilevel.py new file mode 100644 index 0000000000000000000000000000000000000000..08ef29440825f006bf53eea7f21f0809bff99908 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_pivot_multilevel.py @@ -0,0 +1,254 @@ +import numpy as np +import pytest + +from pandas._libs import lib + +import pandas as pd +from pandas import ( + Index, + MultiIndex, +) +import pandas._testing as tm + + +@pytest.mark.parametrize( + "input_index, input_columns, input_values, " + "expected_values, expected_columns, expected_index", + [ + ( + ["lev4"], + "lev3", + "values", + [ + [0.0, np.nan], + [np.nan, 1.0], + [2.0, np.nan], + [np.nan, 3.0], + [4.0, np.nan], + [np.nan, 5.0], + [6.0, np.nan], + [np.nan, 7.0], + ], + Index([1, 2], name="lev3"), + Index([1, 2, 3, 4, 5, 6, 7, 8], name="lev4"), + ), + ( + ["lev4"], + "lev3", + lib.no_default, + [ + [1.0, np.nan, 1.0, np.nan, 0.0, np.nan], + [np.nan, 1.0, np.nan, 1.0, np.nan, 1.0], + [1.0, np.nan, 2.0, np.nan, 2.0, np.nan], + [np.nan, 1.0, np.nan, 2.0, np.nan, 3.0], + [2.0, np.nan, 1.0, np.nan, 4.0, np.nan], + [np.nan, 2.0, np.nan, 1.0, np.nan, 5.0], + [2.0, np.nan, 2.0, np.nan, 6.0, np.nan], + [np.nan, 2.0, np.nan, 2.0, np.nan, 7.0], + ], + MultiIndex.from_tuples( + [ + ("lev1", 1), + ("lev1", 2), + ("lev2", 1), + ("lev2", 2), + ("values", 1), + ("values", 2), + ], + names=[None, "lev3"], + ), + Index([1, 2, 3, 4, 5, 6, 7, 8], name="lev4"), + ), + ( + ["lev1", "lev2"], + "lev3", + "values", + [[0, 1], [2, 3], [4, 5], [6, 7]], + Index([1, 2], name="lev3"), + MultiIndex.from_tuples( + [(1, 1), (1, 2), (2, 1), (2, 2)], names=["lev1", "lev2"] + ), + ), + ( + ["lev1", "lev2"], + "lev3", + lib.no_default, + [[1, 2, 0, 1], [3, 4, 2, 3], [5, 6, 4, 5], [7, 8, 6, 7]], + MultiIndex.from_tuples( + [("lev4", 1), ("lev4", 2), ("values", 1), ("values", 2)], + names=[None, "lev3"], + ), + MultiIndex.from_tuples( + [(1, 1), (1, 2), (2, 1), (2, 2)], names=["lev1", "lev2"] + ), + ), + ], +) +def test_pivot_list_like_index( + input_index, + input_columns, + input_values, + expected_values, + expected_columns, + expected_index, +): + # GH 21425, test when index is given a list + df = pd.DataFrame( + { + "lev1": [1, 1, 1, 1, 2, 2, 2, 2], + "lev2": [1, 1, 2, 2, 1, 1, 2, 2], + "lev3": [1, 2, 1, 2, 1, 2, 1, 2], + "lev4": [1, 2, 3, 4, 5, 6, 7, 8], + "values": [0, 1, 2, 3, 4, 5, 6, 7], + } + ) + + result = df.pivot(index=input_index, columns=input_columns, values=input_values) + expected = pd.DataFrame( + expected_values, columns=expected_columns, index=expected_index + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "input_index, input_columns, input_values, " + "expected_values, expected_columns, expected_index", + [ + ( + "lev4", + ["lev3"], + "values", + [ + [0.0, np.nan], + [np.nan, 1.0], + [2.0, np.nan], + [np.nan, 3.0], + [4.0, np.nan], + [np.nan, 5.0], + [6.0, np.nan], + [np.nan, 7.0], + ], + Index([1, 2], name="lev3"), + Index([1, 2, 3, 4, 5, 6, 7, 8], name="lev4"), + ), + ( + ["lev1", "lev2"], + ["lev3"], + "values", + [[0, 1], [2, 3], [4, 5], [6, 7]], + Index([1, 2], name="lev3"), + MultiIndex.from_tuples( + [(1, 1), (1, 2), (2, 1), (2, 2)], names=["lev1", "lev2"] + ), + ), + ( + ["lev1"], + ["lev2", "lev3"], + "values", + [[0, 1, 2, 3], [4, 5, 6, 7]], + MultiIndex.from_tuples( + [(1, 1), (1, 2), (2, 1), (2, 2)], names=["lev2", "lev3"] + ), + Index([1, 2], name="lev1"), + ), + ( + ["lev1", "lev2"], + ["lev3", "lev4"], + "values", + [ + [0.0, 1.0, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan], + [np.nan, np.nan, 2.0, 3.0, np.nan, np.nan, np.nan, np.nan], + [np.nan, np.nan, np.nan, np.nan, 4.0, 5.0, np.nan, np.nan], + [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, 6.0, 7.0], + ], + MultiIndex.from_tuples( + [(1, 1), (2, 2), (1, 3), (2, 4), (1, 5), (2, 6), (1, 7), (2, 8)], + names=["lev3", "lev4"], + ), + MultiIndex.from_tuples( + [(1, 1), (1, 2), (2, 1), (2, 2)], names=["lev1", "lev2"] + ), + ), + ], +) +def test_pivot_list_like_columns( + input_index, + input_columns, + input_values, + expected_values, + expected_columns, + expected_index, +): + # GH 21425, test when columns is given a list + df = pd.DataFrame( + { + "lev1": [1, 1, 1, 1, 2, 2, 2, 2], + "lev2": [1, 1, 2, 2, 1, 1, 2, 2], + "lev3": [1, 2, 1, 2, 1, 2, 1, 2], + "lev4": [1, 2, 3, 4, 5, 6, 7, 8], + "values": [0, 1, 2, 3, 4, 5, 6, 7], + } + ) + + result = df.pivot(index=input_index, columns=input_columns, values=input_values) + expected = pd.DataFrame( + expected_values, columns=expected_columns, index=expected_index + ) + tm.assert_frame_equal(result, expected) + + +def test_pivot_multiindexed_rows_and_cols(using_array_manager): + # GH 36360 + + df = pd.DataFrame( + data=np.arange(12).reshape(4, 3), + columns=MultiIndex.from_tuples( + [(0, 0), (0, 1), (0, 2)], names=["col_L0", "col_L1"] + ), + index=MultiIndex.from_tuples( + [(0, 0, 0), (0, 0, 1), (1, 1, 1), (1, 0, 0)], + names=["idx_L0", "idx_L1", "idx_L2"], + ), + ) + + res = df.pivot_table( + index=["idx_L0"], + columns=["idx_L1"], + values=[(0, 1)], + aggfunc=lambda col: col.values.sum(), + ) + + expected = pd.DataFrame( + data=[[5, np.nan], [10, 7.0]], + columns=MultiIndex.from_tuples( + [(0, 1, 0), (0, 1, 1)], names=["col_L0", "col_L1", "idx_L1"] + ), + index=Index([0, 1], dtype="int64", name="idx_L0"), + ) + if not using_array_manager: + # BlockManager does not preserve the dtypes + expected = expected.astype("float64") + + tm.assert_frame_equal(res, expected) + + +def test_pivot_df_multiindex_index_none(): + # GH 23955 + df = pd.DataFrame( + [ + ["A", "A1", "label1", 1], + ["A", "A2", "label2", 2], + ["B", "A1", "label1", 3], + ["B", "A2", "label2", 4], + ], + columns=["index_1", "index_2", "label", "value"], + ) + df = df.set_index(["index_1", "index_2"]) + + result = df.pivot(columns="label", values="value") + expected = pd.DataFrame( + [[1.0, np.nan], [np.nan, 2.0], [3.0, np.nan], [np.nan, 4.0]], + index=df.index, + columns=Index(["label1", "label2"], name="label"), + ) + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_qcut.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_qcut.py new file mode 100644 index 0000000000000000000000000000000000000000..b5b19eef1106fd2f500fc74e0eb5157ca63cfeb7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_qcut.py @@ -0,0 +1,305 @@ +import os + +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + Categorical, + DatetimeIndex, + Interval, + IntervalIndex, + NaT, + Series, + Timedelta, + TimedeltaIndex, + Timestamp, + cut, + date_range, + isna, + qcut, + timedelta_range, +) +import pandas._testing as tm +from pandas.api.types import CategoricalDtype + +from pandas.tseries.offsets import Day + + +def test_qcut(): + arr = np.random.default_rng(2).standard_normal(1000) + + # We store the bins as Index that have been + # rounded to comparisons are a bit tricky. + labels, _ = qcut(arr, 4, retbins=True) + ex_bins = np.quantile(arr, [0, 0.25, 0.5, 0.75, 1.0]) + + result = labels.categories.left.values + assert np.allclose(result, ex_bins[:-1], atol=1e-2) + + result = labels.categories.right.values + assert np.allclose(result, ex_bins[1:], atol=1e-2) + + ex_levels = cut(arr, ex_bins, include_lowest=True) + tm.assert_categorical_equal(labels, ex_levels) + + +def test_qcut_bounds(): + arr = np.random.default_rng(2).standard_normal(1000) + + factor = qcut(arr, 10, labels=False) + assert len(np.unique(factor)) == 10 + + +def test_qcut_specify_quantiles(): + arr = np.random.default_rng(2).standard_normal(100) + factor = qcut(arr, [0, 0.25, 0.5, 0.75, 1.0]) + + expected = qcut(arr, 4) + tm.assert_categorical_equal(factor, expected) + + +def test_qcut_all_bins_same(): + with pytest.raises(ValueError, match="edges.*unique"): + qcut([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 3) + + +def test_qcut_include_lowest(): + values = np.arange(10) + ii = qcut(values, 4) + + ex_levels = IntervalIndex( + [ + Interval(-0.001, 2.25), + Interval(2.25, 4.5), + Interval(4.5, 6.75), + Interval(6.75, 9), + ] + ) + tm.assert_index_equal(ii.categories, ex_levels) + + +def test_qcut_nas(): + arr = np.random.default_rng(2).standard_normal(100) + arr[:20] = np.nan + + result = qcut(arr, 4) + assert isna(result[:20]).all() + + +def test_qcut_index(): + result = qcut([0, 2], 2) + intervals = [Interval(-0.001, 1), Interval(1, 2)] + + expected = Categorical(intervals, ordered=True) + tm.assert_categorical_equal(result, expected) + + +def test_qcut_binning_issues(datapath): + # see gh-1978, gh-1979 + cut_file = datapath(os.path.join("reshape", "data", "cut_data.csv")) + arr = np.loadtxt(cut_file) + result = qcut(arr, 20) + + starts = [] + ends = [] + + for lev in np.unique(result): + s = lev.left + e = lev.right + assert s != e + + starts.append(float(s)) + ends.append(float(e)) + + for (sp, sn), (ep, en) in zip( + zip(starts[:-1], starts[1:]), zip(ends[:-1], ends[1:]) + ): + assert sp < sn + assert ep < en + assert ep <= sn + + +def test_qcut_return_intervals(): + ser = Series([0, 1, 2, 3, 4, 5, 6, 7, 8]) + res = qcut(ser, [0, 0.333, 0.666, 1]) + + exp_levels = np.array( + [Interval(-0.001, 2.664), Interval(2.664, 5.328), Interval(5.328, 8)] + ) + exp = Series(exp_levels.take([0, 0, 0, 1, 1, 1, 2, 2, 2])).astype( + CategoricalDtype(ordered=True) + ) + tm.assert_series_equal(res, exp) + + +@pytest.mark.parametrize("labels", ["foo", 1, True]) +def test_qcut_incorrect_labels(labels): + # GH 13318 + values = range(5) + msg = "Bin labels must either be False, None or passed in as a list-like argument" + with pytest.raises(ValueError, match=msg): + qcut(values, 4, labels=labels) + + +@pytest.mark.parametrize("labels", [["a", "b", "c"], list(range(3))]) +def test_qcut_wrong_length_labels(labels): + # GH 13318 + values = range(10) + msg = "Bin labels must be one fewer than the number of bin edges" + with pytest.raises(ValueError, match=msg): + qcut(values, 4, labels=labels) + + +@pytest.mark.parametrize( + "labels, expected", + [ + (["a", "b", "c"], Categorical(["a", "b", "c"], ordered=True)), + (list(range(3)), Categorical([0, 1, 2], ordered=True)), + ], +) +def test_qcut_list_like_labels(labels, expected): + # GH 13318 + values = range(3) + result = qcut(values, 3, labels=labels) + tm.assert_categorical_equal(result, expected) + + +@pytest.mark.parametrize( + "kwargs,msg", + [ + ({"duplicates": "drop"}, None), + ({}, "Bin edges must be unique"), + ({"duplicates": "raise"}, "Bin edges must be unique"), + ({"duplicates": "foo"}, "invalid value for 'duplicates' parameter"), + ], +) +def test_qcut_duplicates_bin(kwargs, msg): + # see gh-7751 + values = [0, 0, 0, 0, 1, 2, 3] + + if msg is not None: + with pytest.raises(ValueError, match=msg): + qcut(values, 3, **kwargs) + else: + result = qcut(values, 3, **kwargs) + expected = IntervalIndex([Interval(-0.001, 1), Interval(1, 3)]) + tm.assert_index_equal(result.categories, expected) + + +@pytest.mark.parametrize( + "data,start,end", [(9.0, 8.999, 9.0), (0.0, -0.001, 0.0), (-9.0, -9.001, -9.0)] +) +@pytest.mark.parametrize("length", [1, 2]) +@pytest.mark.parametrize("labels", [None, False]) +def test_single_quantile(data, start, end, length, labels): + # see gh-15431 + ser = Series([data] * length) + result = qcut(ser, 1, labels=labels) + + if labels is None: + intervals = IntervalIndex([Interval(start, end)] * length, closed="right") + expected = Series(intervals).astype(CategoricalDtype(ordered=True)) + else: + expected = Series([0] * length, dtype=np.intp) + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "ser", + [ + Series(DatetimeIndex(["20180101", NaT, "20180103"])), + Series(TimedeltaIndex(["0 days", NaT, "2 days"])), + ], + ids=lambda x: str(x.dtype), +) +def test_qcut_nat(ser, unit): + # see gh-19768 + ser = ser.dt.as_unit(unit) + td = Timedelta(1, unit=unit).as_unit(unit) + + left = Series([ser[0] - td, np.nan, ser[2] - Day()], dtype=ser.dtype) + right = Series([ser[2] - Day(), np.nan, ser[2]], dtype=ser.dtype) + intervals = IntervalIndex.from_arrays(left, right) + expected = Series(Categorical(intervals, ordered=True)) + + result = qcut(ser, 2) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("bins", [3, np.linspace(0, 1, 4)]) +def test_datetime_tz_qcut(bins): + # see gh-19872 + tz = "US/Eastern" + ser = Series(date_range("20130101", periods=3, tz=tz)) + + result = qcut(ser, bins) + expected = Series( + IntervalIndex( + [ + Interval( + Timestamp("2012-12-31 23:59:59.999999999", tz=tz), + Timestamp("2013-01-01 16:00:00", tz=tz), + ), + Interval( + Timestamp("2013-01-01 16:00:00", tz=tz), + Timestamp("2013-01-02 08:00:00", tz=tz), + ), + Interval( + Timestamp("2013-01-02 08:00:00", tz=tz), + Timestamp("2013-01-03 00:00:00", tz=tz), + ), + ] + ) + ).astype(CategoricalDtype(ordered=True)) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "arg,expected_bins", + [ + [ + timedelta_range("1day", periods=3), + TimedeltaIndex(["1 days", "2 days", "3 days"]), + ], + [ + date_range("20180101", periods=3), + DatetimeIndex(["2018-01-01", "2018-01-02", "2018-01-03"]), + ], + ], +) +def test_date_like_qcut_bins(arg, expected_bins): + # see gh-19891 + ser = Series(arg) + result, result_bins = qcut(ser, 2, retbins=True) + tm.assert_index_equal(result_bins, expected_bins) + + +@pytest.mark.parametrize("bins", [6, 7]) +@pytest.mark.parametrize( + "box, compare", + [ + (Series, tm.assert_series_equal), + (np.array, tm.assert_categorical_equal), + (list, tm.assert_equal), + ], +) +def test_qcut_bool_coercion_to_int(bins, box, compare): + # issue 20303 + data_expected = box([0, 1, 1, 0, 1] * 10) + data_result = box([False, True, True, False, True] * 10) + expected = qcut(data_expected, bins, duplicates="drop") + result = qcut(data_result, bins, duplicates="drop") + compare(result, expected) + + +@pytest.mark.parametrize("q", [2, 5, 10]) +def test_qcut_nullable_integer(q, any_numeric_ea_dtype): + arr = pd.array(np.arange(100), dtype=any_numeric_ea_dtype) + arr[::2] = pd.NA + + result = qcut(arr, q) + expected = qcut(arr.astype(float), q) + + tm.assert_categorical_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_union_categoricals.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_union_categoricals.py new file mode 100644 index 0000000000000000000000000000000000000000..8d78d34e936f0eba01bcd2b2a1135271f7e20918 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_union_categoricals.py @@ -0,0 +1,365 @@ +import numpy as np +import pytest + +from pandas.core.dtypes.concat import union_categoricals + +import pandas as pd +from pandas import ( + Categorical, + CategoricalIndex, + Series, +) +import pandas._testing as tm + + +class TestUnionCategoricals: + @pytest.mark.parametrize( + "a, b, combined", + [ + (list("abc"), list("abd"), list("abcabd")), + ([0, 1, 2], [2, 3, 4], [0, 1, 2, 2, 3, 4]), + ([0, 1.2, 2], [2, 3.4, 4], [0, 1.2, 2, 2, 3.4, 4]), + ( + ["b", "b", np.nan, "a"], + ["a", np.nan, "c"], + ["b", "b", np.nan, "a", "a", np.nan, "c"], + ), + ( + pd.date_range("2014-01-01", "2014-01-05"), + pd.date_range("2014-01-06", "2014-01-07"), + pd.date_range("2014-01-01", "2014-01-07"), + ), + ( + pd.date_range("2014-01-01", "2014-01-05", tz="US/Central"), + pd.date_range("2014-01-06", "2014-01-07", tz="US/Central"), + pd.date_range("2014-01-01", "2014-01-07", tz="US/Central"), + ), + ( + pd.period_range("2014-01-01", "2014-01-05"), + pd.period_range("2014-01-06", "2014-01-07"), + pd.period_range("2014-01-01", "2014-01-07"), + ), + ], + ) + @pytest.mark.parametrize("box", [Categorical, CategoricalIndex, Series]) + def test_union_categorical(self, a, b, combined, box): + # GH 13361 + result = union_categoricals([box(Categorical(a)), box(Categorical(b))]) + expected = Categorical(combined) + tm.assert_categorical_equal(result, expected) + + def test_union_categorical_ordered_appearance(self): + # new categories ordered by appearance + s = Categorical(["x", "y", "z"]) + s2 = Categorical(["a", "b", "c"]) + result = union_categoricals([s, s2]) + expected = Categorical( + ["x", "y", "z", "a", "b", "c"], categories=["x", "y", "z", "a", "b", "c"] + ) + tm.assert_categorical_equal(result, expected) + + def test_union_categorical_ordered_true(self): + s = Categorical([0, 1.2, 2], ordered=True) + s2 = Categorical([0, 1.2, 2], ordered=True) + result = union_categoricals([s, s2]) + expected = Categorical([0, 1.2, 2, 0, 1.2, 2], ordered=True) + tm.assert_categorical_equal(result, expected) + + def test_union_categorical_match_types(self): + # must exactly match types + s = Categorical([0, 1.2, 2]) + s2 = Categorical([2, 3, 4]) + msg = "dtype of categories must be the same" + with pytest.raises(TypeError, match=msg): + union_categoricals([s, s2]) + + def test_union_categorical_empty(self): + msg = "No Categoricals to union" + with pytest.raises(ValueError, match=msg): + union_categoricals([]) + + def test_union_categoricals_nan(self): + # GH 13759 + res = union_categoricals( + [Categorical([1, 2, np.nan]), Categorical([3, 2, np.nan])] + ) + exp = Categorical([1, 2, np.nan, 3, 2, np.nan]) + tm.assert_categorical_equal(res, exp) + + res = union_categoricals( + [Categorical(["A", "B"]), Categorical(["B", "B", np.nan])] + ) + exp = Categorical(["A", "B", "B", "B", np.nan]) + tm.assert_categorical_equal(res, exp) + + val1 = [pd.Timestamp("2011-01-01"), pd.Timestamp("2011-03-01"), pd.NaT] + val2 = [pd.NaT, pd.Timestamp("2011-01-01"), pd.Timestamp("2011-02-01")] + + res = union_categoricals([Categorical(val1), Categorical(val2)]) + exp = Categorical( + val1 + val2, + categories=[ + pd.Timestamp("2011-01-01"), + pd.Timestamp("2011-03-01"), + pd.Timestamp("2011-02-01"), + ], + ) + tm.assert_categorical_equal(res, exp) + + # all NaN + res = union_categoricals( + [ + Categorical(np.array([np.nan, np.nan], dtype=object)), + Categorical(["X"], categories=pd.Index(["X"], dtype=object)), + ] + ) + exp = Categorical([np.nan, np.nan, "X"]) + tm.assert_categorical_equal(res, exp) + + res = union_categoricals( + [Categorical([np.nan, np.nan]), Categorical([np.nan, np.nan])] + ) + exp = Categorical([np.nan, np.nan, np.nan, np.nan]) + tm.assert_categorical_equal(res, exp) + + @pytest.mark.parametrize("val", [[], ["1"]]) + def test_union_categoricals_empty(self, val, request, using_infer_string): + # GH 13759 + if using_infer_string and val == ["1"]: + request.applymarker(pytest.mark.xfail("object and strings dont match")) + res = union_categoricals([Categorical([]), Categorical(val)]) + exp = Categorical(val) + tm.assert_categorical_equal(res, exp) + + def test_union_categorical_same_category(self): + # check fastpath + c1 = Categorical([1, 2, 3, 4], categories=[1, 2, 3, 4]) + c2 = Categorical([3, 2, 1, np.nan], categories=[1, 2, 3, 4]) + res = union_categoricals([c1, c2]) + exp = Categorical([1, 2, 3, 4, 3, 2, 1, np.nan], categories=[1, 2, 3, 4]) + tm.assert_categorical_equal(res, exp) + + def test_union_categorical_same_category_str(self): + c1 = Categorical(["z", "z", "z"], categories=["x", "y", "z"]) + c2 = Categorical(["x", "x", "x"], categories=["x", "y", "z"]) + res = union_categoricals([c1, c2]) + exp = Categorical(["z", "z", "z", "x", "x", "x"], categories=["x", "y", "z"]) + tm.assert_categorical_equal(res, exp) + + def test_union_categorical_same_categories_different_order(self): + # https://github.com/pandas-dev/pandas/issues/19096 + c1 = Categorical(["a", "b", "c"], categories=["a", "b", "c"]) + c2 = Categorical(["a", "b", "c"], categories=["b", "a", "c"]) + result = union_categoricals([c1, c2]) + expected = Categorical( + ["a", "b", "c", "a", "b", "c"], categories=["a", "b", "c"] + ) + tm.assert_categorical_equal(result, expected) + + def test_union_categoricals_ordered(self): + c1 = Categorical([1, 2, 3], ordered=True) + c2 = Categorical([1, 2, 3], ordered=False) + + msg = "Categorical.ordered must be the same" + with pytest.raises(TypeError, match=msg): + union_categoricals([c1, c2]) + + res = union_categoricals([c1, c1]) + exp = Categorical([1, 2, 3, 1, 2, 3], ordered=True) + tm.assert_categorical_equal(res, exp) + + c1 = Categorical([1, 2, 3, np.nan], ordered=True) + c2 = Categorical([3, 2], categories=[1, 2, 3], ordered=True) + + res = union_categoricals([c1, c2]) + exp = Categorical([1, 2, 3, np.nan, 3, 2], ordered=True) + tm.assert_categorical_equal(res, exp) + + c1 = Categorical([1, 2, 3], ordered=True) + c2 = Categorical([1, 2, 3], categories=[3, 2, 1], ordered=True) + + msg = "to union ordered Categoricals, all categories must be the same" + with pytest.raises(TypeError, match=msg): + union_categoricals([c1, c2]) + + def test_union_categoricals_ignore_order(self): + # GH 15219 + c1 = Categorical([1, 2, 3], ordered=True) + c2 = Categorical([1, 2, 3], ordered=False) + + res = union_categoricals([c1, c2], ignore_order=True) + exp = Categorical([1, 2, 3, 1, 2, 3]) + tm.assert_categorical_equal(res, exp) + + msg = "Categorical.ordered must be the same" + with pytest.raises(TypeError, match=msg): + union_categoricals([c1, c2], ignore_order=False) + + res = union_categoricals([c1, c1], ignore_order=True) + exp = Categorical([1, 2, 3, 1, 2, 3]) + tm.assert_categorical_equal(res, exp) + + res = union_categoricals([c1, c1], ignore_order=False) + exp = Categorical([1, 2, 3, 1, 2, 3], categories=[1, 2, 3], ordered=True) + tm.assert_categorical_equal(res, exp) + + c1 = Categorical([1, 2, 3, np.nan], ordered=True) + c2 = Categorical([3, 2], categories=[1, 2, 3], ordered=True) + + res = union_categoricals([c1, c2], ignore_order=True) + exp = Categorical([1, 2, 3, np.nan, 3, 2]) + tm.assert_categorical_equal(res, exp) + + c1 = Categorical([1, 2, 3], ordered=True) + c2 = Categorical([1, 2, 3], categories=[3, 2, 1], ordered=True) + + res = union_categoricals([c1, c2], ignore_order=True) + exp = Categorical([1, 2, 3, 1, 2, 3]) + tm.assert_categorical_equal(res, exp) + + res = union_categoricals([c2, c1], ignore_order=True, sort_categories=True) + exp = Categorical([1, 2, 3, 1, 2, 3], categories=[1, 2, 3]) + tm.assert_categorical_equal(res, exp) + + c1 = Categorical([1, 2, 3], ordered=True) + c2 = Categorical([4, 5, 6], ordered=True) + result = union_categoricals([c1, c2], ignore_order=True) + expected = Categorical([1, 2, 3, 4, 5, 6]) + tm.assert_categorical_equal(result, expected) + + msg = "to union ordered Categoricals, all categories must be the same" + with pytest.raises(TypeError, match=msg): + union_categoricals([c1, c2], ignore_order=False) + + with pytest.raises(TypeError, match=msg): + union_categoricals([c1, c2]) + + def test_union_categoricals_sort(self): + # GH 13846 + c1 = Categorical(["x", "y", "z"]) + c2 = Categorical(["a", "b", "c"]) + result = union_categoricals([c1, c2], sort_categories=True) + expected = Categorical( + ["x", "y", "z", "a", "b", "c"], categories=["a", "b", "c", "x", "y", "z"] + ) + tm.assert_categorical_equal(result, expected) + + # fastpath + c1 = Categorical(["a", "b"], categories=["b", "a", "c"]) + c2 = Categorical(["b", "c"], categories=["b", "a", "c"]) + result = union_categoricals([c1, c2], sort_categories=True) + expected = Categorical(["a", "b", "b", "c"], categories=["a", "b", "c"]) + tm.assert_categorical_equal(result, expected) + + c1 = Categorical(["a", "b"], categories=["c", "a", "b"]) + c2 = Categorical(["b", "c"], categories=["c", "a", "b"]) + result = union_categoricals([c1, c2], sort_categories=True) + expected = Categorical(["a", "b", "b", "c"], categories=["a", "b", "c"]) + tm.assert_categorical_equal(result, expected) + + # fastpath - skip resort + c1 = Categorical(["a", "b"], categories=["a", "b", "c"]) + c2 = Categorical(["b", "c"], categories=["a", "b", "c"]) + result = union_categoricals([c1, c2], sort_categories=True) + expected = Categorical(["a", "b", "b", "c"], categories=["a", "b", "c"]) + tm.assert_categorical_equal(result, expected) + + c1 = Categorical(["x", np.nan]) + c2 = Categorical([np.nan, "b"]) + result = union_categoricals([c1, c2], sort_categories=True) + expected = Categorical(["x", np.nan, np.nan, "b"], categories=["b", "x"]) + tm.assert_categorical_equal(result, expected) + + c1 = Categorical([np.nan]) + c2 = Categorical([np.nan]) + result = union_categoricals([c1, c2], sort_categories=True) + expected = Categorical([np.nan, np.nan]) + tm.assert_categorical_equal(result, expected) + + c1 = Categorical([]) + c2 = Categorical([]) + result = union_categoricals([c1, c2], sort_categories=True) + expected = Categorical([]) + tm.assert_categorical_equal(result, expected) + + c1 = Categorical(["b", "a"], categories=["b", "a", "c"], ordered=True) + c2 = Categorical(["a", "c"], categories=["b", "a", "c"], ordered=True) + msg = "Cannot use sort_categories=True with ordered Categoricals" + with pytest.raises(TypeError, match=msg): + union_categoricals([c1, c2], sort_categories=True) + + def test_union_categoricals_sort_false(self): + # GH 13846 + c1 = Categorical(["x", "y", "z"]) + c2 = Categorical(["a", "b", "c"]) + result = union_categoricals([c1, c2], sort_categories=False) + expected = Categorical( + ["x", "y", "z", "a", "b", "c"], categories=["x", "y", "z", "a", "b", "c"] + ) + tm.assert_categorical_equal(result, expected) + + def test_union_categoricals_sort_false_fastpath(self): + # fastpath + c1 = Categorical(["a", "b"], categories=["b", "a", "c"]) + c2 = Categorical(["b", "c"], categories=["b", "a", "c"]) + result = union_categoricals([c1, c2], sort_categories=False) + expected = Categorical(["a", "b", "b", "c"], categories=["b", "a", "c"]) + tm.assert_categorical_equal(result, expected) + + def test_union_categoricals_sort_false_skipresort(self): + # fastpath - skip resort + c1 = Categorical(["a", "b"], categories=["a", "b", "c"]) + c2 = Categorical(["b", "c"], categories=["a", "b", "c"]) + result = union_categoricals([c1, c2], sort_categories=False) + expected = Categorical(["a", "b", "b", "c"], categories=["a", "b", "c"]) + tm.assert_categorical_equal(result, expected) + + def test_union_categoricals_sort_false_one_nan(self): + c1 = Categorical(["x", np.nan]) + c2 = Categorical([np.nan, "b"]) + result = union_categoricals([c1, c2], sort_categories=False) + expected = Categorical(["x", np.nan, np.nan, "b"], categories=["x", "b"]) + tm.assert_categorical_equal(result, expected) + + def test_union_categoricals_sort_false_only_nan(self): + c1 = Categorical([np.nan]) + c2 = Categorical([np.nan]) + result = union_categoricals([c1, c2], sort_categories=False) + expected = Categorical([np.nan, np.nan]) + tm.assert_categorical_equal(result, expected) + + def test_union_categoricals_sort_false_empty(self): + c1 = Categorical([]) + c2 = Categorical([]) + result = union_categoricals([c1, c2], sort_categories=False) + expected = Categorical([]) + tm.assert_categorical_equal(result, expected) + + def test_union_categoricals_sort_false_ordered_true(self): + c1 = Categorical(["b", "a"], categories=["b", "a", "c"], ordered=True) + c2 = Categorical(["a", "c"], categories=["b", "a", "c"], ordered=True) + result = union_categoricals([c1, c2], sort_categories=False) + expected = Categorical( + ["b", "a", "a", "c"], categories=["b", "a", "c"], ordered=True + ) + tm.assert_categorical_equal(result, expected) + + def test_union_categorical_unwrap(self): + # GH 14173 + c1 = Categorical(["a", "b"]) + c2 = Series(["b", "c"], dtype="category") + result = union_categoricals([c1, c2]) + expected = Categorical(["a", "b", "b", "c"]) + tm.assert_categorical_equal(result, expected) + + c2 = CategoricalIndex(c2) + result = union_categoricals([c1, c2]) + tm.assert_categorical_equal(result, expected) + + c1 = Series(c1) + result = union_categoricals([c1, c2]) + tm.assert_categorical_equal(result, expected) + + msg = "all components to combine must be Categorical" + with pytest.raises(TypeError, match=msg): + union_categoricals([c1, ["a", "b", "c"]]) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_util.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_util.py new file mode 100644 index 0000000000000000000000000000000000000000..4d0be7464cb3d97697323faef5b4e7cd0d9b6df0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/reshape/test_util.py @@ -0,0 +1,79 @@ +import numpy as np +import pytest + +from pandas import ( + Index, + date_range, +) +import pandas._testing as tm +from pandas.core.reshape.util import cartesian_product + + +class TestCartesianProduct: + def test_simple(self): + x, y = list("ABC"), [1, 22] + result1, result2 = cartesian_product([x, y]) + expected1 = np.array(["A", "A", "B", "B", "C", "C"]) + expected2 = np.array([1, 22, 1, 22, 1, 22]) + tm.assert_numpy_array_equal(result1, expected1) + tm.assert_numpy_array_equal(result2, expected2) + + def test_datetimeindex(self): + # regression test for GitHub issue #6439 + # make sure that the ordering on datetimeindex is consistent + x = date_range("2000-01-01", periods=2) + result1, result2 = (Index(y).day for y in cartesian_product([x, x])) + expected1 = Index([1, 1, 2, 2], dtype=np.int32) + expected2 = Index([1, 2, 1, 2], dtype=np.int32) + tm.assert_index_equal(result1, expected1) + tm.assert_index_equal(result2, expected2) + + def test_tzaware_retained(self): + x = date_range("2000-01-01", periods=2, tz="US/Pacific") + y = np.array([3, 4]) + result1, result2 = cartesian_product([x, y]) + + expected = x.repeat(2) + tm.assert_index_equal(result1, expected) + + def test_tzaware_retained_categorical(self): + x = date_range("2000-01-01", periods=2, tz="US/Pacific").astype("category") + y = np.array([3, 4]) + result1, result2 = cartesian_product([x, y]) + + expected = x.repeat(2) + tm.assert_index_equal(result1, expected) + + @pytest.mark.parametrize("x, y", [[[], []], [[0, 1], []], [[], ["a", "b", "c"]]]) + def test_empty(self, x, y): + # product of empty factors + expected1 = np.array([], dtype=np.asarray(x).dtype) + expected2 = np.array([], dtype=np.asarray(y).dtype) + result1, result2 = cartesian_product([x, y]) + tm.assert_numpy_array_equal(result1, expected1) + tm.assert_numpy_array_equal(result2, expected2) + + def test_empty_input(self): + # empty product (empty input): + result = cartesian_product([]) + expected = [] + assert result == expected + + @pytest.mark.parametrize( + "X", [1, [1], [1, 2], [[1], 2], "a", ["a"], ["a", "b"], [["a"], "b"]] + ) + def test_invalid_input(self, X): + msg = "Input must be a list-like of list-likes" + + with pytest.raises(TypeError, match=msg): + cartesian_product(X=X) + + def test_exceed_product_space(self): + # GH31355: raise useful error when produce space is too large + msg = "Product space too large to allocate arrays!" + + with pytest.raises(ValueError, match=msg): + dims = [np.arange(0, 22, dtype=np.int16) for i in range(12)] + [ + (np.arange(15128, dtype=np.int16)), + ] + cartesian_product(X=dims) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/scalar/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/scalar/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/scalar/test_na_scalar.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/scalar/test_na_scalar.py new file mode 100644 index 0000000000000000000000000000000000000000..287b7557f50f9f6a81763f86d0eb616cfd730f8c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/scalar/test_na_scalar.py @@ -0,0 +1,316 @@ +from datetime import ( + date, + time, + timedelta, +) +import pickle + +import numpy as np +import pytest + +from pandas._libs.missing import NA + +from pandas.core.dtypes.common import is_scalar + +import pandas as pd +import pandas._testing as tm + + +def test_singleton(): + assert NA is NA + new_NA = type(NA)() + assert new_NA is NA + + +def test_repr(): + assert repr(NA) == "" + assert str(NA) == "" + + +def test_format(): + # GH-34740 + assert format(NA) == "" + assert format(NA, ">10") == " " + assert format(NA, "xxx") == "" # NA is flexible, accept any format spec + + assert f"{NA}" == "" + assert f"{NA:>10}" == " " + assert f"{NA:xxx}" == "" + + +def test_truthiness(): + msg = "boolean value of NA is ambiguous" + + with pytest.raises(TypeError, match=msg): + bool(NA) + + with pytest.raises(TypeError, match=msg): + not NA + + +def test_hashable(): + assert hash(NA) == hash(NA) + d = {NA: "test"} + assert d[NA] == "test" + + +@pytest.mark.parametrize( + "other", [NA, 1, 1.0, "a", b"a", np.int64(1), np.nan], ids=repr +) +def test_arithmetic_ops(all_arithmetic_functions, other): + op = all_arithmetic_functions + + if op.__name__ in ("pow", "rpow", "rmod") and isinstance(other, (str, bytes)): + pytest.skip(reason=f"{op.__name__} with NA and {other} not defined.") + if op.__name__ in ("divmod", "rdivmod"): + assert op(NA, other) is (NA, NA) + else: + if op.__name__ == "rpow": + # avoid special case + other += 1 + assert op(NA, other) is NA + + +@pytest.mark.parametrize( + "other", + [ + NA, + 1, + 1.0, + "a", + b"a", + np.int64(1), + np.nan, + np.bool_(True), + time(0), + date(1, 2, 3), + timedelta(1), + pd.NaT, + ], +) +def test_comparison_ops(comparison_op, other): + assert comparison_op(NA, other) is NA + assert comparison_op(other, NA) is NA + + +@pytest.mark.parametrize( + "value", + [ + 0, + 0.0, + -0, + -0.0, + False, + np.bool_(False), + np.int_(0), + np.float64(0), + np.int_(-0), + np.float64(-0), + ], +) +@pytest.mark.parametrize("asarray", [True, False]) +def test_pow_special(value, asarray): + if asarray: + value = np.array([value]) + result = NA**value + + if asarray: + result = result[0] + else: + # this assertion isn't possible for ndarray. + assert isinstance(result, type(value)) + assert result == 1 + + +@pytest.mark.parametrize( + "value", [1, 1.0, True, np.bool_(True), np.int_(1), np.float64(1)] +) +@pytest.mark.parametrize("asarray", [True, False]) +def test_rpow_special(value, asarray): + if asarray: + value = np.array([value]) + result = value**NA + + if asarray: + result = result[0] + elif not isinstance(value, (np.float64, np.bool_, np.int_)): + # this assertion isn't possible with asarray=True + assert isinstance(result, type(value)) + + assert result == value + + +@pytest.mark.parametrize("value", [-1, -1.0, np.int_(-1), np.float64(-1)]) +@pytest.mark.parametrize("asarray", [True, False]) +def test_rpow_minus_one(value, asarray): + if asarray: + value = np.array([value]) + result = value**NA + + if asarray: + result = result[0] + + assert pd.isna(result) + + +def test_unary_ops(): + assert +NA is NA + assert -NA is NA + assert abs(NA) is NA + assert ~NA is NA + + +def test_logical_and(): + assert NA & True is NA + assert True & NA is NA + assert NA & False is False + assert False & NA is False + assert NA & NA is NA + + msg = "unsupported operand type" + with pytest.raises(TypeError, match=msg): + NA & 5 + + +def test_logical_or(): + assert NA | True is True + assert True | NA is True + assert NA | False is NA + assert False | NA is NA + assert NA | NA is NA + + msg = "unsupported operand type" + with pytest.raises(TypeError, match=msg): + NA | 5 + + +def test_logical_xor(): + assert NA ^ True is NA + assert True ^ NA is NA + assert NA ^ False is NA + assert False ^ NA is NA + assert NA ^ NA is NA + + msg = "unsupported operand type" + with pytest.raises(TypeError, match=msg): + NA ^ 5 + + +def test_logical_not(): + assert ~NA is NA + + +@pytest.mark.parametrize("shape", [(3,), (3, 3), (1, 2, 3)]) +def test_arithmetic_ndarray(shape, all_arithmetic_functions): + op = all_arithmetic_functions + a = np.zeros(shape) + if op.__name__ == "pow": + a += 5 + result = op(NA, a) + expected = np.full(a.shape, NA, dtype=object) + tm.assert_numpy_array_equal(result, expected) + + +def test_is_scalar(): + assert is_scalar(NA) is True + + +def test_isna(): + assert pd.isna(NA) is True + assert pd.notna(NA) is False + + +def test_series_isna(): + s = pd.Series([1, NA], dtype=object) + expected = pd.Series([False, True]) + tm.assert_series_equal(s.isna(), expected) + + +def test_ufunc(): + assert np.log(NA) is NA + assert np.add(NA, 1) is NA + result = np.divmod(NA, 1) + assert result[0] is NA and result[1] is NA + + result = np.frexp(NA) + assert result[0] is NA and result[1] is NA + + +def test_ufunc_raises(): + msg = "ufunc method 'at'" + with pytest.raises(ValueError, match=msg): + np.log.at(NA, 0) + + +def test_binary_input_not_dunder(): + a = np.array([1, 2, 3]) + expected = np.array([NA, NA, NA], dtype=object) + result = np.logaddexp(a, NA) + tm.assert_numpy_array_equal(result, expected) + + result = np.logaddexp(NA, a) + tm.assert_numpy_array_equal(result, expected) + + # all NA, multiple inputs + assert np.logaddexp(NA, NA) is NA + + result = np.modf(NA, NA) + assert len(result) == 2 + assert all(x is NA for x in result) + + +def test_divmod_ufunc(): + # binary in, binary out. + a = np.array([1, 2, 3]) + expected = np.array([NA, NA, NA], dtype=object) + + result = np.divmod(a, NA) + assert isinstance(result, tuple) + for arr in result: + tm.assert_numpy_array_equal(arr, expected) + tm.assert_numpy_array_equal(arr, expected) + + result = np.divmod(NA, a) + for arr in result: + tm.assert_numpy_array_equal(arr, expected) + tm.assert_numpy_array_equal(arr, expected) + + +def test_integer_hash_collision_dict(): + # GH 30013 + result = {NA: "foo", hash(NA): "bar"} + + assert result[NA] == "foo" + assert result[hash(NA)] == "bar" + + +def test_integer_hash_collision_set(): + # GH 30013 + result = {NA, hash(NA)} + + assert len(result) == 2 + assert NA in result + assert hash(NA) in result + + +def test_pickle_roundtrip(): + # https://github.com/pandas-dev/pandas/issues/31847 + result = pickle.loads(pickle.dumps(NA)) + assert result is NA + + +def test_pickle_roundtrip_pandas(): + result = tm.round_trip_pickle(NA) + assert result is NA + + +@pytest.mark.parametrize( + "values, dtype", [([1, 2, NA], "Int64"), (["A", "B", NA], "string")] +) +@pytest.mark.parametrize("as_frame", [True, False]) +def test_pickle_roundtrip_containers(as_frame, values, dtype): + s = pd.Series(pd.array(values, dtype=dtype)) + if as_frame: + s = s.to_frame(name="A") + result = tm.round_trip_pickle(s) + tm.assert_equal(result, s) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/scalar/test_nat.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/scalar/test_nat.py new file mode 100644 index 0000000000000000000000000000000000000000..cb046e0133245689f75c2def996bfabfb18e6185 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/scalar/test_nat.py @@ -0,0 +1,709 @@ +from datetime import ( + datetime, + timedelta, +) +import operator + +import numpy as np +import pytest +import pytz + +from pandas._libs.tslibs import iNaT +from pandas.compat.numpy import np_version_gte1p24p3 + +from pandas import ( + DatetimeIndex, + DatetimeTZDtype, + Index, + NaT, + Period, + Series, + Timedelta, + TimedeltaIndex, + Timestamp, + isna, + offsets, +) +import pandas._testing as tm +from pandas.core import roperator +from pandas.core.arrays import ( + DatetimeArray, + PeriodArray, + TimedeltaArray, +) + + +class TestNaTFormatting: + def test_repr(self): + assert repr(NaT) == "NaT" + + def test_str(self): + assert str(NaT) == "NaT" + + def test_isoformat(self): + assert NaT.isoformat() == "NaT" + + +@pytest.mark.parametrize( + "nat,idx", + [ + (Timestamp("NaT"), DatetimeArray), + (Timedelta("NaT"), TimedeltaArray), + (Period("NaT", freq="M"), PeriodArray), + ], +) +def test_nat_fields(nat, idx): + for field in idx._field_ops: + # weekday is a property of DTI, but a method + # on NaT/Timestamp for compat with datetime + if field == "weekday": + continue + + result = getattr(NaT, field) + assert np.isnan(result) + + result = getattr(nat, field) + assert np.isnan(result) + + for field in idx._bool_ops: + result = getattr(NaT, field) + assert result is False + + result = getattr(nat, field) + assert result is False + + +def test_nat_vector_field_access(): + idx = DatetimeIndex(["1/1/2000", None, None, "1/4/2000"]) + + for field in DatetimeArray._field_ops: + # weekday is a property of DTI, but a method + # on NaT/Timestamp for compat with datetime + if field == "weekday": + continue + + result = getattr(idx, field) + expected = Index([getattr(x, field) for x in idx]) + tm.assert_index_equal(result, expected) + + ser = Series(idx) + + for field in DatetimeArray._field_ops: + # weekday is a property of DTI, but a method + # on NaT/Timestamp for compat with datetime + if field == "weekday": + continue + + result = getattr(ser.dt, field) + expected = [getattr(x, field) for x in idx] + tm.assert_series_equal(result, Series(expected)) + + for field in DatetimeArray._bool_ops: + result = getattr(ser.dt, field) + expected = [getattr(x, field) for x in idx] + tm.assert_series_equal(result, Series(expected)) + + +@pytest.mark.parametrize("klass", [Timestamp, Timedelta, Period]) +@pytest.mark.parametrize( + "value", [None, np.nan, iNaT, float("nan"), NaT, "NaT", "nat", "", "NAT"] +) +def test_identity(klass, value): + assert klass(value) is NaT + + +@pytest.mark.parametrize("klass", [Timestamp, Timedelta]) +@pytest.mark.parametrize("method", ["round", "floor", "ceil"]) +@pytest.mark.parametrize("freq", ["s", "5s", "min", "5min", "h", "5h"]) +def test_round_nat(klass, method, freq): + # see gh-14940 + ts = klass("nat") + + round_method = getattr(ts, method) + assert round_method(freq) is ts + + +@pytest.mark.parametrize( + "method", + [ + "astimezone", + "combine", + "ctime", + "dst", + "fromordinal", + "fromtimestamp", + "fromisocalendar", + "isocalendar", + "strftime", + "strptime", + "time", + "timestamp", + "timetuple", + "timetz", + "toordinal", + "tzname", + "utcfromtimestamp", + "utcnow", + "utcoffset", + "utctimetuple", + "timestamp", + ], +) +def test_nat_methods_raise(method): + # see gh-9513, gh-17329 + msg = f"NaTType does not support {method}" + + with pytest.raises(ValueError, match=msg): + getattr(NaT, method)() + + +@pytest.mark.parametrize("method", ["weekday", "isoweekday"]) +def test_nat_methods_nan(method): + # see gh-9513, gh-17329 + assert np.isnan(getattr(NaT, method)()) + + +@pytest.mark.parametrize( + "method", ["date", "now", "replace", "today", "tz_convert", "tz_localize"] +) +def test_nat_methods_nat(method): + # see gh-8254, gh-9513, gh-17329 + assert getattr(NaT, method)() is NaT + + +@pytest.mark.parametrize( + "get_nat", [lambda x: NaT, lambda x: Timedelta(x), lambda x: Timestamp(x)] +) +def test_nat_iso_format(get_nat): + # see gh-12300 + assert get_nat("NaT").isoformat() == "NaT" + assert get_nat("NaT").isoformat(timespec="nanoseconds") == "NaT" + + +@pytest.mark.parametrize( + "klass,expected", + [ + (Timestamp, ["normalize", "to_julian_date", "to_period", "unit"]), + ( + Timedelta, + [ + "components", + "resolution_string", + "to_pytimedelta", + "to_timedelta64", + "unit", + "view", + ], + ), + ], +) +def test_missing_public_nat_methods(klass, expected): + # see gh-17327 + # + # NaT should have *most* of the Timestamp and Timedelta methods. + # Here, we check which public methods NaT does not have. We + # ignore any missing private methods. + nat_names = dir(NaT) + klass_names = dir(klass) + + missing = [x for x in klass_names if x not in nat_names and not x.startswith("_")] + missing.sort() + + assert missing == expected + + +def _get_overlap_public_nat_methods(klass, as_tuple=False): + """ + Get overlapping public methods between NaT and another class. + + Parameters + ---------- + klass : type + The class to compare with NaT + as_tuple : bool, default False + Whether to return a list of tuples of the form (klass, method). + + Returns + ------- + overlap : list + """ + nat_names = dir(NaT) + klass_names = dir(klass) + + overlap = [ + x + for x in nat_names + if x in klass_names and not x.startswith("_") and callable(getattr(klass, x)) + ] + + # Timestamp takes precedence over Timedelta in terms of overlap. + if klass is Timedelta: + ts_names = dir(Timestamp) + overlap = [x for x in overlap if x not in ts_names] + + if as_tuple: + overlap = [(klass, method) for method in overlap] + + overlap.sort() + return overlap + + +@pytest.mark.parametrize( + "klass,expected", + [ + ( + Timestamp, + [ + "as_unit", + "astimezone", + "ceil", + "combine", + "ctime", + "date", + "day_name", + "dst", + "floor", + "fromisocalendar", + "fromisoformat", + "fromordinal", + "fromtimestamp", + "isocalendar", + "isoformat", + "isoweekday", + "month_name", + "now", + "replace", + "round", + "strftime", + "strptime", + "time", + "timestamp", + "timetuple", + "timetz", + "to_datetime64", + "to_numpy", + "to_pydatetime", + "today", + "toordinal", + "tz_convert", + "tz_localize", + "tzname", + "utcfromtimestamp", + "utcnow", + "utcoffset", + "utctimetuple", + "weekday", + ], + ), + (Timedelta, ["total_seconds"]), + ], +) +def test_overlap_public_nat_methods(klass, expected): + # see gh-17327 + # + # NaT should have *most* of the Timestamp and Timedelta methods. + # In case when Timestamp, Timedelta, and NaT are overlap, the overlap + # is considered to be with Timestamp and NaT, not Timedelta. + assert _get_overlap_public_nat_methods(klass) == expected + + +@pytest.mark.parametrize( + "compare", + ( + _get_overlap_public_nat_methods(Timestamp, True) + + _get_overlap_public_nat_methods(Timedelta, True) + ), + ids=lambda x: f"{x[0].__name__}.{x[1]}", +) +def test_nat_doc_strings(compare): + # see gh-17327 + # + # The docstrings for overlapping methods should match. + klass, method = compare + klass_doc = getattr(klass, method).__doc__ + + if klass == Timestamp and method == "isoformat": + pytest.skip( + "Ignore differences with Timestamp.isoformat() as they're intentional" + ) + + if method == "to_numpy": + # GH#44460 can return either dt64 or td64 depending on dtype, + # different docstring is intentional + pytest.skip(f"different docstring for {method} is intentional") + + nat_doc = getattr(NaT, method).__doc__ + assert klass_doc == nat_doc + + +_ops = { + "left_plus_right": lambda a, b: a + b, + "right_plus_left": lambda a, b: b + a, + "left_minus_right": lambda a, b: a - b, + "right_minus_left": lambda a, b: b - a, + "left_times_right": lambda a, b: a * b, + "right_times_left": lambda a, b: b * a, + "left_div_right": lambda a, b: a / b, + "right_div_left": lambda a, b: b / a, +} + + +@pytest.mark.parametrize("op_name", list(_ops.keys())) +@pytest.mark.parametrize( + "value,val_type", + [ + (2, "scalar"), + (1.5, "floating"), + (np.nan, "floating"), + ("foo", "str"), + (timedelta(3600), "timedelta"), + (Timedelta("5s"), "timedelta"), + (datetime(2014, 1, 1), "timestamp"), + (Timestamp("2014-01-01"), "timestamp"), + (Timestamp("2014-01-01", tz="UTC"), "timestamp"), + (Timestamp("2014-01-01", tz="US/Eastern"), "timestamp"), + (pytz.timezone("Asia/Tokyo").localize(datetime(2014, 1, 1)), "timestamp"), + ], +) +def test_nat_arithmetic_scalar(op_name, value, val_type): + # see gh-6873 + invalid_ops = { + "scalar": {"right_div_left"}, + "floating": { + "right_div_left", + "left_minus_right", + "right_minus_left", + "left_plus_right", + "right_plus_left", + }, + "str": set(_ops.keys()), + "timedelta": {"left_times_right", "right_times_left"}, + "timestamp": { + "left_times_right", + "right_times_left", + "left_div_right", + "right_div_left", + }, + } + + op = _ops[op_name] + + if op_name in invalid_ops.get(val_type, set()): + if ( + val_type == "timedelta" + and "times" in op_name + and isinstance(value, Timedelta) + ): + typs = "(Timedelta|NaTType)" + msg = rf"unsupported operand type\(s\) for \*: '{typs}' and '{typs}'" + elif val_type == "str": + # un-specific check here because the message comes from str + # and varies by method + msg = "|".join( + [ + "can only concatenate str", + "unsupported operand type", + "can't multiply sequence", + "Can't convert 'NaTType'", + "must be str, not NaTType", + ] + ) + else: + msg = "unsupported operand type" + + with pytest.raises(TypeError, match=msg): + op(NaT, value) + else: + if val_type == "timedelta" and "div" in op_name: + expected = np.nan + else: + expected = NaT + + assert op(NaT, value) is expected + + +@pytest.mark.parametrize( + "val,expected", [(np.nan, NaT), (NaT, np.nan), (np.timedelta64("NaT"), np.nan)] +) +def test_nat_rfloordiv_timedelta(val, expected): + # see gh-#18846 + # + # See also test_timedelta.TestTimedeltaArithmetic.test_floordiv + td = Timedelta(hours=3, minutes=4) + assert td // val is expected + + +@pytest.mark.parametrize( + "op_name", + ["left_plus_right", "right_plus_left", "left_minus_right", "right_minus_left"], +) +@pytest.mark.parametrize( + "value", + [ + DatetimeIndex(["2011-01-01", "2011-01-02"], name="x"), + DatetimeIndex(["2011-01-01", "2011-01-02"], tz="US/Eastern", name="x"), + DatetimeArray._from_sequence(["2011-01-01", "2011-01-02"], dtype="M8[ns]"), + DatetimeArray._from_sequence( + ["2011-01-01", "2011-01-02"], dtype=DatetimeTZDtype(tz="US/Pacific") + ), + TimedeltaIndex(["1 day", "2 day"], name="x"), + ], +) +def test_nat_arithmetic_index(op_name, value): + # see gh-11718 + exp_name = "x" + exp_data = [NaT] * 2 + + if value.dtype.kind == "M" and "plus" in op_name: + expected = DatetimeIndex(exp_data, tz=value.tz, name=exp_name) + else: + expected = TimedeltaIndex(exp_data, name=exp_name) + expected = expected.as_unit(value.unit) + + if not isinstance(value, Index): + expected = expected.array + + op = _ops[op_name] + result = op(NaT, value) + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "op_name", + ["left_plus_right", "right_plus_left", "left_minus_right", "right_minus_left"], +) +@pytest.mark.parametrize("box", [TimedeltaIndex, Series, TimedeltaArray._from_sequence]) +def test_nat_arithmetic_td64_vector(op_name, box): + # see gh-19124 + vec = box(["1 day", "2 day"], dtype="timedelta64[ns]") + box_nat = box([NaT, NaT], dtype="timedelta64[ns]") + tm.assert_equal(_ops[op_name](vec, NaT), box_nat) + + +@pytest.mark.parametrize( + "dtype,op,out_dtype", + [ + ("datetime64[ns]", operator.add, "datetime64[ns]"), + ("datetime64[ns]", roperator.radd, "datetime64[ns]"), + ("datetime64[ns]", operator.sub, "timedelta64[ns]"), + ("datetime64[ns]", roperator.rsub, "timedelta64[ns]"), + ("timedelta64[ns]", operator.add, "datetime64[ns]"), + ("timedelta64[ns]", roperator.radd, "datetime64[ns]"), + ("timedelta64[ns]", operator.sub, "datetime64[ns]"), + ("timedelta64[ns]", roperator.rsub, "timedelta64[ns]"), + ], +) +def test_nat_arithmetic_ndarray(dtype, op, out_dtype): + other = np.arange(10).astype(dtype) + result = op(NaT, other) + + expected = np.empty(other.shape, dtype=out_dtype) + expected.fill("NaT") + tm.assert_numpy_array_equal(result, expected) + + +def test_nat_pinned_docstrings(): + # see gh-17327 + assert NaT.ctime.__doc__ == Timestamp.ctime.__doc__ + + +def test_to_numpy_alias(): + # GH 24653: alias .to_numpy() for scalars + expected = NaT.to_datetime64() + result = NaT.to_numpy() + + assert isna(expected) and isna(result) + + # GH#44460 + result = NaT.to_numpy("M8[s]") + assert isinstance(result, np.datetime64) + assert result.dtype == "M8[s]" + + result = NaT.to_numpy("m8[ns]") + assert isinstance(result, np.timedelta64) + assert result.dtype == "m8[ns]" + + result = NaT.to_numpy("m8[s]") + assert isinstance(result, np.timedelta64) + assert result.dtype == "m8[s]" + + with pytest.raises(ValueError, match="NaT.to_numpy dtype must be a "): + NaT.to_numpy(np.int64) + + +@pytest.mark.parametrize( + "other", + [ + Timedelta(0), + Timedelta(0).to_pytimedelta(), + pytest.param( + Timedelta(0).to_timedelta64(), + marks=pytest.mark.xfail( + not np_version_gte1p24p3, + reason="td64 doesn't return NotImplemented, see numpy#17017", + # When this xfail is fixed, test_nat_comparisons_numpy + # can be removed. + ), + ), + Timestamp(0), + Timestamp(0).to_pydatetime(), + pytest.param( + Timestamp(0).to_datetime64(), + marks=pytest.mark.xfail( + not np_version_gte1p24p3, + reason="dt64 doesn't return NotImplemented, see numpy#17017", + ), + ), + Timestamp(0).tz_localize("UTC"), + NaT, + ], +) +def test_nat_comparisons(compare_operators_no_eq_ne, other): + # GH 26039 + opname = compare_operators_no_eq_ne + + assert getattr(NaT, opname)(other) is False + + op = getattr(operator, opname.strip("_")) + assert op(NaT, other) is False + assert op(other, NaT) is False + + +@pytest.mark.parametrize("other", [np.timedelta64(0, "ns"), np.datetime64("now", "ns")]) +def test_nat_comparisons_numpy(other): + # Once numpy#17017 is fixed and the xfailed cases in test_nat_comparisons + # pass, this test can be removed + assert not NaT == other + assert NaT != other + assert not NaT < other + assert not NaT > other + assert not NaT <= other + assert not NaT >= other + + +@pytest.mark.parametrize("other_and_type", [("foo", "str"), (2, "int"), (2.0, "float")]) +@pytest.mark.parametrize( + "symbol_and_op", + [("<=", operator.le), ("<", operator.lt), (">=", operator.ge), (">", operator.gt)], +) +def test_nat_comparisons_invalid(other_and_type, symbol_and_op): + # GH#35585 + other, other_type = other_and_type + symbol, op = symbol_and_op + + assert not NaT == other + assert not other == NaT + + assert NaT != other + assert other != NaT + + msg = f"'{symbol}' not supported between instances of 'NaTType' and '{other_type}'" + with pytest.raises(TypeError, match=msg): + op(NaT, other) + + msg = f"'{symbol}' not supported between instances of '{other_type}' and 'NaTType'" + with pytest.raises(TypeError, match=msg): + op(other, NaT) + + +@pytest.mark.parametrize( + "other", + [ + np.array(["foo"] * 2, dtype=object), + np.array([2, 3], dtype="int64"), + np.array([2.0, 3.5], dtype="float64"), + ], + ids=["str", "int", "float"], +) +def test_nat_comparisons_invalid_ndarray(other): + # GH#40722 + expected = np.array([False, False]) + result = NaT == other + tm.assert_numpy_array_equal(result, expected) + result = other == NaT + tm.assert_numpy_array_equal(result, expected) + + expected = np.array([True, True]) + result = NaT != other + tm.assert_numpy_array_equal(result, expected) + result = other != NaT + tm.assert_numpy_array_equal(result, expected) + + for symbol, op in [ + ("<=", operator.le), + ("<", operator.lt), + (">=", operator.ge), + (">", operator.gt), + ]: + msg = f"'{symbol}' not supported between" + + with pytest.raises(TypeError, match=msg): + op(NaT, other) + + if other.dtype == np.dtype("object"): + # uses the reverse operator, so symbol changes + msg = None + with pytest.raises(TypeError, match=msg): + op(other, NaT) + + +def test_compare_date(fixed_now_ts): + # GH#39151 comparing NaT with date object is deprecated + # See also: tests.scalar.timestamps.test_comparisons::test_compare_date + + dt = fixed_now_ts.to_pydatetime().date() + + msg = "Cannot compare NaT with datetime.date object" + for left, right in [(NaT, dt), (dt, NaT)]: + assert not left == right + assert left != right + + with pytest.raises(TypeError, match=msg): + left < right + with pytest.raises(TypeError, match=msg): + left <= right + with pytest.raises(TypeError, match=msg): + left > right + with pytest.raises(TypeError, match=msg): + left >= right + + +@pytest.mark.parametrize( + "obj", + [ + offsets.YearEnd(2), + offsets.YearBegin(2), + offsets.MonthBegin(1), + offsets.MonthEnd(2), + offsets.MonthEnd(12), + offsets.Day(2), + offsets.Day(5), + offsets.Hour(24), + offsets.Hour(3), + offsets.Minute(), + np.timedelta64(3, "h"), + np.timedelta64(4, "h"), + np.timedelta64(3200, "s"), + np.timedelta64(3600, "s"), + np.timedelta64(3600 * 24, "s"), + np.timedelta64(2, "D"), + np.timedelta64(365, "D"), + timedelta(-2), + timedelta(365), + timedelta(minutes=120), + timedelta(days=4, minutes=180), + timedelta(hours=23), + timedelta(hours=23, minutes=30), + timedelta(hours=48), + ], +) +def test_nat_addsub_tdlike_scalar(obj): + assert NaT + obj is NaT + assert obj + NaT is NaT + assert NaT - obj is NaT + + +def test_pickle(): + # GH#4606 + p = tm.round_trip_pickle(NaT) + assert p is NaT diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..01b49b5e5b63323b065ec11fc34f6c247a7b0350 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/__init__.py @@ -0,0 +1,15 @@ +import numpy as np + +import pandas as pd + +object_pyarrow_numpy = ("object", "string[pyarrow_numpy]") + + +def _convert_na_value(ser, expected): + if ser.dtype != object: + if ser.dtype.storage == "pyarrow_numpy": + expected = expected.fillna(np.nan) + else: + # GH#18463 + expected = expected.fillna(pd.NA) + return expected diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/conftest.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..036e4de20ba538bc4dbe6636fc802fb9c8f10e5d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/conftest.py @@ -0,0 +1,132 @@ +import pytest + +from pandas import Series +from pandas.core.strings.accessor import StringMethods + +_any_string_method = [ + ("cat", (), {"sep": ","}), + ("cat", (Series(list("zyx")),), {"sep": ",", "join": "left"}), + ("center", (10,), {}), + ("contains", ("a",), {}), + ("count", ("a",), {}), + ("decode", ("UTF-8",), {}), + ("encode", ("UTF-8",), {}), + ("endswith", ("a",), {}), + ("endswith", ((),), {}), + ("endswith", (("a",),), {}), + ("endswith", (("a", "b"),), {}), + ("endswith", (("a", "MISSING"),), {}), + ("endswith", ("a",), {"na": True}), + ("endswith", ("a",), {"na": False}), + ("extract", ("([a-z]*)",), {"expand": False}), + ("extract", ("([a-z]*)",), {"expand": True}), + ("extractall", ("([a-z]*)",), {}), + ("find", ("a",), {}), + ("findall", ("a",), {}), + ("get", (0,), {}), + # because "index" (and "rindex") fail intentionally + # if the string is not found, search only for empty string + ("index", ("",), {}), + ("join", (",",), {}), + ("ljust", (10,), {}), + ("match", ("a",), {}), + ("fullmatch", ("a",), {}), + ("normalize", ("NFC",), {}), + ("pad", (10,), {}), + ("partition", (" ",), {"expand": False}), + ("partition", (" ",), {"expand": True}), + ("repeat", (3,), {}), + ("replace", ("a", "z"), {}), + ("rfind", ("a",), {}), + ("rindex", ("",), {}), + ("rjust", (10,), {}), + ("rpartition", (" ",), {"expand": False}), + ("rpartition", (" ",), {"expand": True}), + ("slice", (0, 1), {}), + ("slice_replace", (0, 1, "z"), {}), + ("split", (" ",), {"expand": False}), + ("split", (" ",), {"expand": True}), + ("startswith", ("a",), {}), + ("startswith", (("a",),), {}), + ("startswith", (("a", "b"),), {}), + ("startswith", (("a", "MISSING"),), {}), + ("startswith", ((),), {}), + ("startswith", ("a",), {"na": True}), + ("startswith", ("a",), {"na": False}), + ("removeprefix", ("a",), {}), + ("removesuffix", ("a",), {}), + # translating unicode points of "a" to "d" + ("translate", ({97: 100},), {}), + ("wrap", (2,), {}), + ("zfill", (10,), {}), +] + list( + zip( + [ + # methods without positional arguments: zip with empty tuple and empty dict + "capitalize", + "cat", + "get_dummies", + "isalnum", + "isalpha", + "isdecimal", + "isdigit", + "islower", + "isnumeric", + "isspace", + "istitle", + "isupper", + "len", + "lower", + "lstrip", + "partition", + "rpartition", + "rsplit", + "rstrip", + "slice", + "slice_replace", + "split", + "strip", + "swapcase", + "title", + "upper", + "casefold", + ], + [()] * 100, + [{}] * 100, + ) +) +ids, _, _ = zip(*_any_string_method) # use method name as fixture-id +missing_methods = {f for f in dir(StringMethods) if not f.startswith("_")} - set(ids) + +# test that the above list captures all methods of StringMethods +assert not missing_methods + + +@pytest.fixture(params=_any_string_method, ids=ids) +def any_string_method(request): + """ + Fixture for all public methods of `StringMethods` + + This fixture returns a tuple of the method name and sample arguments + necessary to call the method. + + Returns + ------- + method_name : str + The name of the method in `StringMethods` + args : tuple + Sample values for the positional arguments + kwargs : dict + Sample values for the keyword arguments + + Examples + -------- + >>> def test_something(any_string_method): + ... s = Series(['a', 'b', np.nan, 'd']) + ... + ... method_name, args, kwargs = any_string_method + ... method = getattr(s.str, method_name) + ... # will not raise + ... method(*args, **kwargs) + """ + return request.param diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_api.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_api.py new file mode 100644 index 0000000000000000000000000000000000000000..31e005466af7b935c446e01f90ed87bb4736b84b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_api.py @@ -0,0 +1,198 @@ +import numpy as np +import pytest + +from pandas import ( + CategoricalDtype, + DataFrame, + Index, + MultiIndex, + Series, + _testing as tm, + option_context, +) +from pandas.core.strings.accessor import StringMethods + +# subset of the full set from pandas/conftest.py +_any_allowed_skipna_inferred_dtype = [ + ("string", ["a", np.nan, "c"]), + ("bytes", [b"a", np.nan, b"c"]), + ("empty", [np.nan, np.nan, np.nan]), + ("empty", []), + ("mixed-integer", ["a", np.nan, 2]), +] +ids, _ = zip(*_any_allowed_skipna_inferred_dtype) # use inferred type as id + + +@pytest.fixture(params=_any_allowed_skipna_inferred_dtype, ids=ids) +def any_allowed_skipna_inferred_dtype(request): + """ + Fixture for all (inferred) dtypes allowed in StringMethods.__init__ + + The covered (inferred) types are: + * 'string' + * 'empty' + * 'bytes' + * 'mixed' + * 'mixed-integer' + + Returns + ------- + inferred_dtype : str + The string for the inferred dtype from _libs.lib.infer_dtype + values : np.ndarray + An array of object dtype that will be inferred to have + `inferred_dtype` + + Examples + -------- + >>> from pandas._libs import lib + >>> + >>> def test_something(any_allowed_skipna_inferred_dtype): + ... inferred_dtype, values = any_allowed_skipna_inferred_dtype + ... # will pass + ... assert lib.infer_dtype(values, skipna=True) == inferred_dtype + ... + ... # constructor for .str-accessor will also pass + ... Series(values).str + """ + inferred_dtype, values = request.param + values = np.array(values, dtype=object) # object dtype to avoid casting + + # correctness of inference tested in tests/dtypes/test_inference.py + return inferred_dtype, values + + +def test_api(any_string_dtype): + # GH 6106, GH 9322 + assert Series.str is StringMethods + assert isinstance(Series([""], dtype=any_string_dtype).str, StringMethods) + + +def test_api_mi_raises(): + # GH 23679 + mi = MultiIndex.from_arrays([["a", "b", "c"]]) + msg = "Can only use .str accessor with Index, not MultiIndex" + with pytest.raises(AttributeError, match=msg): + mi.str + assert not hasattr(mi, "str") + + +@pytest.mark.parametrize("dtype", [object, "category"]) +def test_api_per_dtype(index_or_series, dtype, any_skipna_inferred_dtype): + # one instance of parametrized fixture + box = index_or_series + inferred_dtype, values = any_skipna_inferred_dtype + + t = box(values, dtype=dtype) # explicit dtype to avoid casting + + types_passing_constructor = [ + "string", + "unicode", + "empty", + "bytes", + "mixed", + "mixed-integer", + ] + if inferred_dtype in types_passing_constructor: + # GH 6106 + assert isinstance(t.str, StringMethods) + else: + # GH 9184, GH 23011, GH 23163 + msg = "Can only use .str accessor with string values.*" + with pytest.raises(AttributeError, match=msg): + t.str + assert not hasattr(t, "str") + + +@pytest.mark.parametrize("dtype", [object, "category"]) +def test_api_per_method( + index_or_series, + dtype, + any_allowed_skipna_inferred_dtype, + any_string_method, + request, +): + # this test does not check correctness of the different methods, + # just that the methods work on the specified (inferred) dtypes, + # and raise on all others + box = index_or_series + + # one instance of each parametrized fixture + inferred_dtype, values = any_allowed_skipna_inferred_dtype + method_name, args, kwargs = any_string_method + + reason = None + if box is Index and values.size == 0: + if method_name in ["partition", "rpartition"] and kwargs.get("expand", True): + raises = TypeError + reason = "Method cannot deal with empty Index" + elif method_name == "split" and kwargs.get("expand", None): + raises = TypeError + reason = "Split fails on empty Series when expand=True" + elif method_name == "get_dummies": + raises = ValueError + reason = "Need to fortify get_dummies corner cases" + + elif ( + box is Index + and inferred_dtype == "empty" + and dtype == object + and method_name == "get_dummies" + ): + raises = ValueError + reason = "Need to fortify get_dummies corner cases" + + if reason is not None: + mark = pytest.mark.xfail(raises=raises, reason=reason) + request.applymarker(mark) + + t = box(values, dtype=dtype) # explicit dtype to avoid casting + method = getattr(t.str, method_name) + + bytes_allowed = method_name in ["decode", "get", "len", "slice"] + # as of v0.23.4, all methods except 'cat' are very lenient with the + # allowed data types, just returning NaN for entries that error. + # This could be changed with an 'errors'-kwarg to the `str`-accessor, + # see discussion in GH 13877 + mixed_allowed = method_name not in ["cat"] + + allowed_types = ( + ["string", "unicode", "empty"] + + ["bytes"] * bytes_allowed + + ["mixed", "mixed-integer"] * mixed_allowed + ) + + if inferred_dtype in allowed_types: + # xref GH 23555, GH 23556 + with option_context("future.no_silent_downcasting", True): + method(*args, **kwargs) # works! + else: + # GH 23011, GH 23163 + msg = ( + f"Cannot use .str.{method_name} with values of " + f"inferred dtype {repr(inferred_dtype)}." + ) + with pytest.raises(TypeError, match=msg): + method(*args, **kwargs) + + +def test_api_for_categorical(any_string_method, any_string_dtype): + # https://github.com/pandas-dev/pandas/issues/10661 + s = Series(list("aabb"), dtype=any_string_dtype) + s = s + " " + s + c = s.astype("category") + c = c.astype(CategoricalDtype(c.dtype.categories.astype("object"))) + assert isinstance(c.str, StringMethods) + + method_name, args, kwargs = any_string_method + + result = getattr(c.str, method_name)(*args, **kwargs) + expected = getattr(s.astype("object").str, method_name)(*args, **kwargs) + + if isinstance(result, DataFrame): + tm.assert_frame_equal(result, expected) + elif isinstance(result, Series): + tm.assert_series_equal(result, expected) + else: + # str.cat(others=None) returns string, for example + assert result == expected diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_case_justify.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_case_justify.py new file mode 100644 index 0000000000000000000000000000000000000000..41aedae90ca7656a3df7c4f09beee79b2f533741 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_case_justify.py @@ -0,0 +1,427 @@ +from datetime import datetime +import operator + +import numpy as np +import pytest + +from pandas import ( + Series, + _testing as tm, +) + + +def test_title(any_string_dtype): + s = Series(["FOO", "BAR", np.nan, "Blah", "blurg"], dtype=any_string_dtype) + result = s.str.title() + expected = Series(["Foo", "Bar", np.nan, "Blah", "Blurg"], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +def test_title_mixed_object(): + s = Series(["FOO", np.nan, "bar", True, datetime.today(), "blah", None, 1, 2.0]) + result = s.str.title() + expected = Series( + ["Foo", np.nan, "Bar", np.nan, np.nan, "Blah", None, np.nan, np.nan], + dtype=object, + ) + tm.assert_almost_equal(result, expected) + + +def test_lower_upper(any_string_dtype): + s = Series(["om", np.nan, "nom", "nom"], dtype=any_string_dtype) + + result = s.str.upper() + expected = Series(["OM", np.nan, "NOM", "NOM"], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + result = result.str.lower() + tm.assert_series_equal(result, s) + + +def test_lower_upper_mixed_object(): + s = Series(["a", np.nan, "b", True, datetime.today(), "foo", None, 1, 2.0]) + + result = s.str.upper() + expected = Series( + ["A", np.nan, "B", np.nan, np.nan, "FOO", None, np.nan, np.nan], dtype=object + ) + tm.assert_series_equal(result, expected) + + result = s.str.lower() + expected = Series( + ["a", np.nan, "b", np.nan, np.nan, "foo", None, np.nan, np.nan], dtype=object + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "data, expected", + [ + ( + ["FOO", "BAR", np.nan, "Blah", "blurg"], + ["Foo", "Bar", np.nan, "Blah", "Blurg"], + ), + (["a", "b", "c"], ["A", "B", "C"]), + (["a b", "a bc. de"], ["A b", "A bc. de"]), + ], +) +def test_capitalize(data, expected, any_string_dtype): + s = Series(data, dtype=any_string_dtype) + result = s.str.capitalize() + expected = Series(expected, dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +def test_capitalize_mixed_object(): + s = Series(["FOO", np.nan, "bar", True, datetime.today(), "blah", None, 1, 2.0]) + result = s.str.capitalize() + expected = Series( + ["Foo", np.nan, "Bar", np.nan, np.nan, "Blah", None, np.nan, np.nan], + dtype=object, + ) + tm.assert_series_equal(result, expected) + + +def test_swapcase(any_string_dtype): + s = Series(["FOO", "BAR", np.nan, "Blah", "blurg"], dtype=any_string_dtype) + result = s.str.swapcase() + expected = Series(["foo", "bar", np.nan, "bLAH", "BLURG"], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +def test_swapcase_mixed_object(): + s = Series(["FOO", np.nan, "bar", True, datetime.today(), "Blah", None, 1, 2.0]) + result = s.str.swapcase() + expected = Series( + ["foo", np.nan, "BAR", np.nan, np.nan, "bLAH", None, np.nan, np.nan], + dtype=object, + ) + tm.assert_series_equal(result, expected) + + +def test_casefold(): + # GH25405 + expected = Series(["ss", np.nan, "case", "ssd"]) + s = Series(["ß", np.nan, "case", "ßd"]) + result = s.str.casefold() + + tm.assert_series_equal(result, expected) + + +def test_casemethods(any_string_dtype): + values = ["aaa", "bbb", "CCC", "Dddd", "eEEE"] + s = Series(values, dtype=any_string_dtype) + assert s.str.lower().tolist() == [v.lower() for v in values] + assert s.str.upper().tolist() == [v.upper() for v in values] + assert s.str.title().tolist() == [v.title() for v in values] + assert s.str.capitalize().tolist() == [v.capitalize() for v in values] + assert s.str.swapcase().tolist() == [v.swapcase() for v in values] + + +def test_pad(any_string_dtype): + s = Series(["a", "b", np.nan, "c", np.nan, "eeeeee"], dtype=any_string_dtype) + + result = s.str.pad(5, side="left") + expected = Series( + [" a", " b", np.nan, " c", np.nan, "eeeeee"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + + result = s.str.pad(5, side="right") + expected = Series( + ["a ", "b ", np.nan, "c ", np.nan, "eeeeee"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + + result = s.str.pad(5, side="both") + expected = Series( + [" a ", " b ", np.nan, " c ", np.nan, "eeeeee"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + + +def test_pad_mixed_object(): + s = Series(["a", np.nan, "b", True, datetime.today(), "ee", None, 1, 2.0]) + + result = s.str.pad(5, side="left") + expected = Series( + [" a", np.nan, " b", np.nan, np.nan, " ee", None, np.nan, np.nan], + dtype=object, + ) + tm.assert_series_equal(result, expected) + + result = s.str.pad(5, side="right") + expected = Series( + ["a ", np.nan, "b ", np.nan, np.nan, "ee ", None, np.nan, np.nan], + dtype=object, + ) + tm.assert_series_equal(result, expected) + + result = s.str.pad(5, side="both") + expected = Series( + [" a ", np.nan, " b ", np.nan, np.nan, " ee ", None, np.nan, np.nan], + dtype=object, + ) + tm.assert_series_equal(result, expected) + + +def test_pad_fillchar(any_string_dtype): + s = Series(["a", "b", np.nan, "c", np.nan, "eeeeee"], dtype=any_string_dtype) + + result = s.str.pad(5, side="left", fillchar="X") + expected = Series( + ["XXXXa", "XXXXb", np.nan, "XXXXc", np.nan, "eeeeee"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + + result = s.str.pad(5, side="right", fillchar="X") + expected = Series( + ["aXXXX", "bXXXX", np.nan, "cXXXX", np.nan, "eeeeee"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + + result = s.str.pad(5, side="both", fillchar="X") + expected = Series( + ["XXaXX", "XXbXX", np.nan, "XXcXX", np.nan, "eeeeee"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + + +def test_pad_fillchar_bad_arg_raises(any_string_dtype): + s = Series(["a", "b", np.nan, "c", np.nan, "eeeeee"], dtype=any_string_dtype) + + msg = "fillchar must be a character, not str" + with pytest.raises(TypeError, match=msg): + s.str.pad(5, fillchar="XY") + + msg = "fillchar must be a character, not int" + with pytest.raises(TypeError, match=msg): + s.str.pad(5, fillchar=5) + + +@pytest.mark.parametrize("method_name", ["center", "ljust", "rjust", "zfill", "pad"]) +def test_pad_width_bad_arg_raises(method_name, any_string_dtype): + # see gh-13598 + s = Series(["1", "22", "a", "bb"], dtype=any_string_dtype) + op = operator.methodcaller(method_name, "f") + + msg = "width must be of integer type, not str" + with pytest.raises(TypeError, match=msg): + op(s.str) + + +def test_center_ljust_rjust(any_string_dtype): + s = Series(["a", "b", np.nan, "c", np.nan, "eeeeee"], dtype=any_string_dtype) + + result = s.str.center(5) + expected = Series( + [" a ", " b ", np.nan, " c ", np.nan, "eeeeee"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + + result = s.str.ljust(5) + expected = Series( + ["a ", "b ", np.nan, "c ", np.nan, "eeeeee"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + + result = s.str.rjust(5) + expected = Series( + [" a", " b", np.nan, " c", np.nan, "eeeeee"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + + +def test_center_ljust_rjust_mixed_object(): + s = Series(["a", np.nan, "b", True, datetime.today(), "c", "eee", None, 1, 2.0]) + + result = s.str.center(5) + expected = Series( + [ + " a ", + np.nan, + " b ", + np.nan, + np.nan, + " c ", + " eee ", + None, + np.nan, + np.nan, + ], + dtype=object, + ) + tm.assert_series_equal(result, expected) + + result = s.str.ljust(5) + expected = Series( + [ + "a ", + np.nan, + "b ", + np.nan, + np.nan, + "c ", + "eee ", + None, + np.nan, + np.nan, + ], + dtype=object, + ) + tm.assert_series_equal(result, expected) + + result = s.str.rjust(5) + expected = Series( + [ + " a", + np.nan, + " b", + np.nan, + np.nan, + " c", + " eee", + None, + np.nan, + np.nan, + ], + dtype=object, + ) + tm.assert_series_equal(result, expected) + + +def test_center_ljust_rjust_fillchar(any_string_dtype): + if any_string_dtype == "string[pyarrow_numpy]": + pytest.skip( + "Arrow logic is different, " + "see https://github.com/pandas-dev/pandas/pull/54533/files#r1299808126", + ) + s = Series(["a", "bb", "cccc", "ddddd", "eeeeee"], dtype=any_string_dtype) + + result = s.str.center(5, fillchar="X") + expected = Series( + ["XXaXX", "XXbbX", "Xcccc", "ddddd", "eeeeee"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + expected = np.array([v.center(5, "X") for v in np.array(s)], dtype=np.object_) + tm.assert_numpy_array_equal(np.array(result, dtype=np.object_), expected) + + result = s.str.ljust(5, fillchar="X") + expected = Series( + ["aXXXX", "bbXXX", "ccccX", "ddddd", "eeeeee"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + expected = np.array([v.ljust(5, "X") for v in np.array(s)], dtype=np.object_) + tm.assert_numpy_array_equal(np.array(result, dtype=np.object_), expected) + + result = s.str.rjust(5, fillchar="X") + expected = Series( + ["XXXXa", "XXXbb", "Xcccc", "ddddd", "eeeeee"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + expected = np.array([v.rjust(5, "X") for v in np.array(s)], dtype=np.object_) + tm.assert_numpy_array_equal(np.array(result, dtype=np.object_), expected) + + +def test_center_ljust_rjust_fillchar_bad_arg_raises(any_string_dtype): + s = Series(["a", "bb", "cccc", "ddddd", "eeeeee"], dtype=any_string_dtype) + + # If fillchar is not a character, normal str raises TypeError + # 'aaa'.ljust(5, 'XY') + # TypeError: must be char, not str + template = "fillchar must be a character, not {dtype}" + + with pytest.raises(TypeError, match=template.format(dtype="str")): + s.str.center(5, fillchar="XY") + + with pytest.raises(TypeError, match=template.format(dtype="str")): + s.str.ljust(5, fillchar="XY") + + with pytest.raises(TypeError, match=template.format(dtype="str")): + s.str.rjust(5, fillchar="XY") + + with pytest.raises(TypeError, match=template.format(dtype="int")): + s.str.center(5, fillchar=1) + + with pytest.raises(TypeError, match=template.format(dtype="int")): + s.str.ljust(5, fillchar=1) + + with pytest.raises(TypeError, match=template.format(dtype="int")): + s.str.rjust(5, fillchar=1) + + +def test_zfill(any_string_dtype): + s = Series(["1", "22", "aaa", "333", "45678"], dtype=any_string_dtype) + + result = s.str.zfill(5) + expected = Series( + ["00001", "00022", "00aaa", "00333", "45678"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + expected = np.array([v.zfill(5) for v in np.array(s)], dtype=np.object_) + tm.assert_numpy_array_equal(np.array(result, dtype=np.object_), expected) + + result = s.str.zfill(3) + expected = Series(["001", "022", "aaa", "333", "45678"], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + expected = np.array([v.zfill(3) for v in np.array(s)], dtype=np.object_) + tm.assert_numpy_array_equal(np.array(result, dtype=np.object_), expected) + + s = Series(["1", np.nan, "aaa", np.nan, "45678"], dtype=any_string_dtype) + result = s.str.zfill(5) + expected = Series( + ["00001", np.nan, "00aaa", np.nan, "45678"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + + +def test_wrap(any_string_dtype): + # test values are: two words less than width, two words equal to width, + # two words greater than width, one word less than width, one word + # equal to width, one word greater than width, multiple tokens with + # trailing whitespace equal to width + s = Series( + [ + "hello world", + "hello world!", + "hello world!!", + "abcdefabcde", + "abcdefabcdef", + "abcdefabcdefa", + "ab ab ab ab ", + "ab ab ab ab a", + "\t", + ], + dtype=any_string_dtype, + ) + + # expected values + expected = Series( + [ + "hello world", + "hello world!", + "hello\nworld!!", + "abcdefabcde", + "abcdefabcdef", + "abcdefabcdef\na", + "ab ab ab ab", + "ab ab ab ab\na", + "", + ], + dtype=any_string_dtype, + ) + + result = s.str.wrap(12, break_long_words=True) + tm.assert_series_equal(result, expected) + + +def test_wrap_unicode(any_string_dtype): + # test with pre and post whitespace (non-unicode), NaN, and non-ascii Unicode + s = Series( + [" pre ", np.nan, "\xac\u20ac\U00008000 abadcafe"], dtype=any_string_dtype + ) + expected = Series( + [" pre", np.nan, "\xac\u20ac\U00008000 ab\nadcafe"], dtype=any_string_dtype + ) + result = s.str.wrap(6) + tm.assert_series_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_cat.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_cat.py new file mode 100644 index 0000000000000000000000000000000000000000..c1e7ad6e02779259f37bd6bad57f03428d9c3055 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_cat.py @@ -0,0 +1,427 @@ +import re + +import numpy as np +import pytest + +import pandas.util._test_decorators as td + +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, + _testing as tm, + concat, + option_context, +) + + +@pytest.mark.parametrize("other", [None, Series, Index]) +def test_str_cat_name(index_or_series, other): + # GH 21053 + box = index_or_series + values = ["a", "b"] + if other: + other = other(values) + else: + other = values + result = box(values, name="name").str.cat(other, sep=",") + assert result.name == "name" + + +@pytest.mark.parametrize( + "infer_string", [False, pytest.param(True, marks=td.skip_if_no("pyarrow"))] +) +def test_str_cat(index_or_series, infer_string): + with option_context("future.infer_string", infer_string): + box = index_or_series + # test_cat above tests "str_cat" from ndarray; + # here testing "str.cat" from Series/Index to ndarray/list + s = box(["a", "a", "b", "b", "c", np.nan]) + + # single array + result = s.str.cat() + expected = "aabbc" + assert result == expected + + result = s.str.cat(na_rep="-") + expected = "aabbc-" + assert result == expected + + result = s.str.cat(sep="_", na_rep="NA") + expected = "a_a_b_b_c_NA" + assert result == expected + + t = np.array(["a", np.nan, "b", "d", "foo", np.nan], dtype=object) + expected = box(["aa", "a-", "bb", "bd", "cfoo", "--"]) + + # Series/Index with array + result = s.str.cat(t, na_rep="-") + tm.assert_equal(result, expected) + + # Series/Index with list + result = s.str.cat(list(t), na_rep="-") + tm.assert_equal(result, expected) + + # errors for incorrect lengths + rgx = r"If `others` contains arrays or lists \(or other list-likes.*" + z = Series(["1", "2", "3"]) + + with pytest.raises(ValueError, match=rgx): + s.str.cat(z.values) + + with pytest.raises(ValueError, match=rgx): + s.str.cat(list(z)) + + +def test_str_cat_raises_intuitive_error(index_or_series): + # GH 11334 + box = index_or_series + s = box(["a", "b", "c", "d"]) + message = "Did you mean to supply a `sep` keyword?" + with pytest.raises(ValueError, match=message): + s.str.cat("|") + with pytest.raises(ValueError, match=message): + s.str.cat(" ") + + +@pytest.mark.parametrize( + "infer_string", [False, pytest.param(True, marks=td.skip_if_no("pyarrow"))] +) +@pytest.mark.parametrize("sep", ["", None]) +@pytest.mark.parametrize("dtype_target", ["object", "category"]) +@pytest.mark.parametrize("dtype_caller", ["object", "category"]) +def test_str_cat_categorical( + index_or_series, dtype_caller, dtype_target, sep, infer_string +): + box = index_or_series + + with option_context("future.infer_string", infer_string): + s = Index(["a", "a", "b", "a"], dtype=dtype_caller) + s = s if box == Index else Series(s, index=s, dtype=s.dtype) + t = Index(["b", "a", "b", "c"], dtype=dtype_target) + + expected = Index( + ["ab", "aa", "bb", "ac"], dtype=object if dtype_caller == "object" else None + ) + expected = ( + expected + if box == Index + else Series( + expected, index=Index(s, dtype=dtype_caller), dtype=expected.dtype + ) + ) + + # Series/Index with unaligned Index -> t.values + result = s.str.cat(t.values, sep=sep) + tm.assert_equal(result, expected) + + # Series/Index with Series having matching Index + t = Series(t.values, index=Index(s, dtype=dtype_caller)) + result = s.str.cat(t, sep=sep) + tm.assert_equal(result, expected) + + # Series/Index with Series.values + result = s.str.cat(t.values, sep=sep) + tm.assert_equal(result, expected) + + # Series/Index with Series having different Index + t = Series(t.values, index=t.values) + expected = Index( + ["aa", "aa", "bb", "bb", "aa"], + dtype=object if dtype_caller == "object" else None, + ) + dtype = object if dtype_caller == "object" else s.dtype.categories.dtype + expected = ( + expected + if box == Index + else Series( + expected, + index=Index(expected.str[:1], dtype=dtype), + dtype=expected.dtype, + ) + ) + + result = s.str.cat(t, sep=sep) + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "data", + [[1, 2, 3], [0.1, 0.2, 0.3], [1, 2, "b"]], + ids=["integers", "floats", "mixed"], +) +# without dtype=object, np.array would cast [1, 2, 'b'] to ['1', '2', 'b'] +@pytest.mark.parametrize( + "box", + [Series, Index, list, lambda x: np.array(x, dtype=object)], + ids=["Series", "Index", "list", "np.array"], +) +def test_str_cat_wrong_dtype_raises(box, data): + # GH 22722 + s = Series(["a", "b", "c"]) + t = box(data) + + msg = "Concatenation requires list-likes containing only strings.*" + with pytest.raises(TypeError, match=msg): + # need to use outer and na_rep, as otherwise Index would not raise + s.str.cat(t, join="outer", na_rep="-") + + +def test_str_cat_mixed_inputs(index_or_series): + box = index_or_series + s = Index(["a", "b", "c", "d"]) + s = s if box == Index else Series(s, index=s) + + t = Series(["A", "B", "C", "D"], index=s.values) + d = concat([t, Series(s, index=s)], axis=1) + + expected = Index(["aAa", "bBb", "cCc", "dDd"]) + expected = expected if box == Index else Series(expected.values, index=s.values) + + # Series/Index with DataFrame + result = s.str.cat(d) + tm.assert_equal(result, expected) + + # Series/Index with two-dimensional ndarray + result = s.str.cat(d.values) + tm.assert_equal(result, expected) + + # Series/Index with list of Series + result = s.str.cat([t, s]) + tm.assert_equal(result, expected) + + # Series/Index with mixed list of Series/array + result = s.str.cat([t, s.values]) + tm.assert_equal(result, expected) + + # Series/Index with list of Series; different indexes + t.index = ["b", "c", "d", "a"] + expected = box(["aDa", "bAb", "cBc", "dCd"]) + expected = expected if box == Index else Series(expected.values, index=s.values) + result = s.str.cat([t, s]) + tm.assert_equal(result, expected) + + # Series/Index with mixed list; different index + result = s.str.cat([t, s.values]) + tm.assert_equal(result, expected) + + # Series/Index with DataFrame; different indexes + d.index = ["b", "c", "d", "a"] + expected = box(["aDd", "bAa", "cBb", "dCc"]) + expected = expected if box == Index else Series(expected.values, index=s.values) + result = s.str.cat(d) + tm.assert_equal(result, expected) + + # errors for incorrect lengths + rgx = r"If `others` contains arrays or lists \(or other list-likes.*" + z = Series(["1", "2", "3"]) + e = concat([z, z], axis=1) + + # two-dimensional ndarray + with pytest.raises(ValueError, match=rgx): + s.str.cat(e.values) + + # list of list-likes + with pytest.raises(ValueError, match=rgx): + s.str.cat([z.values, s.values]) + + # mixed list of Series/list-like + with pytest.raises(ValueError, match=rgx): + s.str.cat([z.values, s]) + + # errors for incorrect arguments in list-like + rgx = "others must be Series, Index, DataFrame,.*" + # make sure None/NaN do not crash checks in _get_series_list + u = Series(["a", np.nan, "c", None]) + + # mix of string and Series + with pytest.raises(TypeError, match=rgx): + s.str.cat([u, "u"]) + + # DataFrame in list + with pytest.raises(TypeError, match=rgx): + s.str.cat([u, d]) + + # 2-dim ndarray in list + with pytest.raises(TypeError, match=rgx): + s.str.cat([u, d.values]) + + # nested lists + with pytest.raises(TypeError, match=rgx): + s.str.cat([u, [u, d]]) + + # forbidden input type: set + # GH 23009 + with pytest.raises(TypeError, match=rgx): + s.str.cat(set(u)) + + # forbidden input type: set in list + # GH 23009 + with pytest.raises(TypeError, match=rgx): + s.str.cat([u, set(u)]) + + # other forbidden input type, e.g. int + with pytest.raises(TypeError, match=rgx): + s.str.cat(1) + + # nested list-likes + with pytest.raises(TypeError, match=rgx): + s.str.cat(iter([t.values, list(s)])) + + +@pytest.mark.parametrize("join", ["left", "outer", "inner", "right"]) +def test_str_cat_align_indexed(index_or_series, join): + # https://github.com/pandas-dev/pandas/issues/18657 + box = index_or_series + + s = Series(["a", "b", "c", "d"], index=["a", "b", "c", "d"]) + t = Series(["D", "A", "E", "B"], index=["d", "a", "e", "b"]) + sa, ta = s.align(t, join=join) + # result after manual alignment of inputs + expected = sa.str.cat(ta, na_rep="-") + + if box == Index: + s = Index(s) + sa = Index(sa) + expected = Index(expected) + + result = s.str.cat(t, join=join, na_rep="-") + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize("join", ["left", "outer", "inner", "right"]) +def test_str_cat_align_mixed_inputs(join): + s = Series(["a", "b", "c", "d"]) + t = Series(["d", "a", "e", "b"], index=[3, 0, 4, 1]) + d = concat([t, t], axis=1) + + expected_outer = Series(["aaa", "bbb", "c--", "ddd", "-ee"]) + expected = expected_outer.loc[s.index.join(t.index, how=join)] + + # list of Series + result = s.str.cat([t, t], join=join, na_rep="-") + tm.assert_series_equal(result, expected) + + # DataFrame + result = s.str.cat(d, join=join, na_rep="-") + tm.assert_series_equal(result, expected) + + # mixed list of indexed/unindexed + u = np.array(["A", "B", "C", "D"]) + expected_outer = Series(["aaA", "bbB", "c-C", "ddD", "-e-"]) + # joint index of rhs [t, u]; u will be forced have index of s + rhs_idx = ( + t.index.intersection(s.index) + if join == "inner" + else t.index.union(s.index) + if join == "outer" + else t.index.append(s.index.difference(t.index)) + ) + + expected = expected_outer.loc[s.index.join(rhs_idx, how=join)] + result = s.str.cat([t, u], join=join, na_rep="-") + tm.assert_series_equal(result, expected) + + with pytest.raises(TypeError, match="others must be Series,.*"): + # nested lists are forbidden + s.str.cat([t, list(u)], join=join) + + # errors for incorrect lengths + rgx = r"If `others` contains arrays or lists \(or other list-likes.*" + z = Series(["1", "2", "3"]).values + + # unindexed object of wrong length + with pytest.raises(ValueError, match=rgx): + s.str.cat(z, join=join) + + # unindexed object of wrong length in list + with pytest.raises(ValueError, match=rgx): + s.str.cat([t, z], join=join) + + +def test_str_cat_all_na(index_or_series, index_or_series2): + # GH 24044 + box = index_or_series + other = index_or_series2 + + # check that all NaNs in caller / target work + s = Index(["a", "b", "c", "d"]) + s = s if box == Index else Series(s, index=s) + t = other([np.nan] * 4, dtype=object) + # add index of s for alignment + t = t if other == Index else Series(t, index=s) + + # all-NA target + if box == Series: + expected = Series([np.nan] * 4, index=s.index, dtype=s.dtype) + else: # box == Index + # TODO: Strimg option, this should return string dtype + expected = Index([np.nan] * 4, dtype=object) + result = s.str.cat(t, join="left") + tm.assert_equal(result, expected) + + # all-NA caller (only for Series) + if other == Series: + expected = Series([np.nan] * 4, dtype=object, index=t.index) + result = t.str.cat(s, join="left") + tm.assert_series_equal(result, expected) + + +def test_str_cat_special_cases(): + s = Series(["a", "b", "c", "d"]) + t = Series(["d", "a", "e", "b"], index=[3, 0, 4, 1]) + + # iterator of elements with different types + expected = Series(["aaa", "bbb", "c-c", "ddd", "-e-"]) + result = s.str.cat(iter([t, s.values]), join="outer", na_rep="-") + tm.assert_series_equal(result, expected) + + # right-align with different indexes in others + expected = Series(["aa-", "d-d"], index=[0, 3]) + result = s.str.cat([t.loc[[0]], t.loc[[3]]], join="right", na_rep="-") + tm.assert_series_equal(result, expected) + + +def test_cat_on_filtered_index(): + df = DataFrame( + index=MultiIndex.from_product( + [[2011, 2012], [1, 2, 3]], names=["year", "month"] + ) + ) + + df = df.reset_index() + df = df[df.month > 1] + + str_year = df.year.astype("str") + str_month = df.month.astype("str") + str_both = str_year.str.cat(str_month, sep=" ") + + assert str_both.loc[1] == "2011 2" + + str_multiple = str_year.str.cat([str_month, str_month], sep=" ") + + assert str_multiple.loc[1] == "2011 2 2" + + +@pytest.mark.parametrize("klass", [tuple, list, np.array, Series, Index]) +def test_cat_different_classes(klass): + # https://github.com/pandas-dev/pandas/issues/33425 + s = Series(["a", "b", "c"]) + result = s.str.cat(klass(["x", "y", "z"])) + expected = Series(["ax", "by", "cz"]) + tm.assert_series_equal(result, expected) + + +def test_cat_on_series_dot_str(): + # GH 28277 + ps = Series(["AbC", "de", "FGHI", "j", "kLLLm"]) + + message = re.escape( + "others must be Series, Index, DataFrame, np.ndarray " + "or list-like (either containing only strings or " + "containing only objects of type Series/Index/" + "np.ndarray[1-dim])" + ) + with pytest.raises(TypeError, match=message): + ps.str.cat(others=ps.str) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_extract.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_extract.py new file mode 100644 index 0000000000000000000000000000000000000000..77d008c650264889550ec70331a1b98064242d26 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_extract.py @@ -0,0 +1,724 @@ +from datetime import datetime +import re + +import numpy as np +import pytest + +from pandas.core.dtypes.dtypes import ArrowDtype + +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, + _testing as tm, +) + + +def test_extract_expand_kwarg_wrong_type_raises(any_string_dtype): + # TODO: should this raise TypeError + values = Series(["fooBAD__barBAD", np.nan, "foo"], dtype=any_string_dtype) + with pytest.raises(ValueError, match="expand must be True or False"): + values.str.extract(".*(BAD[_]+).*(BAD)", expand=None) + + +def test_extract_expand_kwarg(any_string_dtype): + s = Series(["fooBAD__barBAD", np.nan, "foo"], dtype=any_string_dtype) + expected = DataFrame(["BAD__", np.nan, np.nan], dtype=any_string_dtype) + + result = s.str.extract(".*(BAD[_]+).*") + tm.assert_frame_equal(result, expected) + + result = s.str.extract(".*(BAD[_]+).*", expand=True) + tm.assert_frame_equal(result, expected) + + expected = DataFrame( + [["BAD__", "BAD"], [np.nan, np.nan], [np.nan, np.nan]], dtype=any_string_dtype + ) + result = s.str.extract(".*(BAD[_]+).*(BAD)", expand=False) + tm.assert_frame_equal(result, expected) + + +def test_extract_expand_False_mixed_object(): + ser = Series( + ["aBAD_BAD", np.nan, "BAD_b_BAD", True, datetime.today(), "foo", None, 1, 2.0] + ) + + # two groups + result = ser.str.extract(".*(BAD[_]+).*(BAD)", expand=False) + er = [np.nan, np.nan] # empty row + expected = DataFrame( + [["BAD_", "BAD"], er, ["BAD_", "BAD"], er, er, er, er, er, er], dtype=object + ) + tm.assert_frame_equal(result, expected) + + # single group + result = ser.str.extract(".*(BAD[_]+).*BAD", expand=False) + expected = Series( + ["BAD_", np.nan, "BAD_", np.nan, np.nan, np.nan, None, np.nan, np.nan], + dtype=object, + ) + tm.assert_series_equal(result, expected) + + +def test_extract_expand_index_raises(): + # GH9980 + # Index only works with one regex group since + # multi-group would expand to a frame + idx = Index(["A1", "A2", "A3", "A4", "B5"]) + msg = "only one regex group is supported with Index" + with pytest.raises(ValueError, match=msg): + idx.str.extract("([AB])([123])", expand=False) + + +def test_extract_expand_no_capture_groups_raises(index_or_series, any_string_dtype): + s_or_idx = index_or_series(["A1", "B2", "C3"], dtype=any_string_dtype) + msg = "pattern contains no capture groups" + + # no groups + with pytest.raises(ValueError, match=msg): + s_or_idx.str.extract("[ABC][123]", expand=False) + + # only non-capturing groups + with pytest.raises(ValueError, match=msg): + s_or_idx.str.extract("(?:[AB]).*", expand=False) + + +def test_extract_expand_single_capture_group(index_or_series, any_string_dtype): + # single group renames series/index properly + s_or_idx = index_or_series(["A1", "A2"], dtype=any_string_dtype) + result = s_or_idx.str.extract(r"(?PA)\d", expand=False) + + expected = index_or_series(["A", "A"], name="uno", dtype=any_string_dtype) + if index_or_series == Series: + tm.assert_series_equal(result, expected) + else: + tm.assert_index_equal(result, expected) + + +def test_extract_expand_capture_groups(any_string_dtype): + s = Series(["A1", "B2", "C3"], dtype=any_string_dtype) + # one group, no matches + result = s.str.extract("(_)", expand=False) + expected = Series([np.nan, np.nan, np.nan], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + # two groups, no matches + result = s.str.extract("(_)(_)", expand=False) + expected = DataFrame( + [[np.nan, np.nan], [np.nan, np.nan], [np.nan, np.nan]], dtype=any_string_dtype + ) + tm.assert_frame_equal(result, expected) + + # one group, some matches + result = s.str.extract("([AB])[123]", expand=False) + expected = Series(["A", "B", np.nan], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + # two groups, some matches + result = s.str.extract("([AB])([123])", expand=False) + expected = DataFrame( + [["A", "1"], ["B", "2"], [np.nan, np.nan]], dtype=any_string_dtype + ) + tm.assert_frame_equal(result, expected) + + # one named group + result = s.str.extract("(?P[AB])", expand=False) + expected = Series(["A", "B", np.nan], name="letter", dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + # two named groups + result = s.str.extract("(?P[AB])(?P[123])", expand=False) + expected = DataFrame( + [["A", "1"], ["B", "2"], [np.nan, np.nan]], + columns=["letter", "number"], + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + # mix named and unnamed groups + result = s.str.extract("([AB])(?P[123])", expand=False) + expected = DataFrame( + [["A", "1"], ["B", "2"], [np.nan, np.nan]], + columns=[0, "number"], + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + # one normal group, one non-capturing group + result = s.str.extract("([AB])(?:[123])", expand=False) + expected = Series(["A", "B", np.nan], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + # two normal groups, one non-capturing group + s = Series(["A11", "B22", "C33"], dtype=any_string_dtype) + result = s.str.extract("([AB])([123])(?:[123])", expand=False) + expected = DataFrame( + [["A", "1"], ["B", "2"], [np.nan, np.nan]], dtype=any_string_dtype + ) + tm.assert_frame_equal(result, expected) + + # one optional group followed by one normal group + s = Series(["A1", "B2", "3"], dtype=any_string_dtype) + result = s.str.extract("(?P[AB])?(?P[123])", expand=False) + expected = DataFrame( + [["A", "1"], ["B", "2"], [np.nan, "3"]], + columns=["letter", "number"], + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + # one normal group followed by one optional group + s = Series(["A1", "B2", "C"], dtype=any_string_dtype) + result = s.str.extract("(?P[ABC])(?P[123])?", expand=False) + expected = DataFrame( + [["A", "1"], ["B", "2"], ["C", np.nan]], + columns=["letter", "number"], + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + +def test_extract_expand_capture_groups_index(index, any_string_dtype): + # https://github.com/pandas-dev/pandas/issues/6348 + # not passing index to the extractor + data = ["A1", "B2", "C"] + + if len(index) == 0: + pytest.skip("Test requires len(index) > 0") + while len(index) < len(data): + index = index.repeat(2) + + index = index[: len(data)] + ser = Series(data, index=index, dtype=any_string_dtype) + + result = ser.str.extract(r"(\d)", expand=False) + expected = Series(["1", "2", np.nan], index=index, dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + result = ser.str.extract(r"(?P\D)(?P\d)?", expand=False) + expected = DataFrame( + [["A", "1"], ["B", "2"], ["C", np.nan]], + columns=["letter", "number"], + index=index, + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + +def test_extract_single_series_name_is_preserved(any_string_dtype): + s = Series(["a3", "b3", "c2"], name="bob", dtype=any_string_dtype) + result = s.str.extract(r"(?P[a-z])", expand=False) + expected = Series(["a", "b", "c"], name="sue", dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +def test_extract_expand_True(any_string_dtype): + # Contains tests like those in test_match and some others. + s = Series(["fooBAD__barBAD", np.nan, "foo"], dtype=any_string_dtype) + + result = s.str.extract(".*(BAD[_]+).*(BAD)", expand=True) + expected = DataFrame( + [["BAD__", "BAD"], [np.nan, np.nan], [np.nan, np.nan]], dtype=any_string_dtype + ) + tm.assert_frame_equal(result, expected) + + +def test_extract_expand_True_mixed_object(): + er = [np.nan, np.nan] # empty row + mixed = Series( + [ + "aBAD_BAD", + np.nan, + "BAD_b_BAD", + True, + datetime.today(), + "foo", + None, + 1, + 2.0, + ] + ) + + result = mixed.str.extract(".*(BAD[_]+).*(BAD)", expand=True) + expected = DataFrame( + [["BAD_", "BAD"], er, ["BAD_", "BAD"], er, er, er, er, er, er], dtype=object + ) + tm.assert_frame_equal(result, expected) + + +def test_extract_expand_True_single_capture_group_raises( + index_or_series, any_string_dtype +): + # these should work for both Series and Index + # no groups + s_or_idx = index_or_series(["A1", "B2", "C3"], dtype=any_string_dtype) + msg = "pattern contains no capture groups" + with pytest.raises(ValueError, match=msg): + s_or_idx.str.extract("[ABC][123]", expand=True) + + # only non-capturing groups + with pytest.raises(ValueError, match=msg): + s_or_idx.str.extract("(?:[AB]).*", expand=True) + + +def test_extract_expand_True_single_capture_group(index_or_series, any_string_dtype): + # single group renames series/index properly + s_or_idx = index_or_series(["A1", "A2"], dtype=any_string_dtype) + result = s_or_idx.str.extract(r"(?PA)\d", expand=True) + expected = DataFrame({"uno": ["A", "A"]}, dtype=any_string_dtype) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("name", [None, "series_name"]) +def test_extract_series(name, any_string_dtype): + # extract should give the same result whether or not the series has a name. + s = Series(["A1", "B2", "C3"], name=name, dtype=any_string_dtype) + + # one group, no matches + result = s.str.extract("(_)", expand=True) + expected = DataFrame([np.nan, np.nan, np.nan], dtype=any_string_dtype) + tm.assert_frame_equal(result, expected) + + # two groups, no matches + result = s.str.extract("(_)(_)", expand=True) + expected = DataFrame( + [[np.nan, np.nan], [np.nan, np.nan], [np.nan, np.nan]], dtype=any_string_dtype + ) + tm.assert_frame_equal(result, expected) + + # one group, some matches + result = s.str.extract("([AB])[123]", expand=True) + expected = DataFrame(["A", "B", np.nan], dtype=any_string_dtype) + tm.assert_frame_equal(result, expected) + + # two groups, some matches + result = s.str.extract("([AB])([123])", expand=True) + expected = DataFrame( + [["A", "1"], ["B", "2"], [np.nan, np.nan]], dtype=any_string_dtype + ) + tm.assert_frame_equal(result, expected) + + # one named group + result = s.str.extract("(?P[AB])", expand=True) + expected = DataFrame({"letter": ["A", "B", np.nan]}, dtype=any_string_dtype) + tm.assert_frame_equal(result, expected) + + # two named groups + result = s.str.extract("(?P[AB])(?P[123])", expand=True) + expected = DataFrame( + [["A", "1"], ["B", "2"], [np.nan, np.nan]], + columns=["letter", "number"], + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + # mix named and unnamed groups + result = s.str.extract("([AB])(?P[123])", expand=True) + expected = DataFrame( + [["A", "1"], ["B", "2"], [np.nan, np.nan]], + columns=[0, "number"], + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + # one normal group, one non-capturing group + result = s.str.extract("([AB])(?:[123])", expand=True) + expected = DataFrame(["A", "B", np.nan], dtype=any_string_dtype) + tm.assert_frame_equal(result, expected) + + +def test_extract_optional_groups(any_string_dtype): + # two normal groups, one non-capturing group + s = Series(["A11", "B22", "C33"], dtype=any_string_dtype) + result = s.str.extract("([AB])([123])(?:[123])", expand=True) + expected = DataFrame( + [["A", "1"], ["B", "2"], [np.nan, np.nan]], dtype=any_string_dtype + ) + tm.assert_frame_equal(result, expected) + + # one optional group followed by one normal group + s = Series(["A1", "B2", "3"], dtype=any_string_dtype) + result = s.str.extract("(?P[AB])?(?P[123])", expand=True) + expected = DataFrame( + [["A", "1"], ["B", "2"], [np.nan, "3"]], + columns=["letter", "number"], + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + # one normal group followed by one optional group + s = Series(["A1", "B2", "C"], dtype=any_string_dtype) + result = s.str.extract("(?P[ABC])(?P[123])?", expand=True) + expected = DataFrame( + [["A", "1"], ["B", "2"], ["C", np.nan]], + columns=["letter", "number"], + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + +def test_extract_dataframe_capture_groups_index(index, any_string_dtype): + # GH6348 + # not passing index to the extractor + + data = ["A1", "B2", "C"] + + if len(index) < len(data): + pytest.skip(f"Index needs more than {len(data)} values") + + index = index[: len(data)] + s = Series(data, index=index, dtype=any_string_dtype) + + result = s.str.extract(r"(\d)", expand=True) + expected = DataFrame(["1", "2", np.nan], index=index, dtype=any_string_dtype) + tm.assert_frame_equal(result, expected) + + result = s.str.extract(r"(?P\D)(?P\d)?", expand=True) + expected = DataFrame( + [["A", "1"], ["B", "2"], ["C", np.nan]], + columns=["letter", "number"], + index=index, + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + +def test_extract_single_group_returns_frame(any_string_dtype): + # GH11386 extract should always return DataFrame, even when + # there is only one group. Prior to v0.18.0, extract returned + # Series when there was only one group in the regex. + s = Series(["a3", "b3", "c2"], name="series_name", dtype=any_string_dtype) + result = s.str.extract(r"(?P[a-z])", expand=True) + expected = DataFrame({"letter": ["a", "b", "c"]}, dtype=any_string_dtype) + tm.assert_frame_equal(result, expected) + + +def test_extractall(any_string_dtype): + data = [ + "dave@google.com", + "tdhock5@gmail.com", + "maudelaperriere@gmail.com", + "rob@gmail.com some text steve@gmail.com", + "a@b.com some text c@d.com and e@f.com", + np.nan, + "", + ] + expected_tuples = [ + ("dave", "google", "com"), + ("tdhock5", "gmail", "com"), + ("maudelaperriere", "gmail", "com"), + ("rob", "gmail", "com"), + ("steve", "gmail", "com"), + ("a", "b", "com"), + ("c", "d", "com"), + ("e", "f", "com"), + ] + pat = r""" + (?P[a-z0-9]+) + @ + (?P[a-z]+) + \. + (?P[a-z]{2,4}) + """ + expected_columns = ["user", "domain", "tld"] + s = Series(data, dtype=any_string_dtype) + # extractall should return a DataFrame with one row for each match, indexed by the + # subject from which the match came. + expected_index = MultiIndex.from_tuples( + [(0, 0), (1, 0), (2, 0), (3, 0), (3, 1), (4, 0), (4, 1), (4, 2)], + names=(None, "match"), + ) + expected = DataFrame( + expected_tuples, expected_index, expected_columns, dtype=any_string_dtype + ) + result = s.str.extractall(pat, flags=re.VERBOSE) + tm.assert_frame_equal(result, expected) + + # The index of the input Series should be used to construct the index of the output + # DataFrame: + mi = MultiIndex.from_tuples( + [ + ("single", "Dave"), + ("single", "Toby"), + ("single", "Maude"), + ("multiple", "robAndSteve"), + ("multiple", "abcdef"), + ("none", "missing"), + ("none", "empty"), + ] + ) + s = Series(data, index=mi, dtype=any_string_dtype) + expected_index = MultiIndex.from_tuples( + [ + ("single", "Dave", 0), + ("single", "Toby", 0), + ("single", "Maude", 0), + ("multiple", "robAndSteve", 0), + ("multiple", "robAndSteve", 1), + ("multiple", "abcdef", 0), + ("multiple", "abcdef", 1), + ("multiple", "abcdef", 2), + ], + names=(None, None, "match"), + ) + expected = DataFrame( + expected_tuples, expected_index, expected_columns, dtype=any_string_dtype + ) + result = s.str.extractall(pat, flags=re.VERBOSE) + tm.assert_frame_equal(result, expected) + + # MultiIndexed subject with names. + s = Series(data, index=mi, dtype=any_string_dtype) + s.index.names = ("matches", "description") + expected_index.names = ("matches", "description", "match") + expected = DataFrame( + expected_tuples, expected_index, expected_columns, dtype=any_string_dtype + ) + result = s.str.extractall(pat, flags=re.VERBOSE) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "pat,expected_names", + [ + # optional groups. + ("(?P[AB])?(?P[123])", ["letter", "number"]), + # only one of two groups has a name. + ("([AB])?(?P[123])", [0, "number"]), + ], +) +def test_extractall_column_names(pat, expected_names, any_string_dtype): + s = Series(["", "A1", "32"], dtype=any_string_dtype) + + result = s.str.extractall(pat) + expected = DataFrame( + [("A", "1"), (np.nan, "3"), (np.nan, "2")], + index=MultiIndex.from_tuples([(1, 0), (2, 0), (2, 1)], names=(None, "match")), + columns=expected_names, + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + +def test_extractall_single_group(any_string_dtype): + s = Series(["a3", "b3", "d4c2"], name="series_name", dtype=any_string_dtype) + expected_index = MultiIndex.from_tuples( + [(0, 0), (1, 0), (2, 0), (2, 1)], names=(None, "match") + ) + + # extractall(one named group) returns DataFrame with one named column. + result = s.str.extractall(r"(?P[a-z])") + expected = DataFrame( + {"letter": ["a", "b", "d", "c"]}, index=expected_index, dtype=any_string_dtype + ) + tm.assert_frame_equal(result, expected) + + # extractall(one un-named group) returns DataFrame with one un-named column. + result = s.str.extractall(r"([a-z])") + expected = DataFrame( + ["a", "b", "d", "c"], index=expected_index, dtype=any_string_dtype + ) + tm.assert_frame_equal(result, expected) + + +def test_extractall_single_group_with_quantifier(any_string_dtype): + # GH#13382 + # extractall(one un-named group with quantifier) returns DataFrame with one un-named + # column. + s = Series(["ab3", "abc3", "d4cd2"], name="series_name", dtype=any_string_dtype) + result = s.str.extractall(r"([a-z]+)") + expected = DataFrame( + ["ab", "abc", "d", "cd"], + index=MultiIndex.from_tuples( + [(0, 0), (1, 0), (2, 0), (2, 1)], names=(None, "match") + ), + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "data, names", + [ + ([], (None,)), + ([], ("i1",)), + ([], (None, "i2")), + ([], ("i1", "i2")), + (["a3", "b3", "d4c2"], (None,)), + (["a3", "b3", "d4c2"], ("i1", "i2")), + (["a3", "b3", "d4c2"], (None, "i2")), + (["a3", "b3", "d4c2"], ("i1", "i2")), + ], +) +def test_extractall_no_matches(data, names, any_string_dtype): + # GH19075 extractall with no matches should return a valid MultiIndex + n = len(data) + if len(names) == 1: + index = Index(range(n), name=names[0]) + else: + tuples = (tuple([i] * (n - 1)) for i in range(n)) + index = MultiIndex.from_tuples(tuples, names=names) + s = Series(data, name="series_name", index=index, dtype=any_string_dtype) + expected_index = MultiIndex.from_tuples([], names=(names + ("match",))) + + # one un-named group. + result = s.str.extractall("(z)") + expected = DataFrame(columns=[0], index=expected_index, dtype=any_string_dtype) + tm.assert_frame_equal(result, expected) + + # two un-named groups. + result = s.str.extractall("(z)(z)") + expected = DataFrame(columns=[0, 1], index=expected_index, dtype=any_string_dtype) + tm.assert_frame_equal(result, expected) + + # one named group. + result = s.str.extractall("(?Pz)") + expected = DataFrame( + columns=["first"], index=expected_index, dtype=any_string_dtype + ) + tm.assert_frame_equal(result, expected) + + # two named groups. + result = s.str.extractall("(?Pz)(?Pz)") + expected = DataFrame( + columns=["first", "second"], index=expected_index, dtype=any_string_dtype + ) + tm.assert_frame_equal(result, expected) + + # one named, one un-named. + result = s.str.extractall("(z)(?Pz)") + expected = DataFrame( + columns=[0, "second"], index=expected_index, dtype=any_string_dtype + ) + tm.assert_frame_equal(result, expected) + + +def test_extractall_stringindex(any_string_dtype): + s = Series(["a1a2", "b1", "c1"], name="xxx", dtype=any_string_dtype) + result = s.str.extractall(r"[ab](?P\d)") + expected = DataFrame( + {"digit": ["1", "2", "1"]}, + index=MultiIndex.from_tuples([(0, 0), (0, 1), (1, 0)], names=[None, "match"]), + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + # index should return the same result as the default index without name thus + # index.name doesn't affect to the result + if any_string_dtype == "object": + for idx in [ + Index(["a1a2", "b1", "c1"], dtype=object), + Index(["a1a2", "b1", "c1"], name="xxx", dtype=object), + ]: + result = idx.str.extractall(r"[ab](?P\d)") + tm.assert_frame_equal(result, expected) + + s = Series( + ["a1a2", "b1", "c1"], + name="s_name", + index=Index(["XX", "yy", "zz"], name="idx_name"), + dtype=any_string_dtype, + ) + result = s.str.extractall(r"[ab](?P\d)") + expected = DataFrame( + {"digit": ["1", "2", "1"]}, + index=MultiIndex.from_tuples( + [("XX", 0), ("XX", 1), ("yy", 0)], names=["idx_name", "match"] + ), + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + +def test_extractall_no_capture_groups_raises(any_string_dtype): + # Does not make sense to use extractall with a regex that has no capture groups. + # (it returns DataFrame with one column for each capture group) + s = Series(["a3", "b3", "d4c2"], name="series_name", dtype=any_string_dtype) + with pytest.raises(ValueError, match="no capture groups"): + s.str.extractall(r"[a-z]") + + +def test_extract_index_one_two_groups(): + s = Series(["a3", "b3", "d4c2"], index=["A3", "B3", "D4"], name="series_name") + r = s.index.str.extract(r"([A-Z])", expand=True) + e = DataFrame(["A", "B", "D"]) + tm.assert_frame_equal(r, e) + + # Prior to v0.18.0, index.str.extract(regex with one group) + # returned Index. With more than one group, extract raised an + # error (GH9980). Now extract always returns DataFrame. + r = s.index.str.extract(r"(?P[A-Z])(?P[0-9])", expand=True) + e_list = [("A", "3"), ("B", "3"), ("D", "4")] + e = DataFrame(e_list, columns=["letter", "digit"]) + tm.assert_frame_equal(r, e) + + +def test_extractall_same_as_extract(any_string_dtype): + s = Series(["a3", "b3", "c2"], name="series_name", dtype=any_string_dtype) + + pattern_two_noname = r"([a-z])([0-9])" + extract_two_noname = s.str.extract(pattern_two_noname, expand=True) + has_multi_index = s.str.extractall(pattern_two_noname) + no_multi_index = has_multi_index.xs(0, level="match") + tm.assert_frame_equal(extract_two_noname, no_multi_index) + + pattern_two_named = r"(?P[a-z])(?P[0-9])" + extract_two_named = s.str.extract(pattern_two_named, expand=True) + has_multi_index = s.str.extractall(pattern_two_named) + no_multi_index = has_multi_index.xs(0, level="match") + tm.assert_frame_equal(extract_two_named, no_multi_index) + + pattern_one_named = r"(?P[a-z])" + extract_one_named = s.str.extract(pattern_one_named, expand=True) + has_multi_index = s.str.extractall(pattern_one_named) + no_multi_index = has_multi_index.xs(0, level="match") + tm.assert_frame_equal(extract_one_named, no_multi_index) + + pattern_one_noname = r"([a-z])" + extract_one_noname = s.str.extract(pattern_one_noname, expand=True) + has_multi_index = s.str.extractall(pattern_one_noname) + no_multi_index = has_multi_index.xs(0, level="match") + tm.assert_frame_equal(extract_one_noname, no_multi_index) + + +def test_extractall_same_as_extract_subject_index(any_string_dtype): + # same as above tests, but s has an MultiIndex. + mi = MultiIndex.from_tuples( + [("A", "first"), ("B", "second"), ("C", "third")], + names=("capital", "ordinal"), + ) + s = Series(["a3", "b3", "c2"], index=mi, name="series_name", dtype=any_string_dtype) + + pattern_two_noname = r"([a-z])([0-9])" + extract_two_noname = s.str.extract(pattern_two_noname, expand=True) + has_match_index = s.str.extractall(pattern_two_noname) + no_match_index = has_match_index.xs(0, level="match") + tm.assert_frame_equal(extract_two_noname, no_match_index) + + pattern_two_named = r"(?P[a-z])(?P[0-9])" + extract_two_named = s.str.extract(pattern_two_named, expand=True) + has_match_index = s.str.extractall(pattern_two_named) + no_match_index = has_match_index.xs(0, level="match") + tm.assert_frame_equal(extract_two_named, no_match_index) + + pattern_one_named = r"(?P[a-z])" + extract_one_named = s.str.extract(pattern_one_named, expand=True) + has_match_index = s.str.extractall(pattern_one_named) + no_match_index = has_match_index.xs(0, level="match") + tm.assert_frame_equal(extract_one_named, no_match_index) + + pattern_one_noname = r"([a-z])" + extract_one_noname = s.str.extract(pattern_one_noname, expand=True) + has_match_index = s.str.extractall(pattern_one_noname) + no_match_index = has_match_index.xs(0, level="match") + tm.assert_frame_equal(extract_one_noname, no_match_index) + + +def test_extractall_preserves_dtype(): + # Ensure that when extractall is called on a series with specific dtypes set, that + # the dtype is preserved in the resulting DataFrame's column. + pa = pytest.importorskip("pyarrow") + + result = Series(["abc", "ab"], dtype=ArrowDtype(pa.string())).str.extractall("(ab)") + assert result.dtypes[0] == "string[pyarrow]" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_find_replace.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_find_replace.py new file mode 100644 index 0000000000000000000000000000000000000000..cd4707ac405de391f2fea3fca8e9578e0ba8aef2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_find_replace.py @@ -0,0 +1,972 @@ +from datetime import datetime +import re + +import numpy as np +import pytest + +from pandas.errors import PerformanceWarning +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + Series, + _testing as tm, +) +from pandas.tests.strings import ( + _convert_na_value, + object_pyarrow_numpy, +) + +# -------------------------------------------------------------------------------------- +# str.contains +# -------------------------------------------------------------------------------------- + + +def using_pyarrow(dtype): + return dtype in ("string[pyarrow]", "string[pyarrow_numpy]") + + +def test_contains(any_string_dtype): + values = np.array( + ["foo", np.nan, "fooommm__foo", "mmm_", "foommm[_]+bar"], dtype=np.object_ + ) + values = Series(values, dtype=any_string_dtype) + pat = "mmm[_]+" + + result = values.str.contains(pat) + expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean" + expected = Series( + np.array([False, np.nan, True, True, False], dtype=np.object_), + dtype=expected_dtype, + ) + tm.assert_series_equal(result, expected) + + result = values.str.contains(pat, regex=False) + expected = Series( + np.array([False, np.nan, False, False, True], dtype=np.object_), + dtype=expected_dtype, + ) + tm.assert_series_equal(result, expected) + + values = Series( + np.array(["foo", "xyz", "fooommm__foo", "mmm_"], dtype=object), + dtype=any_string_dtype, + ) + result = values.str.contains(pat) + expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean" + expected = Series(np.array([False, False, True, True]), dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + # case insensitive using regex + values = Series( + np.array(["Foo", "xYz", "fOOomMm__fOo", "MMM_"], dtype=object), + dtype=any_string_dtype, + ) + + result = values.str.contains("FOO|mmm", case=False) + expected = Series(np.array([True, False, True, True]), dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + # case insensitive without regex + result = values.str.contains("foo", regex=False, case=False) + expected = Series(np.array([True, False, True, False]), dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + # unicode + values = Series( + np.array(["foo", np.nan, "fooommm__foo", "mmm_"], dtype=np.object_), + dtype=any_string_dtype, + ) + pat = "mmm[_]+" + + result = values.str.contains(pat) + expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean" + expected = Series( + np.array([False, np.nan, True, True], dtype=np.object_), dtype=expected_dtype + ) + tm.assert_series_equal(result, expected) + + result = values.str.contains(pat, na=False) + expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean" + expected = Series(np.array([False, False, True, True]), dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + values = Series( + np.array(["foo", "xyz", "fooommm__foo", "mmm_"], dtype=np.object_), + dtype=any_string_dtype, + ) + result = values.str.contains(pat) + expected = Series(np.array([False, False, True, True]), dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + +def test_contains_object_mixed(): + mixed = Series( + np.array( + ["a", np.nan, "b", True, datetime.today(), "foo", None, 1, 2.0], + dtype=object, + ) + ) + result = mixed.str.contains("o") + expected = Series( + np.array( + [False, np.nan, False, np.nan, np.nan, True, None, np.nan, np.nan], + dtype=np.object_, + ) + ) + tm.assert_series_equal(result, expected) + + +def test_contains_na_kwarg_for_object_category(): + # gh 22158 + + # na for category + values = Series(["a", "b", "c", "a", np.nan], dtype="category") + result = values.str.contains("a", na=True) + expected = Series([True, False, False, True, True]) + tm.assert_series_equal(result, expected) + + result = values.str.contains("a", na=False) + expected = Series([True, False, False, True, False]) + tm.assert_series_equal(result, expected) + + # na for objects + values = Series(["a", "b", "c", "a", np.nan]) + result = values.str.contains("a", na=True) + expected = Series([True, False, False, True, True]) + tm.assert_series_equal(result, expected) + + result = values.str.contains("a", na=False) + expected = Series([True, False, False, True, False]) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "na, expected", + [ + (None, pd.NA), + (True, True), + (False, False), + (0, False), + (3, True), + (np.nan, pd.NA), + ], +) +@pytest.mark.parametrize("regex", [True, False]) +def test_contains_na_kwarg_for_nullable_string_dtype( + nullable_string_dtype, na, expected, regex +): + # https://github.com/pandas-dev/pandas/pull/41025#issuecomment-824062416 + + values = Series(["a", "b", "c", "a", np.nan], dtype=nullable_string_dtype) + result = values.str.contains("a", na=na, regex=regex) + expected = Series([True, False, False, True, expected], dtype="boolean") + tm.assert_series_equal(result, expected) + + +def test_contains_moar(any_string_dtype): + # PR #1179 + s = Series( + ["A", "B", "C", "Aaba", "Baca", "", np.nan, "CABA", "dog", "cat"], + dtype=any_string_dtype, + ) + + result = s.str.contains("a") + expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean" + expected = Series( + [False, False, False, True, True, False, np.nan, False, False, True], + dtype=expected_dtype, + ) + tm.assert_series_equal(result, expected) + + result = s.str.contains("a", case=False) + expected = Series( + [True, False, False, True, True, False, np.nan, True, False, True], + dtype=expected_dtype, + ) + tm.assert_series_equal(result, expected) + + result = s.str.contains("Aa") + expected = Series( + [False, False, False, True, False, False, np.nan, False, False, False], + dtype=expected_dtype, + ) + tm.assert_series_equal(result, expected) + + result = s.str.contains("ba") + expected = Series( + [False, False, False, True, False, False, np.nan, False, False, False], + dtype=expected_dtype, + ) + tm.assert_series_equal(result, expected) + + result = s.str.contains("ba", case=False) + expected = Series( + [False, False, False, True, True, False, np.nan, True, False, False], + dtype=expected_dtype, + ) + tm.assert_series_equal(result, expected) + + +def test_contains_nan(any_string_dtype): + # PR #14171 + s = Series([np.nan, np.nan, np.nan], dtype=any_string_dtype) + + result = s.str.contains("foo", na=False) + expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean" + expected = Series([False, False, False], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + result = s.str.contains("foo", na=True) + expected = Series([True, True, True], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + result = s.str.contains("foo", na="foo") + if any_string_dtype == "object": + expected = Series(["foo", "foo", "foo"], dtype=np.object_) + elif any_string_dtype == "string[pyarrow_numpy]": + expected = Series([True, True, True], dtype=np.bool_) + else: + expected = Series([True, True, True], dtype="boolean") + tm.assert_series_equal(result, expected) + + result = s.str.contains("foo") + expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean" + expected = Series([np.nan, np.nan, np.nan], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + +# -------------------------------------------------------------------------------------- +# str.startswith +# -------------------------------------------------------------------------------------- + + +@pytest.mark.parametrize("pat", ["foo", ("foo", "baz")]) +@pytest.mark.parametrize("dtype", ["object", "category"]) +@pytest.mark.parametrize("null_value", [None, np.nan, pd.NA]) +@pytest.mark.parametrize("na", [True, False]) +def test_startswith(pat, dtype, null_value, na): + # add category dtype parametrizations for GH-36241 + values = Series( + ["om", null_value, "foo_nom", "nom", "bar_foo", null_value, "foo"], + dtype=dtype, + ) + + result = values.str.startswith(pat) + exp = Series([False, np.nan, True, False, False, np.nan, True]) + if dtype == "object" and null_value is pd.NA: + # GH#18463 + exp = exp.fillna(null_value) + elif dtype == "object" and null_value is None: + exp[exp.isna()] = None + tm.assert_series_equal(result, exp) + + result = values.str.startswith(pat, na=na) + exp = Series([False, na, True, False, False, na, True]) + tm.assert_series_equal(result, exp) + + # mixed + mixed = np.array( + ["a", np.nan, "b", True, datetime.today(), "foo", None, 1, 2.0], + dtype=np.object_, + ) + rs = Series(mixed).str.startswith("f") + xp = Series([False, np.nan, False, np.nan, np.nan, True, None, np.nan, np.nan]) + tm.assert_series_equal(rs, xp) + + +@pytest.mark.parametrize("na", [None, True, False]) +def test_startswith_nullable_string_dtype(nullable_string_dtype, na): + values = Series( + ["om", None, "foo_nom", "nom", "bar_foo", None, "foo", "regex", "rege."], + dtype=nullable_string_dtype, + ) + result = values.str.startswith("foo", na=na) + exp = Series( + [False, na, True, False, False, na, True, False, False], dtype="boolean" + ) + tm.assert_series_equal(result, exp) + + result = values.str.startswith("rege.", na=na) + exp = Series( + [False, na, False, False, False, na, False, False, True], dtype="boolean" + ) + tm.assert_series_equal(result, exp) + + +# -------------------------------------------------------------------------------------- +# str.endswith +# -------------------------------------------------------------------------------------- + + +@pytest.mark.parametrize("pat", ["foo", ("foo", "baz")]) +@pytest.mark.parametrize("dtype", ["object", "category"]) +@pytest.mark.parametrize("null_value", [None, np.nan, pd.NA]) +@pytest.mark.parametrize("na", [True, False]) +def test_endswith(pat, dtype, null_value, na): + # add category dtype parametrizations for GH-36241 + values = Series( + ["om", null_value, "foo_nom", "nom", "bar_foo", null_value, "foo"], + dtype=dtype, + ) + + result = values.str.endswith(pat) + exp = Series([False, np.nan, False, False, True, np.nan, True]) + if dtype == "object" and null_value is pd.NA: + # GH#18463 + exp = exp.fillna(null_value) + elif dtype == "object" and null_value is None: + exp[exp.isna()] = None + tm.assert_series_equal(result, exp) + + result = values.str.endswith(pat, na=na) + exp = Series([False, na, False, False, True, na, True]) + tm.assert_series_equal(result, exp) + + # mixed + mixed = np.array( + ["a", np.nan, "b", True, datetime.today(), "foo", None, 1, 2.0], + dtype=object, + ) + rs = Series(mixed).str.endswith("f") + xp = Series([False, np.nan, False, np.nan, np.nan, False, None, np.nan, np.nan]) + tm.assert_series_equal(rs, xp) + + +@pytest.mark.parametrize("na", [None, True, False]) +def test_endswith_nullable_string_dtype(nullable_string_dtype, na): + values = Series( + ["om", None, "foo_nom", "nom", "bar_foo", None, "foo", "regex", "rege."], + dtype=nullable_string_dtype, + ) + result = values.str.endswith("foo", na=na) + exp = Series( + [False, na, False, False, True, na, True, False, False], dtype="boolean" + ) + tm.assert_series_equal(result, exp) + + result = values.str.endswith("rege.", na=na) + exp = Series( + [False, na, False, False, False, na, False, False, True], dtype="boolean" + ) + tm.assert_series_equal(result, exp) + + +# -------------------------------------------------------------------------------------- +# str.replace +# -------------------------------------------------------------------------------------- + + +def test_replace(any_string_dtype): + ser = Series(["fooBAD__barBAD", np.nan], dtype=any_string_dtype) + + result = ser.str.replace("BAD[_]*", "", regex=True) + expected = Series(["foobar", np.nan], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +def test_replace_max_replacements(any_string_dtype): + ser = Series(["fooBAD__barBAD", np.nan], dtype=any_string_dtype) + + expected = Series(["foobarBAD", np.nan], dtype=any_string_dtype) + result = ser.str.replace("BAD[_]*", "", n=1, regex=True) + tm.assert_series_equal(result, expected) + + expected = Series(["foo__barBAD", np.nan], dtype=any_string_dtype) + result = ser.str.replace("BAD", "", n=1, regex=False) + tm.assert_series_equal(result, expected) + + +def test_replace_mixed_object(): + ser = Series( + ["aBAD", np.nan, "bBAD", True, datetime.today(), "fooBAD", None, 1, 2.0] + ) + result = Series(ser).str.replace("BAD[_]*", "", regex=True) + expected = Series( + ["a", np.nan, "b", np.nan, np.nan, "foo", None, np.nan, np.nan], dtype=object + ) + tm.assert_series_equal(result, expected) + + +def test_replace_unicode(any_string_dtype): + ser = Series([b"abcd,\xc3\xa0".decode("utf-8")], dtype=any_string_dtype) + expected = Series([b"abcd, \xc3\xa0".decode("utf-8")], dtype=any_string_dtype) + with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)): + result = ser.str.replace(r"(?<=\w),(?=\w)", ", ", flags=re.UNICODE, regex=True) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("repl", [None, 3, {"a": "b"}]) +@pytest.mark.parametrize("data", [["a", "b", None], ["a", "b", "c", "ad"]]) +def test_replace_wrong_repl_type_raises(any_string_dtype, index_or_series, repl, data): + # https://github.com/pandas-dev/pandas/issues/13438 + msg = "repl must be a string or callable" + obj = index_or_series(data, dtype=any_string_dtype) + with pytest.raises(TypeError, match=msg): + obj.str.replace("a", repl) + + +def test_replace_callable(any_string_dtype): + # GH 15055 + ser = Series(["fooBAD__barBAD", np.nan], dtype=any_string_dtype) + + # test with callable + repl = lambda m: m.group(0).swapcase() + with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)): + result = ser.str.replace("[a-z][A-Z]{2}", repl, n=2, regex=True) + expected = Series(["foObaD__baRbaD", np.nan], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "repl", [lambda: None, lambda m, x: None, lambda m, x, y=None: None] +) +def test_replace_callable_raises(any_string_dtype, repl): + # GH 15055 + values = Series(["fooBAD__barBAD", np.nan], dtype=any_string_dtype) + + # test with wrong number of arguments, raising an error + msg = ( + r"((takes)|(missing)) (?(2)from \d+ to )?\d+ " + r"(?(3)required )positional arguments?" + ) + with pytest.raises(TypeError, match=msg): + with tm.maybe_produces_warning( + PerformanceWarning, using_pyarrow(any_string_dtype) + ): + values.str.replace("a", repl, regex=True) + + +def test_replace_callable_named_groups(any_string_dtype): + # test regex named groups + ser = Series(["Foo Bar Baz", np.nan], dtype=any_string_dtype) + pat = r"(?P\w+) (?P\w+) (?P\w+)" + repl = lambda m: m.group("middle").swapcase() + with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)): + result = ser.str.replace(pat, repl, regex=True) + expected = Series(["bAR", np.nan], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +def test_replace_compiled_regex(any_string_dtype): + # GH 15446 + ser = Series(["fooBAD__barBAD", np.nan], dtype=any_string_dtype) + + # test with compiled regex + pat = re.compile(r"BAD_*") + with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)): + result = ser.str.replace(pat, "", regex=True) + expected = Series(["foobar", np.nan], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)): + result = ser.str.replace(pat, "", n=1, regex=True) + expected = Series(["foobarBAD", np.nan], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +def test_replace_compiled_regex_mixed_object(): + pat = re.compile(r"BAD_*") + ser = Series( + ["aBAD", np.nan, "bBAD", True, datetime.today(), "fooBAD", None, 1, 2.0] + ) + result = Series(ser).str.replace(pat, "", regex=True) + expected = Series( + ["a", np.nan, "b", np.nan, np.nan, "foo", None, np.nan, np.nan], dtype=object + ) + tm.assert_series_equal(result, expected) + + +def test_replace_compiled_regex_unicode(any_string_dtype): + ser = Series([b"abcd,\xc3\xa0".decode("utf-8")], dtype=any_string_dtype) + expected = Series([b"abcd, \xc3\xa0".decode("utf-8")], dtype=any_string_dtype) + pat = re.compile(r"(?<=\w),(?=\w)", flags=re.UNICODE) + with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)): + result = ser.str.replace(pat, ", ", regex=True) + tm.assert_series_equal(result, expected) + + +def test_replace_compiled_regex_raises(any_string_dtype): + # case and flags provided to str.replace will have no effect + # and will produce warnings + ser = Series(["fooBAD__barBAD__bad", np.nan], dtype=any_string_dtype) + pat = re.compile(r"BAD_*") + + msg = "case and flags cannot be set when pat is a compiled regex" + + with pytest.raises(ValueError, match=msg): + ser.str.replace(pat, "", flags=re.IGNORECASE, regex=True) + + with pytest.raises(ValueError, match=msg): + ser.str.replace(pat, "", case=False, regex=True) + + with pytest.raises(ValueError, match=msg): + ser.str.replace(pat, "", case=True, regex=True) + + +def test_replace_compiled_regex_callable(any_string_dtype): + # test with callable + ser = Series(["fooBAD__barBAD", np.nan], dtype=any_string_dtype) + repl = lambda m: m.group(0).swapcase() + pat = re.compile("[a-z][A-Z]{2}") + with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)): + result = ser.str.replace(pat, repl, n=2, regex=True) + expected = Series(["foObaD__baRbaD", np.nan], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "regex,expected", [(True, ["bao", "bao", np.nan]), (False, ["bao", "foo", np.nan])] +) +def test_replace_literal(regex, expected, any_string_dtype): + # GH16808 literal replace (regex=False vs regex=True) + ser = Series(["f.o", "foo", np.nan], dtype=any_string_dtype) + expected = Series(expected, dtype=any_string_dtype) + result = ser.str.replace("f.", "ba", regex=regex) + tm.assert_series_equal(result, expected) + + +def test_replace_literal_callable_raises(any_string_dtype): + ser = Series([], dtype=any_string_dtype) + repl = lambda m: m.group(0).swapcase() + + msg = "Cannot use a callable replacement when regex=False" + with pytest.raises(ValueError, match=msg): + ser.str.replace("abc", repl, regex=False) + + +def test_replace_literal_compiled_raises(any_string_dtype): + ser = Series([], dtype=any_string_dtype) + pat = re.compile("[a-z][A-Z]{2}") + + msg = "Cannot use a compiled regex as replacement pattern with regex=False" + with pytest.raises(ValueError, match=msg): + ser.str.replace(pat, "", regex=False) + + +def test_replace_moar(any_string_dtype): + # PR #1179 + ser = Series( + ["A", "B", "C", "Aaba", "Baca", "", np.nan, "CABA", "dog", "cat"], + dtype=any_string_dtype, + ) + + result = ser.str.replace("A", "YYY") + expected = Series( + ["YYY", "B", "C", "YYYaba", "Baca", "", np.nan, "CYYYBYYY", "dog", "cat"], + dtype=any_string_dtype, + ) + tm.assert_series_equal(result, expected) + + with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)): + result = ser.str.replace("A", "YYY", case=False) + expected = Series( + [ + "YYY", + "B", + "C", + "YYYYYYbYYY", + "BYYYcYYY", + "", + np.nan, + "CYYYBYYY", + "dog", + "cYYYt", + ], + dtype=any_string_dtype, + ) + tm.assert_series_equal(result, expected) + + with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)): + result = ser.str.replace("^.a|dog", "XX-XX ", case=False, regex=True) + expected = Series( + [ + "A", + "B", + "C", + "XX-XX ba", + "XX-XX ca", + "", + np.nan, + "XX-XX BA", + "XX-XX ", + "XX-XX t", + ], + dtype=any_string_dtype, + ) + tm.assert_series_equal(result, expected) + + +def test_replace_not_case_sensitive_not_regex(any_string_dtype): + # https://github.com/pandas-dev/pandas/issues/41602 + ser = Series(["A.", "a.", "Ab", "ab", np.nan], dtype=any_string_dtype) + + with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)): + result = ser.str.replace("a", "c", case=False, regex=False) + expected = Series(["c.", "c.", "cb", "cb", np.nan], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)): + result = ser.str.replace("a.", "c.", case=False, regex=False) + expected = Series(["c.", "c.", "Ab", "ab", np.nan], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +def test_replace_regex(any_string_dtype): + # https://github.com/pandas-dev/pandas/pull/24809 + s = Series(["a", "b", "ac", np.nan, ""], dtype=any_string_dtype) + result = s.str.replace("^.$", "a", regex=True) + expected = Series(["a", "a", "ac", np.nan, ""], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("regex", [True, False]) +def test_replace_regex_single_character(regex, any_string_dtype): + # https://github.com/pandas-dev/pandas/pull/24809, enforced in 2.0 + # GH 24804 + s = Series(["a.b", ".", "b", np.nan, ""], dtype=any_string_dtype) + + result = s.str.replace(".", "a", regex=regex) + if regex: + expected = Series(["aaa", "a", "a", np.nan, ""], dtype=any_string_dtype) + else: + expected = Series(["aab", "a", "b", np.nan, ""], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +# -------------------------------------------------------------------------------------- +# str.match +# -------------------------------------------------------------------------------------- + + +def test_match(any_string_dtype): + # New match behavior introduced in 0.13 + expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean" + + values = Series(["fooBAD__barBAD", np.nan, "foo"], dtype=any_string_dtype) + result = values.str.match(".*(BAD[_]+).*(BAD)") + expected = Series([True, np.nan, False], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + values = Series( + ["fooBAD__barBAD", "BAD_BADleroybrown", np.nan, "foo"], dtype=any_string_dtype + ) + result = values.str.match(".*BAD[_]+.*BAD") + expected = Series([True, True, np.nan, False], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + result = values.str.match("BAD[_]+.*BAD") + expected = Series([False, True, np.nan, False], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + values = Series( + ["fooBAD__barBAD", "^BAD_BADleroybrown", np.nan, "foo"], dtype=any_string_dtype + ) + result = values.str.match("^BAD[_]+.*BAD") + expected = Series([False, False, np.nan, False], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + result = values.str.match("\\^BAD[_]+.*BAD") + expected = Series([False, True, np.nan, False], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + +def test_match_mixed_object(): + mixed = Series( + [ + "aBAD_BAD", + np.nan, + "BAD_b_BAD", + True, + datetime.today(), + "foo", + None, + 1, + 2.0, + ] + ) + result = Series(mixed).str.match(".*(BAD[_]+).*(BAD)") + expected = Series([True, np.nan, True, np.nan, np.nan, False, None, np.nan, np.nan]) + assert isinstance(result, Series) + tm.assert_series_equal(result, expected) + + +def test_match_na_kwarg(any_string_dtype): + # GH #6609 + s = Series(["a", "b", np.nan], dtype=any_string_dtype) + + result = s.str.match("a", na=False) + expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean" + expected = Series([True, False, False], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + result = s.str.match("a") + expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean" + expected = Series([True, False, np.nan], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + +def test_match_case_kwarg(any_string_dtype): + values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype) + result = values.str.match("ab", case=False) + expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean" + expected = Series([True, True, True, True], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + +# -------------------------------------------------------------------------------------- +# str.fullmatch +# -------------------------------------------------------------------------------------- + + +def test_fullmatch(any_string_dtype): + # GH 32806 + ser = Series( + ["fooBAD__barBAD", "BAD_BADleroybrown", np.nan, "foo"], dtype=any_string_dtype + ) + result = ser.str.fullmatch(".*BAD[_]+.*BAD") + expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean" + expected = Series([True, False, np.nan, False], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + +def test_fullmatch_dollar_literal(any_string_dtype): + # GH 56652 + ser = Series(["foo", "foo$foo", np.nan, "foo$"], dtype=any_string_dtype) + result = ser.str.fullmatch("foo\\$") + expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean" + expected = Series([False, False, np.nan, True], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + +def test_fullmatch_na_kwarg(any_string_dtype): + ser = Series( + ["fooBAD__barBAD", "BAD_BADleroybrown", np.nan, "foo"], dtype=any_string_dtype + ) + result = ser.str.fullmatch(".*BAD[_]+.*BAD", na=False) + expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean" + expected = Series([True, False, False, False], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + +def test_fullmatch_case_kwarg(any_string_dtype): + ser = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype) + expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean" + + expected = Series([True, False, False, False], dtype=expected_dtype) + + result = ser.str.fullmatch("ab", case=True) + tm.assert_series_equal(result, expected) + + expected = Series([True, True, False, False], dtype=expected_dtype) + + result = ser.str.fullmatch("ab", case=False) + tm.assert_series_equal(result, expected) + + with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)): + result = ser.str.fullmatch("ab", flags=re.IGNORECASE) + tm.assert_series_equal(result, expected) + + +# -------------------------------------------------------------------------------------- +# str.findall +# -------------------------------------------------------------------------------------- + + +def test_findall(any_string_dtype): + ser = Series(["fooBAD__barBAD", np.nan, "foo", "BAD"], dtype=any_string_dtype) + result = ser.str.findall("BAD[_]*") + expected = Series([["BAD__", "BAD"], np.nan, [], ["BAD"]]) + expected = _convert_na_value(ser, expected) + tm.assert_series_equal(result, expected) + + +def test_findall_mixed_object(): + ser = Series( + [ + "fooBAD__barBAD", + np.nan, + "foo", + True, + datetime.today(), + "BAD", + None, + 1, + 2.0, + ] + ) + + result = ser.str.findall("BAD[_]*") + expected = Series( + [ + ["BAD__", "BAD"], + np.nan, + [], + np.nan, + np.nan, + ["BAD"], + None, + np.nan, + np.nan, + ] + ) + + tm.assert_series_equal(result, expected) + + +# -------------------------------------------------------------------------------------- +# str.find +# -------------------------------------------------------------------------------------- + + +def test_find(any_string_dtype): + ser = Series( + ["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF", "XXXX"], dtype=any_string_dtype + ) + expected_dtype = np.int64 if any_string_dtype in object_pyarrow_numpy else "Int64" + + result = ser.str.find("EF") + expected = Series([4, 3, 1, 0, -1], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + expected = np.array([v.find("EF") for v in np.array(ser)], dtype=np.int64) + tm.assert_numpy_array_equal(np.array(result, dtype=np.int64), expected) + + result = ser.str.rfind("EF") + expected = Series([4, 5, 7, 4, -1], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + expected = np.array([v.rfind("EF") for v in np.array(ser)], dtype=np.int64) + tm.assert_numpy_array_equal(np.array(result, dtype=np.int64), expected) + + result = ser.str.find("EF", 3) + expected = Series([4, 3, 7, 4, -1], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + expected = np.array([v.find("EF", 3) for v in np.array(ser)], dtype=np.int64) + tm.assert_numpy_array_equal(np.array(result, dtype=np.int64), expected) + + result = ser.str.rfind("EF", 3) + expected = Series([4, 5, 7, 4, -1], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + expected = np.array([v.rfind("EF", 3) for v in np.array(ser)], dtype=np.int64) + tm.assert_numpy_array_equal(np.array(result, dtype=np.int64), expected) + + result = ser.str.find("EF", 3, 6) + expected = Series([4, 3, -1, 4, -1], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + expected = np.array([v.find("EF", 3, 6) for v in np.array(ser)], dtype=np.int64) + tm.assert_numpy_array_equal(np.array(result, dtype=np.int64), expected) + + result = ser.str.rfind("EF", 3, 6) + expected = Series([4, 3, -1, 4, -1], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + expected = np.array([v.rfind("EF", 3, 6) for v in np.array(ser)], dtype=np.int64) + tm.assert_numpy_array_equal(np.array(result, dtype=np.int64), expected) + + +def test_find_bad_arg_raises(any_string_dtype): + ser = Series([], dtype=any_string_dtype) + with pytest.raises(TypeError, match="expected a string object, not int"): + ser.str.find(0) + + with pytest.raises(TypeError, match="expected a string object, not int"): + ser.str.rfind(0) + + +def test_find_nan(any_string_dtype): + ser = Series( + ["ABCDEFG", np.nan, "DEFGHIJEF", np.nan, "XXXX"], dtype=any_string_dtype + ) + expected_dtype = np.float64 if any_string_dtype in object_pyarrow_numpy else "Int64" + + result = ser.str.find("EF") + expected = Series([4, np.nan, 1, np.nan, -1], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + result = ser.str.rfind("EF") + expected = Series([4, np.nan, 7, np.nan, -1], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + result = ser.str.find("EF", 3) + expected = Series([4, np.nan, 7, np.nan, -1], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + result = ser.str.rfind("EF", 3) + expected = Series([4, np.nan, 7, np.nan, -1], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + result = ser.str.find("EF", 3, 6) + expected = Series([4, np.nan, -1, np.nan, -1], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + result = ser.str.rfind("EF", 3, 6) + expected = Series([4, np.nan, -1, np.nan, -1], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + +# -------------------------------------------------------------------------------------- +# str.translate +# -------------------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "infer_string", [False, pytest.param(True, marks=td.skip_if_no("pyarrow"))] +) +def test_translate(index_or_series, any_string_dtype, infer_string): + obj = index_or_series( + ["abcdefg", "abcc", "cdddfg", "cdefggg"], dtype=any_string_dtype + ) + table = str.maketrans("abc", "cde") + result = obj.str.translate(table) + expected = index_or_series( + ["cdedefg", "cdee", "edddfg", "edefggg"], dtype=any_string_dtype + ) + tm.assert_equal(result, expected) + + +def test_translate_mixed_object(): + # Series with non-string values + s = Series(["a", "b", "c", 1.2]) + table = str.maketrans("abc", "cde") + expected = Series(["c", "d", "e", np.nan], dtype=object) + result = s.str.translate(table) + tm.assert_series_equal(result, expected) + + +# -------------------------------------------------------------------------------------- + + +def test_flags_kwarg(any_string_dtype): + data = { + "Dave": "dave@google.com", + "Steve": "steve@gmail.com", + "Rob": "rob@gmail.com", + "Wes": np.nan, + } + data = Series(data, dtype=any_string_dtype) + + pat = r"([A-Z0-9._%+-]+)@([A-Z0-9.-]+)\.([A-Z]{2,4})" + + use_pyarrow = using_pyarrow(any_string_dtype) + + result = data.str.extract(pat, flags=re.IGNORECASE, expand=True) + assert result.iloc[0].tolist() == ["dave", "google", "com"] + + with tm.maybe_produces_warning(PerformanceWarning, use_pyarrow): + result = data.str.match(pat, flags=re.IGNORECASE) + assert result.iloc[0] + + with tm.maybe_produces_warning(PerformanceWarning, use_pyarrow): + result = data.str.fullmatch(pat, flags=re.IGNORECASE) + assert result.iloc[0] + + result = data.str.findall(pat, flags=re.IGNORECASE) + assert result.iloc[0][0] == ("dave", "google", "com") + + result = data.str.count(pat, flags=re.IGNORECASE) + assert result.iloc[0] == 1 + + msg = "has match groups" + with tm.assert_produces_warning( + UserWarning, match=msg, raise_on_extra_warnings=not use_pyarrow + ): + result = data.str.contains(pat, flags=re.IGNORECASE) + assert result.iloc[0] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_get_dummies.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_get_dummies.py new file mode 100644 index 0000000000000000000000000000000000000000..31386e4e342ae3676a5468cfff5035686821fd52 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_get_dummies.py @@ -0,0 +1,53 @@ +import numpy as np + +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, + _testing as tm, +) + + +def test_get_dummies(any_string_dtype): + s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype) + result = s.str.get_dummies("|") + expected = DataFrame([[1, 1, 0], [1, 0, 1], [0, 0, 0]], columns=list("abc")) + tm.assert_frame_equal(result, expected) + + s = Series(["a;b", "a", 7], dtype=any_string_dtype) + result = s.str.get_dummies(";") + expected = DataFrame([[0, 1, 1], [0, 1, 0], [1, 0, 0]], columns=list("7ab")) + tm.assert_frame_equal(result, expected) + + +def test_get_dummies_index(): + # GH9980, GH8028 + idx = Index(["a|b", "a|c", "b|c"]) + result = idx.str.get_dummies("|") + + expected = MultiIndex.from_tuples( + [(1, 1, 0), (1, 0, 1), (0, 1, 1)], names=("a", "b", "c") + ) + tm.assert_index_equal(result, expected) + + +def test_get_dummies_with_name_dummy(any_string_dtype): + # GH 12180 + # Dummies named 'name' should work as expected + s = Series(["a", "b,name", "b"], dtype=any_string_dtype) + result = s.str.get_dummies(",") + expected = DataFrame([[1, 0, 0], [0, 1, 1], [0, 1, 0]], columns=["a", "b", "name"]) + tm.assert_frame_equal(result, expected) + + +def test_get_dummies_with_name_dummy_index(): + # GH 12180 + # Dummies named 'name' should work as expected + idx = Index(["a|b", "name|c", "b|name"]) + result = idx.str.get_dummies("|") + + expected = MultiIndex.from_tuples( + [(1, 1, 0, 0), (0, 0, 1, 1), (0, 1, 0, 1)], names=("a", "b", "c", "name") + ) + tm.assert_index_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_split_partition.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_split_partition.py new file mode 100644 index 0000000000000000000000000000000000000000..9ff1fc0e13ae9ed3514fe02928049d19c0c275a9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_split_partition.py @@ -0,0 +1,734 @@ +from datetime import datetime +import re + +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, + _testing as tm, +) +from pandas.tests.strings import ( + _convert_na_value, + object_pyarrow_numpy, +) + + +@pytest.mark.parametrize("method", ["split", "rsplit"]) +def test_split(any_string_dtype, method): + values = Series(["a_b_c", "c_d_e", np.nan, "f_g_h"], dtype=any_string_dtype) + + result = getattr(values.str, method)("_") + exp = Series([["a", "b", "c"], ["c", "d", "e"], np.nan, ["f", "g", "h"]]) + exp = _convert_na_value(values, exp) + tm.assert_series_equal(result, exp) + + +@pytest.mark.parametrize("method", ["split", "rsplit"]) +def test_split_more_than_one_char(any_string_dtype, method): + # more than one char + values = Series(["a__b__c", "c__d__e", np.nan, "f__g__h"], dtype=any_string_dtype) + result = getattr(values.str, method)("__") + exp = Series([["a", "b", "c"], ["c", "d", "e"], np.nan, ["f", "g", "h"]]) + exp = _convert_na_value(values, exp) + tm.assert_series_equal(result, exp) + + result = getattr(values.str, method)("__", expand=False) + tm.assert_series_equal(result, exp) + + +def test_split_more_regex_split(any_string_dtype): + # regex split + values = Series(["a,b_c", "c_d,e", np.nan, "f,g,h"], dtype=any_string_dtype) + result = values.str.split("[,_]") + exp = Series([["a", "b", "c"], ["c", "d", "e"], np.nan, ["f", "g", "h"]]) + exp = _convert_na_value(values, exp) + tm.assert_series_equal(result, exp) + + +def test_split_regex(any_string_dtype): + # GH 43563 + # explicit regex = True split + values = Series("xxxjpgzzz.jpg", dtype=any_string_dtype) + result = values.str.split(r"\.jpg", regex=True) + exp = Series([["xxxjpgzzz", ""]]) + tm.assert_series_equal(result, exp) + + +def test_split_regex_explicit(any_string_dtype): + # explicit regex = True split with compiled regex + regex_pat = re.compile(r".jpg") + values = Series("xxxjpgzzz.jpg", dtype=any_string_dtype) + result = values.str.split(regex_pat) + exp = Series([["xx", "zzz", ""]]) + tm.assert_series_equal(result, exp) + + # explicit regex = False split + result = values.str.split(r"\.jpg", regex=False) + exp = Series([["xxxjpgzzz.jpg"]]) + tm.assert_series_equal(result, exp) + + # non explicit regex split, pattern length == 1 + result = values.str.split(r".") + exp = Series([["xxxjpgzzz", "jpg"]]) + tm.assert_series_equal(result, exp) + + # non explicit regex split, pattern length != 1 + result = values.str.split(r".jpg") + exp = Series([["xx", "zzz", ""]]) + tm.assert_series_equal(result, exp) + + # regex=False with pattern compiled regex raises error + with pytest.raises( + ValueError, + match="Cannot use a compiled regex as replacement pattern with regex=False", + ): + values.str.split(regex_pat, regex=False) + + +@pytest.mark.parametrize("expand", [None, False]) +@pytest.mark.parametrize("method", ["split", "rsplit"]) +def test_split_object_mixed(expand, method): + mixed = Series(["a_b_c", np.nan, "d_e_f", True, datetime.today(), None, 1, 2.0]) + result = getattr(mixed.str, method)("_", expand=expand) + exp = Series( + [ + ["a", "b", "c"], + np.nan, + ["d", "e", "f"], + np.nan, + np.nan, + None, + np.nan, + np.nan, + ] + ) + assert isinstance(result, Series) + tm.assert_almost_equal(result, exp) + + +@pytest.mark.parametrize("method", ["split", "rsplit"]) +@pytest.mark.parametrize("n", [None, 0]) +def test_split_n(any_string_dtype, method, n): + s = Series(["a b", pd.NA, "b c"], dtype=any_string_dtype) + expected = Series([["a", "b"], pd.NA, ["b", "c"]]) + result = getattr(s.str, method)(" ", n=n) + expected = _convert_na_value(s, expected) + tm.assert_series_equal(result, expected) + + +def test_rsplit(any_string_dtype): + # regex split is not supported by rsplit + values = Series(["a,b_c", "c_d,e", np.nan, "f,g,h"], dtype=any_string_dtype) + result = values.str.rsplit("[,_]") + exp = Series([["a,b_c"], ["c_d,e"], np.nan, ["f,g,h"]]) + exp = _convert_na_value(values, exp) + tm.assert_series_equal(result, exp) + + +def test_rsplit_max_number(any_string_dtype): + # setting max number of splits, make sure it's from reverse + values = Series(["a_b_c", "c_d_e", np.nan, "f_g_h"], dtype=any_string_dtype) + result = values.str.rsplit("_", n=1) + exp = Series([["a_b", "c"], ["c_d", "e"], np.nan, ["f_g", "h"]]) + exp = _convert_na_value(values, exp) + tm.assert_series_equal(result, exp) + + +def test_split_blank_string(any_string_dtype): + # expand blank split GH 20067 + values = Series([""], name="test", dtype=any_string_dtype) + result = values.str.split(expand=True) + exp = DataFrame([[]], dtype=any_string_dtype) # NOTE: this is NOT an empty df + tm.assert_frame_equal(result, exp) + + +def test_split_blank_string_with_non_empty(any_string_dtype): + values = Series(["a b c", "a b", "", " "], name="test", dtype=any_string_dtype) + result = values.str.split(expand=True) + exp = DataFrame( + [ + ["a", "b", "c"], + ["a", "b", None], + [None, None, None], + [None, None, None], + ], + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, exp) + + +@pytest.mark.parametrize("method", ["split", "rsplit"]) +def test_split_noargs(any_string_dtype, method): + # #1859 + s = Series(["Wes McKinney", "Travis Oliphant"], dtype=any_string_dtype) + result = getattr(s.str, method)() + expected = ["Travis", "Oliphant"] + assert result[1] == expected + + +@pytest.mark.parametrize( + "data, pat", + [ + (["bd asdf jfg", "kjasdflqw asdfnfk"], None), + (["bd asdf jfg", "kjasdflqw asdfnfk"], "asdf"), + (["bd_asdf_jfg", "kjasdflqw_asdfnfk"], "_"), + ], +) +@pytest.mark.parametrize("n", [-1, 0]) +def test_split_maxsplit(data, pat, any_string_dtype, n): + # re.split 0, str.split -1 + s = Series(data, dtype=any_string_dtype) + + result = s.str.split(pat=pat, n=n) + xp = s.str.split(pat=pat) + tm.assert_series_equal(result, xp) + + +@pytest.mark.parametrize( + "data, pat, expected", + [ + ( + ["split once", "split once too!"], + None, + Series({0: ["split", "once"], 1: ["split", "once too!"]}), + ), + ( + ["split_once", "split_once_too!"], + "_", + Series({0: ["split", "once"], 1: ["split", "once_too!"]}), + ), + ], +) +def test_split_no_pat_with_nonzero_n(data, pat, expected, any_string_dtype): + s = Series(data, dtype=any_string_dtype) + result = s.str.split(pat=pat, n=1) + tm.assert_series_equal(expected, result, check_index_type=False) + + +def test_split_to_dataframe_no_splits(any_string_dtype): + s = Series(["nosplit", "alsonosplit"], dtype=any_string_dtype) + result = s.str.split("_", expand=True) + exp = DataFrame({0: Series(["nosplit", "alsonosplit"], dtype=any_string_dtype)}) + tm.assert_frame_equal(result, exp) + + +def test_split_to_dataframe(any_string_dtype): + s = Series(["some_equal_splits", "with_no_nans"], dtype=any_string_dtype) + result = s.str.split("_", expand=True) + exp = DataFrame( + {0: ["some", "with"], 1: ["equal", "no"], 2: ["splits", "nans"]}, + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, exp) + + +def test_split_to_dataframe_unequal_splits(any_string_dtype): + s = Series( + ["some_unequal_splits", "one_of_these_things_is_not"], dtype=any_string_dtype + ) + result = s.str.split("_", expand=True) + exp = DataFrame( + { + 0: ["some", "one"], + 1: ["unequal", "of"], + 2: ["splits", "these"], + 3: [None, "things"], + 4: [None, "is"], + 5: [None, "not"], + }, + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, exp) + + +def test_split_to_dataframe_with_index(any_string_dtype): + s = Series( + ["some_splits", "with_index"], index=["preserve", "me"], dtype=any_string_dtype + ) + result = s.str.split("_", expand=True) + exp = DataFrame( + {0: ["some", "with"], 1: ["splits", "index"]}, + index=["preserve", "me"], + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, exp) + + with pytest.raises(ValueError, match="expand must be"): + s.str.split("_", expand="not_a_boolean") + + +def test_split_to_multiindex_expand_no_splits(): + # https://github.com/pandas-dev/pandas/issues/23677 + + idx = Index(["nosplit", "alsonosplit", np.nan]) + result = idx.str.split("_", expand=True) + exp = idx + tm.assert_index_equal(result, exp) + assert result.nlevels == 1 + + +def test_split_to_multiindex_expand(): + idx = Index(["some_equal_splits", "with_no_nans", np.nan, None]) + result = idx.str.split("_", expand=True) + exp = MultiIndex.from_tuples( + [ + ("some", "equal", "splits"), + ("with", "no", "nans"), + [np.nan, np.nan, np.nan], + [None, None, None], + ] + ) + tm.assert_index_equal(result, exp) + assert result.nlevels == 3 + + +def test_split_to_multiindex_expand_unequal_splits(): + idx = Index(["some_unequal_splits", "one_of_these_things_is_not", np.nan, None]) + result = idx.str.split("_", expand=True) + exp = MultiIndex.from_tuples( + [ + ("some", "unequal", "splits", np.nan, np.nan, np.nan), + ("one", "of", "these", "things", "is", "not"), + (np.nan, np.nan, np.nan, np.nan, np.nan, np.nan), + (None, None, None, None, None, None), + ] + ) + tm.assert_index_equal(result, exp) + assert result.nlevels == 6 + + with pytest.raises(ValueError, match="expand must be"): + idx.str.split("_", expand="not_a_boolean") + + +def test_rsplit_to_dataframe_expand_no_splits(any_string_dtype): + s = Series(["nosplit", "alsonosplit"], dtype=any_string_dtype) + result = s.str.rsplit("_", expand=True) + exp = DataFrame({0: Series(["nosplit", "alsonosplit"])}, dtype=any_string_dtype) + tm.assert_frame_equal(result, exp) + + +def test_rsplit_to_dataframe_expand(any_string_dtype): + s = Series(["some_equal_splits", "with_no_nans"], dtype=any_string_dtype) + result = s.str.rsplit("_", expand=True) + exp = DataFrame( + {0: ["some", "with"], 1: ["equal", "no"], 2: ["splits", "nans"]}, + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, exp) + + result = s.str.rsplit("_", expand=True, n=2) + exp = DataFrame( + {0: ["some", "with"], 1: ["equal", "no"], 2: ["splits", "nans"]}, + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, exp) + + result = s.str.rsplit("_", expand=True, n=1) + exp = DataFrame( + {0: ["some_equal", "with_no"], 1: ["splits", "nans"]}, dtype=any_string_dtype + ) + tm.assert_frame_equal(result, exp) + + +def test_rsplit_to_dataframe_expand_with_index(any_string_dtype): + s = Series( + ["some_splits", "with_index"], index=["preserve", "me"], dtype=any_string_dtype + ) + result = s.str.rsplit("_", expand=True) + exp = DataFrame( + {0: ["some", "with"], 1: ["splits", "index"]}, + index=["preserve", "me"], + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, exp) + + +def test_rsplit_to_multiindex_expand_no_split(): + idx = Index(["nosplit", "alsonosplit"]) + result = idx.str.rsplit("_", expand=True) + exp = idx + tm.assert_index_equal(result, exp) + assert result.nlevels == 1 + + +def test_rsplit_to_multiindex_expand(): + idx = Index(["some_equal_splits", "with_no_nans"]) + result = idx.str.rsplit("_", expand=True) + exp = MultiIndex.from_tuples([("some", "equal", "splits"), ("with", "no", "nans")]) + tm.assert_index_equal(result, exp) + assert result.nlevels == 3 + + +def test_rsplit_to_multiindex_expand_n(): + idx = Index(["some_equal_splits", "with_no_nans"]) + result = idx.str.rsplit("_", expand=True, n=1) + exp = MultiIndex.from_tuples([("some_equal", "splits"), ("with_no", "nans")]) + tm.assert_index_equal(result, exp) + assert result.nlevels == 2 + + +def test_split_nan_expand(any_string_dtype): + # gh-18450 + s = Series(["foo,bar,baz", np.nan], dtype=any_string_dtype) + result = s.str.split(",", expand=True) + exp = DataFrame( + [["foo", "bar", "baz"], [np.nan, np.nan, np.nan]], dtype=any_string_dtype + ) + tm.assert_frame_equal(result, exp) + + # check that these are actually np.nan/pd.NA and not None + # TODO see GH 18463 + # tm.assert_frame_equal does not differentiate + if any_string_dtype in object_pyarrow_numpy: + assert all(np.isnan(x) for x in result.iloc[1]) + else: + assert all(x is pd.NA for x in result.iloc[1]) + + +def test_split_with_name_series(any_string_dtype): + # GH 12617 + + # should preserve name + s = Series(["a,b", "c,d"], name="xxx", dtype=any_string_dtype) + res = s.str.split(",") + exp = Series([["a", "b"], ["c", "d"]], name="xxx") + tm.assert_series_equal(res, exp) + + res = s.str.split(",", expand=True) + exp = DataFrame([["a", "b"], ["c", "d"]], dtype=any_string_dtype) + tm.assert_frame_equal(res, exp) + + +def test_split_with_name_index(): + # GH 12617 + idx = Index(["a,b", "c,d"], name="xxx") + res = idx.str.split(",") + exp = Index([["a", "b"], ["c", "d"]], name="xxx") + assert res.nlevels == 1 + tm.assert_index_equal(res, exp) + + res = idx.str.split(",", expand=True) + exp = MultiIndex.from_tuples([("a", "b"), ("c", "d")]) + assert res.nlevels == 2 + tm.assert_index_equal(res, exp) + + +@pytest.mark.parametrize( + "method, exp", + [ + [ + "partition", + [ + ("a", "__", "b__c"), + ("c", "__", "d__e"), + np.nan, + ("f", "__", "g__h"), + None, + ], + ], + [ + "rpartition", + [ + ("a__b", "__", "c"), + ("c__d", "__", "e"), + np.nan, + ("f__g", "__", "h"), + None, + ], + ], + ], +) +def test_partition_series_more_than_one_char(method, exp, any_string_dtype): + # https://github.com/pandas-dev/pandas/issues/23558 + # more than one char + s = Series(["a__b__c", "c__d__e", np.nan, "f__g__h", None], dtype=any_string_dtype) + result = getattr(s.str, method)("__", expand=False) + expected = Series(exp) + expected = _convert_na_value(s, expected) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "method, exp", + [ + [ + "partition", + [("a", " ", "b c"), ("c", " ", "d e"), np.nan, ("f", " ", "g h"), None], + ], + [ + "rpartition", + [("a b", " ", "c"), ("c d", " ", "e"), np.nan, ("f g", " ", "h"), None], + ], + ], +) +def test_partition_series_none(any_string_dtype, method, exp): + # https://github.com/pandas-dev/pandas/issues/23558 + # None + s = Series(["a b c", "c d e", np.nan, "f g h", None], dtype=any_string_dtype) + result = getattr(s.str, method)(expand=False) + expected = Series(exp) + expected = _convert_na_value(s, expected) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "method, exp", + [ + [ + "partition", + [("abc", "", ""), ("cde", "", ""), np.nan, ("fgh", "", ""), None], + ], + [ + "rpartition", + [("", "", "abc"), ("", "", "cde"), np.nan, ("", "", "fgh"), None], + ], + ], +) +def test_partition_series_not_split(any_string_dtype, method, exp): + # https://github.com/pandas-dev/pandas/issues/23558 + # Not split + s = Series(["abc", "cde", np.nan, "fgh", None], dtype=any_string_dtype) + result = getattr(s.str, method)("_", expand=False) + expected = Series(exp) + expected = _convert_na_value(s, expected) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "method, exp", + [ + [ + "partition", + [("a", "_", "b_c"), ("c", "_", "d_e"), np.nan, ("f", "_", "g_h")], + ], + [ + "rpartition", + [("a_b", "_", "c"), ("c_d", "_", "e"), np.nan, ("f_g", "_", "h")], + ], + ], +) +def test_partition_series_unicode(any_string_dtype, method, exp): + # https://github.com/pandas-dev/pandas/issues/23558 + # unicode + s = Series(["a_b_c", "c_d_e", np.nan, "f_g_h"], dtype=any_string_dtype) + + result = getattr(s.str, method)("_", expand=False) + expected = Series(exp) + expected = _convert_na_value(s, expected) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("method", ["partition", "rpartition"]) +def test_partition_series_stdlib(any_string_dtype, method): + # https://github.com/pandas-dev/pandas/issues/23558 + # compare to standard lib + s = Series(["A_B_C", "B_C_D", "E_F_G", "EFGHEF"], dtype=any_string_dtype) + result = getattr(s.str, method)("_", expand=False).tolist() + assert result == [getattr(v, method)("_") for v in s] + + +@pytest.mark.parametrize( + "method, expand, exp, exp_levels", + [ + [ + "partition", + False, + np.array( + [("a", "_", "b_c"), ("c", "_", "d_e"), ("f", "_", "g_h"), np.nan, None], + dtype=object, + ), + 1, + ], + [ + "rpartition", + False, + np.array( + [("a_b", "_", "c"), ("c_d", "_", "e"), ("f_g", "_", "h"), np.nan, None], + dtype=object, + ), + 1, + ], + ], +) +def test_partition_index(method, expand, exp, exp_levels): + # https://github.com/pandas-dev/pandas/issues/23558 + + values = Index(["a_b_c", "c_d_e", "f_g_h", np.nan, None]) + + result = getattr(values.str, method)("_", expand=expand) + exp = Index(exp) + tm.assert_index_equal(result, exp) + assert result.nlevels == exp_levels + + +@pytest.mark.parametrize( + "method, exp", + [ + [ + "partition", + { + 0: ["a", "c", np.nan, "f", None], + 1: ["_", "_", np.nan, "_", None], + 2: ["b_c", "d_e", np.nan, "g_h", None], + }, + ], + [ + "rpartition", + { + 0: ["a_b", "c_d", np.nan, "f_g", None], + 1: ["_", "_", np.nan, "_", None], + 2: ["c", "e", np.nan, "h", None], + }, + ], + ], +) +def test_partition_to_dataframe(any_string_dtype, method, exp): + # https://github.com/pandas-dev/pandas/issues/23558 + + s = Series(["a_b_c", "c_d_e", np.nan, "f_g_h", None], dtype=any_string_dtype) + result = getattr(s.str, method)("_") + expected = DataFrame( + exp, + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "method, exp", + [ + [ + "partition", + { + 0: ["a", "c", np.nan, "f", None], + 1: ["_", "_", np.nan, "_", None], + 2: ["b_c", "d_e", np.nan, "g_h", None], + }, + ], + [ + "rpartition", + { + 0: ["a_b", "c_d", np.nan, "f_g", None], + 1: ["_", "_", np.nan, "_", None], + 2: ["c", "e", np.nan, "h", None], + }, + ], + ], +) +def test_partition_to_dataframe_from_series(any_string_dtype, method, exp): + # https://github.com/pandas-dev/pandas/issues/23558 + s = Series(["a_b_c", "c_d_e", np.nan, "f_g_h", None], dtype=any_string_dtype) + result = getattr(s.str, method)("_", expand=True) + expected = DataFrame( + exp, + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + +def test_partition_with_name(any_string_dtype): + # GH 12617 + + s = Series(["a,b", "c,d"], name="xxx", dtype=any_string_dtype) + result = s.str.partition(",") + expected = DataFrame( + {0: ["a", "c"], 1: [",", ","], 2: ["b", "d"]}, dtype=any_string_dtype + ) + tm.assert_frame_equal(result, expected) + + +def test_partition_with_name_expand(any_string_dtype): + # GH 12617 + # should preserve name + s = Series(["a,b", "c,d"], name="xxx", dtype=any_string_dtype) + result = s.str.partition(",", expand=False) + expected = Series([("a", ",", "b"), ("c", ",", "d")], name="xxx") + tm.assert_series_equal(result, expected) + + +def test_partition_index_with_name(): + idx = Index(["a,b", "c,d"], name="xxx") + result = idx.str.partition(",") + expected = MultiIndex.from_tuples([("a", ",", "b"), ("c", ",", "d")]) + assert result.nlevels == 3 + tm.assert_index_equal(result, expected) + + +def test_partition_index_with_name_expand_false(): + idx = Index(["a,b", "c,d"], name="xxx") + # should preserve name + result = idx.str.partition(",", expand=False) + expected = Index(np.array([("a", ",", "b"), ("c", ",", "d")]), name="xxx") + assert result.nlevels == 1 + tm.assert_index_equal(result, expected) + + +@pytest.mark.parametrize("method", ["partition", "rpartition"]) +def test_partition_sep_kwarg(any_string_dtype, method): + # GH 22676; depr kwarg "pat" in favor of "sep" + s = Series(["a_b_c", "c_d_e", np.nan, "f_g_h"], dtype=any_string_dtype) + + expected = getattr(s.str, method)(sep="_") + result = getattr(s.str, method)("_") + tm.assert_frame_equal(result, expected) + + +def test_get(): + ser = Series(["a_b_c", "c_d_e", np.nan, "f_g_h"]) + result = ser.str.split("_").str.get(1) + expected = Series(["b", "d", np.nan, "g"], dtype=object) + tm.assert_series_equal(result, expected) + + +def test_get_mixed_object(): + ser = Series(["a_b_c", np.nan, "c_d_e", True, datetime.today(), None, 1, 2.0]) + result = ser.str.split("_").str.get(1) + expected = Series( + ["b", np.nan, "d", np.nan, np.nan, None, np.nan, np.nan], dtype=object + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("idx", [2, -3]) +def test_get_bounds(idx): + ser = Series(["1_2_3_4_5", "6_7_8_9_10", "11_12"]) + result = ser.str.split("_").str.get(idx) + expected = Series(["3", "8", np.nan], dtype=object) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "idx, exp", [[2, [3, 3, np.nan, "b"]], [-1, [3, 3, np.nan, np.nan]]] +) +def test_get_complex(idx, exp): + # GH 20671, getting value not in dict raising `KeyError` + ser = Series([(1, 2, 3), [1, 2, 3], {1, 2, 3}, {1: "a", 2: "b", 3: "c"}]) + + result = ser.str.get(idx) + expected = Series(exp) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("to_type", [tuple, list, np.array]) +def test_get_complex_nested(to_type): + ser = Series([to_type([to_type([1, 2])])]) + + result = ser.str.get(0) + expected = Series([to_type([1, 2])]) + tm.assert_series_equal(result, expected) + + result = ser.str.get(1) + expected = Series([np.nan]) + tm.assert_series_equal(result, expected) + + +def test_get_strings(any_string_dtype): + ser = Series(["a", "ab", np.nan, "abc"], dtype=any_string_dtype) + result = ser.str.get(2) + expected = Series([np.nan, np.nan, np.nan, "c"], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_string_array.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_string_array.py new file mode 100644 index 0000000000000000000000000000000000000000..0b3f368afea5ec035ff11f609358e4377afb87fd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_string_array.py @@ -0,0 +1,112 @@ +import numpy as np +import pytest + +from pandas._libs import lib + +from pandas import ( + NA, + DataFrame, + Series, + _testing as tm, + option_context, +) + + +@pytest.mark.filterwarnings("ignore:Falling back") +def test_string_array(nullable_string_dtype, any_string_method): + method_name, args, kwargs = any_string_method + + data = ["a", "bb", np.nan, "ccc"] + a = Series(data, dtype=object) + b = Series(data, dtype=nullable_string_dtype) + + if method_name == "decode": + with pytest.raises(TypeError, match="a bytes-like object is required"): + getattr(b.str, method_name)(*args, **kwargs) + return + + expected = getattr(a.str, method_name)(*args, **kwargs) + result = getattr(b.str, method_name)(*args, **kwargs) + + if isinstance(expected, Series): + if expected.dtype == "object" and lib.is_string_array( + expected.dropna().values, + ): + assert result.dtype == nullable_string_dtype + result = result.astype(object) + + elif expected.dtype == "object" and lib.is_bool_array( + expected.values, skipna=True + ): + assert result.dtype == "boolean" + result = result.astype(object) + + elif expected.dtype == "bool": + assert result.dtype == "boolean" + result = result.astype("bool") + + elif expected.dtype == "float" and expected.isna().any(): + assert result.dtype == "Int64" + result = result.astype("float") + + if expected.dtype == object: + # GH#18463 + expected[expected.isna()] = NA + + elif isinstance(expected, DataFrame): + columns = expected.select_dtypes(include="object").columns + assert all(result[columns].dtypes == nullable_string_dtype) + result[columns] = result[columns].astype(object) + with option_context("future.no_silent_downcasting", True): + expected[columns] = expected[columns].fillna(NA) # GH#18463 + + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "method,expected", + [ + ("count", [2, None]), + ("find", [0, None]), + ("index", [0, None]), + ("rindex", [2, None]), + ], +) +def test_string_array_numeric_integer_array(nullable_string_dtype, method, expected): + s = Series(["aba", None], dtype=nullable_string_dtype) + result = getattr(s.str, method)("a") + expected = Series(expected, dtype="Int64") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "method,expected", + [ + ("isdigit", [False, None, True]), + ("isalpha", [True, None, False]), + ("isalnum", [True, None, True]), + ("isnumeric", [False, None, True]), + ], +) +def test_string_array_boolean_array(nullable_string_dtype, method, expected): + s = Series(["a", None, "1"], dtype=nullable_string_dtype) + result = getattr(s.str, method)() + expected = Series(expected, dtype="boolean") + tm.assert_series_equal(result, expected) + + +def test_string_array_extract(nullable_string_dtype): + # https://github.com/pandas-dev/pandas/issues/30969 + # Only expand=False & multiple groups was failing + + a = Series(["a1", "b2", "cc"], dtype=nullable_string_dtype) + b = Series(["a1", "b2", "cc"], dtype="object") + pat = r"(\w)(\d)" + + result = a.str.extract(pat, expand=False) + expected = b.str.extract(pat, expand=False) + expected = expected.fillna(NA) # GH#18463 + assert all(result.dtypes == nullable_string_dtype) + + result = result.astype(object) + tm.assert_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_strings.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_strings.py new file mode 100644 index 0000000000000000000000000000000000000000..f662dfd7e2b14cc7016e58ab741c92ef5b29abe0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/strings/test_strings.py @@ -0,0 +1,720 @@ +from datetime import ( + datetime, + timedelta, +) + +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, +) +import pandas._testing as tm +from pandas.core.strings.accessor import StringMethods +from pandas.tests.strings import object_pyarrow_numpy + + +@pytest.mark.parametrize("pattern", [0, True, Series(["foo", "bar"])]) +def test_startswith_endswith_non_str_patterns(pattern): + # GH3485 + ser = Series(["foo", "bar"]) + msg = f"expected a string or tuple, not {type(pattern).__name__}" + with pytest.raises(TypeError, match=msg): + ser.str.startswith(pattern) + with pytest.raises(TypeError, match=msg): + ser.str.endswith(pattern) + + +def test_iter_raises(): + # GH 54173 + ser = Series(["foo", "bar"]) + with pytest.raises(TypeError, match="'StringMethods' object is not iterable"): + iter(ser.str) + + +# test integer/float dtypes (inferred by constructor) and mixed + + +def test_count(any_string_dtype): + ser = Series(["foo", "foofoo", np.nan, "foooofooofommmfoo"], dtype=any_string_dtype) + result = ser.str.count("f[o]+") + expected_dtype = np.float64 if any_string_dtype in object_pyarrow_numpy else "Int64" + expected = Series([1, 2, np.nan, 4], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + +def test_count_mixed_object(): + ser = Series( + ["a", np.nan, "b", True, datetime.today(), "foo", None, 1, 2.0], + dtype=object, + ) + result = ser.str.count("a") + expected = Series([1, np.nan, 0, np.nan, np.nan, 0, np.nan, np.nan, np.nan]) + tm.assert_series_equal(result, expected) + + +def test_repeat(any_string_dtype): + ser = Series(["a", "b", np.nan, "c", np.nan, "d"], dtype=any_string_dtype) + + result = ser.str.repeat(3) + expected = Series( + ["aaa", "bbb", np.nan, "ccc", np.nan, "ddd"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + + result = ser.str.repeat([1, 2, 3, 4, 5, 6]) + expected = Series( + ["a", "bb", np.nan, "cccc", np.nan, "dddddd"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + + +def test_repeat_mixed_object(): + ser = Series(["a", np.nan, "b", True, datetime.today(), "foo", None, 1, 2.0]) + result = ser.str.repeat(3) + expected = Series( + ["aaa", np.nan, "bbb", np.nan, np.nan, "foofoofoo", None, np.nan, np.nan], + dtype=object, + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("arg, repeat", [[None, 4], ["b", None]]) +def test_repeat_with_null(any_string_dtype, arg, repeat): + # GH: 31632 + ser = Series(["a", arg], dtype=any_string_dtype) + result = ser.str.repeat([3, repeat]) + expected = Series(["aaa", None], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +def test_empty_str_methods(any_string_dtype): + empty_str = empty = Series(dtype=any_string_dtype) + if any_string_dtype in object_pyarrow_numpy: + empty_int = Series(dtype="int64") + empty_bool = Series(dtype=bool) + else: + empty_int = Series(dtype="Int64") + empty_bool = Series(dtype="boolean") + empty_object = Series(dtype=object) + empty_bytes = Series(dtype=object) + empty_df = DataFrame() + + # GH7241 + # (extract) on empty series + + tm.assert_series_equal(empty_str, empty.str.cat(empty)) + assert "" == empty.str.cat() + tm.assert_series_equal(empty_str, empty.str.title()) + tm.assert_series_equal(empty_int, empty.str.count("a")) + tm.assert_series_equal(empty_bool, empty.str.contains("a")) + tm.assert_series_equal(empty_bool, empty.str.startswith("a")) + tm.assert_series_equal(empty_bool, empty.str.endswith("a")) + tm.assert_series_equal(empty_str, empty.str.lower()) + tm.assert_series_equal(empty_str, empty.str.upper()) + tm.assert_series_equal(empty_str, empty.str.replace("a", "b")) + tm.assert_series_equal(empty_str, empty.str.repeat(3)) + tm.assert_series_equal(empty_bool, empty.str.match("^a")) + tm.assert_frame_equal( + DataFrame(columns=[0], dtype=any_string_dtype), + empty.str.extract("()", expand=True), + ) + tm.assert_frame_equal( + DataFrame(columns=[0, 1], dtype=any_string_dtype), + empty.str.extract("()()", expand=True), + ) + tm.assert_series_equal(empty_str, empty.str.extract("()", expand=False)) + tm.assert_frame_equal( + DataFrame(columns=[0, 1], dtype=any_string_dtype), + empty.str.extract("()()", expand=False), + ) + tm.assert_frame_equal(empty_df.set_axis([], axis=1), empty.str.get_dummies()) + tm.assert_series_equal(empty_str, empty_str.str.join("")) + tm.assert_series_equal(empty_int, empty.str.len()) + tm.assert_series_equal(empty_object, empty_str.str.findall("a")) + tm.assert_series_equal(empty_int, empty.str.find("a")) + tm.assert_series_equal(empty_int, empty.str.rfind("a")) + tm.assert_series_equal(empty_str, empty.str.pad(42)) + tm.assert_series_equal(empty_str, empty.str.center(42)) + tm.assert_series_equal(empty_object, empty.str.split("a")) + tm.assert_series_equal(empty_object, empty.str.rsplit("a")) + tm.assert_series_equal(empty_object, empty.str.partition("a", expand=False)) + tm.assert_frame_equal(empty_df, empty.str.partition("a")) + tm.assert_series_equal(empty_object, empty.str.rpartition("a", expand=False)) + tm.assert_frame_equal(empty_df, empty.str.rpartition("a")) + tm.assert_series_equal(empty_str, empty.str.slice(stop=1)) + tm.assert_series_equal(empty_str, empty.str.slice(step=1)) + tm.assert_series_equal(empty_str, empty.str.strip()) + tm.assert_series_equal(empty_str, empty.str.lstrip()) + tm.assert_series_equal(empty_str, empty.str.rstrip()) + tm.assert_series_equal(empty_str, empty.str.wrap(42)) + tm.assert_series_equal(empty_str, empty.str.get(0)) + tm.assert_series_equal(empty_object, empty_bytes.str.decode("ascii")) + tm.assert_series_equal(empty_bytes, empty.str.encode("ascii")) + # ismethods should always return boolean (GH 29624) + tm.assert_series_equal(empty_bool, empty.str.isalnum()) + tm.assert_series_equal(empty_bool, empty.str.isalpha()) + tm.assert_series_equal(empty_bool, empty.str.isdigit()) + tm.assert_series_equal(empty_bool, empty.str.isspace()) + tm.assert_series_equal(empty_bool, empty.str.islower()) + tm.assert_series_equal(empty_bool, empty.str.isupper()) + tm.assert_series_equal(empty_bool, empty.str.istitle()) + tm.assert_series_equal(empty_bool, empty.str.isnumeric()) + tm.assert_series_equal(empty_bool, empty.str.isdecimal()) + tm.assert_series_equal(empty_str, empty.str.capitalize()) + tm.assert_series_equal(empty_str, empty.str.swapcase()) + tm.assert_series_equal(empty_str, empty.str.normalize("NFC")) + + table = str.maketrans("a", "b") + tm.assert_series_equal(empty_str, empty.str.translate(table)) + + +@pytest.mark.parametrize( + "method, expected", + [ + ("isalnum", [True, True, True, True, True, False, True, True, False, False]), + ("isalpha", [True, True, True, False, False, False, True, False, False, False]), + ( + "isdigit", + [False, False, False, True, False, False, False, True, False, False], + ), + ( + "isnumeric", + [False, False, False, True, False, False, False, True, False, False], + ), + ( + "isspace", + [False, False, False, False, False, False, False, False, False, True], + ), + ( + "islower", + [False, True, False, False, False, False, False, False, False, False], + ), + ( + "isupper", + [True, False, False, False, True, False, True, False, False, False], + ), + ( + "istitle", + [True, False, True, False, True, False, False, False, False, False], + ), + ], +) +def test_ismethods(method, expected, any_string_dtype): + ser = Series( + ["A", "b", "Xy", "4", "3A", "", "TT", "55", "-", " "], dtype=any_string_dtype + ) + expected_dtype = "bool" if any_string_dtype in object_pyarrow_numpy else "boolean" + expected = Series(expected, dtype=expected_dtype) + result = getattr(ser.str, method)() + tm.assert_series_equal(result, expected) + + # compare with standard library + expected = [getattr(item, method)() for item in ser] + assert list(result) == expected + + +@pytest.mark.parametrize( + "method, expected", + [ + ("isnumeric", [False, True, True, False, True, True, False]), + ("isdecimal", [False, True, False, False, False, True, False]), + ], +) +def test_isnumeric_unicode(method, expected, any_string_dtype): + # 0x00bc: ¼ VULGAR FRACTION ONE QUARTER + # 0x2605: ★ not number + # 0x1378: ፸ ETHIOPIC NUMBER SEVENTY + # 0xFF13: 3 Em 3 # noqa: RUF003 + ser = Series( + ["A", "3", "¼", "★", "፸", "3", "four"], dtype=any_string_dtype # noqa: RUF001 + ) + expected_dtype = "bool" if any_string_dtype in object_pyarrow_numpy else "boolean" + expected = Series(expected, dtype=expected_dtype) + result = getattr(ser.str, method)() + tm.assert_series_equal(result, expected) + + # compare with standard library + expected = [getattr(item, method)() for item in ser] + assert list(result) == expected + + +@pytest.mark.parametrize( + "method, expected", + [ + ("isnumeric", [False, np.nan, True, False, np.nan, True, False]), + ("isdecimal", [False, np.nan, False, False, np.nan, True, False]), + ], +) +def test_isnumeric_unicode_missing(method, expected, any_string_dtype): + values = ["A", np.nan, "¼", "★", np.nan, "3", "four"] # noqa: RUF001 + ser = Series(values, dtype=any_string_dtype) + expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean" + expected = Series(expected, dtype=expected_dtype) + result = getattr(ser.str, method)() + tm.assert_series_equal(result, expected) + + +def test_spilt_join_roundtrip(any_string_dtype): + ser = Series(["a_b_c", "c_d_e", np.nan, "f_g_h"], dtype=any_string_dtype) + result = ser.str.split("_").str.join("_") + expected = ser.astype(object) + tm.assert_series_equal(result, expected) + + +def test_spilt_join_roundtrip_mixed_object(): + ser = Series( + ["a_b", np.nan, "asdf_cas_asdf", True, datetime.today(), "foo", None, 1, 2.0] + ) + result = ser.str.split("_").str.join("_") + expected = Series( + ["a_b", np.nan, "asdf_cas_asdf", np.nan, np.nan, "foo", None, np.nan, np.nan], + dtype=object, + ) + tm.assert_series_equal(result, expected) + + +def test_len(any_string_dtype): + ser = Series( + ["foo", "fooo", "fooooo", np.nan, "fooooooo", "foo\n", "あ"], + dtype=any_string_dtype, + ) + result = ser.str.len() + expected_dtype = "float64" if any_string_dtype in object_pyarrow_numpy else "Int64" + expected = Series([3, 4, 6, np.nan, 8, 4, 1], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + +def test_len_mixed(): + ser = Series( + ["a_b", np.nan, "asdf_cas_asdf", True, datetime.today(), "foo", None, 1, 2.0] + ) + result = ser.str.len() + expected = Series([3, np.nan, 13, np.nan, np.nan, 3, np.nan, np.nan, np.nan]) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "method,sub,start,end,expected", + [ + ("index", "EF", None, None, [4, 3, 1, 0]), + ("rindex", "EF", None, None, [4, 5, 7, 4]), + ("index", "EF", 3, None, [4, 3, 7, 4]), + ("rindex", "EF", 3, None, [4, 5, 7, 4]), + ("index", "E", 4, 8, [4, 5, 7, 4]), + ("rindex", "E", 0, 5, [4, 3, 1, 4]), + ], +) +def test_index(method, sub, start, end, index_or_series, any_string_dtype, expected): + obj = index_or_series( + ["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF"], dtype=any_string_dtype + ) + expected_dtype = np.int64 if any_string_dtype in object_pyarrow_numpy else "Int64" + expected = index_or_series(expected, dtype=expected_dtype) + + result = getattr(obj.str, method)(sub, start, end) + + if index_or_series is Series: + tm.assert_series_equal(result, expected) + else: + tm.assert_index_equal(result, expected) + + # compare with standard library + expected = [getattr(item, method)(sub, start, end) for item in obj] + assert list(result) == expected + + +def test_index_not_found_raises(index_or_series, any_string_dtype): + obj = index_or_series( + ["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF"], dtype=any_string_dtype + ) + with pytest.raises(ValueError, match="substring not found"): + obj.str.index("DE") + + +@pytest.mark.parametrize("method", ["index", "rindex"]) +def test_index_wrong_type_raises(index_or_series, any_string_dtype, method): + obj = index_or_series([], dtype=any_string_dtype) + msg = "expected a string object, not int" + + with pytest.raises(TypeError, match=msg): + getattr(obj.str, method)(0) + + +@pytest.mark.parametrize( + "method, exp", + [ + ["index", [1, 1, 0]], + ["rindex", [3, 1, 2]], + ], +) +def test_index_missing(any_string_dtype, method, exp): + ser = Series(["abcb", "ab", "bcbe", np.nan], dtype=any_string_dtype) + expected_dtype = np.float64 if any_string_dtype in object_pyarrow_numpy else "Int64" + + result = getattr(ser.str, method)("b") + expected = Series(exp + [np.nan], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + +def test_pipe_failures(any_string_dtype): + # #2119 + ser = Series(["A|B|C"], dtype=any_string_dtype) + + result = ser.str.split("|") + expected = Series([["A", "B", "C"]], dtype=object) + tm.assert_series_equal(result, expected) + + result = ser.str.replace("|", " ", regex=False) + expected = Series(["A B C"], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "start, stop, step, expected", + [ + (2, 5, None, ["foo", "bar", np.nan, "baz"]), + (0, 3, -1, ["", "", np.nan, ""]), + (None, None, -1, ["owtoofaa", "owtrabaa", np.nan, "xuqzabaa"]), + (3, 10, 2, ["oto", "ato", np.nan, "aqx"]), + (3, 0, -1, ["ofa", "aba", np.nan, "aba"]), + ], +) +def test_slice(start, stop, step, expected, any_string_dtype): + ser = Series(["aafootwo", "aabartwo", np.nan, "aabazqux"], dtype=any_string_dtype) + result = ser.str.slice(start, stop, step) + expected = Series(expected, dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "start, stop, step, expected", + [ + (2, 5, None, ["foo", np.nan, "bar", np.nan, np.nan, None, np.nan, np.nan]), + (4, 1, -1, ["oof", np.nan, "rab", np.nan, np.nan, None, np.nan, np.nan]), + ], +) +def test_slice_mixed_object(start, stop, step, expected): + ser = Series(["aafootwo", np.nan, "aabartwo", True, datetime.today(), None, 1, 2.0]) + result = ser.str.slice(start, stop, step) + expected = Series(expected, dtype=object) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "start,stop,repl,expected", + [ + (2, 3, None, ["shrt", "a it longer", "evnlongerthanthat", "", np.nan]), + (2, 3, "z", ["shzrt", "a zit longer", "evznlongerthanthat", "z", np.nan]), + (2, 2, "z", ["shzort", "a zbit longer", "evzenlongerthanthat", "z", np.nan]), + (2, 1, "z", ["shzort", "a zbit longer", "evzenlongerthanthat", "z", np.nan]), + (-1, None, "z", ["shorz", "a bit longez", "evenlongerthanthaz", "z", np.nan]), + (None, -2, "z", ["zrt", "zer", "zat", "z", np.nan]), + (6, 8, "z", ["shortz", "a bit znger", "evenlozerthanthat", "z", np.nan]), + (-10, 3, "z", ["zrt", "a zit longer", "evenlongzerthanthat", "z", np.nan]), + ], +) +def test_slice_replace(start, stop, repl, expected, any_string_dtype): + ser = Series( + ["short", "a bit longer", "evenlongerthanthat", "", np.nan], + dtype=any_string_dtype, + ) + expected = Series(expected, dtype=any_string_dtype) + result = ser.str.slice_replace(start, stop, repl) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "method, exp", + [ + ["strip", ["aa", "bb", np.nan, "cc"]], + ["lstrip", ["aa ", "bb \n", np.nan, "cc "]], + ["rstrip", [" aa", " bb", np.nan, "cc"]], + ], +) +def test_strip_lstrip_rstrip(any_string_dtype, method, exp): + ser = Series([" aa ", " bb \n", np.nan, "cc "], dtype=any_string_dtype) + + result = getattr(ser.str, method)() + expected = Series(exp, dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "method, exp", + [ + ["strip", ["aa", np.nan, "bb"]], + ["lstrip", ["aa ", np.nan, "bb \t\n"]], + ["rstrip", [" aa", np.nan, " bb"]], + ], +) +def test_strip_lstrip_rstrip_mixed_object(method, exp): + ser = Series([" aa ", np.nan, " bb \t\n", True, datetime.today(), None, 1, 2.0]) + + result = getattr(ser.str, method)() + expected = Series(exp + [np.nan, np.nan, None, np.nan, np.nan], dtype=object) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "method, exp", + [ + ["strip", ["ABC", " BNSD", "LDFJH "]], + ["lstrip", ["ABCxx", " BNSD", "LDFJH xx"]], + ["rstrip", ["xxABC", "xx BNSD", "LDFJH "]], + ], +) +def test_strip_lstrip_rstrip_args(any_string_dtype, method, exp): + ser = Series(["xxABCxx", "xx BNSD", "LDFJH xx"], dtype=any_string_dtype) + + result = getattr(ser.str, method)("x") + expected = Series(exp, dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "prefix, expected", [("a", ["b", " b c", "bc"]), ("ab", ["", "a b c", "bc"])] +) +def test_removeprefix(any_string_dtype, prefix, expected): + ser = Series(["ab", "a b c", "bc"], dtype=any_string_dtype) + result = ser.str.removeprefix(prefix) + ser_expected = Series(expected, dtype=any_string_dtype) + tm.assert_series_equal(result, ser_expected) + + +@pytest.mark.parametrize( + "suffix, expected", [("c", ["ab", "a b ", "b"]), ("bc", ["ab", "a b c", ""])] +) +def test_removesuffix(any_string_dtype, suffix, expected): + ser = Series(["ab", "a b c", "bc"], dtype=any_string_dtype) + result = ser.str.removesuffix(suffix) + ser_expected = Series(expected, dtype=any_string_dtype) + tm.assert_series_equal(result, ser_expected) + + +def test_string_slice_get_syntax(any_string_dtype): + ser = Series( + ["YYY", "B", "C", "YYYYYYbYYY", "BYYYcYYY", np.nan, "CYYYBYYY", "dog", "cYYYt"], + dtype=any_string_dtype, + ) + + result = ser.str[0] + expected = ser.str.get(0) + tm.assert_series_equal(result, expected) + + result = ser.str[:3] + expected = ser.str.slice(stop=3) + tm.assert_series_equal(result, expected) + + result = ser.str[2::-1] + expected = ser.str.slice(start=2, step=-1) + tm.assert_series_equal(result, expected) + + +def test_string_slice_out_of_bounds_nested(): + ser = Series([(1, 2), (1,), (3, 4, 5)]) + result = ser.str[1] + expected = Series([2, np.nan, 4]) + tm.assert_series_equal(result, expected) + + +def test_string_slice_out_of_bounds(any_string_dtype): + ser = Series(["foo", "b", "ba"], dtype=any_string_dtype) + result = ser.str[1] + expected = Series(["o", np.nan, "a"], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +def test_encode_decode(any_string_dtype): + ser = Series(["a", "b", "a\xe4"], dtype=any_string_dtype).str.encode("utf-8") + result = ser.str.decode("utf-8") + expected = ser.map(lambda x: x.decode("utf-8")).astype(object) + tm.assert_series_equal(result, expected) + + +def test_encode_errors_kwarg(any_string_dtype): + ser = Series(["a", "b", "a\x9d"], dtype=any_string_dtype) + + msg = ( + r"'charmap' codec can't encode character '\\x9d' in position 1: " + "character maps to " + ) + with pytest.raises(UnicodeEncodeError, match=msg): + ser.str.encode("cp1252") + + result = ser.str.encode("cp1252", "ignore") + expected = ser.map(lambda x: x.encode("cp1252", "ignore")) + tm.assert_series_equal(result, expected) + + +def test_decode_errors_kwarg(): + ser = Series([b"a", b"b", b"a\x9d"]) + + msg = ( + "'charmap' codec can't decode byte 0x9d in position 1: " + "character maps to " + ) + with pytest.raises(UnicodeDecodeError, match=msg): + ser.str.decode("cp1252") + + result = ser.str.decode("cp1252", "ignore") + expected = ser.map(lambda x: x.decode("cp1252", "ignore")).astype(object) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "form, expected", + [ + ("NFKC", ["ABC", "ABC", "123", np.nan, "アイエ"]), + ("NFC", ["ABC", "ABC", "123", np.nan, "アイエ"]), # noqa: RUF001 + ], +) +def test_normalize(form, expected, any_string_dtype): + ser = Series( + ["ABC", "ABC", "123", np.nan, "アイエ"], # noqa: RUF001 + index=["a", "b", "c", "d", "e"], + dtype=any_string_dtype, + ) + expected = Series(expected, index=["a", "b", "c", "d", "e"], dtype=any_string_dtype) + result = ser.str.normalize(form) + tm.assert_series_equal(result, expected) + + +def test_normalize_bad_arg_raises(any_string_dtype): + ser = Series( + ["ABC", "ABC", "123", np.nan, "アイエ"], # noqa: RUF001 + index=["a", "b", "c", "d", "e"], + dtype=any_string_dtype, + ) + with pytest.raises(ValueError, match="invalid normalization form"): + ser.str.normalize("xxx") + + +def test_normalize_index(): + idx = Index(["ABC", "123", "アイエ"]) # noqa: RUF001 + expected = Index(["ABC", "123", "アイエ"]) + result = idx.str.normalize("NFKC") + tm.assert_index_equal(result, expected) + + +@pytest.mark.parametrize( + "values,inferred_type", + [ + (["a", "b"], "string"), + (["a", "b", 1], "mixed-integer"), + (["a", "b", 1.3], "mixed"), + (["a", "b", 1.3, 1], "mixed-integer"), + (["aa", datetime(2011, 1, 1)], "mixed"), + ], +) +def test_index_str_accessor_visibility(values, inferred_type, index_or_series): + obj = index_or_series(values) + if index_or_series is Index: + assert obj.inferred_type == inferred_type + + assert isinstance(obj.str, StringMethods) + + +@pytest.mark.parametrize( + "values,inferred_type", + [ + ([1, np.nan], "floating"), + ([datetime(2011, 1, 1)], "datetime64"), + ([timedelta(1)], "timedelta64"), + ], +) +def test_index_str_accessor_non_string_values_raises( + values, inferred_type, index_or_series +): + obj = index_or_series(values) + if index_or_series is Index: + assert obj.inferred_type == inferred_type + + msg = "Can only use .str accessor with string values" + with pytest.raises(AttributeError, match=msg): + obj.str + + +def test_index_str_accessor_multiindex_raises(): + # MultiIndex has mixed dtype, but not allow to use accessor + idx = MultiIndex.from_tuples([("a", "b"), ("a", "b")]) + assert idx.inferred_type == "mixed" + + msg = "Can only use .str accessor with Index, not MultiIndex" + with pytest.raises(AttributeError, match=msg): + idx.str + + +def test_str_accessor_no_new_attributes(any_string_dtype): + # https://github.com/pandas-dev/pandas/issues/10673 + ser = Series(list("aabbcde"), dtype=any_string_dtype) + with pytest.raises(AttributeError, match="You cannot add any new attribute"): + ser.str.xlabel = "a" + + +def test_cat_on_bytes_raises(): + lhs = Series(np.array(list("abc"), "S1").astype(object)) + rhs = Series(np.array(list("def"), "S1").astype(object)) + msg = "Cannot use .str.cat with values of inferred dtype 'bytes'" + with pytest.raises(TypeError, match=msg): + lhs.str.cat(rhs) + + +def test_str_accessor_in_apply_func(): + # https://github.com/pandas-dev/pandas/issues/38979 + df = DataFrame(zip("abc", "def")) + expected = Series(["A/D", "B/E", "C/F"]) + result = df.apply(lambda f: "/".join(f.str.upper()), axis=1) + tm.assert_series_equal(result, expected) + + +def test_zfill(): + # https://github.com/pandas-dev/pandas/issues/20868 + value = Series(["-1", "1", "1000", 10, np.nan]) + expected = Series(["-01", "001", "1000", np.nan, np.nan], dtype=object) + tm.assert_series_equal(value.str.zfill(3), expected) + + value = Series(["-2", "+5"]) + expected = Series(["-0002", "+0005"]) + tm.assert_series_equal(value.str.zfill(5), expected) + + +def test_zfill_with_non_integer_argument(): + value = Series(["-2", "+5"]) + wid = "a" + msg = f"width must be of integer type, not {type(wid).__name__}" + with pytest.raises(TypeError, match=msg): + value.str.zfill(wid) + + +def test_zfill_with_leading_sign(): + value = Series(["-cat", "-1", "+dog"]) + expected = Series(["-0cat", "-0001", "+0dog"]) + tm.assert_series_equal(value.str.zfill(5), expected) + + +def test_get_with_dict_label(): + # GH47911 + s = Series( + [ + {"name": "Hello", "value": "World"}, + {"name": "Goodbye", "value": "Planet"}, + {"value": "Sea"}, + ] + ) + result = s.str.get("name") + expected = Series(["Hello", "Goodbye", None], dtype=object) + tm.assert_series_equal(result, expected) + result = s.str.get("value") + expected = Series(["World", "Planet", "Sea"], dtype=object) + tm.assert_series_equal(result, expected) + + +def test_series_str_decode(): + # GH 22613 + result = Series([b"x", b"y"]).str.decode(encoding="UTF-8", errors="strict") + expected = Series(["x", "y"], dtype="object") + tm.assert_series_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tools/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tools/test_to_datetime.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tools/test_to_datetime.py new file mode 100644 index 0000000000000000000000000000000000000000..ede38ce9c9a09e705a20eae2dd42bde8216f19cb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tools/test_to_datetime.py @@ -0,0 +1,3911 @@ +""" test to_datetime """ + +import calendar +from collections import deque +from datetime import ( + date, + datetime, + timedelta, + timezone, +) +from decimal import Decimal +import locale + +from dateutil.parser import parse +from dateutil.tz.tz import tzoffset +import numpy as np +import pytest +import pytz + +from pandas._libs import tslib +from pandas._libs.tslibs import ( + iNaT, + parsing, +) +from pandas.errors import ( + OutOfBoundsDatetime, + OutOfBoundsTimedelta, +) +import pandas.util._test_decorators as td + +from pandas.core.dtypes.common import is_datetime64_ns_dtype + +import pandas as pd +from pandas import ( + DataFrame, + DatetimeIndex, + Index, + NaT, + Series, + Timestamp, + date_range, + isna, + to_datetime, +) +import pandas._testing as tm +from pandas.core.arrays import DatetimeArray +from pandas.core.tools import datetimes as tools +from pandas.core.tools.datetimes import start_caching_at + +PARSING_ERR_MSG = ( + r"You might want to try:\n" + r" - passing `format` if your strings have a consistent format;\n" + r" - passing `format=\'ISO8601\'` if your strings are all ISO8601 " + r"but not necessarily in exactly the same format;\n" + r" - passing `format=\'mixed\'`, and the format will be inferred " + r"for each element individually. You might want to use `dayfirst` " + r"alongside this." +) + +pytestmark = pytest.mark.filterwarnings( + "ignore:errors='ignore' is deprecated:FutureWarning" +) + + +@pytest.fixture(params=[True, False]) +def cache(request): + """ + cache keyword to pass to to_datetime. + """ + return request.param + + +class TestTimeConversionFormats: + @pytest.mark.parametrize("readonly", [True, False]) + def test_to_datetime_readonly(self, readonly): + # GH#34857 + arr = np.array([], dtype=object) + if readonly: + arr.setflags(write=False) + result = to_datetime(arr) + expected = to_datetime([]) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "format, expected", + [ + [ + "%d/%m/%Y", + [Timestamp("20000101"), Timestamp("20000201"), Timestamp("20000301")], + ], + [ + "%m/%d/%Y", + [Timestamp("20000101"), Timestamp("20000102"), Timestamp("20000103")], + ], + ], + ) + def test_to_datetime_format(self, cache, index_or_series, format, expected): + values = index_or_series(["1/1/2000", "1/2/2000", "1/3/2000"]) + result = to_datetime(values, format=format, cache=cache) + expected = index_or_series(expected) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize( + "arg, expected, format", + [ + ["1/1/2000", "20000101", "%d/%m/%Y"], + ["1/1/2000", "20000101", "%m/%d/%Y"], + ["1/2/2000", "20000201", "%d/%m/%Y"], + ["1/2/2000", "20000102", "%m/%d/%Y"], + ["1/3/2000", "20000301", "%d/%m/%Y"], + ["1/3/2000", "20000103", "%m/%d/%Y"], + ], + ) + def test_to_datetime_format_scalar(self, cache, arg, expected, format): + result = to_datetime(arg, format=format, cache=cache) + expected = Timestamp(expected) + assert result == expected + + def test_to_datetime_format_YYYYMMDD(self, cache): + ser = Series([19801222, 19801222] + [19810105] * 5) + expected = Series([Timestamp(x) for x in ser.apply(str)]) + + result = to_datetime(ser, format="%Y%m%d", cache=cache) + tm.assert_series_equal(result, expected) + + result = to_datetime(ser.apply(str), format="%Y%m%d", cache=cache) + tm.assert_series_equal(result, expected) + + def test_to_datetime_format_YYYYMMDD_with_nat(self, cache): + # Explicit cast to float to explicit cast when setting np.nan + ser = Series([19801222, 19801222] + [19810105] * 5, dtype="float") + # with NaT + expected = Series( + [Timestamp("19801222"), Timestamp("19801222")] + [Timestamp("19810105")] * 5 + ) + expected[2] = np.nan + ser[2] = np.nan + + result = to_datetime(ser, format="%Y%m%d", cache=cache) + tm.assert_series_equal(result, expected) + + # string with NaT + ser2 = ser.apply(str) + ser2[2] = "nat" + with pytest.raises( + ValueError, + match=( + 'unconverted data remains when parsing with format "%Y%m%d": ".0", ' + "at position 0" + ), + ): + # https://github.com/pandas-dev/pandas/issues/50051 + to_datetime(ser2, format="%Y%m%d", cache=cache) + + def test_to_datetime_format_YYYYMM_with_nat(self, cache): + # https://github.com/pandas-dev/pandas/issues/50237 + # Explicit cast to float to explicit cast when setting np.nan + ser = Series([198012, 198012] + [198101] * 5, dtype="float") + expected = Series( + [Timestamp("19801201"), Timestamp("19801201")] + [Timestamp("19810101")] * 5 + ) + expected[2] = np.nan + ser[2] = np.nan + result = to_datetime(ser, format="%Y%m", cache=cache) + tm.assert_series_equal(result, expected) + + def test_to_datetime_format_YYYYMMDD_ignore(self, cache): + # coercion + # GH 7930, GH 14487 + ser = Series([20121231, 20141231, 99991231]) + result = to_datetime(ser, format="%Y%m%d", errors="ignore", cache=cache) + expected = Series( + [20121231, 20141231, 99991231], + dtype=object, + ) + tm.assert_series_equal(result, expected) + + def test_to_datetime_format_YYYYMMDD_ignore_with_outofbounds(self, cache): + # https://github.com/pandas-dev/pandas/issues/26493 + result = to_datetime( + ["15010101", "20150101", np.nan], + format="%Y%m%d", + errors="ignore", + cache=cache, + ) + expected = Index(["15010101", "20150101", np.nan], dtype=object) + tm.assert_index_equal(result, expected) + + def test_to_datetime_format_YYYYMMDD_coercion(self, cache): + # coercion + # GH 7930 + ser = Series([20121231, 20141231, 99991231]) + result = to_datetime(ser, format="%Y%m%d", errors="coerce", cache=cache) + expected = Series(["20121231", "20141231", "NaT"], dtype="M8[ns]") + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "input_s", + [ + # Null values with Strings + ["19801222", "20010112", None], + ["19801222", "20010112", np.nan], + ["19801222", "20010112", NaT], + ["19801222", "20010112", "NaT"], + # Null values with Integers + [19801222, 20010112, None], + [19801222, 20010112, np.nan], + [19801222, 20010112, NaT], + [19801222, 20010112, "NaT"], + ], + ) + def test_to_datetime_format_YYYYMMDD_with_none(self, input_s): + # GH 30011 + # format='%Y%m%d' + # with None + expected = Series([Timestamp("19801222"), Timestamp("20010112"), NaT]) + result = Series(to_datetime(input_s, format="%Y%m%d")) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "input_s, expected", + [ + # NaN before strings with invalid date values + [ + Series(["19801222", np.nan, "20010012", "10019999"]), + Series([Timestamp("19801222"), np.nan, np.nan, np.nan]), + ], + # NaN after strings with invalid date values + [ + Series(["19801222", "20010012", "10019999", np.nan]), + Series([Timestamp("19801222"), np.nan, np.nan, np.nan]), + ], + # NaN before integers with invalid date values + [ + Series([20190813, np.nan, 20010012, 20019999]), + Series([Timestamp("20190813"), np.nan, np.nan, np.nan]), + ], + # NaN after integers with invalid date values + [ + Series([20190813, 20010012, np.nan, 20019999]), + Series([Timestamp("20190813"), np.nan, np.nan, np.nan]), + ], + ], + ) + def test_to_datetime_format_YYYYMMDD_overflow(self, input_s, expected): + # GH 25512 + # format='%Y%m%d', errors='coerce' + result = to_datetime(input_s, format="%Y%m%d", errors="coerce") + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "data, format, expected", + [ + ([pd.NA], "%Y%m%d%H%M%S", DatetimeIndex(["NaT"])), + ([pd.NA], None, DatetimeIndex(["NaT"])), + ( + [pd.NA, "20210202202020"], + "%Y%m%d%H%M%S", + DatetimeIndex(["NaT", "2021-02-02 20:20:20"]), + ), + (["201010", pd.NA], "%y%m%d", DatetimeIndex(["2020-10-10", "NaT"])), + (["201010", pd.NA], "%d%m%y", DatetimeIndex(["2010-10-20", "NaT"])), + ([None, np.nan, pd.NA], None, DatetimeIndex(["NaT", "NaT", "NaT"])), + ([None, np.nan, pd.NA], "%Y%m%d", DatetimeIndex(["NaT", "NaT", "NaT"])), + ], + ) + def test_to_datetime_with_NA(self, data, format, expected): + # GH#42957 + result = to_datetime(data, format=format) + tm.assert_index_equal(result, expected) + + def test_to_datetime_with_NA_with_warning(self): + # GH#42957 + result = to_datetime(["201010", pd.NA]) + expected = DatetimeIndex(["2010-10-20", "NaT"]) + tm.assert_index_equal(result, expected) + + def test_to_datetime_format_integer(self, cache): + # GH 10178 + ser = Series([2000, 2001, 2002]) + expected = Series([Timestamp(x) for x in ser.apply(str)]) + + result = to_datetime(ser, format="%Y", cache=cache) + tm.assert_series_equal(result, expected) + + ser = Series([200001, 200105, 200206]) + expected = Series([Timestamp(x[:4] + "-" + x[4:]) for x in ser.apply(str)]) + + result = to_datetime(ser, format="%Y%m", cache=cache) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "int_date, expected", + [ + # valid date, length == 8 + [20121030, datetime(2012, 10, 30)], + # short valid date, length == 6 + [199934, datetime(1999, 3, 4)], + # long integer date partially parsed to datetime(2012,1,1), length > 8 + [2012010101, 2012010101], + # invalid date partially parsed to datetime(2012,9,9), length == 8 + [20129930, 20129930], + # short integer date partially parsed to datetime(2012,9,9), length < 8 + [2012993, 2012993], + # short invalid date, length == 4 + [2121, 2121], + ], + ) + def test_int_to_datetime_format_YYYYMMDD_typeerror(self, int_date, expected): + # GH 26583 + result = to_datetime(int_date, format="%Y%m%d", errors="ignore") + assert result == expected + + def test_to_datetime_format_microsecond(self, cache): + month_abbr = calendar.month_abbr[4] + val = f"01-{month_abbr}-2011 00:00:01.978" + + format = "%d-%b-%Y %H:%M:%S.%f" + result = to_datetime(val, format=format, cache=cache) + exp = datetime.strptime(val, format) + assert result == exp + + @pytest.mark.parametrize( + "value, format, dt", + [ + ["01/10/2010 15:20", "%m/%d/%Y %H:%M", Timestamp("2010-01-10 15:20")], + ["01/10/2010 05:43", "%m/%d/%Y %I:%M", Timestamp("2010-01-10 05:43")], + [ + "01/10/2010 13:56:01", + "%m/%d/%Y %H:%M:%S", + Timestamp("2010-01-10 13:56:01"), + ], + # The 3 tests below are locale-dependent. + # They pass, except when the machine locale is zh_CN or it_IT . + pytest.param( + "01/10/2010 08:14 PM", + "%m/%d/%Y %I:%M %p", + Timestamp("2010-01-10 20:14"), + marks=pytest.mark.xfail( + locale.getlocale()[0] in ("zh_CN", "it_IT"), + reason="fail on a CI build with LC_ALL=zh_CN.utf8/it_IT.utf8", + strict=False, + ), + ), + pytest.param( + "01/10/2010 07:40 AM", + "%m/%d/%Y %I:%M %p", + Timestamp("2010-01-10 07:40"), + marks=pytest.mark.xfail( + locale.getlocale()[0] in ("zh_CN", "it_IT"), + reason="fail on a CI build with LC_ALL=zh_CN.utf8/it_IT.utf8", + strict=False, + ), + ), + pytest.param( + "01/10/2010 09:12:56 AM", + "%m/%d/%Y %I:%M:%S %p", + Timestamp("2010-01-10 09:12:56"), + marks=pytest.mark.xfail( + locale.getlocale()[0] in ("zh_CN", "it_IT"), + reason="fail on a CI build with LC_ALL=zh_CN.utf8/it_IT.utf8", + strict=False, + ), + ), + ], + ) + def test_to_datetime_format_time(self, cache, value, format, dt): + assert to_datetime(value, format=format, cache=cache) == dt + + @td.skip_if_not_us_locale + def test_to_datetime_with_non_exact(self, cache): + # GH 10834 + # 8904 + # exact kw + ser = Series( + ["19MAY11", "foobar19MAY11", "19MAY11:00:00:00", "19MAY11 00:00:00Z"] + ) + result = to_datetime(ser, format="%d%b%y", exact=False, cache=cache) + expected = to_datetime( + ser.str.extract(r"(\d+\w+\d+)", expand=False), format="%d%b%y", cache=cache + ) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "format, expected", + [ + ("%Y-%m-%d", Timestamp(2000, 1, 3)), + ("%Y-%d-%m", Timestamp(2000, 3, 1)), + ("%Y-%m-%d %H", Timestamp(2000, 1, 3, 12)), + ("%Y-%d-%m %H", Timestamp(2000, 3, 1, 12)), + ("%Y-%m-%d %H:%M", Timestamp(2000, 1, 3, 12, 34)), + ("%Y-%d-%m %H:%M", Timestamp(2000, 3, 1, 12, 34)), + ("%Y-%m-%d %H:%M:%S", Timestamp(2000, 1, 3, 12, 34, 56)), + ("%Y-%d-%m %H:%M:%S", Timestamp(2000, 3, 1, 12, 34, 56)), + ("%Y-%m-%d %H:%M:%S.%f", Timestamp(2000, 1, 3, 12, 34, 56, 123456)), + ("%Y-%d-%m %H:%M:%S.%f", Timestamp(2000, 3, 1, 12, 34, 56, 123456)), + ( + "%Y-%m-%d %H:%M:%S.%f%z", + Timestamp(2000, 1, 3, 12, 34, 56, 123456, tz="UTC+01:00"), + ), + ( + "%Y-%d-%m %H:%M:%S.%f%z", + Timestamp(2000, 3, 1, 12, 34, 56, 123456, tz="UTC+01:00"), + ), + ], + ) + def test_non_exact_doesnt_parse_whole_string(self, cache, format, expected): + # https://github.com/pandas-dev/pandas/issues/50412 + # the formats alternate between ISO8601 and non-ISO8601 to check both paths + result = to_datetime( + "2000-01-03 12:34:56.123456+01:00", format=format, exact=False + ) + assert result == expected + + @pytest.mark.parametrize( + "arg", + [ + "2012-01-01 09:00:00.000000001", + "2012-01-01 09:00:00.000001", + "2012-01-01 09:00:00.001", + "2012-01-01 09:00:00.001000", + "2012-01-01 09:00:00.001000000", + ], + ) + def test_parse_nanoseconds_with_formula(self, cache, arg): + # GH8989 + # truncating the nanoseconds when a format was provided + expected = to_datetime(arg, cache=cache) + result = to_datetime(arg, format="%Y-%m-%d %H:%M:%S.%f", cache=cache) + assert result == expected + + @pytest.mark.parametrize( + "value,fmt,expected", + [ + ["2009324", "%Y%W%w", Timestamp("2009-08-13")], + ["2013020", "%Y%U%w", Timestamp("2013-01-13")], + ], + ) + def test_to_datetime_format_weeks(self, value, fmt, expected, cache): + assert to_datetime(value, format=fmt, cache=cache) == expected + + @pytest.mark.parametrize( + "fmt,dates,expected_dates", + [ + [ + "%Y-%m-%d %H:%M:%S %Z", + ["2010-01-01 12:00:00 UTC"] * 2, + [Timestamp("2010-01-01 12:00:00", tz="UTC")] * 2, + ], + [ + "%Y-%m-%d %H:%M:%S%z", + ["2010-01-01 12:00:00+0100"] * 2, + [ + Timestamp( + "2010-01-01 12:00:00", tzinfo=timezone(timedelta(minutes=60)) + ) + ] + * 2, + ], + [ + "%Y-%m-%d %H:%M:%S %z", + ["2010-01-01 12:00:00 +0100"] * 2, + [ + Timestamp( + "2010-01-01 12:00:00", tzinfo=timezone(timedelta(minutes=60)) + ) + ] + * 2, + ], + [ + "%Y-%m-%d %H:%M:%S %z", + ["2010-01-01 12:00:00 Z", "2010-01-01 12:00:00 Z"], + [ + Timestamp( + "2010-01-01 12:00:00", tzinfo=pytz.FixedOffset(0) + ), # pytz coerces to UTC + Timestamp("2010-01-01 12:00:00", tzinfo=pytz.FixedOffset(0)), + ], + ], + ], + ) + def test_to_datetime_parse_tzname_or_tzoffset(self, fmt, dates, expected_dates): + # GH 13486 + result = to_datetime(dates, format=fmt) + expected = Index(expected_dates) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize( + "fmt,dates,expected_dates", + [ + [ + "%Y-%m-%d %H:%M:%S %Z", + [ + "2010-01-01 12:00:00 UTC", + "2010-01-01 12:00:00 GMT", + "2010-01-01 12:00:00 US/Pacific", + ], + [ + Timestamp("2010-01-01 12:00:00", tz="UTC"), + Timestamp("2010-01-01 12:00:00", tz="GMT"), + Timestamp("2010-01-01 12:00:00", tz="US/Pacific"), + ], + ], + [ + "%Y-%m-%d %H:%M:%S %z", + ["2010-01-01 12:00:00 +0100", "2010-01-01 12:00:00 -0100"], + [ + Timestamp( + "2010-01-01 12:00:00", tzinfo=timezone(timedelta(minutes=60)) + ), + Timestamp( + "2010-01-01 12:00:00", tzinfo=timezone(timedelta(minutes=-60)) + ), + ], + ], + ], + ) + def test_to_datetime_parse_tzname_or_tzoffset_utc_false_deprecated( + self, fmt, dates, expected_dates + ): + # GH 13486, 50887 + msg = "parsing datetimes with mixed time zones will raise an error" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = to_datetime(dates, format=fmt) + expected = Index(expected_dates) + tm.assert_equal(result, expected) + + def test_to_datetime_parse_tzname_or_tzoffset_different_tz_to_utc(self): + # GH 32792 + dates = [ + "2010-01-01 12:00:00 +0100", + "2010-01-01 12:00:00 -0100", + "2010-01-01 12:00:00 +0300", + "2010-01-01 12:00:00 +0400", + ] + expected_dates = [ + "2010-01-01 11:00:00+00:00", + "2010-01-01 13:00:00+00:00", + "2010-01-01 09:00:00+00:00", + "2010-01-01 08:00:00+00:00", + ] + fmt = "%Y-%m-%d %H:%M:%S %z" + + result = to_datetime(dates, format=fmt, utc=True) + expected = DatetimeIndex(expected_dates) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "offset", ["+0", "-1foo", "UTCbar", ":10", "+01:000:01", ""] + ) + def test_to_datetime_parse_timezone_malformed(self, offset): + fmt = "%Y-%m-%d %H:%M:%S %z" + date = "2010-01-01 12:00:00 " + offset + + msg = "|".join( + [ + r'^time data ".*" doesn\'t match format ".*", at position 0. ' + f"{PARSING_ERR_MSG}$", + r'^unconverted data remains when parsing with format ".*": ".*", ' + f"at position 0. {PARSING_ERR_MSG}$", + ] + ) + with pytest.raises(ValueError, match=msg): + to_datetime([date], format=fmt) + + def test_to_datetime_parse_timezone_keeps_name(self): + # GH 21697 + fmt = "%Y-%m-%d %H:%M:%S %z" + arg = Index(["2010-01-01 12:00:00 Z"], name="foo") + result = to_datetime(arg, format=fmt) + expected = DatetimeIndex(["2010-01-01 12:00:00"], tz="UTC", name="foo") + tm.assert_index_equal(result, expected) + + +class TestToDatetime: + @pytest.mark.filterwarnings("ignore:Could not infer format") + def test_to_datetime_overflow(self): + # we should get an OutOfBoundsDatetime, NOT OverflowError + # TODO: Timestamp raises ValueError("could not convert string to Timestamp") + # can we make these more consistent? + arg = "08335394550" + msg = 'Parsing "08335394550" to datetime overflows, at position 0' + with pytest.raises(OutOfBoundsDatetime, match=msg): + to_datetime(arg) + + with pytest.raises(OutOfBoundsDatetime, match=msg): + to_datetime([arg]) + + res = to_datetime(arg, errors="coerce") + assert res is NaT + res = to_datetime([arg], errors="coerce") + tm.assert_index_equal(res, Index([NaT])) + + res = to_datetime(arg, errors="ignore") + assert isinstance(res, str) and res == arg + res = to_datetime([arg], errors="ignore") + tm.assert_index_equal(res, Index([arg], dtype=object)) + + def test_to_datetime_mixed_datetime_and_string(self): + # GH#47018 adapted old doctest with new behavior + d1 = datetime(2020, 1, 1, 17, tzinfo=timezone(-timedelta(hours=1))) + d2 = datetime(2020, 1, 1, 18, tzinfo=timezone(-timedelta(hours=1))) + res = to_datetime(["2020-01-01 17:00 -0100", d2]) + expected = to_datetime([d1, d2]).tz_convert(timezone(timedelta(minutes=-60))) + tm.assert_index_equal(res, expected) + + def test_to_datetime_mixed_string_and_numeric(self): + # GH#55780 np.array(vals) would incorrectly cast the number to str + vals = ["2016-01-01", 0] + expected = DatetimeIndex([Timestamp(x) for x in vals]) + result = to_datetime(vals, format="mixed") + result2 = to_datetime(vals[::-1], format="mixed")[::-1] + result3 = DatetimeIndex(vals) + result4 = DatetimeIndex(vals[::-1])[::-1] + + tm.assert_index_equal(result, expected) + tm.assert_index_equal(result2, expected) + tm.assert_index_equal(result3, expected) + tm.assert_index_equal(result4, expected) + + @pytest.mark.parametrize( + "format", ["%Y-%m-%d", "%Y-%d-%m"], ids=["ISO8601", "non-ISO8601"] + ) + def test_to_datetime_mixed_date_and_string(self, format): + # https://github.com/pandas-dev/pandas/issues/50108 + d1 = date(2020, 1, 2) + res = to_datetime(["2020-01-01", d1], format=format) + expected = DatetimeIndex(["2020-01-01", "2020-01-02"], dtype="M8[ns]") + tm.assert_index_equal(res, expected) + + @pytest.mark.parametrize( + "fmt", + ["%Y-%d-%m %H:%M:%S%z", "%Y-%m-%d %H:%M:%S%z"], + ids=["non-ISO8601 format", "ISO8601 format"], + ) + @pytest.mark.parametrize( + "utc, args, expected", + [ + pytest.param( + True, + ["2000-01-01 01:00:00-08:00", "2000-01-01 02:00:00-08:00"], + DatetimeIndex( + ["2000-01-01 09:00:00+00:00", "2000-01-01 10:00:00+00:00"], + dtype="datetime64[ns, UTC]", + ), + id="all tz-aware, with utc", + ), + pytest.param( + False, + ["2000-01-01 01:00:00+00:00", "2000-01-01 02:00:00+00:00"], + DatetimeIndex( + ["2000-01-01 01:00:00+00:00", "2000-01-01 02:00:00+00:00"], + ), + id="all tz-aware, without utc", + ), + pytest.param( + True, + ["2000-01-01 01:00:00-08:00", "2000-01-01 02:00:00+00:00"], + DatetimeIndex( + ["2000-01-01 09:00:00+00:00", "2000-01-01 02:00:00+00:00"], + dtype="datetime64[ns, UTC]", + ), + id="all tz-aware, mixed offsets, with utc", + ), + pytest.param( + True, + ["2000-01-01 01:00:00", "2000-01-01 02:00:00+00:00"], + DatetimeIndex( + ["2000-01-01 01:00:00+00:00", "2000-01-01 02:00:00+00:00"], + dtype="datetime64[ns, UTC]", + ), + id="tz-aware string, naive pydatetime, with utc", + ), + ], + ) + @pytest.mark.parametrize( + "constructor", + [Timestamp, lambda x: Timestamp(x).to_pydatetime()], + ) + def test_to_datetime_mixed_datetime_and_string_with_format( + self, fmt, utc, args, expected, constructor + ): + # https://github.com/pandas-dev/pandas/issues/49298 + # https://github.com/pandas-dev/pandas/issues/50254 + # note: ISO8601 formats go down a fastpath, so we need to check both + # a ISO8601 format and a non-ISO8601 one + ts1 = constructor(args[0]) + ts2 = args[1] + result = to_datetime([ts1, ts2], format=fmt, utc=utc) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "fmt", + ["%Y-%d-%m %H:%M:%S%z", "%Y-%m-%d %H:%M:%S%z"], + ids=["non-ISO8601 format", "ISO8601 format"], + ) + @pytest.mark.parametrize( + "constructor", + [Timestamp, lambda x: Timestamp(x).to_pydatetime()], + ) + def test_to_datetime_mixed_datetime_and_string_with_format_mixed_offsets_utc_false( + self, fmt, constructor + ): + # https://github.com/pandas-dev/pandas/issues/49298 + # https://github.com/pandas-dev/pandas/issues/50254 + # note: ISO8601 formats go down a fastpath, so we need to check both + # a ISO8601 format and a non-ISO8601 one + args = ["2000-01-01 01:00:00", "2000-01-01 02:00:00+00:00"] + ts1 = constructor(args[0]) + ts2 = args[1] + msg = "parsing datetimes with mixed time zones will raise an error" + + expected = Index( + [ + Timestamp("2000-01-01 01:00:00"), + Timestamp("2000-01-01 02:00:00+0000", tz="UTC"), + ], + ) + with tm.assert_produces_warning(FutureWarning, match=msg): + result = to_datetime([ts1, ts2], format=fmt, utc=False) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "fmt, expected", + [ + pytest.param( + "%Y-%m-%d %H:%M:%S%z", + Index( + [ + Timestamp("2000-01-01 09:00:00+0100", tz="UTC+01:00"), + Timestamp("2000-01-02 02:00:00+0200", tz="UTC+02:00"), + NaT, + ] + ), + id="ISO8601, non-UTC", + ), + pytest.param( + "%Y-%d-%m %H:%M:%S%z", + Index( + [ + Timestamp("2000-01-01 09:00:00+0100", tz="UTC+01:00"), + Timestamp("2000-02-01 02:00:00+0200", tz="UTC+02:00"), + NaT, + ] + ), + id="non-ISO8601, non-UTC", + ), + ], + ) + def test_to_datetime_mixed_offsets_with_none_tz(self, fmt, expected): + # https://github.com/pandas-dev/pandas/issues/50071 + msg = "parsing datetimes with mixed time zones will raise an error" + + with tm.assert_produces_warning(FutureWarning, match=msg): + result = to_datetime( + ["2000-01-01 09:00:00+01:00", "2000-01-02 02:00:00+02:00", None], + format=fmt, + utc=False, + ) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "fmt, expected", + [ + pytest.param( + "%Y-%m-%d %H:%M:%S%z", + DatetimeIndex( + ["2000-01-01 08:00:00+00:00", "2000-01-02 00:00:00+00:00", "NaT"], + dtype="datetime64[ns, UTC]", + ), + id="ISO8601, UTC", + ), + pytest.param( + "%Y-%d-%m %H:%M:%S%z", + DatetimeIndex( + ["2000-01-01 08:00:00+00:00", "2000-02-01 00:00:00+00:00", "NaT"], + dtype="datetime64[ns, UTC]", + ), + id="non-ISO8601, UTC", + ), + ], + ) + def test_to_datetime_mixed_offsets_with_none(self, fmt, expected): + # https://github.com/pandas-dev/pandas/issues/50071 + result = to_datetime( + ["2000-01-01 09:00:00+01:00", "2000-01-02 02:00:00+02:00", None], + format=fmt, + utc=True, + ) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "fmt", + ["%Y-%d-%m %H:%M:%S%z", "%Y-%m-%d %H:%M:%S%z"], + ids=["non-ISO8601 format", "ISO8601 format"], + ) + @pytest.mark.parametrize( + "args", + [ + pytest.param( + ["2000-01-01 01:00:00-08:00", "2000-01-01 02:00:00-07:00"], + id="all tz-aware, mixed timezones, without utc", + ), + ], + ) + @pytest.mark.parametrize( + "constructor", + [Timestamp, lambda x: Timestamp(x).to_pydatetime()], + ) + def test_to_datetime_mixed_datetime_and_string_with_format_raises( + self, fmt, args, constructor + ): + # https://github.com/pandas-dev/pandas/issues/49298 + # note: ISO8601 formats go down a fastpath, so we need to check both + # a ISO8601 format and a non-ISO8601 one + ts1 = constructor(args[0]) + ts2 = constructor(args[1]) + with pytest.raises( + ValueError, match="cannot be converted to datetime64 unless utc=True" + ): + to_datetime([ts1, ts2], format=fmt, utc=False) + + def test_to_datetime_np_str(self): + # GH#32264 + # GH#48969 + value = np.str_("2019-02-04 10:18:46.297000+0000") + + ser = Series([value]) + + exp = Timestamp("2019-02-04 10:18:46.297000", tz="UTC") + + assert to_datetime(value) == exp + assert to_datetime(ser.iloc[0]) == exp + + res = to_datetime([value]) + expected = Index([exp]) + tm.assert_index_equal(res, expected) + + res = to_datetime(ser) + expected = Series(expected) + tm.assert_series_equal(res, expected) + + @pytest.mark.parametrize( + "s, _format, dt", + [ + ["2015-1-1", "%G-%V-%u", datetime(2014, 12, 29, 0, 0)], + ["2015-1-4", "%G-%V-%u", datetime(2015, 1, 1, 0, 0)], + ["2015-1-7", "%G-%V-%u", datetime(2015, 1, 4, 0, 0)], + ], + ) + def test_to_datetime_iso_week_year_format(self, s, _format, dt): + # See GH#16607 + assert to_datetime(s, format=_format) == dt + + @pytest.mark.parametrize( + "msg, s, _format", + [ + [ + "ISO week directive '%V' is incompatible with the year directive " + "'%Y'. Use the ISO year '%G' instead.", + "1999 50", + "%Y %V", + ], + [ + "ISO year directive '%G' must be used with the ISO week directive " + "'%V' and a weekday directive '%A', '%a', '%w', or '%u'.", + "1999 51", + "%G %V", + ], + [ + "ISO year directive '%G' must be used with the ISO week directive " + "'%V' and a weekday directive '%A', '%a', '%w', or '%u'.", + "1999 Monday", + "%G %A", + ], + [ + "ISO year directive '%G' must be used with the ISO week directive " + "'%V' and a weekday directive '%A', '%a', '%w', or '%u'.", + "1999 Mon", + "%G %a", + ], + [ + "ISO year directive '%G' must be used with the ISO week directive " + "'%V' and a weekday directive '%A', '%a', '%w', or '%u'.", + "1999 6", + "%G %w", + ], + [ + "ISO year directive '%G' must be used with the ISO week directive " + "'%V' and a weekday directive '%A', '%a', '%w', or '%u'.", + "1999 6", + "%G %u", + ], + [ + "ISO year directive '%G' must be used with the ISO week directive " + "'%V' and a weekday directive '%A', '%a', '%w', or '%u'.", + "2051", + "%G", + ], + [ + "Day of the year directive '%j' is not compatible with ISO year " + "directive '%G'. Use '%Y' instead.", + "1999 51 6 256", + "%G %V %u %j", + ], + [ + "ISO week directive '%V' is incompatible with the year directive " + "'%Y'. Use the ISO year '%G' instead.", + "1999 51 Sunday", + "%Y %V %A", + ], + [ + "ISO week directive '%V' is incompatible with the year directive " + "'%Y'. Use the ISO year '%G' instead.", + "1999 51 Sun", + "%Y %V %a", + ], + [ + "ISO week directive '%V' is incompatible with the year directive " + "'%Y'. Use the ISO year '%G' instead.", + "1999 51 1", + "%Y %V %w", + ], + [ + "ISO week directive '%V' is incompatible with the year directive " + "'%Y'. Use the ISO year '%G' instead.", + "1999 51 1", + "%Y %V %u", + ], + [ + "ISO week directive '%V' must be used with the ISO year directive " + "'%G' and a weekday directive '%A', '%a', '%w', or '%u'.", + "20", + "%V", + ], + [ + "ISO week directive '%V' must be used with the ISO year directive " + "'%G' and a weekday directive '%A', '%a', '%w', or '%u'.", + "1999 51 Sunday", + "%V %A", + ], + [ + "ISO week directive '%V' must be used with the ISO year directive " + "'%G' and a weekday directive '%A', '%a', '%w', or '%u'.", + "1999 51 Sun", + "%V %a", + ], + [ + "ISO week directive '%V' must be used with the ISO year directive " + "'%G' and a weekday directive '%A', '%a', '%w', or '%u'.", + "1999 51 1", + "%V %w", + ], + [ + "ISO week directive '%V' must be used with the ISO year directive " + "'%G' and a weekday directive '%A', '%a', '%w', or '%u'.", + "1999 51 1", + "%V %u", + ], + [ + "Day of the year directive '%j' is not compatible with ISO year " + "directive '%G'. Use '%Y' instead.", + "1999 50", + "%G %j", + ], + [ + "ISO week directive '%V' must be used with the ISO year directive " + "'%G' and a weekday directive '%A', '%a', '%w', or '%u'.", + "20 Monday", + "%V %A", + ], + ], + ) + @pytest.mark.parametrize("errors", ["raise", "coerce", "ignore"]) + def test_error_iso_week_year(self, msg, s, _format, errors): + # See GH#16607, GH#50308 + # This test checks for errors thrown when giving the wrong format + # However, as discussed on PR#25541, overriding the locale + # causes a different error to be thrown due to the format being + # locale specific, but the test data is in english. + # Therefore, the tests only run when locale is not overwritten, + # as a sort of solution to this problem. + if locale.getlocale() != ("zh_CN", "UTF-8") and locale.getlocale() != ( + "it_IT", + "UTF-8", + ): + with pytest.raises(ValueError, match=msg): + to_datetime(s, format=_format, errors=errors) + + @pytest.mark.parametrize("tz", [None, "US/Central"]) + def test_to_datetime_dtarr(self, tz): + # DatetimeArray + dti = date_range("1965-04-03", periods=19, freq="2W", tz=tz) + arr = dti._data + + result = to_datetime(arr) + assert result is arr + + # Doesn't work on Windows since tzpath not set correctly + @td.skip_if_windows + @pytest.mark.parametrize("arg_class", [Series, Index]) + @pytest.mark.parametrize("utc", [True, False]) + @pytest.mark.parametrize("tz", [None, "US/Central"]) + def test_to_datetime_arrow(self, tz, utc, arg_class): + pa = pytest.importorskip("pyarrow") + + dti = date_range("1965-04-03", periods=19, freq="2W", tz=tz) + dti = arg_class(dti) + + dti_arrow = dti.astype(pd.ArrowDtype(pa.timestamp(unit="ns", tz=tz))) + + result = to_datetime(dti_arrow, utc=utc) + expected = to_datetime(dti, utc=utc).astype( + pd.ArrowDtype(pa.timestamp(unit="ns", tz=tz if not utc else "UTC")) + ) + if not utc and arg_class is not Series: + # Doesn't hold for utc=True, since that will astype + # to_datetime also returns a new object for series + assert result is dti_arrow + if arg_class is Series: + tm.assert_series_equal(result, expected) + else: + tm.assert_index_equal(result, expected) + + def test_to_datetime_pydatetime(self): + actual = to_datetime(datetime(2008, 1, 15)) + assert actual == datetime(2008, 1, 15) + + def test_to_datetime_YYYYMMDD(self): + actual = to_datetime("20080115") + assert actual == datetime(2008, 1, 15) + + def test_to_datetime_unparsable_ignore(self): + # unparsable + ser = "Month 1, 1999" + assert to_datetime(ser, errors="ignore") == ser + + @td.skip_if_windows # `tm.set_timezone` does not work in windows + def test_to_datetime_now(self): + # See GH#18666 + with tm.set_timezone("US/Eastern"): + # GH#18705 + now = Timestamp("now").as_unit("ns") + pdnow = to_datetime("now") + pdnow2 = to_datetime(["now"])[0] + + # These should all be equal with infinite perf; this gives + # a generous margin of 10 seconds + assert abs(pdnow._value - now._value) < 1e10 + assert abs(pdnow2._value - now._value) < 1e10 + + assert pdnow.tzinfo is None + assert pdnow2.tzinfo is None + + @td.skip_if_windows # `tm.set_timezone` does not work in windows + @pytest.mark.parametrize("tz", ["Pacific/Auckland", "US/Samoa"]) + def test_to_datetime_today(self, tz): + # See GH#18666 + # Test with one timezone far ahead of UTC and another far behind, so + # one of these will _almost_ always be in a different day from UTC. + # Unfortunately this test between 12 and 1 AM Samoa time + # this both of these timezones _and_ UTC will all be in the same day, + # so this test will not detect the regression introduced in #18666. + with tm.set_timezone(tz): + nptoday = np.datetime64("today").astype("datetime64[ns]").astype(np.int64) + pdtoday = to_datetime("today") + pdtoday2 = to_datetime(["today"])[0] + + tstoday = Timestamp("today").as_unit("ns") + tstoday2 = Timestamp.today().as_unit("ns") + + # These should all be equal with infinite perf; this gives + # a generous margin of 10 seconds + assert abs(pdtoday.normalize()._value - nptoday) < 1e10 + assert abs(pdtoday2.normalize()._value - nptoday) < 1e10 + assert abs(pdtoday._value - tstoday._value) < 1e10 + assert abs(pdtoday._value - tstoday2._value) < 1e10 + + assert pdtoday.tzinfo is None + assert pdtoday2.tzinfo is None + + @pytest.mark.parametrize("arg", ["now", "today"]) + def test_to_datetime_today_now_unicode_bytes(self, arg): + to_datetime([arg]) + + @pytest.mark.parametrize( + "format, expected_ds", + [ + ("%Y-%m-%d %H:%M:%S%z", "2020-01-03"), + ("%Y-%d-%m %H:%M:%S%z", "2020-03-01"), + (None, "2020-01-03"), + ], + ) + @pytest.mark.parametrize( + "string, attribute", + [ + ("now", "utcnow"), + ("today", "today"), + ], + ) + def test_to_datetime_now_with_format(self, format, expected_ds, string, attribute): + # https://github.com/pandas-dev/pandas/issues/50359 + result = to_datetime(["2020-01-03 00:00:00Z", string], format=format, utc=True) + expected = DatetimeIndex( + [expected_ds, getattr(Timestamp, attribute)()], dtype="datetime64[ns, UTC]" + ) + assert (expected - result).max().total_seconds() < 1 + + @pytest.mark.parametrize( + "dt", [np.datetime64("2000-01-01"), np.datetime64("2000-01-02")] + ) + def test_to_datetime_dt64s(self, cache, dt): + assert to_datetime(dt, cache=cache) == Timestamp(dt) + + @pytest.mark.parametrize( + "arg, format", + [ + ("2001-01-01", "%Y-%m-%d"), + ("01-01-2001", "%d-%m-%Y"), + ], + ) + def test_to_datetime_dt64s_and_str(self, arg, format): + # https://github.com/pandas-dev/pandas/issues/50036 + result = to_datetime([arg, np.datetime64("2020-01-01")], format=format) + expected = DatetimeIndex(["2001-01-01", "2020-01-01"]) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "dt", [np.datetime64("1000-01-01"), np.datetime64("5000-01-02")] + ) + @pytest.mark.parametrize("errors", ["raise", "ignore", "coerce"]) + def test_to_datetime_dt64s_out_of_ns_bounds(self, cache, dt, errors): + # GH#50369 We cast to the nearest supported reso, i.e. "s" + ts = to_datetime(dt, errors=errors, cache=cache) + assert isinstance(ts, Timestamp) + assert ts.unit == "s" + assert ts.asm8 == dt + + ts = Timestamp(dt) + assert ts.unit == "s" + assert ts.asm8 == dt + + @pytest.mark.skip_ubsan + def test_to_datetime_dt64d_out_of_bounds(self, cache): + dt64 = np.datetime64(np.iinfo(np.int64).max, "D") + + msg = "Out of bounds second timestamp: 25252734927768524-07-27" + with pytest.raises(OutOfBoundsDatetime, match=msg): + Timestamp(dt64) + with pytest.raises(OutOfBoundsDatetime, match=msg): + to_datetime(dt64, errors="raise", cache=cache) + + assert to_datetime(dt64, errors="coerce", cache=cache) is NaT + + @pytest.mark.parametrize("unit", ["s", "D"]) + def test_to_datetime_array_of_dt64s(self, cache, unit): + # https://github.com/pandas-dev/pandas/issues/31491 + # Need at least 50 to ensure cache is used. + dts = [ + np.datetime64("2000-01-01", unit), + np.datetime64("2000-01-02", unit), + ] * 30 + # Assuming all datetimes are in bounds, to_datetime() returns + # an array that is equal to Timestamp() parsing + result = to_datetime(dts, cache=cache) + if cache: + # FIXME: behavior should not depend on cache + expected = DatetimeIndex([Timestamp(x).asm8 for x in dts], dtype="M8[s]") + else: + expected = DatetimeIndex([Timestamp(x).asm8 for x in dts], dtype="M8[ns]") + + tm.assert_index_equal(result, expected) + + # A list of datetimes where the last one is out of bounds + dts_with_oob = dts + [np.datetime64("9999-01-01")] + + # As of GH#51978 we do not raise in this case + to_datetime(dts_with_oob, errors="raise") + + result = to_datetime(dts_with_oob, errors="coerce", cache=cache) + if not cache: + # FIXME: shouldn't depend on cache! + expected = DatetimeIndex( + [Timestamp(dts_with_oob[0]).asm8, Timestamp(dts_with_oob[1]).asm8] * 30 + + [NaT], + ) + else: + expected = DatetimeIndex(np.array(dts_with_oob, dtype="M8[s]")) + tm.assert_index_equal(result, expected) + + # With errors='ignore', out of bounds datetime64s + # are converted to their .item(), which depending on the version of + # numpy is either a python datetime.datetime or datetime.date + result = to_datetime(dts_with_oob, errors="ignore", cache=cache) + if not cache: + # FIXME: shouldn't depend on cache! + expected = Index(dts_with_oob) + tm.assert_index_equal(result, expected) + + def test_out_of_bounds_errors_ignore(self): + # https://github.com/pandas-dev/pandas/issues/50587 + result = to_datetime(np.datetime64("9999-01-01"), errors="ignore") + expected = np.datetime64("9999-01-01") + assert result == expected + + def test_out_of_bounds_errors_ignore2(self): + # GH#12424 + msg = "errors='ignore' is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + res = to_datetime( + Series(["2362-01-01", np.nan], dtype=object), errors="ignore" + ) + exp = Series(["2362-01-01", np.nan], dtype=object) + tm.assert_series_equal(res, exp) + + def test_to_datetime_tz(self, cache): + # xref 8260 + # uniform returns a DatetimeIndex + arr = [ + Timestamp("2013-01-01 13:00:00-0800", tz="US/Pacific"), + Timestamp("2013-01-02 14:00:00-0800", tz="US/Pacific"), + ] + result = to_datetime(arr, cache=cache) + expected = DatetimeIndex( + ["2013-01-01 13:00:00", "2013-01-02 14:00:00"], tz="US/Pacific" + ) + tm.assert_index_equal(result, expected) + + def test_to_datetime_tz_mixed(self, cache): + # mixed tzs will raise if errors='raise' + # https://github.com/pandas-dev/pandas/issues/50585 + arr = [ + Timestamp("2013-01-01 13:00:00", tz="US/Pacific"), + Timestamp("2013-01-02 14:00:00", tz="US/Eastern"), + ] + msg = ( + "Tz-aware datetime.datetime cannot be " + "converted to datetime64 unless utc=True" + ) + with pytest.raises(ValueError, match=msg): + to_datetime(arr, cache=cache) + + depr_msg = "errors='ignore' is deprecated" + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + result = to_datetime(arr, cache=cache, errors="ignore") + expected = Index( + [ + Timestamp("2013-01-01 13:00:00-08:00"), + Timestamp("2013-01-02 14:00:00-05:00"), + ], + dtype="object", + ) + tm.assert_index_equal(result, expected) + result = to_datetime(arr, cache=cache, errors="coerce") + expected = DatetimeIndex( + ["2013-01-01 13:00:00-08:00", "NaT"], dtype="datetime64[ns, US/Pacific]" + ) + tm.assert_index_equal(result, expected) + + def test_to_datetime_different_offsets(self, cache): + # inspired by asv timeseries.ToDatetimeNONISO8601 benchmark + # see GH-26097 for more + ts_string_1 = "March 1, 2018 12:00:00+0400" + ts_string_2 = "March 1, 2018 12:00:00+0500" + arr = [ts_string_1] * 5 + [ts_string_2] * 5 + expected = Index([parse(x) for x in arr]) + msg = "parsing datetimes with mixed time zones will raise an error" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = to_datetime(arr, cache=cache) + tm.assert_index_equal(result, expected) + + def test_to_datetime_tz_pytz(self, cache): + # see gh-8260 + us_eastern = pytz.timezone("US/Eastern") + arr = np.array( + [ + us_eastern.localize( + datetime(year=2000, month=1, day=1, hour=3, minute=0) + ), + us_eastern.localize( + datetime(year=2000, month=6, day=1, hour=3, minute=0) + ), + ], + dtype=object, + ) + result = to_datetime(arr, utc=True, cache=cache) + expected = DatetimeIndex( + ["2000-01-01 08:00:00+00:00", "2000-06-01 07:00:00+00:00"], + dtype="datetime64[ns, UTC]", + freq=None, + ) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "init_constructor, end_constructor", + [ + (Index, DatetimeIndex), + (list, DatetimeIndex), + (np.array, DatetimeIndex), + (Series, Series), + ], + ) + def test_to_datetime_utc_true(self, cache, init_constructor, end_constructor): + # See gh-11934 & gh-6415 + data = ["20100102 121314", "20100102 121315"] + expected_data = [ + Timestamp("2010-01-02 12:13:14", tz="utc"), + Timestamp("2010-01-02 12:13:15", tz="utc"), + ] + + result = to_datetime( + init_constructor(data), format="%Y%m%d %H%M%S", utc=True, cache=cache + ) + expected = end_constructor(expected_data) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize( + "scalar, expected", + [ + ["20100102 121314", Timestamp("2010-01-02 12:13:14", tz="utc")], + ["20100102 121315", Timestamp("2010-01-02 12:13:15", tz="utc")], + ], + ) + def test_to_datetime_utc_true_scalar(self, cache, scalar, expected): + # Test scalar case as well + result = to_datetime(scalar, format="%Y%m%d %H%M%S", utc=True, cache=cache) + assert result == expected + + def test_to_datetime_utc_true_with_series_single_value(self, cache): + # GH 15760 UTC=True with Series + ts = 1.5e18 + result = to_datetime(Series([ts]), utc=True, cache=cache) + expected = Series([Timestamp(ts, tz="utc")]) + tm.assert_series_equal(result, expected) + + def test_to_datetime_utc_true_with_series_tzaware_string(self, cache): + ts = "2013-01-01 00:00:00-01:00" + expected_ts = "2013-01-01 01:00:00" + data = Series([ts] * 3) + result = to_datetime(data, utc=True, cache=cache) + expected = Series([Timestamp(expected_ts, tz="utc")] * 3) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "date, dtype", + [ + ("2013-01-01 01:00:00", "datetime64[ns]"), + ("2013-01-01 01:00:00", "datetime64[ns, UTC]"), + ], + ) + def test_to_datetime_utc_true_with_series_datetime_ns(self, cache, date, dtype): + expected = Series( + [Timestamp("2013-01-01 01:00:00", tz="UTC")], dtype="M8[ns, UTC]" + ) + result = to_datetime(Series([date], dtype=dtype), utc=True, cache=cache) + tm.assert_series_equal(result, expected) + + def test_to_datetime_tz_psycopg2(self, request, cache): + # xref 8260 + psycopg2_tz = pytest.importorskip("psycopg2.tz") + + # misc cases + tz1 = psycopg2_tz.FixedOffsetTimezone(offset=-300, name=None) + tz2 = psycopg2_tz.FixedOffsetTimezone(offset=-240, name=None) + arr = np.array( + [ + datetime(2000, 1, 1, 3, 0, tzinfo=tz1), + datetime(2000, 6, 1, 3, 0, tzinfo=tz2), + ], + dtype=object, + ) + + result = to_datetime(arr, errors="coerce", utc=True, cache=cache) + expected = DatetimeIndex( + ["2000-01-01 08:00:00+00:00", "2000-06-01 07:00:00+00:00"], + dtype="datetime64[ns, UTC]", + freq=None, + ) + tm.assert_index_equal(result, expected) + + # dtype coercion + i = DatetimeIndex( + ["2000-01-01 08:00:00"], + tz=psycopg2_tz.FixedOffsetTimezone(offset=-300, name=None), + ) + assert is_datetime64_ns_dtype(i) + + # tz coercion + result = to_datetime(i, errors="coerce", cache=cache) + tm.assert_index_equal(result, i) + + result = to_datetime(i, errors="coerce", utc=True, cache=cache) + expected = DatetimeIndex(["2000-01-01 13:00:00"], dtype="datetime64[ns, UTC]") + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("arg", [True, False]) + def test_datetime_bool(self, cache, arg): + # GH13176 + msg = r"dtype bool cannot be converted to datetime64\[ns\]" + with pytest.raises(TypeError, match=msg): + to_datetime(arg) + assert to_datetime(arg, errors="coerce", cache=cache) is NaT + assert to_datetime(arg, errors="ignore", cache=cache) is arg + + def test_datetime_bool_arrays_mixed(self, cache): + msg = f"{type(cache)} is not convertible to datetime" + with pytest.raises(TypeError, match=msg): + to_datetime([False, datetime.today()], cache=cache) + with pytest.raises( + ValueError, + match=( + r'^time data "True" doesn\'t match format "%Y%m%d", ' + f"at position 1. {PARSING_ERR_MSG}$" + ), + ): + to_datetime(["20130101", True], cache=cache) + tm.assert_index_equal( + to_datetime([0, False, NaT, 0.0], errors="coerce", cache=cache), + DatetimeIndex( + [to_datetime(0, cache=cache), NaT, NaT, to_datetime(0, cache=cache)] + ), + ) + + @pytest.mark.parametrize("arg", [bool, to_datetime]) + def test_datetime_invalid_datatype(self, arg): + # GH13176 + msg = "is not convertible to datetime" + with pytest.raises(TypeError, match=msg): + to_datetime(arg) + + @pytest.mark.parametrize("errors", ["coerce", "raise", "ignore"]) + def test_invalid_format_raises(self, errors): + # https://github.com/pandas-dev/pandas/issues/50255 + with pytest.raises( + ValueError, match="':' is a bad directive in format 'H%:M%:S%" + ): + to_datetime(["00:00:00"], format="H%:M%:S%", errors=errors) + + @pytest.mark.parametrize("value", ["a", "00:01:99"]) + @pytest.mark.parametrize("format", [None, "%H:%M:%S"]) + def test_datetime_invalid_scalar(self, value, format): + # GH24763 + res = to_datetime(value, errors="ignore", format=format) + assert res == value + + res = to_datetime(value, errors="coerce", format=format) + assert res is NaT + + msg = "|".join( + [ + r'^time data "a" doesn\'t match format "%H:%M:%S", at position 0. ' + f"{PARSING_ERR_MSG}$", + r'^Given date string "a" not likely a datetime, at position 0$', + r'^unconverted data remains when parsing with format "%H:%M:%S": "9", ' + f"at position 0. {PARSING_ERR_MSG}$", + r"^second must be in 0..59: 00:01:99, at position 0$", + ] + ) + with pytest.raises(ValueError, match=msg): + to_datetime(value, errors="raise", format=format) + + @pytest.mark.parametrize("value", ["3000/12/11 00:00:00"]) + @pytest.mark.parametrize("format", [None, "%H:%M:%S"]) + def test_datetime_outofbounds_scalar(self, value, format): + # GH24763 + res = to_datetime(value, errors="ignore", format=format) + assert res == value + + res = to_datetime(value, errors="coerce", format=format) + assert res is NaT + + if format is not None: + msg = r'^time data ".*" doesn\'t match format ".*", at position 0.' + with pytest.raises(ValueError, match=msg): + to_datetime(value, errors="raise", format=format) + else: + msg = "^Out of bounds .*, at position 0$" + with pytest.raises(OutOfBoundsDatetime, match=msg): + to_datetime(value, errors="raise", format=format) + + @pytest.mark.parametrize( + ("values"), [(["a"]), (["00:01:99"]), (["a", "b", "99:00:00"])] + ) + @pytest.mark.parametrize("format", [(None), ("%H:%M:%S")]) + def test_datetime_invalid_index(self, values, format): + # GH24763 + # Not great to have logic in tests, but this one's hard to + # parametrise over + if format is None and len(values) > 1: + warn = UserWarning + else: + warn = None + with tm.assert_produces_warning( + warn, match="Could not infer format", raise_on_extra_warnings=False + ): + res = to_datetime(values, errors="ignore", format=format) + tm.assert_index_equal(res, Index(values, dtype=object)) + + with tm.assert_produces_warning( + warn, match="Could not infer format", raise_on_extra_warnings=False + ): + res = to_datetime(values, errors="coerce", format=format) + tm.assert_index_equal(res, DatetimeIndex([NaT] * len(values))) + + msg = "|".join( + [ + r'^Given date string "a" not likely a datetime, at position 0$', + r'^time data "a" doesn\'t match format "%H:%M:%S", at position 0. ' + f"{PARSING_ERR_MSG}$", + r'^unconverted data remains when parsing with format "%H:%M:%S": "9", ' + f"at position 0. {PARSING_ERR_MSG}$", + r"^second must be in 0..59: 00:01:99, at position 0$", + ] + ) + with pytest.raises(ValueError, match=msg): + with tm.assert_produces_warning( + warn, match="Could not infer format", raise_on_extra_warnings=False + ): + to_datetime(values, errors="raise", format=format) + + @pytest.mark.parametrize("utc", [True, None]) + @pytest.mark.parametrize("format", ["%Y%m%d %H:%M:%S", None]) + @pytest.mark.parametrize("constructor", [list, tuple, np.array, Index, deque]) + def test_to_datetime_cache(self, utc, format, constructor): + date = "20130101 00:00:00" + test_dates = [date] * 10**5 + data = constructor(test_dates) + + result = to_datetime(data, utc=utc, format=format, cache=True) + expected = to_datetime(data, utc=utc, format=format, cache=False) + + tm.assert_index_equal(result, expected) + + def test_to_datetime_from_deque(self): + # GH 29403 + result = to_datetime(deque([Timestamp("2010-06-02 09:30:00")] * 51)) + expected = to_datetime([Timestamp("2010-06-02 09:30:00")] * 51) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("utc", [True, None]) + @pytest.mark.parametrize("format", ["%Y%m%d %H:%M:%S", None]) + def test_to_datetime_cache_series(self, utc, format): + date = "20130101 00:00:00" + test_dates = [date] * 10**5 + data = Series(test_dates) + result = to_datetime(data, utc=utc, format=format, cache=True) + expected = to_datetime(data, utc=utc, format=format, cache=False) + tm.assert_series_equal(result, expected) + + def test_to_datetime_cache_scalar(self): + date = "20130101 00:00:00" + result = to_datetime(date, cache=True) + expected = Timestamp("20130101 00:00:00") + assert result == expected + + @pytest.mark.parametrize( + "datetimelikes,expected_values", + ( + ( + (None, np.nan) + (NaT,) * start_caching_at, + (NaT,) * (start_caching_at + 2), + ), + ( + (None, Timestamp("2012-07-26")) + (NaT,) * start_caching_at, + (NaT, Timestamp("2012-07-26")) + (NaT,) * start_caching_at, + ), + ( + (None,) + + (NaT,) * start_caching_at + + ("2012 July 26", Timestamp("2012-07-26")), + (NaT,) * (start_caching_at + 1) + + (Timestamp("2012-07-26"), Timestamp("2012-07-26")), + ), + ), + ) + def test_convert_object_to_datetime_with_cache( + self, datetimelikes, expected_values + ): + # GH#39882 + ser = Series( + datetimelikes, + dtype="object", + ) + result_series = to_datetime(ser, errors="coerce") + expected_series = Series( + expected_values, + dtype="datetime64[ns]", + ) + tm.assert_series_equal(result_series, expected_series) + + @pytest.mark.parametrize("cache", [True, False]) + @pytest.mark.parametrize( + "input", + [ + Series([NaT] * 20 + [None] * 20, dtype="object"), + Series([NaT] * 60 + [None] * 60, dtype="object"), + Series([None] * 20), + Series([None] * 60), + Series([""] * 20), + Series([""] * 60), + Series([pd.NA] * 20), + Series([pd.NA] * 60), + Series([np.nan] * 20), + Series([np.nan] * 60), + ], + ) + def test_to_datetime_converts_null_like_to_nat(self, cache, input): + # GH35888 + expected = Series([NaT] * len(input), dtype="M8[ns]") + result = to_datetime(input, cache=cache) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "date, format", + [ + ("2017-20", "%Y-%W"), + ("20 Sunday", "%W %A"), + ("20 Sun", "%W %a"), + ("2017-21", "%Y-%U"), + ("20 Sunday", "%U %A"), + ("20 Sun", "%U %a"), + ], + ) + def test_week_without_day_and_calendar_year(self, date, format): + # GH16774 + + msg = "Cannot use '%W' or '%U' without day and year" + with pytest.raises(ValueError, match=msg): + to_datetime(date, format=format) + + def test_to_datetime_coerce(self): + # GH 26122 + ts_strings = [ + "March 1, 2018 12:00:00+0400", + "March 1, 2018 12:00:00+0500", + "20100240", + ] + msg = "parsing datetimes with mixed time zones will raise an error" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = to_datetime(ts_strings, errors="coerce") + expected = Index( + [ + datetime(2018, 3, 1, 12, 0, tzinfo=tzoffset(None, 14400)), + datetime(2018, 3, 1, 12, 0, tzinfo=tzoffset(None, 18000)), + NaT, + ] + ) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "string_arg, format", + [("March 1, 2018", "%B %d, %Y"), ("2018-03-01", "%Y-%m-%d")], + ) + @pytest.mark.parametrize( + "outofbounds", + [ + datetime(9999, 1, 1), + date(9999, 1, 1), + np.datetime64("9999-01-01"), + "January 1, 9999", + "9999-01-01", + ], + ) + def test_to_datetime_coerce_oob(self, string_arg, format, outofbounds): + # https://github.com/pandas-dev/pandas/issues/50255 + ts_strings = [string_arg, outofbounds] + result = to_datetime(ts_strings, errors="coerce", format=format) + expected = DatetimeIndex([datetime(2018, 3, 1), NaT]) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "errors, expected", + [ + ("coerce", Index([NaT, NaT])), + ("ignore", Index(["200622-12-31", "111111-24-11"], dtype=object)), + ], + ) + def test_to_datetime_malformed_no_raise(self, errors, expected): + # GH 28299 + # GH 48633 + ts_strings = ["200622-12-31", "111111-24-11"] + with tm.assert_produces_warning( + UserWarning, match="Could not infer format", raise_on_extra_warnings=False + ): + result = to_datetime(ts_strings, errors=errors) + tm.assert_index_equal(result, expected) + + def test_to_datetime_malformed_raise(self): + # GH 48633 + ts_strings = ["200622-12-31", "111111-24-11"] + msg = ( + 'Parsed string "200622-12-31" gives an invalid tzoffset, which must ' + r"be between -timedelta\(hours=24\) and timedelta\(hours=24\), " + "at position 0" + ) + with pytest.raises( + ValueError, + match=msg, + ): + with tm.assert_produces_warning( + UserWarning, match="Could not infer format" + ): + to_datetime( + ts_strings, + errors="raise", + ) + + def test_iso_8601_strings_with_same_offset(self): + # GH 17697, 11736 + ts_str = "2015-11-18 15:30:00+05:30" + result = to_datetime(ts_str) + expected = Timestamp(ts_str) + assert result == expected + + expected = DatetimeIndex([Timestamp(ts_str)] * 2) + result = to_datetime([ts_str] * 2) + tm.assert_index_equal(result, expected) + + result = DatetimeIndex([ts_str] * 2) + tm.assert_index_equal(result, expected) + + def test_iso_8601_strings_with_different_offsets(self): + # GH 17697, 11736, 50887 + ts_strings = ["2015-11-18 15:30:00+05:30", "2015-11-18 16:30:00+06:30", NaT] + msg = "parsing datetimes with mixed time zones will raise an error" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = to_datetime(ts_strings) + expected = np.array( + [ + datetime(2015, 11, 18, 15, 30, tzinfo=tzoffset(None, 19800)), + datetime(2015, 11, 18, 16, 30, tzinfo=tzoffset(None, 23400)), + NaT, + ], + dtype=object, + ) + # GH 21864 + expected = Index(expected) + tm.assert_index_equal(result, expected) + + def test_iso_8601_strings_with_different_offsets_utc(self): + ts_strings = ["2015-11-18 15:30:00+05:30", "2015-11-18 16:30:00+06:30", NaT] + result = to_datetime(ts_strings, utc=True) + expected = DatetimeIndex( + [Timestamp(2015, 11, 18, 10), Timestamp(2015, 11, 18, 10), NaT], tz="UTC" + ) + tm.assert_index_equal(result, expected) + + def test_mixed_offsets_with_native_datetime_raises(self): + # GH 25978 + + vals = [ + "nan", + Timestamp("1990-01-01"), + "2015-03-14T16:15:14.123-08:00", + "2019-03-04T21:56:32.620-07:00", + None, + "today", + "now", + ] + ser = Series(vals) + assert all(ser[i] is vals[i] for i in range(len(vals))) # GH#40111 + + now = Timestamp("now") + today = Timestamp("today") + msg = "parsing datetimes with mixed time zones will raise an error" + with tm.assert_produces_warning(FutureWarning, match=msg): + mixed = to_datetime(ser) + expected = Series( + [ + "NaT", + Timestamp("1990-01-01"), + Timestamp("2015-03-14T16:15:14.123-08:00").to_pydatetime(), + Timestamp("2019-03-04T21:56:32.620-07:00").to_pydatetime(), + None, + ], + dtype=object, + ) + tm.assert_series_equal(mixed[:-2], expected) + # we'll check mixed[-1] and mixed[-2] match now and today to within + # call-timing tolerances + assert (now - mixed.iloc[-1]).total_seconds() <= 0.1 + assert (today - mixed.iloc[-2]).total_seconds() <= 0.1 + + with pytest.raises(ValueError, match="Tz-aware datetime.datetime"): + to_datetime(mixed) + + def test_non_iso_strings_with_tz_offset(self): + result = to_datetime(["March 1, 2018 12:00:00+0400"] * 2) + expected = DatetimeIndex( + [datetime(2018, 3, 1, 12, tzinfo=timezone(timedelta(minutes=240)))] * 2 + ) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "ts, expected", + [ + (Timestamp("2018-01-01"), Timestamp("2018-01-01", tz="UTC")), + ( + Timestamp("2018-01-01", tz="US/Pacific"), + Timestamp("2018-01-01 08:00", tz="UTC"), + ), + ], + ) + def test_timestamp_utc_true(self, ts, expected): + # GH 24415 + result = to_datetime(ts, utc=True) + assert result == expected + + @pytest.mark.parametrize("dt_str", ["00010101", "13000101", "30000101", "99990101"]) + def test_to_datetime_with_format_out_of_bounds(self, dt_str): + # GH 9107 + msg = "Out of bounds nanosecond timestamp" + with pytest.raises(OutOfBoundsDatetime, match=msg): + to_datetime(dt_str, format="%Y%m%d") + + def test_to_datetime_utc(self): + arr = np.array([parse("2012-06-13T01:39:00Z")], dtype=object) + + result = to_datetime(arr, utc=True) + assert result.tz is timezone.utc + + def test_to_datetime_fixed_offset(self): + from pandas.tests.indexes.datetimes.test_timezones import FixedOffset + + fixed_off = FixedOffset(-420, "-07:00") + + dates = [ + datetime(2000, 1, 1, tzinfo=fixed_off), + datetime(2000, 1, 2, tzinfo=fixed_off), + datetime(2000, 1, 3, tzinfo=fixed_off), + ] + result = to_datetime(dates) + assert result.tz == fixed_off + + @pytest.mark.parametrize( + "date", + [ + ["2020-10-26 00:00:00+06:00", "2020-10-26 00:00:00+01:00"], + ["2020-10-26 00:00:00+06:00", Timestamp("2018-01-01", tz="US/Pacific")], + [ + "2020-10-26 00:00:00+06:00", + datetime(2020, 1, 1, 18, tzinfo=pytz.timezone("Australia/Melbourne")), + ], + ], + ) + def test_to_datetime_mixed_offsets_with_utc_false_deprecated(self, date): + # GH 50887 + msg = "parsing datetimes with mixed time zones will raise an error" + with tm.assert_produces_warning(FutureWarning, match=msg): + to_datetime(date, utc=False) + + +class TestToDatetimeUnit: + @pytest.mark.parametrize("unit", ["Y", "M"]) + @pytest.mark.parametrize("item", [150, float(150)]) + def test_to_datetime_month_or_year_unit_int(self, cache, unit, item, request): + # GH#50870 Note we have separate tests that pd.Timestamp gets these right + ts = Timestamp(item, unit=unit) + expected = DatetimeIndex([ts], dtype="M8[ns]") + + result = to_datetime([item], unit=unit, cache=cache) + tm.assert_index_equal(result, expected) + + result = to_datetime(np.array([item], dtype=object), unit=unit, cache=cache) + tm.assert_index_equal(result, expected) + + result = to_datetime(np.array([item]), unit=unit, cache=cache) + tm.assert_index_equal(result, expected) + + # with a nan! + result = to_datetime(np.array([item, np.nan]), unit=unit, cache=cache) + assert result.isna()[1] + tm.assert_index_equal(result[:1], expected) + + @pytest.mark.parametrize("unit", ["Y", "M"]) + def test_to_datetime_month_or_year_unit_non_round_float(self, cache, unit): + # GH#50301 + # Match Timestamp behavior in disallowing non-round floats with + # Y or M unit + warn_msg = "strings will be parsed as datetime strings" + msg = f"Conversion of non-round float with unit={unit} is ambiguous" + with pytest.raises(ValueError, match=msg): + to_datetime([1.5], unit=unit, errors="raise") + with pytest.raises(ValueError, match=msg): + to_datetime(np.array([1.5]), unit=unit, errors="raise") + with pytest.raises(ValueError, match=msg): + with tm.assert_produces_warning(FutureWarning, match=warn_msg): + to_datetime(["1.5"], unit=unit, errors="raise") + + # with errors="ignore" we also end up raising within the Timestamp + # constructor; this may not be ideal + with pytest.raises(ValueError, match=msg): + to_datetime([1.5], unit=unit, errors="ignore") + + res = to_datetime([1.5], unit=unit, errors="coerce") + expected = Index([NaT], dtype="M8[ns]") + tm.assert_index_equal(res, expected) + + with tm.assert_produces_warning(FutureWarning, match=warn_msg): + res = to_datetime(["1.5"], unit=unit, errors="coerce") + tm.assert_index_equal(res, expected) + + # round floats are OK + res = to_datetime([1.0], unit=unit) + expected = to_datetime([1], unit=unit) + tm.assert_index_equal(res, expected) + + def test_unit(self, cache): + # GH 11758 + # test proper behavior with errors + msg = "cannot specify both format and unit" + with pytest.raises(ValueError, match=msg): + to_datetime([1], unit="D", format="%Y%m%d", cache=cache) + + def test_unit_str(self, cache): + # GH 57051 + # Test that strs aren't dropping precision to 32-bit accidentally. + with tm.assert_produces_warning(FutureWarning): + res = to_datetime(["1704660000"], unit="s", origin="unix") + expected = to_datetime([1704660000], unit="s", origin="unix") + tm.assert_index_equal(res, expected) + + def test_unit_array_mixed_nans(self, cache): + values = [11111111111111111, 1, 1.0, iNaT, NaT, np.nan, "NaT", ""] + result = to_datetime(values, unit="D", errors="ignore", cache=cache) + expected = Index( + [ + 11111111111111111, + Timestamp("1970-01-02"), + Timestamp("1970-01-02"), + NaT, + NaT, + NaT, + NaT, + NaT, + ], + dtype=object, + ) + tm.assert_index_equal(result, expected) + + result = to_datetime(values, unit="D", errors="coerce", cache=cache) + expected = DatetimeIndex( + ["NaT", "1970-01-02", "1970-01-02", "NaT", "NaT", "NaT", "NaT", "NaT"], + dtype="M8[ns]", + ) + tm.assert_index_equal(result, expected) + + msg = "cannot convert input 11111111111111111 with the unit 'D'" + with pytest.raises(OutOfBoundsDatetime, match=msg): + to_datetime(values, unit="D", errors="raise", cache=cache) + + def test_unit_array_mixed_nans_large_int(self, cache): + values = [1420043460000000000000000, iNaT, NaT, np.nan, "NaT"] + + result = to_datetime(values, errors="ignore", unit="s", cache=cache) + expected = Index([1420043460000000000000000, NaT, NaT, NaT, NaT], dtype=object) + tm.assert_index_equal(result, expected) + + result = to_datetime(values, errors="coerce", unit="s", cache=cache) + expected = DatetimeIndex(["NaT", "NaT", "NaT", "NaT", "NaT"], dtype="M8[ns]") + tm.assert_index_equal(result, expected) + + msg = "cannot convert input 1420043460000000000000000 with the unit 's'" + with pytest.raises(OutOfBoundsDatetime, match=msg): + to_datetime(values, errors="raise", unit="s", cache=cache) + + def test_to_datetime_invalid_str_not_out_of_bounds_valuerror(self, cache): + # if we have a string, then we raise a ValueError + # and NOT an OutOfBoundsDatetime + msg = "non convertible value foo with the unit 's'" + with pytest.raises(ValueError, match=msg): + to_datetime("foo", errors="raise", unit="s", cache=cache) + + @pytest.mark.parametrize("error", ["raise", "coerce", "ignore"]) + def test_unit_consistency(self, cache, error): + # consistency of conversions + expected = Timestamp("1970-05-09 14:25:11") + result = to_datetime(11111111, unit="s", errors=error, cache=cache) + assert result == expected + assert isinstance(result, Timestamp) + + @pytest.mark.parametrize("errors", ["ignore", "raise", "coerce"]) + @pytest.mark.parametrize("dtype", ["float64", "int64"]) + def test_unit_with_numeric(self, cache, errors, dtype): + # GH 13180 + # coercions from floats/ints are ok + expected = DatetimeIndex( + ["2015-06-19 05:33:20", "2015-05-27 22:33:20"], dtype="M8[ns]" + ) + arr = np.array([1.434692e18, 1.432766e18]).astype(dtype) + result = to_datetime(arr, errors=errors, cache=cache) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "exp, arr, warning", + [ + [ + ["NaT", "2015-06-19 05:33:20", "2015-05-27 22:33:20"], + ["foo", 1.434692e18, 1.432766e18], + UserWarning, + ], + [ + ["2015-06-19 05:33:20", "2015-05-27 22:33:20", "NaT", "NaT"], + [1.434692e18, 1.432766e18, "foo", "NaT"], + None, + ], + ], + ) + def test_unit_with_numeric_coerce(self, cache, exp, arr, warning): + # but we want to make sure that we are coercing + # if we have ints/strings + expected = DatetimeIndex(exp, dtype="M8[ns]") + with tm.assert_produces_warning(warning, match="Could not infer format"): + result = to_datetime(arr, errors="coerce", cache=cache) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "arr", + [ + [Timestamp("20130101"), 1.434692e18, 1.432766e18], + [1.434692e18, 1.432766e18, Timestamp("20130101")], + ], + ) + def test_unit_mixed(self, cache, arr): + # GH#50453 pre-2.0 with mixed numeric/datetimes and errors="coerce" + # the numeric entries would be coerced to NaT, was never clear exactly + # why. + # mixed integers/datetimes + expected = Index([Timestamp(x) for x in arr], dtype="M8[ns]") + result = to_datetime(arr, errors="coerce", cache=cache) + tm.assert_index_equal(result, expected) + + # GH#49037 pre-2.0 this raised, but it always worked with Series, + # was never clear why it was disallowed + result = to_datetime(arr, errors="raise", cache=cache) + tm.assert_index_equal(result, expected) + + result = DatetimeIndex(arr) + tm.assert_index_equal(result, expected) + + def test_unit_rounding(self, cache): + # GH 14156 & GH 20445: argument will incur floating point errors + # but no premature rounding + value = 1434743731.8770001 + result = to_datetime(value, unit="s", cache=cache) + expected = Timestamp("2015-06-19 19:55:31.877000093") + assert result == expected + + alt = Timestamp(value, unit="s") + assert alt == result + + def test_unit_ignore_keeps_name(self, cache): + # GH 21697 + expected = Index([15e9] * 2, name="name") + result = to_datetime(expected, errors="ignore", unit="s", cache=cache) + tm.assert_index_equal(result, expected) + + def test_to_datetime_errors_ignore_utc_true(self): + # GH#23758 + result = to_datetime([1], unit="s", utc=True, errors="ignore") + expected = DatetimeIndex(["1970-01-01 00:00:01"], dtype="M8[ns, UTC]") + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("dtype", [int, float]) + def test_to_datetime_unit(self, dtype): + epoch = 1370745748 + ser = Series([epoch + t for t in range(20)]).astype(dtype) + result = to_datetime(ser, unit="s") + expected = Series( + [ + Timestamp("2013-06-09 02:42:28") + timedelta(seconds=t) + for t in range(20) + ], + dtype="M8[ns]", + ) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("null", [iNaT, np.nan]) + def test_to_datetime_unit_with_nulls(self, null): + epoch = 1370745748 + ser = Series([epoch + t for t in range(20)] + [null]) + result = to_datetime(ser, unit="s") + expected = Series( + [Timestamp("2013-06-09 02:42:28") + timedelta(seconds=t) for t in range(20)] + + [NaT], + dtype="M8[ns]", + ) + tm.assert_series_equal(result, expected) + + def test_to_datetime_unit_fractional_seconds(self): + # GH13834 + epoch = 1370745748 + ser = Series([epoch + t for t in np.arange(0, 2, 0.25)] + [iNaT]).astype(float) + result = to_datetime(ser, unit="s") + expected = Series( + [ + Timestamp("2013-06-09 02:42:28") + timedelta(seconds=t) + for t in np.arange(0, 2, 0.25) + ] + + [NaT], + dtype="M8[ns]", + ) + # GH20455 argument will incur floating point errors but no premature rounding + result = result.round("ms") + tm.assert_series_equal(result, expected) + + def test_to_datetime_unit_na_values(self): + result = to_datetime([1, 2, "NaT", NaT, np.nan], unit="D") + expected = DatetimeIndex( + [Timestamp("1970-01-02"), Timestamp("1970-01-03")] + ["NaT"] * 3, + dtype="M8[ns]", + ) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("bad_val", ["foo", 111111111]) + def test_to_datetime_unit_invalid(self, bad_val): + msg = f"{bad_val} with the unit 'D'" + with pytest.raises(ValueError, match=msg): + to_datetime([1, 2, bad_val], unit="D") + + @pytest.mark.parametrize("bad_val", ["foo", 111111111]) + def test_to_timestamp_unit_coerce(self, bad_val): + # coerce we can process + expected = DatetimeIndex( + [Timestamp("1970-01-02"), Timestamp("1970-01-03")] + ["NaT"] * 1, + dtype="M8[ns]", + ) + result = to_datetime([1, 2, bad_val], unit="D", errors="coerce") + tm.assert_index_equal(result, expected) + + def test_float_to_datetime_raise_near_bounds(self): + # GH50183 + msg = "cannot convert input with unit 'D'" + oneday_in_ns = 1e9 * 60 * 60 * 24 + tsmax_in_days = 2**63 / oneday_in_ns # 2**63 ns, in days + # just in bounds + should_succeed = Series( + [0, tsmax_in_days - 0.005, -tsmax_in_days + 0.005], dtype=float + ) + expected = (should_succeed * oneday_in_ns).astype(np.int64) + for error_mode in ["raise", "coerce", "ignore"]: + result1 = to_datetime(should_succeed, unit="D", errors=error_mode) + # Cast to `np.float64` so that `rtol` and inexact checking kick in + # (`check_exact` doesn't take place for integer dtypes) + tm.assert_almost_equal( + result1.astype(np.int64).astype(np.float64), + expected.astype(np.float64), + rtol=1e-10, + ) + # just out of bounds + should_fail1 = Series([0, tsmax_in_days + 0.005], dtype=float) + should_fail2 = Series([0, -tsmax_in_days - 0.005], dtype=float) + with pytest.raises(OutOfBoundsDatetime, match=msg): + to_datetime(should_fail1, unit="D", errors="raise") + with pytest.raises(OutOfBoundsDatetime, match=msg): + to_datetime(should_fail2, unit="D", errors="raise") + + +class TestToDatetimeDataFrame: + @pytest.fixture + def df(self): + return DataFrame( + { + "year": [2015, 2016], + "month": [2, 3], + "day": [4, 5], + "hour": [6, 7], + "minute": [58, 59], + "second": [10, 11], + "ms": [1, 1], + "us": [2, 2], + "ns": [3, 3], + } + ) + + def test_dataframe(self, df, cache): + result = to_datetime( + {"year": df["year"], "month": df["month"], "day": df["day"]}, cache=cache + ) + expected = Series( + [Timestamp("20150204 00:00:00"), Timestamp("20160305 00:0:00")] + ) + tm.assert_series_equal(result, expected) + + # dict-like + result = to_datetime(df[["year", "month", "day"]].to_dict(), cache=cache) + tm.assert_series_equal(result, expected) + + def test_dataframe_dict_with_constructable(self, df, cache): + # dict but with constructable + df2 = df[["year", "month", "day"]].to_dict() + df2["month"] = 2 + result = to_datetime(df2, cache=cache) + expected2 = Series( + [Timestamp("20150204 00:00:00"), Timestamp("20160205 00:0:00")] + ) + tm.assert_series_equal(result, expected2) + + @pytest.mark.parametrize( + "unit", + [ + { + "year": "years", + "month": "months", + "day": "days", + "hour": "hours", + "minute": "minutes", + "second": "seconds", + }, + { + "year": "year", + "month": "month", + "day": "day", + "hour": "hour", + "minute": "minute", + "second": "second", + }, + ], + ) + def test_dataframe_field_aliases_column_subset(self, df, cache, unit): + # unit mappings + result = to_datetime(df[list(unit.keys())].rename(columns=unit), cache=cache) + expected = Series( + [Timestamp("20150204 06:58:10"), Timestamp("20160305 07:59:11")], + dtype="M8[ns]", + ) + tm.assert_series_equal(result, expected) + + def test_dataframe_field_aliases(self, df, cache): + d = { + "year": "year", + "month": "month", + "day": "day", + "hour": "hour", + "minute": "minute", + "second": "second", + "ms": "ms", + "us": "us", + "ns": "ns", + } + + result = to_datetime(df.rename(columns=d), cache=cache) + expected = Series( + [ + Timestamp("20150204 06:58:10.001002003"), + Timestamp("20160305 07:59:11.001002003"), + ] + ) + tm.assert_series_equal(result, expected) + + def test_dataframe_str_dtype(self, df, cache): + # coerce back to int + result = to_datetime(df.astype(str), cache=cache) + expected = Series( + [ + Timestamp("20150204 06:58:10.001002003"), + Timestamp("20160305 07:59:11.001002003"), + ] + ) + tm.assert_series_equal(result, expected) + + def test_dataframe_coerce(self, cache): + # passing coerce + df2 = DataFrame({"year": [2015, 2016], "month": [2, 20], "day": [4, 5]}) + + msg = ( + r'^cannot assemble the datetimes: time data ".+" doesn\'t ' + r'match format "%Y%m%d", at position 1\.' + ) + with pytest.raises(ValueError, match=msg): + to_datetime(df2, cache=cache) + + result = to_datetime(df2, errors="coerce", cache=cache) + expected = Series([Timestamp("20150204 00:00:00"), NaT]) + tm.assert_series_equal(result, expected) + + def test_dataframe_extra_keys_raisesm(self, df, cache): + # extra columns + msg = r"extra keys have been passed to the datetime assemblage: \[foo\]" + with pytest.raises(ValueError, match=msg): + df2 = df.copy() + df2["foo"] = 1 + to_datetime(df2, cache=cache) + + @pytest.mark.parametrize( + "cols", + [ + ["year"], + ["year", "month"], + ["year", "month", "second"], + ["month", "day"], + ["year", "day", "second"], + ], + ) + def test_dataframe_missing_keys_raises(self, df, cache, cols): + # not enough + msg = ( + r"to assemble mappings requires at least that \[year, month, " + r"day\] be specified: \[.+\] is missing" + ) + with pytest.raises(ValueError, match=msg): + to_datetime(df[cols], cache=cache) + + def test_dataframe_duplicate_columns_raises(self, cache): + # duplicates + msg = "cannot assemble with duplicate keys" + df2 = DataFrame({"year": [2015, 2016], "month": [2, 20], "day": [4, 5]}) + df2.columns = ["year", "year", "day"] + with pytest.raises(ValueError, match=msg): + to_datetime(df2, cache=cache) + + df2 = DataFrame( + {"year": [2015, 2016], "month": [2, 20], "day": [4, 5], "hour": [4, 5]} + ) + df2.columns = ["year", "month", "day", "day"] + with pytest.raises(ValueError, match=msg): + to_datetime(df2, cache=cache) + + def test_dataframe_int16(self, cache): + # GH#13451 + df = DataFrame({"year": [2015, 2016], "month": [2, 3], "day": [4, 5]}) + + # int16 + result = to_datetime(df.astype("int16"), cache=cache) + expected = Series( + [Timestamp("20150204 00:00:00"), Timestamp("20160305 00:00:00")] + ) + tm.assert_series_equal(result, expected) + + def test_dataframe_mixed(self, cache): + # mixed dtypes + df = DataFrame({"year": [2015, 2016], "month": [2, 3], "day": [4, 5]}) + df["month"] = df["month"].astype("int8") + df["day"] = df["day"].astype("int8") + result = to_datetime(df, cache=cache) + expected = Series( + [Timestamp("20150204 00:00:00"), Timestamp("20160305 00:00:00")] + ) + tm.assert_series_equal(result, expected) + + def test_dataframe_float(self, cache): + # float + df = DataFrame({"year": [2000, 2001], "month": [1.5, 1], "day": [1, 1]}) + msg = ( + r"^cannot assemble the datetimes: unconverted data remains when parsing " + r'with format ".*": "1", at position 0.' + ) + with pytest.raises(ValueError, match=msg): + to_datetime(df, cache=cache) + + def test_dataframe_utc_true(self): + # GH#23760 + df = DataFrame({"year": [2015, 2016], "month": [2, 3], "day": [4, 5]}) + result = to_datetime(df, utc=True) + expected = Series( + np.array(["2015-02-04", "2016-03-05"], dtype="datetime64[ns]") + ).dt.tz_localize("UTC") + tm.assert_series_equal(result, expected) + + +class TestToDatetimeMisc: + def test_to_datetime_barely_out_of_bounds(self): + # GH#19529 + # GH#19382 close enough to bounds that dropping nanos would result + # in an in-bounds datetime + arr = np.array(["2262-04-11 23:47:16.854775808"], dtype=object) + + msg = "^Out of bounds nanosecond timestamp: .*, at position 0" + with pytest.raises(OutOfBoundsDatetime, match=msg): + to_datetime(arr) + + @pytest.mark.parametrize( + "arg, exp_str", + [ + ["2012-01-01 00:00:00", "2012-01-01 00:00:00"], + ["20121001", "2012-10-01"], # bad iso 8601 + ], + ) + def test_to_datetime_iso8601(self, cache, arg, exp_str): + result = to_datetime([arg], cache=cache) + exp = Timestamp(exp_str) + assert result[0] == exp + + @pytest.mark.parametrize( + "input, format", + [ + ("2012", "%Y-%m"), + ("2012-01", "%Y-%m-%d"), + ("2012-01-01", "%Y-%m-%d %H"), + ("2012-01-01 10", "%Y-%m-%d %H:%M"), + ("2012-01-01 10:00", "%Y-%m-%d %H:%M:%S"), + ("2012-01-01 10:00:00", "%Y-%m-%d %H:%M:%S.%f"), + ("2012-01-01 10:00:00.123", "%Y-%m-%d %H:%M:%S.%f%z"), + (0, "%Y-%m-%d"), + ], + ) + @pytest.mark.parametrize("exact", [True, False]) + def test_to_datetime_iso8601_fails(self, input, format, exact): + # https://github.com/pandas-dev/pandas/issues/12649 + # `format` is longer than the string, so this fails regardless of `exact` + with pytest.raises( + ValueError, + match=( + rf"time data \"{input}\" doesn't match format " + rf"\"{format}\", at position 0" + ), + ): + to_datetime(input, format=format, exact=exact) + + @pytest.mark.parametrize( + "input, format", + [ + ("2012-01-01", "%Y-%m"), + ("2012-01-01 10", "%Y-%m-%d"), + ("2012-01-01 10:00", "%Y-%m-%d %H"), + ("2012-01-01 10:00:00", "%Y-%m-%d %H:%M"), + (0, "%Y-%m-%d"), + ], + ) + def test_to_datetime_iso8601_exact_fails(self, input, format): + # https://github.com/pandas-dev/pandas/issues/12649 + # `format` is shorter than the date string, so only fails with `exact=True` + msg = "|".join( + [ + '^unconverted data remains when parsing with format ".*": ".*"' + f", at position 0. {PARSING_ERR_MSG}$", + f'^time data ".*" doesn\'t match format ".*", at position 0. ' + f"{PARSING_ERR_MSG}$", + ] + ) + with pytest.raises( + ValueError, + match=(msg), + ): + to_datetime(input, format=format) + + @pytest.mark.parametrize( + "input, format", + [ + ("2012-01-01", "%Y-%m"), + ("2012-01-01 00", "%Y-%m-%d"), + ("2012-01-01 00:00", "%Y-%m-%d %H"), + ("2012-01-01 00:00:00", "%Y-%m-%d %H:%M"), + ], + ) + def test_to_datetime_iso8601_non_exact(self, input, format): + # https://github.com/pandas-dev/pandas/issues/12649 + expected = Timestamp(2012, 1, 1) + result = to_datetime(input, format=format, exact=False) + assert result == expected + + @pytest.mark.parametrize( + "input, format", + [ + ("2020-01", "%Y/%m"), + ("2020-01-01", "%Y/%m/%d"), + ("2020-01-01 00", "%Y/%m/%dT%H"), + ("2020-01-01T00", "%Y/%m/%d %H"), + ("2020-01-01 00:00", "%Y/%m/%dT%H:%M"), + ("2020-01-01T00:00", "%Y/%m/%d %H:%M"), + ("2020-01-01 00:00:00", "%Y/%m/%dT%H:%M:%S"), + ("2020-01-01T00:00:00", "%Y/%m/%d %H:%M:%S"), + ], + ) + def test_to_datetime_iso8601_separator(self, input, format): + # https://github.com/pandas-dev/pandas/issues/12649 + with pytest.raises( + ValueError, + match=( + rf"time data \"{input}\" doesn\'t match format " + rf"\"{format}\", at position 0" + ), + ): + to_datetime(input, format=format) + + @pytest.mark.parametrize( + "input, format", + [ + ("2020-01", "%Y-%m"), + ("2020-01-01", "%Y-%m-%d"), + ("2020-01-01 00", "%Y-%m-%d %H"), + ("2020-01-01T00", "%Y-%m-%dT%H"), + ("2020-01-01 00:00", "%Y-%m-%d %H:%M"), + ("2020-01-01T00:00", "%Y-%m-%dT%H:%M"), + ("2020-01-01 00:00:00", "%Y-%m-%d %H:%M:%S"), + ("2020-01-01T00:00:00", "%Y-%m-%dT%H:%M:%S"), + ("2020-01-01T00:00:00.000", "%Y-%m-%dT%H:%M:%S.%f"), + ("2020-01-01T00:00:00.000000", "%Y-%m-%dT%H:%M:%S.%f"), + ("2020-01-01T00:00:00.000000000", "%Y-%m-%dT%H:%M:%S.%f"), + ], + ) + def test_to_datetime_iso8601_valid(self, input, format): + # https://github.com/pandas-dev/pandas/issues/12649 + expected = Timestamp(2020, 1, 1) + result = to_datetime(input, format=format) + assert result == expected + + @pytest.mark.parametrize( + "input, format", + [ + ("2020-1", "%Y-%m"), + ("2020-1-1", "%Y-%m-%d"), + ("2020-1-1 0", "%Y-%m-%d %H"), + ("2020-1-1T0", "%Y-%m-%dT%H"), + ("2020-1-1 0:0", "%Y-%m-%d %H:%M"), + ("2020-1-1T0:0", "%Y-%m-%dT%H:%M"), + ("2020-1-1 0:0:0", "%Y-%m-%d %H:%M:%S"), + ("2020-1-1T0:0:0", "%Y-%m-%dT%H:%M:%S"), + ("2020-1-1T0:0:0.000", "%Y-%m-%dT%H:%M:%S.%f"), + ("2020-1-1T0:0:0.000000", "%Y-%m-%dT%H:%M:%S.%f"), + ("2020-1-1T0:0:0.000000000", "%Y-%m-%dT%H:%M:%S.%f"), + ], + ) + def test_to_datetime_iso8601_non_padded(self, input, format): + # https://github.com/pandas-dev/pandas/issues/21422 + expected = Timestamp(2020, 1, 1) + result = to_datetime(input, format=format) + assert result == expected + + @pytest.mark.parametrize( + "input, format", + [ + ("2020-01-01T00:00:00.000000000+00:00", "%Y-%m-%dT%H:%M:%S.%f%z"), + ("2020-01-01T00:00:00+00:00", "%Y-%m-%dT%H:%M:%S%z"), + ("2020-01-01T00:00:00Z", "%Y-%m-%dT%H:%M:%S%z"), + ], + ) + def test_to_datetime_iso8601_with_timezone_valid(self, input, format): + # https://github.com/pandas-dev/pandas/issues/12649 + expected = Timestamp(2020, 1, 1, tzinfo=pytz.UTC) + result = to_datetime(input, format=format) + assert result == expected + + def test_to_datetime_default(self, cache): + rs = to_datetime("2001", cache=cache) + xp = datetime(2001, 1, 1) + assert rs == xp + + @pytest.mark.xfail(reason="fails to enforce dayfirst=True, which would raise") + def test_to_datetime_respects_dayfirst(self, cache): + # dayfirst is essentially broken + + # The msg here is not important since it isn't actually raised yet. + msg = "Invalid date specified" + with pytest.raises(ValueError, match=msg): + # if dayfirst is respected, then this would parse as month=13, which + # would raise + with tm.assert_produces_warning(UserWarning, match="Provide format"): + to_datetime("01-13-2012", dayfirst=True, cache=cache) + + def test_to_datetime_on_datetime64_series(self, cache): + # #2699 + ser = Series(date_range("1/1/2000", periods=10)) + + result = to_datetime(ser, cache=cache) + assert result[0] == ser[0] + + def test_to_datetime_with_space_in_series(self, cache): + # GH 6428 + ser = Series(["10/18/2006", "10/18/2008", " "]) + msg = ( + r'^time data " " doesn\'t match format "%m/%d/%Y", ' + rf"at position 2. {PARSING_ERR_MSG}$" + ) + with pytest.raises(ValueError, match=msg): + to_datetime(ser, errors="raise", cache=cache) + result_coerce = to_datetime(ser, errors="coerce", cache=cache) + expected_coerce = Series([datetime(2006, 10, 18), datetime(2008, 10, 18), NaT]) + tm.assert_series_equal(result_coerce, expected_coerce) + result_ignore = to_datetime(ser, errors="ignore", cache=cache) + tm.assert_series_equal(result_ignore, ser) + + @td.skip_if_not_us_locale + def test_to_datetime_with_apply(self, cache): + # this is only locale tested with US/None locales + # GH 5195 + # with a format and coerce a single item to_datetime fails + td = Series(["May 04", "Jun 02", "Dec 11"], index=[1, 2, 3]) + expected = to_datetime(td, format="%b %y", cache=cache) + result = td.apply(to_datetime, format="%b %y", cache=cache) + tm.assert_series_equal(result, expected) + + def test_to_datetime_timezone_name(self): + # https://github.com/pandas-dev/pandas/issues/49748 + result = to_datetime("2020-01-01 00:00:00UTC", format="%Y-%m-%d %H:%M:%S%Z") + expected = Timestamp(2020, 1, 1).tz_localize("UTC") + assert result == expected + + @td.skip_if_not_us_locale + @pytest.mark.parametrize("errors", ["raise", "coerce", "ignore"]) + def test_to_datetime_with_apply_with_empty_str(self, cache, errors): + # this is only locale tested with US/None locales + # GH 5195, GH50251 + # with a format and coerce a single item to_datetime fails + td = Series(["May 04", "Jun 02", ""], index=[1, 2, 3]) + expected = to_datetime(td, format="%b %y", errors=errors, cache=cache) + + result = td.apply( + lambda x: to_datetime(x, format="%b %y", errors="coerce", cache=cache) + ) + tm.assert_series_equal(result, expected) + + def test_to_datetime_empty_stt(self, cache): + # empty string + result = to_datetime("", cache=cache) + assert result is NaT + + def test_to_datetime_empty_str_list(self, cache): + result = to_datetime(["", ""], cache=cache) + assert isna(result).all() + + def test_to_datetime_zero(self, cache): + # ints + result = Timestamp(0) + expected = to_datetime(0, cache=cache) + assert result == expected + + def test_to_datetime_strings(self, cache): + # GH 3888 (strings) + expected = to_datetime(["2012"], cache=cache)[0] + result = to_datetime("2012", cache=cache) + assert result == expected + + def test_to_datetime_strings_variation(self, cache): + array = ["2012", "20120101", "20120101 12:01:01"] + expected = [to_datetime(dt_str, cache=cache) for dt_str in array] + result = [Timestamp(date_str) for date_str in array] + tm.assert_almost_equal(result, expected) + + @pytest.mark.parametrize("result", [Timestamp("2012"), to_datetime("2012")]) + def test_to_datetime_strings_vs_constructor(self, result): + expected = Timestamp(2012, 1, 1) + assert result == expected + + def test_to_datetime_unprocessable_input(self, cache): + # GH 4928 + # GH 21864 + result = to_datetime([1, "1"], errors="ignore", cache=cache) + + expected = Index(np.array([1, "1"], dtype="O")) + tm.assert_equal(result, expected) + msg = '^Given date string "1" not likely a datetime, at position 1$' + with pytest.raises(ValueError, match=msg): + to_datetime([1, "1"], errors="raise", cache=cache) + + def test_to_datetime_unhashable_input(self, cache): + series = Series([["a"]] * 100) + result = to_datetime(series, errors="ignore", cache=cache) + tm.assert_series_equal(series, result) + + def test_to_datetime_other_datetime64_units(self): + # 5/25/2012 + scalar = np.int64(1337904000000000).view("M8[us]") + as_obj = scalar.astype("O") + + index = DatetimeIndex([scalar]) + assert index[0] == scalar.astype("O") + + value = Timestamp(scalar) + assert value == as_obj + + def test_to_datetime_list_of_integers(self): + rng = date_range("1/1/2000", periods=20) + rng = DatetimeIndex(rng.values) + + ints = list(rng.asi8) + + result = DatetimeIndex(ints) + + tm.assert_index_equal(rng, result) + + def test_to_datetime_overflow(self): + # gh-17637 + # we are overflowing Timedelta range here + msg = "Cannot cast 139999 days 00:00:00 to unit='ns' without overflow" + with pytest.raises(OutOfBoundsTimedelta, match=msg): + date_range(start="1/1/1700", freq="B", periods=100000) + + def test_string_invalid_operation(self, cache): + invalid = np.array(["87156549591102612381000001219H5"], dtype=object) + # GH #51084 + + with pytest.raises(ValueError, match="Unknown datetime string format"): + to_datetime(invalid, errors="raise", cache=cache) + + def test_string_na_nat_conversion(self, cache): + # GH #999, #858 + + strings = np.array(["1/1/2000", "1/2/2000", np.nan, "1/4/2000"], dtype=object) + + expected = np.empty(4, dtype="M8[ns]") + for i, val in enumerate(strings): + if isna(val): + expected[i] = iNaT + else: + expected[i] = parse(val) + + result = tslib.array_to_datetime(strings)[0] + tm.assert_almost_equal(result, expected) + + result2 = to_datetime(strings, cache=cache) + assert isinstance(result2, DatetimeIndex) + tm.assert_numpy_array_equal(result, result2.values) + + def test_string_na_nat_conversion_malformed(self, cache): + malformed = np.array(["1/100/2000", np.nan], dtype=object) + + # GH 10636, default is now 'raise' + msg = r"Unknown datetime string format" + with pytest.raises(ValueError, match=msg): + to_datetime(malformed, errors="raise", cache=cache) + + result = to_datetime(malformed, errors="ignore", cache=cache) + # GH 21864 + expected = Index(malformed, dtype=object) + tm.assert_index_equal(result, expected) + + with pytest.raises(ValueError, match=msg): + to_datetime(malformed, errors="raise", cache=cache) + + def test_string_na_nat_conversion_with_name(self, cache): + idx = ["a", "b", "c", "d", "e"] + series = Series( + ["1/1/2000", np.nan, "1/3/2000", np.nan, "1/5/2000"], index=idx, name="foo" + ) + dseries = Series( + [ + to_datetime("1/1/2000", cache=cache), + np.nan, + to_datetime("1/3/2000", cache=cache), + np.nan, + to_datetime("1/5/2000", cache=cache), + ], + index=idx, + name="foo", + ) + + result = to_datetime(series, cache=cache) + dresult = to_datetime(dseries, cache=cache) + + expected = Series(np.empty(5, dtype="M8[ns]"), index=idx) + for i in range(5): + x = series.iloc[i] + if isna(x): + expected.iloc[i] = NaT + else: + expected.iloc[i] = to_datetime(x, cache=cache) + + tm.assert_series_equal(result, expected, check_names=False) + assert result.name == "foo" + + tm.assert_series_equal(dresult, expected, check_names=False) + assert dresult.name == "foo" + + @pytest.mark.parametrize( + "unit", + ["h", "m", "s", "ms", "us", "ns"], + ) + def test_dti_constructor_numpy_timeunits(self, cache, unit): + # GH 9114 + dtype = np.dtype(f"M8[{unit}]") + base = to_datetime(["2000-01-01T00:00", "2000-01-02T00:00", "NaT"], cache=cache) + + values = base.values.astype(dtype) + + if unit in ["h", "m"]: + # we cast to closest supported unit + unit = "s" + exp_dtype = np.dtype(f"M8[{unit}]") + expected = DatetimeIndex(base.astype(exp_dtype)) + assert expected.dtype == exp_dtype + + tm.assert_index_equal(DatetimeIndex(values), expected) + tm.assert_index_equal(to_datetime(values, cache=cache), expected) + + def test_dayfirst(self, cache): + # GH 5917 + arr = ["10/02/2014", "11/02/2014", "12/02/2014"] + expected = DatetimeIndex( + [datetime(2014, 2, 10), datetime(2014, 2, 11), datetime(2014, 2, 12)] + ) + idx1 = DatetimeIndex(arr, dayfirst=True) + idx2 = DatetimeIndex(np.array(arr), dayfirst=True) + idx3 = to_datetime(arr, dayfirst=True, cache=cache) + idx4 = to_datetime(np.array(arr), dayfirst=True, cache=cache) + idx5 = DatetimeIndex(Index(arr), dayfirst=True) + idx6 = DatetimeIndex(Series(arr), dayfirst=True) + tm.assert_index_equal(expected, idx1) + tm.assert_index_equal(expected, idx2) + tm.assert_index_equal(expected, idx3) + tm.assert_index_equal(expected, idx4) + tm.assert_index_equal(expected, idx5) + tm.assert_index_equal(expected, idx6) + + def test_dayfirst_warnings_valid_input(self): + # GH 12585 + warning_msg = ( + "Parsing dates in .* format when dayfirst=.* was specified. " + "Pass `dayfirst=.*` or specify a format to silence this warning." + ) + + # CASE 1: valid input + arr = ["31/12/2014", "10/03/2011"] + expected = DatetimeIndex( + ["2014-12-31", "2011-03-10"], dtype="datetime64[ns]", freq=None + ) + + # A. dayfirst arg correct, no warning + res1 = to_datetime(arr, dayfirst=True) + tm.assert_index_equal(expected, res1) + + # B. dayfirst arg incorrect, warning + with tm.assert_produces_warning(UserWarning, match=warning_msg): + res2 = to_datetime(arr, dayfirst=False) + tm.assert_index_equal(expected, res2) + + def test_dayfirst_warnings_invalid_input(self): + # CASE 2: invalid input + # cannot consistently process with single format + # ValueError *always* raised + + # first in DD/MM/YYYY, second in MM/DD/YYYY + arr = ["31/12/2014", "03/30/2011"] + + with pytest.raises( + ValueError, + match=( + r'^time data "03/30/2011" doesn\'t match format ' + rf'"%d/%m/%Y", at position 1. {PARSING_ERR_MSG}$' + ), + ): + to_datetime(arr, dayfirst=True) + + @pytest.mark.parametrize("klass", [DatetimeIndex, DatetimeArray._from_sequence]) + def test_to_datetime_dta_tz(self, klass): + # GH#27733 + dti = date_range("2015-04-05", periods=3).rename("foo") + expected = dti.tz_localize("UTC") + + obj = klass(dti) + expected = klass(expected) + + result = to_datetime(obj, utc=True) + tm.assert_equal(result, expected) + + +class TestGuessDatetimeFormat: + @pytest.mark.parametrize( + "test_list", + [ + [ + "2011-12-30 00:00:00.000000", + "2011-12-30 00:00:00.000000", + "2011-12-30 00:00:00.000000", + ], + [np.nan, np.nan, "2011-12-30 00:00:00.000000"], + ["", "2011-12-30 00:00:00.000000"], + ["NaT", "2011-12-30 00:00:00.000000"], + ["2011-12-30 00:00:00.000000", "random_string"], + ["now", "2011-12-30 00:00:00.000000"], + ["today", "2011-12-30 00:00:00.000000"], + ], + ) + def test_guess_datetime_format_for_array(self, test_list): + expected_format = "%Y-%m-%d %H:%M:%S.%f" + test_array = np.array(test_list, dtype=object) + assert tools._guess_datetime_format_for_array(test_array) == expected_format + + @td.skip_if_not_us_locale + def test_guess_datetime_format_for_array_all_nans(self): + format_for_string_of_nans = tools._guess_datetime_format_for_array( + np.array([np.nan, np.nan, np.nan], dtype="O") + ) + assert format_for_string_of_nans is None + + +class TestToDatetimeInferFormat: + @pytest.mark.parametrize( + "test_format", ["%m-%d-%Y", "%m/%d/%Y %H:%M:%S.%f", "%Y-%m-%dT%H:%M:%S.%f"] + ) + def test_to_datetime_infer_datetime_format_consistent_format( + self, cache, test_format + ): + ser = Series(date_range("20000101", periods=50, freq="h")) + + s_as_dt_strings = ser.apply(lambda x: x.strftime(test_format)) + + with_format = to_datetime(s_as_dt_strings, format=test_format, cache=cache) + without_format = to_datetime(s_as_dt_strings, cache=cache) + + # Whether the format is explicitly passed, or + # it is inferred, the results should all be the same + tm.assert_series_equal(with_format, without_format) + + def test_to_datetime_inconsistent_format(self, cache): + data = ["01/01/2011 00:00:00", "01-02-2011 00:00:00", "2011-01-03T00:00:00"] + ser = Series(np.array(data)) + msg = ( + r'^time data "01-02-2011 00:00:00" doesn\'t match format ' + rf'"%m/%d/%Y %H:%M:%S", at position 1. {PARSING_ERR_MSG}$' + ) + with pytest.raises(ValueError, match=msg): + to_datetime(ser, cache=cache) + + def test_to_datetime_consistent_format(self, cache): + data = ["Jan/01/2011", "Feb/01/2011", "Mar/01/2011"] + ser = Series(np.array(data)) + result = to_datetime(ser, cache=cache) + expected = Series( + ["2011-01-01", "2011-02-01", "2011-03-01"], dtype="datetime64[ns]" + ) + tm.assert_series_equal(result, expected) + + def test_to_datetime_series_with_nans(self, cache): + ser = Series( + np.array( + ["01/01/2011 00:00:00", np.nan, "01/03/2011 00:00:00", np.nan], + dtype=object, + ) + ) + result = to_datetime(ser, cache=cache) + expected = Series( + ["2011-01-01", NaT, "2011-01-03", NaT], dtype="datetime64[ns]" + ) + tm.assert_series_equal(result, expected) + + def test_to_datetime_series_start_with_nans(self, cache): + ser = Series( + np.array( + [ + np.nan, + np.nan, + "01/01/2011 00:00:00", + "01/02/2011 00:00:00", + "01/03/2011 00:00:00", + ], + dtype=object, + ) + ) + + result = to_datetime(ser, cache=cache) + expected = Series( + [NaT, NaT, "2011-01-01", "2011-01-02", "2011-01-03"], dtype="datetime64[ns]" + ) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "tz_name, offset", + [("UTC", 0), ("UTC-3", 180), ("UTC+3", -180)], + ) + def test_infer_datetime_format_tz_name(self, tz_name, offset): + # GH 33133 + ser = Series([f"2019-02-02 08:07:13 {tz_name}"]) + result = to_datetime(ser) + tz = timezone(timedelta(minutes=offset)) + expected = Series([Timestamp("2019-02-02 08:07:13").tz_localize(tz)]) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "ts,zero_tz", + [ + ("2019-02-02 08:07:13", "Z"), + ("2019-02-02 08:07:13", ""), + ("2019-02-02 08:07:13.012345", "Z"), + ("2019-02-02 08:07:13.012345", ""), + ], + ) + def test_infer_datetime_format_zero_tz(self, ts, zero_tz): + # GH 41047 + ser = Series([ts + zero_tz]) + result = to_datetime(ser) + tz = pytz.utc if zero_tz == "Z" else None + expected = Series([Timestamp(ts, tz=tz)]) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("format", [None, "%Y-%m-%d"]) + def test_to_datetime_iso8601_noleading_0s(self, cache, format): + # GH 11871 + ser = Series(["2014-1-1", "2014-2-2", "2015-3-3"]) + expected = Series( + [ + Timestamp("2014-01-01"), + Timestamp("2014-02-02"), + Timestamp("2015-03-03"), + ] + ) + result = to_datetime(ser, format=format, cache=cache) + tm.assert_series_equal(result, expected) + + def test_parse_dates_infer_datetime_format_warning(self): + # GH 49024 + with tm.assert_produces_warning( + UserWarning, + match="The argument 'infer_datetime_format' is deprecated", + ): + to_datetime(["10-10-2000"], infer_datetime_format=True) + + +class TestDaysInMonth: + # tests for issue #10154 + + @pytest.mark.parametrize( + "arg, format", + [ + ["2015-02-29", None], + ["2015-02-29", "%Y-%m-%d"], + ["2015-02-32", "%Y-%m-%d"], + ["2015-04-31", "%Y-%m-%d"], + ], + ) + def test_day_not_in_month_coerce(self, cache, arg, format): + assert isna(to_datetime(arg, errors="coerce", format=format, cache=cache)) + + def test_day_not_in_month_raise(self, cache): + msg = "day is out of range for month: 2015-02-29, at position 0" + with pytest.raises(ValueError, match=msg): + to_datetime("2015-02-29", errors="raise", cache=cache) + + @pytest.mark.parametrize( + "arg, format, msg", + [ + ( + "2015-02-29", + "%Y-%m-%d", + f"^day is out of range for month, at position 0. {PARSING_ERR_MSG}$", + ), + ( + "2015-29-02", + "%Y-%d-%m", + f"^day is out of range for month, at position 0. {PARSING_ERR_MSG}$", + ), + ( + "2015-02-32", + "%Y-%m-%d", + '^unconverted data remains when parsing with format "%Y-%m-%d": "2", ' + f"at position 0. {PARSING_ERR_MSG}$", + ), + ( + "2015-32-02", + "%Y-%d-%m", + '^time data "2015-32-02" doesn\'t match format "%Y-%d-%m", ' + f"at position 0. {PARSING_ERR_MSG}$", + ), + ( + "2015-04-31", + "%Y-%m-%d", + f"^day is out of range for month, at position 0. {PARSING_ERR_MSG}$", + ), + ( + "2015-31-04", + "%Y-%d-%m", + f"^day is out of range for month, at position 0. {PARSING_ERR_MSG}$", + ), + ], + ) + def test_day_not_in_month_raise_value(self, cache, arg, format, msg): + # https://github.com/pandas-dev/pandas/issues/50462 + with pytest.raises(ValueError, match=msg): + to_datetime(arg, errors="raise", format=format, cache=cache) + + @pytest.mark.parametrize( + "expected, format", + [ + ["2015-02-29", None], + ["2015-02-29", "%Y-%m-%d"], + ["2015-02-29", "%Y-%m-%d"], + ["2015-04-31", "%Y-%m-%d"], + ], + ) + def test_day_not_in_month_ignore(self, cache, expected, format): + result = to_datetime(expected, errors="ignore", format=format, cache=cache) + assert result == expected + + +class TestDatetimeParsingWrappers: + @pytest.mark.parametrize( + "date_str, expected", + [ + ("2011-01-01", datetime(2011, 1, 1)), + ("2Q2005", datetime(2005, 4, 1)), + ("2Q05", datetime(2005, 4, 1)), + ("2005Q1", datetime(2005, 1, 1)), + ("05Q1", datetime(2005, 1, 1)), + ("2011Q3", datetime(2011, 7, 1)), + ("11Q3", datetime(2011, 7, 1)), + ("3Q2011", datetime(2011, 7, 1)), + ("3Q11", datetime(2011, 7, 1)), + # quarterly without space + ("2000Q4", datetime(2000, 10, 1)), + ("00Q4", datetime(2000, 10, 1)), + ("4Q2000", datetime(2000, 10, 1)), + ("4Q00", datetime(2000, 10, 1)), + ("2000q4", datetime(2000, 10, 1)), + ("2000-Q4", datetime(2000, 10, 1)), + ("00-Q4", datetime(2000, 10, 1)), + ("4Q-2000", datetime(2000, 10, 1)), + ("4Q-00", datetime(2000, 10, 1)), + ("00q4", datetime(2000, 10, 1)), + ("2005", datetime(2005, 1, 1)), + ("2005-11", datetime(2005, 11, 1)), + ("2005 11", datetime(2005, 11, 1)), + ("11-2005", datetime(2005, 11, 1)), + ("11 2005", datetime(2005, 11, 1)), + ("200511", datetime(2020, 5, 11)), + ("20051109", datetime(2005, 11, 9)), + ("20051109 10:15", datetime(2005, 11, 9, 10, 15)), + ("20051109 08H", datetime(2005, 11, 9, 8, 0)), + ("2005-11-09 10:15", datetime(2005, 11, 9, 10, 15)), + ("2005-11-09 08H", datetime(2005, 11, 9, 8, 0)), + ("2005/11/09 10:15", datetime(2005, 11, 9, 10, 15)), + ("2005/11/09 10:15:32", datetime(2005, 11, 9, 10, 15, 32)), + ("2005/11/09 10:15:32 AM", datetime(2005, 11, 9, 10, 15, 32)), + ("2005/11/09 10:15:32 PM", datetime(2005, 11, 9, 22, 15, 32)), + ("2005/11/09 08H", datetime(2005, 11, 9, 8, 0)), + ("Thu Sep 25 10:36:28 2003", datetime(2003, 9, 25, 10, 36, 28)), + ("Thu Sep 25 2003", datetime(2003, 9, 25)), + ("Sep 25 2003", datetime(2003, 9, 25)), + ("January 1 2014", datetime(2014, 1, 1)), + # GH#10537 + ("2014-06", datetime(2014, 6, 1)), + ("06-2014", datetime(2014, 6, 1)), + ("2014-6", datetime(2014, 6, 1)), + ("6-2014", datetime(2014, 6, 1)), + ("20010101 12", datetime(2001, 1, 1, 12)), + ("20010101 1234", datetime(2001, 1, 1, 12, 34)), + ("20010101 123456", datetime(2001, 1, 1, 12, 34, 56)), + ], + ) + def test_parsers(self, date_str, expected, cache): + # dateutil >= 2.5.0 defaults to yearfirst=True + # https://github.com/dateutil/dateutil/issues/217 + yearfirst = True + + result1, _ = parsing.parse_datetime_string_with_reso( + date_str, yearfirst=yearfirst + ) + result2 = to_datetime(date_str, yearfirst=yearfirst) + result3 = to_datetime([date_str], yearfirst=yearfirst) + # result5 is used below + result4 = to_datetime( + np.array([date_str], dtype=object), yearfirst=yearfirst, cache=cache + ) + result6 = DatetimeIndex([date_str], yearfirst=yearfirst) + # result7 is used below + result8 = DatetimeIndex(Index([date_str]), yearfirst=yearfirst) + result9 = DatetimeIndex(Series([date_str]), yearfirst=yearfirst) + + for res in [result1, result2]: + assert res == expected + for res in [result3, result4, result6, result8, result9]: + exp = DatetimeIndex([Timestamp(expected)]) + tm.assert_index_equal(res, exp) + + # these really need to have yearfirst, but we don't support + if not yearfirst: + result5 = Timestamp(date_str) + assert result5 == expected + result7 = date_range(date_str, freq="S", periods=1, yearfirst=yearfirst) + assert result7 == expected + + def test_na_values_with_cache( + self, cache, unique_nulls_fixture, unique_nulls_fixture2 + ): + # GH22305 + expected = Index([NaT, NaT], dtype="datetime64[ns]") + result = to_datetime([unique_nulls_fixture, unique_nulls_fixture2], cache=cache) + tm.assert_index_equal(result, expected) + + def test_parsers_nat(self): + # Test that each of several string-accepting methods return pd.NaT + result1, _ = parsing.parse_datetime_string_with_reso("NaT") + result2 = to_datetime("NaT") + result3 = Timestamp("NaT") + result4 = DatetimeIndex(["NaT"])[0] + assert result1 is NaT + assert result2 is NaT + assert result3 is NaT + assert result4 is NaT + + @pytest.mark.parametrize( + "date_str, dayfirst, yearfirst, expected", + [ + ("10-11-12", False, False, datetime(2012, 10, 11)), + ("10-11-12", True, False, datetime(2012, 11, 10)), + ("10-11-12", False, True, datetime(2010, 11, 12)), + ("10-11-12", True, True, datetime(2010, 12, 11)), + ("20/12/21", False, False, datetime(2021, 12, 20)), + ("20/12/21", True, False, datetime(2021, 12, 20)), + ("20/12/21", False, True, datetime(2020, 12, 21)), + ("20/12/21", True, True, datetime(2020, 12, 21)), + ], + ) + def test_parsers_dayfirst_yearfirst( + self, cache, date_str, dayfirst, yearfirst, expected + ): + # OK + # 2.5.1 10-11-12 [dayfirst=0, yearfirst=0] -> 2012-10-11 00:00:00 + # 2.5.2 10-11-12 [dayfirst=0, yearfirst=1] -> 2012-10-11 00:00:00 + # 2.5.3 10-11-12 [dayfirst=0, yearfirst=0] -> 2012-10-11 00:00:00 + + # OK + # 2.5.1 10-11-12 [dayfirst=0, yearfirst=1] -> 2010-11-12 00:00:00 + # 2.5.2 10-11-12 [dayfirst=0, yearfirst=1] -> 2010-11-12 00:00:00 + # 2.5.3 10-11-12 [dayfirst=0, yearfirst=1] -> 2010-11-12 00:00:00 + + # bug fix in 2.5.2 + # 2.5.1 10-11-12 [dayfirst=1, yearfirst=1] -> 2010-11-12 00:00:00 + # 2.5.2 10-11-12 [dayfirst=1, yearfirst=1] -> 2010-12-11 00:00:00 + # 2.5.3 10-11-12 [dayfirst=1, yearfirst=1] -> 2010-12-11 00:00:00 + + # OK + # 2.5.1 10-11-12 [dayfirst=1, yearfirst=0] -> 2012-11-10 00:00:00 + # 2.5.2 10-11-12 [dayfirst=1, yearfirst=0] -> 2012-11-10 00:00:00 + # 2.5.3 10-11-12 [dayfirst=1, yearfirst=0] -> 2012-11-10 00:00:00 + + # OK + # 2.5.1 20/12/21 [dayfirst=0, yearfirst=0] -> 2021-12-20 00:00:00 + # 2.5.2 20/12/21 [dayfirst=0, yearfirst=0] -> 2021-12-20 00:00:00 + # 2.5.3 20/12/21 [dayfirst=0, yearfirst=0] -> 2021-12-20 00:00:00 + + # OK + # 2.5.1 20/12/21 [dayfirst=0, yearfirst=1] -> 2020-12-21 00:00:00 + # 2.5.2 20/12/21 [dayfirst=0, yearfirst=1] -> 2020-12-21 00:00:00 + # 2.5.3 20/12/21 [dayfirst=0, yearfirst=1] -> 2020-12-21 00:00:00 + + # revert of bug in 2.5.2 + # 2.5.1 20/12/21 [dayfirst=1, yearfirst=1] -> 2020-12-21 00:00:00 + # 2.5.2 20/12/21 [dayfirst=1, yearfirst=1] -> month must be in 1..12 + # 2.5.3 20/12/21 [dayfirst=1, yearfirst=1] -> 2020-12-21 00:00:00 + + # OK + # 2.5.1 20/12/21 [dayfirst=1, yearfirst=0] -> 2021-12-20 00:00:00 + # 2.5.2 20/12/21 [dayfirst=1, yearfirst=0] -> 2021-12-20 00:00:00 + # 2.5.3 20/12/21 [dayfirst=1, yearfirst=0] -> 2021-12-20 00:00:00 + + # str : dayfirst, yearfirst, expected + + # compare with dateutil result + dateutil_result = parse(date_str, dayfirst=dayfirst, yearfirst=yearfirst) + assert dateutil_result == expected + + result1, _ = parsing.parse_datetime_string_with_reso( + date_str, dayfirst=dayfirst, yearfirst=yearfirst + ) + + # we don't support dayfirst/yearfirst here: + if not dayfirst and not yearfirst: + result2 = Timestamp(date_str) + assert result2 == expected + + result3 = to_datetime( + date_str, dayfirst=dayfirst, yearfirst=yearfirst, cache=cache + ) + + result4 = DatetimeIndex([date_str], dayfirst=dayfirst, yearfirst=yearfirst)[0] + + assert result1 == expected + assert result3 == expected + assert result4 == expected + + @pytest.mark.parametrize( + "date_str, exp_def", + [["10:15", datetime(1, 1, 1, 10, 15)], ["9:05", datetime(1, 1, 1, 9, 5)]], + ) + def test_parsers_timestring(self, date_str, exp_def): + # must be the same as dateutil result + exp_now = parse(date_str) + + result1, _ = parsing.parse_datetime_string_with_reso(date_str) + result2 = to_datetime(date_str) + result3 = to_datetime([date_str]) + result4 = Timestamp(date_str) + result5 = DatetimeIndex([date_str])[0] + # parse time string return time string based on default date + # others are not, and can't be changed because it is used in + # time series plot + assert result1 == exp_def + assert result2 == exp_now + assert result3 == exp_now + assert result4 == exp_now + assert result5 == exp_now + + @pytest.mark.parametrize( + "dt_string, tz, dt_string_repr", + [ + ( + "2013-01-01 05:45+0545", + timezone(timedelta(minutes=345)), + "Timestamp('2013-01-01 05:45:00+0545', tz='UTC+05:45')", + ), + ( + "2013-01-01 05:30+0530", + timezone(timedelta(minutes=330)), + "Timestamp('2013-01-01 05:30:00+0530', tz='UTC+05:30')", + ), + ], + ) + def test_parsers_timezone_minute_offsets_roundtrip( + self, cache, dt_string, tz, dt_string_repr + ): + # GH11708 + base = to_datetime("2013-01-01 00:00:00", cache=cache) + base = base.tz_localize("UTC").tz_convert(tz) + dt_time = to_datetime(dt_string, cache=cache) + assert base == dt_time + assert dt_string_repr == repr(dt_time) + + +@pytest.fixture(params=["D", "s", "ms", "us", "ns"]) +def units(request): + """Day and some time units. + + * D + * s + * ms + * us + * ns + """ + return request.param + + +@pytest.fixture +def epoch_1960(): + """Timestamp at 1960-01-01.""" + return Timestamp("1960-01-01") + + +@pytest.fixture +def units_from_epochs(): + return list(range(5)) + + +@pytest.fixture(params=["timestamp", "pydatetime", "datetime64", "str_1960"]) +def epochs(epoch_1960, request): + """Timestamp at 1960-01-01 in various forms. + + * Timestamp + * datetime.datetime + * numpy.datetime64 + * str + """ + assert request.param in {"timestamp", "pydatetime", "datetime64", "str_1960"} + if request.param == "timestamp": + return epoch_1960 + elif request.param == "pydatetime": + return epoch_1960.to_pydatetime() + elif request.param == "datetime64": + return epoch_1960.to_datetime64() + else: + return str(epoch_1960) + + +@pytest.fixture +def julian_dates(): + return date_range("2014-1-1", periods=10).to_julian_date().values + + +class TestOrigin: + def test_origin_and_unit(self): + # GH#42624 + ts = to_datetime(1, unit="s", origin=1) + expected = Timestamp("1970-01-01 00:00:02") + assert ts == expected + + ts = to_datetime(1, unit="s", origin=1_000_000_000) + expected = Timestamp("2001-09-09 01:46:41") + assert ts == expected + + def test_julian(self, julian_dates): + # gh-11276, gh-11745 + # for origin as julian + + result = Series(to_datetime(julian_dates, unit="D", origin="julian")) + expected = Series( + to_datetime(julian_dates - Timestamp(0).to_julian_date(), unit="D") + ) + tm.assert_series_equal(result, expected) + + def test_unix(self): + result = Series(to_datetime([0, 1, 2], unit="D", origin="unix")) + expected = Series( + [Timestamp("1970-01-01"), Timestamp("1970-01-02"), Timestamp("1970-01-03")], + dtype="M8[ns]", + ) + tm.assert_series_equal(result, expected) + + def test_julian_round_trip(self): + result = to_datetime(2456658, origin="julian", unit="D") + assert result.to_julian_date() == 2456658 + + # out-of-bounds + msg = "1 is Out of Bounds for origin='julian'" + with pytest.raises(ValueError, match=msg): + to_datetime(1, origin="julian", unit="D") + + def test_invalid_unit(self, units, julian_dates): + # checking for invalid combination of origin='julian' and unit != D + if units != "D": + msg = "unit must be 'D' for origin='julian'" + with pytest.raises(ValueError, match=msg): + to_datetime(julian_dates, unit=units, origin="julian") + + @pytest.mark.parametrize("unit", ["ns", "D"]) + def test_invalid_origin(self, unit): + # need to have a numeric specified + msg = "it must be numeric with a unit specified" + with pytest.raises(ValueError, match=msg): + to_datetime("2005-01-01", origin="1960-01-01", unit=unit) + + @pytest.mark.parametrize( + "epochs", + [ + Timestamp(1960, 1, 1), + datetime(1960, 1, 1), + "1960-01-01", + np.datetime64("1960-01-01"), + ], + ) + def test_epoch(self, units, epochs): + epoch_1960 = Timestamp(1960, 1, 1) + units_from_epochs = np.arange(5, dtype=np.int64) + expected = Series( + [pd.Timedelta(x, unit=units) + epoch_1960 for x in units_from_epochs] + ) + + result = Series(to_datetime(units_from_epochs, unit=units, origin=epochs)) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "origin, exc", + [ + ("random_string", ValueError), + ("epoch", ValueError), + ("13-24-1990", ValueError), + (datetime(1, 1, 1), OutOfBoundsDatetime), + ], + ) + def test_invalid_origins(self, origin, exc, units, units_from_epochs): + msg = "|".join( + [ + f"origin {origin} is Out of Bounds", + f"origin {origin} cannot be converted to a Timestamp", + "Cannot cast .* to unit='ns' without overflow", + ] + ) + with pytest.raises(exc, match=msg): + to_datetime(units_from_epochs, unit=units, origin=origin) + + def test_invalid_origins_tzinfo(self): + # GH16842 + with pytest.raises(ValueError, match="must be tz-naive"): + to_datetime(1, unit="D", origin=datetime(2000, 1, 1, tzinfo=pytz.utc)) + + def test_incorrect_value_exception(self): + # GH47495 + msg = ( + "Unknown datetime string format, unable to parse: yesterday, at position 1" + ) + with pytest.raises(ValueError, match=msg): + to_datetime(["today", "yesterday"]) + + @pytest.mark.parametrize( + "format, warning", + [ + (None, UserWarning), + ("%Y-%m-%d %H:%M:%S", None), + ("%Y-%d-%m %H:%M:%S", None), + ], + ) + def test_to_datetime_out_of_bounds_with_format_arg(self, format, warning): + # see gh-23830 + msg = r"^Out of bounds nanosecond timestamp: 2417-10-10 00:00:00, at position 0" + with pytest.raises(OutOfBoundsDatetime, match=msg): + to_datetime("2417-10-10 00:00:00", format=format) + + @pytest.mark.parametrize( + "arg, origin, expected_str", + [ + [200 * 365, "unix", "2169-11-13 00:00:00"], + [200 * 365, "1870-01-01", "2069-11-13 00:00:00"], + [300 * 365, "1870-01-01", "2169-10-20 00:00:00"], + ], + ) + def test_processing_order(self, arg, origin, expected_str): + # make sure we handle out-of-bounds *before* + # constructing the dates + + result = to_datetime(arg, unit="D", origin=origin) + expected = Timestamp(expected_str) + assert result == expected + + result = to_datetime(200 * 365, unit="D", origin="1870-01-01") + expected = Timestamp("2069-11-13 00:00:00") + assert result == expected + + result = to_datetime(300 * 365, unit="D", origin="1870-01-01") + expected = Timestamp("2169-10-20 00:00:00") + assert result == expected + + @pytest.mark.parametrize( + "offset,utc,exp", + [ + ["Z", True, "2019-01-01T00:00:00.000Z"], + ["Z", None, "2019-01-01T00:00:00.000Z"], + ["-01:00", True, "2019-01-01T01:00:00.000Z"], + ["-01:00", None, "2019-01-01T00:00:00.000-01:00"], + ], + ) + def test_arg_tz_ns_unit(self, offset, utc, exp): + # GH 25546 + arg = "2019-01-01T00:00:00.000" + offset + result = to_datetime([arg], unit="ns", utc=utc) + expected = to_datetime([exp]).as_unit("ns") + tm.assert_index_equal(result, expected) + + +class TestShouldCache: + @pytest.mark.parametrize( + "listlike,do_caching", + [ + ([1, 2, 3, 4, 5, 6, 7, 8, 9, 0], False), + ([1, 1, 1, 1, 4, 5, 6, 7, 8, 9], True), + ], + ) + def test_should_cache(self, listlike, do_caching): + assert ( + tools.should_cache(listlike, check_count=len(listlike), unique_share=0.7) + == do_caching + ) + + @pytest.mark.parametrize( + "unique_share,check_count, err_message", + [ + (0.5, 11, r"check_count must be in next bounds: \[0; len\(arg\)\]"), + (10, 2, r"unique_share must be in next bounds: \(0; 1\)"), + ], + ) + def test_should_cache_errors(self, unique_share, check_count, err_message): + arg = [5] * 10 + + with pytest.raises(AssertionError, match=err_message): + tools.should_cache(arg, unique_share, check_count) + + @pytest.mark.parametrize( + "listlike", + [ + (deque([Timestamp("2010-06-02 09:30:00")] * 51)), + ([Timestamp("2010-06-02 09:30:00")] * 51), + (tuple([Timestamp("2010-06-02 09:30:00")] * 51)), + ], + ) + def test_no_slicing_errors_in_should_cache(self, listlike): + # GH#29403 + assert tools.should_cache(listlike) is True + + +def test_nullable_integer_to_datetime(): + # Test for #30050 + ser = Series([1, 2, None, 2**61, None]) + ser = ser.astype("Int64") + ser_copy = ser.copy() + + res = to_datetime(ser, unit="ns") + + expected = Series( + [ + np.datetime64("1970-01-01 00:00:00.000000001"), + np.datetime64("1970-01-01 00:00:00.000000002"), + np.datetime64("NaT"), + np.datetime64("2043-01-25 23:56:49.213693952"), + np.datetime64("NaT"), + ] + ) + tm.assert_series_equal(res, expected) + # Check that ser isn't mutated + tm.assert_series_equal(ser, ser_copy) + + +@pytest.mark.parametrize("klass", [np.array, list]) +def test_na_to_datetime(nulls_fixture, klass): + if isinstance(nulls_fixture, Decimal): + with pytest.raises(TypeError, match="not convertible to datetime"): + to_datetime(klass([nulls_fixture])) + + else: + result = to_datetime(klass([nulls_fixture])) + + assert result[0] is NaT + + +@pytest.mark.parametrize("errors", ["raise", "coerce", "ignore"]) +@pytest.mark.parametrize( + "args, format", + [ + (["03/24/2016", "03/25/2016", ""], "%m/%d/%Y"), + (["2016-03-24", "2016-03-25", ""], "%Y-%m-%d"), + ], + ids=["non-ISO8601", "ISO8601"], +) +def test_empty_string_datetime(errors, args, format): + # GH13044, GH50251 + td = Series(args) + + # coerce empty string to pd.NaT + result = to_datetime(td, format=format, errors=errors) + expected = Series(["2016-03-24", "2016-03-25", NaT], dtype="datetime64[ns]") + tm.assert_series_equal(expected, result) + + +def test_empty_string_datetime_coerce__unit(): + # GH13044 + # coerce empty string to pd.NaT + result = to_datetime([1, ""], unit="s", errors="coerce") + expected = DatetimeIndex(["1970-01-01 00:00:01", "NaT"], dtype="datetime64[ns]") + tm.assert_index_equal(expected, result) + + # verify that no exception is raised even when errors='raise' is set + result = to_datetime([1, ""], unit="s", errors="raise") + tm.assert_index_equal(expected, result) + + +@pytest.mark.parametrize("cache", [True, False]) +def test_to_datetime_monotonic_increasing_index(cache): + # GH28238 + cstart = start_caching_at + times = date_range(Timestamp("1980"), periods=cstart, freq="YS") + times = times.to_frame(index=False, name="DT").sample(n=cstart, random_state=1) + times.index = times.index.to_series().astype(float) / 1000 + result = to_datetime(times.iloc[:, 0], cache=cache) + expected = times.iloc[:, 0] + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "series_length", + [40, start_caching_at, (start_caching_at + 1), (start_caching_at + 5)], +) +def test_to_datetime_cache_coerce_50_lines_outofbounds(series_length): + # GH#45319 + ser = Series( + [datetime.fromisoformat("1446-04-12 00:00:00+00:00")] + + ([datetime.fromisoformat("1991-10-20 00:00:00+00:00")] * series_length), + dtype=object, + ) + result1 = to_datetime(ser, errors="coerce", utc=True) + + expected1 = Series( + [NaT] + ([Timestamp("1991-10-20 00:00:00+00:00")] * series_length) + ) + + tm.assert_series_equal(result1, expected1) + + result2 = to_datetime(ser, errors="ignore", utc=True) + + expected2 = Series( + [datetime.fromisoformat("1446-04-12 00:00:00+00:00")] + + ([datetime.fromisoformat("1991-10-20 00:00:00+00:00")] * series_length) + ) + + tm.assert_series_equal(result2, expected2) + + with pytest.raises(OutOfBoundsDatetime, match="Out of bounds nanosecond timestamp"): + to_datetime(ser, errors="raise", utc=True) + + +def test_to_datetime_format_f_parse_nanos(): + # GH 48767 + timestamp = "15/02/2020 02:03:04.123456789" + timestamp_format = "%d/%m/%Y %H:%M:%S.%f" + result = to_datetime(timestamp, format=timestamp_format) + expected = Timestamp( + year=2020, + month=2, + day=15, + hour=2, + minute=3, + second=4, + microsecond=123456, + nanosecond=789, + ) + assert result == expected + + +def test_to_datetime_mixed_iso8601(): + # https://github.com/pandas-dev/pandas/issues/50411 + result = to_datetime(["2020-01-01", "2020-01-01 05:00:00"], format="ISO8601") + expected = DatetimeIndex(["2020-01-01 00:00:00", "2020-01-01 05:00:00"]) + tm.assert_index_equal(result, expected) + + +def test_to_datetime_mixed_other(): + # https://github.com/pandas-dev/pandas/issues/50411 + result = to_datetime(["01/11/2000", "12 January 2000"], format="mixed") + expected = DatetimeIndex(["2000-01-11", "2000-01-12"]) + tm.assert_index_equal(result, expected) + + +@pytest.mark.parametrize("exact", [True, False]) +@pytest.mark.parametrize("format", ["ISO8601", "mixed"]) +def test_to_datetime_mixed_or_iso_exact(exact, format): + msg = "Cannot use 'exact' when 'format' is 'mixed' or 'ISO8601'" + with pytest.raises(ValueError, match=msg): + to_datetime(["2020-01-01"], exact=exact, format=format) + + +def test_to_datetime_mixed_not_necessarily_iso8601_raise(): + # https://github.com/pandas-dev/pandas/issues/50411 + with pytest.raises( + ValueError, match="Time data 01-01-2000 is not ISO8601 format, at position 1" + ): + to_datetime(["2020-01-01", "01-01-2000"], format="ISO8601") + + +@pytest.mark.parametrize( + ("errors", "expected"), + [ + ("coerce", DatetimeIndex(["2020-01-01 00:00:00", NaT])), + ("ignore", Index(["2020-01-01", "01-01-2000"], dtype=object)), + ], +) +def test_to_datetime_mixed_not_necessarily_iso8601_coerce(errors, expected): + # https://github.com/pandas-dev/pandas/issues/50411 + result = to_datetime(["2020-01-01", "01-01-2000"], format="ISO8601", errors=errors) + tm.assert_index_equal(result, expected) + + +def test_ignoring_unknown_tz_deprecated(): + # GH#18702, GH#51476 + dtstr = "2014 Jan 9 05:15 FAKE" + msg = 'un-recognized timezone "FAKE". Dropping unrecognized timezones is deprecated' + with tm.assert_produces_warning(FutureWarning, match=msg): + res = Timestamp(dtstr) + assert res == Timestamp(dtstr[:-5]) + + with tm.assert_produces_warning(FutureWarning): + res = to_datetime(dtstr) + assert res == to_datetime(dtstr[:-5]) + with tm.assert_produces_warning(FutureWarning): + res = to_datetime([dtstr]) + tm.assert_index_equal(res, to_datetime([dtstr[:-5]])) + + +def test_from_numeric_arrow_dtype(any_numeric_ea_dtype): + # GH 52425 + pytest.importorskip("pyarrow") + ser = Series([1, 2], dtype=f"{any_numeric_ea_dtype.lower()}[pyarrow]") + result = to_datetime(ser) + expected = Series([1, 2], dtype="datetime64[ns]") + tm.assert_series_equal(result, expected) + + +def test_to_datetime_with_empty_str_utc_false_format_mixed(): + # GH 50887 + vals = ["2020-01-01 00:00+00:00", ""] + result = to_datetime(vals, format="mixed") + expected = Index([Timestamp("2020-01-01 00:00+00:00"), "NaT"], dtype="M8[ns, UTC]") + tm.assert_index_equal(result, expected) + + # Check that a couple of other similar paths work the same way + alt = to_datetime(vals) + tm.assert_index_equal(alt, expected) + alt2 = DatetimeIndex(vals) + tm.assert_index_equal(alt2, expected) + + +def test_to_datetime_with_empty_str_utc_false_offsets_and_format_mixed(): + # GH 50887 + msg = "parsing datetimes with mixed time zones will raise an error" + + with tm.assert_produces_warning(FutureWarning, match=msg): + to_datetime( + ["2020-01-01 00:00+00:00", "2020-01-01 00:00+02:00", ""], format="mixed" + ) + + +def test_to_datetime_mixed_tzs_mixed_types(): + # GH#55793, GH#55693 mismatched tzs but one is str and other is + # datetime object + ts = Timestamp("2016-01-02 03:04:05", tz="US/Pacific") + dtstr = "2023-10-30 15:06+01" + arr = [ts, dtstr] + + msg = ( + "Mixed timezones detected. pass utc=True in to_datetime or tz='UTC' " + "in DatetimeIndex to convert to a common timezone" + ) + with pytest.raises(ValueError, match=msg): + to_datetime(arr) + with pytest.raises(ValueError, match=msg): + to_datetime(arr, format="mixed") + with pytest.raises(ValueError, match=msg): + DatetimeIndex(arr) + + +def test_to_datetime_mixed_types_matching_tzs(): + # GH#55793 + dtstr = "2023-11-01 09:22:03-07:00" + ts = Timestamp(dtstr) + arr = [ts, dtstr] + res1 = to_datetime(arr) + res2 = to_datetime(arr[::-1])[::-1] + res3 = to_datetime(arr, format="mixed") + res4 = DatetimeIndex(arr) + + expected = DatetimeIndex([ts, ts]) + tm.assert_index_equal(res1, expected) + tm.assert_index_equal(res2, expected) + tm.assert_index_equal(res3, expected) + tm.assert_index_equal(res4, expected) + + +dtstr = "2020-01-01 00:00+00:00" +ts = Timestamp(dtstr) + + +@pytest.mark.filterwarnings("ignore:Could not infer format:UserWarning") +@pytest.mark.parametrize( + "aware_val", + [dtstr, Timestamp(dtstr)], + ids=lambda x: type(x).__name__, +) +@pytest.mark.parametrize( + "naive_val", + [dtstr[:-6], ts.tz_localize(None), ts.date(), ts.asm8, ts.value, float(ts.value)], + ids=lambda x: type(x).__name__, +) +@pytest.mark.parametrize("naive_first", [True, False]) +def test_to_datetime_mixed_awareness_mixed_types(aware_val, naive_val, naive_first): + # GH#55793, GH#55693 + # Empty string parses to NaT + vals = [aware_val, naive_val, ""] + + vec = vals + if naive_first: + # alas, the behavior is order-dependent, so we test both ways + vec = [naive_val, aware_val, ""] + + # both_strs-> paths that were previously already deprecated with warning + # issued in _array_to_datetime_object + both_strs = isinstance(aware_val, str) and isinstance(naive_val, str) + has_numeric = isinstance(naive_val, (int, float)) + + depr_msg = "In a future version of pandas, parsing datetimes with mixed time zones" + + first_non_null = next(x for x in vec if x != "") + # if first_non_null is a not a string, _guess_datetime_format_for_array + # doesn't guess a format so we don't go through array_strptime + if not isinstance(first_non_null, str): + # that case goes through array_strptime which has different behavior + msg = "Cannot mix tz-aware with tz-naive values" + if naive_first and isinstance(aware_val, Timestamp): + if isinstance(naive_val, Timestamp): + msg = "Tz-aware datetime.datetime cannot be converted to datetime64" + with pytest.raises(ValueError, match=msg): + to_datetime(vec) + else: + with pytest.raises(ValueError, match=msg): + to_datetime(vec) + + # No warning/error with utc=True + to_datetime(vec, utc=True) + + elif has_numeric and vec.index(aware_val) < vec.index(naive_val): + msg = "time data .* doesn't match format" + with pytest.raises(ValueError, match=msg): + to_datetime(vec) + with pytest.raises(ValueError, match=msg): + to_datetime(vec, utc=True) + + elif both_strs and vec.index(aware_val) < vec.index(naive_val): + msg = r"time data \"2020-01-01 00:00\" doesn't match format" + with pytest.raises(ValueError, match=msg): + to_datetime(vec) + with pytest.raises(ValueError, match=msg): + to_datetime(vec, utc=True) + + elif both_strs and vec.index(naive_val) < vec.index(aware_val): + msg = "unconverted data remains when parsing with format" + with pytest.raises(ValueError, match=msg): + to_datetime(vec) + with pytest.raises(ValueError, match=msg): + to_datetime(vec, utc=True) + + else: + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + to_datetime(vec) + + # No warning/error with utc=True + to_datetime(vec, utc=True) + + if both_strs: + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + to_datetime(vec, format="mixed") + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + msg = "DatetimeIndex has mixed timezones" + with pytest.raises(TypeError, match=msg): + DatetimeIndex(vec) + else: + msg = "Cannot mix tz-aware with tz-naive values" + if naive_first and isinstance(aware_val, Timestamp): + if isinstance(naive_val, Timestamp): + msg = "Tz-aware datetime.datetime cannot be converted to datetime64" + with pytest.raises(ValueError, match=msg): + to_datetime(vec, format="mixed") + with pytest.raises(ValueError, match=msg): + DatetimeIndex(vec) + else: + with pytest.raises(ValueError, match=msg): + to_datetime(vec, format="mixed") + with pytest.raises(ValueError, match=msg): + DatetimeIndex(vec) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tools/test_to_numeric.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tools/test_to_numeric.py new file mode 100644 index 0000000000000000000000000000000000000000..c452382ec572bd24cf704c445f24f9af87947141 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tools/test_to_numeric.py @@ -0,0 +1,978 @@ +import decimal + +import numpy as np +from numpy import iinfo +import pytest + +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + ArrowDtype, + DataFrame, + Index, + Series, + option_context, + to_numeric, +) +import pandas._testing as tm + + +@pytest.fixture(params=[None, "ignore", "raise", "coerce"]) +def errors(request): + return request.param + + +@pytest.fixture(params=[True, False]) +def signed(request): + return request.param + + +@pytest.fixture(params=[lambda x: x, str], ids=["identity", "str"]) +def transform(request): + return request.param + + +@pytest.fixture(params=[47393996303418497800, 100000000000000000000]) +def large_val(request): + return request.param + + +@pytest.fixture(params=[True, False]) +def multiple_elts(request): + return request.param + + +@pytest.fixture( + params=[ + (lambda x: Index(x, name="idx"), tm.assert_index_equal), + (lambda x: Series(x, name="ser"), tm.assert_series_equal), + (lambda x: np.array(Index(x).values), tm.assert_numpy_array_equal), + ] +) +def transform_assert_equal(request): + return request.param + + +@pytest.mark.parametrize( + "input_kwargs,result_kwargs", + [ + ({}, {"dtype": np.int64}), + ({"errors": "coerce", "downcast": "integer"}, {"dtype": np.int8}), + ], +) +def test_empty(input_kwargs, result_kwargs): + # see gh-16302 + ser = Series([], dtype=object) + result = to_numeric(ser, **input_kwargs) + + expected = Series([], **result_kwargs) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "infer_string", [False, pytest.param(True, marks=td.skip_if_no("pyarrow"))] +) +@pytest.mark.parametrize("last_val", ["7", 7]) +def test_series(last_val, infer_string): + with option_context("future.infer_string", infer_string): + ser = Series(["1", "-3.14", last_val]) + result = to_numeric(ser) + + expected = Series([1, -3.14, 7]) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "data", + [ + [1, 3, 4, 5], + [1.0, 3.0, 4.0, 5.0], + # Bool is regarded as numeric. + [True, False, True, True], + ], +) +def test_series_numeric(data): + ser = Series(data, index=list("ABCD"), name="EFG") + + result = to_numeric(ser) + tm.assert_series_equal(result, ser) + + +@pytest.mark.parametrize( + "data,msg", + [ + ([1, -3.14, "apple"], 'Unable to parse string "apple" at position 2'), + ( + ["orange", 1, -3.14, "apple"], + 'Unable to parse string "orange" at position 0', + ), + ], +) +def test_error(data, msg): + ser = Series(data) + + with pytest.raises(ValueError, match=msg): + to_numeric(ser, errors="raise") + + +@pytest.mark.parametrize( + "errors,exp_data", [("ignore", [1, -3.14, "apple"]), ("coerce", [1, -3.14, np.nan])] +) +@pytest.mark.filterwarnings("ignore:errors='ignore' is deprecated:FutureWarning") +def test_ignore_error(errors, exp_data): + ser = Series([1, -3.14, "apple"]) + result = to_numeric(ser, errors=errors) + + expected = Series(exp_data) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "errors,exp", + [ + ("raise", 'Unable to parse string "apple" at position 2'), + ("ignore", [True, False, "apple"]), + # Coerces to float. + ("coerce", [1.0, 0.0, np.nan]), + ], +) +@pytest.mark.filterwarnings("ignore:errors='ignore' is deprecated:FutureWarning") +def test_bool_handling(errors, exp): + ser = Series([True, False, "apple"]) + + if isinstance(exp, str): + with pytest.raises(ValueError, match=exp): + to_numeric(ser, errors=errors) + else: + result = to_numeric(ser, errors=errors) + expected = Series(exp) + + tm.assert_series_equal(result, expected) + + +def test_list(): + ser = ["1", "-3.14", "7"] + res = to_numeric(ser) + + expected = np.array([1, -3.14, 7]) + tm.assert_numpy_array_equal(res, expected) + + +@pytest.mark.parametrize( + "data,arr_kwargs", + [ + ([1, 3, 4, 5], {"dtype": np.int64}), + ([1.0, 3.0, 4.0, 5.0], {}), + # Boolean is regarded as numeric. + ([True, False, True, True], {}), + ], +) +def test_list_numeric(data, arr_kwargs): + result = to_numeric(data) + expected = np.array(data, **arr_kwargs) + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize("kwargs", [{"dtype": "O"}, {}]) +def test_numeric(kwargs): + data = [1, -3.14, 7] + + ser = Series(data, **kwargs) + result = to_numeric(ser) + + expected = Series(data) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "columns", + [ + # One column. + "a", + # Multiple columns. + ["a", "b"], + ], +) +def test_numeric_df_columns(columns): + # see gh-14827 + df = DataFrame( + { + "a": [1.2, decimal.Decimal(3.14), decimal.Decimal("infinity"), "0.1"], + "b": [1.0, 2.0, 3.0, 4.0], + } + ) + + expected = DataFrame({"a": [1.2, 3.14, np.inf, 0.1], "b": [1.0, 2.0, 3.0, 4.0]}) + + df_copy = df.copy() + df_copy[columns] = df_copy[columns].apply(to_numeric) + + tm.assert_frame_equal(df_copy, expected) + + +@pytest.mark.parametrize( + "data,exp_data", + [ + ( + [[decimal.Decimal(3.14), 1.0], decimal.Decimal(1.6), 0.1], + [[3.14, 1.0], 1.6, 0.1], + ), + ([np.array([decimal.Decimal(3.14), 1.0]), 0.1], [[3.14, 1.0], 0.1]), + ], +) +def test_numeric_embedded_arr_likes(data, exp_data): + # Test to_numeric with embedded lists and arrays + df = DataFrame({"a": data}) + df["a"] = df["a"].apply(to_numeric) + + expected = DataFrame({"a": exp_data}) + tm.assert_frame_equal(df, expected) + + +def test_all_nan(): + ser = Series(["a", "b", "c"]) + result = to_numeric(ser, errors="coerce") + + expected = Series([np.nan, np.nan, np.nan]) + tm.assert_series_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore:errors='ignore' is deprecated:FutureWarning") +def test_type_check(errors): + # see gh-11776 + df = DataFrame({"a": [1, -3.14, 7], "b": ["4", "5", "6"]}) + kwargs = {"errors": errors} if errors is not None else {} + with pytest.raises(TypeError, match="1-d array"): + to_numeric(df, **kwargs) + + +@pytest.mark.parametrize("val", [1, 1.1, 20001]) +def test_scalar(val, signed, transform): + val = -val if signed else val + assert to_numeric(transform(val)) == float(val) + + +@pytest.mark.filterwarnings("ignore:errors='ignore' is deprecated:FutureWarning") +def test_really_large_scalar(large_val, signed, transform, errors): + # see gh-24910 + kwargs = {"errors": errors} if errors is not None else {} + val = -large_val if signed else large_val + + val = transform(val) + val_is_string = isinstance(val, str) + + if val_is_string and errors in (None, "raise"): + msg = "Integer out of range. at position 0" + with pytest.raises(ValueError, match=msg): + to_numeric(val, **kwargs) + else: + expected = float(val) if (errors == "coerce" and val_is_string) else val + tm.assert_almost_equal(to_numeric(val, **kwargs), expected) + + +@pytest.mark.filterwarnings("ignore:errors='ignore' is deprecated:FutureWarning") +def test_really_large_in_arr(large_val, signed, transform, multiple_elts, errors): + # see gh-24910 + kwargs = {"errors": errors} if errors is not None else {} + val = -large_val if signed else large_val + val = transform(val) + + extra_elt = "string" + arr = [val] + multiple_elts * [extra_elt] + + val_is_string = isinstance(val, str) + coercing = errors == "coerce" + + if errors in (None, "raise") and (val_is_string or multiple_elts): + if val_is_string: + msg = "Integer out of range. at position 0" + else: + msg = 'Unable to parse string "string" at position 1' + + with pytest.raises(ValueError, match=msg): + to_numeric(arr, **kwargs) + else: + result = to_numeric(arr, **kwargs) + + exp_val = float(val) if (coercing and val_is_string) else val + expected = [exp_val] + + if multiple_elts: + if coercing: + expected.append(np.nan) + exp_dtype = float + else: + expected.append(extra_elt) + exp_dtype = object + else: + exp_dtype = float if isinstance(exp_val, (int, float)) else object + + tm.assert_almost_equal(result, np.array(expected, dtype=exp_dtype)) + + +@pytest.mark.filterwarnings("ignore:errors='ignore' is deprecated:FutureWarning") +def test_really_large_in_arr_consistent(large_val, signed, multiple_elts, errors): + # see gh-24910 + # + # Even if we discover that we have to hold float, does not mean + # we should be lenient on subsequent elements that fail to be integer. + kwargs = {"errors": errors} if errors is not None else {} + arr = [str(-large_val if signed else large_val)] + + if multiple_elts: + arr.insert(0, large_val) + + if errors in (None, "raise"): + index = int(multiple_elts) + msg = f"Integer out of range. at position {index}" + + with pytest.raises(ValueError, match=msg): + to_numeric(arr, **kwargs) + else: + result = to_numeric(arr, **kwargs) + + if errors == "coerce": + expected = [float(i) for i in arr] + exp_dtype = float + else: + expected = arr + exp_dtype = object + + tm.assert_almost_equal(result, np.array(expected, dtype=exp_dtype)) + + +@pytest.mark.parametrize( + "errors,checker", + [ + ("raise", 'Unable to parse string "fail" at position 0'), + ("ignore", lambda x: x == "fail"), + ("coerce", lambda x: np.isnan(x)), + ], +) +@pytest.mark.filterwarnings("ignore:errors='ignore' is deprecated:FutureWarning") +def test_scalar_fail(errors, checker): + scalar = "fail" + + if isinstance(checker, str): + with pytest.raises(ValueError, match=checker): + to_numeric(scalar, errors=errors) + else: + assert checker(to_numeric(scalar, errors=errors)) + + +@pytest.mark.parametrize("data", [[1, 2, 3], [1.0, np.nan, 3, np.nan]]) +def test_numeric_dtypes(data, transform_assert_equal): + transform, assert_equal = transform_assert_equal + data = transform(data) + + result = to_numeric(data) + assert_equal(result, data) + + +@pytest.mark.parametrize( + "data,exp", + [ + (["1", "2", "3"], np.array([1, 2, 3], dtype="int64")), + (["1.5", "2.7", "3.4"], np.array([1.5, 2.7, 3.4])), + ], +) +def test_str(data, exp, transform_assert_equal): + transform, assert_equal = transform_assert_equal + result = to_numeric(transform(data)) + + expected = transform(exp) + assert_equal(result, expected) + + +def test_datetime_like(tz_naive_fixture, transform_assert_equal): + transform, assert_equal = transform_assert_equal + idx = pd.date_range("20130101", periods=3, tz=tz_naive_fixture) + + result = to_numeric(transform(idx)) + expected = transform(idx.asi8) + assert_equal(result, expected) + + +def test_timedelta(transform_assert_equal): + transform, assert_equal = transform_assert_equal + idx = pd.timedelta_range("1 days", periods=3, freq="D") + + result = to_numeric(transform(idx)) + expected = transform(idx.asi8) + assert_equal(result, expected) + + +def test_period(request, transform_assert_equal): + transform, assert_equal = transform_assert_equal + + idx = pd.period_range("2011-01", periods=3, freq="M", name="") + inp = transform(idx) + + if not isinstance(inp, Index): + request.applymarker( + pytest.mark.xfail(reason="Missing PeriodDtype support in to_numeric") + ) + result = to_numeric(inp) + expected = transform(idx.asi8) + assert_equal(result, expected) + + +@pytest.mark.parametrize( + "errors,expected", + [ + ("raise", "Invalid object type at position 0"), + ("ignore", Series([[10.0, 2], 1.0, "apple"])), + ("coerce", Series([np.nan, 1.0, np.nan])), + ], +) +@pytest.mark.filterwarnings("ignore:errors='ignore' is deprecated:FutureWarning") +def test_non_hashable(errors, expected): + # see gh-13324 + ser = Series([[10.0, 2], 1.0, "apple"]) + + if isinstance(expected, str): + with pytest.raises(TypeError, match=expected): + to_numeric(ser, errors=errors) + else: + result = to_numeric(ser, errors=errors) + tm.assert_series_equal(result, expected) + + +def test_downcast_invalid_cast(): + # see gh-13352 + data = ["1", 2, 3] + invalid_downcast = "unsigned-integer" + msg = "invalid downcasting method provided" + + with pytest.raises(ValueError, match=msg): + to_numeric(data, downcast=invalid_downcast) + + +def test_errors_invalid_value(): + # see gh-26466 + data = ["1", 2, 3] + invalid_error_value = "invalid" + msg = "invalid error value specified" + + with pytest.raises(ValueError, match=msg): + to_numeric(data, errors=invalid_error_value) + + +@pytest.mark.parametrize( + "data", + [ + ["1", 2, 3], + [1, 2, 3], + np.array(["1970-01-02", "1970-01-03", "1970-01-04"], dtype="datetime64[D]"), + ], +) +@pytest.mark.parametrize( + "kwargs,exp_dtype", + [ + # Basic function tests. + ({}, np.int64), + ({"downcast": None}, np.int64), + # Support below np.float32 is rare and far between. + ({"downcast": "float"}, np.dtype(np.float32).char), + # Basic dtype support. + ({"downcast": "unsigned"}, np.dtype(np.typecodes["UnsignedInteger"][0])), + ], +) +def test_downcast_basic(data, kwargs, exp_dtype): + # see gh-13352 + result = to_numeric(data, **kwargs) + expected = np.array([1, 2, 3], dtype=exp_dtype) + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize("signed_downcast", ["integer", "signed"]) +@pytest.mark.parametrize( + "data", + [ + ["1", 2, 3], + [1, 2, 3], + np.array(["1970-01-02", "1970-01-03", "1970-01-04"], dtype="datetime64[D]"), + ], +) +def test_signed_downcast(data, signed_downcast): + # see gh-13352 + smallest_int_dtype = np.dtype(np.typecodes["Integer"][0]) + expected = np.array([1, 2, 3], dtype=smallest_int_dtype) + + res = to_numeric(data, downcast=signed_downcast) + tm.assert_numpy_array_equal(res, expected) + + +def test_ignore_downcast_invalid_data(): + # If we can't successfully cast the given + # data to a numeric dtype, do not bother + # with the downcast parameter. + data = ["foo", 2, 3] + expected = np.array(data, dtype=object) + + msg = "errors='ignore' is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + res = to_numeric(data, errors="ignore", downcast="unsigned") + tm.assert_numpy_array_equal(res, expected) + + +def test_ignore_downcast_neg_to_unsigned(): + # Cannot cast to an unsigned integer + # because we have a negative number. + data = ["-1", 2, 3] + expected = np.array([-1, 2, 3], dtype=np.int64) + + res = to_numeric(data, downcast="unsigned") + tm.assert_numpy_array_equal(res, expected) + + +# Warning in 32 bit platforms +@pytest.mark.filterwarnings("ignore:invalid value encountered in cast:RuntimeWarning") +@pytest.mark.parametrize("downcast", ["integer", "signed", "unsigned"]) +@pytest.mark.parametrize( + "data,expected", + [ + (["1.1", 2, 3], np.array([1.1, 2, 3], dtype=np.float64)), + ( + [10000.0, 20000, 3000, 40000.36, 50000, 50000.00], + np.array( + [10000.0, 20000, 3000, 40000.36, 50000, 50000.00], dtype=np.float64 + ), + ), + ], +) +def test_ignore_downcast_cannot_convert_float(data, expected, downcast): + # Cannot cast to an integer (signed or unsigned) + # because we have a float number. + res = to_numeric(data, downcast=downcast) + tm.assert_numpy_array_equal(res, expected) + + +@pytest.mark.parametrize( + "downcast,expected_dtype", + [("integer", np.int16), ("signed", np.int16), ("unsigned", np.uint16)], +) +def test_downcast_not8bit(downcast, expected_dtype): + # the smallest integer dtype need not be np.(u)int8 + data = ["256", 257, 258] + + expected = np.array([256, 257, 258], dtype=expected_dtype) + res = to_numeric(data, downcast=downcast) + tm.assert_numpy_array_equal(res, expected) + + +@pytest.mark.parametrize( + "dtype,downcast,min_max", + [ + ("int8", "integer", [iinfo(np.int8).min, iinfo(np.int8).max]), + ("int16", "integer", [iinfo(np.int16).min, iinfo(np.int16).max]), + ("int32", "integer", [iinfo(np.int32).min, iinfo(np.int32).max]), + ("int64", "integer", [iinfo(np.int64).min, iinfo(np.int64).max]), + ("uint8", "unsigned", [iinfo(np.uint8).min, iinfo(np.uint8).max]), + ("uint16", "unsigned", [iinfo(np.uint16).min, iinfo(np.uint16).max]), + ("uint32", "unsigned", [iinfo(np.uint32).min, iinfo(np.uint32).max]), + ("uint64", "unsigned", [iinfo(np.uint64).min, iinfo(np.uint64).max]), + ("int16", "integer", [iinfo(np.int8).min, iinfo(np.int8).max + 1]), + ("int32", "integer", [iinfo(np.int16).min, iinfo(np.int16).max + 1]), + ("int64", "integer", [iinfo(np.int32).min, iinfo(np.int32).max + 1]), + ("int16", "integer", [iinfo(np.int8).min - 1, iinfo(np.int16).max]), + ("int32", "integer", [iinfo(np.int16).min - 1, iinfo(np.int32).max]), + ("int64", "integer", [iinfo(np.int32).min - 1, iinfo(np.int64).max]), + ("uint16", "unsigned", [iinfo(np.uint8).min, iinfo(np.uint8).max + 1]), + ("uint32", "unsigned", [iinfo(np.uint16).min, iinfo(np.uint16).max + 1]), + ("uint64", "unsigned", [iinfo(np.uint32).min, iinfo(np.uint32).max + 1]), + ], +) +def test_downcast_limits(dtype, downcast, min_max): + # see gh-14404: test the limits of each downcast. + series = to_numeric(Series(min_max), downcast=downcast) + assert series.dtype == dtype + + +def test_downcast_float64_to_float32(): + # GH-43693: Check float64 preservation when >= 16,777,217 + series = Series([16777217.0, np.finfo(np.float64).max, np.nan], dtype=np.float64) + result = to_numeric(series, downcast="float") + + assert series.dtype == result.dtype + + +@pytest.mark.parametrize( + "ser,expected", + [ + ( + Series([0, 9223372036854775808]), + Series([0, 9223372036854775808], dtype=np.uint64), + ) + ], +) +def test_downcast_uint64(ser, expected): + # see gh-14422: + # BUG: to_numeric doesn't work uint64 numbers + + result = to_numeric(ser, downcast="unsigned") + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "data,exp_data", + [ + ( + [200, 300, "", "NaN", 30000000000000000000], + [200, 300, np.nan, np.nan, 30000000000000000000], + ), + ( + ["12345678901234567890", "1234567890", "ITEM"], + [12345678901234567890, 1234567890, np.nan], + ), + ], +) +def test_coerce_uint64_conflict(data, exp_data): + # see gh-17007 and gh-17125 + # + # Still returns float despite the uint64-nan conflict, + # which would normally force the casting to object. + result = to_numeric(Series(data), errors="coerce") + expected = Series(exp_data, dtype=float) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "errors,exp", + [ + ("ignore", Series(["12345678901234567890", "1234567890", "ITEM"])), + ("raise", "Unable to parse string"), + ], +) +@pytest.mark.filterwarnings("ignore:errors='ignore' is deprecated:FutureWarning") +def test_non_coerce_uint64_conflict(errors, exp): + # see gh-17007 and gh-17125 + # + # For completeness. + ser = Series(["12345678901234567890", "1234567890", "ITEM"]) + + if isinstance(exp, str): + with pytest.raises(ValueError, match=exp): + to_numeric(ser, errors=errors) + else: + result = to_numeric(ser, errors=errors) + tm.assert_series_equal(result, ser) + + +@pytest.mark.parametrize("dc1", ["integer", "float", "unsigned"]) +@pytest.mark.parametrize("dc2", ["integer", "float", "unsigned"]) +def test_downcast_empty(dc1, dc2): + # GH32493 + + tm.assert_numpy_array_equal( + to_numeric([], downcast=dc1), + to_numeric([], downcast=dc2), + check_dtype=False, + ) + + +def test_failure_to_convert_uint64_string_to_NaN(): + # GH 32394 + result = to_numeric("uint64", errors="coerce") + assert np.isnan(result) + + ser = Series([32, 64, np.nan]) + result = to_numeric(Series(["32", "64", "uint64"]), errors="coerce") + tm.assert_series_equal(result, ser) + + +@pytest.mark.parametrize( + "strrep", + [ + "243.164", + "245.968", + "249.585", + "259.745", + "265.742", + "272.567", + "279.196", + "280.366", + "275.034", + "271.351", + "272.889", + "270.627", + "280.828", + "290.383", + "308.153", + "319.945", + "336.0", + "344.09", + "351.385", + "356.178", + "359.82", + "361.03", + "367.701", + "380.812", + "387.98", + "391.749", + "391.171", + "385.97", + "385.345", + "386.121", + "390.996", + "399.734", + "413.073", + "421.532", + "430.221", + "437.092", + "439.746", + "446.01", + "451.191", + "460.463", + "469.779", + "472.025", + "479.49", + "474.864", + "467.54", + "471.978", + ], +) +def test_precision_float_conversion(strrep): + # GH 31364 + result = to_numeric(strrep) + + assert result == float(strrep) + + +@pytest.mark.parametrize( + "values, expected", + [ + (["1", "2", None], Series([1, 2, np.nan], dtype="Int64")), + (["1", "2", "3"], Series([1, 2, 3], dtype="Int64")), + (["1", "2", 3], Series([1, 2, 3], dtype="Int64")), + (["1", "2", 3.5], Series([1, 2, 3.5], dtype="Float64")), + (["1", None, 3.5], Series([1, np.nan, 3.5], dtype="Float64")), + (["1", "2", "3.5"], Series([1, 2, 3.5], dtype="Float64")), + ], +) +def test_to_numeric_from_nullable_string(values, nullable_string_dtype, expected): + # https://github.com/pandas-dev/pandas/issues/37262 + s = Series(values, dtype=nullable_string_dtype) + result = to_numeric(s) + tm.assert_series_equal(result, expected) + + +def test_to_numeric_from_nullable_string_coerce(nullable_string_dtype): + # GH#52146 + values = ["a", "1"] + ser = Series(values, dtype=nullable_string_dtype) + result = to_numeric(ser, errors="coerce") + expected = Series([pd.NA, 1], dtype="Int64") + tm.assert_series_equal(result, expected) + + +def test_to_numeric_from_nullable_string_ignore(nullable_string_dtype): + # GH#52146 + values = ["a", "1"] + ser = Series(values, dtype=nullable_string_dtype) + expected = ser.copy() + msg = "errors='ignore' is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = to_numeric(ser, errors="ignore") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "data, input_dtype, downcast, expected_dtype", + ( + ([1, 1], "Int64", "integer", "Int8"), + ([1.0, pd.NA], "Float64", "integer", "Int8"), + ([1.0, 1.1], "Float64", "integer", "Float64"), + ([1, pd.NA], "Int64", "integer", "Int8"), + ([450, 300], "Int64", "integer", "Int16"), + ([1, 1], "Float64", "integer", "Int8"), + ([np.iinfo(np.int64).max - 1, 1], "Int64", "integer", "Int64"), + ([1, 1], "Int64", "signed", "Int8"), + ([1.0, 1.0], "Float32", "signed", "Int8"), + ([1.0, 1.1], "Float64", "signed", "Float64"), + ([1, pd.NA], "Int64", "signed", "Int8"), + ([450, -300], "Int64", "signed", "Int16"), + ([np.iinfo(np.uint64).max - 1, 1], "UInt64", "signed", "UInt64"), + ([1, 1], "Int64", "unsigned", "UInt8"), + ([1.0, 1.0], "Float32", "unsigned", "UInt8"), + ([1.0, 1.1], "Float64", "unsigned", "Float64"), + ([1, pd.NA], "Int64", "unsigned", "UInt8"), + ([450, -300], "Int64", "unsigned", "Int64"), + ([-1, -1], "Int32", "unsigned", "Int32"), + ([1, 1], "Float64", "float", "Float32"), + ([1, 1.1], "Float64", "float", "Float32"), + ([1, 1], "Float32", "float", "Float32"), + ([1, 1.1], "Float32", "float", "Float32"), + ), +) +def test_downcast_nullable_numeric(data, input_dtype, downcast, expected_dtype): + arr = pd.array(data, dtype=input_dtype) + result = to_numeric(arr, downcast=downcast) + expected = pd.array(data, dtype=expected_dtype) + tm.assert_extension_array_equal(result, expected) + + +def test_downcast_nullable_mask_is_copied(): + # GH38974 + + arr = pd.array([1, 2, pd.NA], dtype="Int64") + + result = to_numeric(arr, downcast="integer") + expected = pd.array([1, 2, pd.NA], dtype="Int8") + tm.assert_extension_array_equal(result, expected) + + arr[1] = pd.NA # should not modify result + tm.assert_extension_array_equal(result, expected) + + +def test_to_numeric_scientific_notation(): + # GH 15898 + result = to_numeric("1.7e+308") + expected = np.float64(1.7e308) + assert result == expected + + +@pytest.mark.parametrize("val", [9876543210.0, 2.0**128]) +def test_to_numeric_large_float_not_downcast_to_float_32(val): + # GH 19729 + expected = Series([val]) + result = to_numeric(expected, downcast="float") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "val, dtype", [(1, "Int64"), (1.5, "Float64"), (True, "boolean")] +) +def test_to_numeric_dtype_backend(val, dtype): + # GH#50505 + ser = Series([val], dtype=object) + result = to_numeric(ser, dtype_backend="numpy_nullable") + expected = Series([val], dtype=dtype) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "val, dtype", + [ + (1, "Int64"), + (1.5, "Float64"), + (True, "boolean"), + (1, "int64[pyarrow]"), + (1.5, "float64[pyarrow]"), + (True, "bool[pyarrow]"), + ], +) +def test_to_numeric_dtype_backend_na(val, dtype): + # GH#50505 + if "pyarrow" in dtype: + pytest.importorskip("pyarrow") + dtype_backend = "pyarrow" + else: + dtype_backend = "numpy_nullable" + ser = Series([val, None], dtype=object) + result = to_numeric(ser, dtype_backend=dtype_backend) + expected = Series([val, pd.NA], dtype=dtype) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "val, dtype, downcast", + [ + (1, "Int8", "integer"), + (1.5, "Float32", "float"), + (1, "Int8", "signed"), + (1, "int8[pyarrow]", "integer"), + (1.5, "float[pyarrow]", "float"), + (1, "int8[pyarrow]", "signed"), + ], +) +def test_to_numeric_dtype_backend_downcasting(val, dtype, downcast): + # GH#50505 + if "pyarrow" in dtype: + pytest.importorskip("pyarrow") + dtype_backend = "pyarrow" + else: + dtype_backend = "numpy_nullable" + ser = Series([val, None], dtype=object) + result = to_numeric(ser, dtype_backend=dtype_backend, downcast=downcast) + expected = Series([val, pd.NA], dtype=dtype) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "smaller, dtype_backend", + [["UInt8", "numpy_nullable"], ["uint8[pyarrow]", "pyarrow"]], +) +def test_to_numeric_dtype_backend_downcasting_uint(smaller, dtype_backend): + # GH#50505 + if dtype_backend == "pyarrow": + pytest.importorskip("pyarrow") + ser = Series([1, pd.NA], dtype="UInt64") + result = to_numeric(ser, dtype_backend=dtype_backend, downcast="unsigned") + expected = Series([1, pd.NA], dtype=smaller) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "dtype", + [ + "Int64", + "UInt64", + "Float64", + "boolean", + "int64[pyarrow]", + "uint64[pyarrow]", + "float64[pyarrow]", + "bool[pyarrow]", + ], +) +def test_to_numeric_dtype_backend_already_nullable(dtype): + # GH#50505 + if "pyarrow" in dtype: + pytest.importorskip("pyarrow") + ser = Series([1, pd.NA], dtype=dtype) + result = to_numeric(ser, dtype_backend="numpy_nullable") + expected = Series([1, pd.NA], dtype=dtype) + tm.assert_series_equal(result, expected) + + +def test_to_numeric_dtype_backend_error(dtype_backend): + # GH#50505 + ser = Series(["a", "b", ""]) + expected = ser.copy() + with pytest.raises(ValueError, match="Unable to parse string"): + to_numeric(ser, dtype_backend=dtype_backend) + + msg = "errors='ignore' is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = to_numeric(ser, dtype_backend=dtype_backend, errors="ignore") + tm.assert_series_equal(result, expected) + + result = to_numeric(ser, dtype_backend=dtype_backend, errors="coerce") + if dtype_backend == "pyarrow": + dtype = "double[pyarrow]" + else: + dtype = "Float64" + expected = Series([np.nan, np.nan, np.nan], dtype=dtype) + tm.assert_series_equal(result, expected) + + +def test_invalid_dtype_backend(): + ser = Series([1, 2, 3]) + msg = ( + "dtype_backend numpy is invalid, only 'numpy_nullable' and " + "'pyarrow' are allowed." + ) + with pytest.raises(ValueError, match=msg): + to_numeric(ser, dtype_backend="numpy") + + +def test_coerce_pyarrow_backend(): + # GH 52588 + pa = pytest.importorskip("pyarrow") + ser = Series(list("12x"), dtype=ArrowDtype(pa.string())) + result = to_numeric(ser, errors="coerce", dtype_backend="pyarrow") + expected = Series([1, 2, None], dtype=ArrowDtype(pa.int64())) + tm.assert_series_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tools/test_to_time.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tools/test_to_time.py new file mode 100644 index 0000000000000000000000000000000000000000..b673bd9c2ec7168971ae0ed802336e4f03ff63a7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tools/test_to_time.py @@ -0,0 +1,72 @@ +from datetime import time +import locale + +import numpy as np +import pytest + +from pandas.compat import PY311 + +from pandas import Series +import pandas._testing as tm +from pandas.core.tools.times import to_time + +# The tests marked with this are locale-dependent. +# They pass, except when the machine locale is zh_CN or it_IT. +fails_on_non_english = pytest.mark.xfail( + locale.getlocale()[0] in ("zh_CN", "it_IT"), + reason="fail on a CI build with LC_ALL=zh_CN.utf8/it_IT.utf8", + strict=False, +) + + +class TestToTime: + @pytest.mark.parametrize( + "time_string", + [ + "14:15", + "1415", + pytest.param("2:15pm", marks=fails_on_non_english), + pytest.param("0215pm", marks=fails_on_non_english), + "14:15:00", + "141500", + pytest.param("2:15:00pm", marks=fails_on_non_english), + pytest.param("021500pm", marks=fails_on_non_english), + time(14, 15), + ], + ) + def test_parsers_time(self, time_string): + # GH#11818 + assert to_time(time_string) == time(14, 15) + + def test_odd_format(self): + new_string = "14.15" + msg = r"Cannot convert arg \['14\.15'\] to a time" + if not PY311: + with pytest.raises(ValueError, match=msg): + to_time(new_string) + assert to_time(new_string, format="%H.%M") == time(14, 15) + + def test_arraylike(self): + arg = ["14:15", "20:20"] + expected_arr = [time(14, 15), time(20, 20)] + assert to_time(arg) == expected_arr + assert to_time(arg, format="%H:%M") == expected_arr + assert to_time(arg, infer_time_format=True) == expected_arr + assert to_time(arg, format="%I:%M%p", errors="coerce") == [None, None] + + msg = "errors='ignore' is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + res = to_time(arg, format="%I:%M%p", errors="ignore") + tm.assert_numpy_array_equal(res, np.array(arg, dtype=np.object_)) + + msg = "Cannot convert.+to a time with given format" + with pytest.raises(ValueError, match=msg): + to_time(arg, format="%I:%M%p", errors="raise") + + tm.assert_series_equal( + to_time(Series(arg, name="test")), Series(expected_arr, name="test") + ) + + res = to_time(np.array(arg)) + assert isinstance(res, list) + assert res == expected_arr diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tools/test_to_timedelta.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tools/test_to_timedelta.py new file mode 100644 index 0000000000000000000000000000000000000000..b67694f1c58c7016221ed629358e8867b2a1534a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tools/test_to_timedelta.py @@ -0,0 +1,340 @@ +from datetime import ( + time, + timedelta, +) + +import numpy as np +import pytest + +from pandas.compat import IS64 +from pandas.errors import OutOfBoundsTimedelta + +import pandas as pd +from pandas import ( + Series, + TimedeltaIndex, + isna, + to_timedelta, +) +import pandas._testing as tm +from pandas.core.arrays import TimedeltaArray + + +class TestTimedeltas: + def test_to_timedelta_dt64_raises(self): + # Passing datetime64-dtype data to TimedeltaIndex is no longer + # supported GH#29794 + msg = r"dtype datetime64\[ns\] cannot be converted to timedelta64\[ns\]" + + ser = Series([pd.NaT]) + with pytest.raises(TypeError, match=msg): + to_timedelta(ser) + with pytest.raises(TypeError, match=msg): + ser.to_frame().apply(to_timedelta) + + @pytest.mark.parametrize("readonly", [True, False]) + def test_to_timedelta_readonly(self, readonly): + # GH#34857 + arr = np.array([], dtype=object) + if readonly: + arr.setflags(write=False) + result = to_timedelta(arr) + expected = to_timedelta([]) + tm.assert_index_equal(result, expected) + + def test_to_timedelta_null(self): + result = to_timedelta(["", ""]) + assert isna(result).all() + + def test_to_timedelta_same_np_timedelta64(self): + # pass thru + result = to_timedelta(np.array([np.timedelta64(1, "s")])) + expected = pd.Index(np.array([np.timedelta64(1, "s")])) + tm.assert_index_equal(result, expected) + + def test_to_timedelta_series(self): + # Series + expected = Series([timedelta(days=1), timedelta(days=1, seconds=1)]) + result = to_timedelta(Series(["1d", "1days 00:00:01"])) + tm.assert_series_equal(result, expected) + + def test_to_timedelta_units(self): + # with units + result = TimedeltaIndex( + [np.timedelta64(0, "ns"), np.timedelta64(10, "s").astype("m8[ns]")] + ) + expected = to_timedelta([0, 10], unit="s") + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "dtype, unit", + [ + ["int64", "s"], + ["int64", "m"], + ["int64", "h"], + ["timedelta64[s]", "s"], + ["timedelta64[D]", "D"], + ], + ) + def test_to_timedelta_units_dtypes(self, dtype, unit): + # arrays of various dtypes + arr = np.array([1] * 5, dtype=dtype) + result = to_timedelta(arr, unit=unit) + exp_dtype = "m8[ns]" if dtype == "int64" else "m8[s]" + expected = TimedeltaIndex([np.timedelta64(1, unit)] * 5, dtype=exp_dtype) + tm.assert_index_equal(result, expected) + + def test_to_timedelta_oob_non_nano(self): + arr = np.array([pd.NaT._value + 1], dtype="timedelta64[m]") + + msg = ( + "Cannot convert -9223372036854775807 minutes to " + r"timedelta64\[s\] without overflow" + ) + with pytest.raises(OutOfBoundsTimedelta, match=msg): + to_timedelta(arr) + + with pytest.raises(OutOfBoundsTimedelta, match=msg): + TimedeltaIndex(arr) + + with pytest.raises(OutOfBoundsTimedelta, match=msg): + TimedeltaArray._from_sequence(arr, dtype="m8[s]") + + @pytest.mark.parametrize( + "arg", [np.arange(10).reshape(2, 5), pd.DataFrame(np.arange(10).reshape(2, 5))] + ) + @pytest.mark.parametrize("errors", ["ignore", "raise", "coerce"]) + @pytest.mark.filterwarnings("ignore:errors='ignore' is deprecated:FutureWarning") + def test_to_timedelta_dataframe(self, arg, errors): + # GH 11776 + with pytest.raises(TypeError, match="1-d array"): + to_timedelta(arg, errors=errors) + + def test_to_timedelta_invalid_errors(self): + # bad value for errors parameter + msg = "errors must be one of" + with pytest.raises(ValueError, match=msg): + to_timedelta(["foo"], errors="never") + + @pytest.mark.parametrize("arg", [[1, 2], 1]) + def test_to_timedelta_invalid_unit(self, arg): + # these will error + msg = "invalid unit abbreviation: foo" + with pytest.raises(ValueError, match=msg): + to_timedelta(arg, unit="foo") + + def test_to_timedelta_time(self): + # time not supported ATM + msg = ( + "Value must be Timedelta, string, integer, float, timedelta or convertible" + ) + with pytest.raises(ValueError, match=msg): + to_timedelta(time(second=1)) + assert to_timedelta(time(second=1), errors="coerce") is pd.NaT + + def test_to_timedelta_bad_value(self): + msg = "Could not convert 'foo' to NumPy timedelta" + with pytest.raises(ValueError, match=msg): + to_timedelta(["foo", "bar"]) + + def test_to_timedelta_bad_value_coerce(self): + tm.assert_index_equal( + TimedeltaIndex([pd.NaT, pd.NaT]), + to_timedelta(["foo", "bar"], errors="coerce"), + ) + + tm.assert_index_equal( + TimedeltaIndex(["1 day", pd.NaT, "1 min"]), + to_timedelta(["1 day", "bar", "1 min"], errors="coerce"), + ) + + def test_to_timedelta_invalid_errors_ignore(self): + # gh-13613: these should not error because errors='ignore' + msg = "errors='ignore' is deprecated" + invalid_data = "apple" + + with tm.assert_produces_warning(FutureWarning, match=msg): + result = to_timedelta(invalid_data, errors="ignore") + assert invalid_data == result + + invalid_data = ["apple", "1 days"] + expected = np.array(invalid_data, dtype=object) + + with tm.assert_produces_warning(FutureWarning, match=msg): + result = to_timedelta(invalid_data, errors="ignore") + tm.assert_numpy_array_equal(expected, result) + + invalid_data = pd.Index(["apple", "1 days"]) + with tm.assert_produces_warning(FutureWarning, match=msg): + result = to_timedelta(invalid_data, errors="ignore") + tm.assert_index_equal(invalid_data, result) + + invalid_data = Series(["apple", "1 days"]) + with tm.assert_produces_warning(FutureWarning, match=msg): + result = to_timedelta(invalid_data, errors="ignore") + tm.assert_series_equal(invalid_data, result) + + @pytest.mark.parametrize( + "val, errors", + [ + ("1M", True), + ("1 M", True), + ("1Y", True), + ("1 Y", True), + ("1y", True), + ("1 y", True), + ("1m", False), + ("1 m", False), + ("1 day", False), + ("2day", False), + ], + ) + def test_unambiguous_timedelta_values(self, val, errors): + # GH36666 Deprecate use of strings denoting units with 'M', 'Y', 'm' or 'y' + # in pd.to_timedelta + msg = "Units 'M', 'Y' and 'y' do not represent unambiguous timedelta" + if errors: + with pytest.raises(ValueError, match=msg): + to_timedelta(val) + else: + # check it doesn't raise + to_timedelta(val) + + def test_to_timedelta_via_apply(self): + # GH 5458 + expected = Series([np.timedelta64(1, "s")]) + result = Series(["00:00:01"]).apply(to_timedelta) + tm.assert_series_equal(result, expected) + + result = Series([to_timedelta("00:00:01")]) + tm.assert_series_equal(result, expected) + + def test_to_timedelta_inference_without_warning(self): + # GH#41731 inference produces a warning in the Series constructor, + # but _not_ in to_timedelta + vals = ["00:00:01", pd.NaT] + with tm.assert_produces_warning(None): + result = to_timedelta(vals) + + expected = TimedeltaIndex([pd.Timedelta(seconds=1), pd.NaT]) + tm.assert_index_equal(result, expected) + + def test_to_timedelta_on_missing_values(self): + # GH5438 + timedelta_NaT = np.timedelta64("NaT") + + actual = to_timedelta(Series(["00:00:01", np.nan])) + expected = Series( + [np.timedelta64(1000000000, "ns"), timedelta_NaT], + dtype=f"{tm.ENDIAN}m8[ns]", + ) + tm.assert_series_equal(actual, expected) + + ser = Series(["00:00:01", pd.NaT], dtype="m8[ns]") + actual = to_timedelta(ser) + tm.assert_series_equal(actual, expected) + + @pytest.mark.parametrize("val", [np.nan, pd.NaT, pd.NA]) + def test_to_timedelta_on_missing_values_scalar(self, val): + actual = to_timedelta(val) + assert actual._value == np.timedelta64("NaT").astype("int64") + + @pytest.mark.parametrize("val", [np.nan, pd.NaT, pd.NA]) + def test_to_timedelta_on_missing_values_list(self, val): + actual = to_timedelta([val]) + assert actual[0]._value == np.timedelta64("NaT").astype("int64") + + @pytest.mark.xfail(not IS64, reason="Floating point error") + def test_to_timedelta_float(self): + # https://github.com/pandas-dev/pandas/issues/25077 + arr = np.arange(0, 1, 1e-6)[-10:] + result = to_timedelta(arr, unit="s") + expected_asi8 = np.arange(999990000, 10**9, 1000, dtype="int64") + tm.assert_numpy_array_equal(result.asi8, expected_asi8) + + def test_to_timedelta_coerce_strings_unit(self): + arr = np.array([1, 2, "error"], dtype=object) + result = to_timedelta(arr, unit="ns", errors="coerce") + expected = to_timedelta([1, 2, pd.NaT], unit="ns") + tm.assert_index_equal(result, expected) + + def test_to_timedelta_ignore_strings_unit(self): + arr = np.array([1, 2, "error"], dtype=object) + msg = "errors='ignore' is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = to_timedelta(arr, unit="ns", errors="ignore") + tm.assert_numpy_array_equal(result, arr) + + @pytest.mark.parametrize( + "expected_val, result_val", [[timedelta(days=2), 2], [None, None]] + ) + def test_to_timedelta_nullable_int64_dtype(self, expected_val, result_val): + # GH 35574 + expected = Series([timedelta(days=1), expected_val]) + result = to_timedelta(Series([1, result_val], dtype="Int64"), unit="days") + + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + ("input", "expected"), + [ + ("8:53:08.71800000001", "8:53:08.718"), + ("8:53:08.718001", "8:53:08.718001"), + ("8:53:08.7180000001", "8:53:08.7180000001"), + ("-8:53:08.71800000001", "-8:53:08.718"), + ("8:53:08.7180000089", "8:53:08.718000008"), + ], + ) + @pytest.mark.parametrize("func", [pd.Timedelta, to_timedelta]) + def test_to_timedelta_precision_over_nanos(self, input, expected, func): + # GH: 36738 + expected = pd.Timedelta(expected) + result = func(input) + assert result == expected + + def test_to_timedelta_zerodim(self, fixed_now_ts): + # ndarray.item() incorrectly returns int for dt64[ns] and td64[ns] + dt64 = fixed_now_ts.to_datetime64() + arg = np.array(dt64) + + msg = ( + "Value must be Timedelta, string, integer, float, timedelta " + "or convertible, not datetime64" + ) + with pytest.raises(ValueError, match=msg): + to_timedelta(arg) + + arg2 = arg.view("m8[ns]") + result = to_timedelta(arg2) + assert isinstance(result, pd.Timedelta) + assert result._value == dt64.view("i8") + + def test_to_timedelta_numeric_ea(self, any_numeric_ea_dtype): + # GH#48796 + ser = Series([1, pd.NA], dtype=any_numeric_ea_dtype) + result = to_timedelta(ser) + expected = Series([pd.Timedelta(1, unit="ns"), pd.NaT]) + tm.assert_series_equal(result, expected) + + def test_to_timedelta_fraction(self): + result = to_timedelta(1.0 / 3, unit="h") + expected = pd.Timedelta("0 days 00:19:59.999999998") + assert result == expected + + +def test_from_numeric_arrow_dtype(any_numeric_ea_dtype): + # GH 52425 + pytest.importorskip("pyarrow") + ser = Series([1, 2], dtype=f"{any_numeric_ea_dtype.lower()}[pyarrow]") + result = to_timedelta(ser) + expected = Series([1, 2], dtype="timedelta64[ns]") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("unit", ["ns", "ms"]) +def test_from_timedelta_arrow_dtype(unit): + # GH 54298 + pytest.importorskip("pyarrow") + expected = Series([timedelta(1)], dtype=f"duration[{unit}][pyarrow]") + result = to_timedelta(expected) + tm.assert_series_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..219a7482bb3f578c85ffc82008c505a8c063041c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/frequencies/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/frequencies/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/frequencies/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/frequencies/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3de55aecc7ca3c3b81952df9b36a844202097553 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/frequencies/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/frequencies/__pycache__/test_freq_code.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/frequencies/__pycache__/test_freq_code.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..203956e1ac472ba585fa12a016a31ff650fa2ee8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/frequencies/__pycache__/test_freq_code.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/frequencies/__pycache__/test_frequencies.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/frequencies/__pycache__/test_frequencies.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62220a48bf78a4eee7295d8927f0ed73df552d31 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/frequencies/__pycache__/test_frequencies.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/frequencies/__pycache__/test_inference.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/frequencies/__pycache__/test_inference.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16488e3007d5e2883c35e2745e46a16dd36ee578 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/frequencies/__pycache__/test_inference.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/frequencies/test_freq_code.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/frequencies/test_freq_code.py new file mode 100644 index 0000000000000000000000000000000000000000..16b7190753ee2c5beaa1179b8318732f546f9bf6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/frequencies/test_freq_code.py @@ -0,0 +1,69 @@ +import numpy as np +import pytest + +from pandas._libs.tslibs import ( + Period, + to_offset, +) + + +@pytest.mark.parametrize( + "freqstr,exp_freqstr", + [("D", "D"), ("W", "D"), ("ME", "D"), ("s", "s"), ("min", "s"), ("h", "s")], +) +def test_get_to_timestamp_base(freqstr, exp_freqstr): + off = to_offset(freqstr) + per = Period._from_ordinal(1, off) + exp_code = to_offset(exp_freqstr)._period_dtype_code + + result_code = per._dtype._get_to_timestamp_base() + assert result_code == exp_code + + +@pytest.mark.parametrize( + "args,expected", + [ + ((1.5, "min"), (90, "s")), + ((62.4, "min"), (3744, "s")), + ((1.04, "h"), (3744, "s")), + ((1, "D"), (1, "D")), + ((0.342931, "h"), (1234551600, "us")), + ((1.2345, "D"), (106660800, "ms")), + ], +) +def test_resolution_bumping(args, expected): + # see gh-14378 + off = to_offset(str(args[0]) + args[1]) + assert off.n == expected[0] + assert off._prefix == expected[1] + + +@pytest.mark.parametrize( + "args", + [ + (0.5, "ns"), + # Too much precision in the input can prevent. + (0.3429324798798269273987982, "h"), + ], +) +def test_cat(args): + msg = "Invalid frequency" + + with pytest.raises(ValueError, match=msg): + to_offset(str(args[0]) + args[1]) + + +@pytest.mark.parametrize( + "freqstr,expected", + [ + ("1h", "2021-01-01T09:00:00"), + ("1D", "2021-01-02T08:00:00"), + ("1W", "2021-01-03T08:00:00"), + ("1ME", "2021-01-31T08:00:00"), + ("1YE", "2021-12-31T08:00:00"), + ], +) +def test_compatibility(freqstr, expected): + ts_np = np.datetime64("2021-01-01T08:00:00.00") + do = to_offset(freqstr) + assert ts_np + do == np.datetime64(expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/frequencies/test_frequencies.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/frequencies/test_frequencies.py new file mode 100644 index 0000000000000000000000000000000000000000..f0af290b2fb69d57b448898e9a9d8635e529e7bc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/frequencies/test_frequencies.py @@ -0,0 +1,29 @@ +import pytest + +from pandas._libs.tslibs import offsets + +from pandas.tseries.frequencies import ( + is_subperiod, + is_superperiod, +) + + +@pytest.mark.parametrize( + "p1,p2,expected", + [ + # Input validation. + (offsets.MonthEnd(), None, False), + (offsets.YearEnd(), None, False), + (None, offsets.YearEnd(), False), + (None, offsets.MonthEnd(), False), + (None, None, False), + (offsets.YearEnd(), offsets.MonthEnd(), True), + (offsets.Hour(), offsets.Minute(), True), + (offsets.Second(), offsets.Milli(), True), + (offsets.Milli(), offsets.Micro(), True), + (offsets.Micro(), offsets.Nano(), True), + ], +) +def test_super_sub_symmetry(p1, p2, expected): + assert is_superperiod(p1, p2) is expected + assert is_subperiod(p2, p1) is expected diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/frequencies/test_inference.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/frequencies/test_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..99a504f4188c16718587b70634736dce08bfd444 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/frequencies/test_inference.py @@ -0,0 +1,558 @@ +from datetime import ( + datetime, + timedelta, +) + +import numpy as np +import pytest + +from pandas._libs.tslibs.ccalendar import ( + DAYS, + MONTHS, +) +from pandas._libs.tslibs.offsets import _get_offset +from pandas._libs.tslibs.period import INVALID_FREQ_ERR_MSG +from pandas.compat import is_platform_windows + +from pandas import ( + DatetimeIndex, + Index, + RangeIndex, + Series, + Timestamp, + date_range, + period_range, +) +import pandas._testing as tm +from pandas.core.arrays import ( + DatetimeArray, + TimedeltaArray, +) +from pandas.core.tools.datetimes import to_datetime + +from pandas.tseries import ( + frequencies, + offsets, +) + + +@pytest.fixture( + params=[ + (timedelta(1), "D"), + (timedelta(hours=1), "h"), + (timedelta(minutes=1), "min"), + (timedelta(seconds=1), "s"), + (np.timedelta64(1, "ns"), "ns"), + (timedelta(microseconds=1), "us"), + (timedelta(microseconds=1000), "ms"), + ] +) +def base_delta_code_pair(request): + return request.param + + +freqs = ( + [f"QE-{month}" for month in MONTHS] + + [f"{annual}-{month}" for annual in ["YE", "BYE"] for month in MONTHS] + + ["ME", "BME", "BMS"] + + [f"WOM-{count}{day}" for count in range(1, 5) for day in DAYS] + + [f"W-{day}" for day in DAYS] +) + + +@pytest.mark.parametrize("freq", freqs) +@pytest.mark.parametrize("periods", [5, 7]) +def test_infer_freq_range(periods, freq): + freq = freq.upper() + + gen = date_range("1/1/2000", periods=periods, freq=freq) + index = DatetimeIndex(gen.values) + + if not freq.startswith("QE-"): + assert frequencies.infer_freq(index) == gen.freqstr + else: + inf_freq = frequencies.infer_freq(index) + is_dec_range = inf_freq == "QE-DEC" and gen.freqstr in ( + "QE", + "QE-DEC", + "QE-SEP", + "QE-JUN", + "QE-MAR", + ) + is_nov_range = inf_freq == "QE-NOV" and gen.freqstr in ( + "QE-NOV", + "QE-AUG", + "QE-MAY", + "QE-FEB", + ) + is_oct_range = inf_freq == "QE-OCT" and gen.freqstr in ( + "QE-OCT", + "QE-JUL", + "QE-APR", + "QE-JAN", + ) + assert is_dec_range or is_nov_range or is_oct_range + + +def test_raise_if_period_index(): + index = period_range(start="1/1/1990", periods=20, freq="M") + msg = "Check the `freq` attribute instead of using infer_freq" + + with pytest.raises(TypeError, match=msg): + frequencies.infer_freq(index) + + +def test_raise_if_too_few(): + index = DatetimeIndex(["12/31/1998", "1/3/1999"]) + msg = "Need at least 3 dates to infer frequency" + + with pytest.raises(ValueError, match=msg): + frequencies.infer_freq(index) + + +def test_business_daily(): + index = DatetimeIndex(["01/01/1999", "1/4/1999", "1/5/1999"]) + assert frequencies.infer_freq(index) == "B" + + +def test_business_daily_look_alike(): + # see gh-16624 + # + # Do not infer "B when "weekend" (2-day gap) in wrong place. + index = DatetimeIndex(["12/31/1998", "1/3/1999", "1/4/1999"]) + assert frequencies.infer_freq(index) is None + + +def test_day_corner(): + index = DatetimeIndex(["1/1/2000", "1/2/2000", "1/3/2000"]) + assert frequencies.infer_freq(index) == "D" + + +def test_non_datetime_index(): + dates = to_datetime(["1/1/2000", "1/2/2000", "1/3/2000"]) + assert frequencies.infer_freq(dates) == "D" + + +def test_fifth_week_of_month_infer(): + # see gh-9425 + # + # Only attempt to infer up to WOM-4. + index = DatetimeIndex(["2014-03-31", "2014-06-30", "2015-03-30"]) + assert frequencies.infer_freq(index) is None + + +def test_week_of_month_fake(): + # All of these dates are on same day + # of week and are 4 or 5 weeks apart. + index = DatetimeIndex(["2013-08-27", "2013-10-01", "2013-10-29", "2013-11-26"]) + assert frequencies.infer_freq(index) != "WOM-4TUE" + + +def test_fifth_week_of_month(): + # see gh-9425 + # + # Only supports freq up to WOM-4. + msg = ( + "Of the four parameters: start, end, periods, " + "and freq, exactly three must be specified" + ) + + with pytest.raises(ValueError, match=msg): + date_range("2014-01-01", freq="WOM-5MON") + + +def test_monthly_ambiguous(): + rng = DatetimeIndex(["1/31/2000", "2/29/2000", "3/31/2000"]) + assert rng.inferred_freq == "ME" + + +def test_annual_ambiguous(): + rng = DatetimeIndex(["1/31/2000", "1/31/2001", "1/31/2002"]) + assert rng.inferred_freq == "YE-JAN" + + +@pytest.mark.parametrize("count", range(1, 5)) +def test_infer_freq_delta(base_delta_code_pair, count): + b = Timestamp(datetime.now()) + base_delta, code = base_delta_code_pair + + inc = base_delta * count + index = DatetimeIndex([b + inc * j for j in range(3)]) + + exp_freq = f"{count:d}{code}" if count > 1 else code + assert frequencies.infer_freq(index) == exp_freq + + +@pytest.mark.parametrize( + "constructor", + [ + lambda now, delta: DatetimeIndex( + [now + delta * 7] + [now + delta * j for j in range(3)] + ), + lambda now, delta: DatetimeIndex( + [now + delta * j for j in range(3)] + [now + delta * 7] + ), + ], +) +def test_infer_freq_custom(base_delta_code_pair, constructor): + b = Timestamp(datetime.now()) + base_delta, _ = base_delta_code_pair + + index = constructor(b, base_delta) + assert frequencies.infer_freq(index) is None + + +@pytest.mark.parametrize( + "freq,expected", [("Q", "QE-DEC"), ("Q-NOV", "QE-NOV"), ("Q-OCT", "QE-OCT")] +) +def test_infer_freq_index(freq, expected): + rng = period_range("1959Q2", "2009Q3", freq=freq) + with tm.assert_produces_warning(FutureWarning, match="Dtype inference"): + rng = Index(rng.to_timestamp("D", how="e").astype(object)) + + assert rng.inferred_freq == expected + + +@pytest.mark.parametrize( + "expected,dates", + list( + { + "YS-JAN": ["2009-01-01", "2010-01-01", "2011-01-01", "2012-01-01"], + "QE-OCT": ["2009-01-31", "2009-04-30", "2009-07-31", "2009-10-31"], + "ME": ["2010-11-30", "2010-12-31", "2011-01-31", "2011-02-28"], + "W-SAT": ["2010-12-25", "2011-01-01", "2011-01-08", "2011-01-15"], + "D": ["2011-01-01", "2011-01-02", "2011-01-03", "2011-01-04"], + "h": [ + "2011-12-31 22:00", + "2011-12-31 23:00", + "2012-01-01 00:00", + "2012-01-01 01:00", + ], + }.items() + ), +) +@pytest.mark.parametrize("unit", ["s", "ms", "us", "ns"]) +def test_infer_freq_tz(tz_naive_fixture, expected, dates, unit): + # see gh-7310, GH#55609 + tz = tz_naive_fixture + idx = DatetimeIndex(dates, tz=tz).as_unit(unit) + assert idx.inferred_freq == expected + + +def test_infer_freq_tz_series(tz_naive_fixture): + # infer_freq should work with both tz-naive and tz-aware series. See gh-52456 + tz = tz_naive_fixture + idx = date_range("2021-01-01", "2021-01-04", tz=tz) + series = idx.to_series().reset_index(drop=True) + inferred_freq = frequencies.infer_freq(series) + assert inferred_freq == "D" + + +@pytest.mark.parametrize( + "date_pair", + [ + ["2013-11-02", "2013-11-5"], # Fall DST + ["2014-03-08", "2014-03-11"], # Spring DST + ["2014-01-01", "2014-01-03"], # Regular Time + ], +) +@pytest.mark.parametrize( + "freq", + ["h", "3h", "10min", "3601s", "3600001ms", "3600000001us", "3600000000001ns"], +) +def test_infer_freq_tz_transition(tz_naive_fixture, date_pair, freq): + # see gh-8772 + tz = tz_naive_fixture + idx = date_range(date_pair[0], date_pair[1], freq=freq, tz=tz) + assert idx.inferred_freq == freq + + +def test_infer_freq_tz_transition_custom(): + index = date_range("2013-11-03", periods=5, freq="3h").tz_localize( + "America/Chicago" + ) + assert index.inferred_freq is None + + +@pytest.mark.parametrize( + "data,expected", + [ + # Hourly freq in a day must result in "h" + ( + [ + "2014-07-01 09:00", + "2014-07-01 10:00", + "2014-07-01 11:00", + "2014-07-01 12:00", + "2014-07-01 13:00", + "2014-07-01 14:00", + ], + "h", + ), + ( + [ + "2014-07-01 09:00", + "2014-07-01 10:00", + "2014-07-01 11:00", + "2014-07-01 12:00", + "2014-07-01 13:00", + "2014-07-01 14:00", + "2014-07-01 15:00", + "2014-07-01 16:00", + "2014-07-02 09:00", + "2014-07-02 10:00", + "2014-07-02 11:00", + ], + "bh", + ), + ( + [ + "2014-07-04 09:00", + "2014-07-04 10:00", + "2014-07-04 11:00", + "2014-07-04 12:00", + "2014-07-04 13:00", + "2014-07-04 14:00", + "2014-07-04 15:00", + "2014-07-04 16:00", + "2014-07-07 09:00", + "2014-07-07 10:00", + "2014-07-07 11:00", + ], + "bh", + ), + ( + [ + "2014-07-04 09:00", + "2014-07-04 10:00", + "2014-07-04 11:00", + "2014-07-04 12:00", + "2014-07-04 13:00", + "2014-07-04 14:00", + "2014-07-04 15:00", + "2014-07-04 16:00", + "2014-07-07 09:00", + "2014-07-07 10:00", + "2014-07-07 11:00", + "2014-07-07 12:00", + "2014-07-07 13:00", + "2014-07-07 14:00", + "2014-07-07 15:00", + "2014-07-07 16:00", + "2014-07-08 09:00", + "2014-07-08 10:00", + "2014-07-08 11:00", + "2014-07-08 12:00", + "2014-07-08 13:00", + "2014-07-08 14:00", + "2014-07-08 15:00", + "2014-07-08 16:00", + ], + "bh", + ), + ], +) +def test_infer_freq_business_hour(data, expected): + # see gh-7905 + idx = DatetimeIndex(data) + assert idx.inferred_freq == expected + + +def test_not_monotonic(): + rng = DatetimeIndex(["1/31/2000", "1/31/2001", "1/31/2002"]) + rng = rng[::-1] + + assert rng.inferred_freq == "-1YE-JAN" + + +def test_non_datetime_index2(): + rng = DatetimeIndex(["1/31/2000", "1/31/2001", "1/31/2002"]) + vals = rng.to_pydatetime() + + result = frequencies.infer_freq(vals) + assert result == rng.inferred_freq + + +@pytest.mark.parametrize( + "idx", + [ + Index(np.arange(5), dtype=np.int64), + Index(np.arange(5), dtype=np.float64), + period_range("2020-01-01", periods=5), + RangeIndex(5), + ], +) +def test_invalid_index_types(idx): + # see gh-48439 + msg = "|".join( + [ + "cannot infer freq from a non-convertible", + "Check the `freq` attribute instead of using infer_freq", + ] + ) + + with pytest.raises(TypeError, match=msg): + frequencies.infer_freq(idx) + + +@pytest.mark.skipif(is_platform_windows(), reason="see gh-10822: Windows issue") +def test_invalid_index_types_unicode(): + # see gh-10822 + # + # Odd error message on conversions to datetime for unicode. + msg = "Unknown datetime string format" + + with pytest.raises(ValueError, match=msg): + frequencies.infer_freq(Index(["ZqgszYBfuL"])) + + +def test_string_datetime_like_compat(): + # see gh-6463 + data = ["2004-01", "2004-02", "2004-03", "2004-04"] + + expected = frequencies.infer_freq(data) + result = frequencies.infer_freq(Index(data)) + + assert result == expected + + +def test_series(): + # see gh-6407 + s = Series(date_range("20130101", "20130110")) + inferred = frequencies.infer_freq(s) + assert inferred == "D" + + +@pytest.mark.parametrize("end", [10, 10.0]) +def test_series_invalid_type(end): + # see gh-6407 + msg = "cannot infer freq from a non-convertible dtype on a Series" + s = Series(np.arange(end)) + + with pytest.raises(TypeError, match=msg): + frequencies.infer_freq(s) + + +def test_series_inconvertible_string(using_infer_string): + # see gh-6407 + if using_infer_string: + msg = "cannot infer freq from" + + with pytest.raises(TypeError, match=msg): + frequencies.infer_freq(Series(["foo", "bar"])) + else: + msg = "Unknown datetime string format" + + with pytest.raises(ValueError, match=msg): + frequencies.infer_freq(Series(["foo", "bar"])) + + +@pytest.mark.parametrize("freq", [None, "ms"]) +def test_series_period_index(freq): + # see gh-6407 + # + # Cannot infer on PeriodIndex + msg = "cannot infer freq from a non-convertible dtype on a Series" + s = Series(period_range("2013", periods=10, freq=freq)) + + with pytest.raises(TypeError, match=msg): + frequencies.infer_freq(s) + + +@pytest.mark.parametrize("freq", ["ME", "ms", "s"]) +def test_series_datetime_index(freq): + s = Series(date_range("20130101", periods=10, freq=freq)) + inferred = frequencies.infer_freq(s) + assert inferred == freq + + +@pytest.mark.parametrize( + "offset_func", + [ + _get_offset, + lambda freq: date_range("2011-01-01", periods=5, freq=freq), + ], +) +@pytest.mark.parametrize( + "freq", + [ + "WEEKDAY", + "EOM", + "W@MON", + "W@TUE", + "W@WED", + "W@THU", + "W@FRI", + "W@SAT", + "W@SUN", + "QE@JAN", + "QE@FEB", + "QE@MAR", + "YE@JAN", + "YE@FEB", + "YE@MAR", + "YE@APR", + "YE@MAY", + "YE@JUN", + "YE@JUL", + "YE@AUG", + "YE@SEP", + "YE@OCT", + "YE@NOV", + "YE@DEC", + "YE@JAN", + "WOM@1MON", + "WOM@2MON", + "WOM@3MON", + "WOM@4MON", + "WOM@1TUE", + "WOM@2TUE", + "WOM@3TUE", + "WOM@4TUE", + "WOM@1WED", + "WOM@2WED", + "WOM@3WED", + "WOM@4WED", + "WOM@1THU", + "WOM@2THU", + "WOM@3THU", + "WOM@4THU", + "WOM@1FRI", + "WOM@2FRI", + "WOM@3FRI", + "WOM@4FRI", + ], +) +def test_legacy_offset_warnings(offset_func, freq): + with pytest.raises(ValueError, match=INVALID_FREQ_ERR_MSG): + offset_func(freq) + + +def test_ms_vs_capital_ms(): + left = _get_offset("ms") + right = _get_offset("MS") + + assert left == offsets.Milli() + assert right == offsets.MonthBegin() + + +def test_infer_freq_non_nano(): + arr = np.arange(10).astype(np.int64).view("M8[s]") + dta = DatetimeArray._simple_new(arr, dtype=arr.dtype) + res = frequencies.infer_freq(dta) + assert res == "s" + + arr2 = arr.view("m8[ms]") + tda = TimedeltaArray._simple_new(arr2, dtype=arr2.dtype) + res2 = frequencies.infer_freq(tda) + assert res2 == "ms" + + +def test_infer_freq_non_nano_tzaware(tz_aware_fixture): + tz = tz_aware_fixture + + dti = date_range("2016-01-01", periods=365, freq="B", tz=tz) + dta = dti._data.as_unit("s") + + res = frequencies.infer_freq(dta) + assert res == "B" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ca340ab5943cb425507f49f3e0a44c523dd0da5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/__pycache__/test_calendar.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/__pycache__/test_calendar.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a312ae510fc085074d40aedaceb761dd33ca300 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/__pycache__/test_calendar.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/__pycache__/test_federal.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/__pycache__/test_federal.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa55f918a9eeb6a5d15b0ccb9c47c4537dbb59fb Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/__pycache__/test_federal.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/__pycache__/test_holiday.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/__pycache__/test_holiday.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2009b64d4bd1baec93c0a42f34b17f89c1c8b86f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/__pycache__/test_holiday.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/__pycache__/test_observance.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/__pycache__/test_observance.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d94c49a718a3d56abc01c9ca27b3b372589e1aa Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/__pycache__/test_observance.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/test_calendar.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/test_calendar.py new file mode 100644 index 0000000000000000000000000000000000000000..99829857e68363ec845dd9dff3d90917c31adaa0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/test_calendar.py @@ -0,0 +1,119 @@ +from datetime import datetime + +import pytest + +from pandas import ( + DatetimeIndex, + offsets, + to_datetime, +) +import pandas._testing as tm + +from pandas.tseries.holiday import ( + AbstractHolidayCalendar, + Holiday, + Timestamp, + USFederalHolidayCalendar, + USLaborDay, + USThanksgivingDay, + get_calendar, +) + + +@pytest.mark.parametrize( + "transform", [lambda x: x, lambda x: x.strftime("%Y-%m-%d"), lambda x: Timestamp(x)] +) +def test_calendar(transform): + start_date = datetime(2012, 1, 1) + end_date = datetime(2012, 12, 31) + + calendar = USFederalHolidayCalendar() + holidays = calendar.holidays(transform(start_date), transform(end_date)) + + expected = [ + datetime(2012, 1, 2), + datetime(2012, 1, 16), + datetime(2012, 2, 20), + datetime(2012, 5, 28), + datetime(2012, 7, 4), + datetime(2012, 9, 3), + datetime(2012, 10, 8), + datetime(2012, 11, 12), + datetime(2012, 11, 22), + datetime(2012, 12, 25), + ] + + assert list(holidays.to_pydatetime()) == expected + + +def test_calendar_caching(): + # see gh-9552. + + class TestCalendar(AbstractHolidayCalendar): + def __init__(self, name=None, rules=None) -> None: + super().__init__(name=name, rules=rules) + + jan1 = TestCalendar(rules=[Holiday("jan1", year=2015, month=1, day=1)]) + jan2 = TestCalendar(rules=[Holiday("jan2", year=2015, month=1, day=2)]) + + # Getting holidays for Jan 1 should not alter results for Jan 2. + expected = DatetimeIndex(["01-Jan-2015"]).as_unit("ns") + tm.assert_index_equal(jan1.holidays(), expected) + + expected2 = DatetimeIndex(["02-Jan-2015"]).as_unit("ns") + tm.assert_index_equal(jan2.holidays(), expected2) + + +def test_calendar_observance_dates(): + # see gh-11477 + us_fed_cal = get_calendar("USFederalHolidayCalendar") + holidays0 = us_fed_cal.holidays( + datetime(2015, 7, 3), datetime(2015, 7, 3) + ) # <-- same start and end dates + holidays1 = us_fed_cal.holidays( + datetime(2015, 7, 3), datetime(2015, 7, 6) + ) # <-- different start and end dates + holidays2 = us_fed_cal.holidays( + datetime(2015, 7, 3), datetime(2015, 7, 3) + ) # <-- same start and end dates + + # These should all produce the same result. + # + # In addition, calling with different start and end + # dates should not alter the output if we call the + # function again with the same start and end date. + tm.assert_index_equal(holidays0, holidays1) + tm.assert_index_equal(holidays0, holidays2) + + +def test_rule_from_name(): + us_fed_cal = get_calendar("USFederalHolidayCalendar") + assert us_fed_cal.rule_from_name("Thanksgiving Day") == USThanksgivingDay + + +def test_calendar_2031(): + # See gh-27790 + # + # Labor Day 2031 is on September 1. Saturday before is August 30. + # Next working day after August 30 ought to be Tuesday, September 2. + + class testCalendar(AbstractHolidayCalendar): + rules = [USLaborDay] + + cal = testCalendar() + workDay = offsets.CustomBusinessDay(calendar=cal) + Sat_before_Labor_Day_2031 = to_datetime("2031-08-30") + next_working_day = Sat_before_Labor_Day_2031 + 0 * workDay + assert next_working_day == to_datetime("2031-09-02") + + +def test_no_holidays_calendar(): + # Test for issue #31415 + + class NoHolidaysCalendar(AbstractHolidayCalendar): + pass + + cal = NoHolidaysCalendar() + holidays = cal.holidays(Timestamp("01-Jan-2020"), Timestamp("01-Jan-2021")) + empty_index = DatetimeIndex([]) # Type is DatetimeIndex since return_name=False + tm.assert_index_equal(holidays, empty_index) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/test_federal.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/test_federal.py new file mode 100644 index 0000000000000000000000000000000000000000..2565877f8a2a44071f96cef0f670c23842a364b6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/test_federal.py @@ -0,0 +1,58 @@ +from datetime import datetime + +from pandas import DatetimeIndex +import pandas._testing as tm + +from pandas.tseries.holiday import ( + AbstractHolidayCalendar, + USFederalHolidayCalendar, + USMartinLutherKingJr, + USMemorialDay, +) + + +def test_no_mlk_before_1986(): + # see gh-10278 + class MLKCalendar(AbstractHolidayCalendar): + rules = [USMartinLutherKingJr] + + holidays = MLKCalendar().holidays(start="1984", end="1988").to_pydatetime().tolist() + + # Testing to make sure holiday is not incorrectly observed before 1986. + assert holidays == [datetime(1986, 1, 20, 0, 0), datetime(1987, 1, 19, 0, 0)] + + +def test_memorial_day(): + class MemorialDay(AbstractHolidayCalendar): + rules = [USMemorialDay] + + holidays = MemorialDay().holidays(start="1971", end="1980").to_pydatetime().tolist() + + # Fixes 5/31 error and checked manually against Wikipedia. + assert holidays == [ + datetime(1971, 5, 31, 0, 0), + datetime(1972, 5, 29, 0, 0), + datetime(1973, 5, 28, 0, 0), + datetime(1974, 5, 27, 0, 0), + datetime(1975, 5, 26, 0, 0), + datetime(1976, 5, 31, 0, 0), + datetime(1977, 5, 30, 0, 0), + datetime(1978, 5, 29, 0, 0), + datetime(1979, 5, 28, 0, 0), + ] + + +def test_federal_holiday_inconsistent_returntype(): + # GH 49075 test case + # Instantiate two calendars to rule out _cache + cal1 = USFederalHolidayCalendar() + cal2 = USFederalHolidayCalendar() + + results_2018 = cal1.holidays(start=datetime(2018, 8, 1), end=datetime(2018, 8, 31)) + results_2019 = cal2.holidays(start=datetime(2019, 8, 1), end=datetime(2019, 8, 31)) + expected_results = DatetimeIndex([], dtype="datetime64[ns]", freq=None) + + # Check against expected results to ensure both date + # ranges generate expected results as per GH49075 submission + tm.assert_index_equal(results_2018, expected_results) + tm.assert_index_equal(results_2019, expected_results) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/test_holiday.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/test_holiday.py new file mode 100644 index 0000000000000000000000000000000000000000..b2eefd04ef93b159140ae72ae3d96d8adf071719 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/test_holiday.py @@ -0,0 +1,332 @@ +from datetime import datetime + +import pytest +from pytz import utc + +from pandas import ( + DatetimeIndex, + Series, +) +import pandas._testing as tm + +from pandas.tseries.holiday import ( + MO, + SA, + AbstractHolidayCalendar, + DateOffset, + EasterMonday, + GoodFriday, + Holiday, + HolidayCalendarFactory, + Timestamp, + USColumbusDay, + USFederalHolidayCalendar, + USLaborDay, + USMartinLutherKingJr, + USMemorialDay, + USPresidentsDay, + USThanksgivingDay, + get_calendar, + next_monday, +) + + +@pytest.mark.parametrize( + "holiday,start_date,end_date,expected", + [ + ( + USMemorialDay, + datetime(2011, 1, 1), + datetime(2020, 12, 31), + [ + datetime(2011, 5, 30), + datetime(2012, 5, 28), + datetime(2013, 5, 27), + datetime(2014, 5, 26), + datetime(2015, 5, 25), + datetime(2016, 5, 30), + datetime(2017, 5, 29), + datetime(2018, 5, 28), + datetime(2019, 5, 27), + datetime(2020, 5, 25), + ], + ), + ( + Holiday("July 4th Eve", month=7, day=3), + "2001-01-01", + "2003-03-03", + [Timestamp("2001-07-03 00:00:00"), Timestamp("2002-07-03 00:00:00")], + ), + ( + Holiday("July 4th Eve", month=7, day=3, days_of_week=(0, 1, 2, 3)), + "2001-01-01", + "2008-03-03", + [ + Timestamp("2001-07-03 00:00:00"), + Timestamp("2002-07-03 00:00:00"), + Timestamp("2003-07-03 00:00:00"), + Timestamp("2006-07-03 00:00:00"), + Timestamp("2007-07-03 00:00:00"), + ], + ), + ( + EasterMonday, + datetime(2011, 1, 1), + datetime(2020, 12, 31), + [ + Timestamp("2011-04-25 00:00:00"), + Timestamp("2012-04-09 00:00:00"), + Timestamp("2013-04-01 00:00:00"), + Timestamp("2014-04-21 00:00:00"), + Timestamp("2015-04-06 00:00:00"), + Timestamp("2016-03-28 00:00:00"), + Timestamp("2017-04-17 00:00:00"), + Timestamp("2018-04-02 00:00:00"), + Timestamp("2019-04-22 00:00:00"), + Timestamp("2020-04-13 00:00:00"), + ], + ), + ( + GoodFriday, + datetime(2011, 1, 1), + datetime(2020, 12, 31), + [ + Timestamp("2011-04-22 00:00:00"), + Timestamp("2012-04-06 00:00:00"), + Timestamp("2013-03-29 00:00:00"), + Timestamp("2014-04-18 00:00:00"), + Timestamp("2015-04-03 00:00:00"), + Timestamp("2016-03-25 00:00:00"), + Timestamp("2017-04-14 00:00:00"), + Timestamp("2018-03-30 00:00:00"), + Timestamp("2019-04-19 00:00:00"), + Timestamp("2020-04-10 00:00:00"), + ], + ), + ( + USThanksgivingDay, + datetime(2011, 1, 1), + datetime(2020, 12, 31), + [ + datetime(2011, 11, 24), + datetime(2012, 11, 22), + datetime(2013, 11, 28), + datetime(2014, 11, 27), + datetime(2015, 11, 26), + datetime(2016, 11, 24), + datetime(2017, 11, 23), + datetime(2018, 11, 22), + datetime(2019, 11, 28), + datetime(2020, 11, 26), + ], + ), + ], +) +def test_holiday_dates(holiday, start_date, end_date, expected): + assert list(holiday.dates(start_date, end_date)) == expected + + # Verify that timezone info is preserved. + assert list( + holiday.dates( + utc.localize(Timestamp(start_date)), utc.localize(Timestamp(end_date)) + ) + ) == [utc.localize(dt) for dt in expected] + + +@pytest.mark.parametrize( + "holiday,start,expected", + [ + (USMemorialDay, datetime(2015, 7, 1), []), + (USMemorialDay, "2015-05-25", [Timestamp("2015-05-25")]), + (USLaborDay, datetime(2015, 7, 1), []), + (USLaborDay, "2015-09-07", [Timestamp("2015-09-07")]), + (USColumbusDay, datetime(2015, 7, 1), []), + (USColumbusDay, "2015-10-12", [Timestamp("2015-10-12")]), + (USThanksgivingDay, datetime(2015, 7, 1), []), + (USThanksgivingDay, "2015-11-26", [Timestamp("2015-11-26")]), + (USMartinLutherKingJr, datetime(2015, 7, 1), []), + (USMartinLutherKingJr, "2015-01-19", [Timestamp("2015-01-19")]), + (USPresidentsDay, datetime(2015, 7, 1), []), + (USPresidentsDay, "2015-02-16", [Timestamp("2015-02-16")]), + (GoodFriday, datetime(2015, 7, 1), []), + (GoodFriday, "2015-04-03", [Timestamp("2015-04-03")]), + (EasterMonday, "2015-04-06", [Timestamp("2015-04-06")]), + (EasterMonday, datetime(2015, 7, 1), []), + (EasterMonday, "2015-04-05", []), + ("New Year's Day", "2015-01-01", [Timestamp("2015-01-01")]), + ("New Year's Day", "2010-12-31", [Timestamp("2010-12-31")]), + ("New Year's Day", datetime(2015, 7, 1), []), + ("New Year's Day", "2011-01-01", []), + ("Independence Day", "2015-07-03", [Timestamp("2015-07-03")]), + ("Independence Day", datetime(2015, 7, 1), []), + ("Independence Day", "2015-07-04", []), + ("Veterans Day", "2012-11-12", [Timestamp("2012-11-12")]), + ("Veterans Day", datetime(2015, 7, 1), []), + ("Veterans Day", "2012-11-11", []), + ("Christmas Day", "2011-12-26", [Timestamp("2011-12-26")]), + ("Christmas Day", datetime(2015, 7, 1), []), + ("Christmas Day", "2011-12-25", []), + ("Juneteenth National Independence Day", "2020-06-19", []), + ( + "Juneteenth National Independence Day", + "2021-06-18", + [Timestamp("2021-06-18")], + ), + ("Juneteenth National Independence Day", "2022-06-19", []), + ( + "Juneteenth National Independence Day", + "2022-06-20", + [Timestamp("2022-06-20")], + ), + ], +) +def test_holidays_within_dates(holiday, start, expected): + # see gh-11477 + # + # Fix holiday behavior where holiday.dates returned dates outside + # start/end date, or observed rules could not be applied because the + # holiday was not in the original date range (e.g., 7/4/2015 -> 7/3/2015). + if isinstance(holiday, str): + calendar = get_calendar("USFederalHolidayCalendar") + holiday = calendar.rule_from_name(holiday) + + assert list(holiday.dates(start, start)) == expected + + # Verify that timezone info is preserved. + assert list( + holiday.dates(utc.localize(Timestamp(start)), utc.localize(Timestamp(start))) + ) == [utc.localize(dt) for dt in expected] + + +@pytest.mark.parametrize( + "transform", [lambda x: x.strftime("%Y-%m-%d"), lambda x: Timestamp(x)] +) +def test_argument_types(transform): + start_date = datetime(2011, 1, 1) + end_date = datetime(2020, 12, 31) + + holidays = USThanksgivingDay.dates(start_date, end_date) + holidays2 = USThanksgivingDay.dates(transform(start_date), transform(end_date)) + tm.assert_index_equal(holidays, holidays2) + + +@pytest.mark.parametrize( + "name,kwargs", + [ + ("One-Time", {"year": 2012, "month": 5, "day": 28}), + ( + "Range", + { + "month": 5, + "day": 28, + "start_date": datetime(2012, 1, 1), + "end_date": datetime(2012, 12, 31), + "offset": DateOffset(weekday=MO(1)), + }, + ), + ], +) +def test_special_holidays(name, kwargs): + base_date = [datetime(2012, 5, 28)] + holiday = Holiday(name, **kwargs) + + start_date = datetime(2011, 1, 1) + end_date = datetime(2020, 12, 31) + + assert base_date == holiday.dates(start_date, end_date) + + +def test_get_calendar(): + class TestCalendar(AbstractHolidayCalendar): + rules = [] + + calendar = get_calendar("TestCalendar") + assert TestCalendar == type(calendar) + + +def test_factory(): + class_1 = HolidayCalendarFactory( + "MemorialDay", AbstractHolidayCalendar, USMemorialDay + ) + class_2 = HolidayCalendarFactory( + "Thanksgiving", AbstractHolidayCalendar, USThanksgivingDay + ) + class_3 = HolidayCalendarFactory("Combined", class_1, class_2) + + assert len(class_1.rules) == 1 + assert len(class_2.rules) == 1 + assert len(class_3.rules) == 2 + + +def test_both_offset_observance_raises(): + # see gh-10217 + msg = "Cannot use both offset and observance" + with pytest.raises(NotImplementedError, match=msg): + Holiday( + "Cyber Monday", + month=11, + day=1, + offset=[DateOffset(weekday=SA(4))], + observance=next_monday, + ) + + +def test_half_open_interval_with_observance(): + # Prompted by GH 49075 + # Check for holidays that have a half-open date interval where + # they have either a start_date or end_date defined along + # with a defined observance pattern to make sure that the return type + # for Holiday.dates() remains consistent before & after the year that + # marks the 'edge' of the half-open date interval. + + holiday_1 = Holiday( + "Arbitrary Holiday - start 2022-03-14", + start_date=datetime(2022, 3, 14), + month=3, + day=14, + observance=next_monday, + ) + holiday_2 = Holiday( + "Arbitrary Holiday 2 - end 2022-03-20", + end_date=datetime(2022, 3, 20), + month=3, + day=20, + observance=next_monday, + ) + + class TestHolidayCalendar(AbstractHolidayCalendar): + rules = [ + USMartinLutherKingJr, + holiday_1, + holiday_2, + USLaborDay, + ] + + start = Timestamp("2022-08-01") + end = Timestamp("2022-08-31") + year_offset = DateOffset(years=5) + expected_results = DatetimeIndex([], dtype="datetime64[ns]", freq=None) + test_cal = TestHolidayCalendar() + + date_interval_low = test_cal.holidays(start - year_offset, end - year_offset) + date_window_edge = test_cal.holidays(start, end) + date_interval_high = test_cal.holidays(start + year_offset, end + year_offset) + + tm.assert_index_equal(date_interval_low, expected_results) + tm.assert_index_equal(date_window_edge, expected_results) + tm.assert_index_equal(date_interval_high, expected_results) + + +def test_holidays_with_timezone_specified_but_no_occurences(): + # GH 54580 + # _apply_rule() in holiday.py was silently dropping timezones if you passed it + # an empty list of holiday dates that had timezone information + start_date = Timestamp("2018-01-01", tz="America/Chicago") + end_date = Timestamp("2018-01-11", tz="America/Chicago") + test_case = USFederalHolidayCalendar().holidays( + start_date, end_date, return_name=True + ) + expected_results = Series("New Year's Day", index=[start_date]) + expected_results.index = expected_results.index.as_unit("ns") + + tm.assert_equal(test_case, expected_results) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/test_observance.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/test_observance.py new file mode 100644 index 0000000000000000000000000000000000000000..83038ad254b77daee6667bd269a8775016649d39 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/holiday/test_observance.py @@ -0,0 +1,105 @@ +from datetime import datetime + +import pytest + +from pandas.tseries.holiday import ( + after_nearest_workday, + before_nearest_workday, + nearest_workday, + next_monday, + next_monday_or_tuesday, + next_workday, + previous_friday, + previous_workday, + sunday_to_monday, + weekend_to_monday, +) + +_WEDNESDAY = datetime(2014, 4, 9) +_THURSDAY = datetime(2014, 4, 10) +_FRIDAY = datetime(2014, 4, 11) +_SATURDAY = datetime(2014, 4, 12) +_SUNDAY = datetime(2014, 4, 13) +_MONDAY = datetime(2014, 4, 14) +_TUESDAY = datetime(2014, 4, 15) +_NEXT_WEDNESDAY = datetime(2014, 4, 16) + + +@pytest.mark.parametrize("day", [_SATURDAY, _SUNDAY]) +def test_next_monday(day): + assert next_monday(day) == _MONDAY + + +@pytest.mark.parametrize( + "day,expected", [(_SATURDAY, _MONDAY), (_SUNDAY, _TUESDAY), (_MONDAY, _TUESDAY)] +) +def test_next_monday_or_tuesday(day, expected): + assert next_monday_or_tuesday(day) == expected + + +@pytest.mark.parametrize("day", [_SATURDAY, _SUNDAY]) +def test_previous_friday(day): + assert previous_friday(day) == _FRIDAY + + +def test_sunday_to_monday(): + assert sunday_to_monday(_SUNDAY) == _MONDAY + + +@pytest.mark.parametrize( + "day,expected", [(_SATURDAY, _FRIDAY), (_SUNDAY, _MONDAY), (_MONDAY, _MONDAY)] +) +def test_nearest_workday(day, expected): + assert nearest_workday(day) == expected + + +@pytest.mark.parametrize( + "day,expected", [(_SATURDAY, _MONDAY), (_SUNDAY, _MONDAY), (_MONDAY, _MONDAY)] +) +def test_weekend_to_monday(day, expected): + assert weekend_to_monday(day) == expected + + +@pytest.mark.parametrize( + "day,expected", + [ + (_WEDNESDAY, _THURSDAY), + (_THURSDAY, _FRIDAY), + (_SATURDAY, _MONDAY), + (_SUNDAY, _MONDAY), + (_MONDAY, _TUESDAY), + (_TUESDAY, _NEXT_WEDNESDAY), # WED is same week as TUE + ], +) +def test_next_workday(day, expected): + assert next_workday(day) == expected + + +@pytest.mark.parametrize( + "day,expected", [(_SATURDAY, _FRIDAY), (_SUNDAY, _FRIDAY), (_TUESDAY, _MONDAY)] +) +def test_previous_workday(day, expected): + assert previous_workday(day) == expected + + +@pytest.mark.parametrize( + "day,expected", + [ + (_THURSDAY, _WEDNESDAY), + (_FRIDAY, _THURSDAY), + (_SATURDAY, _THURSDAY), + (_SUNDAY, _FRIDAY), + (_MONDAY, _FRIDAY), # last week Friday + (_TUESDAY, _MONDAY), + (_NEXT_WEDNESDAY, _TUESDAY), # WED is same week as TUE + ], +) +def test_before_nearest_workday(day, expected): + assert before_nearest_workday(day) == expected + + +@pytest.mark.parametrize( + "day,expected", [(_SATURDAY, _MONDAY), (_SUNDAY, _TUESDAY), (_FRIDAY, _MONDAY)] +) +def test_after_nearest_workday(day, expected): + assert after_nearest_workday(day) == expected diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..288a74734579a556c37559200b555882601462a7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/common.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5bb6a6bbbbf062f01dc9163941710ac0d921833 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/common.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_business_day.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_business_day.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4f434ca08cdb9420b810f2e002b828a5eda1159 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_business_day.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_business_hour.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_business_hour.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cff82d519be9098c65051954063a5201ae3ea91 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_business_hour.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_business_month.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_business_month.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a755d839c2e7f7d60c41d011f1497756c34fd4a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_business_month.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_business_quarter.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_business_quarter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cae45c79e10be98c452645e9f93f5c66b1403501 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_business_quarter.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_business_year.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_business_year.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d69172313ab6f1e8aa6a9451a9ca9c161a851b48 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_business_year.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_common.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66e4ae89b7a8acebb1d5ca1217bc6149a409efa8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_common.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_custom_business_day.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_custom_business_day.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9291d0ea29b51b1e068d5213f6fb9e2b9b80aa7b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_custom_business_day.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_custom_business_hour.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_custom_business_hour.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6539092d7b0793f403bdcc7fee5a7a9af7eca297 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_custom_business_hour.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_custom_business_month.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_custom_business_month.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc6c6c34fd61b3681af6650b12ea70991f77e8bb Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_custom_business_month.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_dst.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_dst.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8180c48fa06f15fd910de19db8c3c2a55e5c4a33 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_dst.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_easter.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_easter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f59a47acc664c69108386575c2d0e3b1510038a7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_easter.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_fiscal.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_fiscal.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3327bbbe58ed62c8d80e1d7c5a658d1e6809c0e2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_fiscal.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_index.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_index.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fadef1e7d9b153ac4cf4277b1a6ffc328f788a15 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_index.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_month.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_month.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f83eb261ca5f760fb999cfa471aa3f9de4f57272 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_month.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_offsets.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_offsets.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd6837c999a322515ab1fbfc7bfd249bb467a0c2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_offsets.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_offsets_properties.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_offsets_properties.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef3d8a421befde33eb396b3a2d4f57c0d5b6e0d9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_offsets_properties.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_quarter.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_quarter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dce763373e423aac2bbaf830d8dc9b9d160005d3 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_quarter.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_ticks.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_ticks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92131443f690523d31da07afcf4dd26b44bb379d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_ticks.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_week.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_week.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1fb2fdd28acee99963bb0476eb16ea193b5e081 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_week.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_year.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_year.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..185e599e3c781f03e03d592ddc9915a94d072066 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/__pycache__/test_year.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/common.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/common.py new file mode 100644 index 0000000000000000000000000000000000000000..efb010addad225cda407d55c47dc804645cf3999 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/common.py @@ -0,0 +1,37 @@ +""" +Assertion helpers and base class for offsets tests +""" +from __future__ import annotations + + +def assert_offset_equal(offset, base, expected): + actual = offset + base + actual_swapped = base + offset + actual_apply = offset._apply(base) + try: + assert actual == expected + assert actual_swapped == expected + assert actual_apply == expected + except AssertionError as err: + raise AssertionError( + f"\nExpected: {expected}\nActual: {actual}\nFor Offset: {offset})" + f"\nAt Date: {base}" + ) from err + + +def assert_is_on_offset(offset, date, expected): + actual = offset.is_on_offset(date) + assert actual == expected, ( + f"\nExpected: {expected}\nActual: {actual}\nFor Offset: {offset})" + f"\nAt Date: {date}" + ) + + +class WeekDay: + MON = 0 + TUE = 1 + WED = 2 + THU = 3 + FRI = 4 + SAT = 5 + SUN = 6 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_business_day.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_business_day.py new file mode 100644 index 0000000000000000000000000000000000000000..7db1921369023eaf05c65bf537c259de3f2a81cb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_business_day.py @@ -0,0 +1,236 @@ +""" +Tests for offsets.BDay +""" +from __future__ import annotations + +from datetime import ( + date, + datetime, + timedelta, +) + +import numpy as np +import pytest + +from pandas._libs.tslibs.offsets import ( + ApplyTypeError, + BDay, + BMonthEnd, +) + +from pandas import ( + DatetimeIndex, + Timedelta, + _testing as tm, +) +from pandas.tests.tseries.offsets.common import ( + assert_is_on_offset, + assert_offset_equal, +) + +from pandas.tseries import offsets + + +@pytest.fixture +def dt(): + return datetime(2008, 1, 1) + + +@pytest.fixture +def _offset(): + return BDay + + +@pytest.fixture +def offset(_offset): + return _offset() + + +@pytest.fixture +def offset2(_offset): + return _offset(2) + + +class TestBusinessDay: + def test_different_normalize_equals(self, _offset, offset2): + # GH#21404 changed __eq__ to return False when `normalize` does not match + offset = _offset() + offset2 = _offset(normalize=True) + assert offset != offset2 + + def test_repr(self, offset, offset2): + assert repr(offset) == "" + assert repr(offset2) == "<2 * BusinessDays>" + + expected = "" + assert repr(offset + timedelta(1)) == expected + + def test_with_offset(self, dt, offset): + offset = offset + timedelta(hours=2) + + assert (dt + offset) == datetime(2008, 1, 2, 2) + + @pytest.mark.parametrize( + "td", + [ + Timedelta(hours=2), + Timedelta(hours=2).to_pytimedelta(), + Timedelta(hours=2).to_timedelta64(), + ], + ids=lambda x: type(x), + ) + def test_with_offset_index(self, td, dt, offset): + dti = DatetimeIndex([dt]) + expected = DatetimeIndex([datetime(2008, 1, 2, 2)]) + + result = dti + (td + offset) + tm.assert_index_equal(result, expected) + + result = dti + (offset + td) + tm.assert_index_equal(result, expected) + + def test_eq(self, offset2): + assert offset2 == offset2 + + def test_hash(self, offset2): + assert hash(offset2) == hash(offset2) + + def test_add_datetime(self, dt, offset2): + assert offset2 + dt == datetime(2008, 1, 3) + assert offset2 + np.datetime64("2008-01-01 00:00:00") == datetime(2008, 1, 3) + + def testRollback1(self, dt, _offset): + assert _offset(10).rollback(dt) == dt + + def testRollback2(self, _offset): + assert _offset(10).rollback(datetime(2008, 1, 5)) == datetime(2008, 1, 4) + + def testRollforward1(self, dt, _offset): + assert _offset(10).rollforward(dt) == dt + + def testRollforward2(self, _offset): + assert _offset(10).rollforward(datetime(2008, 1, 5)) == datetime(2008, 1, 7) + + def test_roll_date_object(self, offset): + dt = date(2012, 9, 15) + + result = offset.rollback(dt) + assert result == datetime(2012, 9, 14) + + result = offset.rollforward(dt) + assert result == datetime(2012, 9, 17) + + offset = offsets.Day() + result = offset.rollback(dt) + assert result == datetime(2012, 9, 15) + + result = offset.rollforward(dt) + assert result == datetime(2012, 9, 15) + + @pytest.mark.parametrize( + "dt, expected", + [ + (datetime(2008, 1, 1), True), + (datetime(2008, 1, 5), False), + ], + ) + def test_is_on_offset(self, offset, dt, expected): + assert_is_on_offset(offset, dt, expected) + + apply_cases: list[tuple[int, dict[datetime, datetime]]] = [ + ( + 1, + { + datetime(2008, 1, 1): datetime(2008, 1, 2), + datetime(2008, 1, 4): datetime(2008, 1, 7), + datetime(2008, 1, 5): datetime(2008, 1, 7), + datetime(2008, 1, 6): datetime(2008, 1, 7), + datetime(2008, 1, 7): datetime(2008, 1, 8), + }, + ), + ( + 2, + { + datetime(2008, 1, 1): datetime(2008, 1, 3), + datetime(2008, 1, 4): datetime(2008, 1, 8), + datetime(2008, 1, 5): datetime(2008, 1, 8), + datetime(2008, 1, 6): datetime(2008, 1, 8), + datetime(2008, 1, 7): datetime(2008, 1, 9), + }, + ), + ( + -1, + { + datetime(2008, 1, 1): datetime(2007, 12, 31), + datetime(2008, 1, 4): datetime(2008, 1, 3), + datetime(2008, 1, 5): datetime(2008, 1, 4), + datetime(2008, 1, 6): datetime(2008, 1, 4), + datetime(2008, 1, 7): datetime(2008, 1, 4), + datetime(2008, 1, 8): datetime(2008, 1, 7), + }, + ), + ( + -2, + { + datetime(2008, 1, 1): datetime(2007, 12, 28), + datetime(2008, 1, 4): datetime(2008, 1, 2), + datetime(2008, 1, 5): datetime(2008, 1, 3), + datetime(2008, 1, 6): datetime(2008, 1, 3), + datetime(2008, 1, 7): datetime(2008, 1, 3), + datetime(2008, 1, 8): datetime(2008, 1, 4), + datetime(2008, 1, 9): datetime(2008, 1, 7), + }, + ), + ( + 0, + { + datetime(2008, 1, 1): datetime(2008, 1, 1), + datetime(2008, 1, 4): datetime(2008, 1, 4), + datetime(2008, 1, 5): datetime(2008, 1, 7), + datetime(2008, 1, 6): datetime(2008, 1, 7), + datetime(2008, 1, 7): datetime(2008, 1, 7), + }, + ), + ] + + @pytest.mark.parametrize("case", apply_cases) + def test_apply(self, case, _offset): + n, cases = case + offset = _offset(n) + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + def test_apply_large_n(self, _offset): + dt = datetime(2012, 10, 23) + + result = dt + _offset(10) + assert result == datetime(2012, 11, 6) + + result = dt + _offset(100) - _offset(100) + assert result == dt + + off = _offset() * 6 + rs = datetime(2012, 1, 1) - off + xp = datetime(2011, 12, 23) + assert rs == xp + + st = datetime(2011, 12, 18) + rs = st + off + xp = datetime(2011, 12, 26) + assert rs == xp + + off = _offset() * 10 + rs = datetime(2014, 1, 5) + off # see #5890 + xp = datetime(2014, 1, 17) + assert rs == xp + + def test_apply_corner(self, _offset): + if _offset is BDay: + msg = "Only know how to combine business day with datetime or timedelta" + else: + msg = ( + "Only know how to combine trading day " + "with datetime, datetime64 or timedelta" + ) + with pytest.raises(ApplyTypeError, match=msg): + _offset()._apply(BMonthEnd()) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_business_hour.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_business_hour.py new file mode 100644 index 0000000000000000000000000000000000000000..2779100f5355cf7475b10d3de6bf7ceebc92af96 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_business_hour.py @@ -0,0 +1,1445 @@ +""" +Tests for offsets.BusinessHour +""" +from __future__ import annotations + +from datetime import ( + datetime, + time as dt_time, +) + +import pytest + +from pandas._libs.tslibs import ( + Timedelta, + Timestamp, +) +from pandas._libs.tslibs.offsets import ( + BDay, + BusinessHour, + Nano, +) + +from pandas import ( + DatetimeIndex, + _testing as tm, + date_range, +) +from pandas.tests.tseries.offsets.common import assert_offset_equal + + +@pytest.fixture +def dt(): + return datetime(2014, 7, 1, 10, 00) + + +@pytest.fixture +def _offset(): + return BusinessHour + + +@pytest.fixture +def offset1(): + return BusinessHour() + + +@pytest.fixture +def offset2(): + return BusinessHour(n=3) + + +@pytest.fixture +def offset3(): + return BusinessHour(n=-1) + + +@pytest.fixture +def offset4(): + return BusinessHour(n=-4) + + +@pytest.fixture +def offset5(): + return BusinessHour(start=dt_time(11, 0), end=dt_time(14, 30)) + + +@pytest.fixture +def offset6(): + return BusinessHour(start="20:00", end="05:00") + + +@pytest.fixture +def offset7(): + return BusinessHour(n=-2, start=dt_time(21, 30), end=dt_time(6, 30)) + + +@pytest.fixture +def offset8(): + return BusinessHour(start=["09:00", "13:00"], end=["12:00", "17:00"]) + + +@pytest.fixture +def offset9(): + return BusinessHour(n=3, start=["09:00", "22:00"], end=["13:00", "03:00"]) + + +@pytest.fixture +def offset10(): + return BusinessHour(n=-1, start=["23:00", "13:00"], end=["02:00", "17:00"]) + + +class TestBusinessHour: + @pytest.mark.parametrize( + "start,end,match", + [ + ( + dt_time(11, 0, 5), + "17:00", + "time data must be specified only with hour and minute", + ), + ("AAA", "17:00", "time data must match '%H:%M' format"), + ("14:00:05", "17:00", "time data must match '%H:%M' format"), + ([], "17:00", "Must include at least 1 start time"), + ("09:00", [], "Must include at least 1 end time"), + ( + ["09:00", "11:00"], + "17:00", + "number of starting time and ending time must be the same", + ), + ( + ["09:00", "11:00"], + ["10:00"], + "number of starting time and ending time must be the same", + ), + ( + ["09:00", "11:00"], + ["12:00", "20:00"], + r"invalid starting and ending time\(s\): opening hours should not " + "touch or overlap with one another", + ), + ( + ["12:00", "20:00"], + ["09:00", "11:00"], + r"invalid starting and ending time\(s\): opening hours should not " + "touch or overlap with one another", + ), + ], + ) + def test_constructor_errors(self, start, end, match): + with pytest.raises(ValueError, match=match): + BusinessHour(start=start, end=end) + + def test_different_normalize_equals(self, _offset): + # GH#21404 changed __eq__ to return False when `normalize` does not match + offset = _offset() + offset2 = _offset(normalize=True) + assert offset != offset2 + + def test_repr( + self, + offset1, + offset2, + offset3, + offset4, + offset5, + offset6, + offset7, + offset8, + offset9, + offset10, + ): + assert repr(offset1) == "" + assert repr(offset2) == "<3 * BusinessHours: bh=09:00-17:00>" + assert repr(offset3) == "<-1 * BusinessHour: bh=09:00-17:00>" + assert repr(offset4) == "<-4 * BusinessHours: bh=09:00-17:00>" + + assert repr(offset5) == "" + assert repr(offset6) == "" + assert repr(offset7) == "<-2 * BusinessHours: bh=21:30-06:30>" + assert repr(offset8) == "" + assert repr(offset9) == "<3 * BusinessHours: bh=09:00-13:00,22:00-03:00>" + assert repr(offset10) == "<-1 * BusinessHour: bh=13:00-17:00,23:00-02:00>" + + def test_with_offset(self, dt): + expected = Timestamp("2014-07-01 13:00") + + assert dt + BusinessHour() * 3 == expected + assert dt + BusinessHour(n=3) == expected + + @pytest.mark.parametrize( + "offset_name", + ["offset1", "offset2", "offset3", "offset4", "offset8", "offset9", "offset10"], + ) + def test_eq_attribute(self, offset_name, request): + offset = request.getfixturevalue(offset_name) + assert offset == offset + + @pytest.mark.parametrize( + "offset1,offset2", + [ + (BusinessHour(start="09:00"), BusinessHour()), + ( + BusinessHour(start=["23:00", "13:00"], end=["12:00", "17:00"]), + BusinessHour(start=["13:00", "23:00"], end=["17:00", "12:00"]), + ), + ], + ) + def test_eq(self, offset1, offset2): + assert offset1 == offset2 + + @pytest.mark.parametrize( + "offset1,offset2", + [ + (BusinessHour(), BusinessHour(-1)), + (BusinessHour(start="09:00"), BusinessHour(start="09:01")), + ( + BusinessHour(start="09:00", end="17:00"), + BusinessHour(start="17:00", end="09:01"), + ), + ( + BusinessHour(start=["13:00", "23:00"], end=["18:00", "07:00"]), + BusinessHour(start=["13:00", "23:00"], end=["17:00", "12:00"]), + ), + ], + ) + def test_neq(self, offset1, offset2): + assert offset1 != offset2 + + @pytest.mark.parametrize( + "offset_name", + ["offset1", "offset2", "offset3", "offset4", "offset8", "offset9", "offset10"], + ) + def test_hash(self, offset_name, request): + offset = request.getfixturevalue(offset_name) + assert offset == offset + + def test_add_datetime( + self, + dt, + offset1, + offset2, + offset3, + offset4, + offset8, + offset9, + offset10, + ): + assert offset1 + dt == datetime(2014, 7, 1, 11) + assert offset2 + dt == datetime(2014, 7, 1, 13) + assert offset3 + dt == datetime(2014, 6, 30, 17) + assert offset4 + dt == datetime(2014, 6, 30, 14) + assert offset8 + dt == datetime(2014, 7, 1, 11) + assert offset9 + dt == datetime(2014, 7, 1, 22) + assert offset10 + dt == datetime(2014, 7, 1, 1) + + def test_sub(self, dt, offset2, _offset): + off = offset2 + msg = "Cannot subtract datetime from offset" + with pytest.raises(TypeError, match=msg): + off - dt + assert 2 * off - off == off + + assert dt - offset2 == dt + _offset(-3) + + def test_multiply_by_zero(self, dt, offset1, offset2): + assert dt - 0 * offset1 == dt + assert dt + 0 * offset1 == dt + assert dt - 0 * offset2 == dt + assert dt + 0 * offset2 == dt + + def testRollback1( + self, + dt, + _offset, + offset1, + offset2, + offset3, + offset4, + offset5, + offset6, + offset7, + offset8, + offset9, + offset10, + ): + assert offset1.rollback(dt) == dt + assert offset2.rollback(dt) == dt + assert offset3.rollback(dt) == dt + assert offset4.rollback(dt) == dt + assert offset5.rollback(dt) == datetime(2014, 6, 30, 14, 30) + assert offset6.rollback(dt) == datetime(2014, 7, 1, 5, 0) + assert offset7.rollback(dt) == datetime(2014, 7, 1, 6, 30) + assert offset8.rollback(dt) == dt + assert offset9.rollback(dt) == dt + assert offset10.rollback(dt) == datetime(2014, 7, 1, 2) + + datet = datetime(2014, 7, 1, 0) + assert offset1.rollback(datet) == datetime(2014, 6, 30, 17) + assert offset2.rollback(datet) == datetime(2014, 6, 30, 17) + assert offset3.rollback(datet) == datetime(2014, 6, 30, 17) + assert offset4.rollback(datet) == datetime(2014, 6, 30, 17) + assert offset5.rollback(datet) == datetime(2014, 6, 30, 14, 30) + assert offset6.rollback(datet) == datet + assert offset7.rollback(datet) == datet + assert offset8.rollback(datet) == datetime(2014, 6, 30, 17) + assert offset9.rollback(datet) == datet + assert offset10.rollback(datet) == datet + + assert _offset(5).rollback(dt) == dt + + def testRollback2(self, _offset): + assert _offset(-3).rollback(datetime(2014, 7, 5, 15, 0)) == datetime( + 2014, 7, 4, 17, 0 + ) + + def testRollforward1( + self, + dt, + _offset, + offset1, + offset2, + offset3, + offset4, + offset5, + offset6, + offset7, + offset8, + offset9, + offset10, + ): + assert offset1.rollforward(dt) == dt + assert offset2.rollforward(dt) == dt + assert offset3.rollforward(dt) == dt + assert offset4.rollforward(dt) == dt + assert offset5.rollforward(dt) == datetime(2014, 7, 1, 11, 0) + assert offset6.rollforward(dt) == datetime(2014, 7, 1, 20, 0) + assert offset7.rollforward(dt) == datetime(2014, 7, 1, 21, 30) + assert offset8.rollforward(dt) == dt + assert offset9.rollforward(dt) == dt + assert offset10.rollforward(dt) == datetime(2014, 7, 1, 13) + + datet = datetime(2014, 7, 1, 0) + assert offset1.rollforward(datet) == datetime(2014, 7, 1, 9) + assert offset2.rollforward(datet) == datetime(2014, 7, 1, 9) + assert offset3.rollforward(datet) == datetime(2014, 7, 1, 9) + assert offset4.rollforward(datet) == datetime(2014, 7, 1, 9) + assert offset5.rollforward(datet) == datetime(2014, 7, 1, 11) + assert offset6.rollforward(datet) == datet + assert offset7.rollforward(datet) == datet + assert offset8.rollforward(datet) == datetime(2014, 7, 1, 9) + assert offset9.rollforward(datet) == datet + assert offset10.rollforward(datet) == datet + + assert _offset(5).rollforward(dt) == dt + + def testRollforward2(self, _offset): + assert _offset(-3).rollforward(datetime(2014, 7, 5, 16, 0)) == datetime( + 2014, 7, 7, 9 + ) + + def test_roll_date_object(self): + offset = BusinessHour() + + dt = datetime(2014, 7, 6, 15, 0) + + result = offset.rollback(dt) + assert result == datetime(2014, 7, 4, 17) + + result = offset.rollforward(dt) + assert result == datetime(2014, 7, 7, 9) + + normalize_cases = [] + normalize_cases.append( + ( + BusinessHour(normalize=True), + { + datetime(2014, 7, 1, 8): datetime(2014, 7, 1), + datetime(2014, 7, 1, 17): datetime(2014, 7, 2), + datetime(2014, 7, 1, 16): datetime(2014, 7, 2), + datetime(2014, 7, 1, 23): datetime(2014, 7, 2), + datetime(2014, 7, 1, 0): datetime(2014, 7, 1), + datetime(2014, 7, 4, 15): datetime(2014, 7, 4), + datetime(2014, 7, 4, 15, 59): datetime(2014, 7, 4), + datetime(2014, 7, 4, 16, 30): datetime(2014, 7, 7), + datetime(2014, 7, 5, 23): datetime(2014, 7, 7), + datetime(2014, 7, 6, 10): datetime(2014, 7, 7), + }, + ) + ) + + normalize_cases.append( + ( + BusinessHour(-1, normalize=True), + { + datetime(2014, 7, 1, 8): datetime(2014, 6, 30), + datetime(2014, 7, 1, 17): datetime(2014, 7, 1), + datetime(2014, 7, 1, 16): datetime(2014, 7, 1), + datetime(2014, 7, 1, 10): datetime(2014, 6, 30), + datetime(2014, 7, 1, 0): datetime(2014, 6, 30), + datetime(2014, 7, 7, 10): datetime(2014, 7, 4), + datetime(2014, 7, 7, 10, 1): datetime(2014, 7, 7), + datetime(2014, 7, 5, 23): datetime(2014, 7, 4), + datetime(2014, 7, 6, 10): datetime(2014, 7, 4), + }, + ) + ) + + normalize_cases.append( + ( + BusinessHour(1, normalize=True, start="17:00", end="04:00"), + { + datetime(2014, 7, 1, 8): datetime(2014, 7, 1), + datetime(2014, 7, 1, 17): datetime(2014, 7, 1), + datetime(2014, 7, 1, 23): datetime(2014, 7, 2), + datetime(2014, 7, 2, 2): datetime(2014, 7, 2), + datetime(2014, 7, 2, 3): datetime(2014, 7, 2), + datetime(2014, 7, 4, 23): datetime(2014, 7, 5), + datetime(2014, 7, 5, 2): datetime(2014, 7, 5), + datetime(2014, 7, 7, 2): datetime(2014, 7, 7), + datetime(2014, 7, 7, 17): datetime(2014, 7, 7), + }, + ) + ) + + @pytest.mark.parametrize("case", normalize_cases) + def test_normalize(self, case): + offset, cases = case + for dt, expected in cases.items(): + assert offset._apply(dt) == expected + + on_offset_cases = [] + on_offset_cases.append( + ( + BusinessHour(), + { + datetime(2014, 7, 1, 9): True, + datetime(2014, 7, 1, 8, 59): False, + datetime(2014, 7, 1, 8): False, + datetime(2014, 7, 1, 17): True, + datetime(2014, 7, 1, 17, 1): False, + datetime(2014, 7, 1, 18): False, + datetime(2014, 7, 5, 9): False, + datetime(2014, 7, 6, 12): False, + }, + ) + ) + + on_offset_cases.append( + ( + BusinessHour(start="10:00", end="15:00"), + { + datetime(2014, 7, 1, 9): False, + datetime(2014, 7, 1, 10): True, + datetime(2014, 7, 1, 15): True, + datetime(2014, 7, 1, 15, 1): False, + datetime(2014, 7, 5, 12): False, + datetime(2014, 7, 6, 12): False, + }, + ) + ) + + on_offset_cases.append( + ( + BusinessHour(start="19:00", end="05:00"), + { + datetime(2014, 7, 1, 9, 0): False, + datetime(2014, 7, 1, 10, 0): False, + datetime(2014, 7, 1, 15): False, + datetime(2014, 7, 1, 15, 1): False, + datetime(2014, 7, 5, 12, 0): False, + datetime(2014, 7, 6, 12, 0): False, + datetime(2014, 7, 1, 19, 0): True, + datetime(2014, 7, 2, 0, 0): True, + datetime(2014, 7, 4, 23): True, + datetime(2014, 7, 5, 1): True, + datetime(2014, 7, 5, 5, 0): True, + datetime(2014, 7, 6, 23, 0): False, + datetime(2014, 7, 7, 3, 0): False, + }, + ) + ) + + on_offset_cases.append( + ( + BusinessHour(start=["09:00", "13:00"], end=["12:00", "17:00"]), + { + datetime(2014, 7, 1, 9): True, + datetime(2014, 7, 1, 8, 59): False, + datetime(2014, 7, 1, 8): False, + datetime(2014, 7, 1, 17): True, + datetime(2014, 7, 1, 17, 1): False, + datetime(2014, 7, 1, 18): False, + datetime(2014, 7, 5, 9): False, + datetime(2014, 7, 6, 12): False, + datetime(2014, 7, 1, 12, 30): False, + }, + ) + ) + + on_offset_cases.append( + ( + BusinessHour(start=["19:00", "23:00"], end=["21:00", "05:00"]), + { + datetime(2014, 7, 1, 9, 0): False, + datetime(2014, 7, 1, 10, 0): False, + datetime(2014, 7, 1, 15): False, + datetime(2014, 7, 1, 15, 1): False, + datetime(2014, 7, 5, 12, 0): False, + datetime(2014, 7, 6, 12, 0): False, + datetime(2014, 7, 1, 19, 0): True, + datetime(2014, 7, 2, 0, 0): True, + datetime(2014, 7, 4, 23): True, + datetime(2014, 7, 5, 1): True, + datetime(2014, 7, 5, 5, 0): True, + datetime(2014, 7, 6, 23, 0): False, + datetime(2014, 7, 7, 3, 0): False, + datetime(2014, 7, 4, 22): False, + }, + ) + ) + + @pytest.mark.parametrize("case", on_offset_cases) + def test_is_on_offset(self, case): + offset, cases = case + for dt, expected in cases.items(): + assert offset.is_on_offset(dt) == expected + + apply_cases = [ + ( + BusinessHour(), + { + datetime(2014, 7, 1, 11): datetime(2014, 7, 1, 12), + datetime(2014, 7, 1, 13): datetime(2014, 7, 1, 14), + datetime(2014, 7, 1, 15): datetime(2014, 7, 1, 16), + datetime(2014, 7, 1, 19): datetime(2014, 7, 2, 10), + datetime(2014, 7, 1, 16): datetime(2014, 7, 2, 9), + datetime(2014, 7, 1, 16, 30, 15): datetime(2014, 7, 2, 9, 30, 15), + datetime(2014, 7, 1, 17): datetime(2014, 7, 2, 10), + datetime(2014, 7, 2, 11): datetime(2014, 7, 2, 12), + # out of business hours + datetime(2014, 7, 2, 8): datetime(2014, 7, 2, 10), + datetime(2014, 7, 2, 19): datetime(2014, 7, 3, 10), + datetime(2014, 7, 2, 23): datetime(2014, 7, 3, 10), + datetime(2014, 7, 3, 0): datetime(2014, 7, 3, 10), + # saturday + datetime(2014, 7, 5, 15): datetime(2014, 7, 7, 10), + datetime(2014, 7, 4, 17): datetime(2014, 7, 7, 10), + datetime(2014, 7, 4, 16, 30): datetime(2014, 7, 7, 9, 30), + datetime(2014, 7, 4, 16, 30, 30): datetime(2014, 7, 7, 9, 30, 30), + }, + ), + ( + BusinessHour(4), + { + datetime(2014, 7, 1, 11): datetime(2014, 7, 1, 15), + datetime(2014, 7, 1, 13): datetime(2014, 7, 2, 9), + datetime(2014, 7, 1, 15): datetime(2014, 7, 2, 11), + datetime(2014, 7, 1, 16): datetime(2014, 7, 2, 12), + datetime(2014, 7, 1, 17): datetime(2014, 7, 2, 13), + datetime(2014, 7, 2, 11): datetime(2014, 7, 2, 15), + datetime(2014, 7, 2, 8): datetime(2014, 7, 2, 13), + datetime(2014, 7, 2, 19): datetime(2014, 7, 3, 13), + datetime(2014, 7, 2, 23): datetime(2014, 7, 3, 13), + datetime(2014, 7, 3, 0): datetime(2014, 7, 3, 13), + datetime(2014, 7, 5, 15): datetime(2014, 7, 7, 13), + datetime(2014, 7, 4, 17): datetime(2014, 7, 7, 13), + datetime(2014, 7, 4, 16, 30): datetime(2014, 7, 7, 12, 30), + datetime(2014, 7, 4, 16, 30, 30): datetime(2014, 7, 7, 12, 30, 30), + }, + ), + ( + BusinessHour(-1), + { + datetime(2014, 7, 1, 11): datetime(2014, 7, 1, 10), + datetime(2014, 7, 1, 13): datetime(2014, 7, 1, 12), + datetime(2014, 7, 1, 15): datetime(2014, 7, 1, 14), + datetime(2014, 7, 1, 16): datetime(2014, 7, 1, 15), + datetime(2014, 7, 1, 10): datetime(2014, 6, 30, 17), + datetime(2014, 7, 1, 16, 30, 15): datetime(2014, 7, 1, 15, 30, 15), + datetime(2014, 7, 1, 9, 30, 15): datetime(2014, 6, 30, 16, 30, 15), + datetime(2014, 7, 1, 17): datetime(2014, 7, 1, 16), + datetime(2014, 7, 1, 5): datetime(2014, 6, 30, 16), + datetime(2014, 7, 2, 11): datetime(2014, 7, 2, 10), + # out of business hours + datetime(2014, 7, 2, 8): datetime(2014, 7, 1, 16), + datetime(2014, 7, 2, 19): datetime(2014, 7, 2, 16), + datetime(2014, 7, 2, 23): datetime(2014, 7, 2, 16), + datetime(2014, 7, 3, 0): datetime(2014, 7, 2, 16), + # saturday + datetime(2014, 7, 5, 15): datetime(2014, 7, 4, 16), + datetime(2014, 7, 7, 9): datetime(2014, 7, 4, 16), + datetime(2014, 7, 7, 9, 30): datetime(2014, 7, 4, 16, 30), + datetime(2014, 7, 7, 9, 30, 30): datetime(2014, 7, 4, 16, 30, 30), + }, + ), + ( + BusinessHour(-4), + { + datetime(2014, 7, 1, 11): datetime(2014, 6, 30, 15), + datetime(2014, 7, 1, 13): datetime(2014, 6, 30, 17), + datetime(2014, 7, 1, 15): datetime(2014, 7, 1, 11), + datetime(2014, 7, 1, 16): datetime(2014, 7, 1, 12), + datetime(2014, 7, 1, 17): datetime(2014, 7, 1, 13), + datetime(2014, 7, 2, 11): datetime(2014, 7, 1, 15), + datetime(2014, 7, 2, 8): datetime(2014, 7, 1, 13), + datetime(2014, 7, 2, 19): datetime(2014, 7, 2, 13), + datetime(2014, 7, 2, 23): datetime(2014, 7, 2, 13), + datetime(2014, 7, 3, 0): datetime(2014, 7, 2, 13), + datetime(2014, 7, 5, 15): datetime(2014, 7, 4, 13), + datetime(2014, 7, 4, 18): datetime(2014, 7, 4, 13), + datetime(2014, 7, 7, 9, 30): datetime(2014, 7, 4, 13, 30), + datetime(2014, 7, 7, 9, 30, 30): datetime(2014, 7, 4, 13, 30, 30), + }, + ), + ( + BusinessHour(start="13:00", end="16:00"), + { + datetime(2014, 7, 1, 11): datetime(2014, 7, 1, 14), + datetime(2014, 7, 1, 13): datetime(2014, 7, 1, 14), + datetime(2014, 7, 1, 15): datetime(2014, 7, 2, 13), + datetime(2014, 7, 1, 19): datetime(2014, 7, 2, 14), + datetime(2014, 7, 1, 16): datetime(2014, 7, 2, 14), + datetime(2014, 7, 1, 15, 30, 15): datetime(2014, 7, 2, 13, 30, 15), + datetime(2014, 7, 5, 15): datetime(2014, 7, 7, 14), + datetime(2014, 7, 4, 17): datetime(2014, 7, 7, 14), + }, + ), + ( + BusinessHour(n=2, start="13:00", end="16:00"), + { + datetime(2014, 7, 1, 17): datetime(2014, 7, 2, 15), + datetime(2014, 7, 2, 14): datetime(2014, 7, 3, 13), + datetime(2014, 7, 2, 8): datetime(2014, 7, 2, 15), + datetime(2014, 7, 2, 19): datetime(2014, 7, 3, 15), + datetime(2014, 7, 2, 14, 30): datetime(2014, 7, 3, 13, 30), + datetime(2014, 7, 3, 0): datetime(2014, 7, 3, 15), + datetime(2014, 7, 5, 15): datetime(2014, 7, 7, 15), + datetime(2014, 7, 4, 17): datetime(2014, 7, 7, 15), + datetime(2014, 7, 4, 14, 30): datetime(2014, 7, 7, 13, 30), + datetime(2014, 7, 4, 14, 30, 30): datetime(2014, 7, 7, 13, 30, 30), + }, + ), + ( + BusinessHour(n=-1, start="13:00", end="16:00"), + { + datetime(2014, 7, 2, 11): datetime(2014, 7, 1, 15), + datetime(2014, 7, 2, 13): datetime(2014, 7, 1, 15), + datetime(2014, 7, 2, 14): datetime(2014, 7, 1, 16), + datetime(2014, 7, 2, 15): datetime(2014, 7, 2, 14), + datetime(2014, 7, 2, 19): datetime(2014, 7, 2, 15), + datetime(2014, 7, 2, 16): datetime(2014, 7, 2, 15), + datetime(2014, 7, 2, 13, 30, 15): datetime(2014, 7, 1, 15, 30, 15), + datetime(2014, 7, 5, 15): datetime(2014, 7, 4, 15), + datetime(2014, 7, 7, 11): datetime(2014, 7, 4, 15), + }, + ), + ( + BusinessHour(n=-3, start="10:00", end="16:00"), + { + datetime(2014, 7, 1, 17): datetime(2014, 7, 1, 13), + datetime(2014, 7, 2, 14): datetime(2014, 7, 2, 11), + datetime(2014, 7, 2, 8): datetime(2014, 7, 1, 13), + datetime(2014, 7, 2, 13): datetime(2014, 7, 1, 16), + datetime(2014, 7, 2, 19): datetime(2014, 7, 2, 13), + datetime(2014, 7, 2, 11, 30): datetime(2014, 7, 1, 14, 30), + datetime(2014, 7, 3, 0): datetime(2014, 7, 2, 13), + datetime(2014, 7, 4, 10): datetime(2014, 7, 3, 13), + datetime(2014, 7, 5, 15): datetime(2014, 7, 4, 13), + datetime(2014, 7, 4, 16): datetime(2014, 7, 4, 13), + datetime(2014, 7, 4, 12, 30): datetime(2014, 7, 3, 15, 30), + datetime(2014, 7, 4, 12, 30, 30): datetime(2014, 7, 3, 15, 30, 30), + }, + ), + ( + BusinessHour(start="19:00", end="05:00"), + { + datetime(2014, 7, 1, 17): datetime(2014, 7, 1, 20), + datetime(2014, 7, 2, 14): datetime(2014, 7, 2, 20), + datetime(2014, 7, 2, 8): datetime(2014, 7, 2, 20), + datetime(2014, 7, 2, 13): datetime(2014, 7, 2, 20), + datetime(2014, 7, 2, 19): datetime(2014, 7, 2, 20), + datetime(2014, 7, 2, 4, 30): datetime(2014, 7, 2, 19, 30), + datetime(2014, 7, 3, 0): datetime(2014, 7, 3, 1), + datetime(2014, 7, 4, 10): datetime(2014, 7, 4, 20), + datetime(2014, 7, 4, 23): datetime(2014, 7, 5, 0), + datetime(2014, 7, 5, 0): datetime(2014, 7, 5, 1), + datetime(2014, 7, 5, 4): datetime(2014, 7, 7, 19), + datetime(2014, 7, 5, 4, 30): datetime(2014, 7, 7, 19, 30), + datetime(2014, 7, 5, 4, 30, 30): datetime(2014, 7, 7, 19, 30, 30), + }, + ), + ( + BusinessHour(n=-1, start="19:00", end="05:00"), + { + datetime(2014, 7, 1, 17): datetime(2014, 7, 1, 4), + datetime(2014, 7, 2, 14): datetime(2014, 7, 2, 4), + datetime(2014, 7, 2, 8): datetime(2014, 7, 2, 4), + datetime(2014, 7, 2, 13): datetime(2014, 7, 2, 4), + datetime(2014, 7, 2, 20): datetime(2014, 7, 2, 5), + datetime(2014, 7, 2, 19): datetime(2014, 7, 2, 4), + datetime(2014, 7, 2, 19, 30): datetime(2014, 7, 2, 4, 30), + datetime(2014, 7, 3, 0): datetime(2014, 7, 2, 23), + datetime(2014, 7, 3, 6): datetime(2014, 7, 3, 4), + datetime(2014, 7, 4, 23): datetime(2014, 7, 4, 22), + datetime(2014, 7, 5, 0): datetime(2014, 7, 4, 23), + datetime(2014, 7, 5, 4): datetime(2014, 7, 5, 3), + datetime(2014, 7, 7, 19, 30): datetime(2014, 7, 5, 4, 30), + datetime(2014, 7, 7, 19, 30, 30): datetime(2014, 7, 5, 4, 30, 30), + }, + ), + ( + BusinessHour(n=4, start="00:00", end="23:00"), + { + datetime(2014, 7, 3, 22): datetime(2014, 7, 4, 3), + datetime(2014, 7, 4, 22): datetime(2014, 7, 7, 3), + datetime(2014, 7, 3, 22, 30): datetime(2014, 7, 4, 3, 30), + datetime(2014, 7, 3, 22, 20): datetime(2014, 7, 4, 3, 20), + datetime(2014, 7, 4, 22, 30, 30): datetime(2014, 7, 7, 3, 30, 30), + datetime(2014, 7, 4, 22, 30, 20): datetime(2014, 7, 7, 3, 30, 20), + }, + ), + ( + BusinessHour(n=-4, start="00:00", end="23:00"), + { + datetime(2014, 7, 4, 3): datetime(2014, 7, 3, 22), + datetime(2014, 7, 7, 3): datetime(2014, 7, 4, 22), + datetime(2014, 7, 4, 3, 30): datetime(2014, 7, 3, 22, 30), + datetime(2014, 7, 4, 3, 20): datetime(2014, 7, 3, 22, 20), + datetime(2014, 7, 7, 3, 30, 30): datetime(2014, 7, 4, 22, 30, 30), + datetime(2014, 7, 7, 3, 30, 20): datetime(2014, 7, 4, 22, 30, 20), + }, + ), + ( + BusinessHour(start=["09:00", "14:00"], end=["12:00", "18:00"]), + { + datetime(2014, 7, 1, 11): datetime(2014, 7, 1, 14), + datetime(2014, 7, 1, 15): datetime(2014, 7, 1, 16), + datetime(2014, 7, 1, 19): datetime(2014, 7, 2, 10), + datetime(2014, 7, 1, 16): datetime(2014, 7, 1, 17), + datetime(2014, 7, 1, 16, 30, 15): datetime(2014, 7, 1, 17, 30, 15), + datetime(2014, 7, 1, 17): datetime(2014, 7, 2, 9), + datetime(2014, 7, 2, 11): datetime(2014, 7, 2, 14), + # out of business hours + datetime(2014, 7, 1, 13): datetime(2014, 7, 1, 15), + datetime(2014, 7, 2, 8): datetime(2014, 7, 2, 10), + datetime(2014, 7, 2, 19): datetime(2014, 7, 3, 10), + datetime(2014, 7, 2, 23): datetime(2014, 7, 3, 10), + datetime(2014, 7, 3, 0): datetime(2014, 7, 3, 10), + # saturday + datetime(2014, 7, 5, 15): datetime(2014, 7, 7, 10), + datetime(2014, 7, 4, 17): datetime(2014, 7, 7, 9), + datetime(2014, 7, 4, 17, 30): datetime(2014, 7, 7, 9, 30), + datetime(2014, 7, 4, 17, 30, 30): datetime(2014, 7, 7, 9, 30, 30), + }, + ), + ( + BusinessHour(n=4, start=["09:00", "14:00"], end=["12:00", "18:00"]), + { + datetime(2014, 7, 1, 11): datetime(2014, 7, 1, 17), + datetime(2014, 7, 1, 13): datetime(2014, 7, 2, 9), + datetime(2014, 7, 1, 15): datetime(2014, 7, 2, 10), + datetime(2014, 7, 1, 16): datetime(2014, 7, 2, 11), + datetime(2014, 7, 1, 17): datetime(2014, 7, 2, 14), + datetime(2014, 7, 2, 11): datetime(2014, 7, 2, 17), + datetime(2014, 7, 2, 8): datetime(2014, 7, 2, 15), + datetime(2014, 7, 2, 19): datetime(2014, 7, 3, 15), + datetime(2014, 7, 2, 23): datetime(2014, 7, 3, 15), + datetime(2014, 7, 3, 0): datetime(2014, 7, 3, 15), + datetime(2014, 7, 5, 15): datetime(2014, 7, 7, 15), + datetime(2014, 7, 4, 17): datetime(2014, 7, 7, 14), + datetime(2014, 7, 4, 16, 30): datetime(2014, 7, 7, 11, 30), + datetime(2014, 7, 4, 16, 30, 30): datetime(2014, 7, 7, 11, 30, 30), + }, + ), + ( + BusinessHour(n=-4, start=["09:00", "14:00"], end=["12:00", "18:00"]), + { + datetime(2014, 7, 1, 11): datetime(2014, 6, 30, 16), + datetime(2014, 7, 1, 13): datetime(2014, 6, 30, 17), + datetime(2014, 7, 1, 15): datetime(2014, 6, 30, 18), + datetime(2014, 7, 1, 16): datetime(2014, 7, 1, 10), + datetime(2014, 7, 1, 17): datetime(2014, 7, 1, 11), + datetime(2014, 7, 2, 11): datetime(2014, 7, 1, 16), + datetime(2014, 7, 2, 8): datetime(2014, 7, 1, 12), + datetime(2014, 7, 2, 19): datetime(2014, 7, 2, 12), + datetime(2014, 7, 2, 23): datetime(2014, 7, 2, 12), + datetime(2014, 7, 3, 0): datetime(2014, 7, 2, 12), + datetime(2014, 7, 5, 15): datetime(2014, 7, 4, 12), + datetime(2014, 7, 4, 18): datetime(2014, 7, 4, 12), + datetime(2014, 7, 7, 9, 30): datetime(2014, 7, 4, 14, 30), + datetime(2014, 7, 7, 9, 30, 30): datetime(2014, 7, 4, 14, 30, 30), + }, + ), + ( + BusinessHour(n=-1, start=["19:00", "03:00"], end=["01:00", "05:00"]), + { + datetime(2014, 7, 1, 17): datetime(2014, 7, 1, 4), + datetime(2014, 7, 2, 14): datetime(2014, 7, 2, 4), + datetime(2014, 7, 2, 8): datetime(2014, 7, 2, 4), + datetime(2014, 7, 2, 13): datetime(2014, 7, 2, 4), + datetime(2014, 7, 2, 20): datetime(2014, 7, 2, 5), + datetime(2014, 7, 2, 19): datetime(2014, 7, 2, 4), + datetime(2014, 7, 2, 4): datetime(2014, 7, 2, 1), + datetime(2014, 7, 2, 19, 30): datetime(2014, 7, 2, 4, 30), + datetime(2014, 7, 3, 0): datetime(2014, 7, 2, 23), + datetime(2014, 7, 3, 6): datetime(2014, 7, 3, 4), + datetime(2014, 7, 4, 23): datetime(2014, 7, 4, 22), + datetime(2014, 7, 5, 0): datetime(2014, 7, 4, 23), + datetime(2014, 7, 5, 4): datetime(2014, 7, 5, 0), + datetime(2014, 7, 7, 3, 30): datetime(2014, 7, 5, 0, 30), + datetime(2014, 7, 7, 19, 30): datetime(2014, 7, 7, 4, 30), + datetime(2014, 7, 7, 19, 30, 30): datetime(2014, 7, 7, 4, 30, 30), + }, + ), + ] + + # long business hours (see gh-26381) + + # multiple business hours + + @pytest.mark.parametrize("case", apply_cases) + def test_apply(self, case): + offset, cases = case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + apply_large_n_cases = [ + ( + # A week later + BusinessHour(40), + { + datetime(2014, 7, 1, 11): datetime(2014, 7, 8, 11), + datetime(2014, 7, 1, 13): datetime(2014, 7, 8, 13), + datetime(2014, 7, 1, 15): datetime(2014, 7, 8, 15), + datetime(2014, 7, 1, 16): datetime(2014, 7, 8, 16), + datetime(2014, 7, 1, 17): datetime(2014, 7, 9, 9), + datetime(2014, 7, 2, 11): datetime(2014, 7, 9, 11), + datetime(2014, 7, 2, 8): datetime(2014, 7, 9, 9), + datetime(2014, 7, 2, 19): datetime(2014, 7, 10, 9), + datetime(2014, 7, 2, 23): datetime(2014, 7, 10, 9), + datetime(2014, 7, 3, 0): datetime(2014, 7, 10, 9), + datetime(2014, 7, 5, 15): datetime(2014, 7, 14, 9), + datetime(2014, 7, 4, 18): datetime(2014, 7, 14, 9), + datetime(2014, 7, 7, 9, 30): datetime(2014, 7, 14, 9, 30), + datetime(2014, 7, 7, 9, 30, 30): datetime(2014, 7, 14, 9, 30, 30), + }, + ), + ( + # 3 days and 1 hour before + BusinessHour(-25), + { + datetime(2014, 7, 1, 11): datetime(2014, 6, 26, 10), + datetime(2014, 7, 1, 13): datetime(2014, 6, 26, 12), + datetime(2014, 7, 1, 9): datetime(2014, 6, 25, 16), + datetime(2014, 7, 1, 10): datetime(2014, 6, 25, 17), + datetime(2014, 7, 3, 11): datetime(2014, 6, 30, 10), + datetime(2014, 7, 3, 8): datetime(2014, 6, 27, 16), + datetime(2014, 7, 3, 19): datetime(2014, 6, 30, 16), + datetime(2014, 7, 3, 23): datetime(2014, 6, 30, 16), + datetime(2014, 7, 4, 9): datetime(2014, 6, 30, 16), + datetime(2014, 7, 5, 15): datetime(2014, 7, 1, 16), + datetime(2014, 7, 6, 18): datetime(2014, 7, 1, 16), + datetime(2014, 7, 7, 9, 30): datetime(2014, 7, 1, 16, 30), + datetime(2014, 7, 7, 10, 30, 30): datetime(2014, 7, 2, 9, 30, 30), + }, + ), + ( + # 5 days and 3 hours later + BusinessHour(28, start="21:00", end="02:00"), + { + datetime(2014, 7, 1, 11): datetime(2014, 7, 9, 0), + datetime(2014, 7, 1, 22): datetime(2014, 7, 9, 1), + datetime(2014, 7, 1, 23): datetime(2014, 7, 9, 21), + datetime(2014, 7, 2, 2): datetime(2014, 7, 10, 0), + datetime(2014, 7, 3, 21): datetime(2014, 7, 11, 0), + datetime(2014, 7, 4, 1): datetime(2014, 7, 11, 23), + datetime(2014, 7, 4, 2): datetime(2014, 7, 12, 0), + datetime(2014, 7, 4, 3): datetime(2014, 7, 12, 0), + datetime(2014, 7, 5, 1): datetime(2014, 7, 14, 23), + datetime(2014, 7, 5, 15): datetime(2014, 7, 15, 0), + datetime(2014, 7, 6, 18): datetime(2014, 7, 15, 0), + datetime(2014, 7, 7, 1): datetime(2014, 7, 15, 0), + datetime(2014, 7, 7, 23, 30): datetime(2014, 7, 15, 21, 30), + }, + ), + ( + # large n for multiple opening hours (3 days and 1 hour before) + BusinessHour(n=-25, start=["09:00", "14:00"], end=["12:00", "19:00"]), + { + datetime(2014, 7, 1, 11): datetime(2014, 6, 26, 10), + datetime(2014, 7, 1, 13): datetime(2014, 6, 26, 11), + datetime(2014, 7, 1, 9): datetime(2014, 6, 25, 18), + datetime(2014, 7, 1, 10): datetime(2014, 6, 25, 19), + datetime(2014, 7, 3, 11): datetime(2014, 6, 30, 10), + datetime(2014, 7, 3, 8): datetime(2014, 6, 27, 18), + datetime(2014, 7, 3, 19): datetime(2014, 6, 30, 18), + datetime(2014, 7, 3, 23): datetime(2014, 6, 30, 18), + datetime(2014, 7, 4, 9): datetime(2014, 6, 30, 18), + datetime(2014, 7, 5, 15): datetime(2014, 7, 1, 18), + datetime(2014, 7, 6, 18): datetime(2014, 7, 1, 18), + datetime(2014, 7, 7, 9, 30): datetime(2014, 7, 1, 18, 30), + datetime(2014, 7, 7, 10, 30, 30): datetime(2014, 7, 2, 9, 30, 30), + }, + ), + ( + # 5 days and 3 hours later + BusinessHour(28, start=["21:00", "03:00"], end=["01:00", "04:00"]), + { + datetime(2014, 7, 1, 11): datetime(2014, 7, 9, 0), + datetime(2014, 7, 1, 22): datetime(2014, 7, 9, 3), + datetime(2014, 7, 1, 23): datetime(2014, 7, 9, 21), + datetime(2014, 7, 2, 2): datetime(2014, 7, 9, 23), + datetime(2014, 7, 3, 21): datetime(2014, 7, 11, 0), + datetime(2014, 7, 4, 1): datetime(2014, 7, 11, 23), + datetime(2014, 7, 4, 2): datetime(2014, 7, 11, 23), + datetime(2014, 7, 4, 3): datetime(2014, 7, 11, 23), + datetime(2014, 7, 4, 21): datetime(2014, 7, 12, 0), + datetime(2014, 7, 5, 0): datetime(2014, 7, 14, 22), + datetime(2014, 7, 5, 1): datetime(2014, 7, 14, 23), + datetime(2014, 7, 5, 15): datetime(2014, 7, 14, 23), + datetime(2014, 7, 6, 18): datetime(2014, 7, 14, 23), + datetime(2014, 7, 7, 1): datetime(2014, 7, 14, 23), + datetime(2014, 7, 7, 23, 30): datetime(2014, 7, 15, 21, 30), + }, + ), + ] + + @pytest.mark.parametrize("case", apply_large_n_cases) + def test_apply_large_n(self, case): + offset, cases = case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + def test_apply_nanoseconds(self): + tests = [ + ( + BusinessHour(), + { + Timestamp("2014-07-04 15:00") + + Nano(5): Timestamp("2014-07-04 16:00") + + Nano(5), + Timestamp("2014-07-04 16:00") + + Nano(5): Timestamp("2014-07-07 09:00") + + Nano(5), + Timestamp("2014-07-04 16:00") + - Nano(5): Timestamp("2014-07-04 17:00") + - Nano(5), + }, + ), + ( + BusinessHour(-1), + { + Timestamp("2014-07-04 15:00") + + Nano(5): Timestamp("2014-07-04 14:00") + + Nano(5), + Timestamp("2014-07-04 10:00") + + Nano(5): Timestamp("2014-07-04 09:00") + + Nano(5), + Timestamp("2014-07-04 10:00") + - Nano(5): Timestamp("2014-07-03 17:00") + - Nano(5), + }, + ), + ] + + for offset, cases in tests: + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + @pytest.mark.parametrize("td_unit", ["s", "ms", "us", "ns"]) + @pytest.mark.parametrize("unit", ["s", "ms", "us", "ns"]) + def test_bday_ignores_timedeltas(self, unit, td_unit): + # GH#55608 + idx = date_range("2010/02/01", "2010/02/10", freq="12h", unit=unit) + td = Timedelta(3, unit="h").as_unit(td_unit) + off = BDay(offset=td) + t1 = idx + off + + exp_unit = tm.get_finest_unit(td.unit, idx.unit) + + expected = DatetimeIndex( + [ + "2010-02-02 03:00:00", + "2010-02-02 15:00:00", + "2010-02-03 03:00:00", + "2010-02-03 15:00:00", + "2010-02-04 03:00:00", + "2010-02-04 15:00:00", + "2010-02-05 03:00:00", + "2010-02-05 15:00:00", + "2010-02-08 03:00:00", + "2010-02-08 15:00:00", + "2010-02-08 03:00:00", + "2010-02-08 15:00:00", + "2010-02-08 03:00:00", + "2010-02-08 15:00:00", + "2010-02-09 03:00:00", + "2010-02-09 15:00:00", + "2010-02-10 03:00:00", + "2010-02-10 15:00:00", + "2010-02-11 03:00:00", + ], + freq=None, + ).as_unit(exp_unit) + tm.assert_index_equal(t1, expected) + + # TODO(GH#55564): as_unit will be unnecessary + pointwise = DatetimeIndex([x + off for x in idx]).as_unit(exp_unit) + tm.assert_index_equal(pointwise, expected) + + def test_add_bday_offset_nanos(self): + # GH#55608 + idx = date_range("2010/02/01", "2010/02/10", freq="12h", unit="ns") + off = BDay(offset=Timedelta(3, unit="ns")) + + result = idx + off + expected = DatetimeIndex([x + off for x in idx]) + tm.assert_index_equal(result, expected) + + +class TestOpeningTimes: + # opening time should be affected by sign of n, not by n's value and end + opening_time_cases = [ + ( + [ + BusinessHour(), + BusinessHour(n=2), + BusinessHour(n=4), + BusinessHour(end="10:00"), + BusinessHour(n=2, end="4:00"), + BusinessHour(n=4, end="15:00"), + ], + { + datetime(2014, 7, 1, 11): ( + datetime(2014, 7, 2, 9), + datetime(2014, 7, 1, 9), + ), + datetime(2014, 7, 1, 18): ( + datetime(2014, 7, 2, 9), + datetime(2014, 7, 1, 9), + ), + datetime(2014, 7, 1, 23): ( + datetime(2014, 7, 2, 9), + datetime(2014, 7, 1, 9), + ), + datetime(2014, 7, 2, 8): ( + datetime(2014, 7, 2, 9), + datetime(2014, 7, 1, 9), + ), + # if timestamp is on opening time, next opening time is + # as it is + datetime(2014, 7, 2, 9): ( + datetime(2014, 7, 2, 9), + datetime(2014, 7, 2, 9), + ), + datetime(2014, 7, 2, 10): ( + datetime(2014, 7, 3, 9), + datetime(2014, 7, 2, 9), + ), + # 2014-07-05 is saturday + datetime(2014, 7, 5, 10): ( + datetime(2014, 7, 7, 9), + datetime(2014, 7, 4, 9), + ), + datetime(2014, 7, 4, 10): ( + datetime(2014, 7, 7, 9), + datetime(2014, 7, 4, 9), + ), + datetime(2014, 7, 4, 23): ( + datetime(2014, 7, 7, 9), + datetime(2014, 7, 4, 9), + ), + datetime(2014, 7, 6, 10): ( + datetime(2014, 7, 7, 9), + datetime(2014, 7, 4, 9), + ), + datetime(2014, 7, 7, 5): ( + datetime(2014, 7, 7, 9), + datetime(2014, 7, 4, 9), + ), + datetime(2014, 7, 7, 9, 1): ( + datetime(2014, 7, 8, 9), + datetime(2014, 7, 7, 9), + ), + }, + ), + ( + [ + BusinessHour(start="11:15"), + BusinessHour(n=2, start="11:15"), + BusinessHour(n=3, start="11:15"), + BusinessHour(start="11:15", end="10:00"), + BusinessHour(n=2, start="11:15", end="4:00"), + BusinessHour(n=3, start="11:15", end="15:00"), + ], + { + datetime(2014, 7, 1, 11): ( + datetime(2014, 7, 1, 11, 15), + datetime(2014, 6, 30, 11, 15), + ), + datetime(2014, 7, 1, 18): ( + datetime(2014, 7, 2, 11, 15), + datetime(2014, 7, 1, 11, 15), + ), + datetime(2014, 7, 1, 23): ( + datetime(2014, 7, 2, 11, 15), + datetime(2014, 7, 1, 11, 15), + ), + datetime(2014, 7, 2, 8): ( + datetime(2014, 7, 2, 11, 15), + datetime(2014, 7, 1, 11, 15), + ), + datetime(2014, 7, 2, 9): ( + datetime(2014, 7, 2, 11, 15), + datetime(2014, 7, 1, 11, 15), + ), + datetime(2014, 7, 2, 10): ( + datetime(2014, 7, 2, 11, 15), + datetime(2014, 7, 1, 11, 15), + ), + datetime(2014, 7, 2, 11, 15): ( + datetime(2014, 7, 2, 11, 15), + datetime(2014, 7, 2, 11, 15), + ), + datetime(2014, 7, 2, 11, 15, 1): ( + datetime(2014, 7, 3, 11, 15), + datetime(2014, 7, 2, 11, 15), + ), + datetime(2014, 7, 5, 10): ( + datetime(2014, 7, 7, 11, 15), + datetime(2014, 7, 4, 11, 15), + ), + datetime(2014, 7, 4, 10): ( + datetime(2014, 7, 4, 11, 15), + datetime(2014, 7, 3, 11, 15), + ), + datetime(2014, 7, 4, 23): ( + datetime(2014, 7, 7, 11, 15), + datetime(2014, 7, 4, 11, 15), + ), + datetime(2014, 7, 6, 10): ( + datetime(2014, 7, 7, 11, 15), + datetime(2014, 7, 4, 11, 15), + ), + datetime(2014, 7, 7, 5): ( + datetime(2014, 7, 7, 11, 15), + datetime(2014, 7, 4, 11, 15), + ), + datetime(2014, 7, 7, 9, 1): ( + datetime(2014, 7, 7, 11, 15), + datetime(2014, 7, 4, 11, 15), + ), + }, + ), + ( + [ + BusinessHour(-1), + BusinessHour(n=-2), + BusinessHour(n=-4), + BusinessHour(n=-1, end="10:00"), + BusinessHour(n=-2, end="4:00"), + BusinessHour(n=-4, end="15:00"), + ], + { + datetime(2014, 7, 1, 11): ( + datetime(2014, 7, 1, 9), + datetime(2014, 7, 2, 9), + ), + datetime(2014, 7, 1, 18): ( + datetime(2014, 7, 1, 9), + datetime(2014, 7, 2, 9), + ), + datetime(2014, 7, 1, 23): ( + datetime(2014, 7, 1, 9), + datetime(2014, 7, 2, 9), + ), + datetime(2014, 7, 2, 8): ( + datetime(2014, 7, 1, 9), + datetime(2014, 7, 2, 9), + ), + datetime(2014, 7, 2, 9): ( + datetime(2014, 7, 2, 9), + datetime(2014, 7, 2, 9), + ), + datetime(2014, 7, 2, 10): ( + datetime(2014, 7, 2, 9), + datetime(2014, 7, 3, 9), + ), + datetime(2014, 7, 5, 10): ( + datetime(2014, 7, 4, 9), + datetime(2014, 7, 7, 9), + ), + datetime(2014, 7, 4, 10): ( + datetime(2014, 7, 4, 9), + datetime(2014, 7, 7, 9), + ), + datetime(2014, 7, 4, 23): ( + datetime(2014, 7, 4, 9), + datetime(2014, 7, 7, 9), + ), + datetime(2014, 7, 6, 10): ( + datetime(2014, 7, 4, 9), + datetime(2014, 7, 7, 9), + ), + datetime(2014, 7, 7, 5): ( + datetime(2014, 7, 4, 9), + datetime(2014, 7, 7, 9), + ), + datetime(2014, 7, 7, 9): ( + datetime(2014, 7, 7, 9), + datetime(2014, 7, 7, 9), + ), + datetime(2014, 7, 7, 9, 1): ( + datetime(2014, 7, 7, 9), + datetime(2014, 7, 8, 9), + ), + }, + ), + ( + [ + BusinessHour(start="17:00", end="05:00"), + BusinessHour(n=3, start="17:00", end="03:00"), + ], + { + datetime(2014, 7, 1, 11): ( + datetime(2014, 7, 1, 17), + datetime(2014, 6, 30, 17), + ), + datetime(2014, 7, 1, 18): ( + datetime(2014, 7, 2, 17), + datetime(2014, 7, 1, 17), + ), + datetime(2014, 7, 1, 23): ( + datetime(2014, 7, 2, 17), + datetime(2014, 7, 1, 17), + ), + datetime(2014, 7, 2, 8): ( + datetime(2014, 7, 2, 17), + datetime(2014, 7, 1, 17), + ), + datetime(2014, 7, 2, 9): ( + datetime(2014, 7, 2, 17), + datetime(2014, 7, 1, 17), + ), + datetime(2014, 7, 4, 17): ( + datetime(2014, 7, 4, 17), + datetime(2014, 7, 4, 17), + ), + datetime(2014, 7, 5, 10): ( + datetime(2014, 7, 7, 17), + datetime(2014, 7, 4, 17), + ), + datetime(2014, 7, 4, 10): ( + datetime(2014, 7, 4, 17), + datetime(2014, 7, 3, 17), + ), + datetime(2014, 7, 4, 23): ( + datetime(2014, 7, 7, 17), + datetime(2014, 7, 4, 17), + ), + datetime(2014, 7, 6, 10): ( + datetime(2014, 7, 7, 17), + datetime(2014, 7, 4, 17), + ), + datetime(2014, 7, 7, 5): ( + datetime(2014, 7, 7, 17), + datetime(2014, 7, 4, 17), + ), + datetime(2014, 7, 7, 17, 1): ( + datetime(2014, 7, 8, 17), + datetime(2014, 7, 7, 17), + ), + }, + ), + ( + [ + BusinessHour(-1, start="17:00", end="05:00"), + BusinessHour(n=-2, start="17:00", end="03:00"), + ], + { + datetime(2014, 7, 1, 11): ( + datetime(2014, 6, 30, 17), + datetime(2014, 7, 1, 17), + ), + datetime(2014, 7, 1, 18): ( + datetime(2014, 7, 1, 17), + datetime(2014, 7, 2, 17), + ), + datetime(2014, 7, 1, 23): ( + datetime(2014, 7, 1, 17), + datetime(2014, 7, 2, 17), + ), + datetime(2014, 7, 2, 8): ( + datetime(2014, 7, 1, 17), + datetime(2014, 7, 2, 17), + ), + datetime(2014, 7, 2, 9): ( + datetime(2014, 7, 1, 17), + datetime(2014, 7, 2, 17), + ), + datetime(2014, 7, 2, 16, 59): ( + datetime(2014, 7, 1, 17), + datetime(2014, 7, 2, 17), + ), + datetime(2014, 7, 5, 10): ( + datetime(2014, 7, 4, 17), + datetime(2014, 7, 7, 17), + ), + datetime(2014, 7, 4, 10): ( + datetime(2014, 7, 3, 17), + datetime(2014, 7, 4, 17), + ), + datetime(2014, 7, 4, 23): ( + datetime(2014, 7, 4, 17), + datetime(2014, 7, 7, 17), + ), + datetime(2014, 7, 6, 10): ( + datetime(2014, 7, 4, 17), + datetime(2014, 7, 7, 17), + ), + datetime(2014, 7, 7, 5): ( + datetime(2014, 7, 4, 17), + datetime(2014, 7, 7, 17), + ), + datetime(2014, 7, 7, 18): ( + datetime(2014, 7, 7, 17), + datetime(2014, 7, 8, 17), + ), + }, + ), + ( + [ + BusinessHour(start=["11:15", "15:00"], end=["13:00", "20:00"]), + BusinessHour(n=3, start=["11:15", "15:00"], end=["12:00", "20:00"]), + BusinessHour(start=["11:15", "15:00"], end=["13:00", "17:00"]), + BusinessHour(n=2, start=["11:15", "15:00"], end=["12:00", "03:00"]), + BusinessHour(n=3, start=["11:15", "15:00"], end=["13:00", "16:00"]), + ], + { + datetime(2014, 7, 1, 11): ( + datetime(2014, 7, 1, 11, 15), + datetime(2014, 6, 30, 15), + ), + datetime(2014, 7, 1, 18): ( + datetime(2014, 7, 2, 11, 15), + datetime(2014, 7, 1, 15), + ), + datetime(2014, 7, 1, 23): ( + datetime(2014, 7, 2, 11, 15), + datetime(2014, 7, 1, 15), + ), + datetime(2014, 7, 2, 8): ( + datetime(2014, 7, 2, 11, 15), + datetime(2014, 7, 1, 15), + ), + datetime(2014, 7, 2, 9): ( + datetime(2014, 7, 2, 11, 15), + datetime(2014, 7, 1, 15), + ), + datetime(2014, 7, 2, 10): ( + datetime(2014, 7, 2, 11, 15), + datetime(2014, 7, 1, 15), + ), + datetime(2014, 7, 2, 11, 15): ( + datetime(2014, 7, 2, 11, 15), + datetime(2014, 7, 2, 11, 15), + ), + datetime(2014, 7, 2, 11, 15, 1): ( + datetime(2014, 7, 2, 15), + datetime(2014, 7, 2, 11, 15), + ), + datetime(2014, 7, 5, 10): ( + datetime(2014, 7, 7, 11, 15), + datetime(2014, 7, 4, 15), + ), + datetime(2014, 7, 4, 10): ( + datetime(2014, 7, 4, 11, 15), + datetime(2014, 7, 3, 15), + ), + datetime(2014, 7, 4, 23): ( + datetime(2014, 7, 7, 11, 15), + datetime(2014, 7, 4, 15), + ), + datetime(2014, 7, 6, 10): ( + datetime(2014, 7, 7, 11, 15), + datetime(2014, 7, 4, 15), + ), + datetime(2014, 7, 7, 5): ( + datetime(2014, 7, 7, 11, 15), + datetime(2014, 7, 4, 15), + ), + datetime(2014, 7, 7, 9, 1): ( + datetime(2014, 7, 7, 11, 15), + datetime(2014, 7, 4, 15), + ), + datetime(2014, 7, 7, 12): ( + datetime(2014, 7, 7, 15), + datetime(2014, 7, 7, 11, 15), + ), + }, + ), + ( + [ + BusinessHour(n=-1, start=["17:00", "08:00"], end=["05:00", "10:00"]), + BusinessHour(n=-2, start=["08:00", "17:00"], end=["10:00", "03:00"]), + ], + { + datetime(2014, 7, 1, 11): ( + datetime(2014, 7, 1, 8), + datetime(2014, 7, 1, 17), + ), + datetime(2014, 7, 1, 18): ( + datetime(2014, 7, 1, 17), + datetime(2014, 7, 2, 8), + ), + datetime(2014, 7, 1, 23): ( + datetime(2014, 7, 1, 17), + datetime(2014, 7, 2, 8), + ), + datetime(2014, 7, 2, 8): ( + datetime(2014, 7, 2, 8), + datetime(2014, 7, 2, 8), + ), + datetime(2014, 7, 2, 9): ( + datetime(2014, 7, 2, 8), + datetime(2014, 7, 2, 17), + ), + datetime(2014, 7, 2, 16, 59): ( + datetime(2014, 7, 2, 8), + datetime(2014, 7, 2, 17), + ), + datetime(2014, 7, 5, 10): ( + datetime(2014, 7, 4, 17), + datetime(2014, 7, 7, 8), + ), + datetime(2014, 7, 4, 10): ( + datetime(2014, 7, 4, 8), + datetime(2014, 7, 4, 17), + ), + datetime(2014, 7, 4, 23): ( + datetime(2014, 7, 4, 17), + datetime(2014, 7, 7, 8), + ), + datetime(2014, 7, 6, 10): ( + datetime(2014, 7, 4, 17), + datetime(2014, 7, 7, 8), + ), + datetime(2014, 7, 7, 5): ( + datetime(2014, 7, 4, 17), + datetime(2014, 7, 7, 8), + ), + datetime(2014, 7, 7, 18): ( + datetime(2014, 7, 7, 17), + datetime(2014, 7, 8, 8), + ), + }, + ), + ] + + @pytest.mark.parametrize("case", opening_time_cases) + def test_opening_time(self, case): + _offsets, cases = case + for offset in _offsets: + for dt, (exp_next, exp_prev) in cases.items(): + assert offset._next_opening_time(dt) == exp_next + assert offset._prev_opening_time(dt) == exp_prev diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_business_month.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_business_month.py new file mode 100644 index 0000000000000000000000000000000000000000..a14451e60aa89f3a74f52add62a53759027edd21 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_business_month.py @@ -0,0 +1,217 @@ +""" +Tests for the following offsets: +- BMonthBegin +- BMonthEnd +""" +from __future__ import annotations + +from datetime import datetime + +import pytest + +import pandas as pd +from pandas.tests.tseries.offsets.common import ( + assert_is_on_offset, + assert_offset_equal, +) + +from pandas.tseries.offsets import ( + BMonthBegin, + BMonthEnd, +) + + +@pytest.mark.parametrize("n", [-2, 1]) +@pytest.mark.parametrize( + "cls", + [ + BMonthBegin, + BMonthEnd, + ], +) +def test_apply_index(cls, n): + offset = cls(n=n) + rng = pd.date_range(start="1/1/2000", periods=100000, freq="min") + ser = pd.Series(rng) + + res = rng + offset + assert res.freq is None # not retained + assert res[0] == rng[0] + offset + assert res[-1] == rng[-1] + offset + res2 = ser + offset + # apply_index is only for indexes, not series, so no res2_v2 + assert res2.iloc[0] == ser.iloc[0] + offset + assert res2.iloc[-1] == ser.iloc[-1] + offset + + +class TestBMonthBegin: + def test_offsets_compare_equal(self): + # root cause of #456 + offset1 = BMonthBegin() + offset2 = BMonthBegin() + assert not offset1 != offset2 + + offset_cases = [] + offset_cases.append( + ( + BMonthBegin(), + { + datetime(2008, 1, 1): datetime(2008, 2, 1), + datetime(2008, 1, 31): datetime(2008, 2, 1), + datetime(2006, 12, 29): datetime(2007, 1, 1), + datetime(2006, 12, 31): datetime(2007, 1, 1), + datetime(2006, 9, 1): datetime(2006, 10, 2), + datetime(2007, 1, 1): datetime(2007, 2, 1), + datetime(2006, 12, 1): datetime(2007, 1, 1), + }, + ) + ) + + offset_cases.append( + ( + BMonthBegin(0), + { + datetime(2008, 1, 1): datetime(2008, 1, 1), + datetime(2006, 10, 2): datetime(2006, 10, 2), + datetime(2008, 1, 31): datetime(2008, 2, 1), + datetime(2006, 12, 29): datetime(2007, 1, 1), + datetime(2006, 12, 31): datetime(2007, 1, 1), + datetime(2006, 9, 15): datetime(2006, 10, 2), + }, + ) + ) + + offset_cases.append( + ( + BMonthBegin(2), + { + datetime(2008, 1, 1): datetime(2008, 3, 3), + datetime(2008, 1, 15): datetime(2008, 3, 3), + datetime(2006, 12, 29): datetime(2007, 2, 1), + datetime(2006, 12, 31): datetime(2007, 2, 1), + datetime(2007, 1, 1): datetime(2007, 3, 1), + datetime(2006, 11, 1): datetime(2007, 1, 1), + }, + ) + ) + + offset_cases.append( + ( + BMonthBegin(-1), + { + datetime(2007, 1, 1): datetime(2006, 12, 1), + datetime(2008, 6, 30): datetime(2008, 6, 2), + datetime(2008, 6, 1): datetime(2008, 5, 1), + datetime(2008, 3, 10): datetime(2008, 3, 3), + datetime(2008, 12, 31): datetime(2008, 12, 1), + datetime(2006, 12, 29): datetime(2006, 12, 1), + datetime(2006, 12, 30): datetime(2006, 12, 1), + datetime(2007, 1, 1): datetime(2006, 12, 1), + }, + ) + ) + + @pytest.mark.parametrize("case", offset_cases) + def test_offset(self, case): + offset, cases = case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + on_offset_cases = [ + (BMonthBegin(), datetime(2007, 12, 31), False), + (BMonthBegin(), datetime(2008, 1, 1), True), + (BMonthBegin(), datetime(2001, 4, 2), True), + (BMonthBegin(), datetime(2008, 3, 3), True), + ] + + @pytest.mark.parametrize("case", on_offset_cases) + def test_is_on_offset(self, case): + offset, dt, expected = case + assert_is_on_offset(offset, dt, expected) + + +class TestBMonthEnd: + def test_normalize(self): + dt = datetime(2007, 1, 1, 3) + + result = dt + BMonthEnd(normalize=True) + expected = dt.replace(hour=0) + BMonthEnd() + assert result == expected + + def test_offsets_compare_equal(self): + # root cause of #456 + offset1 = BMonthEnd() + offset2 = BMonthEnd() + assert not offset1 != offset2 + + offset_cases = [] + offset_cases.append( + ( + BMonthEnd(), + { + datetime(2008, 1, 1): datetime(2008, 1, 31), + datetime(2008, 1, 31): datetime(2008, 2, 29), + datetime(2006, 12, 29): datetime(2007, 1, 31), + datetime(2006, 12, 31): datetime(2007, 1, 31), + datetime(2007, 1, 1): datetime(2007, 1, 31), + datetime(2006, 12, 1): datetime(2006, 12, 29), + }, + ) + ) + + offset_cases.append( + ( + BMonthEnd(0), + { + datetime(2008, 1, 1): datetime(2008, 1, 31), + datetime(2008, 1, 31): datetime(2008, 1, 31), + datetime(2006, 12, 29): datetime(2006, 12, 29), + datetime(2006, 12, 31): datetime(2007, 1, 31), + datetime(2007, 1, 1): datetime(2007, 1, 31), + }, + ) + ) + + offset_cases.append( + ( + BMonthEnd(2), + { + datetime(2008, 1, 1): datetime(2008, 2, 29), + datetime(2008, 1, 31): datetime(2008, 3, 31), + datetime(2006, 12, 29): datetime(2007, 2, 28), + datetime(2006, 12, 31): datetime(2007, 2, 28), + datetime(2007, 1, 1): datetime(2007, 2, 28), + datetime(2006, 11, 1): datetime(2006, 12, 29), + }, + ) + ) + + offset_cases.append( + ( + BMonthEnd(-1), + { + datetime(2007, 1, 1): datetime(2006, 12, 29), + datetime(2008, 6, 30): datetime(2008, 5, 30), + datetime(2008, 12, 31): datetime(2008, 11, 28), + datetime(2006, 12, 29): datetime(2006, 11, 30), + datetime(2006, 12, 30): datetime(2006, 12, 29), + datetime(2007, 1, 1): datetime(2006, 12, 29), + }, + ) + ) + + @pytest.mark.parametrize("case", offset_cases) + def test_offset(self, case): + offset, cases = case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + on_offset_cases = [ + (BMonthEnd(), datetime(2007, 12, 31), True), + (BMonthEnd(), datetime(2008, 1, 1), False), + ] + + @pytest.mark.parametrize("case", on_offset_cases) + def test_is_on_offset(self, case): + offset, dt, expected = case + assert_is_on_offset(offset, dt, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_business_quarter.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_business_quarter.py new file mode 100644 index 0000000000000000000000000000000000000000..6d7a115054b7f20e3ab024eb31f266c18920f2c5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_business_quarter.py @@ -0,0 +1,315 @@ +""" +Tests for the following offsets: +- BQuarterBegin +- BQuarterEnd +""" +from __future__ import annotations + +from datetime import datetime + +import pytest + +import pandas._testing as tm +from pandas.tests.tseries.offsets.common import ( + assert_is_on_offset, + assert_offset_equal, +) + +from pandas.tseries.offsets import ( + BQuarterBegin, + BQuarterEnd, +) + + +def test_quarterly_dont_normalize(): + date = datetime(2012, 3, 31, 5, 30) + + offsets = (BQuarterEnd, BQuarterBegin) + + for klass in offsets: + result = date + klass() + assert result.time() == date.time() + + +@pytest.mark.parametrize("offset", [BQuarterBegin(), BQuarterEnd()]) +def test_on_offset(offset): + dates = [ + datetime(2016, m, d) + for m in [10, 11, 12] + for d in [1, 2, 3, 28, 29, 30, 31] + if not (m == 11 and d == 31) + ] + for date in dates: + res = offset.is_on_offset(date) + slow_version = date == (date + offset) - offset + assert res == slow_version + + +class TestBQuarterBegin: + def test_repr(self): + expected = "" + assert repr(BQuarterBegin()) == expected + expected = "" + assert repr(BQuarterBegin(startingMonth=3)) == expected + expected = "" + assert repr(BQuarterBegin(startingMonth=1)) == expected + + def test_is_anchored(self): + msg = "BQuarterBegin.is_anchored is deprecated " + + with tm.assert_produces_warning(FutureWarning, match=msg): + assert BQuarterBegin(startingMonth=1).is_anchored() + assert BQuarterBegin().is_anchored() + assert not BQuarterBegin(2, startingMonth=1).is_anchored() + + def test_offset_corner_case(self): + # corner + offset = BQuarterBegin(n=-1, startingMonth=1) + assert datetime(2007, 4, 3) + offset == datetime(2007, 4, 2) + + offset_cases = [] + offset_cases.append( + ( + BQuarterBegin(startingMonth=1), + { + datetime(2008, 1, 1): datetime(2008, 4, 1), + datetime(2008, 1, 31): datetime(2008, 4, 1), + datetime(2008, 2, 15): datetime(2008, 4, 1), + datetime(2008, 2, 29): datetime(2008, 4, 1), + datetime(2008, 3, 15): datetime(2008, 4, 1), + datetime(2008, 3, 31): datetime(2008, 4, 1), + datetime(2008, 4, 15): datetime(2008, 7, 1), + datetime(2007, 3, 15): datetime(2007, 4, 2), + datetime(2007, 2, 28): datetime(2007, 4, 2), + datetime(2007, 1, 1): datetime(2007, 4, 2), + datetime(2007, 4, 15): datetime(2007, 7, 2), + datetime(2007, 7, 1): datetime(2007, 7, 2), + datetime(2007, 4, 1): datetime(2007, 4, 2), + datetime(2007, 4, 2): datetime(2007, 7, 2), + datetime(2008, 4, 30): datetime(2008, 7, 1), + }, + ) + ) + + offset_cases.append( + ( + BQuarterBegin(startingMonth=2), + { + datetime(2008, 1, 1): datetime(2008, 2, 1), + datetime(2008, 1, 31): datetime(2008, 2, 1), + datetime(2008, 1, 15): datetime(2008, 2, 1), + datetime(2008, 2, 29): datetime(2008, 5, 1), + datetime(2008, 3, 15): datetime(2008, 5, 1), + datetime(2008, 3, 31): datetime(2008, 5, 1), + datetime(2008, 4, 15): datetime(2008, 5, 1), + datetime(2008, 8, 15): datetime(2008, 11, 3), + datetime(2008, 9, 15): datetime(2008, 11, 3), + datetime(2008, 11, 1): datetime(2008, 11, 3), + datetime(2008, 4, 30): datetime(2008, 5, 1), + }, + ) + ) + + offset_cases.append( + ( + BQuarterBegin(startingMonth=1, n=0), + { + datetime(2008, 1, 1): datetime(2008, 1, 1), + datetime(2007, 12, 31): datetime(2008, 1, 1), + datetime(2008, 2, 15): datetime(2008, 4, 1), + datetime(2008, 2, 29): datetime(2008, 4, 1), + datetime(2008, 1, 15): datetime(2008, 4, 1), + datetime(2008, 2, 27): datetime(2008, 4, 1), + datetime(2008, 3, 15): datetime(2008, 4, 1), + datetime(2007, 4, 1): datetime(2007, 4, 2), + datetime(2007, 4, 2): datetime(2007, 4, 2), + datetime(2007, 7, 1): datetime(2007, 7, 2), + datetime(2007, 4, 15): datetime(2007, 7, 2), + datetime(2007, 7, 2): datetime(2007, 7, 2), + }, + ) + ) + + offset_cases.append( + ( + BQuarterBegin(startingMonth=1, n=-1), + { + datetime(2008, 1, 1): datetime(2007, 10, 1), + datetime(2008, 1, 31): datetime(2008, 1, 1), + datetime(2008, 2, 15): datetime(2008, 1, 1), + datetime(2008, 2, 29): datetime(2008, 1, 1), + datetime(2008, 3, 15): datetime(2008, 1, 1), + datetime(2008, 3, 31): datetime(2008, 1, 1), + datetime(2008, 4, 15): datetime(2008, 4, 1), + datetime(2007, 7, 3): datetime(2007, 7, 2), + datetime(2007, 4, 3): datetime(2007, 4, 2), + datetime(2007, 7, 2): datetime(2007, 4, 2), + datetime(2008, 4, 1): datetime(2008, 1, 1), + }, + ) + ) + + offset_cases.append( + ( + BQuarterBegin(startingMonth=1, n=2), + { + datetime(2008, 1, 1): datetime(2008, 7, 1), + datetime(2008, 1, 15): datetime(2008, 7, 1), + datetime(2008, 2, 29): datetime(2008, 7, 1), + datetime(2008, 3, 15): datetime(2008, 7, 1), + datetime(2007, 3, 31): datetime(2007, 7, 2), + datetime(2007, 4, 15): datetime(2007, 10, 1), + datetime(2008, 4, 30): datetime(2008, 10, 1), + }, + ) + ) + + @pytest.mark.parametrize("case", offset_cases) + def test_offset(self, case): + offset, cases = case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + +class TestBQuarterEnd: + def test_repr(self): + expected = "" + assert repr(BQuarterEnd()) == expected + expected = "" + assert repr(BQuarterEnd(startingMonth=3)) == expected + expected = "" + assert repr(BQuarterEnd(startingMonth=1)) == expected + + def test_is_anchored(self): + msg = "BQuarterEnd.is_anchored is deprecated " + + with tm.assert_produces_warning(FutureWarning, match=msg): + assert BQuarterEnd(startingMonth=1).is_anchored() + assert BQuarterEnd().is_anchored() + assert not BQuarterEnd(2, startingMonth=1).is_anchored() + + def test_offset_corner_case(self): + # corner + offset = BQuarterEnd(n=-1, startingMonth=1) + assert datetime(2010, 1, 31) + offset == datetime(2010, 1, 29) + + offset_cases = [] + offset_cases.append( + ( + BQuarterEnd(startingMonth=1), + { + datetime(2008, 1, 1): datetime(2008, 1, 31), + datetime(2008, 1, 31): datetime(2008, 4, 30), + datetime(2008, 2, 15): datetime(2008, 4, 30), + datetime(2008, 2, 29): datetime(2008, 4, 30), + datetime(2008, 3, 15): datetime(2008, 4, 30), + datetime(2008, 3, 31): datetime(2008, 4, 30), + datetime(2008, 4, 15): datetime(2008, 4, 30), + datetime(2008, 4, 30): datetime(2008, 7, 31), + }, + ) + ) + + offset_cases.append( + ( + BQuarterEnd(startingMonth=2), + { + datetime(2008, 1, 1): datetime(2008, 2, 29), + datetime(2008, 1, 31): datetime(2008, 2, 29), + datetime(2008, 2, 15): datetime(2008, 2, 29), + datetime(2008, 2, 29): datetime(2008, 5, 30), + datetime(2008, 3, 15): datetime(2008, 5, 30), + datetime(2008, 3, 31): datetime(2008, 5, 30), + datetime(2008, 4, 15): datetime(2008, 5, 30), + datetime(2008, 4, 30): datetime(2008, 5, 30), + }, + ) + ) + + offset_cases.append( + ( + BQuarterEnd(startingMonth=1, n=0), + { + datetime(2008, 1, 1): datetime(2008, 1, 31), + datetime(2008, 1, 31): datetime(2008, 1, 31), + datetime(2008, 2, 15): datetime(2008, 4, 30), + datetime(2008, 2, 29): datetime(2008, 4, 30), + datetime(2008, 3, 15): datetime(2008, 4, 30), + datetime(2008, 3, 31): datetime(2008, 4, 30), + datetime(2008, 4, 15): datetime(2008, 4, 30), + datetime(2008, 4, 30): datetime(2008, 4, 30), + }, + ) + ) + + offset_cases.append( + ( + BQuarterEnd(startingMonth=1, n=-1), + { + datetime(2008, 1, 1): datetime(2007, 10, 31), + datetime(2008, 1, 31): datetime(2007, 10, 31), + datetime(2008, 2, 15): datetime(2008, 1, 31), + datetime(2008, 2, 29): datetime(2008, 1, 31), + datetime(2008, 3, 15): datetime(2008, 1, 31), + datetime(2008, 3, 31): datetime(2008, 1, 31), + datetime(2008, 4, 15): datetime(2008, 1, 31), + datetime(2008, 4, 30): datetime(2008, 1, 31), + }, + ) + ) + + offset_cases.append( + ( + BQuarterEnd(startingMonth=1, n=2), + { + datetime(2008, 1, 31): datetime(2008, 7, 31), + datetime(2008, 2, 15): datetime(2008, 7, 31), + datetime(2008, 2, 29): datetime(2008, 7, 31), + datetime(2008, 3, 15): datetime(2008, 7, 31), + datetime(2008, 3, 31): datetime(2008, 7, 31), + datetime(2008, 4, 15): datetime(2008, 7, 31), + datetime(2008, 4, 30): datetime(2008, 10, 31), + }, + ) + ) + + @pytest.mark.parametrize("case", offset_cases) + def test_offset(self, case): + offset, cases = case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + on_offset_cases = [ + (BQuarterEnd(1, startingMonth=1), datetime(2008, 1, 31), True), + (BQuarterEnd(1, startingMonth=1), datetime(2007, 12, 31), False), + (BQuarterEnd(1, startingMonth=1), datetime(2008, 2, 29), False), + (BQuarterEnd(1, startingMonth=1), datetime(2007, 3, 30), False), + (BQuarterEnd(1, startingMonth=1), datetime(2007, 3, 31), False), + (BQuarterEnd(1, startingMonth=1), datetime(2008, 4, 30), True), + (BQuarterEnd(1, startingMonth=1), datetime(2008, 5, 30), False), + (BQuarterEnd(1, startingMonth=1), datetime(2007, 6, 29), False), + (BQuarterEnd(1, startingMonth=1), datetime(2007, 6, 30), False), + (BQuarterEnd(1, startingMonth=2), datetime(2008, 1, 31), False), + (BQuarterEnd(1, startingMonth=2), datetime(2007, 12, 31), False), + (BQuarterEnd(1, startingMonth=2), datetime(2008, 2, 29), True), + (BQuarterEnd(1, startingMonth=2), datetime(2007, 3, 30), False), + (BQuarterEnd(1, startingMonth=2), datetime(2007, 3, 31), False), + (BQuarterEnd(1, startingMonth=2), datetime(2008, 4, 30), False), + (BQuarterEnd(1, startingMonth=2), datetime(2008, 5, 30), True), + (BQuarterEnd(1, startingMonth=2), datetime(2007, 6, 29), False), + (BQuarterEnd(1, startingMonth=2), datetime(2007, 6, 30), False), + (BQuarterEnd(1, startingMonth=3), datetime(2008, 1, 31), False), + (BQuarterEnd(1, startingMonth=3), datetime(2007, 12, 31), True), + (BQuarterEnd(1, startingMonth=3), datetime(2008, 2, 29), False), + (BQuarterEnd(1, startingMonth=3), datetime(2007, 3, 30), True), + (BQuarterEnd(1, startingMonth=3), datetime(2007, 3, 31), False), + (BQuarterEnd(1, startingMonth=3), datetime(2008, 4, 30), False), + (BQuarterEnd(1, startingMonth=3), datetime(2008, 5, 30), False), + (BQuarterEnd(1, startingMonth=3), datetime(2007, 6, 29), True), + (BQuarterEnd(1, startingMonth=3), datetime(2007, 6, 30), False), + ] + + @pytest.mark.parametrize("case", on_offset_cases) + def test_is_on_offset(self, case): + offset, dt, expected = case + assert_is_on_offset(offset, dt, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_business_year.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_business_year.py new file mode 100644 index 0000000000000000000000000000000000000000..3b7a1025cc19c9c1c966b9448ceffdb12dcd8159 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_business_year.py @@ -0,0 +1,215 @@ +""" +Tests for the following offsets: +- BYearBegin +- BYearEnd +""" +from __future__ import annotations + +from datetime import datetime + +import pytest + +from pandas.tests.tseries.offsets.common import ( + assert_is_on_offset, + assert_offset_equal, +) + +from pandas.tseries.offsets import ( + BYearBegin, + BYearEnd, +) + + +class TestBYearBegin: + def test_misspecified(self): + msg = "Month must go from 1 to 12" + with pytest.raises(ValueError, match=msg): + BYearBegin(month=13) + with pytest.raises(ValueError, match=msg): + BYearEnd(month=13) + + offset_cases = [] + offset_cases.append( + ( + BYearBegin(), + { + datetime(2008, 1, 1): datetime(2009, 1, 1), + datetime(2008, 6, 30): datetime(2009, 1, 1), + datetime(2008, 12, 31): datetime(2009, 1, 1), + datetime(2011, 1, 1): datetime(2011, 1, 3), + datetime(2011, 1, 3): datetime(2012, 1, 2), + datetime(2005, 12, 30): datetime(2006, 1, 2), + datetime(2005, 12, 31): datetime(2006, 1, 2), + }, + ) + ) + + offset_cases.append( + ( + BYearBegin(0), + { + datetime(2008, 1, 1): datetime(2008, 1, 1), + datetime(2008, 6, 30): datetime(2009, 1, 1), + datetime(2008, 12, 31): datetime(2009, 1, 1), + datetime(2005, 12, 30): datetime(2006, 1, 2), + datetime(2005, 12, 31): datetime(2006, 1, 2), + }, + ) + ) + + offset_cases.append( + ( + BYearBegin(-1), + { + datetime(2007, 1, 1): datetime(2006, 1, 2), + datetime(2009, 1, 4): datetime(2009, 1, 1), + datetime(2009, 1, 1): datetime(2008, 1, 1), + datetime(2008, 6, 30): datetime(2008, 1, 1), + datetime(2008, 12, 31): datetime(2008, 1, 1), + datetime(2006, 12, 29): datetime(2006, 1, 2), + datetime(2006, 12, 30): datetime(2006, 1, 2), + datetime(2006, 1, 1): datetime(2005, 1, 3), + }, + ) + ) + + offset_cases.append( + ( + BYearBegin(-2), + { + datetime(2007, 1, 1): datetime(2005, 1, 3), + datetime(2007, 6, 30): datetime(2006, 1, 2), + datetime(2008, 12, 31): datetime(2007, 1, 1), + }, + ) + ) + + @pytest.mark.parametrize("case", offset_cases) + def test_offset(self, case): + offset, cases = case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + +class TestBYearEnd: + offset_cases = [] + offset_cases.append( + ( + BYearEnd(), + { + datetime(2008, 1, 1): datetime(2008, 12, 31), + datetime(2008, 6, 30): datetime(2008, 12, 31), + datetime(2008, 12, 31): datetime(2009, 12, 31), + datetime(2005, 12, 30): datetime(2006, 12, 29), + datetime(2005, 12, 31): datetime(2006, 12, 29), + }, + ) + ) + + offset_cases.append( + ( + BYearEnd(0), + { + datetime(2008, 1, 1): datetime(2008, 12, 31), + datetime(2008, 6, 30): datetime(2008, 12, 31), + datetime(2008, 12, 31): datetime(2008, 12, 31), + datetime(2005, 12, 31): datetime(2006, 12, 29), + }, + ) + ) + + offset_cases.append( + ( + BYearEnd(-1), + { + datetime(2007, 1, 1): datetime(2006, 12, 29), + datetime(2008, 6, 30): datetime(2007, 12, 31), + datetime(2008, 12, 31): datetime(2007, 12, 31), + datetime(2006, 12, 29): datetime(2005, 12, 30), + datetime(2006, 12, 30): datetime(2006, 12, 29), + datetime(2007, 1, 1): datetime(2006, 12, 29), + }, + ) + ) + + offset_cases.append( + ( + BYearEnd(-2), + { + datetime(2007, 1, 1): datetime(2005, 12, 30), + datetime(2008, 6, 30): datetime(2006, 12, 29), + datetime(2008, 12, 31): datetime(2006, 12, 29), + }, + ) + ) + + @pytest.mark.parametrize("case", offset_cases) + def test_offset(self, case): + offset, cases = case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + on_offset_cases = [ + (BYearEnd(), datetime(2007, 12, 31), True), + (BYearEnd(), datetime(2008, 1, 1), False), + (BYearEnd(), datetime(2006, 12, 31), False), + (BYearEnd(), datetime(2006, 12, 29), True), + ] + + @pytest.mark.parametrize("case", on_offset_cases) + def test_is_on_offset(self, case): + offset, dt, expected = case + assert_is_on_offset(offset, dt, expected) + + +class TestBYearEndLagged: + def test_bad_month_fail(self): + msg = "Month must go from 1 to 12" + with pytest.raises(ValueError, match=msg): + BYearEnd(month=13) + with pytest.raises(ValueError, match=msg): + BYearEnd(month=0) + + offset_cases = [] + offset_cases.append( + ( + BYearEnd(month=6), + { + datetime(2008, 1, 1): datetime(2008, 6, 30), + datetime(2007, 6, 30): datetime(2008, 6, 30), + }, + ) + ) + + offset_cases.append( + ( + BYearEnd(n=-1, month=6), + { + datetime(2008, 1, 1): datetime(2007, 6, 29), + datetime(2007, 6, 30): datetime(2007, 6, 29), + }, + ) + ) + + @pytest.mark.parametrize("case", offset_cases) + def test_offset(self, case): + offset, cases = case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + def test_roll(self): + offset = BYearEnd(month=6) + date = datetime(2009, 11, 30) + + assert offset.rollforward(date) == datetime(2010, 6, 30) + assert offset.rollback(date) == datetime(2009, 6, 30) + + on_offset_cases = [ + (BYearEnd(month=2), datetime(2007, 2, 28), True), + (BYearEnd(month=6), datetime(2007, 6, 30), False), + ] + + @pytest.mark.parametrize("case", on_offset_cases) + def test_is_on_offset(self, case): + offset, dt, expected = case + assert_is_on_offset(offset, dt, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_common.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_common.py new file mode 100644 index 0000000000000000000000000000000000000000..aa4e22f71ad66147d5b9893ead4dc250d1de0ed3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_common.py @@ -0,0 +1,268 @@ +from datetime import datetime + +from dateutil.tz.tz import tzlocal +import pytest + +from pandas._libs.tslibs import ( + OutOfBoundsDatetime, + Timestamp, +) +from pandas.compat import ( + IS64, + is_platform_windows, +) + +from pandas.tseries.offsets import ( + FY5253, + BDay, + BMonthBegin, + BMonthEnd, + BQuarterBegin, + BQuarterEnd, + BusinessHour, + BYearBegin, + BYearEnd, + CBMonthBegin, + CBMonthEnd, + CDay, + CustomBusinessHour, + DateOffset, + FY5253Quarter, + LastWeekOfMonth, + MonthBegin, + MonthEnd, + QuarterEnd, + SemiMonthBegin, + SemiMonthEnd, + Week, + WeekOfMonth, + YearBegin, + YearEnd, +) + + +def _get_offset(klass, value=1, normalize=False): + # create instance from offset class + if klass is FY5253: + klass = klass( + n=value, + startingMonth=1, + weekday=1, + variation="last", + normalize=normalize, + ) + elif klass is FY5253Quarter: + klass = klass( + n=value, + startingMonth=1, + weekday=1, + qtr_with_extra_week=1, + variation="last", + normalize=normalize, + ) + elif klass is LastWeekOfMonth: + klass = klass(n=value, weekday=5, normalize=normalize) + elif klass is WeekOfMonth: + klass = klass(n=value, week=1, weekday=5, normalize=normalize) + elif klass is Week: + klass = klass(n=value, weekday=5, normalize=normalize) + elif klass is DateOffset: + klass = klass(days=value, normalize=normalize) + else: + klass = klass(value, normalize=normalize) + return klass + + +@pytest.fixture( + params=[ + BDay, + BusinessHour, + BMonthEnd, + BMonthBegin, + BQuarterEnd, + BQuarterBegin, + BYearEnd, + BYearBegin, + CDay, + CustomBusinessHour, + CBMonthEnd, + CBMonthBegin, + MonthEnd, + MonthBegin, + SemiMonthBegin, + SemiMonthEnd, + QuarterEnd, + LastWeekOfMonth, + WeekOfMonth, + Week, + YearBegin, + YearEnd, + FY5253, + FY5253Quarter, + DateOffset, + ] +) +def _offset(request): + return request.param + + +@pytest.fixture +def dt(_offset): + if _offset in (CBMonthBegin, CBMonthEnd, BDay): + return Timestamp(2008, 1, 1) + elif _offset is (CustomBusinessHour, BusinessHour): + return Timestamp(2014, 7, 1, 10, 00) + return Timestamp(2008, 1, 2) + + +def test_apply_out_of_range(request, tz_naive_fixture, _offset): + tz = tz_naive_fixture + + # try to create an out-of-bounds result timestamp; if we can't create + # the offset skip + try: + if _offset in (BusinessHour, CustomBusinessHour): + # Using 10000 in BusinessHour fails in tz check because of DST + # difference + offset = _get_offset(_offset, value=100000) + else: + offset = _get_offset(_offset, value=10000) + + result = Timestamp("20080101") + offset + assert isinstance(result, datetime) + assert result.tzinfo is None + + # Check tz is preserved + t = Timestamp("20080101", tz=tz) + result = t + offset + assert isinstance(result, datetime) + if tz is not None: + assert t.tzinfo is not None + + if isinstance(tz, tzlocal) and not IS64 and _offset is not DateOffset: + # If we hit OutOfBoundsDatetime on non-64 bit machines + # we'll drop out of the try clause before the next test + request.applymarker( + pytest.mark.xfail(reason="OverflowError inside tzlocal past 2038") + ) + elif ( + isinstance(tz, tzlocal) + and is_platform_windows() + and _offset in (QuarterEnd, BQuarterBegin, BQuarterEnd) + ): + request.applymarker( + pytest.mark.xfail(reason="After GH#49737 t.tzinfo is None on CI") + ) + assert str(t.tzinfo) == str(result.tzinfo) + + except OutOfBoundsDatetime: + pass + except (ValueError, KeyError): + # we are creating an invalid offset + # so ignore + pass + + +def test_offsets_compare_equal(_offset): + # root cause of GH#456: __ne__ was not implemented + offset1 = _offset() + offset2 = _offset() + assert not offset1 != offset2 + assert offset1 == offset2 + + +@pytest.mark.parametrize( + "date, offset2", + [ + [Timestamp(2008, 1, 1), BDay(2)], + [Timestamp(2014, 7, 1, 10, 00), BusinessHour(n=3)], + [ + Timestamp(2014, 7, 1, 10), + CustomBusinessHour( + holidays=["2014-06-27", Timestamp(2014, 6, 30), Timestamp("2014-07-02")] + ), + ], + [Timestamp(2008, 1, 2), SemiMonthEnd(2)], + [Timestamp(2008, 1, 2), SemiMonthBegin(2)], + [Timestamp(2008, 1, 2), Week(2)], + [Timestamp(2008, 1, 2), WeekOfMonth(2)], + [Timestamp(2008, 1, 2), LastWeekOfMonth(2)], + ], +) +def test_rsub(date, offset2): + assert date - offset2 == (-offset2)._apply(date) + + +@pytest.mark.parametrize( + "date, offset2", + [ + [Timestamp(2008, 1, 1), BDay(2)], + [Timestamp(2014, 7, 1, 10, 00), BusinessHour(n=3)], + [ + Timestamp(2014, 7, 1, 10), + CustomBusinessHour( + holidays=["2014-06-27", Timestamp(2014, 6, 30), Timestamp("2014-07-02")] + ), + ], + [Timestamp(2008, 1, 2), SemiMonthEnd(2)], + [Timestamp(2008, 1, 2), SemiMonthBegin(2)], + [Timestamp(2008, 1, 2), Week(2)], + [Timestamp(2008, 1, 2), WeekOfMonth(2)], + [Timestamp(2008, 1, 2), LastWeekOfMonth(2)], + ], +) +def test_radd(date, offset2): + assert date + offset2 == offset2 + date + + +@pytest.mark.parametrize( + "date, offset_box, offset2", + [ + [Timestamp(2008, 1, 1), BDay, BDay(2)], + [Timestamp(2008, 1, 2), SemiMonthEnd, SemiMonthEnd(2)], + [Timestamp(2008, 1, 2), SemiMonthBegin, SemiMonthBegin(2)], + [Timestamp(2008, 1, 2), Week, Week(2)], + [Timestamp(2008, 1, 2), WeekOfMonth, WeekOfMonth(2)], + [Timestamp(2008, 1, 2), LastWeekOfMonth, LastWeekOfMonth(2)], + ], +) +def test_sub(date, offset_box, offset2): + off = offset2 + msg = "Cannot subtract datetime from offset" + with pytest.raises(TypeError, match=msg): + off - date + + assert 2 * off - off == off + assert date - offset2 == date + offset_box(-2) + assert date - offset2 == date - (2 * off - off) + + +@pytest.mark.parametrize( + "offset_box, offset1", + [ + [BDay, BDay()], + [LastWeekOfMonth, LastWeekOfMonth()], + [WeekOfMonth, WeekOfMonth()], + [Week, Week()], + [SemiMonthBegin, SemiMonthBegin()], + [SemiMonthEnd, SemiMonthEnd()], + [CustomBusinessHour, CustomBusinessHour(weekmask="Tue Wed Thu Fri")], + [BusinessHour, BusinessHour()], + ], +) +def test_Mult1(offset_box, offset1): + dt = Timestamp(2008, 1, 2) + assert dt + 10 * offset1 == dt + offset_box(10) + assert dt + 5 * offset1 == dt + offset_box(5) + + +def test_compare_str(_offset): + # GH#23524 + # comparing to strings that cannot be cast to DateOffsets should + # not raise for __eq__ or __ne__ + off = _get_offset(_offset) + + assert not off == "infer" + assert off != "foo" + # Note: inequalities are only implemented for Tick subclasses; + # tests for this are in test_ticks diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_custom_business_day.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_custom_business_day.py new file mode 100644 index 0000000000000000000000000000000000000000..519fb712d041534b6e96e41539fb7660e6c14114 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_custom_business_day.py @@ -0,0 +1,98 @@ +""" +Tests for offsets.CustomBusinessDay / CDay +""" +from datetime import ( + datetime, + timedelta, +) + +import numpy as np +import pytest + +from pandas._libs.tslibs.offsets import CDay + +from pandas import ( + _testing as tm, + read_pickle, +) +from pandas.tests.tseries.offsets.common import assert_offset_equal + +from pandas.tseries.holiday import USFederalHolidayCalendar + + +@pytest.fixture +def offset(): + return CDay() + + +@pytest.fixture +def offset2(): + return CDay(2) + + +class TestCustomBusinessDay: + def test_repr(self, offset, offset2): + assert repr(offset) == "" + assert repr(offset2) == "<2 * CustomBusinessDays>" + + expected = "" + assert repr(offset + timedelta(1)) == expected + + def test_holidays(self): + # Define a TradingDay offset + holidays = ["2012-05-01", datetime(2013, 5, 1), np.datetime64("2014-05-01")] + tday = CDay(holidays=holidays) + for year in range(2012, 2015): + dt = datetime(year, 4, 30) + xp = datetime(year, 5, 2) + rs = dt + tday + assert rs == xp + + def test_weekmask(self): + weekmask_saudi = "Sat Sun Mon Tue Wed" # Thu-Fri Weekend + weekmask_uae = "1111001" # Fri-Sat Weekend + weekmask_egypt = [1, 1, 1, 1, 0, 0, 1] # Fri-Sat Weekend + bday_saudi = CDay(weekmask=weekmask_saudi) + bday_uae = CDay(weekmask=weekmask_uae) + bday_egypt = CDay(weekmask=weekmask_egypt) + dt = datetime(2013, 5, 1) + xp_saudi = datetime(2013, 5, 4) + xp_uae = datetime(2013, 5, 2) + xp_egypt = datetime(2013, 5, 2) + assert xp_saudi == dt + bday_saudi + assert xp_uae == dt + bday_uae + assert xp_egypt == dt + bday_egypt + xp2 = datetime(2013, 5, 5) + assert xp2 == dt + 2 * bday_saudi + assert xp2 == dt + 2 * bday_uae + assert xp2 == dt + 2 * bday_egypt + + def test_weekmask_and_holidays(self): + weekmask_egypt = "Sun Mon Tue Wed Thu" # Fri-Sat Weekend + holidays = ["2012-05-01", datetime(2013, 5, 1), np.datetime64("2014-05-01")] + bday_egypt = CDay(holidays=holidays, weekmask=weekmask_egypt) + dt = datetime(2013, 4, 30) + xp_egypt = datetime(2013, 5, 5) + assert xp_egypt == dt + 2 * bday_egypt + + @pytest.mark.filterwarnings("ignore:Non:pandas.errors.PerformanceWarning") + def test_calendar(self): + calendar = USFederalHolidayCalendar() + dt = datetime(2014, 1, 17) + assert_offset_equal(CDay(calendar=calendar), dt, datetime(2014, 1, 21)) + + def test_roundtrip_pickle(self, offset, offset2): + def _check_roundtrip(obj): + unpickled = tm.round_trip_pickle(obj) + assert unpickled == obj + + _check_roundtrip(offset) + _check_roundtrip(offset2) + _check_roundtrip(offset * 2) + + def test_pickle_compat_0_14_1(self, datapath): + hdays = [datetime(2013, 1, 1) for ele in range(4)] + pth = datapath("tseries", "offsets", "data", "cday-0.14.1.pickle") + cday0_14_1 = read_pickle(pth) + cday = CDay(holidays=hdays) + assert cday == cday0_14_1 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_custom_business_hour.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_custom_business_hour.py new file mode 100644 index 0000000000000000000000000000000000000000..55a184f95c2d8681fca77d74827dd248a20587f1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_custom_business_hour.py @@ -0,0 +1,329 @@ +""" +Tests for offsets.CustomBusinessHour +""" +from __future__ import annotations + +from datetime import ( + datetime, + time as dt_time, +) + +import numpy as np +import pytest + +from pandas._libs.tslibs import Timestamp +from pandas._libs.tslibs.offsets import ( + BusinessHour, + CustomBusinessHour, + Nano, +) + +from pandas.tests.tseries.offsets.common import assert_offset_equal + +from pandas.tseries.holiday import USFederalHolidayCalendar + +holidays = ["2014-06-27", datetime(2014, 6, 30), np.datetime64("2014-07-02")] + + +@pytest.fixture +def dt(): + return datetime(2014, 7, 1, 10, 00) + + +@pytest.fixture +def _offset(): + return CustomBusinessHour + + +# 2014 Calendar to check custom holidays +# Sun Mon Tue Wed Thu Fri Sat +# 6/22 23 24 25 26 27 28 +# 29 30 7/1 2 3 4 5 +# 6 7 8 9 10 11 12 +@pytest.fixture +def offset1(): + return CustomBusinessHour(weekmask="Tue Wed Thu Fri") + + +@pytest.fixture +def offset2(): + return CustomBusinessHour(holidays=holidays) + + +class TestCustomBusinessHour: + def test_constructor_errors(self): + msg = "time data must be specified only with hour and minute" + with pytest.raises(ValueError, match=msg): + CustomBusinessHour(start=dt_time(11, 0, 5)) + msg = "time data must match '%H:%M' format" + with pytest.raises(ValueError, match=msg): + CustomBusinessHour(start="AAA") + msg = "time data must match '%H:%M' format" + with pytest.raises(ValueError, match=msg): + CustomBusinessHour(start="14:00:05") + + def test_different_normalize_equals(self, _offset): + # GH#21404 changed __eq__ to return False when `normalize` does not match + offset = _offset() + offset2 = _offset(normalize=True) + assert offset != offset2 + + def test_repr(self, offset1, offset2): + assert repr(offset1) == "" + assert repr(offset2) == "" + + def test_with_offset(self, dt): + expected = Timestamp("2014-07-01 13:00") + + assert dt + CustomBusinessHour() * 3 == expected + assert dt + CustomBusinessHour(n=3) == expected + + def test_eq(self, offset1, offset2): + for offset in [offset1, offset2]: + assert offset == offset + + assert CustomBusinessHour() != CustomBusinessHour(-1) + assert CustomBusinessHour(start="09:00") == CustomBusinessHour() + assert CustomBusinessHour(start="09:00") != CustomBusinessHour(start="09:01") + assert CustomBusinessHour(start="09:00", end="17:00") != CustomBusinessHour( + start="17:00", end="09:01" + ) + + assert CustomBusinessHour(weekmask="Tue Wed Thu Fri") != CustomBusinessHour( + weekmask="Mon Tue Wed Thu Fri" + ) + assert CustomBusinessHour(holidays=["2014-06-27"]) != CustomBusinessHour( + holidays=["2014-06-28"] + ) + + def test_hash(self, offset1, offset2): + assert hash(offset1) == hash(offset1) + assert hash(offset2) == hash(offset2) + + def test_add_dateime(self, dt, offset1, offset2): + assert offset1 + dt == datetime(2014, 7, 1, 11) + assert offset2 + dt == datetime(2014, 7, 1, 11) + + def testRollback1(self, dt, offset1, offset2): + assert offset1.rollback(dt) == dt + assert offset2.rollback(dt) == dt + + d = datetime(2014, 7, 1, 0) + + # 2014/07/01 is Tuesday, 06/30 is Monday(holiday) + assert offset1.rollback(d) == datetime(2014, 6, 27, 17) + + # 2014/6/30 and 2014/6/27 are holidays + assert offset2.rollback(d) == datetime(2014, 6, 26, 17) + + def testRollback2(self, _offset): + assert _offset(-3).rollback(datetime(2014, 7, 5, 15, 0)) == datetime( + 2014, 7, 4, 17, 0 + ) + + def testRollforward1(self, dt, offset1, offset2): + assert offset1.rollforward(dt) == dt + assert offset2.rollforward(dt) == dt + + d = datetime(2014, 7, 1, 0) + assert offset1.rollforward(d) == datetime(2014, 7, 1, 9) + assert offset2.rollforward(d) == datetime(2014, 7, 1, 9) + + def testRollforward2(self, _offset): + assert _offset(-3).rollforward(datetime(2014, 7, 5, 16, 0)) == datetime( + 2014, 7, 7, 9 + ) + + def test_roll_date_object(self): + offset = BusinessHour() + + dt = datetime(2014, 7, 6, 15, 0) + + result = offset.rollback(dt) + assert result == datetime(2014, 7, 4, 17) + + result = offset.rollforward(dt) + assert result == datetime(2014, 7, 7, 9) + + normalize_cases = [ + ( + CustomBusinessHour(normalize=True, holidays=holidays), + { + datetime(2014, 7, 1, 8): datetime(2014, 7, 1), + datetime(2014, 7, 1, 17): datetime(2014, 7, 3), + datetime(2014, 7, 1, 16): datetime(2014, 7, 3), + datetime(2014, 7, 1, 23): datetime(2014, 7, 3), + datetime(2014, 7, 1, 0): datetime(2014, 7, 1), + datetime(2014, 7, 4, 15): datetime(2014, 7, 4), + datetime(2014, 7, 4, 15, 59): datetime(2014, 7, 4), + datetime(2014, 7, 4, 16, 30): datetime(2014, 7, 7), + datetime(2014, 7, 5, 23): datetime(2014, 7, 7), + datetime(2014, 7, 6, 10): datetime(2014, 7, 7), + }, + ), + ( + CustomBusinessHour(-1, normalize=True, holidays=holidays), + { + datetime(2014, 7, 1, 8): datetime(2014, 6, 26), + datetime(2014, 7, 1, 17): datetime(2014, 7, 1), + datetime(2014, 7, 1, 16): datetime(2014, 7, 1), + datetime(2014, 7, 1, 10): datetime(2014, 6, 26), + datetime(2014, 7, 1, 0): datetime(2014, 6, 26), + datetime(2014, 7, 7, 10): datetime(2014, 7, 4), + datetime(2014, 7, 7, 10, 1): datetime(2014, 7, 7), + datetime(2014, 7, 5, 23): datetime(2014, 7, 4), + datetime(2014, 7, 6, 10): datetime(2014, 7, 4), + }, + ), + ( + CustomBusinessHour( + 1, normalize=True, start="17:00", end="04:00", holidays=holidays + ), + { + datetime(2014, 7, 1, 8): datetime(2014, 7, 1), + datetime(2014, 7, 1, 17): datetime(2014, 7, 1), + datetime(2014, 7, 1, 23): datetime(2014, 7, 2), + datetime(2014, 7, 2, 2): datetime(2014, 7, 2), + datetime(2014, 7, 2, 3): datetime(2014, 7, 3), + datetime(2014, 7, 4, 23): datetime(2014, 7, 5), + datetime(2014, 7, 5, 2): datetime(2014, 7, 5), + datetime(2014, 7, 7, 2): datetime(2014, 7, 7), + datetime(2014, 7, 7, 17): datetime(2014, 7, 7), + }, + ), + ] + + @pytest.mark.parametrize("norm_cases", normalize_cases) + def test_normalize(self, norm_cases): + offset, cases = norm_cases + for dt, expected in cases.items(): + assert offset._apply(dt) == expected + + @pytest.mark.parametrize( + "dt, expected", + [ + [datetime(2014, 7, 1, 9), False], + [datetime(2014, 7, 1, 10), True], + [datetime(2014, 7, 1, 15), True], + [datetime(2014, 7, 1, 15, 1), False], + [datetime(2014, 7, 5, 12), False], + [datetime(2014, 7, 6, 12), False], + ], + ) + def test_is_on_offset(self, dt, expected): + offset = CustomBusinessHour(start="10:00", end="15:00", holidays=holidays) + assert offset.is_on_offset(dt) == expected + + apply_cases = [ + ( + CustomBusinessHour(holidays=holidays), + { + datetime(2014, 7, 1, 11): datetime(2014, 7, 1, 12), + datetime(2014, 7, 1, 13): datetime(2014, 7, 1, 14), + datetime(2014, 7, 1, 15): datetime(2014, 7, 1, 16), + datetime(2014, 7, 1, 19): datetime(2014, 7, 3, 10), + datetime(2014, 7, 1, 16): datetime(2014, 7, 3, 9), + datetime(2014, 7, 1, 16, 30, 15): datetime(2014, 7, 3, 9, 30, 15), + datetime(2014, 7, 1, 17): datetime(2014, 7, 3, 10), + datetime(2014, 7, 2, 11): datetime(2014, 7, 3, 10), + # out of business hours + datetime(2014, 7, 2, 8): datetime(2014, 7, 3, 10), + datetime(2014, 7, 2, 19): datetime(2014, 7, 3, 10), + datetime(2014, 7, 2, 23): datetime(2014, 7, 3, 10), + datetime(2014, 7, 3, 0): datetime(2014, 7, 3, 10), + # saturday + datetime(2014, 7, 5, 15): datetime(2014, 7, 7, 10), + datetime(2014, 7, 4, 17): datetime(2014, 7, 7, 10), + datetime(2014, 7, 4, 16, 30): datetime(2014, 7, 7, 9, 30), + datetime(2014, 7, 4, 16, 30, 30): datetime(2014, 7, 7, 9, 30, 30), + }, + ), + ( + CustomBusinessHour(4, holidays=holidays), + { + datetime(2014, 7, 1, 11): datetime(2014, 7, 1, 15), + datetime(2014, 7, 1, 13): datetime(2014, 7, 3, 9), + datetime(2014, 7, 1, 15): datetime(2014, 7, 3, 11), + datetime(2014, 7, 1, 16): datetime(2014, 7, 3, 12), + datetime(2014, 7, 1, 17): datetime(2014, 7, 3, 13), + datetime(2014, 7, 2, 11): datetime(2014, 7, 3, 13), + datetime(2014, 7, 2, 8): datetime(2014, 7, 3, 13), + datetime(2014, 7, 2, 19): datetime(2014, 7, 3, 13), + datetime(2014, 7, 2, 23): datetime(2014, 7, 3, 13), + datetime(2014, 7, 3, 0): datetime(2014, 7, 3, 13), + datetime(2014, 7, 5, 15): datetime(2014, 7, 7, 13), + datetime(2014, 7, 4, 17): datetime(2014, 7, 7, 13), + datetime(2014, 7, 4, 16, 30): datetime(2014, 7, 7, 12, 30), + datetime(2014, 7, 4, 16, 30, 30): datetime(2014, 7, 7, 12, 30, 30), + }, + ), + ] + + @pytest.mark.parametrize("apply_case", apply_cases) + def test_apply(self, apply_case): + offset, cases = apply_case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + nano_cases = [ + ( + CustomBusinessHour(holidays=holidays), + { + Timestamp("2014-07-01 15:00") + + Nano(5): Timestamp("2014-07-01 16:00") + + Nano(5), + Timestamp("2014-07-01 16:00") + + Nano(5): Timestamp("2014-07-03 09:00") + + Nano(5), + Timestamp("2014-07-01 16:00") + - Nano(5): Timestamp("2014-07-01 17:00") + - Nano(5), + }, + ), + ( + CustomBusinessHour(-1, holidays=holidays), + { + Timestamp("2014-07-01 15:00") + + Nano(5): Timestamp("2014-07-01 14:00") + + Nano(5), + Timestamp("2014-07-01 10:00") + + Nano(5): Timestamp("2014-07-01 09:00") + + Nano(5), + Timestamp("2014-07-01 10:00") + - Nano(5): Timestamp("2014-06-26 17:00") + - Nano(5), + }, + ), + ] + + @pytest.mark.parametrize("nano_case", nano_cases) + def test_apply_nanoseconds(self, nano_case): + offset, cases = nano_case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + def test_us_federal_holiday_with_datetime(self): + # GH 16867 + bhour_us = CustomBusinessHour(calendar=USFederalHolidayCalendar()) + t0 = datetime(2014, 1, 17, 15) + result = t0 + bhour_us * 8 + expected = Timestamp("2014-01-21 15:00:00") + assert result == expected + + +@pytest.mark.parametrize( + "weekmask, expected_time, mult", + [ + ["Mon Tue Wed Thu Fri Sat", "2018-11-10 09:00:00", 10], + ["Tue Wed Thu Fri Sat", "2018-11-13 08:00:00", 18], + ], +) +def test_custom_businesshour_weekmask_and_holidays(weekmask, expected_time, mult): + # GH 23542 + holidays = ["2018-11-09"] + bh = CustomBusinessHour( + start="08:00", end="17:00", weekmask=weekmask, holidays=holidays + ) + result = Timestamp("2018-11-08 08:00") + mult * bh + expected = Timestamp(expected_time) + assert result == expected diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_custom_business_month.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_custom_business_month.py new file mode 100644 index 0000000000000000000000000000000000000000..d226302e042d325d78a953a651b24f85ad4f0468 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_custom_business_month.py @@ -0,0 +1,437 @@ +""" +Tests for the following offsets: +- CustomBusinessMonthBase +- CustomBusinessMonthBegin +- CustomBusinessMonthEnd +""" +from __future__ import annotations + +from datetime import ( + date, + datetime, + timedelta, +) + +import numpy as np +import pytest + +from pandas._libs.tslibs.offsets import ( + CBMonthBegin, + CBMonthEnd, + CDay, +) + +import pandas._testing as tm +from pandas.tests.tseries.offsets.common import ( + assert_is_on_offset, + assert_offset_equal, +) + +from pandas.tseries import offsets + + +@pytest.fixture +def dt(): + return datetime(2008, 1, 1) + + +class TestCommonCBM: + @pytest.mark.parametrize("offset2", [CBMonthBegin(2), CBMonthEnd(2)]) + def test_eq(self, offset2): + assert offset2 == offset2 + + @pytest.mark.parametrize("offset2", [CBMonthBegin(2), CBMonthEnd(2)]) + def test_hash(self, offset2): + assert hash(offset2) == hash(offset2) + + @pytest.mark.parametrize("_offset", [CBMonthBegin, CBMonthEnd]) + def test_roundtrip_pickle(self, _offset): + def _check_roundtrip(obj): + unpickled = tm.round_trip_pickle(obj) + assert unpickled == obj + + _check_roundtrip(_offset()) + _check_roundtrip(_offset(2)) + _check_roundtrip(_offset() * 2) + + @pytest.mark.parametrize("_offset", [CBMonthBegin, CBMonthEnd]) + def test_copy(self, _offset): + # GH 17452 + off = _offset(weekmask="Mon Wed Fri") + assert off == off.copy() + + +class TestCustomBusinessMonthBegin: + @pytest.fixture + def _offset(self): + return CBMonthBegin + + @pytest.fixture + def offset(self): + return CBMonthBegin() + + @pytest.fixture + def offset2(self): + return CBMonthBegin(2) + + def test_different_normalize_equals(self, _offset): + # GH#21404 changed __eq__ to return False when `normalize` does not match + offset = _offset() + offset2 = _offset(normalize=True) + assert offset != offset2 + + def test_repr(self, offset, offset2): + assert repr(offset) == "" + assert repr(offset2) == "<2 * CustomBusinessMonthBegins>" + + def test_add_datetime(self, dt, offset2): + assert offset2 + dt == datetime(2008, 3, 3) + + def testRollback1(self): + assert CDay(10).rollback(datetime(2007, 12, 31)) == datetime(2007, 12, 31) + + def testRollback2(self, dt): + assert CBMonthBegin(10).rollback(dt) == datetime(2008, 1, 1) + + def testRollforward1(self, dt): + assert CBMonthBegin(10).rollforward(dt) == datetime(2008, 1, 1) + + def test_roll_date_object(self): + offset = CBMonthBegin() + + dt = date(2012, 9, 15) + + result = offset.rollback(dt) + assert result == datetime(2012, 9, 3) + + result = offset.rollforward(dt) + assert result == datetime(2012, 10, 1) + + offset = offsets.Day() + result = offset.rollback(dt) + assert result == datetime(2012, 9, 15) + + result = offset.rollforward(dt) + assert result == datetime(2012, 9, 15) + + on_offset_cases = [ + (CBMonthBegin(), datetime(2008, 1, 1), True), + (CBMonthBegin(), datetime(2008, 1, 31), False), + ] + + @pytest.mark.parametrize("case", on_offset_cases) + def test_is_on_offset(self, case): + offset, dt, expected = case + assert_is_on_offset(offset, dt, expected) + + apply_cases = [ + ( + CBMonthBegin(), + { + datetime(2008, 1, 1): datetime(2008, 2, 1), + datetime(2008, 2, 7): datetime(2008, 3, 3), + }, + ), + ( + 2 * CBMonthBegin(), + { + datetime(2008, 1, 1): datetime(2008, 3, 3), + datetime(2008, 2, 7): datetime(2008, 4, 1), + }, + ), + ( + -CBMonthBegin(), + { + datetime(2008, 1, 1): datetime(2007, 12, 3), + datetime(2008, 2, 8): datetime(2008, 2, 1), + }, + ), + ( + -2 * CBMonthBegin(), + { + datetime(2008, 1, 1): datetime(2007, 11, 1), + datetime(2008, 2, 9): datetime(2008, 1, 1), + }, + ), + ( + CBMonthBegin(0), + { + datetime(2008, 1, 1): datetime(2008, 1, 1), + datetime(2008, 1, 7): datetime(2008, 2, 1), + }, + ), + ] + + @pytest.mark.parametrize("case", apply_cases) + def test_apply(self, case): + offset, cases = case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + def test_apply_large_n(self): + dt = datetime(2012, 10, 23) + + result = dt + CBMonthBegin(10) + assert result == datetime(2013, 8, 1) + + result = dt + CDay(100) - CDay(100) + assert result == dt + + off = CBMonthBegin() * 6 + rs = datetime(2012, 1, 1) - off + xp = datetime(2011, 7, 1) + assert rs == xp + + st = datetime(2011, 12, 18) + rs = st + off + + xp = datetime(2012, 6, 1) + assert rs == xp + + def test_holidays(self): + # Define a TradingDay offset + holidays = ["2012-02-01", datetime(2012, 2, 2), np.datetime64("2012-03-01")] + bm_offset = CBMonthBegin(holidays=holidays) + dt = datetime(2012, 1, 1) + + assert dt + bm_offset == datetime(2012, 1, 2) + assert dt + 2 * bm_offset == datetime(2012, 2, 3) + + @pytest.mark.parametrize( + "case", + [ + ( + CBMonthBegin(n=1, offset=timedelta(days=5)), + { + datetime(2021, 3, 1): datetime(2021, 4, 1) + timedelta(days=5), + datetime(2021, 4, 17): datetime(2021, 5, 3) + timedelta(days=5), + }, + ), + ( + CBMonthBegin(n=2, offset=timedelta(days=40)), + { + datetime(2021, 3, 10): datetime(2021, 5, 3) + timedelta(days=40), + datetime(2021, 4, 30): datetime(2021, 6, 1) + timedelta(days=40), + }, + ), + ( + CBMonthBegin(n=1, offset=timedelta(days=-5)), + { + datetime(2021, 3, 1): datetime(2021, 4, 1) - timedelta(days=5), + datetime(2021, 4, 11): datetime(2021, 5, 3) - timedelta(days=5), + }, + ), + ( + -2 * CBMonthBegin(n=1, offset=timedelta(days=10)), + { + datetime(2021, 3, 1): datetime(2021, 1, 1) + timedelta(days=10), + datetime(2021, 4, 3): datetime(2021, 3, 1) + timedelta(days=10), + }, + ), + ( + CBMonthBegin(n=0, offset=timedelta(days=1)), + { + datetime(2021, 3, 2): datetime(2021, 4, 1) + timedelta(days=1), + datetime(2021, 4, 1): datetime(2021, 4, 1) + timedelta(days=1), + }, + ), + ( + CBMonthBegin( + n=1, holidays=["2021-04-01", "2021-04-02"], offset=timedelta(days=1) + ), + { + datetime(2021, 3, 2): datetime(2021, 4, 5) + timedelta(days=1), + }, + ), + ], + ) + def test_apply_with_extra_offset(self, case): + offset, cases = case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + +class TestCustomBusinessMonthEnd: + @pytest.fixture + def _offset(self): + return CBMonthEnd + + @pytest.fixture + def offset(self): + return CBMonthEnd() + + @pytest.fixture + def offset2(self): + return CBMonthEnd(2) + + def test_different_normalize_equals(self, _offset): + # GH#21404 changed __eq__ to return False when `normalize` does not match + offset = _offset() + offset2 = _offset(normalize=True) + assert offset != offset2 + + def test_repr(self, offset, offset2): + assert repr(offset) == "" + assert repr(offset2) == "<2 * CustomBusinessMonthEnds>" + + def test_add_datetime(self, dt, offset2): + assert offset2 + dt == datetime(2008, 2, 29) + + def testRollback1(self): + assert CDay(10).rollback(datetime(2007, 12, 31)) == datetime(2007, 12, 31) + + def testRollback2(self, dt): + assert CBMonthEnd(10).rollback(dt) == datetime(2007, 12, 31) + + def testRollforward1(self, dt): + assert CBMonthEnd(10).rollforward(dt) == datetime(2008, 1, 31) + + def test_roll_date_object(self): + offset = CBMonthEnd() + + dt = date(2012, 9, 15) + + result = offset.rollback(dt) + assert result == datetime(2012, 8, 31) + + result = offset.rollforward(dt) + assert result == datetime(2012, 9, 28) + + offset = offsets.Day() + result = offset.rollback(dt) + assert result == datetime(2012, 9, 15) + + result = offset.rollforward(dt) + assert result == datetime(2012, 9, 15) + + on_offset_cases = [ + (CBMonthEnd(), datetime(2008, 1, 31), True), + (CBMonthEnd(), datetime(2008, 1, 1), False), + ] + + @pytest.mark.parametrize("case", on_offset_cases) + def test_is_on_offset(self, case): + offset, dt, expected = case + assert_is_on_offset(offset, dt, expected) + + apply_cases = [ + ( + CBMonthEnd(), + { + datetime(2008, 1, 1): datetime(2008, 1, 31), + datetime(2008, 2, 7): datetime(2008, 2, 29), + }, + ), + ( + 2 * CBMonthEnd(), + { + datetime(2008, 1, 1): datetime(2008, 2, 29), + datetime(2008, 2, 7): datetime(2008, 3, 31), + }, + ), + ( + -CBMonthEnd(), + { + datetime(2008, 1, 1): datetime(2007, 12, 31), + datetime(2008, 2, 8): datetime(2008, 1, 31), + }, + ), + ( + -2 * CBMonthEnd(), + { + datetime(2008, 1, 1): datetime(2007, 11, 30), + datetime(2008, 2, 9): datetime(2007, 12, 31), + }, + ), + ( + CBMonthEnd(0), + { + datetime(2008, 1, 1): datetime(2008, 1, 31), + datetime(2008, 2, 7): datetime(2008, 2, 29), + }, + ), + ] + + @pytest.mark.parametrize("case", apply_cases) + def test_apply(self, case): + offset, cases = case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + def test_apply_large_n(self): + dt = datetime(2012, 10, 23) + + result = dt + CBMonthEnd(10) + assert result == datetime(2013, 7, 31) + + result = dt + CDay(100) - CDay(100) + assert result == dt + + off = CBMonthEnd() * 6 + rs = datetime(2012, 1, 1) - off + xp = datetime(2011, 7, 29) + assert rs == xp + + st = datetime(2011, 12, 18) + rs = st + off + xp = datetime(2012, 5, 31) + assert rs == xp + + def test_holidays(self): + # Define a TradingDay offset + holidays = ["2012-01-31", datetime(2012, 2, 28), np.datetime64("2012-02-29")] + bm_offset = CBMonthEnd(holidays=holidays) + dt = datetime(2012, 1, 1) + assert dt + bm_offset == datetime(2012, 1, 30) + assert dt + 2 * bm_offset == datetime(2012, 2, 27) + + @pytest.mark.parametrize( + "case", + [ + ( + CBMonthEnd(n=1, offset=timedelta(days=5)), + { + datetime(2021, 3, 1): datetime(2021, 3, 31) + timedelta(days=5), + datetime(2021, 4, 17): datetime(2021, 4, 30) + timedelta(days=5), + }, + ), + ( + CBMonthEnd(n=2, offset=timedelta(days=40)), + { + datetime(2021, 3, 10): datetime(2021, 4, 30) + timedelta(days=40), + datetime(2021, 4, 30): datetime(2021, 6, 30) + timedelta(days=40), + }, + ), + ( + CBMonthEnd(n=1, offset=timedelta(days=-5)), + { + datetime(2021, 3, 1): datetime(2021, 3, 31) - timedelta(days=5), + datetime(2021, 4, 11): datetime(2021, 4, 30) - timedelta(days=5), + }, + ), + ( + -2 * CBMonthEnd(n=1, offset=timedelta(days=10)), + { + datetime(2021, 3, 1): datetime(2021, 1, 29) + timedelta(days=10), + datetime(2021, 4, 3): datetime(2021, 2, 26) + timedelta(days=10), + }, + ), + ( + CBMonthEnd(n=0, offset=timedelta(days=1)), + { + datetime(2021, 3, 2): datetime(2021, 3, 31) + timedelta(days=1), + datetime(2021, 4, 1): datetime(2021, 4, 30) + timedelta(days=1), + }, + ), + ( + CBMonthEnd(n=1, holidays=["2021-03-31"], offset=timedelta(days=1)), + { + datetime(2021, 3, 2): datetime(2021, 3, 30) + timedelta(days=1), + }, + ), + ], + ) + def test_apply_with_extra_offset(self, case): + offset, cases = case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_dst.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_dst.py new file mode 100644 index 0000000000000000000000000000000000000000..b22dc0b33081794cef587f0bbcf3271d35fd687b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_dst.py @@ -0,0 +1,260 @@ +""" +Tests for DateOffset additions over Daylight Savings Time +""" +from datetime import timedelta + +import pytest +import pytz + +from pandas._libs.tslibs import Timestamp +from pandas._libs.tslibs.offsets import ( + BMonthBegin, + BMonthEnd, + BQuarterBegin, + BQuarterEnd, + BYearBegin, + BYearEnd, + CBMonthBegin, + CBMonthEnd, + CustomBusinessDay, + DateOffset, + Day, + MonthBegin, + MonthEnd, + QuarterBegin, + QuarterEnd, + SemiMonthBegin, + SemiMonthEnd, + Week, + YearBegin, + YearEnd, +) +from pandas.errors import PerformanceWarning + +from pandas import DatetimeIndex +import pandas._testing as tm +from pandas.util.version import Version + +# error: Module has no attribute "__version__" +pytz_version = Version(pytz.__version__) # type: ignore[attr-defined] + + +def get_utc_offset_hours(ts): + # take a Timestamp and compute total hours of utc offset + o = ts.utcoffset() + return (o.days * 24 * 3600 + o.seconds) / 3600.0 + + +class TestDST: + # one microsecond before the DST transition + ts_pre_fallback = "2013-11-03 01:59:59.999999" + ts_pre_springfwd = "2013-03-10 01:59:59.999999" + + # test both basic names and dateutil timezones + timezone_utc_offsets = { + "US/Eastern": {"utc_offset_daylight": -4, "utc_offset_standard": -5}, + "dateutil/US/Pacific": {"utc_offset_daylight": -7, "utc_offset_standard": -8}, + } + valid_date_offsets_singular = [ + "weekday", + "day", + "hour", + "minute", + "second", + "microsecond", + ] + valid_date_offsets_plural = [ + "weeks", + "days", + "hours", + "minutes", + "seconds", + "milliseconds", + "microseconds", + ] + + def _test_all_offsets(self, n, **kwds): + valid_offsets = ( + self.valid_date_offsets_plural + if n > 1 + else self.valid_date_offsets_singular + ) + + for name in valid_offsets: + self._test_offset(offset_name=name, offset_n=n, **kwds) + + def _test_offset(self, offset_name, offset_n, tstart, expected_utc_offset): + offset = DateOffset(**{offset_name: offset_n}) + + if ( + offset_name in ["hour", "minute", "second", "microsecond"] + and offset_n == 1 + and tstart == Timestamp("2013-11-03 01:59:59.999999-0500", tz="US/Eastern") + ): + # This addition results in an ambiguous wall time + err_msg = { + "hour": "2013-11-03 01:59:59.999999", + "minute": "2013-11-03 01:01:59.999999", + "second": "2013-11-03 01:59:01.999999", + "microsecond": "2013-11-03 01:59:59.000001", + }[offset_name] + with pytest.raises(pytz.AmbiguousTimeError, match=err_msg): + tstart + offset + # While we're here, let's check that we get the same behavior in a + # vectorized path + dti = DatetimeIndex([tstart]) + warn_msg = "Non-vectorized DateOffset" + with pytest.raises(pytz.AmbiguousTimeError, match=err_msg): + with tm.assert_produces_warning(PerformanceWarning, match=warn_msg): + dti + offset + return + + t = tstart + offset + if expected_utc_offset is not None: + assert get_utc_offset_hours(t) == expected_utc_offset + + if offset_name == "weeks": + # dates should match + assert t.date() == timedelta(days=7 * offset.kwds["weeks"]) + tstart.date() + # expect the same day of week, hour of day, minute, second, ... + assert ( + t.dayofweek == tstart.dayofweek + and t.hour == tstart.hour + and t.minute == tstart.minute + and t.second == tstart.second + ) + elif offset_name == "days": + # dates should match + assert timedelta(offset.kwds["days"]) + tstart.date() == t.date() + # expect the same hour of day, minute, second, ... + assert ( + t.hour == tstart.hour + and t.minute == tstart.minute + and t.second == tstart.second + ) + elif offset_name in self.valid_date_offsets_singular: + # expect the singular offset value to match between tstart and t + datepart_offset = getattr( + t, offset_name if offset_name != "weekday" else "dayofweek" + ) + assert datepart_offset == offset.kwds[offset_name] + else: + # the offset should be the same as if it was done in UTC + assert t == (tstart.tz_convert("UTC") + offset).tz_convert("US/Pacific") + + def _make_timestamp(self, string, hrs_offset, tz): + if hrs_offset >= 0: + offset_string = f"{hrs_offset:02d}00" + else: + offset_string = f"-{(hrs_offset * -1):02}00" + return Timestamp(string + offset_string).tz_convert(tz) + + def test_springforward_plural(self): + # test moving from standard to daylight savings + for tz, utc_offsets in self.timezone_utc_offsets.items(): + hrs_pre = utc_offsets["utc_offset_standard"] + hrs_post = utc_offsets["utc_offset_daylight"] + self._test_all_offsets( + n=3, + tstart=self._make_timestamp(self.ts_pre_springfwd, hrs_pre, tz), + expected_utc_offset=hrs_post, + ) + + def test_fallback_singular(self): + # in the case of singular offsets, we don't necessarily know which utc + # offset the new Timestamp will wind up in (the tz for 1 month may be + # different from 1 second) so we don't specify an expected_utc_offset + for tz, utc_offsets in self.timezone_utc_offsets.items(): + hrs_pre = utc_offsets["utc_offset_standard"] + self._test_all_offsets( + n=1, + tstart=self._make_timestamp(self.ts_pre_fallback, hrs_pre, tz), + expected_utc_offset=None, + ) + + def test_springforward_singular(self): + for tz, utc_offsets in self.timezone_utc_offsets.items(): + hrs_pre = utc_offsets["utc_offset_standard"] + self._test_all_offsets( + n=1, + tstart=self._make_timestamp(self.ts_pre_springfwd, hrs_pre, tz), + expected_utc_offset=None, + ) + + offset_classes = { + MonthBegin: ["11/2/2012", "12/1/2012"], + MonthEnd: ["11/2/2012", "11/30/2012"], + BMonthBegin: ["11/2/2012", "12/3/2012"], + BMonthEnd: ["11/2/2012", "11/30/2012"], + CBMonthBegin: ["11/2/2012", "12/3/2012"], + CBMonthEnd: ["11/2/2012", "11/30/2012"], + SemiMonthBegin: ["11/2/2012", "11/15/2012"], + SemiMonthEnd: ["11/2/2012", "11/15/2012"], + Week: ["11/2/2012", "11/9/2012"], + YearBegin: ["11/2/2012", "1/1/2013"], + YearEnd: ["11/2/2012", "12/31/2012"], + BYearBegin: ["11/2/2012", "1/1/2013"], + BYearEnd: ["11/2/2012", "12/31/2012"], + QuarterBegin: ["11/2/2012", "12/1/2012"], + QuarterEnd: ["11/2/2012", "12/31/2012"], + BQuarterBegin: ["11/2/2012", "12/3/2012"], + BQuarterEnd: ["11/2/2012", "12/31/2012"], + Day: ["11/4/2012", "11/4/2012 23:00"], + }.items() + + @pytest.mark.parametrize("tup", offset_classes) + def test_all_offset_classes(self, tup): + offset, test_values = tup + + first = Timestamp(test_values[0], tz="US/Eastern") + offset() + second = Timestamp(test_values[1], tz="US/Eastern") + assert first == second + + +@pytest.mark.parametrize( + "original_dt, target_dt, offset, tz", + [ + pytest.param( + Timestamp("1900-01-01"), + Timestamp("1905-07-01"), + MonthBegin(66), + "Africa/Lagos", + marks=pytest.mark.xfail( + pytz_version < Version("2020.5") or pytz_version == Version("2022.2"), + reason="GH#41906: pytz utc transition dates changed", + ), + ), + ( + Timestamp("2021-10-01 01:15"), + Timestamp("2021-10-31 01:15"), + MonthEnd(1), + "Europe/London", + ), + ( + Timestamp("2010-12-05 02:59"), + Timestamp("2010-10-31 02:59"), + SemiMonthEnd(-3), + "Europe/Paris", + ), + ( + Timestamp("2021-10-31 01:20"), + Timestamp("2021-11-07 01:20"), + CustomBusinessDay(2, weekmask="Sun Mon"), + "US/Eastern", + ), + ( + Timestamp("2020-04-03 01:30"), + Timestamp("2020-11-01 01:30"), + YearBegin(1, month=11), + "America/Chicago", + ), + ], +) +def test_nontick_offset_with_ambiguous_time_error(original_dt, target_dt, offset, tz): + # .apply for non-Tick offsets throws AmbiguousTimeError when the target dt + # is dst-ambiguous + localized_dt = original_dt.tz_localize(tz) + + msg = f"Cannot infer dst time from {target_dt}, try using the 'ambiguous' argument" + with pytest.raises(pytz.AmbiguousTimeError, match=msg): + localized_dt + offset diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_easter.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_easter.py new file mode 100644 index 0000000000000000000000000000000000000000..d11a72cc1b9d54387a37d8e4102249c415c4b46e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_easter.py @@ -0,0 +1,33 @@ +""" +Tests for the following offsets: +- Easter +""" +from __future__ import annotations + +from datetime import datetime + +import pytest + +from pandas.tests.tseries.offsets.common import assert_offset_equal + +from pandas.tseries.offsets import Easter + + +class TestEaster: + @pytest.mark.parametrize( + "offset,date,expected", + [ + (Easter(), datetime(2010, 1, 1), datetime(2010, 4, 4)), + (Easter(), datetime(2010, 4, 5), datetime(2011, 4, 24)), + (Easter(2), datetime(2010, 1, 1), datetime(2011, 4, 24)), + (Easter(), datetime(2010, 4, 4), datetime(2011, 4, 24)), + (Easter(2), datetime(2010, 4, 4), datetime(2012, 4, 8)), + (-Easter(), datetime(2011, 1, 1), datetime(2010, 4, 4)), + (-Easter(), datetime(2010, 4, 5), datetime(2010, 4, 4)), + (-Easter(2), datetime(2011, 1, 1), datetime(2009, 4, 12)), + (-Easter(), datetime(2010, 4, 4), datetime(2009, 4, 12)), + (-Easter(2), datetime(2010, 4, 4), datetime(2008, 3, 23)), + ], + ) + def test_offset(self, offset, date, expected): + assert_offset_equal(offset, date, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_fiscal.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_fiscal.py new file mode 100644 index 0000000000000000000000000000000000000000..824e66a1ddef1b31708e53075949a9bba0114190 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_fiscal.py @@ -0,0 +1,656 @@ +""" +Tests for Fiscal Year and Fiscal Quarter offset classes +""" +from datetime import datetime + +from dateutil.relativedelta import relativedelta +import pytest + +from pandas import Timestamp +import pandas._testing as tm +from pandas.tests.tseries.offsets.common import ( + WeekDay, + assert_is_on_offset, + assert_offset_equal, +) + +from pandas.tseries.offsets import ( + FY5253, + FY5253Quarter, +) + + +def makeFY5253LastOfMonthQuarter(*args, **kwds): + return FY5253Quarter(*args, variation="last", **kwds) + + +def makeFY5253NearestEndMonthQuarter(*args, **kwds): + return FY5253Quarter(*args, variation="nearest", **kwds) + + +def makeFY5253NearestEndMonth(*args, **kwds): + return FY5253(*args, variation="nearest", **kwds) + + +def makeFY5253LastOfMonth(*args, **kwds): + return FY5253(*args, variation="last", **kwds) + + +def test_get_offset_name(): + assert ( + makeFY5253LastOfMonthQuarter( + weekday=1, startingMonth=3, qtr_with_extra_week=4 + ).freqstr + == "REQ-L-MAR-TUE-4" + ) + assert ( + makeFY5253NearestEndMonthQuarter( + weekday=1, startingMonth=3, qtr_with_extra_week=3 + ).freqstr + == "REQ-N-MAR-TUE-3" + ) + + +class TestFY5253LastOfMonth: + offset_lom_sat_aug = makeFY5253LastOfMonth(1, startingMonth=8, weekday=WeekDay.SAT) + offset_lom_sat_sep = makeFY5253LastOfMonth(1, startingMonth=9, weekday=WeekDay.SAT) + + on_offset_cases = [ + # From Wikipedia (see: + # https://en.wikipedia.org/wiki/4%E2%80%934%E2%80%935_calendar#Last_Saturday_of_the_month_at_fiscal_year_end) + (offset_lom_sat_aug, datetime(2006, 8, 26), True), + (offset_lom_sat_aug, datetime(2007, 8, 25), True), + (offset_lom_sat_aug, datetime(2008, 8, 30), True), + (offset_lom_sat_aug, datetime(2009, 8, 29), True), + (offset_lom_sat_aug, datetime(2010, 8, 28), True), + (offset_lom_sat_aug, datetime(2011, 8, 27), True), + (offset_lom_sat_aug, datetime(2012, 8, 25), True), + (offset_lom_sat_aug, datetime(2013, 8, 31), True), + (offset_lom_sat_aug, datetime(2014, 8, 30), True), + (offset_lom_sat_aug, datetime(2015, 8, 29), True), + (offset_lom_sat_aug, datetime(2016, 8, 27), True), + (offset_lom_sat_aug, datetime(2017, 8, 26), True), + (offset_lom_sat_aug, datetime(2018, 8, 25), True), + (offset_lom_sat_aug, datetime(2019, 8, 31), True), + (offset_lom_sat_aug, datetime(2006, 8, 27), False), + (offset_lom_sat_aug, datetime(2007, 8, 28), False), + (offset_lom_sat_aug, datetime(2008, 8, 31), False), + (offset_lom_sat_aug, datetime(2009, 8, 30), False), + (offset_lom_sat_aug, datetime(2010, 8, 29), False), + (offset_lom_sat_aug, datetime(2011, 8, 28), False), + (offset_lom_sat_aug, datetime(2006, 8, 25), False), + (offset_lom_sat_aug, datetime(2007, 8, 24), False), + (offset_lom_sat_aug, datetime(2008, 8, 29), False), + (offset_lom_sat_aug, datetime(2009, 8, 28), False), + (offset_lom_sat_aug, datetime(2010, 8, 27), False), + (offset_lom_sat_aug, datetime(2011, 8, 26), False), + (offset_lom_sat_aug, datetime(2019, 8, 30), False), + # From GMCR (see for example: + # http://yahoo.brand.edgar-online.com/Default.aspx? + # companyid=3184&formtypeID=7) + (offset_lom_sat_sep, datetime(2010, 9, 25), True), + (offset_lom_sat_sep, datetime(2011, 9, 24), True), + (offset_lom_sat_sep, datetime(2012, 9, 29), True), + ] + + @pytest.mark.parametrize("case", on_offset_cases) + def test_is_on_offset(self, case): + offset, dt, expected = case + assert_is_on_offset(offset, dt, expected) + + def test_apply(self): + offset_lom_aug_sat = makeFY5253LastOfMonth(startingMonth=8, weekday=WeekDay.SAT) + offset_lom_aug_sat_1 = makeFY5253LastOfMonth( + n=1, startingMonth=8, weekday=WeekDay.SAT + ) + + date_seq_lom_aug_sat = [ + datetime(2006, 8, 26), + datetime(2007, 8, 25), + datetime(2008, 8, 30), + datetime(2009, 8, 29), + datetime(2010, 8, 28), + datetime(2011, 8, 27), + datetime(2012, 8, 25), + datetime(2013, 8, 31), + datetime(2014, 8, 30), + datetime(2015, 8, 29), + datetime(2016, 8, 27), + ] + + tests = [ + (offset_lom_aug_sat, date_seq_lom_aug_sat), + (offset_lom_aug_sat_1, date_seq_lom_aug_sat), + (offset_lom_aug_sat, [datetime(2006, 8, 25)] + date_seq_lom_aug_sat), + (offset_lom_aug_sat_1, [datetime(2006, 8, 27)] + date_seq_lom_aug_sat[1:]), + ( + makeFY5253LastOfMonth(n=-1, startingMonth=8, weekday=WeekDay.SAT), + list(reversed(date_seq_lom_aug_sat)), + ), + ] + for test in tests: + offset, data = test + current = data[0] + for datum in data[1:]: + current = current + offset + assert current == datum + + +class TestFY5253NearestEndMonth: + def test_get_year_end(self): + assert makeFY5253NearestEndMonth( + startingMonth=8, weekday=WeekDay.SAT + ).get_year_end(datetime(2013, 1, 1)) == datetime(2013, 8, 31) + assert makeFY5253NearestEndMonth( + startingMonth=8, weekday=WeekDay.SUN + ).get_year_end(datetime(2013, 1, 1)) == datetime(2013, 9, 1) + assert makeFY5253NearestEndMonth( + startingMonth=8, weekday=WeekDay.FRI + ).get_year_end(datetime(2013, 1, 1)) == datetime(2013, 8, 30) + + offset_n = FY5253(weekday=WeekDay.TUE, startingMonth=12, variation="nearest") + assert offset_n.get_year_end(datetime(2012, 1, 1)) == datetime(2013, 1, 1) + assert offset_n.get_year_end(datetime(2012, 1, 10)) == datetime(2013, 1, 1) + + assert offset_n.get_year_end(datetime(2013, 1, 1)) == datetime(2013, 12, 31) + assert offset_n.get_year_end(datetime(2013, 1, 2)) == datetime(2013, 12, 31) + assert offset_n.get_year_end(datetime(2013, 1, 3)) == datetime(2013, 12, 31) + assert offset_n.get_year_end(datetime(2013, 1, 10)) == datetime(2013, 12, 31) + + JNJ = FY5253(n=1, startingMonth=12, weekday=6, variation="nearest") + assert JNJ.get_year_end(datetime(2006, 1, 1)) == datetime(2006, 12, 31) + + offset_lom_aug_sat = makeFY5253NearestEndMonth( + 1, startingMonth=8, weekday=WeekDay.SAT + ) + offset_lom_aug_thu = makeFY5253NearestEndMonth( + 1, startingMonth=8, weekday=WeekDay.THU + ) + offset_n = FY5253(weekday=WeekDay.TUE, startingMonth=12, variation="nearest") + + on_offset_cases = [ + # From Wikipedia (see: + # https://en.wikipedia.org/wiki/4%E2%80%934%E2%80%935_calendar + # #Saturday_nearest_the_end_of_month) + # 2006-09-02 2006 September 2 + # 2007-09-01 2007 September 1 + # 2008-08-30 2008 August 30 (leap year) + # 2009-08-29 2009 August 29 + # 2010-08-28 2010 August 28 + # 2011-09-03 2011 September 3 + # 2012-09-01 2012 September 1 (leap year) + # 2013-08-31 2013 August 31 + # 2014-08-30 2014 August 30 + # 2015-08-29 2015 August 29 + # 2016-09-03 2016 September 3 (leap year) + # 2017-09-02 2017 September 2 + # 2018-09-01 2018 September 1 + # 2019-08-31 2019 August 31 + (offset_lom_aug_sat, datetime(2006, 9, 2), True), + (offset_lom_aug_sat, datetime(2007, 9, 1), True), + (offset_lom_aug_sat, datetime(2008, 8, 30), True), + (offset_lom_aug_sat, datetime(2009, 8, 29), True), + (offset_lom_aug_sat, datetime(2010, 8, 28), True), + (offset_lom_aug_sat, datetime(2011, 9, 3), True), + (offset_lom_aug_sat, datetime(2016, 9, 3), True), + (offset_lom_aug_sat, datetime(2017, 9, 2), True), + (offset_lom_aug_sat, datetime(2018, 9, 1), True), + (offset_lom_aug_sat, datetime(2019, 8, 31), True), + (offset_lom_aug_sat, datetime(2006, 8, 27), False), + (offset_lom_aug_sat, datetime(2007, 8, 28), False), + (offset_lom_aug_sat, datetime(2008, 8, 31), False), + (offset_lom_aug_sat, datetime(2009, 8, 30), False), + (offset_lom_aug_sat, datetime(2010, 8, 29), False), + (offset_lom_aug_sat, datetime(2011, 8, 28), False), + (offset_lom_aug_sat, datetime(2006, 8, 25), False), + (offset_lom_aug_sat, datetime(2007, 8, 24), False), + (offset_lom_aug_sat, datetime(2008, 8, 29), False), + (offset_lom_aug_sat, datetime(2009, 8, 28), False), + (offset_lom_aug_sat, datetime(2010, 8, 27), False), + (offset_lom_aug_sat, datetime(2011, 8, 26), False), + (offset_lom_aug_sat, datetime(2019, 8, 30), False), + # From Micron, see: + # http://google.brand.edgar-online.com/?sym=MU&formtypeID=7 + (offset_lom_aug_thu, datetime(2012, 8, 30), True), + (offset_lom_aug_thu, datetime(2011, 9, 1), True), + (offset_n, datetime(2012, 12, 31), False), + (offset_n, datetime(2013, 1, 1), True), + (offset_n, datetime(2013, 1, 2), False), + ] + + @pytest.mark.parametrize("case", on_offset_cases) + def test_is_on_offset(self, case): + offset, dt, expected = case + assert_is_on_offset(offset, dt, expected) + + def test_apply(self): + date_seq_nem_8_sat = [ + datetime(2006, 9, 2), + datetime(2007, 9, 1), + datetime(2008, 8, 30), + datetime(2009, 8, 29), + datetime(2010, 8, 28), + datetime(2011, 9, 3), + ] + + JNJ = [ + datetime(2005, 1, 2), + datetime(2006, 1, 1), + datetime(2006, 12, 31), + datetime(2007, 12, 30), + datetime(2008, 12, 28), + datetime(2010, 1, 3), + datetime(2011, 1, 2), + datetime(2012, 1, 1), + datetime(2012, 12, 30), + ] + + DEC_SAT = FY5253(n=-1, startingMonth=12, weekday=5, variation="nearest") + + tests = [ + ( + makeFY5253NearestEndMonth(startingMonth=8, weekday=WeekDay.SAT), + date_seq_nem_8_sat, + ), + ( + makeFY5253NearestEndMonth(n=1, startingMonth=8, weekday=WeekDay.SAT), + date_seq_nem_8_sat, + ), + ( + makeFY5253NearestEndMonth(startingMonth=8, weekday=WeekDay.SAT), + [datetime(2006, 9, 1)] + date_seq_nem_8_sat, + ), + ( + makeFY5253NearestEndMonth(n=1, startingMonth=8, weekday=WeekDay.SAT), + [datetime(2006, 9, 3)] + date_seq_nem_8_sat[1:], + ), + ( + makeFY5253NearestEndMonth(n=-1, startingMonth=8, weekday=WeekDay.SAT), + list(reversed(date_seq_nem_8_sat)), + ), + ( + makeFY5253NearestEndMonth(n=1, startingMonth=12, weekday=WeekDay.SUN), + JNJ, + ), + ( + makeFY5253NearestEndMonth(n=-1, startingMonth=12, weekday=WeekDay.SUN), + list(reversed(JNJ)), + ), + ( + makeFY5253NearestEndMonth(n=1, startingMonth=12, weekday=WeekDay.SUN), + [datetime(2005, 1, 2), datetime(2006, 1, 1)], + ), + ( + makeFY5253NearestEndMonth(n=1, startingMonth=12, weekday=WeekDay.SUN), + [datetime(2006, 1, 2), datetime(2006, 12, 31)], + ), + (DEC_SAT, [datetime(2013, 1, 15), datetime(2012, 12, 29)]), + ] + for test in tests: + offset, data = test + current = data[0] + for datum in data[1:]: + current = current + offset + assert current == datum + + +class TestFY5253LastOfMonthQuarter: + def test_is_anchored(self): + msg = "FY5253Quarter.is_anchored is deprecated " + + with tm.assert_produces_warning(FutureWarning, match=msg): + assert makeFY5253LastOfMonthQuarter( + startingMonth=1, weekday=WeekDay.SAT, qtr_with_extra_week=4 + ).is_anchored() + assert makeFY5253LastOfMonthQuarter( + weekday=WeekDay.SAT, startingMonth=3, qtr_with_extra_week=4 + ).is_anchored() + assert not makeFY5253LastOfMonthQuarter( + 2, startingMonth=1, weekday=WeekDay.SAT, qtr_with_extra_week=4 + ).is_anchored() + + def test_equality(self): + assert makeFY5253LastOfMonthQuarter( + startingMonth=1, weekday=WeekDay.SAT, qtr_with_extra_week=4 + ) == makeFY5253LastOfMonthQuarter( + startingMonth=1, weekday=WeekDay.SAT, qtr_with_extra_week=4 + ) + assert makeFY5253LastOfMonthQuarter( + startingMonth=1, weekday=WeekDay.SAT, qtr_with_extra_week=4 + ) != makeFY5253LastOfMonthQuarter( + startingMonth=1, weekday=WeekDay.SUN, qtr_with_extra_week=4 + ) + assert makeFY5253LastOfMonthQuarter( + startingMonth=1, weekday=WeekDay.SAT, qtr_with_extra_week=4 + ) != makeFY5253LastOfMonthQuarter( + startingMonth=2, weekday=WeekDay.SAT, qtr_with_extra_week=4 + ) + + def test_offset(self): + offset = makeFY5253LastOfMonthQuarter( + 1, startingMonth=9, weekday=WeekDay.SAT, qtr_with_extra_week=4 + ) + offset2 = makeFY5253LastOfMonthQuarter( + 2, startingMonth=9, weekday=WeekDay.SAT, qtr_with_extra_week=4 + ) + offset4 = makeFY5253LastOfMonthQuarter( + 4, startingMonth=9, weekday=WeekDay.SAT, qtr_with_extra_week=4 + ) + + offset_neg1 = makeFY5253LastOfMonthQuarter( + -1, startingMonth=9, weekday=WeekDay.SAT, qtr_with_extra_week=4 + ) + offset_neg2 = makeFY5253LastOfMonthQuarter( + -2, startingMonth=9, weekday=WeekDay.SAT, qtr_with_extra_week=4 + ) + + GMCR = [ + datetime(2010, 3, 27), + datetime(2010, 6, 26), + datetime(2010, 9, 25), + datetime(2010, 12, 25), + datetime(2011, 3, 26), + datetime(2011, 6, 25), + datetime(2011, 9, 24), + datetime(2011, 12, 24), + datetime(2012, 3, 24), + datetime(2012, 6, 23), + datetime(2012, 9, 29), + datetime(2012, 12, 29), + datetime(2013, 3, 30), + datetime(2013, 6, 29), + ] + + assert_offset_equal(offset, base=GMCR[0], expected=GMCR[1]) + assert_offset_equal( + offset, base=GMCR[0] + relativedelta(days=-1), expected=GMCR[0] + ) + assert_offset_equal(offset, base=GMCR[1], expected=GMCR[2]) + + assert_offset_equal(offset2, base=GMCR[0], expected=GMCR[2]) + assert_offset_equal(offset4, base=GMCR[0], expected=GMCR[4]) + + assert_offset_equal(offset_neg1, base=GMCR[-1], expected=GMCR[-2]) + assert_offset_equal( + offset_neg1, base=GMCR[-1] + relativedelta(days=+1), expected=GMCR[-1] + ) + assert_offset_equal(offset_neg2, base=GMCR[-1], expected=GMCR[-3]) + + date = GMCR[0] + relativedelta(days=-1) + for expected in GMCR: + assert_offset_equal(offset, date, expected) + date = date + offset + + date = GMCR[-1] + relativedelta(days=+1) + for expected in reversed(GMCR): + assert_offset_equal(offset_neg1, date, expected) + date = date + offset_neg1 + + lomq_aug_sat_4 = makeFY5253LastOfMonthQuarter( + 1, startingMonth=8, weekday=WeekDay.SAT, qtr_with_extra_week=4 + ) + lomq_sep_sat_4 = makeFY5253LastOfMonthQuarter( + 1, startingMonth=9, weekday=WeekDay.SAT, qtr_with_extra_week=4 + ) + + on_offset_cases = [ + # From Wikipedia + (lomq_aug_sat_4, datetime(2006, 8, 26), True), + (lomq_aug_sat_4, datetime(2007, 8, 25), True), + (lomq_aug_sat_4, datetime(2008, 8, 30), True), + (lomq_aug_sat_4, datetime(2009, 8, 29), True), + (lomq_aug_sat_4, datetime(2010, 8, 28), True), + (lomq_aug_sat_4, datetime(2011, 8, 27), True), + (lomq_aug_sat_4, datetime(2019, 8, 31), True), + (lomq_aug_sat_4, datetime(2006, 8, 27), False), + (lomq_aug_sat_4, datetime(2007, 8, 28), False), + (lomq_aug_sat_4, datetime(2008, 8, 31), False), + (lomq_aug_sat_4, datetime(2009, 8, 30), False), + (lomq_aug_sat_4, datetime(2010, 8, 29), False), + (lomq_aug_sat_4, datetime(2011, 8, 28), False), + (lomq_aug_sat_4, datetime(2006, 8, 25), False), + (lomq_aug_sat_4, datetime(2007, 8, 24), False), + (lomq_aug_sat_4, datetime(2008, 8, 29), False), + (lomq_aug_sat_4, datetime(2009, 8, 28), False), + (lomq_aug_sat_4, datetime(2010, 8, 27), False), + (lomq_aug_sat_4, datetime(2011, 8, 26), False), + (lomq_aug_sat_4, datetime(2019, 8, 30), False), + # From GMCR + (lomq_sep_sat_4, datetime(2010, 9, 25), True), + (lomq_sep_sat_4, datetime(2011, 9, 24), True), + (lomq_sep_sat_4, datetime(2012, 9, 29), True), + (lomq_sep_sat_4, datetime(2013, 6, 29), True), + (lomq_sep_sat_4, datetime(2012, 6, 23), True), + (lomq_sep_sat_4, datetime(2012, 6, 30), False), + (lomq_sep_sat_4, datetime(2013, 3, 30), True), + (lomq_sep_sat_4, datetime(2012, 3, 24), True), + (lomq_sep_sat_4, datetime(2012, 12, 29), True), + (lomq_sep_sat_4, datetime(2011, 12, 24), True), + # INTC (extra week in Q1) + # See: http://www.intc.com/releasedetail.cfm?ReleaseID=542844 + ( + makeFY5253LastOfMonthQuarter( + 1, startingMonth=12, weekday=WeekDay.SAT, qtr_with_extra_week=1 + ), + datetime(2011, 4, 2), + True, + ), + # see: http://google.brand.edgar-online.com/?sym=INTC&formtypeID=7 + ( + makeFY5253LastOfMonthQuarter( + 1, startingMonth=12, weekday=WeekDay.SAT, qtr_with_extra_week=1 + ), + datetime(2012, 12, 29), + True, + ), + ( + makeFY5253LastOfMonthQuarter( + 1, startingMonth=12, weekday=WeekDay.SAT, qtr_with_extra_week=1 + ), + datetime(2011, 12, 31), + True, + ), + ( + makeFY5253LastOfMonthQuarter( + 1, startingMonth=12, weekday=WeekDay.SAT, qtr_with_extra_week=1 + ), + datetime(2010, 12, 25), + True, + ), + ] + + @pytest.mark.parametrize("case", on_offset_cases) + def test_is_on_offset(self, case): + offset, dt, expected = case + assert_is_on_offset(offset, dt, expected) + + def test_year_has_extra_week(self): + # End of long Q1 + assert makeFY5253LastOfMonthQuarter( + 1, startingMonth=12, weekday=WeekDay.SAT, qtr_with_extra_week=1 + ).year_has_extra_week(datetime(2011, 4, 2)) + + # Start of long Q1 + assert makeFY5253LastOfMonthQuarter( + 1, startingMonth=12, weekday=WeekDay.SAT, qtr_with_extra_week=1 + ).year_has_extra_week(datetime(2010, 12, 26)) + + # End of year before year with long Q1 + assert not makeFY5253LastOfMonthQuarter( + 1, startingMonth=12, weekday=WeekDay.SAT, qtr_with_extra_week=1 + ).year_has_extra_week(datetime(2010, 12, 25)) + + for year in [ + x for x in range(1994, 2011 + 1) if x not in [2011, 2005, 2000, 1994] + ]: + assert not makeFY5253LastOfMonthQuarter( + 1, startingMonth=12, weekday=WeekDay.SAT, qtr_with_extra_week=1 + ).year_has_extra_week(datetime(year, 4, 2)) + + # Other long years + assert makeFY5253LastOfMonthQuarter( + 1, startingMonth=12, weekday=WeekDay.SAT, qtr_with_extra_week=1 + ).year_has_extra_week(datetime(2005, 4, 2)) + + assert makeFY5253LastOfMonthQuarter( + 1, startingMonth=12, weekday=WeekDay.SAT, qtr_with_extra_week=1 + ).year_has_extra_week(datetime(2000, 4, 2)) + + assert makeFY5253LastOfMonthQuarter( + 1, startingMonth=12, weekday=WeekDay.SAT, qtr_with_extra_week=1 + ).year_has_extra_week(datetime(1994, 4, 2)) + + def test_get_weeks(self): + sat_dec_1 = makeFY5253LastOfMonthQuarter( + 1, startingMonth=12, weekday=WeekDay.SAT, qtr_with_extra_week=1 + ) + sat_dec_4 = makeFY5253LastOfMonthQuarter( + 1, startingMonth=12, weekday=WeekDay.SAT, qtr_with_extra_week=4 + ) + + assert sat_dec_1.get_weeks(datetime(2011, 4, 2)) == [14, 13, 13, 13] + assert sat_dec_4.get_weeks(datetime(2011, 4, 2)) == [13, 13, 13, 14] + assert sat_dec_1.get_weeks(datetime(2010, 12, 25)) == [13, 13, 13, 13] + + +class TestFY5253NearestEndMonthQuarter: + offset_nem_sat_aug_4 = makeFY5253NearestEndMonthQuarter( + 1, startingMonth=8, weekday=WeekDay.SAT, qtr_with_extra_week=4 + ) + offset_nem_thu_aug_4 = makeFY5253NearestEndMonthQuarter( + 1, startingMonth=8, weekday=WeekDay.THU, qtr_with_extra_week=4 + ) + offset_n = FY5253(weekday=WeekDay.TUE, startingMonth=12, variation="nearest") + + on_offset_cases = [ + # From Wikipedia + (offset_nem_sat_aug_4, datetime(2006, 9, 2), True), + (offset_nem_sat_aug_4, datetime(2007, 9, 1), True), + (offset_nem_sat_aug_4, datetime(2008, 8, 30), True), + (offset_nem_sat_aug_4, datetime(2009, 8, 29), True), + (offset_nem_sat_aug_4, datetime(2010, 8, 28), True), + (offset_nem_sat_aug_4, datetime(2011, 9, 3), True), + (offset_nem_sat_aug_4, datetime(2016, 9, 3), True), + (offset_nem_sat_aug_4, datetime(2017, 9, 2), True), + (offset_nem_sat_aug_4, datetime(2018, 9, 1), True), + (offset_nem_sat_aug_4, datetime(2019, 8, 31), True), + (offset_nem_sat_aug_4, datetime(2006, 8, 27), False), + (offset_nem_sat_aug_4, datetime(2007, 8, 28), False), + (offset_nem_sat_aug_4, datetime(2008, 8, 31), False), + (offset_nem_sat_aug_4, datetime(2009, 8, 30), False), + (offset_nem_sat_aug_4, datetime(2010, 8, 29), False), + (offset_nem_sat_aug_4, datetime(2011, 8, 28), False), + (offset_nem_sat_aug_4, datetime(2006, 8, 25), False), + (offset_nem_sat_aug_4, datetime(2007, 8, 24), False), + (offset_nem_sat_aug_4, datetime(2008, 8, 29), False), + (offset_nem_sat_aug_4, datetime(2009, 8, 28), False), + (offset_nem_sat_aug_4, datetime(2010, 8, 27), False), + (offset_nem_sat_aug_4, datetime(2011, 8, 26), False), + (offset_nem_sat_aug_4, datetime(2019, 8, 30), False), + # From Micron, see: + # http://google.brand.edgar-online.com/?sym=MU&formtypeID=7 + (offset_nem_thu_aug_4, datetime(2012, 8, 30), True), + (offset_nem_thu_aug_4, datetime(2011, 9, 1), True), + # See: http://google.brand.edgar-online.com/?sym=MU&formtypeID=13 + (offset_nem_thu_aug_4, datetime(2013, 5, 30), True), + (offset_nem_thu_aug_4, datetime(2013, 2, 28), True), + (offset_nem_thu_aug_4, datetime(2012, 11, 29), True), + (offset_nem_thu_aug_4, datetime(2012, 5, 31), True), + (offset_nem_thu_aug_4, datetime(2007, 3, 1), True), + (offset_nem_thu_aug_4, datetime(1994, 3, 3), True), + (offset_n, datetime(2012, 12, 31), False), + (offset_n, datetime(2013, 1, 1), True), + (offset_n, datetime(2013, 1, 2), False), + ] + + @pytest.mark.parametrize("case", on_offset_cases) + def test_is_on_offset(self, case): + offset, dt, expected = case + assert_is_on_offset(offset, dt, expected) + + def test_offset(self): + offset = makeFY5253NearestEndMonthQuarter( + 1, startingMonth=8, weekday=WeekDay.THU, qtr_with_extra_week=4 + ) + + MU = [ + datetime(2012, 5, 31), + datetime(2012, 8, 30), + datetime(2012, 11, 29), + datetime(2013, 2, 28), + datetime(2013, 5, 30), + ] + + date = MU[0] + relativedelta(days=-1) + for expected in MU: + assert_offset_equal(offset, date, expected) + date = date + offset + + assert_offset_equal(offset, datetime(2012, 5, 31), datetime(2012, 8, 30)) + assert_offset_equal(offset, datetime(2012, 5, 30), datetime(2012, 5, 31)) + + offset2 = FY5253Quarter( + weekday=5, startingMonth=12, variation="last", qtr_with_extra_week=4 + ) + + assert_offset_equal(offset2, datetime(2013, 1, 15), datetime(2013, 3, 30)) + + +def test_bunched_yearends(): + # GH#14774 cases with two fiscal year-ends in the same calendar-year + fy = FY5253(n=1, weekday=5, startingMonth=12, variation="nearest") + dt = Timestamp("2004-01-01") + assert fy.rollback(dt) == Timestamp("2002-12-28") + assert (-fy)._apply(dt) == Timestamp("2002-12-28") + assert dt - fy == Timestamp("2002-12-28") + + assert fy.rollforward(dt) == Timestamp("2004-01-03") + assert fy._apply(dt) == Timestamp("2004-01-03") + assert fy + dt == Timestamp("2004-01-03") + assert dt + fy == Timestamp("2004-01-03") + + # Same thing, but starting from a Timestamp in the previous year. + dt = Timestamp("2003-12-31") + assert fy.rollback(dt) == Timestamp("2002-12-28") + assert (-fy)._apply(dt) == Timestamp("2002-12-28") + assert dt - fy == Timestamp("2002-12-28") + + +def test_fy5253_last_onoffset(): + # GH#18877 dates on the year-end but not normalized to midnight + offset = FY5253(n=-5, startingMonth=5, variation="last", weekday=0) + ts = Timestamp("1984-05-28 06:29:43.955911354+0200", tz="Europe/San_Marino") + fast = offset.is_on_offset(ts) + slow = (ts + offset) - offset == ts + assert fast == slow + + +def test_fy5253_nearest_onoffset(): + # GH#18877 dates on the year-end but not normalized to midnight + offset = FY5253(n=3, startingMonth=7, variation="nearest", weekday=2) + ts = Timestamp("2032-07-28 00:12:59.035729419+0000", tz="Africa/Dakar") + fast = offset.is_on_offset(ts) + slow = (ts + offset) - offset == ts + assert fast == slow + + +def test_fy5253qtr_onoffset_nearest(): + # GH#19036 + ts = Timestamp("1985-09-02 23:57:46.232550356-0300", tz="Atlantic/Bermuda") + offset = FY5253Quarter( + n=3, qtr_with_extra_week=1, startingMonth=2, variation="nearest", weekday=0 + ) + fast = offset.is_on_offset(ts) + slow = (ts + offset) - offset == ts + assert fast == slow + + +def test_fy5253qtr_onoffset_last(): + # GH#19036 + offset = FY5253Quarter( + n=-2, qtr_with_extra_week=1, startingMonth=7, variation="last", weekday=2 + ) + ts = Timestamp("2011-01-26 19:03:40.331096129+0200", tz="Africa/Windhoek") + slow = (ts + offset) - offset == ts + fast = offset.is_on_offset(ts) + assert fast == slow diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_index.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_index.py new file mode 100644 index 0000000000000000000000000000000000000000..7a62944556d11b536f7a64e49d4a9ff11e90ec0e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_index.py @@ -0,0 +1,57 @@ +""" +Tests for offset behavior with indices. +""" +import pytest + +from pandas import ( + Series, + date_range, +) + +from pandas.tseries.offsets import ( + BMonthBegin, + BMonthEnd, + BQuarterBegin, + BQuarterEnd, + BYearBegin, + BYearEnd, + MonthBegin, + MonthEnd, + QuarterBegin, + QuarterEnd, + YearBegin, + YearEnd, +) + + +@pytest.mark.parametrize("n", [-2, 1]) +@pytest.mark.parametrize( + "cls", + [ + MonthBegin, + MonthEnd, + BMonthBegin, + BMonthEnd, + QuarterBegin, + QuarterEnd, + BQuarterBegin, + BQuarterEnd, + YearBegin, + YearEnd, + BYearBegin, + BYearEnd, + ], +) +def test_apply_index(cls, n): + offset = cls(n=n) + rng = date_range(start="1/1/2000", periods=100000, freq="min") + ser = Series(rng) + + res = rng + offset + assert res.freq is None # not retained + assert res[0] == rng[0] + offset + assert res[-1] == rng[-1] + offset + res2 = ser + offset + # apply_index is only for indexes, not series, so no res2_v2 + assert res2.iloc[0] == ser.iloc[0] + offset + assert res2.iloc[-1] == ser.iloc[-1] + offset diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_month.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_month.py new file mode 100644 index 0000000000000000000000000000000000000000..2b643999c3ad34057156f4dc9f382dd3950e35c5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_month.py @@ -0,0 +1,666 @@ +""" +Tests for the following offsets: +- SemiMonthBegin +- SemiMonthEnd +- MonthBegin +- MonthEnd +""" +from __future__ import annotations + +from datetime import datetime + +import pytest + +from pandas._libs.tslibs import Timestamp +from pandas._libs.tslibs.offsets import ( + MonthBegin, + MonthEnd, + SemiMonthBegin, + SemiMonthEnd, +) + +from pandas import ( + DatetimeIndex, + Series, + _testing as tm, +) +from pandas.tests.tseries.offsets.common import ( + assert_is_on_offset, + assert_offset_equal, +) + + +class TestSemiMonthEnd: + def test_offset_whole_year(self): + dates = ( + datetime(2007, 12, 31), + datetime(2008, 1, 15), + datetime(2008, 1, 31), + datetime(2008, 2, 15), + datetime(2008, 2, 29), + datetime(2008, 3, 15), + datetime(2008, 3, 31), + datetime(2008, 4, 15), + datetime(2008, 4, 30), + datetime(2008, 5, 15), + datetime(2008, 5, 31), + datetime(2008, 6, 15), + datetime(2008, 6, 30), + datetime(2008, 7, 15), + datetime(2008, 7, 31), + datetime(2008, 8, 15), + datetime(2008, 8, 31), + datetime(2008, 9, 15), + datetime(2008, 9, 30), + datetime(2008, 10, 15), + datetime(2008, 10, 31), + datetime(2008, 11, 15), + datetime(2008, 11, 30), + datetime(2008, 12, 15), + datetime(2008, 12, 31), + ) + + for base, exp_date in zip(dates[:-1], dates[1:]): + assert_offset_equal(SemiMonthEnd(), base, exp_date) + + # ensure .apply_index works as expected + shift = DatetimeIndex(dates[:-1]) + with tm.assert_produces_warning(None): + # GH#22535 check that we don't get a FutureWarning from adding + # an integer array to PeriodIndex + result = SemiMonthEnd() + shift + + exp = DatetimeIndex(dates[1:]) + tm.assert_index_equal(result, exp) + + offset_cases = [] + offset_cases.append( + ( + SemiMonthEnd(), + { + datetime(2008, 1, 1): datetime(2008, 1, 15), + datetime(2008, 1, 15): datetime(2008, 1, 31), + datetime(2008, 1, 31): datetime(2008, 2, 15), + datetime(2006, 12, 14): datetime(2006, 12, 15), + datetime(2006, 12, 29): datetime(2006, 12, 31), + datetime(2006, 12, 31): datetime(2007, 1, 15), + datetime(2007, 1, 1): datetime(2007, 1, 15), + datetime(2006, 12, 1): datetime(2006, 12, 15), + datetime(2006, 12, 15): datetime(2006, 12, 31), + }, + ) + ) + + offset_cases.append( + ( + SemiMonthEnd(day_of_month=20), + { + datetime(2008, 1, 1): datetime(2008, 1, 20), + datetime(2008, 1, 15): datetime(2008, 1, 20), + datetime(2008, 1, 21): datetime(2008, 1, 31), + datetime(2008, 1, 31): datetime(2008, 2, 20), + datetime(2006, 12, 14): datetime(2006, 12, 20), + datetime(2006, 12, 29): datetime(2006, 12, 31), + datetime(2006, 12, 31): datetime(2007, 1, 20), + datetime(2007, 1, 1): datetime(2007, 1, 20), + datetime(2006, 12, 1): datetime(2006, 12, 20), + datetime(2006, 12, 15): datetime(2006, 12, 20), + }, + ) + ) + + offset_cases.append( + ( + SemiMonthEnd(0), + { + datetime(2008, 1, 1): datetime(2008, 1, 15), + datetime(2008, 1, 16): datetime(2008, 1, 31), + datetime(2008, 1, 15): datetime(2008, 1, 15), + datetime(2008, 1, 31): datetime(2008, 1, 31), + datetime(2006, 12, 29): datetime(2006, 12, 31), + datetime(2006, 12, 31): datetime(2006, 12, 31), + datetime(2007, 1, 1): datetime(2007, 1, 15), + }, + ) + ) + + offset_cases.append( + ( + SemiMonthEnd(0, day_of_month=16), + { + datetime(2008, 1, 1): datetime(2008, 1, 16), + datetime(2008, 1, 16): datetime(2008, 1, 16), + datetime(2008, 1, 15): datetime(2008, 1, 16), + datetime(2008, 1, 31): datetime(2008, 1, 31), + datetime(2006, 12, 29): datetime(2006, 12, 31), + datetime(2006, 12, 31): datetime(2006, 12, 31), + datetime(2007, 1, 1): datetime(2007, 1, 16), + }, + ) + ) + + offset_cases.append( + ( + SemiMonthEnd(2), + { + datetime(2008, 1, 1): datetime(2008, 1, 31), + datetime(2008, 1, 31): datetime(2008, 2, 29), + datetime(2006, 12, 29): datetime(2007, 1, 15), + datetime(2006, 12, 31): datetime(2007, 1, 31), + datetime(2007, 1, 1): datetime(2007, 1, 31), + datetime(2007, 1, 16): datetime(2007, 2, 15), + datetime(2006, 11, 1): datetime(2006, 11, 30), + }, + ) + ) + + offset_cases.append( + ( + SemiMonthEnd(-1), + { + datetime(2007, 1, 1): datetime(2006, 12, 31), + datetime(2008, 6, 30): datetime(2008, 6, 15), + datetime(2008, 12, 31): datetime(2008, 12, 15), + datetime(2006, 12, 29): datetime(2006, 12, 15), + datetime(2006, 12, 30): datetime(2006, 12, 15), + datetime(2007, 1, 1): datetime(2006, 12, 31), + }, + ) + ) + + offset_cases.append( + ( + SemiMonthEnd(-1, day_of_month=4), + { + datetime(2007, 1, 1): datetime(2006, 12, 31), + datetime(2007, 1, 4): datetime(2006, 12, 31), + datetime(2008, 6, 30): datetime(2008, 6, 4), + datetime(2008, 12, 31): datetime(2008, 12, 4), + datetime(2006, 12, 5): datetime(2006, 12, 4), + datetime(2006, 12, 30): datetime(2006, 12, 4), + datetime(2007, 1, 1): datetime(2006, 12, 31), + }, + ) + ) + + offset_cases.append( + ( + SemiMonthEnd(-2), + { + datetime(2007, 1, 1): datetime(2006, 12, 15), + datetime(2008, 6, 30): datetime(2008, 5, 31), + datetime(2008, 3, 15): datetime(2008, 2, 15), + datetime(2008, 12, 31): datetime(2008, 11, 30), + datetime(2006, 12, 29): datetime(2006, 11, 30), + datetime(2006, 12, 14): datetime(2006, 11, 15), + datetime(2007, 1, 1): datetime(2006, 12, 15), + }, + ) + ) + + @pytest.mark.parametrize("case", offset_cases) + def test_offset(self, case): + offset, cases = case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + @pytest.mark.parametrize("case", offset_cases) + def test_apply_index(self, case): + # https://github.com/pandas-dev/pandas/issues/34580 + offset, cases = case + shift = DatetimeIndex(cases.keys()) + exp = DatetimeIndex(cases.values()) + + with tm.assert_produces_warning(None): + # GH#22535 check that we don't get a FutureWarning from adding + # an integer array to PeriodIndex + result = offset + shift + tm.assert_index_equal(result, exp) + + on_offset_cases = [ + (datetime(2007, 12, 31), True), + (datetime(2007, 12, 15), True), + (datetime(2007, 12, 14), False), + (datetime(2007, 12, 1), False), + (datetime(2008, 2, 29), True), + ] + + @pytest.mark.parametrize("case", on_offset_cases) + def test_is_on_offset(self, case): + dt, expected = case + assert_is_on_offset(SemiMonthEnd(), dt, expected) + + @pytest.mark.parametrize("klass", [Series, DatetimeIndex]) + def test_vectorized_offset_addition(self, klass): + shift = klass( + [ + Timestamp("2000-01-15 00:15:00", tz="US/Central"), + Timestamp("2000-02-15", tz="US/Central"), + ], + name="a", + ) + + with tm.assert_produces_warning(None): + # GH#22535 check that we don't get a FutureWarning from adding + # an integer array to PeriodIndex + result = shift + SemiMonthEnd() + result2 = SemiMonthEnd() + shift + + exp = klass( + [ + Timestamp("2000-01-31 00:15:00", tz="US/Central"), + Timestamp("2000-02-29", tz="US/Central"), + ], + name="a", + ) + tm.assert_equal(result, exp) + tm.assert_equal(result2, exp) + + shift = klass( + [ + Timestamp("2000-01-01 00:15:00", tz="US/Central"), + Timestamp("2000-02-01", tz="US/Central"), + ], + name="a", + ) + + with tm.assert_produces_warning(None): + # GH#22535 check that we don't get a FutureWarning from adding + # an integer array to PeriodIndex + result = shift + SemiMonthEnd() + result2 = SemiMonthEnd() + shift + + exp = klass( + [ + Timestamp("2000-01-15 00:15:00", tz="US/Central"), + Timestamp("2000-02-15", tz="US/Central"), + ], + name="a", + ) + tm.assert_equal(result, exp) + tm.assert_equal(result2, exp) + + +class TestSemiMonthBegin: + def test_offset_whole_year(self): + dates = ( + datetime(2007, 12, 15), + datetime(2008, 1, 1), + datetime(2008, 1, 15), + datetime(2008, 2, 1), + datetime(2008, 2, 15), + datetime(2008, 3, 1), + datetime(2008, 3, 15), + datetime(2008, 4, 1), + datetime(2008, 4, 15), + datetime(2008, 5, 1), + datetime(2008, 5, 15), + datetime(2008, 6, 1), + datetime(2008, 6, 15), + datetime(2008, 7, 1), + datetime(2008, 7, 15), + datetime(2008, 8, 1), + datetime(2008, 8, 15), + datetime(2008, 9, 1), + datetime(2008, 9, 15), + datetime(2008, 10, 1), + datetime(2008, 10, 15), + datetime(2008, 11, 1), + datetime(2008, 11, 15), + datetime(2008, 12, 1), + datetime(2008, 12, 15), + ) + + for base, exp_date in zip(dates[:-1], dates[1:]): + assert_offset_equal(SemiMonthBegin(), base, exp_date) + + # ensure .apply_index works as expected + shift = DatetimeIndex(dates[:-1]) + with tm.assert_produces_warning(None): + # GH#22535 check that we don't get a FutureWarning from adding + # an integer array to PeriodIndex + result = SemiMonthBegin() + shift + + exp = DatetimeIndex(dates[1:]) + tm.assert_index_equal(result, exp) + + offset_cases = [ + ( + SemiMonthBegin(), + { + datetime(2008, 1, 1): datetime(2008, 1, 15), + datetime(2008, 1, 15): datetime(2008, 2, 1), + datetime(2008, 1, 31): datetime(2008, 2, 1), + datetime(2006, 12, 14): datetime(2006, 12, 15), + datetime(2006, 12, 29): datetime(2007, 1, 1), + datetime(2006, 12, 31): datetime(2007, 1, 1), + datetime(2007, 1, 1): datetime(2007, 1, 15), + datetime(2006, 12, 1): datetime(2006, 12, 15), + datetime(2006, 12, 15): datetime(2007, 1, 1), + }, + ), + ( + SemiMonthBegin(day_of_month=20), + { + datetime(2008, 1, 1): datetime(2008, 1, 20), + datetime(2008, 1, 15): datetime(2008, 1, 20), + datetime(2008, 1, 21): datetime(2008, 2, 1), + datetime(2008, 1, 31): datetime(2008, 2, 1), + datetime(2006, 12, 14): datetime(2006, 12, 20), + datetime(2006, 12, 29): datetime(2007, 1, 1), + datetime(2006, 12, 31): datetime(2007, 1, 1), + datetime(2007, 1, 1): datetime(2007, 1, 20), + datetime(2006, 12, 1): datetime(2006, 12, 20), + datetime(2006, 12, 15): datetime(2006, 12, 20), + }, + ), + ( + SemiMonthBegin(0), + { + datetime(2008, 1, 1): datetime(2008, 1, 1), + datetime(2008, 1, 16): datetime(2008, 2, 1), + datetime(2008, 1, 15): datetime(2008, 1, 15), + datetime(2008, 1, 31): datetime(2008, 2, 1), + datetime(2006, 12, 29): datetime(2007, 1, 1), + datetime(2006, 12, 2): datetime(2006, 12, 15), + datetime(2007, 1, 1): datetime(2007, 1, 1), + }, + ), + ( + SemiMonthBegin(0, day_of_month=16), + { + datetime(2008, 1, 1): datetime(2008, 1, 1), + datetime(2008, 1, 16): datetime(2008, 1, 16), + datetime(2008, 1, 15): datetime(2008, 1, 16), + datetime(2008, 1, 31): datetime(2008, 2, 1), + datetime(2006, 12, 29): datetime(2007, 1, 1), + datetime(2006, 12, 31): datetime(2007, 1, 1), + datetime(2007, 1, 5): datetime(2007, 1, 16), + datetime(2007, 1, 1): datetime(2007, 1, 1), + }, + ), + ( + SemiMonthBegin(2), + { + datetime(2008, 1, 1): datetime(2008, 2, 1), + datetime(2008, 1, 31): datetime(2008, 2, 15), + datetime(2006, 12, 1): datetime(2007, 1, 1), + datetime(2006, 12, 29): datetime(2007, 1, 15), + datetime(2006, 12, 15): datetime(2007, 1, 15), + datetime(2007, 1, 1): datetime(2007, 2, 1), + datetime(2007, 1, 16): datetime(2007, 2, 15), + datetime(2006, 11, 1): datetime(2006, 12, 1), + }, + ), + ( + SemiMonthBegin(-1), + { + datetime(2007, 1, 1): datetime(2006, 12, 15), + datetime(2008, 6, 30): datetime(2008, 6, 15), + datetime(2008, 6, 14): datetime(2008, 6, 1), + datetime(2008, 12, 31): datetime(2008, 12, 15), + datetime(2006, 12, 29): datetime(2006, 12, 15), + datetime(2006, 12, 15): datetime(2006, 12, 1), + datetime(2007, 1, 1): datetime(2006, 12, 15), + }, + ), + ( + SemiMonthBegin(-1, day_of_month=4), + { + datetime(2007, 1, 1): datetime(2006, 12, 4), + datetime(2007, 1, 4): datetime(2007, 1, 1), + datetime(2008, 6, 30): datetime(2008, 6, 4), + datetime(2008, 12, 31): datetime(2008, 12, 4), + datetime(2006, 12, 5): datetime(2006, 12, 4), + datetime(2006, 12, 30): datetime(2006, 12, 4), + datetime(2006, 12, 2): datetime(2006, 12, 1), + datetime(2007, 1, 1): datetime(2006, 12, 4), + }, + ), + ( + SemiMonthBegin(-2), + { + datetime(2007, 1, 1): datetime(2006, 12, 1), + datetime(2008, 6, 30): datetime(2008, 6, 1), + datetime(2008, 6, 14): datetime(2008, 5, 15), + datetime(2008, 12, 31): datetime(2008, 12, 1), + datetime(2006, 12, 29): datetime(2006, 12, 1), + datetime(2006, 12, 15): datetime(2006, 11, 15), + datetime(2007, 1, 1): datetime(2006, 12, 1), + }, + ), + ] + + @pytest.mark.parametrize("case", offset_cases) + def test_offset(self, case): + offset, cases = case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + @pytest.mark.parametrize("case", offset_cases) + def test_apply_index(self, case): + offset, cases = case + shift = DatetimeIndex(cases.keys()) + + with tm.assert_produces_warning(None): + # GH#22535 check that we don't get a FutureWarning from adding + # an integer array to PeriodIndex + result = offset + shift + + exp = DatetimeIndex(cases.values()) + tm.assert_index_equal(result, exp) + + on_offset_cases = [ + (datetime(2007, 12, 1), True), + (datetime(2007, 12, 15), True), + (datetime(2007, 12, 14), False), + (datetime(2007, 12, 31), False), + (datetime(2008, 2, 15), True), + ] + + @pytest.mark.parametrize("case", on_offset_cases) + def test_is_on_offset(self, case): + dt, expected = case + assert_is_on_offset(SemiMonthBegin(), dt, expected) + + @pytest.mark.parametrize("klass", [Series, DatetimeIndex]) + def test_vectorized_offset_addition(self, klass): + shift = klass( + [ + Timestamp("2000-01-15 00:15:00", tz="US/Central"), + Timestamp("2000-02-15", tz="US/Central"), + ], + name="a", + ) + with tm.assert_produces_warning(None): + # GH#22535 check that we don't get a FutureWarning from adding + # an integer array to PeriodIndex + result = shift + SemiMonthBegin() + result2 = SemiMonthBegin() + shift + + exp = klass( + [ + Timestamp("2000-02-01 00:15:00", tz="US/Central"), + Timestamp("2000-03-01", tz="US/Central"), + ], + name="a", + ) + tm.assert_equal(result, exp) + tm.assert_equal(result2, exp) + + shift = klass( + [ + Timestamp("2000-01-01 00:15:00", tz="US/Central"), + Timestamp("2000-02-01", tz="US/Central"), + ], + name="a", + ) + with tm.assert_produces_warning(None): + # GH#22535 check that we don't get a FutureWarning from adding + # an integer array to PeriodIndex + result = shift + SemiMonthBegin() + result2 = SemiMonthBegin() + shift + + exp = klass( + [ + Timestamp("2000-01-15 00:15:00", tz="US/Central"), + Timestamp("2000-02-15", tz="US/Central"), + ], + name="a", + ) + tm.assert_equal(result, exp) + tm.assert_equal(result2, exp) + + +class TestMonthBegin: + offset_cases = [] + # NOTE: I'm not entirely happy with the logic here for Begin -ss + # see thread 'offset conventions' on the ML + offset_cases.append( + ( + MonthBegin(), + { + datetime(2008, 1, 31): datetime(2008, 2, 1), + datetime(2008, 2, 1): datetime(2008, 3, 1), + datetime(2006, 12, 31): datetime(2007, 1, 1), + datetime(2006, 12, 1): datetime(2007, 1, 1), + datetime(2007, 1, 31): datetime(2007, 2, 1), + }, + ) + ) + + offset_cases.append( + ( + MonthBegin(0), + { + datetime(2008, 1, 31): datetime(2008, 2, 1), + datetime(2008, 1, 1): datetime(2008, 1, 1), + datetime(2006, 12, 3): datetime(2007, 1, 1), + datetime(2007, 1, 31): datetime(2007, 2, 1), + }, + ) + ) + + offset_cases.append( + ( + MonthBegin(2), + { + datetime(2008, 2, 29): datetime(2008, 4, 1), + datetime(2008, 1, 31): datetime(2008, 3, 1), + datetime(2006, 12, 31): datetime(2007, 2, 1), + datetime(2007, 12, 28): datetime(2008, 2, 1), + datetime(2007, 1, 1): datetime(2007, 3, 1), + datetime(2006, 11, 1): datetime(2007, 1, 1), + }, + ) + ) + + offset_cases.append( + ( + MonthBegin(-1), + { + datetime(2007, 1, 1): datetime(2006, 12, 1), + datetime(2008, 5, 31): datetime(2008, 5, 1), + datetime(2008, 12, 31): datetime(2008, 12, 1), + datetime(2006, 12, 29): datetime(2006, 12, 1), + datetime(2006, 1, 2): datetime(2006, 1, 1), + }, + ) + ) + + @pytest.mark.parametrize("case", offset_cases) + def test_offset(self, case): + offset, cases = case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + +class TestMonthEnd: + def test_day_of_month(self): + dt = datetime(2007, 1, 1) + offset = MonthEnd() + + result = dt + offset + assert result == Timestamp(2007, 1, 31) + + result = result + offset + assert result == Timestamp(2007, 2, 28) + + def test_normalize(self): + dt = datetime(2007, 1, 1, 3) + + result = dt + MonthEnd(normalize=True) + expected = dt.replace(hour=0) + MonthEnd() + assert result == expected + + offset_cases = [] + offset_cases.append( + ( + MonthEnd(), + { + datetime(2008, 1, 1): datetime(2008, 1, 31), + datetime(2008, 1, 31): datetime(2008, 2, 29), + datetime(2006, 12, 29): datetime(2006, 12, 31), + datetime(2006, 12, 31): datetime(2007, 1, 31), + datetime(2007, 1, 1): datetime(2007, 1, 31), + datetime(2006, 12, 1): datetime(2006, 12, 31), + }, + ) + ) + + offset_cases.append( + ( + MonthEnd(0), + { + datetime(2008, 1, 1): datetime(2008, 1, 31), + datetime(2008, 1, 31): datetime(2008, 1, 31), + datetime(2006, 12, 29): datetime(2006, 12, 31), + datetime(2006, 12, 31): datetime(2006, 12, 31), + datetime(2007, 1, 1): datetime(2007, 1, 31), + }, + ) + ) + + offset_cases.append( + ( + MonthEnd(2), + { + datetime(2008, 1, 1): datetime(2008, 2, 29), + datetime(2008, 1, 31): datetime(2008, 3, 31), + datetime(2006, 12, 29): datetime(2007, 1, 31), + datetime(2006, 12, 31): datetime(2007, 2, 28), + datetime(2007, 1, 1): datetime(2007, 2, 28), + datetime(2006, 11, 1): datetime(2006, 12, 31), + }, + ) + ) + + offset_cases.append( + ( + MonthEnd(-1), + { + datetime(2007, 1, 1): datetime(2006, 12, 31), + datetime(2008, 6, 30): datetime(2008, 5, 31), + datetime(2008, 12, 31): datetime(2008, 11, 30), + datetime(2006, 12, 29): datetime(2006, 11, 30), + datetime(2006, 12, 30): datetime(2006, 11, 30), + datetime(2007, 1, 1): datetime(2006, 12, 31), + }, + ) + ) + + @pytest.mark.parametrize("case", offset_cases) + def test_offset(self, case): + offset, cases = case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + on_offset_cases = [ + (MonthEnd(), datetime(2007, 12, 31), True), + (MonthEnd(), datetime(2008, 1, 1), False), + ] + + @pytest.mark.parametrize("case", on_offset_cases) + def test_is_on_offset(self, case): + offset, dt, expected = case + assert_is_on_offset(offset, dt, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_offsets.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_offsets.py new file mode 100644 index 0000000000000000000000000000000000000000..62afb8b83d576a7a16565840b6c3f61cfd26e9e1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_offsets.py @@ -0,0 +1,1185 @@ +""" +Tests of pandas.tseries.offsets +""" +from __future__ import annotations + +from datetime import ( + datetime, + timedelta, +) + +import numpy as np +import pytest + +from pandas._libs.tslibs import ( + NaT, + Timedelta, + Timestamp, + conversion, + timezones, +) +import pandas._libs.tslibs.offsets as liboffsets +from pandas._libs.tslibs.offsets import ( + _get_offset, + _offset_map, + to_offset, +) +from pandas._libs.tslibs.period import INVALID_FREQ_ERR_MSG +from pandas.errors import PerformanceWarning + +from pandas import ( + DataFrame, + DatetimeIndex, + Series, + date_range, +) +import pandas._testing as tm +from pandas.tests.tseries.offsets.common import WeekDay + +from pandas.tseries import offsets +from pandas.tseries.offsets import ( + FY5253, + BDay, + BMonthEnd, + BusinessHour, + CustomBusinessDay, + CustomBusinessHour, + CustomBusinessMonthBegin, + CustomBusinessMonthEnd, + DateOffset, + Easter, + FY5253Quarter, + LastWeekOfMonth, + MonthBegin, + Nano, + Tick, + Week, + WeekOfMonth, +) + +_ARITHMETIC_DATE_OFFSET = [ + "years", + "months", + "weeks", + "days", + "hours", + "minutes", + "seconds", + "milliseconds", + "microseconds", +] + + +def _create_offset(klass, value=1, normalize=False): + # create instance from offset class + if klass is FY5253: + klass = klass( + n=value, + startingMonth=1, + weekday=1, + variation="last", + normalize=normalize, + ) + elif klass is FY5253Quarter: + klass = klass( + n=value, + startingMonth=1, + weekday=1, + qtr_with_extra_week=1, + variation="last", + normalize=normalize, + ) + elif klass is LastWeekOfMonth: + klass = klass(n=value, weekday=5, normalize=normalize) + elif klass is WeekOfMonth: + klass = klass(n=value, week=1, weekday=5, normalize=normalize) + elif klass is Week: + klass = klass(n=value, weekday=5, normalize=normalize) + elif klass is DateOffset: + klass = klass(days=value, normalize=normalize) + else: + klass = klass(value, normalize=normalize) + return klass + + +@pytest.fixture( + params=[ + getattr(offsets, o) + for o in offsets.__all__ + if issubclass(getattr(offsets, o), liboffsets.MonthOffset) + and o != "MonthOffset" + ] +) +def month_classes(request): + """ + Fixture for month based datetime offsets available for a time series. + """ + return request.param + + +@pytest.fixture( + params=[ + getattr(offsets, o) for o in offsets.__all__ if o not in ("Tick", "BaseOffset") + ] +) +def offset_types(request): + """ + Fixture for all the datetime offsets available for a time series. + """ + return request.param + + +@pytest.fixture +def dt(): + return Timestamp(datetime(2008, 1, 2)) + + +@pytest.fixture +def expecteds(): + # executed value created by _create_offset + # are applied to 2011/01/01 09:00 (Saturday) + # used for .apply and .rollforward + return { + "Day": Timestamp("2011-01-02 09:00:00"), + "DateOffset": Timestamp("2011-01-02 09:00:00"), + "BusinessDay": Timestamp("2011-01-03 09:00:00"), + "CustomBusinessDay": Timestamp("2011-01-03 09:00:00"), + "CustomBusinessMonthEnd": Timestamp("2011-01-31 09:00:00"), + "CustomBusinessMonthBegin": Timestamp("2011-01-03 09:00:00"), + "MonthBegin": Timestamp("2011-02-01 09:00:00"), + "BusinessMonthBegin": Timestamp("2011-01-03 09:00:00"), + "MonthEnd": Timestamp("2011-01-31 09:00:00"), + "SemiMonthEnd": Timestamp("2011-01-15 09:00:00"), + "SemiMonthBegin": Timestamp("2011-01-15 09:00:00"), + "BusinessMonthEnd": Timestamp("2011-01-31 09:00:00"), + "YearBegin": Timestamp("2012-01-01 09:00:00"), + "BYearBegin": Timestamp("2011-01-03 09:00:00"), + "YearEnd": Timestamp("2011-12-31 09:00:00"), + "BYearEnd": Timestamp("2011-12-30 09:00:00"), + "QuarterBegin": Timestamp("2011-03-01 09:00:00"), + "BQuarterBegin": Timestamp("2011-03-01 09:00:00"), + "QuarterEnd": Timestamp("2011-03-31 09:00:00"), + "BQuarterEnd": Timestamp("2011-03-31 09:00:00"), + "BusinessHour": Timestamp("2011-01-03 10:00:00"), + "CustomBusinessHour": Timestamp("2011-01-03 10:00:00"), + "WeekOfMonth": Timestamp("2011-01-08 09:00:00"), + "LastWeekOfMonth": Timestamp("2011-01-29 09:00:00"), + "FY5253Quarter": Timestamp("2011-01-25 09:00:00"), + "FY5253": Timestamp("2011-01-25 09:00:00"), + "Week": Timestamp("2011-01-08 09:00:00"), + "Easter": Timestamp("2011-04-24 09:00:00"), + "Hour": Timestamp("2011-01-01 10:00:00"), + "Minute": Timestamp("2011-01-01 09:01:00"), + "Second": Timestamp("2011-01-01 09:00:01"), + "Milli": Timestamp("2011-01-01 09:00:00.001000"), + "Micro": Timestamp("2011-01-01 09:00:00.000001"), + "Nano": Timestamp("2011-01-01T09:00:00.000000001"), + } + + +class TestCommon: + def test_immutable(self, offset_types): + # GH#21341 check that __setattr__ raises + offset = _create_offset(offset_types) + msg = "objects is not writable|DateOffset objects are immutable" + with pytest.raises(AttributeError, match=msg): + offset.normalize = True + with pytest.raises(AttributeError, match=msg): + offset.n = 91 + + def test_return_type(self, offset_types): + offset = _create_offset(offset_types) + + # make sure that we are returning a Timestamp + result = Timestamp("20080101") + offset + assert isinstance(result, Timestamp) + + # make sure that we are returning NaT + assert NaT + offset is NaT + assert offset + NaT is NaT + + assert NaT - offset is NaT + assert (-offset)._apply(NaT) is NaT + + def test_offset_n(self, offset_types): + offset = _create_offset(offset_types) + assert offset.n == 1 + + neg_offset = offset * -1 + assert neg_offset.n == -1 + + mul_offset = offset * 3 + assert mul_offset.n == 3 + + def test_offset_timedelta64_arg(self, offset_types): + # check that offset._validate_n raises TypeError on a timedelt64 + # object + off = _create_offset(offset_types) + + td64 = np.timedelta64(4567, "s") + with pytest.raises(TypeError, match="argument must be an integer"): + type(off)(n=td64, **off.kwds) + + def test_offset_mul_ndarray(self, offset_types): + off = _create_offset(offset_types) + + expected = np.array([[off, off * 2], [off * 3, off * 4]]) + + result = np.array([[1, 2], [3, 4]]) * off + tm.assert_numpy_array_equal(result, expected) + + result = off * np.array([[1, 2], [3, 4]]) + tm.assert_numpy_array_equal(result, expected) + + def test_offset_freqstr(self, offset_types): + offset = _create_offset(offset_types) + + freqstr = offset.freqstr + if freqstr not in ("", "", "LWOM-SAT"): + code = _get_offset(freqstr) + assert offset.rule_code == code + + def _check_offsetfunc_works(self, offset, funcname, dt, expected, normalize=False): + if normalize and issubclass(offset, Tick): + # normalize=True disallowed for Tick subclasses GH#21427 + return + + offset_s = _create_offset(offset, normalize=normalize) + func = getattr(offset_s, funcname) + + result = func(dt) + assert isinstance(result, Timestamp) + assert result == expected + + result = func(Timestamp(dt)) + assert isinstance(result, Timestamp) + assert result == expected + + # see gh-14101 + ts = Timestamp(dt) + Nano(5) + # test nanosecond is preserved + with tm.assert_produces_warning(None): + result = func(ts) + + assert isinstance(result, Timestamp) + if normalize is False: + assert result == expected + Nano(5) + else: + assert result == expected + + if isinstance(dt, np.datetime64): + # test tz when input is datetime or Timestamp + return + + for tz in [ + None, + "UTC", + "Asia/Tokyo", + "US/Eastern", + "dateutil/Asia/Tokyo", + "dateutil/US/Pacific", + ]: + expected_localize = expected.tz_localize(tz) + tz_obj = timezones.maybe_get_tz(tz) + dt_tz = conversion.localize_pydatetime(dt, tz_obj) + + result = func(dt_tz) + assert isinstance(result, Timestamp) + assert result == expected_localize + + result = func(Timestamp(dt, tz=tz)) + assert isinstance(result, Timestamp) + assert result == expected_localize + + # see gh-14101 + ts = Timestamp(dt, tz=tz) + Nano(5) + # test nanosecond is preserved + with tm.assert_produces_warning(None): + result = func(ts) + assert isinstance(result, Timestamp) + if normalize is False: + assert result == expected_localize + Nano(5) + else: + assert result == expected_localize + + def test_apply(self, offset_types, expecteds): + sdt = datetime(2011, 1, 1, 9, 0) + ndt = np.datetime64("2011-01-01 09:00") + + expected = expecteds[offset_types.__name__] + expected_norm = Timestamp(expected.date()) + + for dt in [sdt, ndt]: + self._check_offsetfunc_works(offset_types, "_apply", dt, expected) + + self._check_offsetfunc_works( + offset_types, "_apply", dt, expected_norm, normalize=True + ) + + def test_rollforward(self, offset_types, expecteds): + expecteds = expecteds.copy() + + # result will not be changed if the target is on the offset + no_changes = [ + "Day", + "MonthBegin", + "SemiMonthBegin", + "YearBegin", + "Week", + "Hour", + "Minute", + "Second", + "Milli", + "Micro", + "Nano", + "DateOffset", + ] + for n in no_changes: + expecteds[n] = Timestamp("2011/01/01 09:00") + + expecteds["BusinessHour"] = Timestamp("2011-01-03 09:00:00") + expecteds["CustomBusinessHour"] = Timestamp("2011-01-03 09:00:00") + + # but be changed when normalize=True + norm_expected = expecteds.copy() + for k in norm_expected: + norm_expected[k] = Timestamp(norm_expected[k].date()) + + normalized = { + "Day": Timestamp("2011-01-02 00:00:00"), + "DateOffset": Timestamp("2011-01-02 00:00:00"), + "MonthBegin": Timestamp("2011-02-01 00:00:00"), + "SemiMonthBegin": Timestamp("2011-01-15 00:00:00"), + "YearBegin": Timestamp("2012-01-01 00:00:00"), + "Week": Timestamp("2011-01-08 00:00:00"), + "Hour": Timestamp("2011-01-01 00:00:00"), + "Minute": Timestamp("2011-01-01 00:00:00"), + "Second": Timestamp("2011-01-01 00:00:00"), + "Milli": Timestamp("2011-01-01 00:00:00"), + "Micro": Timestamp("2011-01-01 00:00:00"), + } + norm_expected.update(normalized) + + sdt = datetime(2011, 1, 1, 9, 0) + ndt = np.datetime64("2011-01-01 09:00") + + for dt in [sdt, ndt]: + expected = expecteds[offset_types.__name__] + self._check_offsetfunc_works(offset_types, "rollforward", dt, expected) + expected = norm_expected[offset_types.__name__] + self._check_offsetfunc_works( + offset_types, "rollforward", dt, expected, normalize=True + ) + + def test_rollback(self, offset_types): + expecteds = { + "BusinessDay": Timestamp("2010-12-31 09:00:00"), + "CustomBusinessDay": Timestamp("2010-12-31 09:00:00"), + "CustomBusinessMonthEnd": Timestamp("2010-12-31 09:00:00"), + "CustomBusinessMonthBegin": Timestamp("2010-12-01 09:00:00"), + "BusinessMonthBegin": Timestamp("2010-12-01 09:00:00"), + "MonthEnd": Timestamp("2010-12-31 09:00:00"), + "SemiMonthEnd": Timestamp("2010-12-31 09:00:00"), + "BusinessMonthEnd": Timestamp("2010-12-31 09:00:00"), + "BYearBegin": Timestamp("2010-01-01 09:00:00"), + "YearEnd": Timestamp("2010-12-31 09:00:00"), + "BYearEnd": Timestamp("2010-12-31 09:00:00"), + "QuarterBegin": Timestamp("2010-12-01 09:00:00"), + "BQuarterBegin": Timestamp("2010-12-01 09:00:00"), + "QuarterEnd": Timestamp("2010-12-31 09:00:00"), + "BQuarterEnd": Timestamp("2010-12-31 09:00:00"), + "BusinessHour": Timestamp("2010-12-31 17:00:00"), + "CustomBusinessHour": Timestamp("2010-12-31 17:00:00"), + "WeekOfMonth": Timestamp("2010-12-11 09:00:00"), + "LastWeekOfMonth": Timestamp("2010-12-25 09:00:00"), + "FY5253Quarter": Timestamp("2010-10-26 09:00:00"), + "FY5253": Timestamp("2010-01-26 09:00:00"), + "Easter": Timestamp("2010-04-04 09:00:00"), + } + + # result will not be changed if the target is on the offset + for n in [ + "Day", + "MonthBegin", + "SemiMonthBegin", + "YearBegin", + "Week", + "Hour", + "Minute", + "Second", + "Milli", + "Micro", + "Nano", + "DateOffset", + ]: + expecteds[n] = Timestamp("2011/01/01 09:00") + + # but be changed when normalize=True + norm_expected = expecteds.copy() + for k in norm_expected: + norm_expected[k] = Timestamp(norm_expected[k].date()) + + normalized = { + "Day": Timestamp("2010-12-31 00:00:00"), + "DateOffset": Timestamp("2010-12-31 00:00:00"), + "MonthBegin": Timestamp("2010-12-01 00:00:00"), + "SemiMonthBegin": Timestamp("2010-12-15 00:00:00"), + "YearBegin": Timestamp("2010-01-01 00:00:00"), + "Week": Timestamp("2010-12-25 00:00:00"), + "Hour": Timestamp("2011-01-01 00:00:00"), + "Minute": Timestamp("2011-01-01 00:00:00"), + "Second": Timestamp("2011-01-01 00:00:00"), + "Milli": Timestamp("2011-01-01 00:00:00"), + "Micro": Timestamp("2011-01-01 00:00:00"), + } + norm_expected.update(normalized) + + sdt = datetime(2011, 1, 1, 9, 0) + ndt = np.datetime64("2011-01-01 09:00") + + for dt in [sdt, ndt]: + expected = expecteds[offset_types.__name__] + self._check_offsetfunc_works(offset_types, "rollback", dt, expected) + + expected = norm_expected[offset_types.__name__] + self._check_offsetfunc_works( + offset_types, "rollback", dt, expected, normalize=True + ) + + def test_is_on_offset(self, offset_types, expecteds): + dt = expecteds[offset_types.__name__] + offset_s = _create_offset(offset_types) + assert offset_s.is_on_offset(dt) + + # when normalize=True, is_on_offset checks time is 00:00:00 + if issubclass(offset_types, Tick): + # normalize=True disallowed for Tick subclasses GH#21427 + return + offset_n = _create_offset(offset_types, normalize=True) + assert not offset_n.is_on_offset(dt) + + if offset_types in (BusinessHour, CustomBusinessHour): + # In default BusinessHour (9:00-17:00), normalized time + # cannot be in business hour range + return + date = datetime(dt.year, dt.month, dt.day) + assert offset_n.is_on_offset(date) + + def test_add(self, offset_types, tz_naive_fixture, expecteds): + tz = tz_naive_fixture + dt = datetime(2011, 1, 1, 9, 0) + + offset_s = _create_offset(offset_types) + expected = expecteds[offset_types.__name__] + + result_dt = dt + offset_s + result_ts = Timestamp(dt) + offset_s + for result in [result_dt, result_ts]: + assert isinstance(result, Timestamp) + assert result == expected + + expected_localize = expected.tz_localize(tz) + result = Timestamp(dt, tz=tz) + offset_s + assert isinstance(result, Timestamp) + assert result == expected_localize + + # normalize=True, disallowed for Tick subclasses GH#21427 + if issubclass(offset_types, Tick): + return + offset_s = _create_offset(offset_types, normalize=True) + expected = Timestamp(expected.date()) + + result_dt = dt + offset_s + result_ts = Timestamp(dt) + offset_s + for result in [result_dt, result_ts]: + assert isinstance(result, Timestamp) + assert result == expected + + expected_localize = expected.tz_localize(tz) + result = Timestamp(dt, tz=tz) + offset_s + assert isinstance(result, Timestamp) + assert result == expected_localize + + def test_add_empty_datetimeindex(self, offset_types, tz_naive_fixture): + # GH#12724, GH#30336 + offset_s = _create_offset(offset_types) + + dti = DatetimeIndex([], tz=tz_naive_fixture).as_unit("ns") + + warn = None + if isinstance( + offset_s, + ( + Easter, + WeekOfMonth, + LastWeekOfMonth, + CustomBusinessDay, + BusinessHour, + CustomBusinessHour, + CustomBusinessMonthBegin, + CustomBusinessMonthEnd, + FY5253, + FY5253Quarter, + ), + ): + # We don't have an optimized apply_index + warn = PerformanceWarning + + # stacklevel checking is slow, and we have ~800 of variants of this + # test, so let's only check the stacklevel in a subset of them + check_stacklevel = tz_naive_fixture is None + with tm.assert_produces_warning(warn, check_stacklevel=check_stacklevel): + result = dti + offset_s + tm.assert_index_equal(result, dti) + with tm.assert_produces_warning(warn, check_stacklevel=check_stacklevel): + result = offset_s + dti + tm.assert_index_equal(result, dti) + + dta = dti._data + with tm.assert_produces_warning(warn, check_stacklevel=check_stacklevel): + result = dta + offset_s + tm.assert_equal(result, dta) + with tm.assert_produces_warning(warn, check_stacklevel=check_stacklevel): + result = offset_s + dta + tm.assert_equal(result, dta) + + def test_pickle_roundtrip(self, offset_types): + off = _create_offset(offset_types) + res = tm.round_trip_pickle(off) + assert off == res + if type(off) is not DateOffset: + for attr in off._attributes: + if attr == "calendar": + # np.busdaycalendar __eq__ will return False; + # we check holidays and weekmask attrs so are OK + continue + # Make sure nothings got lost from _params (which __eq__) is based on + assert getattr(off, attr) == getattr(res, attr) + + def test_pickle_dateoffset_odd_inputs(self): + # GH#34511 + off = DateOffset(months=12) + res = tm.round_trip_pickle(off) + assert off == res + + base_dt = datetime(2020, 1, 1) + assert base_dt + off == base_dt + res + + def test_offsets_hashable(self, offset_types): + # GH: 37267 + off = _create_offset(offset_types) + assert hash(off) is not None + + # TODO: belongs in arithmetic tests? + @pytest.mark.filterwarnings( + "ignore:Non-vectorized DateOffset being applied to Series or DatetimeIndex" + ) + @pytest.mark.parametrize("unit", ["s", "ms", "us"]) + def test_add_dt64_ndarray_non_nano(self, offset_types, unit): + # check that the result with non-nano matches nano + off = _create_offset(offset_types) + + dti = date_range("2016-01-01", periods=35, freq="D", unit=unit) + + result = (dti + off)._with_freq(None) + + exp_unit = unit + if isinstance(off, Tick) and off._creso > dti._data._creso: + # cast to higher reso like we would with Timedelta scalar + exp_unit = Timedelta(off).unit + # TODO(GH#55564): as_unit will be unnecessary + expected = DatetimeIndex([x + off for x in dti]).as_unit(exp_unit) + + tm.assert_index_equal(result, expected) + + +class TestDateOffset: + def setup_method(self): + _offset_map.clear() + + def test_repr(self): + repr(DateOffset()) + repr(DateOffset(2)) + repr(2 * DateOffset()) + repr(2 * DateOffset(months=2)) + + def test_mul(self): + assert DateOffset(2) == 2 * DateOffset(1) + assert DateOffset(2) == DateOffset(1) * 2 + + @pytest.mark.parametrize("kwd", sorted(liboffsets._relativedelta_kwds)) + def test_constructor(self, kwd, request): + if kwd == "millisecond": + request.applymarker( + pytest.mark.xfail( + raises=NotImplementedError, + reason="Constructing DateOffset object with `millisecond` is not " + "yet supported.", + ) + ) + offset = DateOffset(**{kwd: 2}) + assert offset.kwds == {kwd: 2} + assert getattr(offset, kwd) == 2 + + def test_default_constructor(self, dt): + assert (dt + DateOffset(2)) == datetime(2008, 1, 4) + + def test_is_anchored(self): + msg = "DateOffset.is_anchored is deprecated " + + with tm.assert_produces_warning(FutureWarning, match=msg): + assert not DateOffset(2).is_anchored() + assert DateOffset(1).is_anchored() + + def test_copy(self): + assert DateOffset(months=2).copy() == DateOffset(months=2) + assert DateOffset(milliseconds=1).copy() == DateOffset(milliseconds=1) + + @pytest.mark.parametrize( + "arithmatic_offset_type, expected", + zip( + _ARITHMETIC_DATE_OFFSET, + [ + "2009-01-02", + "2008-02-02", + "2008-01-09", + "2008-01-03", + "2008-01-02 01:00:00", + "2008-01-02 00:01:00", + "2008-01-02 00:00:01", + "2008-01-02 00:00:00.001000000", + "2008-01-02 00:00:00.000001000", + ], + ), + ) + def test_add(self, arithmatic_offset_type, expected, dt): + assert DateOffset(**{arithmatic_offset_type: 1}) + dt == Timestamp(expected) + assert dt + DateOffset(**{arithmatic_offset_type: 1}) == Timestamp(expected) + + @pytest.mark.parametrize( + "arithmatic_offset_type, expected", + zip( + _ARITHMETIC_DATE_OFFSET, + [ + "2007-01-02", + "2007-12-02", + "2007-12-26", + "2008-01-01", + "2008-01-01 23:00:00", + "2008-01-01 23:59:00", + "2008-01-01 23:59:59", + "2008-01-01 23:59:59.999000000", + "2008-01-01 23:59:59.999999000", + ], + ), + ) + def test_sub(self, arithmatic_offset_type, expected, dt): + assert dt - DateOffset(**{arithmatic_offset_type: 1}) == Timestamp(expected) + with pytest.raises(TypeError, match="Cannot subtract datetime from offset"): + DateOffset(**{arithmatic_offset_type: 1}) - dt + + @pytest.mark.parametrize( + "arithmatic_offset_type, n, expected", + zip( + _ARITHMETIC_DATE_OFFSET, + range(1, 10), + [ + "2009-01-02", + "2008-03-02", + "2008-01-23", + "2008-01-06", + "2008-01-02 05:00:00", + "2008-01-02 00:06:00", + "2008-01-02 00:00:07", + "2008-01-02 00:00:00.008000000", + "2008-01-02 00:00:00.000009000", + ], + ), + ) + def test_mul_add(self, arithmatic_offset_type, n, expected, dt): + assert DateOffset(**{arithmatic_offset_type: 1}) * n + dt == Timestamp(expected) + assert n * DateOffset(**{arithmatic_offset_type: 1}) + dt == Timestamp(expected) + assert dt + DateOffset(**{arithmatic_offset_type: 1}) * n == Timestamp(expected) + assert dt + n * DateOffset(**{arithmatic_offset_type: 1}) == Timestamp(expected) + + @pytest.mark.parametrize( + "arithmatic_offset_type, n, expected", + zip( + _ARITHMETIC_DATE_OFFSET, + range(1, 10), + [ + "2007-01-02", + "2007-11-02", + "2007-12-12", + "2007-12-29", + "2008-01-01 19:00:00", + "2008-01-01 23:54:00", + "2008-01-01 23:59:53", + "2008-01-01 23:59:59.992000000", + "2008-01-01 23:59:59.999991000", + ], + ), + ) + def test_mul_sub(self, arithmatic_offset_type, n, expected, dt): + assert dt - DateOffset(**{arithmatic_offset_type: 1}) * n == Timestamp(expected) + assert dt - n * DateOffset(**{arithmatic_offset_type: 1}) == Timestamp(expected) + + def test_leap_year(self): + d = datetime(2008, 1, 31) + assert (d + DateOffset(months=1)) == datetime(2008, 2, 29) + + def test_eq(self): + offset1 = DateOffset(days=1) + offset2 = DateOffset(days=365) + + assert offset1 != offset2 + + assert DateOffset(milliseconds=3) != DateOffset(milliseconds=7) + + @pytest.mark.parametrize( + "offset_kwargs, expected_arg", + [ + ({"microseconds": 1, "milliseconds": 1}, "2022-01-01 00:00:00.001001"), + ({"seconds": 1, "milliseconds": 1}, "2022-01-01 00:00:01.001"), + ({"minutes": 1, "milliseconds": 1}, "2022-01-01 00:01:00.001"), + ({"hours": 1, "milliseconds": 1}, "2022-01-01 01:00:00.001"), + ({"days": 1, "milliseconds": 1}, "2022-01-02 00:00:00.001"), + ({"weeks": 1, "milliseconds": 1}, "2022-01-08 00:00:00.001"), + ({"months": 1, "milliseconds": 1}, "2022-02-01 00:00:00.001"), + ({"years": 1, "milliseconds": 1}, "2023-01-01 00:00:00.001"), + ], + ) + def test_milliseconds_combination(self, offset_kwargs, expected_arg): + # GH 49897 + offset = DateOffset(**offset_kwargs) + ts = Timestamp("2022-01-01") + result = ts + offset + expected = Timestamp(expected_arg) + + assert result == expected + + def test_offset_invalid_arguments(self): + msg = "^Invalid argument/s or bad combination of arguments" + with pytest.raises(ValueError, match=msg): + DateOffset(picoseconds=1) + + +class TestOffsetNames: + def test_get_offset_name(self): + assert BDay().freqstr == "B" + assert BDay(2).freqstr == "2B" + assert BMonthEnd().freqstr == "BME" + assert Week(weekday=0).freqstr == "W-MON" + assert Week(weekday=1).freqstr == "W-TUE" + assert Week(weekday=2).freqstr == "W-WED" + assert Week(weekday=3).freqstr == "W-THU" + assert Week(weekday=4).freqstr == "W-FRI" + + assert LastWeekOfMonth(weekday=WeekDay.SUN).freqstr == "LWOM-SUN" + + +def test_get_offset(): + with pytest.raises(ValueError, match=INVALID_FREQ_ERR_MSG): + _get_offset("gibberish") + with pytest.raises(ValueError, match=INVALID_FREQ_ERR_MSG): + _get_offset("QS-JAN-B") + + pairs = [ + ("B", BDay()), + ("b", BDay()), + ("bme", BMonthEnd()), + ("Bme", BMonthEnd()), + ("W-MON", Week(weekday=0)), + ("W-TUE", Week(weekday=1)), + ("W-WED", Week(weekday=2)), + ("W-THU", Week(weekday=3)), + ("W-FRI", Week(weekday=4)), + ] + + for name, expected in pairs: + offset = _get_offset(name) + assert offset == expected, ( + f"Expected {repr(name)} to yield {repr(expected)} " + f"(actual: {repr(offset)})" + ) + + +def test_get_offset_legacy(): + pairs = [("w@Sat", Week(weekday=5))] + for name, expected in pairs: + with pytest.raises(ValueError, match=INVALID_FREQ_ERR_MSG): + _get_offset(name) + + +class TestOffsetAliases: + def setup_method(self): + _offset_map.clear() + + def test_alias_equality(self): + for k, v in _offset_map.items(): + if v is None: + continue + assert k == v.copy() + + def test_rule_code(self): + lst = ["ME", "MS", "BME", "BMS", "D", "B", "h", "min", "s", "ms", "us"] + for k in lst: + assert k == _get_offset(k).rule_code + # should be cached - this is kind of an internals test... + assert k in _offset_map + assert k == (_get_offset(k) * 3).rule_code + + suffix_lst = ["MON", "TUE", "WED", "THU", "FRI", "SAT", "SUN"] + base = "W" + for v in suffix_lst: + alias = "-".join([base, v]) + assert alias == _get_offset(alias).rule_code + assert alias == (_get_offset(alias) * 5).rule_code + + suffix_lst = [ + "JAN", + "FEB", + "MAR", + "APR", + "MAY", + "JUN", + "JUL", + "AUG", + "SEP", + "OCT", + "NOV", + "DEC", + ] + base_lst = ["YE", "YS", "BYE", "BYS", "QE", "QS", "BQE", "BQS"] + for base in base_lst: + for v in suffix_lst: + alias = "-".join([base, v]) + assert alias == _get_offset(alias).rule_code + assert alias == (_get_offset(alias) * 5).rule_code + + +def test_freq_offsets(): + off = BDay(1, offset=timedelta(0, 1800)) + assert off.freqstr == "B+30Min" + + off = BDay(1, offset=timedelta(0, -1800)) + assert off.freqstr == "B-30Min" + + +class TestReprNames: + def test_str_for_named_is_name(self): + # look at all the amazing combinations! + month_prefixes = ["YE", "YS", "BYE", "BYS", "QE", "BQE", "BQS", "QS"] + names = [ + prefix + "-" + month + for prefix in month_prefixes + for month in [ + "JAN", + "FEB", + "MAR", + "APR", + "MAY", + "JUN", + "JUL", + "AUG", + "SEP", + "OCT", + "NOV", + "DEC", + ] + ] + days = ["MON", "TUE", "WED", "THU", "FRI", "SAT", "SUN"] + names += ["W-" + day for day in days] + names += ["WOM-" + week + day for week in ("1", "2", "3", "4") for day in days] + _offset_map.clear() + for name in names: + offset = _get_offset(name) + assert offset.freqstr == name + + +# --------------------------------------------------------------------- + + +def test_valid_default_arguments(offset_types): + # GH#19142 check that the calling the constructors without passing + # any keyword arguments produce valid offsets + cls = offset_types + cls() + + +@pytest.mark.parametrize("kwd", sorted(liboffsets._relativedelta_kwds)) +def test_valid_month_attributes(kwd, month_classes): + # GH#18226 + cls = month_classes + # check that we cannot create e.g. MonthEnd(weeks=3) + msg = rf"__init__\(\) got an unexpected keyword argument '{kwd}'" + with pytest.raises(TypeError, match=msg): + cls(**{kwd: 3}) + + +def test_month_offset_name(month_classes): + # GH#33757 off.name with n != 1 should not raise AttributeError + obj = month_classes(1) + obj2 = month_classes(2) + assert obj2.name == obj.name + + +@pytest.mark.parametrize("kwd", sorted(liboffsets._relativedelta_kwds)) +def test_valid_relativedelta_kwargs(kwd, request): + if kwd == "millisecond": + request.applymarker( + pytest.mark.xfail( + raises=NotImplementedError, + reason="Constructing DateOffset object with `millisecond` is not " + "yet supported.", + ) + ) + # Check that all the arguments specified in liboffsets._relativedelta_kwds + # are in fact valid relativedelta keyword args + DateOffset(**{kwd: 1}) + + +@pytest.mark.parametrize("kwd", sorted(liboffsets._relativedelta_kwds)) +def test_valid_tick_attributes(kwd, tick_classes): + # GH#18226 + cls = tick_classes + # check that we cannot create e.g. Hour(weeks=3) + msg = rf"__init__\(\) got an unexpected keyword argument '{kwd}'" + with pytest.raises(TypeError, match=msg): + cls(**{kwd: 3}) + + +def test_validate_n_error(): + with pytest.raises(TypeError, match="argument must be an integer"): + DateOffset(n="Doh!") + + with pytest.raises(TypeError, match="argument must be an integer"): + MonthBegin(n=timedelta(1)) + + with pytest.raises(TypeError, match="argument must be an integer"): + BDay(n=np.array([1, 2], dtype=np.int64)) + + +def test_require_integers(offset_types): + cls = offset_types + with pytest.raises(ValueError, match="argument must be an integer"): + cls(n=1.5) + + +def test_tick_normalize_raises(tick_classes): + # check that trying to create a Tick object with normalize=True raises + # GH#21427 + cls = tick_classes + msg = "Tick offset with `normalize=True` are not allowed." + with pytest.raises(ValueError, match=msg): + cls(n=3, normalize=True) + + +@pytest.mark.parametrize( + "offset_kwargs, expected_arg", + [ + ({"nanoseconds": 1}, "1970-01-01 00:00:00.000000001"), + ({"nanoseconds": 5}, "1970-01-01 00:00:00.000000005"), + ({"nanoseconds": -1}, "1969-12-31 23:59:59.999999999"), + ({"microseconds": 1}, "1970-01-01 00:00:00.000001"), + ({"microseconds": -1}, "1969-12-31 23:59:59.999999"), + ({"seconds": 1}, "1970-01-01 00:00:01"), + ({"seconds": -1}, "1969-12-31 23:59:59"), + ({"minutes": 1}, "1970-01-01 00:01:00"), + ({"minutes": -1}, "1969-12-31 23:59:00"), + ({"hours": 1}, "1970-01-01 01:00:00"), + ({"hours": -1}, "1969-12-31 23:00:00"), + ({"days": 1}, "1970-01-02 00:00:00"), + ({"days": -1}, "1969-12-31 00:00:00"), + ({"weeks": 1}, "1970-01-08 00:00:00"), + ({"weeks": -1}, "1969-12-25 00:00:00"), + ({"months": 1}, "1970-02-01 00:00:00"), + ({"months": -1}, "1969-12-01 00:00:00"), + ({"years": 1}, "1971-01-01 00:00:00"), + ({"years": -1}, "1969-01-01 00:00:00"), + ], +) +def test_dateoffset_add_sub(offset_kwargs, expected_arg): + offset = DateOffset(**offset_kwargs) + ts = Timestamp(0) + result = ts + offset + expected = Timestamp(expected_arg) + assert result == expected + result -= offset + assert result == ts + result = offset + ts + assert result == expected + + +def test_dateoffset_add_sub_timestamp_with_nano(): + offset = DateOffset(minutes=2, nanoseconds=9) + ts = Timestamp(4) + result = ts + offset + expected = Timestamp("1970-01-01 00:02:00.000000013") + assert result == expected + result -= offset + assert result == ts + result = offset + ts + assert result == expected + + offset2 = DateOffset(minutes=2, nanoseconds=9, hour=1) + assert offset2._use_relativedelta + with tm.assert_produces_warning(None): + # no warning about Discarding nonzero nanoseconds + result2 = ts + offset2 + expected2 = Timestamp("1970-01-01 01:02:00.000000013") + assert result2 == expected2 + + +@pytest.mark.parametrize( + "attribute", + [ + "hours", + "days", + "weeks", + "months", + "years", + ], +) +def test_dateoffset_immutable(attribute): + offset = DateOffset(**{attribute: 0}) + msg = "DateOffset objects are immutable" + with pytest.raises(AttributeError, match=msg): + setattr(offset, attribute, 5) + + +def test_dateoffset_misc(): + oset = offsets.DateOffset(months=2, days=4) + # it works + oset.freqstr + + assert not offsets.DateOffset(months=2) == 2 + + +@pytest.mark.parametrize("n", [-1, 1, 3]) +def test_construct_int_arg_no_kwargs_assumed_days(n): + # GH 45890, 45643 + offset = DateOffset(n) + assert offset._offset == timedelta(1) + result = Timestamp(2022, 1, 2) + offset + expected = Timestamp(2022, 1, 2 + n) + assert result == expected + + +@pytest.mark.parametrize( + "offset, expected", + [ + ( + DateOffset(minutes=7, nanoseconds=18), + Timestamp("2022-01-01 00:07:00.000000018"), + ), + (DateOffset(nanoseconds=3), Timestamp("2022-01-01 00:00:00.000000003")), + ], +) +def test_dateoffset_add_sub_timestamp_series_with_nano(offset, expected): + # GH 47856 + start_time = Timestamp("2022-01-01") + teststamp = start_time + testseries = Series([start_time]) + testseries = testseries + offset + assert testseries[0] == expected + testseries -= offset + assert testseries[0] == teststamp + testseries = offset + testseries + assert testseries[0] == expected + + +@pytest.mark.parametrize( + "n_months, scaling_factor, start_timestamp, expected_timestamp", + [ + (1, 2, "2020-01-30", "2020-03-30"), + (2, 1, "2020-01-30", "2020-03-30"), + (1, 0, "2020-01-30", "2020-01-30"), + (2, 0, "2020-01-30", "2020-01-30"), + (1, -1, "2020-01-30", "2019-12-30"), + (2, -1, "2020-01-30", "2019-11-30"), + ], +) +def test_offset_multiplication( + n_months, scaling_factor, start_timestamp, expected_timestamp +): + # GH 47953 + mo1 = DateOffset(months=n_months) + + startscalar = Timestamp(start_timestamp) + startarray = Series([startscalar]) + + resultscalar = startscalar + (mo1 * scaling_factor) + resultarray = startarray + (mo1 * scaling_factor) + + expectedscalar = Timestamp(expected_timestamp) + expectedarray = Series([expectedscalar]) + assert resultscalar == expectedscalar + + tm.assert_series_equal(resultarray, expectedarray) + + +def test_dateoffset_operations_on_dataframes(): + # GH 47953 + df = DataFrame({"T": [Timestamp("2019-04-30")], "D": [DateOffset(months=1)]}) + frameresult1 = df["T"] + 26 * df["D"] + df2 = DataFrame( + { + "T": [Timestamp("2019-04-30"), Timestamp("2019-04-30")], + "D": [DateOffset(months=1), DateOffset(months=1)], + } + ) + expecteddate = Timestamp("2021-06-30") + with tm.assert_produces_warning(PerformanceWarning): + frameresult2 = df2["T"] + 26 * df2["D"] + + assert frameresult1[0] == expecteddate + assert frameresult2[0] == expecteddate + + +def test_is_yqm_start_end(): + freq_m = to_offset("ME") + bm = to_offset("BME") + qfeb = to_offset("QE-FEB") + qsfeb = to_offset("QS-FEB") + bq = to_offset("BQE") + bqs_apr = to_offset("BQS-APR") + as_nov = to_offset("YS-NOV") + + tests = [ + (freq_m.is_month_start(Timestamp("2013-06-01")), 1), + (bm.is_month_start(Timestamp("2013-06-01")), 0), + (freq_m.is_month_start(Timestamp("2013-06-03")), 0), + (bm.is_month_start(Timestamp("2013-06-03")), 1), + (qfeb.is_month_end(Timestamp("2013-02-28")), 1), + (qfeb.is_quarter_end(Timestamp("2013-02-28")), 1), + (qfeb.is_year_end(Timestamp("2013-02-28")), 1), + (qfeb.is_month_start(Timestamp("2013-03-01")), 1), + (qfeb.is_quarter_start(Timestamp("2013-03-01")), 1), + (qfeb.is_year_start(Timestamp("2013-03-01")), 1), + (qsfeb.is_month_end(Timestamp("2013-03-31")), 1), + (qsfeb.is_quarter_end(Timestamp("2013-03-31")), 0), + (qsfeb.is_year_end(Timestamp("2013-03-31")), 0), + (qsfeb.is_month_start(Timestamp("2013-02-01")), 1), + (qsfeb.is_quarter_start(Timestamp("2013-02-01")), 1), + (qsfeb.is_year_start(Timestamp("2013-02-01")), 1), + (bq.is_month_end(Timestamp("2013-06-30")), 0), + (bq.is_quarter_end(Timestamp("2013-06-30")), 0), + (bq.is_year_end(Timestamp("2013-06-30")), 0), + (bq.is_month_end(Timestamp("2013-06-28")), 1), + (bq.is_quarter_end(Timestamp("2013-06-28")), 1), + (bq.is_year_end(Timestamp("2013-06-28")), 0), + (bqs_apr.is_month_end(Timestamp("2013-06-30")), 0), + (bqs_apr.is_quarter_end(Timestamp("2013-06-30")), 0), + (bqs_apr.is_year_end(Timestamp("2013-06-30")), 0), + (bqs_apr.is_month_end(Timestamp("2013-06-28")), 1), + (bqs_apr.is_quarter_end(Timestamp("2013-06-28")), 1), + (bqs_apr.is_year_end(Timestamp("2013-03-29")), 1), + (as_nov.is_year_start(Timestamp("2013-11-01")), 1), + (as_nov.is_year_end(Timestamp("2013-10-31")), 1), + (Timestamp("2012-02-01").days_in_month, 29), + (Timestamp("2013-02-01").days_in_month, 28), + ] + + for ts, value in tests: + assert ts == value diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_offsets_properties.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_offsets_properties.py new file mode 100644 index 0000000000000000000000000000000000000000..1b4fa9292c4031c8c2acec0e1f34fd871bcb50bd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_offsets_properties.py @@ -0,0 +1,60 @@ +""" +Behavioral based tests for offsets and date_range. + +This file is adapted from https://github.com/pandas-dev/pandas/pull/18761 - +which was more ambitious but less idiomatic in its use of Hypothesis. + +You may wish to consult the previous version for inspiration on further +tests, or when trying to pin down the bugs exposed by the tests below. +""" +from hypothesis import ( + assume, + given, +) +import pytest +import pytz + +import pandas as pd +from pandas._testing._hypothesis import ( + DATETIME_JAN_1_1900_OPTIONAL_TZ, + YQM_OFFSET, +) + +# ---------------------------------------------------------------- +# Offset-specific behaviour tests + + +@pytest.mark.arm_slow +@given(DATETIME_JAN_1_1900_OPTIONAL_TZ, YQM_OFFSET) +def test_on_offset_implementations(dt, offset): + assume(not offset.normalize) + # check that the class-specific implementations of is_on_offset match + # the general case definition: + # (dt + offset) - offset == dt + try: + compare = (dt + offset) - offset + except (pytz.NonExistentTimeError, pytz.AmbiguousTimeError): + # When dt + offset does not exist or is DST-ambiguous, assume(False) to + # indicate to hypothesis that this is not a valid test case + # DST-ambiguous example (GH41906): + # dt = datetime.datetime(1900, 1, 1, tzinfo=pytz.timezone('Africa/Kinshasa')) + # offset = MonthBegin(66) + assume(False) + + assert offset.is_on_offset(dt) == (compare == dt) + + +@given(YQM_OFFSET) +def test_shift_across_dst(offset): + # GH#18319 check that 1) timezone is correctly normalized and + # 2) that hour is not incorrectly changed by this normalization + assume(not offset.normalize) + + # Note that dti includes a transition across DST boundary + dti = pd.date_range( + start="2017-10-30 12:00:00", end="2017-11-06", freq="D", tz="US/Eastern" + ) + assert (dti.hour == 12).all() # we haven't screwed up yet + + res = dti + offset + assert (res.hour == 12).all() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_quarter.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_quarter.py new file mode 100644 index 0000000000000000000000000000000000000000..5fd3ba0a5fb87996a4e07fd25569e7161cb08930 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_quarter.py @@ -0,0 +1,303 @@ +""" +Tests for the following offsets: +- QuarterBegin +- QuarterEnd +""" +from __future__ import annotations + +from datetime import datetime + +import pytest + +import pandas._testing as tm +from pandas.tests.tseries.offsets.common import ( + assert_is_on_offset, + assert_offset_equal, +) + +from pandas.tseries.offsets import ( + QuarterBegin, + QuarterEnd, +) + + +@pytest.mark.parametrize("klass", (QuarterBegin, QuarterEnd)) +def test_quarterly_dont_normalize(klass): + date = datetime(2012, 3, 31, 5, 30) + result = date + klass() + assert result.time() == date.time() + + +@pytest.mark.parametrize("offset", [QuarterBegin(), QuarterEnd()]) +@pytest.mark.parametrize( + "date", + [ + datetime(2016, m, d) + for m in [10, 11, 12] + for d in [1, 2, 3, 28, 29, 30, 31] + if not (m == 11 and d == 31) + ], +) +def test_on_offset(offset, date): + res = offset.is_on_offset(date) + slow_version = date == (date + offset) - offset + assert res == slow_version + + +class TestQuarterBegin: + def test_repr(self): + expected = "" + assert repr(QuarterBegin()) == expected + expected = "" + assert repr(QuarterBegin(startingMonth=3)) == expected + expected = "" + assert repr(QuarterBegin(startingMonth=1)) == expected + + def test_is_anchored(self): + msg = "QuarterBegin.is_anchored is deprecated " + + with tm.assert_produces_warning(FutureWarning, match=msg): + assert QuarterBegin(startingMonth=1).is_anchored() + assert QuarterBegin().is_anchored() + assert not QuarterBegin(2, startingMonth=1).is_anchored() + + def test_offset_corner_case(self): + # corner + offset = QuarterBegin(n=-1, startingMonth=1) + assert datetime(2010, 2, 1) + offset == datetime(2010, 1, 1) + + offset_cases = [] + offset_cases.append( + ( + QuarterBegin(startingMonth=1), + { + datetime(2007, 12, 1): datetime(2008, 1, 1), + datetime(2008, 1, 1): datetime(2008, 4, 1), + datetime(2008, 2, 15): datetime(2008, 4, 1), + datetime(2008, 2, 29): datetime(2008, 4, 1), + datetime(2008, 3, 15): datetime(2008, 4, 1), + datetime(2008, 3, 31): datetime(2008, 4, 1), + datetime(2008, 4, 15): datetime(2008, 7, 1), + datetime(2008, 4, 1): datetime(2008, 7, 1), + }, + ) + ) + + offset_cases.append( + ( + QuarterBegin(startingMonth=2), + { + datetime(2008, 1, 1): datetime(2008, 2, 1), + datetime(2008, 1, 31): datetime(2008, 2, 1), + datetime(2008, 1, 15): datetime(2008, 2, 1), + datetime(2008, 2, 29): datetime(2008, 5, 1), + datetime(2008, 3, 15): datetime(2008, 5, 1), + datetime(2008, 3, 31): datetime(2008, 5, 1), + datetime(2008, 4, 15): datetime(2008, 5, 1), + datetime(2008, 4, 30): datetime(2008, 5, 1), + }, + ) + ) + + offset_cases.append( + ( + QuarterBegin(startingMonth=1, n=0), + { + datetime(2008, 1, 1): datetime(2008, 1, 1), + datetime(2008, 12, 1): datetime(2009, 1, 1), + datetime(2008, 1, 1): datetime(2008, 1, 1), + datetime(2008, 2, 15): datetime(2008, 4, 1), + datetime(2008, 2, 29): datetime(2008, 4, 1), + datetime(2008, 3, 15): datetime(2008, 4, 1), + datetime(2008, 3, 31): datetime(2008, 4, 1), + datetime(2008, 4, 15): datetime(2008, 7, 1), + datetime(2008, 4, 30): datetime(2008, 7, 1), + }, + ) + ) + + offset_cases.append( + ( + QuarterBegin(startingMonth=1, n=-1), + { + datetime(2008, 1, 1): datetime(2007, 10, 1), + datetime(2008, 1, 31): datetime(2008, 1, 1), + datetime(2008, 2, 15): datetime(2008, 1, 1), + datetime(2008, 2, 29): datetime(2008, 1, 1), + datetime(2008, 3, 15): datetime(2008, 1, 1), + datetime(2008, 3, 31): datetime(2008, 1, 1), + datetime(2008, 4, 15): datetime(2008, 4, 1), + datetime(2008, 4, 30): datetime(2008, 4, 1), + datetime(2008, 7, 1): datetime(2008, 4, 1), + }, + ) + ) + + offset_cases.append( + ( + QuarterBegin(startingMonth=1, n=2), + { + datetime(2008, 1, 1): datetime(2008, 7, 1), + datetime(2008, 2, 15): datetime(2008, 7, 1), + datetime(2008, 2, 29): datetime(2008, 7, 1), + datetime(2008, 3, 15): datetime(2008, 7, 1), + datetime(2008, 3, 31): datetime(2008, 7, 1), + datetime(2008, 4, 15): datetime(2008, 10, 1), + datetime(2008, 4, 1): datetime(2008, 10, 1), + }, + ) + ) + + @pytest.mark.parametrize("case", offset_cases) + def test_offset(self, case): + offset, cases = case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + +class TestQuarterEnd: + def test_repr(self): + expected = "" + assert repr(QuarterEnd()) == expected + expected = "" + assert repr(QuarterEnd(startingMonth=3)) == expected + expected = "" + assert repr(QuarterEnd(startingMonth=1)) == expected + + def test_is_anchored(self): + msg = "QuarterEnd.is_anchored is deprecated " + + with tm.assert_produces_warning(FutureWarning, match=msg): + assert QuarterEnd(startingMonth=1).is_anchored() + assert QuarterEnd().is_anchored() + assert not QuarterEnd(2, startingMonth=1).is_anchored() + + def test_offset_corner_case(self): + # corner + offset = QuarterEnd(n=-1, startingMonth=1) + assert datetime(2010, 2, 1) + offset == datetime(2010, 1, 31) + + offset_cases = [] + offset_cases.append( + ( + QuarterEnd(startingMonth=1), + { + datetime(2008, 1, 1): datetime(2008, 1, 31), + datetime(2008, 1, 31): datetime(2008, 4, 30), + datetime(2008, 2, 15): datetime(2008, 4, 30), + datetime(2008, 2, 29): datetime(2008, 4, 30), + datetime(2008, 3, 15): datetime(2008, 4, 30), + datetime(2008, 3, 31): datetime(2008, 4, 30), + datetime(2008, 4, 15): datetime(2008, 4, 30), + datetime(2008, 4, 30): datetime(2008, 7, 31), + }, + ) + ) + + offset_cases.append( + ( + QuarterEnd(startingMonth=2), + { + datetime(2008, 1, 1): datetime(2008, 2, 29), + datetime(2008, 1, 31): datetime(2008, 2, 29), + datetime(2008, 2, 15): datetime(2008, 2, 29), + datetime(2008, 2, 29): datetime(2008, 5, 31), + datetime(2008, 3, 15): datetime(2008, 5, 31), + datetime(2008, 3, 31): datetime(2008, 5, 31), + datetime(2008, 4, 15): datetime(2008, 5, 31), + datetime(2008, 4, 30): datetime(2008, 5, 31), + }, + ) + ) + + offset_cases.append( + ( + QuarterEnd(startingMonth=1, n=0), + { + datetime(2008, 1, 1): datetime(2008, 1, 31), + datetime(2008, 1, 31): datetime(2008, 1, 31), + datetime(2008, 2, 15): datetime(2008, 4, 30), + datetime(2008, 2, 29): datetime(2008, 4, 30), + datetime(2008, 3, 15): datetime(2008, 4, 30), + datetime(2008, 3, 31): datetime(2008, 4, 30), + datetime(2008, 4, 15): datetime(2008, 4, 30), + datetime(2008, 4, 30): datetime(2008, 4, 30), + }, + ) + ) + + offset_cases.append( + ( + QuarterEnd(startingMonth=1, n=-1), + { + datetime(2008, 1, 1): datetime(2007, 10, 31), + datetime(2008, 1, 31): datetime(2007, 10, 31), + datetime(2008, 2, 15): datetime(2008, 1, 31), + datetime(2008, 2, 29): datetime(2008, 1, 31), + datetime(2008, 3, 15): datetime(2008, 1, 31), + datetime(2008, 3, 31): datetime(2008, 1, 31), + datetime(2008, 4, 15): datetime(2008, 1, 31), + datetime(2008, 4, 30): datetime(2008, 1, 31), + datetime(2008, 7, 1): datetime(2008, 4, 30), + }, + ) + ) + + offset_cases.append( + ( + QuarterEnd(startingMonth=1, n=2), + { + datetime(2008, 1, 31): datetime(2008, 7, 31), + datetime(2008, 2, 15): datetime(2008, 7, 31), + datetime(2008, 2, 29): datetime(2008, 7, 31), + datetime(2008, 3, 15): datetime(2008, 7, 31), + datetime(2008, 3, 31): datetime(2008, 7, 31), + datetime(2008, 4, 15): datetime(2008, 7, 31), + datetime(2008, 4, 30): datetime(2008, 10, 31), + }, + ) + ) + + @pytest.mark.parametrize("case", offset_cases) + def test_offset(self, case): + offset, cases = case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + on_offset_cases = [ + (QuarterEnd(1, startingMonth=1), datetime(2008, 1, 31), True), + (QuarterEnd(1, startingMonth=1), datetime(2007, 12, 31), False), + (QuarterEnd(1, startingMonth=1), datetime(2008, 2, 29), False), + (QuarterEnd(1, startingMonth=1), datetime(2007, 3, 30), False), + (QuarterEnd(1, startingMonth=1), datetime(2007, 3, 31), False), + (QuarterEnd(1, startingMonth=1), datetime(2008, 4, 30), True), + (QuarterEnd(1, startingMonth=1), datetime(2008, 5, 30), False), + (QuarterEnd(1, startingMonth=1), datetime(2008, 5, 31), False), + (QuarterEnd(1, startingMonth=1), datetime(2007, 6, 29), False), + (QuarterEnd(1, startingMonth=1), datetime(2007, 6, 30), False), + (QuarterEnd(1, startingMonth=2), datetime(2008, 1, 31), False), + (QuarterEnd(1, startingMonth=2), datetime(2007, 12, 31), False), + (QuarterEnd(1, startingMonth=2), datetime(2008, 2, 29), True), + (QuarterEnd(1, startingMonth=2), datetime(2007, 3, 30), False), + (QuarterEnd(1, startingMonth=2), datetime(2007, 3, 31), False), + (QuarterEnd(1, startingMonth=2), datetime(2008, 4, 30), False), + (QuarterEnd(1, startingMonth=2), datetime(2008, 5, 30), False), + (QuarterEnd(1, startingMonth=2), datetime(2008, 5, 31), True), + (QuarterEnd(1, startingMonth=2), datetime(2007, 6, 29), False), + (QuarterEnd(1, startingMonth=2), datetime(2007, 6, 30), False), + (QuarterEnd(1, startingMonth=3), datetime(2008, 1, 31), False), + (QuarterEnd(1, startingMonth=3), datetime(2007, 12, 31), True), + (QuarterEnd(1, startingMonth=3), datetime(2008, 2, 29), False), + (QuarterEnd(1, startingMonth=3), datetime(2007, 3, 30), False), + (QuarterEnd(1, startingMonth=3), datetime(2007, 3, 31), True), + (QuarterEnd(1, startingMonth=3), datetime(2008, 4, 30), False), + (QuarterEnd(1, startingMonth=3), datetime(2008, 5, 30), False), + (QuarterEnd(1, startingMonth=3), datetime(2008, 5, 31), False), + (QuarterEnd(1, startingMonth=3), datetime(2007, 6, 29), False), + (QuarterEnd(1, startingMonth=3), datetime(2007, 6, 30), True), + ] + + @pytest.mark.parametrize("case", on_offset_cases) + def test_is_on_offset(self, case): + offset, dt, expected = case + assert_is_on_offset(offset, dt, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_ticks.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_ticks.py new file mode 100644 index 0000000000000000000000000000000000000000..399b7038d3426a9f3e4916927bd7e38ac3996531 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_ticks.py @@ -0,0 +1,405 @@ +""" +Tests for offsets.Tick and subclasses +""" +from datetime import ( + datetime, + timedelta, +) + +from hypothesis import ( + assume, + example, + given, +) +import numpy as np +import pytest + +from pandas._libs.tslibs.offsets import delta_to_tick +from pandas.errors import OutOfBoundsTimedelta + +from pandas import ( + Timedelta, + Timestamp, +) +import pandas._testing as tm +from pandas._testing._hypothesis import INT_NEG_999_TO_POS_999 +from pandas.tests.tseries.offsets.common import assert_offset_equal + +from pandas.tseries import offsets +from pandas.tseries.offsets import ( + Hour, + Micro, + Milli, + Minute, + Nano, + Second, +) + +# --------------------------------------------------------------------- +# Test Helpers + +tick_classes = [Hour, Minute, Second, Milli, Micro, Nano] + + +# --------------------------------------------------------------------- + + +def test_apply_ticks(): + result = offsets.Hour(3) + offsets.Hour(4) + exp = offsets.Hour(7) + assert result == exp + + +def test_delta_to_tick(): + delta = timedelta(3) + + tick = delta_to_tick(delta) + assert tick == offsets.Day(3) + + td = Timedelta(nanoseconds=5) + tick = delta_to_tick(td) + assert tick == Nano(5) + + +@pytest.mark.parametrize("cls", tick_classes) +@example(n=2, m=3) +@example(n=800, m=300) +@example(n=1000, m=5) +@given(n=INT_NEG_999_TO_POS_999, m=INT_NEG_999_TO_POS_999) +def test_tick_add_sub(cls, n, m): + # For all Tick subclasses and all integers n, m, we should have + # tick(n) + tick(m) == tick(n+m) + # tick(n) - tick(m) == tick(n-m) + left = cls(n) + right = cls(m) + expected = cls(n + m) + + assert left + right == expected + + expected = cls(n - m) + assert left - right == expected + + +@pytest.mark.arm_slow +@pytest.mark.parametrize("cls", tick_classes) +@example(n=2, m=3) +@given(n=INT_NEG_999_TO_POS_999, m=INT_NEG_999_TO_POS_999) +def test_tick_equality(cls, n, m): + assume(m != n) + # tick == tock iff tick.n == tock.n + left = cls(n) + right = cls(m) + assert left != right + + right = cls(n) + assert left == right + assert not left != right + + if n != 0: + assert cls(n) != cls(-n) + + +# --------------------------------------------------------------------- + + +def test_Hour(): + assert_offset_equal(Hour(), datetime(2010, 1, 1), datetime(2010, 1, 1, 1)) + assert_offset_equal(Hour(-1), datetime(2010, 1, 1, 1), datetime(2010, 1, 1)) + assert_offset_equal(2 * Hour(), datetime(2010, 1, 1), datetime(2010, 1, 1, 2)) + assert_offset_equal(-1 * Hour(), datetime(2010, 1, 1, 1), datetime(2010, 1, 1)) + + assert Hour(3) + Hour(2) == Hour(5) + assert Hour(3) - Hour(2) == Hour() + + assert Hour(4) != Hour(1) + + +def test_Minute(): + assert_offset_equal(Minute(), datetime(2010, 1, 1), datetime(2010, 1, 1, 0, 1)) + assert_offset_equal(Minute(-1), datetime(2010, 1, 1, 0, 1), datetime(2010, 1, 1)) + assert_offset_equal(2 * Minute(), datetime(2010, 1, 1), datetime(2010, 1, 1, 0, 2)) + assert_offset_equal(-1 * Minute(), datetime(2010, 1, 1, 0, 1), datetime(2010, 1, 1)) + + assert Minute(3) + Minute(2) == Minute(5) + assert Minute(3) - Minute(2) == Minute() + assert Minute(5) != Minute() + + +def test_Second(): + assert_offset_equal(Second(), datetime(2010, 1, 1), datetime(2010, 1, 1, 0, 0, 1)) + assert_offset_equal(Second(-1), datetime(2010, 1, 1, 0, 0, 1), datetime(2010, 1, 1)) + assert_offset_equal( + 2 * Second(), datetime(2010, 1, 1), datetime(2010, 1, 1, 0, 0, 2) + ) + assert_offset_equal( + -1 * Second(), datetime(2010, 1, 1, 0, 0, 1), datetime(2010, 1, 1) + ) + + assert Second(3) + Second(2) == Second(5) + assert Second(3) - Second(2) == Second() + + +def test_Millisecond(): + assert_offset_equal( + Milli(), datetime(2010, 1, 1), datetime(2010, 1, 1, 0, 0, 0, 1000) + ) + assert_offset_equal( + Milli(-1), datetime(2010, 1, 1, 0, 0, 0, 1000), datetime(2010, 1, 1) + ) + assert_offset_equal( + Milli(2), datetime(2010, 1, 1), datetime(2010, 1, 1, 0, 0, 0, 2000) + ) + assert_offset_equal( + 2 * Milli(), datetime(2010, 1, 1), datetime(2010, 1, 1, 0, 0, 0, 2000) + ) + assert_offset_equal( + -1 * Milli(), datetime(2010, 1, 1, 0, 0, 0, 1000), datetime(2010, 1, 1) + ) + + assert Milli(3) + Milli(2) == Milli(5) + assert Milli(3) - Milli(2) == Milli() + + +def test_MillisecondTimestampArithmetic(): + assert_offset_equal( + Milli(), Timestamp("2010-01-01"), Timestamp("2010-01-01 00:00:00.001") + ) + assert_offset_equal( + Milli(-1), Timestamp("2010-01-01 00:00:00.001"), Timestamp("2010-01-01") + ) + + +def test_Microsecond(): + assert_offset_equal(Micro(), datetime(2010, 1, 1), datetime(2010, 1, 1, 0, 0, 0, 1)) + assert_offset_equal( + Micro(-1), datetime(2010, 1, 1, 0, 0, 0, 1), datetime(2010, 1, 1) + ) + + assert_offset_equal( + 2 * Micro(), datetime(2010, 1, 1), datetime(2010, 1, 1, 0, 0, 0, 2) + ) + assert_offset_equal( + -1 * Micro(), datetime(2010, 1, 1, 0, 0, 0, 1), datetime(2010, 1, 1) + ) + + assert Micro(3) + Micro(2) == Micro(5) + assert Micro(3) - Micro(2) == Micro() + + +def test_NanosecondGeneric(): + timestamp = Timestamp(datetime(2010, 1, 1)) + assert timestamp.nanosecond == 0 + + result = timestamp + Nano(10) + assert result.nanosecond == 10 + + reverse_result = Nano(10) + timestamp + assert reverse_result.nanosecond == 10 + + +def test_Nanosecond(): + timestamp = Timestamp(datetime(2010, 1, 1)) + assert_offset_equal(Nano(), timestamp, timestamp + np.timedelta64(1, "ns")) + assert_offset_equal(Nano(-1), timestamp + np.timedelta64(1, "ns"), timestamp) + assert_offset_equal(2 * Nano(), timestamp, timestamp + np.timedelta64(2, "ns")) + assert_offset_equal(-1 * Nano(), timestamp + np.timedelta64(1, "ns"), timestamp) + + assert Nano(3) + Nano(2) == Nano(5) + assert Nano(3) - Nano(2) == Nano() + + # GH9284 + assert Nano(1) + Nano(10) == Nano(11) + assert Nano(5) + Micro(1) == Nano(1005) + assert Micro(5) + Nano(1) == Nano(5001) + + +@pytest.mark.parametrize( + "kls, expected", + [ + (Hour, Timedelta(hours=5)), + (Minute, Timedelta(hours=2, minutes=3)), + (Second, Timedelta(hours=2, seconds=3)), + (Milli, Timedelta(hours=2, milliseconds=3)), + (Micro, Timedelta(hours=2, microseconds=3)), + (Nano, Timedelta(hours=2, nanoseconds=3)), + ], +) +def test_tick_addition(kls, expected): + offset = kls(3) + td = Timedelta(hours=2) + + for other in [td, td.to_pytimedelta(), td.to_timedelta64()]: + result = offset + other + assert isinstance(result, Timedelta) + assert result == expected + + result = other + offset + assert isinstance(result, Timedelta) + assert result == expected + + +def test_tick_delta_overflow(): + # GH#55503 raise OutOfBoundsTimedelta, not OverflowError + tick = offsets.Day(10**9) + msg = "Cannot cast 1000000000 days 00:00:00 to unit='ns' without overflow" + depr_msg = "Day.delta is deprecated" + with pytest.raises(OutOfBoundsTimedelta, match=msg): + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + tick.delta + + +@pytest.mark.parametrize("cls", tick_classes) +def test_tick_division(cls): + off = cls(10) + + assert off / cls(5) == 2 + assert off / 2 == cls(5) + assert off / 2.0 == cls(5) + + assert off / off._as_pd_timedelta == 1 + assert off / off._as_pd_timedelta.to_timedelta64() == 1 + + assert off / Nano(1) == off._as_pd_timedelta / Nano(1)._as_pd_timedelta + + if cls is not Nano: + # A case where we end up with a smaller class + result = off / 1000 + assert isinstance(result, offsets.Tick) + assert not isinstance(result, cls) + assert result._as_pd_timedelta == off._as_pd_timedelta / 1000 + + if cls._nanos_inc < Timedelta(seconds=1)._value: + # Case where we end up with a bigger class + result = off / 0.001 + assert isinstance(result, offsets.Tick) + assert not isinstance(result, cls) + assert result._as_pd_timedelta == off._as_pd_timedelta / 0.001 + + +def test_tick_mul_float(): + off = Micro(2) + + # Case where we retain type + result = off * 1.5 + expected = Micro(3) + assert result == expected + assert isinstance(result, Micro) + + # Case where we bump up to the next type + result = off * 1.25 + expected = Nano(2500) + assert result == expected + assert isinstance(result, Nano) + + +@pytest.mark.parametrize("cls", tick_classes) +def test_tick_rdiv(cls): + off = cls(10) + delta = off._as_pd_timedelta + td64 = delta.to_timedelta64() + instance__type = ".".join([cls.__module__, cls.__name__]) + msg = ( + "unsupported operand type\\(s\\) for \\/: 'int'|'float' and " + f"'{instance__type}'" + ) + + with pytest.raises(TypeError, match=msg): + 2 / off + with pytest.raises(TypeError, match=msg): + 2.0 / off + + assert (td64 * 2.5) / off == 2.5 + + if cls is not Nano: + # skip pytimedelta for Nano since it gets dropped + assert (delta.to_pytimedelta() * 2) / off == 2 + + result = np.array([2 * td64, td64]) / off + expected = np.array([2.0, 1.0]) + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize("cls1", tick_classes) +@pytest.mark.parametrize("cls2", tick_classes) +def test_tick_zero(cls1, cls2): + assert cls1(0) == cls2(0) + assert cls1(0) + cls2(0) == cls1(0) + + if cls1 is not Nano: + assert cls1(2) + cls2(0) == cls1(2) + + if cls1 is Nano: + assert cls1(2) + Nano(0) == cls1(2) + + +@pytest.mark.parametrize("cls", tick_classes) +def test_tick_equalities(cls): + assert cls() == cls(1) + + +@pytest.mark.parametrize("cls", tick_classes) +def test_tick_offset(cls): + msg = f"{cls.__name__}.is_anchored is deprecated " + + with tm.assert_produces_warning(FutureWarning, match=msg): + assert not cls().is_anchored() + + +@pytest.mark.parametrize("cls", tick_classes) +def test_compare_ticks(cls): + three = cls(3) + four = cls(4) + + assert three < cls(4) + assert cls(3) < four + assert four > cls(3) + assert cls(4) > three + assert cls(3) == cls(3) + assert cls(3) != cls(4) + + +@pytest.mark.parametrize("cls", tick_classes) +def test_compare_ticks_to_strs(cls): + # GH#23524 + off = cls(19) + + # These tests should work with any strings, but we particularly are + # interested in "infer" as that comparison is convenient to make in + # Datetime/Timedelta Array/Index constructors + assert not off == "infer" + assert not "foo" == off + + instance_type = ".".join([cls.__module__, cls.__name__]) + msg = ( + "'<'|'<='|'>'|'>=' not supported between instances of " + f"'str' and '{instance_type}'|'{instance_type}' and 'str'" + ) + + for left, right in [("infer", off), (off, "infer")]: + with pytest.raises(TypeError, match=msg): + left < right + with pytest.raises(TypeError, match=msg): + left <= right + with pytest.raises(TypeError, match=msg): + left > right + with pytest.raises(TypeError, match=msg): + left >= right + + +@pytest.mark.parametrize("cls", tick_classes) +def test_compare_ticks_to_timedeltalike(cls): + off = cls(19) + + td = off._as_pd_timedelta + + others = [td, td.to_timedelta64()] + if cls is not Nano: + others.append(td.to_pytimedelta()) + + for other in others: + assert off == other + assert not off != other + assert not off < other + assert not off > other + assert off <= other + assert off >= other diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_week.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_week.py new file mode 100644 index 0000000000000000000000000000000000000000..0cd6f769769ae3c3ae39f6b4a8f10641cd297715 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_week.py @@ -0,0 +1,351 @@ +""" +Tests for the following offsets: +- Week +- WeekOfMonth +- LastWeekOfMonth +""" +from __future__ import annotations + +from datetime import ( + datetime, + timedelta, +) + +import pytest + +from pandas._libs.tslibs import Timestamp +from pandas._libs.tslibs.offsets import ( + Day, + LastWeekOfMonth, + Week, + WeekOfMonth, +) + +import pandas._testing as tm +from pandas.tests.tseries.offsets.common import ( + WeekDay, + assert_is_on_offset, + assert_offset_equal, +) + + +class TestWeek: + def test_repr(self): + assert repr(Week(weekday=0)) == "" + assert repr(Week(n=-1, weekday=0)) == "<-1 * Week: weekday=0>" + assert repr(Week(n=-2, weekday=0)) == "<-2 * Weeks: weekday=0>" + + def test_corner(self): + with pytest.raises(ValueError, match="Day must be"): + Week(weekday=7) + + with pytest.raises(ValueError, match="Day must be"): + Week(weekday=-1) + + def test_is_anchored(self): + msg = "Week.is_anchored is deprecated " + + with tm.assert_produces_warning(FutureWarning, match=msg): + assert Week(weekday=0).is_anchored() + assert not Week().is_anchored() + assert not Week(2, weekday=2).is_anchored() + assert not Week(2).is_anchored() + + offset_cases = [] + # not business week + offset_cases.append( + ( + Week(), + { + datetime(2008, 1, 1): datetime(2008, 1, 8), + datetime(2008, 1, 4): datetime(2008, 1, 11), + datetime(2008, 1, 5): datetime(2008, 1, 12), + datetime(2008, 1, 6): datetime(2008, 1, 13), + datetime(2008, 1, 7): datetime(2008, 1, 14), + }, + ) + ) + + # Mon + offset_cases.append( + ( + Week(weekday=0), + { + datetime(2007, 12, 31): datetime(2008, 1, 7), + datetime(2008, 1, 4): datetime(2008, 1, 7), + datetime(2008, 1, 5): datetime(2008, 1, 7), + datetime(2008, 1, 6): datetime(2008, 1, 7), + datetime(2008, 1, 7): datetime(2008, 1, 14), + }, + ) + ) + + # n=0 -> roll forward. Mon + offset_cases.append( + ( + Week(0, weekday=0), + { + datetime(2007, 12, 31): datetime(2007, 12, 31), + datetime(2008, 1, 4): datetime(2008, 1, 7), + datetime(2008, 1, 5): datetime(2008, 1, 7), + datetime(2008, 1, 6): datetime(2008, 1, 7), + datetime(2008, 1, 7): datetime(2008, 1, 7), + }, + ) + ) + + # n=0 -> roll forward. Mon + offset_cases.append( + ( + Week(-2, weekday=1), + { + datetime(2010, 4, 6): datetime(2010, 3, 23), + datetime(2010, 4, 8): datetime(2010, 3, 30), + datetime(2010, 4, 5): datetime(2010, 3, 23), + }, + ) + ) + + @pytest.mark.parametrize("case", offset_cases) + def test_offset(self, case): + offset, cases = case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + @pytest.mark.parametrize("weekday", range(7)) + def test_is_on_offset(self, weekday): + offset = Week(weekday=weekday) + + for day in range(1, 8): + date = datetime(2008, 1, day) + expected = day % 7 == weekday + assert_is_on_offset(offset, date, expected) + + @pytest.mark.parametrize( + "n,date", + [ + (2, "1862-01-13 09:03:34.873477378+0210"), + (-2, "1856-10-24 16:18:36.556360110-0717"), + ], + ) + def test_is_on_offset_weekday_none(self, n, date): + # GH 18510 Week with weekday = None, normalize = False + # should always be is_on_offset + offset = Week(n=n, weekday=None) + ts = Timestamp(date, tz="Africa/Lusaka") + fast = offset.is_on_offset(ts) + slow = (ts + offset) - offset == ts + assert fast == slow + + def test_week_add_invalid(self): + # Week with weekday should raise TypeError and _not_ AttributeError + # when adding invalid offset + offset = Week(weekday=1) + other = Day() + with pytest.raises(TypeError, match="Cannot add"): + offset + other + + +class TestWeekOfMonth: + def test_constructor(self): + with pytest.raises(ValueError, match="^Week"): + WeekOfMonth(n=1, week=4, weekday=0) + + with pytest.raises(ValueError, match="^Week"): + WeekOfMonth(n=1, week=-1, weekday=0) + + with pytest.raises(ValueError, match="^Day"): + WeekOfMonth(n=1, week=0, weekday=-1) + + with pytest.raises(ValueError, match="^Day"): + WeekOfMonth(n=1, week=0, weekday=-7) + + def test_repr(self): + assert ( + repr(WeekOfMonth(weekday=1, week=2)) == "" + ) + + def test_offset(self): + date1 = datetime(2011, 1, 4) # 1st Tuesday of Month + date2 = datetime(2011, 1, 11) # 2nd Tuesday of Month + date3 = datetime(2011, 1, 18) # 3rd Tuesday of Month + date4 = datetime(2011, 1, 25) # 4th Tuesday of Month + + # see for loop for structure + test_cases = [ + (-2, 2, 1, date1, datetime(2010, 11, 16)), + (-2, 2, 1, date2, datetime(2010, 11, 16)), + (-2, 2, 1, date3, datetime(2010, 11, 16)), + (-2, 2, 1, date4, datetime(2010, 12, 21)), + (-1, 2, 1, date1, datetime(2010, 12, 21)), + (-1, 2, 1, date2, datetime(2010, 12, 21)), + (-1, 2, 1, date3, datetime(2010, 12, 21)), + (-1, 2, 1, date4, datetime(2011, 1, 18)), + (0, 0, 1, date1, datetime(2011, 1, 4)), + (0, 0, 1, date2, datetime(2011, 2, 1)), + (0, 0, 1, date3, datetime(2011, 2, 1)), + (0, 0, 1, date4, datetime(2011, 2, 1)), + (0, 1, 1, date1, datetime(2011, 1, 11)), + (0, 1, 1, date2, datetime(2011, 1, 11)), + (0, 1, 1, date3, datetime(2011, 2, 8)), + (0, 1, 1, date4, datetime(2011, 2, 8)), + (0, 0, 1, date1, datetime(2011, 1, 4)), + (0, 1, 1, date2, datetime(2011, 1, 11)), + (0, 2, 1, date3, datetime(2011, 1, 18)), + (0, 3, 1, date4, datetime(2011, 1, 25)), + (1, 0, 0, date1, datetime(2011, 2, 7)), + (1, 0, 0, date2, datetime(2011, 2, 7)), + (1, 0, 0, date3, datetime(2011, 2, 7)), + (1, 0, 0, date4, datetime(2011, 2, 7)), + (1, 0, 1, date1, datetime(2011, 2, 1)), + (1, 0, 1, date2, datetime(2011, 2, 1)), + (1, 0, 1, date3, datetime(2011, 2, 1)), + (1, 0, 1, date4, datetime(2011, 2, 1)), + (1, 0, 2, date1, datetime(2011, 1, 5)), + (1, 0, 2, date2, datetime(2011, 2, 2)), + (1, 0, 2, date3, datetime(2011, 2, 2)), + (1, 0, 2, date4, datetime(2011, 2, 2)), + (1, 2, 1, date1, datetime(2011, 1, 18)), + (1, 2, 1, date2, datetime(2011, 1, 18)), + (1, 2, 1, date3, datetime(2011, 2, 15)), + (1, 2, 1, date4, datetime(2011, 2, 15)), + (2, 2, 1, date1, datetime(2011, 2, 15)), + (2, 2, 1, date2, datetime(2011, 2, 15)), + (2, 2, 1, date3, datetime(2011, 3, 15)), + (2, 2, 1, date4, datetime(2011, 3, 15)), + ] + + for n, week, weekday, dt, expected in test_cases: + offset = WeekOfMonth(n, week=week, weekday=weekday) + assert_offset_equal(offset, dt, expected) + + # try subtracting + result = datetime(2011, 2, 1) - WeekOfMonth(week=1, weekday=2) + assert result == datetime(2011, 1, 12) + + result = datetime(2011, 2, 3) - WeekOfMonth(week=0, weekday=2) + assert result == datetime(2011, 2, 2) + + on_offset_cases = [ + (0, 0, datetime(2011, 2, 7), True), + (0, 0, datetime(2011, 2, 6), False), + (0, 0, datetime(2011, 2, 14), False), + (1, 0, datetime(2011, 2, 14), True), + (0, 1, datetime(2011, 2, 1), True), + (0, 1, datetime(2011, 2, 8), False), + ] + + @pytest.mark.parametrize("case", on_offset_cases) + def test_is_on_offset(self, case): + week, weekday, dt, expected = case + offset = WeekOfMonth(week=week, weekday=weekday) + assert offset.is_on_offset(dt) == expected + + @pytest.mark.parametrize( + "n,week,date,tz", + [ + (2, 2, "1916-05-15 01:14:49.583410462+0422", "Asia/Qyzylorda"), + (-3, 1, "1980-12-08 03:38:52.878321185+0500", "Asia/Oral"), + ], + ) + def test_is_on_offset_nanoseconds(self, n, week, date, tz): + # GH 18864 + # Make sure that nanoseconds don't trip up is_on_offset (and with it apply) + offset = WeekOfMonth(n=n, week=week, weekday=0) + ts = Timestamp(date, tz=tz) + fast = offset.is_on_offset(ts) + slow = (ts + offset) - offset == ts + assert fast == slow + + +class TestLastWeekOfMonth: + def test_constructor(self): + with pytest.raises(ValueError, match="^N cannot be 0"): + LastWeekOfMonth(n=0, weekday=1) + + with pytest.raises(ValueError, match="^Day"): + LastWeekOfMonth(n=1, weekday=-1) + + with pytest.raises(ValueError, match="^Day"): + LastWeekOfMonth(n=1, weekday=7) + + def test_offset(self): + # Saturday + last_sat = datetime(2013, 8, 31) + next_sat = datetime(2013, 9, 28) + offset_sat = LastWeekOfMonth(n=1, weekday=5) + + one_day_before = last_sat + timedelta(days=-1) + assert one_day_before + offset_sat == last_sat + + one_day_after = last_sat + timedelta(days=+1) + assert one_day_after + offset_sat == next_sat + + # Test On that day + assert last_sat + offset_sat == next_sat + + # Thursday + + offset_thur = LastWeekOfMonth(n=1, weekday=3) + last_thurs = datetime(2013, 1, 31) + next_thurs = datetime(2013, 2, 28) + + one_day_before = last_thurs + timedelta(days=-1) + assert one_day_before + offset_thur == last_thurs + + one_day_after = last_thurs + timedelta(days=+1) + assert one_day_after + offset_thur == next_thurs + + # Test on that day + assert last_thurs + offset_thur == next_thurs + + three_before = last_thurs + timedelta(days=-3) + assert three_before + offset_thur == last_thurs + + two_after = last_thurs + timedelta(days=+2) + assert two_after + offset_thur == next_thurs + + offset_sunday = LastWeekOfMonth(n=1, weekday=WeekDay.SUN) + assert datetime(2013, 7, 31) + offset_sunday == datetime(2013, 8, 25) + + on_offset_cases = [ + (WeekDay.SUN, datetime(2013, 1, 27), True), + (WeekDay.SAT, datetime(2013, 3, 30), True), + (WeekDay.MON, datetime(2013, 2, 18), False), # Not the last Mon + (WeekDay.SUN, datetime(2013, 2, 25), False), # Not a SUN + (WeekDay.MON, datetime(2013, 2, 25), True), + (WeekDay.SAT, datetime(2013, 11, 30), True), + (WeekDay.SAT, datetime(2006, 8, 26), True), + (WeekDay.SAT, datetime(2007, 8, 25), True), + (WeekDay.SAT, datetime(2008, 8, 30), True), + (WeekDay.SAT, datetime(2009, 8, 29), True), + (WeekDay.SAT, datetime(2010, 8, 28), True), + (WeekDay.SAT, datetime(2011, 8, 27), True), + (WeekDay.SAT, datetime(2019, 8, 31), True), + ] + + @pytest.mark.parametrize("case", on_offset_cases) + def test_is_on_offset(self, case): + weekday, dt, expected = case + offset = LastWeekOfMonth(weekday=weekday) + assert offset.is_on_offset(dt) == expected + + @pytest.mark.parametrize( + "n,weekday,date,tz", + [ + (4, 6, "1917-05-27 20:55:27.084284178+0200", "Europe/Warsaw"), + (-4, 5, "2005-08-27 05:01:42.799392561-0500", "America/Rainy_River"), + ], + ) + def test_last_week_of_month_on_offset(self, n, weekday, date, tz): + # GH 19036, GH 18977 _adjust_dst was incorrect for LastWeekOfMonth + offset = LastWeekOfMonth(n=n, weekday=weekday) + ts = Timestamp(date, tz=tz) + slow = (ts + offset) - offset == ts + fast = offset.is_on_offset(ts) + assert fast == slow + + def test_repr(self): + assert ( + repr(LastWeekOfMonth(n=2, weekday=1)) == "<2 * LastWeekOfMonths: weekday=1>" + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_year.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_year.py new file mode 100644 index 0000000000000000000000000000000000000000..28cbdcf6abeccbbc02827d63c76aaa2f22b3c945 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tseries/offsets/test_year.py @@ -0,0 +1,339 @@ +""" +Tests for the following offsets: +- YearBegin +- YearEnd +""" +from __future__ import annotations + +from datetime import datetime + +import numpy as np +import pytest + +from pandas import Timestamp +from pandas.tests.tseries.offsets.common import ( + assert_is_on_offset, + assert_offset_equal, +) + +from pandas.tseries.offsets import ( + YearBegin, + YearEnd, +) + + +class TestYearBegin: + def test_misspecified(self): + with pytest.raises(ValueError, match="Month must go from 1 to 12"): + YearBegin(month=13) + + offset_cases = [] + offset_cases.append( + ( + YearBegin(), + { + datetime(2008, 1, 1): datetime(2009, 1, 1), + datetime(2008, 6, 30): datetime(2009, 1, 1), + datetime(2008, 12, 31): datetime(2009, 1, 1), + datetime(2005, 12, 30): datetime(2006, 1, 1), + datetime(2005, 12, 31): datetime(2006, 1, 1), + }, + ) + ) + + offset_cases.append( + ( + YearBegin(0), + { + datetime(2008, 1, 1): datetime(2008, 1, 1), + datetime(2008, 6, 30): datetime(2009, 1, 1), + datetime(2008, 12, 31): datetime(2009, 1, 1), + datetime(2005, 12, 30): datetime(2006, 1, 1), + datetime(2005, 12, 31): datetime(2006, 1, 1), + }, + ) + ) + + offset_cases.append( + ( + YearBegin(3), + { + datetime(2008, 1, 1): datetime(2011, 1, 1), + datetime(2008, 6, 30): datetime(2011, 1, 1), + datetime(2008, 12, 31): datetime(2011, 1, 1), + datetime(2005, 12, 30): datetime(2008, 1, 1), + datetime(2005, 12, 31): datetime(2008, 1, 1), + }, + ) + ) + + offset_cases.append( + ( + YearBegin(-1), + { + datetime(2007, 1, 1): datetime(2006, 1, 1), + datetime(2007, 1, 15): datetime(2007, 1, 1), + datetime(2008, 6, 30): datetime(2008, 1, 1), + datetime(2008, 12, 31): datetime(2008, 1, 1), + datetime(2006, 12, 29): datetime(2006, 1, 1), + datetime(2006, 12, 30): datetime(2006, 1, 1), + datetime(2007, 1, 1): datetime(2006, 1, 1), + }, + ) + ) + + offset_cases.append( + ( + YearBegin(-2), + { + datetime(2007, 1, 1): datetime(2005, 1, 1), + datetime(2008, 6, 30): datetime(2007, 1, 1), + datetime(2008, 12, 31): datetime(2007, 1, 1), + }, + ) + ) + + offset_cases.append( + ( + YearBegin(month=4), + { + datetime(2007, 4, 1): datetime(2008, 4, 1), + datetime(2007, 4, 15): datetime(2008, 4, 1), + datetime(2007, 3, 1): datetime(2007, 4, 1), + datetime(2007, 12, 15): datetime(2008, 4, 1), + datetime(2012, 1, 31): datetime(2012, 4, 1), + }, + ) + ) + + offset_cases.append( + ( + YearBegin(0, month=4), + { + datetime(2007, 4, 1): datetime(2007, 4, 1), + datetime(2007, 3, 1): datetime(2007, 4, 1), + datetime(2007, 12, 15): datetime(2008, 4, 1), + datetime(2012, 1, 31): datetime(2012, 4, 1), + }, + ) + ) + + offset_cases.append( + ( + YearBegin(4, month=4), + { + datetime(2007, 4, 1): datetime(2011, 4, 1), + datetime(2007, 4, 15): datetime(2011, 4, 1), + datetime(2007, 3, 1): datetime(2010, 4, 1), + datetime(2007, 12, 15): datetime(2011, 4, 1), + datetime(2012, 1, 31): datetime(2015, 4, 1), + }, + ) + ) + + offset_cases.append( + ( + YearBegin(-1, month=4), + { + datetime(2007, 4, 1): datetime(2006, 4, 1), + datetime(2007, 3, 1): datetime(2006, 4, 1), + datetime(2007, 12, 15): datetime(2007, 4, 1), + datetime(2012, 1, 31): datetime(2011, 4, 1), + }, + ) + ) + + offset_cases.append( + ( + YearBegin(-3, month=4), + { + datetime(2007, 4, 1): datetime(2004, 4, 1), + datetime(2007, 3, 1): datetime(2004, 4, 1), + datetime(2007, 12, 15): datetime(2005, 4, 1), + datetime(2012, 1, 31): datetime(2009, 4, 1), + }, + ) + ) + + @pytest.mark.parametrize("case", offset_cases) + def test_offset(self, case): + offset, cases = case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + on_offset_cases = [ + (YearBegin(), datetime(2007, 1, 3), False), + (YearBegin(), datetime(2008, 1, 1), True), + (YearBegin(), datetime(2006, 12, 31), False), + (YearBegin(), datetime(2006, 1, 2), False), + ] + + @pytest.mark.parametrize("case", on_offset_cases) + def test_is_on_offset(self, case): + offset, dt, expected = case + assert_is_on_offset(offset, dt, expected) + + +class TestYearEnd: + def test_misspecified(self): + with pytest.raises(ValueError, match="Month must go from 1 to 12"): + YearEnd(month=13) + + offset_cases = [] + offset_cases.append( + ( + YearEnd(), + { + datetime(2008, 1, 1): datetime(2008, 12, 31), + datetime(2008, 6, 30): datetime(2008, 12, 31), + datetime(2008, 12, 31): datetime(2009, 12, 31), + datetime(2005, 12, 30): datetime(2005, 12, 31), + datetime(2005, 12, 31): datetime(2006, 12, 31), + }, + ) + ) + + offset_cases.append( + ( + YearEnd(0), + { + datetime(2008, 1, 1): datetime(2008, 12, 31), + datetime(2008, 6, 30): datetime(2008, 12, 31), + datetime(2008, 12, 31): datetime(2008, 12, 31), + datetime(2005, 12, 30): datetime(2005, 12, 31), + }, + ) + ) + + offset_cases.append( + ( + YearEnd(-1), + { + datetime(2007, 1, 1): datetime(2006, 12, 31), + datetime(2008, 6, 30): datetime(2007, 12, 31), + datetime(2008, 12, 31): datetime(2007, 12, 31), + datetime(2006, 12, 29): datetime(2005, 12, 31), + datetime(2006, 12, 30): datetime(2005, 12, 31), + datetime(2007, 1, 1): datetime(2006, 12, 31), + }, + ) + ) + + offset_cases.append( + ( + YearEnd(-2), + { + datetime(2007, 1, 1): datetime(2005, 12, 31), + datetime(2008, 6, 30): datetime(2006, 12, 31), + datetime(2008, 12, 31): datetime(2006, 12, 31), + }, + ) + ) + + @pytest.mark.parametrize("case", offset_cases) + def test_offset(self, case): + offset, cases = case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + on_offset_cases = [ + (YearEnd(), datetime(2007, 12, 31), True), + (YearEnd(), datetime(2008, 1, 1), False), + (YearEnd(), datetime(2006, 12, 31), True), + (YearEnd(), datetime(2006, 12, 29), False), + ] + + @pytest.mark.parametrize("case", on_offset_cases) + def test_is_on_offset(self, case): + offset, dt, expected = case + assert_is_on_offset(offset, dt, expected) + + +class TestYearEndDiffMonth: + offset_cases = [] + offset_cases.append( + ( + YearEnd(month=3), + { + datetime(2008, 1, 1): datetime(2008, 3, 31), + datetime(2008, 2, 15): datetime(2008, 3, 31), + datetime(2008, 3, 31): datetime(2009, 3, 31), + datetime(2008, 3, 30): datetime(2008, 3, 31), + datetime(2005, 3, 31): datetime(2006, 3, 31), + datetime(2006, 7, 30): datetime(2007, 3, 31), + }, + ) + ) + + offset_cases.append( + ( + YearEnd(0, month=3), + { + datetime(2008, 1, 1): datetime(2008, 3, 31), + datetime(2008, 2, 28): datetime(2008, 3, 31), + datetime(2008, 3, 31): datetime(2008, 3, 31), + datetime(2005, 3, 30): datetime(2005, 3, 31), + }, + ) + ) + + offset_cases.append( + ( + YearEnd(-1, month=3), + { + datetime(2007, 1, 1): datetime(2006, 3, 31), + datetime(2008, 2, 28): datetime(2007, 3, 31), + datetime(2008, 3, 31): datetime(2007, 3, 31), + datetime(2006, 3, 29): datetime(2005, 3, 31), + datetime(2006, 3, 30): datetime(2005, 3, 31), + datetime(2007, 3, 1): datetime(2006, 3, 31), + }, + ) + ) + + offset_cases.append( + ( + YearEnd(-2, month=3), + { + datetime(2007, 1, 1): datetime(2005, 3, 31), + datetime(2008, 6, 30): datetime(2007, 3, 31), + datetime(2008, 3, 31): datetime(2006, 3, 31), + }, + ) + ) + + @pytest.mark.parametrize("case", offset_cases) + def test_offset(self, case): + offset, cases = case + for base, expected in cases.items(): + assert_offset_equal(offset, base, expected) + + on_offset_cases = [ + (YearEnd(month=3), datetime(2007, 3, 31), True), + (YearEnd(month=3), datetime(2008, 1, 1), False), + (YearEnd(month=3), datetime(2006, 3, 31), True), + (YearEnd(month=3), datetime(2006, 3, 29), False), + ] + + @pytest.mark.parametrize("case", on_offset_cases) + def test_is_on_offset(self, case): + offset, dt, expected = case + assert_is_on_offset(offset, dt, expected) + + +def test_add_out_of_pydatetime_range(): + # GH#50348 don't raise in Timestamp.replace + ts = Timestamp(np.datetime64("-20000-12-31")) + off = YearEnd() + + result = ts + off + # TODO(cython3): "arg: datetime" annotation will impose + # datetime limitations on Timestamp. The fused type below works in cy3 + # ctypedef fused datetimelike: + # _Timestamp + # datetime + # expected = Timestamp(np.datetime64("-19999-12-31")) + # assert result == expected + assert result.year in (-19999, 1973) + assert result.month == 12 + assert result.day == 31 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_api.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_api.py new file mode 100644 index 0000000000000000000000000000000000000000..42d055326c2a54f5fcd6ba2257201eaec2a3122b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_api.py @@ -0,0 +1,65 @@ +"""Tests that the tslibs API is locked down""" + +from pandas._libs import tslibs + + +def test_namespace(): + submodules = [ + "base", + "ccalendar", + "conversion", + "dtypes", + "fields", + "nattype", + "np_datetime", + "offsets", + "parsing", + "period", + "strptime", + "vectorized", + "timedeltas", + "timestamps", + "timezones", + "tzconversion", + ] + + api = [ + "BaseOffset", + "NaT", + "NaTType", + "iNaT", + "nat_strings", + "OutOfBoundsDatetime", + "OutOfBoundsTimedelta", + "Period", + "IncompatibleFrequency", + "Resolution", + "Tick", + "Timedelta", + "dt64arr_to_periodarr", + "Timestamp", + "is_date_array_normalized", + "ints_to_pydatetime", + "normalize_i8_timestamps", + "get_resolution", + "delta_to_nanoseconds", + "ints_to_pytimedelta", + "localize_pydatetime", + "tz_convert_from_utc", + "tz_convert_from_utc_single", + "to_offset", + "tz_compare", + "is_unitless", + "astype_overflowsafe", + "get_unit_from_dtype", + "periods_per_day", + "periods_per_second", + "guess_datetime_format", + "add_overflowsafe", + "get_supported_dtype", + "is_supported_dtype", + ] + + expected = set(submodules + api) + names = [x for x in dir(tslibs) if not x.startswith("__")] + assert set(names) == expected diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_array_to_datetime.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_array_to_datetime.py new file mode 100644 index 0000000000000000000000000000000000000000..82175c67764f84355dfb6d10a0477e9e0344e296 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_array_to_datetime.py @@ -0,0 +1,337 @@ +from datetime import ( + date, + datetime, + timedelta, + timezone, +) + +from dateutil.tz.tz import tzoffset +import numpy as np +import pytest + +from pandas._libs import ( + NaT, + iNaT, + tslib, +) +from pandas._libs.tslibs.dtypes import NpyDatetimeUnit +from pandas._libs.tslibs.np_datetime import OutOfBoundsDatetime + +from pandas import Timestamp +import pandas._testing as tm + +creso_infer = NpyDatetimeUnit.NPY_FR_GENERIC.value + + +class TestArrayToDatetimeResolutionInference: + # TODO: tests that include tzs, ints + + def test_infer_all_nat(self): + arr = np.array([NaT, np.nan], dtype=object) + result, tz = tslib.array_to_datetime(arr, creso=creso_infer) + assert tz is None + assert result.dtype == "M8[s]" + + def test_infer_homogeoneous_datetimes(self): + dt = datetime(2023, 10, 27, 18, 3, 5, 678000) + arr = np.array([dt, dt, dt], dtype=object) + result, tz = tslib.array_to_datetime(arr, creso=creso_infer) + assert tz is None + expected = np.array([dt, dt, dt], dtype="M8[us]") + tm.assert_numpy_array_equal(result, expected) + + def test_infer_homogeoneous_date_objects(self): + dt = datetime(2023, 10, 27, 18, 3, 5, 678000) + dt2 = dt.date() + arr = np.array([None, dt2, dt2, dt2], dtype=object) + result, tz = tslib.array_to_datetime(arr, creso=creso_infer) + assert tz is None + expected = np.array([np.datetime64("NaT"), dt2, dt2, dt2], dtype="M8[s]") + tm.assert_numpy_array_equal(result, expected) + + def test_infer_homogeoneous_dt64(self): + dt = datetime(2023, 10, 27, 18, 3, 5, 678000) + dt64 = np.datetime64(dt, "ms") + arr = np.array([None, dt64, dt64, dt64], dtype=object) + result, tz = tslib.array_to_datetime(arr, creso=creso_infer) + assert tz is None + expected = np.array([np.datetime64("NaT"), dt64, dt64, dt64], dtype="M8[ms]") + tm.assert_numpy_array_equal(result, expected) + + def test_infer_homogeoneous_timestamps(self): + dt = datetime(2023, 10, 27, 18, 3, 5, 678000) + ts = Timestamp(dt).as_unit("ns") + arr = np.array([None, ts, ts, ts], dtype=object) + result, tz = tslib.array_to_datetime(arr, creso=creso_infer) + assert tz is None + expected = np.array([np.datetime64("NaT")] + [ts.asm8] * 3, dtype="M8[ns]") + tm.assert_numpy_array_equal(result, expected) + + def test_infer_homogeoneous_datetimes_strings(self): + item = "2023-10-27 18:03:05.678000" + arr = np.array([None, item, item, item], dtype=object) + result, tz = tslib.array_to_datetime(arr, creso=creso_infer) + assert tz is None + expected = np.array([np.datetime64("NaT"), item, item, item], dtype="M8[us]") + tm.assert_numpy_array_equal(result, expected) + + def test_infer_heterogeneous(self): + dtstr = "2023-10-27 18:03:05.678000" + + arr = np.array([dtstr, dtstr[:-3], dtstr[:-7], None], dtype=object) + result, tz = tslib.array_to_datetime(arr, creso=creso_infer) + assert tz is None + expected = np.array(arr, dtype="M8[us]") + tm.assert_numpy_array_equal(result, expected) + + result, tz = tslib.array_to_datetime(arr[::-1], creso=creso_infer) + assert tz is None + tm.assert_numpy_array_equal(result, expected[::-1]) + + @pytest.mark.parametrize( + "item", [float("nan"), NaT.value, float(NaT.value), "NaT", ""] + ) + def test_infer_with_nat_int_float_str(self, item): + # floats/ints get inferred to nanos *unless* they are NaN/iNaT, + # similar NaT string gets treated like NaT scalar (ignored for resolution) + dt = datetime(2023, 11, 15, 15, 5, 6) + + arr = np.array([dt, item], dtype=object) + result, tz = tslib.array_to_datetime(arr, creso=creso_infer) + assert tz is None + expected = np.array([dt, np.datetime64("NaT")], dtype="M8[us]") + tm.assert_numpy_array_equal(result, expected) + + result2, tz2 = tslib.array_to_datetime(arr[::-1], creso=creso_infer) + assert tz2 is None + tm.assert_numpy_array_equal(result2, expected[::-1]) + + +class TestArrayToDatetimeWithTZResolutionInference: + def test_array_to_datetime_with_tz_resolution(self): + tz = tzoffset("custom", 3600) + vals = np.array(["2016-01-01 02:03:04.567", NaT], dtype=object) + res = tslib.array_to_datetime_with_tz(vals, tz, False, False, creso_infer) + assert res.dtype == "M8[ms]" + + vals2 = np.array([datetime(2016, 1, 1, 2, 3, 4), NaT], dtype=object) + res2 = tslib.array_to_datetime_with_tz(vals2, tz, False, False, creso_infer) + assert res2.dtype == "M8[us]" + + vals3 = np.array([NaT, np.datetime64(12345, "s")], dtype=object) + res3 = tslib.array_to_datetime_with_tz(vals3, tz, False, False, creso_infer) + assert res3.dtype == "M8[s]" + + def test_array_to_datetime_with_tz_resolution_all_nat(self): + tz = tzoffset("custom", 3600) + vals = np.array(["NaT"], dtype=object) + res = tslib.array_to_datetime_with_tz(vals, tz, False, False, creso_infer) + assert res.dtype == "M8[s]" + + vals2 = np.array([NaT, NaT], dtype=object) + res2 = tslib.array_to_datetime_with_tz(vals2, tz, False, False, creso_infer) + assert res2.dtype == "M8[s]" + + +@pytest.mark.parametrize( + "data,expected", + [ + ( + ["01-01-2013", "01-02-2013"], + [ + "2013-01-01T00:00:00.000000000", + "2013-01-02T00:00:00.000000000", + ], + ), + ( + ["Mon Sep 16 2013", "Tue Sep 17 2013"], + [ + "2013-09-16T00:00:00.000000000", + "2013-09-17T00:00:00.000000000", + ], + ), + ], +) +def test_parsing_valid_dates(data, expected): + arr = np.array(data, dtype=object) + result, _ = tslib.array_to_datetime(arr) + + expected = np.array(expected, dtype="M8[ns]") + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize( + "dt_string, expected_tz", + [ + ["01-01-2013 08:00:00+08:00", 480], + ["2013-01-01T08:00:00.000000000+0800", 480], + ["2012-12-31T16:00:00.000000000-0800", -480], + ["12-31-2012 23:00:00-01:00", -60], + ], +) +def test_parsing_timezone_offsets(dt_string, expected_tz): + # All of these datetime strings with offsets are equivalent + # to the same datetime after the timezone offset is added. + arr = np.array(["01-01-2013 00:00:00"], dtype=object) + expected, _ = tslib.array_to_datetime(arr) + + arr = np.array([dt_string], dtype=object) + result, result_tz = tslib.array_to_datetime(arr) + + tm.assert_numpy_array_equal(result, expected) + assert result_tz == timezone(timedelta(minutes=expected_tz)) + + +def test_parsing_non_iso_timezone_offset(): + dt_string = "01-01-2013T00:00:00.000000000+0000" + arr = np.array([dt_string], dtype=object) + + with tm.assert_produces_warning(None): + # GH#50949 should not get tzlocal-deprecation warning here + result, result_tz = tslib.array_to_datetime(arr) + expected = np.array([np.datetime64("2013-01-01 00:00:00.000000000")]) + + tm.assert_numpy_array_equal(result, expected) + assert result_tz is timezone.utc + + +def test_parsing_different_timezone_offsets(): + # see gh-17697 + data = ["2015-11-18 15:30:00+05:30", "2015-11-18 15:30:00+06:30"] + data = np.array(data, dtype=object) + + msg = "parsing datetimes with mixed time zones will raise an error" + with tm.assert_produces_warning(FutureWarning, match=msg): + result, result_tz = tslib.array_to_datetime(data) + expected = np.array( + [ + datetime(2015, 11, 18, 15, 30, tzinfo=tzoffset(None, 19800)), + datetime(2015, 11, 18, 15, 30, tzinfo=tzoffset(None, 23400)), + ], + dtype=object, + ) + + tm.assert_numpy_array_equal(result, expected) + assert result_tz is None + + +@pytest.mark.parametrize( + "data", [["-352.737091", "183.575577"], ["1", "2", "3", "4", "5"]] +) +def test_number_looking_strings_not_into_datetime(data): + # see gh-4601 + # + # These strings don't look like datetimes, so + # they shouldn't be attempted to be converted. + arr = np.array(data, dtype=object) + result, _ = tslib.array_to_datetime(arr, errors="ignore") + + tm.assert_numpy_array_equal(result, arr) + + +@pytest.mark.parametrize( + "invalid_date", + [ + date(1000, 1, 1), + datetime(1000, 1, 1), + "1000-01-01", + "Jan 1, 1000", + np.datetime64("1000-01-01"), + ], +) +@pytest.mark.parametrize("errors", ["coerce", "raise"]) +def test_coerce_outside_ns_bounds(invalid_date, errors): + arr = np.array([invalid_date], dtype="object") + kwargs = {"values": arr, "errors": errors} + + if errors == "raise": + msg = "^Out of bounds nanosecond timestamp: .*, at position 0$" + + with pytest.raises(OutOfBoundsDatetime, match=msg): + tslib.array_to_datetime(**kwargs) + else: # coerce. + result, _ = tslib.array_to_datetime(**kwargs) + expected = np.array([iNaT], dtype="M8[ns]") + + tm.assert_numpy_array_equal(result, expected) + + +def test_coerce_outside_ns_bounds_one_valid(): + arr = np.array(["1/1/1000", "1/1/2000"], dtype=object) + result, _ = tslib.array_to_datetime(arr, errors="coerce") + + expected = [iNaT, "2000-01-01T00:00:00.000000000"] + expected = np.array(expected, dtype="M8[ns]") + + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize("errors", ["ignore", "coerce"]) +def test_coerce_of_invalid_datetimes(errors): + arr = np.array(["01-01-2013", "not_a_date", "1"], dtype=object) + kwargs = {"values": arr, "errors": errors} + + if errors == "ignore": + # Without coercing, the presence of any invalid + # dates prevents any values from being converted. + result, _ = tslib.array_to_datetime(**kwargs) + tm.assert_numpy_array_equal(result, arr) + else: # coerce. + # With coercing, the invalid dates becomes iNaT + result, _ = tslib.array_to_datetime(arr, errors="coerce") + expected = ["2013-01-01T00:00:00.000000000", iNaT, iNaT] + + tm.assert_numpy_array_equal(result, np.array(expected, dtype="M8[ns]")) + + +def test_to_datetime_barely_out_of_bounds(): + # see gh-19382, gh-19529 + # + # Close enough to bounds that dropping nanos + # would result in an in-bounds datetime. + arr = np.array(["2262-04-11 23:47:16.854775808"], dtype=object) + msg = "^Out of bounds nanosecond timestamp: 2262-04-11 23:47:16, at position 0$" + + with pytest.raises(tslib.OutOfBoundsDatetime, match=msg): + tslib.array_to_datetime(arr) + + +@pytest.mark.parametrize( + "timestamp", + [ + # Close enough to bounds that scaling micros to nanos overflows + # but adding nanos would result in an in-bounds datetime. + "1677-09-21T00:12:43.145224193", + "1677-09-21T00:12:43.145224999", + # this always worked + "1677-09-21T00:12:43.145225000", + ], +) +def test_to_datetime_barely_inside_bounds(timestamp): + # see gh-57150 + result, _ = tslib.array_to_datetime(np.array([timestamp], dtype=object)) + tm.assert_numpy_array_equal(result, np.array([timestamp], dtype="M8[ns]")) + + +class SubDatetime(datetime): + pass + + +@pytest.mark.parametrize( + "data,expected", + [ + ([SubDatetime(2000, 1, 1)], ["2000-01-01T00:00:00.000000000"]), + ([datetime(2000, 1, 1)], ["2000-01-01T00:00:00.000000000"]), + ([Timestamp(2000, 1, 1)], ["2000-01-01T00:00:00.000000000"]), + ], +) +def test_datetime_subclass(data, expected): + # GH 25851 + # ensure that subclassed datetime works with + # array_to_datetime + + arr = np.array(data, dtype=object) + result, _ = tslib.array_to_datetime(arr) + + expected = np.array(expected, dtype="M8[ns]") + tm.assert_numpy_array_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_ccalendar.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_ccalendar.py new file mode 100644 index 0000000000000000000000000000000000000000..8dd1bd47e4728d1b35e84b14f29e0a255178ec9b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_ccalendar.py @@ -0,0 +1,63 @@ +from datetime import ( + date, + datetime, +) + +from hypothesis import given +import numpy as np +import pytest + +from pandas._libs.tslibs import ccalendar + +from pandas._testing._hypothesis import DATETIME_IN_PD_TIMESTAMP_RANGE_NO_TZ + + +@pytest.mark.parametrize( + "date_tuple,expected", + [ + ((2001, 3, 1), 60), + ((2004, 3, 1), 61), + ((1907, 12, 31), 365), # End-of-year, non-leap year. + ((2004, 12, 31), 366), # End-of-year, leap year. + ], +) +def test_get_day_of_year_numeric(date_tuple, expected): + assert ccalendar.get_day_of_year(*date_tuple) == expected + + +def test_get_day_of_year_dt(): + dt = datetime.fromordinal(1 + np.random.default_rng(2).integers(365 * 4000)) + result = ccalendar.get_day_of_year(dt.year, dt.month, dt.day) + + expected = (dt - dt.replace(month=1, day=1)).days + 1 + assert result == expected + + +@pytest.mark.parametrize( + "input_date_tuple, expected_iso_tuple", + [ + [(2020, 1, 1), (2020, 1, 3)], + [(2019, 12, 31), (2020, 1, 2)], + [(2019, 12, 30), (2020, 1, 1)], + [(2009, 12, 31), (2009, 53, 4)], + [(2010, 1, 1), (2009, 53, 5)], + [(2010, 1, 3), (2009, 53, 7)], + [(2010, 1, 4), (2010, 1, 1)], + [(2006, 1, 1), (2005, 52, 7)], + [(2005, 12, 31), (2005, 52, 6)], + [(2008, 12, 28), (2008, 52, 7)], + [(2008, 12, 29), (2009, 1, 1)], + ], +) +def test_dt_correct_iso_8601_year_week_and_day(input_date_tuple, expected_iso_tuple): + result = ccalendar.get_iso_calendar(*input_date_tuple) + expected_from_date_isocalendar = date(*input_date_tuple).isocalendar() + assert result == expected_from_date_isocalendar + assert result == expected_iso_tuple + + +@given(DATETIME_IN_PD_TIMESTAMP_RANGE_NO_TZ) +def test_isocalendar(dt): + expected = dt.isocalendar() + result = ccalendar.get_iso_calendar(dt.year, dt.month, dt.day) + assert result == expected diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_conversion.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_conversion.py new file mode 100644 index 0000000000000000000000000000000000000000..9d7a5e906c3c3771a7d909f0ded9397e5cee0d64 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_conversion.py @@ -0,0 +1,160 @@ +from datetime import datetime + +import numpy as np +import pytest +from pytz import UTC + +from pandas._libs.tslibs import ( + OutOfBoundsTimedelta, + astype_overflowsafe, + conversion, + iNaT, + timezones, + tz_convert_from_utc, + tzconversion, +) + +from pandas import ( + Timestamp, + date_range, +) +import pandas._testing as tm + + +def _compare_utc_to_local(tz_didx): + def f(x): + return tzconversion.tz_convert_from_utc_single(x, tz_didx.tz) + + result = tz_convert_from_utc(tz_didx.asi8, tz_didx.tz) + expected = np.vectorize(f)(tz_didx.asi8) + + tm.assert_numpy_array_equal(result, expected) + + +def _compare_local_to_utc(tz_didx, naive_didx): + # Check that tz_localize behaves the same vectorized and pointwise. + err1 = err2 = None + try: + result = tzconversion.tz_localize_to_utc(naive_didx.asi8, tz_didx.tz) + err1 = None + except Exception as err: + err1 = err + + try: + expected = naive_didx.map(lambda x: x.tz_localize(tz_didx.tz)).asi8 + except Exception as err: + err2 = err + + if err1 is not None: + assert type(err1) == type(err2) + else: + assert err2 is None + tm.assert_numpy_array_equal(result, expected) + + +def test_tz_localize_to_utc_copies(): + # GH#46460 + arr = np.arange(5, dtype="i8") + result = tz_convert_from_utc(arr, tz=UTC) + tm.assert_numpy_array_equal(result, arr) + assert not np.shares_memory(arr, result) + + result = tz_convert_from_utc(arr, tz=None) + tm.assert_numpy_array_equal(result, arr) + assert not np.shares_memory(arr, result) + + +def test_tz_convert_single_matches_tz_convert_hourly(tz_aware_fixture): + tz = tz_aware_fixture + tz_didx = date_range("2014-03-01", "2015-01-10", freq="h", tz=tz) + naive_didx = date_range("2014-03-01", "2015-01-10", freq="h") + + _compare_utc_to_local(tz_didx) + _compare_local_to_utc(tz_didx, naive_didx) + + +@pytest.mark.parametrize("freq", ["D", "YE"]) +def test_tz_convert_single_matches_tz_convert(tz_aware_fixture, freq): + tz = tz_aware_fixture + tz_didx = date_range("2018-01-01", "2020-01-01", freq=freq, tz=tz) + naive_didx = date_range("2018-01-01", "2020-01-01", freq=freq) + + _compare_utc_to_local(tz_didx) + _compare_local_to_utc(tz_didx, naive_didx) + + +@pytest.mark.parametrize( + "arr", + [ + pytest.param(np.array([], dtype=np.int64), id="empty"), + pytest.param(np.array([iNaT], dtype=np.int64), id="all_nat"), + ], +) +def test_tz_convert_corner(arr): + result = tz_convert_from_utc(arr, timezones.maybe_get_tz("Asia/Tokyo")) + tm.assert_numpy_array_equal(result, arr) + + +def test_tz_convert_readonly(): + # GH#35530 + arr = np.array([0], dtype=np.int64) + arr.setflags(write=False) + result = tz_convert_from_utc(arr, UTC) + tm.assert_numpy_array_equal(result, arr) + + +@pytest.mark.parametrize("copy", [True, False]) +@pytest.mark.parametrize("dtype", ["M8[ns]", "M8[s]"]) +def test_length_zero_copy(dtype, copy): + arr = np.array([], dtype=dtype) + result = astype_overflowsafe(arr, copy=copy, dtype=np.dtype("M8[ns]")) + if copy: + assert not np.shares_memory(result, arr) + elif arr.dtype == result.dtype: + assert result is arr + else: + assert not np.shares_memory(result, arr) + + +def test_ensure_datetime64ns_bigendian(): + # GH#29684 + arr = np.array([np.datetime64(1, "ms")], dtype=">M8[ms]") + result = astype_overflowsafe(arr, dtype=np.dtype("M8[ns]")) + + expected = np.array([np.datetime64(1, "ms")], dtype="M8[ns]") + tm.assert_numpy_array_equal(result, expected) + + +def test_ensure_timedelta64ns_overflows(): + arr = np.arange(10).astype("m8[Y]") * 100 + msg = r"Cannot convert 300 years to timedelta64\[ns\] without overflow" + with pytest.raises(OutOfBoundsTimedelta, match=msg): + astype_overflowsafe(arr, dtype=np.dtype("m8[ns]")) + + +class SubDatetime(datetime): + pass + + +@pytest.mark.parametrize( + "dt, expected", + [ + pytest.param( + Timestamp("2000-01-01"), Timestamp("2000-01-01", tz=UTC), id="timestamp" + ), + pytest.param( + datetime(2000, 1, 1), datetime(2000, 1, 1, tzinfo=UTC), id="datetime" + ), + pytest.param( + SubDatetime(2000, 1, 1), + SubDatetime(2000, 1, 1, tzinfo=UTC), + id="subclassed_datetime", + ), + ], +) +def test_localize_pydatetime_dt_types(dt, expected): + # GH 25851 + # ensure that subclassed datetime works with + # localize_pydatetime + result = conversion.localize_pydatetime(dt, UTC) + assert result == expected diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_fields.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_fields.py new file mode 100644 index 0000000000000000000000000000000000000000..da67c093b8f4dbaffba9e02f395bb830de33b489 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_fields.py @@ -0,0 +1,40 @@ +import numpy as np +import pytest + +from pandas._libs.tslibs import fields + +import pandas._testing as tm + + +@pytest.fixture +def dtindex(): + dtindex = np.arange(5, dtype=np.int64) * 10**9 * 3600 * 24 * 32 + dtindex.flags.writeable = False + return dtindex + + +def test_get_date_name_field_readonly(dtindex): + # https://github.com/vaexio/vaex/issues/357 + # fields functions shouldn't raise when we pass read-only data + result = fields.get_date_name_field(dtindex, "month_name") + expected = np.array(["January", "February", "March", "April", "May"], dtype=object) + tm.assert_numpy_array_equal(result, expected) + + +def test_get_date_field_readonly(dtindex): + result = fields.get_date_field(dtindex, "Y") + expected = np.array([1970, 1970, 1970, 1970, 1970], dtype=np.int32) + tm.assert_numpy_array_equal(result, expected) + + +def test_get_start_end_field_readonly(dtindex): + result = fields.get_start_end_field(dtindex, "is_month_start", None) + expected = np.array([True, False, False, False, False], dtype=np.bool_) + tm.assert_numpy_array_equal(result, expected) + + +def test_get_timedelta_field_readonly(dtindex): + # treat dtindex as timedeltas for this next one + result = fields.get_timedelta_field(dtindex, "seconds") + expected = np.array([0] * 5, dtype=np.int32) + tm.assert_numpy_array_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_libfrequencies.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_libfrequencies.py new file mode 100644 index 0000000000000000000000000000000000000000..effd3b4b8b4e5fa113f1b20506997efc54d3c9d2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_libfrequencies.py @@ -0,0 +1,27 @@ +import pytest + +from pandas._libs.tslibs.parsing import get_rule_month + +from pandas.tseries import offsets + + +@pytest.mark.parametrize( + "obj,expected", + [ + ("W", "DEC"), + (offsets.Week().freqstr, "DEC"), + ("D", "DEC"), + (offsets.Day().freqstr, "DEC"), + ("Q", "DEC"), + (offsets.QuarterEnd(startingMonth=12).freqstr, "DEC"), + ("Q-JAN", "JAN"), + (offsets.QuarterEnd(startingMonth=1).freqstr, "JAN"), + ("Y-DEC", "DEC"), + (offsets.YearEnd().freqstr, "DEC"), + ("Y-MAY", "MAY"), + (offsets.YearEnd(month=5).freqstr, "MAY"), + ], +) +def test_get_rule_month(obj, expected): + result = get_rule_month(obj) + assert result == expected diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_liboffsets.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_liboffsets.py new file mode 100644 index 0000000000000000000000000000000000000000..c189a431146a7172862586ea3a015ad4f2676cf2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_liboffsets.py @@ -0,0 +1,173 @@ +""" +Tests for helper functions in the cython tslibs.offsets +""" +from datetime import datetime + +import pytest + +from pandas._libs.tslibs.ccalendar import ( + get_firstbday, + get_lastbday, +) +import pandas._libs.tslibs.offsets as liboffsets +from pandas._libs.tslibs.offsets import roll_qtrday + +from pandas import Timestamp + + +@pytest.fixture(params=["start", "end", "business_start", "business_end"]) +def day_opt(request): + return request.param + + +@pytest.mark.parametrize( + "dt,exp_week_day,exp_last_day", + [ + (datetime(2017, 11, 30), 3, 30), # Business day. + (datetime(1993, 10, 31), 6, 29), # Non-business day. + ], +) +def test_get_last_bday(dt, exp_week_day, exp_last_day): + assert dt.weekday() == exp_week_day + assert get_lastbday(dt.year, dt.month) == exp_last_day + + +@pytest.mark.parametrize( + "dt,exp_week_day,exp_first_day", + [ + (datetime(2017, 4, 1), 5, 3), # Non-weekday. + (datetime(1993, 10, 1), 4, 1), # Business day. + ], +) +def test_get_first_bday(dt, exp_week_day, exp_first_day): + assert dt.weekday() == exp_week_day + assert get_firstbday(dt.year, dt.month) == exp_first_day + + +@pytest.mark.parametrize( + "months,day_opt,expected", + [ + (0, 15, datetime(2017, 11, 15)), + (0, None, datetime(2017, 11, 30)), + (1, "start", datetime(2017, 12, 1)), + (-145, "end", datetime(2005, 10, 31)), + (0, "business_end", datetime(2017, 11, 30)), + (0, "business_start", datetime(2017, 11, 1)), + ], +) +def test_shift_month_dt(months, day_opt, expected): + dt = datetime(2017, 11, 30) + assert liboffsets.shift_month(dt, months, day_opt=day_opt) == expected + + +@pytest.mark.parametrize( + "months,day_opt,expected", + [ + (1, "start", Timestamp("1929-06-01")), + (-3, "end", Timestamp("1929-02-28")), + (25, None, Timestamp("1931-06-5")), + (-1, 31, Timestamp("1929-04-30")), + ], +) +def test_shift_month_ts(months, day_opt, expected): + ts = Timestamp("1929-05-05") + assert liboffsets.shift_month(ts, months, day_opt=day_opt) == expected + + +def test_shift_month_error(): + dt = datetime(2017, 11, 15) + day_opt = "this should raise" + + with pytest.raises(ValueError, match=day_opt): + liboffsets.shift_month(dt, 3, day_opt=day_opt) + + +@pytest.mark.parametrize( + "other,expected", + [ + # Before March 1. + (datetime(2017, 2, 10), {2: 1, -7: -7, 0: 0}), + # After March 1. + (Timestamp("2014-03-15", tz="US/Eastern"), {2: 2, -7: -6, 0: 1}), + ], +) +@pytest.mark.parametrize("n", [2, -7, 0]) +def test_roll_qtrday_year(other, expected, n): + month = 3 + day_opt = "start" # `other` will be compared to March 1. + + assert roll_qtrday(other, n, month, day_opt, modby=12) == expected[n] + + +@pytest.mark.parametrize( + "other,expected", + [ + # Before June 30. + (datetime(1999, 6, 29), {5: 4, -7: -7, 0: 0}), + # After June 30. + (Timestamp(2072, 8, 24, 6, 17, 18), {5: 5, -7: -6, 0: 1}), + ], +) +@pytest.mark.parametrize("n", [5, -7, 0]) +def test_roll_qtrday_year2(other, expected, n): + month = 6 + day_opt = "end" # `other` will be compared to June 30. + + assert roll_qtrday(other, n, month, day_opt, modby=12) == expected[n] + + +def test_get_day_of_month_error(): + # get_day_of_month is not directly exposed. + # We test it via roll_qtrday. + dt = datetime(2017, 11, 15) + day_opt = "foo" + + with pytest.raises(ValueError, match=day_opt): + # To hit the raising case we need month == dt.month and n > 0. + roll_qtrday(dt, n=3, month=11, day_opt=day_opt, modby=12) + + +@pytest.mark.parametrize( + "month", + [3, 5], # (other.month % 3) < (month % 3) # (other.month % 3) > (month % 3) +) +@pytest.mark.parametrize("n", [4, -3]) +def test_roll_qtr_day_not_mod_unequal(day_opt, month, n): + expected = {3: {-3: -2, 4: 4}, 5: {-3: -3, 4: 3}} + + other = Timestamp(2072, 10, 1, 6, 17, 18) # Saturday. + assert roll_qtrday(other, n, month, day_opt, modby=3) == expected[month][n] + + +@pytest.mark.parametrize( + "other,month,exp_dict", + [ + # Monday. + (datetime(1999, 5, 31), 2, {-1: {"start": 0, "business_start": 0}}), + # Saturday. + ( + Timestamp(2072, 10, 1, 6, 17, 18), + 4, + {2: {"end": 1, "business_end": 1, "business_start": 1}}, + ), + # First business day. + ( + Timestamp(2072, 10, 3, 6, 17, 18), + 4, + {2: {"end": 1, "business_end": 1}, -1: {"start": 0}}, + ), + ], +) +@pytest.mark.parametrize("n", [2, -1]) +def test_roll_qtr_day_mod_equal(other, month, exp_dict, n, day_opt): + # All cases have (other.month % 3) == (month % 3). + expected = exp_dict.get(n, {}).get(day_opt, n) + assert roll_qtrday(other, n, month, day_opt, modby=3) == expected + + +@pytest.mark.parametrize( + "n,expected", [(42, {29: 42, 1: 42, 31: 41}), (-4, {29: -4, 1: -3, 31: -4})] +) +@pytest.mark.parametrize("compare", [29, 1, 31]) +def test_roll_convention(n, expected, compare): + assert liboffsets.roll_convention(29, n, compare) == expected[compare] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_np_datetime.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_np_datetime.py new file mode 100644 index 0000000000000000000000000000000000000000..02edf1a09387766d71097ea0baedc2640cfb824b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_np_datetime.py @@ -0,0 +1,222 @@ +import numpy as np +import pytest + +from pandas._libs.tslibs.dtypes import NpyDatetimeUnit +from pandas._libs.tslibs.np_datetime import ( + OutOfBoundsDatetime, + OutOfBoundsTimedelta, + astype_overflowsafe, + is_unitless, + py_get_unit_from_dtype, + py_td64_to_tdstruct, +) + +import pandas._testing as tm + + +def test_is_unitless(): + dtype = np.dtype("M8[ns]") + assert not is_unitless(dtype) + + dtype = np.dtype("datetime64") + assert is_unitless(dtype) + + dtype = np.dtype("m8[ns]") + assert not is_unitless(dtype) + + dtype = np.dtype("timedelta64") + assert is_unitless(dtype) + + msg = "dtype must be datetime64 or timedelta64" + with pytest.raises(ValueError, match=msg): + is_unitless(np.dtype(np.int64)) + + msg = "Argument 'dtype' has incorrect type" + with pytest.raises(TypeError, match=msg): + is_unitless("foo") + + +def test_get_unit_from_dtype(): + # datetime64 + assert py_get_unit_from_dtype(np.dtype("M8[Y]")) == NpyDatetimeUnit.NPY_FR_Y.value + assert py_get_unit_from_dtype(np.dtype("M8[M]")) == NpyDatetimeUnit.NPY_FR_M.value + assert py_get_unit_from_dtype(np.dtype("M8[W]")) == NpyDatetimeUnit.NPY_FR_W.value + # B has been deprecated and removed -> no 3 + assert py_get_unit_from_dtype(np.dtype("M8[D]")) == NpyDatetimeUnit.NPY_FR_D.value + assert py_get_unit_from_dtype(np.dtype("M8[h]")) == NpyDatetimeUnit.NPY_FR_h.value + assert py_get_unit_from_dtype(np.dtype("M8[m]")) == NpyDatetimeUnit.NPY_FR_m.value + assert py_get_unit_from_dtype(np.dtype("M8[s]")) == NpyDatetimeUnit.NPY_FR_s.value + assert py_get_unit_from_dtype(np.dtype("M8[ms]")) == NpyDatetimeUnit.NPY_FR_ms.value + assert py_get_unit_from_dtype(np.dtype("M8[us]")) == NpyDatetimeUnit.NPY_FR_us.value + assert py_get_unit_from_dtype(np.dtype("M8[ns]")) == NpyDatetimeUnit.NPY_FR_ns.value + assert py_get_unit_from_dtype(np.dtype("M8[ps]")) == NpyDatetimeUnit.NPY_FR_ps.value + assert py_get_unit_from_dtype(np.dtype("M8[fs]")) == NpyDatetimeUnit.NPY_FR_fs.value + assert py_get_unit_from_dtype(np.dtype("M8[as]")) == NpyDatetimeUnit.NPY_FR_as.value + + # timedelta64 + assert py_get_unit_from_dtype(np.dtype("m8[Y]")) == NpyDatetimeUnit.NPY_FR_Y.value + assert py_get_unit_from_dtype(np.dtype("m8[M]")) == NpyDatetimeUnit.NPY_FR_M.value + assert py_get_unit_from_dtype(np.dtype("m8[W]")) == NpyDatetimeUnit.NPY_FR_W.value + # B has been deprecated and removed -> no 3 + assert py_get_unit_from_dtype(np.dtype("m8[D]")) == NpyDatetimeUnit.NPY_FR_D.value + assert py_get_unit_from_dtype(np.dtype("m8[h]")) == NpyDatetimeUnit.NPY_FR_h.value + assert py_get_unit_from_dtype(np.dtype("m8[m]")) == NpyDatetimeUnit.NPY_FR_m.value + assert py_get_unit_from_dtype(np.dtype("m8[s]")) == NpyDatetimeUnit.NPY_FR_s.value + assert py_get_unit_from_dtype(np.dtype("m8[ms]")) == NpyDatetimeUnit.NPY_FR_ms.value + assert py_get_unit_from_dtype(np.dtype("m8[us]")) == NpyDatetimeUnit.NPY_FR_us.value + assert py_get_unit_from_dtype(np.dtype("m8[ns]")) == NpyDatetimeUnit.NPY_FR_ns.value + assert py_get_unit_from_dtype(np.dtype("m8[ps]")) == NpyDatetimeUnit.NPY_FR_ps.value + assert py_get_unit_from_dtype(np.dtype("m8[fs]")) == NpyDatetimeUnit.NPY_FR_fs.value + assert py_get_unit_from_dtype(np.dtype("m8[as]")) == NpyDatetimeUnit.NPY_FR_as.value + + +def test_td64_to_tdstruct(): + val = 12454636234 # arbitrary value + + res1 = py_td64_to_tdstruct(val, NpyDatetimeUnit.NPY_FR_ns.value) + exp1 = { + "days": 0, + "hrs": 0, + "min": 0, + "sec": 12, + "ms": 454, + "us": 636, + "ns": 234, + "seconds": 12, + "microseconds": 454636, + "nanoseconds": 234, + } + assert res1 == exp1 + + res2 = py_td64_to_tdstruct(val, NpyDatetimeUnit.NPY_FR_us.value) + exp2 = { + "days": 0, + "hrs": 3, + "min": 27, + "sec": 34, + "ms": 636, + "us": 234, + "ns": 0, + "seconds": 12454, + "microseconds": 636234, + "nanoseconds": 0, + } + assert res2 == exp2 + + res3 = py_td64_to_tdstruct(val, NpyDatetimeUnit.NPY_FR_ms.value) + exp3 = { + "days": 144, + "hrs": 3, + "min": 37, + "sec": 16, + "ms": 234, + "us": 0, + "ns": 0, + "seconds": 13036, + "microseconds": 234000, + "nanoseconds": 0, + } + assert res3 == exp3 + + # Note this out of bounds for nanosecond Timedelta + res4 = py_td64_to_tdstruct(val, NpyDatetimeUnit.NPY_FR_s.value) + exp4 = { + "days": 144150, + "hrs": 21, + "min": 10, + "sec": 34, + "ms": 0, + "us": 0, + "ns": 0, + "seconds": 76234, + "microseconds": 0, + "nanoseconds": 0, + } + assert res4 == exp4 + + +class TestAstypeOverflowSafe: + def test_pass_non_dt64_array(self): + # check that we raise, not segfault + arr = np.arange(5) + dtype = np.dtype("M8[ns]") + + msg = ( + "astype_overflowsafe values.dtype and dtype must be either " + "both-datetime64 or both-timedelta64" + ) + with pytest.raises(TypeError, match=msg): + astype_overflowsafe(arr, dtype, copy=True) + + with pytest.raises(TypeError, match=msg): + astype_overflowsafe(arr, dtype, copy=False) + + def test_pass_non_dt64_dtype(self): + # check that we raise, not segfault + arr = np.arange(5, dtype="i8").view("M8[D]") + dtype = np.dtype("m8[ns]") + + msg = ( + "astype_overflowsafe values.dtype and dtype must be either " + "both-datetime64 or both-timedelta64" + ) + with pytest.raises(TypeError, match=msg): + astype_overflowsafe(arr, dtype, copy=True) + + with pytest.raises(TypeError, match=msg): + astype_overflowsafe(arr, dtype, copy=False) + + def test_astype_overflowsafe_dt64(self): + dtype = np.dtype("M8[ns]") + + dt = np.datetime64("2262-04-05", "D") + arr = dt + np.arange(10, dtype="m8[D]") + + # arr.astype silently overflows, so this + wrong = arr.astype(dtype) + roundtrip = wrong.astype(arr.dtype) + assert not (wrong == roundtrip).all() + + msg = "Out of bounds nanosecond timestamp" + with pytest.raises(OutOfBoundsDatetime, match=msg): + astype_overflowsafe(arr, dtype) + + # But converting to microseconds is fine, and we match numpy's results. + dtype2 = np.dtype("M8[us]") + result = astype_overflowsafe(arr, dtype2) + expected = arr.astype(dtype2) + tm.assert_numpy_array_equal(result, expected) + + def test_astype_overflowsafe_td64(self): + dtype = np.dtype("m8[ns]") + + dt = np.datetime64("2262-04-05", "D") + arr = dt + np.arange(10, dtype="m8[D]") + arr = arr.view("m8[D]") + + # arr.astype silently overflows, so this + wrong = arr.astype(dtype) + roundtrip = wrong.astype(arr.dtype) + assert not (wrong == roundtrip).all() + + msg = r"Cannot convert 106752 days to timedelta64\[ns\] without overflow" + with pytest.raises(OutOfBoundsTimedelta, match=msg): + astype_overflowsafe(arr, dtype) + + # But converting to microseconds is fine, and we match numpy's results. + dtype2 = np.dtype("m8[us]") + result = astype_overflowsafe(arr, dtype2) + expected = arr.astype(dtype2) + tm.assert_numpy_array_equal(result, expected) + + def test_astype_overflowsafe_disallow_rounding(self): + arr = np.array([-1500, 1500], dtype="M8[ns]") + dtype = np.dtype("M8[us]") + + msg = "Cannot losslessly cast '-1500 ns' to us" + with pytest.raises(ValueError, match=msg): + astype_overflowsafe(arr, dtype, round_ok=False) + + result = astype_overflowsafe(arr, dtype, round_ok=True) + expected = arr.astype(dtype) + tm.assert_numpy_array_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_npy_units.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_npy_units.py new file mode 100644 index 0000000000000000000000000000000000000000..6d05dc79fbb2cf52688547b672365802463ce6f2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_npy_units.py @@ -0,0 +1,27 @@ +import numpy as np + +from pandas._libs.tslibs.dtypes import abbrev_to_npy_unit +from pandas._libs.tslibs.vectorized import is_date_array_normalized + +# a datetime64 ndarray which *is* normalized +day_arr = np.arange(10, dtype="i8").view("M8[D]") + + +class TestIsDateArrayNormalized: + def test_is_date_array_normalized_day(self): + arr = day_arr + abbrev = "D" + unit = abbrev_to_npy_unit(abbrev) + result = is_date_array_normalized(arr.view("i8"), None, unit) + assert result is True + + def test_is_date_array_normalized_seconds(self): + abbrev = "s" + arr = day_arr.astype(f"M8[{abbrev}]") + unit = abbrev_to_npy_unit(abbrev) + result = is_date_array_normalized(arr.view("i8"), None, unit) + assert result is True + + arr[0] += np.timedelta64(1, abbrev) + result2 = is_date_array_normalized(arr.view("i8"), None, unit) + assert result2 is False diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_parse_iso8601.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_parse_iso8601.py new file mode 100644 index 0000000000000000000000000000000000000000..1992faae2ea6a687f8bd74b4e1e10ba53bb9e901 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_parse_iso8601.py @@ -0,0 +1,119 @@ +from datetime import datetime + +import pytest + +from pandas._libs import tslib + +from pandas import Timestamp + + +@pytest.mark.parametrize( + "date_str, exp", + [ + ("2011-01-02", datetime(2011, 1, 2)), + ("2011-1-2", datetime(2011, 1, 2)), + ("2011-01", datetime(2011, 1, 1)), + ("2011-1", datetime(2011, 1, 1)), + ("2011 01 02", datetime(2011, 1, 2)), + ("2011.01.02", datetime(2011, 1, 2)), + ("2011/01/02", datetime(2011, 1, 2)), + ("2011\\01\\02", datetime(2011, 1, 2)), + ("2013-01-01 05:30:00", datetime(2013, 1, 1, 5, 30)), + ("2013-1-1 5:30:00", datetime(2013, 1, 1, 5, 30)), + ("2013-1-1 5:30:00+01:00", Timestamp(2013, 1, 1, 5, 30, tz="UTC+01:00")), + ], +) +def test_parsers_iso8601(date_str, exp): + # see gh-12060 + # + # Test only the ISO parser - flexibility to + # different separators and leading zero's. + actual = tslib._test_parse_iso8601(date_str) + assert actual == exp + + +@pytest.mark.parametrize( + "date_str", + [ + "2011-01/02", + "2011=11=11", + "201401", + "201111", + "200101", + # Mixed separated and unseparated. + "2005-0101", + "200501-01", + "20010101 12:3456", + "20010101 1234:56", + # HHMMSS must have two digits in + # each component if unseparated. + "20010101 1", + "20010101 123", + "20010101 12345", + "20010101 12345Z", + ], +) +def test_parsers_iso8601_invalid(date_str): + msg = f'Error parsing datetime string "{date_str}"' + + with pytest.raises(ValueError, match=msg): + tslib._test_parse_iso8601(date_str) + + +def test_parsers_iso8601_invalid_offset_invalid(): + date_str = "2001-01-01 12-34-56" + msg = f'Timezone hours offset out of range in datetime string "{date_str}"' + + with pytest.raises(ValueError, match=msg): + tslib._test_parse_iso8601(date_str) + + +def test_parsers_iso8601_leading_space(): + # GH#25895 make sure isoparser doesn't overflow with long input + date_str, expected = ("2013-1-1 5:30:00", datetime(2013, 1, 1, 5, 30)) + actual = tslib._test_parse_iso8601(" " * 200 + date_str) + assert actual == expected + + +@pytest.mark.parametrize( + "date_str, timespec, exp", + [ + ("2023-01-01 00:00:00", "auto", "2023-01-01T00:00:00"), + ("2023-01-01 00:00:00", "seconds", "2023-01-01T00:00:00"), + ("2023-01-01 00:00:00", "milliseconds", "2023-01-01T00:00:00.000"), + ("2023-01-01 00:00:00", "microseconds", "2023-01-01T00:00:00.000000"), + ("2023-01-01 00:00:00", "nanoseconds", "2023-01-01T00:00:00.000000000"), + ("2023-01-01 00:00:00.001", "auto", "2023-01-01T00:00:00.001000"), + ("2023-01-01 00:00:00.001", "seconds", "2023-01-01T00:00:00"), + ("2023-01-01 00:00:00.001", "milliseconds", "2023-01-01T00:00:00.001"), + ("2023-01-01 00:00:00.001", "microseconds", "2023-01-01T00:00:00.001000"), + ("2023-01-01 00:00:00.001", "nanoseconds", "2023-01-01T00:00:00.001000000"), + ("2023-01-01 00:00:00.000001", "auto", "2023-01-01T00:00:00.000001"), + ("2023-01-01 00:00:00.000001", "seconds", "2023-01-01T00:00:00"), + ("2023-01-01 00:00:00.000001", "milliseconds", "2023-01-01T00:00:00.000"), + ("2023-01-01 00:00:00.000001", "microseconds", "2023-01-01T00:00:00.000001"), + ("2023-01-01 00:00:00.000001", "nanoseconds", "2023-01-01T00:00:00.000001000"), + ("2023-01-01 00:00:00.000000001", "auto", "2023-01-01T00:00:00.000000001"), + ("2023-01-01 00:00:00.000000001", "seconds", "2023-01-01T00:00:00"), + ("2023-01-01 00:00:00.000000001", "milliseconds", "2023-01-01T00:00:00.000"), + ("2023-01-01 00:00:00.000000001", "microseconds", "2023-01-01T00:00:00.000000"), + ( + "2023-01-01 00:00:00.000000001", + "nanoseconds", + "2023-01-01T00:00:00.000000001", + ), + ("2023-01-01 00:00:00.000001001", "auto", "2023-01-01T00:00:00.000001001"), + ("2023-01-01 00:00:00.000001001", "seconds", "2023-01-01T00:00:00"), + ("2023-01-01 00:00:00.000001001", "milliseconds", "2023-01-01T00:00:00.000"), + ("2023-01-01 00:00:00.000001001", "microseconds", "2023-01-01T00:00:00.000001"), + ( + "2023-01-01 00:00:00.000001001", + "nanoseconds", + "2023-01-01T00:00:00.000001001", + ), + ], +) +def test_iso8601_formatter(date_str: str, timespec: str, exp: str): + # GH#53020 + ts = Timestamp(date_str) + assert ts.isoformat(timespec=timespec) == exp diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_parsing.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_parsing.py new file mode 100644 index 0000000000000000000000000000000000000000..d8f23156bd4d41f7d3cc1434a7b56b245837535d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_parsing.py @@ -0,0 +1,414 @@ +""" +Tests for Timestamp parsing, aimed at pandas/_libs/tslibs/parsing.pyx +""" +from datetime import datetime +import re + +from dateutil.parser import parse as du_parse +from dateutil.tz import tzlocal +from hypothesis import given +import numpy as np +import pytest + +from pandas._libs.tslibs import ( + parsing, + strptime, +) +from pandas._libs.tslibs.parsing import parse_datetime_string_with_reso +from pandas.compat import ( + ISMUSL, + is_platform_windows, +) +import pandas.util._test_decorators as td + +import pandas._testing as tm +from pandas._testing._hypothesis import DATETIME_NO_TZ + + +@pytest.mark.skipif( + is_platform_windows() or ISMUSL, + reason="TZ setting incorrect on Windows and MUSL Linux", +) +def test_parsing_tzlocal_deprecated(): + # GH#50791 + msg = ( + "Parsing 'EST' as tzlocal.*" + "Pass the 'tz' keyword or call tz_localize after construction instead" + ) + dtstr = "Jan 15 2004 03:00 EST" + + with tm.set_timezone("US/Eastern"): + with tm.assert_produces_warning(FutureWarning, match=msg): + res, _ = parse_datetime_string_with_reso(dtstr) + + assert isinstance(res.tzinfo, tzlocal) + + with tm.assert_produces_warning(FutureWarning, match=msg): + res = parsing.py_parse_datetime_string(dtstr) + assert isinstance(res.tzinfo, tzlocal) + + +def test_parse_datetime_string_with_reso(): + (parsed, reso) = parse_datetime_string_with_reso("4Q1984") + (parsed_lower, reso_lower) = parse_datetime_string_with_reso("4q1984") + + assert reso == reso_lower + assert parsed == parsed_lower + + +def test_parse_datetime_string_with_reso_nanosecond_reso(): + # GH#46811 + parsed, reso = parse_datetime_string_with_reso("2022-04-20 09:19:19.123456789") + assert reso == "nanosecond" + + +def test_parse_datetime_string_with_reso_invalid_type(): + # Raise on invalid input, don't just return it + msg = "Argument 'date_string' has incorrect type (expected str, got tuple)" + with pytest.raises(TypeError, match=re.escape(msg)): + parse_datetime_string_with_reso((4, 5)) + + +@pytest.mark.parametrize( + "dashed,normal", [("1988-Q2", "1988Q2"), ("2Q-1988", "2Q1988")] +) +def test_parse_time_quarter_with_dash(dashed, normal): + # see gh-9688 + (parsed_dash, reso_dash) = parse_datetime_string_with_reso(dashed) + (parsed, reso) = parse_datetime_string_with_reso(normal) + + assert parsed_dash == parsed + assert reso_dash == reso + + +@pytest.mark.parametrize("dashed", ["-2Q1992", "2-Q1992", "4-4Q1992"]) +def test_parse_time_quarter_with_dash_error(dashed): + msg = f"Unknown datetime string format, unable to parse: {dashed}" + + with pytest.raises(parsing.DateParseError, match=msg): + parse_datetime_string_with_reso(dashed) + + +@pytest.mark.parametrize( + "date_string,expected", + [ + ("123.1234", False), + ("-50000", False), + ("999", False), + ("m", False), + ("T", False), + ("Mon Sep 16, 2013", True), + ("2012-01-01", True), + ("01/01/2012", True), + ("01012012", True), + ("0101", True), + ("1-1", True), + ], +) +def test_does_not_convert_mixed_integer(date_string, expected): + assert parsing._does_string_look_like_datetime(date_string) is expected + + +@pytest.mark.parametrize( + "date_str,kwargs,msg", + [ + ( + "2013Q5", + {}, + ( + "Incorrect quarterly string is given, " + "quarter must be between 1 and 4: 2013Q5" + ), + ), + # see gh-5418 + ( + "2013Q1", + {"freq": "INVLD-L-DEC-SAT"}, + ( + "Unable to retrieve month information " + "from given freq: INVLD-L-DEC-SAT" + ), + ), + ], +) +def test_parsers_quarterly_with_freq_error(date_str, kwargs, msg): + with pytest.raises(parsing.DateParseError, match=msg): + parsing.parse_datetime_string_with_reso(date_str, **kwargs) + + +@pytest.mark.parametrize( + "date_str,freq,expected", + [ + ("2013Q2", None, datetime(2013, 4, 1)), + ("2013Q2", "Y-APR", datetime(2012, 8, 1)), + ("2013-Q2", "Y-DEC", datetime(2013, 4, 1)), + ], +) +def test_parsers_quarterly_with_freq(date_str, freq, expected): + result, _ = parsing.parse_datetime_string_with_reso(date_str, freq=freq) + assert result == expected + + +@pytest.mark.parametrize( + "date_str", ["2Q 2005", "2Q-200Y", "2Q-200", "22Q2005", "2Q200.", "6Q-20"] +) +def test_parsers_quarter_invalid(date_str): + if date_str == "6Q-20": + msg = ( + "Incorrect quarterly string is given, quarter " + f"must be between 1 and 4: {date_str}" + ) + else: + msg = f"Unknown datetime string format, unable to parse: {date_str}" + + with pytest.raises(ValueError, match=msg): + parsing.parse_datetime_string_with_reso(date_str) + + +@pytest.mark.parametrize( + "date_str,expected", + [("201101", datetime(2011, 1, 1, 0, 0)), ("200005", datetime(2000, 5, 1, 0, 0))], +) +def test_parsers_month_freq(date_str, expected): + result, _ = parsing.parse_datetime_string_with_reso(date_str, freq="ME") + assert result == expected + + +@td.skip_if_not_us_locale +@pytest.mark.parametrize( + "string,fmt", + [ + ("20111230", "%Y%m%d"), + ("201112300000", "%Y%m%d%H%M"), + ("20111230000000", "%Y%m%d%H%M%S"), + ("20111230T00", "%Y%m%dT%H"), + ("20111230T0000", "%Y%m%dT%H%M"), + ("20111230T000000", "%Y%m%dT%H%M%S"), + ("2011-12-30", "%Y-%m-%d"), + ("2011", "%Y"), + ("2011-01", "%Y-%m"), + ("30-12-2011", "%d-%m-%Y"), + ("2011-12-30 00:00:00", "%Y-%m-%d %H:%M:%S"), + ("2011-12-30T00:00:00", "%Y-%m-%dT%H:%M:%S"), + ("2011-12-30T00:00:00UTC", "%Y-%m-%dT%H:%M:%S%Z"), + ("2011-12-30T00:00:00Z", "%Y-%m-%dT%H:%M:%S%z"), + ("2011-12-30T00:00:00+9", "%Y-%m-%dT%H:%M:%S%z"), + ("2011-12-30T00:00:00+09", "%Y-%m-%dT%H:%M:%S%z"), + ("2011-12-30T00:00:00+090", None), + ("2011-12-30T00:00:00+0900", "%Y-%m-%dT%H:%M:%S%z"), + ("2011-12-30T00:00:00-0900", "%Y-%m-%dT%H:%M:%S%z"), + ("2011-12-30T00:00:00+09:00", "%Y-%m-%dT%H:%M:%S%z"), + ("2011-12-30T00:00:00+09:000", None), + ("2011-12-30T00:00:00+9:0", "%Y-%m-%dT%H:%M:%S%z"), + ("2011-12-30T00:00:00+09:", None), + ("2011-12-30T00:00:00.000000UTC", "%Y-%m-%dT%H:%M:%S.%f%Z"), + ("2011-12-30T00:00:00.000000Z", "%Y-%m-%dT%H:%M:%S.%f%z"), + ("2011-12-30T00:00:00.000000+9", "%Y-%m-%dT%H:%M:%S.%f%z"), + ("2011-12-30T00:00:00.000000+09", "%Y-%m-%dT%H:%M:%S.%f%z"), + ("2011-12-30T00:00:00.000000+090", None), + ("2011-12-30T00:00:00.000000+0900", "%Y-%m-%dT%H:%M:%S.%f%z"), + ("2011-12-30T00:00:00.000000-0900", "%Y-%m-%dT%H:%M:%S.%f%z"), + ("2011-12-30T00:00:00.000000+09:00", "%Y-%m-%dT%H:%M:%S.%f%z"), + ("2011-12-30T00:00:00.000000+09:000", None), + ("2011-12-30T00:00:00.000000+9:0", "%Y-%m-%dT%H:%M:%S.%f%z"), + ("2011-12-30T00:00:00.000000+09:", None), + ("2011-12-30 00:00:00.000000", "%Y-%m-%d %H:%M:%S.%f"), + ("Tue 24 Aug 2021 01:30:48", "%a %d %b %Y %H:%M:%S"), + ("Tuesday 24 Aug 2021 01:30:48", "%A %d %b %Y %H:%M:%S"), + ("Tue 24 Aug 2021 01:30:48 AM", "%a %d %b %Y %I:%M:%S %p"), + ("Tuesday 24 Aug 2021 01:30:48 AM", "%A %d %b %Y %I:%M:%S %p"), + ("27.03.2003 14:55:00.000", "%d.%m.%Y %H:%M:%S.%f"), # GH50317 + ], +) +def test_guess_datetime_format_with_parseable_formats(string, fmt): + with tm.maybe_produces_warning( + UserWarning, fmt is not None and re.search(r"%d.*%m", fmt) + ): + result = parsing.guess_datetime_format(string) + assert result == fmt + + +@pytest.mark.parametrize("dayfirst,expected", [(True, "%d/%m/%Y"), (False, "%m/%d/%Y")]) +def test_guess_datetime_format_with_dayfirst(dayfirst, expected): + ambiguous_string = "01/01/2011" + result = parsing.guess_datetime_format(ambiguous_string, dayfirst=dayfirst) + assert result == expected + + +@td.skip_if_not_us_locale +@pytest.mark.parametrize( + "string,fmt", + [ + ("30/Dec/2011", "%d/%b/%Y"), + ("30/December/2011", "%d/%B/%Y"), + ("30/Dec/2011 00:00:00", "%d/%b/%Y %H:%M:%S"), + ], +) +def test_guess_datetime_format_with_locale_specific_formats(string, fmt): + result = parsing.guess_datetime_format(string) + assert result == fmt + + +@pytest.mark.parametrize( + "invalid_dt", + [ + "01/2013", + "12:00:00", + "1/1/1/1", + "this_is_not_a_datetime", + "51a", + "13/2019", + "202001", # YYYYMM isn't ISO8601 + "2020/01", # YYYY/MM isn't ISO8601 either + "87156549591102612381000001219H5", + ], +) +def test_guess_datetime_format_invalid_inputs(invalid_dt): + # A datetime string must include a year, month and a day for it to be + # guessable, in addition to being a string that looks like a datetime. + assert parsing.guess_datetime_format(invalid_dt) is None + + +@pytest.mark.parametrize("invalid_type_dt", [9, datetime(2011, 1, 1)]) +def test_guess_datetime_format_wrong_type_inputs(invalid_type_dt): + # A datetime string must include a year, month and a day for it to be + # guessable, in addition to being a string that looks like a datetime. + with pytest.raises( + TypeError, + match=r"^Argument 'dt_str' has incorrect type \(expected str, got .*\)$", + ): + parsing.guess_datetime_format(invalid_type_dt) + + +@pytest.mark.parametrize( + "string,fmt,dayfirst,warning", + [ + ("2011-1-1", "%Y-%m-%d", False, None), + ("2011-1-1", "%Y-%d-%m", True, None), + ("1/1/2011", "%m/%d/%Y", False, None), + ("1/1/2011", "%d/%m/%Y", True, None), + ("30-1-2011", "%d-%m-%Y", False, UserWarning), + ("30-1-2011", "%d-%m-%Y", True, None), + ("2011-1-1 0:0:0", "%Y-%m-%d %H:%M:%S", False, None), + ("2011-1-1 0:0:0", "%Y-%d-%m %H:%M:%S", True, None), + ("2011-1-3T00:00:0", "%Y-%m-%dT%H:%M:%S", False, None), + ("2011-1-3T00:00:0", "%Y-%d-%mT%H:%M:%S", True, None), + ("2011-1-1 00:00:00", "%Y-%m-%d %H:%M:%S", False, None), + ("2011-1-1 00:00:00", "%Y-%d-%m %H:%M:%S", True, None), + ], +) +def test_guess_datetime_format_no_padding(string, fmt, dayfirst, warning): + # see gh-11142 + msg = ( + rf"Parsing dates in {fmt} format when dayfirst=False \(the default\) " + "was specified. " + "Pass `dayfirst=True` or specify a format to silence this warning." + ) + with tm.assert_produces_warning(warning, match=msg): + result = parsing.guess_datetime_format(string, dayfirst=dayfirst) + assert result == fmt + + +def test_try_parse_dates(): + arr = np.array(["5/1/2000", "6/1/2000", "7/1/2000"], dtype=object) + result = parsing.try_parse_dates(arr, parser=lambda x: du_parse(x, dayfirst=True)) + + expected = np.array([du_parse(d, dayfirst=True) for d in arr]) + tm.assert_numpy_array_equal(result, expected) + + +def test_parse_datetime_string_with_reso_check_instance_type_raise_exception(): + # issue 20684 + msg = "Argument 'date_string' has incorrect type (expected str, got tuple)" + with pytest.raises(TypeError, match=re.escape(msg)): + parse_datetime_string_with_reso((1, 2, 3)) + + result = parse_datetime_string_with_reso("2019") + expected = (datetime(2019, 1, 1), "year") + assert result == expected + + +@pytest.mark.parametrize( + "fmt,expected", + [ + ("%Y %m %d %H:%M:%S", True), + ("%Y/%m/%d %H:%M:%S", True), + (r"%Y\%m\%d %H:%M:%S", True), + ("%Y-%m-%d %H:%M:%S", True), + ("%Y.%m.%d %H:%M:%S", True), + ("%Y%m%d %H:%M:%S", True), + ("%Y-%m-%dT%H:%M:%S", True), + ("%Y-%m-%dT%H:%M:%S%z", True), + ("%Y-%m-%dT%H:%M:%S%Z", False), + ("%Y-%m-%dT%H:%M:%S.%f", True), + ("%Y-%m-%dT%H:%M:%S.%f%z", True), + ("%Y-%m-%dT%H:%M:%S.%f%Z", False), + ("%Y%m%d", True), + ("%Y%m", False), + ("%Y", True), + ("%Y-%m-%d", True), + ("%Y-%m", True), + ], +) +def test_is_iso_format(fmt, expected): + # see gh-41047 + result = strptime._test_format_is_iso(fmt) + assert result == expected + + +@pytest.mark.parametrize( + "input", + [ + "2018-01-01T00:00:00.123456789", + "2018-01-01T00:00:00.123456", + "2018-01-01T00:00:00.123", + ], +) +def test_guess_datetime_format_f(input): + # https://github.com/pandas-dev/pandas/issues/49043 + result = parsing.guess_datetime_format(input) + expected = "%Y-%m-%dT%H:%M:%S.%f" + assert result == expected + + +def _helper_hypothesis_delimited_date(call, date_string, **kwargs): + msg, result = None, None + try: + result = call(date_string, **kwargs) + except ValueError as err: + msg = str(err) + return msg, result + + +@given(DATETIME_NO_TZ) +@pytest.mark.parametrize("delimiter", list(" -./")) +@pytest.mark.parametrize("dayfirst", [True, False]) +@pytest.mark.parametrize( + "date_format", + ["%d %m %Y", "%m %d %Y", "%m %Y", "%Y %m %d", "%y %m %d", "%Y%m%d", "%y%m%d"], +) +def test_hypothesis_delimited_date( + request, date_format, dayfirst, delimiter, test_datetime +): + if date_format == "%m %Y" and delimiter == ".": + request.applymarker( + pytest.mark.xfail( + reason="parse_datetime_string cannot reliably tell whether " + "e.g. %m.%Y is a float or a date" + ) + ) + date_string = test_datetime.strftime(date_format.replace(" ", delimiter)) + + except_out_dateutil, result = _helper_hypothesis_delimited_date( + parsing.py_parse_datetime_string, date_string, dayfirst=dayfirst + ) + except_in_dateutil, expected = _helper_hypothesis_delimited_date( + du_parse, + date_string, + default=datetime(1, 1, 1), + dayfirst=dayfirst, + yearfirst=False, + ) + + assert except_out_dateutil == except_in_dateutil + assert result == expected diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_period.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_period.py new file mode 100644 index 0000000000000000000000000000000000000000..715e2d3da88dbc69fd6a376f21b7bba78a46ca9f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_period.py @@ -0,0 +1,123 @@ +import numpy as np +import pytest + +from pandas._libs.tslibs import ( + iNaT, + to_offset, +) +from pandas._libs.tslibs.period import ( + extract_ordinals, + get_period_field_arr, + period_asfreq, + period_ordinal, +) + +import pandas._testing as tm + + +def get_freq_code(freqstr: str) -> int: + off = to_offset(freqstr, is_period=True) + # error: "BaseOffset" has no attribute "_period_dtype_code" + code = off._period_dtype_code # type: ignore[attr-defined] + return code + + +@pytest.mark.parametrize( + "freq1,freq2,expected", + [ + ("D", "h", 24), + ("D", "min", 1440), + ("D", "s", 86400), + ("D", "ms", 86400000), + ("D", "us", 86400000000), + ("D", "ns", 86400000000000), + ("h", "min", 60), + ("h", "s", 3600), + ("h", "ms", 3600000), + ("h", "us", 3600000000), + ("h", "ns", 3600000000000), + ("min", "s", 60), + ("min", "ms", 60000), + ("min", "us", 60000000), + ("min", "ns", 60000000000), + ("s", "ms", 1000), + ("s", "us", 1000000), + ("s", "ns", 1000000000), + ("ms", "us", 1000), + ("ms", "ns", 1000000), + ("us", "ns", 1000), + ], +) +def test_intra_day_conversion_factors(freq1, freq2, expected): + assert ( + period_asfreq(1, get_freq_code(freq1), get_freq_code(freq2), False) == expected + ) + + +@pytest.mark.parametrize( + "freq,expected", [("Y", 0), ("M", 0), ("W", 1), ("D", 0), ("B", 0)] +) +def test_period_ordinal_start_values(freq, expected): + # information for Jan. 1, 1970. + assert period_ordinal(1970, 1, 1, 0, 0, 0, 0, 0, get_freq_code(freq)) == expected + + +@pytest.mark.parametrize( + "dt,expected", + [ + ((1970, 1, 4, 0, 0, 0, 0, 0), 1), + ((1970, 1, 5, 0, 0, 0, 0, 0), 2), + ((2013, 10, 6, 0, 0, 0, 0, 0), 2284), + ((2013, 10, 7, 0, 0, 0, 0, 0), 2285), + ], +) +def test_period_ordinal_week(dt, expected): + args = dt + (get_freq_code("W"),) + assert period_ordinal(*args) == expected + + +@pytest.mark.parametrize( + "day,expected", + [ + # Thursday (Oct. 3, 2013). + (3, 11415), + # Friday (Oct. 4, 2013). + (4, 11416), + # Saturday (Oct. 5, 2013). + (5, 11417), + # Sunday (Oct. 6, 2013). + (6, 11417), + # Monday (Oct. 7, 2013). + (7, 11417), + # Tuesday (Oct. 8, 2013). + (8, 11418), + ], +) +def test_period_ordinal_business_day(day, expected): + # 5000 is PeriodDtypeCode for BusinessDay + args = (2013, 10, day, 0, 0, 0, 0, 0, 5000) + assert period_ordinal(*args) == expected + + +class TestExtractOrdinals: + def test_extract_ordinals_raises(self): + # with non-object, make sure we raise TypeError, not segfault + arr = np.arange(5) + freq = to_offset("D") + with pytest.raises(TypeError, match="values must be object-dtype"): + extract_ordinals(arr, freq) + + def test_extract_ordinals_2d(self): + freq = to_offset("D") + arr = np.empty(10, dtype=object) + arr[:] = iNaT + + res = extract_ordinals(arr, freq) + res2 = extract_ordinals(arr.reshape(5, 2), freq) + tm.assert_numpy_array_equal(res, res2.reshape(-1)) + + +def test_get_period_field_array_raises_on_out_of_range(): + msg = "Buffer dtype mismatch, expected 'const int64_t' but got 'double'" + with pytest.raises(ValueError, match=msg): + get_period_field_arr(-1, np.empty(1), 0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_resolution.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_resolution.py new file mode 100644 index 0000000000000000000000000000000000000000..690962f1daa5eebd047d11297914eb36b494e0dc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_resolution.py @@ -0,0 +1,57 @@ +import numpy as np +import pytest +import pytz + +from pandas._libs.tslibs import ( + Resolution, + get_resolution, +) +from pandas._libs.tslibs.dtypes import NpyDatetimeUnit + +import pandas._testing as tm + + +def test_get_resolution_nano(): + # don't return the fallback RESO_DAY + arr = np.array([1], dtype=np.int64) + res = get_resolution(arr) + assert res == Resolution.RESO_NS + + +def test_get_resolution_non_nano_data(): + arr = np.array([1], dtype=np.int64) + res = get_resolution(arr, None, NpyDatetimeUnit.NPY_FR_us.value) + assert res == Resolution.RESO_US + + res = get_resolution(arr, pytz.UTC, NpyDatetimeUnit.NPY_FR_us.value) + assert res == Resolution.RESO_US + + +@pytest.mark.parametrize( + "freqstr,expected", + [ + ("Y", "year"), + ("Q", "quarter"), + ("M", "month"), + ("D", "day"), + ("h", "hour"), + ("min", "minute"), + ("s", "second"), + ("ms", "millisecond"), + ("us", "microsecond"), + ("ns", "nanosecond"), + ], +) +def test_get_attrname_from_abbrev(freqstr, expected): + reso = Resolution.get_reso_from_freqstr(freqstr) + assert reso.attr_abbrev == freqstr + assert reso.attrname == expected + + +@pytest.mark.parametrize("freq", ["A", "H", "T", "S", "L", "U", "N"]) +def test_units_A_H_T_S_L_U_N_deprecated_from_attrname_to_abbrevs(freq): + # GH#52536 + msg = f"'{freq}' is deprecated and will be removed in a future version." + + with tm.assert_produces_warning(FutureWarning, match=msg): + Resolution.get_reso_from_freqstr(freq) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_strptime.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_strptime.py new file mode 100644 index 0000000000000000000000000000000000000000..d726006b03f6d43cb94a91518daabb2b29b757e0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_strptime.py @@ -0,0 +1,110 @@ +from datetime import ( + datetime, + timezone, +) + +import numpy as np +import pytest + +from pandas._libs.tslibs.dtypes import NpyDatetimeUnit +from pandas._libs.tslibs.strptime import array_strptime + +from pandas import ( + NaT, + Timestamp, +) +import pandas._testing as tm + +creso_infer = NpyDatetimeUnit.NPY_FR_GENERIC.value + + +class TestArrayStrptimeResolutionInference: + def test_array_strptime_resolution_all_nat(self): + arr = np.array([NaT, np.nan], dtype=object) + + fmt = "%Y-%m-%d %H:%M:%S" + res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer) + assert res.dtype == "M8[s]" + + res, _ = array_strptime(arr, fmt=fmt, utc=True, creso=creso_infer) + assert res.dtype == "M8[s]" + + @pytest.mark.parametrize("tz", [None, timezone.utc]) + def test_array_strptime_resolution_inference_homogeneous_strings(self, tz): + dt = datetime(2016, 1, 2, 3, 4, 5, 678900, tzinfo=tz) + + fmt = "%Y-%m-%d %H:%M:%S" + dtstr = dt.strftime(fmt) + arr = np.array([dtstr] * 3, dtype=object) + expected = np.array([dt.replace(tzinfo=None)] * 3, dtype="M8[s]") + + res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer) + tm.assert_numpy_array_equal(res, expected) + + fmt = "%Y-%m-%d %H:%M:%S.%f" + dtstr = dt.strftime(fmt) + arr = np.array([dtstr] * 3, dtype=object) + expected = np.array([dt.replace(tzinfo=None)] * 3, dtype="M8[us]") + + res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer) + tm.assert_numpy_array_equal(res, expected) + + fmt = "ISO8601" + res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer) + tm.assert_numpy_array_equal(res, expected) + + @pytest.mark.parametrize("tz", [None, timezone.utc]) + def test_array_strptime_resolution_mixed(self, tz): + dt = datetime(2016, 1, 2, 3, 4, 5, 678900, tzinfo=tz) + + ts = Timestamp(dt).as_unit("ns") + + arr = np.array([dt, ts], dtype=object) + expected = np.array( + [Timestamp(dt).as_unit("ns").asm8, ts.asm8], + dtype="M8[ns]", + ) + + fmt = "%Y-%m-%d %H:%M:%S" + res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer) + tm.assert_numpy_array_equal(res, expected) + + fmt = "ISO8601" + res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer) + tm.assert_numpy_array_equal(res, expected) + + def test_array_strptime_resolution_todaynow(self): + # specifically case where today/now is the *first* item + vals = np.array(["today", np.datetime64("2017-01-01", "us")], dtype=object) + + now = Timestamp("now").asm8 + res, _ = array_strptime(vals, fmt="%Y-%m-%d", utc=False, creso=creso_infer) + res2, _ = array_strptime( + vals[::-1], fmt="%Y-%m-%d", utc=False, creso=creso_infer + ) + + # 1s is an arbitrary cutoff for call overhead; in local testing the + # actual difference is about 250us + tolerance = np.timedelta64(1, "s") + + assert res.dtype == "M8[us]" + assert abs(res[0] - now) < tolerance + assert res[1] == vals[1] + + assert res2.dtype == "M8[us]" + assert abs(res2[1] - now) < tolerance * 2 + assert res2[0] == vals[1] + + def test_array_strptime_str_outside_nano_range(self): + vals = np.array(["2401-09-15"], dtype=object) + expected = np.array(["2401-09-15"], dtype="M8[s]") + fmt = "ISO8601" + res, _ = array_strptime(vals, fmt=fmt, creso=creso_infer) + tm.assert_numpy_array_equal(res, expected) + + # non-iso -> different path + vals2 = np.array(["Sep 15, 2401"], dtype=object) + expected2 = np.array(["2401-09-15"], dtype="M8[s]") + fmt2 = "%b %d, %Y" + res2, _ = array_strptime(vals2, fmt=fmt2, creso=creso_infer) + tm.assert_numpy_array_equal(res2, expected2) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_timedeltas.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_timedeltas.py new file mode 100644 index 0000000000000000000000000000000000000000..4784a6d0d600dcc77e359fb3d7d56301f78270d2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_timedeltas.py @@ -0,0 +1,149 @@ +import re + +import numpy as np +import pytest + +from pandas._libs.tslibs.timedeltas import ( + array_to_timedelta64, + delta_to_nanoseconds, + ints_to_pytimedelta, +) + +from pandas import ( + Timedelta, + offsets, +) +import pandas._testing as tm + + +@pytest.mark.parametrize( + "obj,expected", + [ + (np.timedelta64(14, "D"), 14 * 24 * 3600 * 1e9), + (Timedelta(minutes=-7), -7 * 60 * 1e9), + (Timedelta(minutes=-7).to_pytimedelta(), -7 * 60 * 1e9), + (Timedelta(seconds=1234e-9), 1234), # GH43764, GH40946 + ( + Timedelta(seconds=1e-9, milliseconds=1e-5, microseconds=1e-1), + 111, + ), # GH43764 + ( + Timedelta(days=1, seconds=1e-9, milliseconds=1e-5, microseconds=1e-1), + 24 * 3600e9 + 111, + ), # GH43764 + (offsets.Nano(125), 125), + ], +) +def test_delta_to_nanoseconds(obj, expected): + result = delta_to_nanoseconds(obj) + assert result == expected + + +def test_delta_to_nanoseconds_error(): + obj = np.array([123456789], dtype="m8[ns]") + + with pytest.raises(TypeError, match=""): + delta_to_nanoseconds(obj) + + with pytest.raises(TypeError, match="float"): + delta_to_nanoseconds(1.5) + with pytest.raises(TypeError, match="int"): + delta_to_nanoseconds(1) + with pytest.raises(TypeError, match="int"): + delta_to_nanoseconds(np.int64(2)) + with pytest.raises(TypeError, match="int"): + delta_to_nanoseconds(np.int32(3)) + + +def test_delta_to_nanoseconds_td64_MY_raises(): + msg = ( + "delta_to_nanoseconds does not support Y or M units, " + "as their duration in nanoseconds is ambiguous" + ) + + td = np.timedelta64(1234, "Y") + + with pytest.raises(ValueError, match=msg): + delta_to_nanoseconds(td) + + td = np.timedelta64(1234, "M") + + with pytest.raises(ValueError, match=msg): + delta_to_nanoseconds(td) + + +@pytest.mark.parametrize("unit", ["Y", "M"]) +def test_unsupported_td64_unit_raises(unit): + # GH 52806 + with pytest.raises( + ValueError, + match=f"Unit {unit} is not supported. " + "Only unambiguous timedelta values durations are supported. " + "Allowed units are 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns'", + ): + Timedelta(np.timedelta64(1, unit)) + + +def test_huge_nanoseconds_overflow(): + # GH 32402 + assert delta_to_nanoseconds(Timedelta(1e10)) == 1e10 + assert delta_to_nanoseconds(Timedelta(nanoseconds=1e10)) == 1e10 + + +@pytest.mark.parametrize( + "kwargs", [{"Seconds": 1}, {"seconds": 1, "Nanoseconds": 1}, {"Foo": 2}] +) +def test_kwarg_assertion(kwargs): + err_message = ( + "cannot construct a Timedelta from the passed arguments, " + "allowed keywords are " + "[weeks, days, hours, minutes, seconds, " + "milliseconds, microseconds, nanoseconds]" + ) + + with pytest.raises(ValueError, match=re.escape(err_message)): + Timedelta(**kwargs) + + +class TestArrayToTimedelta64: + def test_array_to_timedelta64_string_with_unit_2d_raises(self): + # check the 'unit is not None and errors != "coerce"' path + # in array_to_timedelta64 raises correctly with 2D values + values = np.array([["1", 2], [3, "4"]], dtype=object) + with pytest.raises(ValueError, match="unit must not be specified"): + array_to_timedelta64(values, unit="s") + + def test_array_to_timedelta64_non_object_raises(self): + # check we raise, not segfault + values = np.arange(5) + + msg = "'values' must have object dtype" + with pytest.raises(TypeError, match=msg): + array_to_timedelta64(values) + + +@pytest.mark.parametrize("unit", ["s", "ms", "us"]) +def test_ints_to_pytimedelta(unit): + # tests for non-nanosecond cases + arr = np.arange(6, dtype=np.int64).view(f"m8[{unit}]") + + res = ints_to_pytimedelta(arr, box=False) + # For non-nanosecond, .astype(object) gives pytimedelta objects + # instead of integers + expected = arr.astype(object) + tm.assert_numpy_array_equal(res, expected) + + res = ints_to_pytimedelta(arr, box=True) + expected = np.array([Timedelta(x) for x in arr], dtype=object) + tm.assert_numpy_array_equal(res, expected) + + +@pytest.mark.parametrize("unit", ["Y", "M", "ps", "fs", "as"]) +def test_ints_to_pytimedelta_unsupported(unit): + arr = np.arange(6, dtype=np.int64).view(f"m8[{unit}]") + + with pytest.raises(NotImplementedError, match=r"\d{1,2}"): + ints_to_pytimedelta(arr, box=False) + msg = "Only resolutions 's', 'ms', 'us', 'ns' are supported" + with pytest.raises(NotImplementedError, match=msg): + ints_to_pytimedelta(arr, box=True) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_timezones.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_timezones.py new file mode 100644 index 0000000000000000000000000000000000000000..28e4889983fb964167dd74623c8e4c4585c99a96 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_timezones.py @@ -0,0 +1,168 @@ +from datetime import ( + datetime, + timedelta, + timezone, +) + +import dateutil.tz +import pytest +import pytz + +from pandas._libs.tslibs import ( + conversion, + timezones, +) +from pandas.compat import is_platform_windows + +from pandas import Timestamp + + +def test_is_utc(utc_fixture): + tz = timezones.maybe_get_tz(utc_fixture) + assert timezones.is_utc(tz) + + +@pytest.mark.parametrize("tz_name", list(pytz.common_timezones)) +def test_cache_keys_are_distinct_for_pytz_vs_dateutil(tz_name): + tz_p = timezones.maybe_get_tz(tz_name) + tz_d = timezones.maybe_get_tz("dateutil/" + tz_name) + + if tz_d is None: + pytest.skip(tz_name + ": dateutil does not know about this one") + + if not (tz_name == "UTC" and is_platform_windows()): + # they both end up as tzwin("UTC") on windows + assert timezones._p_tz_cache_key(tz_p) != timezones._p_tz_cache_key(tz_d) + + +def test_tzlocal_repr(): + # see gh-13583 + ts = Timestamp("2011-01-01", tz=dateutil.tz.tzlocal()) + assert ts.tz == dateutil.tz.tzlocal() + assert "tz='tzlocal()')" in repr(ts) + + +def test_tzlocal_maybe_get_tz(): + # see gh-13583 + tz = timezones.maybe_get_tz("tzlocal()") + assert tz == dateutil.tz.tzlocal() + + +def test_tzlocal_offset(): + # see gh-13583 + # + # Get offset using normal datetime for test. + ts = Timestamp("2011-01-01", tz=dateutil.tz.tzlocal()) + + offset = dateutil.tz.tzlocal().utcoffset(datetime(2011, 1, 1)) + offset = offset.total_seconds() + + assert ts._value + offset == Timestamp("2011-01-01")._value + + +def test_tzlocal_is_not_utc(): + # even if the machine running the test is localized to UTC + tz = dateutil.tz.tzlocal() + assert not timezones.is_utc(tz) + + assert not timezones.tz_compare(tz, dateutil.tz.tzutc()) + + +def test_tz_compare_utc(utc_fixture, utc_fixture2): + tz = timezones.maybe_get_tz(utc_fixture) + tz2 = timezones.maybe_get_tz(utc_fixture2) + assert timezones.tz_compare(tz, tz2) + + +@pytest.fixture( + params=[ + (pytz.timezone("US/Eastern"), lambda tz, x: tz.localize(x)), + (dateutil.tz.gettz("US/Eastern"), lambda tz, x: x.replace(tzinfo=tz)), + ] +) +def infer_setup(request): + eastern, localize = request.param + + start_naive = datetime(2001, 1, 1) + end_naive = datetime(2009, 1, 1) + + start = localize(eastern, start_naive) + end = localize(eastern, end_naive) + + return eastern, localize, start, end, start_naive, end_naive + + +def test_infer_tz_compat(infer_setup): + eastern, _, start, end, start_naive, end_naive = infer_setup + + assert ( + timezones.infer_tzinfo(start, end) + is conversion.localize_pydatetime(start_naive, eastern).tzinfo + ) + assert ( + timezones.infer_tzinfo(start, None) + is conversion.localize_pydatetime(start_naive, eastern).tzinfo + ) + assert ( + timezones.infer_tzinfo(None, end) + is conversion.localize_pydatetime(end_naive, eastern).tzinfo + ) + + +def test_infer_tz_utc_localize(infer_setup): + _, _, start, end, start_naive, end_naive = infer_setup + utc = pytz.utc + + start = utc.localize(start_naive) + end = utc.localize(end_naive) + + assert timezones.infer_tzinfo(start, end) is utc + + +@pytest.mark.parametrize("ordered", [True, False]) +def test_infer_tz_mismatch(infer_setup, ordered): + eastern, _, _, _, start_naive, end_naive = infer_setup + msg = "Inputs must both have the same timezone" + + utc = pytz.utc + start = utc.localize(start_naive) + end = conversion.localize_pydatetime(end_naive, eastern) + + args = (start, end) if ordered else (end, start) + + with pytest.raises(AssertionError, match=msg): + timezones.infer_tzinfo(*args) + + +def test_maybe_get_tz_invalid_types(): + with pytest.raises(TypeError, match=""): + timezones.maybe_get_tz(44.0) + + with pytest.raises(TypeError, match=""): + timezones.maybe_get_tz(pytz) + + msg = "" + with pytest.raises(TypeError, match=msg): + timezones.maybe_get_tz(Timestamp("2021-01-01", tz="UTC")) + + +def test_maybe_get_tz_offset_only(): + # see gh-36004 + + # timezone.utc + tz = timezones.maybe_get_tz(timezone.utc) + assert tz == timezone(timedelta(hours=0, minutes=0)) + + # without UTC+- prefix + tz = timezones.maybe_get_tz("+01:15") + assert tz == timezone(timedelta(hours=1, minutes=15)) + + tz = timezones.maybe_get_tz("-01:15") + assert tz == timezone(-timedelta(hours=1, minutes=15)) + + # with UTC+- prefix + tz = timezones.maybe_get_tz("UTC+02:45") + assert tz == timezone(timedelta(hours=2, minutes=45)) + + tz = timezones.maybe_get_tz("UTC-02:45") + assert tz == timezone(-timedelta(hours=2, minutes=45)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_to_offset.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_to_offset.py new file mode 100644 index 0000000000000000000000000000000000000000..8ca55648f3780e6f31621f7b5cfdcd1435b1a231 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_to_offset.py @@ -0,0 +1,219 @@ +import re + +import pytest + +from pandas._libs.tslibs import ( + Timedelta, + offsets, + to_offset, +) + + +@pytest.mark.parametrize( + "freq_input,expected", + [ + (to_offset("10us"), offsets.Micro(10)), + (offsets.Hour(), offsets.Hour()), + ("2h30min", offsets.Minute(150)), + ("2h 30min", offsets.Minute(150)), + ("2h30min15s", offsets.Second(150 * 60 + 15)), + ("2h 60min", offsets.Hour(3)), + ("2h 20.5min", offsets.Second(8430)), + ("1.5min", offsets.Second(90)), + ("0.5s", offsets.Milli(500)), + ("15ms500us", offsets.Micro(15500)), + ("10s75ms", offsets.Milli(10075)), + ("1s0.25ms", offsets.Micro(1000250)), + ("1s0.25ms", offsets.Micro(1000250)), + ("2800ns", offsets.Nano(2800)), + ("2SME", offsets.SemiMonthEnd(2)), + ("2SME-16", offsets.SemiMonthEnd(2, day_of_month=16)), + ("2SMS-14", offsets.SemiMonthBegin(2, day_of_month=14)), + ("2SMS-15", offsets.SemiMonthBegin(2)), + ], +) +def test_to_offset(freq_input, expected): + result = to_offset(freq_input) + assert result == expected + + +@pytest.mark.parametrize( + "freqstr,expected", [("-1s", -1), ("-2SME", -2), ("-1SMS", -1), ("-5min10s", -310)] +) +def test_to_offset_negative(freqstr, expected): + result = to_offset(freqstr) + assert result.n == expected + + +@pytest.mark.filterwarnings("ignore:.*'m' is deprecated.*:FutureWarning") +@pytest.mark.parametrize( + "freqstr", + [ + "2h20m", + "us1", + "-us", + "3us1", + "-2-3us", + "-2D:3h", + "1.5.0s", + "2SMS-15-15", + "2SMS-15D", + "100foo", + # Invalid leading +/- signs. + "+-1d", + "-+1h", + "+1", + "-7", + "+d", + "-m", + # Invalid shortcut anchors. + "SME-0", + "SME-28", + "SME-29", + "SME-FOO", + "BSM", + "SME--1", + "SMS-1", + "SMS-28", + "SMS-30", + "SMS-BAR", + "SMS-BYR", + "BSMS", + "SMS--2", + ], +) +def test_to_offset_invalid(freqstr): + # see gh-13930 + + # We escape string because some of our + # inputs contain regex special characters. + msg = re.escape(f"Invalid frequency: {freqstr}") + with pytest.raises(ValueError, match=msg): + to_offset(freqstr) + + +def test_to_offset_no_evaluate(): + msg = str(("", "")) + with pytest.raises(TypeError, match=msg): + to_offset(("", "")) + + +def test_to_offset_tuple_unsupported(): + with pytest.raises(TypeError, match="pass as a string instead"): + to_offset((5, "T")) + + +@pytest.mark.parametrize( + "freqstr,expected", + [ + ("2D 3h", offsets.Hour(51)), + ("2 D3 h", offsets.Hour(51)), + ("2 D 3 h", offsets.Hour(51)), + (" 2 D 3 h ", offsets.Hour(51)), + (" h ", offsets.Hour()), + (" 3 h ", offsets.Hour(3)), + ], +) +def test_to_offset_whitespace(freqstr, expected): + result = to_offset(freqstr) + assert result == expected + + +@pytest.mark.parametrize( + "freqstr,expected", [("00h 00min 01s", 1), ("-00h 03min 14s", -194)] +) +def test_to_offset_leading_zero(freqstr, expected): + result = to_offset(freqstr) + assert result.n == expected + + +@pytest.mark.parametrize("freqstr,expected", [("+1d", 1), ("+2h30min", 150)]) +def test_to_offset_leading_plus(freqstr, expected): + result = to_offset(freqstr) + assert result.n == expected + + +@pytest.mark.parametrize( + "kwargs,expected", + [ + ({"days": 1, "seconds": 1}, offsets.Second(86401)), + ({"days": -1, "seconds": 1}, offsets.Second(-86399)), + ({"hours": 1, "minutes": 10}, offsets.Minute(70)), + ({"hours": 1, "minutes": -10}, offsets.Minute(50)), + ({"weeks": 1}, offsets.Day(7)), + ({"hours": 1}, offsets.Hour(1)), + ({"hours": 1}, to_offset("60min")), + ({"microseconds": 1}, offsets.Micro(1)), + ({"microseconds": 0}, offsets.Nano(0)), + ], +) +def test_to_offset_pd_timedelta(kwargs, expected): + # see gh-9064 + td = Timedelta(**kwargs) + result = to_offset(td) + assert result == expected + + +@pytest.mark.parametrize( + "shortcut,expected", + [ + ("W", offsets.Week(weekday=6)), + ("W-SUN", offsets.Week(weekday=6)), + ("QE", offsets.QuarterEnd(startingMonth=12)), + ("QE-DEC", offsets.QuarterEnd(startingMonth=12)), + ("QE-MAY", offsets.QuarterEnd(startingMonth=5)), + ("SME", offsets.SemiMonthEnd(day_of_month=15)), + ("SME-15", offsets.SemiMonthEnd(day_of_month=15)), + ("SME-1", offsets.SemiMonthEnd(day_of_month=1)), + ("SME-27", offsets.SemiMonthEnd(day_of_month=27)), + ("SMS-2", offsets.SemiMonthBegin(day_of_month=2)), + ("SMS-27", offsets.SemiMonthBegin(day_of_month=27)), + ], +) +def test_anchored_shortcuts(shortcut, expected): + result = to_offset(shortcut) + assert result == expected + + +@pytest.mark.parametrize( + "freq_depr", + [ + "2ye-mar", + "2ys", + "2qe", + "2qs-feb", + "2bqs", + "2sms", + "2bms", + "2cbme", + "2me", + "2w", + ], +) +def test_to_offset_lowercase_frequency_deprecated(freq_depr): + # GH#54939 + depr_msg = f"'{freq_depr[1:]}' is deprecated and will be removed in a " + f"future version, please use '{freq_depr.upper()[1:]}' instead." + + with pytest.raises(FutureWarning, match=depr_msg): + to_offset(freq_depr) + + +@pytest.mark.parametrize( + "freq_depr", + [ + "2H", + "2BH", + "2MIN", + "2S", + "2Us", + "2NS", + ], +) +def test_to_offset_uppercase_frequency_deprecated(freq_depr): + # GH#54939 + depr_msg = f"'{freq_depr[1:]}' is deprecated and will be removed in a " + f"future version, please use '{freq_depr.lower()[1:]}' instead." + + with pytest.raises(FutureWarning, match=depr_msg): + to_offset(freq_depr) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_tzconversion.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_tzconversion.py new file mode 100644 index 0000000000000000000000000000000000000000..c1a56ffb71b020df338721e44d56d7e03479fef6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/tslibs/test_tzconversion.py @@ -0,0 +1,23 @@ +import numpy as np +import pytest +import pytz + +from pandas._libs.tslibs.tzconversion import tz_localize_to_utc + + +class TestTZLocalizeToUTC: + def test_tz_localize_to_utc_ambiguous_infer(self): + # val is a timestamp that is ambiguous when localized to US/Eastern + val = 1_320_541_200_000_000_000 + vals = np.array([val, val - 1, val], dtype=np.int64) + + with pytest.raises(pytz.AmbiguousTimeError, match="2011-11-06 01:00:00"): + tz_localize_to_utc(vals, pytz.timezone("US/Eastern"), ambiguous="infer") + + with pytest.raises(pytz.AmbiguousTimeError, match="are no repeated times"): + tz_localize_to_utc(vals[:1], pytz.timezone("US/Eastern"), ambiguous="infer") + + vals[1] += 1 + msg = "There are 2 dst switches when there should only be 1" + with pytest.raises(pytz.AmbiguousTimeError, match=msg): + tz_localize_to_utc(vals, pytz.timezone("US/Eastern"), ambiguous="infer") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6670cb2986892f9bc169e5be61166cdb581017be Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/conftest.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/conftest.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b97200981a0d6cf2c0ec866f912c7ffab008e1b4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/conftest.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_almost_equal.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_almost_equal.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04b39da631d00e1538e2a495ba0dc524a26703ff Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_almost_equal.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_attr_equal.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_attr_equal.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b759746d0bc1f5860bfab12e6aceb34d04f795b4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_attr_equal.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_categorical_equal.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_categorical_equal.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bcc547e8d2f667a7d846efeef26a66024eae0d55 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_categorical_equal.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_extension_array_equal.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_extension_array_equal.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94f6b4cac1c27b92c45eb3bea185bf1998aa09a0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_extension_array_equal.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_frame_equal.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_frame_equal.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e31a85c7ffe867f8e70b878903bec13023e030c1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_frame_equal.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_index_equal.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_index_equal.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c02d03fd4b30d269b2e81a8f682cdad8a447ad0d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_index_equal.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_interval_array_equal.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_interval_array_equal.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e9f422b51186ee0902dddc2bfb0ce20cfe9bbed Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_interval_array_equal.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_numpy_array_equal.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_numpy_array_equal.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ef8cd8cc1e0ab7f79c9f848daee274d039ae904 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_numpy_array_equal.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_produces_warning.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_produces_warning.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c75d135a95b9b816dd72e9d344562178c8843bda Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_produces_warning.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_series_equal.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_series_equal.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32232c8a4f7b7b27090f41f9f9ea2988a5380244 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_assert_series_equal.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_deprecate.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_deprecate.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4ed8632d7d0a3ea47e2f1c6f50561975968ba61 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_deprecate.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_deprecate_kwarg.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_deprecate_kwarg.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48dbd5d422deb59292f63d08f2ad097eb1cf47fb Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_deprecate_kwarg.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_deprecate_nonkeyword_arguments.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_deprecate_nonkeyword_arguments.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28efbd3adbec4feb40677fb41c01e294109c525e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_deprecate_nonkeyword_arguments.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_doc.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_doc.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7689993e8e24751859aeccfa7c8676f7fc30c40a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_doc.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_hashing.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_hashing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0923311cdfeceb4b5d4bfb16f7f96e47ad57cc60 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_hashing.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_numba.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_numba.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dad8c7c7b80e722ee191292379c6ad5de6b69336 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_numba.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_rewrite_warning.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_rewrite_warning.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9f5dbd9e81627c49da85f83046bc484030b908e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_rewrite_warning.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_shares_memory.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_shares_memory.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de16ef5ce22c708c9bb9eeef7889e44779706851 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_shares_memory.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_show_versions.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_show_versions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9139f68c195b534c2adaafaf45f47db7418ca8db Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_show_versions.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_util.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_util.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57a66a24062387f3c3df5f41107f923d85a5bddf Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_util.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_validate_args.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_validate_args.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2c3599f2f48044c74a71e245237f5c22f9fe197 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_validate_args.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_validate_args_and_kwargs.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_validate_args_and_kwargs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9620c422f43c2e4ae0117aeaaae0f27b03c23bd6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_validate_args_and_kwargs.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_validate_inclusive.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_validate_inclusive.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c809ff016f690c6f198338a67cd38af2c03aa0ff Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_validate_inclusive.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_validate_kwargs.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_validate_kwargs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5215d5a9365bd6874ef16d1b4735e085b14b1dca Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/__pycache__/test_validate_kwargs.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/conftest.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..b68bcc93431d015a5b9bdc47bdd7e46dd531b703 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/conftest.py @@ -0,0 +1,26 @@ +import pytest + + +@pytest.fixture(params=[True, False]) +def check_dtype(request): + return request.param + + +@pytest.fixture(params=[True, False]) +def check_exact(request): + return request.param + + +@pytest.fixture(params=[True, False]) +def check_index_type(request): + return request.param + + +@pytest.fixture(params=[0.5e-3, 0.5e-5]) +def rtol(request): + return request.param + + +@pytest.fixture(params=[True, False]) +def check_categorical(request): + return request.param diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_assert_almost_equal.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_assert_almost_equal.py new file mode 100644 index 0000000000000000000000000000000000000000..4e692084f7352f873b8c7354e7651b432058a1a5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_assert_almost_equal.py @@ -0,0 +1,586 @@ +import numpy as np +import pytest + +from pandas import ( + NA, + DataFrame, + Index, + NaT, + Series, + Timestamp, +) +import pandas._testing as tm + + +def _assert_almost_equal_both(a, b, **kwargs): + """ + Check that two objects are approximately equal. + + This check is performed commutatively. + + Parameters + ---------- + a : object + The first object to compare. + b : object + The second object to compare. + **kwargs + The arguments passed to `tm.assert_almost_equal`. + """ + tm.assert_almost_equal(a, b, **kwargs) + tm.assert_almost_equal(b, a, **kwargs) + + +def _assert_not_almost_equal(a, b, **kwargs): + """ + Check that two objects are not approximately equal. + + Parameters + ---------- + a : object + The first object to compare. + b : object + The second object to compare. + **kwargs + The arguments passed to `tm.assert_almost_equal`. + """ + try: + tm.assert_almost_equal(a, b, **kwargs) + msg = f"{a} and {b} were approximately equal when they shouldn't have been" + pytest.fail(reason=msg) + except AssertionError: + pass + + +def _assert_not_almost_equal_both(a, b, **kwargs): + """ + Check that two objects are not approximately equal. + + This check is performed commutatively. + + Parameters + ---------- + a : object + The first object to compare. + b : object + The second object to compare. + **kwargs + The arguments passed to `tm.assert_almost_equal`. + """ + _assert_not_almost_equal(a, b, **kwargs) + _assert_not_almost_equal(b, a, **kwargs) + + +@pytest.mark.parametrize( + "a,b", + [ + (1.1, 1.1), + (1.1, 1.100001), + (np.int16(1), 1.000001), + (np.float64(1.1), 1.1), + (np.uint32(5), 5), + ], +) +def test_assert_almost_equal_numbers(a, b): + _assert_almost_equal_both(a, b) + + +@pytest.mark.parametrize( + "a,b", + [ + (1.1, 1), + (1.1, True), + (1, 2), + (1.0001, np.int16(1)), + # The following two examples are not "almost equal" due to tol. + (0.1, 0.1001), + (0.0011, 0.0012), + ], +) +def test_assert_not_almost_equal_numbers(a, b): + _assert_not_almost_equal_both(a, b) + + +@pytest.mark.parametrize( + "a,b", + [ + (1.1, 1.1), + (1.1, 1.100001), + (1.1, 1.1001), + (0.000001, 0.000005), + (1000.0, 1000.0005), + # Testing this example, as per #13357 + (0.000011, 0.000012), + ], +) +def test_assert_almost_equal_numbers_atol(a, b): + # Equivalent to the deprecated check_less_precise=True, enforced in 2.0 + _assert_almost_equal_both(a, b, rtol=0.5e-3, atol=0.5e-3) + + +@pytest.mark.parametrize("a,b", [(1.1, 1.11), (0.1, 0.101), (0.000011, 0.001012)]) +def test_assert_not_almost_equal_numbers_atol(a, b): + _assert_not_almost_equal_both(a, b, atol=1e-3) + + +@pytest.mark.parametrize( + "a,b", + [ + (1.1, 1.1), + (1.1, 1.100001), + (1.1, 1.1001), + (1000.0, 1000.0005), + (1.1, 1.11), + (0.1, 0.101), + ], +) +def test_assert_almost_equal_numbers_rtol(a, b): + _assert_almost_equal_both(a, b, rtol=0.05) + + +@pytest.mark.parametrize("a,b", [(0.000011, 0.000012), (0.000001, 0.000005)]) +def test_assert_not_almost_equal_numbers_rtol(a, b): + _assert_not_almost_equal_both(a, b, rtol=0.05) + + +@pytest.mark.parametrize( + "a,b,rtol", + [ + (1.00001, 1.00005, 0.001), + (-0.908356 + 0.2j, -0.908358 + 0.2j, 1e-3), + (0.1 + 1.009j, 0.1 + 1.006j, 0.1), + (0.1001 + 2.0j, 0.1 + 2.001j, 0.01), + ], +) +def test_assert_almost_equal_complex_numbers(a, b, rtol): + _assert_almost_equal_both(a, b, rtol=rtol) + _assert_almost_equal_both(np.complex64(a), np.complex64(b), rtol=rtol) + _assert_almost_equal_both(np.complex128(a), np.complex128(b), rtol=rtol) + + +@pytest.mark.parametrize( + "a,b,rtol", + [ + (0.58310768, 0.58330768, 1e-7), + (-0.908 + 0.2j, -0.978 + 0.2j, 0.001), + (0.1 + 1j, 0.1 + 2j, 0.01), + (-0.132 + 1.001j, -0.132 + 1.005j, 1e-5), + (0.58310768j, 0.58330768j, 1e-9), + ], +) +def test_assert_not_almost_equal_complex_numbers(a, b, rtol): + _assert_not_almost_equal_both(a, b, rtol=rtol) + _assert_not_almost_equal_both(np.complex64(a), np.complex64(b), rtol=rtol) + _assert_not_almost_equal_both(np.complex128(a), np.complex128(b), rtol=rtol) + + +@pytest.mark.parametrize("a,b", [(0, 0), (0, 0.0), (0, np.float64(0)), (0.00000001, 0)]) +def test_assert_almost_equal_numbers_with_zeros(a, b): + _assert_almost_equal_both(a, b) + + +@pytest.mark.parametrize("a,b", [(0.001, 0), (1, 0)]) +def test_assert_not_almost_equal_numbers_with_zeros(a, b): + _assert_not_almost_equal_both(a, b) + + +@pytest.mark.parametrize("a,b", [(1, "abc"), (1, [1]), (1, object())]) +def test_assert_not_almost_equal_numbers_with_mixed(a, b): + _assert_not_almost_equal_both(a, b) + + +@pytest.mark.parametrize( + "left_dtype", ["M8[ns]", "m8[ns]", "float64", "int64", "object"] +) +@pytest.mark.parametrize( + "right_dtype", ["M8[ns]", "m8[ns]", "float64", "int64", "object"] +) +def test_assert_almost_equal_edge_case_ndarrays(left_dtype, right_dtype): + # Empty compare. + _assert_almost_equal_both( + np.array([], dtype=left_dtype), + np.array([], dtype=right_dtype), + check_dtype=False, + ) + + +def test_assert_almost_equal_sets(): + # GH#51727 + _assert_almost_equal_both({1, 2, 3}, {1, 2, 3}) + + +def test_assert_almost_not_equal_sets(): + # GH#51727 + msg = r"{1, 2, 3} != {1, 2, 4}" + with pytest.raises(AssertionError, match=msg): + _assert_almost_equal_both({1, 2, 3}, {1, 2, 4}) + + +def test_assert_almost_equal_dicts(): + _assert_almost_equal_both({"a": 1, "b": 2}, {"a": 1, "b": 2}) + + +@pytest.mark.parametrize( + "a,b", + [ + ({"a": 1, "b": 2}, {"a": 1, "b": 3}), + ({"a": 1, "b": 2}, {"a": 1, "b": 2, "c": 3}), + ({"a": 1}, 1), + ({"a": 1}, "abc"), + ({"a": 1}, [1]), + ], +) +def test_assert_not_almost_equal_dicts(a, b): + _assert_not_almost_equal_both(a, b) + + +@pytest.mark.parametrize("val", [1, 2]) +def test_assert_almost_equal_dict_like_object(val): + dict_val = 1 + real_dict = {"a": val} + + class DictLikeObj: + def keys(self): + return ("a",) + + def __getitem__(self, item): + if item == "a": + return dict_val + + func = ( + _assert_almost_equal_both if val == dict_val else _assert_not_almost_equal_both + ) + func(real_dict, DictLikeObj(), check_dtype=False) + + +def test_assert_almost_equal_strings(): + _assert_almost_equal_both("abc", "abc") + + +@pytest.mark.parametrize( + "a,b", [("abc", "abcd"), ("abc", "abd"), ("abc", 1), ("abc", [1])] +) +def test_assert_not_almost_equal_strings(a, b): + _assert_not_almost_equal_both(a, b) + + +@pytest.mark.parametrize( + "a,b", [([1, 2, 3], [1, 2, 3]), (np.array([1, 2, 3]), np.array([1, 2, 3]))] +) +def test_assert_almost_equal_iterables(a, b): + _assert_almost_equal_both(a, b) + + +@pytest.mark.parametrize( + "a,b", + [ + # Class is different. + (np.array([1, 2, 3]), [1, 2, 3]), + # Dtype is different. + (np.array([1, 2, 3]), np.array([1.0, 2.0, 3.0])), + # Can't compare generators. + (iter([1, 2, 3]), [1, 2, 3]), + ([1, 2, 3], [1, 2, 4]), + ([1, 2, 3], [1, 2, 3, 4]), + ([1, 2, 3], 1), + ], +) +def test_assert_not_almost_equal_iterables(a, b): + _assert_not_almost_equal(a, b) + + +def test_assert_almost_equal_null(): + _assert_almost_equal_both(None, None) + + +@pytest.mark.parametrize("a,b", [(None, np.nan), (None, 0), (np.nan, 0)]) +def test_assert_not_almost_equal_null(a, b): + _assert_not_almost_equal(a, b) + + +@pytest.mark.parametrize( + "a,b", + [ + (np.inf, np.inf), + (np.inf, float("inf")), + (np.array([np.inf, np.nan, -np.inf]), np.array([np.inf, np.nan, -np.inf])), + ], +) +def test_assert_almost_equal_inf(a, b): + _assert_almost_equal_both(a, b) + + +objs = [NA, np.nan, NaT, None, np.datetime64("NaT"), np.timedelta64("NaT")] + + +@pytest.mark.parametrize("left", objs) +@pytest.mark.parametrize("right", objs) +def test_mismatched_na_assert_almost_equal_deprecation(left, right): + left_arr = np.array([left], dtype=object) + right_arr = np.array([right], dtype=object) + + msg = "Mismatched null-like values" + + if left is right: + _assert_almost_equal_both(left, right, check_dtype=False) + tm.assert_numpy_array_equal(left_arr, right_arr) + tm.assert_index_equal( + Index(left_arr, dtype=object), Index(right_arr, dtype=object) + ) + tm.assert_series_equal( + Series(left_arr, dtype=object), Series(right_arr, dtype=object) + ) + tm.assert_frame_equal( + DataFrame(left_arr, dtype=object), DataFrame(right_arr, dtype=object) + ) + + else: + with tm.assert_produces_warning(FutureWarning, match=msg): + _assert_almost_equal_both(left, right, check_dtype=False) + + # TODO: to get the same deprecation in assert_numpy_array_equal we need + # to change/deprecate the default for strict_nan to become True + # TODO: to get the same deprecation in assert_index_equal we need to + # change/deprecate array_equivalent_object to be stricter, as + # assert_index_equal uses Index.equal which uses array_equivalent. + with tm.assert_produces_warning(FutureWarning, match=msg): + tm.assert_series_equal( + Series(left_arr, dtype=object), Series(right_arr, dtype=object) + ) + with tm.assert_produces_warning(FutureWarning, match=msg): + tm.assert_frame_equal( + DataFrame(left_arr, dtype=object), DataFrame(right_arr, dtype=object) + ) + + +def test_assert_not_almost_equal_inf(): + _assert_not_almost_equal_both(np.inf, 0) + + +@pytest.mark.parametrize( + "a,b", + [ + (Index([1.0, 1.1]), Index([1.0, 1.100001])), + (Series([1.0, 1.1]), Series([1.0, 1.100001])), + (np.array([1.1, 2.000001]), np.array([1.1, 2.0])), + (DataFrame({"a": [1.0, 1.1]}), DataFrame({"a": [1.0, 1.100001]})), + ], +) +def test_assert_almost_equal_pandas(a, b): + _assert_almost_equal_both(a, b) + + +def test_assert_almost_equal_object(): + a = [Timestamp("2011-01-01"), Timestamp("2011-01-01")] + b = [Timestamp("2011-01-01"), Timestamp("2011-01-01")] + _assert_almost_equal_both(a, b) + + +def test_assert_almost_equal_value_mismatch(): + msg = "expected 2\\.00000 but got 1\\.00000, with rtol=1e-05, atol=1e-08" + + with pytest.raises(AssertionError, match=msg): + tm.assert_almost_equal(1, 2) + + +@pytest.mark.parametrize( + "a,b,klass1,klass2", + [(np.array([1]), 1, "ndarray", "int"), (1, np.array([1]), "int", "ndarray")], +) +def test_assert_almost_equal_class_mismatch(a, b, klass1, klass2): + msg = f"""numpy array are different + +numpy array classes are different +\\[left\\]: {klass1} +\\[right\\]: {klass2}""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_almost_equal(a, b) + + +def test_assert_almost_equal_value_mismatch1(): + msg = """numpy array are different + +numpy array values are different \\(66\\.66667 %\\) +\\[left\\]: \\[nan, 2\\.0, 3\\.0\\] +\\[right\\]: \\[1\\.0, nan, 3\\.0\\]""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_almost_equal(np.array([np.nan, 2, 3]), np.array([1, np.nan, 3])) + + +def test_assert_almost_equal_value_mismatch2(): + msg = """numpy array are different + +numpy array values are different \\(50\\.0 %\\) +\\[left\\]: \\[1, 2\\] +\\[right\\]: \\[1, 3\\]""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_almost_equal(np.array([1, 2]), np.array([1, 3])) + + +def test_assert_almost_equal_value_mismatch3(): + msg = """numpy array are different + +numpy array values are different \\(16\\.66667 %\\) +\\[left\\]: \\[\\[1, 2\\], \\[3, 4\\], \\[5, 6\\]\\] +\\[right\\]: \\[\\[1, 3\\], \\[3, 4\\], \\[5, 6\\]\\]""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_almost_equal( + np.array([[1, 2], [3, 4], [5, 6]]), np.array([[1, 3], [3, 4], [5, 6]]) + ) + + +def test_assert_almost_equal_value_mismatch4(): + msg = """numpy array are different + +numpy array values are different \\(25\\.0 %\\) +\\[left\\]: \\[\\[1, 2\\], \\[3, 4\\]\\] +\\[right\\]: \\[\\[1, 3\\], \\[3, 4\\]\\]""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_almost_equal(np.array([[1, 2], [3, 4]]), np.array([[1, 3], [3, 4]])) + + +def test_assert_almost_equal_shape_mismatch_override(): + msg = """Index are different + +Index shapes are different +\\[left\\]: \\(2L*,\\) +\\[right\\]: \\(3L*,\\)""" + with pytest.raises(AssertionError, match=msg): + tm.assert_almost_equal(np.array([1, 2]), np.array([3, 4, 5]), obj="Index") + + +def test_assert_almost_equal_unicode(): + # see gh-20503 + msg = """numpy array are different + +numpy array values are different \\(33\\.33333 %\\) +\\[left\\]: \\[á, à, ä\\] +\\[right\\]: \\[á, à, å\\]""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_almost_equal(np.array(["á", "à", "ä"]), np.array(["á", "à", "å"])) + + +def test_assert_almost_equal_timestamp(): + a = np.array([Timestamp("2011-01-01"), Timestamp("2011-01-01")]) + b = np.array([Timestamp("2011-01-01"), Timestamp("2011-01-02")]) + + msg = """numpy array are different + +numpy array values are different \\(50\\.0 %\\) +\\[left\\]: \\[2011-01-01 00:00:00, 2011-01-01 00:00:00\\] +\\[right\\]: \\[2011-01-01 00:00:00, 2011-01-02 00:00:00\\]""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_almost_equal(a, b) + + +def test_assert_almost_equal_iterable_length_mismatch(): + msg = """Iterable are different + +Iterable length are different +\\[left\\]: 2 +\\[right\\]: 3""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_almost_equal([1, 2], [3, 4, 5]) + + +def test_assert_almost_equal_iterable_values_mismatch(): + msg = """Iterable are different + +Iterable values are different \\(50\\.0 %\\) +\\[left\\]: \\[1, 2\\] +\\[right\\]: \\[1, 3\\]""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_almost_equal([1, 2], [1, 3]) + + +subarr = np.empty(2, dtype=object) +subarr[:] = [np.array([None, "b"], dtype=object), np.array(["c", "d"], dtype=object)] + +NESTED_CASES = [ + # nested array + ( + np.array([np.array([50, 70, 90]), np.array([20, 30])], dtype=object), + np.array([np.array([50, 70, 90]), np.array([20, 30])], dtype=object), + ), + # >1 level of nesting + ( + np.array( + [ + np.array([np.array([50, 70]), np.array([90])], dtype=object), + np.array([np.array([20, 30])], dtype=object), + ], + dtype=object, + ), + np.array( + [ + np.array([np.array([50, 70]), np.array([90])], dtype=object), + np.array([np.array([20, 30])], dtype=object), + ], + dtype=object, + ), + ), + # lists + ( + np.array([[50, 70, 90], [20, 30]], dtype=object), + np.array([[50, 70, 90], [20, 30]], dtype=object), + ), + # mixed array/list + ( + np.array([np.array([1, 2, 3]), np.array([4, 5])], dtype=object), + np.array([[1, 2, 3], [4, 5]], dtype=object), + ), + ( + np.array( + [ + np.array([np.array([1, 2, 3]), np.array([4, 5])], dtype=object), + np.array( + [np.array([6]), np.array([7, 8]), np.array([9])], dtype=object + ), + ], + dtype=object, + ), + np.array([[[1, 2, 3], [4, 5]], [[6], [7, 8], [9]]], dtype=object), + ), + # same-length lists + ( + np.array([subarr, None], dtype=object), + np.array([[[None, "b"], ["c", "d"]], None], dtype=object), + ), + # dicts + ( + np.array([{"f1": 1, "f2": np.array(["a", "b"], dtype=object)}], dtype=object), + np.array([{"f1": 1, "f2": np.array(["a", "b"], dtype=object)}], dtype=object), + ), + ( + np.array([{"f1": 1, "f2": np.array(["a", "b"], dtype=object)}], dtype=object), + np.array([{"f1": 1, "f2": ["a", "b"]}], dtype=object), + ), + # array/list of dicts + ( + np.array( + [ + np.array( + [{"f1": 1, "f2": np.array(["a", "b"], dtype=object)}], dtype=object + ), + np.array([], dtype=object), + ], + dtype=object, + ), + np.array([[{"f1": 1, "f2": ["a", "b"]}], []], dtype=object), + ), +] + + +@pytest.mark.filterwarnings("ignore:elementwise comparison failed:DeprecationWarning") +@pytest.mark.parametrize("a,b", NESTED_CASES) +def test_assert_almost_equal_array_nested(a, b): + _assert_almost_equal_both(a, b) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_assert_attr_equal.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_assert_attr_equal.py new file mode 100644 index 0000000000000000000000000000000000000000..bbbb0bf2172b12f93c9f0f6a97751854d1566a99 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_assert_attr_equal.py @@ -0,0 +1,33 @@ +from types import SimpleNamespace + +import pytest + +from pandas.core.dtypes.common import is_float + +import pandas._testing as tm + + +def test_assert_attr_equal(nulls_fixture): + obj = SimpleNamespace() + obj.na_value = nulls_fixture + tm.assert_attr_equal("na_value", obj, obj) + + +def test_assert_attr_equal_different_nulls(nulls_fixture, nulls_fixture2): + obj = SimpleNamespace() + obj.na_value = nulls_fixture + + obj2 = SimpleNamespace() + obj2.na_value = nulls_fixture2 + + if nulls_fixture is nulls_fixture2: + tm.assert_attr_equal("na_value", obj, obj2) + elif is_float(nulls_fixture) and is_float(nulls_fixture2): + # we consider float("nan") and np.float64("nan") to be equivalent + tm.assert_attr_equal("na_value", obj, obj2) + elif type(nulls_fixture) is type(nulls_fixture2): + # e.g. Decimal("NaN") + tm.assert_attr_equal("na_value", obj, obj2) + else: + with pytest.raises(AssertionError, match='"na_value" are different'): + tm.assert_attr_equal("na_value", obj, obj2) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_assert_categorical_equal.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_assert_categorical_equal.py new file mode 100644 index 0000000000000000000000000000000000000000..d07bbcbc460a19ec943c1f8727e25835803cf0e4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_assert_categorical_equal.py @@ -0,0 +1,90 @@ +import pytest + +from pandas import Categorical +import pandas._testing as tm + + +@pytest.mark.parametrize( + "c", + [Categorical([1, 2, 3, 4]), Categorical([1, 2, 3, 4], categories=[1, 2, 3, 4, 5])], +) +def test_categorical_equal(c): + tm.assert_categorical_equal(c, c) + + +@pytest.mark.parametrize("check_category_order", [True, False]) +def test_categorical_equal_order_mismatch(check_category_order): + c1 = Categorical([1, 2, 3, 4], categories=[1, 2, 3, 4]) + c2 = Categorical([1, 2, 3, 4], categories=[4, 3, 2, 1]) + kwargs = {"check_category_order": check_category_order} + + if check_category_order: + msg = """Categorical\\.categories are different + +Categorical\\.categories values are different \\(100\\.0 %\\) +\\[left\\]: Index\\(\\[1, 2, 3, 4\\], dtype='int64'\\) +\\[right\\]: Index\\(\\[4, 3, 2, 1\\], dtype='int64'\\)""" + with pytest.raises(AssertionError, match=msg): + tm.assert_categorical_equal(c1, c2, **kwargs) + else: + tm.assert_categorical_equal(c1, c2, **kwargs) + + +def test_categorical_equal_categories_mismatch(): + msg = """Categorical\\.categories are different + +Categorical\\.categories values are different \\(25\\.0 %\\) +\\[left\\]: Index\\(\\[1, 2, 3, 4\\], dtype='int64'\\) +\\[right\\]: Index\\(\\[1, 2, 3, 5\\], dtype='int64'\\)""" + + c1 = Categorical([1, 2, 3, 4]) + c2 = Categorical([1, 2, 3, 5]) + + with pytest.raises(AssertionError, match=msg): + tm.assert_categorical_equal(c1, c2) + + +def test_categorical_equal_codes_mismatch(): + categories = [1, 2, 3, 4] + msg = """Categorical\\.codes are different + +Categorical\\.codes values are different \\(50\\.0 %\\) +\\[left\\]: \\[0, 1, 3, 2\\] +\\[right\\]: \\[0, 1, 2, 3\\]""" + + c1 = Categorical([1, 2, 4, 3], categories=categories) + c2 = Categorical([1, 2, 3, 4], categories=categories) + + with pytest.raises(AssertionError, match=msg): + tm.assert_categorical_equal(c1, c2) + + +def test_categorical_equal_ordered_mismatch(): + data = [1, 2, 3, 4] + msg = """Categorical are different + +Attribute "ordered" are different +\\[left\\]: False +\\[right\\]: True""" + + c1 = Categorical(data, ordered=False) + c2 = Categorical(data, ordered=True) + + with pytest.raises(AssertionError, match=msg): + tm.assert_categorical_equal(c1, c2) + + +@pytest.mark.parametrize("obj", ["index", "foo", "pandas"]) +def test_categorical_equal_object_override(obj): + data = [1, 2, 3, 4] + msg = f"""{obj} are different + +Attribute "ordered" are different +\\[left\\]: False +\\[right\\]: True""" + + c1 = Categorical(data, ordered=False) + c2 = Categorical(data, ordered=True) + + with pytest.raises(AssertionError, match=msg): + tm.assert_categorical_equal(c1, c2, obj=obj) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_assert_extension_array_equal.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_assert_extension_array_equal.py new file mode 100644 index 0000000000000000000000000000000000000000..674e9307d8bb982826f7c92da798ba8d1eee9fde --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_assert_extension_array_equal.py @@ -0,0 +1,126 @@ +import numpy as np +import pytest + +from pandas import ( + Timestamp, + array, +) +import pandas._testing as tm +from pandas.core.arrays.sparse import SparseArray + + +@pytest.mark.parametrize( + "kwargs", + [ + {}, # Default is check_exact=False + {"check_exact": False}, + {"check_exact": True}, + ], +) +def test_assert_extension_array_equal_not_exact(kwargs): + # see gh-23709 + arr1 = SparseArray([-0.17387645482451206, 0.3414148016424936]) + arr2 = SparseArray([-0.17387645482451206, 0.3414148016424937]) + + if kwargs.get("check_exact", False): + msg = """\ +ExtensionArray are different + +ExtensionArray values are different \\(50\\.0 %\\) +\\[left\\]: \\[-0\\.17387645482.*, 0\\.341414801642.*\\] +\\[right\\]: \\[-0\\.17387645482.*, 0\\.341414801642.*\\]""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_extension_array_equal(arr1, arr2, **kwargs) + else: + tm.assert_extension_array_equal(arr1, arr2, **kwargs) + + +@pytest.mark.parametrize("decimals", range(10)) +def test_assert_extension_array_equal_less_precise(decimals): + rtol = 0.5 * 10**-decimals + arr1 = SparseArray([0.5, 0.123456]) + arr2 = SparseArray([0.5, 0.123457]) + + if decimals >= 5: + msg = """\ +ExtensionArray are different + +ExtensionArray values are different \\(50\\.0 %\\) +\\[left\\]: \\[0\\.5, 0\\.123456\\] +\\[right\\]: \\[0\\.5, 0\\.123457\\]""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_extension_array_equal(arr1, arr2, rtol=rtol) + else: + tm.assert_extension_array_equal(arr1, arr2, rtol=rtol) + + +def test_assert_extension_array_equal_dtype_mismatch(check_dtype): + end = 5 + kwargs = {"check_dtype": check_dtype} + + arr1 = SparseArray(np.arange(end, dtype="int64")) + arr2 = SparseArray(np.arange(end, dtype="int32")) + + if check_dtype: + msg = """\ +ExtensionArray are different + +Attribute "dtype" are different +\\[left\\]: Sparse\\[int64, 0\\] +\\[right\\]: Sparse\\[int32, 0\\]""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_extension_array_equal(arr1, arr2, **kwargs) + else: + tm.assert_extension_array_equal(arr1, arr2, **kwargs) + + +def test_assert_extension_array_equal_missing_values(): + arr1 = SparseArray([np.nan, 1, 2, np.nan]) + arr2 = SparseArray([np.nan, 1, 2, 3]) + + msg = """\ +ExtensionArray NA mask are different + +ExtensionArray NA mask values are different \\(25\\.0 %\\) +\\[left\\]: \\[True, False, False, True\\] +\\[right\\]: \\[True, False, False, False\\]""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_extension_array_equal(arr1, arr2) + + +@pytest.mark.parametrize("side", ["left", "right"]) +def test_assert_extension_array_equal_non_extension_array(side): + numpy_array = np.arange(5) + extension_array = SparseArray(numpy_array) + + msg = f"{side} is not an ExtensionArray" + args = ( + (numpy_array, extension_array) + if side == "left" + else (extension_array, numpy_array) + ) + + with pytest.raises(AssertionError, match=msg): + tm.assert_extension_array_equal(*args) + + +@pytest.mark.parametrize("right_dtype", ["Int32", "int64"]) +def test_assert_extension_array_equal_ignore_dtype_mismatch(right_dtype): + # https://github.com/pandas-dev/pandas/issues/35715 + left = array([1, 2, 3], dtype="Int64") + right = array([1, 2, 3], dtype=right_dtype) + tm.assert_extension_array_equal(left, right, check_dtype=False) + + +def test_assert_extension_array_equal_time_units(): + # https://github.com/pandas-dev/pandas/issues/55730 + timestamp = Timestamp("2023-11-04T12") + naive = array([timestamp], dtype="datetime64[ns]") + utc = array([timestamp], dtype="datetime64[ns, UTC]") + + tm.assert_extension_array_equal(naive, utc, check_dtype=False) + tm.assert_extension_array_equal(utc, naive, check_dtype=False) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_assert_frame_equal.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_assert_frame_equal.py new file mode 100644 index 0000000000000000000000000000000000000000..79132591b15b3d58781c70c1a0ac5ae77c213521 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_assert_frame_equal.py @@ -0,0 +1,393 @@ +import pytest + +import pandas as pd +from pandas import DataFrame +import pandas._testing as tm + + +@pytest.fixture(params=[True, False]) +def by_blocks_fixture(request): + return request.param + + +@pytest.fixture(params=["DataFrame", "Series"]) +def obj_fixture(request): + return request.param + + +def _assert_frame_equal_both(a, b, **kwargs): + """ + Check that two DataFrame equal. + + This check is performed commutatively. + + Parameters + ---------- + a : DataFrame + The first DataFrame to compare. + b : DataFrame + The second DataFrame to compare. + kwargs : dict + The arguments passed to `tm.assert_frame_equal`. + """ + tm.assert_frame_equal(a, b, **kwargs) + tm.assert_frame_equal(b, a, **kwargs) + + +@pytest.mark.parametrize("check_like", [True, False]) +def test_frame_equal_row_order_mismatch(check_like, obj_fixture): + df1 = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}, index=["a", "b", "c"]) + df2 = DataFrame({"A": [3, 2, 1], "B": [6, 5, 4]}, index=["c", "b", "a"]) + + if not check_like: # Do not ignore row-column orderings. + msg = f"{obj_fixture}.index are different" + with pytest.raises(AssertionError, match=msg): + tm.assert_frame_equal(df1, df2, check_like=check_like, obj=obj_fixture) + else: + _assert_frame_equal_both(df1, df2, check_like=check_like, obj=obj_fixture) + + +@pytest.mark.parametrize( + "df1,df2", + [ + (DataFrame({"A": [1, 2, 3]}), DataFrame({"A": [1, 2, 3, 4]})), + (DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}), DataFrame({"A": [1, 2, 3]})), + ], +) +def test_frame_equal_shape_mismatch(df1, df2, obj_fixture): + msg = f"{obj_fixture} are different" + + with pytest.raises(AssertionError, match=msg): + tm.assert_frame_equal(df1, df2, obj=obj_fixture) + + +@pytest.mark.parametrize( + "df1,df2,msg", + [ + # Index + ( + DataFrame.from_records({"a": [1, 2], "c": ["l1", "l2"]}, index=["a"]), + DataFrame.from_records({"a": [1.0, 2.0], "c": ["l1", "l2"]}, index=["a"]), + "DataFrame\\.index are different", + ), + # MultiIndex + ( + DataFrame.from_records( + {"a": [1, 2], "b": [2.1, 1.5], "c": ["l1", "l2"]}, index=["a", "b"] + ), + DataFrame.from_records( + {"a": [1.0, 2.0], "b": [2.1, 1.5], "c": ["l1", "l2"]}, index=["a", "b"] + ), + "MultiIndex level \\[0\\] are different", + ), + ], +) +def test_frame_equal_index_dtype_mismatch(df1, df2, msg, check_index_type): + kwargs = {"check_index_type": check_index_type} + + if check_index_type: + with pytest.raises(AssertionError, match=msg): + tm.assert_frame_equal(df1, df2, **kwargs) + else: + tm.assert_frame_equal(df1, df2, **kwargs) + + +def test_empty_dtypes(check_dtype): + columns = ["col1", "col2"] + df1 = DataFrame(columns=columns) + df2 = DataFrame(columns=columns) + + kwargs = {"check_dtype": check_dtype} + df1["col1"] = df1["col1"].astype("int64") + + if check_dtype: + msg = r"Attributes of DataFrame\..* are different" + with pytest.raises(AssertionError, match=msg): + tm.assert_frame_equal(df1, df2, **kwargs) + else: + tm.assert_frame_equal(df1, df2, **kwargs) + + +@pytest.mark.parametrize("check_like", [True, False]) +def test_frame_equal_index_mismatch(check_like, obj_fixture, using_infer_string): + if using_infer_string: + dtype = "string" + else: + dtype = "object" + msg = f"""{obj_fixture}\\.index are different + +{obj_fixture}\\.index values are different \\(33\\.33333 %\\) +\\[left\\]: Index\\(\\['a', 'b', 'c'\\], dtype='{dtype}'\\) +\\[right\\]: Index\\(\\['a', 'b', 'd'\\], dtype='{dtype}'\\) +At positional index 2, first diff: c != d""" + + df1 = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}, index=["a", "b", "c"]) + df2 = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}, index=["a", "b", "d"]) + + with pytest.raises(AssertionError, match=msg): + tm.assert_frame_equal(df1, df2, check_like=check_like, obj=obj_fixture) + + +@pytest.mark.parametrize("check_like", [True, False]) +def test_frame_equal_columns_mismatch(check_like, obj_fixture, using_infer_string): + if using_infer_string: + dtype = "string" + else: + dtype = "object" + msg = f"""{obj_fixture}\\.columns are different + +{obj_fixture}\\.columns values are different \\(50\\.0 %\\) +\\[left\\]: Index\\(\\['A', 'B'\\], dtype='{dtype}'\\) +\\[right\\]: Index\\(\\['A', 'b'\\], dtype='{dtype}'\\)""" + + df1 = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}, index=["a", "b", "c"]) + df2 = DataFrame({"A": [1, 2, 3], "b": [4, 5, 6]}, index=["a", "b", "c"]) + + with pytest.raises(AssertionError, match=msg): + tm.assert_frame_equal(df1, df2, check_like=check_like, obj=obj_fixture) + + +def test_frame_equal_block_mismatch(by_blocks_fixture, obj_fixture): + obj = obj_fixture + msg = f"""{obj}\\.iloc\\[:, 1\\] \\(column name="B"\\) are different + +{obj}\\.iloc\\[:, 1\\] \\(column name="B"\\) values are different \\(33\\.33333 %\\) +\\[index\\]: \\[0, 1, 2\\] +\\[left\\]: \\[4, 5, 6\\] +\\[right\\]: \\[4, 5, 7\\]""" + + df1 = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + df2 = DataFrame({"A": [1, 2, 3], "B": [4, 5, 7]}) + + with pytest.raises(AssertionError, match=msg): + tm.assert_frame_equal(df1, df2, by_blocks=by_blocks_fixture, obj=obj_fixture) + + +@pytest.mark.parametrize( + "df1,df2,msg", + [ + ( + DataFrame({"A": ["á", "à", "ä"], "E": ["é", "è", "ë"]}), + DataFrame({"A": ["á", "à", "ä"], "E": ["é", "è", "e̊"]}), + """{obj}\\.iloc\\[:, 1\\] \\(column name="E"\\) are different + +{obj}\\.iloc\\[:, 1\\] \\(column name="E"\\) values are different \\(33\\.33333 %\\) +\\[index\\]: \\[0, 1, 2\\] +\\[left\\]: \\[é, è, ë\\] +\\[right\\]: \\[é, è, e̊\\]""", + ), + ( + DataFrame({"A": ["á", "à", "ä"], "E": ["é", "è", "ë"]}), + DataFrame({"A": ["a", "a", "a"], "E": ["e", "e", "e"]}), + """{obj}\\.iloc\\[:, 0\\] \\(column name="A"\\) are different + +{obj}\\.iloc\\[:, 0\\] \\(column name="A"\\) values are different \\(100\\.0 %\\) +\\[index\\]: \\[0, 1, 2\\] +\\[left\\]: \\[á, à, ä\\] +\\[right\\]: \\[a, a, a\\]""", + ), + ], +) +def test_frame_equal_unicode(df1, df2, msg, by_blocks_fixture, obj_fixture): + # see gh-20503 + # + # Test ensures that `tm.assert_frame_equals` raises the right exception + # when comparing DataFrames containing differing unicode objects. + msg = msg.format(obj=obj_fixture) + with pytest.raises(AssertionError, match=msg): + tm.assert_frame_equal(df1, df2, by_blocks=by_blocks_fixture, obj=obj_fixture) + + +def test_assert_frame_equal_extension_dtype_mismatch(): + # https://github.com/pandas-dev/pandas/issues/32747 + left = DataFrame({"a": [1, 2, 3]}, dtype="Int64") + right = left.astype(int) + + msg = ( + "Attributes of DataFrame\\.iloc\\[:, 0\\] " + '\\(column name="a"\\) are different\n\n' + 'Attribute "dtype" are different\n' + "\\[left\\]: Int64\n" + "\\[right\\]: int[32|64]" + ) + + tm.assert_frame_equal(left, right, check_dtype=False) + + with pytest.raises(AssertionError, match=msg): + tm.assert_frame_equal(left, right, check_dtype=True) + + +def test_assert_frame_equal_interval_dtype_mismatch(): + # https://github.com/pandas-dev/pandas/issues/32747 + left = DataFrame({"a": [pd.Interval(0, 1)]}, dtype="interval") + right = left.astype(object) + + msg = ( + "Attributes of DataFrame\\.iloc\\[:, 0\\] " + '\\(column name="a"\\) are different\n\n' + 'Attribute "dtype" are different\n' + "\\[left\\]: interval\\[int64, right\\]\n" + "\\[right\\]: object" + ) + + tm.assert_frame_equal(left, right, check_dtype=False) + + with pytest.raises(AssertionError, match=msg): + tm.assert_frame_equal(left, right, check_dtype=True) + + +def test_assert_frame_equal_ignore_extension_dtype_mismatch(): + # https://github.com/pandas-dev/pandas/issues/35715 + left = DataFrame({"a": [1, 2, 3]}, dtype="Int64") + right = DataFrame({"a": [1, 2, 3]}, dtype="Int32") + tm.assert_frame_equal(left, right, check_dtype=False) + + +def test_assert_frame_equal_ignore_extension_dtype_mismatch_cross_class(): + # https://github.com/pandas-dev/pandas/issues/35715 + left = DataFrame({"a": [1, 2, 3]}, dtype="Int64") + right = DataFrame({"a": [1, 2, 3]}, dtype="int64") + tm.assert_frame_equal(left, right, check_dtype=False) + + +@pytest.mark.parametrize( + "dtype", + [ + ("timedelta64[ns]"), + ("datetime64[ns, UTC]"), + ("Period[D]"), + ], +) +def test_assert_frame_equal_datetime_like_dtype_mismatch(dtype): + df1 = DataFrame({"a": []}, dtype=dtype) + df2 = DataFrame({"a": []}) + tm.assert_frame_equal(df1, df2, check_dtype=False) + + +def test_allows_duplicate_labels(): + left = DataFrame() + right = DataFrame().set_flags(allows_duplicate_labels=False) + tm.assert_frame_equal(left, left) + tm.assert_frame_equal(right, right) + tm.assert_frame_equal(left, right, check_flags=False) + tm.assert_frame_equal(right, left, check_flags=False) + + with pytest.raises(AssertionError, match="\\]""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_numpy_array_equal(a, b) + + +def test_numpy_array_equal_identical_na(nulls_fixture): + a = np.array([nulls_fixture], dtype=object) + + tm.assert_numpy_array_equal(a, a) + + # matching but not the identical object + if hasattr(nulls_fixture, "copy"): + other = nulls_fixture.copy() + else: + other = copy.copy(nulls_fixture) + b = np.array([other], dtype=object) + tm.assert_numpy_array_equal(a, b) + + +def test_numpy_array_equal_different_na(): + a = np.array([np.nan], dtype=object) + b = np.array([pd.NA], dtype=object) + + msg = """numpy array are different + +numpy array values are different \\(100.0 %\\) +\\[left\\]: \\[nan\\] +\\[right\\]: \\[\\]""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_numpy_array_equal(a, b) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_assert_produces_warning.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_assert_produces_warning.py new file mode 100644 index 0000000000000000000000000000000000000000..5c27a3ee79d4a82bce83eec56ab9d88e10dc06cd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_assert_produces_warning.py @@ -0,0 +1,241 @@ +"""" +Test module for testing ``pandas._testing.assert_produces_warning``. +""" +import warnings + +import pytest + +from pandas.errors import ( + DtypeWarning, + PerformanceWarning, +) + +import pandas._testing as tm + + +@pytest.fixture( + params=[ + RuntimeWarning, + ResourceWarning, + UserWarning, + FutureWarning, + DeprecationWarning, + PerformanceWarning, + DtypeWarning, + ], +) +def category(request): + """ + Return unique warning. + + Useful for testing behavior of tm.assert_produces_warning with various categories. + """ + return request.param + + +@pytest.fixture( + params=[ + (RuntimeWarning, UserWarning), + (UserWarning, FutureWarning), + (FutureWarning, RuntimeWarning), + (DeprecationWarning, PerformanceWarning), + (PerformanceWarning, FutureWarning), + (DtypeWarning, DeprecationWarning), + (ResourceWarning, DeprecationWarning), + (FutureWarning, DeprecationWarning), + ], + ids=lambda x: type(x).__name__, +) +def pair_different_warnings(request): + """ + Return pair or different warnings. + + Useful for testing how several different warnings are handled + in tm.assert_produces_warning. + """ + return request.param + + +def f(): + warnings.warn("f1", FutureWarning) + warnings.warn("f2", RuntimeWarning) + + +@pytest.mark.filterwarnings("ignore:f1:FutureWarning") +def test_assert_produces_warning_honors_filter(): + # Raise by default. + msg = r"Caused unexpected warning\(s\)" + with pytest.raises(AssertionError, match=msg): + with tm.assert_produces_warning(RuntimeWarning): + f() + + with tm.assert_produces_warning(RuntimeWarning, raise_on_extra_warnings=False): + f() + + +@pytest.mark.parametrize( + "message, match", + [ + ("", None), + ("", ""), + ("Warning message", r".*"), + ("Warning message", "War"), + ("Warning message", r"[Ww]arning"), + ("Warning message", "age"), + ("Warning message", r"age$"), + ("Message 12-234 with numbers", r"\d{2}-\d{3}"), + ("Message 12-234 with numbers", r"^Mes.*\d{2}-\d{3}"), + ("Message 12-234 with numbers", r"\d{2}-\d{3}\s\S+"), + ("Message, which we do not match", None), + ], +) +def test_catch_warning_category_and_match(category, message, match): + with tm.assert_produces_warning(category, match=match): + warnings.warn(message, category) + + +def test_fail_to_match_runtime_warning(): + category = RuntimeWarning + match = "Did not see this warning" + unmatched = ( + r"Did not see warning 'RuntimeWarning' matching 'Did not see this warning'. " + r"The emitted warning messages are " + r"\[RuntimeWarning\('This is not a match.'\), " + r"RuntimeWarning\('Another unmatched warning.'\)\]" + ) + with pytest.raises(AssertionError, match=unmatched): + with tm.assert_produces_warning(category, match=match): + warnings.warn("This is not a match.", category) + warnings.warn("Another unmatched warning.", category) + + +def test_fail_to_match_future_warning(): + category = FutureWarning + match = "Warning" + unmatched = ( + r"Did not see warning 'FutureWarning' matching 'Warning'. " + r"The emitted warning messages are " + r"\[FutureWarning\('This is not a match.'\), " + r"FutureWarning\('Another unmatched warning.'\)\]" + ) + with pytest.raises(AssertionError, match=unmatched): + with tm.assert_produces_warning(category, match=match): + warnings.warn("This is not a match.", category) + warnings.warn("Another unmatched warning.", category) + + +def test_fail_to_match_resource_warning(): + category = ResourceWarning + match = r"\d+" + unmatched = ( + r"Did not see warning 'ResourceWarning' matching '\\d\+'. " + r"The emitted warning messages are " + r"\[ResourceWarning\('This is not a match.'\), " + r"ResourceWarning\('Another unmatched warning.'\)\]" + ) + with pytest.raises(AssertionError, match=unmatched): + with tm.assert_produces_warning(category, match=match): + warnings.warn("This is not a match.", category) + warnings.warn("Another unmatched warning.", category) + + +def test_fail_to_catch_actual_warning(pair_different_warnings): + expected_category, actual_category = pair_different_warnings + match = "Did not see expected warning of class" + with pytest.raises(AssertionError, match=match): + with tm.assert_produces_warning(expected_category): + warnings.warn("warning message", actual_category) + + +def test_ignore_extra_warning(pair_different_warnings): + expected_category, extra_category = pair_different_warnings + with tm.assert_produces_warning(expected_category, raise_on_extra_warnings=False): + warnings.warn("Expected warning", expected_category) + warnings.warn("Unexpected warning OK", extra_category) + + +def test_raise_on_extra_warning(pair_different_warnings): + expected_category, extra_category = pair_different_warnings + match = r"Caused unexpected warning\(s\)" + with pytest.raises(AssertionError, match=match): + with tm.assert_produces_warning(expected_category): + warnings.warn("Expected warning", expected_category) + warnings.warn("Unexpected warning NOT OK", extra_category) + + +def test_same_category_different_messages_first_match(): + category = UserWarning + with tm.assert_produces_warning(category, match=r"^Match this"): + warnings.warn("Match this", category) + warnings.warn("Do not match that", category) + warnings.warn("Do not match that either", category) + + +def test_same_category_different_messages_last_match(): + category = DeprecationWarning + with tm.assert_produces_warning(category, match=r"^Match this"): + warnings.warn("Do not match that", category) + warnings.warn("Do not match that either", category) + warnings.warn("Match this", category) + + +def test_match_multiple_warnings(): + # https://github.com/pandas-dev/pandas/issues/47829 + category = (FutureWarning, UserWarning) + with tm.assert_produces_warning(category, match=r"^Match this"): + warnings.warn("Match this", FutureWarning) + warnings.warn("Match this too", UserWarning) + + +def test_right_category_wrong_match_raises(pair_different_warnings): + target_category, other_category = pair_different_warnings + with pytest.raises(AssertionError, match="Did not see warning.*matching"): + with tm.assert_produces_warning(target_category, match=r"^Match this"): + warnings.warn("Do not match it", target_category) + warnings.warn("Match this", other_category) + + +@pytest.mark.parametrize("false_or_none", [False, None]) +class TestFalseOrNoneExpectedWarning: + def test_raise_on_warning(self, false_or_none): + msg = r"Caused unexpected warning\(s\)" + with pytest.raises(AssertionError, match=msg): + with tm.assert_produces_warning(false_or_none): + f() + + def test_no_raise_without_warning(self, false_or_none): + with tm.assert_produces_warning(false_or_none): + pass + + def test_no_raise_with_false_raise_on_extra(self, false_or_none): + with tm.assert_produces_warning(false_or_none, raise_on_extra_warnings=False): + f() + + +def test_raises_during_exception(): + msg = "Did not see expected warning of class 'UserWarning'" + with pytest.raises(AssertionError, match=msg): + with tm.assert_produces_warning(UserWarning): + raise ValueError + + with pytest.raises(AssertionError, match=msg): + with tm.assert_produces_warning(UserWarning): + warnings.warn("FutureWarning", FutureWarning) + raise IndexError + + msg = "Caused unexpected warning" + with pytest.raises(AssertionError, match=msg): + with tm.assert_produces_warning(None): + warnings.warn("FutureWarning", FutureWarning) + raise SystemError + + +def test_passes_during_exception(): + with pytest.raises(SyntaxError, match="Error"): + with tm.assert_produces_warning(None): + raise SyntaxError("Error") + + with pytest.raises(ValueError, match="Error"): + with tm.assert_produces_warning(FutureWarning, match="FutureWarning"): + warnings.warn("FutureWarning", FutureWarning) + raise ValueError("Error") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_assert_series_equal.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_assert_series_equal.py new file mode 100644 index 0000000000000000000000000000000000000000..1878e7d8380648be2c5c57119ab0a28c45884a05 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_assert_series_equal.py @@ -0,0 +1,484 @@ +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + Categorical, + DataFrame, + Series, +) +import pandas._testing as tm + + +def _assert_series_equal_both(a, b, **kwargs): + """ + Check that two Series equal. + + This check is performed commutatively. + + Parameters + ---------- + a : Series + The first Series to compare. + b : Series + The second Series to compare. + kwargs : dict + The arguments passed to `tm.assert_series_equal`. + """ + tm.assert_series_equal(a, b, **kwargs) + tm.assert_series_equal(b, a, **kwargs) + + +def _assert_not_series_equal(a, b, **kwargs): + """ + Check that two Series are not equal. + + Parameters + ---------- + a : Series + The first Series to compare. + b : Series + The second Series to compare. + kwargs : dict + The arguments passed to `tm.assert_series_equal`. + """ + try: + tm.assert_series_equal(a, b, **kwargs) + msg = "The two Series were equal when they shouldn't have been" + + pytest.fail(msg=msg) + except AssertionError: + pass + + +def _assert_not_series_equal_both(a, b, **kwargs): + """ + Check that two Series are not equal. + + This check is performed commutatively. + + Parameters + ---------- + a : Series + The first Series to compare. + b : Series + The second Series to compare. + kwargs : dict + The arguments passed to `tm.assert_series_equal`. + """ + _assert_not_series_equal(a, b, **kwargs) + _assert_not_series_equal(b, a, **kwargs) + + +@pytest.mark.parametrize("data", [range(3), list("abc"), list("áàä")]) +def test_series_equal(data): + _assert_series_equal_both(Series(data), Series(data)) + + +@pytest.mark.parametrize( + "data1,data2", + [ + (range(3), range(1, 4)), + (list("abc"), list("xyz")), + (list("áàä"), list("éèë")), + (list("áàä"), list(b"aaa")), + (range(3), range(4)), + ], +) +def test_series_not_equal_value_mismatch(data1, data2): + _assert_not_series_equal_both(Series(data1), Series(data2)) + + +@pytest.mark.parametrize( + "kwargs", + [ + {"dtype": "float64"}, # dtype mismatch + {"index": [1, 2, 4]}, # index mismatch + {"name": "foo"}, # name mismatch + ], +) +def test_series_not_equal_metadata_mismatch(kwargs): + data = range(3) + s1 = Series(data) + + s2 = Series(data, **kwargs) + _assert_not_series_equal_both(s1, s2) + + +@pytest.mark.parametrize("data1,data2", [(0.12345, 0.12346), (0.1235, 0.1236)]) +@pytest.mark.parametrize("dtype", ["float32", "float64", "Float32"]) +@pytest.mark.parametrize("decimals", [0, 1, 2, 3, 5, 10]) +def test_less_precise(data1, data2, dtype, decimals): + rtol = 10**-decimals + s1 = Series([data1], dtype=dtype) + s2 = Series([data2], dtype=dtype) + + if decimals in (5, 10) or (decimals >= 3 and abs(data1 - data2) >= 0.0005): + msg = "Series values are different" + with pytest.raises(AssertionError, match=msg): + tm.assert_series_equal(s1, s2, rtol=rtol) + else: + _assert_series_equal_both(s1, s2, rtol=rtol) + + +@pytest.mark.parametrize( + "s1,s2,msg", + [ + # Index + ( + Series(["l1", "l2"], index=[1, 2]), + Series(["l1", "l2"], index=[1.0, 2.0]), + "Series\\.index are different", + ), + # MultiIndex + ( + DataFrame.from_records( + {"a": [1, 2], "b": [2.1, 1.5], "c": ["l1", "l2"]}, index=["a", "b"] + ).c, + DataFrame.from_records( + {"a": [1.0, 2.0], "b": [2.1, 1.5], "c": ["l1", "l2"]}, index=["a", "b"] + ).c, + "MultiIndex level \\[0\\] are different", + ), + ], +) +def test_series_equal_index_dtype(s1, s2, msg, check_index_type): + kwargs = {"check_index_type": check_index_type} + + if check_index_type: + with pytest.raises(AssertionError, match=msg): + tm.assert_series_equal(s1, s2, **kwargs) + else: + tm.assert_series_equal(s1, s2, **kwargs) + + +@pytest.mark.parametrize("check_like", [True, False]) +def test_series_equal_order_mismatch(check_like): + s1 = Series([1, 2, 3], index=["a", "b", "c"]) + s2 = Series([3, 2, 1], index=["c", "b", "a"]) + + if not check_like: # Do not ignore index ordering. + with pytest.raises(AssertionError, match="Series.index are different"): + tm.assert_series_equal(s1, s2, check_like=check_like) + else: + _assert_series_equal_both(s1, s2, check_like=check_like) + + +@pytest.mark.parametrize("check_index", [True, False]) +def test_series_equal_index_mismatch(check_index): + s1 = Series([1, 2, 3], index=["a", "b", "c"]) + s2 = Series([1, 2, 3], index=["c", "b", "a"]) + + if check_index: # Do not ignore index. + with pytest.raises(AssertionError, match="Series.index are different"): + tm.assert_series_equal(s1, s2, check_index=check_index) + else: + _assert_series_equal_both(s1, s2, check_index=check_index) + + +def test_series_invalid_param_combination(): + left = Series(dtype=object) + right = Series(dtype=object) + with pytest.raises( + ValueError, match="check_like must be False if check_index is False" + ): + tm.assert_series_equal(left, right, check_index=False, check_like=True) + + +def test_series_equal_length_mismatch(rtol): + msg = """Series are different + +Series length are different +\\[left\\]: 3, RangeIndex\\(start=0, stop=3, step=1\\) +\\[right\\]: 4, RangeIndex\\(start=0, stop=4, step=1\\)""" + + s1 = Series([1, 2, 3]) + s2 = Series([1, 2, 3, 4]) + + with pytest.raises(AssertionError, match=msg): + tm.assert_series_equal(s1, s2, rtol=rtol) + + +def test_series_equal_numeric_values_mismatch(rtol): + msg = """Series are different + +Series values are different \\(33\\.33333 %\\) +\\[index\\]: \\[0, 1, 2\\] +\\[left\\]: \\[1, 2, 3\\] +\\[right\\]: \\[1, 2, 4\\]""" + + s1 = Series([1, 2, 3]) + s2 = Series([1, 2, 4]) + + with pytest.raises(AssertionError, match=msg): + tm.assert_series_equal(s1, s2, rtol=rtol) + + +def test_series_equal_categorical_values_mismatch(rtol, using_infer_string): + if using_infer_string: + msg = """Series are different + +Series values are different \\(66\\.66667 %\\) +\\[index\\]: \\[0, 1, 2\\] +\\[left\\]: \\['a', 'b', 'c'\\] +Categories \\(3, string\\): \\[a, b, c\\] +\\[right\\]: \\['a', 'c', 'b'\\] +Categories \\(3, string\\): \\[a, b, c\\]""" + else: + msg = """Series are different + +Series values are different \\(66\\.66667 %\\) +\\[index\\]: \\[0, 1, 2\\] +\\[left\\]: \\['a', 'b', 'c'\\] +Categories \\(3, object\\): \\['a', 'b', 'c'\\] +\\[right\\]: \\['a', 'c', 'b'\\] +Categories \\(3, object\\): \\['a', 'b', 'c'\\]""" + + s1 = Series(Categorical(["a", "b", "c"])) + s2 = Series(Categorical(["a", "c", "b"])) + + with pytest.raises(AssertionError, match=msg): + tm.assert_series_equal(s1, s2, rtol=rtol) + + +def test_series_equal_datetime_values_mismatch(rtol): + msg = """Series are different + +Series values are different \\(100.0 %\\) +\\[index\\]: \\[0, 1, 2\\] +\\[left\\]: \\[1514764800000000000, 1514851200000000000, 1514937600000000000\\] +\\[right\\]: \\[1549065600000000000, 1549152000000000000, 1549238400000000000\\]""" + + s1 = Series(pd.date_range("2018-01-01", periods=3, freq="D")) + s2 = Series(pd.date_range("2019-02-02", periods=3, freq="D")) + + with pytest.raises(AssertionError, match=msg): + tm.assert_series_equal(s1, s2, rtol=rtol) + + +def test_series_equal_categorical_mismatch(check_categorical, using_infer_string): + if using_infer_string: + dtype = "string" + else: + dtype = "object" + msg = f"""Attributes of Series are different + +Attribute "dtype" are different +\\[left\\]: CategoricalDtype\\(categories=\\['a', 'b'\\], ordered=False, \ +categories_dtype={dtype}\\) +\\[right\\]: CategoricalDtype\\(categories=\\['a', 'b', 'c'\\], \ +ordered=False, categories_dtype={dtype}\\)""" + + s1 = Series(Categorical(["a", "b"])) + s2 = Series(Categorical(["a", "b"], categories=list("abc"))) + + if check_categorical: + with pytest.raises(AssertionError, match=msg): + tm.assert_series_equal(s1, s2, check_categorical=check_categorical) + else: + _assert_series_equal_both(s1, s2, check_categorical=check_categorical) + + +def test_assert_series_equal_extension_dtype_mismatch(): + # https://github.com/pandas-dev/pandas/issues/32747 + left = Series(pd.array([1, 2, 3], dtype="Int64")) + right = left.astype(int) + + msg = """Attributes of Series are different + +Attribute "dtype" are different +\\[left\\]: Int64 +\\[right\\]: int[32|64]""" + + tm.assert_series_equal(left, right, check_dtype=False) + + with pytest.raises(AssertionError, match=msg): + tm.assert_series_equal(left, right, check_dtype=True) + + +def test_assert_series_equal_interval_dtype_mismatch(): + # https://github.com/pandas-dev/pandas/issues/32747 + left = Series([pd.Interval(0, 1)], dtype="interval") + right = left.astype(object) + + msg = """Attributes of Series are different + +Attribute "dtype" are different +\\[left\\]: interval\\[int64, right\\] +\\[right\\]: object""" + + tm.assert_series_equal(left, right, check_dtype=False) + + with pytest.raises(AssertionError, match=msg): + tm.assert_series_equal(left, right, check_dtype=True) + + +def test_series_equal_series_type(): + class MySeries(Series): + pass + + s1 = Series([1, 2]) + s2 = Series([1, 2]) + s3 = MySeries([1, 2]) + + tm.assert_series_equal(s1, s2, check_series_type=False) + tm.assert_series_equal(s1, s2, check_series_type=True) + + tm.assert_series_equal(s1, s3, check_series_type=False) + tm.assert_series_equal(s3, s1, check_series_type=False) + + with pytest.raises(AssertionError, match="Series classes are different"): + tm.assert_series_equal(s1, s3, check_series_type=True) + + with pytest.raises(AssertionError, match="Series classes are different"): + tm.assert_series_equal(s3, s1, check_series_type=True) + + +def test_series_equal_exact_for_nonnumeric(): + # https://github.com/pandas-dev/pandas/issues/35446 + s1 = Series(["a", "b"]) + s2 = Series(["a", "b"]) + s3 = Series(["b", "a"]) + + tm.assert_series_equal(s1, s2, check_exact=True) + tm.assert_series_equal(s2, s1, check_exact=True) + + msg = """Series are different + +Series values are different \\(100\\.0 %\\) +\\[index\\]: \\[0, 1\\] +\\[left\\]: \\[a, b\\] +\\[right\\]: \\[b, a\\]""" + with pytest.raises(AssertionError, match=msg): + tm.assert_series_equal(s1, s3, check_exact=True) + + msg = """Series are different + +Series values are different \\(100\\.0 %\\) +\\[index\\]: \\[0, 1\\] +\\[left\\]: \\[b, a\\] +\\[right\\]: \\[a, b\\]""" + with pytest.raises(AssertionError, match=msg): + tm.assert_series_equal(s3, s1, check_exact=True) + + +def test_assert_series_equal_ignore_extension_dtype_mismatch(): + # https://github.com/pandas-dev/pandas/issues/35715 + left = Series([1, 2, 3], dtype="Int64") + right = Series([1, 2, 3], dtype="Int32") + tm.assert_series_equal(left, right, check_dtype=False) + + +def test_assert_series_equal_ignore_extension_dtype_mismatch_cross_class(): + # https://github.com/pandas-dev/pandas/issues/35715 + left = Series([1, 2, 3], dtype="Int64") + right = Series([1, 2, 3], dtype="int64") + tm.assert_series_equal(left, right, check_dtype=False) + + +def test_allows_duplicate_labels(): + left = Series([1]) + right = Series([1]).set_flags(allows_duplicate_labels=False) + tm.assert_series_equal(left, left) + tm.assert_series_equal(right, right) + tm.assert_series_equal(left, right, check_flags=False) + tm.assert_series_equal(right, left, check_flags=False) + + with pytest.raises(AssertionError, match=">> cumavg([1, 2, 3]) + 2 + """ + ), + method="cumavg", + operation="average", +) +def cumavg(whatever): + pass + + +@doc(cumsum, method="cummax", operation="maximum") +def cummax(whatever): + pass + + +@doc(cummax, method="cummin", operation="minimum") +def cummin(whatever): + pass + + +def test_docstring_formatting(): + docstr = dedent( + """ + This is the cumsum method. + + It computes the cumulative sum. + """ + ) + assert cumsum.__doc__ == docstr + + +def test_docstring_appending(): + docstr = dedent( + """ + This is the cumavg method. + + It computes the cumulative average. + + Examples + -------- + + >>> cumavg([1, 2, 3]) + 2 + """ + ) + assert cumavg.__doc__ == docstr + + +def test_doc_template_from_func(): + docstr = dedent( + """ + This is the cummax method. + + It computes the cumulative maximum. + """ + ) + assert cummax.__doc__ == docstr + + +def test_inherit_doc_template(): + docstr = dedent( + """ + This is the cummin method. + + It computes the cumulative minimum. + """ + ) + assert cummin.__doc__ == docstr diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_hashing.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_hashing.py new file mode 100644 index 0000000000000000000000000000000000000000..1e7fdd920e365cd49abf22732c573ca696d3b3d7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_hashing.py @@ -0,0 +1,417 @@ +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, + period_range, + timedelta_range, +) +import pandas._testing as tm +from pandas.core.util.hashing import hash_tuples +from pandas.util import ( + hash_array, + hash_pandas_object, +) + + +@pytest.fixture( + params=[ + Series([1, 2, 3] * 3, dtype="int32"), + Series([None, 2.5, 3.5] * 3, dtype="float32"), + Series(["a", "b", "c"] * 3, dtype="category"), + Series(["d", "e", "f"] * 3), + Series([True, False, True] * 3), + Series(pd.date_range("20130101", periods=9)), + Series(pd.date_range("20130101", periods=9, tz="US/Eastern")), + Series(timedelta_range("2000", periods=9)), + ] +) +def series(request): + return request.param + + +@pytest.fixture(params=[True, False]) +def index(request): + return request.param + + +def test_consistency(): + # Check that our hash doesn't change because of a mistake + # in the actual code; this is the ground truth. + result = hash_pandas_object(Index(["foo", "bar", "baz"])) + expected = Series( + np.array( + [3600424527151052760, 1374399572096150070, 477881037637427054], + dtype="uint64", + ), + index=["foo", "bar", "baz"], + ) + tm.assert_series_equal(result, expected) + + +def test_hash_array(series): + arr = series.values + tm.assert_numpy_array_equal(hash_array(arr), hash_array(arr)) + + +@pytest.mark.parametrize("dtype", ["U", object]) +def test_hash_array_mixed(dtype): + result1 = hash_array(np.array(["3", "4", "All"])) + result2 = hash_array(np.array([3, 4, "All"], dtype=dtype)) + + tm.assert_numpy_array_equal(result1, result2) + + +@pytest.mark.parametrize("val", [5, "foo", pd.Timestamp("20130101")]) +def test_hash_array_errors(val): + msg = "must pass a ndarray-like" + with pytest.raises(TypeError, match=msg): + hash_array(val) + + +def test_hash_array_index_exception(): + # GH42003 TypeError instead of AttributeError + obj = pd.DatetimeIndex(["2018-10-28 01:20:00"], tz="Europe/Berlin") + + msg = "Use hash_pandas_object instead" + with pytest.raises(TypeError, match=msg): + hash_array(obj) + + +def test_hash_tuples(): + tuples = [(1, "one"), (1, "two"), (2, "one")] + result = hash_tuples(tuples) + + expected = hash_pandas_object(MultiIndex.from_tuples(tuples)).values + tm.assert_numpy_array_equal(result, expected) + + # We only need to support MultiIndex and list-of-tuples + msg = "|".join(["object is not iterable", "zip argument #1 must support iteration"]) + with pytest.raises(TypeError, match=msg): + hash_tuples(tuples[0]) + + +@pytest.mark.parametrize("val", [5, "foo", pd.Timestamp("20130101")]) +def test_hash_tuples_err(val): + msg = "must be convertible to a list-of-tuples" + with pytest.raises(TypeError, match=msg): + hash_tuples(val) + + +def test_multiindex_unique(): + mi = MultiIndex.from_tuples([(118, 472), (236, 118), (51, 204), (102, 51)]) + assert mi.is_unique is True + + result = hash_pandas_object(mi) + assert result.is_unique is True + + +def test_multiindex_objects(): + mi = MultiIndex( + levels=[["b", "d", "a"], [1, 2, 3]], + codes=[[0, 1, 0, 2], [2, 0, 0, 1]], + names=["col1", "col2"], + ) + recons = mi._sort_levels_monotonic() + + # These are equal. + assert mi.equals(recons) + assert Index(mi.values).equals(Index(recons.values)) + + +@pytest.mark.parametrize( + "obj", + [ + Series([1, 2, 3]), + Series([1.0, 1.5, 3.2]), + Series([1.0, 1.5, np.nan]), + Series([1.0, 1.5, 3.2], index=[1.5, 1.1, 3.3]), + Series(["a", "b", "c"]), + Series(["a", np.nan, "c"]), + Series(["a", None, "c"]), + Series([True, False, True]), + Series(dtype=object), + DataFrame({"x": ["a", "b", "c"], "y": [1, 2, 3]}), + DataFrame(), + DataFrame(np.full((10, 4), np.nan)), + DataFrame( + { + "A": [0.0, 1.0, 2.0, 3.0, 4.0], + "B": [0.0, 1.0, 0.0, 1.0, 0.0], + "C": Index(["foo1", "foo2", "foo3", "foo4", "foo5"], dtype=object), + "D": pd.date_range("20130101", periods=5), + } + ), + DataFrame(range(5), index=pd.date_range("2020-01-01", periods=5)), + Series(range(5), index=pd.date_range("2020-01-01", periods=5)), + Series(period_range("2020-01-01", periods=10, freq="D")), + Series(pd.date_range("20130101", periods=3, tz="US/Eastern")), + ], +) +def test_hash_pandas_object(obj, index): + a = hash_pandas_object(obj, index=index) + b = hash_pandas_object(obj, index=index) + tm.assert_series_equal(a, b) + + +@pytest.mark.parametrize( + "obj", + [ + Series([1, 2, 3]), + Series([1.0, 1.5, 3.2]), + Series([1.0, 1.5, np.nan]), + Series([1.0, 1.5, 3.2], index=[1.5, 1.1, 3.3]), + Series(["a", "b", "c"]), + Series(["a", np.nan, "c"]), + Series(["a", None, "c"]), + Series([True, False, True]), + DataFrame({"x": ["a", "b", "c"], "y": [1, 2, 3]}), + DataFrame(np.full((10, 4), np.nan)), + DataFrame( + { + "A": [0.0, 1.0, 2.0, 3.0, 4.0], + "B": [0.0, 1.0, 0.0, 1.0, 0.0], + "C": Index(["foo1", "foo2", "foo3", "foo4", "foo5"], dtype=object), + "D": pd.date_range("20130101", periods=5), + } + ), + DataFrame(range(5), index=pd.date_range("2020-01-01", periods=5)), + Series(range(5), index=pd.date_range("2020-01-01", periods=5)), + Series(period_range("2020-01-01", periods=10, freq="D")), + Series(pd.date_range("20130101", periods=3, tz="US/Eastern")), + ], +) +def test_hash_pandas_object_diff_index_non_empty(obj): + a = hash_pandas_object(obj, index=True) + b = hash_pandas_object(obj, index=False) + assert not (a == b).all() + + +@pytest.mark.parametrize( + "obj", + [ + Index([1, 2, 3]), + Index([True, False, True]), + timedelta_range("1 day", periods=2), + period_range("2020-01-01", freq="D", periods=2), + MultiIndex.from_product( + [range(5), ["foo", "bar", "baz"], pd.date_range("20130101", periods=2)] + ), + MultiIndex.from_product([pd.CategoricalIndex(list("aabc")), range(3)]), + ], +) +def test_hash_pandas_index(obj, index): + a = hash_pandas_object(obj, index=index) + b = hash_pandas_object(obj, index=index) + tm.assert_series_equal(a, b) + + +def test_hash_pandas_series(series, index): + a = hash_pandas_object(series, index=index) + b = hash_pandas_object(series, index=index) + tm.assert_series_equal(a, b) + + +def test_hash_pandas_series_diff_index(series): + a = hash_pandas_object(series, index=True) + b = hash_pandas_object(series, index=False) + assert not (a == b).all() + + +@pytest.mark.parametrize( + "obj", [Series([], dtype="float64"), Series([], dtype="object"), Index([])] +) +def test_hash_pandas_empty_object(obj, index): + # These are by-definition the same with + # or without the index as the data is empty. + a = hash_pandas_object(obj, index=index) + b = hash_pandas_object(obj, index=index) + tm.assert_series_equal(a, b) + + +@pytest.mark.parametrize( + "s1", + [ + Series(["a", "b", "c", "d"]), + Series([1000, 2000, 3000, 4000]), + Series(pd.date_range(0, periods=4)), + ], +) +@pytest.mark.parametrize("categorize", [True, False]) +def test_categorical_consistency(s1, categorize): + # see gh-15143 + # + # Check that categoricals hash consistent with their values, + # not codes. This should work for categoricals of any dtype. + s2 = s1.astype("category").cat.set_categories(s1) + s3 = s2.cat.set_categories(list(reversed(s1))) + + # These should all hash identically. + h1 = hash_pandas_object(s1, categorize=categorize) + h2 = hash_pandas_object(s2, categorize=categorize) + h3 = hash_pandas_object(s3, categorize=categorize) + + tm.assert_series_equal(h1, h2) + tm.assert_series_equal(h1, h3) + + +def test_categorical_with_nan_consistency(): + c = pd.Categorical.from_codes( + [-1, 0, 1, 2, 3, 4], categories=pd.date_range("2012-01-01", periods=5, name="B") + ) + expected = hash_array(c, categorize=False) + + c = pd.Categorical.from_codes([-1, 0], categories=[pd.Timestamp("2012-01-01")]) + result = hash_array(c, categorize=False) + + assert result[0] in expected + assert result[1] in expected + + +def test_pandas_errors(): + msg = "Unexpected type for hashing" + with pytest.raises(TypeError, match=msg): + hash_pandas_object(pd.Timestamp("20130101")) + + +def test_hash_keys(): + # Using different hash keys, should have + # different hashes for the same data. + # + # This only matters for object dtypes. + obj = Series(list("abc")) + + a = hash_pandas_object(obj, hash_key="9876543210123456") + b = hash_pandas_object(obj, hash_key="9876543210123465") + + assert (a != b).all() + + +def test_df_hash_keys(): + # DataFrame version of the test_hash_keys. + # https://github.com/pandas-dev/pandas/issues/41404 + obj = DataFrame({"x": np.arange(3), "y": list("abc")}) + + a = hash_pandas_object(obj, hash_key="9876543210123456") + b = hash_pandas_object(obj, hash_key="9876543210123465") + + assert (a != b).all() + + +def test_df_encoding(): + # Check that DataFrame recognizes optional encoding. + # https://github.com/pandas-dev/pandas/issues/41404 + # https://github.com/pandas-dev/pandas/pull/42049 + obj = DataFrame({"x": np.arange(3), "y": list("a+c")}) + + a = hash_pandas_object(obj, encoding="utf8") + b = hash_pandas_object(obj, encoding="utf7") + + # Note that the "+" is encoded as "+-" in utf-7. + assert a[0] == b[0] + assert a[1] != b[1] + assert a[2] == b[2] + + +def test_invalid_key(): + # This only matters for object dtypes. + msg = "key should be a 16-byte string encoded" + + with pytest.raises(ValueError, match=msg): + hash_pandas_object(Series(list("abc")), hash_key="foo") + + +def test_already_encoded(index): + # If already encoded, then ok. + obj = Series(list("abc")).str.encode("utf8") + a = hash_pandas_object(obj, index=index) + b = hash_pandas_object(obj, index=index) + tm.assert_series_equal(a, b) + + +def test_alternate_encoding(index): + obj = Series(list("abc")) + a = hash_pandas_object(obj, index=index) + b = hash_pandas_object(obj, index=index) + tm.assert_series_equal(a, b) + + +@pytest.mark.parametrize("l_exp", range(8)) +@pytest.mark.parametrize("l_add", [0, 1]) +def test_same_len_hash_collisions(l_exp, l_add): + length = 2 ** (l_exp + 8) + l_add + idx = np.array([str(i) for i in range(length)], dtype=object) + + result = hash_array(idx, "utf8") + assert not result[0] == result[1] + + +def test_hash_collisions(): + # Hash collisions are bad. + # + # https://github.com/pandas-dev/pandas/issues/14711#issuecomment-264885726 + hashes = [ + "Ingrid-9Z9fKIZmkO7i7Cn51Li34pJm44fgX6DYGBNj3VPlOH50m7HnBlPxfIwFMrcNJNMP6PSgLmwWnInciMWrCSAlLEvt7JkJl4IxiMrVbXSa8ZQoVaq5xoQPjltuJEfwdNlO6jo8qRRHvD8sBEBMQASrRa6TsdaPTPCBo3nwIBpE7YzzmyH0vMBhjQZLx1aCT7faSEx7PgFxQhHdKFWROcysamgy9iVj8DO2Fmwg1NNl93rIAqC3mdqfrCxrzfvIY8aJdzin2cHVzy3QUJxZgHvtUtOLxoqnUHsYbNTeq0xcLXpTZEZCxD4PGubIuCNf32c33M7HFsnjWSEjE2yVdWKhmSVodyF8hFYVmhYnMCztQnJrt3O8ZvVRXd5IKwlLexiSp4h888w7SzAIcKgc3g5XQJf6MlSMftDXm9lIsE1mJNiJEv6uY6pgvC3fUPhatlR5JPpVAHNSbSEE73MBzJrhCAbOLXQumyOXigZuPoME7QgJcBalliQol7YZ9", + "Tim-b9MddTxOWW2AT1Py6vtVbZwGAmYCjbp89p8mxsiFoVX4FyDOF3wFiAkyQTUgwg9sVqVYOZo09Dh1AzhFHbgij52ylF0SEwgzjzHH8TGY8Lypart4p4onnDoDvVMBa0kdthVGKl6K0BDVGzyOXPXKpmnMF1H6rJzqHJ0HywfwS4XYpVwlAkoeNsiicHkJUFdUAhG229INzvIAiJuAHeJDUoyO4DCBqtoZ5TDend6TK7Y914yHlfH3g1WZu5LksKv68VQHJriWFYusW5e6ZZ6dKaMjTwEGuRgdT66iU5nqWTHRH8WSzpXoCFwGcTOwyuqPSe0fTe21DVtJn1FKj9F9nEnR9xOvJUO7E0piCIF4Ad9yAIDY4DBimpsTfKXCu1vdHpKYerzbndfuFe5AhfMduLYZJi5iAw8qKSwR5h86ttXV0Mc0QmXz8dsRvDgxjXSmupPxBggdlqUlC828hXiTPD7am0yETBV0F3bEtvPiNJfremszcV8NcqAoARMe", + ] + + # These should be different. + result1 = hash_array(np.asarray(hashes[0:1], dtype=object), "utf8") + expected1 = np.array([14963968704024874985], dtype=np.uint64) + tm.assert_numpy_array_equal(result1, expected1) + + result2 = hash_array(np.asarray(hashes[1:2], dtype=object), "utf8") + expected2 = np.array([16428432627716348016], dtype=np.uint64) + tm.assert_numpy_array_equal(result2, expected2) + + result = hash_array(np.asarray(hashes, dtype=object), "utf8") + tm.assert_numpy_array_equal(result, np.concatenate([expected1, expected2], axis=0)) + + +@pytest.mark.parametrize( + "data, result_data", + [ + [[tuple("1"), tuple("2")], [10345501319357378243, 8331063931016360761]], + [[(1,), (2,)], [9408946347443669104, 3278256261030523334]], + ], +) +def test_hash_with_tuple(data, result_data): + # GH#28969 array containing a tuple raises on call to arr.astype(str) + # apparently a numpy bug github.com/numpy/numpy/issues/9441 + + df = DataFrame({"data": data}) + result = hash_pandas_object(df) + expected = Series(result_data, dtype=np.uint64) + tm.assert_series_equal(result, expected) + + +def test_hashable_tuple_args(): + # require that the elements of such tuples are themselves hashable + + df3 = DataFrame( + { + "data": [ + ( + 1, + [], + ), + ( + 2, + {}, + ), + ] + } + ) + with pytest.raises(TypeError, match="unhashable type: 'list'"): + hash_pandas_object(df3) + + +def test_hash_object_none_key(): + # https://github.com/pandas-dev/pandas/issues/30887 + result = pd.util.hash_pandas_object(Series(["a", "b"]), hash_key=None) + expected = Series([4578374827886788867, 17338122309987883691], dtype="uint64") + tm.assert_series_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_numba.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_numba.py new file mode 100644 index 0000000000000000000000000000000000000000..27b68ff0f60447e6695d786de9a72ecbb59f7884 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_numba.py @@ -0,0 +1,12 @@ +import pytest + +import pandas.util._test_decorators as td + +from pandas import option_context + + +@td.skip_if_installed("numba") +def test_numba_not_installed_option_context(): + with pytest.raises(ImportError, match="Missing optional"): + with option_context("compute.use_numba", True): + pass diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_rewrite_warning.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_rewrite_warning.py new file mode 100644 index 0000000000000000000000000000000000000000..f847a06d8ea8d7fa75aac1de9025a5bd29bedf37 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_rewrite_warning.py @@ -0,0 +1,39 @@ +import warnings + +import pytest + +from pandas.util._exceptions import rewrite_warning + +import pandas._testing as tm + + +@pytest.mark.parametrize( + "target_category, target_message, hit", + [ + (FutureWarning, "Target message", True), + (FutureWarning, "Target", True), + (FutureWarning, "get mess", True), + (FutureWarning, "Missed message", False), + (DeprecationWarning, "Target message", False), + ], +) +@pytest.mark.parametrize( + "new_category", + [ + None, + DeprecationWarning, + ], +) +def test_rewrite_warning(target_category, target_message, hit, new_category): + new_message = "Rewritten message" + if hit: + expected_category = new_category if new_category else target_category + expected_message = new_message + else: + expected_category = FutureWarning + expected_message = "Target message" + with tm.assert_produces_warning(expected_category, match=expected_message): + with rewrite_warning( + target_message, target_category, new_message, new_category + ): + warnings.warn(message="Target message", category=FutureWarning) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_shares_memory.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_shares_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..00a897d574a07ac262afa17de6752e5c95e3964e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_shares_memory.py @@ -0,0 +1,30 @@ +import pandas.util._test_decorators as td + +import pandas as pd +import pandas._testing as tm + + +def test_shares_memory_interval(): + obj = pd.interval_range(1, 5) + + assert tm.shares_memory(obj, obj) + assert tm.shares_memory(obj, obj._data) + assert tm.shares_memory(obj, obj[::-1]) + assert tm.shares_memory(obj, obj[:2]) + + assert not tm.shares_memory(obj, obj._data.copy()) + + +@td.skip_if_no("pyarrow") +def test_shares_memory_string(): + # GH#55823 + import pyarrow as pa + + obj = pd.array(["a", "b"], dtype="string[pyarrow]") + assert tm.shares_memory(obj, obj) + + obj = pd.array(["a", "b"], dtype="string[pyarrow_numpy]") + assert tm.shares_memory(obj, obj) + + obj = pd.array(["a", "b"], dtype=pd.ArrowDtype(pa.string())) + assert tm.shares_memory(obj, obj) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_show_versions.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_show_versions.py new file mode 100644 index 0000000000000000000000000000000000000000..72c9db23b210880793f37227c99e99e804800f08 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_show_versions.py @@ -0,0 +1,81 @@ +import json +import os +import re + +from pandas.util._print_versions import ( + _get_dependency_info, + _get_sys_info, +) + +import pandas as pd + + +def test_show_versions(tmpdir): + # GH39701 + as_json = os.path.join(tmpdir, "test_output.json") + + pd.show_versions(as_json=as_json) + + with open(as_json, encoding="utf-8") as fd: + # check if file output is valid JSON, will raise an exception if not + result = json.load(fd) + + # Basic check that each version element is found in output + expected = { + "system": _get_sys_info(), + "dependencies": _get_dependency_info(), + } + + assert result == expected + + +def test_show_versions_console_json(capsys): + # GH39701 + pd.show_versions(as_json=True) + stdout = capsys.readouterr().out + + # check valid json is printed to the console if as_json is True + result = json.loads(stdout) + + # Basic check that each version element is found in output + expected = { + "system": _get_sys_info(), + "dependencies": _get_dependency_info(), + } + + assert result == expected + + +def test_show_versions_console(capsys): + # gh-32041 + # gh-32041 + pd.show_versions(as_json=False) + result = capsys.readouterr().out + + # check header + assert "INSTALLED VERSIONS" in result + + # check full commit hash + assert re.search(r"commit\s*:\s[0-9a-f]{40}\n", result) + + # check required dependency + # 2020-12-09 npdev has "dirty" in the tag + # 2022-05-25 npdev released with RC wo/ "dirty". + # Just ensure we match [0-9]+\..* since npdev version is variable + assert re.search(r"numpy\s*:\s[0-9]+\..*\n", result) + + # check optional dependency + assert re.search(r"pyarrow\s*:\s([0-9]+.*|None)\n", result) + + +def test_json_output_match(capsys, tmpdir): + # GH39701 + pd.show_versions(as_json=True) + result_console = capsys.readouterr().out + + out_path = os.path.join(tmpdir, "test_json.json") + pd.show_versions(as_json=out_path) + with open(out_path, encoding="utf-8") as out_fd: + result_file = out_fd.read() + + assert result_console == result_file diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_util.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_util.py new file mode 100644 index 0000000000000000000000000000000000000000..dfb8587d3924e1441ac9da0aeeaa5585c6b4fe6c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_util.py @@ -0,0 +1,58 @@ +import os + +import pytest + +from pandas import ( + array, + compat, +) +import pandas._testing as tm + + +def test_numpy_err_state_is_default(): + expected = {"over": "warn", "divide": "warn", "invalid": "warn", "under": "ignore"} + import numpy as np + + # The error state should be unchanged after that import. + assert np.geterr() == expected + + +def test_convert_rows_list_to_csv_str(): + rows_list = ["aaa", "bbb", "ccc"] + ret = tm.convert_rows_list_to_csv_str(rows_list) + + if compat.is_platform_windows(): + expected = "aaa\r\nbbb\r\nccc\r\n" + else: + expected = "aaa\nbbb\nccc\n" + + assert ret == expected + + +@pytest.mark.parametrize("strict_data_files", [True, False]) +def test_datapath_missing(datapath): + with pytest.raises(ValueError, match="Could not find file"): + datapath("not_a_file") + + +def test_datapath(datapath): + args = ("io", "data", "csv", "iris.csv") + + result = datapath(*args) + expected = os.path.join(os.path.dirname(os.path.dirname(__file__)), *args) + + assert result == expected + + +def test_external_error_raised(): + with tm.external_error_raised(TypeError): + raise TypeError("Should not check this error message, so it will pass") + + +def test_is_sorted(): + arr = array([1, 2, 3], dtype="Int64") + tm.assert_is_sorted(arr) + + arr = array([4, 2, 3], dtype="Int64") + with pytest.raises(AssertionError, match="ExtensionArray are different"): + tm.assert_is_sorted(arr) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_validate_args.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_validate_args.py new file mode 100644 index 0000000000000000000000000000000000000000..eef0931ec28efd02e3db7a85b0b3260742c1ff2d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_validate_args.py @@ -0,0 +1,70 @@ +import pytest + +from pandas.util._validators import validate_args + + +@pytest.fixture +def _fname(): + return "func" + + +def test_bad_min_fname_arg_count(_fname): + msg = "'max_fname_arg_count' must be non-negative" + + with pytest.raises(ValueError, match=msg): + validate_args(_fname, (None,), -1, "foo") + + +def test_bad_arg_length_max_value_single(_fname): + args = (None, None) + compat_args = ("foo",) + + min_fname_arg_count = 0 + max_length = len(compat_args) + min_fname_arg_count + actual_length = len(args) + min_fname_arg_count + msg = ( + rf"{_fname}\(\) takes at most {max_length} " + rf"argument \({actual_length} given\)" + ) + + with pytest.raises(TypeError, match=msg): + validate_args(_fname, args, min_fname_arg_count, compat_args) + + +def test_bad_arg_length_max_value_multiple(_fname): + args = (None, None) + compat_args = {"foo": None} + + min_fname_arg_count = 2 + max_length = len(compat_args) + min_fname_arg_count + actual_length = len(args) + min_fname_arg_count + msg = ( + rf"{_fname}\(\) takes at most {max_length} " + rf"arguments \({actual_length} given\)" + ) + + with pytest.raises(TypeError, match=msg): + validate_args(_fname, args, min_fname_arg_count, compat_args) + + +@pytest.mark.parametrize("i", range(1, 3)) +def test_not_all_defaults(i, _fname): + bad_arg = "foo" + msg = ( + f"the '{bad_arg}' parameter is not supported " + rf"in the pandas implementation of {_fname}\(\)" + ) + + compat_args = {"foo": 2, "bar": -1, "baz": 3} + arg_vals = (1, -1, 3) + + with pytest.raises(ValueError, match=msg): + validate_args(_fname, arg_vals[:i], 2, compat_args) + + +def test_validation(_fname): + # No exceptions should be raised. + validate_args(_fname, (None,), 2, {"out": None}) + + compat_args = {"axis": 1, "out": None} + validate_args(_fname, (1, None), 2, compat_args) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_validate_args_and_kwargs.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_validate_args_and_kwargs.py new file mode 100644 index 0000000000000000000000000000000000000000..215026d648471c04cb8751506c03626fda73fc68 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_validate_args_and_kwargs.py @@ -0,0 +1,84 @@ +import pytest + +from pandas.util._validators import validate_args_and_kwargs + + +@pytest.fixture +def _fname(): + return "func" + + +def test_invalid_total_length_max_length_one(_fname): + compat_args = ("foo",) + kwargs = {"foo": "FOO"} + args = ("FoO", "BaZ") + + min_fname_arg_count = 0 + max_length = len(compat_args) + min_fname_arg_count + actual_length = len(kwargs) + len(args) + min_fname_arg_count + + msg = ( + rf"{_fname}\(\) takes at most {max_length} " + rf"argument \({actual_length} given\)" + ) + + with pytest.raises(TypeError, match=msg): + validate_args_and_kwargs(_fname, args, kwargs, min_fname_arg_count, compat_args) + + +def test_invalid_total_length_max_length_multiple(_fname): + compat_args = ("foo", "bar", "baz") + kwargs = {"foo": "FOO", "bar": "BAR"} + args = ("FoO", "BaZ") + + min_fname_arg_count = 2 + max_length = len(compat_args) + min_fname_arg_count + actual_length = len(kwargs) + len(args) + min_fname_arg_count + + msg = ( + rf"{_fname}\(\) takes at most {max_length} " + rf"arguments \({actual_length} given\)" + ) + + with pytest.raises(TypeError, match=msg): + validate_args_and_kwargs(_fname, args, kwargs, min_fname_arg_count, compat_args) + + +@pytest.mark.parametrize("args,kwargs", [((), {"foo": -5, "bar": 2}), ((-5, 2), {})]) +def test_missing_args_or_kwargs(args, kwargs, _fname): + bad_arg = "bar" + min_fname_arg_count = 2 + + compat_args = {"foo": -5, bad_arg: 1} + + msg = ( + rf"the '{bad_arg}' parameter is not supported " + rf"in the pandas implementation of {_fname}\(\)" + ) + + with pytest.raises(ValueError, match=msg): + validate_args_and_kwargs(_fname, args, kwargs, min_fname_arg_count, compat_args) + + +def test_duplicate_argument(_fname): + min_fname_arg_count = 2 + + compat_args = {"foo": None, "bar": None, "baz": None} + kwargs = {"foo": None, "bar": None} + args = (None,) # duplicate value for "foo" + + msg = rf"{_fname}\(\) got multiple values for keyword argument 'foo'" + + with pytest.raises(TypeError, match=msg): + validate_args_and_kwargs(_fname, args, kwargs, min_fname_arg_count, compat_args) + + +def test_validation(_fname): + # No exceptions should be raised. + compat_args = {"foo": 1, "bar": None, "baz": -2} + kwargs = {"baz": -2} + + args = (1, None) + min_fname_arg_count = 2 + + validate_args_and_kwargs(_fname, args, kwargs, min_fname_arg_count, compat_args) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_validate_inclusive.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_validate_inclusive.py new file mode 100644 index 0000000000000000000000000000000000000000..c1254c614ab305c447090b148ea6a036569f76e6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_validate_inclusive.py @@ -0,0 +1,40 @@ +import numpy as np +import pytest + +from pandas.util._validators import validate_inclusive + +import pandas as pd + + +@pytest.mark.parametrize( + "invalid_inclusive", + ( + "ccc", + 2, + object(), + None, + np.nan, + pd.NA, + pd.DataFrame(), + ), +) +def test_invalid_inclusive(invalid_inclusive): + with pytest.raises( + ValueError, + match="Inclusive has to be either 'both', 'neither', 'left' or 'right'", + ): + validate_inclusive(invalid_inclusive) + + +@pytest.mark.parametrize( + "valid_inclusive, expected_tuple", + ( + ("left", (True, False)), + ("right", (False, True)), + ("both", (True, True)), + ("neither", (False, False)), + ), +) +def test_valid_inclusive(valid_inclusive, expected_tuple): + resultant_tuple = validate_inclusive(valid_inclusive) + assert expected_tuple == resultant_tuple diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_validate_kwargs.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_validate_kwargs.py new file mode 100644 index 0000000000000000000000000000000000000000..dba447e30cf579c9f2f5c0bd917a4e0837143ed3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/util/test_validate_kwargs.py @@ -0,0 +1,69 @@ +import pytest + +from pandas.util._validators import ( + validate_bool_kwarg, + validate_kwargs, +) + + +@pytest.fixture +def _fname(): + return "func" + + +def test_bad_kwarg(_fname): + good_arg = "f" + bad_arg = good_arg + "o" + + compat_args = {good_arg: "foo", bad_arg + "o": "bar"} + kwargs = {good_arg: "foo", bad_arg: "bar"} + + msg = rf"{_fname}\(\) got an unexpected keyword argument '{bad_arg}'" + + with pytest.raises(TypeError, match=msg): + validate_kwargs(_fname, kwargs, compat_args) + + +@pytest.mark.parametrize("i", range(1, 3)) +def test_not_all_none(i, _fname): + bad_arg = "foo" + msg = ( + rf"the '{bad_arg}' parameter is not supported " + rf"in the pandas implementation of {_fname}\(\)" + ) + + compat_args = {"foo": 1, "bar": "s", "baz": None} + + kwarg_keys = ("foo", "bar", "baz") + kwarg_vals = (2, "s", None) + + kwargs = dict(zip(kwarg_keys[:i], kwarg_vals[:i])) + + with pytest.raises(ValueError, match=msg): + validate_kwargs(_fname, kwargs, compat_args) + + +def test_validation(_fname): + # No exceptions should be raised. + compat_args = {"f": None, "b": 1, "ba": "s"} + + kwargs = {"f": None, "b": 1} + validate_kwargs(_fname, kwargs, compat_args) + + +@pytest.mark.parametrize("name", ["inplace", "copy"]) +@pytest.mark.parametrize("value", [1, "True", [1, 2, 3], 5.0]) +def test_validate_bool_kwarg_fail(name, value): + msg = ( + f'For argument "{name}" expected type bool, ' + f"received type {type(value).__name__}" + ) + + with pytest.raises(ValueError, match=msg): + validate_bool_kwarg(value, name) + + +@pytest.mark.parametrize("name", ["inplace", "copy"]) +@pytest.mark.parametrize("value", [True, False, None]) +def test_validate_bool_kwarg(name, value): + assert validate_bool_kwarg(value, name) == value diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2d02f03ff8e836868c2bf5283e83054bbddc6c9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/conftest.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/conftest.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24db8974da86b84ed80c89585c4d1cc27d13ca76 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/conftest.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_api.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6102be0f8adcbe0d26920f624541b891722d5d1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_api.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_apply.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_apply.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..981a7c92966cde94080965b1a00061cc06ec79be Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_apply.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_base_indexer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_base_indexer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83877ad2a131bb671a3742a11d07813d89185a4b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_base_indexer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_cython_aggregations.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_cython_aggregations.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a89c23bee4845bf418f61b6466da41c165723c7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_cython_aggregations.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_dtypes.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_dtypes.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..796f119c3b4c9ecf89e9147428fc6654b006cb36 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_dtypes.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_ewm.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_ewm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad33f1d8497540dc93e05df3c93f76ceb460865f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_ewm.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_expanding.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_expanding.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19d6d72d8f1686922ced23a1ea5d1257c8d1cc51 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_expanding.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_groupby.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_groupby.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f96978f7a79b82dbd456e4e6e4e59bd6fb5ebc36 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_groupby.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_numba.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_numba.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3499e4a6b0634fec1e281580c0ff6fcdc12b7d38 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_numba.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_online.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_online.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a935c7b9e252bbc1c81f32e7b21e9aeddbd59e9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_online.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_pairwise.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_pairwise.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5929fc48eed2e31ceba4ff6c3a711431f50df822 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_pairwise.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_rolling.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_rolling.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..232ae4a770407b9b33ce76849155032ca81174c6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_rolling.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_rolling_functions.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_rolling_functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6da9dfcbbd2fd3ad784aa187250eff647ad7919 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_rolling_functions.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_rolling_quantile.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_rolling_quantile.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a74c2041296b7455e048bfd0312c7ca78b7d81f2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_rolling_quantile.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_rolling_skew_kurt.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_rolling_skew_kurt.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b462b4314b9c7ea0f6a51b09f692ec31b2580a8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_rolling_skew_kurt.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_timeseries_window.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_timeseries_window.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffe3f7f9816e2dc9f9e4124b5782967b92bdac48 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_timeseries_window.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_win_type.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_win_type.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1aa1ad612e3fb0d74ea20191e7b3491757e40a26 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/__pycache__/test_win_type.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47ff32dcb2605c00e6c694c244ae1e8e159415d6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/__pycache__/conftest.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/__pycache__/conftest.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a94f6a876b210c3c94b9a328f015bf38a0b34282 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/__pycache__/conftest.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/__pycache__/test_moments_consistency_ewm.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/__pycache__/test_moments_consistency_ewm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65b648e1063b802a99867201985a223a2b57d9ea Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/__pycache__/test_moments_consistency_ewm.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/__pycache__/test_moments_consistency_expanding.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/__pycache__/test_moments_consistency_expanding.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..910fa3ef6f012ef5231c2522ecaddf6cd0400f2a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/__pycache__/test_moments_consistency_expanding.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/__pycache__/test_moments_consistency_rolling.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/__pycache__/test_moments_consistency_rolling.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ffb1f9e438e91c20b040b6cbd4270ee0667f277 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/__pycache__/test_moments_consistency_rolling.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/conftest.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..fccf80c3c7a58d691b818709e51a9f6642956a33 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/conftest.py @@ -0,0 +1,72 @@ +import itertools + +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Series, + notna, +) + + +def create_series(): + return [ + Series(dtype=np.float64, name="a"), + Series([np.nan] * 5), + Series([1.0] * 5), + Series(range(5, 0, -1)), + Series(range(5)), + Series([np.nan, 1.0, np.nan, 1.0, 1.0]), + Series([np.nan, 1.0, np.nan, 2.0, 3.0]), + Series([np.nan, 1.0, np.nan, 3.0, 2.0]), + ] + + +def create_dataframes(): + return [ + DataFrame(columns=["a", "a"]), + DataFrame(np.arange(15).reshape((5, 3)), columns=["a", "a", 99]), + ] + [DataFrame(s) for s in create_series()] + + +def is_constant(x): + values = x.values.ravel("K") + return len(set(values[notna(values)])) == 1 + + +@pytest.fixture( + params=( + obj + for obj in itertools.chain(create_series(), create_dataframes()) + if is_constant(obj) + ), +) +def consistent_data(request): + return request.param + + +@pytest.fixture(params=create_series()) +def series_data(request): + return request.param + + +@pytest.fixture(params=itertools.chain(create_series(), create_dataframes())) +def all_data(request): + """ + Test: + - Empty Series / DataFrame + - All NaN + - All consistent value + - Monotonically decreasing + - Monotonically increasing + - Monotonically consistent with NaNs + - Monotonically increasing with NaNs + - Monotonically decreasing with NaNs + """ + return request.param + + +@pytest.fixture(params=[0, 2]) +def min_periods(request): + return request.param diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/test_moments_consistency_ewm.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/test_moments_consistency_ewm.py new file mode 100644 index 0000000000000000000000000000000000000000..49dee50954f4f42365d1ee4525fa48a3e18877fe --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/test_moments_consistency_ewm.py @@ -0,0 +1,243 @@ +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Series, + concat, +) +import pandas._testing as tm + + +def create_mock_weights(obj, com, adjust, ignore_na): + if isinstance(obj, DataFrame): + if not len(obj.columns): + return DataFrame(index=obj.index, columns=obj.columns) + w = concat( + [ + create_mock_series_weights( + obj.iloc[:, i], com=com, adjust=adjust, ignore_na=ignore_na + ) + for i in range(len(obj.columns)) + ], + axis=1, + ) + w.index = obj.index + w.columns = obj.columns + return w + else: + return create_mock_series_weights(obj, com, adjust, ignore_na) + + +def create_mock_series_weights(s, com, adjust, ignore_na): + w = Series(np.nan, index=s.index, name=s.name) + alpha = 1.0 / (1.0 + com) + if adjust: + count = 0 + for i in range(len(s)): + if s.iat[i] == s.iat[i]: + w.iat[i] = pow(1.0 / (1.0 - alpha), count) + count += 1 + elif not ignore_na: + count += 1 + else: + sum_wts = 0.0 + prev_i = -1 + count = 0 + for i in range(len(s)): + if s.iat[i] == s.iat[i]: + if prev_i == -1: + w.iat[i] = 1.0 + else: + w.iat[i] = alpha * sum_wts / pow(1.0 - alpha, count - prev_i) + sum_wts += w.iat[i] + prev_i = count + count += 1 + elif not ignore_na: + count += 1 + return w + + +def test_ewm_consistency_mean(all_data, adjust, ignore_na, min_periods): + com = 3.0 + + result = all_data.ewm( + com=com, min_periods=min_periods, adjust=adjust, ignore_na=ignore_na + ).mean() + weights = create_mock_weights(all_data, com=com, adjust=adjust, ignore_na=ignore_na) + expected = all_data.multiply(weights).cumsum().divide(weights.cumsum()).ffill() + expected[ + all_data.expanding().count() < (max(min_periods, 1) if min_periods else 1) + ] = np.nan + tm.assert_equal(result, expected.astype("float64")) + + +def test_ewm_consistency_consistent(consistent_data, adjust, ignore_na, min_periods): + com = 3.0 + + count_x = consistent_data.expanding().count() + mean_x = consistent_data.ewm( + com=com, min_periods=min_periods, adjust=adjust, ignore_na=ignore_na + ).mean() + # check that correlation of a series with itself is either 1 or NaN + corr_x_x = consistent_data.ewm( + com=com, min_periods=min_periods, adjust=adjust, ignore_na=ignore_na + ).corr(consistent_data) + exp = ( + consistent_data.max() + if isinstance(consistent_data, Series) + else consistent_data.max().max() + ) + + # check mean of constant series + expected = consistent_data * np.nan + expected[count_x >= max(min_periods, 1)] = exp + tm.assert_equal(mean_x, expected) + + # check correlation of constant series with itself is NaN + expected[:] = np.nan + tm.assert_equal(corr_x_x, expected) + + +def test_ewm_consistency_var_debiasing_factors( + all_data, adjust, ignore_na, min_periods +): + com = 3.0 + + # check variance debiasing factors + var_unbiased_x = all_data.ewm( + com=com, min_periods=min_periods, adjust=adjust, ignore_na=ignore_na + ).var(bias=False) + var_biased_x = all_data.ewm( + com=com, min_periods=min_periods, adjust=adjust, ignore_na=ignore_na + ).var(bias=True) + + weights = create_mock_weights(all_data, com=com, adjust=adjust, ignore_na=ignore_na) + cum_sum = weights.cumsum().ffill() + cum_sum_sq = (weights * weights).cumsum().ffill() + numerator = cum_sum * cum_sum + denominator = numerator - cum_sum_sq + denominator[denominator <= 0.0] = np.nan + var_debiasing_factors_x = numerator / denominator + + tm.assert_equal(var_unbiased_x, var_biased_x * var_debiasing_factors_x) + + +@pytest.mark.parametrize("bias", [True, False]) +def test_moments_consistency_var(all_data, adjust, ignore_na, min_periods, bias): + com = 3.0 + + mean_x = all_data.ewm( + com=com, min_periods=min_periods, adjust=adjust, ignore_na=ignore_na + ).mean() + var_x = all_data.ewm( + com=com, min_periods=min_periods, adjust=adjust, ignore_na=ignore_na + ).var(bias=bias) + assert not (var_x < 0).any().any() + + if bias: + # check that biased var(x) == mean(x^2) - mean(x)^2 + mean_x2 = ( + (all_data * all_data) + .ewm(com=com, min_periods=min_periods, adjust=adjust, ignore_na=ignore_na) + .mean() + ) + tm.assert_equal(var_x, mean_x2 - (mean_x * mean_x)) + + +@pytest.mark.parametrize("bias", [True, False]) +def test_moments_consistency_var_constant( + consistent_data, adjust, ignore_na, min_periods, bias +): + com = 3.0 + count_x = consistent_data.expanding(min_periods=min_periods).count() + var_x = consistent_data.ewm( + com=com, min_periods=min_periods, adjust=adjust, ignore_na=ignore_na + ).var(bias=bias) + + # check that variance of constant series is identically 0 + assert not (var_x > 0).any().any() + expected = consistent_data * np.nan + expected[count_x >= max(min_periods, 1)] = 0.0 + if not bias: + expected[count_x < 2] = np.nan + tm.assert_equal(var_x, expected) + + +@pytest.mark.parametrize("bias", [True, False]) +def test_ewm_consistency_std(all_data, adjust, ignore_na, min_periods, bias): + com = 3.0 + var_x = all_data.ewm( + com=com, min_periods=min_periods, adjust=adjust, ignore_na=ignore_na + ).var(bias=bias) + assert not (var_x < 0).any().any() + + std_x = all_data.ewm( + com=com, min_periods=min_periods, adjust=adjust, ignore_na=ignore_na + ).std(bias=bias) + assert not (std_x < 0).any().any() + + # check that var(x) == std(x)^2 + tm.assert_equal(var_x, std_x * std_x) + + cov_x_x = all_data.ewm( + com=com, min_periods=min_periods, adjust=adjust, ignore_na=ignore_na + ).cov(all_data, bias=bias) + assert not (cov_x_x < 0).any().any() + + # check that var(x) == cov(x, x) + tm.assert_equal(var_x, cov_x_x) + + +@pytest.mark.parametrize("bias", [True, False]) +def test_ewm_consistency_series_cov_corr( + series_data, adjust, ignore_na, min_periods, bias +): + com = 3.0 + + var_x_plus_y = ( + (series_data + series_data) + .ewm(com=com, min_periods=min_periods, adjust=adjust, ignore_na=ignore_na) + .var(bias=bias) + ) + var_x = series_data.ewm( + com=com, min_periods=min_periods, adjust=adjust, ignore_na=ignore_na + ).var(bias=bias) + var_y = series_data.ewm( + com=com, min_periods=min_periods, adjust=adjust, ignore_na=ignore_na + ).var(bias=bias) + cov_x_y = series_data.ewm( + com=com, min_periods=min_periods, adjust=adjust, ignore_na=ignore_na + ).cov(series_data, bias=bias) + # check that cov(x, y) == (var(x+y) - var(x) - + # var(y)) / 2 + tm.assert_equal(cov_x_y, 0.5 * (var_x_plus_y - var_x - var_y)) + + # check that corr(x, y) == cov(x, y) / (std(x) * + # std(y)) + corr_x_y = series_data.ewm( + com=com, min_periods=min_periods, adjust=adjust, ignore_na=ignore_na + ).corr(series_data) + std_x = series_data.ewm( + com=com, min_periods=min_periods, adjust=adjust, ignore_na=ignore_na + ).std(bias=bias) + std_y = series_data.ewm( + com=com, min_periods=min_periods, adjust=adjust, ignore_na=ignore_na + ).std(bias=bias) + tm.assert_equal(corr_x_y, cov_x_y / (std_x * std_y)) + + if bias: + # check that biased cov(x, y) == mean(x*y) - + # mean(x)*mean(y) + mean_x = series_data.ewm( + com=com, min_periods=min_periods, adjust=adjust, ignore_na=ignore_na + ).mean() + mean_y = series_data.ewm( + com=com, min_periods=min_periods, adjust=adjust, ignore_na=ignore_na + ).mean() + mean_x_times_y = ( + (series_data * series_data) + .ewm(com=com, min_periods=min_periods, adjust=adjust, ignore_na=ignore_na) + .mean() + ) + tm.assert_equal(cov_x_y, mean_x_times_y - (mean_x * mean_y)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/test_moments_consistency_expanding.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/test_moments_consistency_expanding.py new file mode 100644 index 0000000000000000000000000000000000000000..7d2fa1ad5d21175dcfafe9a57dd8169fc4413360 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/test_moments_consistency_expanding.py @@ -0,0 +1,144 @@ +import numpy as np +import pytest + +from pandas import Series +import pandas._testing as tm + + +def no_nans(x): + return x.notna().all().all() + + +def all_na(x): + return x.isnull().all().all() + + +@pytest.mark.parametrize("f", [lambda v: Series(v).sum(), np.nansum, np.sum]) +def test_expanding_apply_consistency_sum_nans(request, all_data, min_periods, f): + if f is np.sum: + if not no_nans(all_data) and not ( + all_na(all_data) and not all_data.empty and min_periods > 0 + ): + request.applymarker( + pytest.mark.xfail(reason="np.sum has different behavior with NaNs") + ) + expanding_f_result = all_data.expanding(min_periods=min_periods).sum() + expanding_apply_f_result = all_data.expanding(min_periods=min_periods).apply( + func=f, raw=True + ) + tm.assert_equal(expanding_f_result, expanding_apply_f_result) + + +@pytest.mark.parametrize("ddof", [0, 1]) +def test_moments_consistency_var(all_data, min_periods, ddof): + var_x = all_data.expanding(min_periods=min_periods).var(ddof=ddof) + assert not (var_x < 0).any().any() + + if ddof == 0: + # check that biased var(x) == mean(x^2) - mean(x)^2 + mean_x2 = (all_data * all_data).expanding(min_periods=min_periods).mean() + mean_x = all_data.expanding(min_periods=min_periods).mean() + tm.assert_equal(var_x, mean_x2 - (mean_x * mean_x)) + + +@pytest.mark.parametrize("ddof", [0, 1]) +def test_moments_consistency_var_constant(consistent_data, min_periods, ddof): + count_x = consistent_data.expanding(min_periods=min_periods).count() + var_x = consistent_data.expanding(min_periods=min_periods).var(ddof=ddof) + + # check that variance of constant series is identically 0 + assert not (var_x > 0).any().any() + expected = consistent_data * np.nan + expected[count_x >= max(min_periods, 1)] = 0.0 + if ddof == 1: + expected[count_x < 2] = np.nan + tm.assert_equal(var_x, expected) + + +@pytest.mark.parametrize("ddof", [0, 1]) +def test_expanding_consistency_var_std_cov(all_data, min_periods, ddof): + var_x = all_data.expanding(min_periods=min_periods).var(ddof=ddof) + assert not (var_x < 0).any().any() + + std_x = all_data.expanding(min_periods=min_periods).std(ddof=ddof) + assert not (std_x < 0).any().any() + + # check that var(x) == std(x)^2 + tm.assert_equal(var_x, std_x * std_x) + + cov_x_x = all_data.expanding(min_periods=min_periods).cov(all_data, ddof=ddof) + assert not (cov_x_x < 0).any().any() + + # check that var(x) == cov(x, x) + tm.assert_equal(var_x, cov_x_x) + + +@pytest.mark.parametrize("ddof", [0, 1]) +def test_expanding_consistency_series_cov_corr(series_data, min_periods, ddof): + var_x_plus_y = ( + (series_data + series_data).expanding(min_periods=min_periods).var(ddof=ddof) + ) + var_x = series_data.expanding(min_periods=min_periods).var(ddof=ddof) + var_y = series_data.expanding(min_periods=min_periods).var(ddof=ddof) + cov_x_y = series_data.expanding(min_periods=min_periods).cov(series_data, ddof=ddof) + # check that cov(x, y) == (var(x+y) - var(x) - + # var(y)) / 2 + tm.assert_equal(cov_x_y, 0.5 * (var_x_plus_y - var_x - var_y)) + + # check that corr(x, y) == cov(x, y) / (std(x) * + # std(y)) + corr_x_y = series_data.expanding(min_periods=min_periods).corr(series_data) + std_x = series_data.expanding(min_periods=min_periods).std(ddof=ddof) + std_y = series_data.expanding(min_periods=min_periods).std(ddof=ddof) + tm.assert_equal(corr_x_y, cov_x_y / (std_x * std_y)) + + if ddof == 0: + # check that biased cov(x, y) == mean(x*y) - + # mean(x)*mean(y) + mean_x = series_data.expanding(min_periods=min_periods).mean() + mean_y = series_data.expanding(min_periods=min_periods).mean() + mean_x_times_y = ( + (series_data * series_data).expanding(min_periods=min_periods).mean() + ) + tm.assert_equal(cov_x_y, mean_x_times_y - (mean_x * mean_y)) + + +def test_expanding_consistency_mean(all_data, min_periods): + result = all_data.expanding(min_periods=min_periods).mean() + expected = ( + all_data.expanding(min_periods=min_periods).sum() + / all_data.expanding(min_periods=min_periods).count() + ) + tm.assert_equal(result, expected.astype("float64")) + + +def test_expanding_consistency_constant(consistent_data, min_periods): + count_x = consistent_data.expanding().count() + mean_x = consistent_data.expanding(min_periods=min_periods).mean() + # check that correlation of a series with itself is either 1 or NaN + corr_x_x = consistent_data.expanding(min_periods=min_periods).corr(consistent_data) + + exp = ( + consistent_data.max() + if isinstance(consistent_data, Series) + else consistent_data.max().max() + ) + + # check mean of constant series + expected = consistent_data * np.nan + expected[count_x >= max(min_periods, 1)] = exp + tm.assert_equal(mean_x, expected) + + # check correlation of constant series with itself is NaN + expected[:] = np.nan + tm.assert_equal(corr_x_x, expected) + + +def test_expanding_consistency_var_debiasing_factors(all_data, min_periods): + # check variance debiasing factors + var_unbiased_x = all_data.expanding(min_periods=min_periods).var() + var_biased_x = all_data.expanding(min_periods=min_periods).var(ddof=0) + var_debiasing_factors_x = all_data.expanding().count() / ( + all_data.expanding().count() - 1.0 + ).replace(0.0, np.nan) + tm.assert_equal(var_unbiased_x, var_biased_x * var_debiasing_factors_x) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/test_moments_consistency_rolling.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/test_moments_consistency_rolling.py new file mode 100644 index 0000000000000000000000000000000000000000..be22338c00cb28fb4fbd1bfe7f4b6163e239a432 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/window/moments/test_moments_consistency_rolling.py @@ -0,0 +1,244 @@ +import numpy as np +import pytest + +from pandas import Series +import pandas._testing as tm + + +def no_nans(x): + return x.notna().all().all() + + +def all_na(x): + return x.isnull().all().all() + + +@pytest.fixture(params=[(1, 0), (5, 1)]) +def rolling_consistency_cases(request): + """window, min_periods""" + return request.param + + +@pytest.mark.parametrize("f", [lambda v: Series(v).sum(), np.nansum, np.sum]) +def test_rolling_apply_consistency_sum( + request, all_data, rolling_consistency_cases, center, f +): + window, min_periods = rolling_consistency_cases + + if f is np.sum: + if not no_nans(all_data) and not ( + all_na(all_data) and not all_data.empty and min_periods > 0 + ): + request.applymarker( + pytest.mark.xfail(reason="np.sum has different behavior with NaNs") + ) + rolling_f_result = all_data.rolling( + window=window, min_periods=min_periods, center=center + ).sum() + rolling_apply_f_result = all_data.rolling( + window=window, min_periods=min_periods, center=center + ).apply(func=f, raw=True) + tm.assert_equal(rolling_f_result, rolling_apply_f_result) + + +@pytest.mark.parametrize("ddof", [0, 1]) +def test_moments_consistency_var(all_data, rolling_consistency_cases, center, ddof): + window, min_periods = rolling_consistency_cases + + var_x = all_data.rolling(window=window, min_periods=min_periods, center=center).var( + ddof=ddof + ) + assert not (var_x < 0).any().any() + + if ddof == 0: + # check that biased var(x) == mean(x^2) - mean(x)^2 + mean_x = all_data.rolling( + window=window, min_periods=min_periods, center=center + ).mean() + mean_x2 = ( + (all_data * all_data) + .rolling(window=window, min_periods=min_periods, center=center) + .mean() + ) + tm.assert_equal(var_x, mean_x2 - (mean_x * mean_x)) + + +@pytest.mark.parametrize("ddof", [0, 1]) +def test_moments_consistency_var_constant( + consistent_data, rolling_consistency_cases, center, ddof +): + window, min_periods = rolling_consistency_cases + + count_x = consistent_data.rolling( + window=window, min_periods=min_periods, center=center + ).count() + var_x = consistent_data.rolling( + window=window, min_periods=min_periods, center=center + ).var(ddof=ddof) + + # check that variance of constant series is identically 0 + assert not (var_x > 0).any().any() + expected = consistent_data * np.nan + expected[count_x >= max(min_periods, 1)] = 0.0 + if ddof == 1: + expected[count_x < 2] = np.nan + tm.assert_equal(var_x, expected) + + +@pytest.mark.parametrize("ddof", [0, 1]) +def test_rolling_consistency_var_std_cov( + all_data, rolling_consistency_cases, center, ddof +): + window, min_periods = rolling_consistency_cases + + var_x = all_data.rolling(window=window, min_periods=min_periods, center=center).var( + ddof=ddof + ) + assert not (var_x < 0).any().any() + + std_x = all_data.rolling(window=window, min_periods=min_periods, center=center).std( + ddof=ddof + ) + assert not (std_x < 0).any().any() + + # check that var(x) == std(x)^2 + tm.assert_equal(var_x, std_x * std_x) + + cov_x_x = all_data.rolling( + window=window, min_periods=min_periods, center=center + ).cov(all_data, ddof=ddof) + assert not (cov_x_x < 0).any().any() + + # check that var(x) == cov(x, x) + tm.assert_equal(var_x, cov_x_x) + + +@pytest.mark.parametrize("ddof", [0, 1]) +def test_rolling_consistency_series_cov_corr( + series_data, rolling_consistency_cases, center, ddof +): + window, min_periods = rolling_consistency_cases + + var_x_plus_y = ( + (series_data + series_data) + .rolling(window=window, min_periods=min_periods, center=center) + .var(ddof=ddof) + ) + var_x = series_data.rolling( + window=window, min_periods=min_periods, center=center + ).var(ddof=ddof) + var_y = series_data.rolling( + window=window, min_periods=min_periods, center=center + ).var(ddof=ddof) + cov_x_y = series_data.rolling( + window=window, min_periods=min_periods, center=center + ).cov(series_data, ddof=ddof) + # check that cov(x, y) == (var(x+y) - var(x) - + # var(y)) / 2 + tm.assert_equal(cov_x_y, 0.5 * (var_x_plus_y - var_x - var_y)) + + # check that corr(x, y) == cov(x, y) / (std(x) * + # std(y)) + corr_x_y = series_data.rolling( + window=window, min_periods=min_periods, center=center + ).corr(series_data) + std_x = series_data.rolling( + window=window, min_periods=min_periods, center=center + ).std(ddof=ddof) + std_y = series_data.rolling( + window=window, min_periods=min_periods, center=center + ).std(ddof=ddof) + tm.assert_equal(corr_x_y, cov_x_y / (std_x * std_y)) + + if ddof == 0: + # check that biased cov(x, y) == mean(x*y) - + # mean(x)*mean(y) + mean_x = series_data.rolling( + window=window, min_periods=min_periods, center=center + ).mean() + mean_y = series_data.rolling( + window=window, min_periods=min_periods, center=center + ).mean() + mean_x_times_y = ( + (series_data * series_data) + .rolling(window=window, min_periods=min_periods, center=center) + .mean() + ) + tm.assert_equal(cov_x_y, mean_x_times_y - (mean_x * mean_y)) + + +def test_rolling_consistency_mean(all_data, rolling_consistency_cases, center): + window, min_periods = rolling_consistency_cases + + result = all_data.rolling( + window=window, min_periods=min_periods, center=center + ).mean() + expected = ( + all_data.rolling(window=window, min_periods=min_periods, center=center) + .sum() + .divide( + all_data.rolling( + window=window, min_periods=min_periods, center=center + ).count() + ) + ) + tm.assert_equal(result, expected.astype("float64")) + + +def test_rolling_consistency_constant( + consistent_data, rolling_consistency_cases, center +): + window, min_periods = rolling_consistency_cases + + count_x = consistent_data.rolling( + window=window, min_periods=min_periods, center=center + ).count() + mean_x = consistent_data.rolling( + window=window, min_periods=min_periods, center=center + ).mean() + # check that correlation of a series with itself is either 1 or NaN + corr_x_x = consistent_data.rolling( + window=window, min_periods=min_periods, center=center + ).corr(consistent_data) + + exp = ( + consistent_data.max() + if isinstance(consistent_data, Series) + else consistent_data.max().max() + ) + + # check mean of constant series + expected = consistent_data * np.nan + expected[count_x >= max(min_periods, 1)] = exp + tm.assert_equal(mean_x, expected) + + # check correlation of constant series with itself is NaN + expected[:] = np.nan + tm.assert_equal(corr_x_x, expected) + + +def test_rolling_consistency_var_debiasing_factors( + all_data, rolling_consistency_cases, center +): + window, min_periods = rolling_consistency_cases + + # check variance debiasing factors + var_unbiased_x = all_data.rolling( + window=window, min_periods=min_periods, center=center + ).var() + var_biased_x = all_data.rolling( + window=window, min_periods=min_periods, center=center + ).var(ddof=0) + var_debiasing_factors_x = ( + all_data.rolling(window=window, min_periods=min_periods, center=center) + .count() + .divide( + ( + all_data.rolling( + window=window, min_periods=min_periods, center=center + ).count() + - 1.0 + ).replace(0.0, np.nan) + ) + ) + tm.assert_equal(var_unbiased_x, var_biased_x * var_debiasing_factors_x) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cfa7f93fe2a3c8ecb3aa6544f2f52b716bf345d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__pycache__/_decorators.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__pycache__/_decorators.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a247d5893db45b818fc8b9c1e26fae49d252bbc7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__pycache__/_decorators.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__pycache__/_doctools.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__pycache__/_doctools.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4dc9efaf780edb20896f78950c8f905a24603065 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__pycache__/_doctools.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__pycache__/_exceptions.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__pycache__/_exceptions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b28f73ca27d34b1f4eac7db2ffb9a478e5af5a4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__pycache__/_exceptions.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__pycache__/_print_versions.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__pycache__/_print_versions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b29771dd48cfc94d44001aebd3f27c82cf8139bc Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__pycache__/_print_versions.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__pycache__/_test_decorators.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__pycache__/_test_decorators.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b35f9ba516fc0bb1301278795f3d6e72060f811d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__pycache__/_test_decorators.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__pycache__/_tester.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__pycache__/_tester.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc95b392a75c037cacf4bc0c32519548e551caf7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__pycache__/_tester.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__pycache__/_validators.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__pycache__/_validators.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..813a9704adecdbc512e19de5f8ef4dcfc25a8b03 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__pycache__/_validators.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/version/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/version/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3a5efbbb09c1e5085bf0564d819be345edd8c827 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/version/__init__.py @@ -0,0 +1,579 @@ +# Vendored from https://github.com/pypa/packaging/blob/main/packaging/_structures.py +# and https://github.com/pypa/packaging/blob/main/packaging/_structures.py +# changeset ae891fd74d6dd4c6063bb04f2faeadaac6fc6313 +# 04/30/2021 + +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. Licence at LICENSES/PACKAGING_LICENSE +from __future__ import annotations + +import collections +from collections.abc import Iterator +import itertools +import re +from typing import ( + Callable, + SupportsInt, + Tuple, + Union, +) +import warnings + +__all__ = ["parse", "Version", "LegacyVersion", "InvalidVersion", "VERSION_PATTERN"] + + +class InfinityType: + def __repr__(self) -> str: + return "Infinity" + + def __hash__(self) -> int: + return hash(repr(self)) + + def __lt__(self, other: object) -> bool: + return False + + def __le__(self, other: object) -> bool: + return False + + def __eq__(self, other: object) -> bool: + return isinstance(other, type(self)) + + def __ne__(self, other: object) -> bool: + return not isinstance(other, type(self)) + + def __gt__(self, other: object) -> bool: + return True + + def __ge__(self, other: object) -> bool: + return True + + def __neg__(self: object) -> NegativeInfinityType: + return NegativeInfinity + + +Infinity = InfinityType() + + +class NegativeInfinityType: + def __repr__(self) -> str: + return "-Infinity" + + def __hash__(self) -> int: + return hash(repr(self)) + + def __lt__(self, other: object) -> bool: + return True + + def __le__(self, other: object) -> bool: + return True + + def __eq__(self, other: object) -> bool: + return isinstance(other, type(self)) + + def __ne__(self, other: object) -> bool: + return not isinstance(other, type(self)) + + def __gt__(self, other: object) -> bool: + return False + + def __ge__(self, other: object) -> bool: + return False + + def __neg__(self: object) -> InfinityType: + return Infinity + + +NegativeInfinity = NegativeInfinityType() + + +InfiniteTypes = Union[InfinityType, NegativeInfinityType] +PrePostDevType = Union[InfiniteTypes, tuple[str, int]] +SubLocalType = Union[InfiniteTypes, int, str] +LocalType = Union[ + NegativeInfinityType, + tuple[ + Union[ + SubLocalType, + tuple[SubLocalType, str], + tuple[NegativeInfinityType, SubLocalType], + ], + ..., + ], +] +CmpKey = tuple[ + int, tuple[int, ...], PrePostDevType, PrePostDevType, PrePostDevType, LocalType +] +LegacyCmpKey = tuple[int, tuple[str, ...]] +VersionComparisonMethod = Callable[ + [Union[CmpKey, LegacyCmpKey], Union[CmpKey, LegacyCmpKey]], bool +] + +_Version = collections.namedtuple( + "_Version", ["epoch", "release", "dev", "pre", "post", "local"] +) + + +def parse(version: str) -> LegacyVersion | Version: + """ + Parse the given version string and return either a :class:`Version` object + or a :class:`LegacyVersion` object depending on if the given version is + a valid PEP 440 version or a legacy version. + """ + try: + return Version(version) + except InvalidVersion: + return LegacyVersion(version) + + +class InvalidVersion(ValueError): + """ + An invalid version was found, users should refer to PEP 440. + + Examples + -------- + >>> pd.util.version.Version('1.') + Traceback (most recent call last): + InvalidVersion: Invalid version: '1.' + """ + + +class _BaseVersion: + _key: CmpKey | LegacyCmpKey + + def __hash__(self) -> int: + return hash(self._key) + + # Please keep the duplicated `isinstance` check + # in the six comparisons hereunder + # unless you find a way to avoid adding overhead function calls. + def __lt__(self, other: _BaseVersion) -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key < other._key + + def __le__(self, other: _BaseVersion) -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key <= other._key + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key == other._key + + def __ge__(self, other: _BaseVersion) -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key >= other._key + + def __gt__(self, other: _BaseVersion) -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key > other._key + + def __ne__(self, other: object) -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key != other._key + + +class LegacyVersion(_BaseVersion): + def __init__(self, version: str) -> None: + self._version = str(version) + self._key = _legacy_cmpkey(self._version) + + warnings.warn( + "Creating a LegacyVersion has been deprecated and will be " + "removed in the next major release.", + DeprecationWarning, + ) + + def __str__(self) -> str: + return self._version + + def __repr__(self) -> str: + return f"" + + @property + def public(self) -> str: + return self._version + + @property + def base_version(self) -> str: + return self._version + + @property + def epoch(self) -> int: + return -1 + + @property + def release(self) -> None: + return None + + @property + def pre(self) -> None: + return None + + @property + def post(self) -> None: + return None + + @property + def dev(self) -> None: + return None + + @property + def local(self) -> None: + return None + + @property + def is_prerelease(self) -> bool: + return False + + @property + def is_postrelease(self) -> bool: + return False + + @property + def is_devrelease(self) -> bool: + return False + + +_legacy_version_component_re = re.compile(r"(\d+ | [a-z]+ | \.| -)", re.VERBOSE) + +_legacy_version_replacement_map = { + "pre": "c", + "preview": "c", + "-": "final-", + "rc": "c", + "dev": "@", +} + + +def _parse_version_parts(s: str) -> Iterator[str]: + for part in _legacy_version_component_re.split(s): + mapped_part = _legacy_version_replacement_map.get(part, part) + + if not mapped_part or mapped_part == ".": + continue + + if mapped_part[:1] in "0123456789": + # pad for numeric comparison + yield mapped_part.zfill(8) + else: + yield "*" + mapped_part + + # ensure that alpha/beta/candidate are before final + yield "*final" + + +def _legacy_cmpkey(version: str) -> LegacyCmpKey: + # We hardcode an epoch of -1 here. A PEP 440 version can only have a epoch + # greater than or equal to 0. This will effectively put the LegacyVersion, + # which uses the defacto standard originally implemented by setuptools, + # as before all PEP 440 versions. + epoch = -1 + + # This scheme is taken from pkg_resources.parse_version setuptools prior to + # it's adoption of the packaging library. + parts: list[str] = [] + for part in _parse_version_parts(version.lower()): + if part.startswith("*"): + # remove "-" before a prerelease tag + if part < "*final": + while parts and parts[-1] == "*final-": + parts.pop() + + # remove trailing zeros from each series of numeric parts + while parts and parts[-1] == "00000000": + parts.pop() + + parts.append(part) + + return epoch, tuple(parts) + + +# Deliberately not anchored to the start and end of the string, to make it +# easier for 3rd party code to reuse +VERSION_PATTERN = r""" + v? + (?: + (?:(?P[0-9]+)!)? # epoch + (?P[0-9]+(?:\.[0-9]+)*) # release segment + (?P
                                          # pre-release
+            [-_\.]?
+            (?P(a|b|c|rc|alpha|beta|pre|preview))
+            [-_\.]?
+            (?P[0-9]+)?
+        )?
+        (?P                                         # post release
+            (?:-(?P[0-9]+))
+            |
+            (?:
+                [-_\.]?
+                (?Ppost|rev|r)
+                [-_\.]?
+                (?P[0-9]+)?
+            )
+        )?
+        (?P                                          # dev release
+            [-_\.]?
+            (?Pdev)
+            [-_\.]?
+            (?P[0-9]+)?
+        )?
+    )
+    (?:\+(?P[a-z0-9]+(?:[-_\.][a-z0-9]+)*))?       # local version
+"""
+
+
+class Version(_BaseVersion):
+    _regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
+
+    def __init__(self, version: str) -> None:
+        # Validate the version and parse it into pieces
+        match = self._regex.search(version)
+        if not match:
+            raise InvalidVersion(f"Invalid version: '{version}'")
+
+        # Store the parsed out pieces of the version
+        self._version = _Version(
+            epoch=int(match.group("epoch")) if match.group("epoch") else 0,
+            release=tuple(int(i) for i in match.group("release").split(".")),
+            pre=_parse_letter_version(match.group("pre_l"), match.group("pre_n")),
+            post=_parse_letter_version(
+                match.group("post_l"), match.group("post_n1") or match.group("post_n2")
+            ),
+            dev=_parse_letter_version(match.group("dev_l"), match.group("dev_n")),
+            local=_parse_local_version(match.group("local")),
+        )
+
+        # Generate a key which will be used for sorting
+        self._key = _cmpkey(
+            self._version.epoch,
+            self._version.release,
+            self._version.pre,
+            self._version.post,
+            self._version.dev,
+            self._version.local,
+        )
+
+    def __repr__(self) -> str:
+        return f""
+
+    def __str__(self) -> str:
+        parts = []
+
+        # Epoch
+        if self.epoch != 0:
+            parts.append(f"{self.epoch}!")
+
+        # Release segment
+        parts.append(".".join([str(x) for x in self.release]))
+
+        # Pre-release
+        if self.pre is not None:
+            parts.append("".join([str(x) for x in self.pre]))
+
+        # Post-release
+        if self.post is not None:
+            parts.append(f".post{self.post}")
+
+        # Development release
+        if self.dev is not None:
+            parts.append(f".dev{self.dev}")
+
+        # Local version segment
+        if self.local is not None:
+            parts.append(f"+{self.local}")
+
+        return "".join(parts)
+
+    @property
+    def epoch(self) -> int:
+        _epoch: int = self._version.epoch
+        return _epoch
+
+    @property
+    def release(self) -> tuple[int, ...]:
+        _release: tuple[int, ...] = self._version.release
+        return _release
+
+    @property
+    def pre(self) -> tuple[str, int] | None:
+        _pre: tuple[str, int] | None = self._version.pre
+        return _pre
+
+    @property
+    def post(self) -> int | None:
+        return self._version.post[1] if self._version.post else None
+
+    @property
+    def dev(self) -> int | None:
+        return self._version.dev[1] if self._version.dev else None
+
+    @property
+    def local(self) -> str | None:
+        if self._version.local:
+            return ".".join([str(x) for x in self._version.local])
+        else:
+            return None
+
+    @property
+    def public(self) -> str:
+        return str(self).split("+", 1)[0]
+
+    @property
+    def base_version(self) -> str:
+        parts = []
+
+        # Epoch
+        if self.epoch != 0:
+            parts.append(f"{self.epoch}!")
+
+        # Release segment
+        parts.append(".".join([str(x) for x in self.release]))
+
+        return "".join(parts)
+
+    @property
+    def is_prerelease(self) -> bool:
+        return self.dev is not None or self.pre is not None
+
+    @property
+    def is_postrelease(self) -> bool:
+        return self.post is not None
+
+    @property
+    def is_devrelease(self) -> bool:
+        return self.dev is not None
+
+    @property
+    def major(self) -> int:
+        return self.release[0] if len(self.release) >= 1 else 0
+
+    @property
+    def minor(self) -> int:
+        return self.release[1] if len(self.release) >= 2 else 0
+
+    @property
+    def micro(self) -> int:
+        return self.release[2] if len(self.release) >= 3 else 0
+
+
+def _parse_letter_version(
+    letter: str, number: str | bytes | SupportsInt
+) -> tuple[str, int] | None:
+    if letter:
+        # We consider there to be an implicit 0 in a pre-release if there is
+        # not a numeral associated with it.
+        if number is None:
+            number = 0
+
+        # We normalize any letters to their lower case form
+        letter = letter.lower()
+
+        # We consider some words to be alternate spellings of other words and
+        # in those cases we want to normalize the spellings to our preferred
+        # spelling.
+        if letter == "alpha":
+            letter = "a"
+        elif letter == "beta":
+            letter = "b"
+        elif letter in ["c", "pre", "preview"]:
+            letter = "rc"
+        elif letter in ["rev", "r"]:
+            letter = "post"
+
+        return letter, int(number)
+    if not letter and number:
+        # We assume if we are given a number, but we are not given a letter
+        # then this is using the implicit post release syntax (e.g. 1.0-1)
+        letter = "post"
+
+        return letter, int(number)
+
+    return None
+
+
+_local_version_separators = re.compile(r"[\._-]")
+
+
+def _parse_local_version(local: str) -> LocalType | None:
+    """
+    Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
+    """
+    if local is not None:
+        return tuple(
+            part.lower() if not part.isdigit() else int(part)
+            for part in _local_version_separators.split(local)
+        )
+    return None
+
+
+def _cmpkey(
+    epoch: int,
+    release: tuple[int, ...],
+    pre: tuple[str, int] | None,
+    post: tuple[str, int] | None,
+    dev: tuple[str, int] | None,
+    local: tuple[SubLocalType] | None,
+) -> CmpKey:
+    # When we compare a release version, we want to compare it with all of the
+    # trailing zeros removed. So we'll use a reverse the list, drop all the now
+    # leading zeros until we come to something non zero, then take the rest
+    # re-reverse it back into the correct order and make it a tuple and use
+    # that for our sorting key.
+    _release = tuple(
+        reversed(list(itertools.dropwhile(lambda x: x == 0, reversed(release))))
+    )
+
+    # We need to "trick" the sorting algorithm to put 1.0.dev0 before 1.0a0.
+    # We'll do this by abusing the pre segment, but we _only_ want to do this
+    # if there is not a pre or a post segment. If we have one of those then
+    # the normal sorting rules will handle this case correctly.
+    if pre is None and post is None and dev is not None:
+        _pre: PrePostDevType = NegativeInfinity
+    # Versions without a pre-release (except as noted above) should sort after
+    # those with one.
+    elif pre is None:
+        _pre = Infinity
+    else:
+        _pre = pre
+
+    # Versions without a post segment should sort before those with one.
+    if post is None:
+        _post: PrePostDevType = NegativeInfinity
+
+    else:
+        _post = post
+
+    # Versions without a development segment should sort after those with one.
+    if dev is None:
+        _dev: PrePostDevType = Infinity
+
+    else:
+        _dev = dev
+
+    if local is None:
+        # Versions without a local segment should sort before those with one.
+        _local: LocalType = NegativeInfinity
+    else:
+        # Versions with a local segment need that segment parsed to implement
+        # the sorting rules in PEP440.
+        # - Alpha numeric segments sort before numeric segments
+        # - Alpha numeric segments sort lexicographically
+        # - Numeric segments sort numerically
+        # - Shorter versions sort before longer versions when the prefixes
+        #   match exactly
+        _local = tuple(
+            (i, "") if isinstance(i, int) else (NegativeInfinity, i) for i in local
+        )
+
+    return epoch, _release, _pre, _post, _dev, _local
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/version/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/version/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..524368a50e7398b92547ab736b4476ba62342068
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/version/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_awaits/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_awaits/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..14433803b70fef84970abb65689ab98b2b2df510
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_awaits/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dispatch/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dispatch/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cdd0cee3091fedaa56eda4e5d51118ef8d14fd96
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dispatch/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dispatch/__pycache__/python.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dispatch/__pycache__/python.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1de97bed9794904661242654903e89d90bf4fb14
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dispatch/__pycache__/python.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc05fc2db7b6ae223d824264e92b2b4a1a71f1de
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/__pycache__/case.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/__pycache__/case.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9111ea351d5d5e6215501888b1fb8e8e2008d2da
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/__pycache__/case.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/__pycache__/gen_example.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/__pycache__/gen_example.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4aca7f267df726032c73cf5cb7eba13e97cceaad
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/__pycache__/gen_example.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/__pycache__/logging.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/__pycache__/logging.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..40f902411fb10682064d002b2a7762e5500627c0
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/__pycache__/logging.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..856236529635f8a223c4900ec13c96211507a49f
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/assume_constant_result.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/assume_constant_result.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a478a05dfc584fca476df13fadcd70177ee1979d
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/assume_constant_result.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/autograd_function.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/autograd_function.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5fcb78b38a422d2d589d97e2ae22da6e19ed8128
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/autograd_function.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/class_method.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/class_method.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f908d806262ab7292ed1ecf467927f615ac08c9f
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/class_method.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..479640ff3cacea0616138eaf9faa2ead30802b6f
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..95111524260c269cd75874edceba87ab32dac9bb
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4f843a944054f9f0c3eec5404f4d3e24d97b7f12
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..77bbda22e43a08aacb27c92a651f4eb7c764d5f9
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_operands.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_operands.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e34a4a3aaf4b18a2f0ad0406f23921c3c28ae093
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_operands.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_predicate.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_predicate.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ccacc06dac88f57e25278327bb1482a131e3fb96
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_predicate.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..decc03dc84b4872d40acdc1a8e4ab243f18d3d81
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5404948c90d8f26fc5971df8aaa535d897a3a676
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/decorator.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/decorator.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..303bf791d939f8e66d43ec0f12d07e3af3383d4f
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/decorator.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dictionary.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dictionary.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a1779cff1e20558a7fca587b9e86d52dfe4b9594
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dictionary.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6038a6d578a8070538c7321b1b0b1197d1339326
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e928189bc028beb0ac79638fa276c98eb360a29b
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..48062c95025f385ced3c659f45f83ee8c874a97d
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..05061b3dedeb8e434075bb51d793a97f9f725804
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eb4df72162b8eff4801460e6684b06ad6f3a70ce
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..31c41f93ea82f30af549faae793103d6318a1861
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..80d528c9fc265e9349ab3d5b1abadf587b018143
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5ac41b19bd737c1896aeebf1b97c27286dd3880b
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/list_contains.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/list_contains.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5e9304c673b042c5201799135f97609f982d9941
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/list_contains.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/list_unpack.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/list_unpack.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1cbde169206312b0273510472ed240c6dba76c94
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/list_unpack.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3c657d10d6597cbe654b33e41b0a6d0952164f75
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/nested_function.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/nested_function.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ab3f6c051606dba249d2cb00be31a023e640402
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/nested_function.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/null_context_manager.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/null_context_manager.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..75243ae726ae0aaadb0133625d22bfeccaf0ba26
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/null_context_manager.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/optional_input.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/optional_input.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6145a99a9de161da2764838eca3484f77ab99d2f
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/optional_input.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/pytree_flatten.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/pytree_flatten.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ac3f2175699b2604dd8701a78bfce0038ddc1e78
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/pytree_flatten.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/scalar_output.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/scalar_output.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a3fe2ba90c229e1ddfbc131a2c7cc97c6a4c6541
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/scalar_output.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/specialized_attribute.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/specialized_attribute.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..12d2f60395088a2841d886befa1d535945e97d62
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/specialized_attribute.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/static_for_loop.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/static_for_loop.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ac992b5162d899befeec2817b62ee827459e682
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/static_for_loop.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/static_if.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/static_if.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ed121e4c5780c572f22cc85b02abda28e04a2bff
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/static_if.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/tensor_setattr.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/tensor_setattr.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..42ace10c4b16add47e88b5a4d597682e27c8b5f7
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/tensor_setattr.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/type_reflection_method.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/type_reflection_method.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f04d0874d4471327b0fa497b9f2496553193dd9e
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/type_reflection_method.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/unsupported_operator.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/unsupported_operator.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..141e849777f9e66d9a0a4b893c9fbd59b15f74cb
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/unsupported_operator.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/user_input_mutation.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/user_input_mutation.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3cfca9f0b541db734aebf4a160cb99ef5557aa26
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/user_input_mutation.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_view.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_view.py
new file mode 100644
index 0000000000000000000000000000000000000000..c45d4aeebb0282a0f56c58a587b4bfe1655f50e3
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_view.py
@@ -0,0 +1,17 @@
+# mypy: allow-untyped-defs
+import torch
+
+class DynamicShapeView(torch.nn.Module):
+    """
+    Dynamic shapes should be propagated to view arguments instead of being
+    baked into the exported graph.
+    """
+
+    def forward(self, x):
+        new_x_shape = x.size()[:-1] + (2, 5)
+        x = x.view(*new_x_shape)
+        return x.permute(0, 2, 1)
+
+example_args = (torch.randn(10, 10),)
+tags = {"torch.dynamic-shape"}
+model = DynamicShapeView()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/scalar_output.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/scalar_output.py
new file mode 100644
index 0000000000000000000000000000000000000000..86d3b4645330c47c3625736b695d635f4ab58c70
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/scalar_output.py
@@ -0,0 +1,23 @@
+# mypy: allow-untyped-defs
+import torch
+
+from torch.export import Dim
+
+x = torch.randn(3, 2)
+dim1_x = Dim("dim1_x")
+
+class ScalarOutput(torch.nn.Module):
+    """
+    Returning scalar values from the graph is supported, in addition to Tensor
+    outputs. Symbolic shapes are captured and rank is specialized.
+    """
+    def __init__(self) -> None:
+        super().__init__()
+
+    def forward(self, x):
+        return x.shape[1] + 1
+
+example_args = (x,)
+tags = {"torch.dynamic-shape"}
+dynamic_shapes = {"x": {1: dim1_x}}
+model = ScalarOutput()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/user_input_mutation.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/user_input_mutation.py
new file mode 100644
index 0000000000000000000000000000000000000000..3156b3a1bf2ec6f6361395de3dacb098ecf20c3a
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/user_input_mutation.py
@@ -0,0 +1,17 @@
+# mypy: allow-untyped-defs
+import torch
+
+
+class UserInputMutation(torch.nn.Module):
+    """
+    Directly mutate user input in forward
+    """
+
+    def forward(self, x):
+        x.mul_(2)
+        return x.cos()
+
+
+example_args = (torch.randn(3, 2),)
+tags = {"torch.mutation"}
+model = UserInputMutation()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/__pycache__/dynamic_shapes.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/__pycache__/dynamic_shapes.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7b7d7c4c54c395e4d1243654e045920076e616c0
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/__pycache__/dynamic_shapes.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/__pycache__/schema.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/__pycache__/schema.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7e9dbbadf7f0afd830b30dc7bb010f3c2d4dc060
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/__pycache__/schema.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/__pycache__/schema_check.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/__pycache__/schema_check.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf9ca496d7ca31845a8af46e11d2bfb50b547d26
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/__pycache__/schema_check.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/analyze_preserves_zero_mask.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/analyze_preserves_zero_mask.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dcca5017b15397a4977a82fd900b6f9aa244cfd5
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/analyze_preserves_zero_mask.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/async_compile.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/async_compile.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cbe111c3d79df277bfeb7b2ce4229d62be7da91c
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/async_compile.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/augmented_graph_helper.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/augmented_graph_helper.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e68ecb86d99640f47f434679dd9378eb4b0c122e
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/augmented_graph_helper.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/choices.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/choices.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ced54b20f258a48ea4e569d6229abeca45d37e3c
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/choices.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/compile_fx_async.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/compile_fx_async.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b85b647755b529da54849fdc906a329082c7ba66
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/compile_fx_async.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/compile_fx_ext.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/compile_fx_ext.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..43787171ed1201a71c7dba01dcce3f9c83370185
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/compile_fx_ext.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/config.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/config.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..08a48963e3c3ea3d77a5cf2dc59012b3b9eec89f
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/config.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4732f8d8c8bfe998e40b43fdd7f12dc720208965
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/freezing_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/freezing_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0c9a738af2ce222c3c9f365090665b26f944e6c1
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/freezing_utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/fx_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/fx_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5e1649c0cba0b31ad183434e30b7d56aad159118
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/fx_utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/index_propagation.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/index_propagation.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..346ba3b2d5404deff502991b65eadc335f4d5a86
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/index_propagation.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/inductor_prims.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/inductor_prims.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..302ab3cd0f2752426ab367191358daf10420a1d3
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/inductor_prims.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/remote_cache.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/remote_cache.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5d40434fea6442b3462bb2f7c42f6e6f762a1f5f
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/remote_cache.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/remote_gemm_autotune_cache.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/remote_gemm_autotune_cache.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ad5440cf42bac96e1eb1b9585de2879f783239c1
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/remote_gemm_autotune_cache.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/sizevars.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/sizevars.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4aa4c4a336cad9c28b864caadd934b76cc61c3cb
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/sizevars.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/standalone_compile.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/standalone_compile.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a43dd13ed7036a7881c42381d789be13e95e92af
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/standalone_compile.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/subgraph_lowering.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/subgraph_lowering.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..79cdceab9fd2c7d1b44d8e9cc6b9e6eeb3ac6161
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/subgraph_lowering.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/triton_bundler.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/triton_bundler.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cd1de893c420aec5342734b1a797d2351f16b4f3
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/triton_bundler.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/analysis/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/analysis/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/analysis/device_info.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/analysis/device_info.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d5edf1e7fd26d3f902d15af82a3d0c615d20c6f
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/analysis/device_info.py
@@ -0,0 +1,216 @@
+import logging
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+
+
+log = logging.getLogger(__name__)
+
+
+@dataclass(frozen=True)
+class DeviceInfo:
+    """
+    Theoretical Numbers from data sheet. If two numbers are given, Tensor/Matrix Core vs not,
+    then the higher number is reported. Sparsity is not considered.
+
+
+    Bandwidth numbers are tricky, because there are platform differences that may not show up in the profiler trace.
+    For example,
+    """
+
+    tops: dict[Union[torch.dtype, str], float]
+    dram_bw_gbs: float
+    dram_gb: float
+
+
+# Indexing is based on `torch.cuda.get_device_name()`
+# TODO investigate profiler support for tf32 and allow device to report correct number when it's turned on.
+_device_mapping: dict[str, DeviceInfo] = {
+    # Source:
+    # @lint-ignore https://www.nvidia.com/en-us/data-center/h100/
+    "NVIDIA H100": DeviceInfo(
+        tops={
+            torch.float64: 67.0,
+            torch.float32: 67.5,
+            "torch.tf32": 156.0,
+            torch.bfloat16: 1979.0,
+            torch.float16: 1979.0,
+            torch.float8_e8m0fnu: 3958.0,
+            torch.float8_e8m0fnu: 3958.0,
+            torch.float8_e4m3fnuz: 3958.0,
+            torch.float8_e5m2: 3958.0,
+            torch.float8_e5m2fnuz: 3958.0,
+            torch.float8_e8m0fnu: 3958.0,
+            torch.int8: 3958.0,
+        },
+        dram_bw_gbs=3350,
+        dram_gb=80,
+    ),
+    # Source:
+    # @lint-ignore https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/
+    # nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
+    "NVIDIA A100": DeviceInfo(
+        tops={
+            torch.float64: 19.5,
+            torch.float32: 19.5,
+            torch.bfloat16: 312.5,
+            torch.float16: 312.5,
+            # Not in datasheet: float8
+            torch.int8: 624.0,
+            "torch.tf32": 156.0,
+        },
+        dram_bw_gbs=2039.0,
+        dram_gb=80.0,
+    ),
+    # Source:
+    # @lint-ignore https://resources.nvidia.com/en-us-gpu-resources/l4-tensor-datasheet
+    "NVIDIA L4": DeviceInfo(
+        tops={
+            # This is a guess, not in datasheet
+            torch.float64: 15.1,
+            torch.float32: 30.3,
+            "torch.tf32": 120.0,
+            torch.bfloat16: 242.0,
+            torch.float16: 242.0,
+            torch.float8_e8m0fnu: 485.0,
+            torch.float8_e8m0fnu: 485.0,
+            torch.float8_e4m3fnuz: 485.0,
+            torch.float8_e5m2: 485.0,
+            torch.float8_e5m2fnuz: 485.0,
+            torch.float8_e8m0fnu: 485.0,
+            torch.int8: 485.0,
+        },
+        dram_bw_gbs=3350,
+        dram_gb=24,
+    ),
+    # Source:
+    # @lint-ignore https://www.amd.com/content/dam/amd/en/documents\
+    # /instinct-tech-docs/product-briefs/amd-instinct-mi350x-gpu-brochure.pdf
+    "AMD MI350X": DeviceInfo(
+        tops={
+            torch.float64: 72.1,
+            torch.float32: 144.2,
+            # not specified, fall back to float32 numbers
+            "torch.tf32": 144.2,
+            torch.bfloat16: 2309.6,
+            torch.float16: 2309.6,
+            torch.float8_e8m0fnu: 4614.0,
+            torch.float8_e8m0fnu: 4614.0,
+            torch.float8_e4m3fnuz: 4614.0,
+            torch.float8_e5m2: 4614.0,
+            torch.float8_e5m2fnuz: 4614.0,
+            torch.float8_e8m0fnu: 4614.0,
+            torch.int8: 4614.0,
+        },
+        dram_bw_gbs=8000.0,
+        dram_gb=288.0,
+    ),
+    # Source:
+    # @lint-ignore https://www.amd.com/content/dam/amd/en/documents\
+    # /instinct-tech-docs/data-sheets/amd-instinct-mi300a-data-sheet.pdf
+    "AMD MI300A": DeviceInfo(
+        tops={
+            torch.float64: 122.6,
+            torch.float32: 122.6,
+            "torch.tf32": 490.3,
+            torch.bfloat16: 980.6,
+            torch.float16: 980.6,
+            torch.float8_e8m0fnu: 1961.2,
+            torch.float8_e8m0fnu: 1961.2,
+            torch.float8_e4m3fnuz: 1961.2,
+            torch.float8_e5m2: 1961.2,
+            torch.float8_e5m2fnuz: 1961.2,
+            torch.float8_e8m0fnu: 1961.2,
+            torch.int8: 1961.2,
+        },
+        dram_bw_gbs=5300.0,
+        dram_gb=128.0,
+    ),
+    # Source:
+    # @lint-ignore https://www.amd.com/content/dam/amd/en/documents/\
+    # instinct-tech-docs/data-sheets/amd-instinct-mi300x-data-sheet.pdf
+    "AMD MI300X": DeviceInfo(
+        tops={
+            torch.float64: 163.4,
+            torch.float32: 163.4,
+            "torch.tf32": 653.7,
+            torch.bfloat16: 1307.4,
+            torch.float16: 1307.4,
+            torch.float8_e8m0fnu: 2614.9,
+            torch.float8_e8m0fnu: 2614.9,
+            torch.float8_e4m3fnuz: 2614.9,
+            torch.float8_e5m2: 2614.9,
+            torch.float8_e5m2fnuz: 2614.9,
+            torch.float8_e8m0fnu: 2614.9,
+            torch.int8: 2614.9,
+        },
+        dram_bw_gbs=5300.0,
+        dram_gb=192.0,
+    ),
+    # Source:
+    # @lint-ignore https://www.amd.com/content/dam/amd/\
+    # en/documents/instinct-business-docs/product-briefs/instinct-mi210-brochure.pdf
+    "AMD MI210X": DeviceInfo(
+        tops={
+            torch.float64: 45.3,
+            torch.float32: 45.3,
+            # not specified, fall back to float32 numbers
+            "torch.tf32": 45.3,
+            torch.bfloat16: 181.0,
+            torch.float16: 181.0,
+            # not specified, fall back to float16 numbers
+            torch.float8_e8m0fnu: 181.0,
+            torch.float8_e8m0fnu: 181.0,
+            torch.float8_e4m3fnuz: 181.0,
+            torch.float8_e5m2: 181.0,
+            torch.float8_e5m2fnuz: 181.0,
+            torch.float8_e8m0fnu: 181.0,
+            torch.int8: 181.0,
+        },
+        # pcie4.0x16
+        dram_bw_gbs=1600.0,
+        dram_gb=64.0,
+    ),
+}
+_device_mapping["AMD INSTINCT MI350X"] = _device_mapping["AMD MI350X"]
+_device_mapping["AMD INSTINCT MI300X"] = _device_mapping["AMD MI300X"]
+_device_mapping["AMD INSTINCT MI210X"] = _device_mapping["AMD MI210X"]
+
+
+def lookup_device_info(name: str) -> Optional[DeviceInfo]:
+    """
+    Problem: when diffing profiles between amd and nvidia, we don't have access to the device information
+    of the other one. Also, since the analysis is static, we should be able to do it on another device unrelated
+    to the recorded device. Therefore, _device_mapping statically contains the information for lots of devices.
+    If one is missing, please run DeviceInfo.get_device_info() and add it to _device_mapping.
+      name (str): name of the device to lookup. Should map onto torch.cuda.get_device_name().
+    """
+    return _device_mapping.get(name)
+
+
+def datasheet_tops(dtype: torch.dtype, is_tf32: bool = False) -> Optional[float]:
+    """
+    Get the theoretical TFLOPS of the device for a given dtype. This can throw an exception if the device
+    is not in the datasheet list above.
+    """
+    name: Optional[str] = torch.cuda.get_device_name()
+    if name is None:
+        log.info("No device found, returning None")
+        return None
+    device_info = lookup_device_info(name)
+    if device_info is None:
+        log_str = f"Device {name} not in datasheet, returning None"
+        log.info(log_str)
+        return None
+    if dtype not in device_info.tops:
+        log.info(
+            "Device %s does not have a datasheet entry for %s, returning None",
+            name,
+            dtype,
+        )
+        return None
+
+    return device_info.tops[
+        "torch.tf32" if dtype == torch.float32 and is_tf32 else dtype
+    ]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/analysis/profile_analysis.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/analysis/profile_analysis.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a6ec39003bdb2447b72c9aed892e1db01474be0
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/analysis/profile_analysis.py
@@ -0,0 +1,823 @@
+import json
+import logging
+import math
+from collections import defaultdict
+from collections.abc import Callable
+from dataclasses import dataclass
+from typing import Any, Optional, Union
+
+import torch
+from torch._inductor.analysis.device_info import DeviceInfo, lookup_device_info
+from torch._inductor.utils import tabulate_2d, zip_dicts
+from torch.utils import _pytree as pytree
+from torch.utils._ordered_set import OrderedSet
+from torch.utils.flop_counter import flop_registry
+
+
+log = logging.getLogger(__name__)
+
+
+ATEN_PREFIX = "aten::"
+
+
+@dataclass
+class ProfileEvent:
+    category: str
+    key: str
+    self_device_time_ms: float
+    # the benchmark is run multiple times and we average the count across all the
+    # runs. It should be an integer but define a float just in case.
+    count: float
+
+
+# adapters convert the json trace into a format that works with flops_counter
+ArgsType = tuple[tuple[Any, ...], dict[Any, Any]]
+AdapterType = Callable[[tuple[Any, ...], tuple[Any, ...]], ArgsType]
+adapters_map: dict[str, AdapterType] = {}
+
+
+def parse_list(lst: str) -> list[int]:
+    lst = lst.replace("[", "").replace("]", "")
+    substrings = lst.split(",")
+
+    return [int(substring.strip()) for substring in substrings]
+
+
+def register_adapter(
+    aten: Union[str, list[str]],
+) -> Callable[
+    [AdapterType],
+    AdapterType,
+]:
+    def decorator(func: AdapterType) -> AdapterType:
+        # pyrefly: ignore [unknown-name]
+        global _adapters_map
+
+        if isinstance(aten, str):
+            adapters_map[aten] = func
+        else:
+            for at in aten:
+                adapters_map[at] = func
+        return func
+
+    return decorator
+
+
+@register_adapter(["_slow_conv2d_forward"])
+def _slow_conv2d_adapter(
+    shapes: tuple[Any, ...], concrete: tuple[Any, ...]
+) -> tuple[tuple[Any], dict[Any, Any]]:
+    tmp = list(shapes)
+    tmp.append(False)
+    tmp2 = list(concrete)
+    if len(tmp2) < 5:
+        raise ParseException("slow conv2d has less than 5 concrete inputs")
+    tmp2[3] = tmp2[4]
+    return conv_adapter(tuple(tmp), tuple(tmp2))
+
+
+@register_adapter(
+    ["convolution", "_convolution", "cudnn_convolution", "convolution_overrideable"]
+)
+def conv_adapter(
+    shapes: tuple[Any, ...], concrete: tuple[Any, ...]
+) -> tuple[tuple[Any], dict[Any, Any]]:
+    tmp = list(shapes)
+    if len(tmp) == 4:
+        transposed = False
+    elif len(tmp) > 6:
+        transposed = bool(tmp[6])
+        tmp[6] = transposed
+    else:
+        raise ParseException(f"Convolution has the wrong number of inputs: {len(tmp)}")
+
+    kwargs: dict[Any, Any] = {}
+    if not transposed:
+        # calculate output shape if not transposed.
+        def conv_out_dims(x: int, kernel: int, stride: int) -> int:
+            return (x - kernel) // stride + 1
+
+        stride = parse_list(concrete[3])
+        inp = shapes[0]
+        w = shapes[1]
+        out_x_y = [conv_out_dims(*args) for args in zip(inp[2:], w[2:], stride)]
+        out = [inp[0], w[0]] + out_x_y  # we only need the xy values
+        kwargs["out_val"] = out
+
+    return tuple(tmp), kwargs
+
+
+def default_adapter(
+    shapes: tuple[Any], concrete: tuple[Any]
+) -> tuple[tuple[Any], dict[Any, Any]]:
+    return shapes, {}
+
+
+@register_adapter("addmm")
+def addmm_adapter(
+    shapes: tuple[Any], concrete: tuple[Any]
+) -> tuple[tuple[Any], dict[Any, Any]]:
+    tmp = list(shapes)[:3]
+    return tuple(tmp), {}
+
+
+@register_adapter("bmm")
+def bmm_adapter(
+    shapes: tuple[Any], concrete: tuple[Any]
+) -> tuple[tuple[Any], dict[Any, Any]]:
+    tmp = list(shapes)
+    return tuple(tmp[:2]), {}
+
+
+@register_adapter("baddbmm")
+def baddbmm_adapter(
+    shapes: tuple[Any], concrete: tuple[Any]
+) -> tuple[tuple[Any], dict[Any, Any]]:
+    tmp = list(shapes)[:3]
+    return tuple(tmp), {}
+
+
+@register_adapter("mm")
+def mm_adapter(
+    shapes: tuple[Any], concrete: tuple[Any]
+) -> tuple[tuple[Any], dict[Any, Any]]:
+    return shapes, {}
+
+
+def _parse_kernel_name(name: str) -> Optional[str]:
+    """
+    parse the name of the kernel from the event name.
+    """
+    if name.startswith(ATEN_PREFIX):
+        return name[len(ATEN_PREFIX) :]
+    elif "conv" in name:
+        return "convolution"
+    elif "addmm" in name:
+        return "addmm"
+    elif "bmm" in name:
+        return "bmm"
+    elif "baddbmm" in name:
+        return "baddbmm"
+    elif "_mm" in name:
+        return "mm"
+    else:
+        return None
+
+
+def _calculate_flops(event: dict[str, Any]) -> int:
+    """
+    This function has to parse the kernel name, which is error prone. There doesn't seem to be another solution that
+    will support all the different backends that can generate kernels, so make sure to update this function when new
+    ops and backends are desired.
+    """
+    name = event["name"]
+    if "kernel_flop" in event["args"] and event["args"]["kernel_flop"] != 0:
+        return event["args"]["kernel_flop"]
+    op_name = _parse_kernel_name(name)
+    if op_name is None:
+        return 0
+
+    op_obj = getattr(torch.ops.aten, op_name, None)
+    if op_obj is None or op_obj not in flop_registry:
+        return 0
+
+    flop_function = flop_registry[op_obj]
+
+    if "Input Dims" not in event["args"] or "Concrete Inputs" not in event["args"]:
+        return 0
+    input_shapes = event["args"]["Input Dims"]
+    concrete = event["args"]["Concrete Inputs"]
+    if op_name in adapters_map:
+        try:
+            args, kwargs = adapters_map[op_name](input_shapes, concrete)
+        except ParseException as e:
+            msg = f"Failed to parse {op_name} with {e}"
+            log.warning(msg)
+            return 0
+    else:
+        try:
+            args, kwargs = default_adapter(input_shapes, concrete)
+        except ParseException as e:
+            msg = f"Failed to parse {op_name} with {e}"
+            log.warning(msg)
+            return 0
+    return flop_function(*args, **kwargs)
+
+
+def _get_size_from_string(type_string: str) -> int:
+    if not hasattr(torch, type_string):
+        return 1
+    else:
+        return getattr(torch, type_string).itemsize
+
+
+def _default_estimate_gb(event: dict[str, Any]) -> float:
+    sizes_and_types = zip(event["args"]["Input Dims"], event["args"]["Input type"])
+    bw = 0
+    for size, typ in sizes_and_types:
+        isize = _get_size_from_string(typ)
+        bw += isize * math.prod(pytree.tree_flatten(size)[0])
+    return bw / 1e9
+
+
+def _estimate_gb(event: dict[str, Any]) -> float:
+    """
+    Our best effort to estimate the gb, should be refactored soon with MemoryCounter.
+    """
+    name = event["name"]
+    if "kernel_num_gb" in event["args"] and event["args"]["kernel_num_gb"] != 0:
+        return event["args"]["kernel_num_gb"]
+    if "Input type" not in event["args"] or "Input Dims" not in event["args"]:
+        return 0
+    op_name = _parse_kernel_name(name)
+    if op_name is None:
+        return _default_estimate_gb(event)
+
+    op_obj = getattr(torch.ops.aten, op_name, None)
+    if op_obj is None:
+        return _default_estimate_gb(event)
+
+    if "Input Dims" not in event["args"] or "Concrete Inputs" not in event["args"]:
+        return _default_estimate_gb(event)
+    input_shapes = event["args"]["Input Dims"]
+
+    # NOTE these will be refactored into a similar object to FlopCounter soon
+    def mm_formula(M: int, N: int, K: int, size: int) -> int:
+        return 2 * (M * K + N * K + M * N) * size
+
+    if op_name == "addmm":
+        add_in_size = math.prod(pytree.tree_flatten(input_shapes[0])[0])
+        add_type_size = _get_size_from_string(event["args"]["Input type"][0])
+        M = input_shapes[1][0]
+        N = input_shapes[1][1]
+        assert input_shapes[1][1] == input_shapes[2][0]
+        K = input_shapes[2][1]
+        mul_type_size = _get_size_from_string(event["args"]["Input type"][1])
+        return (mm_formula(M, N, K, mul_type_size) + add_in_size * add_type_size) / 1e9
+    elif op_name == "mm":
+        M = input_shapes[0][0]
+        N = input_shapes[0][1]
+        assert input_shapes[0][1] == input_shapes[1][0]
+        K = input_shapes[1][1]
+        type_size = _get_size_from_string(event["args"]["Input type"][0])
+        return mm_formula(M, N, K, type_size) / 1e9
+    elif op_name == "baddbmm":
+        add_in_size = math.prod(pytree.tree_flatten(input_shapes[0])[0])
+        add_type_size = _get_size_from_string(event["args"]["Input type"][0])
+        B = input_shapes[0][0]
+        M = input_shapes[1][1]
+        N = input_shapes[1][2]
+        K = input_shapes[2][2]
+        mul_type_size = _get_size_from_string(event["args"]["Input type"][1])
+        return (
+            B * mm_formula(M, N, K, mul_type_size) + add_in_size * add_type_size
+        ) / 1e9
+    elif op_name == "bmm":
+        add_in_size = math.prod(pytree.tree_flatten(input_shapes[0])[0])
+        add_type_size = _get_size_from_string(event["args"]["Input type"][0])
+        B = input_shapes[0][0]
+        M = input_shapes[0][1]
+        N = input_shapes[0][2]
+        K = input_shapes[1][2]
+        mul_type_size = _get_size_from_string(event["args"]["Input type"][1])
+        return (
+            B * mm_formula(M, N, K, mul_type_size) + add_in_size * add_type_size
+        ) / 1e9
+    elif op_name in [
+        "convolution",
+        "_convolution",
+        "cudnn_convolution",
+        "_slow_conv2d_forward",
+    ]:
+        concrete = event["args"]["Concrete Inputs"]
+
+        def conv_out_dim(x: int, kernel: int, stride: int) -> int:
+            return (x - kernel) // stride + 1
+
+        stride = parse_list(
+            concrete[3] if op_name != "_slow_conv2d_forward" else concrete[4]
+        )
+        inp = input_shapes[0]
+        w = input_shapes[1]
+        out_x_y = [conv_out_dim(*args) for args in zip(inp[2:], w[2:], stride)]
+        out = [inp[0], w[0]] + out_x_y
+        # each output element reads in * w * w chunk
+        input_reads = out[0] * out[1] * out[2] * out[3] * inp[1] * w[2] * w[3]
+        # Assume weights are in cache, so only read once
+        weight_reads = w[0] * w[1] * w[2] * w[3]
+        return (input_reads + weight_reads) / 1e9
+
+    return _default_estimate_gb(event)
+
+
+def _create_extern_mapping(
+    data: dict[str, Any],
+) -> defaultdict[int, list[dict[str, Any]]]:
+    """
+    compute a mapping from external ids to non kernels, which contain the information we need to estimate flops etc
+    """
+    extern_mapping: defaultdict[int, list[dict[str, Any]]] = defaultdict(list)
+    for event in data["traceEvents"]:
+        if (
+            "args" not in event
+            or "External id" not in event["args"]
+            or event["cat"] != "cpu_op"
+        ):
+            continue
+        if len(extern_mapping[event["args"]["External id"]]) > 0:
+            raise ParseException("duplicate external id in event")
+        extern_mapping[event["args"]["External id"]].append(event)
+    return extern_mapping
+
+
+def _augment_trace_helper(data: dict[str, Any]) -> dict[str, Any]:
+    extern_mapping = _create_extern_mapping(data)
+
+    for event in data["traceEvents"]:
+        if "cat" not in event or event["cat"] != "kernel":
+            continue
+        if "args" not in event:
+            raise ParseException(f"kernel has no args: {event}")
+        if "External id" not in event["args"]:
+            event_str = f"kernel has no External id: {event}"
+            log.info(event_str)
+            continue
+
+        external_op = extern_mapping[event["args"]["External id"]][0]
+        flops = _calculate_flops(external_op)
+        if flops == 0:
+            flops = _calculate_flops(event)
+        external_op["args"]["kernel_flop"] = flops
+        external_op["args"]["kernel_num_gb"] = _estimate_gb(external_op)
+        event["args"]["kernel_flop"] = external_op["args"]["kernel_flop"]
+        event["args"]["kernel_num_gb"] = external_op["args"]["kernel_num_gb"]
+    return data
+
+
+_dtype_map = {
+    "float": torch.float,
+    "float32": torch.float,
+    "int": torch.int,
+    "int8": torch.int8,
+    "int16": torch.int16,
+    "int32": torch.int,
+    "long": torch.long,
+    "long int": torch.long,
+    "bfloat16": torch.bfloat16,
+    "float16": torch.float16,
+    "float64": torch.double,
+}
+
+
+@dataclass(frozen=True)
+class KernelStats:
+    flops: int
+    bw: float
+    latency: float  # us
+    achieved_flops: float
+    achieved_bandwidth: float
+
+
+KernelNameMap = defaultdict[str, OrderedSet[KernelStats]]
+
+
+@dataclass(frozen=False)
+class Device:
+    name: str
+    index: int
+    info: Optional[DeviceInfo]
+    stats: KernelNameMap
+
+    def __repr__(self) -> str:
+        return f"Device({self.name}, {self.index}): {self.info}"
+
+
+DeviceMap = dict[int, Device]
+Table = tuple[list[str], dict[str, list[str]]]
+
+
+class JsonProfile:
+    _devices: DeviceMap
+
+    def __init__(
+        self,
+        path: str,
+        benchmark_name: Optional[str] = None,
+        dtype: Optional[Union[torch.dtype, str]] = None,
+    ):
+        """
+        Convenience class for running common operations on chrome/perfetto json traces.
+        """
+        self.path = path
+        with open(path) as f:
+            self.data = json.load(f)
+            self.events = self.data["traceEvents"]
+        self.benchmark_name = benchmark_name
+        if dtype is None:
+            self.dtype = None
+        elif isinstance(dtype, torch.dtype):
+            # pyrefly: ignore [bad-assignment]
+            self.dtype = dtype
+        else:
+            # pyrefly: ignore [bad-assignment]
+            self.dtype = _dtype_map.get(dtype)
+        self._create_devices()
+
+    def convert_dtype(self, event: dict[str, Any]) -> Optional[torch.dtype]:
+        """
+        Each op has a list of dtypes for each input arg. We need to convert these into a single dtype for flop estimation.
+        Issues:
+         - converting the strings to concrete torch.dtypes
+         - What if we have float32, float, float16 all in the inputs? Our choice is to use the largest buffer dtype.
+        """
+
+        if (
+            "Input Dims" not in event["args"]
+            or "Input type" not in event["args"]
+            or "Concrete Inputs" not in event["args"]
+        ):
+            if "bfloat16" in event["name"]:
+                return torch.bfloat16
+            elif "float16" in event["name"]:
+                return torch.float16
+            else:
+                return None
+
+        input_sizes = event["args"]["Input Dims"]
+        input_types = event["args"]["Input type"]
+        concrete_inputs = event["args"]["Concrete Inputs"]
+        assert len(input_sizes) == len(input_types)
+        assert len(input_types) == len(concrete_inputs)
+
+        if len(input_sizes) == 0:
+            raise RuntimeError("Empty input_sizes and input_types")
+
+        biggest_size = 0
+        biggest_index = 0
+        for i in range(len(input_sizes)):
+            if concrete_inputs[i] != "":
+                # concrete inputs are usually small tensors, so we can just skip
+                continue
+            my_size = input_sizes[i]
+            total_size = sum(parse_list(my_size))
+            if total_size > biggest_size:
+                biggest_size = total_size
+                biggest_index = i
+        ret_type = input_types[biggest_index]
+        if ret_type in _dtype_map:
+            return _dtype_map[ret_type]
+        raise RuntimeError(f"Unknown type: {ret_type}. Please add to _dtype_map.")
+
+    def _create_devices(self) -> None:
+        self._devices = {}
+        for dev in self.data["deviceProperties"]:
+            name = dev["name"]
+            device_info = lookup_device_info(name)
+
+            if device_info is None:
+                log.info(
+                    "Unsupported device in profile: %s, please consider contributing to _device_mapping.",
+                    name,
+                )
+            self._devices[dev["id"]] = Device(
+                name, dev["id"], device_info, defaultdict(OrderedSet)
+            )
+
+    def calculate_flops(self, event: dict[str, Any]) -> int:
+        return _calculate_flops(event)
+
+    def estimate_gb(self, event: dict[str, Any]) -> float:
+        return _estimate_gb(event)
+
+    def augment_trace(self) -> None:
+        self.data = _augment_trace_helper(self.data)
+
+    def _compute_stats(self) -> None:
+        """populates the name -> stats map"""
+        for event in self.events:
+            if "cat" not in event or "args" not in event or event["cat"] != "kernel":
+                continue
+            if "device" not in event["args"]:
+                continue
+            dev_tmp = event["args"]["device"]
+            if dev_tmp not in self._devices:
+                continue
+            dev = self._devices[event["args"]["device"]]
+
+            dur = event["dur"]  # us
+            if "kernel_flop" in event["args"]:
+                assert dur != 0
+                # 1,000,000us/s * flop / us
+                op_flops = event["args"]["kernel_flop"] / (dur / 1e6)
+            else:
+                op_flops = 0
+
+            if "kernel_num_gb" in event["args"]:
+                assert dur != 0
+                # 1,000,000us/s * gb  = gb/s
+                op_gbps = event["args"]["kernel_num_gb"] / (dur / 1e6)
+            else:
+                op_gbps = 0
+
+            if dev.info is not None:
+                dtype = self.convert_dtype(event) or self.dtype
+                if dtype is None:
+                    raise RuntimeError(
+                        "dtype is not found on tensor and default dtype is not set"
+                    )
+                achieved_flops = 100 * op_flops / (1e12 * dev.info.tops[dtype])
+                achieved_bandwidth = 100 * op_gbps / dev.info.dram_bw_gbs
+            else:
+                achieved_flops = 0
+                achieved_bandwidth = 0
+
+            if "name" not in event["args"]:
+                continue
+            dev.stats[event["name"]].add(
+                KernelStats(
+                    flops=op_flops,
+                    bw=op_gbps,
+                    latency=dur,
+                    achieved_bandwidth=achieved_bandwidth,
+                    achieved_flops=achieved_flops,
+                )
+            )
+
+    def _create_single_table(self, dev: Device) -> Table:
+        """Create a table with the devices mapped to indices."""
+        headers = [
+            "Kernel Name",
+            "Kernel Count",
+            "FLOPS",
+            "Kernel Reads (GB)",
+            "Dur (us)",
+            "Achieved FLOPS %",
+            "Achieved Bandwidth %",
+        ]
+        rows: dict[str, list[str]] = {}
+
+        def safe_div_format(x: float, y: float) -> str:
+            if y == 0:
+                return "0.0"
+            return f"{x / y:.4f}"
+
+        for kernel_name, stats_set in dev.stats.items():
+            ker_count = 0
+            flops = 0
+            flops_count = 0
+            achieved_flops = 0.0
+            bw = 0.0
+            bw_count = 0
+            achieved_bandwidth = 0.0
+            latency = 0.0
+            for stats in stats_set:
+                if stats.flops != 0:
+                    flops += stats.flops
+                    achieved_flops += stats.achieved_flops
+                    flops_count += 1
+                if stats.bw != 0:
+                    bw += stats.bw
+                    achieved_bandwidth += stats.achieved_bandwidth
+                    bw_count += 1
+                latency += stats.latency
+                ker_count += 1
+            assert ker_count != 0
+            rows[kernel_name] = [
+                str(ker_count),
+                safe_div_format(flops, flops_count),
+                safe_div_format(bw, bw_count),
+                safe_div_format(latency, ker_count),
+                safe_div_format(achieved_flops, flops_count),
+                safe_div_format(achieved_bandwidth, bw_count),
+            ]
+
+        return headers, rows
+
+    def _create_tables(self, devs: DeviceMap) -> dict[int, Table]:
+        return {idx: self._create_single_table(dev) for idx, dev in devs.items()}
+
+    def _combine_tables(
+        self, table1: Table, table1_name: str, table2: Table, table2_name: str
+    ) -> Table:
+        new_headers = (
+            ["Kernel Name"]
+            + [f"{table1_name} {head}" for head in table1[0][1:]]
+            + [f"{table2_name} {head}" for head in table2[0][1:]]
+        )
+        t1_length = len(table1[0][1:])
+        t2_length = len(table2[0][1:])
+        new_rows = {}
+
+        for key, row1, row2 in zip_dicts(
+            table1[1],
+            table2[1],
+            d1_default=["Empty"] * t1_length,
+            d2_default=["Empty"] * t2_length,
+        ):
+            assert row1 is not None
+            assert row2 is not None
+            new_rows[key] = row1 + row2
+        return new_headers, new_rows
+
+    def report(
+        self, other: Optional["JsonProfile"] = None, name_limit: int = 40
+    ) -> str:
+        def create_ret(
+            table_headers: list[str], table_rows: dict[str, list[str]]
+        ) -> str:
+            table_flattened = [
+                [kernel_name[:name_limit], *kernel_vals]
+                for kernel_name, kernel_vals in table_rows.items()
+            ]
+            return tabulate_2d(table_flattened, headers=table_headers)
+
+        if other is not None:
+            self._compute_stats()
+            other._compute_stats()
+
+            self_tables = self._create_tables(self._devices)
+            other_tables = self._create_tables(other._devices)
+
+            self_name = (
+                self.benchmark_name if self.benchmark_name is not None else "Table 1"
+            )
+            other_name = (
+                other.benchmark_name if other.benchmark_name is not None else "Table 2"
+            )
+
+            ret = []
+            assert self._devices.keys() == other._devices.keys()
+            for device_idx, t1, t2 in zip_dicts(
+                self_tables, other_tables, d1_default=None, d2_default=None
+            ):
+                assert t1 is not None
+                assert t2 is not None
+                table_headers, table_rows = self._combine_tables(
+                    t1, self_name, t2, other_name
+                )
+                tab_string = create_ret(table_headers, table_rows)
+                # pyrefly: ignore [bad-argument-type]
+                ret.append(f"{self._devices[device_idx]}:\n{tab_string}")
+            return "\n".join(ret)
+        self._compute_stats()
+
+        self_tables = self._create_tables(self._devices)
+
+        ret = []
+        for idx, table in self_tables.items():
+            table_headers, table_rows = table
+            tab_string = create_ret(table_headers, table_rows)
+            # pyrefly: ignore [bad-argument-type]
+            ret.append(f"{self._devices[idx]}:\n{tab_string}")
+        return "\n".join(ret)
+
+    def dump(self, out: str) -> None:
+        with open(out, "w") as f:
+            json.dump(self.data, f)
+
+    def combine_with(self, other: "JsonProfile") -> "JsonProfile":
+        """
+        Combine this profile with another profile by merging their trace events.
+        Returns a new JsonProfile object with combined data.
+        """
+        # Create a new combined data structure
+        combined_data = {
+            "traceEvents": self.data["traceEvents"] + other.data["traceEvents"],
+            "deviceProperties": self.data.get("deviceProperties", []),
+        }
+
+        # Merge device properties, avoiding duplicates
+        other_device_props = other.data.get("deviceProperties", [])
+        existing_device_ids = OrderedSet(
+            [dev["id"] for dev in combined_data["deviceProperties"]]
+        )
+
+        for device_prop in other_device_props:
+            if device_prop["id"] not in existing_device_ids:
+                combined_data["deviceProperties"].append(device_prop)
+
+        # Copy any other top-level properties from the first profile
+        for key, value in self.data.items():
+            if key not in combined_data:
+                combined_data[key] = value
+
+        import os
+
+        # Create a temporary file to write the combined data
+        import tempfile
+
+        with tempfile.NamedTemporaryFile(
+            mode="w", suffix=".json", delete=False
+        ) as tmp_file:
+            json.dump(combined_data, tmp_file)
+            tmp_path = tmp_file.name
+
+        try:
+            # Create new JsonProfile from the combined data
+            combined_profile = JsonProfile(
+                tmp_path,
+                benchmark_name=f"{self.benchmark_name or 'Profile1'}_+_{other.benchmark_name or 'Profile2'}",
+                dtype=self.dtype or other.dtype,
+            )
+            return combined_profile
+        finally:
+            # Clean up temporary file
+            os.unlink(tmp_path)
+
+
+class ParseException(RuntimeError):
+    pass
+
+
+def main() -> None:
+    """
+    Main function for the profile analysis script.
+    """
+    import argparse
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--diff",
+        nargs=5,
+        metavar=(
+            "input_file1",
+            "name1",
+            "input_file2",
+            "name2",
+            "dtype",
+        ),
+        help="Two json traces to compare with, specified as     ",
+    )
+    parser.add_argument(
+        "--name_limit",
+        type=int,
+        help="the maximum name size in the final report",
+    )
+    parser.add_argument(
+        "--augment_trace",
+        "-a",
+        nargs=3,
+        metavar=("input_file", "output_file", "dtype"),
+        help="Augment a trace with inductor meta information. Provide input and output file paths.",
+    )
+    parser.add_argument(
+        "--analysis",
+        nargs=2,
+        metavar=("input_file", "dtype"),
+        help="Run analysis on a single trace, specified as  ",
+    )
+    parser.add_argument(
+        "--combine",
+        nargs="+",
+        metavar=("input_files", "output_file"),
+        help="Combine multiple profiles into a single profile by merging trace events. Specify as  \
+ [input_file3 ...] . The last argument is the output file, all preceding arguments are \
+input files to combine.",
+    )
+    args = parser.parse_args()
+
+    if args.diff:
+        p1 = JsonProfile(args.diff[0], args.diff[1], dtype=args.diff[4])
+        p1.augment_trace()
+        p2 = JsonProfile(args.diff[2], args.diff[3], dtype=args.diff[4])
+        p2.augment_trace()
+        if args.name_limit:
+            print(p1.report(p2, name_limit=args.name_limit))
+        else:
+            print(p1.report(p2))
+    if args.analysis:
+        p1 = JsonProfile(
+            args.analysis[0],
+            dtype=args.analysis[1],
+        )
+        p1.augment_trace()
+        if args.name_limit:
+            print(p1.report(name_limit=args.name_limit))
+        else:
+            print(p1.report())
+    if args.augment_trace:
+        p = JsonProfile(args.augment_trace[0], dtype=args.augment_trace[2])
+        p.augment_trace()
+        p.dump(args.augment_trace[1])
+    if args.combine:
+        input_files = args.combine[:-1]  # All arguments except the last one
+        output_file = args.combine[-1]  # Last argument is the output file
+
+        if len(input_files) < 2:
+            print("Error: At least 2 input files are required for combining")
+            return
+
+        # Load the first profile
+        combined = JsonProfile(input_files[0], dtype=None)
+
+        # Iteratively combine with all other profiles
+        for input_file in input_files[1:]:
+            profile = JsonProfile(input_file, dtype=None)
+            combined = combined.combine_with(profile)
+
+        combined.dump(output_file)
+        print(f"Successfully combined {', '.join(input_files)} into {output_file}")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/autoheuristic.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/autoheuristic.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c12ca77cf2db28bbe1fc10cca44774ced5c102f
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/autoheuristic.py
@@ -0,0 +1,316 @@
+import json
+import os
+from collections.abc import Callable
+from functools import partial
+from typing import Any, Optional
+
+import torch
+from torch._inductor.autoheuristic.autoheuristic_utils import (
+    AHContext,
+    AHMetadata,
+    AHOperation,
+    Choice,
+    CHOICE_COL,
+    Feedback,
+    FEEDBACK_COL,
+    get_metadata_str_from_log,
+)
+from torch._inductor.autoheuristic.learned_heuristic_controller import (
+    LearnedHeuristicController,
+)
+from torch._inductor.ir import ChoiceCaller
+from torch._inductor.runtime.runtime_utils import cache_dir
+from torch._inductor.utils import get_gpu_shared_memory
+
+
+class LocalFeedback:
+    """
+    To be able to collect data for a choice, a function providing feedback given a choice has to be provided.
+    LocalFeedback can be used when AutoHeuristic should immediately run the function to collect feedback for each choice
+    (see pad_mm.py, where the autotuning happens locally, for an example).
+    """
+
+    def __init__(self, feedback_fn: Callable[[Choice], Feedback]) -> None:
+        self.feedback_fn = feedback_fn
+
+    def __call__(self, choice: Choice) -> Feedback:
+        return self.feedback_fn(choice)
+
+
+class InconsistentMetadata(Exception):
+    """
+    Exception that is thrown when AutoHeuristic tries to log data to a file where the metadata stored in the file does
+    not match the metadata it would store if the file didn't exist.
+    """
+
+
+class AutoHeuristic:
+    """
+    AutoHeuristic is a framework that allows one to collect data, learn a heuristic (i.e. a regression tree) and
+    generate the heuristic to code. This class allows one to collect data. The collected data can then be used to train
+    a heuristic (see torchgen/autoheuristic/).
+    """
+
+    collected_feedback: dict[Choice, Feedback]
+
+    def __init__(
+        self,
+        fallback: Callable[[], Choice],
+        choices: list[Choice],
+        feedback: Optional[LocalFeedback],
+        context: AHContext,
+        name: str,
+        augment_context: Optional[list[AHOperation]] = None,
+        precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
+    ) -> None:
+        """
+        Initializes an instance of the AutoHeuristic class.
+
+        Args:
+            fallback: A callable that returns a Choice when the heuristic is unsure which choice to make, or
+            AutoHeuristic is in data collection mode.
+            choices: A list of possible choices the heuristic can make.
+            feedback: An instance of LocalFeedback that provides feedback for a given choice.
+            context: Context to store with each choice and feedback.
+            name: A string that identifies the heuristic.
+            augment_context: An optional list of AHOperation instances that augment the context.
+            precondition: A callable that returns a boolean indicating whether AutoHeuristic should run.
+        """
+        self.fallback = fallback
+        self.choices = choices
+        self.feedback = feedback
+        self.context = context
+        self.name = name
+        self.collected_feedback = {}
+        self.augment_context = augment_context
+        self.metadata = AHMetadata(
+            get_gpu_shared_memory(),
+            torch.cuda.get_device_capability(),
+            self.choices,
+            self.name,
+        )
+        self.precondition = precondition
+
+        if not self.satisfies_precondition():
+            return
+
+        if torch._inductor.config.autoheuristic_log_path == "DEFAULT":
+            self.log_path = self.get_default_log_path()
+        else:
+            self.log_path = torch._inductor.config.autoheuristic_log_path
+
+        if torch._inductor.config.collect_autoheuristic(self.name):
+            if self.feedback is not None:
+                for choice in self.choices:
+                    feedback_val = self.feedback(choice)
+                    self.save_data(choice, feedback_val)
+
+    def satisfies_precondition(self) -> bool:
+        return self.precondition is None or self.precondition(
+            self.metadata, self.context
+        )
+
+    def get_choice(self) -> Choice:
+        """
+        Returns the chosen option based on the value of autoheuristic_use.
+        If self.name is one of the comma separated strings in autoheuristic_use,
+        it queries a learned heuristic to make a decision. Otherwise, it returns the fallback option.
+        """
+
+        if not self.satisfies_precondition():
+            return self.fallback()
+
+        if torch._inductor.config.use_autoheuristic(self.name):
+            if self.augment_context is not None:
+                self.context.apply_operations(self.augment_context)
+            controller = LearnedHeuristicController(
+                self.metadata,
+                self.context,
+            )
+            decision = controller.get_decision()
+            if decision not in self.choices:
+                # TODO(AlnisM): We might want to allow this in the future
+                return self.fallback()
+            if decision is not None:
+                return decision
+        return self.fallback()
+
+    def get_top_k_choices(
+        self, top_k: int, always_included: Optional[list[str]] = None
+    ) -> Optional[list[Choice]]:
+        if not self.satisfies_precondition():
+            return None
+        if torch._inductor.config.use_autoheuristic(self.name):
+            if self.augment_context is not None:
+                self.context.apply_operations(self.augment_context)
+            controller = LearnedHeuristicController(
+                self.metadata,
+                self.context,
+            )
+            choices = controller.get_decisions_ranked(top_k)
+            if choices is None:
+                return None
+            if always_included is not None:
+                for choice in always_included:
+                    if choice not in choices:
+                        choices.append(choice)
+            return choices
+        return None
+
+    def get_collected_feedback(self, choice: Choice) -> Any:
+        return self.collected_feedback.get(choice, None)
+
+    @staticmethod
+    def get_device_identifier() -> str:
+        # a heuristic might work well for one GPU, but not for another
+        # we store the collected data per GPU model and learn a heuristic per GPU model
+
+        # TODO(AlnisM): just using the device name for now, but the same GPU model can have different names
+        device_name = torch.cuda.get_device_name().replace(" ", "_")
+        return device_name
+
+    def get_default_log_path(self) -> str:
+        device_name = self.get_device_identifier()
+        path = f"{cache_dir()}/autoheuristic/{device_name}/"
+        os.makedirs(path, exist_ok=True)
+        path += f"{self.name}.txt"
+        return path
+
+    def serialize_metadata(self) -> str:
+        metadata_dict = self.metadata.to_dict()
+        (
+            num_features,
+            cat_features,
+        ) = self.context.get_numerical_and_categorical_features()
+        metadata_dict["numerical_features"] = num_features
+        metadata_dict["categorical_features"] = cat_features
+        return json.dumps(metadata_dict)
+
+    def save_data(self, choice: Choice, feedback_val: Feedback) -> None:
+        self.collected_feedback[choice] = feedback_val
+        log_path = self.log_path
+
+        lines = []
+        log_exists = os.path.exists(log_path)
+        if log_exists:
+            # if log already exists, make sure it is consistent
+            metadata = self.serialize_metadata()
+            existing_metadata = get_metadata_str_from_log(self.log_path)
+            if existing_metadata != metadata:
+                raise InconsistentMetadata(
+                    "Given metadata does not match existing metadata"
+                )
+        else:
+            lines.append(self.serialize_metadata())
+            feature_header = self.context.get_feature_names_csv()
+            header = feature_header + "," + CHOICE_COL + "," + FEEDBACK_COL
+            lines.append(header)
+
+        line = ""
+        feature_values = self.context.get_feature_values_csv()
+        line += feature_values + "," + choice + "," + str(feedback_val)
+        lines.append(line)
+
+        with open(log_path, "a") as f:
+            f.write("\n".join(lines) + "\n")
+
+
+class AutoHeuristicSelectAlgorithm(AutoHeuristic):
+    """
+    AutoHeuristicSelectAlgorithm is a subclass of AutoHeuristic that allows one to collect data and learn a heuristic
+    when one wants to use AutoHeuristic for kernel choice selection.
+    """
+
+    def __init__(
+        self,
+        fallback: Callable[[], Optional[ChoiceCaller]],
+        choices: list[ChoiceCaller],
+        input_nodes: list[Any],
+        context: AHContext,
+        name: str,
+        augment_context: Optional[list[AHOperation]] = None,
+        precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
+    ) -> None:
+        """
+        The arguments choices, input_nodes and name have to match the ones used in the call to
+        autotune_select_algorithm(), e.g. if the following call is made
+        autotune_select_algorithm(name, choices, input_nodes, layout), the same name, choices and input_nodes
+        have to be used here.
+        """
+        self.input_nodes = input_nodes
+        self.choicestr2choice: dict[str, ChoiceCaller] = {}
+        for choice in choices:
+            self.choicestr2choice[choice.autoheuristic_id()] = choice
+        choices_str = list(self.choicestr2choice.keys())
+
+        def fallback_str() -> str:
+            fallback_choice = fallback()
+            if fallback_choice is None:
+                # TODO: Find a nicer way to handle this
+                return "unsure"
+            return fallback_choice.autoheuristic_id()
+
+        super().__init__(
+            fallback_str,
+            choices_str,
+            None,
+            context,
+            name,
+            augment_context,
+            precondition,
+        )
+
+        if (
+            torch._inductor.config.collect_autoheuristic(self.name)
+            and self.satisfies_precondition()
+        ):
+            self.register_global_feedback(input_nodes, choices)
+
+    def register_global_feedback(
+        self, input_nodes: list[Any], choices: list[ChoiceCaller]
+    ) -> None:
+        """
+        Registers a callback in select_algorithm, which is called with the timing of each choice.
+        """
+
+        from torch._inductor.select_algorithm import (
+            add_feedback_saver,
+            create_inputs_key,
+            create_precompile_key,
+        )
+
+        def store_global_feedback(
+            ah_inputs_key: str,
+            ah_precompile_key: str,
+            timings: dict[ChoiceCaller, float],
+            name: str,
+            input_nodes: list[Any],
+            choices: list[ChoiceCaller],
+        ) -> None:
+            current_inputs_key = create_inputs_key(input_nodes)
+            if current_inputs_key != ah_inputs_key:
+                return
+            current_precompile_key = create_precompile_key(
+                name, current_inputs_key, choices
+            )
+            if current_precompile_key != ah_precompile_key:
+                return
+            for choice, time in timings.items():
+                self.save_data(choice.autoheuristic_id(), time)
+
+        inputs_key = create_inputs_key(input_nodes)
+        precompile_key = create_precompile_key(self.name, inputs_key, choices)
+        feedback_saver = partial(store_global_feedback, inputs_key, precompile_key)
+        add_feedback_saver(feedback_saver)
+
+    def get_choice_caller(self) -> Optional[ChoiceCaller]:
+        choice = self.get_choice()
+        return self.choicestr2choice.get(choice, None)
+
+    def get_top_k_choices_caller(
+        self, top_k: int, always_included: Optional[list[str]] = None
+    ) -> Optional[list[ChoiceCaller]]:
+        choices = self.get_top_k_choices(top_k, always_included)
+        if choices is None:
+            return None
+        return [self.choicestr2choice[choice] for choice in choices]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/autoheuristic_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/autoheuristic_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d0435fe44b4035a8f338f503340fc351019252a
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/autoheuristic_utils.py
@@ -0,0 +1,340 @@
+import functools
+from collections.abc import Callable
+from typing import Any
+
+import torch
+
+
+Feedback = float
+Choice = str
+Value = Any
+
+CHOICE_COL = "choice"
+FEEDBACK_COL = "feedback"
+
+
+class AHFeature:
+    """
+    The context, that AutoHeuristic stores, is a list of features. AutoHeuristic needs to know whether a feature is
+    categorical (i.e., not a continuous variable) to learn a machine learning model.
+    """
+
+    def __init__(self, name: str, value: Value, is_categorical: bool = False) -> None:
+        self.name = name
+        self.value = value
+        self.is_categorical = is_categorical
+
+
+class AHOperation:
+    """
+    AHOperation can be used to augment the data collected by AutoHeuristic.
+    One might for example store features like m, k, n, but also want to use
+    features like m*n, or k*n, to learn a heuristic. Instead of storing features
+    that can be created from the collected data, one can use AHOperation to
+    create new features from the collected data.
+    """
+
+    def __init__(
+        self, name: str, func: Callable[[Any], Value], is_categorical: bool = False
+    ) -> None:
+        self.name = name
+        self.func = func
+        self.is_categorical = is_categorical
+
+    def apply_operation(self, data: Any) -> None:
+        data[self.name] = self.func(data)
+
+
+class AHContext:
+    """
+    This class is used to specify which information AutoHeuristic should store. For each choice, AutoHeursitic will
+    store the context and the collected feedback. The context could be something like the shape of a tensor, i.e.,
+    information that will help to learn a heuristic.
+    """
+
+    features: list[AHFeature]
+    context_dict: dict[str, Value]
+
+    def __init__(self) -> None:
+        self.features = []
+        self.context_dict = {}
+
+    def add_feature(
+        self, name: str, value: Value, is_categorical: bool = False
+    ) -> None:
+        self.features.append(AHFeature(name, value, is_categorical=is_categorical))
+        self.context_dict[name] = value
+
+    def get_numerical_and_categorical_features(self) -> tuple[list[str], list[str]]:
+        numerical_features = []
+        categorical_features = []
+        for feature in self.features:
+            if feature.is_categorical:
+                categorical_features.append(feature.name)
+            else:
+                numerical_features.append(feature.name)
+
+        return numerical_features, categorical_features
+
+    def get_feature_names_csv(self) -> str:
+        return ",".join(feature.name for feature in self.features)
+
+    def get_feature_values_csv(self) -> str:
+        return ",".join(str(feature.value) for feature in self.features)
+
+    def get_value(self, name: str) -> Value:
+        return self.context_dict[name]
+
+    def apply_operations(self, operations: list[AHOperation]) -> None:
+        for op in operations:
+            op.apply_operation(self.context_dict)
+
+
+class AHMetadata:
+    def __init__(
+        self,
+        shared_memory: Any,
+        device_capa: tuple[int, int],
+        choices: list[Choice],
+        name: str,
+    ) -> None:
+        # use amount of shared_memory and device_capability to identify GPU
+        # TODO(AlnisM): there might be a better way to do this
+        self.shared_memory = shared_memory
+        self.device_capa = device_capa
+        self.choices = choices
+        self.name = name
+
+    def to_dict(self) -> dict[str, Value]:
+        return {
+            "shared_memory": self.shared_memory,
+            "device_capa": self.device_capa,
+            "name": self.name,
+        }
+
+
+def get_metadata_str_from_log(log_path: str) -> str:
+    with open(log_path, newline="") as file:
+        json_string = file.readline().strip()
+        return json_string
+
+
+def check_minsize(context: AHContext, minsize: int) -> bool:
+    return (
+        context.get_value("m") >= minsize
+        and context.get_value("k") >= minsize
+        and context.get_value("n") >= minsize
+    )
+
+
+def pad_mm_precondition(metadata: AHMetadata, context: AHContext) -> bool:
+    if metadata.shared_memory == 166912 and metadata.device_capa == (8, 0):
+        # A100 precondition
+        return check_minsize(context, 512)
+    elif metadata.shared_memory == 232448 and metadata.device_capa == (9, 0):
+        # H100 precondition
+        return check_minsize(context, 768)
+    return True
+
+
+def get_mixedmm_precondition(metadata: AHMetadata, context: AHContext) -> bool:
+    m = context.get_value("m")
+    k = context.get_value("k")
+    n = context.get_value("n")
+    if m > 128 or k < 1024 or n < 1024:
+        return False
+    mat1_iscontig = context.get_value("mat1_iscontig")
+    mat2_iscontig = context.get_value("mat2_iscontig")
+    return mat1_iscontig and not mat2_iscontig
+
+
+def get_mult_dims_ops() -> list[AHOperation]:
+    m_times_k_op = AHOperation("m*k", lambda data: data["m"] * data["k"])
+    m_times_n_op = AHOperation("m*n", lambda data: data["m"] * data["n"])
+    k_times_n_op = AHOperation("k*n", lambda data: data["k"] * data["n"])
+    return [m_times_k_op, m_times_n_op, k_times_n_op]
+
+
+def get_arith_intensity(data: Any) -> float:
+    m = data["m"]
+    k = data["k"]
+    n = data["n"]
+    if m == 0 or k == 0 or n == 0:
+        return 0.0
+    return m * k * n / (m * k + k * n + m * n)
+
+
+def pad_mm_operations() -> list[AHOperation]:
+    mult_dims_ops = get_mult_dims_ops()
+    k_div_m_times_n_op = AHOperation(
+        "k/(m*n)", lambda data: data["k"] / (data["m"] * data["n"])
+    )
+
+    def bfloat_perf_hit(data: Any) -> bool:
+        m = data["m"]
+        k = data["k"]
+        n = data["n"]
+        is_bfloat = str(data["mat1_dtype"]) == "torch.bfloat16"
+        return k > (m * 1024) and k > (n * 1024) and is_bfloat
+
+    bfloat_perf_hit_op = AHOperation(
+        "bfloat_perf_hit", bfloat_perf_hit, is_categorical=True
+    )
+
+    arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity)
+    dims_need_padding_ops = get_dims_need_padding_ops()
+    dims_multiple_ops = get_dims_multiple_ops()
+    is_contig_ops = get_is_contig_ops()
+
+    ah_operations = mult_dims_ops + [
+        k_div_m_times_n_op,
+        bfloat_perf_hit_op,
+        arith_intensity_op,
+    ]
+    ah_operations.extend(dims_need_padding_ops)
+    ah_operations.extend(dims_multiple_ops)
+    ah_operations.extend(is_contig_ops)
+    return ah_operations
+
+
+def between_op(data: Any, dim: str, lower: int, upper: int) -> bool:
+    return data[dim] >= lower and data[dim] <= upper
+
+
+def between_ops() -> list[AHOperation]:
+    dims = ["m", "k", "n"]
+    limits = [(1, 16), (17, 32), (33, 64), (65, 128), (129, 256)]
+    ah_operations = []
+    for dim in dims:
+        for lower, upper in limits:
+            between_op_fn = functools.partial(
+                between_op, dim=dim, lower=lower, upper=upper
+            )
+            # using 'LEQ' instead of '<=' because '<=' cannot be exported to dot
+            between_op_name = f"{lower}LEQ{dim}LEQ{upper}"
+            ah_operations.append(
+                AHOperation(between_op_name, between_op_fn, is_categorical=True)
+            )
+    return ah_operations
+
+
+def pow2_op(data: Any, dim: str, exponent: int) -> bool:
+    return data[dim] == 2**exponent
+
+
+def mm_operations() -> list[AHOperation]:
+    mult_dims_ops = get_mult_dims_ops()
+    arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity)
+    return mult_dims_ops + [arith_intensity_op]
+
+
+def mixed_mm_operations() -> list[AHOperation]:
+    return mm_operations() + between_ops()
+
+
+def is_multiple(data: Any, dim: str, mult: int) -> bool:
+    return data[dim] % mult == 0
+
+
+def get_dims_multiple_ops() -> list[AHOperation]:
+    multiples = [2, 4, 8, 16, 32]
+    dims = ["m", "k", "n"]
+    dims_multiple_ops = []
+    for dim in dims:
+        for mult in multiples:
+            is_multiple_fn = functools.partial(is_multiple, dim=dim, mult=mult)
+            dims_multiple_op = AHOperation(
+                f"{dim}_multiple_{mult}", is_multiple_fn, is_categorical=True
+            )
+            dims_multiple_ops.append(dims_multiple_op)
+    return dims_multiple_ops
+
+
+def get_dims_need_padding_ops() -> list[AHOperation]:
+    def mat1_innermost_needs_padding_fn(data: Any) -> bool:
+        mat1_stride_0 = data["mat1_stride_0"]
+        mat1_stride_1 = data["mat1_stride_1"]
+        m_padded_length = data["m_padded_length"]
+        k_padded_length = data["k_padded_length"]
+        mat1_innermost_needs_padding = False
+        if mat1_stride_0 == 1 and m_padded_length != 0:
+            mat1_innermost_needs_padding = True
+        if mat1_stride_1 == 1 and k_padded_length != 0:
+            mat1_innermost_needs_padding = True
+        return mat1_innermost_needs_padding
+
+    mat1_innermost_op = AHOperation(
+        "mat1_innermost_needs_padding",
+        mat1_innermost_needs_padding_fn,
+        is_categorical=True,
+    )
+
+    def mat2_innermost_needs_padding_fn(data: Any) -> bool:
+        mat2_stride_0 = data["mat2_stride_0"]
+        mat2_stride_1 = data["mat2_stride_1"]
+        k_padded_length = data["k_padded_length"]
+        n_padded_length = data["n_padded_length"]
+        mat2_innermost_needs_padding = False
+        if mat2_stride_0 == 1 and k_padded_length != 0:
+            mat2_innermost_needs_padding = True
+        if mat2_stride_1 == 1 and n_padded_length != 0:
+            mat2_innermost_needs_padding = True
+        return mat2_innermost_needs_padding
+
+    mat2_innermost_op = AHOperation(
+        "mat2_innermost_needs_padding",
+        mat2_innermost_needs_padding_fn,
+        is_categorical=True,
+    )
+
+    def num_dims_needs_padding_fn(data: Any) -> int:
+        m_padded_length = data["m_padded_length"]
+        k_padded_length = data["k_padded_length"]
+        n_padded_length = data["n_padded_length"]
+        num_dims_needs_padding = 0
+        if m_padded_length != 0:
+            num_dims_needs_padding += 1
+        if k_padded_length != 0:
+            num_dims_needs_padding += 1
+        if n_padded_length != 0:
+            num_dims_needs_padding += 1
+        return num_dims_needs_padding
+
+    num_dims_op = AHOperation("num_dims_needs_padding", num_dims_needs_padding_fn)
+    return [mat1_innermost_op, mat2_innermost_op, num_dims_op]
+
+
+def get_is_contig_ops() -> list[AHOperation]:
+    def mat1_is_contig_fn(data: Any) -> bool:
+        stride_0 = data["mat1_stride_0"]
+        stride_1 = data["mat1_stride_1"]
+        k = data["k"]
+        return stride_0 == k and stride_1 == 1
+
+    mat1_is_contig_op = AHOperation(
+        "mat1_iscontig", mat1_is_contig_fn, is_categorical=True
+    )
+
+    def mat2_is_contig_fn(data: Any) -> bool:
+        stride_0 = data["mat2_stride_0"]
+        stride_1 = data["mat2_stride_1"]
+        n = data["n"]
+        return stride_0 == n and stride_1 == 1
+
+    mat2_is_contig_op = AHOperation(
+        "mat2_iscontig", mat2_is_contig_fn, is_categorical=True
+    )
+
+    return [mat1_is_contig_op, mat2_is_contig_op]
+
+
+def context_add_strides(context: AHContext, name: str, stride: tuple[int, ...]) -> None:
+    for i, s in enumerate(stride):
+        context.add_feature(f"{name}_stride_{i}", s)
+
+
+def context_add_using_tf32(context: AHContext, dtype: torch.dtype) -> None:
+    using_tf32 = "not_float_32"
+    if dtype == torch.float32:
+        using_tf32 = torch.backends.cuda.matmul.allow_tf32
+    context.add_feature("using_tf32", using_tf32, is_categorical=True)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/learned_heuristic_controller.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/learned_heuristic_controller.py
new file mode 100644
index 0000000000000000000000000000000000000000..50c11eb9a712afafee7479987a6832e412cc393a
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/learned_heuristic_controller.py
@@ -0,0 +1,119 @@
+import importlib
+import inspect
+import pkgutil
+from collections import defaultdict
+from typing import Any, Optional
+
+from torch._inductor.autoheuristic.autoheuristic_utils import (
+    AHContext,
+    AHMetadata,
+    Choice,
+)
+from torch._inductor.autoheuristic.learnedheuristic_interface import LearnedHeuristic
+
+
+def find_and_instantiate_subclasses(
+    package_name: str, base_class: Any
+) -> list[LearnedHeuristic]:
+    instances = []
+
+    package = importlib.import_module(package_name)
+    for _, module_name, _ in pkgutil.walk_packages(
+        package.__path__, package.__name__ + "."
+    ):
+        try:
+            module_basename = module_name.split(".")[-1]
+            if not module_basename.startswith("_"):
+                # learned heuristics start with an underscore
+                continue
+            module = importlib.import_module(module_name)
+
+            # look for classes that are subclasses of base_class
+            for _name, obj in inspect.getmembers(module):
+                if (
+                    inspect.isclass(obj)
+                    and issubclass(obj, base_class)
+                    and obj != base_class
+                ):
+                    instance = obj()
+                    instances.append(instance)
+        except Exception as e:
+            print(f"Error processing module {module_name}: {e}")
+
+    return instances
+
+
+class LearnedHeuristicController:
+    """
+    Class that finds and instantiates all learned heuristics. It also provides
+    a way to get the decision of a learned heuristic.
+    """
+
+    existing_heuristics: dict[str, list[LearnedHeuristic]] = defaultdict(list)
+    """
+    A dictionary that stores all the learned heuristics for each optimization.
+    The key is the optimization name, and the value is a list of LearnedHeuristic objects.
+    """
+
+    heuristics_initialized: bool = False
+    """
+    A flag that indicates whether the learned heuristics have been initialized.
+    Set to true when the get_decision() function is called for the first time.
+    """
+
+    def __init__(
+        self,
+        metadata: AHMetadata,
+        context: AHContext,
+    ) -> None:
+        self.metadata = metadata
+        self.context = context
+
+    def get_heuristics(self, name: str) -> list[LearnedHeuristic]:
+        """
+        Returns a list of learned heuristics for the given optimization name.
+        """
+
+        if not LearnedHeuristicController.heuristics_initialized:
+            # learned heuristics are generated into the following package
+            learned_heuristics_package = "torch._inductor.autoheuristic.artifacts"
+
+            # learned heuristics have to be of type LearnedHeuristic
+            base_class = LearnedHeuristic
+            found_heuristics = find_and_instantiate_subclasses(
+                learned_heuristics_package, base_class
+            )
+
+            for learned_heuristic in found_heuristics:
+                opt_name = learned_heuristic.get_name()
+                LearnedHeuristicController.existing_heuristics[opt_name].append(
+                    learned_heuristic
+                )
+            LearnedHeuristicController.heuristics_initialized = True
+
+        return LearnedHeuristicController.existing_heuristics[name]
+
+    def get_decision(self) -> Optional[Choice]:
+        """
+        Returns the decision made by the learned heuristic or None if no heuristic was found or the heuristic is unsure
+        which choice to make.
+        """
+
+        heuristics = self.get_heuristics(self.metadata.name)
+        for heuristic in heuristics:
+            if heuristic.check_precondition(self.metadata, self.context):
+                return heuristic.get_decision(self.context, self.metadata.choices)
+        return None
+
+    def get_decisions_ranked(self, top_k: int) -> Optional[list[Choice]]:
+        heuristics = self.get_heuristics(self.metadata.name)
+        for heuristic in heuristics:
+            if heuristic.check_precondition(self.metadata, self.context):
+                choices = heuristic.get_decisions_ranked(self.context)
+                if choices is None:
+                    return None
+                avail_choices = [
+                    choice for choice in choices if choice in self.metadata.choices
+                ]
+                return avail_choices[:top_k]
+        return None
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/learnedheuristic_interface.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/learnedheuristic_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..84a941b076c314d9961af916a5a559e9948c0e00
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/learnedheuristic_interface.py
@@ -0,0 +1,89 @@
+import operator
+from typing import Optional
+
+from torch._inductor.autoheuristic.autoheuristic_utils import (
+    AHContext,
+    AHMetadata,
+    Choice,
+)
+
+
+class LearnedHeuristic:
+    """
+    LearnedHeuristic is a base class for all learned heuristics.
+    """
+
+    def __init__(self) -> None:
+        pass
+
+    def check_precondition(
+        self,
+        metadata: AHMetadata,
+        context: AHContext,
+    ) -> bool:
+        return True
+
+    def get_decision(
+        self, context: AHContext, choices: list[Choice]
+    ) -> Optional[Choice]:
+        return None
+
+    def get_confidence_threshold(self) -> float:
+        return 1.0
+
+    def get_name(self) -> str:
+        return ""
+
+    def get_decisions_ranked(self, context: AHContext) -> Optional[list[str]]:
+        return None
+
+
+class LearnedHeuristicRegression(LearnedHeuristic):
+    def get_feedback(self, context: AHContext, choice: Choice) -> float:
+        return 1.0
+
+    def get_decision(
+        self, context: AHContext, choices: list[Choice]
+    ) -> Optional[Choice]:
+        choice2feedback = {}
+        for choice in choices:
+            predicted_feedback = self.get_feedback(context, choice)
+            choice2feedback[choice] = predicted_feedback
+        sorted_choices_feedback = sorted(
+            choice2feedback.items(), key=operator.itemgetter(1)
+        )
+        highest_feedback = sorted_choices_feedback[-1][1]
+        second_highest_feedback = sorted_choices_feedback[-2][1]
+        if highest_feedback / second_highest_feedback > self.get_confidence_threshold():
+            return sorted_choices_feedback[-1][0]
+        # We are not sure which choice is the best one
+        return None
+
+
+class LearnedHeuristicDecision(LearnedHeuristic):
+    def get_choice(self, idx: int) -> Optional[str]:
+        return None
+
+    def get_decision(
+        self, context: AHContext, choices: list[Choice]
+    ) -> Optional[Choice]:
+        best_choices = self.get_best_choices(context)
+        if not best_choices:
+            return None
+        (best_choice_proba, best_choice_idx) = best_choices[0]
+        if best_choice_proba <= self.get_confidence_threshold():
+            return None
+        return self.get_choice(best_choice_idx)
+
+    def get_decisions_ranked(self, context: AHContext) -> Optional[list[str]]:
+        feedback_idx_list = self.get_best_choices(context)
+        if feedback_idx_list is None:
+            return None
+        choices = [
+            self.get_choice(feedback_idx[1]) for feedback_idx in feedback_idx_list
+        ]
+        choices = [choice for choice in choices if choice is not None]
+        return choices
+
+    def get_best_choices(self, context: AHContext) -> Optional[list[tuple[float, int]]]:
+        return []
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/aoti_hipify_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/aoti_hipify_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..eca4f85ced9260e9122db73366ab4a136b7cc4ab
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/aoti_hipify_utils.py
@@ -0,0 +1,36 @@
+import re
+
+import torch
+
+
+# It is not a good idea to directly apply hipify_torch to codegen, which will be vulnerable to cases like:
+#   "...
+#    from ..codecache import CudaKernelParamCache
+#   ..."
+# In such cases, we do not need to hipify_torch the original class/file name in codegen/codecache
+
+
+def maybe_hipify_code_wrapper(source_codes: str, force_hipify: bool = False) -> str:
+    if torch.version.hip is None and not force_hipify:
+        return source_codes
+
+    try:
+        from torch.utils.hipify.hipify_python import PYTORCH_MAP, PYTORCH_TRIE
+    except ImportError:
+        # hipify not available for non-AMD builds
+        return source_codes
+
+    def c2_repl(m: re.Match[str]) -> object:
+        return PYTORCH_MAP[m.group(0)]
+
+    # We need to redefine RE_PYTORCH_PREPROCESSOR here since in hipify_torch,
+    # it will apply positive lookbehind (?<=\W) to the pattern to avoid matching
+    # keyword at the beginning of code line. However, this can happen in codegen,
+    # which will cause the pattern to not match.
+
+    # Note that lookahead (?=\W) is still needed to keep hipification idomponent, for example
+    # we need to skip replacing "getStreamFromExternal" in "getStreamFromExternalMasqueradingAsCUDA"
+    RE_PYTORCH_PREPROCESSOR = re.compile(rf"({PYTORCH_TRIE.export_to_regex()})(?=\W)")
+
+    source_codes = RE_PYTORCH_PREPROCESSOR.sub(c2_repl, source_codes)  # type: ignore[arg-type]
+    return source_codes
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/block_analysis.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/block_analysis.py
new file mode 100644
index 0000000000000000000000000000000000000000..b47c8325e21545a9ca30f513a22b22480b4d6ab0
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/block_analysis.py
@@ -0,0 +1,192 @@
+import collections
+import functools
+import textwrap
+from typing import Optional
+
+import sympy
+from sympy import Expr, Symbol
+
+from torch.utils._sympy.functions import FloorDiv, ModularIndexing
+
+from ..utils import sympy_dot, sympy_subs
+from ..virtualized import V
+
+
+class BlockPatternMatcher:
+    """
+    Matches block indexing expressions.
+    """
+
+    _indexing_wild_signed_int = functools.partial(
+        sympy.Wild, properties=[lambda x: x.is_integer]
+    )
+    _indexing_wild_unsigned_int = functools.partial(
+        sympy.Wild, properties=[lambda x: x.is_integer and x.is_nonnegative]
+    )
+
+    @classmethod
+    def get_subexpr_involving_symbol(cls, expr: Expr, symbol: Symbol) -> Expr:
+        """
+        Given a sympy expression, return the subexpression comprised only of terms
+        involving the specified symbol.
+
+        For example, if `expr` is `x * 5 + x ** 2 + y * 2 + 5`, and `symbol` is `x`,
+        this returns `x * 5 + x ** 2`.
+        """
+        expr = cls._preprocess(expr)
+        return sympy.S.Zero + sum(
+            term for term in sympy.Add.make_args(expr) if symbol in term.free_symbols
+        )
+
+    @staticmethod
+    def get_slice_numels(dims: list[Expr]) -> list[Expr]:
+        """
+        Compute the cumulative size of each dimension's slice.
+        This proceeds from the last dim up to the second.
+        """
+        numels = collections.deque([sympy.S.One])
+        for dim in dims[:0:-1]:
+            numel = dim * numels[0]
+            numels.appendleft(numel)
+        return [*numels]
+
+    @staticmethod
+    def _preprocess(expr: Expr) -> Expr:
+        # Remove any Identity nodes, e.g. expand x + (5 * y) to x + 5 * y.
+        return expr.expand(identity=True)
+
+    @classmethod
+    def match_mod_div_block_expr(
+        cls,
+        index: Expr,
+        index_var: Symbol,
+        numel: Expr,
+        num_dims: int,
+    ) -> Optional[tuple[list[Expr], list[Expr], list[Expr]]]:
+        """
+        Matches modular indexing expressions, converting them to implied block dimensions and strides.
+        See triton.py for more information.
+        """
+        index = cls._preprocess(index)
+
+        # Pattern match to find the strides and offset.
+        wild_unsigned_int = functools.partial(
+            cls._indexing_wild_unsigned_int, exclude=[index_var]
+        )
+        wild_signed_int = functools.partial(
+            cls._indexing_wild_signed_int, exclude=[index_var]
+        )
+        dims: list[Expr] = [
+            wild_unsigned_int(f"dim_mod{idx}") for idx in range(num_dims)
+        ]
+        strides: list[Expr] = [
+            wild_signed_int(f"stride_mod{idx}") for idx in range(num_dims)
+        ]
+
+        # The first dimension's index is computed by division.
+        # The remaining are computed by modulo.
+        slice_numels = cls.get_slice_numels(dims[:num_dims])
+        block_index_exprs = [FloorDiv(index_var, slice_numels[0])] + [
+            ModularIndexing(index_var, numel, dim)
+            for dim, numel in zip(dims[1:], slice_numels[1:])
+        ]
+
+        # Calculate a linear index from block indices.
+        match_expr = sympy_dot(strides, block_index_exprs)
+
+        # Heuristic: if the number of dimensions is high, check that the minimum requirements
+        # are met before attempting an expensive full match. see triton.py:match_mod_div_block
+        # for more details. In short, here we check that each subexpression in sympy.Add contains
+        # only FloorDiv or ModularIndexing expressions.
+        if num_dims >= 5:
+            stride = sympy.symbols("stride", cls=wild_signed_int)
+            denom, other = sympy.symbols("denominator other", cls=wild_unsigned_int)
+            mod_div_pattern = stride * ModularIndexing(index_var, denom, other)
+            floor_div_pattern = stride * FloorDiv(index_var, denom)
+            first_dim_floor_div_matched = False
+            match_failed = False
+            for arg in sympy.Add.make_args(index):
+                if arg.match(floor_div_pattern):
+                    # There should only be a single FloorDiv(index, denom) expression
+                    # corresponding to the first dimension
+                    if first_dim_floor_div_matched:
+                        match_failed = True
+                        break
+                    first_dim_floor_div_matched = True
+                elif arg.match(mod_div_pattern):
+                    continue
+                else:
+                    match_failed = True
+                    break
+
+            if match_failed:
+                return None
+
+        # Pattern match.
+        match = index.match(match_expr)
+        if match is None:
+            return None
+
+        # Provide default values for unmatched dims and strides.
+        for dim in dims[1:]:
+            if dim not in match:
+                match[dim] = sympy.S.One
+        for stride in strides[1:]:
+            if stride not in match:
+                match[stride] = sympy.S.Zero
+
+        sizevars = V.graph.sizevars
+
+        def get_match(expr: Expr) -> Expr:
+            return sizevars.lookup_precomputed_size(match[expr])
+
+        # Replace wildcards with matched expressions.
+        dims = [dims[0]] + [get_match(dim) for dim in dims[1:]]
+        strides = [get_match(stride) for stride in strides]
+        slice_numels = cls.get_slice_numels(dims)
+        block_index_exprs = [sympy_subs(expr, match) for expr in block_index_exprs]
+
+        # The leading dimension is not directly matched in our expression.
+        # We solve for it by dividing the range tree numel by the product of
+        # all other dimensions. We quit if they are not known to be divisible.
+        assert dims[0] not in match, "Expected not to match the leading dimension!"
+        if not sizevars.statically_known_multiple_of(numel, slice_numels[0]):
+            return None
+        dims[0] = numel / slice_numels[0]
+
+        # Sanity check that we can recover the index from the matched subexpressions.
+        matched_index = sympy_dot(strides, block_index_exprs)
+        assert sizevars.statically_known_equals(
+            # New precomputed replacements may be generated when the `get_match` function
+            # above is called, but the `index` that is being matched has not been updated.
+            # So remove them when checking for equivalence e.g. if ps0=3*s0 and
+            # index=3*s0*expr, matched_index=ps0*expr, then index == matched_index
+            sizevars.remove_precomputed_replacements(matched_index),
+            sizevars.remove_precomputed_replacements(index),
+        ), textwrap.dedent(
+            f"""
+            Invalid match!
+            Index: {index}
+            Matched expression: {matched_index}
+            """
+        )
+
+        return dims, strides, block_index_exprs
+
+    @classmethod
+    def match_affine_block_expr(
+        cls,
+        index: Expr,
+        index_var: Symbol,
+    ) -> Optional[Expr]:
+        """
+        Matches simple expressions of the form stride * index, returning the
+        stride.
+        """
+        index = cls._preprocess(index)
+        stride = cls._indexing_wild_signed_int(name="stride", exclude=[index_var])
+        m = index.match(index_var * stride)
+        if m is None:
+            return None
+
+        return m[stride]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/common.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..e27336af8eab90cf38d6799515df6f6992da0ee5
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/common.py
@@ -0,0 +1,2918 @@
+from __future__ import annotations
+
+import atexit
+import contextlib
+import dataclasses
+import enum
+import functools
+import itertools
+import logging
+import math
+import operator
+import os
+import re
+import tempfile
+from abc import ABC, abstractmethod
+from enum import auto, Enum
+from itertools import chain
+from typing import (
+    Any,
+    cast,
+    ClassVar,
+    Generic,
+    NamedTuple,
+    Optional,
+    TYPE_CHECKING,
+    Union,
+)
+from typing_extensions import Self, TypeVar
+
+import sympy
+
+import torch
+import torch.fx
+from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
+from torch.utils import _pytree as pytree
+from torch.utils._config_module import ConfigModule
+from torch.utils._ordered_set import OrderedSet
+from torch.utils._sympy.numbers import int_oo
+from torch.utils._sympy.printers import PythonPrinter as _PythonPrinter
+from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
+from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
+
+from .. import config, metrics
+from ..dtype_propagation import DtypePropagationOpsHandler
+from ..ops_handler import BasicMathOpsMixin, DefaultHandler
+from ..shape_propagation import ShapePropagationOpsHandler
+from ..utils import (
+    boolean_ops,
+    DeferredLineBase,
+    generate_assert,
+    get_current_backend,
+    IndentedBuffer,
+    ir_dataclass,
+    ScopedDict,
+    sympy_dot,
+    sympy_index_symbol,
+    sympy_subs,
+    triton_type,
+    unique,
+)
+from ..virtualized import (
+    NullHandler,
+    ops,
+    OpsHandler,
+    OpsValue,
+    ReductionType,
+    StoreMode,
+    V,
+)
+
+
+if TYPE_CHECKING:
+    from collections.abc import Callable, Iterator, MutableMapping, Sequence
+
+    from torch.fx import GraphModule
+
+    from ..custom_graph_pass import CustomGraphModulePass
+    from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode
+    from ..loop_body import LoopBody
+    from ..scheduler import BaseScheduling, Scheduler, SchedulerNode
+    from ..shape_propagation import BlockShapeType
+    from .wrapper import PythonWrapperCodegen
+
+    _T = TypeVar("_T")
+    SchedulingConstructor = Callable[[Optional[Scheduler]], BaseScheduling]
+    WrapperConstructor = type[PythonWrapperCodegen]
+    SymbolLike = Union[str, sympy.Symbol]
+
+    # OpVarT should really be Union[CSEVariable, str], however this
+    # causes typing errors in subclasses (defined in other files).
+    OpVarT = str
+
+schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
+log = logging.getLogger(__name__)
+
+
+def data_type_logger(msg: str) -> None:
+    if schedule_log.isEnabledFor(logging.DEBUG):
+        schedule_log.debug("Data type propagation: %s", msg)
+
+
+@dataclasses.dataclass
+class FileBackedGraphModule:
+    """
+    Output of FX wrapper codegen. Exposes the same methods as ModuleType, but these
+    map back to a GraphModule instead of Python source.
+    """
+
+    gm: GraphModule
+    compiled_fn: Callable[..., Any]
+
+    def __post_init__(self) -> None:
+        # Write the code to a file for compatibility with debugging utilities.
+        # The file is deleted upon program termination.
+        self.tempfile = tempfile.NamedTemporaryFile(  # noqa: SIM115
+            mode="w+", suffix=".py", delete=False
+        )
+        atexit.register(os.remove, self.tempfile.name)
+        with self.tempfile as f:
+            f.write(self.value)
+
+    @property
+    def __file__(self) -> str:
+        return self.tempfile.name
+
+    def call(self, args: list[Any]) -> Any:
+        return self.compiled_fn(*args)
+
+    @property
+    def value(self) -> str:
+        return self.gm.code
+
+
+class WorkspaceZeroMode(enum.Enum):
+    UNINITIALIZED = 0
+    ZERO_ON_CALL = 1  # kernel may leave workspace dirty
+    ZERO_PER_GRAPH = 2  # must be re-zeroed by kernel
+
+    @staticmethod
+    def combine(a: WorkspaceZeroMode, b: WorkspaceZeroMode) -> WorkspaceZeroMode:
+        if a == b or b == WorkspaceZeroMode.UNINITIALIZED:
+            return a
+        if a == WorkspaceZeroMode.UNINITIALIZED:
+            return b
+        raise NotImplementedError(f"WorkspaceZeroMode.combine({a!r}, {b!r})")
+
+    @staticmethod
+    def from_bool(zero_fill: bool) -> WorkspaceZeroMode:
+        if zero_fill:
+            return WorkspaceZeroMode.ZERO_ON_CALL
+        return WorkspaceZeroMode.UNINITIALIZED
+
+
+class CodegenSymbol(ABC):
+    """
+    An IR object possibly corresponding to a variable in the wrapper code.
+    """
+
+    @abstractmethod
+    def get_name(self) -> str:
+        pass
+
+    @abstractmethod
+    def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
+        pass
+
+
+@ir_dataclass(frozen=True)
+class WorkspaceArg(CodegenSymbol):
+    """A temporary buffer used for a single kernel, then discarded.
+
+    Not registered as a traditional buffer since there are no users,
+    so it would be dead code eliminated.
+
+    Args:
+        nbytes: The size of the buffer in bytes.
+        zero_fill: Whether the buffer should be initialized to zero.
+
+    """
+
+    count: sympy.Expr
+    zero_mode: WorkspaceZeroMode
+    device: torch.device
+    outer_name: str
+    inner_name: str = "ws_ptr"
+    dtype: torch.dtype = torch.uint8
+
+    @staticmethod
+    def unique_name(prefix: str = "workspace_") -> str:
+        return f"{prefix}{next(V.graph.workspace_id)}"
+
+    @staticmethod
+    def can_join(a: WorkspaceArg, b: WorkspaceArg) -> bool:
+        return (
+            a.inner_name == b.inner_name and a.dtype == b.dtype and a.device == b.device
+        )
+
+    @staticmethod
+    def join(a: WorkspaceArg, b: WorkspaceArg) -> WorkspaceArg:
+        return WorkspaceArg(
+            count=a.count + b.count,
+            zero_mode=WorkspaceZeroMode.combine(a.zero_mode, b.zero_mode),
+            dtype=a.dtype,
+            device=a.device,
+            inner_name=a.inner_name,
+            outer_name=a.outer_name,
+        )
+
+    @staticmethod
+    def maximum(a: WorkspaceArg, b: WorkspaceArg) -> WorkspaceArg:
+        assert (
+            a.dtype == b.dtype and a.device == b.device and a.inner_name == b.inner_name
+        )
+        return WorkspaceArg(
+            count=sympy.Max(a.count, b.count),
+            zero_mode=WorkspaceZeroMode.combine(a.zero_mode, b.zero_mode),
+            dtype=a.dtype,
+            device=a.device,
+            inner_name=a.inner_name,
+            outer_name=a.outer_name,
+        )
+
+    # These methods let WorkspaceArg pretend it is a buffer to reuse allocation code
+    def get_device(self) -> torch.device:
+        return self.device
+
+    get_device_or_error = get_device
+
+    def get_dtype(self) -> torch.dtype:
+        return self.dtype
+
+    def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
+        return self.get_layout().get_example()
+
+    def get_layout(self) -> FixedLayout:
+        from ..ir import FixedLayout
+
+        return FixedLayout(
+            device=self.device,
+            dtype=self.dtype,
+            size=[self.count],
+            stride=[1],
+        )
+
+    @property
+    def layout(self) -> FixedLayout:
+        return self.get_layout()
+
+    get_output_spec = get_layout
+    maybe_get_output_spec = get_layout
+    maybe_get_layout = get_layout
+
+    def get_offset(self) -> sympy.Expr:
+        return sympy.S.Zero
+
+    def get_size(self) -> list[sympy.Expr]:
+        return [self.count]
+
+    def get_stride(self) -> list[sympy.Expr]:
+        return [sympy.S.One]
+
+    def get_name(self) -> str:
+        return self.outer_name
+
+    def get_is_pinned(self) -> bool:
+        return False
+
+    def get_inputs_that_alias_output(self) -> list[str]:
+        return []
+
+
+class TritonScratchWorkspace:
+    def __init__(self, size: int, generate_dtype_str: Callable[..., str]):
+        self.size = size
+        self._generate_dtype_str = generate_dtype_str
+
+    def generate_dtype_str(self) -> str:
+        return self._generate_dtype_str()
+
+
+@dataclasses.dataclass
+class TensorArg:
+    name: str
+    buffer: str
+    dtype: torch.dtype
+    offset: sympy.Expr = sympy.S.Zero  # c++ only
+    alias_of: Optional[str] = None  # halide only
+
+
+@dataclasses.dataclass
+class SizeArg:
+    name: str
+    expr: sympy.Expr
+
+    @property
+    def alias_of(self) -> Optional[str]:
+        return None
+
+
+@dataclasses.dataclass
+class ConstexprArg:
+    name: str
+
+
+@dataclasses.dataclass
+class TMADescriptorArg:
+    name: str
+    api_type: str  # "experimental" or "stable"
+    block_shape: Optional[list[sympy.Expr]]  # only needed for "stable"
+    dtype: Optional[torch.dtype]  # only needed for "stable"
+
+
+@dataclasses.dataclass
+class DeviceCodegen:
+    scheduling: SchedulingConstructor
+    wrapper_codegen: WrapperConstructor
+    cpp_wrapper_codegen: Optional[WrapperConstructor] = None
+    fx_wrapper_codegen: Optional[WrapperConstructor] = None
+
+
+KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg, TMADescriptorArg, ConstexprArg]
+
+device_codegens: dict[str, DeviceCodegen] = {}
+
+
+class DeviceOpOverrides:
+    def import_get_raw_stream_as(self, name: str) -> str:
+        raise NotImplementedError
+
+    def set_device(self, device_idx: int) -> str:
+        raise NotImplementedError
+
+    def synchronize(self) -> str:
+        raise NotImplementedError
+
+    def device_guard(self, device_idx: int) -> str:
+        raise NotImplementedError
+
+    def cpp_device_guard(self) -> str:
+        raise NotImplementedError
+
+    def cpp_aoti_device_guard(self) -> str:
+        raise NotImplementedError
+
+    def cpp_stream_guard(self) -> str:
+        raise NotImplementedError
+
+    def cpp_aoti_stream_guard(self) -> str:
+        raise NotImplementedError
+
+    def cpp_getStreamFromExternal(self) -> str:
+        raise NotImplementedError
+
+    def kernel_header(self) -> str:
+        raise NotImplementedError
+
+    def kernel_driver(self) -> str:
+        raise NotImplementedError
+
+    def cpp_stream_type(self) -> str:
+        raise NotImplementedError
+
+    def aoti_get_stream(self) -> str:
+        raise NotImplementedError
+
+    def cpp_kernel_type(self) -> str:
+        raise NotImplementedError
+
+    def cpp_device_ptr(self) -> str:
+        raise NotImplementedError
+
+    def tma_descriptor_helpers(self) -> str:
+        raise NotImplementedError
+
+    def cpp_scratch(
+        self, idx: int, workspace: TritonScratchWorkspace, prefix: Optional[str] = None
+    ) -> Optional[tuple[list[str], str]]:
+        # optionally return (scratch definition, arg name)
+        raise NotImplementedError
+
+
+device_op_overrides_dict: dict[str, DeviceOpOverrides] = {}
+custom_backend_passes: dict[str, Optional[CustomGraphModulePass]] = {}
+custom_backend_codegen_configs: dict[str, Optional[ConfigModule]] = {}
+
+
+# The code generated by Inductor consists of two main parts: kernel code and wrapper code.
+# For any new backend looking to integrate with Inductor, customization of these two main
+# parts are necessary to generate its specific code.
+#
+# Kernel code generation is determined by different Scheduling. Consequently, a new
+# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently,
+# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively.
+#
+# For the Wrapper, Inductor provides a PythonWrapperCodegen class to generate the Python wrapper code
+# that bridges kernels. This allows out-of-tree backends to inherit from PythonWrapperCodegen,
+# and override specific member functions to create backend-specific Python wrapper code.
+#
+# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part
+# of the logic for either Scheduling or PythonWrapperCodegen. So the Scheduling and PythonWrapperCodegen interfaces
+# provide flexibility to the backend. A backend can choose to implement these classes from scratch,
+# or reuse them by extending and overriding as necessary. And Inductor provides the registration API,
+# register_backend_for_device, to equip a new backend at runtime.
+#
+# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces.
+# This backend can be used as a reference:
+# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9
+def register_backend_for_device(
+    device: str,
+    device_scheduling: SchedulingConstructor,
+    device_wrapper_codegen: WrapperConstructor,
+    device_cpp_wrapper_codegen: Optional[WrapperConstructor] = None,
+    device_fx_wrapper_codegen: Optional[WrapperConstructor] = None,
+    device_custom_pass: Optional[CustomGraphModulePass] = None,
+    device_custom_config: Optional[ConfigModule] = None,
+) -> None:
+    device_codegens[device] = DeviceCodegen(
+        device_scheduling,
+        device_wrapper_codegen,
+        device_cpp_wrapper_codegen,
+        device_fx_wrapper_codegen,
+    )
+    custom_backend_passes[device] = device_custom_pass
+    if device_custom_config:
+        assert (
+            isinstance(device_custom_config, ConfigModule)
+            and device_custom_config is not config
+        ), (
+            f"{device_custom_config=} cannot be the same as the default inductor config {config=}"
+        )
+    custom_backend_codegen_configs[device] = device_custom_config
+
+
+class BackendFeature(Enum):
+    FOREACH = auto()
+    BUCKETIZE = auto()
+    INPLACE_BUFFERS = auto()
+    MASKED_SCATTER_WITH_INDEX = auto()
+    SCAN = auto()
+    SORT = auto()
+    TUPLE_REDUCTION = auto()
+    PREFER_STORE_LOOP_ORDER = auto()
+    TRITON_TEMPLATES = auto()
+    REDUCE_TO_SINGLE_ELEMENT = auto()
+
+
+def get_backend_features(
+    device: Union[torch.device, str, None],
+) -> OrderedSet[BackendFeature]:
+    if device is None:
+        return OrderedSet()
+    init_backend_registration()
+    if isinstance(device, torch.device):
+        device_type = device.type
+    else:
+        assert isinstance(device, str), type(device)
+        device_type = device
+        device = torch.device(device_type)
+    scheduling_ctor = get_scheduling_for_device(device_type)
+    assert scheduling_ctor
+    scheduling = scheduling_ctor(None)
+    return scheduling.get_backend_features(device)
+
+
+def has_backend_feature(
+    device: Union[torch.device, str, None], feature: BackendFeature
+) -> bool:
+    """See also V.graph.has_feature"""
+    assert isinstance(feature, BackendFeature)
+    return feature in get_backend_features(device)
+
+
+def get_scheduling_for_device(device: str) -> Optional[SchedulingConstructor]:
+    return device_codegens[device].scheduling if device in device_codegens else None
+
+
+def get_wrapper_codegen_for_device(
+    device: str, cpp_wrapper: bool = False, fx_wrapper: bool = False
+) -> Optional[WrapperConstructor]:
+    if device in device_codegens:
+        wrapper_codegen_obj: DeviceCodegen = device_codegens[device]
+        if fx_wrapper:
+            return wrapper_codegen_obj.fx_wrapper_codegen
+        elif cpp_wrapper:
+            return wrapper_codegen_obj.cpp_wrapper_codegen
+        else:
+            return wrapper_codegen_obj.wrapper_codegen
+    return None
+
+
+def get_custom_backend_pass_for_device(device: str) -> Optional[CustomGraphModulePass]:
+    return custom_backend_passes.get(device)
+
+
+def get_custom_backend_config_for_device(device: str) -> Optional[ConfigModule]:
+    return custom_backend_codegen_configs.get(device)
+
+
+@functools.cache
+def init_backend_registration() -> None:
+    """
+    Register the backend for different devices, including the scheduling
+    for kernel code generation and the host side wrapper code generation.
+    """
+    from .cpp import CppScheduling
+    from .cpp_wrapper_cpu import CppWrapperCpu
+    from .cpp_wrapper_cpu_array_ref import CppWrapperCpuArrayRef
+    from .cpp_wrapper_gpu import CppWrapperGpu
+    from .cpp_wrapper_mps import CppWrapperMps
+    from .cuda_combined_scheduling import CUDACombinedScheduling
+    from .halide import HalideScheduling
+    from .mps import MetalScheduling
+    from .pallas import PallasScheduling
+    from .python_wrapper_mtia import PythonWrapperMtia
+    from .triton import TritonScheduling
+    from .wrapper import PythonWrapperCodegen
+    from .wrapper_fxir import WrapperFxCodegen
+
+    if get_scheduling_for_device("cpu") is None:
+        cpu_backends = {
+            "cpp": CppScheduling,
+            "halide": HalideScheduling,
+            "triton": TritonScheduling,
+            "pallas": PallasScheduling,
+        }
+        register_backend_for_device(
+            "cpu",
+            lambda scheduling: cpu_backends[config.cpu_backend](scheduling),
+            PythonWrapperCodegen,
+            CppWrapperCpuArrayRef
+            if config.aot_inductor.allow_stack_allocation
+            else CppWrapperCpu,
+            WrapperFxCodegen,
+        )
+
+    if get_scheduling_for_device("cuda") is None:
+        # CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation
+        cuda_backends = {
+            "triton": CUDACombinedScheduling,
+            "halide": HalideScheduling,
+            "pallas": PallasScheduling,
+        }
+        register_backend_for_device(
+            "cuda",
+            lambda scheduling: cuda_backends[config.cuda_backend](scheduling),
+            PythonWrapperCodegen,
+            CppWrapperGpu,
+            WrapperFxCodegen,
+        )
+
+    if get_scheduling_for_device("xpu") is None:
+        register_backend_for_device(
+            "xpu",
+            TritonScheduling,
+            PythonWrapperCodegen,
+            CppWrapperGpu,
+            WrapperFxCodegen,
+        )
+
+    if get_scheduling_for_device("mps") is None:
+        register_backend_for_device(
+            "mps",
+            MetalScheduling,
+            PythonWrapperCodegen,
+            CppWrapperMps,
+            WrapperFxCodegen,
+        )
+
+    if get_scheduling_for_device("mtia") is None:
+        register_backend_for_device(
+            "mtia",
+            TritonScheduling,
+            PythonWrapperMtia,
+            CppWrapperGpu,
+            WrapperFxCodegen,
+        )
+
+    private_backend = torch._C._get_privateuse1_backend_name()
+    if (
+        private_backend != "privateuseone"
+        and get_scheduling_for_device(private_backend) is None
+    ):
+        from torch.utils.backend_registration import _get_custom_mod_func
+
+        try:
+            device_scheduling = _get_custom_mod_func("Scheduling")
+            wrapper_codegen = _get_custom_mod_func("PythonWrapperCodegen")
+            cpp_wrapper_codegen = _get_custom_mod_func("CppWrapperCodegen")
+            fx_wrapper_codegen = _get_custom_mod_func("WrapperFxCodegen")
+            if device_scheduling and wrapper_codegen and cpp_wrapper_codegen:
+                register_backend_for_device(
+                    private_backend,
+                    device_scheduling,
+                    wrapper_codegen,
+                    cpp_wrapper_codegen,
+                    fx_wrapper_codegen,
+                )
+        except RuntimeError:
+            pass
+
+
+def index_prevent_reordering(
+    index: Sequence[sympy.Expr],
+    index_vars: Sequence[sympy.Expr],
+    sizes: Sequence[sympy.Expr],
+) -> list[sympy.Expr]:
+    from ..ir import FlexibleLayout
+
+    # added contiguous index prevents reordering
+    return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]
+
+
+def register_device_op_overrides(
+    device: str, device_op_overrides: DeviceOpOverrides
+) -> None:
+    device_op_overrides_dict[device] = device_op_overrides
+
+
+def get_device_op_overrides(device: str) -> DeviceOpOverrides:
+    assert isinstance(device, str), type(device)
+
+    if not device_op_overrides_dict:
+        from . import cpu_device_op_overrides, mps_device_op_overrides  # noqa: F401
+        from .cuda import device_op_overrides  # noqa: F401
+        from .mtia import device_op_overrides as mtia_op_overrides  # noqa: F401
+        from .xpu import device_op_overrides as xpu_op_overrides  # noqa: F401
+
+    return device_op_overrides_dict[device]
+
+
+DTYPE_TO_COMPUTATION_DTYPE: dict[torch.dtype, torch.dtype] = {
+    torch.bfloat16: torch.float,
+    torch.float16: torch.float,
+    **{
+        dtype: dtype
+        for dtype in [
+            torch.bool,
+            torch.float32,
+            torch.float64,
+            torch.int8,
+            torch.int16,
+            torch.int32,
+            torch.int64,
+            torch.uint8,
+            torch.uint16,
+            torch.uint32,
+            torch.uint64,
+        ]
+    },
+}
+
+
+def deduce_output_dtype_by_name(
+    op_name: str,
+    *args: Any,
+    **kwargs: Any,
+) -> Optional[torch.dtype]:
+    """
+    Given op name and a list of input dtypes, deduce the output dtype
+    """
+    if op_name in boolean_ops():
+        return torch.bool
+    elif op_name in (
+        "to_dtype",
+        "index_expr",
+    ):
+        return kwargs["dtype"] if "dtype" in kwargs else args[-1]
+    elif op_name in (
+        "rand",
+        "randn",
+    ):
+        return torch.float
+    elif op_name in (
+        "get_index",
+        "randint64",
+        "load_seed",
+    ):
+        return torch.int64
+    elif op_name == "reduction":
+        return kwargs["dtype"] if "dtype" in kwargs else args[1]
+    elif op_name == "constant":
+        return kwargs["dtype"] if "dtype" in kwargs else args[-1]
+    elif op_name in (
+        "load",
+        "store",
+        "store_reduction",
+    ):
+        buf_name = args[1]
+        return V.graph.get_dtype(buf_name)  # type: ignore[arg-type]
+    elif op_name == "to_dtype_bitcast":
+        return kwargs["dtype"] if "dtype" in kwargs else args[-2]
+    return None
+
+
+def check_dtype(
+    buffer: IndentedBuffer, var: CSEVariableType, dtype: torch.dtype
+) -> None:
+    backend = get_current_backend()
+    if config.test_configs.runtime_triton_dtype_assert and backend == "triton":
+        buffer.writeline(f"tl.static_assert({var}.dtype == {triton_type(dtype)})")
+    elif config.test_configs.static_cpp_dtype_assert and backend == "cpp":
+        from .cpp_utils import CppCSEVariable, DTYPE_TO_CPP
+
+        assert isinstance(var, CppCSEVariable), type(var)
+        if dtype == torch.bool:
+            if var.is_vec:
+                is_same_dt = f"IsVecMaskType::value"
+            else:
+                # operator&(bool, bool) returns int and it can be used as boolean in C++
+                is_same_dt = f"std::is_same_v || std::is_same_v"
+        else:
+            c_var_type = f"decltype({var})"
+            if var.is_vec:
+                c_var_type = f"typename {c_var_type}::value_type"
+            is_same_dt = f"std::is_same_v<{c_var_type}, {DTYPE_TO_CPP[dtype]}>"
+
+        buffer.writeline(f"static_assert({is_same_dt});")
+
+
+def check_shape(
+    buffer: IndentedBuffer, var: CSEVariableType, shape: BlockShapeType
+) -> None:
+    backend = get_current_backend()
+    assert shape is not None
+    if config.test_configs.runtime_triton_shape_assert and backend == "triton":
+        shape_str = (
+            ", ".join(str(d) for d in shape) if len(shape) != 1 else f"{shape[0]},"
+        )
+        buffer.writeline(f"tl.static_assert({var}.shape == ({shape_str}))")
+
+
+def check_nan(buffer: IndentedBuffer, var: CSEVariableType) -> None:
+    backend = get_current_backend()
+    if backend == "triton":
+        msg = "NaN or Inf found"
+        buffer.writeline(
+            f"tl.device_assert(({var} == {var}) & ({var} != float('inf')) & ({var} != float('-inf')), '{msg}')"
+        )
+
+
+class DataTypePropagation:
+    def __init__(self, body: LoopBody) -> None:
+        self.body = body
+        self.graphs: dict[Union[Callable[..., Any], str], Any] = {
+            "root": body.root_block.graph
+        }
+        for k, v in body.subblocks.items():
+            self.graphs[k] = v.graph
+
+    def deduce_node_dtype_by_inputs(self, node: torch.fx.Node) -> Optional[torch.dtype]:
+        inputs = node.all_input_nodes
+        input_nodes = [
+            n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder"
+        ]
+        if len(input_nodes) == 0:
+            return None
+
+        all_input_nodes_propagated = all(
+            OptimizationContext.key in n.meta
+            and n.meta[OptimizationContext.key].dtype is not None
+            for n in input_nodes
+        )
+        if not all_input_nodes_propagated:
+            return None
+
+        return functools.reduce(
+            torch.promote_types,
+            [n.meta[OptimizationContext.key].dtype for n in input_nodes],
+        )
+
+    def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node) -> torch.dtype:
+        sub_graph = self.graphs[node.target]
+        dtype = self.propagate_graph(sub_graph)
+        assert dtype
+        return dtype
+
+    def deduce_node_dtype(self, node: torch.fx.Node) -> Optional[torch.dtype]:
+        if node.op == "placeholder":
+            return None
+
+        if node.target == "output" and len(node.args) != 1:
+            # we can infer output node if it only have 1 arg
+            return None
+
+        if node.target is operator.getitem:
+            node_arg = node.args[0]
+            assert isinstance(node_arg, torch.fx.Node), type(node_arg)
+            return self.deduce_node_dtype(node_arg)
+
+        assert isinstance(node.target, str), type(node.target)
+
+        if node.target.startswith("masked_subblock"):
+            return self.deduce_node_dtype_by_subgraph(node)
+
+        if (
+            output_dtype := deduce_output_dtype_by_name(
+                node.target,
+                *node.args,
+                **node.kwargs,
+            )
+        ) is not None:
+            return output_dtype
+
+        return self.deduce_node_dtype_by_inputs(node)
+
+    def propagate_graph(self, graph: torch.fx.Graph) -> Optional[torch.dtype]:
+        assert graph.nodes
+        graph_dtype: Optional[torch.dtype] = None
+        # For masked_subblock, we use output's dtype to represent
+        # the dtype of this subgraph. For other cases, graph_dtype
+        # might be None
+        for node in graph.nodes:
+            if OptimizationContext.key in node.meta:
+                opt_ctx = node.meta[OptimizationContext.key]
+            else:
+                opt_ctx = OptimizationContext()
+
+            opt_ctx.dtype = self.deduce_node_dtype(node)
+            node.meta[OptimizationContext.key] = opt_ctx
+            if node.target == "output":
+                graph_dtype = opt_ctx.dtype
+        return graph_dtype
+
+    def propagate(self) -> Optional[torch.dtype]:
+        return self.propagate_graph(self.graphs["root"])
+
+    @classmethod
+    def propagate_loopbody(cls, body: LoopBody) -> Optional[torch.dtype]:
+        return cls(body).propagate()
+
+    @classmethod
+    def propagate_scheduler_node(cls, node: SchedulerNode) -> Optional[torch.dtype]:
+        from ..loop_body import LoopBody
+        from ..scheduler import SchedulerNode
+
+        assert isinstance(node, SchedulerNode), type(node)
+        assert isinstance(node._body, LoopBody), type(node._body)
+        return DataTypePropagation.propagate_loopbody(node._body)
+
+
+class PythonPrinter(_PythonPrinter):
+    def doprint(
+        self, expr: sympy.Expr, *, simplify: bool = True, p: bool = True
+    ) -> str:
+        # TODO: why are people passing strings to the printer here :think:
+        if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
+            expr = V.graph.sizevars.simplify(expr)
+        return super().doprint(expr)
+
+    def parenthesize(self, item: sympy.Expr, level: int, strict: bool = False) -> str:
+        if isinstance(item, sympy.Mod):
+            # use parenthesis to enforce precedence.
+            # in sympy 1.13.3, -2*Mod(x,y) becomes -2*x%y, which is wrong.
+            return f"({self._print(item)})"
+        else:
+            return super().parenthesize(item, level, strict)
+
+
+class OpDecompositions:
+    """
+    Decomposes inductor ops
+    """
+
+    @staticmethod
+    def identity(value: OpVarT) -> OpVarT:
+        # used to trigger cse
+        return value
+
+    @staticmethod
+    def reciprocal(x: OpVarT) -> OpVarT:
+        return ops.truediv(ops.constant(1, torch.int32), x)
+
+    @staticmethod
+    def square(x: OpVarT) -> OpVarT:
+        return ops.mul(x, x)
+
+    @staticmethod
+    def erfc(x: OpVarT) -> OpVarT:
+        return ops.sub(ops.constant(1, torch.float32), ops.erf(x))
+
+    @staticmethod
+    def erfcx(x: OpVarT) -> OpVarT:
+        return ops.mul(ops.exp(ops.square(x)), ops.erfc(x))
+
+    @staticmethod
+    def expm1(x: OpVarT) -> OpVarT:
+        return ops.sub(ops.exp(x), ops.constant(1, torch.float32))
+
+    @staticmethod
+    def log10(x: OpVarT) -> OpVarT:
+        return ops.mul(ops.log(x), ops.constant(1 / math.log(10), torch.float32))
+
+    @staticmethod
+    def log2(x: OpVarT) -> OpVarT:
+        return ops.mul(ops.log(x), ops.constant(1 / math.log(2), torch.float32))
+
+    @staticmethod
+    def exp2(x: OpVarT) -> OpVarT:
+        return ops.exp(ops.mul(x, ops.constant(math.log(2), torch.float32)))
+
+    @staticmethod
+    def log1p(x: OpVarT) -> OpVarT:
+        return ops.log(ops.add(x, ops.constant(1, torch.int32)))
+
+    @staticmethod
+    def sigmoid(x: OpVarT) -> OpVarT:
+        one = ops.constant(1, torch.int32)
+        return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x))))
+
+    @staticmethod
+    def relu(x: OpVarT) -> OpVarT:
+        return ops.maximum(x, ops.constant(0, torch.int32))
+
+    @staticmethod
+    def fma(x: OpVarT, y: OpVarT, z: OpVarT) -> OpVarT:
+        # for backends that don't override this (halide)
+        return ops.add(ops.mul(x, y), z)
+
+    @staticmethod
+    def floor_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT:
+        return ops.to_dtype(ops.floor(a), dtype)
+
+    @staticmethod
+    def ceil_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT:
+        return ops.to_dtype(ops.ceil(a), dtype)
+
+    @staticmethod
+    def trunc_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT:
+        return ops.to_dtype(ops.trunc(a), dtype)
+
+    @staticmethod
+    def remainder(a: OpVarT, b: OpVarT) -> OpVarT:
+        r = ops.mod(a, b)
+        cond = ops.and_(
+            ops.ne(r, ops.constant(0, torch.int32)),
+            ops.ne(ops.signbit(r), ops.signbit(b)),
+        )
+        return ops.where(cond, ops.add(r, b), r)
+
+    @staticmethod
+    def round_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT:
+        return ops.to_dtype(ops.round(a), dtype)
+
+
+_RE_PAREN_NOT_NEEDED = re.compile(r"[a-z0-9_.]+|\([^)]*\)|", flags=re.IGNORECASE)
+
+
+def _all_in_parens(string: str) -> bool:
+    if string[0] != "(" or len(string) < 2:
+        return False
+    count = 1
+    for i, char in enumerate(string[1:]):
+        if char == "(":
+            count += 1
+        elif char == ")":
+            count -= 1
+        if count == 0 and i != len(string) - 2:
+            return False
+    assert count == 0
+    return True
+
+
+class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]):
+    @staticmethod
+    def paren(string: OpVarT) -> OpVarT:
+        if (
+            isinstance(string, CSEVariable)
+            or _RE_PAREN_NOT_NEEDED.fullmatch(string)
+            or _all_in_parens(string)
+        ):
+            # don't put extra parens for strings that are already wrapped in parens
+            # pyrefly: ignore [bad-return]
+            return string
+        return f"({string})"
+
+    @staticmethod
+    def constant(value: Union[bool, float, int], dtype: torch.dtype) -> OpVarT:
+        return repr(value)
+
+    @staticmethod
+    def bitwise_not(x: OpVarT) -> OpVarT:
+        return f"~{OpOverrides.paren(x)}"
+
+    @staticmethod
+    def logical_not(a: OpVarT) -> OpVarT:
+        return f"{OpOverrides.paren(a)} == 0"
+
+    @staticmethod
+    def bitwise_and(x: OpVarT, y: OpVarT) -> OpVarT:
+        return f"{OpOverrides.paren(x)} & {OpOverrides.paren(y)}"
+
+    @staticmethod
+    def bitwise_or(x: OpVarT, y: OpVarT) -> OpVarT:
+        return f"{OpOverrides.paren(x)} | {OpOverrides.paren(y)}"
+
+    @staticmethod
+    def bitwise_xor(x: OpVarT, y: OpVarT) -> OpVarT:
+        return f"{OpOverrides.paren(x)} ^ {OpOverrides.paren(y)}"
+
+    @staticmethod
+    def bitwise_left_shift(x: OpVarT, y: OpVarT) -> OpVarT:
+        return f"{OpOverrides.paren(x)} << {OpOverrides.paren(y)}"
+
+    @staticmethod
+    def bitwise_right_shift(x: OpVarT, y: OpVarT) -> OpVarT:
+        return f"{OpOverrides.paren(x)} >> {OpOverrides.paren(y)}"
+
+    @staticmethod
+    def int_truediv(a: OpVarT, b: OpVarT) -> OpVarT:
+        # TODO: this is wrong
+        # TODO: an easy bandaid is to generate runtime asserts that it's
+        # <= 2**53, which is when this equation is correct
+        return ops.truediv(a, b)
+
+    @staticmethod
+    def load_seed(name: str, offset: OpVarT) -> OpVarT:
+        return ops.load(name, sympy.Integer(offset))
+
+    def indirect_indexing(
+        self,
+        var: OpVarT,
+        size: Union[sympy.Expr, int],
+        check: bool = True,
+        wrap_neg: bool = True,
+    ) -> sympy.Symbol:
+        return sympy_index_symbol(str(var))
+
+    def check_bounds(
+        self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
+    ) -> None:
+        raise NotImplementedError(
+            f"{type(self).__name__}: check_bounds should be handled by CSEProxy"
+        )
+
+    def load(self, name: str, index: sympy.Expr) -> OpVarT:
+        raise NotImplementedError(
+            f"{type(self).__name__}: load should be handled by CSEProxy"
+        )
+
+    def store(
+        self, name: str, index: sympy.Expr, value: OpVarT, mode: StoreMode = None
+    ) -> None:
+        raise NotImplementedError(
+            f"{type(self).__name__}: store should be handled by CSEProxy"
+        )
+
+    def device_assert_async(self, cond: CSEVariable, msg: str) -> None:
+        raise NotImplementedError(
+            f"{type(self).__name__}: device_assert_async should be handled by CSEProxy"
+        )
+
+    def store_reduction(self, name: str, index: sympy.Expr, value: OpVarT) -> None:
+        raise NotImplementedError(
+            f"{type(self).__name__}: store_reduction should be handled by CSEProxy"
+        )
+
+    def reduction(
+        self,
+        dtype: torch.dtype,
+        src_dtype: torch.dtype,
+        reduction_type: ReductionType,
+        value: Union[OpVarT, tuple[OpVarT, ...]],
+    ) -> Union[OpVarT, tuple[OpVarT, ...]]:
+        raise NotImplementedError(
+            f"{type(self).__name__}: reduction should be handled by CSEProxy"
+        )
+
+    def scan(
+        self,
+        dtypes: tuple[torch.dtype, ...],
+        combine_fn: Callable[
+            [tuple[OpVarT, ...], tuple[OpVarT, ...]],
+            tuple[OpVarT, ...],
+        ],
+        values: tuple[OpVarT, ...],
+    ) -> tuple[OpVarT, ...]:
+        raise NotImplementedError(
+            f"{type(self).__name__}: scan should be handled by CSEProxy"
+        )
+
+    def sort(
+        self,
+        dtypes: tuple[torch.dtype, ...],
+        values: tuple[OpVarT, ...],
+        stable: bool,
+        descending: bool,
+    ) -> tuple[OpVarT, ...]:
+        raise NotImplementedError(
+            f"{type(self).__name__}: sort should be handled by CSEProxy"
+        )
+
+    def bucketize(
+        self,
+        values: OpVarT,
+        boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
+        boundary_indices: OpVarT,
+        indexing_dtype: torch.dtype,
+        right: bool,
+        sorter: Optional[tuple[str, sympy.Expr]] = None,
+        sorter_indices: Optional[OpVarT] = None,
+    ) -> OpVarT:
+        raise NotImplementedError(
+            f"{type(self).__name__}: bucketize should be handled by CSEProxy"
+        )
+
+    def halide_clamp(self, value: OpVarT, size: sympy.Expr, check: bool) -> OpVarT:
+        raise NotImplementedError(
+            f"{type(self).__name__}: halide_clamp only implemented for Halide backend"
+        )
+
+    def dot(self, x: OpVarT, y: OpVarT) -> OpVarT:
+        raise NotImplementedError(
+            f"{type(self).__name__}: dot only implemented for Triton backend"
+        )
+
+    def inline_asm_elementwise(
+        self,
+        *inputs: OpVarT,
+        asm: str,
+        constraints: Optional[str] = None,
+        dtype: torch.dtype = torch.float32,
+        is_pure: bool = True,
+        pack: int = 1,
+    ) -> OpVarT:
+        raise NotImplementedError(
+            f"{type(self).__name__}: inline_asm_elementwise only implemented for Triton backend"
+        )
+
+    def output(self, *args: OpVarT) -> None:
+        raise AssertionError(
+            f"{type(self).__name__}: ops.output should not appear at codegen time"
+        )
+
+    def placeholder(self, index: int) -> OpVarT:
+        raise AssertionError(
+            f"{type(self).__name__}: ops.placeholder should not appear at codegen time"
+        )
+
+    @staticmethod
+    def _unimplemented(name: str) -> Callable[..., OpVarT]:
+        def unimplemented(self: OpOverrides, *args: Any, **kwargs: Any) -> OpVarT:
+            raise NotImplementedError(
+                f"{type(self).__name__} does not implement ops.{name}"
+            )
+
+        unimplemented.__name__ = name
+        unimplemented.is_unimplemented = True  # type: ignore[attr-defined]
+        return unimplemented
+
+    @classmethod
+    def _is_unimplemented(cls, name: str) -> bool:
+        fn = getattr(cls, name, None)
+        default_fn = getattr(OpsHandler, name, None)
+        return not fn or fn == default_fn or getattr(fn, "is_unimplemented", False)
+
+    @classmethod
+    def _initialize_pointwise_overrides(cls, target: str) -> None:
+        assert target in ("triton", "cpp", "cppvec", "halide", "mps"), target
+
+        for funcname, data in pointwise_overrides_data.items():
+            impl = getattr(data, target)
+            if impl is None:
+                if cls._is_unimplemented(funcname):
+                    setattr(cls, funcname, cls._unimplemented(funcname))
+            else:
+                assert funcname not in cls.__dict__, (
+                    f"multiple definitions of {funcname} on {cls.__name__}"
+                )
+                impl.__name__ = funcname
+                setattr(cls, funcname, staticmethod(impl))
+
+
+@dataclasses.dataclass
+class OverridesData:
+    name: str
+    cpp: Callable[..., str]
+    # None when not impl in libdevice/triton
+    triton: Optional[Callable[..., str]] = None
+    # None when not impl in aten/.../vec
+    cppvec: Optional[Callable[..., str]] = None
+    type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = (
+        ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    )
+    halide: Optional[Callable[..., str]] = None
+    mps: Optional[Callable[..., str]] = None
+
+
+# NB: if you add a new special function, don't forget to update
+# torch._inductor.ops_handler too
+pointwise_overrides_data: dict[str, OverridesData] = dict(
+    airy_ai=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x: f"airy_ai_forward({x})",
+        name="special_airy_ai",
+    ),
+    bessel_j0=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x: f"bessel_j0_forward({x})",
+        triton=lambda x: f"libdevice.j0({x})",
+        name="special_bessel_j0",
+    ),
+    bessel_j1=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x: f"bessel_j1_forward({x})",
+        triton=lambda x: f"libdevice.j1({x})",
+        name="special_bessel_j1",
+    ),
+    bessel_y0=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x: f"bessel_y0_forward({x})",
+        triton=lambda x: f"libdevice.y0({x})",
+        name="special_bessel_y0",
+    ),
+    bessel_y1=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x: f"bessel_y1_forward({x})",
+        triton=lambda x: f"libdevice.y1({x})",
+        name="special_bessel_y1",
+    ),
+    digamma=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x: f"calc_digamma({x})",
+        cppvec=lambda x: f"{x}.digamma()",
+        name="digamma",
+    ),
+    # no cpp nor triton implementation for entr, it is defined as decomposition
+    # erf, erfc
+    erfcx=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x: f"calc_erfcx({x})",
+        triton=lambda x: f"libdevice.erfcx({x})",
+        name="special_erfcx",
+    ),
+    fma=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x, y, z: f"std::fma({x}, {y}, {z})",
+        cppvec=lambda x, y, z: f"fmadd({x}, {y}, {z})",
+        triton=lambda x, y, z: f"libdevice.fma({x}, {y}, {z})",
+        name="fma",
+    ),
+    # erfinv, exp2, expit, gammaln
+    igamma=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x, y: f"calc_igamma({x}, {y})",
+        name="igamma",
+    ),
+    igammac=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x, y: f"calc_igammac({x}, {y})",
+        name="igammac",
+    ),
+    gammainc=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x, y: f"calc_igamma({x}, {y})",
+        name="special_gammainc",
+    ),
+    gammaincc=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x, y: f"calc_igammac({x}, {y})",
+        name="special_gammaincc",
+    ),
+    i0=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x: f"calc_i0({x})",
+        triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
+        cppvec=lambda x: f"{x}.i0()",
+        name="i0",
+    ),
+    i0e=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x: f"calc_i0e({x})",
+        cppvec=lambda x: f"{x}.i0e()",
+        name="special_i0e",
+    ),
+    i1=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x: f"calc_i1({x})",
+        triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
+        name="special_i1",
+    ),
+    i1e=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x: f"calc_i1e({x})",
+        name="special_i1e",
+    ),
+    log_ndtr=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x: f"calc_log_ndtr({x})",
+        name="special_log_ndtr",
+    ),
+    # logit
+    modified_bessel_i0=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x: f"modified_bessel_i0_forward({x})",
+        triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
+        name="special_modified_bessel_i0",
+    ),
+    modified_bessel_i1=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x: f"modified_bessel_i1_forward({x})",
+        triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
+        name="special_modified_bessel_i1",
+    ),
+    modified_bessel_k0=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x: f"modified_bessel_k0_forward({x})",
+        name="special_modified_bessel_k0",
+    ),
+    modified_bessel_k1=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x: f"modified_bessel_k1_forward({x})",
+        name="special_modified_bessel_k1",
+    ),
+    # multigamma
+    ndtr=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x: f"calc_ndtr({x})",
+        name="special_ndtr",
+    ),
+    ndtri=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x: f"calc_ndtri({x})",
+        name="special_ndtri",
+    ),
+    polygamma=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x,
+        y: f"{x} == 0 ? calc_digamma({y}) : ({x} == 1 ? trigamma({y}) : calc_polygamma({y}, {x}))",
+        name="polygamma",
+    ),
+    # psi - alias to digamma
+    # round
+    scaled_modified_bessel_k0=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x: f"scaled_modified_bessel_k0_forward({x})",
+        name="special_scaled_modified_bessel_k0",
+    ),
+    scaled_modified_bessel_k1=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x: f"scaled_modified_bessel_k1_forward({x})",
+        name="special_scaled_modified_bessel_k1",
+    ),
+    # sinc
+    spherical_bessel_j0=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x: f"spherical_bessel_j0_forward({x})",
+        name="special_spherical_bessel_j0",
+    ),
+    zeta=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x, y: f"zeta({x}, {y})",
+        name="special_zeta",
+    ),
+    chebyshev_polynomial_t=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x, y: f"chebyshev_polynomial_t_forward({x}, {y})",
+        name="special_chebyshev_polynomial_t",
+    ),
+    chebyshev_polynomial_u=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x, y: f"chebyshev_polynomial_u_forward({x}, {y})",
+        name="special_chebyshev_polynomial_u",
+    ),
+    chebyshev_polynomial_v=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x, y: f"chebyshev_polynomial_v_forward({x}, {y})",
+        name="special_chebyshev_polynomial_v",
+    ),
+    chebyshev_polynomial_w=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x, y: f"chebyshev_polynomial_w_forward({x}, {y})",
+        name="special_chebyshev_polynomial_w",
+    ),
+    legendre_polynomial_p=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x, y: f"legendre_polynomial_p_forward({x}, {y})",
+        name="special_legendre_polynomial_p",
+    ),
+    shifted_chebyshev_polynomial_t=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x, y: f"shifted_chebyshev_polynomial_t_forward({x}, {y})",
+        name="special_shifted_chebyshev_polynomial_t",
+    ),
+    shifted_chebyshev_polynomial_u=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x, y: f"shifted_chebyshev_polynomial_u_forward({x}, {y})",
+        name="special_shifted_chebyshev_polynomial_u",
+    ),
+    shifted_chebyshev_polynomial_v=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x, y: f"shifted_chebyshev_polynomial_v_forward({x}, {y})",
+        name="special_shifted_chebyshev_polynomial_v",
+    ),
+    shifted_chebyshev_polynomial_w=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x, y: f"shifted_chebyshev_polynomial_w_forward({x}, {y})",
+        name="special_shifted_chebyshev_polynomial_w",
+    ),
+    hermite_polynomial_h=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x, y: f"hermite_polynomial_h_forward({x}, {y})",
+        name="special_hermite_polynomial_h",
+    ),
+    hermite_polynomial_he=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x, y: f"hermite_polynomial_he_forward({x}, {y})",
+        name="special_hermite_polynomial_he",
+    ),
+    laguerre_polynomial_l=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp=lambda x, y: f"laguerre_polynomial_l_forward({x}, {y})",
+        name="special_laguerre_polynomial_l",
+    ),
+)
+
+
+def is_buffer_removed(name: str) -> bool:
+    return any(
+        name in x
+        for x in (
+            V.graph.removed_buffers,
+            V.kernel.removed_buffers,
+            V.graph.inplaced_to_remove,
+            V.kernel.inplaced_to_remove,
+        )
+    )
+
+
+class DeferredLine(DeferredLineBase):
+    """A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
+
+    def __init__(self, name: str, line: str):
+        super().__init__(line)
+        self.name = name
+        assert not isinstance(line, DeferredLineBase)
+
+    def __call__(self) -> Optional[str]:
+        if not is_buffer_removed(self.name):
+            return self.line
+        return None
+
+    def _new_line(self, line: str) -> DeferredLine:
+        return DeferredLine(self.name, line)
+
+
+class BracesBuffer(IndentedBuffer):
+    def indent(self, offset: int = 1) -> contextlib.AbstractContextManager[None]:
+        @contextlib.contextmanager
+        def ctx() -> Iterator[None]:
+            for _ in range(offset):
+                self.writeline("{")
+                self._indent += 1
+            for _ in range(-offset):
+                self._indent -= 1
+                self.writeline("}")
+            yield
+            for _ in range(-offset):
+                self.writeline("{")
+                self._indent += 1
+            for _ in range(offset):
+                self._indent -= 1
+                self.writeline("}")
+
+        return ctx()
+
+
+class InplacedBuffer(NamedTuple):
+    inner_name: str
+    other_names: list[str]
+
+
+@dataclasses.dataclass
+class ArgName:
+    name: str
+    # is_constexpr=True is used to attach a " : tl.constexpr" into the argument list
+    is_constexpr: bool = False
+
+    def full_name(self) -> str:
+        return f"{self.name}{' : tl.constexpr' if self.is_constexpr else ''}"
+
+
+class RemovedArg:
+    def __str__(self) -> str:
+        return "REMOVED"
+
+
+REMOVED = RemovedArg()
+
+
+class KernelArgs:
+    @staticmethod
+    def _lookup(
+        prefix: str,
+        odict: Union[dict[_T, Union[str, RemovedArg]], dict[_T, str]],
+        name: _T,
+    ) -> str:
+        result: Union[str, RemovedArg] = odict.get(name, REMOVED)
+        if isinstance(result, RemovedArg):
+            odict[name] = new_result = f"{prefix}{len(odict)}"
+            return new_result
+        return result
+
+    def __init__(self) -> None:
+        self.input_buffers: dict[str, str] = {}
+        self.output_buffers: dict[str, Union[str, RemovedArg]] = {}
+        self.inplace_buffers: dict[str, Union[InplacedBuffer, RemovedArg]] = {}
+        self.sizevars: dict[sympy.Expr, str] = {}
+        self.workspace_args: list[WorkspaceArg] = []
+
+    def __repr__(self) -> str:
+        return "KernelArgs({})".format(
+            ", ".join(
+                map(
+                    repr,
+                    [
+                        self.input_buffers,
+                        self.output_buffers,
+                        self.inplace_buffers,
+                        self.sizevars,
+                    ],
+                )
+            )
+        )
+
+    @staticmethod
+    def _buffer_is_marked_removed(name: Any) -> bool:
+        # this function is needed by MTIA
+        return isinstance(name, RemovedArg)
+
+    def input(self, name: str) -> str:
+        if V.graph.scheduler:
+            name = V.graph.scheduler.mutation_real_name.get(name, name)
+        assert name not in V.graph.removed_buffers, name
+        if name in self.output_buffers:
+            return cast(str, self.output_buffers[name])
+        if name in self.inplace_buffers:
+            return cast(InplacedBuffer, self.inplace_buffers[name]).inner_name
+        if name.startswith("seed"):
+            return self._lookup("seed", self.input_buffers, name)
+        return self._lookup("in_ptr", self.input_buffers, name)
+
+    def output(self, name: str) -> str:
+        if V.graph.scheduler:
+            name = V.graph.scheduler.mutation_real_name.get(name, name)
+        assert name not in V.graph.removed_buffers, name
+        if name in self.inplace_buffers:
+            return cast(InplacedBuffer, self.inplace_buffers[name]).inner_name
+        return self._lookup("out_ptr", self.output_buffers, name)
+
+    def make_inplace(self, input_name: str, output_name: str) -> None:
+        if input_name in V.graph.unaligned_buffers:
+            V.graph.unaligned_buffers.add(output_name)
+        assert output_name not in self.inplace_buffers, output_name
+        if input_name in self.inplace_buffers:
+            buf = self.inplace_buffers[input_name]
+            assert not isinstance(buf, RemovedArg)
+            buf.other_names.append(output_name)
+            self.inplace_buffers[output_name] = buf
+        else:
+            alive_buffers = [
+                val
+                for val in self.inplace_buffers.values()
+                if not isinstance(val, RemovedArg)
+            ]
+            removed_buffers = [
+                val
+                for val in self.inplace_buffers.values()
+                if isinstance(val, RemovedArg)
+            ]
+            inplace_buffer_idx = len(unique(alive_buffers)) + len(removed_buffers)
+            buf = InplacedBuffer(
+                f"in_out_ptr{inplace_buffer_idx}",
+                [input_name, output_name],
+            )
+            self.inplace_buffers[input_name] = buf
+            self.inplace_buffers[output_name] = buf
+
+    def workspace(
+        self, nelem: sympy.Expr, zero_fill: bool, dtype: torch.dtype = torch.uint8
+    ) -> tuple[str, str, int]:
+        """
+        Allocate or extend a workspace buffer of nelem elements.
+
+        This function manages the allocation of a workspace buffer. It either creates
+        a new WorkspaceArg or extends an existing one.
+
+        Note:
+        - Calling this function will in-place mutate the args by adding or updating
+        a WorkspaceArg.
+        - The codegen for generating the Python argdefs and call_defs will check
+        this field and allocate the buffer accordingly.
+        - A new argument "ws_ptr" will be present in the generated code.
+
+        Args:
+            nelem (sympy.Expr): The number of elements to allocate.
+            zero_fill (bool): Whether to initialize the buffer to zero.
+            dtype (torch.dtype): the dtype of the workspace tensor
+
+        Returns:
+            Tuple[str, str, int]: A tuple containing:
+                - "ws_ptr": A string identifier for the workspace pointer.
+                - "workspace_{i}": agraph level unique identifier for
+                    the workspace tensor.
+                - offset: An integer representing the item offset in the workspace.
+        """
+        arg = WorkspaceArg(
+            count=nelem,
+            zero_mode=WorkspaceZeroMode.from_bool(zero_fill),
+            device=V.graph.get_current_device_or_throw(),
+            outer_name=WorkspaceArg.unique_name(),
+            dtype=dtype,
+        )
+        for i, existing_arg in enumerate(self.workspace_args):
+            if WorkspaceArg.can_join(existing_arg, arg):
+                offset = existing_arg.count
+                self.workspace_args[i] = WorkspaceArg.join(existing_arg, arg)
+                return existing_arg.inner_name, existing_arg.outer_name, offset
+            assert (
+                existing_arg.inner_name != arg.inner_name
+                and existing_arg.outer_name != arg.outer_name
+            ), existing_arg
+        self.workspace_args.append(arg)
+        return arg.inner_name, arg.outer_name, 0
+
+    def semaphores(self, min_size: sympy.Expr) -> str:
+        """
+        Lazily allocate a graph-wide semaphores buffer with at least min_size.  This is a single buffer shared by
+        all kernels and zero initialized once at graph start.  Each kernel must leave the buffer zeroed on exit.
+
+        Warning: multiple calls to this function will return the same buffer.
+
+        Args:
+            min_size: the number of int32 semaphores required
+
+        Returns:
+            name of the semaphores buffer
+        """
+        current_device = V.graph.get_current_device_or_throw()
+        arg = WorkspaceArg(
+            count=min_size,
+            zero_mode=WorkspaceZeroMode.ZERO_PER_GRAPH,
+            dtype=torch.uint32,
+            inner_name="sem_ptr",
+            outer_name=f"semaphores_{current_device.type}_{current_device.index}",
+            device=current_device,
+        )
+        for existing_arg in self.workspace_args:
+            if existing_arg.inner_name == arg.inner_name:
+                assert arg == existing_arg, (arg, existing_arg)
+        self.workspace_args.append(arg)
+        return arg.inner_name
+
+    def seed_offset(self, name: str, value: int) -> str:
+        assert isinstance(value, int), (type(value), value)
+        # here we are lifting a constant integer into an arg to the kernel to try to get additional cache hits
+        value = sympy.Integer(value)
+        if value in self.sizevars:
+            return self.sizevars[value]
+        if name in self.sizevars.values():
+            name = (
+                f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}"
+            )
+        self.sizevars[value] = name
+        return name
+
+    def size(self, name: sympy.Symbol) -> str:
+        assert isinstance(name, sympy.Symbol), (type(name), name)
+        if name.name == "seed":
+            self.sizevars[name] = "seed"  # don't manage the name of seeds
+            return "seed"
+        return self._lookup("ks", self.sizevars, name)
+
+    def call_names(self) -> Iterator[str]:
+        return chain(
+            self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
+        )
+
+    def arg_name(self, name: str) -> Optional[str]:
+        """
+        Returns inner name of a given outer name.
+        """
+        inplaced = self.inplace_buffers.get(name, None)
+        if inplaced is not None and not isinstance(inplaced, RemovedArg):
+            return inplaced.inner_name
+        output_name = self.output_buffers.get(name, None)
+        if output_name is not None and not isinstance(output_name, RemovedArg):
+            return output_name
+        return self.input_buffers.get(name, None)
+
+    def wrap_ptr_arg(self, buf: str, dtype: torch.dtype) -> str:
+        return buf
+
+    def wrap_size_arg(self, size: SymbolLike) -> str:
+        return str(size)
+
+    def cpp_argdefs(
+        self, dtype_to_cpp_type: Optional[dict[torch.dtype, str]] = None
+    ) -> tuple[list[str], list[str], list[str]]:
+        from .cpp_utils import INDEX_TYPE
+
+        if dtype_to_cpp_type is None:
+            from .cpp_utils import DTYPE_TO_CPP
+
+            dtype_to_cpp_type = DTYPE_TO_CPP
+
+        call_args = []
+        arg_defs = []
+        arg_types = []
+        for inplaced in unique(self.inplace_buffers.values()):
+            if isinstance(inplaced, RemovedArg):
+                continue
+            outer = inplaced.other_names[-1]
+            inner = inplaced.inner_name
+            dtype = V.graph.get_dtype(outer)
+            cpp_dtype = dtype_to_cpp_type[dtype]
+            arg_defs.append(f"{cpp_dtype}* {inner}")
+            call_args.append(self.wrap_ptr_arg(outer, dtype))
+            arg_types.append(f"{cpp_dtype}*")
+        for outer, inner in self.input_buffers.items():
+            if outer in self.inplace_buffers:
+                continue
+            dtype = V.graph.get_dtype(outer)
+            cpp_dtype = dtype_to_cpp_type[dtype]
+            arg_defs.append(f"const {cpp_dtype}* {inner}")
+            call_args.append(self.wrap_ptr_arg(outer, dtype))
+            arg_types.append(f"const {cpp_dtype}*")
+        for outer, maybe_inner in self.output_buffers.items():
+            if outer in self.inplace_buffers or isinstance(maybe_inner, RemovedArg):
+                continue
+            dtype = V.graph.get_dtype(outer)
+            cpp_dtype = dtype_to_cpp_type[dtype]
+            arg_defs.append(f"{cpp_dtype}* {maybe_inner}")
+            call_args.append(self.wrap_ptr_arg(outer, dtype))
+            arg_types.append(f"{cpp_dtype}*")
+        for outer, inner in self.sizevars.items():
+            if isinstance(outer, sympy.Symbol) and symbol_is_type(
+                outer, (SymT.UNBACKED_FLOAT)
+            ):
+                arg_defs.append(f"const float {inner}")
+                arg_types.append("const float")
+            else:
+                arg_defs.append(f"const {INDEX_TYPE} {inner}")
+                arg_types.append(f"const {INDEX_TYPE}")
+            call_args.append(self.wrap_size_arg(outer))
+            if V.graph.wrapper_code:
+                V.graph.wrapper_code.ensure_size_computed(outer)
+        assert not self.workspace_args, "Workspace not supported on CPU "
+        return arg_defs, call_args, arg_types
+
+    def python_argdefs(
+        self,
+    ) -> tuple[list[ArgName], list[str], list[KernelArgType], list[Any]]:
+        arg_defs: list[ArgName] = []
+        call_args: list[str] = []
+        arg_types: list[Any] = []
+        precompile_args: list[KernelArgType] = []
+        for inplaced in unique(self.inplace_buffers.values()):
+            if isinstance(inplaced, RemovedArg):
+                continue
+            arg_defs.append(ArgName(inplaced.inner_name))
+            call_args.append(inplaced.other_names[-1])
+            arg_types.append(V.graph.get_dtype(inplaced.other_names[-1]))
+            precompile_args.append(
+                TensorArg(
+                    name=inplaced.inner_name,
+                    buffer=inplaced.other_names[-1],
+                    dtype=V.graph.get_dtype(inplaced.other_names[-1]),
+                )
+            )
+        for outer, inner in chain(
+            self.input_buffers.items(),
+            # pyrefly: ignore [bad-argument-type]
+            self.output_buffers.items(),
+        ):
+            if outer in self.inplace_buffers or isinstance(inner, RemovedArg):
+                continue
+            arg_defs.append(ArgName(inner))
+            call_args.append(outer)
+            arg_types.append(V.graph.get_dtype(outer))
+            precompile_args.append(
+                TensorArg(
+                    name=inner,
+                    buffer=outer,
+                    dtype=V.graph.get_dtype(outer),
+                )
+            )
+        for outer, inner in self.sizevars.items():
+            arg_defs.append(ArgName(inner))
+            call_args.append(outer)
+            arg_types.append(type(outer))
+            precompile_args.append(SizeArg(inner, outer))
+            if V.graph.wrapper_code:
+                V.graph.wrapper_code.ensure_size_computed(outer)
+        for arg in self.workspace_args:
+            arg_defs.append(ArgName(arg.inner_name))
+            call_args.append(arg.outer_name)
+            precompile_args.append(arg)
+            arg_types.append(arg.dtype)
+        return arg_defs, call_args, precompile_args, arg_types
+
+    def aliases(self) -> Iterator[tuple[str, str]]:
+        for inplaced in unique(self.inplace_buffers.values()):
+            if isinstance(inplaced, RemovedArg):
+                continue
+            for other in inplaced.other_names:
+                if (
+                    other in V.graph.inplaced_to_remove
+                    or other in V.kernel.inplaced_to_remove
+                ):
+                    continue
+                if other in self.input_buffers:
+                    yield self.input_buffers[other], inplaced.inner_name
+                if other in self.output_buffers:
+                    yield cast(str, self.output_buffers[other]), inplaced.inner_name
+
+    def is_removed(self, name: str) -> bool:
+        return isinstance(
+            self.output_buffers.get(name, REMOVED), RemovedArg
+        ) and isinstance(self.inplace_buffers.get(name, REMOVED), RemovedArg)
+
+    # Includes inplace buffers, excludes removed buffers.  Essentially,
+    # after you do a call into this kernel, which buffers actually contain
+    # updated data?  Modeled off of python_argdefs.
+    def live_output_buffers(self) -> OrderedSet[str]:
+        live_outs: OrderedSet[str] = OrderedSet()
+        for inplaced in unique(self.inplace_buffers.values()):
+            if isinstance(inplaced, RemovedArg):
+                continue
+            live_outs.add(inplaced.other_names[-1])
+        for outer, inner in self.output_buffers.items():
+            if outer in self.inplace_buffers or isinstance(inner, RemovedArg):
+                continue
+            live_outs.add(outer)
+        return live_outs
+
+
+class CSEVariable:
+    """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis.
+    To do so, the backends can simply overload `Kernel.create_cse_var`
+    The "CSEVariable.update_on_args" method gives you a hook for annotations
+    See example of TritonCSEVariable in triton.py
+    """
+
+    def __init__(
+        self,
+        name: str,
+        bounds: ValueRanges[Any],
+        dtype: Optional[torch.dtype] = None,
+        shape: BlockShapeType = None,
+    ):
+        super().__init__()
+        assert isinstance(bounds, ValueRanges), type(bounds)
+        self.name = name
+        self.bounds = bounds
+        self.use_count = 1  # track how many times this expression is used
+        self.dtype = dtype
+        self.shape = shape
+
+    def __str__(self) -> str:
+        return self.name
+
+    def __hash__(self) -> int:
+        return hash(self.name)
+
+    def __eq__(self, other: object) -> bool:
+        return isinstance(other, CSEVariable) and other.name == self.name
+
+    def update_on_args(self, name: str, args: Any, kwargs: Any) -> None:
+        pass
+
+    def __repr__(self) -> str:
+        return f"{self.__class__.__name__}({self.name!r})"
+
+
+AugmentedKeyT = TypeVar("AugmentedKeyT", default=str)
+CSEVariableType = TypeVar("CSEVariableType", bound=CSEVariable, default=CSEVariable)
+
+if TYPE_CHECKING:
+    ReductionCacheKey = tuple[
+        torch.dtype,
+        ReductionType,
+        Union[CSEVariable, tuple[CSEVariable, ...]],
+    ]
+
+
+class CSE(Generic[CSEVariableType, AugmentedKeyT]):
+    """Common subexpression elimination"""
+
+    def __init__(
+        self,
+        prefix: str = "",
+        suffix: str = "",
+        name_prefix: str = "tmp",
+        iter_buffers: Optional[itertools.count[int]] = None,
+        store_cache: Optional[MutableMapping[str, CSEVariableType]] = None,
+        reduction_cache: Optional[
+            MutableMapping[ReductionCacheKey, CSEVariableType]
+        ] = None,
+        varname_map: Optional[dict[str, CSEVariableType]] = None,
+    ):
+        self.prefix = prefix
+        self.suffix = suffix
+        self._cache: MutableMapping[AugmentedKeyT, CSEVariableType] = {}
+        self.name_prefix = name_prefix
+        self.store_cache: MutableMapping[str, CSEVariableType] = store_cache or {}
+        self.reduction_cache: MutableMapping[ReductionCacheKey, CSEVariableType] = (
+            reduction_cache or {}
+        )
+        self.iter_buffer_ids: itertools.count[int] = iter_buffers or itertools.count()
+        self.invalidated_stores: OrderedSet[str] = OrderedSet()
+        self.varname_map: dict[str, CSEVariableType] = varname_map or {}
+
+    def invalidate(self, keep_vars: OrderedSet[CSEVariable]) -> None:
+        for name, tmp in [*self.store_cache.items()]:
+            if tmp not in keep_vars:
+                del self.store_cache[name]
+                self.invalidated_stores.add(name)
+        if keep_vars:
+            self._cache = {k: v for k, v in self._cache.items() if v in keep_vars}
+        else:
+            self._cache = {}
+
+    def clone(self) -> Self:
+        return type(self)(
+            prefix=self.prefix,
+            suffix=self.suffix,
+            name_prefix=self.name_prefix,
+            iter_buffers=self.iter_buffer_ids,
+            store_cache=self.store_cache,
+            varname_map=self.varname_map,
+            reduction_cache=self.reduction_cache,
+        )
+
+    def scoped_copy(self) -> Self:
+        """Return a copy of using ScopedDict so changes to *_cache aren't visible in self"""
+        new_cse = self.clone()
+        new_cse._cache = ScopedDict(self._cache)
+        new_cse.reduction_cache = ScopedDict(self.reduction_cache)
+        new_cse.store_cache = ScopedDict(self.store_cache)
+        return new_cse
+
+    def augment_key(self, cache_key: str) -> AugmentedKeyT:
+        "Override this method to augment cache key with backend specifics"
+        return cast(AugmentedKeyT, cache_key)
+
+    def put(self, cache_key: str, val: CSEVariableType) -> None:
+        self._cache[self.augment_key(cache_key)] = val
+
+    def contains(self, cache_key: str) -> bool:
+        return self.augment_key(cache_key) in self._cache
+
+    def try_get(self, cache_key: str) -> Optional[CSEVariableType]:
+        return self._cache.get(self.augment_key(cache_key), None)
+
+    def get(self, cache_key: str) -> CSEVariableType:
+        return self._cache[self.augment_key(cache_key)]
+
+    def generate(
+        self,
+        buffer: IndentedBuffer,
+        expr: Union[str, CSEVariable, OpsValue, IndentedBuffer, DeferredLineBase],
+        *,
+        bounds: ValueRanges[Any] = ValueRanges.unknown(),
+        write: bool = True,
+        assignment: bool = True,
+        dtype: Optional[torch.dtype] = None,
+        shape: BlockShapeType = None,
+    ) -> CSEVariableType:
+        if isinstance(expr, OpsValue):
+            expr = expr.value
+
+        assert write or assignment
+        if isinstance(expr, CSEVariable):
+            # If the expressions were always created with all the information, we could
+            # assert expr.bounds == bounds, but sometimes the expression is created
+            # with the loose ValueRanges.unknown(), so we need to tighten the bounds
+            expr.bounds = expr.bounds.tighten(bounds)
+            expr.use_count += 1
+            return cast(CSEVariableType, expr)
+        elif isinstance(expr, IndentedBuffer):
+            cache_key = expr.getvalue()
+        elif isinstance(expr, DeferredLineBase):
+            cache_key = expr.line
+        else:
+            assert isinstance(expr, str)
+            cache_key = expr
+        var = self.try_get(cache_key)
+        if shape is None and not assignment:
+            # since there's no assignment to a variable, use any shape here
+            # other than None to avoid the unknown shape failures
+            shape = ()
+        if not var:
+            var = self.newvar(bounds, dtype, shape)
+            self.put(cache_key, var)
+            if write:
+                if V.kernel.current_node:
+                    V.kernel.current_node.codegen_originating_info(
+                        buffer, only_once=True
+                    )
+                if isinstance(expr, IndentedBuffer):
+                    if assignment:
+                        buffer.writeline(f"{self.prefix}{var} =")
+                    buffer.splice(expr)
+                    buffer.writeline(self.suffix)
+                elif isinstance(expr, DeferredLineBase):
+                    assert assignment
+                    buffer.writeline(
+                        expr._new_line(f"{self.prefix}{var} = {expr.line}{self.suffix}")
+                    )
+                else:
+                    if assignment:
+                        line = f"{self.prefix}{var} = {expr}{self.suffix}"
+                    else:
+                        line = f"{expr}{self.suffix}"
+                    buffer.writeline(line)
+
+                    # cpp backend cannot determine is_vec at this point
+                    if (
+                        assignment
+                        and (
+                            config.test_configs.runtime_triton_dtype_assert
+                            or config.test_configs.static_cpp_dtype_assert
+                        )
+                        and dtype is not None
+                        and get_current_backend() != "cpp"
+                    ):
+                        check_dtype(buffer, var, dtype)
+
+        else:
+            var.bounds = var.bounds.tighten(bounds)
+            var.use_count += 1
+
+        return var
+
+    def newvar(
+        self,
+        bounds: ValueRanges[Any] = ValueRanges.unknown(),
+        dtype: Optional[torch.dtype] = None,
+        shape: BlockShapeType = None,
+    ) -> CSEVariableType:
+        var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
+        var = V.kernel.create_cse_var(var_name, bounds, dtype, shape)
+        self.varname_map[var_name] = var
+        return var
+
+    def namedvar(
+        self,
+        name: str,
+        bounds: ValueRanges[Any] = ValueRanges.unknown(),
+        dtype: Optional[torch.dtype] = None,
+        shape: BlockShapeType = None,
+    ) -> CSEVariableType:
+        torch._check_value(
+            name not in self.varname_map, lambda: f"duplicate name: {name}"
+        )
+        var = V.kernel.create_cse_var(name, bounds, dtype, shape)
+        self.varname_map[name] = var
+        return var
+
+
+class CodeGen:
+    def __init__(self) -> None:
+        super().__init__()
+        self.exit_stack = contextlib.ExitStack()
+
+    def __enter__(self) -> Self:
+        self.exit_stack.__enter__()
+        return self
+
+    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
+        self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
+
+
+class Kernel(CodeGen, Generic[CSEVariableType]):
+    newvar_prefix: str = ""
+    suffix: str = ""
+    overrides: Optional[Callable[[], OpsHandler[Any]]] = None
+
+    def __init__(
+        self, args: Optional[KernelArgs] = None, increase_kernel_count: bool = True
+    ) -> None:
+        super().__init__()
+        if increase_kernel_count:
+            # pyrefly: ignore [bad-assignment]
+            metrics.generated_kernel_count += 1
+        self.args = args or KernelArgs()
+        self.loads = IndentedBuffer()
+        self.compute = IndentedBuffer()
+        self.stores = IndentedBuffer()
+
+        self.atomic_add_found = False
+        self.num_load = 0
+        self.num_store = 0
+        self.num_reduction = 0
+
+        self.cse: CSE[CSEVariableType, Any] = CSE(self.newvar_prefix, self.suffix)
+        self.must_keep_buffers: OrderedSet[str] = OrderedSet()
+        self.store_buffer_names: OrderedSet[str] = OrderedSet()
+        self._load_mask: Optional[str] = None
+        self._load_other: Union[None, int, float] = None
+        # OrderedSet in set_current_node
+        self.current_node: Optional[SchedulerNode] = None
+        self.node_to_bounds: Optional[dict[torch.fx.Node, ValueRanges[Any]]] = None
+
+        self.removed_buffers: OrderedSet[str] = OrderedSet()
+        self.inplaced_to_remove: OrderedSet[str] = OrderedSet()
+
+        # key: the buffer to write
+        # value: the buffer to read and whose memory can be reused for
+        #   the buffer specified by key
+        self.inplace_update_buffers: dict[str, str] = {}
+        # Set minimum number of elements processed per thread.
+        self.min_elem_per_thread = 1
+        self.kernel_name: Optional[str] = None
+
+    @contextlib.contextmanager
+    def set_current_node(self, node: SchedulerNode) -> Iterator[None]:
+        prior = self.current_node
+        self.current_node = node
+        self.node_to_bounds = node._body.bounds().get_bounds()
+        try:
+            yield
+        finally:
+            self.current_node = prior
+
+    @contextlib.contextmanager
+    def swap_buffers(
+        self,
+        lb: IndentedBuffer,
+        cb: Optional[IndentedBuffer] = None,
+        sb: Optional[IndentedBuffer] = None,
+    ) -> Iterator[None]:
+        if cb is None:
+            cb = lb
+        if disallow_stores := sb is None:
+            sb = IndentedBuffer()
+        loads = self.loads
+        compute = self.compute
+        stores = self.stores
+        cse = self.cse
+        self.loads = lb
+        self.compute = cb
+        self.stores = sb
+        self.cse = cse.scoped_copy()
+        try:
+            yield
+        finally:
+            self.loads = loads
+            self.compute = compute
+            self.stores = stores
+            self.cse = cse
+            # pyrefly: ignore [unbound-name]
+            if disallow_stores:
+                assert not sb, "unexpected store inside swap_buffers"
+
+    def load(self, name: str, index: sympy.Expr) -> CSEVariable:
+        raise NotImplementedError
+
+    def indirect_load(self, name: str, index: sympy.Expr) -> CSEVariable:
+        """A load the depends on an index we have read"""
+        prior = self.loads
+        try:
+            # put the load in the compute section as it might have deps
+            self.loads = self.compute
+            return self.load(name, index)
+        finally:
+            self.loads = prior
+
+    def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None:
+        raise NotImplementedError
+
+    def store(
+        self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
+    ) -> None:
+        raise NotImplementedError
+
+    def device_assert_async(self, cond: CSEVariable, msg: str) -> None:
+        raise NotImplementedError(
+            f"{type(self).__name__}: device_assert_async should be handled by CSEProxy"
+        )
+
+    def reduction(
+        self,
+        dtype: torch.dtype,
+        src_dtype: torch.dtype,
+        reduction_type: ReductionType,
+        value: Union[CSEVariable, tuple[CSEVariable, ...]],
+    ) -> Union[CSEVariable, tuple[CSEVariable, ...]]:
+        raise NotImplementedError
+
+    def partial_accumulate(
+        self,
+        name: str,
+        reduction_type: ReductionType,
+        value: CSEVariable,
+        extra_meta: dict[str, Any],
+    ) -> None:
+        raise NotImplementedError
+
+    def scan(
+        self,
+        dtypes: tuple[torch.dtype, ...],
+        combine_fn: Callable[
+            [tuple[CSEVariable, ...], tuple[CSEVariable, ...]], tuple[CSEVariable, ...]
+        ],
+        values: tuple[CSEVariable, ...],
+    ) -> tuple[CSEVariable, ...]:
+        raise NotImplementedError
+
+    def sort(
+        self,
+        dtypes: tuple[torch.dtype, ...],
+        values: tuple[CSEVariable, ...],
+        stable: bool,
+        descending: bool,
+    ) -> tuple[CSEVariable, ...]:
+        raise NotImplementedError
+
+    def var_ranges(self) -> dict[sympy.Symbol, sympy.Expr]:
+        raise NotImplementedError
+
+    def bucketize(
+        self,
+        values: CSEVariable,
+        boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
+        boundary_indices: CSEVariable,
+        indexing_dtype: torch.dtype,
+        right: bool,
+        sorter: Optional[tuple[str, sympy.Expr]] = None,
+        sorter_indices: Optional[CSEVariable] = None,
+    ) -> CSEVariable:
+        """
+        See [Note: Inductor bucketize op]
+        """
+        raise NotImplementedError
+
+    @property
+    def assert_function(self) -> str:
+        raise NotImplementedError
+
+    def indirect_assert(
+        self,
+        var: Union[CSEVariable, str],
+        lower: Optional[str],
+        upper: Optional[str],
+        mask: Optional[Union[CSEVariable, str]] = None,
+    ) -> str:
+        if isinstance(var, CSEVariable):
+            var = str(var)
+        assert isinstance(var, str), type(var)
+        assert lower is None or isinstance(lower, str)
+        assert upper is None or isinstance(upper, str)
+        if lower and upper:
+            # The conditions need to be in parens because of Python's operator precedence.
+            # It'd be less error-prone to use and/or/not, which is supported by triton
+            cond = f"({lower} <= {var}) & ({var} < {upper})"
+            cond_print = f"{lower} <= {var} < {upper}"
+        elif lower:
+            cond = f"{lower} <= {var}"
+            cond_print = cond
+        else:
+            assert upper
+            cond = f"{var} < {upper}"
+            cond_print = cond
+
+        if mask:
+            cond = f"({cond}) | ~({mask})"
+
+        return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")'
+
+    def check_bounds(
+        self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
+    ) -> None:
+        raise NotImplementedError
+
+    def index_to_str(self, index: sympy.Expr) -> str:
+        raise NotImplementedError
+
+    def __enter__(self) -> Self:
+        super().__enter__()
+        assert self.overrides
+        self.exit_stack.enter_context(
+            V.set_ops_handler(CSEProxy(self, self.overrides()))
+        )
+        self.exit_stack.enter_context(V.set_kernel_handler(self))
+        return self
+
+    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
+        self.remove_kernel_local_buffers()
+        super().__exit__(exc_type, exc_val, exc_tb)
+
+    def remove_kernel_local_buffers(self) -> None:
+        """
+        Any buffers that are both created and have a last use in the
+        same kernel can be removed.
+
+        Note that V.graph.scheduler can be None when codegening triton template
+        kernels.
+        """
+        scheduler = V.graph.scheduler
+        if not scheduler:
+            return
+        fused_node_names = OrderedSet(
+            scheduler.name_to_buf[buf].defining_op_name()
+            for buf in self.store_buffer_names
+            if buf in scheduler.name_to_buf
+        )
+        names_to_remove: OrderedSet[str] = OrderedSet()
+        for name in self.store_buffer_names:
+            if (
+                name not in self.must_keep_buffers
+                and name not in self.args.input_buffers
+                and scheduler.can_buffer_be_removed_through_fusion(
+                    name, fused_node_names
+                )
+            ):
+                self.num_store -= 1
+                names_to_remove.add(name)
+
+        for name in names_to_remove:
+            if name in self.args.inplace_buffers:
+                buf = self.args.inplace_buffers[name]
+                if isinstance(buf, RemovedArg):
+                    continue
+                remove = all(n in names_to_remove for n in buf.other_names)
+                if remove:
+                    self.remove_inplace_buffer(name)
+                self.inplaced_to_remove.add(name)
+            else:
+                self.remove_buffer(name)
+
+    def remove_buffer(self, name: str) -> None:
+        # Assign a special value instead of deleting the entry
+        # because we still rely on output_buffers's length to
+        # generate unique arg name.
+        log.debug("remove_buffer(%r)", name)
+        self.args.output_buffers[name] = REMOVED
+        self.removed_buffers.add(name)
+
+    def remove_inplace_buffer(self, name: str) -> None:
+        log.debug("removing_inplace_buffer(%r)", name)
+        self.args.inplace_buffers[name] = REMOVED
+        self.removed_buffers.add(name)
+
+    def rename_indexing(
+        self, index: Union[list[sympy.Expr], tuple[sympy.Expr, ...], sympy.Expr]
+    ) -> sympy.Expr:
+        # adds the necessary kernel args for index expressions
+        # and renames variables in index expressions to kernel arg names
+        if isinstance(index, (list, tuple)):
+            return [self.rename_indexing(x) for x in index]
+        index = V.graph.sizevars.simplify(index)
+        sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
+        replacements = {
+            x: self.args.size(x)
+            for x in sorted_symbols
+            if symbol_is_type(
+                x,
+                (
+                    SymT.UNBACKED_INT,
+                    SymT.SIZE,
+                    SymT.PRECOMPUTED_SIZE,
+                    SymT.UNBACKED_FLOAT,
+                ),
+            )
+        }
+        return sympy_subs(index, replacements)
+
+    def create_cse_var(self, *args: Any, **kwargs: Any) -> CSEVariable:
+        return CSEVariable(*args, **kwargs)
+
+    def arg_name(self, node: IRNode) -> Optional[str]:
+        """
+        Returns arg name of a given input or output node.
+        """
+        if node is None:
+            return None
+        return self.args.arg_name(node.get_name())
+
+
+@dataclasses.dataclass
+class OptimizationContext:
+    key: ClassVar[str] = "opt_ctx"
+
+    dtype: Optional[torch.dtype] = None
+    ops_name: str = ""
+
+
+@functools.cache
+def jinja2_env() -> Any:
+    try:
+        import jinja2
+
+        return jinja2.Environment(
+            undefined=jinja2.StrictUndefined,
+        )
+    except ImportError:
+        return None
+
+
+class KernelTemplate:
+    """
+    Base class for defining kernel templates.
+
+    Children classes: TritonTemplate, CUDATemplate
+    """
+
+    @staticmethod
+    def indent_except_first(
+        source: str, num_indents: int, indents_spacing: int = 4
+    ) -> str:
+        lines = source.splitlines(True)
+        if len(lines) > 1:
+            lines[1:] = [
+                (" " * indents_spacing * num_indents) + line for line in lines[1:]
+            ]
+        return "".join(lines)
+
+    @staticmethod
+    def _template_from_string(source: str) -> Any:
+        env = jinja2_env()
+        if env is None:
+            return None
+        env.filters["indent_except_first"] = KernelTemplate.indent_except_first
+        from jinja2 import TemplateSyntaxError
+
+        try:
+            return env.from_string(source)
+        except TemplateSyntaxError as e:
+
+            class DetailedTemplateSyntaxError(TemplateSyntaxError):
+                def __init__(self, original_error: TemplateSyntaxError) -> None:
+                    super().__init__(
+                        # pyrefly: ignore [bad-argument-type]
+                        original_error.message,
+                        original_error.lineno,
+                        original_error.name,
+                        original_error.filename,
+                    )
+                    self.original_error = original_error
+
+                def __str__(self) -> str:
+                    error_info = f"Error in template at line {self.lineno}\n"
+                    error_info += f"Error message: {self.message}\n"
+                    if hasattr(self.original_error, "source"):
+                        # pyrefly: ignore [missing-attribute]
+                        lines = self.original_error.source.split("\n")
+                        error_info += "Context:\n"
+                        start = max(0, self.lineno - 2)
+                        end = min(len(lines), self.lineno + 2)
+                        for i in range(start, end):
+                            if i == self.lineno - 1:
+                                error_info += f"{i + 1}: --> {lines[i]}\n"
+                                if hasattr(self.original_error, "column"):
+                                    error_info += (
+                                        "     "
+                                        + " " * (self.original_error.column - 1)
+                                        + "^\n"
+                                    )
+                            else:
+                                error_info += f"{i + 1}:     {lines[i]}\n"
+                    return error_info
+
+            raise DetailedTemplateSyntaxError(e) from e
+
+    @staticmethod
+    def _fake_get_dtype(
+        fake_outs: Union[list[Buffer], Buffer],
+    ) -> Callable[[str], torch.dtype]:
+        _get_dtype_real = V.graph.get_dtype
+        if isinstance(fake_outs, (list, tuple)):
+            lookup = {buf.get_name(): buf.get_dtype() for buf in fake_outs}
+        else:
+            lookup = {fake_outs.get_name(): fake_outs.get_dtype()}
+
+        def get_dtype(name: str) -> torch.dtype:
+            result = lookup.get(name)
+            if result is not None:
+                return result
+            return _get_dtype_real(name)
+
+        return get_dtype
+
+    def __init__(self, name: str, hash: Optional[str] = None) -> None:
+        self.name = name
+        self._hash = hash
+
+    @property
+    def uid(self) -> str:
+        """
+        entry point to override for templates to ensure a uid e.g. through a prefix
+
+        the purpose of this is that every KernelTemplate/ExternKernelChoice is unique
+        in the system, but reproducible e.g. restarting pytorch should yield the same id
+        """
+        # TODO(coconutruben): add some central registration to assert on global uniqueness
+        return self.name
+
+    @property
+    def src_hash(self) -> Union[str, None]:
+        """
+        source hash for a Template.
+
+        Templates can optionally provide a src hash to make it easier to cache/validate that
+        a template has not changed from one version to another. Override this if that detection
+        is different for your specific Template
+        """
+        return self._hash
+
+    def choice_or_none(self, **kwargs: Any) -> Optional[ChoiceCaller]:
+        """
+        Maybe generates a new ChoiceCaller and returns it, or None if generation fails.
+
+        kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller.
+        """
+        temp_choices: list[Any] = []
+        result = self.maybe_append_choice(temp_choices, **kwargs)
+        if result is None and len(temp_choices) == 1:
+            return temp_choices[0]
+        return None
+
+    def maybe_append_choice(
+        self, choices: list[Any], **kwargs: Any
+    ) -> Optional[NotImplementedError]:
+        """
+        Maybe generates a new ChoiceCaller and appends it into existing choices.
+        Returns None if success, otherwise returns the error.
+
+        choices: A list of ChoiceCallers.
+        kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller.
+        """
+
+        try:
+            choices.append(self.generate(**kwargs))
+            return None
+        except NotImplementedError as e:
+            log.info(  # noqa: G200
+                "Cannot Append Choice: %s. KernelTemplate type is %s",
+                e,
+                type(self),
+                stack_info=log.getEffectiveLevel() < logging.INFO,
+            )
+            return e
+
+    def generate(self, **kwargs: Any) -> ChoiceCaller:
+        """
+        Generates a ChoiceCaller instance from the given arguments.
+        """
+
+        raise NotImplementedError
+
+
+class CSEProxy(DefaultHandler):
+    """A ops handler that proxies calls to `kernel` and its
+    handler and returns `CSEVariable`s with correct shape and dtype.
+    """
+
+    name = "CSEProxy"
+
+    def __init__(self, kernel: Kernel[Any], parent_handler: OpsHandler[Any]):
+        super().__init__()
+        from ..bounds import ValueRangeAnalysis
+
+        self.vr_analysis = ValueRangeAnalysis()
+        self.kernel = kernel
+        self.parent_handler = parent_handler
+
+    def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
+        bounds = self._bound_variable(name, *args, **kwargs)
+
+        value = getattr(self.parent_handler, name)(*args, **kwargs)
+        dtype_handler = DtypePropagationOpsHandler()
+        shape_handler = ShapePropagationOpsHandler()
+
+        backend = get_current_backend()
+
+        shape_op = getattr(shape_handler, name)
+        output_dtype = None
+        output_shape = None
+
+        if name == "masked" and backend == "triton":
+            output_dtype = value.dtype
+            output_shape = value.shape
+        elif name == "masked" and backend == "cpp":
+            output_dtype = V.interpreter.current_node.meta.get(
+                OptimizationContext.key, None
+            ).dtype
+            # TODO: fix me
+            output_shape = None
+        elif backend in ("triton", "cpp", "mps"):
+            dtype_op = getattr(dtype_handler, name)
+            output_dtype = dtype_op(*args, **kwargs)
+            output_shape = shape_op(*args, **kwargs)
+
+        if backend in ("triton", "cpp"):
+            # maybe there are some exceptions on mps?
+            assert output_dtype is not None
+
+        output_idx = 0
+
+        def do_cse(v: Union[str, CSEVariable]) -> CSEVariable:
+            # we tree_map over the output, so we need to fetch corresponding dtype
+            nonlocal output_idx
+            var_dtype: Optional[torch.dtype] = (
+                output_dtype[output_idx]
+                if isinstance(output_dtype, (list, tuple))
+                else output_dtype
+            )
+            var_shape: BlockShapeType = (
+                output_shape[output_idx]  # type: ignore[assignment]
+                if isinstance(output_shape, (list, tuple))
+                and len(output_shape) > 0
+                and isinstance(output_shape[0], (list, tuple))
+                else output_shape
+            )
+            output_idx += 1
+
+            # some cpp op implementations don't set the dtype
+            if isinstance(v, CSEVariable):
+                if backend == "cpp" and v.dtype is None:
+                    v.dtype = var_dtype
+                if v.shape is None:
+                    v.shape = var_shape
+
+            csevar = V.kernel.cse.generate(
+                V.kernel.compute,
+                v,
+                bounds=bounds,
+                dtype=output_dtype,
+                shape=output_shape,
+            )
+
+            csevar.update_on_args(name, args, kwargs)
+
+            if (
+                config.test_configs.runtime_triton_dtype_assert
+                or config.test_configs.static_cpp_dtype_assert
+            ):
+                assert var_dtype is not None
+                check_dtype(V.kernel.compute, csevar, var_dtype)
+
+            if config.test_configs.runtime_triton_shape_assert:
+                assert output_shape is not None
+                check_shape(V.kernel.compute, csevar, output_shape)
+
+            if config.runtime_triton_nan_asserts:
+                check_nan(V.kernel.compute, csevar)
+
+            return csevar
+
+        return pytree.tree_map(do_cse, value)
+
+    def _bound_variable(self, name: str, *args: Any, **kwargs: Any) -> ValueRanges[Any]:
+        """
+        If the variable comes from an FX node, we forward the bound we have already computed
+        Else, if the variable when codegen'ing another op, we try to compute its bounds
+        """
+        from ..bounds import ValueRangeAnalysis
+        from ..select_algorithm import TritonTemplateKernel
+        from .cuda.cuda_kernel import CUDATemplateKernel
+
+        if isinstance(V.kernel, TritonTemplateKernel):
+            return ValueRanges.unknown()
+
+        if isinstance(V.kernel, CUDATemplateKernel):
+            return ValueRanges.unknown()
+
+        if isinstance(V.interpreter, NullHandler):
+            return ValueRanges.unknown()
+
+        fx_node = V.interpreter.current_node
+        if fx_node.target == name and self.kernel.node_to_bounds is not None:
+            assert isinstance(self.kernel.node_to_bounds, dict), type(
+                self.kernel.node_to_bounds
+            )
+            return self.kernel.node_to_bounds.get(fx_node, ValueRanges.unknown())
+        elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name):
+            # These create lots of inner strings. We would need to compute the bounds at the ops
+            # We will also likely not get much from computing VRs on these nodes
+            if any(s in fx_node.target for s in ("set_indirect", "reduction", "scan")):
+                return ValueRanges.unknown()
+
+            # We assume that the inputs come from `ops.` and are not strings. If you want to generate
+            # intermediary strings, wrap them in CSE variables with properly initialised bounds.
+
+            # If there is no FX bound but we know how to compute one we do so
+            assert not kwargs
+
+            def arg_to_bound(x: Any) -> Any:
+                if isinstance(x, CSEVariable):
+                    return x.bounds
+                elif isinstance(x, sympy.Expr):
+                    return bound_sympy(x)
+                else:
+                    return x
+
+            arg_bounds = list(map(arg_to_bound, args))
+            return getattr(self.vr_analysis, name)(*arg_bounds)
+        return ValueRanges.unknown()
+
+    def indirect_indexing(
+        self,
+        var: CSEVariable,
+        size: Union[sympy.Expr, int],
+        check: bool = True,
+        wrap_neg: bool = True,
+    ) -> sympy.Symbol:
+        if isinstance(size, int):
+            size = sympy.Integer(size)
+        assert isinstance(size, sympy.Expr), (type(size), size)
+        # Skip CSE since this doesn't return an expression
+
+        if var.bounds.lower < 0:
+            if wrap_neg:
+                stm = ops.add(var, ops.index_expr(size, torch.long))
+                # Mixed negative and non-negative
+                if var.bounds.upper >= 0:
+                    lt = ops.lt(var, 0)
+                    stm = ops.where(lt, stm, var)
+            else:
+                stm = var
+
+            # Propagate bounds as we know how to compute them properly
+            new_bounds = ValueRanges.unknown()
+            if var.bounds != ValueRanges.unknown() and isinstance(size, sympy.Number):
+                # Take the negative part of the bound and add size to it
+                # Then take union of that and the positive part
+                # This is a tighter bound than that of a generic ops.where, as we have info on the cond
+                neg_bounds = var.bounds & ValueRanges(-int_oo, -1)
+                new_bounds = ValueRanges(
+                    neg_bounds.lower + size, neg_bounds.upper + size
+                )
+                # We don't have a good way of representing the empty range
+                if var.bounds.upper >= 0:
+                    pos = var.bounds & ValueRanges(0, int_oo)
+                    new_bounds = new_bounds | pos
+
+            var = self.kernel.cse.generate(
+                self.kernel.compute,
+                stm,
+                bounds=new_bounds,
+                dtype=var.dtype,
+                shape=var.shape,
+            )
+
+        sympy_var = self.parent_handler.indirect_indexing(var, size, check)
+        if generate_assert(check):
+            assert_lower = not (var.bounds.lower >= 0)
+            # value ranges cannot x < s when x and s are symbols
+            assert_upper = not isinstance(size, sympy.Number) or not (
+                var.bounds.upper < size
+            )
+            self.kernel.check_bounds(sympy_var, size, assert_lower, assert_upper)
+        return sympy_var
+
+    def check_bounds(
+        self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
+    ) -> None:
+        return self.kernel.check_bounds(expr, size, lower, upper)
+
+    def load(self, name: str, index: sympy.Expr) -> CSEVariable:
+        if name in self.kernel.cse.invalidated_stores:
+            # A load from an invalidated store requires us to
+            # keep the actual buffer around
+            V.kernel.must_keep_buffers.add(name)
+        if free_symbol_is_type(index, SymT.TMP):
+            return self.kernel.indirect_load(name, index)
+        store_cache = self.kernel.cse.store_cache
+        if name in store_cache:
+            return store_cache[name]
+        out = self.kernel.load(name, index)
+        # count load that is not in the store_cache, and also not in the
+        # cse cache.
+        if out.use_count == 1:
+            self.kernel.num_load += 1
+        return out
+
+    def _update_store_cache(self, name: str, value: CSEVariable) -> None:
+        self.kernel.cse.store_cache[name] = value
+        if self.kernel.current_node and name in V.graph.name_to_buffer:
+            buf = self.kernel.current_node.get_output(name)
+            for other_name in buf.get_mutations():
+                self.kernel.cse.store_cache[other_name] = value
+
+    def store(
+        self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
+    ) -> None:
+        self.kernel.store_buffer_names.add(name)
+        if mode is None:
+            self._update_store_cache(name, value)
+        if name not in V.graph.removed_buffers:
+            self.kernel.store(name, index, value, mode=mode)
+            self.kernel.num_store += 1
+
+    def device_assert_async(self, cond: CSEVariable, msg: str) -> None:
+        self.kernel.device_assert_async(cond, msg)
+
+    # pyrefly: ignore [bad-override]
+    def partial_accumulate(self, *args: Any) -> None:
+        self.kernel.partial_accumulate(*args)
+
+    def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None:
+        self.kernel.store_buffer_names.add(name)
+        self._update_store_cache(name, value)
+
+        if name not in V.graph.removed_buffers:
+            self.kernel.num_store += 1
+            return self.kernel.store_reduction(name, index, value)
+
+    def reduction(
+        self,
+        dtype: torch.dtype,
+        src_dtype: torch.dtype,
+        reduction_type: ReductionType,
+        value: Union[CSEVariable, tuple[CSEVariable, ...]],
+    ) -> Union[CSEVariable, tuple[CSEVariable, ...]]:
+        self.kernel.num_reduction += 1
+        return self.kernel.reduction(dtype, src_dtype, reduction_type, value)
+
+    def scan(
+        self,
+        dtypes: tuple[torch.dtype, ...],
+        combine_fn: Callable[
+            [tuple[CSEVariable, ...], tuple[CSEVariable, ...]],
+            tuple[CSEVariable, ...],
+        ],
+        values: tuple[CSEVariable, ...],
+    ) -> tuple[CSEVariable, ...]:
+        return self.kernel.scan(dtypes, combine_fn, values)
+
+    def sort(
+        self,
+        dtypes: tuple[torch.dtype, ...],
+        values: tuple[CSEVariable, ...],
+        stable: bool,
+        descending: bool,
+    ) -> tuple[CSEVariable, ...]:
+        return self.kernel.sort(dtypes, values, stable, descending)
+
+    def bucketize(
+        self,
+        values: CSEVariable,
+        boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
+        boundary_indices: CSEVariable,
+        indexing_dtype: torch.dtype,
+        right: bool,
+        sorter: Optional[tuple[str, sympy.Expr]] = None,
+        sorter_indices: Optional[CSEVariable] = None,
+    ) -> CSEVariable:
+        """
+        [Note: Inductor bucketize op]
+
+        Inputs:
+        -------
+        values: the values to be bucketized.
+        boundaries: a tuple containing
+          (a) the name of the boundaries tensor (which must be sorted, unless
+          the sorting tensor is present),
+          (b) the length of the tensor in the last dimension (i.e. the length of
+          one set of boundaries),
+          (c) the number of elements in the underlying storage (i.e. the length
+          of the flattened tensor, ignoring striding), and
+          (d) the stride of the tensor in the last dimension.
+        boundary_indices: indices into a flattened version of the boundaries
+        tensor, of the same size and shape as "values".  Each index points to
+        the first element in the set of boundaries to be used for the
+        corresponding value.
+        indexing_dtype: the dtype to use when indexing into the boundaries
+        tensor.  This must be int64 or int32.  This additionally specifies the
+        dtype of the return value.
+        right: see "Details" below.
+        sorter: an optional tuple containing
+          (a) the name of an optional sorting tensor, used to access unsorted
+          boundaries without reordering the boundaries tensor, and
+          (b) the stride of the tensor in the last dimension.
+        The values in the sorting tensor are used as indices into the *last*
+        dimension of the boundaries tensor, with all other indices matching.
+        The size of the sorting and boundaries tensors must be equivalent.
+        sorter_indices: must be present if the sorting array is present; see
+        "boundary_indices" for the equivalent definition for the boundaries
+        tensor.
+
+        Output:
+        -------
+        The buckets each value belongs in, within a given set of boundaries.  0
+        indicates a position before the first boundary, and len(boundaries_set)
+        represents a position after the last boundary.
+
+        Details:
+        --------
+        Given a value and a set of boundaries, calculate the bucket that each
+        value belongs to.  This works differently in 1-D and N-D cases.
+
+        for values [[-1, 0, 1, 2], [3, 4, 5, 9]], boundaries [0, 4, 4, 8], right=True
+        return =   [[ 0, 1, 1, 1], [1, 3, 3, 4]].
+
+        for values [[-1, 0, 1, 2], [3, 4, 5, 9]], boundaries [[0, 4], [4, 8]], right=True
+        return =   [[ 0, 1, 1, 1], [0, 1, 1, 2]]
+
+        Note that in the N-D boundaries case, the shape of "values" and
+        "boundaries" must match in every dimension _except_ the last.
+
+        When right == False, bucket i refers to range (boundaries[i], boundaries[i+1]].
+        When right == True,  bucket i refers to range [boundaries[i], boundaries[i+1]).
+
+        Boundaries must be non-decreasing, or a sorter must be provided which
+        would re-index offsets in a non-decreasing order (e.g. the second output
+        of torch.sort(offsets)).  Otherwise, the result is undefined.
+        """
+        return self.kernel.bucketize(
+            values,
+            boundaries,
+            boundary_indices,
+            indexing_dtype,
+            right,
+            sorter,
+            sorter_indices,
+        )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9c45cd32981418fe1121c47c78aaac35b0e65b2
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp.py
@@ -0,0 +1,5826 @@
+# mypy: allow-untyped-defs
+import contextlib
+import dataclasses
+import functools
+import itertools
+import math
+import operator
+import re
+import sys
+import warnings
+from collections.abc import Callable, Sequence
+from enum import Enum
+from typing import Any, cast, Optional, Union
+
+import sympy
+
+import torch
+import torch.fx
+from torch._inductor import dependencies
+from torch._prims_common import is_float_dtype, is_integer_dtype
+from torch.utils._ordered_set import OrderedSet
+from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
+from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
+
+from ..._dynamo.utils import counters
+from .. import config, cpp_builder, cpu_vec_isa, ir, metrics
+from ..debug import set_kernel_post_grad_provenance_tracing
+from ..loop_body import LoopBody
+from ..scheduler import (
+    BaseSchedulerNode,
+    BaseScheduling,
+    ExternKernelSchedulerNode,
+    ForeachKernelSchedulerNode,
+    FusedSchedulerNode,
+    Scheduler,
+    SchedulerNode,
+)
+from ..utils import (
+    cache_on_self,
+    get_bounds_index_expr,
+    get_fused_kernel_name,
+    has_free_symbols,
+    is_multi_outputs_template,
+    is_welford_reduction,
+    parallel_num_threads,
+    Placeholder,
+    sympy_index_symbol,
+    sympy_index_symbol_with_prefix,
+    sympy_product,
+    sympy_subs,
+)
+from ..virtualized import NullKernelHandler, ops, OpsValue, V
+from .common import (
+    BackendFeature,
+    BracesBuffer,
+    CSE,
+    CSEVariable,
+    DataTypePropagation,
+    DeferredLine,
+    DTYPE_TO_COMPUTATION_DTYPE,
+    IndentedBuffer,
+    Kernel,
+    KernelArgs,
+    OpOverrides,
+    OptimizationContext,
+)
+from .cpp_utils import (
+    _get_dtype_from_loopbodies,
+    _get_loop_body,
+    cexpr,
+    cexpr_index,
+    codegen_rand,
+    CppCSEVariable,
+    DTYPE_TO_CPP,
+    get_promote_dtype,
+    INDEX_TYPE,
+    LocalBufferContext,
+    may_unify_binary_op_mask_type,
+    promote_args,
+    template_fusion_with_epilogues_supported,
+    unify_mask_base_type,
+    value_to_cpp,
+)
+
+
+_IS_WINDOWS = sys.platform == "win32"
+
+
+@functools.cache
+def get_export_declaration():
+    return "__declspec(dllexport)" if _IS_WINDOWS else ""
+
+
+schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
+
+NATIVE_OMP_RTYPES = OrderedSet(["+", "*", "^", "||", "min", "max"])
+RTYPE_TO_CPP = {
+    "sum": "+",
+    "prod": "*",
+    "xor_sum": "^",
+    "min": "min",
+    "max": "max",
+    "argmin": "argmin",
+    "argmax": "argmax",
+    "any": "||",
+    "welford_reduce": "welford",
+    "welford_combine": "welford",
+}
+VECTORIZABLE_RTYPES = OrderedSet(
+    [
+        "max",
+        "min",
+        "sum",
+        "prod",
+        "xor_sum",
+        "welford_reduce",
+        "welford_combine",
+        "argmin",
+        "argmax",
+        "any",
+    ]
+)
+
+PYTHON_TO_CPP = {
+    "Tensor": "at::Tensor",
+    "int": "long",
+    "float": "double",
+    "bool": "bool",
+    "str": "std::string",
+    "ScalarType": "c10::ScalarType",
+    "MemoryFormat": "at::MemoryFormat",
+    "Layout": "at::Layout",
+    "Device": "at::Device",
+    "number": "at::Scalar",
+}
+
+CONTAINER_PYTHON_TO_CPP = {
+    "List": "std::vector",
+    "Optional": "std::optional",
+}
+
+DTYPE_LOWP_FP = [
+    torch.bfloat16,
+    torch.float16,
+]
+
+VECTORIZABLE_DTYPES: list[torch.dtype] = [
+    torch.float64,
+    torch.float,
+    torch.bfloat16,
+    torch.float16,
+    torch.bool,
+    torch.uint8,
+    torch.int8,
+    torch.int32,
+    torch.int64,
+    torch.float8_e4m3fn,
+    torch.float8_e5m2,
+]
+
+
+def reduction_init(reduction_type, dtype):
+    if dtype in DTYPE_LOWP_FP:
+        # Since load promotes all half-precision inputs to float, the initial
+        # constant for reduction must be promoted as well
+        dtype = torch.float32
+    if reduction_type in ("xor_sum", "sum", "any"):
+        return 0
+    if reduction_type == "prod":
+        return 1
+    if reduction_type in ("max", "argmax", "min", "argmin"):
+        cdtype = DTYPE_TO_CPP[dtype]
+        if dtype == torch.bool and reduction_type in ("argmin", "argmax"):
+            cdtype = DTYPE_TO_CPP[torch.float]
+        min_var = (
+            f"-std::numeric_limits<{cdtype}>::infinity()"
+            if is_float_dtype(dtype)
+            else f"std::numeric_limits<{cdtype}>::min()"
+        )
+        max_var = (
+            f"std::numeric_limits<{cdtype}>::infinity()"
+            if is_float_dtype(dtype)
+            else f"std::numeric_limits<{cdtype}>::max()"
+        )
+        init_var = min_var if reduction_type in ("max", "argmax") else max_var
+        return (
+            init_var
+            if reduction_type in ("max", "min")
+            else f"IndexValue<{cdtype}>{{0, {init_var}}}"
+        )
+    if is_welford_reduction(reduction_type):
+        return f"Welford<{DTYPE_TO_CPP[dtype]}>()"
+    raise AssertionError(reduction_type)
+
+
+def reduction_acc_type(reduction_type, dtype):
+    scalar_type = DTYPE_TO_CPP[DTYPE_TO_COMPUTATION_DTYPE[dtype]]
+    if is_welford_reduction(reduction_type):
+        return f"Welford<{scalar_type}>"
+    if reduction_type in ("argmin", "argmax"):
+        if dtype == torch.bool:
+            scalar_type = DTYPE_TO_CPP[torch.float]
+        return f"IndexValue<{scalar_type}>"
+    return scalar_type
+
+
+def reduction_combine(
+    reduction_type,
+    var,
+    next_value,
+    helper_val=None,
+    index: Optional[sympy.Symbol] = None,
+    src_dtype=None,
+):
+    is_bool = src_dtype == torch.bool
+    if reduction_type == "sum":
+        if helper_val:
+            return f"cascade_sum_combine({next_value}, &{helper_val})"
+        else:
+            conjunction = "|" if is_bool else "+"
+            return f"{var} {conjunction} {next_value}"
+    if reduction_type == "prod":
+        return f"{var} * {next_value}"
+    if reduction_type == "xor_sum":
+        return f"{var} ^ {next_value}"
+    if reduction_type == "any":
+        return f"{var} || {next_value}"
+    if reduction_type in ("min", "max"):
+        return f"{reduction_type}_propagate_nan({var}, {next_value})"
+    if reduction_type == "welford_reduce":
+        if helper_val:
+            return f"welford_combine({var}, {next_value}, &{helper_val})"
+        else:
+            return f"welford_combine({var}, {next_value})"
+    if reduction_type == "welford_combine":
+        if isinstance(next_value, tuple):
+            mean, m2, weight = next_value
+        else:
+            mean, m2, weight = reduction_project(reduction_type, next_value)
+        return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})"
+    if reduction_type in ("argmin", "argmax"):
+        if (
+            hasattr(next_value, "dtype")
+            and next_value.dtype == torch.bool
+            and not next_value.is_vec
+        ):
+            if index is not None:
+                return f"{reduction_type}_combine({var}, static_cast({next_value}), {index})"
+            else:
+                return (
+                    f"{reduction_type}_combine({var}, static_cast({next_value}))"
+                )
+        if index is not None:
+            return f"{reduction_type}_combine({var}, {next_value}, {index})"
+        else:
+            return f"{reduction_type}_combine({var}, {next_value})"
+    raise AssertionError(reduction_type)
+
+
+def reduction_project(reduction_type, acc):
+    if is_welford_reduction(reduction_type):
+        return f"{acc}.mean", f"{acc}.m2", f"{acc}.weight"
+    elif reduction_type in ("argmin", "argmax"):
+        return f"{acc}.index"
+    return acc
+
+
+def move_code_under_inner_loop(
+    code: IndentedBuffer,
+    iter_var: sympy.Expr,
+    new_iter_var: str,
+    loop_start: sympy.Expr,
+    loop_end: sympy.Expr,
+) -> BracesBuffer:
+    r"""
+    f(iter_var) is transformed to f(new_iter_var) under the inner loop
+      \/
+    for (new_iter_var = loop_start; new_iter_var < loop_end; new_iter_var++) {
+        f(new_iter_var)
+    }
+    Please be careful while using this function,
+    as the variable defined in f(iter_var) will be invalid outside the for loop.
+    For example:
+    auto tmp0 = in_ptr[x0]; ->
+    for (new_x0 = start; new_x0 < end; new_x0++){
+        auto tmp0 = in_ptr[new_x0];
+    }
+    The tmp0 is invalid outside the loop.
+    """
+    transformed_code = BracesBuffer()
+    with contextlib.ExitStack() as stack:
+        transformed_code.writeline(
+            f"for ({INDEX_TYPE} {new_iter_var} = {cexpr_index(loop_start)};"
+            + f"{new_iter_var} < {cexpr_index(loop_end)}; {new_iter_var}++)"
+        )
+        stack.enter_context(transformed_code.indent())
+        for _, line in enumerate(code._lines):
+            assert isinstance(
+                line,
+                (
+                    str,
+                    DeferredLine,
+                ),
+            )
+            deferred_name = None
+            if isinstance(line, DeferredLine):
+                deferred_name = line.name
+                line = line.line
+            new_line = re.sub(r"\b" + f"{iter_var}" + r"\b", f"{new_iter_var}", line)
+            if deferred_name:
+                new_line = DeferredLine(deferred_name, new_line)  # type: ignore[assignment]
+            transformed_code.writeline(new_line)
+    return transformed_code
+
+
+def reduction_prefix_array(
+    acc_var: Union[str, CSEVariable],
+    acc_type: str,
+    reduction_type: str,
+    dtype: torch.dtype,
+    len: Union[str, int],
+    init_fn,
+):
+    """
+    MSVC don't support dynamic array(VLA). So we use std::unique_ptr here.
+    Ref: https://stackoverflow.com/questions/56555406/creating-dynamic-sized-array-using-msvc-c-compiler
+    MSVC is the only one compiler without VLA. support. Since MSVC can't get good performance here.
+    We just use unique_ptr make it works on MSVC.
+    For other compilers, we continue to use VLA to get best performance.
+    """
+    code_buffer = IndentedBuffer()
+    acc_decl = (
+        f"auto {acc_var}_arr = std::make_unique<{acc_type}[]>({len});"
+        if cpp_builder.is_msvc_cl()
+        else f"{acc_type} {acc_var}_arr[{len}];"
+    )
+    code_buffer.writeline(f"{acc_decl}")
+    code_buffer.writelines(
+        [
+            f"for (int i = 0; i < {len}; i++)",
+            "{",
+            f"    {acc_var}_arr[i] = {init_fn(reduction_type, dtype)};",
+            "}",
+        ],
+    )
+    return code_buffer
+
+
+def replace_acc_name(buffer: IndentedBuffer, name: str, new_name: str):
+    for i, line in enumerate(buffer._lines):
+        assert isinstance(
+            line,
+            (
+                str,
+                DeferredLine,
+            ),
+        )
+        if isinstance(line, DeferredLine):
+            line.line = re.sub(r"\b" + f"{name}" + r"\b", f"{new_name}", line.line)
+        else:
+            buffer._lines[i] = re.sub(r"\b" + f"{name}" + r"\b", f"{new_name}", line)
+
+
+def replace_cascade_sum_with_add(buffer: IndentedBuffer):
+    """
+    Replaces `acc = cascade_sum_combine(value, ...)` with `acc = acc + value;`
+    """
+
+    pattern = r"(.*?)\s*=\s*cascade_sum_combine\(([^,]+),.*?\);"
+    for i, line in enumerate(buffer._lines):
+        assert isinstance(
+            line,
+            (
+                str,
+                DeferredLine,
+            ),
+        )
+        content = line.line if isinstance(line, DeferredLine) else line
+        match = re.search(pattern, content)
+        if match:
+            acc, value = match.groups()
+            new_content = re.sub(pattern, f"{acc} = {acc} + {value};", content)
+            if isinstance(line, DeferredLine):
+                line.line = new_content
+            else:
+                buffer._lines[i] = new_content
+
+
+@functools.lru_cache
+def stride_at(index: sympy.Expr, var: sympy.Symbol):
+    if not index.has(var):
+        # see test_torchinductor_dynamic_shapes.py::test_full_boolean_dynamic_shapes_cpu
+        # which has tmp0 = ops.index_expr(s0 >= 1024, torch.bool) and fails below calculation.
+        # in this case, there is no dependencies between index and var.
+        return sympy.S.Zero
+    replacement = {var: var + 1}
+    new_index = sympy_subs(index, replacement)  # type: ignore[arg-type]
+    return sympy.simplify(new_index - index)
+
+
+@functools.lru_cache
+def simplify_index_in_vec_range(index: sympy.Expr, var: sympy.Expr, vec_length: int):
+    """
+    Simplifies the index expression within the range of a vectorized loop.
+    Given a vectorized loop variable `var` in the range of a loop with `vec_length`,
+    this function transforms the `index` into an equivalent form. It handles
+    simplifications for cases where `var` can be expressed as `vec_length * a + b`,
+    where `b` ranges from 0 to `vec_length - 1`. The function reduces occurrences
+    of `FloorDiv` and `ModularIndexing` in the `index` with best-effort optimizations.
+
+    NOTE:
+    The simplified index expression is intended for analysis purposes only, not
+    for code generation. It replaces `FloorDiv` and `ModularIndexing` with free variables
+    which are not dependent on the loop variable `var` in the vectorized range. Check
+    https://github.com/pytorch/pytorch/pull/117221#discussion_r1449746217 for more details.
+
+    Examples:
+    1. If `var` is `x3` and `vec_length` is 16, and `x3 = 16*a + b`, then
+       `FloorDiv(x3, div)` or `ModularIndexing(x3, div, mod)` becomes a free variable
+       when `div` is divisible by 16.
+    2. `ModularIndexing(x3, 1, mod)` can be simplified to `x3 + c` where `c` is a free
+       variable when `mod` is divisible by 16.
+    """
+
+    div_freevar_id = 0
+    mod_freevar_id = 0
+
+    def visit_indexing_div(divisor):
+        nonlocal div_freevar_id
+        result = FloorDiv(var, divisor)
+        if sympy.gcd(divisor, vec_length) == vec_length:
+            result = sympy.Symbol(f"{var}_div_c{div_freevar_id}")
+            div_freevar_id += 1
+        return result
+
+    def visit_modular_indexing(divisor, modulus):
+        nonlocal mod_freevar_id
+        result = ModularIndexing(var, divisor, modulus)
+        if sympy.gcd(divisor, vec_length) == vec_length:
+            result = sympy.Symbol(f"{var}_mod_c{mod_freevar_id}")
+            mod_freevar_id += 1
+        elif divisor == 1 and sympy.gcd(modulus, vec_length) == vec_length:
+            result = var + sympy.Symbol(f"{var}_mod_c{mod_freevar_id}")
+            mod_freevar_id += 1
+        return result
+
+    original_index = index
+
+    div = sympy.Wild("divisor", integer=True)
+    if index.has(FloorDiv):
+        index = index.replace(FloorDiv(var, div), visit_indexing_div)
+
+    mod = sympy.Wild("modulus", integer=True)
+    if index.has(ModularIndexing):
+        index = index.replace(ModularIndexing(var, div, mod), visit_modular_indexing)
+
+    index = sympy.simplify(index)
+    if index != original_index:
+        return simplify_index_in_vec_range(index, var, vec_length)
+
+    return index
+
+
+@functools.lru_cache
+def stride_at_vec_range(
+    index: sympy.Expr, var: sympy.Symbol, vec_length: Optional[int] = None
+):
+    if vec_length:
+        index = simplify_index_in_vec_range(index, var, vec_length)
+    return stride_at(index, var)
+
+
+@dataclasses.dataclass
+class ParallelDepth:
+    """
+    A class representing parallel depth.
+    Includes the starting depth of parallelism and the depth of parallelism.
+    """
+
+    parallel_depth: int
+    start_depth: int
+
+
+class OuterLoopFusedSchedulerNode(FusedSchedulerNode):
+    @classmethod
+    def fuse(  # type: ignore[override]
+        cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode, outer_loop_fusion_depth
+    ):
+        assert node1.scheduler is node2.scheduler
+        assert all(
+            type(node)
+            in (
+                OuterLoopFusedSchedulerNode,
+                SchedulerNode,
+                FusedSchedulerNode,
+            )
+            for node in (node1, node2)
+        )
+        if any(type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)):
+            return cls(
+                node1.scheduler,
+                # pyrefly: ignore [bad-argument-type]
+                (
+                    list(node1.get_outer_nodes())
+                    if type(node1) is OuterLoopFusedSchedulerNode
+                    else [
+                        node1,
+                    ]
+                )
+                + (
+                    list(node2.get_outer_nodes())
+                    if type(node2) is OuterLoopFusedSchedulerNode
+                    else [
+                        node2,
+                    ]
+                ),
+                outer_loop_fusion_depth,
+            )
+        else:
+            return cls(node1.scheduler, [node1, node2], outer_loop_fusion_depth)  # type: ignore[list-item]
+
+    def __init__(
+        self,
+        scheduler: "Scheduler",
+        outer_fused_nodes: list[Union[FusedSchedulerNode, SchedulerNode]],
+        outer_loop_fusion_depth,
+    ):
+        self.outer_fused_nodes: list[Union[FusedSchedulerNode, SchedulerNode]] = (
+            outer_fused_nodes
+        )
+        self.outer_loop_fusion_depth = outer_loop_fusion_depth
+        flatten_snodes = []
+        for _node in self.outer_fused_nodes:
+            assert isinstance(_node, (SchedulerNode, FusedSchedulerNode))
+            flatten_snodes.extend(list(_node.get_nodes()))
+        super().__init__(scheduler, flatten_snodes)  # type: ignore[arg-type]
+
+    def get_outer_nodes(self):
+        return self.outer_fused_nodes
+
+    def check_outer_fusion_loop_level_attr(
+        self, cpp_kernel_proxy_list, outer_loop_fusion_depth
+    ):
+        # This function ensures that the same tiling split is applied at each loop level within the outer loop fusion depth.
+        # In the fusion stage, we only examine nodes with same vars and reduce.
+        # However, for nodes with same vars and reduce, the loops may still have different tile splits.
+        # For example (test_expr_vec_non_contiguous in test_cpu_repro.py):
+        #   * buf0 tiling along the 2nd loop level, buf1 tiling along the 3rd loop level.
+        # If the check failed, we should fall back to standard loop codegen.
+        def _inner(
+            left_loop_nest: LoopNest,
+            right_loop_nest: LoopNest,
+            loop_fusion_depth: int,
+            current_checking_depth: int,
+        ) -> bool:
+            assert left_loop_nest.loops
+            assert right_loop_nest.loops
+            left_loop_level = left_loop_nest.loops[current_checking_depth]
+            right_loop_level = right_loop_nest.loops[current_checking_depth]
+            # Check if same loop level attr
+            outer_loops_attr_compare_list = [
+                "var",
+                "size",
+                "offset",
+                "steps",
+            ]
+            if not (
+                all(
+                    getattr(left_loop_level, attr_compare)
+                    == getattr(right_loop_level, attr_compare)
+                    for attr_compare in outer_loops_attr_compare_list
+                )
+            ):
+                return False
+
+            assert loop_fusion_depth >= 1
+            if (loop_fusion_depth := loop_fusion_depth - 1) > 0:
+                # Check next loop level attr
+                current_checking_depth = current_checking_depth + 1
+                assert current_checking_depth < len(left_loop_nest.loops)
+                assert current_checking_depth < len(right_loop_nest.loops)
+                if not _inner(
+                    left_loop_nest,
+                    right_loop_nest,
+                    loop_fusion_depth,
+                    current_checking_depth,
+                ):
+                    return False
+
+            return True
+
+        for idx in range(len(cpp_kernel_proxy_list) - 1):
+            left_loop_nest = cpp_kernel_proxy_list[idx].loop_nest
+            right_loop_nest = cpp_kernel_proxy_list[idx + 1].loop_nest
+            if not _inner(
+                left_loop_nest,
+                right_loop_nest,
+                outer_loop_fusion_depth,
+                0,
+            ):
+                return False
+
+        for cpp_kernel_proxy in cpp_kernel_proxy_list:
+            outer_ranges = functools.reduce(
+                operator.mul,
+                cpp_kernel_proxy.ranges[:outer_loop_fusion_depth],
+            )
+            # When the range of the first inner loop is much larger than the range of
+            # all outer loops, do not fuse outer loop and fallback to standard loop codegen,
+            # so that the inner loops with larger range have a chance to be parallelized.
+            # We set a conservative threshold here:
+            # First inner loop range / all outer loops range > 300.
+            if (
+                len(cpp_kernel_proxy.ranges) > outer_loop_fusion_depth
+                and isinstance(outer_ranges, sympy.Integer)
+                and isinstance(
+                    cpp_kernel_proxy.ranges[outer_loop_fusion_depth],
+                    sympy.Integer,
+                )
+                and outer_ranges * 300
+                < cpp_kernel_proxy.ranges[outer_loop_fusion_depth]
+            ):
+                return False
+
+        return True
+
+    def merge_outer_fusion_kernels(
+        self,
+        cpp_kernel_proxy_list,
+    ):
+        kernel_group = cpp_kernel_proxy_list[0].kernel_group
+        outer_loop_fused_kernel = OuterLoopFusedKernel(kernel_group)
+        outer_loop_fused_kernel.inner = [
+            proxy.loop_nest.from_loop_level(self.outer_loop_fusion_depth)
+            for proxy in cpp_kernel_proxy_list
+        ]
+        outer_fused_proxy = cpp_kernel_proxy_list[0]
+        outer_fused_proxy.loop_nest.kernel = outer_loop_fused_kernel
+        outer_fused_proxy.loop_nest.loops = outer_fused_proxy.loop_nest.loops[
+            : self.outer_loop_fusion_depth
+        ]
+        return outer_fused_proxy
+
+
+class RecordOptimizationContext:
+    def __init__(self, func_name: str = ""):
+        self.func_name = func_name
+        self.current_node: Optional[torch.fx.Node] = None
+        self.opt_ctx: Optional[OptimizationContext] = None
+
+    def __enter__(self):
+        assert V.interpreter
+        assert V.interpreter.current_node
+
+        self.current_node = V.interpreter.current_node
+        assert self.current_node is not None
+        if OptimizationContext.key in self.current_node.meta:
+            self.opt_ctx = self.current_node.meta[OptimizationContext.key]
+        else:
+            self.opt_ctx = OptimizationContext()
+        assert self.opt_ctx is not None
+        self.opt_ctx.ops_name = self.func_name
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        assert self.current_node
+        assert self.opt_ctx
+        self.current_node.meta[OptimizationContext.key] = self.opt_ctx
+
+    def get_opt_ctx(self):
+        return self.opt_ctx
+
+    def get_fx_node(self):
+        assert self.current_node
+        return self.current_node
+
+
+def decltype_promoted(*args):
+    assert not any(isinstance(arg, CppCSEVariable) and arg.is_vec for arg in args), (
+        "Promotion of vector types is not supported"
+    )
+
+    if (dt := get_promote_dtype(args)) is not None:
+        return DTYPE_TO_CPP[dt]
+    else:
+        return f"decltype({args[0]})"
+
+
+class CppOverrides(OpOverrides):
+    """Map element-wise ops to C++"""
+
+    @staticmethod
+    def add(a, b):
+        return f"{decltype_promoted(a, b)}({a} + {b})"
+
+    @staticmethod
+    def sub(a, b):
+        return f"{decltype_promoted(a, b)}({a} - {b})"
+
+    @staticmethod
+    def mul(a, b):
+        return f"{decltype_promoted(a, b)}({a} * {b})"
+
+    @staticmethod
+    def to_dtype(x, dtype, src_dtype=None, use_compute_types=True):
+        assert isinstance(x, CppCSEVariable)
+        if src_dtype is None:
+            src_dtype = x.dtype
+        expr = V.kernel.get_to_dtype_expr(x, dtype, src_dtype)
+        csevar = V.kernel.cse.generate(V.kernel.compute, expr)
+        csevar.update_on_args("to_dtype", (x, dtype), {"src_dtype": src_dtype})
+        if dtype in DTYPE_LOWP_FP and src_dtype == torch.float:
+            """
+            https://github.com/pytorch/pytorch/issues/115260
+            For FusedSchedulerNode[node1, node2], the node2 loads what node1 stores and the buffer is
+            in low-precision floating point data type. When the output of node1 also serves as the output of the
+            kernel, the result of nodes would be different from the case when output of node1 is not the output
+            of the kernel (where we don't need to insert `to_dtype` for legalization). To address the problem, on
+            storing the lowp node1 output, we also add the inverse dtype conversion to high precision data type
+            to the cse cache.
+
+            Example (pseudo code):
+                node1_output = ...
+                node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16)
+                store(buf, node1_output_lowp)
+                node2_input_lowp = load(buf)
+                node2_input = to_dtype(node2_input_lowp, dtype=torch.float)
+
+            Without cse cache trick:
+                node1_output = ...
+                node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16)
+                store(buf, node1_output_lowp)
+                node2_input_lowp = node_output_lowp # hit store cache
+                node2_input = to_dtype(node2_input_lowp, dtype=torch.float)
+
+            With cse cache trick:
+                node1_output = ...
+                node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16)
+                # also add `to_dtype(node1_input_lowp, dtype=torch.float)` -> `node1_output` to cse cache
+                store(buf, node1_output_lowp)
+                node2_input_lowp = node_output_lowp # hit store cache
+                node2_input = node1_output # hit cse cache
+            """
+            V.kernel.cache_dtype_convert(x, src_dtype, csevar, dtype)
+        return csevar
+
+    @staticmethod
+    def to_dtype_bitcast(x, dtype, src_dtype):
+        assert dtype in DTYPE_TO_CPP, f"{dtype} missing from {__name__}.DTYPE_TO_CPP"
+        return f"c10::bit_cast<{DTYPE_TO_CPP[dtype]}>({x})"
+
+    @staticmethod
+    def abs(x):
+        return f"std::abs({x})"
+
+    @staticmethod
+    def sin(x):
+        return f"std::sin({x})"
+
+    @staticmethod
+    def cos(x):
+        return f"std::cos({x})"
+
+    @staticmethod
+    def neg(x):
+        return f"decltype({x})(-{x})"
+
+    @staticmethod
+    def exp(x):
+        # return f"Sleef_expf_u10({x})"
+        return f"std::exp({x})"
+
+    @staticmethod
+    def exp2(x):
+        return f"std::exp2({x})"
+
+    @staticmethod
+    def expm1(x):
+        return f"std::expm1({x})"
+
+    @staticmethod
+    def erf(x):
+        return f"std::erf({x})"
+
+    @staticmethod
+    def erfc(x):
+        return f"std::erfc({x})"
+
+    @staticmethod
+    def erfinv(x):
+        return f"calc_erfinv({x})"
+
+    @staticmethod
+    def sqrt(x):
+        return f"std::sqrt({x})"
+
+    @staticmethod
+    def rsqrt(x):
+        return f"1 / std::sqrt({x})"
+
+    @staticmethod
+    def log1p(x):
+        bug = config.cpp.inject_log1p_bug_TESTING_ONLY
+        if bug == "accuracy":
+            return f"{x} + decltype({x})(1)"
+        elif bug is None:
+            return f"std::log1p({x})"
+        else:
+            raise AssertionError(
+                f"unrecognized config cpp.inject_log1p_bug_TESTING_ONLY = {bug!r}"
+            )
+
+    @staticmethod
+    def tan(x):
+        return f"std::tan({x})"
+
+    @staticmethod
+    def tanh(x):
+        return f"std::tanh({x})"
+
+    @staticmethod
+    def signbit(x):
+        """
+        On windows std::signbit only support float type.
+        Ref: https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/signbit?view=msvc-170
+        """
+        return (
+            f"std::signbit(static_cast({x}))"
+            if _IS_WINDOWS
+            else f"std::signbit({x})"
+        )
+
+    @staticmethod
+    def pow(a, b):
+        return f"std::pow({a}, {b})"
+
+    @staticmethod
+    def log(x):
+        return f"std::log({x})"
+
+    @staticmethod
+    def round(x):
+        return f"std::nearbyint({x})"
+
+    @staticmethod
+    def floor(x):
+        return f"std::floor({x})"
+
+    @staticmethod
+    def floordiv(a, b):
+        # a and b are integer type
+        quot = f"{a} / {b}"
+        rem = f"{a} % {b}"
+        return f"(({a} < 0) != ({b} < 0) ? ({rem} != 0 ? {quot} - 1 : {quot}) : {quot})"
+
+    @staticmethod
+    def ceil(x):
+        return f"std::ceil({x})"
+
+    @staticmethod
+    def trunc(x):
+        return f"std::trunc({x})"
+
+    @staticmethod
+    def truncdiv(a, b):
+        # a and b are integer type
+        return f"{a} / {b}"
+
+    @staticmethod
+    def fmod(a, b):
+        return f"std::fmod({a}, {b})"
+
+    @staticmethod
+    def isinf(x):
+        return f"std::isinf({x})"
+
+    @staticmethod
+    def isnan(x):
+        return f"std::isnan({x})"
+
+    @staticmethod
+    def lgamma(x):
+        return f"std::lgamma({x})"
+
+    @staticmethod
+    def acos(x):
+        return f"std::acos({x})"
+
+    @staticmethod
+    def acosh(x):
+        return f"std::acosh({x})"
+
+    @staticmethod
+    def cosh(x):
+        return f"std::cosh({x})"
+
+    @staticmethod
+    def sinh(x):
+        return f"std::sinh({x})"
+
+    @staticmethod
+    def asin(x):
+        return f"std::asin({x})"
+
+    @staticmethod
+    def asinh(x):
+        return f"std::asinh({x})"
+
+    @staticmethod
+    def atan2(x, y):
+        return f"std::atan2({x}, {y})"
+
+    @staticmethod
+    def atan(x):
+        return f"std::atan({x})"
+
+    @staticmethod
+    def atanh(x):
+        return f"std::atanh({x})"
+
+    @staticmethod
+    def copysign(x, y):
+        return f"std::copysign({x}, {y})"
+
+    @staticmethod
+    def frexp(x):
+        cache_keys = f"frexp({x})[0]", f"frexp({x})[1]"
+        if all(V.kernel.cse.try_get(cache_key) is not None for cache_key in cache_keys):
+            return tuple(V.kernel.cse.try_get(cache_key) for cache_key in cache_keys)
+
+        code = BracesBuffer()
+        exponent = V.kernel.cse.newvar(dtype=torch.int32, shape=x.shape)
+        mantissa = V.kernel.cse.newvar(dtype=x.dtype, shape=x.shape)
+        code.writeline(f"int32_t {exponent};")
+        code.writeline(f"auto {mantissa} = std::frexp({x}, &{exponent});")
+        V.kernel.compute.splice(code)
+        cse_vars = (mantissa, exponent)
+        for cache_key, cse_var in zip(cache_keys, cse_vars):
+            V.kernel.cse.put(cache_key, cse_var)
+        return mantissa, exponent
+
+    @staticmethod
+    def hypot(x, y):
+        return f"std::hypot({x}, {y})"
+
+    @staticmethod
+    def log10(x):
+        return f"std::log10({x})"
+
+    @staticmethod
+    def log2(x):
+        return f"std::log2({x})"
+
+    @staticmethod
+    def nextafter(x, y):
+        return f"std::nextafter({x}, {y})"
+
+    @staticmethod
+    def relu(x):
+        bug = config.cpp.inject_relu_bug_TESTING_ONLY
+        if bug == "compile_error":
+            return "compile error!"
+        elif bug == "runtime_error":
+            return f"{x}; throw 1"
+        elif bug == "accuracy":
+            return f"{x} + decltype({x})(1)"
+        elif bug is None:
+            return f"std::max({x}, decltype({x})(0))"
+        else:
+            raise AssertionError(
+                f"unrecognized config cpp.inject_relu_bug_TESTING_ONLY = {bug!r}"
+            )
+
+    @staticmethod
+    def minimum(a, b):
+        return f"min_propagate_nan({a}, {b})"
+
+    @staticmethod
+    def maximum(a, b):
+        return f"max_propagate_nan({a}, {b})"
+
+    @staticmethod
+    def where(a, b, c):
+        return f"{a} ? {b} : {c}"
+
+    @staticmethod
+    def mod(a, b):
+        return f"mod({a}, {b})"
+
+    @staticmethod
+    def constant(val, dtype):
+        return value_to_cpp(val, DTYPE_TO_CPP[dtype])
+
+    @staticmethod
+    def index_expr(expr, dtype):
+        idx_str = cexpr(V.kernel.rename_indexing(expr))
+        var = V.kernel.cse.generate(
+            V.kernel.compute, idx_str, bounds=get_bounds_index_expr(expr)
+        )
+        return ops.to_dtype(var, dtype)
+
+    @staticmethod
+    def masked(mask, body, other):
+        code = BracesBuffer()
+
+        # Write masked operation into a lambda
+        body_var = V.kernel.cse.newvar()
+        code.writeline(f"auto {body_var} = [&]")
+        with V.kernel.swap_buffers(code), code.indent():
+            result = body()
+            code.writeline(f"return {result};")
+        code.writeline(";")
+        V.kernel.compute.splice(code)
+
+        # Use the lambda's return type as the type of other
+        other_code = value_to_cpp(other, f"decltype({body_var}())")
+        return f"{mask} ? {body_var}() : {other_code}"
+
+    @staticmethod
+    def logical_and(a, b):
+        return f"{a} && {b}"
+
+    @staticmethod
+    def logical_not(a):
+        return f"!{a}"
+
+    @staticmethod
+    def logical_or(a, b):
+        return f"{a} || {b}"
+
+    @staticmethod
+    def logical_xor(a, b):
+        return f"{a} != {b}"
+
+    @staticmethod
+    def bitwise_and(a, b):
+        return f"decltype({a})({a} & {b})"
+
+    @staticmethod
+    def bitwise_not(a):
+        return f"decltype({a})(~{a})"
+
+    @staticmethod
+    def bitwise_or(a, b):
+        return f"decltype({a})({a} | {b})"
+
+    @staticmethod
+    def bitwise_xor(a, b):
+        return f"decltype({a})({a} ^ {b})"
+
+    @staticmethod
+    def bitwise_left_shift(a, b):
+        code = BracesBuffer()
+        code.writeline("[&]()")
+        with code.indent():
+            scalar_t = DTYPE_TO_CPP[a.dtype]
+            code.writeline(
+                f"constexpr decltype({b}) max_shift = sizeof({scalar_t}) * CHAR_BIT;"
+            )
+            code.writeline(
+                f"if ((static_cast>({b}) < 0) || ({b} >= max_shift))"
+            )
+            with code.indent():
+                code.writeline(f"return decltype({a})(0);")
+            code.writeline(
+                f"return decltype({a})(static_cast>({a}) << {b});"
+            )
+        code.writeline("()")
+        return code
+
+    @staticmethod
+    def bitwise_right_shift(a, b):
+        code = BracesBuffer()
+        code.writeline("[&]()")
+        with code.indent():
+            scalar_t = DTYPE_TO_CPP[a.dtype]
+            code.writeline(
+                f"constexpr decltype({b}) max_shift = sizeof({scalar_t}) * CHAR_BIT - std::is_signed_v<{scalar_t}>;"
+            )
+            code.writeline(
+                f"if ((static_cast>({b}) < 0) || ({b} >= max_shift))"
+            )
+            with code.indent():
+                code.writeline(f"return decltype({a})({a} >> max_shift);")
+            code.writeline(f"return decltype({a})({a} >> {b});")
+        code.writeline("()")
+        return code
+
+    @staticmethod
+    def rand(seed: sympy.Expr, offset: sympy.Expr):
+        return f"normalized_rand_cpu({seed}, {offset})"
+
+    @staticmethod
+    def randn(seed: sympy.Expr, offset: sympy.Expr):
+        return f"randn_cpu({seed}, {offset})"
+
+    @staticmethod
+    def randint64(seed: sympy.Expr, offset: sympy.Expr, low, high):
+        return f"randint64_cpu({seed}, {offset}, {low}, {high})"
+
+    @staticmethod
+    def sigmoid(x):
+        return f"decltype({x})(1) / (decltype({x})(1) + std::exp(-{x}))"
+
+    @staticmethod
+    def sign(x):
+        code = BracesBuffer()
+        scalar_zero = f"decltype({x})(0)"
+        scalar_one = f"decltype({x})(1)"
+        code.writeline("[&]()")
+        with code.indent():
+            code.writeline(f"auto left = {x} > 0 ? {scalar_one} : {scalar_zero};")
+            code.writeline(f"auto right = {x} < 0 ? {scalar_one} : {scalar_zero};")
+            code.writeline("return left - right;")
+        code.writeline("()")
+        return code
+
+    def partial_accumulate(
+        self,
+        name: str,
+        reduction_type: str,
+        value: CSEVariable,
+        extra_meta: dict[str, Any],
+    ) -> None:
+        raise NotImplementedError
+
+
+CppOverrides._initialize_pointwise_overrides("cpp")
+
+
+class CppVecOverrides(CppOverrides):
+    """Map element-wise ops to aten vectorization C++"""
+
+    def __new__(cls, *args, **kargs):
+        self = super().__new__(cls)
+
+        def wrap(func):
+            # `CppVecKernel` generates both scalar ops and vector ops according to
+            # whether the inputs are scalars or vectors while all ops in `CppVecOverrides`
+            # (except for some ops explained below) assume the inputs are vectors. We wrap the ops in
+            # `CppVecOverrides` to broadcast scalar inputs to vectors if needed or fallback to
+            # `CppOverrides` when all inputs are scalars.
+            #
+            # Notes on ops handled separately in their own functions:
+            # `ops.masked`:
+            #     needs recursive handling of masked body.
+            # `ops.index_expr`:
+            #     needs to further analyze the dependency of the index expression on
+            #     the tiling itervar.
+            def wrapper(*args, **kwargs):
+                scalars = [
+                    arg
+                    for arg in args
+                    if isinstance(arg, (int, sympy.Expr))
+                    or (isinstance(arg, CppCSEVariable) and not arg.is_vec)
+                ]
+                vectors = [
+                    arg
+                    for arg in args
+                    if isinstance(arg, CppCSEVariable) and arg.is_vec
+                ]
+                new_args = list(args)
+                if scalars and vectors:
+                    new_args = []
+                    for arg in args:
+                        if isinstance(arg, (int, sympy.Expr)):
+                            if isinstance(arg, sympy.Expr) and not arg.is_number:
+                                arg = ops.index_expr(arg, torch.int64)
+                            else:
+                                arg = ops.constant(arg, torch.int64)
+                            arg = arg.value if isinstance(arg, OpsValue) else arg
+                        new_args.append(arg)
+
+                # DType Promotion
+                if vectors:
+                    # We have saw several data type mismatch issues related with index_expr in
+                    # the lowering phase of torch.int8. torch.int32, torch.int64.
+                    # 1. int32 and int64 in test_torchinductor.py::test_max_pool2d_with_indices_backward3_cpu
+                    # 2. int8 and int32 in test_torchinductor.py::test_max_pool2d5_cpu
+                    # 3. int32 and fp32 in test_torchinductor_dynamic_shapes.py::test_avg_pool2d8_dynamic_shapes_cpu
+                    if len(new_args) == 2:
+                        new_args = promote_args(new_args)
+                    elif func is CppVecOverrides.where:
+                        new_args[1:] = promote_args(new_args[1:])
+
+                # Broadcast scalar args to vector
+                if scalars and vectors:
+                    assert isinstance(V.kernel, CppVecKernel)
+                    new_args = [
+                        (
+                            V.kernel.broadcast(new_arg)
+                            if (
+                                isinstance(new_arg, CppCSEVariable)
+                                and not new_arg.is_vec
+                                and func
+                                not in [
+                                    CppVecOverrides.rand,
+                                    CppVecOverrides.randn,
+                                    CppVecOverrides.randint64,
+                                ]
+                            )
+                            else new_arg
+                        )
+                        for new_arg in new_args
+                    ]
+
+                if vectors:
+                    return func(*new_args, **kwargs)
+                else:
+                    # fallback to scalar ops
+                    scalar_ops = super(CppVecOverrides, self)
+                    scalar_func = getattr(scalar_ops, func.__name__)
+                    assert scalar_func is not None
+                    return scalar_func(*args, **kwargs)
+
+            return wrapper
+
+        for name, method in vars(CppVecOverrides).items():
+            if getattr(method, "__class__", None) is staticmethod and name not in [
+                "masked",
+                "index_expr",
+            ]:
+                setattr(self, name, wrap(method.__func__))
+
+        return self
+
+    @staticmethod
+    def add(a, b):
+        return f"{a} + {b}"
+
+    @staticmethod
+    def sub(a, b):
+        return f"{a} - {b}"
+
+    @staticmethod
+    def mul(a, b):
+        return f"{a} * {b}"
+
+    @staticmethod
+    def truediv(a, b):
+        return f"{a} / {b}"
+
+    @staticmethod
+    def abs(x):
+        return f"{x}.abs()"
+
+    @staticmethod
+    def sin(x):
+        return f"{x}.sin()"
+
+    @staticmethod
+    def cos(x):
+        return f"{x}.cos()"
+
+    @staticmethod
+    def exp(x):
+        return f"{x}.exp()"
+
+    @staticmethod
+    def exp2(x):
+        return f"{x}.exp2()"
+
+    @staticmethod
+    def expm1(x):
+        # decompose for a better performance
+        vec_one = f"decltype({x})(1)"
+        return f"{x}.exp() - {vec_one}"
+
+    @staticmethod
+    def erf(x):
+        return f"{x}.erf()"
+
+    @staticmethod
+    def erfc(x):
+        return f"{x}.erfc()"
+
+    @staticmethod
+    def erfinv(x):
+        return f"{x}.erfinv()"
+
+    @staticmethod
+    def sqrt(x):
+        return f"{x}.sqrt()"
+
+    @staticmethod
+    def eq(x, y):
+        assert isinstance(V.kernel, CppVecKernel)
+        assert isinstance(x, CppCSEVariable)
+        assert x.dtype is not None
+        return f"{V.kernel._get_mask_type(x.dtype)}({x} == {y})"
+
+    @staticmethod
+    def ne(x, y):
+        assert isinstance(V.kernel, CppVecKernel)
+        assert isinstance(x, CppCSEVariable)
+        if x.dtype == torch.bool:
+            assert y.dtype == torch.bool
+            x_cast, y_cast = unify_mask_base_type(V.kernel.compute, (x, y))
+            return f"{x_cast} != {y_cast}"
+        else:
+            assert x.dtype is not None
+            return f"{V.kernel._get_mask_type(x.dtype)}({x} != {y})"
+
+    @staticmethod
+    def lt(x, y):
+        assert isinstance(V.kernel, CppVecKernel)
+        assert isinstance(x, CppCSEVariable)
+        assert x.dtype is not None
+        return f"{V.kernel._get_mask_type(x.dtype)}({x} < {y})"
+
+    @staticmethod
+    def gt(x, y):
+        assert isinstance(V.kernel, CppVecKernel)
+        assert isinstance(x, CppCSEVariable)
+        assert x.dtype is not None
+        return f"{V.kernel._get_mask_type(x.dtype)}({x} > {y})"
+
+    @staticmethod
+    def le(x, y):
+        assert isinstance(V.kernel, CppVecKernel)
+        assert isinstance(x, CppCSEVariable)
+        assert x.dtype is not None
+        return f"{V.kernel._get_mask_type(x.dtype)}({x} <= {y})"
+
+    @staticmethod
+    def ge(x, y):
+        assert isinstance(V.kernel, CppVecKernel)
+        assert isinstance(x, CppCSEVariable)
+        assert x.dtype is not None
+        return f"{V.kernel._get_mask_type(x.dtype)}({x} >= {y})"
+
+    @staticmethod
+    def and_(x, y):
+        return f"{x} & {y}"
+
+    @staticmethod
+    def rsqrt(x):
+        return f"{x}.rsqrt()"
+
+    @staticmethod
+    def pow(a, b):
+        return f"{a}.pow({b})"
+
+    @staticmethod
+    def log(x):
+        return f"{x}.log()"
+
+    @staticmethod
+    def round(x):
+        return f"{x}.round()"
+
+    @staticmethod
+    def floor(x):
+        return f"{x}.floor()"
+
+    @staticmethod
+    def ceil(x):
+        return f"{x}.ceil()"
+
+    @staticmethod
+    def trunc(x):
+        return f"{x}.trunc()"
+
+    @staticmethod
+    def fmod(a, b):
+        return f"{a}.fmod({b})"
+
+    @staticmethod
+    def lgamma(x):
+        return f"{x}.lgamma()"
+
+    @staticmethod
+    def logical_and(a, b):
+        a, b = may_unify_binary_op_mask_type(a, b)
+        return f"{a} & {b}"
+
+    @staticmethod
+    def logical_not(a):
+        return f"~{a}"
+
+    @staticmethod
+    def logical_or(a, b):
+        a, b = may_unify_binary_op_mask_type(a, b)
+        return f"{a} | {b}"
+
+    @staticmethod
+    def logical_xor(a, b):
+        a, b = may_unify_binary_op_mask_type(a, b)
+        return f"{a} ^ {b}"
+
+    @staticmethod
+    def bitwise_and(a, b):
+        a, b = may_unify_binary_op_mask_type(a, b)
+        return f"{a} & {b}"
+
+    @staticmethod
+    def bitwise_not(a):
+        return f"~{a}"
+
+    @staticmethod
+    def bitwise_or(a, b):
+        a, b = may_unify_binary_op_mask_type(a, b)
+        return f"{a} | {b}"
+
+    @staticmethod
+    def bitwise_xor(a, b):
+        a, b = may_unify_binary_op_mask_type(a, b)
+        return f"{a} ^ {b}"
+
+    @staticmethod
+    def bitwise_left_shift(a, b):
+        return f"{a} << {b}"
+
+    @staticmethod
+    def bitwise_right_shift(a, b):
+        return f"{a} >> {b}"
+
+    @staticmethod
+    def load_seed(name, offset):
+        assert isinstance(V.kernel, CppVecKernel)
+        return f"{V.kernel.load(name, offset)}"
+
+    @staticmethod
+    def rand(seed, offset):
+        assert isinstance(V.kernel, CppVecKernel)
+        code = BracesBuffer()
+        rand_function = (
+            f"result[offset_idx] = normalized_rand_cpu({seed}, offset[offset_idx]);"
+        )
+        return codegen_rand(offset, code, rand_function)
+
+    @staticmethod
+    def randn(seed, offset):
+        assert isinstance(V.kernel, CppVecKernel)
+        code = BracesBuffer()
+        rand_function = f"result[offset_idx] = randn_cpu({seed}, offset[offset_idx]);"
+        return codegen_rand(offset, code, rand_function)
+
+    @staticmethod
+    def randint64(seed, offset, low, high):
+        assert isinstance(V.kernel, CppVecKernel)
+        code = BracesBuffer()
+        rand_function = f"result[offset_idx] = randint64_cpu({seed}, offset[offset_idx], {low}, {high});"
+        return codegen_rand(offset, code, rand_function, torch.int64)
+
+    @staticmethod
+    def remainder(a, b):
+        assert a.dtype == b.dtype, (
+            "remainder vec implementation expect the same inputs' dtype."
+        )
+        return f"{a} - ({CppVecOverrides.floordiv(a, b)}) * {b}"
+
+    @staticmethod
+    def tan(a):
+        return f"{a}.tan()"
+
+    @staticmethod
+    def tanh(a):
+        if config.cpp.use_decompose_tanh:
+            vec_one = f"decltype({a})(1)"
+            vec_two = f"decltype({a})(2)"
+            vec_minus_two = f"decltype({a})(-2)"
+            return (
+                f"{vec_two} / ({vec_one} + ({vec_minus_two} * {a}).exp()) - {vec_one}"
+            )
+        else:
+            return f"{a}.tanh()"
+
+    @staticmethod
+    def reciprocal(a):
+        return f"{a}.reciprocal()"
+
+    @staticmethod
+    def atan(x):
+        return f"{x}.atan()"
+
+    @staticmethod
+    def acos(x):
+        return f"{x}.acos()"
+
+    @staticmethod
+    def asin(x):
+        return f"{x}.asin()"
+
+    @staticmethod
+    def cosh(x):
+        return f"{x}.cosh()"
+
+    @staticmethod
+    def sinh(x):
+        return f"{x}.sinh()"
+
+    @staticmethod
+    def log10(x):
+        return f"{x}.log10()"
+
+    @staticmethod
+    def log2(x):
+        return f"{x}.log2()"
+
+    @staticmethod
+    def nextafter(x, y):
+        return f"{x}.nextafter({y})"
+
+    @staticmethod
+    def copysign(a, b):
+        return f"{a}.copysign({b})"
+
+    @staticmethod
+    def atan2(a, b):
+        return f"{a}.atan2({b})"
+
+    @staticmethod
+    def hypot(a, b):
+        return f"{a}.hypot({b})"
+
+    @staticmethod
+    def atanh(x):
+        # For real x, atanh(x) = 1/2 * log((1+x)/(1-x))
+        vec_one = f"decltype({x})(1)"
+        vec_one_half = f"decltype({x})(0.5)"
+        return f"{vec_one_half} * (({vec_one} + {x})/({vec_one} - {x})).log()"
+
+    @staticmethod
+    def asinh(x):
+        return f"{x}.asinh()"
+
+    @staticmethod
+    def acosh(x):
+        return f"{x}.acosh()"
+
+    @staticmethod
+    def relu(x):
+        bug = config.cpp.inject_relu_bug_TESTING_ONLY
+        if bug == "compile_error":
+            return "compile error!"
+        elif bug == "runtime_error":
+            return f"{x}; throw 1"
+        elif bug == "accuracy":
+            return f"{x} + decltype({x})(1)"
+        elif bug is None:
+            return f"at::vec::clamp_min({x}, decltype({x})(0))"
+        else:
+            raise AssertionError(
+                f"unrecognized config cpp.inject_relu_bug_TESTING_ONLY = {bug!r}"
+            )
+
+    # TODO: this seems to be dead
+    @staticmethod
+    def sigmoid(x):
+        return f"decltype({x})(1)/(decltype({x})(1) + {x}.neg().exp())"
+
+    @staticmethod
+    def neg(x):
+        return f"{x}.neg()"
+
+    @staticmethod
+    def floordiv(a, b):
+        if is_float_dtype(a.dtype):
+            assert a.dtype == b.dtype, (
+                "div_floor_floating_vec implementation expect the same inputs' dtype."
+            )
+            return f"div_floor_floating_vec({a}, {b})"
+        else:
+            assert all(is_integer_dtype(item.dtype) for item in [a, b])
+            # a and b are integer type
+            _t = f"decltype({a})"
+            if V.kernel._get_raw_num_vectors(b.dtype) < 1:
+                # Doing blend to set the remaining bits of b to non-zero
+                b = f"{_t}::blend<{(1 << V.kernel.tiling_factor) - 1}>({_t}(1), {b})"
+            quot = f"{a} / {b}"
+            has_rem = f"({a} % {b} != {_t}(0))"
+            is_neg = f"(({a} < {_t}(0)) != ({b} < {_t}(0)))"
+            return f"{_t}::blendv({quot}, {quot} - {_t}(1), {has_rem} & {is_neg})"
+
+    @staticmethod
+    def truncdiv(a, b):
+        # a and b are integer type
+        if V.kernel._get_raw_num_vectors(b.dtype) < 1:
+            # Doing blend to set the remaining bits of b to non-zero
+            _t = f"decltype({b})"
+            b = f"{_t}::blend<{(1 << V.kernel.tiling_factor) - 1}>({_t}(1), {b})"
+        return f"{a} / {b}"
+
+    @staticmethod
+    def minimum(a, b):
+        if a.dtype == torch.bool:
+            assert b.dtype == torch.bool
+            a_cast, b_cast = unify_mask_base_type(V.kernel.compute, (a, b))
+            return f"{a_cast} & {b_cast}"
+        else:
+            return f"at::vec::minimum({a}, {b})"
+
+    @staticmethod
+    def maximum(a, b):
+        if a.dtype == torch.bool:
+            assert b.dtype == torch.bool
+            a_cast, b_cast = unify_mask_base_type(V.kernel.compute, (a, b))
+            return f"{a_cast} | {b_cast}"
+        else:
+            return f"at::vec::maximum({a}, {b})"
+
+    @staticmethod
+    def square(a):
+        return f"{a} * {a}"
+
+    @staticmethod
+    def where(a, b, c):
+        assert isinstance(V.kernel, CppVecKernel)
+        if b.dtype == torch.bool:
+            assert c.dtype == torch.bool
+            blendv_a, blendv_b, blendv_c = unify_mask_base_type(
+                V.kernel.compute, (a, b, c)
+            )
+            return f"decltype({blendv_b})::blendv({blendv_c}, {blendv_b}, {blendv_a})"
+        else:
+            return f"decltype({b})::blendv({c}, {b}, {V.kernel._get_mask_cast(a, b.dtype)})"
+
+    @staticmethod
+    def sign(x):
+        code = BracesBuffer()
+        vec_zero = f"decltype({x})(0)"
+        vec_one = f"decltype({x})(1)"
+        blendv_l = f"decltype({x})::blendv({vec_zero}, {vec_one}, {vec_zero} < {x})"
+        blendv_r = f"decltype({x})::blendv({vec_zero}, {vec_one}, {x} < {vec_zero})"
+        code.writeline("[&]()")
+        with code.indent():
+            code.writeline(f"auto left = {blendv_l};")
+            code.writeline(f"auto right = {blendv_r};")
+            code.writeline("return left - right;")
+        code.writeline("()")
+        return code
+
+    @staticmethod
+    def to_dtype(x, dtype, src_dtype=None, use_compute_dtypes=True):
+        assert dtype in [
+            torch.bool,
+            torch.float64,
+            torch.float,
+            torch.bfloat16,
+            torch.float16,
+            torch.uint8,
+            torch.int8,
+            torch.int32,
+            torch.int64,
+            torch.float8_e4m3fn,
+            torch.float8_e5m2,
+        ], f"{__name__} does not support {dtype}"
+        assert isinstance(x, CppCSEVariable)
+        src_dtype = x.dtype
+        expr = V.kernel.get_to_dtype_expr(x, dtype, src_dtype)
+        csevar = V.kernel.cse.generate(V.kernel.compute, expr)
+        csevar.update_on_args("to_dtype", (x, dtype), {"src_dtype": src_dtype})
+        if dtype in DTYPE_LOWP_FP and src_dtype == torch.float:
+            V.kernel.cache_dtype_convert(x, src_dtype, csevar, dtype)
+        return csevar
+
+    @staticmethod
+    def log1p(x):
+        bug = config.cpp.inject_log1p_bug_TESTING_ONLY
+        if bug == "accuracy":
+            return f"{x} + decltype({x})(1)"
+        elif bug is None:
+            return f"{x}.log1p()"
+        else:
+            raise AssertionError(
+                f"unrecognized config cpp.inject_log1p_bug_TESTING_ONLY = {bug!r}"
+            )
+
+    @staticmethod
+    def masked(mask, body, other):
+        assert isinstance(V.kernel, CppVecKernel)
+        code = BracesBuffer()
+        var = V.kernel.cse.newvar()
+        with V.kernel.masked(mask) as new_mask:
+            code.writeline(f"auto {var} = [&]")
+            with V.kernel.swap_buffers(code), code.indent():
+                result = body()
+                code.writeline(f"return {result};")
+        code.writeline(";")
+        V.kernel.compute.splice(code)
+
+        dtype = result.dtype
+        body_code = f"{var}()"
+
+        def maskify_or_vecify(code):
+            return (
+                f"{V.kernel._get_mask_type()}::from({code})"
+                if dtype == torch.bool
+                else f"{V.kernel._get_vec_type(dtype)}({code})"
+            )
+
+        if result.is_vec:
+            body_code_vec = body_code
+        else:
+            body_code_vec = maskify_or_vecify(body_code)
+        other_code = value_to_cpp(other, DTYPE_TO_CPP[dtype])
+        # loading bool as VecMask
+        other_code_vec = maskify_or_vecify(other_code)
+        assert isinstance(new_mask, CppCSEVariable), new_mask
+        if new_mask.is_vec:
+            code = BracesBuffer()
+            code.writeline("[&]")
+            with V.kernel.swap_buffers(code), code.indent():
+                code.writeline(f"if ({new_mask}.all_zero())")
+                with code.indent():
+                    code.writeline(f"return {other_code_vec};")
+                code.writeline("else")
+                with code.indent():
+                    # Create cse variable to reuse kernel.overrides.where
+                    body_vec_var = V.kernel.cse.generate(
+                        V.kernel.compute,
+                        body_code_vec,
+                    )
+                    other_vec_var = V.kernel.cse.generate(
+                        V.kernel.compute,
+                        other_code_vec,
+                    )
+                    assert isinstance(body_vec_var, CppCSEVariable), body_vec_var
+                    assert isinstance(other_vec_var, CppCSEVariable), other_vec_var
+                    body_vec_var.dtype = dtype
+                    other_vec_var.dtype = dtype
+                    overrides: type[Union[CppOverrides, CppVecOverrides]] = (
+                        # pyrefly: ignore [bad-assignment]
+                        V.kernel.overrides
+                    )  # type: ignore[has-type]
+                    code.writeline(
+                        f"return {overrides.where(new_mask, body_vec_var, other_vec_var)};"
+                    )
+            code.writeline("()")
+            csevar = V.kernel.cse.generate(
+                V.kernel.compute,
+                code,
+            )
+            result.is_vec = True
+        elif result.is_vec:
+            csevar = V.kernel.cse.generate(
+                V.kernel.compute, f"{mask} ? {body_code_vec} : {other_code_vec}"
+            )
+        else:
+            csevar = V.kernel.cse.generate(
+                V.kernel.compute, f"{mask} ? {body_code} : {other_code}"
+            )
+        # `result` is explicitly added to the args for correct propagation
+        # of relevant itervars and vectorization status.
+        csevar.update_on_args("masked", (mask, body, other, result), {})
+        return csevar
+
+    @staticmethod
+    def index_expr(expr, dtype):
+        assert isinstance(V.kernel, CppVecKernel)
+        index = V.kernel.rename_indexing(expr)
+        tiling_var = V.kernel.itervars[V.kernel.tiling_idx]
+        stride = V.kernel._try_get_const_stride(index, tiling_var)
+        if stride == 0:
+            return CppOverrides.index_expr(expr, dtype)
+        elif stride is not None:
+            idx = V.kernel.cse.generate(
+                V.kernel.compute, cexpr(index), bounds=get_bounds_index_expr(expr)
+            )
+            value = ops.to_dtype(idx, dtype)
+            if isinstance(value, OpsValue):
+                value = value.value
+            csevar = V.kernel.arange(value, stride)
+        else:
+            csevar = V.kernel._load_or_store_non_contiguous(  # type: ignore[assignment]
+                None, index, dtype, V.kernel.compute
+            )
+        # pyrefly: ignore [missing-attribute]
+        csevar.update_on_args("index_expr", (expr, dtype), {})
+        return csevar
+
+    @staticmethod
+    def frexp(x):
+        cache_keys = f"frexp({x})[0]", f"frexp({x})[1]"
+        if all(V.kernel.cse.try_get(cache_key) is not None for cache_key in cache_keys):
+            return tuple(V.kernel.cse.try_get(cache_key) for cache_key in cache_keys)
+
+        cdtype = DTYPE_TO_CPP[x.dtype]
+        size = V.kernel.tail_size if V.kernel.tail_size else V.kernel.tiling_factor
+        code = BracesBuffer()
+        exponent = V.kernel.cse.newvar(dtype=torch.int32)
+        mantissa = V.kernel.cse.newvar(dtype=x.dtype)
+        exponent.update_on_args("frexp", (x,), kwargs={})
+        mantissa.update_on_args("frexp", (x,), kwargs={})
+        n_vec = V.kernel._get_num_vectors(x.dtype)
+        mantissa_t = (
+            f"at::vec::Vectorized<{cdtype}>"
+            if n_vec == 1
+            else f"at::vec::VectorizedN<{cdtype}, {n_vec}>"
+        )
+        code.writeline(
+            f"at::vec::Vectorized {exponent};"
+            if n_vec == 1
+            else f"at::vec::VectorizedN {exponent};"
+        )
+        code.writeline(f"{mantissa_t} {mantissa};")
+        code.writeline("[&]()")
+        with code.indent():
+            code.writeline(
+                f"__at_align__ std::array<{cdtype}, {V.kernel.tiling_factor}> tmpbuf;"
+            )
+            code.writeline(f"{x}.store(tmpbuf.data(), {cexpr_index(size)});")
+            code.writeline(
+                f"__at_align__ std::array tmpbuf_exponent;"
+            )
+            code.writeline(
+                f"__at_align__ std::array<{cdtype}, {V.kernel.tiling_factor}> tmpbuf_mantissa;"
+            )
+            code.writeline(f"for (int i = 0; i < {cexpr_index(size)}; i++)")
+            with code.indent():
+                code.writeline(
+                    "tmpbuf_mantissa[i] = std::frexp(tmpbuf[i], &tmpbuf_exponent[i]);"
+                )
+            code.writeline(
+                f"{exponent} = at::vec::Vectorized::loadu(tmpbuf_exponent.data(), {cexpr_index(size)});"
+                if n_vec == 1
+                else f"{exponent} = at::vec::VectorizedN::loadu(tmpbuf_exponent.data(), {cexpr_index(size)});"
+            )
+            code.writeline(
+                f"{mantissa} = {mantissa_t}::loadu(tmpbuf_mantissa.data(), {cexpr_index(size)});"
+            )
+        code.writeline("();")
+        V.kernel.compute.splice(code)
+        cse_vars = (mantissa, exponent)
+        for cache_key, cse_var in zip(cache_keys, cse_vars):
+            V.kernel.cse.put(cache_key, cse_var)
+        return mantissa, exponent
+
+    @classmethod
+    def _scalarize(cls, scalar_func):
+        def inner(*args, **kwargs):
+            assert not kwargs
+            kernel = V.kernel
+            assert isinstance(kernel, CppVecKernel)
+            code = BracesBuffer()
+            code.writeline("[&]()")
+            vec_dtype = args[0].dtype
+            n_vec = kernel._get_num_vectors(vec_dtype)
+            size = kernel.tail_size if kernel.tail_size else kernel.tiling_factor
+            scalar_args = []
+            cdtype = DTYPE_TO_CPP[vec_dtype]
+            output_mask = scalar_func.__name__ in (
+                "isinf",
+                "isnan",
+                "signbit",
+            )
+            octype = "bool" if output_mask else cdtype
+            octype = (
+                DTYPE_TO_CPP[args[-2]]
+                if (scalar_func.__name__ == "to_dtype_bitcast")
+                else octype
+            )
+            with code.indent():
+                for argidx, arg in enumerate(args):
+                    if isinstance(arg, CppCSEVariable):
+                        assert arg.is_vec
+                        assert arg.dtype == vec_dtype
+                        code.writeline(
+                            f"__at_align__ std::array<{cdtype}, {kernel.tiling_factor}> tmpbuf{argidx};"
+                        )
+                        code.writeline(
+                            f"{arg}.store(tmpbuf{argidx}.data(), {cexpr_index(size)});"
+                        )
+                        scalar_args.append(f"tmpbuf{argidx}[i]")
+                    else:
+                        scalar_args.append(arg)
+                code.writeline(
+                    f"__at_align__ std::array<{octype}, {kernel.tiling_factor}> tmpbuf_out;"
+                )
+                res = scalar_func(*scalar_args)
+                code.writeline(f"for (int i = 0; i < {cexpr_index(size)}; i++)")
+                with code.indent():
+                    code.writeline(f"tmpbuf_out[i] = {res};")
+                load_args = f"tmpbuf_out.data(), {cexpr_index(size)}"
+                if output_mask:
+                    load_fn = f"at::vec::VecMask<{cdtype},{n_vec}>::from"
+                elif n_vec == 1:
+                    load_fn = f"at::vec::Vectorized<{octype}>::loadu"
+                else:
+                    load_fn = f" at::vec::VectorizedN<{octype}, {n_vec}>::loadu"
+                code.writeline(f"return {load_fn}({load_args});")
+            code.writeline("()")
+            return code
+
+        return inner
+
+    @classmethod
+    def _initialize_scalarize(cls):
+        vec_vars = vars(CppVecOverrides)
+        for name, method in vars(CppOverrides).items():
+            if isinstance(method, staticmethod) and name not in vec_vars:
+                func = cls._scalarize(method.__func__)
+                func.__name__ = name
+                setattr(cls, name, staticmethod(func))
+
+
+CppVecOverrides._initialize_pointwise_overrides("cppvec")
+CppVecOverrides._initialize_scalarize()
+
+
+class CppTile2DOverrides(CppVecOverrides):
+    @staticmethod
+    def index_expr(expr, dtype):
+        assert isinstance(V.kernel, CppTile2DKernel)
+        expr = V.kernel.transform_indexing(expr)
+        return CppVecOverrides.index_expr(expr, dtype)
+
+
+class CppKernel(Kernel):
+    """
+    Base class for C++ kernel code generation in PyTorch Inductor.
+    This class is responsible for generating C++ code from the intermediate representation.
+
+    Args:
+        args: Kernel arguments used for code generation
+        num_threads: Number of threads for parallel execution
+    """
+
+    overrides = CppOverrides  # type: ignore[assignment]
+    sexpr = cexpr
+    newvar_prefix = "auto "
+    suffix = ";"
+
+    def __init__(self, args, num_threads):
+        super().__init__(args)
+        # Indicate when this kernel is active, for example
+        # {x0, {24, 26}} -> this kernel is active when x0 >= 24 and x0 < 26
+        self.active_ranges: dict[sympy.Expr, tuple[sympy.Expr, ...]] = {}
+        # Indicate this kernel will be moved under the inner for-loop
+        # See move_code_under_inner_loop
+        self.inner_itervars: list[sympy.Symbol] = []
+        self.call_ranges: Optional[tuple[sympy.Expr, ...]] = None
+        self.ranges: list[sympy.Expr] = []
+        self.itervars: list[sympy.Symbol] = []
+        self.reduction_depth = None
+        self.reduction_prefix = IndentedBuffer()
+        # We need this because when we run "reduction" nodes here, we lack
+        # "loop" information to decide whether we need a scalar init or an array init
+        # in the reduction prefix. Meanwhile, we have other information like
+        # reduction types and dtype to generate the reduction prefix. We record the information
+        # with a callable lambda function, and when we have enough information to finalize
+        # the reduction prefix, we can invoke the functions here with additional information.
+        self.reduction_prefix_generators: list[Callable] = []  # type: ignore[type-arg]
+        self.reduction_suffix = IndentedBuffer()
+        self.parallel_reduction_prefix = IndentedBuffer()
+        self.parallel_reduction_suffix = IndentedBuffer()
+        self.local_reduction_init = IndentedBuffer()
+        self.local_reduction_stores = IndentedBuffer()
+        self.is_reduction = False
+        self.non_parallel_reduction_prefix = IndentedBuffer()
+        self.non_parallel_reduction_suffix = IndentedBuffer()
+        self.reduction_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="tmp_acc")
+        self.welford_helper_cse = CSE(
+            self.newvar_prefix, self.suffix, name_prefix="welford_helper"
+        )
+        self.cascade_helper_cse = CSE(
+            self.newvar_prefix, self.suffix, name_prefix="cascade_helper"
+        )
+        self.preloads = IndentedBuffer()
+        self.poststores = IndentedBuffer()
+        self.num_threads = num_threads  # num_threads the kernel specialized for
+        self.reduction_omp_dec: dict[tuple[str, str], str] = {}
+        self.reduction_var_names: list[str] = []
+
+    def _gen_parallel_reduction_buffers(
+        self,
+        acc,
+        acc_type,
+        reduction_type,
+        dtype,
+        reduction_combine_fn=reduction_combine,
+        reduction_init_fn=reduction_init,
+    ):
+        if config.cpp.dynamic_threads and not self.parallel_reduction_prefix:
+            self.parallel_reduction_prefix.writeline(
+                "int max_threads = omp_get_max_threads();"
+            )
+        acc_local = f"{acc}_local"
+        num_threads = (
+            "max_threads" if config.cpp.dynamic_threads else parallel_num_threads()
+        )
+        acc_local_in_array = f"{acc}_arr[tid]"
+        self.local_reduction_init.writeline(
+            f"{acc_type} {acc_local} = {reduction_init_fn(reduction_type, dtype)};"
+        )
+        self.parallel_reduction_prefix.splice(
+            reduction_prefix_array(
+                acc,
+                acc_type,
+                reduction_type,
+                dtype,
+                num_threads,
+                reduction_init_fn,
+            )
+        )
+        self.local_reduction_stores.writeline(f"{acc_local_in_array} = {acc_local};")
+        self.parallel_reduction_suffix.writelines(
+            [
+                f"for (int tid = 0; tid < {num_threads}; tid++)",
+                "{",
+                f"    {acc} = {reduction_combine_fn(reduction_type, acc, acc_local_in_array, src_dtype=dtype)};",
+                "}",
+            ],
+        )
+
+    def update_stores_with_parallel_reduction(self):
+        for var_name in self.reduction_var_names:
+            replace_acc_name(self.stores, var_name, f"{var_name}_local")
+
+    def gen_body(self, code: Optional[BracesBuffer] = None):
+        assert code is None
+        code = BracesBuffer()
+        with contextlib.ExitStack() as stack:
+            if hasattr(self, "codegen_inner_loops"):
+                code.splice(self.preloads)
+                self.codegen_inner_loops(code)
+                stack.enter_context(code.indent())
+            code.splice(self.loads)
+            code.splice(self.compute)
+            code.splice(self.stores)
+        if hasattr(self, "codegen_inner_loops"):
+            code.splice(self.poststores)
+
+        if self.inner_itervars:
+            for idx in self.inner_itervars:
+                start, end = self.active_ranges[idx]
+                code = move_code_under_inner_loop(code, idx, f"{idx}_tail", start, end)
+        return code
+
+    @contextlib.contextmanager
+    def masked(self, mask):
+        """Context manager to add an additional mask to loads and stores."""
+        prior = self._load_mask
+        if prior:
+            mask = ops.and_(mask, prior)
+            if isinstance(mask, OpsValue):
+                mask = mask.value
+                assert isinstance(mask, CppCSEVariable)
+                # see NOTE [dtype of CppCSEVariable]
+                # mask's dtype should be bool
+                mask.dtype = torch.bool
+
+        # pyrefly: ignore [bad-assignment]
+        self._load_mask = mask
+        try:
+            yield mask
+        finally:
+            self._load_mask = prior
+
+    def scale_index_with_offset(
+        self, index: sympy.Expr, scale=1, itervar_idx=-1, offset=0
+    ):
+        var = self.itervars[itervar_idx]
+        replacement = {var: var * scale + offset}
+        new_index = sympy_subs(index, replacement)
+        return new_index
+
+    def index_to_str(self, index: sympy.Expr) -> str:
+        """
+        Convert an index expr to a string that can be used in cpp code.
+        e.g. a sympy expression "s2" may actually appear as "ks1" in the cpp kernel.
+        """
+        return cexpr(self.rename_indexing(index))
+
+    def index_indirect_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol):
+        """
+        Check if an index has free symbol CppCSEVariable that depends on `itervar`.
+        """
+        return any(
+            self.cse.varname_map[s.name].depends_on(itervar)  # type: ignore[attr-defined]
+            for s in index.free_symbols
+            if s.name in self.cse.varname_map  # type: ignore[attr-defined]
+            and isinstance(self.cse.varname_map[s.name], CppCSEVariable)  # type: ignore[attr-defined]
+        )
+
+    def index_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol):
+        return itervar in index.free_symbols or self.index_indirect_depends_on(
+            index, itervar
+        )
+
+    def var_ranges(self):
+        return dict(zip(self.itervars, self.ranges))
+
+    def check_bounds(
+        self,
+        expr: sympy.Expr,
+        size: sympy.Expr,
+        lower: bool,
+        upper: bool,
+    ):
+        if not (lower or upper):
+            return
+
+        indirect = free_symbol_is_type(expr, SymT.TMP)
+        if indirect:
+            # indexing in compute
+            csevar = ops.index_expr(expr, torch.int64).value
+            buffer = V.kernel.compute
+        else:
+            # indexing in loads
+            prior_compute = V.kernel.compute
+            try:
+                V.kernel.compute = self.loads
+                csevar = ops.index_expr(expr, torch.int64).value
+            finally:
+                V.kernel.compute = prior_compute
+            buffer = self.loads
+
+        size_str = V.kernel.sexpr(self.rename_indexing(size)) if upper else None
+
+        line = self.indirect_assert(
+            csevar, "0" if lower else None, size_str, self._load_mask
+        )
+        self.cse.generate(buffer, line, assignment=False)
+
+    def load(self, name: str, index: sympy.Expr):
+        var = self.args.input(name)
+        index = self.rename_indexing(index)
+        line = f"{var}[{cexpr_index(index)}]"
+        csevar = self.cse.generate(self.loads, line, dtype=V.graph.get_dtype(name))
+        csevar.update_on_args("load", (self, name, index), {})
+        return csevar
+
+    def store(self, name, index, value, mode=None):
+        assert "buf" in name
+        var = self.args.output(name)
+        index = self.rename_indexing(index)
+        if mode is None:
+            line = f"{var}[{cexpr_index(index)}] = {value};"
+        elif mode == "atomic_add":
+            if not config.cpp.dynamic_threads and self.num_threads == 1:
+                line = f"{var}[{cexpr_index(index)}] += {value};"
+            else:
+                dtype = V.graph.get_dtype(name)
+                # mirroring static_cast(...) in load:
+                value = f"static_cast<{DTYPE_TO_CPP[dtype]}>({value})"
+                line = f"atomic_add(&{var}[{cexpr_index(index)}], {value});"
+        else:
+            raise NotImplementedError(f"store mode={mode}")
+        self.stores.writeline(DeferredLine(name, line))
+
+    def device_assert_async(self, cond, msg):
+        self.compute.writeline(
+            f'({cond} ? 0 : (throw std::runtime_error("{msg}"), 0));'
+        )
+
+    def _gen_reduction_prefix(
+        self,
+        acc: Union[CSEVariable, str],
+        acc_type: str,
+        rtype: str,
+        dtype: torch.dtype,
+        init_fn,
+    ):
+        # Generate reduction prefix
+        # If size is None, we will define and initialize a single reduction variable
+        # => float tmp_acc0 = 0;
+        # Otherwise, we will define and initialize a reduction array
+        # => float tmp_acc0_arr[size];
+        # => for (int i = 0; i < size; i++) tmp_acc0_arr[i] = 0;
+        def inner(size: Optional[int] = None):
+            if size is None:
+                return f"{acc_type} {acc} = {init_fn(rtype, dtype)};"
+            else:
+                return reduction_prefix_array(
+                    acc,
+                    acc_type,
+                    rtype,
+                    dtype,
+                    size,
+                    init_fn,
+                )
+
+        return inner
+
+    def finalize_reduction_prefix(self, size: Optional[int] = None):
+        for gen_fn in self.reduction_prefix_generators:
+            self.reduction_prefix.splice(gen_fn(size))
+
+    def need_use_acc_helper(self, reduction_type, dtype, use_scalar):
+        # Check if we need accumulate helper for the reduction operation.
+        # using accumulate helper generates the necessary code to improve precision for
+        # sum and welford
+        # Note: using helper has non-negligible impact on performance
+
+        if reduction_type == "welford_reduce":
+            return True
+
+        # TODO add supports for more data types when needed
+        if reduction_type == "sum" and dtype == torch.float:
+            assert self.call_ranges is not None
+            reduction_size = functools.reduce(
+                operator.mul, self.call_ranges[self.reduction_depth :]
+            )
+
+            # chunk size to balance accuracy and performance
+            chunk_size = 4096
+
+            # use acc helper If cannot get size_hint
+            try:
+                reduction_size_hint = V.graph.sizevars.size_hint(reduction_size)
+            except Exception:
+                return True
+
+            if reduction_size_hint > chunk_size:
+                # use helper if the reduction size is too large
+                V.graph.sizevars.check_lt(chunk_size, reduction_size)
+                return True
+            else:
+                V.graph.sizevars.check_leq(reduction_size, chunk_size)
+        return False
+
+    def _acc_helper_init(
+        self,
+        reduction_type,
+        helper_val,
+        helper_range,
+        dtype,
+        num_threads=None,
+        use_scalar=False,
+    ):
+        num_range_thread = (
+            CeilDiv(helper_range, num_threads) if num_threads else helper_range
+        )
+        num_range_thread_expr = cexpr_index(num_range_thread)
+        assert reduction_type in ["welford_reduce", "sum"]
+        chunk_size = 4096
+        num_chunks = CeilDiv(num_range_thread, chunk_size)
+        helper_type = (
+            "WelfordHelper"
+            if reduction_type == "welford_reduce"
+            else "CascadeSumHelper"
+        )
+        if use_scalar:
+            h_type = DTYPE_TO_CPP[dtype]
+        else:
+            h_type = (
+                self._get_vec_type(dtype)
+                if hasattr(self, "_get_vec_type")
+                else DTYPE_TO_CPP[dtype]
+            )
+        helper_init_line = (
+            f"{helper_type}<{h_type}, {chunk_size}> {helper_val}"
+            f"("
+            f"{num_range_thread_expr}"
+            f");"
+        )
+        if reduction_type == "sum":
+            return helper_init_line
+        if isinstance(num_chunks, sympy.Integer) and num_chunks <= 1:
+            # When the number of chunks <= 1, there is no need to use cascade summation to improve
+            # reduction accuracy. We can initialize a static WelfordHelper to improve performance.
+            return f"static {helper_init_line}"
+        else:
+            return helper_init_line
+
+    def _use_acc_helper(
+        self, reduction_type, acc, helper_val, helper_range, dtype, use_scalar=False
+    ):
+        num_threads = (
+            "max_threads" if config.cpp.dynamic_threads else parallel_num_threads()
+        )
+        self.non_parallel_reduction_prefix.writeline(
+            self._acc_helper_init(
+                reduction_type, helper_val, helper_range, dtype, None, use_scalar
+            )
+        )
+        self.local_reduction_init.writeline(
+            self._acc_helper_init(
+                reduction_type, helper_val, helper_range, dtype, num_threads, use_scalar
+            )
+        )
+        result = acc if use_scalar else f"{acc}_vec"
+        if reduction_type == "welford_reduce":
+            self.non_parallel_reduction_suffix.writeline(
+                f"{result} = welford_combine({result}, &{helper_val});"
+            )
+            self.local_reduction_stores.writeline(
+                f"{result}_local = welford_combine({result}_local, &{helper_val});"
+            )
+        else:
+            self.non_parallel_reduction_suffix.writeline(
+                f"{result} = cascade_sum_final(&{helper_val});"
+            )
+            self.local_reduction_stores.writeline(
+                f"{result}_local = cascade_sum_final(&{helper_val});"
+            )
+
+    def reduction(self, dtype, src_dtype, reduction_type, value):
+        argmax_or_argmin = reduction_type in ("argmax", "argmin")
+        reduction_key = src_dtype, reduction_type, value
+        if reduction_key in self.reduction_cse.reduction_cache:
+            return self.reduction_cse.reduction_cache[reduction_key]
+
+        acc = self.reduction_cse.generate(
+            self.loads, f"reduction {reduction_key}", write=False
+        )
+        self.reduction_var_names.append(f"{acc}")
+        self.is_reduction = True
+        init_dtype = src_dtype if argmax_or_argmin else dtype
+        acc_type = reduction_acc_type(reduction_type, init_dtype)
+        self.reduction_prefix_generators.append(
+            self._gen_reduction_prefix(
+                acc, acc_type, reduction_type, init_dtype, reduction_init
+            )
+        )
+
+        if self.need_use_acc_helper(reduction_type, dtype, True):
+            # use cascade_helper for vec kernel
+            reduction_size = functools.reduce(
+                operator.mul, self.ranges[self.reduction_depth :]
+            )
+            # use welford_helper/cascade_helper for vec kernel
+            if reduction_type == "welford_reduce":
+                helper_val = self.welford_helper_cse.generate(
+                    self.compute, f"reduction {reduction_key}", write=False
+                )
+            else:
+                helper_val = self.cascade_helper_cse.generate(
+                    self.compute, f"reduction {reduction_key}", write=False
+                )
+            # rename the helper variable to distinguish it from vectorized version
+            scalar_helper_val = f"scalar_{helper_val}"
+            self._use_acc_helper(
+                reduction_type,
+                acc,
+                scalar_helper_val,
+                reduction_size,
+                dtype,
+                use_scalar=True,
+            )
+            self.stores.writeline(
+                f"{acc} = {reduction_combine(reduction_type, acc, value, scalar_helper_val)};"
+            )
+        else:
+            assert self.reduction_depth is not None
+            index = self.itervars[self.reduction_depth]
+            for i in range(self.reduction_depth + 1, len(self.itervars)):
+                index = index * self.ranges[i] + self.itervars[i]
+            self.stores.writeline(
+                f"{acc} = {reduction_combine(reduction_type, acc, value, index=index)};"
+            )
+
+        self._gen_parallel_reduction_buffers(acc, acc_type, reduction_type, init_dtype)
+        result = reduction_project(reduction_type, acc)
+        self.reduction_cse.reduction_cache[reduction_key] = result
+        return result
+
+    def store_reduction(self, name, index, value):
+        index = self.rename_indexing(index)
+        var = self.args.output(name)
+        self.reduction_suffix.writeline(
+            DeferredLine(name, f"{var}[{cexpr_index(index)}] = {value};")
+        )
+
+    def set_ranges(self, lengths, reduction_lengths):
+        if self.call_ranges:
+            assert self.call_ranges == tuple(lengths) + tuple(reduction_lengths), (
+                f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}"
+            )
+            assert self.reduction_depth == len(lengths)
+        else:
+            self.call_ranges = tuple(lengths) + tuple(reduction_lengths)
+            self.ranges = [self.rename_indexing(x) for x in self.call_ranges]
+            self.itervars = [
+                sympy_index_symbol_with_prefix(SymT.XBLOCK, n)
+                for n in range(len(self.ranges))
+            ]
+            # pyrefly: ignore [bad-assignment]
+            self.reduction_depth = len(lengths)
+        return (
+            self.itervars[: self.reduction_depth],
+            self.itervars[self.reduction_depth :],
+        )
+
+    def size_hint(self):
+        assert self.call_ranges is not None
+        return V.graph.sizevars.size_hint(
+            sympy_product(self.call_ranges), fallback=8192
+        )
+
+    def codegen_loops_impl(self, loop_nest, code, worksharing):
+        assert isinstance(self, CppKernelProxy)
+        threads = parallel_num_threads()
+        assert self.call_ranges is not None
+        if isinstance(loop_nest.kernel, OuterLoopFusedKernel):
+            par_depth = loop_nest.kernel.decide_parallel_depth(
+                loop_nest.max_parallel_depth(), threads
+            )
+        else:
+            par_depth = self.decide_parallel_depth(
+                loop_nest.max_parallel_depth(), threads
+            )
+
+        is_reduction_loop = (
+            loop_nest.loops is not None
+            and loop_nest.loops[par_depth.start_depth].is_reduction
+        )
+        with contextlib.ExitStack() as stack:
+            if par_depth.parallel_depth:
+                if is_reduction_loop:
+                    # need to close the worksharing scope to define reduction vars outside it
+                    worksharing.close()
+                else:
+                    worksharing.parallel(threads)
+                loop_nest.mark_parallel(par_depth)
+            elif threads > 1:
+                if worksharing.single():
+                    stack.enter_context(code.indent())
+
+            def gen_kernel(_loop_nest: LoopNest):
+                def is_parallel_reduction():
+                    assert _loop_nest.loops
+                    root = _loop_nest.loops[par_depth.start_depth]
+                    return root.is_reduction and root.parallel
+
+                kernel = _loop_nest.get_kernel()
+                if isinstance(kernel, OuterLoopFusedKernel):
+                    for _loop_nest in kernel.inner:
+                        gen_loop_nest(_loop_nest)
+                else:
+                    assert isinstance(kernel, CppKernelProxy)
+                    if _loop_nest.loops is not None and is_parallel_reduction():
+                        kernel.update_stores_with_parallel_reduction()
+                    with contextlib.ExitStack() as stack:
+                        stack.enter_context(code.indent())
+                        kernel.gen_body(code)
+
+            def get_reduction_prefix_suffix(kernel, parallel=False, is_suffix=False):
+                if is_suffix:
+                    suffix = kernel.reduction_suffix
+                    if parallel:
+                        suffix = kernel.parallel_reduction_suffix + suffix
+                    else:
+                        suffix = kernel.non_parallel_reduction_suffix + suffix
+                    return suffix
+                else:
+                    prefix = kernel.reduction_prefix
+                    if parallel:
+                        prefix = prefix + kernel.parallel_reduction_prefix
+                    else:
+                        prefix = prefix + kernel.non_parallel_reduction_prefix
+                    return prefix
+
+            def gen_loop_with_reduction(
+                _loop_nest: LoopNest, depth: int = 0, in_reduction=False
+            ):
+                kernel = _loop_nest.get_kernel()
+                assert _loop_nest.loops
+                loop = _loop_nest.loops[depth]
+                with contextlib.ExitStack() as stack_outer:
+                    if loop.is_reduction and not in_reduction:
+                        reduction_prefix = get_reduction_prefix_suffix(
+                            kernel, loop.parallel, is_suffix=False
+                        )
+                        if reduction_prefix:
+                            stack_outer.enter_context(code.indent())
+                        code.splice(reduction_prefix)
+                    if is_reduction_loop and loop.parallel:
+                        worksharing.parallel(threads)
+                        if kernel.local_reduction_init:
+                            assert kernel.local_reduction_stores
+                            code.splice(kernel.local_reduction_init)
+
+                    gen_loop_at(_loop_nest, depth)
+
+                    if is_reduction_loop and loop.parallel:
+                        if kernel.local_reduction_stores:
+                            code.splice(kernel.local_reduction_stores)
+                        worksharing.close()
+                    if loop.is_reduction and not in_reduction:
+                        code.splice(
+                            get_reduction_prefix_suffix(
+                                kernel, loop.parallel, is_suffix=True
+                            )
+                        )
+
+            def gen_loop_at(_loop_nest: LoopNest, depth: int = 0):
+                with contextlib.ExitStack() as stack:
+                    assert _loop_nest.loops
+                    loop = _loop_nest.loops[depth]
+                    loop_lines = loop.lines()
+                    if loop_lines is None:
+                        return
+                    code.writelines(loop_lines)
+                    stack.enter_context(code.indent())
+                    gen_loop_nest(_loop_nest, depth + 1, loop.is_reduction)
+
+            def gen_loop_nest(
+                _loop_nest: LoopNest,
+                depth: int = 0,
+                in_reduction: bool = False,
+            ):
+                if _loop_nest.loops is None or depth == len(_loop_nest.loops):  # type: ignore[arg-type]
+                    gen_kernel(_loop_nest)
+                else:
+                    gen_loop_with_reduction(_loop_nest, depth, in_reduction)
+
+            stack.enter_context(code.indent())
+
+            if (
+                isinstance(loop_nest.kernel, OuterLoopFusedKernel)
+                and isinstance(V.local_buffer_context, LocalBufferContext)
+                and V.local_buffer_context.local_buffers
+            ):
+                # Allocate local buffer
+                local_buffers = V.local_buffer_context.local_buffers
+                for local_buffer in local_buffers.values():
+                    # For dynamic size, rename s to ks
+                    local_buf_size = sympy_product(
+                        [
+                            self.rename_indexing(size_val)
+                            for size_val in local_buffer.get_layout().size
+                        ]
+                    )
+                    local_buf_dtype = DTYPE_TO_CPP[local_buffer.get_layout().dtype]
+                    allocate = f"std::make_unique<{local_buf_dtype} []>({cexpr(local_buf_size)})"
+                    local_buffer_name = local_buffer.get_name()
+                    code.splice(
+                        f"std::unique_ptr<{local_buf_dtype} []> buf_{local_buffer_name} = {allocate};"
+                    )
+                    code.splice(
+                        f"{local_buf_dtype}* {local_buffer_name} = buf_{local_buffer_name}.get();"
+                    )
+            gen_loop_nest(loop_nest)
+
+    def codegen_loops(self, code, worksharing):
+        loop_nest = LoopNest.build(self)
+        self.codegen_loops_impl(loop_nest, code, worksharing)
+
+    @property
+    def assert_function(self) -> str:
+        if V.graph.aot_mode:
+            return "AOTI_TORCH_CHECK"
+        else:
+            return "TORCH_CHECK"
+
+    def decide_parallel_depth(self, max_parallel_depth, threads):
+        assert self.call_ranges is not None
+        ranges = self.call_ranges[
+            max_parallel_depth.start_depth : (
+                max_parallel_depth.start_depth + max_parallel_depth.parallel_depth
+            )
+        ]
+        seq = self.size_hint()
+        par = 1
+        depth = 0
+        for expr in ranges:
+            hint = V.graph.sizevars.size_hint(expr, fallback=8192)
+            if par >= 2 * threads or par == threads:
+                break
+            if seq // threads < config.cpp.min_chunk_size:
+                # not enough work
+                break
+            depth += 1
+            par *= hint
+            seq /= hint
+        # if we assume thread number is dynamic, make sure we
+        # have at least one parallel scope and let OMP runtime
+        # to manage the serial vs. parallel.
+        if config.cpp.dynamic_threads and depth == 0 and len(ranges) > 0:
+            depth = 1
+        return ParallelDepth(
+            parallel_depth=depth, start_depth=max_parallel_depth.start_depth
+        )
+
+    @contextlib.contextmanager
+    def write_to_suffix(self):
+        prior = (self.loads, self.compute, self.stores, self.cse)
+        self.loads = IndentedBuffer()
+        self.compute = IndentedBuffer()
+        self.stores = IndentedBuffer()
+        self.cse = self.cse.clone()
+        yield
+        self.reduction_suffix.splice(self.loads)
+        self.reduction_suffix.splice(self.compute)
+        self.reduction_suffix.splice(self.stores)
+        (self.loads, self.compute, self.stores, self.cse) = prior
+
+    def create_cse_var(self, *args, **kwargs):
+        return CppCSEVariable(*args, **kwargs)
+
+    def get_to_dtype_expr(self, src, dtype, src_dtype):
+        return f"c10::convert<{DTYPE_TO_CPP[dtype]}>({src})"
+
+    def cache_dtype_convert(self, dst, dst_dtype, src, src_dtype):
+        expr = self.get_to_dtype_expr(src, dst_dtype, src_dtype)
+        self.cse.put(expr, dst)
+
+    def codegen_conditions(
+        self,
+        code: BracesBuffer,
+        prefix: Optional[str] = None,
+        var: Optional[sympy.Symbol] = None,
+    ):
+        if prefix is None:
+            prefix = ""
+        if not self.active_ranges:
+            return True
+        conditions = []
+
+        def gen(start, end, var):
+            if start == end:
+                return False
+            var_id = None
+            for i, _var in enumerate(self.itervars):
+                if var == _var:
+                    var_id = i
+                    break
+            if (
+                type(self) is CppKernel
+                and var_id
+                and start == 0
+                and end == self.ranges[var_id]
+            ):
+                end = 1
+            # pyrefly: ignore [bad-argument-type]
+            conditions.append(f"{var} >= {cexpr_index(start)}")
+            # pyrefly: ignore [bad-argument-type]
+            conditions.append(f"{var} < {cexpr_index(end)}")
+            return True
+
+        if var is not None:
+            assert var in self.active_ranges
+            start, end = self.active_ranges[var]
+            if not gen(start, end, var):
+                return False
+        else:
+            for _var, _range in self.active_ranges.items():
+                start, end = _range
+                if not gen(start, end, _var):
+                    return False
+        joined_conditions = " && ".join(conditions)
+        if joined_conditions:
+            code.writeline(f"if({prefix}({joined_conditions}))")
+            return True
+        else:
+            return False
+
+
+class CppVecKernel(CppKernel):
+    overrides = CppVecOverrides  # type: ignore[assignment]
+
+    def __init__(
+        self,
+        args,
+        num_threads,
+        tiling_factor,
+        tiling_idx,
+        tail_size=None,
+    ):
+        super().__init__(args, num_threads)
+        self.vec_isa = cpu_vec_isa.pick_vec_isa()
+        assert self.vec_isa
+        assert tiling_factor > 0, "Expect pass in Non-Zero tiling_factor explicitly"
+        self.tiling_factor = tiling_factor
+        self.tiling_idx = tiling_idx
+        self.tail_size = tail_size
+        self.num_elems = tail_size if tail_size else tiling_factor
+
+    def _try_get_const_stride(self, index: sympy.Expr, itervar: sympy.Symbol):
+        if self.index_indirect_depends_on(index, itervar):
+            return None
+        for indirect_var in (
+            self.cse.varname_map[s.name]  # type: ignore[attr-defined]
+            for s in index.free_symbols
+            if symbol_is_type(s, SymT.TMP)
+        ):
+            assert isinstance(indirect_var, CppCSEVariable)
+            if indirect_var.is_vec:
+                return None
+        stride = stride_at_vec_range(index, itervar, self.tiling_factor)
+        return stride if stride.is_number else None
+
+    def _get_num_vectors(self, dtype: torch.dtype) -> int:
+        num_vectors = math.ceil(
+            self.tiling_factor * dtype.itemsize * 8 / self.vec_isa.bit_width()
+        )
+        assert num_vectors >= 1
+        return num_vectors
+
+    def _get_raw_num_vectors(self, dtype: torch.dtype) -> float:
+        # This utility function is used to check if the vector lanes has been
+        # fully utilized. For example, uint8 will only use 1/4 of the vector lanes.
+        return self.tiling_factor * dtype.itemsize * 8 / self.vec_isa.bit_width()
+
+    def _get_vec_type(self, dtype: torch.dtype) -> str:
+        num_vectors = self._get_num_vectors(dtype)
+        if num_vectors == 1:
+            return f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>"
+        else:
+            return f"at::vec::VectorizedN<{DTYPE_TO_CPP[dtype]},{num_vectors}>"
+
+    def _get_mask_type(self, dtype: torch.dtype = torch.float) -> str:
+        if dtype == torch.bool:
+            return ""
+        num_vectors = self._get_num_vectors(dtype)
+        return f"at::vec::VecMask<{DTYPE_TO_CPP[dtype]},{num_vectors}>"
+
+    def _get_mask_cast(self, mask: CppCSEVariable, dtype: torch.dtype) -> str:
+        assert mask.dtype == torch.bool, repr(mask)
+        num_vectors = self._get_num_vectors(dtype)
+        return f"{mask}.template cast<{DTYPE_TO_CPP[dtype]},{num_vectors}>()"
+
+    def _get_vec_load_line(
+        self,
+        var: str,
+        index: sympy.Expr,
+        dtype: torch.dtype,
+        load_mask: Optional[CppCSEVariable] = None,
+    ):
+        """
+        Get a load line str that loads a vector from `var` at `index` of type `dtype`.
+        If `load_mask` is not None, we do a masked load accordingly.
+        Notes on the `dtype`:
+        1. We always load `self.tiling_factor` number of elements regardless of the `dtype`.
+           It means we load half of the vector lanes for 16-bit data types and quarter of the
+           vector lanes for 8-bit data types.
+        2. `torch.bool` and `torch.uint8` could mean masks and we load them as float mask vectors.
+        """
+        cpp_type = DTYPE_TO_CPP[dtype]
+        num_vectors = self._get_num_vectors(dtype)
+        load_mask_str = None
+        if load_mask:
+            if not load_mask.is_vec:
+                # TODO: avoid hard-code torch.float
+                load_mask_str = f"{self._get_mask_type(torch.float)}::from({load_mask})"
+            else:
+                load_mask_str = f"{self._get_mask_cast(load_mask, torch.float)}"
+        loadbuf = f"{var} + {cexpr_index(index)}" if index != 0 else var
+        if dtype == torch.bool:
+            # TODO: should we consider load mask here?
+            line = f"{self._get_mask_type()}::from({loadbuf}, {cexpr_index(self.num_elems)})"
+        else:
+            line = (
+                f"{load_mask_str}.template loadu<{cpp_type},{num_vectors}>({loadbuf})"
+                if load_mask_str
+                else f"{self._get_vec_type(dtype)}::loadu({loadbuf}, {cexpr_index(self.num_elems)})"
+            )
+        return line
+
+    def _load_or_store_non_contiguous(
+        self,
+        var: Optional[str],
+        index: sympy.Expr,
+        dtype: torch.dtype,
+        buffer: Optional[IndentedBuffer] = None,
+        store_value: Optional[Union[str, CppCSEVariable]] = None,
+        accu_store: bool = False,
+    ) -> Optional[CppCSEVariable]:
+        """
+        Load or store a vector in a non-contiguous way. The vector is initialized from an array that is
+        filled in an inner loop over the tiling factor.
+        :param var: buffer to load from or store to, i.e. `var[transformed(index)]`. If None, we load the index
+                    as index expression, i.e. `transformed(index)`.
+        :param index: index into the `var` or the index expression by its own if `var` is None.
+                      The `index` could contain indirect indexing or the tiling itervar. When used in
+                      the inner loop, the index is transformed as follows:
+                      1. the index is linearized along the tiling dim.
+                      2. the indirect indexing vector variables are transformed into arrays over the tiling dim.
+        :param dtype: data type of `var` or `index` if `var` is None.
+        :param buffer: the code buffer to write the generated code to. If None, we write to `self.loads`.
+        :param store_value: the value to store. If None, we load the vector.
+        :param accu_store: whether accumulate the store_value to store_ptr. If True, a store_value should be provided
+        :return: a CppCSEVariable that represents the loaded vector or None if it is a store.
+        """
+        assert not store_value or var is not None, "store var must be provided"
+        if accu_store:
+            assert store_value
+        if buffer is None:
+            buffer = self.loads
+
+        def get_result_size(dtype: torch.dtype) -> int:
+            if dtype.itemsize < 4:
+                return self.num_elems * (4 // dtype.itemsize)
+            else:
+                return self.num_elems
+
+        def get_tiling_size(dtype: torch.dtype) -> int:
+            if dtype.itemsize < 4:
+                return self.tiling_factor * (4 // dtype.itemsize)
+            else:
+                return self.tiling_factor
+
+        def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable:
+            assert vec_var.is_vec
+            code = BracesBuffer()
+            code.writeline("[&]")
+            with code.indent():
+                vec_dtype = vec_var.dtype
+                assert vec_dtype is not None
+                if vec_dtype == torch.bool:
+                    vec_dtype = torch.float
+                result_size = get_result_size(vec_dtype)
+                tiling_size = get_tiling_size(vec_dtype)
+                code.writeline(
+                    f"__at_align__ std::array<{DTYPE_TO_CPP[vec_dtype]}, {tiling_size}> tmpbuf;"
+                )
+                line = f"{vec_var}.store(tmpbuf.data(), {cexpr_index(result_size)});"
+                code.writeline(line)
+                code.writeline("return tmpbuf;")
+            code.writeline("()")
+            csevar = self.cse.generate(buffer, code)
+            assert isinstance(csevar, CppCSEVariable)
+            return csevar
+
+        code = BracesBuffer()
+        code.writeline("[&]")
+        with code.indent():
+            result_size = get_result_size(dtype)
+            tiling_size = get_tiling_size(dtype)
+            result_declare = (
+                f"__at_align__ std::array<{DTYPE_TO_CPP[dtype]}, {tiling_size}> tmpbuf;"
+            )
+            code.writeline(result_declare)
+            if store_value:
+                code.writeline(
+                    f"{store_value}.store(tmpbuf.data(), {cexpr_index(result_size)});"
+                )
+            itervar_inner = sympy_index_symbol(
+                f"{self.itervars[self.tiling_idx]}_inner"
+            )
+            replacements = {}
+            for indirect_var in (
+                self.cse.varname_map[s.name]  # type: ignore[attr-defined]
+                for s in index.free_symbols
+                if symbol_is_type(s, SymT.TMP)
+            ):
+                assert isinstance(indirect_var, CppCSEVariable)
+                if indirect_var.is_vec:
+                    array_var = vec_to_array(indirect_var)
+                    replacements[indirect_var] = f"{array_var}[{itervar_inner}]"
+            index = self.scale_index_with_offset(
+                index, itervar_idx=self.tiling_idx, offset=itervar_inner
+            )
+            load_mask = None
+            if self._load_mask is not None:
+                assert not store_value, "unexpected store with load mask"
+                assert isinstance(self._load_mask, CppCSEVariable), self._load_mask
+                if self._load_mask.is_vec:
+                    load_mask = f"{self._load_mask}.is_masked({itervar_inner})"
+                else:
+                    load_mask = f"{self._load_mask} != 0"
+            if cpp_builder.is_gcc():
+                code.writeline(f"#pragma GCC unroll {self.tiling_factor}")
+            else:
+                code.writeline(f"#pragma unroll {self.tiling_factor}")
+            code.writeline(
+                f"for (long {itervar_inner} = 0; "
+                + f"{itervar_inner} < {cexpr_index(self.num_elems)}; "
+                + f"{itervar_inner}++)"
+            )
+            with code.indent(), contextlib.ExitStack() as stack:
+                index_c = cexpr_index(index)
+                for indirect_var in replacements:
+                    index_c = re.sub(
+                        r"\b" + f"{indirect_var}" + r"\b",
+                        replacements[indirect_var],
+                        index_c,
+                    )
+                rhs = f"{var}[{index_c}]" if var is not None else f"{index_c}"
+                if load_mask:
+                    code.writeline(f"if ({load_mask})")
+                    stack.enter_context(code.indent())
+                if store_value:
+                    conjunction = "+=" if accu_store else "="
+                    code.writeline(f"{rhs} {conjunction} tmpbuf[{itervar_inner}];")
+                else:
+                    code.writeline(f"tmpbuf[{itervar_inner}] = {rhs};")
+            if not store_value:
+                load_line = self._get_vec_load_line("tmpbuf.data()", 0, dtype)  # type: ignore[arg-type]
+                code.writeline(f"return {load_line};")
+        code.writeline("()")
+        if store_value:
+            code.writeline(";")
+            buffer.splice(code)
+            return None
+        else:
+            csevar = self.cse.generate(buffer, code, dtype=dtype)
+            assert isinstance(csevar, CppCSEVariable)
+            csevar.is_vec = True
+            return csevar
+
+    def load(self, name: str, index: sympy.Expr):
+        var = self.args.input(name)
+        index = self.rename_indexing(index)
+        dtype = V.graph.get_dtype(name)
+        tiling_var = self.itervars[self.tiling_idx]
+        stride = self._try_get_const_stride(index, tiling_var)
+        if stride == 0:
+            # load scalar and lazily broadcast it on demand
+            return super().load(name, index)
+        elif stride == 1:
+            # load contiguously
+            line = self._get_vec_load_line(var, index, dtype, self._load_mask)  # type: ignore[arg-type]
+            csevar = self.cse.generate(self.loads, line, dtype=dtype)  # type: ignore[assignment]
+        else:
+            csevar = self._load_or_store_non_contiguous(var, index, dtype)  # type: ignore[assignment]
+        assert isinstance(csevar, CppCSEVariable)
+        csevar.update_on_args("load", (self, name, index), {})
+        csevar.is_vec = True
+        return csevar
+
+    def _get_store_line(
+        self,
+        value: Union[str, CppCSEVariable],
+        var: str,
+        index: sympy.Expr,
+        dtype: torch.dtype,
+        accu_store: bool = False,
+    ):
+        """
+        Get a store line buffer that stores `value` into `var` at `index` of `dtype`. It handles
+        both contiguous and non-contiguous store cases.
+        :param value: Vectorized type templaterized on `dtype`.
+        :param var: buffer to store into.
+        :index: index into the `var`.
+        """
+        # when value's type is str (e.g., welford reduction), caller should make sure
+        # it is a vector
+        assert isinstance(value, str) or (
+            isinstance(value, CppCSEVariable) and value.is_vec
+        ), value
+        tiling_var = self.itervars[self.tiling_idx]
+        var_expr = f"{var} + {cexpr_index(index)}"
+        stride = self._try_get_const_stride(index, tiling_var)
+        code = IndentedBuffer()
+        if stride == 1:
+            if accu_store:
+                load = (
+                    f"{self._get_vec_type(dtype)}::loadu({var_expr})"
+                    if dtype == torch.float and self.tail_size is None
+                    else f"{self._get_vec_type(dtype)}::loadu({var_expr}, {cexpr_index(self.num_elems)})"
+                )
+                value = f"({value} + {load})"
+            if dtype == torch.float and self.tail_size is None:
+                code.writeline(f"{value}.store({var_expr});")
+            else:
+                code.writeline(
+                    f"{value}.store({var_expr}, {cexpr_index(self.num_elems)});"
+                )
+        else:
+            self._load_or_store_non_contiguous(
+                var, index, dtype, buffer=code, store_value=value, accu_store=accu_store
+            )
+        return code
+
+    def store(self, name, index, value, mode=None):
+        assert "buf" in name
+        assert isinstance(value, CppCSEVariable), value
+        if not value.is_vec:
+            # this happens when we store a scalar into a vectorized buffer like "fill"
+            value = self.broadcast(value)
+        var = self.args.output(name)
+        index = self.rename_indexing(index)
+        dtype = V.graph.get_dtype(name)
+        if mode is None:
+            code = self._get_store_line(value, var, index, dtype)
+            self.stores.splice(code.map(lambda x: DeferredLine(name, x)))
+        elif mode == "atomic_add":
+            if not config.cpp.dynamic_threads and self.num_threads == 1:
+                code = self._get_store_line(
+                    f"{value}",
+                    var,
+                    index,
+                    dtype,
+                    accu_store=True,
+                )
+                self.stores.splice(code.map(lambda x: DeferredLine(name, x)))
+            else:
+                n_src = self._get_num_vectors(dtype)
+                n_idx = self._get_num_vectors(torch.int64)
+                cdtype = DTYPE_TO_CPP[dtype]
+                index = ops.index_expr(index, torch.int64).value
+                assert isinstance(index, CppCSEVariable) and index.is_vec
+                if self.tail_size:
+                    line = f"atomic_add_vec<{cdtype}, {n_idx}, {n_src}>({var}, {index}, {value}, {cexpr_index(self.tail_size)});"
+                else:
+                    line = f"atomic_add_vec<{cdtype}, {n_idx}, {n_src}>({var}, {index}, {value});"
+                self.stores.writeline(DeferredLine(name, line))
+        else:
+            raise NotImplementedError(f"store mode={mode}")
+
+    def reduction(self, dtype, src_dtype, reduction_type, value):
+        """
+        Perform vectorized reduction operation.
+
+        This method handles vectorized reduction for different reduction types.
+        It manages special cases for low-precision floating point types and
+        employs precision improvement techniques for certain reduction operations.
+
+        Args:
+            dtype: The output data type for the reduction result
+            src_dtype: The source data type of the input value
+            reduction_type: Type of reduction operation (sum, min, max, etc.)
+            value: The input value to reduce
+
+        Returns:
+            The result of the reduction operation
+        """
+        # Note: For argmax and argmin on bool type, we always convert bool to float.
+        # Fix issue: https://github.com/pytorch/pytorch/issues/143568
+        assert reduction_type in VECTORIZABLE_RTYPES
+        argmax_or_argmin = reduction_type in ("argmax", "argmin")
+        horizontal_reduction = self.tiling_idx >= self.reduction_depth
+        init_dtype = src_dtype if argmax_or_argmin else dtype
+        assert isinstance(value, CppCSEVariable), value
+
+        if not value.is_vec:
+            value = self.broadcast(value)
+
+        reduction_key = src_dtype, reduction_type, value
+        if reduction_key in self.reduction_cse.reduction_cache:
+            return self.reduction_cse.reduction_cache[reduction_key]
+
+        vec_ns = "at::vec"
+        vec = f"{vec_ns}::Vectorized<{DTYPE_TO_CPP[dtype]}>"
+        acc_type = reduction_acc_type(reduction_type, init_dtype)
+        acc_type_vec = self.reduction_acc_type_vec(reduction_type, init_dtype)
+
+        acc = self.reduction_cse.generate(
+            self.loads, f"reduction {reduction_key}", write=False
+        )
+        assert isinstance(acc, CppCSEVariable)
+        acc_vec = f"{acc}_vec"
+        masked_acc = f"masked_{acc}"
+        masked_acc_vec = f"masked_{acc_vec}"
+        self.reduction_var_names += [f"{acc}", acc_vec, masked_acc_vec]
+        self.is_reduction = True
+        self.reduction_prefix_generators.append(
+            self._gen_reduction_prefix(
+                acc, acc_type, reduction_type, init_dtype, reduction_init
+            )
+        )
+        self.reduction_prefix_generators.append(
+            self._gen_reduction_prefix(
+                acc_vec,
+                acc_type_vec,
+                reduction_type,
+                init_dtype,
+                self.reduction_init_vec,
+            )
+        )
+
+        use_acc_helper = self.need_use_acc_helper(reduction_type, dtype, False)
+        if use_acc_helper:
+            # use masked acc_vec for tail vec kernel
+            self.reduction_prefix_generators.append(
+                self._gen_reduction_prefix(
+                    masked_acc_vec,
+                    acc_type_vec,
+                    reduction_type,
+                    dtype,
+                    self.reduction_init_vec,
+                )
+            )
+
+            # use welford_helper/cascade_helper for vec kernel
+            assert self.reduction_depth is not None
+            reduction_size = functools.reduce(
+                operator.mul, self.ranges[self.reduction_depth :]
+            )
+            if reduction_type == "welford_reduce":
+                helper_val = self.welford_helper_cse.generate(
+                    self.compute, f"reduction {reduction_key}", write=False
+                )
+            else:
+                helper_val = self.cascade_helper_cse.generate(
+                    self.compute, f"reduction {reduction_key}", write=False
+                )
+            masked_helper_val = f"masked_{helper_val}"
+            helper_vec_range = (
+                (
+                    FloorDiv(reduction_size, self.ranges[self.tiling_idx])
+                    * FloorDiv(self.ranges[self.tiling_idx], self.tiling_factor)
+                    if self.tiling_idx >= self.reduction_depth
+                    else reduction_size
+                )
+                if FloorDiv(self.ranges[self.tiling_idx], self.tiling_factor)
+                else sympy.Integer(0)
+            )
+            masked_helper_vec_range = (
+                (
+                    FloorDiv(reduction_size, self.ranges[self.tiling_idx])
+                    if self.tiling_idx >= self.reduction_depth
+                    else reduction_size
+                )
+                if self.ranges[self.tiling_idx] % self.tiling_factor
+                else sympy.Integer(0)
+            )
+            # scalar helper for scalar welford_reduce/sum is also needed when vec kernel is included
+            scalar_helper_val = f"scalar_{helper_val}"
+            self._use_acc_helper(
+                reduction_type,
+                acc,
+                scalar_helper_val,
+                reduction_size,
+                dtype,
+                use_scalar=True,
+            )
+            self._use_acc_helper(
+                reduction_type, acc, helper_val, helper_vec_range, dtype
+            )
+            self._use_acc_helper(
+                reduction_type,
+                masked_acc,
+                masked_helper_val,
+                masked_helper_vec_range,
+                dtype,
+            )
+
+            # use masked acc_vec for tail vec kernel
+            acc_vec_ = masked_acc_vec if self.tail_size else acc_vec
+            helper_val_ = masked_helper_val if self.tail_size else helper_val
+            if reduction_type == "sum":
+                self.stores.writeline(
+                    f"{acc_vec_} = {self.reduction_combine_vec(reduction_type, acc_vec_, value, helper_val_)};"
+                )
+            else:
+                self.stores.writeline(
+                    f"{acc_vec_} = {self.reduction_combine_vec(reduction_type, acc_vec_, value, helper_val_)};"
+                )
+        else:
+            assert self.reduction_depth is not None
+            index = self.itervars[self.reduction_depth]
+            for i in range(self.reduction_depth + 1, len(self.itervars)):
+                index = index * self.ranges[i] + self.itervars[i]
+            kwargs = {
+                "next_value": value,
+                "index": index,
+                "horizontal_reduction": horizontal_reduction,
+                "src_dtype": src_dtype,
+            }
+            self.stores.writeline(
+                f"{acc_vec} = {self.reduction_combine_vec(reduction_type, acc_vec, **kwargs)};"
+            )
+        self._gen_parallel_reduction_buffers(
+            acc_vec,
+            acc_type_vec,
+            reduction_type,
+            init_dtype,
+            reduction_combine_fn=self.reduction_combine_vec,
+            reduction_init_fn=self.reduction_init_vec,
+        )
+        self._gen_parallel_reduction_buffers(
+            acc,
+            acc_type,
+            reduction_type,
+            init_dtype,
+            reduction_combine_fn=reduction_combine,
+            reduction_init_fn=reduction_init,
+        )
+        if use_acc_helper:
+            # use masked acc_vec for tail vec kernel
+            self._gen_parallel_reduction_buffers(
+                masked_acc_vec,
+                acc_type_vec,
+                reduction_type,
+                dtype,
+                reduction_combine_fn=self.reduction_combine_vec,
+                reduction_init_fn=self.reduction_init_vec,
+            )
+        tmpvar: Union[str, CSEVariable]
+        is_bool = dtype == torch.bool
+        if horizontal_reduction:
+            # Horizontal reduction
+            if is_welford_reduction(reduction_type):
+                assert self._get_num_vectors(dtype) in [
+                    1,
+                    2,
+                ], "Welford reduction does not support VectorizedN (N>2)"
+                next_value = f"welford_vec_reduce_all({acc_vec})"
+                masked_next_value = f"welford_vec_reduce_all({masked_acc_vec})"
+                self.reduction_suffix.writeline(
+                    f"{acc} = {reduction_combine(reduction_type, acc, masked_next_value)};"
+                )
+            elif argmax_or_argmin:
+                next_value = f"{reduction_type}_vec_reduce_all({acc_vec})"
+            elif is_bool:
+                if reduction_type in (
+                    "any",
+                    "sum",
+                    "max",
+                ):
+                    next_value = f"!{acc_vec}.all_zero()"
+                else:
+                    assert reduction_type == "min"
+                    next_value = f"{acc_vec}.all_masked()"
+            else:
+                reduce_all_body = (
+                    "{ return "
+                    + self.reduction_combine_vec(reduction_type, "x", "y")
+                    + "; }"
+                )
+                is_bool = dtype == torch.bool
+                # we are using at::vec::VecMask for bool
+                vec_dtype = torch.float if is_bool else dtype
+                vec = f"at::vec::Vectorized<{DTYPE_TO_CPP[vec_dtype]}>"
+                vec_reduce_all_func = f"at::vec::vec_reduce_all<{DTYPE_TO_CPP[vec_dtype]}, {self._get_num_vectors(vec_dtype)}>"
+                result_vec = f"{acc_vec}"
+                if use_acc_helper:
+                    assert reduction_type == "sum"
+                    result_vec = f"{acc_vec} + {masked_acc_vec}"
+                next_value = f"{vec_reduce_all_func}([]({vec}& x, {vec}& y) {reduce_all_body}, {result_vec})"
+
+            self.reduction_suffix.writeline(
+                f"{acc} = {reduction_combine(reduction_type, acc, next_value, src_dtype=src_dtype)};"
+            )
+            tmpvar = acc
+        else:
+            tmpvar = acc_vec
+            if is_welford_reduction(reduction_type):
+                masked_tmpvar = f"masked_{tmpvar}"
+                self.reduction_suffix.writeline(
+                    f"{tmpvar} = {reduction_combine(reduction_type, tmpvar, masked_tmpvar)};"
+                )
+            elif use_acc_helper:
+                assert reduction_type == "sum"
+                masked_tmpvar = f"masked_{tmpvar}"
+                self.reduction_suffix.writeline(
+                    f"{tmpvar} = {tmpvar} + {masked_tmpvar};"
+                )
+
+        result = reduction_project(reduction_type, tmpvar)
+        self.reduction_cse.reduction_cache[reduction_key] = result
+        return result
+
+    def store_reduction(self, name, index, value):
+        index = self.rename_indexing(index)
+        var = self.args.output(name)
+        out_dtype = V.graph.get_dtype(name)
+        if out_dtype.is_floating_point and out_dtype != torch.double:
+            dtype = torch.float
+        else:
+            dtype = out_dtype
+        out_num_vectors = V.kernel._get_num_vectors(out_dtype)
+        src_num_vectors = V.kernel._get_num_vectors(dtype)
+        code = IndentedBuffer()
+        if self.tiling_idx >= self.reduction_depth:
+            # Horizontal reduction
+            code.writeline(
+                f"{var}[{cexpr_index(index)}] = static_cast<{DTYPE_TO_CPP[out_dtype]}>({value});"
+            )
+        else:
+            # Vertical reduction
+            if out_dtype != dtype:
+                converted_value = (
+                    f"{DTYPE_TO_CPP[out_dtype].replace('::', '_')}_{value}"
+                )
+                if out_dtype == torch.bool:
+                    convert = f"{value}.template cast()"
+                else:
+                    if src_num_vectors == out_num_vectors == 1:
+                        convert = (
+                            f"at::vec::convert<{DTYPE_TO_CPP[out_dtype]}>({value})"
+                        )
+                    else:
+                        convert = (
+                            f"at::vec::convert<{DTYPE_TO_CPP[out_dtype]},"
+                            f"{out_num_vectors},{DTYPE_TO_CPP[dtype]},{src_num_vectors}>({value})"
+                        )
+                code.writeline(f"auto {converted_value} = {convert};")
+                value = converted_value
+            code.splice(self._get_store_line(value, var, index, out_dtype))
+        self.reduction_suffix.splice(code.map(lambda x: DeferredLine(name, x)))
+
+    def broadcast(self, scalar_var: CppCSEVariable) -> CppCSEVariable:
+        assert not scalar_var.is_vec
+        if scalar_var.dtype == torch.bool:
+            vec_var = self.cse.generate(
+                self.compute, f"{self._get_mask_type()}::from({scalar_var.name})"
+            )
+        else:
+            assert scalar_var.dtype is not None
+            vec_var = self.cse.generate(
+                self.compute,
+                f"{self._get_vec_type(scalar_var.dtype)}({scalar_var.name})",
+            )
+        assert isinstance(vec_var, CppCSEVariable)
+        vec_var.dtype = scalar_var.dtype
+        vec_var.dependent_itervars = scalar_var.dependent_itervars
+        vec_var.is_vec = True
+        return vec_var
+
+    def arange(self, index: CppCSEVariable, stride: sympy.Symbol) -> CppCSEVariable:
+        assert not index.is_vec
+        assert index.dtype is not None
+        csevar = self.cse.generate(
+            self.compute,
+            f"{self._get_vec_type(index.dtype)}::arange({index}, {stride})",
+        )
+        assert isinstance(csevar, CppCSEVariable)
+        csevar.dtype = index.dtype
+        csevar.is_vec = True
+        return csevar
+
+    def reduction_init_vec(self, reduction_type, dtype):
+        scalar_type = DTYPE_TO_COMPUTATION_DTYPE[dtype]
+        vec_type = self._get_vec_type(scalar_type)
+
+        if is_welford_reduction(reduction_type):
+            return f"Welford<{vec_type}>()"
+
+        if reduction_type in ("argmin", "argmax"):
+            cdtype = DTYPE_TO_CPP[scalar_type]
+            acc_type = self.reduction_acc_type_vec(reduction_type, dtype)
+            if reduction_type == "argmin":
+                val = (
+                    f"std::numeric_limits<{cdtype}>::infinity()"
+                    if is_float_dtype(dtype)
+                    else f"std::numeric_limits<{cdtype}>::max()"
+                )
+            else:
+                val = (
+                    f"-std::numeric_limits<{cdtype}>::infinity()"
+                    if is_float_dtype(dtype)
+                    else f"std::numeric_limits<{cdtype}>::min()"
+                )
+            return f"{acc_type}({val})"
+
+        if reduction_type == "any":
+            return f"{self._get_mask_type()}::from(0)"
+
+        scalar_init = reduction_init(reduction_type, dtype)
+        vec_init = f"{vec_type}({scalar_init})"
+        if dtype == torch.bool:
+            assert reduction_type in ("min", "max", "sum")
+            return f"{self._get_mask_type()}::from({scalar_init})"
+        return vec_init
+
+    def reduction_acc_type_vec(self, reduction_type, dtype):
+        scalar_type = DTYPE_TO_COMPUTATION_DTYPE[dtype]
+        vec_type = self._get_vec_type(scalar_type)
+        if is_welford_reduction(reduction_type):
+            return f"Welford<{vec_type}>"
+        if reduction_type in ("argmin", "argmax"):
+            n_src = self._get_num_vectors(scalar_type)
+            n_idx = self._get_num_vectors(torch.int64)
+            if dtype == torch.bool:
+                return f"IndexValueVec<{DTYPE_TO_CPP[torch.float]}, {n_src}, {n_idx}>"
+            return f"IndexValueVec<{DTYPE_TO_CPP[scalar_type]}, {n_src}, {n_idx}>"
+        if dtype == torch.bool:
+            assert reduction_type in ("min", "max", "any", "sum")
+            return f"{self._get_mask_type()}"
+        return vec_type
+
+    def reduction_combine_vec(
+        self,
+        reduction_type,
+        var,
+        next_value,
+        helper_val=None,
+        index: Optional[sympy.Symbol] = None,
+        horizontal_reduction: Optional[bool] = None,
+        src_dtype: Optional[torch.dtype] = torch.float32,
+    ):
+        is_bool = src_dtype == torch.bool
+        if reduction_type == "max":
+            if self.tail_size:
+                return f"max_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})"
+            else:
+                return (
+                    f"{var} | {next_value}"
+                    if is_bool
+                    else f"at::vec::maximum({var}, {next_value})"
+                )
+        elif reduction_type == "min":
+            if self.tail_size:
+                return f"min_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})"
+            else:
+                return (
+                    f"{var} & {next_value}"
+                    if is_bool
+                    else f"at::vec::minimum({var}, {next_value})"
+                )
+        elif reduction_type == "sum":
+            if helper_val:
+                if self.tail_size:
+                    return f"cascade_sum_combine({next_value}, {cexpr_index(self.tail_size)}, &{helper_val})"
+                else:
+                    return f"cascade_sum_combine({next_value}, &{helper_val})"
+            else:
+                if self.tail_size:
+                    return f"sum_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})"
+                else:
+                    conjunction = "|" if is_bool else "+"
+                    return f"{var} {conjunction} {next_value}"
+        elif reduction_type == "prod":
+            if self.tail_size:
+                return f"prod_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})"
+            else:
+                return f"{var} * {next_value}"
+        elif reduction_type == "xor_sum":
+            if self.tail_size:
+                return f"xor_sum_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})"
+            else:
+                return f"{var} ^ {next_value}"
+        elif reduction_type == "welford_reduce":
+            if helper_val:
+                if self.tail_size:
+                    return f"welford_combine({var}, {next_value}, {cexpr_index(self.tail_size)}, &{helper_val})"
+                else:
+                    return f"welford_combine({var}, {next_value}, &{helper_val})"
+            else:
+                if self.tail_size:
+                    return f"welford_combine({var}, {next_value}, {cexpr_index(self.tail_size)})"
+                else:
+                    return f"welford_combine({var}, {next_value})"
+        elif reduction_type == "welford_combine":
+            if isinstance(next_value, tuple):
+                # When reading a value from Inductor IR we have a tuple of variable names
+                mean, m2, weight = next_value
+            else:
+                # When combining intermediate accumulators we have a Welford struct
+                mean, m2, weight = reduction_project(reduction_type, next_value)
+            if self.tail_size:
+                return f"welford_combine({var}, {{{mean}, {m2}, {weight}}}, {cexpr_index(self.tail_size)})"
+            else:
+                return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})"
+        elif reduction_type in ("argmin", "argmax"):
+            assert src_dtype is not None
+            cdtype = DTYPE_TO_CPP[src_dtype]
+            if src_dtype == torch.bool:
+                cdtype = DTYPE_TO_CPP[torch.float]
+            n_src = self._get_num_vectors(src_dtype)
+            n_idx = self._get_num_vectors(torch.int64)
+            t_extra = ""
+            arg_extra = ""
+            if index is not None:
+                assert horizontal_reduction is not None
+                t_extra = f", {str(horizontal_reduction).lower()}"
+                arg_extra = f", {index}"
+            if self.tail_size:
+                return (
+                    f"{reduction_type}_combine_vec<{cdtype}, {n_src}, {n_idx}{t_extra}>"
+                    f"({var}, {next_value}{arg_extra}, {cexpr_index(self.tail_size)})"
+                )
+            else:
+                return f"{reduction_type}_combine_vec<{cdtype}, {n_src}, {n_idx}{t_extra}>({var}, {next_value}{arg_extra})"
+        elif reduction_type == "any":
+            if isinstance(next_value, CppCSEVariable):
+                assert next_value.dtype == torch.bool
+                (next_value,) = unify_mask_base_type(V.kernel.compute, (next_value,))
+            if self.tail_size:
+                return f"any_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})"
+            else:
+                return f"{var} | {next_value}"
+        else:
+            raise NotImplementedError
+
+    def indirect_assert(self, var, lower, upper, mask=None):
+        assert isinstance(var, CppCSEVariable)
+        assert var.dtype is not None
+        if not var.is_vec:
+            if isinstance(mask, CppCSEVariable) and mask.is_vec:
+                mask = f"({mask}).all_masked()"
+            return super().indirect_assert(var, lower, upper, mask)
+        lower_scalar = lower
+        upper_scalar = upper
+        if lower:
+            lower = f"{self._get_vec_type(var.dtype)}({lower})"
+        if upper:
+            upper = f"{self._get_vec_type(var.dtype)}({upper})"
+        if lower and upper:
+            cond = f"({lower} <= {var}) & ({var} < {upper})"
+            cond_print = f"{lower_scalar} <= {var} < {upper_scalar}"
+        elif lower:
+            cond = f"{lower} <= {var}"
+            cond_print = f"{lower_scalar} <= {var}"
+        else:
+            assert upper
+            cond = f"{var} < {upper}"
+            cond_print = f"{var} < {upper_scalar}"
+        cond = f"{self._get_mask_type(var.dtype)}({cond})"
+        if mask:
+            if not mask.is_vec:
+                mask = f"{self._get_mask_type(var.dtype)}({mask})"
+            # We need not check when the mask is False
+            cond = f"({cond}) | ~({mask})"
+        if self.tail_size:
+            cond = (
+                f"{self._get_mask_type(var.dtype)}::set({self._get_mask_type(var.dtype)}::from(1)"
+                f", ({cond}), {cexpr_index(self.tail_size)})"
+            )
+        cond = f"({cond}).all_masked()"
+        return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")'
+
+    def get_to_dtype_expr(self, src, dtype, src_dtype):
+        assert isinstance(src, CppCSEVariable)
+        if not src.is_vec:
+            return super().get_to_dtype_expr(src, dtype, src_dtype)
+        src_cpp_type = DTYPE_TO_CPP[src_dtype]
+        src_num_vectors = self._get_num_vectors(src_dtype)
+        dst_cpp_type = DTYPE_TO_CPP[dtype]
+        dst_num_vectors = self._get_num_vectors(dtype)
+        expr = f"({src})"
+        if src_dtype != torch.bool and dtype == torch.bool:
+            expr = f"{self._get_mask_type(src_dtype)}::from<{src_cpp_type},{src_num_vectors}>({src})"
+        elif src_dtype == torch.bool and dtype != torch.bool:
+            expr = f"{src}.to<{dst_cpp_type},{dst_num_vectors}>()"
+        elif src_dtype != dtype:
+            if src_num_vectors == dst_num_vectors == 1:
+                expr = f"at::vec::convert<{dst_cpp_type}>({src})"
+            else:
+                expr = f"at::vec::convert<{dst_cpp_type},{dst_num_vectors},{src_cpp_type},{src_num_vectors}>({src})"
+        return expr
+
+
+class CppTile2DKernel(CppVecKernel):
+    """
+    A vector kernel that handles the 2d tiles with the tile size defined in `tiling_factor` on
+    the inner-most loop level and one of the outer loop level (`outer_tiling_idx`). When the data
+    tile is accessed in a contiguous way from the outer loop axis, a transposition is applied on the
+    tile to make the access contiguous from the inner-most loop axis. Then, the same vectorization
+    logic from its parent `CppVecKernel` is leveraged for load/store/compute. The transposed tile load
+    and store are generated into kernel.preloads and kernel.poststores buffers.
+
+    The loop structure looks like below:
+    for ...
+      for i_outer ...
+        for ...
+          for inner_most ...
+            // generated by CppTile2DKernel
+            float tmp0[16*16]; at::vec::transpose_mxn<...>(tmp0, in_ptr0 + ..., ...); // into kernel.preloads
+            float tmp1[16*16]; // into kernel.preloads
+            for i_inner ... { // the kernel inner loop
+              vectorized loads/compute/stores (e.g., load tmp0, store tmp1) // into kernel.loads/compute/stores
+            }
+            at::vec::transpose_mxn(out_ptr0 + ..., tmp1, ...) // into kernel.poststores
+          for inner_most ... (tail)
+            // generated by CppVecKernel
+            ...
+      for i_outer ... (tail)
+        for ...
+          for ...
+            // generated by CppKernel
+            ...
+    """
+
+    overrides = CppTile2DOverrides  # type: ignore[assignment]
+
+    def __init__(
+        self,
+        args,
+        num_threads,
+        tiling_factor,
+        tiling_indices,
+        inner_tail_size=None,
+        outer_tail_size=None,
+    ):
+        super().__init__(
+            args,
+            num_threads,
+            tiling_factor,
+            tiling_indices[1],
+            inner_tail_size,
+        )
+        self.tiling_indices = tiling_indices
+        self.inner_tail_size = inner_tail_size
+        self.outer_tail_size = outer_tail_size
+        self.inner_num_elems = inner_tail_size if inner_tail_size else tiling_factor
+        self.outer_num_elems = outer_tail_size if outer_tail_size else tiling_factor
+        self.inner_is_tiling_idx = True
+
+    def inner_itervar(self):
+        return sympy_index_symbol(f"{self.itervars[self.outer_idx]}_inner")
+
+    def need_vec_transpose(self, index):
+        outer_var = self.itervars[self.outer_idx]
+        inner_var = self.itervars[self.tiling_idx]
+        outer_stride = stride_at_vec_range(index, outer_var, self.tiling_factor)
+        inner_stride = stride_at_vec_range(index, inner_var, self.tiling_factor)
+        return (
+            self._load_mask is None  # TODO: support transposition with mask
+            and outer_stride == 1
+            and index.has(inner_var)
+            and not inner_stride.has(inner_var)
+            and not inner_stride.has(outer_var)
+        )
+
+    def gen_transposed_tile_load_store(
+        self, name, var, index, is_store, store_mode=None
+    ):
+        # transposed tile load/store outside the kernel inner loop
+        dtype = V.graph.get_dtype(name)
+        factor = self.tiling_factor
+        src = f"{var} + {cexpr_index(index)}"
+        dst = "__place_holder__"
+        ld_src = f"{cexpr_index(stride_at_vec_range(index, self.itervars[self.tiling_idx], self.tiling_factor))}"
+        ld_dst = f"{cexpr_index(self.num_elems)}"
+        if is_store:
+            src, dst = dst, src
+            ld_src, ld_dst = ld_dst, ld_src
+
+        need_define = True
+        if self.inner_is_tiling_idx ^ is_store:
+            M, N = self.inner_num_elems, self.outer_num_elems
+        else:
+            M, N = (
+                self.outer_num_elems,
+                self.inner_num_elems,
+            )
+        atomic_add = "true" if (is_store and (store_mode == "atomic_add")) else "false"
+        if (isinstance(M, sympy.Expr) and not M.is_number) or (
+            isinstance(N, sympy.Expr) and not N.is_number
+        ):
+            load_or_store = (
+                f"transpose_mxn<{DTYPE_TO_CPP[dtype]},{atomic_add}>"
+                f"({src}, {ld_src}, {dst}, {ld_dst}, {cexpr_index(M)}, {cexpr_index(N)});"
+            )
+        else:
+            load_or_store = (
+                f"transpose_mxn<{DTYPE_TO_CPP[dtype]},{cexpr_index(M)},{cexpr_index(N)},{atomic_add}>"
+                f"({src}, {ld_src}, {dst}, {ld_dst});"
+            )
+        if is_store:
+            tile_var = self.cse.newvar()
+        elif not self.cse.contains(load_or_store):
+            tile_var = self.cse.generate(self.preloads, load_or_store, write=False)
+        else:
+            need_define = False
+            tile_var = self.cse.get(load_or_store)
+
+        if need_define:
+            cpp_dtype = DTYPE_TO_CPP[dtype]
+            # tiling_factor might be smaller than the alignment of cpp_dtype, such as
+            # with a vector that only holds 4 elements due to NEON 128-bit vectors and
+            # cpp_dtype being a 64-bit integer.
+            alignas = f"alignas(std::max(std::size_t({factor}), alignof({cpp_dtype})))"
+            define_line = f"{alignas} {cpp_dtype} {tile_var}[{factor}*{factor}];"
+            self.preloads.writeline(define_line)
+
+        load_or_store = load_or_store.replace("__place_holder__", str(tile_var))
+        if is_store:
+            self.poststores.writeline(DeferredLine(name, load_or_store))
+        else:
+            self.preloads.writeline(load_or_store)
+
+        return tile_var
+
+    def load(self, name: str, index: sympy.Expr):
+        var = self.args.input(name)
+        index = self.rename_indexing(index)
+
+        inner = self.inner_itervar()
+        if self.need_vec_transpose(index):
+            tile_var = self.gen_transposed_tile_load_store(
+                name, var, index, is_store=False
+            )
+            # vector load inside the kernel inner loop
+            loadbuf = f"{tile_var} + {cexpr_index(inner * self.num_elems)}"
+            dtype = V.graph.get_dtype(name)
+            line = self._get_vec_load_line(loadbuf, 0, dtype)  # type: ignore[arg-type]
+            csevar = self.cse.generate(self.loads, line, dtype=dtype)
+            csevar.update_on_args("load", (self, name, index), {})
+            assert isinstance(csevar, CppCSEVariable)
+            csevar.is_vec = True
+            return csevar
+        else:
+            new_index = self.transform_indexing(index)
+            return super().load(name, new_index)
+
+    def store(self, name, index, value, mode=None):
+        assert "buf" in name
+        assert isinstance(value, CppCSEVariable), value
+        if not value.is_vec:
+            # this happens when we store a scalar into a vectorized buffer like "fill"
+            value = self.broadcast(value)
+
+        var = self.args.output(name)
+
+        inner = self.inner_itervar()
+        index = self.rename_indexing(index)
+        if self.need_vec_transpose(index):
+            tile_var = self.gen_transposed_tile_load_store(
+                name, var, index, is_store=True, store_mode=mode
+            )
+            # vector store inside the kernel inner loop
+            storebuf = f"{tile_var} + {cexpr_index(inner * self.num_elems)}"
+            if self.tail_size or V.graph.get_dtype(name) in DTYPE_LOWP_FP + [
+                torch.uint8,
+                torch.int8,
+                torch.float8_e4m3fn,
+                torch.float8_e5m2,
+            ]:
+                line = f"{value}.store({storebuf}, {cexpr_index(self.num_elems)});"
+            else:
+                line = f"{value}.store({storebuf});"
+            self.stores.writeline(DeferredLine(name, line))
+        else:
+            new_index = self.transform_indexing(index)
+            super().store(name, new_index, value, mode)
+
+    def codegen_inner_loops(self, code):
+        inner = self.inner_itervar()
+        if self.inner_is_tiling_idx:
+            code.writeline(
+                f"for (long {inner} = 0; {inner} < {cexpr_index(self.outer_num_elems)}; {inner}++)"
+            )
+        else:
+            code.writeline(
+                f"for (long {inner} = 0; {inner} < {cexpr_index(self.inner_num_elems)}; {inner}++)"
+            )
+
+    def set_ranges(self, group, reduction_group):
+        vars = super().set_ranges(group, reduction_group)
+        # do vertical reduction as the tail loop
+        self.outer_idx, self.tiling_idx = (
+            self.tiling_indices
+            if self.tiling_indices[1] < self.reduction_depth
+            else reversed(self.tiling_indices)
+        )
+        if self.tiling_idx == self.tiling_indices[0]:
+            self.tail_size = self.outer_tail_size
+            self.num_elems = self.outer_num_elems
+            self.inner_is_tiling_idx = False
+        else:
+            self.tail_size = self.inner_tail_size
+            self.num_elems = self.inner_num_elems
+            self.inner_is_tiling_idx = True
+        return vars
+
+    def transform_indexing(self, index: sympy.Expr) -> sympy.Expr:
+        return self.scale_index_with_offset(
+            index,
+            itervar_idx=self.outer_idx,
+            offset=self.inner_itervar(),
+        )
+
+
+def get_loop_body_lowp_fp(_body: LoopBody) -> tuple[Optional[torch.dtype], bool]:
+    """
+    Returns the low precision data type (torch.float16/torch.bfloat16) contained in the nodes
+    and if all the nodes can codegen with this data type without converting to float.
+    Otherwise returns None and True.
+    """
+    sub_blocks = [_body.root_block] + list(_body.subblocks.values())
+
+    _lowp_fp_type: Optional[torch.dtype] = None
+    _use_fp32 = False
+    for sub_block in sub_blocks:
+        for _node in sub_block.graph.nodes:
+            if _node.op == "placeholder" or _node.target in (
+                "get_index",
+                "index_expr",
+            ):
+                continue
+
+            # Fast path if all operations can support bf16/fp16 without converting to fp32
+            if _node.target not in [
+                "load",
+                "store",
+                "abs",
+                "neg",
+                "output",
+            ]:
+                _use_fp32 = True
+
+            if hasattr(_node, "meta") and _node.meta:
+                assert OptimizationContext.key in _node.meta
+                opt_ctx: OptimizationContext = _node.meta[OptimizationContext.key]
+                if not opt_ctx.dtype or opt_ctx.dtype not in DTYPE_LOWP_FP:
+                    _use_fp32 = True
+                elif _lowp_fp_type is not None:
+                    if _lowp_fp_type != opt_ctx.dtype:
+                        warnings.warn("bf16 and fp16 are mixed in the scheduler node.")
+                else:
+                    _lowp_fp_type = opt_ctx.dtype
+            else:
+                _use_fp32 = True
+
+    return _lowp_fp_type, _use_fp32
+
+
+class TilingSelect:
+    """
+    Implement the heuristic to select the tiling factors and tiling indices.
+    In the future, we can implement advanced heuristic in a subclass.
+    """
+
+    def select_tiling(
+        self,
+        fn_list,
+        var_sizes_list,
+    ) -> tuple[list[int], list[int]]:
+        # TODO(jgong5): support alternative tiling factors and data types
+        loop_bodies = _get_loop_body(fn_list)
+        all_dtypes = _get_dtype_from_loopbodies(loop_bodies)
+        assert all_dtypes
+        if any(dtype not in VECTORIZABLE_DTYPES for dtype in all_dtypes):
+            return [], []
+        dtype = torch.float
+        _lowp_fp_dtype = get_loop_body_lowp_fp(loop_bodies[0])[0]
+        if _lowp_fp_dtype and all(
+            (get_loop_body_lowp_fp(loop_body)[0] == _lowp_fp_dtype)
+            for loop_body in loop_bodies[1:]
+        ):
+            dtype = _lowp_fp_dtype
+
+        tiling_factor = cpu_vec_isa.pick_vec_isa().nelements(dtype=dtype)
+        tiling_indices = self._select_tiling_indices(
+            fn_list, var_sizes_list, tiling_factor
+        )
+
+        if tiling_indices:
+            group, reduction_group = max(
+                var_sizes_list, key=lambda sizes: len(sizes[1])
+            )
+            call_ranges = tuple(group) + tuple(reduction_group)
+
+            if config.cpp.enable_tiling_heuristics:
+
+                def _try_get_stride(
+                    index,
+                    itervars,
+                    tiling_factor,
+                    tiling_indices,
+                ):
+                    itervar = itervars[tiling_indices[0]]
+                    stride = stride_at_vec_range(index, itervar, tiling_factor)
+                    return stride if stride.is_number else None
+
+                def _update_negative_op_count(
+                    node_name, non_contig_indexing_op_counter
+                ):
+                    if node_name not in non_contig_indexing_op_counter:
+                        non_contig_indexing_op_counter[node_name] = 1
+                    else:
+                        non_contig_indexing_op_counter[node_name] += 1
+
+                def _is_valid_indices(
+                    itervars,
+                    tiling_indices,
+                ):
+                    return (
+                        len(tiling_indices) == 1
+                        and len(itervars) > 0
+                        and (
+                            tiling_indices[0]
+                            if tiling_indices[0] >= 0
+                            else tiling_indices[0] + len(itervars)
+                        )
+                        < len(itervars)
+                    )
+
+                itervars = [
+                    sympy_index_symbol_with_prefix(SymT.XBLOCK, n)
+                    for n in range(len(call_ranges))
+                ]
+                reduction_depth = len(group)
+                vars, reduction_vars = (
+                    itervars[:reduction_depth],
+                    itervars[reduction_depth:],
+                )
+                op_counter: dict[str, int] = {}
+                # ops may cause overhead with vectorization, like non-contiguous
+                # index_expr, load, store
+                non_contig_indexing_op_counter: dict[str, int] = {}
+                for _body in loop_bodies:
+                    sub_blocks = [_body.root_block] + list(_body.subblocks.values())
+                    for sub_block in sub_blocks:
+                        for _node in sub_block.graph.nodes:
+                            if _node.target in ["index_expr", "load", "store"]:
+                                # get the index and replace prefix from z to x
+                                arg_idx = 1 if _node.target == "index_expr" else 2
+                                index = sub_block.body.indexing_from_args(
+                                    (vars, reduction_vars)
+                                )[_node.args[arg_idx].args[0]]
+                                if _is_valid_indices(itervars, tiling_indices):
+                                    stride = _try_get_stride(
+                                        index, itervars, tiling_factor, tiling_indices
+                                    )
+                                    if (
+                                        stride is None
+                                        if _node.target == "index_expr"
+                                        else stride not in [0, 1]
+                                    ):
+                                        _update_negative_op_count(
+                                            _node.target, non_contig_indexing_op_counter
+                                        )
+                            if isinstance(_node.target, str) and not (
+                                _node.target.startswith("masked_subblock")
+                                or _node.target
+                                in ["ops", "output", "constant", "get_index"]
+                            ):
+                                if _node.target not in op_counter:
+                                    op_counter[_node.target] = 1
+                                else:
+                                    op_counter[_node.target] += 1
+
+                op_num = sum(op_counter.values())
+                non_contig_indexing_op_num = sum(
+                    non_contig_indexing_op_counter.values()
+                )
+                ratio_threshold = 0.12
+                quantity_threshold = 35
+                if non_contig_indexing_op_num >= quantity_threshold or (
+                    op_num > 0
+                    and non_contig_indexing_op_num / op_num >= ratio_threshold
+                ):
+                    # Too many non-contiguous load/store/index_expr which hurts the
+                    # vectorization performance. Disable vectorization when exceeding
+                    # the thresholds.
+                    return [], []
+
+                if (
+                    not reduction_group
+                    and group
+                    and len(tiling_indices) == 1
+                    and not has_free_symbols(
+                        [
+                            group[tiling_indices[0]],
+                        ]
+                    )
+                    and group[tiling_indices[0]] < tiling_factor / 4
+                    and op_num < 10
+                ):
+                    # We found that when the number of elements in the inner loop range is
+                    # relatively small(< tiling_factor / 4) and the number of operations is
+                    # not large(< 10), vectorization is not efficient.
+                    # And found that `#pragma GCC ivdep` has better performance than
+                    # `#pragma omp simd simdlen(8)` for these cases.
+                    return [], []
+
+            if dtype in DTYPE_LOWP_FP:
+                # For lower precision data type, if the call_range is not long enough,
+                # use tiling_factor // 2 for better performance
+                factor_lowp = cpu_vec_isa.pick_vec_isa().nelements(dtype=dtype)
+                for tiling_indice in tiling_indices:
+                    if tiling_indice < 0:
+                        tiling_indice = tiling_indice + len(call_ranges)
+                    if tiling_indice < 0 or tiling_indice >= len(call_ranges):
+                        continue
+                    if has_free_symbols(call_ranges):
+                        call_range = V.graph.sizevars.size_hint(
+                            call_ranges[tiling_indice], fallback=0
+                        )
+                        if call_range < factor_lowp:
+                            V.graph.sizevars.check_lt(call_range, factor_lowp)  # type: ignore[arg-type]
+                            tiling_factor = factor_lowp // 2
+                            break
+                    elif call_ranges[tiling_indice] < factor_lowp:
+                        tiling_factor = factor_lowp // 2
+                        break
+
+            if len(tiling_indices) == 1:
+                return [tiling_factor], tiling_indices
+            if len(tiling_indices) == 2:
+                return [tiling_factor, tiling_factor], tiling_indices
+        return [], []
+
+    def _select_tiling_indices(
+        self,
+        fn_list,
+        var_sizes_list,
+        tiling_factor,
+    ):
+        all_index = []
+        for fn, var_sizes in zip(fn_list, var_sizes_list):
+            rw = dependencies.extract_read_writes(fn, *var_sizes)
+            all_index += [dep.index for dep in itertools.chain(rw.reads, rw.writes)]
+        contig_vars = OrderedSet[int]()
+        contig_vars_list = []
+        non_contig_stride_const = OrderedSet[int]()
+        non_contig_stride_other = OrderedSet[int]()
+        for index in all_index:
+            for var in index.free_symbols:
+                if not re.search(r"^d\d+$", var.name):
+                    continue
+                stride = stride_at_vec_range(index, var, tiling_factor)
+                if stride == 0:
+                    continue
+                elif stride == 1:
+                    contig_vars.add(int(var.name[1:]))
+                    contig_vars_list.append(int(var.name[1:]))
+                elif all(symbol_is_type(s, SymT.SIZE) for s in stride.free_symbols):
+                    non_contig_stride_const.add(int(var.name[1:]))
+                else:
+                    non_contig_stride_other.add(int(var.name[1:]))
+        contig_only = contig_vars - non_contig_stride_const - non_contig_stride_other
+        group, reduction_group = max(var_sizes_list, key=lambda sizes: len(sizes[1]))
+        num_itervars = len(group) + len(reduction_group)
+        if len(contig_vars) == 0:
+            # no contiguous vars
+            return [num_itervars - 1]
+        if contig_only:
+            return sorted(contig_only)[-1:]
+        contig_and_const_stride = (
+            contig_vars & non_contig_stride_const
+        ) - non_contig_stride_other
+        contig_vars_sorted = sorted(contig_vars)
+        if (
+            len(contig_vars_sorted) == 2
+            and contig_vars_sorted[-1] in contig_and_const_stride
+            and contig_vars_sorted[-1] == num_itervars - 1
+        ):
+            return contig_vars_sorted
+        return sorted(contig_vars_sorted, key=contig_vars_list.count)[-1:]
+
+
+class CppKernelProxy(CppKernel):
+    # Subclass CppKernel, CppVecKernel, etc., to customize code generation.
+    # Override CppOverrides or CppVecOverrides to emit custom ops.
+    # Earlier, this meant copying codegen_functions() to use your subclasses.
+    # Now, use kernel_cls and vec_kernel_cls class attributes instead.
+    # This lets CppKernelProxy subclasses inject custom behavior cleanly.
+    # No need to duplicate codegen_functions() just to swap kernel classes.
+    kernel_cls: type[CppKernel] = CppKernel
+    vec_kernel_cls: type[CppVecKernel] = CppVecKernel
+    tile2d_kernel_cls: type[CppTile2DKernel] = CppTile2DKernel
+
+    def __init__(self, kernel_group):
+        super().__init__(kernel_group.args, kernel_group.ws.num_threads)
+        self.kernel_group = kernel_group
+        self.loop_nest = None
+        self.call_ranges = None
+        self.picked_vec_isa: cpu_vec_isa.VecISA = cpu_vec_isa.pick_vec_isa()
+        self.kernels: list[CppKernel] = []
+
+    def data_type_propagation(self, nodes):
+        for _node in nodes:
+            assert isinstance(_node, SchedulerNode)
+            DataTypePropagation.propagate_scheduler_node(_node)
+
+    # Check if all the nodes of a given fx graph can support BF16/FP16
+    def is_lowp_fp_scheduler(self, scheduler_node: SchedulerNode):
+        if not isinstance(scheduler_node._body, LoopBody):
+            return True
+        # Propagate the dtype to check if all the fx node is bf16/fp16
+        DataTypePropagation.propagate_scheduler_node(scheduler_node)
+        return (
+            get_loop_body_lowp_fp(scheduler_node._body)[0] is not None
+            and not get_loop_body_lowp_fp(scheduler_node._body)[1]
+        )
+
+    def legalize_lowp_fp_dtype_loopbody(self, loop_body: LoopBody):
+        def add_to_dtype(sub_graph: torch.fx.Graph):
+            def get_input_dtype(node: torch.fx.Node) -> Optional[torch.dtype]:
+                """Get input dtype for nodes that may consumes lowp fp dt"""
+                if node.target == "store":
+                    return V.graph.get_dtype(node.args[1])  # type: ignore[arg-type]
+                elif node.target == "to_dtype_bitcast":
+                    return node.args[-1]  # type: ignore[return-value]
+                elif node.target == "to_dtype":
+                    if len(node.args) > 3:
+                        return node.args[3]  # type: ignore[return-value]
+                    else:
+                        return node.kwargs.get("src_dtype", None)  # type: ignore[return-value]
+                else:
+                    return None
+
+            def get_output_dtype(node: torch.fx.Node) -> Optional[torch.dtype]:
+                """Get output dtype for nodes that may produce lowp fp dt"""
+                if node.target == "load":
+                    assert len(node.args) == 3
+                    return V.graph.get_dtype(node.args[1])  # type: ignore[arg-type]
+                elif node.target in ["to_dtype", "constant", "index_expr"]:
+                    return node.args[-1]  # type: ignore[return-value]
+                elif node.target == "to_dtype_bitcast":
+                    return node.args[2]  # type: ignore[return-value]
+                else:
+                    return None
+
+            def is_lowp_fp_source(node: torch.fx.Node, dt: torch.dtype):
+                """Check if the given node produces output with expected low precision floating point data type."""
+                assert dt in DTYPE_LOWP_FP
+                return get_output_dtype(node) == dt
+
+            def is_lowp_fp_sink(node: torch.fx.Node, dt: torch.dtype):
+                """Check if the given node accept input with expected low precision floating point data type."""
+                assert dt in DTYPE_LOWP_FP
+                if input_dtype := get_input_dtype(node):
+                    return input_dtype == dt
+                elif node.target == "to_dtype":
+                    # The `src_dtype` of a `to_dtype` node might miss, in which case the node accept any input dtype.
+                    return True
+                else:
+                    return False
+
+            def is_lowp_fp_source_no_promote(node: torch.fx.Node, dt: torch.dtype):
+                """Check if the node is a lowp fp sources which are all directly fed to ops that accepts lowp fp input
+                thus no need to promote to float
+                """
+                return is_lowp_fp_source(node, dt) and all(
+                    is_lowp_fp_sink(user, dt) for user in node.users
+                )
+
+            sub_graph_nodes = list(sub_graph.nodes)
+            to_lowp_fp_legalized_nodes = []
+            for _node in sub_graph_nodes:
+                if (
+                    _node.target in ["load", "index_expr"]
+                    and (dt := get_output_dtype(_node)) in DTYPE_LOWP_FP
+                ):
+                    # No need to promote to float if all users are ops that accepts lowp fp input
+                    # pyrefly: ignore [bad-argument-type]
+                    if all(is_lowp_fp_sink(user, dt) for user in _node.users):
+                        continue
+                    ops = _node.args[0]
+                    with sub_graph.inserting_after(_node):
+                        to_type_node = sub_graph.call_method(
+                            "to_dtype", args=(ops, _node, torch.float)
+                        )
+                        _node.replace_all_uses_with(
+                            to_type_node, lambda n: n is not to_type_node
+                        )
+                        # pyrefly: ignore [bad-assignment]
+                        metrics.cpp_to_dtype_count += 1
+                elif (
+                    _node.target == "store"
+                    and (dt := get_input_dtype(_node)) in DTYPE_LOWP_FP
+                ):
+                    ops, name, _, value_var, _ = _node.args
+                    # pyrefly: ignore [bad-argument-type]
+                    if is_lowp_fp_source_no_promote(value_var, dt):
+                        continue
+                    dtype = V.graph.get_dtype(name)
+                    with sub_graph.inserting_before(_node):
+                        to_type_node = sub_graph.call_method(
+                            "to_dtype", args=(ops, value_var, dtype)
+                        )
+                        _node.replace_input_with(value_var, to_type_node)
+                        # pyrefly: ignore [bad-assignment]
+                        metrics.cpp_to_dtype_count += 1
+                elif _node.target == "reduction":
+                    (
+                        ops,
+                        dtype,
+                        src_dtype,
+                        reduction_type,
+                        value,
+                    ) = _node.args
+                    if src_dtype in DTYPE_LOWP_FP:
+                        # Since we always convert the load/store value to float if the tensor is bfloat16/float16.
+                        # Therefore, the reduction should never work with bfloat16/float16 value. Hence, we update
+                        # the bfloat16/float16 reduction by
+                        #     1) updating the src_dtype to float
+                        # and 2) updating the dtype to float if it is bfloat16/float16.
+                        assert dtype in [
+                            torch.float,
+                            torch.bfloat16,
+                            torch.float16,
+                            torch.int64,
+                        ]
+                        _node.args = (
+                            ops,
+                            torch.float if dtype in DTYPE_LOWP_FP else dtype,
+                            torch.float,
+                            reduction_type,
+                            value,
+                        )
+                elif _node.target == "constant" and _node.args[-1] in DTYPE_LOWP_FP:
+                    # No need to promote to float if all users are ops that accepts lowp fp input
+                    (ops, value, dt) = _node.args
+                    if all(is_lowp_fp_sink(user, dt) for user in _node.users):  # type: ignore[arg-type]
+                        continue
+                    _node.args = (ops, value, torch.float)
+                elif _node.target == "to_dtype" and _node.args[-1] in DTYPE_LOWP_FP:
+                    # No need to promote to float if all users are ops that accepts lowp fp input
+                    (ops, x, dt) = _node.args
+                    if all(is_lowp_fp_sink(user, dt) for user in _node.users):  # type: ignore[arg-type]
+                        continue
+                    # The legalization always loads the BF16/FP16 tensor as FP32 for computation
+                    # and converts back to BF16/FP16 after the computation.
+                    # Hence, there should be no computation w/ BF16/FP16.
+                    # Therefore, we update the to_dtype by replacing the bf16/fp16 dtype with fp32.
+                    # Save the legalized to_dtype node for the elimination(eliminate_to_dtype step):
+                    #  1) Eliminate the redundant to_dtype node if we have a pattern as follows:
+                    #     graph():
+                    #       %lowp_fp_legalized = call_method[target=to_dtype](args = (%ops, %input, torch.float))
+                    #       %to_dtype2 = call_method[target=to_dtype](args = (%ops, %lowp_fp_legalized, torch.bfloat16/float16))
+                    # Regarding the first to_dtype, it is redundant because
+                    # the second to_type also converts to the torch.bfloat16/torch.float16.
+                    # Hence, we remove the first to_type.
+                    to_lowp_fp_legalized_nodes.append(_node)
+                    _node.args = (ops, x, torch.float)
+                elif _node.target == "to_dtype_bitcast":
+                    (ops, value_var, dtype, src_dtype) = _node.args
+
+                    # to_dtype_bitcast act as a lowp fp sink:
+                    # c10::bit_cast requires the source and target have the same bitwidth. Because the input tensor's
+                    # dtype could be promoted, e.g. from float16 to float, we have to cast the tensor to its original
+                    # source dtype before invoking bit_cast.
+                    if src_dtype in DTYPE_LOWP_FP:
+                        # No need to promote to float if it is a user of a lowp fp sources
+                        # which are all directly fed to ops that accepts lowp fp input
+                        if not is_lowp_fp_source_no_promote(value_var, src_dtype):
+                            with sub_graph.inserting_before(_node):
+                                to_type_node = sub_graph.call_method(
+                                    "to_dtype", args=(ops, value_var, src_dtype)
+                                )
+                                _node.replace_input_with(value_var, to_type_node)
+                                # pyrefly: ignore [bad-assignment]
+                                metrics.cpp_to_dtype_count += 1
+
+                    # to_dtype_bitcast act as a lowp fp source:
+                    # We also need to convert the bit-casted tensor back to float to make sure we keep using higher
+                    # precision values for the rest of the computation.
+                    if dtype in DTYPE_LOWP_FP:
+                        # No need to promote to float if all users are ops that accepts lowp fp input
+                        if not (
+                            all(is_lowp_fp_sink(user, dtype) for user in _node.users)
+                        ):
+                            ops = _node.args[0]
+                            with sub_graph.inserting_after(_node):
+                                to_type_node = sub_graph.call_method(
+                                    "to_dtype", args=(ops, _node, torch.float)
+                                )
+                                _node.replace_all_uses_with(
+                                    to_type_node, lambda n: n is not to_type_node
+                                )
+                                # pyrefly: ignore [bad-assignment]
+                                metrics.cpp_to_dtype_count += 1
+
+            def eliminate_to_dtype(sub_graph: torch.fx.Graph):
+                def _eliminate_duplicate_to_node(sub_graph: torch.fx.Graph):
+                    # Eliminate the redundant to_dtype node. Let's consider a pattern as follows:
+                    #   graph():
+                    #     %to_dtype1 = call_method[target=to_dtype](args = (%ops, %input, torch.float), kwargs = {})
+                    #     %to_dtype2 = call_method[target=to_dtype](args = (%ops, %to_dtype1, torch.float), kwargs = {})
+                    # Regarding the first to_dtype, it is redundant because the second to_type also converts to the
+                    # torch.float. Hence, we remove the first to_type
+                    def _used_by_to(to_node: torch.fx.Node):
+                        return all(usr.target == "to_dtype" for usr in to_node.users)
+
+                    all_to_nodes = [
+                        node for node in sub_graph.nodes if node.target == "to_dtype"
+                    ]
+                    all_to_nodes_and_users = [
+                        {node: node.users} for node in all_to_nodes if _used_by_to(node)
+                    ]
+                    for node_users in all_to_nodes_and_users:
+                        for node, users in node_users.items():
+                            if node in sub_graph.nodes and (
+                                all(usr.args[-1] == node.args[-1] for usr in users)
+                                or (
+                                    node in to_lowp_fp_legalized_nodes
+                                    and all(
+                                        usr.args[-1] in DTYPE_LOWP_FP for usr in users
+                                    )
+                                )
+                            ):
+                                val_node = node.all_input_nodes[-1]
+                                node.replace_all_uses_with(val_node)
+                                sub_graph.erase_node(node)
+
+                    # For debug mode, the graph of LoopBody will attach a new GraphModule as
+                    # owning_module for debugging while the release mode will not. The lint will
+                    # check whether the graph has owning_module to decide if it needs to check
+                    # call_module. LoopBody might contain get_index as a module call. But it
+                    # is just a function. Hence, it cannot pass the lint check for debug mode.
+                    # We bypass the check if the owning_module is None. Eventually, we should call
+                    # get_index via call_function but not call_module.
+                    if sub_graph.owning_module is None:
+                        sub_graph.lint()
+
+                _eliminate_duplicate_to_node(sub_graph)
+
+            eliminate_to_dtype(sub_graph)
+
+        sub_blocks = [loop_body.root_block] + list(loop_body.subblocks.values())
+        for sub_block in sub_blocks:
+            add_to_dtype(sub_block.graph)
+
+    def legalize_lowp_fp_dtype(self, nodes):
+        if all(
+            isinstance(_node, SchedulerNode) and self.is_lowp_fp_scheduler(_node)
+            for _node in nodes
+        ):
+            # Mark the load node to load bf16/fp16
+            for _node in nodes:
+                sub_blocks = [_node._body.root_block] + list(
+                    _node._body.subblocks.values()
+                )
+                for sub_block in sub_blocks:
+                    for fx_node in sub_block.graph.nodes:
+                        if fx_node.target in ["load", "store"]:
+                            assert fx_node.meta
+                            assert OptimizationContext.key in fx_node.meta
+                            opt_ctx: OptimizationContext = fx_node.meta[
+                                OptimizationContext.key
+                            ]
+                            assert opt_ctx.dtype in DTYPE_LOWP_FP
+
+            # Bypass the legalization as the kernel can run with bf16/fp16 directly
+            return
+
+        for _node in nodes:
+            assert isinstance(_node, SchedulerNode)
+            assert isinstance(_node._body, LoopBody)
+            body: LoopBody = _node._body
+            if not body.is_memory_copy():
+                self.legalize_lowp_fp_dtype_loopbody(body)
+
+    def codegen_functions(self, fn_list, var_sizes_list):
+        assert len(fn_list) == len(var_sizes_list)
+        kernel_group = self.kernel_group
+        group, reduction_group = max(var_sizes_list, key=lambda sizes: len(sizes[1]))
+
+        self.set_ranges(group, reduction_group)
+
+        def codegen_kernel(cls, *args):
+            with kernel_group.new_kernel(cls, *args) as kernel:
+                # Ugly hack to maintain the metrics kernel count since
+                # we only count in CppKernelProxy, not those contained in it
+                # pyrefly: ignore [bad-assignment]
+                metrics.generated_kernel_count -= 1
+
+                run(kernel)
+                return kernel
+
+        def run(kernel):
+            vars, reduction_vars = kernel.set_ranges(group, reduction_group)
+            in_suffix = False
+            for fn, var_sizes in zip(fn_list, var_sizes_list):
+                if var_sizes in [
+                    (group, reduction_group),
+                    (tuple(itertools.chain(group, reduction_group)), ()),
+                ]:
+                    assert not in_suffix
+                    fn(vars, reduction_vars)
+                else:
+                    in_suffix = True
+                    assert var_sizes == (
+                        group,
+                        (),
+                    ), f"unexpected group: {var_sizes} != {group}, {reduction_group}"
+                    # we can fuse in some extra pointwise into the suffix
+                    with kernel.write_to_suffix():
+                        fn(vars, ())
+
+        scalar_kernel = codegen_kernel(self.kernel_cls)
+        V.graph.removed_buffers |= scalar_kernel.removed_buffers
+        V.graph.inplaced_to_remove |= scalar_kernel.inplaced_to_remove
+        self.loop_nest = LoopNest.build(scalar_kernel)
+
+        if not self.picked_vec_isa or not self.itervars:
+            self.kernels = [scalar_kernel]
+            self.aggregate_reduction_buffers(False, None)
+            self.loop_nest.set_kernel(self)
+            return
+
+        # Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args.
+        # But the generated scalar kernel has updated these global contexts. Hence, the other kernels
+        # should not do this again to avoid context conflict. By now, we only control the
+        # config.inplace_buffers. In the future, we could maintain more contexts.
+        with torch._inductor.config.patch(inplace_buffers=False):
+            tiling_select = TilingSelect()
+            tiling_factors, tiling_indices = tiling_select.select_tiling(
+                fn_list, var_sizes_list
+            )
+            assert len(tiling_factors) == len(tiling_indices)
+            _inner_loop_reduction_outer_not = False
+            _outer_loop = None
+            if tiling_indices:
+                inner_loop_reduction = False
+                outer_loop_level = tiling_indices[0]
+                inner_loop_level = outer_loop_level + 1
+                if len(self.loop_nest.loops) > inner_loop_level:
+                    inner_loop_reduction = self.loop_nest.loops[
+                        inner_loop_level
+                    ].is_reduction
+                    outer_loop_reduction = self.loop_nest.loops[
+                        outer_loop_level
+                    ].is_reduction
+                    _inner_loop_reduction_outer_not = (
+                        inner_loop_reduction and not outer_loop_reduction
+                    )
+
+            if len(tiling_indices) == 1:
+                # pyrefly: ignore [bad-assignment]
+                metrics.generated_cpp_vec_kernel_count += 1
+                loop = self.loop_nest.tile(tiling_indices[0], factor=tiling_factors[0])
+                vec_kernel = codegen_kernel(
+                    self.vec_kernel_cls, tiling_factors[0], tiling_indices[0]
+                )
+                tail_size = loop.size - loop.tiled_size
+                vec_kernel.active_ranges = {loop.var: (0, loop.tiled_size)}
+                if config.cpp.enable_loop_tail_vec:
+                    tail_kernel = codegen_kernel(
+                        self.vec_kernel_cls,
+                        tiling_factors[0],
+                        tiling_indices[0],
+                        tail_size,
+                    )
+                else:
+                    tail_kernel = scalar_kernel
+                    scalar_kernel.inner_itervars = [loop.var]
+                tail_kernel.active_ranges = {loop.var: (loop.tiled_size, loop.size)}
+                self.kernels = [vec_kernel, tail_kernel]
+                _outer_loop = loop
+            elif len(tiling_indices) == 2:
+                assert (
+                    tiling_indices[1] == len(self.itervars) - 1
+                    and tiling_factors[0] == tiling_factors[1]
+                )
+
+                # pyrefly: ignore [bad-assignment]
+                metrics.generated_cpp_vec_kernel_count += 2
+                outer_loop = self.loop_nest.tile(
+                    tiling_indices[0], factor=tiling_factors[0]
+                )
+                outer_ranges = {
+                    "main": (0, outer_loop.tiled_size),
+                    "tail": (outer_loop.tiled_size, outer_loop.size),
+                }
+                outer_tail_size = outer_loop.size - outer_loop.tiled_size
+                inner_loop = self.loop_nest.tile(
+                    tiling_indices[1], factor=tiling_factors[0]
+                )
+                inner_ranges = {
+                    "main": (0, inner_loop.tiled_size),
+                    "tail": (inner_loop.tiled_size, inner_loop.size),
+                }
+                inner_tail_size = inner_loop.size - inner_loop.tiled_size
+                tile2d_kernel = codegen_kernel(
+                    self.tile2d_kernel_cls,
+                    tiling_factors[0],
+                    tiling_indices,
+                )
+                tile2d_kernel.active_ranges = {
+                    outer_loop.var: outer_ranges["main"],
+                    inner_loop.var: inner_ranges["main"],
+                }
+                tail_kernel = []
+                if config.cpp.enable_loop_tail_vec:
+                    for outer_r, inner_r in (
+                        ("main", "tail"),
+                        ("tail", "main"),
+                        ("tail", "tail"),
+                    ):
+                        _inner_tail_size = (
+                            inner_tail_size if inner_r == "tail" else None
+                        )
+                        _outer_tail_size = (
+                            outer_tail_size if outer_r == "tail" else None
+                        )
+                        kernel = codegen_kernel(
+                            self.tile2d_kernel_cls,
+                            tiling_factors[0],
+                            tiling_indices,
+                            _inner_tail_size,
+                            _outer_tail_size,
+                        )
+                        kernel.active_ranges = {
+                            outer_loop.var: outer_ranges[outer_r],
+                            inner_loop.var: inner_ranges[inner_r],
+                        }
+                        tail_kernel.append(kernel)
+                else:
+                    vec_kernel = codegen_kernel(
+                        self.vec_kernel_cls, tiling_factors[0], tiling_indices[0]
+                    )
+                    vec_kernel.active_ranges = {
+                        outer_loop.var: outer_ranges["main"],
+                        inner_loop.var: inner_ranges["tail"],
+                    }
+                    vec_kernel.inner_itervars = [inner_loop.var]
+                    tail_kernel.append(vec_kernel)
+                    scalar_kernel.active_ranges = {
+                        outer_loop.var: outer_ranges["tail"],
+                        inner_loop.var: (0, inner_loop.size),
+                    }
+                    scalar_kernel.inner_itervars = [inner_loop.var, outer_loop.var]
+                    tail_kernel.append(scalar_kernel)
+                self.kernels = [tile2d_kernel] + tail_kernel
+                _outer_loop = outer_loop
+            else:
+                self.kernels = [scalar_kernel]
+            self.aggregate_reduction_buffers(
+                _inner_loop_reduction_outer_not, _outer_loop
+            )
+            self.loop_nest.set_kernel(self)
+
+    def codegen_loop_bodies(self, loop_bodies, var_sizes_list):
+        for body in loop_bodies:
+            self.legalize_lowp_fp_dtype_loopbody(body)
+            DataTypePropagation.propagate_loopbody(body)
+        self.codegen_functions(loop_bodies, var_sizes_list)
+
+    def codegen_nodes(self, nodes: list[SchedulerNode]):
+        # Legalize BF16 node by adding to_dtype explicitly
+        self.legalize_lowp_fp_dtype(nodes)
+        self.data_type_propagation(nodes)
+        assert len(nodes) >= 1
+
+        def fn(node, *index_vars):
+            node.decide_inplace_update()
+            node.mark_run()
+            if isinstance(V.kernel, NullKernelHandler):
+                return node._body(*index_vars)
+            else:
+                return node.codegen(index_vars)
+
+        fn_list = [functools.partial(fn, node) for node in nodes]
+
+        if (
+            isinstance(V.local_buffer_context, LocalBufferContext)
+            and V.local_buffer_context.local_buffers
+        ):
+
+            def wrap_fn(fn):
+                wrapped_fn = V.local_buffer_context.localize_function(
+                    fn,
+                )
+                wrapped_fn.original_fn = fn
+                return wrapped_fn
+
+            fn_list = [wrap_fn(fn) for fn in fn_list]
+
+        var_sizes_list = [node.group[1] for node in nodes]
+        self.codegen_functions(fn_list, var_sizes_list)
+
+    def codegen_loops(self, code, worksharing):
+        self.codegen_loops_impl(self.loop_nest, code, worksharing)
+
+    def update_stores_with_parallel_reduction(self):
+        for kernel in self.kernels:
+            kernel.update_stores_with_parallel_reduction()
+
+    def gen_body(self, code: Optional[BracesBuffer] = None):
+        assert code is not None
+        if_prefix = "C10_LIKELY"
+        for kernel in self.kernels:
+            with contextlib.ExitStack() as stack:
+                if kernel.codegen_conditions(code, if_prefix):
+                    if_prefix = "C10_UNLIKELY"
+                    stack.enter_context(code.indent())
+                    code.splice(kernel.gen_body())
+
+    def aggregate_reduction_buffers(
+        self, inner_loop_reduction_outer_not: bool, outer_loop: Optional["LoopLevel"]
+    ):
+        """
+        CppKernel/CppVecKernel/CppTile2dKernel have reduction buffers themselves.
+        Here, we decide how to aggregate them together and place new reduction buffers
+        under CppKernelProxy.
+        """
+
+        def aggregate_reduction_prefix_suffix(outer_loop: "LoopLevel"):
+            assert len(self.kernels) >= 2
+            main_loop_kernel = self.kernels[0]
+            tail_loop_kernel = self.kernels[-1]
+            assert isinstance(main_loop_kernel, self.vec_kernel_cls)
+
+            # Prefix
+            if type(tail_loop_kernel) is self.kernel_cls:
+                # if tail loop kernel is a scalar kernel, we need to extend tmp_acc -> tmp_acc_arr[] to
+                # hold the temporary inner loop acc result for outer tail loop
+                tail_loop_kernel.finalize_reduction_prefix(
+                    main_loop_kernel.tiling_factor
+                )
+                main_loop_kernel.finalize_reduction_prefix()
+                self.reduction_prefix.splice(
+                    tail_loop_kernel.reduction_prefix
+                    + main_loop_kernel.reduction_prefix
+                )
+            else:
+                main_loop_kernel.finalize_reduction_prefix()
+                self.reduction_prefix.splice(main_loop_kernel.reduction_prefix)
+
+            # Suffix
+            suffix_buf = BracesBuffer()
+            with contextlib.ExitStack() as stack:
+                if main_loop_kernel.codegen_conditions(
+                    suffix_buf, "C10_LIKELY", outer_loop.var
+                ):
+                    stack.enter_context(suffix_buf.indent())
+                    suffix_buf.splice(main_loop_kernel.reduction_suffix)
+            with contextlib.ExitStack() as stack:
+                if tail_loop_kernel.codegen_conditions(
+                    suffix_buf, "C10_UNLIKELY", outer_loop.var
+                ):
+                    stack.enter_context(suffix_buf.indent())
+                    if type(tail_loop_kernel) is self.kernel_cls:
+                        reduction_vars = tail_loop_kernel.reduction_var_names
+                        for name in reduction_vars:
+                            new_name = f"{name}_arr[{outer_loop.var}_tail - {cexpr_index(outer_loop.tiled_size)}]"
+                            replace_acc_name(tail_loop_kernel.stores, name, new_name)
+                            replace_acc_name(
+                                tail_loop_kernel.reduction_suffix, name, new_name
+                            )
+                        # If tail loop kernel is a scalar kernel, use direct sum instead of cascade_sum_combine
+                        # as the reduction vars are extended: tmp_acc -> tmp_acc_arr[].
+                        replace_cascade_sum_with_add(tail_loop_kernel.stores)
+                        suffix_buf.splice(
+                            move_code_under_inner_loop(
+                                tail_loop_kernel.reduction_suffix,
+                                outer_loop.var,
+                                f"{outer_loop.var}_tail",
+                                outer_loop.tiled_size,
+                                outer_loop.size,
+                            )
+                        )
+                    else:
+                        suffix_buf.splice(tail_loop_kernel.reduction_suffix)
+            self.reduction_suffix = suffix_buf
+
+        main_kernel = self.kernels[0]
+        if inner_loop_reduction_outer_not:
+            assert outer_loop
+            aggregate_reduction_prefix_suffix(outer_loop)
+        else:
+            main_kernel.finalize_reduction_prefix()
+            self.reduction_prefix.splice(main_kernel.reduction_prefix)
+            self.reduction_suffix.splice(main_kernel.reduction_suffix)
+        self.parallel_reduction_prefix.splice(main_kernel.parallel_reduction_prefix)
+        self.parallel_reduction_suffix.splice(main_kernel.parallel_reduction_suffix)
+        self.local_reduction_init.splice(main_kernel.local_reduction_init)
+        self.local_reduction_stores.splice(main_kernel.local_reduction_stores)
+        self.non_parallel_reduction_prefix.splice(
+            main_kernel.non_parallel_reduction_prefix
+        )
+        self.non_parallel_reduction_suffix.splice(
+            main_kernel.non_parallel_reduction_suffix
+        )
+
+
+class OuterLoopFusedKernel(CppKernel):
+    def __init__(self, kernel_group):
+        super().__init__(kernel_group.args, kernel_group.ws.num_threads)
+        self.inner: list[LoopNest] = []
+
+    def decide_parallel_depth(self, max_parallel_depth, threads):
+        kernels_parallel_depth = []
+        nested_kernels: list[CppKernel] = [
+            loop_nest.get_kernel() for loop_nest in self.inner
+        ]
+        # TODO(leslie-fang-intel): only enable parallel within all outer loop levels.
+        for kernel in nested_kernels:
+            # For any ScalarKernel, VecKernel, or Tile2DKernel,
+            # they should all have the same call_ranges
+            call_ranges = kernel.call_ranges
+            assert call_ranges is not None
+            kernels_parallel_depth.append(
+                kernel.decide_parallel_depth(
+                    ParallelDepth(
+                        parallel_depth=(
+                            len(call_ranges) - max_parallel_depth.start_depth
+                        ),
+                        start_depth=max_parallel_depth.start_depth,
+                    ),
+                    threads,
+                ).parallel_depth
+            )
+        return ParallelDepth(
+            parallel_depth=min(
+                max_parallel_depth.parallel_depth, max(kernels_parallel_depth)
+            ),
+            start_depth=max_parallel_depth.start_depth,
+        )
+
+
+class ReasonFusedNodes(Enum):
+    SAME_VARS_REDUCE = "same_vars_reduce"
+    COMPATIBLE_REDUCTION = "compatible_reduction"
+    COMPATIBLE_RANGES_NO_REDUCTION = "compatible_ranges_no_reduction"
+
+
+class CppScheduling(BaseScheduling):
+    # Subclass CppKernelProxy to customize codegen without copying codegen_node().
+    # Use kernel_proxy_cls to inject custom proxies in CppScheduling subclasses.
+    # Avoid duplicating codegen_node() just to swap in a custom kernel proxy class.
+    kernel_proxy_cls: type[CppKernelProxy] = CppKernelProxy
+    # ctypes limits the number of args to 1024, refer to:
+    # https://github.com/python/cpython/commit/a285af7e626d1b81cf09f8b2bf7656f100bc1237
+    # We set a conservative threshold here.
+    MAX_FUSED_KERNEL_ARGS_NUM = 500
+    backend_features = OrderedSet(
+        [
+            BackendFeature.INPLACE_BUFFERS,
+            BackendFeature.REDUCE_TO_SINGLE_ELEMENT,
+        ]
+    )
+
+    @classmethod
+    def get_backend_features(cls, device: torch.device) -> OrderedSet[BackendFeature]:
+        return cls.backend_features
+
+    def __init__(self, scheduler):
+        super().__init__(scheduler)
+        if scheduler:
+            self.reset_kernel_group()
+        self._ready_to_flush = False
+
+    def _set_flush_status(self, status: bool):
+        self._ready_to_flush = status
+
+    def group_fn(self, sizes):
+        return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes)
+
+    def reset_kernel_group(self):
+        self.kernel_group = KernelGroup()
+
+    def fuse(self, node1, node2):
+        if node1.is_foreach() or node2.is_foreach():
+            return ForeachKernelSchedulerNode.fuse(node1, node2)
+        elif node1.is_template():
+            assert not node2.is_template()
+            return FusedSchedulerNode.fuse(node1, node2)
+        else:
+            if (
+                self._why_fuse_nodes(node1, node2)
+                == ReasonFusedNodes.COMPATIBLE_RANGES_NO_REDUCTION
+            ):
+                assert isinstance(node1, (SchedulerNode, FusedSchedulerNode))
+                assert isinstance(node2, (SchedulerNode, FusedSchedulerNode))
+
+                _, (vars1, reduce1) = node1.group
+                _, (vars2, reduce2) = node2.group
+                assert reduce1 == () and reduce2 == (), (reduce1, reduce2)
+
+                def get_indexing_ranges_exprs(node):
+                    if isinstance(node, FusedSchedulerNode):
+                        assert len(node.snodes) > 0, node.snodes
+                        var_ranges = None
+                        indexing_exprs = OrderedSet[Any]()
+                        for snode in node.snodes:
+                            v, exprs = get_indexing_ranges_exprs(snode)
+                            if var_ranges is None:
+                                var_ranges = v
+                            assert var_ranges == v, (var_ranges, v, node.snodes)
+                            indexing_exprs.update(exprs)
+                        return var_ranges, list(indexing_exprs)
+                    else:
+                        assert isinstance(node, SchedulerNode)
+                        comp_buffer = node.node
+                        assert isinstance(comp_buffer, ir.ComputedBuffer)
+                        _, body, _ = comp_buffer.get_default_sizes_body()
+                        return body.var_ranges, list(body.indexing_exprs.values())
+
+                node_to_recomp = node1 if len(vars1) < len(vars2) else node2
+                assert isinstance(node_to_recomp, SchedulerNode)
+
+                ref_node = node2 if len(vars1) < len(vars2) else node1
+
+                ref_indexing_constraints = get_indexing_ranges_exprs(ref_node)
+
+                node_to_recomp.recompute_size_and_body(
+                    extra_indexing_constraints=ref_indexing_constraints
+                )
+
+                _, (vars1, _) = node1.group
+                _, (vars2, _) = node2.group
+
+                if vars1 == vars2:
+                    return FusedSchedulerNode.fuse(node1, node2)
+
+                # recompute ref_node if its ranges are also changed
+                node_to_recomp_indexing_constraints = get_indexing_ranges_exprs(
+                    node_to_recomp
+                )
+                if isinstance(ref_node, SchedulerNode):
+                    ref_node.recompute_size_and_body(
+                        extra_indexing_constraints=node_to_recomp_indexing_constraints
+                    )
+                else:
+                    assert isinstance(ref_node, FusedSchedulerNode)
+                    for snode in ref_node.snodes:
+                        assert isinstance(snode, SchedulerNode)
+                        snode.recompute_size_and_body(
+                            extra_indexing_constraints=node_to_recomp_indexing_constraints
+                        )
+                    ref_node = FusedSchedulerNode(ref_node.scheduler, ref_node.snodes)
+
+                _, (vars1, _) = node1.group
+                _, (vars2, _) = node2.group
+                assert vars1 == vars2, (vars1, vars2)
+                return FusedSchedulerNode.fuse(node1, node2)
+            elif self.can_fuse_vertical_outer_loop(node1, node2):
+                return OuterLoopFusedSchedulerNode.fuse(
+                    node1, node2, self._get_outer_loop_fusion_depth(node1, node2)
+                )
+            else:
+                return FusedSchedulerNode.fuse(node1, node2)
+
+    def _why_fuse_nodes(self, node1, node2) -> Optional[ReasonFusedNodes]:
+        _, (vars1, reduce1) = node1.group
+        _, (vars2, reduce2) = node2.group
+
+        if vars1 == vars2 and reduce1 == reduce2:
+            return ReasonFusedNodes.SAME_VARS_REDUCE
+        if reduce1 == () and vars1 == vars2 + reduce2:
+            return ReasonFusedNodes.COMPATIBLE_REDUCTION
+        if self._can_fuse_nodes_with_compatible_ranges(node1, node2):
+            return ReasonFusedNodes.COMPATIBLE_RANGES_NO_REDUCTION
+        # TODO(jansel): allow fusion pointwise (vars1, ()) suffix?
+        return None
+
+    def _can_fuse_nodes_with_compatible_ranges(self, node1, node2):
+        # Here we try to fuse SchedulerNode/FusedSchedulerNode with compatible ranges
+        # e.g. (s0, s1, s2) and (s0 * s1 * s2)
+        _, (vars1, reduce1) = node1.group
+        _, (vars2, reduce2) = node2.group
+
+        c1 = reduce1 == () and reduce2 == ()
+        c2 = math.prod(vars1) == math.prod(vars2)
+        c3 = len(vars1) == 1 or len(vars2) == 1
+        if not (c1 and c2 and c3):
+            return False
+
+        node_to_recomp = node1 if len(vars1) < len(vars2) else node2
+        ref_node = node2 if len(vars1) < len(vars2) else node1
+
+        # We can not recompute sizes and body for nodes other than SchedulerNode
+        # TODO: we can extend fusion support with compatible ranges for FusedSchedulerNode
+        if isinstance(node_to_recomp, FusedSchedulerNode):
+            return False
+
+        # It may happen that node1 and node2 compatible number of elements
+        # but different original ranges, for example:
+        # {d0: s0, d1: s1, d2: s2} vs {d0: s0*s1*s2}
+        # See https://github.com/pytorch/pytorch/pull/120077/files#r1500427848 for more details
+        # TODO: we can fix if it allows us to CSE at least one of the variables
+
+        assert isinstance(node_to_recomp, SchedulerNode)
+        if isinstance(node_to_recomp.node, ir.TemplateBuffer):
+            return False
+        assert isinstance(node_to_recomp.node, ir.ComputedBuffer)
+        # node.data.get_size() is a cheaper version of node.get_read_writes().var_ranges
+        # but without variable name
+        ranges2 = node_to_recomp.node.data.get_size()
+        ranges1 = None
+        if isinstance(ref_node, FusedSchedulerNode):
+            ranges_set = OrderedSet[tuple[Any, ...]]()
+            for snode in ref_node.snodes:
+                if isinstance(snode.node, ir.TemplateBuffer):
+                    break
+                assert isinstance(snode.node, ir.ComputedBuffer)
+                ranges_set.add(tuple(snode.node.data.get_size()))
+
+            if len(ranges_set) != 1:
+                return False
+
+            ranges1 = list(next(iter(ranges_set)))
+        else:
+            assert isinstance(ref_node, SchedulerNode)
+            assert isinstance(ref_node.node, ir.ComputedBuffer)
+            ranges1 = ref_node.node.data.get_size()  # type: ignore[assignment]
+
+        if ranges1 != ranges2:
+            return False
+
+        return True
+
+    def _can_fuse_horizontal_impl(self, node1, node2):
+        assert isinstance(node1, (FusedSchedulerNode, SchedulerNode))
+        assert isinstance(node2, (FusedSchedulerNode, SchedulerNode))
+        if any(
+            isinstance(node, OuterLoopFusedSchedulerNode) for node in (node1, node2)
+        ):
+            return False
+        return self._why_fuse_nodes(node1, node2) is not None
+
+    def can_fuse_horizontal(self, node1, node2):
+        if node1.is_template() or node2.is_template():
+            return False
+        if (
+            len(node1.get_nodes()) + len(node2.get_nodes())
+            > config.cpp.max_horizontal_fusion_size
+        ):
+            return False
+
+        return self._can_fuse_horizontal_impl(node1, node2)
+
+    def can_fuse_multi_outputs_template(
+        self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
+    ) -> bool:
+        if template_buf := node1.get_template_node():
+            return (
+                isinstance(template_buf.layout, ir.MultiOutputLayout)
+                and isinstance(node2.node, ir.MultiOutput)
+                and len(node2.node.inputs) == 1
+                and node2.node.inputs[0].get_name() == template_buf.name  # type: ignore[union-attr]
+            )
+        return False
+
+    def _get_outer_loop_fusion_depth(self, node1, node2):
+        DISABLE_OUTER_LOOP_FUSION = 0
+        if not all(
+            type(node)
+            in (OuterLoopFusedSchedulerNode, FusedSchedulerNode, SchedulerNode)
+            for node in (node1, node2)
+        ):
+            return DISABLE_OUTER_LOOP_FUSION
+
+        _node1 = (
+            node1.get_outer_nodes()[-1]
+            if isinstance(node1, OuterLoopFusedSchedulerNode)
+            else node1
+        )
+        assert isinstance(_node1, (FusedSchedulerNode, SchedulerNode))
+        _node2 = (
+            node2.get_outer_nodes()[0]
+            if isinstance(node2, OuterLoopFusedSchedulerNode)
+            else node2
+        )
+        assert isinstance(_node2, (FusedSchedulerNode, SchedulerNode))
+
+        _, (vars1, reduce1) = _node1.group
+        _, (vars2, reduce2) = _node2.group
+        if vars1 == () and vars2 == () and reduce1 != () and reduce2 != ():
+            # Reduction only
+            return DISABLE_OUTER_LOOP_FUSION
+        if all(type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)):
+            return (
+                node1.outer_loop_fusion_depth
+                if node1.outer_loop_fusion_depth == node2.outer_loop_fusion_depth
+                else DISABLE_OUTER_LOOP_FUSION
+            )
+        outer_loop_fusion_depth = min(len(vars1), len(vars2))
+        if (
+            outer_loop_fusion_depth >= 1
+            and vars1[:outer_loop_fusion_depth] == vars2[:outer_loop_fusion_depth]
+        ):
+            if any(
+                type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)
+            ):
+                _compare_node = (
+                    node1 if type(node1) is OuterLoopFusedSchedulerNode else node2
+                )
+                if _compare_node.outer_loop_fusion_depth == outer_loop_fusion_depth:
+                    # Same outer loop fusion depth as prev nodes in OuterLoopFusedSchedulerNode
+                    return outer_loop_fusion_depth
+                else:
+                    return DISABLE_OUTER_LOOP_FUSION
+            else:
+                # First 2 nodes to generate OuterLoopFusedSchedulerNode
+                return outer_loop_fusion_depth
+        return DISABLE_OUTER_LOOP_FUSION
+
+    def can_fuse_vertical_outer_loop(self, node1, node2):
+        return (
+            not node1.is_template()
+            and not node2.is_template()
+            and node1.get_operation_names() & node2.ancestors
+            and not (
+                self._can_fuse_horizontal_impl(node1, node2)
+                and not node1.is_reduction()
+            )
+            and self._get_outer_loop_fusion_depth(node1, node2) >= 1
+        )
+
+    def get_fusion_pair_priority(self, node1, node2):
+        if self.can_fuse_vertical_outer_loop(node1, node2):
+            # Outer loop fusion with lower priority
+            return 1
+        else:
+            return 0
+
+    def can_fuse_vertical(self, node1, node2):
+        if node2.is_template():
+            # TODO(jgong5): support pre-op fusion with template
+            return False
+        if node1.is_template():
+            template_fusion_supported, _ = template_fusion_with_epilogues_supported(
+                node1, [node2]
+            )
+            return not node2.is_reduction() and template_fusion_supported
+        return (
+            self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction()
+        ) or self.can_fuse_vertical_outer_loop(node1, node2)
+
+    def try_loop_split(self, nodes: list[SchedulerNode]):
+        """
+        Apply loop split optimization.
+        When one of the indexing_exprs contains a division, we eliminate the division by splitting the loop
+        to avoid non-contiguous loads, subject to the following conditions:
+            1. No reduction and no mudular index for all nodes.
+            2. The indexing_exprs of all nodes contain only one (or more, but all the same) division,
+               where the divisor is an integer and not too small (the divisor > 8), the dividend is
+               one of the iter_vars, and this var, i.e. the dimension that needs to be split, is
+               contiguous in all other indexing_exprs.
+
+        For example, if the node's var_ranges: {z0: 2, z1: 9216, z2: 960} and indexing_exprs:
+        {'index0': 8847360*z0 + 960*z1 + z2, 'index1': 32*z0 + (z2//30), 'index2': z2},
+        we will split z2 -> 30*z2 + z3, then the node's var_ranges will be changed to
+        {z0: 2, z1: 9216, z2: 32, z3: 30} and indexing_exprs will be changed to
+        {'index0': 8847360*z0 + 960*z1 + 30*z2 + z3, 'index1': 32*z0 + z2, 'index2': 30*z2 + z3}.
+        """
+
+        # No reduction and no mudular
+        if any(
+            len(node.group[1][1]) != 0
+            or any(
+                expr.has(ModularIndexing) for expr in node._body.indexing_exprs.values()
+            )
+            for node in nodes
+        ):
+            return nodes
+
+        split_var = None
+        split_number = None
+        num_div = 0
+        div_expr_ = None
+        match_div = False
+        matched_node = None
+
+        for node in nodes:
+            assert isinstance(node.node, ir.ComputedBuffer)
+            _, original_body, _ = node.node.get_default_sizes_body()
+            for name, expr in original_body.indexing_exprs.items():
+                if not isinstance(expr, sympy.Expr):
+                    continue
+                for div_expr in expr.find(FloorDiv):
+                    if (
+                        any(div_expr.has(var) for var in original_body.iter_vars)
+                        and div_expr != div_expr_
+                    ):
+                        div_expr_ = div_expr
+                        num_div += 1
+                    if num_div > 1:
+                        return nodes
+                    if (
+                        isinstance(div_expr.args[1], sympy.core.numbers.Integer)
+                        and div_expr.args[0] in original_body.iter_vars
+                        and name is not None
+                        and all(
+                            stride_at_vec_range(expr_, div_expr.args[0]) in (0, 1)
+                            for name_, expr_ in original_body.indexing_exprs.items()
+                            if name_ != name
+                        )
+                        and div_expr.args[1] > 8
+                    ):
+                        split_var = div_expr.args[0]
+                        split_number = div_expr.args[1]
+                        match_div = True
+                        matched_node = node
+
+        # Only one node contains a division, and the split dimension is contiguous in all other indexing_exprs.
+        if not match_div:
+            return nodes
+
+        extra_indexing_constraints = None
+
+        def loop_split(sizes, body, vars):
+            index_size, reduce_size = sizes
+            index_vars, reduce_vars = vars
+            split_idx = index_vars.index(split_var)
+            new_index_size = index_size.copy()
+            new_index_size[split_idx] = index_size[split_idx] // split_number
+            new_index_size.insert(split_idx + 1, split_number)
+            (new_index_vars, _), var_ranges = dependencies.index_vars_no_squeeze(
+                new_index_size, reduce_size, prefix="y"
+            )
+            iter_vars = new_index_vars.copy()
+            divisor_var = iter_vars.pop(split_idx + 1)
+            iter_vars[split_idx] = split_number * iter_vars[split_idx] + divisor_var
+            body = ir.LoopBody(
+                body, [iter_vars, reduce_vars], var_ranges, new_index_vars, reduce_vars
+            )
+            nonlocal extra_indexing_constraints
+            if not extra_indexing_constraints:
+                extra_indexing_constraints = (
+                    body.var_ranges,
+                    list(body.indexing_exprs.values()),
+                )
+            return (
+                (new_index_size, reduce_size),
+                body,
+                (new_index_vars, reduce_vars),
+            )
+
+        # Here decide the final loop order
+        for node in nodes:
+            if node == matched_node:
+                node.recompute_size_and_body(recompute_sizes_body_func=loop_split)
+        for node in nodes:
+            if node != matched_node:
+                node.recompute_size_and_body(
+                    extra_indexing_constraints=extra_indexing_constraints,
+                    recompute_sizes_body_func=loop_split,
+                )
+
+        return nodes
+
+    def codegen_outer_loop_node(
+        self,
+        node: OuterLoopFusedSchedulerNode,
+    ):
+        """
+        Generate the code for the outer loop fused scheduler node.
+        1. Codegen with fused outer loop: depends on the analysis of
+            the outer loop fused scheduler node, with or without the local buffer.
+        2. If failed, fallback to standard codegen.
+        """
+        kernel_group = self.kernel_group
+        generated_cpp_vec_kernel_count = metrics.generated_cpp_vec_kernel_count
+        cpp_kernel_proxy_list: list[self.kernel_proxy_cls] = []  # type: ignore[name-defined]
+        nodes_list: list[list[SchedulerNode]] = []
+        assert isinstance(node, OuterLoopFusedSchedulerNode)
+
+        def try_outer_loop_fusion_with_local_buf(node: OuterLoopFusedSchedulerNode):
+            """
+            Codegen code with fused outer loop and local Buffer.
+            """
+            assert isinstance(node, OuterLoopFusedSchedulerNode)
+            cpp_kernel_proxy_list.clear()
+            nodes_list.clear()
+
+            def get_call_ranges(node: BaseSchedulerNode):
+                assert isinstance(node, (SchedulerNode, FusedSchedulerNode))
+                nodes: list[SchedulerNode] = node.get_nodes()  # type: ignore[assignment]
+                _, (group, reduction_group) = max(
+                    nodes, key=lambda x: int(x.is_reduction())
+                ).group
+                call_ranges = tuple(group) + tuple(reduction_group)
+                return call_ranges
+
+            local_buffers: list[ir.Buffer] = []
+            # Map local buffer name to a list of global buffers
+            local_to_global_buffers: dict[str, list[ir.Buffer]] = {}
+            if all(
+                len(get_call_ranges(_node)) == node.outer_loop_fusion_depth + 1
+                for _node in node.get_outer_nodes()
+            ):
+                # Ref to the typical case of local buffer in
+                # https://github.com/pytorch/pytorch/blob/1115a25c36340554442f28f9570abd42f0aface2/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L159 # noqa: B950
+                # where the buffer is with size of last dim and contiguous.
+                # Only support this typical case at first.
+                visited_scheduler_nodes: OrderedSet[str] = OrderedSet()
+                for scheduler_node in node.get_nodes():
+                    # all users inside same OuterLoopFusedSchedulerNode
+                    assert isinstance(scheduler_node, SchedulerNode)
+                    visited_scheduler_nodes.add(scheduler_node.get_name())
+                    if (
+                        scheduler_node.is_reduction()
+                        or len(scheduler_node.get_outputs()) != 1
+                    ):
+                        continue
+
+                    scheduler_buffer = scheduler_node.get_outputs()[0]
+                    if all(
+                        user.node in node.get_nodes() for user in scheduler_buffer.users
+                    ):
+                        global_buffer = scheduler_buffer.node
+                        assert isinstance(global_buffer, ir.ComputedBuffer)
+                        global_buffer_layout = global_buffer.get_layout()
+                        size_offset = node.outer_loop_fusion_depth - len(
+                            get_call_ranges(scheduler_node)
+                        )
+
+                        def is_all_write_read_contiguous():
+                            contiguous_index_expr = 0
+                            stride = 1
+                            for var, range in reversed(
+                                # pyrefly: ignore [missing-attribute]
+                                scheduler_node._body.var_ranges.items()
+                            ):
+                                contiguous_index_expr += stride * var
+                                stride *= range
+                            # pyrefly: ignore [missing-attribute]
+                            write_index_expr = scheduler_node._body.get_write_expr(
+                                scheduler_buffer.get_name()
+                            )
+
+                            def is_contiguous_index(x):
+                                return x == contiguous_index_expr
+
+                            return is_contiguous_index(write_index_expr) and all(
+                                isinstance(user.node, SchedulerNode)
+                                and is_contiguous_index(
+                                    user.node._body.get_read_expr(
+                                        scheduler_buffer.get_name()
+                                    ),
+                                )
+                                for user in scheduler_buffer.users
+                            )
+
+                        if not (
+                            global_buffer_layout.is_contiguous()
+                            and is_all_write_read_contiguous()
+                        ):
+                            continue
+                        # Local Buffer is a view of global buffer
+                        local_buffer_stride: list[int] = []
+                        stride = global_buffer_layout.stride[-1]
+                        local_buffer_size = get_call_ranges(scheduler_node)[
+                            size_offset:
+                        ]
+                        for sz in reversed(local_buffer_size):
+                            local_buffer_stride.insert(0, stride)
+                            stride *= sz
+                        local_buffer_layout = ir.FixedLayout(
+                            global_buffer_layout.device,
+                            global_buffer_layout.dtype,
+                            local_buffer_size,
+                            local_buffer_stride,
+                        )
+
+                        def try_share_local_buffer(local_buffer_layout, local_buffers):
+                            for local_buf in local_buffers:
+                                if local_buffer_layout == local_buf.layout and all(
+                                    all(
+                                        user.node.get_name() in visited_scheduler_nodes
+                                        for user in V.graph.scheduler.name_to_buf[
+                                            global_buffer.name
+                                        ].users
+                                    )
+                                    for global_buffer in local_to_global_buffers[
+                                        local_buf.name
+                                    ]
+                                    if global_buffer.name is not None
+                                ):
+                                    return local_buf
+                            return None
+
+                        local_buf_prefix = "local_buffer_data"
+                        # Share existing local buffer
+                        local_buffer_used = try_share_local_buffer(
+                            local_buffer_layout, local_buffers
+                        )
+                        if not local_buffer_used:
+                            # Create new local buffer
+                            local_buffer_used = ir.Buffer(
+                                name=f"{local_buf_prefix}_{len(local_buffers)}",
+                                layout=local_buffer_layout,
+                            )
+                            local_buffers.append(local_buffer_used)
+                            local_to_global_buffers[local_buffer_used.name] = []  # type: ignore[index]
+                        # pyrefly: ignore [index-error]
+                        local_to_global_buffers[local_buffer_used.name].append(
+                            global_buffer,
+                        )
+
+            with LocalBufferContext(kernel_group.args) as scope:
+                if len(local_buffers) > 0:
+                    for local_buffer in local_buffers:
+                        assert local_buffer.name is not None
+                        scope.add_local_buffer(
+                            local_buffer, local_to_global_buffers[local_buffer.name]
+                        )
+                for _node in node.get_outer_nodes():
+                    assert isinstance(_node, (FusedSchedulerNode, SchedulerNode))
+                    cpp_kernel_proxy = self.kernel_proxy_cls(kernel_group)
+                    cpp_kernel_proxy.codegen_nodes(_node.get_nodes())  # type: ignore[arg-type]
+                    cpp_kernel_proxy_list.append(cpp_kernel_proxy)
+                    nodes_list.append(_node.get_nodes())  # type: ignore[arg-type]
+
+                if not node.check_outer_fusion_loop_level_attr(
+                    cpp_kernel_proxy_list, node.outer_loop_fusion_depth
+                ):
+                    for removed_buffer in scope.removed_buffers:
+                        # Restore the removed buffers by this context before
+                        # fallback to codegen without using Local Buffer
+                        V.graph.removed_buffers.remove(removed_buffer)
+                    return False
+                metrics.cpp_outer_loop_fused_inner_counts.append(
+                    metrics.CppOuterLoopFusedCount(
+                        len(cpp_kernel_proxy_list),
+                        local_buffer_number=len(scope.local_buffers),
+                    )
+                )
+                outer_fusion_cpp_kernel_proxy = node.merge_outer_fusion_kernels(
+                    cpp_kernel_proxy_list,
+                )
+                kernel_group.finalize_kernel(
+                    outer_fusion_cpp_kernel_proxy,
+                    [*itertools.chain.from_iterable(nodes_list)],
+                )
+
+            return True
+
+        if not try_outer_loop_fusion_with_local_buf(node):
+            # Reset generated_cpp_vec_kernel_count to codegen again
+            metrics.generated_cpp_vec_kernel_count = generated_cpp_vec_kernel_count
+            cpp_kernel_proxy_list.clear()
+            nodes_list.clear()
+            # Similar as comment in
+            # https://github.com/pytorch/pytorch/blob/469383755fe416eb1c41fa724762ad3eaecdff07/torch/_inductor/codegen/cpp.py#L3269-L3272
+            # Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args.
+            with torch._inductor.config.patch(inplace_buffers=False):
+                for _node in node.get_outer_nodes():
+                    assert isinstance(_node, (FusedSchedulerNode, SchedulerNode))
+                    _nodes: list[SchedulerNode] = _node.get_nodes()  # type: ignore[assignment]
+                    cpp_kernel_proxy = self.kernel_proxy_cls(kernel_group)
+                    cpp_kernel_proxy.codegen_nodes(_nodes)
+                    kernel_group.finalize_kernel(cpp_kernel_proxy, _nodes)
+
+    def codegen_node(
+        self,
+        node: Union[OuterLoopFusedSchedulerNode, FusedSchedulerNode, SchedulerNode],
+    ):
+        """
+        Turn an set of pre-fused nodes into a C++ kernel.
+        """
+        kernel_group = self.kernel_group
+
+        if isinstance(node, OuterLoopFusedSchedulerNode):
+            self.codegen_outer_loop_node(node)
+        else:
+            nodes: list[SchedulerNode] = node.get_nodes()  # type: ignore[assignment]
+            nodes = self.try_loop_split(nodes)
+            cpp_kernel_proxy = self.kernel_proxy_cls(kernel_group)
+            cpp_kernel_proxy.codegen_nodes(nodes)
+            kernel_group.finalize_kernel(cpp_kernel_proxy, nodes)
+
+        args_num = self._get_scheduled_num_args()
+        if args_num > CppScheduling.MAX_FUSED_KERNEL_ARGS_NUM:
+            self._set_flush_status(True)
+
+    def is_cpp_template(self, node: BaseSchedulerNode) -> bool:
+        return isinstance(node, SchedulerNode) and isinstance(
+            node.node, ir.CppTemplateBuffer
+        )
+
+    def codegen_template(
+        self,
+        template_node: BaseSchedulerNode,
+        epilogue_nodes: Sequence[BaseSchedulerNode],
+        prologue_nodes: Sequence[BaseSchedulerNode],
+    ):
+        """
+        Codegen a CPP template, possibly with fused epilogues
+        """
+        assert not prologue_nodes
+
+        # remove MultiOutput from epilogue_nodes
+        epilogue_nodes = [
+            epilogue_node
+            for epilogue_node in epilogue_nodes
+            if isinstance(epilogue_node, (SchedulerNode, FusedSchedulerNode))
+        ]
+        # The counter cpp_templated_kernel_counter is used for verifying if a
+        # a templated kernel was successfully compiled in a UT
+        counters["inductor"]["cpp_templated_kernel_counter"] += 1
+        counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes)
+        assert self.is_cpp_template(template_node), (
+            "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer"
+        )
+        template_node = cast(SchedulerNode, template_node)
+        _, (_, rnumel) = template_node.group
+        assert rnumel == ()
+        ctb: ir.CppTemplateBuffer = cast(ir.CppTemplateBuffer, template_node.node)
+        epilogue_ir_nodes: list[Optional[ir.Operation]] = [
+            n.node for n in epilogue_nodes
+        ]
+        assert all(isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes), (
+            "Epilogue nodes must all be instances of ir.ComputedBuffer"
+        )
+
+        def template_buffer_has_other_users(
+            template_buffer, outputs_by_name, epilogue_nodes
+        ):
+            if not epilogue_nodes:
+                return False
+
+            assert template_buffer.get_name() in outputs_by_name
+            users = outputs_by_name[template_buffer.get_name()].users
+            return not all(
+                isinstance(user.node, BaseSchedulerNode)
+                and user.node.node in epilogue_nodes
+                for user in users
+            )
+
+        flag_template_buffer_has_other_users = template_buffer_has_other_users(
+            ctb, template_node.outputs_by_name, epilogue_ir_nodes
+        )
+        kernel, render = ctb.make_kernel_render(  # type: ignore[misc]
+            ctb,
+            flag_template_buffer_has_other_users=flag_template_buffer_has_other_users,
+            epilogue_nodes=epilogue_ir_nodes,
+        )
+        with kernel:
+            if not is_multi_outputs_template(template_node.node):
+                template_node.mark_run()  # type: ignore[attr-defined]
+            for node in epilogue_nodes:
+                node.mark_run()  # type: ignore[attr-defined]
+            src_code = render()
+
+        with V.set_kernel_handler(kernel):
+            node_schedule = [template_node, *epilogue_nodes]
+            kernel_name = self.define_kernel(src_code, node_schedule, kernel.args)
+
+        if is_multi_outputs_template(template_node.node):
+            # For multi outputs template, allocate buffers for each output after the epilogue
+            # codegen to which determines if the buffer has been removed.
+            assert len(template_node.outputs) == 1, (
+                "Multi outputs template should be with 1 output template buffer of MultiOutputLayout"
+            )
+            for user in template_node.outputs[0].users:
+                assert isinstance(user.node, ExternKernelSchedulerNode), (
+                    "Multi outputs template should be with ExternKernelSchedulerNode"
+                )
+                assert isinstance(user.node.node, ir.MultiOutput), (
+                    "Multi outputs template has multi users with MultiOutput"
+                )
+                user.node.mark_run()
+
+        self.codegen_comment(node_schedule, kernel_name)
+        kernel.call_kernel(kernel_name, ctb)
+        V.graph.removed_buffers |= kernel.removed_buffers
+        self.free_buffers_in_scheduler()
+
+    def _get_scheduled_num_args(self):
+        return self.kernel_group.get_num_args()
+
+    def ready_to_flush(self):
+        return self._ready_to_flush
+
+    def codegen_sync(self):
+        pass
+
+    def define_kernel(self, src_code, nodes, kernel_args=None):
+        wrapper = V.graph.wrapper_code
+        if src_code in wrapper.src_to_kernel:
+            kernel_name = wrapper.src_to_kernel[src_code]
+        else:
+            fused_name = (
+                get_fused_kernel_name(nodes, config.cpp.descriptive_names)
+                if config.cpp.descriptive_names
+                else ""
+            )
+            kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()])
+            wrapper.src_to_kernel[src_code] = kernel_name
+            kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel"
+            src_code = src_code.replace(str(Placeholder.KERNEL_NAME), kernel_decl_name)
+            src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name)
+            # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does
+            # not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
+            src_code = src_code.replace("#pragma CMT", "//")
+
+            # Get the lines in the source code representing the function definition,
+            # excluding the first line including cpp_prefix.h.
+            first_char = src_code.rfind('extern "C"')
+            last_char = src_code.find(")", first_char)
+            if _IS_WINDOWS:
+                # get_export_declaration introduced one more ')' in Windows
+                last_char = src_code.find(")", last_char + 1)
+            kernel_definition = f"{src_code[first_char : last_char + 1]};\n"
+
+            compile_wrapper = IndentedBuffer()
+            args = self.kernel_group.args if kernel_args is None else kernel_args
+            _, _, arg_types = args.cpp_argdefs()
+            if not V.graph.cpp_wrapper:
+                compile_wrapper.writeline(
+                    f"async_compile.cpp_pybinding({arg_types!r}, r'''"
+                )
+            compile_wrapper.splice(src_code, strip=True)
+            if not V.graph.cpp_wrapper:
+                compile_wrapper.writeline("''')")
+            wrapper.define_kernel(
+                kernel_name,
+                compile_wrapper.getvalue(),
+                gpu=False,
+                cpp_definition=kernel_definition,
+            )
+        return kernel_name
+
+    def flush(self):
+        src_code = self.kernel_group.codegen_group()
+        if src_code:
+            kernel_name = self.define_kernel(
+                src_code, self.kernel_group.scheduled_nodes
+            )
+            self.codegen_comment(self.kernel_group.scheduled_nodes, kernel_name)
+            if config.cpp.enable_kernel_profile:
+                V.graph.wrapper_code.write_kernel_context_guard_begin()
+                V.graph.wrapper_code.write_kernel_context_guard(
+                    kernel_name,
+                    self.kernel_group.scheduled_nodes,  # type: ignore[arg-type]
+                )
+            self.kernel_group.call_kernel(V.graph.wrapper_code, kernel_name)
+            if config.cpp.enable_kernel_profile:
+                V.graph.wrapper_code.write_kernel_context_guard_end()
+
+        self.reset_kernel_group()
+        self._set_flush_status(False)
+
+    def codegen_comment(self, node_schedule, kernel_name=None):
+        # below add provenance tracing info for cpu CppKernel types
+        wrapper = V.graph.wrapper_code
+        debug_handle = set_kernel_post_grad_provenance_tracing(
+            node_schedule,  # type: ignore[arg-type]
+            # pyrefly: ignore [bad-argument-type]
+            kernel_name,
+        )
+        wrapper.write_provenance_debug_handle(kernel_name, debug_handle)
+
+
+class KernelGroup:
+    def __init__(self):
+        super().__init__()
+        self.args = KernelArgs()
+        self.loops_code = BracesBuffer()
+        self.ws = WorkSharing(self.loops_code)
+        self.stack = contextlib.ExitStack()
+        self.stack.enter_context(self.ws)
+        self.scheduled_nodes = []
+
+    def new_kernel(self, cls, *args):
+        return cls(self.args, parallel_num_threads(), *args)
+
+    def finalize_kernel(self, new_kernel, nodes):
+        self.scheduled_nodes += nodes
+        code = self.loops_code
+        ws = self.ws
+        new_kernel.codegen_loops(code, ws)
+
+    def get_num_args(self):
+        arg_defs, _call_args, _arg_types = self.args.cpp_argdefs()
+        args_num = len(arg_defs)
+        return args_num
+
+    def codegen_group(self, name=None) -> str:
+        self.stack.close()
+        if not self.scheduled_nodes:
+            return ""
+        code = BracesBuffer()
+        # 1. Include header files
+        # TODO: support kernel profile on other platforms
+        enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [
+            "linux",
+            "win32",
+        ]
+        if enable_kernel_profile:
+            code.writelines(["#include "])
+        code.writeline("#include ")
+
+        # 2. Function definition
+        kernel_decl_name = str(Placeholder.KERNEL_NAME) if name is None else name
+        kernel_name = str(Placeholder.DESCRIPTIVE_NAME) if name is None else name
+        arg_defs, _, _ = self.args.cpp_argdefs()
+        arg_defs = ",\n".ljust(25).join(arg_defs)
+        func_export_decl = get_export_declaration()
+        inline_attr = (
+            "C10_ALWAYS_INLINE_ATTRIBUTE" if config.cpp.force_inline_kernel else ""
+        )
+        code.writeline(
+            f'extern "C" {func_export_decl} void {inline_attr} {kernel_decl_name}({arg_defs})'
+        )
+
+        # 3. Function body
+        with code.indent():
+            if enable_kernel_profile:
+                graph_id = V.graph.graph_id
+                prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else ""
+                code.writelines(
+                    [
+                        (
+                            "torch::aot_inductor::RAIIAtenRecordFunctionHandle "
+                            f'record_{prefix + kernel_name}_("{prefix + kernel_name}", nullptr);'
+                        )
+                    ]
+                )
+            for old, new in self.args.aliases():
+                code.writeline(f"auto {old} = {new};")
+            code.splice(self.loops_code)
+        return code.getvalue()
+
+    def call_kernel(self, wrapper, kernel_name):
+        _, call_args, arg_types = self.args.cpp_argdefs()
+        wrapper.generate_kernel_call(
+            kernel_name,
+            call_args,
+            triton=False,
+            arg_types=arg_types,
+        )
+
+
+class WorkSharing:
+    def __init__(self, code):
+        self.code = code
+        self.in_parallel = False
+        self.num_threads = None
+        self.stack = contextlib.ExitStack()
+
+    def parallel(self, threads):
+        if self.in_parallel and threads != self.num_threads:
+            # wrong number of threads
+            self.close()
+        if not self.in_parallel:
+            self.num_threads = threads
+            self.in_parallel = True
+            if config.cpp.dynamic_threads:
+                self.code.writeline("#pragma omp parallel")
+            else:
+                self.code.writeline(f"#pragma omp parallel num_threads({threads})")
+            self.stack.enter_context(self.code.indent())
+            self.code.writeline(
+                "int tid = omp_get_thread_num();",
+            )
+
+    def single(self):
+        if self.in_parallel:
+            self.code.writeline("#pragma omp single")
+        return self.in_parallel
+
+    def close(self):
+        self.stack.close()
+        self.in_parallel = False
+
+    def __enter__(self):
+        self.stack.__enter__()
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.stack.__exit__(exc_type, exc_val, exc_tb)
+
+
+@dataclasses.dataclass
+class LoopLevel:
+    var: Optional[sympy.Expr] = None
+    size: Optional[sympy.Expr] = None
+    offset: sympy.Expr = sympy.S.Zero
+    # Note [tiled_size]
+    # We may do loop-tiling at this loop level.
+    # When var is in [offset, tiled_size), we will perform the vectorization kernel.
+    # When var is in [tiled_size, size), we will perform the scalar or masked vectorization kernel.
+    # for (var = offset; var < size; var += steps) {
+    #     if (var >= offset && var < tiled_size) vec_loop_body();
+    #     if (var >= tiled_size && var < size) scalar_or_maskvec_loop_body();
+    # }
+    tiled_size: sympy.Expr = sympy.S.Zero
+    steps: sympy.Expr = sympy.S.One
+    parallel: int = 0
+    simd_omp: bool = False
+    simd_vec: bool = False
+    collapsed: bool = False
+    is_reduction: bool = False
+
+    def __post_init__(self):
+        # Regarding the C++/OpenMP backend, `cpu_vec_isa.pick_vec_isa()` to check
+        # vectorization ISA is a time-consuming and one-shot operation. It leads
+        # to taking a longer time to import `codegen.cpp` package because the
+        # `LoopLevel` of the package is decorated by `@dataclasses.dataclass` while
+        # the decorator will invoke `cpu_vec_isa.pick_vec_isa()` to initialize the
+        # `simd_nelements` of the `LoopLevel`. It might introduce additional compilation
+        # overhead to the Triton backend. Therefore, we moved the `simd_nelements` to
+        # `__post_init__`
+        picked_vec_isa: cpu_vec_isa.VecISA = cpu_vec_isa.pick_vec_isa()
+        self.simd_nelements: int = picked_vec_isa.nelements() if picked_vec_isa else 0
+
+    def tile(self, factor):
+        sympy_factor = sympy.Integer(factor)
+        loop = LoopLevel(self.var, self.size)
+        loop.steps = sympy_factor
+        loop.simd_vec = True
+        loop.tiled_size = FloorDiv(loop.size, sympy_factor) * sympy_factor
+        loop.parallel = self.parallel
+        loop.collapsed = False
+        loop.is_reduction = self.is_reduction
+        return loop
+
+    def lines(self):
+        offset_expr = cexpr_index(self.offset)
+        size_expr = cexpr_index(self.size)
+        if config.cpp.no_redundant_loops and offset_expr == size_expr:
+            return None
+        simd = (
+            f"simd simdlen({self.simd_nelements}) "
+            if self.simd_omp and self.simd_nelements > 1
+            else ""
+        )
+        if self.parallel:
+            # TODO(jansel): look into chunk size and other schedules
+            line1 = "#pragma omp for"
+            if self.parallel > 1:
+                line1 += f" collapse({self.parallel})"
+            if self.simd_omp:
+                line1 = line1.replace(" for ", f" for {simd}")
+        elif self.simd_vec:
+            line1 = ""
+        elif self.simd_omp:
+            line1 = f"#pragma omp {simd}"
+        elif not self.is_reduction and cpp_builder.is_gcc():
+            line1 = "#pragma GCC ivdep"
+        else:
+            line1 = ""
+        offset_str = f"{INDEX_TYPE} {self.var}={offset_expr}"
+        size_str = f"{self.var}<{size_expr}"
+        if self.steps.is_number:
+            steps_str = f"{self.var}+={cexpr_index(self.steps)}"
+        else:
+            # If the step size is 0, change it to 1 because a step size of 0
+            # will cause floating point exception (core dump) during parallelization.
+            steps_str = (
+                f"{self.var}+=({cexpr_index(self.steps)} == 0 ? "
+                f"1 : {cexpr_index(self.steps)})"
+            )
+        line2 = f"for({offset_str}; {size_str}; {steps_str})"
+        if self.collapsed or not line1:
+            return [line2]
+        return [line1, line2]
+
+
+@dataclasses.dataclass
+class LoopNest:
+    """
+    A loop-nest-like structure. It is built with the `build` method
+    as a loop nest and then will perform loop-tiling at some depth.
+
+    A typical case is for vectorization, where we typically do loop-tiling
+    at the innermost loop level. A more complicated case is when we do
+    2D tiling at both the innermost and outer levels.
+    """
+
+    loops: Optional[list[LoopLevel]] = None
+    kernel: Optional[CppKernel] = None
+
+    @staticmethod
+    def build(kernel: CppKernel):
+        """Build a LoopNest with the given `kernel` as the leaf"""
+        itervars = kernel.itervars
+        ranges = kernel.ranges
+        reduction_depth = kernel.reduction_depth
+        assert reduction_depth is not None
+
+        loops: Optional[list[LoopLevel]] = None
+        for loop_idx, (var, size) in enumerate(zip(itervars, ranges)):
+            loop = LoopLevel(var, size)
+            if not loops:
+                loops = [loop]
+            else:
+                loops.append(loop)
+            if loop_idx >= reduction_depth:
+                loop.is_reduction = kernel.is_reduction
+
+        loop_nest = LoopNest(loops)
+        return loop_nest
+
+    def __bool__(self):
+        return bool(self.loops)
+
+    @cache_on_self
+    def max_parallel_depth(self):
+        """
+        Maximal allowed depth for parallelism: All reduction or non-reduction levels.
+        When the range of the first inner loop beyond the maximum parallel depth is much
+        larger than the range of all outer loops within the maximum parallel depth,
+        change the starting depth of parallelism to the first inner loop and recalculate
+        the maximum parallel depth.
+        """
+        if self.loops is None:
+            return ParallelDepth(parallel_depth=0, start_depth=0)
+
+        start_depth = 0
+        max_depth = 0
+        is_reduction = self.loops[0].is_reduction
+        num_steps = sympy.Integer(1)
+        for loop in self.loops:
+            if loop.is_reduction != is_reduction:
+                break
+            num_steps = num_steps * FloorDiv(loop.size, loop.steps)
+            max_depth += 1
+
+        def get_simd_vec_depth(loops):
+            # Return the first loop level which is simd_vec
+            for i, loop in enumerate(loops):
+                if loop.simd_vec:
+                    return i
+            return None
+
+        simd_vec_depth = get_simd_vec_depth(self.loops)
+
+        def has_scalar_kernel(loop_nest: LoopNest):
+            assert isinstance(loop_nest.kernel, CppKernelProxy)
+            return any(
+                not isinstance(kernel, CppVecKernel)
+                for kernel in loop_nest.kernel.kernels
+            )
+
+        # When the number of steps of the first inner loop is much larger than the number of steps of
+        # all outer loops, change `start_depth` to the first inner loop and recalculate `max_depth`.
+        if (
+            max_depth < len(self.loops)
+            and isinstance(num_steps, sympy.Integer)
+            and isinstance(self.loops[max_depth].size, sympy.Integer)
+            and num_steps * 300
+            < FloorDiv(self.loops[max_depth].size, self.loops[max_depth].steps)
+            and not (
+                # Disable parallel reduction under the vec loop
+                simd_vec_depth is not None
+                and max_depth > simd_vec_depth
+                and self.loops[max_depth].is_reduction
+                and has_scalar_kernel(self)
+            )
+        ):
+            start_depth = max_depth
+            max_depth = 0
+            is_reduction = self.loops[start_depth].is_reduction
+            for i in range(start_depth, len(self.loops)):
+                if self.loops[i].is_reduction != is_reduction:
+                    break
+                max_depth += 1
+        return ParallelDepth(parallel_depth=max_depth, start_depth=start_depth)
+
+    def mark_parallel(self, par_depth):
+        assert par_depth.parallel_depth <= self.max_parallel_depth().parallel_depth, (
+            "Parallel depth cannot exceed the maximal allowed parallel depth"
+        )
+        assert self.loops is not None
+        assert len(self.loops) >= par_depth.parallel_depth
+        loop = self.loops[par_depth.start_depth]
+        loop.parallel = par_depth.parallel_depth
+        if loop.is_reduction:
+            # pyrefly: ignore [bad-assignment]
+            metrics.parallel_reduction_count += 1
+        for i in range(par_depth.start_depth + 1, par_depth.parallel_depth):
+            self.loops[i].collapsed = True
+
+    def tile(self, depth, factor):
+        """
+        Do loop-tiling at the `depth` level with `factor`.
+            for (x0 = 0; x0 < x0_end; x0++)
+            ->
+            for (x0 = 0; x0 < x0_end; x0 += factor)
+        See details in Note [tiled_size].
+        """
+        assert self.loops
+        self.loops[depth] = self.loops[depth].tile(factor)
+        return self.loops[depth]
+
+    def get_kernel(self) -> CppKernel:
+        assert self.kernel
+        return self.kernel
+
+    def set_kernel(self, kernel):
+        self.kernel = kernel
+
+    def from_loop_level(self, level: int):
+        assert self.loops
+        assert len(self.loops) >= level
+        loops = None if level == len(self.loops) else self.loops[level:]
+        return LoopNest(loops, self.kernel)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_bmm_template.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_bmm_template.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4a7c2ef1640690bf751e08b1c1e4d33a3c147b4
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_bmm_template.py
@@ -0,0 +1,263 @@
+# mypy: allow-untyped-defs
+import contextlib
+import itertools
+from collections.abc import Callable
+from typing import Any, Optional
+from unittest.mock import patch
+
+import sympy
+
+from .. import ir
+from ..select_algorithm import PartialRender
+from ..virtualized import V
+from .common import ArgName
+from .cpp_gemm_template import CppGemmTemplate, GEMM_TEMPLATE
+from .cpp_micro_gemm import LayoutType
+from .cpp_template_kernel import CppTemplateKernel
+from .cpp_utils import DTYPE_TO_CPP, GemmBlocking
+
+
+# We pass all sizevars present in BY to the GEMM templates so variables are not renamed in the BMM definition
+GEMM_SINGLE_THREAD_MM_STUB = r"""
+{{kernel.def_kernel(
+    inputs={"X": X, "W": W},
+    outputs={"Y": Y_2d},
+    aliases=aliases,
+    function_name=kernel_name+"_single_thread_mm",
+    extra_sizevars=BY_sizevars + [b_index],
+    placeholder="")}}"""
+
+GEMM_THREADED_MM_STUB = r"""
+{{kernel.def_kernel(
+    inputs={"X": X, "W": W},
+    outputs={"Y": Y_2d},
+    aliases=aliases,
+    function_name=kernel_name+"_threaded_mm",
+    extra_sizevars=BY_sizevars + [b_index],
+    placeholder="")}}"""
+
+BMM_TEMPLATE = r"""
+{{ template.codegen_microkernel_def() }}
+{{ template.codegen_single_thread_gemm() }}
+{{ template.codegen_multi_thread_gemm() }}
+
+extern "C"
+{{kernel.def_kernel(inputs={"X": BX, "W": BW}, outputs={"Y": BY}, aliases=aliases)}}
+{
+    const int64_t B = {{kernel.size(BY_2d, 0)}};
+    {%- if num_threads > 1 %}
+    constexpr int64_t num_threads = {{num_threads}};
+    int64_t B_single_thread_block = (B / num_threads) * num_threads;
+
+    #pragma omp parallel for num_threads({{num_threads}})
+    {%- else %}
+    int64_t B_single_thread_block = B;
+    {%- endif %}
+    for (int64_t b_start = 0; b_start < B_single_thread_block; ++b_start) {
+        {{template.get_gemm_function_call(
+            kernel,
+            kernel_name+"_single_thread_mm",
+            "",
+            b_index="b_start",
+        )}}
+    }
+    for (int64_t b_start = B_single_thread_block; b_start < B; ++b_start) {
+        {{template.get_gemm_function_call(
+            kernel,
+            kernel_name+"_threaded_mm",
+            "",
+            b_index="b_start",
+        )}}
+    }
+}
+"""
+
+
+class CppBmmTemplate(CppGemmTemplate):
+    def __init__(
+        self,
+        input_nodes,
+        layout: ir.Layout,
+        num_threads: int,
+        register_blocking: GemmBlocking,
+        beta=1,
+        alpha=1,
+        has_bias=False,
+        epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
+        should_block_weights: bool = False,
+        name="bmm",
+    ):
+        """
+        In order to simplify the implementation and increase code reuse, the BMM template implements
+        two versions of the GEMM kernel: a single-threaded version and a multi-threaded version.
+        GEMM kernels are called in a loop over the batch dimension, with single-threaded GEMM calls
+        for all but the last (B % num_threads), which are handled by the multi-threaded GEMM kernel.
+
+        We use an extra sizevar `b_index` to index the batch dimension, which we pass into the GEMM
+        template as a sympy.Symbol. This allows us to slice the 3D batch tensors in the GEMM template
+        without any changes to the GEMM template itself.
+        """
+        super().__init__(
+            input_nodes,
+            layout,
+            num_threads,
+            register_blocking,
+            beta=beta,
+            alpha=alpha,
+            has_bias=has_bias,
+            epilogue_creator=epilogue_creator,
+            should_block_weights=should_block_weights,
+            name=name,
+        )
+        self.b_index = sympy.Symbol("s_b_index", integer=True, nonnegative=True)
+
+    @staticmethod
+    def get_padded_size(n, block_n, k, should_block_weight):
+        if should_block_weight:
+            # Tensor is constant or not contiguous, so we will pad and block
+            new_size, padded_n = CppGemmTemplate.get_padded_size(
+                n, block_n, k, should_block_weight
+            )
+            # Add the new batch dimension
+            new_size.insert(0, -1)
+            return new_size, padded_n
+        else:
+            new_size = [-1, k, n]
+            return new_size, n
+
+    @staticmethod
+    def check_if_block_weight(W, micro_gemm):
+        assert isinstance(W, ir.IRNode)
+        _, n = W.get_size()[-2:]
+        result = (
+            not W.get_layout().is_contiguous()
+            or W.get_name() in V.graph.constants
+            or (
+                n % micro_gemm.register_blocking.block_n != 0
+                and micro_gemm.get_b_layout != LayoutType.NORMAL
+            )
+        )
+        return result
+
+    def get_gemm_function_call(
+        self,
+        kernel: CppTemplateKernel,
+        function_name: str,
+        placeholder: str,
+        b_index: str,
+    ) -> str:
+        """
+        Similar to 'def_kernel' in cpp_template_kernel, but instead of generating a function definition,
+        generate a function call for the GEMM kernel.
+        Args:
+            placeholder: The string to replace the function call with
+            b_index: The index for slicing the 3D batch tensors
+        """
+
+        def hook():
+            arg_defs, call_args, _, _ = kernel.args.python_argdefs()
+            for i, buf in enumerate(call_args):
+                if buf == self.b_index:
+                    arg_defs[i] = ArgName(b_index)
+            call = f"{function_name}({', '.join(x.full_name() for x in arg_defs)});"
+            return call
+
+        assert placeholder not in kernel.render_hooks
+        kernel.render_hooks[placeholder] = hook
+        return placeholder
+
+    def get_default_reindexers(self, epilogue_nodes):
+        def reindexer(args):
+            # if epilogue nodes exist, they have 3D ranges but args are 2D, so add 0 index
+            return [self.b_index] + args
+
+        return [reindexer] * len(epilogue_nodes)
+
+    def get_options(
+        self,
+        kernel: CppTemplateKernel,
+        template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
+        flag_template_buffer_has_other_users: Optional[bool] = None,
+        epilogue_nodes: Optional[list[ir.IRNode]] = None,
+        **kwargs,
+    ) -> dict[str, Any]:
+        options = super().get_options(
+            kernel=kernel,
+            template_buffer_node=template_buffer_node,
+            flag_template_buffer_has_other_users=flag_template_buffer_has_other_users,
+            epilogue_nodes=epilogue_nodes,
+            **kwargs,
+        )
+
+        BX, BW, BY = options["X"], options["W"], options["Y"]
+        options["BX"], options["BW"], options["BY"] = BX, BW, BY
+        options["BY_2d"] = options["Y_2d"]
+        for kword in ["X", "W", "GemmOut", "Y_2d"]:
+            options[kword] = kernel.select(options[kword], 0, self.b_index)
+        for kword in ["X", "W", "Y_2d"]:
+            options[kword + "_dtype"] = DTYPE_TO_CPP[options[kword].dtype]
+        options["b_index"] = self.b_index
+        options["BY_sizevars"] = [
+            s
+            for sym in itertools.chain(BY.get_size(), BY.get_stride())
+            if isinstance(sym, sympy.Expr)
+            for s in sym.free_symbols
+        ]
+        options["kernel_name"] = kernel.kernel_name
+
+        return options
+
+    def render(  # type: ignore[override, return]
+        self,
+        kernel: CppTemplateKernel,
+        template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
+        flag_template_buffer_has_other_users: Optional[bool] = None,
+        epilogue_nodes: Optional[list[ir.IRNode]] = None,
+        **kwargs,
+    ) -> str:
+        options = self.get_options(
+            kernel=kernel,
+            template_buffer_node=template_buffer_node,
+            flag_template_buffer_has_other_users=flag_template_buffer_has_other_users,
+            epilogue_nodes=epilogue_nodes,
+            **kwargs,
+        )
+        self.render_options = options
+
+        with contextlib.ExitStack() as stack:
+            for buf in options["fake_buffers"]:
+                stack.enter_context(
+                    patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf))
+                )
+            result = self._template_from_string(BMM_TEMPLATE).render(**options)
+
+            # Finalize the function definitions for the gemm routines
+            sub_mm_hooks = {
+                name: hook
+                for name, hook in kernel.render_hooks.items()
+                if "FOR_BMM" in name
+            }
+            result = PartialRender(result, sub_mm_hooks).finalize_all()
+            for name in sub_mm_hooks:
+                del kernel.render_hooks[name]
+            del kernel.args.sizevars[options["b_index"]]
+            return result
+
+    def codegen_single_thread_gemm(self):
+        stub = self._template_from_string(GEMM_SINGLE_THREAD_MM_STUB).render(
+            self.render_options
+        )
+        return stub + self._template_from_string(GEMM_TEMPLATE).render(
+            {**self.render_options, "num_threads": 1}
+        )
+
+    def codegen_multi_thread_gemm(self):
+        stub = self._template_from_string(GEMM_THREADED_MM_STUB).render(
+            self.render_options
+        )
+        return stub + self._template_from_string(GEMM_TEMPLATE).render(
+            self.render_options
+        )
+
+    def codegen_gemm_stub_def(self):
+        return ""
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_flex_attention_template.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_flex_attention_template.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1ceecf7f7c9ea8081660c21a8ddf96254c98a68
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_flex_attention_template.py
@@ -0,0 +1,1090 @@
+# mypy: allow-untyped-defs
+import contextlib
+import logging
+import re
+from typing import Optional
+from unittest.mock import patch
+
+import sympy
+
+import torch
+import torch.utils
+
+from ...utils._ordered_set import OrderedSet
+from .. import ir
+from ..ir import TensorBox
+from ..select_algorithm import DataProcessorTemplateWrapper
+from ..utils import parallel_num_threads
+from ..virtualized import V
+from .cpp_template import CppTemplate
+from .cpp_utils import GemmBlocking
+
+
+log = logging.getLogger(__name__)
+
+# TODO: reuse cpp codegen to generate below pointwise/reduction kernels
+SOFTMAX_FUSIONS = r"""
+// 1) out = exp(a - val)
+// 2) val = sum(out)
+template 
+inline void {{kernel_name}}_exp_reduce_sum_fusion_kernel(
+    T1* a,
+    const int& size,
+    T2* out,
+    T1& val) {
+  auto vec_size = at::vec::Vectorized::size();
+  auto vec_max = at::vec::Vectorized(val);
+  T1 tmp_sum = 0;
+  auto vec_tmp_sum = at::vec::Vectorized(tmp_sum);
+  for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
+    auto tmp0 = at::vec::Vectorized::loadu(a + i);
+    auto tmp1 = tmp0 - vec_max;
+    auto tmp2 = tmp1.exp_u20();
+    vec_tmp_sum += tmp2;
+    at::native::_store(out + i, tmp2);
+  }
+  tmp_sum = at::vec::vec_reduce_all(
+      [](at::vec::Vectorized& x, at::vec::Vectorized& y) {
+        return x + y;
+      },
+      vec_tmp_sum);
+  for (long i = vec_size * (size / vec_size); i < size; i++) {
+    auto tmp0 = a[i];
+    auto tmp1 = tmp0 - val;
+    auto tmp2 = exp(tmp1);
+    tmp_sum += tmp2;
+    out[i] = tmp2;
+  }
+  val = tmp_sum;
+}
+
+// 1) out = a * scale
+// 2) max = max(out)
+template 
+inline void {{kernel_name}}_mul_reduce_max_fusion_kernel(
+    const scalar_t* a,
+    const scalar_t& scale,
+    const int& size,
+    scalar_t* out,
+    scalar_t& max) {
+  auto vec_size = at::vec::Vectorized::size();
+  auto vec_scale = at::vec::Vectorized(scale);
+  scalar_t tmp_max = -std::numeric_limits::infinity();
+  auto vec_tmp_max = at::vec::Vectorized(tmp_max);
+  for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
+    auto tmp0 = at::vec::Vectorized::loadu(a + i);
+    auto tmp1 = tmp0 * vec_scale;
+    vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp1);
+    at::native::_store(out + i, tmp1);
+  }
+  for (long i = vec_size * (size / vec_size); i < size; i++) {
+    auto tmp0 = a[i];
+    auto tmp1 = tmp0 * scale;
+    tmp_max = std::max(tmp_max, tmp1);
+    out[i] = tmp1;
+  }
+  max = std::max(
+      tmp_max,
+      at::vec::vec_reduce_all(
+          [](at::vec::Vectorized& x, at::vec::Vectorized& y) {
+            return at::vec::maximum(x, y);
+          },
+          vec_tmp_max));
+}
+
+template 
+static inline scalar_t* {{kernel_name}}_conditional_data_ptr(scalar_t* ptr, scalar_t* ptr2) {
+  TORCH_CHECK(ptr2 == nullptr);
+  return ptr;
+}
+
+template , int> = 0>
+static inline scalar_t* {{kernel_name}}_conditional_data_ptr(float* ptr, scalar_t* ptr2) {
+  return ptr2;
+}
+
+template 
+inline void {{kernel_name}}_fill_stub(scalar_t* data, scalar_t val, int64_t size) {
+  using Vec = at::vec::Vectorized;
+  Vec data_vec = Vec(val);
+  int64_t d = 0;
+  for (; d < size - (size % Vec::size()); d += Vec::size()) {
+    data_vec.store(data + d);
+  }
+  #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
+  # pragma unroll
+  #endif
+  for (; d < size; d++) {
+    data[d] = val;
+  }
+}
+
+// out = a * scale
+template 
+inline void {{kernel_name}}_mul_scale_kernel(
+    scalar_t* a,
+    scalar_t scale,
+    int64_t size) {
+  auto vec_size = at::vec::Vectorized::size();
+  auto vec_scale = at::vec::Vectorized(scale);
+  for (int64_t i = 0; i < vec_size * (size / vec_size); i += vec_size) {
+    auto tmp0 = at::vec::Vectorized::loadu(a + i);
+    auto tmp1 = tmp0 * vec_scale;
+    at::native::_store(a + i, tmp1);
+  }
+  for (int64_t i = vec_size * (size / vec_size); i < size; i++) {
+    auto tmp0 = a[i];
+    auto tmp1 = tmp0 * scale;
+    a[i] = tmp1;
+  }
+}
+
+"""
+
+BRGEMM_PACK_FUNCTIONS = r"""
+template 
+inline void {{kernel_name}}_copy_value_with_pad(
+    const scalar_t* value_ptr,
+    scalar_t* dst_ptr,
+    int64_t rows,
+    int64_t cols,
+    int64_t prows,
+    int64_t pcols,
+    int64_t ldi) {
+  auto vec_size = at::vec::Vectorized::size();
+  int64_t i = 0;
+  for (; i < rows; i++) {
+    int64_t j = 0;
+    for (; j < cols - (cols % vec_size); j += vec_size) {
+      auto vec_v =
+          at::vec::Vectorized::loadu(value_ptr + i * ldi + j);
+      vec_v.store(dst_ptr + i * pcols + j);
+    }
+
+    if (j < cols) {
+      auto vec_v = at::vec::Vectorized::loadu(
+          value_ptr + i * ldi + j, cols - j);
+      vec_v.store(dst_ptr + i * pcols + j, cols - j);
+    }
+
+    // col padding
+    auto psize = pcols - cols;
+    if (psize > 0) {
+      auto zero_vec = at::vec::Vectorized(0);
+      int64_t pj = 0;
+      for (; pj < psize - (psize % vec_size); pj += vec_size) {
+        zero_vec.store(dst_ptr + i * pcols + cols + pj);
+      }
+      if (pj < psize) {
+        zero_vec.store(dst_ptr + i * pcols + cols + pj, psize - pj);
+      }
+    }
+  }
+  // row padding
+  for (; i < prows; i++) {
+    auto zero_vec = at::vec::Vectorized(0);
+    int64_t j = 0;
+    for (; j < pcols - (pcols % vec_size); j += vec_size) {
+      zero_vec.store(dst_ptr + i * pcols + j);
+    }
+    if (j < pcols) {
+      zero_vec.store(dst_ptr + i * pcols + j, pcols - j);
+    }
+
+  }
+}
+"""
+
+MICRO_GEMM_TEMPLATE = r"""
+GEMM_DEFINE
+"""
+
+ALLOCATE_BUFFER = r"""
+  int64_t {{buffer_name}}_dtype_itemsize = c10::is_reduced_floating_point_v<{{buffer_dtype}}> ? 2 : 4;
+  auto& {{buffer_name}}_allocator = *at::getCPUAllocator();
+  auto {{buffer_name}}_work_data = {{buffer_name}}_allocator.allocate({{buffer_size}}*{{buffer_name}}_dtype_itemsize);
+  void* {{buffer_name}}_data_ptr = {{buffer_name}}_work_data.get();
+  {{buffer_dtype}}* {{buffer_name}} = ({{buffer_dtype}}*){{buffer_name}}_data_ptr;
+"""
+
+FLEX_ATTENTION_TEMPLATE = r"""
+{{template.header().getvalue()}}
+#include 
+#include 
+#include 
+{{template.codegen_micro_gemm(kernel.kernel_name)}}
+{{template.codegen_softmax_fusion(kernel.kernel_name)}}
+{{template.codegen_brgemm_pack_function(kernel.kernel_name)}}
+{%- set kernel_args = {"query": query, "key": key, "value": value,
+                       "kv_num_blocks": kv_num_blocks, "kv_indices": kv_indices,
+                       "full_kv_num_blocks": full_kv_num_blocks, "full_kv_indices": full_kv_indices } %}
+{%- set kernel_args = template.update_kernel_args(kernel_args) %}
+
+extern "C"
+{{kernel.def_kernel(inputs=kernel_args, outputs={"output": output}, extra_sizevars=template.extra_sizevars)}}
+{
+  {{ kernel.maybe_codegen_profile() }}
+  int64_t qBlockSize = {{qBlockSize}};
+  int64_t kvBlockSize = {{kvBlockSize}};
+  int64_t num_thread = {{num_thread}};
+
+  // dtypes of kernel and internal buffers
+  using scalar_t = {{kernel.dtype(query)}};
+  constexpr bool is_reduced_type = c10::is_reduced_floating_point_v;
+  using accum_t = at::opmath_type<{{kernel.dtype(query)}}>;
+  using Vec = at::vec::Vectorized;
+  accum_t scaling_factor = {{scale}};
+  int64_t batchSize = {{kernel.size(query, 0)}};
+  int64_t qSize = {{kernel.size(query, 1)}};
+  int64_t num_head = {{kernel.size(query, 2)}};
+  int64_t headSize = {{kernel.size(query, 3)}};
+  int64_t batchSize_k = {{kernel.size(key, 0)}};
+  int64_t num_head_k = {{kernel.size(key, 2)}};
+  int64_t headSize_v = {{kernel.size(value, 3)}};
+  bool is_broadcast_bs_kv = batchSize != batchSize_k;
+  bool is_broadcast_head_kv = num_head != num_head_k;
+  int64_t gqa_shards = num_head / num_head_k;
+  int64_t bs_shards = batchSize / batchSize_k;
+
+  int64_t batchSize_kvi = {{kernel.size(kv_indices, 0)}};
+  int64_t num_head_kvi = {{kernel.size(kv_indices, 1)}};
+  int64_t block_num_kvi = {{kernel.size(kv_indices, 3)}};
+  bool is_broadcast_bs_kvi = batchSize != batchSize_kvi;
+  bool is_broadcast_head_kvi = num_head != num_head_kvi;
+  int64_t gqa_shards_kvi = num_head / num_head_kvi;
+  int64_t bs_shards_kvi = batchSize / batchSize_kvi;
+
+  int64_t kviStrideB = {{kernel.stride(kv_indices, 0)}};
+  int64_t kviStrideH = {{kernel.stride(kv_indices, 1)}};
+  int64_t kviStrideQ = {{kernel.stride(kv_indices, 2)}};
+
+  int64_t num_kviStrideB = {{kernel.stride(kv_num_blocks, 0)}};
+  int64_t num_kviStrideH = {{kernel.stride(kv_num_blocks, 1)}};
+
+{%- if has_full_kv_block %}
+  int64_t full_kviStrideB = {{kernel.stride(full_kv_indices, 0)}};
+  int64_t full_kviStrideH = {{kernel.stride(full_kv_indices, 1)}};
+  int64_t full_kviStrideQ = {{kernel.stride(full_kv_indices, 2)}};
+
+  int64_t full_num_kviStrideB = {{kernel.stride(full_kv_num_blocks, 0)}};
+  int64_t full_num_kviStrideH = {{kernel.stride(full_kv_num_blocks, 1)}};
+  auto full_kv_indices_data = full_kv_indices;
+  auto full_kv_num_blocks_data = full_kv_num_blocks;
+{%- endif %}
+
+  auto kv_num_blocks_data = kv_num_blocks;
+  auto kv_indices_data = kv_indices;
+
+  // Strides
+  int64_t qStrideB = {{kernel.stride(query, 0)}};
+  int64_t qStrideM = {{kernel.stride(query, 1)}};
+  int64_t qStrideH = {{kernel.stride(query, 2)}};
+  int64_t kStrideB = {{kernel.stride(key, 0)}};
+  int64_t kStrideN = {{kernel.stride(key, 1)}};
+  int64_t kStrideH = {{kernel.stride(key, 2)}};
+  int64_t vStrideB = {{kernel.stride(value, 0)}};
+  int64_t vStrideN = {{kernel.stride(value, 1)}};
+  int64_t vStrideH = {{kernel.stride(value, 2)}};
+  int64_t oStrideB = {{kernel.stride(output, 0)}};
+  int64_t oStrideM = {{kernel.stride(output, 2)}};
+  int64_t oStrideH = {{kernel.stride(output, 1)}};
+
+  int64_t kvSize = {{kernel.size(key, 1)}};
+
+  int64_t qSplitSize = qBlockSize;
+  int64_t kvSplitSize = kvBlockSize;
+
+
+  qSplitSize = qSplitSize > qSize ? qSize : qSplitSize;
+  kvSplitSize = kvSplitSize > kvSize ? kvSize : kvSplitSize;
+  int64_t qSlice = (qSize + qSplitSize - 1) / qSplitSize;
+  int64_t kvSlice = (kvSize + kvSplitSize - 1) / kvSplitSize;
+  int64_t kvTail = (kvSize - 1) % kvSplitSize + 1;
+
+  bool need_pack = false;
+  // Whether pack is needed for BFloat16/Half
+  if (is_reduced_type) {
+    // check platform ability
+    need_pack = std::is_same_v ? at::native::cpublas::could_pack(at::kBFloat16)
+                                                       : at::native::cpublas::could_pack(at::kHalf);
+  }
+  if (need_pack) {
+    // When the number of gemm is greater than the number of pack,
+    // the pack overhead can be overlapped.
+    int64_t thresh_size = 64;
+    need_pack = kvSize >= thresh_size && qSize >= thresh_size;
+    if (need_pack) {
+      double pack_size = batchSize * num_head * kvSize * headSize;
+      double qs_per_thread = (batchSize * num_head * qSlice + num_thread - 1) / num_thread;
+      double gemm_size_per_thread = qs_per_thread * qSplitSize * kvSize * headSize;
+      need_pack = gemm_size_per_thread / pack_size >= 4;
+    }
+  }
+  // Pad is needed for packing when K is not even
+  bool headSize_even = headSize % 2 == 0;
+  int64_t eheadSize = need_pack && !headSize_even ? headSize + 1: headSize;
+  int64_t ekvSplitSize = need_pack && (kvSplitSize % 2 != 0) ? kvSplitSize + 1 : kvSplitSize;
+  int64_t ekvTail = need_pack && (kvTail % 2 != 0) ? kvTail + 1 : kvTail;
+  int64_t kv_padding_size = (kvSize - 1) / kvSplitSize * ekvSplitSize + ekvTail;
+
+  // Allocate per thread temp buf (accumulate type)
+  int64_t _size_per_thread =
+      /* qk     */ qSplitSize * kvSplitSize +
+      /* qk_max */ qSplitSize +
+      /* qk_sum */ qSplitSize +
+      /* dst    */ qSplitSize * headSize_v;
+
+  // Inputs/outputs buffers
+  const scalar_t* q_data = query;
+  const scalar_t* k_data = key;
+  const scalar_t* v_data = value;
+  scalar_t* out_data = output;
+
+  // Buffers to store accum results, padding query and transpose/packing key/value
+  {{template.codegen_allocate_buffer("buf_data", "accum_t", "num_thread*_size_per_thread")}}
+  {{template.codegen_allocate_buffer("buf_reduced_data", "scalar_t", "num_thread*qSplitSize*ekvSplitSize")}}
+  {{template.codegen_allocate_buffer("key_reorder_ptr", "scalar_t", "batchSize_k*num_head_k*eheadSize*kvSize")}}
+  {{template.codegen_allocate_buffer("value_reorder_ptr", "scalar_t", "batchSize_k*num_head_k*kv_padding_size*headSize_v")}}
+  {{template.codegen_allocate_buffer("transpose_buffer_ptr", "scalar_t", "num_thread*kvSplitSize*headSize")}}
+  {{template.codegen_allocate_buffer("query_padding_ptr", "scalar_t", "num_thread*qSplitSize*eheadSize")}}
+  if (need_pack) {
+    // Pack K, V
+    at::parallel_for(0, batchSize_k * num_head_k * kvSlice, 1, [&](int64_t begin, int64_t end) {
+      int ompIdx = at::get_thread_num();
+      int64_t i = 0, j = 0, l = 0, n = 0;
+      scalar_t* transpose_ptr = need_pack? transpose_buffer_ptr + ompIdx * kvSplitSize * headSize : nullptr;
+      at::native::data_index_init(begin, i, batchSize_k, j, num_head_k, l, kvSlice);
+      for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
+        n = l * kvSplitSize;
+        int64_t cur_kvSplitSize = std::min(kvSplitSize, kvSize - n);
+        auto k_addr =
+              k_data + i * kStrideB + j * kStrideH + n * kStrideN;
+        auto v_addr =
+              v_data + i * vStrideB + j * vStrideH + n * vStrideN;
+        // transpose [cur_kvSplitSize, headSize] -> [headSize, cur_kvSplitSize]
+        at::native::utils::transpose(
+          cur_kvSplitSize,
+          headSize,
+          /* src_ptr */
+          reinterpret_cast(k_addr),
+          /* ld_src */ kStrideN,
+          /* dst */ reinterpret_cast(transpose_ptr),
+          /* ld_dst */ cur_kvSplitSize);
+
+        // Pack [headSize, cur_kvSplitSize]
+        at::vec::pack_vnni2(
+          /* src */ reinterpret_cast(transpose_ptr),
+          /* dst */ reinterpret_cast(key_reorder_ptr + i * num_head_k * eheadSize * kvSize +
+                  j * eheadSize * kvSize + n * eheadSize),
+          /* ld_src */ cur_kvSplitSize,
+          /* K */ headSize,
+          /* N */ cur_kvSplitSize);
+
+        // Pack [cur_kvSplitSize, headSize_v]
+        at::vec::pack_vnni2(
+          /* src */ reinterpret_cast(v_addr),
+          /* dst */ reinterpret_cast(value_reorder_ptr +
+                  i * num_head_k * kv_padding_size * headSize_v +
+                  j * kv_padding_size * headSize_v + n * headSize_v),
+          /* ld_src */ vStrideN,
+          /* K */ cur_kvSplitSize,
+          /* N */ headSize_v);
+      // Move to the next query
+      at::native::data_index_step(i, batchSize_k, j, num_head_k, l, kvSlice);
+      }
+    });
+  }
+  // Attention loop below
+  at::parallel_for(0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) {
+    int64_t i = 0, j = 0, k = 0;
+    at::native::data_index_init(begin, i, batchSize, j, num_head, k, qSlice);
+    int ompIdx = at::get_thread_num();
+    accum_t* buf_ptr = buf_data + ompIdx * _size_per_thread;
+    accum_t* qk_data = buf_ptr;
+    accum_t* qk_max_data = qk_data + qSplitSize * kvSplitSize;
+    accum_t* qk_sum_data = qk_max_data + qSplitSize;
+    accum_t* dst_data = qk_sum_data + qSplitSize;
+    scalar_t *qk_reduced_data =
+        is_reduced_type
+            ? buf_reduced_data + ompIdx * qSplitSize * ekvSplitSize
+            : nullptr;
+    scalar_t* query_t_padding_ptr = (!headSize_even && need_pack)
+            ? query_padding_ptr + ompIdx * qSplitSize * eheadSize
+            : nullptr;
+
+    for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
+      auto i_kvi = is_broadcast_bs_kvi ? i/bs_shards_kvi : i;
+      auto j_kvi = is_broadcast_head_kvi ? j/gqa_shards_kvi : j;
+      auto kv_logical_num_data = kv_num_blocks_data + i_kvi * num_kviStrideB +
+                              j_kvi * num_kviStrideH + k;
+      int kv_indice_num = *kv_logical_num_data;
+      std::vector kv_indice_list(kv_indice_num);
+      for(int kv_i = 0; kv_i < kv_indice_num; kv_i++){
+        auto kv_logical_data = kv_indices_data + i_kvi * kviStrideB +
+                                  j_kvi * kviStrideH + k*kviStrideQ + kv_i;
+        kv_indice_list[kv_i] = *kv_logical_data;
+      }
+      bool is_skip_kv = kv_indice_num > 0 ? false : true;
+{%- if has_full_kv_block %}
+      auto full_kv_logical_num_data = full_kv_num_blocks_data + i_kvi * num_kviStrideB +
+                              j_kvi * num_kviStrideH + k;
+      int full_kv_indice_num = *full_kv_logical_num_data;
+      std::vector full_kv_indice_list(full_kv_indice_num);
+      for(int kv_i = 0; kv_i < full_kv_indice_num; kv_i++){
+        auto full_kv_logical_data = full_kv_indices_data + i_kvi * full_kviStrideB +
+                                  j_kvi * full_kviStrideH + k*full_kviStrideQ + kv_i;
+        full_kv_indice_list[kv_i] = *full_kv_logical_data;
+      }
+      is_skip_kv = kv_indice_num + full_kv_indice_num > 0 ? false : true;
+{%- endif %}
+      int64_t m = k * qSplitSize;
+      int64_t cur_qSplitSize = std::min(qSplitSize, qSize - m);
+      if (!is_skip_kv){
+        // Initialize max and sum
+        {{kernel.kernel_name}}_fill_stub(qk_max_data,
+            -std::numeric_limits::infinity(), cur_qSplitSize);
+        {{kernel.kernel_name}}_fill_stub(qk_sum_data,
+            static_cast(0), cur_qSplitSize);
+
+        if (!headSize_even && need_pack) {
+          // Pad query if headSize is not even
+          {{kernel.kernel_name}}_copy_value_with_pad(
+            q_data + i * qStrideB + j * qStrideH + m * qStrideM,
+            query_t_padding_ptr,
+            cur_qSplitSize,
+            headSize,
+            cur_qSplitSize,
+            eheadSize,
+            qStrideM
+          );
+        }
+      }
+
+{%- if has_full_kv_block %}
+      for (int64_t n_idx = 0; n_idx < kv_indice_num + full_kv_indice_num ; n_idx += 1) {
+        auto n = n_idx < kv_indice_num ? kv_indice_list[n_idx]*kvSplitSize : full_kv_indice_list[n_idx - kv_indice_num]*kvSplitSize;
+{%- else %}
+      for (int64_t n_idx = 0; n_idx < kv_indice_num ; n_idx += 1) {
+        auto n = kv_indice_list[n_idx]*kvSplitSize;
+{%- endif %}
+
+        auto cur_n = n/kvSplitSize;
+        int64_t cur_kvSplitSize = std::min(kvSplitSize, kvSize - n);
+        int64_t cur_ekvSplitSize = (need_pack && cur_kvSplitSize % 2 != 0) ? cur_kvSplitSize + 1 : cur_kvSplitSize;
+
+        // Calculate scale * q @ k.T
+        auto i_kv = is_broadcast_bs_kv ? i/bs_shards : i;
+        auto j_kv = is_broadcast_head_kv ? j/gqa_shards : j;
+
+        if (!need_pack) {
+          auto k_addr =
+              k_data + i_kv * kStrideB + j_kv * kStrideH + n * kStrideN;
+
+          {{kernel.kernel_name}}_kernel_micro_gemm_transpose_b(false)>(
+              q_data + i * qStrideB + j * qStrideH +
+                  m * qStrideM,
+              k_addr,
+              qk_data,
+              cur_qSplitSize,
+              cur_kvSplitSize,
+              headSize,
+              qStrideM,
+              kStrideN,
+              cur_kvSplitSize);
+
+        } else {
+          at::native::cpublas::brgemm(
+              cur_qSplitSize,
+              cur_kvSplitSize,
+              eheadSize,
+              headSize_even ? qStrideM : eheadSize,
+              cur_kvSplitSize,
+              cur_kvSplitSize,
+              false,
+              !headSize_even
+                  ? query_t_padding_ptr
+                  : q_data + i * qStrideB + j * qStrideH + m * qStrideM,
+              key_reorder_ptr + i_kv * num_head_k * eheadSize * kvSize +
+                  j_kv * eheadSize * kvSize + n * eheadSize,
+              qk_data,
+              need_pack);
+        }
+
+        {{kernel.kernel_name}}_mul_scale_kernel(qk_data, scaling_factor, cur_qSplitSize*cur_kvSplitSize);
+
+{%- if score_mod and mask_mod %}
+        // TODO: reduce the number of calls of q_idx and kv_idx initialization
+        std::vector q_idx(cur_qSplitSize);
+        for (int64_t i = 0; i < cur_qSplitSize; ++i) {
+          q_idx[i] = m + i;
+        }
+
+        std::vector kv_idx(cur_kvSplitSize);
+        for (int64_t i = 0; i < cur_kvSplitSize; ++i) {
+          kv_idx[i] = n + i;
+        }
+
+        std::vector b_idx = {i};
+        std::vector h_idx = {j};
+
+        accum_t* in_ptr0 = qk_data;
+
+        auto in_ptr1 = b_idx.data();
+        auto in_ptr2 = h_idx.data();
+        auto in_ptr3 = q_idx.data();
+        auto in_ptr4 = kv_idx.data();
+
+        // apply score mod function
+        {
+            {{ template.generate_other_buffer("score_others", 0, "len_score_other", kernel.args) }}
+            accum_t* out_ptr{{score_buf_idx}} = in_ptr0;
+            {{ template.modification(score_mod, score_buf_name, score_buf_idx)|indent(12, false) }}
+        }
+
+        if ((std::find(kv_indice_list.begin(), kv_indice_list.end(), cur_n) != kv_indice_list.end()) ){
+          // Apply block mask, fill unused with -inf
+          {
+              {{ template.generate_other_buffer("mask_others", -1, "len_mask_other", kernel.args) }}
+              accum_t* out_ptr{{mask_buf_idx}} = in_ptr0;
+              {{ template.modification(mask_mod, mask_buf_name, mask_buf_idx)|indent(12, false) }}
+          }
+        }
+
+{%- endif %}
+        // Update coefficients with Softmax
+        accum_t tmp_max = 0, tmp_sum = 0, exp_tmp = 0;
+        for (int64_t row = 0; row < cur_qSplitSize; ++row) {
+          // apply scaling factor and max per row in fusion
+          {{kernel.kernel_name}}_mul_reduce_max_fusion_kernel(
+              qk_data + row * cur_kvSplitSize,
+              static_cast(1),
+              cur_kvSplitSize,
+              qk_data + row * cur_kvSplitSize,
+              tmp_max);
+          tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max;
+          if (tmp_max == -std::numeric_limits::infinity()) {
+            // to avoid `nan = exp2f(-inf - (-inf))`
+            {{kernel.kernel_name}}_fill_stub(
+              {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data) + row * cur_ekvSplitSize,
+              static_cast(0), cur_kvSplitSize);
+          } else {
+            tmp_sum = tmp_max;
+            // qk <- exp(qk - max) and sum per row
+            {{kernel.kernel_name}}_exp_reduce_sum_fusion_kernel(
+              qk_data + row * cur_kvSplitSize, cur_kvSplitSize,
+              {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data) + row * cur_ekvSplitSize,
+              tmp_sum);
+            // exp_tmp <- exp(max[row] - max)
+            exp_tmp = std::exp(qk_max_data[row] - tmp_max);
+            // sum[row] <- sum + exp_tmp * sum[row]
+            qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row];
+            // max[row] <- max
+            qk_max_data[row] = tmp_max;
+            // dst <- dst * exp_tmp
+            if (n_idx > 0) {
+              at::vec::map(
+              [exp_tmp](Vec x) { return x * Vec(exp_tmp); },
+              dst_data + row * headSize_v,
+              dst_data + row * headSize_v,
+              headSize_v);
+            }
+          }
+          if (need_pack && cur_kvSplitSize % 2 != 0) {
+            // Pad: [qSplitSize, cur_kvSplitSize] -> [qSplitSize, cur_kvSplitSize + 1]
+            *(qk_reduced_data + row * (1 + cur_kvSplitSize) + cur_kvSplitSize) = scalar_t(0);
+          }
+        }
+        // Calculate Softmax(q @ k.T) @ v
+        if (!need_pack) {
+          auto v_addr =
+              v_data + i_kv * vStrideB + j_kv * vStrideH + n * vStrideN;
+          // Fallback Half brgemm is slower than micro gemm
+          if (!std::is_same_v) {
+            at::native::cpublas::brgemm(
+                  cur_qSplitSize,
+                  headSize_v,
+                  cur_ekvSplitSize,
+                  cur_ekvSplitSize,
+                  vStrideN,
+                  headSize_v,
+                  n_idx > 0,
+                  {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data),
+                  v_addr,
+                  dst_data,
+                  need_pack);
+          } else {
+            if (n_idx > 0) {
+              {{kernel.kernel_name}}_kernel_micro_gemm(true)>(
+                {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data),
+                v_addr,
+                dst_data,
+                cur_qSplitSize,
+                headSize_v,
+                cur_ekvSplitSize,
+                cur_ekvSplitSize,
+                vStrideN,
+                headSize_v);
+            } else {
+              {{kernel.kernel_name}}_kernel_micro_gemm(false)>(
+                {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data),
+                v_addr,
+                dst_data,
+                cur_qSplitSize,
+                headSize_v,
+                cur_ekvSplitSize,
+                cur_ekvSplitSize,
+                vStrideN,
+                headSize_v);
+            }
+          }
+        } else {
+          int64_t psize = n / kvSplitSize * ekvSplitSize;
+          at::native::cpublas::brgemm(
+              cur_qSplitSize,
+              headSize_v,
+              cur_ekvSplitSize,
+              cur_ekvSplitSize,
+              headSize_v,
+              headSize_v,
+              n_idx > 0,
+              qk_reduced_data,
+              value_reorder_ptr +
+                  i_kv * num_head_k * kv_padding_size * headSize_v +
+                  j_kv * kv_padding_size * headSize_v + psize * headSize_v,
+              dst_data,
+              need_pack);
+        }
+      }
+
+      // dst <- dst / sum[row]
+      // reorder MHA output with strides
+      for (int64_t row = 0; row < cur_qSplitSize; ++row) {
+        // Row sums for full masked out rows are 0, we set them to 1
+        // in order to avoid NaNs in the output and instead set fully
+        // masked out rows to 0
+        qk_max_data[row] = qk_max_data[row] == -std::numeric_limits::infinity() ? 0 : qk_max_data[row];
+        qk_sum_data[row] = qk_sum_data[row] == 0 ? 1 : qk_sum_data[row];
+        accum_t sum_reciprocal = 1 / qk_sum_data[row];
+        at::vec::map(
+            [sum_reciprocal, is_skip_kv](Vec x) { return  is_skip_kv ? Vec(0.0) : x * Vec(sum_reciprocal); },
+            out_data + i * oStrideB + j * oStrideH + m * oStrideM + row * oStrideM,
+            dst_data + row * headSize_v,
+            headSize_v);
+      }
+
+      // Move to the next query
+      at::native::data_index_step(i, batchSize, j, num_head, k, qSlice);
+    }
+
+    at::native::cpublas::brgemm_release(need_pack);
+
+  });
+}
+"""
+
+
+class CppFlexAttentionTemplate(CppTemplate):
+    def __init__(
+        self,
+        input_nodes,
+        layout: ir.Layout,
+        scale,
+        score_mod,
+        mask_mod,
+        kv_block_size,
+        q_block_size,
+        has_other_buffer,
+        no_full_kv_block,
+        fake_buffers,
+        len_score_other,
+        len_mask_other,
+        kernel_input_name_to_buffer,
+        block_vars,
+    ) -> None:
+        assert layout.dtype in [torch.float, torch.bfloat16, torch.float16]
+        super().__init__("flex_attention", input_nodes, layout, parallel_num_threads())
+        self.scale = scale
+        self.score_mod = score_mod
+        self.mask_mod = mask_mod
+        self.score_buf_name = (
+            V.graph.register_buffer(self.score_mod) if self.score_mod else None
+        )
+        self.mask_buf_name = (
+            V.graph.register_buffer(self.mask_mod) if self.mask_mod else None
+        )
+
+        def get_idx(buf_name):
+            match = re.search(r"\d+", buf_name)
+            assert match, f"incorrect score buf name: {buf_name}"
+            return match.group()
+
+        self.score_buf_idx = (
+            get_idx(self.score_buf_name) if self.score_buf_name else None
+        )
+        self.mask_buf_idx = get_idx(self.mask_buf_name) if self.mask_buf_name else None
+        self.kv_block_size = kv_block_size
+        self.q_block_size = q_block_size
+        self.has_other_buffer = has_other_buffer
+        self.no_full_kv_block = no_full_kv_block
+        self.other_buffer_input_offset = 2
+        if self.no_full_kv_block:
+            self.other_buffer_input_offset = 0
+        self.fake_buffers = fake_buffers
+        self.len_score_other = len_score_other
+        self.len_mask_other = len_mask_other
+        self.kernel_input_name_to_buffer = kernel_input_name_to_buffer
+        self.block_vars = block_vars
+        self.extra_sizevars = list(
+            OrderedSet(
+                val
+                for val in self.kernel_input_name_to_buffer.values()
+                if isinstance(val, sympy.Symbol)
+            )
+        )
+        self.other_buf_start_idx = 5
+        self.score_mod_other_buffers = (
+            self.input_nodes[
+                self.other_buf_start_idx
+                + self.other_buffer_input_offset : self.other_buf_start_idx
+                + self.other_buffer_input_offset
+                + self.len_score_other
+            ]
+            if self.has_other_buffer
+            else None
+        )
+        self.mask_mod_other_buffers = (
+            self.input_nodes[
+                self.other_buf_start_idx
+                + self.other_buffer_input_offset
+                + self.len_score_other :
+            ]
+            if self.has_other_buffer
+            else None
+        )
+        self.other_ptr_data = {}  # type: ignore[var-annotated]
+
+    def update_kernel_args(self, kernel_args):
+        kernel_args.update(
+            {
+                key: value
+                for key, value in self.kernel_input_name_to_buffer.items()
+                if not isinstance(value, sympy.Symbol)
+            }
+        )
+        return kernel_args
+
+    def generate_other_buffer(self, buf_list, start_offset, len_attr, kernel_args):
+        kernel_input_name_to_buffer_name = {
+            key: value if isinstance(value, sympy.Symbol) else value.get_name()
+            for key, value in self.kernel_input_name_to_buffer.items()
+        }
+
+        def get_arg(name):
+            return kernel_input_name_to_buffer_name.get(name)
+
+        def get_arg_name(name):
+            if isinstance(get_arg(name), sympy.Symbol):
+                return kernel_args.sizevars.get(get_arg(name))
+            return kernel_args.input_buffers.get(get_arg(name))
+
+        if not self.has_other_buffer:
+            return ""
+
+        if start_offset == -1:
+            start_offset = self.len_score_other
+
+        length = getattr(self, len_attr)
+        for i in range(length):
+            pointer = f"in_ptr{self.other_buf_start_idx + start_offset + i}"
+            buffer_key = f"{buf_list}_{i}"
+            if pointer not in self.other_ptr_data:
+                self.other_ptr_data[pointer] = (
+                    get_arg_name(buffer_key),
+                    get_arg(buffer_key),
+                )
+
+        return "\n".join(
+            f"auto {ptr} = {name};" for ptr, (name, _) in self.other_ptr_data.items()
+        )
+
+    def modification(self, subgraph_buffer, output_name, output_idx):
+        assert isinstance(subgraph_buffer, ir.ComputedBuffer)
+        subgraph_buffer_data = subgraph_buffer.data
+        from ..loop_body import LoopBody
+        from ..utils import sympy_index_symbol_with_prefix, SymT
+        from ..virtualized import V
+        from .cpp import CppKernelProxy, KernelGroup, ParallelDepth
+
+        kernel_group = KernelGroup()
+        kernel_input_args = {
+            "score": "in_ptr0",
+            "b": "in_ptr1",
+            "h": "in_ptr2",
+            "q_idx": "in_ptr3",
+            "kv_idx": "in_ptr4",
+        }
+        if self.has_other_buffer:
+            kernel_input_args.update(
+                {arg: ptr for ptr, (_, arg) in self.other_ptr_data.items()}
+            )
+
+        kernel_output_args = {output_name: f"out_ptr{output_idx}"}
+
+        args = kernel_group.args
+        for name, inp in kernel_input_args.items():
+            args.input_buffers[name] = inp
+
+        for name, inp in kernel_output_args.items():
+            args.output_buffers[name] = inp
+
+        for name in self.extra_sizevars:
+            args.sizevars[name] = f"k{name}"
+
+        kernel_group.args = args
+
+        cpp_kernel_proxy = CppKernelProxy(kernel_group)
+        bodies = []
+        var_sizes_list = []
+        var_sizes = tuple(subgraph_buffer.get_size())
+        var_ranges = {
+            sympy_index_symbol_with_prefix(SymT.INDEX, i): sz
+            for i, sz in enumerate(var_sizes)
+        }
+
+        dst_layout = subgraph_buffer.get_layout()
+        output_index = dst_layout.make_indexer()([*var_ranges.keys()])
+
+        def fn(*args):
+            V.ops.store(
+                output_name,
+                output_index,
+                subgraph_buffer_data.make_loader()(args).value,
+            )
+
+        body = LoopBody(
+            fn,
+            (list(var_ranges.keys())),
+            var_ranges,
+            list(var_ranges.keys()),
+            tuple(),
+        )
+
+        from ..loop_body import MemoryUsageType
+
+        assert all(
+            mem.buffer_name in kernel_group.args.input_buffers
+            for mem in body.memory_usage[MemoryUsageType.LOAD]
+        ), (
+            "All the buffers in the score and mask subgraph should be in kernel_group.args.input_buffers"
+        )
+
+        bodies.append(body)
+        var_sizes_list.append((var_sizes, ()))
+
+        cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list)
+
+        def max_parallel_depth():
+            return ParallelDepth(parallel_depth=0, start_depth=0)
+
+        # This loop is not parallelized since it is not the outermost loop.
+        with patch.object(
+            cpp_kernel_proxy.loop_nest, "max_parallel_depth", max_parallel_depth
+        ):
+            kernel_group.finalize_kernel(cpp_kernel_proxy, [])
+        output_code = kernel_group.loops_code.getvalue()
+
+        var_q_symbol, var_kv_symbol = self.block_vars
+        # See [Note] Handle the case where the split sizes are not statically known.
+        # We don't know the value of qBlockSize and rkvBlockSize during compilation time
+        # thus we've represented them by symbols.
+        # We change the symbol strings back to "cur_qSplitSize" and "cur_kvSplitSize"
+        # in the generated code thus they'll be filled with the real value during runtime.
+        if var_q_symbol in kernel_group.args.sizevars:
+            output_code = output_code.replace(
+                kernel_group.args.sizevars[var_q_symbol], "cur_qSplitSize"
+            )
+        if var_kv_symbol in kernel_group.args.sizevars:
+            output_code = output_code.replace(
+                kernel_group.args.sizevars[var_kv_symbol], "cur_kvSplitSize"
+            )
+
+        return output_code
+
+    @staticmethod
+    def add_choices(
+        choices,
+        input_nodes,
+        layout,
+        scale,
+        score_mod,
+        mask_mod,
+        kv_block_size,
+        q_block_size,
+        has_other_buffer,
+        no_full_kv_block,
+        fake_buffers,
+        len_score_other,
+        len_mask_other,
+        kernel_input_name_to_buffer,
+        block_vars,
+    ):
+        def preprocessor(input_nodes, layout):
+            return input_nodes, layout
+
+        def postprocessor(output):
+            return output
+
+        template = DataProcessorTemplateWrapper(
+            CppFlexAttentionTemplate,
+            preprocessor,
+            postprocessor,
+            input_nodes=input_nodes,
+            layout=layout,
+            scale=scale,
+            score_mod=score_mod,
+            mask_mod=mask_mod,
+            kv_block_size=kv_block_size,
+            q_block_size=q_block_size,
+            has_other_buffer=has_other_buffer,
+            no_full_kv_block=no_full_kv_block,
+            fake_buffers=fake_buffers,
+            len_score_other=len_score_other,
+            len_mask_other=len_mask_other,
+            kernel_input_name_to_buffer=kernel_input_name_to_buffer,
+            block_vars=block_vars,
+        )
+        template.maybe_append_choice(choices)
+        return template
+
+    def apply_score_mod(self, score, b, h, q_idx, kv_idx):
+        return self.score_mod.graph_module(score, b, h, q_idx, kv_idx).item()
+
+    def render(  # type: ignore[override,return]
+        self,
+        kernel,
+        template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
+        epilogue_nodes: Optional[list[ir.IRNode]] = None,
+        **kwargs,
+    ) -> str:
+        if epilogue_nodes is not None and epilogue_nodes != []:
+            raise NotImplementedError(
+                "Unsupported for `epilogue_nodes` in CppFlexAttentionTemplate."
+            )
+        # Query (Batch x Num_heads  x Q_seq_len  x Dim_per_head)
+        #     -> (Batch x Q_seq_len  x Num_heads  x Dim_per_head)
+        #  Key   (Batch x Num_heads  x KV_seq_len x Dim_per_head)
+        #     -> (Batch x KV_seq_len x Num_heads  x Dim_per_head)
+        #  Value (Batch x Num_heads  x KV_seq_len x Dim_per_head)
+        #     -> (Batch x KV_seq_len x Num_heads  x Dim_per_head)
+
+        query = kernel.permute(self.input_nodes[0], [0, 2, 1, 3])
+        key = kernel.permute(self.input_nodes[1], [0, 2, 1, 3])
+        value = kernel.permute(self.input_nodes[2], [0, 2, 1, 3])
+        self.accumulate_dtype = torch.float
+        self.input_dtype = query.layout.dtype
+
+        num_threads = parallel_num_threads()
+        assert isinstance(self.output_node, ir.IRNode)
+        buf_out: ir.IRNode = TensorBox.create(self.output_node)
+        if template_buffer_node is not None:
+            buf_out = template_buffer_node
+        options = dict(
+            query=query,
+            key=key,
+            value=value,
+            kv_num_blocks=self.input_nodes[3],
+            kv_indices=self.input_nodes[4],
+            full_kv_num_blocks=(
+                self.input_nodes[5] if not self.no_full_kv_block else None
+            ),
+            full_kv_indices=self.input_nodes[6] if not self.no_full_kv_block else None,
+            score_mod_other_buffers=self.score_mod_other_buffers,
+            mask_mod_other_buffers=self.mask_mod_other_buffers,
+            scale=self.scale,
+            has_full_kv_block=not self.no_full_kv_block,
+            accumulate_dtype=self.accumulate_dtype,
+            query_dtype=self.input_dtype,
+            kvBlockSize=self.kv_block_size,
+            qBlockSize=self.q_block_size,
+            template=self,
+            output=buf_out,
+            kernel=kernel,
+            num_thread=num_threads,
+            score_mod=self.score_mod,
+            mask_mod=self.mask_mod,
+            score_buf_name=self.score_buf_name,
+            mask_buf_name=self.mask_buf_name,
+            score_buf_idx=self.score_buf_idx,
+            mask_buf_idx=self.mask_buf_idx,
+        )
+        with contextlib.ExitStack() as stack:
+            for buf in self.fake_buffers:
+                stack.enter_context(
+                    patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf))
+                )
+            return self._template_from_string(FLEX_ATTENTION_TEMPLATE).render(**options)
+
+    def codegen_softmax_fusion(self, kernel_name: str):
+        # TODO: use inductor IR to rewrite those fusions
+        return self._template_from_string(SOFTMAX_FUSIONS).render(
+            dict(kernel_name=kernel_name)
+        )
+
+    def codegen_brgemm_pack_function(self, kernel_name: str):
+        # TODO: make them general for common bmm templates
+        return self._template_from_string(BRGEMM_PACK_FUNCTIONS).render(
+            dict(kernel_name=kernel_name)
+        )
+
+    def codegen_allocate_buffer(self, buffer_name: str, buffer_dtype, buffer_size):
+        return self._template_from_string(ALLOCATE_BUFFER).render(
+            dict(
+                buffer_name=buffer_name,
+                buffer_dtype=buffer_dtype,
+                buffer_size=buffer_size,
+            )
+        )
+
+    def micro_gemm_define(self, kernel_name: str):
+        from torch._inductor.codegen.cpp_gemm_template import (
+            CppTemplateKernel,
+            parallel_num_threads,
+        )
+        from torch._inductor.codegen.cpp_micro_gemm import CppMicroGemmFP32Vec
+        from torch._inductor.virtualized import V
+
+        micro_gemm_trans = CppMicroGemmFP32Vec(
+            kernel_name + "_kernel_micro_gemm_transpose_b",
+            self.input_dtype,
+            self.input_dtype,
+            self.accumulate_dtype,
+            self.accumulate_dtype,
+            GemmBlocking(1, 16, 1),
+            1,
+            True,
+            True,
+        )
+
+        micro_gemm = CppMicroGemmFP32Vec(
+            kernel_name + "_kernel_micro_gemm",
+            self.input_dtype,
+            self.input_dtype,
+            self.accumulate_dtype,
+            self.accumulate_dtype,
+            GemmBlocking(1, 16, 1),
+            1,
+            True,
+            False,
+        )
+
+        with V.set_graph_handler(V.graph):
+            kernel = CppTemplateKernel("cpp_micro_gemm", parallel_num_threads())
+            code_trans = micro_gemm_trans.codegen_define(kernel)
+            code = micro_gemm.codegen_define(kernel)
+        return code + code_trans
+
+    def codegen_micro_gemm(self, kernel_name: str):
+        micro_gemm = self.micro_gemm_define(kernel_name)
+        GEMM_SOURCE_CODE = MICRO_GEMM_TEMPLATE.replace("GEMM_DEFINE", micro_gemm)
+        return self._template_from_string(GEMM_SOURCE_CODE).render()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_gemm_template.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_gemm_template.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b15ef253a4d0bf61e7449fd77d15f7107997019
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_gemm_template.py
@@ -0,0 +1,1819 @@
+# mypy: allow-untyped-defs
+import contextlib
+import logging
+import math
+from collections.abc import Callable
+from functools import lru_cache
+from typing import Any, cast, Optional, TypeVar, Union
+from unittest.mock import patch
+
+import torch
+import torch.utils
+from torch.utils._ordered_set import OrderedSet
+
+from ..._dynamo.utils import counters
+from .. import config, ir, lowering as L
+from ..kernel.mm_common import mm_args
+from ..select_algorithm import DataProcessorTemplateWrapper
+from ..utils import (
+    has_free_symbols,
+    is_same_mkldnn_tensor,
+    is_same_tensor,
+    parallel_num_threads,
+)
+from ..virtualized import ops, V
+from .cpp import get_export_declaration
+from .cpp_micro_gemm import (
+    CppMicroBrgemm,
+    CppMicroGemm,
+    CppMicroGemmAMX,
+    CppMicroGemmFP32Vec,
+    create_micro_gemm,
+    is_int8_woq_gemm_small_m_dim_corner_case,
+    LayoutType,
+)
+from .cpp_template import CppTemplate
+from .cpp_template_kernel import CppTemplateKernel
+from .cpp_utils import (
+    create_epilogue_with_attr,
+    DTYPE_TO_CPP,
+    GemmBlocking,
+    get_gemm_template_output_and_compute_dtype,
+)
+
+
+log = logging.getLogger(__name__)
+
+GEMM_TEMPLATE_INIT_BLOCKING_BASIC_BLOCK = r"""
+    constexpr int64_t num_threads = {{num_threads}};
+    constexpr int64_t N = {{N}};
+    constexpr int64_t K = {{K}};
+    constexpr int64_t Mr = {{micro_gemm.register_blocking.block_m}};
+    constexpr int64_t Nr = {{micro_gemm.register_blocking.block_n}};
+    constexpr int64_t Kr = {{micro_gemm.register_blocking.block_k}};
+    constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr;
+    constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr;
+{%- if is_dynamic_M %}
+    const int64_t M = {{kernel.size(GemmOut, 0)}};
+    const int64_t Mr_blocks = (M + Mr - 1) / Mr;
+{%- else %}
+    constexpr int64_t M = {{kernel.size(GemmOut, 0)}};
+    constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr;
+{%- endif %}
+"""
+
+GEMM_TEMPLATE_INIT_BLOCKING_EXTENDED = r"""
+{%- if is_dynamic_M %}
+    {%- if num_threads > 1 %}
+    int64_t Mt_blocks, Nt_blocks, Kt_blocks;
+    mm_get_thread_blocking(num_threads, {{config.cpp.gemm_max_k_slices}}, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks);
+    {%- else %}
+    const auto Mt_blocks = Mr_blocks;
+    const auto Nt_blocks = Nr_blocks;
+    const auto Kt_blocks = Kr_blocks;
+    {%- endif %}
+    int64_t Mc_blocks, Nc_blocks, Kc_blocks;
+    uint32_t L1_cache_size = {{L1_cache_size}};
+    uint32_t L2_cache_size = {{L2_cache_size}};
+    mm_get_cache_blocking<{{kernel.dtype(X)}}, {{kernel.dtype(W)}}>(
+        num_threads,
+        M,
+        N,
+        K,
+        Mr,
+        Nr,
+        Kr,
+        Mt_blocks,
+        Nt_blocks,
+        Kt_blocks,
+        Mc_blocks,
+        Nc_blocks,
+        Kc_blocks,
+        L1_cache_size,
+        L2_cache_size
+    );
+    const int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
+    const int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
+    const int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks;
+    const int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks;
+    const int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
+{%- else %}
+    constexpr int64_t Mt_blocks = {{template.thread_blocking(num_threads).block_m}};
+    constexpr int64_t Nt_blocks = {{template.thread_blocking(num_threads).block_n}};
+    constexpr int64_t Kt_blocks = {{template.thread_blocking(num_threads).block_k}};
+    constexpr int64_t Mc_blocks = {{template.cache_blocking(num_threads).block_m}};
+    constexpr int64_t Nc_blocks = {{template.cache_blocking(num_threads).block_n}};
+    constexpr int64_t Kc_blocks = {{template.cache_blocking(num_threads).block_k}};
+    constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
+    constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
+    constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks;
+    constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks;
+    constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
+{%- endif %}
+{%- if is_woq_int4 %}
+    int64_t group_size = *q_group_size;
+{%- endif %}
+
+    // make sure all partitions are assigned
+    {{kernel.assert_function}}(
+        Mt_blocks * Nt_blocks * Kt_blocks * {{num_threads}} >= Mr_blocks * Nr_blocks * Kr_blocks,
+        "Not all partitions are assigned."
+    );
+"""
+
+GEMM_TEMPLATE_MULTI_THREADS_PARAMS = r"""
+const int tid = omp_get_thread_num();
+const int64_t k_group_id = tid / num_Kt_blocks;
+const int64_t k_slice_id = tid % num_Kt_blocks;
+const int64_t n_group_id = k_group_id / num_Nt_blocks;
+const int64_t n_slice_id = k_group_id % num_Nt_blocks;
+const int64_t k_block_start = k_slice_id * Kt_blocks;
+const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks);
+const int64_t n_block_start = n_slice_id * Nt_blocks;
+const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks);
+const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks);
+const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks);
+const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks;
+"""
+
+GEMM_TEMPLATE_SINGLE_THREAD_PARAMS = r"""
+constexpr int tid = 0;
+constexpr int64_t k_group_id = 0;
+constexpr int64_t k_slice_id = 0;
+constexpr int64_t n_group_id = 0;
+constexpr int64_t n_slice_id = 0;
+constexpr int64_t m_block_start = 0;
+constexpr int64_t n_block_start = 0;
+constexpr int64_t n_block_end = Nr_blocks;
+constexpr int64_t k_block_start = 0;
+constexpr int64_t k_block_end = Kr_blocks;
+{%- if is_dynamic_M %}
+const int64_t num_Mc_blocks_per_thread = num_Mc_blocks;
+const int64_t m_block_end = Mr_blocks;
+{%- else %}
+constexpr int64_t num_Mc_blocks_per_thread = num_Mc_blocks;
+constexpr int64_t m_block_end = Mr_blocks;
+{%- endif %}
+"""
+
+GEMM_TEMPLATE_M_LOOP_PARAMS = r"""
+const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread;
+const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks;
+const int64_t m_start = mc * Mr;
+const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
+const int64_t m_size = m_end - m_start;
+"""
+
+GEMM_TEMPLATE_N_LOOP_PARAMS = r"""
+const int64_t n_start = nc * Nr;
+const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
+const int64_t n_size = n_end - n_start;
+// NB: assume we pad N, nc_block_end won't exceed padded N here.
+const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end);
+"""
+
+GEMM_TEMPLATE_MICROKERNEL_DEF = r"""
+{{template.header().getvalue()}}
+
+{{micro_gemm.codegen_define(kernel)}}
+"""
+
+GEMM_TEMPLATE_STUB_DEF = r"""
+{%- if x_scale is not none %}
+    {%- set kernel_args = {"X": X, "W": W, "inp": inp, "x_scale": x_scale, "x_zp": x_zp, "w_scale": w_scale, "w_zp": w_zp,} %}
+{%- elif is_woq_int4 %}
+    {%- set kernel_args = {"X": X, "W": W, "q_group_size": q_group_size, "qscale_and_zeros": qscale_and_zeros} %}
+{%- else %}
+    {%- set kernel_args = {"X": X, "W": W, "inp": inp} %}
+{%- endif %}
+
+extern "C" {{export_declaration}}
+{{kernel.def_kernel(inputs=kernel_args, outputs={"Y": Y}, aliases=aliases)}}
+"""
+
+GEMM_TEMPLATE = r"""
+{{ template.codegen_gemm_stub_def() }}
+{
+    {{ kernel.maybe_codegen_profile() }}
+    {{ template.codegen_blocks(
+        num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOut, config, L1_cache_size, L2_cache_size, X, W
+    ) }}
+
+{%- if maybe_k_slicing %}
+    std::unique_ptr[]> local_buf_ptrs;
+    if (num_Kt_blocks > 1) {
+        local_buf_ptrs.reset(new std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[num_Mc_blocks * num_Nc_blocks * num_Kt_blocks]);
+    }
+{%- endif %}
+
+{%- if num_threads > 1 %}
+    #pragma omp parallel num_threads({{num_threads}})
+    {
+        {{ template.codegen_multi_threads_params()|indent(8, false) }}
+{%- else %}
+    {
+        {{ template.codegen_single_thread_params(is_dynamic_M)|indent(8, false) }}
+{%- endif %}
+        {{ micro_gemm.codegen_init(kernel) }}
+{%- if use_local_acc %}
+    {%- set acc_buf_name = "local_acc_buf" %}
+        {{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }}
+{%- endif %}
+        for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) {
+            {{ template.codegen_m_loop_params()|indent(12, false) }}
+            for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
+                {{ template.codegen_n_loop_params()|indent(16, false) }}
+{%- if use_local_acc %}
+    {%- set acc = kernel.local_buffers[acc_buf_name] %}
+                {{ kernel.reinit_buffer_if_null(acc_buf_name) }}
+{%- else %}
+    {%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_end")]) %}
+{%- endif %}
+                for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
+                    int64_t k_start = kc * Kr;
+                    int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
+{%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %}
+                    for (int64_t nci = nc; nci < nc_block_end; nci++) {
+{%- set acc_slice = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")]) %}
+{%- if template.should_block_weights and not is_woq_int4 %}
+{%- set tile_W_3d = kernel.slice_nd(W, [("nci", "nci + 1"), ("k_start", "k_end"), ()]) %}
+{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %}
+{%- else %}
+    {%- if is_woq_int4 %}
+        {%- set tile_W = kernel.slice_nd(W, [("nci * Nr", "(nci + 1) * Nr"), ("k_start * Nr / 2", "k_end * Nr / 2")]) %}
+        {%- set tile_qparam = kernel.slice_nd(
+            qscale_and_zeros, [("k_start // group_size", "k_end // group_size"), ("nci * Nr", "(nci + 1) * Nr"), ()]) %}
+    {%- else %}
+        {%- set tile_W = kernel.slice_nd(W, [("k_start", "k_end"), ("n_start", "n_start + n_size")]) %}
+        {%- set tile_qparam = None %}
+    {%- endif %}
+{%- endif %}
+                        if (kc == k_block_start) {
+                            {{ micro_gemm.codegen_call(kernel,
+                                                       tile_X,
+                                                       tile_W,
+                                                       acc_slice,
+                                                       accum=False,
+                                                       qscale_and_zeros=tile_qparam)|indent(28, false)
+                            }}
+                        } else {
+                            {{ micro_gemm.codegen_call(kernel,
+                                                       tile_X,
+                                                       tile_W,
+                                                       acc_slice,
+                                                       accum=True,
+                                                       qscale_and_zeros=tile_qparam)|indent(28, false)
+                            }}
+                        }
+                    }
+                }
+{%- if maybe_k_slicing %}
+                if (num_Kt_blocks > 1) {
+                    const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc;
+                    local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks + k_slice_id].reset(
+                        {{ kernel.release_buffer(acc_buf_name) }});
+                } else
+{%- endif %}
+                {
+{%- set tile_Y = kernel.slice_nd(Y_2d, [("m_start", "m_end"), ("n_start", "n_end")]) %}
+{%- set tile_acc = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("0", "n_end - n_start")]) %}
+                    {{ kernel.store_output(
+                        tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
+                    )|indent(20, false)
+                    }}
+                }
+            }
+        }
+{%- if maybe_k_slicing %}
+        if (num_Kt_blocks > 1) {
+            #pragma omp barrier
+            for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) {
+                // We slice M-dim and each thread in the k-slicing group works on a slice
+                const int64_t m_start_unsliced = mc * Mr;
+                const int64_t m_end_unsliced = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
+                const int64_t m_size_unsliced = m_end_unsliced - m_start_unsliced;
+                const int64_t m_slice_size = (m_size_unsliced + num_Kt_blocks - 1) / num_Kt_blocks;
+                const int64_t m_start = std::min(m_start_unsliced + m_slice_size * k_slice_id, m_end_unsliced);
+                const int64_t m_end = std::min(m_start_unsliced + m_slice_size * (k_slice_id + 1), m_end_unsliced);
+                const int64_t m_size = m_end - m_start;
+                const int64_t m_offset = m_start - m_start_unsliced;
+                for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
+                    const int64_t n_start = nc * Nr;
+                    const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
+                    const int64_t n_size = n_end - n_start;
+                    const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc;
+                    auto {{acc_buf_name}} = local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks].get();
+                    for (int64_t other_slice = 1; other_slice < num_Kt_blocks; other_slice++) {
+                        auto other_acc = local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks + other_slice].get();
+                        for (int64_t m = m_offset; m < m_offset + m_size; m++) {
+                            #pragma omp simd
+                            for (int64_t n = 0; n < n_size; n++) {
+                                {{acc_buf_name}}[m*Nr + n] += other_acc[m*Nr + n];
+                            }
+                        }
+                    }
+    {%- set tile_acc_m_slice = kernel.slice_nd(tile_acc, [("m_offset", "m_offset + m_end - m_start"), ()]) %}
+                    {{ kernel.store_output(
+                        tile_Y, tile_acc_m_slice, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
+                    )|indent(20, false)
+                    }}
+                }
+            }
+        }
+{%- endif %}
+        {{ micro_gemm.codegen_finalize(kernel) }}
+    }
+}
+"""
+
+SMALL_M_GEMM_TEMPLATE = r"""
+{{ template.codegen_gemm_stub_def() }}
+{
+    {{ kernel.maybe_codegen_profile() }}
+    {{ template.codegen_blocks(
+        num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOut, config, L1_cache_size, L2_cache_size, X, W
+    ) }}
+    # pragma omp parallel
+    {
+        #pragma omp for nowait
+        for (int64_t nr_block_id = 0; nr_block_id < Nr_blocks; nr_block_id++) {
+            // Handle one output M * Nr block in each thread
+            int64_t n_start = nr_block_id * Nr;
+            int64_t n_end = (nr_block_id + 1) * Nr;
+{%- if use_local_acc %}
+    {%- set acc_buf_name = "local_acc_buf" %}
+            {{ kernel.define_stack_allocated_buffer(acc_buf_name, ["M", "Nr"], acc_buf_dtype) }}
+    {%- set acc = kernel.local_buffers[acc_buf_name] %}
+{%- else %}
+    {%- set acc = kernel.slice_nd(GemmOut, [(0, "M"), ("n_start", "n_end")]) %}
+{%- endif %}
+            for (int64_t kr_block_id = 0; kr_block_id < Kr_blocks; kr_block_id++) {
+                // this loop is not parallelized
+                int64_t k_start = kr_block_id * Kr;
+                int64_t k_end = std::min((kr_block_id + 1) * Kr, K);
+{%- set tile_X = kernel.slice_nd(X, [(0, "M"), ("k_start", "k_end")]) %}
+{%- set tile_W_3d = kernel.slice_nd(W, [("nr_block_id", "nr_block_id + 1"), ("k_start", "k_end"), ()]) %}
+{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %}
+                if C10_UNLIKELY(kr_block_id == 0) {
+                    {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=False, prefetch=True)|indent(20, false) }}
+                } else if C10_UNLIKELY(k_end == K) {
+                    {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True, prefetch=False)|indent(20, false) }}
+                } else {
+                    {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True, prefetch=True)|indent(20, false) }}
+                }
+            }
+{%- set tile_Y = kernel.slice_nd(Y_2d, [("0", "M"), ("n_start", "n_end")]) %}
+{%- set tile_acc = kernel.slice_nd(acc, [("0", "M"), ("0", "n_end - n_start")]) %}
+            {{ kernel.store_output(
+                tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("0", "n_start"), reindexers=reindexers
+            )|indent(20, false) }}
+        }
+    }
+}
+"""
+
+
+def _is_int8_gemm(inputs):
+    return (
+        isinstance(inputs[0], ir.IRNode)
+        and inputs[0].get_dtype() in [torch.uint8, torch.int8]
+    ) or (
+        isinstance(inputs[0], torch.Tensor)
+        and inputs[0].dtype in [torch.uint8, torch.int8]
+    )
+
+
+def get_padded_n(n, block_n):
+    return (n + block_n - 1) // block_n * block_n
+
+
+_T = TypeVar("_T", ir.IRNode, torch.Tensor)
+
+
+def transpose_w(W: _T, trans_w: bool) -> _T:
+    """
+    Transpose W based on the trans_w flag.
+    """
+    if isinstance(W, ir.IRNode):
+        if trans_w:
+            if not isinstance(W, ir.TensorBox):
+                # pyrefly: ignore [bad-assignment]
+                W = ir.TensorBox(W)
+            W = L.permute(W, [1, 0])
+    else:
+        if trans_w:
+            assert isinstance(W, torch.Tensor)
+            # pyrefly: ignore [bad-assignment]
+            W = W.transpose(0, 1)
+    # pyrefly: ignore [bad-return]
+    return W
+
+
+def expand_bias(B: Optional[_T], X: _T) -> Optional[_T]:
+    """
+    Expand Bias to the same size of X.
+    """
+    if B is not None:
+        if isinstance(B, ir.IRNode):
+            if not isinstance(B, ir.TensorBox):
+                # pyrefly: ignore [bad-assignment]
+                B = ir.TensorBox(B)
+            assert hasattr(X, "get_size")
+            # pyrefly: ignore [missing-attribute]
+            B = L.expand(B, (X.get_size()[0], B.get_size()[-1]))
+        else:
+            assert isinstance(B, torch.Tensor)
+            assert isinstance(X, torch.Tensor)
+            # pyrefly: ignore [bad-assignment]
+            B = B.expand(X.shape[0], B.shape[-1])
+    return B
+
+
+def prune_tensors(input_nodes: list[ir.IRNode], new_input_nodes: list[ir.IRNode]):
+    """
+    Prune unused tensors from `V.graph` since the GEMM Template use new packed weight.
+    """
+
+    def share_storage(base_tensor: torch.Tensor, comp_tensor: torch.Tensor):
+        return base_tensor.is_mkldnn == comp_tensor.is_mkldnn and (
+            is_same_tensor(base_tensor, comp_tensor)
+            or is_same_mkldnn_tensor(base_tensor, comp_tensor)
+        )
+
+    def get_candidates(input_nodes, new_input_nodes):
+        # Only Constant Buffer like weight and bias might be changed in GEMM Template.
+        # The Inductor IR Node may changed, but still share the storage. For example:
+        # bias in bfloat16 case which only do the expand
+        return [
+            node
+            for node in input_nodes
+            if (
+                node not in new_input_nodes
+                and isinstance(node, (ir.TensorBox, ir.StorageBox))
+                and node.get_name() in V.graph.constants
+                and not any(
+                    (
+                        isinstance(new_node, (ir.TensorBox, ir.StorageBox))
+                        and new_node.get_name() in V.graph.constants
+                        and share_storage(
+                            V.graph.constants[node.get_name()],
+                            V.graph.constants[new_node.get_name()],
+                        )
+                    )
+                    for new_node in new_input_nodes
+                )
+            )
+        ]
+
+    for candidate_node in get_candidates(input_nodes, new_input_nodes):
+        # By using the new packed weight for the GEMM template, we can prune the
+        # old weight if it has no other users. This saves memory but makes the FX graph
+        # non-retraceable. To support retracing, we can add a repack node to the
+        # FX graph. For example:
+        # mkldnn._linear_pointwise <- repack_linear_wgt <- packed_wgt_for_template
+        candidate_tensor_users = 0
+        candidate_tensor = V.graph.constants[candidate_node.get_name()]
+        for node in reversed(V.graph.graph.nodes):
+            # Case may happen when the candidate tensor is used by more than 1 get_attr node
+            # https://github.com/pytorch/pytorch/issues/134998
+            if node.op == "get_attr" and hasattr(
+                V.graph.module, node.target
+            ):  # candidate tensor might already be deleted
+                comp_tensor = getattr(V.graph.module, node.target)
+                if isinstance(comp_tensor, torch.Tensor) and share_storage(
+                    candidate_tensor, comp_tensor
+                ):
+                    candidate_tensor_users += 1
+
+        for node in reversed(V.graph.graph.nodes):
+            # The get_attr node has only 1 user fx node
+            # The candidate tensor has been used by only 1 get_attr node
+            if (
+                node.op == "get_attr"
+                and node.target == candidate_node.get_name()
+                and len(node.users) == 1
+                and candidate_tensor_users == 1
+            ):
+                del V.graph.constants[node.target]
+                delattr(V.graph.module, node.target)
+                delattr(V.graph.graph.owning_module, node.target)
+                counters["inductor"]["select_algorithm_weight_prune"] += 1
+
+
+def gen_2d_view_of_epilogue_buf(
+    Y: ir.Buffer,
+    template_buffer: ir.Buffer,
+    epilogue_nodes: list[ir.IRNode],
+    reindexers: list[Optional[Callable[[list[Any]], list[Any]]]],
+    default_reindexers: list[Optional[Callable[[list[Any]], list[Any]]]],
+) -> tuple[
+    Union[ir.Buffer, ir.ReinterpretView],
+    list[Optional[Callable[[list[Any]], list[Any]]]],
+]:
+    """
+    The dimension and the indexing could be different between the GEMM output, i.e. `template_buffer`, which is
+    2D with MxN) and the output from the template after epilogues, i.e. `Y`. In the GEMM template code,
+    we are not aware of the dimension and the indexing of the epilogues and always work on 2D tiles according to
+    the indexing of the GEMM output.
+    In this function, we return a 2D buffer (`Y_2d`) according to GEMM output (reinterpreted from `Y` if needed) and
+    build a reindexer that converts the indexing of `Y` into `Y_2d`.
+    """
+    Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y
+    if (
+        Y.get_size() == template_buffer.get_size()
+        and Y.get_stride() == template_buffer.get_stride()
+    ):
+        reindexers.extend(default_reindexers)
+        Y_2d = Y
+    else:
+
+        def get_reindexer(epilogue_node, default_reindexer=None):
+            # From template_buffer to epilogue_node_ordered (ordered by stride decreasingly, in dense format), for example:
+            #   template_buffer:
+            #       size (324, 512), stride (512, 1)
+            #   epilogue_node_ordered (ordered by stride decreasingly, in dense format):
+            #       size (1, 18, 18, 512), stride (165888, 9216, 512, 1)
+            stride_order = list(
+                ir.get_stride_order(
+                    V.graph.sizevars.size_hints(epilogue_node.get_stride())
+                )
+            )
+            fill_order = ir.stride_order2fill_order(stride_order)
+            reversed_fill_order = list(reversed(fill_order))
+            size_with_stride_ordered_decreasingly = [
+                epilogue_node.get_size()[i] for i in reversed_fill_order
+            ]
+            reshape_reindex = ir.View.dynamic_reshape_indexer(
+                size_with_stride_ordered_decreasingly,
+                template_buffer.get_size(),
+            )
+            if default_reindexer:
+                reshape_reindex = ir.fuse_reindexing(reshape_reindex, default_reindexer)
+
+            # From epilogue_node_ordered (ordered by stride decreasingly, in dense format) to epilogue_node, for example:
+            #   epilogue_node_ordered (ordered by stride decreasingly, in dense format):
+            #       size (1, 18, 18, 512), stride (165888, 9216, 512, 1)
+            #   epilogue_node:
+            #       size (1, 18, 18, 512), stride (165888, 1, 9216, 512)
+            from_stride_ordered_decreasingly_to_epilogue_node_order = [
+                (len(stride_order) - 1) - stride_order[i]
+                for i in range(len(stride_order))
+            ]
+            stride_reindex = ir.same_reorder(
+                from_stride_ordered_decreasingly_to_epilogue_node_order
+            )
+
+            reindexer = ir.fuse_reindexing(stride_reindex, reshape_reindex)  # type: ignore[var-annotated]
+            return reindexer
+
+        if default_reindexers is None:
+            default_reindexers = [None] * len(epilogue_nodes)
+        new_reindexers = [
+            get_reindexer(epilogue_node, default_reindexer)
+            for epilogue_node, default_reindexer in zip(
+                epilogue_nodes, default_reindexers
+            )
+        ]
+        reindexers.extend(new_reindexers)
+        if isinstance(Y, ir.BaseView):
+            storage = ir.StorageBox(Y.unwrap_view())
+        else:
+            assert isinstance(Y, ir.Buffer)
+            storage = ir.StorageBox(Y)
+        Y_2d = ir.ReinterpretView(data=storage, layout=template_buffer.get_layout())
+    return Y_2d, reindexers
+
+
+class CppGemmTemplate(CppTemplate):
+    """
+    GEMM Template for Inductor CPP Backend.
+    """
+
+    def __init__(
+        self,
+        input_nodes,
+        layout: ir.Layout,
+        num_threads: int,
+        register_blocking: GemmBlocking,
+        beta=1,
+        alpha=1,
+        has_bias=False,
+        epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
+        should_block_weights: bool = True,
+        name="packed_gemm",
+    ) -> None:
+        assert layout.dtype in [torch.float, torch.bfloat16, torch.half, torch.uint8]
+        super().__init__(
+            name,
+            input_nodes,
+            layout,
+            num_threads,
+            epilogue_creator=epilogue_creator,
+        )
+        self.beta = beta
+        self.alpha = alpha
+        self.has_bias = has_bias
+        self.register_blocking = register_blocking
+        m, n = layout.size[-2:]
+        k = input_nodes[0].get_size()[-1]
+        self.m, self.n, self.k = m, n, k
+        self.padded_n = get_padded_n(n, self.register_blocking.block_n)
+        self.is_dynamic_M = has_free_symbols((m,))
+        self.should_block_weights = should_block_weights
+        self.thread_blocking = self.make_thread_blocking_cache()
+        self.cache_blocking = self.make_cache_blocking_cache()
+
+    def make_thread_blocking_cache(self):
+        cache = lru_cache()(self._thread_blocking)
+
+        def thread_blocking(num_threads: int) -> GemmBlocking:
+            return cache(num_threads)
+
+        return thread_blocking
+
+    def _thread_blocking(self, num_threads: int) -> GemmBlocking:
+        """
+        NOTE [Thread blocking in Cpp GEMM]
+        We use simple heuristics to decide the thread blocking:
+        1. Make sure all threads are occupied as much as possible.
+        2. For (m, n) blocks, favor more square-sized thread blocks for better data reuse.
+        3. If (m, n) blocks cannot occupy all the threads, we consider k-slicing.
+        TODO(jgong5): allow tuning various blocking options
+        """
+
+        def get_factors(number):
+            factors = []
+            for i in range(int(number**0.5), 0, -1):
+                if number % i == 0:
+                    factors.append(number // i)
+                    factors.append(i)
+            return factors
+
+        def get_blocking(m_factor, n_factor, k_factor, m_blocks, n_blocks, k_blocks):
+            thread_block_k = math.ceil(k_blocks / k_factor)
+            thread_block_n = math.ceil(n_blocks / n_factor)
+            thread_block_m = math.ceil(m_blocks / m_factor)
+            return GemmBlocking(thread_block_m, thread_block_n, thread_block_k)
+
+        assert not self.is_dynamic_M, (
+            "Unable to determine thread blocking for dynamic M."
+        )
+        register_blocking = self.register_blocking
+        m_blocks = math.ceil(self.m / register_blocking.block_m)
+        n_blocks = math.ceil(self.n / register_blocking.block_n)
+        k_blocks = math.ceil(self.k / register_blocking.block_k)
+        factors = get_factors(num_threads)
+        assert len(factors) > 0
+
+        if config.cpp.gemm_thread_factors is not None:
+            factors = [int(i) for i in config.cpp.gemm_thread_factors.split(",")]
+            assert len(factors) == 3
+            assert math.prod(factors) == self.num_threads
+            return get_blocking(
+                factors[0], factors[1], factors[2], m_blocks, n_blocks, k_blocks
+            )
+
+        # we favor square-sized thread blocks for good data reuse
+        def get_better_blocking(blocking, best_blocking):
+            if best_blocking is None:
+                best_blocking = blocking
+            else:
+                block_m_size = blocking.block_m * register_blocking.block_m
+                block_n_size = blocking.block_n * register_blocking.block_n
+                best_block_m_size = best_blocking.block_m * register_blocking.block_m
+                best_block_n_size = best_blocking.block_n * register_blocking.block_n
+                if blocking.block_k > best_blocking.block_k:
+                    best_blocking = blocking
+                elif (
+                    blocking.block_k == best_blocking.block_k
+                    and block_m_size + block_n_size
+                    < best_block_m_size + best_block_n_size
+                ):
+                    best_blocking = blocking
+            return best_blocking
+
+        best_blocking = None
+        # check if we can have a thread-blocking to occupy all threads without k-slicing
+        for n_factor in factors:
+            m_factor = num_threads // n_factor
+            if n_blocks >= n_factor and m_blocks >= m_factor:
+                blocking = get_blocking(
+                    m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks
+                )
+                best_blocking = get_better_blocking(blocking, best_blocking)
+
+        if best_blocking is None:
+            for k_factor in factors:
+                if k_blocks >= k_factor and (
+                    config.cpp.gemm_max_k_slices == 0
+                    or k_factor <= config.cpp.gemm_max_k_slices
+                ):
+                    n_factors = get_factors(num_threads // k_factor)
+                    for n_factor in n_factors:
+                        m_factor = (num_threads // k_factor) // n_factor
+                        if n_blocks >= n_factor and m_blocks >= m_factor:
+                            blocking = get_blocking(
+                                m_factor,
+                                n_factor,
+                                k_factor,
+                                m_blocks,
+                                n_blocks,
+                                k_blocks,
+                            )
+                            best_blocking = get_better_blocking(blocking, best_blocking)
+
+        if best_blocking is None:
+            for n_factor in factors:
+                m_factor = num_threads // n_factor
+                if n_blocks >= n_factor or m_blocks >= m_factor:
+                    blocking = get_blocking(
+                        m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks
+                    )
+                    best_blocking = get_better_blocking(blocking, best_blocking)
+
+        assert best_blocking is not None
+        return best_blocking
+
+    def make_cache_blocking_cache(self):
+        cache = lru_cache()(self._cache_blocking)
+
+        def cache_blocking(num_threads: int) -> GemmBlocking:
+            return cache(num_threads)
+
+        return cache_blocking
+
+    def _cache_blocking(self, num_threads: int) -> GemmBlocking:
+        def get_cache_blocking(register_blocking, thread_blocking):
+            Mr = register_blocking.block_m
+            Nr = register_blocking.block_n
+            Kr = register_blocking.block_k
+
+            Mt_blocks = thread_blocking.block_m
+            Nt_blocks = thread_blocking.block_n
+            Kt_blocks = thread_blocking.block_k
+
+            if config.cpp.gemm_cache_blocking is not None:
+                blockings = [int(i) for i in config.cpp.gemm_cache_blocking.split(",")]
+                assert len(blockings) == 3
+                Mc_blocks, Nc_blocks, Kc_blocks = blockings
+                return (
+                    min(Mc_blocks, Mt_blocks),
+                    min(Nc_blocks, Nt_blocks),
+                    min(Kc_blocks, Kt_blocks),
+                )
+
+            # The ratios below are empirically determined to decide
+            # the effective sizes of L1 and L2.
+            # TODO: tune the factor here
+            L1_limit_factor = 0.8
+            L2_limit_factor = 0.5
+
+            L1_cache_size = (
+                torch._C._cpu._L1d_cache_size()
+            )  # per core cache size in Bytes
+            assert L1_cache_size > 0, (
+                f"Expect L1_cache_size > 0 but got {L1_cache_size}"
+            )
+            L1 = L1_cache_size * L1_limit_factor
+
+            L2_cache_size = (
+                torch._C._cpu._L2_cache_size()
+            )  # per core cache size in Bytes
+            assert L2_cache_size > 0, (
+                f"Expect L2_cache_size > 0 but got {L2_cache_size}"
+            )
+            L2 = L2_cache_size * L2_limit_factor
+
+            def get_num_byte(dtype):
+                return torch.tensor([], dtype=dtype).element_size()
+
+            dtype_A = self.input_nodes[0].get_dtype()
+            dtype_B = self.input_nodes[1].get_dtype()
+            num_byte_A = get_num_byte(dtype_A)
+            num_byte_B = get_num_byte(dtype_B)
+            if dtype_A is torch.bfloat16 and dtype_B is torch.int8 and Kr != 1:
+                # We will cache dequantized weights (BF16) in L1D for AMX micro-kernel.
+                # In this case, the choice of the micro-kernel being used can't be decoupled from
+                # the cache blocking.
+                # TODO: Decouple the choice of micro-kernel from cache blocking
+                num_byte_B *= num_byte_A
+
+            # NOTE [CPP GEMM Cache Blocking Algorithm]
+            # Our overall strategy is to
+            # 1) Make cache blocks of B L1-reside and reused by multiple rows of A, i.e. Mc.
+            #    Here, B is Kc x Nr where Nr is a single register block. We use L1 size to
+            #    decide Kc. We want to make Mc large enough to better reuse B.
+            # 2) Make cache blocks of A L2-reside, which would limit Mc. We want to reuse A
+            #    along N, where we have two sub-strategies (see notes below) to decide Mc and Nc.
+
+            # Step 1: Decide Kc assuming B block is L1-reside.
+            size_cache_B = Kr * Kt_blocks * Nr * num_byte_B
+
+            Kc_blocks = Kt_blocks
+            if size_cache_B > L1:
+                Kc_blocks = math.floor(L1 / (Kr * Nr * num_byte_B))
+
+            if (
+                config.cpp.use_small_dequant_buffer
+                and dtype_A is torch.bfloat16
+                and Mt_blocks == 1
+            ):
+                if dtype_B is torch.uint8:
+                    # A16W4
+                    # Make a small dequant_B buffer for woq int4 [q_group_size, Nr]
+                    # Since when Mt_blocks == 1, L1-reside B block can't be reused by A.
+                    if Kc_blocks * Kr >= self.q_group_size():
+                        Kc_blocks = self.q_group_size() // Kr
+
+                elif dtype_B is torch.int8:
+                    # A16W8
+                    # Make A, B, C buffer in L1
+                    A_buf_size_div_K = self.m * num_byte_A
+                    B_buf_size_div_K = Nr * num_byte_B
+                    # assume acc in float32/int32 and Mc_blocks = Nc_blocks = 1
+                    C_buf_size = Mr * Nr * 4
+                    K_block_size = (L1 - C_buf_size) // (
+                        A_buf_size_div_K + B_buf_size_div_K
+                    )
+                    if Kc_blocks * Kr >= K_block_size:
+                        Kc_blocks = (K_block_size + Kr - 1) // Kr
+
+            # Step 2: Decide Mc assuming A block is L2-reside.
+            min_Mc_ratio = 2  # TODO(jgong5): something to tune?
+            min_Mc_blocks = math.ceil(min_Mc_ratio * Mr / Nr)
+            assert min_Mc_blocks >= 1
+            Kt_bytes = Kt_blocks * Kr * num_byte_A
+            if min_Mc_blocks * Mr * Kt_bytes < L2:
+                # Strategy 1: A (Mc x Kt) resides in L2 and reused by all Nt
+                # when Nc_blocks is kept 1. Mc should be large enough (>= min_Mc_blocks)
+                # to reuse B (Kc x Nr) in L1. This makes C (Mc x Nr) small enough to reside
+                # in L1.
+                Mc_blocks = min(Mt_blocks, math.floor(L2 / (Mr * Kt_bytes)))
+                Nc_blocks = 1
+            else:
+                # Strategy 2: Kt is too large to hold A (Mc x Kt) in L2, we reuse
+                # A (Mc x Kc) in L2 by B (Kc x Nc). C (Mc x Nc) resides in L2.
+                Mc_blocks = Mt_blocks
+                Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks)
+                Nc_bytes = Nc_blocks * Nr * 4  # assume C or acc is float32/int32
+                Kc_bytes = Kc_blocks * Kr * num_byte_A
+                if Mc_blocks * Mr * (Kc_bytes + Nc_bytes) > L2:
+                    # The following is the solution for 4*Mc*Nc + Mc*Kc_bytes = L2,
+                    # assuming Mc == Nc for good data reuse.
+                    M_max = (math.sqrt(Kc_bytes * Kc_bytes + 16 * L2) - Kc_bytes) / 8
+                    if M_max < Mc_blocks * Mr:
+                        Mc_blocks = math.floor(M_max / Mr)
+                        Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks)
+
+            return Mc_blocks, Nc_blocks, Kc_blocks
+
+        assert not self.is_dynamic_M, (
+            "Unable to determine cache blocking for dynamic M."
+        )
+        register_blocking = self.register_blocking
+        thread_blocking = self.thread_blocking(num_threads)
+
+        return GemmBlocking(*get_cache_blocking(register_blocking, thread_blocking))
+
+    def log_blockings(self):
+        log.debug(f"Register blocking: {self.register_blocking}")  # noqa: G004
+        if self.is_dynamic_M:
+            # thread and cache blockings are determined at runtime for dynamic shapes
+            return
+        log.debug(
+            f"Cache blocking: {self.cache_blocking(self.num_threads)}"  # noqa: G004
+        )
+        thread_blocking = self.thread_blocking(self.num_threads)
+        log.debug(f"Thread blocking: {thread_blocking}")  # noqa: G004
+
+        def get_occupancy():
+            m_blocks = math.ceil(self.m / self.register_blocking.block_m)
+            n_blocks = math.ceil(self.n / self.register_blocking.block_n)
+            k_blocks = math.ceil(self.k / self.register_blocking.block_k)
+            m = math.ceil(m_blocks / thread_blocking.block_m)
+            n = math.ceil(n_blocks / thread_blocking.block_n)
+            k = math.ceil(k_blocks / thread_blocking.block_k)
+            return (m, n, k)
+
+        log.debug(
+            f"Number of threads: {self.num_threads}, occupancy: {get_occupancy()}"  # noqa: G004
+        )
+
+    def maybe_k_slicing(self):
+        if self.num_threads == 1:
+            return False
+        if self.is_dynamic_M:
+            # TODO(jgong5): perhaps use size hint to decide?
+            return True
+        register_blocking = self.register_blocking
+        k_blocks = math.ceil(self.k / register_blocking.block_k)
+        thread_blocking = self.thread_blocking(self.num_threads)
+        return k_blocks > thread_blocking.block_k
+
+    @classmethod
+    def add_choices(
+        cls,
+        choices,
+        layout,
+        input_nodes,
+        beta=1,
+        alpha=1,
+        has_bias=False,
+        trans_w=False,
+        input_indices=None,
+        epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
+        act_mapping: Optional[dict[int, ir.IRNode]] = None,
+    ):
+        """
+        Add choices for the GEMM template.
+        """
+        # Fast path to save the epilogue calculation when x_scale/x_zp/w_scale are constant
+        use_int8_fast_compensation_path = _is_int8_gemm(input_nodes) and all(
+            (
+                isinstance(input_nodes[idx], ir.TensorBox)
+                and isinstance(input_nodes[idx].data.data, ir.ConstantBuffer)
+            )
+            for idx in [1, 2, 4]
+        )
+
+        if input_indices is None:
+            input_indices = list(range(len(input_nodes)))
+
+        def reorder_and_filter(inputs, layout_or_out):
+            if has_bias:
+                assert len(input_indices) >= 3
+                # Assume the input order is [inp, x, w] and we reorder it to [x, w, inp]
+                inp_idx = input_indices[0]
+                x_idx = input_indices[1]
+                w_idx = input_indices[2]
+                return [
+                    inputs[x_idx],
+                    inputs[w_idx],
+                    inputs[inp_idx],
+                    *[inputs[idx] for idx in input_indices[3:]],
+                ], layout_or_out
+            elif len(inputs) >= len(input_indices):
+                assert len(input_indices) >= 2
+                return [inputs[idx] for idx in input_indices], layout_or_out
+            else:
+                # For when input is used for x and w, i.e. X@X.T or similar
+                # Assumes the first input is the only input
+                assert len(inputs) == 1
+                return [inputs[0]] * len(input_indices), layout_or_out
+
+        new_inputs, new_layout = reorder_and_filter(input_nodes, layout)
+        is_mkldnn_wgt = (
+            new_inputs[1].get_name() in V.graph.constants
+            and V.graph.constants[new_inputs[1].get_name()].is_mkldnn
+        )
+        if is_mkldnn_wgt:
+            # It shouldn't happen as viewing an mkldnn tensor, we can extend the
+            # implementation if it does.
+            assert not isinstance(new_inputs[1], ir.BaseView)
+        # Note that the layout of MKLDNN Tensor is with the wrong stride
+        view_size = new_inputs[1].layout.size
+        view_stride = new_inputs[1].layout.stride
+        view_offset = new_inputs[1].layout.offset
+
+        def maybe_to_dense(inputs, layout_or_out):
+            new_inputs = list(inputs)
+            if isinstance(inputs[1], torch.Tensor):
+                W = inputs[1]
+                new_inputs[1] = W.to_dense() if W.is_mkldnn else W
+            return new_inputs, layout_or_out
+
+        def normalize_shapes(inputs, layout_or_out):
+            new_inputs = list(inputs)
+            if not is_mkldnn_wgt and isinstance(new_inputs[1], torch.Tensor):
+                if has_free_symbols(view_size):
+                    # If batch size B is dynamic, we need to set the batch size and possibly stride
+                    assert not has_free_symbols(view_size[1:])
+                    view_size[:] = V.graph.sizevars.size_hints(view_size)
+                    view_stride[:] = V.graph.sizevars.size_hints(view_stride)
+                # With the assumptation that W is the storage of unwrap view
+                # thus view it back here
+                new_inputs[1] = new_inputs[1].as_strided(
+                    view_size, view_stride, view_offset
+                )
+
+            if not trans_w:
+                return new_inputs, layout_or_out
+            X = new_inputs[0]
+            W = new_inputs[1]
+            B = new_inputs[2] if has_bias else None
+            W = transpose_w(W, trans_w)
+            B = expand_bias(B, X)  # type:ignore[arg-type]
+            new_inputs[1] = W
+            if B is not None:
+                new_inputs[2] = B
+            return new_inputs, layout_or_out
+
+        # TODO(jgong5): decide proper number of threads per problem size
+        num_threads = parallel_num_threads()
+        new_inputs, _ = normalize_shapes(*maybe_to_dense(new_inputs, new_layout))
+        m, n, k, *_ = mm_args(
+            new_inputs[0],
+            new_inputs[1],
+            mat2_transposed=cls.is_woq_int4(),
+            use_4x2_dim=cls.is_woq_int4(),
+        )
+        output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
+            new_inputs[0].get_dtype()
+        )
+        micro_gemm = create_micro_gemm(
+            "micro_gemm",
+            m,
+            n,
+            k,
+            input_dtype=new_inputs[0].get_dtype(),
+            input2_dtype=new_inputs[1].get_dtype(),
+            output_dtype=output_dtype,
+            compute_dtype=compute_dtype,
+            alpha=alpha,
+            num_threads=num_threads,
+            use_ref=not cls.is_woq_int4(),
+            q_group_size=cls.q_group_size(),
+        )
+        assert micro_gemm is not None
+        pre_block_weights = cls.check_if_block_weight(new_inputs[1], micro_gemm)
+        micro_gemm.use_local_vnni_blocking(not pre_block_weights)
+        only_one_input = (
+            input_nodes[0] == input_nodes[1] if len(input_nodes) > 1 else False
+        ) and not pre_block_weights  # If weights are blocked, use the second input
+
+        def preprocessor(inputs, layout):
+            new_inputs, new_layout = normalize_shapes(
+                *maybe_to_dense(*reorder_and_filter(inputs, layout))
+            )
+            if only_one_input and isinstance(new_inputs[0], torch.Tensor):
+                return new_inputs[1:], new_layout
+            return cls.prep_weight(
+                new_inputs,
+                new_layout,
+                # pyrefly: ignore [bad-argument-type]
+                micro_gemm,
+                pre_block_weights,
+                use_int8_fast_compensation_path,
+            )
+
+        def postprocessor(output):
+            if isinstance(output, ir.TensorBox):
+                # prepack the weight as input to the template buffer
+                template_buffer = ir.InputsKernel.unwrap_storage_for_input(output)
+                assert isinstance(template_buffer, ir.CppTemplateBuffer)
+                new_input_nodes, _ = reorder_and_filter(input_nodes, layout)
+
+                W_node = new_input_nodes[1]
+                if W_node.get_name() not in V.graph.constants:
+                    return output
+                W = V.graph.constants[W_node.get_name()]
+                new_input_nodes[1] = W
+                new_input_nodes, new_layout = normalize_shapes(
+                    *maybe_to_dense(new_input_nodes, layout)
+                )
+                new_input_nodes, _ = cls.prep_weight(
+                    new_input_nodes,
+                    new_layout,
+                    # pyrefly: ignore [bad-argument-type]
+                    micro_gemm,
+                    pre_block_weights,
+                    use_int8_fast_compensation_path,
+                    skip_int8_compensation=True,
+                )
+                W_packed = new_input_nodes[1]
+                W_packed_constant = V.graph.add_tensor_constant(W_packed)
+                new_input_nodes[1] = W_packed_constant
+
+                # Prune unused tensors
+                prune_tensors(input_nodes, new_input_nodes)
+
+                template_buffer.inputs[1] = ir.InputsKernel.unwrap_storage_for_input(
+                    W_packed_constant
+                )
+            return output
+
+        template = DataProcessorTemplateWrapper(
+            cls,
+            preprocessor,
+            postprocessor,
+            input_nodes=input_nodes,
+            layout=layout,
+            num_threads=num_threads,
+            register_blocking=micro_gemm.register_blocking,
+            beta=beta,
+            alpha=alpha,
+            has_bias=has_bias,
+            epilogue_creator=epilogue_creator,
+            should_block_weights=pre_block_weights,
+            name=micro_gemm.__class__.__name__,
+        )
+        template.maybe_append_choice(choices)
+        return template
+
+    @staticmethod
+    def get_padded_size(n, block_n, k, should_block_weight):
+        padded_n = get_padded_n(n, block_n)
+        # We assume that all GEMM weight tensors should be blocked and padded
+        new_size = [padded_n // block_n, k, block_n]
+        return new_size, padded_n
+
+    @staticmethod
+    def _maybe_remove_storage_offset(node: ir.IRNode):
+        if node.get_layout().offset == 0:
+            return node
+        # node may be contiguous but still have a non-zero storage offset.
+        # GEMM_TEMPLATE emits code like:
+        #   W.data_ptr[node.offset + ...]
+        # but runtime W.data_ptr (after normalize_shapes()) already includes this offset.
+        # To avoid double-offsetting, we remove the offset in the node also in the generated code.
+        #   W.data_ptr[...]
+        return ir.ExternKernel.copy_input(node)
+
+    @classmethod
+    def prep_weight(
+        cls,
+        inputs,
+        layout: ir.Layout,
+        micro_gemm: CppMicroGemm,
+        should_block_weight: bool,
+        use_int8_fast_compensation_path: bool = False,
+        skip_int8_compensation: bool = False,
+    ):
+        """
+        NOTE Weight prep consists of 2 separate steps:
+        1. Blocking the weight tensor into a 3D shape: [n//block_n, k, block_n]
+           This is always done if the weight tensor is constant, i.e. for all GEMM and some BMM.
+           For BMM, we also block non-contiguous weight tensors, since they would be reshaped anyway.
+           This assumes that blocked, contiguous weights will be more efficient for the GEMM kernel,
+           and is worth the overhead of reshape and blocking.
+
+           This blocking includes additional padding, when n is not a multiple of block_n.
+           This padding allows a more efficient microkernel implementation. For BMM, this is only done
+           if reshape would happen anyway, i.e.  if the weight tensor is constant, is not contiguous,
+           or is using AMX VNNI layout.
+        2. Packing the weight tensor into a VNNI-friendly shape. For constant input,
+           this is done at the same time as the weight blocking.
+
+        At compile time, the constant weight tensors are blocked and packed. For non-constant tensors (e.g. BMM)
+        which will be blocked (non-contiguous or VNNI-layout tensors), the weight tensor is blocked and packed at runtime.
+
+        CppBmmTemplate overrides the methods get_padded_size, and block_weight in order to accommodate
+        an additional dimension for the batch size and to determine if the weight tensor should be blocked.
+        """
+        W = inputs[1]
+        new_inputs = list(inputs)
+        if cls.is_woq_int4():
+            assert (
+                len(W.get_size()) == 2
+                if isinstance(W, ir.IRNode)
+                else len(W.shape) == 2
+            )
+            n, k = W.get_size() if isinstance(W, ir.IRNode) else W.shape
+        else:
+            k, n = W.get_size()[-2:] if isinstance(W, ir.IRNode) else W.shape[-2:]
+        _, block_n, _ = micro_gemm.register_blocking
+        new_size, padded_n = cls.get_padded_size(n, block_n, k, should_block_weight)
+        padding = padded_n - n
+
+        if should_block_weight and not cls.is_woq_int4():
+            blocked_w = cls.block_weight(W, new_size, padding)
+            new_inputs[1] = cls.pack_vnni_weight(blocked_w, micro_gemm, new_size)
+        elif should_block_weight:
+            assert cls.is_woq_int4()
+            new_inputs[1] = cls.block_weight(W, new_size, padding)
+        elif isinstance(W, ir.IRNode):
+            # Require W layout to be fixed & contiguous, happens inplace.
+            ir.ExternKernel.require_contiguous(W)
+            new_inputs[1] = cls._maybe_remove_storage_offset(W)
+
+        if not skip_int8_compensation and _is_int8_gemm(new_inputs):
+            BCompensate = None
+            x_w_scale = None
+
+            def _get_compensation_node(W, use_int8_fast_compensation_path):
+                BCompensate = V.graph.add_tensor_constant(
+                    V.graph.constants[W.get_name() + "_BMatrixCompens"],
+                    W.get_name() + "_BMatrixCompens",
+                )
+                x_w_scale = None
+                if use_int8_fast_compensation_path:
+                    x_w_scale = V.graph.add_tensor_constant(
+                        V.graph.constants[W.get_name() + "_x_w_compens"],
+                        W.get_name() + "_x_w_compens",
+                    )
+                return BCompensate, x_w_scale
+
+            if use_int8_fast_compensation_path:
+                # new_inputs has been reordered: [x, w, optional[bias], x_scale, x_zp, w_scale, w_zp]
+                x_scale = new_inputs[-4]
+                x_zp = new_inputs[-3]
+                w_scale = new_inputs[-2]
+                if isinstance(W, ir.IRNode):
+                    BCompensate, x_w_scale = _get_compensation_node(
+                        W, use_int8_fast_compensation_path
+                    )
+                else:
+                    # Use the original W, not the blocked_w in new_inputs[1] to calculate BCompensate
+                    BCompensate = torch.sum(W.to_dense().to(torch.float), dim=0)  # type: ignore[assignment]
+                    assert all(
+                        isinstance(item, torch.Tensor)
+                        for item in (x_scale, x_zp, w_scale)
+                    )
+                    BCompensate = BCompensate * x_scale * w_scale * x_zp
+                    x_w_scale = x_scale * w_scale
+                new_inputs.append(BCompensate)
+                new_inputs.append(x_w_scale)
+            else:
+                if isinstance(W, ir.IRNode):
+                    BCompensate, _ = _get_compensation_node(
+                        W, use_int8_fast_compensation_path
+                    )
+                else:
+                    # Use the original W, not the blocked_w in new_inputs[1] to calculate BCompensate
+                    BCompensate = torch.sum(W.to_dense().to(torch.float), dim=0)  # type: ignore[assignment]
+                new_inputs.append(BCompensate)
+        return new_inputs, layout
+
+    @staticmethod
+    def check_if_block_weight(W, micro_gemm):
+        return True
+
+    @classmethod
+    def block_weight(cls, W, new_size, padding):
+        # These are separated into two methods to allow subclasses to override them separately
+        if isinstance(W, ir.IRNode):
+            if W.get_name() in V.graph.constants:
+                # Create a new buffer, representing the constant blocked tensor
+                blocked_w = ir.Buffer(
+                    name=W.get_name(),  # Borrow the registered buffer name
+                    layout=ir.FixedLayout(
+                        W.get_device_or_error(),
+                        W.get_dtype(),
+                        new_size,
+                        ir.FlexibleLayout.contiguous_strides(new_size),
+                        0,
+                    ),
+                )
+            else:
+                if not isinstance(W, ir.TensorBox):
+                    W = ir.TensorBox(W)
+                permute_dims = list(range(len(new_size)))
+                permute_dims[-2], permute_dims[-3] = permute_dims[-3], permute_dims[-2]
+                permute_size = list(new_size)
+                permute_size[-2], permute_size[-3] = permute_size[-3], permute_size[-2]
+                blocked_w = L.constant_pad_nd(W, (0, padding))
+                blocked_w = L.permute(
+                    L.view(blocked_w, permute_size),  # type: ignore[arg-type]
+                    permute_dims,
+                )
+        else:
+            assert isinstance(W, torch.Tensor)
+            # Pad the weight tensor and reshape it into a 3D blocked shape
+            blocked_size = list(new_size)
+            blocked_size[-2], blocked_size[-3] = blocked_size[-3], blocked_size[-2]
+            blocked_w = (
+                torch.nn.functional.pad(W, (0, padding))  # type: ignore[assignment]
+                .reshape(*blocked_size)
+                .transpose(-3, -2)
+                .contiguous()
+            )
+        return blocked_w
+
+    @classmethod
+    def pack_vnni_weight(cls, W, micro_gemm, new_size):
+        # WOQ INT4 weights are reordered in microkernel so do not pack them here
+        should_pack = (
+            micro_gemm.get_b_layout() != LayoutType.NORMAL
+            and not micro_gemm.is_woq_int4()
+        )
+
+        # These are separated into two methods to allow subclasses to override them separately
+        if isinstance(W, ir.IRNode):
+            if isinstance(W, ir.Buffer) and W.get_name() in V.graph.constants:
+                return W
+            k = new_size[-2]
+            if not isinstance(W, ir.TensorBox):
+                W = ir.TensorBox(W)
+            if should_pack:
+                permute_dims = list(range(len(new_size) + 1))
+                permute_dims[-1], permute_dims[-2] = permute_dims[-2], permute_dims[-1]
+                vnni_size = 4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2
+                vnni_view_size = list(new_size)
+                vnni_view_size[-2] = k // vnni_size
+                vnni_view_size.insert(-1, vnni_size)
+                W = L.view(
+                    L.permute(L.view(W, vnni_view_size), permute_dims),
+                    new_size,
+                )
+            W = ir.ExternKernel.realize_input(W)
+            W = ir.ExternKernel.require_contiguous(W)
+            return W
+        else:
+            k = new_size[-2]
+            # Apply VNNI packing to the weight tensor
+            if should_pack:
+                # TODO: Move VNNI weight packing for non-constant tensors into the template,
+                # to improve cache locality and avoid full-tensor copy.
+                layout_str = (
+                    "VNNI4"
+                    if micro_gemm.get_b_layout() == LayoutType.VNNI4
+                    else "VNNI2"
+                )
+                assert micro_gemm.get_b_layout() in [
+                    LayoutType.VNNI2,
+                    LayoutType.VNNI4,
+                ], f"We only support {layout_str} for now"
+                vnni_size = 4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2
+                assert k % vnni_size == 0, (
+                    f"k should be divisible by vnni_size for {layout_str} layout"
+                )
+                vnni_view_size = list(new_size)
+                vnni_view_size[-2] = k // vnni_size
+                vnni_view_size.insert(-1, vnni_size)
+                W = W.view(vnni_view_size).transpose(-1, -2).contiguous().view(new_size)
+            # normalize stride to be "contiguous_strides" per size
+            # this avoids the problems in L.view during template codegen
+            new_stride = [1]
+            for sz in reversed(W.shape[1:]):
+                new_stride.insert(0, new_stride[0] * sz)
+            W = W.as_strided(W.shape, new_stride)
+            return W
+
+    def get_default_reindexers(self, epilogue_nodes):
+        return [None] * len(epilogue_nodes)
+
+    def get_options(
+        self,
+        kernel: CppTemplateKernel,
+        template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
+        flag_template_buffer_has_other_users: Optional[bool] = None,
+        epilogue_nodes: Optional[list[ir.IRNode]] = None,
+    ) -> dict[str, Any]:
+        assert len(self.input_nodes) >= 2
+
+        int8_gemm = self.input_nodes[0].get_dtype() in [torch.uint8, torch.int8]
+        x_scale = None
+        x_zp = None
+        w_scale = None
+        w_zp = None
+        inp = None
+        q_group_size_node = None
+        qscale_and_zeros = None
+        if int8_gemm:
+            X, W = self.input_nodes[0], self.input_nodes[1]
+            bias_idx = 2 if self.has_bias else 1
+            inp = self.input_nodes[bias_idx] if self.has_bias else None
+            x_scale = self.input_nodes[bias_idx + 1]
+            x_zp = self.input_nodes[bias_idx + 2]
+            w_scale = self.input_nodes[bias_idx + 3]
+            w_zp = self.input_nodes[bias_idx + 4]
+            Y = self.output_node
+        elif self.is_woq_int4():
+            X, W = self.input_nodes[0], self.input_nodes[1]
+            Y = self.output_node
+            q_group_size_node = self.input_nodes[2]
+            qscale_and_zeros = self.input_nodes[3]
+        else:
+            X, W = self.input_nodes[0], self.input_nodes[1]
+            Y = self.output_node
+            inp = self.input_nodes[2] if self.has_bias else None
+
+        template_buffer_has_other_users = None
+
+        if template_buffer_node is not None:
+            # Use the updated prepacked weight buffer
+            W = template_buffer_node.inputs[1]
+            Y = template_buffer_node
+
+            assert flag_template_buffer_has_other_users is not None
+            template_buffer_has_other_users = flag_template_buffer_has_other_users
+
+        template_buffer = Y
+        gemm_output_buffer = template_buffer
+
+        epilogues: list[ir.IRNode] = []
+        reindexers: list[Optional[Callable[[list[Any]], list[Any]]]] = []
+        epilogue_creators: list[Callable[[ir.Buffer], ir.Pointwise]] = []
+        fake_buffers: list[ir.Buffer] = []
+        Y_aliases: OrderedSet[str] = OrderedSet()
+
+        use_local_acc = (
+            self.layout.dtype != torch.float
+            or template_buffer_has_other_users
+            or int8_gemm
+            or self.padded_n != self.n
+            or self.maybe_k_slicing()
+            or (epilogue_nodes and epilogue_nodes[-1].get_dtype() != self.layout.dtype)
+        )
+
+        # TODO(jgong5): for int8 gemm, bias-add is handled outside of gemm template,
+        # but we'd better move it here to align with fp.
+        if inp is not None and self.beta != 0 and not int8_gemm:
+            # add an epilogue for bias add
+            def _bias_add_epilogue(buf):
+                return create_epilogue_with_attr(
+                    buf, "bias_add", other=inp, beta=self.beta, dtype=self.layout.dtype
+                )
+
+            epilogue_creators.append(_bias_add_epilogue)
+
+        if self.epilogue_creator is not None:
+            epilogue_creators.append(self.epilogue_creator)
+
+        # When the GEMM output buffer is localized but it has users other than the epilogue nodes,
+        # we need to copy the value in the GEMM output local buffer to a global buffer.
+        def need_copy_from_local_to_global_buffer_epilogue(
+            use_local_acc, template_buffer_has_other_users, epilogue_creators
+        ):
+            # The GEMM output buffer is a global buffer, thus copy is not needed.
+            if not use_local_acc:
+                return False
+
+            # The possible value of template_buffer_has_other_users is (None, False, True)
+            # It is None when generating the gemm template during autotune and it will have value during scheduler codegen.
+            # extra copy_from_local_to_global_buffer_epilogue is not needed in either of the below two cases:
+            #   1. template_buffer_has_other_users is None (i.e. when doing the codegen during autotune)
+            #   2. template_buffer_has_other_users is False, which means it's safe to keep the value in the
+            #       GEMM output buffer in local buffer only (no users outside of the epilogues will use its value).
+            if not template_buffer_has_other_users:
+                return False
+
+            # When bias is not None or self.epilogue_creator is not None,
+            # there will be epilogue_creators after the GEMM.
+            # The GEMM output buffer is localized while
+            # the output buffer of the epilogue_creators is a global buffer.
+            if epilogue_creators:
+                return False
+
+            return True
+
+        if need_copy_from_local_to_global_buffer_epilogue(
+            use_local_acc, template_buffer_has_other_users, epilogue_creators
+        ):
+
+            def copy_from_local_to_global_buffer_epilogue(input_buffer: ir.Buffer):
+                dtype = self.layout.dtype
+                input_loader = input_buffer.make_loader()
+
+                def copy_inner(index):
+                    input = input_loader(index)
+                    result = ops.to_dtype(input, dtype)
+                    return result
+
+                return ir.Pointwise(
+                    device=input_buffer.get_device_or_error(),
+                    dtype=self.layout.dtype,
+                    inner_fn=copy_inner,
+                    ranges=input_buffer.get_size(),
+                )
+
+            epilogue_creators.append(copy_from_local_to_global_buffer_epilogue)
+
+        # NOTE [How CPP GEMM template epilogues are organized]
+        #   gemm_output_buffer
+        #     --> zero or more in-template epilogues (created by `epilogue_creators`) -->
+        #   template_buffer
+        #     --> zero or more out-of-template epilogues (`epilogue_nodes`) -->
+        #   Y
+        if epilogue_creators:
+            assert isinstance(template_buffer, ir.IRNode)
+            gemm_output_name = f"{template_buffer.get_name()}_GemmOut"
+            gemm_output_buffer = ir.Buffer(
+                name=gemm_output_name,
+                # pyrefly: ignore [missing-attribute]
+                layout=template_buffer.layout,
+            )
+            current_input_buffer = gemm_output_buffer
+            for i, creator in enumerate(epilogue_creators):
+                if i == len(epilogue_creators) - 1:
+                    buffer_name = template_buffer.get_name()
+                else:
+                    buffer_name = f"{gemm_output_name}_epilogue_{i}"
+                epilogues.append(
+                    ir.ComputedBuffer(
+                        name=buffer_name,
+                        # pyrefly: ignore [missing-attribute]
+                        layout=template_buffer.layout,
+                        data=creator(current_input_buffer),
+                    )
+                )
+                fake_buffers.append(current_input_buffer)
+                Y_aliases.add(current_input_buffer.get_name())
+                reindexers.append(None)
+                if i < len(epilogue_creators) - 1:
+                    current_input_buffer = ir.Buffer(
+                        name=buffer_name,
+                        # pyrefly: ignore [missing-attribute]
+                        layout=template_buffer.layout,
+                    )
+
+        assert isinstance(Y, (ir.Buffer, ir.ReinterpretView))
+        Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y
+
+        if epilogue_nodes:
+            if not template_buffer_has_other_users:
+                assert isinstance(template_buffer, ir.IRNode)
+                Y_aliases.add(template_buffer.get_name())
+            epilogues.extend(epilogue_nodes)
+            assert Y.get_numel() == epilogues[-1].get_numel()
+            Y = cast(ir.Buffer, epilogues[-1])
+            assert isinstance(template_buffer, ir.Buffer)
+            Y_2d, reindexers = gen_2d_view_of_epilogue_buf(
+                Y,
+                template_buffer,
+                epilogue_nodes,
+                reindexers,
+                default_reindexers=self.get_default_reindexers(epilogue_nodes),
+            )
+
+        output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
+            X.get_dtype()
+        )
+        micro_gemm = create_micro_gemm(
+            f"{kernel.kernel_name}_micro_gemm",
+            self.m,
+            self.n,
+            self.k,
+            input_dtype=X.get_dtype(),
+            # pyrefly: ignore [missing-attribute]
+            input2_dtype=W.get_dtype(),
+            output_dtype=output_dtype,
+            compute_dtype=compute_dtype,
+            alpha=self.alpha,
+            num_threads=self.num_threads,
+            use_ref=not self.is_woq_int4(),
+            q_group_size=self.q_group_size(),
+        )
+        assert micro_gemm is not None
+        micro_gemm.use_local_vnni_blocking(not self.should_block_weights)
+        assert self.register_blocking == micro_gemm.register_blocking
+        self.log_blockings()
+        if isinstance(micro_gemm, CppMicroGemmAMX):
+            counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1
+        if isinstance(micro_gemm, CppMicroBrgemm):
+            counters["inductor"]["cpp_micro_brgemm_counter"] += 1
+
+        L1_cache_size = torch._C._cpu._L1d_cache_size()  # per core cache size in Bytes
+        assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}"
+
+        L2_cache_size = torch._C._cpu._L2_cache_size()  # per core cache size in Bytes
+        assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}"
+
+        options = dict(
+            X=X,
+            W=W,
+            inp=inp,
+            Y=Y,
+            N=self.n,
+            K=self.k,
+            PADDED_N=self.padded_n,
+            GemmOut=gemm_output_buffer,
+            aliases={alias: Y.get_name() for alias in Y_aliases},
+            beta=self.beta,
+            alpha=self.alpha,
+            num_threads=self.num_threads,
+            micro_gemm=micro_gemm,
+            is_dynamic_M=self.is_dynamic_M,
+            template=self,
+            kernel=kernel,
+            export_declaration=get_export_declaration(),
+            epilogue_nodes=epilogues,
+            reindexers=reindexers,
+            Y_2d=Y_2d,
+            use_local_acc=use_local_acc,
+            maybe_k_slicing=self.maybe_k_slicing(),
+            x_scale=x_scale,
+            x_zp=x_zp,
+            w_scale=w_scale,
+            w_zp=w_zp,
+            acc_buf_dtype=torch.int32 if int8_gemm else torch.float,
+            DTYPE_TO_CPP=DTYPE_TO_CPP,
+            L1_cache_size=L1_cache_size,
+            L2_cache_size=L2_cache_size,
+            config=config,
+            fake_buffers=fake_buffers,
+            is_woq_int4=self.is_woq_int4(),
+            q_group_size=q_group_size_node,
+            qscale_and_zeros=qscale_and_zeros,
+        )
+        return options
+
+    def is_int8_woq_gemm_small_m_dim(
+        self,
+        X: ir.ReinterpretView,
+        W: ir.ReinterpretView,
+        N,
+        K,
+        micro_gemm,
+    ):
+        """Use SMALL_M_GEMM_TEMPLATE"""
+        return (
+            isinstance(micro_gemm, CppMicroGemmFP32Vec)
+            and is_int8_woq_gemm_small_m_dim_corner_case(
+                micro_gemm, X.get_size()[0], N, K
+            )
+            and X.get_dtype() is torch.bfloat16
+            and W.get_dtype() is torch.int8
+        )
+
+    def render(  # type: ignore[override, return]
+        self,
+        kernel: CppTemplateKernel,
+        template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
+        flag_template_buffer_has_other_users: Optional[bool] = None,
+        epilogue_nodes: Optional[list[ir.IRNode]] = None,
+        **kwargs,
+    ) -> str:
+        options = self.get_options(
+            kernel=kernel,
+            template_buffer_node=template_buffer_node,
+            flag_template_buffer_has_other_users=flag_template_buffer_has_other_users,
+            epilogue_nodes=epilogue_nodes,
+        )
+        self.render_options = options
+
+        with contextlib.ExitStack() as stack:
+            for buf in options["fake_buffers"]:
+                stack.enter_context(
+                    patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf))
+                )
+            if not options["is_dynamic_M"] and self.is_int8_woq_gemm_small_m_dim(
+                options["X"],
+                options["W"],
+                options["N"],
+                options["K"],
+                options["micro_gemm"],
+            ):
+                template_str = SMALL_M_GEMM_TEMPLATE
+            else:
+                template_str = GEMM_TEMPLATE
+            return self._template_from_string(template_str).render(**options)
+
+    def codegen_blocks(
+        self,
+        num_threads,
+        N,
+        K,
+        micro_gemm,
+        is_dynamic_M,
+        kernel,
+        GemmOut,
+        config,
+        L1_cache_size,
+        L2_cache_size,
+        X,
+        W,
+    ):
+        options = dict(
+            num_threads=num_threads,
+            N=N,
+            K=K,
+            micro_gemm=micro_gemm,
+            is_dynamic_M=is_dynamic_M,
+            kernel=kernel,
+            GemmOut=GemmOut,
+            config=config,
+            L1_cache_size=L1_cache_size,
+            L2_cache_size=L2_cache_size,
+            template=self,
+            X=X,
+            W=W,
+            is_woq_int4=self.is_woq_int4(),
+        )
+        template_str = GEMM_TEMPLATE_INIT_BLOCKING_BASIC_BLOCK
+        if not (
+            not is_dynamic_M
+            and self.is_int8_woq_gemm_small_m_dim(X, W, N, K, micro_gemm)
+        ):
+            template_str += GEMM_TEMPLATE_INIT_BLOCKING_EXTENDED
+        return self._template_from_string(template_str).render(options)
+
+    def codegen_microkernel_def(self):
+        return self._template_from_string(GEMM_TEMPLATE_MICROKERNEL_DEF).render(
+            self.render_options
+        )
+
+    def codegen_gemm_stub_def(self):
+        microkernel = self.codegen_microkernel_def()
+        return microkernel + self._template_from_string(GEMM_TEMPLATE_STUB_DEF).render(
+            self.render_options
+        )
+
+    def codegen_multi_threads_params(self):
+        return self._template_from_string(GEMM_TEMPLATE_MULTI_THREADS_PARAMS).render()
+
+    def codegen_single_thread_params(self, is_dynamic_M):
+        options = dict(
+            is_dynamic_M=is_dynamic_M,
+        )
+        return self._template_from_string(GEMM_TEMPLATE_SINGLE_THREAD_PARAMS).render(
+            options
+        )
+
+    def codegen_m_loop_params(self):
+        return self._template_from_string(GEMM_TEMPLATE_M_LOOP_PARAMS).render()
+
+    def codegen_n_loop_params(self):
+        return self._template_from_string(GEMM_TEMPLATE_N_LOOP_PARAMS).render()
+
+    @classmethod
+    def is_woq_int4(cls):
+        return False
+
+    @classmethod
+    def q_group_size(cls):
+        return None
+
+
+class CppWoqInt4GemmTemplateMeta(type):
+    def __getitem__(cls, q_group_size):
+        class CppWoqInt4GemmTemplateInstance(CppGemmTemplate):
+            def __init__(
+                self,
+                *args,
+                **kwargs,
+            ) -> None:
+                super().__init__(
+                    *args,
+                    **kwargs,
+                )
+
+            @classmethod
+            def is_woq_int4(cls):
+                return True
+
+            @classmethod
+            def q_group_size(cls):
+                return q_group_size
+
+            @staticmethod
+            def check_if_block_weight(W, micro_gemm):
+                # For WOQ INT4, weight is already packed
+                # However, for AMX microkernel, we want to change the blocking of weight
+                from .cpp_micro_gemm import CppMicroGemmWoQInt4Amx
+
+                return isinstance(micro_gemm, CppMicroGemmWoQInt4Amx)
+
+            @classmethod
+            def block_weight(cls, W, new_size, padding):
+                # This method is called only if AMX microkernels are used.
+                # In this case, we unpack and repack weight so that block_n=32
+                # the format of packed weight is described here:
+                # https://github.com/pytorch/pytorch/blob/32eee8ed225d9f10fbbcb38c24b8b44c24c0c97c/aten/src/ATen/native/cpu/int4mm_kernel.cpp#L583
+                if isinstance(W, ir.IRNode):
+                    # in this case, we do nothing
+                    ir.ExternKernel.require_contiguous(W)
+                    blocked_w = W
+                else:
+                    # in this case, we unpack and repack weight
+                    assert isinstance(W, torch.Tensor)
+                    assert W.dim() == 2
+                    N = W.size(0)
+                    K = W.size(-1) * 2
+                    G = cls.q_group_size()
+                    # x and qscales_and_zeros are in bfloat16 instead of float to use the optimized kernel
+                    # so that the unpacking process is faster
+                    x = torch.eye(K).bfloat16()
+                    # Here we use scale=1 and qzero=8 because we want to unpack weight
+                    # without dequantizing it. The qzero here is 8 instead of 0 because
+                    # int4 values are converted to [-7, 8] in the _weight_int4pack_mm_for_cpu kernel:
+                    # https://github.com/pytorch/pytorch/blob/32eee8ed225d9f10fbbcb38c24b8b44c24c0c97c/aten/src/ATen/native/cpu/int4mm_kernel.cpp#L95
+                    qscales_and_zeros = (
+                        torch.tensor([1.0, 8.0])
+                        .bfloat16()
+                        .expand(K // G, N, 2)
+                        .contiguous()
+                    )
+                    # shape: [K, N]
+                    unpacked_w = torch.ops.aten._weight_int4pack_mm_for_cpu(
+                        x,
+                        W,
+                        G,
+                        qscales_and_zeros,
+                    ).to(torch.uint8)
+                    block_n = 32
+                    # shape: [N // block_n, K, block_n]
+                    w_blocked = (
+                        unpacked_w.view(K, N // block_n, block_n)
+                        .permute(1, 0, 2)
+                        .contiguous()
+                    )
+                    # pack 2 int4 -> 1 int8
+                    # block_n: [a0, a1, ..., a15, b0, b1, ..., b15]
+                    # -> [(a0 & 0xf) | (b0 << 4), (a1 & 0xf) | (b1 << 4), ...]
+                    # shape: [N // block_n, K, 2, block_n // 2]
+                    w_blocked = w_blocked.view(N // block_n, K, 2, block_n // 2)
+                    # shape: [N // block_n, K, block_n // 2]
+                    w_blocked_packed = (w_blocked[:, :, 0, :] & 0xF) | (
+                        w_blocked[:, :, 1, :] << 4
+                    )
+                    # shape: [N, K // 2]
+                    blocked_w = w_blocked_packed.view(N, K // 2)
+
+                return blocked_w
+
+        return CppWoqInt4GemmTemplateInstance
+
+
+class CppWoqInt4GemmTemplate(metaclass=CppWoqInt4GemmTemplateMeta):
+    pass
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_grouped_gemm_template.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_grouped_gemm_template.py
new file mode 100644
index 0000000000000000000000000000000000000000..abea505b2d069a26c2d1ed181e217a88fb61d0d4
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_grouped_gemm_template.py
@@ -0,0 +1,511 @@
+import contextlib
+import logging
+from collections.abc import Callable
+from typing import Any, cast, Optional, TypeVar
+from unittest.mock import patch
+
+import torch
+import torch.utils
+from torch.utils._ordered_set import OrderedSet
+
+from ..._dynamo.utils import counters
+from .. import config, ir
+from ..kernel.mm_common import mm_args
+from ..select_algorithm import ChoiceCaller, DataProcessorTemplateWrapper
+from ..utils import parallel_num_threads
+from ..virtualized import V
+from .cpp import get_export_declaration
+from .cpp_gemm_template import (
+    CppGemmTemplate,
+    expand_bias,
+    gen_2d_view_of_epilogue_buf,
+    prune_tensors,
+    transpose_w,
+)
+from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm
+from .cpp_template_kernel import CppTemplateKernel
+from .cpp_utils import (
+    create_epilogue_with_attr,
+    DTYPE_TO_CPP,
+    GemmBlocking,
+    get_gemm_template_output_and_compute_dtype,
+)
+
+
+log = logging.getLogger(__name__)
+
+GEMM_TEMPLATE = r"""
+{{template.header().getvalue()}}
+{{micro_gemm.codegen_define(kernel)}}
+
+extern "C" {{export_declaration}}
+{{kernel.def_kernel(inputs=kernel_args, outputs=Y_list, aliases=aliases)}}
+{
+    {{kernel.maybe_codegen_profile()}}
+    {{ template.codegen_blocks(
+        num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOuts[0], config, L1_cache_size, L2_cache_size, X_list[0], W_list[0]
+    ) }}
+{%- if num_threads > 1 %}
+    #pragma omp parallel num_threads({{num_threads}})
+    {
+        {{ template.codegen_multi_threads_params()|indent(8, false) }}
+{%- else %}
+    {
+        {{ template.codegen_single_thread_params(is_dynamic_M)|indent(8, false) }}
+{%- endif %}
+        {{ micro_gemm.codegen_init(kernel) }}
+{%- set acc_buf_name_list=[] %}
+{%- set acc_buf_name_prefix = "local_acc_buf_" %}
+{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
+    {%- set acc_buf_name = acc_buf_name_prefix + gemm_idx|string %}
+    {{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }}
+    {%- set acc_buf_name_list=acc_buf_name_list.append(acc_buf_name) %}
+{%- endfor %}
+        for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) {
+            {{ template.codegen_m_loop_params()|indent(12, false) }}
+            for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
+                {{ template.codegen_n_loop_params()|indent(16, false) }}
+{%- set acc_list=[] %}
+{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
+    {%- set acc_list = acc_list.append( kernel.local_buffers[acc_buf_name_list[gemm_idx]] ) %}
+    {{ kernel.reinit_buffer_if_null(acc_buf_name_list[gemm_idx]) }}
+{%- endfor %}
+                for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
+                    int64_t k_start = kc * Kr;
+                    int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
+{%- set tile_X_list=[] %}
+{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
+    {%- set tile_X_list = tile_X_list.append( kernel.slice_nd(X_list[gemm_idx], [("m_start", "m_end"), ("k_start", "k_end")]) ) %}
+{%- endfor %}
+                    for (int64_t nci = nc; nci < nc_block_end; nci++) {
+{%- set tile_W_3d_list=[] %}
+{%- set tile_W_list=[] %}
+{%- set acc_slice_list=[] %}
+{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
+    {%- set acc_slice_list = acc_slice_list.append(
+        kernel.slice_nd(acc_list[gemm_idx], [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")])
+    ) %}
+    {%- set tile_W_3d_list = tile_W_3d_list.append(
+        kernel.slice_nd(W_list[gemm_idx], [("nci", "nci + 1"), ("k_start", "k_end"), ()])
+    ) %}
+{%- endfor %}
+{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
+    {%- set tile_W_list = tile_W_list.append(
+        kernel.view(tile_W_3d_list[gemm_idx], ["k_end - k_start", micro_gemm.register_blocking.block_n])
+    ) %}
+{%- endfor %}
+                        if (kc == k_block_start) {
+                            {%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
+                                {{ micro_gemm.codegen_call(
+                                    kernel, tile_X_list[gemm_idx], tile_W_list[gemm_idx], acc_slice_list[gemm_idx], accum=False
+                                )|indent(28, false) }}
+                            {%- endfor %}
+                        } else {
+                            {%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
+                                {{ micro_gemm.codegen_call(
+                                    kernel, tile_X_list[gemm_idx], tile_W_list[gemm_idx], acc_slice_list[gemm_idx], accum=True
+                                )|indent(28, false) }}
+                            {%- endfor %}
+                        }
+                    }
+                }
+                {
+{%- set tile_acc_list = [] %}
+{%- set tile_Y_list = [] %}
+{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
+    {%- set tile_acc_list = tile_acc_list.append(
+        kernel.slice_nd(acc_list[gemm_idx], [("0", "m_end - m_start"), ("0", "n_end - n_start")])
+    ) %}
+    {%- set tile_Y_list = tile_Y_list.append(
+        kernel.slice_nd(Y_2d_list[gemm_idx], [("m_start", "m_end"), ("n_start", "n_end")])
+    ) %}
+{%- endfor %}
+                    {{ kernel.store_outputs(
+                        tile_Y_list,
+                        tile_acc_list,
+                        GemmOuts,
+                        epilogue_nodes,
+                        offsets=("m_start", "n_start"),
+                        reindexers=reindexers,
+                        multi_output_buffers=multi_output_buffers
+                    )|indent(20, false)
+                    }}
+                }
+            }
+        }
+        {{ micro_gemm.codegen_finalize(kernel) }}
+    }
+}
+"""
+
+
+def get_deduplicated_act(act_mapping: dict[int, ir.IRNode]) -> list[ir.IRNode]:
+    act_deduplicated = []
+    act_deduplicated_name: OrderedSet[str] = OrderedSet()
+    for act_idx in range(len(act_mapping.values())):
+        act = act_mapping[act_idx]
+        if act.get_name() not in act_deduplicated_name:
+            act_deduplicated.append(act)
+            act_deduplicated_name.add(act.get_name())
+    return act_deduplicated
+
+
+class CppGroupedGemmTemplate(CppGemmTemplate):
+    def __init__(
+        self,
+        input_nodes: list[ir.IRNode],
+        layout: ir.Layout,
+        num_threads: int,
+        register_blocking: GemmBlocking,
+        beta: int = 1,
+        alpha: int = 1,
+        has_bias: bool = False,
+        epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
+        act_mapping: Optional[dict[int, ir.IRNode]] = None,
+        gemm_grouped_num: int = 1,
+    ) -> None:
+        """
+        Template for Group of GEMMs:
+        * Each GEMM has the same dimensions (m, n, k) and the same leading dimensions (lda, ldb, ldc)
+          for their A, B, and C matrices.
+        * Each GEMM has distinct or shared activations, has distinct weight, has unique bias or no bias, has distinct epilogues.
+        * In the current implementation, the outputs of all GEMMs are accumulated using pointwise epilogues.
+          This behavior can be extended in the future if needed.
+        """
+        super().__init__(
+            input_nodes,
+            layout,
+            num_threads,
+            register_blocking,
+            beta,
+            alpha,
+            has_bias,
+            epilogue_creator,
+        )
+        self.act_mapping = act_mapping
+        self.gemm_grouped_num = gemm_grouped_num
+        # pyrefly: ignore [bad-override]
+        self.output_node: list[ir.Buffer] = [
+            ir.Buffer(name="buf_out" + str(idx), layout=layout)
+            for idx in range(gemm_grouped_num)
+        ]
+
+    @classmethod
+    # pyrefly: ignore [bad-override]
+    def add_choices(
+        cls,
+        choices: list[ChoiceCaller],
+        layout: ir.Layout,
+        input_nodes: list[ir.IRNode],
+        beta: int = 1,
+        alpha: int = 1,
+        has_bias: tuple[bool, ...] = (False, False),
+        trans_w: bool = False,
+        input_indices: Optional[list[int]] = None,
+        epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
+        act_mapping: Optional[dict[int, ir.IRNode]] = None,  # gemm idx to its act buf
+    ) -> DataProcessorTemplateWrapper:
+        # Input nodes order: x, optional[x1], ... w0, w1, ... optional[b0], optional[b1], ...
+        gemm_grouped_num = len(has_bias)
+        assert act_mapping
+        act_deduplicated = get_deduplicated_act(act_mapping)
+        wgt_start_idx = len(act_deduplicated)
+        bias_start_idx = wgt_start_idx + gemm_grouped_num
+        input_indices = list(range(len(input_nodes)))
+
+        _T = TypeVar("_T", ir.IRNode, torch.Tensor)
+        _U = TypeVar("_U", ir.Layout, torch.Tensor)
+
+        def reorder_and_filter(
+            inputs: list[_T],
+            layout_or_out: _U,
+        ) -> tuple[list[_T], _U]:
+            assert input_indices is not None, "input_indices must be set"
+            return [inputs[idx] for idx in input_indices], layout_or_out
+
+        new_inputs, new_layout = reorder_and_filter(input_nodes, layout)
+
+        def maybe_to_dense(
+            inputs: list[_T],
+            layout_or_out: _U,
+        ) -> tuple[list[_T], _U]:
+            new_inputs = list(inputs)
+            for idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num):
+                if isinstance(inputs[idx], torch.Tensor):
+                    W = inputs[idx]
+                    assert isinstance(W, torch.Tensor), "W must be a torch.Tensor"
+                    # pyrefly: ignore [unsupported-operation]
+                    new_inputs[idx] = W.to_dense() if W.is_mkldnn else W
+            return new_inputs, layout_or_out
+
+        def normalize_shapes(
+            inputs: list[_T],
+            layout_or_out: _U,
+        ) -> tuple[list[_T], _U]:
+            new_inputs: list[_T] = list(inputs)
+            if not trans_w:
+                return new_inputs, layout_or_out
+            X = new_inputs[0]
+            for wgt_idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num):
+                new_input = new_inputs[wgt_idx]
+                new_inputs[wgt_idx] = transpose_w(new_input, trans_w)
+            for bias_idx in range(bias_start_idx, len(new_inputs)):
+                # pyrefly: ignore [bad-argument-type]
+                new_bias = expand_bias(new_inputs[bias_idx], X)
+                assert new_bias is not None
+                # pyrefly: ignore [unsupported-operation]
+                new_inputs[bias_idx] = new_bias
+            return new_inputs, layout_or_out
+
+        num_threads = parallel_num_threads()
+        new_inputs, _ = normalize_shapes(*maybe_to_dense(new_inputs, new_layout))
+        m, n, k, *_ = mm_args(new_inputs[0], new_inputs[wgt_start_idx])
+        output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
+            new_inputs[0].get_dtype()
+        )
+        micro_gemm = create_micro_gemm(
+            "micro_gemm",
+            m,
+            n,
+            k,
+            input_dtype=new_inputs[0].get_dtype(),
+            input2_dtype=new_inputs[wgt_start_idx].get_dtype(),
+            output_dtype=output_dtype,
+            compute_dtype=compute_dtype,
+            alpha=alpha,
+            num_threads=num_threads,
+        )
+        assert micro_gemm is not None
+        _, block_n, _ = micro_gemm.register_blocking
+        new_size, padded_n = cls.get_padded_size(
+            n, block_n, k, should_block_weight=True
+        )
+        padding = padded_n - n
+
+        def pack_weight(
+            inputs: list[_T],
+            layout_or_out: _U,
+        ) -> tuple[list[_T], _U]:
+            new_W_list = []
+            new_inputs = list(inputs)
+            W_list = new_inputs[wgt_start_idx : wgt_start_idx + gemm_grouped_num]
+            for W in W_list:
+                blocked_w = cls.block_weight(W, new_size, padding)
+                new_W_list.append(cls.pack_vnni_weight(blocked_w, micro_gemm, new_size))
+            new_inputs[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = new_W_list
+            return new_inputs, layout_or_out
+
+        def preprocessor(
+            inputs: list[_T],
+            layout: _U,
+        ) -> tuple[list[_T], _U]:
+            return pack_weight(
+                *normalize_shapes(*maybe_to_dense(*reorder_and_filter(inputs, layout)))
+            )
+
+        def postprocessor(output: _T) -> _T:
+            if isinstance(output, ir.TensorBox):
+                template_buffer = ir.InputsKernel.unwrap_storage_for_input(output)
+                assert isinstance(template_buffer, ir.CppTemplateBuffer)
+                new_input_nodes, _ = reorder_and_filter(input_nodes, layout)
+                W_nodes = new_input_nodes[
+                    wgt_start_idx : wgt_start_idx + gemm_grouped_num
+                ]
+                W_tensor = []
+                for W_node in W_nodes:
+                    assert W_node.get_name() in V.graph.constants
+                    # pyrefly: ignore [bad-argument-type]
+                    W_tensor.append(V.graph.constants[W_node.get_name()])
+                new_input_nodes[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = (
+                    W_tensor  # type: ignore[assignment]
+                )
+                new_input_nodes, _ = pack_weight(
+                    *normalize_shapes(*maybe_to_dense(new_input_nodes, layout))
+                )
+                # Prune unused tensors
+                prune_tensors(input_nodes, new_input_nodes)
+                for idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num):
+                    W_packed = new_input_nodes[idx]
+                    assert isinstance(W_packed, torch.Tensor)
+                    W_packed_constant = V.graph.add_tensor_constant(W_packed)
+                    template_buffer.inputs[idx] = (
+                        ir.InputsKernel.unwrap_storage_for_input(W_packed_constant)
+                    )
+            # pyrefly: ignore [bad-return]
+            return output
+
+        template = DataProcessorTemplateWrapper(
+            CppGroupedGemmTemplate,
+            preprocessor,
+            postprocessor,
+            input_nodes=input_nodes,
+            layout=layout,
+            num_threads=num_threads,
+            register_blocking=micro_gemm.register_blocking,
+            beta=beta,
+            alpha=alpha,
+            has_bias=has_bias,
+            epilogue_creator=epilogue_creator,
+            act_mapping=act_mapping,
+            gemm_grouped_num=gemm_grouped_num,
+        )
+        template.maybe_append_choice(choices)
+        return template
+
+    def render(  # type: ignore[override,return,no-untyped-def]
+        self,
+        kernel: CppTemplateKernel,
+        template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
+        flag_template_buffer_has_other_users: Optional[bool] = None,
+        epilogue_nodes: Optional[list[ir.IRNode]] = None,
+        **kwargs,
+    ) -> str:
+        assert self.act_mapping
+        act_deduplicated = get_deduplicated_act(self.act_mapping)
+        wgt_start_idx = len(act_deduplicated)
+        bias_start_idx = wgt_start_idx + self.gemm_grouped_num
+        X_list = list(self.act_mapping.values())
+        W_list = self.input_nodes[wgt_start_idx : wgt_start_idx + self.gemm_grouped_num]
+        inp_list = []
+        cur_idx = bias_start_idx
+        for inp_idx in range(self.gemm_grouped_num):
+            inp = None
+            # pyrefly: ignore [index-error]
+            if self.has_bias[inp_idx]:
+                inp = self.input_nodes[cur_idx]
+                cur_idx += 1
+            inp_list.append(inp)
+
+        Y_list = self.output_node
+        multi_output_buffers = None
+        if template_buffer_node is not None:
+            W_list = template_buffer_node.inputs[
+                wgt_start_idx : wgt_start_idx + self.gemm_grouped_num
+            ]
+            assert isinstance(template_buffer_node.outputs, list)
+            Y_list = template_buffer_node.outputs
+            counters["inductor"]["cpp_grouped_gemm_template"] += 1
+            multi_output_buffers = template_buffer_node.outputs
+
+        template_buffer = Y_list[0]
+        fake_buffers: list[ir.Buffer] = []
+        Y_2d_list = Y_list
+        output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
+            X_list[0].get_dtype()
+        )
+        micro_gemm = create_micro_gemm(
+            f"{kernel.kernel_name}_micro_gemm",
+            self.m,
+            self.n,
+            self.k,
+            input_dtype=X_list[0].get_dtype(),
+            # pyrefly: ignore [missing-attribute]
+            input2_dtype=W_list[0].get_dtype(),
+            output_dtype=output_dtype,
+            compute_dtype=compute_dtype,
+            alpha=self.alpha,
+            num_threads=self.num_threads,
+        )
+        assert micro_gemm is not None
+        assert self.register_blocking == micro_gemm.register_blocking
+        self.log_blockings()
+        if isinstance(micro_gemm, CppMicroGemmAMX):
+            counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1
+
+        L1_cache_size = torch._C._cpu._L1d_cache_size()  # per core cache size in Bytes
+        assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}"
+
+        L2_cache_size = torch._C._cpu._L2_cache_size()  # per core cache size in Bytes
+        assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}"
+
+        epilogues: list[ir.IRNode] = []
+        reindexers: list[Optional[Callable[[list[Any]], list[Any]]]] = []
+        gemm_output_buffers: list[ir.Buffer] = []
+        for out_buf_idx in range(self.gemm_grouped_num):
+            gemm_output_name = f"{template_buffer.get_name()}_GemmOut" + str(
+                out_buf_idx
+            )
+            gemm_output_buffers.append(
+                ir.Buffer(name=gemm_output_name, layout=template_buffer.layout)
+            )
+
+        assert not self.epilogue_creator, (
+            "epilogue_creator is not supported yet in Grouped GEMM Template"
+        )
+
+        kernel_args: dict[str, Optional[ir.IRNode]] = {}
+        for x_idx in range(wgt_start_idx):
+            kernel_args["X" + str(x_idx)] = act_deduplicated[x_idx]
+        for w_idx in range(self.gemm_grouped_num):
+            # pyrefly: ignore [unsupported-operation]
+            kernel_args["W" + str(w_idx)] = W_list[w_idx]
+        for inp_idx in range(self.gemm_grouped_num):
+            kernel_args["inp" + str(inp_idx)] = inp_list[inp_idx]
+
+        def _bias_add_epilogue(buf: ir.IRNode, inp: ir.IRNode) -> ir.Pointwise:
+            return create_epilogue_with_attr(
+                buf, "bias_add", other=inp, beta=self.beta, dtype=self.layout.dtype
+            )
+
+        for gemm_idx, inp in enumerate(inp_list):
+            if inp:
+                buffer_name = Y_list[gemm_idx].get_name()
+                epilogues.append(
+                    ir.ComputedBuffer(
+                        name=buffer_name,
+                        layout=template_buffer.layout,
+                        data=_bias_add_epilogue(gemm_output_buffers[gemm_idx], inp),
+                    )
+                )
+                reindexers.append(None)
+
+        if epilogue_nodes:
+            epilogues.extend(epilogue_nodes)
+            for epilogue_node in epilogue_nodes:
+                Y = cast(ir.Buffer, epilogue_node)
+                _, reindexers = gen_2d_view_of_epilogue_buf(
+                    Y,
+                    template_buffer,
+                    [
+                        epilogue_node,
+                    ],
+                    reindexers,
+                    default_reindexers=[
+                        None,
+                    ],
+                )
+
+        options = dict(
+            N=self.n,
+            K=self.k,
+            PADDED_N=self.padded_n,
+            aliases={},
+            beta=self.beta,
+            alpha=self.alpha,
+            num_threads=self.num_threads,
+            micro_gemm=micro_gemm,
+            is_dynamic_M=self.is_dynamic_M,
+            template=self,
+            kernel=kernel,
+            export_declaration=get_export_declaration(),
+            acc_buf_dtype=torch.float,
+            DTYPE_TO_CPP=DTYPE_TO_CPP,
+            L1_cache_size=L1_cache_size,
+            L2_cache_size=L2_cache_size,
+            config=config,
+            epilogue_nodes=epilogues,
+            GemmOuts=gemm_output_buffers,
+            reindexers=reindexers,
+            kernel_args=kernel_args,
+            X_list=X_list,
+            W_list=W_list,
+            gemm_grouped_num=self.gemm_grouped_num,
+            Y_list={"Y" + str(idx): Y for idx, Y in enumerate(Y_list)},
+            Y_2d_list=Y_2d_list,
+            multi_output_buffers=multi_output_buffers,
+        )
+        with contextlib.ExitStack() as stack:
+            stack.enter_context(
+                patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_buffers))
+            )
+            return self._template_from_string(GEMM_TEMPLATE).render(**options)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_micro_gemm.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_micro_gemm.py
new file mode 100644
index 0000000000000000000000000000000000000000..39c026949fb13d541191b7462ad8f5666f09c098
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_micro_gemm.py
@@ -0,0 +1,2232 @@
+# mypy: allow-untyped-defs
+import dataclasses
+import operator
+import sys
+from collections.abc import Callable
+from enum import Enum
+from typing import Optional
+
+import torch
+
+from .. import cpp_builder, ir
+from ..cpu_vec_isa import (
+    pick_vec_isa,
+    VecAMX,
+    VecAVX2,
+    VecAVX512,
+    VecAVX512VNNI,
+    VecISA,
+    VecNEON,
+    VecSVE256,
+)
+from ..utils import IndentedBuffer, parallel_num_threads
+from ..virtualized import V
+from .common import KernelTemplate
+from .cpp_template_kernel import CppTemplateKernel
+from .cpp_utils import DTYPE_TO_CPP, GemmBlocking, value_to_cpp
+
+
+class LayoutType(Enum):
+    NORMAL = 0
+    VNNI2 = 1
+    VNNI4 = 2
+
+
+_IS_WINDOWS = sys.platform == "win32"
+
+
+def get_restrict_keyword() -> str:
+    if _IS_WINDOWS:
+        # https://learn.microsoft.com/en-us/cpp/cpp/extension-restrict?view=msvc-170
+        return "__restrict"
+    else:
+        return "__restrict__"
+
+
+class CppMicroGemm:
+    """
+    A class that codegens a kernel that computes small-sized matrix multiplication.
+
+    A micro GEMM kernel is responsible for register blocking, instruction selection,
+    and other CPU architecture-specific optimizations.
+
+    The subclasses need to override `codegen_define` to define the kernel function
+    that is called by the code generated by `codegen_call`.
+    """
+
+    # TODO(jgong5): support constant shapes and lds as template args.
+    DECLARE_KERNEL = r"""
+template 
+inline void {{kernel_name}}(
+{%- if kernel_extra_args_declare %}
+    {{kernel_extra_args_declare}}
+{%- endif %}
+    const {{input_t}}* {{restrict_keyword}} A,
+    const {{input2_t}}* {{restrict_keyword}} B,
+    {{output_t}}* {{restrict_keyword}} C,
+    int64_t M,
+    int64_t N,
+    int64_t K,
+    int64_t lda,
+    int64_t ldb,
+    int64_t ldc
+)
+"""
+
+    def __init__(
+        self,
+        name,
+        input_dtype,
+        input2_dtype,
+        output_dtype,
+        compute_dtype,
+        register_blocking,
+        alpha=1,
+    ) -> None:
+        self.name = name
+        self.input_dtype = input_dtype
+        assert input2_dtype is not None
+        self.input2_dtype = input2_dtype
+        self.output_dtype = output_dtype
+        self.compute_dtype = compute_dtype
+        self.register_blocking = register_blocking
+        self.alpha = alpha
+        self.pack_vnni_B_locally = False
+
+    def get_common_options(self):
+        if self.input_dtype in [torch.uint8, torch.int8]:
+            assert self.compute_dtype == torch.int32
+            assert self.output_dtype == torch.int32
+            assert self.input2_dtype == torch.int8
+        return {
+            "torch": torch,
+            "kernel_name": self.name,
+            "input_dtype": self.input_dtype,
+            "input2_dtype": self.input2_dtype,
+            "output_dtype": self.output_dtype,
+            "compute_dtype": self.compute_dtype,
+            "input_t": DTYPE_TO_CPP[self.input_dtype],
+            "input2_t": DTYPE_TO_CPP[self.input2_dtype],
+            "output_t": DTYPE_TO_CPP[self.output_dtype],
+            "compute_t": DTYPE_TO_CPP[self.compute_dtype],
+            "alpha": self.alpha,
+            "kernel_extra_args_declare": self.get_kernel_extra_args_declare(),
+            "int8_gemm": self.input_dtype in [torch.uint8, torch.int8],
+            "vnni_size": 4 if self.input_dtype in [torch.uint8, torch.int8] else 2,
+            "restrict_keyword": get_restrict_keyword(),
+            "pack_vnni_B_locally": self.pack_vnni_B_locally,
+            "template": self,
+            "is_woq_int4": self.is_woq_int4(),
+        }
+
+    def get_kernel_declaration(self):
+        options = self.get_common_options()
+        return KernelTemplate._template_from_string(self.DECLARE_KERNEL).render(options)
+
+    def get_kernel_extra_args_declare(self) -> str:
+        return ""
+
+    def get_kernel_extra_args(self, **kwargs) -> list[str]:
+        return []
+
+    def codegen_define(self, kernel: CppTemplateKernel) -> str:
+        raise NotImplementedError
+
+    def codegen_call(
+        self,
+        kernel: CppTemplateKernel,
+        A: ir.Buffer,
+        B: ir.Buffer,
+        C: ir.Buffer,
+        accum: bool,
+        prefetch: bool = False,
+        **kwargs_for_extra_args,
+    ) -> str:
+        """
+        Generate the code for calling the templated kernel that computes
+        `C += alpha * A @ B` if `accum` is True, or `C = alpha * A @ B` otherwise.
+        """
+        A_ptr = f"&({kernel.index(A, [0, 0])})"
+        B_ptr = f"&({kernel.index(B, [0, 0])})"
+        C_ptr = f"&({kernel.index(C, [0, 0])})"
+        M = kernel.size(C, 0)
+        N = kernel.size(C, 1)
+        K = kernel.size(A, 1)
+        lda = kernel.stride(A, 0)
+        ldb = kernel.stride(B, 0)
+        ldc = kernel.stride(C, 0)
+        res = IndentedBuffer()
+        res.writeline(
+            f"{self.name}<{value_to_cpp(accum, 'bool')}, {value_to_cpp(prefetch, 'bool')}>("
+        )
+        with res.indent():
+            kwargs_for_extra_args.update({"kernel": kernel})
+            extra_args = self.get_kernel_extra_args(**kwargs_for_extra_args)
+            for arg in extra_args:
+                res.writeline(arg)
+            res.writeline(f"{A_ptr},")
+            res.writeline(f"{B_ptr},")
+            res.writeline(f"{C_ptr},")
+            res.writeline(f"{M},")
+            res.writeline(f"{N},")
+            res.writeline(f"{K},")
+            res.writeline(f"{lda},")
+            res.writeline(f"{ldb},")
+            res.writeline(f"{ldc}")
+        res.writeline(");")
+        return res.getvalue()
+
+    def use_local_vnni_blocking(self, should_block_weight: bool):
+        self.pack_vnni_B_locally = should_block_weight
+
+    def codegen_init(
+        self,
+        kernel: CppTemplateKernel,
+    ) -> str:
+        return ""
+
+    def codegen_finalize(
+        self,
+        kernel: CppTemplateKernel,
+    ) -> str:
+        return ""
+
+    def get_b_layout(self) -> LayoutType:
+        return LayoutType.NORMAL
+
+    ALLOCATE_WEIGHT_BUFFER = r"""
+    {%- if is_msvc_compiler %}
+    // MSVC doesn't support stack-allocated dynamic-sized arrays, so using heap memory here.
+    auto heap_deq_b_buf_ptr = std::make_unique<{{buffer_dtype}}[]>({{buffer_size}});
+    {{buffer_dtype}}* {{buffer_name}} = heap_deq_b_buf_ptr.get();
+    {%- else %}
+    // It's safe to use a stack-allocated array since the blocking strategy would
+    // require us to allocate an array that's smaller than the size of L1D cache,
+    // and the default per thread max stack size on Linux is quite higher,
+    // so we need not worry about stack overflow.
+    alignas(4096) {{buffer_dtype}} {{buffer_name}}[{{buffer_size}}];
+    {%- endif %}
+"""
+
+    def codegen_allocate_weight_buffer(
+        self, buffer_name: str, buffer_dtype: str, *size_args
+    ) -> str:
+        buffer_size = " * ".join(map(str, size_args))
+        return KernelTemplate._template_from_string(self.ALLOCATE_WEIGHT_BUFFER).render(
+            {
+                "buffer_name": buffer_name,
+                "buffer_dtype": buffer_dtype,
+                "buffer_size": buffer_size,
+                "is_msvc_compiler": cpp_builder.is_msvc_cl(),
+            }
+        )
+
+    def is_woq_int4(self):
+        return False
+
+
+@dataclasses.dataclass
+class CppMicroGemmConfig:
+    input_dtype: torch.dtype
+    input2_dtype: torch.dtype
+    output_dtype: torch.dtype
+    compute_dtype: torch.dtype
+    vec_isa_cls: type[VecISA]
+    register_blocking: GemmBlocking
+    extra_check: Optional[Callable[..., bool]] = None
+
+
+micro_gemm_configs: dict[type[CppMicroGemm], list[CppMicroGemmConfig]] = {}
+
+
+def register_micro_gemm(*configs):
+    def inner(cls):
+        assert cls not in micro_gemm_configs, (
+            f"Duplicate micro_gemm registration for {cls}"
+        )
+        assert len(configs) > 0, f"No micro_gemm configs provided for {cls}"
+        micro_gemm_configs[cls] = list(configs)
+        return cls
+
+    return inner
+
+
+def generate_gemm_config(
+    vec_isa_cls,
+    register_blockings,
+    input_dtype=torch.float,
+    input2_dtype=None,
+    output_dtype=None,
+    compute_dtype=None,
+    extra_check=None,
+):
+    if output_dtype is None:
+        output_dtype = input_dtype
+    if compute_dtype is None:
+        compute_dtype = output_dtype
+    if input2_dtype is None:
+        input2_dtype = input_dtype
+    return [
+        CppMicroGemmConfig(
+            input_dtype,
+            input2_dtype,
+            output_dtype,
+            compute_dtype,
+            vec_isa_cls,
+            GemmBlocking(*blocking),
+            extra_check,
+        )
+        for blocking in register_blockings
+    ]
+
+
+class CppMicroGemmRef(CppMicroGemm):
+    """
+    A reference implementation of the CppMicroGemm class with naive C++ code.
+    It is used for correctness debugging.
+    """
+
+    TEMPLATE_ENTRY = r"""
+{{declare_kernel}} {
+    for (int64_t m = 0; m < M; ++m) {
+        for (int64_t n = 0; n < N; ++n) {
+            {{compute_t}} result = accum ? C[m * ldc + n] : 0;
+            for (int64_t k = 0; k < K; ++k) {
+                result += ({{compute_t}})A[m * lda + k] * ({{compute_t}})B[k * ldb + n] * {{alpha}};
+            }
+            C[m * ldc + n] = result;
+        }
+    }
+}
+"""
+
+    def __init__(
+        self, name, input_dtype, input2_dtype, output_dtype, compute_dtype, alpha
+    ) -> None:
+        super().__init__(
+            name,
+            input_dtype,
+            input2_dtype,
+            output_dtype,
+            compute_dtype,
+            GemmBlocking(1, 1, 1),
+            alpha,
+        )
+
+    def codegen_define(self, kernel: CppTemplateKernel) -> str:
+        options = {
+            "declare_kernel": self.get_kernel_declaration(),
+            **self.get_common_options(),
+        }
+        return KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(options)
+
+
+def is_int8_woq_gemm_small_m_dim_corner_case(config, m, n, k):
+    return (
+        k % config.register_blocking.block_k == 0
+        and n % config.register_blocking.block_n == 0
+        and m < 16
+    )
+
+
+# extra check for small M dimension for int8 WoQ case
+def check_int8_woq_small_m_dim(config, m, n, k, alpha, num_threads, **kwargs):
+    return is_int8_woq_gemm_small_m_dim_corner_case(config, m, n, k) and not kwargs.get(
+        "dynamic_M", False
+    )
+
+
+# For int8 WoQ GEMM with small M, we use different blockings that shouldn't be used otherwise
+def do_not_use_with_small_m_for_int8_woq(config, m, n, k, alpha, num_threads, **kwargs):
+    return not check_int8_woq_small_m_dim(config, m, n, k, alpha, num_threads, **kwargs)
+
+
+@register_micro_gemm(
+    *generate_gemm_config(
+        VecAVX512,
+        [(8, 48, 1), (8, 32, 1), (16, 16, 1)],
+        input_dtype=torch.float,
+    ),
+    *generate_gemm_config(
+        VecAVX512,
+        [(8, 48, 1), (8, 32, 1), (16, 16, 1)],
+        input_dtype=torch.bfloat16,
+        output_dtype=torch.float,
+    ),
+    *generate_gemm_config(
+        VecAVX512,
+        [(8, 48, 1), (8, 32, 1), (16, 16, 1)],
+        input_dtype=torch.half,
+        output_dtype=torch.float,
+    ),
+    *generate_gemm_config(
+        VecAVX512,
+        [(8, 48, 1), (8, 32, 1), (16, 16, 1)],
+        input_dtype=torch.bfloat16,
+        input2_dtype=torch.int8,
+        output_dtype=torch.float,
+        compute_dtype=torch.float,
+        extra_check=do_not_use_with_small_m_for_int8_woq,
+    ),
+    *generate_gemm_config(
+        VecAVX512,
+        [
+            (4, 32, 64),
+            (8, 32, 64),
+        ],
+        input_dtype=torch.bfloat16,
+        input2_dtype=torch.int8,
+        output_dtype=torch.float,
+        compute_dtype=torch.float,
+        extra_check=check_int8_woq_small_m_dim,
+    ),
+    *generate_gemm_config(
+        VecAVX2,
+        [(4, 24, 1), (4, 16, 1), (8, 8, 1)],
+        input_dtype=torch.float,
+    ),
+    *generate_gemm_config(
+        VecAVX2,
+        [(4, 24, 1), (4, 16, 1), (8, 8, 1)],
+        input_dtype=torch.bfloat16,
+        output_dtype=torch.float,
+    ),
+    *generate_gemm_config(
+        VecAVX2,
+        [(4, 24, 1), (4, 16, 1), (8, 8, 1)],
+        input_dtype=torch.half,
+        output_dtype=torch.float,
+    ),
+    *generate_gemm_config(
+        VecAVX2,
+        [(4, 24, 1), (4, 16, 1), (8, 8, 1)],
+        input_dtype=torch.bfloat16,
+        input2_dtype=torch.int8,
+        output_dtype=torch.float,
+        compute_dtype=torch.float,
+        extra_check=do_not_use_with_small_m_for_int8_woq,
+    ),
+    *generate_gemm_config(
+        VecAVX2,
+        [
+            (2, 16, 64),
+            (4, 16, 64),
+        ],
+        input_dtype=torch.bfloat16,
+        input2_dtype=torch.int8,
+        output_dtype=torch.float,
+        compute_dtype=torch.float,
+        extra_check=check_int8_woq_small_m_dim,
+    ),
+    *generate_gemm_config(
+        VecNEON,
+        [(4, 24, 1), (4, 16, 1), (8, 8, 1)],
+        input_dtype=torch.float,
+        input2_dtype=torch.float,
+        output_dtype=torch.float,
+        compute_dtype=torch.float,
+    ),
+    *generate_gemm_config(
+        VecSVE256,
+        [(4, 24, 1), (4, 16, 1), (8, 8, 1)],
+        input_dtype=torch.float,
+        input2_dtype=torch.float,
+        output_dtype=torch.float,
+        compute_dtype=torch.float,
+    ),
+)
+class CppMicroGemmFP32Vec(CppMicroGemm):
+    """
+    This class generates the code for micro gemm using fp32 vec instructions for compute.
+    It supports input types of torch.float, torch.bfloat16, and torch.half with fp32 output.
+    The output of the microkernel is in FP32, but it would be converted to BF16/FP16 in the template,
+    if the desired output is BF16/FP16.
+    """
+
+    TEMPLATE_ENTRY = r"""
+{{declare_kernel}} {
+    using Vectorized = at::vec::Vectorized<{{compute_t}}>;
+    constexpr auto VLEN = Vectorized::size();
+    {{kernel.assert_function}}({{block_n}} % VLEN == 0, "block_n dimension must be multiple of Vector size");
+    {{kernel.assert_function}}(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}");
+    // TODO(jgong5): loop unroll for M and N
+    for (int64_t m = 0; m < M; m += {{block_m}}) {
+        int64_t block_m = std::min(M - m, {{block_m}});
+        for (int64_t n = 0; n < N; n += {{block_n}}) {
+            int64_t block_n = std::min(N - n, {{block_n}});
+            if (block_m == {{block_m}} && block_n == {{block_n}}) {
+{%- if not trans_b %}
+                {{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum, prefetch>(
+{%- else %}
+                {{kernel_name}}_transpose_b_kernel<{{block_m}}, {{block_n}}, accum, prefetch>(
+{%- endif %}
+                    A + m * lda,
+{%- if not trans_b %}
+                    B + n,
+{%- else %}
+                    B + n * ldb,
+{%- endif %}
+                    C + m * ldc + n,
+                    K,
+                    lda,
+                    ldb,
+                    ldc
+                );
+{%- if tail_n %}
+            } else if (block_n == {{block_n}}){
+{%- else %}
+            } else {
+{%- endif %}
+                switch (block_m) {
+{%- for b in range(block_m - 1, 0, -1) %}
+                case {{b}}:
+    {%- if not trans_b %}
+                    {{kernel_name}}_kernel<{{b}}, {{block_n}}, accum, prefetch>(
+    {%- else %}
+                    {{kernel_name}}_transpose_b_kernel<{{b}}, {{block_n}}, accum, prefetch>(
+    {%- endif %}
+                        A + m * lda,
+    {%- if not trans_b %}
+                        B + n,
+    {%- else %}
+                        B + n * ldb,
+    {%- endif %}
+                        C + m * ldc + n,
+                        K,
+                        lda,
+                        ldb,
+                        ldc
+                    );
+                    break;
+{%- endfor %}
+                default:
+                    {{kernel.assert_function}}(false, "Unsupported block_m: {{block_m}}");
+                }
+
+{%- if tail_n %}
+            } else {
+                switch (block_m) {
+    {%- for b in range(block_m, 0, -1) %}
+                case {{b}}:
+        {%- if not trans_b %}
+                    {{kernel_name}}_ntail_kernel<{{b}}, {{block_n}}, accum, prefetch>(
+        {%- else %}
+                    {{kernel_name}}_ntail_transpose_b_kernel<{{b}}, {{block_n}}, accum, prefetch>(
+        {%- endif %}
+                        A + m * lda,
+        {%- if not trans_b %}
+                        B + n,
+        {%- else %}
+                        B + n * ldb,
+        {%- endif %}
+                        C + m * ldc + n,
+                        block_n,
+                        K,
+                        lda,
+                        ldb,
+                        ldc
+                    );
+                    break;
+    {%- endfor %}
+                default:
+                    {{kernel.assert_function}}(false, "Unsupported block_m: {{block_m}}");
+                }
+            }
+{%- else %}
+            }
+{%- endif %}
+        }
+    }
+}
+"""
+
+    TEMPLATE_KERNEL = r"""
+
+template 
+{%- if not trans_b %}
+    {%- if tail_n %}
+inline void {{kernel_name}}_ntail_kernel(
+    {%- else %}
+inline void {{kernel_name}}_kernel(
+    {%- endif %}
+{%- else %}
+    {%- if tail_n %}
+inline void {{kernel_name}}_ntail_transpose_b_kernel(
+    {%- else %}
+inline void {{kernel_name}}_transpose_b_kernel(
+    {%- endif %}
+{%- endif %}
+    const {{input_t}}* {{restrict_keyword}} A,
+    const {{input2_t}}* {{restrict_keyword}} B,
+    {{output_t}}* {{restrict_keyword}} C,
+{%- if tail_n %}
+    int64_t N,
+{%- endif %}
+    int64_t K,
+    int64_t lda,
+    int64_t ldb,
+    int64_t ldc
+) {
+    using Vectorized = at::vec::Vectorized<{{compute_t}}>;
+{%- if input2_dtype in [torch.bfloat16, torch.float16] %}
+    using VectorizedIn = at::vec::Vectorized<{{input_t}}>;
+{%- endif %}
+
+{%- if not trans_b %}
+    constexpr auto VLEN = Vectorized::size();
+    constexpr auto ROWS = BLOCK_M;
+    constexpr auto COLS = BLOCK_N / VLEN;
+
+    Vectorized va;
+    at::vec::VectorizedN<{{compute_t}}, COLS> vb;
+    at::vec::VectorizedN<{{compute_t}}, ROWS*COLS> vc;
+
+    {%- if tail_n %}
+    int64_t rCOLS = (N + VLEN - 1) / VLEN;
+    int ntail = N % VLEN;
+    {%- endif %}
+    auto loadc = [&](auto i) {
+        if constexpr (accum) {
+            constexpr int row = i / COLS;
+            constexpr int col = i % COLS;
+    {%- if tail_n %}
+            int load_size = (col == rCOLS - 1 && ntail != 0) ? ntail : VLEN;
+            if (col < rCOLS) {
+                vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN, load_size);
+            }
+    {%- else %}
+            vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN);
+    {%- endif %}
+        } else {
+            vc[i] = Vectorized(0.0f);
+        }
+    };
+    c10::ForcedUnroll{}(loadc);
+
+    auto compute = [&, COLS](auto i, int k) {
+        constexpr int row = i / COLS;
+        constexpr int col = i % COLS;
+    {%- if tail_n %}
+        int load_size = (col == rCOLS - 1 && ntail != 0) ? ntail : VLEN;
+    {%- endif %}
+        if constexpr (col == 0) {
+    {%- if alpha != 1 %}
+            va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k]) * {{alpha}});
+    {%- else %}
+            va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k]));
+    {%- endif %}
+        }
+
+        if constexpr (row == 0) {
+    {%- if tail_n %}
+            if (col < rCOLS) {
+        {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
+                auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, load_size);
+                vb[col] = at::vec::convert<{{compute_t}}>(b);
+        {%- elif input2_dtype == torch.int8 %}
+            // Convert VLEN int8 elements to int32, and then fp32
+                auto b32 = at::vec::convert_to_int32(B + k * ldb + col * VLEN, load_size);
+                vb[col] = at::vec::convert(b32);
+        {%- else %}
+                vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN, load_size);
+        {%- endif %}
+            } else {
+                vb[col] = Vectorized(0.0f);
+            }
+
+    {%- else %}
+
+        {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
+            auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, VLEN);
+            vb[col] = at::vec::convert<{{compute_t}}>(b);
+        {%- elif input2_dtype == torch.int8 %}
+            // Convert VLEN int8 elements to int32, and then fp32
+            auto b32 = at::vec::convert_to_int32(B + k * ldb + col * VLEN);
+            if constexpr (prefetch) {
+              _mm_prefetch(B + (k + {{block_k}}) * ldb + col * VLEN, _MM_HINT_T0);
+            }
+            vb[col] = at::vec::convert(b32);
+        {%- else %}
+            vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN);
+        {%- endif %}
+    {%- endif %}
+
+        }
+
+        constexpr int idx = row * COLS + col;
+    {%- if tail_n %}
+        if (col < rCOLS) {
+            vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]);
+        }
+    {%- else %}
+        vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]);
+    {%- endif %}
+    };
+
+    for (int k = 0; k < K; ++k) {
+        c10::ForcedUnroll{}(compute, k);
+    }
+
+    // store to C
+    auto storec = [&](auto i) {
+        constexpr int row = i / COLS;
+        constexpr int col = i % COLS;
+    {%- if tail_n %}
+        int store_size = (col == rCOLS - 1 && ntail != 0) ? ntail : VLEN;
+        if (col < rCOLS) {
+            vc[i].store(C + row * ldc + col * VLEN, store_size);
+        }
+    {%- else %}
+        vc[i].store(C + row * ldc + col * VLEN);
+    {%- endif %}
+    };
+    c10::ForcedUnroll{}(storec);
+
+{%- else %}
+    // Use 2 implementations for the transposed B:
+    // First implementation:
+    //   Transpose first and then perform outer product calculation in sub-blocks,
+    //   which introduces an additional transpose overhead of [K, N] compared to the non-transpose version.
+    // Second implementation:
+    //   Directly perform inner product calculation in sub-blocks,
+    //   which introduces an additional vector reduction of [M, N] compared to the non-tranpose version.
+    // Therefore, when M * N / (K * N) is large, the first implementation has better performance.
+    {%- if tail_n %}
+    if (K % Vectorized::size() == 0 && N % Vectorized::size() == 0 && 24 * BLOCK_M > K) {
+    {%- else %}
+    if (K % Vectorized::size() == 0 && 24 * BLOCK_M > K) {
+    {%- endif %}
+        // First implementation:
+        constexpr auto VLEN = Vectorized::size();
+        constexpr auto ROWS = BLOCK_M;
+        constexpr auto COLS = BLOCK_N / VLEN;
+        int _K = K / VLEN;
+        Vectorized va;
+        at::vec::VectorizedN<{{compute_t}}, VLEN> vb;
+        at::vec::VectorizedN<{{compute_t}}, ROWS*COLS> vc;
+        auto loadc = [&](auto i) {
+            if constexpr (accum) {
+                constexpr int row = i / COLS;
+                constexpr int col = i % COLS;
+                vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN);
+            } else {
+                vc[i] = Vectorized(0.0f);
+            }
+        };
+        c10::ForcedUnroll{}(loadc);
+        auto unroll_loadB = [&](auto i, const {{input2_t}}* {{restrict_keyword}} src_ptr) {
+    {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
+            auto b = VectorizedIn::loadu(src_ptr + i * ldb, VLEN);
+            vb[i] = at::vec::convert<{{compute_t}}>(b);
+    {%- elif input2_dtype == torch.int8 %}
+            auto b32 = at::vec::convert_to_int32(src_ptr + i * ldb, VLEN);
+            vb[i] = at::vec::convert(b32);
+    {%- else %}
+            vb[i] = Vectorized::loadu(src_ptr + i * ldb, VLEN);
+    {%- endif %}
+        };
+        auto compute_trans = [&, COLS](auto i, int k) {
+            constexpr int row = i % ROWS;
+            constexpr int col = i / ROWS;
+            constexpr int e_col = col * VLEN;
+            int idk = k * VLEN;
+            if constexpr (row == 0) {
+                c10::ForcedUnroll{}(unroll_loadB, B + e_col * ldb + idk);
+                at::vec::transpose_block(vb);
+            }
+            constexpr int idx = row * COLS + col;
+            {{kernel.unroll_pragma(16)}}
+            for (int j = 0; j < VLEN; j++) {
+    {%- if alpha != 1 %}
+                va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + idk + j]) * {{alpha}});
+    {%- else %}
+                va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + idk + j]));
+    {%- endif %}
+                vc[idx] = at::vec::fmadd(va, vb[j], vc[idx]);
+            }
+        };
+        for (int k = 0; k < _K; ++k) {
+            c10::ForcedUnroll{}(compute_trans, k);
+        }
+        // store to C
+        auto storec = [&](auto i) {
+            constexpr int row = i / COLS;
+            constexpr int col = i % COLS;
+            vc[i].store(C + row * ldc + col * VLEN);
+        };
+        c10::ForcedUnroll{}(storec);
+    } else {
+        // Second implementation
+    {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
+        constexpr auto VLEN = VectorizedIn::size();
+    {%- else %}
+        constexpr auto VLEN = Vectorized::size();
+    {%- endif %}
+        int _K = (K + VLEN - 1) / VLEN;
+        // sub-block size of BLOCK_N and BLOCK_M
+        constexpr int sM = {{sub_block_m}};
+        constexpr int sN = {{sub_block_n}};
+    {%- if tail_n %}
+        int bN = (N + sN - 1) / sN;
+    {%- else %}
+        constexpr int bN = (BLOCK_N + sN - 1) / sN;
+    {%- endif %}
+        constexpr int bM = (BLOCK_M + sM - 1) / sM;
+
+    {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
+        at::vec::VectorizedN<{{compute_t}}, 2> va;
+        at::vec::VectorizedN<{{compute_t}}, 2 * sN> vb;
+    {%- else %}
+        at::vec::Vectorized<{{compute_t}}> va;
+        at::vec::VectorizedN<{{compute_t}}, sN> vb;
+    {%- endif %}
+        at::vec::VectorizedN<{{compute_t}}, sN * sM> vmid;
+
+    {%- if tail_n %}
+        int ntail = N % sN;
+    {%- else %}
+        constexpr int ntail = BLOCK_N % sN;
+    {%- endif %}
+        constexpr int mtail = BLOCK_M % sM;
+        int ktail = K % VLEN;
+
+        auto compute_trans = [&](int m, int n, int k) {
+    {%- if tail_n %}
+            int e_n = (n == bN - 1 && ntail != 0) ? (N - n * sN) : sN;
+    {%- else %}
+            int e_n = (n == bN - 1 && ntail != 0) ? (BLOCK_N - n * sN) : sN;
+    {%- endif %}
+            int e_m = (m == bM - 1 && mtail != 0) ? (BLOCK_M - m * sM) : sM;
+            int e_k = (k == _K - 1 && ktail != 0) ? (K - k * VLEN) : VLEN;
+            {{kernel.unroll_pragma(sub_block_n)}}
+            for (int i = 0; i < e_n; i++) {
+    {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
+                auto b = VectorizedIn::loadu(B + (sN * n + i) * ldb + k * VLEN, e_k);
+                std::tie(vb[2 * i], vb[2 * i + 1]) = at::vec::convert_to_float<{{input_t}}>(b);
+    {%- elif input2_dtype == torch.int8 %}
+                auto b32 = at::vec::convert_to_int32(B + (sN * n + i) * ldb + k * VLEN, e_k);
+                vb[i] = at::vec::convert(b32);
+    {%- else %}
+                vb[i] = Vectorized::loadu(B + (sN * n + i) * ldb + k * VLEN, e_k);
+    {%- endif %}
+            }
+
+            {{kernel.unroll_pragma(sub_block_m)}}
+            for (int s = 0; s < e_m; s++) {
+    {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
+                auto a = VectorizedIn::loadu(A + (sM * m + s) * lda + k * VLEN, e_k);
+                std::tie(va[0], va[1]) = at::vec::convert_to_float<{{input_t}}>(a);
+    {%- elif input2_dtype == torch.int8 %}
+                auto a32 = at::vec::convert_to_int32(A + (sM * m + s) * lda + k * VLEN, e_k);
+                va = at::vec::convert(a32);
+    {%- else %}
+                va = Vectorized::loadu(A + (sM * m + s) * lda + k * VLEN, e_k);
+    {%- endif %}
+
+    {%- if alpha != 1 %}
+                va = va * Vectorized({{alpha}});
+    {%- endif %}
+                if (k == 0) {
+                    {{kernel.unroll_pragma(sub_block_n)}}
+                    for (int i = 0; i < e_n; i++) {
+    {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
+                        vmid[sN * s + i] = at::vec::fmadd(va[0], vb[2 * i], Vectorized(0.0f));
+                        vmid[sN * s + i] = at::vec::fmadd(va[1], vb[2 * i + 1], vmid[sN * s + i]);
+    {%- else %}
+                        vmid[sN * s + i] = at::vec::fmadd(va, vb[i], Vectorized(0.0f));
+    {%- endif %}
+                    }
+                } else {
+                    {{kernel.unroll_pragma(sub_block_n)}}
+                    for (int i = 0; i < e_n; i++) {
+    {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
+                        vmid[sN * s + i] = at::vec::fmadd(va[0], vb[2 * i], vmid[sN * s + i]);
+                        vmid[sN * s + i] = at::vec::fmadd(va[1], vb[2 * i + 1], vmid[sN * s + i]);
+    {%- else %}
+                        vmid[sN * s + i] = at::vec::fmadd(va, vb[i], vmid[sN * s + i]);
+    {%- endif %}
+                    }
+                }
+            }
+
+            // store to C
+            if (k == _K - 1) {
+                {{kernel.unroll_pragma(sub_block_m)}}
+                for (int s = 0; s < e_m; s++) {
+                    {{kernel.unroll_pragma(sub_block_n)}}
+                    for (int i = 0; i < e_n; i++) {
+                        auto v = at::vec::vec_reduce_all([](Vectorized& x, Vectorized& y) { return x + y; }, vmid[sN * s + i]);
+                        if constexpr (accum) {
+                            auto c = *(C + (sM * m + s) * ldc + sN * n + i);
+                            *(C + (sM * m + s) * ldc + sN * n + i) = c + v;
+                        } else {
+                            *(C + (sM * m + s) * ldc + sN * n + i) = v;
+                        }
+                    }
+                }
+            }
+        };
+
+        for (int n = 0; n < bN; ++n) {
+            for (int m = 0; m < bM; ++m) {
+                for (int k = 0; k < _K; ++k) {
+                    compute_trans(m, n, k);
+                }
+            }
+        }
+    }
+{%- endif %}
+}
+"""
+
+    # set trans_b to generate gemm that supports transposed B matrix
+    # set tail_n to support the tail of N
+    # TODO add trans_b support for other micro gemms
+    # and move setting of trans_b to the init of CppMicroGemm
+    def __init__(
+        self,
+        name,
+        input_dtype,
+        input2_dtype,
+        output_dtype,
+        compute_dtype,
+        register_blocking,
+        alpha=1,
+        tail_n=False,
+        trans_b=False,
+    ) -> None:
+        super().__init__(
+            name,
+            input_dtype,
+            input2_dtype,
+            output_dtype,
+            compute_dtype,
+            register_blocking,
+            alpha,
+        )
+        self.tail_n = tail_n
+        # trans_b is only supported on platforms that
+        # support avx512 or avx2 since transpose_block is
+        # only implemented on these platforms
+        if trans_b:
+            vec_isa = pick_vec_isa()
+            assert issubclass(vec_isa.__class__, VecAVX512) or issubclass(
+                vec_isa.__class__, VecAVX2
+            )
+        self.trans_b = trans_b
+
+    def codegen_define(self, kernel: CppTemplateKernel) -> str:
+        options = {
+            "declare_kernel": self.get_kernel_declaration(),
+            "kernel": kernel,
+            "block_m": self.register_blocking.block_m,
+            "block_n": self.register_blocking.block_n,
+            "block_k": self.register_blocking.block_k,
+            "trans_b": False,
+            "tail_n": False,
+            "restrict_keyword": get_restrict_keyword(),
+            **self.get_common_options(),
+        }
+        if self.trans_b:
+            # TODO supports tuning of sub_block_m/sub_block_n
+            # to get better performance for specific shapes
+            sub_block_m = min(1, self.register_blocking.block_m)
+            sub_block_n = min(4, self.register_blocking.block_n)
+            # update options to generate kernel with trans_b and sub-block size
+            options.update(
+                {
+                    "trans_b": self.trans_b,
+                    "sub_block_m": sub_block_m,
+                    "sub_block_n": sub_block_n,
+                }
+            )
+        result = KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render(
+            options
+        )
+        # update options to generate the kernel for the tail of N
+        if self.tail_n:
+            options.update(
+                {
+                    "tail_n": self.tail_n,
+                }
+            )
+            result += KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render(
+                options
+            )
+        result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(
+            options
+        )
+        return result
+
+
+def check_vnni_extra(config, m, n, k, alpha, num_threads, **kwargs):
+    assert config.input_dtype == torch.uint8 and config.input2_dtype == torch.int8
+    vnni_size = 4
+    return k % vnni_size == 0
+
+
+@register_micro_gemm(
+    *generate_gemm_config(
+        VecAVX512VNNI,
+        # (block_m, block_n, block_k)
+        [(6, 64, 4)],
+        input_dtype=torch.uint8,
+        input2_dtype=torch.int8,
+        output_dtype=torch.int32,
+        compute_dtype=torch.int32,
+        extra_check=check_vnni_extra,
+    ),
+)
+class CppMicroGemmAVX512VNNI(CppMicroGemm):
+    """
+    This class generates the code for micro gemm using AVX512 VNNI instructions for compute.
+    It supports u8s8s32 GEMM only.
+    AVX512_VNNI ISA has been available since the 3rd gen of Intel Xeon.
+    """
+
+    TEMPLATE_ENTRY = r"""
+{{declare_kernel}} {
+    {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
+    {{kernel.assert_function}}(K % {{vnni_size}} == 0, "K dimension must be multiple of {{vnni_size}}");
+    constexpr int64_t M_BLOCK = {{block_m}};
+    const int64_t M_TAIL = M % M_BLOCK;
+    const int64_t M_MAIN = M - M_TAIL;
+    for (int64_t m = 0; m < M_MAIN; m += M_BLOCK) {
+        for (int64_t n = 0; n < N; n += {{block_n}}) {
+            {{kernel_name}}_kernel(
+                A + m * lda,
+                B + n,
+                C + m * ldc + n,
+                K,
+                lda,
+                ldb,
+                ldc
+            );
+        }
+    }
+    if (M_TAIL > 0) {
+        switch (M_TAIL) {
+{%- for m_tail in range(block_m - 1, 0, -1) %}
+            case ({{m_tail}}):
+                for (int64_t n = 0; n < N; n += {{block_n}}) {
+                    {{kernel_name}}_kernel<{{m_tail}}, {{block_n}}, accum>(
+                        A + M_MAIN * lda,
+                        B + n,
+                        C + M_MAIN * ldc + n,
+                        K,
+                        lda,
+                        ldb,
+                        ldc
+                    );
+                }
+                break;
+{%- endfor %}
+            default:
+                {{kernel.assert_function}}(false, "Unsupported M_TAIL: {}", M_TAIL);
+        } // switch M_TAIL
+    } // if M_TAIL
+}
+"""
+
+    TEMPLATE_KERNEL = r"""
+template 
+inline void {{kernel_name}}_kernel(
+    const {{input_t}}* {{restrict_keyword}} A,
+    const {{input2_t}}* {{restrict_keyword}} B,
+    {{output_t}}* {{restrict_keyword}} C,
+    int64_t K,
+    int64_t lda,
+    int64_t ldb,
+    int64_t ldc
+) {
+    constexpr const int COLS = N / {{vec_len}};
+    __m512i va;
+    __m512i vb[COLS];
+    __m512i vc[M * COLS];
+
+    c10::ForcedUnroll{}([&](auto i) { vc[i] = _mm512_setzero_epi32(); });
+
+    auto compute = [&](auto i, int k) {
+        constexpr const int row = i / COLS;
+        constexpr const int col = i % COLS;
+
+        if constexpr (col == 0) {
+            va = _mm512_set1_epi32(*(int32_t*)(A + row * lda + k));
+        }
+
+        if constexpr (row == 0) {
+            // B block in VNNI layout: [K / {{vnni_size}}, N, {{vnni_size}}]
+            int64_t offset = k * ldb + col * {{vec_len}} * {{vnni_size}};
+            vb[col] = _mm512_loadu_si512((__m512i const*)(B + offset));
+        }
+        vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]);
+    };
+
+    // Accumulate along k
+    constexpr const int k_unroll = 2;
+    int k = 0;
+    int k_limit = K / {{vnni_size}} / k_unroll;
+    for (; k < k_limit; k++) {
+        c10::ForcedUnroll{}(
+            [&](auto i) {
+                c10::ForcedUnroll{}(compute, {{vnni_size}} * (k * k_unroll + i));
+            }
+        );
+    }
+    k *= {{vnni_size}} * k_unroll;
+    for (; k < K; k += {{vnni_size}}) {
+        c10::ForcedUnroll{}(compute, k);
+    }
+
+    // Store to C
+    auto store_c = [&](auto i) {
+        constexpr const int row = i / COLS;
+        constexpr const int col = i % COLS;
+        if constexpr (accum) {
+            __m512i vc_old = _mm512_loadu_si512((__m512i const*)(C + row * ldc + col * {{vec_len}}));
+            vc[i] = _mm512_add_epi32(vc[i], vc_old);
+        }
+        _mm512_storeu_si512((__m512i*)(C + row * ldc + col * {{vec_len}}), vc[i]);
+    };
+    c10::ForcedUnroll{}(store_c);
+}
+"""
+
+    def __init__(
+        self,
+        name,
+        input_dtype,
+        input2_dtype,
+        output_dtype,
+        compute_dtype,
+        register_blocking,
+        alpha=1,
+    ) -> None:
+        super().__init__(
+            name,
+            input_dtype,
+            input2_dtype,
+            output_dtype,
+            compute_dtype,
+            register_blocking,
+            alpha,
+        )
+        assert input_dtype == torch.uint8 and input2_dtype == torch.int8, (
+            f"Only u8s8s32 GEMM is supported by AVX512VNNI microkernel, got A:{input_dtype}, B:{input2_dtype}, C:{output_dtype}."
+        )
+
+    def codegen_define(self, kernel: CppTemplateKernel) -> str:
+        options = {
+            "declare_kernel": self.get_kernel_declaration(),
+            "kernel": kernel,
+            "block_m": self.register_blocking.block_m,
+            "block_n": self.register_blocking.block_n,
+            "block_k": self.register_blocking.block_k,
+            "restrict_keyword": get_restrict_keyword(),
+            "vec_len": 16,  # = 512 / 32 for C
+            **self.get_common_options(),
+        }
+        return KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render(
+            options
+        ) + KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(options)
+
+    def get_b_layout(self):
+        return LayoutType.VNNI4
+
+
+# extra check for CppMicroGemmAMX
+def check_amx_extra(config, m, n, k, alpha, num_threads, **kwargs):
+    vnni_size = 4 if config.input_dtype in [torch.uint8, torch.int8] else 2
+    return k % vnni_size == 0 and alpha == 1
+
+
+def check_int8_bf16_amx_extra(config, m, n, k, alpha, num_threads, **kwargs):
+    # We need avx512_bf16 to dequant int8 to bf16
+    vec_isa = kwargs.get("vec_isa")
+    assert vec_isa is not None
+    return vec_isa.is_avx512_bf16_supported() and check_amx_extra(
+        config, m, n, k, alpha, num_threads, **kwargs
+    )
+
+
+# amx_fp16 need to be checked separately since it is not always supported when amx is supported
+def check_amx_fp16_extra(config, m, n, k, alpha, num_threads, **kwargs):
+    assert config.input_dtype == torch.float16 and config.output_dtype == torch.float
+    vec_isa = kwargs.get("vec_isa")
+    assert vec_isa is not None
+    vnni_size = 2
+    return vec_isa.is_amx_fp16_supported() and k % vnni_size == 0 and alpha == 1
+
+
+@register_micro_gemm(
+    *generate_gemm_config(
+        VecAMX,
+        [(32, 32, 64), (48, 16, 64)],
+        input_dtype=torch.int8,
+        input2_dtype=torch.int8,
+        output_dtype=torch.int32,
+        compute_dtype=torch.int32,
+        extra_check=check_amx_extra,
+    ),
+    *generate_gemm_config(
+        VecAMX,
+        [(32, 32, 32), (48, 16, 32)],
+        input_dtype=torch.bfloat16,
+        input2_dtype=torch.int8,
+        output_dtype=torch.float,
+        compute_dtype=torch.float,
+        extra_check=check_int8_bf16_amx_extra,
+    ),
+    *generate_gemm_config(
+        VecAMX,
+        [(32, 16, 32), (32, 32, 32), (48, 16, 32), (16, 48, 32)],
+        input_dtype=torch.bfloat16,
+        output_dtype=torch.float,
+        extra_check=check_amx_extra,
+    ),
+    *generate_gemm_config(
+        VecAMX,
+        [(32, 32, 32), (48, 16, 32), (16, 48, 32)],
+        input_dtype=torch.float16,
+        output_dtype=torch.float,
+        extra_check=check_amx_fp16_extra,
+    ),
+    *generate_gemm_config(
+        VecAMX,
+        [(32, 32, 64), (48, 16, 64)],
+        input_dtype=torch.uint8,
+        input2_dtype=torch.int8,
+        output_dtype=torch.int32,
+        compute_dtype=torch.int32,
+        extra_check=check_amx_extra,
+    ),
+)
+class CppMicroGemmAMX(CppMicroGemm):
+    """
+    This class generates the code for micro gemm using Advanced Matrix extension (AMX)
+    instructions available in 4th generation Intel Xeon for compute.
+    It supports input types of torch.bfloat16 with fp32 output.
+    """
+
+    TEMPLATE_ENTRY = r"""
+{{declare_kernel}} {
+    {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
+    {{kernel.assert_function}}(K % 2 == 0, "K dimension must be multiple of 2");
+{%- if pack_vnni_B_locally %}
+    {{template.codegen_allocate_weight_buffer("packed_B_buf", input2_t, "K", block_n)}}
+{%- endif %}
+{%- if use_cached_dequantized_B %}
+    // Create a stack-allocated buffer for tiles of B.
+    // Except maybe for the tail-case, an AMX tile of B has 16x32 BF16 elements.
+    // we cache K * {{block_n}} elements of dequantized B
+    {{template.codegen_allocate_weight_buffer("dequantized_B_buf", input_t, "K", block_n)}}
+    const auto buf_size = K * {{block_n}};
+    auto load_dequantized_B = [&](int base_idx) {
+        // Load a tile of B & cache it in L1D.
+        {{input2_t}}* base_addr = const_cast<{{input2_t}}*>(B) + base_idx;
+        for (int idx_dq = 0, idx_q = 0; idx_dq < buf_size; idx_q += ldb, idx_dq += {{block_n}}) {
+        {%- for vec_idx in range(0, block_n, 32) %}
+            _mm_prefetch(base_addr + idx_q + 64 * ldb, _MM_HINT_T0);
+            {%- if (block_n - vec_idx) >= 32 %}
+            // 1) Load 32 x int8
+            __m256i v8  = _mm256_loadu_si256((const __m256i*)(base_addr + idx_q + {{vec_idx}}));
+            // 2) Extract two halves
+            __m128i v8_lo = _mm256_extracti128_si256(v8, 0);
+            __m128i v8_hi = _mm256_extracti128_si256(v8, 1);
+            // 3) Widen each half to i32
+            __m512i v32_lo = _mm512_cvtepi8_epi32(v8_lo);
+            __m512i v32_hi = _mm512_cvtepi8_epi32(v8_hi);
+            // 4) Convert to f32
+            __m512 f_lo = _mm512_cvtepi32_ps(v32_lo);
+            __m512 f_hi = _mm512_cvtepi32_ps(v32_hi);
+            // 5) f32 -> bf16 (round-to-nearest-even) and pack 32 lanes to 512b
+            // Packs the second operand (f_lo) into the lower 16 bf16 lanes and the first (f_hi) into the upper 16.
+            __m512i bf = (__m512i)_mm512_cvtne2ps_pbh(f_hi, f_lo);
+            // 6) Store 32 x bf16 (512 bits)
+            _mm512_storeu_si512((__m512i*)(dequantized_B_buf + idx_dq + {{vec_idx}}), bf);
+            {%- elif (block_n - vec_idx) >= 16 %}
+            // 1) Load 16 x int8 (128 bits)
+            __m128i v8 = _mm_loadu_si128((const __m128i*)(base_addr + idx_q + {{vec_idx}}));
+            // 2) Widen: 16 x i8 -> 16 x i32
+            __m512i v32 = _mm512_cvtepi8_epi32(v8);
+            // 3) Convert to f32
+            __m512 f32 = _mm512_cvtepi32_ps(v32);
+            // 4) Convert f32 -> bf16 (round-to-nearest-even)
+            __m256i bf16 = (__m256i)_mm512_cvtneps_pbh(f32);
+            // 5) Store 16 x bf16 (256 bits)
+            _mm256_storeu_si256((__m256i*)(dequantized_B_buf + idx_dq + {{vec_idx}}), bf16);
+            {%- else %}
+            auto b_int8_tail = at::vec::Vectorized::loadu(
+                base_addr + idx_q + {{block_n - (block_n % 32)}},
+                static_cast({{block_n % 32}})
+            );
+            auto b_bf16_tail = at::vec::convert<{{input_t}}>(b_int8_tail);
+            b_bf16_tail.store(
+                dequantized_B_buf + idx_dq + {{block_n - (block_n % 32)}},
+                static_cast({{block_n % 32}})
+            );
+            {%- endif %}
+        {%- endfor %}
+        }
+    };
+{%- endif %}
+// The ldb would not be block_n if N != block_n
+{%- if use_cached_dequantized_B or pack_vnni_B_locally %}
+    const int64_t updated_ldb = {{block_n}};
+{%- else %}
+    const int64_t updated_ldb = ldb;
+{%- endif %}
+    // TODO(jgong5): loop unroll for M and N
+    for (int64_t n = 0; n < N; n += {{block_n}}) {
+{%- if pack_vnni_B_locally %}
+        // Pack non-constant weights into VNNI interleaved format in packed_B_buf
+        at::vec::pack_vnni2(B + n, packed_B_buf, ldb, K, {{block_n}});
+{%- elif use_cached_dequantized_B %}
+        // Dequantize K * block_n int8 B elements into BF16
+        load_dequantized_B(n);
+{%- endif %}
+        for (int64_t m = 0; m < M; m += {{block_m}}) {
+            int64_t block_m = std::min(M - m, {{block_m}});
+            int64_t m_tail = m;
+{%- for num_rows in range(block_m, 0, -16) %}
+    {%- if num_rows != block_m %}
+            else
+    {%- endif %}
+            if (block_m >= {{num_rows}}) {
+                {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}(
+                    amx_state,
+                    A + m * lda,
+{%- if use_cached_dequantized_B %}
+                    dequantized_B_buf,
+{%- elif pack_vnni_B_locally %}
+                    packed_B_buf,
+{%- else %}
+                    B + n,
+{%- endif %}
+                    C + m * ldc + n,
+                    K,
+                    lda,
+                    updated_ldb,
+                    ldc,
+                    16
+                );
+                block_m -= {{num_rows}};
+                m_tail += {{num_rows}};
+            }
+{%- endfor %}
+            if (block_m > 0) {
+                {{kernel_name}}_amx_kernel_16_{{num_columns}}(
+                    amx_state,
+                    A + m_tail * lda,
+{%- if use_cached_dequantized_B %}
+                    dequantized_B_buf,
+{%- elif pack_vnni_B_locally %}
+                    packed_B_buf,
+{%- else %}
+                    B + n,
+{%- endif %}
+                    C + m_tail * ldc + n,
+                    K,
+                    lda,
+                    updated_ldb,
+                    ldc,
+                    block_m
+                );
+            }
+        }
+    }
+}
+"""
+
+    TEMPLATE_KERNEL = r"""
+
+template 
+inline void {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}(
+    AMXState& amx_state,
+    const {{input_t}}* {{restrict_keyword}} A,
+{%- if use_cached_dequantized_B %}
+    const {{input_t}}* {{restrict_keyword}} B,
+{%- else %}
+    const {{input2_t}}* {{restrict_keyword}} B,
+{%- endif %}
+    {{output_t}}* {{restrict_keyword}} C,
+    int64_t K,
+    int64_t lda,
+    int64_t ldb,
+    int64_t ldc,
+    uint8_t tilecfg_rows
+) {
+    // TODO(jgong5): add prefetch hint for A, B, C
+    auto loadconfig = [](const amx_tilecfg& cfg) {
+        _tile_loadconfig(&cfg);
+    };
+    const auto last_k_offset = K / {{block_k}} * {{block_k}};
+    const auto tail_k_size = K - last_k_offset;
+    if C10_LIKELY (last_k_offset > 0) {
+        amx_state.configure(tilecfg_rows, 64, {{num_rows}} / 16, {{num_columns}}, loadconfig);
+    } else {
+        amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig);
+    }
+    auto load_c = [&]() {
+{%- for tile_row in range(num_rows // 16) %}
+    {%- for tile_col in range(num_columns) %}
+        {%- set tile_idx = tile_row * num_columns + tile_col %}
+        _tile_loadd({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}}));
+    {%- endfor %}
+{%- endfor %}
+    };
+    auto zero_c = [&]() {
+{%- for tile_row in range(num_rows // 16) %}
+    {%- for tile_col in range(num_columns) %}
+        {%- set tile_idx = tile_row * num_columns + tile_col %}
+        _tile_zero({{tile_idx}});
+    {%- endfor %}
+{%- endfor %}
+    };
+
+    if constexpr (accum) {
+        load_c();
+    } else {
+        zero_c();
+    }
+
+    auto compute = [&](int k) {
+{%- set tile_offset_a = num_rows // 16 * num_columns %}
+{%- set tile_offset_b = tile_offset_a + num_rows // 16 %}
+{%- for tile_row in range(num_rows // 16) %}
+    {%- for tile_col in range(num_columns) %}
+        {%- set tile_idx_a = tile_offset_a + tile_row %}
+        {%- set tile_idx_b = tile_offset_b + tile_col %}
+        {%- set tile_idx_c = tile_row * num_columns + tile_col %}
+        {%- if tile_col == 0 %}
+        _tile_stream_loadd({{tile_idx_a}}, A + {{tile_row * 16}} * lda + k, lda * sizeof({{input_t}}));
+        {%- endif %}
+        {%- if tile_row == 0 %}
+        _tile_loadd({{tile_idx_b}}, B + k * ldb + {{tile_col * 16 * vnni_size}}, ldb * {{vnni_size}} * sizeof({{input_t}}));
+        {%- endif %}
+        {%- if int8_gemm %}
+            {%- if input_dtype == torch.int8 %}
+        _tile_dpbssd({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}});
+            {%- else %}
+        _tile_dpbusd({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}});
+            {%- endif %}
+        {%- else %}
+            {%- if input_dtype == torch.float16 %}
+        _tile_dpfp16ps({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}});
+            {%- else %}
+        _tile_dpbf16ps({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}});
+            {%- endif %}
+        {%- endif %}
+    {%- endfor %}
+{%- endfor %}
+    };
+
+    {{kernel.unroll_pragma(4)}}
+    for (int k = 0; k < last_k_offset; k += {{block_k}}) {
+        compute(k);
+    }
+
+    auto store_c = [&]() {
+    // store to C
+{%- for tile_row in range(num_rows // 16) %}
+    {%- for tile_col in range(num_columns) %}
+        {%- set tile_idx = tile_row * num_columns + tile_col %}
+        _tile_stored({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}}));
+    {%- endfor %}
+{%- endfor %}
+    };
+
+    // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead
+    if C10_UNLIKELY (tail_k_size > 0) {
+        if C10_LIKELY (last_k_offset > 0) {
+            store_c();
+            amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig);
+            load_c();
+        }
+        compute(last_k_offset);
+    }
+
+    store_c();
+}
+"""
+
+    def codegen_define(self, kernel: CppTemplateKernel) -> str:
+        block_m, block_n, block_k = self.register_blocking
+        assert block_m % 16 == 0, "Only support block_m % 16 == 0 for AMX"
+        assert block_n % 16 == 0, "Only support block_n % 16 == 0 for AMX"
+        if self.input_dtype in [torch.uint8, torch.int8]:
+            assert block_k == 64, "Only support block_k = 64 for AMX INT8"
+        else:
+            assert block_k == 32, "Only support block_k = 32 for AMX Bfloat16/Float16"
+        num_columns = block_n // 16
+        options = {
+            "declare_kernel": self.get_kernel_declaration(),
+            "use_cached_dequantized_B": (
+                self.input_dtype == torch.bfloat16
+                and self.input2_dtype in [torch.int8, torch.uint8]
+            ),
+            "kernel": kernel,
+            "block_m": block_m,
+            "block_n": block_n,
+            "block_k": block_k,
+            "num_columns": num_columns,
+            "restrict_keyword": get_restrict_keyword(),
+            **self.get_common_options(),
+        }
+        result = ""
+        for num_rows in range(block_m, 0, -16):
+            amx_kernel_options = {**options, "num_rows": num_rows}
+            result += KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render(
+                amx_kernel_options
+            )
+        result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(
+            options
+        )
+        return result
+
+    def codegen_init(
+        self,
+        kernel: CppTemplateKernel,
+    ) -> str:
+        return "AMXState amx_state;"
+
+    def codegen_finalize(
+        self,
+        kernel: CppTemplateKernel,
+    ) -> str:
+        return "amx_state.release([]() { _tile_release(); });"
+
+    def get_kernel_extra_args_declare(self) -> str:
+        return "AMXState& amx_state,"
+
+    def get_kernel_extra_args(self, **kwargs) -> list[str]:
+        return ["amx_state,"]
+
+    def get_b_layout(self):
+        if self.input_dtype in [torch.uint8, torch.int8]:
+            return LayoutType.VNNI4
+        else:
+            return LayoutType.VNNI2
+
+
+# extra check for CppMicroBrgemm
+def check_brgemm_extra(config, m, n, k, alpha, num_threads, **kwargs):
+    assert config.input_dtype == torch.half and config.output_dtype == torch.float
+    vnni_size = 2
+    # use brgemm for Half when amx_fp16 is supported
+    return torch.cpu._is_amx_fp16_supported() and k % vnni_size == 0 and alpha == 1
+
+
+@register_micro_gemm(
+    *generate_gemm_config(
+        VecAMX,
+        [(32, 32, 32), (48, 16, 32), (16, 48, 32)],
+        input_dtype=torch.half,
+        output_dtype=torch.float,
+        extra_check=check_brgemm_extra,
+    ),
+)
+class CppMicroBrgemm(CppMicroGemm):
+    """
+    This class generates the code for micro gemm using oneDNN brgemm.
+    It supports input types of torch.half.
+    """
+
+    TEMPLATE_ENTRY = r"""
+#include 
+{{declare_kernel}} {
+{%- if pack_vnni_B_locally %}
+    {{template.codegen_allocate_weight_buffer("packed_B_buf", input2_t, "K * N")}}
+    at::vec::pack_vnni2(B, packed_B_buf, ldb, K, N);
+{%- endif %}
+    at::native::cpublas::brgemm(
+      M, N, K,
+    {%- if pack_vnni_B_locally %}
+      lda, N, ldc,
+    {%- else %}
+      lda, ldb, ldc,
+    {%- endif %}
+      accum,
+      A,
+    {%- if pack_vnni_B_locally %}
+      packed_B_buf,
+    {%- else %}
+      B,
+    {%- endif %}
+      C);
+}
+"""
+
+    def codegen_define(self, kernel: CppTemplateKernel) -> str:
+        options = {
+            "declare_kernel": self.get_kernel_declaration(),
+            "kernel": kernel,
+            "block_m": self.register_blocking.block_m,
+            "block_n": self.register_blocking.block_n,
+            "block_k": self.register_blocking.block_k,
+            "restrict_keyword": get_restrict_keyword(),
+            **self.get_common_options(),
+        }
+        result = ""
+        result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(
+            options
+        )
+        return result
+
+    def codegen_finalize(
+        self,
+        kernel: CppTemplateKernel,
+    ) -> str:
+        return "at::native::cpublas::brgemm_release();"
+
+    def get_b_layout(self):
+        assert self.input_dtype == torch.half and torch.cpu._is_amx_fp16_supported()
+        return LayoutType.VNNI2
+
+
+def check_woq_int4_extra(config, m, n, k, alpha, num_threads, **kwargs):
+    if alpha != 1:
+        return False
+    q_group_size = kwargs.get("q_group_size")
+    assert q_group_size is not None
+    if (
+        q_group_size not in [32, 64, 128]
+        or k % q_group_size != 0
+        or config.register_blocking.block_k > q_group_size
+    ):
+        return False
+    return k % config.register_blocking.block_k == 0 and n % 64 == 0
+
+
+@register_micro_gemm(
+    # TODO: support float/half input
+    *generate_gemm_config(
+        VecAVX512,
+        [(4, 64, 32), (4, 64, 64), (4, 64, 128)],
+        input_dtype=torch.bfloat16,
+        input2_dtype=torch.uint8,
+        output_dtype=torch.float,
+        compute_dtype=torch.float,
+        extra_check=check_woq_int4_extra,
+    ),
+)
+class CppMicroGemmWoQInt4Avx512(CppMicroGemmFP32Vec):
+    """
+    This class generates the code for WoQ int4 micro gemm using AVX512 intrinsics.
+    It is based on the corresponding ATen kernel.
+    Shape of packed weight = [N // 64, K, 32], viewed as [N, K // 2]
+    Shape of packed ScalesAndZeros = [K // group_size, N, 2]
+    """
+
+    TEMPLATE_ENTRY = r"""
+{{declare_kernel}} {
+    {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
+    {{kernel.assert_function}}(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}");
+    auto group_size = q_group_size;
+    for (int64_t m = 0; m < M; m += {{block_m}}) {
+        int64_t block_m = std::min(M - m, {{block_m}});
+        for (int64_t n = 0; n < N; n += {{block_n}}) {
+            if (block_m == {{block_m}}) {
+                {{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum>(
+                    A + m * lda,
+                    reinterpret_cast(B) + n * ldb,
+                    C + m * ldc + n,
+                    K,
+                    lda,
+                    /* ldb */ {{block_n}} / 2,
+                    ldc,
+                    group_size,
+                    ScaleAndZeros + n * 2,
+                    lds,
+                    k_start
+                );
+            } else {
+                switch (block_m) {
+                {%- for b in range(block_m - 1, 0, -1) %}
+                case {{b}}:
+                    {{kernel_name}}_kernel<{{b}}, {{block_n}}, accum>(
+                        A + m * lda,
+                        reinterpret_cast(B) + n * ldb,
+                        C + m * ldc + n,
+                        K,
+                        lda,
+                        /* ldb */ {{block_n}} / 2,
+                        ldc,
+                        group_size,
+                        ScaleAndZeros + n * 2,
+                        lds,
+                        k_start
+                    );
+                    break;
+                {%- endfor %}
+                default:
+                    {{kernel.assert_function}}(false, "Unsupported block_m: ", block_m);
+                }
+            }
+        }
+    }
+}
+"""
+
+    TEMPLATE_KERNEL = r"""
+inline bool {{kernel_name}}_is_block_start(int index, int k_start, int group_size) {
+  return (k_start + index) % group_size == 0;
+}
+
+inline __m128i {{kernel_name}}_convert_int4_to_int8(const uint8_t* data) {
+  __m128i tmp = _mm_loadu_si64((const __m128i*)data);
+  __m128i bytes = _mm_cvtepu8_epi16(tmp);
+  const __m128i lowMask = _mm_set1_epi8(0xF);
+  __m128i high = _mm_andnot_si128(lowMask, bytes);
+  __m128i low = _mm_and_si128(lowMask, bytes);
+  high = _mm_slli_epi16(high, 4);
+  bytes = _mm_or_si128(low, high);
+  return bytes;
+}
+
+template 
+inline void {{kernel_name}}_kernel(
+    const {{input_t}}* {{restrict_keyword}} A,
+    const uint8_t* {{restrict_keyword}} B,
+    {{output_t}}* {{restrict_keyword}} C,
+    int64_t K,
+    int64_t lda,
+    int64_t ldb,
+    int64_t ldc,
+    int64_t q_group_size,
+    const at::BFloat16* {{restrict_keyword}} ScaleAndZeros,
+    int64_t lds, // leading dimension of ScaleAndZeros
+    int64_t k_start) {
+  constexpr int BLOCK_K = {{block_k}};
+  constexpr int ROWS = BLOCK_M;
+  constexpr int COLS = BLOCK_N / 16;
+
+  const int PREFETCH_SIZE_K = 16 * 4;
+  const int PREFETCH_SIZE_KB = (PREFETCH_SIZE_K + BLOCK_K - 1) / BLOCK_K;
+
+  // number of blocks on K
+  const int KB = K / BLOCK_K;
+
+  __m512 va;
+  __m512 vb[COLS];
+  __m512 vc[ROWS * COLS];
+  __m512 scale[COLS];
+  __m512 zero[COLS];
+
+  // Lookup table to de-quantize int4 values to bf16.
+  // Values are dequantized as truly int4 [-8, 7] range;
+  //
+  // dequant = (bf16(int4_value) * bf16_scale) + bf16_zero
+  //
+  static const __m512 lut = _mm512_set_ps(
+      7.0f, 6.0f, 5.0f, 4.0f,
+      3.0f, 2.0f, 1.0f, 0.0f,
+      -1.0f, -2.0f, -3.0f, -4.0f,
+      -5.0f, -6.0f, -7.0f, -8.0f);
+
+  // index for transpose
+  static const __m512i idx1 = _mm512_set_epi32(
+      30, 28, 26, 24, 22, 20, 18, 16,
+      14, 12, 10, 8, 6, 4, 2, 0);
+  static const __m512i idx2 = _mm512_set_epi32(
+      31, 29, 27, 25, 23, 21, 19, 17,
+      15, 13, 11, 9, 7, 5, 3, 1);
+
+  // load scale and zero point
+  auto load_scale_and_zeros = [&](int i, int _kb) {
+    // load 2x bfloat16 vector
+    __m512i t = _mm512_loadu_si512((__m512i*)(ScaleAndZeros + _kb * lds + 32 * i));
+    _mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * lds + 32 * i, _MM_HINT_T0);
+
+    // convert to 2x f32 vector
+    __m512 a, b;
+    at::vec::cvtbf16_fp32(t, a, b);
+
+    // transpose scale_and_zero from {16, 2} to {2, 16}
+    // inputs:
+    //   a: {s0, z0, s1, z1, ..., s7, z7}
+    //   b: {s8, z8, s9, z9, ..., s15, z15}
+    // output:
+    //   scale: {s0, s1, s2, ..., s15}
+    //   zero:  {z0, z1, z2, ..., z15}
+    scale[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b);
+    zero[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b);
+  };
+
+  auto loadc = [&](auto i) {
+    if constexpr (accum) {
+       constexpr int row = i / COLS;
+       constexpr int col = i % COLS;
+       vc[i] = _mm512_loadu_ps(C + row * ldc + col * 16);
+    } else {
+      vc[i] = _mm512_setzero_ps();
+    }
+  };
+  c10::ForcedUnroll{}(loadc);
+
+  auto compute = [&, COLS](auto i, int k) {
+    constexpr  int row = i / COLS;
+    constexpr  int col = i % COLS;
+
+    if constexpr (col == 0) {
+      float aa = static_cast(A[row * lda + k]);
+      _mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0);
+      va = _mm512_set1_ps(aa);
+    }
+
+    if constexpr (row == 0) {
+      if constexpr (COLS == 4) {
+        // when BLOCK_N = 64, handle each row at a time
+        // to reduce de-quantize overhead.
+        if constexpr (col == 0) {
+          __m256i b4 = _mm256_loadu_si256((__m256i*)(B + k * ldb));
+          _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb, _MM_HINT_T0);
+
+          __m512i b32 = _mm512_cvtepu8_epi32(_mm256_castsi256_si128(b4));
+          vb[0] = _mm512_permutexvar_ps(b32, lut);
+          vb[0] = _mm512_fmadd_ps(vb[0], scale[0], zero[0]);
+          vb[2] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
+          vb[2] = _mm512_fmadd_ps(vb[2], scale[2], zero[2]);
+
+          b32 = _mm512_cvtepu8_epi32(_mm256_extracti128_si256(b4, 1));
+          vb[1] = _mm512_permutexvar_ps(b32, lut);
+          vb[1] = _mm512_fmadd_ps(vb[1], scale[1], zero[1]);
+          vb[3] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
+          vb[3] = _mm512_fmadd_ps(vb[3], scale[3], zero[3]);
+        }
+      } else {
+        __m128i b8 = {{kernel_name}}_convert_int4_to_int8(B + k * ldb + col * 8);
+        __m512i b32 = _mm512_cvtepu8_epi32(b8);
+        vb[col] = _mm512_permutexvar_ps(b32, lut);
+        vb[col] = _mm512_fmadd_ps(vb[col], scale[col], zero[col]);
+      }
+    }
+
+    constexpr int idx = row * COLS + col;
+    vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);
+  };
+
+  for (int k = 0, kb = 0; k < K; ++k) {
+    if ({{kernel_name}}_is_block_start(k, k_start, q_group_size)) {
+      c10::ForcedUnroll{}(load_scale_and_zeros, kb++);
+    }
+    c10::ForcedUnroll{}(compute, k);
+  }
+
+  //store to C
+  auto storec = [&, COLS](auto i) {
+    constexpr int row = i / COLS;
+    constexpr int col = i % COLS;
+    _mm512_storeu_ps(C + row * ldc + col * 16, vc[i]);
+  };
+  c10::ForcedUnroll{}(storec);
+}
+"""
+
+    def get_kernel_extra_args_declare(self) -> str:
+        return (
+            "const int64_t q_group_size,\n"
+            "    const at::BFloat16* __restrict__ ScaleAndZeros,\n"
+            "    const int64_t lds,\n"
+            "    int64_t k_start,"
+        )
+
+    def get_kernel_extra_args(self, **kwargs) -> list[str]:
+        assert "kernel" in kwargs
+        assert "qscale_and_zeros" in kwargs
+        kernel = kwargs["kernel"]
+        qscale_and_zeros = kwargs["qscale_and_zeros"]
+        return [
+            "group_size,",
+            f"&({kernel.index(qscale_and_zeros, [0, 0, 0])}),",
+            "N * 2,",  # lds
+            "k_start,",
+        ]
+
+    def is_woq_int4(self):
+        return True
+
+
+@register_micro_gemm(
+    *generate_gemm_config(
+        VecAMX,
+        [  # (block_m, block_n, block_k)
+            (16, 32, 32),
+            (32, 32, 32),
+        ],
+        input_dtype=torch.bfloat16,
+        input2_dtype=torch.uint8,
+        output_dtype=torch.float,
+        compute_dtype=torch.float,
+        extra_check=check_amx_extra,
+    ),
+)
+class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX):
+    """
+    This class generates the code for WoQ int4 micro gemm using AMX intrinsics,
+    which are available on 4th and newer generations of Intel Xeon.
+    Shape of packed weight = [N // 32, K, 16], viewed as [N, K // 2]
+    Shape of packed ScalesAndZeros = [K // group_size, N, 2]
+    Reuse TEMPLATE_KERNEL of CppMicroGemmAMX.
+    """
+
+    TEMPLATE_ENTRY = r"""
+inline bool {{kernel_name}}_is_block_start(int index, int k_start, int group_size) {
+  // check if (k_start + index) % group_size == 0, assuming group_size = 32/64/128
+  return ((k_start + index) & (group_size - 1)) == 0;
+}
+
+{{declare_kernel}} {
+    {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
+    {{kernel.assert_function}}(K % 2 == 0, "K dimension must be multiple of 2");
+    {{kernel.assert_function}}({{block_n}} == 32, "block_n must be 32 for WOQ int4");
+
+    // Create a stack-allocated buffer for tiles of B.
+    // Except maybe for the tail-case, an AMX tile of B has 16x32 BF16 elements.
+    // we cache K * {{block_n}} elements of dequantized B
+    {{template.codegen_allocate_weight_buffer("dequantized_B_buf", input_t, "K", block_n)}}
+
+    constexpr int BLOCK_K = {{block_k}};
+    constexpr int64_t BLOCK_N = {{block_n}};
+    constexpr int COLS = BLOCK_N / 16;
+    const int PREFETCH_SIZE_K = 16 * 4;
+    const int PREFETCH_SIZE_KB = (PREFETCH_SIZE_K + BLOCK_K - 1) / BLOCK_K;
+    const int KB = K / BLOCK_K;
+
+    __m512i b32[COLS * 2];
+    __m512 vb[COLS * 2];
+    __m512 scale[COLS];
+    __m512 zero[COLS];
+
+    // Lookup table to de-quantize int4 values to bf16.
+    // Values are dequantized as truly int4 [-8, 7] range;
+    //
+    // dequant = (bf16(int4_value) * bf16_scale) + bf16_zero
+    //
+    static const __m512 lut = _mm512_set_ps(
+        7.0f, 6.0f, 5.0f, 4.0f,
+        3.0f, 2.0f, 1.0f, 0.0f,
+        -1.0f, -2.0f, -3.0f, -4.0f,
+        -5.0f, -6.0f, -7.0f, -8.0f);
+
+    // index for transpose
+    static const __m512i idx1 = _mm512_set_epi32(
+        30, 28, 26, 24, 22, 20, 18, 16,
+        14, 12, 10, 8, 6, 4, 2, 0);
+    static const __m512i idx2 = _mm512_set_epi32(
+        31, 29, 27, 25, 23, 21, 19, 17,
+        15, 13, 11, 9, 7, 5, 3, 1);
+
+    // Indices for VNNI layout conversion
+    __m512i idx_low = _mm512_set_epi32(
+        0x17,
+        0x07,
+        0x16,
+        0x06,
+        0x15,
+        0x05,
+        0x14,
+        0x04,
+        0x13,
+        0x03,
+        0x12,
+        0x02,
+        0x11,
+        0x01,
+        0x10,
+        0x00);
+    __m512i idx_high = _mm512_set_epi32(
+        0x1f,
+        0x0f,
+        0x1e,
+        0x0e,
+        0x1d,
+        0x0d,
+        0x1c,
+        0x0c,
+        0x1b,
+        0x0b,
+        0x1a,
+        0x0a,
+        0x19,
+        0x09,
+        0x18,
+        0x08);
+
+    // load scale and zero point
+    auto load_scale_and_zeros = [&](int i, int _kb) {
+        // load 2x bfloat16 vector
+        __m512i t = _mm512_loadu_si512((__m512i*)(ScaleAndZeros + _kb * lds + 32 * i));
+        _mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * lds + 32 * i, _MM_HINT_T0);
+
+        // convert to 2x f32 vector
+        __m512 a, b;
+        at::vec::cvtbf16_fp32(t, a, b);
+
+        // transpose scale_and_zero from {16, 2} to {2, 16}
+        // inputs:
+        //   a: {s0, z0, s1, z1, ..., s7, z7}
+        //   b: {s8, z8, s9, z9, ..., s15, z15}
+        // output:
+        //   scale: {s0, s1, s2, ..., s15}
+        //   zero:  {z0, z1, z2, ..., z15}
+        scale[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b);
+        zero[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b);
+    };
+
+    // Dequantize a B block of 2 * block_n into bf16
+    // So, it handles k and k+1 at the same time
+    auto dequantize_B = [&](int n) {
+        constexpr int64_t ldb_int4 = BLOCK_N / 2; // 16
+        for (int k = 0, kb = 0; k < K; k += 2) {
+            // Since block_k must be 32 for AMX microkernels, k_start may not be
+            // a multiple of q_group_size. In that case, we need to load scales
+            // and zero points immediately when k == 0 here
+            if ({{kernel_name}}_is_block_start(k, k_start, q_group_size) || k == 0) {
+                c10::ForcedUnroll{}(load_scale_and_zeros, kb++);
+            }
+
+            _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb_int4, _MM_HINT_T0);
+
+            // load 256 bits = 64 elements in int4
+            __m128i b4 = _mm_loadu_si128((__m128i*)(B + n / 2 * K + k * ldb_int4));
+            b32[0] = _mm512_cvtepu8_epi32(b4);
+            b32[1] = _mm512_srli_epi32(b32[0], 4);
+            vb[0] = _mm512_permutexvar_ps(b32[0] , lut);
+            vb[0] = _mm512_fmadd_ps(vb[0], scale[0], zero[0]);
+            vb[1] = _mm512_permutexvar_ps(b32[1], lut);
+            vb[1] = _mm512_fmadd_ps(vb[1], scale[1], zero[1]);
+
+            __m128i b4_2 = _mm_loadu_si128((__m128i*)(B + n / 2 * K + (k + 1) * ldb_int4));
+            b32[0 + COLS] = _mm512_cvtepu8_epi32(b4_2);
+            b32[1 + COLS] = _mm512_srli_epi32(b32[0 + COLS], 4);
+            vb[0 + COLS] = _mm512_permutexvar_ps(b32[0 + COLS] , lut);
+            vb[0 + COLS] = _mm512_fmadd_ps(vb[0 + COLS], scale[0], zero[0]);
+            vb[1 + COLS] = _mm512_permutexvar_ps(b32[1 + COLS], lut);
+            vb[1 + COLS] = _mm512_fmadd_ps(vb[1 + COLS], scale[1], zero[1]);
+
+            for (int i = 0; i < COLS; i++) {
+                // convert to VNNI
+                auto low = _mm512_permutex2var_ps(vb[i], idx_low, vb[i + COLS]);
+                auto high = _mm512_permutex2var_ps(vb[i], idx_high, vb[i + COLS]);
+                // convert lower 16 float32 values to bfloat16
+                auto v0_bf16 = reinterpret_cast<__m256i>(_mm512_cvtneps_pbh(low));
+                // convert higher 16 float32 values to bfloat16
+                auto v1_bf16 = reinterpret_cast<__m256i>(_mm512_cvtneps_pbh(high));
+                // combine the lower 16 and higher 16 bfloat16 values
+                auto v = _mm512_castsi256_si512(v0_bf16);
+                v = _mm512_inserti64x4(v, v1_bf16, 1);
+                // store the VNNI format bfloat16 values
+                {{input_t}}* addr = dequantized_B_buf + k * 32 + (i % 2) * 32;
+                _mm512_storeu_si512(addr, v);
+            }
+        }
+    };
+
+    for (int64_t n = 0; n < N; n += {{block_n}}) {
+        // Dequantize K * block_n int8 B elements into BF16
+        dequantize_B(n);
+        for (int64_t m = 0; m < M; m += {{block_m}}) {
+            int64_t block_m = std::min(M - m, {{block_m}});
+            int64_t m_tail = m;
+        {%- for num_rows in range(block_m, 0, -16) %}
+            {%- if num_rows != block_m %}
+            else
+        {%- endif %}
+            if (block_m >= {{num_rows}}) {
+                {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}(
+                    amx_state,
+                    A + m * lda,
+                    dequantized_B_buf + n * K,
+                    C + m * ldc + n,
+                    K,
+                    lda,
+                    {{block_n}},
+                    ldc,
+                    16
+                );
+                block_m -= {{num_rows}};
+                m_tail += {{num_rows}};
+            }
+        {%- endfor %}
+            if (block_m > 0) {
+                {{kernel_name}}_amx_kernel_16_{{num_columns}}(
+                    amx_state,
+                    A + m_tail * lda,
+                    dequantized_B_buf + n * K,
+                    C + m_tail * ldc + n,
+                    K,
+                    lda,
+                    {{block_n}},
+                    ldc,
+                    block_m
+                );
+            }
+        } // for m
+    } // for n
+}
+"""
+
+    def get_kernel_extra_args_declare(self) -> str:
+        return (
+            "AMXState& amx_state,\n"
+            "    const int64_t q_group_size,\n"
+            "    const c10::BFloat16* __restrict__ ScaleAndZeros,\n"
+            "    const int64_t lds,\n"
+            "    int64_t k_start,"
+        )
+
+    def get_kernel_extra_args(self, **kwargs) -> list[str]:
+        assert "kernel" in kwargs
+        assert "qscale_and_zeros" in kwargs
+        kernel = kwargs["kernel"]
+        qscale_and_zeros = kwargs["qscale_and_zeros"]
+        return [
+            "amx_state,",
+            "group_size,",
+            f"&({kernel.index(qscale_and_zeros, [0, 0, 0])}),",
+            "N * 2,",  # lds
+            "k_start,",
+        ]
+
+    def is_woq_int4(self):
+        return True
+
+
+def create_micro_gemm(
+    name,
+    m,
+    n,
+    k,
+    input_dtype,
+    input2_dtype,
+    output_dtype=None,
+    compute_dtype=None,
+    alpha=1,
+    num_threads=-1,
+    use_ref=True,
+    q_group_size=None,
+) -> Optional[CppMicroGemm]:
+    """
+    Based on the provided info, try to find the config of the micro-kernel that would
+    deliver the best performance in terms of lower latency for this case.
+    """
+
+    def create_from_config(cls, config: CppMicroGemmConfig):
+        return cls(
+            name,
+            config.input_dtype,
+            config.input2_dtype,
+            config.output_dtype,
+            config.compute_dtype,
+            config.register_blocking,
+            alpha,
+        )
+
+    def skip_amx_kernel_for_woq(dynamic_M):
+        # For WoQ GEMM, AMX micro-kernel may not perform well if m is small.
+        # Exception: for dynamic shapes, we consider using the AMX micro-kernel.
+        if (
+            dynamic_M
+            or input_dtype != torch.bfloat16
+            or input2_dtype not in [torch.int8, torch.uint8]
+        ):
+            return False
+        m_threshold = 5
+        return m < m_threshold
+
+    assert isinstance(n, int) or n.is_number, n
+    assert isinstance(k, int) or k.is_number, k
+    from ..utils import has_free_symbols
+
+    dynamic_M = has_free_symbols((m,))
+    m = V.graph.sizevars.size_hint(m, fallback=1) if dynamic_M else m
+    assert isinstance(m, int) or m.is_number, m
+    if output_dtype is None:
+        output_dtype = input_dtype
+    if compute_dtype is None:
+        compute_dtype = output_dtype
+    if num_threads < 0:
+        num_threads = parallel_num_threads()
+    vec_isa = pick_vec_isa()
+    matched_configs = []
+    for cls, configs in micro_gemm_configs.items():
+        for config in configs:
+            if not issubclass(vec_isa.__class__, config.vec_isa_cls):
+                continue
+            if (
+                config.input_dtype == input_dtype
+                and config.compute_dtype == compute_dtype
+                and config.input2_dtype == input2_dtype
+                and config.output_dtype == output_dtype
+                # The output_dtype here is the output dtype of the micro-kernel.
+                # In some cases, the actual output dtype of the op for which the micro-kernel
+                # is being created would be same as that of the activation, but the micro-kernels
+                # compute output in Float/int32, which is converted in the GEMM template. This is
+                # subject to change in the future.
+            ):
+                if config.extra_check is not None and not config.extra_check(
+                    config,
+                    m,
+                    n,
+                    k,
+                    alpha,
+                    num_threads,
+                    dynamic_M=dynamic_M,
+                    q_group_size=q_group_size,
+                    vec_isa=vec_isa,
+                ):
+                    continue
+                block_m, block_n, block_k = config.register_blocking
+                if config.vec_isa_cls == VecAMX and skip_amx_kernel_for_woq(dynamic_M):
+                    continue
+                # Criteria on the ranking of configurations
+                # 1. ISA: AMX > VNNI > VEC
+                # 2. Dividable by block sizes (block_m, block_n, block_k)
+                # 3. Number of mxn blocks is large enough to occupy all the threads
+                # 4. Register blocks are larger
+                isa_score = 0
+                if config.vec_isa_cls == VecAMX:
+                    isa_score += 2
+                elif config.vec_isa_cls == VecAVX512VNNI:
+                    isa_score += 1
+                dividable_score = 0
+                if m % block_m == 0:
+                    dividable_score += 1
+                if n % block_n == 0:
+                    dividable_score += 1
+                if k % block_k == 0:
+                    dividable_score += 1
+                occupancy_score = 0
+                n_blocks = (n + block_n - 1) // block_n
+                total_mxn_blocks = n_blocks * ((m + block_m - 1) // block_m)
+                if n_blocks >= num_threads:
+                    occupancy_score += 1
+                if total_mxn_blocks >= num_threads:
+                    occupancy_score += 1
+                register_bytes = (
+                    block_m * block_n * config.compute_dtype.itemsize
+                    + (block_m * block_k + block_k * block_n)
+                    * config.input_dtype.itemsize
+                )
+                size_score = register_bytes
+                # if number of mxn blocks can not occupy all the threads,
+                # we favor smaller register blocks.
+                if occupancy_score == 0:
+                    size_score = 0 - register_bytes
+                matched_configs.append(
+                    (
+                        (isa_score, dividable_score, occupancy_score, size_score),
+                        cls,
+                        config,
+                    )
+                )
+    if len(matched_configs) == 0:
+        if use_ref:
+            return CppMicroGemmRef(
+                name, input_dtype, input2_dtype, output_dtype, compute_dtype, alpha
+            )
+        else:
+            return None
+    # TODO(jgong5): allow autotuning on choices of configs
+    return create_from_config(*max(matched_configs, key=operator.itemgetter(0))[1:])
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_template.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_template.py
new file mode 100644
index 0000000000000000000000000000000000000000..c01ca4363685deff18328b48026fa2d33f92e29f
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_template.py
@@ -0,0 +1,140 @@
+# mypy: allow-untyped-defs
+import ctypes
+import functools
+import itertools
+import logging
+import sys
+from collections.abc import Callable, Iterable
+from typing import Optional, Union
+from unittest.mock import patch
+
+import sympy
+
+from .. import config, ir
+from ..autotune_process import CppBenchmarkRequest, TensorMeta
+from ..utils import IndentedBuffer, Placeholder, unique
+from ..virtualized import V
+from .common import KernelTemplate
+from .cpp_template_kernel import CppTemplateCaller, CppTemplateKernel
+
+
+log = logging.getLogger(__name__)
+
+
+class CppTemplate(KernelTemplate):
+    index_counter = itertools.count()
+
+    def __init__(
+        self,
+        name: str,
+        input_nodes,
+        layout: ir.Layout,
+        num_threads: int,
+        epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
+    ) -> None:
+        super().__init__(name)
+        self.input_nodes = input_nodes
+        self.index = next(self.index_counter)
+        self.output_node: Union[ir.Buffer, list[ir.Buffer]] = ir.Buffer(
+            name=f"buf_out{self.index}", layout=layout
+        )
+        self.layout = layout
+        self.num_threads = num_threads
+        self.epilogue_creator = epilogue_creator
+
+    def generate(self, **kwargs):
+        kernel_name = f"cpp_{self.name}"
+        with (
+            patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)),
+            patch.object(ir.FlexibleLayout, "allow_indexing", True),
+            V.graph.set_current_device(self.layout.device),
+            CppTemplateKernel(
+                kernel_name=kernel_name, num_threads=self.num_threads
+            ) as kernel,
+        ):
+            code = kernel.render(self, **kwargs)
+            _, call_args, _, _ = kernel.args.python_argdefs()
+            log.debug("Generated Code:\n%s", code)
+            log.debug(
+                "Args: cpp_argdefs: %s, python_argdefs: %s",
+                kernel.args.cpp_argdefs(),
+                kernel.args.python_argdefs(),
+            )
+
+        expected_args = list(
+            unique(input_node.get_name() for input_node in self.input_nodes)
+        )
+        if isinstance(self.output_node, Iterable):
+            expected_args.extend([node.get_name() for node in self.output_node])
+        else:
+            expected_args.extend([self.output_node.get_name()])
+        assert list(call_args)[: len(expected_args)] == expected_args, (
+            call_args,
+            expected_args,
+        )
+        extra_args = V.graph.sizevars.size_hints(
+            map(sympy.expand, call_args[len(expected_args) :])
+        )
+        # Cast the size hint from int to ctypes.c_ulonglong explicitly
+        # since in cpp kernel, we bind it to C long
+        extra_args = tuple(ctypes.c_ulonglong(x) for x in extra_args)
+
+        kernel_hash_name = f"cpp_{self.name}_{self.index}"
+
+        # Create the BenchmarkRequest for CPP
+        bmreq = CppBenchmarkRequest(
+            kernel_name=kernel_name,
+            input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes),
+            # pyrefly: ignore [bad-argument-type]
+            output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
+            extra_args=extra_args,
+            source_code=code,
+        )
+
+        def make_kernel_render(
+            template_node: ir.CppTemplateBuffer,
+            flag_template_buffer_has_other_users: bool,
+            epilogue_nodes: Optional[list[ir.IRNode]] = None,
+        ):
+            kernel = CppTemplateKernel(
+                kernel_name=str(Placeholder.KERNEL_NAME), num_threads=self.num_threads
+            )
+            render = functools.partial(
+                kernel.render,
+                self,
+                template_buffer_node=template_node,
+                flag_template_buffer_has_other_users=flag_template_buffer_has_other_users,
+                epilogue_nodes=epilogue_nodes,
+                **kwargs,
+            )
+            return kernel, render
+
+        return CppTemplateCaller(
+            kernel_hash_name,
+            self.name,
+            self.input_nodes,
+            # pyrefly: ignore [index-error]
+            self.output_node[0].get_layout()
+            if isinstance(self.output_node, Iterable)
+            else self.output_node.get_layout(),
+            make_kernel_render,
+            bmreq,
+            self,
+        )
+
+    def header(self) -> IndentedBuffer:
+        res = IndentedBuffer()
+        res.writeline("#include ")
+        # TODO: add c10::ForcedUnroll test to test_aoti_abi_check
+        res.splice("""#include """)
+        res.splice("""#include """)
+        enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [
+            "linux",
+            "win32",
+        ]
+        if enable_kernel_profile:
+            res.writelines(["#include "])
+        return res
+
+    def render(self, **kwargs) -> str:
+        raise NotImplementedError
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_template_kernel.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_template_kernel.py
new file mode 100644
index 0000000000000000000000000000000000000000..1434398eac8a7e095a6ebb7d9c17e81cde8db11e
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_template_kernel.py
@@ -0,0 +1,621 @@
+# mypy: allow-untyped-defs
+import itertools
+from collections.abc import Callable, Iterable
+from typing import Any, Optional, Union
+from unittest.mock import patch
+
+import sympy
+from sympy.parsing.sympy_parser import parse_expr
+
+import torch
+from torch._inductor.utils import do_bench_using_profiling
+from torch.utils._ordered_set import OrderedSet
+from torch.utils._sympy.symbol import SymT
+
+from .. import config, cpp_builder, ir, lowering as L
+from ..autotune_process import CppBenchmarkRequest
+from ..loop_body import LoopBody
+from ..select_algorithm import PartialRender
+from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix
+from ..virtualized import V
+from .common import REMOVED
+from .cpp import CppKernel, CppKernelProxy, KernelGroup, ParallelDepth
+from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferContext
+
+
+def parse_expr_with_index_symbols(expr):
+    if isinstance(expr, sympy.Expr):
+        return expr
+    elif isinstance(expr, (list, tuple)):
+        return [parse_expr_with_index_symbols(e) for e in expr]
+    else:
+        expr = parse_expr(str(expr))
+        int_symbols = {sym: sympy_index_symbol(sym.name) for sym in expr.free_symbols}
+        return expr.subs(int_symbols)
+
+
+def wrap_with_tensorbox(node) -> Union[ir.TensorBox, ir.ShapeAsConstantBuffer]:
+    return (
+        ir.TensorBox.create(node) if isinstance(node, ir.Buffer) else ir.TensorBox(node)
+    )
+
+
+class CppTemplateKernel(CppKernel):
+    def __init__(self, kernel_name, num_threads):
+        super().__init__(None, num_threads)
+        self.kernel_name = kernel_name
+        self.render_hooks = {}
+        self.local_buffers = {}
+
+    def render(self, template, **kwargs):
+        return PartialRender(
+            template.render(kernel=self, **kwargs), self.render_hooks
+        ).finalize_all()
+
+    def def_kernel(
+        self,
+        inputs: dict[str, ir.Buffer],
+        outputs: dict[str, ir.Buffer],
+        aliases: Optional[dict[str, str]] = None,
+        function_name: str = "",
+        extra_sizevars: Optional[list[sympy.Expr]] = None,
+        placeholder: str = "",
+    ) -> str:
+        if len(function_name) == 0:
+            function_name = str(self.kernel_name)
+        for name, inp in inputs.items():
+            if inp is not None:
+                self.args.input_buffers[inp.get_name()] = name
+        for name, out in outputs.items():
+            self.args.output_buffers[out.get_name()] = name
+        if aliases is not None:
+            for alias, orig in aliases.items():
+                if orig in self.args.input_buffers:
+                    self.args.input_buffers[alias] = self.args.input_buffers[orig]
+                if orig in self.args.output_buffers:
+                    self.args.output_buffers[alias] = self.args.output_buffers[orig]
+
+        unique_sizevars = OrderedSet(
+            s
+            for input in inputs.values()
+            if input is not None
+            for sym in itertools.chain(input.get_size(), input.get_stride())
+            if isinstance(sym, sympy.Expr)
+            for s in sym.free_symbols
+        )
+        unique_sizevars.update(
+            s
+            for sym in extra_sizevars or []
+            if isinstance(sym, sympy.Expr)
+            for s in sym.free_symbols
+        )
+        unique_sizevars.update(
+            s
+            for output in outputs.values()
+            for sym in itertools.chain(output.get_size(), output.get_stride())
+            if isinstance(sym, sympy.Expr)
+            for s in sym.free_symbols
+        )
+        sizevars = sorted(unique_sizevars, key=str)
+        for sizevar in sizevars:
+            self.args.sizevars[sizevar] = f"k{sizevar}"
+
+        def hook():
+            # remove all aliases before generate function definition
+            if aliases is not None:
+                for alias in aliases:
+                    if alias in self.args.input_buffers:
+                        raise AssertionError(
+                            f"input_buffers cannot be removed: {alias}"
+                        )
+                    if alias in self.args.output_buffers:
+                        self.args.output_buffers[alias] = REMOVED
+            cpp_argdefs, _, _ = self.args.cpp_argdefs()
+            return f"void {function_name}({', '.join(cpp_argdefs)})"
+
+        assert placeholder not in self.render_hooks
+        self.render_hooks[placeholder] = hook
+        return placeholder
+
+    def call_kernel(self, name: str, node: ir.CppTemplateBuffer):
+        wrapper = V.graph.wrapper_code
+        _, call_args, arg_types = self.args.cpp_argdefs()
+        wrapper.generate_kernel_call(name, call_args, triton=False, arg_types=arg_types)
+
+    def dtype(self, node: ir.Buffer) -> str:
+        return DTYPE_TO_CPP[node.get_dtype()]
+
+    def acc_dtype(self, node: ir.Buffer) -> str:
+        if node.get_dtype() in [torch.float32, torch.bfloat16, torch.half]:
+            return "float"
+        else:
+            raise NotImplementedError(f"Unsupported dtype: {node.get_dtype()}")
+
+    def size(self, node: ir.Buffer, dim: int) -> str:
+        return cexpr_index(self.rename_indexing(node.get_size()[dim]))
+
+    def stride(self, node: ir.Buffer, dim: int) -> str:
+        return cexpr_index(self.rename_indexing(node.get_stride()[dim]))
+
+    def index(self, node: ir.Buffer, indices: list[Any]) -> str:
+        indexer = node.get_layout().as_fixed().make_indexer()
+        index = indexer(parse_expr_with_index_symbols(indices))
+        index = self.rename_indexing(index)
+        outer_name = node.get_name()
+        inner_name = (
+            outer_name
+            if outer_name in self.local_buffers
+            else self.args.input(node.get_name())
+        )
+        return f"{inner_name}[{cexpr_index(index)}]"
+
+    def slice_nd(self, node, ranges: list[tuple[Any, Any]]) -> ir.ReinterpretView:
+        """
+        Slice the given node with a list of ranges (start and end) corresponding to its dims.
+        The dim is not sliced if the corresponding range is empty.
+        """
+        assert len(ranges) == len(node.get_size()), f"{ranges=}, {node=}"
+        sliced = wrap_with_tensorbox(node)
+        for dim, _range in enumerate(ranges):
+            if len(_range) == 0:
+                continue
+            assert len(_range) == 2
+            start, end = parse_expr_with_index_symbols(_range)
+            sliced = L.slice_(sliced, dim, start, end, clamp=False)
+        assert isinstance(sliced, ir.TensorBox)
+        assert isinstance(sliced.data, ir.ReinterpretView), sliced.data
+        return sliced.data
+
+    def select(self, node, dim: int, idx: int) -> ir.ReinterpretView:
+        # We avoid using L.select here because we need clamp=False so the dim after slicing
+        # is 1 instead of a sympy expression of symbol - dim_size.
+        node = wrap_with_tensorbox(node)
+        idx = ir.View.handle_negative_index(idx, node.get_size()[dim])
+        sliced = L.squeeze(L.slice_(node, dim, idx, idx + 1, clamp=False), dim)
+        assert isinstance(sliced.data, ir.ReinterpretView), sliced.data
+        return sliced.data
+
+    def view(self, node, sizes: list[Any]) -> ir.IRNode:
+        node = wrap_with_tensorbox(node)
+        sizes = parse_expr_with_index_symbols(sizes)
+        return L.view(node, sizes).data  # type: ignore[arg-type]
+
+    def permute(self, node, dims):
+        node = wrap_with_tensorbox(node)
+        permuted = L.permute(node, dims).data
+        assert isinstance(permuted, ir.ReinterpretView)
+        return permuted
+
+    def maybe_codegen_profile(self) -> str:
+        if config.cpp.enable_kernel_profile:
+            graph_id = V.graph.graph_id
+            prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else ""
+            handle_str = (
+                "torch::aot_inductor::RAIIAtenRecordFunctionHandle "
+                f'record_{prefix}{self.kernel_name}_("{prefix}{self.kernel_name}", nullptr);'
+            )
+            return handle_str
+        else:
+            return ""
+
+    def unroll_pragma(self, unroll):
+        if cpp_builder.is_gcc():
+            return f"#pragma GCC unroll {unroll}"
+        else:
+            return f"#pragma unroll {unroll}"
+
+    def define_buffer(self, name, sizes: list[Any], dtype=torch.float) -> str:
+        """Define kernel local buffer"""
+        sizes = parse_expr_with_index_symbols(sizes)
+        buf = ir.Buffer(
+            name=name, layout=ir.FixedLayout(torch.device("cpu"), dtype, sizes)
+        )
+        self.local_buffers[name] = buf
+        ctype = f"{DTYPE_TO_CPP[dtype]}"
+        numel = f"{cexpr_index(buf.get_numel())}"
+        return f"auto _{name} = std::make_unique<{ctype}[]>({numel}); auto {name} = _{name}.get();"
+
+    def define_stack_allocated_buffer(
+        self, name, sizes: list[Any], dtype=torch.float
+    ) -> str:
+        """Define stack-allocated buffer"""
+        sizes = parse_expr_with_index_symbols(sizes)
+        buf = ir.Buffer(
+            name=name, layout=ir.FixedLayout(torch.device("cpu"), dtype, sizes)
+        )
+        self.local_buffers[name] = buf
+        ctype = f"{DTYPE_TO_CPP[dtype]}"
+        numel = f"{cexpr_index(buf.get_numel())}"
+        return f"alignas(64) {ctype} _{name}[{numel}]; {ctype}* {name} = _{name};"
+
+    def reinit_buffer_if_null(self, name):
+        """Reinit the previously defined local buffer if it is null"""
+        assert name in self.local_buffers
+        buf = self.local_buffers[name]
+        ctype = f"{DTYPE_TO_CPP[buf.layout.dtype]}"
+        numel = f"{cexpr_index(buf.get_numel())}"
+        return f"if (_{name} == nullptr) {{ _{name} = std::make_unique<{ctype}[]>({numel}); {name} = _{name}.get(); }}"
+
+    def release_buffer(self, name):
+        """Codegen the code to release the ownership of a local buffer to others"""
+        assert name in self.local_buffers
+        return f"_{name}.release()"
+
+    def store_pointwise_nodes(
+        self,
+        dst: ir.Buffer,
+        nodes: list[ir.IRNode],
+        offsets: Optional[list[sympy.Expr]] = None,
+        reindexers: Optional[list[Optional[Callable[[list[Any]], list[Any]]]]] = None,
+    ) -> str:
+        var_sizes = (tuple(dst.get_size()), ())
+        var_ranges = {
+            sympy_index_symbol_with_prefix(SymT.INDEX, i): sz
+            for i, sz in enumerate(var_sizes[0])
+        }
+        if not offsets:
+            offsets = [sympy.S.Zero] * len(var_sizes[0])
+        if not reindexers:
+            reindexers = [None] * len(nodes)
+        assert len(offsets) == len(var_sizes[0])
+        output_index = dst.get_layout().make_indexer()([*var_ranges.keys()])
+        kernel_group = KernelGroup()
+        kernel_group.args = self.args
+        cpp_kernel_proxy = CppKernelProxy(kernel_group)
+        bodies = []
+        var_sizes_list = []
+        for i, node in enumerate(nodes):
+            output_name = node.get_name() if i < len(nodes) - 1 else dst.get_name()
+            node = node.data if isinstance(node, ir.ComputedBuffer) else node
+            assert isinstance(node, ir.Pointwise), node
+
+            def fn(*args):
+                assert len(args) == 2
+                assert len(args[0]) == len(var_sizes[0])
+                assert len(args[1]) == 0
+                new_args = [arg + offset for arg, offset in zip(args[0], offsets)]  # type: ignore[arg-type]
+                if reindexers[i] is not None:
+                    new_args = reindexers[i](new_args)  # type: ignore[misc]
+                V.ops.store(
+                    output_name,
+                    output_index,
+                    node.make_loader()(new_args).value,
+                )
+
+            body = LoopBody(
+                fn,
+                (list(var_ranges.keys()), ()),
+                var_ranges,
+                list(var_ranges.keys()),
+                tuple(),
+            )
+            bodies.append(body)
+            var_sizes_list.append(var_sizes)
+
+        cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list)
+
+        def max_parallel_depth():
+            return ParallelDepth(parallel_depth=0, start_depth=0)
+
+        # This loop is not parallelized since it is not the outermost loop.
+        with patch.object(
+            cpp_kernel_proxy.loop_nest, "max_parallel_depth", max_parallel_depth
+        ):
+            kernel_group.finalize_kernel(cpp_kernel_proxy, [])
+        return kernel_group.loops_code.getvalue()
+
+    def store_grouped_gemm_pointwise_nodes(
+        self,
+        dst: tuple[ir.Buffer],
+        nodes: list[ir.IRNode],
+        offsets: list[sympy.Expr],
+        reindexers: list[Optional[Callable[[list[Any]], list[Any]]]],
+        output_names: list[str],
+    ) -> str:
+        ref_dst = dst[0]
+        var_sizes = (tuple(ref_dst.get_size()), ())
+        var_ranges = {
+            sympy_index_symbol_with_prefix(SymT.INDEX, i): sz
+            for i, sz in enumerate(var_sizes[0])
+        }
+        assert offsets, "offsets should be set outside"
+        assert all(len(offset) == len(var_sizes[0]) for offset in offsets)
+        output_index = ref_dst.get_layout().make_indexer()([*var_ranges.keys()])
+        kernel_group = KernelGroup()
+        kernel_group.args = self.args
+        cpp_kernel_proxy = CppKernelProxy(kernel_group)
+        bodies = []
+        var_sizes_list = []
+        for i, node in enumerate(nodes):
+            output_name = output_names[i]
+            node = node.data if isinstance(node, ir.ComputedBuffer) else node
+            assert isinstance(node, ir.Pointwise), node
+
+            def fn(*args):
+                assert len(args) == 2
+                assert len(args[0]) == len(var_sizes[0])
+                assert len(args[1]) == 0
+                new_args = [arg + offset for arg, offset in zip(args[0], offsets[i])]  # type: ignore[arg-type]
+                if reindexers[i] is not None:
+                    new_args = reindexers[i](new_args)  # type: ignore[misc]
+                V.ops.store(
+                    output_name,
+                    output_index,
+                    node.make_loader()(new_args).value,
+                )
+
+            body = LoopBody(
+                fn,
+                (list(var_ranges.keys()), ()),
+                var_ranges,
+                list(var_ranges.keys()),
+                tuple(),
+            )
+            bodies.append(body)
+            var_sizes_list.append(var_sizes)
+
+        cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list)
+
+        def max_parallel_depth():
+            return ParallelDepth(parallel_depth=0, start_depth=0)
+
+        # This loop is not parallelized since it is not the outermost loop.
+        with patch.object(
+            cpp_kernel_proxy.loop_nest, "max_parallel_depth", max_parallel_depth
+        ):
+            kernel_group.finalize_kernel(cpp_kernel_proxy, [])
+        return kernel_group.loops_code.getvalue()
+
+    def store_output(
+        self,
+        dst: ir.Buffer,
+        src: ir.Buffer,
+        orig_src: Optional[ir.Buffer] = None,
+        epilogue_nodes: Optional[list[ir.IRNode]] = None,
+        offsets: Optional[list[Any]] = None,
+        reindexers: Optional[list[Optional[Callable[[list[Any]], list[Any]]]]] = None,
+    ):
+        """
+        Store the `src` buffer to the `dst` buffer. The size of `src` and `dst` should match.
+        If `epilogue_nodes` is provided, the `src` buffer is firstly computed with the epilogues
+        before stored to `dst`. The `epilogues_nodes` are all pointwise.
+
+        Notes:
+        1. `src` and `dst` buffer could be the same buffer in which case we are doing in-place compute
+           and stores. In case `epilogue_nodes` are not provided, we do nothing.
+        2. The `epilogue_nodes`, if exist, have computations on `src` before storing to `dst` but since
+           they come form the original Inductor IR, they might need to be adjusted before working with
+           `src` and `dst` as outlined below:
+           a) `src` or `dst` buffer could be a sub-slice of the ranges the `epilogue_nodes`work on.
+              In this case, the `offsets` could be provided to adjust the indices passed to
+              `epilogue_nodes` during codegen and the data ranges are also configured according to
+              the sizes of `src` and `dst`.
+           b) `dst` might be indexed in a different way as the `epilogue_nodes`, hence a `reindexer` is
+              needed on the indices to `epilogue_nodes` to match the indexing of `dst`.
+           c) If `src` is local, we need to add a local buffer for it and localize the `orig_src` buffer
+              in `epilogue_nodes` with `src`.
+        """
+        assert isinstance(dst, (ir.Buffer, ir.ReinterpretView))
+        assert dst.get_size() == src.get_size(), f"{dst=}, {src=}"
+        if offsets:
+            offsets = parse_expr_with_index_symbols(offsets)
+        if epilogue_nodes:
+            with LocalBufferContext(self.args) as scope:
+                assert orig_src is not None
+                if orig_src.get_name() != src.get_name():
+                    scope.add_local_buffer(
+                        src,
+                        [
+                            orig_src,
+                        ],
+                    )
+                    epilogue_nodes = scope.localize_nodes(epilogue_nodes)
+                return self.store_pointwise_nodes(
+                    # pyrefly: ignore [bad-argument-type]
+                    dst,
+                    epilogue_nodes,  # type: ignore[arg-type]
+                    offsets,
+                    reindexers,
+                )
+        else:
+            if dst.get_name() != src.get_name():
+                # src is local
+                copy = L.copy(dst, src).data.data
+                with LocalBufferContext(self.args) as scope:
+                    scope.add_local_buffer(src)
+                    # pyrefly: ignore [bad-argument-type]
+                    return self.store_pointwise_nodes(dst, [copy])
+            else:
+                assert dst.layout == src.layout, f"{dst=}, {src=}"
+                return ""
+
+    def store_outputs(
+        self,
+        dst: tuple[ir.Buffer],
+        src: tuple[ir.IRNode],
+        orig_src: Optional[tuple[ir.IRNode]] = None,
+        epilogue_nodes: Optional[list[ir.IRNode]] = None,
+        offsets: Optional[list[Any]] = None,
+        reindexers: Optional[list[Optional[Callable[[list[Any]], list[Any]]]]] = None,
+        multi_output_buffers: Optional[tuple[ir.MultiOutput, ...]] = None,
+    ):
+        assert isinstance(dst, Iterable)
+        assert all(_dst.get_size() == _src.get_size() for _src, _dst in zip(src, dst))
+        if offsets:
+            offsets = parse_expr_with_index_symbols(offsets)
+        gemm_num = len(src)
+        final_offsets = []
+        output_names = []
+        if epilogue_nodes:
+            if not reindexers:
+                reindexers = [None] * len(epilogue_nodes)
+            with LocalBufferContext(self.args) as scope:
+                assert orig_src is not None
+                localize_epilogue_nodes = []
+                all_read_names = []
+                for epilogue in epilogue_nodes:
+                    all_read_names.extend(list(epilogue.get_read_names()))
+                localize_epilogue_nodes.extend(scope.localize_nodes(epilogue_nodes))
+                final_offsets.extend([offsets] * len(localize_epilogue_nodes))
+                output_names.extend(
+                    [node.get_name() for node in localize_epilogue_nodes]
+                )
+                for gemm_idx in range(gemm_num):
+                    if orig_src[gemm_idx].get_name() != src[gemm_idx].get_name():
+                        if orig_src[gemm_idx].get_name() in all_read_names or (
+                            multi_output_buffers
+                            and multi_output_buffers[gemm_idx].get_name()
+                            in all_read_names
+                        ):
+                            # If any of the Epilogue nodes use this GEMM output, let's localize the GEMM output
+                            global_buffers = [orig_src[gemm_idx]]
+                            if (
+                                multi_output_buffers
+                                and multi_output_buffers[gemm_idx].get_name()
+                                in all_read_names
+                                and orig_src[gemm_idx].get_name() not in all_read_names
+                            ):
+                                # Epilogue might directly read the MultiOutput, Locallize MultiOutput to the local Buffer
+                                # if this MultiOutput has not been stored by in-template epilogue
+                                # otherwise, use the cse store cache if it will be stored before used
+                                global_buffers.append(multi_output_buffers[gemm_idx])
+                            scope.add_local_buffer(
+                                src[gemm_idx],
+                                global_buffers,
+                            )
+                        else:
+                            scope.add_local_buffer(src[gemm_idx])
+                            localize_epilogue_nodes.extend(
+                                [L.copy(dst[gemm_idx], src[gemm_idx]).data.data]
+                            )
+                            reindexers.append(None)
+                            output_names.append(dst[gemm_idx].get_name())
+                            final_offsets.append(
+                                [sympy.S.Zero] * len(dst[gemm_idx].get_size())
+                            )
+                res = self.store_grouped_gemm_pointwise_nodes(
+                    dst,
+                    localize_epilogue_nodes,
+                    final_offsets,
+                    reindexers,
+                    output_names=output_names,
+                )
+                for gemm_idx in range(gemm_num):
+                    if (
+                        multi_output_buffers
+                        and multi_output_buffers[gemm_idx].get_name() in all_read_names
+                    ):
+                        # If the MultiOutput is used in the Epilogue, let's remove it from args
+                        multi_output_name = multi_output_buffers[gemm_idx].get_name()
+                        if (
+                            multi_output_name in self.args.output_buffers
+                            and self.args.output_buffers[multi_output_name]
+                            is not REMOVED
+                        ):
+                            self.remove_buffer(multi_output_name)
+                return res
+        else:
+            if dst[0].get_name() != src[0].get_name():
+                copy_list = []
+                with LocalBufferContext(self.args) as scope:
+                    for _src, _dst in zip(src, dst):
+                        copy_list.extend([L.copy(_dst, _src).data.data])
+                        scope.add_local_buffer(_src)
+                        output_names.append(_dst.get_name())
+                        final_offsets.append([sympy.S.Zero] * len(_dst.get_size()))
+                    reindexers = [None] * len(copy_list)
+                    return self.store_grouped_gemm_pointwise_nodes(
+                        dst,
+                        nodes=copy_list,
+                        offsets=final_offsets,
+                        reindexers=reindexers,
+                        output_names=output_names,
+                    )
+            else:
+                assert all(
+                    _src.get_name() == _dst.get_name() for _src, _dst in zip(src, dst)
+                )
+                assert all(
+                    _src.get_layout() == _dst.get_layout()
+                    for _src, _dst in zip(src, dst)
+                )
+                return ""
+
+    def check_bounds(self, expr, size, lower, upper):
+        # CppTemplateKernel does not need codegen related operations
+        return
+
+
+class CppTemplateCaller(ir.ChoiceCaller):
+    """
+    CppTemplateCaller
+
+    This class represents a caller for CPP template kernels. It is a subclass of ir.ChoiceCaller.
+    Attributes:
+        name (str): The name of the caller.
+        category (str): The category of the caller.
+        bmreq (CppBenchmarkRequest): The benchmark request for the caller.
+        template_buffer (ir.CppTemplateBuffer): The template buffer for the caller.
+    """
+
+    def __init__(
+        self,
+        name: str,
+        category: str,
+        input_nodes: list[ir.Buffer],
+        layout: ir.Layout,
+        make_kernel_render: Callable[
+            [
+                ir.CppTemplateBuffer,
+                bool,
+                Optional[list[ir.IRNode]],
+            ],
+            str,
+        ],
+        bmreq: CppBenchmarkRequest,
+        template: "CppTemplate",  # type: ignore[name-defined]  # noqa: F821
+        info_kwargs: Optional[
+            dict[str, Union[ir.PrimitiveInfoType, list[ir.PrimitiveInfoType]]]
+        ] = None,
+    ):
+        super().__init__(name, input_nodes, layout, description="")
+        self.category = category
+        self.make_kernel_render = make_kernel_render
+        self.bmreq = bmreq
+        self.template = template
+        self.info_kwargs = info_kwargs
+
+    def precompile(self) -> None:
+        assert self.bmreq is not None
+        self.bmreq.precompile()
+
+    def benchmark(self, *args, out) -> float:
+        assert self.bmreq is not None
+        if config.profile_bandwidth_with_do_bench_using_profiling:
+            algo = self.bmreq.make_run_fn(*args, out=out)
+            return do_bench_using_profiling(algo)
+        return self.bmreq.benchmark(*args, out=out)
+
+    def hash_key(self) -> str:
+        return "-".join(
+            [
+                self.category,
+                self.bmreq.hash_key,
+            ]
+        )
+
+    def info_dict(
+        self,
+    ) -> dict[str, Union[ir.PrimitiveInfoType, list[ir.PrimitiveInfoType]]]:
+        return {"backend": "CPP", "op_type": "unknown"}
+
+    def output_node(self) -> Union[ir.TensorBox, ir.ShapeAsConstantBuffer]:
+        return ir.TensorBox.create(
+            ir.CppTemplateBuffer(
+                layout=self.layout,
+                inputs=self.input_nodes,
+                make_kernel_render=self.make_kernel_render,
+                template=self.template,
+                choice=self,
+            )
+        )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef2bcede213b8f90e66ae18e40e9e18a5e24652e
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_utils.py
@@ -0,0 +1,787 @@
+# mypy: allow-untyped-defs
+import contextlib
+import dataclasses
+import functools
+import math
+import sys
+from collections import namedtuple
+from collections.abc import Callable, Sequence
+from typing import Any, Optional
+from unittest.mock import patch
+
+import sympy
+
+import torch
+from torch._prims_common import is_integer_dtype
+from torch.utils._ordered_set import OrderedSet
+from torch.utils._sympy.printers import CppPrinter as _CppPrinter
+from torch.utils._sympy.symbol import symbol_is_type, SymT
+from torch.utils._sympy.value_ranges import ValueRanges
+
+from .. import ir
+from ..dependencies import Dep
+from ..loop_body import LoopBody
+from ..scheduler import BaseSchedulerNode, SchedulerBuffer
+from ..shape_propagation import BlockShapeType
+from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs
+from ..virtualized import ops, OpsValue, V
+from .common import CSEVariable, Kernel, KernelArgs, OptimizationContext
+
+
+DTYPE_TO_CPP = {
+    torch.float32: "float",
+    torch.float64: "double",
+    torch.float16: "at::Half",
+    torch.int64: "int64_t",
+    torch.int32: "int32_t",
+    torch.int16: "int16_t",
+    torch.int8: "int8_t",
+    torch.uint64: "uint64_t",
+    torch.uint32: "uint32_t",
+    torch.uint16: "uint16_t",
+    torch.uint8: "uint8_t",
+    torch.bool: "bool",
+    torch.bfloat16: "at::BFloat16",
+    torch.complex32: "at::complex",
+    torch.complex64: "at::complex",
+    torch.complex128: "at::complex",
+    torch.float8_e4m3fn: "at::Float8_e4m3fn",
+    torch.float8_e5m2: "at::Float8_e5m2",
+    torch.float8_e4m3fnuz: "at::Float8_e4m3fnuz",
+    torch.float8_e5m2fnuz: "at::Float8_e5m2fnuz",
+}
+
+DTYPE_TO_ATEN = {
+    torch.float32: "at::kFloat",
+    torch.float64: "at::kDouble",
+    torch.float16: "at::kHalf",
+    torch.int64: "at::kLong",
+    torch.int32: "at::kInt",
+    torch.int16: "at::kShort",
+    torch.int8: "at::kChar",
+    torch.uint64: "at::kUInt64",
+    torch.uint32: "at::kUInt32",
+    torch.uint16: "at::kUInt16",
+    torch.uint8: "at::kByte",
+    torch.uint32: "at::kUInt32",
+    torch.uint64: "at::kUInt64",
+    torch.bool: "at::kBool",
+    torch.bfloat16: "at::kBFloat16",
+    torch.complex32: "at::kComplexHalf",
+    torch.complex64: "at::kComplexFloat",
+    torch.complex128: "at::kComplexDouble",
+    torch.float8_e4m3fn: "at::kFloat8_e4m3fn",
+    torch.float8_e5m2: "at::kFloat8_e5m2",
+    torch.float8_e4m3fnuz: "at::kFloat8_e4m3fnuz",
+    torch.float8_e5m2fnuz: "at::kFloat8_e5m2fnuz",
+}
+
+DEVICE_TO_ATEN = {
+    "meta": "at::kMeta",
+    "cpu": "at::kCPU",
+    "cuda": "at::kCUDA",
+    "xpu": "at::kXPU",
+    "mps": "at::kMPS",
+}
+
+LAYOUT_TO_ATEN = {
+    torch.strided: "at::kStrided",
+    torch._mkldnn: "at::kMkldnn",  # type: ignore[attr-defined]
+}
+
+# matches c10/core/DeviceType.h
+DEVICE_TO_INT = {"cpu": 0, "cuda": 1}
+
+_IS_WINDOWS = sys.platform == "win32"
+
+INDEX_TYPE = "int64_t"
+
+GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"])
+
+
+def get_promote_dtype(args):
+    return (
+        functools.reduce(
+            torch.promote_types,  # type: ignore[arg-type]
+            [n.dtype for n in args if isinstance(n, CppCSEVariable)],
+        )
+        if all(n.dtype is not None for n in args if isinstance(n, CppCSEVariable))
+        else None  # not enough info to calculate the promote dtype
+    )
+
+
+def promote_args(new_args):
+    def promote_arg(arg, promote_type):
+        if (
+            isinstance(arg, CppCSEVariable)
+            and arg.dtype
+            and promote_type
+            and arg.dtype != promote_type
+        ):
+            arg = ops.to_dtype(arg, promote_type)
+            arg = arg.value if isinstance(arg, OpsValue) else arg
+            arg.dtype = promote_type
+        return arg
+
+    promote_type = get_promote_dtype(new_args)
+    promote_fn = functools.partial(
+        promote_arg,
+        promote_type=promote_type,
+    )
+    if (
+        all(
+            new_arg.dtype is not None
+            for new_arg in new_args
+            if isinstance(new_arg, CppCSEVariable)
+        )
+        and promote_type
+    ):
+        new_args = list(map(promote_fn, new_args))
+    return new_args
+
+
+class CppCSEVariable(CSEVariable):
+    def __init__(
+        self,
+        name,
+        bounds: ValueRanges[Any],
+        dtype: Optional[torch.dtype] = None,
+        shape: BlockShapeType = None,
+    ) -> None:
+        super().__init__(name, bounds, dtype, shape=shape)
+        self.is_vec = False
+        self.dependent_itervars = OrderedSet[sympy.Symbol]()
+
+    def __repr__(self) -> str:
+        return (
+            f"CppCSEVariable(name: {self.name}, bounds: {self.bounds}, is_vec: {self.is_vec}, dtype: {self.dtype}, "
+            f"dependent_itervars: {self.dependent_itervars})"
+        )
+
+    def update_on_args(self, name, args, kwargs):
+        if name == "load":
+            # args[2] is index
+            self._set_dependent_itervars(args[2])
+        else:
+            # propagate relevant itervars and is_vec from args
+            self.dependent_itervars.update(
+                *[
+                    arg.dependent_itervars
+                    for arg in args
+                    if isinstance(arg, CppCSEVariable)
+                ]
+            )
+            if name == "index_expr":
+                self._set_dependent_itervars(args[0])
+            if any(arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)):
+                self.is_vec = True
+
+    def _set_dependent_itervars(self, index: sympy.Expr):
+        """
+        Set the relevant itervars for this variable based on the `index` expression.
+        This includes the itervars directly used in the `index` as well as relevant itervars
+        of other cse variables used in the `index`.
+        """
+        for s in index.free_symbols:
+            if s in V.kernel.itervars:
+                self.dependent_itervars.add(s)  # type: ignore[arg-type]
+            elif s.name in V.kernel.cse.varname_map:  # type: ignore[attr-defined]
+                self.dependent_itervars.update(
+                    V.kernel.cse.varname_map[s.name].dependent_itervars  # type: ignore[attr-defined]
+                )
+
+    def depends_on(self, itervar: sympy.Symbol):
+        return itervar in self.dependent_itervars
+
+
+class CppPrinter(_CppPrinter):
+    def doprint(self, expr, *, simplify: bool = True, p=True):
+        # TODO: why are people passing strings to the printer here :think:
+        if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
+            expr = V.graph.sizevars.simplify(expr)
+        return super().doprint(expr)
+
+    def parenthesize(self, item: sympy.Expr, level: int, strict: bool = False) -> str:
+        if isinstance(item, sympy.Mod):
+            # use parenthesis to enforce precedence.
+            # in sympy 1.13.3, -2*Mod(x,y) becomes -2*x%y, which is wrong.
+            return f"({self._print(item)})"
+        else:
+            return super().parenthesize(item, level, strict)
+
+
+# A function to print, useful for printing sympy symbols.
+cexpr = CppPrinter().doprint
+
+
+def cexpr_index(index):
+    return f"static_cast<{INDEX_TYPE}>({cexpr(index)})"
+
+
+def value_to_cpp(value, cpp_type):
+    if value == float("-inf"):
+        return f"-std::numeric_limits<{cpp_type}>::infinity()"
+    elif value == float("inf"):
+        return f"std::numeric_limits<{cpp_type}>::infinity()"
+    elif isinstance(value, bool):
+        return f"static_cast<{cpp_type}>({str(value).lower()})"
+    elif math.isnan(value):
+        return f"std::numeric_limits<{cpp_type}>::quiet_NaN()"
+    else:
+        return f"static_cast<{cpp_type}>({repr(value)})"
+
+
+def rewrite_index_for_function(
+    localize_buffer_handler: "LocalizeBufferHandler",
+    index: sympy.Expr,
+    global_buf_name: str,
+):
+    # Local buffer at the inner dimensions
+    snode = V.graph.scheduler.name_to_buf[global_buf_name].defining_op
+    assert snode is not None
+    local_buf = localize_buffer_handler.global_to_local[global_buf_name]
+    scheduler_nodes = snode.get_nodes()
+    _, (group, reduction_group) = max(
+        scheduler_nodes, key=lambda x: int(x.is_reduction())
+    ).group
+    call_ranges = tuple(group) + tuple(reduction_group)
+    indices_to_keep = [
+        f"x{len(call_ranges) - (idx + 1)}"
+        for idx in range(len(local_buf.get_layout().size))
+    ]
+    sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)  # type: ignore[attr-defined]
+    replacements = {}
+    for x in sorted_symbols:
+        if x.name.startswith("x") and x.name not in indices_to_keep:  # type: ignore[attr-defined]
+            # Only keep index used by local buffer
+            replacements[x] = sympy.core.numbers.Zero()
+    index = sympy_subs(index, replacements)  # type: ignore[arg-type]
+    return index
+
+
+def rewrite_index_for_nodes(
+    localize_buffer_handler: "LocalizeBufferHandler",
+    index: sympy.Expr,
+    global_buf_name: str,
+):
+    used_vars = OrderedSet(
+        s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX)
+    )
+    index_vars = []
+    local_buf = localize_buffer_handler.global_to_local[global_buf_name]
+    for i in range(len(local_buf.get_size())):
+        var = sympy_index_symbol_with_prefix(SymT.INDEX, i)
+        index_vars.append(var if var in used_vars else 0)
+    index = local_buf.get_layout().make_indexer()(index_vars)
+    return index
+
+
+class LocalizeBufferHandler(V.WrapperHandler):  # type: ignore[name-defined]
+    def __init__(
+        self,
+        inner,
+        global_to_local: dict[str, ir.Buffer],
+        rewrite_index: Callable[["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr],
+    ) -> None:
+        super().__init__(inner)
+        self.global_to_local = global_to_local
+        self.rewrite_index = rewrite_index
+
+    def localize(self, name: str, index: sympy.Expr):
+        if self.global_to_local and name in self.global_to_local:
+            assert self.rewrite_index is not None
+            index = self.rewrite_index(self, index, name)
+            name = self.global_to_local[name].get_name()
+        return name, index
+
+    def load(self, name: str, index: sympy.Expr):
+        return self._inner.load(*self.localize(name, index))
+
+    def store(self, name, index, value, mode=None):
+        local_buffer_name, local_buffer_index = self.localize(name, index)
+        res = self._inner.store(local_buffer_name, local_buffer_index, value, mode)
+        if (
+            self.global_to_local
+            and name in self.global_to_local
+            and isinstance(V.kernel, Kernel)
+        ):
+            # Remove name of local buffer from Kernel.store_buffer_names
+            # local_buffer_name is added to Kernel.store_buffer_names in Kernel.CSEProxy.store.
+            V.kernel.store_buffer_names.discard(local_buffer_name)
+        return res
+
+    def store_reduction(self, name, index, value):
+        # pyrefly: ignore [bad-argument-count]
+        return self._inner.store_reduction(*self.localize(name, index), value)
+
+
+class LocalBufferContext:
+    """
+    This class creates a context that helps to generate code involving Inductor IR with
+    function local buffers. These buffers are constructed during the codegen process and
+    are used to store intermediate results such as local accumulators. We do not want to
+    add them to `V.graph` since they are not global and we do not want to add them as
+    function arguments either. So we patch the codegen processes under this scope to support
+    these buffers without exposure to the outside world.
+    """
+
+    def __init__(self, kernel_args: KernelArgs) -> None:
+        self.kernel_args = kernel_args
+        self.exit_stack = contextlib.ExitStack()
+        # map local buffer name to local buffer
+        self.local_buffers: dict[str, ir.Buffer] = {}
+        # map global buffer name to global buffer
+        self.global_buffers: dict[str, ir.Buffer] = {}
+        # map global buffer name to local buffer
+        self.global_to_local: dict[str, ir.Buffer] = {}
+        # record the global buffers that are removed by this LocalBufferContext
+        self.removed_buffers: OrderedSet[str] = OrderedSet()
+
+    def __enter__(self):
+        self.exit_stack.__enter__()
+        original_get_dtype = V.graph.get_dtype
+
+        def get_dtype(name):
+            if name in self.local_buffers:
+                return self.local_buffers[name].get_dtype()
+            return original_get_dtype(name)
+
+        self.exit_stack.enter_context(patch.object(V.graph, "get_dtype", get_dtype))
+
+        original_input = self.kernel_args.input
+
+        def input(name):
+            if name in self.local_buffers:
+                return name
+            return original_input(name)
+
+        self.exit_stack.enter_context(patch.object(self.kernel_args, "input", input))
+
+        original_output = self.kernel_args.output
+
+        def output(name):
+            if name in self.local_buffers:
+                return name
+            return original_output(name)
+
+        self.exit_stack.enter_context(patch.object(self.kernel_args, "output", output))
+
+        # Set current LocalBufferContext into V
+        self.exit_stack.enter_context(V.set_local_buffer_context(self))
+
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.local_buffers.clear()
+        self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
+
+    def add_local_buffer(
+        self, local_buffer: ir.Buffer, global_buffers: Optional[list[ir.Buffer]] = None
+    ):
+        assert local_buffer.get_name() not in self.local_buffers
+        self.local_buffers[local_buffer.get_name()] = local_buffer
+        if global_buffers:
+            for global_buffer in global_buffers:
+                global_buffer_name = global_buffer.get_name()
+                assert (
+                    global_buffer_name not in self.global_buffers
+                    and global_buffer_name not in self.global_to_local
+                )
+                self.global_buffers[global_buffer_name] = global_buffer
+                self.global_to_local[global_buffer_name] = local_buffer
+                if global_buffer_name not in V.graph.removed_buffers:
+                    # Record the global buffers that are removed by this LocalBufferContext
+                    # since which may need to restore. Refer to issue:
+                    # https://github.com/pytorch/pytorch/issues/144186
+                    self.removed_buffers.add(global_buffer_name)
+                    V.graph.removed_buffers.add(global_buffer_name)
+
+    def localize_function(
+        self,
+        fn: Callable[..., Any],
+        rewrite_index: Callable[
+            ["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr
+        ] = rewrite_index_for_function,
+    ):
+        def inner(*args, **kwargs):
+            with V.set_ops_handler(
+                LocalizeBufferHandler(
+                    V.get_ops_handler(),
+                    global_to_local=self.global_to_local,
+                    rewrite_index=rewrite_index,
+                )
+            ):
+                return fn(*args, **kwargs)
+
+        return inner
+
+    def localize_nodes(
+        self,
+        nodes: list[ir.IRNode],
+        rewrite_index: Callable[
+            ["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr
+        ] = rewrite_index_for_nodes,
+    ) -> list[ir.IRNode]:
+        """
+        Given `local_buf` and `global_buf` registered in current `LocalBufferContext`
+        though the method of `add_local_buffer`, localizes the `global_buf` to `local_buf`
+        for the given `nodes` and returns a new list of IR nodes that work on `local_buf`
+        instead of `global_buf`, i.e., all the loads and stores are redirected to
+        `local_buf`. This helps the fused loops to work on smaller-sized local buffers
+        for better data locality.
+
+        The data access of `local_buf` is assumed to be contiguous with the
+        same order as the `global_buf`.
+        """
+        assert len(nodes) > 0
+
+        def wrap_inner_fn_for_node(node: ir.IRNode):
+            loops = node.data if isinstance(node, ir.ComputedBuffer) else node
+            assert isinstance(loops, ir.Loops)
+            new_inner_fn = self.localize_function(
+                loops.inner_fn,
+                rewrite_index,
+            )
+
+            new_loops = dataclasses.replace(loops, inner_fn=new_inner_fn)
+            if isinstance(node, ir.ComputedBuffer):
+                new_node = ir.ComputedBuffer(
+                    name=node.get_name(), layout=node.get_layout(), data=new_loops
+                )
+            else:
+                new_node = new_loops  # type: ignore[assignment]
+
+            return new_node
+
+        return [wrap_inner_fn_for_node(node) for node in nodes]
+
+
+def unify_mask_base_type(
+    buffer: IndentedBuffer,
+    vars: tuple[CSEVariable, ...],
+    dtype=torch.float,
+):
+    """
+    Given list of cse variables,
+    Cast each to new mask base dtype and return casted cse variable.
+    """
+    new_vars = (
+        V.kernel.cse.generate(
+            buffer,
+            f"{V.kernel._get_mask_cast(var, dtype)}",
+        )
+        for var in vars
+    )
+    return new_vars
+
+
+def may_unify_binary_op_mask_type(a, b):
+    """
+    Given two cse variables, when dtype is bool, unify them to the same mask dtype and return casted cse variable.
+    """
+    if a.dtype == torch.bool:
+        assert b.dtype == torch.bool
+        mask_dtype = torch.int32
+        return unify_mask_base_type(V.kernel.compute, (a, b), mask_dtype)
+    return a, b
+
+
+def codegen_rand(offset, code, rand_function, dst_dtype=torch.float32):
+    assert is_integer_dtype(offset.dtype)
+    code.writeline("[&]()")
+    with code.indent():
+        code.writeline(
+            f"{DTYPE_TO_CPP[offset.dtype]} offset[{V.kernel.tiling_factor}];"
+        )
+        code.writeline(f"{DTYPE_TO_CPP[dst_dtype]} result[{V.kernel.tiling_factor}];")
+        code.writeline(f"{offset}.store(offset);")
+        code.writeline(
+            f"for( {DTYPE_TO_CPP[offset.dtype]} offset_idx = 0; offset_idx < {V.kernel.tiling_factor}; offset_idx++ )"
+        )
+        with code.indent():
+            code.writeline(rand_function)
+        num_vectors = V.kernel._get_num_vectors(dtype=dst_dtype)
+        if num_vectors == 1:
+            code.writeline(
+                f"return at::vec::Vectorized<{DTYPE_TO_CPP[dst_dtype]}>::loadu(result);"
+            )
+        else:
+            code.writeline(
+                f"return at::vec::VectorizedN<{DTYPE_TO_CPP[dst_dtype]}, {num_vectors}>::loadu(result);"
+            )
+    code.writeline("()")
+    return code
+
+
+def get_gemm_template_output_and_compute_dtype(input_dtype):
+    if input_dtype in [torch.uint8, torch.int8]:
+        return (torch.int32, torch.int32)
+    else:
+        return (torch.float32, torch.float32)
+
+
+def create_epilogue_with_attr(input_buffer, attr, **kwargs):
+    input_loader = input_buffer.make_loader()
+    dtype = input_buffer.get_dtype()
+    if attr == "relu":
+
+        def inner_fn(index):
+            input = input_loader(index)
+            zero = ops.constant(0, dtype)
+            return ops.maximum(input, zero)
+
+    elif attr == "gelu":
+        assert "algorithm" in kwargs
+        if kwargs["algorithm"] == "none":
+
+            def inner_fn(index):
+                input = input_loader(index)
+                if dtype != torch.float:
+                    input = ops.to_dtype(input, torch.float)
+                half = ops.constant(0.5, torch.float)
+                one = ops.constant(1.0, torch.float)
+                const = ops.constant(0.7071067811865476, torch.float)
+                result = input * half * (ops.erf(input * const) + one)
+                if dtype != torch.float:
+                    result = ops.to_dtype(result, dtype)
+                return result
+
+        else:
+            assert kwargs["algorithm"] == "tanh"
+
+            def inner_fn(index):
+                input = input_loader(index)
+                if dtype != torch.float:
+                    input = ops.to_dtype(input, torch.float)
+                half = ops.constant(0.5, torch.float)
+                one = ops.constant(1.0, torch.float)
+                const1 = ops.constant(0.7978845608028654, torch.float)
+                const2 = ops.constant(0.044715, torch.float)
+                result = (
+                    half
+                    * input
+                    * (
+                        one
+                        + ops.tanh(const1 * (input + const2 * input * input * input))
+                    )
+                )
+                if dtype != torch.float:
+                    result = ops.to_dtype(result, dtype)
+                return result
+
+    elif attr == "swish":
+
+        def inner_fn(index):
+            input = input_loader(index)
+            result = input * ops.sigmoid(input)
+            return result
+
+    elif attr == "sigmoid":
+
+        def inner_fn(index):
+            return ops.sigmoid(input_loader(index))
+
+    elif attr == "tanh":
+
+        def inner_fn(index):
+            return ops.tanh(input_loader(index))
+
+    elif attr == "hardswish" or attr == "hardsigmoid":
+
+        def hardsigmoid_float(input):
+            zero = ops.constant(0, torch.float)
+            six = ops.constant(6, torch.float)
+            three = ops.constant(3, torch.float)
+            one_over_six = ops.constant(0.16666666666666666, torch.float)
+            max = ops.maximum(input + three, zero)
+            min = ops.minimum(max, six)
+            return min * one_over_six
+
+        def inner_fn(index):
+            input = input_loader(index)
+            if dtype != torch.float:
+                input = ops.to_dtype(input, torch.float)
+            result = hardsigmoid_float(input)
+            if attr == "hardswish":
+                result = input * result
+            if dtype != torch.float:
+                result = ops.to_dtype(result, dtype)
+            return result
+
+    elif attr == "leaky_relu":
+        assert "scalars" in kwargs
+        assert len(kwargs["scalars"]) == 1
+        negative_slope = kwargs["scalars"][0]
+
+        def inner_fn(index):
+            input = input_loader(index)
+            if dtype != torch.float:
+                input = ops.to_dtype(input, torch.float)
+            zero = ops.constant(0, torch.float)
+            result = ops.where(
+                input > zero, input, input * ops.constant(negative_slope, torch.float)
+            )
+            if dtype != torch.float:
+                result = ops.to_dtype(result, dtype)
+            return result
+
+    elif attr == "hardtanh":
+        assert "scalars" in kwargs
+        assert len(kwargs["scalars"]) == 2
+        min_value = kwargs["scalars"][0]
+        max_value = kwargs["scalars"][1]
+
+        def inner_fn(index):
+            input = input_loader(index)
+            if dtype != torch.float:
+                input = ops.to_dtype(input, torch.float)
+            result = ops.minimum(
+                ops.maximum(input, ops.constant(min_value, torch.float)),
+                ops.constant(max_value, torch.float),
+            )
+            if dtype != torch.float:
+                result = ops.to_dtype(result, dtype)
+            return result
+
+    elif attr in ["add", "sub", "mul"]:
+        assert "other" in kwargs
+        other = kwargs["other"]
+        num_input_dims = len(input_buffer.get_size())
+        num_other_dims = len(other.get_size())
+        dims_diff = num_input_dims - num_other_dims
+        other_loader = other.make_loader()
+
+        def inner_fn(index):
+            op = getattr(ops, attr)
+            if dims_diff != 0:
+                return op(input_loader(index), other_loader(index[dims_diff:]))
+            else:
+                return op(input_loader(index), other_loader(index))
+
+    elif attr == "bias_add":
+        assert "other" in kwargs
+        assert "beta" in kwargs
+        assert "dtype" in kwargs
+        beta = kwargs["beta"]
+        other = kwargs["other"]
+        dtype = kwargs["dtype"]
+        bias_loader = other.make_loader()
+
+        def inner_fn(index):
+            bias = bias_loader(index)
+            input = input_loader(index)
+            if beta != 1:
+                result = ops.constant(beta, torch.float) * bias + input
+            else:
+                result = bias + input
+            return result
+
+    else:
+        raise ValueError(f"Unsupported epilogue attribute: {attr}")
+    return ir.Pointwise(
+        device=input_buffer.get_device(),
+        dtype=dtype,
+        inner_fn=inner_fn,
+        ranges=input_buffer.get_size(),
+    )
+
+
+def _get_loop_body(fn_list):
+    if all(isinstance(fn, LoopBody) for fn in fn_list):
+        loop_bodies = fn_list
+    else:
+        if hasattr(fn_list[0], "original_fn"):
+            # For the case of local buffer, we wrap the fn with localize_function
+            assert all(hasattr(fn, "original_fn") for fn in fn_list)
+            assert all(
+                isinstance(fn.original_fn.args[0]._body, LoopBody) for fn in fn_list
+            )
+            loop_bodies = [fn.original_fn.args[0]._body for fn in fn_list]
+        else:
+            assert all(isinstance(fn, functools.partial) for fn in fn_list)
+            assert all(isinstance(fn.args[0]._body, LoopBody) for fn in fn_list)
+            loop_bodies = [fn.args[0]._body for fn in fn_list]
+    assert loop_bodies is not None
+    return loop_bodies
+
+
+def _get_dtype_from_loopbodies(loop_bodies):
+    dtypes = OrderedSet[torch.dtype]()
+    for loop_body in loop_bodies:
+        graphs = [loop_body.root_block.graph] + [
+            body.graph for body in list(loop_body.subblocks.values())
+        ]
+        for graph in graphs:
+            for node in graph.nodes:
+                if node.op != "call_method":
+                    continue
+                dtypes.add(node.meta[OptimizationContext.key].dtype)
+    return dtypes
+
+
+def template_fusion_with_epilogues_supported(
+    template: BaseSchedulerNode, epilogues: list[BaseSchedulerNode]
+) -> tuple[bool, bool]:
+    def _get_indexes_of_template_buf_read(
+        epilogue_node: ir.Operation, template_buf_names: list[str]
+    ) -> list[sympy.Expr]:
+        return [
+            read.index
+            for read in epilogue_node.get_reads()
+            if read.name in template_buf_names
+        ]
+
+    def _check_supported_and_same_indexes(
+        index_of_template_buf_read: Sequence[sympy.Expr],
+        epilogue_writes: OrderedSet[Dep],
+    ) -> tuple[bool, bool]:
+        num_indexes = len(OrderedSet(index_of_template_buf_read))
+
+        if num_indexes > 1:
+            same_index = False
+            supported = False  # Different read indexes not supported
+        elif num_indexes == 0:
+            same_index = True
+            supported = True  # No reads, automatically supported
+        elif num_indexes == 1:
+            iotbr = index_of_template_buf_read[0]
+            same_index = all(write.index == iotbr for write in epilogue_writes)
+            # TODO: Add support of fusion when the read of template buffer and the write of epilogue output
+            # in the epilogue node don't have the same index and change supported to True
+            supported = same_index
+        else:
+            raise AssertionError("Should not reach here")
+
+        return supported, same_index
+
+    def _template_fusion_supported(
+        template_outputs: Sequence[SchedulerBuffer], epilogue_nodes: list[ir.Operation]
+    ) -> tuple[bool, bool]:
+        template_buf_names = [x.get_name() for x in template_outputs]
+        indexes_of_template_buf_reads = [
+            _get_indexes_of_template_buf_read(epilogue_node, template_buf_names)
+            for epilogue_node in epilogue_nodes
+        ]
+        epilogue_nodes_writes = [
+            epilogue_node.get_read_writes().writes for epilogue_node in epilogue_nodes
+        ]
+
+        results = [
+            _check_supported_and_same_indexes(reads, writes)
+            for reads, writes in zip(
+                indexes_of_template_buf_reads, epilogue_nodes_writes
+            )
+        ]
+        supported, same_indexes = zip(*results)
+        return all(supported), all(same_indexes)
+
+    assert template.is_template()
+    template_outputs = template.get_outputs()
+
+    epilogue_nodes = [
+        n.node
+        for epilogue in epilogues
+        for n in epilogue.get_nodes()
+        if n.node is not None
+    ]
+    return _template_fusion_supported(template_outputs, epilogue_nodes)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..16522d9832ec0e9e8ce7686fe5537e3c4a647410
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py
@@ -0,0 +1,3010 @@
+# mypy: allow-untyped-defs
+from __future__ import annotations
+
+import ctypes
+import functools
+import math
+import os
+import sys
+import textwrap
+from itertools import chain, count
+from typing import Any, Optional, Protocol, TYPE_CHECKING, Union
+
+import sympy
+
+import torch
+import torch._higher_order_ops.torchbind
+import torch._inductor.async_compile  # noqa: F401 required to warm up AsyncCompile pools
+import torch._ops
+from torch._inductor.runtime.runtime_utils import dynamo_timed
+from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey, SymTypes
+from torch.utils._ordered_set import OrderedSet
+from torch.utils._sympy.symbol import symbol_is_type, SymT
+
+from .. import config, cpp_builder, ir
+from ..ir import ExternKernel
+from ..utils import _align, DeferredLineBase, LineContext, normalize_name
+from ..virtualized import V
+from .aoti_hipify_utils import maybe_hipify_code_wrapper
+from .common import get_device_op_overrides, IndentedBuffer, Kernel
+from .cpp_utils import cexpr, DEVICE_TO_ATEN, DEVICE_TO_INT, DTYPE_TO_ATEN, DTYPE_TO_CPP
+from .wrapper import (
+    codegen_reinterpret_view_helper,
+    EnterSubgraphLine,
+    ExitSubgraphLine,
+    PythonWrapperCodegen,
+    SymbolicCallArg,
+)
+
+
+if TYPE_CHECKING:
+    from collections.abc import Callable, Sequence
+
+    from ..graph import GraphLowering
+
+    # At most, the list nesting can go one layer deep.
+    _OUTPUT_ARGS_TYPE = list[Union[Optional[str], list[Optional[str]]]]
+
+    from ..scheduler import BaseSchedulerNode
+
+
+class HasWriteLine(Protocol):
+    def writeline(self, line: Union[LineContext, DeferredLineBase, str]) -> None: ...
+
+
+class CppWrapperCpu(PythonWrapperCodegen):
+    """
+    Generates cpp wrapper for running on CPU and calls cpp kernels
+    """
+
+    def __init__(self):
+        if not hasattr(self, "device"):
+            self.device = "cpu"
+        # must be initialized prior to calling super().__init__()
+        self.included_devices: OrderedSet[str] = OrderedSet()
+        self.model_class_name_suffix = (
+            ""
+            if config.aot_inductor.dynamic_linkage
+            else config.aot_inductor.model_name_for_generated_files
+        )
+        self.aoti_model_class_name = f"AOTInductorModel{self.model_class_name_suffix}"
+
+        super().__init__()
+
+        self.declare = "auto "
+        self.declare_maybe_reference = "decltype(auto) "
+        self.ending = ";"
+        self.comment = "//"
+        self.none_str = "nullptr"
+        self.supports_intermediate_hooks = False
+        self.kernel_callsite_id = count()
+        self.int_array_id = count()  # for int array local variable declarations
+        self.declared_int_array_vars: OrderedSet[str] = OrderedSet()
+        self.tmp_tensor_id = count()  # for tmp tensor local variable declarations
+        self.arg_var_id = count()
+        self.used_cached_devices: OrderedSet[str] = OrderedSet()
+        self.used_cached_dtypes: OrderedSet[str] = OrderedSet()
+        self.used_cached_layouts: OrderedSet[str] = OrderedSet()
+        self.used_cached_memory_formats: OrderedSet[str] = OrderedSet()
+        self.used_cond_predicate: OrderedSet[str] = OrderedSet()
+        self.cached_output_id = count()
+        self.scalar_to_tensor_id = count()
+        self.custom_op_wrapper_loaded = False
+        # For GEMM kernels that must be initialized and are resolved at linking.
+        self.initialized_kernels: dict[str, Kernel] = {}
+        self.device_codegen = get_device_op_overrides(self.device)
+        # only need to include each header once
+        self.include_extra_header = functools.lru_cache(None)(  # type: ignore[method-assign]
+            self._include_extra_header
+        )
+        self.codegen_int_array_var_cache = {}
+
+    @staticmethod
+    def create(
+        is_subgraph: bool,
+        subgraph_name: Optional[str],
+        parent_wrapper: Optional[PythonWrapperCodegen],
+        partition_signatures: Optional[ir.GraphPartitionSignature] = None,
+    ):
+        # TODO - support subgraph codegen by lifting functions. Check the
+        # comment at CppWrapperCpu `codegen_subgraph` function.
+        return CppWrapperCpu()
+
+    @staticmethod
+    def _generate_temporary_array_pointer(
+        c_type: str, elements: Sequence[str], *, force_mutable: bool = False
+    ) -> str:
+        """Get a pointer to an array that only exists for the duration of the C++
+        statement it's used in."""
+        # If the c_type is already a pointer, return a mutable pointer to the array.
+        # Otherwise, return a const pointer.  In the C-shim API, pointer types are only
+        # const-qualified with respect to the underlying value, not any nested pointers.
+        # e.g. const double** is possible, but not const double* const*.  This means
+        # that an array containing pointers must _already_ be properly const-qualified
+        # by the c_type, and not add additional const-ness.
+        # MSVC does not support implicitly converting a const iterator to a const pointer.
+        ptr_call = (
+            "data()"
+            if force_mutable or c_type.endswith("*") or cpp_builder.is_msvc_cl()
+            else "cbegin()"
+        )
+        return (
+            f"std::array<{c_type}, {len(elements)}>{{{', '.join(elements)}}}.{ptr_call}"
+        )
+
+    def _generate_kernel_call_helper(
+        self,
+        kernel_name: str,
+        call_args,
+        *,
+        device=None,
+        triton=True,
+        arg_types=None,
+        raw_keys=None,
+        raw_args=None,
+        triton_meta=None,
+        graph_name="",
+        original_fxnode_name=None,
+    ):
+        """
+        Generates kernel call code.
+
+        triton: Defines whether the GPU backend uses Triton for codegen.
+                Otherwise it uses the CUDA language for codegen.
+                Only valid when cuda == True.
+        """
+        assert arg_types is not None and len(call_args) == len(arg_types), (
+            "Mismatch call_args and arg_types in generate_kernel_call:\n"
+            f"call_args: {call_args}\n"
+            f"arg_types: {arg_types}"
+        )
+        new_args = []
+        for idx, arg in enumerate(call_args):
+            if isinstance(arg_types[idx], str) and "*" in arg_types[idx]:
+                new_args.append(f"({arg_types[idx]})({arg}.data_ptr())")
+            else:
+                # arg is a scalar - ensure it's a string for C++ codegen
+                # With Triton support, arg might be a SymPy expression or other type
+                new_args.append(str(arg) if not isinstance(arg, str) else arg)
+        # debug printer related logic for cpp kernel type.
+        debug_printer_manager = V.graph.wrapper_code.debug_printer
+        debug_printer_manager.set_printer_args(
+            call_args,
+            kernel_name,
+            None,
+            None,
+            "cpp",
+        )
+        with debug_printer_manager:
+            self.writeline(self.wrap_kernel_call(kernel_name, new_args))
+
+    def write_constant(self, name, hashed):
+        # include a hash so our code cache gives different constants different files
+        self.header.writeline(f"// {name} {hashed}")
+
+    @staticmethod
+    def get_device_include_path(device: str) -> str:
+        if V.graph.aot_mode:
+            return f"#include "
+        return f"#include "
+
+    def add_device_include(self, device: str) -> None:
+        if device in self.included_devices:
+            return
+
+        self.included_devices.add(device)
+
+        # Add the default header for this device, plus any C-shim extensions that are
+        # present.
+        self.header.splice(self.get_device_include_path(device))
+        extend_aoti_c_shim_include = (
+            f"torch/csrc/inductor/aoti_torch/generated/extend/c_shim_{self.device}.h"
+        )
+        extend_aoti_c_shim_path = os.path.join(
+            os.path.dirname(torch.__file__),
+            "include",
+            extend_aoti_c_shim_include,
+        )
+        if os.path.exists(extend_aoti_c_shim_path):
+            self.header.splice(f"#include <{extend_aoti_c_shim_include}>")
+
+    def write_header(self):
+        if V.graph.is_const_graph:
+            # We do not write header for constant graph, it will be written by main module.
+            return
+
+        if not V.graph.aot_mode:
+            self.header.splice(
+                """
+                import torch
+                from torch._inductor.codecache import CppWrapperCodeCache
+
+                cpp_wrapper_src = (
+                r'''
+                """
+            )
+
+        for device in V.graph.device_types:
+            if device != "meta":
+                self.add_device_include(device)
+
+        if V.graph.aot_mode:
+            if config.aot_inductor.dynamic_linkage:
+                with open(
+                    os.path.join(
+                        os.path.dirname(__file__), "aoti_runtime", "interface.cpp"
+                    )
+                ) as f:
+                    self.header.splice(f.read())
+            else:
+                # we produce a separate model header for each model in static linkage
+                self.header.splice(f"""#include \"{self.model_class_name_suffix}.h\"""")
+            self.header.splice("\n")
+
+        if config.cpp.enable_kernel_profile:
+            self.header.splice(
+                "#include "
+            )
+            self.header.splice(
+                """
+                namespace torch::aot_inductor {
+                thread_local KernelContext* tls_kernel_context = nullptr;
+                }
+                """
+            )
+
+    def _include_extra_header(self, header: str):
+        # This is needed for cpp to python dtype conversion
+        self.header.splice(f"#include <{header}>")
+
+    def mark_output_type(self):
+        # mark output type to unwrap tensor back to python scalar
+        from ..ir import ShapeAsConstantBuffer
+
+        output_is_tensor = {}
+        for idx, x in enumerate(V.graph.graph_outputs):
+            if isinstance(x, ShapeAsConstantBuffer):
+                output_is_tensor[idx] = False
+            else:
+                output_is_tensor[idx] = True
+
+        self.output_is_tensor = output_is_tensor
+
+    def write_prefix(self):
+        if V.graph.is_const_graph:
+            # We do not write prefix for constant graph, it will be written by main module.
+            return
+        if config.aot_inductor.custom_ops_to_c_shims:
+            # custom_ops_to_c_shims contains declaration of custom ops with C shim.
+            # TODO: this could be auto-generated from a passed-in custom op schema
+            custom_c_shims = list(
+                chain(*config.aot_inductor.custom_ops_to_c_shims.values())
+            )
+            declarations = "\n".join(
+                [f"extern {textwrap.dedent(shim)};" for shim in custom_c_shims]
+            )
+            self.prefix.splice(
+                f"""
+                extern "C" {{
+                    {declarations}
+                }}
+                """
+            )
+        if V.graph.aot_mode:
+            self.prefix.writeline("namespace torch::aot_inductor {")
+
+    def write_input_output_info(
+        self,
+        info_kind: str,
+        idx: int,
+        name: str,
+    ):
+        self.prefix.writeline(f"""{info_kind}[{idx}].name = "{name}";""")
+
+    def codegen_input_symbol_assignment(
+        self,
+        name: str,
+        value: ir.TensorBox,
+        bound_vars: OrderedSet[sympy.Symbol],
+    ):
+        code = self.prefix
+
+        @functools.cache
+        def sizeof(name):
+            self.codegen_input_size_var_decl(code, name)
+            return f"{name}_size"
+
+        @functools.cache
+        def strideof(name):
+            self.codegen_input_stride_var_decl(code, name)
+            return f"{name}_stride"
+
+        def codegen_symbol(
+            sym_or_exp: Union[sympy.Symbol, sympy.Expr],
+            base_name: str,
+            name_fn: Callable[[str], str],
+            dim: int,
+        ):
+            if isinstance(sym_or_exp, sympy.Symbol):
+                if sym_or_exp in bound_vars:
+                    return
+                code.writeline(f"int64_t {sym_or_exp} = {name_fn(base_name)}[{dim}];")
+                bound_vars.add(sym_or_exp)
+            elif isinstance(sym_or_exp, sympy.Expr):
+                undefined_symbols = [
+                    sym for sym in sym_or_exp.free_symbols if sym not in bound_vars
+                ]
+                if len(undefined_symbols) != 1:
+                    # Skip if expression contains no symbols or if multiple
+                    # symbols exists since we assume each base symbol is defined
+                    # by other codegen_symbol calls.
+                    return
+
+                from torch.utils._sympy.solve import try_solve
+
+                free_symbol = undefined_symbols.pop()
+                base_name = name_fn(base_name)
+                # Use a size symbol to solve the free symbol
+                size_symbol = sympy.Symbol(f"{base_name}_{dim}", integer=True)
+                code.writeline(f"int64_t {size_symbol} = {base_name}[{dim}];")
+                solution = try_solve(sympy.Eq(sym_or_exp, size_symbol), free_symbol)
+                if solution is not None:
+                    code.writeline(f"int64_t {free_symbol} = {cexpr(solution[1])};")
+                    bound_vars.add(free_symbol)
+                else:
+                    raise AssertionError(
+                        str(sympy.Eq(sym_or_exp, size_symbol)) + " is not solvable"
+                    )
+
+        if isinstance(value, sympy.Expr):
+            if not isinstance(value, sympy.Symbol) or value in bound_vars:
+                return
+            if value.is_integer:
+                decl = "int64_t"
+            elif value.is_float:
+                decl = "double"
+            else:
+                raise AssertionError("Unexpected symbol type")
+            code.writeline(f"{decl} {value} = {name};")
+            bound_vars.add(value)
+        elif isinstance(value, ir.TensorBox):
+            for dim, size in enumerate(value.get_size()):
+                codegen_symbol(size, name, sizeof, dim)
+            for dim, stride in enumerate(value.get_stride()):
+                codegen_symbol(stride, name, strideof, dim)
+        elif isinstance(value, ir.TorchBindObject):
+            # torchbind objects are loaded in proxy executor
+            pass
+        else:
+            raise AssertionError(f"Unknown value type: {type(value)}")
+
+    def generate_input_output_runtime_checks(self):
+        """
+        In debug_compile mode, we generate checks to ensure the dtype/shape/stride/device of each
+        real input/output tensor match ones provided at compile time via sample
+        input/output.
+        """
+
+        def gen_check(handle_kind, idx, name, tensor):
+            # Wrap AtenTensorHandle with ConstantHandle for cleaner utility function access
+            self.prefix.writeline(
+                f"ConstantHandle {name} = ConstantHandle({handle_kind}[{idx}]);"
+            )
+            self.codegen_tensor_dtype_var_decl(self.prefix, name)
+            expected_dtype_name = DTYPE_TO_ATEN[tensor.dtype]
+            dtype_str = str(tensor.dtype).split(".")[-1]
+            self.prefix.splice(
+                f"""
+                    int32_t {name}_expected_dtype = aoti_torch_dtype_{dtype_str}();
+                    if ({name}_expected_dtype != {name}_dtype) {{
+                        std::stringstream ss;
+                        ss << "{handle_kind}[{idx}]: unmatched dtype, "
+                           << "expected: " << {name}_expected_dtype << "({expected_dtype_name}), "
+                           << "but got: " << {name}_dtype << "\\n";
+                        throw std::runtime_error(ss.str());
+                    }}
+                """
+            )
+            self.codegen_input_size_var_decl(self.prefix, name)
+            for dim_idx, d in enumerate(tensor.get_size()):
+                if isinstance(d, (int, sympy.Integer)):
+                    self.prefix.splice(
+                        f"""
+                            if ({d} != {name}_size[{dim_idx}]) {{
+                                std::stringstream ss;
+                                ss << "{handle_kind}[{idx}]: unmatched dim value at {dim_idx}, "
+                                   << "expected: {d}, " << "but got: " << {name}_size[{dim_idx}]
+                                   << "\\n";
+                                throw std::runtime_error(ss.str());
+                            }}
+                        """
+                    )
+                else:
+                    from torch.utils._sympy.value_ranges import bound_sympy
+
+                    sym_range = bound_sympy(d, V.graph.sizevars.shape_env.var_to_range)
+                    if config.aot_inductor.check_lowerbound and not math.isinf(
+                        sym_range.lower
+                    ):
+                        self.prefix.splice(
+                            f"""
+                                if ({name}_size[{dim_idx}] < {sym_range.lower}) {{
+                                    std::stringstream ss;
+                                    ss << "{handle_kind}[{idx}]: dim value is too small at {dim_idx}, "
+                                       << "expected it to be >= {sym_range.lower}, " << "but got: "
+                                       << {name}_size[{dim_idx}] << "\\n";
+                                    throw std::runtime_error(ss.str());
+                                }}
+                            """
+                        )
+                    if not math.isinf(sym_range.upper):
+                        # Limit upper bound to max C long long value (2^63 - 1)
+                        max_long_long = ctypes.c_longlong(2**63 - 1).value
+                        upper_bound = min(sym_range.upper, max_long_long)
+                        self.prefix.splice(
+                            f"""
+                                if ({name}_size[{dim_idx}] > {upper_bound}) {{
+                                    std::stringstream ss;
+                                    ss << "{handle_kind}[{idx}]: dim value is too large at {dim_idx}, "
+                                       << "expected to be <= {upper_bound}, " << "but got: "
+                                       << {name}_size[{dim_idx}] << "\\n";
+                                    throw std::runtime_error(ss.str());
+                                }}
+                            """
+                        )
+
+            self.codegen_input_stride_var_decl(self.prefix, name)
+            for stride_idx, s in enumerate(tensor.get_stride()):
+                if not isinstance(s, (int, sympy.Integer)):
+                    continue
+                self.prefix.splice(
+                    f"""
+                        if ({s} != {name}_stride[{stride_idx}]) {{
+                            std::stringstream ss;
+                            ss << "{handle_kind}[{idx}]: unmatched stride value at {stride_idx}, "
+                               << "expected: {s}, " << "but got: " << {name}_stride[{stride_idx}]
+                               << "\\n";
+                            throw std::runtime_error(ss.str());
+                        }}
+                    """
+                )
+
+            # check input device type
+            if isinstance(tensor, ir.TensorBox):
+                tensor_device = tensor.get_device()
+                if tensor_device is not None:
+                    expected_device_type = DEVICE_TO_INT.get(tensor_device.type)
+                    if expected_device_type is not None:
+                        self.codegen_input_device_type_var_decl(self.prefix, name)
+                        device_type_str = str(tensor_device.type)
+                        self.prefix.splice(
+                            f"""
+                                int32_t {name}_expected_device_type = {expected_device_type};
+                                if ({name}_expected_device_type != {name}_device_type) {{
+                                    std::stringstream ss;
+                                    ss << "{handle_kind}[{idx}]: unmatched device type, "
+                                    << "expected: " << {name}_expected_device_type << "{expected_device_type}({device_type_str}), "
+                                    << "but got: " << {name}_device_type << "\\n";
+                                    throw std::runtime_error(ss.str());
+                                }}
+                            """
+                        )
+
+        # Create a separate function for each input check to avoid "too big to optimize" error
+        for idx, (name, tensor) in enumerate(V.graph.graph_inputs.items()):
+            self.prefix.splice(
+                f"""
+                AOTI_NOINLINE static void check_input_{idx}(
+                    AtenTensorHandle* input_handles
+                ) {{
+                """
+            )
+            with self.prefix.indent():
+                gen_check("input_handles", idx, name, tensor)
+            self.prefix.writeline("}")
+
+        # force noinline to avoid any potential compilation slowdown due to aggressive
+        # inline done by the host compiler
+        self.prefix.splice(
+            """
+            static bool _check_aoti_runtime_check_inputs_env() {
+                const static char* env_var_value = getenv("AOTI_RUNTIME_CHECK_INPUTS");
+                const static bool result = env_var_value != nullptr && env_var_value[0] != '0';
+                return result;
+            }
+
+            AOTI_NOINLINE static void __check_inputs_outputs(
+                AtenTensorHandle* input_handles,
+                AtenTensorHandle* output_handles) {
+                if (!_check_aoti_runtime_check_inputs_env()){
+                    return;
+                }
+            """
+        )
+        with self.prefix.indent():
+            for idx in range(len(V.graph.graph_inputs)):
+                self.prefix.writeline(f"check_input_{idx}(input_handles);")
+        self.prefix.writeline("}")
+
+    def write_wrapper_decl(self):
+        inputs_len = len(V.graph.graph_inputs.keys())
+        if V.graph.aot_mode:
+            self.codegen_additional_funcs()
+
+            if V.graph.const_module:
+                self.header.splice(V.graph.const_module.wrapper_code.header)
+
+                assert V.graph.const_wrapper_code is not None
+                self.prefix.splice(V.graph.const_wrapper_code)
+
+                assert V.graph.const_kernel_code is not None
+                self.kernel_declarations.splice(V.graph.const_kernel_code)
+
+            if V.graph.is_const_graph:
+                self.prefix.splice(
+                    f"""
+                    void {self.aoti_model_class_name}::_const_run_impl(
+                        std::vector& output_handles,
+                        DeviceStreamType stream,
+                        AOTIProxyExecutorHandle proxy_executor
+                    ) {{
+                    """
+                )
+            else:
+                if not config.aot_inductor.use_runtime_constant_folding:
+                    # If we do not split the constant graph, we'll just create
+                    # an empty implementation when wrapping the main module.
+                    self.prefix.splice(
+                        f"""
+                        void {self.aoti_model_class_name}::_const_run_impl(
+                            std::vector& output_handles,
+                            DeviceStreamType stream,
+                            AOTIProxyExecutorHandle proxy_executor
+                        ) {{}}
+
+                        """
+                    )
+
+                run_impl_proto = f"""
+                    void {self.aoti_model_class_name}::run_impl(
+                        AtenTensorHandle*
+                            input_handles, // array of input AtenTensorHandle; handles
+                                            // are stolen; the array itself is borrowed
+                        AtenTensorHandle*
+                            output_handles, // array for writing output AtenTensorHandle; handles
+                                            // will be stolen by the caller; the array itself is
+                                            // borrowed
+                        DeviceStreamType stream,
+                        AOTIProxyExecutorHandle proxy_executor
+                    ) {{
+                        __check_inputs_outputs(input_handles, output_handles);
+                    """
+
+                self.generate_input_output_runtime_checks()
+                self.prefix.splice(run_impl_proto)
+        else:
+            # cpp entry function for JIT with cpp wrapper
+            self.prefix.splice(
+                """
+                void inductor_entry_impl(
+                    AtenTensorHandle*
+                        input_handles, // array of input AtenTensorHandle; handles
+                                        // are stolen; the array itself is borrowed
+                    AtenTensorHandle*
+                        output_handles  // array for writing output AtenTensorHandle; handles
+                                        // will be stolen by the caller; the array itself is
+                                        // borrowed)
+                ) {
+                """
+            )
+        with self.prefix.indent():
+            # assign inputs and outputs in both cases so the later codegen can be simplified
+            if not V.graph.is_const_graph:
+                if V.graph.aot_mode:
+                    num_args = len(V.graph.graph_inputs)
+                else:
+                    # Weights are promoted in the JIT mode
+                    num_args = len(V.graph.graph_inputs) + len(V.graph.constants)
+                    # release GIL to support multiple instances inference (in different threads of the same process)
+                    self.prefix.splice("py::gil_scoped_release_simple release;")
+
+                self.prefix.splice(
+                    f"""
+                        auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, {num_args});
+                    """
+                )
+
+            if inputs_len != 0:
+                for idx, input_key in enumerate(V.graph.graph_inputs.keys()):
+                    # unwrap input tensor back to scalar
+                    if isinstance(V.graph.graph_inputs[input_key], sympy.Expr):
+                        from ..graph import may_get_constant_buffer_dtype
+
+                        dtype = may_get_constant_buffer_dtype(
+                            V.graph.graph_inputs[input_key]  # type: ignore[arg-type]
+                        )
+                        assert dtype is not None, (
+                            "Fails to get the dtype of the sympy.Expr"
+                        )
+                        self.codegen_tensor_item(
+                            dtype, f"inputs[{idx}]", input_key, self.prefix
+                        )
+                    else:
+                        self.prefix.writeline(
+                            f"auto {input_key} = std::move(inputs[{idx}]);"
+                        )
+                # debug printing for all input args to AOTI model
+                debug_printer_manager = V.graph.wrapper_code.debug_printer
+                debug_printer_manager.codegen_model_inputs_value_print(
+                    input_args_to_print=[
+                        input_key
+                        for input_key in V.graph.graph_inputs
+                        if input_key.startswith("arg")
+                    ]
+                )
+
+            assert all(
+                isinstance(v, torch.Tensor) for v in list(V.graph.constants.values())
+            ), "Expect all constants to be Tensor"
+            for idx, constants_key in enumerate(V.graph.constants.keys()):
+                if V.graph.aot_mode:
+                    # Weights are stored in constants_ and owned by ConstantHandle there.
+                    # Don't call std::move here because it will cause constants_ to lose the ownership.
+                    self.prefix.writeline(
+                        f"""[[maybe_unused]] auto& {constants_key} = constants_->at({idx});"""
+                    )
+                else:
+                    # Append constants as inputs to the graph
+                    constants_idx = inputs_len + idx
+                    self.prefix.writeline(
+                        f"[[maybe_unused]] auto {constants_key} = std::move(inputs[{constants_idx}]);"
+                    )
+
+            self.codegen_inputs()
+
+            if V.graph.aot_mode:
+                if not V.graph.is_const_graph:
+                    self.prefix.writeline("inputs.clear();")
+                self.prefix.writeline(
+                    "[[maybe_unused]] auto& kernels = static_cast(*this->kernels_.get());"
+                )
+
+    def codegen_tensor_dtype_var_decl(self, code: IndentedBuffer, name):
+        code.writeline(f"int32_t {name}_dtype;")
+        code.writeline(
+            f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype({name}, &{name}_dtype));"
+        )
+
+    def codegen_input_size_var_decl(self, code: IndentedBuffer, name):
+        code.writeline(f"auto {name}_size = {name}.sizes();")
+
+    def codegen_input_stride_var_decl(self, code: IndentedBuffer, name):
+        code.writeline(f"auto {name}_stride = {name}.strides();")
+
+    def codegen_input_device_type_var_decl(self, code: IndentedBuffer, name):
+        code.writeline(f"int32_t {name}_device_type;")
+        code.writeline(
+            f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type({name}, &{name}_device_type));"
+        )
+
+    def codegen_additional_funcs(self):
+        pass
+
+    def codegen_model_kernels(self):
+        self.prefix.writeline("namespace {")
+
+        # Tell compiler we need to link with the non-mangled symbols
+        for kernel in self.initialized_kernels.values():
+            assert hasattr(kernel, "get_signature"), (
+                f"{kernel} must have get_signature implemented"
+            )
+            signature = kernel.get_signature()
+            self.prefix.writeline(f'extern "C" {signature};')
+
+        self.prefix.writeline(
+            "class AOTInductorModelKernels : public AOTInductorModelKernelsBase {"
+        )
+        self.prefix.writeline("  public:")
+        declare_kernel = OrderedSet(self.src_to_kernel.values()) - OrderedSet(
+            self.initialized_kernels.keys()
+        )
+        declare_kernel.update(
+            entry[0] for entry in self.user_defined_kernel_cache.values()
+        )
+        if V.graph.const_module:
+            declare_kernel.update(
+                V.graph.const_module.wrapper_code.src_to_kernel.values()
+            )
+        for kernel in sorted(declare_kernel):
+            self.prefix.writeline(
+                maybe_hipify_code_wrapper(
+                    f"    {self.device_codegen.cpp_kernel_type()} {kernel}{{nullptr}};"
+                )
+            )
+        for name, kernel in self.initialized_kernels.items():
+            assert hasattr(kernel, "get_signature"), (
+                f"{kernel} must have get_signature implemented"
+            )
+            kernel_ptr = f"(*{name})"
+            signature = kernel.get_signature().replace(name, kernel_ptr)
+            self.prefix.writeline(f"    {signature} = torch::aot_inductor::{name};")
+        self.prefix.writeline("};")
+        self.prefix.writeline("}  // namespace\n\n")
+
+        if config.aot_inductor.embed_kernel_binary:
+            self.prefix.writeline('extern "C" {')
+            for name in sorted(declare_kernel):
+                self.prefix.writeline(
+                    f"    extern const unsigned char __{name}_start[];"
+                )
+                if torch.xpu.is_available():
+                    self.prefix.writeline(
+                        f"    extern const unsigned char __{name}_end[];"
+                    )
+            self.prefix.writeline("}")
+
+    # MSVC string was longer than the limit of 16380 single-byte characters.
+    # https://learn.microsoft.com/en-us/cpp/error-messages/compiler-errors-1/compiler-error-c2026
+    MSVC_C2026_MAX_STRING_LENGTH = 16000
+
+    def codegen_write_arg_with_large_length_string(
+        self,
+        arg_name: str,
+        arg_str_val: str,
+        max_truncate_length: int = MSVC_C2026_MAX_STRING_LENGTH,
+    ):
+        def truncate_string(s: str, length: int) -> list[str]:
+            return [s[i : i + length] for i in range(0, len(s), length)]
+
+        if len(arg_str_val) > max_truncate_length:
+            truncated_strs = truncate_string(arg_str_val, max_truncate_length)
+            self.prefix.writeline(f"{arg_name} =")
+            for truncate_str in truncated_strs:
+                self.prefix.writeline(f'R"({truncate_str})"')
+            self.prefix.writeline(";")
+        else:
+            self.prefix.writeline(f'{arg_name} = R"({arg_str_val})";')
+
+    def codegen_model_constructor(self):
+        """
+        // Generated code example
+        AOTInductorModel::AOTInductorModel()
+            : AOTInductorModelBase(4, 1) {
+        inputs_info_[0].name = "input0";
+        inputs_info_[0].dtype = "torch.float16";
+        ...
+        constants_info_[0].name = "L__self___weight";
+        constants_info_[0].dtype = at::kFloat;
+        constants_info_[0].offset = 0;
+        constants_info_[0].data_size = 8192;
+        constants_info_[0].shape = {64, 32};
+        constants_info_[0].stride = {32, 1};
+        ...
+        outputs_info_[0].name = "output0";
+        outputs_info_[0].dtype = "torch.float16";
+        }
+        """
+
+        num_inputs = len(V.graph.graph_inputs)
+        num_outputs = len(V.graph.graph_outputs)
+        num_constants = len(V.graph.constants)
+        include_weights = (
+            "true"
+            if config.aot_inductor.package_constants_in_so
+            and config.aot_inductor.package_constants_on_disk_format != "binary_blob"
+            else "false"
+        )
+        self.prefix.splice(
+            f"""
+            {self.aoti_model_class_name}::{self.aoti_model_class_name}(std::shared_ptr constants_map,
+                                               std::shared_ptr> constants_array,
+                                               const std::string& device_str,
+                                               std::optional cubin_dir)
+                : AOTInductorModelBase({num_inputs},
+                                       {num_outputs},
+                                       {num_constants},
+                                       device_str,
+                                       std::move(cubin_dir),
+                                       {include_weights}) {{
+            """
+        )
+
+        with self.prefix.indent():
+            for idx, (name, inp) in enumerate(V.graph.graph_inputs.items()):
+                assert not isinstance(inp, sympy.Expr), (
+                    f"input {name=} cannot be symbolic"
+                )
+                self.write_input_output_info("inputs_info_", idx, name)
+
+            all_cuda = all(
+                V.graph.get_original_value_of_constant(name).is_cuda
+                for name in V.graph.constants
+                if name not in V.graph.folded_constants
+            )
+            for idx, name in enumerate(V.graph.constants.keys()):
+                tensor = V.graph.get_original_value_of_constant(name)
+                assert isinstance(tensor, torch.Tensor)
+                self.prefix.writeline(f"""constants_info_[{idx}].name = "{name}";""")
+                self.prefix.writeline(
+                    f"constants_info_[{idx}].dtype = static_cast({self.codegen_dtype(tensor.dtype)});"
+                )
+                self.prefix.writeline(
+                    f"constants_info_[{idx}].offset = {tensor.storage_offset()};"
+                )
+
+                # If constants to serialize contain cpu tensors, we always align data_size it to 64.
+                # When loading the constants, the valid data will depends on the size
+                # not the data_size so there won't be correctness issue.
+                data_size = (
+                    torch.ops.mkldnn._nbytes(tensor)
+                    if tensor.is_mkldnn
+                    else tensor.untyped_storage().nbytes()
+                )
+                self.prefix.writeline(
+                    f"constants_info_[{idx}].data_size = {data_size if all_cuda else _align(data_size)};"
+                )
+
+                from_folded = "true" if name in V.graph.folded_constants else "false"
+                self.prefix.writeline(
+                    f"constants_info_[{idx}].from_folded = {from_folded};"
+                )
+
+                if name in V.graph.folded_constants:
+                    constant_type_str = "FoldedConstant"
+                elif name.startswith("_tensor_constant"):
+                    constant_type_str = "TensorConstant"
+                elif any(
+                    name == normalize_name(parameter_name)
+                    for parameter_name in V.graph.named_parameters
+                ):
+                    constant_type_str = "Parameter"
+                elif any(
+                    name == normalize_name(buffer_name)
+                    for buffer_name in V.graph.named_buffers
+                ):
+                    constant_type_str = "Buffer"
+                else:
+                    constant_type_str = "Unknown"
+                self.prefix.writeline(
+                    f"constants_info_[{idx}].type = static_cast(torch::aot_inductor::ConstantType::{constant_type_str});"
+                )
+
+                size_str = ", ".join([str(s) for s in tensor.size()])
+                self.prefix.writeline(f"constants_info_[{idx}].shape = {{{size_str}}};")
+
+                stride_str = ", ".join([str(s) for s in tensor.stride()])
+                self.prefix.writeline(
+                    f"constants_info_[{idx}].stride = {{{stride_str}}};"
+                )
+                self.prefix.writeline(
+                    f"constants_info_[{idx}].layout = static_cast({self.codegen_layout(tensor.layout)});"
+                )
+
+                if tensor.is_mkldnn:
+                    opaque_metadata_tensor = torch.ops.mkldnn._get_mkldnn_serialized_md(
+                        tensor
+                    )
+                    assert opaque_metadata_tensor.dim() == 1, (
+                        "Expect opaque_metadata_tensor to be 1-D"
+                    )
+
+                    opaque_metadata_list = opaque_metadata_tensor.tolist()
+                    opaque_metadata_str = self.codegen_shape_tuple(opaque_metadata_list)
+                    self.prefix.writeline(
+                        f"constants_info_[{idx}].opaque_metadata = {opaque_metadata_str};"
+                    )
+                if name in V.graph.dynamo_flat_name_to_original_fqn:
+                    original_fqn = V.graph.dynamo_flat_name_to_original_fqn.get(
+                        name, name
+                    )
+                elif name in V.graph.allocated_constant_name:
+                    original_fqn = V.graph.allocated_constant_name[name]
+                else:
+                    raise AssertionError("original_fqn must be set for constant")
+                self.prefix.writeline(
+                    f"""constants_info_[{idx}].original_fqn = "{original_fqn}";"""
+                )
+            self.prefix.writeline("update_constants_map(std::move(constants_map));")
+            self.prefix.writeline("update_constants_array(std::move(constants_array));")
+
+            def escape_string(x):
+                return (
+                    x.replace("\\", "\\\\")
+                    .replace('"', '\\"')
+                    .replace("\n", "\\n")
+                    .replace("\t", "\\t")
+                )
+
+            # Origin code: self.prefix.writeline(f'in_spec_ = R"({config.aot_inductor.serialized_in_spec})";')
+            # Fix msvc C2026 error via codegen_write_arg_with_large_length_string
+            self.codegen_write_arg_with_large_length_string(
+                arg_name="in_spec_", arg_str_val=config.aot_inductor.serialized_in_spec
+            )
+            # Origin code: self.prefix.writeline(f'out_spec_ = R"({config.aot_inductor.serialized_out_spec})";')
+            # Fix msvc C2026 error via codegen_write_arg_with_large_length_string
+            self.codegen_write_arg_with_large_length_string(
+                arg_name="out_spec_",
+                arg_str_val=config.aot_inductor.serialized_out_spec,
+            )
+
+            for idx, output in enumerate(V.graph.graph_outputs):
+                assert not isinstance(output, sympy.Expr), (
+                    f"output {name=} cannot be symbolic"
+                )
+                name = f"output{idx}"
+                self.write_input_output_info("outputs_info_", idx, name)
+
+            self.prefix.writeline(
+                "this->kernels_ = std::make_unique();"
+            )
+
+        self.prefix.writeline("}")
+
+    def codegen_const_run_driver(self):
+        """
+        // Generated code example
+        std::unordered_map AOTInductorModel::const_run_impl(
+            DeviceStreamType stream,
+            AOTIProxyExecutorHandle proxy_executor,
+            bool initialization
+        ) {
+            std::unordered_map folded_constants_map;
+            std::vector output_handles;
+            // build up output_handles over here.
+            _const_run_impl(output_handles, stream, proxy_executor);
+            // build up folded_constants_map
+            return folded_constants_map;
+        }
+        """
+
+        self.prefix.splice(
+            f"""
+            std::unordered_map {self.aoti_model_class_name}::const_run_impl(
+                DeviceStreamType stream,
+                AOTIProxyExecutorHandle proxy_executor,
+                bool initialization
+            ) {{
+            """
+        )
+        if not config.aot_inductor.use_runtime_constant_folding:
+            self.prefix.splice(
+                """
+                    if (!initialization) {
+                        std::cerr << "[WARNING] Calling constant_folding in model, but compiled with config: "
+                                  << "aot_inductor.use_runtime_constant_folding=False\\n";
+                    }
+                    return {};
+                }
+                """
+            )
+            return
+
+        with self.prefix.indent():
+            # This is a mapping to the index of constant folding graph's output
+            const_index_mapping: list[Optional[tuple[int, str]]] = [None] * len(
+                V.graph.const_output_index
+            )
+            for idx, (name, _) in enumerate(V.graph.constants.items()):
+                if name in V.graph.const_output_index:
+                    const_index_mapping[V.graph.const_output_index[name]] = (idx, name)  # type: ignore[call-overload]
+            assert None not in const_index_mapping, (
+                "Not all constant gets mapped for constant folding graph."
+            )
+
+            self.prefix.writeline(
+                f"""
+                std::unordered_map folded_constants_map;
+                folded_constants_map.reserve({len(const_index_mapping)});
+                std::vector output_handles({len(const_index_mapping)});
+                """
+            )
+
+            self.prefix.splice(
+                """
+                // The below assignment of output_handles to constants is not used directly.
+                // It's only used to memo the correspondence of handle and constants.
+                """
+            )
+
+            for output_idx, (const_idx, _) in enumerate(const_index_mapping):  # type: ignore[misc]
+                self.prefix.writeline(
+                    f"output_handles[{output_idx}] = constants_->at({const_idx});"
+                )
+
+            self.prefix.writeline(
+                "_const_run_impl(output_handles, stream, proxy_executor);"
+            )
+
+            for output_idx, (_, const_name) in enumerate(const_index_mapping):  # type: ignore[misc]
+                self.prefix.writeline(
+                    f'folded_constants_map["{const_name}"] = output_handles[{output_idx}];'
+                )
+            self.prefix.writeline("return folded_constants_map;")
+
+        self.prefix.writeline("}")
+
+    def generate(self, is_inference):
+        with dynamo_timed("CppWrapperCpu.generate", log_pt2_compile_event=True):
+            self.write_wrapper_decl()
+            return super().generate(is_inference)
+
+    def finalize_prefix(self):
+        prior = self.prefix
+        self.prefix = aot_mode_decls = IndentedBuffer()
+        if V.graph.aot_mode and not V.graph.is_const_graph:
+            aot_mode_decls.writeline("namespace torch::aot_inductor {")
+            self.codegen_model_kernels()
+            self.codegen_model_constructor()
+            self.codegen_const_run_driver()
+            aot_mode_decls.writeline("} // namespace torch::aot_inductor")
+            aot_mode_decls.writeline("using namespace torch::aot_inductor;")
+
+        self.prefix = cache_decls = IndentedBuffer()
+        for dtype in self.used_cached_dtypes:
+            cache_decls.writeline(f"CACHE_TORCH_DTYPE({dtype});")
+        for device in self.used_cached_devices:
+            cache_decls.writeline(f"CACHE_TORCH_DEVICE({device});")
+        for layout in self.used_cached_layouts:
+            cache_decls.writeline(f"CACHE_TORCH_LAYOUT({layout});")
+        for memory_format in self.used_cached_memory_formats:
+            cache_decls.writeline(f"CACHE_TORCH_MEMORY_FORMAT({memory_format});")
+
+        self.prefix.splice(aot_mode_decls)
+        self.prefix.splice(prior)
+
+    def _define_kernel_helper(
+        self,
+        kernel_name: str,
+        kernel_body: str,
+        metadata: Optional[str] = None,
+        gpu: bool = False,
+        cpp_definition: Optional[str] = None,
+    ):
+        if cpp_definition is not None:
+            self.header.splice(cpp_definition)
+            self.kernel_declarations.splice(f"\n{kernel_body}\n")
+        else:
+            self.header.splice(f"\n{kernel_body}\n")
+
+    def codegen_scalar_to_tensor(self, output: str):
+        name = f"scalar_to_tensor_{next(self.scalar_to_tensor_id)}"
+        self.wrapper_call.writeline(
+            f"RAIIAtenTensorHandle {name} = scalar_to_tensor_handle({output});"
+        )
+        return name
+
+    def codegen_tensor_item(
+        self, dtype: torch.dtype, tensor: str, scalar: str, indented_buffer=None
+    ):
+        dtype_str = str(dtype).split(".")[-1]
+        writer = indented_buffer or self
+
+        if dtype == torch.float16 or dtype == torch.bfloat16:
+            scalar_tmp = f"{scalar}_tmp"
+            writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar_tmp};")
+            writer.writeline(
+                f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar_tmp}));"
+            )
+            writer.writeline(f"float {scalar} = float({scalar_tmp});")
+        else:
+            writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar};")
+            writer.writeline(
+                f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar}));"
+            )
+
+    def generate_return(self, output_refs: list[str]):
+        cst_names = V.graph.constants.keys()
+        output2idx: dict[str, int] = {}
+
+        # If any output ref represents an rvalue tensor, materialize it to an lvalue
+        # RAIIAtenTensorHandle first.  This prevents situations where the code for the
+        # rvalue tensor references tensor handles whose contents are modified below.
+        output_refs = [
+            self.create_tmp_raii_handle_var_if_needed(o, self.wrapper_call)
+            for o in output_refs
+        ]
+
+        for idx, output in enumerate(output_refs):
+            if output == "nullptr":
+                continue
+
+            is_constant_buffer = output in cst_names
+            output_buffer = V.graph.graph_outputs[idx]
+            if isinstance(output_buffer, ir.BaseView):
+                output_storage = output_buffer.unwrap_view()
+                assert isinstance(output_storage, (ir.BaseView, ir.MutableBox))
+                if isinstance(output_storage.data, ir.ConstantBuffer):
+                    is_constant_buffer = True
+
+            if isinstance(output_buffer, ir.ShapeAsConstantBuffer):
+                # Need to wrap scalar into tensor as the main function returns a vector of tensors
+                output_tensor = self.codegen_scalar_to_tensor(output)
+                self.wrapper_call.writeline(
+                    f"output_handles[{idx}] = {output_tensor}.release();"
+                )
+                continue
+
+            if is_constant_buffer:
+                # See NOTE(return_constant) above.
+                self.wrapper_call.writeline(
+                    f"aoti_torch_clone({output}, &output_handles[{idx}]);"
+                )
+            else:
+                if output in output2idx:
+                    src_idx = output2idx[output]
+                    self.wrapper_call.writeline(
+                        f"output_handles[{idx}] = output_handles[{src_idx}];"
+                    )
+                else:
+                    self.wrapper_call.writeline(
+                        f"output_handles[{idx}] = {output}.release();"
+                    )
+
+            if output not in output2idx:
+                output2idx[output] = idx
+
+    def generate_before_suffix(self, result):
+        if not V.graph.is_const_graph:
+            if V.graph.aot_mode:
+                result.writeline(f"}} // {self.aoti_model_class_name}::run_impl")
+            else:
+                result.writeline("} // inductor_entry_impl")
+
+    def generate_end(self, result):
+        """Generates the end of the code block, and any code needed to call it."""
+        if V.graph.aot_mode:
+            if V.graph.is_const_graph:
+                result.writeline(f"}} // {self.aoti_model_class_name}::_const_run_impl")
+            else:
+                result.writeline("} // namespace torch::aot_inductor\n\n\n")
+            return
+
+        if config.cpp_wrapper_build_separate:
+            # Close the wrapper code block, then write any kernel definitions.
+            result.splice("'''\n)")
+            if self.kernel_declarations:
+                result.splice("\nkernel_src = (\nr'''")
+                result.splice(self.kernel_declarations.getvalue())
+                result.splice("'''\n)")
+            else:
+                result.splice(
+                    """
+                    kernel_src = ''
+                    """
+                )
+        else:
+            # Merge main code and kernel code
+            result.splice(self.kernel_declarations.getvalue())
+            self.kernel_declarations.clear()
+            # Close the wrapper code block
+            result.splice("'''\n)")
+
+        kernel_code = "kernel_src" if config.cpp_wrapper_build_separate else "None"
+        # Cpp entry function for JIT with cpp wrapper
+        result.splice(
+            f"""
+            inductor_entry = CppWrapperCodeCache.load_pybinding(
+                argtypes=["std::vector"],
+                main_code=cpp_wrapper_src,
+                device_type="{self.device}",
+                num_outputs={len(V.graph.graph_outputs)},
+                kernel_code={kernel_code},
+            )
+            """
+        )
+
+        wrapper_body = "input_tensors = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg, device='cpu') for arg in args]"
+        if V.graph.constants:
+            # Append constants to the input args for cpp wrapper.
+            # Python wrapper directly gets the value inside the wrapper call
+            # as a global variable passed when calling exec(code, mod.__dict__, mod.__dict__).
+            # For cpp wrapper, we need to pass this python value to the inductor_entry_impl function explicitly.
+            assert all(
+                isinstance(v, torch.Tensor) for v in list(V.graph.constants.values())
+            ), "Expect all constants to be Tensor"
+            constants_str = f"[{', '.join(V.graph.constants.keys())}]"
+            wrapper_body += f"""
+                    constants_tensor = {constants_str}
+                    input_tensors.extend(constants_tensor)
+            """
+        # Convert vector of at::Tensor to vector of AtenTensorHandle.
+        # If we pass at::Tensor, the compilation will be too slow.
+        wrapper_body += """
+                    input_handles = torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(input_tensors)
+        """
+        # Release the inputs for memory reuse.
+        wrapper_body += """
+                    args.clear()
+                    del input_tensors
+        """
+
+        # unwrap output tensor back to python scalar
+        if all(x for x in self.output_is_tensor.values()):
+            # If no ShapeAsConstantBuffer in the output, directly return the output as tensors
+            outputs_str = "output_tensors"
+        else:
+            outputs = [
+                (
+                    f"output_tensors[{i}]"
+                    if self.output_is_tensor[i]
+                    else f"output_tensors[{i}].item()"
+                )
+                for i in range(len(V.graph.graph_outputs))
+            ]
+            outputs_str = f"[{', '.join(outputs)}]"
+        wrapper_body += f"""
+                    output_handles = f(input_handles)
+                    output_tensors = torch._C._aoti.alloc_tensors_by_stealing_from_void_ptrs(output_handles)
+                    return {outputs_str}
+        """
+
+        # Wrap the func to support setting result._boxed_call = True
+        result.splice(
+            f"""
+            def _wrap_func(f):
+                def g(args):
+                    {wrapper_body}
+                return g
+
+            call = _wrap_func(inductor_entry)
+            """
+        )
+
+    @staticmethod
+    def get_c_shim_func_name(kernel: str, device: str) -> str:
+        if kernel.startswith("aoti_torch_"):
+            return kernel
+
+        assert "::" in kernel, "Cpp kernel name: " + kernel + " does not contain '::'"
+        kernel_tokens = kernel.split("::")
+        kernel_suffix = kernel_tokens[-1]
+        if kernel_suffix == "call":
+            kernel_suffix = kernel_tokens[-2]
+
+        shim_fn = f"aoti_torch_{device}_{kernel_suffix}"
+        return shim_fn
+
+    def generate_c_shim_extern_kernel_call(
+        self,
+        kernel: str,
+        args: list[str],
+        device: str,
+        *,
+        debug_args: Optional[list[str]] = None,
+        stack_traces: Optional[OrderedSet[str]] = None,
+    ) -> None:
+        """debug_args kwarg allows CppWrapperCpuArrayRef to pass in wrapped arguments in
+        place of args while preserving debug printer output."""
+        # We can do this unconditionally, since we cache this call.
+        self.add_device_include(device)
+
+        debug_printer_manager = V.graph.wrapper_code.debug_printer
+        debug_printer_manager.set_printer_args(
+            debug_args if debug_args is not None else args, kernel, None, None, "extern"
+        )
+        enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [
+            "linux",
+            "win32",
+        ]
+        with debug_printer_manager:
+            shim_fn = self.get_c_shim_func_name(kernel, device)
+            shim_fn_codes = [
+                f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(args)}));"
+            ]
+            if enable_kernel_profile:
+                stack_trace_str = 'R"('
+                if stack_traces:
+                    for stack_trace in stack_traces:
+                        for line in stack_trace.split("\n"):
+                            stack_trace_str += f"\n{line}"
+                        stack_trace_str += "\n"
+                stack_trace_str += ')"'
+
+                shim_fn_codes = [
+                    "{",
+                    f"""KernelContextGuard _ctx("{shim_fn}", {stack_trace_str});""",
+                    f"""RAIIAtenRecordFunctionHandle record_{shim_fn}_("{shim_fn}", nullptr);""",
+                    shim_fn_codes[0],
+                    "}",
+                ]
+            self.writelines(shim_fn_codes)
+
+    def generate_c_shim_extern_kernel_alloc(
+        self, extern_kernel: ir.ExternKernelAlloc, args: list[str]
+    ) -> None:
+        # registered output buffer name
+        name = extern_kernel.name
+        output_handle_name = f"{name}_handle"
+        is_inplace = (
+            isinstance(extern_kernel.op_overload, torch._ops.OpOverload)
+            and torch.Tag.inplace_view in extern_kernel.op_overload.tags
+        )
+
+        if not is_inplace:
+            self.writeline(f"AtenTensorHandle {output_handle_name};")
+            args = [*args, f"&{output_handle_name}"]
+
+        device = d.type if (d := extern_kernel.get_device()) else self.device
+
+        self.generate_c_shim_extern_kernel_call(
+            extern_kernel.get_kernel_name(), args, device
+        )
+
+        if extern_kernel.python_kernel_name in (
+            "torch.ops._c10d_functional.all_reduce_.default",
+            "torch.ops._c10d_functional.wait_tensor.default",
+        ):
+            # all_reduce_ is an inplace op and its returned tensor is not used anywhere.
+            # wait_tensor returns its input without any modification and the returned tensor is not used anywhere.
+            # In both cases, we can immediately delete the returned AtenTensorHandle to reduce its lifetime.
+            self.writeline(
+                f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object({output_handle_name}));"
+            )
+        elif not is_inplace:
+            self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});")
+
+    def _generate_extern_kernel_alloc_helper(self, extern_kernel, args):
+        if getattr(extern_kernel, "outputs", None):
+            # ir.ExternKernelAlloc may have outputs if it returns a tuple
+            self.generate_c_shim_fallback_kernel(extern_kernel, args)
+        else:
+            self.generate_c_shim_extern_kernel_alloc(extern_kernel, args)
+
+    def generate_c_shim_fallback_kernel(
+        self, fallback_kernel: ir.FallbackKernel, args: list[str]
+    ) -> None:
+        output_args = []
+        output_raii_handles = []
+        output_name_base = fallback_kernel.get_name()
+        for idx, output in enumerate(fallback_kernel.outputs):
+            if isinstance(output, ir.MultiOutput):
+                # TODO: handle integer output (e.g., as in attention)
+                name = f"{output.get_name()}"
+                output_handle_name = f"{name}_handle"
+                if output.indices:
+                    assert output.indices[0][1] == idx, (
+                        f"expected {output.indices[0][1]=} == {idx=} for {output_name_base=}"
+                    )
+                self.writeline(f"AtenTensorHandle {output_handle_name};")
+                output_args.append(f"&{output_handle_name}")
+                output_raii_handles.append(
+                    f"RAIIAtenTensorHandle {name}({output_handle_name});"
+                )
+            elif isinstance(output, int):
+                output_name = f"{output_name_base}_{idx}"
+                self.writeline(f"int64_t {output_name} = {output};")
+                output_args.append(f"&{output_name}")
+            elif isinstance(output, sympy.Expr):
+                output_name = f"{output_name_base}_{idx}"
+                self.writeline(f"auto {output_name} = {cexpr(output)};")
+                output_args.append(f"&{output_name}")
+            elif output is None:
+                output_args.append("nullptr")
+            else:
+                raise NotImplementedError(f"unsupported type of {output=}")
+        args = args + output_args
+        device = d.type if (d := fallback_kernel.get_device()) else self.device
+
+        self.generate_c_shim_extern_kernel_call(
+            fallback_kernel.cpp_kernel_name,  # type: ignore[arg-type]
+            args,
+            device,
+        )
+        for raii_handle in output_raii_handles:
+            self.writeline(raii_handle)
+
+    def _generate_extern_kernel_out_helper(
+        self,
+        kernel: str,
+        out: str,
+        out_view: Optional[str],
+        args: list[str],
+        device: str,
+        stack_traces: Optional[OrderedSet[str]] = None,
+    ) -> None:
+        if out_view:
+            out_name = f"{out}_as_strided"
+            self.writeline(f"auto {out_name} = {out_view};")
+            args.insert(0, out_name)
+        else:
+            args.insert(0, out)
+
+        self.generate_c_shim_extern_kernel_call(
+            kernel, args, device, stack_traces=stack_traces
+        )
+
+    def _get_scatter_reduce_enum(self, reduce):
+        # Follow aten/src/ATen/native/ReductionType.h:get_operator_enum
+        get_operator_enum = {"add": "sum", "multiply": "prod"}
+        if reduce in get_operator_enum:
+            reduce = get_operator_enum[reduce]
+
+        return reduce
+
+    def _generate_scatter_fallback(
+        self,
+        output,
+        inputs,
+        cpp_kernel_name,
+        python_kernel_name,
+        src_is_tensor,
+        reduce,
+        kwargs,
+        device,
+    ):
+        reduce = self._get_scatter_reduce_enum(reduce)
+
+        # call the ABI shim function instead of the ATen one
+        self.add_device_include(device)
+        cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, device)
+        # TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py
+        cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out"
+        inputs_wrapped = [str(x) for x in inputs]
+        line = f"{cpp_kernel_name}({output}, {','.join(inputs_wrapped)}"
+
+        if python_kernel_name.startswith("aten.scatter_reduce"):
+            line += f", {','.join(kwargs)}"
+        else:
+            if src_is_tensor:
+                if reduce:
+                    line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}"
+            else:
+                assert reduce is None, (
+                    "Expect reduce to be None for aten.scatter_ with scalar src"
+                )
+        line += ");"
+        self.writeline(line)
+
+    def _generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
+        # TODO: update aoti_torch_index_put_out in ir.py to use autogen out version
+        # See the comment in codegen_reinterpret_view about why having something like
+        # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the corresponding
+        # tensor prematurely deallocated, thus the temporary array trick here.
+        indices_str = self._generate_temporary_array_pointer(
+            "AtenTensorHandle", indices
+        )
+        args = [
+            x,
+            indices_str,
+            str(len(indices)),
+            values,
+            accumulate,
+        ]
+        args.insert(0, x)  # set x as the output tensor, this fallback mutates x.
+        self.writeline(self.wrap_kernel_call(kernel, args))
+
+    def add_benchmark_harness(self, output):
+        if V.graph.aot_mode:
+            return
+        super().add_benchmark_harness(output)
+
+    def codegen_cpp_sizevar(self, x: sympy.Expr, *, simplify: bool = True) -> str:
+        return cexpr(V.graph.sizevars.simplify(x) if simplify else x)
+
+    def codegen_sizevar(self, x: sympy.Expr) -> str:
+        return self.codegen_cpp_sizevar(x)
+
+    def codegen_tuple_access(self, basename: str, name: str, index: str) -> str:
+        # in the abi_compatible mode, outputs are returned via arguments
+        return name
+
+    def codegen_shape_tuple(self, shape: Sequence[sympy.Expr]) -> str:
+        parts = [*map(self.codegen_sizevar, shape)]
+        if len(parts) == 0:
+            return "{}"
+        if len(parts) == 1:
+            return f"{{{parts[0]}, }}"
+        return f"{{{', '.join(parts)}}}"
+
+    def ensure_size_computed(self, sym: sympy.Symbol):
+        if isinstance(sym, sympy.Symbol) and symbol_is_type(sym, SymT.PRECOMPUTED_SIZE):
+            if sym in self.computed_sizes:
+                return
+            self.computed_sizes.add(sym)
+            expr = V.graph.sizevars.inv_precomputed_replacements[sym]
+            self.writeline(f"int64_t {sym} = {cexpr(expr)};")
+
+    def _generate_symbolic_call_arg_helper(
+        self, arg: SymbolicCallArg, graph: GraphLowering
+    ) -> None:
+        if (arg.inner, graph) not in self.kernel_numel_expr:
+            # declare expr once in each graph (scope)
+            self.kernel_numel_expr.add((arg.inner, graph))
+            self.writeline(f"int64_t {arg.inner} = {cexpr(arg.inner_expr)};")
+        else:
+            self.writeline(f"{arg.inner} = {cexpr(arg.inner_expr)};")
+
+    def _codegen_dynamic_scalar(self, node):
+        (data,) = (t.codegen_reference() for t in node.inputs)
+        self.codegen_tensor_item(node.inputs[0].get_dtype(), data, f"{node.sym}_raw")
+
+        if len(node.keypath) == 0:
+            self.writeline(f"auto {node.sym} = {node.sym}_raw;")
+        elif len(node.keypath) == 1 and isinstance(node.keypath[0], ConvertIntKey):
+            self.writeline(f"int64_t {node.sym} = {node.sym}_raw ? 1 : 0;")
+        elif len(node.keypath) == 1 and isinstance(node.keypath[0], DivideByKey):
+            # TODO: assert divisibility here
+            self.writeline(
+                f"int64_t {node.sym} = {node.sym}_raw / {node.keypath[0].divisor};"
+            )
+        else:
+            raise AssertionError(f"unrecognized keypath {node.keypath}")
+
+        # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
+        self.unbacked_symbol_decls.add(str(node.sym))
+
+    def codegen_dynamic_select_index(self, node, clamp):
+        index_cpp_str = self.val_to_arg_str_for_prim_type(node.index, int)
+        size_cpp_str = self.val_to_arg_str_for_prim_type(node.size, int)
+
+        # codegen index
+        sym = node.unbacked_offset_symbol
+        index_str = (
+            f"{index_cpp_str} < 0 ? {index_cpp_str} + "
+            f"{self.val_to_arg_str_for_prim_type(node.size, int)}: {index_cpp_str}"
+        )
+        self.writeline(f"auto {sym}_index = {index_str};")
+        index_str_clamped = (
+            f"{sym}_index < 0 ? 0 : ({sym}_index > {size_cpp_str} ? {size_cpp_str} : {sym}_index)"
+            if clamp
+            else f"{sym}_index"
+        )
+        self.writeline(f"auto {sym}_index_clamped = {index_str_clamped};")
+        self.writeline(
+            f"auto {sym} = {self.val_to_arg_str_for_prim_type(node.base_offset, int)} + "
+            f"{self.val_to_arg_str_for_prim_type(node.base_dim_stride, int)} * {sym}_index_clamped;"
+        )
+        # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
+        self.unbacked_symbol_decls.add(str(sym))
+
+    def codegen_dynamic_slice_size(self, node):
+        start_cpp_str = self.val_to_arg_str_for_prim_type(node.start, int)
+        end_cpp_str = self.val_to_arg_str_for_prim_type(node.end, int)
+        size_cpp_str = self.val_to_arg_str_for_prim_type(node.size, int)
+        step_cpp_str = self.val_to_arg_str_for_prim_type(node.step, int)
+        sym = node.unbacked_size_symbol
+
+        def codegen_clamp(index_str, start=True):
+            suf = "st" if start else "en"
+            index_ = f"{sym}_{suf}_index"
+            self.writeline(
+                f"int64_t {index_} = {index_str} < 0 ? {index_str} + {size_cpp_str} : {index_str};"
+            )
+            self.writeline(
+                f"int64_t {sym}_{suf}_cl = {index_} < 0 ? 0 : ({index_} > {size_cpp_str} ? {size_cpp_str} : {index_});"
+            )
+
+        codegen_clamp(start_cpp_str, start=True)
+        codegen_clamp(end_cpp_str, start=False)
+        if node.step == 1:
+            step_str = f"{sym}_en_cl - {sym}_st_cl"
+        else:
+            step_str = (
+                f"({sym}_en_cl - {sym}_st_cl + {step_cpp_str} - 1) / {step_cpp_str}"
+            )
+        self.writeline(f"int64_t {sym}_with_step = {step_str};")
+        self.writeline(f"int64_t {sym} = {sym}_with_step < 0 ? 0 : {sym}_with_step;")
+        self.unbacked_symbol_decls.add(str(sym))
+
+    def make_buffer_free(self, buffer):
+        return (
+            ""
+            if isinstance(buffer.get_output_spec(), ir.MultiOutputLayout)
+            or isinstance(buffer, ir.TMADescriptor)
+            else f"{buffer.get_name()}.reset();"
+        )
+
+    def make_free_by_names(self, names_to_del: list[str]):
+        return " ".join(f"{name}.reset();" for name in names_to_del)
+
+    def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str):
+        return f"auto {new_name} = std::move({old_name});  // reuse"
+
+    def generate_profiler_mark_wrapper_call(self, stack):
+        self.wrapper_call.writeline(
+            'RAIIAtenRecordFunctionHandle record_inductor_wrapper_call_("inductor_wrapper_call", nullptr);'
+        )
+
+    def generate_start_graph(self):
+        pass
+
+    def generate_end_graph(self):
+        pass
+
+    def generate_inf_and_nan_checker(self, nodes):
+        for buf in nodes.get_names():
+            # TODO: Add buf name directly into check_inf_and_nan.
+            self.writeline(
+                f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_check_inf_and_nan({buf}));"
+            )
+
+    def codegen_device(self, device):
+        assert device.type in DEVICE_TO_ATEN, (
+            device.type + " not found in DEVICE_TO_ATEN"
+        )
+        device_str = DEVICE_TO_ATEN[device.type][5:].lower()  # remove "at::k"
+        self.used_cached_devices.add(device_str)
+        return f"cached_torch_device_type_{device_str}, {device.index if device.index else 0}"
+
+    def codegen_dtype(self, dtype):
+        dtype_str = str(dtype).split(".")[-1]
+        self.used_cached_dtypes.add(dtype_str)
+        return f"cached_torch_dtype_{dtype_str}"
+
+    def codegen_layout(self, layout):
+        layout_str = str(layout).split(".")[-1]
+        self.used_cached_layouts.add(layout_str)
+        return f"cached_torch_layout_{layout_str}"
+
+    def codegen_memory_format(self, memory_format):
+        memory_format_str = str(memory_format).split(".")[-1]
+        self.used_cached_memory_formats.add(memory_format_str)
+        return f"cached_torch_memory_format_{memory_format_str}"
+
+    def codegen_int_array_var(
+        self,
+        int_array: str,
+        writeline: Callable[..., None],
+        known_statically=False,
+        graph=None,  # for per-graph caching
+    ) -> str:
+        # Use id(graph) for caching to avoid circular references
+        cache_key = (
+            int_array,
+            id(writeline),
+            known_statically,
+            id(graph) if graph else None,
+        )
+        if cache_key not in self.codegen_int_array_var_cache:
+            self.codegen_int_array_var_cache[cache_key] = (
+                self._codegen_int_array_var_impl(int_array, writeline, known_statically)
+            )
+
+        return self.codegen_int_array_var_cache[cache_key]
+
+    def _codegen_int_array_var_impl(
+        self,
+        int_array: str,
+        writeline: Callable[..., None],
+        known_statically: bool,
+    ) -> str:
+        # Used for size/stride declaration
+        #
+        # Because the memory planning is done in two passes (see the implementation
+        # of self.generate), the writeline behavior is different in the two passes.
+        # As a result, the emitted int array declarations may appear in a later
+        # position of the generated code, so the second pass codegen should not
+        # reuse int array declarations generated in the first pass.
+        # This is why writeline needs to explicitly passed in as a parameter.
+        var = f"int_array_{next(self.int_array_id)}"
+        ctype = "int64_t"
+        if int_array == "{}":
+            #  An array of unknown bound cannot be initialized with {}.
+            if known_statically:
+                if config.cpp.use_constexpr_for_int_array:
+                    writeline(f"static constexpr {ctype} *{var}=nullptr;")
+                else:
+                    writeline(f"static const {ctype} *{var}=nullptr;")
+            else:
+                writeline(f"const {ctype} *{var}=nullptr;")
+        else:
+            if var not in self.declared_int_array_vars:
+                self.declared_int_array_vars.add(var)
+                if known_statically:
+                    if config.cpp.use_constexpr_for_int_array:
+                        writeline(f"static constexpr {ctype} {var}[] = {int_array};")
+                    else:
+                        writeline(f"static const {ctype} {var}[] = {int_array};")
+                else:
+                    writeline(f"const {ctype} {var}[] = {int_array};")
+        return var
+
+    def make_buffer_allocation(self, buffer):
+        return self.make_allocation(
+            buffer.get_name(),
+            buffer.get_device(),
+            buffer.get_dtype(),
+            buffer.get_size(),
+            buffer.get_stride(),
+            V.graph.get_allocation_size(buffer),
+            buffer.get_is_pinned(),
+        )
+
+    def make_allocation(
+        self, name, device, dtype, shape, stride, allocation_shape=None, is_pinned=False
+    ):
+        if allocation_shape is None:
+            allocation_shape = shape
+
+        orig_stride = stride
+        device_str = self.codegen_device(device)
+        dtype_code = self.codegen_dtype(dtype)
+        size = self.codegen_shape_tuple(shape)
+        allocation_size = self.codegen_shape_tuple(allocation_shape)
+        stride = self.codegen_shape_tuple(orig_stride)
+
+        size_array_var = self.codegen_int_array_var(
+            size,
+            self.wrapper_call.writeline,
+            known_statically=self.is_statically_known_list_of_ints(shape),
+            graph=self.get_codegened_graph(),
+        )
+
+        if allocation_size != size:
+            allocation_size_array_var = self.codegen_int_array_var(
+                allocation_size,
+                self.wrapper_call.writeline,
+                known_statically=self.is_statically_known_list_of_ints(
+                    allocation_shape
+                ),
+                graph=self.get_codegened_graph(),
+            )
+        else:
+            allocation_size_array_var = size_array_var
+
+        stride_array_var = self.codegen_int_array_var(
+            stride,
+            self.wrapper_call.writeline,
+            known_statically=self.is_statically_known_list_of_ints(orig_stride),
+            graph=self.get_codegened_graph(),
+        )
+        device_type, device_id = device_str.split(",")
+        device_idx = "this->device_idx_" if V.graph.aot_mode else device_id
+
+        handle_name = f"{name}_handle"
+        args = [
+            str(len(shape)),
+            allocation_size_array_var,
+            stride_array_var,
+            dtype_code,
+            device_type,
+            device_idx,
+            f"&{handle_name}",
+        ]
+
+        self.wrapper_call.writeline(f"AtenTensorHandle {handle_name};")
+        pinned_str = "_pinned" if is_pinned else ""
+        self.wrapper_call.writeline(
+            f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided{pinned_str}({', '.join(args)}));"
+        )
+
+        if allocation_size != size:
+            old_handle_name, handle_name = handle_name, f"{name}_handle_restrided"
+            self.wrapper_call.writeline(f"AtenTensorHandle {handle_name};")
+            args = [
+                old_handle_name,
+                size_array_var,
+                stride_array_var,
+                f"&{handle_name}",
+            ]
+            self.wrapper_call.writeline(
+                f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_as_strided({', '.join(args)}));"
+            )
+            self.wrapper_call.writeline(
+                f"wrap_with_raii_handle_if_needed({old_handle_name});"
+            )
+
+        return f"RAIIAtenTensorHandle {name}({handle_name});"
+
+    def codegen_alloc_from_pool(
+        self, name, offset, dtype, shape, stride
+    ) -> tuple[str, list[str]]:
+        size = self.codegen_shape_tuple(shape)
+        stride = self.codegen_shape_tuple(stride)
+        tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}"
+        args = [
+            name,
+            cexpr(offset),  # bytes not numel
+            self.codegen_dtype(dtype),
+            str(len(shape)),
+            self.codegen_int_array_var(
+                size, self.wrapper_call.writeline, graph=self.get_codegened_graph()
+            ),
+            self.codegen_int_array_var(
+                stride, self.wrapper_call.writeline, graph=self.get_codegened_graph()
+            ),
+            f"&{tmp_name}",
+        ]
+        # We return the lines instead of writing here because writing here is bug prune.
+        # If you write aoti_torch__alloc_from_pool lines, you must write the RAIIAtenTensorHandle
+        # as well, otherwise you get memory leaks
+        allocations_to_write = [
+            f"AtenTensorHandle {tmp_name};",
+            f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool({', '.join(args)}));",
+        ]
+        return f"RAIIAtenTensorHandle({tmp_name})", allocations_to_write
+
+    def codegen_reinterpret_view(
+        self,
+        data,
+        size,
+        stride,
+        offset,
+        writeline: Callable[..., None],
+        dtype=None,
+    ) -> str:
+        """Returns a newly-created, temporary RAII tensor handle containing the
+        reinterpreted tensor data.  Callers of this function are responsible for saving
+        the handle if persistent access is needed."""
+
+        d_size, d_stride, d_offset, d_dtype, collapsible = (
+            codegen_reinterpret_view_helper(data)
+        )
+
+        dim = str(len(size))
+        original_offset = offset
+        offset = self.codegen_sizevar(offset)
+        call_strs = []
+        final_tensor_str = None
+
+        def create_reinterpret_call() -> str:
+            args = [
+                f"{data.get_name()}",
+                dim,
+                self.codegen_int_array_var(
+                    self.codegen_shape_tuple(size),
+                    writeline,
+                    known_statically=self.is_statically_known_list_of_ints(size),
+                    graph=self.get_codegened_graph(),
+                ),
+                self.codegen_int_array_var(
+                    self.codegen_shape_tuple(stride),
+                    writeline,
+                    known_statically=self.is_statically_known_list_of_ints(stride),
+                    graph=self.get_codegened_graph(),
+                ),
+                offset,
+            ]
+            return f"wrap_with_raii_handle_if_needed(reinterpret_tensor_wrapper({', '.join(args)}))"
+
+        def create_dtypeview_call(reinterpret_call: str) -> tuple[str, list[str]]:
+            tmp_AtenTensorHandle = f"tmp_{data.get_name()}_{next(self.tmp_tensor_id)}"
+            tmp_call_strs = [f"AtenTensorHandle {tmp_AtenTensorHandle};"]
+            device_name = data.layout.device.type
+            dtypeview_function = f"aoti_torch_{device_name}_view_dtype"
+            tmp_call_strs.append(
+                f"AOTI_TORCH_ERROR_CODE_CHECK({dtypeview_function}"
+                f"({reinterpret_call}, {self.codegen_dtype(dtype)}, &{tmp_AtenTensorHandle}));"
+            )
+            return f"RAIIAtenTensorHandle({tmp_AtenTensorHandle})", tmp_call_strs
+
+        def create_new_tensor_handle() -> tuple[str, list[str]]:
+            tmp_AtenTensorHandle = f"tmp_{data.get_name()}_{next(self.tmp_tensor_id)}"
+            tmp_call_strs = [
+                f"AtenTensorHandle {tmp_AtenTensorHandle};",
+                f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle({data.get_name()}, &{tmp_AtenTensorHandle}));",
+            ]
+            return f"RAIIAtenTensorHandle({tmp_AtenTensorHandle})", tmp_call_strs
+
+        collapsed = collapsible and original_offset == d_offset
+        if collapsed:
+            same_layout = size == d_size and stride == d_stride
+            base_dtype = d_dtype
+        else:
+            same_layout = (
+                size == data.layout.size
+                and stride == data.layout.stride
+                and original_offset == data.layout.offset
+            )
+            base_dtype = data.dtype
+
+        if same_layout:
+            # pure dtypeview
+            if dtype is not None and dtype != base_dtype:
+                final_tensor_str, tmp_call_strs = create_dtypeview_call(data.get_name())
+            else:
+                final_tensor_str, tmp_call_strs = create_new_tensor_handle()
+            call_strs.extend(tmp_call_strs)
+        else:
+            # firstly create reinterpretview
+            final_tensor_str = create_reinterpret_call()
+            if dtype is not None and dtype != base_dtype:
+                # wrap it with dtypeview
+                final_tensor_str, tmp_call_strs = create_dtypeview_call(
+                    final_tensor_str
+                )
+                call_strs.extend(tmp_call_strs)
+
+        for line in call_strs:
+            writeline(line)
+
+        # NB, the return handle here represents a temporary tensor, which will be automatically
+        # released.
+        # Here's a sample usage in the cpp wrapper code:
+        # ```
+        # aoti_torch_addmm_out(
+        #     buf1,
+        #     arg1_1,
+        #     RAIIAtenTensorHandle(tmp_tensor_handle_0),
+        #     buf0,
+        #     1L,
+        #     1L));
+        # ```
+        # RAIIAtenTensorHandle(tmp_tensor_handle_0) will be released after the call to addmm_out.
+        # This could be problematic when it's used in a different pattern, for example:
+        # ````
+        # AtenTensorHandle tensor_args[] = {RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6};
+        # aoti_torch_proxy_executor_call_function(..., tensor_args);
+        # ````
+        # RAIIAtenTensorHandle(tmp_tensor_handle_2) will be invalid when it's used in the latter
+        # kernel call.
+        #
+        # This is solved by updating the proxy_executor invocation to
+        # ```
+        # aoti_torch_proxy_executor_call_function(...,
+        #     std::array{
+        #         RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6
+        #     }.cbegin()
+        # );
+        # ```
+        return final_tensor_str
+
+    def codegen_device_copy(self, src, dst, non_blocking: Union[bool, str]):
+        """This function is overridden by cpp_wrapper_cpu_array_ref, so we don't need to
+        handle cases where dst is not an AtenTensorHandle."""
+        self.writeline(
+            f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_copy_({dst}, {src}, {non_blocking}));"
+        )
+
+    def codegen_multi_output(self, node: ir.MultiOutput):
+        # in the abi_compatible mode, outputs are retrieved by passing
+        # output pointers, so we skip its codegen here.
+        pass
+
+    def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs):
+        assert len(subgraph.graph.graph_inputs) == len(outer_inputs)
+
+        for (inner_input, inner_input_val), outer_input in zip(
+            subgraph.graph.graph_inputs.items(), outer_inputs
+        ):
+            if not isinstance(inner_input_val, ir.TensorBox):
+                continue
+
+            # in ABI-compatible mode, we copy the underlying at::Tensor of the conditional
+            # input (outer_input) into another at::Tensor to be used as a subgraph input
+            # (inner_input) in the nested scope. we can't std::move here, as the codegened
+            # outer input may be an expression / rvalue (e.g., reinterpret_view(x)), so we
+            # can't necessarily std::move it back to the origin (x).
+            self.writeline(f"AtenTensorHandle {inner_input}_handle;")
+            self.writeline(
+                f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out({outer_input}, &{inner_input}_handle));"
+            )
+            self.writeline(f"RAIIAtenTensorHandle {inner_input}({inner_input}_handle);")
+
+    def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs):
+        for inner_output, outer_output in zip(
+            subgraph.graph.graph_outputs, outer_outputs
+        ):
+            src = inner_output.codegen_reference()
+            if not isinstance(inner_output, ir.ShapeAsConstantBuffer):
+                # in ABI-compatible mode, we need to std::move subgraph output (inner_output)
+                # to the conditional output (outer_output), as RAIIAtenTensorHandle's copy
+                # constructor is deleted.
+                src = f"std::move({src})"
+                # in case the outer_output carried a value
+                # before (e.g., in the while_loop codegen)
+                self.writeline(f"{outer_output}.reset();")
+            self.writeline(f"{outer_output} = {src};")
+
+    def codegen_invoke_subgraph(self, invoke_subgraph):
+        raise NotImplementedError(
+            "codegen invoke_subgraph is not implemented for cpp wrapper"
+        )
+
+    def codegen_conditional(self, conditional):
+        outer_inputs = [f"{buf.codegen_reference()}" for buf in conditional.operands]
+        outer_outputs = []
+        for out in conditional.outputs:
+            # in ABI-compatible mode, ir.MultiOutput is not codegened,
+            # hence pre-declare output variables directly and separately
+            self.writeline(f"RAIIAtenTensorHandle {out.get_name()};")
+            outer_outputs.append(out.get_name())
+
+        if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer):
+            # in ABI-compatible mode, we need to use the ABI shim function
+            # to extract a C++ bool from the underlying scalar bool Tensor
+            predicate = f"{conditional.predicate.get_name()}_scalar"
+            if predicate not in self.used_cond_predicate:
+                self.codegen_tensor_item(
+                    torch.bool,
+                    conditional.predicate.codegen_reference(),
+                    predicate,
+                )
+                self.used_cond_predicate.add(predicate)
+        else:
+            # the predicate is not a Tensor: SymBool or Python bool
+            predicate = conditional.predicate.codegen_reference()
+
+        self.writeline(f"if ({predicate}) {{")
+        self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph))
+        self.codegen_subgraph(conditional.true_subgraph, outer_inputs, outer_outputs)
+        self.writeline(ExitSubgraphLine(self))
+        self.writeline("} else {")
+        self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph))
+        self.codegen_subgraph(conditional.false_subgraph, outer_inputs, outer_outputs)
+        self.writeline(ExitSubgraphLine(self))
+        self.writeline("}")
+
+    def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs):
+        # TODO (desertfire) - This function is the old way of supporting
+        # subgraph codegen by inlining subgraphs in the output code. For python
+        # wrapper, we have moved to lifting subgraphs as functions, supported by
+        # PythonWrapperCode `codegen_subgraph` function. We should perhaps
+        # support lifting of subgraphs as functions for cpp wrapper as well.
+        try:
+            self.push_codegened_graph(subgraph.graph)
+            self.writeline(f"// subgraph: {subgraph.name}")
+            self.codegen_subgraph_prefix(subgraph, outer_inputs, outer_outputs)
+            parent_graph = V.graph
+            with V.set_graph_handler(subgraph.graph):
+                subgraph.graph.codegen_subgraph(
+                    parent_graph=parent_graph,
+                )
+            self.codegen_subgraph_suffix(subgraph, outer_inputs, outer_outputs)
+        finally:
+            self.pop_codegened_graph()
+
+    def codegen_while_loop(self, while_loop, stack_output=False):
+        if stack_output:
+            raise NotImplementedError("NYI cpp wrapper for while_loop_stack_output")
+        is_bool_pred = isinstance(
+            while_loop.cond_subgraph.graph.graph_outputs[0], ir.ShapeAsConstantBuffer
+        )
+        name = while_loop.get_name()
+        outer_carried_inputs = [
+            buf.codegen_reference() for buf in while_loop.carried_inputs
+        ]
+        outer_additional_inputs = [
+            buf.codegen_reference() for buf in while_loop.additional_inputs
+        ]
+        cond_result_name = f"{name}_cond_result"
+        if is_bool_pred:
+            self.writeline(f"bool {cond_result_name};")
+        else:
+            self.writeline(f"RAIIAtenTensorHandle {cond_result_name};")
+
+        cond_outer_inputs = []
+        for inp, out in zip(outer_carried_inputs, while_loop.outputs):
+            # in ABI-compatible mode, the carried inputs are codegened
+            # as buffers outside the while loop and set to the initial
+            # values. at the end of each while_loop iteration, they
+            # will be assigned the carried values.
+            out_name = out.get_name()
+            self.writeline(f"AtenTensorHandle {out_name}_handle;")
+            self.writeline(
+                f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out({inp}, &{out_name}_handle));"
+            )
+            self.writeline(f"RAIIAtenTensorHandle {out_name}({out_name}_handle);")
+            cond_outer_inputs.append(out_name)
+
+        # additional inputs will be assigned within the while_loop
+        # iteration directly from the corresponding outer graph buffers
+        cond_outer_inputs.extend(outer_additional_inputs)
+
+        cond_outer_outputs = [cond_result_name]
+        body_outer_inputs = list(cond_outer_inputs)
+        body_outer_outputs = body_outer_inputs[: len(outer_carried_inputs)]
+
+        self.writeline("while (1) {")
+        self.writeline(EnterSubgraphLine(self, while_loop.cond_subgraph.graph))
+        self.codegen_subgraph(
+            while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs
+        )
+
+        if is_bool_pred:
+            cond_result = f"{cond_result_name}"
+        else:
+            cond_result = f"{cond_result_name}_scalar"
+            self.codegen_tensor_item(torch.bool, cond_result_name, cond_result)
+        self.writeline(f"if (!{cond_result}) break;")
+
+        self.writeline(ExitSubgraphLine(self))
+        self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
+        self.codegen_subgraph(
+            while_loop.body_subgraph, body_outer_inputs, body_outer_outputs
+        )
+        self.writeline(ExitSubgraphLine(self))
+        self.writeline("}")
+
+    def generate_extern_kernel_args_decl_if_needed(
+        self,
+        op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator],
+        raw_args: Sequence[Any],
+        output_args: _OUTPUT_ARGS_TYPE,
+        raw_outputs: Sequence[ir.Buffer],
+    ):
+        """
+        Generates declarations for external kernel arguments if needed, based on the provided
+        operator and its arguments. It processes both input and output arguments, categorizing
+        them into tensor and integer arguments for further code generation.
+        """
+        schema = None
+        if isinstance(op_overload, torch._higher_order_ops.torchbind.CallTorchBind):
+            obj = raw_args[0]
+            method = raw_args[1]
+            schema = op_overload.schema(obj, method)
+        else:
+            assert isinstance(op_overload, torch._ops.OpOverload), type(op_overload)
+            schema = op_overload._schema
+        assert schema is not None
+        arg_types = [x.real_type for x in schema.arguments]
+        return_types = [x.type for x in schema.returns]
+
+        new_tensor_args = []
+        new_int_args = []
+
+        def fill_args(arg, arg_type):
+            static_arg_types = (
+                torch.FloatType,
+                torch.BoolType,
+                torch.StringType,
+                torch.Type,
+                torch.DeviceObjType,
+            )
+            inductor_tensor_buffers = (
+                ir.Buffer,
+                ir.ReinterpretView,
+            )
+
+            if isinstance(arg_type, torch.TensorType):
+                assert isinstance(arg, inductor_tensor_buffers), f"got {type(arg)}"
+                new_tensor_args.append(f"{arg.codegen_reference()}")
+            elif isinstance(arg_type, torch.IntType):
+                # int
+                new_int_args.append(str(arg))
+            elif isinstance(arg_type, torch.SymIntType):
+                # SymInt
+                expr = arg.node.expr if isinstance(arg, torch.SymInt) else arg
+                new_int_args.append(cexpr(expr))
+            elif isinstance(arg_type, torch.NumberType):
+                # Scalar of type int
+                assert isinstance(arg, (int, float, bool))
+                # Only treat int Scalar as dynamic
+                if isinstance(arg, int):
+                    new_int_args.append(str(arg))
+            elif isinstance(arg, ir.TorchBindObject):
+                # torchbind objects are loaded in proxy executor
+                pass
+            elif isinstance(arg_type, torch.ListType):
+                assert isinstance(arg, (list, tuple))
+
+                # List[Tensor]
+                if isinstance(arg_type.getElementType(), torch.TensorType):
+                    new_tensor_args.extend([f"{a.codegen_reference()}" for a in arg])
+                # List[Optional[Tensor]]
+                elif isinstance(
+                    arg_type.getElementType(), torch.OptionalType
+                ) and isinstance(
+                    arg_type.getElementType().getElementType(), torch.TensorType
+                ):
+                    new_tensor_args.extend(
+                        [f"{a.codegen_reference()}" for a in arg if a is not None]
+                    )
+                # List[int]
+                elif isinstance(arg_type.getElementType(), torch.IntType):
+                    new_int_args.extend([str(a) for a in arg])
+                # List[SymInt]
+                elif isinstance(arg_type.getElementType(), torch.SymIntType):
+                    expressions = [
+                        a.node.expr if isinstance(a, torch.SymInt) else a for a in arg
+                    ]
+                    new_int_args.extend([cexpr(expr) for expr in expressions])
+                # List[Scalar]
+                elif isinstance(arg_type.getElementType(), torch.NumberType):
+                    # Only treat int Scalar as dynamic
+                    is_int_type = [isinstance(a, int) for a in arg]
+                    if any(is_int_type):
+                        assert all(is_int_type), (
+                            "AOTInductor only supports int scalars of the same type"
+                        )
+                        new_int_args.extend([str(a) for a in arg])
+                else:
+                    assert isinstance(
+                        arg_type.getElementType(),
+                        static_arg_types,  # type: ignore[arg-type]
+                    ), (
+                        f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
+                    )
+            else:
+                assert isinstance(
+                    arg_type,
+                    static_arg_types,  # type: ignore[arg-type]
+                ), (
+                    f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
+                )
+
+        for arg, arg_type in zip(raw_args, arg_types):
+            if arg is not None:
+                if isinstance(arg_type, torch.OptionalType):
+                    fill_args(arg, arg_type.getElementType())
+                else:
+                    fill_args(arg, arg_type)
+
+        def fill_output_arg(
+            arg: str, return_type: torch.JitType, is_mutated_output: bool
+        ) -> None:
+            if isinstance(return_type, torch.TensorType):
+                if not is_mutated_output:
+                    self.writeline(f"AtenTensorHandle {arg}_handle;  // output buffer")
+                    self.writeline(
+                        f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{arg}_handle));"
+                    )
+                    self.writeline(f"RAIIAtenTensorHandle {arg}({arg}_handle);")
+                new_tensor_args.append(f"{arg}")
+            elif isinstance(return_type, torch.SymIntType):
+                raise NotImplementedError("NYI support for return type: SymInt")
+            elif isinstance(return_type, torch.ListType) and isinstance(
+                return_type.getElementType(), torch.SymIntType
+            ):
+                raise NotImplementedError("NYI support for return type: List[SymInt]")
+            else:
+                raise AssertionError(f"Unsupported return type found: {return_type}")
+
+        # TODO: Only support None and tensor(s) returns for now, SymInt is not implemented yet
+        for return_type in return_types:
+            if isinstance(
+                return_type, (torch.TensorType, torch.NoneType, torch.IntType)
+            ):
+                pass
+            elif isinstance(return_type, torch.OptionalType):
+                assert isinstance(return_type.getElementType(), torch.TensorType)
+            elif isinstance(return_type, torch.ListType):
+                assert isinstance(return_type.getElementType(), torch.TensorType)
+            else:
+                raise NotImplementedError(
+                    f"return type {return_type} is not yet supported."
+                )
+
+        for output_arg, raw_output_arg in zip(output_args, raw_outputs):  # type: ignore[arg-type]
+            # None output is supported, but Optional return types are not yet supported
+            if output_arg is None:
+                continue
+            elif isinstance(raw_output_arg, int):
+                new_int_args.append(str(raw_output_arg))
+            elif isinstance(output_arg, list):
+                for out in output_arg:
+                    assert out is not None, out
+                    fill_output_arg(
+                        out,
+                        torch.TensorType.get(),
+                        isinstance(raw_output_arg, ir.MutationOutput),
+                    )
+            else:
+                fill_output_arg(
+                    output_arg,
+                    torch.TensorType.get(),
+                    isinstance(raw_output_arg, ir.MutationOutput),
+                )
+
+        return new_tensor_args, new_int_args
+
+    @staticmethod
+    def _compatible_with_stableivalue(op: torch._ops.OpOverload) -> bool:
+        """Returns true if op_overload._schema only utilizes types supported by the AOT
+        C-shim *internal* function to_ivalue.  to_ivalue is an implementation detail, so
+        these types are not guaranteed to be supported long-term.  When generating code
+        for cpp_wrapper mode, we don't have to be forward-compatible, so changing this
+        function's implementation in future is fine."""
+        supported_types = (
+            torch.BoolType,
+            torch.DeviceObjType,
+            torch.FloatType,
+            # ScalarTypeType, LayoutType, and MemoryFormatType are seen as IntType
+            # when queried via torch.JitType.type.
+            torch.IntType,
+            torch.TensorType,
+        )
+
+        def type_supported(t: torch.JitType) -> bool:
+            if isinstance(t, torch.OptionalType):
+                return type_supported(t.getElementType())
+            return isinstance(t, supported_types)
+
+        return all(
+            type_supported(a.type)
+            for a in chain(op._schema.arguments, op._schema.returns)
+        )
+
+    def generate_fallback_kernel_with_runtime_lookup(
+        self,
+        buf_name: str,
+        python_kernel_name: str,
+        get_args: Callable[[], Sequence[str]],
+        op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator],
+        raw_args: Sequence[Any],
+        outputs: Sequence[ir.Buffer],
+    ) -> None:
+        """Generate a call to a kernel not contained in the C-shim.  This results in
+        different code paths for AOT Inductor vs cpp_wrapper Inductor mode."""
+
+        def extract_output_name(
+            out: Optional[Union[ir.Buffer, Sequence[ir.Buffer]]],
+        ) -> Union[Optional[str], _OUTPUT_ARGS_TYPE]:
+            if out is None:
+                return None
+            if isinstance(out, (ir.MultiOutput, ir._CollectiveKernel)):
+                return out.get_name()
+            if isinstance(out, ir.MutationOutput):
+                mutated_buf_names = out.get_mutation_names()
+                assert (
+                    isinstance(mutated_buf_names, list) and len(mutated_buf_names) == 1
+                ), "Expect only one mutated buffer in MutationOutput"
+                return mutated_buf_names[0]
+            if isinstance(out, (list, tuple)):
+                return [extract_output_name(o) for o in out]  # type: ignore[misc]
+            if isinstance(out, int):
+                return str(out)
+            raise AssertionError(f"Unexpected output: {type(out)}")
+
+        if isinstance(op_overload, torch._ops.HigherOrderOperator):
+            assert isinstance(
+                op_overload, torch._higher_order_ops.torchbind.CallTorchBind
+            ), type(op_overload)
+            assert len(raw_args) > 1
+            obj = raw_args[0]
+            method = raw_args[1]
+            return_schema = op_overload.schema(obj, method).returns
+        else:
+            return_schema = op_overload._schema.returns
+
+        # output_args has the same pytree structure as outputs
+        if not return_schema:
+            # kernel does not return a value
+            output_args: _OUTPUT_ARGS_TYPE = []
+        elif isinstance(output_name := extract_output_name(outputs), str):
+            output_args = [output_name]
+        else:
+            # If the schema indicates a return value, we should have a non-None value by
+            # this point.
+            assert isinstance(output_name, list), type(output_name)
+            output_args = output_name
+
+        # In AOT mode, we use a ProxyExecutor to run fallback kernels.
+        if V.graph.aot_mode:
+            self.generate_fallback_kernel_with_runtime_lookup_aot(
+                op_overload,
+                raw_args,
+                output_args,
+                outputs,
+            )
+            return
+
+        assert isinstance(op_overload, torch._ops.OpOverload), type(op_overload)
+        for output in output_args:
+            assert output is None or isinstance(output, str), (
+                "fallback kernels with runtime lookup currently only support tensor "
+                "returns, not more complicated types (such as list-of-list-of-tensor)"
+            )
+
+        # In non-AOT mode, we use aoti_torch_call_dispatcher if all the inputs and
+        # outputs of the op can be represented with StableIValue.  This avoids the
+        # overhead of calling back into Python, and covers most remaining fallback ops.
+        if self._compatible_with_stableivalue(op_overload):
+            self.generate_fallback_kernel_with_runtime_lookup_nopython(
+                get_args,
+                op_overload,
+                output_args,  # type: ignore[arg-type]
+                outputs,
+            )
+            return
+
+        # Otherwise, we call back into Python, which has some extra runtime overhead,
+        # but handles situations like list[Tensor] (currently unrepresentable via
+        # StableIValue).
+        self.generate_fallback_kernel_with_runtime_lookup_python(
+            buf_name,
+            python_kernel_name,
+            op_overload,
+            raw_args,
+            output_args,  # type: ignore[arg-type]
+            outputs,
+        )
+
+    def generate_scoped_gil_acquire(self, declarations_before_scope, lines_in_scope):
+        scoped_lines = IndentedBuffer()
+        for declaration in declarations_before_scope:
+            scoped_lines.writeline(declaration)
+
+        scoped_lines.writeline("{")
+        with scoped_lines.indent():
+            scoped_lines.writeline("py::gil_scoped_acquire_simple acquire;")
+            scoped_lines.writelines(lines_in_scope.split("\n"))
+        scoped_lines.writelines("}")
+        return scoped_lines._lines
+
+    def load_custom_op_wrapper(self):
+        # TODO: need to support control flow
+        if self.custom_op_wrapper_loaded:
+            return
+
+        lines = """
+RAIIPyObject codecache_module(PyImport_ImportModule("torch._inductor.codecache"));
+if (!codecache_module) {
+    throw std::runtime_error("Failed to load torch._inductor.codecache");
+}
+custom_op_wrapper = PyObject_GetAttrString(codecache_module, "custom_op_wrapper");
+if (!custom_op_wrapper) {
+    throw std::runtime_error("Failed to load torch._inductor.codecache.custom_op_wrapper");
+}"""
+
+        declarations_before_scope = ["RAIIPyObject custom_op_wrapper;"]
+        scope_gil_acquire = self.generate_scoped_gil_acquire(
+            declarations_before_scope, lines
+        )
+        self.writelines(scope_gil_acquire)
+
+        self.custom_op_wrapper_loaded = True
+
+    def generate_float_value(self, val):
+        assert isinstance(val, float)
+        if val == float("inf"):
+            return "std::numeric_limits::infinity()"
+        elif val == float("-inf"):
+            return "-std::numeric_limits::infinity()"
+        elif math.isnan(val):
+            return "std::numeric_limits::quiet_NaN()"
+        else:
+            return f"{val}"
+
+    def generate_py_arg(self, py_args_var, idx, raw_arg, arg_type):
+        def generate_py_arg_inner(lines, raw_arg, arg_type):
+            def handle_scalar(scalar):
+                if isinstance(scalar, int):
+                    return f"PyLong_FromLongLong({scalar})"
+                if isinstance(scalar, float):
+                    return f"PyFloat_FromDouble({self.generate_float_value(scalar)})"
+                if isinstance(scalar, bool):
+                    return f"PyBool_FromLong({1 if scalar else 0})"
+                if isinstance(scalar, complex):
+                    real = self.generate_float_value(scalar.real)
+                    imag = self.generate_float_value(scalar.imag)
+                    return f"PyComplex_FromDoubles({real}, {imag})"
+                if isinstance(scalar, SymTypes):
+                    scalar_var = cexpr(scalar.node.expr)
+                    if isinstance(scalar, torch.SymBool):
+                        return f"PyBool_FromLong({scalar_var})"
+                    if isinstance(scalar, torch.SymFloat):
+                        return f"PyFloat_FromDouble({scalar_var})"
+                    return f"PyLong_FromLongLong({scalar_var})"
+                raise NotImplementedError(
+                    f"scalar {scalar}, {type(scalar)} cannot be handled by handle_scalar"
+                )
+
+            if raw_arg is None:
+                # Py_None is a singleton, so we have to explicitly incref it here
+                lines.append("Py_INCREF(Py_None);\n")
+                return "Py_None"
+            elif isinstance(arg_type, torch.TensorType):
+                # In some cases, scalar arguments may be passed in place of tensors.
+                if not hasattr(raw_arg, "codegen_reference"):
+                    return handle_scalar(raw_arg)
+
+                # Store AtenTensorHandle as void*.  All Python args are constructed in a
+                # nested scope, so this handle will self-destruct after the function
+                # call.
+                base_handle = self.create_tmp_raii_handle_var_if_needed(
+                    raw_arg.codegen_reference(), lines
+                )
+                return f"PyCapsule_New(reinterpret_cast({base_handle}.get()), NULL, NULL)"
+            elif isinstance(arg_type, torch.OptionalType):
+                return generate_py_arg_inner(lines, raw_arg, arg_type.getElementType())
+            elif isinstance(arg_type, torch.IntType):
+                # int
+                return f"PyLong_FromLongLong({raw_arg})"
+            elif isinstance(arg_type, torch.SymIntType):
+                # SymInt
+                expr = (
+                    raw_arg.node.expr if isinstance(raw_arg, torch.SymInt) else raw_arg
+                )
+                return f"PyLong_FromLongLong({cexpr(expr)})"
+            elif isinstance(arg_type, torch.FloatType):
+                return f"PyFloat_FromDouble({self.generate_float_value(raw_arg)})"
+            elif isinstance(arg_type, torch.BoolType):
+                return f"PyBool_FromLong({1 if raw_arg else 0})"
+            elif isinstance(arg_type, torch.StringType):
+                return f'PyUnicode_FromString("{raw_arg}")'
+            elif isinstance(arg_type, torch.NumberType):
+                # Union[bool, int, float, complex]
+                # torch/_prims_common/__init__.py
+                return handle_scalar(raw_arg)
+            elif isinstance(raw_arg, torch.device):
+                device_str, device_index = self.codegen_device(raw_arg).split(", ")
+                return f"THPDevice_New(c10::Device(static_cast({device_str}), {device_index}))"
+            elif isinstance(raw_arg, torch.dtype):
+                return f"Py_NewRef(torch::getTHPDtype(static_cast({self.codegen_dtype(raw_arg)})))"
+            elif isinstance(raw_arg, torch.layout):
+                return f"Py_NewRef(torch::getTHPLayout(static_cast({self.codegen_layout(raw_arg)})))"
+            elif isinstance(raw_arg, torch.memory_format):
+                return (
+                    "Py_NewRef(torch::utils::getTHPMemoryFormat(static_cast("
+                    f"{self.codegen_memory_format(raw_arg)})))"
+                )
+            else:
+                raise NotImplementedError(
+                    f"arg type {arg_type} is not yet supported by custom_op_wrapper"
+                )
+
+        lines = []
+        if isinstance(arg_type, torch.ListType):
+            assert isinstance(raw_arg, (list, tuple)), str(raw_arg) + " is not a list"
+            lines.append(
+                f"PyObject* {py_args_var}_{idx} = PyList_New({len(raw_arg)});\n"
+            )
+            for i, elem in enumerate(raw_arg):
+                lines.append(
+                    f"PyList_SetItem({py_args_var}_{idx}, {i}, {generate_py_arg_inner(lines, elem, arg_type.getElementType())});\n"
+                )
+            lines.append(
+                f"PyTuple_SetItem({py_args_var}, {idx}, {py_args_var}_{idx});\n"
+            )
+        else:
+            lines.append(
+                f"PyTuple_SetItem({py_args_var}, {idx}, {generate_py_arg_inner(lines, raw_arg, arg_type)});\n"
+            )
+        return "".join(lines)
+
+    def generate_fallback_kernel_with_runtime_lookup_nopython(
+        self,
+        get_args: Callable[[], Sequence[str]],
+        op_overload: torch._ops.OpOverload,
+        output_args: Sequence[Optional[str]],
+        raw_outputs: Sequence[ir.Buffer],
+    ) -> None:
+        """Generate fallback kernel calls with runtime (non-AOT) dispatch.  This can
+        only be called in cpp_wrapper mode, and assumes that the input is a non-None
+        OpOverload.
+
+        In the future, we may switch over to directly calling c10::Dispatcher if we need
+        to support more datatypes."""
+        if raw_outputs:
+            declarations_before_scope = [
+                f"RAIIAtenTensorHandle {output_arg};"
+                for output_arg, raw_output_arg in zip(output_args, raw_outputs)  # type: ignore[arg-type]
+                if output_arg is not None
+                and not isinstance(raw_output_arg, ir.MutationOutput)
+            ]
+        else:
+            declarations_before_scope = [
+                f"RAIIAtenTensorHandle {output_arg};"
+                for output_arg in output_args  # type: ignore[arg-type]
+                if output_arg is not None
+            ]
+
+        dispatch_lines = IndentedBuffer()
+        dispatch_lines.writelines(declarations_before_scope)
+        dispatch_lines.writeline("{")
+
+        with dispatch_lines.indent():
+            tmp_var_number = count()
+
+            def parse_arg(arg_type: torch.JitType, codegen_arg: str) -> str:
+                # Strip off any temporary references; we're in an indented context, so
+                # any saved-off variables will be auto-destroyed.
+                new_codegen_arg = codegen_arg.removeprefix("&temporary_reference(")
+                if new_codegen_arg != codegen_arg:
+                    # If we removed temporary_reference, there's a good chance the
+                    # variable ends with get() (which would retrieve an ATenTensorHandle
+                    # from a temporary RAII handle).  Strip that off too, since we're
+                    # going to save this in a temporary RAII handle.
+                    if codegen_arg.endswith(".get())"):
+                        codegen_arg = new_codegen_arg.removesuffix(".get())")
+                    else:
+                        codegen_arg = new_codegen_arg.removesuffix(")")
+
+                if isinstance(arg_type, torch.OptionalType):
+                    # If we have a pointer to a variable, strip it off and let
+                    # from handle any internal pointers.
+                    codegen_arg = codegen_arg.removeprefix("&")
+
+                    if codegen_arg == "nullptr":
+                        return "torch::stable::detail::from(std::nullopt)"
+
+                    var_name = f"tmp_var_{next(tmp_var_number)}"
+                    dispatch_lines.writeline(
+                        f"std::optional {var_name}{{{parse_arg(arg_type.getElementType(), codegen_arg)}}};"
+                    )
+                    return f"torch::stable::detail::from({var_name})"
+
+                raii_var = self.create_tmp_raii_handle_var_if_needed(
+                    codegen_arg, dispatch_lines
+                )
+                temp_handle = raii_var != codegen_arg
+
+                if isinstance(arg_type, torch.TensorType):
+                    if not temp_handle:
+                        # If the RAII tensor being referenced _isn't_ a temporary,
+                        # scoped to this fallback call, then create a new handle
+                        # referencing it which from can steal.
+                        var_name = f"tmp_var_{next(tmp_var_number)}"
+                        dispatch_lines.writeline(f"AtenTensorHandle {var_name};")
+                        dispatch_lines.writeline(
+                            f"aoti_torch_new_tensor_handle({raii_var}, &{var_name});"
+                        )
+                        return f"torch::stable::detail::from({var_name})"
+                    # If the RAII tensor _is_ a temporary scoped to this fallback call,
+                    # simply release and steal the handle.
+                    return f"torch::stable::detail::from({raii_var}.release())"
+                return f"torch::stable::detail::from({codegen_arg})"
+
+            codegen_args = get_args()
+            ivalue_args = (
+                parse_arg(a.type, c)
+                for a, c in zip(op_overload._schema.arguments, codegen_args)
+            )
+            array_len = max(len(codegen_args), len(output_args))
+            dispatch_lines.writeline(
+                f"std::array dispatch_vars{{{', '.join(ivalue_args)}}};"
+            )
+            dispatch_lines.writeline("AOTI_TORCH_ERROR_CODE_CHECK(")
+            with dispatch_lines.indent():
+                dispatch_lines.writeline(
+                    f'aoti_torch_call_dispatcher("{op_overload._schema.name}", "{op_overload._schema.overload_name}", dispatch_vars.data())'  # noqa: B950
+                )
+            dispatch_lines.writeline(");")
+
+            if len(output_args) == 1 and (output := output_args[0]) is not None:
+                # result is a single tensor
+                dispatch_lines.writeline(
+                    f"{output} = torch::stable::detail::to(dispatch_vars[0]);"
+                )
+            else:
+                # result is a tuple of tensors
+                for idx, output_arg in enumerate(output_args):
+                    if output_arg is None:
+                        continue
+                    dispatch_lines.writeline(
+                        f"{output_arg} = torch::stable::detail::to(dispatch_vars[{idx}]);"
+                    )
+
+        dispatch_lines.writeline("}")
+        self.writelines(dispatch_lines.getvalue().splitlines())
+
+    def generate_fallback_kernel_with_runtime_lookup_python(
+        self,
+        buf_name: str,
+        python_kernel_name: str,
+        op_overload: torch._ops.OpOverload,
+        raw_args: Sequence[Any],
+        output_args: Sequence[Optional[str]],
+        raw_outputs: Sequence[ir.Buffer],
+    ) -> None:
+        """Generate fallback kernel calls with runtime (non-AOT) dispatch.  This can
+        only be called in cpp_wrapper mode, and assumes that the input is a non-None
+        OpOverload.
+
+        This function calls into Python to dispatch, which allows it to handle datatypes
+        that cannot be contained in StableIValue, at the cost of some performance."""
+        self.load_custom_op_wrapper()
+
+        num_args = len(raw_args)
+        py_args_var = f"py_args_{next(self.arg_var_id)}"
+        # First arg is always the python op name
+        lines = textwrap.dedent(
+            f"""
+            RAIIPyObject {py_args_var}(PyTuple_New({num_args + 1}));
+            if (!{py_args_var}) {{
+                throw std::runtime_error("PyTuple_New {py_args_var} failed");
+            }}
+            PyTuple_SetItem({py_args_var}, 0, PyUnicode_FromString("{python_kernel_name}"));
+            """
+        )
+
+        for idx, (raw_arg, schema_arg) in enumerate(
+            zip(raw_args, op_overload._schema.arguments)
+        ):
+            lines += self.generate_py_arg(
+                py_args_var, idx + 1, raw_arg, schema_arg.real_type
+            )
+
+        lines += textwrap.dedent(
+            f"""
+            // Call the custom op in Python
+            RAIIPyObject py_{buf_name}(PyObject_CallObject(custom_op_wrapper, {py_args_var}));
+            if (!py_{buf_name}) {{
+                if (PyErr_Occurred()) {{
+                    return;
+                }}
+                throw std::runtime_error("PyObject_CallObject {python_kernel_name} failed");
+            }}
+            """
+        )
+
+        if len(output_args) == 1 and (output := output_args[0]) is not None:
+            # result is a single tensor
+            lines += f"{output} = reinterpret_cast(PyCapsule_GetPointer(py_{buf_name}.get(), NULL));\n"
+        else:
+            # result is a tuple of tensors
+            for idx, output_arg in enumerate(output_args):
+                if output_arg is None:
+                    continue
+                lines += f"{output_arg} = reinterpret_cast(PyCapsule_GetPointer(PyList_GET_ITEM(py_{buf_name}.get(), {idx}), NULL));\n"  # noqa: B950
+
+        if raw_outputs:
+            declarations_before_scope = [
+                f"RAIIAtenTensorHandle {output_arg};"
+                for output_arg, raw_output_arg in zip(output_args, raw_outputs)  # type: ignore[arg-type]
+                if output_arg is not None
+                and not isinstance(raw_output_arg, ir.MutationOutput)
+            ]
+        else:
+            declarations_before_scope = [
+                f"RAIIAtenTensorHandle {output_arg};"
+                for output_arg in output_args  # type: ignore[arg-type]
+                if output_arg is not None
+            ]
+        scope_gil_acquire = self.generate_scoped_gil_acquire(
+            declarations_before_scope, lines
+        )
+        self.writelines(scope_gil_acquire)
+
+    def generate_fallback_kernel_with_runtime_lookup_aot(
+        self,
+        op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator],
+        raw_args: Sequence[Any],
+        output_args: _OUTPUT_ARGS_TYPE,
+        raw_outputs: Sequence[ir.Buffer],
+    ) -> None:
+        (
+            tensor_call_args,
+            int_call_args,
+        ) = self.generate_extern_kernel_args_decl_if_needed(
+            op_overload,
+            raw_args,
+            output_args,
+            raw_outputs,
+        )
+        # force both temporary arrays to generate mutable data pointers, since the proxy
+        # executor signature requires that datatype
+        int_call_str = self._generate_temporary_array_pointer(
+            "int64_t", int_call_args, force_mutable=True
+        )
+        tensor_call_str = self._generate_temporary_array_pointer(
+            "AtenTensorHandle", tensor_call_args, force_mutable=True
+        )
+
+        extern_kernel_node_index = len(V.extern_kernel_nodes) - 1
+        self.writeline(
+            f"aoti_torch_proxy_executor_call_function(proxy_executor, "
+            f"{extern_kernel_node_index}, "
+            f"{len(int_call_args)}, "
+            f"{int_call_str}, "
+            f"{len(tensor_call_args)}, "
+            f"{tensor_call_str});"
+        )
+
+    def generate_reset_kernel_saved_flags(self):
+        pass
+
+    def generate_save_uncompiled_kernels(self):
+        pass
+
+    def c_type_for_prim_type(self, val, type_) -> str:
+        if isinstance(type_, torch.OptionalType):
+            return f"{self.c_type_for_prim_type(val, type_.getElementType())}*"
+        elif isinstance(type_, torch.TensorType):
+            return "AtenTensorHandle"
+        elif isinstance(type_, (torch.IntType, torch.SymIntType)):
+            return "int64_t"
+        elif isinstance(
+            type_, (torch.BoolType, torch.SymBoolType, torch.EnumType)
+        ) or repr(type_) in ("Layout", "MemoryFormat", "ScalarType"):
+            return "int32_t"
+        elif isinstance(type_, torch.FloatType):
+            return "double"
+        elif isinstance(type_, torch.NumberType):
+            if isinstance(val, bool):
+                return "int32_t"
+            elif isinstance(val, (int, float)):
+                return "double"
+            elif val is None:
+                # This could happen when val is an optional value
+                return "double"
+            else:
+                raise AssertionError(
+                    f"Unexpected type in c_type_for_prim_type: {type_=}"
+                )
+        elif isinstance(type_, torch.StringType):
+            return "const char*"
+        else:
+            raise AssertionError(f"Unexpected type in c_type_for_prim_type: {type_=}")
+
+    def val_to_arg_str_for_prim_type(self, val, type_) -> str:
+        # TODO: not using type_ as the first step of refactoring. Will update this later.
+        if isinstance(val, bool):
+            return "1" if val else "0"
+        elif isinstance(val, int):
+            # uint64_t is long on Linux, but long long on MacOS and Windows
+            return f"{val}LL" if sys.platform in ["darwin", "win32"] else f"{val}L"
+        elif isinstance(val, complex):
+            return f"c10::complex{{ {self.generate_float_value(val.real)}, {self.generate_float_value(val.imag)} }}"
+        elif isinstance(val, str):
+            return f'"{val}"'
+        elif isinstance(
+            val, (ir.Buffer, ir.ReinterpretView, ir.StorageBox, ir.TensorBox)
+        ):
+            return val.codegen_reference()
+        elif isinstance(val, torch.device):
+            return self.codegen_device(val)
+        elif isinstance(val, torch.dtype):
+            return self.codegen_dtype(val)
+        elif isinstance(val, torch.layout):
+            return self.codegen_layout(val)
+        elif isinstance(val, torch.memory_format):
+            return self.codegen_memory_format(val)
+        elif isinstance(val, float):
+            return self.generate_float_value(val)
+        elif isinstance(val, (list, tuple)):
+            # FIXME: This happens because type_ is not always properly set to torch.ListType
+            return f"{{{', '.join(self.val_to_arg_str(x, None) for x in val)}}}"
+        elif isinstance(val, SymTypes):
+            return cexpr(val.node.expr)
+        elif isinstance(val, sympy.Expr):
+            return cexpr(val)
+        else:
+            return repr(val)
+
+    def val_to_arg_str(self, val, type_=None) -> str:
+        if val is None:
+            # None needs special care. It either represent nullopt or an empty tensor
+            if type_ is None or isinstance(type_, torch.OptionalType):
+                if type_ is not None and isinstance(
+                    type_.getElementType(),
+                    (
+                        torch.DeviceObjType,
+                        torch.ListType,
+                        torch.TupleType,
+                    ),
+                ):
+                    return "nullptr, 0"
+                return "nullptr"
+
+            if isinstance(type_, torch.TensorType):
+                # create an empty tensor, the equivalent of at::Tensor()
+                var_name = f"var_{next(self.arg_var_id)}"
+                self.writeline(f"AtenTensorHandle {var_name}_handle;")
+                self.writeline(
+                    f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{var_name}_handle));"
+                )
+                self.writeline(f"RAIIAtenTensorHandle {var_name}({var_name}_handle);")
+                return var_name
+
+            raise AssertionError("Can not map None to a known data type")
+
+        if isinstance(type_, torch.OptionalType):
+            element_type = type_.getElementType()
+            arg_str = self.val_to_arg_str(val, element_type)
+            # Handle optional iterables as a special case.  Utilize the
+            # temporary_reference function to avoid saving them off and increasing
+            # memory usage.
+            if isinstance(element_type, (torch.ListType, torch.TupleType)):
+                main_value, aux = arg_str.rsplit(", ", maxsplit=1)
+                return f"&temporary_reference({main_value}), {aux}"
+
+            # Handle optional tensors as a special case, as above.
+            if isinstance(element_type, torch.TensorType):
+                base_handle = self.val_to_arg_str(val, element_type)
+                return f"&temporary_reference({base_handle}.get())"
+
+            var_name = f"var_{next(self.arg_var_id)}"
+            if isinstance(element_type, torch.DeviceObjType):
+                main_value, aux = arg_str.rsplit(", ", maxsplit=1)
+                self.writeline(f"auto {var_name} = {main_value};")
+                return f"&{var_name}, {aux}"
+
+            self.writeline(
+                f"{self.c_type_for_prim_type(val, element_type)} {var_name} = {arg_str};"
+            )
+            return f"&{var_name}"
+
+        if isinstance(type_, (torch.ListType, torch.TupleType)):
+            assert isinstance(val, (list, tuple)), (
+                f"{val} does not match with arg type {type_}"
+            )
+            element_type = type_.getElementType()
+
+            if len(val) == 0:
+                # Zero-size array is not supported in the C or C++ standard, so return a
+                # nullptr.
+                return "nullptr, 0"
+
+            result = [self.val_to_arg_str(x, element_type) for x in val]
+            if isinstance(element_type, torch.TensorType):
+                result = [f"{t}.get()" for t in result]
+
+            c_type = self.c_type_for_prim_type(val[0], element_type)
+            # see the comment in self._generate_temporary_array_pointer for an
+            # explanation of why this c_type gets modified
+            if isinstance(element_type, torch.OptionalType) and not c_type.startswith(
+                "const"
+            ):
+                c_type = f"const {c_type}"
+
+            # need to pass the array length, because we can't use the std::array member
+            # function
+            return (
+                f"{self._generate_temporary_array_pointer(c_type, result)}, {len(val)}"
+            )
+
+        val_is_scalar = isinstance(val, (bool, complex, float, int, *SymTypes))
+        if isinstance(type_, torch.TensorType) and val_is_scalar:
+            val_str = self.val_to_arg_str_for_prim_type(val, None)
+            return self.codegen_scalar_to_tensor(val_str)
+
+        return self.val_to_arg_str_for_prim_type(val, type_)
+
+    def create_tmp_raii_handle_var_if_needed(
+        self, handle: str, writer: Optional[Union[HasWriteLine, list[str]]] = None
+    ) -> str:
+        """If the input handle is an rvalue RAII tensor, creates an lvalue variable for
+        it in writer.  Returns a variable name that can be used to access handle."""
+        if not handle.startswith(
+            (
+                "borrow_arrayref_tensor_as_tensor(",
+                "copy_arrayref_tensor_to_tensor(",
+                "wrap_with_raii_handle_if_needed(",
+                "RAIIAtenTensorHandle(",
+            )
+        ):
+            return handle
+
+        tmp_var_name = f"var_{next(self.arg_var_id)}"
+        call_str = f"auto {tmp_var_name} = {handle};"
+
+        writer = writer if writer is not None else self
+        if isinstance(writer, list):
+            writer.append(call_str)
+        else:
+            writer.writeline(call_str)
+
+        return tmp_var_name
+
+    def write_kernel_context_guard_begin(
+        self,
+    ):
+        # Beginning of a kernel context guarded block.
+        # The block looks like this:
+        # {
+        # KernelContextGuard _ctx("{kernel_name}", {stack_trace_str});
+        # ... operations...
+        # }
+        self.writeline("{")
+
+    def write_kernel_context_guard_end(
+        self,
+    ):
+        # End of a kernel context guarded block.
+        self.writeline("}")
+
+    def write_kernel_context_guard(
+        self,
+        kernel_name: str,
+        node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
+    ):
+        def aggregate_stack_traces(
+            node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
+        ) -> OrderedSet[str]:
+            if isinstance(node_schedule, list):
+                return functools.reduce(
+                    lambda a, b: a | b,
+                    [
+                        # pyrefly: ignore [missing-attribute]
+                        node.node.get_stack_traces()
+                        for node in node_schedule
+                        if hasattr(node, "node") and node.node
+                    ],
+                    OrderedSet(),
+                )
+            elif isinstance(node_schedule, ExternKernel):
+                return node_schedule.get_stack_traces()
+            else:
+                return OrderedSet()
+
+        stack_trace_str = 'R"('
+        stack_traces = aggregate_stack_traces(node_schedule)
+
+        for stack_trace in stack_traces:
+            for line in stack_trace.split("\n"):
+                stack_trace_str += f"\n{line}"
+            stack_trace_str += "\n"
+        stack_trace_str += ')"'
+        self.writeline(f'KernelContextGuard _ctx("{kernel_name}", {stack_trace_str});')
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0c9aef609ba483ad9178f0653f52a20b1b2ea2f
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py
@@ -0,0 +1,897 @@
+# mypy: allow-untyped-defs
+from collections.abc import Callable, Sequence
+from typing import Any, Optional, Union
+
+import sympy
+
+import torch
+import torch._inductor.async_compile  # noqa: F401 required to warm up AsyncCompile pools
+import torch._ops
+
+from .. import config, ir
+from ..utils import sympy_product
+from ..virtualized import V
+from .cpp_utils import DTYPE_TO_CPP
+from .cpp_wrapper_cpu import CppWrapperCpu
+from .wrapper import (
+    BufferLike,
+    EnterSubgraphLine,
+    ExitSubgraphLine,
+    MemoryPlanningLine,
+    MemoryPlanningState,
+    PythonWrapperCodegen,
+)
+
+
+BufferName = str
+
+# Default thread stack sizes vary by platform:
+# - Linux: 8 MB
+# - macOS: 512 KB
+# - Windows: 1 MB
+# Just pick something comfortably smaller than the smallest for now.
+MAX_STACK_ALLOCATION_SIZE = 1024 * 100
+
+
+class CppWrapperCpuArrayRef(CppWrapperCpu):
+    """
+    Generates cpp wrapper for running on CPU and calls cpp kernels
+
+    This class is forked from CppWrapperCpu, with a difference that tensors may be
+    represented as ArrayRef, see torch/csrc/inductor/aoti_runtime/arrayref_tensor.h
+    """
+
+    def __init__(self):
+        super().__init__()
+        assert self.device == "cpu", "ArrayRefTensor only supported on CPU!"
+        self.allow_stack_allocation = config.aot_inductor.allow_stack_allocation
+        self.stack_allocated_buffers: dict[BufferName, BufferLike] = {}
+
+    @staticmethod
+    def create(
+        is_subgraph: bool,
+        subgraph_name: Optional[str],
+        parent_wrapper: Optional[PythonWrapperCodegen],
+        partition_signatures: Optional[ir.GraphPartitionSignature] = None,
+    ):
+        # TODO - support subgraph codegen by lifting functions. Check the
+        # comment at CppWrapperCpu `codegen_subgraph` function.
+        return CppWrapperCpuArrayRef()
+
+    @staticmethod
+    def get_input_cpp_type(input):
+        assert config.aot_inductor.use_minimal_arrayref_interface
+
+        if isinstance(input, sympy.Expr):
+            from ..graph import may_get_constant_buffer_dtype
+
+            dtype = may_get_constant_buffer_dtype(input)
+            assert dtype is not None, f"Failed to get the dtype of sympy.Expr: {input}"
+            return DTYPE_TO_CPP[dtype]
+        return f"ArrayRefTensor<{DTYPE_TO_CPP[input.get_dtype()]}>"
+
+    @staticmethod
+    def get_device_include_path(device: str) -> str:
+        assert device == "cpu", "ArrayRef only supported on CPU!"
+        if V.graph.aot_mode:
+            return "#include "
+        return "#include "
+
+    def codegen_input_numel_asserts(self):
+        for name, buf in V.graph.graph_inputs.items():
+            if isinstance(buf, sympy.Expr):
+                continue
+
+            # comparing strides for 0 size tensor is tricky. Ignore them for now.
+            if sympy_product(buf.get_size()) == 0:
+                continue
+            numel = buf.get_numel()
+            self.prefix.writeline(f"assert_numel({name}, {numel});")
+
+    def generate_extern_kernel_alloc(self, *args, **kwargs):
+        # Disable stack allocation for extern kernels.
+        self.allow_stack_allocation = False
+        super().generate_extern_kernel_alloc(*args, **kwargs)
+
+    def generate_extern_kernel_out(self, *args, **kwargs):
+        # Disable stack allocation for extern kernels.
+        self.allow_stack_allocation = False
+        super().generate_extern_kernel_out(*args, **kwargs)
+
+    def generate_fallback_kernel(self, node: ir.FallbackKernel) -> None:
+        # Disable stack allocation for extern kernels.
+        self.allow_stack_allocation = False
+        super().generate_fallback_kernel(node)
+
+    def _generate_kernel_call_helper(
+        self,
+        kernel_name: str,
+        call_args,
+        *,
+        device=None,
+        triton=True,
+        arg_types=None,
+        raw_keys=None,
+        raw_args=None,
+        triton_meta=None,
+        graph_name="",
+        original_fxnode_name=None,
+    ):
+        """
+        Generates kernel call code.
+
+        triton: Defines whether the GPU backend uses Triton for codegen.
+                Otherwise it uses the CUDA language for codegen.
+                Only valid when cuda == True.
+        """
+        assert not triton, (
+            "CppWrapperCpuArrayRef.generate_kernel_call does not support GPU"
+        )
+        assert arg_types is not None and len(call_args) == len(arg_types), (
+            "Mismatch call_args and arg_types in generate_kernel_call"
+        )
+        new_args = []
+        for idx, arg in enumerate(call_args):
+            if "*" in arg_types[idx]:
+                var_name = f"var_{next(self.arg_var_id)}"
+                self.writeline(f"auto* {var_name} = get_data_ptr_wrapper({arg});")
+                new_args.append(f"({arg_types[idx]})({var_name})")
+            else:
+                # arg is a scalar
+                new_args.append(arg)
+        # debug printer related logic for cpp kernel type.
+        debug_printer_manager = V.graph.wrapper_code.debug_printer
+        debug_printer_manager.set_printer_args(
+            call_args,
+            kernel_name,
+            None,
+            None,
+            "cpp",
+        )
+        with debug_printer_manager:
+            self.writeline(self.wrap_kernel_call(kernel_name, new_args))
+
+    def write_wrapper_decl(self):
+        inputs_len = len(V.graph.graph_inputs.keys())
+        if V.graph.aot_mode:
+            if (
+                config.aot_inductor.use_minimal_arrayref_interface
+                and not V.graph.is_const_graph
+            ):
+                input_cpp_types = ", ".join(
+                    f"{CppWrapperCpuArrayRef.get_input_cpp_type(x)}"
+                    for x in V.graph.graph_inputs.values()
+                )
+                output_arrayref_types = ", ".join(
+                    f"ArrayRefTensor<{DTYPE_TO_CPP[x.get_dtype()]}>"
+                    for x in V.graph.graph_outputs
+                )
+
+                self.prefix.splice(
+                    f"""
+                    using AOTInductorModelInputs = std::tuple<{input_cpp_types}>;
+                    using AOTInductorModelOutputs = std::tuple<{output_arrayref_types}>;
+                    """
+                )
+
+            if V.graph.const_module:
+                self.header.splice(V.graph.const_module.wrapper_code.header)
+
+                assert V.graph.const_wrapper_code is not None
+                self.prefix.splice(V.graph.const_wrapper_code)
+
+                assert V.graph.const_kernel_code is not None
+                self.kernel_declarations.splice(V.graph.const_kernel_code)
+
+            if V.graph.is_const_graph:
+                self.prefix.splice(
+                    """
+                    void AOTInductorModel::_const_run_impl(
+                        std::vector& output_handles,
+                        DeviceStreamType stream,
+                        AOTIProxyExecutorHandle proxy_executor
+                    ) {
+                    """
+                )
+            else:
+                if not config.aot_inductor.use_runtime_constant_folding:
+                    # If we do not split the constant graph, we'll just create
+                    # an empty implementation when wrapping the main module.
+                    self.prefix.splice(
+                        """
+                        void AOTInductorModel::_const_run_impl(
+                            std::vector& output_handles,
+                            DeviceStreamType stream,
+                            AOTIProxyExecutorHandle proxy_executor
+                        ) {}
+
+                        """
+                    )
+
+                run_impl_proto = """
+                    void AOTInductorModel::run_impl(
+                        AtenTensorHandle*
+                            input_handles, // array of input AtenTensorHandle; handles
+                                            // are stolen; the array itself is borrowed
+                        AtenTensorHandle*
+                            output_handles, // array for writing output AtenTensorHandle; handles
+                                            // will be stolen by the caller; the array itself is
+                                            // borrowed
+                        DeviceStreamType stream,
+                        AOTIProxyExecutorHandle proxy_executor
+                    ) {
+                    """
+
+                self.generate_input_output_runtime_checks()
+                run_impl_proto += """
+                    __check_inputs_outputs(input_handles, output_handles);
+                """
+
+                if config.aot_inductor.use_minimal_arrayref_interface:
+                    self.prefix.splice(
+                        """
+                        template <>
+                        AOTInductorModelOutputs AOTInductorModel::run_impl_minimal_arrayref_interface<
+                          AOTInductorModelInputs, AOTInductorModelOutputs>(
+                            const AOTInductorModelInputs& inputs,
+                            DeviceStreamType stream,
+                            AOTIProxyExecutorHandle proxy_executor
+                        ) {
+                        """
+                    )
+                    self.suffix.splice(run_impl_proto)
+                    self.suffix.splice(
+                        """
+                            AOTInductorModelInputs inputs;
+                            convert_handles_to_inputs(input_handles, inputs);
+                            auto outputs = run_impl_minimal_arrayref_interface(
+                                inputs, stream, proxy_executor);
+                            // NOTE: outputs is full of ArrayRef to thread_local storage. If in the future we need this
+                            // interface to perform well for a DSO using the minimal arrayref interface, all we need
+                            // to do is provide ThreadLocalCachedTensor for each one!
+                            convert_outputs_to_handles(outputs, output_handles);
+                        }
+                    """
+                    )
+
+                    self.suffix.splice(
+                        """
+                        extern "C" AOTIRuntimeError AOTInductorModelRunMinimalArrayrefInterface(
+                            AOTInductorModelHandle model_handle,
+                            const AOTInductorModelInputs& inputs,
+                            AOTInductorModelOutputs& outputs) {
+                          auto model = reinterpret_cast(model_handle);
+                          CONVERT_EXCEPTION_TO_ERROR_CODE({
+                              outputs = model->run_impl_minimal_arrayref_interface(
+                                  inputs,
+                                  (torch::aot_inductor::DeviceStreamType)nullptr,
+                                  nullptr);
+                          })
+                        }
+                    """
+                    )
+                else:
+                    self.prefix.splice(run_impl_proto)
+        else:
+            # cpp entry function for JIT with cpp wrapper
+            self.prefix.splice(
+                """
+                void inductor_entry_impl(
+                    AtenTensorHandle*
+                        input_handles, // array of input AtenTensorHandle; handles
+                                        // are stolen; the array itself is borrowed
+                    AtenTensorHandle*
+                        output_handles  // array for writing output AtenTensorHandle; handles
+                                        // will be stolen by the caller; the array itself is
+                                        // borrowed)
+                ) {
+                """
+            )
+        with self.prefix.indent():
+            # assign inputs and outputs in both cases so the later codegen can be simplified
+            if not config.aot_inductor.use_minimal_arrayref_interface:
+                if not V.graph.is_const_graph:
+                    if V.graph.aot_mode:
+                        num_args = len(V.graph.graph_inputs)
+                    else:
+                        # Weights are promoted in the JIT mode
+                        num_args = len(V.graph.graph_inputs) + len(V.graph.constants)
+                        # release GIL to support multiple instances inference (in different threads of the same process)
+                        self.prefix.splice("py::gil_scoped_release_simple release;")
+
+                    self.prefix.splice(
+                        f"""
+                            auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, {num_args});
+                        """
+                    )
+
+            if inputs_len != 0:
+                for idx, input_key in enumerate(V.graph.graph_inputs.keys()):
+                    if config.aot_inductor.use_minimal_arrayref_interface:
+                        self.prefix.writeline(
+                            f"auto {input_key} = std::get<{idx}>(inputs);"
+                        )
+                        continue
+                    # unwrap input tensor back to scalar
+                    if isinstance(V.graph.graph_inputs[input_key], sympy.Expr):
+                        from ..graph import may_get_constant_buffer_dtype
+
+                        dtype = may_get_constant_buffer_dtype(
+                            V.graph.graph_inputs[input_key]  # type: ignore[arg-type]
+                        )
+                        assert dtype is not None, (
+                            "Fails to get the dtype of the sympy.Expr"
+                        )
+                        self.codegen_tensor_item(
+                            dtype, f"inputs[{idx}]", input_key, self.prefix
+                        )
+                    else:
+                        self.prefix.writeline(
+                            f"auto {input_key} = std::move(inputs[{idx}]);"
+                        )
+
+            assert all(
+                isinstance(v, torch.Tensor) for v in list(V.graph.constants.values())
+            ), "Expect all constants to be Tensor"
+            for idx, constants_key in enumerate(V.graph.constants.keys()):
+                if V.graph.aot_mode:
+                    # Weights are stored in constants_ and owned by RAIIAtenTensorHandle there.
+                    # Don't call std::move here because it will cause constants_ to lose the ownership.
+                    self.prefix.writeline(
+                        f"""auto {constants_key} = constants_->at({idx});"""
+                    )
+                else:
+                    # Append constants as inputs to the graph
+                    constants_idx = inputs_len + idx
+                    self.prefix.writeline(
+                        f"auto {constants_key} = std::move(inputs[{constants_idx}]);"
+                    )
+
+            self.codegen_inputs()
+
+            if V.graph.aot_mode:
+                if not V.graph.is_const_graph:
+                    if config.aot_inductor.use_minimal_arrayref_interface:
+                        # TODO: input shape checking for regular tensor interface as well?
+                        self.codegen_input_numel_asserts()
+                    else:
+                        self.prefix.writeline("inputs.clear();")
+                self.prefix.writeline(
+                    "[[maybe_unused]] auto& kernels = static_cast(*this->kernels_.get());"
+                )
+
+    def generate_return(self, output_refs: list[str]):
+        cst_names = V.graph.constants.keys()
+        arr_iface = (
+            not V.graph.is_const_graph
+            and config.aot_inductor.use_minimal_arrayref_interface
+        )  # For brevity.
+
+        def use_thread_local_cached_output_tensor(idx, output):
+            cached_output_name = f"cached_output_{next(self.cached_output_id)}"
+            cache_type = "Array" if arr_iface else "Tensor"
+            self.wrapper_call.writeline(
+                f"thread_local ThreadLocalCachedOutput{cache_type}> "
+                f"{cached_output_name}({output});"
+            )
+            if arr_iface:
+                self.wrapper_call.writeline(
+                    f"{cached_output_name}.copy_data_from({output});"
+                )
+                output_entry = f"std::get<{idx}>(output_arrayref_tensors)"
+                element_type = f"std::decay_t"
+                self.wrapper_call.writeline(
+                    f"{output_entry} = {cached_output_name}.arrayref_tensor<{element_type}>();"
+                )
+            else:
+                self.wrapper_call.writeline(
+                    f"{cached_output_name}.copy_data_from({output});"
+                )
+                self.wrapper_call.writeline(
+                    f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&output_handles[{idx}]));"
+                )
+                self.wrapper_call.writeline(
+                    f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors({cached_output_name}.tensor(), "
+                    f"output_handles[{idx}]));"
+                )
+
+        if arr_iface:
+            self.wrapper_call.writeline(
+                "AOTInductorModelOutputs output_arrayref_tensors;"
+            )
+
+        output2idx: dict[str, int] = {}
+        for idx, output in enumerate(output_refs):
+            if output == "nullptr":
+                continue
+
+            is_constant_buffer = output in cst_names
+            output_buffer = V.graph.graph_outputs[idx]
+            if isinstance(output_buffer, ir.BaseView):
+                output_storage = output_buffer.unwrap_view()
+                assert isinstance(output_storage, (ir.BaseView, ir.MutableBox))
+                if isinstance(output_storage.data, ir.ConstantBuffer):
+                    is_constant_buffer = True
+
+            if isinstance(output_buffer, ir.ShapeAsConstantBuffer):
+                # Need to wrap scalar into tensor as the main function returns a vector of tensors
+                output_tensor = self.codegen_scalar_to_tensor(output)
+                self.wrapper_call.writeline(
+                    f"output_handles[{idx}] = {output_tensor}.release();"
+                )
+                continue
+
+            output_is_tensor_handle_expr = (
+                f"std::is_same_v,"
+                "RAIIAtenTensorHandle> || "
+                f"std::is_same_v,"
+                "AtenTensorHandle> || "
+                f"std::is_same_v,"
+                "ConstantHandle>"
+            )
+            self.wrapper_call.writeline(
+                f"if constexpr ({output_is_tensor_handle_expr}) {{"
+            )
+            with self.wrapper_call.indent():
+                if arr_iface:
+                    cached_output_name = f"cached_output_{next(self.cached_output_id)}"
+                    self.wrapper_call.writeline(
+                        f"thread_local RAIIAtenTensorHandle {cached_output_name};"
+                    )
+                    if is_constant_buffer:
+                        # NOTE(return_constant): In some rare cases where we return
+                        # a constant, we have to return a copy of this constant,
+                        # because (1) constants are not owned by the Model instance
+                        # (2) constants remain the same cross inference runs,
+                        # assuming they are not updated at runtime Basically, we
+                        # cannot release or transfer the ownership of any original
+                        # constant to the user.
+                        self.wrapper_call.writeline(
+                            f"AtenTensorHandle {cached_output_name}_tmp;"
+                        )
+                        self.wrapper_call.writeline(
+                            f"aoti_torch_clone({output}, &{cached_output_name}_tmp);"
+                        )
+                        self.wrapper_call.writeline(
+                            f"{cached_output_name} = {cached_output_name}_tmp;"
+                        )
+                    else:
+                        self.wrapper_call.writeline(
+                            f"{cached_output_name} = {output}.release();"
+                        )
+                    self.wrapper_call.writeline(
+                        f"convert_handle_to_arrayref_tensor({cached_output_name}, "
+                        f"std::get<{idx}>(output_arrayref_tensors));"
+                    )
+                else:
+                    if is_constant_buffer:
+                        # See NOTE(return_constant) above.
+                        self.wrapper_call.writeline(
+                            f"aoti_torch_clone({output}, &output_handles[{idx}]);"
+                        )
+                    else:
+                        if output in output2idx:
+                            src_idx = output2idx[output]
+                            self.wrapper_call.writeline(
+                                f"output_handles[{idx}] = output_handles[{src_idx}];"
+                            )
+                        else:
+                            self.wrapper_call.writeline(
+                                f"output_handles[{idx}] = {output}.release();"
+                            )
+            self.wrapper_call.writeline("} else {")
+            with self.wrapper_call.indent():
+                use_thread_local_cached_output_tensor(idx, output)
+            self.wrapper_call.writeline("}")
+
+            if output not in output2idx:
+                output2idx[output] = idx
+        if arr_iface:
+            self.wrapper_call.writeline("return output_arrayref_tensors;")
+
+    def memory_plan(self):
+        from .memory_planning import MemoryPlanner
+
+        self.lines = MemoryPlanner(self).plan(self.lines)
+        # TODO: integrate memory planning & stack allocation?
+        self.allow_stack_allocation = False
+
+    def memory_plan_reuse(self):
+        out_names = V.graph.get_output_names()
+
+        while (
+            self.lines
+            and isinstance(self.lines[-1], MemoryPlanningLine)
+            # TODO: this seems legit, NullLine has no node
+            and self.lines[-1].node.name not in out_names  # type: ignore[attr-defined]
+        ):
+            # these lines will be pointless
+            self.lines.pop()
+
+        # codegen allocations in two passes
+        planning_states = [MemoryPlanningState()]
+        past_planning_states = []
+        for i in range(len(self.lines)):
+            line = self.lines[i]
+            if isinstance(line, MemoryPlanningLine):
+                self.lines[i] = line.plan(planning_states[-1])
+            elif isinstance(line, EnterSubgraphLine):
+                planning_states.append(MemoryPlanningState())
+            elif isinstance(line, ExitSubgraphLine):
+                past_planning_states.append(planning_states.pop())
+        past_planning_states.append(planning_states.pop())
+        assert len(planning_states) == 0
+
+        # conservatively use the sum of all allocated buffer sizes
+        # in potentially nested scopes as the total allocated size
+        total_allocated_buffer_size = sum(
+            s.total_allocated_buffer_size for s in past_planning_states
+        )
+
+        self.allow_stack_allocation = (
+            self.allow_stack_allocation is not False
+            and config.aot_inductor.allow_stack_allocation
+            and total_allocated_buffer_size <= MAX_STACK_ALLOCATION_SIZE
+        )
+
+    def can_stack_allocate_buffer(self, buffer):
+        return (
+            self.allow_stack_allocation
+            and buffer.get_device().type == "cpu"
+            and self.can_prove_buffer_has_static_shape(buffer)
+            and ir.is_contiguous_strides_for_shape(
+                buffer.get_stride(), buffer.get_size()
+            )
+        )
+
+    def make_buffer_free(self, buffer):
+        return (
+            ""
+            if isinstance(buffer.get_output_spec(), ir.MultiOutputLayout)
+            or (V.graph.aot_mode and buffer.get_name() in self.stack_allocated_buffers)
+            or (
+                config.aot_inductor.use_minimal_arrayref_interface
+                and V.graph.aot_mode
+                and buffer.get_name() in V.graph.graph_inputs
+            )
+            else f"{buffer.get_name()}.reset();"
+        )
+
+    def make_buffer_allocation(self, buffer):
+        return self.make_allocation(
+            buffer.get_name(),
+            buffer.get_device(),
+            buffer.get_dtype(),
+            buffer.get_size(),
+            buffer.get_stride(),
+            buffer if self.can_stack_allocate_buffer(buffer) else None,
+            buffer.get_is_pinned(),
+        )
+
+    def make_allocation(
+        self,
+        name,
+        device,
+        dtype,
+        shape,
+        stride,
+        buffer_if_can_stack_allocate=None,
+        is_pinned=False,
+    ):
+        orig_stride = stride
+        device_str = self.codegen_device(device)
+        dtype_code = self.codegen_dtype(dtype)
+        size = self.codegen_shape_tuple(shape)
+        stride = self.codegen_shape_tuple(orig_stride)
+        size_array_var = self.codegen_int_array_var(
+            size,
+            self.wrapper_call.writeline,
+            known_statically=self.is_statically_known_list_of_ints(shape),
+            graph=self.get_codegened_graph(),
+        )
+        stride_array_var = self.codegen_int_array_var(
+            stride,
+            self.wrapper_call.writeline,
+            known_statically=self.is_statically_known_list_of_ints(orig_stride),
+            graph=self.get_codegened_graph(),
+        )
+        device_type, device_id = device_str.split(",")
+        device_idx = "this->device_idx_" if V.graph.aot_mode else device_id
+        if buffer_if_can_stack_allocate is not None:
+            self.stack_allocated_buffers[name] = buffer_if_can_stack_allocate
+            cpp_type = DTYPE_TO_CPP[dtype]
+            numel = buffer_if_can_stack_allocate.get_numel()
+            # Note: we don't zero storage because empty_strided doesn't zero either.
+            self.wrapper_call.writeline(f"{cpp_type} {name}_storage[{numel}];")
+            args = [
+                f"{name}_storage",
+                size_array_var,
+                stride_array_var,
+                device_type,
+                device_idx,
+            ]
+            return f"ArrayRefTensor<{cpp_type}> {name}({', '.join(args)});"
+
+        args = [
+            str(len(shape)),
+            size_array_var,
+            stride_array_var,
+            dtype_code,
+            device_type,
+            device_idx,
+            f"&{name}_handle",
+        ]
+
+        self.wrapper_call.writeline(f"AtenTensorHandle {name}_handle;")
+        pinned_str = "_pinned" if is_pinned else ""
+        self.wrapper_call.writeline(
+            f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided{pinned_str}({', '.join(args)}));"
+        )
+
+        return f"RAIIAtenTensorHandle {name}({name}_handle);"
+
+    def make_buffer_reuse(self, old: BufferLike, new: BufferLike, delete_old: bool):
+        assert old.get_dtype() == new.get_dtype()
+        old_name = old.get_name()
+        new_name = new.get_name()
+        del_line = ";"
+        if old_name not in V.graph.get_output_names() and delete_old:
+            del_line = f"; {self.make_buffer_free(old)}"
+
+        if old.get_size() == new.get_size() and old.get_stride() == new.get_stride():
+            if old_name in self.stack_allocated_buffers:
+                self.stack_allocated_buffers[new_name] = new
+            return self.codegen_exact_buffer_reuse(old_name, new_name, del_line)
+
+        reinterpret_view = self.codegen_reinterpret_view(
+            old, new.get_size(), new.get_stride(), 0, self.wrapper_call.writeline
+        )
+        if reinterpret_view in self.stack_allocated_buffers:
+            self.stack_allocated_buffers[new_name] = new
+            # The only way to get into this case is via an exact buffer reuse, since all
+            # other options result in a new tensor handle.
+            return self.codegen_exact_buffer_reuse(old_name, new_name, del_line)
+        return f"{self.declare}{new_name} = {reinterpret_view}{del_line}  // reuse"
+
+    def _assert_safe_to_use_borrow_arrayref_tensor_as_tensor(self):
+        # Borrowing arguments to shim functions is only safe because we know
+        # that the arguments can't be stack-allocated. Otherwise, to be sure
+        # we can't return a dangling pointer, we need to either 1) be
+        # certain that the shim function cannot return an alias of a
+        # borrowed argument, or 2) be certain that the returned Tensor from
+        # the shim function cannot escape.
+        assert self.is_safe_to_use_borrow_arrayref_tensor_as_tensor(), (
+            "borrowing arguments to shim functions is unsafe with "
+            "stack allocation on! (see comment above this assertion)"
+        )
+
+    def is_safe_to_use_borrow_arrayref_tensor_as_tensor(self):
+        return not self.allow_stack_allocation and not self.stack_allocated_buffers
+
+    def generate_c_shim_extern_kernel_call(
+        self, kernel: str, args: list[str], device: str, **_
+    ) -> None:
+        # In the abi_compatible mode, we call fallback aten ops through a C shim layer
+        # Setting self.allow_stack_allocation to False because the exchange between
+        # ArrayRefTensor and at::Tensor is still fragile.
+        self.allow_stack_allocation = False
+
+        wrapped_args = []
+        for arg in args:
+            # We only really *need* borrow_arrayref_tensor_as_tensor for
+            # ArrayRefTensors. The code flowing into here uses `0` for nullptr, which
+            # borrow_arrayref_tensor_as_tensor would blindly coerce to int, so just
+            # avoid wrapping integers.  Name matching is to find tensor is hacky, but
+            # fixing all the ArrayRefTensor issues is not a priority for now.
+            if isinstance(arg, str) and arg.startswith(
+                ("buf", "arg", "wrap_with_raii_handle_if_needed")
+            ):
+                self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor()
+                arg = f"borrow_arrayref_tensor_as_tensor({arg})"
+            wrapped_args.append(arg)
+
+        super().generate_c_shim_extern_kernel_call(
+            kernel, wrapped_args, device, debug_args=args
+        )
+
+    def generate_scatter_fallback(self, node: ir.ScatterFallback):
+        # No stack allocation when there is a fallback op
+        self.allow_stack_allocation = False
+        super().generate_scatter_fallback(node)
+
+    def _generate_scatter_fallback(
+        self,
+        output,
+        inputs,
+        cpp_kernel_name,
+        python_kernel_name,
+        src_is_tensor,
+        reduce,
+        kwargs,
+        device,
+    ):
+        reduce = self._get_scatter_reduce_enum(reduce)
+
+        # call the ABI shim function instead of the ATen one
+        self.add_device_include(device)
+        cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, device)
+
+        # TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py
+        cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out"
+        self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor()
+        inputs_wrapped = [
+            (f"borrow_arrayref_tensor_as_tensor({x})" if isinstance(x, str) else str(x))
+            for x in inputs
+        ]
+        line = f"{cpp_kernel_name}(borrow_arrayref_tensor_as_tensor({output}), {','.join(inputs_wrapped)}"
+
+        if python_kernel_name.startswith("aten.scatter_reduce"):
+            line += f", {','.join(kwargs)}"
+        else:
+            if src_is_tensor:
+                if reduce:
+                    line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}"
+            else:
+                assert reduce is None, (
+                    "Expect reduce to be None for aten.scatter_ with scalar src"
+                )
+        line += ");"
+        self.writeline(line)
+
+    def generate_index_put_fallback(self, node: ir.IndexPutFallback) -> None:
+        # No stack allocation when there is a fallback op
+        self.allow_stack_allocation = False
+        super().generate_index_put_fallback(node)
+
+    def _generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
+        self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor()
+        # TODO: update aoti_torch_index_put_out in ir.py to use autogen out version
+        # See the comment in codegen_reinterpret_view about why having something like
+        # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the corresponding
+        # tensor prematurely deallocated, thus the temporary array trick here.
+        indices_str = self._generate_temporary_array_pointer(
+            "AtenTensorHandle",
+            [f"borrow_arrayref_tensor_as_tensor({i})" for i in indices],
+        )
+        args = [
+            f"borrow_arrayref_tensor_as_tensor({x})",
+            indices_str,
+            str(len(indices)),
+            f"borrow_arrayref_tensor_as_tensor({values})",
+            accumulate,
+        ]
+        args.insert(
+            0, f"borrow_arrayref_tensor_as_tensor({x})"
+        )  # set x as the output tensor, this fallback mutates x.
+        self.writeline(self.wrap_kernel_call(kernel, args))
+
+    def generate_fallback_kernel_with_runtime_lookup(
+        self,
+        buf_name: str,
+        python_kernel_name: str,
+        get_args: Callable[[], Sequence[str]],
+        op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator],
+        raw_args: Sequence[Any],
+        outputs: Sequence[ir.Buffer],
+    ) -> None:
+        # No stack allocation when there is a fallback op
+        self.allow_stack_allocation = False
+        super().generate_fallback_kernel_with_runtime_lookup(
+            buf_name, python_kernel_name, get_args, op_overload, raw_args, outputs
+        )
+
+    def codegen_device_copy(self, src, dst, non_blocking: Union[bool, str]):
+        # aoti_torch_tensor_copy_ takes AtenTensorHandle as input,
+        # while stack-allocation results in ArrayRefTensor
+        # so disable stack allocation here
+        self.allow_stack_allocation = False
+        self.writeline(
+            f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_copy_(expensive_copy_to_tensor_if_needed({dst}), {src}, {non_blocking}));"
+        )
+
+    def codegen_reinterpret_view(
+        self,
+        data,
+        size,
+        stride,
+        offset,
+        writeline: Callable[..., None],
+        dtype=None,
+    ) -> str:
+        """Returns a newly-created, temporary RAII tensor handle containing the
+        reinterpreted tensor data.  Callers of this function are responsible for saving
+        the handle if persistent access is needed."""
+        dim = str(len(size))
+
+        def create_reinterpret_call() -> str:
+            args = [
+                f"{data.get_name()}",
+                dim,
+                self.codegen_int_array_var(
+                    self.codegen_shape_tuple(size),
+                    writeline,
+                    known_statically=self.is_statically_known_list_of_ints(size),
+                    graph=self.get_codegened_graph(),
+                ),
+                self.codegen_int_array_var(
+                    self.codegen_shape_tuple(stride),
+                    writeline,
+                    known_statically=self.is_statically_known_list_of_ints(stride),
+                    graph=self.get_codegened_graph(),
+                ),
+                offset,
+            ]
+            return f"wrap_with_raii_handle_if_needed(reinterpret_tensor_wrapper({', '.join(args)}))"
+
+        def create_new_tensor_handle() -> tuple[str, list[str]]:
+            # Calling reset() on ArrayRefTensor does nothing, since the array is
+            # const-allocated on the stack.  Thus, it's safe to return a reference to
+            # the original array.
+            if (name := data.get_name()) in self.stack_allocated_buffers:
+                return name, []
+
+            tmp_AtenTensorHandle = f"tmp_{name}_{next(self.tmp_tensor_id)}"
+            tmp_call_strs = [
+                f"AtenTensorHandle {tmp_AtenTensorHandle};",
+                f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle({data.get_name()}, &{tmp_AtenTensorHandle}));",
+            ]
+            return f"RAIIAtenTensorHandle({tmp_AtenTensorHandle})", tmp_call_strs
+
+        if (
+            size == data.layout.size
+            and stride == data.layout.stride
+            and offset == data.layout.offset
+            and (dtype is None or dtype == data.dtype)
+        ):
+            final_tensor_str, call_strs = create_new_tensor_handle()
+            for line in call_strs:
+                writeline(line)
+            return final_tensor_str
+
+        return super().codegen_reinterpret_view(
+            data, size, stride, offset, writeline, dtype
+        )
+
+    def val_to_arg_str(self, val, type_=None) -> str:
+        if (
+            val is not None
+            and isinstance(type_, torch.OptionalType)
+            and isinstance(type_.getElementType(), torch.TensorType)
+        ):
+            # Handle optional tensors as a special case, as in the parent class.
+            base_handle = self.val_to_arg_str(val, torch.TensorType)
+            if config.aot_inductor.use_minimal_arrayref_interface:
+                if self.is_safe_to_use_borrow_arrayref_tensor_as_tensor():
+                    base_handle = f"borrow_arrayref_tensor_as_tensor({base_handle})"
+                else:
+                    base_handle = f"copy_arrayref_tensor_to_tensor({base_handle})"
+            return f"&temporary_reference({base_handle}.get())"
+
+        return super().val_to_arg_str(val, type_)
+
+    def codegen_tensor_item(
+        self, dtype: torch.dtype, tensor: str, scalar: str, indented_buffer=None
+    ):
+        dtype_str = str(dtype).split(".")[-1]
+        writer = indented_buffer or self
+
+        if dtype == torch.float16 or dtype == torch.bfloat16:
+            scalar_tmp = f"{scalar}_tmp"
+            writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar_tmp};")
+
+            # We know that item_ doesn't alias the input, so borrowing should be safe.
+            tensor = f"borrow_arrayref_tensor_as_tensor({tensor})"
+
+            writer.writeline(
+                f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar_tmp}));"
+            )
+            writer.writeline(f"float {scalar} = float({scalar_tmp});")
+        else:
+            writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar};")
+
+            # We know that item_ doesn't alias the input, so borrowing should be safe.
+            tensor = f"borrow_arrayref_tensor_as_tensor({tensor})"
+
+            writer.writeline(
+                f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar}));"
+            )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_gpu.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_gpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..42c082d9d92af7585c1d56dd35a1b79ba55f9ede
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_gpu.py
@@ -0,0 +1,891 @@
+# mypy: allow-untyped-defs
+from __future__ import annotations
+
+import dataclasses
+import re
+import sys
+from itertools import count, zip_longest
+from typing import Any, Optional, Union
+from typing_extensions import Self
+
+import sympy
+
+import torch
+from torch import dtype as torch_dtype
+from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name
+from torch._inductor.runtime.runtime_utils import dynamo_timed
+
+from .. import config
+from ..codecache import CudaKernelParamCache
+from ..ir import (
+    GraphPartitionSignature,
+    TensorBox,
+    TMADescriptorExperimental,
+    TMADescriptorStable,
+)
+from ..utils import cache_on_self, get_gpu_type, GPU_ALIGN_BYTES, IndentedBuffer
+from ..virtualized import V
+from .aoti_hipify_utils import maybe_hipify_code_wrapper
+from .common import get_device_op_overrides, TritonScratchWorkspace
+from .cpp_utils import cexpr
+from .cpp_wrapper_cpu import CppWrapperCpu
+from .multi_kernel import MultiKernelCall
+from .triton_utils import should_unwrap_unspec_arg
+from .wrapper import PythonWrapperCodegen, SymbolicCallArg
+
+
+_cpp_string_literal_escapes = {
+    "\\": "\\\\",
+    '"': '\\"',
+    "\n": "\\n",
+    "\t": "\\t",
+    "\r": "\\r",
+}
+_cpp_string_literal_pattern = re.compile(r'["\\\n\t\r]')
+
+
+def cpp_string_literal(s: str) -> str:
+    escaped = _cpp_string_literal_pattern.sub(
+        lambda match: _cpp_string_literal_escapes[match.group(0)], s
+    )
+    return f'"{escaped}"'
+
+
+@dataclasses.dataclass
+class DeferredTritonCallWrapper:
+    """
+    When using cpp wrapper, GPU kernel load and launch needs to wait for Triton kernels
+    to be tuned and stored as cubin files, so use a deferred generating the final wrapper around
+    the triton kernel until right before the prefix is written.
+    """
+
+    wrapper_name: str
+    kernel_name: str
+    kernel_name_to_body: dict[str, str]
+    arg_types: list[Any]
+
+    def generate(self, wrapper: CppWrapperGpu):
+        """
+        Generate the GPU kernel definition, as well as load and launch code.
+        """
+        prefix = wrapper.prefix
+        if self.kernel_name.startswith("multi_kernel_"):
+            # MultiKernel will select one kernel after running the autotune block
+            self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
+        params = CudaKernelParamCache.get(self.kernel_name)
+        assert params, f"CudaKernelParamCache not populated for {self.kernel_name}"
+        def_args = params["def_args"]
+        arg_types = self.arg_types
+        inductor_meta = params["inductor_meta"]
+
+        if "extra_launcher_args" in inductor_meta and len(def_args) > len(arg_types):
+            # extra_launcher_args should already be in def_args
+            assert len(def_args) == len(arg_types) - len(
+                inductor_meta["extra_launcher_args"]
+            )
+            arg_types = arg_types + [SymbolicCallArg] * len(
+                inductor_meta["extra_launcher_args"]
+            )
+
+        if not V.graph.aot_mode:
+            prefix.writeline(
+                maybe_hipify_code_wrapper(
+                    f"static {wrapper.device_codegen.cpp_kernel_type()} {self.kernel_name} = nullptr;"
+                )
+            )
+            kernel_var_name = self.kernel_name
+        else:
+            kernel_var_name = f"kernels_.{self.kernel_name}"
+
+        # tensors can be RAIIAtenTensorHandle or ConstantHandle, so make them template types
+        template_types = [
+            f"typename {name}_type_"
+            for name, arg_type in zip(def_args, arg_types)
+            if isinstance(arg_type, (torch_dtype, UnwrapUnspecArg))
+        ]
+        if V.graph.aot_mode:
+            template_types.append("typename kernels_type_")
+        if template_types:
+            prefix.writeline(f"template <{', '.join(template_types)}>")
+        prefix.writeline(f"static inline void {self.wrapper_name}(")
+        with prefix.indent():
+            assert len(def_args) == len(arg_types), (def_args, arg_types)
+            for name, arg_type in zip(def_args, arg_types):
+                if isinstance(arg_type, (torch_dtype, UnwrapUnspecArg)):
+                    prefix.writeline(f"const {name}_type_& {name},")
+                elif issubclass(arg_type, (SymbolicCallArg, sympy.Expr, int)):
+                    prefix.writeline(f"int64_t {name},")
+                elif arg_type is float:
+                    prefix.writeline(f"float {name},")
+                elif arg_type is bool:
+                    prefix.writeline(f"bool {name},")
+                else:
+                    raise ValueError(f"Unexpected arg type {arg_type}")
+            prefix.writeline("int32_t device_idx_,")
+            prefix.writeline(
+                maybe_hipify_code_wrapper(
+                    f"{wrapper.device_codegen.cpp_stream_type()} stream_,"
+                )
+            )
+            if V.graph.aot_mode:
+                prefix.writeline("kernels_type_& kernels_,")
+            prefix.writeline(
+                "const std::optional& cubin_dir_ = std::nullopt"
+            )
+        prefix.writeline("){")
+        with prefix.indent():
+            if V.graph.aot_mode:
+                # Emit the original Triton kernel for debugging purposes
+                prefix.writeline("/*")
+                prefix.splice(self.kernel_name_to_body[self.kernel_name])
+                prefix.writeline("*/")
+            self.generate_grid(prefix, inductor_meta, params)
+            self.generate_load_kernel(prefix, kernel_var_name, params)
+            self.generate_launch_kernel(prefix, wrapper, kernel_var_name, params)
+        prefix.writeline("}")
+
+        if not config.aot_inductor.embed_kernel_binary:
+            # Ensure the cubin file is included in the package
+            V.graph.wrapper_code.additional_files.append(
+                params[get_cpp_wrapper_cubin_path_name()]
+            )
+
+    def generate_grid(
+        self,
+        prefix: IndentedBuffer,
+        inductor_meta: dict[str, Any],
+        params: dict[str, Any],
+    ):
+        from ..runtime.triton_heuristics import GridExpr
+
+        grid = GridExpr.from_meta(inductor_meta, params["config"], mode="cpp")
+        for line in grid.prefix:
+            prefix.writeline(line)
+        prefix.splice(
+            f"""\
+            uint32_t grid_0 = {grid.x_grid};
+            uint32_t grid_1 = {grid.y_grid};
+            uint32_t grid_2 = {grid.z_grid};
+            """
+        )
+        prefix.writeline("if (grid_0 == 0 || grid_1 == 0 || grid_2 == 0) return;")
+
+    def generate_load_kernel(self, prefix, kernel_var_name, params):
+        prefix.writeline(f"if ({kernel_var_name} == nullptr) {{")
+        with prefix.indent():
+            embed_kernel_args = [f"__{params['inductor_meta']['kernel_name']}_start"]
+            if torch.xpu.is_available():
+                # XPU needs the end address of the kernel to calculate the size of the kernel binary.
+                embed_kernel_args.append(
+                    f"__{params['inductor_meta']['kernel_name']}_end"
+                )
+
+            load_kernel_args = (
+                [
+                    *embed_kernel_args,
+                    cpp_string_literal(params["mangled_name"]),
+                    str(params["shared_mem"]),
+                ]
+                if V.graph.aot_mode and config.aot_inductor.embed_kernel_binary
+                else [
+                    cpp_string_literal(params[get_cpp_wrapper_cubin_path_name()]),
+                    cpp_string_literal(params["mangled_name"]),
+                    str(params["shared_mem"]),
+                    "cubin_dir_",
+                ]
+            )
+            prefix.writeline(
+                f"{kernel_var_name} = loadKernel({', '.join(load_kernel_args)}); "
+            )
+        prefix.writeline("}")
+
+    def generate_launch_kernel(self, prefix, wrapper, kernel_var_name, params):
+        """
+        Generate the GPU kernel launching code.
+        This is where all the call args being sorted out and generated.
+        If enable_kernel_profile is enabled, all args related information would be packed in this function.
+        """
+        triton_meta = params["triton_meta"]
+        assert len(self.arg_types) == len(params["def_args"]), (
+            self.arg_types,
+            params["def_args"],
+        )
+        arg_type_loookup = dict(zip(params["def_args"], self.arg_types))
+        # difference between Python and C++ wrapper: C++ wrapper strips out equal_to_1 constants
+        call_args = [
+            name for name in params["call_args"] if name not in triton_meta["constants"]
+        ]
+        arg_types = [arg_type_loookup[name] for name in call_args]
+        arg_signatures = [triton_meta["signature"][name] for name in call_args]
+        scratch_spaces = {
+            name: params[name]
+            for name in ["global_scratch", "profile_scratch"]
+            if params.get(name, None) is not None
+        }
+        call_args_str = wrapper.generate_args_decl(
+            prefix,
+            call_args,
+            arg_types,
+            arg_signatures,
+            scratch_spaces=scratch_spaces,
+        )
+        prefix.writeline(f"void* kernel_args_[] = {{{call_args_str}}};")
+        launch_kernel_args = [
+            kernel_var_name,
+            "grid_0",
+            "grid_1",
+            "grid_2",
+            str(params["num_warps"]),
+            str(params["shared_mem"]),
+            "kernel_args_",
+            "stream_",
+        ]
+        if wrapper.device == "xpu":
+            launch_kernel_args.append(str(params["threads_per_warp"]))
+
+        enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [
+            "linux",
+            "win32",
+        ]
+        if enable_kernel_profile:
+            normalized_kernel_name = re.sub(r"[^a-zA-Z0-9_]", "_", f"{kernel_var_name}")
+            prefix.writeline("{")
+            with prefix.indent():
+                prefix.writelines(
+                    [
+                        f"std::unordered_map kwargs_{normalized_kernel_name};",
+                        "",
+                    ]
+                )
+                # Add launch args info
+                record_launch_kernel_args = [
+                    ("grid_0", "grid_0"),
+                    ("grid_1", "grid_1"),
+                    ("grid_2", "grid_2"),
+                    ("num_warps", str(params["num_warps"])),
+                    ("shared_mem", str(params["shared_mem"])),
+                ]
+                for k, v in record_launch_kernel_args:
+                    arg_name = f"{normalized_kernel_name}_{k}"
+                    prefix.writelines(
+                        [
+                            f"// Create c10::IValue for {k}",
+                            f"C10IValueHandle tmp_{arg_name};",
+                            f"aoti_torch_int64_to_ivalue({v}, &tmp_{arg_name});",
+                            f"RAIIC10IValueHandle RAII_{arg_name}(tmp_{arg_name});",
+                            f'kwargs_{normalized_kernel_name}.emplace("{k}", RAII_{arg_name});',
+                        ]
+                    )
+
+                # Add input info (This copies the logic from args_decl)
+                signature2dtype = {
+                    "i32": "int32_t",
+                    "i64": "int64_t",
+                    "fp32": "float",
+                }
+
+                def signature_is_tma_desc(sig):
+                    if not sig:
+                        return False
+                    if sig == "nvTmaDesc":
+                        return True
+                    if sig.startswith("tensordesc<"):
+                        return True
+                    return False
+
+                curr_arg_id = -1
+                total_args = []
+                ordered_argsname = []
+
+                def write_dummy_scalar_ivalue(arg_name):
+                    # We only care about the shape, therefore we create a dummy scalar here.
+                    prefix.writelines(
+                        [
+                            f"// Create c10::IValue for arg_{curr_arg_id}",
+                            f"C10IValueHandle tmp_{arg_name};",
+                            f"aoti_torch_int64_to_ivalue(0, &tmp_{arg_name});",
+                            f"RAIIC10IValueHandle RAII_{arg_name}(tmp_{arg_name});",
+                        ]
+                    )
+                    # pyrefly: ignore [bad-argument-type]
+                    total_args.append(f"tmp_{arg_name}")
+
+                def process_args_for_input_shape(arg, arg_type, arg_signature=None):
+                    nonlocal curr_arg_id
+                    curr_arg_id += 1
+                    arg_name = f"{normalized_kernel_name}_arg_{curr_arg_id}"
+                    # ignore tma descriptors, as host-side TMA descriptors need
+                    # to be passed to the compiled Triton kernel by value
+                    if isinstance(
+                        arg_type, UnwrapUnspecArg
+                    ) and not signature_is_tma_desc(arg_signature):
+                        write_dummy_scalar_ivalue(arg_name)
+                    elif isinstance(
+                        arg_type, torch_dtype
+                    ) and not signature_is_tma_desc(arg_signature):
+                        # This is an at::Tensor.
+                        prefix.writelines(
+                            [
+                                f"// Create c10::IValue for arg_{curr_arg_id}",
+                                f"C10IValueHandle tmp_{arg_name};",
+                                f"aoti_torch_tensor_to_ivalue({arg}, &tmp_{arg_name});",
+                                f"RAIIC10IValueHandle RAII_{arg_name}(tmp_{arg_name});",
+                            ]
+                        )
+                        # pyrefly: ignore [bad-argument-type]
+                        total_args.append(f"tmp_{arg_name}")
+                    elif (
+                        isinstance(arg_type, type(SymbolicCallArg))
+                        and arg_signature is not None
+                        and arg_signature in signature2dtype
+                    ) or arg_type in (sympy.Integer, int, sympy.Float, float):
+                        write_dummy_scalar_ivalue(arg_name)
+                    elif arg_signature and arg_signature.startswith("tensordesc<"):
+                        # Skip tma related args
+                        pass
+                    else:
+                        write_dummy_scalar_ivalue(arg_name)
+
+                # Add input name and shape information
+                for arg, arg_type, arg_signature in zip_longest(
+                    call_args, arg_types, arg_signatures
+                ):
+                    # pyrefly: ignore [bad-argument-type]
+                    ordered_argsname.append(f'"{arg}"')
+                    process_args_for_input_shape(arg, arg_type, arg_signature)
+
+                # Add input name into kwargs
+                name_var = f"{normalized_kernel_name}_input_names"
+                prefix.writelines(
+                    [
+                        "// Create c10::IValue for input names",
+                        f"C10IValueHandle tmp_{name_var};",
+                        f"std::vector {name_var}({{{', '.join(ordered_argsname)}}});",
+                        f"aoti_torch_strlist_to_ivalue({name_var}.data(), {len(ordered_argsname)}, &tmp_{name_var});",
+                        f"RAIIC10IValueHandle RAII_{name_var}(tmp_{name_var});",
+                        f'kwargs_{normalized_kernel_name}.emplace("Input Args", RAII_{name_var});',
+                    ]
+                )
+
+                inputs_info_ = f"{normalized_kernel_name}_inputs_info_"
+                # We pass in the non-RAII handles, since C10 doesn't automatically free them.
+                # The RAII will make sure they get freed when they are out of scope.
+                tmp_args = ",".join(total_args)
+                prefix.writelines(
+                    [
+                        "// Aggregate all c10::IValue for inputs",
+                        f"std::vector {inputs_info_}({{{tmp_args}}});",
+                    ]
+                )
+
+                # Start recording Function
+                prefix.writelines(
+                    [
+                        "",
+                        (
+                            "torch::aot_inductor::RAIIAtenRecordFunctionHandle "
+                            f"record_{normalized_kernel_name}_"
+                            f'("{kernel_var_name}", '
+                            f"reinterpret_cast(&kwargs_{normalized_kernel_name}), "
+                            f"{inputs_info_});"
+                        ),
+                        "",
+                        f"launchKernel({', '.join(launch_kernel_args)});",
+                    ]
+                )
+            prefix.writeline("}")
+        else:
+            prefix.writeline(f"launchKernel({', '.join(launch_kernel_args)});")
+
+
+class CppWrapperGpu(CppWrapperCpu):
+    """
+    Generates cpp wrapper for running on GPU and calls CUDA kernels
+    """
+
+    def __init__(self) -> None:
+        self.device = get_gpu_type()
+        self.device_codegen = get_device_op_overrides(self.device)
+        super().__init__()
+        self.grid_id = count()
+        self._kernel_name_to_body: dict[str, str] = {}
+        self._triton_call_wrappers: dict[str, DeferredTritonCallWrapper] = {}
+        self.autotune_input_prefix = "_REAL_AUTOTUNE_INPUT"
+
+    @staticmethod
+    def create(
+        is_subgraph: bool,
+        subgraph_name: Optional[str],
+        parent_wrapper: Optional[PythonWrapperCodegen],
+        partition_signatures: Optional[GraphPartitionSignature] = None,
+    ):
+        # TODO - support subgraph codegen by lifting functions. Check the
+        # comment at CppWrapperCpu `codegen_subgraph` function.
+        return CppWrapperGpu()
+
+    def write_header(self):
+        if V.graph.is_const_graph:
+            # We do not write header for constant graph, it will be written by main module.
+            return
+
+        super().write_header()
+        self.header.splice(
+            maybe_hipify_code_wrapper(self.device_codegen.kernel_driver())
+        )
+
+    @cache_on_self
+    def write_tma_descriptor_helpers_once(self):
+        self.header.splice(self.device_codegen.tma_descriptor_helpers())
+
+    def write_get_raw_stream(self, device_idx: int, graph_name: str) -> str:
+        name = f"stream{device_idx}"
+        self.writeline(
+            maybe_hipify_code_wrapper(
+                f"{self.device_codegen.cpp_stream_type()} {name};"
+            )
+        )
+        self.writeline(
+            f"AOTI_TORCH_ERROR_CODE_CHECK({self.device_codegen.aoti_get_stream()}({device_idx}, (void**)&{name}));"
+        )
+        return name
+
+    def get_autotuning_input_name(self, idx):
+        return f"{self.autotune_input_prefix}_{idx}"
+
+    def codegen_inputs(self):
+        # See Note: [Input Alignment handling in Inductor]
+        #
+        # JIT Inductor does not guard on input alignment. It relies on copy_misaligned_inputs to
+        # copy misaligned inputs to aligned buffers. For AOTInductor, we need to do the same in cpp.
+
+        if config.is_fbcode():
+            # TODO: This is added because FC. Remove this once the newly added shim symbols,
+            # e.g. aoti_torch_clone_preserve_strides, have landed
+            return super().codegen_inputs()
+
+        if V.graph.aot_mode and V.graph.inputs_to_check:
+            for idx in V.graph.inputs_to_check:
+                input_name = V.graph.graph_input_names[idx]
+                assert input_name in V.graph.graph_inputs, (
+                    f"{input_name} not found in graph inputs"
+                )
+                value = V.graph.graph_inputs[input_name]
+                assert isinstance(value, TensorBox), (
+                    f"{input_name} is expected to be tensor but found as {type(value)}"
+                )
+                warn_msg = (
+                    f"Input {idx} was compiled as {GPU_ALIGN_BYTES}-bytes aligned, "
+                    "but it is not aligned at run time. Copying to an aligned tensor "
+                    "to guarantee correctness, but expect a performance hit."
+                )
+                self.prefix.splice(
+                    f"""
+                    if ((reinterpret_cast({input_name}.data_ptr()) & ({GPU_ALIGN_BYTES} -1)) != 0) {{
+                        AOTI_TORCH_WARN("{warn_msg}");
+                        AtenTensorHandle {input_name}_aligned;
+                        aoti_torch_clone_preserve_strides({input_name}, &{input_name}_aligned);
+                        {input_name} = std::move(RAIIAtenTensorHandle({input_name}_aligned));
+                    }}
+                    """
+                )
+
+        super().codegen_inputs()
+
+    def _define_kernel_helper(
+        self,
+        kernel_name: str,
+        kernel_body: str,
+        metadata: Optional[str] = None,
+        gpu: bool = True,
+        cpp_definition: Optional[str] = None,
+    ):
+        if gpu:
+            self._kernel_name_to_body[kernel_name] = kernel_body
+            if config.triton.autotune_at_compile_time:
+                # Call PythonWrapperCodegen to create the autotune code block
+                PythonWrapperCodegen._define_kernel_helper(
+                    self, kernel_name, kernel_body, metadata, gpu, cpp_definition
+                )
+        else:
+            return CppWrapperCpu._define_kernel_helper(
+                self, kernel_name, kernel_body, metadata, gpu, cpp_definition
+            )
+
+    def generate(self, is_inference):
+        with dynamo_timed("CppWrapperGpu.generate", log_pt2_compile_event=True):
+            return super().generate(is_inference)
+
+    def finalize_prefix(self):
+        """Define the triton kernels now that autotuning is finished"""
+        old_prefix = self.prefix  # new content should go at start of prefix
+
+        # Generating triton kernel callers can modify the prefix (cached dtypes),
+        # so do this before running finalize_prefix(), but put the generated code
+        # after the finalize_prefix() code.
+        self.prefix = IndentedBuffer()
+        for kernel in self._triton_call_wrappers.values():
+            self.prefix.writeline("\n")
+            kernel.generate(self)
+        triton_prefix = self.prefix
+
+        self.prefix = IndentedBuffer()
+        super().finalize_prefix()
+
+        self.prefix.splice(triton_prefix)
+
+        self.prefix.writeline("\n")
+        self.prefix.splice(old_prefix)
+
+    def generate_tma_descriptor(self, desc):
+        self.write_tma_descriptor_helpers_once()
+
+        if isinstance(desc, TMADescriptorExperimental):
+            self._generate_experimental_tma_descriptor(desc)
+        else:
+            assert isinstance(desc, TMADescriptorStable)
+            self._generate_stable_tma_descriptor(desc)
+
+    def _generate_experimental_tma_descriptor(self, desc):
+        # generate data pointer for the source tensor
+        source = self.generate_args_decl(
+            code=self,
+            call_args=[self.val_to_arg_str(desc.tensor)],
+            arg_types=[desc.tensor.get_dtype()],
+            arg_signatures=[None],
+            # these args are passed to initNDTMADescriptor, which is NOT a triton kernel
+            is_triton_kernel=False,
+        )
+
+        desc_name = desc.name
+        self.writeline(f"alignas(64) CUtensorMap {desc_name};")
+
+        # `source` is in the form of `&var_x`, where `var_x` is the data pointer
+        # (CUdeviceptr); we dereference `source` and cast to `void*` to pass to
+        # the data pointer of the source tensor to the helper function
+        # `init{1,2}DTMADescriptor`
+        ptr = f"reinterpret_cast(*({source}))"
+        dims = ", ".join(self.val_to_arg_str(dim) for dim in desc.dims)
+        block_dims = ", ".join(self.val_to_arg_str(dim) for dim in desc.block_dims)
+        element_size = self.val_to_arg_str(desc.element_size)
+        fn = f"init{desc.rank}DTMADescriptor"
+        args = f"&{desc_name}, {ptr}, {dims}, {block_dims}, {element_size}"
+        self.writeline(f"{fn}({args});")
+
+    def _generate_stable_tma_descriptor(self, desc):
+        source = self.generate_args_decl(
+            code=self,
+            call_args=[self.val_to_arg_str(desc.tensor)],
+            arg_types=[desc.tensor.get_dtype()],
+            arg_signatures=[None],
+            # these args are passed to initNDTMADescriptor, which is NOT a triton kernel
+            is_triton_kernel=False,
+        )
+
+        desc_name = desc.name
+        # Pack the relevant information into a StableTMADescriptor struct.
+        # See [Note: AOTI TMA Stable handling] for more details.
+        self.writeline(f"alignas(64) StableTMADescriptor {desc_name};")
+
+        def fill_array(name, values):
+            for i, val in enumerate(values):
+                self.writeline(f"{name}[{i}] = {val};")
+
+        ptr = f"reinterpret_cast(*({source}))"
+        rank = len(desc.tensor.get_size())
+
+        fill_array(f"{desc_name}.block_shape", desc.block_shape)
+        fill_array(f"{desc_name}.global_shape", desc.tensor.get_size())
+        fill_array(f"{desc_name}.strides", desc.tensor.get_stride())
+
+        element_size = self.val_to_arg_str(desc.tensor.get_dtype().itemsize)
+        fn = "initTMADescriptor"
+        args = ", ".join(
+            str(x)
+            for x in [
+                f"&{desc_name}.m",
+                ptr,
+                element_size,
+                rank,
+                f"{desc_name}.block_shape",
+                f"{desc_name}.global_shape",
+                f"{desc_name}.strides",
+            ]
+        )
+        self.writeline(f"{fn}({args});")
+
+    def generate_args_decl(
+        self,
+        code: Union[IndentedBuffer, Self],
+        call_args,
+        arg_types,
+        arg_signatures,
+        is_triton_kernel=True,
+        scratch_spaces: Optional[dict[str, int]] = None,
+    ):
+        """
+        Generates any declarations of args to pass into a kernel call, and then returns the arg names.
+
+        In more detail:
+        * declarations: e.g. this function has a side effect of generating lines like `auto var_0 = ...;`
+        * returns: a string with the list of args, e.g. "var_0, var_1"
+
+        call_args: list of call arguments
+        arg_types: list of argument types
+        arg_signatures: list with signatures of all the args
+        is_triton_kernel: whether these are passed into a triton kernel or not. In particular,
+                          calls to triton kernels will have an additional global scratch space
+                          arg injected at the front of the arg list.
+        """
+        new_args: list[str] = []
+
+        # Add more cases for other types as needed
+        signature2dtype = {
+            "i32": "int32_t",
+            "i64": "int64_t",
+            "fp32": "float",
+        }
+
+        def signature_is_tma_desc(sig):
+            if not sig:
+                return False
+            if sig == "nvTmaDesc":
+                return True
+            if sig.startswith("tensordesc<"):
+                return True
+            return False
+
+        def process_tma_stable_arg(arg, arg_type, arg_signature, var_name):
+            # [Note: AOTI TMA Stable handling]
+            # For most args, a single arg passed to the python triton interface
+            # maps to a single arg in the cubin interface. However, for host-side
+            # TMA descriptors, a single python arg turns into 1 + 2 * N args in the
+            # cubin interface (where N is the rank).
+            #
+            # To do this: at TMA codegen time (for aoti), we generate a struct
+            # (StableTMADescriptor) containing the necessary information; and then
+            # when we call the function (i.e. here), we unpack the struct members.
+            code.writeline(f"auto {var_name} = {cexpr(arg)};")
+
+            result = []
+            result.append(f"&{var_name}.m")
+
+            # from https://github.com/triton-lang/triton/blob/16961b79bdac1b774b42d44e52fd55a266ec2866/third_party/nvidia/backend/driver.py#L111  # noqa: B950
+            match = re.match("tensordesc<([^[>]*)\\[([^]]*)\\]", arg_signature)
+            assert match is not None
+            shape = match.group(2)
+            ndim = shape.count(",") + 1
+
+            for i in range(ndim):
+                result.append(f"&{var_name}.block_shape[{i}]")
+
+            for i in range(ndim):
+                result.append(f"&{var_name}.strides[{i}]")
+
+            return result
+
+        def process_args(arg, arg_type, arg_signature=None):
+            var_name = f"var_{next(self.arg_var_id)}"
+            # ignore tma descriptors, as host-side TMA descriptors need
+            # to be passed to the compiled Triton kernel by value
+            if isinstance(arg_type, UnwrapUnspecArg) and not signature_is_tma_desc(
+                arg_signature
+            ):
+                self.codegen_tensor_item(
+                    arg_type.dtype,
+                    arg,
+                    var_name,
+                    indented_buffer=code,
+                )
+                new_args.append(f"&{var_name}")
+            elif isinstance(arg_type, torch_dtype) and not signature_is_tma_desc(
+                arg_signature
+            ):
+                device_ptr_type = self.device_codegen.cpp_device_ptr()
+                code.writeline(
+                    maybe_hipify_code_wrapper(
+                        f"{device_ptr_type} {var_name} = reinterpret_cast<{device_ptr_type}>({arg}.data_ptr());"
+                    )
+                )
+                new_args.append(f"&{var_name}")
+            # For symbolic call arguments, examine the arg signatures from triton meta
+            # to explicitly cast to the right type
+            # Reason: `auto` can infer unexpected type against kernel input signature.
+            elif (
+                isinstance(arg_type, type(SymbolicCallArg))
+                and arg_signature is not None
+                and arg_signature in signature2dtype
+            ):
+                code.writeline(
+                    f"{signature2dtype[arg_signature]} {var_name} = {cexpr(arg)};"
+                )
+                new_args.append(f"&{var_name}")
+            elif arg_type in (sympy.Integer, int):
+                code.writeline(f"int {var_name} = {cexpr(arg)};")
+                new_args.append(f"&{var_name}")
+            elif arg_type in (sympy.Float, float):
+                code.writeline(f"float {var_name} = {cexpr(arg)};")
+                new_args.append(f"&{var_name}")
+            elif arg_signature and arg_signature.startswith("tensordesc<"):
+                new_args.extend(
+                    process_tma_stable_arg(arg, arg_type, arg_signature, var_name)
+                )
+            else:
+                code.writeline(f"auto {var_name} = {cexpr(arg)};")
+                new_args.append(f"&{var_name}")
+
+        for arg, arg_type, arg_signature in zip_longest(
+            call_args, arg_types, arg_signatures
+        ):
+            process_args(arg, arg_type, arg_signature)
+
+        for scratch_name, workspace_size in (scratch_spaces or {}).items():
+            if (
+                is_triton_kernel
+                and (
+                    scratch := self.device_codegen.cpp_scratch(
+                        next(self.arg_var_id),
+                        workspace=TritonScratchWorkspace(
+                            size=workspace_size,
+                            generate_dtype_str=(
+                                lambda: self.codegen_dtype(torch.uint8)
+                            ),
+                        ),
+                        prefix=scratch_name,
+                    )
+                )
+                is not None
+            ):
+                scratch_def, scratch_var = scratch
+                code.writelines([maybe_hipify_code_wrapper(x) for x in scratch_def])
+                new_args.append(f"&{scratch_var}")
+
+        return ", ".join(new_args)
+
+    def _generate_kernel_call_helper(
+        self,
+        kernel_name: str,
+        call_args,
+        *,
+        device=None,
+        triton=True,
+        arg_types=None,
+        raw_keys=None,
+        raw_args=None,
+        triton_meta=None,
+        graph_name="",
+        original_fxnode_name=None,
+    ):
+        """
+        Override the default value of argument 'gpu' to True here.
+        generate_kernel_call can still be called with gpu=False because of
+        a mix of cpu kernels and gpu kernels.
+        """
+        device = device or V.graph.get_current_device_or_throw()
+        if device.type == "cpu":
+            # Even in CppWrapperGpu, we may see cpp kernels
+            return CppWrapperCpu._generate_kernel_call_helper(
+                self,
+                kernel_name,
+                call_args,
+                device=device,
+                triton=triton,
+                arg_types=arg_types,
+                raw_keys=raw_keys,
+                raw_args=raw_args,
+                triton_meta=triton_meta,
+            )
+
+        if (
+            triton
+            and config.triton.autotune_at_compile_time
+            and kernel_name not in self.kernel_autotune_names
+        ):
+            # Call PythonWrapperCodegen to create the autotune code block
+            PythonWrapperCodegen._generate_kernel_call_helper(
+                self,
+                kernel_name,
+                call_args,
+                device=device,
+                triton=triton,
+                arg_types=arg_types,
+                raw_keys=raw_keys,
+                raw_args=raw_args,
+                triton_meta=triton_meta,
+                original_fxnode_name=original_fxnode_name,
+            )
+
+        stream = (
+            "stream"
+            if V.graph.aot_mode
+            else self.write_get_raw_stream(device.index, graph_name)
+        )
+
+        if triton:
+            call_args, arg_types = self.prepare_triton_wrapper_args(
+                call_args,
+                # pyrefly: ignore [bad-argument-type]
+                arg_types,
+            )
+            wrapper_name = f"call_{kernel_name}"
+            if wrapper_name not in self._triton_call_wrappers:
+                self._triton_call_wrappers[wrapper_name] = DeferredTritonCallWrapper(
+                    wrapper_name,
+                    kernel_name,
+                    self._kernel_name_to_body,
+                    arg_types,
+                )
+            device_idx = "this->device_idx_" if V.graph.aot_mode else str(device.index)
+            call_args.append(device_idx)
+            call_args.append(stream)
+            if V.graph.aot_mode:
+                call_args.append("kernels")
+                call_args.append("this->cubin_dir_")
+            debug_printer_manager = V.graph.wrapper_code.debug_printer
+            debug_printer_manager.set_printer_args(
+                call_args[: len(arg_types)], kernel_name, arg_types, None
+            )
+            with debug_printer_manager:
+                self.writeline(f"{wrapper_name}({', '.join(call_args)});")
+        else:
+            casted = []
+            # pyrefly: ignore [no-matching-overload]
+            for arg_type, arg in zip(arg_types, call_args):
+                new_arg = arg
+                if arg_type.endswith("*") and arg != "nullptr":
+                    new_arg = f"{arg}.data_ptr()"
+                # pyrefly: ignore [bad-argument-type]
+                casted.append(f"({arg_type}){cexpr(new_arg)}")
+            call_args_str = ", ".join(casted)
+            self.writeline(f"kernels.{kernel_name}({call_args_str}, {stream});")
+
+    @staticmethod
+    def prepare_triton_wrapper_args(
+        call_args: list[Any], arg_types: list[Any]
+    ) -> tuple[list[Any], list[Any]]:
+        assert len(call_args) == len(arg_types), (call_args, arg_types)
+        new_args = []
+        new_args_types = []
+        for arg, arg_type in zip(call_args, arg_types):
+            if isinstance(arg, str):
+                if isinstance(arg_type, torch_dtype) and should_unwrap_unspec_arg(arg):
+                    # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
+                    arg_type = UnwrapUnspecArg(dtype=arg_type)
+                new_args.append(arg)
+            elif isinstance(arg, bool):
+                new_args.append(str(arg).lower())
+            elif isinstance(arg, (int, float, SymbolicCallArg)):
+                new_args.append(str(arg))
+            else:
+                new_args.append(cexpr(V.graph.sizevars.simplify(arg)))
+            new_args_types.append(arg_type)
+        return new_args, new_args_types
+
+    def make_zero_buffer(self, name):
+        return f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_({name}.get()));"
+
+
+@dataclasses.dataclass
+class UnwrapUnspecArg:
+    """Marker that we need to call .item() on the tensor"""
+
+    dtype: torch_dtype
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_mps.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_mps.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a5638f37b7856927f612a37c867c46c3e76785e
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_mps.py
@@ -0,0 +1,301 @@
+from typing import Any, Optional
+
+import sympy
+
+import torch
+from torch.utils._ordered_set import OrderedSet
+
+from ..ir import GraphPartitionSignature
+from ..virtualized import V
+from .cpp_wrapper_cpu import CppWrapperCpu
+from .cpp_wrapper_gpu import CppWrapperGpu
+from .wrapper import KernelCallLine, PythonWrapperCodegen
+
+
+class CppWrapperMps(CppWrapperGpu):
+    """
+    Generates cpp wrapper for running on MPS and calls metal kernels
+    """
+
+    def __init__(self) -> None:
+        super().__init__()
+        self._used_kernel_names: OrderedSet[str] = OrderedSet()
+        self._lambda_counter: int = 0
+
+    @staticmethod
+    def create(
+        is_subgraph: bool,
+        subgraph_name: Optional[str],
+        parent_wrapper: Optional[PythonWrapperCodegen],
+        partition_signatures: Optional[GraphPartitionSignature] = None,
+    ) -> "CppWrapperMps":
+        return CppWrapperMps()
+
+    def _generate_kernel_call_helper(
+        self,
+        kernel_name: str,
+        call_args: list[str],
+        *,
+        device: Optional[torch.device] = None,
+        triton: bool = True,
+        arg_types: Optional[tuple[Any, ...]] = None,
+        raw_keys: Optional[tuple[Any, ...]] = None,
+        raw_args: Optional[tuple[Any, ...]] = None,
+        triton_meta: Optional[dict[str, Any]] = None,
+        graph_name: str = "",
+        original_fxnode_name: Optional[str] = None,
+    ) -> None:
+        """
+        Generates MPS kernel call code. It should look something like:
+        ```
+        auto mps_lib_0_lambda = [&](AOTIMetalKernelFunctionHandle handle) {
+            aoti_torch_mps_start_encoding(handle);
+            aoti_torch_mps_set_arg_tensor(handle, 0, buf0);
+            aoti_torch_mps_set_arg_tensor(handle, 1, arg0_1);
+            aoti_torch_mps_set_arg_tensor(handle, 2, arg1_1);
+            aoti_torch_mps_dispatch_single(handle, static_cast(10LL));
+        };
+
+        std::function mps_lib_0_func_wrapper = mps_lib_0_lambda;
+        aoti_torch_mps_run_command_block(get_mps_lib_0_handle(), aoti_torch_mps_shared_callback, &mps_lib_0_func_wrapper);
+        ```
+        """
+        device = device or V.graph.get_current_device_or_throw()
+        if device.type == "cpu":
+            # Even in CppWrapperGpu, we may see cpp kernels
+            return CppWrapperCpu._generate_kernel_call_helper(
+                self,
+                kernel_name,
+                call_args,
+                device=device,
+                triton=triton,
+                arg_types=arg_types,
+                raw_keys=raw_keys,
+                raw_args=raw_args,
+                triton_meta=triton_meta,
+            )
+
+        assert device.type == "mps"
+
+        assert arg_types is not None
+
+        new_args = []
+        for idx, (arg, arg_type) in enumerate(zip(call_args[:-2], arg_types[:-2])):
+            if isinstance(arg_type, torch.dtype):
+                new_args.append(f"aoti_torch_mps_set_arg_tensor(handle, {idx}, {arg});")
+            elif arg_type in (int, sympy.core.symbol.Symbol):
+                new_args.append(f"aoti_torch_mps_set_arg_int(handle, {idx}, {arg});")
+            else:
+                raise NotImplementedError(
+                    f"Unsupported arg type {arg_type} for arg {arg} for kernel {kernel_name}"
+                )
+
+        threads, group_size = call_args[-2], call_args[-1]
+        if threads is None:
+            raise NotImplementedError("No threads or group_size provided")
+
+        # Check if threads is a single value or an array-like structure
+        threads_str = str(threads)
+        is_single_value = (
+            threads_str.startswith("{")
+            and threads_str.endswith("}")
+            and threads_str.count(",") == 0
+        ) or not threads_str.startswith(("{", "["))
+
+        if is_single_value:
+            # Extract single value from braces if present
+            if threads_str.startswith("{") and threads_str.endswith("}"):
+                single_value = threads_str[1:-1].strip()  # Remove braces
+            else:
+                single_value = threads_str
+
+            if group_size is None:
+                new_args.append(
+                    f"aoti_torch_mps_dispatch_single(handle, {single_value});"
+                )
+            else:
+                # Extract group size value if it's also in braces
+                group_size_str = str(group_size)
+                if group_size_str.startswith("{") and group_size_str.endswith("}"):
+                    group_size_value = group_size_str[1:-1].strip()
+                else:
+                    group_size_value = group_size_str
+                new_args.append(
+                    f"aoti_torch_mps_dispatch_single_with_group_size(handle, {single_value}, {group_size_value});"
+                )
+        else:
+            # Handle array case - need to convert initializer list to array
+            # Use kernel name to make variable names unique
+            threads_var = f"{kernel_name}_threads_array"
+            group_size_var = f"{kernel_name}_group_size_array"
+
+            # Extract array size from the initializer list string
+            def get_array_size(array_str: str) -> int:
+                # Remove braces and whitespace
+                content = array_str.strip()
+                if content.startswith("{") and content.endswith("}"):
+                    content = content[1:-1].strip()
+
+                if not content:  # Empty array
+                    return 0
+
+                # Count elements by counting commas, accounting for nested structures
+                depth = 0
+                comma_count = 0
+                for char in content:
+                    if char in "({[<":
+                        depth += 1
+                    elif char in ")}]>":
+                        depth -= 1
+                    elif char == "," and depth == 0:
+                        comma_count += 1
+
+                return comma_count + 1  # Number of elements = commas + 1
+
+            threads_size = get_array_size(threads_str)
+
+            if group_size is None:
+                new_args.append("{")
+                new_args.append(f"    uint64_t {threads_var}[] = {threads};")
+                new_args.append(
+                    f"    aoti_torch_mps_dispatch_array(handle, {threads_var}, {threads_size});"
+                )
+                new_args.append("}")
+            else:
+                group_size_str = str(group_size)
+                group_size_size = get_array_size(group_size_str)
+                new_args.append("{")
+                new_args.append(f"    uint64_t {threads_var}[] = {threads};")
+                new_args.append(f"    uint64_t {group_size_var}[] = {group_size};")
+                dispatch_args = f"handle, {threads_var}, {threads_size}, {group_size_var}, {group_size_size}"
+                new_args.append(
+                    f"    aoti_torch_mps_dispatch_array_with_group_size({dispatch_args});"
+                )
+                new_args.append("}")
+
+        # debug printer related logic for cpp kernel type.
+        debug_printer_manager = V.graph.wrapper_code.debug_printer
+        debug_printer_manager.set_printer_args(
+            call_args[:-2],
+            kernel_name,
+            None,
+            None,
+            "cpp",
+        )
+        with debug_printer_manager:
+            self.write_mps_kernel_call(kernel_name, new_args)
+
+    def write_mps_kernel_call(self, name: str, call_args: list[str]) -> None:
+        # Generate unique variable names to avoid duplicate declarations
+        # when the same MPS lib is used multiple times
+        unique_suffix = self._lambda_counter
+        self._lambda_counter += 1
+
+        lambda_name = f"{name}_lambda_{unique_suffix}"
+        wrapper_name = f"{name}_func_wrapper_{unique_suffix}"
+
+        # Generate the function call code (in current location)
+        # Create lambda that captures by reference and pass its pointer through void*
+        self.writeline(
+            f"auto {lambda_name} = [&](AOTIMetalKernelFunctionHandle handle) {{"
+        )
+        self.writeline("    aoti_torch_mps_start_encoding(handle);")
+
+        # Output call args directly since we're capturing by reference
+        for call_arg in call_args:
+            self.writeline(f"    {call_arg}")
+        self.writeline("};")
+        self.writeline("")
+
+        # Pass lambda pointer through void*
+        self.writeline(
+            f"std::function {wrapper_name} = {lambda_name};"
+        )
+        self.writeline(
+            f"aoti_torch_mps_run_command_block(get_{name}_handle(), aoti_torch_mps_shared_callback, &{wrapper_name});"
+        )
+
+    @staticmethod
+    def get_device_include_path(device: str) -> str:
+        assert V.graph.aot_mode
+        return (
+            "#include \n"
+            "#include "
+        )
+
+    def codegen_additional_funcs(self) -> None:
+        """
+        Generate thread-safe lazy singleton pattern for MPS shader libraries with RAII cleanup.
+
+        The generated code will look like:
+        ```
+        AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
+            static auto kernel_handle = []() {
+                AOTIMetalShaderLibraryHandle lib_handle = nullptr;
+                AOTIMetalKernelFunctionHandle kern_handle = nullptr;
+
+                aoti_torch_mps_create_shader_library(mps_lib_0_source, &lib_handle);
+                aoti_torch_mps_get_kernel_function(lib_handle, "generated_kernel", &kern_handle);
+
+                // RAII wrapper with custom deleter
+                auto lib_deleter = [](AOTIMetalShaderLibraryHandle h) {
+                    if (h) aoti_torch_mps_delete_shader_library(h);
+                };
+
+                using LibDeleter = decltype(lib_deleter);
+                using LibPtr = std::unique_ptr;
+
+                // Return pair of kernel handle and library smart pointer for cleanup
+                return std::make_pair(kern_handle, LibPtr(lib_handle, lib_deleter));
+            }();
+            return kernel_handle.first;
+        }
+        ```
+        """
+
+        # Add shimified handles and functions
+        shader_libraries: OrderedSet[str] = OrderedSet()
+        for line in self.lines:
+            if not isinstance(line, KernelCallLine):
+                continue
+            if line.device.type != "mps":
+                continue
+
+            # Extract library name from kernel name (e.g., "mps_lib_0" from kernel calls)
+            if line.kernel_name not in self._used_kernel_names:
+                self._used_kernel_names.add(line.kernel_name)
+                shader_libraries.add(line.kernel_name)
+
+        # NOTE: For shimified version, we expect the shader source constant to be generated
+        # by the existing MPS shader generation process, but instead of instantiating the
+        # DynamicMetalShaderLibrary directly, we'll use our shim functions.
+        # The existing codegen should produce something like:
+        # const char* mps_lib_0_source = R"MTL(...shader_source...)MTL";
+        # instead of:
+        # at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(...shader_source...)MTL");
+
+        # Generate thread-safe lazy singleton with RAII for each library
+        for lib_name in shader_libraries:
+            self.prefix.splice(f"""
+AOTIMetalKernelFunctionHandle get_{lib_name}_handle() {{
+    static auto kernel_handle = []() {{
+        AOTIMetalShaderLibraryHandle lib_handle = nullptr;
+        AOTIMetalKernelFunctionHandle kern_handle = nullptr;
+
+        aoti_torch_mps_create_shader_library({lib_name}_source, &lib_handle);
+        aoti_torch_mps_get_kernel_function(lib_handle, "generated_kernel", &kern_handle);
+
+        // RAII wrapper with custom deleter
+        auto lib_deleter = [](AOTIMetalShaderLibraryHandle h) {{
+            if (h) aoti_torch_mps_delete_shader_library(h);
+        }};
+
+        using LibDeleter = decltype(lib_deleter);
+        using LibPtr = std::unique_ptr;
+
+        // Return pair of kernel handle and library smart pointer for cleanup
+        return std::make_pair(kern_handle, LibPtr(lib_handle, lib_deleter));
+    }}();
+    return kernel_handle.first;
+}}
+""")
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpu_device_op_overrides.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpu_device_op_overrides.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccada837abbd4dbdaf16984c0a44ff7f90cedc04
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpu_device_op_overrides.py
@@ -0,0 +1,30 @@
+from __future__ import annotations
+
+from textwrap import dedent
+
+from .common import DeviceOpOverrides, register_device_op_overrides
+
+
+class CpuDeviceOpOverrides(DeviceOpOverrides):
+    def import_get_raw_stream_as(self, name: str) -> str:
+        return dedent(
+            """
+            def get_raw_stream(_):
+                return 0
+            """
+        )
+
+    def cpp_kernel_type(self) -> str:
+        return "void*"
+
+    def set_device(self, device_idx: int) -> str:
+        return "pass"
+
+    def synchronize(self) -> str:
+        return "pass"
+
+    def device_guard(self, device_idx: int) -> str:
+        return "pass"
+
+
+register_device_op_overrides("cpu", CpuDeviceOpOverrides())
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py
new file mode 100644
index 0000000000000000000000000000000000000000..8779a9e86cda65cd7859a4a693ff2fb6a1ddba70
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py
@@ -0,0 +1,162 @@
+# mypy: allow-untyped-defs
+from __future__ import annotations
+
+from typing import Any, Optional, TYPE_CHECKING, Union
+
+from ..scheduler import (
+    BaseSchedulerNode,
+    BaseScheduling,
+    FusedSchedulerNode,
+    Scheduler,
+    SchedulerNode,
+)
+from .cuda.cuda_cpp_scheduling import CUDACPPScheduling
+from .cutedsl.cutedsl_scheduling import CuteDSLScheduling
+from .rocm.rocm_cpp_scheduling import ROCmCPPScheduling
+from .triton import TritonScheduling
+
+
+if TYPE_CHECKING:
+    from collections.abc import Sequence
+    from typing import TypeAlias
+
+    from sympy import Expr
+
+    import torch
+    from torch.utils._ordered_set import OrderedSet
+
+    from .common import BackendFeature
+
+    _IntLike: TypeAlias = Union[int, Expr]
+
+
+class CUDACombinedScheduling(BaseScheduling):
+    """
+    Scheduler for CUDA Kernels, which delegates calls as appropriate
+    to the CUDA-C++ and Triton Schedulers, which both work for CUDA devices
+    and use a unified-wrapper for codegen.
+
+    If Scheduling code needs to be specialized for the case of mixed Triton / CUDA C++ code,
+    this would also be the place to do it.
+    """
+
+    def __init__(self, scheduler: Optional[Scheduler]) -> None:
+        super().__init__(scheduler)
+        self._triton_scheduling = TritonScheduling(scheduler)
+        self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler)
+        self._rocm_cpp_scheduling = ROCmCPPScheduling(scheduler)
+        self._cutedsl_scheduling = CuteDSLScheduling(scheduler)
+
+    def get_backend_features(self, device: torch.device) -> OrderedSet[BackendFeature]:
+        return self._triton_scheduling.get_backend_features(device)
+
+    def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling:
+        if self._cuda_cpp_scheduling.is_cuda_cpp_template(node):
+            return self._cuda_cpp_scheduling
+        if self._rocm_cpp_scheduling.is_rocm_cpp_template(node):
+            return self._rocm_cpp_scheduling
+        if self._cutedsl_scheduling.is_cutedsl_template(node):
+            return self._cutedsl_scheduling
+        return self._triton_scheduling
+
+    def can_fuse_vertical(
+        self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
+    ) -> bool:
+        if self._cuda_cpp_scheduling.can_fuse_vertical(node1, node2):
+            return True
+        elif self._cuda_cpp_scheduling.is_cuda_cpp_template(
+            node1
+        ) or self._cuda_cpp_scheduling.is_cuda_cpp_template(node2):
+            return False
+        # CuteDSL doesn't support vertical fusion currently
+        elif self._cutedsl_scheduling.is_cutedsl_template(
+            node1
+        ) or self._cutedsl_scheduling.is_cutedsl_template(node2):
+            return False
+        return self._triton_scheduling.can_fuse_vertical(node1, node2)
+
+    def can_fuse_horizontal(
+        self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
+    ) -> bool:
+        for node in (node1, node2):
+            if self._cuda_cpp_scheduling.is_cuda_cpp_template(node):
+                return self._cuda_cpp_scheduling.can_fuse_horizontal(
+                    node1, node2
+                )  # always False at the moment
+            if self._cutedsl_scheduling.is_cutedsl_template(node):
+                return self._cutedsl_scheduling.can_fuse_horizontal(
+                    node1, node2
+                )  # always False at the moment
+        return self._triton_scheduling.can_fuse_horizontal(node1, node2)
+
+    def group_fn(
+        self, sizes: Sequence[Sequence[_IntLike]]
+    ) -> tuple[tuple[_IntLike, ...], ...]:
+        return self._triton_scheduling.group_fn(sizes)
+
+    def codegen_template(
+        self,
+        template_node: BaseSchedulerNode,
+        epilogue_nodes: Sequence[BaseSchedulerNode],
+        prologue_nodes: Sequence[BaseSchedulerNode],
+    ) -> Optional[str]:
+        if self._cuda_cpp_scheduling.is_cuda_cpp_template(template_node):
+            assert not prologue_nodes
+            return self._cuda_cpp_scheduling.codegen_template(
+                template_node, epilogue_nodes, prologue_nodes
+            )
+        elif self._rocm_cpp_scheduling.is_rocm_cpp_template(template_node):
+            assert not epilogue_nodes
+            assert not prologue_nodes
+            return self._rocm_cpp_scheduling.codegen_template(
+                template_node, epilogue_nodes, prologue_nodes
+            )
+        elif self._cutedsl_scheduling.is_cutedsl_template(template_node):
+            # TODO remove this when we add epilogue support
+            assert not epilogue_nodes
+            assert not prologue_nodes
+            return self._cutedsl_scheduling.codegen_template(
+                template_node, epilogue_nodes, prologue_nodes
+            )
+        else:
+            return self._triton_scheduling.codegen_template(
+                template_node, epilogue_nodes, prologue_nodes
+            )
+
+    def codegen_mix_order_reduction(self, node):
+        return self._triton_scheduling.codegen_mix_order_reduction(node)
+
+    def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]) -> None:
+        return self._triton_scheduling.codegen_node(node)
+
+    def codegen_sync(self) -> None:
+        return self._triton_scheduling.codegen_sync()
+
+    def flush(self) -> None:
+        return self._triton_scheduling.flush()
+
+    def codegen_combo_kernel(self, *args: Any, **kwargs: Any) -> None:
+        return self._triton_scheduling.codegen_combo_kernel(*args, **kwargs)
+
+    def benchmark_fused_nodes(
+        self, nodes: Sequence[BaseSchedulerNode]
+    ) -> tuple[float, str]:
+        return self._triton_scheduling.benchmark_fused_nodes(nodes)
+
+    def benchmark_codegened_module(self, module):
+        return self._triton_scheduling.benchmark_codegened_module(module)
+
+    def generate_kernel_code_from_nodes(
+        self,
+        nodes: Sequence[Any],
+        benchmark_kernel: bool = False,
+        hint_override: Optional[int] = None,
+    ) -> str:
+        return self._triton_scheduling.generate_kernel_code_from_nodes(
+            nodes, benchmark_kernel, hint_override=hint_override
+        )
+
+    def benchmark_combo_kernel(
+        self, node_list: Sequence[BaseSchedulerNode]
+    ) -> tuple[float, float, list[Optional[str]]]:
+        return self._triton_scheduling.benchmark_combo_kernel(node_list)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/debug_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/debug_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b465e3d1ffab27bf67fca9a54e8eb6da6f9843d
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/debug_utils.py
@@ -0,0 +1,290 @@
+# mypy: allow-untyped-defs
+from __future__ import annotations
+
+import functools
+import logging
+import os
+from enum import Enum
+from typing import Optional, TYPE_CHECKING
+
+import torch
+from torch import dtype as torch_dtype
+
+from .. import config
+from ..virtualized import V
+from .multi_kernel import MultiKernel
+
+
+if TYPE_CHECKING:
+    from collections.abc import Callable
+
+
+log = logging.getLogger(__name__)
+
+
+def _print_debugging_tensor_value_info(msg, arg):
+    # helper for printing debugging stats for intermediate tensor values
+    # at jit inductor level codegen
+    max_numel_to_print = 64
+    print(msg)
+    if not isinstance(arg, torch.Tensor):
+        print("Value: ", arg)
+        return
+    numel = arg.float().numel()
+    # print the debug printing stats
+    if numel <= max_numel_to_print:
+        print(arg)
+    print("Number of elements: ", numel)
+    print("Size: ", arg.float().size())
+    print("Dtype: ", arg.float().mean().item())
+    print("Mean: ", arg.float().mean().item())
+    print("Min: ", arg.float().min().item())
+    print("Max: ", arg.float().max().item())
+    print("Std: ", arg.float().std().item())
+
+
+# AOTI debug printing related configs
+class IntermediateValueDebuggingLevel(Enum):
+    # OFF: No intermediate tensor value debug info will be printed or saved.
+    OFF = "0"
+    # LEVEL 1: Save all intermediate tensor values to individual `.pt` files. No debug printing will be displayed.
+    SAVE_ONLY = "1"
+    # LEVEL 2: Print all intermediate tensor values by default to the console. No debug saving will be performed.
+    PRINT_ONLY = "2"
+    # LEVEL 3: Print all kernel names to the console only. No debug saving/printing for input tensor value info will be performed.
+    # This mode can be helpful in cases when you just want to pinpointing what kernel is running into a CUDA IMA issue, etc.
+    PRINT_KERNEL_NAMES_ONLY = "3"
+
+
+class DebugPrinterManager:
+    def __init__(
+        self,
+        debug_printer_level,
+        use_array_ref: bool,
+        writeline: Optional[Callable[..., None]] = None,
+        args_to_print_or_save: Optional[list[str]] = None,
+        kernel_name: str = "",
+        kernel=None,
+        arg_signatures: Optional[list[type]] = None,
+        kernel_type=None,
+    ):
+        self.debug_printer_level = IntermediateValueDebuggingLevel(debug_printer_level)
+        self.use_array_ref = use_array_ref
+        if args_to_print_or_save is None:
+            args_to_print_or_save = []
+        self.args_to_print_or_save = args_to_print_or_save
+        self.kernel_name = kernel_name
+        self.arg_signatures: Optional[list[type]] = None
+        self.kernel = kernel
+        self.filtered_kernel_names_to_print = self._get_debug_filtered_kernel_names()
+        self.kernel_type = None
+
+    def __enter__(self):
+        self._perform_debug_print_or_save_helper(
+            self.args_to_print_or_save,
+            self.kernel_name,
+            before_launch=True,
+            arg_signatures=self.arg_signatures,
+        )
+
+    def __exit__(self, args_to_print_or_save, kernel_name, arg_signatures):
+        self._perform_debug_print_or_save_helper(
+            args_to_print_or_save,
+            kernel_name,
+            before_launch=False,
+            arg_signatures=arg_signatures,
+        )
+
+    def _perform_debug_print_or_save_helper(
+        self,
+        args_to_print_or_save,
+        kernel_name,
+        before_launch,
+        arg_signatures: Optional[list[type]] = None,
+    ):
+        if self.debug_printer_level == IntermediateValueDebuggingLevel.OFF:
+            return
+        if self.debug_printer_level == IntermediateValueDebuggingLevel.SAVE_ONLY:
+            # by default save all the tensor values before launch
+            self.codegen_intermediate_tensor_value_save(
+                self.args_to_print_or_save,
+                self.kernel_name,
+                before_launch,
+                arg_signatures=self.arg_signatures,
+            )
+        if self.debug_printer_level == IntermediateValueDebuggingLevel.PRINT_ONLY:
+            # by default print all the tensor values before launch
+            self.codegen_intermediate_tensor_value_print(
+                self.args_to_print_or_save,
+                self.kernel_name,
+                before_launch,
+                arg_signatures=self.arg_signatures,
+            )
+        if (
+            self.debug_printer_level
+            == IntermediateValueDebuggingLevel.PRINT_KERNEL_NAMES_ONLY
+        ):
+            # Print all kernel names to the console only
+            self.codegen_intermediate_tensor_value_print(
+                [],
+                self.kernel_name,
+                before_launch,
+            )
+
+    @functools.lru_cache  # noqa: B019
+    def _get_debug_filtered_kernel_names(self) -> list[str]:
+        if config.aot_inductor.filtered_kernel_names is None:
+            return []
+        return [
+            x.strip()
+            for x in config.aot_inductor.filtered_kernel_names.lower().split(",")
+        ]
+
+    def set_printer_args(
+        self,
+        args_to_print_or_save: list[str],
+        kernel_name: str,
+        arg_signatures: Optional[list[type]],
+        kernel,
+        kernel_type=None,
+    ):
+        # Note: MultiKernel debug printing is not supported for now
+        if isinstance(kernel, MultiKernel):
+            log.info(
+                "MultiKernel type is not supported in AOTI debug printer tool yet."
+            )
+            self.debug_printer_level = IntermediateValueDebuggingLevel.OFF
+
+        self.kernel_type = kernel_type
+        # Note: if the kernel type is an extern kernel (or cpp kernel), we do a special handling to
+        # get the list of args_to_print_or_save
+        # TODO: Find a more reliable way to detect kernel args types to print for extern kernel calls
+        if kernel_type == "extern":
+            args_to_print_or_save_extern = [
+                arg
+                for arg in args_to_print_or_save
+                if isinstance(arg, str) and arg.startswith(("buf", "arg"))
+            ]
+            self.args_to_print_or_save = args_to_print_or_save_extern
+        elif kernel_type == "cpp":
+            self.args_to_print_or_save = [
+                (
+                    f"copy_arrayref_tensor_to_tensor({arg})"
+                    if self.use_array_ref
+                    else arg
+                )
+                for arg in args_to_print_or_save
+                if isinstance(arg, str) and arg.startswith(("buf", "arg"))
+            ]
+        else:
+            self.args_to_print_or_save = args_to_print_or_save
+        self.kernel_name = kernel_name
+        self.arg_signatures = arg_signatures
+        self.kernel = kernel
+
+    def codegen_model_inputs_value_print(self, input_args_to_print: list[str]) -> None:
+        if self.debug_printer_level != IntermediateValueDebuggingLevel.PRINT_ONLY:
+            return
+        for arg in input_args_to_print:
+            if V.graph.cpp_wrapper:
+                V.graph.wrapper_code.prefix.writeline(
+                    f'aoti_torch_print_tensor_handle({arg}, "aoti_model_inputs - {arg}");'
+                )
+
+    def codegen_intermediate_tensor_value_save(
+        self,
+        args_to_save,
+        kernel_name,
+        before_launch=True,
+        arg_signatures: Optional[list[type]] = None,
+    ) -> None:
+        for i, arg in enumerate(args_to_save):
+            if arg_signatures is not None and not isinstance(
+                arg_signatures[i], torch_dtype
+            ):
+                # infer from the arg data type (has torch.dtype) to see if it is a tensor type
+                continue
+            launch_prefix = "before_launch" if before_launch else "after_launch"
+            if V.graph.cpp_wrapper:
+                V.graph.wrapper_code.writeline(
+                    f'aoti_torch_save_tensor_handle({arg}, "{arg}", "{launch_prefix}", "{kernel_name}");'
+                )
+            else:
+                cwd = os.getcwd()
+                saved_dir = cwd + "/tmp/jit_inductor/"
+                if not os.path.exists(saved_dir):
+                    log.info(
+                        "Creating directory to save inductor intermediate tensor values."
+                    )
+                    os.makedirs(saved_dir)
+                # Save the model to the directory
+                saved_path = saved_dir + f"{launch_prefix}_{kernel_name}_{arg}.pt"
+                log.info(
+                    "Saved intermediate tensor %s for %s to %s",
+                    arg,
+                    kernel_name,
+                    saved_path,
+                )
+                line = f"torch.save({arg}, '{saved_path}')"
+                V.graph.wrapper_code.writeline(line)
+
+    def codegen_intermediate_tensor_value_print(
+        self,
+        args_to_print,
+        kernel_name,
+        before_launch=True,
+        arg_signatures: Optional[list[type]] = None,
+    ) -> None:
+        launch_prefix = "before_launch" if before_launch else "after_launch"
+
+        # if the debug printing level is PRINT_KERNEL_NAMES_ONLY
+        # we only print the kernel name to the console
+        if (
+            self.debug_printer_level
+            == IntermediateValueDebuggingLevel.PRINT_KERNEL_NAMES_ONLY
+        ):
+            if V.graph.cpp_wrapper:
+                V.graph.wrapper_code.writeline(
+                    f'printf("[ {launch_prefix}: {kernel_name} ]\\n");'
+                )
+            return
+
+        if self.debug_printer_level != IntermediateValueDebuggingLevel.PRINT_ONLY:
+            return
+        for i, arg in enumerate(args_to_print):
+            # when debug printing is enabled i.e. IntermediateValueDebuggingLevel.PRINT_ONLY,
+            # check if filtered kernel name list is provided
+            if (
+                len(self.filtered_kernel_names_to_print) > 0
+                and kernel_name.lower() not in self.filtered_kernel_names_to_print
+            ):
+                continue
+            if V.graph.cpp_wrapper:
+                if arg_signatures is not None and isinstance(
+                    arg_signatures[i], torch_dtype
+                ):
+                    # infer from the arg data type (has torch.dtype) to see if it is a tensor type
+                    V.graph.wrapper_code.writeline(
+                        f'aoti_torch_print_tensor_handle({arg}, "{launch_prefix} - {kernel_name} - {arg}");'
+                    )
+                elif arg_signatures is not None and isinstance(
+                    arg_signatures[i],
+                    (
+                        type(torch._inductor.codegen.wrapper.SymbolicCallArg),
+                        type(int),
+                        type(float),
+                        type(bool),
+                    ),
+                ):
+                    V.graph.wrapper_code.writeline(
+                        f'printf("[  {launch_prefix} - {kernel_name} - {arg}: %ld  ]", {arg}); printf("\\\\n");'
+                    )
+                else:
+                    if arg_signatures is None and self.kernel_type in ("cpp", "extern"):
+                        V.graph.wrapper_code.writeline(
+                            f'aoti_torch_print_tensor_handle({arg}, "{launch_prefix} - {kernel_name} - {arg}");'
+                        )
+            else:
+                V.graph.wrapper_code.writeline(
+                    f'_print_debugging_tensor_value_info("inductor: {launch_prefix} - {kernel_name} - {arg}", {arg})'
+                )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/halide.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/halide.py
new file mode 100644
index 0000000000000000000000000000000000000000..e47e8e6d7841d4b70b7b41f2298bcd083fe2b8ec
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/halide.py
@@ -0,0 +1,1732 @@
+# mypy: allow-untyped-defs
+from __future__ import annotations
+
+import dataclasses
+import functools
+import itertools
+import logging
+import re
+from collections import defaultdict
+from math import inf
+from typing import Any, cast, Optional, TYPE_CHECKING, Union
+
+import sympy
+
+import torch
+import torch._logging
+
+from ..._prims_common import is_integer_dtype
+from ...utils._ordered_set import OrderedSet
+from ...utils._sympy.functions import FloorDiv, ModularIndexing
+from ...utils._sympy.symbol import symbol_is_type, SymT
+from ...utils._sympy.value_ranges import ValueRanges
+from .. import config, ir
+from ..codecache import HalideCodeCache
+from ..ir import get_reduction_combine_fn
+from ..metrics import is_metric_table_enabled, log_kernel_metadata
+from ..ops_handler import AddParenHandler
+from ..runtime.hints import HalideInputSpec, HalideMeta
+from ..utils import (
+    get_bounds_index_expr,
+    get_kernel_metadata,
+    parallel_num_threads,
+    sympy_index_symbol,
+    sympy_subs,
+)
+from ..virtualized import _ops as ops, V
+from .common import (
+    BackendFeature,
+    CSEVariable,
+    DeferredLine,
+    IndentedBuffer,
+    KernelArgType,
+    OpOverrides,
+    PythonPrinter,
+    SizeArg,
+    TensorArg,
+)
+from .cpp import DTYPE_TO_CPP
+from .cpp_utils import cexpr
+from .simd import constant_repr, SIMDKernel, SIMDScheduling
+
+
+if TYPE_CHECKING:
+    from collections.abc import Callable, Sequence
+
+    from ..ops_handler import ReductionType, StoreMode
+    from ..shape_propagation import BlockShapeType
+
+log = logging.getLogger(__name__)
+
+
+def halide_constant(val):
+    if isinstance(val, int) and not (-2147483648 <= val <= 2147483647):
+        info = torch.iinfo(torch.int64)
+        if val == info.min:
+            return "hl.Int(64).min()"
+        if val == info.max:
+            return "hl.Int(64).max()"
+        return f"hl.i64({val!r})"
+    if isinstance(val, float):
+        return f"hl.f64({constant_repr(val)})"
+    return repr(val)
+
+
+class Unsupported(RuntimeError):
+    def __init__(self, thing) -> None:
+        super().__init__(f"halide backend does not support: {thing}")
+
+
+class HalidePrinter(PythonPrinter):
+    @staticmethod
+    def cast_index(expr):
+        return f"hl.cast({V.kernel.index_dtype}, {expr})"
+
+    @staticmethod
+    def cast_float(expr):
+        return f"hl.cast(hl.Float(32), {expr})"
+
+    def _print_Float(self, expr):
+        return f"hl.f32({expr})"
+
+    def _print_ToFloat(self, expr):
+        assert len(expr.args) == 1
+        return f"hl.f32({self._print(expr.args[0])})"
+
+    def _print_floor(self, expr):
+        assert len(expr.args) == 1
+        return self.cast_index(f"hl.floor({self._print(expr.args[0])})")
+
+    _print_FloorToInt = _print_floor
+
+    def _print_Trunc(self, expr):
+        assert len(expr.args) == 1
+        return self.cast_index(f"hl.trunc({self._print(expr.args[0])})")
+
+    _print_TruncToInt = _print_Trunc
+
+    def _print_ceiling(self, expr):
+        assert len(expr.args) == 1
+        return self.cast_index(f"hl.ceil({self._print(expr.args[0])})")
+
+    def _helper_sqrt(self, expr):
+        return f"hl.sqrt({self.cast_float(self._print(expr))})"
+
+    def _print_Where(self, expr):
+        c = self.doprint(expr.args[0])
+        p = self.doprint(expr.args[1])
+        q = self.doprint(expr.args[2])
+        return f"hl.select({c}, {p}, {q})"
+
+    def _print_Min(self, expr):
+        if len(expr.args) == 1:
+            return self._print(expr.args[0])
+
+        mid = len(expr.args) // 2
+        a = self._print(sympy.Min(*expr.args[:mid]))
+        b = self._print(sympy.Min(*expr.args[mid:]))
+        return f"hl.min({a}, {b})"
+
+    def _print_Max(self, expr):
+        if len(expr.args) == 1:
+            return self._print(expr.args[0])
+
+        mid = len(expr.args) // 2
+        a = self._print(sympy.Max(*expr.args[:mid]))
+        b = self._print(sympy.Max(*expr.args[mid:]))
+
+        return f"hl.max({a}, {b})"
+
+    def _print_Abs(self, expr):
+        assert len(expr.args) == 1
+        return self.cast_index(f"hl.abs({self._print(expr.args[0])})")
+
+    def _print_OpaqueUnaryFn_cos(self, expr):
+        assert len(expr.args) == 1
+        return f"hl.cos({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_cosh(self, expr):
+        assert len(expr.args) == 1
+        return f"hl.cosh({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_acos(self, expr):
+        assert len(expr.args) == 1
+        return f"hl.acos({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_sin(self, expr):
+        assert len(expr.args) == 1
+        return f"hl.sin({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_sinh(self, expr):
+        assert len(expr.args) == 1
+        return f"hl.sinh({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_asin(self, expr):
+        assert len(expr.args) == 1
+        return f"hl.asin({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_tan(self, expr):
+        assert len(expr.args) == 1
+        return f"hl.tan({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_tanh(self, expr):
+        assert len(expr.args) == 1
+        return f"hl.tanh({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_atan(self, expr):
+        assert len(expr.args) == 1
+        return f"hl.atan({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_log2(self, expr):
+        raise NotImplementedError("log2")
+
+    def _print_FloorDiv(self, expr):
+        if expr.is_integer:
+            return super()._print_FloorDiv(expr)
+
+        x, div = expr.args
+        x = self.cast_float(self.doprint(x))
+        div = self.cast_float(self.doprint(div))
+        return self.cast_index(f"hl.floor({x} / {div})")
+
+    def _print_Round(self, expr):
+        assert len(expr.args) == 1
+        return self.cast_index(f"hl.round({self._print(expr.args[0])})")
+
+    _print_RoundToInt = _print_Round
+
+    def _print_IntTrueDiv(self, expr):
+        a, b = expr.args
+        # force a cast to float
+        return f"({a}) / ({b}+hl.f32(0))"
+
+    def _print_RoundDecimal(self, expr):
+        val, n = expr.args
+        val = self._print(val)
+        n = int(n)
+        return f"hl.f32({10.0 ** (-n)!r})*hl.round(({val})*hl.f32({10.0**n!r}))"
+
+
+texpr = HalidePrinter().doprint
+pexpr = PythonPrinter().doprint
+
+
+_halide_type = {
+    torch.bool: "hl.Bool()",
+    torch.bfloat16: "hl.BFloat(16)",
+    torch.float16: "hl.Float(16)",
+    torch.float32: "hl.Float(32)",
+    torch.float64: "hl.Float(64)",
+    torch.int8: "hl.Int(8)",
+    torch.int16: "hl.Int(16)",
+    torch.int32: "hl.Int(32)",
+    torch.int64: "hl.Int(64)",
+    torch.uint8: "hl.UInt(8)",
+    torch.uint16: "hl.UInt(16)",
+    torch.uint32: "hl.UInt(32)",
+    torch.uint64: "hl.UInt(64)",
+}
+
+
+def halide_type(dtype):
+    return _halide_type[dtype]
+
+
+def halide_acc_type(dtype):
+    if is_integer_dtype(dtype) and dtype.is_signed and dtype != torch.int64:
+        dtype = torch.int32
+    if dtype in (torch.float16, torch.bfloat16):
+        dtype = torch.float32
+    return halide_type(dtype)
+
+
+class HalideOverrides(OpOverrides):
+    @staticmethod
+    def to_dtype(
+        x,
+        dtype: torch.dtype,
+        src_dtype: Optional[torch.dtype] = None,
+        use_compute_types=True,
+    ):
+        if dtype == torch.bool:
+            return f"({x} != 0)"
+        return f"hl.cast({halide_type(dtype)}, {x})"
+
+    @staticmethod
+    def to_dtype_bitcast(x, dtype: torch.dtype, src_dtype: torch.dtype):
+        if src_dtype in (torch.float16, torch.bfloat16):
+            x = f"hl.cast({halide_type(src_dtype)}, {x})"  # body compute is upcast to fp32
+        line = f"hl.reinterpret({halide_type(dtype)}, {x})"
+        if dtype in (torch.float16, torch.bfloat16):
+            line = f"hl.cast(hl.Float(32), {line})"
+        return line
+
+    @classmethod
+    def constant(cls, value, dtype):
+        return cls.to_dtype(halide_constant(value), dtype)
+
+    @staticmethod
+    def abs(x):
+        return f"hl.abs({x})"
+
+    @staticmethod
+    def exp(x):
+        if not hasattr(x, "name"):
+            return f"hl.exp({x})"
+        return f"hl.fast_exp(hl.cast(hl.Float(32), {x})) if {x.name}.type().bits() <= 32 else hl.exp({x})"
+
+    @staticmethod
+    def sqrt(x):
+        return f"hl.sqrt({x})"
+
+    @staticmethod
+    def minimum(a, b):
+        # return f"hl.min({a}, {b})"  <== handles nan wrong
+        if not hasattr(a, "name"):
+            return f"hl.min({a}, {b})"
+        b = f"hl.cast({a.name}.type(), {b})"
+        return f"hl.select(({a}<{b})|hl.is_nan({a}), {a}, {b}) if {a.name}.type().is_float() else hl.min({a}, {b})"
+
+    @staticmethod
+    def maximum(a, b):
+        # return f"hl.max({a}, {b})"  <== handles nan wrong
+        if not hasattr(a, "name"):
+            return f"hl.max({a}, {b})"
+        b = f"hl.cast({a.name}.type(), {b})"
+        return f"hl.select(({a}>{b})|hl.is_nan({a}), {a}, {b}) if {a.name}.type().is_float() else hl.max({a}, {b})"
+
+    @staticmethod
+    def where(a, b, c):
+        if hasattr(b, "name"):
+            c = f"hl.cast({b.name}.type(), {c})"
+        return f"hl.select({a}, {b}, {c})"
+
+    @staticmethod
+    def cos(x):
+        return f"hl.cos({x})"
+
+    @staticmethod
+    def sin(x):
+        return f"hl.sin({x})"
+
+    @staticmethod
+    def lgamma(x):
+        raise Unsupported("lgamma")
+
+    @staticmethod
+    def erf(x):
+        return f"hl.erf({x})"
+
+    @staticmethod
+    def cosh(x):
+        return f"hl.cosh({x})"
+
+    @staticmethod
+    def sinh(x):
+        return f"hl.sinh({x})"
+
+    @staticmethod
+    def acos(x):
+        return f"hl.acos({x})"
+
+    @staticmethod
+    def acosh(x):
+        return f"hl.acosh({x})"
+
+    @staticmethod
+    def asin(x):
+        return f"hl.asin({x})"
+
+    @staticmethod
+    def asinh(x):
+        return f"hl.asinh({x})"
+
+    @staticmethod
+    def atan2(x, y):
+        return f"hl.atan2({x}, {y})"
+
+    @staticmethod
+    def atan(x):
+        return f"hl.atan({x})"
+
+    @staticmethod
+    def atanh(x):
+        return f"hl.atanh({x})"
+
+    @staticmethod
+    def copysign(x, y):
+        raise Unsupported("copysign")
+
+    @staticmethod
+    def erfinv(x):
+        raise Unsupported("erfinv")
+
+    @staticmethod
+    def hypot(x, y):
+        return f"hl.hypot({x}, {y})"
+
+    @staticmethod
+    def nextafter(x, y):
+        raise Unsupported("nextafter")
+
+    @staticmethod
+    def logical_and(a, b):
+        return f"{a} & {b}"
+
+    @staticmethod
+    def logical_not(a):
+        return f"{a} == 0"
+
+    @staticmethod
+    def logical_or(a, b):
+        return f"{a} | {b}"
+
+    @staticmethod
+    def logical_xor(a, b):
+        return f"({a} ^ {b})"
+
+    @staticmethod
+    def bitwise_and(a, b):
+        return f"{a} & {b}"
+
+    @staticmethod
+    def bitwise_not(a):
+        return f"~{a}"
+
+    @staticmethod
+    def bitwise_or(a, b):
+        return f"{a} | {b}"
+
+    @staticmethod
+    def bitwise_xor(a, b):
+        return f"{a} ^ {b}"
+
+    @staticmethod
+    def bitwise_left_shift(a, b):
+        return f"{a} << {b}"
+
+    @staticmethod
+    def bitwise_right_shift(a, b):
+        return f"{a} >> {b}"
+
+    @staticmethod
+    def rand(seed, offset):
+        return f"halide_helpers.rand({seed}, {offset})"
+
+    @staticmethod
+    def randn(seed, offset):
+        return f"halide_helpers.randn({seed}, {offset})"
+
+    @staticmethod
+    def randint64(seed, offset, low, high):
+        return f"halide_helpers.randint64({seed}, {offset}, {low}, {high})"
+
+    @staticmethod
+    def load_seed(name, offset):
+        return f"{ops.load(name, 0)} + {V.kernel.args.seed_offset('load_seed_offset', offset)}"
+
+    @staticmethod
+    def rsqrt(x):
+        # return f"hl.fast_inverse_sqrt({x})"  <== accuracy issues
+        return f"1./hl.sqrt({x})"
+
+    @staticmethod
+    def tan(x):
+        return f"hl.tan({x})"
+
+    @staticmethod
+    def tanh(x):
+        return f"hl.tanh({x})"
+
+    @staticmethod
+    def signbit(x):
+        return f"(hl.reinterpret(hl.UInt(32), hl.cast(hl.Float(32), {x})) >> 31) != 0"
+
+    @staticmethod
+    def fmod(a, b):
+        # TODO(jansel): find a better way to do this, builtin % has wrong sign
+        return f"{a} - hl.trunc({a}/{b})*{b}"
+
+    @staticmethod
+    def pow(a, b):
+        return f"hl.pow({a}, {b})"  # hl.fast_pow fails accuracy
+
+    @staticmethod
+    def log(x):
+        return f"hl.log({x})"  # hl.fast_log fails accuracy
+
+    @staticmethod
+    def log2(x):
+        raise NotImplementedError("log2")
+
+    @staticmethod
+    def isinf(x):
+        # workaround https://github.com/halide/Halide/issues/8309
+        return f"hl.is_inf(hl.cast(hl.Float(32), {x}))"
+
+    @staticmethod
+    def isnan(x):
+        # workaround https://github.com/halide/Halide/issues/8309
+        return f"hl.is_nan(hl.cast(hl.Float(32), {x}))"
+
+    @staticmethod
+    def round(x):
+        return f"hl.round({x})"
+
+    @staticmethod
+    def floor(x):
+        return f"hl.floor({x})"
+
+    @staticmethod
+    def int_truediv(a, b):
+        return f"({a}) / ({b} + hl.f32(0))"
+
+    @staticmethod
+    def floordiv(a, b):
+        # TODO(jansel): find a better ways to do this, the select-based trick from triton.py didn't work
+        return (
+            f"hl.floor(hl.cast(hl.Float(max(32, {a.name}.type().bits())), {a}) / {b})"
+        )
+
+    @classmethod
+    def sign(cls, x):
+        left = ops.to_dtype(ops.lt("0", x), torch.int8)
+        right = ops.to_dtype(ops.lt(x, "0"), torch.int8)
+        sub = ops.sub(left, right)
+        return f"hl.cast({x.name}.type(), {sub})"
+
+    @staticmethod
+    def trunc(x):
+        return f"hl.trunc({x})"
+
+    @staticmethod
+    def truncdiv(a, b):
+        # this causes crashes with floating point exception, see test_div_zero_dim_cpu
+        # return f"hl.div_round_to_zero({a}, {b})"
+        return (
+            f"hl.trunc(hl.cast(hl.Float(max(32, {a.name}.type().bits())), {a}) / {b})"
+        )
+
+    @staticmethod
+    def ceil(x):
+        return f"hl.ceil({x})"
+
+    @staticmethod
+    def relu(x):
+        return f"hl.max({x}, 0)"
+
+    @classmethod
+    def index_expr(cls, expr, dtype):
+        index = V.kernel.prepare_indexing(expr)
+        var = V.kernel.genfunc(
+            V.kernel.index_to_str(index),
+            V.kernel.used_dims_from_index(index),
+            bounds=get_bounds_index_expr(expr),
+        )
+        if dtype not in (torch.int32, torch.int64):
+            return ops.to_dtype(var, dtype)
+        return var
+
+    @classmethod
+    def indirect_indexing(cls, index_var, size, check=True, wrap_neg=True):
+        # TODO(jansel): Halide only supports 32-bit indexing, we should error on overflow
+        index_var = ops.to_dtype(index_var, torch.int32)
+        index_var = ops.halide_clamp(index_var, size, check)
+        index_var.indirect_indexing_size = size
+        return sympy_index_symbol(str(index_var))
+
+    @classmethod
+    def halide_clamp(cls, value, size, check):
+        end = V.kernel.kexpr(V.kernel.rename_indexing(size) - 1)
+        if not isinstance(size, (int, sympy.Integer)):
+            end = f"hl.cast({value.name}.type(), {end})"
+        # Skip unsafe_promise_clamped to workaround: https://github.com/halide/Halide/issues/8261#issuecomment-2148835692
+        # return f"hl.unsafe_promise_clamped({value}, 0, {end})"
+        return f"hl.clamp({value}, 0, {end})"
+
+    @staticmethod
+    def masked(mask, body, other):
+        with V.kernel.mask_loads(mask, other) as new_mask:
+            result = body()
+
+        if result.bounds.is_bool:
+            other = bool(other)
+
+        # Take dtype from result to prevent accidental promotion
+        other = V.kernel.genfunc(
+            f"hl.cast({result.name}.type(), {halide_constant(other)})",
+            [],
+            bounds=ValueRanges.wrap(other),
+            shape=result.shape,
+        )
+        # TODO(jansel): look into removing the where in the same places triton does
+        return ops.where(new_mask, result, other)
+
+    @staticmethod
+    def frexp(x):
+        raise NotImplementedError("frexp")
+
+    @staticmethod
+    def device_assert_async(cond, msg):
+        raise NotImplementedError("device_assert_async")
+
+    @staticmethod
+    # pyrefly: ignore [bad-override]
+    def partial_accumulate(
+        name: str,
+        reduction_type: str,
+        value: CSEVariable,
+        extra_meta: dict[str, Any],
+    ) -> None:
+        raise NotImplementedError
+
+
+HalideOverrides._initialize_pointwise_overrides("halide")
+
+
+class HalideCSEVariable(CSEVariable):
+    undefined_re = re.compile(r"\b(tmp\d+)\[\?\]")
+
+    def __init__(
+        self,
+        name,
+        bounds: ValueRanges[Any],
+        dtype: Optional[torch.dtype] = None,
+        shape: BlockShapeType = None,
+    ) -> None:
+        super().__init__(name, bounds, dtype, shape=shape)
+        self.used_dims: Optional[list[sympy.Symbol]] = None
+
+    def update_on_args(self, name, args, kwargs):
+        used = OrderedSet(self.used_dims or ())
+        for arg in itertools.chain(args, kwargs.values()):
+            if isinstance(arg, HalideCSEVariable):
+                assert arg.used_dims is not None, (name, arg, args)
+                used.update(arg.used_dims)
+        self.used_dims = V.kernel.sort_used_dims(used)
+
+    def index_str(self, dims):
+        if len(dims) == 0:
+            return f"{self.name}[()]"
+        # Reversed since Halide is column major
+        return f"{self.name}[{', '.join(map(str, dims))}]"
+
+    def __str__(self) -> str:
+        if self.used_dims is None:
+            # This will get recomputed and replaced in codegen_kernel()
+            return f"{self.name}[?]"
+        return self.index_str(self.used_dims)
+
+    def subs_str(self, replacements):
+        assert self.used_dims is not None and all(
+            isinstance(x, sympy.Expr) for x in self.used_dims
+        )
+        return self.index_str([replacements.get(n, n) for n in self.used_dims])
+
+
+@dataclasses.dataclass
+class DimensionInfo:
+    expr: Optional[sympy.Expr]
+    size: sympy.Expr
+    stride: sympy.Expr
+
+    def __init__(self, expr, size, stride) -> None:
+        super().__init__()
+        if V.graph.sizevars.statically_known_lt(stride, 0):
+            stride = -stride
+            expr = -expr
+        self.expr = expr
+        self.size = size
+        self.stride = stride
+
+    def index_str(self, replacements=None, zero_vars=False):
+        assert self.expr is not None
+        expr = self.expr
+        if zero_vars and expr == 0:
+            return "hl.Var()"
+        if replacements:
+            replacements = {**replacements}
+            # pyrefly: ignore [missing-attribute]
+            for sym in expr.free_symbols:
+                if symbol_is_type(sym, SymT.TMP):
+                    assert isinstance(sym, sympy.Symbol)
+                    var = V.kernel.lookup_cse_var(sym.name)
+                    assert isinstance(var, HalideCSEVariable)
+                    replacements[sym] = sympy_index_symbol(var.subs_str(replacements))
+            expr = sympy_subs(expr, replacements)
+        return V.kernel.index_to_str(expr)
+
+
+def eq(left, right):
+    if V.graph.sizevars.statically_known_equals(left, right):
+        return True
+    try:
+        a = V.graph.sizevars.size_hint_or_throw(left)
+        b = V.graph.sizevars.size_hint_or_throw(right)
+    except TypeError:  # unbacked symints
+        return False
+    if a == b:
+        V.graph.sizevars.check_equals(left, right)
+    return a == b
+
+
+def lt(left, right):
+    if V.graph.sizevars.statically_known_lt(left, right):
+        return True
+    try:
+        a = V.graph.sizevars.size_hint_or_throw(left)
+        b = V.graph.sizevars.size_hint_or_throw(right)
+    except TypeError:  # unbacked symints
+        gcd = sympy.gcd(left, right)
+        if gcd == left:
+            return left != right
+        return False
+    if a < b:
+        V.graph.sizevars.check_lt(left, right)
+    return a < b
+
+
+class HalideKernel(SIMDKernel):
+    overrides = HalideOverrides  # type: ignore[assignment]
+    kexpr: Callable[[sympy.Expr], str] = texpr
+
+    def __init__(
+        self,
+        tiling: dict[str, sympy.Expr],
+        **kwargs,
+    ) -> None:
+        super().__init__(tiling, **kwargs)
+        # For halide, we just write directly to the body
+        self.compute = self.body
+        self.loads = self.body
+        self.stores = self.body
+        self.indexing_code_dom = IndentedBuffer()
+        self.needs_dom_indexing = self.inside_reduction
+        self.has_reduction = self.inside_reduction
+        self.buffer_dimensions: dict[str, list[DimensionInfo]] = {}
+        self.buffer_offsets: dict[str, sympy.Expr] = {}
+        # {h0: size1, h1: size2, ...}
+        self.halide_vars: dict[sympy.Symbol, sympy.Expr] = {}
+        # {x0: h0, x1: h1+10*h2, ...}
+        self.index_replacements: dict[sympy.Expr, sympy.Expr] = {}
+        # {h1: hr1, ...}
+        self.reduction_renames: dict[sympy.Symbol, sympy.Symbol] = {}
+        # {"i": {h0: hi0}, "o": ...}
+        self.dom_renames: dict[str, dict[sympy.Symbol, sympy.Symbol]] = {}
+        # {"in_ptr0": ["in_ptr0_view0"], ...}
+        self.buffer_aliases: dict[str, list[str]] = defaultdict(list)
+        self.has_indirect_indexing = False
+
+    def dtype_to_str(self, dtype: torch.dtype) -> str:
+        return halide_type(dtype)
+
+    # pyrefly: ignore [bad-override]
+    def create_cse_var(self, name, bounds=None, dtype=None, shape=None):
+        self.body.writeline(f"{name} = hl.Func({name!r})")
+        # pyrefly: ignore [bad-argument-type]
+        return HalideCSEVariable(name, bounds, dtype, shape)
+
+    def finalize_indexing(self, indices: Sequence[sympy.Expr]):
+        """
+        Hook called right before codegen with every index that will be
+        used in the fused kernel.
+
+        This populates self.halide_vars/index_replacements/reduction_renames which is an alternate indexing
+        scheme that avoids using divide and modulus.  Instead of xindex/yindex/rindex
+        we base indexing on a larger number of vars whose product combines to those.
+
+        This function populates self.halide_vars, self.index_replacements, and self.reduction_renames
+        """
+        assert not (
+            self.index_replacements or self.halide_vars or self.reduction_renames
+        )
+        size_hint = functools.partial(V.graph.sizevars.size_hint, fallback=inf)  # type: ignore[arg-type]
+        # pyrefly: ignore [bad-assignment]
+        indices = dict.fromkeys(map(super().prepare_indexing, indices))
+        all_used_symbols = OrderedSet[Any]()
+        sym_to_node = {
+            n.symbol(): n
+            for n in itertools.chain.from_iterable(
+                [tree.nodes.values() for tree in self.range_trees]
+            )
+        }
+
+        def simplify(expr):
+            return sympy.simplify(
+                V.graph.sizevars.remove_precomputed_replacements(expr)
+            )
+
+        def visit_modular_indexing(base, divisor, modulus):
+            if base in sym_to_node:
+                node = sym_to_node[base]
+                all_used_symbols.add(
+                    node.root.lookup(
+                        node.divisor * divisor,
+                        V.graph.sizevars.evaluate_min(
+                            modulus, FloorDiv(node.length, divisor)
+                        ),
+                    ).symbol()
+                )
+
+        def visit_floor_div(base, divisor):
+            if base in sym_to_node:
+                node = sym_to_node[base]
+                all_used_symbols.add(
+                    node.root.lookup(
+                        node.divisor * divisor,
+                        FloorDiv(node.length, divisor),
+                    ).symbol()
+                )
+
+        # first figure out all_used_symbols to do dead symbol elimination
+        for index in indices:
+            if index.has(ModularIndexing):
+                index.replace(
+                    ModularIndexing(
+                        sympy.Wild("base"),
+                        sympy.Wild("divisor"),
+                        sympy.Wild("modulus"),
+                    ),
+                    visit_modular_indexing,
+                )
+            if index.has(FloorDiv):
+                index.replace(
+                    FloorDiv(
+                        sympy.Wild("base"),
+                        sympy.Wild("divisor"),
+                    ),
+                    visit_floor_div,
+                )
+            all_used_symbols.update(super().prepare_indexing(index).free_symbols)
+
+        self.has_indirect_indexing = any(
+            symbol_is_type(sym, SymT.INDIRECT) for sym in all_used_symbols
+        )
+
+        had_fallback = False
+        for tree in reversed(self.range_trees):
+            nodes = [n for n in tree.nodes.values() if n.symbol() in all_used_symbols]
+            nodes.sort(key=lambda n: size_hint(n.divisor))
+            if not nodes:
+                nodes.append(tree.lookup(1, tree.numel))
+            handled_count = 0
+            divisor = sympy.S.One
+            added_sym_size = []
+            # decide on a minimal set of symbols and put them in self.halide_vars
+            while handled_count < len(nodes) and not eq(tree.numel, divisor):
+                sizes_to_add = [
+                    simplify(n.length) for n in nodes if eq(n.divisor, divisor)
+                ]
+                handled_count += len(sizes_to_add)
+                assert sizes_to_add, nodes
+                end = divisor * functools.reduce(
+                    V.graph.sizevars.evaluate_max, sizes_to_add
+                )
+                sizes_to_add.extend(
+                    [
+                        simplify(n.divisor / divisor)
+                        for n in nodes
+                        if lt(divisor, n.divisor) and lt(n.divisor, end)
+                    ]
+                )
+                while sizes_to_add:
+                    next_size = functools.reduce(sympy.gcd, sizes_to_add)
+                    if eq(next_size, 1):
+                        # sizes share no common factors, e.g [2, 21, 42, 441, 889056]
+                        # TODO(jansel): we should just prevent fusion in cases that hit this
+                        next_size = simplify(tree.numel / divisor)
+                        assert not eq(next_size, 1)
+                        sizes_to_add = []
+                        handled_count = len(nodes)
+                        had_fallback = True
+                    sym = sympy_index_symbol(f"h{len(self.halide_vars)}")
+                    # pyrefly: ignore [missing-argument]
+                    if tree.is_reduction:
+                        self.reduction_renames[sym] = sympy_index_symbol(
+                            f"hr{len(self.halide_vars)}"
+                        )
+                    self.halide_vars[sym] = next_size
+                    added_sym_size.append((sym, next_size))
+                    divisor *= next_size
+                    new_sizes = [n.length for n in nodes if eq(n.divisor, divisor)]
+                    handled_count += len(new_sizes)
+                    prior_len = len(sizes_to_add)
+                    sizes_to_add = [
+                        sympy.simplify(s / next_size)
+                        for s in sizes_to_add
+                        if not eq(s, next_size)
+                    ]
+                    assert len(sizes_to_add) < prior_len or prior_len == 0
+                    sizes_to_add.extend(new_sizes)
+
+            # create a mapping to the new set of symbols in self.index_replacements
+            for node in nodes:
+                try:
+                    idx = 0
+                    divisor = 1
+                    while not eq(node.divisor, divisor):
+                        sym, size = added_sym_size[idx]
+                        idx += 1
+                        divisor *= size
+                    length = 1
+                    expr = sympy.S.Zero
+                    while not eq(node.length, length):
+                        sym, size = added_sym_size[idx]
+                        idx += 1
+                        expr += length * sym
+                        length *= size
+                    self.index_replacements[node.symbol()] = expr
+                except IndexError:
+                    assert had_fallback
+                    full_index = sympy.S.Zero
+                    stride = sympy.S.One
+                    for sym, size in added_sym_size:
+                        full_index += stride * sym
+                        stride *= size
+                    self.index_replacements[node.symbol()] = (
+                        V.graph.sizevars.simplify_with_ranges(
+                            ModularIndexing(full_index, node.divisor, node.length),
+                            self.halide_vars,  # type: ignore[arg-type]
+                        )
+                    )
+
+        # codegen the variable definitions
+        for sym in self.halide_vars:
+            self.indexing_code.writeline(f"{sym} = hl.Var({sym.name!r})")
+        if self.reduction_renames:
+            self.codegen_rdom(
+                "rdom",
+                {rv: self.halide_vars[v] for v, rv in self.reduction_renames.items()},
+            )
+
+    def setup_dom_indexing(self):
+        """RDom based indexing uses explicit iteration ranges for Func updates"""
+        prefix = "i" if self.inside_reduction else "o"
+        if prefix in self.dom_renames:
+            return self.dom_renames[prefix]
+
+        renames = {}
+        for var in self.halide_vars:
+            if not self.inside_reduction and var in self.reduction_renames:
+                continue
+            m = re.match(r"^h(\d+)$", var.name)
+            assert m
+            renames[var] = sympy_index_symbol(f"h{prefix}{m.group(1)}")
+
+        self.codegen_rdom(
+            f"{prefix}dom", {rv: self.halide_vars[v] for v, rv in renames.items()}
+        )
+
+        self.dom_renames[prefix] = renames
+        return renames
+
+    def codegen_rdom(self, name, vars):
+        rsizes = [
+            f"hl.Range(0, {self.kexpr(self.rename_indexing(size))})"
+            for size in vars.values()
+        ]
+        self.indexing_code.writeline(f"{name} = hl.RDom([{', '.join(rsizes)}])")
+        for i, rsym in enumerate(vars.keys()):
+            self.indexing_code.writeline(f"{rsym} = {name}[{i}]")
+
+    def prepare_indexing(
+        self,
+        index: sympy.Expr,
+    ):
+        index = super().prepare_indexing(index)
+        index = sympy_subs(index, self.index_replacements)
+        return V.graph.sizevars.simplify_with_ranges(index, self.halide_vars)  # type: ignore[arg-type]
+
+    def sym_size(self, sym):
+        """The size of an index symbol"""
+        if symbol_is_type(sym, SymT.TMP):
+            return self.lookup_cse_var(sym.name).indirect_indexing_size
+        return self.halide_vars[sym]
+
+    def indexing_to_dimensions(self, var: str, index: sympy.Expr, is_store: bool):
+        """Convert address-based indexing into dimensions using self.halide_vars"""
+        symbols = []
+        for sym in sorted(index.free_symbols, key=lambda x: x.name):  # type: ignore[attr-defined]
+            if symbol_is_type(sym, (SymT.HALIDE, SymT.TMP)):
+                symbols.append(sym)
+            else:
+                assert symbol_is_type(
+                    sym,
+                    (
+                        SymT.UNBACKED_INT,
+                        SymT.SIZE,
+                        SymT.PRECOMPUTED_SIZE,
+                    ),
+                ), sym
+
+        # group the expression by variables used
+        offset = sympy.S.Zero
+        split_expr = dict.fromkeys(symbols, sympy.S.Zero)
+        split_failed: list[tuple[list[sympy.Symbol], sympy.Expr]] = []
+        index = sympy.expand(self.rename_indexing(index))
+        for part in index.args if isinstance(index, sympy.Add) else [index]:
+            part_vars = [v for v in part.free_symbols if v in split_expr]
+            if len(part_vars) == 0:
+                offset += part
+            elif len(part_vars) == 1:
+                split_expr[part_vars[0]] += part
+            else:
+                new_split_failed = []
+                for i in range(len(split_failed)):
+                    assert split_failed[i] is not None
+                    other_vars, other_part = split_failed[i]
+                    if OrderedSet(other_vars) & OrderedSet(part_vars):
+                        part_vars.extend([v for v in other_vars if v not in part_vars])
+                        part += other_part
+                    else:
+                        new_split_failed.append((other_vars, other_part))
+                split_failed = [*new_split_failed, (part_vars, part)]
+
+        def expr_to_dimension(expr, syms):
+            expr = sympy.factor(expr)
+            if len(syms) == 1:
+                stride_wild = sympy.Wild("wild", exclude=symbols)
+                m = expr.match(stride_wild * syms[0])
+                if m:
+                    return DimensionInfo(
+                        syms[0], self.sym_size(syms[0]), m[stride_wild]
+                    )
+            assert not is_store, expr
+            length = sympy.simplify(
+                sympy_subs(expr, {sym: self.sym_size(sym) - 1 for sym in syms}) + 1
+            )
+            stride = sympy.S.One
+            if isinstance(expr, sympy.Mul):
+                for term in expr.args:
+                    if isinstance(term, sympy.Integer):
+                        stride *= term
+                        expr = sympy.simplify(expr / term)
+                        length = sympy.simplify(sympy.ceiling(length / term))
+            return DimensionInfo(expr, length, stride)
+
+        # try to turn each group into a strided access
+        dims = []
+        for syms, expr in split_failed:
+            for v in syms:
+                expr += split_expr.pop(v)
+            dims.append(expr_to_dimension(expr, syms))
+        for sym, expr in split_expr.items():
+            dims.append(expr_to_dimension(expr, [sym]))
+        dims.sort(key=lambda d: V.graph.sizevars.size_hint(d.stride, fallback=inf))  # type: ignore[arg-type]
+
+        if not dims:  # scalar load/store
+            if self.has_indirect_indexing:
+                # workaround https://github.com/halide/Halide/issues/8338
+                dims.append(DimensionInfo(sympy.S.Zero, 1, 1))
+        elif not V.graph.sizevars.statically_known_equals(dims[0].stride, 1):
+            # Halide assumes dimension 0 is stride == 1, so add a dummy dimension
+            dims.insert(
+                0, DimensionInfo(sympy.S.Zero, 1 if is_store else dims[0].stride, 1)
+            )
+
+        if dims and not is_store:
+            if var in self.buffer_offsets and V.graph.sizevars.statically_known_geq(
+                offset, self.buffer_offsets[var]
+            ):
+                # reuse the existing offset to avoid needing an input alias
+                self.apply_offset_to_dimension(dims, offset - self.buffer_offsets[var])
+                offset = self.buffer_offsets[var]
+            elif V.graph.sizevars.statically_known_gt(
+                offset, 0
+            ):  # TODO(jansel): negative offsets
+                # roll the offset into the dimensions for cleaner indexing
+                self.apply_offset_to_dimension(dims, offset)
+                offset = 0
+
+        orig_var = var
+        for i in itertools.count():
+            if self.install_dims(var, dims, offset, is_store):
+                return var, dims
+            assert not is_store
+            var = f"{orig_var}_view{i}"
+            if var not in self.buffer_aliases[orig_var]:
+                self.buffer_aliases[orig_var].append(var)
+
+    def install_dims(self, var, dims, offset, is_store):
+        """Try to set self.buffer_dimensions[var], return True on success"""
+        if var not in self.buffer_dimensions:
+            self.buffer_dimensions[var] = dims
+            self.buffer_offsets[var] = offset
+            return True
+        if self.buffer_offsets[var] != offset or len(
+            self.buffer_dimensions[var]
+        ) != len(dims):
+            return False
+        if is_store:
+            return self.buffer_dimensions[var] == dims
+        for old, new in zip(self.buffer_dimensions[var], dims):
+            if old.stride != new.stride:
+                return False
+            if old.size != new.size or old.expr != new.expr:
+                old.size = V.graph.sizevars.evaluate_max(old.size, new.size)
+                old.expr = None
+        return True
+
+    def apply_offset_to_dimension(self, dims, offset):
+        if offset == 0:
+            return
+        for i in reversed(range(len(dims))):
+            if dims[i].stride == 1 or V.graph.sizevars.statically_known_geq(
+                offset, dims[i].stride
+            ):
+                part = FloorDiv(offset, dims[i].stride)
+                offset -= part * dims[i].stride
+                dims[i].expr += part
+        assert offset == 0
+
+    def used_dims_from_index(self, index: sympy.Expr):
+        """Detect which range trees are used to populate HalideCSEVariable.used_dims"""
+        used_dims = OrderedSet[sympy.Symbol]()
+        for sym in index.free_symbols:
+            assert isinstance(sym, sympy.Symbol)
+            if symbol_is_type(sym, SymT.TMP):
+                # indirect indexing
+                cse_var = self.lookup_cse_var(sym.name)
+                assert (
+                    isinstance(cse_var, HalideCSEVariable)
+                    and cse_var.used_dims is not None
+                )
+                used_dims.update(cse_var.used_dims)
+            elif symbol_is_type(sym, SymT.HALIDE):
+                used_dims.add(sym)
+            elif symbol_is_type(
+                sym, (SymT.UNBACKED_INT, SymT.SIZE, SymT.PRECOMPUTED_SIZE, SymT.INDEX)
+            ):
+                pass
+            else:
+                raise NotImplementedError(f"unhandled symbol {sym}")
+        return self.sort_used_dims(used_dims)
+
+    def sort_used_dims(self, used_dims):
+        assert all(isinstance(x, sympy.Expr) for x in used_dims)
+        ordered = [
+            sym
+            for sym in itertools.chain(
+                self.halide_vars, self.reduction_renames.values()
+            )
+            if sym in used_dims
+        ]
+        assert len(ordered) == len(used_dims)
+        return ordered
+
+    def make_index_str(self, dims, replacements=None, zero_vars=False):
+        index_str = ", ".join(d.index_str(replacements, zero_vars) for d in dims)
+        if len(dims) == 0:
+            index_str = "()"
+        elif len(dims) == 1:
+            # workaround for https://github.com/halide/Halide/issues/8299
+            index_str = f"{index_str},"
+        return index_str
+
+    def load(self, name: str, index: sympy.Expr):
+        """Codegen a load from an InputBuffer"""
+        var = self.args.input(name)
+        index = self.prepare_indexing(index)
+        var, dims = self.indexing_to_dimensions(var, index, False)
+        line = f"{var}[{self.make_index_str(dims)}]"
+        dtype = V.graph.get_dtype(name)
+        if dtype in (torch.float16, torch.bfloat16):
+            dtype = torch.float32
+            line = f"hl.cast(hl.Float(32), {line})"
+
+        if self._load_mask:
+            assert (
+                isinstance(self._load_mask, HalideCSEVariable)
+                and self._load_mask.used_dims is not None
+            )
+            used_dims = OrderedSet(
+                (*self.used_dims_from_index(index), *self._load_mask.used_dims)
+            )
+            result = self.newfunc(self.sort_used_dims(used_dims))
+            if result.used_dims:
+                self.body.writeline(f"{result.name}_mask = hl.RDom([hl.Range(0, 1)])")
+                self.body.writeline(f"{result.name}_mask.where({self._load_mask})")
+                other = self.kexpr(self._load_other or 0)  # type: ignore[arg-type]
+                self.body.writeline(
+                    f"{result} = hl.cast({halide_type(dtype)}, {other})"
+                )
+                self.body.writeline(
+                    f"{result} = {line} + hl.cast({halide_type(dtype)}, {result.name}_mask)"
+                )
+            else:
+                # scalar case
+                self.body.writeline(
+                    f"{result} = hl.select({self._load_mask}, {line}, hl.cast({halide_type(dtype)}, 0))"
+                )
+            return result
+        else:
+            return self.genfunc(line, self.used_dims_from_index(index))
+
+    def lookup_cse_var(self, name: str):
+        return self.cse.varname_map[re.sub(r"\[.*", "", name)]
+
+    def store(
+        self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
+    ) -> None:
+        """Codegen a store to an OutputBuffer"""
+        assert isinstance(value, HalideCSEVariable)
+        var = self.args.output(name)
+        index = self.prepare_indexing(index)
+        var, dims = self.indexing_to_dimensions(var, index, True)
+        if self.is_indirect_indexing(index) or mode is not None:
+            replacements = self.setup_dom_indexing()
+            index_str = self.make_index_str(dims, replacements)
+            value_str = value.subs_str(replacements)
+            undef_dims = (", ".join(["hl.Var()"] * len(dims))) or "()"
+            self.body.writeline(
+                DeferredLine(name, f"{var}[{undef_dims}] = hl.undef({var}.type())")
+            )
+        else:
+            index_str = self.make_index_str(dims, zero_vars=True)
+            value_str = str(value)
+
+        dtype = V.graph.get_dtype(name)
+        if mode is None:
+            line = f"{var}[{index_str}] = hl.cast({halide_type(dtype)}, {value_str})"
+        elif mode == "atomic_add":
+            line = f"{var}[{index_str}] += hl.cast({halide_type(dtype)}, {value_str})"
+        else:
+            raise NotImplementedError(f"store mode={mode}")
+        self.body.writeline(DeferredLine(name, line))
+
+    def reduction(
+        self,
+        dtype: torch.dtype,
+        src_dtype: torch.dtype,
+        reduction_type: ReductionType,
+        value: Union[CSEVariable, tuple[CSEVariable, ...]],
+    ) -> Union[CSEVariable, tuple[CSEVariable, ...]]:
+        """Codegen a reduction operation"""
+        assert self.inside_reduction
+        assert not self._load_mask
+        cache_key = (src_dtype, reduction_type, value)
+        if cache_key in self.cse.reduction_cache:
+            return self.cse.reduction_cache[cache_key]
+
+        if isinstance(value, tuple):
+            assert reduction_type == "welford_combine"
+            self.cse.reduction_cache[cache_key] = result_tuple = (
+                self.welford_combine_impl(*value)
+            )
+            return result_tuple
+
+        assert isinstance(value, HalideCSEVariable) and value.used_dims is not None
+        reduction_vars = OrderedSet(self.reduction_renames)
+        result_var = self.newfunc(
+            [v for v in value.used_dims if v not in reduction_vars],
+        )
+        if reduction_vars - OrderedSet(value.used_dims):
+            value = self.genfunc(
+                f"{value}",
+                self.sort_used_dims(OrderedSet((*value.used_dims, *reduction_vars))),
+                shape=value.shape,
+            )
+        value_str = value.subs_str(self.reduction_renames)
+        default = ir.Reduction.default_accumulator(reduction_type, src_dtype)
+        acc_type = halide_acc_type(dtype)
+
+        if reduction_type in ("argmax", "argmin"):
+            index = f"{result_var.name}_{reduction_type}"
+            self.body.writeline(f"{index} = hl.{reduction_type}(rdom, {value_str})")
+            # turn the N-D argmax index into a 1-D one
+            parts = []
+            stride = 1
+            for i, sym in enumerate(self.reduction_renames):
+                # pyrefly: ignore [bad-argument-type]
+                parts.append(f"{index}[{i}]")
+                if stride != 1:
+                    # pyrefly: ignore [unsupported-operation]
+                    parts[-1] += f"*{stride}"
+                stride *= self.halide_vars[sym]
+            self.body.writeline(f"{result_var} = {' + '.join(parts)}")
+        elif reduction_type == "welford_reduce":
+            # TODO(jansel): implement welford_reduce without fallback
+            result_var = self.welford_reduce_fallback(dtype, value)
+        else:
+            combine_fn = get_reduction_combine_fn(reduction_type, acc_type)
+            with V.set_ops_handler(AddParenHandler(HalideOverrides())):
+                combine_str = combine_fn(result_var, value_str)  # type: ignore[arg-type]
+            default_str = f"hl.cast({acc_type}, {halide_constant(default)})"
+            self.body.writeline(f"{result_var} = {default_str}")
+            self.body.writeline(f"{result_var} = {combine_str}")
+
+        self.cse.reduction_cache[cache_key] = result_var
+        return result_var
+
+    def welford_combine_impl(self, mean, m2, weight):
+        assert isinstance(mean, HalideCSEVariable) and mean.used_dims is not None
+        assert isinstance(m2, HalideCSEVariable) and m2.used_dims is not None
+        assert isinstance(weight, HalideCSEVariable) and weight.used_dims is not None
+        used_dims = OrderedSet(
+            (*mean.used_dims, *m2.used_dims, *weight.used_dims) or self.halide_vars
+        )
+        used_dims -= OrderedSet(self.reduction_renames)
+        result_var = self.newfunc(self.sort_used_dims(used_dims))
+        default = [f"hl.cast({x.name}.type(), 0)" for x in (mean, m2, weight)]
+        pfx = result_var.name
+        self.body.writeline(f"{result_var} = hl.Tuple([{', '.join(default)}])")
+        self.body.writeline(f"{pfx}_mean_1 = {result_var}[0]")
+        self.body.writeline(f"{pfx}_m2_1 = {result_var}[1]")
+        self.body.writeline(f"{pfx}_weight_1 = {result_var}[2]")
+        self.body.writeline(f"{pfx}_mean_2 = {mean.subs_str(self.reduction_renames)}")
+        self.body.writeline(f"{pfx}_m2_2 = {m2.subs_str(self.reduction_renames)}")
+        self.body.writeline(
+            f"{pfx}_weight_2 = {weight.subs_str(self.reduction_renames)}"
+        )
+        self.body.writeline(f"{pfx}_delta = {pfx}_mean_2 - {pfx}_mean_1")
+        self.body.writeline(f"{pfx}_new_weight = {pfx}_weight_1 + {pfx}_weight_2")
+        self.body.writeline(
+            f"{pfx}_w2_over_w = hl.select({pfx}_new_weight == 0.0, 0.0, {pfx}_weight_2 / {pfx}_new_weight)"
+        )
+        update = [
+            f"{pfx}_mean_1 + {pfx}_delta * {pfx}_w2_over_w",
+            f"{pfx}_m2_1 + {pfx}_m2_2 + {pfx}_delta * {pfx}_delta * {pfx}_weight_1 * {pfx}_w2_over_w",
+            f"{pfx}_new_weight",
+        ]
+        self.body.writeline(f"{result_var} = hl.Tuple([{', '.join(update)}])")
+
+        unpacked = []
+        for i in range(3):
+            unpacked.append(self.newfunc(result_var.used_dims))
+            self.body.writeline(f"{unpacked[-1]} = {result_var}[{i}]")
+        return tuple(unpacked)
+
+    def scan(
+        self,
+        dtypes: tuple[torch.dtype, ...],
+        combine_fn: Callable[
+            [tuple[CSEVariable, ...], tuple[CSEVariable, ...]], tuple[CSEVariable, ...]
+        ],
+        values_orig: tuple[CSEVariable, ...],
+    ) -> tuple[CSEVariable, ...]:
+        assert self.inside_reduction
+        assert len(dtypes) == len(values_orig)
+        values: list[HalideCSEVariable] = []
+        all_used_dims = OrderedSet[sympy.Symbol]()
+
+        for value in values_orig:
+            assert isinstance(value, HalideCSEVariable) and value.used_dims is not None
+            if OrderedSet(value.used_dims) & OrderedSet(self.reduction_renames):
+                values.append(value)
+            else:
+                values.append(
+                    self.genfunc(
+                        f"{value}",
+                        [*value.used_dims, [*self.reduction_renames][:1]],
+                        shape=value.shape,
+                    )
+                )
+            all_used_dims.update(value.used_dims)
+        result_var = self.newfunc(self.sort_used_dims(all_used_dims))
+        assert result_var.used_dims and OrderedSet(result_var.used_dims) & OrderedSet(
+            self.reduction_renames
+        )
+        initial = [
+            f"hl.cast({halide_acc_type(dtype)}, {value})"
+            for dtype, value in zip(dtypes, values)
+        ]
+
+        length = self.kexpr(self.rename_indexing(self.range_trees[-1].numel))
+        scan_dom = f"{result_var.name}_rdom"
+        scan = f"{scan_dom}.x"
+        self.body.writeline(f"{scan_dom} = hl.RDom([hl.Range(1, {length})])")
+
+        assert len(self.reduction_renames) == 1, (
+            "multi-dimensional scan not implemented"
+        )
+        (scan_var,) = [*self.reduction_renames]  # type: ignore[misc]
+        scan_renames_cur = {scan_var: sympy_index_symbol(scan)}
+        scan_renames_pri = {scan_var: sympy_index_symbol(scan) - 1}
+
+        if len(values) == 1:
+
+            def maybe_tuple(x):
+                return x[0]
+
+            read_left = [result_var.subs_str(scan_renames_pri)]
+            read_right = [result_var.subs_str(scan_renames_cur)]
+        else:
+
+            def maybe_tuple(x):
+                return f"hl.Tuple([{', '.join(x)}])"
+
+            read_left = [
+                result_var.subs_str(scan_renames_pri) + f"[{i}]"
+                for i in range(len(values))
+            ]
+            read_right = [
+                result_var.subs_str(scan_renames_cur) + f"[{i}]"
+                for i in range(len(values))
+            ]
+
+        self.body.writeline(f"{result_var} = {maybe_tuple(initial)}")
+
+        # Disable CSE for update fn
+        with V.set_ops_handler(AddParenHandler(HalideOverrides())):
+            combine_str = combine_fn(read_left, read_right)  # type: ignore[arg-type]
+        self.body.writeline(
+            f"{result_var.subs_str(scan_renames_cur)} = {maybe_tuple(combine_str)}"
+        )
+
+        if len(values) == 1:
+            return (result_var,)
+
+        unpack_vars = [self.newfunc(self.sort_used_dims(all_used_dims)) for _ in values]
+        for i, v in enumerate(unpack_vars):
+            self.body.writeline(f"{v} = {result_var}[{i}]")
+        return tuple(unpack_vars)
+
+    def genfunc(
+        self,
+        line,
+        used_dims,
+        *,
+        bounds=ValueRanges.unknown(),
+        shape: BlockShapeType = None,
+    ) -> HalideCSEVariable:
+        var = self.cse.generate(self.body, line, bounds=bounds, shape=shape)
+        assert isinstance(var, HalideCSEVariable)
+        var.used_dims = used_dims
+        return var
+
+    def newfunc(self, used_dims, *, shape: BlockShapeType = None) -> HalideCSEVariable:
+        var = self.cse.newvar(shape=shape)
+        assert isinstance(var, HalideCSEVariable)
+        var.used_dims = used_dims
+        return var
+
+    def halide_buffer_numel(self, name: str):
+        """
+        We map all tensors to 1D buffers in Halide since Halide has trouble representing some strides that PyTorch
+        supports.  If there are gaps in the underlying layout the numel we pass to Halide includes the gaps while
+        PyTorch's numel excludes them.
+        """
+        return V.graph.get_buffer(name).get_layout().storage_size()
+
+    def halide_argdefs(self):
+        """
+        Halide requires scalar inputs before outputs, so need to reorder args.
+        """
+
+        def arg_order(arg_tuple):
+            _call_str, arg = arg_tuple
+            if isinstance(arg, SizeArg):
+                return 1  # this would normally be at the end, move it to middle
+            elif "out_ptr" in arg.name:
+                return 2
+            else:
+                assert "in_ptr" in arg.name
+                return 0
+
+        result: list[tuple[Optional[str], KernelArgType]] = []
+        _, a, b, _ = self.args.python_argdefs()
+        for call_str, arg in sorted(zip(a, b), key=arg_order):
+            result.append((call_str, arg))
+            if isinstance(arg, TensorArg):
+                assert arg.offset == 0 and arg.alias_of is None
+                result.extend(
+                    (
+                        None,
+                        TensorArg(
+                            alias,
+                            arg.buffer,
+                            arg.dtype,
+                            arg.offset,
+                            alias_of=arg.name,
+                        ),
+                    )
+                    for alias in self.buffer_aliases.get(arg.name, ())
+                )
+        return result
+
+    def halide_kernel_meta(self) -> HalideMeta:
+        """Compute metadata required by codecache.py"""
+        argtypes = []
+        for _, arg in self.halide_argdefs():
+            if isinstance(arg, SizeArg):
+                shape = None
+                stride = None
+                offset = None
+                dtype = "long"
+            else:
+                shape = [
+                    cexpr(self.rename_indexing(x.size))
+                    for x in self.buffer_dimensions[arg.name]
+                ]
+                stride = [
+                    cexpr(self.rename_indexing(x.stride))
+                    for x in self.buffer_dimensions[arg.name]
+                ]
+                assert len(shape) == len(stride)
+                offset = cexpr(self.buffer_offsets[arg.name])
+                dtype = f"{DTYPE_TO_CPP[arg.dtype]}*"
+            argtypes.append(
+                HalideInputSpec(
+                    dtype,
+                    arg.name,
+                    shape=shape,
+                    stride=stride,
+                    offset=offset,
+                    alias_of=arg.alias_of,
+                )
+            )
+
+        current_device = V.graph.get_current_device_or_throw()
+        if current_device.type == "cpu":
+            target = [config.halide.cpu_target]
+            scheduler = config.halide.scheduler_cpu
+            scheduler_flags = {
+                "parallelism": parallel_num_threads(),
+            }
+            cuda_device = None
+        else:
+            assert current_device.type == "cuda", "only cpu/cuda supported"
+            assert current_device.index <= 0, "only default device supported"
+            target = [config.halide.gpu_target]
+            scheduler = config.halide.scheduler_cuda
+            capability = torch.cuda.get_device_properties(current_device)
+            if "cuda_capability" not in target[0]:
+                for major, minor in [(8, 6), (8, 0), (7, 5), (7, 0), (6, 1)]:
+                    if capability.major >= major and capability.minor >= minor:
+                        target.append(f"cuda_capability_{major}{minor}")
+                        break
+            target.append("user_context")
+            scheduler_flags = {
+                "parallelism": capability.multi_processor_count,
+                # TODO(jansel): explore other flags, see:
+                # grep parser.parse ~/Halide/src/autoschedulers/anderson2021/AutoSchedule.cpp
+            }
+            cuda_device = max(0, current_device.index)
+
+        # strict_float is requires for correctness
+        target.append("strict_float")
+
+        # without this we will initialize cuda once per kernel and hit errors
+        target.append("no_runtime")
+
+        if not config.halide.asserts:
+            target.append("no_asserts")
+
+        if config.halide.debug:
+            target.append("debug")
+
+        if "64" in self.index_dtype:
+            # TODO(jansel): it is unclear if this does anything, since input sizes are still int32
+            target.append("large_buffers")
+
+        return HalideMeta(
+            argtypes,
+            target="-".join(target),
+            scheduler=scheduler,
+            scheduler_flags=scheduler_flags,  # type: ignore[arg-type]
+            cuda_device=cuda_device,
+        )
+
+    def codegen_kernel(self, name=None):
+        """Called at the end to generate a final kernel string"""
+        if self.args.inplace_buffers:
+            raise Unsupported("inplace_buffers")
+        meta = self.halide_kernel_meta()  # ensure needed args are added early
+        code = IndentedBuffer()
+        code.splice(
+            """
+            import halide as hl
+            from torch._inductor.runtime import halide_helpers
+            from math import inf, nan
+
+            @hl.generator(name="kernel")
+            class Kernel:
+        """,
+            strip=True,
+        )
+        code.do_indent()
+        for _, arg in self.halide_argdefs():
+            if isinstance(arg, SizeArg):
+                code.writeline(f"{arg.name} = hl.InputScalar({self.index_dtype})")
+            else:
+                assert arg.buffer, arg
+                argcls = "hl.OutputBuffer" if "out" in arg.name else "hl.InputBuffer"
+                argtype = halide_type(arg.dtype)
+                ndim = len(self.buffer_dimensions[arg.name])
+                code.writeline(f"{arg.name} = {argcls}({argtype}, {ndim})")
+        code.splice(
+            """
+            def generate(g):
+        """
+        )
+        code.do_indent()
+        for _, arg in self.halide_argdefs():
+            code.writeline(f"{arg.name} = g.{arg.name}")
+        for old, new in self.args.aliases():
+            code.writeline(f"{old} = {new}")
+        code.splice(self.indexing_code)
+
+        def update_index(m):
+            var = cast(HalideCSEVariable, self.cse.varname_map[m.group(1)])
+            assert var.used_dims is not None, var
+            return str(var)
+
+        for line in self.body._lines:
+            if isinstance(line, str):
+                # fill in missing indices
+                line = HalideCSEVariable.undefined_re.sub(update_index, line)
+            code.writeline(line)
+        code.writeline("")
+        code.writeline("assert g.using_autoscheduler()")
+
+        for _, arg in self.halide_argdefs():
+            # fallback=1 below because halide requires buffers to be at least as large as the estimates
+            # This causes crashes if our estimate is greater than the vector length
+            # https://github.com/halide/Halide/issues/3103
+            if isinstance(arg, SizeArg):
+                hint = V.graph.sizevars.size_hint(arg.expr, fallback=1)
+                code.writeline(f"{arg.name}.set_estimate({hint})")
+            else:
+                dims = self.buffer_dimensions[arg.name]
+                range_hints = []
+                for i, dim in enumerate(dims):
+                    hint = self._autoscheduler_workarounds(
+                        V.graph.sizevars.size_hint(dim.size, fallback=1), dims
+                    )
+                    # pyrefly: ignore [bad-argument-type]
+                    range_hints.append(f"hl.Range(0, {hint})")
+                    if "out" not in arg.name:
+                        code.writeline(f"{arg.name}.dim({i}).set_min(0)")
+                        try:
+                            code.writeline(
+                                f"{arg.name}.dim({i}).set_stride({int(dim.stride)})"
+                            )
+                        except TypeError:
+                            pass  # not integer
+                        try:
+                            code.writeline(
+                                f"{arg.name}.dim({i}).set_extent({int(dim.size)})"
+                            )
+                        except TypeError:
+                            pass  # not integer
+                code.writeline(f"{arg.name}.set_estimates([{', '.join(range_hints)}])")
+
+        code.do_unindent(2)
+        code.splice(
+            """
+            if __name__ == "__main__":
+                hl.main()
+            """.rstrip(),
+        )
+        if meta.scheduler:
+            code.splice(
+                f"""
+                else:
+                    hl.load_plugin({HalideCodeCache.find_libautoschedule(meta.scheduler)!r})
+                    target = hl.Target({meta.target!r})
+                    autoscheduler = hl.AutoschedulerParams({meta.scheduler!r}, {meta.scheduler_flags!r})
+                    with hl.GeneratorContext(target, autoscheduler):
+                        gen = Kernel()
+                        pipeline = gen._build_pipeline()
+                        # gen.compile_to_callable() does not run the autoscheduler
+                        pipeline.apply_autoscheduler(target, autoscheduler)
+                        kernel = pipeline.compile_to_callable([
+                                gen._get_input_parameter(a.name)._to_argument()
+                                for a in gen._get_arginfos()
+                                if a.dir == hl.ArgInfoDirection.Input
+                            ], target)
+                """,
+                strip=True,
+            )
+        else:
+            code.splice(
+                f"""
+                  else:
+                      with hl.GeneratorContext(hl.Target({meta.target!r})):
+                          kernel = Kernel().compile_to_callable()
+                  """,
+                strip=True,
+            )
+        return code.getvalue()
+
+    @staticmethod
+    def _autoscheduler_workarounds(n, dims):
+        if (
+            len(dims) == 1
+            and config.halide.scheduler_cuda == "Anderson2021"
+            and V.graph.get_current_device_or_throw().type == "cuda"
+        ):
+            # workaround https://github.com/halide/Halide/issues/8246
+            n = max(2, n)
+        return n
+
+    def call_kernel(self, name: str, node=None, deallocate_ws: bool = True):
+        """Codegen a call to this kernel"""
+        wrapper = V.graph.wrapper_code
+        call_args = [f"{n}" for n, arg in self.halide_argdefs() if arg.alias_of is None]
+        current_device = V.graph.get_current_device_or_throw()
+        if current_device.type == "cuda":
+            stream_name = wrapper.write_get_raw_stream(
+                current_device.index, V.graph.name
+            )
+            call_args.append(stream_name)
+        wrapper.generate_kernel_call(
+            name,
+            call_args,
+            device=current_device,
+            triton=False,
+        )
+
+    def generate_assert(self, check):
+        return False  # TODO(jansel): support asserts
+
+    def check_bounds(
+        self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
+    ):
+        pass  # TODO(jansel): support asserts
+
+
+class HalideScheduling(SIMDScheduling):
+    kernel_type = HalideKernel  # type: ignore[arg-type,assignment]
+
+    @classmethod
+    def get_backend_features(cls, device: torch.device) -> OrderedSet[BackendFeature]:
+        result = OrderedSet(
+            [
+                BackendFeature.TUPLE_REDUCTION,
+                BackendFeature.PREFER_STORE_LOOP_ORDER,
+                BackendFeature.REDUCE_TO_SINGLE_ELEMENT,
+            ]
+        )
+        if config.halide.scan_kernels:
+            result.add(BackendFeature.SCAN)
+        return result
+
+    def define_kernel(self, src_code, node_schedule, kernel):
+        """Codegen kernel definition to go in output wrapper code"""
+        wrapper = V.graph.wrapper_code
+        if src_code in wrapper.src_to_kernel:
+            kernel_name = wrapper.src_to_kernel[src_code]
+        else:
+            kernel_name = f"halide_kernel_{wrapper.next_kernel_suffix()}"
+            wrapper.src_to_kernel[src_code] = kernel_name
+            wrapper.add_import_once(
+                "from torch._inductor.runtime.hints import HalideMeta, HalideInputSpec"
+            )
+
+            compile_wrapper = IndentedBuffer()
+            compile_wrapper.writeline(
+                f"async_compile.halide({kernel.halide_kernel_meta()!r}, '''"
+            )
+            compile_wrapper.splice(src_code, strip=True)
+            compile_wrapper.writeline("''')")
+
+            origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
+            metadata_comment = f"{origins}\n{detailed_origins}"
+            wrapper.define_kernel(
+                kernel_name, compile_wrapper.getvalue(), metadata_comment
+            )
+            if is_metric_table_enabled("kernel_metadata"):
+                log_kernel_metadata(kernel_name, "", src_code)
+
+        return kernel_name
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/memory_planning.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/memory_planning.py
new file mode 100644
index 0000000000000000000000000000000000000000..12d7500975e5b93c6c837a48821ef737df6a3f19
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/memory_planning.py
@@ -0,0 +1,816 @@
+# mypy: allow-untyped-defs
+from __future__ import annotations
+
+import collections
+import dataclasses
+import itertools
+import pprint
+from typing import Any, Optional, Protocol, TYPE_CHECKING
+
+import sympy
+
+import torch
+from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
+from torch.utils._ordered_set import OrderedSet
+
+from .. import config
+from ..utils import _align, align, cache_on_self, CachedMethod, IndentedBuffer
+from ..virtualized import V
+from .wrapper import (
+    AllocateLine,
+    BufferLike,
+    FreeIfNotReusedLine,
+    MemoryPlanningLine,
+    NullLine,
+    ReuseLine,
+)
+
+
+if TYPE_CHECKING:
+    from collections.abc import Iterable
+
+
+@dataclasses.dataclass
+class LiveRange:
+    """
+    A range where a given tensor is live.  Begin and end are both counters
+    representing points in the program of grouped memory operations.
+    Begin is inclusive, end is exclusive.
+
+    Invariant: begin <= end
+    """
+
+    begin: float  # int | +/-inf
+    end: float  # int | +/-inf
+
+    def contains(self, other: LiveRange):
+        """Is other entirely within self"""
+        return self.begin <= other.begin and other.end <= self.end
+
+    def join(self, other: LiveRange):
+        """Combine two ranges using a union operation"""
+        return LiveRange(min(self.begin, other.begin), max(self.end, other.end))
+
+    def __len__(self):
+        return self.end - self.begin
+
+
+class LiveRanges:
+    """
+    A collection of LiveRange regions, allowing for non-contiguous
+    live regions.
+
+    Invariant: LiveRanges.ranges is in sorted order and non-overlapping
+    """
+
+    def __init__(self, ranges: Iterable[LiveRange]):
+        ranges = [*sorted(ranges, key=lambda x: x.begin)]
+        self.ranges = ranges[:1]
+        for r in ranges[1:]:
+            assert self.ranges[-1].begin <= r.begin
+            if self.ranges[-1].end >= r.begin:
+                self.ranges[-1] = LiveRange.join(self.ranges[-1], r)
+            else:
+                self.ranges.append(r)
+
+    def overlaps(self, other: LiveRanges):
+        """Check if any pair of ranges in self and other overlap"""
+        left = collections.deque(self.ranges)
+        right = collections.deque(other.ranges)
+        while left and right:
+            if left[0].begin > right[0].begin:
+                left, right = right, left
+            assert left[0].begin <= right[0].begin
+            if left[0].end > right[0].begin:
+                return True
+            left.popleft()
+        return False
+
+    @property
+    def begin(self):
+        return self.ranges[0].begin
+
+    @property
+    def end(self):
+        return self.ranges[-1].end
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}([{', '.join(map(repr, self.ranges))}])"
+
+
+class AllocationTreeNode:
+    """
+    Abstract base class for nodes in allocation pool.
+    """
+
+    def allocate(self, block: Allocation, is_last: bool) -> bool:
+        """
+        Try to assign block to a memory location in this bool.  Return True if
+        an assignment was made.
+        """
+        return False
+
+    def get_live_ranges(self) -> LiveRanges:
+        """Aggregate LiveRanges for all objects below this in tree"""
+        raise NotImplementedError
+
+    def get_size_hint(self) -> int:
+        """Number of bytes used for example inputs"""
+        raise NotImplementedError
+
+    def get_symbolic_size(self) -> sympy.Expr:
+        """Number of bytes needed at runtime"""
+        raise NotImplementedError
+
+    def finalize(self, pool, offset) -> AllocationTreeNode:
+        """Called after all allocations have been made"""
+        return self
+
+    def is_empty(self):
+        return False
+
+
+@dataclasses.dataclass
+class Allocation(AllocationTreeNode):
+    """
+    Represents memory allocated to a given node in the allocation pool.
+    """
+
+    node: BufferLike
+    live_range: LiveRange
+    size_hint: int
+    symbolic_size: sympy.Expr
+    allocated: bool = False
+    pool: Optional[AllocationPool] = None
+    offset: Optional[sympy.Expr] = None
+    earliest_available: Optional[float] = None
+
+    def __post_init__(self) -> None:
+        has_unbacked_sym = False
+        for s in self.node.get_layout().size:
+            if free_unbacked_symbols(s):
+                has_unbacked_sym = True
+                break
+
+        if has_unbacked_sym:
+            self.earliest_available = self.get_live_ranges().begin
+
+    @property
+    def device(self):
+        return self.node.get_device()
+
+    def get_live_ranges(self):
+        return LiveRanges([self.live_range])
+
+    def get_size_hint(self):
+        return self.size_hint
+
+    def get_symbolic_size(self):
+        return self.symbolic_size
+
+    def mark_allocated(self):
+        assert not self.allocated
+        self.allocated = True
+
+    def finalize(self, pool, offset):
+        assert self.pool is None and self.offset is None
+        self.pool = pool
+        self.offset = offset
+        return self
+
+    def codegen_alloc_from_pool(self, wrapper):
+        assert self.pool
+        node = self.node
+        shape = tuple(node.get_size())
+        stride = tuple(node.get_stride())
+        return wrapper.codegen_alloc_from_pool(
+            self.pool.name, self.offset, node.get_dtype(), shape, stride
+        )
+
+    def __repr__(self):
+        return (
+            f"{self.__class__.__name__}("
+            f"node={self.node.get_name()}, "
+            f"live_range={self.live_range}, "
+            f"size_hint={self.size_hint}, "
+            f"symbolic_size={self.symbolic_size}, "
+            f"pool={self.pool.name if self.pool else None}, "
+            f"offset={self.offset})"
+        )
+
+    def get_earliest_available(self):
+        return self.earliest_available
+
+
+@dataclasses.dataclass
+class Empty(AllocationTreeNode):
+    """
+    Placeholder to represent empty space in the allocation pool.
+    Only exists to get the size_hint correct in parent nodes.
+    """
+
+    size_hint: int
+
+    def get_live_ranges(self):
+        return LiveRanges([])
+
+    def get_size_hint(self):
+        return self.size_hint
+
+    def get_symbolic_size(self):
+        return 0
+
+    def is_empty(self):
+        return True
+
+
+class MemorySplitProtocol(Protocol):
+    get_live_ranges: CachedMethod[[], LiveRanges]
+    get_size_hint: CachedMethod[[], int]
+    get_symbolic_size: CachedMethod[[], sympy.Expr]
+
+    def _allocate(self, block: Allocation, is_last: bool) -> bool: ...
+
+
+class ClearCacheOnAllocateMixin(MemorySplitProtocol):
+    """
+    Helper to assist in caching get_live_ranges, get_size_hint, and
+    get_symbolic_size.
+    """
+
+    def allocate(self, block: Allocation, is_last: bool):
+        is_allocated = self._allocate(block, is_last)
+        if is_allocated:
+            self.clear_cache()
+        return is_allocated
+
+    def clear_cache(self):
+        self.get_live_ranges.clear_cache(self)
+        self.get_size_hint.clear_cache(self)
+        self.get_symbolic_size.clear_cache(self)
+
+
+@dataclasses.dataclass
+class TemporalSplit(ClearCacheOnAllocateMixin, AllocationTreeNode):
+    """
+    Contains a list of allocations not overlapping in LiveRanges.
+
+    Invariant: no pair (a,b) in self.allocations will have:
+         a.get_live_ranges().overlaps(b.get_live_ranges())
+    """
+
+    allocations: list[AllocationTreeNode]
+
+    def _allocate(self, block: Allocation, is_last: bool):
+        slot_size = self.get_size_hint()
+        block_size = block.get_size_hint()
+        if not is_last and block_size > slot_size:
+            return False  # doesn't fit
+
+        block_live = block.get_live_ranges()
+        overlapping = [
+            s for s in self.allocations if s.get_live_ranges().overlaps(block_live)
+        ]
+        if len(overlapping) > 1:
+            # TODO(jansel): we could try harder here by merging overlapping in space
+            return False
+        elif len(overlapping) == 1:
+            return overlapping[0].allocate(block, is_last)
+        else:
+            block.mark_allocated()
+
+            if len(self.allocations) == 1 and isinstance(self.allocations[-1], Empty):
+                self.allocations.pop()
+
+            if slot_size == block_size:
+                # perfect fit
+                self.allocations.append(block)
+            elif slot_size > block_size:
+                self.allocations.append(
+                    SpatialSplit.create(block, slot_size - block_size)
+                )
+            else:  # grow this allocation
+                assert is_last
+                self.allocations = [
+                    *(
+                        SpatialSplit.create(a, block_size - slot_size)
+                        for a in self.allocations
+                    ),
+                    block,
+                ]
+            return True
+
+    @cache_on_self
+    def get_live_ranges(self) -> LiveRanges:
+        return LiveRanges(
+            itertools.chain.from_iterable(
+                x.get_live_ranges().ranges for x in self.allocations
+            )
+        )
+
+    @cache_on_self
+    def get_size_hint(self) -> int:
+        if not self.allocations:
+            return 0
+        return max(x.get_size_hint() for x in self.allocations)
+
+    @cache_on_self
+    def get_symbolic_size(self) -> sympy.Expr:
+        if not self.allocations:
+            return 0  # type: ignore[return-value]
+        return sympy.Max(*[x.get_symbolic_size() for x in self.allocations])
+
+    def is_empty(self):
+        return len(self.allocations) == 1 and self.allocations[0].is_empty()
+
+    def finalize(self, pool, offset):
+        self.allocations = [block.finalize(pool, offset) for block in self.allocations]
+        self.clear_cache()
+        if len(self.allocations) == 1:
+            return self.allocations[0]
+        return self
+
+
+@dataclasses.dataclass
+class SpatialSplit(ClearCacheOnAllocateMixin, AllocationTreeNode):
+    """
+    Contains two allocations, left and right, that do not overlap in space.
+    Right will be allocated immediately after left in memory.
+    """
+
+    left: TemporalSplit
+    right: TemporalSplit
+
+    @staticmethod
+    def create(left, extra_space):
+        assert isinstance(left, AllocationTreeNode)
+        assert isinstance(extra_space, int) and extra_space >= 1
+        return SpatialSplit(TemporalSplit([left]), TemporalSplit([Empty(extra_space)]))
+
+    def _allocate(self, block: Allocation, is_last: bool):
+        return self.left.allocate(block, False) or self.right.allocate(block, is_last)
+
+    @cache_on_self
+    def get_live_ranges(self):
+        return LiveRanges(
+            itertools.chain(
+                self.left.get_live_ranges().ranges, self.right.get_live_ranges().ranges
+            )
+        )
+
+    @cache_on_self
+    def get_size_hint(self) -> int:
+        return _align(self.left.get_size_hint()) + self.right.get_size_hint()
+
+    @cache_on_self
+    def get_symbolic_size(self) -> sympy.Expr:
+        return align(self.left.get_symbolic_size()) + self.right.get_symbolic_size()
+
+    def finalize(self, pool, offset):
+        self.left = self.left.finalize(pool, offset)
+        self.right = self.right.finalize(
+            pool, offset + align(self.left.get_symbolic_size())
+        )
+        self.clear_cache()
+        if self.right.is_empty():
+            return self.left
+        return self
+
+
+@dataclasses.dataclass
+class AllocationPool:
+    """
+    Represents a pool of allocations that will be generated by a single
+    call to torch.empty.
+    """
+
+    device: torch.device
+    root: TemporalSplit
+    can_expand: bool = True
+    restrict_live_range: Optional[LiveRange] = None
+    name: Optional[str] = None
+    names_to_del: list[str] = dataclasses.field(default_factory=list)
+    creation_cache: dict[str, str] = dataclasses.field(default_factory=dict)
+
+    def __post_init__(self) -> None:
+        for block in self.root.allocations:
+            if isinstance(block, Allocation):
+                self.update_restrict_live_range(block)
+
+    def allocate(self, block: Allocation, is_last: bool):
+        if (
+            self.restrict_live_range is not None
+            and not self.restrict_live_range.contains(block.live_range)
+        ):
+            return False
+
+        block_earliest_available = block.get_earliest_available()
+        pool_begin = self.root.get_live_ranges().begin
+        if block_earliest_available and block_earliest_available > pool_begin:
+            return False
+
+        is_last = self.can_expand and is_last
+        if self.root.allocate(block, is_last):
+            self.update_restrict_live_range(block)
+            return True
+
+        if is_last:
+            return self.allocate_at_end(block)
+
+        return False
+
+    def update_restrict_live_range(self, block: Allocation):
+        if block_earliest_available := block.get_earliest_available():
+            if self.restrict_live_range is None:
+                self.restrict_live_range = LiveRange(
+                    block_earliest_available, float("inf")
+                )
+            else:
+                self.restrict_live_range = LiveRange(
+                    min(self.restrict_live_range.begin, block_earliest_available),
+                    self.restrict_live_range.end,
+                )
+
+    def allocate_at_end(self, block):
+        block.mark_allocated()
+        self.root = TemporalSplit([SpatialSplit(self.root, TemporalSplit([block]))])
+        self.update_restrict_live_range(block)
+        return True
+
+    def finalize(self, name):
+        assert not self.name
+        self.name = name
+        self.names_to_del.append(name)
+        self.root.finalize(self, 0)
+
+    def codegen_create(self, wrapper, code: IndentedBuffer):
+        assert self.name
+        nbytes = self.root.get_symbolic_size()
+        for block in self.root.allocations:
+            if isinstance(block, Allocation) and nbytes == block.get_symbolic_size():
+                node = block.node
+                code.writeline(
+                    wrapper.make_allocation(
+                        self.name,
+                        device=self.device,
+                        dtype=node.get_dtype(),
+                        shape=tuple(node.get_size()),
+                        stride=tuple(node.get_stride()),
+                    )
+                )
+                return
+        else:
+            code.writeline(
+                wrapper.make_allocation(
+                    self.name,
+                    device=self.device,
+                    dtype=torch.uint8,
+                    shape=(nbytes,),
+                    stride=(1,),
+                )
+            )
+
+    def codegen_destroy(self, wrapper, code: IndentedBuffer):
+        code.writeline(wrapper.make_free_by_names(self.names_to_del))
+
+    def __eq__(self, other):
+        return self is other
+
+    def __hash__(self):
+        return id(self)
+
+
+@dataclasses.dataclass
+class AllocationPools:
+    """
+    Collection of many AllocationPool objects grouped by device.
+    """
+
+    device_to_pools: dict[torch.device, list[AllocationPool]] = dataclasses.field(
+        default_factory=dict
+    )
+
+    def get_pools(self, block):
+        if block.device not in self.device_to_pools:
+            self.device_to_pools[block.device] = []
+        return self.device_to_pools[block.device]
+
+    def allocate(self, block: Allocation):
+        pools = self.get_pools(block)
+
+        for pool in pools:
+            if pool.allocate(block, is_last=pool is pools[-1]):
+                return
+
+        # everything is full, make a new pool
+        pools.append(
+            AllocationPool(
+                block.device,
+                TemporalSplit([block]),
+                can_expand=config.memory_pool != "none",
+            )
+        )
+        block.mark_allocated()
+
+    def allocate_output(self, block: Allocation):
+        """Outputs get different pools so memory gets freed properly"""
+        pools = self.get_pools(block)
+        if pools and config.memory_pool in ("outputs", "combined"):
+            pools[-1].allocate_at_end(block)
+        else:
+            # create a new pool
+            block.mark_allocated()
+            pools.append(
+                AllocationPool(
+                    block.device,
+                    TemporalSplit([block]),
+                    can_expand=config.memory_pool == "combined",
+                )
+            )
+
+    def finalize(self):
+        """Called at the end of allocation process"""
+        for i, pool in enumerate(
+            itertools.chain.from_iterable(self.device_to_pools.values())
+        ):
+            pool.finalize(f"pool{i}")
+
+    def pprint(self):
+        for pool in itertools.chain.from_iterable(self.device_to_pools.values()):
+            print()
+            print(pool.name)
+            print(pool.root.get_live_ranges())
+            pprint.pprint(pool.root)
+
+
+class BufferGroup:
+    """
+    Due to inplace reuse an allocated buffer can have many names.
+    This tracks these collections of buffers sharing underlying memory.
+    """
+
+    def __init__(self, node: BufferLike):
+        self.node = node
+        self.names = [node.get_name()]
+        self.is_output = False
+        self.allocation: Optional[Allocation] = None
+        self.live_range = LiveRange(float("inf"), -float("inf"))
+
+    def update_usage(self, timestep: int):
+        """Expand self.live_range to include timestep"""
+        self.live_range = LiveRange(
+            min(timestep, self.live_range.begin),
+            max(timestep, self.live_range.end),
+        )
+
+    def sym_nbytes(self):
+        return self.node.get_layout().storage_size() * self.node.get_dtype().itemsize
+
+    def make_allocation(self):
+        assert not self.allocation, "multiple allocations"
+        assert isinstance(self.live_range.begin, int), "live ranges not computed"
+        nbytes = self.sym_nbytes()
+        # For now, fallback value will be used if we encounter an unbacked SymInt. The longer-term plan is to have
+        # size_hint() use better heuristics for unbackeds, at which point the fallback value will be ignored.
+        size_hint = V.graph.sizevars.size_hint(nbytes, fallback=64)
+        self.allocation = Allocation(
+            self.node,
+            self.live_range,
+            size_hint=size_hint,
+            symbolic_size=nbytes,
+        )
+
+    def __repr__(self):
+        return (
+            f"{self.__class__.__name__}({self.names!r}, is_output={self.is_output}, "
+            f"live_range={self.live_range}"
+        )
+
+
+@dataclasses.dataclass
+class PoolMemoryPlanningLine(MemoryPlanningLine):
+    """Abstract base class for {Alloc,Dealloc}FromPoolLine"""
+
+    group: BufferGroup
+    timestep: Optional[int] = None
+
+    @property
+    def node(self):
+        return self.group.node
+
+
+@dataclasses.dataclass
+class AllocFromPoolLine(PoolMemoryPlanningLine):
+    """Similar to AllocationLine, but takes memory from a pool"""
+
+    is_first_pool_usage: bool = False
+
+    def codegen(self, code: IndentedBuffer):
+        allocation = self.group.allocation
+        assert allocation and allocation.pool
+        pool = allocation.pool
+        name = self.node.get_name()
+
+        if self.is_first_pool_usage:
+            pool.codegen_create(self.wrapper, code)
+
+        pool.names_to_del.extend(self.group.names)
+        alloc_from_pool, allocation_lines_to_write = allocation.codegen_alloc_from_pool(
+            self.wrapper
+        )
+        code.writelines(allocation_lines_to_write)
+        if alloc_from_pool in pool.creation_cache:
+            code.writeline(
+                self.wrapper.make_tensor_alias(
+                    name, pool.creation_cache[alloc_from_pool], "alloc"
+                )
+            )
+        else:
+            pool.creation_cache[alloc_from_pool] = name
+            code.writeline(
+                f"{self.wrapper.declare}{name} = {alloc_from_pool}{self.wrapper.ending}"
+            )
+
+
+@dataclasses.dataclass
+class DeallocFromPoolLine(PoolMemoryPlanningLine):
+    """Similar to FreeIfNotReusedLine, but takes memory from a pool"""
+
+    is_last_pool_usage: bool = False
+
+    def codegen(self, code: IndentedBuffer):
+        if self.is_last_pool_usage:
+            assert self.group.allocation and self.group.allocation.pool
+            self.group.allocation.pool.codegen_destroy(self.wrapper, code)
+
+
+@dataclasses.dataclass
+class MemoryPlanner:
+    """
+    Coordination object to run memory planning passes during wrapper
+    codegen.
+    """
+
+    wrapper: Any
+    pools: AllocationPools = dataclasses.field(default_factory=AllocationPools)
+    buffer_groups: Optional[list[BufferGroup]] = None
+
+    def plan(self, lines: list[Any]) -> list[Any]:
+        """Call all the memory planning passes in sequence"""
+        lines = [*lines]
+        self.drop_removed_buffers(lines)
+        self.convert_to_pool_lines(lines)
+        self.compute_live_ranges(lines)
+        self.allocate_groups()
+        self.mark_first_last_usage(lines)
+        return lines
+
+    def drop_removed_buffers(self, lines):
+        """
+        Replace any memory planning lines in V.graph.removed_buffers with NullLine
+        """
+        # drop any removed buffers
+        for i, line in enumerate(lines):
+            if isinstance(line, (AllocateLine, FreeIfNotReusedLine, ReuseLine)):
+                if line.node.get_name() in V.graph.removed_buffers:
+                    lines[i] = NullLine(self.wrapper)
+
+    def compute_buffer_groups(self, lines):
+        """
+        Populates self.buffer_groups with BufferGroup objects that join
+        allocations with common storage (due to inplace reuse) into a
+        single object.
+        """
+        name_to_group = {}
+        for line in lines:
+            if isinstance(line, AllocateLine):
+                name = line.node.get_name()
+                assert name not in name_to_group
+                name_to_group[name] = BufferGroup(line.node)
+            elif isinstance(line, ReuseLine):
+                old_name = line.node.get_name()
+                new_name = line.reused_as.get_name()
+                assert new_name not in name_to_group
+                # TODO(jansel): we should support reusing buffers created via ExternKernelAlloc
+                if old_name in name_to_group:
+                    name_to_group[old_name].names.append(new_name)
+                    name_to_group[new_name] = name_to_group[old_name]
+
+        outputs = OrderedSet(V.graph.get_output_names())
+        unique_groups = [*{id(g): g for g in name_to_group.values()}.values()]
+        for group in unique_groups:
+            group.is_output = any(x in outputs for x in group.names)
+
+        assert self.buffer_groups is None
+        self.buffer_groups = unique_groups
+        return name_to_group
+
+    def convert_to_pool_lines(self, lines):
+        """
+        Convert AllocateLine/FreeIfNotReusedLine/ReuseLine into their
+        pool-based counterparts.
+        """
+        name_to_group = self.compute_buffer_groups(lines)
+        for i, line in enumerate(lines):
+            if isinstance(line, AllocateLine):
+                if line.node.get_name() in name_to_group:
+                    lines[i] = AllocFromPoolLine(
+                        self.wrapper, name_to_group[line.node.get_name()]
+                    )
+            elif isinstance(line, FreeIfNotReusedLine):
+                assert not line.is_reused
+                if line.node.get_name() in name_to_group:
+                    lines[i] = DeallocFromPoolLine(
+                        self.wrapper, name_to_group[line.node.get_name()]
+                    )
+            elif isinstance(line, ReuseLine):
+                if line.node.get_name() in name_to_group:
+                    line.delete_old = False
+
+    def compute_live_ranges(self, lines):
+        """Populate every BufferGroup.live_ranges field based on first/last usage"""
+        timestep = 0
+        worklist = collections.deque(lines)
+        while worklist:
+            if isinstance(worklist[0], MemoryPlanningLine):
+                timestep += 1
+                while worklist and isinstance(worklist[0], MemoryPlanningLine):
+                    line = worklist.popleft()
+                    if isinstance(line, PoolMemoryPlanningLine):
+                        line.group.update_usage(timestep)
+                        line.timestep = timestep
+            else:
+                worklist.popleft()
+
+        timestep += 1
+        assert self.buffer_groups is not None
+        for group in self.buffer_groups:
+            if group.is_output:
+                group.update_usage(timestep)
+
+    def allocate_groups(self):
+        """
+        Assign every allocation to a specific location in a specific AllocationPool.
+        """
+        assert config.memory_pool in ("none", "intermediates", "outputs", "combined")
+        assert self.buffer_groups is not None
+
+        for group in self.buffer_groups:
+            group.make_allocation()
+
+        outputs: list[Allocation] = []
+        intermediates: list[Allocation] = []
+        for group in self.buffer_groups:
+            assert group.allocation
+            if group.is_output and config.memory_pool != "combined":
+                outputs.append(group.allocation)
+            else:
+                intermediates.append(group.allocation)
+
+        for block in sorted(
+            outputs,
+            key=lambda x: (
+                x.size_hint,
+                -len(x.live_range),
+            ),
+        ):
+            self.pools.allocate_output(block)
+
+        for block in sorted(
+            intermediates,
+            key=lambda x: (
+                -x.size_hint,
+                -len(x.live_range),
+            ),
+        ):
+            self.pools.allocate(block)
+
+        self.pools.finalize()
+
+    def mark_first_last_usage(self, lines):
+        """
+        Populate the AllocFromPoolLine.is_first_pool_usage and
+        DeallocFromPoolLine.is_last_pool_usage fields so that pools
+        are created/destroyed.
+        """
+        seen = OrderedSet[AllocationPool]()
+        for line in lines:
+            if isinstance(line, AllocFromPoolLine):
+                assert line.group.allocation
+                pool = line.group.allocation.pool
+                assert pool is not None
+                if pool not in seen:
+                    line.is_first_pool_usage = True
+                    seen.add(pool)
+
+        seen = OrderedSet[AllocationPool]()
+        for line in reversed(lines):
+            if isinstance(line, DeallocFromPoolLine):
+                assert line.group.allocation
+                pool = line.group.allocation.pool
+                assert pool is not None
+                if pool not in seen:
+                    line.is_last_pool_usage = (
+                        pool.root.get_live_ranges().end <= line.timestep
+                    )
+                    seen.add(pool)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/mps.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/mps.py
new file mode 100644
index 0000000000000000000000000000000000000000..84165fea6e3803e6f4feaa33d8bbb5ae4af6be26
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/mps.py
@@ -0,0 +1,1097 @@
+# This is not a feature-complete compiler backend
+# Just an early prototype that shows that one can compile elementwise ops into a Metal shader
+from __future__ import annotations
+
+import functools
+import itertools
+import logging
+import math
+from pathlib import Path
+from typing import Any, Optional, TYPE_CHECKING
+
+import sympy
+from sympy.printing.precedence import PRECEDENCE
+
+import torch
+from torch.utils._cpp_embed_headers import _embed_headers
+from torch.utils._ordered_set import OrderedSet
+from torch.utils._sympy.printers import CppPrinter, ExprPrinter as ExprPrinter_
+from torch.utils._sympy.value_ranges import ValueRanges
+
+from ..utils import ceildiv, get_bounds_index_expr, get_kernel_metadata
+from ..virtualized import ops, OpsWrapper, V
+from .common import (
+    CSEVariable,
+    DeferredLine,
+    DTYPE_TO_COMPUTATION_DTYPE,
+    IndentedBuffer,
+    OpOverrides,
+    PythonPrinter,
+)
+from .simd import IterationRangesEntry, SIMDKernel, SIMDScheduling
+
+
+if TYPE_CHECKING:
+    from typing import Union
+
+    from ..ops_handler import ReductionType, StoreMode
+    from ..scheduler import Scheduler, SchedulerNode
+    from .common import OpVarT
+
+log = logging.getLogger(__name__)
+
+DTYPE_TO_METAL = {
+    torch.bool: "bool",
+    torch.int8: "char",
+    torch.int16: "short",
+    torch.int32: "int",
+    torch.int64: "long",
+    torch.uint8: "uchar",
+    torch.float: "float",
+    torch.half: "half",
+    torch.bfloat16: "bfloat",
+}
+
+
+def value_to_metal(val: Union[float, int, bool, str, CSEVariable]) -> str:
+    if isinstance(val, float):
+        if val == torch.inf:
+            return "HUGE_VALF"
+        elif val == -torch.inf:
+            return "-HUGE_VALF"
+        elif val != val:  # Only float that not equal to self is nan
+            return "NAN"
+        return str(val)
+    elif isinstance(val, bool):
+        return "true" if val else "false"
+    return str(val)
+
+
+class MetalExprPrinter(ExprPrinter_):
+    """Converts sympy expression to Metal code snippet"""
+
+    def _print_FloorDiv(self, expr: sympy.Expr) -> str:
+        x, div = expr.args
+        x = self.doprint(x)
+        div = self.doprint(div)
+        if expr.is_integer:
+            return f"c10::metal::floor_divide({x}, {div})"
+        return f"metal::floor({x}) / ({div})"
+
+    def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
+        x, div, mod = expr.args
+        x = self.doprint(x)
+        if div != 1:
+            div = self.doprint(div)
+            if expr.is_integer:
+                x = f"({x}) / ({div})"
+            else:
+                x = f"metal::floor({x}) / ({div})"
+        mod = self.doprint(mod)
+        return f"({x}) % ({mod})"
+
+    def _print_Min(self, expr: sympy.Expr) -> str:
+        if len(expr.args) != 2:
+            raise RuntimeError("metal::min only supported for 2 args")
+        a, b = map(self._print, expr.args)
+        typecast_a = f"static_cast({a})"
+        typecast_b = f"static_cast({b})"
+        return f"metal::min({typecast_a}, {typecast_b})"
+
+    def _print_Max(self, expr: sympy.Expr) -> str:
+        if len(expr.args) != 2:
+            raise RuntimeError("metal::max only supported for 2 args")
+        a, b = map(self._print, expr.args)
+        typecast_a = f"static_cast({a})"
+        typecast_b = f"static_cast({b})"
+        return f"metal::max({typecast_a}, {typecast_b})"
+
+    def _print_Abs(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"metal::abs({self._print(expr.args[0])})"
+
+    def _print_RoundToInt(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"static_cast(metal::rint({self._print(expr.args[0])}))"
+
+    def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 2
+        number, ndigits = expr.args
+        if number.is_integer:
+            # ndigits < 0 should have been filtered by the sympy function
+            assert ndigits < 0
+            raise ValueError(
+                f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
+            )
+        number_str = self.parenthesize(number, PRECEDENCE["Mul"])
+        return f"static_cast(metal::rint(1e{ndigits} * {number_str}) * 1e{-ndigits})"
+
+    def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
+        lhs, rhs = expr.args
+        # TODO: This is only accurate up to 2**23
+        return f"static_cast({self._print(lhs)}) / static_cast({self._print(rhs)})"
+
+    def _print_PowByNatural(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 2
+        x, y = map(self.doprint, expr.args)
+        return f"metal::pow(static_cast({x}), static_cast({y}))"
+
+    def _print_ToFloat(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        x = self.doprint(expr.args[0])
+        return f"static_cast({x})"
+
+    def _print_Float(self, expr: sympy.Expr) -> str:
+        if expr.is_integer:
+            # sympy considers 0.0 to be integer, but Metal doesn't.
+            # this workaround prints the float as an integer
+            # xref: https://github.com/sympy/sympy/issues/26620
+            return str(int(expr))
+        else:
+            return str(expr)
+
+    def _print_FloorToInt(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        x = self.doprint(expr.args[0])
+        return f"static_cast(metal::floor(static_cast({x})))"
+
+    _print_floor = _print_FloorToInt
+
+    def _print_TruncToInt(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        x = self.doprint(expr.args[0])
+        return f"static_cast(metal::trunc({x}))"
+
+    def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        x = self.doprint(expr.args[0])
+        return f"metal::log2({x})"
+
+    def _print_Where(self, expr: sympy.Expr) -> str:
+        c, p, q = (
+            self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args
+        )
+        return f"{c} ? {p} : {q}"
+
+
+class MetalOverrides(OpOverrides):
+    """Implements Metal-specific overrides for ops. Base class emits Python-friendly overrides."""
+
+    @staticmethod
+    def to_dtype(
+        x: CSEVariable,
+        dtype: torch.dtype,
+        src_dtype: Optional[torch.dtype] = None,
+        use_compute_types: bool = True,
+    ) -> str:
+        if dtype == torch.double:
+            log.warning(
+                "float64 cast requested, probably from tensorify_python_scalars"
+            )
+            return f"static_cast({x})"
+        return f"static_cast<{DTYPE_TO_METAL[dtype]}>({x})"
+
+    @staticmethod
+    def to_dtype_bitcast(
+        x: CSEVariable, dtype: torch.dtype, src_dtype: torch.dtype
+    ) -> str:
+        return f"as_type<{DTYPE_TO_METAL[dtype]}>(static_cast<{DTYPE_TO_METAL[src_dtype]}>({x}))"
+
+    @staticmethod
+    def constant(val: Union[bool, float, int], dtype: torch.dtype) -> str:
+        return value_to_metal(val)
+
+    @staticmethod
+    def index_expr(expr: sympy.Expr, dtype: torch.dtype) -> str:
+        idx_str = V.kernel.index_to_str(V.kernel.prepare_indexing(expr))
+        var = V.kernel.cse.generate(
+            V.kernel.compute, idx_str, bounds=get_bounds_index_expr(expr)
+        )
+        return ops.to_dtype(var, dtype)
+
+    @staticmethod
+    def masked(mask: CSEVariable, body: sympy.Expr, other: CSEVariable) -> str:
+        # TODO: Type annotation for other is wrong, it's often float or int
+        with V.kernel.mask_loads(mask, other) as new_mask:
+            result = body()
+
+        if result.bounds.is_bool:
+            other = bool(other)  # type: ignore[assignment]
+
+        return ops.where(new_mask, result, other)
+
+    @staticmethod
+    def where(a: OpVarT, b: OpVarT, c: OpVarT) -> str:
+        return f"{a} ? {b} : {value_to_metal(c)}"
+
+    @staticmethod
+    def remainder(a: OpVarT, b: OpVarT) -> str:
+        return f"c10::metal::remainder({a}, {b})"
+
+    @staticmethod
+    def maximum(a: CSEVariable, b: CSEVariable) -> str:
+        typecast_a = f"static_cast({a})"
+        typecast_b = f"static_cast({b})"
+        return f"c10::metal::max({typecast_a}, {typecast_b})"
+
+    @staticmethod
+    def minimum(a: CSEVariable, b: CSEVariable) -> str:
+        typecast_a = f"static_cast({a})"
+        typecast_b = f"static_cast({b})"
+        return f"c10::metal::min({typecast_a}, {typecast_b})"
+
+    @staticmethod
+    def logical_or(a: CSEVariable, b: CSEVariable) -> str:
+        return f"{a} || {b}"
+
+    @staticmethod
+    def logical_and(a: CSEVariable, b: CSEVariable) -> str:
+        return f"{a} && {b}"
+
+    @staticmethod
+    def isnan(x: CSEVariable) -> str:
+        return f"metal::isnan({x})"
+
+    @staticmethod
+    def isinf(x: CSEVariable) -> str:
+        return f"metal::isinf({x})"
+
+    @staticmethod
+    def log(x: CSEVariable) -> str:
+        return f"metal::log({x})"
+
+    @staticmethod
+    def exp(x: CSEVariable) -> str:
+        return f"metal::exp({x})"
+
+    @staticmethod
+    def abs(x: CSEVariable) -> str:
+        return f"metal::abs({x})"
+
+    @staticmethod
+    def signbit(x: CSEVariable) -> str:
+        return f"metal::signbit({x})"
+
+    @staticmethod
+    def sin(x: CSEVariable) -> str:
+        return f"metal::precise::sin({x})"
+
+    @staticmethod
+    def sinc(x: CSEVariable) -> str:
+        return f"c10::metal::sinc({x})"
+
+    @staticmethod
+    def cos(x: CSEVariable) -> str:
+        return f"metal::precise::cos({x})"
+
+    @staticmethod
+    def tan(x: CSEVariable) -> str:
+        return f"metal::tan({x})"
+
+    @staticmethod
+    def asin(x: CSEVariable) -> str:
+        return f"metal::asin({x})"
+
+    @staticmethod
+    def acos(x: CSEVariable) -> str:
+        return f"metal::acos({x})"
+
+    @staticmethod
+    def atan(x: CSEVariable) -> str:
+        return f"metal::atan({x})"
+
+    @staticmethod
+    def atan2(x: CSEVariable, y: CSEVariable) -> str:
+        return f"::metal::atan2({x}, {y})"
+
+    @staticmethod
+    def sqrt(x: CSEVariable) -> str:
+        return f"metal::sqrt({x})"
+
+    @staticmethod
+    def neg(x: CSEVariable) -> str:
+        # TODO: Does it rely on undefined behavior?
+        # If so, add special logic for unsigned types
+        return f"static_cast(-{x})"
+
+    @staticmethod
+    def rsqrt(x: CSEVariable) -> str:
+        return f"metal::rsqrt({x})"
+
+    @staticmethod
+    def tanh(x: CSEVariable) -> str:
+        return f"metal::tanh({x})"
+
+    @staticmethod
+    def atanh(x: CSEVariable) -> str:
+        return f"metal::atanh({x})"
+
+    @staticmethod
+    def floordiv(a: CSEVariable, b: CSEVariable) -> str:
+        # a and b must be of integer type
+        return f"c10::metal::floor_divide({a}, {b})"
+
+    @staticmethod
+    def floor(x: CSEVariable) -> str:
+        return f"metal::floor({x})"
+
+    @staticmethod
+    def sign(x: CSEVariable) -> str:
+        return f"metal::sign({x})"
+
+    @staticmethod
+    def fmod(a: CSEVariable, b: CSEVariable) -> str:
+        typecast_a = f"static_cast({a})"
+        typecast_b = f"static_cast({b})"
+        return f"metal::fmod({typecast_a}, {typecast_b})"
+
+    @staticmethod
+    def trunc(x: CSEVariable) -> str:
+        return f"metal::trunc({x})"
+
+    @staticmethod
+    def truncdiv(a: CSEVariable, b: CSEVariable) -> str:
+        quot = f"{a} / {b}"
+        if (a.dtype is not None and a.dtype.is_floating_point) or (
+            b.dtype is not None and b.dtype.is_floating_point
+        ):
+            return f"metal::trunc({quot})"
+        return quot
+
+    @staticmethod
+    def ceil(x: CSEVariable) -> str:
+        return f"metal::ceil({x})"
+
+    @staticmethod
+    def rand(seed: CSEVariable, offset: CSEVariable) -> str:
+        V.kernel.headers.add("random")
+        return f"c10::metal::rand({seed}, {offset})"
+
+    @staticmethod
+    def randn(seed: CSEVariable, offset: CSEVariable) -> str:
+        V.kernel.headers.add("random")
+        return f"c10::metal::randn({seed}, {offset})"
+
+    @staticmethod
+    def randint64(
+        seed: CSEVariable, offset: CSEVariable, low: CSEVariable, high: CSEVariable
+    ) -> str:
+        V.kernel.headers.add("random")
+        return f"c10::metal::randint64({seed}, {offset}, {low}, {high})"
+
+    @staticmethod
+    def round(x: CSEVariable) -> str:
+        return f"metal::rint({x})"
+
+    @staticmethod
+    def pow(a: CSEVariable, b: CSEVariable) -> str:
+        cast_a = f"static_cast({a})"
+        cast_b = f"static_cast({b})"
+        return f"metal::pow({cast_a}, {cast_b})"
+
+    def _special_unary(self, a: CSEVariable, name: str) -> str:
+        V.kernel.headers.add("special_math")
+        return f"c10::metal::{name}({a})"
+
+    def _special_binary(self, a: CSEVariable, b: CSEVariable, name: str) -> str:
+        V.kernel.headers.add("special_math")
+        return f"c10::metal::{name}({a}, {b})"
+
+    @classmethod
+    def _initialize_special_ops(cls) -> None:
+        # Unary special ops
+        for name in [
+            "erf",
+            "erfinv",
+            "i0",
+            "i0e",
+            "i1",
+            "i1e",
+            "digamma",
+            "spherical_bessel_j0",
+        ]:
+            setattr(cls, name, functools.partialmethod(cls._special_unary, name=name))
+
+        cls.lgamma = functools.partialmethod(cls._special_unary, name="log_gamma")  # type: ignore[assignment]
+
+        # Unary special ops with forward in method name
+        for name in [
+            "bessel_j0",
+            "bessel_j1",
+            "bessel_y0",
+            "bessel_y1",
+            "modified_bessel_i0",
+            "modified_bessel_i1",
+            "modified_bessel_k0",
+            "modified_bessel_k1",
+            "scaled_modified_bessel_k0",
+            "scaled_modified_bessel_k1",
+        ]:
+            setattr(
+                cls,
+                name,
+                functools.partialmethod(cls._special_unary, name=name + "_forward"),
+            )
+
+        # Binary special ops
+        for name in [
+            "polygamma",
+            "igamma",
+            "igammac",
+            "zeta",
+        ]:
+            setattr(cls, name, functools.partialmethod(cls._special_binary, name=name))
+
+        # Binary special ops with forward in method name
+        for name in [
+            "chebyshev_polynomial_t",
+            "chebyshev_polynomial_u",
+            "chebyshev_polynomial_v",
+            "chebyshev_polynomial_w",
+            "hermite_polynomial_h",
+            "hermite_polynomial_he",
+            "shifted_chebyshev_polynomial_t",
+            "shifted_chebyshev_polynomial_u",
+            "shifted_chebyshev_polynomial_v",
+            "shifted_chebyshev_polynomial_w",
+        ]:
+            setattr(
+                cls,
+                name,
+                functools.partialmethod(cls._special_binary, name=name + "_forward"),
+            )
+
+
+MetalOverrides._initialize_pointwise_overrides("mps")
+MetalOverrides._initialize_special_ops()
+
+
+class MetalKernel(SIMDKernel):
+    """Implement Metal codegen based on the SIMDKernel abstraction"""
+
+    overrides = MetalOverrides  # type: ignore[assignment]
+    suffix = ";"
+    newvar_prefix = "auto "
+    max_threadgroup_size = 1024
+    simd_group_size = 32
+    pexpr = PythonPrinter().doprint
+    cexpr = CppPrinter().doprint
+    sexpr = MetalExprPrinter().doprint
+    kexpr = sexpr
+    headers: OrderedSet[str] = OrderedSet(["utils"])
+    multistage_reduction_entry: list[IterationRangesEntry] = []
+
+    def __init__(
+        self,
+        tiling: dict[str, sympy.Expr],
+        **kwargs: Any,
+    ) -> None:
+        super().__init__(tiling, **kwargs)
+        self.acc_var_ids = itertools.count()
+
+    def dtype_to_str(self, dtype: torch.dtype) -> str:
+        return DTYPE_TO_METAL[dtype]
+
+    def load(self, name: str, index: sympy.Expr) -> CSEVariable:
+        """Codegen a load from an InputBuffer"""
+        var = self.args.input(name)
+        index = self.prepare_indexing(index)
+        dtype = V.graph.get_dtype(name)
+        line = f"{var}[{self.index_to_str(index)}]"
+        if dtype in [torch.float16, torch.bfloat16]:
+            # TODO(NS): Figure out the right balance between optype casts
+            # op_math_t for half-precision floats should be float32
+            # Otherwise it can lead to a correctness issues with eager
+            line = f"static_cast({line})"
+            dtype = torch.float32
+        return self.cse.generate(self.loads, line, dtype=dtype)
+
+    def store(
+        self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
+    ) -> None:
+        var = self.args.output(name)
+        index = self.prepare_indexing(index)
+        dtype_str = self.dtype_to_str(V.graph.get_dtype(name))
+        cast_val = f"static_cast<{dtype_str}>({value})"
+        if mode is None:
+            line = f"{var}[{self.index_to_str(index)}] = {cast_val};"
+        elif mode == "atomic_add":
+            self.headers.add("atomic")
+            atomic_type = f"c10::metal::AtomicType<{dtype_str}>"
+            cast_var = f"reinterpret_cast({var})"
+            line = f"{atomic_type}::atomic_add({cast_var}, {self.index_to_str(index)}, {cast_val});"
+        else:
+            raise RuntimeError(f"Unimplemented store mode {mode}")
+        if self.inside_reduction:
+            self.compute.writeline(DeferredLine(name, line))
+        else:
+            self.stores.writeline(DeferredLine(name, line))
+
+    def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None:
+        var = self.args.output(name)
+        index = self.prepare_indexing(index)
+        dtype_str = self.dtype_to_str(V.graph.get_dtype(name))
+        # pyrefly: ignore [missing-argument]
+        reduction_dim = next(t for t in self.range_trees if t.is_reduction)
+        # Only one thread in the reduction group needs to store the results
+        line = f"{var}[{self.index_to_str(index)}] = static_cast<{dtype_str}>({value});"
+        line = f"if ({reduction_dim.name} == 0) {line}"
+        self.stores.writeline(DeferredLine(name, line))
+
+    def _new_idxvar(
+        self,
+        dtype: Union[str | torch.dtype],
+        elem_count: Optional[int] = None,
+        default_value: Optional[Any] = None,
+        is_threadgroup: bool = True,
+        bounds: ValueRanges[Any] = ValueRanges.unknown(),
+    ) -> CSEVariable:
+        if isinstance(dtype, torch.dtype):
+            dtype = self.dtype_to_str(dtype)
+        var_name = f"tmp_acc_{next(self.acc_var_ids)}"
+        var = V.kernel.create_cse_var(var_name, bounds, dtype)
+        var_def = "threadgroup " if is_threadgroup else ""
+        var_def += f"{dtype} {var_name}"
+        if elem_count:
+            var_def += f"[{self.sexpr(elem_count)}]"
+        if default_value is not None:
+            assert not is_threadgroup, "Thread group var can not have default value"
+            var_def += f" = {default_value}"
+        self.indexing_code.writeline(var_def + self.suffix)
+        return var
+
+    def reduction(
+        self,
+        dtype: torch.dtype,
+        src_dtype: torch.dtype,
+        reduction_type: ReductionType,
+        value: Union[CSEVariable, tuple[CSEVariable, ...]],
+    ) -> Union[CSEVariable, tuple[CSEVariable, ...]]:
+        "Caching wrapper around _reduction_nocache"
+        cache_key = (src_dtype, reduction_type, value)
+        # Return cached reduction
+        if cache_key in self.cse.reduction_cache:
+            return self.cse.reduction_cache[cache_key]
+        result = self._reduction_nocache(dtype, src_dtype, reduction_type, value)
+        self.cse.reduction_cache[cache_key] = result  # type: ignore[assignment]
+        return result
+
+    def _reduction_nocache(
+        self,
+        dtype: torch.dtype,
+        src_dtype: torch.dtype,
+        reduction_type: ReductionType,
+        value: Union[CSEVariable, tuple[CSEVariable, ...]],
+    ) -> Union[CSEVariable, tuple[CSEVariable, ...]]:
+        """Codegen a reduction operation.
+        Only sum and prod operations are somewhat reasonable optimized"""
+        assert self.inside_reduction
+        assert not self._load_mask
+
+        def _unwrap_helper(res3: CSEVariable) -> tuple[CSEVariable, ...]:
+            # Uwraps vec3 dtype into individual components
+            return OpsWrapper._unwrap(
+                [CSEVariable(f"{res3}.{t}", res3.bounds, res3.dtype) for t in "xyz"]
+            )
+
+        # Establish reduction buffer size and index expression
+        reduction_idx = ""
+        acc_buf_size = 1
+        for rd in self.range_trees:
+            # pyrefly: ignore [missing-argument]
+            if not rd.is_reduction:
+                continue
+            if reduction_idx:
+                reduction_idx += " + "
+            reduction_idx += f"{rd.name} * {acc_buf_size}"
+
+            if isinstance(rd.numel, sympy.Integer):
+                acc_buf_size *= rd.numel
+            else:
+                acc_buf_size *= sympy.Symbol(
+                    f"{rd.prefix}numel", integer=True, positive=True
+                )
+
+        acc_buf_size = sympy.Min(acc_buf_size, self.max_threadgroup_size)
+        acc_buf_size_str = self.sexpr(acc_buf_size)
+        shmem_buf_size = (
+            ceildiv(acc_buf_size, self.simd_group_size)
+            if isinstance(acc_buf_size, sympy.Integer)
+            else self.simd_group_size
+        )
+
+        if reduction_type == "any":
+            acc = self._new_idxvar(dtype)
+            self.indexing_code.writeline(f"{acc} = false;")
+            self.indexing_code.writeline(
+                "threadgroup_barrier(metal::mem_flags::mem_threadgroup);"
+            )
+            self.compute.splice(
+                f"""
+                if ({value}) {{
+                    {acc} = true;
+                }}
+            """
+            )
+            self.stores.writeline(
+                "threadgroup_barrier(metal::mem_flags::mem_threadgroup);"
+            )
+            return acc
+
+        self.headers.add("reduction_utils")
+
+        if reduction_type in ["prod", "sum"]:
+            acc_dtype = DTYPE_TO_COMPUTATION_DTYPE[src_dtype]
+            acc_buf = self._new_idxvar(acc_dtype, shmem_buf_size)
+            if not self.multistage_reduction_entry:
+                val = value
+            else:
+                default_val, reduction_op = (
+                    (0, "+") if reduction_type == "sum" else (1, "*")
+                )
+                val = self._new_idxvar(
+                    acc_dtype, default_value=default_val, is_threadgroup=False
+                )
+                self.compute.splice(f"{val} {reduction_op}= {value};")
+
+            return self.cse.generate(
+                self.stores,
+                f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {val}, {reduction_idx}, {acc_buf_size_str})",
+                dtype=DTYPE_TO_COMPUTATION_DTYPE[dtype],
+            )
+        if reduction_type in ["max", "min"]:
+            acc_buf = self._new_idxvar(src_dtype, shmem_buf_size)
+            src_metal_type = DTYPE_TO_METAL[src_dtype]
+            cast_value = f"static_cast<{src_metal_type}>({value})"
+            if not self.multistage_reduction_entry:
+                val = cast_value  # type: ignore[assignment]
+            else:
+                lim_fn = "lowest" if reduction_type.endswith("max") else "max"
+                limit_val = f"::metal::numeric_limits<{src_metal_type}>::{lim_fn}()"
+                val = self._new_idxvar(
+                    src_dtype, default_value=limit_val, is_threadgroup=False
+                )
+                self.compute.splice(
+                    f"{val} = ::c10::metal::{reduction_type}({val}, {cast_value});"
+                )
+            return self.cse.generate(
+                self.stores,
+                f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {val}, {reduction_idx}, {acc_buf_size_str})",
+                dtype=DTYPE_TO_COMPUTATION_DTYPE[dtype],
+            )
+        if reduction_type in ["argmin", "argmax"]:
+            data_acc_buf = self._new_idxvar(src_dtype, shmem_buf_size)
+            idx_acc_buf = self._new_idxvar(dtype, shmem_buf_size)
+            src_metal_type = DTYPE_TO_METAL[src_dtype]
+            cast_value = f"static_cast<{src_metal_type}>({value})"
+            if not self.multistage_reduction_entry:
+                val = cast_value  # type: ignore[assignment]
+                idx_val = f"static_cast<{DTYPE_TO_METAL[dtype]}>({reduction_idx})"
+            else:
+                lim_fn = "lowest" if reduction_type.endswith("max") else "max"
+                limit_val = f"::metal::numeric_limits<{src_metal_type}>::{lim_fn}()"
+                val = self._new_idxvar(
+                    src_dtype, default_value=limit_val, is_threadgroup=False
+                )
+                idx_val = self._new_idxvar(dtype, default_value=0, is_threadgroup=False)  # type: ignore[assignment]
+                idx_var = next(
+                    t
+                    for t in self.range_tree_nodes.values()
+                    # pyrefly: ignore [missing-argument]
+                    if t.is_reduction
+                )
+                cmp_op = ">" if reduction_type == "argmax" else "<"
+                nan_suffix = (
+                    f" || ::metal::isnan({value}) "
+                    if src_dtype.is_floating_point
+                    else ""
+                )
+                self.compute.splice(f"""
+                if ({value} {cmp_op} {val}{nan_suffix}) {{
+                    {val} = {value};
+                    {idx_val} = {idx_var.name};
+                }}
+                """)
+            return self.cse.generate(
+                self.stores,
+                f"c10::metal::threadgroup_{reduction_type}({data_acc_buf}, {idx_acc_buf}, "
+                f"{val}, {idx_val}, {reduction_idx}, {acc_buf_size_str})",
+                dtype=dtype,
+            )
+        if reduction_type == "welford_reduce":
+            if not self.multistage_reduction_entry:
+                acc_buf = self._new_idxvar(src_dtype, acc_buf_size)
+                self.compute.splice(f"{acc_buf}[{reduction_idx}] = {value};")
+                wf_res = self.cse.generate(
+                    self.compute,
+                    f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size_str})",
+                    dtype=torch.float32,
+                )
+                return _unwrap_helper(wf_res)
+            acc_buf = self._new_idxvar("float3", acc_buf_size)
+            acc_thread_var = f"{acc_buf}[{reduction_idx}]"
+            self.indexing_code.splice(f"{acc_thread_var} = 0.0;")
+            self.compute.writeline(
+                f"{acc_thread_var} = ::c10::metal::welford_combine({acc_thread_var}, float3({value}, 0.0, 1.0));"
+            )
+            wf_res = self.cse.generate(
+                self.stores,
+                f"c10::metal::threadgroup_welford_combine({acc_buf}, {acc_buf_size})",
+                dtype=torch.float32,
+            )
+            return _unwrap_helper(wf_res)
+        if reduction_type == "welford_combine":
+            assert isinstance(value, tuple), "Input to welford combine must be tuple"
+            acc_buf = self._new_idxvar("float3", acc_buf_size)
+            acc_thread_var = f"{acc_buf}[{reduction_idx}]"
+            inp_value = f"float3({value[0]}, {value[1]}, {value[2]})"
+            self.indexing_code.splice(f"{acc_thread_var} = 0.0;")
+            if self.multistage_reduction_entry:
+                self.indexing_code.splice(f"{acc_thread_var} = 0.0;")
+                self.compute.writeline(
+                    f"{acc_thread_var} = ::c10::metal::welford_combine({acc_thread_var}, {inp_value});"
+                )
+            else:
+                self.compute.writeline(f"{acc_thread_var} = {inp_value};")
+            wf_res = self.cse.generate(
+                self.stores if self.multistage_reduction_entry else self.compute,
+                f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size_str})",
+                dtype=torch.float32,
+            )
+            return _unwrap_helper(wf_res)
+        raise NotImplementedError(reduction_type)
+
+    def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry) -> None:
+        index_expr = self.rename_indexing(entry.expr)
+        index_str = self.sexpr(index_expr)  # type: ignore[misc]
+
+        # pyrefly: ignore [missing-argument]
+        if not entry.is_reduction or (
+            isinstance(entry.root.numel, sympy.Integer)
+            and entry.root.numel <= self.max_threadgroup_size
+        ):
+            self.indexing_code.writeline(
+                f"{self.index_dtype} {entry.name} = {index_str};"
+            )
+            return
+
+        acc_size = (
+            entry.root.numel
+            if isinstance(entry.root.numel, sympy.Integer)
+            else sympy.Symbol(f"{entry.root.prefix}numel", integer=True, positive=True)
+        )
+
+        self.multistage_reduction_entry.append(entry)
+        # When reducing the tensor whose size exceeds max threadgroup size
+        # loop over extra indices per reduction thread and perform part of the operation
+        # using values in the shared memory
+
+        # Use floats so that it doesn't do integer division
+        loop_size = (acc_size + float(self.max_threadgroup_size - 1)) // float(
+            self.max_threadgroup_size
+        )
+        loop_size_str = self.sexpr(loop_size)
+
+        self.body.writeline(
+            f"for(auto {entry.name}_cnt = 0; {entry.name}_cnt < {loop_size_str}; ++{entry.name}_cnt) {{"
+        )
+        with self.body.indent():
+            if isinstance(acc_size, sympy.Symbol):
+                self.body.writeline(
+                    f"{self.index_dtype} {entry.name} = {self.max_threadgroup_size} * {entry.name}_cnt + {index_str};"
+                )
+            else:
+                self.body.writeline(
+                    f"{self.index_dtype} {entry.name} = {loop_size_str} * {index_str} + {entry.name}_cnt;"
+                )
+
+            # Check that reduction is performed only within tensor boundary
+            if (
+                isinstance(acc_size, sympy.Symbol)
+                or loop_size * self.max_threadgroup_size != acc_size
+            ):
+                self.body.writeline(f"if ({entry.name} >= {acc_size}) break;")
+
+    def codegen_body(self) -> None:
+        """
+        Concat output code from index_code, loads, compute, stores,
+        suffix into self.body.
+
+        For pointwise kernels, this is called just once at the end.
+
+        For reduction kernels, this generates a loop over the reduction
+        axis.
+        """
+        if self.multistage_reduction_entry:
+            with self.body.indent():
+                self.body.splice(self.loads)
+                self.body.splice(self.compute)
+            self.body.writeline("}" * len(self.multistage_reduction_entry))
+            # Invalidate variables instantiated inside loop
+            # But results of reduction alive. Reduction cache values can be
+            # either CSEVariable or tuple of CSEVariables, in which case all
+            # variables in the tuple must be preserved
+            self.cse.invalidate(
+                OrderedSet(
+                    v
+                    for item in self.cse.reduction_cache.values()
+                    for v in (item if isinstance(item, tuple) else (item,))
+                )
+            )
+            # And loop codegen
+            while self.multistage_reduction_entry:
+                self.multistage_reduction_entry.pop().cache_clear()
+        else:
+            self.body.splice(self.loads)
+            self.body.splice(self.compute)
+        self.body.splice(self.stores)
+        self.loads.clear()
+        self.compute.clear()
+        self.stores.clear()
+
+    def codegen_kernel(self, name: Optional[str] = None) -> str:
+        """Called at the end to generate a final kernel string"""
+        self.codegen_body()
+        code = IndentedBuffer()
+
+        if V.graph.cpp_wrapper:
+            code.writeline('(R"MTL(')
+        else:
+            code.writeline("compile_mps_shader('''")
+
+        idx_vars = self.active_range_trees()
+        with code.indent():
+            if not V.graph.cpp_wrapper:
+                for header in self.headers:
+                    code.writeline(f"#include ")
+            else:
+                headers = [
+                    f"#include " for header in self.headers
+                ]
+                header_contents = _embed_headers(
+                    headers,
+                    [Path(__file__).parent.parent.parent / "include"],
+                    OrderedSet(),  # type: ignore[arg-type]
+                )
+                code.writeline(header_contents)
+
+            if self.inside_reduction:
+                total_reduction_size = math.prod(
+                    t.numel
+                    for t in self.range_trees
+                    # pyrefly: ignore [missing-argument]
+                    if t.is_reduction
+                )
+                # If using dynamic shapes, set the threadgroup size to be the
+                # max possible size
+                threadgroup_size = (
+                    min(total_reduction_size, self.max_threadgroup_size)
+                    if isinstance(total_reduction_size, sympy.Integer)
+                    else self.max_threadgroup_size
+                )
+                code.writeline(
+                    f"[[max_total_threads_per_threadgroup({threadgroup_size})]]"
+                )
+            code.writeline("kernel void generated_kernel(")
+            with code.indent():
+                for outer, inner in self.args.output_buffers.items():
+                    if outer in self.removed_buffers:
+                        continue
+                    dtype_str = self.dtype_to_str(V.graph.get_dtype(outer))
+                    code.writeline(f"device {dtype_str}* {inner},")
+                for outer, inner in self.args.input_buffers.items():
+                    dtype = V.graph.get_dtype(outer)
+                    # MPS does not support float64, but scalar inputs are fine
+                    if dtype == torch.float64:
+                        outer_buf = V.graph.try_get_buffer(outer)
+                        if outer_buf is None or outer_buf.get_size() != []:
+                            raise RuntimeError("float64 is not supported by MPS")
+                        dtype_str = "float"
+                    else:
+                        dtype_str = self.dtype_to_str(dtype)
+                    code.writeline(f"constant {dtype_str}* {inner},")
+                for inner in self.args.sizevars.values():
+                    code.writeline(f"constant long& {inner},")
+
+                # Write dynamic values as inputs
+                for idx_var in idx_vars:
+                    if isinstance(idx_var.numel, sympy.Integer):
+                        pass
+                    else:
+                        code.writeline(f"constant long& {idx_var.prefix}numel,")
+
+                assert len(idx_vars) < 4, "Up to 3 index variables are supported"
+                thread_pos_dtype = (
+                    f"uint{len(idx_vars)}" if len(idx_vars) > 1 else "uint"
+                )
+                thread_pos_var_name = (
+                    idx_vars[0].name if len(idx_vars) == 1 else "thread_pos"
+                )
+                thread_pos_suffix = "," if self.inside_reduction else ""
+                code.writeline(
+                    f"{thread_pos_dtype} {thread_pos_var_name} [[thread_position_in_grid]]{thread_pos_suffix}"
+                )
+                if self.inside_reduction:
+                    code.writeline(
+                        f"{thread_pos_dtype} group_pos [[thread_position_in_threadgroup]]"
+                    )
+            code.writeline(") {")
+            with code.indent():
+                if len(idx_vars) > 1:
+                    for idx, var in enumerate(idx_vars):
+                        code.writeline(
+                            f"auto {var.name} = thread_pos.{chr(120 + idx)};"
+                        )
+                code.splice(self.indexing_code)
+                code.splice(self.body)
+            code.writeline("}")
+
+        if V.graph.cpp_wrapper:
+            code.writeline(')MTL");')
+        else:
+            code.writeline("''')")
+
+        return code.getvalue()
+
+    def call_kernel(
+        self, name: str, node: Any = None, deallocate_ws: bool = True
+    ) -> None:
+        """
+        Codegens a call to this kernel
+        """
+        wrapper = V.graph.wrapper_code
+        # Make sure sizevars has been computed
+        for v in self.args.sizevars:
+            wrapper.ensure_size_computed(v)
+
+        _, call_args, _, arg_types = self.args.python_argdefs()
+        arg_name_to_type = {
+            str(call_arg): arg_type for call_arg, arg_type in zip(call_args, arg_types)
+        }
+
+        args = [*self.args.output_buffers.keys(), *self.args.input_buffers.keys()]
+        args = [arg for arg in args if arg not in self.removed_buffers]
+        args += [str(v) for v in self.args.sizevars]
+        arg_types = [arg_name_to_type[arg] for arg in args]
+
+        # Add any dynamic ints as inputs
+        for tree in self.range_trees:
+            if isinstance(tree.numel, (sympy.Integer, int)):
+                # Don't need to pass in integers as inputs
+                continue
+            elif isinstance(tree.numel, sympy.Symbol):
+                expr = tree.numel
+            else:
+                expr = V.graph.wrapper_code.generate_numel_expr(name, tree).inner
+
+            # pyrefly: ignore [missing-argument]
+            if not tree.is_reduction or self.inside_reduction:
+                args.append(str(expr))
+                arg_types.append(int)
+
+        expr_printer = self.cexpr if V.graph.cpp_wrapper else self.pexpr
+
+        def format_threads(threads: list[str], kwarg: str) -> str:
+            if V.graph.cpp_wrapper:
+                threads = [f"static_cast({t})" for t in threads]
+                return f"{{{', '.join(threads)}}}"
+            else:
+                return f"{kwarg}=[{', '.join(threads)}]"
+
+        # For reduction kernels, limit the maximum size over reduction dimensions to
+        # a maximum threadgroup size
+        if len(self.active_range_trees()) > 0:
+            threads = [
+                expr_printer(
+                    sympy.Min(v.numel, self.max_threadgroup_size)  # type: ignore[misc]
+                    # pyrefly: ignore [missing-argument]
+                    if v.is_reduction
+                    else v.numel
+                )
+                for v in self.active_range_trees()
+            ]
+
+            args.append(format_threads(threads, "threads"))
+            arg_types.append(list)
+        else:
+            if V.graph.cpp_wrapper:
+                raise RuntimeError("We should always have threads?")
+
+        if self.inside_reduction:
+            threads = [
+                expr_printer(sympy.Min(v.numel, self.max_threadgroup_size))  # type: ignore[misc]
+                # pyrefly: ignore [missing-argument]
+                if v.is_reduction
+                else "1"
+                for v in self.active_range_trees()
+            ]
+            args.append(format_threads(threads, "group_size"))
+            arg_types.append(list)
+        else:
+            if V.graph.cpp_wrapper:
+                # Add a None so that we always have a group_size in the
+                # arguments. We won't use it if the value is None.
+                args += [None]  # type: ignore[list-item]
+                arg_types.append(None)
+
+        wrapper.generate_kernel_call(
+            name,
+            args,
+            device=torch.device("mps"),
+            triton=False,
+            arg_types=arg_types,
+        )
+
+    def check_bounds(
+        self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
+    ) -> None:
+        if not (lower or upper):
+            return
+        # TODO(malfet): support asserts
+        # See https://github.com/pytorch/pytorch/issues/144634
+        expr_str = self.index_to_str(expr)
+        lower_expr = f"{expr_str} < 0" if lower else ""
+        # TODO(malfet): Is upper bound inclusive or exclusive?
+        upper_expr = f"{expr_str} > {self.index_to_str(size)}" if upper else ""
+        if lower and upper:
+            line = f"if (({lower_expr}) && ({upper_expr})) return"
+        else:
+            line = f"if ({lower_expr}{upper_expr}) return"
+        self.cse.generate(self.compute, line, assignment=False)
+
+
+class MetalScheduling(SIMDScheduling):
+    kernel_type = MetalKernel  # type: ignore[assignment]
+
+    def __init__(self, scheduler: Optional[Scheduler]) -> None:
+        super().__init__(scheduler)
+        wrapper = V.graph.wrapper_code
+        if wrapper is not None:
+            if not V.graph.cpp_wrapper:
+                wrapper.header.splice(
+                    "from torch._inductor.runtime.runtime_utils import compile_mps_shader"
+                )
+
+    def define_kernel(
+        self, src_code: str, node_schedule: list[SchedulerNode], kernel: MetalKernel
+    ) -> str:
+        wrapper = V.graph.wrapper_code
+        if src_code in wrapper.src_to_kernel:
+            kernel_name = wrapper.src_to_kernel[src_code]
+        else:
+            # TODO: Merge multiple kernels into a single library
+            # Either using MultiKernel concept or overriding SIMDScheduling.codegen_node_scheduling
+            mps_lib_name = f"mps_lib_{wrapper.next_kernel_suffix()}"
+
+            kernel_name = f"{mps_lib_name}"
+            wrapper.src_to_kernel[src_code] = kernel_name
+
+            if V.graph.cpp_wrapper:
+                # For shimified version, generate source constant instead of direct instantiation
+                src_code = f"const char* {mps_lib_name}_source = " + src_code
+
+            origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
+            metadata_comment = f"{origins}\n{detailed_origins}"
+            wrapper.define_kernel(mps_lib_name, src_code, metadata_comment, gpu=False)
+
+        return kernel_name
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/mps_device_op_overrides.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/mps_device_op_overrides.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b4ddb163ef4f9957e1a64a4ab25ec865e8206b5
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/mps_device_op_overrides.py
@@ -0,0 +1,24 @@
+from __future__ import annotations
+
+from .common import DeviceOpOverrides, register_device_op_overrides
+
+
+class MPSDeviceOpOverrides(DeviceOpOverrides):
+    def device_guard(self, device_idx: int) -> str:
+        assert device_idx == 0
+        return "torch._ops.contextlib.nullcontext()"
+
+    def set_device(self, device_idx: int) -> str:
+        assert device_idx == 0
+        return "pass  # MPS set device"
+
+    def kernel_driver(self) -> str:
+        return """
+            #include 
+        """
+
+    def cpp_kernel_type(self) -> str:
+        return "MTLFunction_t"
+
+
+register_device_op_overrides("mps", MPSDeviceOpOverrides())
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/multi_kernel.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/multi_kernel.py
new file mode 100644
index 0000000000000000000000000000000000000000..094164a1f08ca8db973047a90d182fb943795b88
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/multi_kernel.py
@@ -0,0 +1,612 @@
+# mypy: allow-untyped-defs
+import functools
+import logging
+import math
+import os
+import pathlib
+from typing import Any, Optional, Union
+
+from torch._inductor.ir import MultiTemplateBuffer
+from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
+from torch.utils._ordered_set import OrderedSet
+
+from .. import config
+from ..codecache import code_hash, CodeCacheFuture, get_path, write_atomic
+from ..runtime.benchmarking import benchmarker
+from ..utils import cache_on_self, IndentedBuffer
+from ..virtualized import V
+from .common import TensorArg, WorkspaceArg
+
+
+log = logging.getLogger(__name__)
+
+
+class MultiKernelState:
+    """
+    Maintain state of multi-kernel compilation so we don't define duplicated
+    multi-kernel for the same set of sub-kernels.
+
+    V.graph.wrapper_code has a reference to MultiKernelState instance.
+    """
+
+    def __init__(self):
+        self.subkernel_to_kernel_name = {}
+        self.kernel_defs = IndentedBuffer()
+
+    def define_kernel(
+        self,
+        kernels: list[Any],
+        kernel_shape_keys: Optional[
+            list[Union[None, tuple[tuple[int, ...], ...]]]
+        ] = None,
+    ) -> str:
+        """
+        Previously we name the multi kernel as "multi_kernel_{kernel_names[0]}".
+        This has some minor issue.
+
+        E.g. for persistent reduction https://gist.github.com/shunting314/39e7c00ff8bb2055942ed5a3255d61ca ,
+        there are 2 flavors of non-persistent reduction:
+          https://gist.github.com/shunting314/056d43d35907e87efb883970b35c17d4
+        and
+          https://gist.github.com/shunting314/02ee753b65c513c54e695626afe682bd
+
+        The only different is cache eviction policy.
+
+        We should name the multi-kernel differently in these 2 cases.
+
+        kernels:
+            A list of kernels
+        kernel_shape_keys:
+            Specified for size-hint multi-kernels.
+            Each list element is a shape key, corresponding to the concrete input & output size hints each kernel was tuned for.
+        """
+        # Prevent circular import
+        from ..select_algorithm import TritonTemplateKernel
+
+        kernel_names = tuple(k.kernel_name for k in kernels)
+        if kernel_names in self.subkernel_to_kernel_name:
+            return self.subkernel_to_kernel_name[kernel_names]
+
+        # name the multi kernel based on the first kernel
+        multi_kernel_name = f"multi_kernel_{len(self.subkernel_to_kernel_name)}"
+        self.subkernel_to_kernel_name[kernel_names] = multi_kernel_name
+
+        if V.graph.cpp_wrapper and not config.triton.autotune_at_compile_time:
+            # we should not generate any python code for multi-kernel during
+            # the second pass of cpp-wrapper.
+            return multi_kernel_name
+
+        arg_index: dict[int, list[slice]] = {}
+        _, call_args, _, arg_types = kernels[0].args.python_argdefs()
+        if isinstance(kernels[0], TritonTemplateKernel) and isinstance(
+            kernels[0].output_node, MultiTemplateBuffer
+        ):
+            for i, kernel in enumerate(kernels):
+                additional_call_args, _ = kernel.additional_call_args_and_types()
+                if i not in arg_index:
+                    arg_index[i] = []
+                arg_index[i].append(slice(0, len(call_args)))
+                arg_index[i].append(
+                    slice(
+                        len(call_args) + i * len(additional_call_args),
+                        len(call_args) + (i + 1) * len(additional_call_args),
+                    )
+                )
+        else:
+            kernels[0].add_numel_to_call_args(multi_kernel_name, call_args, arg_types)
+            for i in range(len(kernels)):
+                arg_index[i] = [slice(0, len(call_args))]
+
+        keyed_by_sizes = kernel_shape_keys is not None
+        buf = self.kernel_defs
+        buf.writeline("")
+        buf.writeline("arg_index = {")
+        for key, slice_list in arg_index.items():
+            slice_reprs = ", ".join(repr(s) for s in slice_list)
+            buf.writeline(f"    {key}: [{slice_reprs}],")
+        buf.writeline("}")
+
+        if not keyed_by_sizes:  # no size hint keys, just call with list of kernels
+            buf.writeline(
+                f"{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, ["
+            )
+            with buf.indent():
+                for name in kernel_names:
+                    buf.writeline(f"{name},")
+            buf.writeline("], arg_index=arg_index)")
+        else:  # call with dict[size hint key, kernel]
+            assert isinstance(kernels[0], TritonTemplateKernel)
+            assert isinstance(kernel_shape_keys, list)
+            assert len(kernels) == len(kernel_shape_keys)
+            buf.writeline(
+                f"{multi_kernel_name} = async_compile.size_hint_multi_kernel({multi_kernel_name!r}, {{"
+            )
+            with buf.indent():
+                for shape_key, name in zip(kernel_shape_keys, kernel_names):
+                    buf.writeline(f"{shape_key}: {name},")
+            buf.writeline("}, arg_index=arg_index)")
+
+        if config.triton.autotune_at_compile_time:
+            V.graph.wrapper_code.src_to_kernel["\n".join(kernel_names)] = (
+                multi_kernel_name
+            )
+
+        return multi_kernel_name
+
+
+class MultiKernel:
+    """
+    This class maintains the compile time state for multi kernels.
+
+    Assume we do codegen for a MultiKernel encapsulating kernel1 and kernel2.
+    The generated definition for the multi-kernel will looks like:
+    ```
+    multi_kernel_kernel1 = MultiKernelCall(
+        [kernel1, kernel2], multi_kernel_definition_code
+    )
+    ```
+
+    Here is a concrete example: https://gist.github.com/shunting314/d9f3fb6bc6cee3dbae005825ca196d39
+    """
+
+    def __init__(self, kernels):
+        assert len(kernels) >= 2
+
+        self.kernels = kernels
+        self.kernel_name = V.graph.wrapper_code.multi_kernel_state.define_kernel(
+            kernels
+        )
+
+        # need this since some code in inductor check if the kernel object has an args
+        # attribute to decide if it's a non-null kernel.
+        self.args = object()
+
+    @staticmethod
+    def _merge_workspace_args(left: list[WorkspaceArg], right: list[WorkspaceArg]):
+        if left == right:
+            return left
+        result = {x.inner_name: x for x in left}
+        for arg in right:
+            if arg.inner_name in result:
+                result[arg.inner_name] = WorkspaceArg.maximum(
+                    result[arg.inner_name], arg
+                )
+            else:
+                result[arg.inner_name] = arg
+        return [*result.values()]
+
+    @staticmethod
+    def merge_workspaces_inplace(kernels):
+        if len(kernels) < 2:
+            return
+        # All kernels must share the same workspace
+        workspace_args = functools.reduce(
+            MultiKernel._merge_workspace_args,
+            [kernel.args.workspace_args for kernel in kernels],
+        )
+        for kernel in kernels:
+            kernel.args.workspace_args = workspace_args
+        return workspace_args
+
+    def call_kernel(self, kernel_name):
+        """
+        Collect the union of arguments from all subkernels as the arguments
+        for the multi-kernel.
+        """
+        # Prevent circular import
+        from ..select_algorithm import TritonTemplateKernel
+
+        assert kernel_name == self.kernel_name
+        V.graph.wrapper_code.write_triton_header_once()
+        _, call_args, _, arg_types = self.kernels[0].args.python_argdefs()
+        for kernel in self.kernels[1:]:
+            _, other_call_args, _, other_arg_types = kernel.args.python_argdefs()
+            assert call_args == other_call_args, (call_args, other_call_args)
+            assert arg_types == other_arg_types
+
+        if V.graph.cpp_wrapper and not config.triton.autotune_at_compile_time:
+            # for the second pass of cpp-wrapper codegen, we should call
+            # the fast kernel directly
+            kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
+
+        if isinstance(self.kernels[0], TritonTemplateKernel) and isinstance(
+            self.kernels[0].output_node, MultiTemplateBuffer
+        ):
+            # For matmuls the grid arguments are passed in as additional arguments
+            # to the kernel run method. These grids change based on the various
+            # parameters of the matmul. So we need to pass each kernel's grid into
+            # the multi call kernel.
+            multi_call_args = call_args
+            multi_call_arg_types = arg_types
+            for kernel in self.kernels:
+                additional_call_args, additional_arg_types = (
+                    kernel.additional_call_args_and_types()
+                )
+                multi_call_args.extend(list(additional_call_args))
+                multi_call_arg_types.extend(list(additional_arg_types))
+        else:
+            # numels for all subkernels should be the same. Use kernels[0] here
+            self.kernels[0].add_numel_to_call_args(kernel_name, call_args, arg_types)
+            multi_call_args = call_args
+            multi_call_arg_types = arg_types
+
+        for ws in self.kernels[0].args.workspace_args:
+            V.graph.wrapper_code.generate_workspace_allocation(ws)
+
+        if V.graph.cpp_wrapper:
+            # We have already selected the best kernel at compile time
+            # so we only have one set of call args. NB: this currently
+            # doesn't work with MultiTemplateBuffer kernels. @bobrenjc93
+            # will add it in a subsequent PR.
+            V.graph.wrapper_code.generate_kernel_call(
+                kernel_name, call_args, arg_types=arg_types
+            )
+        else:
+            V.graph.wrapper_code.generate_kernel_call(
+                kernel_name, multi_call_args, arg_types=multi_call_arg_types
+            )
+
+        for ws in reversed(self.kernels[0].args.workspace_args):
+            V.graph.wrapper_code.generate_workspace_deallocation(ws)
+
+    def codegen_nan_check(self):
+        wrapper = V.graph.wrapper_code
+        seen: OrderedSet[str] = OrderedSet()
+        for k in self.kernels:
+            _, call_args, precompile_args, _ = k.args.python_argdefs()
+            for arg, precompile_arg in zip(call_args, precompile_args):
+                if arg in seen:
+                    continue
+                seen.add(arg)
+                if isinstance(precompile_arg, TensorArg):
+                    line = f"assert not {arg}.isnan().any().item()"
+                    wrapper.writeline(line)
+                    line = f"assert not {arg}.isinf().any().item()"
+                    wrapper.writeline(line)
+
+    @property
+    def removed_buffers(self):
+        return OrderedSet.intersection(*[k.removed_buffers for k in self.kernels])
+
+    @property
+    def inplaced_to_remove(self):
+        return OrderedSet.intersection(*[k.inplaced_to_remove for k in self.kernels])
+
+    @property
+    @cache_on_self
+    def inplace_update_buffers(self):
+        """
+        Make sure all kernels have the same inplace update mappings.
+        """
+        for k in self.kernels[1:]:
+            assert k.inplace_update_buffers == self.kernels[0].inplace_update_buffers
+        return self.kernels[0].inplace_update_buffers
+
+    def warn_mix_layout(self, kernel_name: str):
+        pass
+
+
+class MultiKernelCall:
+    """
+    This class is called at run time to actually run the kernel
+    """
+
+    def __init__(self, multi_kernel_name, kernels, arg_index):
+        assert len(kernels) >= 1
+        self._kernels = kernels
+        self.multi_kernel_name = multi_kernel_name
+
+        self.disable_cache = os.environ.get(
+            "TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE"
+        ) == "1" or is_metric_table_enabled("persistent_red_perf")
+
+        self.picked_kernel = None
+        self.arg_index = arg_index
+        if config.triton.multi_kernel > 1:
+            # manually force a subkernel to ease perf testing
+            picked_by_config = config.triton.multi_kernel - 2
+            assert picked_by_config < len(self._kernels)
+            # pyrefly: ignore [bad-assignment]
+            self.picked_kernel = picked_by_config
+        elif not self.disable_cache:
+            self.load_cache()
+
+        self._recorded = False
+
+    def cache_file_path(self):
+        key = code_hash(
+            ",".join(
+                [
+                    f"{k.fn.cache_key}{k.size_hints!r}{k.triton_meta!r}"
+                    for k in self.kernels
+                ]
+            )
+        )
+        _, _, path = get_path(key, "picked_kernel")
+        return pathlib.Path(path)
+
+    def load_cache(self):
+        assert self.picked_kernel is None
+        path = self.cache_file_path()
+        if path.exists():
+            with path.open() as fd:
+                # pyrefly: ignore [bad-assignment]
+                self.picked_kernel = int(fd.read())
+                # pyrefly: ignore [unsupported-operation]
+                assert self.picked_kernel >= 0 and self.picked_kernel < len(
+                    self._kernels
+                )
+                log.debug(
+                    "Load picked kernel %d from cache file %s", self.picked_kernel, path
+                )
+
+    def store_cache(self):
+        assert self.picked_kernel is not None
+        path = self.cache_file_path()
+        path.parent.mkdir(parents=True, exist_ok=True)
+
+        write_atomic(path, str(self.picked_kernel))
+        log.debug("Store picked kernel %d to cache file %s", self.picked_kernel, path)
+
+    @property
+    def kernels(self):
+        """
+        Read results from future.
+
+        This should be called after parallel compilation is done.
+        In case you call this before compilation is done,
+        it may slow down the parallel compilation.
+        """
+        for i, kernel in enumerate(self._kernels):
+            if isinstance(kernel, CodeCacheFuture):
+                self._kernels[i] = kernel.result()
+
+        return self._kernels
+
+    def benchmark_sub_kernels(self, *args, **kwargs):
+        """
+        Benchmark all the sub kernels and return the execution time
+        (in milliseconds) for each of time.
+
+        Unit test may mock this method to force a specific kernel to
+        be picked.
+        """
+
+        def wrap_fn(kernel, index):
+            def inner():
+                filtered_args = self._get_filtered_args(args, index)
+                args_clone, kwargs_clone = kernel.clone_args(*filtered_args, **kwargs)
+                return kernel.run(*args_clone, **kwargs_clone)
+
+            return inner
+
+        return [
+            benchmarker.benchmark(
+                wrap_fn(kernel, index),
+                # Currently the kernel type must be a CachingAutotuner
+                device=kernel.device_props.type,
+                rep=40,
+            )
+            for index, kernel in enumerate(self.kernels)
+        ]
+
+    def _get_filtered_args(self, args, index):
+        """
+        We pass in all arguments to all kernels into the MultiKernelCall
+        so when invoking a particular kernel we need to filter to only the
+        arguments for that specific kernel.
+        """
+
+        # This is sometimes invoked at runtime where V.graph is
+        # a NullHandler
+        if hasattr(V.graph, "cpp_wrapper") and V.graph.cpp_wrapper:
+            # for cpp-wrapper, we should not filter the args since
+            # we already have chosen a single kernel and arg set.
+            return args
+        return [item for s in self.arg_index[index] for item in args[s]]
+
+    # record_choice and lookup_choice are helper functions for cpp-wrapper
+    # codegen. The first pass use record_choice to keep the choice and
+    # the second pass do lookup by calling lookup_choice.
+    #
+    # An alternative that reused the multi-kernel cache does not work well
+    # since during codegen of the second pass, it's very hard to know the
+    # path for the cache file. Also reading the cache file need do some IO
+    # which can be slower.
+    @staticmethod
+    def record_choice(multi_kernel_name: str, picked_kernel_name: str):
+        """
+        Record the multi-kernel choice for cpp-wrapper after autotuning
+
+        We should do nothing if this function is not called during codegen.
+        """
+        from torch._inductor.graph import GraphLowering
+
+        if not isinstance(V.graph, GraphLowering):
+            return
+
+        if not V.graph.record_multi_kernel_choice:
+            return
+
+        V.graph.multi_kernel_to_choice[multi_kernel_name] = picked_kernel_name
+
+    @staticmethod
+    def lookup_choice(multi_kernel_name: str) -> str:
+        # this should always been done during cpp-wrapper codegen
+        assert (
+            V.graph.record_multi_kernel_choice
+            and multi_kernel_name in V.graph.multi_kernel_to_choice
+        )
+        # there should be no miss
+        return V.graph.multi_kernel_to_choice[multi_kernel_name]
+
+    def run(self, *args, **kwargs):
+        if self.picked_kernel is None:
+            timings = self.benchmark_sub_kernels(*args, **kwargs)
+            self.picked_kernel = timings.index(min(timings))
+            k0 = self.kernels[0]
+            log.debug(
+                "pick %dth sub-kernel in %s. Size hints %s. Reduction hint %s. Timings %s",
+                self.picked_kernel,
+                [k.inductor_meta.get("kernel_name") for k in self.kernels],
+                k0.size_hints,
+                k0.inductor_meta.get("reduction_hint"),
+                timings,
+            )
+            get_metric_table("persistent_red_perf").add_row(
+                functools.partial(self._metrics_table_row, timings)
+            )
+
+            if not self.disable_cache:
+                self.store_cache()
+
+        if not self._recorded:
+            self._recorded = True
+            picked_kernel_name = self.kernels[self.picked_kernel].inductor_meta.get(
+                "kernel_name"
+            )
+            assert picked_kernel_name is not None
+            self.record_choice(self.multi_kernel_name, picked_kernel_name)
+
+        run = self.kernels[self.picked_kernel].run  # type: ignore[method-assign]
+        filtered_args = self._get_filtered_args(args, self.picked_kernel)
+        run(*filtered_args, **kwargs)
+
+    def _metrics_table_row(self, timings):
+        def get_kernel_path(k):
+            return k.fn.fn.__code__.co_filename
+
+        k0 = self.kernels[0]
+        row = {
+            "size_hints": k0.size_hints,
+            "reduction_hint": k0.inductor_meta.get("reduction_hint"),
+        }
+        max_kernels = 4
+        assert len(timings) <= max_kernels
+        for i in range(max_kernels):
+            if i < len(self.kernels):
+                row[f"kernel{i}_path"] = get_kernel_path(self.kernels[i])
+                row[f"kernel{i}_latency"] = timings[i]
+            else:
+                row[f"kernel{i}_path"] = ""
+                row[f"kernel{i}_latency"] = ""
+        return row
+
+
+class SizeHintMultiKernel(MultiKernel):
+    """
+    Version of multi-kernel that generates kernels based on specified size hints.
+    Currently only performs 1-d search over hints; doesn't perform combinatorial n-d search
+    if n > 1 dynamic dimensions are specified.
+
+    e.g. matmul([s0, s1], [s1, s2]) with size-hints [64, 256] only generates 2 kernels,
+    based on tuning shapes ([64, 64], [64, 64]) and ([256, 256], [256, 256])
+    """
+
+    def __init__(self, kernels):
+        assert isinstance(kernels, dict) and len(kernels) >= 1
+
+        self.kernels, self.kernel_shape_keys = [], []
+        for shape_key, kernel in kernels.items():
+            self.kernels.append(kernel)
+            self.kernel_shape_keys.append(shape_key)
+        self.kernel_name = V.graph.wrapper_code.multi_kernel_state.define_kernel(
+            self.kernels, self.kernel_shape_keys
+        )
+
+        # need this since some code in inductor check if the kernel object has an args
+        # attribute to decide if it's a non-null kernel.
+        self.args = object()
+
+
+class SizeHintMultiKernelCall(MultiKernelCall):
+    """
+    Runtime class for size-hint multi-kernels.
+    Instead of having a plain list of kernels to benchmark over, keys them by input & output shapes,
+    and optionally perform shape-based selection. The pre-generated kernel is chosen based on the shape keys,
+    with the heuristic being log2 l1 distance between the pre-generated / runtime input & output shapes.
+    """
+
+    def __init__(self, multi_kernel_name, kernels, arg_index):
+        super().__init__(multi_kernel_name, list(kernels.values()), arg_index)
+        self._kernel_hints = list(kernels.keys())
+
+        # Caches results for unique shapes.
+        self._shape_cache = {}
+
+    def _get_shape_cache_key(self, *args, **kwargs):
+        """
+        Generate a cache key based on tensor shapes for shape-specialized dispatch.
+        """
+        shapes = []
+        for arg in args:
+            if hasattr(arg, "shape"):
+                shapes.append(tuple(arg.shape))
+        return tuple(shapes)
+
+    def _get_cached_shape_choice(self, cache_key):
+        """
+        Get cached kernel choice for a specific shape.
+        """
+        return self._shape_cache.get(cache_key)
+
+    def _cache_shape_choice(self, cache_key, kernel_idx):
+        """
+        Cache kernel choice for a specific shape.
+        """
+        self._shape_cache[cache_key] = kernel_idx
+
+    def _dist_heuristic(self, k1, k2):
+        """
+        log2 L1 distance heuristic for kernel selection.
+        """
+
+        def dist(x, y):
+            lx = math.log2(x) if x > 0 else -1
+            ly = math.log2(y) if y > 0 else -1
+            return abs(lx - ly)
+
+        out = 0
+        for s1, s2 in zip(k1, k2):
+            out += sum(dist(x, y) for x, y in zip(s1, s2))
+        return out
+
+    def run(self, *args, **kwargs):
+        cache_key = self._get_shape_cache_key(*args, **kwargs)
+        cached_choice = self._get_cached_shape_choice(cache_key)
+        if cached_choice is not None:
+            self.picked_kernel = cached_choice
+            log.debug(
+                "using cached shape-specialized choice %dth sub-kernel in %s. Cache key: %s",
+                self.picked_kernel,
+                [k.inductor_meta.get("kernel_name") for k in self.kernels],
+                cache_key,
+            )
+        else:
+            self._select_kernel_by_shape(*args, **kwargs)
+
+        if not self._recorded:
+            self._recorded = True
+            picked_kernel_name = self.kernels[self.picked_kernel].inductor_meta.get(
+                "kernel_name"
+            )
+            assert picked_kernel_name is not None
+            self.record_choice(self.multi_kernel_name, picked_kernel_name)
+
+        run = self.kernels[self.picked_kernel].run  # type: ignore[method-assign]
+        filtered_args = self._get_filtered_args(args, self.picked_kernel)
+        run(*filtered_args, **kwargs)
+
+    def _select_kernel_by_shape(self, *args, **kwargs):
+        """
+        Benchmark kernels for a particular shape and return the
+        best kernel for this shape.
+        """
+        shape_key = self._get_shape_cache_key(*args, **kwargs)
+        dists = [
+            self._dist_heuristic(shape_key, key) if key is not None else 2**62
+            for key in self._kernel_hints
+        ]
+        # pyrefly: ignore [bad-assignment]
+        self.picked_kernel = dists.index(min(dists))
+        self._cache_shape_choice(shape_key, self.picked_kernel)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/pallas.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/pallas.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca955ba5f351839209b9be5d3d1312dc69efb48f
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/pallas.py
@@ -0,0 +1,1840 @@
+from __future__ import annotations
+
+import hashlib
+import math
+from typing import Any, Optional, TYPE_CHECKING, Union
+
+import sympy  # noqa: TC002
+
+import torch  # noqa: TC001
+from torch.utils._ordered_set import OrderedSet
+from torch.utils._pallas import has_tpu_pallas
+
+from .. import config
+from ..runtime.runtime_utils import torch_dtype_to_jax
+from ..utils import get_fused_kernel_name, get_kernel_metadata
+from ..virtualized import V
+from .block_analysis import BlockPatternMatcher
+from .common import BackendFeature, CSEVariable, IndentedBuffer, OpOverrides
+from .simd import pexpr, SIMDKernel, SIMDScheduling
+
+
+if TYPE_CHECKING:
+    from collections.abc import Callable, Sequence
+
+    from ..ir import IRNode
+    from ..ops_handler import ReductionType
+    from ..scheduler import BaseSchedulerNode
+
+
+# Main function suffix used in generated Pallas code
+MAIN_SUFFIX = "main"
+
+# Logger for Pallas kernel code
+kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code")
+
+
+class PallasKernelWrapper:
+    """Wrapper to provide .run() interface for Pallas kernels"""
+
+    def __init__(
+        self, kernel_fn: Callable[..., Any], kernel_path: Optional[str] = None
+    ):
+        self.kernel_fn = kernel_fn
+        self.kernel_path = kernel_path
+        kernel_code_log.info("Pallas kernel path: %s", kernel_path)
+
+    def run(self, *args, stream=None, **kwargs):
+        """
+        Execute the Pallas kernel.
+
+        Args:
+            *args: Arguments to pass to the kernel function
+            stream: CUDA stream to pass to the kernel function
+            **kwargs: Additional keyword arguments for the kernel
+
+        Returns:
+            Result of the kernel execution
+        """
+        return self.kernel_fn(*args, stream=stream, **kwargs)
+
+
+class Unsupported(RuntimeError):
+    """Exception raised when an operation is not supported by the Pallas backend."""
+
+
+class PallasKernelOverrides(OpOverrides):
+    """
+    Map element-wise ops to JAX/Pallas operations.
+
+    For now, we use the default Python operators which are compatible
+    with JAX numpy broadcasting semantics.
+    """
+
+    @staticmethod
+    def sin(x: str) -> str:
+        return f"jnp.sin({x})"
+
+    @staticmethod
+    def cos(x: str) -> str:
+        return f"jnp.cos({x})"
+
+    @staticmethod
+    def tan(x: str) -> str:
+        return f"jnp.tan({x})"
+
+    @staticmethod
+    def sinh(x: str) -> str:
+        return f"jnp.sinh({x})"
+
+    @staticmethod
+    def cosh(x: str) -> str:
+        return f"jnp.cosh({x})"
+
+    @staticmethod
+    def tanh(x: str) -> str:
+        return f"jnp.tanh({x})"
+
+    @staticmethod
+    def asin(x: str) -> str:
+        return f"jnp.arcsin({x})"
+
+    @staticmethod
+    def acos(x: str) -> str:
+        return f"jnp.arccos({x})"
+
+    @staticmethod
+    def atan(x: str) -> str:
+        return f"jnp.arctan({x})"
+
+    @staticmethod
+    def exp(x: str) -> str:
+        return f"jnp.exp({x})"
+
+    @staticmethod
+    def exp2(x: str) -> str:
+        return f"jnp.exp2({x})"
+
+    @staticmethod
+    def expm1(x: str) -> str:
+        return f"jnp.expm1({x})"
+
+    @staticmethod
+    def log(x: str) -> str:
+        return f"jnp.log({x})"
+
+    @staticmethod
+    def log10(x: str) -> str:
+        return f"jnp.log10({x})"
+
+    @staticmethod
+    def log2(x: str) -> str:
+        return f"jnp.log2({x})"
+
+    @staticmethod
+    def log1p(x: str) -> str:
+        return f"jnp.log1p({x})"
+
+    @staticmethod
+    def sqrt(x: str) -> str:
+        return f"jnp.sqrt({x})"
+
+    @staticmethod
+    def rsqrt(x: str) -> str:
+        return f"(1.0 / jnp.sqrt({x}))"
+
+    @staticmethod
+    def abs(x: str) -> str:
+        return f"jnp.abs({x})"
+
+    @staticmethod
+    def neg(x: str) -> str:
+        return f"(-{x})"
+
+    @staticmethod
+    def floor(x: str) -> str:
+        return f"jnp.floor({x})"
+
+    @staticmethod
+    def ceil(x: str) -> str:
+        return f"jnp.ceil({x})"
+
+    @staticmethod
+    def trunc(x: str) -> str:
+        return f"jnp.trunc({x})"
+
+    @staticmethod
+    def round(x: str) -> str:
+        return f"jnp.round({x})"
+
+    @staticmethod
+    def sigmoid(x: str) -> str:
+        return f"(1.0 / (1.0 + jnp.exp(-{x})))"
+
+    @staticmethod
+    def relu(x: str) -> str:
+        return f"jnp.maximum({x}, 0)"
+
+    @staticmethod
+    def pow(a: str, b: str) -> str:
+        return f"jnp.power({a}, {b})"
+
+    @staticmethod
+    def maximum(a: str, b: str) -> str:
+        return f"jnp.maximum({a}, {b})"
+
+    @staticmethod
+    def minimum(a: str, b: str) -> str:
+        return f"jnp.minimum({a}, {b})"
+
+    @staticmethod
+    def where(cond: str, a: str, b: str) -> str:
+        return f"jnp.where({cond}, {a}, {b})"
+
+    @staticmethod
+    def to_dtype(
+        x: str,
+        dtype: torch.dtype,
+        src_dtype: Optional[torch.dtype] = None,
+        use_compute_types: bool = True,
+    ) -> str:
+        jax_dtype = torch_dtype_to_jax(dtype)
+        # Wrap in jnp.asarray to handle scalars from integer indexing
+        return f"jnp.asarray({x}).astype({jax_dtype})"
+
+    @staticmethod
+    def to_dtype_bitcast(x: str, dtype: torch.dtype, src_dtype: torch.dtype) -> str:
+        """Bitcast a value from one dtype to another with the same size."""
+        jax_dtype = torch_dtype_to_jax(dtype)
+        jax_src_dtype = torch_dtype_to_jax(src_dtype)
+        # First ensure the value is the correct source dtype, then bitcast
+        return f"jax.lax.bitcast_convert_type(jnp.asarray({x}).astype({jax_src_dtype}), {jax_dtype})"
+
+    @staticmethod
+    def index_expr(expr: sympy.Expr, dtype: torch.dtype) -> str:
+        """Convert a sympy expression to a JAX array indexing expression."""
+        from ..utils import get_bounds_index_expr
+
+        idx_str = V.kernel.kexpr(V.kernel.prepare_indexing(expr))
+        var = V.kernel.cse.generate(
+            V.kernel.compute, idx_str, bounds=get_bounds_index_expr(expr)
+        )
+        return PallasKernelOverrides.to_dtype(var, dtype)
+
+    @staticmethod
+    def constant(val, dtype: torch.dtype) -> str:
+        """Convert a constant value to JAX representation."""
+        jax_dtype = torch_dtype_to_jax(dtype)
+        if dtype == torch.bool:
+            return "True" if val else "False"
+        # Handle special float values
+        if isinstance(val, float):
+            if math.isnan(val):
+                return "jnp.nan"
+            if math.isinf(val):
+                return "jnp.inf" if val > 0 else "-jnp.inf"
+        return f"jnp.array({val}, dtype={jax_dtype})"
+
+    @staticmethod
+    def real(x: str) -> str:
+        return f"jnp.real({x})"
+
+    @staticmethod
+    def imag(x: str) -> str:
+        return f"jnp.imag({x})"
+
+    @staticmethod
+    def conj(x: str) -> str:
+        return f"jnp.conj({x})"
+
+    @staticmethod
+    def angle(x: str) -> str:
+        return f"jnp.angle({x})"
+
+    @staticmethod
+    def view_as_real(x: str) -> str:
+        """View complex tensor as real tensor with extra dimension."""
+        return f"jnp.stack([jnp.real({x}), jnp.imag({x})], axis=-1)"
+
+    @staticmethod
+    def view_as_complex(x: str) -> str:
+        """View real tensor as complex tensor."""
+        return f"({x}[..., 0] + 1j * {x}[..., 1])"
+
+    # Comparison operations
+    @staticmethod
+    def eq(a: str, b: str) -> str:
+        return f"({a} == {b})"
+
+    @staticmethod
+    def ne(a: str, b: str) -> str:
+        return f"({a} != {b})"
+
+    @staticmethod
+    def lt(a: str, b: str) -> str:
+        return f"({a} < {b})"
+
+    @staticmethod
+    def le(a: str, b: str) -> str:
+        return f"({a} <= {b})"
+
+    @staticmethod
+    def gt(a: str, b: str) -> str:
+        return f"({a} > {b})"
+
+    @staticmethod
+    def isnan(x: str) -> str:
+        return f"jnp.isnan({x})"
+
+    @staticmethod
+    def isinf(x: str) -> str:
+        return f"jnp.isinf({x})"
+
+    @staticmethod
+    def isfinite(x: str) -> str:
+        return f"jnp.isfinite({x})"
+
+    @staticmethod
+    def ge(a: str, b: str) -> str:
+        return f"({a} >= {b})"
+
+    # Logical operations
+    @staticmethod
+    def logical_and(a: str, b: str) -> str:
+        return f"jnp.logical_and({a}, {b})"
+
+    @staticmethod
+    def logical_or(a: str, b: str) -> str:
+        return f"jnp.logical_or({a}, {b})"
+
+    @staticmethod
+    def logical_not(x: str) -> str:
+        return f"jnp.logical_not({x})"
+
+    @staticmethod
+    def logical_xor(a: str, b: str) -> str:
+        return f"jnp.logical_xor({a}, {b})"
+
+    # Math operations
+    @staticmethod
+    def atan2(a: str, b: str) -> str:
+        return f"jnp.arctan2({a}, {b})"
+
+    @staticmethod
+    def hypot(a: str, b: str) -> str:
+        return f"jnp.hypot({a}, {b})"
+
+    @staticmethod
+    def fmod(a: str, b: str) -> str:
+        return f"jnp.fmod({a}, {b})"
+
+    @staticmethod
+    def remainder(a: str, b: str) -> str:
+        return f"jnp.remainder({a}, {b})"
+
+    @staticmethod
+    def truncdiv(a: str, b: str) -> str:
+        # Truncated division (rounds toward zero)
+        # For integers: sign(a)*sign(b) * (abs(a) // abs(b))
+        return f"(jnp.sign({a}) * jnp.sign({b}) * (jnp.abs({a}) // jnp.abs({b}))).astype({a}.dtype)"
+
+    @staticmethod
+    def floordiv(a: str, b: str) -> str:
+        return f"({a} // {b})"
+
+    @staticmethod
+    def clamp(x: str, min_val: str, max_val: str) -> str:
+        return f"jnp.clip({x}, {min_val}, {max_val})"
+
+    @staticmethod
+    def clip(x: str, min_val: str, max_val: str) -> str:
+        return f"jnp.clip({x}, {min_val}, {max_val})"
+
+    # Sign operations
+    @staticmethod
+    def sign(x: str) -> str:
+        return f"jnp.sign({x})"
+
+    @staticmethod
+    def signbit(x: str) -> str:
+        return f"jnp.signbit({x})"
+
+    # Special math functions
+    @staticmethod
+    def erf(x: str) -> str:
+        return f"jax.scipy.special.erf({x})"
+
+    @staticmethod
+    def erfc(x: str) -> str:
+        return f"jax.scipy.special.erfc({x})"
+
+    @staticmethod
+    def erfinv(x: str) -> str:
+        return f"jax.scipy.special.erfinv({x})"
+
+    @staticmethod
+    def lgamma(x: str) -> str:
+        return f"jax.scipy.special.gammaln({x})"
+
+    @staticmethod
+    def digamma(x: str) -> str:
+        return f"jax.scipy.special.digamma({x})"
+
+    @staticmethod
+    def bessel_j0(x: str) -> str:
+        # bessel_jn requires float64 and has numerical issues at x=0 (returns NaN)
+        # bessel_jn(x, v=n) returns array of shape (n+1, ...) with J_0 to J_n
+        # Handle by: convert to float64, compute, handle x=0, convert back
+        # J0(0) = 1.0
+        return (
+            f"jnp.where({x}.astype(jnp.float64) == 0.0, 1.0, "
+            f"jax.scipy.special.bessel_jn({x}.astype(jnp.float64), v=0)[0])"
+            f".astype({x}.dtype)"
+        )
+
+    @staticmethod
+    def bessel_j1(x: str) -> str:
+        # bessel_jn requires float64 and has numerical issues at x=0 (returns NaN)
+        # bessel_jn(x, v=n) returns array of shape (n+1, ...) with J_0 to J_n
+        # Handle by: convert to float64, compute, handle x=0, convert back
+        # J1(0) = 0.0
+        return (
+            f"jnp.where({x}.astype(jnp.float64) == 0.0, 0.0, "
+            f"jax.scipy.special.bessel_jn({x}.astype(jnp.float64), v=1)[1])"
+            f".astype({x}.dtype)"
+        )
+
+    @staticmethod
+    def modified_bessel_i0(x: str) -> str:
+        # Modified Bessel function of the first kind I_0(x)
+        # I_0(x) = bessel_i0e(x) * exp(|x|) where bessel_i0e is the scaled version
+        return f"jax.lax.bessel_i0e({x}) * jnp.exp(jnp.abs({x}))"
+
+    @staticmethod
+    def modified_bessel_i1(x: str) -> str:
+        # Modified Bessel function of the first kind I_1(x)
+        # I_1(x) = bessel_i1e(x) * exp(|x|) where bessel_i1e is the scaled version
+        return f"jax.lax.bessel_i1e({x}) * jnp.exp(jnp.abs({x}))"
+
+    @staticmethod
+    def spherical_bessel_j0(x: str) -> str:
+        # Spherical Bessel function of the first kind j_0(x) = sin(x) / x
+        # Handle x=0: j_0(0) = 1
+        return f"jnp.where({x} == 0.0, 1.0, jnp.sin({x}) / {x})"
+
+    @staticmethod
+    def i0(x: str) -> str:
+        # Modified Bessel function I_0 (same as modified_bessel_i0)
+        return f"jax.lax.bessel_i0e({x}) * jnp.exp(jnp.abs({x}))"
+
+    @staticmethod
+    def i0e(x: str) -> str:
+        # Exponentially scaled modified Bessel function I_0
+        return f"jax.lax.bessel_i0e({x})"
+
+    @staticmethod
+    def i1(x: str) -> str:
+        # Modified Bessel function I_1 (same as modified_bessel_i1)
+        return f"jax.lax.bessel_i1e({x}) * jnp.exp(jnp.abs({x}))"
+
+    @staticmethod
+    def i1e(x: str) -> str:
+        # Exponentially scaled modified Bessel function I_1
+        return f"jax.lax.bessel_i1e({x})"
+
+    @staticmethod
+    def gammainc(x: str, y: str) -> str:
+        # Regularized lower incomplete gamma function P(a, x)
+        # Note: PyTorch uses gammainc(input, other) where input is a (shape param)
+        return f"jax.scipy.special.gammainc({x}, {y})"
+
+    @staticmethod
+    def gammaincc(x: str, y: str) -> str:
+        # Regularized upper incomplete gamma function Q(a, x)
+        return f"jax.scipy.special.gammaincc({x}, {y})"
+
+    @staticmethod
+    def igamma(x: str, y: str) -> str:
+        # Regularized lower incomplete gamma function (alias for gammainc)
+        return f"jax.scipy.special.gammainc({x}, {y})"
+
+    @staticmethod
+    def igammac(x: str, y: str) -> str:
+        # Regularized upper incomplete gamma function (alias for gammaincc)
+        return f"jax.scipy.special.gammaincc({x}, {y})"
+
+    @staticmethod
+    def polygamma(x: str, y: str) -> str:
+        # Polygamma function psi^(n)(x), x is order n, y is the value
+        # Note: JAX uses polygamma(n, x) where n is integer order
+        return f"jax.scipy.special.polygamma({x}.astype(jnp.int32), {y})"
+
+    @staticmethod
+    def ndtri(x: str) -> str:
+        # Inverse of the standard normal CDF
+        return f"jax.scipy.special.ndtri({x})"
+
+    @staticmethod
+    def zeta(x: str, y: str) -> str:
+        # Hurwitz zeta function zeta(x, q) = sum_{k=0}^inf 1/(k+q)^x
+        return f"jax.scipy.special.zeta({x}, {y})"
+
+    @staticmethod
+    def xlogy(x: str, y: str) -> str:
+        # x * log(y), with proper handling of x=0
+        return f"jax.scipy.special.xlogy({x}, {y})"
+
+    @staticmethod
+    def xlog1py(x: str, y: str) -> str:
+        # x * log1p(y), with proper handling of x=0
+        return f"jax.scipy.special.xlog1py({x}, {y})"
+
+    @staticmethod
+    def chebyshev_polynomial_t(x: str, n: str) -> str:
+        # Chebyshev polynomial of the first kind T_n(x)
+        # For |x| <= 1: T_n(x) = cos(n * arccos(x))
+        # For x > 1: T_n(x) = cosh(n * arccosh(x))
+        # For x < -1: T_n(x) = (-1)^n * cosh(n * arccosh(-x))
+        return (
+            f"jnp.where(jnp.abs({x}) <= 1, "
+            f"jnp.cos({n} * jnp.arccos(jnp.clip({x}, -1, 1))), "
+            f"jnp.where({x} > 1, "
+            f"jnp.cosh({n} * jnp.arccosh(jnp.maximum({x}, 1.0))), "
+            f"((-1.0) ** {n}) * jnp.cosh({n} * jnp.arccosh(jnp.maximum(-{x}, 1.0)))))"
+        )
+
+    @staticmethod
+    def chebyshev_polynomial_u(x: str, n: str) -> str:
+        # Chebyshev polynomial of the second kind U_n(x)
+        # For |x| < 1: U_n(x) = sin((n+1) * arccos(x)) / sqrt(1 - x^2)
+        # For x = 1: U_n(1) = n+1
+        # For x = -1: U_n(-1) = (-1)^n * (n+1)
+        # For x > 1: U_n(x) = sinh((n+1) * arccosh(x)) / sqrt(x^2 - 1)
+        # For x < -1: U_n(x) = (-1)^n * U_n(-x) (symmetry)
+        return (
+            f"jnp.where(jnp.abs({x}) < 1, "
+            f"jnp.sin(({n} + 1) * jnp.arccos(jnp.clip({x}, -1, 1))) / "
+            f"jnp.sqrt(jnp.maximum(1 - {x}**2, 1e-10)), "
+            f"jnp.where({x} >= 1, "
+            f"jnp.where({x} == 1, {n} + 1.0, "
+            f"jnp.sinh(({n} + 1) * jnp.arccosh(jnp.maximum({x}, 1.0))) / "
+            f"jnp.sqrt(jnp.maximum({x}**2 - 1, 1e-10))), "
+            f"jnp.where({x} == -1, ((-1.0) ** {n}) * ({n} + 1.0), "
+            f"((-1.0) ** {n}) * jnp.sinh(({n} + 1) * jnp.arccosh(jnp.maximum(-{x}, 1.0))) / "
+            f"jnp.sqrt(jnp.maximum({x}**2 - 1, 1e-10)))))"
+        )
+
+    @staticmethod
+    def chebyshev_polynomial_v(x: str, n: str) -> str:
+        # Chebyshev polynomial of the third kind V_n(x)
+        # V_n(x) = (T_n(x) - T_{n+1}(x)) / (1 - x) for x != 1
+        # V_n(1) = 1, recurrence: V_0 = 1, V_1 = 2x - 1, V_n = 2x*V_{n-1} - V_{n-2}
+        # Explicit: V_0 = 1, V_1 = 2x-1, V_2 = 4x^2-2x-1, V_3 = 8x^3-4x^2-4x+1
+        return (
+            f"jnp.where({n} == 0, jnp.ones_like({x}), "
+            f"jnp.where({n} == 1, 2*{x} - 1, "
+            f"jnp.where({n} == 2, 4*{x}**2 - 2*{x} - 1, "
+            f"jnp.where({n} == 3, 8*{x}**3 - 4*{x}**2 - 4*{x} + 1, "
+            f"jnp.where({n} == 4, 16*{x}**4 - 8*{x}**3 - 12*{x}**2 + 4*{x} + 1, "
+            f"jnp.where({n} == 5, 32*{x}**5 - 16*{x}**4 - 32*{x}**3 + 12*{x}**2 + 6*{x} - 1, "
+            f"jnp.zeros_like({x})))))))"
+        )
+
+    @staticmethod
+    def chebyshev_polynomial_w(x: str, n: str) -> str:
+        # Chebyshev polynomial of the fourth kind W_n(x)
+        # W_n(x) = (T_n(x) + T_{n+1}(x)) / (1 + x) for x != -1
+        # W_n(-1) = (-1)^n, recurrence: W_0 = 1, W_1 = 2x + 1, W_n = 2x*W_{n-1} - W_{n-2}
+        # Explicit: W_0 = 1, W_1 = 2x+1, W_2 = 4x^2+2x-1, W_3 = 8x^3+4x^2-4x-1
+        return (
+            f"jnp.where({n} == 0, jnp.ones_like({x}), "
+            f"jnp.where({n} == 1, 2*{x} + 1, "
+            f"jnp.where({n} == 2, 4*{x}**2 + 2*{x} - 1, "
+            f"jnp.where({n} == 3, 8*{x}**3 + 4*{x}**2 - 4*{x} - 1, "
+            f"jnp.where({n} == 4, 16*{x}**4 + 8*{x}**3 - 12*{x}**2 - 4*{x} + 1, "
+            f"jnp.where({n} == 5, 32*{x}**5 + 16*{x}**4 - 32*{x}**3 - 12*{x}**2 + 6*{x} + 1, "
+            f"jnp.zeros_like({x})))))))"
+        )
+
+    @staticmethod
+    def shifted_chebyshev_polynomial_t(x: str, n: str) -> str:
+        # Shifted Chebyshev polynomial of the first kind T*_n(x) = T_n(2x - 1)
+        # T_n(y) where y = 2x - 1
+        # Use same formula as chebyshev_polynomial_t
+        y = f"(2 * {x} - 1)"
+        return (
+            f"jnp.where(jnp.abs({y}) <= 1, "
+            f"jnp.cos({n} * jnp.arccos(jnp.clip({y}, -1, 1))), "
+            f"jnp.where({y} > 1, "
+            f"jnp.cosh({n} * jnp.arccosh(jnp.maximum({y}, 1.0))), "
+            f"((-1.0) ** {n}) * jnp.cosh({n} * jnp.arccosh(jnp.maximum(-{y}, 1.0)))))"
+        )
+
+    @staticmethod
+    def shifted_chebyshev_polynomial_u(x: str, n: str) -> str:
+        # Shifted Chebyshev polynomial of the second kind U*_n(x) = U_n(2x - 1)
+        # Use same formula as chebyshev_polynomial_u
+        y = f"(2 * {x} - 1)"
+        return (
+            f"jnp.where(jnp.abs({y}) < 1, "
+            f"jnp.sin(({n} + 1) * jnp.arccos(jnp.clip({y}, -1, 1))) / "
+            f"jnp.sqrt(jnp.maximum(1 - ({y})**2, 1e-10)), "
+            f"jnp.where({y} >= 1, "
+            f"jnp.where({y} == 1, {n} + 1.0, "
+            f"jnp.sinh(({n} + 1) * jnp.arccosh(jnp.maximum({y}, 1.0))) / "
+            f"jnp.sqrt(jnp.maximum({y}**2 - 1, 1e-10))), "
+            f"jnp.where({y} == -1, ((-1.0) ** {n}) * ({n} + 1.0), "
+            f"((-1.0) ** {n}) * jnp.sinh(({n} + 1) * jnp.arccosh(jnp.maximum(-{y}, 1.0))) / "
+            f"jnp.sqrt(jnp.maximum({y}**2 - 1, 1e-10)))))"
+        )
+
+    @staticmethod
+    def shifted_chebyshev_polynomial_v(x: str, n: str) -> str:
+        # Shifted Chebyshev polynomial of the third kind V*_n(x) = V_n(2x - 1)
+        y = f"(2 * {x} - 1)"  # shifted variable
+        return (
+            f"jnp.where({n} == 0, jnp.ones_like({x}), "
+            f"jnp.where({n} == 1, 2*{y} - 1, "
+            f"jnp.where({n} == 2, 4*{y}**2 - 2*{y} - 1, "
+            f"jnp.where({n} == 3, 8*{y}**3 - 4*{y}**2 - 4*{y} + 1, "
+            f"jnp.where({n} == 4, 16*{y}**4 - 8*{y}**3 - 12*{y}**2 + 4*{y} + 1, "
+            f"jnp.where({n} == 5, 32*{y}**5 - 16*{y}**4 - 32*{y}**3 + 12*{y}**2 + 6*{y} - 1, "
+            f"jnp.zeros_like({x})))))))"
+        )
+
+    @staticmethod
+    def shifted_chebyshev_polynomial_w(x: str, n: str) -> str:
+        # Shifted Chebyshev polynomial of the fourth kind W*_n(x) = W_n(2x - 1)
+        y = f"(2 * {x} - 1)"  # shifted variable
+        return (
+            f"jnp.where({n} == 0, jnp.ones_like({x}), "
+            f"jnp.where({n} == 1, 2*{y} + 1, "
+            f"jnp.where({n} == 2, 4*{y}**2 + 2*{y} - 1, "
+            f"jnp.where({n} == 3, 8*{y}**3 + 4*{y}**2 - 4*{y} - 1, "
+            f"jnp.where({n} == 4, 16*{y}**4 + 8*{y}**3 - 12*{y}**2 - 4*{y} + 1, "
+            f"jnp.where({n} == 5, 32*{y}**5 + 16*{y}**4 - 32*{y}**3 - 12*{y}**2 + 6*{y} + 1, "
+            f"jnp.zeros_like({x})))))))"
+        )
+
+    @staticmethod
+    def hermite_polynomial_h(x: str, n: str) -> str:
+        # Physicist's Hermite polynomial H_n(x)
+        # H_n(x) = 2^n * x^n - n*(n-1)/2 * 2^(n-2) * x^(n-2) + ...
+        # Use explicit formula: H_n(x) = n! * sum_{m=0}^{n//2} (-1)^m / (m! * (n-2m)!) * (2x)^(n-2m)
+        # For simplicity, use the relation: H_n(x) = 2^(n/2) * He_n(x * sqrt(2)) where He is probabilist's
+        # Actually simpler: use recurrence or closed form
+        # H_0 = 1, H_1 = 2x, H_2 = 4x^2 - 2, H_3 = 8x^3 - 12x
+        return (
+            f"jnp.where({n} == 0, jnp.ones_like({x}), "
+            f"jnp.where({n} == 1, 2 * {x}, "
+            f"jnp.where({n} == 2, 4 * {x}**2 - 2, "
+            f"jnp.where({n} == 3, 8 * {x}**3 - 12 * {x}, "
+            f"jnp.where({n} == 4, 16 * {x}**4 - 48 * {x}**2 + 12, "
+            f"jnp.where({n} == 5, 32 * {x}**5 - 160 * {x}**3 + 120 * {x}, "
+            f"jnp.zeros_like({x})))))))"  # Fallback for higher n
+        )
+
+    @staticmethod
+    def hermite_polynomial_he(x: str, n: str) -> str:
+        # Probabilist's Hermite polynomial He_n(x)
+        # He_0 = 1, He_1 = x, He_2 = x^2 - 1, He_3 = x^3 - 3x
+        return (
+            f"jnp.where({n} == 0, jnp.ones_like({x}), "
+            f"jnp.where({n} == 1, {x}, "
+            f"jnp.where({n} == 2, {x}**2 - 1, "
+            f"jnp.where({n} == 3, {x}**3 - 3 * {x}, "
+            f"jnp.where({n} == 4, {x}**4 - 6 * {x}**2 + 3, "
+            f"jnp.where({n} == 5, {x}**5 - 10 * {x}**3 + 15 * {x}, "
+            f"jnp.zeros_like({x})))))))"  # Fallback for higher n
+        )
+
+    @staticmethod
+    def laguerre_polynomial_l(x: str, n: str) -> str:
+        # Laguerre polynomial L_n(x)
+        # L_0 = 1, L_1 = 1 - x, L_2 = (x^2 - 4x + 2)/2, L_3 = (-x^3 + 9x^2 - 18x + 6)/6
+        return (
+            f"jnp.where({n} == 0, jnp.ones_like({x}), "
+            f"jnp.where({n} == 1, 1 - {x}, "
+            f"jnp.where({n} == 2, ({x}**2 - 4*{x} + 2) / 2, "
+            f"jnp.where({n} == 3, (-{x}**3 + 9*{x}**2 - 18*{x} + 6) / 6, "
+            f"jnp.where({n} == 4, ({x}**4 - 16*{x}**3 + 72*{x}**2 - 96*{x} + 24) / 24, "
+            f"jnp.where({n} == 5, (-{x}**5 + 25*{x}**4 - 200*{x}**3 + 600*{x}**2 - 600*{x} + 120) / 120, "
+            f"jnp.zeros_like({x})))))))"  # Fallback for higher n
+        )
+
+    @staticmethod
+    def legendre_polynomial_p(x: str, n: str) -> str:
+        # Legendre polynomial P_n(x)
+        # P_0 = 1, P_1 = x, P_2 = (3x^2 - 1)/2, P_3 = (5x^3 - 3x)/2
+        return (
+            f"jnp.where({n} == 0, jnp.ones_like({x}), "
+            f"jnp.where({n} == 1, {x}, "
+            f"jnp.where({n} == 2, (3 * {x}**2 - 1) / 2, "
+            f"jnp.where({n} == 3, (5 * {x}**3 - 3 * {x}) / 2, "
+            f"jnp.where({n} == 4, (35 * {x}**4 - 30 * {x}**2 + 3) / 8, "
+            f"jnp.where({n} == 5, (63 * {x}**5 - 70 * {x}**3 + 15 * {x}) / 8, "
+            f"jnp.zeros_like({x})))))))"  # Fallback for higher n
+        )
+
+    # Reciprocal and square
+    @staticmethod
+    def reciprocal(x: str) -> str:
+        return f"jnp.reciprocal({x})"
+
+    @staticmethod
+    def square(x: str) -> str:
+        return f"jnp.square({x})"
+
+    # Additional operations
+    @staticmethod
+    def fma(a: str, b: str, c: str) -> str:
+        """Fused multiply-add: a * b + c"""
+        return f"jnp.fma({a}, {b}, {c})"
+
+    @staticmethod
+    def copysign(a: str, b: str) -> str:
+        return f"jnp.copysign({a}, {b})"
+
+    @staticmethod
+    def nextafter(a: str, b: str) -> str:
+        return f"jnp.nextafter({a}, {b})"
+
+    @staticmethod
+    def ldexp(a: str, b: str) -> str:
+        return f"jnp.ldexp({a}, {b})"
+
+    @staticmethod
+    def frexp(x: str) -> str:
+        return f"jnp.frexp({x})"
+
+    @staticmethod
+    def modf(x: str) -> str:
+        return f"jnp.modf({x})"
+
+    # Bitwise operations
+    @staticmethod
+    def bitwise_and(a: str, b: str) -> str:
+        return f"jnp.bitwise_and({a}, {b})"
+
+    @staticmethod
+    def bitwise_or(a: str, b: str) -> str:
+        return f"jnp.bitwise_or({a}, {b})"
+
+    @staticmethod
+    def bitwise_xor(a: str, b: str) -> str:
+        return f"jnp.bitwise_xor({a}, {b})"
+
+    @staticmethod
+    def bitwise_not(x: str) -> str:
+        return f"jnp.bitwise_not({x})"
+
+    @staticmethod
+    def left_shift(a: str, b: str) -> str:
+        return f"jnp.left_shift({a}, {b})"
+
+    @staticmethod
+    def right_shift(a: str, b: str) -> str:
+        return f"jnp.right_shift({a}, {b})"
+
+
+class PallasKernel(SIMDKernel):
+    """
+    Pallas kernel for elementwise operations with support for strided/scatter access.
+
+    Strategy:
+    - Convert index expressions to JAX-compatible array slicing
+    - Load/store using indexed access: "in_ptrX[slice]" or full-array "in_ptrX[...]"
+    - Compute expression with Python operators (compatible with jax.numpy broadcasting)
+    - Generate Python code that defines a Pallas kernel and a host entrypoint.
+    - Use async_compile.pallas path to compile and load Python code.
+
+    For GPU (Triton backend):
+    - Use masked loads/stores with power-of-2 block sizes to handle non-power-of-2 shapes
+    """
+
+    overrides = PallasKernelOverrides  # type: ignore[assignment]
+    kexpr: Callable[[sympy.Expr], str] = pexpr  # Use Python expression printer
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        # Determine device type once at initialization
+        device = V.graph.get_current_device_or_throw()
+        self.is_gpu = device.type == "cuda"
+        self.use_masked_ops: bool | None = None
+        self.tensor_masks = {}  # Map tensor name to mask variable name
+        # Track which output param each store uses: list of (out_ptr_name, store_line)
+        self.store_with_output: list[tuple[str, str]] = []
+        # Track load index expressions for argmax/argmin axis detection
+        self.load_index_exprs: dict[str, sympy.Expr] = {}
+
+    def check_bounds(
+        self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
+    ) -> None:
+        """Check array bounds for indirect indexing."""
+        # For now, skip explicit bounds checking as JAX/Pallas handles this internally
+        # TODO: Implement explicit bounds checking with assertions if needed
+
+    def _get_index_str(self, index: sympy.Expr) -> str:
+        """
+        Convert an index expression to a string suitable for Pallas indexing.
+
+        Pallas operates on full arrays, so we need to convert index expressions
+        to JAX array slicing. For example:
+        - x0 -> "..." (contiguous access, full array)
+        - 2*x0 -> "::2" (strided access with stride 2)
+        - 2*x0 + 1 -> "1::2" (strided access with offset 1, stride 2)
+
+        Args:
+            index: The indexing expression to convert
+
+        Returns:
+            The indexing string to use in generated code
+        """
+        # Prepare and simplify the index
+        prepared_index = self.prepare_indexing(index)
+
+        # For simple single-symbol access (contiguous case), we can use [...]
+        # which is more efficient as it operates on the entire array at once
+        if isinstance(prepared_index, sympy.Symbol):
+            return "..."
+        elif prepared_index.is_Integer:
+            # Scalar index
+            return str(prepared_index)
+        else:
+            # Complex expression (strided/scatter access)
+            # Try to extract stride and offset for common patterns
+            return self._convert_to_jax_slice(prepared_index)
+
+    def _convert_to_jax_slice(self, index: sympy.Expr) -> str:
+        """
+        Convert a sympy index expression to JAX slice notation.
+
+        Handles common patterns like:
+        - stride*var -> ::stride
+        - stride*var + offset -> offset::stride
+
+        For more complex patterns, falls back to explicit indexing.
+        Uses BlockPatternMatcher for robust pattern matching.
+        """
+        # Get the iteration variables for this kernel
+        if not self.range_trees:
+            return "..."
+
+        # Simplify the index
+        index = V.graph.sizevars.simplify(index)
+        free_symbols = index.free_symbols
+
+        # Get iteration variables from range_tree_nodes
+        iter_vars = OrderedSet(self.range_tree_nodes.keys())
+
+        # Find which iteration variable(s) are used
+        used_vars = free_symbols & iter_vars
+
+        if len(used_vars) == 0:
+            # No iteration variables, this is a constant index
+            return str(index)
+        elif len(used_vars) == 1:
+            # Single iteration variable - try to extract stride and offset using BlockPatternMatcher
+            var = next(iter(used_vars))
+
+            # Get the subexpression involving this variable
+            var_expr = BlockPatternMatcher.get_subexpr_involving_symbol(index, var)
+
+            # Try to match affine pattern: stride * var
+            stride = BlockPatternMatcher.match_affine_block_expr(var_expr, var)
+
+            if stride is not None:
+                # Extract the constant offset (terms not involving var)
+                offset = index - var_expr
+                offset = V.graph.sizevars.simplify(offset)
+
+                # Generate JAX slice notation
+                if stride == 1 and offset == 0:
+                    # Contiguous access
+                    return "..."
+                elif offset == 0:
+                    # Pure stride: ::stride
+                    stride_str = self.kexpr(stride)
+                    return f"::{stride_str}"
+                else:
+                    # Offset + stride: offset::stride
+                    offset_str = self.kexpr(offset)
+                    stride_str = self.kexpr(stride)
+                    return f"{offset_str}::{stride_str}"
+            else:
+                # Couldn't match affine pattern, fall back to original logic
+                offset = index - var_expr
+                offset = V.graph.sizevars.simplify(offset)
+                if offset == 0 and var_expr == var:
+                    # Just the variable itself, unit stride
+                    return "..."
+        elif len(used_vars) > 1:
+            # Multi-dimensional indexing
+            # For contiguous multi-dim access, all terms should have unit stride
+            all_unit_stride = True
+            for var in used_vars:
+                var_expr = BlockPatternMatcher.get_subexpr_involving_symbol(index, var)
+                stride = BlockPatternMatcher.match_affine_block_expr(var_expr, var)
+                if stride != 1:
+                    all_unit_stride = False
+                    break
+
+            if all_unit_stride:
+                # Contiguous multi-dimensional access
+                return "..."
+            else:
+                # Strided multi-dimensional access - requires advanced indexing
+                # For now, use ellipsis which may work for many cases
+                # TODO: Implement proper multi-dimensional strided indexing
+                return "..."
+
+        # For complex cases, raise an error
+        return self._generate_index_array(index)
+
+    def _generate_index_array(self, index: sympy.Expr) -> str:
+        """
+        Generate JAX code to compute an index array for complex indexing patterns.
+
+        For very complex patterns that can't be expressed as simple slices,
+        we need to compute the indices explicitly. This is not yet fully implemented.
+        """
+        # For now, raise an error for complex patterns
+        # TODO: Implement advanced indexing support
+        raise Unsupported(
+            f"Pallas backend does not yet support complex indexing pattern: {index}"
+        )
+
+    def _has_iteration_vars(self, index: sympy.Expr) -> bool:
+        """Check if index expression contains iteration variables (x0, x1, etc.)."""
+        free_symbols = index.free_symbols
+        iter_vars = OrderedSet(self.range_tree_nodes.keys())
+        return bool(free_symbols & iter_vars)
+
+    def _has_indirect_vars(self, index: sympy.Expr) -> bool:
+        """Check if index expression contains indirect variables (tmp0, tmp1, etc.)."""
+        free_symbols = index.free_symbols
+        for sym in free_symbols:
+            if str(sym).startswith("tmp"):
+                return True
+        return False
+
+    def _get_index_expr(self, index: sympy.Expr) -> tuple[str, bool]:
+        """
+        Get the index expression string and whether it needs flattening.
+
+        Returns:
+            Tuple of (index_str, needs_flatten) where needs_flatten indicates
+            if the buffer should be flattened before indexing (for mixed indexing).
+        """
+        has_indirect = self._has_indirect_vars(index)
+        has_iter_vars = self._has_iteration_vars(index)
+
+        if has_indirect and has_iter_vars:
+            return self._handle_mixed_indexing(index), True
+        elif has_indirect:
+            return self.kexpr(index), False
+        else:
+            return self._get_index_str(index), False
+
+    def _determine_masked_ops_for_kernel(self) -> bool:
+        """
+        Determine if we should use masked ops for this entire kernel.
+
+        Masked ops with pl.ds(block_size) flatten tensors to 1D, which works when:
+        1. We're on GPU (CUDA backend uses Triton which requires power-of-2 sizes)
+        2. All tensors are already 1D (so flattening doesn't change dimensionality)
+        3. All tensors have the same size (so broadcasting works correctly)
+
+        With per-tensor masks, each tensor gets its own mask based on its size.
+
+        This should be called once in codegen_kernel() before generating the kernel body.
+        """
+        if not self.is_gpu:
+            return False
+
+        # Get all buffer sizes
+        # We need ALL buffers - inputs, outputs, and intermediates
+        all_buffer_names = OrderedSet()
+
+        # Get input buffers from args
+        all_buffer_names.update(self.args.input_buffers.keys())
+        # Get output buffers from args
+        all_buffer_names.update(self.args.output_buffers.keys())
+        # Also get any intermediate buffers from the graph
+        all_buffer_names.update(V.graph.name_to_buffer.keys())
+
+        # Get shapes and sizes for all buffers
+        buf_info = []
+        for buf_name in all_buffer_names:
+            try:
+                buf = V.graph.get_buffer(buf_name)
+                size = buf.get_size()
+                shape = tuple(int(s) if hasattr(s, "__int__") else s for s in size)
+                # Calculate flattened size
+                total_size = 1
+                for s in size:
+                    if hasattr(s, "__int__"):
+                        total_size *= int(s)
+                    else:
+                        total_size *= s
+                buf_info.append((buf_name, shape, total_size))
+            except Exception:
+                pass
+
+        # Only use masked ops if:
+        # 1. All buffers are 1D (single-element shape tuples)
+        # 2. All buffers have the same size
+        # This ensures that pl.ds(block_size) flattening works correctly
+        # and masks can be properly applied without broadcasting issues.
+        if buf_info and len(buf_info) > 0:
+            # Check if all are 1D
+            all_1d = all(len(shape) == 1 for _, shape, _ in buf_info)
+            if not all_1d:
+                return False
+
+            # Check if all have the same size
+            first_size = buf_info[0][2]
+            all_same_size = all(size == first_size for _, _, size in buf_info)
+            return all_same_size
+
+        return False
+
+    def _get_or_create_mask(self, buf_name: str) -> str:
+        """Get or create a unique mask variable for a buffer."""
+        if buf_name not in self.tensor_masks:
+            mask_var = f"mask_{buf_name}"
+            self.tensor_masks[buf_name] = mask_var
+        return self.tensor_masks[buf_name]
+
+    def load(self, name: str, index: sympy.Expr) -> CSEVariable:  # type: ignore[override]
+        buf = self.args.input(name)
+        dtype = V.graph.get_dtype(name)
+
+        # Track the load index expression for argmax/argmin axis detection
+        self.load_index_exprs[name] = index
+
+        # Determine masked ops strategy on first load/store if not yet determined
+        if self.use_masked_ops is None:
+            self.use_masked_ops = self._determine_masked_ops_for_kernel()
+
+        index_str, needs_flatten = self._get_index_expr(index)
+
+        # Build load expression using string concatenation
+        use_masked = index_str == "..." and not needs_flatten and self.use_masked_ops
+
+        if use_masked:
+            # GPU masked load: flatten tensor and apply per-tensor mask
+            mask_var = self._get_or_create_mask(name)
+            load_expr = f"pltriton.load({buf}.at[pl.ds(block_size)], mask={mask_var})"
+        elif needs_flatten:
+            # Flatten then index for non-contiguous access
+            load_expr = f"{buf}[...].flatten()[{index_str}]"
+        else:
+            # Direct indexing for contiguous access
+            load_expr = f"{buf}[{index_str}]"
+
+        return self.cse.generate(
+            self.compute,
+            load_expr,
+            dtype=dtype,
+        )
+
+    def _handle_mixed_indexing(self, index: sympy.Expr) -> str:
+        """
+        Handle indexing with both indirect variables and iteration variables.
+
+        For example, x[indices, :] generates index = i0 + stride * tmp0
+        where tmp0 is loaded from indices and i0 is the iteration variable.
+
+        We need to convert this to JAX advanced indexing with proper broadcasting.
+        When there are multiple iteration variables, they need different shapes
+        to form an outer product (grid) rather than broadcasting together.
+        """
+        # Get iteration variables
+        iter_vars = OrderedSet(self.range_tree_nodes.keys())
+        free_symbols = index.free_symbols
+        used_iter_vars_set = free_symbols & iter_vars
+
+        if len(used_iter_vars_set) == 0:
+            return self.kexpr(index)
+
+        # Sort iteration variables by their coefficient (stride) in the index expression.
+        # Variables with larger strides correspond to earlier output dimensions.
+        def get_coefficient(var):
+            """Extract the coefficient of a variable in the index expression."""
+            coeff = index.coeff(var)
+            if coeff == 0:
+                # Variable appears in a more complex form, try differentiation
+                coeff = sympy.diff(index, var)
+            # Convert to int if possible for sorting
+            try:
+                return int(coeff)
+            except (TypeError, ValueError):
+                return 0
+
+        used_iter_vars = sorted(used_iter_vars_set, key=get_coefficient, reverse=True)
+        iter_coeffs = [get_coefficient(var) for var in used_iter_vars]
+
+        index_str = self.kexpr(index)
+        indirect_var_syms = [s for s in free_symbols if str(s).startswith("tmp")]
+        indirect_vars = [str(sym) for sym in indirect_var_syms]
+
+        # Get coefficients for indirect vars to determine output ordering
+        indirect_coeffs = {str(s): get_coefficient(s) for s in indirect_var_syms}
+
+        # Build a sorted list of all components by coefficient (descending)
+        # Each component is (coeff, type, var) where type is 'iter' or 'indirect'
+        all_components = []
+        for var in used_iter_vars:
+            all_components.append((get_coefficient(var), "iter", var))
+        for sym in indirect_var_syms:
+            all_components.append((get_coefficient(sym), "indirect", sym))
+        all_components.sort(key=lambda x: x[0], reverse=True)
+
+        # Calculate trailing dims needed for each component
+        # Each component needs trailing dims for all subsequent iter vars
+        # plus trailing dims for all dimensions of subsequent indirect vars
+        # For simplicity, assume each indirect var contributes some dimensions
+        # that will be handled by the reshape at store time
+
+        # For iter vars, we need to count how many dimensions come after in the output
+        for i, var in enumerate(used_iter_vars):
+            var_name = str(var)
+            if var in self.range_tree_nodes:
+                range_entry = self.range_tree_nodes[var]
+                range_size = range_entry.length
+                var_coeff = get_coefficient(var)
+
+                arange_expr = f"jnp.arange({self.kexpr(range_size)})"
+
+                # Count trailing dims needed:
+                # - One for each subsequent iter var (with smaller coeff)
+                # - One for each dimension of indirect vars with smaller coeff
+                # For indirect vars, assume each contributes 2 dims (common case)
+                # The actual reshape at store time will fix any shape mismatches
+                n_trailing_iter = sum(1 for c in iter_coeffs if c < var_coeff)
+                n_trailing_indirect = sum(
+                    2 for c in indirect_coeffs.values() if c < var_coeff
+                )
+                n_trailing = n_trailing_iter + n_trailing_indirect
+
+                if n_trailing > 0:
+                    trailing_dims = ", None" * n_trailing
+                    arange_expr = f"{arange_expr}[:{trailing_dims}]"
+
+                index_str = index_str.replace(var_name, arange_expr)
+
+        # Reshape indirect variables for proper broadcasting.
+        for indirect_var in indirect_vars:
+            indirect_coeff = indirect_coeffs[indirect_var]
+
+            # Count dims needed before and after this indirect var
+            n_leading = sum(1 for c in iter_coeffs if c > indirect_coeff)
+            n_trailing = sum(1 for c in iter_coeffs if c < indirect_coeff)
+
+            # Build the indexing expression with leading Nones, ellipsis, trailing Nones
+            if n_leading > 0 and n_trailing > 0:
+                leading_nones = "None, " * n_leading
+                trailing_nones = ", None" * n_trailing
+                reshape_expr = f"{indirect_var}[{leading_nones}...{trailing_nones}]"
+            elif n_leading > 0:
+                leading_nones = "None, " * n_leading
+                reshape_expr = f"{indirect_var}[{leading_nones}...]"
+            elif n_trailing > 0:
+                trailing_nones = ", None" * n_trailing
+                reshape_expr = f"{indirect_var}[...{trailing_nones}]"
+            else:
+                reshape_expr = indirect_var
+
+            index_str = index_str.replace(indirect_var, reshape_expr)
+
+        return index_str
+
+    def store(
+        self, name: str, index: sympy.Expr, value: CSEVariable, mode: Any = None
+    ) -> None:  # type: ignore[override]
+        if mode is not None:
+            raise Unsupported("pallas store mode not supported")
+        out = self.args.output(name)
+        self.store_buffer_names.add(name)
+
+        # Determine masked ops strategy on first load/store if not yet determined
+        if self.use_masked_ops is None:
+            self.use_masked_ops = self._determine_masked_ops_for_kernel()
+
+        # Check if this is a scalar output (reduction to scalar)
+        # Only shape () is a true scalar, not (1,) which is a 1-element tensor
+        try:
+            buf = V.graph.get_buffer(name)
+            output_shape = buf.get_size()
+            is_scalar = len(output_shape) == 0
+        except Exception:
+            output_shape = ()
+            is_scalar = False
+
+        if is_scalar:
+            # For scalar outputs, use [...] to assign the entire scalar
+            store_expr = f"{out}[...] = {value}"
+        else:
+            index_str, needs_flatten = self._get_index_expr(index)
+
+            # Build store expression using string concatenation
+            use_masked = (
+                index_str == "..." and not needs_flatten and self.use_masked_ops
+            )
+
+            if use_masked:
+                # GPU masked store: flatten tensor and apply per-tensor mask
+                mask_var = self._get_or_create_mask(name)
+                store_expr = f"pltriton.store({out}.at[pl.ds(block_size)], {value}, mask={mask_var})"
+            elif index_str == "...":
+                # When storing the full array, reshape to match the output shape.
+                # This handles:
+                # - Mixed indexing producing flat results needing reshape
+                # - Squeeze operations where value has more dims than output
+                # - If shapes already match, reshape is a no-op.
+                # Use the output array's shape at runtime to avoid issues with
+                # symbolic sizes not being defined in the kernel.
+                store_expr = f"{out}[...] = {value}.reshape({out}.shape)"
+            else:
+                # Direct indexed assignment
+                store_expr = f"{out}[{index_str}] = {value}"
+
+        self.stores.writeline(store_expr)
+        # Track which output param this store uses for filtering in codegen_kernel
+        self.store_with_output.append((out, store_expr))
+
+    def reduction(
+        self,
+        dtype: torch.dtype,
+        src_dtype: torch.dtype,
+        reduction_type: ReductionType,
+        value: Union[CSEVariable, tuple[CSEVariable, ...]],
+    ) -> Union[CSEVariable, tuple[CSEVariable, ...]]:  # type: ignore[override]
+        """
+        Generate code for reduction operations in JAX/Pallas.
+
+        Reductions in Pallas work by:
+        1. Loading the input data into the kernel
+        2. Applying JAX reduction operations (jnp.sum, jnp.max, etc.)
+        3. Storing the reduced result
+
+        The reduction happens over the loaded block of data.
+        """
+        assert self.inside_reduction
+
+        if isinstance(value, tuple):
+            raise Unsupported(
+                "Tuple reductions (e.g., welford_combine) not supported in Pallas backend"
+            )
+
+        # Check if this reduction is already cached
+        cache_key = (src_dtype, reduction_type, value)
+        if cache_key in self.cse.reduction_cache:
+            return self.cse.reduction_cache[cache_key]
+
+        # Map reduction types to JAX functions
+        reduction_ops = {
+            "sum": "jnp.sum",
+            "prod": "jnp.prod",  # CPU only - not supported in Pallas GPU (Triton) backend
+            "max": "jnp.max",
+            "min": "jnp.min",
+            "any": "jnp.any",
+            "argmax": "jnp.argmax",
+            "argmin": "jnp.argmin",
+        }
+
+        # Determine if this is a partial reduction (has pointwise dimensions)
+        # or a full reduction to scalar
+        pointwise_prefixes = OrderedSet(["x", "y", "z"])
+        has_pointwise = any(p in self.numels for p in pointwise_prefixes)
+
+        # Get the individual pointwise dimension sizes from range_tree_nodes
+        pointwise_sizes = []
+        for var, entry in sorted(
+            self.range_tree_nodes.items(), key=lambda x: str(x[0])
+        ):
+            if not entry.prefix.startswith("r"):
+                try:
+                    pointwise_sizes.append(int(entry.length))
+                except (TypeError, ValueError):
+                    pointwise_sizes = None
+                    break
+
+        # Get the pointwise and reduction numels
+        pointwise_numel = 1
+        for p in pointwise_prefixes:
+            if p in self.numels:
+                numel = self.numels[p]
+                try:
+                    pointwise_numel *= int(numel)
+                except (TypeError, ValueError):
+                    pointwise_numel = None
+                    break
+
+        reduction_numel = 1
+        for p in self.numels:
+            if p.startswith("r"):
+                numel = self.numels[p]
+                try:
+                    reduction_numel *= int(numel)
+                except (TypeError, ValueError):
+                    reduction_numel = None
+                    break
+
+        # Count the number of pointwise and reduction dimensions
+        n_reduction_dims = sum(
+            1
+            for var, entry in self.range_tree_nodes.items()
+            if entry.prefix.startswith("r")
+        )
+
+        if reduction_type == "xor_sum":
+            if has_pointwise and pointwise_numel and reduction_numel:
+                reduction_expr = f"jnp.bitwise_xor.reduce({value}.reshape({pointwise_numel}, -1), axis=-1)"
+            else:
+                reduction_expr = f"jnp.bitwise_xor.reduce({value})"
+        elif reduction_type in ("argmax", "argmin"):
+            # For argmax/argmin, we need to preserve the axis information
+            # because the result is indices, not values.
+            reduction_op = reduction_ops[reduction_type]
+            # Check if this is a true partial reduction (pointwise numel > 1)
+            # When pointwise_numel == 1, it's effectively a full reduction to scalar
+            is_partial_reduction = (
+                has_pointwise and pointwise_numel and pointwise_numel > 1
+            )
+            if is_partial_reduction and n_reduction_dims > 0:
+                # Partial reduction: determine the reduction axis from load index
+                # The reduction variable's coefficient in the index expression tells us its stride
+                # Higher stride = outer axis (lower axis number in row-major order)
+                reduction_axis = 0  # Default to axis 0
+                if self.load_index_exprs:
+                    # Get the first load index expression
+                    load_index = next(iter(self.load_index_exprs.values()))
+                    # Find the reduction variable (starts with 'r')
+                    reduction_vars = [
+                        var
+                        for var, entry in self.range_tree_nodes.items()
+                        if entry.prefix.startswith("r")
+                    ]
+                    if reduction_vars:
+                        r_var = reduction_vars[0]
+                        # Get the coefficient (stride) of the reduction variable
+                        r_coeff = load_index.coeff(r_var)
+                        try:
+                            r_stride = int(r_coeff) if r_coeff != 0 else 1
+                        except (TypeError, ValueError):
+                            r_stride = 1
+                        # Get pointwise variable
+                        pw_vars = [
+                            var
+                            for var, entry in self.range_tree_nodes.items()
+                            if not entry.prefix.startswith("r")
+                        ]
+                        if pw_vars:
+                            pw_var = pw_vars[0]
+                            pw_coeff = load_index.coeff(pw_var)
+                            try:
+                                pw_stride = int(pw_coeff) if pw_coeff != 0 else 1
+                            except (TypeError, ValueError):
+                                pw_stride = 1
+                            # Higher stride = earlier (outer) axis
+                            # For 2D: axis 0 has stride = dim1_size, axis 1 has stride = 1
+                            reduction_axis = 0 if r_stride > pw_stride else 1
+                if n_reduction_dims == 1:
+                    reduction_expr = f"{reduction_op}({value}, axis={reduction_axis})"
+                else:
+                    # Multiple reduction dims - reduce over all of them
+                    axes = tuple(range(n_reduction_dims))
+                    reduction_expr = f"{reduction_op}({value}, axis={axes})"
+            else:
+                # Full reduction to scalar
+                reduction_expr = f"{reduction_op}({value})"
+        elif reduction_type in reduction_ops:
+            if (
+                has_pointwise
+                and pointwise_numel
+                and reduction_numel
+                and pointwise_sizes
+            ):
+                # For partial reductions, we need to:
+                # 1. Move pointwise axes to the front and reduction axes to the back
+                # 2. Reshape to (pointwise_numel, reduction_numel)
+                # 3. Reduce over the last axis
+                #
+                # We use moveaxis to reorder: first move axes matching pointwise sizes
+                # to the front, then the remaining (reduction) axes go to the back.
+                # Finally reshape and reduce.
+                #
+                # Generate code to dynamically determine and reorder axes:
+                pw_sizes_str = str(pointwise_sizes)
+                reduction_op = reduction_ops[reduction_type]
+                reduction_expr = (
+                    f"(lambda v: (lambda pw_sizes: "
+                    f"{reduction_op}(v.reshape(-1, {reduction_numel}), axis=-1) "
+                    f"if v.ndim == 2 else "
+                    f"(lambda input_shape, pw_axes: "
+                    f"{reduction_op}("
+                    f"jnp.moveaxis(v, pw_axes, list(range(len(pw_axes)))).reshape({pointwise_numel}, -1), axis=-1)"
+                    f")("
+                    f"v.shape, "
+                    f"[i for i, s in enumerate(v.shape) if s in pw_sizes][:len(pw_sizes)]"
+                    f")"
+                    f")({pw_sizes_str}))({value})"
+                )
+            else:
+                # Full reduction to scalar
+                reduction_expr = f"{reduction_ops[reduction_type]}({value})"
+        else:
+            raise Unsupported(
+                f"Reduction type '{reduction_type}' not yet supported in Pallas backend. "
+                f"Supported types: {list(reduction_ops.keys())}, xor_sum"
+            )
+
+        # Generate CSE variable for the reduction result
+        result = self.cse.generate(
+            self.compute,
+            reduction_expr,
+            dtype=dtype,
+        )
+
+        # Cache the result
+        self.cse.reduction_cache[cache_key] = result
+        return result
+
+    @staticmethod
+    def _buffer_is_contiguous(buffer_name: str) -> bool:
+        buf = V.graph.get_buffer(buffer_name)
+        layout = buf.get_layout()
+        return layout.is_contiguous()
+
+    def codegen_kernel(self, name: Optional[str] = None) -> str:  # type: ignore[override]
+        """
+        Generate the complete Pallas kernel code as a Python string.
+
+        This includes:
+        - Import statements for JAX/Pallas
+        - The kernel function that operates on refs
+        - The main wrapper function that handles PyTorch<->JAX conversions via DLPack
+
+        Args:
+            name: Optional kernel name (will use placeholder if not provided)
+
+        Returns:
+            str: Complete Python source code for the Pallas kernel
+        """
+        code = IndentedBuffer()
+
+        # Define the Pallas kernel: accepts refs, uses broadcasted expressions
+        arg_defs, _, _, _ = self.args.python_argdefs()
+        kernel_params = [a.name for a in arg_defs]
+        pure_out_params = [p for p in kernel_params if p.startswith("out_ptr")]
+        output_params = [
+            p for p in kernel_params if p.startswith(("out_ptr", "in_out_ptr"))
+        ]
+        if not output_params:
+            raise RuntimeError("Pallas backend requires at least one output buffer")
+
+        output_buffer_lookup = {
+            inner: outer
+            for outer, inner in self.args.output_buffers.items()
+            if isinstance(inner, str)
+        }
+
+        kernel_name = name or ""
+        interpret_is_cpu = V.graph.get_current_device_or_throw().type == "cpu"
+        is_tpu = torch._inductor.config._debug_cpu_to_tpu_pallas
+        if is_tpu:
+            if not torch._inductor.config.pallas_take_first_jax_device_only:
+                raise RuntimeError(
+                    "Pallas backend currently only supports using the first JAX device."
+                )
+            if not has_tpu_pallas():
+                raise RuntimeError(
+                    "PALLAS_TARGET_TPU is set, but no TPU device was found. "
+                    "Please make sure that you have a TPU available and that JAX is configured correctly."
+                )
+        interpret_literal = "True" if interpret_is_cpu else "False"
+
+        # For GPU (Triton backend), import pltriton for masked loads/stores
+        # Import math at module level if we'll use it for masked ops
+        imports = (
+            """
+            import functools
+            """
+            + ("import math\n            " if self.use_masked_ops else "")
+            + """import torch
+            import jax
+            import jax.numpy as jnp
+            from jax.experimental import pallas as pl
+            from torch._inductor.runtime.runtime_utils import torch_dtype_to_jax_runtime
+            """
+            + (
+                "\n            from jax.experimental.pallas import triton as pltriton"
+                if not interpret_is_cpu
+                else ""
+            )
+            + (
+                "\n            from torch._inductor.runtime.runtime_utils import next_power_of_2"
+                if self.use_masked_ops
+                else ""
+            )
+        )
+        code.splice(imports, strip=True)
+
+        aliasable_flags: dict[str, bool] = {}
+        for param in pure_out_params:
+            buffer_name = output_buffer_lookup.get(param)
+            is_contiguous = buffer_name is not None and self._buffer_is_contiguous(
+                buffer_name
+            )
+            aliasable_flags[param] = (not interpret_is_cpu) and is_contiguous
+        alias_params = [
+            f"{param}_alias" for param in pure_out_params if aliasable_flags[param]
+        ]
+        pointer_tail = [
+            p for p in kernel_params if p.startswith(("in_out_ptr", "in_ptr"))
+        ]
+        kernel_input_params = alias_params + pointer_tail
+        full_kernel_params = alias_params + kernel_params
+        non_alias_out_set = OrderedSet(
+            [name for name, flag in aliasable_flags.items() if not flag]
+        )
+        copy_output_indices = [
+            idx for idx, name in enumerate(output_params) if name in non_alias_out_set
+        ]
+        self.aliasable_out_ptrs = aliasable_flags
+
+        # For GPU with masked ops, add block_size as keyword-only parameter
+        kernel_signature = (
+            f"def {kernel_name}_kernel({', '.join(full_kernel_params)}"
+            + (", *, block_size" if self.use_masked_ops else "")
+            + "):"
+        )
+        code.writeline(kernel_signature)
+        with code.indent():
+            # For masked ops on GPU, generate per-tensor masks at the start
+            if self.use_masked_ops and self.tensor_masks:
+                # Create a mapping from buffer name to parameter name
+                buf_to_param = {}
+                for outer, inner in self.args.input_buffers.items():
+                    buf_to_param[outer] = inner if isinstance(inner, str) else outer
+                for outer, inner in self.args.output_buffers.items():
+                    buf_to_param[outer] = inner if isinstance(inner, str) else outer
+
+                # Generate a mask for each tensor that was accessed
+                for buf_name, mask_var in sorted(self.tensor_masks.items()):
+                    param_name = buf_to_param.get(buf_name, buf_name)
+                    # Find the corresponding parameter in kernel_params
+                    matching_param = None
+                    for p in kernel_params:
+                        # Check if this parameter corresponds to the buffer
+                        if param_name == p or buf_name in str(p):
+                            matching_param = p
+                            break
+
+                    if matching_param:
+                        # Calculate flattened size for this tensor
+                        code.writeline(f"# Mask for {buf_name}")
+                        code.writeline(f"{mask_var}_size = {matching_param}.size")
+                        code.writeline(
+                            f"{mask_var} = jnp.arange(block_size) < {mask_var}_size"
+                        )
+
+            # Generate iteration variables as jnp.arange arrays
+            # These are used by index_expr operations like torch.arange
+            # Skip on GPU with masked ops - iteration vars would create non-power-of-2 arrays
+            # which are not supported by Pallas Triton backend
+            if self.range_tree_nodes and not self.use_masked_ops:
+                code.writeline("# Define iteration variables as JAX arrays")
+                # Get the first output buffer's shape for reshaping
+                first_output_shape = None
+                first_output_numel = None
+                if output_params:
+                    first_out_param = output_params[0]
+                    first_out_buf_name = output_buffer_lookup.get(first_out_param)
+                    if first_out_buf_name:
+                        try:
+                            buf = V.graph.get_buffer(first_out_buf_name)
+                            size = buf.get_size()
+                            first_output_shape = tuple(
+                                int(s) if hasattr(s, "__int__") else s for s in size
+                            )
+                            first_output_numel = 1
+                            for s in first_output_shape:
+                                first_output_numel *= s
+                        except Exception:
+                            pass
+
+                for var_sym, entry in self.range_tree_nodes.items():
+                    var_name = str(var_sym)
+                    length = entry.length
+                    length_str = self.kexpr(length)
+                    # If the iteration variable length matches the output numel,
+                    # reshape it to match the output shape for proper broadcasting
+                    try:
+                        length_val = int(length) if hasattr(length, "__int__") else None
+                    except (TypeError, ValueError):
+                        length_val = None
+
+                    # Skip symbolic lengths - jnp.arange requires concrete values
+                    # This happens with dynamic shapes
+                    if length_val is None:
+                        continue
+
+                    if (
+                        first_output_shape
+                        and len(first_output_shape) > 1
+                        and length_val == first_output_numel
+                    ):
+                        shape_str = ", ".join(str(s) for s in first_output_shape)
+                        code.writeline(
+                            f"{var_name} = jnp.arange({length_str}).reshape({shape_str})"
+                        )
+                    else:
+                        code.writeline(f"{var_name} = jnp.arange({length_str})")
+
+            # Emit compute (CSE) and store lines; they reference *_ptr[index] directly.
+            for line in self.compute._lines:
+                code.writeline(str(line))
+            # Filter stores to only emit those for outputs that are in kernel params.
+            # This handles cases where an intermediate value was stored but the buffer
+            # was later optimized away (not passed to the kernel).
+            for out_ptr, store_line in self.store_with_output:
+                if out_ptr in full_kernel_params:
+                    code.writeline(store_line)
+
+        jit_wrapper_name = f"{kernel_name}_jit_wrapper"
+        donate_indices = []
+        for idx, name in enumerate(kernel_input_params):
+            if (name in alias_params) or name.startswith("in_out_ptr"):
+                donate_indices.append(idx + 2)
+        if donate_indices:
+            donate_literal = "(" + ", ".join(str(x) for x in donate_indices) + ",)"
+        else:
+            donate_literal = "()"
+        code.writeline(
+            "@functools.partial("
+            "jax.jit, static_argnums=(0, 1), donate_argnums="
+            f"{donate_literal})"
+        )
+        code.writeline(
+            f"def {jit_wrapper_name}(out_shapes, out_dtypes, {', '.join(kernel_input_params)}):"
+        )
+        with code.indent():
+            code.writeline("out_specs = tuple(")
+            code.writeline("    jax.ShapeDtypeStruct(shape, dtype)")
+            code.writeline("    for shape, dtype in zip(out_shapes, out_dtypes)")
+            code.writeline(")")
+
+            # For masked ops, calculate block_size as next power of 2 of max flattened size
+            if self.use_masked_ops:
+                code.writeline(
+                    "# Calculate block_size as next power of 2 for Triton backend"
+                )
+                code.writeline("# Find maximum flattened size across all tensors")
+                code.writeline("max_size = 0")
+                # Calculate size for all input tensors
+                for param in kernel_input_params:
+                    code.writeline(f"max_size = max(max_size, {param}.size)")
+                # Also consider output shapes
+                code.writeline("for shape in out_shapes:")
+                code.writeline(
+                    "    tensor_size = shape[0] if len(shape) == 1 else math.prod(shape)"
+                )
+                code.writeline("    max_size = max(max_size, tensor_size)")
+                code.writeline("block_size = next_power_of_2(max_size)")
+
+            alias_pairs: list[tuple[int, int]] = []
+            for out_idx, name in enumerate(output_params):
+                if name.startswith("out_ptr"):
+                    if aliasable_flags.get(name, False):
+                        alias_name = f"{name}_alias"
+                        input_idx = kernel_input_params.index(alias_name)
+                        alias_pairs.append((input_idx, out_idx))
+                else:
+                    input_idx = kernel_input_params.index(name)
+                    alias_pairs.append((input_idx, out_idx))
+            alias_map_literal = ", ".join(f"{i}: {o}" for (i, o) in alias_pairs)
+
+            # For masked ops, wrap kernel with functools.partial to pass block_size
+            kernel_arg = (
+                f"functools.partial({kernel_name}_kernel, block_size=block_size),"
+                if self.use_masked_ops
+                else f"{kernel_name}_kernel,"
+            )
+            code.writeline("return pl.pallas_call(")
+            code.writeline("    " + kernel_arg)
+
+            code.writeline("    out_shape=out_specs,")
+            code.writeline(f"    interpret={interpret_literal},")
+            code.writeline("    grid=(1,),")
+            code.writeline(
+                f"    input_output_aliases={{ {alias_map_literal} }},"
+                if alias_pairs
+                else "    input_output_aliases={},"
+            )
+            code.writeline(")(")
+            if kernel_input_params:
+                code.writeline(f"    {', '.join(kernel_input_params)},")
+            code.writeline(")")
+
+        main_name = f"{kernel_name}_main"
+        code.writeline(
+            f"def {main_name}({', '.join(full_kernel_params)}, stream=None):"
+        )
+        with code.indent():
+            code.writeline("# Enable JAX x64 mode for float64/int64 support")
+            code.writeline("jax.config.update('jax_enable_x64', True)")
+            if alias_params:
+                code.writeline("# Convert Torch -> JAX for donated outputs")
+                for alias_name in alias_params:
+                    # TODO: The `jax.device_put` path is a temporary workaround for a Mosaic compiler bug
+                    # that occurs with DLPack. Once TorchTPU provides a direct method for placing a
+                    # `torch.Tensor` on a TPU device, this should be reverted to use the
+                    #  `jax.dlpack.from_dlpack` path.
+                    if is_tpu:
+                        code.writeline(
+                            f"{alias_name}_jax = jax.device_put({alias_name}.cpu().numpy(), device=jax.devices('tpu')[0])"
+                        )
+                    else:
+                        code.writeline(
+                            f"{alias_name}_jax = jax.dlpack.from_dlpack({alias_name}.detach())"
+                        )
+            code.writeline("# Convert Torch -> JAX for in-place tensors")
+            for ptr in pointer_tail:
+                if ptr.startswith("in_out_ptr"):
+                    if is_tpu:
+                        code.writeline(
+                            f"{ptr}_jax = jax.device_put({ptr}.cpu().numpy(), device=jax.devices('tpu')[0])"
+                        )
+                    else:
+                        code.writeline(
+                            f"{ptr}_jax = jax.dlpack.from_dlpack({ptr}.detach())"
+                        )
+            code.writeline("# Convert Torch -> JAX for inputs")
+            for ptr in pointer_tail:
+                if ptr.startswith("in_ptr"):
+                    if is_tpu:
+                        code.writeline(
+                            f"{ptr}_jax = jax.device_put({ptr}.cpu().numpy(), device=jax.devices('tpu')[0])"
+                        )
+                    else:
+                        code.writeline(
+                            f"{ptr}_jax = jax.dlpack.from_dlpack({ptr}.detach().contiguous())"
+                        )
+
+            code.writeline("# Prepare output metadata from PyTorch tensor")
+            code.writeline(
+                "out_shapes = ("
+                + ", ".join([f"tuple({name}.shape)" for name in output_params])
+                + ",)"
+            )
+            code.writeline(
+                "out_dtypes = ("
+                + ", ".join(
+                    [
+                        f"torch_dtype_to_jax_runtime({name}.dtype)"
+                        for name in output_params
+                    ]
+                )
+                + ",)"
+            )
+            arg_name_map: dict[str, str] = {}
+            for alias_name in alias_params:
+                arg_name_map[alias_name] = f"{alias_name}_jax"
+            for ptr in pointer_tail:
+                arg_name_map[ptr] = f"{ptr}_jax"
+
+            if kernel_input_params:
+                alias_args_str = ", ".join(
+                    arg_name_map[name] for name in kernel_input_params
+                )
+                code.writeline(
+                    f"res = {jit_wrapper_name}(out_shapes, out_dtypes, {alias_args_str})"
+                )
+            else:
+                code.writeline(f"res = {jit_wrapper_name}(out_shapes, out_dtypes)")
+            if copy_output_indices:
+                code.writeline(
+                    "result_values = res if isinstance(res, tuple) else (res,)"
+                )
+                for idx in copy_output_indices:
+                    name = output_params[idx]
+                    if is_tpu:
+                        code.writeline(
+                            f"res_cpu = jax.device_get(result_values[{idx}])"
+                        )
+                        code.writeline(f"{name}.copy_(torch.from_dlpack(res_cpu))")
+                    else:
+                        code.writeline(
+                            f"{name}.copy_(torch.from_dlpack(result_values[{idx}]))"
+                        )
+
+        return code.getvalue()
+
+    def call_kernel(self, name: str, node: Optional[IRNode] = None) -> None:  # type: ignore[override]
+        """Generate the Python code that calls this Pallas kernel."""
+        wrapper = V.graph.wrapper_code
+        arg_defs, call_args, _, _ = self.args.python_argdefs()
+        kernel_param_names = [a.name for a in arg_defs]
+        pure_out_params = [p for p in kernel_param_names if p.startswith("out_ptr")]
+        call_arg_strs = list(map(str, call_args))
+        aliasable = getattr(self, "aliasable_out_ptrs", {})
+        alias_call_args = [
+            call_arg_strs[kernel_param_names.index(p)]
+            for p in pure_out_params
+            if aliasable.get(p, False)
+        ]
+
+        # Generate kernel call: kernel_name.run(arg1, arg2, ...)
+        # Note: async_compile.pallas loads {name}_main function and wraps it in PallasKernelWrapper
+        # which exposes a run() method
+        kernel_call = f"{name}.run({', '.join(alias_call_args + call_arg_strs)})"
+        wrapper.writeline(kernel_call)
+
+
+class PallasScheduling(SIMDScheduling):
+    kernel_type = PallasKernel  # type: ignore[assignment]
+
+    @classmethod
+    def get_backend_features(cls, device: torch.device) -> OrderedSet[BackendFeature]:
+        # Pallas/JAX can handle reductions to single elements efficiently
+        # without requiring split reductions
+        return OrderedSet([BackendFeature.REDUCE_TO_SINGLE_ELEMENT])
+
+    def define_kernel(
+        self,
+        src_code: str,
+        node_schedule: Sequence[BaseSchedulerNode],
+        kernel: PallasKernel,
+    ) -> str:  # type: ignore[override]
+        wrapper = V.graph.wrapper_code
+        if src_code in wrapper.src_to_kernel:
+            return wrapper.src_to_kernel[src_code]
+
+        fused_name = (
+            get_fused_kernel_name(node_schedule, config.triton.descriptive_names)
+            if config.triton.descriptive_names
+            else ""
+        )
+        kernel_hash = hashlib.sha256(src_code.encode("utf-8")).hexdigest()[:8]
+        if fused_name == "fused":
+            kernel_name = f"pallas_{kernel_hash}"
+        else:
+            kernel_name = f"pallas_{fused_name}_{kernel_hash}"
+        wrapper.src_to_kernel[src_code] = kernel_name
+
+        # Replace placeholder if any
+        src_code = src_code.replace("", kernel_name)
+
+        compile_wrapper = IndentedBuffer()
+        compile_wrapper.writeline(f"async_compile.pallas({kernel_name!r}, r'''")
+        compile_wrapper.splice(src_code, strip=True)
+        compile_wrapper.writeline("''')")
+
+        origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
+        metadata_comment = f"{origins}\n{detailed_origins}"
+        wrapper.define_kernel(kernel_name, compile_wrapper.getvalue(), metadata_comment)
+
+        return kernel_name
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/python_wrapper_mtia.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/python_wrapper_mtia.py
new file mode 100644
index 0000000000000000000000000000000000000000..00833e1de702ca9922b41c53defc88c92fa6d350
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/python_wrapper_mtia.py
@@ -0,0 +1,34 @@
+from typing import Optional
+from typing_extensions import override
+
+from torch._inductor import ir
+
+from .wrapper import PythonWrapperCodegen
+
+
+class PythonWrapperMtia(PythonWrapperCodegen):
+    """
+    A thin wrapper of PythonWrapperCodegen with MTIA specific logic
+    """
+
+    @override
+    def write_header(self) -> None:
+        super().write_header()
+
+        # MITA specific imports
+        self.imports.splice("import mtia.host_runtime.torch_mtia.dynamic_library")
+
+    @override
+    @staticmethod
+    def create(
+        is_subgraph: bool,
+        subgraph_name: Optional[str],
+        parent_wrapper: Optional[PythonWrapperCodegen],
+        partition_signatures: Optional[ir.GraphPartitionSignature] = None,
+    ) -> PythonWrapperCodegen:
+        if is_subgraph:
+            # Delegate to the parent class to handle the case of subgraph
+            return PythonWrapperCodegen.create(
+                is_subgraph, subgraph_name, parent_wrapper, partition_signatures
+            )
+        return PythonWrapperMtia()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/segmented_tree.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/segmented_tree.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6e5c86d18109d36ee8b9595d0bce48685845f54
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/segmented_tree.py
@@ -0,0 +1,242 @@
+from collections.abc import Callable
+from typing import Generic, Optional, TypeVar
+
+
+T = TypeVar("T")
+
+
+def _value_or(opt: Optional[T], default: T) -> T:
+    return opt if opt is not None else default
+
+
+class SegmentedTree(Generic[T]):
+    def __init__(
+        self,
+        values: list[T],
+        update_op: Callable[[T, T], T],
+        summary_op: Callable[[T, T], T],
+        identity_element: T,
+    ):
+        """
+        Initialize a segment tree with the given values and operations.
+
+        Args:
+            values: list of initial values
+            update_op: Function to apply when updating a value (e.g., addition)
+            summary_op: Function to summarize two values (e.g., min, max, sum)
+            identity_element: Identity element for the summary_op (e.g., 0 for sum, float('inf') for min)
+
+        Raises:
+            ValueError: If the input values list is empty
+        """
+        if not values:
+            raise ValueError("Cannot create a segment tree with empty values list")
+
+        self.n = len(values)
+        self.update_op = update_op
+        self.summary_op = summary_op
+        self.identity = identity_element
+
+        # Size of segment tree array (next power of 2 * 2)
+        # The tree follows a standard heap layout where
+        # node `n`'s children are at `2*n` and `2*n+1`.
+        # Index 0 is unused.
+        self.size = 1
+        while self.size < self.n:
+            self.size *= 2
+        self.size *= 2
+
+        # Initialize tree and lazy arrays
+        self.tree = [identity_element] * self.size
+        # The lazy array contains updates to the given node
+        # Upon update, we only push updates to the top-most
+        # nodes that fully receive the update. We then
+        # propagate the update down as required (i.e., when
+        # we receive an interval query that neither fully
+        # contains the node nor fully doesn't contain the
+        # node
+        self.lazy: list[Optional[T]] = [None] * self.size
+
+        # Build the tree
+        self._build(values, 1, 0, self.n - 1)
+
+    def _build(self, values: list[T], node: int, start: int, end: int) -> None:
+        """
+        Build the segment tree recursively.
+
+        Args:
+            values: Original array of values
+            node: Current node index in the segment tree
+            start: Start index of the segment
+            end: End index of the segment
+        """
+        if start == end:
+            # Leaf node
+            if start < len(values):
+                self.tree[node] = values[start]
+            return
+
+        mid = (start + end) // 2
+        left_child = 2 * node
+        right_child = 2 * node + 1
+
+        # Recursively build left and right subtrees
+        self._build(values, left_child, start, mid)
+        self._build(values, right_child, mid + 1, end)
+
+        # Update current node with summary of children
+        self.tree[node] = self.summary_op(self.tree[left_child], self.tree[right_child])
+
+    def _children(self, node: int) -> list[int]:
+        return [2 * node, 2 * node + 1]
+
+    def _push_lazy(self, node: int, start: int, end: int) -> None:
+        """
+        Push lazy updates down to children.
+
+        Args:
+            node: Current node index
+            start: Start index of the segment
+            end: End index of the segment
+        """
+        lazy_node = self.lazy[node]
+        if lazy_node is None:
+            return
+
+        # Apply lazy update to current node
+        self.tree[node] = self.update_op(self.tree[node], lazy_node)
+
+        if start != end:  # Not a leaf node
+            # Propagate to children
+            for child in self._children(node):
+                self.lazy[child] = self.update_op(
+                    _value_or(self.lazy[child], self.identity), lazy_node
+                )
+
+        # Clear the lazy value
+        self.lazy[node] = None
+
+    def _update_range_helper(
+        self, node: int, start: int, end: int, left: int, right: int, value: T
+    ) -> None:
+        """
+        Helper method to update a range of values in the segment tree.
+
+        Args:
+            node: Current node index
+            start: Start index of the current segment
+            end: End index of the current segment
+            left: Start index of the range to update
+            right: End index of the range to update
+            value: Value to apply to the range
+        """
+        # Push lazy updates before processing this node
+        self._push_lazy(node, start, end)
+
+        # No overlap
+        if start > right or end < left:
+            return
+
+        # Complete overlap
+        if start >= left and end <= right:
+            # Apply update to current node
+            self.lazy[node] = value
+            self._push_lazy(node, start, end)
+            return
+
+        # Partial overlap, recurse to children
+        mid = (start + end) // 2
+        left_child = 2 * node
+        right_child = 2 * node + 1
+
+        self._update_range_helper(left_child, start, mid, left, right, value)
+        self._update_range_helper(right_child, mid + 1, end, left, right, value)
+
+        # Update current node based on children
+        self.tree[node] = self.summary_op(self.tree[left_child], self.tree[right_child])
+
+    def _query_range_helper(
+        self, node: int, start: int, end: int, left: int, right: int
+    ) -> T:
+        """
+        Helper method to query a range of values in the segment tree.
+
+        Args:
+            node: Current node index
+            start: Start index of the current segment
+            end: End index of the current segment
+            left: Start index of the range to query
+            right: End index of the range to query
+
+        Returns:
+            Summary value for the range
+        """
+        # No overlap
+        if start > right or end < left:
+            return self.identity
+
+        # Push lazy updates before processing this node
+        self._push_lazy(node, start, end)
+
+        # Complete overlap
+        if start >= left and end <= right:
+            return self.tree[node]
+
+        # Partial overlap, recurse to children
+        mid = (start + end) // 2
+        left_child = 2 * node
+        right_child = 2 * node + 1
+
+        left_result = self._query_range_helper(left_child, start, mid, left, right)
+        right_result = self._query_range_helper(right_child, mid + 1, end, left, right)
+
+        # Combine results from children
+        return self.summary_op(left_result, right_result)
+
+    def update_range(self, start: int, end: int, value: T) -> None:
+        """
+        Update a range of values in the segment tree.
+
+        Args:
+            start: Start index of the range to update (inclusive)
+            end: End index of the range to update (inclusive)
+            value: Value to apply to the range
+
+        Raises:
+            ValueError: If start > end or indices are out of bounds
+        """
+        if start > end:
+            raise ValueError("Start index must be less than or equal to end index")
+
+        if start < 0 or start >= self.n:
+            raise ValueError(f"Start index {start} out of bounds [0, {self.n - 1}]")
+
+        if end < 0 or end >= self.n:
+            raise ValueError(f"End index {end} out of bounds [0, {self.n - 1}]")
+
+        self._update_range_helper(1, 0, self.n - 1, start, end, value)
+
+    def summarize_range(self, start: int, end: int) -> T:
+        """
+        Query a range of values in the segment tree.
+
+        Args:
+            start: Start index of the range to query (inclusive)
+            end: End index of the range to query (inclusive)
+
+        Returns:
+            Summary value for the range according to the summary operation
+
+        Raises:
+            ValueError: If start > end or indices are out of bounds
+        """
+        if start > end:
+            raise ValueError("Start index must be less than or equal to end index")
+
+        if start < 0 or start >= self.n:
+            raise ValueError(f"Start index {start} out of bounds [0, {self.n - 1}]")
+
+        if end < 0 or end >= self.n:
+            raise ValueError(f"End index {end} out of bounds [0, {self.n - 1}]")
+
+        return self._query_range_helper(1, 0, self.n - 1, start, end)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/simd.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/simd.py
new file mode 100644
index 0000000000000000000000000000000000000000..aff8966e5af7167eb0e2b6f7133d3397149c7436
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/simd.py
@@ -0,0 +1,3127 @@
+# mypy: allow-untyped-defs
+from __future__ import annotations
+
+import collections
+import contextlib
+import dataclasses
+import functools
+import itertools
+import logging
+import math
+import operator
+import textwrap
+from collections import Counter
+from typing import Any, Generic, Optional, TYPE_CHECKING, Union
+from typing_extensions import TypeVar
+
+import sympy
+
+import torch
+import torch._logging
+from torch._inductor import metrics
+from torch._inductor.ir import MultiTemplateBuffer
+from torch._inductor.tiling_utils import analyze_memory_coalescing
+from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
+from torch.fx.immutable_collections import immutable_dict
+from torch.utils._ordered_set import OrderedSet
+from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing
+from torch.utils._sympy.symbol import (
+    free_symbol_is_type,
+    prefix_str,
+    symbol_is_type,
+    SymT,
+)
+
+from ..._dynamo.utils import counters
+from .. import config, ir, scheduler
+from ..analyze_preserves_zero_mask import prologue_preserves_zero_mask
+from ..codecache import code_hash, PyCodeCache
+from ..dependencies import MemoryDep, StarDep, WeakDep
+
+
+if TYPE_CHECKING:
+    from collections.abc import Callable
+
+    from ..ir import IRNode
+
+from ..optimize_indexing import indexing_dtype_strength_reduction
+from ..runtime.coordinate_descent_tuner import CoordescTuner
+from ..runtime.hints import DeviceProperties
+from ..runtime.runtime_utils import green_text, last_power_of_2, yellow_text
+from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse
+from ..utils import (
+    cache_property_on_self,
+    expr_fits_within_32bit,
+    get_dtype_size,
+    IndentedBuffer,
+    Placeholder,
+    prefix_is_reduction,
+    sympy_index_symbol,
+    sympy_product,
+    sympy_subs,
+    unique,
+)
+from ..virtualized import ops, OpsWrapper, V
+from .block_analysis import BlockPatternMatcher
+from .common import CSEVariable, index_prevent_reordering, Kernel, PythonPrinter
+from .multi_kernel import MultiKernel, SizeHintMultiKernel
+from .simd_kernel_features import (
+    DisableReduction,
+    EnableReduction,
+    NodeScheduleEntry,
+    NodeScheduleMarker,
+    SIMDKernelFeatures,
+)
+
+
+if TYPE_CHECKING:
+    from collections.abc import Iterable, Iterator, Sequence
+
+    from torch._inductor.tiling_utils import CoalesceVarAnalysis
+
+
+log = logging.getLogger(__name__)
+perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
+schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
+fusion_log = torch._logging.getArtifactLogger(__name__, "fusion")
+
+
+pexpr = PythonPrinter().doprint
+
+all_prefixes = OrderedSet(["z", "y", "x", "r0_", "r1_"])
+
+
+def get_max_tiles(default: int = 2) -> int:
+    max_tiles = torch._inductor.config.triton.max_tiles
+    return max_tiles if max_tiles is not None else default
+
+
+@dataclasses.dataclass
+class IterationRanges:
+    """
+    Each range tree represents multiple sets of iteration indexing
+    in a single tiled dimension in the output kernel.
+
+    If you have two loops ranges one (4, 3, 2) and another (4, 6),
+    then the range tree will be:
+            4 (i0)
+        3 (i1)  6 (i3)
+        2 (i2)
+    Where i0 is shared between both loops, but then the split into
+    different indexing vars.  All loop ranges must iterate over
+    the same number of elements.
+    """
+
+    def __init__(
+        self,
+        name: str,
+        var_list: list[sympy.Symbol],
+        var_ranges: dict[sympy.Symbol, sympy.Expr],
+        numel: sympy.Expr,
+        prefix: str,
+        *,
+        kernel: SIMDKernel,
+        divisor=sympy.S.One,
+        length=sympy.S.One,
+        root: IterationRangesRoot,
+    ) -> None:
+        super().__init__()
+        self.name = name
+        self.var_list = var_list
+        self.var_ranges = var_ranges
+        self.numel = numel
+        self.prefix = prefix
+        self.divisor = divisor
+        self.length = length
+        self.kernel = kernel
+        self.root = root
+
+    @property
+    @cache_property_on_self
+    def is_reduction(self) -> bool:
+        return prefix_is_reduction(self.prefix)
+
+    def symbol(self) -> sympy.Symbol:
+        return sympy_index_symbol(self.name)
+
+    @property
+    @cache_property_on_self
+    def symt(self) -> SymT:
+        prefix_to_symt = {prefix: symt for symt, prefix in prefix_str.items()}
+        return prefix_to_symt[self.prefix]
+
+
+class IterationRangesRoot(IterationRanges):
+    """
+    Root of a iteration range tree that represents a single
+    tiled dimension in the output kernel. It contains multiple
+    sets of iteration represented with IterationRangesEntry.
+    """
+
+    def __init__(
+        self,
+        name: str,
+        numel: sympy.Expr,
+        prefix: str,
+        index: int,
+        kernel: SIMDKernel,
+        pid_cache: Optional[dict[str, str]] = None,
+        *,
+        is_loop: bool,
+        tensor_dim: Optional[int],
+        grid_dim: Optional[int],
+        has_zdim: bool,
+    ) -> None:
+        if pid_cache is None:
+            pid_cache = {}
+        super().__init__(
+            name=name,
+            var_list=[],
+            var_ranges={},
+            numel=numel,
+            prefix=prefix,
+            kernel=kernel,
+            root=self,
+        )
+        self.index = index
+        # Store all the nodes in one flat list
+        self.nodes: dict[sympy.Expr, IterationRangesEntry] = {}
+        # This is for re-ordering program ID in triton mm template
+        # pid_cache["tl.program_id(0)"] = pid_m
+        self.pid_cache: dict[str, str] = pid_cache
+
+        # True if the dimension is implemented as a single program looping over
+        # the full dimension (currently only used for non-persistent reduction)
+        # pyrefly: ignore [missing-argument]
+        assert not is_loop or (self.is_reduction and grid_dim is None)
+        self.is_loop = is_loop
+        # Index of corresponding dimension on triton tensors
+        self.tensor_dim = tensor_dim
+        # Index of corresponding dimension in the triton grid
+        self.grid_dim = grid_dim
+        self.has_zdim = has_zdim
+
+    def __repr__(self) -> str:
+        return f"IterationRangesRoot({self.name!r}, {self.numel}, ...)"
+
+    def cache_clear(self) -> None:
+        for node in self.nodes.values():
+            node.cache_clear()
+
+    def index_sym(self) -> sympy.Symbol:
+        return sympy_index_symbol(f"{self.prefix}index")
+
+    def lookup(self, divisor: sympy.Expr, length: sympy.Expr) -> IterationRangesEntry:
+        """
+        Lookup a given RangeTreeEntry, creating it if needed
+        """
+        if V.graph.sizevars.statically_known_equals(divisor * length, self.numel):
+            expr = FloorDiv(self.index_sym(), divisor)
+        else:
+            expr = ModularIndexing(self.index_sym(), divisor, length)
+
+        if expr not in self.nodes:
+            node = IterationRangesEntry(
+                f"{self.prefix}{next(V.kernel.iter_vars_count)}",
+                divisor,
+                length,
+                expr,
+                self,
+            )
+            V.kernel.range_tree_nodes[node.symbol()] = node
+            self.var_list.append(node.symbol())
+            self.var_ranges[node.symbol()] = length
+            self.nodes[expr] = node
+        return self.nodes[expr]
+
+    def construct_entries(
+        self, lengths: list[sympy.Expr]
+    ) -> list[IterationRangesEntry]:
+        divisor = sympy.S.One
+        itervars = []
+        for length in reversed(lengths):
+            itervars.append(self.lookup(divisor, length))
+            divisor = divisor * length
+        return [*reversed(itervars)]
+
+    def construct(self, lengths: list[sympy.Expr]) -> list[sympy.Symbol]:
+        return [e.symbol() for e in self.construct_entries(lengths)]
+
+    def vars_and_sizes(
+        self, index: sympy.Expr
+    ) -> tuple[list[sympy.Symbol], list[sympy.Expr]]:
+        """Figure out vars from this tree used in index"""
+
+        def get_sort_key(x: IterationRangesEntry) -> tuple[int, bool]:
+            """
+            Gets the key for sorting nodes. When two nodes have the
+            same divisor, the node with length as 1 should be handled
+            first so the current divisor is not changed after multiplied
+            node.length. Returns `not length_is_one_hint` for ascending
+            sort.
+            """
+            divisor_hint = V.graph.sizevars.size_hint(
+                x.divisor, fallback=config.unbacked_symint_fallback
+            )
+            length_is_one_hint = (
+                V.graph.sizevars.size_hint(
+                    x.length, fallback=config.unbacked_symint_fallback
+                )
+                == 1
+            )
+            return (divisor_hint, not length_is_one_hint)
+
+        nodes = [V.kernel.range_tree_nodes.get(s) for s in index.free_symbols]
+        nodes = [n for n in nodes if n and n.prefix == self.prefix]
+        nodes.sort(key=lambda x: get_sort_key(x))
+        divisor = sympy.S.One
+        index_vars = []
+        sizes = []
+
+        def add(node):
+            nonlocal divisor
+            index_vars.append(node.symbol())
+            sizes.append(node.length)
+            divisor = divisor * node.length
+
+        for node in nodes:
+            if not V.graph.sizevars.statically_known_equals(node.divisor, divisor):
+                # fill in unused index var
+                add(self.lookup(divisor, FloorDiv(node.divisor, divisor)))
+                divisor = node.divisor
+            add(node)
+        if not V.graph.sizevars.statically_known_equals(self.numel, divisor):
+            # fill in unused index var
+            add(self.lookup(divisor, FloorDiv(self.numel, divisor)))
+
+        return [*reversed(index_vars)], [*reversed(sizes)]
+
+
+class IterationRangesEntry(IterationRanges):
+    def __init__(
+        self,
+        name: str,
+        divisor: sympy.Expr,
+        length: sympy.Expr,
+        expr: sympy.Expr,
+        parent: IterationRanges,
+    ) -> None:
+        super().__init__(
+            name=name,
+            numel=parent.numel / length,
+            var_list=parent.var_list,
+            var_ranges=parent.var_ranges,
+            prefix=parent.prefix,
+            divisor=divisor,
+            length=length,
+            kernel=parent.kernel,
+            root=parent.root,
+        )
+        self.parent = parent
+        self.codegen = functools.lru_cache(None)(self._codegen)
+        self.expr = expr
+
+    def __repr__(self) -> str:
+        return f"IterationRangesEntry({self.name}, {self.divisor}, {self.length}, {self.expr}, {self.var_ranges})"
+
+    def set_name(self, name: str) -> None:
+        self.codegen = lambda: name  # type: ignore[assignment]
+        self.codegen.cache_clear = lambda: None  # type: ignore[method-assign]
+        self.name = name
+
+    def cache_clear(self) -> None:
+        self.codegen.cache_clear()
+
+    def _codegen(self) -> str:
+        V.kernel.codegen_iteration_ranges_entry(self)
+        return self.name
+
+    def precomputed_args(self) -> list[sympy.Expr]:
+        # for dynamic shapes, find parts of indexing expressions that have to be precomputed
+        precomputed_args: list[sympy.Expr] = []
+        if isinstance(self.expr, sympy.Symbol):
+            return precomputed_args
+        assert isinstance(self.expr, (FloorDiv, ModularIndexing)), type(self.expr)
+        for arg in self.expr.args[1:]:
+            if not isinstance(arg, (sympy.Integer, sympy.Symbol)):
+                symbols = arg.free_symbols
+                if len(symbols) > 0 and all(
+                    symbol_is_type(s, SymT.SIZE) for s in symbols
+                ):
+                    precomputed_args.append(arg)
+        return precomputed_args
+
+    def __hash__(self) -> int:
+        return hash(self.name)
+
+    def __eq__(self, other: object) -> bool:
+        assert isinstance(other, IterationRangesEntry)
+        return self.name == other.name
+
+
+def constant_repr(value: Union[int, float]) -> str:
+    if value == float("inf"):
+        return 'float("inf")'
+    elif value == float("-inf"):
+        return 'float("-inf")'
+    elif math.isnan(value):
+        return 'float("nan")'
+    return repr(value)
+
+
+CSEVariableType = TypeVar("CSEVariableType", bound=CSEVariable, default=CSEVariable)
+
+
+@dataclasses.dataclass
+class PartialAccumulate:
+    buffer_name: str
+    reduction_type: str
+    value: Any
+
+
+class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
+    """
+    Common base class for Triton/Halide codegen which both use flattened indexing rather than loop nests.
+    """
+
+    sexpr: Callable[[sympy.Expr], str] = pexpr
+    kexpr: Callable[[sympy.Expr], str]
+    allow_block_ptr: bool = False
+    # pyrefly: ignore [bad-override]
+    kernel_name: str
+
+    def __init__(
+        self,
+        tiling: dict[str, sympy.Expr],
+        features: SIMDKernelFeatures,
+        pid_cache: Optional[dict[str, str]] = None,
+        override_persistent_reduction: Optional[bool] = None,
+        override_cooperative_reduction: Optional[bool] = None,
+        tiling_scores: Optional[dict[str, sympy.Expr]] = None,
+        mix_order_reduction: bool = False,
+    ) -> None:
+        if pid_cache is None:
+            pid_cache = {}
+        super().__init__()
+        self.features = features
+        self.mutations = features.get_mutations()
+        self.body = IndentedBuffer()
+        self.indexing_code = IndentedBuffer()
+        self.numels = {
+            prefix: V.graph.sizevars.simplify(val) for prefix, val in tiling.items()
+        }
+        self.range_trees: list[IterationRangesRoot] = []
+        self.range_tree_nodes: dict[sympy.Symbol, IterationRangesEntry] = {}
+        self.iter_vars_count = itertools.count()
+        self.inside_reduction = features.is_reduction()
+        self.cooperative_reduction: bool = (
+            override_cooperative_reduction
+            if override_cooperative_reduction is not None
+            else self.should_use_cooperative_reduction()
+        )
+        self.tiling_scores: Optional[dict[str, sympy.Expr]] = tiling_scores
+        self.tiling: dict[str, sympy.Expr] = tiling
+        self.persistent_reduction: bool = (
+            override_persistent_reduction
+            if override_persistent_reduction is not None
+            else self.should_use_persistent_reduction()
+        )
+        self.mix_order_reduction: bool = mix_order_reduction
+        self.no_x_dim = self.want_no_x_dim()
+        self.code_hash: Optional[str] = None
+        # Info to enable multiple store_output calls for epilogue subtiling
+        self.store_output_ctr = itertools.count()
+        self.is_native_matmul = False
+        if config.triton.native_matmul:
+            for node in self.features.node_schedule:
+                if (
+                    isinstance(node, scheduler.SchedulerNode)
+                    and isinstance(node.node, ir.ComputedBuffer)
+                    and node.node.get_reduction_type() == "dot"
+                ):
+                    self.is_native_matmul = True
+                    break
+
+        # define this in a closure to make cache local to object
+        @functools.cache
+        def simplify_indexing(index: sympy.Expr):
+            index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges())
+            for tree in self.range_trees:
+                index = self.combine_contiguous_dims(index, tree)
+
+            return self.combine_modular_indexing_pairs(index)
+
+        self.simplify_indexing = simplify_indexing
+        self.initialize_range_tree(pid_cache)
+
+        self.rsplit_size = 0
+        self.saved_partial_accumulate: list[PartialAccumulate] = []
+
+    def _get_store_output_subgraph_name(self, i: int) -> str:
+        return f""
+
+    def get_store_output_count(self):
+        total = next(self.store_output_ctr)
+        self.store_output_ctr = itertools.count(start=total - 1, step=1)
+        return total
+
+    @property
+    @cache_property_on_self
+    def num_reduction_dims(self) -> int:
+        return sum(prefix_is_reduction(prefix) for prefix in self.numels)
+
+    def dtype_to_str(self, dtype: torch.dtype) -> str:
+        raise NotImplementedError
+
+    def get_index_dtype_as_torch_dtype(self) -> torch.dtype:
+        return self.features.select_index_dtype()
+
+    @property
+    def index_dtype(self) -> str:
+        return self.dtype_to_str(self.get_index_dtype_as_torch_dtype())
+
+    def want_no_x_dim(self) -> bool:
+        return False
+
+    def construct_range_trees(
+        self,
+        pid_cache: Optional[dict[str, str]],
+        inside_reduction: bool,
+        is_reduction: bool,
+        numels: dict[str, sympy.Expr],
+        no_x_dim: bool,
+    ) -> list[IterationRangesRoot]:
+        active_prefixes = OrderedSet(
+            prefix for prefix in all_prefixes if prefix in numels
+        )
+        no_r_dim = not inside_reduction or not is_reduction
+
+        def filtered_index_map(seq, mask) -> dict[Any, int]:
+            return {
+                val: idx for idx, val in enumerate(val for val in seq if val in mask)
+            }
+
+        grid_dims = ["x", "y", "z"]
+        pointwise_tensor_dims = list(reversed(grid_dims))
+        reduction_dims = ["r0_", "r1_"]
+        if no_x_dim:
+            tensor_dims = reduction_dims
+        elif no_r_dim:
+            tensor_dims = pointwise_tensor_dims
+        else:
+            tensor_dims = pointwise_tensor_dims + reduction_dims
+
+        # Filter out unused tensor dims.
+        # Convert to dicts for O(1) index lookup.
+        tensor_dim_map = filtered_index_map(tensor_dims, active_prefixes)
+        grid_dim_map = filtered_index_map(grid_dims, all_prefixes)
+
+        range_trees = []
+        for i, prefix in enumerate(active_prefixes):
+            is_reduction = prefix_is_reduction(prefix)
+            tensor_dim = tensor_dim_map.get(prefix)
+            grid_dim = grid_dim_map.get(prefix)
+            index = i if grid_dim is None else grid_dim
+            range_trees.append(
+                IterationRangesRoot(
+                    f"{prefix}index",
+                    numels[prefix],
+                    prefix,
+                    index,
+                    self,  # type: ignore[arg-type]
+                    pid_cache=pid_cache,
+                    is_loop=is_reduction and not self.persistent_reduction,
+                    tensor_dim=tensor_dim,
+                    grid_dim=grid_dim,
+                    has_zdim="z" in numels,
+                )
+            )
+        return range_trees
+
+    def initialize_range_tree(self, pid_cache: dict[str, str]) -> None:
+        range_trees = self.construct_range_trees(
+            pid_cache,
+            self.inside_reduction,
+            self.features.is_reduction(),
+            self.numels,
+            self.no_x_dim,
+        )
+        self.range_trees.extend(range_trees)
+
+    def finalize_indexing(self, indices: Sequence[sympy.Expr]) -> None:
+        """
+        Hook called right before codegen with every index that will be
+        used in the fused kernel.
+        """
+
+    def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None:
+        prior = self.inside_reduction
+        self.inside_reduction = False
+        try:
+            return self.store(name, index, value)
+        finally:
+            self.inside_reduction = prior
+
+    def should_use_cooperative_reduction(self) -> bool:
+        return False  # defined in subclass
+
+    def should_use_persistent_reduction(self) -> bool:
+        return False  # defined in subclass
+
+    def var_ranges(self) -> dict[sympy.Symbol, sympy.Expr]:
+        return dict(
+            itertools.chain.from_iterable(
+                tree.var_ranges.items() for tree in self.range_trees
+            )
+        )
+
+    def triton_tensor_ndim(self) -> int:
+        return sum(int(tree.tensor_dim is not None) for tree in self.range_trees)
+
+    def indexing_size_str(self, i: int) -> str:
+        sizes = ["None"] * self.triton_tensor_ndim()
+        sizes[i] = ":"
+        return f"[{', '.join(sizes)}]"
+
+    def dense_size_list(self) -> list[str]:
+        sizes = ["1"] * self.triton_tensor_ndim()
+        for tree in self.range_trees:
+            if tree.tensor_dim is None:
+                continue
+
+            # pyrefly: ignore [missing-argument]
+            if not tree.is_reduction or self.inside_reduction:
+                sizes[tree.tensor_dim] = f"{tree.prefix.upper()}BLOCK"
+        return sizes
+
+    def create_constant_mask(self, entry) -> str:
+        x = entry.prefix
+        if entry.tensor_dim is None:
+            sizestr = self.dense_size_str()
+            return f"{x}mask = tl.full({sizestr}, True, tl.int1)"
+        sizes = ["None"] * self.triton_tensor_ndim()
+        sizes[entry.tensor_dim] = ":"
+        suffix = ", ".join(sizes)
+        out = f"{x}mask = tl.full([{x.upper()}BLOCK], True, tl.int1)[{suffix}]"
+        return out
+
+    def dense_size_str(self) -> str:
+        sizes = self.dense_size_list()
+        return f"[{', '.join(sizes)}]"
+
+    def combine_modular_indexing_pairs(self, index: sympy.Expr) -> sympy.Expr:
+        if not isinstance(index, ModularIndexing):
+            return index
+        x = index.args[0]
+        if (tree_node := self.range_tree_nodes.get(x)) is None:
+            return index
+        new_index = sympy_subs(index, {x: tree_node.expr})
+        new_index = V.graph.sizevars.combine_modular_indexing_pairs(new_index)
+        # the index now contains xindex/etc, which is nonstandard, fix it up
+        return sympy_subs(
+            new_index,
+            {
+                tree_node.root.index_sym(): tree_node.root.lookup(
+                    sympy.S.One, tree_node.root.numel
+                ).symbol()
+            },
+        )
+
+    def combine_contiguous_dims(
+        self, index: sympy.Expr, tree: IterationRangesRoot
+    ) -> sympy.Expr:
+        if expand_res := V.graph.sizevars.expand_floor_div(index):
+            new_index, denominator = expand_res  # type: ignore[misc]
+            return FloorDiv(self._combine_contiguous_dims(new_index, tree), denominator)
+        else:
+            return self._combine_contiguous_dims(index, tree)
+
+    def _combine_contiguous_dims(
+        self, index: sympy.Expr, tree: IterationRangesRoot
+    ) -> sympy.Expr:
+        """
+        More aggressive simplification to merge contiguous dims
+        """
+        if isinstance(index, (sympy.Integer, sympy.Symbol)):
+            return index
+        index_vars, sizes = tree.vars_and_sizes(index)
+        if len(sizes) <= 1:
+            return index
+        new_sizes, reindex, _prune = V.graph.sizevars._simplify_loops(
+            index_vars, sizes, index_prevent_reordering([index], index_vars, sizes)
+        )
+        if new_sizes == sizes:
+            return index
+        new_index_vars = tree.construct(new_sizes)
+        new_index = sympy_subs(index, dict(zip(index_vars, reindex(new_index_vars))))
+        return new_index
+
+    def disable_reduction(self) -> contextlib.AbstractContextManager[None]:
+        should_flush = self.range_trees[-1].is_loop or self.cooperative_reduction
+
+        @contextlib.contextmanager
+        def ctx():
+            if not self.features.is_reduction():
+                assert not self.inside_reduction
+                yield
+                return
+            if should_flush:
+                # calling codegen_body() will flush all the pending buffers
+                # and write out a reduction loop
+                self.codegen_body()
+            self.inside_reduction = False
+            try:
+                yield
+                if should_flush:
+                    # flush out any code before opening the next loop
+                    self.codegen_body()
+            finally:
+                self.inside_reduction = True
+
+        return ctx()
+
+    def set_ranges(self, *lengths: sympy.Expr) -> list[sympy.Symbol]:
+        assert len(lengths) == len(self.range_trees)
+        return [
+            ranges.construct(length)
+            for length, ranges in zip(lengths, self.range_trees)
+        ]
+
+    @staticmethod
+    def _split_iteration_ranges(
+        groups: Iterable[sympy.Expr], lengths: Sequence[Sequence[sympy.Expr]]
+    ) -> tuple[
+        list[list[sympy.Expr]], list[list[Callable[[list[sympy.Expr]], sympy.Expr]]]
+    ]:
+        # Special case: if a node's sizes are ([], []), there's nothing to split.
+        if all(len(length) == 0 for length in lengths):
+            return [[] for group in groups], []
+
+        sv = V.graph.sizevars
+        new_ranges: list[list[sympy.Expr]] = [[] for _ in groups]
+        remaining = [sv.simplify(g) for g in groups]
+        var_count = itertools.count()
+
+        def add_range(i: int, expr: sympy.Expr) -> int:
+            expr = sv.simplify(expr)
+            if not sv.statically_known_multiple_of(remaining[i], expr):
+                raise CantSplit
+            # guard on the last item out
+            remaining[i] = FloorDiv(remaining[i], expr)
+            new_ranges[i].append(expr)
+            return next(var_count)
+
+        def make_combined(
+            sizes: list[sympy.Expr], idxs: list[int]
+        ) -> Callable[[list[sympy.Expr]], sympy.Expr]:
+            """
+            Builds the nested expression:
+              ((...((s1*v[i1] + v[i2]) * s2 + v[i3]) ... ) * sk + v[i(k+1)])
+            """
+            assert len(idxs) == len(sizes) + 1
+
+            def getter(flat_vars: list[sympy.Expr]) -> sympy.Expr:
+                expr = flat_vars[idxs[0]]
+                for s, idx in zip(sizes, idxs[1:]):
+                    expr = s * expr + flat_vars[idx]
+                return expr
+
+            return getter
+
+        return_getters_groups = []
+        current_group = 0
+        for length_group in lengths:
+            return_getters = []
+            for size in length_group:
+                if sv.statically_known_equals(size, 1):  # type: ignore[arg-type]
+                    return_getters.append(lambda _: sympy.S.Zero)
+                    continue
+
+                while current_group < len(remaining) and sv.statically_known_equals(
+                    remaining[current_group],
+                    1,  # type: ignore[arg-type]
+                ):
+                    # scroll to next group with remaining elements
+                    current_group += 1
+
+                # During native matmul on bmm, we enforce tiling order (z, y, x, r).
+                # When fusing a bmm node with loop (z, y, x, r) with a pw node
+                # of shape (z*y*x, 1), we need to split the pw iteration range
+                # into three dimensions.
+                # The group becomes [z, y, x, 1], with lengths ([z*y*x], []).
+                # In this case, we decompose the combined size z*y*x into three
+                # consecutive groups. Previously, _split_iteration_ranges supported
+                # splitting into at most two dimensions, but we now extend it to do
+                # three splits when the total size is divisible by all three.
+
+                # is group having (z,y,x,r=1) form?
+                is_bmm_then_pw = len(remaining) == 4 and remaining[-1] == 1
+                if (
+                    current_group + 2 < len(remaining)
+                    and sv.statically_known_gt(
+                        size, remaining[current_group] * remaining[current_group + 1]
+                    )
+                    and is_bmm_then_pw
+                ):
+                    # need to break size in three
+                    if not sv.statically_known_multiple_of(
+                        size, remaining[current_group] * remaining[current_group + 1]
+                    ):
+                        raise CantSplit
+
+                    size1 = remaining[current_group]
+                    size2 = remaining[current_group + 1]
+                    size3 = FloorDiv(size, size1 * size2)
+                    return_getters.append(
+                        make_combined(
+                            [size2, size3],
+                            [
+                                add_range(current_group, size1),
+                                add_range(current_group + 1, size2),
+                                add_range(current_group + 2, size3),
+                            ],
+                        )
+                    )
+
+                # Two-dimensional tiling: split size across current_group and next group.
+                elif current_group + 1 < len(remaining) and (
+                    sv.statically_known_gt(size, remaining[current_group])
+                    or
+                    # statically_known_gt(size, remaining) may return False for symbolic
+                    # expressions like 64*u0 vs u0, because both could be 0. Similarly for
+                    # backed expressions like s25*(((s70 - 5)//4)) - s25 and
+                    # (s25*(((s70 - 5)//4)) - s25)*64.
+                    # We want to assume tensor sizes are not 0 and pass the gt
+                    # using the following logic.
+                    #
+                    # if A//B = C and C >= 1
+                    # then A = B * C + R
+                    # and assuming A!=0
+                    # A must be > B .
+                    #
+                    sv.statically_known_gt(FloorDiv(size, remaining[current_group]), 1)
+                ):
+                    # need to break size in two
+                    if not sv.statically_known_multiple_of(
+                        size, remaining[current_group]
+                    ):
+                        raise CantSplit
+
+                    size1 = remaining[current_group]
+                    size2 = FloorDiv(size, remaining[current_group])
+                    return_getters.append(
+                        make_combined(
+                            [size2],
+                            [
+                                add_range(current_group, size1),
+                                add_range(current_group + 1, size2),
+                            ],
+                        )
+                    )
+                else:
+                    if current_group < len(remaining):
+                        return_getters.append(
+                            operator.itemgetter(add_range(current_group, size))
+                        )
+            return_getters_groups.append(return_getters)
+
+        assert all(V.graph.sizevars.size_hint(s) == 1 for s in remaining), (
+            f"failed to set ranges {remaining} {lengths}"
+        )
+        return new_ranges, return_getters_groups
+
+    @classmethod
+    def prepare_split_iteration_lengths(
+        cls,
+        groups: Iterable[sympy.Expr],
+        lengths: Sequence[Sequence[sympy.Expr]],
+        reduction_numel: sympy.Expr = sympy.S.One,
+    ) -> Sequence[Sequence[sympy.Expr]]:
+        "Fill in the reduction numel of lengths if missing"
+        sizevars = V.graph.sizevars
+        if len(lengths[1]) == 0 and (
+            not sizevars.statically_known_equals(reduction_numel, sympy.S.One)
+            and sizevars.statically_known_equals(
+                sympy_product(groups),
+                sympy_product(lengths[0]) * reduction_numel,
+            )
+        ):
+            return (lengths[0], [reduction_numel])
+
+        return lengths
+
+    @classmethod
+    def is_compatible(
+        cls,
+        groups: Iterable[sympy.Expr],
+        lengths: Sequence[Sequence[sympy.Expr]],
+        reduction_numel: sympy.Expr = sympy.S.One,
+    ) -> bool:
+        lengths = cls.prepare_split_iteration_lengths(groups, lengths, reduction_numel)
+
+        try:
+            cls._split_iteration_ranges(groups, lengths)
+            return True
+        except CantSplit:
+            return False
+
+    def split_and_set_ranges(
+        self, lengths: Sequence[Sequence[sympy.Expr]]
+    ) -> list[list[sympy.Expr]]:
+        """
+        Split and set iteration ranges for the kernel based on the provided lengths.
+
+        This method maps the kernel's tiling structure to the node's iteration space,
+        handling both pointwise and reduction dimensions appropriately.
+
+        Args:
+            lengths: A sequence of sequences of symbolic expressions representing
+                    the sizes of different dimensions for each node.
+
+        Returns:
+            A list of lists of symbolic expressions representing the mapped
+            iteration variables for each dimension.
+        """
+        # Create a dictionary mapping each range tree prefix to its total number of elements
+        tiling = {rt.prefix: rt.numel for rt in self.range_trees}
+
+        # If we're not inside a reduction loop, set all reduction dimensions to 1
+        # This effectively disables reduction dimensions when not needed
+        if not self.inside_reduction:
+            for prefix in tiling:
+                if prefix_is_reduction(prefix):
+                    tiling[prefix] = sympy.S.One
+
+        # Extract the values from the tiling dictionary to create groups
+        groups = [*tiling.values()]
+
+        # Map the kernel's group structure to the node's sizes and set the ranges
+        # using the set_ranges method, returning the resulting iteration variables
+        return self.map_kernel_groups_to_node_sizes(groups, lengths, self.set_ranges)
+
+    @classmethod
+    def map_kernel_groups_to_node_sizes(
+        cls,
+        groups: Sequence[sympy.Expr],
+        lengths: Sequence[Sequence[sympy.Expr]],
+        set_ranges,
+    ) -> list[list[sympy.Expr]]:
+        """
+        We may want to fuse `for i0 in s0*s1` into a tiled kernel with groups (s0, s1).
+
+        To do this we need to split up the iteration space of i0 into something like:
+            for i1 in s0:
+              for i2 in s1:
+                i0 = i1*s1 + i2
+                ....
+
+        This function matches and resplits lengths to the groups of
+        this kernel to enable tiled + non-tiled fusions.
+        """
+        if len(lengths) == len(groups) and all(
+            V.graph.sizevars.simplify(sympy_product(x) - g) == 0
+            for x, g in zip(lengths, groups)
+        ):
+            return set_ranges(*lengths)
+
+        new_ranges, return_getters_groups = cls._split_iteration_ranges(groups, lengths)
+        itervars = [*itertools.chain.from_iterable(set_ranges(*new_ranges))]
+        return [[fn(itervars) for fn in fns] for fns in return_getters_groups]
+
+    def is_indirect_indexing(self, index: sympy.Expr) -> bool:
+        # tmpX  means indirect indexing
+        return free_symbol_is_type(index, SymT.TMP)
+
+    def is_broadcasted(self, index: sympy.Expr) -> bool:
+        # Note. This may not be correct when there is indirect indexing
+        if self.is_indirect_indexing(index):
+            return False
+
+        index_numels = [1] * len(self.numels)
+        for symbol in index.free_symbols:
+            if symbol not in self.range_tree_nodes:
+                # Non-iterated variables, e.g. strides
+                continue
+            entry = self.range_tree_nodes[symbol]  # type: ignore[index]
+            assert isinstance(entry.parent, IterationRangesRoot)
+            index_numels[entry.parent.index] *= entry.length
+
+        # If the index variables only iterate over a subset of the kernel
+        # numels, then it must be broadcasted.
+        simplify = V.graph.sizevars.simplify
+        return any(
+            simplify(idx_range) != simplify(iter_range)  # type: ignore[arg-type]
+            for idx_range, iter_range in zip(index_numels, self.numels.values())
+        )
+
+    def index_to_str(self, index: sympy.Expr) -> str:
+        """
+        Convert an index expr to a string that can be used in output code.
+        e.g. a sympy expression "s2" may actually appear as "ks1" in the generated kernel.
+
+        Index expressions often need to be passed in as arguments to the triton kernel.
+        Rename_indexing and codegen_indexing keep track of the needed indices and add
+        new parameters to the function signature.
+        """
+        if isinstance(index, list):
+            return f"[{', '.join(map(self.index_to_str, index))}]"
+        return self.kexpr(self.rename_indexing(index))  # type: ignore[call-arg]
+
+    def prepare_indexing(
+        self,
+        index: sympy.Expr,
+    ) -> sympy.Expr:
+        index = self.simplify_indexing(index)
+        index = sympy_subs(index, V.graph.sizevars.precomputed_replacements)
+        # if simple replacements didn't get rid of floor/ceil, try full subs
+        if len(index.atoms(sympy.floor)) or len(index.atoms(sympy.ceiling)):
+            index = index.subs(V.graph.sizevars.precomputed_replacements)
+        # last resort, if no range vars are in the expr, hoist it
+        # TODO instead of trying to blindly find complicated exprs, we should hoist the
+        # inputs/outputs sizes and strides, but at the time indexing is generated
+        # kernel inputs and outputs are not set yet, we'd need a deeper refactor
+        # to do it this way
+
+        if len(index.atoms(sympy.ceiling)):
+            for a in index.atoms(sympy.ceiling):
+                # for nested exprs, atoms yields top level first (?)
+                # so if everything goes fine, lower level replacements will come up empty
+                symbols = a.free_symbols
+                if len(symbols) > 0 and all(
+                    symbol_is_type(s, (SymT.SIZE, SymT.PRECOMPUTED_SIZE))
+                    for s in symbols
+                ):
+                    replacements = {a: V.graph.sizevars.lookup_precomputed_size(a)}
+                    index = sympy_subs(index, replacements)
+
+        simp_index = self.simplify_indexing(index)
+
+        # Now that we are done simplifying we can unwrap Identity so that downstream handling
+        # for its contained expression will work. previously, tl.full wrapping of sympy.Integer
+        # would not occur
+        simp_index = (
+            simp_index if not isinstance(simp_index, Identity) else simp_index.args[0]
+        )
+
+        return self.codegen_indexing(simp_index)
+
+    def active_range_trees(self) -> list[IterationRangesRoot]:
+        return [
+            t
+            for t in self.range_trees
+            # pyrefly: ignore [missing-argument]
+            if not t.is_reduction or self.inside_reduction
+        ]
+
+    def codegen_indexing(self, expr: sympy.Expr) -> sympy.Expr:
+        expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges())
+        for sym in sorted(expr.free_symbols, key=str):
+            if sym in self.range_tree_nodes:
+                # if indexing expression is complicated, we precompute it on the host side
+                # and send the result as a kernel argument
+                replacements = {}
+                for ps in self.range_tree_nodes[sym].precomputed_args():  # type: ignore[index]
+                    replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps)
+                if len(replacements) > 0:
+                    self.range_tree_nodes[sym].expr = sympy_subs(  # type: ignore[index]
+                        self.range_tree_nodes[sym].expr,
+                        replacements,  # type: ignore[index]
+                    )
+                self.range_tree_nodes[sym].codegen()  # type: ignore[index]
+        return expr
+
+    def codegen_nan_check(self) -> None:
+        raise NotImplementedError("NYI: codegen_nan_check")
+
+    def deallocate_workspaces(self):
+        wrapper = V.graph.wrapper_code
+        for ws in reversed(self.args.workspace_args):
+            wrapper.generate_workspace_deallocation(ws)
+
+    def call_kernel(
+        self, name: str, node: Optional[IRNode] = None, deallocate_ws: bool = True
+    ) -> None:
+        raise NotImplementedError("NYI: call_kernel")
+
+    @contextlib.contextmanager
+    def mask_loads(
+        self, mask: Union[str, OpsWrapper], value: Union[int, float]
+    ) -> Iterator[str]:
+        """Context manager to add an additional mask to tl.load/store"""
+        prior = self._load_mask
+        prior_val = self._load_other
+        if prior:
+            mask = ops.logical_and(mask, prior)
+
+        mask = OpsWrapper._unwrap(mask)
+        self._load_mask = mask
+        self._load_other = value
+        try:
+            # TODO(jansel): do we need a reshape here?
+            yield mask
+        finally:
+            self._load_mask = prior
+            self._load_other = prior_val
+
+    def get_strides_of_load(self, index: sympy.Expr) -> dict[sympy.Symbol, sympy.Expr]:
+        """
+        This gets the stride of the index for each of the tiling variables
+        (technically, it does it at index 0)
+
+        For example, if
+        xindex = x0 + 512*x1 + 1024*r0
+        x0 = (xindex//512)
+        x1 = (xindex % 512)
+        r0 = rindex // 1024
+
+        this function would return
+        {xindex: 512, rindex: 1024}
+        """
+        index_to_tile_indexes = {k: v.expr for k, v in self.range_tree_nodes.items()}
+        index_in_tile_vars = sympy_subs(index, index_to_tile_indexes)  # type: ignore[arg-type]
+        strides = {}
+        for range_tree in self.range_trees:
+            s = sympy_index_symbol(range_tree.name)
+            strides[s] = sympy_subs(index_in_tile_vars, {s: 1}) - sympy_subs(
+                index_in_tile_vars, {s: 0}
+            )
+        return strides
+
+    @staticmethod
+    def _map_tuple_or_scalar(fn, value):
+        if isinstance(value, tuple):
+            return tuple(map(fn, value))
+        return fn(value)
+
+    def estimate_flops(self) -> Optional[int]:
+        flops = [
+            node.estimate_flops()
+            for node in NodeScheduleMarker.only_nodes(self.features.node_schedule)
+        ]
+        return sum(filter(None, flops))
+
+    def estimate_kernel_num_bytes(self):
+        """
+        Try the best to estimate the total size (in bytes) of the
+        kernel's inputs and outputs, which is used for estimating the memory
+        throughput of this kernel. This information is used for checking how
+        far we are from the peak memory bandwidth. It's important that
+        we want to avoid overestimating the sizes of the inputs and outputs,
+        because it can wrongfully give us a very large memory traffic value,
+        which may be even larger than the theoretical bandwidth and thus
+        become very misleading. This is particularly problematic for cases
+        where we slice some inputs. In those cases, we should only count
+        the size of the "slices" instead of the original inputs, because
+        only the slices contribute to the real memory traffic.
+        """
+        nbytes = []
+        ninplace_args = len(unique(self.args.inplace_buffers.values()))
+        _, call_args, _, _ = self.args.python_argdefs()
+        buf_accesses = self.features.buf_accesses()
+
+        # For pointwise and reduction kernels, this is the upper-bound numels
+        # for the output buffer.
+        # FIXME: This is not exactly right for cases like below:
+        #    def foo(tensor0, tensor1):
+        #        x0 = narrow(tensor0)
+        #        return cat(x0, tensor1)
+        # For this example, we will end up overestimate the size for the
+        # slice s0. Potentially, we could have precise inputs information
+        # if we maintained the original inputs of the Pointwise kernel created
+        # for the "cat". However, I think it might be a bit overwhelming that
+        # we add such complexity only for handling some particular cases for
+        # benchmarking.
+        out_numel = V.graph.sizevars.size_hint(
+            sympy_product(self.numels.values()),
+            fallback=config.unbacked_symint_fallback,
+        )
+        for i, arg in enumerate(call_args):
+            # "buf" may be narrowed. In this case, the number of memory accesses
+            # should be estimated based on the reinterpreted layout.
+            # On the other hand, buf may be broadcasted. In this case,
+            # counting the size of the underline storage would give us
+            # a better estimation in terms of memory accesses.
+            if arg not in buf_accesses:
+                nbytes.append(0)
+                continue
+            arg_numel = V.graph.get_numel(arg)
+            buf_size = V.graph.sizevars.size_hint(
+                arg_numel, fallback=config.unbacked_symint_fallback
+            )
+            if buf_size > out_numel:
+                # This arg points to a buf that has been sliced.
+                # We need to count each individual slice to have
+                # a better estimation.
+                indices = OrderedSet[Any]()
+                no_index_dep_count = 0
+                for dep in buf_accesses[arg]:
+                    if isinstance(dep, (StarDep, WeakDep)):
+                        indices.add(f"no_index_dep_{no_index_dep_count}")
+                        no_index_dep_count += 1
+                    else:
+                        indices.add(dep.index)
+                numel = len(indices) * out_numel
+            else:
+                numel = buf_size
+            dtype = V.graph.get_dtype(arg)
+            dtype_size = get_dtype_size(dtype)
+            # pyrefly: ignore [bad-argument-type]
+            nbytes.append(numel * dtype_size * (1 + int(i < ninplace_args)))
+        return sum(nbytes)
+
+    def warn_mix_layout(self, kernel_name):
+        """
+        Print message if the kernel have mixed layout inputs.
+        Only care about 4D tensor for now.
+        """
+        if (
+            len(self.args.input_buffers) == 1
+            and len(self.args.output_buffers) == 1
+            and len(self.args.inplace_buffers) == 0
+        ):
+            # even if input buffer and output buffer have different layout,
+            # this can be a layout conversion kernel. No need to warn for
+            # the mix layouts.
+            return
+
+        argdefs, call_args, _signature, _ = self.args.python_argdefs()
+        uniform_stride_order = None
+        # pyrefly: ignore [bad-assignment]
+        for arg_name in call_args:
+            buf = V.graph.try_get_buffer(arg_name)
+            if not buf:
+                continue
+            layout = buf.get_layout()
+            if len(layout.size) == 4:
+                # ignore the tensor if only 1 dimension is non-zero
+                if len([x for x in layout.size if x == 1]) == 3:
+                    continue
+                stride_order = ir.get_stride_order(layout.stride)
+                if uniform_stride_order is None:
+                    uniform_stride_order = stride_order
+                elif uniform_stride_order != stride_order:
+                    msg = yellow_text(
+                        f"Expected stride order {uniform_stride_order}, but found stride order"
+                        + f" {stride_order} for kernel {kernel_name}"
+                    )
+                    log.warning(msg)
+
+                    stride_order_list = [
+                        ir.get_stride_order(
+                            V.graph.get_buffer(name).get_layout().stride
+                        )
+                        if V.graph.try_get_buffer(name)
+                        else None
+                        for name in call_args
+                    ]
+                    size_list = [
+                        V.graph.get_buffer(name).get_layout().size
+                        if V.graph.try_get_buffer(name)
+                        else None
+                        for name in call_args
+                    ]
+                    source_list = [
+                        "GraphInput"
+                        if name in V.graph.graph_inputs
+                        else "IntermediateBuffer"
+                        if name in V.graph.name_to_buffer
+                        else None
+                        for name in call_args
+                    ]
+
+                    argdef_names = [x.name for x in argdefs]
+                    msg = yellow_text(
+                        f"  param names {argdef_names}\n  buf names {call_args}\n  strides {stride_order_list}"
+                        + f"\n  sizes {size_list}\n  sources {source_list}\n"
+                    )
+                    log.warning(msg)
+                    return
+        msg = green_text(
+            f"All the inputs for the triton kernel {kernel_name} have uniform layout"
+        )
+        log.warning(msg)
+
+    def welford_reduce_fallback(self, dtype, value):
+        sum_ = ops.reduction(dtype, dtype, "sum", value)
+        self.inside_reduction = False
+        rnumel = ops.index_expr(self.features.reduction_numel, dtype)
+        mean = ops.truediv(sum_, rnumel)
+
+        self.inside_reduction = True
+        dx = ops.sub(value, mean)
+        dx2 = ops.mul(dx, dx)
+        m2 = ops.reduction(dtype, dtype, "sum", dx2)
+        return OpsWrapper._unwrap((mean, m2, rnumel))
+
+    def prepare_softmax_twopass_fallback(self, dtype, value):
+        vmax = ops.reduction(dtype, dtype, "max", value)
+        sub = ops.sub(value, vmax)
+        exp = ops.exp(sub)
+        vsum = ops.reduction(dtype, dtype, "sum", exp)
+        return OpsWrapper._unwrap((vmax, vsum))
+
+    def codegen_kernel(self):
+        raise NotImplementedError
+
+    def codegen_body(self):
+        pass
+
+    def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry):
+        pass
+
+
+class SIMDScheduling(BaseScheduling):
+    """
+    Single Instruction Multiple Data parent class used for fusion across
+    multiple different backends.
+    """
+
+    kernel_type: type[Any] = SIMDKernel  # override in subclass
+
+    def group_fn(self, sizes):
+        return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes)
+
+    def can_fuse(self, node1, node2):
+        """
+        Hook called by Scheduler to determine if the Triton backend
+        can fuse node1 and node2.  These nodes might already be
+        FusedSchedulerNodes.
+        """
+        if isinstance(node1, scheduler.ForeachKernelSchedulerNode) or isinstance(
+            node2, scheduler.ForeachKernelSchedulerNode
+        ):
+            return scheduler.ForeachKernelSchedulerNode.can_fuse(node1, node2)
+
+        _, (numel1, rnumel1) = node1.group
+        _, (numel2, rnumel2) = node2.group
+        why = WhyNoFuse(node1, node2)
+
+        if node1.is_split_scan() and not node2.is_split_scan():
+            if node2.is_reduction():
+                why("Split scan cannot fuse with reductions")
+        elif node2.is_split_scan() and not node1.is_split_scan():
+            if node1.is_reduction():
+                why("Split scan cannot fuse with reductions")
+
+        if node1.is_reduction() and node2.is_reduction():
+            reduction_can_fuse = numel1 == numel2 and rnumel1 == rnumel2
+            if not reduction_can_fuse:
+                from torch._inductor.scheduler import MixOrderReduction
+
+                reduction_can_fuse = MixOrderReduction.can_fuse(node1, node2)
+
+            if not reduction_can_fuse:
+                why(
+                    "numel/rnumel mismatch (reduce) (%s, %s), (%s, %s)",
+                    numel1,
+                    numel2,
+                    rnumel1,
+                    rnumel2,
+                )
+
+            if reduction_can_fuse and (
+                node1.is_native_matmul() or node2.is_native_matmul()
+            ):
+                # Ensure node1 is always the native matmul side
+                if not node1.is_native_matmul():
+                    node1, node2 = node2, node1
+
+                # 1. A native matmul node keeps its original loop order.
+                #    For example: C[z,y,x] = torch.bmm(A[z,y,r], B[z,r,x]) keeps (z,y,x) order.
+                #    (see simplify_and_reorder in ir.py)
+                #
+                # 2. Triton kernels with native matmul always tile loops as (z,y,x)
+                #    (see get_tiling_and_scores in this file)
+                #
+                # 3. If a candidate node (node2) uses a different loop order (e.g., (z,x,y,r)),
+                #    its tiling is incompatible with native matmul tiling (z,y,x,r).
+                #    This means _split_iteration_ranges will fail, so these nodes should not be fused.
+                tiling = self.select_tiling(node1.get_nodes(), numel1, rnumel1)
+                if not all(
+                    SIMDKernel.is_compatible(
+                        tiling.values(), n2.get_ranges(), reduction_numel=rnumel1
+                    )
+                    for n2 in node2.get_nodes()
+                ):
+                    why("invalid loop order and tiling for native matmul")
+                    return False
+
+            return reduction_can_fuse
+
+        if not node1.is_reduction() and not node2.is_reduction():
+            if not (numel1 == numel2 and rnumel1 == rnumel2):
+                if not node2.is_template():
+                    why(
+                        "numel/rnumel mismatch (non-reduce) (%s, %s), (%s, %s)",
+                        numel1,
+                        numel2,
+                        rnumel1,
+                        rnumel2,
+                    )
+                    return False
+                else:
+                    # prologue fusion input sizes differ from output group
+                    # fuse so long as this node matches the group of existing prologue nodes
+                    for node in node2.get_nodes():
+                        # dont need to check epilogue nodes for prologue fusion, break after template
+                        if node.is_template():
+                            break
+                        # we would have already restricted prologue from fusing if it had multiple
+                        # uses, so it must be fusing into this node
+                        if not node.used_buffer_names() & node1.get_buffer_names():
+                            continue
+                        _, (pro_numel, pro_rnumel) = node.group
+                        if not (numel1 == pro_numel and rnumel1 == pro_rnumel):
+                            why(
+                                "numel/rnumel mismatch prologue mismatch (%s, %s), (%s, %s)",
+                                numel1,
+                                pro_numel,
+                                rnumel1,
+                                pro_rnumel,
+                            )
+                            return False
+
+            for n in (node1, node2):
+                if n.is_template():
+                    return True
+
+            # check for a bad combined tiling
+            tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1)
+            tiling2 = self.select_tiling(node2.get_nodes(), numel1, rnumel1)
+            tiling3 = self.select_tiling(
+                node1.get_nodes() + node2.get_nodes(), numel1, rnumel1
+            )
+            if config.triton.tiling_prevents_pointwise_fusion:
+                cond = True
+                if len(tiling1) > 2:
+                    if len(tiling2) > 2:
+                        cond = tiling1 == tiling2 == tiling3
+                    else:
+                        cond = tiling1 == tiling3
+                elif len(tiling2) > 2:
+                    cond = tiling2 == tiling3
+                if not cond:
+                    why(
+                        "tiling mismatch (%s, %s, %s)",
+                        tiling1,
+                        tiling2,
+                        tiling3,
+                    )
+                    return False
+
+            return True
+
+        if not node1.is_reduction() and node2.is_reduction():
+            assert rnumel1 == 1 and rnumel2 != 1
+            if numel1 == numel2 * rnumel2:
+                if not all(
+                    SIMDKernel.is_compatible((numel2, rnumel2), n.get_ranges())
+                    for n in node1.get_nodes()
+                ):
+                    why("nodes numel/rnumel incompatibility")
+                    return False
+                if (
+                    config.triton.tiling_prevents_reduction_fusion
+                    and not node1.is_template()
+                ):
+                    is_reduction_tiling_valid = tuple(
+                        self.select_tiling(node1.get_nodes(), numel1).values()
+                    ) in (
+                        (numel1, 1),
+                        (numel2, rnumel2, 1),
+                    )
+                    if not is_reduction_tiling_valid:
+                        why("invalid tiling for reduction")
+                    return is_reduction_tiling_valid
+                return True
+
+            if numel1 != numel2:
+                why("nodes numel incompatibility")
+            return numel1 == numel2
+
+        assert node1.is_reduction() and not node2.is_reduction()
+        # swap args to hit the case above
+        return self.can_fuse_horizontal(node2, node1)
+
+    can_fuse_vertical = can_fuse
+    can_fuse_horizontal = can_fuse
+
+    def generate_node_schedule(self, nodes, numel, rnumel):
+        node_schedule: list[Any] = []
+        done = OrderedSet[scheduler.BaseSchedulerNode]()
+        # Writes with a reduced shape, meaning they are only present once the
+        # reduction loop has ended
+        not_ready_yet_nodes: OrderedSet[str] = OrderedSet()
+        current_loop_buffer_usage: OrderedSet[str] = OrderedSet()
+        maybe_split_index: Optional[int] = None
+
+        def fits_in_main_body(n):
+            _, (node_numel, node_rnumel) = n.group
+            return (node_numel == numel and node_rnumel == rnumel) or (
+                node_numel == numel * rnumel and node_rnumel == 1
+            )
+
+        def fits_outside_reduction(n):
+            _, (node_numel, node_rnumel) = n.group
+            return node_numel == numel and node_rnumel == 1 and rnumel != 1
+
+        def expect_improved_memory_usage(n):
+            for read in n.read_writes.reads:
+                if read.name in current_loop_buffer_usage:
+                    return True
+            return False
+
+        def schedule_node_in_loop(n):
+            done.add(n)
+            node_schedule.append(n)
+            current_loop_buffer_usage.update([x.name for x in n.read_writes.reads])
+
+            # A scan is modelled as a reduction in the scheduler but has a
+            # full sized output that can be used inside the loop body
+            if (
+                n.is_reduction()
+                and isinstance(n, scheduler.SchedulerNode)
+                and isinstance(n.node, ir.ComputedBuffer)
+                and not isinstance(n.node.data, ir.Scan)
+            ):
+                not_ready_yet_nodes.add(n.get_name())
+            else:  # this node is available within the loop
+                current_loop_buffer_usage.update([x.name for x in n.read_writes.writes])
+
+        @contextlib.contextmanager
+        def end_current_reduction_loop():
+            nonlocal maybe_split_index
+            if node_schedule and node_schedule[-1] is EnableReduction:
+                node_schedule.pop()
+            else:
+                node_schedule.append(DisableReduction)
+            if maybe_split_index:
+                node_schedule.insert(maybe_split_index, DisableReduction)
+                node_schedule.insert(maybe_split_index + 1, EnableReduction)
+                maybe_split_index = None
+            yield
+            node_schedule.append(EnableReduction)
+            not_ready_yet_nodes.clear()
+            current_loop_buffer_usage.clear()
+
+        def requires_closing_previous_reduction(node, node_schedule):
+            if rnumel == 1:
+                return False
+            if not not_ready_yet_nodes & node.ancestors:
+                return False
+            assert node_schedule and not isinstance(
+                node_schedule[-1], (EnableReduction, DisableReduction)
+            )
+            return bool(not_ready_yet_nodes)
+
+        for node in nodes:
+            if node in done:
+                continue
+            done.add(node)
+
+            if fits_in_main_body(node):
+                if requires_closing_previous_reduction(node, node_schedule):
+                    with end_current_reduction_loop():
+                        pass  # need to start a new reduction loop
+
+                if current_loop_buffer_usage and not expect_improved_memory_usage(node):
+                    # If we don't improve memory usage, then it is better to split into two loops
+                    maybe_split_index = maybe_split_index or len(node_schedule)
+                else:
+                    # Memory usage got improved, cancel the loop split
+                    maybe_split_index = None
+
+                schedule_node_in_loop(node)
+            elif fits_outside_reduction(node):
+                with end_current_reduction_loop():
+                    node_schedule.append(node)
+            else:
+                raise NotImplementedError(
+                    f"unexpected group: ({numel}, {rnumel}) != {node.group[1]}"
+                )
+
+        return node_schedule
+
+    def codegen_mix_order_reduction(self, node):
+        node1, node2 = node.node1, node.node2
+
+        # Make sure there are no producer/consumer relationship
+        assert not (node1.ancestors & node2.get_operation_names()) and not (
+            node2.ancestors & node1.get_operation_names()
+        )
+
+        self._codegen_mix_order_reduction(node1, node2)
+
+    def _split_mix_order_reduction_epilogue(self, node):
+        # TODO: do more validation here
+        nodes = node.get_nodes()
+        reductions = []
+        epilogues = []
+        for node in nodes:
+            if node.is_reduction():
+                reductions.append(node)
+            else:
+                epilogues.append(node)
+        return reductions, epilogues
+
+    def _generate_kernel_code_for_mix_order_reduction(
+        self, kernel_features, split_size, for_benchmark
+    ):
+        """
+        for_benchmark:
+            True if the generated code is for benchmarking. We need make
+            sure benchmark harness code is generated.
+        """
+        numel, rnumel = kernel_features.numel, kernel_features.reduction_numel
+        node_schedule = kernel_features.node_schedule
+
+        kernel = self.create_kernel_choices(
+            kernel_features,
+            [{"x": numel, "r0_": rnumel}],
+            {
+                "features": kernel_features,
+                "tiling_scores": None,
+                "mix_order_reduction": True,
+                "override_persistent_reduction": True,
+            },
+        )[0]
+        assert kernel.persistent_reduction
+        assert kernel.mix_order_reduction
+        kernel.rsplit_size = split_size
+        self.codegen_node_schedule_with_kernel(node_schedule, kernel)
+
+        # allocate workspace for this kernel
+        _, ws_name, ws_off = kernel.args.workspace(
+            len(kernel.saved_partial_accumulate)
+            * kernel.numels["r0_"]
+            * ((kernel.numels["x"] + kernel.rsplit_size - 1) // kernel.rsplit_size),
+            False,
+            dtype=torch.float,
+        )
+        assert ws_off == 0, f"{ws_off=}"
+        with kernel:
+            kernel.codegen_body()
+
+        stack = contextlib.ExitStack()
+        with V.set_kernel_handler(kernel), stack:
+            if for_benchmark:
+                stack.enter_context(config.patch(benchmark_kernel=True))
+            src_code = kernel.codegen_kernel()
+
+        if for_benchmark:
+            # only do this if we are doing benchmarking.
+            # When we are generating final code, the kernel name
+            # should be decided differently with node type, fx node name
+            # etc.
+            src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_")
+        return kernel, ws_name, src_code
+
+    def benchmark_codegened_module(
+        self, mod, n_spills_threshold=8, node_names: Optional[OrderedSet[str]] = None
+    ) -> tuple[float, str]:
+        raise NotImplementedError
+
+    def _codegen_mix_order_reduction(self, node1, node2):
+        numel, rnumel = scheduler.MixOrderReduction.get_numel_rnumel(node1)
+
+        if not V.graph.sizevars.evaluate_expr(sympy.Gt(numel, rnumel)):
+            return self._codegen_mix_order_reduction(node2, node1)
+
+        def _pick_split_size():
+            # the overridden has highest priority
+            if config.triton.mix_order_reduction_split_size is not None:
+                return config.triton.mix_order_reduction_split_size
+
+            # heuristics based on number of SMs
+            device_prop = DeviceProperties.create(node1.get_device())
+            num_sm = device_prop.multi_processor_count
+            estimated_num_splits = num_sm * 8
+
+            # split_size is decided based on hint
+            numel_hint = V.graph.sizevars.size_hint(numel)
+            split_size = max(last_power_of_2(numel_hint // estimated_num_splits), 16)
+            split_size = min(split_size, 128)
+            return split_size
+
+        split_size = _pick_split_size()
+
+        # pyrefly: ignore [bad-assignment]
+        metrics.codegen_mix_order_reduction += 1
+
+        assert V.graph.sizevars.evaluate_expr(sympy.Gt(numel, rnumel))
+
+        # split epilogue out of node2
+        node2_reductions, node2_epilogue = self._split_mix_order_reduction_epilogue(
+            node2
+        )
+
+        converted_nodes = []
+        for subnode in node2_reductions:
+            subnode.cancel_reduction_split()
+            converted = subnode.extract_pw_from_reduction()
+            converted.swap_pw_red_dimension()
+            converted_nodes.append(converted)
+        node_schedule = self.generate_node_schedule(
+            node1.get_nodes() + converted_nodes, numel, rnumel
+        )
+        kernel_features = SIMDKernelFeatures(node_schedule, numel, rnumel)
+
+        # The autotuning is skipped in deterministic mode
+        if (
+            not torch._inductor.config.deterministic
+            and config.triton.mix_order_reduction_split_size is None
+            and (
+                config.triton.mix_order_reduction_autotune_split_size
+                or config.max_autotune
+                or config.coordinate_descent_tuning
+            )
+        ):
+
+            def _bench(candidate_split_size):
+                _, _, src_code = self._generate_kernel_code_for_mix_order_reduction(
+                    kernel_features,
+                    split_size=candidate_split_size,
+                    for_benchmark=True,
+                )
+                mod = PyCodeCache.load(src_code)
+                ms, _ = self.benchmark_codegened_module(mod)
+                return ms
+
+            split_size = CoordescTuner.autotune_single_field(
+                _bench,
+                split_size,
+                8,
+            )
+
+        kernel, ws_name, src_code = self._generate_kernel_code_for_mix_order_reduction(
+            kernel_features,
+            split_size=split_size,
+            for_benchmark=False,
+        )
+
+        # rename intermediate reduction output to final reduction
+        # output
+        is_split_reduction = bool(node2_reductions[0].node._split_size)
+        rename = {}
+        if is_split_reduction:
+            for subnode in node2_reductions:
+                bufname = subnode.get_outputs()[0].node.get_name()
+                username = (
+                    subnode.get_outputs()[0]
+                    .users[0]
+                    .node.get_outputs()[0]
+                    .node.get_name()
+                )
+                rename[bufname] = username
+                assert self.scheduler
+                self.scheduler.removed_ops.add(
+                    subnode.get_outputs()[0].users[0].node.get_name()
+                )
+                V.graph.removed_buffers.add(bufname)
+
+            for partial_accum in kernel.saved_partial_accumulate:
+                partial_accum.buffer_name = rename.get(
+                    partial_accum.buffer_name, partial_accum.buffer_name
+                )
+
+        kernel_name = self.define_kernel(src_code, node_schedule, kernel)
+        kernel.kernel_name = kernel_name
+        kernel.code_hash = code_hash(src_code)
+
+        with V.set_kernel_handler(kernel):
+            for node in kernel_features.scheduler_nodes():
+                # No need to allocate buffer for split reduction
+                # since we are gonna to allocate workspace to store the
+                # intermediate reduction reduction
+                if node.get_outputs()[0].node.get_name() not in rename:
+                    node.mark_run()
+
+        V.graph.wrapper_code.make_comment("# Call mix order reduction kernel")
+        self.codegen_comment(node_schedule, None)
+        # workspace args is still needed after the call
+        kernel.call_kernel(kernel.kernel_name, deallocate_ws=False)
+        V.graph.removed_buffers |= kernel.removed_buffers
+        V.graph.inplaced_to_remove |= kernel.inplaced_to_remove
+
+        # a extra round of reduction
+        assert len(converted_nodes) == len(kernel.saved_partial_accumulate)
+        nsplit = V.graph.wrapper_code.codegen_python_sizevar(
+            (numel + split_size - 1) // split_size
+        )
+        for idx, partial_accum in enumerate(kernel.saved_partial_accumulate):
+            buffer_name = partial_accum.buffer_name
+
+            stride_str = f"{nsplit} * {rnumel}"
+            start = f"{idx} * {stride_str}"
+            end = f"({idx} + 1) * {stride_str}"
+            reduction_type2op = {
+                "min": "amin",
+                "max": "amax",
+            }
+            opname = reduction_type2op.get(
+                partial_accum.reduction_type, partial_accum.reduction_type
+            )
+
+            V.graph.wrapper_code.writeline(
+                f"{buffer_name} = {ws_name}[{start} : {end}].view({nsplit}, {rnumel}).{opname}(dim=0)",
+            )
+            # mark the buffer as allocated, so we don't try to allocate
+            # it again when it's later used
+            V.graph.wrapper_code.allocated.add(buffer_name)
+
+        kernel.deallocate_workspaces()
+
+        if node2_epilogue:
+            self._codegen_nodes(node2_epilogue)
+
+        self.free_buffers_in_scheduler()
+
+    def _codegen_nodes(
+        self,
+        nodes: Sequence[scheduler.SchedulerNode],
+        coalesce_analysis: Optional[CoalesceVarAnalysis] = None,
+    ):
+        assert self.scheduler
+        nodes = [
+            node for node in nodes if node.get_name() not in self.scheduler.removed_ops
+        ]
+        if not nodes:
+            return
+        _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
+
+        node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
+        schedule_log.debug("Schedule:\n %s", node_schedule)
+
+        return self.codegen_node_schedule(
+            SIMDKernelFeatures(node_schedule, numel, rnumel, coalesce_analysis)
+        )
+
+    def codegen_node(
+        self, node: Union[scheduler.FusedSchedulerNode, scheduler.SchedulerNode]
+    ):
+        """
+        Given a set of pre-fused nodes, generate a Triton kernel.
+        """
+        assert self.scheduler
+        nodes = [
+            node
+            for node in node.get_nodes()
+            if node.get_name() not in self.scheduler.removed_ops
+        ]
+        if len(nodes) == 0:
+            return
+
+        if torch._inductor.config.triton.coalesce_tiling_analysis:
+            if len(nodes) != len(node.get_nodes()):
+                assert self.scheduler
+                node = scheduler.FusedSchedulerNode(self.scheduler, nodes)
+            coalesce_analysis = analyze_memory_coalescing(node)
+        else:
+            coalesce_analysis = None
+
+        return self._codegen_nodes(nodes, coalesce_analysis)  # type: ignore[arg-type]
+
+    @staticmethod
+    def can_use_32bit_indexing(
+        numel: sympy.Expr,
+        buffers: Iterable[
+            Union[ir.Buffer, ir.TensorBox, ir.TorchBindObject, ir.IRNode]
+        ],
+    ) -> bool:
+        int_max = torch.iinfo(torch.int32).max
+
+        if not expr_fits_within_32bit(numel):
+            return False
+
+        # Any use of a MultiOutputLayout will create a buffer with a
+        # Layout whose sizes are accounted for
+        buf_sizes = [
+            buf.get_layout().storage_size()
+            for buf in buffers
+            if buf.has_tensor_output()
+        ]
+
+        for buf in buffers:
+            if not buf.has_tensor_output() and isinstance(buf, ir.MutationOutput):
+                mutated_bufs = buf.get_mutation_buffers()
+                buf_sizes += [
+                    buf.get_layout().storage_size()
+                    for buf in mutated_bufs
+                    if buf.has_tensor_output()
+                ]
+
+        if not all(expr_fits_within_32bit(size) for size in buf_sizes):
+            return False
+
+        # Only install guards for 32-bit indexing as there is no correctness
+        # issue with using 64-bit for everything
+        V.graph.sizevars.check_leq(numel, int_max)  # type: ignore[arg-type]
+        for size in buf_sizes:
+            V.graph.sizevars.check_leq(size, int_max)  # type: ignore[arg-type]
+        return True
+
+    def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures):
+        """
+        Generate code for nodes in kernel_features
+        """
+        node_schedule = kernel_features.node_schedule
+
+        tiling, tiling_score = self.get_tiling_and_scores(
+            node_schedule,
+            kernel_features.numel,
+            kernel_features.reduction_numel,
+            kernel_features.coalesce_analysis,
+        )
+        kernels = self.create_kernel_choices(
+            kernel_features,
+            [tiling],
+            {"features": kernel_features, "tiling_scores": tiling_score},
+        )
+        for kernel in kernels:
+            self.codegen_node_schedule_with_kernel(node_schedule, kernel)
+        MultiKernel.merge_workspaces_inplace(kernels)
+        for kernel in kernels:
+            with V.set_kernel_handler(kernel):
+                src_code = kernel.codegen_kernel()
+            kernel_name = self.define_kernel(src_code, node_schedule, kernel)
+            log.debug("Generating kernel code with kernel_name: %s", kernel_name)
+            kernel.kernel_name = kernel_name
+            kernel.code_hash = code_hash(src_code)
+        del kernel
+
+        final_kernel: Union[SIMDKernel, MultiKernel]
+        if len(kernels) > 1:
+            final_kernel = MultiKernel(kernels)
+        else:
+            (final_kernel,) = kernels
+
+        with V.set_kernel_handler(final_kernel):
+            for node in kernel_features.scheduler_nodes():
+                node.mark_run()
+
+        # filter out NodeScheduleMarker
+        base_scheduler_nodes = [
+            node for node in node_schedule if isinstance(node, BaseSchedulerNode)
+        ]
+        self.codegen_comment(base_scheduler_nodes, final_kernel.kernel_name)
+        if config.cpp.enable_kernel_profile:
+            V.graph.wrapper_code.write_kernel_context_guard_begin()
+            V.graph.wrapper_code.write_kernel_context_guard(
+                final_kernel.kernel_name,
+                base_scheduler_nodes,  # type: ignore[arg-type]
+            )
+        final_kernel.call_kernel(final_kernel.kernel_name)
+        if config.cpp.enable_kernel_profile:
+            V.graph.wrapper_code.write_kernel_context_guard_end()
+
+        if config.nan_asserts:
+            final_kernel.codegen_nan_check()
+        if config.warn_mix_layout:
+            final_kernel.warn_mix_layout(kernels[0].kernel_name)
+
+        V.graph.removed_buffers |= final_kernel.removed_buffers
+        V.graph.inplaced_to_remove |= final_kernel.inplaced_to_remove
+
+        if (
+            V.graph.wrapper_code.supports_intermediate_hooks  # type: ignore[has-type]
+            and config.generate_intermediate_hooks
+        ):
+            # Not every node in the schedule will actually be live on output;
+            # we can't check dead buffers.
+            live_outs = kernels[0].args.live_output_buffers()
+            for node in kernel_features.scheduler_nodes():
+                name = node.get_name()
+                if name not in live_outs:
+                    continue
+                assert node.node is not None
+                origin_node = node.node.get_origin_node()
+                if origin_node is not None:
+                    counters["inductor"]["intermediate_hooks"] += 1
+                    V.graph.wrapper_code.writeline(
+                        f"run_intermediate_hooks({origin_node.name!r}, {name})"
+                    )
+
+        self.free_buffers_in_scheduler()
+
+    def create_kernel_choices(
+        self, kernel_features: SIMDKernelFeatures, kernel_args, kernel_kwargs
+    ) -> list[SIMDKernel]:
+        return [
+            self.kernel_type(
+                *kernel_args,
+                **kernel_kwargs,
+            )
+        ]
+
+    def codegen_node_schedule_with_kernel(self, node_schedule, kernel):
+        with kernel:
+            stack = contextlib.ExitStack()
+            all_indexing = {}
+
+            # First pass to collect indexing and decide inplace updates
+            for node in node_schedule:
+                if node is DisableReduction:
+                    stack.enter_context(kernel.disable_reduction())
+                elif node is EnableReduction:
+                    stack.close()
+                else:
+                    node.decide_inplace_update()
+                    index_vars = kernel.split_and_set_ranges(node.get_ranges())
+                    all_indexing.update(
+                        dict.fromkeys(
+                            node._body.indexing_from_args(index_vars).values()
+                        )
+                    )
+
+            kernel.finalize_indexing(all_indexing.keys())
+
+            # Second pass to do codegen
+            for node in node_schedule:
+                if node is DisableReduction:
+                    stack.enter_context(kernel.disable_reduction())
+                elif node is EnableReduction:
+                    stack.close()
+                else:
+                    # TODO - use split ranges ?
+                    indexing_dtype_strength_reduction(node._body)
+                    index_vars = kernel.split_and_set_ranges(node.get_ranges())
+                    node.codegen(index_vars)
+
+    def _codegen_single_template(
+        self,
+        kernel,
+        render,
+        template_node,
+        epilogue_nodes,
+        prologue_nodes,
+        *,
+        only_gen_src_code=False,
+    ):
+        """
+        Helper method to codegen a single template kernel variant
+        """
+        buf_name_to_prologue_group = {}
+        template_reads = template_node.used_buffer_names()
+        prologue_group = []
+        for prologue in prologue_nodes:
+            names = prologue.get_buffer_names()
+            prologue_group.append(prologue)
+            # this must be the end of a prologue group
+            if names & template_reads:
+                assert len(names) == 1
+                buf_name_to_prologue_group[next(iter(names))] = prologue_group
+                kernel.prologue_fused_inputs.add(next(iter(names)))
+                prologue_group = []
+
+        # all prologue groups should have finalized with use in template
+        assert len(prologue_group) == 0
+
+        with kernel:
+            if not only_gen_src_code:
+                # prologue nodes can only be fused if their only use is in the template,
+                # so they are necessarily not allocated
+                for node in [template_node, *epilogue_nodes]:
+                    node.mark_run()
+
+            partial_code = render()
+
+            num_store_subgraphs = kernel.get_store_output_count()
+            for i in range(num_store_subgraphs):
+                subgraph_name = kernel._get_store_output_subgraph_name(i)
+                with kernel.set_subgraph_body(subgraph_name):
+                    for node in epilogue_nodes:
+                        node.codegen(kernel.split_and_set_ranges(node.get_ranges()))
+                    kernel.cse.invalidate(OrderedSet())
+
+            for input_name, buffer in kernel.named_input_nodes.items():
+                subgraph_name = f""
+                if prologue_group := buf_name_to_prologue_group.get(
+                    buffer.get_name(), []
+                ):
+                    can_codegen_without_upcast = all(
+                        p_n.can_codegen_without_upcasts() for p_n in prologue_group
+                    )
+
+                    # TODO - this doesn't work with libdevice calls, potentially other bugs
+                    # upcasting to fp32 and downcasting gives large slowdown
+                    with config.patch(
+                        "triton.codegen_upcast_to_fp32", not can_codegen_without_upcast
+                    ):
+                        with kernel.set_subgraph_body(subgraph_name):
+                            for prologue_node in prologue_group:
+                                if (
+                                    len(prologue_node.get_buffer_names()) == 1
+                                    and len(prologue_group) == 1
+                                ):
+                                    if prologue_preserves_zero_mask(prologue_node):
+                                        kernel.prologue_fused_inputs_preserve_zero |= (
+                                            prologue_node.get_buffer_names()
+                                        )
+
+                                prologue_node.codegen(
+                                    kernel.split_and_set_ranges(
+                                        prologue_node.get_ranges()
+                                    )
+                                )
+                            kernel.cse.invalidate(OrderedSet())
+
+        # Template hooks must be finalised after kernel.remove_kernel_local_buffers
+        # is called (this is called when the kernel context is exited above), and when
+        # the kernel handler is set (as below). This is because the hooks may add
+        # DeferredLine type lines, which preclude lines involving buffers that have
+        # been removed
+
+        # finalize must be called after adding epilogue above
+        with V.set_kernel_handler(kernel):
+            if not isinstance(partial_code, str):
+                # This is used to calculate flops in TritonTemplateKernels
+                with ir.IRNode.current_origins(template_node.node.origins):
+                    partial_code.finalize_hook("")
+                partial_code.finalize_hook("", strict=False)
+
+            # TODO: Maybe unify CUDATemplateKernel to also use PartialRender for flexible epilogue fusion.
+
+            for input_name in kernel.named_input_nodes:
+                subgraph_name = f""
+                # pyrefly: ignore [missing-attribute]
+                partial_code.finalize_hook(subgraph_name, strict=False)
+
+            num_store_subgraphs = kernel.get_store_output_count()
+            for i in range(num_store_subgraphs):
+                subgraph_name = kernel._get_store_output_subgraph_name(i)
+                # pyrefly: ignore [missing-attribute]
+                partial_code.finalize_hook(subgraph_name)
+
+            if isinstance(partial_code, str):
+                src_code = partial_code
+            else:
+                # Ensure all hooks are finalized before the kernel is defined.
+                # Note: some of these hooks may have been registered by a kernel subclass
+                src_code = partial_code.finalize_remaining()
+
+            node_schedule = [*prologue_nodes, template_node, *epilogue_nodes]
+
+            if config.benchmark_kernel:
+                num_gb = kernel.estimate_kernel_num_bytes() / 1e9
+                src_code = (
+                    f"{kernel.imports_for_benchmark_kernel()}\n"
+                    f"{src_code}\n"
+                    f"{kernel.codegen_kernel_benchmark(num_gb).getvalue()}"
+                )
+
+            if only_gen_src_code:
+                return src_code
+
+            kernel.kernel_name = self.define_kernel(src_code, node_schedule, kernel)
+
+            return kernel
+
+    def _get_multikernel_shapes(
+        self, node: MultiTemplateBuffer
+    ) -> tuple[tuple[int, ...], ...]:
+        from ..ir import IRNode
+
+        def get_size(arg):
+            if not isinstance(arg, IRNode):
+                return None
+            if isinstance(arg, ir.BaseView):  # triton templates want the base tensor.
+                arg = arg.unwrap_view()
+            if (size := arg.maybe_get_size()) is None:
+                return None
+            return tuple(s for s in size)
+
+        out = []
+        for arg in list(node.inputs) + [node]:
+            if isinstance(arg, (list, tuple)):
+                out.append(tuple(get_size(_arg) for _arg in arg))
+            else:
+                out.append(get_size(arg))
+        return tuple(out)
+
+    def _kernel_has_dynamic_shapes(self, node: MultiTemplateBuffer) -> bool:
+        shapes = self._get_multikernel_shapes(node)
+        return any(
+            any(
+                isinstance(s, sympy.Expr) and not isinstance(s, sympy.Integer)
+                for s in shape
+            )
+            for shape in shapes
+        )
+
+    def _make_shape_cache_key(
+        self, node: MultiTemplateBuffer, hint: int
+    ) -> tuple[tuple[int, ...], ...]:
+        """
+        Returns cache key for hint-based multi-graph; key is tuple of shapes with hint filled in.
+        """
+        shapes = self._get_multikernel_shapes(node)
+        return tuple(
+            tuple(
+                hint
+                if isinstance(s, sympy.Expr) and not isinstance(s, sympy.Integer)
+                else s
+                for s in shape
+            )
+            for shape in shapes
+        )
+
+    def codegen_template(
+        self,
+        template_node,
+        epilogue_nodes,
+        prologue_nodes,
+        *,
+        only_gen_src_code=False,
+        hint_override: Optional[int] = None,
+    ) -> Optional[str]:
+        """
+        Codegen a triton template with multi-kernel dispatch support
+
+        If `only_gen_src_code=True` the src code will be returned instead of being
+        codegenned into the wrapper
+        """
+
+        _, (_numel, rnumel) = template_node.group
+        assert rnumel == 1
+
+        if (
+            isinstance(template_node.node, MultiTemplateBuffer)
+            and template_node.node._make_kernel_renders
+            and len(template_node.node._make_kernel_renders) > 1
+            and self._kernel_has_dynamic_shapes(template_node.node)
+        ):
+            kernels = {}
+            src_codes = []
+
+            for (
+                size_hint,
+                make_kernel_render,
+            ) in template_node.node._make_kernel_renders.items():
+                kernel, render = make_kernel_render(
+                    template_node.node, hint_override=hint_override
+                )
+
+                if only_gen_src_code:
+                    src_code = self._codegen_single_template(
+                        kernel,
+                        render,
+                        template_node,
+                        epilogue_nodes,
+                        prologue_nodes,
+                        only_gen_src_code=True,
+                    )
+                    assert isinstance(src_code, str)
+                    # pyrefly: ignore [bad-argument-type]
+                    src_codes.append(src_code)
+                else:
+                    if size_hint is None:
+                        continue  # skip kernel generation based on real runtime value; only use hints
+                    kernel = self._codegen_single_template(
+                        kernel,
+                        render,
+                        template_node,
+                        epilogue_nodes,
+                        prologue_nodes,
+                        only_gen_src_code=False,
+                    )
+                    shape_cache_key = (
+                        None
+                        if size_hint is None
+                        else self._make_shape_cache_key(template_node.node, size_hint)
+                    )
+                    kernels[shape_cache_key] = kernel
+
+            if only_gen_src_code:
+                return "\n\n".join(src_codes)
+
+            MultiKernel.merge_workspaces_inplace(list(kernels.values()))
+            multi_kernel = SizeHintMultiKernel(kernels)
+            node_schedule = [*prologue_nodes, template_node, *epilogue_nodes]
+            self.codegen_comment(node_schedule, multi_kernel.kernel_name)
+            multi_kernel.call_kernel(multi_kernel.kernel_name)
+            V.graph.removed_buffers |= multi_kernel.removed_buffers
+            V.graph.inplaced_to_remove |= multi_kernel.inplaced_to_remove
+            self.free_buffers_in_scheduler()
+            return None
+        else:
+            kernel, render = template_node.node.make_kernel_render(
+                template_node.node, hint_override=hint_override
+            )
+
+            if only_gen_src_code:
+                return self._codegen_single_template(
+                    kernel,
+                    render,
+                    template_node,
+                    epilogue_nodes,
+                    prologue_nodes,
+                    only_gen_src_code=True,
+                )
+            else:
+                kernel = self._codegen_single_template(
+                    kernel,
+                    render,
+                    template_node,
+                    epilogue_nodes,
+                    prologue_nodes,
+                    only_gen_src_code=False,
+                )
+
+                node_schedule = [*prologue_nodes, template_node, *epilogue_nodes]
+                self.codegen_comment(node_schedule, kernel.kernel_name)
+                kernel.call_kernel(kernel.kernel_name, template_node.node)
+
+                V.graph.removed_buffers |= kernel.removed_buffers
+                V.graph.inplaced_to_remove |= kernel.inplaced_to_remove
+                self.free_buffers_in_scheduler()
+                return None
+
+    def codegen_sync(self):
+        V.graph.wrapper_code.writeline(V.graph.device_ops.synchronize())
+
+    def generate_combo_kernel_code(
+        self,
+        subkernel_nodes: list[BaseSchedulerNode],
+        custom_part_algorithm: bool,
+        enable_autotune: bool,
+        mixed_sizes: bool,
+        only_gen_src_code: bool = False,
+    ) -> list[tuple[str, Any, Any]]:
+        from .triton_combo_kernel import ComboKernel
+
+        fused_node_lists = [node.get_nodes() for node in subkernel_nodes]
+        subkernel_map, node_schedule_map = {}, {}
+        for pn, nodes in zip(subkernel_nodes, fused_node_lists):
+            _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
+            node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
+            tiling = self.select_tiling(node_schedule, numel, rnumel)
+            node_schedule_map[pn] = node_schedule, tiling, numel, rnumel
+            subkernel_map[pn] = ComboKernel.create_triton_kernel(
+                tiling,
+                features=SIMDKernelFeatures(node_schedule, numel, rnumel),
+                optimize_mask=not mixed_sizes,
+            )
+
+        partitions = ComboKernel.horizontal_partition(
+            nodes=subkernel_nodes,
+            triton_scheduling=self,
+            custom_algorithm=custom_part_algorithm,
+            kernel_map=subkernel_map,
+            node_info_map=node_schedule_map,
+        )
+        log.debug(
+            "ComboKernels: %d nodes partitioned into %s groups",
+            len(subkernel_nodes),
+            [len(p) for p in partitions],
+        )
+        kernel_code_list = []
+        for node_group in partitions:
+            if len(node_group) == 0:
+                continue
+            kernel = ComboKernel(
+                enable_autotune=enable_autotune,
+                mixed_sizes=mixed_sizes,
+            )
+
+            for pn in node_group:
+                self.codegen_node_schedule_with_kernel(
+                    node_schedule_map[pn][0],
+                    kernel.create_sub_kernel(subkernel_map[pn]),
+                )
+                subkernel = subkernel_map[pn]
+                node_schedule = node_schedule_map[pn][0]
+                if not only_gen_src_code:
+                    with V.set_kernel_handler(subkernel):  # type: ignore[call-arg]
+                        for node in NodeScheduleMarker.only_nodes(node_schedule):
+                            node.mark_run()
+                V.graph.removed_buffers |= subkernel.removed_buffers
+                V.graph.inplaced_to_remove |= subkernel.inplaced_to_remove
+
+            src_code = kernel.codegen_kernel()
+            kernel_code_list.append((src_code, kernel, node_group))
+        return kernel_code_list
+
+    def codegen_combo_kernel(self, combo_kernel_node):
+        subkernel_nodes = combo_kernel_node.get_subkernel_nodes()
+        custom_part_algorithm = combo_kernel_node.use_custom_partition_algo
+        enable_autotune = combo_kernel_node.enable_autotune
+        mixed_sizes = config.combo_kernel_allow_mixed_sizes > 1 or (
+            config.combo_kernel_allow_mixed_sizes == 1 and custom_part_algorithm
+        )
+
+        kernel_code_list = self.generate_combo_kernel_code(
+            subkernel_nodes, custom_part_algorithm, enable_autotune, mixed_sizes
+        )
+
+        for src_code, kernel, _ in kernel_code_list:
+            kernel_name = self.define_kernel(src_code, [combo_kernel_node], kernel)
+            self.codegen_comment(combo_kernel_node.snodes, kernel_name)
+            log.debug("ComboKernels: generated kernel %s.", kernel_name)
+            kernel.call_kernel(V.graph.wrapper_code, kernel_name)
+
+        self.free_buffers_in_scheduler()
+
+    @classmethod
+    @functools.lru_cache(32)
+    def candidate_tilings(cls, node, numel, reduction_numel) -> list[CandidateTiling]:
+        is_pointwise = reduction_numel == 1
+
+        def tile_ranges(is_pointwise: bool, ranges, rw) -> list[CandidateTiling]:
+            """
+            Compute tiling candidates by dividing up the iteration ranges.
+            """
+            assert len(rw.range_vars) == len(ranges), f"{rw.range_vars=} {ranges=}"
+
+            # isinstance(dep, MemoryDep): this filters out StarDeps. StarDeps refer to reads
+            # that need to access the entire tensor; they don't contribute read indexing
+            # information (and practically, they don't have dep.index so they can't be used
+            # for stride_hints below
+            dep_sources = [rw.reads, rw.writes]
+            assert all(
+                isinstance(dep, (MemoryDep, StarDep))
+                for dep in itertools.chain.from_iterable(dep_sources)
+            )
+            deps = [
+                dep
+                for dep in itertools.chain.from_iterable(dep_sources)
+                if dep.name not in V.graph.removed_buffers
+                and isinstance(dep, MemoryDep)
+            ]
+            write_names = OrderedSet([dep.name for dep in rw.writes])
+
+            def collapse_ranges(ranges: Sequence[sympy.Expr]) -> sympy.Expr:
+                return V.graph.sizevars.simplify(sympy_product(ranges))
+
+            # Default to no tiling.
+            tilings = [
+                CandidateTiling(
+                    tiling=cls.create_partial_tiling(
+                        [collapse_ranges(ranges)], is_pointwise
+                    ),
+                    name="none",
+                    score=0,
+                )
+            ]
+
+            # Find non-trivial tiling candidates.
+            for dep in deps:
+                strides = V.graph.sizevars.stride_hints(dep.index, rw.range_vars)
+                assert len(strides) == len(ranges)
+                try:
+                    split = strides.index(1) + 1
+                    if split == len(ranges):
+                        continue
+                    if all(s == 0 for s in strides[split:]):
+                        # if this is a broadcasted tensor and all dimensions after split are broadcast,
+                        # this is not a real split
+                        continue
+
+                except ValueError:
+                    continue
+
+                tiled_groups = (
+                    collapse_ranges(ranges[:split]),
+                    collapse_ranges(ranges[split:]),
+                )
+
+                # score by number of elements
+                score = V.graph.sizevars.size_hint(
+                    sympy_product(
+                        size for size, stride in zip(ranges, strides) if stride != 0
+                    )
+                )
+                if dep.name in write_names:
+                    # ngimel said contiguous writes is more important than reads
+                    score *= 2
+                if CandidateTiling.is_good_size(tiled_groups[0]):
+                    score *= 2
+                if CandidateTiling.is_good_size(tiled_groups[1]):
+                    score *= 2
+
+                if (
+                    V.graph.sizevars.size_hint(
+                        score - sympy_product(itertools.chain(ranges, reduction_ranges))
+                    )
+                    >= 0
+                ):
+                    tilings.append(
+                        CandidateTiling(
+                            tiling=cls.create_partial_tiling(
+                                [
+                                    collapse_ranges(ranges[:split]),
+                                    collapse_ranges(ranges[split:]),
+                                ],
+                                reduction_numel,
+                            ),
+                            score=score,
+                            name=dep.name,
+                        )
+                    )
+
+            return tilings
+
+        pointwise_ranges, reduction_ranges = node.get_ranges()
+        if (
+            len(pointwise_ranges) <= 1
+            and len(reduction_ranges) <= 1
+            or free_unbacked_symbols(pointwise_ranges + reduction_ranges)
+        ):
+            return []
+
+        # Tile either pointwise or reduction dims.
+        pointwise_ranges, reduction_ranges = node.get_ranges()
+        partial_tilings = tile_ranges(
+            is_pointwise,
+            pointwise_ranges if is_pointwise else reduction_ranges,
+            node.pointwise_or_reduction_read_writes(is_pointwise),
+        )
+
+        # Fill in the missing ranges.
+        full_tilings = [
+            CandidateTiling(
+                tiling=cls.complete_partial_tiling(
+                    tiling.tiling, numel, reduction_numel
+                ),
+                score=tiling.score,
+                name=tiling.name,
+            )
+            for tiling in partial_tilings
+        ]
+
+        return full_tilings
+
+    @classmethod
+    def create_tiling(
+        cls, pw_tiling: Sequence[sympy.Expr], reduction_tiling: Sequence[sympy.Expr]
+    ) -> immutable_dict[str, sympy.Expr]:
+        """
+        Create a tiling dict from pointwise and reduction splits.
+        """
+        pw_prefixes = ["z", "y", "x"][-len(pw_tiling) :]
+        reduction_prefixes = ["r0_", "r1_"][: len(reduction_tiling)]
+        return immutable_dict(
+            [*zip(pw_prefixes, pw_tiling), *zip(reduction_prefixes, reduction_tiling)]
+        )
+
+    @classmethod
+    def create_partial_tiling(
+        cls,
+        tiling: Sequence[sympy.Expr],
+        is_pointwise: bool,
+    ) -> immutable_dict[str, sympy.Expr]:
+        return cls.create_tiling(
+            tiling if is_pointwise else [],
+            tiling if not is_pointwise else [],
+        )
+
+    @classmethod
+    def complete_partial_tiling(
+        cls,
+        tiling: dict[str, sympy.Expr],
+        numel: sympy.Expr,
+        reduction_numel: sympy.Expr,
+    ) -> immutable_dict[str, sympy.Expr]:
+        """
+        Given a tiling for only pointwise or reduction dimensions, adds the missing one.
+        """
+        splits = list(tiling.values())
+        is_pointwise = "x" in tiling
+
+        total_numel = numel * reduction_numel
+        missing_tiling = [total_numel / sympy_product(splits)]
+
+        tiling_args = (
+            (splits, missing_tiling) if is_pointwise else (missing_tiling, splits)
+        )
+        return cls.create_tiling(*tiling_args)
+
+    @classmethod
+    def get_nd_tilings(
+        cls,
+        node_schedule,
+        pointwise_numel,
+        reduction_numel,
+    ) -> list[immutable_dict[str, sympy.Expr]]:
+        """
+        Creates N-dimensional tiling candidates, attempting to simplify loads/stores
+        by tiling the kernel into higher dimensions.
+
+        Returns a list of tilings ranked by dimensionality.
+        """
+        is_pointwise = reduction_numel == 1
+        tilings = OrderedSet[immutable_dict[str, sympy.Expr]]()
+        for node in EnableReduction.filter(node_schedule):
+            if not isinstance(node, scheduler.SchedulerNode):
+                continue
+
+            # If this is a reduction schedule, skip nodes which are missing their
+            # reduction ranges.
+            node_ranges = node.get_ranges()
+            if not is_pointwise and len(node_ranges[1]) == 0:
+                continue
+
+            # Use the node ranges as the default tiling candidate.
+            ranges_to_tile = node_ranges[0 if is_pointwise else 1]
+            node_tilings = [ranges_to_tile]
+
+            # Search the indexing expressions for more candidates.
+            # If we see modular indexing, try to subdivide ranges into their implied
+            # block shape.
+            memory_deps = [
+                dep
+                for dep in node.read_writes.reads_and_writes()
+                if isinstance(dep, MemoryDep) and len(dep.ranges) > 0
+            ]
+            for dep in memory_deps:
+                # Attempt to partition variable ranges into pointwise and reduction groups.
+                # To achieve this, merge the leading ranges until we reach the pointwise numel.
+                all_var_ranges = [*dep.ranges.items()]
+                pointwise_vars_numel = sympy.S.One
+                sizevars = V.graph.sizevars
+                pointwise_end_idx = 0
+                for idx, (_var, numel) in enumerate(all_var_ranges):
+                    pointwise_vars_numel *= numel
+                    pointwise_end_idx = idx
+                    if sizevars.statically_known_geq(
+                        pointwise_vars_numel, pointwise_numel
+                    ):
+                        break
+
+                # Reject the split if it does not match the total pointwise numel.
+                if not sizevars.statically_known_equals(
+                    pointwise_vars_numel, pointwise_numel
+                ):
+                    continue
+
+                # Partition var ranges into pointwise and reduction splits.
+                reduction_start_idx = pointwise_end_idx + 1
+                var_ranges = (
+                    all_var_ranges[:reduction_start_idx]
+                    if is_pointwise
+                    else all_var_ranges[reduction_start_idx:]
+                )
+
+                # Pattern match the subexpression pertaining to each index variable.
+                index_tiling = []
+                for var, numel in var_ranges:
+                    index = BlockPatternMatcher.get_subexpr_involving_symbol(
+                        dep.index, var
+                    )
+
+                    # Heuristic to bound the maximum dimensionality of the block.
+                    num_dims = max(
+                        2,
+                        index.count(FloorDiv) + index.count(ModularIndexing),
+                        len(ranges_to_tile),
+                    )
+
+                    # Attempt to pattern match the index expr.
+                    # Failed matches default to the full range.
+                    match_result = BlockPatternMatcher.match_mod_div_block_expr(
+                        index, var, numel, num_dims
+                    )
+                    dims = match_result[0] if match_result is not None else [numel]
+                    index_tiling.extend(dims)
+
+                # Prune dimensions of size 1.
+                index_tiling = [
+                    dim
+                    for dim in index_tiling
+                    if not V.graph.sizevars.statically_known_equals(dim, sympy.S.One)
+                ]
+
+                if len(index_tiling) > 0:
+                    node_tilings.append(index_tiling)
+
+            # Flatten leading dimensions, assigning labels to each dim.
+            for node_tiling in node_tilings:
+                num_leading_dims = max(0, len(node_tiling) - get_max_tiles(2))
+                first_trailing_dim = num_leading_dims + 1
+                collapsed_leading_dim = sympy_product(node_tiling[:first_trailing_dim])
+                collapsed_splits = (collapsed_leading_dim,) + tuple(
+                    node_tiling[first_trailing_dim:]
+                )
+                tilings.add(
+                    cls.complete_partial_tiling(
+                        cls.create_partial_tiling(collapsed_splits, is_pointwise),
+                        pointwise_numel,
+                        reduction_numel,
+                    )
+                )
+
+        # Rank tilings by the number of dimensions. E.g., prefer 2D to 1D.
+        # Since this is a stable sort, ties are broken by schedule order.
+        ranked_tilings = sorted(
+            tilings,
+            key=len,
+            reverse=True,
+        )
+
+        return ranked_tilings
+
+    @classmethod
+    def compute_tiling_strategy(
+        cls,
+        node_schedule: list[NodeScheduleEntry],
+        pointwise_numel: sympy.Expr,
+        reduction_numel: sympy.Expr,
+        coalesce_analysis: CoalesceVarAnalysis,
+    ) -> tuple[dict[str, sympy.Expr], Optional[dict[str, sympy.Expr]]]:
+        """
+        Generates a tiling, and a score of each tile according to each tile's coalesced memory accesses.
+        """
+        tiling_var: Optional[sympy.Expr] = (
+            None
+            if not coalesce_analysis.suggested_split
+            else coalesce_analysis.suggested_split.var
+        )
+
+        all_iter_vars = coalesce_analysis.norm_read_writes.index_vars
+        all_red_vars = coalesce_analysis.norm_read_writes.reduce_vars
+        ranges = coalesce_analysis.norm_read_writes.var_ranges
+
+        pw_ranges = [ranges[v] for v in all_iter_vars]
+        red_ranges = [ranges[v] for v in all_red_vars]
+
+        torch._check(
+            sympy_product(pw_ranges) == pointwise_numel,
+            lambda: f"{pw_ranges}, {pointwise_numel}, {node_schedule}",
+        )
+
+        torch._check(
+            sympy_product(red_ranges) == reduction_numel,
+            lambda: f"{red_ranges}, {reduction_numel}, {node_schedule}",
+        )
+
+        # score of a pointwise or reduction split
+        scored_sub_split: dict[Any, tuple[list[int], list[int]]] = {}
+
+        score_split: list[
+            tuple[tuple[list[int], list[int]], tuple[list[int], list[int]]]
+        ] = []
+
+        def process_node_vars(
+            vars_to_use: tuple[sympy.Expr, ...] = (),
+            use_split_var: bool = False,
+            is_pointwise: bool = False,
+        ) -> tuple[list[int], list[int]]:
+            """
+            Generate a tiling, and a tiling score, given vars to use as splits.
+            """
+
+            ranges = pw_ranges if is_pointwise else red_ranges
+            target_numel = pointwise_numel if is_pointwise else reduction_numel
+            # Some kernels have no reduction ranges, and a reduction numel of 1
+            if not ranges:
+                if target_numel:
+                    return ([target_numel], [])
+                else:
+                    return ([], [])
+
+            key = (repr(vars_to_use), use_split_var, is_pointwise)
+            if out := scored_sub_split.get(key):
+                return out
+
+            splitting_vars = all_iter_vars if is_pointwise else all_red_vars
+
+            splits = []
+            split_scores = []
+            prod = 1
+            prev_var_coalesced_score = 0
+
+            # iterate from non-dense to dense
+            for v, v_range in zip(splitting_vars, ranges):
+                if v not in vars_to_use:
+                    prod *= v_range
+                    prev_var_coalesced_score = coalesce_analysis.coalesced_by_var.get(
+                        v, 0
+                    )
+                    continue
+
+                if use_split_var and v == tiling_var:
+                    var_tiling = coalesce_analysis.suggested_split
+                    assert var_tiling is not None
+
+                    tile = var_tiling.tiling_factor
+                    remainder = FloorDiv(v_range, var_tiling.tiling_factor)
+
+                    splits.append(prod * remainder)
+                    split_scores.append(var_tiling.score)
+
+                    splits.append(tile)
+                    split_scores.append(coalesce_analysis.coalesced_by_var.get(v, 0))
+
+                    prod = 1
+                    prev_var_coalesced_score = 0
+
+                    continue
+
+                prod *= v_range
+                splits.append(prod)
+                split_scores.append(coalesce_analysis.coalesced_by_var.get(v, 0))
+                prod = 1
+
+            if prod != 1 or (is_pointwise and len(splits) == 0):
+                splits.append(prod)
+                split_scores.append(prev_var_coalesced_score)
+
+            # penalize splits that leave small blocks
+            # where we can't fully utilize full memory transaction
+            # TODO: incorporate exact bitwidth, and read/write
+            # coalesced write is 2x more important
+            for i in range(len(splits)):
+                s = V.graph.sizevars.size_hint(splits[i], fallback=32)
+                s = min(s, 8)
+                split_scores[i] = int(split_scores[i] * s / 8)
+
+            scored_sub_split[key] = (splits, split_scores)
+            return (splits, split_scores)
+
+        # add the default tiling
+        score_split.append(
+            (
+                process_node_vars(is_pointwise=True),
+                process_node_vars(is_pointwise=False),
+            )
+        )
+
+        if tiling_var:
+            score_split.append(
+                (
+                    process_node_vars(
+                        (tiling_var,), use_split_var=True, is_pointwise=True
+                    ),
+                    process_node_vars(is_pointwise=False),
+                )
+            )
+
+        # TODO, add tests, reduction splits if config.triton.tile_reductions
+        # TODO: we should ignore tiny increases in score for extra splits
+        overlapping_iter_vars = (
+            all_iter_vars & coalesce_analysis.coalesced_by_var.keys()
+        )
+        for v in overlapping_iter_vars:
+            score_split.append(
+                (
+                    process_node_vars((v,), is_pointwise=True),
+                    process_node_vars(is_pointwise=False),
+                )
+            )
+
+        if get_max_tiles(default=3) == 3 and reduction_numel == 1:
+            for vars_to_use in itertools.combinations(overlapping_iter_vars, 2):
+                score_split.append(
+                    (
+                        process_node_vars(vars_to_use, is_pointwise=True),
+                        process_node_vars(is_pointwise=False),
+                    )
+                )
+
+        tilings: list[tuple[CandidateTiling, immutable_dict[str, sympy.Expr]]] = []
+        for (pw_split, pw_score), (red_split, red_score) in score_split:
+            candidate = CandidateTiling(
+                cls.create_tiling(pw_split, red_split),
+                score=sum(pw_score) + sum(red_score),
+            )
+            tiling_score = cls.create_tiling(pw_score, red_score)
+            tilings.append((candidate, tiling_score))
+
+        default_tiling = cls.create_tiling([pointwise_numel], [reduction_numel])
+
+        # add a slight penalty for longer tilings that dont increase score much,
+        # and are poor sizes
+        bad_size_additional_tiling_penalty = 1.025
+        good_size_tiling_penalty = 1.005
+
+        total_uncoalesced = sum(coalesce_analysis.uncoalesced_addrs.values())
+
+        def score_mod(t):
+            score_factor = 1.0
+            for tile_size in t[0].tiling.values():
+                if not CandidateTiling.is_good_size(tile_size):
+                    score_factor = score_factor / bad_size_additional_tiling_penalty
+                else:
+                    score_factor = score_factor / good_size_tiling_penalty
+
+            # Add uncoalesced memory score to prevent small coalesced benefits
+            # from dominating large amounts of uncoalesced memory
+            uncoalesced_penalty = total_uncoalesced * 0.05
+
+            return -(t[0].score + uncoalesced_penalty) * score_factor
+
+        # apply penalty for longer tilings that dont increase score much
+        for cand, tiling_score in sorted(tilings, key=score_mod):
+            if (
+                cls.tiling_is_compatible(
+                    node_schedule, pointwise_numel, reduction_numel, cand.tiling
+                )
+                or cand.tiling == default_tiling
+            ):
+                # we always include default reduction numel == 1, dont include
+                tiling_len = len(cand.tiling) - (1 if reduction_numel == 1 else 0)
+                if tiling_len > get_max_tiles(default=3):
+                    perf_hint_log.info(
+                        "Found optimal tiling with %s tiles but torch._inductor.config.triton.max_tiles "
+                        "set to %s. Consider increasing",
+                        tiling_len,
+                        torch._inductor.config.triton.max_tiles,
+                    )
+                    continue
+
+                return cand.tiling, tiling_score
+
+            # surprisingly, the default tiling is not always read as compatible by `tiling_is_compatible`
+            # TODO - look into, occurs with dynamic shapes often
+            if cand.tiling == default_tiling:
+                return cand.tiling, tiling_score
+
+        return default_tiling, None
+
+    @classmethod
+    def tiling_is_compatible(
+        cls,
+        node_schedule: list[NodeScheduleEntry],
+        numel: sympy.Expr,
+        reduction_numel: sympy.Expr,
+        tiling: dict[str, sympy.Expr],
+    ):
+        assert isinstance(tiling, dict)
+        return all(
+            SIMDKernel.is_compatible(
+                tiling.values(), node.get_ranges(), reduction_numel=reduction_numel
+            )
+            for node in node_schedule
+            if isinstance(node, scheduler.SchedulerNode)
+        )
+
+    @classmethod
+    def get_first_compatible_tiling(
+        cls,
+        node_schedule: list[NodeScheduleEntry],
+        numel: sympy.Expr,
+        reduction_numel: sympy.Expr,
+        ranked_tilings: list[dict[str, sympy.Expr]],
+    ):
+        for tiling in ranked_tilings:
+            if cls.tiling_is_compatible(node_schedule, numel, reduction_numel, tiling):
+                return tiling
+
+        return None
+
+    @classmethod
+    def select_tiling(
+        cls,
+        node_schedule,
+        numel,
+        reduction_numel=sympy.S.One,
+        coalesce_analysis: Optional[CoalesceVarAnalysis] = None,
+    ) -> dict[str, sympy.Expr]:
+        return cls.get_tiling_and_scores(
+            node_schedule, numel, reduction_numel, coalesce_analysis
+        )[0]
+
+    @classmethod
+    def get_tiling_and_scores(
+        cls,
+        node_schedule,
+        numel,
+        reduction_numel=sympy.S.One,
+        coalesce_analysis: Optional[CoalesceVarAnalysis] = None,
+    ) -> tuple[dict[str, sympy.Expr], Optional[dict[str, sympy.Expr]]]:
+        """
+        Heuristics to decide how to tile kernels.
+        Currently, we tile based on stride-1 dimensions.
+
+        Returns:
+            `(tile1, tile2, reduction_numel)` s.t. `tile1 * tile2 == numel`
+
+        """
+        # If this is a reduction, only tile reduction dims.
+        is_pointwise = reduction_numel == 1
+
+        # Tiled reductions are gated by a config flag.
+        default_tiling = cls.create_tiling([numel], [reduction_numel])
+
+        # Force tiling compatible with matmul dimensions
+        # when natively generating matmul without template calls.
+        for node in EnableReduction.filter(node_schedule):
+            if isinstance(node.node, ir.ComputedBuffer):
+                if (
+                    node.node.get_reduction_type() == "dot"
+                    and config.triton.native_matmul
+                ):
+                    # A[M,K] @ B[K,N]
+                    # force tiling to be {'y':M, 'x':N, 'r0_':K}
+                    node_ranges = node.get_ranges()
+                    range_y_x = node_ranges[0]  # (M,N)
+                    range_r = node_ranges[1]  # (K)
+                    tiling = cls.create_tiling(range_y_x, range_r)
+                    return tiling, None
+
+        # # TODO: enable by default
+        if (
+            torch._inductor.config.triton.coalesce_tiling_analysis
+            and coalesce_analysis
+            and not config.triton.prefer_nd_tiling
+        ):
+            return cls.compute_tiling_strategy(
+                node_schedule, numel, reduction_numel, coalesce_analysis
+            )
+
+        if (not is_pointwise and not config.triton.tile_reductions) or get_max_tiles(
+            default=2
+        ) <= 1:
+            # Emit a perf hint in case we miss an opportunity to tile a reduction.
+            if perf_hint_log.level <= logging.WARNING:
+                for node in EnableReduction.filter(node_schedule):
+                    if (
+                        not config.triton.tile_reductions
+                        and len(cls.candidate_tilings(node, numel, reduction_numel)) > 0
+                    ):
+                        perf_hint_log.info(
+                            textwrap.dedent(
+                                """
+                                Reduction over non-contiguous dims.
+                                Consider setting config.triton.tile_reductions to True.
+                                """
+                            )
+                        )
+                        break
+
+            return default_tiling, None
+
+        seen_names: OrderedSet[str] = OrderedSet()
+        candidate_tiles: Counter[CandidateTiling] = collections.Counter()
+        for node in EnableReduction.filter(node_schedule):
+            for candidate_tiling in cls.candidate_tilings(node, numel, reduction_numel):
+                if candidate_tiling.name in seen_names:
+                    continue
+                elif candidate_tiling.name is not None:
+                    seen_names.add(candidate_tiling.name)
+                candidate_tiles[candidate_tiling] += candidate_tiling.score
+
+        ranked_tilings: list[dict[str, sympy.Expr]] = [
+            candidate_tiling.tiling
+            for candidate_tiling, score in candidate_tiles.most_common()
+        ]
+
+        if get_max_tiles(default=2) >= 3 and is_pointwise:
+            # Consider adding a third dimension of tiling, but only
+            # when a1 is a multiple of b1; otherwise, you have a lot
+            # of stragglers which is annoying to generate code for.
+            #
+            # NB: More than three max tiles is not enabled by default.
+
+            def convert_tiling_to_3d(
+                tiling0: dict[str, sympy.Expr], tiling1: dict[str, sympy.Expr]
+            ) -> Optional[dict[str, sympy.Expr]]:
+                a0, a1 = tiling0["x"], tiling0.get("y", 1)
+                b0, b1 = tiling1["x"], tiling1.get("y", 1)
+
+                if (
+                    free_unbacked_symbols([a1, b1])
+                    or V.graph.sizevars.size_hint(a1 - b1) == 0
+                ):
+                    return None
+                if V.graph.sizevars.size_hint(a1 - b1) < 0:
+                    # swap so a0 is bigger
+                    (a0, a1), (b0, b1) = (b0, b1), (a0, a1)
+
+                assert V.graph.sizevars.size_hint(a1 - b1) > 0
+                if not V.graph.sizevars.statically_known_multiple_of(a1, b1):
+                    return None
+
+                new_tiling = {
+                    "z": a0,
+                    "y": FloorDiv(a1, b1),
+                    "x": b1,
+                    "r0_": tiling0["r0_"],
+                }
+
+                return new_tiling
+
+            for i in range(1, len(ranked_tilings)):
+                new_3d_tiling = convert_tiling_to_3d(
+                    ranked_tilings[0], ranked_tilings[i]
+                )
+                if new_3d_tiling is not None:
+                    ranked_tilings = [new_3d_tiling] + ranked_tilings
+                    break  # only 1 choice for now
+
+        if len(ranked_tilings) > 1:
+            perf_hint_log.info("possibly bad tiling: %s", ranked_tilings)
+
+        # Optionally, prefer tiling into as many dimensions as possible.
+        # pyrefly: ignore [unbound-name]
+        if config.triton.prefer_nd_tiling:
+            ranked_tilings = (
+                cls.get_nd_tilings(node_schedule, numel, reduction_numel)
+                + ranked_tilings
+            )
+
+        if tiling := cls.get_first_compatible_tiling(
+            node_schedule, numel, reduction_numel, ranked_tilings
+        ):
+            return tiling, None
+
+        return default_tiling, None
+
+    def flush(self):
+        pass
+
+    def ready_to_flush(self) -> bool:
+        return False
+
+    def generate_kernel_code_from_nodes(
+        self, nodes, benchmark_kernel=False, hint_override: Optional[int] = None
+    ):
+        if not any(n.is_template() for n in nodes):
+            _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
+            node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
+            tiling = self.select_tiling(node_schedule, numel, rnumel)
+            kernel = self.kernel_type(
+                tiling,
+                features=SIMDKernelFeatures(node_schedule, numel, rnumel),
+            )
+            self.codegen_node_schedule_with_kernel(node_schedule, kernel)
+            with (
+                config.patch("benchmark_kernel", benchmark_kernel),
+                V.set_kernel_handler(kernel),
+            ):
+                src_code = kernel.codegen_kernel()
+        else:
+            prologue, template, epilogue = nodes[0].get_prologue_template_epilogue(
+                nodes
+            )
+            with config.patch("benchmark_kernel", benchmark_kernel):
+                src_code = self.codegen_template(
+                    template,
+                    epilogue,
+                    prologue,
+                    only_gen_src_code=True,
+                    hint_override=hint_override,
+                )
+
+        # pyrefly: ignore [missing-attribute]
+        src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_")
+        return src_code
+
+    def define_kernel(self, src_code, node_schedule, kernel):
+        raise NotImplementedError
+
+
+@dataclasses.dataclass(frozen=True)
+class CandidateTiling:
+    tiling: dict[str, sympy.Expr]
+    score: int  # higher is better
+    name: Optional[str] = None
+
+    @staticmethod
+    def is_good_size(s):
+        """Somewhat arbitrary heuristic used to boost scores for some sizes"""
+        s = V.graph.sizevars.size_hint(s)
+        return s >= 32 and (s % 32 == 0)
+
+
+class CantSplit(Exception):
+    pass
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/simd_kernel_features.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/simd_kernel_features.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cb38dda5a3660e090adc7013da94577507e8a89
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/simd_kernel_features.py
@@ -0,0 +1,620 @@
+from __future__ import annotations
+
+import collections
+import dataclasses
+import functools
+import itertools
+import typing
+from typing import Any, Optional, Union
+
+import sympy
+
+import torch
+
+from ...utils._ordered_set import OrderedSet
+from ...utils._sympy.functions import FloorDiv, ModularIndexing
+from ...utils._sympy.symbol import make_symbol, SymT
+from ..dependencies import Dep, extract_loop_body_with_args, MemoryDep
+from ..runtime.hints import ReductionHint
+from ..scheduler import SchedulerNode
+from ..utils import cache_on_self
+from ..virtualized import V
+
+
+if typing.TYPE_CHECKING:
+    from collections.abc import Iterable, Sequence
+
+    from torch._inductor.tiling_utils import CoalesceVarAnalysis
+
+
+class NodeScheduleMarker:
+    @staticmethod
+    def only_nodes(it: Iterable[NodeScheduleEntry]) -> Iterable[SchedulerNode]:
+        for item in it:
+            if not (item is DisableReduction or item is EnableReduction):
+                yield item  # type: ignore[misc]
+
+    @staticmethod
+    def is_reduction() -> bool:
+        return False
+
+
+NodeScheduleEntry = Union[SchedulerNode, type[NodeScheduleMarker]]
+
+
+class DisableReduction(NodeScheduleMarker):
+    """
+    Marker to invoke `kernel.disable_reduction()`.  This closes a
+    reduction loop and allows for pointwise ops to occur on the output
+    of a reduction.
+    """
+
+
+class EnableReduction(NodeScheduleMarker):
+    """
+    Marker to end a DisableReduction block.
+    """
+
+    @staticmethod
+    def filter(node_schedule: list[NodeScheduleEntry]) -> Iterable[SchedulerNode]:
+        """
+        Get the nodes from node_schedule skipping those in a
+        DisableReduction block.
+        """
+        disabled = False
+        for node in node_schedule:
+            if node in (EnableReduction, DisableReduction):
+                # Don't tile stuff outside the main reduction loop
+                disabled = node is DisableReduction
+            elif disabled:
+                pass
+            else:
+                yield node  # type: ignore[misc]
+
+
+class SIMDKernelFeatures:
+    """
+    An ordered schedule of nodes that will become a single kernel.
+    """
+
+    def __init__(
+        self,
+        node_schedule: list[NodeScheduleEntry],
+        numel: sympy.Expr,
+        reduction_numel: sympy.Expr = sympy.S.One,
+        coalesce_analysis: Optional[CoalesceVarAnalysis] = None,
+    ):
+        self.node_schedule = node_schedule
+        # numel excludes reduction_numel
+        self.numel: sympy.Expr = V.graph.sizevars.simplify(numel)
+        self.reduction_numel: sympy.Expr = V.graph.sizevars.simplify(reduction_numel)
+        self._stats_cache: dict[tuple[sympy.Expr, ...], MemoryStats] = {}
+        self.coalesce_analysis = coalesce_analysis
+
+    @cache_on_self
+    def is_reduction(self) -> bool:
+        return self.reduction_numel != 1
+
+    @cache_on_self
+    def scheduler_nodes(self) -> Iterable[SchedulerNode]:
+        return tuple(NodeScheduleMarker.only_nodes(self.node_schedule))
+
+    def reduction_nodes(self) -> list[SchedulerNode]:
+        return [n for n in self.scheduler_nodes() if n.is_reduction()]
+
+    @cache_on_self
+    def buf_accesses(self) -> dict[str, list[Dep]]:
+        """only needed for config.benchmark_kernel"""
+        buf_accesses = collections.defaultdict(list)
+        for node in self.scheduler_nodes():
+            for access in node.read_writes.reads | node.read_writes.writes:
+                buf_accesses[access.name].append(access)
+        return buf_accesses
+
+    @cache_on_self
+    def op_counts(self) -> collections.Counter[str]:
+        counts: collections.Counter[str] = collections.Counter()
+        for node in self.scheduler_nodes():
+            counts.update(node._body.op_counts)
+        return counts
+
+    def contains_op(self, op_name: str) -> bool:
+        """True if V.ops.{op_name} is used in node_schedule"""
+        return bool(self.op_counts().get(op_name))
+
+    def get_mutations(self) -> OrderedSet[str]:
+        mutations: OrderedSet[str] = OrderedSet()
+        for node in self.scheduler_nodes():
+            for buf in node.get_outputs():
+                mutations.update(buf.get_mutations())
+        return mutations
+
+    @cache_on_self
+    def select_index_dtype(self) -> torch.dtype:
+        # Gather all used buffer names
+        buffer_names: OrderedSet[str] = OrderedSet()
+        for node in self.scheduler_nodes():
+            buffer_names.update(node.get_buffer_names())
+            buffer_names.update(node.used_buffer_names())
+        buffers = [V.graph.get_buffer(name) for name in buffer_names]
+
+        # In theory we can separately check xnumel and rnumel are <= int_max
+        # but some indexers do use the full linear index so we need to be
+        # conservative here.
+        total_numel = self.numel * self.reduction_numel
+
+        from .simd import SIMDScheduling
+
+        if SIMDScheduling.can_use_32bit_indexing(total_numel, buffers):
+            return torch.int32
+        return torch.int64
+
+    @cache_on_self
+    def get_reduction_hint(self) -> ReductionHint:
+        reductions = self.reduction_nodes()
+        if len(reductions) > 0:
+            hints = [self.reduction_hint(n) for n in reductions]
+            if hints.count(hints[0]) == len(hints):
+                reduction_hint_val = hints[0]
+            else:
+                reduction_hint_val = ReductionHint.DEFAULT
+
+            if (
+                reduction_hint_val == ReductionHint.INNER
+                and self.has_non_contiguous_pw_in_reduction_kernel()
+            ):
+                reduction_hint_val = ReductionHint.DEFAULT
+        else:
+            reduction_hint_val = ReductionHint.DEFAULT
+        return reduction_hint_val
+
+    @cache_on_self
+    def buffer_read_counts(self) -> dict[str, int]:
+        """Counts how many times each buffer is read within the kernel"""
+        read_counts: dict[str, int] = collections.defaultdict(int)
+
+        for node in self.scheduler_nodes():
+            # node.read_writes.reads contains MemoryDep objects for each read
+            for read_dep in node.read_writes.reads:
+                read_counts[read_dep.name] += 1
+
+        return dict(read_counts)  # Convert defaultdict to regular dict
+
+    def has_non_contiguous_pw_in_reduction_kernel(self) -> bool:
+        pointwise_nodes = [
+            n
+            for n in self.scheduler_nodes()
+            if not n.is_reduction()
+            and n.group[1][0] == self.numel * self.reduction_numel
+        ]
+        for node in pointwise_nodes:
+            # An index can be an integer when loading a random seed.
+            if not all(
+                not isinstance(dep, MemoryDep)
+                or dep.is_contiguous()
+                or isinstance(dep.index, (sympy.Integer, int))
+                or dep.stride1_for_last_dim()
+                for dep in itertools.chain(
+                    node.read_writes.reads, node.read_writes.writes
+                )
+            ):
+                return True
+        return False
+
+    @staticmethod
+    def reduction_hint(node: Any) -> ReductionHint:
+        assert node.is_reduction()
+        if node.node.data.reduction_hint != ReductionHint.INNER and all(
+            dep.is_contiguous()
+            for dep in itertools.chain(node.read_writes.reads, node.read_writes.writes)
+        ):
+            return ReductionHint.INNER
+        else:
+            return node.node.data.reduction_hint
+
+    def memory_stats(
+        self, groups_dict: Optional[dict[str, sympy.Expr]] = None
+    ) -> MemoryStats:
+        """Analysis to generate features that can be used in heuristics"""
+        if groups_dict is None:
+            groups = (self.numel, self.reduction_numel)
+        elif groups_dict.keys() == OrderedSet(["x", "r0_"]):
+            groups = (groups_dict["x"], groups_dict["r0_"])
+        else:
+            raise NotImplementedError(f"groups_dict={groups_dict!r}")
+        result = self._stats_cache.get(groups)
+        if result is None:
+            self._stats_cache[groups] = result = MemoryStats.compute(
+                MemoryEstimator(self, groups)
+            )
+        return result
+
+
+class MemoryEstimator:
+    """
+    Estimate various properties of the kernel for use in heuristics.
+    We simulate the memory effects of CSE/buffer elimination in codegen.
+    """
+
+    kernel_sizes: tuple[sympy.Expr, ...]
+    outside_loop: MemoryEstimate
+    loops: list[MemoryEstimate]
+    persistent: MemoryEstimate
+    symbols: list[sympy.Symbol]
+
+    def __init__(self, features: SIMDKernelFeatures, groups: Sequence[sympy.Expr]):
+        self.features = features
+        self.inside_reduction = features.is_reduction()
+        self.store_buffer_names: OrderedSet[str] = OrderedSet()
+        self.must_keep_buffers: OrderedSet[str] = OrderedSet()
+        self.num_reductions_dims = 1
+        self.groups = groups
+        self.symbols = [make_symbol(SymT.INDEX, i) for i in range(len(groups))]
+        # We are doing two estimates simultaneously:
+        # 1) the first is a for a non-persistent (aka looped) reduction, using self.outside_loop/self.loops
+        # we add an item to loops each corresponding to each reduction loop in the kernel
+        # outside_loop is only used for broadcasting or point-wise ops that don't use the reduction dimension
+        # 2) the second is for a persistent kernel, using self.persistent
+        # persistent kernels don't have loops, so we only have one MemoryEstimate()
+        # for point-wise ops the two estimates will be the same, they matter for reductions only
+        self.outside_loop = MemoryEstimate()
+        self.loops = [MemoryEstimate()]
+        self.persistent = MemoryEstimate()
+        self.simulate_codegen()
+        self.remove_kernel_local()
+
+    def simulate_codegen(self) -> None:
+        from .simd import SIMDKernel
+
+        kernel_size_outside_loop = (*self.groups[:-1], sympy.S.One)
+        kernel_size_inside_loop = tuple(self.groups)
+        self.kernel_sizes = kernel_size_inside_loop
+
+        for node in self.features.node_schedule:
+            if node is DisableReduction:
+                self.inside_reduction = False
+                self.kernel_sizes = kernel_size_outside_loop
+                continue
+            elif node is EnableReduction:
+                self.inside_reduction = True
+                self.kernel_sizes = kernel_size_inside_loop
+                self.loops.append(MemoryEstimate())
+                continue
+            assert isinstance(node, SchedulerNode)
+            rw = extract_loop_body_with_args(
+                node._body,
+                SIMDKernel.map_kernel_groups_to_node_sizes(
+                    self.kernel_sizes, node.get_ranges(), self.set_ranges
+                ),
+                dict(zip(self.symbols, self.kernel_sizes)),
+            )
+
+            for dep in rw._reads:
+                if not isinstance(dep, MemoryDep):
+                    continue
+                dep = dep.simplify_with_ranges()
+                if not self.persistent.writes.get(dep.name):  # cache miss?
+                    self.persistent.reads[dep.name].add(dep)
+                # the cache behavior of looped kernels is more complex than the persistent case above
+                # some operations are lifted outside the loop (if they don't use the reduction dimension)
+                # other operations are inside the loop, and can only be reused within the same loop
+                if not (
+                    self.outside_loop.writes.get(dep.name)
+                    or self.loops[-1].writes.get(dep.name)
+                ):
+                    self.scope(dep).reads[dep.name].add(dep)
+                    if dep.name in self.store_buffer_names and self.loops[-1].reads.get(
+                        dep.name
+                    ):
+                        self.must_keep_buffers.add(dep.name)
+
+            for dep in rw._writes:
+                if not isinstance(dep, MemoryDep):
+                    continue
+                dep = dep.simplify_with_ranges()
+                self.store_buffer_names.add(dep.name)
+                self.persistent.writes[dep.name].add(dep)
+                self.scope(dep).writes[dep.name].add(dep)
+
+    def remove_kernel_local(self) -> None:
+        # Remove any kernel-local buffers
+        fused_node_names = OrderedSet(
+            [n.get_name() for n in self.features.scheduler_nodes()]
+        )
+        for name in self.store_buffer_names:
+            if not self.persistent.reads.get(
+                name
+            ) and V.graph.scheduler.can_buffer_be_removed_through_fusion(
+                name, fused_node_names
+            ):
+                self.persistent.remove(name)
+                if name not in self.must_keep_buffers:
+                    # we can also remove this from the looped kernel
+                    self.outside_loop.remove(name)
+                    for loop in self.loops:
+                        loop.remove(name)
+
+        if not self.loops[-1]:
+            self.loops.pop()  # for pointwise ops
+
+    def scope(self, dep: MemoryDep) -> MemoryEstimate:
+        """Determine how a read/write should be categorized"""
+        if self.inside_reduction and (
+            self.has_reduction_var(dep.index) or dep.is_indirect()
+        ):
+            return self.loops[-1]
+        return self.outside_loop
+
+    def has_reduction_var(self, index: sympy.Expr) -> bool:
+        for sym in self.symbols[-self.num_reductions_dims :]:
+            if isinstance(sym, sympy.Symbol) and sym in index.free_symbols:
+                return True
+        return False
+
+    def set_ranges(self, *lengths: list[list[sympy.Expr]]) -> list[list[sympy.Expr]]:
+        assert len(self.kernel_sizes) == len(lengths)
+        return [
+            self.make_flat_range(sym, numel, length)
+            for sym, numel, length in zip(self.symbols, self.kernel_sizes, lengths)
+        ]
+
+    @staticmethod
+    def make_flat_range(
+        sym: sympy.Symbol, numel: sympy.Expr, lengths: list[sympy.Expr]
+    ) -> list[sympy.Expr]:
+        if len(lengths) == 1 and numel == lengths[0]:
+            return [sym]
+        divisor = sympy.S.One
+        itervars = []
+        for length in reversed(lengths):
+            if V.graph.sizevars.statically_known_equals(divisor * length, numel):
+                expr = FloorDiv(sym, divisor)
+            else:
+                expr = ModularIndexing(sym, divisor, length)
+            itervars.append(expr)
+            divisor = divisor * length
+        return [*reversed(itervars)]
+
+
+@dataclasses.dataclass
+class MemoryEstimate:
+    """Tracks the memory usage of a single loop in the generated kernel"""
+
+    reads: dict[str, OrderedSet[MemoryDep]] = dataclasses.field(
+        default_factory=functools.partial(collections.defaultdict, OrderedSet)
+    )
+    writes: dict[str, OrderedSet[MemoryDep]] = dataclasses.field(
+        default_factory=functools.partial(collections.defaultdict, OrderedSet)
+    )
+
+    def remove(self, name: str) -> None:
+        self.reads.pop(name, None)
+        self.writes.pop(name, None)
+
+    def __bool__(self) -> bool:
+        return bool(self.reads or self.writes)
+
+    def __repr__(self) -> str:
+        return f"""MemoryEstimate(
+            reads={[*itertools.chain.from_iterable(self.reads.values())]!r},
+            writes={[*itertools.chain.from_iterable(self.writes.values())]!r}
+        )"""
+
+
+@dataclasses.dataclass
+class StatsForDim:
+    """Memory usage stats for a block dimension in the generated kernel (different from user dimensions)"""
+
+    # the number of load/store ops
+    count_per_thread_contiguous: int = 0
+    count_per_thread_broadcast: int = 0
+    count_per_thread_non_contiguous: int = 0  # excludes broadcast
+
+    # total bytes in each load/store op for a single element
+    bytes_per_thread_contiguous: int = 0
+    bytes_per_thread_broadcast: int = 0
+    bytes_per_thread_non_contiguous: int = 0  # excludes broadcast
+
+    # total bytes read by entire kernel
+    bytes_contiguous_or_broadcast: sympy.Expr = sympy.S.Zero
+    bytes_non_contiguous: sympy.Expr = sympy.S.Zero
+
+    def __add__(self, other: typing.Self) -> StatsForDim:
+        return StatsForDim(
+            count_per_thread_contiguous=self.count_per_thread_contiguous
+            + other.count_per_thread_contiguous,
+            count_per_thread_broadcast=self.count_per_thread_broadcast
+            + other.count_per_thread_broadcast,
+            count_per_thread_non_contiguous=self.count_per_thread_non_contiguous
+            + other.count_per_thread_non_contiguous,
+            bytes_per_thread_contiguous=self.bytes_per_thread_contiguous
+            + other.bytes_per_thread_contiguous,
+            bytes_per_thread_broadcast=self.bytes_per_thread_broadcast
+            + other.bytes_per_thread_broadcast,
+            bytes_per_thread_non_contiguous=self.bytes_per_thread_non_contiguous
+            + other.bytes_per_thread_non_contiguous,
+            bytes_contiguous_or_broadcast=self.bytes_contiguous_or_broadcast
+            + other.bytes_contiguous_or_broadcast,
+            bytes_non_contiguous=self.bytes_non_contiguous + other.bytes_non_contiguous,
+        )
+
+    @property
+    def count_per_thread(self) -> int:
+        return (
+            self.count_per_thread_contiguous
+            + self.count_per_thread_broadcast
+            + self.count_per_thread_non_contiguous
+        )
+
+    @property
+    def bytes_per_thread(self) -> int:
+        return (
+            self.bytes_per_thread_contiguous
+            + self.bytes_per_thread_broadcast
+            + self.bytes_per_thread_non_contiguous
+        )
+
+    @property
+    def bytes(self) -> sympy.Expr:
+        return self.bytes_contiguous_or_broadcast + self.bytes_non_contiguous
+
+    @property
+    def contiguous_score(self) -> float:
+        return 1.0 - self.count_per_thread_non_contiguous / max(
+            self.count_per_thread, 1
+        )
+
+
+@dataclasses.dataclass
+class StatsForLoop:
+    """Memory usage stats for single loop in the generated kernel"""
+
+    # load/store ops
+    count_per_thread: int = 0
+    bytes_per_thread: int = 0
+
+    def __add__(self, other: typing.Self) -> StatsForLoop:
+        return StatsForLoop(
+            count_per_thread=self.count_per_thread + other.count_per_thread,
+            bytes_per_thread=self.bytes_per_thread + other.bytes_per_thread,
+        )
+
+
+@dataclasses.dataclass
+class StatsForReadsOrWrites:
+    """Memory usage stats that are collected for reads/writes/both"""
+
+    dim: list[StatsForDim]
+    loop: list[StatsForLoop]
+    # total bytes contiguous in any dimension
+    bytes_contiguous_or_broadcast: sympy.Expr = sympy.S.Zero
+    bytes_non_contiguous: sympy.Expr = sympy.S.Zero
+
+    def __add__(self, other: typing.Self) -> StatsForReadsOrWrites:
+        assert len(self.dim) == len(other.dim)
+        assert len(self.loop) == len(other.loop)
+        return StatsForReadsOrWrites(
+            dim=[a + b for a, b in zip(self.dim, other.dim)],
+            loop=[a + b for a, b in zip(self.loop, other.loop)],
+            bytes_contiguous_or_broadcast=self.bytes_contiguous_or_broadcast
+            + self.bytes_contiguous_or_broadcast,
+            bytes_non_contiguous=self.bytes_non_contiguous + other.bytes_non_contiguous,
+        )
+
+    @property
+    def count_per_thread(self) -> int:
+        return self.dim[0].count_per_thread
+
+    @property
+    def bytes_per_thread(self) -> int:
+        return self.dim[0].bytes_per_thread
+
+    @property
+    def bytes(self) -> sympy.Expr:
+        return self.bytes_contiguous_or_broadcast + self.bytes_non_contiguous
+
+    @classmethod
+    def compute(
+        cls,
+        loop_deps: list[dict[str, OrderedSet[MemoryDep]]],
+        index_symbols: list[sympy.Symbol],
+    ) -> typing.Self:
+        ndim = len(index_symbols)
+        result = cls(dim := [StatsForDim() for _ in range(ndim)], [])
+        for dep_group in loop_deps:
+            result.loop.append(loop_stats := StatsForLoop())
+            for name, deps in dep_group.items():
+                assert deps
+                contiguous_or_broadcast = [True] * ndim
+                numel = sympy.S.Zero
+                itemsize = V.graph.get_dtype(name).itemsize
+                loop_stats.count_per_thread += len(deps)
+                loop_stats.bytes_per_thread += itemsize * len(deps)
+                for dep in deps:
+                    strides: list[sympy.Expr] = V.graph.sizevars.stride_vars(
+                        dep.index, index_symbols
+                    )
+                    for i in range(ndim):
+                        if V.graph.sizevars.statically_known_equals(strides[i], 1):
+                            dim[i].count_per_thread_contiguous += 1
+                            dim[i].bytes_per_thread_contiguous += itemsize
+                        elif (
+                            V.graph.sizevars.statically_known_equals(strides[i], 0)
+                            and not dep.is_indirect()
+                        ):
+                            dim[i].count_per_thread_broadcast += 1
+                            dim[i].bytes_per_thread_broadcast += itemsize
+                        else:
+                            dim[i].count_per_thread_non_contiguous += 1
+                            dim[i].bytes_per_thread_non_contiguous += itemsize
+                            contiguous_or_broadcast[i] = False
+                    numel += dep.get_numel()
+                if len(deps) > 1:
+                    # can't read more elements than exist in the buffer
+                    numel = sympy.Min(numel, V.graph.get_numel(name))
+                nbytes = numel * itemsize
+                for i in range(ndim):
+                    if contiguous_or_broadcast[i]:
+                        dim[i].bytes_contiguous_or_broadcast += nbytes
+                    else:
+                        dim[i].bytes_non_contiguous += nbytes
+                if any(contiguous_or_broadcast):
+                    result.bytes_contiguous_or_broadcast += nbytes
+                else:
+                    result.bytes_non_contiguous += nbytes
+        if len(result.loop) > 1:
+            # the first loop represent the "outside of the loop" compute which could be long lived
+            result.loop = [result.loop[0] + x for x in result.loop[1:]]
+        return result
+
+
+@dataclasses.dataclass
+class StatsForKernelType:
+    """Memory usage stats that are collected for both persistent and looped kernels"""
+
+    reads: StatsForReadsOrWrites
+    writes: StatsForReadsOrWrites
+    memory: StatsForReadsOrWrites
+
+    @classmethod
+    def compute(
+        cls, loops: list[MemoryEstimate], estimator: MemoryEstimator
+    ) -> typing.Self:
+        reads = StatsForReadsOrWrites.compute(
+            [loop.reads for loop in loops], estimator.symbols
+        )
+        writes = StatsForReadsOrWrites.compute(
+            [loop.writes for loop in loops], estimator.symbols
+        )
+        return cls(
+            reads=reads,
+            writes=writes,
+            memory=reads + writes,
+        )
+
+
+@dataclasses.dataclass
+class MemoryStats:
+    """Memory usage stats collected for each generated kernel"""
+
+    persistent: StatsForKernelType
+    looped: StatsForKernelType
+
+    def get(self, persistent: bool) -> StatsForKernelType:
+        return self.persistent if persistent else self.looped
+
+    @classmethod
+    def compute(cls, estimator: MemoryEstimator) -> typing.Self:
+        persistent = StatsForKernelType.compute([estimator.persistent], estimator)
+        if len(estimator.loops) == 1 and not (
+            estimator.outside_loop and estimator.loops[0]
+        ):
+            looped = persistent  # loops/persistent is the same in this common case
+        else:
+            looped = StatsForKernelType.compute(
+                [estimator.outside_loop, *estimator.loops], estimator
+            )
+        return cls(
+            persistent=persistent,
+            looped=looped,
+        )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/subgraph.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/subgraph.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b931fb3bf47e74596e9053f58177a6faa180edd
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/subgraph.py
@@ -0,0 +1,433 @@
+import itertools
+import logging
+from collections.abc import Callable
+from typing import Any, Union
+
+import torch
+import torch._inductor.config as config
+from torch._inductor import ir
+from torch._inductor.codegen.common import KernelTemplate
+from torch._inductor.ir import (
+    Buffer,
+    FixedLayout,
+    get_free_symbols,
+    get_symbolic_inputs,
+    gm_original_output_strides,
+    ir_node_to_tensor,
+    Layout,
+)
+from torch._inductor.runtime.benchmarking import benchmarker
+from torch._inductor.utils import do_bench_using_profiling
+from torch._inductor.virtualized import V
+
+
+log = logging.getLogger(__name__)
+
+
+def inline_subgraph_to_ir_nodes(
+    gm: torch.fx.GraphModule, inputs: list[Any], name: str
+) -> Any:
+    """Inline a subgraph by converting its FX operations to individual IR nodes.
+
+    This converts a subgraph to multiple ComputedBuffer nodes (fusable),
+    enabling epilogue fusion with subsequent operations.
+
+    Returns:
+        TensorBox containing the final operation result as individual IR nodes
+    """
+    from torch._inductor.lowering import process_subgraph_nodes
+
+    return process_subgraph_nodes(gm, inputs)
+
+
+class SubgraphChoiceCaller(ir.ChoiceCaller):
+    """
+    Represents a Subgraph Autotuning choice, and the subgraph can be any arbitrary
+    GraphModule. Compiles the Subgraph down to a module for benchmarking.
+    """
+
+    def __init__(
+        self,
+        name: str,
+        input_nodes: list[Buffer],
+        layout: Layout,
+        description: str,
+        make_fx_graph: Callable[..., Any],
+    ) -> None:
+        super().__init__(name, input_nodes, layout, description)
+
+        self.example_inputs = []
+        with V.fake_mode:
+            for inp in self.input_nodes:
+                # Here there will be no unbacked symbols, as SubgraphBuffer does not support them
+                assert len(get_free_symbols(inp.get_size(), unbacked_only=True)) == 0
+                assert len(get_free_symbols(inp.get_stride(), unbacked_only=True)) == 0
+
+                inp.data.freeze_layout()  # type: ignore[attr-defined]
+                self.example_inputs.append(ir_node_to_tensor(inp))
+
+        self.gm = make_fx_graph(*self.example_inputs)
+        gm_original_output_strides(self.gm)
+
+        self.sym_inputs = get_symbolic_inputs(self.input_nodes)
+
+        # Cache compiled module to avoid recompiling on every benchmark call
+        self._compiled_module: Any = None
+        self._compiled_sym_inputs: list[Any] | None = None
+
+    def __str__(self) -> str:
+        return f"SubgraphCaller({self.name})"
+
+    def _compile_for_benchmarking(self, *args: list[Any]) -> tuple[Any, list[Any]]:
+        """
+        Compile the subgraph for benchmarking and return (module, sym_inputs).
+
+        TODO: Add precompile() method to enable parallel compilation of all choices
+        before benchmarking.
+        """
+        import torch._inductor.config as inductor_config
+        from torch._inductor.graph import GraphLowering
+
+        safe_name = self.name.replace("::", "_").replace(".", "_")
+
+        bm_graph_lowering = GraphLowering(
+            gm=self.gm,
+            example_inputs=self.example_inputs,
+            shape_env=V.graph._shape_env,
+            cpp_wrapper=V.graph.cpp_wrapper,
+            aot_mode=V.graph.aot_mode,
+            extern_node_serializer=V.graph.extern_node_serializer,
+            is_inference=V.graph.is_inference,
+            is_backward=V.graph.is_backward,
+            name=f"benchmark_{safe_name}",
+        )
+
+        for sym_inp in self.sym_inputs:
+            bm_graph_lowering.graph_inputs[sym_inp.name] = sym_inp
+            bm_graph_lowering.graph_input_names.append(sym_inp.name)
+
+        sym_inputs = [
+            # pyrefly: ignore [no-matching-overload]
+            int(V.graph.sizevars.shape_env.size_hint(sym_var))
+            for sym_var in self.sym_inputs
+        ]
+
+        if len(sym_inputs) == 0:
+            # Sanity check that args are same layout as example inputs
+            # Only do it if there are no symbolic inputs, otherwise
+            # the dynamic dim will be realized to the same size as args
+            for ar, example_inp in zip(args, self.example_inputs):
+                # Sanity check that args are same layout as example inputs
+                if isinstance(ar, torch.Tensor):
+                    assert isinstance(example_inp, torch.Tensor)
+                    assert ar.shape == example_inp.shape
+                    assert ar.stride() == example_inp.stride()
+
+        with V.set_graph_handler(bm_graph_lowering):
+            # Don't bother autotuning on Triton here
+            with inductor_config.patch(
+                max_autotune=False,
+                max_autotune_gemm=False,
+                max_autotune_gemm_backends="ATEN",
+            ):
+                bm_graph_lowering.run(*self.example_inputs)
+                mod = bm_graph_lowering.compile_to_module()
+
+        return mod, sym_inputs
+
+    def benchmark(self, *args: list[Any], out: torch.Tensor) -> float:
+        """
+        Regular benchmarking: compile and use benchmarker with warmup/rep.
+        """
+        if self._compiled_module is None:
+            mod, sym_inputs = self._compile_for_benchmarking(*args)
+            self._compiled_module = mod
+            self._compiled_sym_inputs = sym_inputs
+        else:
+            mod = self._compiled_module
+            sym_inputs = self._compiled_sym_inputs
+            assert sym_inputs is not None  # Type narrowing
+
+        bm_func = mod.call
+        if config.profile_bandwidth_with_do_bench_using_profiling:
+            return do_bench_using_profiling(lambda: bm_func([*sym_inputs, *args]))
+        return benchmarker.benchmark(
+            # Shallow clone args since bm_func may clear args
+            lambda: bm_func([*sym_inputs, *args]),
+            device=benchmarker.infer_device(*sym_inputs, *args),
+        )
+
+    def benchmark_collective(self, *args: list[Any], out: torch.Tensor) -> None:
+        """
+        Only run once with cached compiled module.
+        Called by benchmark_collective_choice which handles warmup
+        and timing with barrier synchronization across all ranks.
+        """
+        if self._compiled_module is None:
+            mod, sym_inputs = self._compile_for_benchmarking(*args)
+            self._compiled_module = mod
+            self._compiled_sym_inputs = sym_inputs
+        else:
+            mod = self._compiled_module
+            sym_inputs = self._compiled_sym_inputs
+            assert sym_inputs is not None  # Type narrowing
+
+        bm_func = mod.call
+        bm_func([*sym_inputs, *args])
+
+    def hash_key(self) -> str:
+        return "-".join(
+            [
+                self.name.rsplit("_", 1)[0],
+                *[str(inp.get_size()) for inp in self.input_nodes],
+                *[str(inp.get_stride()) for inp in self.input_nodes],
+                str(self.gm.graph),
+            ]
+        )
+
+    def output_node(self) -> Union[ir.TensorBox, ir.ShapeAsConstantBuffer]:
+        return ir.TensorBox.create(
+            ir.SubgraphBuffer(
+                layout=self.layout,
+                input_nodes=self.input_nodes,
+                gm=self.gm,
+                example_inputs=self.example_inputs,
+                subgraph_name=self.name,
+            )
+        )
+
+    def info_dict(self) -> dict[str, Any]:
+        """Information returned here is logged to the autotune log file when that is enabled."""
+        return {
+            "backend": "subgraph",
+            "kernel_name": self.name,
+        }
+
+    def autoheuristic_id(self) -> str:
+        return f"subgraph_{self.name}"
+
+
+class SubgraphTemplate(KernelTemplate):
+    """
+    A template for subgraph evaluation to be used in autotuning.
+
+    This class allows creating customized subgraphs that can be appended
+    as choices during the autotuning process, enabling the selection of
+    optimal implementations for complex operations.
+    """
+
+    index_counter = itertools.count()
+
+    def __init__(
+        self,
+        name: str,
+    ):
+        """
+        Initialize a subgraph template.
+
+        Args:
+            name: The name of this template
+            graph: The FX graph
+        """
+        super().__init__(name=name)
+
+    def generate(  # type: ignore[override]
+        self,
+        name: str,
+        input_nodes: list[Buffer],
+        layout: Layout,
+        make_fx_graph: Callable[..., Any],
+        description: str = "",
+        **kwargs: Any,
+    ) -> SubgraphChoiceCaller:
+        """
+        Generate a SubgraphChoiceCaller instance for autotuning.
+
+        Args:
+            name: The name for this subgraph choice
+            input_nodes: List of input nodes to the subgraph
+            layout: Memory layout information for the output
+            make_fx_graph: Callable that creates the FX graph for this subgraph
+            description: Optional description of this choice
+            **kwargs: Additional keyword arguments
+
+        Returns:
+            SubgraphChoiceCaller: A callable object that can be used for autotuning
+        """
+
+        return SubgraphChoiceCaller(
+            name=f"{name}_{next(SubgraphTemplate.index_counter)}",
+            input_nodes=input_nodes,
+            layout=layout,
+            description=description,
+            make_fx_graph=make_fx_graph,
+        )
+
+    def generate_custom_op_choices(
+        self,
+        name: str,
+        decompositions: list[Callable[..., Any]],
+        input_nodes: list[Buffer],
+        non_tensor_args: list[dict[str, Any]],
+        default_impl: Callable[..., Any] | None = None,
+    ) -> list[SubgraphChoiceCaller]:
+        """
+        Generate multiple SubgraphChoiceCaller instances for custom op autotuning.
+
+        This method extends SubgraphTemplate to support custom op decompositions,
+        allowing multiple implementations to compete in autotuning.
+
+        Args:
+            name: Base name for the choices
+            decompositions: List of decomposition functions to compete in autotuning
+            input_nodes: List of tensor inputs. All tensor arguments must be passed here.
+            non_tensor_args: List of non-tensor kwargs only, one dict per corresponding decomposition.
+            default_impl: Default implementation for layout inference
+
+        Returns:
+            List of SubgraphChoiceCaller instances for autotuning
+        """
+        if not decompositions:
+            return []
+
+        assert len(decompositions) == len(non_tensor_args), (
+            f"decompositions and non_tensor_args must have same length, "
+            f"got {len(decompositions)} decompositions and {len(non_tensor_args)} kwargs"
+        )
+
+        # Infer layouts and ensure layout consistency for fair autotuning comparison
+        layouts = [
+            self._infer_custom_op_layout(input_nodes, decomp, kwargs, default_impl)
+            for decomp, kwargs in zip(decompositions, non_tensor_args)
+        ]
+
+        # Validate all decompositions produce equivalent layouts for fair comparison
+        self._validate_layout_equivalence(name, decompositions, layouts)
+        layout = layouts[0]  # All layouts are now validated to be equivalent
+
+        choices: list[SubgraphChoiceCaller] = []
+        for decomp, decomp_kwargs in zip(decompositions, non_tensor_args):
+            # Create make_fx_graph function for this decomposition
+            import functools
+
+            def make_fx_graph(
+                *args: Any,
+                decomp: Callable[..., Any] = decomp,
+                decomp_kwargs: dict[str, Any] = decomp_kwargs,
+            ) -> Any:
+                # decomp_kwargs contains all merged parameters: CustomOpConfig params + runtime kwargs
+                from torch.fx.experimental.proxy_tensor import make_fx
+
+                from ..decomposition import select_decomp_table
+
+                decomposition_table = select_decomp_table()
+
+                return make_fx(
+                    functools.partial(decomp, **decomp_kwargs),
+                    decomposition_table=decomposition_table,
+                )(*args)
+
+            # Generate descriptive name for this variant
+            variant_name = self._generate_variant_name(decomp, decomp_kwargs)
+
+            choice = self.generate(
+                name=f"{name}_{variant_name}",
+                input_nodes=input_nodes,
+                layout=layout,
+                make_fx_graph=make_fx_graph,
+                description=f"CustomOp {decomp.__name__}",
+            )
+            choices.append(choice)
+
+        return choices
+
+    def _generate_variant_name(
+        self, decomp: Callable[..., Any], kwargs: dict[str, Any]
+    ) -> str:
+        """Generate a descriptive name for a decomposition variant with its parameters."""
+        base_name = decomp.__name__
+        if not kwargs:
+            return base_name
+        param_suffix = "_".join(f"{k}_{v}" for k, v in sorted(kwargs.items()))
+        return f"{base_name}_{param_suffix}"
+
+    def _validate_non_tensor_kwargs(self, kwargs: dict[str, Any]) -> None:
+        """Validate that kwargs contains only non-tensor arguments."""
+        for key, value in kwargs.items():
+            assert not isinstance(value, (torch.Tensor, Buffer)), (
+                f"kwargs['{key}'] contains tensor {type(value)}. "
+                f"Tensor arguments should be in input_nodes, not kwargs. "
+                f"Only scalar/non-tensor parameters should be in kwargs."
+            )
+
+    def _validate_layout_equivalence(
+        self,
+        op_name: str,
+        decompositions: list[Callable[..., Any]],
+        layouts: list[Layout],
+    ) -> None:
+        """Ensure all layouts have consistent stride, device, dtype, and sizes for fair autotuning."""
+        if not layouts:
+            return
+
+        reference = layouts[0]
+        for i, layout in enumerate(layouts[1:], start=1):
+            if (layout.device, layout.dtype, layout.size, layout.stride) != (
+                reference.device,
+                reference.dtype,
+                reference.size,
+                reference.stride,
+            ):
+                raise AssertionError(
+                    f"Layout mismatch in custom op '{op_name}': "
+                    f"decomposition '{decompositions[i].__name__}' produces "
+                    f"({layout.device}, {layout.dtype}, {layout.size}, {layout.stride}) "
+                    f"but '{decompositions[0].__name__}' produces "
+                    f"({reference.device}, {reference.dtype}, {reference.size}, {reference.stride})"
+                )
+
+    def _infer_custom_op_layout(
+        self,
+        input_nodes: list[Buffer],
+        function_decomposition: Callable[..., Any],
+        kwargs: dict[str, Any],
+        default_impl: Callable[..., Any] | None = None,
+    ) -> Layout:
+        """Infer output layout for custom ops using the default implementation when available.
+        Note that the Subgraph assumes custom ops return exactly one tensor output.
+        TODO: Add support for multiple output custom ops.
+        """
+        import functools
+
+        from torch._inductor.virtualized import V
+
+        # Assert kwargs contain only non-tensor arguments
+        self._validate_non_tensor_kwargs(kwargs)
+
+        with V.fake_mode:
+            example_inputs = []
+            for inp in input_nodes:
+                raw_shape = inp.get_size()
+                concrete_shape = V.graph.sizevars.size_hints(
+                    raw_shape, fallback=config.unbacked_symint_fallback
+                )
+                fake_tensor = torch.empty(
+                    concrete_shape, dtype=inp.get_dtype(), device=inp.get_device()
+                )
+                example_inputs.append(fake_tensor)
+
+            fn = functools.partial(function_decomposition, **kwargs)
+            output = fn(*example_inputs)
+
+            # Assert single output
+            assert isinstance(output, torch.Tensor), (
+                f"Expected single tensor output, got {type(output)}. "
+                f"Multi-output custom ops not yet supported in autotuning."
+            )
+
+            return FixedLayout(
+                device=output.device,
+                dtype=output.dtype,
+                size=output.shape,
+                stride=output.stride(),
+            )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/triton.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/triton.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d6b671f15db31bfd8e2ddea7e556e9867b93227
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/triton.py
@@ -0,0 +1,6263 @@
+# mypy: allow-untyped-defs
+from __future__ import annotations
+
+import collections
+import contextlib
+import dataclasses
+import functools
+import itertools
+import logging
+import math
+import operator
+import os
+import textwrap
+from abc import abstractmethod
+from collections.abc import Callable, Iterable, Sequence
+from functools import lru_cache
+from typing import Any, cast, Optional, TYPE_CHECKING, TypeVar, Union
+
+import sympy
+from sympy.printing.precedence import PRECEDENCE
+
+import torch
+import torch._logging
+import torch.utils._pytree as pytree
+from torch._dynamo.device_interface import get_interface_for_device
+from torch._dynamo.utils import identity, preserve_rng_state
+from torch._prims_common import is_integer_dtype
+from torch.utils._ordered_set import OrderedSet
+from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
+from torch.utils._triton import (
+    get_triton_version,
+    has_triton_package,
+    has_triton_stable_tma_api,
+)
+
+from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT
+from ...utils._sympy.value_ranges import ValueRanges
+from .. import config, ir, metrics, utils
+from ..async_compile import AsyncCompile
+from ..codecache import code_hash, get_path, PyCodeCache, write_atomic
+from ..debug import set_kernel_post_grad_provenance_tracing
+from ..ops_handler import DefaultHandler
+from ..runtime import triton_heuristics
+from ..runtime.benchmarking import benchmarker
+from ..runtime.hints import (
+    AutotuneHint,
+    DeviceProperties,
+    TRITON_MAX_BLOCK,
+    TRITON_MAX_RSPLIT,
+)
+from ..runtime.runtime_utils import get_max_y_grid, next_power_of_2
+from ..scheduler import BaseSchedulerNode, FusedSchedulerNode, Scheduler, SchedulerNode
+from ..shape_propagation import get_broadcasted_shape
+from ..utils import (
+    cache_on_self,
+    DelayMaybeLine,
+    DelayReplaceLine,
+    get_bounds_index_expr,
+    get_fused_kernel_name,
+    get_kernel_metadata,
+    is_welford_reduction,
+    Placeholder,
+    prefix_is_reduction,
+    sympy_dot,
+    sympy_product,
+    sympy_subs,
+    triton_type,
+    triton_version_uses_attrs_dict,
+    upcast_compute_type,
+)
+from ..virtualized import _ops as ops, ReductionType, StoreMode, V
+from ..wrapper_benchmark import get_kernel_category_by_source_code
+from .block_analysis import BlockPatternMatcher
+from .common import (
+    ArgName,
+    BackendFeature,
+    ConstexprArg,
+    CSE,
+    CSEVariable,
+    DeferredLine,
+    IndentedBuffer,
+    InplacedBuffer,
+    is_buffer_removed,
+    OpOverrides,
+    PythonPrinter,
+    RemovedArg,
+    SizeArg,
+    TensorArg,
+    WorkspaceArg,
+    WorkspaceZeroMode,
+)
+from .simd import (
+    constant_repr,
+    IterationRanges,
+    IterationRangesEntry,
+    IterationRangesRoot,
+    PartialAccumulate,
+    SIMDKernel,
+    SIMDScheduling,
+)
+from .triton_utils import (
+    config_of,
+    equal_1_arg_indices,
+    non_constexpr_signature,
+    should_unwrap_unspec_arg,
+    signature_to_meta,
+)
+from .wrapper import SymbolicCallArg
+
+
+if TYPE_CHECKING:
+    from types import ModuleType
+
+    from torch._inductor.dtype_propagation import DtypePropagationOpsHandler
+    from torch.fx.experimental.symbolic_shapes import ShapeEnv
+
+    from ..ir import IRNode
+    from .common import BlockShapeType
+    from .simd_kernel_features import SIMDKernelFeatures
+
+    _T = TypeVar("_T")
+
+log = logging.getLogger(__name__)
+perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
+schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
+fusion_log = torch._logging.getArtifactLogger(__name__, "fusion")
+async_compile = AsyncCompile()
+
+
+def get_triton_reduction_function(reduction_type):
+    use_helper = reduction_type in ("any", "max", "min", "prod")
+    module = "triton_helpers" if use_helper else "tl"
+    if reduction_type in ("max", "min"):
+        return f"{module}.{reduction_type}2"
+    else:
+        return f"{module}.{reduction_type}"
+
+
+def is_sympy_integer_like(expr: object):
+    """ "
+    Is this expression a Sympy Integer or is it an integer sympy Expr
+    containing no free symbols. The latter case can happen with Identity expr.
+    """
+    if not isinstance(expr, sympy.Expr):
+        return False
+    return isinstance(expr, sympy.Integer) or (
+        expr.is_integer and len(expr.free_symbols) == 0
+    )
+
+
+class OpDtypeSupport:
+    """
+    Some Triton ops such as libdevice and tl.math only support float32 and float64.
+    This class records which dtypes are supported by specific IR ops.
+    """
+
+    supported_dtypes: dict[str, OrderedSet[torch.dtype]] = {}
+    convert_outputs: dict[str, bool] = {}
+
+    @classmethod
+    def register_upcast(cls, func: Callable[..., str], convert_output: bool) -> None:
+        op_name = func.__name__
+        cls.supported_dtypes[op_name] = OrderedSet([torch.float32, torch.float64])
+        cls.convert_outputs[op_name] = convert_output
+
+
+@lru_cache(None)
+def gen_attr_descriptor_import() -> str:
+    """
+    import AttrsDescriptor if the triton version is new enough to have this
+    class defined.
+    """
+    if not has_triton_package():
+        return ""
+
+    import triton.compiler.compiler
+
+    # Note: this works because triton.compiler.compiler imports AttrsDescriptor from triton.backends.compiler
+    # When support for the legacy AttrsDescriptor is removed then this import path should be changed.
+    if hasattr(triton.compiler.compiler, "AttrsDescriptor"):
+        return "from triton.compiler.compiler import AttrsDescriptor"
+    else:
+        return ""
+
+
+@lru_cache(None)
+def gen_common_triton_imports() -> str:
+    imports = IndentedBuffer()
+    imports.splice(
+        """
+        import triton
+        import triton.language as tl
+        """
+    )
+    if attr_desc := gen_attr_descriptor_import():
+        imports.writeline(attr_desc)
+
+    imports.splice(
+        """
+        from torch._inductor.runtime import triton_helpers, triton_heuristics
+        from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
+        from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
+        """
+    )
+    return imports.getvalue()
+
+
+class TritonSymbols:
+    """
+    Stores sympy.Symbol instances and constants associated with triton codegen.
+    """
+
+    reduction_types = OrderedSet([SymT.R0_INDEX, SymT.R1_INDEX])
+    block_types = OrderedSet([SymT.XBLOCK, SymT.YBLOCK, SymT.ZBLOCK, *reduction_types])
+
+    block_offsets = {
+        symt: sympy.Symbol(f"{prefix_str[symt]}offset", integer=True, nonnegative=True)
+        for symt in block_types
+    }
+
+    block_sizes = {
+        symt: sympy.Symbol(
+            f"{prefix_str[symt].upper()}BLOCK", integer=True, positive=True
+        )
+        for symt in block_types
+    }
+
+    @classmethod
+    def get_block_shape(cls, expr: sympy.Expr) -> BlockShapeType:
+        # return block shape of sympy Expression
+        # e.g.,
+        # tmp13 = y1
+        # tmp14 = x0 - tmp13
+        #
+        # get_block_shape(y1) = (YBLOCK,1,1)
+        # get_block_shape(x0-tmp13) = (YBLOCK,XBLOCK,1)
+
+        expr_shape: BlockShapeType = ()
+        expr_vars = expr.free_symbols
+        for var in expr_vars:
+            if symbol_is_type(var, SymT.TMP):
+                cse_var = V.kernel.cse.varname_map[var.name]
+                var_shape = cse_var.shape
+            elif symbol_is_type(
+                var,
+                (
+                    SymT.UNBACKED_INT,
+                    SymT.SIZE,
+                    SymT.PRECOMPUTED_SIZE,
+                    SymT.INDEX,
+                    SymT.FLOAT,
+                    SymT.UNBACKED_FLOAT,
+                ),
+            ):
+                var_shape = ()
+            else:
+                symbol_matches = [
+                    symt for symt in cls.block_types if symbol_is_type(var, symt)
+                ]
+                assert len(symbol_matches) == 1, f"Ambiguous type: {var.name}"
+
+                sym = symbol_matches[0]
+                ndim = V.kernel.triton_tensor_ndim()
+                shape = ["1"] * ndim
+
+                tree_match = [
+                    tree
+                    for tree in V.kernel.active_range_trees()
+                    if prefix_str[sym] == tree.prefix
+                ]
+                assert len(tree_match) == 1, "# of Match expected to 1"
+
+                shape[tree_match[0].tensor_dim] = str(cls.get_block_size(tree_match[0]))
+                var_shape = tuple(shape)
+
+            # Union current variable shape
+            expr_shape = get_broadcasted_shape(expr_shape, var_shape)
+
+        assert expr_shape is not None
+
+        return expr_shape
+
+    @classmethod
+    def get_block_size(cls, tree: IterationRanges) -> sympy.Symbol:
+        return cls.block_sizes[tree.symt]
+
+    @classmethod
+    def get_block_offset(cls, tree: IterationRanges) -> sympy.Symbol:
+        return cls.block_offsets[tree.symt]
+
+
+@dataclasses.dataclass
+class IndexingOptions:
+    index_str: str
+    mask_vars: OrderedSet[str]
+    expand_str: Optional[str]
+    _has_rindex: bool
+    index: sympy.Expr
+    expand_shape: Optional[Sequence[Union[int, str]]]
+
+    def has_mask(self) -> bool:
+        return bool(self.mask_vars)
+
+    def has_indirect(self) -> bool:
+        return free_symbol_is_type(self.index, SymT.TMP)
+
+    def has_rindex(self) -> bool:
+        return self._has_rindex
+
+    def has_tmpmask(self) -> bool:
+        return any(str(mask).startswith("tmp") for mask in self.mask_vars)
+
+    def has_rmask(self) -> bool:
+        return any(str(mask).startswith("r") for mask in self.mask_vars)
+
+    @property
+    def mask_str(self) -> str:
+        # The sorted call is added to make sure the order is still
+        # deterministic if self.mask_vars contains mix of string
+        # and TritonCSEVariable
+        return (
+            " & ".join(sorted(map(str, self.mask_vars))) if self.mask_vars else "None"
+        )
+
+
+@dataclasses.dataclass
+class BlockDescriptorOptions:
+    """
+    This is a base class that describes a block descriptor used in Triton kernels.
+    It can be used to create either a tensor descriptor (with TensorDescriptorOptions)
+    or a block pointer (with BlockPtrOptions).
+    """
+
+    params: BlockParameters
+    constant_offset: sympy.Expr
+    order: list[int]
+    mask_vars: OrderedSet[str]
+    broadcast_shape: Sequence[sympy.Expr]
+    broadcasting_dims: list[bool]
+    final_shape: Sequence[sympy.Expr]
+    # If the BlockParameters have been sorted using a particular stride order
+    # transpose load / store blocks at runtime using the information in
+    # stride_sorter.
+    stride_sorter: BlockParameters.StrideSorter
+    _boundary_check: Optional[list[int]] = None
+    # Can we safely lift the constructor
+    # to the top of the kernel?
+    can_lift: bool = False
+
+    @property
+    def shape(self) -> list[sympy.Expr]:
+        return self.params.shape
+
+    @property
+    def block_shape(self) -> list[sympy.Expr]:
+        return self.params.block_shape
+
+    @property
+    def strides(self) -> list[sympy.Expr]:
+        return self.params.strides
+
+    @property
+    def offsets(self) -> list[sympy.Expr]:
+        return self.params.offsets
+
+    @classmethod
+    def create(
+        cls,
+        *,
+        params: BlockParameters,
+        constant_offset: sympy.Expr,
+        range_trees: list[IterationRangesRoot],
+        mask_vars: OrderedSet[str],
+        get_max_block: Callable[[str], int],
+        stride_sorter_cls: type[BlockParameters.StrideSorter],
+        can_lift: bool = False,
+    ) -> BlockDescriptorOptions:
+        """Helper to create a BlockDescriptorOptions instance"""
+
+        sizevars = V.graph.sizevars
+
+        def lookup_size(exprs: Iterable[sympy.Expr]) -> list[sympy.Expr]:
+            return [sizevars.lookup_precomputed_size(expr) for expr in exprs]
+
+        # Look up precomputed sizes
+        params.shape = lookup_size(params.shape)
+        params.strides = lookup_size(params.strides)
+
+        # Strip out dimensions of size 1.
+        # Size 1 dimensions are redundant since the triton kernel shape
+        # will be e.g. [YBLOCK, XBLOCK], so tl.reshape would just remove these
+        # dimensions anyway
+        singleton_dims = [
+            sizevars.statically_known_equals(dim, 1) for dim in params.block_shape
+        ]
+        if all(singleton_dims):
+            # Handle a pure singletons, e.g. [1, 1]
+            singleton_dims[-1] = False
+
+        # Drop singleton dimensions from the block descriptor.
+        params = params.remove_dims(singleton_dims)
+
+        # Maybe reorder dimensions based on strides
+        # with tl.trans applied at load / store time
+        params, stride_sorter = params.maybe_sort_with_stride_order(
+            stride_sorter_cls=stride_sorter_cls, shape_env=V.graph._shape_env
+        )
+
+        # Strip out dimensions of stride 0.
+        # These will be restored with tl.broadcast_to.
+        broadcasting_dims = [
+            sizevars.statically_known_equals(stride, 0) for stride in params.strides
+        ]
+
+        # Record the post-broadcast shape before broadcasting dims are removed.
+        # The pre-broadcast shape is identical to this, except broadcasting dims are
+        # replaced with 1.
+        broadcast_shape = params.block_shape
+
+        # Drop broadcasting dims from the block descriptor.
+        params = params.remove_dims(broadcasting_dims)
+
+        # Compute the final shape, adjusting for special kernel types.
+        final_shape = [TritonSymbols.get_block_size(tree) for tree in range_trees]
+        if V.kernel.no_x_dim:
+            assert range_trees[0].prefix == "x"
+            final_shape.pop(0)
+
+        reduction_ndim = V.kernel.num_reduction_dims
+        if (
+            not V.kernel.inside_reduction
+            and len(params.strides) == len(V.kernel.numels) - reduction_ndim
+            and V.kernel.features.is_reduction()
+        ):
+            # Need to expand rank to match the rank used inside the reduction loop
+            final_shape += [sympy.S.One] * reduction_ndim
+
+        try:
+            # Get permutation to sort strides in ascending order.
+            # This is used as the order argument in tl.make_block_ptr
+            order = utils.argsort_sym(V.graph._shape_env, params.strides)
+        except AssertionError:
+            # Symbolic shapes, failed to evaluate comparison expression
+            order = list(reversed(range(len(params.strides))))
+
+        result = cls(
+            params=params,
+            constant_offset=V.graph.sizevars.lookup_precomputed_size(constant_offset),
+            order=order,
+            mask_vars=mask_vars,
+            final_shape=final_shape,
+            broadcast_shape=broadcast_shape,
+            broadcasting_dims=broadcasting_dims,
+            stride_sorter=stride_sorter,
+            can_lift=can_lift,
+        )
+        result.compute_boundary_check(get_max_block, range_trees)
+        return result
+
+    def replace_offset(
+        self, expr: sympy.Expr, replacement: sympy.Expr, symt: SymT
+    ) -> sympy.Expr:
+        """
+        Replaces instances of {symt}_offset with the new expression.
+        """
+        roffset = TritonSymbols.block_offsets[symt]
+        return sympy_subs(expr, {roffset: replacement})
+
+    def remove_roffsets(self, expr: sympy.Expr) -> sympy.Expr:
+        for symt in TritonSymbols.reduction_types:
+            expr = self.replace_offset(expr, sympy.Integer(0), symt)
+        return expr
+
+    def compute_boundary_check(
+        self,
+        get_max_block: Callable[[str], int],
+        range_trees: list[IterationRangesRoot],
+    ) -> None:
+        """List of indices to pass to tl.load(boundary_check=...)"""
+        sizevars = V.graph.sizevars
+
+        # Substitute maximum block sizes in shape expressions.
+        # This works in multiple_of checks because block sizes are powers of 2.
+        block_to_max: dict[sympy.Expr, Any] = {
+            TritonSymbols.block_sizes[t.symt]: get_max_block(prefix_str[t.symt])
+            for t in range_trees
+        }
+
+        # Also see Note: Constant mask optimisation
+        # if ynumel / YBLOCK > max_ygrid, then the z dimension is used to handle
+        # the remaining programs that cannot fit into the y dimension. This means
+        # it's possible that more than the required number of programs are launched,
+        # possibly leading to out-of-bounds accesses. So even if ynumel divides YBLOCK,
+        # boundary checking is required in the dimensions that are based on YBLOCK
+        # e.g. for [YBLOCK // 16, YBLOCK, XBLOCK] dimensions 0 and 1 need boundary
+        # checks when max_ygrid is exceeded.
+        needs_overflow_grid = any(map(V.kernel.needs_yz_grid_overflow, range_trees))
+        self._boundary_check = [
+            idx
+            for idx in range(len(self.shape))
+            if (
+                not sizevars.statically_known_equals(self.strides[idx], sympy.S.Zero)
+                and (
+                    (
+                        needs_overflow_grid
+                        and TritonSymbols.block_sizes[SymT.YBLOCK]
+                        in self.block_shape[idx].free_symbols
+                    )
+                    or (
+                        not sizevars.statically_known_multiple_of(
+                            self.shape[idx], self.block_shape[idx]
+                        )
+                        and not sizevars.statically_known_multiple_of(
+                            self.shape[idx],
+                            sympy_subs(self.block_shape[idx], block_to_max),
+                        )
+                    )
+                )
+                and not (
+                    V.kernel.no_x_dim
+                    and self.block_shape[idx] == TritonSymbols.block_sizes[SymT.XBLOCK]
+                )
+            )
+        ]
+
+    def boundary_check(self) -> list[int]:
+        assert self._boundary_check is not None
+        return self._boundary_check
+
+    def has_indirect(self) -> bool:
+        return False  # block_ptr can't do indirect indexing
+
+    def has_rindex(self) -> bool:
+        return any(
+            free_symbol_is_type(expr, TritonSymbols.reduction_types)
+            for expr in self.block_shape
+        )
+
+    def has_rmask(self) -> bool:
+        return self.has_rindex()
+
+    def has_tmpmask(self) -> bool:
+        return False  # block_ptr can't do indirect indexing
+
+    def has_mask(self) -> bool:
+        return bool(self.boundary_check())
+
+    def codegen_broadcast_and_reshape(
+        self,
+        value: str,
+        initial_shape: Sequence[sympy.Expr],
+        final_shape: Sequence[sympy.Expr],
+        allow_implicit: bool,
+        for_store: bool,
+    ) -> str:
+        """
+        Generate a broadcast and a reshape for the block descriptor.
+        This restores stride-0 dimensions which were removed from the block descriptor.
+
+        Transposes are also applied to the input using self.stride_sorter:
+        if for_store is True:
+            - First Broadcast the value. Since self.broadcast_shape is stored in
+            descending stride order, it must be reverted to the original order
+            since the input value does not have dims with descending strides
+            - After, transpose the broadcasted value so that dimensions are in
+            descending stride order
+            - Finally reshape to the block shape
+        else (for load):
+            - First broadcast the value to self.broadcast_shape (strides are descending)
+            - Then transpose the value so that dimensions no longer have descending strides
+            - Finally reshape the block to the final kernel tile shape
+        """
+        broadcast_shape = self.broadcast_shape
+        broadcasting_dims = self.broadcasting_dims
+
+        # If the block parameters have been sorted by descending strides,
+        # permute the broadcasting parameters so that they are compatible
+        # with the value being stored. This is because the dimensions
+        # of the value being stored are not sorted in descending stride order,
+        # but the broadcasting parameters are based on the dims in sorted order
+        if for_store:
+            broadcast_shape = self.stride_sorter.revert(self.broadcast_shape)
+            broadcasting_dims = self.stride_sorter.revert(self.broadcasting_dims)
+
+        # Reshape to add singletons.
+        pre_broadcast_shape = [
+            sympy.S.One if is_broadcasting else dim
+            for dim, is_broadcasting in zip(broadcast_shape, broadcasting_dims)
+        ]
+        value = triton_reshape(value, initial_shape, pre_broadcast_shape)
+
+        if (
+            not self.stride_sorter.is_identity
+            and not for_store
+            and len(pre_broadcast_shape) == len(final_shape)
+        ):
+            # If all we need to do is transpose to match the final shape
+            # with implicit broadcasting then we don't need an explicit broadcast
+            # unless the caller requests it. So just test implicit broadcast support
+            # with the transposed pre broadcast shape
+            pre_broadcast_shape = self.stride_sorter.revert(pre_broadcast_shape)
+
+        # Broadcast singletons.
+        # For loads, we can often implicitly broadcast singleton dimensions.
+        # We need an explicit broadcast for stores, or if the final reshape does more
+        # than add singletons.
+        sizevars = V.graph.sizevars
+        supports_implicit_broadcast = allow_implicit and (
+            len(pre_broadcast_shape) == len(final_shape)
+            and all(
+                sizevars.statically_known_equals(pre_dim, 1)
+                or sizevars.statically_known_equals(pre_dim, post_dim)
+                for pre_dim, post_dim in zip(pre_broadcast_shape, final_shape)
+            )
+        )
+
+        if any(self.broadcasting_dims) and not supports_implicit_broadcast:
+            value = (
+                f"tl.broadcast_to({value}, {V.kernel.index_to_str(broadcast_shape)})"
+            )
+
+        old_shape = self.broadcast_shape
+        if not self.stride_sorter.is_identity:
+            # if for_store the transform is
+            #   (non-descending strides) broadcasted kernel tile shape
+            #       -> (descending strides) block descriptor shape
+            # o/w if loading the transform is
+            #   (descending strides) ((maybe implicitly) broadcasted block shape
+            #       -> (non-descending) (maybe implicitly) broadcasted kernel tile shape
+            permute_dims = (
+                self.stride_sorter.sort_idx
+                if for_store
+                else self.stride_sorter.revert_sort_idx
+            )
+            value = f"tl.trans({value}, {permute_dims})"
+            old_shape = (
+                self.broadcast_shape
+                if for_store
+                else self.stride_sorter.revert(self.broadcast_shape)
+            )
+
+        # Reshape to the final shape.
+        value = triton_reshape(value, old_shape, final_shape)
+
+        return value
+
+
+@dataclasses.dataclass
+class TensorDescriptorOptions(BlockDescriptorOptions):
+    def format(self, name: str, roffset=True) -> str:
+        """
+        Codegen a call to tl.make_tensor_descriptor()
+
+        Args:
+            name: variable name for pointer
+            roffset: unused, but kept for compatibility with BlockPtrOptions.format()
+
+        Returns:
+            "tl.make_tensor_descriptor(...)"
+        """
+
+        f = V.kernel.index_to_str
+        args = [
+            (
+                f"{name} + ({f(self.constant_offset)})"
+                if self.constant_offset != 0
+                else name
+            ),
+            f"shape={f(self.shape)}",
+            f"strides={f(self.strides)}",
+            f"block_shape={f(self.block_shape)}",
+        ]
+
+        return f"tl.make_tensor_descriptor({', '.join(args)})"
+
+
+@dataclasses.dataclass
+class BlockPtrOptions(BlockDescriptorOptions):
+    def replace_offset(
+        self, expr: sympy.Expr, replacement: sympy.Expr, symt: SymT
+    ) -> sympy.Expr:
+        """
+        Replaces instances of {symt}_offset with the new expression.
+        """
+        roffset = TritonSymbols.block_offsets[symt]
+        return sympy_subs(expr, {roffset: replacement})
+
+    def remove_roffsets(self, expr: sympy.Expr) -> sympy.Expr:
+        for symt in TritonSymbols.reduction_types:
+            expr = self.replace_offset(expr, sympy.Integer(0), symt)
+        return expr
+
+    def format(self, name: str, roffset=True) -> str:
+        """
+        Codegen a call to tl.make_block_ptr()
+
+        Args:
+            name: variable name for pointer
+            roffset: should rn_offset be included in offsets=..., for use with tl.advance()
+
+        Returns:
+            "tl.make_block_ptr(...)"
+        """
+        f = V.kernel.index_to_str
+        offsets = [*self.offsets]
+        if not roffset:
+            offsets = [self.remove_roffsets(offset) for offset in offsets]
+        args = [
+            (
+                f"{name} + ({f(self.constant_offset)})"
+                if self.constant_offset != 0
+                else name
+            ),
+            f"shape={f(self.shape)}",
+            f"strides={f(self.strides)}",
+            f"block_shape={f(self.block_shape)}",
+            f"order={f(self.order)}",
+            f"offsets={f(offsets)}",
+        ]
+        return f"tl.make_block_ptr({', '.join(args)})"
+
+    def advance_roffset(self, symt: SymT) -> sympy.Expr:
+        """
+        Codegen string to pass to tl.advance(name, ...).
+
+        Advance is the difference between offsets in each loop iteration.
+        To compute it, we replace rN_offset with multiples of RN_BLOCK.
+        Since we expect rN_offset to vary in range(0, rN_numel, RN_BLOCK), the first
+        iteration has rN_offset=0, while the second has rN_offset=RN_BLOCK.
+        """
+        rblock = TritonSymbols.block_sizes[symt]
+        advance = [
+            (
+                self.replace_offset(offset, rblock, symt)
+                - self.replace_offset(offset, sympy.S.Zero, symt)
+            )
+            for offset in self.offsets
+        ]
+        return advance
+
+
+def triton_reshape(
+    value: str, old_shape: Sequence[sympy.Expr], new_shape: Sequence[sympy.Expr]
+) -> str:
+    """Workaround https://github.com/triton-lang/triton/issues/2836"""
+    assert isinstance(old_shape, list) and isinstance(new_shape, list)
+
+    old_shape_str = [V.kernel.index_to_str(shape) for shape in old_shape]
+    new_shape_str = [V.kernel.index_to_str(shape) for shape in new_shape]
+
+    if old_shape_str == new_shape_str:
+        return value
+    if [s for s in new_shape_str if s != "1"] != old_shape_str:
+        return f"tl.reshape({value}, [{', '.join(new_shape_str)}])"
+    # rewrite to [:, None] syntax, which is less buggy
+    idx = 0
+    expand = []
+    for size in new_shape_str:
+        if idx < len(old_shape_str) and size == old_shape_str[idx]:
+            expand.append(":")
+            idx += 1
+        else:
+            assert size == "1"
+            expand.append("None")
+    assert idx == len(old_shape_str)
+    return f"{value}[{', '.join(expand)}]"
+
+
+def enable_pdl_codegen():
+    if not torch._inductor.config.triton.enable_pdl:
+        return False
+    major, _ = torch.cuda.get_device_capability(torch.cuda.current_device())
+    return major >= 9
+
+
+# NB: Inheriting from PythonPrinter is somewhat dangerous, because there are a
+# number of operators which Triton "implements", but in a way that is
+# inconsistent with Python semantics (and consistent with C semantics).  We
+# must override all of these, or it is potential silent correctness problem
+class TritonPrinter(PythonPrinter):
+    def _print_TruncToInt(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return (
+            f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
+        )
+
+    def _print_Float(self, expr: sympy.Expr) -> str:
+        if expr.is_integer:
+            # sympy considers 0.0 to be integer, but triton doesn't.
+            # this workaround prints the float as an integer
+            # xref: https://github.com/sympy/sympy/issues/26620
+            ret = str(int(expr))
+        elif config.is_fbcode() and torch.version.hip:
+            ret = f"{expr}"
+        else:
+            ret = f"tl.full([], {expr}, tl.float64)"
+        return ret
+
+    def _print_ToFloat(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        s = self.parenthesize(expr.args[0], PRECEDENCE["Atom"] - 0.5)
+        return f"{s}.to(tl.float64)"
+
+    def _print_PythonMod(self, expr: sympy.Expr) -> str:
+        quot, div = expr.args
+        if quot.is_nonnegative and div.is_nonnegative:
+            return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
+        quot_s = self._print(quot)
+        div_s = self._print(div)
+        return f"triton_helpers.remainder_integer({quot_s}, {div_s})"
+
+    def _print_FloorDiv(self, expr: sympy.Expr) -> str:
+        assert expr.is_integer
+        quot, div = expr.args
+        if quot.is_nonnegative and div.is_nonnegative:
+            return self.stringify(expr.args, " // ", PRECEDENCE["Atom"] - 0.5)
+        quot_s = self._print(quot)
+        div_s = self._print(div)
+        return f"triton_helpers.div_floor_integer({quot_s},  {div_s})"
+
+    # TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher
+    # precision algorithm, which we would need to replicate here
+    def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
+        return self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5)
+
+    # NB: sympy.floor/ceiling produce integers, so we have to do the
+    # conversion to index dtype
+    def _print_floor(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return (
+            f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
+        )
+
+    def _print_FloorToInt(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return (
+            f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
+        )
+
+    def _print_ceiling(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
+
+    def _print_CeilToInt(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
+
+    def _helper_sqrt(self, expr: sympy.Expr) -> str:
+        # work around for https://github.com/pytorch/pytorch/issues/165738
+        if torch.xpu.is_available():
+            return f"libdevice.sqrt(({self._print(expr)}).to(tl.float32))"
+        return f"tl.sqrt_rn(({self._print(expr)}).to(tl.float32))"
+
+    def _print_FloatPow(self, expr: sympy.Expr) -> str:
+        return (
+            f"libdevice.pow({self._print(expr.args[0])}, {self._print(expr.args[1])})"
+        )
+
+    def _print_PowByNatural(self, expr: sympy.Expr) -> str:
+        if expr.args[0].is_Integer:
+            return f"libdevice.pow({float(expr.args[0])}, {self._print(expr.args[1])})"
+        return (
+            f"libdevice.pow({self._print(expr.args[0])}, {self._print(expr.args[1])})"
+        )
+
+    def _print_Where(self, expr: sympy.Expr) -> str:
+        c = self.doprint(expr.args[0])
+        p = self.doprint(expr.args[1])
+        q = self.doprint(expr.args[2])
+        return f"tl.where({c}, {p}, {q})"
+
+    def _print_min_max_helper(self, expr: sympy.Expr, cmp: str) -> str:
+        """
+        Helper for max/min code generation.
+        cmp: > or <
+        """
+        if len(expr.args) == 1:
+            return self._print(expr.args[0])
+
+        mid = len(expr.args) // 2
+        cls = type(expr)
+        a = self._print(cls(*expr.args[:mid]))
+        b = self._print(cls(*expr.args[mid:]))
+
+        # Use a macro so we can propagate constexprs.
+        # https://github.com/triton-lang/triton/issues/3815
+        a, b = tuple(f"({x})" for x in (a, b))
+        assert cmp in (">", "<"), f"Unexpected comparator: '{cmp}'"
+        return f"({a} * ({a} {cmp}= {b}) + {b} * ({b} {cmp} {a}))"
+
+    def _print_Min(self, expr: sympy.Expr) -> str:
+        return self._print_min_max_helper(expr, "<")
+
+    def _print_Max(self, expr: sympy.Expr) -> str:
+        return self._print_min_max_helper(expr, ">")
+
+    def _print_Abs(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"tl_math.abs({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"libdevice.cos(({self._print(expr.args[0])}).to(tl.float32))"
+
+    def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"libdevice.cosh(({self._print(expr.args[0])}).to(tl.float32))"
+
+    def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"libdevice.acos(({self._print(expr.args[0])}).to(tl.float32))"
+
+    def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"libdevice.sin(({self._print(expr.args[0])}).to(tl.float32))"
+
+    def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"libdevice.sinh(({self._print(expr.args[0])}).to(tl.float32))"
+
+    def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"libdevice.asin(({self._print(expr.args[0])}).to(tl.float32))"
+
+    def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"libdevice.tan(({self._print(expr.args[0])}).to(tl.float32))"
+
+    def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"libdevice.tanh(({self._print(expr.args[0])}).to(tl.float32))"
+
+    def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"libdevice.atan(({self._print(expr.args[0])}).to(tl.float32))"
+
+    def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"libdevice.log2(({self._print(expr.args[0])}).to(tl.float32))"
+
+    def _print_RoundToInt(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return (
+            f"libdevice.llrint({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
+        )
+
+    def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 2
+        number, ndigits = expr.args
+        if number.is_integer:
+            # ndigits < 0 should have been filtered by the sympy function
+            assert ndigits < 0
+            raise ValueError(
+                f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
+            )
+
+        number_str = self.parenthesize(number, PRECEDENCE["Mul"])
+        return f"libdevice.nearbyint(1e{ndigits} * {number_str}) * 1e{-ndigits}"
+
+
+texpr = TritonPrinter().doprint
+
+
+def triton_compute_type(dtype: torch.dtype) -> str:
+    """Convert torch.dtype to triton type and upcast [b]float16 to float32"""
+    return triton_type(upcast_compute_type(dtype))
+
+
+def triton_store_type(dtype: torch.dtype) -> str:
+    """Convert torch.dtype to triton type, with fix for storing tl.bool"""
+    if dtype == torch.bool:
+        dtype = torch.int8
+    return triton_type(dtype)
+
+
+def upcast_acc_dtype(dtype: torch.dtype) -> torch.dtype:
+    """Implicit upcasts used for Triton reduction types"""
+    if is_integer_dtype(dtype) and dtype.is_signed and dtype.itemsize <= 4:
+        return torch.int32
+    return upcast_compute_type(dtype)
+
+
+def triton_acc_type(dtype: torch.dtype) -> str:
+    """Convert torch.dtype to triton type, with reduction upcasts"""
+    return triton_compute_type(upcast_acc_dtype(dtype))
+
+
+def low_precision_fp(dtype: torch.dtype) -> bool:
+    return dtype.itemsize <= 2 and dtype.is_floating_point
+
+
+def low_precision_fp_var(var: Union[CSEVariable, Any]) -> bool:
+    if not isinstance(var, CSEVariable):
+        return False
+
+    dtype = var.dtype
+    return low_precision_fp(dtype) if isinstance(dtype, torch.dtype) else False
+
+
+class TritonCSEVariable(CSEVariable):
+    def __init__(
+        self,
+        name: str,
+        bounds: ValueRanges[Any],
+        dtype: torch.dtype,
+        shape: BlockShapeType = None,
+    ) -> None:
+        super().__init__(name, bounds, dtype, shape=shape)
+        # We'll use this to track which masks the variable needs when used for indirect indexing
+        self.mask_vars: OrderedSet[str] = OrderedSet()
+        assert dtype is not None, "TritonCSEVariable must have dtype"
+        assert shape is not None, "TritonCSEVariable must have shape"
+
+    def update_on_args(self, name, args, kwargs):
+        for arg in args:
+            if isinstance(arg, TritonCSEVariable):
+                self.mask_vars.update(arg.mask_vars)
+            elif isinstance(arg, sympy.Symbol):
+                # most of the time index vars don't need masks associated with them
+                # however, when index vars are used to compute indices for indirect reads
+                # those reads should subsequently be masked,
+                for symt in TritonSymbols.block_types:
+                    if symbol_is_type(arg, symt):
+                        self.mask_vars.update([f"{prefix_str[symt]}mask"])
+                        break
+
+
+def get_dtype_handler() -> DtypePropagationOpsHandler:
+    from torch._inductor.dtype_propagation import DtypePropagationOpsHandler
+
+    return DtypePropagationOpsHandler()
+
+
+def maybe_upcast_float32(convert_output: bool = True) -> Callable[[_T], _T]:
+    """
+    Codegen helper to upcast arguments to float32, depending on the config and dtype.
+    This decorates tl.math/libdevice codegen functions.
+    """
+
+    def needs_upcast(var) -> bool:
+        return (
+            not config.triton.codegen_upcast_to_fp32
+            and isinstance(var, CSEVariable)
+            and var.dtype in (torch.float16, torch.bfloat16)
+        )
+
+    def maybe_upcast_arg(var) -> str:
+        upcast_string = ".to(tl.float32)" if needs_upcast(var) else ""
+        return f"{var}{upcast_string}"
+
+    def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
+        # Record that this function only supports float32 and float64.
+        OpDtypeSupport.register_upcast(func, convert_output)
+
+        def wrapped(*args, **kwargs) -> str:
+            # Optionally upcast args to float32.
+            upcast_args = [maybe_upcast_arg(arg) for arg in args]
+            upcast_kwargs = {key: maybe_upcast_arg(val) for key, val in kwargs.items()}
+
+            # Call the decorated function, optionally downcasting the result.
+            result = func(*upcast_args, **upcast_kwargs)
+            any_needs_upcast = convert_output and any(
+                needs_upcast(var) for var in itertools.chain(args, kwargs.values())
+            )
+            result_dtype = (
+                None
+                if not any_needs_upcast
+                else getattr(get_dtype_handler(), func.__name__)(*args, **kwargs)
+            )
+            needs_downcast = result_dtype not in (torch.float32, None)
+            downcast_string = (
+                f".to({triton_type(result_dtype)})"
+                if needs_downcast and result_dtype is not None
+                else ""
+            )
+            return f"{result}{downcast_string}"
+
+        return wrapped
+
+    return decorator  # type: ignore[return-value]
+
+
+class TritonOverrides(OpOverrides):
+    """Map element-wise ops to Triton e.g., ops.to_dtype(x,...) -> x.to(...)"""
+
+    _LOG_2_E = math.log2(math.e)
+
+    @staticmethod
+    def to_dtype(
+        x,
+        dtype: torch.dtype,
+        src_dtype: Optional[torch.dtype] = None,
+        use_compute_types=True,
+    ):
+        def _get_min_elements_per_thread(
+            src_dtype: torch.dtype, dst_dtype: torch.dtype
+        ) -> int:
+            if src_dtype == dst_dtype:
+                # No data type conversion is needed. No requirements on min_elem_per_thread.
+                return 0
+
+            # fp8 data type conversions has min_elem_per_thread requirements.
+            # Refer to Triton implementations here:
+            # https://github.com/triton-lang/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10.
+            fp8_dtypes = (
+                torch.float8_e4m3fn,
+                torch.float8_e5m2,
+            )
+            # Triton doesn't support type conversions between fp8_e4m3 and fp8_e5m2.
+            assert not (
+                src_dtype in fp8_dtypes
+                and dst_dtype in fp8_dtypes
+                and src_dtype != dst_dtype
+            ), "Conversions between float8_e5m2 and float8_e4m3fn is not supported!"
+            if src_dtype == torch.float8_e5m2 or dst_dtype == torch.float8_e5m2:
+                return 4
+            if src_dtype == torch.float8_e4m3fn or dst_dtype == torch.float8_e4m3fn:
+                return 2
+            # No requirements on min_elem_per_thread.
+            return 0
+
+        if src_dtype is not None:
+            # Both dtype and src_dtype are set. This is used by torch to(dtype=dtype).
+            # It takes the maximum min_elem_per_thread if there are multiple fp8 conversions
+            # in the same kernel.
+            V.kernel.min_elem_per_thread = max(
+                _get_min_elements_per_thread(src_dtype, dtype),
+                V.kernel.min_elem_per_thread,
+            )
+
+        if dtype == torch.bool:
+            return f"({x} != 0)"
+        elif dtype == torch.uint8 and (
+            src_dtype is not None and src_dtype.is_floating_point or src_dtype is None
+        ):
+            # to work around llvm uint conversion semantics that produces 0's for negative
+            # values when converting from floating types.
+            # optimization - if source type is known and it's not a floating type, then
+            # do not apply conversion to the intermediate type.
+            return f"{x}.to(tl.int16).to(tl.uint8)"
+
+        if use_compute_types:
+            out_dtype = triton_compute_type(dtype)
+        else:
+            out_dtype = triton_store_type(dtype)
+
+        return f"{x}.to({out_dtype})"
+
+    @staticmethod
+    def to_dtype_bitcast(x, dtype: torch.dtype, src_dtype: torch.dtype):
+        assert src_dtype.itemsize == dtype.itemsize
+        # We may promote float16 or bfloat16 to float32 and cause the
+        # bitwidth of dtype to be different from the input tensor (i.e. float32).
+        # In such as case, we will have to convert the input tensor to
+        # its src_type, perform bitcast, and then convert the bit-casted
+        # tensor back to float to ensure we use values with the right precision.
+        if x.dtype != src_dtype:
+            x = f"{x}.to({triton_type(src_dtype)})"
+
+        out = f"{x}.to({triton_type(dtype)}, bitcast=True)"
+        if upcast_compute_type(dtype) != dtype:
+            out = f"{out}.to({triton_type(upcast_compute_type(dtype))})"
+
+        return out
+
+    @staticmethod
+    def _shaped_constant(value, dtype, shape):
+        type_ = torch._prims_common.dtype_to_type(dtype)
+        triton_val = constant_repr(type_(value))
+        triton_type = triton_compute_type(dtype)
+
+        if triton_type == "tl.float32":
+            # Float constants are always f32 in triton
+            return triton_val
+
+        # NOTE: We use a tensor here in order to get the expected type.
+        # Otherwise, e.g. float64 constants would be truncated to float32.
+        if value < 0 and not dtype.is_signed:
+            triton_signed_type = f"tl.{triton_type[4:]}"
+            return f"tl.full({shape}, {triton_val}, {triton_signed_type}).to({triton_type})"
+        else:
+            return f"tl.full({shape}, {triton_val}, {triton_type})"
+
+    @classmethod
+    def constant(cls, value, dtype):
+        return cls._shaped_constant(value, dtype, shape=[])
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def abs(x):
+        return f"tl_math.abs({x})"
+
+    # TODO - register these ops as having divergent dtype
+    # output if doing graph pass to remove consecutive casts
+
+    @staticmethod
+    def truediv(x, y):
+        x_dtype = getattr(x, "dtype", None)
+        y_dtype = getattr(y, "dtype", None)
+
+        if (
+            x_dtype == torch.float32
+            and y_dtype == torch.float32
+            and config.emulate_divison_rounding
+        ):
+            # x / y in Triton is lowered to div.full which is approx
+            # we want div_rn to adhere with eager
+            out = f"triton.language.div_rn({x}, {y})"
+        else:
+            out = f"({x} / {y})"
+
+        # Workaround here since the functionality of div_rn has not ready on XPU.
+        # TODO: remove this workaround after https://github.com/intel/intel-xpu-backend-for-triton/issues/5306
+        # resolved.
+        if torch.xpu.is_available():
+            out = f"({x} / {y})"
+
+        if low_precision_fp_var(x) or low_precision_fp_var(y):
+            out_dtype = get_dtype_handler().truediv(x, y)
+            if out_dtype in (torch.float16, torch.float32):
+                out = f"{out}.to({triton_type(out_dtype)})"
+
+        return out
+
+    @staticmethod
+    def mod(x, y):
+        out = f"({x} % {y})"
+        if low_precision_fp_var(x) or low_precision_fp_var(y):
+            out_dtype = get_dtype_handler().mod(x, y)
+            if out_dtype in (torch.float16, torch.float32):
+                out = f"{out}.to({triton_type(out_dtype)})"
+        return out
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def exp(x):
+        """
+        When use_fast_math, use the ftz (flushing to zero) variant
+        of exponent computation.
+
+        Check https://github.com/triton-lang/triton/issues/5735 for
+        more details.
+        """
+        if config.use_fast_math:
+            return f"tl_math.exp({x})"
+        else:
+            return f"libdevice.exp({x})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def exp2(x):
+        return f"libdevice.exp2({x})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def expm1(x):
+        return f"libdevice.expm1({x})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def sqrt(x):
+        # work around for https://github.com/pytorch/pytorch/issues/165738
+        if torch.xpu.is_available():
+            return f"libdevice.sqrt({x})"
+        return f"tl.sqrt_rn({x})"
+
+    @staticmethod
+    def relu(x):
+        bug = config.triton.inject_relu_bug_TESTING_ONLY
+        if bug == "compile_error":
+            return "compile error!"
+        elif bug == "runtime_error":
+            # NB: this only triggers runtime error as long as input
+            # is not all zero
+            return f'triton_helpers.device_assert_then({x} == 0, "injected assert fail", {x})'
+        elif bug == "accuracy":
+            return f"{x} + 1"
+        elif bug is None:
+            return ops.maximum(ops.constant(0, torch.int32), x)
+        else:
+            raise AssertionError(
+                f"unrecognized config triton.inject_relu_bug_TESTING_ONLY = {bug!r}"
+            )
+
+    @staticmethod
+    def minimum(a, b):
+        return f"triton_helpers.minimum({a}, {b})"
+
+    @staticmethod
+    def maximum(a, b):
+        return f"triton_helpers.maximum({a}, {b})"
+
+    @staticmethod
+    def where(a, b, c):
+        return f"tl.where({a}, {b}, {c})"
+
+    @staticmethod
+    def dot(a, b):
+        """
+        Triton code generation for lowering ops.dot to tl.dot.
+
+        The logic is as follows:
+
+        1. Downcasting for performance
+           If the data was previously upcasted to fp32, we downcast back to the
+           original dtype (e.g., fp16 or bf16) for better performance. While
+           surrounding operations may run in fp32, matmul itself is executed at the
+           original precision to optimize throughput.
+
+        2. Handling non-constant reduction masks
+           If the reduction mask is not constant and there was any operation between
+           tl.load and tl.dot, we zero out regions outside the mask using
+           tl.where(r0_mask, val, 0).
+           This ensures that values outside the mask do not contribute to the dot
+           product, preventing incorrect results.
+
+        3. Shape alignment for tl.dot
+           We massage shapes to match the tl.dot requirement of (Y, R) x (R, X).
+           Current codegen eagerly broadcasts tl.arange to create unique axes. We
+           reshape, transpose, or broadcast to align with the (Y, R) x (R, X) shape.
+           We avoid using 3D dot ((Z, Y, R) x (Z, R, X)) because 3D tl.dot has
+           poor performance. During batched matmul (bmm), we keep ZBLOCK=1 and call
+           the 2D dot kernel instead.
+        """
+        assert V.kernel.is_native_matmul
+        orig_a, orig_b = a, b
+
+        def is_where_needed(var):
+            # Skip if the variable doesn't have a reduction mask
+            if not any(map(prefix_is_reduction, var.mask_vars)):
+                return False
+
+            reduction_range = V.kernel.range_trees[-1]
+            assert reduction_range.is_reduction
+
+            # Skip if reduction mask was already constant
+            if V.kernel._has_constant_mask(reduction_range):
+                return False
+
+            # Skip if the variable is already zeroed outside the mask
+            # (e.g., from tl.load(..., other=0.0))
+            # TODO : track the value of outside of mask region with cse
+            for k, v in V.kernel.cse._cache.items():
+                if v == var and "tl.load" in k and "other=0.0" in k:
+                    return False
+
+            return True
+
+        def where_cond(var):
+            default = ir.Reduction.default_value("dot", var.dtype)
+            reduction_mask = [
+                f"{tree.prefix}mask"
+                for tree in V.kernel.range_trees
+                if tree.is_reduction
+            ]
+
+            assert len(reduction_mask) == 1, "don't tile reduction when native matmul"
+
+            where_var = TritonKernelOverrides.where(reduction_mask[0], var, default)
+            return V.kernel.cse.generate(
+                V.kernel.compute, where_var, dtype=var.dtype, shape=var.shape
+            )
+
+        # When computing expressions like ((A+1) @ (B+2)),
+        # native codegen will do
+        #
+        # a = tl.load(..., r0_mask, other=0.0)
+        # b = tl.load(..., r0_mask, other=0.0)
+        # tmp0 = a+1
+        # tmp1 = b+2
+        # tmp2 = tl.dot(tmp0, tmp1)
+        #
+        # This produces incorrect results because outside of r0_mask is not zero.
+        # So before calling tl.dot, apply tl.where to zero out values properly.
+        # TODO: Optimize - We don't need both operands to be zeroed except NaN * 0
+        if is_where_needed(orig_a):
+            a = where_cond(a)
+        if is_where_needed(orig_b):
+            b = where_cond(b)
+
+        def reshape_transpose_broadcast_for_dot(
+            value,
+            initial_shape: Sequence[sympy.Expr],
+            final_shape: Sequence[sympy.Expr],
+        ) -> str:
+            """
+            Generate a reshape, transpose, and broadcast for the tl.dot.
+            tl.dot requires specific shape requirement : (Y,R) x (R,X)
+            but the current triton codegen eagerly broadcast the tl.arange so
+            it needs to be reshaped to meet the requirement.
+
+            This is done by three steps.
+            1. remove the empty dimension (dim with size 1) and make it 2d with tl.reshape
+            2. permute the dimension if needed (e.g., (X,R) -> (R,X)) with tl.trans
+            3. broadcast if needed with broadcast_to.
+                - This shows up when matmul operand is broadcasted with torch.expand/repeat.
+                - e.g., torch.rand((16,)).expand(16,16) @ B
+
+            e.g., (Y,1,R), (Y,R) -> tl.reshape(var, (Y,R))
+            e.g., (1,X,R), (R,X) -> tl.trans(tl.reshape(var, (X,R)))
+            e.g., (1,X,1), (R,X) -> tl.broadcast_to(tl.trans(tl.reshape(var, (X,1))), (R,X))
+
+            TODO : eventually we want to remove this function when lazy broadcasting arrives
+            """
+
+            # Triton 3d dot is slower than 2d dot, so we want to keep block shape in 2d
+            # by fixing ZBLOCK=1 in the autotune config
+            if ZBLOCK in initial_shape:
+                initial_shape = ["1" if dim == ZBLOCK else dim for dim in initial_shape]
+
+            if final_shape == [YBLOCK, RBLOCK]:
+                assert XBLOCK not in initial_shape, (
+                    "left tl.dot operand cannot depend on x"
+                )
+
+                shape_2d = ["1", "1"]
+                if YBLOCK in initial_shape:
+                    shape_2d[0] = YBLOCK
+                if RBLOCK in initial_shape:
+                    shape_2d[1] = RBLOCK
+
+                # reshape it into 2d
+                value = triton_reshape(value, initial_shape, shape_2d)
+
+                # broadcast if needed
+                broadcast_needed = shape_2d != [YBLOCK, RBLOCK]
+                if broadcast_needed:
+                    value = f"tl.broadcast_to({value}, ({YBLOCK}, {RBLOCK}))"
+
+            elif final_shape == [RBLOCK, XBLOCK]:
+                assert YBLOCK not in initial_shape, (
+                    "right tl.dot operand cannot depend on y"
+                )
+
+                shape_2d = ["1", "1"]
+                if XBLOCK in initial_shape:
+                    shape_2d[0] = XBLOCK
+                if RBLOCK in initial_shape:
+                    shape_2d[1] = RBLOCK
+
+                # reshape it into 2d (X,R)
+                value = triton_reshape(value, initial_shape, shape_2d)
+
+                # transpose to (R,X)
+                value = f"tl.trans({value})"
+
+                # broadcast if needed
+                broadcast_needed = shape_2d != [XBLOCK, RBLOCK]
+                if broadcast_needed:
+                    value = f"tl.broadcast_to({value}, ({RBLOCK}, {XBLOCK}))"
+            else:
+                raise NotImplementedError
+
+            return value
+
+        assert len(V.kernel.dense_size_list()) >= 3, "tl.dot can only do mm and bmm"
+
+        XBLOCK = str(TritonSymbols.block_sizes[SymT.XBLOCK])
+        YBLOCK = str(TritonSymbols.block_sizes[SymT.YBLOCK])
+        ZBLOCK = str(TritonSymbols.block_sizes[SymT.ZBLOCK])
+        RBLOCK = str(TritonSymbols.block_sizes[SymT.R0_INDEX])
+
+        a = V.kernel.cse.generate(
+            V.kernel.compute,
+            reshape_transpose_broadcast_for_dot(a, list(a.shape), [YBLOCK, RBLOCK]),
+            dtype=a.dtype,
+            shape=(YBLOCK, RBLOCK),
+        )
+
+        b = V.kernel.cse.generate(
+            V.kernel.compute,
+            reshape_transpose_broadcast_for_dot(b, list(b.shape), [RBLOCK, XBLOCK]),
+            dtype=b.dtype,
+            shape=(RBLOCK, XBLOCK),
+        )
+
+        if torch.backends.cuda.matmul.fp32_precision == "tf32":
+            input_precision = "tf32"
+        else:
+            input_precision = "ieee"
+
+        return f'tl.dot({a}, {b}, input_precision="{input_precision}")'
+
+    @staticmethod
+    def inline_asm_elementwise(
+        *inputs, asm, constraints=None, dtype=torch.float32, is_pure=True, pack=1
+    ):
+        triton_type = triton_compute_type(dtype)
+        input_refs = ", ".join([str(i) for i in inputs])
+        if constraints is None:
+            constraints = ", ".join(["=r"] + ["r" for _ in inputs])
+        return f"tl.inline_asm_elementwise('{asm}', '{constraints}', [{input_refs}], dtype={triton_type}, is_pure={is_pure}, pack={pack})"  # noqa: B950
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def cos(x):
+        return f"tl_math.cos({x})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def sin(x):
+        return f"tl_math.sin({x})"
+
+    @classmethod
+    def index_expr(cls, expr, dtype):
+        raise NotImplementedError("ops.index_expr not implemented outside a kernel")
+
+    @staticmethod
+    def masked(mask, body, other):
+        raise NotImplementedError("ops.masked not implemented outside a kernel")
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def lgamma(x):
+        return f"libdevice.lgamma({x})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def erf(x):
+        return f"libdevice.erf({x})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def cosh(x):
+        return f"libdevice.cosh({x})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def sinh(x):
+        return f"libdevice.sinh({x})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def acos(x):
+        return f"libdevice.acos({x})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def acosh(x):
+        return f"libdevice.acosh({x})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def asin(x):
+        return f"libdevice.asin({x})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def asinh(x):
+        return f"libdevice.asinh({x})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def atan2(x, y):
+        return f"libdevice.atan2({x}, {y})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def atan(x):
+        return f"libdevice.atan({x})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def atanh(x):
+        return f"libdevice.atanh({x})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def copysign(x, y):
+        return f"libdevice.copysign({x}, {y})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def erfc(x):
+        return f"libdevice.erfc({x})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def erfinv(x):
+        return f"libdevice.erfinv({x})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def hypot(x, y):
+        return f"libdevice.hypot({x}, {y})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def log10(x):
+        return f"libdevice.log10({x})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def log2(x):
+        return f"libdevice.log2({x})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def nextafter(x, y):
+        return f"libdevice.nextafter({x}, {y})"
+
+    @staticmethod
+    def logical_and(a, b):
+        return f"{a} & {b}"
+
+    @staticmethod
+    def logical_not(a):
+        return f"{a} == 0"
+
+    @staticmethod
+    def logical_or(a, b):
+        return f"{a} | {b}"
+
+    @staticmethod
+    def logical_xor(a, b):
+        return f"({a} ^ {b})"
+
+    @staticmethod
+    def bitwise_and(a, b):
+        return f"{a} & {b}"
+
+    @staticmethod
+    def bitwise_not(a):
+        return f"~{a}"
+
+    @staticmethod
+    def bitwise_or(a, b):
+        return f"{a} | {b}"
+
+    @staticmethod
+    def bitwise_xor(a, b):
+        return f"{a} ^ {b}"
+
+    @staticmethod
+    def bitwise_left_shift(a, b):
+        return f"{a} << {b}"
+
+    @staticmethod
+    def bitwise_right_shift(a, b):
+        return f"{a} >> {b}"
+
+    @staticmethod
+    def rand(seed, offset):
+        offset = f"({offset}).to(tl.uint32)"
+        return f"tl.rand({seed}, {offset})"
+
+    @staticmethod
+    def randn(seed, offset):
+        offset = f"({offset}).to(tl.uint32)"
+        return f"tl.randn({seed}, {offset})"
+
+    @staticmethod
+    def randint64(seed, offset, low, high):
+        offset = f"({offset}).to(tl.uint32)"
+        return f"triton_helpers.randint64({seed}, {offset}, {low}, {high})"
+
+    @staticmethod
+    def load_seed(name, offset):
+        raise NotImplementedError("ops.load_seed not implemented outside a kernel")
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def rsqrt(x):
+        return f"libdevice.rsqrt({x})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def log1p(x):
+        return f"libdevice.log1p({x})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def tan(x):
+        return f"libdevice.tan({x})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def tanh(x):
+        cse_var = V.kernel.cse.varname_map.get(x)
+        if cse_var and hasattr(cse_var, "dtype"):
+            dtype = cse_var.dtype
+        else:
+            dtype = None
+        if (
+            config.use_fast_math
+            and torch.version.hip
+            and get_triton_version() > (3, 5)
+            and dtype != torch.float64
+            and dtype is not None
+        ):
+            # Requires upstream Triton 3.6+ for latest fast_tanhf support
+            # https://github.com/triton-lang/triton/pull/8551
+            return f"libdevice.fast_tanhf({x})"
+        else:
+            return f"libdevice.tanh({x})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def sigmoid(x):
+        return f"tl.sigmoid({x})"
+
+    @staticmethod
+    def signbit(x):
+        # XX: This is wrong for the value -0.0 in floating point
+        return (
+            f"(libdevice.signbit({x}) != 0) if ({x}).dtype is tl.float32 else {x} < 0"
+        )
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def fmod(a, b):
+        return f"libdevice.fmod({a}, {b})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def pow(a, b):
+        return f"libdevice.pow({a}, {b})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def log(x):
+        return f"tl_math.log({x})"
+
+    @staticmethod
+    @maybe_upcast_float32(convert_output=False)
+    def isinf(x):
+        return f"libdevice.isinf({x}).to(tl.int1)"
+
+    @staticmethod
+    @maybe_upcast_float32(convert_output=False)
+    def isnan(x):
+        return f"libdevice.isnan({x}).to(tl.int1)"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def round(x):
+        return f"libdevice.nearbyint({x})"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def floor(x):
+        return f"libdevice.floor({x})"
+
+    @staticmethod
+    def floordiv(a, b):
+        # See the comment in lowering.div_mode. a and b are integer type.
+        # Similar to div_floor_kernel_cuda in pytorch core.
+        # Notice that // in triton behaves as truncdiv instead of floordiv
+        quot = f"{a} // {b}"
+        rem = f"{a} % {b}"
+        return f"tl.where(({a} < 0) != ({b} < 0), tl.where({rem} != 0, {quot} - 1, {quot}), {quot})"
+
+    @staticmethod
+    def sign(x):
+        z = ops.constant(0, torch.int32)
+        left = ops.to_dtype((ops.lt(z, x)), torch.int8)
+        right = ops.to_dtype((ops.lt(x, z)), torch.int8)
+        sub = ops.sub(left, right)
+        return f"{sub}.to({x}.dtype)"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def trunc(x):
+        return f"libdevice.trunc({x})"
+
+    @staticmethod
+    def truncdiv(a, b):
+        # See the comment in lowering.div_mode. a and b are integer type.
+        # Notice that // in triton behaves as truncdiv instead of floordiv
+        return f"{a} // {b}"
+
+    @staticmethod
+    @maybe_upcast_float32()
+    def ceil(x):
+        return f"libdevice.ceil({x})"
+
+
+TritonOverrides._initialize_pointwise_overrides("triton")
+
+
+class TritonKernelOverrides(TritonOverrides):
+    """Map element-wise ops to Triton within a TritonKernel
+
+    Unlike TritonOverrides, these assume the code is going to be inserted into
+    the body of the main triton kernel and so it may use indexing and mask
+    variables which are assumed to already be defined in the current scope.
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        # happens in __init__ unlike _initialize_pointwise_overrides
+        # because the libdevice registrations are populated during lowerings
+        self._setup_libdevice_routing()
+
+    @classmethod
+    @functools.cache
+    def _setup_libdevice_routing(cls):
+        """Set up routing to libdevice implementations for fp64 inputs."""
+
+        from torch._inductor.codegen.common import OpDecompositions
+
+        for fn_name in torch._inductor.utils.op_requires_libdevice_fp64:
+            assert hasattr(cls, fn_name)
+            original_impl = getattr(cls, fn_name)
+
+            def decomposition_router(x, _original_impl, _fn_name):
+                if x.dtype != torch.float64:
+                    return _original_impl(x)
+                else:
+                    return getattr(OpDecompositions, _fn_name)(x).value
+
+            if fn_name == "sigmoid":
+                assert hasattr(OpDecompositions, "sigmoid")
+                fn = functools.partial(
+                    decomposition_router, _original_impl=original_impl, _fn_name=fn_name
+                )
+                fn.__name__ = fn_name  # type: ignore[attr-defined]
+                setattr(cls, fn_name, staticmethod(fn))
+                continue
+
+            def dtype_router(x, _original_impl, _fn_name):
+                if x.dtype == torch.float64:
+                    return f"libdevice.{_fn_name}({x})"
+                else:
+                    return _original_impl(x)
+
+            fn = functools.partial(
+                dtype_router, _original_impl=original_impl, _fn_name=fn_name
+            )
+            fn.__name__ = fn_name  # type: ignore[attr-defined]
+            setattr(cls, fn_name, staticmethod(fn))
+
+    @classmethod
+    def constant(cls, value, dtype):
+        # NOTE: Cannot use shape=[] as it's not supported by triton-rocm
+        # We could use shape=[1] instead but starting with the correct
+        # ndim avoids extra `tt.expand_dim` ops appearing in the triton IR.
+        ndim = V.kernel.triton_tensor_ndim()
+        shape = [1] * ndim
+        return cls._shaped_constant(value, dtype, shape=shape)
+
+    @classmethod
+    def index_expr(cls, expr, dtype):
+        indexing = V.kernel.indexing(
+            expr, block_ptr=False, tma_compatibility_checker=None
+        )
+        assert isinstance(indexing, IndexingOptions)
+
+        shape: BlockShapeType
+        if indexing.expand_shape:
+            shape = indexing.expand_shape
+        else:
+            shape = TritonSymbols.get_block_shape(indexing.index)
+
+        # Our sympy expr printing casts to the current kernel index dtype.
+        # we only respect non int32-int64 dtypes and otherwise use current kernel indexing dtype
+        index_dtype = V.kernel.get_index_dtype_as_torch_dtype()
+        dtype = dtype if dtype not in (torch.int32, torch.int64) else index_dtype
+
+        # after we emit this var we cast it to the correct dtype
+        orig = config.test_configs.runtime_triton_dtype_assert
+        try:
+            config.test_configs.runtime_triton_dtype_assert = False
+            var = V.kernel.cse.generate(
+                V.kernel.compute,
+                indexing.index_str,
+                bounds=get_bounds_index_expr(expr),
+                dtype=dtype,
+                shape=shape,
+            )
+        finally:
+            config.test_configs.runtime_triton_dtype_assert = orig
+
+        if dtype not in (torch.int32, torch.int64):
+            var = V.kernel.cse.generate(
+                V.kernel.compute,
+                cls.to_dtype(var, dtype),
+                dtype=upcast_compute_type(dtype),
+                shape=var.shape,
+            )
+        else:
+            # TODO: we are not always consistent in enforcing that the output of the index expr printing
+            # results in the indexing dtype. So if we detect that we have an input which might type promote
+            # to a dtype other than indexing dtype, add a cast.
+            # Trying to avoid
+            dtype = index_dtype
+            for index_var in expr.free_symbols:
+                if symbol_is_type(index_var, SymT.TMP):
+                    dtype = torch.promote_types(
+                        dtype, V.kernel.cse.varname_map[index_var.name].dtype
+                    )
+
+            if dtype != index_dtype:
+                var = V.kernel.cse.generate(
+                    V.kernel.compute,
+                    cls.to_dtype(var, index_dtype),
+                    dtype=index_dtype,
+                    shape=var.shape,
+                )
+
+        var.mask_vars = indexing.mask_vars
+        return var
+
+    @staticmethod
+    def masked(mask, body, other):
+        if mask is not None and torch.version.hip is not None:
+            mask = V.kernel.cse.generate(
+                V.kernel.compute,
+                f"{mask}.to(tl.int1)",
+                dtype=torch.bool,
+                shape=mask.shape,
+            )
+
+        nodes = body.graph.find_nodes(op="output")
+        assert nodes, "graph for body does not contain an output"
+
+        need_where = False
+        # If we have a tl.load with a masking operator and no other value
+        # we can add the mask here and the other value to the tl.load
+        # operator to save the branching cost.
+        for node in nodes:
+            for arg in node.args:
+                if arg.target != "load" or should_unwrap_unspec_arg(arg.args[1]):
+                    need_where = True
+                    break
+
+        value = None if need_where else other
+
+        with V.kernel.mask_loads(mask, value=value) as new_mask:
+            result = body()
+
+        if need_where:
+            # Remove once CSEVariables track the dtype
+            if result.bounds.is_bool:
+                other = bool(other)
+            # Take dtype from result to prevent accidental promotion
+            other = V.kernel.cse.generate(
+                V.kernel.compute,
+                f"tl.full({result}.shape, {constant_repr(other)}, {result}.dtype)",
+                bounds=ValueRanges.wrap(other),
+                dtype=result.dtype,
+                shape=result.shape,
+            )
+            ret = ops.where(new_mask, result, other)
+        else:
+            ret = result
+
+        ret.mask_vars.discard(new_mask)
+        return ret
+
+    @staticmethod
+    def load_seed(name, offset):
+        var = V.kernel.args.input(name)
+        return (
+            f"tl.load({var} + {V.kernel.args.seed_offset('load_seed_offset', offset)})"
+        )
+
+    @staticmethod
+    def frexp(x):
+        cache_key = f"frexp({x})"
+        if cse_val := V.kernel.cse.try_get(cache_key):
+            return cse_val
+
+        mantissa = V.kernel.cse.newvar(dtype=x.dtype, shape=x.shape)
+        exponent = V.kernel.cse.newvar(dtype=torch.int32, shape=x.shape)
+        V.kernel.compute.writeline(
+            f"{mantissa}, {exponent} = triton_helpers.frexp({x})"
+        )
+        V.kernel.cse.put(cache_key, (mantissa, exponent))
+        return (mantissa, exponent)
+
+    @staticmethod
+    def partial_accumulate(
+        name: str,
+        reduction_type: str,
+        value: CSEVariable,
+        extra_meta: dict[str, Any],
+    ) -> None:
+        raise NotImplementedError
+
+
+class HelperFunctions:
+    """An ordered set of helper functions."""
+
+    _templates_seen: dict[str, str]  # Template code to function name
+    finalized_helpers: list[str]
+
+    def __init__(self) -> None:
+        self._templates_seen = {}
+        self.finalized_helpers = []
+
+    def add(self, template_code: str, *, base_name="_triton_helper_fn") -> str:
+        """This accepts a function definition with the function name
+        left as a format specifier e.g.
+
+            @triton.jit
+            def {name}(arg0, arg1):
+                return arg0 + arg1
+
+        We add the templated code to the function set and return the name
+        assigned to that function.
+
+        """
+        existing_name = self._templates_seen.get(template_code)
+        if existing_name is not None:
+            # Don't duplicate existing helpers
+            return existing_name
+
+        name = f"{base_name}{len(self.finalized_helpers)}"
+        self._templates_seen[template_code] = name
+        self.finalized_helpers.append(template_code.format(name=name))
+        return name
+
+    def __iter__(self):
+        return iter(self.finalized_helpers)
+
+    def __getitem__(self, idx):
+        return self.finalized_helpers[idx]
+
+
+@dataclasses.dataclass
+class BlockParameters:
+    """
+    Class representing ND block dimensions, for block pointer analysis.
+    """
+
+    shape: list[sympy.Expr] = dataclasses.field(default_factory=list)
+    block_shape: list[sympy.Expr] = dataclasses.field(default_factory=list)
+    strides: list[sympy.Expr] = dataclasses.field(default_factory=list)
+    offsets: list[sympy.Expr] = dataclasses.field(default_factory=list)
+
+    @dataclasses.dataclass
+    class StrideSorter:
+        original_strides: list[int]
+        sort_idx: list[int]
+        revert_sort_idx: list[int] = dataclasses.field(init=False)
+
+        def __post_init__(self):
+            assert len(self.original_strides) > 0
+            assert len(self.sort_idx) == len(self.original_strides)
+
+            identity_sort_idx = list(range(len(self.original_strides)))
+            self._is_identity = self.sort_idx == identity_sort_idx
+
+            # Set revert_sort_idx
+            sorted_dims_by_strides_map = {k: i for i, k in enumerate(self.sort_idx)}
+            self.revert_sort_idx = [
+                sorted_dims_by_strides_map[i]
+                for i in range(len(sorted_dims_by_strides_map))
+            ]
+
+        @property
+        def is_identity(self):
+            return self._is_identity
+
+        @classmethod
+        @abstractmethod
+        def create(
+            cls, original_strides: list[Union[int, sympy.Expr]], shape_env: ShapeEnv
+        ) -> BlockParameters.StrideSorter:
+            """Create a `StrideSorter` that can be used to sort block parameters."""
+
+        def sort(self, attr):
+            if not self.is_identity:
+                return [attr[i] for i in self.sort_idx]
+            return attr
+
+        def revert(self, attr):
+            if not self.is_identity:
+                return [attr[i] for i in self.sort_idx]
+            return attr
+
+    @dataclasses.dataclass
+    class IdentityStrideSorter(StrideSorter):
+        def __post_init__(self):
+            super().__post_init__()
+
+        @classmethod
+        def create(
+            cls, original_strides: list[Union[int, sympy.Expr]], shape_env: ShapeEnv
+        ) -> BlockParameters.StrideSorter:
+            return cls(
+                original_strides=original_strides,
+                sort_idx=list(range(len(original_strides))),
+            )
+
+    @dataclasses.dataclass
+    class TensorDecriptorStrideSorter(StrideSorter):
+        """
+        Sorts BlockParameters dimensions with strides in descending order.
+        """
+
+        def __post_init__(self):
+            super().__post_init__()
+
+        @classmethod
+        def create(
+            cls, original_strides: list[Union[int, sympy.Expr]], shape_env: ShapeEnv
+        ) -> BlockParameters.StrideSorter:
+            """
+            If the strides are not all known constants or if the strides are already
+            sorted in descending order, return identity sort.
+
+            For example if block_shape @ strides is [ZBLOCK, XBLOCK, YBLOCK] @ [8, 1, 16]
+            The indices to sort the strides in descending order will be [2, 0, 1].
+            The indices to revert back to the original order will be [1, 2, 0].
+            """
+            identity_sort = list(range(len(original_strides)))
+            try:
+                # TODO: even if the strides are not in descending order the strides
+                # may be tensor descriptor compliant
+                # i.e. innermost stride == 1 and outer strides 16 byte aligned
+                # We should benchmark the effect of applying a transpose to these
+                # cases vs leaving them unsorted.
+                sort_idx = utils.argsort_sym(shape_env, original_strides, reverse=True)
+            except AssertionError:
+                # Symbolic shapes, failed to evaluate comparison expression
+                sort_idx = identity_sort
+
+            return cls(
+                original_strides=original_strides,
+                sort_idx=sort_idx,
+            )
+
+    def __add__(self, other: BlockParameters) -> BlockParameters:
+        """
+        Concatenates block parameters.
+        """
+        cls = type(self)
+        a, b = tuple(dataclasses.asdict(x) for x in (self, other))
+        return cls(**{key: a[key] + b[key] for key in a})
+
+    def maybe_sort_with_stride_order(
+        self, stride_sorter_cls: type[StrideSorter], shape_env: ShapeEnv
+    ) -> tuple[BlockParameters, BlockParameters.StrideSorter]:
+        """
+        Sort `BlockParameter` with stride_sorter_cls. Returns block parameters
+        as well as a `StrideSorter` which contains information on how the sort
+        can be reverted.
+        """
+        stride_sorter = stride_sorter_cls.create(self.strides, shape_env=shape_env)
+        params = BlockParameters(
+            **{
+                key: stride_sorter.sort(val)
+                for key, val in dataclasses.asdict(self).items()
+            }
+        )
+        return params, stride_sorter
+
+    def remove_dims(self, removable_dims: list[bool]) -> BlockParameters:
+        """
+        Remove dimensions where removable_dims is True.
+        """
+
+        def filter_dims(it):
+            return [
+                item
+                for item, is_removable in zip(it, removable_dims)
+                if not is_removable
+            ]
+
+        return BlockParameters(
+            **{key: filter_dims(val) for key, val in dataclasses.asdict(self).items()},
+        )
+
+
+class CooperativeReductionWorkspaceCache:
+    """
+    The scratch space used for cooperative reductions can be reused
+    after two reduction loops.  This keeps track of what can be reused.
+    """
+
+    def __init__(self, args):
+        self.args = args
+        self.current_loop = []
+        self.prior_loop = []
+        self.ready_for_reuse = collections.defaultdict(collections.deque)
+        self.loop_count = 0
+        self.store_count = 0
+
+    def allocate(self, nbytes: sympy.Expr):
+        cached = self.ready_for_reuse.get(nbytes)
+        if cached:
+            return cached.popleft()
+        ws_name, _, ws_offset = self.args.workspace(nbytes, False)
+        self.current_loop.append((nbytes, ws_name, ws_offset))
+        return (ws_name, ws_offset)
+
+    def on_loop_end(self):
+        # Buffers can be reused after 2 loop ends
+        for nbytes, ws_name, ws_offset in self.prior_loop:
+            self.ready_for_reuse[nbytes].append((ws_name, ws_offset))
+        self.prior_loop = self.current_loop
+        self.current_loop = []
+        self.loop_count += 1
+
+    def increment_store_count(self):
+        prior = self.store_count
+        self.store_count += 1
+        return prior
+
+
+@dataclasses.dataclass
+class FixedTritonConfig:
+    config: dict[str, int]
+
+    def __getitem__(self, item):
+        return self.config[item]
+
+    def __contains__(self, item):
+        return item in self.config
+
+
+class TritonCSE(CSE[TritonCSEVariable, Union[str, tuple[str, str]]]):
+    """
+    Subclasses CSE to apply the current load mask to the cache key to avoid CSEing
+    variables across separate masked blocks.
+    """
+
+    def augment_key(self, cache_key: str) -> Union[str, tuple[str, str]]:
+        if mask := V.kernel._load_mask:
+            return (cache_key, mask.name)
+        else:
+            return cache_key
+
+
+@dataclasses.dataclass
+class TMACompatibilityChecker:
+    """
+    Checks if the TMA API can be used for load / store triton operations.
+    """
+
+    kernel: TritonKernel
+    dtype: torch.dtype
+    for_store: bool
+    force: bool
+
+    def __post_init__(self):
+        self.failed_debug_prefix = "Cannot use TMA descriptor for load / store since: "
+
+    # Also see Note: TMA API Restrictions for the below
+    def can_use_tma(
+        self,
+    ) -> bool:
+        if self.force:
+            return True
+        if not (
+            V.graph.get_current_device_or_throw().type == "cuda"
+            and torch.cuda.get_device_capability()[0] >= 9
+            and config.triton.use_tensor_descriptor
+            and config.assume_aligned_inputs
+            and has_triton_stable_tma_api()
+            # For CUDA The base ptr needs to be aligned
+        ):
+            log.debug(
+                (
+                    "%s Requires triton>=3.4.0, a CUDA device with cc>=9.0 and"
+                    " `use_tensor_descriptor` and `assume_aligned_inputs` options enabled"
+                ),
+                self.failed_debug_prefix,
+            )
+            return False
+
+        # `no_x_dim` => XBLOCK=1, and for reductions this means only one element
+        # is to be stored . However the TMA API requires that
+        # the store will be 16 byte aligned, which is not attainable with a single
+        # element
+        if self.for_store and self.kernel.no_x_dim:
+            log.debug(
+                "%s stores with `no_x_dim` cannot load 16 bytes.",
+                self.failed_debug_prefix,
+            )
+            return False
+
+        return True
+
+    def are_block_parameters_compatible(
+        self,
+        block_params: BlockParameters,
+    ) -> bool:
+        """
+        Check if the block parameters are valid for TMA.
+        If force, we allow relying on symbolic hints equivalent
+        to what we check for Triton templates.
+        """
+        if self.force:
+            strides = [
+                V.graph.sizevars.symbolic_hint(st) for st in block_params.strides
+            ]
+        else:
+            strides = block_params.strides
+
+        # The TMA API requires that the innermost stride is 1
+        # and that the outer strides are 16 byte aligned
+        if not V.graph.sizevars.statically_known_equals(strides[-1], sympy.Integer(1)):
+            log.debug(
+                "%s TMA API requires innermost stride to be 1. Strides are: %s",
+                self.failed_debug_prefix,
+                strides,
+            )
+            return False
+
+        element_size = self.dtype.itemsize
+        for stride in strides[:-1]:
+            if not V.graph.sizevars.statically_known_equals(
+                ModularIndexing(stride * element_size, 1, sympy.Integer(16)),
+                sympy.Integer(0),
+            ):
+                log.debug(
+                    "%s TMA API requires outer strides to be 16 byte aligned. Dtype bytes: %d, strides: %s",
+                    self.failed_debug_prefix,
+                    element_size,
+                    strides,
+                )
+                return False
+
+        # Now compute the minimum value of the block type that is used
+        # in the innermost block size that can guarantee that 16 bytes of data
+        # can be loaded / stored.
+        # Start with finding the innermost block type
+        innermost_block_shape = block_params.block_shape[-1]
+
+        # Pure singleton case
+        if V.graph.sizevars.statically_known_equals(
+            innermost_block_shape, sympy.Integer(1)
+        ):
+            log.debug(
+                "%s innermost block shape cannot load 16 bytes. Block shape: %s",
+                self.failed_debug_prefix,
+                block_params.block_shape,
+            )
+            return False
+
+        innermost_block_type = None
+        innermost_block_symt = None
+        for block_type_str in innermost_block_shape.free_symbols:
+            for block_symt in TritonSymbols.block_types:
+                if symbol_is_type(block_type_str, block_symt):
+                    innermost_block_type = block_type_str
+                    innermost_block_symt = block_symt
+                    break
+
+        assert innermost_block_type and innermost_block_symt, (
+            f"{innermost_block_shape} expr must contain a single block type from {TritonSymbols.block_types}"
+        )
+
+        # For persistent reductions, the reduction block sizes are fixed at compile time
+        if self.kernel.persistent_reduction and not self.for_store:
+            # For a discontiguous tensor, a 1D block will be split across several
+            # dimensions, e.g. R0_BLOCK:
+            # block_shape=[XBLOCK, ((R0_BLOCK + 31)//32), Min(1, ((R0_BLOCK + 31)//32)), Min(32, R0_BLOCK)]
+            # The persistent R0_BLOCK will be a power of 2 that is at least r0_numel So it
+            # should be guaranteed that Min(32, R0_BLOCK) * element_size >= 16
+            innermost_tree_prefix = prefix_str[innermost_block_symt]
+            tree_numel = None
+            for t in self.kernel.range_trees:
+                if t.is_reduction:
+                    if t.prefix == innermost_tree_prefix:
+                        tree_numel = t.numel
+                        break
+            assert tree_numel is not None
+            persistent_rblock = self.kernel._get_persistent_RBLOCK(tree_numel)
+            innermost_block_bytes = (
+                innermost_block_shape.subs({innermost_block_type: persistent_rblock})
+                * element_size
+            )
+            if not V.graph.sizevars.statically_known_geq(
+                innermost_block_bytes, sympy.Integer(16)
+            ):
+                log.debug(
+                    "%s persistent reduction innermost block shape cannot load 16 bytes. Block shape: %s, persistent RBLOCK: %d",
+                    self.failed_debug_prefix,
+                    block_params.block_shape,
+                    persistent_rblock,
+                )
+                return False
+
+        else:
+            # E.g. if the innermost block shape is Min(2, XBLOCK)
+            # then the TMA API can only be used if the dtype has an 8 byte element
+            # size so that 16 bytes of data can be loaded in the innermost dimension
+            try:
+
+                def indexing_div_rep(
+                    x: sympy.Expr,
+                    y: sympy.Expr,
+                    z: Optional[sympy.Expr] = None,
+                ) -> sympy.Expr:
+                    div = x / y
+                    if z:
+                        div = div % z
+                    return div
+
+                solve_expr = innermost_block_shape * element_size - 16
+                # Sympy cannot handle FloorDiv and ModularIndexing well, so simplify
+                solve_expr_simplified = solve_expr.replace(
+                    FloorDiv, indexing_div_rep
+                ).replace(ModularIndexing, indexing_div_rep)
+                min_block_size = next_power_of_2(
+                    int(
+                        sympy.nsolve(
+                            solve_expr_simplified,
+                            innermost_block_type,
+                            1,
+                        )
+                    )
+                )
+
+                # TODO: min block size may be too large / introduce redundancy
+                if min_block_size > self.kernel.max_block(
+                    prefix_str[innermost_block_symt]
+                ):
+                    log.debug(
+                        "%s the minimum block size to satisfy expression %s is too large: %d",
+                        self.failed_debug_prefix,
+                        solve_expr_simplified,
+                        min_block_size,
+                    )
+                    return False
+
+                block_type_str = self.kernel.index_to_str(innermost_block_type)
+                # Check block sizes if the user has provided a fixed triton config
+                if self.kernel.fixed_config:
+                    if min_block_size > self.kernel.fixed_config[block_type_str]:
+                        log.debug(
+                            "%s For block %s, fixed config block size %d is smaller "
+                            "than the minimum required: %d",
+                            self.failed_debug_prefix,
+                            block_type_str,
+                            self.kernel.fixed_config[block_type_str],
+                            min_block_size,
+                        )
+                        return False
+                else:
+                    # Update the minimum block sizes that are passed to triton
+                    # heuristics
+                    self.kernel.tma_min_block_sizes[block_type_str] = max(
+                        min_block_size,
+                        self.kernel.tma_min_block_sizes.get(block_type_str, 1),
+                    )
+
+            except ValueError:
+                log.debug(
+                    "%s innermost block shape cannot load 16 bytes. Block params: %s",
+                    self.failed_debug_prefix,
+                    block_params.block_shape,
+                )
+                return False
+
+        return True
+
+    def can_lift(self) -> bool:
+        """
+        Can you lift the make_tensor_descriptor
+        call to the top of the kernel? This requires
+        being certain that all of the shape, stride,
+        and block_shape information is handled in arguments
+        or top level definitions.
+
+        Right now we assume this is always possible if you force TMA.
+        """
+        return self.force
+
+
+class TritonKernel(SIMDKernel[TritonCSEVariable]):
+    """A class to represent a triton kernel and helpers to generate
+    triton kernel programmatically
+    """
+
+    overrides = TritonKernelOverrides  # type: ignore[assignment]
+    helper_functions: HelperFunctions
+    kexpr: Callable[[sympy.Expr], str] = texpr
+    allow_block_ptr = True
+    tma_compatibility_checker_cls = TMACompatibilityChecker
+    transpose_discontiguous_tensor_descriptors_override: Optional[bool] = None
+
+    def __init__(
+        self,
+        tiling: dict[str, sympy.Expr],
+        min_elem_per_thread=0,
+        optimize_mask=True,
+        fixed_config: Optional[FixedTritonConfig] = None,
+        hint_override: Optional[int] = None,
+        **kwargs,
+    ) -> None:
+        self.optimize_mask: bool = optimize_mask
+        self.fixed_config = fixed_config
+        super().__init__(tiling, **kwargs)
+        self.cse = TritonCSE(self.newvar_prefix, self.suffix)
+        # Cache of values that can be reused for the prologue.
+        self.prologue_cache: dict[str, str] = {}
+        self.prologue: IndentedBuffer = IndentedBuffer()
+        self.post_loop_combine: IndentedBuffer = IndentedBuffer()
+        self.post_loop_store: IndentedBuffer = IndentedBuffer()
+        self.outside_loop_vars = OrderedSet[Any]()
+        self.min_elem_per_thread = min_elem_per_thread
+        self.block_ptr_id = itertools.count()
+        self.block_ptr_to_buffer = dict[str, str]()
+        self.helper_functions = HelperFunctions()
+        self.pointer_advancements: dict[SymT, dict[str, list[sympy.Expr]]] = (
+            collections.defaultdict(dict)
+        )
+        self.tma_min_block_sizes = dict[str, int]()
+        self.hint_override = hint_override
+        self._load_counts: collections.Counter[str] = collections.Counter()
+        self._load_index = 0
+
+        # A set of autotuning hints to pass as part of triton_meta
+        self.autotune_hints = OrderedSet[AutotuneHint]()
+        self.triton_meta: Optional[dict[str, Any]] = None
+
+        if self.inside_reduction:
+            self.codegen_reduction_numels(self.body)
+
+        if self.cooperative_reduction:
+            self.init_cooperative_reduction()
+
+        self.codegen_range_tree()
+
+        if self.cooperative_reduction:
+            self.init_cooperative_reduction_mask()
+
+        self.has_load_with_contiguous_rdim = False
+        # We track the store name since a store can be canceled later
+        self.stores_with_contiguous_rdim: list[str] = []
+
+    @staticmethod
+    def _has_stride1_on_rdim(index) -> bool:
+        # These analysis is only needed in deterministic mode so far
+        # to filter triton configs. Return false immediately to avoid
+        # increasing compilation time when the mode is off.
+        if not (
+            config.deterministic or config.test_configs.force_filter_reduction_configs
+        ):
+            return False
+        support_vars = index.free_symbols
+        reduce_vars = [
+            var
+            for var in support_vars
+            if symbol_is_type(var, TritonSymbols.reduction_types)
+        ]
+
+        if len(reduce_vars) == 0:
+            return False
+
+        # for expression "x0 + 150528*((x1//(s27*s38))) + 3*(ModularIndexing(x1, 1, s38)) + 672*(ModularIndexing(x1, s38, s27))"
+        # stride_vars will results in DivisionByZero error
+        try:
+            stride_vars = V.graph.sizevars.stride_vars(index, reduce_vars, support_vars)
+        except ZeroDivisionError:
+            return False
+
+        return any(stride == 1 for stride in stride_vars)
+
+    @property
+    def has_store_with_contiguous_rdim(self) -> bool:
+        return not all(
+            is_buffer_removed(name) for name in self.stores_with_contiguous_rdim
+        )
+
+    def dtype_to_str(self, dtype: torch.dtype) -> str:
+        return triton_type(dtype)
+
+    def should_use_cooperative_reduction(self) -> bool:
+        return self.inside_reduction and V.choices.should_use_cooperative_reduction(
+            self.features
+        )
+
+    def init_cooperative_reduction(self):
+        """One time setup code for cooperative reductions."""
+        assert self.cooperative_reduction
+
+        # shift all the grids over since tl.program_id(0) is for rsplit
+        for tree in self.range_trees:
+            if tree.grid_dim is not None:
+                tree.grid_dim += 1
+
+        sem_count = self.numels["x"]
+        if self.fixed_config:
+            sem_count = CeilDiv(sem_count, self.fixed_config["XBLOCK"])
+        self.semaphores_name = self.args.semaphores(sem_count)
+        self.cooperative_reduction_workspace_cache = CooperativeReductionWorkspaceCache(
+            self.args
+        )
+        self.body.splice(
+            """\
+            RSPLIT_NEXT_POWER_OF_2: tl.constexpr = triton_helpers.constexpr_next_power_of_2(RSPLIT)
+            RSPLIT_IS_POWER_OF_2: tl.constexpr = RSPLIT == RSPLIT_NEXT_POWER_OF_2
+            HAS_RSPLIT: tl.constexpr = RSPLIT > 1
+            rsplit_id = tl.program_id(0)
+            num_rblocks = (rnumel + RBLOCK - 1) // RBLOCK
+            rsplit_chunk = (num_rblocks + RSPLIT - 1) // RSPLIT * RBLOCK
+            rsplit_start = rsplit_chunk * rsplit_id
+            rsplit_end = rsplit_chunk * (rsplit_id + 1)
+            """,
+        )
+        if any(
+            not self._has_constant_mask(tree)
+            for tree in self.range_trees
+            if tree.is_reduction
+        ):
+            self.body.writeline(
+                "rsplit_end = tl.where(rsplit_end < rnumel, rsplit_end, rnumel)"
+            )
+
+    def init_cooperative_reduction_mask(self):
+        rsplit_arange = "tl.arange(0, RSPLIT_NEXT_POWER_OF_2)"
+        if not self.no_x_dim:
+            rsplit_arange = f"{rsplit_arange}[None, :]"
+        self.body.writeline(f"rsplit_arange = {rsplit_arange}")
+
+        if self._has_constant_xmask():
+            self.body.splice(
+                """\
+                if RSPLIT_IS_POWER_OF_2:
+                    rsplit_mask: tl.constexpr = None
+                else:
+                    rsplit_mask = rsplit_arange < RSPLIT
+                """
+            )
+        else:
+            assert not self.no_x_dim
+            self.body.writeline(
+                "rsplit_mask = xmask if RSPLIT_IS_POWER_OF_2 else ((rsplit_arange < RSPLIT) & xmask)"
+            )
+
+    def codegen_range_tree(self):
+        for tree in self.range_trees:
+            # reduction indexing goes inside a loop
+            if not tree.is_loop:
+                self.iteration_ranges_codegen_header(tree, self.body)
+            elif self.inside_reduction:
+                # workaround for this issue:
+                # https://gist.github.com/jansel/6527126f781559095c5531f98a4235a7
+                self.body.writeline(
+                    f"{tree.prefix}base = {self.iteration_ranges_ranges_code(tree)}"
+                )
+
+        if self.inside_reduction:
+            if any(tree.is_loop for tree in self.range_trees):
+                # If the kernel contains loops, compute rbase.
+                rn_bases = self._get_reduction_symbols(
+                    "base", integer=True, nonnegative=True
+                )
+                rbase = self._flatten_reduction_indices(rn_bases)
+                self.body.splice(f"rbase = {self.index_to_str(rbase)}")
+            else:
+                # For looped reductions, indexing is deferred to the innermost loop.
+                self.codegen_reduction_indices(self.body)
+
+    def need_numel_args(self):
+        """
+        Indicate whether we need provide numel as arguments for the generated
+        kernel calls in the benchmark.
+
+        Should be true for pointwise/reduction kernels but false for triton
+        matmul kernels.
+        """
+        return True
+
+    def should_use_persistent_reduction(self) -> bool:
+        return self.inside_reduction and V.choices.should_use_persistent_reduction(
+            self.features, self.cooperative_reduction
+        )
+
+    def want_no_x_dim(self):
+        return (
+            self.persistent_reduction
+            and len(self.numels) == self.num_reduction_dims + 1
+            and self.fixed_config
+            and self.fixed_config["XBLOCK"] == 1
+        )
+
+    @property
+    def assert_function(self) -> str:
+        return "tl.device_assert"
+
+    def indexing(
+        self,
+        index: sympy.Expr,
+        *,
+        copy_shape: Optional[Union[str, tuple[str]]] = None,
+        dense_indexing=False,
+        override_mask=None,
+        block_ptr=False,
+        tma_compatibility_checker: Optional[TMACompatibilityChecker] = None,
+    ):
+        """
+        Compute the index and mask to pass to tl.load() or tl.store()
+        """
+        index = self.prepare_indexing(index)
+        index_vars = index.free_symbols
+        has_rindex = False
+
+        mask_vars: OrderedSet[str] = OrderedSet()
+        for var in sorted(index_vars, key=operator.attrgetter("name")):
+            assert isinstance(var, sympy.Symbol)
+            has_rindex = has_rindex or symbol_is_type(
+                var, TritonSymbols.reduction_types
+            )
+            if override_mask:
+                pass
+            elif symbol_is_type(var, SymT.TMP):
+                # indirect indexing
+                cse_var = self.cse.varname_map[var.name]
+                mask_vars.update(cse_var.mask_vars)
+            elif symbol_is_type(
+                var,
+                (
+                    SymT.UNBACKED_INT,
+                    SymT.SIZE,
+                    SymT.PRECOMPUTED_SIZE,
+                    SymT.INDEX,
+                    SymT.FLOAT,
+                    SymT.UNBACKED_FLOAT,
+                ),
+            ):
+                pass
+            else:
+                # var is one of xN, yN, r0_N or r1_N
+                prefix_matches = [
+                    prefix_str[symt]
+                    for symt in TritonSymbols.block_types
+                    if symbol_is_type(var, symt)
+                ]
+                if len(prefix_matches) == 0:
+                    pass
+                assert len(prefix_matches) == 1, f"Ambiguous type: {var.name}"
+                mask_vars.add(f"{prefix_matches[0]}mask")
+
+        need_dense = (
+            config.triton.dense_indexing
+            or dense_indexing
+            or self._load_mask is not None
+        ) and index != 0
+
+        have_dense = True
+        have_loop_vars = False
+        dense_mask_vars: OrderedSet[str] = OrderedSet()
+
+        for tree in self.active_range_trees():
+            if index_vars.intersection(tree.var_list):
+                have_loop_vars = True
+            else:
+                have_dense = False
+            dense_mask_vars.add(f"{tree.prefix}mask")
+
+        if (
+            (
+                (block_ptr and self.allow_block_ptr and config.triton.use_block_ptr)
+                or (
+                    tma_compatibility_checker
+                    and tma_compatibility_checker.can_use_tma()
+                )
+            )
+            and not override_mask
+            and not self._load_mask
+            and len(mask_vars - dense_mask_vars) == 0
+            and not self.is_indirect_indexing(index)
+            and have_loop_vars
+            # workaround https://github.com/triton-lang/triton/issues/2821
+            and self.index_dtype == "tl.int32"
+        ):
+
+            def match_affine_block(
+                index: sympy.Expr, range_tree: IterationRangesRoot
+            ) -> Optional[BlockParameters]:
+                """
+                Matches expressions of the form:
+                    idx = s * xindex
+
+                This implies stride (s,), and shape (XBLOCK,).
+                """
+                stride = BlockPatternMatcher.match_affine_block_expr(
+                    index, range_tree.symbol()
+                )
+                if stride is None:
+                    return None
+
+                return BlockParameters(
+                    shape=[range_tree.numel],
+                    block_shape=[TritonSymbols.get_block_size(range_tree)],
+                    strides=[stride],
+                    offsets=[TritonSymbols.get_block_offset(range_tree)],
+                )
+
+            def match_mod_div_block(
+                index: sympy.Expr, range_tree: IterationRangesRoot
+            ) -> Optional[BlockParameters]:
+                """
+                Matches higher-dimensional blocks coming from FloorDiv and ModularIndexing.
+
+                Example expression to match:
+                   sN * ((rindex//(d1 * ... * d(N-1))))
+                       + s1 * ModularIndexing(rindex, 1, d1)
+                       + ...
+                       + s(N-1) * ModularIndexing(rindex, d1 * ... * d(N-2), d(N-1))
+
+                This iterates over a block of shape (dN, ..., d1) and stride
+                (sN, ..., s1). (d1,...,d(N-1)) and (s1,...,sN) are
+                wildcards that we match.
+
+                Note that dN does not appear in the expression, but we solve for it
+                using range tree numels and the other dims.
+                """
+
+                index_var = range_tree.symbol()
+
+                # Bound the possible number of dims. We use the following heuristics:
+                # - At least one dim for each range tree node.
+                # - At least one dim for every FloorDiv or ModularIndexing op.
+                # - At least 2 dims to pattern match.
+                denom, modulo = sympy.symbols(
+                    "denom modulo",
+                    cls=functools.partial(sympy.Wild, exclude=[index_var]),
+                )
+                num_dims = max(
+                    2,
+                    # range_tree.nodes only includes the entries for the range tree
+                    # len(range_tree.nodes) <= self.range_tree_nodes
+                    len(range_tree.nodes),
+                    (
+                        index.count(FloorDiv(index_var, denom))
+                        + index.count(ModularIndexing(index_var, denom, modulo))
+                    ),
+                )
+
+                match_result = BlockPatternMatcher.match_mod_div_block_expr(
+                    index, index_var, range_tree.numel, num_dims
+                )
+                if match_result is None:
+                    return None
+
+                (
+                    dims,
+                    strides,
+                    block_index_exprs,
+                ) = match_result
+                slice_numels = BlockPatternMatcher.get_slice_numels(dims)
+
+                # Check for applicable iteration range sizes.
+                # When mapping a 1D block into an ND one, we need to know that
+                # the number of elements is not changed. This means the slice numels of
+                # the ND iteration range must evenly divide the length of the 1D block.
+                # There are two cases where we can guarantee this:
+                #  1. Numels are powers of 2. If numel == 2 ** n, and we know XBLOCK == 2 ** m,
+                #     with n and m integers, then either numel is a multiple of XBLOCK, or numel
+                #     is less than XBLOCK. (If numel is less than XBLOCK, we round up to 1 below.)
+                #  2. Numels are multiples of the maximum possible block size.
+                sizevars = V.graph.sizevars
+                max_block = self.max_block(range_tree.prefix)
+                if any(
+                    not sizevars.statically_known_multiple_of(numel, max_block)
+                    and not sizevars.statically_known_power_of_2(numel)
+                    for numel in slice_numels
+                ):
+                    return None
+
+                # Compute the ND block shape from the linear block size.
+                # Use CielDiv to round leading dimensions up to 1.
+                # Non-leading dimensions are clamped to the size of the iteration range,
+                # while the leading dimension can exceed this to accommodate a larger
+                # block size.
+                linear_block_size = TritonSymbols.get_block_size(range_tree)
+                block_shape: list[sympy.Expr] = [
+                    CeilDiv(linear_block_size, slice_numels[0])
+                ] + [
+                    sympy.Min(CeilDiv(linear_block_size, numel), dim)
+                    for numel, dim in zip(slice_numels[1:], dims[1:])
+                ]
+
+                # Compute block offsets from {xyzr}offset and the matched expressions.
+                block_offsets: list[sympy.Expr] = [
+                    sympy_subs(
+                        expr, {index_var: TritonSymbols.get_block_offset(range_tree)}
+                    )
+                    for expr in block_index_exprs
+                ]
+
+                return BlockParameters(
+                    shape=dims,
+                    block_shape=block_shape,
+                    strides=strides,
+                    offsets=block_offsets,
+                )
+
+            def match_block_subexpr(
+                expr: sympy.Expr, range_tree: IterationRangesRoot
+            ) -> Optional[BlockParameters]:
+                """
+                Match a block indexing subexpression involving a single range tree.
+                """
+                for match_func in (
+                    match_affine_block,
+                    match_mod_div_block,
+                ):
+                    match = match_func(expr, range_tree)
+                    if match is not None:
+                        return match
+
+                return None
+
+            def match_block_expr() -> Optional[BlockDescriptorOptions]:
+                index_relative_to_xyr_index = sympy_subs(
+                    index, {v: t.expr for v, t in self.range_tree_nodes.items()}
+                )
+                range_trees = self.active_range_trees()
+
+                # Partition the index into subexpressions pertaining to each range tree.
+                # For example xindex * 5 + r0_index * 3 is partitioned to
+                # (xindex * 5, r0_index * 3).
+                index_subexprs = [
+                    BlockPatternMatcher.get_subexpr_involving_symbol(
+                        index_relative_to_xyr_index, tree.symbol()
+                    )
+                    for tree in range_trees
+                ]
+
+                # Match each range tree's subexpression separately.
+                range_symbols = OrderedSet(tree.symbol() for tree in range_trees)
+                block_params = BlockParameters()
+                for tree, subexpr in zip(range_trees, index_subexprs):
+                    # Reject mixed terms, e.g. xindex * r0_index.
+                    # NB: the zero expression is allowed, for broadcasting.
+                    if len(range_symbols.intersection(subexpr.free_symbols)) > 1:
+                        return None
+
+                    # Match the subexpression for this range tree.
+                    params = match_block_subexpr(subexpr, tree)
+                    if params is None:
+                        return None
+                    block_params += params
+
+                # Collect leftover terms as a constant offset.
+                offset = index_relative_to_xyr_index - sum(index_subexprs)
+
+                # Form the block pointer or TMA descriptor.
+                self.filter_masks(mask_vars)
+
+                options_class = (
+                    BlockPtrOptions
+                    if config.triton.use_block_ptr
+                    else TensorDescriptorOptions
+                )
+                nonlocal tma_compatibility_checker
+                stride_sorter_cls: type[BlockParameters.StrideSorter]
+                if config.triton.use_block_ptr:
+                    can_lift = False
+                    stride_sorter_cls = BlockParameters.IdentityStrideSorter
+                else:
+                    tma_compatibility_checker = cast(
+                        TMACompatibilityChecker, tma_compatibility_checker
+                    )
+                    can_lift = tma_compatibility_checker.can_lift()
+
+                    if (
+                        self.transpose_discontiguous_tensor_descriptors_override
+                        is not None
+                    ):
+                        transpose_contiguous = (
+                            self.transpose_discontiguous_tensor_descriptors_override
+                        )
+                    else:
+                        transpose_contiguous = (
+                            config.triton.transpose_discontiguous_tensor_descriptor
+                        )
+
+                    # For templates:
+                    # Only try transpose if we know the output shape
+                    # in case we need to transpose the data.
+                    if hasattr(self, "template_out_shape"):
+                        transpose_contiguous &= copy_shape is not None
+
+                    stride_sorter_cls = (
+                        BlockParameters.TensorDecriptorStrideSorter
+                        if transpose_contiguous
+                        else BlockParameters.IdentityStrideSorter
+                    )
+
+                options = options_class.create(
+                    params=block_params,
+                    constant_offset=offset,
+                    range_trees=range_trees,
+                    mask_vars=mask_vars,
+                    get_max_block=self.max_block,
+                    can_lift=can_lift,
+                    stride_sorter_cls=stride_sorter_cls,
+                )
+                if options_class == TensorDescriptorOptions:
+                    tma_compatibility_checker = cast(
+                        TMACompatibilityChecker, tma_compatibility_checker
+                    )
+                    if not tma_compatibility_checker.are_block_parameters_compatible(
+                        options.params
+                    ):
+                        return None
+
+                return options
+
+            # Return a block pointer, if indexing matches the pattern.
+            options = match_block_expr()
+            if options is not None:
+                return options
+        expand_str = None
+        expand_shape: BlockShapeType = None
+        index_str = self.index_to_str(index)
+
+        def _get_expand_str():
+            if copy_shape:
+                if isinstance(copy_shape, str):
+                    return f"{copy_shape}.shape", None
+                else:
+                    return "[" + ", ".join(str(c) for c in copy_shape) + "]", copy_shape
+            else:
+                return self.dense_size_str(), tuple(self.dense_size_list())
+
+        if is_sympy_integer_like(index):
+            # Integer indexing produces a size-1 scalar tensor with the same shape
+            # as the dense dimension. E.g, if dense_size = [YBLOCK, XBLOCK, R0_BLOCK],
+            # then we create tl.full([1, 1, 1], int).
+            #
+            # Exceptions:
+            # 1. If copy_shape is explicitly provided, use copy_shape expansion instead.
+            # 2. If the dense tensor has only one dimension (e.g., [XBLOCK]),
+            #    broadcasting does not apply. For example:
+            #        tl.arange(0, XBLOCK) + tl.full([1], int)  # -> broadcasting error
+            #    In this case, we fall back to dense indexing:
+            #        tl.full([XBLOCK], int)
+            if copy_shape or len(self.dense_size_list()) == 1:
+                expand_str, expand_shape = _get_expand_str()
+            else:
+                expand_str = str([1] * len(self.dense_size_list()))
+                expand_shape = tuple([1] * len(self.dense_size_list()))
+
+            index_str = f"tl.full({expand_str}, {index_str}, tl.int32)"
+            if self.fixed_config and not self._has_constant_xmask():
+                mask_vars = OrderedSet(["xmask"])
+            else:
+                mask_vars = OrderedSet()
+            if self._load_mask:
+                mask_vars.add(self._load_mask)
+            return IndexingOptions(
+                index_str,
+                mask_vars,
+                expand_str,
+                has_rindex,
+                index,
+                expand_shape=expand_shape,
+            )
+
+        if need_dense and not have_dense:
+            if self.inside_reduction and self.is_native_matmul:
+                # This avoids full broadcasting (need_dense) when performing native matmul.
+                # For example, self._load_mask previously required tl.broadcast_to() in index_str.
+                # Due to the restrictions of tl.dot semantics, we only want to expand the block
+                # shape for the necessary axes.
+                #
+                # Previously:
+                #   tmp1 = tl.load(ptr + tl.broadcast_to(r0, [YBLOCK, XBLOCK, R0_BLOCK]),
+                #                  r0_mask & tmp0 & xmask)
+                #
+                # Now:
+                #   tmp1 = tl.load(ptr + tl.broadcast_to(r0, [1, 1, R0_BLOCK]),
+                #                  r0_mask & tmp0 & xmask)
+                #
+                # We achieve this by determining the required block shape through mask inspection.
+                # When a temporary variable appears in the mask (e.g., self._load_mask), we retrieve
+                # its true shape by inspecting tmp.mask_vars tracked by TritonCSEVariable.
+                #
+                # Caution: it may miss the correct block shape if the specific mask was constant
+                # and thus not tracked in TritonCSEVariable.mask_vars.
+                #
+                # TODO: Once the shape propagation PR lands, reimplement this logic:
+                #       https://github.com/pytorch/pytorch/pull/152198
+                mask_shape = mask_vars.copy()
+                if self._load_mask:
+                    mask_shape.add(self._load_mask)
+
+                xyzr = OrderedSet(["xmask", "ymask", "zmask", "r0_mask"])
+                while not mask_shape.issubset(xyzr):
+                    tmp_masks = mask_shape.difference(xyzr)
+                    tmp = tmp_masks.pop()
+                    assert isinstance(tmp, TritonCSEVariable)
+                    mask_shape.discard(tmp)
+                    mask_shape.update(tmp.mask_vars)
+
+                # e.g., expand_list becomes ['ZBLOCK', 1, 1, 'R0_BLOCK']
+                expand_list = ["1"] * len(self.dense_size_list())
+                for mask in mask_shape:
+                    assert isinstance(mask, str)
+                    for tree in self.active_range_trees():
+                        if mask.startswith(tree.prefix):
+                            dim = tree.tensor_dim
+                            assert isinstance(dim, int)
+                            expand_list[dim] = self.dense_size_list()[dim]
+
+                expand_str = "[" + ",".join(map(str, expand_list)) + "]"
+                expand_shape = tuple(expand_list)
+                index_str = f"tl.broadcast_to({index_str}, {expand_str})"
+            else:
+                expand_str, expand_shape = _get_expand_str()
+                index_str = f"tl.broadcast_to({index_str}, {expand_str})"
+                mask_vars = dense_mask_vars
+        elif not have_loop_vars and copy_shape:
+            expand_shape_str, expand_shape = _get_expand_str()
+            index_str = f"tl.broadcast_to({index_str}, {expand_shape_str})"
+            mask_vars = dense_mask_vars
+
+        if expand_shape is None:
+            if need_dense or have_dense:
+                _, expand_shape = _get_expand_str()
+            else:
+                expand_shape = ()
+
+        if override_mask:
+            mask_vars = OrderedSet([override_mask])
+
+        if self._load_mask:
+            mask_vars.add(self._load_mask)
+
+        self.filter_masks(mask_vars)
+
+        return IndexingOptions(
+            index_str,
+            mask_vars,
+            expand_str,
+            has_rindex,
+            index,
+            expand_shape=expand_shape,
+        )
+
+    def codegen_block_ptr(
+        self,
+        name: str,
+        var: str,
+        indexing: Union[BlockPtrOptions, TensorDescriptorOptions],
+        other="",
+    ) -> tuple[str, str]:
+        """Generate a block pointer or tensor descriptor for Triton kernel operations.
+
+        This method creates either a block pointer (for regular Triton operations) or
+        a tensor descriptor (for TMA operations) based on the indexing type. It handles
+        caching and reuse of descriptors for performance optimization.
+
+        Args:
+            name: The name of the buffer/tensor being accessed
+            var: The variable name for the pointer
+            indexing: Block pointer options or tensor descriptor options containing
+                     indexing information and boundary check settings
+            other: Additional parameters string (e.g., padding options)
+
+        Returns:
+            A tuple containing:
+            - block_descriptor: The generated block pointer or tensor descriptor variable name
+            - other: Modified additional parameters string with boundary check options
+        """
+        check = indexing.boundary_check()
+        if isinstance(indexing, TensorDescriptorOptions):
+            if check and other:
+                # The TMA API currently does not support padding values
+                # but the default is zero
+                assert other == ", other=0.0"
+                other = ""
+        else:
+            if not check:
+                # workaround https://github.com/triton-lang/triton/issues/2813
+                other = ""
+            elif other:
+                assert other == ", other=0.0"
+                other = f", boundary_check={check!r}, padding_option='zero'"
+            else:
+                other = f", boundary_check={check!r}"
+
+        if (
+            self.inside_reduction
+            and self.range_trees[-1].is_loop
+            and indexing.has_rindex()
+        ) or indexing.can_lift:
+            if indexing.can_lift and var in self.prologue_cache:
+                # Check for epilogue subtiling to reuse the same
+                # tensor descriptor.
+                block_descriptor = self.prologue_cache[var]
+            else:
+                block_ptr_line = indexing.format(var, roffset=False)
+                block_var = self.cse.try_get(block_ptr_line)
+
+                # Early return if block descriptor already exists
+                if block_var:
+                    return str(block_var), other
+
+                block_descriptor_id = next(self.block_ptr_id)
+                if isinstance(indexing, BlockPtrOptions):
+                    block_descriptor = f"block_ptr{block_descriptor_id}"
+                else:
+                    block_descriptor = f"tma_descriptor{block_descriptor_id}"
+                named_var = self.cse.namedvar(
+                    block_descriptor, dtype=torch.uint64, shape=[]
+                )
+                self.cse.put(block_ptr_line, named_var)
+
+                line_body = DeferredLine(name, f"{block_descriptor} = {block_ptr_line}")
+                if indexing.can_lift:
+                    self.prologue.writeline(line_body)
+                    # Cache the descriptor for epilogue subtiling
+                    self.prologue_cache[var] = block_descriptor
+                else:
+                    self.body.writeline(line_body)
+
+                if isinstance(indexing, BlockPtrOptions):
+                    # Store for later use. If the buffer is removed the below advancements
+                    # are no longer necessary
+                    self.block_ptr_to_buffer[block_descriptor] = name
+
+                    # Generate block pointer advancements, for later use.
+                    for symt in TritonSymbols.reduction_types:
+                        advance_offsets = indexing.advance_roffset(symt)
+
+                        # Ignore identity advancements.
+                        if all(
+                            V.graph.sizevars.statically_known_equals(
+                                offset, sympy.Integer(0)
+                            )
+                            for offset in advance_offsets
+                        ):
+                            continue
+
+                        advancements = self.pointer_advancements[symt]
+                        assert block_descriptor not in advancements, (
+                            f"duplicate advancement for pointer '{block_descriptor}' at type '{symt}'"
+                        )
+                        advancements[block_descriptor] = advance_offsets
+        else:
+            block_descriptor = indexing.format(var)
+        return block_descriptor, other
+
+    def codegen_block_ptr_store_line(self, name, indexing, block_ptr, value, other=""):
+        # Stores require an explicit broadcast. We do this in two phases:
+        #  1. Broadcast the operand to the final shape of the range trees, e.g. [ZBLOCK,
+        #     YBLOCK, XBLOCK]. This protects against implicit broadcasting from loads.
+        #  2. In case the block pointer / tma descriptor has different dimensionality, broadcast/reshape the
+        #     result to the shape of the pointer.
+        value = f"tl.broadcast_to({value}, {indexing.final_shape})"
+
+        # These dims no longer need broadcasting.
+        for idx, (dim, broadcast_dim) in enumerate(
+            zip(indexing.final_shape, indexing.broadcast_shape)
+        ):
+            if V.graph.sizevars.statically_known_equals(dim, broadcast_dim):
+                indexing.broadcasting_dims[idx] = False
+
+        value = indexing.codegen_broadcast_and_reshape(
+            value,
+            indexing.final_shape,
+            indexing.block_shape,
+            allow_implicit=False,
+            for_store=True,
+        )
+
+        # workaround https://github.com/triton-lang/triton/issues/2814
+        value = f"{value}.to({triton_store_type(V.graph.get_dtype(name))})"
+        if isinstance(indexing, BlockPtrOptions):
+            return f"tl.store({block_ptr}, {value}{other})"
+        return f"{block_ptr}.store({V.kernel.index_to_str(indexing.offsets)}, {value})"
+
+    def check_bounds(
+        self,
+        expr: sympy.Expr,
+        size: sympy.Expr,
+        lower: bool,
+        upper: bool,
+    ):
+        if not (lower or upper):
+            return
+
+        assert isinstance(expr, sympy.Expr)
+        indexing = self.indexing(expr, block_ptr=False, tma_compatibility_checker=None)
+        assert isinstance(indexing, IndexingOptions)
+
+        index_str = indexing.index_str
+        mask_str = indexing.mask_str if indexing.has_mask() else None
+        size_str = texpr(self.rename_indexing(size)) if upper else None
+
+        # expr is already wrapped
+        line = self.indirect_assert(
+            index_str, "0" if lower else None, size_str, mask_str
+        )
+
+        buffer = self.get_load_buffer(indexing)
+        self.cse.generate(buffer, line, assignment=False, dtype=torch.int32)
+
+    def get_load_buffer(self, indexing):
+        if indexing.has_indirect() or indexing.has_tmpmask():
+            # Masked loads must come after the mask is computed
+            return self.compute
+        elif (
+            self.inside_reduction
+            and self.range_trees[-1].is_loop
+            and not indexing.has_rindex()
+        ):
+            # can lift a common load outside of reduction loop
+            # One exception is when this is an indirect_load.
+            return self.body
+        else:
+            return self.loads
+
+    def _handle_pdl_before_load(self, wait_buffer):
+        GDC_WAIT = "tl.extra.cuda.gdc_wait()"
+        self._load_index += 1
+        if self.inside_reduction:
+            wait_buffer = self.body
+        if enable_pdl_codegen():
+            if self._load_index == 1:
+                wait_buffer.writeline(GDC_WAIT)
+
+    def _handle_pdl_after_load(self, launch_buffer, result_var):
+        GDC_LAUNCH = "tl.extra.cuda.gdc_launch_dependents()"
+        if self.inside_reduction:
+            launch_buffer = self.post_loop_combine
+        if enable_pdl_codegen():
+            current_load_index = self._load_index
+            launch_if_last_load = DelayMaybeLine(
+                lambda: current_load_index == self._load_index,
+                f"0; {GDC_LAUNCH} # gdc launch for {result_var}",
+            )
+            self.cse.generate(launch_buffer, launch_if_last_load, dtype=torch.int32)
+
+    def partial_accumulate(
+        self, name: str, reduction_type, val, extra_meta: dict[str, Any]
+    ):
+        self.saved_partial_accumulate.append(
+            PartialAccumulate(name, reduction_type, val)
+        )
+
+    def load(self, name: str, index: sympy.Expr):
+        """
+        Load from the memory location 'name', offset by some indexing expression 'index'.
+        """
+        var = self.args.input(name)
+        load_counts = self._load_counts
+        load_counts[name] += 1
+        make_line: Callable[[str], Union[str, DelayReplaceLine]] = identity
+        indirect_indexing = self.is_indirect_indexing(index)
+        original_index = index
+        dtype = V.graph.get_dtype(name)
+        indexing = self.indexing(
+            index,
+            block_ptr=True,
+            tma_compatibility_checker=self.tma_compatibility_checker_cls(
+                self,
+                dtype,
+                for_store=False,
+                force=False,
+            ),
+        )
+
+        if isinstance(indexing, IndexingOptions) and self._has_stride1_on_rdim(
+            indexing.index
+        ):
+            self.has_load_with_contiguous_rdim = True
+
+        has_rindex = indexing.has_rindex()
+        has_tmpmask = indexing.has_tmpmask()
+
+        # Keep the variable in cache if were going to reuse it. Equiv., if any of the following hold
+        #  1) We are doing broadcasting
+        #  2) It is a non-coalesced load. The intuition is that if it's
+        #  non-coalesced, we will likely load each element multiple times in
+        #  practice.
+        #  3) It will be used later and it won't be CSE'd. Equiv., if all the following hold
+        #   3.1) We are in a reduction loop
+        #   3.2) Its not its last use
+        #   3.3) This load will not be lifted to the body
+        #
+        is_coalesced = any(
+            i == 1 for i in self.get_strides_of_load(original_index).values()
+        )
+        if self.is_broadcasted(original_index):
+            ep = ", eviction_policy='evict_last'"
+        elif not is_coalesced:
+            ep = ", eviction_policy='evict_last'"
+        elif self.inside_reduction and self.range_trees[-1].is_loop:
+
+            def decide_later():
+                if load_counts[name] > expected_count and (
+                    has_rindex or indirect_indexing
+                ):
+                    return "evict_last"
+                return "evict_first"
+
+            expected_count = load_counts[name]
+            ep = ", eviction_policy=''"
+            make_line = functools.partial(DelayReplaceLine, "", decide_later)
+        else:
+            ep = ""
+
+        if (has_tmpmask or has_rindex) and indexing.has_mask():
+            if self._load_other:
+                other = f", other={constant_repr(self._load_other)}"
+            else:
+                other = ", other=0.0"
+        else:
+            other = ""
+
+        """Check if the buffer we're about to load, has
+        more than one read dependency
+        NOTE: enabled with env variable TORCHINDUCTOR_SKIP_L1
+        """
+        has_read_deps = True
+        if config.triton.skip_l1_cache:
+            buffer_read_counts = self.features.buffer_read_counts()
+            has_read_deps = buffer_read_counts[name] > 1
+        """Skip L1 cache if we're (pretty?) sure the data is used only once
+        """
+        skip_l1_cache = (
+            not self.is_broadcasted(original_index)
+            and not self.inside_reduction
+            and not has_read_deps
+            and is_coalesced  # for indirect loads is_coalesced is False?
+        )
+        cachemod = ""
+        if skip_l1_cache:
+            cachemod = ", cache_modifier='.cg'"
+
+        append_broadcast = None
+        shape: BlockShapeType = None
+
+        if should_unwrap_unspec_arg(name):
+            line = var
+            # unwrapped bf16/fp16 0d tensors are passed in as float32 scalars
+            # see triton_utils.py:signature_of
+            if dtype in (torch.float16, torch.bfloat16):
+                if config.triton.codegen_upcast_to_fp32:
+                    dtype = torch.float32
+                else:
+                    line += f".to({triton_type(dtype)})"
+            shape = ()
+
+        else:
+            if isinstance(indexing, (BlockPtrOptions, TensorDescriptorOptions)):
+                block_descriptor, other = self.codegen_block_ptr(
+                    name, var, indexing, other
+                )
+                if isinstance(indexing, BlockPtrOptions):
+                    line = f"tl.load({block_descriptor}{other}{ep}{cachemod})"
+                else:
+                    line = f"{block_descriptor}.load({V.kernel.index_to_str(indexing.offsets)})"
+                line = indexing.codegen_broadcast_and_reshape(
+                    line,
+                    indexing.block_shape,
+                    indexing.final_shape,
+                    allow_implicit=True,
+                    for_store=False,
+                )
+                shape = indexing.final_shape
+            elif is_sympy_integer_like(original_index):
+                line = f"tl.load({var} + ({original_index}))"
+                append_broadcast = indexing.expand_str
+                shape = ()
+            else:
+                line = f"tl.load({var} + ({indexing.index_str}), {indexing.mask_str}{ep}{other}{cachemod})"
+
+                # The block shape of tl.load depends on the indexing expression.
+                # Inferring shape solely from the mask may miss cases where the mask is constant.
+                # Inferring from indexing.expand_shape alone may also fail when dense indexing is absent.
+                # so, iterate over variables in the indexexpr to accurately infer the block shape.
+                if indexing.expand_shape:
+                    shape = indexing.expand_shape
+                else:
+                    shape = TritonSymbols.get_block_shape(indexing.index)
+
+            if (
+                dtype in (torch.float16, torch.bfloat16)
+                and config.triton.codegen_upcast_to_fp32
+            ):
+                line += ".to(tl.float32)"
+                dtype = torch.float32
+            if dtype == torch.bool and torch.version.hip is None:
+                # Workaround for https://github.com/triton-lang/triton/issues/2151
+                # tl.load returns int8 when loading from pointer to int1
+                # NOTE: Currently causes hangs on bool UTs for ROCm
+                line += ".to(tl.int1)"
+                dtype = torch.bool
+
+        load_buffer = self.get_load_buffer(indexing)
+        self._handle_pdl_before_load(load_buffer)
+        result_var = self.cse.generate(
+            load_buffer, make_line(line), dtype=dtype, shape=shape
+        )
+        self._handle_pdl_after_load(load_buffer, result_var)
+        if result_var.use_count > 1:
+            load_counts[name] -= 1  # don't double count cache hit
+        assert isinstance(result_var, TritonCSEVariable)
+        result_var.mask_vars = indexing.mask_vars  # type: ignore[assignment]
+
+        if append_broadcast:
+            line = f"tl.broadcast_to({result_var}, {append_broadcast})"
+            result_var = self.cse.generate(
+                load_buffer, line, dtype=dtype, shape=indexing.expand_shape
+            )
+            if indexing.mask_vars:
+                if dtype.is_floating_point:
+                    zero = "0.0"
+                elif dtype == torch.bool:
+                    zero = "True"
+                else:
+                    zero = "0"
+                other_val = (
+                    constant_repr(self._load_other) if self._load_other else zero
+                )
+                line = f"tl.where({indexing.mask_str}, {result_var}, {other_val})"
+                result_var = self.cse.generate(
+                    load_buffer, line, dtype=dtype, shape=result_var.shape
+                )
+
+        if not self.inside_reduction or (not indexing.has_rmask() and not has_rindex):
+            self.outside_loop_vars.add(result_var)
+
+        return result_var
+
+    def store(
+        self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
+    ) -> None:
+        """
+        store the 'value' to the memory location 'name', offset by some indexing expression 'index'.
+        """
+
+        var = self.args.output(name)
+        original_index = index
+        dtype = V.graph.get_dtype(name)
+
+        tma_compatibility_checker = None
+        if mode is None or mode == "tma":
+            force = mode == "tma"
+            tma_compatibility_checker = self.tma_compatibility_checker_cls(
+                self,
+                dtype,
+                for_store=True,
+                force=force,
+            )
+        indexing = self.indexing(
+            index,
+            dense_indexing=True,
+            block_ptr=mode is None,
+            tma_compatibility_checker=tma_compatibility_checker,
+        )
+
+        if isinstance(indexing, IndexingOptions) and self._has_stride1_on_rdim(
+            indexing.index
+        ):
+            self.stores_with_contiguous_rdim.append(name)
+
+        # Guard against write-after-read corruption in triton.
+        # See # https://github.com/triton-lang/triton/issues/1615
+        # This triton bug means that a load which is broadcasted over multiple
+        # warps may see the result of a store that happens later in the triton
+        # program. The workaround is to add a barrier before storing, which
+        # enforces that all warps have already read the data.
+        is_inplace = name in self.args.inplace_buffers
+        is_broadcasted = self.is_broadcasted(original_index)
+        if is_inplace and is_broadcasted:
+            self.stores.writeline(DeferredLine(name, "tl.debug_barrier()"))
+
+        if isinstance(indexing, (BlockPtrOptions, TensorDescriptorOptions)):
+            block_descriptor, other = self.codegen_block_ptr(name, var, indexing)
+            # block_ptr / tma descriptor stores don't do implicit casting
+            line = self.codegen_block_ptr_store_line(
+                name, indexing, block_descriptor, value, other
+            )
+        elif mode is None:
+            # If indexing is an integer and value has block shape larger than one,
+            # broadcasting fails. So, we manually broadcast indexing to the value shape.
+            # Without broadcast :
+            # tl.store(out_ptr0 + (tl.full([1, 1], 0, tl.int32)), tmp4, xmask) # Fail
+            #
+            # With broadcast:
+            # tl.store(out_ptr0 + (tl.full([1, 1], 0, tl.int32).broadcast_to((XBLOCK,1)), tmp4, xmask)
+            indexing_str = indexing.index_str
+            if (
+                is_sympy_integer_like(index)
+                and value.shape is not None
+                and not all(str(x) == "1" for x in value.shape)
+            ):
+                value_shape = ", ".join(map(str, value.shape))
+                indexing_str += f".broadcast_to({value_shape})"
+            line = f"tl.store({var} + ({indexing_str}), {value}, {indexing.mask_str})"
+        elif mode == "atomic_add":
+            self.atomic_add_found = True
+            indexing_str = indexing.index_str
+            if (
+                is_sympy_integer_like(index)
+                and value.shape is not None
+                and not all(str(x) == "1" for x in value.shape)
+            ):
+                value_shape = ", ".join(map(str, value.shape))
+                indexing_str += f".broadcast_to({value_shape})"
+            line = f"tl.atomic_add({var} + ({indexing_str}), {value}, {indexing.mask_str}, sem='relaxed')"
+        else:
+            raise NotImplementedError(f"store mode={mode}")
+
+        exit_stack = contextlib.ExitStack()
+        if not self.inside_reduction and self.cooperative_reduction:
+            exit_stack.enter_context(self.guard_cooperative_store(name, self.stores))
+
+        self.stores.writeline(DeferredLine(name, line))
+
+        if not self.inside_reduction:
+            self.outside_loop_vars.add(value)
+
+        exit_stack.close()
+
+    def device_assert_async(self, cond, msg) -> None:
+        self.compute.writeline(f"tl.device_assert({cond}, {repr(msg)})")
+
+    def guard_cooperative_store(self, name, buffer):
+        """
+        For cooperative reductions only one thread block should write out the result.
+        We rotate which thread block does each write for better parallelism
+        """
+        idx = self.cooperative_reduction_workspace_cache.increment_store_count()
+        buffer.writeline(DeferredLine(name, f"if rsplit_id == ({idx} % RSPLIT):"))
+        return buffer.indent()
+
+    def _combine_masks(self, *variables: Optional[CSEVariable]):
+        masks = None
+        for elem in variables:
+            if elem is None:
+                continue
+            if hasattr(elem, "mask_vars"):
+                if masks is None:
+                    masks = elem.mask_vars
+                else:
+                    masks = masks | elem.mask_vars
+        return masks
+
+    def bucketize(
+        self,
+        values: CSEVariable,
+        boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
+        boundary_indices: CSEVariable,
+        indexing_dtype: torch.dtype,
+        right: bool,
+        sorter: Optional[tuple[str, sympy.Expr]] = None,
+        sorter_indices: Optional[CSEVariable] = None,
+    ) -> CSEVariable:
+        """
+        See [Note: Inductor bucketize op]
+        """
+
+        # Triton performance for bucketize_binary_search is much better when the number
+        # of threads equals the number of elements.
+        # If we're trying to use a bucketize kernel, we should make sure that an
+        # autotuning config with num_elements_per_warp=(warp_size) exists.
+        self.autotune_hints.add(AutotuneHint.ONE_ELEMENT_PER_THREAD)
+
+        boundaries_ptr = self.args.input(boundaries[0])
+        boundary_size = self.index_to_str(boundaries[1])
+        boundaries_underlying_numel = self.index_to_str(boundaries[2])
+        boundary_stride = self.index_to_str(boundaries[3])
+        sorter_ptr = self.args.input(sorter[0]) if sorter else "None"
+        sorter_stride = self.index_to_str(sorter[1]) if sorter else "None"
+
+        if indexing_dtype == torch.int32:
+            triton_dtype = "tl.int32"
+        elif indexing_dtype == torch.int64:
+            triton_dtype = "tl.int64"
+        else:
+            raise NotImplementedError(
+                "Bucketize only supports indexing with int32 and int64"
+            )
+
+        self._handle_pdl_before_load(self.compute)
+        result = self.cse.generate(
+            self.compute,
+            f"triton_helpers.bucketize_binary_search({values}, "
+            f"{boundaries_ptr}, {boundary_size}, {boundaries_underlying_numel}, {boundary_stride}, "
+            f"{boundary_indices}, "
+            f"{triton_dtype}, "
+            f"{right}, "
+            f"{sorter_ptr}, {sorter_stride}, "
+            f"{sorter_indices}, "
+            ")",
+            dtype=indexing_dtype,  # type: ignore[attr-defined]
+            shape=values.shape,
+        )
+        self._handle_pdl_after_load(self.compute, result)
+
+        masks = self._combine_masks(values, boundary_indices, sorter_indices)
+        result.mask_vars = masks  # type: ignore[attr-defined]
+
+        return result
+
+    def reduction_resize(self, value) -> str:
+        ndims = self.triton_tensor_ndim()
+        if ndims == 1:
+            return f"triton_helpers.promote_to_tensor({value})"
+
+        nreduce = self.num_reduction_dims
+        sizes = [":"] * (ndims - nreduce) + ["None"] * nreduce
+        return f"{value}[{', '.join(sizes)}]"
+
+    def reduction_resize_and_shape(self, value, shape) -> tuple[str, BlockShapeType]:
+        ndims = self.triton_tensor_ndim()
+        if ndims == 1:
+            return f"triton_helpers.promote_to_tensor({value})", shape
+
+        nreduce = self.num_reduction_dims
+        sizes = [":"] * (ndims - nreduce) + ["None"] * nreduce
+        new_shape = (
+            (*shape[: (ndims - nreduce)], *[1] * nreduce) if shape is not None else None
+        )
+        return f"{value}[{', '.join(sizes)}]", new_shape
+
+    def reduction_collapse_dims(
+        self, buffer, value: CSEVariable, dtype: torch.dtype
+    ) -> CSEVariable:
+        """
+        Reshape to RBLOCK, collapsing all reduction dims.
+        """
+        # This is not needed for 1D reductions.
+        if self.num_reduction_dims == 1:
+            return value
+
+        target_ndim = self.triton_tensor_ndim() - self.num_reduction_dims
+        initial_shape = self.dense_size_list()
+        target_shape = initial_shape[:target_ndim] + ["RBLOCK"]
+        return self.cse.generate(
+            buffer,
+            triton_reshape(str(value), initial_shape, target_shape),
+            dtype=dtype,
+            shape=tuple(target_shape),
+        )
+
+    def reduction(
+        self,
+        dtype: torch.dtype,
+        src_dtype: torch.dtype,
+        reduction_type: ReductionType,
+        value: Union[CSEVariable, tuple[CSEVariable, ...]],
+    ) -> Union[CSEVariable, tuple[CSEVariable, ...]]:
+        """
+        codegen reduction of value to Triton according the reduction_type
+        """
+
+        def maybe_upcast(value: CSEVariable) -> CSEVariable:
+            # Math reductions in FP16/BF16 are less accurate because the Triton compiler does not
+            # automatically promote to FP32 for accumulation. Additionally, max/min reductions
+            # do not support FP16/BF16. We manually promote to FP32 here.
+            return (
+                ops.to_dtype(value, torch.float32)
+                if value.dtype
+                in [
+                    torch.float16,
+                    torch.bfloat16,
+                ]
+                else value
+            )
+
+        original_dtypes = [val.dtype for val in pytree.tree_leaves(value)]
+        value = pytree.tree_map(maybe_upcast, value)
+        if any(x in [torch.float16, torch.bfloat16] for x in original_dtypes):
+            # Only promote FB16/BF16; do not promote other integer/boolean dtypes
+            src_dtype = torch.promote_types(src_dtype, torch.float32)
+            dtype = torch.promote_types(dtype, torch.float32)
+
+        assert self.inside_reduction
+        masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees)
+        self.filter_masks(masks)
+        masks = sorted(masks)
+        if self._load_mask:
+            masks.append(self._load_mask)
+        reduction_range_prefix = self.range_trees[-1].prefix[0]
+
+        # When we do native matmtul codegen,
+        # we don't want to keep the R0_BLOCK/R1_BLOCK in the accumulator.
+        # so instead of naively calling dense_size_str(), we filter out
+        # reduction block from accumulator and only keep (Y,X).
+        # In bmm (Z,Y,R)x(Z,R,X) case, we also remove z dimension from accumulator
+        # because 3d (Z,Y,X) tl.dot is somehow slower than 2d tl.dot.
+        # Instead, we force ZBLOCK to be always 1 during autotune.
+        dense_size_str: str
+        if self.is_native_matmul:
+            dense_sizes = self.dense_size_list()
+            assert len(dense_sizes) >= 3
+            xy_sizes_only = [size for size in dense_sizes if "X" in size or "Y" in size]
+            dense_size_str = f"[{', '.join(xy_sizes_only)}]"
+            value_shape = tuple(xy_sizes_only)
+        else:
+            dense_size_str = self.dense_size_str()
+            value_shape = tuple(self.dense_size_list())
+
+        # Say we have
+        #     tmp0 = ops.constant(1, torch.int64)
+        #     tmp1 = ops.reduction(torch.int64, torch.int64, "sum", tmp0)
+        # tmp0 in the triton code is either a scalar, or single-element tensor
+        # so if we emit tl.sum directly, it will only give 1 instead of RBLOCK * 1
+        # To avoid this, we broadcast to the expected shape first.
+        value = self._map_tuple_or_scalar(
+            lambda v: self.cse.generate(
+                self.compute,
+                f"tl.broadcast_to({v}, {dense_size_str})",
+                dtype=v.dtype,
+                shape=value_shape,
+            ),
+            value,
+        )
+
+        logical_index = None
+        if reduction_type in ("argmin", "argmax"):
+            if isinstance(value, tuple):
+                value, logical_index = value
+
+        dim = self.triton_tensor_ndim() - self.num_reduction_dims
+        root_op: str
+
+        def final_reduction(
+            buffer,
+            value: CSEVariable,
+            result_type: Optional[torch.dtype],
+        ) -> tuple[str, Optional[torch.dtype], BlockShapeType]:
+            """
+            Helper to generate a reduction call, e.g. tl.sum.
+            """
+            triton_reduction_fn = get_triton_reduction_function(reduction_type)
+
+            value = self.reduction_collapse_dims(buffer, value, dtype)
+            if reduction_type == "dot":
+                # Native matmul is a special case because accumulator shape is fixed to (Y,X)
+                is_bmm = len(self.dense_size_list()) == 4
+                assert value.shape is not None
+                if is_bmm:
+                    result = f"{value}[None,:,:,None]"  # (Y,X) to (Z=1,Y,X,R=1)
+                    shape = [1, *value.shape, 1]
+                else:
+                    result = f"{value}[:,:,None]"  # (Y,X) to (Y,X,R=1)
+                    shape = [*value.shape, 1]
+            else:
+                result, shape = self.reduction_resize_and_shape(  # type: ignore[assignment]
+                    f"{triton_reduction_fn}({value}, {dim})", value.shape
+                )
+
+            if result_type is not None:
+                result = f"{result}.to({self.dtype_to_str(result_type)})"
+            else:
+                result_type = value.dtype
+
+            return result, result_type, shape
+
+        def final_reduction_define(
+            buffer,
+            result_var: CSEVariable,
+            value: CSEVariable,
+            result_type: Optional[torch.dtype],
+        ) -> None:
+            """
+            Generate a reduction and assign it to an existing variable.
+            """
+            value, _, _ = final_reduction(buffer, value, result_type)
+            buffer.splice(f"{result_var} = {value}")
+
+        def final_argreduce(buffer, result_var, value, index):
+            value = self.reduction_collapse_dims(buffer, value, dtype)
+            index = self.reduction_collapse_dims(buffer, index, dtype)
+            buffer.splice(
+                f"""\
+                {result_var}_val, {result_var}_idx = triton_helpers.{root_op}_with_index({value}, {index}, {dim})
+                {result_var} = {self.reduction_resize(f"{result_var}_idx")}
+                """
+            )
+
+        cache_key = (src_dtype, reduction_type, value)
+        if cache_key in self.cse.reduction_cache:
+            return self.cse.reduction_cache[cache_key]
+
+        acc_type = triton_acc_type(src_dtype)
+        torch_acc_type = upcast_acc_dtype(src_dtype)
+        result_shape = list(self.dense_size_list())
+        result_shape[dim] = "1"
+        result_var: Any = self.cse.newvar(
+            dtype=torch_acc_type, shape=tuple(result_shape)
+        )
+        result_var.mask_vars = OrderedSet(
+            var for var in masks if not prefix_is_reduction(var[0])
+        )
+        cond = " & ".join(masks)
+
+        def where_cond(tval, fval):
+            if not cond:
+                return tval
+            return TritonKernelOverrides.where(cond, tval, fval)
+
+        if self.persistent_reduction:
+            default = ir.Reduction.default_value(reduction_type, src_dtype)
+
+            def update_constant_dtype(constant, src_dtype, dst_dtype):
+                "update reduction constant mask value to match dst_dtype"
+
+                # int is the only mask which may not fit within lower bitwidth,
+                # because float uses inf/-inf
+                if src_dtype.is_floating_point or src_dtype == torch.bool:
+                    return constant
+
+                if src_dtype == dst_dtype or constant == 0:
+                    return constant
+
+                if constant == torch.iinfo(src_dtype).max:
+                    return torch.iinfo(dst_dtype).max
+                elif constant == torch.iinfo(src_dtype).min:
+                    return torch.iinfo(dst_dtype).min
+                else:
+                    return constant
+
+            def _mask_value(value, default) -> CSEVariable:
+                default = update_constant_dtype(default, src_dtype, value.dtype)
+                default_str = self._map_tuple_or_scalar(constant_repr, default)
+
+                return self.cse.generate(
+                    self.compute,
+                    where_cond(value, default_str),
+                    dtype=value.dtype,
+                    shape=value.shape,
+                )
+
+            masked_value: Union[CSEVariable, Sequence[CSEVariable]]
+            if reduction_type == "online_softmax_reduce":
+                # Don't generate mask value for online_softmax since we
+                # will fallback below
+                pass
+            elif isinstance(value, tuple):
+                masked_value = [_mask_value(v, d) for v, d in zip(value, default)]  # type: ignore[arg-type]
+            elif reduction_type == "dot":
+                # Here, we don't perform the masking.
+                # Masking w/ where condition in native matmul is handled in ops.dot codegen.
+                # Since tl.dot performs reduction within the triton block,
+                # masking should happen before the tl.dot is called.
+                masked_value = self.cse.generate(self.compute, value, dtype=value.dtype)
+            else:
+                masked_value = _mask_value(value, default)
+
+            if reduction_type in ("argmax", "argmin"):
+                assert isinstance(masked_value, CSEVariable)
+                accumulator_dtype = V.kernel.get_index_dtype_as_torch_dtype()
+                if logical_index:
+                    accumulator_index = f"({str(logical_index)}).to({self.dtype_to_str(accumulator_dtype)})"
+                else:
+                    accumulator_index = str(
+                        self.cse.generate(
+                            self.compute,
+                            f"tl.broadcast_to({reduction_range_prefix}index, {masked_value}.shape)",
+                            dtype=accumulator_dtype,
+                            shape=masked_value.shape,
+                        )
+                    )
+                root_op = {"argmax": "max", "argmin": "min"}[reduction_type]
+                final_argreduce(
+                    self.compute, result_var, masked_value, accumulator_index
+                )
+                result_var.dtype = accumulator_dtype
+            elif reduction_type == "welford_reduce":
+                if self.cooperative_reduction:
+                    # cooperative reductions require full welford for correctness
+                    result_var = self.welford_reduce(
+                        result_var, reduction_type, value, where_cond, acc_type, dtype
+                    )
+                else:
+                    # For persistent reductions, don't bother with
+                    # welford's algorithm since it uses more registers, and
+                    # taking two reductions doesn't increase memory usage.
+                    result_var = self.welford_reduce_fallback(dtype, value)
+            elif reduction_type == "welford_combine":
+                assert isinstance(masked_value, Sequence)
+                (mean, m2, weight) = masked_value
+                result_var = tuple(
+                    self.cse.generate(self.compute, value, dtype=dtype, shape=shape)
+                    for value, shape in self._welford(
+                        self.compute, mean, m2, weight, dim, dtype
+                    )
+                )
+            elif reduction_type == "online_softmax_reduce":
+                # All data is loaded to register anyway, no need to do
+                # online softmax
+                result_var = self.prepare_softmax_twopass_fallback(dtype, value)
+            else:
+                assert isinstance(masked_value, CSEVariable)
+                _result, _dtype, _shape = final_reduction(
+                    self.compute, masked_value, masked_value.dtype
+                )
+                result_var = self.cse.generate(
+                    self.compute, _result, dtype=_dtype, shape=_shape
+                )
+        else:
+            accumulator = self.cse.namedvar(
+                f"_{result_var}",
+                dtype=torch_acc_type,
+                shape=tuple(self.dense_size_list()),
+            )
+            default = ir.Reduction.default_accumulator(reduction_type, src_dtype)
+            default = self._map_tuple_or_scalar(constant_repr, default)
+            if not isinstance(default, tuple):
+                if reduction_type == "dot":
+                    dense_sizes = self.dense_size_list()
+                    assert len(dense_sizes) >= 3
+                    xy_sizes_only = [
+                        size for size in dense_sizes if "X" in size or "Y" in size
+                    ]
+                    accumulator.shape = tuple(xy_sizes_only)
+                    dense_size_str = f"[{', '.join(xy_sizes_only)}]"
+                    self.body.writeline(
+                        f"{accumulator} = tl.full({dense_size_str}, {default}, {acc_type})"
+                    )
+                else:
+                    self.body.writeline(
+                        f"{accumulator} = tl.full({self.dense_size_str()}, {default}, {acc_type})"
+                    )
+
+            if reduction_type in ("argmax", "argmin"):
+                accumulator_index = f"_{result_var}_index"
+                index_dtype = self.features.select_index_dtype()
+                self.body.writeline(
+                    f"{accumulator_index} = tl.full({self.dense_size_str()}, "
+                    f"{torch.iinfo(index_dtype).max}, {self.dtype_to_str(index_dtype)})"
+                )
+                root_op = {"argmax": "max", "argmin": "min"}[reduction_type]
+                # Use logical_index if it was unpacked, otherwise fall back to physical index
+                index_var = (
+                    f"({str(logical_index)}).to({self.dtype_to_str(index_dtype)})"
+                    if logical_index is not None
+                    else f"{reduction_range_prefix}index"
+                )
+                self.compute.splice(
+                    f"""\
+                {accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index(
+                    {accumulator}, {accumulator_index}, {value}, {index_var}
+                )
+                {accumulator} = {where_cond(f"{accumulator}_next", accumulator)}
+                {accumulator_index} = {where_cond(f"{accumulator_index}_next", accumulator_index)}
+                """
+                )
+                final_argreduce(
+                    self.post_loop_combine, result_var, accumulator, accumulator_index
+                )
+            elif is_welford_reduction(reduction_type):
+                result_var = self.welford_reduce(
+                    result_var, reduction_type, value, where_cond, acc_type, dtype
+                )
+            elif reduction_type == "online_softmax_reduce":
+                accumulator_max = f"_{result_var}_max"
+                accumulator_sum = f"_{result_var}_sum"
+
+                # setup accumulator
+                self.body.writeline(
+                    f"{accumulator_max} = tl.full({self.dense_size_str()}, float('-inf'), {acc_type})"
+                )
+                self.body.writeline(
+                    f"{accumulator_sum} = tl.zeros({self.dense_size_str()}, {acc_type})"
+                )
+
+                # combine
+                # Note, we pass config.use_fast_math to the JITFunction
+                # since a triton kernel can not access a config.
+                self.compute.splice(
+                    f"""
+                    {accumulator_max}_next, {accumulator_sum}_next = triton_helpers.online_softmax_combine(
+                        {accumulator_max}, {accumulator_sum}, {value}, {config.use_fast_math}
+                    )
+                    """
+                )
+
+                # mask
+                self.compute.splice(
+                    f"""
+                    {accumulator_max} = {where_cond(f"{accumulator_max}_next", accumulator_max)}
+                    {accumulator_sum} = {where_cond(f"{accumulator_sum}_next", accumulator_sum)}
+                    """
+                )
+
+                # reduce. Similar to the final reduction for coopereative
+                # reduction
+                result_max = result_var
+                result_sum = self.cse.newvar(dtype=dtype, shape=result_max.shape)
+
+                result_var = self.online_softmax_reduce_final_reduction(
+                    self.post_loop_combine,
+                    result_max,
+                    result_sum,
+                    accumulator_max,
+                    accumulator_sum,
+                    dim,
+                    dtype,
+                )
+            else:
+                combine_fn = ir.get_reduction_combine_fn(reduction_type, src_dtype)
+                updated = combine_fn(accumulator, value)
+                if reduction_type == "dot":
+                    self.compute.writeline(f"{accumulator} = {updated}")
+                else:
+                    self.compute.writeline(
+                        f"{accumulator} = {where_cond(updated, accumulator)}"
+                    )
+
+                if src_dtype == torch.bool:
+                    # This is only really used for aten.any. It changes the
+                    # final reduction of a non-persistent reduction from
+                    #     tmp5 = triton_helpers.max(_tmp5, 1)[:, None]
+                    # to
+                    #     tmp5 = triton_helpers.max(_tmp5.to(tl.int8), 1)[:, None].to(tl.int1)
+                    # which is needed because tl.reduce doesn't support tl.int1
+                    accumulator = self.cse.generate(
+                        self.post_loop_combine,
+                        f"{accumulator}.to(tl.int8)",
+                        dtype=torch.int8,
+                        shape=accumulator.shape,
+                    )
+
+                final_reduction_define(
+                    self.post_loop_combine, result_var, accumulator, None
+                )
+
+        if self.cooperative_reduction:
+            default = ir.Reduction.default_accumulator(reduction_type, src_dtype)
+            exit_stack = contextlib.ExitStack()
+            for buf in (self.post_loop_combine, self.post_loop_store):
+                # only do cooperative reduction combines if we have more than one thread block
+                buf.writeline("if HAS_RSPLIT:")
+                exit_stack.enter_context(buf.indent())
+
+            if reduction_type in ("argmax", "argmin"):
+                self.post_loop_combine.writeline(
+                    f"{result_var}_bval = {self.reduction_resize(f'{result_var}_val')}"
+                )
+                peer_val = self.codegen_cooperative_reduction_peer_combine(
+                    f"{result_var}_bval", src_dtype, default
+                )
+                index_dtype = self.features.select_index_dtype()
+                peer_idx = self.codegen_cooperative_reduction_peer_combine(
+                    result_var, index_dtype, torch.iinfo(index_dtype).max
+                )
+                final_argreduce(self.post_loop_store, result_var, peer_val, peer_idx)
+            elif is_welford_reduction(reduction_type):
+                assert reduction_type == "welford_reduce"
+                result_mean, result_m2, result_weight = result_var
+                peer_mean = self.codegen_cooperative_reduction_peer_combine(
+                    result_mean,
+                    upcast_acc_dtype(src_dtype),
+                    default[0],  # type: ignore[index]
+                )
+                peer_m2 = self.codegen_cooperative_reduction_peer_combine(
+                    result_m2,
+                    upcast_acc_dtype(src_dtype),
+                    default[1],  # type: ignore[index]
+                )
+                peer_weight = self.codegen_cooperative_reduction_peer_combine(
+                    result_weight,
+                    upcast_acc_dtype(src_dtype),
+                    default[2],  # type: ignore[index]
+                )
+                self.welford_reduce_final_reduction(
+                    self.post_loop_store,
+                    result_mean,
+                    result_m2,
+                    result_weight,
+                    peer_mean,
+                    peer_m2,
+                    peer_weight,
+                    dim,
+                    dtype,
+                )
+            elif reduction_type == "online_softmax_reduce":
+                result_max, result_sum = result_var
+                assert isinstance(default, Sequence)
+                peer_max = self.codegen_cooperative_reduction_peer_combine(
+                    result_max, upcast_acc_dtype(src_dtype), default[0]
+                )
+                peer_sum = self.codegen_cooperative_reduction_peer_combine(
+                    result_sum, upcast_acc_dtype(src_dtype), default[1]
+                )
+                self.online_softmax_reduce_final_reduction(
+                    self.post_loop_store,
+                    result_max,
+                    result_sum,
+                    peer_max,
+                    peer_sum,
+                    dim,
+                    dtype,
+                )
+            else:
+                peers = self.codegen_cooperative_reduction_peer_combine(
+                    result_var, upcast_acc_dtype(src_dtype), default
+                )
+                final_reduction_define(self.post_loop_store, result_var, peers, None)
+            exit_stack.close()
+
+        self.cse.reduction_cache[cache_key] = result_var
+
+        if isinstance(result_var, tuple):
+            assert all(isinstance(x, TritonCSEVariable) for x in result_var)
+            self.outside_loop_vars.update(result_var)
+
+            # Match output dtype with input dtype
+            if reduction_type in ("welford_reduce", "online_softmax_reduce"):
+                assert len(original_dtypes) == 1
+                original_dtypes = len(result_var) * original_dtypes
+
+            assert len(result_var) == len(original_dtypes)
+            for var, orig_dtype in zip(result_var, original_dtypes):
+                assert orig_dtype is not None
+                if var.dtype != orig_dtype:
+                    self.post_loop_combine.writeline(
+                        f"{var} = {var}.to({triton_compute_type(orig_dtype)})"
+                    )
+        else:
+            assert isinstance(result_var, TritonCSEVariable)
+            self.outside_loop_vars.add(result_var)
+
+            # Match output dtype with input dtype
+            if result_var.dtype != original_dtypes[0]:
+                assert original_dtypes[0] is not None
+                self.post_loop_combine.writeline(
+                    f"{result_var} = {result_var}.to({triton_compute_type(original_dtypes[0])})"
+                )
+
+        return result_var
+
+    def _online_softmax_reduce(
+        self, buffer, accumulator_max, accumulator_sum, dim, dtype: torch.dtype
+    ):
+        accumulator_max = self.reduction_collapse_dims(buffer, accumulator_max, dtype)
+        accumulator_sum = self.reduction_collapse_dims(buffer, accumulator_sum, dtype)
+        result_max, result_sum = [str(self.cse.newvar(dtype=dtype)) for _ in range(2)]
+        buffer.splice(
+            f"""
+            {result_max}, {result_sum} = triton_helpers.online_softmax_reduce(
+                {accumulator_max}, {accumulator_sum}, {dim}, {config.use_fast_math})
+            {result_max} = {self.reduction_resize(f"{result_max}")}
+            {result_sum} = {self.reduction_resize(f"{result_sum}")}
+            """
+        )
+
+        return result_max, result_sum
+
+    def _welford(self, buffer, mean, m2, weight, dim, dtype: torch.dtype):
+        """
+        Helper to codegen triton_helpers.welford.
+        """
+        mean, m2, weight = (
+            self.reduction_collapse_dims(buffer, value, dtype)
+            for value in (mean, m2, weight)
+        )
+        welford = f"triton_helpers.welford({mean}, {m2}, {weight}, {dim})"
+
+        def reduced_shape(shape):
+            return tuple(shape[0:dim] + shape[dim + 1 :])
+
+        welford_results = [
+            self.cse.newvar(dtype=dtype, shape=reduced_shape(value.shape))
+            for value in (mean, m2, weight)
+        ]
+        buffer.writeline(f"{', '.join([str(r) for r in welford_results])} = {welford}")
+
+        return tuple(
+            self.reduction_resize_and_shape(value, value.shape)
+            for value in welford_results
+        )
+
+    def welford_reduce(
+        self, result_var, reduction_type, value, where_cond, acc_type, dtype
+    ):
+        """Helper to codegen a welford reduction"""
+        dim = self.triton_tensor_ndim() - self.num_reduction_dims
+
+        accumulator = TritonCSEVariable(
+            f"{result_var}_mean",
+            shape=tuple(self.dense_size_list()),
+            dtype=acc_type,
+            bounds=ValueRanges.unknown(),
+        )
+        accumulator_m2 = TritonCSEVariable(
+            f"{result_var}_m2",
+            shape=tuple(self.dense_size_list()),
+            dtype=acc_type,
+            bounds=ValueRanges.unknown(),
+        )
+        accumulator_weight = TritonCSEVariable(
+            f"{result_var}_weight",
+            shape=tuple(self.dense_size_list()),
+            dtype=acc_type,
+            bounds=ValueRanges.unknown(),
+        )
+        self.body.writeline(
+            f"{accumulator} = tl.zeros({self.dense_size_str()}, {acc_type})"
+        )
+        self.body.writeline(
+            f"{accumulator_m2} = tl.zeros({self.dense_size_str()}, {acc_type})"
+        )
+        self.body.writeline(
+            f"{accumulator_weight} = tl.zeros({self.dense_size_str()}, {acc_type})"
+        )
+        if reduction_type == "welford_combine":
+            mean, m2, weight = value
+            self.compute.splice(
+                f"""\
+                {accumulator}_next, {accumulator_m2}_next, {accumulator_weight}_next = triton_helpers.welford_combine(
+                    {accumulator}, {accumulator_m2}, {accumulator_weight},
+                    {mean}, {m2}, {weight}
+                )
+                """
+            )
+        else:
+            assert reduction_type == "welford_reduce"
+            self.compute.splice(
+                f"""\
+                {accumulator}_next, {accumulator_m2}_next, {accumulator_weight}_next = triton_helpers.welford_reduce(
+                    {value}, {accumulator}, {accumulator_m2}, {accumulator_weight}, roffset == 0
+                )
+                """
+            )
+        self.compute.splice(
+            f"""\
+            {accumulator} = {where_cond(f"{accumulator}_next", accumulator)}
+            {accumulator_m2} = {where_cond(f"{accumulator_m2}_next", accumulator_m2)}
+            {accumulator_weight} = {where_cond(f"{accumulator_weight}_next", accumulator_weight)}
+            """
+        )
+        result_mean = result_var
+        return self.welford_reduce_final_reduction(
+            self.post_loop_combine,
+            result_mean,
+            None,
+            None,
+            accumulator,
+            accumulator_m2,
+            accumulator_weight,
+            dim,
+            dtype,
+        )
+
+    def welford_reduce_final_reduction(
+        self,
+        buffer,
+        result_mean,
+        result_m2,
+        result_weight,
+        mean,
+        m2,
+        weight,
+        dim,
+        dtype,
+    ):
+        """Helper to codegen call to triton_helpers.welford"""
+        values = list(self._welford(buffer, mean, m2, weight, dim, dtype))
+
+        result_exprs = [result_mean, result_m2, result_weight]
+        for i, (result_expr, (value, shape)) in enumerate(zip(result_exprs, values)):
+            if result_expr is None:
+                result_expr = self.cse.newvar(dtype=dtype, shape=shape)
+                result_exprs[i] = result_expr
+            buffer.splice(f"{result_expr} = {value}")
+
+        return tuple(result_exprs)
+
+    def online_softmax_reduce_final_reduction(
+        self, buffer, result_max, result_sum, peer_max, peer_sum, dim, dtype
+    ):
+        accumulator_max = self.reduction_collapse_dims(buffer, peer_max, dtype)
+        accumulator_sum = self.reduction_collapse_dims(buffer, peer_sum, dtype)
+        buffer.splice(
+            f"""
+            {result_max}, {result_sum} = triton_helpers.online_softmax_reduce(
+                {accumulator_max}, {accumulator_sum}, {dim}, {config.use_fast_math})
+            {result_max} = {self.reduction_resize(f"{result_max}")}
+            {result_sum} = {self.reduction_resize(f"{result_sum}")}
+            """
+        )
+        return result_max, result_sum
+
+    def max_rsplit(self):
+        if self.fixed_config:
+            return self.fixed_config["RSPLIT"]
+        return TRITON_MAX_RSPLIT
+
+    def codegen_cooperative_reduction_peer_combine(
+        self, result_var, dtype, default_val
+    ) -> CSEVariable:
+        """
+        Generate code to save a [XBLOCK, RSPLIT] temporary workspace, where each thread block writes a different
+        column.  After the barrier, every thread block loads the completed value so that it can compute the final
+        value independently.
+        """
+        xnumel = self.numels["x"]
+        mask = "xindex < xnumel" if not self._has_constant_xmask() else None
+
+        nbytes = xnumel * dtype.itemsize * self.max_rsplit()
+        ws_name, ws_offset = self.cooperative_reduction_workspace_cache.allocate(nbytes)
+
+        self.post_loop_combine.splice(
+            f"""
+                {result_var}_ws = ({ws_name} + {self.index_to_str(ws_offset)}).to(tl.pointer_type({triton_type(dtype)}))
+                tl.store({result_var}_ws + (xindex * RSPLIT + rsplit_id), {result_var}, {mask})
+            """,
+            strip=True,
+        )
+        peers = self.create_cse_var(
+            f"{result_var}_peers",
+            shape=["XBLOCK", "RSPLIT"],
+            dtype=dtype,
+            bounds=ValueRanges.unknown(),
+        )
+        self.post_loop_store.writeline(
+            f"{peers} = tl.load({result_var}_ws + (xindex * RSPLIT + rsplit_arange), "
+            f"rsplit_mask, eviction_policy='evict_first', other=triton_helpers.if_mask(rsplit_mask, {constant_repr(default_val)}))"
+        )
+        return peers
+
+    def store_reduction(
+        self,
+        name: str,
+        index: sympy.Expr,
+        value: CSEVariable,
+    ):
+        assert self.inside_reduction
+        self.inside_reduction = False
+        dtype = V.graph.get_dtype(name)
+        indexing = self.indexing(
+            index,
+            block_ptr=True,
+            tma_compatibility_checker=self.tma_compatibility_checker_cls(
+                kernel=self,
+                dtype=dtype,
+                for_store=True,
+                force=False,
+            ),
+        )
+        self.inside_reduction = True
+        var = self.args.output(name)
+
+        exit_stack = contextlib.ExitStack()
+        if self.cooperative_reduction:
+            exit_stack.enter_context(
+                self.guard_cooperative_store(name, self.post_loop_store)
+            )
+
+        if isinstance(indexing, (BlockPtrOptions, TensorDescriptorOptions)):
+            self.post_loop_store.writeline(
+                DeferredLine(
+                    name,
+                    self.codegen_block_ptr_store_line(
+                        name,
+                        indexing,
+                        indexing.format(var),
+                        value,
+                        f", boundary_check={indexing.boundary_check()!r}",
+                    ),
+                )
+            )
+        else:
+            assert isinstance(indexing, IndexingOptions)
+
+            indexing_str = indexing.index_str
+            if (
+                is_sympy_integer_like(index)
+                and value.shape is not None
+                and not all(str(x) == "1" for x in value.shape)
+            ):
+                value_shape = ", ".join(map(str, value.shape))
+                indexing_str += f".broadcast_to({value_shape})"
+
+            self.post_loop_store.writeline(
+                DeferredLine(
+                    name,
+                    f"tl.store({var} + ({indexing_str}), {value}, {indexing.mask_str})",
+                )
+            )
+
+        exit_stack.close()
+
+    def _lift_helper(
+        self, fn, values: tuple[CSEVariable, ...], dtypes: tuple[torch.dtype, ...]
+    ) -> str:
+        # Lift IR function for scan operations into a triton function
+        # in the global namespace
+        helper = IndentedBuffer()
+        helper.writeline("@triton.jit")
+        cse = CSE()
+
+        args = [
+            tuple(
+                cse.namedvar(f"arg{i}_{n}", dtype=dtype, shape=value.shape)
+                for n, (value, dtype) in enumerate(zip(values, dtypes))
+            )
+            for i in range(2)
+        ]
+        signature = ", ".join(str(x) for x in itertools.chain.from_iterable(args))
+        helper.writeline(f"def {{name}}({signature}):")
+
+        overrides = TritonOverrides()
+
+        # Build a name that changes depending on fn to workaround a triton bug
+        # where the combine_fn to reduce and scan is not hashed, and so different
+        # scan ops may collide in the triton cache.
+        # This is fixed with the latest triton pin, but not the triton-rocm pin.
+        helper_name = "_triton_helper_fn"
+
+        from torch._inductor.dtype_propagation import DtypePropagationOpsHandler
+        from torch._inductor.shape_propagation import ShapePropagationOpsHandler
+
+        shape_handler = ShapePropagationOpsHandler()
+        dtype_handler = DtypePropagationOpsHandler()
+
+        class CSEProxy(DefaultHandler):
+            def _default(
+                self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
+            ) -> Any:
+                nonlocal helper_name
+                helper_name += f"_{name}"
+
+                output_dtype = getattr(
+                    dtype_handler,
+                    name,
+                )(*args, **kwargs)
+
+                output_shape = getattr(
+                    shape_handler,
+                    name,
+                )(*args, **kwargs)
+
+                return cse.generate(
+                    helper,
+                    getattr(overrides, name)(*args, **kwargs),
+                    dtype=output_dtype,
+                    shape=output_shape,
+                )
+
+        with helper.indent(), V.set_ops_handler(CSEProxy()):
+            outputs = fn(*args)
+            outputs = ", ".join(str(output) for output in outputs)
+            helper.writeline(f"return {outputs}")
+
+        return self.helper_functions.add(helper.getvalue(), base_name=helper_name)
+
+    def scan(
+        self,
+        dtypes: tuple[torch.dtype, ...],
+        combine_fn: Callable[
+            [tuple[CSEVariable, ...], tuple[CSEVariable, ...]], tuple[CSEVariable, ...]
+        ],
+        values: tuple[CSEVariable, ...],
+    ) -> tuple[CSEVariable, ...]:
+        """
+        Perform an associative scan on 'values'.
+        """
+        assert self.inside_reduction
+        assert not self.cooperative_reduction, "TODO"
+        masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees)
+        self.filter_masks(masks)
+        masks = sorted(masks)
+        assert not self._load_mask, "ops.scan not supported inside ops.masked"
+
+        broadcasted_values = []
+        accumulators = []
+
+        dtypes = tuple(upcast_compute_type(dtype) for dtype in dtypes)
+        cse_compute = functools.partial(self.cse.generate, self.compute)
+        combine_helper_fn = self._lift_helper(combine_fn, values, dtypes)
+        dim = self.triton_tensor_ndim() - self.num_reduction_dims
+
+        for value, dtype in zip(values, dtypes):
+            value_dtype = self.cse.generate(
+                self.compute,
+                f"{value}.to({triton_compute_type(dtype)})",
+                dtype=dtype,
+                shape=value.shape,
+            )
+            value = self.cse.generate(
+                self.compute,
+                f"tl.broadcast_to({value_dtype}, {self.dense_size_str()})",
+                dtype=dtype,
+                shape=tuple(self.dense_size_list()),
+            )
+            broadcasted_values.append(value)
+
+            acc_type = triton_acc_type(dtype)
+
+            if not self.persistent_reduction:
+                reduced_size = self.dense_size_list()
+                reduced_size[-1] = "1"
+                accumulator = self.cse.newvar(dtype=dtype, shape=reduced_size)
+                reduced_size_str = f"[{', '.join(reduced_size)}]"
+
+                default = "float('nan')" if dtype.is_floating_point else "-1"
+                self.body.writeline(
+                    f"{accumulator} = tl.full({reduced_size_str}, {default}, {acc_type})"
+                )
+
+                accumulators.append(accumulator)
+
+        def csv(values):
+            return " ".join(f"{value}," for value in values)
+
+        def cse_multiple(line, values, masks, dtypes):
+            n = len(values)
+            cache_keys = [f"{line}, {i}, {masks}" for i in range(n)]
+            if all(self.cse.contains(cache_key) for cache_key in cache_keys):
+                return [self.cse.get(cache_key) for cache_key in cache_keys]
+            result_vars = [
+                self.cse.newvar(dtype=dtype, shape=value.shape)
+                for (dtype, value) in zip(dtypes, values)
+            ]
+            self.compute.writeline(
+                f"{csv(result_vars)} = {line}",
+            )
+            for result_var, cache_key in zip(result_vars, cache_keys):
+                if masks:
+                    result_var.mask_vars = masks  # type: ignore[attr-defined]
+                self.cse.put(cache_key, result_var)
+            return tuple(result_vars)
+
+        partial_scan_vars = cse_multiple(
+            f"tl.associative_scan(({csv(broadcasted_values)}), {dim}, {combine_helper_fn})",
+            broadcasted_values,
+            masks,
+            dtypes,
+        )
+
+        if not self.persistent_reduction:
+            # tl.reduce doesn't work for non-commutative operators, so instead
+            # of repeating the scan op as a reduction, we use sum to select the
+            # last scan value
+            def _partial_scan_shape(var):
+                if var.shape is None:
+                    return None
+                else:
+                    shape = list(var.shape)
+                    shape[-1] = "1"
+                    return shape
+
+            partial_reduce_vars = [
+                cse_compute(
+                    f"triton_helpers.select_one(({partial_scan_var}), rbase == (RBLOCK - 1), dim=-1, keep_dims=True)",
+                    dtype=upcast_compute_type(partial_scan_var.dtype),
+                    shape=_partial_scan_shape(partial_scan_var),
+                )
+                for partial_scan_var in partial_scan_vars
+            ]
+            accs_next = combine_fn(tuple(accumulators), tuple(partial_reduce_vars))
+            full_scan_vars = combine_fn(tuple(accumulators), partial_scan_vars)
+            result_vars = [
+                cse_compute(
+                    f"tl.where(roffset > 0, {full_scan}, {partial_scan})",
+                    dtype=partial_scan.dtype,
+                    shape=partial_scan.shape,
+                )
+                for full_scan, partial_scan in zip(full_scan_vars, partial_scan_vars)
+            ]
+            for acc_next, accumulator, partial_reduce in zip(
+                accs_next, accumulators, partial_reduce_vars
+            ):
+                self.compute.writeline(
+                    f"{accumulator} = tl.where(roffset > 0, {acc_next}, {partial_reduce})"
+                )
+        else:
+            result_vars = partial_scan_vars
+
+        for result_var in result_vars:
+            assert isinstance(result_var, TritonCSEVariable)
+            result_var.mask_vars = OrderedSet(masks)
+
+        return tuple(result_vars)
+
+    def sort(
+        self,
+        dtypes: tuple[torch.dtype, ...],
+        values: tuple[CSEVariable, ...],
+        stable: bool,
+        descending: bool,
+    ) -> tuple[CSEVariable, ...]:
+        assert self.inside_reduction
+        assert not self.cooperative_reduction, "TODO"
+        masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees)
+        self.filter_masks(masks)
+        masks = sorted(masks)
+        assert not self._load_mask, "ops.sort not supported inside ops.masked"
+        assert self.persistent_reduction, (
+            "ops.sort is only supported in persistent reductions"
+        )
+
+        cse_compute = functools.partial(self.cse.generate, self.compute)
+        dim = self.triton_tensor_ndim() - self.num_reduction_dims
+
+        dtypes = tuple(upcast_compute_type(dtype) for dtype in dtypes)
+        assert len(dtypes) == len(values)
+        broadcasted_values = [
+            cse_compute(
+                f"tl.broadcast_to({value}, {self.dense_size_str()})",
+                dtype=dtypes[i],
+                shape=tuple(self.dense_size_list()),
+            )
+            for i, value in enumerate(values)
+        ]
+
+        def csv(values):
+            return " ".join(f"{value}," for value in values)
+
+        def cse_multiple(line, broadcasted_values, masks, dtypes):
+            n = len(broadcasted_values)
+            cache_keys = [f"{line}, {i}, {masks}" for i in range(n)]
+            if all(self.cse.contains(cache_key) for cache_key in cache_keys):
+                return [self.cse.get(cache_key) for cache_key in cache_keys]
+            result_vars = [
+                self.cse.newvar(dtype=dtype, shape=value.shape)
+                for dtype, value in zip(dtypes, broadcasted_values)
+            ]  # type: ignore[attr-defined]
+            self.compute.writeline(
+                f"{csv(result_vars)} = {line}",
+            )
+            for result_var, cache_key in zip(result_vars, cache_keys):
+                if masks:
+                    result_var.mask_vars = masks  # type: ignore[attr-defined]
+                self.cse.put(cache_key, result_var)
+            return tuple(result_vars)
+
+        assert self.range_trees[-1].is_reduction
+        rnumel = "None" if self._has_constant_mask(self.range_trees[-1]) else "rnumel"
+
+        if len(values) == 2:
+            line = (
+                f"triton_helpers.sort_with_index({broadcasted_values[0]}, {broadcasted_values[1]},"
+                f" {rnumel}, {dim}, stable={stable}, descending={descending})"
+            )
+            result_vars = cse_multiple(line, broadcasted_values, masks, dtypes)
+        else:
+            raise AssertionError("Unhandled sort")
+
+        for result_var, input_var in zip(result_vars, values):
+            result_var.mask_vars = masks  # type: ignore[attr-defined]
+            result_var.bounds = input_var.bounds
+
+        return tuple(result_vars)
+
+    def codegen_prologue(self, code: IndentedBuffer):
+        """
+        Generate the output from prologue. This should be
+        extracted from the subgraph, which is why this is
+        partitioned from codegen_body.
+        """
+        if not self.prologue:
+            return
+
+        code.splice(self.prologue)
+        self.prologue.clear()
+        self.prologue_cache.clear()
+
+    def codegen_body(self):
+        """
+        Concat output code from index_code, loads, compute, stores,
+        suffix into self.body.
+
+        For pointwise kernels, this is called just once at the end.
+
+        For reduction kernels, this generates a loop over the reduction
+        axis.
+        """
+        if not (
+            self.indexing_code
+            or self.loads
+            or self.stores
+            or self.compute
+            or self.post_loop_combine
+            or self.post_loop_store
+        ):
+            return
+
+        loop_trees = [tree for tree in self.range_trees if tree.is_loop]
+        if self.mix_order_reduction:
+            assert self.persistent_reduction, (
+                "Mix order reduction requires persistent reduction"
+            )
+            accumname2var = {}
+            for idx, partial_accum in enumerate(self.saved_partial_accumulate):
+                reduction_type = partial_accum.reduction_type
+                default = ir.Reduction.default_accumulator(reduction_type, torch.float)
+                default = self._map_tuple_or_scalar(constant_repr, default)
+                name = f"accum{idx}"
+                self.body.writeline(
+                    f"{name} = tl.full([R0_BLOCK], {default}, tl.float32)[None, :]"
+                )
+                accumname2var[name] = self.cse.namedvar(
+                    name, dtype=torch.float, shape=("1", "R0_BLOCK")
+                )
+            self.body.writeline("split_size = min(RSPLIT_SIZE, xnumel - xoffset)")
+            self.body.writeline(
+                "for _ in tl.range(0, split_size, XBLOCK, num_stages=NUM_STAGES):"
+            )
+            with self.body.indent(offset=1):
+                # generate xmask if it's not constant
+                if not self._has_constant_xmask():
+                    entry = self.range_trees[0]
+                    assert entry.prefix == "x"
+                    x = entry.prefix
+                    self.body.writeline(f"{x}mask = {entry.name} < {x}numel")
+                self.body.splice(self.indexing_code)
+                self.body.writelines(
+                    [
+                        "xindex += XBLOCK",
+                    ]
+                )
+                self.body.splice(self.loads)
+                self.body.splice(self.compute)
+                self.body.splice(self.stores)
+                self.body.splice(self.post_loop_store)
+
+                # no need to sum if XBLOCK == 1, or does that matter?
+                for idx, partial_accum in enumerate(self.saved_partial_accumulate):
+                    var = partial_accum.value
+                    name = f"accum{idx}"
+                    combine_fn = ir.get_reduction_combine_fn(
+                        partial_accum.reduction_type, torch.float
+                    )
+                    triton_reduction_function = get_triton_reduction_function(
+                        partial_accum.reduction_type,
+                    )
+                    newval = self.cse.generate(
+                        self.body,
+                        f"{triton_reduction_function}({var}, 0)",
+                        dtype=var.dtype,
+                        shape=("R0_BLOCK",),
+                    )
+                    import unittest
+
+                    with unittest.mock.patch.object(self, "compute", self.body):
+                        updated = combine_fn(
+                            accumname2var[name],
+                            newval,
+                        )
+                    self.body.writeline(f"{name} = {updated}")
+
+            for idx in range(len(self.saved_partial_accumulate)):
+                self.body.writeline(
+                    f"tl.store(ws_ptr + (tl.program_id(0) + {idx} * tl.num_programs(0)) * r0_numel + r0_index, accum{idx}, r0_mask)"
+                )
+
+        elif self.inside_reduction and len(loop_trees) > 0:
+            # Write the loop headers.
+            for level, tree in enumerate(loop_trees):
+                with self.body.indent(offset=level):
+                    prefix = tree.prefix
+                    loop_start = "rsplit_start" if self.cooperative_reduction else "0"
+                    loop_end = (
+                        "rsplit_end" if self.cooperative_reduction else f"{prefix}numel"
+                    )
+                    self.body.writeline(
+                        f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):"
+                    )
+                with self.body.indent(offset=level + 1):
+                    self.iteration_ranges_codegen_header(tree, self.body)
+
+            # The innermost loop performs the reduction.
+            with self.body.indent(offset=len(loop_trees)):
+                self.codegen_reduction_indices(self.body)
+                self.body.splice(self.indexing_code)
+                self.body.splice(self.loads)
+                self.body.splice(self.compute)
+                self.body.splice(self.stores)
+
+            # Write loop suffixes.
+            for level, tree in reversed([*enumerate(loop_trees)]):
+                with self.body.indent(offset=level + 1):
+                    # Advance pointers at the end of each loop.
+                    for block_ptr, advancement in self.pointer_advancements[
+                        tree.symt
+                    ].items():
+                        # Subtract any advancements made in the previous loop level.
+                        if level < len(loop_trees) - 1:
+                            prev_tree = loop_trees[level + 1]
+                            prev_advancement = self.pointer_advancements[
+                                prev_tree.symt
+                            ][block_ptr]
+                            prev_block = TritonSymbols.get_block_size(prev_tree)
+                            prev_num_iter = CeilDiv(prev_tree.numel, prev_block)
+                            advancement = [
+                                cur - prev * prev_num_iter
+                                for cur, prev in zip(advancement, prev_advancement)
+                            ]
+
+                        self.body.writeline(
+                            DeferredLine(
+                                self.block_ptr_to_buffer[block_ptr],
+                                f"{block_ptr} = tl.advance({block_ptr}, {V.kernel.index_to_str(advancement)})",
+                            )
+                        )
+
+                # Invalidate any cache entries that came from inside the loop.
+                self.cse.invalidate(self.outside_loop_vars)
+                tree.cache_clear()
+        else:
+            self.body.splice(self.indexing_code)
+            self.body.splice(self.loads)
+            self.body.splice(self.compute)
+            self.body.splice(self.stores)
+        self.body.splice(self.post_loop_combine)
+        if self.cooperative_reduction and (
+            self.post_loop_combine or self.post_loop_store
+        ):
+            sem_ptr = f"{self.semaphores_name} + tl.program_id(1)"
+            self.body.splice(
+                f"""
+                if HAS_RSPLIT:
+                    triton_helpers.x_grid_barrier({sem_ptr})
+                """,
+                strip=True,
+            )
+            self.cooperative_reduction_workspace_cache.on_loop_end()
+        if not self.mix_order_reduction:
+            self.body.splice(self.post_loop_store)
+        self.indexing_code.clear()
+        self.loads.clear()
+        self.compute.clear()
+        self.stores.clear()
+        self.post_loop_combine.clear()
+        self.post_loop_store.clear()
+
+    def kernel_benchmark_extra_args(self) -> list[str]:
+        args = []
+        if self.need_numel_args():
+            numel_args: list[sympy.Expr] = []
+            self.add_numel_to_call_args("", numel_args, [])
+            for arg in numel_args:
+                if isinstance(arg, int):
+                    args.append(str(arg))
+                elif isinstance(arg, SymbolicCallArg):
+                    hint = V.graph.sizevars.size_hint(
+                        arg.inner_expr,
+                        hint_override=self.hint_override,
+                        fallback=config.unbacked_symint_fallback,
+                    )
+                    args.append(str(hint))
+                elif isinstance(arg, sympy.Expr):
+                    hint = V.graph.sizevars.size_hint(
+                        arg,
+                        hint_override=self.hint_override,
+                        fallback=config.unbacked_symint_fallback,
+                    )
+                    args.append(str(hint))
+                else:
+                    raise ValueError(f"Unsupported numel argument type: {type(arg)}")
+        return args
+
+    def codegen_kernel_benchmark(self, num_gb: Optional[float]) -> IndentedBuffer:
+        """
+        Generates Python code for benchmarking this Triton kernel.
+        - Creates example inputs (random tensors, constants, sizes).
+        - Runs the kernel on the current GPU/stream.
+        - Prints runtime (ms) and throughput (GB/s) using `num_gb`.
+        Args:
+            num_gb (float): The number of gigabytes to use for throughput calculation.
+        Returns:
+            IndentedBuffer: A buffer containing the generated Python benchmark code.
+        """
+        result = IndentedBuffer()
+        _argdefs, call_args, signature, _ = self.args.python_argdefs()
+
+        result.writelines(["", "", "def get_args():"])
+        with result.indent():
+            name_cnt = itertools.count()
+            var_names = []
+            for arg_name, arg_sig in zip(call_args, signature):
+                var_name = f"arg_{next(name_cnt)}"
+                buf = V.graph.try_get_buffer(arg_name)
+                if buf:
+                    size = V.graph.sizevars.size_hints(
+                        buf.get_size(),
+                        hint_override=self.hint_override,
+                        fallback=config.unbacked_symint_fallback,
+                    )
+                    stride = V.graph.sizevars.size_hints(
+                        buf.get_stride(),
+                        hint_override=self.hint_override,
+                        fallback=config.unbacked_symint_fallback,
+                    )
+                    result.writeline(
+                        f"{var_name} = rand_strided({size}, {stride}, device='{buf.get_device()}', dtype={buf.get_dtype()})"  # noqa: B950 line too long
+                    )
+                elif arg_name in V.graph.constants:
+                    # note that random seed is put in V.graph.constants
+                    const_tensor = V.graph.constants[arg_name]
+                    size = V.graph.sizevars.size_hints(
+                        const_tensor.size(),
+                        hint_override=self.hint_override,
+                        fallback=config.unbacked_symint_fallback,
+                    )
+                    stride = V.graph.sizevars.size_hints(
+                        const_tensor.stride(),
+                        hint_override=self.hint_override,
+                        fallback=config.unbacked_symint_fallback,
+                    )
+                    result.writeline(
+                        f"{var_name} = rand_strided({size}, {stride}, device='{const_tensor.device}', dtype={const_tensor.dtype})"  # type: ignore[arg-type]  # noqa: B950 line too long
+                    )
+                elif isinstance(arg_sig, SizeArg):
+                    symval_hint = V.graph.sizevars.size_hint(
+                        arg_sig.expr,
+                        hint_override=self.hint_override,
+                        fallback=config.unbacked_symint_fallback,
+                    )
+
+                    # Force the seed_offset to be 0 so calls to the same kernel
+                    # using different seed offset will have the same benchmark harness.
+                    # We can dedup kernel definitions in this case.
+                    if "seed_offset" in arg_sig.name:
+                        symval_hint = 0
+                    result.writeline(f"{var_name} = {symval_hint}")
+                elif isinstance(arg_sig, WorkspaceArg):
+                    device = V.graph.get_current_device_or_throw()
+                    count = V.graph.sizevars.size_hint(
+                        arg_sig.count, hint_override=self.hint_override
+                    )
+                    result.writeline(
+                        f"{var_name} = torch.zeros({count}, device='{device}', dtype={arg_sig.dtype})"
+                    )
+                else:
+                    raise KeyError(
+                        f"Don't find the buffer or const tensor for {arg_name}"
+                    )
+                var_names.append(var_name)
+            var_names.extend(self.kernel_benchmark_extra_args())
+            result.writeline(f"return {', '.join(var_names)},")
+
+        result.writelines(["\n", "\n", "def call(args):"])
+        current_device = V.graph.get_current_device_or_throw()
+        index = current_device.index
+        with result.indent():
+            result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
+            with result.indent():
+                result.writeline(
+                    V.graph.device_ops.set_device(index)
+                )  # no-op to ensure context
+                stream_name = f"stream{index}"
+                result.writeline(f"{stream_name} = get_raw_stream({index})")
+                result.writeline(
+                    f"{str(Placeholder.KERNEL_NAME)}.run(*args, stream={stream_name})"
+                )
+
+        # benchmark all configs
+        result.writelines(["\n", "\n", "def benchmark_all_configs(args):"])
+        with result.indent():
+            result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
+            with result.indent():
+                result.writeline(
+                    V.graph.device_ops.set_device(index)
+                )  # no-op to ensure context
+                result.writeline(
+                    f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args)"
+                )
+
+        result.writelines(["\n", "\n", "if __name__ == '__main__':"])
+        with result.indent():
+            result.writeline(
+                "from torch._inductor.runtime.benchmarking import benchmarker"
+            )
+            result.writeline("")
+
+            result.writeline("args = get_args()")
+            result.writeline(
+                f"ms = benchmarker.benchmark(lambda: call(args), device={V.graph.get_current_device_or_throw().type}, rep=40)"  # noqa: B950 line too long
+            )
+            result.writeline(f"num_gb = {num_gb}")
+            result.writeline("gb_per_s = num_gb / (ms / 1e3)")
+            result.writeline(
+                'print(f"{ms:.3f}ms    {num_gb:.3f}GB    {gb_per_s:.2f}GB/s")'
+            )
+
+        return result
+
+    def imports_for_benchmark_kernel(self):
+        return textwrap.dedent(
+            """
+            from torch._dynamo.testing import rand_strided
+            {}
+            import torch
+        """.format(V.graph.device_ops.import_get_raw_stream_as("get_raw_stream"))
+        )
+
+    def _get_heuristic(self):
+        if self.fixed_config:
+            return "fixed_config"
+        elif self.cooperative_reduction:
+            return "cooperative_reduction"
+        elif self.persistent_reduction:
+            assert self.inside_reduction
+            return "persistent_reduction"
+        elif self.inside_reduction:
+            return "reduction"
+        return "pointwise"
+
+    @staticmethod
+    def inductor_meta_common():
+        inductor_meta = {
+            "backend_hash": torch.utils._triton.triton_hash_with_backend(),
+            "assert_indirect_indexing": config.assert_indirect_indexing,
+            "autotune_local_cache": config.autotune_local_cache,
+            "autotune_pointwise": config.triton.autotune_pointwise,
+            "autotune_remote_cache": config.autotune_remote_cache,
+            "force_disable_caches": config.force_disable_caches,
+            "dynamic_scale_rblock": config.dynamic_scale_rblock,
+            "max_autotune": config.max_autotune,
+            "max_autotune_pointwise": config.max_autotune_pointwise,
+            "min_split_scan_rblock": config.triton.min_split_scan_rblock,
+            "spill_threshold": config.triton.spill_threshold,
+            "store_cubin": config.triton.store_cubin,
+            "deterministic": config.deterministic,
+            "force_filter_reduction_configs": config.test_configs.force_filter_reduction_configs,
+        }
+
+        if config.write_are_deterministic_algorithms_enabled:
+            inductor_meta["are_deterministic_algorithms_enabled"] = (
+                torch.are_deterministic_algorithms_enabled()
+            )
+
+        if torch.version.hip is not None:
+            inductor_meta["is_hip"] = True
+        if config.is_fbcode():
+            inductor_meta["is_fbcode"] = True
+        if config.profile_bandwidth:
+            inductor_meta["profile_bandwidth"] = config.profile_bandwidth
+            inductor_meta["profile_bandwidth_regex"] = config.profile_bandwidth_regex
+            inductor_meta["profile_bandwidth_output"] = config.profile_bandwidth_output
+            inductor_meta["profile_bandwidth_with_do_bench_using_profiling"] = (
+                config.profile_bandwidth_with_do_bench_using_profiling
+            )
+        if config.coordinate_descent_tuning:
+            inductor_meta["coordinate_descent_tuning"] = (
+                config.coordinate_descent_tuning
+            )
+            inductor_meta["coordinate_descent_search_radius"] = (
+                config.coordinate_descent_search_radius
+            )
+            inductor_meta["coordinate_descent_check_all_directions"] = (
+                config.coordinate_descent_check_all_directions
+            )
+        return inductor_meta
+
+    def codegen_kernel(self, name=None) -> str:
+        """
+        Convert the TritonKernel from Inductor SIMD IR to triton code, including inductor triton heuristics, imports,
+        metadata, and benchmarking infra.
+        """
+
+        code = IndentedBuffer()
+
+        size_hints = {}
+        for prefix, numel in self.numels.items():
+            if prefix_is_reduction(prefix) and not self.inside_reduction:
+                continue
+
+            numel_hint = V.graph.sizevars.symbolic_hint(numel)
+            if not isinstance(numel_hint, (int, sympy.Integer)):
+                # This default heuristic hint was picked carefully: it is
+                # large, to ensure that we don't shrink the block size (since
+                # if you don't have many elements, it'd be wasteful to pick a
+                # large block size).  Since we don't know how many elements we
+                # might have, we should be OK with some inefficiency to make
+                # sure we handle the large case well.  8192 is the largest
+                # block size we support, so we pick that.
+                #
+                # If we have a better hint for unbacked SymInts (e.g., because
+                # a user told us, or we are tracking upper bounds) we could
+                # use that here.
+                size_hint = 8192
+            else:
+                size_hint = next_power_of_2(int(numel_hint))
+            size_hints[prefix] = size_hint
+
+        if name is None:
+            code.splice(gen_common_triton_imports())
+            device_type = V.graph.get_current_device_or_throw().type
+            if device_type == "cpu":
+                code.splice("triton_helpers.set_driver_to_cpu()")
+            else:
+                code.splice("triton_helpers.set_driver_to_gpu()")
+
+            if config.benchmark_kernel:
+                code.splice(self.imports_for_benchmark_kernel())
+
+        argdefs, _, signature, _ = self.args.python_argdefs()
+        # maps actual expression to SizeArg if it is in sizevars replacements
+        for i, arg in enumerate(signature):
+            if isinstance(arg, SizeArg):
+                # mypy is unhappy about the sympy.Expr
+                # type for the key of the dict below
+                symbol = cast(sympy.Symbol, arg.expr)
+                if symbol in V.graph.sizevars.inv_precomputed_replacements:
+                    signature[i] = SizeArg(
+                        arg.name, V.graph.sizevars.inv_precomputed_replacements[symbol]
+                    )
+
+        mutated_args: OrderedSet[str] = OrderedSet()
+        for mutation in self.mutations:
+            if mutation in self.args.input_buffers:
+                mutated_args.add(self.args.input_buffers[mutation])
+            if (
+                mutation in self.args.inplace_buffers
+                and mutation not in V.graph.removed_buffers
+                and mutation not in self.removed_buffers
+            ):
+                mutated_args.add(
+                    cast(InplacedBuffer, self.args.inplace_buffers[mutation]).inner_name
+                )
+            if mutation in self.args.output_buffers:
+                mutation_arg = self.args.output_buffers[mutation]
+                assert not isinstance(mutation_arg, RemovedArg)
+                mutated_args.add(mutation_arg)
+
+        # Note: [Workspace Mutation]
+        # workspace arguments are mutated, but are not marked as mutations in self.mutations
+        # because their buffers are added during codegen, and aren't tracked during
+        # lowering/scheduling. So we add them as mutated_args explicitly below.
+        #
+        # In the logic below, we only mark the workspaces a mutated if they are marked with
+        # zero_fill: that's because, if we don't expect the buffer to be pre-filled with
+        # zeros, then, although we still mutate the data, we don't care about those
+        # mutations because we don't make any assumptions about the contents of the
+        # workspace buffer.  Similarly, ZERO_PER_GRAPH requires the kernel to return
+        # the buffer back to its original state.
+        for argname, arg in zip(argdefs, signature):
+            if (
+                isinstance(arg, WorkspaceArg)
+                and arg.zero_mode == WorkspaceZeroMode.ZERO_ON_CALL
+            ):
+                mutated_args.add(argname.name)
+
+        mutated_args = sorted(mutated_args)
+
+        for tree in self.active_range_trees():
+            sizearg = SizeArg(f"{tree.prefix}numel", tree.numel)
+            signature.append(sizearg)
+            argdefs.append(ArgName(sizearg.name))
+            # constexpr version causes issues, see
+            # https://github.com/pytorch/torchdynamo/pull/1362
+            # triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint(
+            #     tree.numel
+            # )
+            # argdefs.append(f"{tree.prefix}numel: tl.constexpr")
+
+        def add_constexpr_arg(arg_name):
+            # new versions (but not old versions) of Triton need constexprs included in the signature
+            if triton_version_uses_attrs_dict():
+                signature.append(ConstexprArg(arg_name))
+            argdefs.append(ArgName(arg_name, is_constexpr=True))
+
+        for tree in self.range_trees:
+            if tree.is_reduction and self.persistent_reduction:
+                # Rn_BLOCK for persistent_reduction is defined in codegen_static_numels
+                continue
+            if tree.tensor_dim is None:
+                continue
+
+            add_constexpr_arg(f"{tree.prefix.upper()}BLOCK")
+
+        if self.cooperative_reduction:
+            add_constexpr_arg("RSPLIT")
+
+        if self.mix_order_reduction:
+            add_constexpr_arg("RSPLIT_SIZE")
+            add_constexpr_arg("NUM_STAGES")
+
+        triton_meta_signature = signature_to_meta(
+            signature, size_dtype=self.index_dtype, argdefs=argdefs
+        )
+        triton_meta: dict[str, Any] = {
+            "signature": triton_meta_signature,
+            "device": DeviceProperties.create(V.graph.get_current_device_or_throw()),
+            "constants": {},
+            "native_matmul": (
+                torch._inductor.config.triton.native_matmul
+                and ("tl.dot" in str(self.body) or "tl.dot" in str(self.compute))
+            ),
+        }
+
+        # Skip memory optimization for forward of the training loop where we expect
+        # every new node will increase the peak memory and our greedy approach would
+        # introduce a lot of unnecessary cpu copies.
+        optimize_mem = V.graph.is_inference or V.graph.is_backward
+
+        inductor_meta = {
+            "grid_type": self._get_grid_type().__name__,
+            # Triton will not accept an OrderedSet for autotune_hints
+            "autotune_hints": set(self.autotune_hints),  # noqa: set_linter
+            "kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
+            "mutated_arg_names": mutated_args,
+            "optimize_mem": optimize_mem,
+            "no_x_dim": self.no_x_dim,
+            "atomic_add_found": self.atomic_add_found,
+            "num_load": self.num_load,
+            "num_store": self.num_store,
+            "num_reduction": self.num_reduction,
+            **self.inductor_meta_common(),
+        }
+
+        if self.mix_order_reduction:
+            inductor_meta["RSPLIT_SIZE"] = self.rsplit_size
+
+        if config.deterministic or config.test_configs.force_filter_reduction_configs:
+            inductor_meta["has_loadstore_with_contiguous_rdim"] = (
+                self.has_load_with_contiguous_rdim
+                or self.has_store_with_contiguous_rdim
+            )
+
+        # Bail on 3d tiling, which has more complicated coalesce patterns
+        looped_red = V.kernel.features.is_reduction() and not self.persistent_reduction
+        tiling_scores = self.tiling_scores
+        two_d_red = len(self.tiling) == 2
+        if looped_red and two_d_red:
+            memory_stats = self.features.memory_stats(self.tiling)
+            dim_stats = memory_stats.persistent.memory.dim[0]
+            mem_ops_per_thread = dim_stats.count_per_thread
+
+            if (
+                tiling_scores is not None
+                and "x" in tiling_scores
+                and "r0_" in tiling_scores
+            ):
+                # large rblock inhibits xblock size, dont attempt if there is a decent amount of
+                # reads coalesced by xblock
+                r_coalesce_ratio = tiling_scores["r0_"] / max(tiling_scores["x"], 1)
+                contiguous_red = r_coalesce_ratio >= 8.0
+            else:
+                from torch._inductor.runtime.hints import ReductionHint
+
+                contiguous_red = (
+                    self.features.get_reduction_hint() == ReductionHint.INNER
+                )
+
+            looped_mem = memory_stats.looped.memory.bytes
+            persistent_mem = memory_stats.persistent.memory.bytes
+            # check that we save significant memory by doing persistent
+            saved_bytes_ratio = V.graph.sizevars.size_hint(
+                looped_mem, fallback=config.unbacked_symint_fallback
+            ) / max(
+                V.graph.sizevars.size_hint(
+                    persistent_mem, fallback=config.unbacked_symint_fallback
+                ),
+                1,
+            )
+
+            # TODO - rnumel should be reasonably close to power of 2
+            if (
+                # significant memory bandwidth savings
+                saved_bytes_ratio >= 1.3
+                and contiguous_red
+                # TODO - need more detailed register analysis
+                and V.graph.sizevars.statically_known_leq(
+                    self.features.reduction_numel, 32768
+                )
+                # We will already generate a persistent config in this case
+                and V.graph.sizevars.statically_known_gt(
+                    self.features.reduction_numel, 2048
+                )
+                and mem_ops_per_thread <= 10
+            ):
+                inductor_meta["add_persistent_rblock"] = True
+
+        if self.tiling_scores:
+            inductor_meta["tiling_scores"] = self.tiling_scores
+
+        if self.tma_min_block_sizes:
+            inductor_meta["tma_min_block_sizes"] = self.tma_min_block_sizes
+
+        if self.cooperative_reduction:
+            inductor_meta["persistent_reduction"] = self.persistent_reduction
+
+        num_gb = None
+        if config.benchmark_kernel or config.profile_bandwidth:
+            num_gb = self.estimate_kernel_num_bytes() / 1e9
+            if num_gb is not None:
+                inductor_meta["kernel_num_gb"] = num_gb
+        if config.benchmark_kernel:
+            flops = self.estimate_flops()
+            if flops is not None:
+                inductor_meta["kernel_flop"] = flops
+
+        triton_meta["configs"] = [config_of(signature)]
+
+        if enable_pdl_codegen():
+            triton_meta["launch_pdl"] = True
+
+        # Triton compiler includes equal_to_1 args into constants even
+        # when they are not constexpr. otherwise there may be a segfault
+        # during launching the Inductor-compiled Triton kernel.
+        # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307
+        # https://github.com/triton-lang/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384
+        for arg_num in equal_1_arg_indices(signature):  # type: ignore[index]
+            triton_meta["constants"][signature[arg_num].name] = 1  # type: ignore[index,union-attr]
+        triton_meta["enable_fp_fusion"] = not config.emulate_precision_casts
+
+        self.triton_meta = triton_meta
+
+        self.codegen_prologue(self.body)
+        self.codegen_body()
+
+        for helper in self.helper_functions:
+            code.writeline("")
+            code.splice(helper)
+
+        if self.fixed_config:
+            heuristics_line = f"""
+                @triton_heuristics.{self._get_heuristic()}(
+                    config={self.fixed_config.config!r},
+                    filename=__file__,
+                    triton_meta={triton_meta!r},
+                    inductor_meta={inductor_meta!r}
+                )
+                @triton.jit
+            """
+        elif self.inside_reduction:
+            reduction_hint = self.features.get_reduction_hint()
+            heuristics_line = f"""
+                @triton_heuristics.{self._get_heuristic()}(
+                    size_hints={size_hints!r},
+                    reduction_hint={reduction_hint},
+                    filename=__file__,
+                    triton_meta={triton_meta!r},
+                    inductor_meta={inductor_meta!r}
+                )
+                @triton.jit
+            """
+        else:
+            tile_hint = ""
+            if len(size_hints) == 2:
+                if (
+                    len(non_constexpr_signature(signature)) == 4
+                ):  # input, output and 2 args
+                    tile_hint = "tile_hint=TileHint.SQUARE,"
+                else:
+                    tile_hint = "tile_hint=TileHint.DEFAULT,"
+            heuristics_line = f"""
+                @triton_heuristics.{self._get_heuristic()}(
+                    size_hints={size_hints!r}, {tile_hint}
+                    filename=__file__,
+                    triton_meta={triton_meta!r},
+                    inductor_meta={inductor_meta!r},
+                    min_elem_per_thread={self.min_elem_per_thread}
+                )
+                @triton.jit
+            """
+        code.splice(heuristics_line)
+        code.writeline(
+            f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(x.full_name() for x in argdefs)}):"
+        )
+        with code.indent():
+            self.codegen_static_numels(code)
+            for old, new in self.args.aliases():
+                code.writeline(f"{old} = {new}")
+            code.splice(self.body)
+
+        if config.benchmark_kernel:
+            code.splice(self.codegen_kernel_benchmark(num_gb))
+
+        return code.getvalue()
+
+    @staticmethod
+    def _get_persistent_RBLOCK(rnumel):
+        rnumel = V.graph.sizevars.simplify(rnumel)
+        if isinstance(rnumel, (sympy.Integer, int)):
+            val = int(rnumel)
+            val = next_power_of_2(val)
+        else:
+            val = 2
+            while not V.graph.sizevars.statically_known_leq(rnumel, val):
+                if val > 16 * 1024:
+                    raise ValueError(f"Failed to find static RBLOCK for {rnumel}")
+                val *= 2
+
+            return val
+
+        return val
+
+    @staticmethod
+    def has_persistent_RBLOCK(rnumel):
+        try:
+            TritonKernel._get_persistent_RBLOCK(rnumel)
+            return True
+        except ValueError:
+            return False
+
+    def codegen_static_numels(self, code):
+        """
+        We get a small speedup from hard coding numels if they are static.
+
+        This code stomps on the passed-in values by writing an constant to the top of the kernel.
+
+        In a kernel like:
+        def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
+
+        We would add
+        xnumel = 4096
+        r0_numel = 768
+
+        After the signature, before the kernel code, if we decided to make these static. As its hardcoded, it becomes
+        a better signal to triton on how to unroll and do some static indexing. So, it's not so much that downstream
+        knows that its a static numel, as that you just plop a constant into the kernel.
+        """
+
+        def is_static_integer(expr: sympy.Expr) -> bool:
+            return isinstance(expr, (sympy.Integer, int))
+
+        for tree in self.range_trees:
+            if not tree.is_reduction or self.inside_reduction:
+                simplified_tree_numel = V.graph.sizevars.simplify(tree.numel)
+                if is_static_integer(simplified_tree_numel):
+                    code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}")
+
+            if tree.is_reduction and self.persistent_reduction:
+                if self.cooperative_reduction:
+                    numel = self.kexpr(self.rename_indexing(tree.numel))
+                    val = f"triton_helpers.constexpr_next_power_of_2(({numel} + RSPLIT - 1) // RSPLIT)"
+                else:
+                    val = self._get_persistent_RBLOCK(tree.numel)
+                    if self.is_native_matmul:
+                        # tl.dot only supports shapes >= 16
+                        val = max(val, 16)
+
+                code.writeline(f"{tree.prefix.upper()}BLOCK: tl.constexpr = {val}")
+
+            if tree.prefix == "x" and self.no_x_dim:
+                code.writeline("XBLOCK: tl.constexpr = 1")
+
+    def _get_grid_type(self) -> type[triton_heuristics.GridExpr]:
+        n = sum([int(not tree.is_reduction) for tree in self.range_trees])
+        if self.mix_order_reduction:
+            assert n == 1
+            return triton_heuristics.MixOrderReductionGrid
+        elif self.cooperative_reduction:
+            assert n == 1
+            return triton_heuristics.CooperativeReductionGrid
+        elif n == 1:
+            return triton_heuristics.Grid1D
+        elif n == 2:
+            if any(map(self.needs_yz_grid_overflow, self.range_trees)):
+                return triton_heuristics.Grid2DWithYZOverflow
+            return triton_heuristics.Grid2D
+        elif n == 3:
+            return triton_heuristics.Grid3D
+        raise ValueError(f"Unsupported number of dimensions: {n}")
+
+    def add_numel_to_call_args(self, name, call_args, arg_types):
+        # TODO(jansel): if there are constants, we shouldn't bother passing them as args
+        for tree in self.range_trees:
+            if isinstance(tree.numel, (sympy.Integer, sympy.Symbol)):
+                expr = tree.numel
+            else:
+                expr = V.graph.wrapper_code.generate_numel_expr(name, tree)
+
+            if not tree.is_reduction or self.inside_reduction:
+                call_args.append(expr)
+                arg_types.append(type(expr))
+
+    def call_kernel(
+        self, name: str, node: Optional[IRNode] = None, deallocate_ws: bool = True
+    ):
+        wrapper = V.graph.wrapper_code
+        wrapper.write_triton_header_once()
+        _, call_args, _, arg_types = self.args.python_argdefs()
+        self.add_numel_to_call_args(name, call_args, arg_types)
+
+        for ws in self.args.workspace_args:
+            wrapper.generate_workspace_allocation(ws)
+
+        wrapper.generate_kernel_call(
+            name,
+            call_args,
+            triton=True,
+            arg_types=arg_types,
+            triton_meta=self.triton_meta,
+        )
+
+        if deallocate_ws:
+            self.deallocate_workspaces()
+
+    def codegen_nan_check(self) -> None:
+        wrapper = V.graph.wrapper_code
+        _, call_args, arg_signatures, _ = self.args.python_argdefs()
+        for arg, arg_signature in zip(call_args, arg_signatures):
+            if isinstance(arg_signature, TensorArg):
+                if V.graph.cpp_wrapper:
+                    wrapper.writeline(
+                        f'AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_check_inf_and_nan("{arg}", {arg}));'
+                    )
+                else:
+                    line = f"assert not {arg}.isnan().any().item()"
+                    wrapper.writeline(line)
+                    line = f"assert not {arg}.isinf().any().item()"
+                    wrapper.writeline(line)
+
+    def create_cse_var(self, *args, **kwargs) -> TritonCSEVariable:
+        return TritonCSEVariable(*args, **kwargs)
+
+    def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry):
+        line = f"{entry.name} = {self.kexpr(self.rename_indexing(entry.expr))}"
+
+        # mix order reduction introduces an extra loop across the x
+        # dimension
+        if entry.root.is_loop or (self.mix_order_reduction and entry.prefix == "x"):
+            self.indexing_code.writeline(line)
+        else:
+            # lift non-reduction stores outside loop
+            self.body.writeline(line)
+
+    def iteration_ranges_ranges_code(self, entry: IterationRangesRoot) -> str:
+        assert entry.tensor_dim is not None
+        size = self.indexing_size_str(entry.tensor_dim)
+        index_dtype = self.index_dtype
+        suffix = f".to({index_dtype})" if index_dtype != "tl.int32" else ""
+        if (
+            self.cooperative_reduction
+            and self.persistent_reduction
+            and entry.is_reduction
+        ):
+            suffix = f"{suffix} + rsplit_start"
+        return f"tl.arange(0, {entry.prefix.upper()}BLOCK){size}{suffix}"
+
+    def iteration_ranges_scalar_code(
+        self, entry: IterationRangesRoot, value: Any
+    ) -> str:
+        index_dtype = self.index_dtype
+        ndim = self.triton_tensor_ndim()
+        size = [1] * ndim
+        return f"tl.full({size}, {value}, {index_dtype})"
+
+    def iteration_ranges_get_pid(self, entry: IterationRangesRoot) -> str:
+        assert entry.grid_dim is not None
+        key = f"tl.program_id({entry.grid_dim})"
+        # y_grid has a limit, so express it in terms of y and z in case of overflow.
+        # z grid is only exercised when max_tiles == 3 (off by default).
+        if self.needs_yz_grid_overflow(entry):
+            # For ynumel larger than max_ygrid, we need to use zdim.
+            # For each z dimension, there are tl.num_programs(1) yblocks which is passed by grad(x,y,z).
+            # So, we need to add tl.program_id(z) * tl.num_programs(y) *YBLOCK to get the correct yoffset.
+            key = f"({key} + tl.program_id({entry.grid_dim + 1}) * tl.num_programs({entry.grid_dim}))"
+        pid = entry.pid_cache.get(key, key)
+        if self.index_dtype != "tl.int32":
+            return f"{pid}.to({self.index_dtype})"
+        return pid
+
+    def needs_yz_grid_overflow(self, entry: IterationRangesRoot) -> bool:
+        return (
+            entry.grid_dim == 1
+            and not entry.has_zdim
+            and not self.cooperative_reduction
+            and not V.graph.sizevars.statically_known_leq(entry.numel, get_max_y_grid())
+        )
+
+    def max_block(self, prefix: str) -> int:
+        if self.fixed_config:
+            return self.fixed_config[f"{prefix.upper()}BLOCK"]
+        return TRITON_MAX_BLOCK[prefix.upper()]
+
+    def _has_constant_mask(self, tree: IterationRangesRoot) -> bool:
+        if self.is_native_matmul:
+            # tl.dot requires the shape to be >= 16,
+            # so when matmul shape is smaller than 16, we always keep the mask.
+            if V.graph.sizevars.statically_known_lt(tree.numel, 16):
+                return False
+
+        if not self.optimize_mask:
+            return False
+
+        if self.fixed_config and f"{tree.prefix.upper()}BLOCK" in self.fixed_config:
+            if self.fixed_config[f"{tree.prefix.upper()}BLOCK"] == 1:
+                return True
+        else:
+            if V.graph.sizevars.statically_known_equals(tree.numel, 1):
+                return True
+
+        # Masks are superfluous if numel is a multiple of BLOCK
+        # (We use the fact that BLOCK is required by triton to be a power of 2)
+        if tree.is_reduction and self.persistent_reduction:
+            max_block = self._get_persistent_RBLOCK(tree.numel)
+        elif tree.prefix == "x" and self.no_x_dim:
+            max_block = 1
+        else:
+            max_block = self.max_block(tree.prefix)
+
+        if tree.is_reduction and self.cooperative_reduction:
+            max_block = max_block * self.max_rsplit()
+
+        # [Note: Constant mask optimisation]
+        # Optional optimization: if block divides numel exactly, we will
+        # never need to do a masked load to handle stragglers at the end.
+        # If this tree is for the y dimension, we should only use a constant
+        # mask if it can be guaranteed that:
+        # 1. (ynumel / YBLOCK) < max_ygrid or
+        # 2. (ynumel / YBLOCK) % max_ygrid == 0
+        # Because YBLOCK is not constant, use a conservative heuristic:
+        # only use a constant mask if ynumel < max_ygrid.
+        # It's faster to avoid masking at all.  But it is sound to always
+        # mask.
+        if V.graph.sizevars.statically_known_multiple_of(tree.numel, max_block):
+            return (
+                tree.grid_dim != 1
+                or tree.has_zdim
+                or V.graph.sizevars.statically_known_leq(tree.numel, get_max_y_grid())
+            )
+
+        return False
+
+    def _has_constant_xmask(self) -> bool:
+        xtree = self.range_trees[0]
+        assert xtree.prefix == "x"
+        return self._has_constant_mask(xtree)
+
+    def filter_masks(self, mask_vars: OrderedSet[str]) -> None:
+        for tree in self.range_trees:
+            if self._has_constant_mask(tree):
+                mask_vars.discard(f"{tree.prefix}mask")
+
+        # can be added as an override_mask
+        mask_vars.discard("None")
+
+    @cache_on_self
+    def get_reduction_prefixes(self) -> list[str]:
+        return [
+            prefix_str[symt]
+            for symt in list(TritonSymbols.reduction_types)[: self.num_reduction_dims]
+        ]
+
+    def codegen_reduction_numels(self, buffer: IndentedBuffer) -> None:
+        """
+        Generates code that flattens ND reduction numels, block sizes, etc. into 1D.
+        """
+        # rnumel = r0_numel * ... * r(n-1)_numel
+        reduction_trees = [tree for tree in self.range_trees if tree.is_reduction]
+        rnumel = " * ".join(sorted(f"{tree.prefix}numel" for tree in reduction_trees))
+        buffer.splice(f"rnumel = {self.kexpr(rnumel)}")
+
+        # RBLOCK = R0_BLOCK * ... * R(N-1)_BLOCK
+        rn_blocks = [
+            TritonSymbols.block_sizes[tree.symt]
+            for tree in self.range_trees
+            if tree.is_reduction
+        ]
+        rblock = sympy_product(rn_blocks)
+        buffer.splice(f"RBLOCK: tl.constexpr = {self.kexpr(rblock)}")
+
+    def _get_reduction_symbols(self, suffix: str, **kwargs) -> list[sympy.Symbol]:
+        """
+        Helper to initialize symbols like rn_numel, rn_base, etc.
+        """
+        rn_prefixes = self.get_reduction_prefixes()
+        return [sympy.Symbol(f"{prefix}{suffix}", **kwargs) for prefix in rn_prefixes]
+
+    @cache_on_self
+    def _get_reduction_index_coeffs(self) -> list[sympy.Expr]:
+        """
+        Compute coefficients to convert ND reduction indices to linear indices.
+        For example:
+          rindex = r0_index * r1_numel * ... * rn_numel + ... + rn_index.
+        """
+        rn_prefixes = self.get_reduction_prefixes()
+        rn_numels = self._get_reduction_symbols("numel", integer=True, positive=True)
+        return [
+            sympy_product(rn_numels[idx + 1 :]) for idx in range(len(rn_prefixes) - 1)
+        ] + [sympy.Integer(1)]
+
+    def _flatten_reduction_indices(self, multi_inds: list[sympy.Expr]) -> sympy.Expr:
+        """
+        Compute linear reduction indices from N dimensional ones.
+        """
+        coeffs = self._get_reduction_index_coeffs()
+        return sympy_dot(coeffs, multi_inds)
+
+    def codegen_reduction_indices(self, buffer: IndentedBuffer) -> None:
+        """
+        Generates code that converts ND reduction indices into linear indices.
+        """
+        # Gather relevant numels, indices, etc.
+        rn_offsets = self._get_reduction_symbols(
+            "offset", integer=True, nonnegative=True
+        )
+        rn_inds = self._get_reduction_symbols("index", integer=True, nonnegative=True)
+
+        # Compute roffset and rindex.
+        roffset = self._flatten_reduction_indices(rn_offsets)
+        buffer.splice(f"roffset = {self.index_to_str(roffset)}")
+        rindex = self._flatten_reduction_indices(rn_inds)
+        buffer.splice(f"rindex = {self.index_to_str(rindex)}")
+
+    def iteration_ranges_codegen_header(
+        self, entry: IterationRangesRoot, code: IndentedBuffer
+    ) -> None:
+        x = entry.prefix
+        if entry.is_loop:
+            code.writeline(f"{entry.name} = {x}offset + {x}base")
+        elif entry.grid_dim is None:
+            # no need to "{x}offset = "
+            code.writeline(f"{entry.name} = {self.iteration_ranges_ranges_code(entry)}")
+            code.writeline(f"{x}offset = 0")
+        else:
+            if entry.tensor_dim is not None:
+                line = f"{x}offset + {self.iteration_ranges_ranges_code(entry)}"
+            else:
+                line = self.iteration_ranges_scalar_code(entry, f"{x}offset")
+
+            block_size = (
+                f"{x.upper()}BLOCK" if not self.mix_order_reduction else "RSPLIT_SIZE"
+            )
+            code.writelines(
+                [
+                    f"{x}offset = {self.iteration_ranges_get_pid(entry)} * {block_size}",
+                    f"{entry.name} = {line}",
+                ]
+            )
+        if self._has_constant_mask(entry):
+            code.writeline(self.create_constant_mask(entry))
+        elif not (x == "x" and self.mix_order_reduction):
+            # mix order reduction should generate xmask inside the loop
+            code.writeline(f"{x}mask = {entry.name} < {x}numel")
+
+
+class TritonScheduling(SIMDScheduling):
+    kernel_type: type[Any] = TritonKernel
+    backend_features = OrderedSet(
+        [
+            BackendFeature.FOREACH,
+            BackendFeature.BUCKETIZE,
+            BackendFeature.INPLACE_BUFFERS,
+            BackendFeature.MASKED_SCATTER_WITH_INDEX,
+            BackendFeature.SCAN,
+            BackendFeature.SORT,
+            BackendFeature.TRITON_TEMPLATES,
+            BackendFeature.TUPLE_REDUCTION,
+        ]
+    )
+
+    def __init__(self, scheduler: Optional[Scheduler]) -> None:
+        super().__init__(scheduler)
+        if scheduler is None or not hasattr(scheduler, "nodes"):
+            return
+        for node in scheduler.nodes:
+            if isinstance(node, (SchedulerNode, FusedSchedulerNode)):
+                node.debug_device_str = debug_triton_code
+
+    @classmethod
+    def get_backend_features(cls, device: torch.device):
+        if (
+            config.triton.cooperative_reductions
+            or config.triton.force_cooperative_reductions
+        ):
+            return OrderedSet(
+                [*cls.backend_features, BackendFeature.REDUCE_TO_SINGLE_ELEMENT]
+            )
+        return cls.backend_features
+
+    def codegen_comment(self, node_schedule, kernel_name=None):
+        wrapper = V.graph.wrapper_code
+        origins, _detailed_origins = get_kernel_metadata(node_schedule, wrapper)
+        if origins:
+            wrapper.make_comment(origins)
+
+        if config.debug_fusion:
+            from torch._inductor.scheduler import (
+                BaseSchedulerNode,
+                ForeachKernelSchedulerNode,
+            )
+
+            if not any(
+                isinstance(n, ForeachKernelSchedulerNode) for n in node_schedule
+            ):
+                # We probably should look what are the nodes inside a foreach
+                # schedule node
+                node_names = [
+                    n.get_name()
+                    for n in node_schedule
+                    if isinstance(n, BaseSchedulerNode)
+                ]
+                wrapper.make_comment(
+                    f"{wrapper.comment} Fused node name list: {', '.join(node_names)}"
+                )
+
+        if kernel_name:
+            debug_handle = set_kernel_post_grad_provenance_tracing(
+                node_schedule,  # type: ignore[arg-type]
+                kernel_name,
+            )
+            wrapper.write_provenance_debug_handle(kernel_name, debug_handle)
+
+    def define_kernel(self, src_code, node_schedule, kernel):
+        wrapper = V.graph.wrapper_code
+        if src_code in wrapper.src_to_kernel:
+            kernel_name = wrapper.src_to_kernel[src_code]
+        else:
+            fused_name = (
+                get_fused_kernel_name(node_schedule, config.triton.descriptive_names)
+                if config.triton.descriptive_names
+                else ""
+            )
+            kernel_category = get_kernel_category_by_source_code(src_code)[:3]
+            kernel_name = "_".join(
+                ["triton", kernel_category, fused_name, wrapper.next_kernel_suffix()]
+            )
+            if config.aot_inductor.model_name_for_generated_files:
+                # When AOTI compiles multiple submodules, we need to use the model name to
+                # distinguish kernel related symbols.
+                kernel_name = f"{config.aot_inductor.model_name_for_generated_files}_{kernel_name}"
+
+            # use the original src_code as the key
+            wrapper.src_to_kernel[src_code] = kernel_name
+            subs_name = kernel_name if config.triton.unique_kernel_names else "triton_"
+
+            # DESCRIPTIVE_NAME is used for profiling purposes; it shows the full kernel name
+            # even when unique_kernel_names is turned off. Meanwhile, KERNEL_NAME is sometimes set
+            # to "triton_" to maximize caching opportunities (when unique_kernel_names = False).
+            src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name)
+            src_code = src_code.replace(str(Placeholder.KERNEL_NAME), subs_name)
+
+            # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does
+            # not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
+            src_code = src_code.replace("#pragma CMT", "#")
+
+            _basename, _, kernel_path = get_path(code_hash(src_code.strip()), "py")
+            compile_wrapper = IndentedBuffer()
+
+            if async_compile.use_process_pool():
+                # The process pool is warm, we can shell out to workers right away. This
+                # allows us to save the result in async_compile.CompiledTritonKernels,
+                # so that the second time we call async_compile.triton, we do no work.
+                async_compile.triton(subs_name, src_code)
+
+            compile_wrapper.writeline(f"async_compile.triton({subs_name!r}, '''")
+
+            compile_wrapper.splice(src_code, strip=True)
+            current_device = V.graph.get_current_device_or_throw()
+            compile_wrapper.writeline(f"''', device_str='{current_device.type}')")
+
+            metadata_comment = f"# kernel path: {kernel_path}"
+            origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
+            metadata_comment += "\n" + origins + "\n" + detailed_origins
+            wrapper.define_kernel(
+                kernel_name, compile_wrapper.getvalue(), metadata_comment
+            )
+
+            # log kernel metadata for offline analysis.
+            # E.g. one can find all unaligned inner reduction and check if
+            # padding helps with the perf kernel by kernel.
+            if metrics.is_metric_table_enabled("kernel_metadata"):
+                metrics.log_kernel_metadata(kernel_name, kernel_path, src_code)
+
+        return kernel_name
+
+    def benchmark_fused_nodes(self, nodes, n_spills_threshold=8) -> tuple[float, str]:
+        """
+        Benchmark fused list of nodes and return the execution time
+        in milliseconds on randomly generated inputs.
+        """
+        src_code = self.generate_kernel_code_from_nodes(nodes, benchmark_kernel=True)
+        mod = PyCodeCache.load(src_code)
+        return self.benchmark_codegened_module(
+            mod, n_spills_threshold, node_names=OrderedSet(n.get_name() for n in nodes)
+        )
+
+    def benchmark_codegened_module(
+        self, mod, n_spills_threshold=8, node_names: Optional[OrderedSet[str]] = None
+    ) -> tuple[float, str]:
+        """Benchmark an already compiled module"""
+        device_interface = get_interface_for_device(V.graph.device_type)
+        with (
+            preserve_rng_state(),
+            device_interface.device(V.graph.get_current_device_or_throw()),  # type: ignore[attr-defined]
+        ):
+            ms = None
+
+            def cache_file_path():
+                assert mod.__file__ is not None
+                return os.path.splitext(mod.__file__)[0] + ".kernel_perf"
+
+            def store_cache():
+                path = cache_file_path()
+                write_atomic(path, str(ms))
+
+            def load_cache():
+                path = cache_file_path()
+                if os.path.exists(path):
+                    with open(path) as fd:
+                        return float(fd.read())
+                return None
+
+            node_names = (
+                node_names if node_names is not None else OrderedSet(["unknown"])
+            )
+            log.debug(
+                "kernel src code for %s written to: %s",
+                node_names,
+                mod.__file__,
+            )
+            ms = load_cache()
+            if ms is not None:
+                return ms, mod.__file__
+
+            args = mod.get_args()
+            call = mod.call
+            wrapped_jit_function = mod.triton_
+            # call once to trigger the compilation
+            try:
+                call(wrapped_jit_function.clone_args(*args)[0])
+            except Exception as e:
+                if config.triton.disallow_failing_autotune_kernels_TESTING_ONLY:
+                    raise
+                log.debug(  # noqa: G200
+                    "Exception (%s) in compiling fused nodes %s",
+                    e,
+                    node_names,
+                )
+                ms = float("inf")
+                store_cache()
+                return ms, mod.__file__
+
+            launchers = wrapped_jit_function.launchers
+            assert len(launchers) == 1
+            # n_spills does not necessarily mean it's not profitable to fuse,
+            # and sometimes it can be inaccurate
+            if launchers[0].n_spills > n_spills_threshold:
+                # skip benchmarking the kernel if there are register spills
+                ms = float("inf")
+            else:
+                device = V.graph.get_current_device_or_throw()
+                # We have to clone the inplace updated arguments to avoid earlier calls
+                # generating out of range indices for later calls.
+                ms = benchmarker.benchmark(
+                    lambda: call(wrapped_jit_function.clone_args(*args)[0]),
+                    device=device,
+                )
+                # overhead of cloning args gives bias for fusing the kernel
+                # in the case of mutating/in-placeable second fusion
+                # TODO - would be better as a hook in triton do_bench that reset
+                # the input values between benchmarking
+                if len(wrapped_jit_function.mutated_arg_names) > 0:
+                    ms = ms - benchmarker.benchmark(
+                        lambda: wrapped_jit_function.clone_args(*args),
+                        device=str(device),
+                    )
+
+            log.debug(
+                "The fused kernel for %s took %.3f ms to run",
+                node_names,
+                ms,
+            )
+            store_cache()
+            return ms, mod.__file__
+
+    def create_kernel_choices(  # type: ignore[override]
+        self,
+        kernel_features: SIMDKernelFeatures,
+        kernel_args: list[Any],
+        kernel_kwargs: dict[str, Any],
+    ) -> list[TritonKernel]:
+        is_scan = kernel_features.contains_op("scan")
+        is_split_scan = is_scan and any(
+            node.is_split_scan() for node in kernel_features.scheduler_nodes()
+        )
+        kernel_type: type[TritonKernel] = self.kernel_type
+        if is_split_scan:
+            from .triton_split_scan import TritonSplitScanKernel
+
+            kernel_type = TritonSplitScanKernel
+
+        if is_scan:
+            # TODO(jansel): scan does not yet work with cooperative reductions
+            kernel_kwargs["override_cooperative_reduction"] = False
+
+        # ops.sort only works with persistent reduction, and is not bandwidth bound anyway
+        # so taking the hit of non-coalesced loads is okay
+        if kernel_features.contains_op("sort"):
+            kernel_kwargs["override_persistent_reduction"] = True
+            kernel_kwargs["override_cooperative_reduction"] = False
+
+        if not TritonKernel.has_persistent_RBLOCK(kernel_features.reduction_numel):
+            # Cannot use persistent reduction with unknown dynamic rnumel
+            assert not kernel_kwargs.get("override_persistent_reduction")
+            kernel_kwargs["override_persistent_reduction"] = False
+
+        kernel_kwargs = V.choices.triton_kernel_kwargs(
+            kernel_type, kernel_features, kernel_args, kernel_kwargs
+        )
+        kernel = kernel_type(*kernel_args, **kernel_kwargs)
+        return self.add_multi_kernel_choices(kernel, kernel_args, kernel_kwargs)
+
+    def add_multi_kernel_choices(
+        self,
+        kernel: TritonKernel,
+        kernel_args: list[Any],
+        kernel_kwargs: dict[str, Any],
+    ) -> list[TritonKernel]:
+        kernels: list[TritonKernel] = [kernel]
+        if not config.triton.multi_kernel:
+            return kernels
+
+        optional_persistent = kernel.persistent_reduction and not kernel_kwargs.get(
+            "override_persistent_reduction"
+        )
+        optional_cooperative = kernel.cooperative_reduction and not kernel_kwargs.get(
+            "override_cooperative_reduction"
+        )
+        if optional_persistent:
+            kernels.append(
+                self.kernel_type(
+                    *kernel_args,
+                    **kernel_kwargs,
+                    override_persistent_reduction=False,
+                )
+            )
+        if optional_cooperative:
+            rnumel = kernel.features.reduction_numel
+            # for larger sizes non-cooperative gets very slow
+            if V.graph.sizevars.statically_known_leq(rnumel, 65536):
+                kernels.append(
+                    other := self.kernel_type(
+                        *kernel_args,
+                        **kernel_kwargs,
+                        override_cooperative_reduction=False,
+                    )
+                )
+                if optional_persistent and other.persistent_reduction:
+                    kernels.append(
+                        self.kernel_type(
+                            *kernel_args,
+                            **kernel_kwargs,
+                            override_cooperative_reduction=False,
+                            override_persistent_reduction=False,
+                        )
+                    )
+
+        if len(kernels) > 1:
+            for kernel2 in kernels[1:]:
+                # Keep buffers needed by the non-persistent reduction so both kernels have the same arguments
+                kernel2.must_keep_buffers = kernel.must_keep_buffers
+            # persistent kernels must be generated last so must_keep_buffers works right
+            kernels.sort(key=lambda k: k.persistent_reduction)
+        return kernels
+
+    def benchmark_combo_kernel(self, node_list):
+        mod: ModuleType
+        ms: float
+        ms_clone: float
+
+        def cache_file_path():
+            assert mod.__file__ is not None
+            return os.path.splitext(mod.__file__)[0] + ".kernel_perf"
+
+        def load_cache():
+            path = cache_file_path()
+            if os.path.exists(path):
+                with open(path) as fd:
+                    return tuple(float(e) for e in fd.read().split())
+            return (None, None)
+
+        def store_cache():
+            path = cache_file_path()
+            write_atomic(path, str(ms) + " " + str(ms_clone))
+
+        total_ms, file_list = 0, []
+        total_clone_ms: float = 0.0
+        removed_buffers_orig = V.graph.removed_buffers
+        V.graph.removed_buffers = OrderedSet(removed_buffers_orig)
+        inplaced_to_remove_orig = V.graph.inplaced_to_remove
+        V.graph.inplaced_to_remove = OrderedSet(inplaced_to_remove_orig)
+        enable_autotune = config.combo_kernels_autotune > 0
+        mixed_sizes = config.combo_kernel_allow_mixed_sizes > 0
+        kernel_code_list = self.generate_combo_kernel_code(
+            subkernel_nodes=node_list,
+            custom_part_algorithm=True,
+            enable_autotune=enable_autotune,
+            mixed_sizes=mixed_sizes,
+            only_gen_src_code=True,
+        )
+
+        for src_code, _, node_group in kernel_code_list:
+            fused_node_lists = [node.get_nodes() for node in node_group]
+            names = [n.get_name() for nodes in fused_node_lists for n in nodes]
+
+            src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_")
+            mod = PyCodeCache.load(src_code)
+
+            log.debug(
+                "kernel src code for %s written to: %s",
+                names,
+                mod.__file__,
+            )
+            ms, ms_clone = load_cache()
+            if ms is not None:
+                total_ms += ms  # type: ignore[assignment]
+                total_clone_ms += ms_clone
+                file_list.append(mod.__file__)
+                continue
+
+            args = mod.get_args()
+            call = mod.call
+            wrapped_jit_function = mod.triton_
+
+            # call once to trigger the compilation
+            call(wrapped_jit_function.clone_args(*args)[0])
+
+            launchers = wrapped_jit_function.launchers
+            assert len(launchers) == 1
+            if launchers[0].n_spills > 0:
+                # skip benchmarking the kernel if there are register spills
+                ms = ms_clone = float("inf")
+            else:
+                device = V.graph.get_current_device_or_throw()
+                # We have to clone the inplace updated arguments to avoid earlier calls
+                # generating out of range indices for later calls.
+                ms = benchmarker.benchmark(
+                    lambda: call(wrapped_jit_function.clone_args(*args)[0]),
+                    device=device,
+                )
+                ms_clone = benchmarker.benchmark(
+                    lambda: wrapped_jit_function.clone_args(*args)[0],
+                    device=device,
+                )
+
+            log.debug(
+                "The fused kernel for %s took %.3f ms to run, %.3f ms to clone inputs",
+                OrderedSet(n.get_name() for n in node_group),
+                ms,
+                ms_clone,
+            )
+            store_cache()
+            total_ms += ms
+            total_clone_ms += ms_clone
+            file_list.append(mod.__file__)
+        V.graph.removed_buffers = removed_buffers_orig
+        V.graph.inplaced_to_remove = inplaced_to_remove_orig
+        return total_ms, total_clone_ms, file_list
+
+
+def debug_triton_code(node: BaseSchedulerNode) -> list[str]:
+    lines = []
+    multi_template = node.get_template_node()
+    assert multi_template is None or isinstance(multi_template, ir.MultiTemplateBuffer)
+    if multi_template and multi_template.make_kernel_render is None:
+        lines.append(f"{node.get_name()} Unfinalized multi template buffer")
+    else:
+        from torch._inductor.codegen.cuda_combined_scheduling import (
+            CUDACombinedScheduling,
+        )
+
+        device = node.get_device()
+        assert device is not None
+        backend = node.scheduler.get_backend(device)
+        assert isinstance(backend, (SIMDScheduling, CUDACombinedScheduling)), (
+            f"Scheduling backend should be SIMD or CUDACombined when generating debug Triton strings, got: {type(backend)}"
+        )
+
+        with V.graph.set_current_device(device):
+            # Don't increment kernel count when generating debug string.
+            # This will confuse some unit tests that check the number of
+            # generated kernels.
+            old_generated_kernel_count = metrics.generated_kernel_count
+            triton_code = backend.generate_kernel_code_from_nodes(
+                node.get_nodes()
+            ).strip()
+            metrics.generated_kernel_count = old_generated_kernel_count
+
+        lines.append(f"{node.get_name()} Triton code:")
+        lines.append(textwrap.indent(triton_code, "    "))
+    return lines
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/triton_combo_kernel.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/triton_combo_kernel.py
new file mode 100644
index 0000000000000000000000000000000000000000..74ed7d3797396deb07d5957d67af5bb22930e11a
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/triton_combo_kernel.py
@@ -0,0 +1,1037 @@
+import itertools
+import logging
+import textwrap
+from collections import defaultdict
+from collections.abc import Callable
+from dataclasses import dataclass
+from typing import Any, cast, Optional, Union
+
+import sympy
+from sympy import Integer, Symbol
+
+from torch.utils._ordered_set import OrderedSet
+
+from .. import config, metrics
+from ..runtime.hints import DeviceProperties
+from ..runtime.runtime_utils import next_power_of_2
+from ..runtime.triton_heuristics import (
+    RoundRobinComboKernelGrid,
+    SequentialComboKernelGrid,
+)
+from ..scheduler import BaseSchedulerNode
+from ..utils import Placeholder, triton_version_uses_attrs_dict
+from ..virtualized import V
+from .common import (
+    ArgName,
+    ConstexprArg,
+    DeferredLine,
+    IndentedBuffer,
+    InplacedBuffer,
+    Kernel,
+    PythonPrinter,
+    RemovedArg,
+    SizeArg,
+    WorkspaceArg,
+)
+from .simd import prefix_is_reduction, SIMDScheduling
+from .simd_kernel_features import SIMDKernelFeatures
+from .triton import gen_common_triton_imports, TritonKernel
+from .triton_utils import config_of, equal_1_arg_indices, signature_to_meta
+
+
+log = logging.getLogger(__name__)
+pexpr = PythonPrinter().doprint
+LARGE_NUMELS = 512e5
+BLOCK_UTILIZATION = 0.8
+
+
+def _default_custom_combo_kernel_horizontal_partition(
+    nodes: list[BaseSchedulerNode],
+    triton_scheduling: SIMDScheduling,
+    kernel_map: dict[BaseSchedulerNode, TritonKernel],
+    node_info_map: dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]],
+) -> list[list[BaseSchedulerNode]]:
+    """Horizontally partition the given list of nodes into a list of list of nodes where each sublist
+    represents a partition. Nodes in different partitions are implemented in different combo kernels.
+    Nodes in the same partition are likely to be implemented
+    in the same combo kernel, but subject to subsequent restrictions like CUDA limits for number of args.
+
+    Input arguments:
+        nodes: a list of fused scheduler nodes to partition.
+        triton_scheduling: TritonScheduling instance.
+        kernel_map: a map from node to its kernel.
+        node_info_map: a map from node to (node_schedule, tiled_groups, numel, rnumel).
+    Output:
+        a list of list of nodes with each sublist representing a partition.
+
+    The default algorithm is to partition nodes based on the following rules:
+        1) nodes with the same number of block dimensions are grouped together.
+        2) large pointwise nodes (numels greater than LARGE_NUMELS) are separated from other nodes.
+        3) large reduce nodes are separated from other nodes.
+    """
+
+    assert len(nodes) >= 1
+
+    # first partition nodes based on number of block dimensions
+    tilings = [node_info_map[n][1] for n in nodes]
+
+    max_dims = max(len(t) for t in tilings)
+    nodes_per_ndim: list[list[BaseSchedulerNode]] = []
+    for i in range(2, max_dims + 1):
+        group_per_dim = [n for n, t in zip(nodes, tilings) if len(t) == i]
+        reduction = [
+            n
+            for n in group_per_dim
+            if kernel_map[n].inside_reduction
+            and not (kernel_map[n].persistent_reduction and kernel_map[n].no_x_dim)
+        ]
+        not_reduction = [n for n in group_per_dim if n not in reduction]
+        # rnumel > 2048 usually has long execution time
+        # BaseSchedulerNode.group[-1][-1] is rnumel for reduction nodes
+        long_reduction = [
+            n
+            for n in reduction
+            if (
+                V.graph.sizevars.shape_env.has_hint(n.group[-1][-1])
+                and V.graph.sizevars.size_hint(n.group[-1][-1]) > 2048  # type: ignore[arg-type]
+            )
+        ]
+        short_reduction = [n for n in reduction if n not in long_reduction]
+        if long_reduction:
+            log.debug(
+                "ComboKernels: %d long reduction nodes are separated",
+                len(long_reduction),
+            )
+        large_pointwise = [
+            n
+            for n in not_reduction
+            if not kernel_map[n].inside_reduction
+            and len(kernel_map[n].numels) == 2
+            and V.graph.sizevars.shape_env.has_hint(kernel_map[n].numels["x"])
+            and V.graph.sizevars.size_hint(kernel_map[n].numels["x"]) > LARGE_NUMELS
+        ]
+        if large_pointwise:
+            # TODO benchmark the performance when large pointwise nodes combining with others
+            log.debug(
+                "ComboKernels: %d large pointwise nodes are separated",
+                len(large_pointwise),
+            )
+            not_reduction = [n for n in not_reduction if n not in large_pointwise]
+            nodes_per_ndim.extend([node] for node in large_pointwise)
+
+        nodes_per_ndim.extend(
+            g for g in (not_reduction, short_reduction, long_reduction) if g
+        )
+
+    assert sum(len(p) for p in nodes_per_ndim) == len(nodes)
+    return nodes_per_ndim
+
+
+_custom_combo_kernel_horizontal_partition_algorithm: Callable[
+    [
+        list[BaseSchedulerNode],
+        SIMDScheduling,
+        dict[BaseSchedulerNode, TritonKernel],
+        dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]],
+    ],
+    list[list[BaseSchedulerNode]],
+] = _default_custom_combo_kernel_horizontal_partition
+
+
+def set_custom_combo_kernel_horizontal_partition(
+    algorithm: Callable[
+        [
+            list[BaseSchedulerNode],
+            SIMDScheduling,
+            dict[BaseSchedulerNode, TritonKernel],
+            dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]],
+        ],
+        list[list[BaseSchedulerNode]],
+    ],
+) -> None:
+    """Sets the algorithm used to partition nodes into horizontal partitions. Nodes in different partitions
+    are implemented in different combo kernels. Nodes in the same partition are likely to be implemented
+    in the same combo kernel, but subject to subsequent restricts like CUDA limits for number of args.
+
+    The algorithm should take a list of nodes and return a list of list of nodes.
+
+    The default algorithm is to partition nodes based on number of block dimensions.
+    """
+    global _custom_combo_kernel_horizontal_partition_algorithm
+    _custom_combo_kernel_horizontal_partition_algorithm = algorithm
+
+
+@dataclass
+class PartitionState:
+    partitions: list[list[BaseSchedulerNode]]
+    cur_partition: list[BaseSchedulerNode]
+    cur_count: int
+
+    def finalize(self) -> None:
+        if self.cur_partition:
+            self.partitions.append(self.cur_partition)
+
+
+class ComboKernel(Kernel):
+    @staticmethod
+    def _update_partition(
+        partition_state: PartitionState,
+        node_rw_count: int,
+        node_info: BaseSchedulerNode,
+    ) -> None:
+        if partition_state.cur_count + node_rw_count > config.combo_kernel_max_num_args:
+            partition_state.partitions.append(partition_state.cur_partition)
+            partition_state.cur_partition = [node_info]
+            partition_state.cur_count = node_rw_count
+        else:
+            partition_state.cur_count += node_rw_count
+            partition_state.cur_partition.append(node_info)
+
+    @staticmethod
+    def _base_horizontal_partition(
+        subkernel_nodes: list[BaseSchedulerNode],
+        triton_scheduling: SIMDScheduling,
+        node_info_map: dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]],
+        custom_algorithm: bool,
+    ) -> list[list[BaseSchedulerNode]]:
+        """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel)
+        for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args
+        (read/writes) and to have the same 2D or 1D blocking strategy."""
+        # TODO support combination of kernels with different block dimensions
+        assert len(subkernel_nodes) >= 1
+        mixed_sizes = config.combo_kernel_allow_mixed_sizes > 1 or (
+            config.combo_kernel_allow_mixed_sizes == 1 and custom_algorithm
+        )
+
+        ndim_to_partition_state: dict[int, PartitionState] = defaultdict(
+            lambda: PartitionState([], [], 0)
+        )
+        yelem_to_partition_state: dict[int, PartitionState] = defaultdict(
+            lambda: PartitionState([], [], 0)
+        )
+
+        for node in subkernel_nodes:
+            _node_schedule, tiled_groups, _numel, _rnumel = node_info_map[node]
+            node_info = node
+
+            read_writes = node.read_writes
+            read_write_count = len(read_writes.reads) + len(read_writes.writes)
+
+            ndim = len(tiled_groups)
+            assert ndim >= 2, f"Combokernel not support tile {tiled_groups}"
+            if not mixed_sizes and ndim == 3:
+                y_elem = tiled_groups["y"]
+                partition_state = yelem_to_partition_state[y_elem]
+                ComboKernel._update_partition(
+                    partition_state, read_write_count, node_info
+                )
+            else:
+                assert mixed_sizes or ndim <= 3, f"No mixed sizes: tile {tiled_groups}"
+                partition_state = ndim_to_partition_state[ndim]
+                ComboKernel._update_partition(
+                    partition_state, read_write_count, node_info
+                )
+
+        all_partitions = []
+        for partition_state in ndim_to_partition_state.values():
+            partition_state.finalize()
+            all_partitions.extend(partition_state.partitions)
+        for partition_state in yelem_to_partition_state.values():
+            partition_state.finalize()
+            all_partitions.extend(partition_state.partitions)
+
+        return all_partitions
+
+    @staticmethod
+    def horizontal_partition(
+        nodes: list[BaseSchedulerNode],
+        triton_scheduling: SIMDScheduling,
+        kernel_map: dict[BaseSchedulerNode, TritonKernel],
+        node_info_map: dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]],
+        custom_algorithm: bool = False,
+    ) -> list[list[BaseSchedulerNode]]:
+        """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnum)
+        for each subkernel node where each sublist forms a ComboKernel. It horizontally partitions nodes into
+        sublists in the following way:
+            1) call _custom_combo_kernel_horizontal_partition_algorithm() if custom_algorithm is True
+            2) then, call _base_horizontal_partition() to partition nodes into sublists, each sublist is
+               guaranteed to not exceed CUDA limits for number of args (read/writes) and to have the same
+               2D or 1D blocking strategy.
+        """
+        if custom_algorithm:
+            raw_partitions = _custom_combo_kernel_horizontal_partition_algorithm(
+                nodes, triton_scheduling, kernel_map, node_info_map
+            )
+        else:
+            raw_partitions = [nodes]
+
+        """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel)
+        for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args
+        (read/writes) and to have the same 2D or 1D blocking strategy."""
+        all_partitions = []
+        for raw_partition in raw_partitions:
+            all_partitions.extend(
+                ComboKernel._base_horizontal_partition(
+                    raw_partition, triton_scheduling, node_info_map, custom_algorithm
+                )
+            )
+        return all_partitions
+
+    class SequentialDispatch:
+        """
+        The dispatcher which dispatches the subkernels in a sequential manner:
+        the blocks are first dispatched to the 1st subkernel (until it is filled),
+        then to the 2nd subkernel, and so on.
+        The class defines the methods specific to the dispatch algorithm.
+        Methods:
+            codegen_pid_range(...): codegen the pid range for each subkernel.
+            grid(...): codegen the grid size for launching the combo kernel.
+        """
+
+        grid_expr = SequentialComboKernelGrid
+
+        @classmethod
+        def codegen_pid_range(
+            cls, kernel: "ComboKernel", num: int, code: IndentedBuffer
+        ) -> None:
+            if num == 0:
+                cls._calculate_xblocks(kernel, code)
+                code.splice(f"if pid < num_xblocks_{num}:")
+                with code.indent():
+                    code.splice("pid_offset = pid")
+            else:
+                code.splice(f"elif pid < num_xblocks_{num}:")
+                with code.indent():
+                    code.splice(f"pid_offset = pid - num_xblocks_{num - 1}")
+
+        @classmethod
+        def _calculate_xblocks(
+            cls, kernel: "ComboKernel", code: IndentedBuffer
+        ) -> None:
+            x_numels_list = kernel.x_numels_list
+            for i in range(len(x_numels_list)):
+                xnumels, no_x_dim = (
+                    (x_numels_list[i], False)
+                    if isinstance(x_numels_list[i], str)
+                    and cast(str, x_numels_list[i])[0] != "-"
+                    or (
+                        isinstance(x_numels_list[i], int)
+                        and cast(int, x_numels_list[i]) > 0
+                    )
+                    else (kernel.min_x_blocks_list[i], True)
+                )
+                xblock_str = (
+                    f"tl.cdiv({xnumels}, XBLOCK)" if not no_x_dim else f"{xnumels}"
+                )
+                if i == 0:
+                    code.splice(f"num_xblocks_{i} = {xblock_str}")
+                else:
+                    code.splice(f"num_xblocks_{i} = num_xblocks_{i - 1} + {xblock_str}")
+
+    class RoundRobinDispatch:
+        """
+        The dispatcher which dispatches the subkernels in a round robin manner:
+        the blocks are interleavedly dispatched to each subkernel to execute them
+        in parallel.
+        The class defines the methods specific to the dispatch algorithm.
+        Methods:
+            codegen_pid_range(...): codegen the pid range for each subkernel.
+            grid(...): codegen the grid size for launching the combo kernel.
+        """
+
+        grid_expr = RoundRobinComboKernelGrid
+
+        @classmethod
+        def codegen_pid_range(
+            cls, kernel: "ComboKernel", num: int, code: IndentedBuffer
+        ) -> None:
+            num_kernels = len(kernel.sub_kernels)
+            if num == 0:
+                cond = "if"
+            else:
+                cond = "elif"
+            code.splice(f"{cond} pid % {num_kernels} == {num}:")
+            with code.indent():
+                code.splice(f"pid_offset = pid // {num_kernels}")
+
+    def __init__(
+        self, enable_autotune: bool = False, mixed_sizes: bool = False
+    ) -> None:
+        super().__init__()
+        self.sub_kernels: list[TritonKernel] = []
+        self.iter_vars_count = itertools.count()
+        self.grids: list[list[int]] = []
+        self.min_x_blocks_list: list[Union[int, str]] = []
+        self.x_numels_list: list[Union[int, str]] = []
+        self.enable_autotune = enable_autotune
+        self.mixed_sizes = mixed_sizes
+        self.dispatch_class: Optional[
+            type[Union[ComboKernel.SequentialDispatch, ComboKernel.RoundRobinDispatch]]
+        ] = None
+        self.block_args: list[str] = []
+        # there following are used when autotuning is disabled
+        self.block_size_1d = 1024  # Try tuning this value
+        self.block_size_2d = 32
+        self.num_warps = 8
+        self.block_size_reduce = 256
+        self.dynamic_shape_args: list[str] = []
+
+    def create_sub_kernel(self, triton_kernel: TritonKernel) -> TritonKernel:
+        sub_kernel = triton_kernel
+        # pyrefly: ignore [bad-assignment]
+        metrics.generated_kernel_count -= 1
+        sub_kernel.args = self.args
+        sub_kernel.iter_vars_count = self.iter_vars_count
+        sub_kernel.cse.iter_buffer_ids = self.cse.iter_buffer_ids
+        self.sub_kernels.append(sub_kernel)
+        return sub_kernel
+
+    @staticmethod
+    def create_triton_kernel(
+        tiling: dict[str, sympy.Expr],
+        features: SIMDKernelFeatures,
+        optimize_mask: bool,
+    ) -> TritonKernel:
+        """
+        Only allow optimize_mask=True when 1) sequential dispatch is used,
+        2) numels except x dimension are the same for each sub kernel.
+        """
+        return TritonKernel(
+            tiling,
+            features=features,
+            pid_cache={"tl.program_id(0)": "pid_offset"},
+            optimize_mask=optimize_mask,
+            # foreach kernels don't work with cooperative reductions
+            override_cooperative_reduction=False,
+        )
+
+    def codegen_static_numels_sub_kernel(
+        self, code: IndentedBuffer, sub_kernel: TritonKernel, num: int
+    ) -> list[str]:
+        """
+        We get a small speedup from hard coding numels if they are static.
+
+        This code stomps on the passed-in values by writing an constant to the top of the kernel.
+
+        In a kernel like:
+        def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
+
+        We would add
+        xnumel = 4096
+        rnumel = 768
+
+        After the signature, before the kernel code, if we decided to make these static. As its hardcoded, it becomes
+        a better signal to triton on how to unroll and do some static indexing. So, it's not so much that downstream
+        knows that its a static numel, as that you just plop a constant into the kernel.
+        """
+        grid = []
+        uniquify_block_sizes = []
+        for tree in sub_kernel.range_trees:
+            simplified_tree_numel = V.graph.sizevars.simplify(tree.numel)
+            if isinstance(simplified_tree_numel, (Integer, int)):
+                code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}")
+            else:
+                assert f"{tree.prefix}numel_{num}" in self.dynamic_shape_args
+                uniquify_block_sizes.append(f"{tree.prefix}numel")
+
+            # pyrefly: ignore [missing-argument]
+            if not tree.is_reduction:
+                if isinstance(simplified_tree_numel, (Integer, int)):
+                    grid.append(int(simplified_tree_numel))
+                else:
+                    # pyrefly: ignore [bad-argument-type]
+                    grid.append(f"{tree.prefix}numel_{num}")
+
+            if tree.is_reduction and sub_kernel.persistent_reduction:
+                if isinstance(simplified_tree_numel, (Integer, int)):
+                    val = int(simplified_tree_numel)
+                else:
+                    raise RuntimeError(
+                        "Dynamic shape on reduction dimension is not supported"
+                    )
+                val = next_power_of_2(val)
+                code.writeline(
+                    f"{tree.prefix.upper()}BLOCK_{num}: tl.constexpr = {val}"
+                )
+                uniquify_block_sizes.append(f"{tree.prefix.upper()}BLOCK")
+
+            if tree.prefix == "x" and sub_kernel.no_x_dim:
+                code.writeline(f"XBLOCK_{num}: tl.constexpr = 1")
+                uniquify_block_sizes.append("XBLOCK")
+        self.grids.append(grid)
+        return uniquify_block_sizes
+
+    def min_x_blocks_sub_kernel(self, sub_kernel: TritonKernel, num: int) -> None:
+        """
+        Kernels with no_x_dim being true has no tunable XBLOCK. They have a fixed number of X blocks.
+        Grid calculation needs to make sure that they are assigned with enough number of blocks.
+        """
+        min_x_blocks: Union[int, str] = 0
+        x_numels: Union[int, str] = 0
+        for tree in sub_kernel.range_trees:
+            simplified_tree_numel = V.graph.sizevars.simplify(tree.numel)
+            if tree.prefix == "x":
+                if isinstance(simplified_tree_numel, (Integer, int)):
+                    x_numels = int(simplified_tree_numel)
+                else:
+                    x_numels = f"{tree.prefix}numel_{num}"
+                if sub_kernel.no_x_dim:
+                    min_x_blocks = x_numels
+                    x_numels = (
+                        # pyrefly: ignore [unsupported-operation]
+                        -min_x_blocks
+                        if isinstance(x_numels, int)
+                        # pyrefly: ignore [redundant-cast]
+                        else "-" + cast(str, x_numels)
+                    )
+                else:
+                    if isinstance(simplified_tree_numel, (Integer, int)):
+                        x_numels = int(simplified_tree_numel)
+                    else:
+                        x_numels = f"{tree.prefix}numel_{num}"
+        self.min_x_blocks_list.append(min_x_blocks)
+        self.x_numels_list.append(x_numels)
+
+    def select_heuristics(self, sub_kernel: TritonKernel) -> tuple[str, dict[str, int]]:
+        size_hints = {
+            prefix: next_power_of_2(
+                V.graph.sizevars.size_hint(
+                    numel, fallback=config.unbacked_symint_fallback
+                )
+            )
+            for prefix, numel in sub_kernel.numels.items()
+            if not prefix_is_reduction(prefix) or sub_kernel.inside_reduction
+        }
+        if sub_kernel.persistent_reduction:
+            assert sub_kernel.inside_reduction
+            heuristics = "persistent_reduction"
+        elif sub_kernel.inside_reduction:
+            heuristics = "reduction"
+        else:
+            heuristics = "pointwise"
+        return heuristics, size_hints
+
+    def select_combo_heuristics(
+        self, heuristics_list: list[str], size_hints_list: list[dict[str, int]]
+    ) -> tuple[str, dict[str, int], TritonKernel]:
+        if not self.enable_autotune:
+            return "foreach", size_hints_list[0], self.sub_kernels[0]
+        if "reduction" in heuristics_list:
+            i, _ = max(
+                enumerate(size_hints_list),
+                key=lambda x: x[1]["x"] if heuristics_list[x[0]] == "reduction" else 0,
+            )
+            return heuristics_list[i], size_hints_list[i], self.sub_kernels[i]
+        elif "pointwise" in heuristics_list:
+            i, _ = max(
+                enumerate(size_hints_list),
+                key=lambda x: x[1]["x"] if heuristics_list[x[0]] == "pointwise" else 0,
+            )
+            # modify size_hint to avoid oom check fail (may be a false alarm)
+            num_pointwise = len([e for e in heuristics_list if e == "pointwise"])
+            num_reduction = len([e for e in heuristics_list if e == "reduction"])
+            num_persistent_reduction = len(
+                [e for e in heuristics_list if e == "persistent_reduction"]
+            )
+            assert num_reduction == 0, (
+                "combining pointwise and reduction are not supported yet."
+            )
+            heuristics = (
+                "pointwise_with_reduction"
+                if num_persistent_reduction > 0
+                else "pointwise"
+            )
+            if len(heuristics_list) - num_pointwise >= 4:
+                size_hints = size_hints_list[i]
+                size_hints["x"] = min(128, size_hints["x"])
+            return heuristics, size_hints_list[i], self.sub_kernels[i]
+        else:
+            return heuristics_list[0], size_hints_list[0], self.sub_kernels[0]
+
+    def get_mutated_args_sub_kernels(self) -> list[str]:
+        mutated_args: OrderedSet[str] = OrderedSet()
+        for sub_kernel in self.sub_kernels:
+            for mutation in sub_kernel.mutations:
+                if mutation in sub_kernel.args.input_buffers:
+                    mutated_args.add(sub_kernel.args.input_buffers[mutation])
+                if (
+                    mutation in sub_kernel.args.inplace_buffers
+                    and mutation not in V.graph.removed_buffers
+                    and mutation not in sub_kernel.removed_buffers
+                ):
+                    mutated_args.add(
+                        cast(
+                            InplacedBuffer, sub_kernel.args.inplace_buffers[mutation]
+                        ).inner_name
+                    )
+                if mutation in sub_kernel.args.output_buffers:
+                    arg = sub_kernel.args.output_buffers[mutation]
+                    assert not isinstance(arg, RemovedArg)
+                    mutated_args.add(arg)
+        return sorted(mutated_args)
+
+    def select_dispatch_strategy(self) -> None:
+        if self.dispatch_class is not None:
+            return
+        # mixed_sizes is used for optimize_mask, so it only allows sequential dispatch
+        # Not mixed sizes on y dim technically is ok to use round robin as wells.
+        if not self.mixed_sizes or any(isinstance(e, str) for e in self.x_numels_list):
+            # str in x_numels_list means a dynamic shape
+            self.dispatch_class = ComboKernel.SequentialDispatch
+            return
+        # A negative x_blocks_list element means the kernel is not tunable,
+        # i.e., no_x_dim = True
+        x_numels_list = [abs(cast(int, e)) for e in self.x_numels_list]
+        total = max(x_numels_list) * len(x_numels_list)
+        needed = sum(x_numels_list)
+        if needed / total > BLOCK_UTILIZATION:
+            # Introduced overhead (masked blocks) is less than 20%
+            self.dispatch_class = ComboKernel.RoundRobinDispatch
+        else:
+            self.dispatch_class = ComboKernel.SequentialDispatch
+
+    def jit_line(
+        self,
+        heuristics: str,
+        size_hints: dict[str, int],
+        selected_kernel: TritonKernel,
+        signature: list[Any],
+        argdefs: list[ArgName],
+        pointwise_with_reduce: bool = False,
+    ) -> str:
+        can_use_32bit = all(k.index_dtype == "tl.int32" for k in self.sub_kernels)
+        size_dtype = "tl.int32" if can_use_32bit else "tl.int64"
+        for i, sub in enumerate(self.sub_kernels):
+            self.min_x_blocks_sub_kernel(sub, i)
+        self.select_dispatch_strategy()
+        triton_meta = {
+            "signature": signature_to_meta(
+                signature, size_dtype=size_dtype, argdefs=argdefs
+            ),
+            "device": DeviceProperties.create(V.graph.get_current_device_or_throw()),
+            "constants": {},
+        }
+
+        for arg_num in equal_1_arg_indices(signature):
+            triton_meta["constants"][signature[arg_num].name] = 1  # type: ignore[index,union-attr]
+
+        # pyrefly: ignore [unsupported-operation]
+        triton_meta["configs"] = [config_of(signature)]
+        mutated_args = self.get_mutated_args_sub_kernels()
+        dispatch = self.dispatch_class
+        assert dispatch is not None
+        inductor_meta = {
+            "grid_type": dispatch.grid_expr.__name__,
+            "combo_grid_meta": self.combo_grid_meta(),
+            "kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
+            "mutated_arg_names": mutated_args,
+            **TritonKernel.inductor_meta_common(),
+        }
+
+        sub_kernel = selected_kernel
+        if heuristics == "foreach":
+            heuristics_line = f"""
+                @triton_heuristics.foreach(
+                    filename=__file__,
+                    triton_meta={triton_meta!r},
+                    inductor_meta={inductor_meta!r},
+                )
+                @triton.jit
+            """
+        elif sub_kernel.inside_reduction:
+            reduction_hint = sub_kernel.features.get_reduction_hint()
+            heuristics_line = f"""
+                @triton_heuristics.{heuristics}(
+                    size_hints={size_hints!r},
+                    reduction_hint={reduction_hint},
+                    filename=__file__,
+                    triton_meta={triton_meta!r},
+                    inductor_meta={inductor_meta!r}
+                )
+                @triton.jit
+            """
+        else:
+            tile_hint = ""
+            if len(size_hints) == 2:
+                tile_hint = "tile_hint=TileHint.SQUARE,"
+            else:
+                tile_hint = "tile_hint=TileHint.DEFAULT,"
+            heuristics_line = f"""
+                @triton_heuristics.{heuristics}(
+                    size_hints={size_hints!r}, {tile_hint}
+                    filename=__file__,
+                    triton_meta={triton_meta!r},
+                    inductor_meta={inductor_meta!r}
+                )
+                @triton.jit
+            """
+
+        return heuristics_line
+
+    def codegen_blocks(self, code: IndentedBuffer) -> None:
+        for block in self.block_args:
+            assert block in (
+                "XBLOCK",
+                "YBLOCK",
+                "R0_BLOCK",
+            ), f"{block} is not supported without autotuning"
+        if "YBLOCK" in self.block_args:
+            code.splice(f"XBLOCK: tl.constexpr = {self.block_size_2d}")
+            code.splice(f"YBLOCK: tl.constexpr = {self.block_size_2d}")
+        else:
+            code.splice(f"XBLOCK: tl.constexpr = {self.block_size_1d}")
+        if "R0_BLOCK" in self.block_args:
+            code.splice(f"R0_BLOCK: tl.constexpr = {self.block_size_reduce}")
+            code.splice(f"RBLOCK: tl.constexpr = {self.block_size_reduce}")
+
+    def get_block_args(self) -> list[ConstexprArg]:
+        """
+        Calculate blocks from sub_kernels and range_trees.
+        **Update self.block_args**
+        Return the block args
+        """
+        block_names = {}
+        for sub_kernel in self.sub_kernels:
+            # TODO: we assume all sub_kernels have the same block size
+            for tree in sub_kernel.range_trees:
+                # pyrefly: ignore [missing-argument]
+                if tree.is_reduction and (
+                    not sub_kernel.inside_reduction or sub_kernel.persistent_reduction
+                ):
+                    continue
+                if tree.prefix == "x" and sub_kernel.no_x_dim:
+                    continue
+                block_names[f"{tree.prefix.upper()}BLOCK"] = tree.prefix
+        self.block_args = list(block_names.keys())
+
+        return [ConstexprArg(x) for x in block_names]
+
+    def add_numel_to_args(
+        self, argdefs: list[ArgName], signature: list[Any]
+    ) -> list[ArgName]:
+        for num, sub_kernel in enumerate(self.sub_kernels):
+            for tree in sub_kernel.active_range_trees():
+                if not isinstance(tree.numel, (Integer, int)):
+                    # only if it is a dynamic shape
+                    sizearg = SizeArg(f"{tree.prefix}numel_{num}", tree.numel)
+                    signature.append(sizearg)
+                    argdefs.append(ArgName(f"{tree.prefix}numel_{num}"))
+                    self.dynamic_shape_args.append(f"{tree.prefix}numel_{num}")
+        return argdefs
+
+    def add_numel_to_call_args(
+        self, name: str, call_args: list[Any], arg_types: list[Any]
+    ) -> None:
+        for num, sub_kernel in enumerate(self.sub_kernels):
+            for tree in sub_kernel.range_trees:
+                numel_name = f"{tree.prefix}numel_{num}"
+                if numel_name not in self.dynamic_shape_args:
+                    continue
+                if isinstance(tree.numel, (Integer, Symbol)):
+                    expr = tree.numel
+                else:
+                    expr = V.graph.wrapper_code.generate_numel_expr(
+                        name, tree, suffix=str(num)
+                    )
+                # pyrefly: ignore [missing-argument]
+                if not tree.is_reduction or sub_kernel.inside_reduction:
+                    call_args.append(expr)
+                    arg_types.append(type(expr))
+
+    def kernel_benchmark_extra_args(self) -> list[str]:
+        extra_args = []
+        for num, sub_kernel in enumerate(self.sub_kernels):
+            for tree in sub_kernel.range_trees:
+                numel_name = f"{tree.prefix}numel_{num}"
+                if numel_name not in self.dynamic_shape_args:
+                    continue
+                # pyrefly: ignore [missing-argument]
+                if not tree.is_reduction or sub_kernel.inside_reduction:
+                    extra_args.append(
+                        str(
+                            V.graph.sizevars.size_hint(
+                                tree.numel, fallback=config.unbacked_symint_fallback
+                            )
+                        )
+                    )
+        return extra_args
+
+    def codegen_kernel(self, name: Optional[str] = None) -> str:
+        # TODO: is it correct to use the first sub kernel's heuristics?
+        heuristics_list, size_hints_list = [], []
+        for subkernel in self.sub_kernels:
+            h, s = self.select_heuristics(subkernel)
+            heuristics_list.append(h)
+            size_hints_list.append(s)
+        heuristics, size_hints, selected_kernel = self.select_combo_heuristics(
+            heuristics_list, size_hints_list
+        )
+        pointwise_with_reduction, heuristics = (
+            (True, "pointwise")
+            if heuristics == "pointwise_with_reduction"
+            else (False, heuristics)
+        )
+        code = IndentedBuffer()
+
+        code.splice(gen_common_triton_imports())
+        if config.benchmark_combo_kernel:
+            code.splice(self.imports_for_benchmark_kernel())
+
+        seen_helpers: OrderedSet[str] = OrderedSet()
+        for sub_kernel in self.sub_kernels:
+            for helper in sub_kernel.helper_functions:
+                if helper not in seen_helpers:
+                    code.writeline("")
+                    code.splice(helper)
+                    seen_helpers.add(helper)
+
+        argdefs, _, signature, _ = self.args.python_argdefs()
+        argdefs = self.add_numel_to_args(argdefs, signature)
+        block_args = self.get_block_args()
+        if self.enable_autotune:
+            argdefs.extend([ArgName(x.name, is_constexpr=True) for x in block_args])
+            if triton_version_uses_attrs_dict():
+                signature.extend(block_args)
+
+        code.splice(
+            self.jit_line(
+                heuristics,
+                size_hints,
+                selected_kernel,
+                pointwise_with_reduce=pointwise_with_reduction,
+                signature=signature,
+                argdefs=argdefs,
+            )
+        )
+        code.writeline(
+            f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(x.full_name() for x in argdefs)}):"
+        )
+
+        with code.indent():
+            code.splice("pid = tl.program_id(0)")
+            if not self.enable_autotune:
+                self.codegen_blocks(code)
+
+            for num, sub_kernel in enumerate(self.sub_kernels):
+                assert self.dispatch_class is not None
+                self.dispatch_class.codegen_pid_range(self, num, code)
+                with code.indent():
+                    uniquify = self.codegen_static_numels_sub_kernel(
+                        code, sub_kernel, num
+                    )
+                    sub_kernel.codegen_body()
+                    uniquified_body = self.uniquify_block_sizes(
+                        sub_kernel.body, num, uniquify
+                    )
+                    code.splice(uniquified_body)
+
+            code.splice("else:")
+            with code.indent():
+                code.splice("pass")
+
+        if config.benchmark_combo_kernel:
+            code.splice(self.codegen_kernel_benchmark(num_gb=0))
+
+        return code.getvalue()
+
+    def codegen_kernel_benchmark(self, num_gb: float) -> IndentedBuffer:
+        """
+        Generates Python code for benchmarking this combo kernel.
+        - Creates example inputs (random tensors, constants, sizes).
+        - Runs the kernel on the current GPU/stream.
+        - Prints runtime (ms) and throughput (GB/s) using `num_gb`.
+        Args:
+            num_gb (float): The number of gigabytes to use for throughput calculation.
+        Returns:
+            IndentedBuffer: A buffer containing the generated Python benchmark code.
+        """
+        result = IndentedBuffer()
+        _argdefs, call_args, signature, _ = self.args.python_argdefs()
+        result.writelines(["", "", "def get_args():"])
+        with result.indent():
+            name_cnt = itertools.count()
+            var_names = []
+            for arg_name, arg_sig in zip(call_args, signature):
+                var_name = f"arg_{next(name_cnt)}"
+                buf = V.graph.try_get_buffer(arg_name)
+                if buf:
+                    size = V.graph.sizevars.size_hints(
+                        buf.get_size(), fallback=config.unbacked_symint_fallback
+                    )
+                    stride = V.graph.sizevars.size_hints(
+                        buf.get_stride(), fallback=config.unbacked_symint_fallback
+                    )
+                    result.writeline(
+                        f"{var_name} = rand_strided({size}, {stride}, device='{buf.get_device()}', dtype={buf.get_dtype()})"  # noqa: B950 line too long
+                    )
+                elif arg_name in V.graph.constants:
+                    # note that random seed is put in V.graph.constants
+                    const_tensor = V.graph.constants[arg_name]
+                    size = V.graph.sizevars.size_hints(
+                        const_tensor.size(), fallback=config.unbacked_symint_fallback
+                    )
+                    stride = V.graph.sizevars.size_hints(
+                        const_tensor.stride(), fallback=config.unbacked_symint_fallback
+                    )
+                    result.writeline(
+                        f"{var_name} = rand_strided({size}, {stride}, device='{const_tensor.device}', dtype={const_tensor.dtype})"  # type: ignore[arg-type]  # noqa: B950 line too long
+                    )
+                elif isinstance(arg_sig, SizeArg):
+                    symval_hint = V.graph.sizevars.size_hint(arg_sig.expr)
+
+                    # Force the seed_offset to be 0 so calls to the same kernel
+                    # using different seed offset will have the same benchmark harness.
+                    # We can dedup kernel definitions in this case.
+                    if "seed_offset" in arg_sig.name:
+                        symval_hint = 0
+                    result.writeline(f"{var_name} = {symval_hint}")
+                elif isinstance(arg_sig, WorkspaceArg):
+                    device = V.graph.get_current_device_or_throw()
+                    count = V.graph.sizevars.size_hint(arg_sig.count)
+                    # for benchmark harness, we ignore arg_sig.zero_mode and always zero it
+                    result.writeline(
+                        f"{var_name} = torch.zeros({count}, device='{device}', dtype={arg_sig.dtype})"
+                    )
+                else:
+                    raise KeyError(
+                        f"Don't find the buffer or const tensor for {arg_name}"
+                    )
+                var_names.append(var_name)
+            if self.dynamic_shape_args:
+                var_names.extend(self.kernel_benchmark_extra_args())
+            result.writeline(f"return {', '.join(var_names)},")
+
+        result.writelines(["\n", "\n", "def call(args):"])
+        device = V.graph.get_current_device_or_throw()
+        index = V.graph.get_current_device_or_throw().index
+        with result.indent():
+            result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
+            with result.indent():
+                result.writeline(
+                    V.graph.device_ops.set_device(index)
+                )  # no-op to ensure context
+                stream_name = f"stream{index}"
+                result.writeline(f"{stream_name} = get_raw_stream({index})")
+                result.writeline(
+                    f"{str(Placeholder.KERNEL_NAME)}.run(*args, stream={stream_name})"
+                )
+
+        # benchmark all configs
+        result.writelines(["\n", "\n", "def benchmark_all_configs(args):"])
+        with result.indent():
+            result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
+            with result.indent():
+                result.writeline(
+                    V.graph.device_ops.set_device(index)
+                )  # no-op to ensure context
+                result.writeline(
+                    f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args)"
+                )
+
+        result.writelines(["\n", "\n", "if __name__ == '__main__':"])
+        with result.indent():
+            result.writeline(
+                "from torch._inductor.runtime.benchmarking import benchmarker"
+            )
+            result.writeline("")
+
+            result.writeline("args = get_args()")
+            result.writeline(
+                f"ms = benchmarker.benchmark(call, fn_args=(args,), device={device.type},rep=40)"
+            )
+            result.writeline(f"num_gb = {num_gb}")
+            result.writeline("gb_per_s = num_gb / (ms / 1e3)")
+            result.writeline(
+                'print(f"{ms:.3f}ms    {num_gb:.3f}GB    {gb_per_s:.2f}GB/s")'
+            )
+
+        return result
+
+    def imports_for_benchmark_kernel(self) -> str:
+        return textwrap.dedent(
+            """
+            from torch._dynamo.testing import rand_strided
+            {}
+            import torch
+        """.format(V.graph.device_ops.import_get_raw_stream_as("get_raw_stream"))
+        )
+
+    def uniquify_block_sizes(
+        self, code: IndentedBuffer, num_kernel: int, uniquify: list[str]
+    ) -> IndentedBuffer:
+        if not uniquify:
+            return code
+        modified = IndentedBuffer(initial_indent=code._indent)
+        for line in code._lines:
+            if isinstance(line, str) and (blocks := [e for e in uniquify if e in line]):
+                modified_line = line
+                for block in blocks:
+                    modified_line = modified_line.replace(
+                        block, f"{block}_{num_kernel}"
+                    )
+                modified.writeline(modified_line)
+            elif isinstance(line, DeferredLine) and (
+                blocks := [e for e in uniquify if e in line.line]
+            ):
+                modified_line = line.line
+                for block in blocks:
+                    modified_line = modified_line.replace(
+                        block, f"{block}_{num_kernel}"
+                    )
+                new_line = DeferredLine(line.name, modified_line)
+                modified.writeline(new_line)
+            else:
+                modified.writeline(line)
+        return modified
+
+    def call_kernel(self, code: IndentedBuffer, name: str) -> None:
+        _, call_args, _, arg_types = self.args.python_argdefs()
+
+        wrapper = V.graph.wrapper_code
+        assert self.dispatch_class is not None
+        if self.dynamic_shape_args:
+            self.add_numel_to_call_args(name, call_args, arg_types)
+
+        wrapper.generate_kernel_call(
+            name,
+            call_args,
+            triton=True,
+            arg_types=arg_types,
+        )
+
+    def combo_grid_meta(self) -> dict[str, Any]:
+        dynamic_shape = bool(self.dynamic_shape_args)
+        num_kernels = len(self.sub_kernels)
+        min_blocks = (
+            max(self.min_x_blocks_list) * num_kernels if not dynamic_shape else None
+        )
+
+        if not self.enable_autotune:
+            if "YBLOCK" in self.block_args:
+                default_config = {
+                    "XBLOCK": self.block_size_2d,
+                    "YBLOCK": self.block_size_2d,
+                }
+            else:
+                default_config = {"XBLOCK": self.block_size_1d}
+        else:
+            default_config = None
+
+        meta = {
+            "num_kernels": num_kernels,
+            "min_blocks": min_blocks,
+            "default_config": default_config,
+        }
+
+        for num, sub_kernel in enumerate(self.sub_kernels):
+            meta[f"no_x_dim_{num}"] = sub_kernel.no_x_dim
+            for tree in sub_kernel.range_trees:
+                # pyrefly: ignore [missing-argument]
+                if not tree.is_reduction:
+                    numel_name = f"{tree.prefix}numel_{num}"
+                    if numel_name in self.dynamic_shape_args:
+                        meta[numel_name] = None
+                    else:
+                        meta[numel_name] = int(V.graph.sizevars.simplify(tree.numel))
+
+        return meta
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/triton_split_scan.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/triton_split_scan.py
new file mode 100644
index 0000000000000000000000000000000000000000..0abee5439393980560347aa07f6baf3f24f3e35f
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/triton_split_scan.py
@@ -0,0 +1,224 @@
+# mypy: allow-untyped-defs
+import functools
+from typing import Union
+
+import sympy
+
+from torch._inductor import config
+from torch._inductor.codegen.simd import IterationRangesRoot, prefix_is_reduction
+from torch._inductor.codegen.triton import (
+    triton_compute_type,
+    TritonCSEVariable,
+    TritonKernel,
+)
+from torch._inductor.runtime.triton_heuristics import SplitScanGrid
+from torch.utils._ordered_set import OrderedSet
+from torch.utils._sympy.functions import CeilDiv
+
+from ..utils import sympy_product
+
+
+class TritonSplitScanKernel(TritonKernel):
+    """Generates a triton kernel that supports ops.scan calls while also splitting
+    the reduction dimension over multiple triton programs.
+
+    For this kernel, loop numels will always take the form ``(xdim, rdim)``
+    and the grid has the shape ``(CeilDiv(rdim, RBLOCK), xdim)``. Communication
+    between blocks occurs within a global memory workspace buffer, which
+    must be zero-filled before launching the kernel.
+
+    Note that generation for ``ops.reduction`` is not supported.
+
+    For details of the communication strategy, see
+    https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back
+
+    """
+
+    def __init__(
+        self,
+        tiling: dict[str, sympy.Expr],
+        pid_cache=None,
+        fixed_config=None,
+        **kwargs,
+    ) -> None:
+        assert pid_cache is None, "not supported"
+        assert fixed_config is None, "not supported"
+        super().__init__(
+            tiling,
+            **kwargs,
+        )
+        self.no_x_dim = True
+
+    def should_use_persistent_reduction(self) -> bool:
+        return False
+
+    def should_use_cooperative_reduction(self) -> bool:
+        return False
+
+    def initialize_range_tree(self, pid_cache):
+        prefixes = ["y", "x", "r0_"]
+        assert len(self.numels) <= len(prefixes), (
+            "z dimension not supported for split scan"
+        )
+        active_prefixes = prefixes[len(prefixes) - len(self.numels) :]
+
+        grid_dims = {"r0_": 0, "x": 1, "y": 2}
+        for prefix in active_prefixes:
+            numel = self.numels[prefix]
+            tensor_dim = 0 if prefix_is_reduction(prefix) else None
+            grid_dim = grid_dims[prefix]
+            self.range_trees.append(
+                IterationRangesRoot(
+                    f"{prefix}index",
+                    numel,
+                    prefix,
+                    grid_dim,
+                    self,  # type: ignore[arg-type]
+                    pid_cache=pid_cache,
+                    is_loop=False,
+                    tensor_dim=tensor_dim,
+                    grid_dim=grid_dim,
+                    has_zdim=False,
+                )
+            )
+
+    def reduction(self, dtype, src_dtype, reduction_type, value):
+        raise NotImplementedError("NYI TritonSplitDimKernel reductions")
+
+    def scan(self, dtypes, combine_fn, values):
+        """
+        Perform an associative scan on 'values'.
+        """
+        import triton.language as tl
+
+        (dtype,) = dtypes
+        (value,) = values
+
+        compute_type = triton_compute_type(dtype)
+        compute_type_triton = getattr(tl, compute_type[3:])
+
+        element_nbits = compute_type_triton.primitive_bitwidth
+
+        scratch_type = "tl.uint32" if element_nbits <= 16 else "tl.uint64"
+        scratch_type_triton = getattr(tl, scratch_type[3:])
+        scratch_elems_per_block = 3 if element_nbits == 64 else 1
+        scratch_nbytes_per_block = scratch_elems_per_block * (
+            scratch_type_triton.primitive_bitwidth // 8
+        )
+
+        cse_load = functools.partial(self.cse.generate, self.loads, dtype=dtype)
+        cse_compute = functools.partial(self.cse.generate, self.compute)
+
+        assert len(self.numels) == 2, "Unexpected tiling"
+        min_rblock = config.triton.min_split_scan_rblock
+        reduction_numel = sympy_product(
+            numel
+            for prefix, numel in self.numels.items()
+            if prefix_is_reduction(prefix)
+        )
+        pointwise_numel = sympy_product(
+            numel
+            for prefix, numel in self.numels.items()
+            if not prefix_is_reduction(prefix)
+        )
+        max_blocks = pointwise_numel * CeilDiv(reduction_numel, min_rblock)
+        nbytes = scratch_nbytes_per_block * max_blocks
+        scratch_base: Union[str, TritonCSEVariable]
+        scratch_base, _, offset = self.args.workspace(nelem=nbytes, zero_fill=True)
+        if offset != 0:
+            scratch_base = cse_load(
+                f"{scratch_base} + {self.index_to_str(offset)}", shape=()
+            )
+        runtime_rblocks = cse_load(
+            f"tl.num_programs({self.range_trees[-1].index})", shape=()
+        )
+        scratch_base = cse_load(
+            f"{scratch_base}.to(tl.pointer_type({scratch_type})) + xoffset * "
+            f"{scratch_elems_per_block} * {runtime_rblocks}",
+            shape=(),
+        )
+
+        masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees)
+        self.filter_masks(masks)
+        assert not self._load_mask, "ops.scan not supported inside ops.masked"
+
+        value = cse_compute(
+            f"{value}.to({compute_type})",
+            dtype=dtype,
+            shape=value.shape,
+        )
+        value = cse_compute(
+            f"tl.broadcast_to({value}, {self.dense_size_str()})",
+            dtype=dtype,
+            shape=self.dense_size_list(),
+        )
+
+        combine_helper_fn = self._lift_helper(combine_fn, (value,), (dtype,))
+        dim = self.triton_tensor_ndim() - 1
+        assert dim == 0, ""
+        shape = list(self.dense_size_list())
+        del shape[dim]
+
+        block_sum = cse_compute(
+            f"tl.reduce({value}, {dim}, {combine_helper_fn})",
+            dtype=dtype,
+            shape=shape,
+        )
+        exclusive_prefix = self.cse.newvar(
+            dtype=dtype,
+            shape=shape,
+        )
+        if element_nbits == 64:
+            self.compute.splice(
+                f"""
+                {exclusive_prefix} = triton_helpers.exclusive_scan_decoupled_lookback_64(
+                    {scratch_base},
+                    {block_sum},
+                    {self.iteration_ranges_get_pid(self.range_trees[-1])},
+                    {combine_helper_fn},
+                )
+                """,
+                strip=True,
+            )
+
+        else:
+            assert element_nbits <= 32
+            value_as_uint_dtype = f"tl.uint{element_nbits}"
+
+            self.compute.splice(
+                f"""
+                {exclusive_prefix} = triton_helpers.exclusive_scan_decoupled_lookback(
+                    {scratch_base},
+                    {block_sum},
+                    {self.iteration_ranges_get_pid(self.range_trees[-1])},
+                    {combine_helper_fn},
+                    DTYPE_VALUE_AS_UINT={value_as_uint_dtype},
+                    DTYPE_PACK={scratch_type},
+                )
+                """,
+                strip=True,
+            )
+        # Compute final cumsum
+        block_scan = cse_compute(
+            f"tl.associative_scan({value}, {dim}, {combine_helper_fn})",
+            dtype=dtype,
+            shape=shape,
+        )
+        combined_result = cse_compute(
+            f"{combine_helper_fn}({exclusive_prefix}, {block_scan})",
+            dtype=dtype,
+            shape=shape,
+        )
+        return (
+            cse_compute(
+                f"tl.where(roffset == 0, {block_scan}, {combined_result})",
+                dtype=dtype,
+                shape=block_scan.shape,
+            ),
+        )
+
+    def _get_heuristic(self):
+        return "split_scan"
+
+    def _get_grid_type(self) -> type[SplitScanGrid]:
+        return SplitScanGrid
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/triton_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/triton_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..75a34813c876b2e8fa11cb14cac60b761636973e
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/triton_utils.py
@@ -0,0 +1,265 @@
+# mypy: allow-untyped-defs
+from typing import Any, Optional
+
+import sympy
+
+import torch
+from torch.utils._sympy.symbol import symbol_is_type, SymT
+
+from .. import config
+from ..runtime.hints import AttrsDescriptorWrapper
+from ..utils import _type_of, expr_fits_within_32bit, triton_version_uses_attrs_dict
+from ..virtualized import V
+from .common import (
+    ArgName,
+    ConstexprArg,
+    KernelArgType,
+    SizeArg,
+    TensorArg,
+    TMADescriptorArg,
+    WorkspaceArg,
+)
+
+
+def should_unwrap_unspec_arg(name: str):
+    if V.graph.is_unspec_arg(name):
+        # Unwrap on all devices except CPU
+        if V.graph.get_current_device_or_throw().type != "cpu":
+            return True
+        # Only unwrap on CPU if the input is not used as an output
+        if name not in V.graph.mutated_buffers:
+            return True
+    return False
+
+
+def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str:
+    if isinstance(arg, TensorArg):
+        # TODO: Remove fp8 special handling when Triton supports PyTorch fp8 dtypes.
+        # Related PR: https://github.com/triton-lang/triton/pull/2279/
+        if arg.dtype == torch.float8_e4m3fn:
+            typ = "*fp8e4nv"
+        elif arg.dtype == torch.float8_e5m2:
+            typ = "*fp8e5"
+        elif arg.dtype == torch.float8_e4m3fnuz:
+            typ = "*fp8e4b8"
+        elif arg.dtype == torch.float8_e5m2fnuz:
+            typ = "*fp8e5b16"
+        else:
+            typ = _type_of(arg.dtype)
+        if should_unwrap_unspec_arg(arg.buffer):
+            # had unwrapped 0d tensor as scalar
+            new_typ = typ.lstrip("*")
+            if new_typ in ["fp16", "bf16"]:
+                return "fp32"
+            else:
+                return new_typ
+        else:
+            return typ
+    if isinstance(arg, SizeArg):
+        if arg.expr is None:
+            if triton_version_uses_attrs_dict():
+                # In newer versions of Triton, the signature includes "None" args
+                # and their type is marked as "constexpr"
+                return "constexpr"
+            else:
+                # In older versions of Triton...
+                # From triton/runtime/jit.py
+                # `None` is nullptr.  Implicitly convert to *i8.
+                return "*i8"
+        elif _arg_equals_1(arg) and triton_version_uses_attrs_dict():
+            # In new versions of Triton, if we have an equal-to-1 arg that's marked as a constant,
+            # it should be marked as "constexpr" in the signature.
+            return "constexpr"
+        elif isinstance(arg.expr, (float, sympy.Float)):
+            return "fp32"
+        elif isinstance(arg.expr, sympy.Symbol) and symbol_is_type(
+            arg.expr, (SymT.UNBACKED_FLOAT)
+        ):
+            return "fp32"
+        elif isinstance(arg.expr, bool):
+            return "i1"
+
+        # if this is a integer
+        if size_dtype == "tl.int32":
+            return "i32"
+        elif size_dtype == "tl.int64":
+            return "i64"
+        elif size_dtype is None:
+            # no hint: we'll see if we know that this is a 32-bit int, and guard if possible.
+            int_max = torch.iinfo(torch.int32).max
+            if expr_fits_within_32bit(arg.expr):
+                V.graph.sizevars.check_leq(arg.expr, int_max)
+                return "i32"
+            else:
+                return "i64"
+        else:
+            raise NotImplementedError(f"unhandled size_dtype {size_dtype}")
+    if isinstance(arg, WorkspaceArg):
+        return _type_of(arg.dtype)
+    if isinstance(arg, TMADescriptorArg):
+        if arg.api_type == "experimental":
+            return "nvTmaDesc"
+        else:
+            # https://github.com/triton-lang/triton/blob/9695baed9b46cf957e08b157bb4133f4a4b331c5/python/triton/runtime/jit.py#L360-L363
+            assert arg.api_type == "stable"
+            assert arg.block_shape is not None
+            assert arg.dtype is not None
+            inner = _type_of(arg.dtype)[1:]  # strip the `*`: *fp32 -> fp32
+            return f"tensordesc<{inner}{list(arg.block_shape)}>"
+    if isinstance(arg, ConstexprArg):
+        return "constexpr"
+    raise NotImplementedError(f"unhandled {type(arg)}: {arg}")
+
+
+def non_constexpr_signature(signature):
+    new_signature = []
+    for arg in signature:
+        if not isinstance(arg, ConstexprArg):
+            new_signature.append(arg)
+
+    return new_signature
+
+
+def signature_to_meta(
+    signature: list[KernelArgType],
+    *,
+    size_dtype: Optional[str],
+    argdefs: list[ArgName],
+    indices: Optional[list[int]] = None,
+    is_template: bool = False,
+) -> dict[str, str]:
+    if indices is None:
+        indices = list(range(len(signature)))
+
+    def _decide_tl_dtype(arg):
+        # Even if the ks0 symbol itself is within tl.int32 range, it's
+        # risky to use tl.int32 dtype since we may have ks0*ks1 later
+        # for kernels like torch.mean when dynamic shape is enabled.
+        #
+        # Check config.triton.use_block_ptr, since Triton block pointer
+        # does not support 64bit indexing:
+        # https://gist.github.com/shunting314/6a41c776171720ce4561f202dcde0ad6
+        #
+        # If the triton metadata is for a template, don't use tl.int64 index.
+        # Templates like flex attention/decoding uses block pointers which
+        # does not support 64 bit indexing.
+        if (
+            not config.triton.use_block_ptr
+            and not is_template
+            and isinstance(arg, SizeArg)
+            and arg.name.startswith("ks")
+        ):
+            return "tl.int64"
+        return size_dtype
+
+    return {
+        argdefs[i].name: signature_of(arg, size_dtype=_decide_tl_dtype(arg))
+        for i, arg in zip(indices, signature)
+    }
+
+
+def is_unaligned_buffer(arg: TensorArg):
+    buf_name = arg.buffer
+    if buf_name in V.graph.unaligned_buffers:
+        return True
+
+    if buf_name in V.graph.graph_inputs:
+        # See Note: [Input Alignment handling in Inductor]
+        # For graph inputs that is not recorded in V.graph.unaligned_buffers,
+        # we know for sure the tensor is aligned.
+        return False
+
+    if buf_name in V.graph.constants:
+        # all constants are assumed to be aligned
+        return False
+
+    if V.graph.scheduler:
+        layout = V.graph.scheduler.get_buffer_layout(buf_name)
+    else:
+        buffer = V.graph.try_get_buffer(buf_name)
+        # output arg
+        if not buffer:
+            assert buf_name == V.kernel.output_node.name
+            layout = V.kernel.output_node.layout
+        else:
+            layout = buffer.get_layout()
+
+    if isinstance(layout, torch._inductor.ir.NonOwningLayout):
+        return not layout.maybe_guard_aligned()
+    else:
+        return False
+
+
+def _arg_equals_1(arg: KernelArgType) -> bool:
+    return (
+        isinstance(arg, SizeArg)
+        and isinstance(arg.expr, (int, sympy.Integer))
+        and V.graph.sizevars.statically_known_equals(arg.expr, 1)  # type: ignore[arg-type]
+    )
+
+
+def equal_1_arg_indices(
+    args: list[KernelArgType],
+    *,
+    indices: Optional[list[int]] = None,
+) -> tuple[int, ...]:
+    if indices is None:
+        indices = list(range(len(args)))
+
+    equal_to_1 = tuple(i for i, arg in zip(indices, args) if _arg_equals_1(arg))
+
+    return equal_to_1
+
+
+def config_of(
+    args: list[KernelArgType],
+    *,
+    indices: Optional[list[int]] = None,
+) -> Any:
+    if indices is None:
+        indices = list(range(len(args)))
+
+    def is_aligned(x: KernelArgType, alignment: int, include_tensor: bool) -> bool:
+        """
+        Roughly follow triton code here:
+        https://github.com/triton-lang/triton/blob/5282ed890d453e10b9ee30076ef89115dd197761/python/triton/runtime/jit.py#L208-L222
+        """
+        if isinstance(x, TensorArg):
+            if include_tensor:
+                offset_aligned = V.graph.sizevars.statically_known_multiple_of(
+                    x.offset * x.dtype.itemsize,
+                    alignment,  # type: ignore[arg-type]
+                )
+                return offset_aligned and not is_unaligned_buffer(x)
+            else:
+                return False
+        if isinstance(x, SizeArg):
+            # TODO(voz): These are kinda redundant, if we can solve out statically_known_multiple_of with
+            # _maybe_evaluate_static...
+            if x.name.startswith("load_seed_offset"):
+                return False
+            if x.expr is None:
+                return False
+            if isinstance(x.expr, float):
+                return False
+            return V.graph.sizevars.statically_known_multiple_of(x.expr, alignment)  # type: ignore[arg-type]
+        if isinstance(x, WorkspaceArg):
+            # We allocate the workspace ourselves, so it is always aligned
+            return True
+        if isinstance(x, (TMADescriptorArg, ConstexprArg)):
+            return False
+        raise NotImplementedError(f"unhandled {type(x)}: {x}")
+
+    if config.triton.divisible_by_16:
+        divisible_by_16 = tuple(
+            i
+            for i, arg in zip(indices, args)
+            if is_aligned(arg, alignment=16, include_tensor=True)
+        )
+    else:
+        divisible_by_16 = ()
+
+    equal_to_1 = equal_1_arg_indices(args, indices=indices)
+
+    # pyrefly: ignore [bad-argument-type]
+    return AttrsDescriptorWrapper(divisible_by_16, equal_to_1)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/wrapper.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5b62bbee97c2f1d81fa7bf6a12599930a920662
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/wrapper.py
@@ -0,0 +1,3950 @@
+# mypy: allow-untyped-defs
+from __future__ import annotations
+
+import collections
+import contextlib
+import dataclasses
+import dis
+import functools
+import inspect
+import logging
+import operator
+import random
+import re
+import tempfile
+from collections.abc import Callable
+from itertools import chain, count
+from typing import Any, Optional, TYPE_CHECKING, Union
+
+import sympy
+from sympy import Expr
+
+import torch
+import torch._ops
+import torch.utils._pytree as pytree
+from torch import dtype as torch_dtype
+from torch._dynamo.utils import counters, dynamo_timed
+from torch._inductor.codegen.debug_utils import DebugPrinterManager
+from torch._inductor.codegen.multi_kernel import MultiKernelState
+from torch._inductor.runtime.runtime_utils import cache_dir
+from torch._library.opaque_object import is_opaque_value_type
+from torch._logging import trace_structured
+from torch.fx.experimental.symbolic_shapes import (
+    CallMethodKey,
+    ConvertIntKey,
+    DivideByKey,
+    resolve_unbacked_bindings,
+    SymTypes,
+)
+from torch.fx.node import _get_qualified_name
+from torch.utils._ordered_set import OrderedSet
+from torch.utils._sympy.singleton_int import SingletonInt
+from torch.utils._sympy.symbol import symbol_is_type, SymT
+
+from .. import async_compile, config, ir
+from ..codecache import output_code_log
+from ..ir import IRNode, ReinterpretView
+from ..runtime import triton_heuristics
+from ..runtime.hints import DeviceProperties
+from ..utils import (
+    cache_on_self,
+    DelayReplaceLine,
+    get_benchmark_name,
+    get_dtype_size,
+    IndentedBuffer,
+    is_codegen_graph_partition_subgraph,
+    is_using_cudagraph_partition,
+    LineContext,
+    sympy_product,
+    sympy_str,
+    sympy_subs,
+    triton_version_uses_attrs_dict,
+)
+from ..virtualized import V
+from .common import (
+    ArgName,
+    CodeGen,
+    DeferredLine,
+    PythonPrinter,
+    WorkspaceArg,
+    WorkspaceZeroMode,
+)
+from .cpp_utils import cexpr
+from .triton_utils import config_of, should_unwrap_unspec_arg, signature_to_meta
+
+
+if TYPE_CHECKING:
+    from collections.abc import Iterator, Sequence
+
+    import triton
+
+    from ..graph import GraphLowering
+    from ..ir import ExternKernel
+    from ..scheduler import BaseSchedulerNode
+    from .wrapper_fxir import FxConverter
+
+
+log = logging.getLogger(__name__)
+
+pexpr = PythonPrinter().doprint
+
+
+ReuseKey = tuple[torch.device, torch.dtype, str, bool]
+BufferLike = Union[ir.Buffer, WorkspaceArg]
+FxConversionFunc = Callable[["WrapperLine"], None]
+
+
+def buffer_reuse_key(node: BufferLike) -> ReuseKey:
+    storage_size = V.graph.get_allocation_storage_size(node)
+    alignment = node.get_name() not in V.graph.unaligned_buffers
+    return (
+        node.get_device_or_error(),
+        node.get_dtype(),
+        # NB: this is symbolic so that we don't try to reuse a buffer
+        # for s0 for s1, just because they happen to share the same
+        # size hint
+        sympy_str(V.graph.sizevars.simplify(storage_size)),
+        alignment,
+    )
+
+
+def can_match_buffer_size(input_buf: BufferLike, output_buf: BufferLike):
+    # Return True if input_buf can be re-inplaced for output_buf.
+    # This differs from `buffer_reuse_key` for general buffer reuse.
+    if input_buf.get_device_or_error() != output_buf.get_device_or_error():
+        return False
+
+    if input_buf.get_dtype() != output_buf.get_dtype():
+        return False
+
+    input_size = V.graph.sizevars.simplify(
+        V.graph.get_allocation_storage_size(input_buf)
+    )
+    output_size = V.graph.sizevars.simplify(
+        V.graph.get_allocation_storage_size(output_buf)
+    )
+
+    if (
+        # NB: this is symbolic so that we don't try to reuse a buffer
+        # for s0 for s1, just because they happen to share the same
+        # size hint
+        sympy_str(input_size) == sympy_str(output_size)
+    ) or (
+        # statically known that 0.95 * input_size <= output_size <= input_size
+        V.graph.sizevars.statically_known_geq(output_size, 0.95 * input_size)
+        and V.graph.sizevars.statically_known_leq(output_size, input_size)
+    ):
+        return True
+
+    return False
+
+
+def codegen_reinterpret_view_helper(data):
+    """
+    Collapse a chain of ReinterpretView <- StorageBox
+    <- ReinterpretView <- StorageBox.... <- buffer wrappers if every layer
+    has the same offset as the innermost (base) buffer.
+
+    Returns:
+        (size, stride, offset, dtype, collapsible: bool)
+    """
+    if isinstance(data, ir.Buffer):
+        lay = data.get_layout()
+        return lay.size, lay.stride, lay.offset, lay.dtype, True
+
+    layouts: list[Any] = []
+    cur = data
+    while isinstance(cur, (ir.TensorBox, ir.StorageBox, ir.ReinterpretView)):
+        lay = cur.get_layout()
+        if lay is None:
+            return None, None, None, None, False
+        layouts.append(lay)
+        cur = cur.data  # unwrap
+
+    if not isinstance(cur, ir.Buffer):
+        return None, None, None, None, False
+
+    # All wrapper offsets must match base offset to be collapsible
+    for lay in layouts:
+        if lay.offset != cur.get_layout().offset:
+            return None, None, None, None, False
+
+    base_lay = cur.get_layout()
+    return base_lay.size, base_lay.stride, base_lay.offset, base_lay.dtype, True
+
+
+# TODO: Move to a well known place
+TritonMetaParams = dict[str, int]
+TritonGrid = Union[
+    tuple[Union[int, sympy.Expr], ...], Callable[[TritonMetaParams], tuple[int, ...]]
+]
+
+
+def user_defined_kernel_grid_fn_code(
+    name: str,
+    configs: list[triton.Config],  # type: ignore[name-defined]
+    grids: list[TritonGrid],
+    wrapper: Optional[PythonWrapperCodegen] = None,
+    original_fxnode_name: Optional[str] = None,
+) -> tuple[str, str]:
+    output = IndentedBuffer()
+
+    def _convert_to_sympy_expr(item: Union[int, sympy.Expr]) -> sympy.Expr:
+        return item if isinstance(item, sympy.Expr) else sympy.Integer(item)
+
+    def determine_grid(
+        grid: TritonGrid,
+        example_grid: Optional[TritonGrid] = None,
+    ):
+        """
+        This function return a tuple of two values: the first one is for the real grid
+        which is used in the generated code; the second one is an example grid with
+        concreate values which is used in the autotune block to run the generated
+        kernels at compile time.
+        """
+        if wrapper is None or callable(grid):
+            # return as-is when used in eager mode or when grid is callable
+            return grid, grid
+        # Grid contains ints/Expr, so utilize wrapper's expr printer for codegen
+        sympy_grid = tuple(_convert_to_sympy_expr(g) for g in grid)
+        if not example_grid:
+            example_grid = sympy_grid
+        return (
+            wrapper.codegen_python_shape_tuple(sympy_grid),
+            (
+                wrapper.codegen_python_shape_tuple(
+                    tuple(
+                        wrapper.generate_example_arg_value(g, type(g))
+                        for g in example_grid  # type: ignore[union-attr]
+                    )
+                )
+                if config.triton.autotune_at_compile_time
+                else None
+            ),
+        )
+
+    def writeline(line: str, example_grid: Optional[str] = None):
+        output.writeline(line)
+        if (
+            wrapper
+            and config.triton.autotune_at_compile_time
+            and name not in wrapper.kernel_autotune_names
+        ):
+            wrapper.kernel_autotune_calls.writeline(example_grid or line)
+
+    fn_name = f"grid_wrapper_for_{name}"
+    writeline(f"def {fn_name}(meta):")
+    kernel_autotune_calls_indent = (
+        wrapper.kernel_autotune_calls.indent()
+        if wrapper and config.triton.autotune_at_compile_time
+        else contextlib.nullcontext()
+    )
+    with output.indent(), kernel_autotune_calls_indent:
+        if (
+            config.triton.autotune_at_compile_time
+            and original_fxnode_name
+            and V.graph.autotuning_grids
+            and original_fxnode_name in V.graph.autotuning_grids
+        ):
+            example_grids = V.graph.autotuning_grids[original_fxnode_name]
+        else:
+            example_grids = [None] * len(grids)
+        if len(grids) == 1:
+            grid, example_grid = determine_grid(grids[0], example_grids[0])
+            writeline(f"return {grid}", f"return {example_grid}")
+        else:
+            assert len(grids) > 1
+            assert len(grids) == len(configs)
+            seen: OrderedSet[str] = OrderedSet()
+            # sort the configs from the largest # of kwargs to the smallest to
+            # emit the grids in the order of (approximately) decreasing specificity
+            # TODO(aakhundov): the sorting below is generally not sufficient, so
+            # maybe we'll need to restrict the supported cases to identical kwarg
+            # names in all autotuning configs.
+            for grid, c, example_grid in sorted(
+                zip(grids, configs, example_grids),
+                key=lambda x: len(x[1].kwargs),
+                reverse=True,
+            ):
+                guardslist = []
+                if c.kwargs:
+                    # Remove AMD specific kwargs.
+                    for kwarg in c.kwargs:
+                        if kwarg not in [
+                            "matrix_instr_nonkdim",
+                            "waves_per_eu",
+                            "kpack",
+                        ]:
+                            guardslist.append(f"meta['{kwarg}'] == {c.kwargs[kwarg]}")
+                if guardslist:
+                    guards = " and ".join(guardslist)
+                else:
+                    guards = "True"  # for configs with empty kwargs
+                grid, example_grid = determine_grid(grid, example_grid)
+                statement = f"if {guards}: return {grid}"
+                if statement in seen:
+                    continue
+                seen.add(statement)
+                writeline(statement, f"if {guards}: return {example_grid}")
+
+    return fn_name, output.getvalue()
+
+
+def user_defined_triton_kernel_transitive_closure_source_code(kernel) -> str:
+    """
+    Given a triton kernel function pointer collect the transitive closure of
+    its dependencies
+    """
+    compile_wrapper = IndentedBuffer()
+    compile_wrapper.splice(kernel.src, strip=True)
+
+    # Also include any possible kernel being called indirectly
+    import triton
+    from triton import JITFunction  # type: ignore[name-defined, attr-defined]
+    from triton.language import constexpr  # type: ignore[name-defined]
+
+    # global constexpr vars handled above
+    symbols_included = OrderedSet([kernel.__name__])
+
+    def traverse(cur_kernel):
+        # here we extract the unqualified names (i.e., not attributes and
+        # without prepended module name) loaded in the kernel code, which
+        # are matched with the co_names and __globals__ below to codegen
+        # the respective imports necessary for the kernel compilation
+        unqualified_loads = OrderedSet(
+            inst.argval
+            for inst in dis.Bytecode(cur_kernel.fn)
+            if inst.opname == "LOAD_GLOBAL"
+        )
+        global_annotations = cur_kernel.fn.__globals__.get("__annotations__", {})
+        for symbol_name in cur_kernel.fn.__code__.co_names:
+            if symbol_name in symbols_included:
+                continue
+            if symbol_name in cur_kernel.fn.__globals__:
+                symbol = cur_kernel.fn.__globals__[symbol_name]
+                if isinstance(symbol, JITFunction):
+                    compile_wrapper.newline()
+                    compile_wrapper.writeline("@triton.jit")
+                    # pyrefly: ignore  # missing-attribute
+                    compile_wrapper.splice(symbol.src, strip=True)
+                    symbols_included.add(symbol_name)
+                    traverse(symbol)
+                elif hasattr(triton, "constexpr_function") and isinstance(
+                    # pyrefly: ignore  # missing-attribute
+                    symbol,
+                    # pyrefly: ignore  # missing-attribute
+                    triton.runtime.jit.ConstexprFunction,
+                ):
+                    compile_wrapper.newline()
+                    compile_wrapper.writeline("@triton.constexpr_function")
+                    compile_wrapper.splice(symbol.src, strip=True)
+                    symbols_included.add(symbol_name)
+                    traverse(symbol)
+                elif isinstance(symbol, (int, str, bool, constexpr)):
+                    compile_wrapper.newline()
+                    if isinstance(symbol, constexpr):
+                        symbol_str = f"tl.constexpr({symbol.value!r})"
+                    else:
+                        symbol_str = f"{symbol!r}"
+                    if annotation := global_annotations.get(symbol_name):
+                        if isinstance(annotation, type):
+                            annotation_code = (
+                                f": {annotation.__module__}.{annotation.__name__}"
+                            )
+                        else:
+                            annotation_code = f": {annotation!r}"
+                        compile_wrapper.writeline(
+                            f"{symbol_name}{annotation_code} = {symbol_str}"
+                        )
+                    else:
+                        compile_wrapper.writeline(f"{symbol_name} = {symbol_str}")
+                    symbols_included.add(symbol_name)
+                elif (
+                    symbol_name in unqualified_loads
+                    and symbol_name != "tl"  # already imported
+                    and hasattr(symbol, "__module__")
+                    # only codegen imports from triton; JITFunctions
+                    # imported from other modules will be codegened
+                    # in the separate branch above
+                    and symbol.__module__.startswith("triton")
+                ):
+                    # a global symbol imported from triton is referenced
+                    # without module qualification (i.e., `store` instead
+                    # of `tl.store`): need to codegen an import
+                    compile_wrapper.writeline(
+                        f"from {symbol.__module__} import {symbol.__name__} as {symbol_name}"
+                    )
+                    symbols_included.add(symbol_name)
+
+    traverse(kernel)
+    return compile_wrapper.getvalue()
+
+
+@dataclasses.dataclass
+class SymbolicCallArg:
+    inner: sympy.Symbol
+    # the original symbolic expression represented by inner
+    inner_expr: sympy.Expr
+
+    def __str__(self):
+        return str(self.inner)
+
+
+class MemoryPlanningState:
+    def __init__(self):
+        super().__init__()
+        self.reuse_pool: dict[ReuseKey, list[FreeIfNotReusedLine]] = (
+            collections.defaultdict(list)
+        )
+        self.total_allocated_buffer_size: int = 0
+
+    def __contains__(self, key: ReuseKey) -> bool:
+        return bool(self.reuse_pool.get(key, None))
+
+    def pop(self, key: ReuseKey) -> FreeIfNotReusedLine:
+        item = self.reuse_pool[key].pop()
+        assert not item.is_reused
+        return item
+
+    def push(self, key: ReuseKey, item: FreeIfNotReusedLine) -> None:
+        assert not item.is_reused
+        self.reuse_pool[key].append(item)
+
+
+class WrapperLine:
+    def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
+        raise NotImplementedError(f"FX codegen not yet supported for type {type(self)}")
+
+
+@dataclasses.dataclass
+class EnterSubgraphLine(WrapperLine):
+    wrapper: PythonWrapperCodegen
+    graph: GraphLowering
+
+    def __post_init__(self) -> None:
+        self.wrapper.push_computed_sizes(self.wrapper.computed_sizes)
+
+    def codegen(self, code: IndentedBuffer) -> None:
+        self.wrapper.push_codegened_graph(self.graph)
+        code.do_indent()
+
+    def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
+        return converter._generate_enter_subgraph
+
+
+@dataclasses.dataclass
+class ConditionalLine(WrapperLine):
+    wrapper: PythonWrapperCodegen
+    node: ir.Conditional
+
+    def codegen(self, code: IndentedBuffer) -> None:
+        raise NotImplementedError("Only supports FX codegen")
+
+    @staticmethod
+    def codegen_fx(converter: FxConverter) -> FxConversionFunc:
+        return converter._generate_conditional
+
+
+@dataclasses.dataclass
+class CommentLine(WrapperLine):
+    line: LineContext
+
+    def codegen(self, code: IndentedBuffer) -> None:
+        code.writeline(self.line)
+
+    @staticmethod
+    def codegen_fx(converter: FxConverter) -> FxConversionFunc:
+        return converter._generate_comment
+
+
+@dataclasses.dataclass
+class DynamicScalarLine(WrapperLine):
+    wrapper: PythonWrapperCodegen
+    node: ir.DynamicScalar
+
+    def codegen(self, code: IndentedBuffer) -> None:
+        self.wrapper._codegen_dynamic_scalar(self.node)
+
+    @staticmethod
+    def codegen_fx(converter: FxConverter) -> FxConversionFunc:
+        return converter._generate_dynamic_scalar
+
+
+@dataclasses.dataclass
+class ExitSubgraphLine(WrapperLine):
+    wrapper: PythonWrapperCodegen
+
+    def __post_init__(self) -> None:
+        self.wrapper.computed_sizes = self.wrapper.pop_computed_sizes()
+
+    def codegen(self, code: IndentedBuffer) -> None:
+        self.wrapper.pop_codegened_graph()
+        code.do_unindent()
+
+    def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
+        return converter._generate_exit_subgraph
+
+
+@dataclasses.dataclass
+class EnterDeviceContextManagerLine(WrapperLine):
+    device_idx: int
+    last_seen_device_guard_index: Optional[int]
+
+    def codegen(self, code: IndentedBuffer) -> None:
+        if V.graph.cpp_wrapper:
+            code.writeline("\n")
+            if V.graph.aot_mode:
+                # In AOT mode, we have a stream provided as a param. A stream is
+                # associated with a device, so we never expect the device to change.
+                # CUDAStreamGuard sets the stream and the device.
+                if self.last_seen_device_guard_index is None:
+                    code.writeline(
+                        f"{V.graph.device_ops.cpp_aoti_stream_guard()} stream_guard(stream, this->device_idx_);"
+                    )
+                else:
+                    assert self.last_seen_device_guard_index == self.device_idx, (
+                        "AOTInductor only supports running on one CUDA device"
+                    )
+            else:
+                if self.last_seen_device_guard_index is None:
+                    code.writeline(
+                        f"{V.graph.device_ops.cpp_aoti_device_guard()} device_guard({self.device_idx});"
+                    )
+                else:
+                    code.writeline(f"device_guard.set_index({self.device_idx});")
+        else:
+            # Note _DeviceGuard has less overhead than device, but only accepts
+            # integers
+            code.writeline(f"with {V.graph.device_ops.device_guard(self.device_idx)}:")
+            code.do_indent()
+            code.writeline(V.graph.device_ops.set_device(self.device_idx))
+
+    def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
+        return converter._generate_enter_device_context_manager
+
+
+class ExitDeviceContextManagerLine(WrapperLine):
+    def codegen(self, code: IndentedBuffer) -> None:
+        if not V.graph.cpp_wrapper:
+            code.do_unindent()
+
+    def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
+        return converter._generate_exit_device_context_manager
+
+
+@dataclasses.dataclass
+class ExternKernelAllocLine(WrapperLine):
+    wrapper: PythonWrapperCodegen
+    node: ir.ExternKernelAlloc
+
+    def codegen(self, code: IndentedBuffer) -> None:
+        node = self.node
+        args = [*node.codegen_args(), *node.codegen_kwargs()]
+        self.wrapper._generate_extern_kernel_alloc_helper(self.node, args)
+
+    def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
+        return converter._generate_extern_kernel_alloc
+
+
+@dataclasses.dataclass
+class ExternKernelOutLine(WrapperLine):
+    wrapper: PythonWrapperCodegen
+    node: ir.ExternKernelOut
+
+    def codegen(self, code: IndentedBuffer) -> None:
+        node = self.node
+        args = [*node.codegen_args(), *node.codegen_kwargs(skip_out=True)]
+        kernel_name = node.get_kernel_name()
+        if (
+            V.graph.cpp_wrapper
+            and node.cpp_kernel_name == "torch::inductor::_mm_plus_mm"
+        ):
+            # For https://github.com/pytorch/pytorch/issues/128474
+            kernel_name = "aoti_torch__mm_plus_mm_out"
+        else:
+            kernel_name = node.get_kernel_name()
+        device = d.type if (d := node.get_device()) else V.graph.device_type
+        self.wrapper._generate_extern_kernel_out_helper(
+            kernel_name,
+            node.codegen_reference(),
+            node.output_view.codegen_reference() if node.output_view else None,
+            args,
+            device,
+            self.node.get_stack_traces(),
+        )
+
+    def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
+        return converter._generate_extern_kernel_out
+
+
+@dataclasses.dataclass
+class FreeLine(WrapperLine):
+    wrapper: PythonWrapperCodegen
+    node: Union[BufferLike, ir.TorchBindObject]
+
+    def codegen(self, code: IndentedBuffer) -> None:
+        assert self.node.get_name() not in V.graph.removed_buffers
+        code.writeline(self.wrapper.make_buffer_free(self.node))
+
+    def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
+        return converter._generate_free
+
+
+@dataclasses.dataclass
+class KernelCallLine(WrapperLine):
+    wrapper: PythonWrapperCodegen
+    kernel_name: str
+    call_args: tuple[Any, ...]
+    raw_keys: tuple[Any, ...]
+    raw_args: tuple[Any, ...]
+    arg_types: list[str]
+    triton: bool
+    triton_meta: dict[str, Any]
+    device: torch.device
+    graph_name: str
+    original_fxnode_name: str
+
+    def codegen(self, code: IndentedBuffer) -> None:
+        self.wrapper._generate_kernel_call_helper(
+            self.kernel_name,
+            self.call_args,
+            triton=self.triton,
+            arg_types=self.arg_types,
+            raw_keys=self.raw_keys,
+            raw_args=self.raw_args,
+            triton_meta=self.triton_meta,
+            device=self.device,
+            graph_name=self.graph_name,
+            original_fxnode_name=self.original_fxnode_name,
+        )
+
+    def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
+        return converter._generate_kernel_call
+
+
+@dataclasses.dataclass
+class KernelDefinitionLine(WrapperLine):
+    wrapper: PythonWrapperCodegen
+    kernel_name: str
+    kernel_body: str
+    metadata: Optional[str] = None
+    gpu: bool = True
+    cpp_definition: Optional[str] = None
+
+    def codegen(self, code: IndentedBuffer) -> None:
+        self.wrapper._define_kernel_helper(
+            self.kernel_name,
+            self.kernel_body,
+            metadata=self.metadata,
+            gpu=self.gpu,
+            cpp_definition=self.cpp_definition,
+        )
+
+    def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
+        return converter._generate_kernel_definition
+
+
+@dataclasses.dataclass
+class MemoryPlanningLine(WrapperLine):
+    wrapper: PythonWrapperCodegen
+
+    def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
+        """First pass to find reuse"""
+        return self
+
+    def codegen(self, code: IndentedBuffer) -> None:
+        """Second pass to output code"""
+
+    def __str__(self) -> str:
+        """
+        Emits a string representation that fits on one line.
+        """
+        args: list[str] = []
+        for field in dataclasses.fields(self):
+            if field.name == "wrapper":
+                continue
+            val = getattr(self, field.name)
+            args.append(
+                f"{field.name}={val.get_name() if field.type is ir.Buffer else val}"
+            )
+        return f"{type(self).__name__}({', '.join(args)})"
+
+
+class EfficientPeakEstimate:
+    def __init__(self):
+        from ..memory import estimate_peak_memory, get_freeable_input_buf
+
+        scheduler_nodes = V.graph.scheduler.nodes
+        graph_inputs = OrderedSet(V.graph.graph_inputs.keys())
+        graph_outputs = OrderedSet(V.graph.get_output_names())
+        names_to_freeable_bufs = get_freeable_input_buf(scheduler_nodes, graph_inputs)
+        self.overall_peak_memory, peak_by_scheduler_node = estimate_peak_memory(
+            scheduler_nodes,
+            names_to_freeable_bufs,
+            graph_outputs,
+        )
+
+        from .segmented_tree import SegmentedTree
+
+        self.segmented_tree = SegmentedTree(
+            peak_by_scheduler_node, operator.add, max, 0
+        )
+
+    def _get_size(self, node: BufferLike) -> int:
+        return V.graph.sizevars.size_hint(
+            V.graph.get_allocation_storage_size(node), fallback=0
+        ) * get_dtype_size(node.get_dtype())
+
+    def peak_between(self, line_a: FreeIfNotReusedLine, line_b: AllocateLine):
+        return self.segmented_tree.summarize_range(
+            line_a.scheduler_node_index + 1, line_b.scheduler_node_index - 1
+        )
+
+    def update_peak_between(self, line_a: FreeIfNotReusedLine, line_b: AllocateLine):
+        if line_a.scheduler_node_index + 1 == line_b.scheduler_node_index:
+            return
+        self.segmented_tree.update_range(
+            line_a.scheduler_node_index + 1,
+            line_b.scheduler_node_index - 1,
+            self._get_size(line_b.node),
+        )
+
+
+@dataclasses.dataclass
+class AllocateLine(MemoryPlanningLine):
+    node: BufferLike
+
+    def __post_init__(self):
+        assert V.graph.scheduler.current_node is not None
+        self.scheduler_node_index = V.graph.scheduler.nodes.index(
+            V.graph.scheduler.current_node
+        )
+
+    def should_reuse_buffer(self, free_line: FreeIfNotReusedLine, size: int) -> bool:
+        if free_line.scheduler_node_index + 1 == self.scheduler_node_index:
+            return True
+        overall_peak_memory = self.wrapper.estimate_peak.overall_peak_memory
+        peak_memory_in_range = self.wrapper.estimate_peak.peak_between(free_line, self)
+        new_peak_memory = size + peak_memory_in_range
+        return new_peak_memory <= overall_peak_memory
+
+    def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
+        if self.node.get_name() in V.graph.removed_buffers:
+            return NullLine(self.wrapper)
+
+        # try to reuse a recently freed buffer
+        key = buffer_reuse_key(self.node)
+        if config.allow_buffer_reuse and key in state:
+            free_line = state.pop(key)
+            size = V.graph.sizevars.size_hint(
+                V.graph.get_allocation_storage_size(self.node), fallback=0
+            ) * get_dtype_size(self.node.get_dtype())
+            if self.should_reuse_buffer(free_line, size):
+                free_line.is_reused = True
+                self.wrapper.estimate_peak.update_peak_between(free_line, self)
+                return ReuseLine(self.wrapper, free_line.node, self.node)
+            else:
+                state.push(key, free_line)
+                return self
+
+        if self.node.get_device_or_error().type == "cpu":
+            static_shape = self.wrapper.static_shape_for_buffer_or_none(self.node)
+            if static_shape is not None:
+                state.total_allocated_buffer_size += int(
+                    functools.reduce(operator.mul, static_shape, 1)
+                )
+
+        return self
+
+    def codegen(self, code: IndentedBuffer) -> None:
+        assert self.node.get_name() not in V.graph.removed_buffers
+        line = self.wrapper.make_buffer_allocation(self.node)
+        code.writeline(line)
+
+    def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
+        return converter._generate_allocate
+
+
+@dataclasses.dataclass
+class FreeIfNotReusedLine(MemoryPlanningLine):
+    node: BufferLike
+    is_reused: bool = False
+
+    def __post_init__(self):
+        assert V.graph.scheduler.current_node is not None
+        self.scheduler_node_index = V.graph.scheduler.nodes.index(
+            V.graph.scheduler.current_node
+        )
+
+    def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
+        if len(self.node.get_inputs_that_alias_output()) > 0:
+            return self
+        if isinstance(self.node.layout, ir.MultiOutputLayout):
+            return self
+        assert not self.is_reused
+        if self.node.get_name() in V.graph.removed_buffers:
+            return NullLine(self.wrapper)
+        if config.allow_buffer_reuse:
+            state.push(buffer_reuse_key(self.node), self)
+        return self
+
+    def codegen(self, code: IndentedBuffer) -> None:
+        assert self.node.get_name() not in V.graph.removed_buffers
+        if not self.is_reused:
+            code.writeline(self.wrapper.make_buffer_free(self.node))
+
+    def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
+        return converter._generate_free_if_not_reused
+
+
+@dataclasses.dataclass
+class ReinterpretLine(MemoryPlanningLine):
+    node: BufferLike
+    reused_as: BufferLike
+    layout: ir.Layout
+
+    def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
+        return self
+
+    def codegen(self, code: IndentedBuffer) -> None:
+        assert isinstance(self.layout, ir.NonOwningLayout)
+        assert isinstance(self.layout.view, ir.ReinterpretView)
+        self.wrapper.codegen_deferred_allocation(
+            self.reused_as.get_name(), self.layout.view
+        )
+
+    def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
+        return converter._generate_reinterpret
+
+
+@dataclasses.dataclass
+class ReuseLine(MemoryPlanningLine):
+    node: BufferLike
+    reused_as: BufferLike
+    delete_old: bool = True
+
+    def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
+        if self.node.get_name() in V.graph.removed_buffers:
+            assert self.reused_as.get_name() in V.graph.removed_buffers
+            return NullLine(self.wrapper)
+        assert self.reused_as.get_name() not in V.graph.removed_buffers
+        return self
+
+    def codegen(self, code: IndentedBuffer) -> None:
+        assert self.node.get_name() not in V.graph.removed_buffers
+        assert self.reused_as.get_name() not in V.graph.removed_buffers
+        code.writeline(
+            self.wrapper.make_buffer_reuse(self.node, self.reused_as, self.delete_old)
+        )
+
+    def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
+        return converter._generate_reuse
+
+
+class NullLine(MemoryPlanningLine):
+    def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
+        return converter._generate_null
+
+
+@dataclasses.dataclass
+class CommBufferLine(WrapperLine):
+    wrapper: PythonWrapperCodegen  # type: ignore[name-defined] # noqa: F821
+    node: ir.Buffer
+
+    @property
+    def size(self) -> int:
+        from torch._inductor.utils import is_symbolic
+
+        numel = self.node.get_numel()
+        dtype = self.node.get_dtype()
+        if is_symbolic(numel):
+            raise AssertionError(
+                f"The size of a comm buffer can't be symbolic: {self.node}"
+            )
+        return int(numel) * dtype.itemsize
+
+    @property
+    def comm_buffer_type(self) -> ir.CommBufferType:
+        layout = self.node.get_output_spec()
+        assert isinstance(layout, ir.CommBufferLayout)
+        return layout.comm_buffer_type
+
+    @property
+    def group_name(self) -> str:
+        layout = self.node.get_output_spec()
+        assert isinstance(layout, ir.CommBufferLayout)
+        return layout.group_name
+
+
+@dataclasses.dataclass
+class CommBufferAllocateLine(CommBufferLine):
+    def codegen(self, code: IndentedBuffer) -> None:
+        assert self.node.get_name() not in V.graph.removed_buffers
+        name = self.node.get_name()
+        device = self.node.get_device()
+        dtype = self.node.get_dtype()
+        shape = tuple(self.node.get_size())
+        stride = tuple(self.node.get_stride())
+        code.writeline(
+            self.make_allocation_line(
+                self.comm_buffer_type,
+                self.group_name,
+                self.wrapper,
+                name,
+                device,
+                dtype,
+                shape,
+                stride,
+            )
+        )
+
+    @staticmethod
+    def make_allocation_line(
+        comm_buffer_type, group_name, wrapper, name, device, dtype, shape, stride
+    ):
+        if comm_buffer_type == ir.CommBufferType.SYMM_MEM:
+            return (
+                f"{name} = empty_strided_p2p("
+                f"{wrapper.codegen_shape_tuple(shape)}, "
+                f"{wrapper.codegen_shape_tuple(stride)}, "
+                f"{dtype}, "
+                f'torch.device("cuda:{device.index}"), '
+                f'group_name="{group_name}", '
+                f"alloc_id={random.randint(0, 2**64 - 1)})"
+            )
+        else:
+            raise NotImplementedError(
+                f"Unsupported comm buffer type: {comm_buffer_type}"
+            )
+
+    def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
+        return converter._generate_comm_buffer_allocate
+
+
+@dataclasses.dataclass
+class CommBufferFreeLine(CommBufferLine):
+    def codegen(self, code: IndentedBuffer) -> None:
+        line = self.wrapper.make_buffer_free(self.node)
+        code.writeline(f"{line} # {self.comm_buffer_type.value} buffer free")
+
+    def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
+        return converter._generate_comm_buffer_free
+
+
+@dataclasses.dataclass
+class MultiOutputLine(WrapperLine):
+    """
+    Given a MultiOutputLayout buffer, indexes actual buffer(s) from the result.
+    """
+
+    wrapper: PythonWrapperCodegen
+    result_name: str
+    arg_name: str
+    indices: Sequence[Any]
+
+    def codegen(self, code: IndentedBuffer) -> None:
+        def codegen_list_tuple_access(basename, indices):  # type: ignore[no-untyped-def]
+            if len(indices) > 0:
+                itype, i = indices[0]
+                if issubclass(itype, list):
+                    return codegen_list_tuple_access(f"{basename}[{i}]", indices[1:])
+                elif issubclass(itype, tuple):
+                    # cpp wrapper code needs to use std::get<> to access a tuple
+                    tuple_access = self.wrapper.codegen_tuple_access(
+                        basename, self.result_name, str(i)
+                    )
+                    return codegen_list_tuple_access(tuple_access, indices[1:])
+                elif issubclass(itype, dict):
+                    return codegen_list_tuple_access(f"{basename}['{i}']", indices[1:])
+                else:
+                    raise AssertionError("non supported index type: ", itype)
+            else:
+                return basename
+
+        value = codegen_list_tuple_access(self.arg_name, self.indices)
+        code.writeline(
+            f"{self.wrapper.declare}{self.result_name} = {value}{self.wrapper.ending}"
+        )
+
+    def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
+        return converter._generate_multi_output
+
+
+@dataclasses.dataclass
+class IndexPutFallbackLine(WrapperLine):
+    wrapper: PythonWrapperCodegen
+    node: ir.IndexPutFallback
+    indices: list[Optional[ir.IRNode]]
+
+    def codegen(self, code: IndentedBuffer) -> None:
+        node = self.node
+        assert ir.is_node_sequence(node.inputs)
+        (x, values) = (t.codegen_reference() for t in node.inputs[:2])
+        indices = [
+            idx.codegen_reference() if idx else self.wrapper.none_str
+            for idx in self.indices
+        ]
+
+        self.wrapper._generate_index_put_fallback(
+            node.get_kernel_name(), x, indices, values, *node.codegen_const_args()
+        )
+
+    def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
+        return converter._generate_index_put_fallback
+
+
+@dataclasses.dataclass
+class ScatterFallbackLine(WrapperLine):
+    wrapper: PythonWrapperCodegen
+    node: ir.ScatterFallback
+
+    def codegen(self, code: IndentedBuffer) -> None:
+        node = self.node
+        assert ir.is_node_sequence(node.inputs)
+        if node.src_is_tensor:
+            (x, index, src) = (t.codegen_reference() for t in node.inputs)
+        else:
+            (x, index) = (t.codegen_reference() for t in node.inputs)
+            src = node.constant_args[1]
+        device = d.type if (d := node.get_device()) else V.graph.device_type
+        self.wrapper._generate_scatter_fallback(
+            x,
+            [x, node.constant_args[0], index, src],
+            node.cpp_kernel_name,
+            node.python_kernel_name,
+            node.src_is_tensor,
+            node.kwargs["reduce"],
+            node.codegen_kwargs(),
+            device,
+        )
+
+    def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
+        return converter._generate_scatter_fallback
+
+
+@dataclasses.dataclass
+class SymbolicCallArgLine(WrapperLine):
+    wrapper: PythonWrapperCodegen
+    arg: SymbolicCallArg
+    graph: GraphLowering
+
+    def codegen(self, code: IndentedBuffer) -> None:
+        self.wrapper._generate_symbolic_call_arg_helper(self.arg, self.graph)
+
+    def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
+        return converter._generate_symbolic_call_arg
+
+
+@dataclasses.dataclass
+class UnbackedSymbolDefsLine(WrapperLine):
+    wrapper: PythonWrapperCodegen
+    output_name: str
+    outputs: Any
+    unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]]
+
+    def codegen(self, code: IndentedBuffer) -> None:
+        self.wrapper._codegen_unbacked_symbol_defs_for_outputs(
+            self.output_name, self.outputs, self.unbacked_bindings
+        )
+
+    def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
+        return converter._generate_unbacked_symbol_defs
+
+
+BufferName = str
+Line = Union[MemoryPlanningLine, LineContext]
+
+
+class PythonWrapperCodegen(CodeGen):
+    """
+    Generate outer wrapper in Python that calls the kernels.
+    """
+
+    supports_caching = True  # Whether the output code is cacheable.
+
+    def __init__(self):
+        super().__init__()
+        self._names_iter: Iterator[int] = count()
+        self.args_to_buffers: dict[
+            str, Union[None, ir.TensorBox, ir.Buffer, ir.TorchBindObject]
+        ] = {}
+        self.imports = IndentedBuffer()
+        self.header = IndentedBuffer()
+        self.prefix = IndentedBuffer()
+        self.suffix = IndentedBuffer()
+        self.kernel_declarations = IndentedBuffer()
+        self.wrapper_call = IndentedBuffer()
+        self.kernel_autotune_defs = IndentedBuffer()
+        self.kernel_autotune_calls = IndentedBuffer()
+        self.subgraph_definitions = IndentedBuffer()
+        self.kernel_autotune_names: OrderedSet[str] = OrderedSet()
+        # Map key is the kernel argument name; value is a tuple of the resulting example
+        # tensor name with the kernel where that tensor was most recently used.
+        self.kernel_autotune_example_args: dict[str, tuple[str, str]] = {}
+        self.kernel_autotune_tmp_arg_idx: int = 0
+        # If the generated source code is exactly the same, reuse the
+        # pre-existing kernel for it
+        self.src_to_kernel: dict[str, str] = {}
+        self.kernel_numel_expr: OrderedSet[tuple[str, GraphLowering]] = OrderedSet()
+        self.lines: list[Line] = []
+        self.declare = ""
+        self.declare_maybe_reference = ""
+        self.ending = ""
+        self.comment = "#"
+        self.none_str = "None"
+        self.move_begin = "std::move(" if V.graph.cpp_wrapper else ""
+        self.move_end = ")" if V.graph.cpp_wrapper else ""
+        self.last_seen_device_guard_index: Optional[int] = None
+        self.supports_intermediate_hooks = True
+        self.user_defined_kernel_cache: dict[tuple[Any, ...], tuple[str, Any]] = {}
+        self.unbacked_symbol_decls: OrderedSet[str] = (
+            OrderedSet()
+        )  # str of sympy.Symbol
+        self.computed_sizes: OrderedSet[sympy.Symbol] = OrderedSet()
+        self.launcher_fn_name = None
+        # This function can be overridden to change the launcher name
+        self.set_launcher_fn_name()
+
+        # this is used for tracking which GraphLowering instance---parent graph
+        # or (nested) subgraph---is currently codegened; the primary use case is
+        # including the graph instance into a cache key to avoid cross-graph
+        # caching during lowering of nested subgraphs
+        self.codegened_graph_stack = []
+        self.computed_sizes_stack = []
+
+        self.write_header()
+
+        if not is_codegen_graph_partition_subgraph(self):
+            # See [Note: Removed Graph Partition Arguments]
+            self.write_prefix()
+
+        self.write_kernel_autotune_defs_header()
+
+        if not V.graph.aot_mode:
+            for name, hashed in V.graph.constant_reprs.items():
+                # include a hash so our code cache puts different constants into different files
+                self.write_constant(name, hashed)
+
+        self.allocated = OrderedSet[BufferName]()
+        self.freed = OrderedSet[BufferName]()
+
+        # maps from reusing buffer to reused buffer
+        self.reuses: dict[BufferName, BufferName] = {}
+
+        self.write_get_raw_stream = functools.lru_cache(None)(  # type: ignore[assignment]
+            self.write_get_raw_stream
+        )
+
+        @functools.cache
+        def add_import_once(line: str) -> None:
+            self.imports.writeline(line)
+            if config.triton.autotune_at_compile_time:
+                self.kernel_autotune_calls.writeline(line)
+
+        self.add_import_once = add_import_once
+        self._metas: dict[str, str] = {}
+        self._meta_vars: OrderedSet[str] = OrderedSet()
+        self.multi_kernel_state = MultiKernelState()
+        self.already_codegened_subgraphs: OrderedSet[str] = OrderedSet()
+        self.allocated_workspaces: dict[str, Any] = {}
+
+        # intermediate tensor value printing utility
+        self.debug_printer = DebugPrinterManager(
+            debug_printer_level=config.aot_inductor.debug_intermediate_value_printer,
+            use_array_ref=config.aot_inductor.allow_stack_allocation,
+        )
+
+        # Additional files that are dependent to the wrapper (ex. cubin files)
+        self.additional_files = []
+
+    @staticmethod
+    def create(
+        is_subgraph: bool,
+        subgraph_name: Optional[str],
+        parent_wrapper: Optional[PythonWrapperCodegen],
+        partition_signatures: Optional[ir.GraphPartitionSignature] = None,
+    ):
+        if is_subgraph:
+            assert subgraph_name is not None
+            assert parent_wrapper is not None
+            return SubgraphPythonWrapperCodegen(
+                subgraph_name, parent_wrapper, partition_signatures
+            )
+        return PythonWrapperCodegen()
+
+    def set_launcher_fn_name(self) -> None:
+        # pyrefly: ignore [bad-assignment]
+        self.launcher_fn_name = "call"
+
+    def write_constant(self, name: str, hashed: str) -> None:
+        self.header.writeline(f"{name} = None  # {hashed}")
+
+    def write_header(self) -> None:
+        context = torch._guards.TracingContext.try_get()
+        aot_config_comment = ""
+        if context is not None and context.aot_graph_name is not None:
+            aot_config_comment = f"# AOT ID: {context.aot_graph_name}"
+        inductor_debug_utils = ""
+        if int(config.aot_inductor.debug_intermediate_value_printer) > 0:
+            inductor_debug_utils = "from torch._inductor.codegen.debug_utils import _print_debugging_tensor_value_info"
+        elif torch._inductor.config.test_configs.track_memory_lifecycle:
+            inductor_debug_utils = "from torch._inductor.runtime.debug_utils import tracked_empty_strided\n"
+
+        self.imports.splice(
+            f"""
+                {aot_config_comment}
+                from ctypes import c_void_p, c_long, c_int
+                import torch
+                import math
+                import random
+                import os
+                import tempfile
+                from math import inf, nan
+                from cmath import nanj
+                from torch._inductor.hooks import run_intermediate_hooks
+                from torch._inductor.utils import maybe_profile
+                from torch._inductor.codegen.memory_planning import _align as align
+                from torch import device, empty_strided
+                from {async_compile.__name__} import AsyncCompile
+                from torch._inductor.select_algorithm import extern_kernels
+                {inductor_debug_utils}
+            """,
+            strip=True,
+        )
+        self.header.splice(
+            """
+                aten = torch.ops.aten
+                inductor_ops = torch.ops.inductor
+                _quantized = torch.ops._quantized
+                assert_size_stride = torch._C._dynamo.guards.assert_size_stride
+                assert_alignment = torch._C._dynamo.guards.assert_alignment
+                empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
+                empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
+                empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
+                empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
+                empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
+                reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
+                alloc_from_pool = torch.ops.inductor._alloc_from_pool
+                async_compile = AsyncCompile()
+            """,
+            strip=True,
+        )
+        try:
+            # Only add empty_strided_p2p() if distributed and SymmetricMemory
+            # is available
+            from torch._C._distributed_c10d import _SymmetricMemory  # noqa: F401
+
+            self.header.splice(
+                """
+                empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
+                """,
+                strip=True,
+            )
+        except (AttributeError, ImportError):
+            pass
+        if config.annotate_training:
+            self.header.writeline("from torch.cuda import nvtx")
+
+    def include_extra_header(self, header: str):
+        pass
+
+    def write_kernel_autotune_defs_header(self) -> None:
+        self.kernel_autotune_defs.splice(
+            f"""
+                import torch
+                from torch._dynamo.testing import rand_strided
+                from torch._dynamo.utils import preserve_rng_state
+                from torch._inductor.select_algorithm import AlgorithmSelectorCache
+                from {async_compile.__name__} import AsyncCompile
+
+                async_compile = AsyncCompile()
+                generate_example_value = AlgorithmSelectorCache.generate_example_value
+                empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
+                empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
+            """
+        )
+
+        try:
+            from torch._C import _cuda_getCurrentRawStream  # noqa: F401
+
+            self.kernel_autotune_defs.splice(
+                """
+                get_raw_stream = torch._C._cuda_getCurrentRawStream
+                """,
+                strip=True,
+            )
+        except (ImportError, AttributeError):
+            pass
+
+    @cache_on_self
+    def write_triton_header_once(self) -> None:
+        import_str = f"""
+            import triton
+            import triton.language as tl
+            from {triton_heuristics.__name__} import start_graph, end_graph
+            """
+        if config.triton.autotune_at_compile_time:
+            self.kernel_autotune_calls.splice(import_str)
+            self.kernel_autotune_calls.writeline(
+                V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
+            )
+        if not V.graph.cpp_wrapper:
+            self.imports.splice(import_str, strip=True)
+            self.imports.writeline(
+                V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
+            )
+
+    def write_get_raw_stream_header(self) -> None:
+        import_get_raw_stream_str = V.graph.device_ops.import_get_raw_stream_as(
+            "get_raw_stream"
+        )
+        if config.triton.autotune_at_compile_time:
+            if not self.kernel_autotune_calls.contains(import_get_raw_stream_str):
+                self.kernel_autotune_calls.writeline(import_get_raw_stream_str)
+        if not V.graph.cpp_wrapper:
+            if not self.imports.contains(import_get_raw_stream_str):
+                self.imports.writeline(import_get_raw_stream_str)
+
+    @cache_on_self
+    def write_get_raw_stream_header_once(self) -> None:
+        self.write_get_raw_stream_header()
+
+    def add_meta_once(self, meta: TritonMetaParams) -> str:
+        # pyrefly: ignore [bad-assignment]
+        meta = repr(meta)
+        if meta not in self._metas:
+            var = f"meta{len(self._metas)}"
+            # pyrefly: ignore [unsupported-operation]
+            self._metas[meta] = var
+            self.header.writeline(f"{var} = {meta}")
+            if config.triton.autotune_at_compile_time:
+                self.kernel_autotune_calls.writeline(f"{var} = {meta}")
+                self._meta_vars.add(var)
+        # pyrefly: ignore [index-error]
+        return self._metas[meta]
+
+    @cache_on_self
+    def get_output_refs(self) -> list[str]:
+        return [
+            x.codegen_reference(self.wrapper_call) for x in self.get_graph_outputs()
+        ]
+
+    def mark_output_type(self) -> None:
+        return
+
+    def get_graph_inputs(
+        self,
+    ) -> dict[str, Union[ir.TensorBox, ir.TorchBindObject, sympy.Expr]]:
+        return V.graph.graph_inputs
+
+    def get_graph_outputs(self) -> list[IRNode]:
+        return V.graph.graph_outputs
+
+    def codegen_input_size_asserts(self) -> None:
+        for name, buf in self.get_graph_inputs().items():
+            if isinstance(buf, (sympy.Expr, ir.TorchBindObject)):
+                continue
+
+            # a graph partition may take an IRNode output from a previous partition
+            if name not in V.graph.graph_input_names or isinstance(
+                buf, ir.GeneratorState
+            ):
+                continue
+
+            # comparing strides for 0 size tensor is tricky. Ignore them for now.
+            if sympy_product(buf.get_size()) == 0:
+                continue
+            size = self.codegen_python_shape_tuple(buf.get_size())
+            stride = self.codegen_python_shape_tuple(buf.get_stride())
+            self.prefix.writeline(f"assert_size_stride({name}, {size}, {stride})")
+
+    def codegen_input_nan_asserts(self) -> None:
+        self.prefix.writeline("# make sure graph inputs are not nan/inf")
+        for name, buf in self.get_graph_inputs().items():
+            if isinstance(buf, (sympy.Expr, ir.TorchBindObject)):
+                continue
+
+            line = f"assert not {name}.isnan().any().item()"
+            self.prefix.writeline(line)
+            line = f"assert not {name}.isinf().any().item()"
+            self.prefix.writeline(line)
+
+    def write_async_compile_wait(self) -> None:
+        self.prefix.splice(
+            """
+
+            async_compile.wait(globals())
+            del async_compile
+            """
+        )
+
+    def write_args(self, input_names: list[str]):
+        lhs = ", ".join(input_names)
+        if len(input_names) == 1:
+            lhs += ","
+        self.prefix.writeline(f"{lhs} = args")
+        self.prefix.writeline("args.clear()")
+
+    def write_launcher_fn_call_get_indent(self) -> int:
+        if config.graph_partition:
+            self.prefix.splice(
+                """
+                class Runner:
+                    def __init__(self, partitions):
+                        self.partitions = partitions
+
+                    def recursively_apply_fns(self, fns):
+                        new_callables = []
+                        for fn, c in zip(fns, self.partitions):
+                            new_callables.append(fn(c))
+                        self.partitions = new_callables
+
+                    def call(self, args):
+                """
+            )
+            prefix_indent = 2
+        else:
+            self.prefix.splice(
+                f"""
+                def {self.launcher_fn_name}(args):
+                """
+            )
+            prefix_indent = 1
+
+        return prefix_indent
+
+    def get_graph_input_names(self) -> list[str]:
+        return V.graph.graph_input_names
+
+    def write_prefix(self) -> None:
+        assert self.launcher_fn_name is not None
+        self.write_async_compile_wait()
+        prefix_indent = self.write_launcher_fn_call_get_indent()
+
+        with self.prefix.indent(prefix_indent):
+            if config.triton.debug_sync_graph:
+                self.prefix.writeline(V.graph.device_ops.synchronize())
+            phase = V.graph.get_training_phase()
+            if config.annotate_training:
+                self.prefix.writeline(
+                    f"training_annotation = nvtx._device_range_start('{phase}')"
+                )
+
+            if graph_input_names := self.get_graph_input_names():
+                self.write_args(graph_input_names)
+
+            self.codegen_inputs()
+
+            # avoid duplicating asserts for both partition functions and
+            # the call function when using cudagraph partition
+            if not (
+                is_using_cudagraph_partition()
+                and (not is_codegen_graph_partition_subgraph(self))
+            ):
+                self.codegen_input_size_and_nan_asserts()
+
+    def codegen_input_size_and_nan_asserts(self) -> None:
+        if config.size_asserts:
+            self.codegen_input_size_asserts()
+        if config.nan_asserts:
+            self.codegen_input_nan_asserts()
+
+    # this function (and below) takes the graph name as input so
+    # that stream caching happens per graph instance. this
+    # is important for nested subgraph codegening.
+    def write_get_raw_stream(self, device_idx: int, graph_name: str) -> str:
+        self.write_get_raw_stream_header()
+        name = f"stream{device_idx}"
+        if config.triton.autotune_at_compile_time:
+            self.kernel_autotune_calls.writeline(
+                f"{name} = get_raw_stream({device_idx})"
+            )
+            if V.graph.cpp_wrapper:
+                # For cpp wrapper, no need to continue codegen for the main body
+                return name
+        self.writeline(f"{name} = get_raw_stream({device_idx})")
+        return name
+
+    def get_codegened_graph(self):
+        return self.codegened_graph_stack[-1]
+
+    def push_codegened_graph(self, graph):
+        self.codegened_graph_stack.append(graph)
+
+    def pop_codegened_graph(self):
+        return self.codegened_graph_stack.pop()
+
+    def push_computed_sizes(self, computed_sizes):
+        from copy import deepcopy
+
+        return self.computed_sizes_stack.append(deepcopy(computed_sizes))
+
+    def pop_computed_sizes(self):
+        return self.computed_sizes_stack.pop()
+
+    def next_kernel_suffix(self) -> str:
+        return f"{next(self._names_iter)}"
+
+    def codegen_device_guard_enter(self, device_idx: int) -> None:
+        self.writeline(
+            EnterDeviceContextManagerLine(device_idx, self.last_seen_device_guard_index)
+        )
+        if config.triton.autotune_at_compile_time:
+            # mimic logic of EnterDeviceContextManagerLine.codegen for the autotune code block
+            self.write_triton_header_once()
+            self.kernel_autotune_calls.writeline(
+                f"with {V.graph.device_ops.device_guard(device_idx)}:"
+            )
+            self.kernel_autotune_calls.do_indent()
+            if is_codegen_graph_partition_subgraph(self):
+                # Need get_raw_stream for subgraph
+                self.write_get_raw_stream_header()
+            self.kernel_autotune_calls.writeline(
+                f"stream{device_idx} = get_raw_stream({device_idx})"
+            )
+        self.last_seen_device_guard_index = device_idx
+
+    def codegen_device_guard_exit(self) -> None:
+        self.writeline(ExitDeviceContextManagerLine())
+        if config.triton.autotune_at_compile_time:
+            self.kernel_autotune_calls.do_unindent()
+
+    def generate_return(self, output_refs: list[str]) -> None:
+        if output_refs:
+            if config.nan_asserts:
+                self.wrapper_call.writeline(
+                    "return_vars = (" + ", ".join(output_refs) + ", )"
+                )
+                self.wrapper_call.writeline("for var in return_vars:")
+                self.wrapper_call.do_indent()
+                self.wrapper_call.writeline("if isinstance(var, torch.Tensor):")
+                self.wrapper_call.do_indent()
+                self.wrapper_call.writeline("assert not var.isnan().any().item()")
+                self.wrapper_call.writeline("assert not var.isinf().any().item()")
+                self.wrapper_call.do_unindent(2)
+
+            self.wrapper_call.writeline("return (" + ", ".join(output_refs) + ", )")
+        else:
+            self.wrapper_call.writeline("return ()")
+
+    def generate_before_suffix(self, result: IndentedBuffer) -> None:
+        return
+
+    def generate_after_suffix(self, result: IndentedBuffer) -> None:
+        if config.graph_partition:
+            all_partition_name_list = ", ".join(self.all_partition_names) + (
+                "," if len(self.all_partition_names) == 1 else ""
+            )
+
+            result.splice(
+                f"""
+                runner = Runner(partitions=[{all_partition_name_list}])
+                call = runner.call
+                recursively_apply_fns = runner.recursively_apply_fns
+                """
+            )
+
+    def generate_end(self, result: IndentedBuffer) -> None:
+        return
+
+    def generate_fallback_kernel(self, node: ir.FallbackKernel) -> None:
+        self.writeline(ExternKernelAllocLine(self, node))
+
+    def generate_extern_kernel_alloc(self, node: ir.ExternKernelAlloc):
+        node.codegen_comment(self)
+        self.writeline(ExternKernelAllocLine(self, node))
+        if isinstance(node.layout, ir.Layout):
+            node.codegen_size_asserts(self)
+
+    def _generate_extern_kernel_alloc_helper(self, extern_kernel, args):
+        # If it's a NoneLayout then the extern_kernel should essentially be
+        # treated as if it doesn't return anything
+        no_return = isinstance(extern_kernel.layout, ir.NoneLayout)
+        output_name = extern_kernel.get_name()
+        origin_node = extern_kernel.get_origin_node()
+        kernel_name = extern_kernel.get_kernel_name()
+        ending = self.ending
+        if config.memory_planning and "view_as_complex" in kernel_name:
+            # view operation fallbacks cause issues since inductor
+            # doesn't know the memory is still needed and might reuse it.
+            ending = f".clone(){ending}"
+
+        if no_return:
+            self.writeline(f"{self.declare}{kernel_name}({', '.join(args)}){ending}")
+        else:
+            self.writeline(
+                f"{self.declare}{output_name} = {kernel_name}({', '.join(args)}){ending}"
+            )
+            if (
+                self.supports_intermediate_hooks
+                and config.generate_intermediate_hooks
+                and origin_node is not None
+            ):
+                counters["inductor"]["intermediate_hooks"] += 1
+                self.writeline(
+                    f"run_intermediate_hooks({origin_node.name!r}, {output_name})"
+                )
+
+    def generate_extern_kernel_out(
+        self,
+        node: ir.ExternKernelOut,
+    ) -> None:
+        node.codegen_comment(self)
+        self.writeline(ExternKernelOutLine(self, node))
+
+    def _generate_extern_kernel_out_helper(
+        self,
+        kernel: str,
+        out: str,
+        out_view: Optional[str],
+        args: list[str],
+        device: str,
+        stack_traces: Optional[OrderedSet[str]] = None,
+    ) -> None:
+        # add debug printer code for triton kernel calls at (jit) inductor level
+        debug_printer_manager = V.graph.wrapper_code.debug_printer
+        debug_printer_manager.set_printer_args(args, kernel, None, None, "extern")
+        args.append(f"out={out_view if out_view else out}")
+        with debug_printer_manager:
+            self.writeline(f"{kernel}({', '.join(args)})")
+
+    def _generate_tma_descriptor_call_experimental(self, desc, apply_size_hints=False):
+        dims = desc.dims
+        block_dims = desc.block_dims
+        if apply_size_hints:
+            dims = tuple(V.graph.sizevars.atomically_apply_size_hint(d) for d in dims)
+            block_dims = tuple(
+                V.graph.sizevars.atomically_apply_size_hint(d) for d in block_dims
+            )
+
+        ptr = f"{desc.tensor.codegen_reference()}.data_ptr()"
+        # Explicitly call the Python version of val_to_arg_str
+        dims = ", ".join(PythonWrapperCodegen.val_to_arg_str(self, dim) for dim in dims)
+        block_dims = ", ".join(
+            PythonWrapperCodegen.val_to_arg_str(self, dim) for dim in block_dims
+        )
+        element_size = PythonWrapperCodegen.val_to_arg_str(self, desc.element_size)
+        prefix = "triton.tools.experimental_descriptor"
+        fn = f"{prefix}.create_{desc.rank}d_tma_descriptor"
+        args = f"{ptr}, {dims}, {block_dims}, {element_size}"
+        call = f"{fn}({args})"
+        return call
+
+    def _generate_tma_descriptor_call_stable(self, desc, apply_size_hints=False):
+        block_shape = desc.block_shape
+        if apply_size_hints:
+            block_shape = tuple(
+                V.graph.sizevars.atomically_apply_size_hint(d) for d in block_shape
+            )
+
+        prefix = "triton.tools.tensor_descriptor.TensorDescriptor"
+        fn = f"{prefix}.from_tensor"
+        args = f"{desc.tensor.codegen_reference()}, {block_shape}"
+        call = f"{fn}({args})"
+        return call
+
+    def _generate_tma_descriptor_call(self, desc, apply_size_hints=False):
+        if isinstance(desc, ir.TMADescriptorExperimental):
+            return self._generate_tma_descriptor_call_experimental(
+                desc, apply_size_hints
+            )
+        else:
+            assert isinstance(desc, ir.TMADescriptorStable)
+            return self._generate_tma_descriptor_call_stable(desc, apply_size_hints)
+
+    def generate_tma_descriptor(self, desc):
+        call = self._generate_tma_descriptor_call(desc)
+        line = f"{desc.name} = {call}{self.ending}"
+        self.writeline(line)
+
+    def generate_scatter_fallback(self, node: ir.ScatterFallback):
+        self.writeline(ScatterFallbackLine(self, node))
+
+    def _generate_scatter_fallback(
+        self,
+        output,
+        inputs,
+        cpp_kernel_name,
+        python_kernel_name,
+        src_is_tensor,
+        reduce,
+        kwargs,
+        device,
+    ):
+        line = f"{python_kernel_name}({','.join(map(str, inputs))}"
+        if python_kernel_name.startswith("aten.scatter_reduce"):
+            line += ", ".join([""] + kwargs)
+        else:
+            if reduce:
+                line += f", reduce={repr(reduce)}"
+        line += ")"
+        self.writeline(line)
+
+    def generate_index_put_fallback(self, node: ir.IndexPutFallback) -> None:
+        # Collect index tensors into a list.
+        indices: list[Optional[ir.IRNode]] = []
+        valid_indices = node.inputs[2:]
+        iter_valid_indices = iter(valid_indices)
+        for i, _ in enumerate(node.indices):
+            if node.indices[i] is not None:
+                index = next(iter_valid_indices)
+                assert isinstance(index, ir.IRNode)
+                indices.append(index)
+            else:
+                indices.append(None)
+
+        self.writeline(IndexPutFallbackLine(self, node, indices))
+
+    def _generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
+        indices_str = f"[{', '.join(indices)}]"
+        args = [x, indices_str, values, accumulate]
+        self.writeline(self.wrap_kernel_call(kernel, args))
+
+    def generate_fallback_kernel_with_runtime_lookup(
+        self,
+        buf_name: str,
+        python_kernel_name: str,
+        get_args: Callable[[], Sequence[str]],
+        op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator],
+        raw_args: Sequence[Any],
+        outputs: Sequence[ir.Buffer],
+    ) -> None:
+        self.writeline(f"{buf_name} = {python_kernel_name}({', '.join(get_args())})")
+
+    def generate(self, is_inference):
+        with dynamo_timed("PythonWrapperCodegen.generate"):
+            return self._generate(is_inference)
+
+    def get_wrapper_call_indent(self) -> int:
+        if config.graph_partition:
+            return 2
+        else:
+            return 1
+
+    @contextlib.contextmanager
+    def set_writeline(self, new: Callable[..., None]) -> Iterator[Callable[..., None]]:
+        old = self.writeline
+        try:
+            self.writeline = new  # type: ignore[method-assign]
+            yield new
+        finally:
+            self.writeline = old  # type: ignore[method-assign]
+
+    def _write_multi_kernel_defs(self) -> None:
+        kernel_defs = self.multi_kernel_state.kernel_defs
+        if config.triton.autotune_at_compile_time:
+            self.kernel_autotune_defs.splice(kernel_defs)
+        else:
+            self.header.splice(kernel_defs)
+
+    def _generate(self, is_inference):
+        if config.profile_bandwidth:
+            self.write_triton_header_once()
+
+        with contextlib.ExitStack() as stack:
+            stack.enter_context(self.wrapper_call.indent())
+            if config.profiler_mark_wrapper_call:
+                self.generate_profiler_mark_wrapper_call(stack)
+            if config.profile_bandwidth:
+                self.generate_start_graph()
+
+            self.run_wrapper_ir_passes(is_inference)
+
+            if config.triton.store_cubin and not config.triton.autotune_at_compile_time:
+                self.generate_reset_kernel_saved_flags()
+
+            # At this point, we shouldn't generate any new memory planning lines.
+            # Override writeline to point at the wrapper call, in case it gets called.
+            with self.set_writeline(self.wrapper_call.writeline):
+                for line in self.lines:
+                    if isinstance(line, WrapperLine):
+                        # pyrefly: ignore [missing-attribute]
+                        line.codegen(self.wrapper_call)
+                    else:
+                        self.wrapper_call.writeline(line)
+
+            self._write_multi_kernel_defs()
+
+            output_refs = self.get_output_refs()
+            self.mark_output_type()
+            if config.triton.debug_sync_graph:
+                self.wrapper_call.writeline(V.graph.device_ops.synchronize())
+
+            if config.profile_bandwidth:
+                self.generate_end_graph()
+
+            if config.triton.store_cubin and not config.triton.autotune_at_compile_time:
+                self.generate_save_uncompiled_kernels()
+
+            if config.triton.autotune_at_compile_time:
+                self.generate_and_run_autotune_block()
+
+            # cpp_wrapper currently doesn't support nvtx
+            if config.annotate_training and not config.cpp_wrapper:
+                self.wrapper_call.writeline(
+                    "nvtx._device_range_end(training_annotation)"
+                )
+            self.generate_return(output_refs)
+
+        # Assemble the final code from sections.
+        result = IndentedBuffer()
+        result.splice(self.imports)
+        result.writeline("")
+        result.splice(self.header)
+        # We do not want the cpp header for intermediate const graph. Headers would be
+        # rendered by the main module instead.
+        if V.graph.aot_mode and V.graph.cpp_wrapper and V.graph.is_const_graph:
+            result = IndentedBuffer()
+
+        # Add subgraph definitions to the result
+        result.splice(self.subgraph_definitions)
+        self.finalize_prefix()
+        result.splice(self.prefix)
+
+        wrapper_call_indent = self.get_wrapper_call_indent()
+
+        with result.indent(wrapper_call_indent):
+            result.splice(self.wrapper_call)
+
+        self.generate_before_suffix(result)
+        result.splice(self.suffix)
+        self.generate_after_suffix(result)
+
+        self.generate_end(result)
+
+        self.add_benchmark_harness(result)
+
+        return (
+            result.getvaluewithlinemap(),
+            self.kernel_declarations.getvaluewithlinemap(),
+        )
+
+    def generate_and_run_autotune_block(self):
+        """
+        Compose self.kernel_autotune_defs and self.kernel_autotune_calls into a single block of
+        code and execute it to trigger Triton kernel compilation and auto-tuning
+        """
+        self.kernel_autotune_defs.splice(
+            """
+            async_compile.wait(globals())
+            del async_compile
+        """
+        )
+        scope = {}  # type: ignore[var-annotated]
+        if config.triton.autotune_at_compile_time and V.graph.autotuning_inputs:
+            scope = {
+                self.get_autotuning_input_name(idx): v  # type: ignore[attr-defined]
+                for idx, v in enumerate(V.graph.autotuning_inputs)
+            }
+        tuning_code = (
+            self.kernel_autotune_defs.getvalue()
+            + "\n"
+            + self.kernel_autotune_calls.getvalue()
+        )
+        if output_code_log.level == logging.DEBUG:
+            # Save the autotuning code block into a file
+            # Create a temporary file
+            with tempfile.NamedTemporaryFile(
+                dir=cache_dir(), suffix=".py", delete=False
+            ) as f:
+                f.write(tuning_code.encode("utf-8"))
+                file_path = f.name
+            output_code_log.debug(
+                "Auto-tuning code written to %s",
+                file_path,
+            )
+        trace_structured(
+            "artifact",
+            metadata_fn=lambda: {
+                "name": "inductor_autotune_at_compile_time_code",
+                "encoding": "string",
+            },
+            payload_fn=lambda: tuning_code,
+        )
+        # Execute the code to autotune kernels
+        try:
+            exec(tuning_code, scope)
+        except Exception as e:
+            raise RuntimeError(f"Failed to run autotuning code block: {e}") from e
+
+    def memory_plan(self):
+        from .memory_planning import MemoryPlanner
+
+        self.lines = MemoryPlanner(self).plan(self.lines)
+
+    def memory_plan_reuse(self):
+        outputs = self.get_graph_outputs()
+        out_names = V.graph._get_output_names(outputs)
+
+        while (
+            self.lines
+            and isinstance(self.lines[-1], MemoryPlanningLine)
+            # TODO: this seems legit, NullLine has no node
+            and self.lines[-1].node.name not in out_names  # type: ignore[attr-defined]
+        ):
+            # these lines will be pointless
+            self.lines.pop()
+
+        # codegen allocations in two passes
+        planning_states = [MemoryPlanningState()]
+        past_planning_states = []
+        for i in range(len(self.lines)):
+            line = self.lines[i]
+            if isinstance(line, MemoryPlanningLine):
+                self.lines[i] = line.plan(planning_states[-1])
+            elif isinstance(line, EnterSubgraphLine):
+                planning_states.append(MemoryPlanningState())
+            elif isinstance(line, ExitSubgraphLine):
+                past_planning_states.append(planning_states.pop())
+        past_planning_states.append(planning_states.pop())
+        assert len(planning_states) == 0
+
+        # conservatively use the sum of all allocated buffer sizes
+        # in potentially nested scopes as the total allocated size
+        # FIXME(rec): not used
+        _total_allocated_buffer_size = sum(
+            s.total_allocated_buffer_size for s in past_planning_states
+        )
+
+    def run_wrapper_ir_passes(self, is_inference: bool):
+        # We disable planning during training because it presently increases peak memory consumption.
+        if is_inference and config.memory_planning:
+            self.memory_plan()
+        else:
+            if config.allow_buffer_reuse:
+                self.estimate_peak = EfficientPeakEstimate()
+            self.memory_plan_reuse()
+
+    def codegen_input_symbol_assignment(
+        self,
+        name: str,
+        value: ir.TensorBox,
+        bound_vars: OrderedSet[sympy.Symbol],
+    ):
+        code = self.prefix
+
+        @functools.cache
+        def sizeof(name):
+            code.writeline(f"{name}_size = {name}.size()")
+            return f"{name}_size"
+
+        @functools.cache
+        def strideof(name):
+            code.writeline(f"{name}_stride = {name}.stride()")
+            return f"{name}_stride"
+
+        if isinstance(value, sympy.Expr):
+            if not isinstance(value, sympy.Symbol) or value in bound_vars:
+                return
+            code.writeline(f"{value} = {name}")
+            bound_vars.add(value)
+        elif isinstance(value, ir.TensorBox):
+            for dim, size in enumerate(value.get_size()):
+                if isinstance(size, sympy.Symbol) and size not in bound_vars:
+                    code.writeline(f"{size} = {sizeof(name)}[{dim}]")
+                    bound_vars.add(size)
+            for dim, stride in enumerate(value.get_stride()):
+                if isinstance(stride, sympy.Symbol) and stride not in bound_vars:
+                    code.writeline(f"{stride} = {strideof(name)}[{dim}]")
+                    bound_vars.add(stride)
+        elif isinstance(value, ir.TorchBindObject):
+            return
+        elif isinstance(value, ir.GeneratorState):
+            return
+        else:
+            if torch._inductor.config.graph_partition:
+                pass
+            else:
+                raise AssertionError(f"Unknown value type: {type(value)}")
+
+    def codegen_inputs(self):
+        """Assign all symbolic shapes to locals"""
+        bound_vars = OrderedSet[sympy.Symbol]()
+        # There is a subtle case in the cpp wrapper codegen which requires generating
+        # symbol inputs first followed by non-symbol ones.
+        #
+        # When a dynamic size constraint specified at the Export time is an expression,
+        # we need to solve that expression to proper define a symbol in cpp. Thus we
+        # are enforcing this iterating order here to make sure all plain size symbols
+        # are defined first.
+        graph_inputs = self.get_graph_inputs()
+        inputs = [
+            (k, v) for k, v in graph_inputs.items() if isinstance(v, sympy.Symbol)
+        ] + [(k, v) for k, v in graph_inputs.items() if not isinstance(v, sympy.Symbol)]
+        for name, value in inputs:
+            self.codegen_input_symbol_assignment(name, value, bound_vars)
+
+        def _verify_input_symbol_assignment(
+            value: ir.TensorBox,
+            bound_vars: OrderedSet[sympy.Symbol],
+        ):
+            for expr in chain.from_iterable([value.get_size(), value.get_stride()]):
+                if not isinstance(expr, Expr) or isinstance(expr, sympy.Symbol):
+                    continue
+
+                undefined_symbols = [
+                    sym for sym in expr.free_symbols if sym not in bound_vars
+                ]
+                if len(undefined_symbols) > 0:
+                    raise AssertionError(
+                        f"For {expr}, expected {undefined_symbols} to have been codegen-ed."
+                    )
+
+        # For inputs with size/strides which contain sympy expressions, we can
+        # encounter symbols that weren't defined yet. Now, let's check each
+        # symbol is defined.
+        for _, value in inputs:
+            if not isinstance(value, ir.TensorBox):
+                continue
+            _verify_input_symbol_assignment(value, bound_vars)
+
+    def ensure_size_computed(self, sym: sympy.Symbol):
+        if isinstance(sym, sympy.Symbol) and symbol_is_type(sym, SymT.PRECOMPUTED_SIZE):
+            if sym in self.computed_sizes:
+                return
+            self.computed_sizes.add(sym)
+            expr = V.graph.sizevars.inv_precomputed_replacements[sym]
+            arg = SymbolicCallArg(sym, expr)
+            self.writeline(SymbolicCallArgLine(self, arg, V.graph))
+
+    def finalize_prefix(self):
+        pass
+
+    def codegen_cpp_sizevar(self, x: Expr, *, simplify: bool = True) -> str:
+        raise RuntimeError("codegen_cpp_sizevar is only implemented for cpp_wrapper!")
+
+    def codegen_python_sizevar(self, x: Expr, *, simplify: bool = True) -> str:
+        return pexpr(x, simplify=simplify)
+
+    def codegen_sizevar(self, x: Expr) -> str:
+        return self.codegen_python_sizevar(x)
+
+    def codegen_tuple_access(self, basename: str, name: str, index: str) -> str:
+        return f"{basename}[{index}]"
+
+    def codegen_python_shape_tuple(self, shape: Sequence[Expr]) -> str:
+        parts = [*map(self.codegen_python_sizevar, shape)]
+        if len(parts) == 0:
+            return "()"
+        if len(parts) == 1:
+            return f"({parts[0]}, )"
+        return f"({', '.join(parts)})"
+
+    def codegen_shape_tuple(self, shape: Sequence[Expr]) -> str:
+        return self.codegen_python_shape_tuple(shape)
+
+    def codegen_alloc_from_pool(
+        self, name, offset, dtype, shape, stride
+    ) -> tuple[str, list[str]]:
+        return "alloc_from_pool({})".format(
+            ", ".join(
+                [
+                    name,
+                    pexpr(offset),  # bytes not numel
+                    str(dtype),
+                    self.codegen_python_shape_tuple(shape),
+                    self.codegen_python_shape_tuple(stride),
+                ]
+            )
+        ), []
+
+    def codegen_reinterpret_view(
+        self,
+        data,
+        size,
+        stride,
+        offset,
+        writeline: Callable[..., None],
+        dtype=None,
+    ) -> str:
+        # Get the innermost buffer's layout info to help reinterpret view.
+        # Consider a chain of (ReinterpretView <- TensorBox| StorageBox)... <- buffer
+        # If we only use x.data to determine the reinterpret, we may get wrong layout.
+        # For example:
+        # x = ReinterpretView(
+        #       Storage(
+        #         ReinterpretView(
+        #           storage(
+        #             Buffer(name='buf0', layout=(size=(2, 5, 10), ...)
+        #           ),
+        #           layout=(10, 10),
+        #         ),
+        #       ),
+        #       layout=(10, 10),
+        #     )
+        # In this case, x.data.layout == x.layout is (10, 10), the reinterpret view will return buf0,
+        # but buf0 need to be viewed from (2, 5, 10) to (10, 10).
+        # So we need to dig into the chain to find the innermost buffer's layout.
+        d_size, d_stride, d_offset, d_dtype, collapsible = (
+            codegen_reinterpret_view_helper(data)
+        )
+
+        def apply_reinterpret(
+            name, tgt_size, tgt_stride, tgt_offset, cast_dtype, base_dtype
+        ):
+            s = self.codegen_python_shape_tuple(tgt_size)
+            st = self.codegen_python_shape_tuple(tgt_stride)
+            off = self.codegen_sizevar(tgt_offset)
+            expr = f"reinterpret_tensor({name}, {s}, {st}, {off})"
+            if cast_dtype is not None and cast_dtype != base_dtype:
+                return f"aten.view.dtype({expr}, {cast_dtype})"
+            return expr
+
+        name = data.get_name()
+        collapsed = collapsible and offset == d_offset
+        if collapsed:
+            same_layout = size == d_size and stride == d_stride
+            base_dtype = d_dtype
+        else:
+            same_layout = (
+                size == data.layout.size
+                and stride == data.layout.stride
+                and offset == data.layout.offset
+            )
+            base_dtype = data.dtype
+
+        if same_layout:
+            if dtype is not None and dtype != base_dtype:
+                return f"aten.view.dtype({name}, {dtype})"
+            return f"{name}"
+
+        return apply_reinterpret(name, size, stride, offset, dtype, base_dtype)
+
+    def codegen_device_copy(self, src, dst, non_blocking: Union[bool, str]):
+        self.writeline(f"{dst}.copy_({src}, {non_blocking})")
+
+    def codegen_multi_output(self, node: ir.MultiOutput):
+        result_name = node.get_name()
+        arg_name = node.input_name(0)
+        self.writeline(MultiOutputLine(self, result_name, arg_name, node.indices))
+
+    def codegen_dynamic_select_index(self, node, clamp):
+        index_str = f"{node.index} + {node.size} if {node.index} < 0 else {node.index}"
+        if clamp:
+            index_str = f"max(0, min({node.size}, {index_str}))"
+        self.writeline(
+            f"{node.unbacked_offset_symbol} = {node.base_offset} + {node.base_dim_stride} * ({index_str})"
+        )
+        # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
+        self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol))
+
+    def codegen_dynamic_slice_size(self, node):
+        def clamp_index(x):
+            pos = self.codegen_sizevar(sympy.Max(0, sympy.Min(x, node.size)))
+            neg = self.codegen_sizevar(
+                sympy.Max(0, sympy.Min(x + node.size, node.size))
+            )
+            x_cond = self.codegen_sizevar(x)
+            return f"{pos} if {x_cond} >= 0 else {neg}"
+
+        def codegen_with_step(start_var, end_var, step):
+            if step == 1:
+                return f"{end_var} - {start_var}"
+            step_ = self.codegen_sizevar(step)
+            return f"({end_var} - {start_var} + {step_} - 1) // {step_}"
+
+        # codegen start, end
+        sym = node.unbacked_size_symbol
+        start = clamp_index(node.start)
+        end = clamp_index(node.end)
+        self.writeline(f"{sym}_start = {start}")
+        self.writeline(f"{sym}_end = {end}")
+        with_step = codegen_with_step(f"{sym}_start", f"{sym}_end", node.step)
+        self.writeline(f"{sym} = max(0, {with_step})")
+        self.unbacked_symbol_decls.add(str(node.unbacked_size_symbol))
+
+    def codegen_dynamic_scalar(self, node):
+        self.writeline(DynamicScalarLine(self, node))
+
+    def _codegen_dynamic_scalar(self, node):
+        (data,) = (t.codegen_reference() for t in node.inputs)
+        if len(node.keypath) == 0:
+            self.writeline(f"{node.sym} = {data}.item()")
+        elif len(node.keypath) == 1 and isinstance(node.keypath[0], ConvertIntKey):
+            self.writeline(f"{node.sym} = 1 if {data}.item() else 0")
+        elif len(node.keypath) == 1 and isinstance(node.keypath[0], DivideByKey):
+            self.writeline(f"{node.sym}_undivided = {data}.item()")
+            self.writeline(
+                f"assert {node.sym}_undivided % {node.keypath[0].divisor} == 0, "
+                f"f'{{{node.sym}_undivided}} not divisible by {node.keypath[0].divisor}'"
+            )
+            self.writeline(
+                f"{node.sym} = {node.sym}_undivided // {node.keypath[0].divisor}"
+            )
+        else:
+            raise AssertionError(f"unrecognized keypath {node.keypath}")
+        # No one should ever use this buffer, but for uniformity
+        # define the variable and assign it None
+        self.writeline(f"{node.get_name()} = None")
+
+    def benchmark_compiled_module(self, output):
+        def add_fake_input(name, shape, stride, device, dtype):
+            output.writeline(
+                f"{name} = rand_strided("
+                f"{self.codegen_python_shape_tuple(shape)}, "
+                f"{self.codegen_python_shape_tuple(stride)}, "
+                f"device='{device}', dtype={dtype})"
+            )
+
+        def add_expr_input(name, val):
+            output.writeline(f"{name} = {val}")
+
+        def add_torchbind_input(name, value):
+            if value is None:
+                output.writeline(f"{name} = None")
+                return
+
+            import pickle
+
+            assert isinstance(value, torch.ScriptObject)
+
+            output.writeline(f"{name} = pickle.loads({pickle.dumps(value)!r})")
+
+        output.writelines(
+            ["", "", "def benchmark_compiled_module(times=10, repeat=10):"]
+        )
+        with output.indent():
+            output.splice(
+                """
+                from torch._dynamo.testing import rand_strided
+                from torch._inductor.utils import print_performance
+                """,
+                strip=True,
+            )
+
+            for name, value in V.graph.constants.items():
+                # all the constants are global variables, that's why we need
+                # these 'global var_name' lines
+                output.writeline(f"global {name}")
+                add_fake_input(
+                    name, value.size(), value.stride(), value.device, value.dtype
+                )
+
+            if len(V.graph.torchbind_constants) > 0:
+                output.writeline("import pickle")
+                for name, torchbind_obj in V.graph.torchbind_constants.items():
+                    # all the constants are global variables, that's why we need
+                    # these 'global var_name' lines
+                    output.writeline(f"global {name}")
+                    add_torchbind_input(name, torchbind_obj)
+
+            for name, value in V.graph.graph_inputs.items():
+                if isinstance(value, sympy.Symbol) and isinstance(
+                    V.graph.sizevars.var_to_val.get(value, None), SingletonInt
+                ):
+                    # Inductor should only work with dense -> dense graph, and
+                    # SingletonInts belong to metadata that should only live on
+                    # the subclass.
+                    continue
+                if isinstance(value, ir.TorchBindObject):
+                    if len(V.graph.torchbind_constants) == 0:
+                        # otherwise we have already imported the pickle package
+                        output.writeline("import pickle")
+                    output.writeline(f"global {name}")
+                    add_torchbind_input(name, value.get_real_obj())
+                elif isinstance(value, sympy.Expr):  # Don't need to add symbolic
+                    # TODO: this fallback and those below actually will generate possibly
+                    # invalid benchmark code, because it's not guaranteed 42
+                    # is actually a valid value for the kernel in question.
+                    # See https://github.com/pytorch/pytorch/issues/124686
+                    add_expr_input(name, V.graph.sizevars.size_hint(value, fallback=42))
+                elif isinstance(value, ir.GeneratorState):
+                    add_expr_input(
+                        name,
+                        f"torch.cuda.default_generators[{value.device.index}].graphsafe_get_state()",
+                    )
+                else:
+                    shape = [
+                        V.graph.sizevars.size_hint(x, fallback=42)
+                        for x in value.get_size()
+                    ]
+                    stride = [
+                        V.graph.sizevars.size_hint(x, fallback=42)
+                        for x in value.get_stride()
+                    ]
+                    add_fake_input(
+                        name,
+                        shape,
+                        stride,
+                        value.get_device(),
+                        value.get_dtype(),
+                    )
+
+            call_str = f"call([{', '.join(V.graph.graph_inputs.keys())}])"
+            output.writeline(f"fn = lambda: {call_str}")
+            output.writeline("return print_performance(fn, times=times, repeat=repeat)")
+
+    def add_benchmark_harness(self, output):
+        """
+        Append a benchmark harness to generated code for debugging
+        """
+        if not config.benchmark_harness:
+            return
+
+        self.benchmark_compiled_module(output)
+
+        output.writelines(["", "", 'if __name__ == "__main__":'])
+        with output.indent():
+            output.writelines(
+                [
+                    "from torch._inductor.wrapper_benchmark import compiled_module_main",
+                    f"compiled_module_main('{get_benchmark_name()}', benchmark_compiled_module)",
+                ]
+            )
+
+    def define_kernel(
+        self,
+        kernel_name: str,
+        kernel_body: str,
+        metadata: Optional[str] = None,
+        gpu: bool = True,
+        cpp_definition: Optional[str] = None,
+    ):
+        self.writeline(
+            KernelDefinitionLine(
+                self,
+                kernel_name,
+                kernel_body,
+                metadata=metadata,
+                gpu=gpu,
+                cpp_definition=cpp_definition,
+            )
+        )
+
+    @staticmethod
+    def _format_kernel_definition(
+        kernel_name: str, kernel_body: str, metadata: Optional[str] = None
+    ):
+        if config.triton.autotune_at_compile_time and metadata:
+            # Generating autotune block
+            # Need to replace C++ comment starter with Python comment starter
+            metadata = re.sub(r"^// ", "# ", metadata, flags=re.MULTILINE)
+        metadata_comment = f"{metadata}\n" if metadata else ""
+        body = f"\n\n{metadata_comment}{kernel_name} = {kernel_body}"
+        return body
+
+    def _define_kernel_helper(
+        self,
+        kernel_name: str,
+        kernel_body: str,
+        metadata: Optional[str] = None,
+        gpu: bool = True,
+        cpp_definition: Optional[str] = None,
+    ):
+        if config.triton.autotune_at_compile_time and gpu:
+            body = self._format_kernel_definition(
+                kernel_name, kernel_body, metadata=metadata
+            )
+            self.kernel_autotune_defs.splice(body)
+            if V.graph.cpp_wrapper:
+                # For cpp wrapper, no need to continue codegen for the main body
+                return
+
+        body = self._format_kernel_definition(
+            kernel_name, kernel_body, metadata=metadata
+        )
+        self.header.splice(body)
+
+    def define_subgraph_launcher_fn(self, name: str, subgraph_code):
+        self.subgraph_definitions.splice(subgraph_code.value)
+
+    def define_user_defined_triton_kernel(
+        self,
+        kernel,
+        configs,
+        kwargs,
+        restore_value_args,
+        reset_to_zero_args,
+        grids: list[list[Union[int, sympy.Expr]]],
+    ):
+        from ..runtime.triton_heuristics import (
+            config_to_dict,
+            FixedGrid,
+            PrecomputedGrid,
+        )
+        from .common import (
+            ConstexprArg,
+            KernelArgType,
+            SizeArg,
+            TensorArg,
+            TMADescriptorArg,
+        )
+        from .triton import gen_common_triton_imports, TritonKernel
+
+        original_name = kernel.__name__
+        signature: list[KernelArgType] = []
+        constants: dict[str, Any] = {}
+        arg_indices: list[int] = []
+        equal_to_1_args: list[str] = []
+
+        def add_to_signature(idx, arg):
+            signature.append(arg)
+            arg_indices.append(idx)
+
+        def add_arg(idx, arg, is_constexpr=False, equals_1=False, equals_none=False):
+            if is_constexpr:
+                if triton_version_uses_attrs_dict():
+                    # tl.constexpr args appear in the signature in new versions of triton,
+                    # but not in old versions of triton.
+                    add_to_signature(idx, arg)
+
+                if arg.name in kwargs:
+                    # the arg may not appear in kwargs if it is an autotuned arg.
+                    # in this case, it will be added in triton_heuristics after autotuning.
+                    constants[arg.name] = kwargs[arg.name]
+
+            else:
+                # the only case where arg name isn't in kwargs, should be
+                # when the arg is a constexpr.
+                assert arg.name in kwargs
+
+                if equals_1:
+                    if triton_version_uses_attrs_dict():
+                        # new versions of triton: add the equal-to-1 arg in the signature (labeled as "constexpr"),
+                        #                         and add the arg as a constant.
+                        # new versions of triton: add the equal-to-1 arg in the signature (labeled as, e.g., "i32"),
+                        #                         and add the arg as a constant.
+                        add_to_signature(idx, ConstexprArg(name=arg.name))
+                    else:
+                        add_to_signature(idx, arg)
+                    constants[arg.name] = 1
+                elif equals_none:
+                    if triton_version_uses_attrs_dict():
+                        # new versions of triton: add the none arg in the signature (as a constexpr arg) and as a constant
+                        # old versions of triton: include the none arg as a constant (but not in the signature)
+                        add_to_signature(idx, ConstexprArg(name=arg.name))
+                    constants[arg.name] = None
+                else:
+                    add_to_signature(idx, arg)
+
+        arg_names = [p.name for p in kernel.params]
+        constexprs = [p.num for p in kernel.params if p.is_constexpr]
+        for idx, key in enumerate(arg_names):
+            if idx in constexprs:
+                add_arg(idx, ConstexprArg(name=key), is_constexpr=True)
+                continue
+
+            if key not in kwargs:
+                continue
+
+            arg = kwargs[key]
+
+            if kwargs[key] is None:
+                add_arg(idx, ConstexprArg(name=key), equals_none=True)
+            else:
+                if isinstance(arg, ir.TMADescriptor):
+                    api_type, block_shape, dtype = (
+                        ("stable", arg.block_shape, arg.tensor.get_dtype())
+                        if isinstance(arg, ir.TMADescriptorStable)
+                        else ("experimental", None, None)
+                    )
+                    add_arg(
+                        idx,
+                        TMADescriptorArg(
+                            name=key,
+                            api_type=api_type,
+                            block_shape=block_shape,
+                            dtype=dtype,
+                        ),
+                    )
+                elif isinstance(arg, ir.Buffer):
+                    add_arg(
+                        idx,
+                        TensorArg(
+                            name=key,
+                            buffer=arg.get_name(),
+                            dtype=arg.get_dtype(),
+                        ),
+                    )
+                elif isinstance(arg, ir.ReinterpretView):
+                    # for ReinterpretView we use the underlying
+                    # buffer name and note the (possibly non-zero)
+                    # offset relative to the underlying buffer
+                    add_arg(
+                        idx,
+                        TensorArg(
+                            name=key,
+                            buffer=arg.data.get_name(),
+                            dtype=arg.get_dtype(),
+                            offset=arg.layout.offset,
+                        ),
+                    )
+                else:
+                    equals_1 = isinstance(
+                        arg, (int, sympy.Integer)
+                    ) and V.graph.sizevars.statically_known_equals(
+                        arg,
+                        1,  # type: ignore[arg-type]
+                    )
+                    add_arg(idx, SizeArg(key, arg), equals_1=equals_1)
+
+        triton_signature = signature_to_meta(
+            signature,
+            size_dtype=None,  # try to infer based on symints
+            indices=arg_indices,
+            argdefs=[ArgName(x) for x in kernel.arg_names],
+        )
+        triton_meta: dict[str, Any] = {
+            "signature": triton_signature,
+            "device": DeviceProperties.create(V.graph.get_current_device_or_throw()),
+            # Triton compiler includes equal_to_1 args into constants even
+            # when they are not constexpr. otherwise there may be a segfault
+            # during launching the Inductor-compiled Triton kernel.
+            # TODO(aakhundov): add None args to constants, too. currently, this
+            # causes CUDA errors in test_aot_inductor.test_triton_kernel_with_none_input.
+            # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307
+            # https://github.com/triton-lang/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384
+            "constants": {
+                **constants,
+                **dict.fromkeys(equal_to_1_args, 1),
+            },
+            "configs": [
+                config_of(
+                    signature,
+                    indices=arg_indices,
+                )
+            ],
+        }
+
+        if restore_value_args:
+            triton_meta["restore_value"] = tuple(restore_value_args)
+
+        if reset_to_zero_args:
+            triton_meta["reset_to_zero"] = tuple(reset_to_zero_args)
+
+        if len(grids) == 1:
+            # compute the grid in the wrapper and pass it in as an arg
+            inductor_meta: dict[str, Any] = FixedGrid.setup_grid_as_args()
+            extra_launcher_call_args = [*map(sympy.sympify, grids[0])]
+        else:
+
+            def rename_sizes_for_launcher(expr: Union[int, sympy.Expr]) -> sympy.Expr:
+                if isinstance(expr, sympy.Expr):
+                    symbols = [*expr.free_symbols]
+                    if not symbols:
+                        return expr
+                    symbols.sort(key=str)
+                    for sym in symbols:
+                        if sym in extra_launcher_args:
+                            continue
+                        extra_launcher_args[sym] = sympy.Symbol(
+                            f"_launcher_s{len(extra_launcher_args)}"
+                        )
+                    return sympy_subs(expr, extra_launcher_args)
+                assert isinstance(expr, int)
+                return sympy.Integer(expr)
+
+            extra_launcher_args: dict[sympy.Symbol, sympy.Symbol] = {}
+            grids = [[*map(rename_sizes_for_launcher, grid)] for grid in grids]
+
+            assert grids and len(grids) == len(configs)
+            precomputed_grids = []
+            for grid, cfg in sorted(
+                zip(grids, configs), key=lambda x: len(x[1].kwargs), reverse=True
+            ):
+                precomputed_grids.append(
+                    {
+                        "config": config_to_dict(cfg),
+                        "python": [*map(pexpr, grid)],
+                        "cpp": [*map(cexpr, grid)],
+                        "python_slow": [*map(pexpr, grid)],
+                    }
+                )
+            inductor_meta = {
+                "grid_type": PrecomputedGrid.__name__,
+                "precomputed_grids": precomputed_grids,
+                "extra_launcher_args": [*map(str, extra_launcher_args.values())],
+            }
+            extra_launcher_call_args = [*extra_launcher_args.keys()]
+
+        # Distinguish between different functions using function id
+        cache_key: Any = [id(kernel.fn)]
+        if len(configs) > 0:
+            for arg in kwargs.values():
+                # We need to key on non tensor arg only in autotune mode
+                if not isinstance(arg, (ir.Buffer, ir.ReinterpretView)):
+                    cache_key.append(arg)
+        cache_key.append(str(triton_meta))
+        cache_key.extend(str(inductor_meta))
+        cache_key = tuple(cache_key)
+        if cache_key in self.user_defined_kernel_cache:
+            return (
+                *self.user_defined_kernel_cache[cache_key],
+                extra_launcher_call_args,
+            )
+
+        name = f"{original_name}_{len(self.user_defined_kernel_cache)}"
+
+        compile_wrapper = IndentedBuffer()
+        if config.triton.unique_user_kernel_names:
+            compile_wrapper.writeline(f"async_compile.triton({name!r}, '''")
+        else:
+            compile_wrapper.writeline(f"async_compile.triton({original_name!r}, '''")
+
+        inductor_meta["kernel_name"] = name
+        inductor_meta.update(TritonKernel.inductor_meta_common())
+
+        compile_wrapper.splice(gen_common_triton_imports())
+        compile_wrapper.splice(
+            f"""
+            @triton_heuristics.user_autotune(
+                configs={[*map(config_to_dict, configs)]!r},
+                inductor_meta={inductor_meta!r},
+                triton_meta={triton_meta!r},
+                filename=__file__,
+                custom_kernel=True,
+            )
+            @triton.jit
+            """
+        )
+        kernel_src = user_defined_triton_kernel_transitive_closure_source_code(kernel)
+        if config.triton.unique_user_kernel_names:
+            # We replace the original_name with the unique name.
+            kernel_src = kernel_src.replace(f"def {original_name}(", f"def {name}(")
+        kernel_src = kernel_src.replace("'''", "\\'\\'\\'")
+        compile_wrapper.splice(kernel_src)
+
+        current_device = V.graph.get_current_device_or_throw()
+        compile_wrapper.writeline(f"''', device_str='{current_device.type}')")
+        _, lineno = inspect.getsourcelines(kernel.fn)
+        srcfile = inspect.getsourcefile(kernel.fn)
+        metadata = f"# Original path: {srcfile}:{lineno}"
+        self.define_kernel(
+            name,
+            compile_wrapper.getvalue(),
+            metadata,
+        )
+        # Add to the cache for the next use
+        self.user_defined_kernel_cache[cache_key] = (name, triton_meta)
+        return name, triton_meta, extra_launcher_call_args
+
+    def generate_numel_expr(self, kernel_name: str, tree, suffix: Optional[str] = None):
+        sym_name = f"{kernel_name}_{tree.prefix}numel"
+        if suffix is not None:
+            sym_name += f"_{suffix}"
+        sym = sympy.Symbol(sym_name, is_integer=True, is_positive=True)
+
+        # We can get symbolic expressions here, like s0*64
+        # It is fine to have them here, but we need to handle them correctly as their own type
+        # This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy*
+        # scalars as well.
+        # This is handled in `generate_args_decl` which has a correct comment of: TODO: only works for
+        # constant now, need type info. I agree, this needs type info, and while this is not true type info
+        # it suffices as a type hint for the purposes of producing the correct code for this type.
+        arg = SymbolicCallArg(sym, tree.numel)
+
+        is_benchmark_kernel = kernel_name == ""
+        if not is_benchmark_kernel:
+            self.writeline(SymbolicCallArgLine(self, arg, V.graph))
+
+        return arg
+
+    def _generate_symbolic_call_arg_helper(
+        self, arg: SymbolicCallArg, graph: GraphLowering
+    ) -> None:
+        self.writeline(f"{arg.inner} = {pexpr(arg.inner_expr)}")
+
+    def generate_workspace_allocation(self, ws: WorkspaceArg):
+        name = ws.get_name()
+        line = AllocateLine(self, ws)
+        if ws.zero_mode == WorkspaceZeroMode.UNINITIALIZED:
+            self.writeline(line)
+        elif ws.zero_mode == WorkspaceZeroMode.ZERO_ON_CALL:
+            self.writeline(line)
+            self.writeline(self.make_zero_buffer(name))
+        elif ws.zero_mode == WorkspaceZeroMode.ZERO_PER_GRAPH:
+            prior = self.allocated_workspaces.get(name)
+            if prior:
+                assert isinstance(prior, AllocateLine) and isinstance(
+                    prior.node, WorkspaceArg
+                )
+                # expand existing allocation
+                prior.node = WorkspaceArg.maximum(prior.node, ws)
+            else:
+                self.writeline(line)
+                self.writeline(self.make_zero_buffer(name))
+                self.allocated_workspaces[name] = line
+        else:
+            raise AssertionError(ws.zero_mode)
+
+        if config.triton.autotune_at_compile_time:
+            self.kernel_autotune_calls.writeline(
+                PythonWrapperCodegen.make_allocation(
+                    self,
+                    name,
+                    ws.device,
+                    ws.dtype,
+                    shape=(V.graph.sizevars.size_hint(ws.count),),
+                    stride=(1,),
+                )
+            )
+            if ws.zero_mode != WorkspaceZeroMode.UNINITIALIZED:
+                self.kernel_autotune_calls.writeline(
+                    PythonWrapperCodegen.make_zero_buffer(self, name)
+                )
+
+    def generate_workspace_deallocation(self, ws: WorkspaceArg):
+        if ws.zero_mode != WorkspaceZeroMode.ZERO_PER_GRAPH:
+            self.writeline(FreeIfNotReusedLine(self, ws))
+
+    def make_zero_buffer(self, name):
+        return f"{name}.zero_(){self.ending}"
+
+    def wrap_kernel_call(self, name, call_args):
+        return f"{name}({', '.join(call_args)}){self.ending}"
+
+    def generate_profiler_mark_wrapper_call(self, stack):
+        self.wrapper_call.writeline("from torch.profiler import record_function")
+        self.wrapper_call.writeline(
+            f"with record_function('graph_{V.graph.graph_id}_inductor_wrapper_call'):"
+        )
+        stack.enter_context(self.wrapper_call.indent())
+
+    def generate_start_graph(self):
+        self.wrapper_call.writeline("start_graph()")
+
+    def generate_end_graph(self):
+        self.wrapper_call.writeline(f"end_graph({config.profile_bandwidth_output!r})")
+
+    def generate_reset_kernel_saved_flags(self):
+        self.wrapper_call.splice(
+            f"""
+            for kernel in globals().values():
+                if isinstance(kernel, {triton_heuristics.__name__}.CachingAutotuner):
+                    kernel.cuda_kernel_saved = False
+            """
+        )
+
+    def generate_save_uncompiled_kernels(self):
+        """
+        Precompile and save the CUBINs of the Triton kernels that haven't
+        been precompiled and saved as a side effect of running the generated
+        JIT model (Python wrapper). This can happen when the model contains
+        control flow: only one pass through the control flow operators covers
+        the kernels that are saved, the remaining kernels are not launched,
+        hence not saved. The main purpose of this codegen is to compile and
+        save the Triton kernels outside the active control flow path for
+        subsequent AOTInductor code generation and compilation.
+        """
+        self.wrapper_call.splice(
+            f"""
+            for kernel in globals().values():
+                if isinstance(kernel, {triton_heuristics.__name__}.CachingAutotuner):
+                    if not kernel.cuda_kernel_saved:
+                        if len(kernel.launchers) == 0:
+                            kernel.precompile()
+                        kernel.save_gpu_kernel(
+                            stream="stream",  # use dummy stream
+                            launcher=kernel.launchers[0],
+                        )
+            """
+        )
+
+    def prepare_triton_kernel_call(self, call_args):
+        def wrap_arg(arg):
+            if isinstance(arg, str):
+                # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
+                return arg + ".item()" if should_unwrap_unspec_arg(arg) else arg
+            elif isinstance(arg, (int, float, bool, SymbolicCallArg)):
+                return str(arg)
+            else:
+                return pexpr(V.graph.sizevars.simplify(arg))
+
+        return [wrap_arg(arg) for arg in call_args]
+
+    def generate_example_arg_value(self, arg, arg_type, raw_arg=None):
+        if isinstance(arg_type, torch_dtype):
+            if isinstance(raw_arg, ir.TMADescriptor):
+                # first we generate the underlying buffer
+                buf_name = raw_arg.get_tensor().get_name()
+                buf = self.args_to_buffers[arg]
+            elif self.args_to_buffers.get(arg):
+                buf_name = arg
+                buf = self.args_to_buffers[arg]
+            else:
+                assert raw_arg is not None, (
+                    "V.graph.get_buffer(arg) and raw_arg can't be None at the same time"
+                )
+                buf_name = f"tmp_arg_{self.kernel_autotune_tmp_arg_idx}"
+                buf = raw_arg
+                self.kernel_autotune_tmp_arg_idx += 1
+
+            assert buf is not None, f"Failed to find a buffer for arg {arg}"
+            size = tuple(
+                V.graph.sizevars.atomically_apply_size_hint(
+                    e,
+                    fallback=config.unbacked_symint_fallback,
+                )
+                for e in buf.get_size()
+            )
+            allocation_size = tuple(
+                V.graph.sizevars.atomically_apply_size_hint(
+                    e,
+                    fallback=config.unbacked_symint_fallback,
+                )
+                for e in V.graph.get_allocation_size(buf)
+            )
+            stride = tuple(
+                V.graph.sizevars.atomically_apply_size_hint(
+                    e,
+                    fallback=config.unbacked_symint_fallback,
+                )
+                for e in buf.get_stride()
+            )
+            device = buf.get_device()
+            dtype = buf.get_dtype()
+            offset = V.graph.sizevars.size_hint(
+                buf.get_layout().offset,
+                fallback=config.unbacked_symint_fallback,
+            )
+            value = f"generate_example_value({size}, {stride}, '{device}', {dtype}, {offset}, {allocation_size})"
+            self.kernel_autotune_calls.writeline(f"{buf_name} = {value}")
+
+            if isinstance(raw_arg, ir.TMADescriptor):
+                # generate another line initializing a host-side TMA
+                # descriptor from the underlying buffer created above
+                value = self._generate_tma_descriptor_call(
+                    desc=raw_arg,
+                    apply_size_hints=True,
+                )
+                buf_name = arg
+                self.kernel_autotune_calls.writeline(f"{buf_name} = {value}")
+
+            return buf_name
+        elif issubclass(arg_type, sympy.Basic) or isinstance(arg, SymbolicCallArg):
+            # arg is a symbol or symbolic expression
+            if isinstance(arg, str):
+                if arg in self._meta_vars:
+                    return arg
+                if raw_arg is None:
+                    return "None"
+                arg = raw_arg
+            if isinstance(arg, SymbolicCallArg):
+                arg = arg.inner_expr
+            if arg in V.graph.sizevars.inv_precomputed_replacements:
+                arg = V.graph.sizevars.inv_precomputed_replacements[arg]
+
+            return str(
+                V.graph.sizevars.atomically_apply_size_hint(
+                    arg, fallback=config.unbacked_symint_fallback
+                )
+            )
+
+        elif isinstance(arg, (str, int, float, bool)):
+            return str(arg)
+        elif isinstance(arg, list):
+            return f"[{', '.join(self.generate_example_arg_value(a, type(a)) for a in arg)}]"
+        else:
+            raise NotImplementedError(f"Unsupported type {type(arg)}")
+
+    def _grid_dim_str(self, grid_per_dim):
+        if isinstance(grid_per_dim, list):
+            return (
+                "[" + ", ".join(self._grid_dim_str(item) for item in grid_per_dim) + "]"
+            )
+        else:
+            return pexpr(grid_per_dim)
+
+    def generate_kernel_call(
+        self,
+        kernel_name: str,
+        call_args,
+        *,
+        device=None,
+        triton=True,
+        arg_types=None,
+        raw_keys=None,
+        raw_args=None,
+        triton_meta=None,
+        original_fxnode_name=None,
+    ):
+        """
+        Generates kernel call code.
+
+        triton: Defines whether the backend uses Triton for codegen. Otherwise it uses the CUDA language when gpu=True,
+                and C++ when gpu=False.
+        """
+
+        # Store buffers corresponding to each call arg.
+        # This is used to generate example args for autotuning later on.
+        self.args_to_buffers.update(
+            {
+                arg: V.graph.try_get_buffer(arg)
+                for arg in call_args
+                if isinstance(arg, str)
+            }
+        )
+
+        device = device or V.graph.get_current_device_or_throw()
+        self.writeline(
+            KernelCallLine(
+                self,
+                kernel_name=kernel_name,
+                call_args=call_args,
+                # pyrefly: ignore [bad-argument-type]
+                raw_keys=raw_keys,
+                # pyrefly: ignore [bad-argument-type]
+                raw_args=raw_args,
+                # pyrefly: ignore [bad-argument-type]
+                arg_types=arg_types,
+                triton=triton,
+                # pyrefly: ignore [bad-argument-type]
+                triton_meta=triton_meta,
+                device=device,
+                graph_name=V.graph.name,
+                # pyrefly: ignore [bad-argument-type]
+                original_fxnode_name=original_fxnode_name,
+            )
+        )
+
+    def _generate_kernel_call_helper(
+        self,
+        kernel_name: str,
+        call_args,
+        *,
+        device=None,
+        triton=True,
+        arg_types=None,
+        raw_keys=None,
+        raw_args=None,
+        triton_meta=None,
+        graph_name="",
+        original_fxnode_name=None,
+    ):
+        device = device or V.graph.get_current_device_or_throw()
+        if not triton and device.type != "cuda":
+            if device.type == "cpu":
+                self.writeline(self.wrap_kernel_call(kernel_name, call_args))
+            elif device.type == "mps":
+                # TODO: Fix me, MPS does not expose streams now
+                self.writeline(
+                    self.wrap_kernel_call(f"{kernel_name}.generated_kernel", call_args)
+                )
+            else:
+                raise RuntimeError(f"device {device.type} nyi")
+            return
+
+        call_args_str = self.prepare_triton_kernel_call(call_args)
+        call_args_str = ", ".join(call_args_str)
+        stream_name = PythonWrapperCodegen.write_get_raw_stream(
+            self, device.index, graph_name
+        )
+        if not triton:
+            stream_ptr = f"c_void_p({stream_name})"
+            self.writeline(
+                f"{kernel_name}.{kernel_name}({call_args_str}, {stream_ptr})"
+            )
+            return
+
+        self.write_triton_header_once()
+
+        if (
+            config.triton.autotune_at_compile_time
+            and kernel_name not in self.kernel_autotune_names
+        ):
+            # Create example args for autotune in a separate epilogue
+            assert arg_types is not None and len(call_args) == len(arg_types), (
+                "call_args and arg_types do not match"
+            )
+
+            autotune_args = None
+            if original_fxnode_name and V.graph.autotuning_mapping:
+                autotune_args = V.graph.autotuning_mapping.get(
+                    original_fxnode_name, None
+                )
+
+            def get_autotune_deletion_call() -> str:
+                """After all the autotune kernel calls have been written (i.e.
+                self.kernel_autotune_example_args is complete), returns a deletion call
+                for all autotune example tensors that are unnecessary after kernel_name
+                is called."""
+                tensors_to_delete = [
+                    tensor
+                    for tensor, kn in self.kernel_autotune_example_args.values()
+                    if kn == kernel_name
+                ]
+                if tensors_to_delete:
+                    return f"del {', '.join(tensors_to_delete)}\n"
+                return ""
+
+            def infer_arg_by_inputs(raw_keys, raw_args, idx, reused_args):
+                """We try to infer raw_arg (i.e. raw_args[idx]) from remaining raw_args.
+                This is particularly useful for jagged cases, where the dimension is often
+                being passed in as an input."""
+
+                target_arg = raw_args[idx]
+                if target_arg in reused_args:
+                    return True
+
+                for i, (raw_key, raw_arg) in enumerate(zip(raw_keys, raw_args)):
+                    if i == idx or not isinstance(raw_arg, IRNode):
+                        continue
+
+                    triton_input = ""
+                    if autotune_args and raw_key in autotune_args:
+                        triton_input = self.get_autotuning_input_name(  # type: ignore[attr-defined]
+                            autotune_args[raw_key]
+                        )
+                    if triton_input == "":
+                        continue
+
+                    try:
+                        layout = raw_arg.get_layout()
+                        for dim, s in enumerate(layout.size):
+                            if s == target_arg:
+                                reused_args[target_arg] = f"{triton_input}.shape[{dim}]"
+                                return True
+                    except NotImplementedError:
+                        # If layout for this IRNode is not implemented, we could just skip.
+                        # Only raise for other Error cases.
+                        continue
+                return False
+
+            all_args = []
+            if raw_args is None:
+                # create a dummy raw_args for uniform behavior in the following loop
+                assert raw_keys is None, "keys are not None but args are"
+                raw_keys = [None] * len(call_args)
+                raw_args = [None] * len(call_args)
+            else:
+                assert len(raw_args) == len(call_args), (
+                    "call_args and raw_args do not match"
+                )
+
+            reused_args = {}
+            for i, (arg, arg_type, raw_key, raw_arg) in enumerate(
+                # pyrefly: ignore [no-matching-overload]
+                zip(call_args, arg_types, raw_keys, raw_args)
+            ):
+                key = None
+                if isinstance(arg, str) and "=" in str(arg):
+                    # arg may be passed in a kwarg style, and then we need to extract its value
+                    key, arg = arg.split("=")
+
+                triton_input: Optional[str] = None
+                if autotune_args and raw_key in autotune_args:
+                    triton_input = self.get_autotuning_input_name(  # type: ignore[attr-defined]
+                        autotune_args[raw_key]
+                    )
+
+                if triton_input:
+                    arg_str = triton_input
+                    if not isinstance(arg_type, torch_dtype) and (
+                        issubclass(arg_type, sympy.Basic)
+                        or isinstance(arg, SymbolicCallArg)
+                    ):
+                        reused_args[raw_arg] = arg_str
+                elif raw_key == "" and infer_arg_by_inputs(
+                    raw_keys, raw_args, i, reused_args
+                ):
+                    # Empty raw_key means this is a arg that's not native to the triton kernel,
+                    # and is being added by inductor.
+                    arg_str = reused_args[raw_arg]
+                elif isinstance(arg_type, torch_dtype):
+                    # workspace allocation is already generated by `generate_workspace_allocation()`
+                    # in `TritonKernel.call_kernel()`.
+                    if re.match(r"^(workspace|semaphore)", arg):
+                        arg_str = arg
+                    elif arg not in self.kernel_autotune_example_args:
+                        arg_str = self.generate_example_arg_value(
+                            arg, arg_type, raw_arg
+                        )
+                    else:
+                        arg_str = self.kernel_autotune_example_args[arg][0]
+                    self.kernel_autotune_example_args[arg] = (arg_str, kernel_name)
+                else:
+                    arg_str = self.generate_example_arg_value(arg, arg_type, raw_arg)
+                all_args.append(arg_str if key is None else f"{key}={arg_str}")
+
+            # Make sure kernel launch under a device guard because models don't always run on device 0
+            self.kernel_autotune_calls.writeline(
+                f"with {V.graph.device_ops.device_guard(device.index)}:"
+            )
+            self.kernel_autotune_calls.do_indent()
+            self.kernel_autotune_calls.writeline(
+                f"{kernel_name}.run({', '.join(all_args)}, stream={stream_name})"
+            )
+            self.kernel_autotune_calls.do_unindent()
+
+            self.kernel_autotune_calls.writeline(
+                DelayReplaceLine("", get_autotune_deletion_call, "")
+            )
+            self.kernel_autotune_names.add(kernel_name)
+            if V.graph.cpp_wrapper:
+                # For cpp wrapper, no need to continue codegen for the main body
+                return
+
+        # add debug printer code for triton kernel calls at (jit) inductor level
+        debug_printer_manager = V.graph.wrapper_code.debug_printer
+        debug_printer_manager.set_printer_args(call_args, kernel_name, arg_types, None)
+        with debug_printer_manager:
+            self.writeline(f"{kernel_name}.run({call_args_str}, stream={stream_name})")
+        self.write_triton_header_once()
+
+    def writeline(self, line):
+        self.lines.append(line)
+
+    def writelines(self, lines):
+        for line in lines:
+            self.writeline(line)
+
+    def enter_context(self, ctx):
+        self.lines.append(LineContext(ctx))
+
+    def val_to_arg_str(self, s, type_=None):
+        from torch.utils._triton import has_triton_package
+
+        if has_triton_package():
+            import triton
+
+        if isinstance(s, SymTypes):
+            return pexpr(s.node.expr)
+        elif isinstance(s, sympy.Expr):
+            return pexpr(s)
+        elif isinstance(s, (tuple, list)):
+
+            @dataclasses.dataclass
+            class Shim:
+                ref: Any
+
+                def __repr__(self):
+                    return self.ref
+
+            # Explicitly call the Python version of val_to_arg_str
+            return repr(
+                type(s)(Shim(PythonWrapperCodegen.val_to_arg_str(self, a)) for a in s)
+            )
+        elif isinstance(s, torch._ops.OpOverload):
+            return _get_qualified_name(s)
+        elif isinstance(s, (ir.Buffer, ir.MutableBox, ReinterpretView)):
+            return s.codegen_reference()
+        elif has_triton_package() and isinstance(s, triton.language.dtype):  # type: ignore[possibly-undefined]
+            return repr(s)
+        elif isinstance(s, ir.GeneratorState):
+            return s.codegen_reference()
+        elif is_opaque_value_type(type(s)):
+            opaque_type = type(s)
+            V.graph.opaque_value_type_classes[opaque_type.__name__] = opaque_type
+            return repr(s)
+        else:
+            return repr(s)
+
+    # The following methods are for memory management
+    def make_buffer_allocation(self, buffer: BufferLike):
+        device = buffer.get_device()
+        dtype = buffer.get_dtype()
+        shape = tuple(buffer.get_size())
+        allocation_shape = tuple(V.graph.get_allocation_size(buffer))
+        stride = tuple(buffer.get_stride())
+        is_pinned = buffer.get_is_pinned()
+        return self.make_allocation(
+            buffer.get_name(), device, dtype, shape, stride, allocation_shape, is_pinned
+        )
+
+    @cache_on_self
+    def write_memory_track_allocation_once(self):
+        import_str = """
+            from torch._inductor.runtime.debug_utils import check_memory_step, track_tensor
+            """
+        if not V.graph.cpp_wrapper:
+            self.imports.splice(import_str, strip=True)
+
+    def make_allocation(
+        self, name, device, dtype, shape, stride, allocation_shape=None, is_pinned=False
+    ):
+        if allocation_shape is None:
+            allocation_shape = shape
+
+        codegen_shape_tuple = self.codegen_python_shape_tuple(shape)
+        codegen_allocation_shape_tuple = self.codegen_python_shape_tuple(
+            allocation_shape
+        )
+        codegen_stride_tuple = self.codegen_python_shape_tuple(stride)
+        if torch._inductor.config.test_configs.track_memory_lifecycle:
+            out = (
+                f"{name} = tracked_empty_strided("
+                f"{codegen_allocation_shape_tuple}, "
+                f"{codegen_stride_tuple}, "
+                f"dtype={dtype}, "
+                f"device='{device.type}', "
+                f"name='{name}')"
+            )
+        elif device.type == "cpu" and is_pinned:
+            out = (
+                f"{name} = empty_strided_cpu_pinned("
+                f"{codegen_allocation_shape_tuple}, "
+                f"{codegen_stride_tuple}, "
+                f"{dtype})"
+            )
+        elif device.type in ("cpu", "cuda", "xpu", "mtia"):
+            # optimized path for faster allocations, saving ~2us versus the stuff below
+            out = (
+                f"{name} = empty_strided_{device.type}("
+                f"{codegen_allocation_shape_tuple}, "
+                f"{codegen_stride_tuple}, "
+                f"{dtype})"
+            )
+        # all other devices:
+        else:
+            out = (
+                f"{name} = empty_strided("
+                f"{codegen_allocation_shape_tuple}, "
+                f"{codegen_stride_tuple}, "
+                f"device='{device.type}', dtype={dtype})"
+            )
+        if codegen_shape_tuple != codegen_allocation_shape_tuple:
+            # need an extra as_strided call
+            out = out + f".as_strided({codegen_shape_tuple}, {codegen_stride_tuple})"
+        return out
+
+    def make_comment(self, line):
+        self.writeline(CommentLine(line))
+
+    def make_tensor_alias(self, new_name, old_name, comment=""):
+        return f"{self.declare}{new_name} = {old_name}{self.ending}  {self.comment} {comment}"
+
+    def make_buffer_free(self, buffer: Union[BufferLike, ir.TorchBindObject]):
+        return f"del {buffer.get_name()}"
+
+    def make_free_by_names(self, names_to_del: list[str]):
+        return f"del {', '.join(name for name in names_to_del)}"
+
+    def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str):
+        return f"{self.declare_maybe_reference}{new_name} = {old_name}{del_line}{self.ending}  {self.comment} reuse"
+
+    def write_provenance_debug_handle(
+        self,
+        kernel_name,
+        debug_handle: Optional[int] = None,
+    ):
+        if debug_handle is not None:
+            self.writeline(
+                f"{self.comment} [Provenance debug handles] {kernel_name}:{debug_handle}"
+            )
+
+    def make_buffer_reuse(self, old: BufferLike, new: BufferLike, delete_old: bool):
+        assert old.get_dtype() == new.get_dtype()
+        old_name = old.get_name()
+        new_name = new.get_name()
+        del_line = ";"
+        if old_name not in V.graph.get_output_names() and delete_old:
+            del_line = f"; {self.make_buffer_free(old)}"
+
+        if old.get_size() == new.get_size() and old.get_stride() == new.get_stride():
+            return self.codegen_exact_buffer_reuse(old_name, new_name, del_line)
+
+        reinterpret_view = self.codegen_reinterpret_view(
+            old, new.get_size(), new.get_stride(), 0, self.wrapper_call.writeline
+        )
+        return f"{self.declare}{new_name} = {reinterpret_view}{del_line}  {self.comment} reuse"
+
+    def codegen_deferred_allocation(self, name: str, view: ir.ReinterpretView) -> None:
+        self.writeline(
+            DeferredLine(
+                name,
+                f"{self.declare}{name} = {view.codegen_reference()}{self.ending}  {self.comment} alias",
+            )
+        )
+
+    def codegen_allocation(self, buffer: ir.Buffer):
+        name = buffer.get_name()
+
+        if (
+            name in V.graph.removed_buffers
+            or name in self.allocated
+            or isinstance(buffer, (ir.DonatedBuffer, ir.SubgraphBuffer, ir.InputBuffer))
+        ):
+            return
+        self.allocated.add(name)
+        if (
+            isinstance(
+                buffer.get_defining_op(),
+                (ir.ExternKernelAlloc, ir.MultiOutput),
+            )
+            and not buffer.should_allocate()
+        ):
+            return
+
+        layout = buffer.get_output_spec()
+        if isinstance(layout, ir.MutationLayoutSHOULDREMOVE):
+            return
+        if isinstance(layout, ir.NoneLayout):
+            return
+        if isinstance(layout, ir.NonOwningLayout):
+            assert isinstance(layout.view, ir.ReinterpretView), (
+                f"unexpected {type(layout.view)}: {layout.view}"
+            )
+            box = layout.view.data
+            assert isinstance(box, ir.StorageBox), type(box)
+            input_buffer = box.data
+            assert isinstance(input_buffer, (ir.Buffer, ir.ReinterpretView)), type(
+                input_buffer
+            )
+            if isinstance(input_buffer, ir.ReinterpretView):
+
+                def unwrap_views(target) -> ir.Buffer:
+                    if isinstance(target, ir.BaseView):
+                        return unwrap_views(target.unwrap_view())
+                    if isinstance(target, ir.MutableBox):
+                        return unwrap_views(target.data)
+                    assert isinstance(target, ir.Buffer), type(target)
+                    return target
+
+                input_buffer = unwrap_views(input_buffer)
+            self.codegen_allocation(input_buffer)
+            self.writeline(ReinterpretLine(self, input_buffer, buffer, layout))
+            return
+
+        if isinstance(layout, ir.CommBufferLayout):
+            self.writeline(CommBufferAllocateLine(self, buffer))
+            return
+
+        self.writeline(AllocateLine(self, buffer))
+
+    def codegen_free(self, buffer):
+        name = buffer.get_name()
+
+        # can be freed but not reused
+        if isinstance(buffer, (ir.InputBuffer, ir.TorchBindObject)):
+            self.writeline(FreeLine(self, buffer))
+            return
+
+        if isinstance(buffer.get_output_spec(), ir.CommBufferLayout):
+            # Comm buffers are not eligible for in-place reuse. Their reuse is
+            # achieved exclusively via buffer planning.
+            self.writeline(CommBufferFreeLine(self, buffer))
+            return
+
+        if not self.can_reuse(buffer):
+            return
+        self.freed.add(name)
+
+        self.writeline(FreeIfNotReusedLine(self, buffer))
+
+    def can_reuse(self, input_buffer, output_buffer=None):
+        name = input_buffer.get_name()
+        return not (
+            name in V.graph.removed_buffers
+            or (
+                name in V.graph.graph_inputs
+                and not isinstance(
+                    V.graph.graph_inputs_original[name], ir.DonatedBuffer
+                )
+            )
+            or name in V.graph.constants
+            or name in V.graph.torchbind_constants
+            or name in V.graph.never_reuse_buffers
+            or name in self.freed
+        )
+
+    def did_reuse(self, buffer, reused_buffer):
+        # Check whether a given buffer was reused by a possible reuser in the wrapper codegen
+        # Can be consulted from inside ir codegen, e.g. to determine whether a copy is needed
+        return (
+            buffer.get_name() in self.reuses
+            and self.reuses[buffer.get_name()] == reused_buffer.get_name()
+        )
+
+    def codegen_inplace_reuse(self, input_buffer: ir.Buffer, output_buffer: ir.Buffer):
+        assert can_match_buffer_size(input_buffer, output_buffer)
+        self.codegen_allocation(input_buffer)
+        self.freed.add(input_buffer.get_name())
+        self.allocated.add(output_buffer.get_name())
+        self.reuses[output_buffer.get_name()] = input_buffer.get_name()
+        self.writeline(ReuseLine(self, input_buffer, output_buffer))
+
+    def codegen_unbacked_symbol_decl(self, symbol):
+        name = str(symbol)
+        if name in self.unbacked_symbol_decls:
+            return name
+        else:
+            # When in CppWrapperCpu, we should only generate the declaration once
+            self.unbacked_symbol_decls.add(name)
+            return self.declare + name
+
+    def codegen_unbacked_symbol_defs_for_outputs(
+        self,
+        output_name: str,
+        outputs: Any,
+        unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]],
+    ) -> None:
+        unbacked_bindings = resolve_unbacked_bindings(
+            V.graph.sizevars.shape_env, unbacked_bindings
+        )
+        self.writeline(
+            UnbackedSymbolDefsLine(self, output_name, outputs, unbacked_bindings)
+        )
+
+    def _codegen_unbacked_symbol_defs_for_outputs(
+        self,
+        output_name: str,
+        outputs: Any,
+        unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]],
+    ) -> None:
+        if not unbacked_bindings:
+            return
+
+        # This code is designed to generate code expressions from symbolic paths (keypaths)
+        # associated with certain symbols (unbacked bindings). These keypaths describe how
+        # to access the unbacked symbol in a structured way.
+        # For example, we might want to generate "u0 = outs[0].stride(1)"", where s = u0, and the keypath
+        # describes the structure of "outs[0].stride(1)", like [SequenceKey(0), CallMethodKey("stride"), SequenceKey[1]].
+        for s, keypath in unbacked_bindings.items():
+            # `go` recursively constructs a code expression by processing each element of
+            # the keypath and construct the expression incrementally.
+            # For example, given output name outs and keypath [SequenceKey(0), CallMethodKey("stride", 1)],
+            # it generates "outs[0]" based on SequenceKey(0), then recursively go("outs[0]", [CallMethodKey("stride"), ...])
+            def go(expr: str, keypath: pytree.KeyPath):
+                if keypath == ():
+                    return expr
+
+                if (
+                    len(keypath) >= 2
+                    and isinstance(keypath[0], CallMethodKey)
+                    and isinstance(keypath[1], pytree.SequenceKey)
+                ):
+                    return go(
+                        f"{expr}.{keypath[0].name}({keypath[1].idx})", keypath[2:]
+                    )
+                elif isinstance(keypath[0], CallMethodKey):
+                    return go(f"{expr}.{keypath[0].name}()", keypath[1:])
+                elif isinstance(keypath[0], pytree.SequenceKey):
+                    return (
+                        go(f"std::get<{keypath[0].idx}>({expr})", keypath[1:])
+                        if V.graph.cpp_wrapper
+                        else go(f"{expr}[{keypath[0].idx}]", keypath[1:])
+                    )
+                elif isinstance(keypath[0], DivideByKey):
+                    # TODO: need to assert divisibility
+                    # TODO: this is invalid C++ codegen
+                    return go(f"{expr}.__floordiv__({keypath[0].divisor})", keypath[1:])
+                else:
+                    raise AssertionError(f"unrecognized keypath {keypath}")
+
+            # `go_outer` manages the top-level logic for generating the final expression.
+            # It handles special cases for C++ code generation and adjusts
+            # the keypath based on the context (e.g., single vs. multiple outputs).
+            def go_outer():  # type: ignore[no-untyped-def]
+                if V.graph.cpp_wrapper:
+                    # Special handling for the top level buffer access,
+                    # because self.get_name() is actually never bound; the
+                    # individual output arguments are bound by
+                    # generate_c_shim_fallback_kernel
+                    if len(outputs) == 1:
+                        out = outputs[0]
+                        # When fallback kernel returns a list consisting of a single tensor,
+                        # the output is represented as a MultiOutput with non empty indices.
+                        # In this case, we strip the first key path away.
+                        return go(
+                            outputs[0].get_name(),
+                            keypath[1:]
+                            if isinstance(out, ir.MultiOutput) and len(out.indices) != 0
+                            else keypath,
+                        )
+                    else:
+                        assert isinstance(keypath[0], pytree.SequenceKey)
+                        return go(outputs[keypath[0].idx].get_name(), keypath[1:])
+                else:
+                    return go(output_name, keypath)
+
+            self.writeline(
+                f"{self.codegen_unbacked_symbol_decl(s)} = {go_outer()}{self.ending}"
+            )
+
+    def codegen_subgraph_by_inlining(self, subgraph, outer_inputs, outer_outputs):
+        # TODO (desertfire) - This function is the old way of supporting
+        # subgraph codegen by inlining subgraphs in the output code. For python
+        # wrapper, we have moved to lifting subgraphs as functions, supported by
+        # `codegen_subgraph` function.
+        #
+        # However this does not work with cpp wrapper. With cpp wrapper, we make
+        # two passes and the kernels are shared from the first pass to the next.
+        # Therefore, both the Python and CppWrapper need to share the some
+        # codegen infra. For now, CppWrapperCpu has not been updated to lift the
+        # subgraph as functions. Therefore for cpp_wrapper first pass with
+        # PythonWrapper, we still fallback to the old way of inlining subgraphs
+        # in the output code. Once we update CppWrapperCpu, we can remove this
+        # function.
+        def _codegen_subgraph_prefix():
+            assert len(subgraph.graph.graph_inputs) == len(outer_inputs)
+            for inner_input, outer_input in zip(
+                subgraph.graph.graph_inputs, outer_inputs
+            ):
+                self.writeline(
+                    f"{self.declare}{inner_input} = {outer_input}{self.ending}"
+                )
+
+        def _codegen_subgraph_suffix():
+            assert len(subgraph.graph.graph_outputs) == len(outer_outputs)
+            for inner_output, outer_output in zip(
+                subgraph.graph.graph_outputs, outer_outputs
+            ):
+                self.writeline(
+                    f"{outer_output} = {inner_output.codegen_reference()}{self.ending}"
+                )
+
+        try:
+            self.push_codegened_graph(subgraph.graph)
+            self.writeline(f"{self.comment} subgraph: {subgraph.name}")
+            _codegen_subgraph_prefix()
+            parent_graph = V.graph
+            with V.set_graph_handler(subgraph.graph):
+                subgraph.graph.codegen_subgraph(
+                    parent_graph=parent_graph,
+                )
+            _codegen_subgraph_suffix()
+        finally:
+            self.pop_codegened_graph()
+
+    def codegen_partition_call(
+        self,
+        partition_id: int,
+        partition_signatures: ir.GraphPartitionSignature,
+    ):
+        """Generate code to call a graph partition"""
+        input_deallocation = partition_signatures.input_deallocation
+        output_nodes = partition_signatures.output_nodes
+
+        input_names = list(input_deallocation.keys()) + [
+            symbol_input.name for symbol_input in partition_signatures.symbol_inputs
+        ]
+
+        inputs = ", ".join(input_names) + ("," if len(input_names) == 1 else "")
+
+        output_names = [node.get_name() for node in output_nodes]
+        outputs = ", ".join(output_names) + ("," if len(output_nodes) == 1 else "")
+
+        # Create a list of inputs for the subgraph call
+        self.writeline(f"partition{partition_id}_args = [{inputs}]")
+
+        names_to_del = [
+            name for name, deallocate in input_deallocation.items() if deallocate
+        ]
+        if names_to_del:
+            self.writeline(f"del {', '.join(names_to_del)}")
+
+        # Call the subgraph launcher function
+        self.writeline(
+            f"({outputs}) = self.partitions[{partition_id}](partition{partition_id}_args)"
+        )
+        self.writeline(f"del partition{partition_id}_args")
+
+    def set_all_partition_names(self, num_partitions: int):
+        self.all_partition_names = [f"partition_{idx}" for idx in range(num_partitions)]
+
+    def codegen_subgraph_call_with_flattened_outputs(
+        self, subgraph, outer_inputs, outer_flattened_outputs
+    ):
+        # Get the input and output names of the subgraph
+        outer_output_names = ", ".join(outer_flattened_outputs) + (
+            "," if len(outer_flattened_outputs) == 1 else ""
+        )
+        outer_input_names = ", ".join(outer_inputs) + (
+            "," if len(outer_inputs) == 1 else ""
+        )
+
+        self.writeline(f"{subgraph.graph.name}_args = [{outer_input_names}]")
+
+        # Call the subgraph launcher function
+        self.writeline(
+            f"({outer_output_names}) = {subgraph.graph.name}({subgraph.graph.name}_args)"
+        )
+
+    def codegen_subgraph_call(self, subgraph, outer_inputs, outer_buffer_name):
+        # Get the input and output names of the subgraph
+        outer_input_names = ", ".join(outer_inputs) + (
+            "," if len(outer_inputs) == 1 else ""
+        )
+
+        self.writeline(f"{subgraph.graph.name}_args = [{outer_input_names}]")
+
+        # Since the buffers are already put into the args list, we can free the
+        # buffers here.
+        V.graph.scheduler.free_buffers()
+
+        # Call the subgraph launcher function
+        self.writeline(
+            f"{outer_buffer_name} = {subgraph.graph.name}({subgraph.graph.name}_args)"
+        )
+
+    def codegen_subgraph_common(self, subgraph):
+        self.push_codegened_graph(subgraph.graph)
+        self.make_comment("")
+        self.make_comment(f"{self.comment} subgraph: {subgraph.name}")
+
+        parent_graph = V.graph
+        subgraph.graph.cpp_wrapper = parent_graph.cpp_wrapper
+        subgraph.graph.fx_wrapper = parent_graph.fx_wrapper
+
+        if subgraph.graph.name not in self.already_codegened_subgraphs:
+            # If it is already codegened, the parent wrapper already has
+            # subgraph fn by name subgraph.graph.name
+            with V.set_graph_handler(subgraph.graph):
+                # do not graph partition for subgraph
+                with config.patch("graph_partition", False):
+                    # Call the codegen of subgraph recursively
+                    subgraph_code, _ = subgraph.graph.codegen()
+            subgraph_name = subgraph.graph.name
+            self.already_codegened_subgraphs.add(subgraph_name)
+            self.define_subgraph_launcher_fn(subgraph_name, subgraph_code)
+
+    def codegen_subgraph_with_flattened_outputs(
+        self, subgraph, outer_inputs, outer_flattened_outputs
+    ):
+        self.codegen_subgraph_common(subgraph)
+        self.codegen_subgraph_call_with_flattened_outputs(
+            subgraph, outer_inputs, outer_flattened_outputs
+        )
+
+    def codegen_subgraph(self, subgraph, outer_inputs, outer_buffer_name):
+        # Codegen subgraph by recursively calling the codegen for the subgraph.
+        # This lifts the subgraph as a function in the output code.
+        self.codegen_subgraph_common(subgraph)
+        self.codegen_subgraph_call(subgraph, outer_inputs, outer_buffer_name)
+
+    def codegen_invoke_subgraph(self, invoke_subgraph):
+        name = invoke_subgraph.get_name()
+
+        self.writeline(f"{name} = [None] * {len(invoke_subgraph.outputs)}")
+        outer_inputs = [buf.codegen_reference() for buf in invoke_subgraph.inputs]
+
+        if V.graph.aot_mode:
+            outer_outputs = [
+                f"{name}[{i}]" for i in range(len(invoke_subgraph.outputs))
+            ]
+            self.codegen_subgraph_by_inlining(
+                invoke_subgraph.subgraph, outer_inputs, outer_outputs
+            )
+        else:
+            self.codegen_subgraph(invoke_subgraph.subgraph, outer_inputs, name)
+
+    def codegen_conditional(self, conditional) -> None:
+        name = conditional.get_name()
+
+        outer_inputs = [buf.codegen_reference() for buf in conditional.operands]
+
+        predicate = conditional.predicate.codegen_reference()
+        if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer):
+            # move the Tensor predicate to host
+            predicate = f"{predicate}.item()"
+
+        self.writeline(f"{name} = [None] * {len(conditional.outputs)}")
+        self.writeline(f"if {predicate}:")
+        self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph))
+        if V.graph.aot_mode:
+            outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))]
+            self.codegen_subgraph_by_inlining(
+                conditional.true_subgraph, outer_inputs, outer_outputs
+            )
+        else:
+            self.codegen_subgraph(conditional.true_subgraph, outer_inputs, name)
+
+        self.writeline(ExitSubgraphLine(self))
+        self.writeline("else:")
+        self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph))
+        if V.graph.aot_mode:
+            outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))]
+            self.codegen_subgraph_by_inlining(
+                conditional.false_subgraph, outer_inputs, outer_outputs
+            )
+        else:
+            self.codegen_subgraph(conditional.false_subgraph, outer_inputs, name)
+        self.writeline(ExitSubgraphLine(self))
+
+    def codegen_while_loop(self, while_loop, stack_output):
+        """while_loop is codegened as a host side while_loop"""
+
+        def codegen_subgraph(subgraph, outer_inputs, outer_outputs):
+            """Helper method to deduplicate subgraph codegen logic"""
+            if V.graph.aot_mode:
+                self.codegen_subgraph_by_inlining(subgraph, outer_inputs, outer_outputs)
+            else:
+                self.codegen_subgraph_with_flattened_outputs(
+                    subgraph, outer_inputs, outer_outputs
+                )
+
+        name = while_loop.get_name()
+        outer_carried_inputs = [
+            buf.codegen_reference() for buf in while_loop.carried_inputs
+        ]
+        outer_additional_inputs = [
+            buf.codegen_reference() for buf in while_loop.additional_inputs
+        ]
+
+        ckp_offset = len(outer_carried_inputs)
+        self.writeline(f"{name} = [None] * {len(outer_carried_inputs)}")
+        if stack_output:
+            self.writeline(
+                f"{name}.extend([[] for _ in range({len(outer_carried_inputs)})])"
+            )
+
+        for i, inp in enumerate(outer_carried_inputs):
+            # set the initial state before the loop
+            self.writeline(f"{name}[{i}] = {inp}")
+
+        cond_outer_inputs = [
+            *[f"{name}[{i}]" for i in range(len(outer_carried_inputs))],
+            *outer_additional_inputs,
+        ]
+        cond_outer_outputs = [f"{name}_cond_result"]
+        body_outer_inputs = list(
+            cond_outer_inputs
+        )  # same inputs for cond_fn and body_fn
+        # Carry over the state from body_fn. Note: We only carry over
+        # the carried_inputs part of the inputs, the additional ones
+        # are passed in as they're before.
+        body_outer_outputs = body_outer_inputs[: len(outer_carried_inputs)]
+        # Check condition at the beginning and set up flag
+        codegen_subgraph(
+            while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs
+        )
+        self.writeline(f"should_loop = {cond_outer_outputs[0]}")
+        self.writeline("if not should_loop:")
+        if stack_output:
+            # Handle the case when loop never executes
+            for i, carried_input in enumerate(outer_carried_inputs):
+                self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
+                self.writeline(f"{name}[{i}] = {carried_input}.unsqueeze(0).clone()")
+                self.writeline(ExitSubgraphLine(self))
+        else:
+            for i, carried_input in enumerate(outer_carried_inputs):
+                self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
+                self.writeline(f"{name}[{i}] = {carried_input}.clone()")
+                self.writeline(ExitSubgraphLine(self))
+
+        self.writeline("while should_loop:")
+        # Body execution
+        self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
+        codegen_subgraph(
+            while_loop.body_subgraph, body_outer_inputs, body_outer_outputs
+        )
+        self.writeline(ExitSubgraphLine(self))
+
+        # Collect outputs if enabled
+        if stack_output:
+            self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
+            for i in range(len(outer_carried_inputs)):
+                self.writeline(f"{name}[{i + ckp_offset}].append({name}[{i}])")
+            self.writeline(ExitSubgraphLine(self))
+
+        # Condition check at end of loop
+        self.writeline(EnterSubgraphLine(self, while_loop.cond_subgraph.graph))
+        codegen_subgraph(
+            while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs
+        )
+        self.writeline(ExitSubgraphLine(self))
+        self.writeline(f"    should_loop = {cond_outer_outputs[0]}")
+
+        # Stack outputs after loop completion
+        if stack_output:
+            self.writeline("# Stack outputs after loop completion")
+            for i in range(len(outer_carried_inputs)):
+                self.writeline(f"if len({name}[{i + ckp_offset}]) > 0:")
+                self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
+                self.writeline(
+                    f"{name}[{i}] = torch.stack({name}[{i + ckp_offset}], dim=0)"
+                )
+                self.writeline(ExitSubgraphLine(self))
+
+    @staticmethod
+    def statically_known_int_or_none(x):
+        try:
+            if getattr(x, "free_symbols", None):
+                # _maybe_evaluate_static will return (s0 // (2 // s0)) as 2, but
+                # the actual codegen will still generate the full expression here.
+                return None
+            if isinstance(x, int):
+                return x
+            val = V.graph._shape_env._maybe_evaluate_static(x)
+            if val is None:
+                return val
+            return int(val)  # type: ignore[call-overload]
+        except Exception:
+            return None
+
+    @staticmethod
+    def statically_known_list_of_ints_or_none(lst):
+        result = []
+        for x in lst:
+            num = PythonWrapperCodegen.statically_known_int_or_none(x)
+            if num is None:
+                return None
+            result.append(num)
+        return result
+
+    @staticmethod
+    def is_statically_known_list_of_ints(lst):
+        return (
+            PythonWrapperCodegen.statically_known_list_of_ints_or_none(lst) is not None
+        )
+
+    @staticmethod
+    def static_shape_for_buffer_or_none(buffer):
+        return PythonWrapperCodegen.statically_known_list_of_ints_or_none(
+            buffer.get_size()
+        )
+
+    @staticmethod
+    def can_prove_buffer_has_static_shape(buffer):
+        return PythonWrapperCodegen.static_shape_for_buffer_or_none(buffer) is not None
+
+    def write_kernel_context_guard(
+        self,
+        kernel_name: str,
+        node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
+    ):
+        return
+
+    def write_kernel_context_guard_begin(
+        self,
+    ):
+        """
+        Mark the beginning of kernel context guard
+        """
+        return
+
+    def write_kernel_context_guard_end(
+        self,
+    ):
+        """
+        Mark the end of kernel context guard
+        """
+        return
+
+
+class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
+    """
+    A wrapper codegen that generates code for a subgraph. For most of the
+    methods, we rely on the implementation in the PythonWrapperCodegen. But we
+    override a few functions to produce cleaner code (like avoiding writing
+    imports twice in the output code)
+    """
+
+    def __init__(
+        self,
+        subgraph_name: str,
+        parent_wrapper: PythonWrapperCodegen,
+        partition_signatures: Optional[ir.GraphPartitionSignature] = None,
+    ):
+        # It is necessary to set the subgraph_name before calling super __init__
+        # because __init__ calls set_launcher_fn_name
+        self.subgraph_name = subgraph_name
+        self.parent_wrapper = parent_wrapper
+        self.partition_signatures = partition_signatures
+
+        super().__init__()
+
+        root = self.get_root_graph()
+        # Only generate auto-tuning block in the main graph
+        self.kernel_autotune_defs = root.kernel_autotune_defs
+        self.kernel_autotune_calls = root.kernel_autotune_calls
+        # Only store kernel src to name mapping in the main graph
+        self.src_to_kernel = root.src_to_kernel
+        # Same here, only define user-defined Triton kernels in the main graph
+        self.user_defined_kernel_cache = root.user_defined_kernel_cache
+
+    def set_launcher_fn_name(self) -> None:
+        # This sets up the name of the function containing the launcher code of
+        # the subgraph.
+        # pyrefly: ignore [bad-assignment]
+        self.launcher_fn_name = self.subgraph_name
+
+    def write_header(self) -> None:
+        pass
+
+    def add_benchmark_harness(self, output):
+        pass
+
+    def benchmark_compiled_module(self, output):
+        pass
+
+    def write_async_compile_wait(self):
+        pass
+
+    def next_kernel_suffix(self) -> str:
+        # Ensures that subgraphs kernels do not clash with each other
+        return self.parent_wrapper.next_kernel_suffix()
+
+    def generate_after_suffix(self, result: IndentedBuffer) -> None:
+        return
+
+    def write_launcher_fn_call_get_indent(self) -> int:
+        self.prefix.splice(
+            f"""
+            def {self.launcher_fn_name}(args):
+            """
+        )
+        prefix_indent = 1
+        return prefix_indent
+
+    def get_wrapper_call_indent(self) -> int:
+        return 1
+
+    def get_graph_inputs(
+        self,
+    ) -> dict[str, Union[ir.TensorBox, ir.TorchBindObject, sympy.Expr, None]]:
+        if signature := self.partition_signatures:
+            inputs = signature.input_nodes | {
+                str(s): s for s in signature.symbol_inputs
+            }
+        else:
+            inputs = V.graph.graph_inputs
+        return inputs
+
+    def get_graph_input_names(self) -> list[str]:
+        if signature := self.partition_signatures:
+            names = list(signature.input_nodes.keys()) + [
+                symbol_input.name for symbol_input in signature.symbol_inputs
+            ]
+        else:
+            names = V.graph.graph_input_names
+        return names
+
+    def get_graph_outputs(self) -> list[IRNode]:
+        if signature := self.partition_signatures:
+            outputs = signature.output_nodes
+        else:
+            outputs = V.graph.graph_outputs
+        return outputs
+
+    def codegen_allocation(self, buffer: ir.Buffer):
+        name = buffer.get_name()
+        if (signature := self.partition_signatures) and name in signature.input_nodes:
+            # skip allocation if buffer is a subgraph input.
+            # This allows reusing an input buffer in graph partition,
+            # although this is not allowed in general.
+            return
+
+        super().codegen_allocation(buffer)
+
+    @cache_on_self
+    def write_triton_header_once(self) -> None:
+        # TODO: Uncomment in future. This will be needed to support subgraph
+        # codegen for cpp wrapper.
+        # if config.triton.autotune_at_compile_time:
+        #     import_str = self.triton_header_str()
+        #     self.kernel_autotune_calls.splice(import_str)
+        self.parent_wrapper.write_triton_header_once()
+
+    @cache_on_self
+    def write_get_raw_stream_header_once(self) -> None:
+        # TODO: Uncomment in future. This will be needed to support subgraph
+        # codegen for cpp wrapper.
+        # if config.triton.autotune_at_compile_time:
+        #     self.kernel_autotune_calls.writeline(
+        #         V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
+        #     )
+        self.parent_wrapper.write_get_raw_stream_header_once()
+
+    @cache_on_self
+    def get_root_graph(self) -> PythonWrapperCodegen:
+        root: PythonWrapperCodegen | SubgraphPythonWrapperCodegen = self
+        while isinstance(root, SubgraphPythonWrapperCodegen):
+            root = root.parent_wrapper
+
+        assert isinstance(root, PythonWrapperCodegen)
+        return root
+
+    def generate_and_run_autotune_block(self):
+        # Only execute auto-tuning block in the main graph
+        pass
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/wrapper_fxir.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/wrapper_fxir.py
new file mode 100644
index 0000000000000000000000000000000000000000..02c498d6debce64609751edae5c4e9287797fa6a
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/wrapper_fxir.py
@@ -0,0 +1,1213 @@
+import dataclasses
+import functools
+import logging
+import operator
+import textwrap
+from collections import Counter
+from collections.abc import Callable, Sequence
+from typing import Any, Optional, Union
+
+import sympy
+
+import torch
+from torch._export.passes._node_metadata_hook import (
+    _node_metadata_hook,
+    _set_node_metadata_hook,
+)
+from torch._higher_order_ops.triton_kernel_wrap import (
+    TraceableTritonKernelWrapper,
+    tracing_triton_hopifier_singleton,
+    triton_kernel_wrapper_mutation,
+)
+from torch._inductor.codecache import LambdaFuture, PyCodeCache
+from torch._inductor.runtime.triton_heuristics import CachingAutotuner
+from torch._inductor.select_algorithm import extern_kernels  # noqa: F401
+from torch._inductor.utils import convert_to_symint
+from torch._inductor.virtualized import V
+from torch._library.triton import wrap_triton
+from torch.fx import GraphModule
+from torch.fx.experimental.symbolic_shapes import (
+    CallMethodKey,
+    ConvertIntKey,
+    DivideByKey,
+    free_unbacked_symbols,
+)
+from torch.utils import _pytree as pytree
+from torch.utils._sympy.functions import FloorDiv
+from torch.utils._sympy.interp import _run_sympy_handler, sympy_interp
+from torch.utils._sympy.reference import OptimizedPythonReferenceAnalysis
+from torch.utils._sympy.solve import try_solve
+
+from .. import config, ir
+from ..runtime.triton_compat import Config
+from ..utils import cache_property_on_self, LineContext, ValueWithLineMap
+from .common import (
+    CodegenSymbol,
+    FileBackedGraphModule,
+    WorkspaceArg,
+    WorkspaceZeroMode,
+)
+from .wrapper import (
+    AllocateLine,
+    BufferLike,
+    CommBufferAllocateLine,
+    CommBufferFreeLine,
+    CommentLine,
+    ConditionalLine,
+    DynamicScalarLine,
+    EnterDeviceContextManagerLine,
+    EnterSubgraphLine,
+    ExitDeviceContextManagerLine,
+    ExitSubgraphLine,
+    ExternKernelAllocLine,
+    ExternKernelOutLine,
+    FreeIfNotReusedLine,
+    FreeLine,
+    IndexPutFallbackLine,
+    KernelCallLine,
+    KernelDefinitionLine,
+    Line,
+    MultiOutputLine,
+    NullLine,
+    PythonWrapperCodegen,
+    ReinterpretLine,
+    ReuseLine,
+    ScatterFallbackLine,
+    SubgraphPythonWrapperCodegen,
+    SymbolicCallArg,
+    SymbolicCallArgLine,
+    UnbackedSymbolDefsLine,
+    WrapperLine,
+)
+
+
+aten = torch.ops.aten
+log = logging.getLogger(__name__)
+
+
+@dataclasses.dataclass
+class SymbolBuffer(CodegenSymbol):
+    """
+    Represents a sympy.Symbol graph input.
+    """
+
+    symbol: sympy.Symbol
+
+    def get_name(self) -> str:
+        return str(self.symbol)
+
+    def get_example(self) -> Union[torch.Tensor, torch.SymInt]:
+        sym_int = convert_to_symint(self.symbol)
+        assert isinstance(sym_int, torch.SymInt)
+        return sym_int
+
+
+CodegenBuffer = Union[BufferLike, SymbolBuffer]
+
+
+@dataclasses.dataclass
+class TritonKernel:
+    """
+    Stores metadata about Triton kernels for use in FX.
+    """
+
+    tuner: CachingAutotuner
+    wrapped: TraceableTritonKernelWrapper
+
+
+def replace_floor_div(expr: sympy.Expr) -> sympy.Expr:
+    """
+    Replace sympy.floor with FloorDiv.
+    """
+
+    def replace(expr: sympy.Expr) -> sympy.Expr:
+        expr = sympy.together(expr)
+
+        # Division is represented as a Mul with a Rational factor or a Pow with negative
+        # exponent. We convert floor(Mul(...)) to FloorDiv(numerator, denominator) by
+        # partitioning factors into the numerator and denominator.
+        (numerator, denominator) = (sympy.S.One,) * 2
+        for arg in sympy.Mul.make_args(expr):
+            if isinstance(arg, sympy.Rational):
+                numerator *= arg.numerator
+                denominator *= arg.denominator
+            elif isinstance(arg, sympy.Pow) and arg.exp.is_negative:
+                denominator *= arg.base**-arg.exp
+            else:
+                numerator *= arg
+
+        return FloorDiv(numerator, denominator)
+
+    return expr.replace(sympy.floor, replace)
+
+
+class WrapperFxCodegen(PythonWrapperCodegen):
+    """
+    Backend to generate wrapper code as an FX IR graph.
+    """
+
+    supports_caching = False
+
+    def __init__(self, *args: Any, **kwargs: Any):
+        super().__init__(*args, **kwargs)
+        self.subgms: dict[str, torch.fx.GraphModule] = {}
+
+    def codegen_inputs(self) -> None:
+        """
+        This would generate code for symbolic input shapes, strides, etc.
+        Since the FX converter handles this, do nothing here.
+        """
+
+    def codegen_conditional(self, conditional: ir.Conditional) -> None:
+        """
+        Conditional codegen normally emits a number of different wrapper lines.
+        Instead, FX conversion uses a dedicated line for the whole conditional.
+        """
+        self.writeline(ConditionalLine(self, conditional))
+        for subgraph in (conditional.true_subgraph, conditional.false_subgraph):
+            self.codegen_subgraph_common(subgraph)
+
+    def define_subgraph_launcher_fn(
+        self, name: str, subgraph_code: Union[ValueWithLineMap, FileBackedGraphModule]
+    ) -> None:
+        """
+        Record subgms as they're generated.
+        """
+        assert isinstance(subgraph_code, FileBackedGraphModule)
+        self.subgms[name] = subgraph_code.gm
+
+    @property
+    @cache_property_on_self
+    def is_subgraph(self) -> bool:
+        return isinstance(self, SubgraphPythonWrapperCodegen)
+
+    def get_fx_graph_inputs(
+        self,
+    ) -> dict[str, Union[ir.TensorBox, ir.TorchBindObject, sympy.Expr, None]]:
+        """
+        Get the input nodes corresponding to FX graph placeholders.
+        """
+        # pyrefly: ignore [missing-argument]
+        if V.aot_compilation and not self.is_subgraph:
+            # AOT graphs must match the signature of the input module.
+            return {
+                node.name: V.graph.graph_inputs.get(node.name)
+                for node in V.graph.module.graph.find_nodes(op="placeholder")  # type: ignore[operator, union-attr]
+            }
+
+        return self.get_graph_inputs()
+
+    def _generate(self, is_inference: bool) -> tuple[FileBackedGraphModule, None]:
+        self.run_wrapper_ir_passes(is_inference)
+
+        prologue = "\n".join(
+            [
+                self.imports.getvalue(),
+                self.header.getvalue(),
+            ]
+        )
+        gm = FxConverter(
+            lines=self.lines,
+            prologue=prologue,
+            graph_inputs=self.get_fx_graph_inputs(),
+            graph_outputs=self.get_graph_outputs(),
+            subgms=self.subgms,
+            # pyrefly: ignore [missing-argument]
+            is_subgraph=self.is_subgraph,
+        ).generate()
+
+        compiled_fn = self.compile_graph(gm)
+
+        return FileBackedGraphModule(gm, compiled_fn), None
+
+    def compile_graph(self, gm: GraphModule) -> Callable[..., Any]:
+        """
+        Converts the graph module into a runnable function. The default implementation
+        is simply an interpreter calling kernels in eager mode. Derived backends can
+        override this to do further compilation.
+        """
+        return gm.forward
+
+    def write_header(self) -> None:
+        """
+        Python subgraphs normally lack headers.
+        Override this behavior to generate prologues for FX subgraphs.
+        """
+        PythonWrapperCodegen.write_header(self)
+
+    @classmethod
+    def create(
+        cls: type["WrapperFxCodegen"],
+        is_subgraph: bool,
+        subgraph_name: Optional[str],
+        parent_wrapper: Optional[PythonWrapperCodegen],
+        partition_signatures: Optional[ir.GraphPartitionSignature] = None,
+    ) -> "WrapperFxCodegen":
+        if is_subgraph:
+            assert subgraph_name is not None
+            assert parent_wrapper is not None
+
+            # Subgraphs override some methods of PythonWrapperCodegen.
+            # Apply these overrides to the user-provided class, with priority given to
+            # user-provided methods.
+            class SubgraphFxWrapperCodegen(cls, SubgraphPythonWrapperCodegen):  # type: ignore[misc,valid-type]
+                def compile_graph(self, gm: GraphModule) -> Callable[..., Any]:
+                    """
+                    Skip graph compilation for subgraphs.
+                    """
+
+                    def crash_if_run(*args: Any) -> None:
+                        raise NotImplementedError("Cannot run a subgraph in isolation!")
+
+                    return crash_if_run
+
+            return SubgraphFxWrapperCodegen(
+                subgraph_name, parent_wrapper, partition_signatures
+            )
+
+        return cls()
+
+
+@dataclasses.dataclass
+class FxConverter:
+    """
+    Generates FX IR from Wrapper IR. As each instance is only meant to be used once, the
+    input and output code are stored as attributes.
+    """
+
+    lines: list[Line]
+    prologue: str
+    graph_inputs: dict[str, Union[ir.TensorBox, ir.TorchBindObject, sympy.Expr, None]]
+    graph_outputs: list[ir.IRNode]
+    subgms: dict[str, torch.fx.GraphModule]
+    is_subgraph: bool
+
+    def __post_init__(self) -> None:
+        graph = torch.fx.Graph()
+        self.gm = GraphModule({}, graph)  # Wrapper FX IR.
+        self.buffer_to_node: dict[
+            Optional[str], torch.fx.Node
+        ] = {}  # Symbol table for codegen.
+        self.kernels: dict[str, TritonKernel] = {}  # Table to store Triton kernels.
+        self._unique_symbol_ids: Counter[str] = Counter()
+        self.tracer = torch.fx.proxy.GraphAppendingTracer(graph)
+        self.expr_to_proxy: dict[sympy.Expr, torch.fx.Proxy] = {}
+
+    def _import_kernel(self, code: str, kernel_name: str) -> CachingAutotuner:
+        """
+        Imports a kernel from source, possibly autotuning block parameters.
+        """
+        module_code = "\n".join([self.prologue, code])
+        mod = PyCodeCache.load(module_code)
+        kernel = getattr(mod, kernel_name)
+
+        if isinstance(kernel, LambdaFuture):
+            kernel = kernel.result()
+
+        if not isinstance(kernel, CachingAutotuner):
+            raise NotImplementedError(
+                textwrap.dedent(f"""
+                Unsupported type for kernel {kernel_name}: {type(kernel)}.
+                FX conversion only supports Triton kernels.
+            """)
+            )
+
+        return kernel
+
+    def _create_as_strided(
+        self,
+        input_node: torch.fx.Node,
+        size: tuple[Any, ...],
+        stride: tuple[Any, ...],
+        offset: Union[int, sympy.Expr],
+    ) -> torch.fx.Node:
+        return self.gm.graph.call_function(
+            torch.as_strided,
+            args=(
+                input_node,
+                self._generate_sym_nodes(size),
+                self._generate_sym_nodes(stride),
+                self._generate_sym_node(offset),
+            ),
+        )
+
+    def _record_allocation(self, buffer: CodegenBuffer, node: torch.fx.Node) -> None:
+        """
+        Updates the symbol table to record that an Inductor buffer maps to the result of
+        an FX node.
+        """
+        assert node not in self.buffer_to_node
+        self.buffer_to_node[buffer.get_name()] = node
+
+    def _free(self, buffer: Union[CodegenBuffer, ir.TorchBindObject]) -> None:
+        """
+        Removes the buffer from the symbol table.
+        """
+        name = buffer.get_name()
+        del self.buffer_to_node[name]
+
+    def _lookup_args(self, args: tuple[Any, ...]) -> tuple[Any, ...]:
+        """
+        Maps call args back to FX nodes.
+        """
+        return tuple(
+            self.buffer_to_node[arg]
+            if isinstance(arg, str)
+            else arg.inner_expr
+            if isinstance(arg, SymbolicCallArg)
+            else arg
+            for arg in args
+        )
+
+    def _get_buffer(self, node: ir.IRNode) -> CodegenBuffer:
+        """
+        Extract buffer data from an IR node.
+        """
+        if isinstance(node, (ir.Buffer, WorkspaceArg)):
+            return node
+        elif isinstance(node, (ir.BaseView, ir.MutableBox)):
+            return self._get_buffer(node.data)
+        elif isinstance(node, sympy.Symbol):
+            return SymbolBuffer(node)
+        else:
+            raise NotImplementedError(f"Unable to extract buffer from node: {node}")
+
+    def _generate_size_proxy(
+        self, node: torch.fx.Node, expr: sympy.Expr
+    ) -> torch.fx.Proxy:
+        proxy = torch.fx.Proxy(node, tracer=self.tracer)
+        self.expr_to_proxy[expr] = proxy
+        return proxy
+
+    def _generate_graph_inputs(self) -> None:
+        """
+        Converts graph inputs to FX placeholders.
+        """
+
+        for name, ir_node in self.graph_inputs.items():
+            if ir_node is None:
+                # Create dummy input nodes to match the input signature
+                self.gm.graph.placeholder(name)
+                continue
+
+            # Introduce a new symbol for constant inputs.
+            is_constant = isinstance(ir_node, (int, float, sympy.Integer, sympy.Float))
+            buffer = (
+                SymbolBuffer(sympy.Symbol(name, is_integer=True))
+                if is_constant
+                else self._get_buffer(ir_node)
+            )
+            placeholder_node = self.gm.graph.placeholder(buffer.get_name())
+            placeholder_node.meta["val"] = (
+                ir_node if is_constant else buffer.get_example()
+            )
+            self._record_allocation(buffer, placeholder_node)
+
+            # Record symbol definitions for dynamic shapes.
+            if isinstance(ir_node, sympy.Symbol):
+                self._generate_size_proxy(placeholder_node, ir_node)
+
+    def _generate_graph_input_shapes(self) -> None:
+        """
+        Generate nodes creating symints that are part of graph input
+        shape/strides.
+        """
+
+        def _codegen_symbol(
+            sym_or_exp: Union[sympy.Symbol, sympy.Expr],
+            base_node: torch.fx.Node,
+            target: torch._ops.OpOverload,
+            dim: int,
+        ) -> None:
+            def codegen_proxy() -> torch.fx.Proxy:
+                size_node = self.gm.graph.call_function(target, (base_node, dim))
+                size_proxy = self._generate_size_proxy(size_node, sym_or_exp)
+                return size_proxy
+
+            if isinstance(sym_or_exp, sympy.Symbol):
+                if sym_or_exp in self.expr_to_proxy:
+                    return
+                codegen_proxy()
+
+            elif isinstance(sym_or_exp, sympy.Integer):
+                return
+
+            elif isinstance(sym_or_exp, sympy.Expr):
+                # Check if we need to solve for an undefined symbol.
+                undefined_symbols = [
+                    sym
+                    for sym in sym_or_exp.free_symbols
+                    if sym not in self.expr_to_proxy
+                ]
+                if len(undefined_symbols) == 0:
+                    self._sympy_interp(sym_or_exp)
+                    return
+                elif len(undefined_symbols) > 1:
+                    raise ValueError(f"Underdetermined input expression: {sym_or_exp}")
+
+                # Define a new symbol for the input size.
+                size_proxy = codegen_proxy()
+                size_symbol = sympy.Symbol(
+                    size_proxy.node.name, integer=True, nonnegative=True
+                )
+                self.expr_to_proxy[size_symbol] = size_proxy
+
+                # Solve for the undefined symbol.
+                undefined_symbol = undefined_symbols[0]
+                solution = try_solve(
+                    sympy.Eq(sym_or_exp, size_symbol), undefined_symbol
+                )
+                if solution is None:
+                    raise ValueError(f"Cannot solve input expression: {sym_or_exp}")
+
+                # Since the symbol is a size, it must be an integer.
+                # Therefore, we can convert division to FloorDiv.
+                undefined_symbol_expr = solution[1]
+                if undefined_symbol.is_integer:
+                    undefined_symbol_expr = replace_floor_div(
+                        sympy.floor(undefined_symbol_expr)
+                    )
+
+                # Generate FX for the symbol.
+                self._sympy_interp(undefined_symbol_expr)
+                self.expr_to_proxy[undefined_symbol] = self.expr_to_proxy[
+                    undefined_symbol_expr
+                ]
+
+        for ir_node in self.graph_inputs.values():
+            if isinstance(ir_node, ir.TensorBox):
+                buffer = self._get_buffer(ir_node)
+                placeholder_node = self.buffer_to_node[buffer.get_name()]
+
+                for dim, size in enumerate(ir_node.get_size()):
+                    _codegen_symbol(
+                        size, placeholder_node, torch.ops.aten.sym_size.int, dim
+                    )
+                for dim, stride in enumerate(ir_node.get_stride()):
+                    _codegen_symbol(
+                        stride, placeholder_node, torch.ops.aten.sym_stride.int, dim
+                    )
+
+    def _generate_graph_constants(self) -> None:
+        for name, value in V.graph.constants.items():
+            node = self.gm.graph.get_attr(name)
+            node.meta["val"] = value
+            setattr(self.gm, name, value)
+            self.buffer_to_node[name] = node
+
+    def _generate_buffer(self, node: ir.IRNode) -> Optional[torch.fx.Node]:
+        """
+        Generates FX IR for transformations on a buffer, such as ReinterpretView.
+        Does nothing if no such transformations are present.
+        """
+
+        if isinstance(node, ir.ShapeAsConstantBuffer):
+            # Generate FX nodes to compute the shape expression.
+            return self._sympy_interp(node.expr).node
+
+        def generate_to_buffer(node: ir.IRNode) -> Optional[BufferLike]:
+            if isinstance(node, (ir.Buffer, WorkspaceArg)):
+                return node
+            elif isinstance(node, ir.NoneAsConstantBuffer):
+                return None
+            elif isinstance(node, ir.MutableBox):
+                return generate_to_buffer(node.data)
+            elif isinstance(node, ir.ReinterpretView):
+                # We need to introduce a new symbol if the output is a ReinterpretView.
+                # Use a WorkspaceArg for this.
+                buffer = self._get_buffer(node.data)
+                assert isinstance(buffer, (ir.Buffer, WorkspaceArg))
+                unique_name = self.gm.graph._graph_namespace.create_name(
+                    f"{buffer.get_name()}_view", None
+                )
+                device = buffer.get_device()
+                assert device
+                reused_as = WorkspaceArg(
+                    count=buffer.get_size(),
+                    zero_mode=WorkspaceZeroMode.UNINITIALIZED,
+                    device=device,
+                    outer_name=unique_name,
+                    dtype=buffer.get_dtype(),
+                )
+
+                # Generate FX IR for the view.
+                self._generate_reinterpret_helper(buffer, reused_as, node.layout)
+
+                return reused_as
+            else:
+                raise NotImplementedError(f"Unrecognized buffer/view node: {node}")
+
+        buffer = generate_to_buffer(node)
+        return self.buffer_to_node[buffer.get_name()] if buffer is not None else None
+
+    def _generate_outputs(
+        self,
+    ) -> Union[Optional[torch.fx.Node], list[Optional[torch.fx.Node]]]:
+        """
+        Generate FX IR for graph outputs.
+        """
+        output_nodes = [
+            self._generate_buffer(node) for idx, node in enumerate(self.graph_outputs)
+        ]
+
+        # Parent graphs with single return elements don't use a tuple.
+        output_value = (
+            output_nodes[0]
+            if len(output_nodes) == 1 and not self.is_subgraph
+            else output_nodes
+        )
+
+        return output_value
+
+    def _generate_subgm_getattrs(self) -> None:
+        """
+        Generate getattr nodes for subgms.
+        """
+
+        def generate_getattr(name: str, subgm: torch.fx.GraphModule) -> torch.fx.Node:
+            self.gm.add_submodule(name, subgm)
+            node = self.gm.graph.get_attr(name)
+            node.meta["val"] = subgm
+            return node
+
+        self.subgm_getattrs = {
+            name: generate_getattr(name, subgm) for name, subgm in self.subgms.items()
+        }
+
+    def _get_subgm_attr(self, subgraph: ir.Subgraph) -> torch.fx.Node:
+        """
+        Look up the getattr node for a subgraph.
+        """
+        graph = subgraph.graph
+        assert graph is not None
+        return self.subgm_getattrs[graph.name]
+
+    def generate(self) -> torch.fx.GraphModule:
+        """
+        Main entrypoint for FX codegen.
+        """
+        self._generate_graph_inputs()
+        self._generate_graph_constants()
+        self._generate_subgm_getattrs()
+
+        with _set_node_metadata_hook(
+            self.gm,
+            functools.partial(_node_metadata_hook, fake_mode=V.fake_mode),
+        ):
+            self._generate_graph_input_shapes()
+
+            # Generate FX IR from Wrapper IR lines.
+            for line in self.lines:
+                if isinstance(line, WrapperLine):
+                    line.codegen_fx(self)(line)
+                elif isinstance(line, LineContext):
+                    # Ignore line context in FX IR.
+                    pass
+                else:
+                    raise NotImplementedError(
+                        textwrap.dedent(
+                            f"""
+                        Found line of unrecognized type '{type(line)}':
+                            '{line}'
+
+                        FX conversion only supports Wrapper IR lines.
+                        """
+                        )
+                    )
+
+            output = self._generate_outputs()
+
+        self.gm.graph.output(output)
+        self.gm.recompile()
+        return self.gm
+
+    def _sympy_interp(self, expr: sympy.Expr) -> torch.fx.Proxy:
+        # hash cons
+        if expr in self.expr_to_proxy:
+            return self.expr_to_proxy[expr]
+        # base cases, don't cache
+        if isinstance(
+            expr,
+            (
+                sympy.Integer,
+                sympy.Number,
+                sympy.Symbol,
+                sympy.logic.boolalg.BooleanAtom,
+            ),
+        ):
+            return sympy_interp(
+                OptimizedPythonReferenceAnalysis, self.expr_to_proxy, expr
+            )
+
+        # hash cons on arguments, run expr handler
+        self.expr_to_proxy[expr] = _run_sympy_handler(
+            OptimizedPythonReferenceAnalysis,
+            [self._sympy_interp(arg) for arg in expr.args],
+            expr,
+        )
+        return self.expr_to_proxy[expr]
+
+    def _generate_sym_node(
+        self, s: Union[int, sympy.Expr]
+    ) -> Union[int, torch.fx.Node]:
+        if isinstance(s, (int, sympy.Integer)):
+            return int(s)
+        elif isinstance(s, sympy.Symbol):
+            assert s in self.expr_to_proxy, (
+                f"Could not find a node corresponding to the symbol {s}"
+            )
+            return self.expr_to_proxy[s].node
+        elif isinstance(s, sympy.Expr):
+            return self._sympy_interp(s).node
+
+        elif isinstance(s, torch.fx.Node):
+            return s
+
+        else:
+            raise ValueError(f"{s} of type {type(s)} is not a valid input")
+
+    def _generate_sym_nodes(
+        self, shape: Sequence[sympy.Expr]
+    ) -> list[Union[int, torch.fx.Node]]:
+        return [self._generate_sym_node(s) for s in shape]
+
+    def _generate_allocate(self, line: WrapperLine) -> None:
+        assert isinstance(line, AllocateLine)
+        buffer = line.node
+        name = buffer.get_name()
+        assert name not in V.graph.removed_buffers
+
+        device = buffer.get_device()
+        assert device
+        dtype = buffer.get_dtype()
+        shape = self._generate_sym_nodes(buffer.get_size())
+        stride = self._generate_sym_nodes(buffer.get_stride())
+
+        node = self.gm.graph.call_function(
+            torch.empty_strided,
+            args=(shape, stride),
+            kwargs={"dtype": dtype, "device": device.type},
+        )
+        assert name
+        node.name = name
+        self._record_allocation(buffer, node)
+
+    def _generate_conditional(self, line: WrapperLine) -> None:
+        assert isinstance(line, ConditionalLine)
+
+        def get_subgm_attr(subgraph: Optional[ir.Subgraph]) -> torch.fx.Node:
+            assert subgraph is not None
+            return self._get_subgm_attr(subgraph)
+
+        # Access the subgraphs as getattrs.
+        ir_node = line.node
+        (true_subgm, false_subgm) = [
+            get_subgm_attr(subgraph)
+            for subgraph in (ir_node.true_subgraph, ir_node.false_subgraph)
+        ]
+
+        def generate_buffer(node: Optional[ir.IRNode]) -> Optional[torch.fx.Node]:
+            assert node is not None
+            return self._generate_buffer(node)
+
+        predicate = generate_buffer(ir_node.predicate)
+        assert ir_node.operands is not None
+        operands = tuple(generate_buffer(arg) for arg in ir_node.operands)
+        fx_node = self.gm.graph.call_function(
+            torch.ops.higher_order.cond,
+            args=(predicate, true_subgm, false_subgm, operands),
+        )
+        self._record_allocation(ir_node, fx_node)
+
+    def _generate_comment(self, line: WrapperLine) -> None:
+        assert isinstance(line, CommentLine)
+        # We ignore comments in FX IR.
+
+    def _generate_dynamic_scalar(self, line: WrapperLine) -> None:
+        assert isinstance(line, DynamicScalarLine)
+
+        ir_node = line.node
+        (input_ir_node,) = ir_node.inputs
+        assert isinstance(input_ir_node, ir.IRNode)
+        input_fx_node = self._generate_buffer(input_ir_node)
+        keypath = ir_node.keypath
+        graph = self.gm.graph
+
+        def generate_item(x: Optional[torch.fx.Node]) -> torch.fx.Node:
+            assert x is not None
+            return graph.call_function(
+                aten.item.default,
+                args=(x,),
+            )
+
+        if len(keypath) == 0:
+            result_fx_node = generate_item(input_fx_node)
+        elif len(keypath) == 1 and isinstance(keypath[0], ConvertIntKey):
+            where_fx_node = graph.call_function(
+                aten.where.Scalar,
+                args=(input_fx_node, 1, 0),
+            )
+            result_fx_node = generate_item(where_fx_node)
+        else:
+            raise NotImplementedError(f"Unsupported keypath: {keypath}")
+
+        result_symbol = ir_node.sym
+        result_buffer = SymbolBuffer(result_symbol)
+        self._record_allocation(result_buffer, result_fx_node)
+        self._generate_size_proxy(result_fx_node, result_symbol)
+
+    def _generate_enter_device_context_manager(self, line: WrapperLine) -> None:
+        assert isinstance(line, EnterDeviceContextManagerLine)
+        # We ignore the device context in FX IR.
+
+    def _generate_exit_device_context_manager(self, line: WrapperLine) -> None:
+        assert isinstance(line, ExitDeviceContextManagerLine)
+        # We ignore the device context in FX IR.
+
+    def _generate_enter_subgraph(self, line: WrapperLine) -> None:
+        assert isinstance(line, EnterSubgraphLine)
+        # We ignore memory planning lines in FX IR.
+
+    def _generate_exit_subgraph(self, line: WrapperLine) -> None:
+        assert isinstance(line, ExitSubgraphLine)
+        # We ignore memory planning lines in FX IR.
+
+    def _generate_free(self, line: WrapperLine) -> None:
+        assert isinstance(line, FreeLine)
+
+        buf = line.node
+
+        # No need to free placeholders.
+        if self.buffer_to_node[buf.get_name()].op == "placeholder":
+            return
+
+        self._free(buf)
+
+    def _generate_free_if_not_reused(self, line: WrapperLine) -> None:
+        assert isinstance(line, FreeIfNotReusedLine)
+        buf = line.node
+        assert buf.get_name() not in V.graph.removed_buffers
+        if not line.is_reused:
+            self._free(buf)
+
+    def _generate_line_context(self, line: WrapperLine) -> None:
+        assert isinstance(line, LineContext)
+        # We ignore line context in FX IR.
+
+    def _generate_reinterpret(self, line: WrapperLine) -> None:
+        assert isinstance(line, ReinterpretLine)
+        self._generate_reinterpret_helper(line.node, line.reused_as, line.layout)
+
+    def _generate_reinterpret_helper(
+        self, input_buffer: BufferLike, result_buffer: BufferLike, layout: ir.Layout
+    ) -> None:
+        input_node = self.buffer_to_node[input_buffer.get_name()]
+
+        # Look up output metadata.
+        name = result_buffer.get_name()
+        assert name
+        size = tuple(layout.size)
+        stride = tuple(layout.stride)
+        if isinstance(layout, ir.NonOwningLayout):
+            # Look up the view's layout.
+            view = layout.view
+            assert isinstance(view, ir.ReinterpretView), (
+                f"unexpected type: {type(view)}"
+            )
+            layout = view.layout
+        offset = input_buffer.get_offset() + layout.offset
+
+        # Map ReinterpretView to as_strided.
+        result_node = self._create_as_strided(input_node, size, stride, offset)
+        result_node.name = name
+        self._record_allocation(result_buffer, result_node)
+
+    def _generate_reuse(self, line: WrapperLine) -> None:
+        assert isinstance(line, ReuseLine)
+        old = line.node
+        new = line.reused_as
+        assert not any(buf.get_name() in V.graph.removed_buffers for buf in (old, new))
+        assert old.get_dtype() == new.get_dtype()
+
+        old_node = self.buffer_to_node[old.get_name()]
+        result_node = old_node
+
+        # Change shape and stride.
+        size = tuple(new.get_size())
+        stride = tuple(new.get_stride())
+        offset = new.get_offset()
+        if (
+            tuple(old.get_size()) != size
+            or tuple(old.get_stride()) != stride
+            or old.get_offset() != offset
+        ):
+            result_node = self._create_as_strided(old_node, size, stride, offset)
+
+        self._record_allocation(new, result_node)
+
+        # Free the old buffer, if we allocated a new tensor.
+        if (
+            old.get_name() not in V.graph.get_output_names()
+            and line.delete_old
+            and result_node is not old_node
+        ):
+            self._free(old)
+
+    def _generate_multi_output(self, line: WrapperLine) -> None:
+        assert isinstance(line, MultiOutputLine)
+
+        arg_node = self.buffer_to_node[line.arg_name]
+
+        # For non-tuple / non-list outputs, map the
+        # output to the same node as the input.
+        if len(line.indices) == 0:
+            self.buffer_to_node[line.result_name] = arg_node
+            return
+
+        # Extract the index for tuple access.
+        inds = line.indices[0][1:]
+        assert len(inds) == 1, f"Cannot convert {inds} to an index."
+        idx = inds[0]
+
+        node = self.gm.graph.call_function(operator.getitem, args=(arg_node, idx))
+        node.name = line.result_name
+        self.buffer_to_node[line.result_name] = node
+
+    def _generate_fallback_call(
+        self,
+        ir_node: ir.ExternKernel,
+        args: Optional[tuple[Any, ...]] = None,
+        kwargs: Optional[dict[str, Any]] = None,
+    ) -> None:
+        fx_node = self.gm.graph.call_function(
+            ir_node.op_overload,  # type: ignore[arg-type]
+            args=args,
+            kwargs=kwargs,
+        )
+        result_buffer = ir_node.codegen_reference()
+        self.buffer_to_node[result_buffer] = fx_node
+
+    def _generate_index_put_fallback(self, line: WrapperLine) -> None:
+        assert isinstance(line, IndexPutFallbackLine)
+        ir_node = line.node
+
+        def generate_buffer_or_none(
+            x: Union[ir.IRNode, Sequence[ir.IRNode], None],
+        ) -> Optional[torch.fx.Node]:
+            """
+            Handles None before calling _generate_buffer.
+            """
+            if x is None:
+                return None
+
+            assert isinstance(x, ir.IRNode)
+            return self._generate_buffer(x)
+
+        (x, values) = [generate_buffer_or_none(t) for t in ir_node.inputs[:2]]
+        indices = tuple(generate_buffer_or_none(t) for t in line.indices)
+        accumulate = ir_node.constant_args[0]
+        args = (x, indices, values, accumulate)
+        self._generate_fallback_call(ir_node, args)
+
+    def _generate_scatter_fallback(self, line: WrapperLine) -> None:
+        assert isinstance(line, ScatterFallbackLine)
+        ir_node = line.node
+        assert ir.is_node_sequence(ir_node.inputs)
+        (x, index, src) = [self._generate_buffer(t) for t in ir_node.inputs] + (
+            [] if ir_node.src_is_tensor else [ir_node.constant_args[1]]
+        )
+        args = (x, ir_node.constant_args[0], index, src)
+        kwargs = {}
+        if reduce := ir_node.kwargs.get("reduce"):
+            kwargs["reduce"] = reduce
+
+        self._generate_fallback_call(ir_node, args, kwargs)
+
+    def _generate_null(self, line: WrapperLine) -> None:
+        assert isinstance(line, NullLine)
+        # Does nothing.
+
+    def _generate_comm_buffer_allocate(self, line: WrapperLine) -> None:
+        assert isinstance(line, CommBufferAllocateLine)
+        raise NotImplementedError("Comm buffer allocation is not yet supported")
+
+    def _generate_comm_buffer_free(self, line: WrapperLine) -> None:
+        assert isinstance(line, CommBufferFreeLine)
+        self._free(line.node)
+
+    def _generate_triton_call(self, line: WrapperLine) -> None:
+        assert isinstance(line, KernelCallLine)
+
+        # Collect all kwargs, including autotuned block sizes.
+        call_args = self._lookup_args(line.call_args)
+        kernel = self.kernels[line.kernel_name]
+        tuner = kernel.tuner
+
+        class UnbackedSymintsError(Exception):
+            pass
+
+        def tune_kernel(tuner: CachingAutotuner, call_args: Sequence[Any]) -> None:
+            from triton.runtime import driver
+
+            log.info("Autotuning Triton kernel %s at compile time.", kernel_name)
+            # pyrefly: ignore  # missing-attribute
+            device = driver.active.get_current_device()
+            # pyrefly: ignore  # missing-attribute
+            stream = driver.active.get_current_stream(device)
+
+            def node_to_tuning_arg(arg: Any) -> Any:
+                """
+                Create real tensors for autotuning arguments, substituting size hints
+                for dynamic shapes.
+                """
+
+                def to_size_hint(arg: Any) -> Any:
+                    if len(free_unbacked_symbols(arg)) > 0:
+                        # NYI: tuning args require backed symints.
+                        raise UnbackedSymintsError
+                    return pytree.tree_map(V.graph.sizevars.size_hint, arg)
+
+                if not isinstance(arg, torch.fx.Node):
+                    return to_size_hint(arg)
+
+                fake = arg.meta["val"]
+                return torch.empty_strided(
+                    to_size_hint(fake.shape),
+                    to_size_hint(fake.stride()),
+                    dtype=fake.dtype,
+                    device=device,
+                ).zero_()
+
+            arg_values = [node_to_tuning_arg(arg) for arg in call_args]
+            tuner.run(*arg_values, stream=stream)
+
+        # Optionally autotune the kernels.
+        # The FX backend currently only supports compile-time tuning.
+        kernel_name = tuner.fn.__name__
+        if config.triton.autotune_at_compile_time:
+            try:
+                tune_kernel(tuner, call_args)
+            except UnbackedSymintsError:
+                log.info(
+                    "Detected unbacked symints. Skipping autotuning for kernel %s.",
+                    kernel_name,
+                )
+        else:
+            log.info(
+                "Skipping autotuning for kernel %s. Set config.triton.autotune_at_compile_time = True to enable.",
+                kernel_name,
+            )
+
+        triton_meta = tuner.triton_meta
+        signature = triton_meta["signature"]
+
+        def add_constants_to_call_args(
+            call_args: Sequence[Any], cfg: Config
+        ) -> tuple[Any, ...]:
+            """
+            Add constant kwargs to the arg list.
+            """
+            # Add args from the proper Triton signature.
+            # Exclude constants and config kwargs, as those are tracked separately.
+            new_call_args = []
+            constants = triton_meta["constants"]
+            call_kwargs = {
+                key: val
+                for key, val in zip(signature, call_args)
+                # pyrefly: ignore [missing-attribute]
+                if key not in constants and key not in cfg.kwargs
+            }
+
+            # Add constants stored as Triton metadata, in signature order.
+            call_kwargs |= constants
+            new_call_args = [
+                call_kwargs[key]
+                for key in signature
+                # pyrefly: ignore [missing-attribute]
+                if key not in cfg.kwargs
+            ]
+
+            # Add Inductor's extra launcher args to the end.
+            if extra_launcher_args := tuner.inductor_meta.get("extra_launcher_args"):
+                new_call_args.extend(
+                    call_args[len(call_args) - len(extra_launcher_args) :]
+                )
+
+            return tuple(new_call_args)
+
+        kernel_config = tuner.compile_results[0].config
+        extra_options = getattr(kernel_config, "extra_options", None)
+        call_args = add_constants_to_call_args(call_args, kernel_config)
+        call_args, grid = tuner._interpret_args_grid(call_args, kernel_config)
+        call_kwargs = dict(zip(signature, call_args))
+        # pyrefly: ignore [missing-attribute]
+        assert not any(kwarg in kernel_config.kwargs for kwarg in call_kwargs), (
+            f"kwargs overlap config: {call_kwargs}"
+        )
+        # pyrefly: ignore [missing-attribute]
+        call_kwargs.update(kernel_config.kwargs)
+
+        # Replace sympy.floor with FloorDiv, to make the expression traceable.
+        grid = [replace_floor_div(x) if isinstance(x, sympy.Expr) else x for x in grid]
+        wrapper_grid = [tuple(self._generate_sym_nodes(grid))]
+        call_kwargs = {
+            name: self._generate_sym_node(val) for name, val in call_kwargs.items()
+        }
+
+        # Store non-graphable kwargs in the side table.
+        (
+            call_kwargs,
+            constant_args_idx,
+        ) = tracing_triton_hopifier_singleton.store_non_graphable_args(call_kwargs)
+
+        triton_node = self.gm.graph.call_function(
+            triton_kernel_wrapper_mutation,
+            kwargs={
+                "kernel_idx": kernel.wrapped.kernel_idx,
+                "constant_args_idx": constant_args_idx,
+                "grid": wrapper_grid,
+                "tma_descriptor_metadata": {},
+                "kwargs": call_kwargs,
+            },
+        )
+        if extra_options:
+            triton_node.meta["extra_options"] = extra_options
+
+    def _generate_extern_kernel_alloc(self, line: WrapperLine) -> None:
+        assert isinstance(line, ExternKernelAllocLine)
+        node = line.node
+        self._generate_extern_kernel_common(node, node)
+
+    def _generate_extern_kernel_out(
+        self,
+        line: WrapperLine,
+    ) -> None:
+        assert isinstance(line, ExternKernelOutLine)
+        node = line.node
+        out_node = node.output_view if node.output_view else node
+        self._generate_extern_kernel_common(node, out_node)
+
+    def _generate_extern_kernel_common(
+        self, kernel: ir.ExternKernel, out_ir_node: ir.IRNode
+    ) -> None:
+        """
+        Generates FX IR from either ExternKernelAlloc or ExternKernelOut.
+        """
+
+        # Get FX nodes corresponding to the call args.
+        assert ir.is_node_sequence(kernel.inputs)
+        tensor_nodes = tuple(self._generate_buffer(arg) for arg in kernel.inputs)
+        if hasattr(kernel, "unflatten_args"):
+            args, _ = kernel.unflatten_args(tensor_nodes, kernel.constant_args)
+        else:
+            args = tensor_nodes + tuple(kernel.constant_args)
+
+        # Get the result buffer.
+        # Some kernels write to a pre-existing output tensor via the "out" kwarg.
+        kwargs = kernel.kwargs.copy()
+
+        result_buffer: Optional[str] = None
+        if isinstance(kernel, ir.ExternKernelOut):
+            kwargs["out"] = self.buffer_to_node[out_ir_node.codegen_reference()]
+        elif isinstance(kernel.layout, (ir.Layout, ir.MultiOutputLayout)):
+            result_buffer = kernel.get_name()
+        elif isinstance(kernel.layout, ir.NoneLayout):
+            pass
+        else:
+            raise NotImplementedError(f"Unrecognized output layout: {kernel.layout}")
+
+        fx_node = self.gm.graph.call_function(
+            kernel.op_overload,  # type: ignore[arg-type]
+            args=args,
+            kwargs=kwargs,
+        )
+
+        # Assign the result to the given name.
+        if result_buffer:
+            assert "out" not in kwargs, (
+                f"Extern kernel '{kernel}' has both result and out kwarg. Expected only one."
+            )
+            fx_node.name = result_buffer
+            self.buffer_to_node[result_buffer] = fx_node
+
+    def _generate_kernel_call(self, line: WrapperLine) -> None:
+        assert isinstance(line, KernelCallLine)
+        if not line.triton:
+            raise NotImplementedError("FX conversion only supports Triton kernels.")
+
+        self._generate_triton_call(line)
+
+    def _generate_kernel_definition(self, line: WrapperLine) -> None:
+        assert isinstance(line, KernelDefinitionLine)
+
+        # Generate code for the kernel.
+        kernel_code = PythonWrapperCodegen._format_kernel_definition(
+            line.kernel_name, line.kernel_body, metadata=line.metadata
+        )
+
+        # Import the module and store the JIT kernel.
+        tuner = self._import_kernel(kernel_code, line.kernel_name)
+        wrapped = wrap_triton(tuner.fn)
+        self.kernels[line.kernel_name] = TritonKernel(tuner, wrapped)
+
+    def _generate_symbolic_call_arg(self, line: WrapperLine) -> None:
+        assert isinstance(line, SymbolicCallArgLine)
+        # Store the arg: expr mapping for later use.
+        arg = line.arg
+
+        inner_expr_proxy = self._sympy_interp(arg.inner_expr)
+        self.expr_to_proxy[arg.inner] = inner_expr_proxy
+
+    def _generate_unbacked_symbol_defs(self, line: WrapperLine) -> None:
+        assert isinstance(line, UnbackedSymbolDefsLine)
+        graph = self.gm.graph
+
+        def convert_key(node: torch.fx.Node, path: pytree.KeyPath) -> torch.fx.Node:
+            """
+            Generate FX IR for each key entry.
+            """
+            # Base case.
+            if len(path) == 0:
+                return node
+
+            # Process the first entry and recurse.
+            entry = path[0]
+            if isinstance(entry, CallMethodKey):
+                target = {
+                    "size": aten.sym_size.int,
+                    "stride": aten.sym_stride.int,
+                    "storage_offset": aten.sym_storage_offset,
+                }[entry.name]
+                assert callable(target)
+                node = graph.call_function(
+                    target,
+                    args=(
+                        (node, path[1].idx)
+                        if len(path) > 1 and isinstance(path[1], pytree.SequenceKey)
+                        else (node,)
+                    ),
+                )
+                return convert_key(node, path[1 + len(node.args) :])
+            elif isinstance(entry, pytree.SequenceKey):
+                node = graph.call_function(operator.getitem, args=(node, entry.idx))
+                return convert_key(node, path[1:])
+            elif isinstance(entry, DivideByKey):
+                node = graph.call_function(
+                    operator.floordiv, args=(node, entry.divisor)
+                )
+                return convert_key(node, path[1:])
+            else:
+                raise NotImplementedError(f"Unrecognized entry type: {type(entry)}")
+
+        root_node = self.buffer_to_node[line.output_name]
+        unbacked_bindings = line.unbacked_bindings
+        assert unbacked_bindings is not None
+        for s, keypath in unbacked_bindings.items():
+            # Check if we already generated this symbol.
+            if s.name in self.buffer_to_node:
+                continue
+
+            node = convert_key(root_node, keypath)
+            out_buffer = SymbolBuffer(s)
+            self._record_allocation(out_buffer, node)
+            self._generate_size_proxy(node, s)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/b2b_gemm.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/b2b_gemm.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a8dc65c08ec457c1cb87354a7a95afb4d15203d
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/b2b_gemm.py
@@ -0,0 +1,774 @@
+# mypy: allow-untyped-defs
+import functools
+from collections import deque
+
+import torch
+from torch.utils._ordered_set import OrderedSet
+from torch.utils._pytree import tree_map
+
+from ..._dynamo.utils import counters
+from ..ir import (
+    ComputedBuffer,
+    FixedLayout,
+    FlexibleLayout,
+    InputBuffer,
+    ShapeAsConstantBuffer,
+    StorageBox,
+    Subgraph,
+    TensorBox,
+)
+from ..lowering import lowerings
+from ..pattern_matcher import (
+    Arg,
+    CallFunction,
+    Match,
+    PatternMatcherPass,
+    register_graph_pattern,
+)
+from ..select_algorithm import (
+    autotune_select_algorithm,
+    ExternKernelChoice,
+    SymbolicGridFn,
+    TritonTemplate,
+    TritonTemplateCaller,
+)
+from ..utils import ceildiv
+
+
+B2B_GEMM_PASS = PatternMatcherPass(
+    pass_name="b2b_gemm_pass",
+)
+
+
+@SymbolicGridFn
+def b2b_gemm_grid(M, P, meta, *, cdiv):
+    return (cdiv(M, meta["BLOCK_SIZE_M"]) * cdiv(P, meta["BLOCK_SIZE_P"]), 1, 1)
+
+
+b2b_gemm_left_template = TritonTemplate(
+    name="b2b_gemm_left",
+    grid=b2b_gemm_grid,
+    debug=False,
+    source=r"""
+{{def_kernel("A", "B", "C")}}
+
+
+    # B2B_GEMM_LEFT_TRITON_ENTRANCE
+
+    # dynamic shapes
+    M = {{size("A", 0)}}
+    N = {{size("A", 1)}}
+    O = {{size("C", 0)}}
+    P = {{size("C", 1)}}
+
+    # dynamic strides
+    stride_am = {{stride("A", 0)}}
+    stride_an = {{stride("A", 1)}}
+    stride_bn = {{stride("B", 0)}}
+    stride_bo = {{stride("B", 1)}}
+    stride_co = {{stride("C", 0)}}
+    stride_cp = {{stride("C", 1)}}
+
+    # output block counts
+    num_m_block = tl.cdiv(M, BLOCK_SIZE_M)
+    num_p_block = tl.cdiv(P, BLOCK_SIZE_P)
+
+    # internal block counts
+    num_n_block = tl.cdiv(N, BLOCK_SIZE_N)
+    num_o_block = tl.cdiv(O, BLOCK_SIZE_O)
+
+    # output block ids
+    pid = tl.program_id(axis=0)
+    m_block_id = pid // num_p_block
+    p_block_id = pid % num_p_block
+
+    # accumulator
+    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_P), dtype=tl.float32)
+
+    # main loop
+    offs_m = (m_block_id * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))
+    offs_p = (p_block_id * BLOCK_SIZE_P + tl.arange(0, BLOCK_SIZE_P))
+    # (subgraph(A @ B) @ C)
+    offs_o = tl.arange(0, BLOCK_SIZE_O)
+    for _ in range(num_o_block):
+        c_mask = (offs_o[:, None] < O) & (offs_p[None, :] < P)
+        c_ptrs = C + (offs_o[:, None] * stride_co + offs_p[None, :] * stride_cp)
+        c = tl.load(c_ptrs, mask=c_mask, other=0.0).to(tl.float32)  # BLOCK_SIZE_O * BLOCK_SIZE_P
+        acc_ab = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_O), dtype=tl.float32)
+        offs_n = tl.arange(0, BLOCK_SIZE_N)
+        for __ in range(num_n_block):
+            a_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
+            a_ptrs = A + (offs_m[:, None] * stride_am + offs_n[None, :] * stride_an)
+            a = tl.load(a_ptrs, mask=a_mask, other=0.0).to(tl.float32)  # BLOCK_SIZE_M * BLOCK_SIZE_N
+            b_mask = (offs_n[:, None] < N) & (offs_o[None, :] < O)
+            b_ptrs = B + (offs_n[:, None] * stride_bn + offs_o[None, :] * stride_bo)
+            b = tl.load(b_ptrs, mask=b_mask, other=0.0).to(tl.float32)  # BLOCK_SIZE_N * BLOCK_SIZE_O
+            acc_ab += tl.dot(a, b, out_dtype=tl.float32)
+            offs_n += BLOCK_SIZE_N
+        # apply the subgraph
+        {{ modification(
+            subgraph_number=0,
+            output_name="post_subgraph_acc_ab",
+            inner_mm="acc_ab"
+        ) | indent_except_first(2) }}
+        acc += tl.dot(post_subgraph_acc_ab, c, out_dtype=tl.float32)
+        offs_o += BLOCK_SIZE_O
+
+    # type conversion
+    acc = acc.to(tl.float16)
+
+    # store preparation
+    idx_m = offs_m[:, None]
+    idx_p = offs_p[None, :]
+    out_mask = (idx_m < M) & (idx_p < P)
+
+    {{store_output(("idx_m", "idx_p"), "acc", "out_mask", val_shape=("BLOCK_SIZE_M", "BLOCK_SIZE_P"))}}
+""",
+)
+
+
+b2b_gemm_right_template = TritonTemplate(
+    name="b2b_gemm_right",
+    grid=b2b_gemm_grid,
+    debug=False,
+    source=r"""
+{{def_kernel("A", "B", "C")}}
+
+
+    # B2B_GEMM_RIGHT_TRITON_ENTRANCE
+
+    # dynamic shapes
+    M = {{size("A", 0)}}
+    N = {{size("A", 1)}}
+    O = {{size("C", 0)}}
+    P = {{size("C", 1)}}
+
+    # dynamic strides
+    stride_am = {{stride("A", 0)}}
+    stride_an = {{stride("A", 1)}}
+    stride_bn = {{stride("B", 0)}}
+    stride_bo = {{stride("B", 1)}}
+    stride_co = {{stride("C", 0)}}
+    stride_cp = {{stride("C", 1)}}
+
+    # output block counts
+    num_m_block = tl.cdiv(M, BLOCK_SIZE_M)
+    num_p_block = tl.cdiv(P, BLOCK_SIZE_P)
+
+    # internal block counts
+    num_n_block = tl.cdiv(N, BLOCK_SIZE_N)
+    num_o_block = tl.cdiv(O, BLOCK_SIZE_O)
+
+    # output block ids
+    pid = tl.program_id(axis=0)
+    m_block_id = pid // num_p_block
+    p_block_id = pid % num_p_block
+
+    # accumulator
+    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_P), dtype=tl.float32)
+
+    # main loop (two cases)
+    offs_m = (m_block_id * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))
+    offs_p = (p_block_id * BLOCK_SIZE_P + tl.arange(0, BLOCK_SIZE_P))
+    # (A @ subgraph(B @ C))
+    offs_n = tl.arange(0, BLOCK_SIZE_N)
+    for _ in range(num_n_block):
+        a_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
+        a_ptrs = A + (offs_m[:, None] * stride_am + offs_n[None, :] * stride_an)
+        a = tl.load(a_ptrs, mask=a_mask, other=0.0).to(tl.float32)  # BLOCK_SIZE_M * BLOCK_SIZE_N
+        acc_bc = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_P), dtype=tl.float32)
+        offs_o = tl.arange(0, BLOCK_SIZE_O)
+        for __ in range(num_o_block):
+            b_mask = (offs_n[:, None] < N) & (offs_o[None, :] < O)
+            b_ptrs = B + (offs_n[:, None] * stride_bn + offs_o[None, :] * stride_bo)
+            b = tl.load(b_ptrs, mask=b_mask, other=0.0).to(tl.float32)  # BLOCK_SIZE_N * BLOCK_SIZE_O
+            c_mask = (offs_o[:, None] < O) & (offs_p[None, :] < P)
+            c_ptrs = C + (offs_o[:, None] * stride_co + offs_p[None, :] * stride_cp)
+            c = tl.load(c_ptrs, mask=c_mask, other=0.0).to(tl.float32)  # BLOCK_SIZE_O * BLOCK_SIZE_P
+            acc_bc += tl.dot(b, c, out_dtype=tl.float32)
+            offs_o += BLOCK_SIZE_O
+        # apply the subgraph
+        {{ modification(
+            subgraph_number=0,
+            output_name="post_subgraph_acc_bc",
+            inner_mm="acc_bc"
+        ) | indent_except_first(2) }}
+        acc += tl.dot(a, post_subgraph_acc_bc, out_dtype=tl.float32)
+        offs_n += BLOCK_SIZE_N
+
+    # type conversion
+    acc = acc.to(tl.float16)
+
+    # store preparation
+    idx_m = offs_m[:, None]
+    idx_p = offs_p[None, :]
+    out_mask = (idx_m < M) & (idx_p < P)
+
+    {{store_output(("idx_m", "idx_p"), "acc", "out_mask", val_shape=("BLOCK_SIZE_M", "BLOCK_SIZE_P"))}}
+""",
+)
+
+
+# Note: load_ratio_left and load_ratio_right are only calculating numbers
+# in the trivial subgraph case; i.e. (A @ (B @ C)) or ((A @ B) @ C)
+
+
+def load_ratio_left(
+    M: int, N: int, O: int, P: int, m: int, n: int, o: int, p: int
+) -> float:
+    """
+    compute the ratio of estimated numbers of loads in baseline and b2bgemm
+    M, N, O, P are matrix sizes
+    m, n, o, p are block sizes
+    |       | baseline (lower bound)        | b2bgemm
+    | load  | M * N + N * O + M * O + O * P | M / m * P / p * O / o * (o * p + N / n * (m * n + n * o))
+    | store | M * O + M * P                 | M * P
+    b2bgemm is always better on stores, but for loads we need to find out beneficial cases using this function
+    """
+    base = M * N + N * O + M * O + O * P
+    gemm = (
+        ceildiv(M, m)
+        * ceildiv(P, p)
+        * ceildiv(O, o)
+        * (o * p + ceildiv(N, n) * (m * n + n * o))
+    )
+    return base / gemm
+
+
+def load_ratio_right(
+    M: int, N: int, O: int, P: int, m: int, n: int, o: int, p: int
+) -> float:
+    """
+    compute the ratio of estimated numbers of loads in baseline and b2bgemm
+    M, N, O, P are matrix sizes
+    m, n, o, p are block sizes
+    |       | baseline (lower bound)        | b2bgemm
+    | load  | N * O + O * P + M * N + N * P | M / m * P / p * N / n * (m * n + O / o * (n * o + o * p))
+    | store | N * P + M * P                 | M * P
+    b2bgemm is always better on stores, but for loads we need to find out beneficial cases using this function
+    """
+    base = N * O + O * P + M * N + N * P
+    gemm = (
+        ceildiv(M, m)
+        * ceildiv(P, p)
+        * ceildiv(N, n)
+        * (m * n + ceildiv(O, o) * (n * o + o * p))
+    )
+    return base / gemm
+
+
+# the block sizes are limited by hardware (the shared memory)
+# intuitively, the optimization works when the intermediate matrix is large
+# and we assign large block sizes to large dimensions
+b2b_gemm_configs = [
+    {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 16,
+        "BLOCK_SIZE_O": 16,
+        "BLOCK_SIZE_P": 16,
+        "num_stages": 4,
+        "num_warps": 8,
+    },
+    {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_O": 32,
+        "BLOCK_SIZE_P": 32,
+        "num_stages": 2,
+        "num_warps": 4,
+    },
+    {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_O": 64,
+        "BLOCK_SIZE_P": 64,
+        "num_stages": 2,
+        "num_warps": 4,
+    },
+    {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 16,
+        "BLOCK_SIZE_O": 128,
+        "BLOCK_SIZE_P": 16,
+        "num_stages": 4,
+        "num_warps": 8,
+    },
+    {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_O": 128,
+        "BLOCK_SIZE_P": 32,
+        "num_stages": 2,
+        "num_warps": 4,
+    },
+    {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_O": 128,
+        "BLOCK_SIZE_P": 64,
+        "num_stages": 2,
+        "num_warps": 4,
+    },
+    {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 16,
+        "BLOCK_SIZE_O": 16,
+        "BLOCK_SIZE_P": 128,
+        "num_stages": 4,
+        "num_warps": 8,
+    },
+    {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_O": 32,
+        "BLOCK_SIZE_P": 128,
+        "num_stages": 2,
+        "num_warps": 4,
+    },
+    {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_O": 64,
+        "BLOCK_SIZE_P": 128,
+        "num_stages": 2,
+        "num_warps": 4,
+    },
+    {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_O": 16,
+        "BLOCK_SIZE_P": 128,
+        "num_stages": 4,
+        "num_warps": 8,
+    },
+    {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_O": 32,
+        "BLOCK_SIZE_P": 128,
+        "num_stages": 2,
+        "num_warps": 4,
+    },
+    {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_O": 64,
+        "BLOCK_SIZE_P": 128,
+        "num_stages": 2,
+        "num_warps": 4,
+    },
+]
+
+
+def is_b2b_gemm_good_on(
+    is_left_assoc: bool,
+    A_node: torch.fx.Node,
+    B_node: torch.fx.Node,
+    C_node: torch.fx.Node,
+) -> bool:
+    """
+    checks whether the sizes are good for b2b_gemm
+    """
+    # basic checks
+    if not all(["val" in A_node.meta, "val" in B_node.meta, "val" in C_node.meta]):
+        return False
+    fake_tensors = (
+        A_node.meta["val"],
+        B_node.meta["val"],
+        C_node.meta["val"],
+    )  # torch._subclasses.fake_tensor.FakeTensor
+
+    A, B, C = fake_tensors
+
+    def check_all_attr_true(objects, attr):
+        return all(hasattr(obj, attr) and getattr(obj, attr) for obj in objects)
+
+    if not check_all_attr_true(fake_tensors, "is_cuda") and not check_all_attr_true(
+        fake_tensors, "is_xpu"
+    ):
+        return False
+    if not all([len(A.shape) == 2, len(B.shape) == 2, len(C.shape) == 2]):
+        return False
+    if not ((A.shape[1] == B.shape[0]) and (B.shape[1] == C.shape[0])):
+        return False
+    # size checks: we only dispatch to B2B-GEMM when the average load ratio is > 1
+    M, N = A.shape
+    O, P = C.shape
+    ratios = []
+    if is_left_assoc:
+        for config in b2b_gemm_configs:
+            ratio = load_ratio_left(
+                M,
+                N,
+                O,
+                P,
+                config["BLOCK_SIZE_M"],
+                config["BLOCK_SIZE_N"],
+                config["BLOCK_SIZE_O"],
+                config["BLOCK_SIZE_P"],
+            )
+            ratios.append(ratio)
+    else:
+        for config in b2b_gemm_configs:
+            ratio = load_ratio_right(
+                M,
+                N,
+                O,
+                P,
+                config["BLOCK_SIZE_M"],
+                config["BLOCK_SIZE_N"],
+                config["BLOCK_SIZE_O"],
+                config["BLOCK_SIZE_P"],
+            )
+            ratios.append(ratio)
+    ratios.sort(reverse=True)
+    average_ratio = 1.0
+    for r in ratios[:3]:  # top 3 choices
+        average_ratio *= r
+    average_ratio = average_ratio ** (1 / 3)
+    return (
+        average_ratio > 1
+    )  # even if average_ratio is close to 1, the number of stores is always better
+
+
+def unoptimized_b2b_gemm(
+    is_left_assoc: bool,
+    subgraph: Subgraph,
+    A: torch.Tensor,
+    B: torch.Tensor,
+    C: torch.Tensor,
+    *,
+    out: torch.Tensor,
+) -> torch.Tensor:
+    """
+    The unoptimized version is used as a fallback when the b2b_gemm kernel is not beneficial.
+    """
+    if is_left_assoc:
+        torch.mm(subgraph.graph_module(torch.mm(A, B)), C, out=out)
+    else:
+        torch.mm(A, subgraph.graph_module(torch.mm(B, C)), out=out)
+    return out
+
+
+unoptimized_choice = ExternKernelChoice(unoptimized_b2b_gemm)
+
+
+def build_subgraph_buffer(
+    args: list[TensorBox],
+    subgraph: Subgraph,
+):
+    """
+    This function is adapted from ../kernel/flex_attention.py.
+    The goal is to take in the required args and produce the subgraph buffer
+    The subgraph buffer is a ComputedBuffer that will be inlined into the triton template
+
+    Args:
+        args: The args that are passed into the subgraph
+        subgraph: The Subgraph ir for which to produce the output node
+    """
+    cnt = 0
+    env = {}
+    for node in subgraph.graph_module.graph.nodes:
+        if node.op == "placeholder":
+            env[node] = args[cnt]
+            cnt += 1
+        elif node.op == "call_function":
+            # For call_function we use the default lowerings and pass in the
+            # already created TensorBoxes as args
+            args, kwargs = tree_map(lambda x: env.get(x, x), (node.args, node.kwargs))
+            env[node] = lowerings[node.target](*args, **kwargs)
+        elif node.op == "output":
+
+            def convert_output_node_to_buffer(output):
+                if output is None:
+                    return None
+                output_node = output
+                output_buffer = env[output_node]
+                assert isinstance(output_buffer, TensorBox), (
+                    "The output node for B2B-GEMM's subgraph must be a TensorBox, but got: ",
+                    type(output_buffer),
+                )
+                assert isinstance(output_buffer.data, StorageBox), (
+                    "The output node for B2B-GEMM's subgraph must be a StorageBox, but got: ",
+                    type(output_buffer),
+                )
+                device = output_buffer.data.get_device()
+                assert device is not None
+                subgraph_buffer = ComputedBuffer(
+                    name=None,
+                    layout=FlexibleLayout(
+                        device=device,
+                        dtype=output_buffer.data.get_dtype(),
+                        size=output_buffer.data.get_size(),
+                    ),
+                    data=output_buffer.data.data,  # type: ignore[arg-type]
+                )
+                return subgraph_buffer
+
+            # node.args[0] should be a single element representing the output of the subgraph
+            return tree_map(convert_output_node_to_buffer, node.args[0])
+
+    raise ValueError("B2B-GEMM was passed a subgraph with no output node!")
+
+
+def create_placeholder(
+    name: str, dtype: torch.dtype, device: torch.device
+) -> TensorBox | ShapeAsConstantBuffer:
+    """
+    Creates a placeholder input buffers for producing subgraph_output
+    """
+    input_buffer = InputBuffer(name=name, layout=FixedLayout(device, dtype, [], []))
+    return TensorBox.create(input_buffer)
+
+
+def tuned_b2b_gemm(
+    is_left_assoc: bool,
+    subgraph: Subgraph,
+    A: torch._inductor.ir.TensorBox,
+    B: torch._inductor.ir.TensorBox,
+    C: torch._inductor.ir.TensorBox,
+    *,
+    layout=None,
+) -> torch._inductor.ir.TensorBox:
+    # call .realize() to get rid of Pointwise
+    A.realize()
+    B.realize()
+    C.realize()
+    layout = FixedLayout(
+        A.get_device_or_error(),
+        A.get_dtype(),
+        [A.shape[0], C.shape[1]],  # type: ignore[index]
+    )
+    placeholders = [
+        create_placeholder("inner_mm", A.get_dtype(), A.get_device_or_error())
+    ]
+    subgraph_buffer = build_subgraph_buffer(
+        placeholders,  # type: ignore[arg-type, list-item]
+        subgraph,
+    )
+    choices: list[TritonTemplateCaller] = []
+    for config in b2b_gemm_configs:
+        if is_left_assoc:
+            b2b_gemm_left_template.maybe_append_choice(
+                choices,
+                input_nodes=(A, B, C),
+                layout=layout,
+                subgraphs=[subgraph_buffer],
+                **config,
+            )
+        else:
+            b2b_gemm_right_template.maybe_append_choice(
+                choices,
+                input_nodes=(A, B, C),
+                layout=layout,
+                subgraphs=[subgraph_buffer],
+                **config,
+            )
+    # add the unoptimized choice to mitigate performance degradation
+    choices.append(
+        unoptimized_choice.bind(
+            (A, B, C), layout, is_left_assoc=is_left_assoc, subgraph=subgraph
+        )
+    )
+    # autotune
+    return autotune_select_algorithm("b2b_gemm", choices, [A, B, C], layout)
+
+
+# match the inner mm of a potential b2b_gemm
+@register_graph_pattern(
+    CallFunction(torch.ops.aten.mm, Arg(), Arg()),
+    # pyrefly: ignore [bad-argument-type]
+    pass_dict=B2B_GEMM_PASS,
+)
+def b2b_gemm_handler(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node) -> None:
+    # match.args: list[torch.fx.Node]
+
+    def is_pointwise_node(node: torch.fx.Node) -> bool:
+        return (
+            node.op == "call_function"
+            and isinstance(node.target, torch._ops.OpOverload)
+            and (torch.Tag.pointwise in node.target.tags)
+        )
+
+    def is_mm(node: torch.fx.Node) -> bool:
+        return node.target is torch.ops.aten.mm.default
+
+    # the inner MM
+    inner_mm = match.nodes[-1]
+
+    # find the (candidate) outer MM, which will be re-checked below to ensure every path reaches it
+    # In a real (A @ f(B @ C)), every path starting from (B @ C) must reach (A @ _).
+    outer_mm = None
+    node = inner_mm
+    while len(node.users) > 0:
+        node = next(iter(node.users))
+        if is_mm(node):
+            outer_mm = node
+            break
+        elif is_pointwise_node(node):
+            continue
+        else:
+            break
+    if not outer_mm:
+        return
+
+    # find the unique input node for outer_mm representing f(B @ C) in (A @ f(B @ C))
+    # we call it the "f_node"
+    # when the pattern is simply (A @ (B @ C)), f_node is just inner_mm
+    f_node = inner_mm
+    while next(iter(f_node.users)) is not outer_mm:
+        f_node = next(iter(f_node.users))
+
+    def all_reach_via_pointwise_with_no_other_inputs(
+        src: torch.fx.Node,
+        dst: torch.fx.Node,
+    ) -> tuple[bool, OrderedSet[torch.fx.Node]]:
+        """
+        check whether every user path from src reaches dst via pointwise nodes,
+        with no other input nodes for the intermediates and dst;
+        return
+        (1) the Boolean value
+        (2) the subgraph node set including src and dst (which only makes sense when the Boolean value is True)
+        """
+        visited = OrderedSet[torch.fx.Node]()
+        input_counter: dict[torch.fx.Node, int] = {}
+
+        all_reachable = True
+        queue = deque([src])
+        while queue:
+            node = queue.popleft()
+            if node not in visited:
+                if node is dst:
+                    visited.add(node)
+                elif (node is src) or is_pointwise_node(node):
+                    for user in node.users:
+                        # for nodes other than dst, bookkeep their users' input counts
+                        if user not in input_counter:
+                            input_counter[user] = len(user.all_input_nodes)
+                        input_counter[user] -= 1
+                        # continue BFS
+                        queue.append(user)
+                    visited.add(node)
+                else:
+                    all_reachable = False
+                    break
+
+        return (
+            all_reachable and all(count == 0 for count in input_counter.values()),
+            visited,
+        )
+
+    # check inner_mm reaches f_node on every user path via pointwise nodes with no outside input_nodes
+    ok, subgraph_node_set = all_reach_via_pointwise_with_no_other_inputs(
+        inner_mm, f_node
+    )
+    if not ok:
+        return
+
+    # check inner_mm's inputs and f_node's outputs
+    if not (len(inner_mm.all_input_nodes) == 2 and len(f_node.users) == 1):
+        return
+
+    # at this point, the nodes between inner_mm and f_node (both included)
+    # are all used internally inside (A @ subgraph(B @ C))
+    # i.e. they neither have other users nor have other inputs
+
+    # original graph and module
+    graph, module = inner_mm.graph, inner_mm.graph.owning_module
+
+    # construct the new (sub)graph
+    subgraph_node_list: list[
+        torch.fx.Node
+    ] = []  # ordered list of nodes used for node removal later
+    new_graph: torch.fx.Graph = torch.fx.Graph()
+    node_remapping: dict[torch.fx.Node, torch.fx.Node] = {}
+    new_input_anchor: torch.fx.Node  # inner_mm, to be changed to an input node
+    new_output_anchor: torch.fx.Node  # f_node, to be used to construct an output node
+    new_input_node: torch.fx.Node
+    new_output_node: torch.fx.Node
+    for node in graph.nodes:  # preserve the order of nodes
+        if node in subgraph_node_set:
+            subgraph_node_list.append(node)
+            new_node = new_graph.node_copy(node, lambda x: node_remapping.get(x, x))
+            node_remapping[node] = new_node
+            if node is inner_mm:
+                new_input_anchor = new_node
+            if node is f_node:
+                new_output_anchor = new_node
+    # pyrefly: ignore [unbound-name]
+    if new_input_anchor is not new_output_anchor:  # subgraph is non-trivial
+        # update the input node
+        # pyrefly: ignore [unbound-name]
+        with new_graph.inserting_before(new_input_anchor):
+            new_input_node = new_graph.placeholder(name="subgraph_input")
+            # pyrefly: ignore [unbound-name]
+            new_input_node.meta.update(new_input_anchor.meta)
+            # pyrefly: ignore [unbound-name]
+            new_input_anchor.replace_all_uses_with(new_input_node)
+        # pyrefly: ignore [unbound-name]
+        new_graph.erase_node(new_input_anchor)
+        # add the output node
+        # pyrefly: ignore [unbound-name]
+        new_output_node = new_graph.output(new_output_anchor)
+        # pyrefly: ignore [unbound-name]
+        new_output_node.meta.update(new_output_anchor.meta)
+    else:  # subgraph is trivial, e.g. (A @ (B @ C))
+        # update the input node
+        # pyrefly: ignore [unbound-name]
+        with new_graph.inserting_before(new_input_anchor):
+            new_input_node = new_graph.placeholder(name="subgraph_input")
+            # pyrefly: ignore [unbound-name]
+            new_input_node.meta.update(new_input_anchor.meta)
+            # pyrefly: ignore [unbound-name]
+            new_input_anchor.replace_all_uses_with(new_input_node)
+        # pyrefly: ignore [unbound-name]
+        new_graph.erase_node(new_input_anchor)
+        # update the output node (don't use new_output_anchor since it has been erased)
+        new_output_node = new_graph.output(new_input_node)
+        new_output_node.meta.update(new_input_node.meta)
+    new_graph.lint()
+
+    # construct the subgraph
+    subgraph = Subgraph(
+        name="subgraph", graph_module=torch.fx.GraphModule(module, new_graph)
+    )
+
+    # two cases
+    # (1) (subgraph(A @ B) @ C), called "left_assoc"
+    # (2) (A @ subgraph(B @ C)), called "right_assoc"
+    is_left_assoc = outer_mm.args[0] is f_node
+
+    # find the nodes A, B, C and check the sizes
+    A: torch.fx.Node
+    B: torch.fx.Node
+    C: torch.fx.Node
+    if is_left_assoc:
+        A = inner_mm.args[0]  # type: ignore[assignment]
+        B = inner_mm.args[1]  # type: ignore[assignment]
+        C = outer_mm.args[1]  # type: ignore[assignment]
+    else:
+        A = outer_mm.args[0]  # type: ignore[assignment]
+        B = inner_mm.args[0]  # type: ignore[assignment]
+        C = inner_mm.args[1]  # type: ignore[assignment]
+    if not is_b2b_gemm_good_on(is_left_assoc, A, B, C):
+        return
+
+    # finally update the original graph
+    counters["inductor"]["b2b_gemm"] += 1
+    graph = match.graph
+    with graph.inserting_before(outer_mm):
+        function = functools.partial(tuned_b2b_gemm, is_left_assoc, subgraph)
+        function.__name__ = tuned_b2b_gemm.__name__  # type: ignore[attr-defined]
+        function._inductor_lowering_function = True  # type: ignore[attr-defined]
+        replacement: torch.fx.Node = graph.call_function(
+            function,
+            (A, B, C),
+            match.kwargs,
+        )
+        replacement.meta.update(outer_mm.meta)
+        outer_mm.replace_all_uses_with(replacement)
+    # erase unnecessary nodes
+    graph.erase_node(outer_mm)
+    for node in reversed(subgraph_node_list):
+        graph.erase_node(node)
+    graph.lint()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/binary_folding.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/binary_folding.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f9bce1a8a2d599da6e8fa1f9b5a9442d6cbb954
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/binary_folding.py
@@ -0,0 +1,503 @@
+# mypy: allow-untyped-defs
+import functools
+import itertools
+
+import torch
+
+from ..._dynamo.utils import counters
+from .. import config
+from ..pattern_matcher import Arg, CallFunction, KeywordArg
+from .freezing_patterns import register_binary_folding_pattern
+
+
+aten = torch.ops.aten
+prims = torch.ops.prims
+
+
+def mark_mixed_dtype(computation_node):
+    computation_node_dtype = computation_node.meta["val"].dtype
+    if computation_node_dtype not in (torch.float16, torch.bfloat16):
+        return
+
+    if len(computation_node.users) != 1:
+        return
+
+    computation_node_user = next(iter(computation_node.users.keys()))
+    if not isinstance(computation_node_user.meta["val"], torch.Tensor):
+        return
+
+    if computation_node_user.meta["val"].dtype != torch.float32:
+        return
+
+    while computation_node_user.target in _binary_ops:
+        if len(computation_node_user.users) != 1:
+            return
+
+        computation_node_user = next(iter(computation_node_user.users.keys()))
+
+    if computation_node_user.target != prims.convert_element_type.default:
+        return
+
+    computation_node.meta["_allow_mixed_dtype_folding"] = computation_node_dtype
+
+
+def mark_mixed_dtype_allowed_computation_ops(gm):
+    """
+    Mark convolutions/linear which we will binary fold even with mixed precision constants. We constant fold in the higher precision
+    for better accuracy and then recover the original precision after.
+    """
+    for target in [aten.convolution.default, aten.addmm.default, aten.mm.default]:
+        for node in gm.graph.find_nodes(op="call_function", target=target):
+            mark_mixed_dtype(node)
+
+
+def recover_original_precision_folded_computation_ops(gm):
+    """
+    After binary folding conv/linear weights and biases to a higher dtype, recover the original precision they were in.
+    """
+    graph = gm.graph
+    for target, idx in (
+        (aten.convolution.default, (1, 2)),
+        (aten.addmm.default, (0, 2)),
+        (aten.mm.default, (1,)),
+    ):
+        for node in graph.find_nodes(op="call_function", target=target):
+            orig_dtype = node.meta.get("_allow_mixed_dtype_folding", None)
+            if orig_dtype is None:
+                continue
+
+            with graph.inserting_before(node):
+                for i in idx:
+                    old_input = node.args[i]
+                    if old_input is None:
+                        continue
+
+                    new_input = graph.create_node(
+                        "call_function",
+                        prims.convert_element_type.default,
+                        (old_input, orig_dtype),
+                    )
+                    node.replace_input_with(old_input, new_input)
+
+
+_binary_ops = [aten.add.Tensor, aten.sub.Tensor, aten.mul.Tensor, aten.div.Tensor]
+
+
+@functools.cache
+def binary_folding_init():
+    _conv_args = [Arg() for _ in range(9)]
+    _addmm_args = [Arg() for _ in range(3)]
+    _mm_args = [Arg() for _ in range(2)]
+    _computation_ops = [aten.convolution.default, aten.addmm.default, aten.mm.default]
+    _computation_calls = [
+        CallFunction(aten.convolution.default, *_conv_args, _users=1),
+        CallFunction(aten.addmm.default, *_addmm_args, _users=1),
+        CallFunction(
+            aten.reshape.default,
+            CallFunction(aten.addmm.default, *_addmm_args, _users=1),
+            Arg(),
+            _users=1,
+        ),
+        CallFunction(aten.mm.default, *_mm_args, _users=1),
+        CallFunction(
+            aten.reshape.default,
+            CallFunction(aten.mm.default, *_mm_args, _users=1),
+            Arg(),
+            _users=1,
+        ),
+    ]
+
+    """
+    In order to fuse add/sub/mul/div with conv/linear, the dimensions of its
+    constant tensor must satisfy the following:
+    - with resizing, broadcast to w/ weight/bias tensor shape
+    - broadcast to the conv/linear output shape
+    It needs to have a shape that can resize to weight/bias
+    tensor shape because we need to run the op with the conv/linear
+    weights/bias without changing their sizes.
+    It needs to broadcast to the conv/linear output shape so that we do
+    accidentally change the shape of op output by pre-fusing it
+    compared to eager.
+    The only dimension value shared by weight, bias, and conv/linear output
+    is they all contain a dim with value = channels-out. In the
+    conv/linear output tensor, this is in the second dimension,
+    so the pointwise op tensor may have a second dimension of
+    value == channels-out, but all the other dimensions have to be 1
+    """
+
+    def _op_not_broadcasting_with_conv(weight_tensor, other_tensor):
+        # According to opDoesNotBroadCastWithConv of frozen_conv_folding.cpp
+        weight_shape = weight_tensor.shape
+        other_shape = other_tensor.shape
+        if len(weight_shape) < len(other_shape):
+            return False
+        if len(weight_shape) == len(other_shape) + 1:
+            # weight shape is [o, i, *], other_shape is [o, 1...].
+            for i in reversed(range(len(other_shape))):
+                if i == 0 and weight_shape[0] == other_shape[i]:
+                    continue
+                if other_shape[i] != 1:
+                    return False
+        else:
+            # weight shape is [o, i, *], other_shape is [1, i, *]
+            for i in reversed(range(len(other_shape))):
+                if i == 1 and weight_shape[0] == other_shape[i]:
+                    continue
+                if other_shape[i] != 1:
+                    return False
+        return True
+
+    def _op_not_broadcasting_with_linear(weight_tensor, other_tensor, has_reshape):
+        weight_shape = weight_tensor.shape
+        other_shape = other_tensor.shape
+        other_shapes = [
+            torch.Size(
+                [
+                    weight_shape[1],
+                ]
+            ),
+            torch.Size([1, weight_shape[1]]),
+            torch.Size(
+                [
+                    1,
+                ]
+            ),
+            torch.Size([1, 1]),
+        ]
+        if has_reshape:
+            other_shapes.extend(
+                [
+                    torch.Size([1, 1, weight_shape[1]]),
+                    torch.Size([1, 1, 1]),
+                ]
+            )
+        return other_shape in other_shapes
+
+    def _check_conv_and_broadcast_op(conv_node, other):
+        # According to checkConvAndBroadcastingOpPreConditions of frozen_conv_folding.cpp.
+        # conv.weight
+        if conv_node.args[1].op != "get_attr":
+            return False
+        # conv.bias
+        if conv_node.args[1] is not None and conv_node.args[1].op != "get_attr":
+            return False
+        if (
+            not isinstance(other, int)
+            and not isinstance(other, float)
+            and other.op != "get_attr"
+        ):
+            return False
+
+        if len(conv_node.args[1].users) != 1:
+            return False
+
+        weight_meta_value = conv_node.args[1].meta.get("val")
+        if weight_meta_value is None:
+            return False
+        # Avoid fusing op that causes type promotion
+        # restricting to float avoids int/float difficulties with scalar overload
+        if not weight_meta_value.is_floating_point():
+            return False
+        if isinstance(other, torch.fx.Node) and other.op == "get_attr":
+            other_meta_value = other.meta.get("val")
+            if not other_meta_value.is_floating_point():  # type: ignore[union-attr]
+                return False
+            if (
+                torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype)  # type: ignore[union-attr]
+                != weight_meta_value.dtype
+            ):
+                if not conv_node.meta.get("_allow_mixed_dtype_folding", False):
+                    return False
+
+                if (
+                    other_meta_value.dtype != torch.float  # type: ignore[union-attr]
+                    and weight_meta_value.dtype not in (torch.float16, torch.bfloat16)
+                ):
+                    return False
+
+            if not _op_not_broadcasting_with_conv(weight_meta_value, other_meta_value):
+                return False
+        elif not isinstance(other, float):
+            return False
+
+        return True
+
+    def _check_linear_and_broadcast_op(linear_node, other, has_reshape):
+        weight_node = (
+            linear_node.args[2]
+            if linear_node.target is aten.addmm.default
+            else linear_node.args[1]
+        )
+        bias_node = (
+            linear_node.args[0] if linear_node.target is aten.addmm.default else None
+        )
+        if weight_node.op != "get_attr":
+            return False
+        if bias_node is not None and bias_node.op != "get_attr":
+            return False
+        if (
+            not isinstance(other, int)
+            and not isinstance(other, float)
+            and other.op != "get_attr"
+        ):
+            return False
+
+        if len(weight_node.users) != 1:
+            return False
+
+        weight_meta_value = weight_node.meta.get("val")
+        if weight_meta_value is None:
+            return False
+        # Avoid fusing op that causes type promotion
+        # restricting to float avoids int/float difficulties with scalar overload
+        if not weight_meta_value.is_floating_point():
+            return False
+        if isinstance(other, torch.fx.Node) and other.op == "get_attr":
+            other_meta_value = other.meta.get("val")
+            if not other_meta_value.is_floating_point():  # type: ignore[union-attr]
+                return False
+            if (
+                torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype)  # type: ignore[union-attr]
+                != weight_meta_value.dtype
+            ):
+                if not linear_node.meta.get("_allow_mixed_dtype_folding", False):
+                    return False
+
+                if (
+                    other_meta_value.dtype != torch.float  # type: ignore[union-attr]
+                    and weight_meta_value.dtype not in (torch.float16, torch.bfloat16)
+                ):
+                    return False
+
+            if not _op_not_broadcasting_with_linear(
+                weight_meta_value, other_meta_value, has_reshape
+            ):
+                return False
+        elif not isinstance(other, float):
+            return False
+
+        return True
+
+    def _is_foldable_pattern(match):
+        binary_node = match.output_node()
+        has_reshape = False
+        if binary_node.args[0].target in _computation_ops:
+            computation_node = binary_node.args[0]
+            other = binary_node.args[1]
+        elif binary_node.args[0].target is aten.reshape.default:
+            computation_node = binary_node.args[0].args[0]
+            other = binary_node.args[1]
+            has_reshape = True
+        elif binary_node.args[1].target in _computation_ops:
+            computation_node = binary_node.args[1]
+            other = binary_node.args[0]
+        else:
+            computation_node = binary_node.args[1].args[0]
+            other = binary_node.args[0]
+            has_reshape = False
+        if computation_node.target is aten.convolution.default:
+            return _check_conv_and_broadcast_op(computation_node, other)
+        elif computation_node.target in [aten.addmm.default, aten.mm.default]:
+            return (
+                config.enable_linear_binary_folding
+                and _check_linear_and_broadcast_op(computation_node, other, has_reshape)
+            )
+
+        return False
+
+    def resize_scalar_or_tensor_to_shape(graph, other, shape, weight):
+        if isinstance(other, float):
+            with torch.utils._python_dispatch._disable_current_modes():
+                other_tensor = torch.tensor(
+                    other, dtype=weight.dtype, device=weight.device
+                )
+            graph.owning_module.register_buffer("other_tensor", other_tensor)
+            res = graph.create_node("get_attr", "other_tensor")
+            res = graph.create_node(
+                "call_function",
+                aten.reshape.default,
+                (res, (1,)),
+            )
+            res = graph.create_node(
+                "call_function",
+                aten.expand.default,
+                (res, shape),
+            )
+        elif other.meta.get("val").numel() == 1:
+            # expand errors if the shape input has less # dims than the tensor input
+            res = graph.create_node(
+                "call_function",
+                aten.reshape.default,
+                (other, (1,)),
+            )
+            res = graph.create_node(
+                "call_function",
+                aten.expand.default,
+                (res, shape),
+            )
+        else:
+            res = graph.create_node(
+                "call_function",
+                aten.reshape.default,
+                (other, shape),
+            )
+        return res
+
+    def _create_new_conv_node(graph, conv_node, binary_node, other):
+        assert conv_node.target is aten.convolution.default
+        conv_args = list(conv_node.args)
+        weight_meta_value = conv_node.args[1].meta.get("val")
+        bias = conv_args[2]
+        if binary_node.target in [aten.add.Tensor, aten.sub.Tensor]:
+            other_reshape = resize_scalar_or_tensor_to_shape(
+                graph,
+                other,
+                (weight_meta_value.size(0),),
+                weight_meta_value,
+            )
+            new_bias = graph.create_node(
+                "call_function",
+                binary_node.target,
+                (0 if bias is None else bias, other_reshape),
+            )
+            conv_args[2] = new_bias
+        else:
+            assert binary_node.target in [aten.mul.Tensor, aten.div.Tensor]
+            weight_broadcast_shape = [1 for _ in range(len(weight_meta_value.shape))]
+            weight_broadcast_shape[0] = weight_meta_value.size(0)
+            other_reshape1 = resize_scalar_or_tensor_to_shape(
+                graph,
+                other,
+                tuple(weight_broadcast_shape),
+                weight_meta_value,
+            )
+            new_weight = graph.create_node(
+                "call_function", binary_node.target, (conv_args[1], other_reshape1)
+            )
+            new_weight.meta.update(conv_args[1].meta)
+            conv_args[1] = new_weight
+            if bias is not None:
+                other_reshape = resize_scalar_or_tensor_to_shape(
+                    graph,
+                    other,
+                    (weight_meta_value.size(0),),
+                    weight_meta_value,
+                )
+                new_bias = graph.create_node(
+                    "call_function", binary_node.target, (bias, other_reshape)
+                )
+                new_bias.meta.update(bias.meta)
+                conv_args[2] = new_bias
+        return graph.create_node("call_function", conv_node.target, tuple(conv_args))
+
+    def _create_new_linear_node(graph, linear_node, binary_node, other):
+        assert linear_node.target in [aten.addmm.default, aten.mm.default]
+        input_node = (
+            linear_node.args[1]
+            if linear_node.target is aten.addmm.default
+            else linear_node.args[0]
+        )
+        weight_node = (
+            linear_node.args[2]
+            if linear_node.target is aten.addmm.default
+            else linear_node.args[1]
+        )
+        bias_node = (
+            linear_node.args[0] if linear_node.target is aten.addmm.default else None
+        )
+        weight_meta_value = weight_node.meta.get("val")
+        if binary_node.target in [aten.add.Tensor, aten.sub.Tensor]:
+            other_reshape = resize_scalar_or_tensor_to_shape(
+                graph,
+                other,
+                (weight_meta_value.size(1),),
+                weight_meta_value,
+            )
+            new_bias_node = graph.create_node(
+                "call_function",
+                binary_node.target,
+                (0 if bias_node is None else bias_node, other_reshape),
+            )
+            return graph.create_node(
+                "call_function",
+                aten.addmm.default,
+                (new_bias_node, input_node, weight_node),
+            )
+        else:
+            assert binary_node.target in [aten.mul.Tensor, aten.div.Tensor]
+            weight_broadcast_shape = [1, weight_meta_value.size(1)]
+            other_reshape1 = resize_scalar_or_tensor_to_shape(
+                graph,
+                other,
+                tuple(weight_broadcast_shape),
+                weight_meta_value,
+            )
+            new_weight_node = graph.create_node(
+                "call_function", binary_node.target, (weight_node, other_reshape1)
+            )
+            new_weight_node.meta.update(weight_node.meta)
+            if bias_node is not None:
+                other_reshape = resize_scalar_or_tensor_to_shape(
+                    graph,
+                    other,
+                    (weight_meta_value.size(1),),
+                    weight_meta_value,
+                )
+                new_bias_node = graph.create_node(
+                    "call_function", binary_node.target, (bias_node, other_reshape)
+                )
+                new_bias_node.meta.update(bias_node.meta)
+                return graph.create_node(
+                    "call_function",
+                    linear_node.target,
+                    (new_bias_node, input_node, new_weight_node),
+                )
+            else:
+                return graph.create_node(
+                    "call_function", linear_node.target, (input_node, new_weight_node)
+                )
+
+    for _computation_call, binary_op in itertools.product(
+        _computation_calls, _binary_ops
+    ):
+
+        @register_binary_folding_pattern(
+            CallFunction(binary_op, _computation_call, KeywordArg("other")),
+            extra_check=_is_foldable_pattern,
+        )
+        def folded_op(match, *args, **kwargs):
+            counters["inductor"]["binary_folding"] += 1
+            other = kwargs.get("other")
+            binary_node = match.output_node()
+            reshape_node = None
+            if binary_node.args[0].target in _computation_ops:
+                computation_node = binary_node.args[0]
+            elif binary_node.args[0].target is aten.reshape.default:
+                computation_node = binary_node.args[0].args[0]
+                reshape_node = binary_node.args[0]
+            elif binary_node.args[1].target in _computation_ops:
+                computation_node = binary_node.args[1]
+            else:
+                computation_node = binary_node.args[1].args[0]
+                reshape_node = binary_node.args[1]
+            graph = match.graph
+            with graph.inserting_before(reshape_node if reshape_node else binary_node):
+                assert computation_node.target in _computation_ops
+                if computation_node.target is aten.convolution.default:
+                    counters["inductor"]["binary_folding_conv"] += 1
+                    new_computation_node = _create_new_conv_node(
+                        graph, computation_node, binary_node, other
+                    )
+                else:
+                    new_computation_node = _create_new_linear_node(
+                        graph, computation_node, binary_node, other
+                    )
+                new_computation_node.meta.update(computation_node.meta)
+                if reshape_node:
+                    assert reshape_node.target is aten.reshape.default
+                    computation_node.replace_all_uses_with(new_computation_node)
+                    binary_node.replace_all_uses_with(reshape_node)
+                else:
+                    binary_node.replace_all_uses_with(new_computation_node)
+                graph.erase_node(binary_node)
+                graph.erase_node(computation_node)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/bucketing.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/bucketing.py
new file mode 100644
index 0000000000000000000000000000000000000000..e72cdccddb44010f316cad92d8e10e1d13af6400
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/bucketing.py
@@ -0,0 +1,1100 @@
+import collections
+import logging
+import operator
+from collections import defaultdict
+from collections.abc import Callable
+from typing import Any, Literal, TypeAlias
+
+import torch
+import torch.distributed as dist
+import torch.utils._pytree as pytree
+from torch._dispatch.python import enable_python_dispatcher
+from torch._dynamo.utils import detect_fake_mode
+from torch._inductor.comm_analysis import (
+    get_collective_type_from_kernel_name,
+    NCCL_COLL,
+)
+from torch._inductor.runtime.runtime_utils import dynamo_timed
+from torch._logging import trace_structured
+from torch.fx.experimental.proxy_tensor import make_fx
+from torch.fx.traceback import NodeSource, NodeSourceAction
+from torch.utils._ordered_set import OrderedSet
+
+
+logger: logging.Logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+overlap_log = torch._logging.getArtifactLogger(__name__, "overlap")
+
+BucketMode: TypeAlias = Literal["default", "custom_ops", "custom_ops_multidtype"]
+
+
+# Helper functions moved to top for better organization
+def _ag_group_key(node: torch.fx.Node) -> tuple[str, torch.dtype]:  # type: ignore[name-defined]
+    _, group_size, group_name = node.args
+    dtype = node.meta["val"].dtype
+    assert isinstance(group_name, str)
+    return (group_name, dtype)
+
+
+def _ag_group_key_multidtype(node: torch.fx.Node) -> tuple[str]:
+    _, group_size, group_name = node.args
+    assert isinstance(group_name, str)
+    return (group_name,)
+
+
+def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:  # type: ignore[name-defined]
+    _, reduce_op, group_size, group_name = node.args
+    dtype = node.meta["val"].dtype
+    assert isinstance(group_name, str)
+    assert isinstance(reduce_op, str)
+    return (group_name, reduce_op, dtype)
+
+
+def _ar_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
+    _, reduce_op, group_name = node.args
+    dtype = node.meta["val"].dtype
+    assert isinstance(group_name, str)
+    assert isinstance(reduce_op, str)
+    return (group_name, reduce_op, dtype)
+
+
+def _schedulable_wait_node(node: torch.fx.Node) -> bool:
+    """
+    Add additional check on if the wait node is schedulable
+    We should not schedule a fx node that is:
+        1. wait on a collective that is not callable
+        2. wait on a non-NCCL communication node
+    """
+    if not is_wait_tensor(node):
+        return False
+    assert isinstance(node.args[0], torch.fx.Node)
+    if not isinstance(node.args[0].target, Callable):
+        return False
+    is_callable: bool = node.args[0].op == "call_function"
+    coll: NCCL_COLL = get_collective_type_from_kernel_name(node.args[0].target.name())
+    is_collective: bool = coll != NCCL_COLL.UNSUPPORTED
+    return is_callable and is_collective
+
+
+def _populate_node_meta(
+    bucket_nodes: list[torch.fx.Node], new_nodes: list[torch.fx.Node]
+):
+    if bucket_nodes:
+        for n in new_nodes:
+            # For the following keys, we only store the information of the first node so
+            # gm.print_readable shows some information
+            # Full information are stored in "bucketing_{key}_sources"
+            for key, default in [
+                ("nn_module_stack", ""),
+                ("fwd_nn_module_stack", ""),
+                ("stack_trace", ""),
+                ("custom", {}),
+            ]:
+                n.meta[key] = bucket_nodes[0].meta.get(key, default)
+
+                # Collect sources from all bucket nodes for this metadata key, for debugging purposes only
+                bucketing_sources_key = f"bucketing_{key}_sources"
+                # Use set to remove duplicates
+                if key == "stack_trace":
+                    sources = OrderedSet(
+                        [
+                            node.meta.get(key, default)
+                            for node in bucket_nodes
+                            if node.meta.get(key, default)
+                        ]
+                    )
+                else:
+                    # type might not be hashable
+                    sources = [
+                        node.meta.get(key, default)
+                        for node in bucket_nodes
+                        if node.meta.get(key, default)
+                    ]
+                n.meta[bucketing_sources_key] = sources
+
+            # used by inductor provenance tracking
+            n.meta["from_node"] = [
+                NodeSource(
+                    original_node,
+                    "bucketing_pass",
+                    [NodeSourceAction.CREATE, NodeSourceAction.REPLACE],
+                )
+                for original_node in bucket_nodes
+            ]
+
+
+def bucket_key(node: torch.fx.Node, mode: BucketMode | None = None) -> object | None:
+    if is_all_gather_into_tensor(node):
+        group_key_fn = (
+            _ag_group_key_multidtype if mode and "multidtype" in mode else _ag_group_key
+        )
+        return group_key_fn(node)
+    elif is_reduce_scatter_tensor(node):
+        return _rs_group_key(node)
+    elif is_all_reduce_tensor(node):
+        return _ar_group_key(node)
+    else:
+        return None
+
+
+def pick_bucket_dtype(dtypes: list[torch.dtype]) -> torch.dtype:  # type: ignore[name-defined]
+    assert len(dtypes) > 0
+    return min(dtypes, key=operator.attrgetter("itemsize"))
+
+
+def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float:
+    """
+    Determine the size of a bucket based on its ID.
+
+    Args:
+    bucket_id (int): The ID of the bucket.
+
+    Returns:
+    float: The size of the bucket.
+    """
+    return 2000.0
+
+
+def bucket_all_gather(
+    gm: torch.fx.GraphModule,
+    bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
+    mode: BucketMode = "default",
+) -> None:
+    if bucket_cap_mb_by_bucket_idx is None:
+        from torch._inductor.fx_passes.bucketing import (  # pyrefly: ignore  # missing-module-attribute
+            bucket_cap_mb_by_bucket_idx_default,
+        )
+
+        bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
+    ag_buckets = bucket_all_gather_by_mb(gm, bucket_cap_mb_by_bucket_idx, None, mode)
+    if len(ag_buckets) == 0:
+        return
+    merge_all_gather(gm, ag_buckets, mode)
+
+
+def bucket_reduce_scatter(
+    gm: torch.fx.GraphModule,
+    bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
+    mode: BucketMode = "default",
+) -> None:
+    if bucket_cap_mb_by_bucket_idx is None:
+        from torch._inductor.fx_passes.bucketing import (  # pyrefly: ignore  # missing-module-attribute
+            bucket_cap_mb_by_bucket_idx_default,
+        )
+
+        bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
+    rs_buckets = bucket_reduce_scatter_by_mb(
+        gm, bucket_cap_mb_by_bucket_idx, None, mode
+    )
+    if len(rs_buckets) == 0:
+        return
+    merge_reduce_scatter(gm, rs_buckets, mode)
+
+
+def is_all_gather_into_tensor(node: torch.fx.Node) -> bool:  # type: ignore[arg-type]
+    return node.op == "call_function" and (
+        node.target == torch.ops._c10d_functional.all_gather_into_tensor.default
+        or node.target == torch.ops._c10d_functional.all_gather_into_tensor_out.default
+    )
+
+
+def is_reduce_scatter_tensor(node: torch.fx.Node) -> bool:
+    return (
+        node.op == "call_function"
+        and node.target is torch.ops._c10d_functional.reduce_scatter_tensor.default
+    )
+
+
+def is_wait_tensor(node: torch.fx.Node) -> bool:
+    return (
+        node.op == "call_function"
+        and node.target is torch.ops._c10d_functional.wait_tensor.default
+    )
+
+
+def is_all_reduce_tensor(node: torch.fx.Node) -> bool:
+    return (
+        node.op == "call_function"
+        and node.target is torch.ops._c10d_functional.all_reduce.default
+    )
+
+
+def is_all_to_all_tensor(node: torch.fx.Node) -> bool:
+    return (
+        node.op == "call_function"
+        and node.target is torch.ops._c10d_functional.all_to_all_single.default
+    )
+
+
+def is_wait_tensor_from_all_gather_into_tensor(node: torch.fx.Node) -> bool:
+    return is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0])  # type: ignore[arg-type]
+
+
+def collect_node_descendants(
+    graph: torch.fx.Graph,
+) -> dict[torch.fx.Node, OrderedSet[torch.fx.Node]]:
+    """
+    Collects the descendants of each node in the graph.
+    Args:
+        graph (torch.fx.Graph): The graph to collect descendants from.
+    Returns:
+        dict[torch.fx.Node, OrderedSet[torch.fx.Node]]: A dictionary mapping each node to its descendants.
+    """
+    node_descendants: dict[torch.fx.Node, OrderedSet[torch.fx.Node]] = (
+        collections.defaultdict(OrderedSet)
+    )
+    outdegree = collections.defaultdict(int)
+    queue = []
+
+    for node in graph.nodes:
+        n_outdegree = len(node.users)
+        if n_outdegree == 0:
+            queue.append(node)
+        else:
+            outdegree[node] = len(node.users)
+
+    while queue:
+        node = queue.pop()
+        for input_node in node.all_input_nodes:
+            node_descendants[input_node] |= node_descendants[node]
+            node_descendants[input_node].add(node)
+            outdegree[input_node] -= 1
+
+            if outdegree[input_node] == 0:
+                queue.append(input_node)
+
+    return node_descendants
+
+
+def greedy_bucket_collective_by_mb(
+    gm: torch.fx.GraphModule,
+    bucket_cap_mb_by_bucket_idx: Callable[[int], float],
+    filter_node: Callable[[torch.fx.Node], bool],
+    node_group_key: Callable[[torch.fx.Node], Any],
+    filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
+) -> list[list[torch.fx.Node]]:
+    """
+    Bucketing adjacent collectives with equal node_group_key.
+    We can not bucket non adjacent collectives,
+    as this will effectively change the order of collectives.
+    Reordering can lead to different order on different ranks.
+    """
+    g = gm.graph
+    found_candidates = False
+    for node in g.nodes:
+        if filter_node(node):
+            found_candidates = True
+            break
+    if not found_candidates:
+        return []
+
+    # TODO: pearce kelly algorithm for detecting cycles
+    node_descendents = collect_node_descendants(gm.graph)
+
+    nodes_groups: list[list[torch.fx.Node]] = []
+    cur_group: list[torch.fx.Node] = []
+    cur_group_key = None
+
+    for node in g.nodes:
+        if is_wait_tensor(node) and filter_node(node.args[0]):
+            if (filter_wait_node is None) or filter_wait_node(node):
+                coll_node = node.args[0]
+                group_key = node_group_key(coll_node)
+                if group_key == cur_group_key:
+                    cur_group.append(coll_node)
+                else:
+                    if len(cur_group) > 1:
+                        nodes_groups.append(cur_group)
+                    cur_group = [coll_node]
+                    cur_group_key = group_key
+
+    if len(cur_group) > 1:
+        nodes_groups.append(cur_group)
+
+    buckets: list[list[torch.fx.Node]] = []
+    for nodes in nodes_groups:
+        cur_bucket: list[torch.fx.Node] = []
+        cur_bucket_descendents: OrderedSet[torch.fx.Node] = OrderedSet()
+        cur_bucket_size_bytes: int = 0
+        cur_bucket_id: int = 0
+        bucket_size_bytes = int(
+            bucket_cap_mb_by_bucket_idx(cur_bucket_id) * 1024 * 1024
+        )
+        for node in nodes:
+            if node in cur_bucket_descendents:
+                # if there is a path from node to the current bucket, we cannot horizontally fuse (bucket)
+                continue
+            assert "val" in node.meta
+            n_val = node.meta["val"]
+            out_size_bytes = n_val.numel() * n_val.element_size()
+            n_input_val = node.all_input_nodes[0].meta["val"]
+            in_size_bytes = n_input_val.numel() * n_input_val.element_size()
+            size_bytes = max(out_size_bytes, in_size_bytes)
+            if cur_bucket_size_bytes + size_bytes > bucket_size_bytes and cur_bucket:
+                # Current bucket is full, create new bucket
+                if len(cur_bucket) > 1:
+                    buckets.append(cur_bucket)
+                cur_bucket = []
+                cur_bucket_size_bytes = 0
+                cur_bucket_id += 1
+                cur_bucket_descendents = OrderedSet()
+            cur_bucket_size_bytes += size_bytes
+            cur_bucket.append(node)
+            cur_bucket_descendents |= node_descendents[node]
+        if len(cur_bucket) > 1:
+            buckets.append(cur_bucket)
+    return buckets
+
+
+def bucket_all_gather_by_mb(
+    gm: torch.fx.GraphModule,
+    bucket_cap_mb_by_bucket_idx: Callable[[int], float],
+    filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
+    mode: BucketMode = "default",
+) -> list[list[torch.fx.Node]]:
+    """
+    Identifies all all_gather nodes and groups them into buckets,
+    based on size limit `bucket_cap_mb_by_bucket_idx`.
+
+    Args:
+        gm (torch.fx.GraphModule): GraphModule where to bucket all_gathers.
+        bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
+            in megabytes by bucket idx.  The idea of `bucket_cap_mb_by_bucket_idx` is to allow
+            to specify different sizes of the buckets at the start,
+            as first all_gather is usually exposed.  Interface of bucket_cap_mb_by_bucket_idx
+            is `bucket_cap_mb_by_bucket_idx_default` function that is default value for `bucket_cap_mb_by_bucket_idx`.
+        filter_wait_node (Callable[[torch.fx.Node], bool] | None): If specified,
+            only all_gather nodes with wait_node that satisfy `filter_wait_node` will be bucketed.
+
+    Returns:
+        list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of all_gather nodes.
+    """
+
+    group_key_fn = (
+        _ag_group_key_multidtype if mode and "multidtype" in mode else _ag_group_key
+    )
+
+    return greedy_bucket_collective_by_mb(
+        gm,
+        bucket_cap_mb_by_bucket_idx,
+        is_all_gather_into_tensor,
+        group_key_fn,
+        filter_wait_node,
+    )
+
+
+def bucket_reduce_scatter_by_mb(
+    gm: torch.fx.GraphModule,
+    bucket_cap_mb_by_bucket_idx: Callable[[int], float],
+    filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
+    mode: BucketMode = "default",
+) -> list[list[torch.fx.Node]]:
+    """
+    Identifies all reduce_scatter nodes and groups them into buckets,
+        based on size limit `bucket_cap_mb_by_bucket_idx`.
+
+    Args:
+        gm (torch.fx.GraphModule): GraphModule where to bucket reduce_scatters.
+        bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
+            in megabytes by bucket idx.  The idea of `bucket_cap_mb_by_bucket_idx` is to allow
+            to specify different sizes of the buckets.
+        filter_wait_node (Callable[[torch.fx.Node], bool] | None): If specified,
+            only reduce_scatter nodes with wait_node that satisfy `filter_wait_node` will be bucketed.
+
+    Returns:
+        list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of reduce_scatter nodes.
+    """
+
+    assert "multidtype" not in mode, (
+        "reduce scatter bucketing does not support multidtype"
+    )
+
+    return greedy_bucket_collective_by_mb(
+        gm,
+        bucket_cap_mb_by_bucket_idx,
+        is_reduce_scatter_tensor,
+        _rs_group_key,
+        filter_wait_node,
+    )
+
+
+def bucket_all_reduce_by_mb(
+    gm: torch.fx.GraphModule,
+    bucket_cap_mb_by_bucket_idx: Callable[[int], float],
+    filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
+) -> list[list[torch.fx.Node]]:
+    return greedy_bucket_collective_by_mb(
+        gm,
+        bucket_cap_mb_by_bucket_idx,
+        is_all_reduce_tensor,
+        _ar_group_key,
+        filter_wait_node,
+    )
+
+
+def bucket_all_reduce(
+    gm: torch.fx.GraphModule,
+    bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
+    mode: str | None = None,
+) -> None:
+    if bucket_cap_mb_by_bucket_idx is None:
+        from torch._inductor.fx_passes.bucketing import (  # pyrefly: ignore  # missing-module-attribute
+            bucket_cap_mb_by_bucket_idx_default,
+        )
+
+        bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
+    ar_buckets = bucket_all_reduce_by_mb(gm, bucket_cap_mb_by_bucket_idx)
+    if len(ar_buckets) == 0:
+        return
+    for bucket in ar_buckets:
+        merge_all_reduce_bucket(gm.graph, bucket, mode)
+
+
+@torch.library.custom_op("bucketing::_pre_bucket_reduce_scatter", mutates_args={})
+def _pre_bucket_reduce_scatter(
+    rs_ins: list[torch.Tensor],
+    group_size: int,
+) -> torch.Tensor:
+    rs_ins_flattened = [x.view(group_size, -1) for x in rs_ins]
+    new_rs_in = torch.cat(rs_ins_flattened, dim=1).flatten()
+    return new_rs_in
+
+
+def _pre_bucket_reduce_scatter_fake(
+    rs_ins: list[torch.Tensor],
+    group_size: int,
+) -> torch.Tensor:
+    out_numel = sum(rs_in.numel() for rs_in in rs_ins)
+    return torch.empty((out_numel,), device=rs_ins[0].device, dtype=rs_ins[0].dtype)
+
+
+_pre_bucket_reduce_scatter.register_fake(_pre_bucket_reduce_scatter_fake)
+
+
+def reduce_scatter_merge_fn_to_trace_custom_ops(
+    rs_ins: list[torch.Tensor],
+    group_size: int,
+    group_name: str,
+    reduce_op: str,
+    reduce_dtype: torch.dtype,  # type: ignore[name-defined]
+    device: torch.device,  # type: ignore[name-defined]
+) -> list[torch.Tensor]:  # type: ignore[no-untyped-def]
+    new_out_sizes = [(x.shape[0] // group_size,) + x.shape[1:] for x in rs_ins]
+    new_out_numels = [x.numel() // group_size for x in rs_ins]
+
+    new_rs_in = torch.ops.bucketing._pre_bucket_reduce_scatter(rs_ins, group_size)
+
+    # TODO - either use torch.cat or make sure inductor foreach codegen
+    # fires more reliably
+    new_rs_out = torch.ops.c10d_functional.wait_tensor(
+        torch.ops._c10d_functional.reduce_scatter_tensor.default(
+            new_rs_in, reduce_op, group_size, group_name
+        )
+    )
+    new_out_flat = new_rs_out.split(new_out_numels, 0)
+    new_outs = [x.view(s) for x, s in zip(new_out_flat, new_out_sizes)]
+    return new_outs
+
+
+def reduce_scatter_merge_fn_to_trace(
+    rs_ins: list[torch.Tensor],
+    group_size: int,
+    group_name: str,
+    reduce_op: str,
+    reduce_dtype: torch.dtype,  # type: ignore[name-defined]
+    device: torch.device,  # type: ignore[name-defined]
+) -> list[torch.Tensor]:  # type: ignore[no-untyped-def]
+    rs_ins_flattened = [x.view(group_size, -1) for x in rs_ins]
+
+    new_out_sizes = [(x.shape[0] // group_size,) + x.shape[1:] for x in rs_ins]
+    new_out_numels = [x.numel() // group_size for x in rs_ins]
+
+    new_rs_in = torch.cat(rs_ins_flattened, dim=1).flatten()
+
+    new_rs_out = torch.ops.c10d_functional.wait_tensor(
+        torch.ops._c10d_functional.reduce_scatter_tensor.default(
+            new_rs_in, reduce_op, group_size, group_name
+        )
+    )
+    new_out_flat = new_rs_out.split(new_out_numels, 0)
+    new_outs = [x.view(s) for x, s in zip(new_out_flat, new_out_sizes)]
+    return new_outs
+
+
+def all_reduce_merge_fn_to_trace(
+    ar_ins: list[torch.Tensor],
+    group_name: str,
+    reduce_op: str,
+    reduce_dtype: torch.dtype,  # type: ignore[name-defined]
+    device: torch.device,  # type: ignore[name-defined]
+) -> list[torch.Tensor]:  # type: ignore[no-untyped-def]
+    ar_ins_flattened = [x.view(-1) for x in ar_ins]
+    new_ar_in = torch.cat(ar_ins_flattened)
+    new_ar_out = torch.ops.c10d_functional.wait_tensor(
+        torch.ops._c10d_functional.all_reduce.default(new_ar_in, reduce_op, group_name)
+    )
+    split_sizes = [x.numel() for x in ar_ins]
+    new_outs_flat = new_ar_out.split(split_sizes)
+    new_outs = [x.view(ar_in.shape) for x, ar_in in zip(new_outs_flat, ar_ins)]
+    return new_outs
+
+
+# List of all torch dtypes for serialization through custom ops
+# TODO: custom ops support list[dtype] input
+_ALL_DTYPES = tuple(
+    [
+        getattr(torch, attr)
+        for attr in dir(torch)
+        if isinstance(getattr(torch, attr), torch.dtype)
+    ]
+)
+
+
+@torch.library.custom_op("bucketing::_pre_bucket_all_gather", mutates_args={})
+def _pre_bucket_all_gather(
+    ag_ins: list[torch.Tensor],
+    group_size: int,
+    group_name: str,
+    dtype: torch.dtype,  # type: ignore[name-defined]
+    out_dtype_ints: list[
+        int
+    ],  # dtype enum values, that inputs are converted to before all_gather
+    rank: int,
+) -> torch.Tensor:
+    # Convert int indices back to torch.dtype
+    out_dtypes = [_ALL_DTYPES[d] for d in out_dtype_ints]
+    ins_split_sizes_bytes = [
+        ag_in.numel() * out_dtype.itemsize
+        for ag_in, out_dtype in zip(ag_ins, out_dtypes, strict=True)
+    ]
+    bucket_dtype_size_bytes = dtype.itemsize
+    ins_split_sizes = [
+        _bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
+    ]
+    ag_input_numel = sum(ins_split_sizes)
+    device = ag_ins[0].device
+    new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
+    new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
+    foreach_copy_dsts = torch.split(new_ag_in, ins_split_sizes)
+    # View each destination slice as its output dtype, then copy
+    # The copy operation handles dtype conversion from input dtype to output dtype
+    foreach_copy_dsts_typed = [
+        dst.view(out_dtype)
+        for dst, out_dtype in zip(foreach_copy_dsts, out_dtypes, strict=True)
+    ]
+    ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins]
+    torch._foreach_copy_(foreach_copy_dsts_typed, ag_ins_flattened)
+    return new_ag_out
+
+
+def _pre_bucket_all_gather_fake(
+    ag_ins: list[torch.Tensor],
+    group_size: int,
+    group_name: str,
+    dtype: torch.dtype,  # type: ignore[name-defined]
+    out_dtype_ints: list[int],
+    rank: int,
+) -> torch.Tensor:
+    out_dtypes = [_ALL_DTYPES[d] for d in out_dtype_ints]
+    ins_split_sizes_bytes = [
+        ag_in.numel() * out_dtype.itemsize
+        for ag_in, out_dtype in zip(ag_ins, out_dtypes, strict=True)
+    ]
+    bucket_dtype_size_bytes = dtype.itemsize
+    ins_split_sizes = [
+        _bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
+    ]
+    ag_input_numel = sum(ins_split_sizes)
+    device = ag_ins[0].device
+    new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
+    return new_ag_out
+
+
+_pre_bucket_all_gather.register_fake(_pre_bucket_all_gather_fake)
+
+
+def all_gather_merge_fn_to_trace_custom_ops(
+    _ag_ins: list[torch.Tensor],
+    group_size: int,
+    group_name: str,
+    dtype: torch.dtype,  # type: ignore[name-defined]
+    out_dtypes: list[torch.dtype],  # type: ignore[name-defined]
+    rank: int,
+) -> list[torch.Tensor]:
+    # Don't create convert_element_type ops - _pre_bucket_all_gather handles conversion
+    # by viewing destination slices as output dtypes and letting copy do the conversion
+    ag_ins = _ag_ins
+    ins_sizes = [ag_in.shape for ag_in in ag_ins]
+    ins_split_sizes_bytes = [
+        ag_in.numel() * out_dtype.itemsize
+        for ag_in, out_dtype in zip(ag_ins, out_dtypes)
+    ]
+    bucket_dtype_size_bytes = dtype.itemsize
+    ins_split_sizes = [
+        _bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
+    ]
+    ag_input_numel = sum(ins_split_sizes)
+
+    # Convert out_dtypes to indices for custom_op
+    # TODO: custom ops support list[dtype] input
+    out_dtype_ints = [_ALL_DTYPES.index(dt) for dt in out_dtypes]
+
+    new_ag_out = torch.ops.bucketing._pre_bucket_all_gather(
+        ag_ins, group_size, group_name, dtype, out_dtype_ints, rank
+    )
+    new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
+    wait_tensor = torch.ops.c10d_functional.wait_tensor(
+        torch.ops._c10d_functional.all_gather_into_tensor_out.default(
+            new_ag_in, group_size, group_name, out=new_ag_out
+        )
+    )
+    new_ag_out_reshaped = wait_tensor.reshape(group_size, -1)
+    outs_bucket_dtype = torch.split_with_sizes(
+        new_ag_out_reshaped,
+        ins_split_sizes,
+        dim=1,
+    )
+    outs_reshaped = [
+        o.view(out_dtype).reshape((shape[0] * group_size,) + shape[1:])
+        for o, shape, out_dtype in zip(outs_bucket_dtype, ins_sizes, out_dtypes)
+    ]
+    return outs_reshaped
+
+
+def all_gather_merge_fn_to_trace(
+    ag_ins: list[torch.Tensor],
+    group_size: int,
+    group_name: str,
+    dtype: torch.dtype,  # type: ignore[name-defined]
+    out_dtypes: list[torch.dtype],  # type: ignore[name-defined]
+    rank: int,
+) -> list[torch.Tensor]:
+    ins_sizes = [ag_in.shape for ag_in in ag_ins]
+    ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
+    ag_input_numel = sum(ins_split_sizes)
+    device = ag_ins[0].device
+    new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
+    new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
+    foreach_copy_dsts = torch.split(new_ag_in, ins_split_sizes)
+    ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins]
+    torch._foreach_copy_(foreach_copy_dsts, ag_ins_flattened)
+    wait_tensor = torch.ops.c10d_functional.wait_tensor(
+        torch.ops._c10d_functional.all_gather_into_tensor_out.default(
+            new_ag_in, group_size, group_name, out=new_ag_out
+        )
+    )
+    new_ag_out_reshaped = wait_tensor.reshape(group_size, -1)
+    outs = torch.split_with_sizes(
+        new_ag_out_reshaped,
+        ins_split_sizes,
+        dim=1,
+    )
+    outs_reshaped = [
+        o.reshape((shape[0] * group_size,) + shape[1:])
+        for o, shape in zip(outs, ins_sizes)
+    ]
+    return outs_reshaped
+
+
+def all_gather_merge_fn_to_trace_functional(
+    ag_ins: list[torch.Tensor],
+    group_size: int,
+    group_name: str,
+    dtype: torch.dtype,  # type: ignore[name-defined]
+    out_dtypes: list[torch.dtype],  # type: ignore[name-defined]
+    rank: int,
+    use_fsdp_ag_copy_in: bool = False,
+) -> list[torch.Tensor]:
+    # Implementation that is functional in graph,
+    # but uses custom op torch.ops.fsdp.all_gather_copy_in.
+    ins_sizes = [ag_in.shape for ag_in in ag_ins]
+    ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
+    ag_input_numel = sum(ins_split_sizes)
+    device = ag_ins[0].device
+    new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
+    ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins]
+    if use_fsdp_ag_copy_in:
+        new_ag_in, new_ag_out = torch.ops.fsdp.all_gather_copy_in(
+            ag_ins_flattened, new_ag_out, ins_split_sizes, ag_input_numel, rank
+        )
+    else:
+        new_ag_in = torch.cat(ag_ins_flattened, dim=0)
+    wait_tensor = torch.ops.c10d_functional.wait_tensor(
+        torch.ops._c10d_functional.all_gather_into_tensor_out.default(
+            new_ag_in, group_size, group_name, out=new_ag_out
+        )
+    )
+    new_ag_out_reshaped = wait_tensor.reshape(group_size, -1)
+    outs = torch.split_with_sizes(
+        new_ag_out_reshaped,
+        ins_split_sizes,
+        dim=1,
+    )
+    outs_reshaped = [
+        o.reshape((shape[0] * group_size,) + shape[1:])
+        for o, shape in zip(outs, ins_sizes)
+    ]
+    return outs_reshaped
+
+
+def _trace(fn, inps) -> torch.fx.GraphModule:  # type: ignore[no-untyped-def]
+    with dynamo_timed("fx.bucketing._trace", log_pt2_compile_event=True):
+        fake_mode = detect_fake_mode(inps)
+        assert fake_mode is not None
+        with fake_mode, enable_python_dispatcher():
+            out = make_fx(fn)(*inps)
+            for node in out.graph.find_nodes(
+                op="call_function", target=torch.ops.aten.detach.default
+            ):
+                node.replace_all_uses_with(node.args[0])
+                out.graph.erase_node(node)
+            return out
+
+
+def _insert_fn_trace_before_node(  # type: ignore[no-untyped-def]
+    g: torch.fx.Graph,
+    fn_to_trace,
+    inps,
+    insert_before_node: torch.fx.Node,
+    g_fn_inps: list[torch.fx.Node],
+    g_fn_outs: list[torch.fx.Node],
+) -> tuple[dict[torch.fx.Node, torch.fx.Node], list[torch.fx.Node]]:  # type: ignore[no-untyped-def]
+    """
+    Helper function that traces :attr:`fn_to_trace` with inputs
+    :attr:`inps`.
+    The result function graph will be inserted before :attr:`insert_before_node`,
+    using :attr:`g_fn_inps` nodes of original graph as inputs of function graph,
+    function graph outputs will replace :attr:`g_fn_outs` in original graph.
+
+    Returns:
+        (replacements, new_nodes): Dictionary mapping old to new nodes, and list of all newly inserted nodes
+    """
+    with dynamo_timed(
+        "fx.bucketing._insert_fn_trace_before_node", log_pt2_compile_event=True
+    ):
+        fn_gm = _trace(
+            fn_to_trace,
+            inps,
+        )
+        fn_g = fn_gm.graph
+        fn_g_ins = fn_g.find_nodes(op="placeholder")
+        env = {fn_g_ins[idx]: g_fn_inps[idx] for idx in range(len(g_fn_inps))}
+        g_fn_new_outs: list[torch.fx.Node] = []
+        new_nodes: list[torch.fx.Node] = []  # Track all newly inserted nodes
+
+        with g.inserting_before(insert_before_node):
+            for _n in fn_g.nodes:
+                if _n.op == "placeholder":
+                    continue
+                _new_n = g.node_copy(_n, lambda x: env[x])
+                env[_n] = _new_n
+                if _n.op == "output":
+                    g_fn_new_outs = _new_n.args[0]  # type: ignore[assignment]
+                    g.erase_node(_new_n)
+                else:
+                    new_nodes.append(_new_n)  # Track non-output nodes
+
+        replacements = {  # noqa: C416
+            orig_out: new_out for orig_out, new_out in zip(g_fn_outs, g_fn_new_outs)
+        }
+        for orig_out, new_out in zip(g_fn_outs, g_fn_new_outs):
+            orig_out.replace_all_uses_with(new_out)
+
+        return replacements, new_nodes
+
+
+def has_mergeable_all_gather_convert_dtype(n: torch.fx.Node) -> bool:
+    node_in = n.args[0]
+    return (
+        is_all_gather_into_tensor(n)
+        and isinstance(node_in, torch.fx.Node)
+        and node_in.op == "call_function"
+        and (
+            node_in.target is torch.ops.prims.convert_element_type.default
+            or node_in.target is torch.ops.aten._to_copy.default
+        )
+        and len(node_in.users) == 1
+    )
+
+
+def process_collective_bucket(
+    g: torch.fx.Graph,
+    bucket_nodes: list[torch.fx.Node],
+    fn_to_trace: Callable[..., list[torch.Tensor]],
+    trace_args_fn: Callable[[list[torch.fx.Node]], tuple[Any, ...]],
+    insert_before: torch.fx.Node | None = None,
+    wait_insertion_point: torch.fx.Node | None = None,
+) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
+    """
+    Process a single bucket of collective operation nodes with flexible insertion control.
+
+    Args:
+        g: The graph to modify
+        bucket_nodes: Nodes in the current bucket to process
+        fn_to_trace: Function to trace and insert
+        trace_args_fn: Function to create trace arguments from inputs
+        insert_before: Where to insert the traced function (default: after last bucket node)
+        wait_insertion_point: If provided, move all nodes from wait() onwards to before this node
+
+    Returns:
+        new_nodes: List of all newly inserted nodes
+        replacements: Dictionary mapping old wait nodes to new output nodes
+    """
+    # Collect inputs and waits from current bucket
+    bucket_ins: list[torch.fx.Node] = []
+    bucket_waits: list[torch.fx.Node] = []
+    ag_node_to_pre_nodes: dict[torch.fx.Node, list[torch.fx.Node]] = defaultdict(list)
+
+    for n in bucket_nodes:
+        assert len(n.users) == 1, f"Expected single user for {n}, got {n.users}"
+        wait_n = next(iter(n.users))
+
+        # Handle convert_element_type operations (for all_gather)
+        node_in = n.args[0]
+        if has_mergeable_all_gather_convert_dtype(n):
+            ag_node_to_pre_nodes[n].append(node_in)
+            node_in = node_in.args[0]
+
+        assert isinstance(node_in, torch.fx.Node)  # Ensure node_in is a Node
+        bucket_ins.append(node_in)
+        bucket_waits.append(wait_n)
+
+    # Create trace arguments
+    trace_args = trace_args_fn(bucket_ins)
+
+    # Determine insertion point
+    if insert_before is None:
+        insert_before = bucket_nodes[-1].next
+
+    # Insert traced function and get replacements + new nodes
+    replacements, new_nodes = _insert_fn_trace_before_node(
+        g,
+        fn_to_trace,
+        trace_args,
+        insert_before,
+        bucket_ins,
+        bucket_waits,
+    )
+
+    # If requested, move wait nodes and everything after to specified location
+    if wait_insertion_point is not None:
+        # Find the first wait node in new_nodes
+        wait_start_idx = None
+        for i, node in enumerate(new_nodes):
+            if is_wait_tensor(node):
+                wait_start_idx = i
+                break
+
+        # Move all nodes from wait onwards (including the wait)
+        if wait_start_idx is not None:
+            nodes_to_move = new_nodes[wait_start_idx:]
+            for node in nodes_to_move:
+                wait_insertion_point.prepend(node)
+
+    # Preserve metadata from original collective nodes to new bucketed nodes
+    if bucket_nodes:
+        overlap_log.debug(
+            "Bucketing nodes: %s, New nodes: %s",
+            ",".join([n.name for n in bucket_nodes]),
+            ",".join([n.name for n in new_nodes]),
+        )
+    _populate_node_meta(bucket_nodes, new_nodes)
+
+    # Erase old nodes
+    for node, wait_n in zip(bucket_nodes, bucket_waits):
+        g.erase_node(wait_n)
+        g.erase_node(node)
+        # Erase any convert_element_type nodes we tracked
+        for pre_node in reversed(ag_node_to_pre_nodes[node]):
+            g.erase_node(pre_node)
+
+    return new_nodes, replacements
+
+
+def merge_reduce_scatter_bucket(
+    g: torch.fx.Graph,
+    rs_nodes: list[torch.fx.Node],
+    mode: BucketMode = "default",
+    insert_before: torch.fx.Node | None = None,
+    wait_insertion_point: torch.fx.Node | None = None,
+) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
+    # Validate bucket consistency
+    rs0 = rs_nodes[0]
+    rs0_val = rs0.meta["val"]
+    _, reduce_op, group_size, group_name = rs0.args
+    reduce_dtype = rs0_val.dtype
+    device = rs0_val.device
+
+    for n in rs_nodes:
+        rs_val = n.meta["val"]
+        assert (
+            n.args[1] == reduce_op
+            and n.args[2] == group_size
+            and n.args[3] == group_name
+            and rs_val.device == device
+            and rs_val.dtype == reduce_dtype
+        )
+
+    # Choose merge function based on mode
+    rs_merge_fn = reduce_scatter_merge_fn_to_trace
+    if mode and "custom_ops" in mode:
+        rs_merge_fn = reduce_scatter_merge_fn_to_trace_custom_ops
+
+    # Process bucket with lazy input collection
+    def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]:
+        return (
+            pytree.tree_map(lambda node: node.meta["val"], bucket_ins),
+            group_size,
+            group_name,
+            reduce_op,
+            reduce_dtype,
+            device,
+        )
+
+    return process_collective_bucket(
+        g,
+        rs_nodes,
+        rs_merge_fn,
+        create_trace_args,
+        insert_before=insert_before,
+        wait_insertion_point=wait_insertion_point,
+    )
+
+
+def merge_all_reduce_bucket(
+    g: torch.fx.Graph,
+    ar_nodes: list[torch.fx.Node],
+    mode: str | None = None,
+    insert_before: torch.fx.Node | None = None,
+    wait_insertion_point: torch.fx.Node | None = None,
+) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
+    ar0 = ar_nodes[0]
+    ar0_val = ar0.meta["val"]
+    _, reduce_op, group_name = ar0.args
+    reduce_dtype = ar0_val.dtype
+    device = ar0_val.device
+
+    for n in ar_nodes:
+        ar_val = n.meta["val"]
+        assert (
+            n.args[1] == reduce_op
+            and n.args[2] == group_name
+            and ar_val.device == device
+            and ar_val.dtype == reduce_dtype
+        )
+
+    ar_merge_fn = all_reduce_merge_fn_to_trace
+
+    def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]:
+        return (
+            pytree.tree_map(lambda node: node.meta["val"], bucket_ins),
+            group_name,
+            reduce_op,
+            reduce_dtype,
+            device,
+        )
+
+    return process_collective_bucket(
+        g,
+        ar_nodes,
+        ar_merge_fn,
+        create_trace_args,
+        insert_before=insert_before,
+        wait_insertion_point=wait_insertion_point,
+    )
+
+
+def merge_all_gather_bucket(
+    g: torch.fx.Graph,
+    ag_nodes: list[torch.fx.Node],
+    mode: BucketMode = "default",
+    insert_before: torch.fx.Node | None = None,
+    wait_insertion_point: torch.fx.Node | None = None,
+) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
+    from torch.distributed.distributed_c10d import _resolve_process_group
+
+    ag0 = ag_nodes[0]
+    _, group_size, group_name = ag0.args
+    assert isinstance(group_name, str)
+    _ag_dtypes: list[torch.dtype] = []  # type: ignore[name-defined]
+
+    for n in ag_nodes:
+        assert n.args[1] == group_size and n.args[2] == group_name
+        _ag_dtypes.append(n.meta["val"].dtype)
+
+    bucket_dtype = pick_bucket_dtype(_ag_dtypes)
+
+    # Choose merge function based on mode
+    ag_merge_fn = all_gather_merge_fn_to_trace
+    if mode is not None and "custom_ops" in mode:
+        ag_merge_fn = all_gather_merge_fn_to_trace_custom_ops  # type: ignore[assignment]
+
+    # Process bucket with lazy input collection
+    rank: int = dist.get_rank(_resolve_process_group(group_name))
+
+    def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]:
+        return (
+            pytree.tree_map(lambda node: node.meta["val"], bucket_ins),
+            group_size,
+            group_name,
+            bucket_dtype,
+            _ag_dtypes,
+            rank,
+        )
+
+    return process_collective_bucket(
+        g,
+        ag_nodes,
+        ag_merge_fn,
+        create_trace_args,
+        wait_insertion_point=wait_insertion_point,
+    )
+
+
+def merge_reduce_scatter(
+    gm: torch.fx.GraphModule,
+    rs_buckets: list[list[torch.fx.Node]],
+    mode: BucketMode = "default",
+) -> None:
+    """
+    Merges specified buckets of reduce_scatter to joint reduce_scatter.
+    """
+    with dynamo_timed("fx.bucketing.merge_reduce_scatter", log_pt2_compile_event=True):
+        trace_structured(
+            "artifact",
+            metadata_fn=lambda: {
+                "name": "fx_bucketing_passes_reduce_scatter_buckets",
+                "encoding": "string",
+            },
+            payload_fn=lambda: str(rs_buckets),
+        )
+
+        g = gm.graph
+
+        for rs_nodes in rs_buckets:
+            merge_reduce_scatter_bucket(g, rs_nodes, mode)
+
+
+def merge_all_gather(
+    gm: torch.fx.GraphModule,
+    ag_buckets: list[list[torch.fx.Node]],
+    mode: BucketMode = "default",
+) -> None:
+    """
+    Merges specified buckets of all_gather to joint all_gather.
+    """
+    with dynamo_timed("fx.bucketing.merge_all_gather", log_pt2_compile_event=True):
+        trace_structured(
+            "artifact",
+            metadata_fn=lambda: {
+                "name": "fx_bucketing_passes_all_gather_buckets",
+                "encoding": "string",
+            },
+            payload_fn=lambda: str(ag_buckets),
+        )
+
+        g = gm.graph
+
+        for ag_nodes in ag_buckets:
+            merge_all_gather_bucket(g, ag_nodes, mode)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/control_dependencies.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/control_dependencies.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6e3ca625c5d97bcd0e52508ed084f5bf82b2bb2
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/control_dependencies.py
@@ -0,0 +1,226 @@
+# mypy: allow-untyped-defs
+"""
+Effect ordering pass for inductor.
+
+This pass adds ordering dependencies to FX graphs using the control_deps HOP
+for precise control over scheduling constraints. When you need exact ordering between
+operations (e.g., collective_start -> mm -> wait), this pass wraps operations
+with control_deps to make dependencies explicit.
+"""
+
+from typing import Any
+
+import torch.fx as fx
+from torch._higher_order_ops.utils import register_fake
+from torch._ops import HigherOrderOperator
+from torch.utils._ordered_set import OrderedSet
+
+
+class ControlDeps(HigherOrderOperator):
+    """
+    Higher-order operator that enforces ordering by making dependencies explicit.
+
+    Schema: control_deps(additional_deps, target, *args, **kwargs) -> result
+    where:
+    - additional_deps: tuple of tensors that must be computed before this op
+    - subgraph: GraphModule containing the exact operation to execute
+    - args/kwargs: arguments for the target function
+
+    This ensures all tensors in additional_deps are computed before the target
+    executes, creating explicit scheduling dependencies.
+    """
+
+    def __init__(self) -> None:
+        super().__init__("control_deps")
+
+    def __call__(self, additional_deps, subgraph, *args, **kwargs):
+        """Call the operator with dependencies and subgraph.
+
+        Args:
+            additional_deps: Tuple of tensors that must be computed first
+            subgraph: GraphModule containing the exact operation to execute
+            *args: Arguments to pass to the subgraph
+        """
+        if not isinstance(additional_deps, (tuple, list)):
+            raise TypeError(
+                f"additional_deps must be tuple/list, got {type(additional_deps).__name__}"
+            )
+        if not (isinstance(subgraph, fx.GraphModule) or callable(subgraph)):
+            raise TypeError(
+                f"subgraph must be GraphModule or callable, got {type(subgraph).__name__}"
+            )
+        return super().__call__(additional_deps, subgraph, *args, **kwargs)
+
+
+control_deps = ControlDeps()
+
+
+# Register fake implementation for tracing
+@register_fake(control_deps)
+def _(additional_deps, subgraph, *args, **kwargs):
+    """Fake tensor implementation - execute the subgraph."""
+    return subgraph(*args, **kwargs)
+
+
+def get_subgraph_name(gm: fx.GraphModule, name):
+    name = f"subgraph_{name}"
+
+    if not hasattr(gm, name):
+        return name
+
+    i = 0
+    while hasattr(gm, f"{name}_{i}"):
+        i += 1
+
+    return f"{name}_{i}"
+
+
+def preserve_node_ordering(
+    graph: fx.Graph,
+    additional_deps_map: dict[fx.Node, OrderedSet[fx.Node]],
+    verbose: bool = False,
+) -> None:
+    """
+    Preserve node ordering using control_deps HOP with subgraph.
+
+    This function wraps operations with control_deps that:
+    1. Makes additional dependencies explicit (first argument)
+    2. Creates a subgraph internally to preserve the exact original operation
+    3. Preserves the original node names
+
+    Args:
+        graph: The FX graph to modify
+        additional_deps_map: Mapping from dependent nodes to their dependencies
+        verbose: If True, print debug information
+    """
+    if not additional_deps_map:
+        return
+
+    # Track replacements so we can update dependencies
+    replacements: dict[fx.Node, fx.Node] = {}
+
+    # Process each node that needs additional dependencies
+    for dependent_node, dep_nodes in additional_deps_map.items():
+        assert dependent_node.op == "call_function", dependent_node.op
+
+        original_name = dependent_node.name
+        original_args = dependent_node.args
+        original_kwargs = dependent_node.kwargs
+        original_meta = dependent_node.meta.copy()
+
+        updated_dep_nodes = [replacements.get(dep, dep) for dep in dep_nodes]
+
+        # Create a subgraph that preserves the exact original operation
+        subgraph_module = _create_subgraph_for_node(graph, dependent_node)
+
+        owning_mod = graph.owning_module
+        assert owning_mod is not None
+        subgraph_attr_name = get_subgraph_name(owning_mod, original_name)
+        setattr(graph.owning_module, subgraph_attr_name, subgraph_module)
+
+        # Create control_deps call with:
+        # 1. Additional dependencies as first arg (explicit)
+        # 2. Subgraph via get_attr (like b2b gemm pass)
+        # 3. Original arguments (only fx.Node args and kwargs are passed)
+        with graph.inserting_before(dependent_node):
+            # Create get_attr node for the subgraph
+            get_subgraph = graph.get_attr(subgraph_attr_name)
+
+            # add additional args
+            node_args = [a for a in original_args if isinstance(a, fx.Node)]
+            for value in original_kwargs.values():
+                if isinstance(value, fx.Node):
+                    node_args.append(value)
+
+            # Create with temporary name first
+            ordered_node = graph.call_function(
+                control_deps,
+                args=(
+                    tuple(updated_dep_nodes),  # additional_deps
+                    get_subgraph,  # subgraph via get_attr (like b2b gemm)
+                    *node_args,  # original node arguments (from both args and kwargs)
+                ),
+                kwargs={},
+                name=f"__temp_{original_name}",  # Temporary name to avoid conflict
+            )
+
+        # Copy metadata from original node
+        ordered_node.meta = original_meta
+        # this will be constrained on the target node in subgraph if it exists
+        ordered_node.meta.pop("eager_input_vals", None)
+
+        # Replace all uses of the original node with the ordered version
+        dependent_node.replace_all_uses_with(ordered_node)
+
+        # Remove the original node from the graph
+        graph.erase_node(dependent_node)
+
+        # Now rename the ordered node to the original name
+        ordered_node.name = original_name  # PRESERVE ORIGINAL NAME
+
+        # Track the replacement for future dependencies
+        replacements[dependent_node] = ordered_node
+
+
+def _create_subgraph_for_node(graph: fx.Graph, node: fx.Node) -> fx.GraphModule:
+    """
+    Create a subgraph that exactly recreates a node's operation.
+
+    The subgraph takes only the fx.Node arguments and recreates the operation
+    with the exact target, args structure, and kwargs.
+
+    Args:
+        graph: The parent graph
+        node: The node to wrap in a subgraph
+
+    Returns:
+        A GraphModule containing the subgraph
+    """
+    # Get the owning module
+    # torch.distributed.breakpoint(0)
+    owning_module = graph.owning_module
+
+    # Create a new graph for the subgraph
+    subgraph = fx.Graph(owning_module)
+
+    new_args: list[Any] = []
+    placeholder_idx = 0
+    for _, arg in enumerate(node.args):
+        if not isinstance(arg, fx.Node):
+            new_args.append(arg)
+            continue
+
+        placeholder = subgraph.placeholder(f"arg_{placeholder_idx}")
+        placeholder_idx += 1
+        if "val" in arg.meta:
+            placeholder.meta.update(arg.meta)
+        new_args.append(placeholder)  # type: ignore[arg-type]
+
+    new_kwargs: dict[str, Any] = {}
+    for key, value in node.kwargs.items():
+        if not isinstance(value, fx.Node):
+            new_kwargs[key] = value
+            continue
+
+        placeholder = subgraph.placeholder(f"kwarg_{key}")
+        if "val" in value.meta:
+            placeholder.meta.update(value.meta)
+
+        new_kwargs[key] = placeholder  # type: ignore[assignment]
+
+    # Recreate the exact original operation in the subgraph
+    assert callable(node.target)
+    result = subgraph.call_function(
+        node.target,
+        tuple(new_args),
+        new_kwargs,  # type: ignore[arg-type]
+    )
+
+    # Copy metadata from the original node
+    result.meta.update(node.meta)
+
+    out = subgraph.output(result)
+    if "val" in result.meta:
+        out.meta["val"] = result.meta["val"]
+
+    return fx.GraphModule(owning_module, subgraph)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/ddp_fusion.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/ddp_fusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..44314b912786f9537286108dc33c94905a5db0de
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/ddp_fusion.py
@@ -0,0 +1,589 @@
+# Owner(s): ["oncall: distributed"]
+import collections
+import inspect
+import logging
+import math
+import operator
+from collections.abc import Callable, Generator
+from dataclasses import dataclass
+from functools import partial
+from typing import Any, cast
+
+import torch
+import torch.fx as fx
+from torch._dynamo.utils import counters
+from torch.fx.passes.graph_transform_observer import GraphTransformObserver
+from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
+from torch.utils._ordered_set import OrderedSet
+from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
+
+from ..fx_utils import get_fake_args_kwargs
+from ..virtualized import V
+
+
+aten = torch.ops.aten
+logger: logging.Logger = logging.getLogger("comm_fusion")
+
+
+def move_block_after(block: list[fx.Node], target_node: fx.Node) -> None:
+    for node in block:
+        target_node.append(node)
+        target_node = node
+
+
+def move_block_before(block: list[fx.Node], target_node: fx.Node) -> None:
+    for node in block:
+        target_node.prepend(node)
+        target_node = node
+
+
+def call_function(
+    graph: fx.Graph,
+    target: str | Callable[..., Any],
+    args: tuple[fx.node.Argument, ...] | None = None,
+    kwargs: dict[str, fx.node.Argument] | None = None,
+) -> fx.Node:
+    # We accept target as a str to avoid typing error as the type of
+    # a node.target is str | Callable[..., Any].
+    # This also allows us to avoid writing check for every call.
+    if isinstance(target, str):
+        raise RuntimeError(f"Call function should not get a str target {target=}")
+    node = graph.call_function(target, args, kwargs)
+    _, args, kwargs = get_fake_args_kwargs(node)
+    with V.fake_mode:
+        node.meta["val"] = target(*args, **kwargs)
+        # node.meta["val"] may be a container. So we use tree_map here
+        # to recursively extract the tensor metadata.
+        node.meta["tensor_meta"] = tree_map(
+            _extract_tensor_metadata, (node.meta["val"],)
+        )[0]
+    return node
+
+
+@dataclass(unsafe_hash=True)
+class CommBlock:
+    shape: torch.Size | list[torch.Size]
+    node_list: list[fx.Node]
+    inputs: list[fx.Node]
+    wait_nodes: list[fx.Node]
+    comm_node: fx.Node
+    outputs: OrderedSet[fx.Node]
+
+
+def get_comm_block(comm_node: fx.Node) -> CommBlock | None:
+    """
+    Given a collective node (e.g., allreduce), find out all the nodes belong to
+    this communication.
+
+    Args:
+        comm_node(fx.Node): The target communication/collective node.
+    Returns:
+        The CommBlock that encapsulates the related nodes (e.g., wait_node) of
+        the given comm_node.
+    """
+    node_list = []
+    wait_nodes = []
+    inputs, _ = tree_flatten((comm_node.args, comm_node.kwargs))
+    input_nodes = [inp for inp in inputs if isinstance(inp, fx.Node)]
+    # If the users of the wait node are following items, we consinder them
+    # to be a part of the output.
+    intermediate_outputs = ("split", "reshape", "getitem", "detach", "alias")
+
+    first_user = next(iter(comm_node.users))
+    if (
+        len(comm_node.users) == 1
+        and first_user.target is torch.ops._c10d_functional.wait_tensor.default
+    ):
+        # Collective with only one output
+        node_list = [comm_node, first_user]
+        wait_nodes.append(first_user)
+    elif len(comm_node.users) > 1 and first_user.target is operator.getitem:
+        # Collective with only more than one output
+        node_list.append(comm_node)
+        for user in comm_node.users:
+            if user.target != operator.getitem:
+                return None
+            if len(user.users) != 1:
+                return None
+            wait_node = next(iter(user.users))
+            if wait_node.target != torch.ops._c10d_functional.wait_tensor.default:
+                return None
+            wait_nodes.append(wait_node)
+            node_list.append(user)
+        node_list.extend(wait_nodes)
+    else:
+        return None
+
+    # Identify all the outputs of this collective block.
+    outputs = OrderedSet[fx.Node]()
+    nodes = collections.deque(wait_nodes)
+    while nodes:
+        node = nodes.popleft()
+        for user in node.users:
+            if isinstance(user, fx.Node) and user.name.startswith(intermediate_outputs):
+                nodes.append(user)
+                node_list.append(user)
+            else:
+                outputs.add(node)
+                break
+
+    tensor_meta = input_nodes[0].meta["tensor_meta"]
+    shape: torch.Size | list[torch.Size]
+    if isinstance(tensor_meta, TensorMetadata):
+        shape = tensor_meta.shape
+    elif isinstance(tensor_meta, (list, tuple)):
+        shape = [tm.shape for tm in tensor_meta]
+    else:
+        logger.warning("Unexpected type of tensor_meta %s", type(tensor_meta))
+        return None
+
+    return CommBlock(
+        shape=shape,
+        node_list=node_list,
+        wait_nodes=wait_nodes,
+        comm_node=comm_node,
+        inputs=input_nodes,
+        outputs=outputs,
+    )
+
+
+def get_all_comm_blocks(
+    graph: fx.Graph,
+    comm_ops: tuple[torch._ops.OpOverload, ...],
+    comm_filter: Callable[..., bool] | None = None,
+) -> list[CommBlock]:
+    if comm_filter is None:
+
+        def always_true(comm_block: CommBlock) -> bool:
+            return True
+
+        comm_filter = always_true
+
+    blocks = []
+    for node in graph.nodes:
+        if node.target not in comm_ops:
+            continue
+        comm_block = get_comm_block(node)
+        if comm_block is not None and comm_filter(comm_block):
+            blocks.append(comm_block)
+    return blocks
+
+
+def _fuse_allreduce_by_concat(
+    graph: fx.Graph,
+    last_input_node: fx.Node,
+    all_input_nodes: list[fx.Node],
+    last_comm_block: CommBlock,
+) -> CommBlock:
+    """Given a list of inputs in order, create a fused allreduce using concat."""
+    # Flatten all the inputs to the all_reduce nodes.
+    with graph.inserting_after(last_input_node):
+        cat_inputs = []
+        for input_node in all_input_nodes:
+            assert isinstance(input_node.args[0], fx.Node)
+            input_node = input_node.args[0]
+            cat_inputs.append(
+                call_function(graph, aten.flatten.using_ints, (input_node,))
+            )
+
+    # Concat all the flattened nodes.
+    with graph.inserting_after(cat_inputs[0]):
+        cat_node = call_function(graph, aten.cat, (cat_inputs,))
+
+    # Insert the fused div node and remove the input div nodes.
+    # This is an optimization and is not mandatory for fusion.
+    divisors = [div.args[1] for div in all_input_nodes]
+    assert all(divisor == divisors[0] for divisor in divisors)
+    with graph.inserting_after(cat_node):
+        div_node = call_function(graph, last_input_node.target, (cat_node, divisors[0]))
+
+    # Create a new Comm/all_reduce node.
+    last_comm_node = last_comm_block.comm_node
+    last_wait_node = last_comm_block.wait_nodes[0]
+    with graph.inserting_after(div_node):
+        flatten_args, spec = tree_flatten((last_comm_node.args, last_comm_node.kwargs))
+        flatten_args[0] = div_node
+        args, kwargs = tree_unflatten(flatten_args, spec)
+        fused_comm_node = call_function(graph, last_comm_node.target, args, kwargs)
+
+    # Create a new Wait node.
+    with graph.inserting_after(fused_comm_node):
+        flatten_args, spec = tree_flatten((last_wait_node.args, last_wait_node.kwargs))
+        flatten_args[0] = fused_comm_node
+        args, kwargs = tree_unflatten(flatten_args, spec)
+        fused_wait_node = call_function(graph, last_wait_node.target, args, kwargs)
+
+    # Move the fused all_reduce and its args to right after the input node
+    nodes_to_move = cat_inputs + [cat_node, div_node, fused_comm_node, fused_wait_node]
+    # pyrefly: ignore [bad-argument-type]
+    move_block_after(nodes_to_move, last_input_node)
+
+    return CommBlock(
+        shape=cast(TensorMetadata, cat_node.meta.get("tensor_meta")).shape,
+        node_list=[fused_comm_node, fused_wait_node],
+        wait_nodes=[fused_wait_node],
+        comm_node=fused_comm_node,
+        inputs=[div_node],
+        outputs=OrderedSet([fused_wait_node]),
+    )
+
+
+def _fuse_with_coalesced_op(
+    graph: fx.Graph,
+    last_input_node: fx.Node,
+    all_input_nodes: list[fx.Node],
+    last_comm_block: CommBlock,
+) -> CommBlock:
+    """Given a list of inputs in order, create a fused allreduce by coalesced."""
+    last_comm_node = last_comm_block.comm_node
+    last_wait_node = last_comm_block.wait_nodes[0]
+
+    # Insert the fused div node and remove the input div nodes.
+    # This is an optimization and is not mandatory for fusion.
+    dividends = [div.args[0] for div in all_input_nodes]
+    divisors = [div.args[1] for div in all_input_nodes]
+    assert all(divisor == divisors[0] for divisor in divisors)
+    with graph.inserting_before(last_input_node):
+        last_input_node = call_function(
+            graph, aten._foreach_div.Scalar, (dividends, divisors[0])
+        )
+    input_node = last_input_node
+
+    # Create a new Comm/all_reduce_coalesced node.
+    with graph.inserting_after(last_comm_node):
+        flatten_args, spec = tree_flatten((last_comm_node.args, last_comm_node.kwargs))
+        flatten_args[0] = input_node
+        args, kwargs = tree_unflatten(flatten_args, spec)
+        fused_comm_node = call_function(
+            graph, torch.ops._c10d_functional.all_reduce_coalesced.default, args, kwargs
+        )
+
+    # Create a new wait node.
+    getitem_nodes = []
+    wait_nodes = []
+    flatten_args, spec = tree_flatten((last_wait_node.args, last_wait_node.kwargs))
+    for idx in range(len(all_input_nodes)):
+        with graph.inserting_after(fused_comm_node):
+            gi_node = call_function(graph, operator.getitem, (fused_comm_node, idx))
+        getitem_nodes.append(gi_node)
+        flatten_args[0] = gi_node
+        args, kwargs = tree_unflatten(flatten_args, spec)
+        with graph.inserting_after(gi_node):
+            wait_nodes.append(call_function(graph, last_wait_node.target, args, kwargs))
+
+    # Move the new all_reduce_coalesced and its args to right after the input node
+    nodes_to_move = [fused_comm_node] + getitem_nodes + wait_nodes
+    move_block_after(nodes_to_move, last_input_node)
+
+    return CommBlock(
+        shape=[
+            tm.shape
+            for tm in cast(
+                list[TensorMetadata], fused_comm_node.meta.get("tensor_meta")
+            )
+        ],
+        node_list=[fused_comm_node] + getitem_nodes + wait_nodes,
+        wait_nodes=wait_nodes,
+        comm_node=fused_comm_node,
+        inputs=[input_node],
+        outputs=OrderedSet(wait_nodes),
+    )
+
+
+def _scatter_fused_allreduce_waits(
+    graph: fx.Graph,
+    fused_comm_block: CommBlock,
+    orig_comm_blocks: list[CommBlock],
+    node_indices: dict[fx.Node, int],
+    split_and_reshape: bool = True,
+) -> None:
+    """
+    Scatters the result of the fused communication node to the original users.
+    If the fused method is concat splitting the output and reshape will be inserted,
+    before inserting getitem. Otherwise getitem will be used as the users of the
+    wait node.
+    """
+
+    # Before we mass up the order, we need to get the index of the last wait node
+    # in orig_comm_blocks. This index will be later used to determine what users
+    # nodes need to be move to maintain a correct topological sort order.
+    last_wait_node_idx = 0
+    # pyrefly: ignore [bad-assignment]
+    for node in graph.nodes:
+        last_wait_node_idx = max(
+            node_indices.get(node, last_wait_node_idx), last_wait_node_idx
+        )
+        if node == orig_comm_blocks[-1].wait_nodes[0]:
+            break
+
+    if split_and_reshape:
+        fused_wait_node = fused_comm_block.wait_nodes[0]
+        with graph.inserting_after(fused_wait_node):
+            split_node = call_function(
+                graph,
+                aten.split,
+                (
+                    fused_wait_node,
+                    [math.prod(cast(list[int], cb.shape)) for cb in orig_comm_blocks],
+                ),
+            )
+        with graph.inserting_after(split_node):
+            fused_outputs = []
+            for idx, comm_block in enumerate(orig_comm_blocks):
+                split_idx_node = call_function(
+                    graph, operator.getitem, (split_node, idx)
+                )
+                with graph.inserting_after(split_idx_node):
+                    fused_outputs.append(
+                        call_function(
+                            graph, aten.reshape, (split_idx_node, comm_block.shape)
+                        )
+                    )
+    else:
+        fused_outputs = fused_comm_block.wait_nodes
+
+    # Scatter the fused outputs.
+    incorrect_order_nodes = []
+    for comm_block, fused_output in zip(orig_comm_blocks, fused_outputs):
+        # Some descendant users of the orig_comm_blocks may be scheduled before
+        # the fused all_reduce. For example, the user nodes of the very first
+        # all_reduce may be scheduled before the second all_reduce. Since the
+        # fused all_reduce is inserted right after the last all_reduce, the
+        # order can be wrong.
+        # `incorrect_order_nodes` records these nodes.
+
+        orig_wait = comm_block.wait_nodes[0]
+        nodes = collections.deque(list(orig_wait.users))
+        while nodes:
+            user_node = nodes.popleft()
+            if not isinstance(user_node, fx.Node):
+                continue
+            # pyrefly: ignore [unsupported-operation]
+            if node_indices[user_node] < last_wait_node_idx:
+                incorrect_order_nodes.append(user_node)
+                nodes.extend(list(user_node.users))
+
+        orig_wait.replace_all_uses_with(fused_output)
+
+    last_fused_result = fused_outputs[0]
+    fused_outputs_set = OrderedSet(fused_outputs)
+    for node in graph.nodes:
+        if node in fused_outputs_set:
+            last_fused_result = node
+
+    # Move the incorrect_order_nodes to right after the last fused_result.
+    incorrect_order_nodes = sorted(
+        incorrect_order_nodes, key=lambda node: node_indices[node]
+    )
+    move_block_after(incorrect_order_nodes, last_fused_result)
+
+
+def _fuse_allreduce(
+    graph: fx.Graph,
+    comm_blocks: list[CommBlock],
+    node_indices: dict[fx.Node, int],
+    use_concat: bool,
+) -> CommBlock:
+    """Given a list of allreduce CommBlock, fuse the CommBlocks into one CommBlock."""
+
+    if len(comm_blocks) == 1:
+        return comm_blocks[0]
+
+    # Find the last input node of all the CommBlocks. This node will be served
+    # as the inserting point of the new collective op.
+    last_input_node = comm_blocks[0].inputs[0]
+    last_input_index = -1
+    all_input_nodes = []
+    for comm_block in comm_blocks:
+        input_node = comm_block.inputs[0]
+        all_input_nodes.append(input_node)
+        index = node_indices[input_node]
+        if index >= last_input_index:
+            assert index != last_input_index
+            last_input_node = input_node
+            last_input_index = index
+
+    if use_concat:
+        fused_comm_block = _fuse_allreduce_by_concat(
+            graph, last_input_node, all_input_nodes, comm_blocks[-1]
+        )
+    else:
+        fused_comm_block = _fuse_with_coalesced_op(
+            graph, last_input_node, all_input_nodes, comm_blocks[-1]
+        )
+
+    _scatter_fused_allreduce_waits(
+        graph, fused_comm_block, comm_blocks, node_indices, split_and_reshape=use_concat
+    )
+
+    for comm_block in comm_blocks:
+        for wait in comm_block.wait_nodes:
+            graph.erase_node(wait)
+        graph.erase_node(comm_block.comm_node)
+    graph.eliminate_dead_code()
+
+    return fused_comm_block
+
+
+def _bucket_size_fusion(
+    graph: fx.Graph, comm_blocks: list[CommBlock], bucket_size_mb: int
+) -> Generator[list[CommBlock], None, None]:
+    MB = 1024**2
+    bucket_size = 1 * MB
+    bucket_cap_size = bucket_size_mb * MB
+    curr_size = 0
+    curr_blocks = []
+
+    count = 0
+    fuse_count = 0
+    for i, block in enumerate(comm_blocks):
+        curr_blocks.append(block)
+        itemsize = block.comm_node.meta["tensor_meta"].dtype.itemsize
+        curr_size += cast(torch.Size, block.shape).numel() * itemsize
+        count += 1
+        if curr_size < bucket_size and i != len(comm_blocks) - 1:
+            continue
+
+        fuse_count += 1
+        if torch.distributed.get_rank() == 0:
+            logger.info(
+                "DDP bucketing: block%d, count=%d, curr_size=%d, bucket_size=%d",
+                fuse_count,
+                count,
+                curr_size,
+                bucket_size,
+            )
+
+        # Set the debug counters
+        counters["inductor"]["ddp_buckets"] = fuse_count
+        yield curr_blocks
+
+        bucket_size = bucket_cap_size
+        curr_blocks = []
+        curr_size = 0
+        count = 0
+
+
+def _fuse_ddp_communication(
+    graph: fx.Graph, algorithm_fn: Callable[..., Any], fusion_fn: Callable[..., Any]
+) -> None:
+    for output in reversed(graph.nodes):
+        if output.op == "output":
+            break
+
+    def ddp_reducer_filter(block: CommBlock) -> bool:
+        if (
+            not isinstance(block.comm_node.args[0], fx.Node)
+            or block.comm_node.args[0].target != aten.div.Tensor
+        ):
+            return False
+
+        if len(block.wait_nodes[0].users) != 1:
+            # gradient/wait node should only be used by one user
+            return False
+
+        # Two cases:
+        # 1. gradient/wait node should be directly used by the output
+        # if gradient is None before bwd.
+        # 2. gradient/wait node should be directly used by copy_.
+        if (
+            output not in block.wait_nodes[0].users
+            and next(iter(block.wait_nodes[0].users)).target != aten.copy_.default
+        ):
+            return False
+
+        return True
+
+    ops = (
+        torch.ops._c10d_functional.all_reduce_.default,
+        torch.ops._c10d_functional.all_reduce.default,
+    )
+    comm_blocks = get_all_comm_blocks(graph, ops, comm_filter=ddp_reducer_filter)
+    node_indices = {node: i for i, node in enumerate(graph.nodes)}
+
+    for block in algorithm_fn(graph, comm_blocks):
+        fusion_fn(graph, block, node_indices)
+
+
+def fuse_ddp_with_coalesced_op(graph: fx.Graph, bucket_size_mb: int) -> None:
+    _fuse_ddp_communication(
+        graph,
+        partial(_bucket_size_fusion, bucket_size_mb=bucket_size_mb),
+        partial(_fuse_allreduce, use_concat=False),
+    )
+
+
+def fuse_ddp_with_concat_op(graph: fx.Graph, bucket_size_mb: int) -> None:
+    _fuse_ddp_communication(
+        graph,
+        partial(_bucket_size_fusion, bucket_size_mb=bucket_size_mb),
+        partial(_fuse_allreduce, use_concat=True),
+    )
+
+
+def schedule_comm_wait(graph: fx.Graph) -> None:
+    """
+    Delay the execution of wait tensors of allreduce until its first user.
+
+    This algorithm considers the intermediate users, like split, getitem,
+    of the wait node and schedule those intermediate users as well.
+    This will result in a better overlapping result.
+    """
+    ops = (
+        torch.ops._c10d_functional.all_reduce_.default,
+        torch.ops._c10d_functional.all_reduce.default,
+        torch.ops._c10d_functional.all_reduce_coalesced.default,
+        torch.ops._c10d_functional.all_reduce_coalesced_.default,
+    )
+    comm_blocks = get_all_comm_blocks(graph, ops)
+    if not comm_blocks:
+        return
+
+    # Find all the end users.
+    allreduce_users = OrderedSet[fx.Node]()
+    for allreduce in comm_blocks:
+        for output in allreduce.outputs:
+            allreduce_users.update(output.users)
+
+    node_indices = {node: i for i, node in enumerate(graph.nodes)}
+    for allreduce in comm_blocks:
+        # Find the earliest/first user -- target_node.
+        assert len(allreduce.outputs) >= 1, (
+            f"Found a allreduce that has zero outputs/users -- {allreduce}."
+        )
+        # Initialize the target node to avoid typing issues.
+        target_node = next(iter(next(iter(allreduce.outputs)).users))
+        target_node_index = 2**31
+        for user in (user for output in allreduce.outputs for user in output.users):
+            index = node_indices[user]
+            if index < target_node_index:
+                target_node = user
+                target_node_index = index
+
+        # Move wait nodes and all the subsequent nodes in the comm_block to
+        # before the first user -- target_node.
+        wait_idx = -1
+        for wait_idx, node in enumerate(allreduce.node_list):
+            if node == allreduce.wait_nodes[0]:
+                break
+        assert wait_idx >= 0
+        move_block_before(allreduce.node_list[wait_idx:], target_node)
+
+
+def fuse_ddp_communication(
+    graph: fx.Graph, passes: list[Callable[..., None] | str], bucket_size_mb: int
+) -> None:
+    for i, pa in enumerate(passes):
+        with GraphTransformObserver(
+            graph.owning_module, f"fuse_ddp_communication_pass_{i}"
+        ):
+            if isinstance(pa, str):
+                func = globals()[pa]
+            else:
+                func = pa
+            if "bucket_size_mb" in OrderedSet(
+                v.name for v in inspect.signature(func).parameters.values()
+            ):
+                func(graph, bucket_size_mb=bucket_size_mb)
+            else:
+                func(graph)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py
new file mode 100644
index 0000000000000000000000000000000000000000..3613ab1ed17b5e35815d1bca359b94b29b511abc
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py
@@ -0,0 +1,285 @@
+# mypy: allow-untyped-defs
+import logging
+
+import torch
+from torch import Tensor
+from torch._dynamo.utils import counters, is_node_meta_valid
+from torch.fx.experimental.symbolic_shapes import (
+    statically_known_false,
+    statically_known_true,
+)
+
+from .. import config
+from ..pattern_matcher import Arg, CallFunction, Match, register_graph_pattern
+from .split_cat import construct_pattern_matcher_pass
+
+
+aten = torch.ops.aten
+log = logging.getLogger(__name__)
+
+# TODO: need a better strategy for decomposing mm
+# The following two constants are for CUDA device only
+MIN_FIRST_DIMENSION_DECOMPOSITION = 10240
+MAX_OTHER_DIMENSION_DECOMPOSITION = 32
+# The following two constants are for CPU device only
+CPU_MAX_FIRST_DIMENSION_DECOMPOSITION = 1
+CPU_MAX_OTHER_DIMENSION_DECOMPOSITION = 2048
+
+min_first_dimension_decomposition = MIN_FIRST_DIMENSION_DECOMPOSITION
+max_other_dimension_decomposition = MAX_OTHER_DIMENSION_DECOMPOSITION
+cpu_max_first_dimension_decomposition = CPU_MAX_FIRST_DIMENSION_DECOMPOSITION
+cpu_max_other_dimension_decomposition = CPU_MAX_OTHER_DIMENSION_DECOMPOSITION
+if "decompose_mm_pass" in config.post_grad_fusion_options:
+    min_first_dimension_decomposition = config.post_grad_fusion_options[
+        "decompose_mm_pass"
+    ].get("min_first_dimension_decomposition", MIN_FIRST_DIMENSION_DECOMPOSITION)
+    max_other_dimension_decomposition = config.post_grad_fusion_options[
+        "decompose_mm_pass"
+    ].get("max_other_dimension_decomposition", MAX_OTHER_DIMENSION_DECOMPOSITION)
+    cpu_max_first_dimension_decomposition = config.post_grad_fusion_options[
+        "decompose_mm_pass"
+    ].get(
+        "cpu_max_first_dimension_decomposition", CPU_MAX_FIRST_DIMENSION_DECOMPOSITION
+    )
+    cpu_max_other_dimension_decomposition = config.post_grad_fusion_options[
+        "decompose_mm_pass"
+    ].get(
+        "cpu_max_other_dimension_decomposition", CPU_MAX_OTHER_DIMENSION_DECOMPOSITION
+    )
+
+
+def check_device(a: Tensor, b: Tensor, device="cuda") -> bool:
+    return (a.device.type == b.device.type) and (b.device.type == device)
+
+
+def realize_inputs(inputs: list[torch.fx.Node]):
+    for inp in inputs:
+        if isinstance(inp, torch.fx.node.Node):
+            inp.meta["inductor_realize_to_strides"] = True
+
+
+def should_decompose_bmm(mat1, mat2) -> bool:
+    if is_node_meta_valid(mat1) and is_node_meta_valid(mat2):
+        mat1 = mat1.meta["val"]
+        mat2 = mat2.meta["val"]
+    else:
+        return False
+    if len(mat1.shape) != 3 or len(mat2.shape) != 3:
+        return False
+    if check_device(mat1, mat2, device="cuda") or check_device(
+        mat1, mat2, device="xpu"
+    ):
+        if mat1.shape[0] < min_first_dimension_decomposition:
+            return False
+        # 2 of m, n, k must be <= MAX_OTHER_DIMENSION_DECOMPOSITION
+        # use bool() to deal with BooleanAtom type
+        if (
+            bool(mat1.shape[1] < max_other_dimension_decomposition)
+            + bool(mat1.shape[2] < max_other_dimension_decomposition)
+            + bool(mat2.shape[2] < max_other_dimension_decomposition)
+            < 2
+        ):
+            return False
+        return True
+    elif check_device(mat1, mat2, device="cpu"):
+        if (
+            mat1.shape[0] <= cpu_max_first_dimension_decomposition
+            and mat2.shape[0] <= cpu_max_first_dimension_decomposition
+        ):
+            return True
+    return False
+
+
+def should_decompose_mm(mat1, mat2) -> bool:
+    """
+    Determines whether matrix multiplication (mm) should be decomposed into pointwise operations
+    based on the input matrices' metadata, shapes, device placement, and configuration options.
+    Args:
+        mat1: The first matrix operand. Expected to be an object with a `.meta` attribute containing
+              a "val" key, or a tensor-like object with a `.shape` attribute.
+        mat2: The second matrix operand. Same requirements as `mat1`.
+    Returns:
+        bool: True if the matrix multiplication should be decomposed according to the following logic:
+            - Both inputs must have valid node metadata.
+            - Both matrices must be 2-dimensional.
+            - If the configuration option `skip_dynamic_shape_dim_check` is False:
+                - Decomposition is only considered for statically-shaped matrices.
+                - For CUDA devices: `mat1.shape[0]` must be at least `min_first_dimension_decomposition`,
+                  and both dimensions of `mat2` must be less than `max_other_dimension_decomposition`.
+                - For CPU devices: All relevant dimensions must be less than or equal to their respective
+                  CPU decomposition thresholds.
+            - If `skip_dynamic_shape_dim_check` is True:
+                - Decomposition is considered for dynamic shapes as well, using a combination of
+                  `statically_known_true` and `statically_known_false` checks to handle uncertainty.
+                - The same dimension and device checks apply, but allow for dynamic/static uncertainty.
+            - Returns False if any of the above conditions are not met.
+    Notes:
+        - Relies on helper functions such as `is_node_meta_valid`, `check_device`, `statically_known_true`,
+          and `statically_known_false`, as well as configuration values like
+          `min_first_dimension_decomposition`, `max_other_dimension_decomposition`, etc.
+        - Designed for use in graph optimization or fusion passes where decomposing large or dynamic
+          matrix multiplications can improve performance or memory usage.
+    """
+    if is_node_meta_valid(mat1) and is_node_meta_valid(mat2):
+        mat1 = mat1.meta["val"]
+        mat2 = mat2.meta["val"]
+    else:
+        return False
+    if len(mat1.shape) != 2 or len(mat2.shape) != 2:
+        return False
+    # case 1: we skip decompose mm if the input is dynamic shape
+    if not config.post_grad_fusion_options["decompose_mm_pass"].get(
+        "skip_dynamic_shape_dim_check", False
+    ):
+        return (
+            (
+                check_device(mat1, mat2, device="cuda")
+                or check_device(mat1, mat2, device="xpu")
+            )
+            and statically_known_true(
+                mat1.shape[0] >= min_first_dimension_decomposition
+            )
+            and statically_known_true(mat2.shape[0] < max_other_dimension_decomposition)
+            and statically_known_true(mat2.shape[1] < max_other_dimension_decomposition)
+        ) or (
+            check_device(mat1, mat2, device="cpu")
+            and statically_known_true(
+                mat1.shape[0] <= cpu_max_first_dimension_decomposition
+            )
+            and statically_known_true(
+                mat2.shape[0] <= cpu_max_other_dimension_decomposition
+            )
+            and statically_known_true(
+                mat2.shape[1] <= cpu_max_other_dimension_decomposition
+            )
+        )
+    # case 2: we decompose mm if the input is dynamic shape
+    else:
+        return (
+            (
+                check_device(mat1, mat2, device="cuda")
+                or check_device(mat1, mat2, device="xpu")
+            )
+            and (
+                statically_known_true(
+                    mat1.shape[0] >= min_first_dimension_decomposition
+                )
+                or not statically_known_false(
+                    mat1.shape[0] >= min_first_dimension_decomposition
+                )
+            )
+            and (
+                statically_known_true(mat2.shape[0] < max_other_dimension_decomposition)
+                or not statically_known_false(
+                    mat2.shape[0] < max_other_dimension_decomposition
+                )
+            )
+            and (
+                statically_known_true(mat2.shape[1] < max_other_dimension_decomposition)
+                or not statically_known_false(
+                    mat2.shape[1] < max_other_dimension_decomposition
+                )
+            )
+        ) or (
+            check_device(mat1, mat2, device="cpu")
+            and (
+                statically_known_true(
+                    mat1.shape[0] <= cpu_max_first_dimension_decomposition
+                )
+                or not statically_known_false(
+                    mat1.shape[0] <= cpu_max_first_dimension_decomposition
+                )
+            )
+            and (
+                statically_known_true(
+                    mat2.shape[0] <= cpu_max_other_dimension_decomposition
+                )
+                or not statically_known_false(
+                    mat2.shape[0] <= cpu_max_other_dimension_decomposition
+                )
+            )
+            and (
+                statically_known_true(
+                    mat2.shape[1] <= cpu_max_other_dimension_decomposition
+                )
+                or not statically_known_false(
+                    mat2.shape[1] <= cpu_max_other_dimension_decomposition
+                )
+            )
+        )
+
+
+def print_decompose_pattern(match: Match, inputs: list[torch.fx.Node]):
+    node = match.nodes[-1]
+    log.debug(
+        "Decompose %s with input shape: %s",
+        node.target,
+        ", ".join(
+            str(input.meta["val"].shape) if "val" in input.meta else "None"
+            for input in inputs
+        ),
+    )
+
+
+@register_graph_pattern(
+    CallFunction(aten.bmm, Arg(), Arg()),
+    pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"),
+)
+def decompose_bmm(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node):
+    def repl(mat1, mat2):
+        return torch.sum(mat1[:, :, :, None] * mat2[:, None, :, :], dim=-2).to(
+            mat1.dtype
+        )
+
+    if should_decompose_bmm(mat1, mat2):
+        counters["inductor"]["decompose_bmm"] += 1
+        # pyrefly: ignore [bad-argument-type]
+        match.replace_by_example(repl, [mat1, mat2])
+        print_decompose_pattern(match, [mat1, mat2])
+        realize_inputs([mat1, mat2])
+    return
+
+
+@register_graph_pattern(
+    CallFunction(aten.addmm, Arg(), Arg(), Arg()),
+    pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"),
+)
+def decompose_addmm(
+    match: Match,
+    mat1: torch.fx.Node,
+    mat2: torch.fx.Node,
+    mat3: torch.fx.Node,
+):
+    def repl(mat1, mat2, mat3):
+        return (
+            torch.sum(mat2[:, :, None] * mat3[None, :, :], dim=-2).to(mat2.dtype) + mat1
+        )
+
+    if should_decompose_mm(mat2, mat3):
+        counters["inductor"]["decompose_addmm"] += 1
+        # pyrefly: ignore [bad-argument-type]
+        match.replace_by_example(repl, [mat1, mat2, mat3])
+        print_decompose_pattern(match, [mat1, mat2, mat3])
+        realize_inputs([mat1, mat2, mat3])
+    return
+
+
+@register_graph_pattern(
+    CallFunction(aten.mm, Arg(), Arg()),
+    pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"),
+)
+def decompose_mm(
+    match: Match,
+    mat1: torch.fx.Node,
+    mat2: torch.fx.Node,
+):
+    def repl(mat1, mat2):
+        return torch.sum(mat1[:, :, None] * mat2[None, :, :], dim=-2).to(mat1.dtype)
+
+    if should_decompose_mm(mat1, mat2):
+        counters["inductor"]["decompose_mm"] += 1
+        # pyrefly: ignore [bad-argument-type]
+        match.replace_by_example(repl, [mat1, mat2])
+        print_decompose_pattern(match, [mat1, mat2])
+        realize_inputs([mat1, mat2])
+    return
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/dedupe_symint_uses.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/dedupe_symint_uses.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b431c2f17117ae0c9e570072759a72417711562
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/dedupe_symint_uses.py
@@ -0,0 +1,81 @@
+# mypy: allow-untyped-defs
+from dataclasses import dataclass
+from typing import Any
+
+import torch
+from torch import SymBool, SymFloat, SymInt
+from torch.types import py_sym_types
+from torch.utils._ordered_set import OrderedSet
+
+
+@dataclass
+class _SymExprHash:
+    """
+    Hash for a py_sym_types that will use the underlying sympy expression
+    """
+
+    sym_obj: SymInt | SymFloat | SymBool
+
+    def __hash__(self) -> int:
+        return hash((type(self.sym_obj), self.sym_obj.node.expr))
+
+    def __eq__(self, value) -> bool:
+        if not isinstance(value, _SymExprHash):
+            return False
+        return self.sym_obj.node.expr == value.sym_obj.node.expr
+
+
+class _SymHashingDict:
+    """
+    Wrapper around a dictionary that will convert sym types to hash with _SymExprHash and reuse
+    existing sym proxies.
+
+    SymPy hash is not always reliable so optimistically hash sympy expression, and if those fail,
+    fallback to symnodes.
+    """
+
+    def __init__(self):
+        self.sym_hash_dict = {}
+
+    def __setitem__(self, key, value):
+        self.sym_hash_dict.__setitem__(self._wrap_to_sym_expr_hash(key), value)
+
+    def __getitem__(self, key):
+        return self.sym_hash_dict[self._wrap_to_sym_expr_hash(key)]
+
+    def __contains__(self, key):
+        return self._wrap_to_sym_expr_hash(key) in self.sym_hash_dict
+
+    def get(self, key, default=None):
+        return self.sym_hash_dict.get(self._wrap_to_sym_expr_hash(key), default)
+
+    def _wrap_to_sym_expr_hash(self, key):
+        return _SymExprHash(key) if isinstance(key, py_sym_types) else key
+
+
+def dedupe_symints(graph: torch.fx.Graph):
+    """
+    Dedupes sym ints in the graph to nodes are resolvable to symint graph inputs.
+
+    We only dedupe from graph inputs to avoid adding a potential dependency in the forward
+    from the backward.
+
+    """
+
+    sym_dict = _SymHashingDict()
+    resolvable_from_input_symints = OrderedSet[Any]()
+
+    for node in graph.nodes:
+        val = node.meta.get("val", None)
+        if val is None or not isinstance(val, py_sym_types):
+            continue
+
+        if node.op == "placeholder":
+            resolvable_from_input_symints.add(node)
+            sym_dict[val] = node
+        elif existing_node := sym_dict.get(val):
+            node.replace_all_uses_with(existing_node)
+            graph.erase_node(node)
+        elif all(n in resolvable_from_input_symints for n in node.all_input_nodes):
+            sym_dict[val] = node
+            resolvable_from_input_symints.add(node)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..72c853f7e5f66c980222244e822942d2fad640f5
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py
@@ -0,0 +1,408 @@
+# mypy: allow-untyped-defs
+import torch
+import torch.nn as nn
+from torch._dynamo.utils import counters
+from torch._inductor import config as inductor_config
+from torch.func import functional_call
+
+from ..pattern_matcher import (
+    CallFunctionVarArgs,
+    CallModuleVarArgs,
+    Match,
+    register_graph_pattern,
+)
+from .pre_grad import efficient_conv_bn_eval_pass
+
+
+def efficient_conv_bn_eval(
+    bn: nn.modules.batchnorm._BatchNorm, conv: nn.modules.conv._ConvNd, x: torch.Tensor
+):
+    """
+    Implementation based on https://arxiv.org/abs/2305.11624
+    "Efficient ConvBN Blocks for Transfer Learning and Beyond"
+    It leverages the associative law between convolution and affine transform,
+    i.e., normalize (weight conv feature) = (normalize weight) conv feature.
+    It works for Eval mode of ConvBN blocks during validation, and can be used
+    for **training** as well, but only if one sets `bn.training=False`. It
+     reduces memory footprint and computation cost, at the cost of slightly
+     reduced numerical stability.
+    Args:
+        bn (nn.modules.batchnorm._BatchNorm): a BatchNorm module.
+        conv (nn.modules.conv._ConvNd): a conv module
+        x (torch.Tensor): Input feature map.
+    """
+
+    assert bn.running_var is not None
+    assert bn.running_mean is not None
+
+    # These lines of code are designed to deal with various cases
+    # like bn without affine transform, and conv without bias
+    weight_on_the_fly = conv.weight
+    if conv.bias is not None:
+        bias_on_the_fly = conv.bias
+    else:
+        bias_on_the_fly = torch.zeros_like(bn.running_var)
+
+    if bn.weight is not None:
+        bn_weight = bn.weight
+    else:
+        bn_weight = torch.ones_like(bn.running_var)
+
+    if bn.bias is not None:
+        bn_bias = bn.bias
+    else:
+        bn_bias = torch.zeros_like(bn.running_var)
+
+    # shape of [C_out, 1, 1, 1] in Conv2d
+    target_shape = [-1] + [1] * (conv.weight.ndim - 1)
+    if isinstance(conv, nn.modules.conv._ConvTransposeNd):
+        # for transposed conv, the C_out dimension should at index 1.
+        target_shape[:2] = [target_shape[1], target_shape[0]]
+    weight_coeff = torch.rsqrt(bn.running_var + bn.eps).reshape(target_shape)
+    # shape of [C_out, 1, 1, 1] in Conv2d
+    coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff
+
+    # shape of [C_out, C_in, k, k] in Conv2d
+    weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly
+    # shape of [C_out] in Conv2d
+    bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() * (
+        bias_on_the_fly - bn.running_mean
+    )
+
+    input = x
+    params = {"weight": weight_on_the_fly, "bias": bias_on_the_fly}
+    output = functional_call(conv, params, input)
+    return output
+
+
+def efficient_conv_bn_eval_decomposed(
+    bn_weight,
+    bn_bias,
+    bn_running_mean,
+    bn_running_var,
+    bn_eps,
+    conv: torch._ops.OpOverload,
+    conv_weight,
+    conv_bias,
+    x,
+    conv_remainging_args,
+):
+    """
+    Implementation based on https://arxiv.org/abs/2305.11624
+    "Efficient ConvBN Blocks for Transfer Learning and Beyond"
+    It leverages the associative law between convolution and affine transform,
+    i.e., normalize (weight conv feature) = (normalize weight) conv feature.
+    It works for Eval mode of ConvBN blocks during validation, and can be used
+    for **training** as well, but only if one sets `bn.training=False`. It
+     reduces memory footprint and computation cost, at the cost of slightly
+     reduced numerical stability.
+    Args:
+    """
+    assert bn_running_var is not None
+
+    # These lines of code are designed to deal with various cases
+    # like bn without affine transform, and conv without bias
+    weight_on_the_fly = conv_weight
+    if conv_bias is not None:
+        bias_on_the_fly = conv_bias
+    else:
+        bias_on_the_fly = torch.zeros_like(bn_running_var)
+
+    if bn_weight is None:
+        bn_weight = torch.ones_like(bn_running_var)
+
+    if bn_bias is None:
+        bn_bias = torch.zeros_like(bn_running_var)
+
+    # shape of [C_out, 1, 1, 1] in Conv2d
+    target_shape = [-1] + [1] * (conv_weight.ndim - 1)
+    if "conv_transpose" in conv.__str__():
+        # for transposed conv, the C_out dimension should at index 1.
+        target_shape[:2] = [target_shape[1], target_shape[0]]
+    weight_coeff = torch.rsqrt(bn_running_var + bn_eps).reshape(target_shape)
+    # shape of [C_out, 1, 1, 1] in Conv2d
+    coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff
+
+    # shape of [C_out, C_in, k, k] in Conv2d
+    weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly
+    # shape of [C_out] in Conv2d
+    bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() * (
+        bias_on_the_fly - bn_running_mean
+    )
+
+    input = x
+    return conv(*((input, weight_on_the_fly, bias_on_the_fly) + conv_remainging_args))
+
+
+@register_graph_pattern(
+    CallFunctionVarArgs(
+        [
+            torch.nn.functional.batch_norm,
+        ]
+    ),
+    # pyrefly: ignore [bad-argument-type]
+    pass_dict=efficient_conv_bn_eval_pass,
+    extra_check=lambda match: not inductor_config.freezing
+    and inductor_config.efficient_conv_bn_eval_fx_passes,
+)
+def efficient_conv_bn_eval_graph_transform_inlined(match: Match, *args, **kwargs):
+    bn_node = match.nodes[0]
+    graph = match.graph
+    assert len(bn_node.args) == 8
+
+    # We can only use efficient conv-bn for eval mode with track_running_stats
+    # bn_node.args is `training`
+    if bn_node.args[-3]:
+        return
+
+    # Check if the input is Conv
+    input_node = bn_node.args[0]
+
+    if input_node.op != "call_function":  # type: ignore[union-attr]
+        return
+
+    input_fn = input_node.target  # type: ignore[arg-type, union-attr]
+    supported_convs = [
+        torch._C._nn.linear,
+        torch.conv1d,
+        torch.conv2d,
+        torch.conv3d,
+        torch.conv_transpose1d,
+        torch.conv_transpose2d,
+        torch.conv_transpose3d,
+    ]
+
+    if not any(input_fn is cls for cls in supported_convs):
+        return
+
+    conv_node = input_node
+    # Output of conv is used by other nodes, cannot optimize
+    if len(conv_node.users) > 1:  # type: ignore[union-attr]
+        return
+
+    counters["inductor"]["efficient_conv_bn_eval"] += 1
+
+    with graph.inserting_before(bn_node):
+        # prepare args for the fused function
+        bn_running_mean = bn_node.args[1]
+        bn_running_var = bn_node.args[2]
+        bn_weight = bn_node.args[3]
+        bn_bias = bn_node.args[4]
+        bn_eps = bn_node.args[7]
+        assert len(conv_node.args) >= 2  # type: ignore[union-attr]
+        conv_input = conv_node.args[0]  # type: ignore[union-attr]
+        conv_weight = conv_node.args[1]  # type: ignore[union-attr]
+        conv_bias = conv_node.args[2] if len(conv_node.args) >= 3 else None  # type: ignore[union-attr]
+        conv_remainging_args = conv_node.args[3:]  # type: ignore[union-attr]
+        args = (
+            bn_weight,
+            bn_bias,
+            bn_running_mean,
+            bn_running_var,
+            bn_eps,
+            conv_node.target,  # type: ignore[union-attr]
+            conv_weight,
+            conv_bias,
+            conv_input,
+            conv_remainging_args,
+        )
+
+        # create a new node
+        new_node = graph.create_node(
+            op="call_function",
+            target=efficient_conv_bn_eval_decomposed,
+            args=args,  # type: ignore[arg-type]
+            name="efficient_conv_bn_eval",
+        )
+
+    # this node replaces the original conv + bn, and therefore
+    # should replace the uses of bn_node
+    bn_node.replace_all_uses_with(new_node)
+    # take care of the deletion order:
+    # delete bn_node first, and then conv_node
+    graph.erase_node(bn_node)
+    graph.erase_node(conv_node)  # type: ignore[arg-type]
+
+    return
+
+
+@register_graph_pattern(
+    CallFunctionVarArgs(
+        [
+            torch.ops.aten.batch_norm.default,
+        ]
+    ),
+    # pyrefly: ignore [bad-argument-type]
+    pass_dict=efficient_conv_bn_eval_pass,
+    extra_check=lambda match: not inductor_config.freezing
+    and inductor_config.efficient_conv_bn_eval_fx_passes,
+)
+def efficient_conv_bn_eval_graph_transform_decomposed(match: Match, *args, **kwargs):
+    bn_node = match.nodes[0]
+    graph = match.graph
+    assert len(bn_node.args) == 9
+
+    # We can only use efficient conv-bn for eval mode with track_running_stats
+    # bn_node.args is `training`
+    if bn_node.args[-4]:
+        return
+
+    # Check if the input is Conv
+    input_node = bn_node.args[0]
+
+    if input_node.op != "call_function":  # type: ignore[union-attr]
+        return
+
+    input_fn = input_node.target  # type: ignore[arg-type, union-attr]
+    supported_convs = [
+        torch.ops.aten.linear.default,
+        torch.ops.aten.conv1d.default,
+        torch.ops.aten.conv2d.default,
+        torch.ops.aten.conv3d.default,
+        torch.ops.aten.conv_transpose1d.default,
+        torch.ops.aten.conv_transpose2d.input,
+        torch.ops.aten.conv_transpose3d.input,
+    ]
+
+    if not any(input_fn is cls for cls in supported_convs):
+        return
+
+    conv_node = input_node
+    # Output of conv is used by other nodes, cannot optimize
+    if len(conv_node.users) > 1:  # type: ignore[union-attr]
+        return
+
+    counters["inductor"]["efficient_conv_bn_eval"] += 1
+
+    with graph.inserting_before(bn_node):
+        # prepare args for the fused function
+        bn_weight = bn_node.args[1]
+        bn_bias = bn_node.args[2]
+        bn_running_mean = bn_node.args[3]
+        bn_running_var = bn_node.args[4]
+        bn_eps = bn_node.args[7]
+        assert len(conv_node.args) >= 2  # type: ignore[union-attr]
+        conv_input = conv_node.args[0]  # type: ignore[union-attr]
+        conv_weight = conv_node.args[1]  # type: ignore[union-attr]
+        conv_bias = conv_node.args[2] if len(conv_node.args) >= 3 else None  # type: ignore[union-attr]
+        conv_remainging_args = conv_node.args[3:]  # type: ignore[union-attr]
+        args = (
+            bn_weight,
+            bn_bias,
+            bn_running_mean,
+            bn_running_var,
+            bn_eps,
+            conv_node.target,  # type: ignore[union-attr]
+            conv_weight,
+            conv_bias,
+            conv_input,
+            conv_remainging_args,
+        )
+
+        # create a new node
+        new_node = graph.create_node(
+            op="call_function",
+            target=efficient_conv_bn_eval_decomposed,
+            args=args,  # type: ignore[arg-type]
+            name="efficient_conv_bn_eval",
+        )
+
+    # this node replaces the original conv + bn, and therefore
+    # should replace the uses of bn_node
+    bn_node.replace_all_uses_with(new_node)
+    # take care of the deletion order:
+    # delete bn_node first, and then conv_node
+    graph.erase_node(bn_node)
+    graph.erase_node(conv_node)  # type: ignore[arg-type]
+
+    return
+
+
+@register_graph_pattern(
+    CallModuleVarArgs(
+        [
+            nn.modules.batchnorm._BatchNorm,
+            nn.BatchNorm1d,
+            nn.BatchNorm2d,
+            nn.BatchNorm3d,
+            nn.SyncBatchNorm,
+        ],
+    ),
+    # pyrefly: ignore [bad-argument-type]
+    pass_dict=efficient_conv_bn_eval_pass,
+    extra_check=lambda match: not inductor_config.freezing
+    and inductor_config.efficient_conv_bn_eval_fx_passes,
+)
+def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
+    # We matched a BN node
+    bn_node = match.nodes[0]
+    graph = match.graph
+    gm = graph.owning_module
+    bn_mod = getattr(gm, bn_node.target)  # type: ignore[arg-type]
+
+    # We can only use efficient conv-bn for eval mode with track_running_stats
+    if not bn_mod.track_running_stats or bn_mod.training:
+        return
+
+    # Check if the input is Conv
+    if bn_node.args:
+        input_node = bn_node.args[0]
+    else:
+        input_node = bn_node.kwargs["input"]
+    if input_node.op != "call_module":  # type: ignore[union-attr]
+        return
+    if not hasattr(gm, input_node.target):  # type: ignore[arg-type, union-attr]
+        return
+    input_mod = getattr(gm, input_node.target)  # type: ignore[arg-type, union-attr]
+    supported_convs = [
+        nn.Linear,
+        nn.Conv1d,
+        nn.Conv2d,
+        nn.Conv3d,
+        nn.ConvTranspose1d,
+        nn.ConvTranspose2d,
+        nn.ConvTranspose3d,
+    ]
+    if not any(isinstance(input_mod, cls) for cls in supported_convs):
+        return
+    conv_node = input_node
+    # Output of conv is used by other nodes, cannot optimize
+    if len(conv_node.users) > 1:  # type: ignore[union-attr]
+        return
+
+    # Find a pair of conv and bn computation nodes to optimize.
+    counters["inductor"]["efficient_conv_bn_eval"] += 1
+
+    with graph.inserting_before(conv_node):  # type: ignore[arg-type]
+        # create `get_attr` node to access modules
+        # note that we directly call `create_node` to fill the `name`
+        # argument. `graph.get_attr` and
+        # `graph.call_function` does not allow the `name` argument.
+        conv_get_node = graph.create_node(
+            op="get_attr",
+            target=conv_node.target,  # type: ignore[union-attr]
+            name="get_conv",
+        )
+        bn_get_node = graph.create_node(
+            op="get_attr", target=bn_node.target, name="get_bn"
+        )
+        if conv_node.args:  # type: ignore[union-attr]
+            conv_input = conv_node.args[0]  # type: ignore[union-attr]
+        else:
+            conv_input = conv_node.kwargs["input"]  # type: ignore[union-attr]
+        # prepare args for the fused function
+        args = (bn_get_node, conv_get_node, conv_input)
+        # create a new node
+        new_node = graph.create_node(
+            op="call_function",
+            target=efficient_conv_bn_eval,
+            args=args,
+            name="efficient_conv_bn_eval",
+        )
+    # this node replaces the original conv + bn, and therefore
+    # should replace the uses of bn_node
+    bn_node.replace_all_uses_with(new_node)
+    # take care of the deletion order:
+    # delete bn_node first, and then conv_node
+    graph.erase_node(bn_node)
+    graph.erase_node(conv_node)  # type: ignore[arg-type]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/freezing_patterns.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/freezing_patterns.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8fca2087a5d5220b7256f313bbb25d2d23d9ab7
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/freezing_patterns.py
@@ -0,0 +1,310 @@
+# mypy: allow-untyped-defs
+import functools
+
+import torch
+from torch._inductor.compile_fx import fake_tensor_prop
+from torch._inductor.utils import GPU_TYPES
+
+from ..._dynamo.utils import counters
+from .. import config
+from ..pattern_matcher import (
+    _return_true,
+    CallFunction,
+    fwd_only,
+    Ignored,
+    init_once_fakemode,
+    KeywordArg,
+    Match,
+    PatternMatcherPass,
+    register_graph_pattern,
+    register_replacement,
+    stable_topological_sort,
+)
+
+
+aten = torch.ops.aten
+
+# First pass_patterns[0] are applied, then [1], then [2]
+pass_patterns = [
+    PatternMatcherPass(),
+    PatternMatcherPass(),
+    PatternMatcherPass(),
+]
+
+binary_folding_pass = PatternMatcherPass()
+
+
+def freezing_passes(gm: torch.fx.GraphModule, aot_example_inputs):
+    """
+    Passes that are applied to the graph to freeze pass.
+    """
+
+    from ..freezing import constant_fold
+
+    lazy_init()
+    # We need a few rounds of binary folding to get rid of all the
+    # unnecessary nodes, but may need a good method to chose the rounds number.
+    # works like: conv+binary+binary.
+    binary_folding = counters["inductor"]["binary_folding"]
+    fake_tensor_prop(gm, aot_example_inputs, True)
+
+    torch._inductor.fx_passes.binary_folding.mark_mixed_dtype_allowed_computation_ops(
+        gm
+    )
+    for _ in range(4):
+        constant_fold(gm)
+        # Make sure meta['val'] is properly set for all nodes
+        fake_tensor_prop(gm, aot_example_inputs, True)
+        binary_folding_pass.apply(gm.graph)  # type: ignore[arg-type]
+        # If we don't have binary folding, we don't need to run the pass again.
+        # TODO: remove the need to run fake_tensor_prop on the whole model.
+        if counters["inductor"]["binary_folding"] == binary_folding:
+            break
+        binary_folding = counters["inductor"]["binary_folding"]
+
+    torch._inductor.fx_passes.binary_folding.recover_original_precision_folded_computation_ops(
+        gm
+    )
+
+    constant_fold(gm)
+    fake_tensor_prop(gm, aot_example_inputs, True)
+
+    for pattern in pass_patterns:
+        pattern.apply(gm.graph)  # type: ignore[arg-type]
+
+    # The CPU weight packing always assume the conv's weight is channels last,
+    # So make sure the layout_optimization is on when doing it.
+    if (
+        torch._C._has_mkldnn
+        and config.cpp.weight_prepack
+        and config.layout_optimization
+    ):
+        from .mkldnn_fusion import _eliminate_duplicate_packed_nodes
+
+        _eliminate_duplicate_packed_nodes(gm)
+
+    stable_topological_sort(gm.graph)
+    gm.recompile()
+    gm.graph.lint()
+
+
+@init_once_fakemode
+def lazy_init():
+    if torch._C._has_mkldnn and config.cpp.weight_prepack:
+        from .mkldnn_fusion import _mkldnn_weight_pack_init
+
+        _mkldnn_weight_pack_init()
+
+    from .binary_folding import binary_folding_init
+
+    addmm_patterns_init()
+    binary_folding_init()
+
+
+def register_freezing_graph_pattern(pattern, extra_check=_return_true, pass_number=0):
+    while pass_number > len(pass_patterns) - 1:
+        pass_patterns.append(PatternMatcherPass())
+    return register_graph_pattern(
+        pattern,
+        extra_check=extra_check,
+        # pyrefly: ignore [bad-argument-type]
+        pass_dict=pass_patterns[pass_number],
+    )
+
+
+def register_binary_folding_pattern(pattern, extra_check=_return_true):
+    return register_graph_pattern(
+        pattern,
+        extra_check=extra_check,
+        # pyrefly: ignore [bad-argument-type]
+        pass_dict=binary_folding_pass,
+    )
+
+
+@functools.cache
+def addmm_patterns_init():
+    """
+    addmm related patterns.
+    To avoid duplication, also includes int8 WoQ GEMM pattern without bias.
+    """
+    device = next(
+        (gpu for gpu in GPU_TYPES if getattr(torch, gpu).is_available()), "cpu"
+    )
+    val = functools.partial(torch.empty, (10, 10), device=device, requires_grad=False)
+    scale = functools.partial(torch.empty, (10,), device=device, requires_grad=False)
+
+    def check_int8_woq_concat_linear_weights(match):
+        is_cpu = match.kwargs["inp"].meta["val"].is_cpu
+        if not is_cpu or not config.cpp.enable_concat_linear:
+            # Currently, this pattern is only supported on CPU
+            return False
+
+        weight_inputs = ["w1", "w2"]
+        if "w3" in match.kwargs:
+            weight_inputs.append("w3")
+
+        if not all(
+            match.kwargs[wgt].target is torch.ops.prims.convert_element_type.default
+            for wgt in weight_inputs
+        ):
+            return False
+
+        if not all(
+            next(iter(match.kwargs[wgt]._input_nodes.keys())).meta["val"].dtype
+            is torch.int8
+            for wgt in weight_inputs
+        ):
+            return False
+
+        if not all(
+            match.kwargs[wgt].meta["val"].dtype is torch.bfloat16
+            for wgt in weight_inputs
+        ):
+            return False
+
+        return True
+
+    def check_concat_weights(match):
+        is_cpu = match.kwargs["inp"].meta["val"].is_cpu
+        if is_cpu and not config.cpp.enable_concat_linear:
+            return False
+
+        weight_inputs = ["w1", "w2"]
+        if "w3" in match.kwargs:
+            weight_inputs.append("w3")
+
+        equal_shape_inputs = [weight_inputs]
+
+        if "b1" in match.kwargs:
+            bias_inputs = ["b1", "b2"]
+            if "b3" in match.kwargs:
+                bias_inputs.append("b3")
+
+            equal_shape_inputs.append(bias_inputs)
+
+        for equal_shape_group in equal_shape_inputs:
+            inps = [match.kwargs[name] for name in equal_shape_group]
+
+            if not all(
+                inp.op == "get_attr"
+                and inp.meta["val"].shape == inps[0].meta["val"].shape
+                for inp in inps
+            ):
+                return False
+        return True
+
+    def int8_woq_fusion_pattern(inp, w1, w2, w3, s1, s2, s3):
+        return ((inp @ w1) * s1, (inp @ w2) * s2, (inp @ w3) * s3)
+
+    def int8_woq_fusion_replacement(inp, w1, w2, w3, s1, s2, s3):
+        cat_w = torch.cat((w1, w2, w3), dim=1)
+        cat_s = torch.cat((s1, s2, s3), dim=0)
+        mm = (inp @ cat_w).mul(cat_s)
+        n1, n2 = w1.size(1), w2.size(1)
+        return mm.tensor_split([n1, n1 + n2], dim=-1)
+
+    register_replacement(
+        # pyrefly: ignore [bad-argument-type]
+        int8_woq_fusion_pattern,
+        # pyrefly: ignore [bad-argument-type]
+        int8_woq_fusion_replacement,
+        [val(), val(), val(), val(), scale(), scale(), scale()],
+        # pyrefly: ignore [bad-argument-type]
+        fwd_only,
+        # pyrefly: ignore [bad-argument-type]
+        pass_patterns[0],
+        extra_check=check_int8_woq_concat_linear_weights,
+        exclusive_arg_names=("w1", "w2", "w3", "s1", "s2", "s3"),
+    )
+
+    def matmul_fuse_pattern(inp, w1, w2, w3):
+        return (inp @ w1, inp @ w2, inp @ w3)
+
+    def matmul_replacement(inp, w1, w2, w3):
+        cat_t = torch.cat((w1, w2, w3), dim=1)
+        mm = inp @ cat_t
+        return mm.chunk(3, dim=1)
+
+    register_replacement(
+        # pyrefly: ignore [bad-argument-type]
+        matmul_fuse_pattern,
+        # pyrefly: ignore [bad-argument-type]
+        matmul_replacement,
+        [val(), val(), val(), val()],
+        # pyrefly: ignore [bad-argument-type]
+        fwd_only,
+        # pyrefly: ignore [bad-argument-type]
+        pass_patterns[0],
+        extra_check=check_concat_weights,
+        exclusive_arg_names=("w1", "w2", "w3"),
+    )
+
+    def matmul_fuse_pattern_two(inp, w1, w2):
+        return (inp @ w1, inp @ w2)
+
+    def matmul_replacement_two(inp, w1, w2):
+        cat_t = torch.cat((w1, w2), dim=1)
+        mm = inp @ cat_t
+        return mm.chunk(2, dim=1)
+
+    register_replacement(
+        # pyrefly: ignore [bad-argument-type]
+        matmul_fuse_pattern_two,
+        # pyrefly: ignore [bad-argument-type]
+        matmul_replacement_two,
+        [val(), val(), val()],
+        # pyrefly: ignore [bad-argument-type]
+        fwd_only,
+        # pyrefly: ignore [bad-argument-type]
+        pass_patterns[0],
+        extra_check=check_concat_weights,
+        exclusive_arg_names=("w1", "w2"),
+    )
+
+    def addmm_fuse_pattern_second(inp, w1, w2, w3, b1, b2, b3):
+        return (
+            aten.addmm(b1, inp, w1),
+            aten.addmm(b2, inp, w2),
+            aten.addmm(b3, inp, w3),
+        )
+
+    def addmm_fuse_replacement_second(inp, w1, w2, w3, b1, b2, b3):
+        cat_w = torch.cat((w1, w2, w3), dim=1)
+        cat_b = torch.cat((b1, b2, b3))
+        return aten.addmm(cat_b, inp, cat_w).chunk(3, dim=1)
+
+    register_replacement(
+        # pyrefly: ignore [bad-argument-type]
+        addmm_fuse_pattern_second,
+        # pyrefly: ignore [bad-argument-type]
+        addmm_fuse_replacement_second,
+        [val() for _ in range(7)],
+        # pyrefly: ignore [bad-argument-type]
+        fwd_only,
+        # pyrefly: ignore [bad-argument-type]
+        pass_patterns[0],
+        extra_check=check_concat_weights,
+        exclusive_arg_names=("w1", "w2", "w3", "b1", "b2", "b3"),
+    )
+
+
+def same_dtype(match):
+    return match.output_node().args[0].meta["val"].dtype == match.kwargs["dtype"]
+
+
+@register_graph_pattern(
+    CallFunction(
+        torch.ops.prims.convert_element_type.default,
+        Ignored(),
+        KeywordArg("dtype"),
+    ),
+    # pyrefly: ignore [bad-argument-type]
+    pass_dict=pass_patterns[0],
+    extra_check=same_dtype,
+)
+def unnecessary_dtype_convert(match: Match, **kwargs):
+    """Remove unnecessary dtype conversion op, probably left as a result of Conv-Bn folding"""
+    graph = match.graph
+    node = match.output_node()
+    node.replace_all_uses_with(node.args[0])  # type: ignore[arg-type]
+    graph.erase_node(node)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/fsdp.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/fsdp.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e71c350ed7b67b47e7a77af7cbd4b93bfc48f98
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/fsdp.py
@@ -0,0 +1,115 @@
+import logging
+from collections.abc import Callable
+
+import torch
+from torch._inductor.fx_passes.bucketing import (
+    bucket_all_gather_by_mb,
+    bucket_reduce_scatter_by_mb,
+    BucketMode,
+    merge_all_gather,
+    merge_reduce_scatter,
+)
+
+
+logger: logging.Logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+
+def is_graph_input(node: torch.fx.Node) -> bool:
+    return node.op == "placeholder"
+
+
+def is_fsdp_all_gather_wait(wait: torch.fx.Node) -> bool:
+    # Assume all_gather_into_tensor input is either graph input
+    # or dtype conversion of graph input
+    ag_node = wait.args[0]  # type: ignore[arg-type, union-attr]
+    return (
+        is_graph_input(ag_node.args[0])  # type: ignore[arg-type, union-attr]
+        or (  # type: ignore[arg-type, union-attr]
+            ag_node.args[0].op == "call_function"  # type: ignore[arg-type, union-attr]
+            and ag_node.args[0].target  # type: ignore[arg-type, union-attr]
+            == torch.ops.prims.convert_element_type.default  # type: ignore[arg-type, union-attr]
+            and is_graph_input(ag_node.args[0].args[0])  # type: ignore[arg-type, union-attr]
+        )
+    )
+
+
+def is_graph_output(node: torch.fx.Node) -> bool:
+    return all(user.op == "output" for user in node.users)
+
+
+def is_fsdp_reduce_scatter_wait(wait: torch.fx.Node) -> bool:
+    if is_graph_output(wait):
+        return True
+
+    if len(wait.users) == 1:
+        user = next(iter(wait.users))
+        assert user is not None
+        return (
+            is_graph_output(user)
+            and user.op == "call_function"
+            and user.target is torch.ops.prims.convert_element_type.default
+        )
+
+    return False
+
+
+def bucket_fsdp_all_gather(
+    gm: torch.fx.GraphModule,
+    bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
+    mode: BucketMode = "default",
+) -> None:
+    """
+    Bucketing pass for SimpleFSDP all_gather ops.
+
+    Attributes:
+        gm (torch.fx.GraphModule): Graph module of the graph.
+        bucket_cap_mb_by_bucket_idx (Callable[[int], float] | None): callback function that
+            takes in bucket id and returns size of a bucket in megabytes.
+    """
+    if bucket_cap_mb_by_bucket_idx is None:
+        from torch._inductor.fx_passes.bucketing import (
+            bucket_cap_mb_by_bucket_idx_default,
+        )
+
+        bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
+    assert bucket_cap_mb_by_bucket_idx is not None
+    ag_buckets = bucket_all_gather_by_mb(
+        gm,
+        bucket_cap_mb_by_bucket_idx,
+        filter_wait_node=is_fsdp_all_gather_wait,
+    )
+    if len(ag_buckets) == 0:
+        return
+    merge_all_gather(gm, ag_buckets, mode)
+
+
+def bucket_fsdp_reduce_scatter(
+    gm: torch.fx.GraphModule,
+    bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
+    mode: BucketMode = "default",
+) -> None:
+    """
+    Bucketing pass for SimpleFSDP reduce_scatter ops.
+
+    Attributes:
+        gm (torch.fx.GraphModule): Graph module of the graph.
+        bucket_cap_mb_by_bucket_idx (Callable[[int], float] | None): callback function that
+            takes in bucket idx and returns size of a bucket in megabytes. By default
+            torch._inductor.fx_passes.bucketing.bucket_cap_mb_by_bucket_idx_default is used.
+
+    """
+    if bucket_cap_mb_by_bucket_idx is None:
+        from torch._inductor.fx_passes.bucketing import (
+            bucket_cap_mb_by_bucket_idx_default,
+        )
+
+        bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
+    rs_buckets = bucket_reduce_scatter_by_mb(
+        gm,
+        bucket_cap_mb_by_bucket_idx,
+        filter_wait_node=is_fsdp_reduce_scatter_wait,
+    )
+    if len(rs_buckets) == 0:
+        return
+    merge_reduce_scatter(gm, rs_buckets, mode)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/fuse_attention.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/fuse_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a09d2531348849ed997fc762aef44a09c43e6a9
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/fuse_attention.py
@@ -0,0 +1,1152 @@
+# mypy: allow-untyped-defs
+import functools
+import inspect
+import logging
+import math
+
+import torch
+
+from ..._dynamo.utils import counters
+from ..pattern_matcher import (
+    filter_nodes,
+    fwd_only,
+    gen_register_replacement,
+    joint_fwd_bwd,
+)
+
+
+log = logging.getLogger(__name__)
+aten = torch.ops.aten
+
+_scaled_dot_product_attention = aten.scaled_dot_product_attention
+
+
+def _sfdp_pattern_1(query, key, value, inv_scale):
+    return (
+        torch.matmul(query, key.transpose(-2, -1))
+        .div(inv_scale)
+        .softmax(dim=-1)
+        .matmul(value)
+    )
+
+
+def _sfdp_replacement_1(query, key, value, inv_scale):
+    counters["inductor"]["fuse_attention"] += 1
+    return _scaled_dot_product_attention(
+        query,
+        key,
+        value,
+        attn_mask=None,
+        dropout_p=0.0,
+        is_causal=False,
+        scale=1.0 / inv_scale,
+    )
+
+
+def _sfdp_pattern_2(query, key, value, scale_factor):
+    return (
+        torch.matmul(query, key.transpose(-2, -1))
+        .mul(scale_factor)
+        .softmax(dim=-1)
+        .matmul(value)
+    )
+
+
+def _sfdp_replacement_2(query, key, value, scale_factor):
+    counters["inductor"]["fuse_attention"] += 1
+    return _scaled_dot_product_attention(
+        query,
+        key,
+        value,
+        attn_mask=None,
+        dropout_p=0.0,
+        is_causal=False,
+        scale=scale_factor,
+    )
+
+
+def _sfdp_pattern_3(query, key, value, inv_scale_factor, dropout_p):
+    return torch.nn.functional.dropout(
+        torch.matmul(query, key.transpose(-2, -1))
+        .div(inv_scale_factor)
+        .softmax(dim=-1),
+        p=dropout_p,
+    ).matmul(value)
+
+
+def _sfdp_replacement_3(query, key, value, inv_scale_factor, dropout_p):
+    counters["inductor"]["fuse_attention"] += 1
+    return _scaled_dot_product_attention(
+        query,
+        key,
+        value,
+        attn_mask=None,
+        dropout_p=dropout_p,
+        is_causal=False,
+        scale=1.0 / inv_scale_factor,
+    )
+
+
+def _sfdp_pattern_4(query, key, value, scale_factor, dropout_p):
+    return torch.nn.functional.dropout(
+        torch.matmul(query, key.transpose(-2, -1)).mul(scale_factor).softmax(dim=-1),
+        p=dropout_p,
+    ).matmul(value)
+
+
+def _sfdp_replacement_4(query, key, value, scale_factor, dropout_p):
+    counters["inductor"]["fuse_attention"] += 1
+    return _scaled_dot_product_attention(
+        query,
+        key,
+        value,
+        attn_mask=None,
+        dropout_p=dropout_p,
+        is_causal=False,
+        scale=scale_factor,
+    )
+
+
+def _sfdp_pattern_5(query, key, value, attn_mask):
+    attn_weight = torch.softmax(
+        (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1
+    )
+    # attn_weight = torch.dropout(attn_weight, dropout_p)
+    return attn_weight @ value
+
+
+def _sfdp_replacement_5(query, key, value, attn_mask):
+    counters["inductor"]["fuse_attention"] += 1
+    return _scaled_dot_product_attention(
+        query,
+        key,
+        value,
+        attn_mask=attn_mask.to(dtype=query.dtype),
+        dropout_p=0.0,
+        is_causal=False,
+    )
+
+
+def _sfdp_pattern_6(query, key, value, attn_mask, dropout_p):
+    attn_weight = torch.softmax(
+        (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1
+    )
+    attn_weight = torch.dropout(attn_weight, dropout_p, True)
+    return attn_weight @ value
+
+
+def _sfdp_replacement_6(query, key, value, attn_mask, dropout_p):
+    counters["inductor"]["fuse_attention"] += 1
+    return _scaled_dot_product_attention(
+        query,
+        key,
+        value,
+        attn_mask=attn_mask.to(dtype=query.dtype),
+        dropout_p=dropout_p,
+        is_causal=False,
+    )
+
+
+def _sfdp_pattern_7(query, key, value, dropout_p):
+    # in real workloads inputs to matmul are permuted
+    # causing matmul to expand to a series of expand and clone calls
+    # we want the same to happen during pattern tracing
+    q = query.permute(0, 2, 1, 3)
+    k = key.permute(0, 2, 1, 3)
+    v = value.permute(0, 2, 1, 3)
+    div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
+    div = div.to(torch.float32)
+    attn_weight = torch.softmax(div, dim=-1)
+    attn_weight = torch.dropout(attn_weight, dropout_p, True)
+    attn_weight = attn_weight.to(torch.float16)
+    return attn_weight @ v
+
+
+def _sfdp_replacement_7(query, key, value, dropout_p):
+    # sdpa prefers inputs in permuted format
+    # it makes a copy to put them in this format
+    # if they aren't already
+    # to make replacement efficient ensure that inputs to sdpa
+    # are in required order
+    counters["inductor"]["fuse_attention"] += 1
+    q = query.permute(0, 2, 1, 3)
+    k = key.permute(0, 2, 1, 3)
+    v = value.permute(0, 2, 1, 3)
+    return _scaled_dot_product_attention(
+        q,
+        k,
+        v,
+        attn_mask=None,  # attn_mask,
+        dropout_p=dropout_p,
+        is_causal=False,
+    )
+
+
+def _sfdp_pattern_8(query, key, value):
+    # no dropout version of pattern 7
+    q = query.permute(0, 2, 1, 3)
+    k = key.permute(0, 2, 1, 3)
+    v = value.permute(0, 2, 1, 3)
+    div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
+    div = div.to(torch.float32)
+    attn_weight = torch.softmax(div, dim=-1)
+    attn_weight = attn_weight.to(torch.float16)
+    return attn_weight @ v
+
+
+def _sfdp_replacement_8(query, key, value):
+    counters["inductor"]["fuse_attention"] += 1
+    q = query.permute(0, 2, 1, 3)
+    k = key.permute(0, 2, 1, 3)
+    v = value.permute(0, 2, 1, 3)
+    return _scaled_dot_product_attention(
+        q,
+        k,
+        v,
+        attn_mask=None,  # attn_mask,
+        dropout_p=0.0,
+        is_causal=False,
+    )
+
+
+def _sfdp_pattern_9(query, key, value, dropout_p):
+    q = query.permute(0, 2, 1, 3)
+    k = key.permute(0, 2, 1, 3)
+    v = value.permute(0, 2, 1, 3)
+    q = q / math.sqrt(q.size(-1))
+    div = q @ k.transpose(-2, -1)
+    div = div.to(torch.float32)
+    attn_weight = torch.softmax(div, dim=-1)
+    attn_weight = torch.dropout(attn_weight, dropout_p, True)
+    attn_weight = attn_weight.to(torch.float16)
+    return attn_weight @ v
+
+
+def _sfdp_replacement_9(query, key, value, dropout_p):
+    counters["inductor"]["fuse_attention"] += 1
+    q = query.permute(0, 2, 1, 3)
+    k = key.permute(0, 2, 1, 3)
+    v = value.permute(0, 2, 1, 3)
+    return _scaled_dot_product_attention(
+        q,
+        k,
+        v,
+        attn_mask=None,  # attn_mask,
+        dropout_p=dropout_p,
+        is_causal=False,
+    )
+
+
+def _sfdp_pattern_10(query, key, value):
+    # no dropout version of 9
+    q = query.permute(0, 2, 1, 3)
+    k = key.permute(0, 2, 1, 3)
+    v = value.permute(0, 2, 1, 3)
+    q = q / math.sqrt(q.size(-1))
+    div = q @ k.transpose(-2, -1)
+    div = div.to(torch.float32)
+    attn_weight = torch.softmax(div, dim=-1)
+    attn_weight = attn_weight.to(torch.float16)
+    return attn_weight @ v
+
+
+def _sfdp_replacement_10(query, key, value):
+    counters["inductor"]["fuse_attention"] += 1
+    q = query.permute(0, 2, 1, 3)
+    k = key.permute(0, 2, 1, 3)
+    v = value.permute(0, 2, 1, 3)
+    return _scaled_dot_product_attention(
+        q,
+        k,
+        v,
+        attn_mask=None,  # attn_mask,
+        dropout_p=0.0,
+        is_causal=False,
+    )
+
+
+def _sfdp_pattern_11(query, key, value, inv_scale):
+    # Mainly for huggingface models
+    q = query.permute(0, 2, 1, 3)
+    k = key.permute(0, 2, 1, 3)
+    v = value.permute(0, 2, 1, 3)
+    return torch.matmul(q, k.transpose(-2, -1)).div(inv_scale).softmax(dim=-1).matmul(v)
+
+
+def _sfdp_replacement_11(query, key, value, inv_scale):
+    counters["inductor"]["fuse_attention"] += 1
+    return _scaled_dot_product_attention(
+        query.transpose(1, 2),
+        key.transpose(1, 2),
+        value.transpose(1, 2),
+        attn_mask=None,
+        dropout_p=0.0,
+        is_causal=False,
+        scale=1.0 / inv_scale,
+    )
+
+
+def _sfdp_pattern_12(query, key, value, inv_scale_factor, dropout_p):
+    q = query.permute(0, 2, 1, 3)
+    k = key.permute(0, 2, 1, 3)
+    v = value.permute(0, 2, 1, 3)
+    return torch.nn.functional.dropout(
+        torch.matmul(q, k.transpose(-2, -1)).div(inv_scale_factor).softmax(dim=-1),
+        p=dropout_p,
+    ).matmul(v)
+
+
+def _sfdp_replacement_12(query, key, value, inv_scale_factor, dropout_p):
+    counters["inductor"]["fuse_attention"] += 1
+    return _scaled_dot_product_attention(
+        query.transpose(1, 2),
+        key.transpose(1, 2),
+        value.transpose(1, 2),
+        attn_mask=None,
+        dropout_p=dropout_p,
+        is_causal=False,
+        scale=1.0 / inv_scale_factor,
+    )
+
+
+def _sfdp_pattern_13(query, key, value, dropout_p):
+    attn_weight = torch.bmm(query, key.transpose(1, 2)).softmax(dim=-1)
+    attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p)
+    return torch.bmm(attn_weight, value)
+
+
+def _sfdp_replacement_13(query, key, value, dropout_p):
+    counters["inductor"]["fuse_attention"] += 1
+    return _scaled_dot_product_attention(
+        query.unsqueeze(0),
+        key.unsqueeze(0),
+        value.unsqueeze(0),
+        dropout_p=dropout_p,
+        scale=1.0,
+    ).squeeze(0)
+
+
+def _sfdp_pattern_14(query, key, value, attn_mask, inv_scale):
+    # for BertLarge
+    # Permutations are needed to create clones in graph.
+    q = query.permute([0, 2, 1, 3])
+    k = key.permute([0, 2, 1, 3])
+    v = value.permute([0, 2, 1, 3])
+    return (
+        (torch.matmul(q, k.transpose(-2, -1)).div(inv_scale) + attn_mask)
+        .softmax(dim=-1)
+        .matmul(v)
+    )
+
+
+def _sfdp_replacement_14(query, key, value, attn_mask, inv_scale):
+    counters["inductor"]["fuse_attention"] += 1
+    return _scaled_dot_product_attention(
+        query.transpose(1, 2),
+        key.transpose(1, 2),
+        value.transpose(1, 2),
+        attn_mask=attn_mask.to(dtype=query.dtype),
+        dropout_p=0.0,
+        is_causal=False,
+        scale=1.0 / inv_scale,
+    )
+
+
+def _sfdp_pattern_15(query, key, value, attn_mask, inv_scale):
+    # for DistilBert
+    # Permutations are needed to create clones in graph.
+    # Ref: https://github.com/pytorch/pytorch/issues/119911
+    q = query.permute([0, 2, 1, 3])
+    k = key.permute([0, 2, 1, 3])
+    v = value.permute([0, 2, 1, 3])
+    bs = q.size(0)
+    k_len = k.size(-2)
+    scores = q @ k.transpose(-2, -1)
+    scores = scores.div(inv_scale)
+    fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device)
+    attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores)
+    return torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1) @ v
+
+
+def _sfdp_replacement_15(query, key, value, attn_mask, inv_scale):
+    counters["inductor"]["fuse_attention"] += 1
+    bs = query.size(0)
+    n_head = query.size(2)
+    q_len = query.size(1)
+    k_len = key.size(1)
+    # do attn_mask->logical_not() in _scaled_dot_product_attention
+    attn_mask = (
+        (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len))
+    )
+    return _scaled_dot_product_attention(
+        query.transpose(1, 2),
+        key.transpose(1, 2),
+        value.transpose(1, 2),
+        attn_mask=attn_mask.to(dtype=torch.bool),
+        dropout_p=0.0,
+        is_causal=False,
+        scale=1.0 / inv_scale,
+    )
+
+
+def _sfdp_pattern_16(query, key, value, attn_mask, inv_scale, dropout_p):
+    # for BertLarge with dropout
+    q = query.permute([0, 2, 1, 3])
+    k = key.permute([0, 2, 1, 3])
+    v = value.permute([0, 2, 1, 3])
+    return (
+        torch.nn.functional.dropout(
+            (torch.matmul(q, k.transpose(-2, -1)).div(inv_scale) + attn_mask).softmax(
+                dim=-1
+            ),
+            dropout_p,
+        )
+        .to(dtype=query.dtype)
+        .matmul(v)
+    )
+
+
+def _sfdp_replacement_16(query, key, value, attn_mask, inv_scale, dropout_p):
+    counters["inductor"]["fuse_attention"] += 1
+    return _scaled_dot_product_attention(
+        query.transpose(1, 2),
+        key.transpose(1, 2),
+        value.transpose(1, 2),
+        attn_mask=attn_mask.to(dtype=query.dtype),
+        dropout_p=dropout_p,
+        is_causal=False,
+        scale=1.0 / inv_scale,
+    )
+
+
+def _sfdp_pattern_17(query, key, value, attn_mask, inv_scale, dropout_p):
+    # for DistilBert with dropout
+    q = query.permute([0, 2, 1, 3])
+    k = key.permute([0, 2, 1, 3])
+    v = value.permute([0, 2, 1, 3])
+    bs = q.size(0)
+    k_len = k.size(-2)
+    scores = q @ k.transpose(-2, -1)
+    scores = scores.div(inv_scale)
+    fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device)
+    attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores)
+    return (
+        torch.nn.functional.dropout(
+            torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1), dropout_p
+        )
+        @ v
+    )
+
+
+def _sfdp_replacement_17(query, key, value, attn_mask, inv_scale, dropout_p):
+    counters["inductor"]["fuse_attention"] += 1
+    bs = query.size(0)
+    n_head = query.size(2)
+    q_len = query.size(1)
+    k_len = key.size(1)
+    # do attn_mask->logical_not() in _scaled_dot_product_attention
+    attn_mask = (
+        (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len))
+    )
+    return _scaled_dot_product_attention(
+        query.transpose(1, 2),
+        key.transpose(1, 2),
+        value.transpose(1, 2),
+        attn_mask=attn_mask.to(dtype=torch.bool),
+        dropout_p=dropout_p,
+        is_causal=False,
+        scale=1.0 / inv_scale,
+    )
+
+
+def _sfdp_pattern_18(query, key, value, causal_mask, dropout_p):
+    # for hf_GPT2 with dropout (introduces clone node) for inference
+    # it also returns permuted key & value
+    query = query.permute([0, 2, 1, 3])
+    key = key.permute([0, 2, 1, 3])
+    value = value.permute([0, 2, 1, 3])
+    attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2))
+    inv_scale = torch.full(
+        [],
+        value.size(-1) ** 0.5,
+        dtype=attn_weights.dtype,
+        device=attn_weights.device,
+    )
+    attn_weights = attn_weights.div(inv_scale)
+    causal_mask_value = torch.full(
+        (), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device
+    )
+    attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value)
+    return (
+        (
+            torch.nn.functional.dropout(attn_weights.softmax(dim=-1), dropout_p).matmul(
+                value
+            )
+        ),
+        key,
+        value,
+    )
+
+
+def _sfdp_replacement_18(query, key, value, causal_mask, dropout_p):
+    counters["inductor"]["fuse_attention"] += 1
+    permuted_key = key.transpose(1, 2)
+    permuted_value = value.transpose(1, 2)
+    return (
+        _scaled_dot_product_attention(
+            query.transpose(1, 2),
+            permuted_key,
+            permuted_value,
+            attn_mask=causal_mask,
+            dropout_p=dropout_p,
+            is_causal=False,
+            scale=1.0 / math.sqrt(value.size(-1)),
+        ),
+        permuted_key,
+        permuted_value,
+    )
+
+
+def _sfdp_pattern_19(query, key, value, causal_mask, attn_mask, dropout_p):
+    # for token-classification+gpt2 / text-generation+gpt2
+    attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2))
+    inv_scale = torch.full(
+        [],
+        value.size(-1) ** 0.5,
+        dtype=attn_weights.dtype,
+        device=attn_weights.device,
+    )
+    attn_weights = attn_weights.div(inv_scale)
+    causal_mask_value = torch.full(
+        (), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device
+    )
+    attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value)
+    attn_weights = attn_weights + attn_mask
+    attn_weights = attn_weights.softmax(dim=-1).type(value.dtype)
+    return torch.nn.functional.dropout(attn_weights, dropout_p).matmul(value)
+
+
+def _sfdp_replacement_19(query, key, value, causal_mask, attn_mask, dropout_p):
+    counters["inductor"]["fuse_attention"] += 1
+    fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device)
+    attn_mask = torch.where(causal_mask, attn_mask, fill_value)
+    return _scaled_dot_product_attention(
+        query,
+        key,
+        value,
+        attn_mask=attn_mask,
+        dropout_p=dropout_p,
+        is_causal=False,
+        scale=1.0 / math.sqrt(value.size(-1)),
+    )
+
+
+def _sfdp_pattern_20(query, key, value, attn_mask, dropout_p):
+    # for DistilBert with dropout transformers==4.44.2
+    q = query.permute([0, 2, 1, 3])
+    k = key.permute([0, 2, 1, 3])
+    v = value.permute([0, 2, 1, 3])
+    bs = q.size(0)
+    k_len = k.size(-2)
+    q = q.div(math.sqrt(q.size(-1)))
+    scores = q @ k.transpose(-2, -1)
+    fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device)
+    attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores)
+    return (
+        torch.nn.functional.dropout(
+            torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1), dropout_p
+        )
+        @ v
+    )
+
+
+def _sfdp_replacement_20(query, key, value, attn_mask, dropout_p):
+    counters["inductor"]["fuse_attention"] += 1
+    bs = query.size(0)
+    n_head = query.size(2)
+    q_len = query.size(1)
+    k_len = key.size(1)
+    # do attn_mask->logical_not() in _scaled_dot_product_attention
+    attn_mask = (
+        (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len))
+    )
+    return _scaled_dot_product_attention(
+        query.transpose(1, 2),
+        key.transpose(1, 2),
+        value.transpose(1, 2),
+        attn_mask=attn_mask.to(dtype=torch.bool),
+        dropout_p=dropout_p,
+        is_causal=False,
+        scale=1.0 / math.sqrt(query.size(-1)),
+    )
+
+
+def _sfdp_pattern_21(query, key, value, attn_mask):
+    # for T5 with inplace add
+    query = query.permute([0, 2, 1, 3])
+    key = key.permute([0, 2, 1, 3])
+    value = value.permute([0, 2, 1, 3])
+    score = torch.matmul(query, key.permute(0, 1, 3, 2))
+    masked_score = score + attn_mask
+    score = masked_score.type_as(query)
+    viewd_score1 = score.view(
+        score.size(0) * score.size(1), score.size(2), score.size(3)
+    )
+    viewd_score2 = viewd_score1.view(
+        score.size(0), score.size(1), score.size(2), score.size(3)
+    )
+    return viewd_score2.float().softmax(dim=-1).type_as(query).matmul(value)
+
+
+def _sfdp_replacement_21(query, key, value, attn_mask):
+    counters["inductor"]["fuse_attention"] += 1
+    query = query.permute(0, 2, 1, 3)
+    key = key.permute(0, 2, 1, 3)
+    value = value.permute(0, 2, 1, 3)
+    return _scaled_dot_product_attention(
+        query,
+        key,
+        value,
+        attn_mask=attn_mask.to(dtype=query.dtype),
+        is_causal=False,
+        scale=1.0,
+    )
+
+
+def _sfdp_pattern_22(query, key, value, attn_mask):
+    # for T5 with inplace add and return key and value
+    query = query.permute([0, 2, 1, 3])
+    key = key.permute([0, 2, 1, 3])
+    value = value.permute([0, 2, 1, 3])
+    score = torch.matmul(query, key.permute(0, 1, 3, 2))
+    masked_score = score + attn_mask
+    score = masked_score.type_as(query)
+    viewd_score1 = score.view(
+        score.size(0) * score.size(1), score.size(2), score.size(3)
+    )
+    viewd_score2 = viewd_score1.view(
+        score.size(0), score.size(1), score.size(2), score.size(3)
+    )
+    return viewd_score2.float().softmax(dim=-1).type_as(query).matmul(value), key, value
+
+
+def _sfdp_replacement_22(query, key, value, attn_mask):
+    counters["inductor"]["fuse_attention"] += 1
+    query = query.permute(0, 2, 1, 3)
+    key = key.permute(0, 2, 1, 3)
+    value = value.permute(0, 2, 1, 3)
+    return (
+        _scaled_dot_product_attention(
+            query,
+            key,
+            value,
+            attn_mask=attn_mask.to(dtype=query.dtype),
+            is_causal=False,
+            scale=1.0,
+        ),
+        key,
+        value,
+    )
+
+
+def _sfdp_pattern_23(query, key, value):
+    # for T5 with inplace add and
+    # return key and value and
+    # attn_mask is generated by atem.full(..., 0)
+    query = query.permute([0, 2, 1, 3])
+    key = key.permute([0, 2, 1, 3])
+    value = value.permute([0, 2, 1, 3])
+    score = torch.matmul(query, key.permute(0, 1, 3, 2))
+    fp32_score = score.float()
+    score = fp32_score.type_as(query)
+    viewd_score1 = score.view(
+        score.size(0) * score.size(1), score.size(2), score.size(3)
+    )
+    viewd_score2 = viewd_score1.view(
+        score.size(0), score.size(1), score.size(2), score.size(3)
+    )
+    return viewd_score2.float().softmax(dim=-1).type_as(query).matmul(value), key, value
+
+
+def _sfdp_replacement_23(query, key, value):
+    counters["inductor"]["fuse_attention"] += 1
+    query = query.permute(0, 2, 1, 3)
+    key = key.permute(0, 2, 1, 3)
+    value = value.permute(0, 2, 1, 3)
+    return (
+        _scaled_dot_product_attention(
+            query,
+            key,
+            value,
+            attn_mask=None,
+            is_causal=False,
+            scale=1.0,
+        ),
+        key,
+        value,
+    )
+
+
+def _sfdp_pattern_24(query, key, value, attention_mask):
+    """
+    this pattern is for MBartForCausalLM/PLBartForCausalLM.
+    attn_mask has a different dtype with QKV.
+    there is no scale in sdpa.
+    """
+    bs = query.size(0)
+    n_head = query.size(1)
+    seq_len = query.size(2)
+    head_size = query.size(3)
+    q = query.view(bs * n_head, -1, head_size)
+    k = key.reshape(bs * n_head, -1, head_size)
+    v = value.reshape(bs * n_head, -1, head_size)
+    attn_weights = torch.bmm(q, k.transpose(1, 2))
+    attn_weights = attn_weights.view(bs, n_head, seq_len, -1) + attention_mask
+    attn_weights = attn_weights.view(bs * n_head, seq_len, -1)
+    attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
+    if query.dtype == torch.half:
+        attn_weights = attn_weights.to(torch.half)
+    attn_output = torch.bmm(attn_weights, v)
+    attn_output = attn_output.view(bs, n_head, seq_len, head_size)
+    return attn_output
+
+
+def _sfdp_replacement_24(query, key, value, attention_mask):
+    counters["inductor"]["fuse_attention"] += 1
+    return _scaled_dot_product_attention(
+        query,
+        key,
+        value,
+        attn_mask=attention_mask.to(dtype=query.dtype),
+        is_causal=False,
+        scale=1,
+    )
+
+
+def _sfdp_params_check(match):
+    assert all(k in match.kwargs for k in ("query", "key", "value"))
+    query = match.kwargs["query"].meta["val"]
+    key = match.kwargs["key"].meta["val"]
+    value = match.kwargs["value"].meta["val"]
+    if not (query.dtype == key.dtype == value.dtype) or not (
+        query.device == key.device == value.device
+    ):
+        return False
+    add_mask_node = filter_nodes(match.nodes, aten.add.Tensor)
+    # Has attn_mask add.
+    if len(add_mask_node) > 0:
+        attn_mask_node = add_mask_node[0].args[1]
+        # attn_mask_node may be a float/int number.
+        if not hasattr(attn_mask_node, "meta"):
+            return False
+        attn_mask = attn_mask_node.meta["val"]  # type: ignore[union-attr]
+        # Make sure attn_mask.dtype == query.dtype or attn_mask.dtype == torch.bool
+        # attn_mask.dtype == torch.float for models like albert.
+        if (
+            not isinstance(attn_mask, torch.Tensor)
+            or not (
+                attn_mask.dtype == query.dtype
+                or attn_mask.dtype == torch.bool
+                or attn_mask.dtype == torch.float
+            )
+            or query.device != attn_mask.device
+            # When we tensorify floats we end up turning floats
+            # into 0d scalar tensors. It doesn't make any sense
+            # to have a 0d scalar tensor attention mask so
+            # conveniently we can insert this check to get
+            # tests that erroneously passing in a float
+            # attention mask to fail as expected.
+            or attn_mask.dim() == 0
+        ):
+            return False
+    return True
+
+
+def _sfdp_extra_check(scale_factor_op=None, disable_cuda=False):
+    def fn(match):
+        if (
+            disable_cuda
+            and "query" in match.kwargs
+            and "cuda" in str(match.kwargs["query"].meta["val"].device)
+        ):
+            return False
+        if scale_factor_op is not None:
+            scale_factor_node = filter_nodes(match.nodes, scale_factor_op)[0]
+            # Note: args[1] of the scale_factor_node is always the scale_factor for the current patterns.
+            scale_factor = scale_factor_node.args[1]
+            # make sure the scale_factor a float/int. SymInt?
+            if not isinstance(scale_factor, (float, int)):
+                return False
+        return _sfdp_params_check(match)
+
+    return fn
+
+
+def partialize_and_update_signature(func, **kwargs):
+    """
+    Equivalent to functools.partial but also updates the signature on returned function
+    """
+    original_sig = inspect.signature(func)
+    parameters = original_sig.parameters
+
+    new_parameters = {
+        key: value for key, value in parameters.items() if key not in kwargs
+    }
+    new_sig = inspect.Signature(parameters=list(new_parameters.values()))
+
+    partial_func = functools.partial(func, **kwargs)
+
+    def wrapper(*args, **kwargs):
+        return partial_func(*args, **kwargs)
+
+    wrapper.__signature__ = new_sig  # type: ignore[attr-defined]
+    wrapper.__name__ = func.__name__
+
+    return wrapper
+
+
+def _get_sfdp_patterns():
+    from .joint_graph import patterns
+
+    if torch.cuda.is_available():
+        # workaround https://github.com/pytorch/pytorch/issues/97894
+        device = "cuda"
+    else:
+        device = "cpu"
+
+    # sizes/values don't actually matter for initial trace
+    # once we get a possible match we re-trace with the actual values and verify the match still holds
+    g_inp = functools.partial(
+        torch.empty, (2, 4, 8, 16), device=device, requires_grad=True
+    )
+    # attn_mask
+    b_inp = functools.partial(torch.empty, (1, 1, 8, 8), device=device)
+    m_inp = functools.partial(torch.empty, (2, 1, 1, 4), device=device)
+    # need 2d attn_mask to generate patterns with view op
+    m_inp_2d = functools.partial(torch.empty, (2, 4), device=device)
+    # inv_scale
+    c_inp = functools.partial(torch.tensor, 2.0, device=device)
+    # workaround https://github.com/pytorch/pytorch/issues/97894
+    # 0.113377 is a "magic" value that lets us recover the lost input arg relationship
+    d = {"dropout_p": 0.113377}
+
+    # we could also generate all these patterns in 3d.. TODO
+    g_3d_inp = functools.partial(
+        torch.empty, (1024, 128, 128), device=device, requires_grad=True
+    )
+
+    # reshape in matmul decomposition generates a clone when batch_size>1 due to the memory layout change.
+    # however when batch_size=1, reshape does not change the memory layout, so clone would not be generated.
+    # here we need to trace with input of batch_size=1 to generate a pattern graph without clone.
+    g_bs1_inp = functools.partial(
+        torch.empty, (1, 4, 8, 16), device=device, requires_grad=True
+    )
+    m_bs1_inp = functools.partial(torch.empty, (1, 1, 1, 4), device=device)
+
+    # softmax will generate a dtype conversion on inputs if they are in half,
+    # but will not in float, so we generate a pattern for both
+    for dtype in [torch.float, torch.half]:
+        g = functools.partial(g_inp, dtype=dtype)
+        b = functools.partial(b_inp, dtype=dtype)
+        b_float = functools.partial(b_inp, dtype=torch.float)
+        b_bool = functools.partial(b_inp, dtype=torch.bool)
+        m = functools.partial(m_inp, dtype=dtype)
+        m_float = functools.partial(m_inp, dtype=torch.float)
+        m_bool = functools.partial(m_inp, dtype=torch.bool)
+        m_2d = functools.partial(m_inp_2d, dtype=dtype)
+        c = functools.partial(c_inp, dtype=dtype)
+        g_3d = functools.partial(g_3d_inp, dtype=dtype)
+        g_bs1 = functools.partial(g_bs1_inp, dtype=dtype)
+        m_bs1 = functools.partial(m_bs1_inp, dtype=dtype)
+        m_bs1_float = functools.partial(m_bs1_inp, dtype=torch.float)
+        m_bs1_bool = functools.partial(m_bs1_inp, dtype=torch.bool)
+
+        candidates = [
+            (
+                _sfdp_pattern_1,
+                _sfdp_replacement_1,
+                [g(), g(), g(), c()],
+                {},
+                _sfdp_extra_check(aten.div.Tensor),
+            ),
+            (
+                _sfdp_pattern_2,
+                _sfdp_replacement_2,
+                [g(), g(), g(), c()],
+                {},
+                _sfdp_extra_check(aten.mul.Tensor),
+            ),
+            (
+                _sfdp_pattern_3,
+                _sfdp_replacement_3,
+                [g(), g(), g(), c()],
+                d,
+                _sfdp_extra_check(aten.div.Tensor),
+            ),
+            (
+                _sfdp_pattern_4,
+                _sfdp_replacement_4,
+                [g(), g(), g(), c()],
+                d,
+                _sfdp_extra_check(aten.mul.Tensor),
+            ),
+            (
+                _sfdp_pattern_5,
+                _sfdp_replacement_5,
+                [g(), g(), g(), b()],
+                {},
+                _sfdp_params_check,
+            ),
+            (
+                _sfdp_pattern_6,
+                _sfdp_replacement_6,
+                [g(), g(), g(), b()],
+                d,
+                _sfdp_params_check,
+            ),
+            (
+                _sfdp_pattern_7,
+                _sfdp_replacement_7,
+                [g(), g(), g()],
+                d,
+                _sfdp_params_check,
+            ),
+            (
+                _sfdp_pattern_8,
+                _sfdp_replacement_8,
+                [g(), g(), g()],
+                {},
+                _sfdp_params_check,
+            ),
+            (
+                _sfdp_pattern_9,
+                _sfdp_replacement_9,
+                [g(), g(), g()],
+                d,
+                _sfdp_params_check,
+            ),
+            (
+                _sfdp_pattern_10,
+                _sfdp_replacement_10,
+                [g(), g(), g()],
+                {},
+                _sfdp_params_check,
+            ),
+            (
+                _sfdp_pattern_11,
+                _sfdp_replacement_11,
+                [g(), g(), g(), c()],
+                {},
+                _sfdp_extra_check(aten.div.Tensor),
+            ),
+            (
+                _sfdp_pattern_12,
+                _sfdp_replacement_12,
+                [g(), g(), g(), c()],
+                d,
+                _sfdp_extra_check(aten.div.Tensor),
+            ),
+            (
+                _sfdp_pattern_13,
+                _sfdp_replacement_13,
+                [g_3d(), g_3d(), g_3d()],
+                d,
+                _sfdp_params_check,
+            ),
+            (
+                _sfdp_pattern_14,
+                _sfdp_replacement_14,
+                [g(), g(), g(), m(), c()],
+                {},
+                _sfdp_extra_check(aten.div.Tensor),
+            ),
+            (
+                _sfdp_pattern_15,
+                _sfdp_replacement_15,
+                [g(), g(), g(), m_2d(), c()],
+                {},
+                _sfdp_extra_check(aten.div.Tensor),
+            ),
+            # TODO: Enable CUDA after solving Bert accuracy issue of calling efficient attention
+            (
+                _sfdp_pattern_16,
+                _sfdp_replacement_16,
+                [g(), g(), g(), m(), c()],
+                d,
+                _sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
+            ),
+            (
+                _sfdp_pattern_16,
+                _sfdp_replacement_16,
+                [g_bs1(), g_bs1(), g_bs1(), m_bs1(), c()],
+                d,
+                _sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
+            ),
+            (
+                _sfdp_pattern_17,
+                _sfdp_replacement_17,
+                [g(), g(), g(), m_2d(), c()],
+                d,
+                _sfdp_extra_check(aten.div.Tensor),
+            ),
+            (
+                _sfdp_pattern_18,
+                _sfdp_replacement_18,
+                [g(), g(), g(), m_bool()],
+                d,
+                _sfdp_params_check,
+            ),
+            (
+                _sfdp_pattern_18,
+                _sfdp_replacement_18,
+                [g_bs1(), g_bs1(), g_bs1(), m_bs1_bool()],
+                d,
+                _sfdp_params_check,
+            ),
+            (
+                _sfdp_pattern_19,
+                _sfdp_replacement_19,
+                [g(), g(), g(), b_bool(), b_float()],
+                d,
+                _sfdp_params_check,
+            ),
+            (
+                _sfdp_pattern_20,
+                _sfdp_replacement_20,
+                [g(), g(), g(), m_2d()],
+                d,
+                _sfdp_extra_check(aten.div.Tensor),
+            ),
+            (
+                _sfdp_pattern_21,
+                _sfdp_replacement_21,
+                [g(), g(), g(), m_float()],
+                {},
+                _sfdp_params_check,
+            ),
+            (
+                _sfdp_pattern_21,
+                _sfdp_replacement_21,
+                [g_bs1(), g_bs1(), g_bs1(), m_bs1_float()],
+                {},
+                _sfdp_params_check,
+            ),
+            (
+                _sfdp_pattern_22,
+                _sfdp_replacement_22,
+                [g(), g(), g(), m_float()],
+                {},
+                _sfdp_params_check,
+            ),
+            (
+                _sfdp_pattern_22,
+                _sfdp_replacement_22,
+                [g_bs1(), g_bs1(), g_bs1(), m_bs1_float()],
+                {},
+                _sfdp_params_check,
+            ),
+            (
+                _sfdp_pattern_23,
+                _sfdp_replacement_23,
+                [g(), g(), g()],
+                {},
+                _sfdp_params_check,
+            ),
+            (
+                _sfdp_pattern_23,
+                _sfdp_replacement_23,
+                [g_bs1(), g_bs1(), g_bs1()],
+                {},
+                _sfdp_params_check,
+            ),
+            (
+                _sfdp_pattern_24,
+                _sfdp_replacement_24,
+                [g(), g(), g(), b_float()],
+                {},
+                _sfdp_extra_check,
+            ),
+        ]
+        mask_fp32_patterns = ["pattern_16"]
+        if dtype == torch.half:
+            # Add inputs of bf16 q/k/v and fp32 mask, for models like albert.
+            candidates.append(
+                (
+                    _sfdp_pattern_16,
+                    _sfdp_replacement_16,
+                    [g(), g(), g(), m_float(), c()],
+                    d,
+                    _sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
+                )
+            )
+            candidates.append(
+                (
+                    _sfdp_pattern_16,
+                    _sfdp_replacement_16,
+                    [g_bs1(), g_bs1(), g_bs1(), m_bs1_float(), c()],
+                    d,
+                    _sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
+                )
+            )
+
+        for pattern, replacement, args, workaround, extra_check in candidates:
+            # XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern
+            # gets serialized to a python file and does not require tracing at runtime.
+            assert isinstance(workaround, dict)
+            name = pattern.__name__
+
+            if dtype != torch.float:
+                name += "_half"
+                if (
+                    any(p in name for p in mask_fp32_patterns)
+                    and args[3].dtype == torch.float32
+                ):
+                    name += "_mask_fp32"
+            if args[0].size(0) == 1:
+                name += "_bs1"
+
+            training_name = name + "_training"
+            yield (
+                training_name,
+                {
+                    "search_fn": pattern,
+                    "replace_fn": replacement,
+                    "example_inputs": args,
+                    "trace_fn": joint_fwd_bwd,
+                    "pass_dicts": patterns,
+                    "extra_check": extra_check,
+                    "scalar_workaround": workaround,
+                },
+            )
+
+            if workaround:
+                assert len(workaround) == 1 and "dropout_p" in workaround
+                # functools.partial insufficient because we look at signature downstream
+                pattern = partialize_and_update_signature(pattern, dropout_p=0.0)
+                replacement = partialize_and_update_signature(
+                    replacement, dropout_p=0.0
+                )
+                workaround = {}
+
+            inference_name = name + "_inference"
+            yield (
+                inference_name,
+                {
+                    "search_fn": pattern,
+                    "replace_fn": replacement,
+                    "example_inputs": args,
+                    "trace_fn": fwd_only,
+                    "pass_dicts": patterns,
+                    "extra_check": extra_check,
+                    "scalar_workaround": workaround,
+                    # with dropout turned into clone, we end up with a number of
+                    # semantically identical graphs
+                    "skip_duplicates": True,
+                },
+            )
+
+
+@functools.cache
+def _sfdp_init():
+    for key, register_replacement_kwargs in _get_sfdp_patterns():
+        gen_register_replacement(key, **register_replacement_kwargs)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/graph_view.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/graph_view.py
new file mode 100644
index 0000000000000000000000000000000000000000..5758551a9b8a5cad4f2a5aa1a21357a9ab12cfcc
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/graph_view.py
@@ -0,0 +1,240 @@
+from __future__ import annotations
+
+import itertools
+import re
+from typing import Any, Optional, TYPE_CHECKING, Union
+
+import torch.fx as fx  # noqa: TC001
+from torch.utils._ordered_set import OrderedSet
+
+
+if TYPE_CHECKING:
+    from collections.abc import Callable
+
+
+def _get_module_stack(node: fx.Node) -> list[tuple[str, type[Any]]]:
+    nn_stack = node.meta.get("nn_module_stack", "")
+    if nn_stack:
+        return list(nn_stack.values())
+
+    fwd_nn_stack = node.meta.get("fwd_nn_module_stack", "")
+    if fwd_nn_stack:
+        return list(fwd_nn_stack.values())
+
+    return []
+
+
+def _addindent(s_: str, num_spaces: int) -> str:
+    s: list[str] = s_.split("\n")
+    # don't do anything for single-line stuff
+    if len(s) == 1:
+        return s_
+    first: str = s.pop(0)
+    s: list[str] = [(num_spaces * " ") + line for line in s]
+    joint_s: str = "\n".join(s)
+    joint_s = first + "\n" + joint_s
+    return joint_s
+
+
+class GraphView:
+    """
+    A hierarchical class for organizing and managing torch.fx nodes by their module stack.
+
+    This class provides a tree-like structure where each node in the hierarchy corresponds
+    to a module or submodule in a traced FX graph. Each `GraphView` instance can hold a list
+    of FX nodes (`self.data`) belonging to that module scope, maintain a unique set of nodes
+    (`self.unique_nodes`), and manage its child containers (`self.children`).
+
+    Attributes:
+        name (str): The name of the module or container scope.
+        klass (type[Any]): The class type associated with this module/container.
+        data (list[fx.Node]): A list of FX graph nodes belonging to this module.
+        unique_nodes (OrderedSet[fx.Node]): A deduplicated set of nodes to ensure no duplicates.
+        children (dict[str, GraphView]): A mapping of child module names to their corresponding GraphView instances.
+    """
+
+    def __init__(self, name: str, klass: type[Any]) -> None:
+        self.name: str = name
+        self.klass: type[Any] = klass
+        self.data: list[fx.Node] = []
+        self.unique_nodes: OrderedSet[fx.Node] = OrderedSet()
+        self.children: dict[str, GraphView] = {}
+
+    def add(self, data: fx.Node) -> None:
+        if data not in self.unique_nodes:
+            self.data.append(data)
+            self.unique_nodes.add(data)
+
+    def get_child(
+        self, module_stack: str, klass: Optional[type[Any]] = None
+    ) -> GraphView:
+        if module_stack not in self.children:
+            new_stack = GraphView(module_stack, klass or self.klass)
+            self.children[module_stack] = new_stack
+        return self.children[module_stack]
+
+    def __getitem__(self, name: str) -> GraphView:
+        return self.children[name]
+
+    def __getattr__(self, name: str) -> GraphView:
+        return self.children[name]
+
+    def __repr__(self) -> str:
+        child_lines: list[str] = []
+        for name, child in self.children.items():
+            mod_str = repr(child)
+            mod_str = _addindent(mod_str, 2)
+            child_lines.append(f"({name}): {mod_str}")
+        main_str = f"{self.klass.__name__}("
+        if child_lines:
+            main_str += "\n  " + "\n  ".join(child_lines) + "\n"
+        main_str += ")"
+        return main_str
+
+
+def _clean_stack_name(stack_name: str) -> str:
+    """
+    Clean up FX node's nn_module_stack metadata string to match the module name hierarchies
+
+    Example:
+        Input: "L['self']._modules['layers']['0']._modules['attention']"
+        Output: "layers.0.attention"
+    """
+    cleaned = re.sub(r"^L\['self'\]\.?", "", stack_name)
+    parts = re.findall(r"\['([^']+)'\]", cleaned)
+    return ".".join(parts) if parts else cleaned
+
+
+def _is_root(stack: str) -> bool:
+    return stack == ""
+
+
+def make_graph_view(
+    graph: fx.Graph,
+    module_stack_fn: None | Callable[[fx.Node], list[tuple[str, type[Any]]]] = None,
+) -> Optional[GraphView]:
+    """
+    Code from: https://github.com/meta-pytorch/autoparallel/pull/158
+
+    Make a graph view from the fx.Graph. This is a tree structure that
+    represents the module hierarchy of the graph, and enables us to
+    easily find the nodes that belong to each module, and gives a slightly
+    easier way of visualize different parts of the graph by extracting
+    subgraphs that belong to a particular module FQN.
+
+    For example, if we have the following model with module hierarchy:
+
+    Transformer(
+        (tok_embeddings): Embedding(128256, 4096)
+        (layers): ModuleDict(
+            (0): TransformerBlock(
+            (attention): Attention(
+                (wq): Linear(in_features=4096, out_features=4096, bias=False)
+                (wk): Linear(in_features=4096, out_features=1024, bias=False)
+                (wv): Linear(in_features=4096, out_features=1024, bias=False)
+                (wo): Linear(in_features=4096, out_features=4096, bias=False)
+                (sdpa): ScaledDotProductAttention()
+            )
+            (feed_forward): FeedForward(
+                (w1): Linear(in_features=4096, out_features=14336, bias=False)
+                (w2): Linear(in_features=14336, out_features=4096, bias=False)
+                (w3): Linear(in_features=4096, out_features=14336, bias=False)
+            )
+            (attention_norm): RMSNorm((4096,), eps=1e-05, elementwise_affine=True)
+            (ffn_norm): RMSNorm((4096,), eps=1e-05, elementwise_affine=True)
+            )
+        )
+        (norm): RMSNorm((4096,), eps=1e-05, elementwise_affine=True)
+        (output): Linear(in_features=4096, out_features=128256, bias=False)
+    )
+
+    Then we can get a GraphView for the fx.Graph that enables us to do
+
+    graph_view = make_graph_view(graph)
+    subgraph = get_subgraph_by_path(graph_view, "layers.0")
+
+    where subgraph contains all the nodes that belong to this region
+
+    module_stack_fn: Optional callable for extracting module hierarchy information from nodes.
+
+        Signature: Callable[[fx.Node], list[tuple[str, type[Any]]]]
+
+        Takes an FX node and returns a list of (module_path, module_class) tuples representing
+        the nested module hierarchy for that node, ordered from outermost to innermost scope.
+
+        - module_path (str): Dot-separated path identifying the module in the hierarchy
+          (e.g., "layers.0.attention.wq")
+        - module_class (type): The Python class type of the module
+
+        This enables custom logic for determining module membership, useful for:
+        - Graphs without standard nn_module_stack metadata
+        - Filtering or grouping nodes by custom criteria
+
+        Example of getting the module stack from annotation:
+
+        def module_stack_fn(node):
+            module_stack = node.meta.get("custom", {}).get("module_path", "")
+            return [(module_stack, torch.nn.Module)]
+
+        If None, defaults to extracting from node.meta["nn_module_stack"] or
+        node.meta["fwd_nn_module_stack"].
+    """
+
+    def nn_module_stack_meta(node: fx.Node) -> list[tuple[str, type[Any]]]:
+        result = []
+        for module_stack, module_class in _get_module_stack(node):
+            module_stack = _clean_stack_name(module_stack)
+            result.append((module_stack, module_class))
+        return result
+
+    if module_stack_fn is None:
+        module_stack_fn = nn_module_stack_meta
+    nodes: list[fx.Node] = list(graph.nodes)
+    nodes_by_module_stack_root: GraphView | None = None
+    for node in nodes:
+        for module_stack, module_class in module_stack_fn(node):
+            nodes_by_module_stack: GraphView | None = nodes_by_module_stack_root
+            for name in module_stack.split("."):
+                if nodes_by_module_stack is None:
+                    nodes_by_module_stack = GraphView(name, module_class)
+                    nodes_by_module_stack_root = nodes_by_module_stack
+                if _is_root(module_stack):
+                    new_stack: GraphView = nodes_by_module_stack
+                else:
+                    new_stack = nodes_by_module_stack.get_child(name, module_class)
+                nodes_by_module_stack = new_stack
+                nodes_by_module_stack.add(node)
+
+    return nodes_by_module_stack_root
+
+
+def get_subgraph_by_path(
+    graph_view: GraphView, paths: Union[str, list[str]]
+) -> list[fx.Node]:
+    """
+    Get subgraph by path(s).
+    Args:
+        graph_view (object): Root graph view object.
+        paths (str or list of str): Path(s) to subgraph.
+    Returns:
+        list[fx.Node]: fx nodes belong to the subgraph
+    """
+
+    def get_node_by_path(node: GraphView, path: str) -> GraphView:
+        for p in path.split("."):
+            if p in node.children:
+                node = node.children[p]
+            else:
+                return GraphView("", object)
+        return node
+
+    if isinstance(paths, list):
+        nodes = list(
+            itertools.chain.from_iterable(
+                get_node_by_path(graph_view, p).data for p in paths
+            )
+        )
+        return nodes
+    else:
+        node = get_node_by_path(graph_view, paths)
+        return node.data
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/group_batch_fusion.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/group_batch_fusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..f46d4d3ba216f15da9464e6052c36bdaa8b7c68a
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/group_batch_fusion.py
@@ -0,0 +1,1440 @@
+# mypy: allow-untyped-defs
+import collections
+import logging
+import operator
+from collections import OrderedDict
+from collections.abc import Iterable, Iterator
+from typing import Any
+
+import torch
+from torch._dynamo.utils import counters, is_node_meta_valid
+from torch._logging import trace_structured
+from torch.fx.passes.graph_transform_observer import GraphTransformObserver
+from torch.utils._ordered_set import OrderedSet
+
+from .. import config
+from ..pattern_matcher import (
+    CallFunctionVarArgs,
+    get_arg_value,
+    stable_topological_sort,
+)
+from ..utils import OPTIMUS_EXCLUDE_POST_GRAD
+
+
+try:
+    # importing this will register fbgemm lowerings for inductor
+    import deeplearning.fbgemm.fbgemm_gpu.fb.inductor_lowerings  # noqa: F401
+
+    has_fbgemm = True
+except Exception:
+    has_fbgemm = False
+
+aten = torch.ops.aten
+
+log = logging.getLogger(__name__)
+
+DEFAULT_BETA = 1
+DEFAULT_ALPHA = 1
+
+MIN_FUSE_SET_SIZE = 5
+MAX_FUSE_SET_SIZE = 300
+MAX_FUSE_SEARCH_DEPTH = 5
+# The maximum tensor size that can go into the fusion group
+MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR = 4096
+# Whether we only fuse nodes with same parent node
+FUSE_NODES_WITH_SAME_PARENT = False
+# Whether we enable the add broadcast in batch linear
+SHAPE_BROADCAST_BATCH_LINEAR = False
+# Whether we enable the fuse nodes with same users
+Fuse_NODES_WITH_SAME_USERS = False
+
+# exclude these nodes from BFS
+# excluding get item improves optimizer compilation time by 60s
+SEARCH_EXCLUSIONS = OrderedSet([operator.getitem])
+
+
+default_graph_search_options = {
+    "min_fuse_set_size": MIN_FUSE_SET_SIZE,
+    "max_fuse_set_size": MAX_FUSE_SET_SIZE,
+    "max_fuse_search_depth": MAX_FUSE_SEARCH_DEPTH,
+    "max_fuse_tensor_size_group_linear": MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR,
+    "fuse_nodes_with_same_parent": FUSE_NODES_WITH_SAME_PARENT,
+    "shape_broadcast_batch_linear": SHAPE_BROADCAST_BATCH_LINEAR,
+    "fuse_nodes_with_same_users": Fuse_NODES_WITH_SAME_USERS,
+}
+
+graph_search_options = default_graph_search_options
+
+
+def update_stack_example_value(node, metadata, dim=0, op=torch.stack):
+    """
+    Update the example value of the node in the graph to enable followup split cat opt.
+    """
+    if node is not None and hasattr(node, "meta"):
+        if op is torch.stack:
+            example_value = torch.stack(metadata, dim=dim)
+        elif op is torch.unbind:
+            example_value = torch.unbind(metadata, dim=dim)  # type: ignore[assignment]
+        else:
+            return
+        node.meta["example_value"] = example_value
+
+
+def update_pointwise_example_value(pointwise_node, input, other, op):
+    """
+    Update the example value of the add node in the graph to enable followup split cat opt.
+    """
+    if pointwise_node is not None and hasattr(pointwise_node, "meta"):
+        if op is torch.add:
+            example_value = torch.add(input, other)
+        elif op is torch.mul:
+            example_value = torch.mul(input, other)
+        else:
+            return
+        pointwise_node.meta["example_value"] = example_value
+
+
+class GroupBatchFusionBase:
+    def __init__(self, **kwargs) -> None:
+        self.graph_search_options = kwargs.pop(
+            "graph_search_options", default_graph_search_options
+        )
+
+    def match(self, node):
+        raise NotImplementedError("match called on base")
+
+    def fuse(self, graph, subset):
+        raise NotImplementedError("fuse called on base")
+
+
+PRE_GRAD_FUSIONS: dict[str, GroupBatchFusionBase] = {}
+POST_GRAD_FUSIONS: dict[str, GroupBatchFusionBase] = {}
+
+
+def register_fusion(name: str, pre_grad=True):
+    def decorator(fusion_cls: GroupBatchFusionBase):
+        if pre_grad:
+            PRE_GRAD_FUSIONS[name] = fusion_cls
+        else:
+            POST_GRAD_FUSIONS[name] = fusion_cls
+        return fusion_cls
+
+    return decorator
+
+
+def list_group_batch_fusions(pre_grad=True) -> list[str]:
+    if pre_grad:
+        return list(PRE_GRAD_FUSIONS.keys())
+    else:
+        return list(POST_GRAD_FUSIONS.keys())
+
+
+def decompose_stack(graph: torch.fx.GraphModule, input_tensors: list[Any]) -> Any:
+    unsqueezed_inputs = []
+    unsqueezed_inputs_meta = []
+    for input_tensor in input_tensors:
+        unsqueezed_input = graph.call_function(  # type: ignore[operator]
+            aten.unsqueeze, args=(input_tensor,), kwargs={"dim": 0}
+        )
+        unsqueezed_inputs.append(unsqueezed_input)
+        unsqueezed_input.meta["val"] = aten.unsqueeze(input_tensor.meta["val"], dim=0)  # type: ignore[assignment]
+        unsqueezed_inputs_meta.append(unsqueezed_input.meta["val"])
+    stacked_inputs = graph.call_function(  # type: ignore[operator]
+        aten.cat, args=(unsqueezed_inputs,), kwargs={"dim": 0}
+    )
+    stacked_inputs.meta["val"] = aten.cat(unsqueezed_inputs_meta, dim=0)  # type: ignore[assignment]
+    return stacked_inputs
+
+
+class GroupFusion(GroupBatchFusionBase):
+    """
+    Fuse ops in a group way, e.g, fuse mm/addmm of arbitrary input shapes with fbgemm.gmm.
+    """
+
+
+class BatchFusion(GroupBatchFusionBase):
+    """
+    Fuse ops in a batch way, e.g, fuse mm/addmm of same input shapes with bmm.
+    """
+
+
+class BatchPointwiseOpsFusionFactory(BatchFusion):
+    def __init__(self, op, **kwargs) -> None:
+        super().__init__(**kwargs)
+        self.op = op
+
+
+@register_fusion("batch_linear_post_grad", pre_grad=False)
+class PostGradBatchLinearFusion(BatchFusion):
+    """
+    Fuse ops in a batch way in post grad (aten level).
+    """
+
+    def _addmm_node_can_be_fused(self, node: torch.fx.Node) -> bool:
+        # pyre-fixme[7]: Incompatible return type
+        return (
+            node.kwargs.get("beta", DEFAULT_BETA) == DEFAULT_BETA
+            and node.kwargs.get("alpha", DEFAULT_ALPHA) == DEFAULT_ALPHA  # type: ignore[return-value]
+        )
+
+    def _is_input_2d(self, input: torch.fx.Node) -> bool:
+        input_shapes = input.meta["val"].shape
+        return (
+            len(input_shapes) == 2
+            and isinstance(input_shapes[0], int)
+            and isinstance(input_shapes[1], int)
+        )
+
+    def match(self, node: torch.fx.Node) -> tuple[str, int, int, int, bool, str] | None:
+        if CallFunctionVarArgs(aten.mm).match(node):
+            input_m, weight_m = node.args
+            bias_m = None
+
+        elif CallFunctionVarArgs(aten.addmm.default).match(
+            node
+        ) and self._addmm_node_can_be_fused(node):
+            bias_m, input_m, weight_m = node.args
+        else:
+            return None
+        # get the user of the node
+        if self.graph_search_options.get("fuse_nodes_with_same_users", False):
+            users = [user.target for user in node.users]
+        else:
+            users = ""  # type: ignore[assignment]
+        # only handle the cases where inputs are 2D tensors
+        if not self._is_input_2d(input_m) or not self._is_input_2d(weight_m):  # type: ignore[arg-type]
+            return None
+        m, k = input_m.meta["val"].shape  # type: ignore[union-attr]
+        n = weight_m.meta["val"].shape[1]  # type: ignore[union-attr]
+        batch_key = ("batch_linear_post_grad", m, k, n, bias_m is not None, str(users))
+        return batch_key
+
+    def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]):
+        batch_inputs = []
+        batch_weights = []
+        batch_biases = []
+        batch_nodes = []
+        batch_inputs_meta = []
+        batch_weights_meta = []
+        batch_biases_meta = []
+
+        for node in subset:
+            if CallFunctionVarArgs(aten.addmm.default).match(node):
+                bias, input, weight = node.args
+            elif CallFunctionVarArgs(aten.mm.default).match(node):
+                input, weight = node.args
+                bias = None
+            batch_nodes.append(node)
+            batch_inputs.append(input)  # type: ignore[possibly-undefined]
+            batch_weights.append(weight)  # type: ignore[possibly-undefined]
+            batch_biases.append(bias)  # type: ignore[possibly-undefined]
+            batch_inputs_meta.append(input.meta)  # type: ignore[possibly-undefined, union-attr]
+            batch_weights_meta.append(weight.meta)  # type: ignore[possibly-undefined, union-attr]
+            if bias is not None:  # type: ignore[possibly-undefined]
+                batch_biases_meta.append(bias.meta)  # type: ignore[possibly-undefined, union-attr]
+            else:
+                batch_biases_meta.append(None)
+
+        with graph.inserting_before(subset[-1]):  # type: ignore[operator]
+            fused_inputs = decompose_stack(graph, batch_inputs)
+            fused_weights = decompose_stack(graph, batch_weights)
+            fused_inputs_meta_val = torch.stack(
+                [input["val"] for input in batch_inputs_meta]
+            )
+            fused_weights_meta_val = torch.stack(
+                [weight["val"] for weight in batch_weights_meta]
+            )
+            fused_bmm = graph.call_function(  # type: ignore[operator]
+                aten.bmm,
+                args=(fused_inputs, fused_weights),
+            )
+            fused_bmm.meta["val"] = aten.bmm(
+                fused_inputs_meta_val, fused_weights_meta_val
+            )
+        for i, original_mm in enumerate(batch_nodes):
+            has_bias = False
+            with graph.inserting_after(fused_bmm):  # type: ignore[operator]
+                new_mm = graph.call_function(aten.select, args=((fused_bmm, 0, i)))  # type: ignore[operator]
+                new_mm.meta["val"] = aten.select(fused_bmm.meta["val"], 0, i)
+                if batch_biases[i]:
+                    has_bias = True
+                    # broadcast the bias to the same shape as the mm output
+                    if self.graph_search_options.get(
+                        "shape_broadcast_batch_linear", False
+                    ):
+                        broadcast_shape = torch.broadcast_shapes(
+                            batch_biases_meta[i]["val"].shape, new_mm.meta["val"].shape
+                        )
+                        broadcast_bias = graph.call_function(  # type: ignore[operator]
+                            aten.broadcast_to.default,
+                            args=(batch_biases[i],),
+                            kwargs={"size": broadcast_shape},
+                        )
+                        broadcast_bias.meta["val"] = aten.broadcast_to(
+                            batch_biases_meta[i]["val"], broadcast_shape
+                        )  # type: ignore[assignment]
+                        new_bias_add = graph.call_function(  # type: ignore[operator]
+                            aten.add.Tensor, args=((broadcast_bias, new_mm))
+                        )
+                        new_bias_add.meta["val"] = aten.add.Tensor(
+                            broadcast_bias.meta["val"], new_mm.meta["val"]
+                        )
+                    else:
+                        new_bias_add = graph.call_function(  # type: ignore[operator]
+                            aten.add, args=((batch_biases[i], new_mm))
+                        )
+                        new_bias_add.meta["val"] = aten.add.Tensor(
+                            batch_biases_meta[i]["val"], new_mm.meta["val"]
+                        )
+            new_mm_cont = new_bias_add if has_bias else new_mm  # type: ignore[possibly-undefined]
+            original_mm.replace_all_uses_with(new_mm_cont)
+            new_mm_cont.meta.update(original_mm.meta)
+            graph.erase_node(original_mm)  # type: ignore[operator]
+        counters["inductor"]["batch_linear_post_grad"] += 1
+
+
+@register_fusion("group_linear", pre_grad=False)
+class GroupLinearFusion(GroupFusion):
+    def _addmm_node_can_be_fused(self, node: torch.fx.Node):
+        input_shape = node.args[1].meta["val"].shape  # type: ignore[union-attr]
+        weight_shape = node.args[2].meta["val"].shape  # type: ignore[union-attr]
+        return (
+            node.kwargs.get("beta", DEFAULT_BETA) == DEFAULT_BETA
+            and node.kwargs.get("alpha", DEFAULT_ALPHA) == DEFAULT_ALPHA
+            and len(input_shape) == 2
+            and len(weight_shape) == 2
+            and all(x % 2 == 0 for x in input_shape + weight_shape)
+            and all(
+                shape <= self.graph_search_options["max_fuse_tensor_size_group_linear"]
+                for shape in input_shape + weight_shape
+            )
+        )
+
+    def _mm_node_can_be_fused(self, node: torch.fx.Node):
+        input_shape = node.args[0].meta["val"].shape  # type: ignore[union-attr]
+        weight_shape = node.args[1].meta["val"].shape  # type: ignore[union-attr]
+        return (
+            len(input_shape) == 2
+            and len(weight_shape) == 2
+            and all(x % 2 == 0 for x in input_shape + weight_shape)
+            and all(
+                shape <= self.graph_search_options["max_fuse_tensor_size_group_linear"]
+                for shape in input_shape + weight_shape
+            )
+        )
+
+    def match(self, node: torch.fx.Node) -> tuple[str, bool] | None:
+        if CallFunctionVarArgs(aten.mm.default).match(
+            node
+        ) and self._mm_node_can_be_fused(node):
+            group_key = ("group_linear", True)
+        elif CallFunctionVarArgs(aten.addmm.default).match(
+            node
+        ) and self._addmm_node_can_be_fused(node):
+            bias = node.args[0]
+            group_key = ("group_linear", bias is None)
+        else:
+            group_key = None
+        return group_key
+
+    def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]):
+        group_inputs = []
+        group_weights = []
+        group_biases = []
+        group_nodes = []
+        for node in subset:
+            if CallFunctionVarArgs(aten.addmm.default).match(node):
+                bias, input, weight = node.args
+            else:
+                assert CallFunctionVarArgs(aten.mm.default).match(node)
+                input, weight = node.args
+                bias = None
+
+            group_nodes.append(node)
+            group_inputs.append(input)
+            group_weights.append(weight)
+            group_biases.append(bias)
+
+        if all(bias is None for bias in group_biases):
+            group_biases = None  # type: ignore[assignment]
+
+        with graph.inserting_before(subset[0]):  # type: ignore[operator]
+            fused_mm = graph.call_function(  # type: ignore[operator]
+                torch.ops.fbgemm.gmm.default,
+                args=(group_inputs, group_weights, group_biases),
+                kwargs={"smart_fused": True},
+            )
+
+        for i, original_mm in enumerate(group_nodes):
+            with graph.inserting_after(fused_mm):  # type: ignore[operator]
+                new_mm = graph.call_function(operator.getitem, args=(fused_mm, i))  # type: ignore[operator]
+            original_mm.replace_all_uses_with(new_mm)
+            new_mm.meta.update(original_mm.meta)
+            graph.erase_node(original_mm)  # type: ignore[operator]
+        counters["inductor"]["group_linear"] += 1
+
+
+class BatchPointwiseMathOpsPostGradFusion(BatchPointwiseOpsFusionFactory):
+    """
+    Batch pointwise math operator (e.g., add, mul) in post grad pass.
+    """
+
+    def __init__(self, op, **kwargs) -> None:
+        super().__init__(op, **kwargs)
+        self.op = op
+
+    def _pointwise_node_can_be_fused(self, node: torch.fx.Node):
+        # note: we only consider the case where the inputs are tensors
+        # for mixed precision training, we need to make sure the inputs
+        # of the aten.cat when do the stack should be the same dtype
+        # otherwise, the output of the aten.cat may be not the same as
+        # its inputs, and cause dtype not same error in mm or addmm
+        input, other = node.args
+        return (
+            input.meta["val"].shape == other.meta["val"].shape  # type: ignore[union-attr]
+            # input and other can be scalars, where they have no attribute 'meta'
+            if hasattr(input, "meta")
+            and hasattr(other, "meta")
+            and is_node_meta_valid(input)  # type: ignore[arg-type, union-attr]
+            and is_node_meta_valid(other)  # type: ignore[arg-type, union-attr]
+            # torch.SymInt or torch.SymFloat object has no attribute 'shape'
+            and isinstance(input.meta["val"], torch.Tensor)  # type: ignore[union-attr]
+            and isinstance(other.meta["val"], torch.Tensor)  # type: ignore[union-attr]
+            else False
+        )
+
+    def match(self, node: torch.fx.Node):
+        if CallFunctionVarArgs(self.op).match(
+            node
+        ) and self._pointwise_node_can_be_fused(node):
+            alpha = node.kwargs.get("alpha", DEFAULT_ALPHA)
+            rounding_mode = node.kwargs.get("rounding_mode", None)
+            input, other = node.args
+            shape = list(input.meta["val"].shape)  # type: ignore[union-attr]
+            if self.graph_search_options.get("fuse_nodes_with_same_parent", False):
+                # only consider the linear case so far
+                # pyre-fixme[16]
+                if input.target is aten.select or other.target is aten.select:  # type: ignore[union-attr]
+                    parent = (
+                        # pyre-fixme[16]
+                        input.args[0]  # type: ignore[union-attr]
+                        # pyre-fixme[16]
+                        if input.target is aten.select  # type: ignore[union-attr]
+                        else other.args[0]  # type: ignore[union-attr]
+                    )
+                else:
+                    parent = ""
+            else:
+                parent = ""
+            group_key = (
+                "batch_aten_" + self.op.__name__.lower().split(".")[0],
+                str(shape),
+                str(input.meta["val"].dtype),  # type: ignore[union-attr]
+                str(other.meta["val"].dtype),  # type: ignore[union-attr]
+                str(alpha),
+                str(rounding_mode),
+                str(parent),
+            )
+        else:
+            group_key = None
+        return group_key
+
+    def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]):
+        batch_inputs, batch_others = [], []
+        alpha = subset[0].kwargs.get("alpha", DEFAULT_ALPHA)
+        batch_inputs_meta, batch_others_meta = [], []
+
+        for node in subset:
+            input, other = node.args
+            batch_inputs.append(input)
+            batch_others.append(other)
+            batch_inputs_meta.append(input.meta)  # type: ignore[possibly-undefined, union-attr]
+            batch_others_meta.append(other.meta)  # type: ignore[possibly-undefined, union-attr]
+
+        with graph.inserting_before(subset[0]):  # type: ignore[operator]
+            stack_inputs = decompose_stack(graph, batch_inputs)
+            stack_others = decompose_stack(graph, batch_others)
+            stack_inputs_meta = torch.stack(
+                [input["val"] for input in batch_inputs_meta]
+            )
+            stack_others_meta = torch.stack(
+                [other["val"] for other in batch_others_meta]
+            )
+
+            batch_op = graph.call_function(  # type: ignore[operator]
+                self.op,
+                args=(stack_inputs, stack_others),
+                kwargs={"alpha": alpha} if self.op == aten.add.Tensor else {},
+            )
+            batch_op.meta["val"] = self.op(stack_inputs_meta, stack_others_meta)
+            for i, original_add in enumerate(subset):
+                with graph.inserting_after(batch_op):  # type: ignore[operator]
+                    new_add = graph.call_function(  # type: ignore[operator]
+                        torch.ops.aten.select, args=((batch_op, 0, i))
+                    )
+                original_add.replace_all_uses_with(new_add)
+                new_add.meta.update(original_add.meta)
+                graph.erase_node(original_add)  # type: ignore[operator]
+        counters["inductor"][
+            "batch_aten_" + self.op.__name__.lower().split(".")[0]
+        ] += 1
+
+
+@register_fusion("batch_linear_lhs")
+class BatchLinearLHSFusion(BatchFusion):
+    """
+    Batch linear left-hand side fusion. This pass tries to fuse the following patterns:
+
+        torch.nn.functional.linear(x, w1), linear(x, w2),... * linear(x, wn)
+        -> torch.mm(x, torch.cat([w1, w2,... * wn]).transpose(0, 1))
+
+    We have a separate pass to eliminate contiguous transpose in a generic way.
+    """
+
+    def match(self, node: torch.fx.Node) -> tuple[str, bool, Any] | None:
+        if CallFunctionVarArgs(torch.nn.functional.linear).match(
+            node
+        ) and is_linear_node_can_be_fused(node):
+            input = get_arg_value(node, 0, "input")
+            bias = get_arg_value(node, 2, "bias")
+            group_key = ("batch_linear_lhs", bias is None, input)
+        else:
+            group_key = None
+        return group_key
+
+    def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]):
+        batch_nodes = []
+        batch_input = None
+        batch_weights, batch_weights_meta = [], []
+        batch_biases, batch_biases_meta = [], []
+        split_sections = []
+        for node in subset:
+            input = get_arg_value(node, 0, "input")
+            weight = get_arg_value(node, 1, "weight")
+            bias = get_arg_value(node, 2, "bias")
+            batch_nodes.append(node)
+            if batch_input is None:
+                batch_input = input
+            else:
+                assert batch_input is input
+            batch_weights.append(weight)
+            batch_weights_meta.append(weight.meta["example_value"])
+            if bias:
+                batch_biases.append(bias)
+                batch_biases_meta.append(bias.meta["example_value"])
+            split_sections.append(weight.meta["example_value"].shape[0])
+
+        with graph.inserting_before(subset[0]):  # type: ignore[operator]
+            cat_weights = graph.call_function(  # type: ignore[operator]
+                torch.cat, args=(batch_weights,), kwargs={"dim": 0}
+            )
+            cat_weights.meta["example_value"] = torch.cat(batch_weights_meta, dim=0)
+            transposed_weights = graph.call_function(  # type: ignore[operator]
+                torch.transpose, args=(cat_weights, 0, 1)
+            )
+            transposed_weights.meta["example_value"] = torch.transpose(
+                cat_weights.meta["example_value"], 0, 1
+            )
+            if len(batch_biases) > 0:
+                cat_biases = graph.call_function(  # type: ignore[operator]
+                    torch.cat, args=(batch_biases,), kwargs={"dim": 0}
+                )
+                cat_biases.meta["example_value"] = torch.cat(batch_biases_meta, dim=0)
+                fused_lhs = graph.call_function(  # type: ignore[operator]
+                    torch.addmm,
+                    args=(cat_biases, batch_input, transposed_weights),
+                )
+                fused_lhs.meta["example_value"] = torch.addmm(
+                    cat_biases.meta["example_value"],
+                    batch_input.meta["example_value"],  # type: ignore[union-attr]
+                    transposed_weights.meta["example_value"],
+                )
+            else:
+                fused_lhs = graph.call_function(  # type: ignore[operator]
+                    torch.mm,
+                    args=(batch_input, transposed_weights),
+                )
+                fused_lhs.meta["example_value"] = torch.mm(
+                    batch_input.meta["example_value"],  # type: ignore[union-attr]
+                    transposed_weights.meta["example_value"],
+                )
+            fused_lhs_list = graph.call_function(  # type: ignore[operator]
+                torch.split, args=(fused_lhs, split_sections), kwargs={"dim": 1}
+            )
+
+        for i, node in enumerate(batch_nodes):
+            with graph.inserting_after(fused_lhs_list):  # type: ignore[operator]
+                new_node = graph.call_function(  # type: ignore[operator]
+                    operator.getitem, args=(fused_lhs_list, i)
+                )
+            node.replace_all_uses_with(new_node)
+            new_node.meta.update(node.meta)
+            graph.erase_node(node)  # type: ignore[operator]
+        counters["inductor"]["batch_linear_lhs"] += 1
+
+
+# Poor person's check for if a node in the graph mutates its input.
+# (the graph is torch IR, so we will see torch fns and python operators)
+def _is_mutable_node(tgt):
+    if str(tgt).endswith("_"):
+        # e.g. torch.mul_, torch.Tensor.mul_
+        return True
+    if (
+        hasattr(tgt, "__module__")
+        and tgt.__module__ == "_operator"
+        and tgt.__name__.startswith("i")
+    ):
+        # e.g. operator.iand, operator.imul
+        return True
+    return False
+
+
+def is_linear_node_can_be_fused(node: torch.fx.Node):
+    input = get_arg_value(node, 0, "input")
+    weight = get_arg_value(node, 1, "weight")
+    return (
+        is_node_meta_valid(node)
+        and is_node_meta_valid(input)
+        and is_node_meta_valid(weight)
+        and len(input.meta["example_value"].shape) == 2
+        and len(weight.meta["example_value"].shape) == 2
+        # the mm -> bmm transform adds an unbind() op,
+        # which is not safe for autograd when the output of the mm is mutated.
+        # don't pattern match if any users of the mm mutate the input.
+        and not any(_is_mutable_node(user.target) for user in node.users)
+    )
+
+
+@register_fusion("batch_linear")
+class PreGradBatchLinearFusion(BatchFusion):
+    """
+    Batch linear fusion in pre grad pass.
+    Fuse linear with same size with torch.baddmm
+    """
+
+    def _getitem_args(self, getitem_node: torch.fx.Node):
+        if getitem_node.target != operator.__getitem__ or (
+            getitem_node.op != "call_function"
+        ):
+            return None
+        return getitem_node.args[0]
+
+    def match(self, node: torch.fx.Node):
+        if CallFunctionVarArgs(torch.nn.functional.linear).match(
+            node
+        ) and is_linear_node_can_be_fused(node):
+            input = get_arg_value(node, 0, "input")
+            weight = get_arg_value(node, 1, "weight")
+            bias = get_arg_value(node, 2, "bias")
+            if self.graph_search_options.get("fuse_nodes_with_same_users", False):
+                users = [user.target for user in node.users]
+            else:
+                users = ""  # type: ignore[assignment]
+            group_key = (
+                "batch_linear",
+                self._getitem_args(input),
+                str(input.meta["example_value"].shape),
+                str(weight.meta["example_value"].shape),
+                bias is None,
+                str(users),
+            )
+        else:
+            group_key = None
+        return group_key
+
+    def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]):
+        batch_nodes = []
+        batch_inputs = []
+        batch_weights = []
+        batch_biases = []
+        batch_inputs_metadata = []
+        batch_weights_metadata = []
+        batch_biases_metadata = []
+        for node in subset:
+            batch_nodes.append(node)
+            input = get_arg_value(node, 0, "input")
+            batch_inputs.append(input)
+            batch_inputs_metadata.append(input.meta["example_value"])
+            weight = get_arg_value(node, 1, "weight")
+            batch_weights.append(weight)
+            batch_weights_metadata.append(weight.meta["example_value"])
+            bias = get_arg_value(node, 2, "bias")
+            batch_biases.append(bias)
+            if bias is not None and hasattr(bias, "meta"):
+                batch_biases_metadata.append(bias.meta["example_value"])
+
+        with graph.inserting_before(subset[0]):  # type: ignore[operator]
+            stack_inputs = graph.call_function(  # type: ignore[operator]
+                torch.stack, args=(batch_inputs,), kwargs={"dim": 0}
+            )
+            update_stack_example_value(stack_inputs, batch_inputs_metadata)
+            stack_weights = graph.call_function(  # type: ignore[operator]
+                torch.stack, args=(batch_weights,), kwargs={"dim": 0}
+            )
+            update_stack_example_value(stack_weights, batch_weights_metadata)
+            transpose_weight = graph.call_function(  # type: ignore[operator]
+                torch.transpose, args=(stack_weights, 1, 2)
+            )
+            transpose_weight.meta["example_value"] = torch.transpose(
+                stack_weights.meta["example_value"], 1, 2
+            )
+            if all(bias is None for bias in batch_biases):
+                bmm = graph.call_function(  # type: ignore[operator]
+                    torch.bmm,
+                    args=(stack_inputs, transpose_weight),
+                )
+                bmm.meta["example_value"] = torch.bmm(
+                    stack_inputs.meta["example_value"],
+                    transpose_weight.meta["example_value"],
+                )
+                bmm_meta = bmm.meta["example_value"]
+            else:
+                stack_biases = graph.call_function(  # type: ignore[operator]
+                    torch.stack, args=(batch_biases,), kwargs={"dim": 0}
+                )
+                update_stack_example_value(stack_biases, batch_biases_metadata)
+                unsqueeze_biases = graph.call_function(  # type: ignore[operator]
+                    torch.unsqueeze, args=(stack_biases, 1)
+                )
+                unsqueeze_biases.meta["example_value"] = torch.unsqueeze(
+                    stack_biases.meta["example_value"], 1
+                )
+                bmm = graph.call_function(  # type: ignore[operator]
+                    torch.baddbmm,
+                    args=(unsqueeze_biases, stack_inputs, transpose_weight),
+                )
+                try:
+                    # it will have runtime error to broadcast when it has dynamic shape included
+                    # in the meta data, so we need to skip the update meta data
+                    bmm.meta["example_value"] = torch.baddbmm(
+                        unsqueeze_biases.meta["example_value"],
+                        stack_inputs.meta["example_value"],
+                        transpose_weight.meta["example_value"],
+                    )
+                    bmm_meta = bmm.meta["example_value"]
+                except Exception as e:
+                    log.debug(
+                        f" exception when update bmm meta data with stack error tracekey {e}"  # noqa: G004
+                    )
+                    bmm_meta = None
+
+            bmm = graph.call_function(torch.unbind, args=(bmm,), kwargs={"dim": 0})  # type: ignore[operator]
+            if bmm_meta is not None:
+                bmm.meta["example_value"] = torch.unbind(bmm_meta, dim=0)
+            for i, linear in enumerate(batch_nodes):
+                with graph.inserting_after(bmm):  # type: ignore[operator]
+                    getitem = graph.call_function(operator.getitem, args=(bmm, i))  # type: ignore[operator]
+                linear.replace_all_uses_with(getitem)
+                getitem.meta.update(linear.meta)
+                graph.erase_node(linear)  # type: ignore[operator]
+        counters["inductor"]["batch_linear"] += 1
+
+
+@register_fusion("batch_layernorm")
+class BatchLayernormFusion(BatchFusion):
+    """
+    Batch layer norm fusion in pre grad pass
+    """
+
+    def match(self, node: torch.fx.Node):
+        if CallFunctionVarArgs(torch.nn.functional.layer_norm).match(node):
+            input = get_arg_value(node, 0, "input")
+            weight = get_arg_value(node, 2, "weight")
+            bias = get_arg_value(node, 3, "bias")
+            if self.graph_search_options.get("fuse_nodes_with_same_users", False):
+                users = [user.target for user in node.users]
+            else:
+                users = ""  # type: ignore[assignment]
+            group_key = (
+                (
+                    "batch_layernorm",
+                    str(input.meta["example_value"].shape),
+                    str(weight.meta["example_value"].shape)
+                    if weight is not None
+                    else "",
+                    str(bias.meta["example_value"].shape) if bias is not None else "",
+                    str(get_arg_value(node, 1, "normalized_shape")),
+                    str(get_arg_value(node, 4, "eps")),
+                    str(users),
+                )
+                if "example_value" in input.meta
+                and is_node_meta_valid(weight)
+                and is_node_meta_valid(bias)
+                else None
+            )
+        else:
+            group_key = None
+        return group_key
+
+    def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]):
+        group_inputs = []
+        group_shapes = []
+        group_weights = []
+        group_biases = []
+        group_epss = []
+        group_nodes = []
+        group_inputs_metadata = []
+        group_biases_metadata = []
+        group_weights_metadata = []
+        for node in subset:
+            group_nodes.append(node)
+            input = get_arg_value(node, 0, "input")
+            group_inputs.append(input)
+            group_inputs_metadata.append(input.meta["example_value"])
+            group_shapes.append(get_arg_value(node, 1, "normalized_shape"))
+            weight = get_arg_value(node, 2, "weight")
+            group_weights.append(weight)
+            if weight is not None and hasattr(weight, "meta"):
+                group_weights_metadata.append(weight.meta["example_value"])
+            bias = get_arg_value(node, 3, "bias")
+            group_biases.append(bias)
+            if bias is not None and hasattr(bias, "meta"):
+                group_biases_metadata.append(bias.meta["example_value"])
+            eps = get_arg_value(node, 4, "eps")
+            if eps is None:
+                eps = 1e-5
+            group_epss.append(eps)
+        stack_dim = -1 - len(group_shapes[-1])
+
+        if all(bias is None for bias in group_biases):
+            group_biases = None  # type: ignore[assignment]
+        if all(weight is None for weight in group_weights):
+            group_weights = None  # type: ignore[assignment]
+        assert all(eps == group_epss[0] for eps in group_epss), (
+            "all epsilon values must be equal"
+        )
+
+        with graph.inserting_before(subset[0]):  # type: ignore[operator]
+            stack_input = graph.call_function(  # type: ignore[operator]
+                torch.stack, args=(group_inputs,), kwargs={"dim": stack_dim}
+            )
+            update_stack_example_value(stack_input, group_inputs_metadata, stack_dim)
+            if group_weights is not None:
+                stack_weight = graph.call_function(  # type: ignore[operator]
+                    torch.stack, args=(group_weights,), kwargs={"dim": 0}
+                )
+                update_stack_example_value(stack_weight, group_weights_metadata)
+            else:
+                stack_weight = None
+            if group_biases is not None:
+                stack_bias = graph.call_function(  # type: ignore[operator]
+                    torch.stack, args=(group_biases,), kwargs={"dim": 0}
+                )
+                update_stack_example_value(stack_bias, group_biases_metadata)
+            else:
+                stack_bias = None
+
+            batch_layer_norm = graph.call_function(  # type: ignore[operator]
+                torch.nn.functional.layer_norm,
+                args=(stack_input, group_shapes[-1]),
+                kwargs={"eps": group_epss[-1]},
+            )
+            batch_layer_norm.meta["example_value"] = stack_input.meta["example_value"]
+
+            if group_weights is not None and group_biases is not None:
+                previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"]
+                batch_layer_norm = graph.call_function(  # type: ignore[operator]
+                    torch.mul, args=(stack_weight, batch_layer_norm)
+                )
+                update_pointwise_example_value(
+                    batch_layer_norm,
+                    # pyrefly: ignore [missing-attribute]
+                    stack_weight.meta["example_value"],
+                    previous_batch_layer_norm_meta,
+                    torch.mul,
+                )
+                previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"]
+                batch_layer_norm = graph.call_function(  # type: ignore[operator]
+                    torch.add, args=(stack_bias, batch_layer_norm)
+                )
+                update_pointwise_example_value(
+                    batch_layer_norm,
+                    # pyrefly: ignore [missing-attribute]
+                    stack_bias.meta["example_value"],
+                    previous_batch_layer_norm_meta,
+                    torch.add,
+                )
+            elif group_weights is not None and group_biases is None:
+                previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"]
+                # pyrefly: ignore [not-callable]
+                batch_layer_norm = graph.call_function(
+                    torch.mul, args=(stack_weight, batch_layer_norm)
+                )
+                update_pointwise_example_value(
+                    batch_layer_norm,
+                    # pyrefly: ignore [missing-attribute]
+                    stack_weight.meta["example_value"],
+                    previous_batch_layer_norm_meta,
+                    torch.mul,
+                )
+            elif group_weights is None and group_biases is not None:
+                previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"]
+                # pyrefly: ignore [not-callable]
+                batch_layer_norm = graph.call_function(
+                    torch.add, args=(stack_bias, batch_layer_norm)
+                )
+                update_pointwise_example_value(
+                    batch_layer_norm,
+                    # pyrefly: ignore [missing-attribute]
+                    stack_bias.meta["example_value"],
+                    previous_batch_layer_norm_meta,
+                    torch.add,
+                )
+
+            batch_layer_norm_unbind = graph.call_function(  # type: ignore[operator]
+                torch.unbind,
+                args=(batch_layer_norm,),
+                kwargs={"dim": stack_dim},
+            )
+            update_stack_example_value(
+                batch_layer_norm_unbind,
+                batch_layer_norm.meta["example_value"],
+                op=torch.unbind,
+                dim=stack_dim,
+            )
+
+        for i, node in enumerate(group_nodes):
+            with graph.inserting_after(batch_layer_norm_unbind):  # type: ignore[operator]
+                new_node = graph.call_function(  # type: ignore[operator]
+                    operator.getitem, args=(batch_layer_norm_unbind, i)
+                )
+            node.replace_all_uses_with(new_node)
+            new_node.meta.update(node.meta)
+            graph.erase_node(node)  # type: ignore[operator]
+        counters["inductor"]["batch_layernorm"] += 1
+
+
+class BatchPointwiseOpsPreGradFusion(BatchPointwiseOpsFusionFactory):
+    """
+    Batch pointwise ops (e.g., sigmoid, relu, tanh) fusion in pre grad pass.
+    We fuse it in random place, and the introduced stack node may be merged in split cat.
+    """
+
+    def __init__(self, op, **kwargs) -> None:
+        super().__init__(op, **kwargs)
+        self.op = op
+
+    def match(self, node: torch.fx.Node):
+        input = get_arg_value(node, 0, "input")
+        if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node):
+            if self.graph_search_options.get("fuse_nodes_with_same_parent", False):
+                # pyre-fixme[16]
+                parent = node.args[0]
+                parent = parent.target if parent is not None else ""  # type: ignore[union-attr]
+            else:
+                parent = ""
+            # for relu op, we also use the inplace to construct the key
+            group_key = (
+                "batch_" + self.op.__name__.lower().split(".")[0],
+                str(input.meta["example_value"].shape),
+                str(node.kwargs.get("inplace", False)),
+                str(parent),
+            )
+        else:
+            group_key = None
+        return group_key
+
+    def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]):
+        batch_nodes = []
+        batch_inputs = []
+        batch_inputs_metadata = []
+
+        for node in subset:
+            batch_nodes.append(node)
+            input = get_arg_value(node, 0, "input")
+            batch_inputs.append(input)
+            batch_inputs_metadata.append(input.meta["example_value"])
+
+        with graph.inserting_before(subset[0]):  # type: ignore[operator]
+            stack_inputs = graph.call_function(  # type: ignore[operator]
+                torch.stack, args=(batch_inputs,), kwargs={"dim": 0}
+            )
+            update_stack_example_value(stack_inputs, batch_inputs_metadata)
+            if self.op is torch.nn.functional.relu:
+                batch_op = graph.call_function(  # type: ignore[operator]
+                    self.op,
+                    args=(stack_inputs,),
+                    kwargs={"inplace": subset[0].kwargs.get("inplace", False)},
+                )
+                batch_op.meta["example_value"] = self.op(
+                    stack_inputs.meta["example_value"],
+                    # pyrefly: ignore [bad-argument-type]
+                    inplace=subset[0].kwargs.get("inplace", False),
+                )
+            else:
+                batch_op = graph.call_function(  # type: ignore[operator]
+                    self.op,
+                    args=(stack_inputs,),
+                )
+                batch_op.meta["example_value"] = self.op(
+                    stack_inputs.meta["example_value"]
+                )
+            unbind_op = graph.call_function(  # type: ignore[operator]
+                torch.unbind, args=(batch_op,), kwargs={"dim": 0}
+            )
+            unbind_op.meta["example_value"] = torch.unbind(
+                batch_op.meta["example_value"], dim=0
+            )
+            for i, node in enumerate(batch_nodes):
+                with graph.inserting_after(unbind_op):  # type: ignore[operator]
+                    getitem = graph.call_function(operator.getitem, args=(unbind_op, i))  # type: ignore[operator]
+                node.replace_all_uses_with(getitem)
+                getitem.meta.update(node.meta)
+                graph.erase_node(node)  # type: ignore[operator]
+        counters["inductor"]["batch_" + self.op.__name__.lower().split(".")[0]] += 1
+
+
+class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory):
+    """
+    Batch pointwise ops (e.g., sigmoid, relu, tanh) fusion in post grad pass.
+    The introduced stack node may be merged in split cat.
+    """
+
+    def __init__(self, op, **kwargs) -> None:
+        super().__init__(op, **kwargs)
+        self.op = op
+
+    def match(self, node: torch.fx.Node):
+        input = get_arg_value(node, 0, "input")
+        if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node):
+            # for relu op, we also use the inplace to construct the key
+            # we batch the ops with same parent to enable followup split cat
+            parent = node.args[0]
+            parent = (
+                parent.target  # type: ignore[union-attr]
+                if self.graph_search_options.get("fuse_nodes_with_same_parent", False)
+                else ""
+            )
+            group_key = (
+                "batch_aten_" + self.op.__name__.lower().split(".")[0],
+                str(input.meta["val"].shape),
+                str(node.kwargs.get("inplace", False)),
+                # pyre-fixme[16]
+                str(parent),
+            )
+        else:
+            group_key = None
+        return group_key
+
+    def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]):
+        batch_nodes = []
+        batch_inputs = []
+        batch_inputs_metadata = []
+
+        for node in subset:
+            batch_nodes.append(node)
+            input = get_arg_value(node, 0, "input")
+            batch_inputs.append(input)
+            batch_inputs_metadata.append(input.meta["val"])
+
+        with graph.inserting_before(subset[0]):  # type: ignore[operator]
+            stack_inputs = decompose_stack(graph, batch_inputs)
+            update_stack_example_value(stack_inputs, batch_inputs_metadata)
+            batch_op = graph.call_function(  # type: ignore[operator]
+                self.op,
+                args=(stack_inputs,),
+            )
+            for i, node in enumerate(batch_nodes):
+                with graph.inserting_after(batch_op):  # type: ignore[operator]
+                    getitem = graph.call_function(aten.select, args=(batch_op, 0, i))  # type: ignore[operator]
+                node.replace_all_uses_with(getitem)
+                getitem.meta.update(node.meta)
+                graph.erase_node(node)  # type: ignore[operator]
+        counters["inductor"][
+            "batch_aten_" + self.op.__name__.lower().split(".")[0]
+        ] += 1
+
+
+class BatchMathOpsPreGradFusion(BatchPointwiseOpsFusionFactory):
+    """
+    Batch simple match related ops such as nan_to_num in pre grad pass.
+    """
+
+    def __init__(self, op, **kwargs):
+        super().__init__(op, **kwargs)
+        self.op = op
+
+    def match(self, node: torch.fx.Node):
+        input = get_arg_value(node, 0, "input")
+        if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node):
+            # check the input has the same shape and its users have the same target
+            # check all clamp operators have the same min and max values, and
+            # nan_to_num operators use the same default value.
+            child = next(iter(node.users.keys()))
+            group_key = (
+                str(input.meta["example_value"].shape)
+                + str(node.kwargs)
+                + str(child.target)
+            )
+        else:
+            group_key = None
+        return group_key
+
+    def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]):
+        batch_nodes = []
+        batch_inputs = []
+        batch_inputs_metadata = []
+        kwargs = subset[0].kwargs
+
+        for node in subset:
+            batch_nodes.append(node)
+            input = get_arg_value(node, 0, "input")
+            batch_inputs.append(input)
+            batch_inputs_metadata.append(input.meta["example_value"])
+
+        with graph.inserting_before(subset[0]):  # type: ignore[operator]
+            stack_inputs = graph.call_function(  # type: ignore[operator]
+                torch.stack, args=(batch_inputs,), kwargs={"dim": 0}
+            )
+            update_stack_example_value(stack_inputs, batch_inputs_metadata)
+            batch_op = graph.call_function(  # type: ignore[operator]
+                self.op,
+                args=(stack_inputs,),
+                kwargs=kwargs,
+            )
+            batch_op.meta["example_value"] = self.op(
+                stack_inputs.meta["example_value"], **kwargs
+            )
+            unbind_op = graph.call_function(  # type: ignore[operator]
+                torch.unbind, args=(batch_op,), kwargs={"dim": 0}
+            )
+            unbind_op.meta["example_value"] = torch.unbind(
+                batch_op.meta["example_value"], dim=0
+            )
+            for i, node in enumerate(batch_nodes):
+                with graph.inserting_after(unbind_op):  # type: ignore[operator]
+                    getitem = graph.call_function(operator.getitem, args=(unbind_op, i))  # type: ignore[operator]
+                node.replace_all_uses_with(getitem)
+                getitem.meta.update(node.meta)
+                graph.erase_node(node)  # type: ignore[operator]
+        counters["inductor"]["batch_" + self.op.__name__.lower().split(".")[0]] += 1
+
+
+@register_fusion("batch_tanh")
+class BatchTanhPreGradFusion(BatchPointwiseOpsPreGradFusion):
+    def __init__(self, **kwargs) -> None:
+        super().__init__(torch.tanh, **kwargs)
+
+
+@register_fusion("batch_sigmoid")
+class BatchSigmoidPreGradFusion(BatchPointwiseOpsPreGradFusion):
+    def __init__(self, **kwargs) -> None:
+        super().__init__(torch.sigmoid, **kwargs)
+
+
+@register_fusion("batch_relu")
+class BatchReLuPreGradFusion(BatchPointwiseOpsPreGradFusion):
+    def __init__(self, **kwargs) -> None:
+        super().__init__(torch.nn.functional.relu, **kwargs)
+
+
+@register_fusion("batch_detach")
+class BatchDetachPreGradFusion(BatchMathOpsPreGradFusion):
+    def __init__(self, **kwargs):
+        super().__init__(torch.detach, **kwargs)
+
+
+@register_fusion("batch_nan_to_num")
+class BatchNanToNumPreGradFusion(BatchMathOpsPreGradFusion):
+    def __init__(self, **kwargs):
+        super().__init__(torch.nan_to_num, **kwargs)
+
+
+@register_fusion("batch_clamp")
+class BatchClampPreGradFusion(BatchMathOpsPreGradFusion):
+    def __init__(self, **kwargs):
+        super().__init__(torch.clamp, **kwargs)
+
+
+@register_fusion("batch_dropout")
+class BatchDropoutPreGradFusion(BatchMathOpsPreGradFusion):
+    def __init__(self, **kwargs):
+        super().__init__(torch.nn.functional.dropout, **kwargs)
+
+
+@register_fusion("batch_aten_tanh", pre_grad=False)
+class BatchTanhPostGradFusion(BatchPointwiseOpsPostGradFusion):
+    def __init__(self, **kwargs) -> None:
+        super().__init__(aten.tanh.default, **kwargs)
+
+
+@register_fusion("batch_aten_sigmoid", pre_grad=False)
+class BatchSigmoidPostGradFusion(BatchPointwiseOpsPostGradFusion):
+    def __init__(self, **kwargs) -> None:
+        super().__init__(aten.sigmoid.default, **kwargs)
+
+
+@register_fusion("batch_aten_relu", pre_grad=False)
+class BatchReLuPostGradFusion(BatchPointwiseOpsPostGradFusion):
+    def __init__(self, **kwargs) -> None:
+        super().__init__(aten.relu.default, **kwargs)
+
+
+@register_fusion("batch_aten_add", pre_grad=False)
+class BatchAddPostGradFusion(BatchPointwiseMathOpsPostGradFusion):
+    def __init__(self, **kwargs) -> None:
+        super().__init__(aten.add.Tensor, **kwargs)
+
+
+@register_fusion("batch_aten_sub", pre_grad=False)
+class BatchSubPostGradFusion(BatchPointwiseMathOpsPostGradFusion):
+    def __init__(self, **kwargs) -> None:
+        super().__init__(aten.sub.Tensor, **kwargs)
+
+
+@register_fusion("batch_aten_div", pre_grad=False)
+class BatchDivPostGradFusion(BatchPointwiseMathOpsPostGradFusion):
+    def __init__(self, **kwargs) -> None:
+        super().__init__(aten.div.Tensor, **kwargs)
+
+
+@register_fusion("batch_aten_mul", pre_grad=False)
+class BatchMulPostGradFusion(BatchPointwiseMathOpsPostGradFusion):
+    def __init__(self, **kwargs) -> None:
+        super().__init__(aten.mul.Tensor, **kwargs)
+
+
+class _OrderedSet:
+    def __init__(self, param=None) -> None:
+        if param:
+            self.rep = OrderedDict(dict.fromkeys(param))
+        else:
+            self.rep = OrderedDict()
+
+    def __contains__(self, o) -> bool:
+        return o in self.rep
+
+    def __len__(self) -> int:
+        return self.rep.__len__()
+
+    def append(self, o):
+        self.rep[o] = None
+
+    def __iter__(self):
+        return self.rep.keys().__iter__()
+
+
+def find_independent_subset_greedy(
+    node_list: Iterable[torch.fx.Node],
+    graph_search_options: dict[str, Any],
+) -> Iterator[Iterable[torch.fx.Node]]:
+    """
+    Yields a list of subsets of `node_list` where no element in the subset
+    depends on any other element in the subset. This results in a set of
+    independent nodes which can be fused together.
+
+    The order of `node_list` is preserved within each subset so we can benefit
+    from split-cat elimination in later passes.
+
+    During iteration it is only safe to mutate the graph by changing the nodes
+    that have been returned.
+
+    graph_search_options:
+      - min_fuse_set_size: Minimum size of the subset to consider. Subsets below
+        this size will be ignored.
+      - max_fuse_set_size: Maximum size of the subset to consider. Subsets will
+        be broken to be at most this size.
+    """
+
+    # Compute all the children of `node` which are members of
+    # `interesting_nodes`.
+    def find_dependent_nodes(node, interesting_nodes):
+        visited_node_set = OrderedSet[torch.fx.Node]()
+        dep_set = OrderedSet[torch.fx.Node]()
+
+        work = [node]
+        while work:
+            node = work.pop()
+            for input_node in node.all_input_nodes:
+                if input_node in interesting_nodes:
+                    dep_set.add(input_node)
+
+                if input_node not in visited_node_set:
+                    visited_node_set.add(input_node)
+                    work.append(input_node)
+
+        return dep_set
+
+    min_fuse_set_size = graph_search_options["min_fuse_set_size"]
+    max_fuse_set_size = graph_search_options["max_fuse_set_size"]
+
+    # node_list needs to be a set because we only track the nodes that are left
+    # in it (and we want to do the `in` on a set, not a list). But we want to
+    # keep the correct order.
+    node_list = _OrderedSet(node_list)
+
+    cache: dict[torch.fx.Node, OrderedSet[torch.fx.Node]] = {}
+    while node_list:
+        subset: list[torch.fx.Node] = []
+        subset_deps = OrderedSet[torch.fx.Node]()
+
+        next_round_node_list = _OrderedSet()
+        for node in node_list:
+            if len(subset) >= max_fuse_set_size or node in subset_deps:
+                next_round_node_list.append(node)
+                continue
+
+            dep_set = cache.pop(node, None)
+            if dep_set is None:
+                dep_set = find_dependent_nodes(node, node_list)
+
+            if not dep_set.intersection(subset):
+                subset.append(node)
+                subset_deps.update(dep_set)
+            else:
+                next_round_node_list.append(node)
+                cache[node] = dep_set
+
+        if len(subset) >= min_fuse_set_size:
+            # Careful here - the caller uses the subsets to fuse nodes together
+            # so we need to clear any cache entry that contains one of the
+            # returned nodes because the dependency list could be different
+            # (larger) after the merge.
+            cache = {k: v for k, v in cache.items() if v.isdisjoint(subset)}
+            yield subset
+
+        node_list = next_round_node_list
+
+
+def get_fusion_candidates(
+    rule: GroupBatchFusionBase,
+    root_node: torch.fx.Node,
+    fused_set: OrderedSet[torch.fx.Node],
+) -> collections.defaultdict[Any, list[torch.fx.Node]]:
+    """
+    Search fusion candidates for a specific rule using BFS starting from the root node.
+    We only search the subgraph within graph_search_options["max_fuse_search_depth"].
+    """
+    q: collections.deque[tuple[int, torch.fx.Node]] = collections.deque()
+
+    candidate_dict: collections.defaultdict[Any, list[torch.fx.Node]] = (
+        collections.defaultdict(list)
+    )
+
+    if root_node.target in SEARCH_EXCLUSIONS:
+        return candidate_dict
+
+    visited_set = OrderedSet[torch.fx.Node]()
+
+    for next_node in root_node.all_input_nodes:
+        q.append((1, next_node))
+        visited_set.add(next_node)
+
+    while len(q) > 0:
+        depth, node = q.popleft()
+
+        if node in fused_set:
+            continue
+
+        key = rule.match(node)
+        if key is not None:
+            candidate_nodes = candidate_dict[key]
+            if node not in candidate_nodes:
+                candidate_nodes.append(node)
+        else:
+            if depth < rule.graph_search_options["max_fuse_search_depth"]:
+                for next_node in node.all_input_nodes:
+                    if next_node not in visited_set:
+                        visited_set.add(next_node)
+                        q.append((depth + 1, next_node))
+
+    return candidate_dict
+
+
+def apply_group_batch_fusion(graph: torch.fx.GraphModule, rule: GroupBatchFusionBase):
+    stable_topological_sort(graph)  # type: ignore[arg-type]
+    fused_set = OrderedSet[torch.fx.Node]()
+    log_to_scuba = False
+
+    for node in reversed(graph.nodes):  # type: ignore[arg-type]
+        candidates = get_fusion_candidates(rule, node, fused_set)
+
+        for key, candidate_nodes in candidates.items():
+            if len(candidate_nodes) < rule.graph_search_options["min_fuse_set_size"]:
+                continue
+
+            for subset in find_independent_subset_greedy(
+                candidate_nodes, rule.graph_search_options
+            ):
+                rule.fuse(graph, subset)
+                fused_set.update(subset)
+                log.debug(
+                    f"{rule.__class__.__name__}: key = {key}; subset size = {len(list(subset))}"  # noqa: G004
+                )
+                log_to_scuba = True
+    if log_to_scuba:
+        from torch.fx._lazy_graph_module import _LazyGraphModule
+
+        # Force graph to re-compile otherwise the output python code may be broken
+        gm = graph._owning_module
+        if isinstance(gm, _LazyGraphModule):
+            _LazyGraphModule.recompile()
+        else:
+            assert isinstance(gm, torch.fx.GraphModule)
+            gm.recompile()
+        graph_str = gm.print_readable(
+            print_output=False, include_stride=True, include_device=True
+        )
+
+        name = f"optimus_{str(rule.__class__.__name__)}"
+        if "MTIA" in name:
+            name = f"cff_{str(rule.__class__.__name__)}"
+        trace_structured(
+            "artifact",
+            metadata_fn=lambda: {
+                "name": name,
+                "encoding": "string",
+            },
+            payload_fn=lambda: graph_str,
+        )
+
+
+def generate_fusion_from_config(config_options: dict[str, Any], pre_grad=True):
+    fusions: list[GroupBatchFusionBase] = []
+    for name, options in config_options.items():
+        # we skip all patterns from pattern_matcher passes (e.g., split_cat)
+        if name not in PRE_GRAD_FUSIONS and name not in POST_GRAD_FUSIONS:
+            continue
+        fusion_cls = PRE_GRAD_FUSIONS[name] if pre_grad else POST_GRAD_FUSIONS[name]
+        _options = graph_search_options.copy()
+        _options.update(options)
+        fusions.append(fusion_cls(graph_search_options=_options))  # type: ignore[operator]
+    return fusions
+
+
+def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True):
+    fusions: list[GroupBatchFusionBase] = []
+    # we keep all current pre grad fusions to keep
+    # current implementation, will remove this later
+    if pre_grad:
+        fusions += generate_fusion_from_config(
+            config.pre_grad_fusion_options, pre_grad=True
+        )
+    else:
+        fbgemm_fusion_keys = [
+            x
+            for x in config.post_grad_fusion_options
+            if (
+                x not in OPTIMUS_EXCLUDE_POST_GRAD
+                and config.post_grad_fusion_options[x].get("require_fbgemm", False)
+            )
+        ]
+        fbgemm_fusions = {
+            fusion: config.post_grad_fusion_options[fusion]
+            for fusion in fbgemm_fusion_keys
+        }
+        non_fbgemm_fusions = {
+            fusion: config.post_grad_fusion_options[fusion]
+            for fusion in config.post_grad_fusion_options
+            if fusion not in fbgemm_fusion_keys
+        }
+        fusions += generate_fusion_from_config(non_fbgemm_fusions, pre_grad=False)
+        if has_fbgemm:
+            fusions += generate_fusion_from_config(fbgemm_fusions, pre_grad=False)
+
+    for i, rule in enumerate(fusions):
+        with GraphTransformObserver(
+            graph.owning_module,
+            f"group_batch_fusion_{i}",
+        ):
+            apply_group_batch_fusion(graph, rule)  # type: ignore[arg-type]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/joint_graph.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/joint_graph.py
new file mode 100644
index 0000000000000000000000000000000000000000..021abb0d6b13bd94c146b9a058c058745252e904
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/joint_graph.py
@@ -0,0 +1,1048 @@
+# mypy: allow-untyped-defs
+import functools
+import itertools
+import logging
+import operator
+import typing
+from collections import Counter
+from collections.abc import Sequence
+from typing import Any
+
+import torch
+import torch._guards
+import torch.utils._pytree as pytree
+from torch._dynamo.utils import counters
+from torch._inductor.constant_folding import ConstantFolder
+from torch._inductor.fx_passes.dedupe_symint_uses import _SymHashingDict
+from torch._inductor.utils import get_gpu_type
+from torch.fx.experimental.symbolic_shapes import (
+    guard_or_false,
+    guard_or_true,
+    statically_known_true,
+)
+from torch.multiprocessing.reductions import StorageWeakRef
+from torch.utils._ordered_set import OrderedSet
+
+from .. import config
+from ..pattern_matcher import (
+    Arg,
+    CallFunction,
+    init_once_fakemode,
+    KeywordArg,
+    Match,
+    MULTIPLE,
+    PatternMatcherPass as PatternMatcherPassBase,
+    register_graph_pattern,
+    stable_topological_sort,
+)
+from .decompose_mem_bound_mm import check_device
+from .replace_random import replace_random_passes
+
+
+PatternMatcherPass = functools.partial(
+    PatternMatcherPassBase, subsystem="joint_graph_passes"
+)
+
+log = logging.getLogger(__name__)
+patterns = PatternMatcherPass()
+aten = torch.ops.aten
+prims = torch.ops.prims
+
+pass_patterns = [
+    patterns,
+    PatternMatcherPass(),
+]
+
+
+@init_once_fakemode
+def lazy_init():
+    from .fuse_attention import _sfdp_init
+    from .misc_patterns import _misc_patterns_init
+    from .pad_mm import _pad_mm_init
+
+    _pad_mm_init()
+    _sfdp_init()
+    _misc_patterns_init()
+
+
+def remove_no_ops(
+    gm: torch.fx.GraphModule,
+    zeros: OrderedSet[torch.fx.Node],
+    ones: OrderedSet[torch.fx.Node],
+):
+    with torch.utils._python_dispatch._disable_current_modes():
+        "Removes no-ops: (+ 0, - 0, * 1, / 1)"
+        graph = gm.graph
+
+        def fake_tensors_eq(t1, t2, fields=("shape", "dtype", "device")):
+            if any(not isinstance(t, torch.Tensor) for t in (t1, t2)):
+                return False
+            for field in fields:
+                if getattr(t1, field) != getattr(t2, field):
+                    return False
+            return True
+
+        def replace_no_op(node, replace_input_index):
+            replacement = node.args[replace_input_index]
+
+            # https://github.com/pytorch/pytorch/issues/86128 causes
+            # non-Tensor inputs even for ops with only Tensor inputs.
+            # TODO - decompose/type promote to avoid this
+            if not all(isinstance(arg, torch.fx.Node) for arg in node.args):
+                return
+
+            if not fake_tensors_eq(node.meta["val"], replacement.meta["val"]):
+                if fake_tensors_eq(
+                    node.meta["val"],
+                    replacement.meta["val"],
+                    ("shape", "device"),
+                ):
+                    with graph.inserting_after(node):
+                        replacement = graph.call_function(
+                            torch.ops.prims.convert_element_type.default,
+                            args=(replacement, node.meta["val"].dtype),
+                        )
+                else:
+                    return
+
+            node.replace_all_uses_with(replacement)
+            replacement.meta.update(node.meta)
+            graph.erase_node(node)
+
+        for node in graph.find_nodes(op="call_function", target=aten.add.Tensor):
+            # TODO handle Tensor-Scalar adds, it's a different schema
+            if len(node.args) == 2:
+                if (
+                    not any(e in zeros for e in node.args)
+                    or node.kwargs.get("alpha", 1) != 1
+                ):
+                    continue
+
+                replace_index = 1 if node.args[0] in zeros else 0
+                replace_no_op(node, replace_index)
+
+        for node in graph.find_nodes(op="call_function", target=aten.sub.Tensor):
+            if len(node.args) == 2:
+                if node.args[1] not in zeros or node.kwargs.get("alpha", 1) != 1:
+                    continue
+
+                replace_no_op(node, 0)
+
+        for node in graph.find_nodes(op="call_function", target=aten.mul.Tensor):
+            if len(node.args) == 2:
+                if not any(e in ones for e in node.args):
+                    continue
+
+                replace_input_index = 1 if node.args[0] in ones else 0
+                replace_no_op(node, replace_input_index)
+
+        for node in graph.find_nodes(op="call_function", target=aten.div.Tensor):
+            if len(node.args) == 2 and node.args[1] in ones:
+                replace_no_op(node, 0)
+
+        # meta tensors returned from the graph have no data and can be replaced with empty_strided
+        for output_node in graph.find_nodes(op="output"):
+            had_meta_return = False
+
+            def visit(n):
+                nonlocal had_meta_return
+                val = n.meta.get("val")
+                if isinstance(val, torch.Tensor) and val.device.type == "meta":
+                    with graph.inserting_before(output_node):
+                        n.replace_all_uses_with(
+                            graph.call_function(
+                                torch.ops.aten.empty_strided.default,
+                                args=(val.size(), val.stride()),
+                                kwargs={"dtype": val.dtype, "device": val.device},
+                            )
+                        )
+                    had_meta_return = True
+
+            torch.fx.map_arg(output_node.args, visit)
+            if had_meta_return:
+                graph.eliminate_dead_code()
+
+
+def remove_redundant_views(gm: torch.fx.GraphModule):
+    """
+    Removes redundant views by reusing existing ones.
+    """
+    with torch.utils._python_dispatch._disable_current_modes():
+        # A dictionary mapping a tensor to all aliased views.
+        views: dict[torch.fx.Node, dict[torch.dtype, torch.fx.Node]] = {}
+        graph = gm.graph
+
+        for node in graph.find_nodes(
+            op="call_function", target=torch.ops.aten.view.dtype
+        ):
+            src = node.args[0]
+            to_type = node.args[1]
+            existing_views = views.get(src)
+            is_needed = True
+
+            if existing_views:
+                # Replace the view with the an existing view if available.
+                alias = existing_views.get(to_type)
+                if alias:
+                    is_needed = False
+                    node.replace_all_uses_with(alias)
+                    alias.meta.update(node.meta)
+                    graph.erase_node(node)
+            else:
+                from_type = src.meta["val"].dtype
+                existing_views = {from_type: src}
+                views[src] = existing_views
+
+            if is_needed:
+                # Save the new alias but do not replace existing one.
+                existing_views.setdefault(to_type, node)
+                views[node] = existing_views
+
+        # Clean up unused views.
+        while True:
+            unused_views = [alias for alias in views if not alias.users]
+            if len(unused_views) == 0:
+                break
+            for unused in unused_views:
+                views.pop(unused)
+                graph.erase_node(unused)
+
+
+class UniformValueConstantFolder(ConstantFolder):
+    """
+    Runs constant folding and replaces tensors that have a uniform value
+    with a tensor constructor call: aten.full([shape], value, ...)
+    """
+
+    def __init__(self, gm, skip_constructors=False) -> None:
+        super().__init__(gm, skip_constructors)
+        self.node_storages_ptrs: dict[torch.fx.Node, int] = {}
+        self.constant_data_ptrs: dict[torch.fx.Node, StorageWeakRef] = {}
+        # we may constant fold a tensor which in the graph has a sym size
+        # see: [constant folding refining of symints]
+        self.node_replacements_shapes: dict[torch.fx.Node, list[int]] = {}
+
+        # initialize symint -> node mapping so that we can
+        # use symint nodes in full constructors
+        self.symint_nodes = _SymHashingDict()
+        for n in self.module.graph.nodes:  # type: ignore[union-attr]
+            if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt):
+                if n.meta["val"] not in self.symint_nodes:
+                    self.symint_nodes[n.meta["val"]] = n
+
+        # reference from torch/_funtorch/partitioners.py:get_default_op_list
+        self.view_op_packets = [
+            aten.squeeze,
+            aten.unsqueeze,
+            aten.alias,
+            aten.view,
+            aten.slice,
+            aten.t,
+            prims.broadcast_in_dim,
+            aten.expand,
+            aten.as_strided,
+            aten.permute,
+        ]
+
+        self.indexing_op_packets = OrderedSet(
+            [
+                aten.slice,
+            ]
+        )
+
+        self._add_peephole_patterns()
+
+    def _add_peephole_patterns(self) -> None:
+        """
+        Add peephole patterns for nodes where we can infer constant value even if some inputs
+        of the node are unknown.
+        """
+        for op in itertools.chain(
+            self.module.graph.find_nodes(  # type: ignore[operator, union-attr]
+                op="call_function", target=torch.ops.aten.mul.Tensor
+            ),
+            self.module.graph.find_nodes(  # type: ignore[operator, union-attr]
+                op="call_function", target=torch.ops.aten.mul.Scalar
+            ),
+        ):
+            tensor_val = op.meta.get("val", None)
+            if not isinstance(tensor_val, torch.Tensor):
+                continue
+
+            def is_zero_int(arg: Any) -> bool:
+                return isinstance(arg, int) and arg == 0
+
+            if not any(is_zero_int(a) for a in op.args):
+                continue
+
+            t = torch.full(
+                [1],  # shape
+                0,  # value
+                dtype=tensor_val.dtype,
+                device=tensor_val.device,
+                pin_memory=False,
+            )
+            self.add_node_replacement(op, t)
+
+    def _support_dynamic_shape(self):
+        return True
+
+    def insertable_tensor_check(self, t: torch.Tensor) -> bool:
+        return True
+
+    def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
+        self.node_replacements[node] = tensor.flatten()[0].item()
+        self.node_replacements_shapes[node] = node.meta["val"].shape
+        self.constant_data_ptrs[node] = StorageWeakRef(tensor.untyped_storage())
+
+    def insert_placerholder_values(self, env: dict[torch.fx.Node, Any]) -> None:
+        for n in self.module.graph.find_nodes(op="placeholder"):  # type: ignore[operator, union-attr]
+            if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt):
+                env[n] = n.meta["val"]
+            else:
+                env[n] = self.unknown_value
+
+    def _deduce_value(self, node: torch.fx.Node):
+        # deduce value for full-like nodes
+        # 1. for constructors, substitute value is a tensor of size [1]
+        # 2. for view ops/indexing, substitute value is the same as the input
+        # 3. for pointwise ops, run node to get the substitute value
+        # 4. deal with some special ops
+        # otherwise, stop deduce value and return unknown value
+
+        # TODO: cat, more indexing
+        # TODO - do on cpu to avoid syncs
+
+        # single-elem attrs
+        if node.op == "get_attr" or (
+            node.op == "call_function"
+            and node.target is torch.ops.aten.lift_fresh_copy.default
+        ):
+            out = super(ConstantFolder, self).run_node(node)
+            if isinstance(out, torch.Tensor) and out.numel() == 1:
+                return out
+
+        # handle device_put op
+        if node.target == prims.device_put.default:
+            return super(ConstantFolder, self).run_node(node)
+
+        # constructors ops
+        if (
+            node.op == "call_function"
+            and node.target is aten.full.default
+            and len(node.args) == 2
+        ):
+            args, kwargs = self.fetch_args_kwargs_from_env(node)
+            value = args[1]
+            # Don't specialize symbolic value.
+            if not isinstance(value, (torch.SymInt, torch.SymFloat, torch.SymBool)):
+                new_args = [[1], value]
+                return aten.full.default(*new_args, **node.kwargs)
+
+        # handle before view ops because this changes value
+        if node.target is aten.view.dtype:
+            return super(ConstantFolder, self).run_node(node)
+
+        # view ops, return input tensor, the first argument
+        if hasattr(node.target, "overloadpacket") and (
+            node.target.overloadpacket in self.view_op_packets
+            or node.target.overloadpacket in self.indexing_op_packets
+        ):
+            assert isinstance(node.args[0], torch.fx.Node)
+            return self.env[node.args[0]]
+
+        # we don't want to return unknown value for symints so that we can
+        # still constant fold through their use in constructors or views
+        # if we see them in a pointwise node (e.g., tensor * symint)
+        # we will bail
+        if "val" in node.meta and isinstance(node.meta["val"], torch.SymInt):
+            return node.meta["val"]
+
+        # pointwise ops
+        if isinstance(node.target, torch._ops.OpOverload) and (
+            torch.Tag.pointwise in node.target.tags
+            or node.target is torch.ops.aten.scalar_tensor.default
+        ):
+            args, kwargs = self.fetch_args_kwargs_from_env(node)
+            flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
+
+            if any(isinstance(inp, torch.SymInt) for inp in flattened_inputs):
+                return self.unknown_value
+
+            # we run the ops with dim 1, so remove memory_format to avoid error
+            kwargs = dict(kwargs)
+            kwargs.pop("memory_format", None)
+
+            return node.target(*args, **kwargs)
+
+        return self.unknown_value
+
+
+def constant_fold_uniform_value(gm: torch.fx.GraphModule):
+    with torch.utils._python_dispatch._disable_current_modes():
+        "Runs constant folding and replaces constants which can be constructed with a single `full` call. Calls into remove_no_ops."
+        aten = torch.ops.aten
+
+        # Constant folding can leak memory, especially with repeated compilation, so we are only going to
+        # remove constants which can be replaced with a constructor.
+        cf = UniformValueConstantFolder(gm)
+        cf.run()
+
+        node_replacements = cf.node_replacements
+
+        # note: [constant folding refining of symints]
+        # constant folding will partially evaluate a graph such that values which have dependencies which
+        # are entirely known at compile time may also become compile time constants. in some cases,
+        # this will include symints which we had not yet previously deduced are guaranteed a
+        # constant value and is then deduced in constant folding. an example is:
+        # unbacked_symint_eq_11 = torch.full((), 11).item()
+        # torch.full((unbacked_symint_eq_11,), 0)
+        node_replacements_shapes = cf.node_replacements_shapes
+
+        graph = gm.graph
+
+        zeros = OrderedSet[Any]()
+        ones = OrderedSet[Any]()
+
+        # Got failures in `test_is_set_to_cuda` if we change aliasing on constants,
+        # so just constant-ify if a Tensor is unaliased
+        constant_data_ptr_count: typing.Counter[StorageWeakRef] = Counter()
+
+        for node in cf.node_replacements:
+            constant_data_ptr_count[cf.constant_data_ptrs[node]] += 1
+
+        for node, value in node_replacements.items():
+            # we dont have a functional way right now of instantiating a non-contiguous tensor with full/zeros/ones right now
+            # hasn't shown up to be important yet
+            if "val" not in node.meta:
+                # This can only happen in AOTI
+                continue
+
+            fake_tensor = node.meta["val"]
+            if not fake_tensor.is_contiguous(memory_format=torch.contiguous_format):
+                continue
+
+            # TODO - not sure about lossy uint->python value->uint conversions
+            if fake_tensor.dtype in (
+                torch.uint8,
+                torch.uint16,
+                torch.uint32,
+                torch.uint64,
+            ):
+                continue
+
+            if constant_data_ptr_count[cf.constant_data_ptrs[node]] > 1:
+                continue
+
+            with graph.inserting_after(node):
+                # the conversion from tensor and back to value can be lossy, just use the original full ctor value
+                if (
+                    node.op == "call_function"
+                    and node.target is aten.full.default
+                    and len(node.args) == 2
+                ):
+                    value = node.args[1]
+
+                # refines symints, see [constant folding refining of symints] above
+                for runtime_size, compile_time_size in zip(
+                    node_replacements_shapes[node], fake_tensor.shape
+                ):
+                    torch._check(runtime_size == compile_time_size)
+
+                # replace SymInt as Node before creating a new full node
+                # e.g. (1, s0) -> (1, arg0_1)
+                node_shape = node_replacements_shapes[node]
+                if not all(
+                    not isinstance(s, torch.SymInt) or s in cf.symint_nodes
+                    for s in node_shape
+                ):
+                    continue
+
+                shapes = [
+                    cf.symint_nodes[s] if isinstance(s, torch.SymInt) else s
+                    for s in node_replacements_shapes[node]
+                ]
+
+                # zeros and ones just get traced into full, so we insert those
+                new_node = graph.call_function(
+                    aten.full.default,
+                    args=(shapes, value),
+                    kwargs={
+                        "dtype": fake_tensor.dtype,
+                        "layout": torch.strided,
+                        "device": fake_tensor.device,
+                        "pin_memory": False,
+                    },
+                )
+
+                new_node.meta.update(node.meta)
+                node.replace_all_uses_with(new_node)
+                graph.erase_node(node)
+
+                if value == 0:
+                    zeros.add(new_node)
+                elif value == 1:
+                    ones.add(new_node)
+
+        remove_no_ops(gm, zeros, ones)
+        remove_redundant_views(gm)
+
+
+def canonicalize_quant_mapping(gm: torch.fx.GraphModule):
+    """
+
+
+    torch.ops.higher_order.invoke_quant_packed(repeated_subgraph0, 'quant_invoke_0_0', (arg0_1, arg1_1));
+    ->
+    torch.ops.higher_order.invoke_quant(repeated_subgraph0, arg0_1, arg1_1, scheme = 'nf4');
+    """
+    graph = gm.graph
+    invoke_quant_invocations = graph.find_nodes(
+        op="call_function", target=torch.ops.higher_order.invoke_quant_packed
+    )
+    for invoke_quant in invoke_quant_invocations:
+        kwargs = dict(invoke_quant.kwargs)
+
+        quant_options_node = kwargs.pop("quant_options", None)
+        if quant_options_node is not None:
+            assert isinstance(quant_options_node, torch.fx.Node)
+            quant_options = torch._higher_order_ops.InvokeQuant(
+                *invoke_quant.kwargs["quant_options"].args,
+                **invoke_quant.kwargs["quant_options"].kwargs,
+            )
+        else:
+            quant_options = torch._higher_order_ops.InvokeQuant()
+
+        subgraph, *args = invoke_quant.args
+        with gm.graph.inserting_before(invoke_quant):
+            invoke_quant_replacement = graph.call_function(
+                torch._higher_order_ops.invoke_quant,
+                (subgraph, *args),
+                # pyrefly: ignore [bad-argument-type]
+                kwargs,
+            )
+            invoke_quant_replacement.meta.update(subgraph.meta)
+            invoke_quant_replacement.meta["quant_options"] = quant_options
+
+            invoke_quant.replace_all_uses_with(invoke_quant_replacement)
+            graph.erase_node(invoke_quant)
+
+            if quant_options_node and len(quant_options_node.users) == 0:
+                graph.erase_node(quant_options_node)
+
+            first_user = next(iter(invoke_quant_replacement.users))
+
+            if (
+                len(invoke_quant_replacement.users) == 1
+                and len(subgraph.users) == 1
+                and first_user.target is operator.getitem
+                and first_user.args[1] == 0
+            ):
+                subgraph_graph = getattr(gm, subgraph.target)
+                output_node = torch._inductor.utils.output_node(subgraph_graph)
+                assert (
+                    isinstance(output_node.args[0], (list, tuple))
+                    and len(output_node.args[0]) == 1
+                )
+
+                unpacked_output = output_node.args[0][0]
+                output_node.args = (unpacked_output,)
+                if "val" in output_node.meta:
+                    output_node.meta["val"] = output_node.meta["val"][0]
+                subgraph_graph.recompile()
+
+                invoke_quant_replacement.meta.update(first_user.meta)
+                first_user.replace_all_uses_with(invoke_quant_replacement)
+                graph.erase_node(first_user)
+
+
+def canonicalize_aten_ir_passes(gm: torch.fx.GraphModule):
+    """
+    Canonicalization passes that will run immediately after aot autograd
+    tracing. Thsis must be run before all other graph passes.
+    """
+    canonicalize_quant_mapping(gm)
+
+
+def joint_graph_passes(graph: torch.fx.GraphModule):
+    """
+    Run FX transformations on the joint forwards+backwards graph.
+    """
+    GraphTransformObserver = functools.partial(
+        torch.fx.passes.graph_transform_observer.GraphTransformObserver,
+        subsystem="joint_graph_passes",
+    )
+
+    lazy_init()
+    count = 0
+
+    # must occur before other passes
+    canonicalize_aten_ir_passes(graph)
+
+    if config.joint_custom_pre_pass is not None:
+        GraphTransformObserver(graph, "joint_custom_pre_pass").apply_graph_pass(
+            config.joint_custom_pre_pass
+        )
+        count += 1
+
+    from .post_grad import remove_noop_ops
+
+    GraphTransformObserver(graph, "remove_noop_ops").apply_graph_pass(remove_noop_ops)
+
+    if config.joint_graph_constant_folding:
+        GraphTransformObserver(graph, "constant_fold_uniform_value").apply_gm_pass(
+            constant_fold_uniform_value
+        )
+
+    if config.pattern_matcher:
+        for i, patterns in enumerate(pass_patterns):
+            maybe_count = GraphTransformObserver(
+                graph, f"pass_pattern_{i}"
+            ).apply_graph_pass(patterns.apply)
+            count += maybe_count if maybe_count is not None else 0
+
+    if not config.fallback_random:
+        # not trying into the bisector because decomps may have already affected rng reproducibility
+        # we'll instead explicitly turn off the config
+        count += replace_random_passes(graph)
+
+    if config.joint_custom_post_pass is not None:
+        GraphTransformObserver(graph, "joint_custom_post_pass").apply_graph_pass(
+            config.joint_custom_post_pass
+        )
+        count += 1
+
+    if count:
+        stable_topological_sort(graph.graph)
+        graph.graph.lint()
+        graph.recompile()
+    return graph
+
+
+@register_graph_pattern(
+    CallFunction(
+        torch.ops.prims.iota.default,
+        KeywordArg("length"),
+        start=KeywordArg("start"),
+        step=KeywordArg("step"),
+        dtype=KeywordArg("dtype"),
+        device=KeywordArg("device"),
+        requires_grad=KeywordArg("requires_grad"),
+    ),
+    # pyrefly: ignore [bad-argument-type]
+    pass_dict=patterns,
+)
+def fix_iota_device(match: Match, length, start, step, dtype, device, requires_grad):
+    """
+    Eager supports:
+
+        aten.index(cuda_tensor, torch.arange(..., device="cpu"))
+
+    But this results in an implicit host-device-copy and breaks cudagraphs.
+    Rewrite the arange to use CUDA.
+    """
+    (node,) = match.nodes
+    user_devices = OrderedSet[torch.device]()
+    for user in node.users:
+        if (
+            user.op == "call_function"
+            and user.target in (aten.index.Tensor, aten.index_put.default)
+            and hasattr(user.meta.get("val"), "device")
+        ):
+            user_devices.add(user.meta["val"].device)  # type: ignore[union-attr]
+        else:
+            return  # bail out
+
+    if len(user_devices) == 1 and "val" in node.meta:
+        (user_device,) = user_devices
+        if device.type != user_device.type:
+            repl = match.graph.call_function(
+                torch.ops.prims.iota.default,
+                (length,),
+                {
+                    "start": start,
+                    "step": step,
+                    "dtype": dtype,
+                    "device": user_device,
+                    "requires_grad": requires_grad,
+                },
+            )
+            repl.meta.update(node.meta)
+            repl.meta["val"] = repl.meta["val"].to(user_device)
+            node.replace_all_uses_with(repl)
+            match.erase_nodes()
+
+
+@register_graph_pattern(
+    CallFunction(
+        torch.ops.prims.convert_element_type.default,
+        CallFunction(
+            torch.ops.prims.convert_element_type.default,
+            KeywordArg("arg"),
+            KeywordArg("dtype1"),
+        ),
+        KeywordArg("dtype2"),
+    ),
+    # pyrefly: ignore [bad-argument-type]
+    pass_dict=patterns,
+)
+def pointless_convert(match: Match, arg, dtype1: torch.dtype, dtype2: torch.dtype):
+    """Remove chain of dtype conversions often created by AMP"""
+    graph = match.graph
+    node = match.output_node()
+    allowed = torch.float16, torch.bfloat16, torch.float32, torch.float64
+    if dtype1 in allowed and dtype2 in allowed:
+        repl = graph.call_function(
+            torch.ops.prims.convert_element_type.default, (arg, dtype2)
+        )
+        repl.meta.update(node.meta)
+        node.replace_all_uses_with(repl)
+        match.erase_nodes()
+
+
+def definitely_equal(
+    old_sizes: Sequence[torch.SymInt | int],
+    new_sizes: Sequence[torch.SymInt | torch.fx.Node | int],
+) -> bool:
+    """
+    Leverage guard_or_true/false to compare if two lists of int/symint are equal.
+    Useful to compare sizes, strides etc.
+
+    Can handle -1 in new_sizes which happens in the size arguments of a
+    view op. old_sizes is supposed to be the tensor shape and should not
+    contain -1.
+
+    new_sizes can contains fx.Node when dynamic shape is enabled. In that
+    case new_sizes[i].meta['val'] contains the real torch.SymInt.
+    """
+
+    num_neg1 = 0
+
+    if len(old_sizes) != len(new_sizes):
+        return False
+
+    for lhs_item, rhs_item in zip(old_sizes, new_sizes):
+        if isinstance(rhs_item, torch.fx.Node):
+            rhs_item = rhs_item.meta["val"]
+
+        assert isinstance(lhs_item, (int, torch.SymInt)), type(lhs_item)
+        assert isinstance(rhs_item, (int, torch.SymInt)), type(rhs_item)
+
+        # It still makes sense to call guard_or_true/false since lhs_item
+        # rhs_item are torch.SymInt rather than sympy expressions when
+        # dynamic shape is enabled.
+        if guard_or_false(lhs_item == rhs_item):
+            continue
+
+        if guard_or_true(rhs_item != -1):
+            return False
+
+        num_neg1 += 1
+
+        if num_neg1 > 1:
+            return False
+    return True
+
+
+@register_graph_pattern(
+    CallFunction(torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size")),
+    # pyrefly: ignore [bad-argument-type]
+    pass_dict=patterns,
+)
+def pointless_view(match: Match, arg, size):
+    """Remove no-op view"""
+    node = match.output_node()
+    arg_size = list(node.args[0].meta["val"].shape)  # type: ignore[union-attr]
+    if definitely_equal(arg_size, size):
+        node.replace_all_uses_with(node.args[0])  # type: ignore[arg-type]
+        match.erase_nodes()
+
+
+@register_graph_pattern(
+    CallFunction(
+        aten.view.default,
+        CallFunction(aten.view.default, KeywordArg("arg"), KeywordArg("size1")),
+        KeywordArg("size2"),
+    ),
+    # pyrefly: ignore [bad-argument-type]
+    pass_dict=patterns,
+)
+def pointless_view_pair(match: Match, arg, size1, size2):
+    """
+    Remove a pair of views that are pointless.
+    """
+    node = match.output_node()
+    arg_size = list(arg.meta["val"].shape)
+    if definitely_equal(arg_size, size2):
+        node.replace_all_uses_with(arg)
+        match.erase_nodes()
+        counters["inductor"]["removed_pointless_view_pair"] += 1
+
+
+@register_graph_pattern(
+    CallFunction(
+        aten.permute.default,
+        CallFunction(aten.permute.default, KeywordArg("arg"), KeywordArg("perm1")),
+        KeywordArg("perm2"),
+    ),
+    # pyrefly: ignore [bad-argument-type]
+    pass_dict=patterns,
+)
+def pointless_permute_pair(match: Match, arg, perm1, perm2):
+    rank = len(perm1)
+    assert len(perm2) == rank
+
+    for i in range(rank):
+        if perm1[perm2[i]] != i:
+            return  # bail out
+    node = match.output_node()
+    node.replace_all_uses_with(arg)
+    match.erase_nodes()
+
+
+@register_graph_pattern(
+    CallFunction(
+        aten.bmm,
+        Arg(),
+        Arg(),
+    ),
+    # pyrefly: ignore [bad-argument-type]
+    pass_dict=patterns,
+)
+def bmm_to_mm(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node):
+    """Convert bmm to mm when batch size is 1"""
+
+    def repl(a, b):
+        return torch.mm(a.squeeze(0), b.squeeze(0)).unsqueeze(0)
+
+    if (
+        check_device(mat1.meta["val"], mat2.meta["val"], get_gpu_type())
+        and statically_known_true(mat1.meta["val"].shape[0] == 1)
+        and statically_known_true(mat2.meta["val"].shape[0] == 1)
+    ):
+        # pyrefly: ignore [bad-argument-type]
+        match.replace_by_example(repl, [mat1, mat2])
+
+
+# When softmax is used with temperature or other scaling, we get the pattern
+#
+#   scale(x) - scale(x).amax(dim, keepdim=True)
+#
+# which is expected to be at most zero, but we may end up with numerical
+# discrepancies # between the recomputed values of scale(x) inside and out
+# of the reduction, # depending on compiler optimizations, e.g. use of fma
+# instructions.
+#
+# Here we replace it with the mathematically equivalent,
+#
+#   scale(x - x.amax(dim, keepdim=True))
+#
+# which is more stable as we only compute the scaling once.
+#
+# NOTE: This pattern must come after fused attention matching!
+
+
+def _partial_softmax_pattern(linear_func, reverse=False, to_dtype=False):
+    # Allow matching inp * other and other * input
+    if reverse:
+        scaled = CallFunction(
+            linear_func, KeywordArg("other"), KeywordArg("inp"), _users=MULTIPLE
+        )
+    else:
+        scaled = CallFunction(
+            linear_func, KeywordArg("inp"), KeywordArg("other"), _users=MULTIPLE
+        )
+    if to_dtype:
+        scaled = CallFunction(
+            prims.convert_element_type, scaled, KeywordArg("dtype"), _users=MULTIPLE
+        )
+    amax = CallFunction(
+        aten.amax.default, scaled, KeywordArg("dim"), KeywordArg("keepdim")
+    )
+    return CallFunction(aten.sub.Tensor, scaled, amax)
+
+
+def _other_is_broadcasted_in_dim(match):
+    # Check that the scaling factor is constant across the reduction dim,
+    # so scaling doesn't change which index corresponds to the maximum value
+    other = match.kwargs["other"]
+    if isinstance(other, (int, float)):
+        return True
+
+    inp = match.kwargs["inp"]
+    if not all(isinstance(x, torch.fx.Node) for x in (inp, other)):
+        return False
+
+    inp_example = inp.meta["val"]
+    other_example = other.meta["val"]
+    if isinstance(other_example, (torch.SymInt, torch.SymFloat)):
+        return True
+
+    if not all(isinstance(x, torch.Tensor) for x in (inp_example, other_example)):
+        return False
+
+    inp_ndim = inp_example.ndim
+    other_shape = other_example.shape
+    if inp_ndim < len(other_shape):
+        return False
+
+    # Pad other_shape to the same ndim as inp
+    other_shape = [1] * (inp_ndim - len(other_shape)) + list(other_shape)
+
+    dim = match.kwargs["dim"]
+    if isinstance(dim, int):
+        dim = (dim,)
+
+    if any(d >= len(other_shape) for d in dim):
+        return False
+
+    return all(statically_known_true(other_shape[d] == 1) for d in dim)
+
+
+def mul_softmax_pattern(match: Match, *, inp, other, dim, keepdim, dtype=None):
+    def repl(inp, other):
+        if dtype is not None:
+            inp = inp.to(dtype)
+
+        sign: int | float | torch.Tensor
+        if isinstance(other, (int, float, torch.SymInt, torch.SymFloat)):
+            sign = 1 if other >= 0 else -1
+        else:
+            one = torch.scalar_tensor(1, dtype=inp.dtype, device=inp.device)
+            sign = torch.where(other >= 0, one, -one)
+
+        inp = inp * sign
+        max_ = torch.amax(inp, dim=dim, keepdim=keepdim)
+        # pyrefly: ignore [unsupported-operation]
+        return (inp - max_) * (sign * other)
+
+    # pyrefly: ignore [bad-argument-type]
+    match.replace_by_example(repl, [inp, other])
+
+
+for reverse, to_dtype in itertools.product((False, True), repeat=2):
+    register_graph_pattern(
+        _partial_softmax_pattern(aten.mul.Tensor, reverse=reverse, to_dtype=to_dtype),
+        # pyrefly: ignore [bad-argument-type]
+        pass_dict=pass_patterns[1],
+        extra_check=_other_is_broadcasted_in_dim,
+    )(mul_softmax_pattern)
+
+
+def div_softmax_pattern(match: Match, *, inp, other, dim, keepdim, dtype=None):
+    def repl(inp, other):
+        if dtype is not None:
+            inp = inp.to(dtype)
+
+        sign: int | float | torch.Tensor
+        if isinstance(other, (int, float, torch.SymInt, torch.SymFloat)):
+            sign = 1 if other >= 0 else -1
+        else:
+            one = torch.scalar_tensor(1, dtype=inp.dtype, device=inp.device)
+            sign = torch.where(other >= 0, one, -one)
+
+        inp = inp * sign
+        max_ = torch.amax(inp, dim=dim, keepdim=keepdim)
+        # pyrefly: ignore [unsupported-operation]
+        return (inp - max_) / (sign * other)
+
+    # pyrefly: ignore [bad-argument-type]
+    match.replace_by_example(repl, [inp, other])
+
+
+for to_dtype in (False, True):
+    register_graph_pattern(
+        _partial_softmax_pattern(aten.div.Tensor, to_dtype=to_dtype),
+        # pyrefly: ignore [bad-argument-type]
+        pass_dict=pass_patterns[1],
+        extra_check=_other_is_broadcasted_in_dim,
+    )(div_softmax_pattern)
+
+
+def scatter_upon_const_tensor_extra_check(m):
+    if not config.optimize_scatter_upon_const_tensor:
+        return False
+    full_shape = m.kwargs["shape"]
+    selector = m.kwargs["selector"]
+    dim = m.kwargs["dim"]
+    if dim < 0:
+        dim += len(full_shape)
+
+    selector_ft = selector.meta["val"]
+    assert selector_ft.dim() == len(full_shape)
+
+    for idx, select_sz, full_sz in zip(
+        itertools.count(), selector_ft.shape, full_shape
+    ):
+        if idx == dim:
+            continue
+
+        # TODO: the pattern can be updated to support the case that index tensor
+        # is shorter. But that will need a more complex condition expression
+        # especially for multi-dimensional tensors.
+        # Skip it for now.
+        if isinstance(full_sz, torch.fx.Node):
+            full_sz = full_sz.meta["val"]
+        if select_sz < full_sz:
+            return False
+
+    # Actually we can support small size larger than 1. It would be a bit
+    # tedious. E.g., we load all the index values (not many) and compare
+    # them with the position in tensor to decide what value to return.
+    return selector_ft.size(dim) == 1
+
+
+@register_graph_pattern(
+    CallFunction(
+        aten.scatter.value,
+        CallFunction(
+            aten.full,
+            KeywordArg("shape"),
+            KeywordArg("background_val"),
+            dtype=KeywordArg("dtype"),
+        ),
+        KeywordArg("dim"),
+        KeywordArg("selector"),
+        KeywordArg("val"),  # scalar value
+    ),
+    # pyrefly: ignore [bad-argument-type]
+    pass_dict=patterns,
+    extra_check=scatter_upon_const_tensor_extra_check,
+)
+def scatter_upon_const_tensor(
+    match: Match, shape, background_val, dtype, dim, selector, val
+):
+    """
+    Match the pattern of full+scatter into a pointwise operation in joint graph.
+
+    TODO: Right now the scatter value must be a scalar. But we could support it
+    when it is a tensor as well.
+    """
+    from torch._inductor import metrics
+
+    # pyrefly: ignore  # bad-assignment
+    metrics.num_matches_for_scatter_upon_const_tensor += 1
+
+    # Create a replacement that uses torch.where for the pointwise operation
+    def repl_fn(shape, background_val, dim, selector, val):
+        # Create a tensor of indices for the scatter dimension
+        length = shape[dim]
+        indices = torch.arange(length, device=selector.device, dtype=torch.int64)
+
+        # Reshape indices to have size 'length' at dim, then broadcast
+        view_shape = [1] * len(shape)
+        view_shape[dim] = length
+        indices_view = indices.view(*view_shape)
+
+        # Broadcast selector to match full tensor shape
+        selector_expanded = selector.expand(shape)
+
+        # Create a mask for where to scatter
+        mask = selector_expanded == indices_view
+
+        # Use torch.where to implement the scatter pointwise operation
+        return torch.where(mask, val, background_val)
+
+    # replace the scatter operation with pointwise equivalent
+    # pyrefly: ignore [bad-argument-type]
+    match.replace_by_example(repl_fn, [shape, background_val, dim, selector, val])
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/memory_estimator.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/memory_estimator.py
new file mode 100644
index 0000000000000000000000000000000000000000..e887d4bf62c8e11196ac5b2740c0ef3c39e64def
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/memory_estimator.py
@@ -0,0 +1,454 @@
+import itertools
+import logging
+from collections import defaultdict
+from collections.abc import Callable
+from dataclasses import dataclass
+
+import torch
+import torch.fx as fx
+from torch.fx.experimental.symbolic_shapes import hint_int
+from torch.utils._ordered_set import OrderedSet
+from torch.utils._pytree import tree_map_only
+
+
+log = logging.getLogger(__name__)
+
+
+@dataclass(frozen=True)
+class StorageKey:
+    storage: torch.UntypedStorage
+    device: torch.device
+
+    def __hash__(self) -> int:
+        return self.storage._cdata
+
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, StorageKey):
+            return False
+        return (
+            self.storage._cdata == other.storage._cdata and self.device == other.device
+        )
+
+
+class GraphAliasTracker:
+    """
+    Tracks storage allocation and usage relationships in an FX graph.
+
+    Differentiates between:
+    - Fresh allocations: nodes that allocate new storage (not views/aliases)
+    - Uses: nodes that use a storage as input
+    """
+
+    def __init__(self, nodes: list[fx.Node]):
+        # Map from node to the fresh storages it allocates (not views/aliases)
+        self.node_to_fresh_allocations: dict[fx.Node, OrderedSet[StorageKey]] = {}
+
+        # Map from storage to the node that originally allocated it
+        self.storage_to_allocator: dict[StorageKey, fx.Node] = {}
+
+        # Map from node to all storages it uses as inputs
+        self.node_to_storage_uses: dict[fx.Node, OrderedSet[StorageKey]] = {}
+
+        # Map from storage to all nodes that use it
+        self.storage_to_uses: dict[StorageKey, OrderedSet[fx.Node]] = defaultdict(
+            OrderedSet
+        )
+
+        # Map from storage to the last node that uses it
+        self.storage_to_last_user: dict[StorageKey, fx.Node] = {}
+
+        # Map from node to storages that have their last use at that node
+        self.node_to_storages_last_used: dict[fx.Node, OrderedSet[StorageKey]] = (
+            defaultdict(OrderedSet)
+        )
+
+        # Track all output storages for each node (for building usage graph)
+        self.node_to_output_storages: dict[fx.Node, OrderedSet[StorageKey]] = {}
+
+        # First pass: build storage allocations and track uses
+        for node in nodes:
+            # Get output storages
+            output_storages = self._get_output_storages(node)
+            self.node_to_output_storages[node] = output_storages
+
+            # Track fresh allocations
+            fresh_allocations: OrderedSet[StorageKey] = OrderedSet()
+            for storage_key in output_storages:
+                if storage_key not in self.storage_to_allocator:
+                    self.storage_to_allocator[storage_key] = node
+                    fresh_allocations.add(storage_key)
+            self.node_to_fresh_allocations[node] = fresh_allocations
+
+            # Track input storage uses (safe because inputs were already processed)
+            input_storages = self._get_input_storages(node)
+            self.node_to_storage_uses[node] = input_storages
+            for storage_key in input_storages:
+                self.storage_to_uses[storage_key].add(node)
+
+        # Second pass: find last users (iterate in reverse)
+        for node in reversed(nodes):
+            input_storages = self.node_to_storage_uses[node]
+            for storage_key in input_storages:
+                if storage_key not in self.storage_to_last_user:
+                    self.storage_to_last_user[storage_key] = node
+                    self.node_to_storages_last_used[node].add(storage_key)
+
+    @staticmethod
+    def _get_output_storages(node: fx.Node) -> OrderedSet[StorageKey]:
+        """
+        Get all storages from a node's outputs.
+
+        Uses pytree to handle arbitrary nested structures.
+        """
+        val = node.meta.get("val")
+        if val is None:
+            return OrderedSet()
+
+        storages: OrderedSet[StorageKey] = OrderedSet()
+
+        def collect_storage(tensor: torch._subclasses.FakeTensor) -> None:
+            storages.add(StorageKey(tensor.untyped_storage(), tensor.device))
+
+        # Use tree_map_only to handle FakeTensors in nested structures
+        tree_map_only(torch._subclasses.FakeTensor, collect_storage, val)
+
+        return storages
+
+    def _get_input_storages(self, node: fx.Node) -> OrderedSet[StorageKey]:
+        """
+        Get all storages from a node's inputs.
+        """
+        input_storages: OrderedSet[StorageKey] = OrderedSet()
+
+        for input_node in node.all_input_nodes:
+            input_storages.update(self.node_to_output_storages[input_node])
+
+        return input_storages
+
+    def get_fresh_allocations(self, node: fx.Node) -> OrderedSet[StorageKey]:
+        """Get all fresh storage allocations by this node (not views/aliases)."""
+        return self.node_to_fresh_allocations[node]
+
+    def get_storage_uses(self, node: fx.Node) -> OrderedSet[StorageKey]:
+        """Get all storages that this node uses as inputs."""
+        return self.node_to_storage_uses[node]
+
+    def get_storages_last_used(
+        self,
+        node: fx.Node,
+    ) -> OrderedSet[StorageKey]:
+        """
+        Get storages whose last use is at this node.
+        """
+        return self.node_to_storages_last_used[node]
+
+
+def _size_of_default(num_bytes: int | torch.SymInt) -> int:
+    return hint_int(num_bytes, fallback=torch._inductor.config.unbacked_symint_fallback)
+
+
+def device_filter(device: torch.device) -> bool:
+    return device.type != "cpu"
+
+
+def build_memory_profile(
+    graph: fx.Graph,
+    is_releasable: Callable[[fx.Node], bool],
+    size_of: Callable[[int | torch.SymInt], int] | None = None,
+) -> list[int]:
+    """
+    Function to estimate the memory profile of an input FX graph.
+
+    Args:
+    - graph (fx.Graph): The input FX graph for which the memory profile
+      is to be estimated.
+    - is_releasable (Callable[[fx.Node], bool]): A function that
+      determines if a node's memory can be released (e.g. primal nodes
+      cannot be released).
+    - size_of (Callable[[int | torch.SymInt], int]): A function that converts
+      byte counts (possibly symbolic) to concrete integers.
+
+    Returns:
+    - List[int]: A list representing the memory profile over the execution
+      of the graph, where each entry corresponds to the memory usage at
+      a particular point in the execution.
+    """
+
+    size_of = size_of or _size_of_default
+    nodes = list(graph.nodes)
+    alias_info = GraphAliasTracker(nodes)
+
+    # Build memory profile
+    current_memory = 0
+
+    for node in itertools.chain(
+        graph.find_nodes(op="placeholder"), graph.find_nodes(op="get_attr")
+    ):
+        for storage_key in alias_info.get_fresh_allocations(node):
+            if device_filter(storage_key.device):
+                current_memory += size_of(storage_key.storage.nbytes())
+
+    memory_profile = [current_memory]
+
+    for node in nodes:
+        if node.op in ("placeholder", "get_attr", "output"):
+            continue
+
+        # Process allocations
+        for storage_key in alias_info.get_fresh_allocations(node):
+            if device_filter(storage_key.device):
+                current_memory += size_of(storage_key.storage.nbytes())
+
+        memory_profile.append(current_memory)
+
+        # Process deallocations
+        for storage_key in alias_info.get_storages_last_used(node):
+            allocator = alias_info.storage_to_allocator[storage_key]
+            if is_releasable(allocator):
+                if device_filter(storage_key.device):
+                    current_memory -= size_of(storage_key.storage.nbytes())
+
+        memory_profile.append(current_memory)
+
+    return memory_profile
+
+
+def get_fwd_bwd_interactions(
+    fwd_graph: fx.Graph,
+    bwd_graph: fx.Graph,
+    size_of: Callable[[int | torch.SymInt], int] | None = None,
+) -> tuple[int, OrderedSet[str]]:
+    """
+    Analyze the interactions between the forward (fwd) and backward (bwd) graphs
+    to determine memory usage characteristics.
+
+    Args:
+    - fwd_graph (fx.Graph): The forward graph representing the forward pass.
+    - bwd_graph (fx.Graph): The backward graph representing the backward pass.
+    - size_of (Callable[[int | torch.SymInt], int]): A function that converts
+      byte counts (possibly symbolic) to concrete integers.
+
+    Returns:
+    - tuple[int, OrderedSet[str]]: A tuple containing:
+        1. The baseline memory usage during the backward pass, accounting for
+           storages that persist from the forward pass (i.e., in fwd output but
+           not in bwd input).
+        2. A set of node names whose storage cannot be released during the bwd pass.
+           These include nodes that use storage from primals or are in bwd input
+           but not in fwd output.
+    """
+
+    size_of = size_of or _size_of_default
+
+    # Build alias info for forward graph
+    fwd_nodes = list(fwd_graph.nodes)
+    fwd_alias_info = GraphAliasTracker(fwd_nodes)
+
+    # Identify storages allocated by primal placeholder nodes
+    primal_storages: OrderedSet[StorageKey] = OrderedSet()
+    for node in fwd_graph.find_nodes(op="placeholder"):
+        if node.name.startswith("primals"):
+            primal_storages.update(fwd_alias_info.get_fresh_allocations(node))
+
+    # Get storages in forward output
+    fwd_output_node = next(iter(reversed(fwd_graph.nodes)))[-1]
+    assert fwd_output_node.op == "output"
+    fwd_output_storages = fwd_alias_info.get_storage_uses(fwd_output_node)
+
+    # Node names that should not be deleted during memory profile estimation of bwd_graph
+    do_not_delete: OrderedSet[str] = OrderedSet()
+
+    # Collect all storages in backward inputs and identify nodes to not delete
+    bwd_input_storages: OrderedSet[StorageKey] = OrderedSet()
+    for node in bwd_graph.find_nodes(op="placeholder"):
+        node_storages = GraphAliasTracker._get_output_storages(node)
+        bwd_input_storages.update(node_storages)
+
+        # Check if this node uses primal storage
+        if node_storages & primal_storages:
+            do_not_delete.add(node.name)
+
+        # Check if this node's storages are not in forward outputs
+        # (meaning it's an external input to backward pass)
+        if not (node_storages & fwd_output_storages):
+            do_not_delete.add(node.name)
+
+    # Calculate baseline memory: storages in fwd output but not in bwd input
+    # These storages persist throughout the backward pass
+    baseline_storages = fwd_output_storages - bwd_input_storages
+    bwd_baseline_memory = 0
+    for storage_key in baseline_storages:
+        if storage_key.device.type != "cpu":
+            bwd_baseline_memory += size_of(storage_key.storage.nbytes())
+
+    return bwd_baseline_memory, do_not_delete
+
+
+def _is_releasable(n: fx.Node) -> bool:
+    # Storages of primals cannot be released during fwd or bwd pass.
+    return not n.name.startswith("primals")
+
+
+def get_peak_memory(
+    fwd_graph: fx.Graph,
+    bwd_graph: fx.Graph,
+) -> int:
+    fwd_peak_memory = max(build_memory_profile(fwd_graph, _is_releasable))
+
+    bwd_baseline_memory, bwd_do_not_delete = get_fwd_bwd_interactions(
+        fwd_graph,
+        bwd_graph,
+    )
+
+    def _is_bwd_releasable(n: fx.Node) -> bool:
+        # Storages of nodes in bwd_do_not_delete cannot be released
+        # during the bwd pass.
+        return _is_releasable(n) and n.name not in bwd_do_not_delete
+
+    bwd_peak_memory = bwd_baseline_memory + max(
+        build_memory_profile(bwd_graph, _is_bwd_releasable)
+    )
+    return max(
+        fwd_peak_memory,
+        bwd_peak_memory,
+    )
+
+
+class MemoryTracker:
+    """
+    Tracks memory usage for alternative scheduling orders of an FX graph.
+
+    This class enables tracking memory usage as nodes are scheduled in a different
+    order than the original graph.
+    """
+
+    def __init__(
+        self,
+        graph: fx.Graph,
+        is_releasable: Callable[[fx.Node], bool] | None = None,
+        device_filter: Callable[[torch.device], bool] | None = None,
+    ):
+        """
+        Initialize memory tracker for alternative scheduling of the given graph.
+
+        Args:
+            graph: FX graph to track memory for under alternative scheduling
+            is_releaseable: do we consider this input to the graph to release memory
+            upon final use, or is allocated for the duration of the graph ?
+            by default, we assume all nodes but those that start with "primals" to be releasable
+            device_filter: Function to determine which devices to track (default: non-CPU)
+        """
+
+        self.graph = graph
+        self.nodes = list(graph.nodes)
+        self.device_filter = device_filter or (lambda device: device.type != "cpu")
+        self.scheduled: OrderedSet[fx.Node] = OrderedSet()
+
+        # Memory tracking using GraphAliasTracker
+        self.alias_tracker = GraphAliasTracker(self.nodes)
+        self.current_live_storages: OrderedSet[StorageKey] = OrderedSet()
+        self.current_memory_bytes = 0
+        self.is_releasable = _is_releasable if is_releasable is None else is_releasable
+
+        # Initialize live storages with placeholders and get_attr nodes
+        for node in self.nodes:
+            if node.op in ("placeholder", "get_attr"):
+                fresh_allocations = self.alias_tracker.get_fresh_allocations(node)
+                for storage_key in fresh_allocations:
+                    if self.device_filter(storage_key.device):
+                        self.current_live_storages.add(storage_key)
+                        self.current_memory_bytes += self._get_storage_size(storage_key)
+
+        self.peak_memory = self.current_memory_bytes
+
+        log.debug(
+            "Memory tracker initialized with initial memory: %d MB",
+            self.current_memory_bytes // (1024 * 1024),
+        )
+
+    def schedule_node(self, node: fx.Node) -> None:
+        """
+        Schedule a node and update memory tracking for the new scheduling order.
+
+        Args:
+            node: The node being scheduled (potentially out of original order)
+        """
+        assert node not in self.scheduled, "should not schedule node twice"
+        self.scheduled.add(node)
+        self._update_memory_for_node(node)
+
+    def get_current_memory_bytes(self) -> int:
+        """Get current live memory in bytes under the current scheduling."""
+        return self.current_memory_bytes
+
+    def _get_storage_size(self, storage_key: StorageKey) -> int:
+        """Get the size of a storage in bytes, handling symbolic shapes."""
+        size_bytes = storage_key.storage.nbytes()
+        return hint_int(
+            size_bytes, fallback=torch._inductor.config.unbacked_symint_fallback
+        )
+
+    def _get_storages_freed_by_node(self, node: fx.Node) -> OrderedSet[StorageKey]:
+        """Get storages that would be freed if we schedule this node."""
+        freed_storages: OrderedSet[StorageKey] = OrderedSet()
+
+        input_storages = self.alias_tracker.get_storage_uses(node)
+        for storage_key in input_storages:
+            if not self.device_filter(storage_key.device):
+                continue
+
+            # Invariant: if a node uses a storage, it must be live
+            assert storage_key in self.current_live_storages, (
+                "all input storages should be currently allocated"
+            )
+
+            if not self.is_releasable(
+                self.alias_tracker.storage_to_allocator[storage_key]
+            ):
+                continue
+
+            all_uses = self.alias_tracker.storage_to_uses[storage_key]
+
+            # If no more unscheduled uses remain, the storage can be freed
+            if all(u in self.scheduled for u in all_uses):
+                freed_storages.add(storage_key)
+
+        return freed_storages
+
+    def _update_memory_for_node(self, node: fx.Node) -> None:
+        """Update memory tracking when a node is scheduled."""
+        if node.op in ("placeholder", "get_attr", "output"):
+            return
+
+        # Add fresh allocations
+        fresh_allocations = self.alias_tracker.get_fresh_allocations(node)
+        alloc_bytes = 0
+        for storage_key in fresh_allocations:
+            if (
+                self.device_filter(storage_key.device)
+                and storage_key not in self.current_live_storages
+            ):
+                size = self._get_storage_size(storage_key)
+                self.current_live_storages.add(storage_key)
+                self.current_memory_bytes += size
+                alloc_bytes += size
+
+        self.peak_memory = max(self.current_memory_bytes, self.peak_memory)
+
+        # Remove storages that are no longer used
+        storages_to_free = self._get_storages_freed_by_node(node)
+        freed_bytes = 0
+        for storage_key in storages_to_free:
+            if storage_key in self.current_live_storages:
+                size = self._get_storage_size(storage_key)
+                self.current_live_storages.remove(storage_key)
+                self.current_memory_bytes -= size
+                freed_bytes += size
+
+        log.debug(
+            "Scheduled %s: memory change %d allocs, %d frees, current memory: %d MB",
+            node.name,
+            len(fresh_allocations),
+            len(storages_to_free),
+            self.current_memory_bytes // (1024 * 1024),
+        )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/micro_pipeline_tp.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/micro_pipeline_tp.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cc5503d4815b6a37d1e0daa9c5ffaad4498539f
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/micro_pipeline_tp.py
@@ -0,0 +1,1114 @@
+# mypy: allow-untyped-defs
+import logging
+import operator
+from collections import defaultdict
+from dataclasses import dataclass, field
+from math import prod
+from typing import Any, cast
+
+import torch
+from torch.utils._ordered_set import OrderedSet
+
+from .. import config, inductor_prims
+from ..pattern_matcher import (
+    CallFunction,
+    Ignored,
+    KeywordArg,
+    ListOf,
+    Match,
+    MULTIPLE,
+    PatternExpr,
+    PatternMatcherPass,
+)
+
+
+log = logging.getLogger(__name__)
+aten = torch.ops.aten
+patterns = PatternMatcherPass()
+
+
+def _is_last_dim(t: torch.Tensor, dim: int) -> bool:
+    return dim == t.ndim - 1 or dim == -1
+
+
+def _is_backward(graph: torch.fx.Graph) -> bool:
+    placeholders = []
+    for node in graph.nodes:
+        if node.op != "placeholder":
+            break
+        placeholders.append(node)
+    return not all(node.name.startswith("primal") for node in placeholders)
+
+
+def _compute_mm_arithmetic_intensity(M: int, N: int, K: int) -> float:
+    return M * N * K / (M * K + N * K + M * N)
+
+
+def _filter_nodes_by_target(nodes: list[torch.fx.Node], target) -> list[torch.fx.Node]:
+    return [x for x in nodes if x.target == target]
+
+
+def _find_ancestors(node: torch.fx.Node) -> OrderedSet[torch.fx.Node]:
+    ancestors = OrderedSet[torch.fx.Node]()
+    ancestors.add(node)
+    cur_nodes = [node]
+    while len(cur_nodes) > 0:
+        new_nodes = []
+        for node in cur_nodes:
+            for inp in node.all_input_nodes:
+                if inp not in ancestors:
+                    ancestors.add(inp)
+                    new_nodes.append(inp)
+        cur_nodes = new_nodes
+    return OrderedSet(node for node in ancestors if node.op != "placeholder")
+
+
+def _get_tensor(node: torch.fx.Node) -> torch.Tensor:
+    val = node.meta["val"]
+    assert isinstance(val, torch.Tensor)
+    return val
+
+
+@dataclass
+class _AllGatherMatch:
+    match: Match
+    shard_node: torch.fx.Node
+    ag_node: torch.fx.Node
+    res_node: torch.fx.Node
+    gather_dim: int
+    group_name: "torch.distributed.distributed_c10d.GroupName"
+
+    def replace_with(self, new_node: torch.fx.Node) -> None:
+        self.res_node.replace_all_uses_with(new_node)
+
+    def erase(self) -> None:
+        for node in reversed(self.match.nodes):
+            if len(node.users) == 0:
+                node.graph.erase_node(node)
+
+
+def find_all_gather_patterns(graph: torch.fx.Graph):
+    c10d = torch.ops._c10d_functional
+
+    def make_zero_dim_all_gather_pattern(shard):
+        return CallFunction(
+            c10d.wait_tensor.default,
+            CallFunction(
+                c10d.all_gather_into_tensor.default,
+                shard,
+                Ignored(),
+                KeywordArg("group_name"),
+            ),
+        )
+
+    # Matches funcol.all_gather_tensor with gather_dim == 0
+    zero_dim_all_gather_pattern = make_zero_dim_all_gather_pattern(KeywordArg("shard"))
+
+    def make_all_gather_split_pattern(shard):
+        return CallFunction(
+            operator.getitem,
+            CallFunction(
+                aten.split.Tensor,
+                make_zero_dim_all_gather_pattern(shard),
+                Ignored(),
+                _users=MULTIPLE,
+            ),
+            Ignored(),
+        )
+
+    def make_cat_pattern(splits):
+        return CallFunction(
+            aten.cat.default,
+            ListOf(splits),
+            KeywordArg("gather_dim"),
+        )
+
+    # Matches funcol.all_gather_tensor with gather_dim > 0
+    non_zero_dim_all_gather_pattern = make_cat_pattern(
+        make_all_gather_split_pattern(KeywordArg("shard")),
+    )
+
+    # Match a zero-dim all-gather in which the data is transferred as uint8 and
+    # viewed back as the original dtype.
+    zero_dim_type_erased_all_gather_pattern = CallFunction(
+        aten.view.dtype,
+        make_zero_dim_all_gather_pattern(
+            KeywordArg("shard"),
+        ),
+        Ignored(),
+    )
+
+    # Match a non-zero dim all-gather in which the data is transferred as uint8
+    # and viewed back as the original dtype.
+    non_zero_dim_type_erased_all_gather_pattern = CallFunction(
+        aten.view.dtype,
+        make_cat_pattern(
+            CallFunction(
+                aten.view.dtype,
+                make_all_gather_split_pattern(
+                    KeywordArg("shard"),
+                ),
+                Ignored(),
+            ),
+        ),
+        Ignored(),
+    )
+
+    # If two patterns with the same res_node_target have the same suffix, the
+    # longer pattern should appear first in the list.
+    # e.g. supposed we have (1) A -> B -> C -> D and (2) B -> C -> D, (1)
+    # should appear before (2) in the list.
+    res_node_target_to_patterns = {
+        aten.cat.default: [
+            (non_zero_dim_all_gather_pattern, 0),
+        ],
+        aten.view.dtype: [
+            (non_zero_dim_type_erased_all_gather_pattern, 0),
+            (zero_dim_type_erased_all_gather_pattern, 0),
+        ],
+        c10d.wait_tensor.default: [
+            (zero_dim_all_gather_pattern, 0),
+        ],
+    }
+
+    # Match in reverse to ensure longer patterns is prioritized
+    all_gathers = []
+    visited_ag_nodes = OrderedSet[torch.fx.Node]()
+    for node in reversed(graph.nodes):
+        for target, patterns in res_node_target_to_patterns.items():
+            if node.target != target:
+                continue
+            for pattern, ag_node_idx in patterns:
+                match = pattern.match(node)
+                if not match:
+                    continue
+
+                assert isinstance(match, Match)
+                ag_node = match.nodes[ag_node_idx]
+                assert ag_node.target == c10d.all_gather_into_tensor.default
+
+                if ag_node in visited_ag_nodes:
+                    continue
+                visited_ag_nodes.add(ag_node)
+
+                ag_match = _AllGatherMatch(
+                    match=match,
+                    shard_node=match.kwargs["shard"],
+                    ag_node=ag_node,
+                    res_node=node,
+                    gather_dim=match.kwargs.get("gather_dim", 0),
+                    group_name=match.kwargs["group_name"],
+                )
+                all_gathers.append(ag_match)
+
+    return list(reversed(all_gathers))
+
+
+@dataclass
+class _ReduceScatterMatch:
+    match: Match
+    input_node: torch.fx.Node
+    reduce_scatter_node: torch.fx.Node
+    wait_tensor_node: torch.fx.Node
+    reduce_op: str
+    scatter_dim: int
+    group_name: "torch.distributed.distributed_c10d.GroupName"
+
+    def replace_with(self, new_node: torch.fx.Node) -> None:
+        # Replace all uses of the result node (wait_tensor) with the fused node.
+        self.wait_tensor_node.replace_all_uses_with(new_node)
+
+        # If the reduce-scatter result is saved for backward, save the fused node for backward instead.
+        self._update_save_for_backward(new_node)
+
+    def _update_save_for_backward(self, new_node: torch.fx.Node) -> None:
+        """
+        If the output node is a user of the reduce_scatter node (indicating the reduce_scatter
+        result is saved for backward), this method will update the output node to use the fused node instead.
+        """
+        output_node = None
+        for user in self.reduce_scatter_node.users:
+            if user.target == "output":
+                output_node = user
+                break
+        if output_node is not None:
+            output_node.replace_input_with(self.reduce_scatter_node, new_node)
+
+            # Assert that now the reduce scatter node has only one user (the wait_tensor) and it's not
+            # saved for backward anymore.
+            assert len(self.reduce_scatter_node.users) == 1, (
+                "Reduce scatter node has multiple users, this is not expected"
+            )
+
+    def erase(self) -> None:
+        for node in reversed(self.match.nodes):
+            if len(node.users) == 0:
+                node.graph.erase_node(node)
+
+
+def find_reduce_scatter_patterns(graph: torch.fx.Graph):
+    c10d = torch.ops._c10d_functional
+
+    def reduce_scatter_template(inp: PatternExpr, users: int):
+        return CallFunction(
+            c10d.wait_tensor.default,
+            CallFunction(
+                c10d.reduce_scatter_tensor.default,
+                inp,
+                KeywordArg("reduce_op"),
+                Ignored(),
+                KeywordArg("group_name"),
+                _users=users,
+            ),
+        )
+
+    # Matches funcol.reduce_scatter_tensor with scatter_dim == 0
+    zero_dim_reduce_scatter_pattern_single_user = reduce_scatter_template(
+        KeywordArg("input"), users=1
+    )
+
+    # Two users will occur when the reduce-scatter result is saved for backward
+    zero_dim_reduce_scatter_pattern_multi_user = reduce_scatter_template(
+        KeywordArg("input"), users=2
+    )
+
+    # Matches funcol.reduce_scatter_tensor with scatter_dim > 0
+    non_zero_dim_reduce_scatter_pattern_single_user = reduce_scatter_template(
+        CallFunction(
+            aten.cat.default,
+            ListOf(
+                CallFunction(
+                    operator.getitem,
+                    CallFunction(
+                        aten.split.Tensor,
+                        KeywordArg("input"),
+                        Ignored(),
+                        KeywordArg("scatter_dim"),
+                        _users=MULTIPLE,
+                    ),
+                    Ignored(),
+                )
+            ),
+        ),
+        users=1,
+    )
+
+    # Two users will occur when the reduce-scatter result is saved for backward
+    non_zero_dim_reduce_scatter_pattern_multi_user = reduce_scatter_template(
+        CallFunction(
+            aten.cat.default,
+            ListOf(
+                CallFunction(
+                    operator.getitem,
+                    CallFunction(
+                        aten.split.Tensor,
+                        KeywordArg("input"),
+                        Ignored(),
+                        KeywordArg("scatter_dim"),
+                        _users=MULTIPLE,
+                    ),
+                    Ignored(),
+                )
+            ),
+        ),
+        users=2,
+    )
+
+    reduce_scatters = []
+    for node in reversed(graph.nodes):
+        if node.target == c10d.wait_tensor.default:
+            if match := non_zero_dim_reduce_scatter_pattern_single_user.match(node):
+                assert isinstance(match, Match)
+                reduce_scatters.append(
+                    _ReduceScatterMatch(
+                        match=match,
+                        input_node=match.kwargs["input"],
+                        reduce_scatter_node=match.nodes[-2],
+                        wait_tensor_node=node,
+                        reduce_op=match.kwargs["reduce_op"],
+                        scatter_dim=match.kwargs["scatter_dim"],
+                        group_name=match.kwargs["group_name"],
+                    )
+                )
+            elif match := zero_dim_reduce_scatter_pattern_single_user.match(node):
+                assert isinstance(match, Match)
+                reduce_scatters.append(
+                    _ReduceScatterMatch(
+                        match=match,
+                        input_node=match.kwargs["input"],
+                        reduce_scatter_node=match.nodes[0],
+                        wait_tensor_node=node,
+                        reduce_op=match.kwargs["reduce_op"],
+                        scatter_dim=0,
+                        group_name=match.kwargs["group_name"],
+                    )
+                )
+            elif match := non_zero_dim_reduce_scatter_pattern_multi_user.match(node):
+                assert isinstance(match, Match)
+                reduce_scatters.append(
+                    _ReduceScatterMatch(
+                        match=match,
+                        input_node=match.kwargs["input"],
+                        reduce_scatter_node=match.nodes[-2],
+                        wait_tensor_node=node,
+                        reduce_op=match.kwargs["reduce_op"],
+                        scatter_dim=match.kwargs["scatter_dim"],
+                        group_name=match.kwargs["group_name"],
+                    )
+                )
+            elif match := zero_dim_reduce_scatter_pattern_multi_user.match(node):
+                assert isinstance(match, Match)
+                reduce_scatters.append(
+                    _ReduceScatterMatch(
+                        match=match,
+                        input_node=match.kwargs["input"],
+                        reduce_scatter_node=match.nodes[0],
+                        wait_tensor_node=node,
+                        reduce_op=match.kwargs["reduce_op"],
+                        scatter_dim=0,
+                        group_name=match.kwargs["group_name"],
+                    )
+                )
+    return list(reversed(reduce_scatters))
+
+
+@dataclass
+class _Matmul:
+    nodes: list[torch.fx.Node]
+    arg_ancestor_nodes: OrderedSet[torch.fx.Node] = field(init=False)
+    A_node: torch.fx.Node
+    B_node: torch.fx.Node
+    pre_mm_reshape: torch.fx.Node | None
+    post_mm_reshape: torch.fx.Node | None
+
+    def __post_init__(self):
+        assert len(self.nodes) in (1, 3)
+        if len(self.nodes) == 1:
+            assert self.nodes[0].target in (aten.mm.default, aten._scaled_mm.default)
+        else:
+            assert self.nodes[0].target is aten.reshape.default
+            assert self.nodes[1].target in (aten.mm.default, aten._scaled_mm.default)
+            assert self.nodes[2].target is aten.reshape.default
+        self.arg_ancestor_nodes = _find_ancestors(self.B_node)
+
+    def replace_with(self, new_node: torch.fx.Node) -> None:
+        """
+        Replace the matmul with the new node.
+        """
+        graph = new_node.graph
+
+        # For 2D-matmuls, we simply replace the mm node with `new_node`.
+        if len(self.nodes) == 1:
+            mm_node = self.nodes[0]
+            assert mm_node.target in (aten.mm.default, aten._scaled_mm.default)
+            mm_node.replace_all_uses_with(new_node)
+            graph.erase_node(mm_node)
+            return
+
+        # An ND-matmul is reshape -> mm -> reshape sequence. We first replace
+        # the second reshape node with `new_node`. Then, we ensure that the
+        # original mm node in the sequence ends up with zero users by replacing
+        # it with a reverse reshape of `new_node`.
+        graph = new_node.graph
+        assert len(self.nodes) == 3
+        mm_node = self.nodes[1]
+        output_reshape_node = self.nodes[2]
+
+        assert mm_node.target in (aten.mm.default, aten._scaled_mm.default)
+        assert output_reshape_node.target is aten.reshape.default
+
+        output_reshape_node.replace_all_uses_with(new_node)
+        if len(mm_node.users) > 1:
+            with graph.inserting_after(new_node):
+                new_mm_node = graph.call_function(
+                    aten.reshape.default,
+                    args=(new_node, list(_get_tensor(mm_node).shape)),
+                )
+            mm_node.replace_all_uses_with(new_mm_node)
+
+    def erase(self) -> None:
+        for node in reversed(self.nodes):
+            if len(node.users) == 0:
+                node.graph.erase_node(node)
+
+    @classmethod
+    def from_match(cls, match: list[torch.fx.Node]) -> "_Matmul":
+        assert len(match) in (1, 3)
+        assert match[0].target in (
+            aten.mm.default,
+            aten.reshape.default,
+        )
+        mm_node = match[0] if len(match) == 1 else match[1]
+        return _Matmul(
+            nodes=match,
+            A_node=cast("torch.fx.Node", match[0].args[0]),
+            B_node=cast("torch.fx.Node", mm_node.args[1]),
+            # _Matmul handles reshapes via custom graph manipulation logic, see `replace_with()` method.
+            # TODO: explore unifying the _Matmul and _ScaledMatmul approaches to handling reshapes.
+            pre_mm_reshape=None,
+            post_mm_reshape=None,
+        )
+
+
+@dataclass
+class _ScaledMatmul(_Matmul):
+    A_scale_node: torch.fx.Node
+    B_scale_node: torch.fx.Node
+    bias_node: torch.fx.Node | None
+    result_scale_node: torch.fx.Node | None
+    out_dtype: torch.dtype | None
+    use_fast_accum: bool
+    pre_mm_reshape: torch.fx.Node | None
+    post_mm_reshape: torch.fx.Node | None
+
+    def __post_init__(self):
+        super().__post_init__()
+        self.arg_ancestor_nodes |= _find_ancestors(self.A_scale_node)
+        self.arg_ancestor_nodes |= _find_ancestors(self.B_scale_node)
+
+    @classmethod
+    def from_match(cls, match: list[torch.fx.Node]) -> "_ScaledMatmul":
+        assert len(match) in (1, 3)
+        assert match[0].target in (
+            aten._scaled_mm.default,
+            aten.reshape.default,
+        )
+
+        def get_arg(node: torch.fx.Node, idx: int, default: Any) -> Any:
+            if idx >= len(node.args):
+                return default
+            return node.args[idx]
+
+        # Use mm_node with 2D args for both A and B, even if this is a "reshape -> mm -> reshape" pattern.
+        # We will store the reshapes in pre_mm_reshape and post_mm_reshape, to be referenced later to
+        # produce the correct output shapes, reduce-scatter along the correct dimensions, etc.
+        is_reshape_mm_reshape_pattern = match[0].target is aten.reshape.default
+        mm_node = match[1] if is_reshape_mm_reshape_pattern else match[0]
+        pre_mm_reshape = match[0] if is_reshape_mm_reshape_pattern else None
+        post_mm_reshape = match[-1] if is_reshape_mm_reshape_pattern else None
+        A_node = cast("torch.fx.Node", mm_node.args[0])
+        B_node = cast("torch.fx.Node", mm_node.args[1])
+        A_scale_node = cast("torch.fx.Node", mm_node.args[2])
+        B_scale_node = cast("torch.fx.Node", mm_node.args[3])
+
+        return _ScaledMatmul(
+            nodes=match,
+            A_node=A_node,
+            B_node=B_node,
+            A_scale_node=A_scale_node,
+            B_scale_node=B_scale_node,
+            bias_node=get_arg(mm_node, 4, None),
+            result_scale_node=get_arg(mm_node, 5, None),
+            out_dtype=get_arg(mm_node, 6, None),
+            use_fast_accum=get_arg(mm_node, 7, False),
+            pre_mm_reshape=pre_mm_reshape,
+            post_mm_reshape=post_mm_reshape,
+        )
+
+
+def _find_reshape_mm_reshape(node: torch.fx.Node) -> list[_Matmul]:
+    if node.target != aten.reshape.default:
+        return []
+
+    matches = []
+    for mm_node in node.users:
+        if mm_node.target not in (aten.mm.default, aten._scaled_mm.default):
+            continue
+        for reshape_node in mm_node.users:
+            if reshape_node.target != aten.reshape.default:
+                continue
+
+            # Since the reshape -> mm -> reshape pattern would be subsumed into
+            # the fused op, we only match the patterns where the shape of the
+            # second reshape is matches the mm result produced by the fused op.
+            matmul_input_node = cast("torch.fx.Node", node.args[0])
+            B_node = cast("torch.fx.Node", mm_node.args[1])
+            matmul_out_shape = torch.Size(
+                [
+                    *_get_tensor(matmul_input_node).shape[:-1],
+                    _get_tensor(B_node).shape[-1],
+                ]
+            )
+            if _get_tensor(reshape_node).shape != matmul_out_shape:
+                continue
+            matches.append([node, mm_node, reshape_node])
+            # If for some rare reason mm_node is being reshaped by two
+            # different reshape nodes, we only include mm_node once in the
+            # parsing result.
+            break
+
+    matmuls = []
+    for match in matches:
+        mm_node = match[1]
+        if mm_node.target is aten.mm.default:
+            matmul = _Matmul.from_match(match)
+            matmuls.append(matmul)
+        elif mm_node.target is aten._scaled_mm.default:
+            matmul = _ScaledMatmul.from_match(match)
+            matmuls.append(matmul)
+        else:
+            raise AssertionError(
+                "Expect the node's target to be either aten.mm.default or "
+                f"aten._scaled_mm.default. Got {mm_node.target}."
+            )
+    return matmuls
+
+
+def _find_consumer_matmuls(node: torch.fx.Node) -> list[_Matmul]:
+    """
+    Find the matmuls that use `node` as the lhs argument.
+    """
+    matmuls = []
+    for user in node.users:
+        # ND matmuls
+        if user.target is aten.reshape.default:
+            matmuls.extend(_find_reshape_mm_reshape(user))
+        # 2D matmuls
+        elif user.target is aten.mm.default:
+            matmul = _Matmul.from_match(match=[user])
+            matmuls.append(matmul)
+        elif user.target is aten._scaled_mm.default:
+            matmul = _ScaledMatmul.from_match([user])
+            matmuls.append(matmul)
+    return matmuls
+
+
+def _insert_fused_all_gather_matmul(
+    graph: torch.fx.Graph,
+    matmuls: list[_Matmul],
+    shard_node: torch.fx.Node,
+    gather_dim: int,
+    group_name: "torch.distributed.distributed_c10d.GroupName",
+) -> torch.fx.Node:
+    mm_types = OrderedSet(map(type, matmuls))
+    assert len(mm_types) == 1
+    mm_type = next(iter(mm_types))
+    if mm_type == _Matmul:
+        B_nodes = [matmul.B_node for matmul in matmuls]
+        return graph.call_function(
+            torch.ops.symm_mem.fused_all_gather_matmul.default,
+            args=(shard_node, B_nodes, gather_dim, group_name),
+            kwargs={"return_A": True},
+        )
+    elif mm_type == _ScaledMatmul:
+        scaled_matmuls = cast("list[_ScaledMatmul]", matmuls)
+        return graph.call_function(
+            torch.ops.symm_mem.fused_all_gather_scaled_matmul.default,
+            args=(
+                shard_node,
+                [matmul.B_node for matmul in scaled_matmuls],
+                scaled_matmuls[0].A_scale_node,
+                [matmul.B_scale_node for matmul in scaled_matmuls],
+                gather_dim,
+                group_name,
+                [matmul.bias_node for matmul in scaled_matmuls],
+                [matmul.result_scale_node for matmul in scaled_matmuls],
+                [matmul.out_dtype for matmul in scaled_matmuls],
+                [matmul.use_fast_accum for matmul in scaled_matmuls],
+            ),
+        )
+    else:
+        raise AssertionError(f"Unexpected matmul match type: {mm_type}")
+
+
+def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
+    """
+    Fused the pattern
+
+        A = all_gather_tensor(A_shard, gather_dim, group_name)
+        C_0 = torch.matmul(A, B_0)
+        C_1 = torch.matmul(A, B_1)
+        C_2 = torch.matmul(A, B_2)
+        ...
+
+    into
+
+        A, Cs = torch.ops.symm_mem.fused_all_gather_matmul(
+            A_shard, [B_0, B_1, B_2, ...], gather_dim, group_name,
+        )
+    """
+    if (
+        not torch.distributed.is_available()
+        or not torch.distributed.is_nccl_available()
+    ):
+        return
+
+    from torch.distributed._symmetric_memory import (
+        is_symm_mem_enabled_for_group,
+        restride_A_shard_for_fused_all_gather_matmul,
+    )
+
+    shard_node, ag_node, ag_res_node, gather_dim, group_name = (
+        all_gather.shard_node,
+        all_gather.ag_node,
+        all_gather.res_node,
+        all_gather.gather_dim,
+        all_gather.group_name,
+    )
+
+    if not is_symm_mem_enabled_for_group(group_name):
+        return
+
+    filter_matmul = None
+    if _is_last_dim(_get_tensor(shard_node), gather_dim):
+        # Decomposed mms should not be too small
+        if _get_tensor(shard_node).shape[-1] < 1024:
+            return
+
+        # scaled_mm is not supported yet for last dim
+        def _filter_out_scaled_matmul(matmul: _Matmul):
+            return not isinstance(matmul, _ScaledMatmul)
+
+        filter_matmul = _filter_out_scaled_matmul
+
+    # Find consumer matmuls
+    matmuls = _find_consumer_matmuls(ag_res_node)
+
+    # The matmuls are only fusible if non-A args don't depend on the all-gather
+    # result node
+    matmuls = [
+        matmul
+        for matmul in matmuls
+        if all_gather.res_node not in matmul.arg_ancestor_nodes
+    ]
+
+    if len(matmuls) == 0 or len(OrderedSet(map(type, matmuls))) != 1:
+        return
+
+    if _is_last_dim(_get_tensor(shard_node), gather_dim) and len(
+        all_gather.res_node.users
+    ) > len(matmuls):
+        # The result of ag-split-cat is used not only in matmuls.
+        # Then it has to be materialized, which can have overhead.
+        return
+
+    if filter_matmul and not filter_matmul(matmuls[0]):
+        return
+
+    # Fuse the all_gather_tensor with the eligible matmuls
+    graph = ag_node.graph
+    with graph.inserting_before(ag_node):
+        if not _is_last_dim(_get_tensor(shard_node), gather_dim):
+            if "val" in shard_node.meta:
+                restrided = restride_A_shard_for_fused_all_gather_matmul(
+                    _get_tensor(shard_node),
+                    gather_dim,
+                )
+                shard_node = graph.call_function(
+                    inductor_prims.force_stride_order,
+                    args=(shard_node, restrided.stride()),
+                )
+
+        fused_node = _insert_fused_all_gather_matmul(
+            graph, matmuls, shard_node, gather_dim, group_name
+        )
+        new_ag_node = graph.call_function(
+            operator.getitem,
+            args=(fused_node, 0),
+        )
+        new_out_nodes = graph.call_function(
+            operator.getitem,
+            args=(fused_node, 1),
+        )
+        for idx, matmul in enumerate(matmuls):
+            new_out_node = graph.call_function(
+                operator.getitem,
+                args=(new_out_nodes, idx),
+            )
+            matmul.replace_with(new_out_node)
+            matmul.erase()
+        all_gather.replace_with(new_ag_node)
+        all_gather.erase()
+
+        # If the new_ag_node has no users, we tell the fused op to not return
+        # it. This creates more optimization opportunities.
+        if len(new_ag_node.users) == 0:
+            graph.erase_node(new_ag_node)
+            kwargs = dict(fused_node.kwargs)
+            if "return_A" in kwargs:
+                kwargs["return_A"] = False
+                fused_node.kwargs = kwargs
+
+    # Raise ancestors of non-A args that are topologically ordered between
+    # ag_res_node and the matmul above fused_node.
+    order = {node: idx for idx, node in enumerate(graph.nodes)}
+    nodes_to_raise = sorted(
+        OrderedSet(x for matmul in matmuls for x in matmul.arg_ancestor_nodes),
+        key=lambda x: order[x],
+    )
+    for node in nodes_to_raise:
+        if order[node] > order[fused_node]:
+            fused_node.prepend(node)
+
+
+def _scatter_dim_after_reshape(
+    reshape_node: torch.fx.Node, orig_scatter_dim: int
+) -> int:
+    """
+    Given a reshape node and the original scatter dim for the target tensor,
+    returns the new scatter dim for the reshaped tensor.
+    """
+    # if there was no pre-mm reshape, scatter dim will not change.
+    if not reshape_node:
+        return orig_scatter_dim
+
+    reshape_op_output_tensor = _get_tensor(reshape_node)
+    assert reshape_op_output_tensor.ndim == 2, (
+        "reshape must produce 2D tensor for scaled_mm"
+    )
+
+    assert len(reshape_node.args) >= 1, "reshape node must have at least 1 arg"
+    input_tensor_node = cast(torch.fx.Node, reshape_node.args[0])
+    reshape_op_input_tensor = _get_tensor(input_tensor_node)
+    assert reshape_op_input_tensor.ndim > reshape_op_output_tensor.ndim, (
+        "reshape must be from 3D+ to 2D"
+    )
+
+    # Note: for a N-D tensor to be reshaped into 2D, either the leading dims or ending dims must
+    # be collapsed to a single dim. First determine which of these happened.
+    input_shape = reshape_op_input_tensor.shape
+    output_shape = reshape_op_output_tensor.shape
+    leading_dims_collapsed = output_shape[0] == prod(input_shape[:-1])
+
+    # Case 1: scatter dim 0 always maps to 0 after any reshape from 3D+ to 2D, regardless if
+    # leading dims or ending dims were collapsed.
+    if orig_scatter_dim == 0:
+        return 0
+
+    # Case 2: scatter dim "ndim-1" always maps to 1 after any reshape from 3D+ to 2D, regardless if
+    # leading dims or ending dims were collapsed.
+    if orig_scatter_dim == reshape_op_input_tensor.ndim - 1:
+        return 1
+
+    # Case 3: scatter dim was one of the middle dims (between 0 and ndim-1).
+    # if the leading dims were collapsed, the new scatter dim will be 0.
+    # if the ending dims were collapsed, the new scatter dim will be 1.
+    return 0 if leading_dims_collapsed else 1
+
+
+def _find_producer_matmul(node: torch.fx.Node) -> _Matmul | None:
+    """
+    Returns producer matmul node if found, otherwise returns None.
+    """
+    if node.target is aten.mm.default:
+        return _Matmul.from_match(match=[node])
+    elif node.target is aten._scaled_mm.default:
+        return _ScaledMatmul.from_match(match=[node])
+    elif node.target is aten.reshape.default:
+        reshape_node_1 = node
+
+        mm_node = reshape_node_1.args[0]
+        assert isinstance(mm_node, torch.fx.Node)
+        if mm_node.target not in (aten.mm.default, aten._scaled_mm.default):
+            return None
+
+        reshape_node_0 = mm_node.args[0]
+        assert isinstance(reshape_node_0, torch.fx.Node)
+        if reshape_node_0.target != aten.reshape.default:
+            return None
+
+        if mm_node.target is aten.mm.default:
+            return _Matmul.from_match(match=[reshape_node_0, mm_node, reshape_node_1])
+        elif mm_node.target is aten._scaled_mm.default:
+            return _ScaledMatmul.from_match(
+                match=[reshape_node_0, mm_node, reshape_node_1]
+            )
+    return None
+
+
+def _insert_fused_matmul_reduce_scatter(
+    graph: torch.fx.Graph,
+    matmul: _Matmul,
+    reduce_op: str,
+    orig_scatter_dim: int,
+    group_name: "torch.distributed.distributed_c10d.GroupName",
+    scatter_dim_after_reshape: int,  # only used for reshape -> scaled_mm -> reshape pattern
+    output_shape: list[int],  # only used for reshape -> scaled_mm -> reshape pattern
+) -> torch.fx.Node:
+    if type(matmul) is _Matmul:
+        return graph.call_function(
+            torch.ops.symm_mem.fused_matmul_reduce_scatter.default,
+            args=(
+                matmul.A_node,
+                matmul.B_node,
+                reduce_op,
+                orig_scatter_dim,
+                group_name,
+            ),
+        )
+    elif type(matmul) is _ScaledMatmul:
+        return graph.call_function(
+            torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default,
+            args=(
+                matmul.A_node,
+                matmul.B_node,
+                matmul.A_scale_node,
+                matmul.B_scale_node,
+                reduce_op,
+                orig_scatter_dim,
+                scatter_dim_after_reshape,
+                group_name,
+                output_shape,
+                matmul.bias_node,
+                matmul.result_scale_node,
+                matmul.out_dtype,
+                matmul.use_fast_accum,
+            ),
+        )
+    else:
+        raise AssertionError(f"Unexpected matmul match type: {type(matmul)}")
+
+
+def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
+    """
+    Fused the pattern
+
+        reduce_scatter_tensor(A @ B, scatter_dim, group_name)
+
+    into
+
+        torch.ops.symm_mem.fused_matmul_reduce_scatter(
+            A, B, scatter_dim, group_name,
+        )
+
+    Returns boolean indicating if fusion was successful or not.
+    """
+    if (
+        not torch.distributed.is_available()
+        or not torch.distributed.is_nccl_available()
+    ):
+        return
+
+    from torch.distributed._symmetric_memory import (
+        is_symm_mem_enabled_for_group,
+        restride_A_for_fused_matmul_reduce_scatter,
+    )
+
+    (
+        input_node,
+        _reduce_scatter_node,
+        rs_wait_tensor_node,
+        reduce_op,
+        orig_scatter_dim,
+        group_name,
+    ) = (
+        reduce_scatter.input_node,
+        reduce_scatter.reduce_scatter_node,
+        reduce_scatter.wait_tensor_node,
+        reduce_scatter.reduce_op,
+        reduce_scatter.scatter_dim,
+        reduce_scatter.group_name,
+    )
+
+    if not is_symm_mem_enabled_for_group(group_name):
+        return
+
+    filter_matmul = None
+    if _is_last_dim(_get_tensor(input_node), orig_scatter_dim):
+        # scaled_mm is not supported yet for last dim mm+rs
+        def _filter_out_scaled_matmul(matmul: _Matmul):
+            return not isinstance(matmul, _ScaledMatmul)
+
+        filter_matmul = _filter_out_scaled_matmul
+
+    # Currently fused_matmul_reduce_scatter doesn't return the matmul result,
+    # so we can't apply the fusion if the matmul result is used by multiple
+    # users. This is not a fundamental limitation of the fused op and can be
+    # addressed if needed.
+    if len(input_node.users) != 1:
+        log.warning(
+            "matmul result has more than one user, skipping fused_matmul_reduce_scatter fusion."
+        )
+        return
+
+    matmul = _find_producer_matmul(input_node)
+
+    if matmul is None:
+        log.warning(
+            "no producer matmul found for reduce scatter, skipping fuse_matmul_reduce_scatter fusion"
+        )
+        return
+
+    if filter_matmul and not filter_matmul(matmul):
+        return
+
+    if rs_wait_tensor_node in matmul.arg_ancestor_nodes:
+        log.warning(
+            "reduce-scatter result node is an ancestor of matmul, skipping fuse_matmul_reduce_scatter fusion"
+        )
+        return
+
+    # We need to track 3 values for the fused scaled mm reduce scatter implementation:
+    #   1. The scatter dim before the reshape, which was assigned using the original (a,b,c) @ (c,d) = (a,b,d) dims.
+    #   2. The scatter dim after the reshape, to use when we are doing the 2D (a*b,c) @ (c,d) = (a,b,d) scaled mm op.
+    #   3. Store expected potentially 3D+ mm output shape, so we can reshape the 2D mm output to the intended
+    #      3D+ shape before applying reduce-scatter, and to prevent shape errors with subsequent ops.
+
+    # If 'A' was reshaped from 3D+ -> 2D for the mm, we need to determine the new scattter dim after the reshape
+    # for the fused matmul reduce scatter implementation to use.
+    if matmul.pre_mm_reshape:
+        scatter_dim_after_maybe_reshape = _scatter_dim_after_reshape(
+            matmul.pre_mm_reshape, orig_scatter_dim
+        )
+    else:
+        scatter_dim_after_maybe_reshape = orig_scatter_dim
+
+    # If the 2D mm output was reshaped from 2D -> 3D+, we need to store the intended output shape for the
+    # fused matmul reduce scatter implementation to use.
+    if matmul.post_mm_reshape:
+        output_shape = list(_get_tensor(matmul.post_mm_reshape).shape)
+    else:
+        A_orig_shape = list(_get_tensor(matmul.A_node).shape)
+        B_shape = list(_get_tensor(matmul.B_node).shape)
+        output_shape = [*A_orig_shape[:-1], B_shape[-1]]
+
+    graph = rs_wait_tensor_node.graph
+    with graph.inserting_before(rs_wait_tensor_node):
+        # Restride A tensor before fused op, for optimal perf in fused matmul reduce scatter
+        if "val" in matmul.A_node.meta:
+            restrided = restride_A_for_fused_matmul_reduce_scatter(
+                _get_tensor(matmul.A_node),
+                scatter_dim_after_maybe_reshape,
+            )
+            matmul.A_node = graph.call_function(
+                inductor_prims.force_stride_order,
+                args=(matmul.A_node, restrided.stride()),
+            )
+
+        # Replace matched subgraph with fused matmul reduce scatter node
+        fused_node = _insert_fused_matmul_reduce_scatter(
+            graph,
+            matmul,
+            reduce_op,
+            orig_scatter_dim,
+            group_name,
+            scatter_dim_after_maybe_reshape,
+            output_shape,
+        )
+        reduce_scatter.replace_with(fused_node)
+        reduce_scatter.erase()
+        matmul.erase()
+
+    order = {node: idx for idx, node in enumerate(graph.nodes)}
+    nodes_to_raise = sorted(
+        matmul.arg_ancestor_nodes,
+        key=lambda x: order[x],
+    )
+    for node in nodes_to_raise:
+        if order[node] > order[fused_node]:
+            fused_node.prepend(node)
+
+    log.debug("successfully fused matmul reduce scatter")
+
+
+def _get_node_to_ancestors(
+    graph: torch.fx.Graph,
+) -> dict[torch.fx.Node, OrderedSet[torch.fx.Node]]:
+    """
+    Compute the ancestors for all nodes in a graph.
+    """
+    node_to_ancestors = defaultdict(OrderedSet[torch.fx.Node])  # type: ignore[var-annotated]
+    for node in graph.nodes:
+        node_to_ancestors[node] = OrderedSet(node.all_input_nodes)
+        for dep in node.all_input_nodes:
+            node_to_ancestors[node] |= node_to_ancestors[dep]
+
+    return node_to_ancestors
+
+
+def _get_collective_to_overlappable_nodes(
+    graph: torch.fx.Graph,
+) -> dict[torch.fx.Node, list[torch.fx.Node]]:
+    """
+    For each collective in the graph, find nodes that are neither ancestors nor
+    descendants of the collective.
+    """
+
+    def is_collective(node) -> bool:
+        # Only consider all-gather and reduce-scatter in the context of
+        # micro-pipeline TP.
+        return node.target in [
+            torch.ops._c10d_functional.all_gather_into_tensor.default,
+            torch.ops._c10d_functional.reduce_scatter_tensor.default,
+        ]
+
+    node_to_ancestors = _get_node_to_ancestors(graph)
+    collective_to_overlappable_nodes = defaultdict(list)
+    for node in graph.nodes:
+        if not is_collective(node):
+            continue
+        for x in graph.nodes:
+            if (
+                node not in node_to_ancestors[x]
+                and x not in node_to_ancestors[node]
+                and x.op == "call_function"
+            ):
+                collective_to_overlappable_nodes[node].append(x)
+
+    return collective_to_overlappable_nodes
+
+
+def _get_unexposed_collectives(graph: torch.fx.Graph) -> list[torch.fx.Node]:
+    """
+    Find all unexposed collectives in the graph.
+
+    Because we don't have the runtime estimate, this function is a rough
+    estimation using the following strong/hand-wavy assumptions:
+
+    - Only a predefined set of "compute intensive" operation can hide a collective.
+    - Any "compute intensive" operation can hide exactly one collective.
+    """
+
+    def _is_compute_intensive(node: torch.fx.Node) -> bool:
+        return node.target is torch.ops.aten.mm.default
+
+    collective_to_overlapping_candidates = defaultdict(list)
+    available_nodes = OrderedSet[torch.fx.Node]()
+    collective_to_overlappable_nodes = _get_collective_to_overlappable_nodes(graph)
+    for collective, overlappable_nodes in collective_to_overlappable_nodes.items():
+        candidates = [x for x in overlappable_nodes if _is_compute_intensive(x)]
+        collective_to_overlapping_candidates[collective] = candidates
+        available_nodes.update(candidates)
+
+    unexposed_collectives = []
+    for (
+        collective,
+        overlapping_candidates,
+    ) in collective_to_overlapping_candidates.items():
+        # Each collective consumes exactly one overlapping candidate
+        for x in overlapping_candidates:
+            if x in available_nodes:
+                unexposed_collectives.append(collective)
+                available_nodes.remove(x)
+                break
+    return unexposed_collectives
+
+
+def micro_pipeline_tp_pass(graph: torch.fx.Graph):
+    all_gathers = find_all_gather_patterns(graph)
+    reduce_scatters = find_reduce_scatter_patterns(graph)
+
+    # When a collective can be hidden through either simple overlapping or
+    # micro-pipeline TP, we prefer simple overlapping to avoid the overhead
+    # associated with decomposition. If reorder_for_compute_comm_overlap is
+    # enabled, we identify collectives that can be hidden through simple
+    # overlapping and exclude them from micro-pipeline TP candidates.
+    if config.reorder_for_compute_comm_overlap:
+        unexposed_collectives = _get_unexposed_collectives(graph)
+        all_gathers = [x for x in all_gathers if x.ag_node not in unexposed_collectives]
+        reduce_scatters = [
+            x
+            for x in reduce_scatters
+            if x.reduce_scatter_node not in unexposed_collectives
+        ]
+
+    if not all_gathers and not reduce_scatters:
+        log.warning(
+            "async TP found no matching all-gather/reduce-scatter patterns for fusion"
+        )
+
+    for all_gather in all_gathers:
+        fuse_all_gather_matmul(all_gather)
+
+    for reduce_scatter in reduce_scatters:
+        fuse_matmul_reduce_scatter(reduce_scatter)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/misc_patterns.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/misc_patterns.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff0981e72e8b2f1e4f4d618c7bcc4dc0afa970c5
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/misc_patterns.py
@@ -0,0 +1,139 @@
+# mypy: allow-untyped-defs
+import functools
+
+import torch
+from torch._dynamo.utils import counters
+from torch._ops import OpOverload, OpOverloadPacket
+from torch.utils._ordered_set import OrderedSet
+
+from ..pattern_matcher import fwd_only, register_replacement
+
+
+aten = torch.ops.aten
+
+
+@functools.cache
+def _misc_patterns_init():
+    from .joint_graph import patterns as joint_graph_patterns
+    from .post_grad import pass_patterns as post_grad_patterns_all
+
+    post_grad_patterns = post_grad_patterns_all[1]  # medium priority
+
+    if torch.cuda.is_available():
+        # workaround https://github.com/pytorch/pytorch/issues/97894
+        device = "cuda"
+    else:
+        device = "cpu"
+
+    # These patterns do 2 things
+    # 1. Since we know that index is completely unique, we can codegen it using
+    # stores instead of atomic adds, which is quite a bit faster.
+    # 2. Also, since we are guaranteed that they are completely within bounds,
+    # we can use unsafe indexing and skip debug asserts
+    def randperm_index_add_pattern(x, y):
+        index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]]
+        return torch.index_add(x, dim=0, source=y, index=index), index
+
+    def randperm_index_add_replacement(x, y):
+        index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]]
+        return (
+            torch.ops.aten._unsafe_index_put(
+                x, (index,), aten._unsafe_index(x, (index,)) + y, accumulate=False
+            ),
+            index,
+        )
+
+    register_replacement(
+        # pyrefly: ignore [bad-argument-type]
+        randperm_index_add_pattern,
+        # pyrefly: ignore [bad-argument-type]
+        randperm_index_add_replacement,
+        [torch.empty(4, 8, device=device), torch.empty(2, 8, device=device)],
+        # pyrefly: ignore [bad-argument-type]
+        fwd_only,
+        # pyrefly: ignore [bad-argument-type]
+        [post_grad_patterns, joint_graph_patterns],
+    )
+
+    def randperm_index_pattern(x, slice_shape):
+        index = torch.randperm(x.shape[0], device=x.device)[:slice_shape]
+        return torch.ops.aten.index(x, (index,)), index
+
+    def randperm_index_replacement(x, slice_shape):
+        index = torch.randperm(x.shape[0], device=x.device)[:slice_shape]
+        return torch.ops.aten._unsafe_index(x, (index,)), index
+
+    register_replacement(
+        # pyrefly: ignore [bad-argument-type]
+        randperm_index_pattern,
+        # pyrefly: ignore [bad-argument-type]
+        randperm_index_replacement,
+        [torch.empty(4, 8, device=device)],
+        # pyrefly: ignore [bad-argument-type]
+        fwd_only,
+        # pyrefly: ignore [bad-argument-type]
+        [post_grad_patterns, joint_graph_patterns],
+        scalar_workaround={"slice_shape": 42},
+    )
+
+
+class NumpyCompatNormalization:
+    numpy_compat: dict[str, tuple[str, ...]] = {
+        "dim": ("axis",),
+        "keepdim": ("keepdims",),
+        "input": ("x", "a", "x1"),
+        "other": ("x2",),
+    }
+    inverse_mapping: dict[str, str]
+    cache: dict["torch.fx.graph.Target", OrderedSet[str]]
+
+    def __init__(self) -> None:
+        self.cache = {}  # callable -> tuple of replaceable args e.g. ["axis"]
+        self.inverse_mapping = {}
+        for actual_kwarg, numpy_kwargs in self.numpy_compat.items():
+            for numpy_kwarg in numpy_kwargs:
+                assert numpy_kwarg not in self.inverse_mapping
+                self.inverse_mapping[numpy_kwarg] = actual_kwarg
+
+    def __call__(self, graph: torch.fx.Graph):
+        for node in graph.nodes:
+            if node.op != "call_function":
+                continue
+            if isinstance(node.target, (OpOverload, OpOverloadPacket)):
+                # only applies to torch ops; e.g. torch.stack(axis=1) works, torch.ops.aten.stack(axis=1) doesn't.
+                continue
+            kwargs = node.kwargs
+
+            if node.target in self.cache:
+                replaceable_kwargs = self.cache[node.target]
+            else:
+                signatures = torch.fx.operator_schemas.get_signature_for_torch_op(
+                    node.target
+                )
+                signatures = () if signatures is None else signatures
+                replaceable_kwargs = OrderedSet()
+                for sig in signatures:
+                    for param_name in sig.parameters:
+                        if param_name in self.numpy_compat:
+                            replaceable_kwargs.update(self.numpy_compat[param_name])
+
+                self.cache[node.target] = replaceable_kwargs
+
+            if not replaceable_kwargs:
+                continue
+
+            new_kwargs = {}
+            kwargs_changed = False
+            for k, v in kwargs.items():
+                if k in replaceable_kwargs:
+                    kwargs_changed = True
+                    new_kwargs[self.inverse_mapping[k]] = v
+                else:
+                    new_kwargs[k] = v
+
+            if kwargs_changed:
+                node.kwargs = torch.fx.immutable_collections.immutable_dict(new_kwargs)
+                counters["inductor"]["numpy_compat_normalization"] += 1
+
+
+numpy_compat_normalization = NumpyCompatNormalization()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f729596cbb1f180d377a8e895b3a2fe12c8e1be
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py
@@ -0,0 +1,1585 @@
+# mypy: allow-untyped-defs
+import functools
+import operator
+from functools import reduce
+from typing import Any, TYPE_CHECKING
+
+import torch
+from torch._dynamo.utils import counters
+from torch.fx.experimental.symbolic_shapes import has_free_symbols
+from torch.utils._ordered_set import OrderedSet
+
+from .. import ir, mkldnn_ir
+from ..lowering import lowerings as L
+from ..pattern_matcher import (
+    Arg,
+    CallFunction,
+    filter_nodes,
+    get_arg_value,
+    KeywordArg,
+    MULTIPLE,
+)
+from ..utils import (
+    is_mkldnn_bf16_supported,
+    is_mkldnn_fp16_supported,
+    SUPPORTED_MKLDNN_DEVICES,
+)
+from ..virtualized import ops, V
+from .freezing_patterns import register_freezing_graph_pattern
+from .post_grad import register_lowering_pattern
+from .quantization import (
+    _register_int8_woq_concat_linear_pattern,
+    _register_quantization_lowerings,
+    _register_quantization_weight_pack_pass,
+    _register_woq_lowerings,
+)
+
+
+if TYPE_CHECKING:
+    from collections.abc import Callable
+
+
+if torch._C._has_mkldnn:
+    aten = torch.ops.aten
+    mkldnn = torch.ops.mkldnn
+    prims = torch.ops.prims
+
+    _conv_args = [Arg() for _ in range(10)]
+    _linear_args = [Arg() for _ in range(6)]
+    _conv_transpose_args = [Arg() for _ in range(11)]
+
+    class MkldnnDeviceOpBase:
+        def get_linear_transpose_weight(self, weight_node):
+            raise NotImplementedError
+
+        def pack_conv_weight(
+            self,
+            graph,
+            is_transposed,
+            weight,
+            constant_args,
+            input_size,
+        ):
+            raise NotImplementedError
+
+        def pack_linear_weight(
+            self, graph, is_lp_weight, transpose_weight_node, batch_size
+        ):
+            raise NotImplementedError
+
+        def pack_linear(
+            self, graph, is_lp_weight, batch_size, input, packed_weight_node, bias
+        ):
+            raise NotImplementedError
+
+    class CpuMkldnnDeviceOp(MkldnnDeviceOpBase):
+        def get_linear_transpose_weight(self, weight_node):
+            packed_weight_node = weight_node
+            assert packed_weight_node.target == mkldnn._reorder_linear_weight
+            transpose_weight_node = packed_weight_node.args[0]
+            assert transpose_weight_node.target is aten.permute.default
+            return transpose_weight_node
+
+        def pack_conv_weight(
+            self,
+            graph,
+            is_transposed,
+            weight,
+            constant_args,
+            input_size,
+        ):
+            packed_weight_op = mkldnn._reorder_convolution_weight
+            if is_transposed:
+                packed_weight_op = mkldnn._reorder_convolution_transpose_weight
+
+            # mkldnn_reorder_conv_weight(self, padding, stride, dilation, groups, input_size)
+            packed_weight_inputs = (weight,) + tuple(constant_args) + (input_size,)
+            return graph.create_node(
+                "call_function", packed_weight_op, args=packed_weight_inputs
+            )
+
+        def pack_linear_weight(
+            self, graph, is_lp_weight, transpose_weight_node, batch_size
+        ):
+            # For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance.
+            packed_weight_inputs = (
+                transpose_weight_node,
+                batch_size.node.shape_env.size_hint(batch_size.node.expr)
+                if has_free_symbols(batch_size)
+                else batch_size,
+            )
+
+            # MKL packed matrix can't be copied to a different address because the internal implementation
+            # depends on the alignment of internally-stored metadata.
+            # In aot mode, we need to firstly save the packed weight, when loading it,
+            # it will be in a different address which doesn't work.
+            # Disable MKL prepack linear in AOT mode.
+            # Disable MKL prepack linear when batch_size has free symbols.
+            packed_weight_op = (
+                mkldnn._reorder_linear_weight
+                if (
+                    is_lp_weight
+                    or mkldnn._is_mkldnn_acl_supported()
+                    or V.aot_compilation
+                    or has_free_symbols(batch_size)
+                )
+                else torch.ops.mkl._mkl_reorder_linear_weight
+            )
+            return graph.create_node(
+                "call_function", packed_weight_op, args=packed_weight_inputs
+            )
+
+        def pack_linear(
+            self, graph, is_lp_weight, batch_size, input, packed_weight_node, bias
+        ):
+            packed_linear_inputs: tuple[Any, ...] = (input, packed_weight_node)
+            transpose_weight_node = packed_weight_node.args[0]
+            if (
+                is_lp_weight
+                or mkldnn._is_mkldnn_acl_supported()
+                or V.aot_compilation
+                or has_free_symbols(batch_size)
+            ):
+                packed_linear_inputs += (bias, "none", [], "")
+                packed_linear_op: Callable[..., Any] = mkldnn._linear_pointwise.default
+            else:
+                packed_linear_inputs += (transpose_weight_node, bias, batch_size)
+                packed_linear_op = torch.ops.mkl._mkl_linear
+
+            return graph.create_node(
+                "call_function", packed_linear_op, packed_linear_inputs
+            )
+
+    class XpuMkldnnDeviceOp(MkldnnDeviceOpBase):
+        def pack_conv_weight(
+            self,
+            graph,
+            is_transposed,
+            weight,
+            constant_args,
+            input_size,
+        ):
+            assert not is_transposed, (
+                "'mkldnn::_convolution_transpose_pointwise' is not currently implemented for the XPU device."
+            )
+            return weight
+
+    def _get_mkldnn_device_op(device_type: str) -> MkldnnDeviceOpBase:
+        """
+        Returns the MKLDNN device operation class based on the current device type.
+        """
+        if device_type == "cpu":
+            return CpuMkldnnDeviceOp()
+        elif device_type == "xpu":
+            return XpuMkldnnDeviceOp()
+        else:
+            raise RuntimeError(f"MKLDNN is not supported on {device_type} device.")
+
+    def _is_valid_grouped_gemm_fusion(computation_nodes):
+        """
+        Here we check:
+        1. More than 1 GEMM nodes has been found.
+        2. All the GEMM nodes share the same activation.
+        3. All the GEMM nodes have same weight size but different wgt node.
+        """
+        computation_op = mkldnn._linear_pointwise.default
+        act = computation_nodes[0].args[0]
+        wgt = computation_nodes[0].args[1]
+        wgt_size = wgt.meta.get("val").size()  # type: ignore[union-attr]
+        return len(computation_nodes) >= 2 and all(
+            (
+                node.target == computation_op
+                and node.args[0] == act
+                and (node.args[1].meta.get("val").size() == wgt_size)
+                and (node.args[1] != wgt or gemm_idx == 0)
+            )
+            for gemm_idx, node in enumerate(computation_nodes)
+        )
+
+    def grouped_gemm_pass(graph: torch.fx.Graph):
+        """
+        Group GEMM has multi output nodes which is complicated to define a Pattern.
+        Use below way to connect the pattern to the lowering.
+        TODO: Use MultiOutputPattern, current limitation is the pattern requires
+        fixed number of output nodes. Extend to support Group GEMM for pattern matcher.
+        """
+        computation_op = mkldnn._linear_pointwise.default
+        from ..mkldnn_lowerings import grouped_gemm_lowering
+
+        for node in graph.find_nodes(op="call_function", target=computation_op):
+            if (
+                not node._erased
+                and isinstance(node.meta.get("val"), torch.Tensor)
+                and node.meta["val"].device.type == "cpu"
+            ):
+                act = node.args[0]
+                users = list(act.users)
+                if _is_valid_grouped_gemm_fusion(users):
+                    with graph.inserting_before(node):
+                        grouped_gemm_node = graph.create_node(
+                            "call_function",
+                            grouped_gemm_lowering,
+                            (
+                                act,
+                                [user.args[1] for user in users],
+                                [user.args[2] for user in users],
+                            ),
+                        )
+                        grouped_gemm_node.meta["val"] = [
+                            user.meta["val"] for user in users
+                        ]
+                        with graph.inserting_after(grouped_gemm_node):
+                            for gemm_idx, user in enumerate(users):
+                                assert user.target == computation_op
+                                get_item = graph.create_node(
+                                    "call_function",
+                                    operator.getitem,
+                                    (
+                                        grouped_gemm_node,
+                                        gemm_idx,
+                                    ),
+                                )
+                                user.replace_all_uses_with(get_item)
+                                graph.erase_node(user)
+        return
+
+    def _conv_call(users=1):
+        return CallFunction(
+            mkldnn._convolution_pointwise.default, *_conv_args, _users=users
+        )
+
+    def _linear_call(users=1):
+        return CallFunction(
+            mkldnn._linear_pointwise.default, *_linear_args, _users=users
+        )
+
+    def _conv_transpose_call(users=1):
+        return CallFunction(
+            mkldnn._convolution_transpose_pointwise.default,
+            *_conv_transpose_args,
+            _users=users,
+        )
+
+    def _to_float(input_call, users=1):
+        return CallFunction(
+            prims.convert_element_type.default,
+            input_call,
+            KeywordArg("to_float"),
+            _users=users,
+        )
+
+    def _to_bf16(input_call):
+        return CallFunction(
+            prims.convert_element_type.default,
+            input_call,
+            KeywordArg("to_bf16"),
+            _users=1,
+        )
+
+    def _to_fp16(input_call):
+        return CallFunction(
+            prims.convert_element_type.default,
+            input_call,
+            KeywordArg("to_fp16"),
+            _users=1,
+        )
+
+    def _unary_fusion_pattern(unary_fusion, call_fn, users, lowp_dtype):
+        # only insert to_dtype if lowp_dtype is True
+        computation_call = (
+            _to_float(call_fn(), users=users) if lowp_dtype else call_fn(users=users)
+        )
+        out = unary_fusion(computation_call)
+        if lowp_dtype == torch.bfloat16:
+            return _to_bf16(out)
+        elif lowp_dtype == torch.float16:
+            return _to_fp16(out)
+        else:
+            return out
+
+    def _gelu_fusion_1(computation_call):
+        return CallFunction(
+            aten.mul,
+            CallFunction(aten.mul, computation_call, 0.5),
+            CallFunction(
+                aten.add,
+                CallFunction(
+                    aten.erf,
+                    CallFunction(aten.mul, computation_call, 0.7071067811865476),
+                ),
+                1,
+            ),
+        )
+
+    def _gelu_fusion_2(computation_call):
+        return CallFunction(
+            aten.mul,
+            CallFunction(aten.mul, computation_call, 0.5),
+            CallFunction(
+                aten.add,
+                CallFunction(
+                    aten.tanh,
+                    CallFunction(
+                        aten.mul,
+                        CallFunction(
+                            aten.add,
+                            computation_call,
+                            CallFunction(
+                                aten.mul,
+                                CallFunction(
+                                    aten.mul,
+                                    CallFunction(
+                                        aten.mul, computation_call, computation_call
+                                    ),
+                                    computation_call,
+                                ),
+                                0.044715,
+                            ),
+                        ),
+                        0.7978845608028654,
+                    ),
+                ),
+                1,
+            ),
+        )
+
+    def _hardswish_fusion(computation_call):
+        return CallFunction(
+            aten.div,
+            CallFunction(
+                aten.mul,
+                computation_call,
+                CallFunction(
+                    aten.clamp_max,
+                    CallFunction(
+                        aten.clamp_min, CallFunction(aten.add, computation_call, 3), 0
+                    ),
+                    6,
+                ),
+            ),
+            6,
+        )
+
+    def _silu_fusion(computation_call):
+        return CallFunction(
+            aten.mul, computation_call, CallFunction(aten.sigmoid, computation_call)
+        )
+
+    def _hardsigmoid_fusion(computation_call):
+        return CallFunction(
+            aten.div,
+            CallFunction(
+                aten.clamp_max,
+                CallFunction(
+                    aten.clamp_min, CallFunction(aten.add, computation_call, 3), 0
+                ),
+                6,
+            ),
+            6,
+        )
+
+    def _leaky_relu_fusion(computation_call):
+        return CallFunction(
+            aten.where,
+            CallFunction(aten.gt, computation_call, 0),
+            computation_call,
+            CallFunction(aten.mul, computation_call, KeywordArg("negative_slope")),
+        )
+
+    def _hardtanh_fusion(computation_call):
+        return CallFunction(
+            aten.clamp_max,
+            CallFunction(aten.clamp_min, computation_call, KeywordArg("min_value")),
+            KeywordArg("max_value"),
+        )
+
+    def _combined_fusion(computation_call, elementwise_op):
+        return CallFunction(elementwise_op, computation_call)
+
+    # binary_op(other, computation_op)
+    def _binary_fusion_v1(computation_call, binary_fn):
+        return CallFunction(binary_fn, KeywordArg("other"), computation_call)
+
+    # binary_op(computation_op, other)
+    def _binary_fusion_v2(computation_call, binary_fn):
+        return CallFunction(binary_fn, computation_call, KeywordArg("other"))
+
+    def _is_single_computation_op(computation_op, lowp_dtype=None):
+        def fn(match):
+            computation_nodes = filter_nodes(match.nodes, computation_op)
+
+            if lowp_dtype:
+                output_node_meta = match.output_node().meta.get("val")
+                if output_node_meta.dtype != lowp_dtype:
+                    return False
+
+            if len(computation_nodes) < 1:
+                return False
+            if any(n.args[-3] != "none" for n in computation_nodes):
+                return False
+            return True
+
+        return fn
+
+    def _is_valid_computation_unary_fusion(computation_op, lowp_dtype=None):
+        def fn(match):
+            matched = _is_single_computation_op(computation_op, lowp_dtype)(match)
+            computation_node = filter_nodes(match.nodes, computation_op)[0]
+            if lowp_dtype:
+                conversion_dtype_nodes = filter_nodes(
+                    match.nodes, prims.convert_element_type.default
+                )
+                if len(conversion_dtype_nodes) != 2:
+                    return False
+                # fusion pattern is always in the form of computation_op + to_float32 + unary_op + to_bfloat16
+                if computation_node == conversion_dtype_nodes[0].args[0]:
+                    to_float = conversion_dtype_nodes[0].args[1]
+                    to_lp = conversion_dtype_nodes[1].args[1]
+                else:
+                    to_float = conversion_dtype_nodes[1].args[1]
+                    to_lp = conversion_dtype_nodes[0].args[1]
+                matched = matched and to_float == torch.float and to_lp == lowp_dtype
+            return matched
+
+        return fn
+
+    def _register_unary_fusion_lowering(
+        pattern, unary_attr, computation_op, lowp_dtype=None
+    ):
+        @register_lowering_pattern(
+            pattern,
+            extra_check=_is_valid_computation_unary_fusion(computation_op, lowp_dtype),
+        )
+        def fn(match, *args, **kwargs):
+            computation_args = list(args)[:-3] + [
+                unary_attr.op_name,
+                unary_attr.scalars_attr,
+                unary_attr.algorithm_attr,
+            ]
+            counters["inductor"]["mkldnn_unary_fusion_matcher_count"] += 1
+            counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"] += len(
+                match.nodes
+            )
+            return L[computation_op](*computation_args)
+
+        return fn
+
+    def _register_leaky_relu_fusion_lowering(pattern, computation_op, lowp_dtype=None):
+        @register_lowering_pattern(
+            pattern, extra_check=_is_single_computation_op(computation_op, lowp_dtype)
+        )
+        def fn(match, *args, **kwargs):
+            negative_slope = kwargs.get("negative_slope")
+            if isinstance(negative_slope, ir.TensorBox):
+                matched = False
+            else:  # inp is a Number
+                matched = True
+            if lowp_dtype:
+                dtype1 = kwargs.get("to_float")
+                dtype2 = (
+                    kwargs.get("to_bf16")
+                    if lowp_dtype == torch.bfloat16
+                    else kwargs.get("to_fp16")
+                )
+                matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype
+            computation_args = list(args)
+            counters["inductor"]["mkldnn_unary_fusion_matcher_count"] += 1
+            counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"] += len(
+                match.nodes
+            )
+            if matched:
+                computation_args = computation_args[:-3] + [
+                    "leaky_relu",
+                    [negative_slope],
+                    "",
+                ]
+                return L[computation_op](*computation_args)
+            else:
+                # computation_args += ["none", [], ""]
+                out = L[computation_op](*computation_args)
+                if lowp_dtype:
+                    out = L[prims.convert_element_type.default](out, dtype=torch.float)
+                out = L[aten.where](
+                    L[aten.gt](out, 0),
+                    out,
+                    L[aten.mul](out, negative_slope),
+                )
+                if lowp_dtype:
+                    out = L[prims.convert_element_type.default](out, dtype=dtype2)  # type: ignore[possibly-undefined]
+                return out
+
+        return fn
+
+    def _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype=None):
+        @register_lowering_pattern(
+            pattern, extra_check=_is_single_computation_op(computation_op, lowp_dtype)
+        )
+        def fn(match, *args, **kwargs):
+            min_value = kwargs.get("min_value")
+            max_value = kwargs.get("max_value")
+            if isinstance(min_value, ir.TensorBox) or isinstance(
+                max_value, ir.TensorBox
+            ):
+                matched = False
+            else:  # inp is a Number
+                assert max_value is not None
+                matched = min_value <= max_value
+            if lowp_dtype:
+                dtype1 = kwargs.get("to_float")
+                dtype2 = (
+                    kwargs.get("to_bf16")
+                    if lowp_dtype == torch.bfloat16
+                    else kwargs.get("to_fp16")
+                )
+                matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype
+            computation_args = list(args)
+            counters["inductor"]["mkldnn_unary_fusion_matcher_count"] += 1
+            counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"] += len(
+                match.nodes
+            )
+            if matched:
+                computation_args = computation_args[:-3] + [
+                    "hardtanh",
+                    [min_value, max_value],
+                    "",
+                ]
+                return L[computation_op](*computation_args)
+            else:
+                out = L[computation_op](*computation_args)
+                if lowp_dtype:
+                    out = L[prims.convert_element_type.default](out, dtype=torch.float)
+                out = L[aten.clamp_max](L[aten.clamp_min](out, min_value), max_value)
+                if lowp_dtype:
+                    out = L[prims.convert_element_type.default](out, dtype=dtype2)  # type: ignore[possibly-undefined]
+                return out
+
+        return fn
+
+    _binary_attr = {
+        aten.add: "add",
+        ops.add: "add",
+        aten.sub: "sub",
+        ops.sub: "sub",
+    }
+
+    def _is_valid_binary(match, computation_op, binary_op):
+        binary_nodes = filter_nodes(match.nodes, binary_op)
+        if len(binary_nodes) < 1:
+            return False
+
+        def get_meta_value(argument: torch.fx.node.Argument):
+            # Only torch.fx.Node is expected to have meta.
+            if isinstance(argument, torch.fx.Node):
+                return argument.meta.get("val", None)
+            return None
+
+        if any(
+            not isinstance(get_meta_value(n.args[0]), torch.Tensor)
+            or not isinstance(get_meta_value(n.args[1]), torch.Tensor)
+            for n in binary_nodes
+        ):
+            return False
+        # check alpha is one.
+        if any(
+            get_arg_value(n, 2, kwarg_name="alpha") != 1.0
+            and get_arg_value(n, 2, kwarg_name="alpha") is not None
+            for n in binary_nodes
+        ):
+            return False
+
+        def _check_input_sizes(n, computation_op):
+            # Check if the tensor shape of the 'other' node is the same as or
+            # can be broadcasted to the tensor shape of the computation node.
+            computation_node = (
+                n.args[0] if n.args[1] is match.kwargs["other"] else n.args[1]
+            )
+            assert computation_node.target == computation_op
+            computation_node_size = get_meta_value(computation_node).size()
+            if computation_op is mkldnn._linear_pointwise.default:
+                broadcast_sizes = []
+                if len(computation_node_size) >= 2:
+                    broadcast_sizes = [
+                        torch.Size(
+                            [1 for _ in range(len(computation_node_size) - 1)]
+                            + [computation_node_size[-1]]
+                        ),
+                    ]
+            else:
+                assert len(computation_node_size) > 2
+                broadcast_sizes = [
+                    torch.Size(
+                        [computation_node_size[0], computation_node_size[1]]
+                        + [1 for _ in range(len(computation_node_size) - 2)]
+                    ),
+                    torch.Size(
+                        [1, computation_node_size[1]]
+                        + [1 for _ in range(len(computation_node_size) - 2)]
+                    ),
+                    torch.Size([1 for _ in range(len(computation_node_size))]),
+                ]
+            return (
+                get_meta_value(match.kwargs["other"]).size()
+                in [
+                    computation_node_size,
+                ]
+                + broadcast_sizes
+            )
+
+        if any(
+            not _check_input_sizes(n, computation_op)
+            or get_meta_value(n.args[0]).device != get_meta_value(n.args[1]).device
+            or get_meta_value(n.args[0]).dtype != get_meta_value(n.args[1]).dtype
+            for n in binary_nodes
+        ):
+            return False
+        # check args[0] and args[1] is not same
+        if any(n.args[0] == n.args[1] for n in binary_nodes):
+            return False
+        return True
+
+    def _is_valid_computation_binary(computation_op, binary_op, other_index=None):
+        def fn(match):
+            if not _is_single_computation_op(computation_op)(match):
+                return False
+            if not _is_valid_binary(match, computation_op, binary_op):
+                return False
+            return True
+
+        return fn
+
+    def _get_remaining_users(extra_input_node, compute_node):
+        # Think about this pattern:
+        #      ReLU
+        #     /   \
+        #  Conv1
+        #   /      \
+        # Conv2
+        #   \      /
+        #      Add
+        # Although, the extra input node (ReLU) has more than 1 users: Conv1 and Add.
+        # The Conv1 is the ancestor node of the current compute node (Conv2).
+        # This indicates that the buffer of ReLU has completed all its usage,
+        # So we can safely make changes to it now by doing Conv2->Add inplace fusion.
+        # Take above case as example:
+        # * extra_input_node: ReLU
+        # * compute_node: Conv2
+        # _get_remaining_users will return the users of extra_input_node which are not
+        # ancestor node of compute_node.
+        def _is_ancestor_node(_current_node, _ancestor_node):
+            # Check whether _ancestor_node is the ancestor node of _current_node
+            _node_list = [_current_node]
+            _visited_nodes = OrderedSet[torch.fx.Node]()
+            while len(_node_list) != 0:
+                _current_node = _node_list.pop(0)
+                if _current_node not in _visited_nodes:
+                    _visited_nodes.add(_current_node)
+                    if _current_node == _ancestor_node:
+                        return True
+                    elif isinstance(
+                        _current_node, torch.fx.Node
+                    ) and _current_node.op not in ["placeholder", "output", "get_attr"]:
+                        for input in _current_node.all_input_nodes:
+                            _node_list.append(input)  # noqa: PERF402
+            return False
+
+        return [
+            user
+            for user in list(extra_input_node.users)
+            if not _is_ancestor_node(compute_node, user)
+        ]
+
+    def _is_valid_computation_binary_inplace(computation_op, binary_op, other_index):
+        def fn(match):
+            if not _is_valid_computation_binary(computation_op, binary_op)(match):
+                return False
+            binary_nodes = filter_nodes(match.nodes, binary_op)
+
+            def _get_compute_node(_binary_node, _other_index):
+                assert len(_binary_node.all_input_nodes) == 2, (
+                    "Binary node should have 2 input nodes."
+                )
+                _compute_index = 1 if (_other_index == 0) else 0
+                return _binary_node.args[_compute_index]
+
+            def _other_input_not_inplaceable(_binary_node, _other_index):
+                _compute_node = _get_compute_node(_binary_node, _other_index)
+                return (
+                    len(
+                        _get_remaining_users(
+                            _binary_node.args[_other_index], _compute_node
+                        )
+                    )
+                    > 1
+                    or _binary_node.args[_other_index] == _compute_node.args[0]
+                )
+
+            if any(_other_input_not_inplaceable(n, other_index) for n in binary_nodes):
+                return False
+            if any(
+                # pyrefly: ignore [missing-attribute]
+                n.args[other_index].op in ["placeholder", "output"]
+                for n in binary_nodes
+            ):
+                return False
+            return True
+
+        return fn
+
+    def _register_binary_unary_fusion_lowering(
+        pattern,
+        computation_op,
+        binary_op,
+        fusion_op,
+        unary_attr=None,
+    ):
+        @register_lowering_pattern(
+            pattern, extra_check=_is_valid_computation_binary(computation_op, binary_op)
+        )
+        def fn(match, *args, **kwargs):
+            other = kwargs.get("other")
+            assert isinstance(other, ir.TensorBox)
+            binary_attr = _binary_attr[binary_op]
+            args_list = list(args)
+            computation_args = [args_list[0], other] + args_list[1:-3] + [binary_attr]
+            if len(args_list) > 6:
+                if unary_attr is not None:
+                    computation_args += [
+                        1.0,
+                        unary_attr.op_name,
+                        unary_attr.scalars_attr,
+                        unary_attr.algorithm_attr,
+                    ]
+                else:
+                    computation_args += [1.0, None, [], None]
+            counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1
+            counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_nodes"] += (
+                len(match.nodes)
+            )
+            return L[fusion_op](*computation_args)
+
+        return fn
+
+    def _can_be_inplace(_other):
+        return not (
+            isinstance(_other.data, ir.BaseView)
+            or len(_other.get_inputs_that_alias_output()) > 0
+        )
+
+    def _qlinear_binary_can_be_inplace(_other):
+        if isinstance(_other.data, ir.BaseView):
+
+            def unwrap_buffer(data):
+                if isinstance(data, ir.StorageBox):
+                    return data.data
+                return data
+
+            data = _other.data.unwrap_view()
+            if isinstance(unwrap_buffer(data), ir.CppTemplateBuffer):
+                # It can be inplaced when _other is the 2D to 3D view of
+                # a CppTemplateBuffer because if there is a view of CppTemplateBuffer,
+                # CppTemplateBuffer will not be used directly but the view.
+                return True
+            else:
+                # The case of QLinearPointwiseBinaryPT2E(sum) -> QLinearPointwiseBinaryPT2E(sum)
+                # is similar to CppTemplateBuffer above.
+                # The output of previous QLinearPointwiseBinaryPT2E is
+                # the input x2 of current QLinearPointwiseBinaryPT2E.
+                # Use V.graph.operations to check if _other is a view of the output
+                # of previous QLinearPointwiseBinaryPT2E (the inputs[6]).
+                for op in V.graph.operations:
+                    if (
+                        isinstance(op, mkldnn_ir.QLinearPointwiseBinaryPT2E)
+                        and unwrap_buffer(data) == op.inputs[6]  # type: ignore[attr-defined]
+                    ):
+                        return True
+            return False
+        elif len(_other.get_inputs_that_alias_output()) > 0:
+            return False
+        else:
+            return True
+
+    def _register_binary_unary_maybe_inplace_fusion_lowering(
+        pattern,
+        computation_op,
+        binary_op,
+        inplace_fusion_op,
+        outplace_fusion_op,
+        unary_attr=None,
+        other_index=None,
+    ):
+        @register_lowering_pattern(
+            pattern,
+            extra_check=_is_valid_computation_binary_inplace(
+                computation_op, binary_op, other_index
+            ),
+        )
+        def fn(match, *args, **kwargs):
+            other = kwargs.get("other")
+            assert isinstance(other, ir.TensorBox)
+            binary_attr = _binary_attr[binary_op]
+            args_list = list(args)
+            computation_args = [args_list[0], other] + args_list[1:-3] + [binary_attr]
+            if len(args_list) > 6:
+                if unary_attr is not None:
+                    computation_args += [
+                        1.0,
+                        unary_attr.op_name,
+                        unary_attr.scalars_attr,
+                        unary_attr.algorithm_attr,
+                    ]
+                else:
+                    computation_args += [1.0, None, [], None]
+            counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1
+            counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_nodes"] += (
+                len(match.nodes)
+            )
+            # Make sure the other is not an alias or mutation(fx side doesn't has such info).
+            other.realize()
+            if not _can_be_inplace(other) or other.data.shape != list(
+                match.nodes[0].meta["val"].size()
+            ):
+                return L[outplace_fusion_op](*computation_args)
+            return L[inplace_fusion_op](*computation_args)
+
+        return fn
+
+    computation_ops = [
+        mkldnn._convolution_pointwise.default,
+        mkldnn._linear_pointwise.default,
+        mkldnn._convolution_transpose_pointwise.default,
+    ]
+
+    class UnaryAttr:
+        def __init__(
+            self, op_name: str, scalars_attr=None, algorithm_attr=None
+        ) -> None:
+            self.op_name = op_name
+            self.scalars_attr = scalars_attr if scalars_attr else []
+            self.algorithm_attr = algorithm_attr if algorithm_attr else ""
+
+    def _register_unary_fusion():
+        computation_call_fns = [_conv_call, _linear_call, _conv_transpose_call]
+
+        def _unary_fusion_patterns(lowp_dtype):
+            replacement_unary_fusion_patterns = {
+                UnaryAttr("gelu", algorithm_attr="tanh"): [
+                    _unary_fusion_pattern(_gelu_fusion_2, call_fn, 4, lowp_dtype)
+                    for call_fn in computation_call_fns
+                ],
+                UnaryAttr("gelu", algorithm_attr="none"): [
+                    _unary_fusion_pattern(_gelu_fusion_1, call_fn, 2, lowp_dtype)
+                    for call_fn in computation_call_fns
+                ],
+                UnaryAttr("hardswish"): [
+                    _unary_fusion_pattern(_hardswish_fusion, call_fn, 2, lowp_dtype)
+                    for call_fn in computation_call_fns
+                ],
+                UnaryAttr("hardsigmoid"): [
+                    _unary_fusion_pattern(_hardsigmoid_fusion, call_fn, 1, lowp_dtype)
+                    for call_fn in computation_call_fns
+                ],
+                UnaryAttr("swish"): [
+                    _unary_fusion_pattern(_silu_fusion, call_fn, 2, lowp_dtype)
+                    for call_fn in computation_call_fns
+                ],
+            }
+            if not lowp_dtype:
+                call_user1 = [call_fn(users=1) for call_fn in computation_call_fns]
+                replacement_unary_fusion_patterns.update(
+                    {
+                        UnaryAttr("relu"): [
+                            _combined_fusion(u, aten.relu) for u in call_user1
+                        ],
+                        UnaryAttr("sigmoid"): [
+                            _combined_fusion(u, aten.sigmoid) for u in call_user1
+                        ],
+                        UnaryAttr("tanh"): [
+                            _combined_fusion(u, aten.tanh) for u in call_user1
+                        ],
+                    }
+                )
+
+            return replacement_unary_fusion_patterns
+
+        for lowp_dtype in [torch.bfloat16, torch.float16, None]:
+            replace_patterns = _unary_fusion_patterns(lowp_dtype)
+            for unary_attr, patterns in replace_patterns.items():
+                _register_unary_fusion_lowering(
+                    patterns[0], unary_attr, computation_ops[0], lowp_dtype
+                )
+                _register_unary_fusion_lowering(
+                    patterns[1], unary_attr, computation_ops[1], lowp_dtype
+                )
+                _register_unary_fusion_lowering(
+                    patterns[2], unary_attr, computation_ops[2], lowp_dtype
+                )
+            _leaky_relu_patterns = [
+                _unary_fusion_pattern(_leaky_relu_fusion, call_fn, 3, lowp_dtype)
+                for call_fn in computation_call_fns
+            ]
+            for pattern, computation_op in zip(_leaky_relu_patterns, computation_ops):
+                _register_leaky_relu_fusion_lowering(
+                    pattern, computation_op, lowp_dtype
+                )
+            hardtanh_patterns = [
+                _unary_fusion_pattern(_hardtanh_fusion, call_fn, 1, lowp_dtype)
+                for call_fn in computation_call_fns
+            ]
+            for pattern, computation_op in zip(hardtanh_patterns, computation_ops):
+                _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype)
+
+    def _register_inplace_fusion():
+        binary_ops = [aten.add, ops.add]
+        inplace_fusion_op = mkldnn._convolution_pointwise_.binary
+        outplace_fusion_op = mkldnn._convolution_pointwise.binary
+        conv_call = _conv_call(users=1)
+        conv_op = computation_ops[0]
+        for binary_op in binary_ops:
+            binary_v1 = _binary_fusion_v1(conv_call, binary_op)
+            binary_unary_v1 = _combined_fusion(binary_v1, aten.relu)
+            _register_binary_unary_maybe_inplace_fusion_lowering(
+                binary_unary_v1,
+                conv_op,
+                binary_op,
+                inplace_fusion_op,
+                outplace_fusion_op,
+                other_index=0,
+                unary_attr=UnaryAttr("relu"),
+            )
+            _register_binary_unary_maybe_inplace_fusion_lowering(
+                binary_v1,
+                conv_op,
+                binary_op,
+                inplace_fusion_op,
+                outplace_fusion_op,
+                other_index=0,
+            )
+            binary_v2 = _binary_fusion_v2(conv_call, binary_op)
+            binary_unary_v2 = _combined_fusion(binary_v2, aten.relu)
+            _register_binary_unary_maybe_inplace_fusion_lowering(
+                binary_unary_v2,
+                conv_op,
+                binary_op,
+                inplace_fusion_op,
+                outplace_fusion_op,
+                other_index=1,
+                unary_attr=UnaryAttr("relu"),
+            )
+            _register_binary_unary_maybe_inplace_fusion_lowering(
+                binary_v2,
+                conv_op,
+                binary_op,
+                inplace_fusion_op,
+                outplace_fusion_op,
+                other_index=1,
+            )
+
+    def _register_binary_fusion():
+        binary_ops = [aten.add, ops.add, aten.sub, ops.sub]
+        fusion_ops = [
+            mkldnn._convolution_pointwise.binary,
+            mkldnn._linear_pointwise.binary,
+        ]
+        _computation_user_1 = [_conv_call(users=1), _linear_call(users=1)]
+        for computation_call, computation_op, fusion_op in zip(
+            _computation_user_1, computation_ops[:-1], fusion_ops
+        ):
+            for binary_op in binary_ops:
+                pattern = _binary_fusion_v2(computation_call, binary_op)
+                _register_binary_unary_fusion_lowering(
+                    pattern, computation_op, binary_op, fusion_op
+                )
+
+            for binary_op in [aten.add, ops.add]:
+                pattern = _binary_fusion_v1(computation_call, binary_op)
+                _register_binary_unary_fusion_lowering(
+                    pattern, computation_op, binary_op, fusion_op
+                )
+
+    def _register_binary_unary_fusion():
+        binary_ops = [aten.add, ops.add, aten.sub, ops.sub]
+        fusion_ops = [mkldnn._convolution_pointwise.binary]
+        _computation_user_1 = [_conv_call(users=1)]
+        for computation_call, computation_op, fusion_op in zip(
+            _computation_user_1, computation_ops[:-1], fusion_ops
+        ):
+            for binary_op in binary_ops:
+                pattern_v1 = _combined_fusion(
+                    _binary_fusion_v2(computation_call, binary_op), aten.relu
+                )
+                _register_binary_unary_fusion_lowering(
+                    pattern_v1,
+                    computation_op,
+                    binary_op,
+                    fusion_op,
+                    unary_attr=UnaryAttr("relu"),
+                )
+            for binary_op in [aten.add, ops.add]:
+                pattern_v2 = _combined_fusion(
+                    _binary_fusion_v1(computation_call, binary_op), aten.relu
+                )
+                _register_binary_unary_fusion_lowering(
+                    pattern_v2,
+                    computation_op,
+                    binary_op,
+                    fusion_op,
+                    unary_attr=UnaryAttr("relu"),
+                )
+
+    def _recover_linear():
+        # convert reshape+linear+reshape to a single linear for applying fusion path.
+        # concat_linear (pass_number=0) -> mkldnn_linear_pack (pass_number=1) -> _recover_linear(pass_number=2)
+        @register_freezing_graph_pattern(
+            CallFunction(
+                aten.reshape.default,
+                CallFunction(
+                    mkldnn._linear_pointwise.default,
+                    CallFunction(
+                        aten.reshape.default,
+                        Arg(),
+                        KeywordArg("reshape_1"),
+                        _users=MULTIPLE,
+                    ),
+                    Arg(),
+                    Arg(),
+                    Arg(),
+                    Arg(),
+                    Arg(),
+                ),
+                KeywordArg("reshape_2"),
+            ),
+            pass_number=2,
+        )
+        def reshape_linear_reshape_pattern(match, *args, **kwargs):
+            def get_val(val):
+                return val if isinstance(val, int) else val.meta.get("val")
+
+            reshape_1 = kwargs.get("reshape_1")
+            reshape_2 = kwargs.get("reshape_2")
+            assert isinstance(reshape_1, list)
+            assert isinstance(reshape_2, list)
+            assert len(reshape_1) == 2
+
+            graph = match.graph
+            reshape_2_node = match.output_node()
+            linear_input_node = reshape_2_node.args[0].args[0].args[0]
+            # check linear's input's shape[:-1] == reshape_2[:-1]
+            # and check product(reshape_2[:-1]) == reshape_1[0]
+            can_remove_reshape = linear_input_node.meta.get("val").shape[
+                :-1
+            ] == torch.Size([get_val(val) for val in reshape_2[:-1]])
+            can_remove_reshape = can_remove_reshape and (
+                reduce(
+                    operator.mul,
+                    [get_val(val) for val in reshape_2[:-1]],
+                )
+                == get_val(reshape_1[0])
+            )
+
+            if can_remove_reshape:
+                repl = graph.call_function(mkldnn._linear_pointwise.default, args)
+                repl.meta.update(reshape_2_node.meta)
+                reshape_2_node.replace_all_uses_with(repl)
+                old_linear_node = reshape_2_node.args[0]
+                reshape_1_node = old_linear_node.args[0]
+                graph.erase_node(reshape_2_node)
+                graph.erase_node(old_linear_node)
+                if len(reshape_1_node.users) == 0:
+                    graph.erase_node(reshape_1_node)
+            counters["inductor"]["mkldnn_reshape_linear_reshape_matcher_count"] += 1
+            counters["inductor"]["mkldnn_reshape_linear_reshape_matcher_nodes"] += len(
+                match.nodes
+            )
+
+        def is_linear_add_bias(match):
+            add_node = match.output_node()
+            linear_node = add_node.args[0]
+            device_type = add_node.meta.get("val").device.type
+            mkldnn_device_op = _get_mkldnn_device_op(device_type)
+            transpose_weight_node = mkldnn_device_op.get_linear_transpose_weight(
+                linear_node.args[1]
+            )
+            weight_meta = transpose_weight_node.args[0].meta.get("val")
+            bias_node = add_node.args[1]
+            if isinstance(bias_node, int):
+                # we only folding bias if it is a constant
+                return False
+            bias_meta = add_node.args[1].meta.get("val")
+            if weight_meta is None or bias_meta is None:
+                return False
+
+            if bias_meta.dtype != weight_meta.dtype:
+                return False
+            return (
+                linear_node.args[2] is None
+                and bias_meta.dim() == 1
+                and bias_meta.size(0) == weight_meta.size(1)
+            )
+
+        # convert linear+bias to a single linear for applying fusion path.
+        @register_freezing_graph_pattern(
+            CallFunction(
+                aten.add.Tensor,
+                CallFunction(mkldnn._linear_pointwise.default, *_linear_args),
+                Arg(),
+            ),
+            pass_number=2,
+            extra_check=is_linear_add_bias,
+        )
+        def linear_bias_pattern(match, *args):
+            graph = match.graph
+            add_node = match.output_node()
+            linear_node = add_node.args[0]
+            new_args = list(linear_node.args)
+            new_args[2] = add_node.args[1]
+            repl = graph.call_function(
+                mkldnn._linear_pointwise.default, tuple(new_args)
+            )
+            repl.meta.update(add_node.meta)
+            add_node.replace_all_uses_with(repl)
+            match.erase_nodes()
+            counters["inductor"]["mkldnn_linear_bias_matcher_count"] += 1
+            counters["inductor"]["mkldnn_linear_bias_matcher_nodes"] += len(match.nodes)
+
+    def _is_packable_mkldnn_rnn_layer(match):
+        lstm_node = match.output_node()
+        POS_WEIGHTS = [1, 2]
+        POS_INPUTS = [0, 5, 6]
+        POS_ARGS = POS_WEIGHTS + POS_INPUTS
+        # Weights should be Constant
+        if any(
+            lstm_node.args[POS_WEIGHT].op != "get_attr" for POS_WEIGHT in POS_WEIGHTS
+        ):
+            return False
+
+        # Meta info for weights and inputs should be available
+        if any(lstm_node.args[POS_ARG].meta.get("val") is None for POS_ARG in POS_ARGS):
+            return False
+
+        # Check device
+        if any(
+            lstm_node.args[POS_ARG].meta.get("val").device.type != "cpu"
+            for POS_ARG in POS_ARGS
+        ):
+            return False
+
+        # Check dtype
+        if any(
+            lstm_node.args[POS_ARG].meta.get("val").dtype == torch.bfloat16
+            and not is_mkldnn_bf16_supported("cpu")
+            for POS_ARG in POS_ARGS
+        ):
+            return False
+        if any(
+            lstm_node.args[POS_ARG].meta.get("val").dtype == torch.float16
+            and not is_mkldnn_fp16_supported("cpu")
+            for POS_ARG in POS_ARGS
+        ):
+            return False
+
+        return True
+
+    def _is_packable_convolution(match):
+        """
+        Check if the node is supported for MKLDNN convolution.
+        """
+        conv_node = match.output_node()
+        device_type = conv_node.meta.get("val").device.type
+        # The operator 'mkldnn::_convolution_transpose_pointwise' is not currently implemented for the XPU device.
+        if match.kwargs["is_transposed"] and device_type == "xpu":
+            return False
+
+        input_meta_value = conv_node.args[0].meta.get("val")
+        weight_meta_value = conv_node.args[1].meta.get("val")
+        if input_meta_value is None or weight_meta_value is None:
+            return False
+        input_size = input_meta_value.shape
+        if conv_node.args[1].op != "get_attr":
+            return False
+        for meta_value in [input_meta_value, weight_meta_value]:
+            if (
+                meta_value is None
+                or meta_value.device.type not in SUPPORTED_MKLDNN_DEVICES
+                or (meta_value.dim() != 4 and meta_value.dim() != 5)
+            ):
+                return False
+
+        if (
+            input_meta_value.dtype == torch.bfloat16
+            or weight_meta_value.dtype == torch.bfloat16
+        ):
+            if not is_mkldnn_bf16_supported(device_type):
+                return False
+        if (
+            input_meta_value.dtype == torch.float16
+            or weight_meta_value.dtype == torch.float16
+        ):
+            if not is_mkldnn_fp16_supported(device_type):
+                return False
+        is_transposed = conv_node.args[-3]
+        if is_transposed:
+            # TODO: Support dynamic shape case for MKLDNN conv transpose.
+            if has_free_symbols(input_size):
+                return False
+            groups = conv_node.args[-1]
+            in_channels = weight_meta_value.size(0)
+            # doesn't support group_depthwise_conv_transpose.
+            if groups > 1 and groups == in_channels:
+                return False
+            # Port from: aten/src/ATen/native/Convolution.cpp:is_output_padding_big
+            output_paddings = conv_node.args[-2]
+            strides = conv_node.args[3]
+            if any(
+                output_padding >= stride
+                for output_padding, stride in zip(output_paddings, strides)
+            ):
+                return False
+        return True
+
+    def _is_packable_linear(match):
+        """
+        Check if the node is supported for MKLDNN linear.
+        """
+
+        def is_const_or_cat_by_const(weight):
+            if weight.op == "get_attr":
+                return True
+            if weight.target != aten.cat.default:
+                return False
+            return all(arg.op == "get_attr" for arg in weight.args[0])
+
+        linear_node = match.output_node()
+        # mkldnn linear only supports beta=1or0 and alpha=1
+        if linear_node.target is aten.addmm.default:
+            alpha = linear_node.kwargs.get("alpha", 1.0)
+            beta = linear_node.kwargs.get("beta", 1.0)
+            if (beta != 0.0 and beta != 1.0) or alpha != 1.0:
+                return False
+        # weight_idx is 1 for aten.mm and is 2 for aten.addmm
+        weight_idx = 2 if linear_node.target is aten.addmm.default else 1
+        if not is_const_or_cat_by_const(linear_node.args[weight_idx]):
+            return False
+        input_meta_value = linear_node.args[weight_idx - 1].meta.get("val")
+        weight_meta_value = linear_node.args[weight_idx].meta.get("val")
+        if input_meta_value is None or weight_meta_value is None:
+            return False
+        if (
+            input_meta_value.dtype == torch.float64
+            or weight_meta_value.dtype == torch.float64
+        ):
+            return False
+        is_lp_weight = weight_meta_value.dtype in (
+            torch.bfloat16,
+            torch.float16,
+        )
+        reduced_f32_matmul_enabled = torch.backends.mkldnn.matmul.fp32_precision in [  # type: ignore[attr-defined]
+            "bf16",
+            "tf32",
+        ]
+        use_reduced_f32_for_fp32_weight = (
+            reduced_f32_matmul_enabled and weight_meta_value.dtype == torch.float32
+        )
+        compute_with_lp = is_lp_weight or use_reduced_f32_for_fp32_weight
+        # on x86, for fp32, mkl should be enabled.
+        # on aarch64, use mkldnn op for fp32 as well if acl is enabled
+        if (
+            not compute_with_lp
+            and not mkldnn._is_mkldnn_acl_supported()
+            and not torch._C.has_mkl
+        ):
+            return False
+        for meta_value in [input_meta_value, weight_meta_value]:
+            if (
+                meta_value is None
+                or meta_value.device.type != "cpu"
+                or meta_value.dim() != 2
+            ):
+                return False
+        if weight_idx == 2:
+            bias_meta_value = linear_node.args[0].meta.get("val")
+            if (
+                bias_meta_value is None
+                or meta_value.device.type != "cpu"
+                or bias_meta_value.dim() != 1
+                or bias_meta_value.size(0) != weight_meta_value.size(1)
+            ):
+                return False
+
+        device_type = input_meta_value.device.type
+        if (
+            input_meta_value.dtype == torch.bfloat16
+            or weight_meta_value.dtype == torch.bfloat16
+        ):
+            if not is_mkldnn_bf16_supported(device_type):
+                return False
+        if (
+            input_meta_value.dtype == torch.float16
+            or weight_meta_value.dtype == torch.float16
+        ):
+            if not is_mkldnn_fp16_supported(device_type):
+                return False
+        return True
+
+    _aten_conv_args = (
+        Arg(),
+        Arg(),
+        Arg(),
+        Arg(),
+        Arg(),
+        Arg(),
+        KeywordArg("is_transposed"),
+        Arg(),
+        Arg(),
+    )
+
+    _aten_mkldnn_rnn_layer_args = (
+        Arg(),  # input
+        Arg(),  # weight0
+        Arg(),  # weight1
+        Arg(),  # weight2
+        Arg(),  # weight3
+        Arg(),  # hx_
+        Arg(),  # cx_
+        KeywordArg("reverse"),  # reverse
+        Arg(),  # batch_sizes
+        Arg(),  # mode
+        Arg(),  # hidden_size
+        Arg(),  # num_layers
+        Arg(),  # has_biases
+        Arg(),  # bidirectional
+        Arg(),  # batch_first
+        Arg(),  # train
+    )
+
+    def _register_weight_pack_pass():
+        @register_freezing_graph_pattern(
+            CallFunction(aten.convolution.default, *_aten_conv_args),
+            extra_check=_is_packable_convolution,
+        )
+        def convolution(match, *args, **kwargs):
+            is_transposed = kwargs.get("is_transposed")
+            assert isinstance(is_transposed, bool)
+            graph = match.graph
+            conv_node = match.output_node()
+            device_type = conv_node.args[0].meta.get("val").device.type
+            mkldnn_device_op = _get_mkldnn_device_op(device_type)
+            input_size = conv_node.args[0].meta.get("val").shape
+            with graph.inserting_before(conv_node):
+                constant_args = [args[4], args[3], args[5], args[-1]]
+                packed_conv_op = mkldnn._convolution_pointwise.default
+                if is_transposed:
+                    constant_args.insert(1, args[-2])  # output_padding
+                    packed_conv_op = mkldnn._convolution_transpose_pointwise.default
+
+                if not has_free_symbols(input_size):
+                    packed_weight_node = mkldnn_device_op.pack_conv_weight(
+                        graph,
+                        is_transposed,
+                        args[1],
+                        constant_args,
+                        input_size,
+                    )
+                else:
+                    assert not is_transposed
+                    # For dynamic shape case, we need to pack weight in runtime.
+                    packed_weight_node = args[1]
+
+                packed_conv_inputs = (
+                    (args[0], packed_weight_node, args[2])
+                    + tuple(constant_args)
+                    + ("none", [], "")
+                )
+                packed_conv_node = graph.create_node(
+                    "call_function", packed_conv_op, tuple(packed_conv_inputs)
+                )
+                conv_node.replace_all_uses_with(packed_conv_node)
+                packed_conv_node.meta.update(conv_node.meta)
+                graph.erase_node(conv_node)
+            counters["inductor"]["mkldnn_conv_weight_pack_matcher_count"] += 1
+            counters["inductor"]["mkldnn_conv_weight_pack_matcher_nodes"] += len(
+                match.nodes
+            )
+
+        @register_freezing_graph_pattern(
+            CallFunction(aten.mkldnn_rnn_layer.default, *_aten_mkldnn_rnn_layer_args),
+            extra_check=_is_packable_mkldnn_rnn_layer,
+        )
+        def mkldnn_rnn_layer(match, *args, **kwargs):
+            def get_item(graph, node, index):
+                return graph.call_function(operator.getitem, (node, index))
+
+            graph = match.graph
+            lstm_node = match.output_node()
+            weight0, weight1 = args[1:3]
+            reverse = kwargs.get("reverse")
+            packed_lstm_op = aten.mkldnn_rnn_layer.default
+            hidden_size = args[9]
+            has_biases = args[11]
+            batch_first = args[13]
+            with graph.inserting_before(lstm_node):
+                packed_weight_op = mkldnn._reorder_mkldnn_rnn_layer_weight.default
+                packed_weight_inputs = (
+                    weight0,
+                    weight1,
+                    hidden_size,
+                    reverse,
+                    has_biases,
+                    batch_first,
+                )
+                packed_weight_node = graph.create_node(
+                    "call_function", packed_weight_op, packed_weight_inputs, {}, "name"
+                )
+                packed_weight_items = [
+                    get_item(graph, packed_weight_node, i) for i in range(2)
+                ]
+                pack_lstm_inputs = (
+                    args[0],
+                    *packed_weight_items,
+                    args[3],
+                    args[4],
+                    args[5],
+                    args[6],
+                    reverse,
+                    *args[7:],
+                )
+
+                packed_lstm_node = graph.create_node(
+                    "call_function", packed_lstm_op, args=pack_lstm_inputs
+                )
+                lstm_node.replace_all_uses_with(packed_lstm_node)
+                packed_lstm_node.meta.update(lstm_node.meta)
+                graph.erase_node(lstm_node)
+            counters["inductor"]["mkldnn_rnn_weight_pack_matcher_count"] += 1
+            counters["inductor"]["mkldnn_rnn_weight_pack_matcher_nodes"] += len(
+                match.nodes
+            )
+
+        @register_freezing_graph_pattern(
+            CallFunction(
+                aten.addmm.default,
+                Arg(),
+                Arg(),
+                Arg(),
+                beta=KeywordArg("beta"),
+                alpha=KeywordArg("alpha"),
+            ),
+            extra_check=_is_packable_linear,
+            pass_number=1,
+        )
+        @register_freezing_graph_pattern(
+            CallFunction(aten.mm.default, Arg(), Arg()),
+            extra_check=_is_packable_linear,
+            pass_number=1,
+        )
+        def linear(match, *args, **kwargs):
+            graph = match.graph
+            linear_node = match.output_node()
+            input = args[0] if linear_node.target is aten.mm.default else args[1]
+            bias = (
+                None
+                if linear_node.target is aten.mm.default
+                or (
+                    linear_node.target is aten.addmm.default
+                    and linear_node.kwargs.get("beta", 1.0) == 0.0
+                )
+                else args[0]
+            )
+            weight = args[1] if linear_node.target is aten.mm.default else args[2]
+            device_type = input.meta.get("val").device.type
+            mkldnn_device_op = _get_mkldnn_device_op(device_type)
+            with graph.inserting_before(linear_node):
+                transpose_weight_node = graph.create_node(
+                    "call_function", aten.permute.default, (weight, (1, 0))
+                )
+                weight_dtype = weight.meta.get("val").dtype
+                is_lp_weight = weight_dtype in (
+                    torch.bfloat16,
+                    torch.float16,
+                )
+                reduced_f32_matmul_enabled = (
+                    torch.backends.mkldnn.matmul.fp32_precision in ["bf16", "tf32"]  # type: ignore[attr-defined]
+                )
+                use_reduced_f32_for_fp32_weight = (
+                    reduced_f32_matmul_enabled and weight_dtype == torch.float32
+                )
+                compute_with_lp = is_lp_weight or use_reduced_f32_for_fp32_weight
+                batch_size = input.meta.get("val").shape[0]
+                packed_weight_node = mkldnn_device_op.pack_linear_weight(
+                    graph, compute_with_lp, transpose_weight_node, batch_size
+                )
+                packed_linear_node = mkldnn_device_op.pack_linear(
+                    graph, compute_with_lp, batch_size, input, packed_weight_node, bias
+                )
+
+                linear_node.replace_all_uses_with(packed_linear_node)
+                packed_linear_node.meta.update(linear_node.meta)
+                graph.erase_node(linear_node)
+            counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"] += 1
+            counters["inductor"]["mkldnn_linear_weight_pack_matcher_nodes"] += len(
+                match.nodes
+            )
+
+    def _eliminate_duplicate_packed_nodes(gm):
+        """
+        Combine packed weight nodes with the same inputs to reduce memory usage.
+        for example:
+        class Model(nn.Module):
+            def __init__(self) -> None:
+                super().__init__()
+                self.linear = nn.Linear(32, 32, bias=True)
+
+            def forward(self, x):
+                return self.linear(self.linear(x))
+
+        the above's packed weight nodes are duplicate if two linear calls have same input size.
+        """
+        if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()):
+            return gm
+
+        packed_weight_ops = [
+            torch._C._nn.mkldnn_reorder_conv2d_weight,
+            torch._C._nn.mkldnn_reorder_conv3d_weight,
+            mkldnn._reorder_convolution_transpose_weight,
+            mkldnn._reorder_linear_weight,
+            mkldnn._reorder_mkldnn_rnn_layer_weight,
+        ]
+        if torch._C.has_mkl:
+            packed_weight_ops.append(torch.ops.mkl._mkl_reorder_linear_weight)
+
+        for node in gm.graph.nodes:
+            if node.target in packed_weight_ops and len(node.args[0].users) > 1:
+                for user_node in list(node.args[0].users.keys()):
+                    if (
+                        user_node.target == node.target
+                        and user_node != node
+                        and user_node.args == node.args
+                    ):
+                        user_node.replace_all_uses_with(node)
+                        gm.graph.erase_node(user_node)
+
+    @functools.cache
+    def _mkldnn_fusion_init():
+        # TODO: aarch64: enable op fusion for acl once it supports fused operators. Disabling it for now.
+        # Otherwise even the matmul or innerproduct can not be accelerated with acl
+        if (
+            not torch.backends.mkldnn.enabled
+            or not torch.backends.mkldnn.is_available()
+        ):
+            return
+
+        if not torch.ops.mkldnn._is_mkldnn_acl_supported():
+            _register_unary_fusion()
+            _register_inplace_fusion()
+            _register_binary_unary_fusion()
+            _register_binary_fusion()
+            _register_quantization_lowerings()
+
+        _register_woq_lowerings()
+
+    @functools.cache
+    def _mkldnn_weight_pack_init():
+        if torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available():
+            _register_weight_pack_pass()
+            _recover_linear()
+            _register_quantization_weight_pack_pass()
+            _register_int8_woq_concat_linear_pattern()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/node_runtime_estimation.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/node_runtime_estimation.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e3e3ebf084ad7b785450a453c483ad4ae01895b
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/node_runtime_estimation.py
@@ -0,0 +1,325 @@
+"""
+Collective runtime estimation using CUDA events and power-of-2 rounding.
+"""
+
+from __future__ import annotations
+
+import itertools
+from functools import lru_cache
+from typing import Any, Optional
+
+import torch
+import torch.fx as fx
+from torch._inductor.utils import clear_on_fresh_cache, tabulate_2d
+from torch._logging import getArtifactLogger, trace_structured
+from torch.fx.operator_schemas import normalize_function
+
+
+# Setup logger for artifact logging
+log = getArtifactLogger(__name__, "node_runtime_estimation")
+
+
+# TODO: Consider using a distributed-aware cache or rank-local disk cache
+# not using local cache because different ranks might write to it concurrently.
+# solvable in future, potentially with workflow to seed cache
+@clear_on_fresh_cache
+@lru_cache
+def _get_collective_cache() -> dict[str, float]:
+    """Get process-local cache for collective benchmarks."""
+    return {}
+
+
+def get_cached_runtime(key: str) -> Optional[float]:
+    """Get cached runtime from process-local cache."""
+    return _get_collective_cache().get(key)
+
+
+def set_cached_runtime(key: str, value: float) -> None:
+    """Set cached runtime in process-local cache."""
+    _get_collective_cache()[key] = value
+
+
+def get_hint(x: int | torch.SymInt) -> Optional[int]:
+    if isinstance(x, int):
+        return x
+    assert isinstance(x, torch.SymInt)
+    return x.node.hint if x.node.has_hint() else None
+
+
+def can_benchmark_collective() -> bool:
+    """Check if we can benchmark collectives (not fake process group)."""
+    import torch.distributed as c10d
+
+    if not c10d.is_initialized():
+        return False
+
+    pg = c10d.distributed_c10d._get_default_group()
+    if torch.distributed.distributed_c10d.get_backend(pg) == "fake":
+        return False
+
+    return True
+
+
+def _median(lst):
+    assert len(lst) > 0
+    return torch.median(torch.tensor(lst)).item()
+
+
+def _benchmark_collective_with_cuda_events_impl(
+    n: torch.fx.Node,
+    args: tuple[Any, ...],
+    kwargs: dict[str, Any],
+    nruns: int,
+) -> float | None:
+    """
+    Core benchmarking logic using CUDA events and barriers.
+    Returns runtime in ms or None on failure.
+    """
+    from torch._dynamo.testing import rand_strided
+
+    # Convert FakeTensors to real tensors before benchmarking
+    def to_real(t: torch.Tensor) -> torch.Tensor:
+        shape = [get_hint(dim) for dim in t.shape]
+        stride = [get_hint(s) for s in t.stride()]
+
+        if any(s is None for s in itertools.chain(shape, stride)):
+            # This should not happen, as can_benhcmark_collective checks for unbacked
+            raise ValueError("Cannot convert tensor with symbolic dimensions")
+
+        return rand_strided(shape, stride, device=t.device, dtype=t.dtype)  # type: ignore[arg-type]
+
+    args, kwargs = torch.utils._pytree.tree_map_only(
+        torch.Tensor,
+        to_real,
+        (args, kwargs),
+    )
+
+    # Warmup: call collective once and wait
+    torch.cuda.synchronize()
+    result = n.target(*args, **kwargs)  # type: ignore[operator]
+    torch.ops._c10d_functional.wait_tensor(result)
+    torch.cuda.synchronize()
+
+    # Benchmark with CUDA events
+    comm_times = []
+    for _ in range(nruns):
+        start_evt = torch.cuda.Event(enable_timing=True)
+        end_evt = torch.cuda.Event(enable_timing=True)
+
+        start_evt.record()
+        result = n.target(*args, **kwargs)  # type: ignore[operator]
+        torch.ops._c10d_functional.wait_tensor(result)
+        end_evt.record()
+        end_evt.synchronize()
+
+        comm_times.append(start_evt.elapsed_time(end_evt))
+
+    return _median(comm_times)
+
+
+def benchmark_collective_with_cuda_events(
+    n: torch.fx.Node,
+    nruns: int = 2,
+) -> tuple[float | None, str]:
+    """
+    Benchmark collective with CUDA events. Returns (runtime_ms, cache_key) or (None, "") on failure.
+    """
+    # context manager not allowed with profiler.
+    with torch.utils._python_dispatch._disable_current_modes():
+        return benchmark_collective_with_cuda_events_impl(n, nruns)
+
+
+def benchmark_collective_with_cuda_events_impl(
+    n: torch.fx.Node,
+    nruns: int = 3,
+) -> tuple[float | None, str]:
+    """
+    Benchmark collective with CUDA events. Returns (runtime_ms, cache_key) or (None, "") on failure.
+    """
+    from torch._inductor import fx_utils
+    from torch.distributed.distributed_c10d import _get_group_size_by_name
+
+    # Early check: can we actually run collectives?
+    if not can_benchmark_collective():
+        return None, ""
+
+    success, args, kwargs = fx_utils.get_fake_args_kwargs(n)
+
+    opt_args_kwargs = normalize_function(
+        n.target,  # type: ignore[arg-type]
+        args=n.args,
+        kwargs=n.kwargs,
+        normalize_to_only_use_kwargs=True,
+    )
+    assert opt_args_kwargs is not None
+    group_name = opt_args_kwargs[1]["group_name"]
+    group_size = _get_group_size_by_name(group_name)
+
+    if not success:
+        return None, ""
+
+    # Extract actual input size in BYTES (first tensor argument)
+    actual_bytes: Optional[int] = None
+
+    def extract_tensor_info(t: torch.Tensor) -> torch.Tensor:
+        nonlocal actual_bytes
+        if actual_bytes is None:
+            shape = [get_hint(dim) for dim in t.shape]
+            if any(s is None for s in shape):
+                return t
+
+            total_elems = 1
+            for dim in shape:
+                assert dim is not None
+                total_elems *= dim
+
+            actual_bytes = total_elems * t.dtype.itemsize
+        else:
+            raise RuntimeError(f"should only be one input tensor to collective {n}")
+        return t
+
+    torch.utils._pytree.tree_map_only(torch.Tensor, extract_tensor_info, (args, kwargs))
+
+    if actual_bytes is None:
+        return None, ""
+
+    # Cache key by BYTES (dtype-agnostic)
+    key = f"{n.target}: ({group_size} group size, {actual_bytes} bytes)"
+
+    # Check cache
+    if (cached := get_cached_runtime(key)) is not None:
+        return cached, key
+
+    # Benchmark using CUDA events with actual args/kwargs
+    runtime = _benchmark_collective_with_cuda_events_impl(n, args, kwargs, nruns)
+
+    if runtime is None:
+        return None, key
+
+    # Cache the result
+    set_cached_runtime(key, runtime)
+    return runtime, key
+
+
+def _log_compute_estimations(
+    compute_nodes: list[fx.Node],
+    benchmarked_estimations: list[float],
+    analytical_estimations: list[float],
+) -> None:
+    """Log compute node runtime estimations comparing benchmarked vs analytical."""
+    import torch.utils._pytree as pytree
+    from torch._inductor.fx_utils import count_flops_fx
+    from torch.utils._dtype_abbrs import dtype_abbrs
+
+    def _node_summary(n: fx.Node) -> str:
+        ret = str(n)
+        for arg in pytree.arg_tree_leaves(n.args, n.kwargs):
+            if not isinstance(arg, torch.fx.Node):
+                continue
+            if "val" in arg.meta:
+                t = arg.meta["val"]
+                ret += f" {dtype_abbrs[t.dtype]}{tuple(t.shape)}"
+        return ret
+
+    headers = [
+        "Node",
+        "Benchmarked Est(us)",
+        "Analytical Est(us)",
+        "Diff(%)",
+        "Diff(us)",
+        "Flops",
+    ]
+
+    rows = [
+        [
+            _node_summary(node)[:120],
+            est_b * 1e3,
+            est_a * 1e3,
+            (est_a / est_b) if est_b > 0 else 0,
+            (est_a - est_b) * 1e3,
+            count_flops_fx(node),
+        ]
+        for node, est_b, est_a in zip(
+            compute_nodes, benchmarked_estimations, analytical_estimations
+        )
+    ]
+
+    log_str = tabulate_2d(rows, headers)
+
+    trace_structured(
+        "artifact",
+        metadata_fn=lambda: {
+            "name": "fx_compute_nodes_runtime_estimation",
+            "encoding": "string",
+        },
+        payload_fn=lambda: log_str,
+    )
+
+
+def _log_collective_benchmarks(
+    collective_nodes: list[fx.Node],
+    collective_keys: list[str],
+    benchmarked_medians: list[float],
+    world_size: int,
+) -> None:
+    """Log collective benchmarks with analytical comparisons for tlparse."""
+    headers = [
+        "Collective Key",
+        "Benchmarked(ms)",
+        "NCCL Est(ms)",
+        "Inductor Est(ms)",
+        "NCCL Diff(%)",
+        "Inductor Diff(%)",
+    ]
+
+    rows = []
+    collective_benchmarks = {}
+    for key, benchmarked_ms, coll_node in zip(
+        collective_keys, benchmarked_medians, collective_nodes
+    ):
+        # NCCL estimator (deterministic, no need to align)
+        nccl_ms = (
+            torch._inductor.comm_analysis.estimate_nccl_collective_runtime_from_fx_node(
+                coll_node, None, use_nccl_estimator=True
+            )
+        )
+
+        # Inductor analytical (deterministic, no need to align)
+        inductor_ms = (
+            torch._inductor.comm_analysis.estimate_nccl_collective_runtime_from_fx_node(
+                coll_node, None, use_nccl_estimator=False
+            )
+        )
+
+        collective_benchmarks[key] = {
+            "benchmarked_ms": benchmarked_ms,
+            "analytical_nccl_ms": nccl_ms,
+            "analytical_inductor_ms": inductor_ms,
+        }
+
+        # Compute percentage differences
+        nccl_diff_pct = (nccl_ms / benchmarked_ms) if benchmarked_ms > 0 else 0
+        inductor_diff_pct = (inductor_ms / benchmarked_ms) if benchmarked_ms > 0 else 0
+
+        rows.append(
+            [
+                key[:80],
+                f"{benchmarked_ms:.4f}",
+                f"{nccl_ms:.4f}",
+                f"{inductor_ms:.4f}",
+                f"{nccl_diff_pct:.2f}",
+                f"{inductor_diff_pct:.2f}",
+            ]
+        )
+
+    log_str = f"World size: {world_size}\n"
+    log_str += tabulate_2d(rows, headers)
+
+    trace_structured(
+        "artifact",
+        metadata_fn=lambda: {
+            "name": "fx_collectives_node_runtime_estimation",
+            "encoding": "string",
+        },
+        payload_fn=lambda: log_str,
+    )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/numeric_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/numeric_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1db82f21f7ec6a37e1f260b02d2fcd77622c058
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/numeric_utils.py
@@ -0,0 +1,213 @@
+# mypy: allow-untyped-defs
+import gc
+import logging
+import os
+import random
+import traceback
+
+import numpy
+
+import torch
+import torch.optim as optim
+from torch.utils._ordered_set import OrderedSet
+
+from .. import config
+
+
+logger: logging.Logger = logging.getLogger(__name__)
+
+MAIN_RANDOM_SEED = 1337
+
+# Set the CUBLAS_WORKSPACE_CONFIG environment variable
+os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
+
+
+# If the two forward functions involve any non-deterministic operations,
+# such as certain types of parallelism or asynchronous execution,
+# this can also lead to different outputs.
+def set_deterministic() -> None:
+    """Make torch manual seed deterministic."""
+
+    torch.manual_seed(MAIN_RANDOM_SEED)
+    random.seed(MAIN_RANDOM_SEED)
+    numpy.random.seed(MAIN_RANDOM_SEED)
+    torch.use_deterministic_algorithms(True)
+
+
+def clean_memory() -> None:
+    """Clean memory to avoid OOM."""
+    gc.collect()
+    torch.cuda.empty_cache()
+
+
+# We compare the numerical results before and after pre/post grad fx passes
+# transformation to make sure the numerical results are the same.
+def compare_dict_tensors(dict_base, dict_control, precision):
+    if len(OrderedSet(dict_base.keys())) != len(OrderedSet(dict_control.keys())):
+        logger.warning("Mismatch keys found before and after pre/post grad fx passes.")
+        logger.debug("keys before pre/post grad fx passes %s", dict_base.keys())
+        logger.debug("keys after pre/post grad fx passes %s", dict_control.keys())
+        return False
+    is_allclose = True
+    for key in dict_base:
+        if key not in dict_control:
+            logger.warning(
+                "Mismatch parameter name %s does not exist after pre/post grad fx passes",
+                key,
+            )
+        # Some parameters have `None`, and not every param has a valid .grad field, we skip them
+        if dict_base[key] is None or dict_control[key] is None:
+            continue
+        if not torch.allclose(
+            dict_base[key],
+            dict_control[key],
+            rtol=precision,
+            atol=precision,
+            equal_nan=True,
+        ):
+            logger.warning(
+                "Mismatch parameter values found before and after pre/post grad fx passes."
+            )
+            logger.debug("value before pre/post grad fx passes %s", dict_base[key])
+            logger.debug("value after pre/post grad fx passes %s", dict_control[key])
+            is_allclose = False
+    return is_allclose
+
+
+def compare_tuple_tensors(tuple_base, tuple_control, precision):
+    if len(tuple_base) != len(tuple_control):
+        logger.warning(
+            "Mismatch fw output length. before transformation: %s, after transformation: %s",
+            len(tuple_base),
+            len(tuple_control),
+        )
+        return False
+    is_allclose = True
+    for i in range(len(tuple_base)):
+        # Some parameters have `None`, we skip them
+        if tuple_base[i] is None or tuple_control[i] is None:
+            continue
+        if not torch.allclose(
+            tuple_base[i],
+            tuple_control[i],
+            rtol=precision,
+            atol=precision,
+            equal_nan=True,
+        ):
+            logger.debug(
+                "forward output before pre/post grad fx passes %s", tuple_base[i]
+            )
+            logger.debug(
+                "forward output after pre/post grad fx passes %s", tuple_control[i]
+            )
+            is_allclose = False
+    return is_allclose
+
+
+def compare_parameters(model_base, model_control, precision):
+    return compare_dict_tensors(
+        dict(model_base.named_parameters()),
+        dict(model_control.named_parameters()),
+        precision,
+    )
+
+
+def compare_forward_output(pred_base, pred_control, precision):
+    return compare_tuple_tensors(
+        pred_base,
+        pred_control,
+        precision,
+    )
+
+
+def compare_gradients(model_base, model_control, precision):
+    grad_base = {key: param.grad for key, param in model_base.named_parameters()}
+    grad_pt2 = {key: param.grad for key, param in model_control.named_parameters()}
+    return compare_dict_tensors(
+        grad_base,
+        grad_pt2,
+        precision,
+    )
+
+
+def run_model(
+    model_base, model_control, model_input, num_iterations=10, precision=1e-4
+):
+    clean_memory()
+    for i in range(num_iterations):
+        logger.info("start %s iteration", i)
+        set_deterministic()
+        pred_base = model_base(*model_input)
+        set_deterministic()
+        pred_control = model_control(*model_input)
+
+        res = compare_parameters(model_base, model_control, precision)
+        logger.info("compare parameters. Numerical result : %s", res)
+
+        res = compare_forward_output(pred_base, pred_control, precision)
+        logger.info("compare loss/predict. Numerical result : %s", res)
+        # tensor may not have a grad_fn
+        try:
+            _ = pred_base[0].sum().backward(retain_graph=True)
+            _ = pred_control[0].sum().backward(retain_graph=True)
+            res = compare_gradients(model_base, model_control, precision)
+            logger.info("compare param grad. Numerical result : %s", res)
+        except Exception:
+            logger.exception("Exception when comparing gradients")
+            traceback.print_exc()
+
+        if config.fx_passes_numeric_check["requires_optimizer"]:
+            try:
+                optimizer_base = optim.SGD(
+                    [param for name, param in model_base.named_parameters()], lr=0.01
+                )
+                optimizer_base.step()
+
+                optimizer_control = optim.SGD(
+                    [param for name, param in model_control.named_parameters()], lr=0.01
+                )
+                optimizer_control.step()
+
+                res = compare_parameters(model_base, model_control, precision)
+                logger.info(
+                    "compare parameters with optimizer added. Numerical result : %s",
+                    res,
+                )
+            except Exception:
+                logger.exception(
+                    "Exception when optimizer is added to check parameter names"
+                )
+                traceback.print_exc()
+        else:
+            logger.warning(
+                "no parameter with optimizer to compare with length %s before transformation"
+                " and the length %s after transformation",
+                len(dict(model_base.named_parameters())),
+                len(dict(model_control.named_parameters())),
+            )
+
+
+def numeric_check_if_enabled(
+    gm_before_fx_passes,
+    gm_after_fx_passes,
+    example_inputs,
+    num_iterations,
+    precision,
+):
+    # need to topo-sort graphmodule before we run the model,
+    # otherwise it may fail as refer before def
+    # fail silently in order not to block the model run
+    try:
+        with torch.autograd.set_detect_anomaly(True):
+            run_model(
+                gm_before_fx_passes,
+                gm_after_fx_passes,
+                example_inputs,
+                num_iterations=num_iterations,
+                precision=precision,
+            )
+    except Exception as e:
+        logger.warning(  # noqa: G200
+            "Runtime numeric check failed in pre grad fx passes with error: %s", e
+        )
+        traceback.print_exc()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/overlap_manual_scheduling.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/overlap_manual_scheduling.py
new file mode 100644
index 0000000000000000000000000000000000000000..540e73166ba45be7d9fd6eb12e627f795bae94dc
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/overlap_manual_scheduling.py
@@ -0,0 +1,374 @@
+from __future__ import annotations
+
+import heapq
+from collections import Counter, defaultdict
+from typing import Any, Optional, TYPE_CHECKING
+
+import torch
+import torch.fx as fx
+from torch._dynamo.graph_deduplication import _stable_topological_sort
+from torch._inductor.fx_passes.bucketing import (
+    _schedulable_wait_node,
+    is_all_gather_into_tensor as is_all_gather,
+    is_reduce_scatter_tensor as is_reduce_scatter,
+    merge_all_gather_bucket,
+    merge_reduce_scatter_bucket,
+)
+from torch._inductor.fx_passes.overlap_preserving_bucketer import (
+    bucket_key,
+    OverlapPreservingBucketer,
+)
+from torch._inductor.fx_passes.overlap_scheduling import (
+    CollectiveInfo,
+    is_compute_node,
+    OverlapScheduler,
+)
+from torch.utils._ordered_set import OrderedSet
+
+from .graph_view import get_subgraph_by_path, GraphView, make_graph_view
+
+
+if TYPE_CHECKING:
+    from collections.abc import Callable
+
+
+class ManualOverlapPreservingBucketer(OverlapPreservingBucketer):
+    """
+    Buckets collective operations based on user specifications.
+    The actual bucket happens in bucket_collectives, where all-gathers/reduce-scatters in
+        `nodes` will be buckted one single all-gather/reduce-scatter.
+    """
+
+    def __init__(
+        self,
+        node_users: dict[fx.Node, OrderedSet[fx.Node]],
+        *args: Any,
+        **kwargs: Any,
+    ):
+        super().__init__(*args, **kwargs)
+        self.node_users = node_users
+        self.wait_to_node_map: dict[fx.Node, fx.Node] = defaultdict()
+
+    def _check_recursive_dep(
+        self,
+        node: fx.Node,
+        target_op: str,
+        dep_dict: dict[torch.fx.Node, OrderedSet[torch.fx.Node]],
+    ) -> bool:
+        """
+        Check if the node is directly used for fetch parameters/gradients
+
+        TODO (ruisizhang123): currently, we assume the node only pre-fetch/update one parameter/gradient
+            We should handle multiple parameters/gradients update case by checking if there are non closure
+            computes along the path from primal/output to coll_node
+        """
+        deps: OrderedSet[fx.Node] = dep_dict[node]
+        seen_target_op = 0
+        for d in deps:
+            if d.op == target_op:
+                seen_target_op += 1
+
+        return seen_target_op == 1
+
+    def _bucket_group(self, coll_nodes: list[fx.Node]) -> None:
+        assert len(coll_nodes) > 0, "bucketed coll_nodes should have nonzero node"
+
+        waits = [self.collective_info[n].wait_node for n in coll_nodes]
+        # Use earliest wait insertion point
+        first_wait = min(waits, key=lambda w: self.node_idx[w])
+        # Find insertion location
+        first = coll_nodes[0]
+        next_node = first
+        while next_node in coll_nodes:
+            next_node = next_node.next
+
+        if is_all_gather(first):
+            new_nodes, replacements = merge_all_gather_bucket(
+                self.graph,
+                coll_nodes,
+                wait_insertion_point=first_wait,
+                insert_before=next_node,
+                mode="custom_ops",
+            )
+        elif is_reduce_scatter(first):
+            new_nodes, replacements = merge_reduce_scatter_bucket(
+                self.graph,
+                coll_nodes,
+                wait_insertion_point=first_wait,
+                insert_before=next_node,
+                mode="custom_ops",
+            )
+        else:
+            raise ValueError(
+                "bucket non all_gather/reduce_scatter node is not supported"
+            )
+
+        # Identify the new wait and start
+        new_waits = [n for n in new_nodes if _schedulable_wait_node(n)]
+        assert len(new_waits) == 1, f"Expected exactly one new wait, got {new_waits}"
+        new_wait = new_waits[0]
+        new_start = new_wait.args[0]
+        assert isinstance(new_start, fx.Node)
+
+        # Set manual bucketing-specific metadata
+        # Note: Generic metadata (nn_module_stack, fwd_nn_module_stack, custom, stack_trace)
+        # is now preserved automatically by the bucketing functions in bucketing.py
+        node_type = (
+            "bucketed_all_gather" if is_all_gather(first) else "bucketed_reduce_scatter"
+        )
+        for n in new_nodes:
+            if n == new_wait:
+                node_type = node_type + "_wait"
+            n.meta["manual_bucket_node_type"] = node_type
+            if "wait" in node_type:
+                self.wait_to_node_map[n] = new_wait
+
+    def manual_bucket_collectives(self, nodes: list[fx.Node]) -> None:
+        """
+        Bucket all all-gather/reduce-scatter nodes from nodes into one all-gather/reduce-scatter.
+        """
+        # Filter out valid collectives
+        collectives = [n for n in nodes if n in self.collective_info]
+        if collectives == []:
+            return
+        grouped_collectives: dict[object, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
+        for node in collectives:
+            key = bucket_key(node)
+            if not (is_all_gather(node) or is_reduce_scatter(node)):
+                continue
+            # We only want to bucket all-gather/reduce-scatter that
+            # 1. all_gather that have ancestors dependent only on input placeholder(parameters)
+            # 2. reduce scatter that the wait user node is returned as output(gradients)
+            if is_all_gather(node) and not self._check_recursive_dep(
+                node, "placeholder", self.node_ancestors
+            ):
+                continue
+            if is_reduce_scatter(node) and not self._check_recursive_dep(
+                self.collective_info[node].wait_node, "output", self.node_users
+            ):
+                continue
+            if key is not None:
+                grouped_collectives[key].add(node)
+
+        for key, nodes in grouped_collectives.items():  # type: ignore[arg-type]
+            self._bucket_group(list(nodes))
+
+
+class ManualOverlapScheduler(OverlapScheduler):
+    """
+    Scheduler that manual buckets and reorders collective nodes based on module_bucket_plans
+    """
+
+    def __init__(
+        self,
+        gm: fx.GraphModule,
+        module_bucket_plans: list[list[str] | str],
+        insert_overlap_deps: bool,
+        module_stack_fn: None | Callable[[fx.Node], list[tuple[str, type[Any]]]] = None,
+    ):
+        super().__init__(
+            gm,
+            max_in_flight_gb=0.0,
+            max_compute_pre_fetch=0,
+            collective_bucketing=True,
+            insert_overlap_deps=insert_overlap_deps,
+            compute_overlap_multipler=0.0,
+            max_coll_distance=0,
+            custom_runtime_estimation=None,
+            collective_estimator="analytical",
+            max_memory_increase_gb=None,
+            max_memory_increase_ratio=None,
+        )
+        self.module_bucket_plans = module_bucket_plans
+        self.nodes_in_subgraph: list[list[fx.Node]] = []
+
+        self.node_users: dict[fx.Node, OrderedSet[fx.Node]] = self._collect_node_users()
+        self.bucketer = ManualOverlapPreservingBucketer(
+            graph=self.graph,
+            collective_info=self.collective_info,
+            node_users=self.node_users,
+            scheduled=OrderedSet(self.graph.nodes),
+        )
+        self.insert_overlap_deps = insert_overlap_deps
+
+        self.module_stack_fn = module_stack_fn
+
+    def _identify_collectives(self) -> None:
+        """Identify all collective operations."""
+        for node in self.nodes:
+            if _schedulable_wait_node(node):
+                start = node.args[0]
+                info = CollectiveInfo(
+                    start_node=start,
+                    wait_node=node,
+                    size_bytes=0,
+                    estimated_time_ms=0,
+                    exposed_time_ms=0,
+                )
+                self.collective_info[start] = info
+                self.wait_to_start[node] = start
+                self.unscheduled_collectives.add(start)
+
+    def run(self) -> torch.fx.GraphModule:
+        """Entry point to run the manual bucket algorithm"""
+        # Bucket collectives in each bucket_module
+        self._manual_bucket_collectives()
+
+        # Reorder collectives with last/next bucket_module
+        self._manual_reorder_graph()
+
+        return self.gm
+
+    def _manual_reorder_graph(self) -> None:
+        """
+        Reorder nodes in the FX graph to enforce manual overlap dependencies.
+
+        Enforce:
+        - all_gather_start_i depends on all_gather_wait_(i-1)
+        - reduce_scatter_wait_i must happen before reduce_scatter_start_(i+1)
+        """
+        delayed_rs_nodes: list[fx.Node] = []
+        overlap_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
+
+        # schedule reduce scatter normally in self._schedule
+        while self.ready:
+            _, node = heapq.heappop(self.ready)
+            node_type = node.meta.get("manual_bucket_node_type", "")
+
+            if node in self.scheduled:
+                continue
+
+            if node_type == "bucketed_reduce_scatter":
+                # Ensure all delayed waits execute before this reduce_scatter
+                for delayed in delayed_rs_nodes:
+                    self._schedule(delayed)
+                    overlap_deps[delayed].add(node)
+                delayed_rs_nodes.clear()
+
+            elif node_type == "bucketed_reduce_scatter_wait":
+                # Defer until next reduce_scatter
+                delayed_rs_nodes.append(node)
+                continue
+            self._schedule(node)
+
+        for delayed in delayed_rs_nodes:
+            self._schedule(delayed)
+
+        self.scheduled = OrderedSet(reversed(list(self.scheduled)))
+        picked_ag: list[fx.Node] = []
+        last_compute: Optional[fx.Node] = None
+
+        for node in self.scheduled:
+            node_type = node.meta.get("manual_bucket_node_type", "")
+            if node_type == "bucketed_all_gather":
+                picked_ag.append(node)
+                continue
+
+            if node_type == "bucketed_all_gather_wait":
+                # Connect corresponding all_gather_wait -> all_gather edges
+                if picked_ag:
+                    for ag in picked_ag:
+                        overlap_deps[self.bucketer.wait_to_node_map[node]].add(ag)
+                picked_ag.clear()
+            if is_compute_node(node):
+                last_compute = node
+
+        if last_compute is not None and not bool(
+            OrderedSet(picked_ag) & OrderedSet(self.node_ancestors[last_compute])
+        ):
+            for ag in picked_ag:
+                overlap_deps[last_compute].add(ag)
+
+        _stable_topological_sort(self.graph, overlap_deps)
+        self.graph.lint()
+
+        if self.insert_overlap_deps:
+            from torch._inductor.fx_passes.control_dependencies import (
+                preserve_node_ordering,
+            )
+
+            preserve_node_ordering(self.graph, overlap_deps)
+
+    def _manual_bucket_collectives(self) -> None:
+        """Bucket nodes in each module_bucket from module_bucket_plans."""
+        self._obtain_nodes_in_subgraph()
+        for i, nodes in enumerate(self.nodes_in_subgraph):
+            self.bucketer.manual_bucket_collectives(nodes=nodes)
+
+        _stable_topological_sort(self.graph, {})
+        self.graph.lint()
+        self.nodes = list(self.graph.nodes)
+        self.in_degree = Counter(user for node in self.nodes for user in node.users)
+
+    def _collect_node_users(self) -> dict[fx.Node, OrderedSet[fx.Node]]:
+        """Collect all users for each node."""
+        node_users: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
+        for node in self.nodes:
+            for output_node in list(node.users.keys()):
+                node_users[node].add(output_node)
+                node_users[node] |= node_users[output_node]
+        return node_users
+
+    def _schedule(self, node: fx.Node) -> None:
+        """Schedule a node."""
+        assert node not in self.scheduled
+        assert all(n in self.scheduled for n in node.all_input_nodes)
+        self.scheduled.add(node)
+        for user in node.users:
+            self.in_degree[user] -= 1
+            if self.in_degree[user] == 0:
+                heapq.heappush(self.ready, ((), user))
+
+    def _obtain_nodes_in_subgraph(self) -> None:
+        """
+        Obtain nodes in each subgraph from module_bucket_plans
+        """
+        graph_view: GraphView | None = make_graph_view(self.graph, self.module_stack_fn)
+        if graph_view is None:
+            return
+
+        for module in self.module_bucket_plans:
+            subgraph_view = get_subgraph_by_path(graph_view, module)
+            self.nodes_in_subgraph.append(subgraph_view)
+
+        all_subgraph_nodes = [
+            node for sublist in self.nodes_in_subgraph for node in sublist
+        ]
+        unique_subgraph_nodes = list(OrderedSet(all_subgraph_nodes))
+        assert len(all_subgraph_nodes) <= len(unique_subgraph_nodes), (
+            f"Overlapping FX nodes detected across subgraphs in `module_bucket_plans`. "
+            f"Expected disjoint node sets but found "
+            f"{len(all_subgraph_nodes) - len(unique_subgraph_nodes)} duplicated node(s)."
+        )
+
+
+def manual_overlap_bucketing(
+    gm: torch.fx.GraphModule,
+    module_bucket_plans: list[list[str] | str],
+    insert_overlap_deps: bool = False,
+    module_stack_fn: None | Callable[[fx.Node], list[tuple[str, type[Any]]]] = None,
+) -> torch.fx.GraphModule:
+    """Schedule nodes based on user specifications in module_bucket_plans
+    The manual overlapping consists of two steps:
+    Step 1: bucket all-gather/reduce-scatter in each module in module_bucket_plans
+    Step 2: reorder all-gather to overlap with last module_bucket &
+        reorder reduce-scatter to overlap with next module_bucket
+    TODO(ruisizhang123): allow users to explicitly specify which
+        module_bucket they want to overlap.
+
+    Args:
+        gm: input graph module to optimize.
+        module_bucket_plans: user specified FQNs
+        module_stack_fn: Optional callable for extracting module hierarchy from nodes.
+            Used to construct a GraphView for identifying nodes in module_bucket_plans.
+            The module_class component of the returned tuples is not used by this pass.
+
+            See the `module_stack_fn` parameter in `make_graph_view` (graph_view.py) for
+            detailed documentation on signature, return format, and usage examples.
+    """
+    # decode abbreviated FQNs to actual FQNs
+    overlapped_gm = ManualOverlapScheduler(
+        gm, module_bucket_plans, insert_overlap_deps, module_stack_fn
+    ).run()
+    overlapped_gm.recompile()
+    return overlapped_gm
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/overlap_preserving_bucketer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/overlap_preserving_bucketer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c819f37a1a83ecff13c4b18ceb2753b61087c29
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/overlap_preserving_bucketer.py
@@ -0,0 +1,912 @@
+import itertools
+import logging
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Any, Literal, Optional
+
+import torch
+import torch.fx as fx
+from torch._dynamo.utils import counters
+from torch._inductor.augmented_graph_helper import AugmentedGraphHelper
+from torch._inductor.fx_passes.bucketing import (
+    _schedulable_wait_node,
+    bucket_key,
+    BucketMode,
+    has_mergeable_all_gather_convert_dtype,
+    is_all_gather_into_tensor as is_all_gather,
+    is_reduce_scatter_tensor as is_reduce_scatter,
+)
+from torch._inductor.fx_passes.overlap_scheduling import (
+    CollBucket,
+    CollectiveInfo,
+    get_group_name,
+    is_compute_node,
+)
+from torch.utils._ordered_set import OrderedSet
+
+
+bucket_log = logging.getLogger(__name__)
+
+
+@dataclass
+class WhyNoBucket:
+    name1: str
+    name2: str
+    reason: str
+    args: tuple[Any, ...]
+
+    def __init__(self, node1: fx.Node, node2: fx.Node) -> None:
+        self.name1 = node1.name
+        self.name2 = node2.name
+        self.reason = ""
+        self.args = ()
+
+    def __call__(self, reason: str, *args: Any) -> None:
+        if bucket_log.isEnabledFor(logging.DEBUG):
+            bucket_log.debug(
+                "cannot bucket %s with %s: " + reason,  # noqa: G003
+                self.name1,
+                self.name2,
+                *args,
+            )
+
+
+def is_collective_or_wait(n: fx.Node) -> bool:
+    """Check if node is a collective start or wait."""
+    if _schedulable_wait_node(n):
+        return True
+    # Collective starts have exactly one use: the wait_tensor
+    if len(n.users) == 1:
+        user = next(iter(n.users.keys()))
+        if _schedulable_wait_node(user):
+            return True
+    return False
+
+
+@dataclass
+class PGEvent:
+    """
+    Represents an important event in a process group timeline. Either
+    a collective start, wait, or hiding compute. Each node is linked
+    to its prev and next and these dependencies are reflected
+    in the augmented graph.
+
+    We want to enforce a sequential ordering of collective starts and waits
+    because NCCL collectives on the same process group execute on the same CUDA
+    stream, creating implicit dependencies between all operations on that PG.
+
+    A wait of a particular collective will implicitly force realization of all collectives
+    enqueued prior to that collective.
+    """
+
+    node: fx.Node
+    event_type: Literal["compute", "starts", "waits"]
+    position: int
+    prev: Optional["PGEvent"] = None
+    next: Optional["PGEvent"] = None
+
+    @property
+    def is_start(self) -> bool:
+        return self.event_type == "starts"
+
+    @property
+    def is_wait(self) -> bool:
+        return self.event_type == "waits"
+
+    @property
+    def is_compute(self) -> bool:
+        return self.event_type == "compute"
+
+    def unlink(self) -> tuple[Optional["PGEvent"], Optional["PGEvent"]]:
+        """Remove this event from the linked list, return (prev, next)."""
+        prev_event, next_event = self.prev, self.next
+        if self.prev:
+            self.prev.next = self.next
+        if self.next:
+            self.next.prev = self.prev
+        self.prev = None
+        self.next = None
+        return prev_event, next_event
+
+    def insert_between(
+        self, prev_event: Optional["PGEvent"], next_event: Optional["PGEvent"]
+    ) -> None:
+        """Insert this event between prev_event and next_event in the linked list."""
+        if prev_event:
+            prev_event.next = self
+        self.prev = prev_event
+
+        if next_event:
+            next_event.prev = self
+        self.next = next_event
+
+
+class OverlapPreservingBucketer:
+    """
+    Buckets collective operations while preserving compute-collective overlap relationships.
+    Uses an augmented graph to track dependencies between compute and collective operations.
+    """
+
+    def __init__(
+        self,
+        graph: fx.Graph,
+        collective_info: dict[fx.Node, CollectiveInfo],
+        scheduled: OrderedSet[fx.Node],
+        max_bucket_memory_gb: float = 1.0,
+        max_coll_distance: int = 1000,
+        insert_overlap_deps: bool = False,
+        bucket_mode: BucketMode = "custom_ops_multidtype",
+    ):
+        self.graph = graph
+        self.collective_info = collective_info
+        self.scheduled = scheduled
+        self.max_bucket_memory_gb = max_bucket_memory_gb
+        self.node_idx = {n: i for i, n in enumerate(scheduled)}
+        self.max_coll_distance = max_coll_distance
+        self.insert_overlap_deps = insert_overlap_deps
+        self.bucket_mode = bucket_mode
+        self.node_to_event: dict[fx.Node, PGEvent] = {}
+        self.all_hiding_nodes: OrderedSet[fx.Node] = OrderedSet()
+
+        # Compute ancestors including original graph edges and hiding interval dependencies
+        self.node_ancestors = self._compute_node_ancestors()
+        self.aug_graph = AugmentedGraphHelper(self.graph, self.node_ancestors)
+
+        # Build timelines and add constraints to aug_graph
+        self.pg_to_timeline_head: dict[str, Optional[PGEvent]] = self.build_timelines()
+        self._add_hiding_interval_constraints()
+
+    def _compute_node_ancestors(self) -> dict[fx.Node, OrderedSet[fx.Node]]:
+        """
+        Compute ancestor sets for all nodes including:
+        1. Original graph edges
+        2. Hiding interval deps: collective_start -> hiding_node -> wait
+        """
+        augmented_inputs: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
+        for start, info in self.collective_info.items():
+            if info.is_exposed:
+                continue
+            for hiding_node in info.hiding_nodes:
+                augmented_inputs[hiding_node].add(start)
+                augmented_inputs[info.wait_node].add(hiding_node)
+
+        node_ancestors: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
+        for node in self.scheduled:
+            for input_node in itertools.chain(
+                augmented_inputs[node], node.all_input_nodes
+            ):
+                node_ancestors[node].add(input_node)
+                node_ancestors[node] |= node_ancestors[input_node]
+
+        return node_ancestors
+
+    def build_timelines(self) -> dict[str, Optional[PGEvent]]:
+        "Construct each process groups ordered series of event"
+        all_pgs: OrderedSet[str] = OrderedSet()
+        for start in self.collective_info:
+            pg = get_group_name(start)
+            all_pgs.add(pg)
+
+        pg_timeline: dict[str, Optional[PGEvent]] = {}
+        for pg in all_pgs:
+            pg_timeline[pg] = self.build_timeline(pg)
+
+        return pg_timeline
+
+    def build_timeline(self, pg: str) -> Optional[PGEvent]:
+        """
+        Build a timeline of important events (starts, waits, hiding compute) for this process group
+        and constrain this ordering in the augmented graph.
+
+        Sequential dependencies are added between all events because NCCL collectives on the same
+        process group execute on the same CUDA stream, enforcing LIFO semantics where later-issued
+        collectives must complete before earlier ones can finish.
+        """
+
+        head = None
+        prev_event = None
+        position = 0
+        hiding_nodes = OrderedSet()
+
+        for node in self.scheduled:
+            node_type = None
+
+            # Determine if this node is relevant for this PG
+            if node in self.collective_info and get_group_name(node) == pg:
+                node_type = "starts"
+                hiding_nodes |= self.collective_info[node].hiding_nodes
+            elif _schedulable_wait_node(node):
+                wait_input = node.args[0]
+                if isinstance(wait_input, fx.Node) and get_group_name(wait_input) == pg:
+                    node_type = "waits"
+                # Wait for a different PG but hiding a collective on this PG
+                elif node in hiding_nodes:
+                    node_type = "compute"
+            elif is_compute_node(node) or node in hiding_nodes:
+                node_type = "compute"
+
+            if node_type is None:
+                continue
+
+            event = PGEvent(node=node, event_type=node_type, position=position)  # type: ignore[arg-type]
+
+            event.insert_between(prev_event, None)
+
+            # Add sequential dependency to augmented graph
+            if prev_event:
+                self.aug_graph.add_extra_dep(n=event.node, dep=prev_event.node)
+            else:
+                head = event
+
+            prev_event = event
+            position += 1
+
+        return head
+
+    def _populate_node_to_event(self, pg: str) -> None:
+        """Populate node_to_event mapping for a specific PG's timeline."""
+        self.node_to_event.clear()
+        head = self.pg_to_timeline_head[pg]
+        curr = head
+        while curr is not None:
+            self.node_to_event[curr.node] = curr
+            curr = curr.next
+
+    def _add_hiding_interval_constraints(self) -> None:
+        """
+        Add hiding interval constraints: start -> compute -> wait.
+        """
+        for start, info in self.collective_info.items():
+            if info.is_exposed:
+                continue
+            for hn in info.hiding_nodes:
+                # Enforce: start -> compute -> wait
+                self.aug_graph.add_extra_dep(n=hn, dep=start)
+                self.aug_graph.add_extra_dep(n=info.wait_node, dep=hn)
+
+            self.all_hiding_nodes |= info.hiding_nodes
+
+    def bucket_collectives(self) -> None:
+        # Group collectives by PG first
+        pg_collectives: dict[str, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
+        for start in self.collective_info:
+            pg = get_group_name(start)
+            pg_collectives[pg].add(start)
+
+        all_buckets: list[CollBucket] = []
+        for pg, collectives in pg_collectives.items():
+            # Populate node_to_event for this PG's timeline
+            self._populate_node_to_event(pg)
+
+            # Group by bucket key within this PG
+            grouped_collectives: dict[object, OrderedSet[fx.Node]] = defaultdict(
+                OrderedSet
+            )
+            for start in collectives:
+                key = bucket_key(start, self.bucket_mode)
+                if key is not None:
+                    grouped_collectives[key].add(start)
+
+            # Find buckets for this PG
+            for key, collective_group in grouped_collectives.items():
+                bucket_log.debug(
+                    "bucketing collective group with key %s: %s",
+                    key,
+                    [n.name for n in collective_group],
+                )
+                buckets = self._find_buckets(collective_group)
+                all_buckets.extend(buckets)
+
+        # Apply bucketing transformations
+        # Dependencies are tracked in aug_graph.extra_deps during bucketing
+        for coll_bucket in all_buckets:
+            if len(coll_bucket.collectives) <= 1:
+                continue
+
+            counters["inductor"]["collective_buckets"] += 1
+            self._apply_bucket(coll_bucket)
+
+        # Extract all dependencies from augmented graph
+        # This includes:
+        # - Sequential timeline deps (added during build_timeline)
+        # - Hiding interval deps (added during _add_hiding_interval_constraints)
+        # - All transferred deps from bucketing (transferred during _apply_bucket)
+        additional_deps = self.aug_graph.get_all_extra_deps()
+
+        # Apply topological sort with all dependencies
+        from torch._dynamo.graph_deduplication import _stable_topological_sort
+
+        for n, deps in additional_deps.items():
+            torch._check(
+                not n._erased, lambda: f"Erased node deps not transferred: {n}"
+            )
+            for d in deps:
+                torch._check(
+                    not d._erased, lambda: f"Erased node deps not transferred: {d}"
+                )
+
+        _stable_topological_sort(self.graph, additional_deps)
+
+        # After topological sort, preserve dependencies using effect tokens
+        # Only preserve edges where NOT both nodes are collective starts or waits
+        if self.insert_overlap_deps:
+            filtered_deps: dict[fx.Node, OrderedSet[fx.Node]] = {}
+            for node, deps in additional_deps.items():
+                filtered_node_deps: OrderedSet[fx.Node] = OrderedSet()
+
+                # only preserve comm-comptue overlap for now, although we could more
+                # generally constrain
+                for dep in deps:
+                    if not (is_collective_or_wait(node) and is_collective_or_wait(dep)):
+                        filtered_node_deps.add(dep)
+
+                if filtered_node_deps:
+                    filtered_deps[node] = filtered_node_deps
+
+            self._preserve_dependencies_with_tokens(filtered_deps)
+
+        self.graph.lint()
+
+    def _find_buckets(
+        self,
+        collective_group: OrderedSet[fx.Node],
+    ) -> list[CollBucket]:
+        """Find valid buckets within a group of similar collectives."""
+        max_bucket_bytes = int(self.max_bucket_memory_gb * 1024 * 1024 * 1024)
+        buckets = []
+        processed: OrderedSet[fx.Node] = OrderedSet()
+
+        # Sort collectives by node index for efficient distance checking
+        sorted_collectives = sorted(collective_group, key=lambda n: self.node_idx[n])
+
+        for i, start_node in enumerate(sorted_collectives):
+            if start_node in processed:
+                continue
+
+            if (
+                start_node in self.all_hiding_nodes
+                or self.collective_info[start_node].wait_node in self.all_hiding_nodes
+            ):
+                continue
+
+            # Initialize bucket with first collective
+            bucket_info = CollBucket(
+                collectives=[start_node],
+                total_bytes=self.collective_info[start_node].size_bytes,
+            )
+            processed.add(start_node)
+
+            # Greedy optimization: stop after consecutive failures
+            consecutive_failures = 0
+            max_consecutive_failures = 20
+
+            # Check candidates in sorted order, break when beyond max distance
+            for candidate in sorted_collectives[i + 1 : i + 1 + self.max_coll_distance]:
+                candidate_bytes = self.collective_info[candidate].size_bytes
+                # proxy on memory use, if we see a too large bucket,
+                # dont look for another, later bucket
+                if bucket_info.total_bytes + candidate_bytes > max_bucket_bytes:
+                    break
+
+                if candidate in processed:
+                    continue
+
+                if self._can_add_to_bucket(bucket_info, candidate):
+                    bucket_info.collectives.append(candidate)
+                    bucket_info.total_bytes += candidate_bytes
+                    processed.add(candidate)
+                    consecutive_failures = 0  # Reset on success
+                else:
+                    consecutive_failures += 1
+                    if consecutive_failures >= max_consecutive_failures:
+                        break
+
+            if len(bucket_info.collectives) > 1:
+                buckets.append(bucket_info)
+
+        return buckets
+
+    def _ancestor_dep(self, n1: fx.Node, n2: fx.Node) -> bool:
+        """Check if there's an ancestor relationship between two nodes."""
+        return n1 in self.node_ancestors[n2] or n2 in self.node_ancestors[n1]
+
+    def _get_intervals(
+        self, event: PGEvent
+    ) -> tuple[Optional[tuple[int, int]], list[tuple[int, int]]]:
+        """Get (execution_interval, hiding_intervals) for a collective event.
+
+        Returns:
+            (execution_interval, hiding_intervals) where:
+            - execution_interval is (start_pos, wait_pos) or None
+            - hiding_intervals is a list of (start_pos, compute_pos) tuples, one for each hiding node
+
+        Works for both start and wait events by looking up the collective info.
+        """
+        # For start events, directly use the node
+        if event.is_start:
+            coll = event.node
+        # For wait events, look up the start node from the event's args
+        elif event.is_wait:
+            wait_input = event.node.args[0]
+            if not isinstance(wait_input, fx.Node):
+                return None, []
+            coll = wait_input
+        else:
+            return None, []
+
+        if coll not in self.collective_info:
+            return None, []
+
+        info = self.collective_info[coll]
+        start_event = self.node_to_event[coll]
+        wait_event = self.node_to_event[info.wait_node]
+
+        execution_interval = (start_event.position, wait_event.position)
+
+        hiding_intervals = []
+        if info.hiding_nodes:
+            for hiding_node in info.hiding_nodes:
+                hiding_intervals.append(
+                    (
+                        start_event.position,
+                        self.node_to_event[hiding_node].position,
+                    )
+                )
+
+        return execution_interval, hiding_intervals
+
+    def _preserves_hiding_intervals(
+        self,
+        bucket_info: CollBucket,
+        candidate: fx.Node,
+        start_pos: fx.Node,
+        wait_pos: fx.Node,
+        why: WhyNoBucket,
+    ) -> bool:
+        """
+        Check that (start_pos, wait_pos) doesn't violate any hiding intervals or collectives.
+
+        Collects all execution and hiding intervals in the affected timeline regions,
+        then checks:
+        1. All bucket hiding compute stays between new start/wait
+        2. No other collective's compute interval is enclosed by bucket execution interval
+        3. No other collective's execution interval encloses bucket compute intervals
+        """
+        # Collect all collectives being bucketed
+        all_bucketed_colls = [candidate] + list(bucket_info.collectives)
+        all_bucketed_waits = [
+            self.collective_info[coll].wait_node for coll in all_bucketed_colls
+        ]
+
+        # Collect hiding compute positions for the bucket
+        bucket_hiding_compute_positions = []
+        for coll in all_bucketed_colls:
+            for coll_hiding_node in self.collective_info[coll].hiding_nodes:
+                bucket_hiding_compute_positions.append(
+                    self.node_to_event[coll_hiding_node].position
+                )
+
+        # Get new positions
+        new_start_event = self.node_to_event[start_pos]
+        new_wait_event = self.node_to_event[wait_pos]
+
+        # Check 1: All bucket hiding compute must be between new start and wait
+        for compute_pos in bucket_hiding_compute_positions:
+            if not (new_start_event.position < compute_pos < new_wait_event.position):
+                why(
+                    "hiding compute at pos %d not between start %d and wait %d",
+                    compute_pos,
+                    new_start_event.position,
+                    new_wait_event.position,
+                )
+                return False
+
+        def get_wait(n: fx.Node) -> fx.Node:
+            return self.collective_info[n].wait_node
+
+        def get_pos(n: fx.Node) -> int:
+            return self.node_to_event[n].position
+
+        latest_start_pos = max(get_pos(candidate), get_pos(bucket_info.collectives[0]))
+        earliest_wait_pos = min(
+            get_pos(get_wait(candidate)), get_pos(get_wait(bucket_info.collectives[0]))
+        )
+
+        # Bucket execution interval
+        bucket_execution_interval = (new_start_event.position, new_wait_event.position)
+
+        # Because collectives on the same PG operate under LIFO semantics,
+        # it's only possible for us to force an early realization of an unrelated collective
+        # by delaying a start or raising a wait.
+        # We search in the interval from old_start -> new_start, to see if would be
+        # forcing another collective to be realized prior to its hiding nodes.
+        # Similarly, we search from old_wait -> new_wait, in the reverse direction,
+        # to check the same thing.
+
+        execution_intervals = [bucket_execution_interval]
+        hiding_intervals = [
+            (bucket_execution_interval[0], pos)
+            for pos in bucket_hiding_compute_positions
+        ]
+
+        curr_event = new_start_event.next
+        while curr_event is not None and curr_event.position < latest_start_pos:
+            if (
+                curr_event.node not in all_bucketed_colls
+                and curr_event.node not in all_bucketed_waits
+            ):
+                exec_interval, hiding_interval_list = self._get_intervals(curr_event)
+                if exec_interval:
+                    execution_intervals.append(exec_interval)
+                hiding_intervals.extend(hiding_interval_list)
+            curr_event = curr_event.next
+
+        curr_event = new_wait_event.prev
+        while curr_event is not None and curr_event.position > earliest_wait_pos:
+            if (
+                curr_event.node not in all_bucketed_colls
+                and curr_event.node not in all_bucketed_waits
+            ):
+                exec_interval, hiding_interval_list = self._get_intervals(curr_event)
+                if exec_interval:
+                    execution_intervals.append(exec_interval)
+                hiding_intervals.extend(hiding_interval_list)
+            curr_event = curr_event.prev
+
+        # Check: no hiding interval should be enclosed by any execution interval
+        def enclosed_interval(inner: tuple[int, int], outer: tuple[int, int]) -> bool:
+            return outer[0] < inner[0] and inner[1] < outer[1]
+
+        for hiding_interval in hiding_intervals:
+            for execution_interval in execution_intervals:
+                if enclosed_interval(hiding_interval, execution_interval):
+                    why(
+                        "hiding interval %s enclosed by execution interval %s",
+                        hiding_interval,
+                        execution_interval,
+                    )
+                    return False
+
+        return True
+
+    def remove_from_event(
+        self, node: fx.Node
+    ) -> tuple[Optional[PGEvent], Optional[PGEvent]]:
+        """Remove node from timeline and return (prev_event, next_event)."""
+        event = self.node_to_event[node]
+        assert not event.is_compute, "Cannot remove compute events from timeline"
+
+        prev_event, next_event = event.unlink()
+
+        # Remove augmented graph dependency
+        if prev_event:
+            self.aug_graph.remove_extra_dep(n=node, dep=prev_event.node)
+        if next_event:
+            self.aug_graph.remove_extra_dep(n=next_event.node, dep=node)
+
+        # Add bypass dependency
+        if prev_event and next_event:
+            self.aug_graph.add_extra_dep(n=next_event.node, dep=prev_event.node)
+
+        return prev_event, next_event
+
+    def restore_to_event(
+        self,
+        node: fx.Node,
+        prev_event: Optional[PGEvent],
+        next_event: Optional[PGEvent],
+    ) -> None:
+        """Restore node to timeline after failed merge attempt."""
+        event = self.node_to_event[node]
+
+        # Reinsert into linked list
+        event.insert_between(prev_event, next_event)
+        if prev_event:
+            self.aug_graph.add_extra_dep(n=node, dep=prev_event.node)
+        if next_event and not prev_event:
+            self.aug_graph.add_extra_dep(n=next_event.node, dep=node)
+
+        # Remove bypass dependency
+        if prev_event and next_event:
+            self.aug_graph.remove_extra_dep(n=next_event.node, dep=prev_event.node)
+
+    def _try_timeline_position(
+        self,
+        bucket_info: CollBucket,
+        candidate: fx.Node,
+        start_pos: fx.Node,
+        wait_pos: fx.Node,
+        why: WhyNoBucket,
+    ) -> bool:
+        """
+        Try a specific timeline position for the candidate.
+        Returns True if valid and merges are successful.
+        """
+        candidate_info = self.collective_info[candidate]
+        candidate_wait = candidate_info.wait_node
+
+        # Quick check: does this violate hiding intervals?
+        if not self._preserves_hiding_intervals(
+            bucket_info, candidate, start_pos, wait_pos, why
+        ):
+            return False
+
+        # Determine which start needs to move
+        existing_coll = bucket_info.collectives[0]
+        if start_pos == existing_coll:
+            start_to_move = candidate
+        else:
+            assert start_pos == candidate
+            start_to_move = existing_coll
+
+        # Remove start from timeline
+        start_prev, start_next = self.remove_from_event(start_to_move)
+
+        # Check if starts can be merged
+        if self.aug_graph.has_path(existing_coll, candidate) or self.aug_graph.has_path(
+            candidate, existing_coll
+        ):
+            # Restore start constraints
+            self.restore_to_event(start_to_move, start_prev, start_next)
+            why("path exists between starts")
+            return False
+
+        # Merge starts
+        self.aug_graph.merge_to_set(existing_coll, candidate)
+
+        # Determine which wait needs to move
+        existing_wait = self.collective_info[existing_coll].wait_node
+        candidate_wait = self.collective_info[candidate].wait_node
+
+        if wait_pos == existing_wait:
+            wait_to_move = candidate_wait
+        else:
+            wait_to_move = existing_wait
+
+        # Remove wait from timeline
+        wait_prev, wait_next = self.remove_from_event(wait_to_move)
+
+        # Check if waits can be merged
+        if self.aug_graph.has_path(
+            existing_wait, candidate_wait
+        ) or self.aug_graph.has_path(candidate_wait, existing_wait):
+            # Restore wait constraints
+            self.restore_to_event(wait_to_move, wait_prev, wait_next)
+            # Unmerge the start we just merged
+            self.aug_graph.unmerge_node(candidate)
+            # Restore start constraints
+            self.restore_to_event(start_to_move, start_prev, start_next)
+            why("path exists between waits")
+            return False
+
+        # Merge waits - success!
+        self.aug_graph.merge_to_set(existing_wait, candidate_wait)
+
+        # Update node_to_event for moved nodes
+        target_start_event = self.node_to_event[start_pos]
+        target_wait_event = self.node_to_event[wait_pos]
+
+        self.node_to_event[candidate] = target_start_event
+        self.node_to_event[candidate_wait] = target_wait_event
+        self.node_to_event[existing_coll] = target_start_event
+        self.node_to_event[existing_wait] = target_wait_event
+
+        return True
+
+    def _has_ancestor_conflicts(
+        self, bucket_info: CollBucket, candidate: fx.Node
+    ) -> bool:
+        """
+        Check if candidate has ancestor conflicts with bucket collectives.
+        Returns True if there are conflicts.
+        """
+        candidate_info = self.collective_info[candidate]
+        candidate_wait = candidate_info.wait_node
+
+        for coll in bucket_info.collectives:
+            if (
+                coll in self.node_ancestors[candidate]
+                or candidate in self.node_ancestors[coll]
+            ):
+                return True
+
+            # Check if waits are ancestors of each other
+            coll_wait = self.collective_info[coll].wait_node
+            if (
+                coll_wait in self.node_ancestors[candidate_wait]
+                or candidate_wait in self.node_ancestors[coll_wait]
+            ):
+                return True
+
+            # Check if existing hiding node conflicts with candidate wait
+            for old_hiding_node in self.collective_info[coll].hiding_nodes:
+                if candidate_wait in self.node_ancestors[old_hiding_node]:
+                    return True
+
+            # Check if candidate hiding node conflicts with existing wait
+            for new_hiding_node in candidate_info.hiding_nodes:
+                if coll_wait in self.node_ancestors[new_hiding_node]:
+                    return True
+
+        return False
+
+    def _can_add_to_bucket(
+        self,
+        bucket_info: CollBucket,
+        candidate: fx.Node,
+    ) -> bool:
+        """
+        Check if candidate can be added to bucket without breaking comm/compute overlap.
+
+        Strategy: Try all timeline positions - combinations of [existing_start, candidate_start]
+        x [existing_wait, candidate_wait]. For each position, verify:
+        1. Hiding intervals preserved - for any (start, hiding_compute, wait) interval, no other
+           collective's (start, wait) pair falls between start and hiding_compute, which would
+           force realization and break overlap due to LIFO semantics
+        2. Topologically valid (no dependency cycles)
+
+        Return True if any timeline position satisfies both constraints.
+        """
+        existing_coll = bucket_info.collectives[0]
+        why = WhyNoBucket(existing_coll, candidate)
+
+        candidate_info = self.collective_info[candidate]
+
+        if (
+            candidate in self.all_hiding_nodes
+            or candidate_info.wait_node in self.all_hiding_nodes
+        ):
+            why("nyi: bucketing collective used for overlap")
+            return False
+
+        # Step 1: Quick check using precomputed ancestors
+        # These ancestors are computed prior to adding augmented dependencies and not updated,
+        # so if any of these checks fail then the merge will not be topologically valid
+        # even ignoring comm/compute overlap
+        if self._has_ancestor_conflicts(bucket_info, candidate):
+            why("has ancestor conflicts")
+            return False
+
+        # Step 2: Try different rail positions
+        existing_wait = self.collective_info[existing_coll].wait_node
+
+        candidate_start = candidate
+        candidate_wait = candidate_info.wait_node
+
+        # Try combinations in order of likelihood to succeed
+        # (early start, later wait is most likely to work)
+        combinations = [
+            (
+                existing_coll,
+                candidate_wait,
+            ),  # Move candidate start early, keep wait late
+            (
+                existing_coll,
+                existing_wait,
+            ),  # Move candidate start early, move wait early
+            (candidate_start, candidate_wait),  # Keep both in place
+            (candidate_start, existing_wait),  # Keep start in place, move wait early
+        ]
+
+        for i, (start_pos, wait_pos) in enumerate(combinations):
+            if self._try_timeline_position(
+                bucket_info, candidate, start_pos, wait_pos, why
+            ):
+                bucket_log.debug(
+                    "bucketed %s with %s using timeline position %d: (start=%s, wait=%s)",
+                    candidate.name,
+                    existing_coll.name,
+                    i + 1,
+                    start_pos.name,
+                    wait_pos.name,
+                )
+                return True
+
+        why("all timeline positions failed")
+        return False
+
+    def _apply_bucket(self, bucket_info: CollBucket) -> None:
+        """
+        Apply bucketing transformation.
+
+        Dependencies are added to aug_graph.extra_deps and transferred from old nodes.
+        """
+
+        from torch._inductor.fx_passes.bucketing import (
+            is_all_reduce_tensor,
+            merge_all_gather_bucket,
+            merge_all_reduce_bucket,
+            merge_reduce_scatter_bucket,
+        )
+
+        bucket = bucket_info.collectives
+
+        # Collect old nodes BEFORE they're erased
+        old_starts = list(bucket)
+        old_waits = [self.collective_info[n].wait_node for n in bucket]
+
+        fused_convert_dtypes = []
+        for n in old_starts:
+            if has_mergeable_all_gather_convert_dtype(n):
+                fused_convert_dtypes.append(n.args[0])
+
+        # Find where to place the bucketed operations
+        next_node = bucket[0]
+        while next_node in bucket:
+            next_node = next_node.next
+
+        # Don't use wait_insertion_point - let merge functions place waits naturally
+        # The wait_insertion_point feature tries to move waits to a specific location,
+        # but this can cause issues when that location is one of the nodes being erased
+        # Create bucketed collective (this will erase old nodes)
+        if is_all_gather(bucket[0]):
+            new_nodes, replacements = merge_all_gather_bucket(
+                self.graph,
+                bucket,
+                insert_before=next_node,
+                mode="custom_ops",
+            )
+        elif is_all_reduce_tensor(bucket[0]):
+            new_nodes, replacements = merge_all_reduce_bucket(
+                self.graph,
+                bucket,
+                mode="custom_ops",
+                insert_before=next_node,
+            )
+        else:
+            assert is_reduce_scatter(bucket[0])
+            new_nodes, replacements = merge_reduce_scatter_bucket(
+                self.graph,
+                bucket,
+                insert_before=next_node,
+                mode="custom_ops",
+            )
+
+        # Get new nodes
+        new_waits = [n for n in new_nodes if _schedulable_wait_node(n)]
+        assert len(new_waits) == 1
+
+        new_wait = new_waits[0]
+        new_start = new_wait.args[0]
+        assert isinstance(new_start, fx.Node)
+
+        # Create mapping of all erased nodes to their replacements
+        erased_to_new = {}
+        for old_start in old_starts:
+            erased_to_new[old_start] = new_start
+        for old_wait in old_waits:
+            erased_to_new[old_wait] = new_wait
+
+        # Handle convert_element_type nodes that were fused and erased
+        # The bucketed operation may have a _pre_bucket op that handles dtype conversion
+        if fused_convert_dtypes:
+            # all gather bucketing may fuse in dtype conversion into the bucketing
+            # if so, we need to transfer hiding deps from the old dtype conversion
+            # to the new bucketing node
+            new_convert_dtypes_node = new_start.kwargs["out"]
+            assert isinstance(new_convert_dtypes_node, fx.Node)
+            assert (
+                new_convert_dtypes_node.target
+                == torch.ops.bucketing._pre_bucket_all_gather.default
+            )
+
+            for n in fused_convert_dtypes:
+                erased_to_new[n] = new_convert_dtypes_node
+
+        # Transfer all dependencies from old nodes to new nodes
+        self.aug_graph.transfer_erased_node_deps(erased_to_new)
+
+    def _preserve_dependencies_with_tokens(
+        self, additional_deps: dict[fx.Node, OrderedSet[fx.Node]]
+    ) -> None:
+        """
+        Preserve dependencies using effect tokens and with_effects higher-order op.
+
+        Uses the standalone token_dependencies utility for consistent behavior
+        across different overlap scheduling approaches.
+        """
+        from torch._inductor.fx_passes.control_dependencies import (
+            preserve_node_ordering,
+        )
+
+        preserve_node_ordering(self.graph, additional_deps)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/overlap_scheduling.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/overlap_scheduling.py
new file mode 100644
index 0000000000000000000000000000000000000000..5770991dc233ef3dac26a8027f990c048acf3ce9
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/overlap_scheduling.py
@@ -0,0 +1,1324 @@
+import functools
+import heapq
+import itertools
+import logging
+import sys
+from collections import Counter, defaultdict
+from collections.abc import Callable, Iterable
+from dataclasses import dataclass, field
+from typing import Any, Literal
+
+import torch
+import torch.fx as fx
+from torch._dynamo.utils import counters, dynamo_timed
+from torch._inductor.comm_analysis import estimate_fx_collective_memory_footprint
+from torch._inductor.fx_passes.bucketing import _schedulable_wait_node, is_wait_tensor
+from torch._inductor.fx_passes.memory_estimator import MemoryTracker
+from torch.fx.operator_schemas import normalize_function
+from torch.utils._ordered_set import OrderedSet
+from torch.utils._python_dispatch import _disable_current_modes
+
+
+log = logging.getLogger(__name__)
+
+from torch._inductor.fx_passes.bucketing import bucket_key
+
+from ..pattern_matcher import stable_topological_sort
+
+
+def estimate_runtime_analytical(n: torch.fx.Node) -> float:
+    """Estimate runtime using analytical roofline model for mm operations."""
+    if n.target != torch.ops.aten.mm.default:
+        return 0.0
+    import torch.utils._pytree as pytree
+    from torch.distributed._tools import RuntimeEstimator
+
+    def _val(node: Any) -> Any:
+        if not isinstance(node, torch.fx.Node):
+            return node
+        return node.meta["val"]
+
+    args = pytree.tree_map(_val, n.args)
+    kwargs = pytree.tree_map(_val, n.kwargs)
+    _, ms = RuntimeEstimator._roofline_estimate(n.target, args, kwargs)
+    return ms
+
+
+@dataclass
+class WhyNoOverlap:
+    """Track reasons why a collective cannot overlap with compute."""
+
+    compute_name: str
+    collective_name: str
+
+    def __init__(self, compute_node: fx.Node, collective_node: fx.Node) -> None:
+        self.compute_name = compute_node.name
+        self.collective_name = collective_node.name
+
+    def __call__(self, reason: str, *args: Any) -> None:
+        if log.isEnabledFor(logging.DEBUG):
+            log.debug(
+                "cannot overlap %s with %s: " + reason,  # noqa: G003
+                self.collective_name,
+                self.compute_name,
+                *args,
+            )
+
+
+def get_group_name(n: fx.Node) -> str:
+    """Extract the group name from a collective operation node."""
+    opt_args_kwargs = normalize_function(
+        n.target,  # type: ignore[arg-type]
+        args=n.args,
+        kwargs=n.kwargs,
+        normalize_to_only_use_kwargs=True,
+    )
+    assert opt_args_kwargs is not None
+    _, kwargs = opt_args_kwargs
+    return kwargs["group_name"]
+
+
+def get_custom_estimation(
+    n: fx.Node,
+    custom_runtime_estimation: Callable[[fx.Node, int | None], float | None]
+    | None = None,
+    override_size: int | None = None,
+) -> float | None:
+    if custom_runtime_estimation is None:
+        return None
+
+    return custom_runtime_estimation(n, override_size)
+
+
+def estimate_collective_time(
+    n: fx.Node,
+    override_size: int | None = None,
+    custom_runtime_estimation: Callable[[fx.Node, int | None], float | None]
+    | None = None,
+) -> float:
+    """Estimate the runtime of a collective operation, optionally with an overridden size."""
+    if (
+        est := get_custom_estimation(n, custom_runtime_estimation, override_size)
+    ) is not None:
+        return est
+
+    # Use analytical model (benchmarking is handled separately in alignment)
+    return torch._inductor.comm_analysis.estimate_nccl_collective_runtime_from_fx_node(
+        n, override_size
+    )
+
+
+def is_compute_node(n: fx.Node) -> bool:
+    """
+    Should we consider this node computationally expensive ?
+    Currently uses flop registration, but we could expand more generally.
+    """
+    return (
+        getattr(n.target, "overloadpacket", None)
+        in torch.utils.flop_counter.flop_registry
+    )
+
+
+def is_reduce_scatter(n: fx.Node) -> bool:
+    """Check if node is a reduce_scatter collective."""
+    return "reduce_scatter" in str(n.target).lower()
+
+
+def get_hint(x: int | torch.SymInt) -> int | None:
+    if isinstance(x, int):
+        return x
+    assert isinstance(x, torch.SymInt)
+    if not x.node.has_hint():
+        return None
+    return x.node.hint
+
+
+def get_collective_do_bench() -> Callable[[Callable[[], Any]], float]:
+    with dynamo_timed("collective_compute_do_bench"):
+        return functools.partial(
+            # pyrefly: ignore [bad-argument-type]
+            torch._inductor.runtime.benchmarking.benchmarker.benchmark_gpu,
+            warmup=5,
+        )
+
+
+def benchmark_node_with_cache_key(
+    n: fx.Node,
+    custom_runtime_estimation: Callable[[fx.Node, int | None], float | None]
+    | None = None,
+) -> tuple[float, str | None]:
+    """Benchmark a compute node and return (runtime, cache_key)."""
+    assert is_compute_node(n)
+
+    from torch._dynamo.testing import rand_strided
+
+    # todo - skip unbacked, symbolic
+    success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs(n)
+
+    if not success:
+        return 0, None
+
+    unbacked_tensor = False
+
+    key = f"{str(n.target)}: "
+
+    def to_real(t: torch.Tensor) -> torch.Tensor | None:
+        shape = [get_hint(dim) for dim in t.shape]
+        stride = [get_hint(s) for s in t.stride()]
+
+        if any(s is None for s in itertools.chain(shape, stride)):
+            nonlocal unbacked_tensor
+            unbacked_tensor = True
+            return None
+
+        nonlocal key
+        key += f"T: {shape, stride, t.dtype} "
+        return rand_strided(shape, stride, device=t.device, dtype=t.dtype)  # type: ignore[arg-type]
+
+    with _disable_current_modes():
+        args, kwargs = torch.utils._pytree.tree_map_only(
+            torch.Tensor,
+            lambda t: to_real(t),
+            (args, kwargs),
+        )
+
+        if val := get_cached_node_time(key):
+            return val, key
+
+        if unbacked_tensor:
+            return 0, key
+
+        if (
+            est := get_custom_estimation(n, custom_runtime_estimation, None)
+        ) is not None:
+            set_cached_node_time(key, est)
+            return est, key
+
+        bench = get_collective_do_bench()
+        out = bench(lambda: n.target(*args, **kwargs))  # type: ignore[operator]
+        set_cached_node_time(key, out)
+        return out, key
+
+
+def benchmark_node(
+    n: fx.Node,
+    custom_runtime_estimation: Callable[[fx.Node, int | None], float | None]
+    | None = None,
+) -> float:
+    return benchmark_node_with_cache_key(n, custom_runtime_estimation)[0]
+
+
+@functools.cache
+def get_benchmark_cache() -> torch._inductor.codecache.LocalCache:
+    return torch._inductor.codecache.LocalCache()
+
+
+def get_cached_node_time(key: str) -> float:
+    return get_benchmark_cache().lookup(key)  # type: ignore[return-value]
+
+
+def set_cached_node_time(key: str, value: float) -> None:
+    return get_benchmark_cache().set_value(key, value=value)
+
+
+@dataclass
+class CollectiveInfo:
+    """Track info about a collective operation"""
+
+    start_node: fx.Node
+    wait_node: fx.Node
+    size_bytes: int
+    estimated_time_ms: float
+    exposed_time_ms: float  # How much of this collective is still exposed
+    hiding_nodes: OrderedSet[fx.Node] = field(default_factory=OrderedSet)
+
+    @property
+    def is_exposed(self) -> bool:
+        return self.exposed_time_ms != 0
+
+
+@dataclass
+class CollBucket:
+    """Track information about a bucket of collectives."""
+
+    collectives: list[fx.Node]  # Original collective starts
+    bucketed_start: fx.Node | None = None  # After bucketing
+    bucketed_wait: fx.Node | None = None  # After bucketing
+    total_bytes: int = 0
+
+
+def gb_to_bytes(gb: float) -> int:
+    """Convert gigabytes to bytes."""
+    return int(gb * 1024 * 1024 * 1024)
+
+
+class OverlapScheduler:
+    """
+    Scheduler that reorders operations to maximize compute-collective overlap.
+
+    The reordering is done as a scheduling pass. We maintain a priority queue of
+    schedulable nodes. The nodes are ranked by:
+
+    1) the compute node index they dominate. this allows reordering locally, such as with
+    parallel mms, and also allows overlapping reduce scatter nodes outputs in the backward
+    with compute by deferring their waits.
+
+    2) whether the current node is a collective or wait that is currently exposed but has a compute
+    node which it could be overlapped with.
+
+    3) original order in the graph for stability.
+
+    When we schedule compute nodes, we first overlap exposed in-flight collectives, then look for unscheduled
+    collectives that can be scheduled concurrently.
+
+    TODO:
+        - experiment with other priority scores / allow other mechanisms of reorder / more strict adherence to original graph
+        - memory limit for deferred scheduling of reduce_scatter nodes.
+    """
+
+    def __init__(
+        self,
+        gm: torch.fx.GraphModule,
+        max_in_flight_gb: float,
+        max_compute_pre_fetch: int,
+        collective_bucketing: bool,
+        insert_overlap_deps: bool,
+        compute_overlap_multipler: float,
+        max_coll_distance: int,
+        custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] | None,
+        collective_estimator: Literal["analytical", "benchmark"],
+        max_memory_increase_gb: float | None = 1.0,
+        max_memory_increase_ratio: float | None = 0.05,
+    ):
+        self.gm = gm
+        self.graph = gm.graph
+        self.compute_overlap_multipler = compute_overlap_multipler
+        self.max_node_distance = max_coll_distance
+        self.max_in_flight_bytes: int = gb_to_bytes(max_in_flight_gb)
+        self.custom_runtime_estimation = custom_runtime_estimation
+        self.collective_bucketing = collective_bucketing
+        self.insert_overlap_deps = insert_overlap_deps
+        self.max_compute_pre_fetch = max_compute_pre_fetch
+        self.collective_estimator = collective_estimator
+
+        # Build structures
+        stable_topological_sort(self.graph)
+        self.nodes = list(self.graph.nodes)
+        self.node_idx = {n: i for i, n in enumerate(self.nodes)}
+        self.node_ancestors: dict[fx.Node, OrderedSet[fx.Node]] = (
+            self._collect_node_ancestors()
+        )
+
+        # Identify collectives and compute nodes
+        self.collective_info: dict[fx.Node, CollectiveInfo] = {}
+        self.unscheduled_collectives: OrderedSet[fx.Node] = OrderedSet()
+
+        # Identify compute nodes early (needed for baseline memory computation)
+        self.compute_nodes = [n for n in self.nodes if is_compute_node(n)]
+        self.current_compute_index = 0
+
+        # Compute baseline memory profile from original schedule
+        self.original_mem_before_compute_index: list[int] = []
+        self.original_peak_memory = self._compute_baseline_memory()
+
+        # Maximum allowed peak memory = baseline + max(absolute, ratio * baseline)
+        # When both limits are specified, use the more permissive one
+        memory_increase_bytes = None
+        if max_memory_increase_gb is not None:
+            memory_increase_bytes = gb_to_bytes(max_memory_increase_gb)
+        if max_memory_increase_ratio is not None:
+            ratio_increase = int(self.original_peak_memory * max_memory_increase_ratio)
+            memory_increase_bytes = (
+                max(memory_increase_bytes, ratio_increase)
+                if memory_increase_bytes is not None
+                else ratio_increase
+            )
+        if memory_increase_bytes is None:
+            memory_increase_bytes = 0
+
+        self.allowed_peak_memory_bytes = (
+            self.original_peak_memory + memory_increase_bytes
+        )
+
+        # Track cumulative prefetch memory at each compute index
+        # When we prefetch a collective at compute index i that will be used at index j,
+        # it adds memory from i to j, so we need to track this cumulative effect
+        self.cumulative_prefetch_mem_by_compute_index: list[int] = [
+            0 for _ in range(len(self.compute_nodes))
+        ]
+
+        self.memory_tracker = MemoryTracker(self.graph)
+
+        self.wait_to_start: dict[fx.Node, fx.Node] = {}
+        self._identify_collectives()
+        self.wasted_compute = 0.0
+
+        self.compute_index_domination = self._calculate_compute_node_domination_index()
+
+        # Scheduling state
+        self.potentially_hidden_collectives = (
+            self.compute_potential_hidden_collectives()
+        )
+        self.potentially_hidden_waits = self.compute_potential_hidden_waits()
+        self.in_degree = Counter(user for node in self.nodes for user in node.users)
+        self.ready: list[tuple[object, fx.Node]] = []
+
+        for node in self.nodes:
+            if self.in_degree[node] == 0:
+                heapq.heappush(self.ready, (self._compute_score(node), node))
+
+        self.in_flight: dict[fx.Node, CollectiveInfo] = {}  # start -> info
+        self.in_flight_bytes = 0
+        self.scheduled: OrderedSet[fx.Node] = OrderedSet()
+        self.max_compute_pre_fetch = max_compute_pre_fetch
+
+    def _collect_node_ancestors(self) -> dict[fx.Node, OrderedSet[fx.Node]]:
+        """Collect all ancestors for each node."""
+        ancestors: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
+        for node in self.nodes:
+            for input_node in node.all_input_nodes:
+                ancestors[node].add(input_node)
+                ancestors[node] |= ancestors[input_node]
+
+        return ancestors
+
+    def _compute_baseline_memory(self) -> int:
+        """
+        Simulate the original schedule to compute baseline memory profile.
+        Returns the peak memory observed during simulation.
+        """
+        baseline_tracker = MemoryTracker(self.graph)
+
+        last_compute_max_memory = 0
+        peak_memory = 0
+
+        for node in self.nodes:
+            baseline_tracker.schedule_node(node)
+            current_mem = baseline_tracker.current_memory_bytes
+
+            # Record the max memory between this and previous compute node
+            last_compute_max_memory = max(last_compute_max_memory, current_mem)
+
+            if is_compute_node(node):
+                self.original_mem_before_compute_index.append(last_compute_max_memory)
+                last_compute_max_memory = current_mem
+
+            peak_memory = max(peak_memory, current_mem)
+
+        return peak_memory
+
+    def _prefetch_would_exceed_memory_budget(self, start_node: fx.Node) -> bool:
+        """
+        Check if prefetching this collective would exceed memory budget at ANY compute node
+        between now and when it's used.
+        """
+        info = self.collective_info[start_node]
+        size = info.size_bytes
+
+        domination_index = self.compute_index_domination[start_node]
+
+        # If off-path, assume it doesn't increase memory
+        if domination_index == sys.maxsize:
+            return False
+
+        # check current mem
+        if (
+            self.memory_tracker.current_memory_bytes + size
+            > self.allowed_peak_memory_bytes
+        ):
+            return True
+
+        start_index = self.current_compute_index
+
+        # then, check future mem
+        for compute_idx in range(start_index, domination_index):
+            cumulative_prefetch = self.cumulative_prefetch_mem_by_compute_index[
+                compute_idx
+            ]
+
+            # Check 1: Would cumulative prefetch exceed in-flight limit?
+            if (cumulative_prefetch + size) > self.max_in_flight_bytes:
+                return True
+
+            # Check 2: Would total memory (baseline + cumulative prefetch) exceed budget?
+            baseline_mem = self.original_mem_before_compute_index[compute_idx]
+            projected = baseline_mem + cumulative_prefetch + size
+
+            if projected > self.allowed_peak_memory_bytes:
+                return True
+
+        return False
+
+    def _update_cumulative_prefetch_memory(
+        self, collective: fx.Node, info: CollectiveInfo
+    ) -> None:
+        """
+        Update cumulative prefetch memory for all compute indices this collective will be live.
+        """
+        domination_index = self.compute_index_domination[collective]
+        if domination_index == sys.maxsize:
+            return
+
+        for compute_idx in range(self.current_compute_index, domination_index):
+            self.cumulative_prefetch_mem_by_compute_index[compute_idx] += (
+                info.size_bytes
+            )
+
+    def off_compute_path(self, n: fx.Node) -> bool:
+        """Check if a node is off the compute path (doesn't block any compute)."""
+        return self.compute_index_domination[n] == sys.maxsize
+
+    def _identify_collectives(self) -> None:
+        """Identify all collective operations and process groups."""
+        self.all_pgs: OrderedSet[str] = OrderedSet()
+
+        for node in self.nodes:
+            if _schedulable_wait_node(node):
+                start = node.args[0]
+                coll_time_ms = estimate_collective_time(
+                    start, custom_runtime_estimation=self.custom_runtime_estimation
+                )
+
+                info = CollectiveInfo(
+                    start_node=start,
+                    wait_node=node,
+                    size_bytes=estimate_fx_collective_memory_footprint(start),
+                    estimated_time_ms=coll_time_ms,
+                    exposed_time_ms=coll_time_ms,  # Initially fully exposed
+                )
+                self.collective_info[start] = info
+                self.wait_to_start[node] = start
+                self.unscheduled_collectives.add(start)
+                self.all_pgs.add(get_group_name(start))
+
+    def _calculate_compute_node_domination_index(self) -> dict[fx.Node, int]:
+        """
+        Compute the topological index of the earliest compute node each node dominates.
+
+        Compute nodes are assigned indices based on their topological order (0, 1, 2, ...).
+        For each node, returns the minimum index of compute nodes it blocks/dominates.
+        Returns sys.maxsize if the node doesn't block any compute nodes.
+        """
+        compute_node_index: dict[fx.Node, int] = {}
+        for node in self.graph.nodes:
+            if is_compute_node(node):
+                compute_node_index[node] = len(compute_node_index)
+
+        domination_index: dict[fx.Node, int] = {}
+        for node in reversed(self.graph.nodes):
+            if node in compute_node_index:
+                # Compute nodes dominate themselves (return their own index)
+                domination_index[node] = compute_node_index[node]
+            else:
+                domination_index[node] = min(
+                    (domination_index[succ] for succ in node.users), default=sys.maxsize
+                )
+
+        return domination_index
+
+    def _align_compute_nodes_runtime_estimations_across_all_distributed_ranks(
+        self,
+    ) -> None:
+        """Align runtime estimations across ranks (compute + collectives)."""
+        log.info(
+            "Overlap scheduling: Aligning runtime estimations across all distributed ranks"
+        )
+
+        # Benchmark compute nodes
+        runtime_estimations_keys: list[str | None] = []
+        runtime_estimations: list[float] = []
+        compute_key_count = 0
+
+        # Also collect analytical estimations for logging
+        runtime_estimations_analytical: list[float] = []
+
+        for n in self.compute_nodes:
+            val, key = benchmark_node_with_cache_key(n, self.custom_runtime_estimation)
+
+            # Analytical estimations
+            val_analytical = estimate_runtime_analytical(n)
+            runtime_estimations_analytical.append(val_analytical)
+
+            runtime_estimations.append(val)
+            runtime_estimations_keys.append(key)
+            compute_key_count += 1
+
+        # Log compute estimations
+        from torch._inductor.fx_passes.node_runtime_estimation import (
+            _log_compute_estimations,
+        )
+
+        _log_compute_estimations(
+            self.compute_nodes,
+            runtime_estimations,
+            runtime_estimations_analytical,
+        )
+
+        # Benchmark collectives if enabled (only CUDA events - others are deterministic)
+        # Skip if custom estimation is provided for collectives
+        collective_nodes: list[fx.Node] = []
+        benchmarked_collective_nodes: list[
+            fx.Node
+        ] = []  # Track which were actually benchmarked
+        if self.collective_estimator == "benchmark":
+            from torch._inductor.fx_passes.node_runtime_estimation import (
+                benchmark_collective_with_cuda_events,
+            )
+
+            collective_nodes = [
+                info.start_node for info in self.collective_info.values()
+            ]
+
+            # Benchmark CUDA events (non-deterministic, needs alignment)
+            # Skip collectives with custom estimation
+            for n in collective_nodes:
+                if (
+                    get_custom_estimation(n, self.custom_runtime_estimation, None)
+                    is not None
+                ):
+                    continue
+
+                # Benchmark actual size
+                cuda_val, cuda_key = benchmark_collective_with_cuda_events(n, nruns=5)
+                if cuda_val is not None:
+                    runtime_estimations.append(cuda_val)
+                    runtime_estimations_keys.append(cuda_key)
+                    benchmarked_collective_nodes.append(n)
+
+        # Single all_gather and compute medians
+        import torch.distributed as dist
+        from torch._subclasses.fake_tensor import unset_fake_temporarily
+        from torch.distributed.distributed_c10d import _get_default_group
+
+        world_size = dist.get_world_size()
+        pg = _get_default_group()
+
+        with unset_fake_temporarily():
+            gathered_runtime_estimations: list[list[float]] = [
+                [] for _ in range(world_size)
+            ]
+            dist.all_gather_object(
+                gathered_runtime_estimations, runtime_estimations, pg
+            )
+            median_runtime_estimations = torch.median(
+                torch.tensor(gathered_runtime_estimations), dim=0
+            ).values.tolist()
+
+        # Cache medians
+        collective_keys = []
+        collective_medians = []
+        for idx, (key, median_runtime_estimation) in enumerate(
+            zip(runtime_estimations_keys, median_runtime_estimations)
+        ):
+            if key is None:
+                continue
+            if idx < compute_key_count:
+                # Compute node
+                set_cached_node_time(key, median_runtime_estimation)
+            else:
+                # Collective CUDA event benchmark
+                from torch._inductor.fx_passes.node_runtime_estimation import (
+                    set_cached_runtime,
+                )
+
+                set_cached_runtime(key, median_runtime_estimation)
+
+                # Update CollectiveInfo with aligned benchmark
+                coll_idx = idx - compute_key_count
+                coll_node = benchmarked_collective_nodes[coll_idx]
+                info = self.collective_info[coll_node]
+                info.estimated_time_ms = median_runtime_estimation
+                info.exposed_time_ms = median_runtime_estimation
+
+                collective_keys.append(key)
+                collective_medians.append(median_runtime_estimation)
+
+        # Log benchmarks with analytical comparisons
+        if collective_keys:
+            from torch._inductor.fx_passes.node_runtime_estimation import (
+                _log_collective_benchmarks,
+            )
+
+            _log_collective_benchmarks(
+                benchmarked_collective_nodes,
+                collective_keys,
+                collective_medians,
+                world_size,
+            )
+
+        log.info("Overlap scheduling: Runtime estimations aligned")
+
+    def run(self) -> torch.fx.GraphModule:
+        """Run the scheduling algorithm."""
+        # All ranks must make identical decisions on overlap reordering,
+        # Thus we must have identical runtime estimations across ranks.
+        # For now we do benchmarking only for compute nodes.
+        self._align_compute_nodes_runtime_estimations_across_all_distributed_ranks()
+
+        while self.ready:
+            if self._should_force_wait_for_memory():
+                self._force_oldest_wait()
+                continue
+
+            _, node = heapq.heappop(self.ready)
+
+            # we don't always remove nodes from the heap when we schedule them
+            if node in self.scheduled:
+                continue
+
+            if node.op == "placeholder":
+                self._schedule(node)
+            elif node in self.collective_info:
+                self._handle_collective_start(node)
+            elif _schedulable_wait_node(node):
+                self._handle_wait(node)
+            else:
+                self._handle_compute_or_other(node)
+
+        self._reorder_graph()
+
+        if self.collective_bucketing:
+            self._bucket_collectives()
+        elif self.insert_overlap_deps:
+            # If not bucketing, add effect tokens to preserve hiding dependencies
+            self._add_effect_tokens_for_overlap()
+
+        return self.gm
+
+    def _add_effect_tokens_for_overlap(self) -> None:
+        """
+        Add effect tokens to preserve hiding dependency relationships when not bucketing.
+
+        This ensures that communication-compute overlap is preserved through effect tokens
+        when overlap preserving bucketing is not enabled.
+        """
+        from torch._inductor.fx_passes.control_dependencies import (
+            preserve_node_ordering,
+        )
+
+        # Collect hiding dependencies: hiding_node -> collective_start, wait -> hiding_node
+        additional_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
+
+        for start_node, info in self.collective_info.items():
+            if info.is_exposed:
+                continue
+            for hn in info.hiding_nodes:
+                # Compute depends on collective start (compute must wait for collective to start)
+                additional_deps[hn].add(start_node)
+                # Wait depends on compute (wait must wait for compute to finish)
+                additional_deps[info.wait_node].add(hn)
+
+        # Apply effect tokens to preserve these dependencies
+        if additional_deps:
+            preserve_node_ordering(self.graph, additional_deps)
+
+    def get_non_collective_runtime_estimate(self, node: fx.Node) -> float | None:
+        """Get runtime estimation for a node in ms. Returns None if no estimation is available."""
+
+        # TODO: non custom estimation of aten nodes, potentially requires notion of fusion group
+        if is_compute_node(node):
+            return benchmark_node(node, self.custom_runtime_estimation)
+
+        if self.custom_runtime_estimation is None:
+            return None
+
+        return self.custom_runtime_estimation(node, None)
+
+    def _reduce_exposed_time_of_in_flight_collectives(
+        self,
+        node: fx.Node,
+        available_compute: float,
+        exclude_pg: str | None = None,
+    ) -> dict[str, float]:
+        """
+        Reduce exposed time of in-flight collectives using available compute time.
+
+        Collectives on different process groups can overlap simultaneously with the same
+        compute, so we track remaining time separately per PG.
+        """
+        # Initialize all PGs with full available compute (except excluded)
+        remaining_time_per_pg: dict[str, float] = {
+            pg: available_compute for pg in self.all_pgs if pg != exclude_pg
+        }
+
+        for start_node, info in self.in_flight.items():
+            if info.exposed_time_ms == 0:
+                continue
+
+            pg_name = get_group_name(start_node)
+            if pg_name == exclude_pg:
+                continue
+
+            pg_remaining = remaining_time_per_pg[pg_name]
+            if pg_remaining <= 0:
+                continue
+
+            overlap_amount = min(info.exposed_time_ms, pg_remaining)
+            info.exposed_time_ms -= overlap_amount
+            remaining_time_per_pg[pg_name] -= overlap_amount
+            info.hiding_nodes.add(node)
+
+        return remaining_time_per_pg
+
+    def _handle_compute_or_other(self, node: fx.Node) -> None:
+        """Handle scheduling compute or other nodes and attempt to overlap with collectives."""
+        runtime_estimate = self.get_non_collective_runtime_estimate(node)
+
+        # TODO: we could consider skipping overlapping for overlapable, unary chains to collectives.
+        # using these nodes for overlap prevents bucketing. potentially if chain time < latency
+        if runtime_estimate is None:
+            assert not is_compute_node(node), "should have estimate for compute nodes"
+            self._schedule(node)
+            return
+
+        available_compute = runtime_estimate * self.compute_overlap_multipler
+
+        # First, reduce exposed time of in-flight collectives (per PG)
+        remaining_time_per_pg = self._reduce_exposed_time_of_in_flight_collectives(
+            node, available_compute
+        )
+        # Then, schedule new collectives for overlap
+        self._schedule_collectives_for_overlap(node, remaining_time_per_pg)
+        self._schedule(node)
+
+        if is_compute_node(node):
+            self.current_compute_index += 1
+
+    def _schedule(self, node: fx.Node) -> None:
+        """Schedule a node."""
+        assert node not in self.scheduled
+        assert all(n in self.scheduled for n in node.all_input_nodes)
+        self.scheduled.add(node)
+        self.memory_tracker.schedule_node(node)
+
+        log.debug(
+            "Scheduled node %s: current_memory=%d bytes, total_scheduled=%d",
+            node.name,
+            self.memory_tracker.get_current_memory_bytes(),
+            len(self.scheduled),
+        )
+
+        for user in node.users:
+            self.in_degree[user] -= 1
+            if self.in_degree[user] == 0:
+                heapq.heappush(self.ready, (self._compute_score(user), user))
+
+    def _compute_score(self, node: fx.Node) -> object:
+        """Compute priority score for a node"""
+
+        if _schedulable_wait_node(node):
+            info = self.collective_info[self.wait_to_start[node]]
+            # defer waits locally if they are exposed.
+            compute_local_priority = int(info.is_exposed)
+        else:
+            # if we're scheduling this collective via its queue, then it was not
+            # pre-fetched. we might as well maximize overlap for the
+            # local, non-mm nodes prior to the next compute node.
+            if self.in_overlappable_collective_unary_chain(node):
+                compute_local_priority = -1
+            else:
+                compute_local_priority = 0
+
+        return (
+            self.compute_index_domination[node],  # what index compute it blocks
+            compute_local_priority,  # collective_start=-1, wait=1, or neither=0
+            self.node_idx[node],  # Original order for stability
+        )
+
+    @staticmethod
+    def is_cheap_fn(node: fx.Node) -> bool:
+        return getattr(node.target, "is_view", False) or torch.Tag.pointwise in getattr(
+            node.target, "tags", ()
+        )
+
+    def in_overlappable_collective_unary_chain(self, curr: fx.Node) -> bool:
+        while True:
+            if len(curr.users) != 1:
+                return False
+
+            user = next(iter(curr.users))
+            if len(user.all_input_nodes) != 1:
+                return False
+
+            if user in self.unscheduled_collectives:
+                return True
+
+            if not self.is_cheap_fn(user):
+                return False
+
+            curr = user
+
+        return False
+
+    def _should_force_wait_for_memory(self) -> bool:
+        """Check if we need to force a wait due to memory pressure"""
+        if not self.in_flight:
+            return False
+
+        return self.in_flight_bytes >= self.max_in_flight_bytes
+
+    def _force_oldest_wait(self) -> None:
+        """Schedule the oldest in flight wait"""
+        self._handle_wait(self._get_oldest_wait())
+
+    def _handle_collective_start(self, node: fx.Node) -> None:
+        """Handle scheduling a collective start."""
+        info = self.collective_info[node]
+
+        if self.should_assume_bucketed(node):
+            latency = estimate_collective_time(
+                node, 0, custom_runtime_estimation=self.custom_runtime_estimation
+            )
+            assert latency <= info.exposed_time_ms
+            info.exposed_time_ms = info.exposed_time_ms - latency
+
+        self.in_flight[node] = info
+        self.in_flight_bytes += info.size_bytes
+        self.unscheduled_collectives.discard(node)
+        self._schedule(node)
+
+    def _handle_wait(self, node: fx.Node) -> None:
+        """Handle scheduling a wait."""
+        assert node in self.wait_to_start
+        coll_start = self.wait_to_start[node]
+        assert coll_start in self.in_flight
+
+        # Scheduling a wait of a collective also forces the wait
+        # of every node enqueued prior to the collective on the
+        # same process group
+        group_name = get_group_name(coll_start)
+        to_schedule: list[fx.Node] = []
+        for in_flight_coll in self.in_flight:
+            if in_flight_coll == coll_start:
+                break
+            if get_group_name(in_flight_coll) == group_name:
+                to_schedule.append(in_flight_coll)
+
+        for coll_to_schedule in to_schedule:
+            self._handle_wait(self.collective_info[coll_to_schedule].wait_node)
+
+        # If we are waiting on an exposed collective, use this time to
+        # overlap on other PGs.
+        info = self.collective_info[coll_start]
+        if info.exposed_time_ms > 0:
+            exposed_time = info.exposed_time_ms
+            exclude_pg = group_name
+
+            remaining_time_per_pg = self._reduce_exposed_time_of_in_flight_collectives(
+                node, exposed_time, exclude_pg=exclude_pg
+            )
+            self._schedule_collectives_for_overlap(
+                node, remaining_time_per_pg, exclude_pg=exclude_pg
+            )
+
+        self.in_flight_bytes -= self.in_flight[coll_start].size_bytes
+        del self.in_flight[coll_start]
+        self._schedule(node)
+
+    def _schedule_collectives_for_overlap(
+        self,
+        overlap_node: fx.Node,
+        remaining_time_per_pg: dict[str, float],
+        exclude_pg: str | None = None,
+    ) -> None:
+        """Opportunistically schedule collectives that can be hidden by available overlap time."""
+        if not remaining_time_per_pg or all(
+            t <= 0 for t in remaining_time_per_pg.values()
+        ):
+            return
+
+        overlap_node_ancestors = self.node_ancestors[overlap_node]
+
+        # Compile candidates - limit by distance to bound compile time
+        candidates = []
+        for i, collective in enumerate(self.unscheduled_collectives):
+            if i > self.max_node_distance:
+                break
+
+            pg_name = get_group_name(collective)
+            if pg_name == exclude_pg:
+                continue
+
+            if (
+                not self.off_compute_path(collective)
+                and self.compute_index_domination[collective]
+                - self.current_compute_index
+                > self.max_compute_pre_fetch
+            ):
+                continue
+
+            candidates.append(collective)
+
+        # Sort candidates prioritizing:
+        # 1. reduce_scatter operations (reduce memory pressure)
+        # 2. Earlier domination index
+        # 3. Original order for stability
+        candidates.sort(
+            key=lambda n: (
+                not is_reduce_scatter(n),  # reduce_scatter first
+                self.compute_index_domination[n],
+                self.node_idx[n],
+            ),
+        )
+
+        for collective in candidates:
+            pg_name = get_group_name(collective)
+            pg_available_time = remaining_time_per_pg[pg_name]
+
+            if pg_available_time <= 0:
+                continue
+
+            why = WhyNoOverlap(overlap_node, collective)
+            info = self.collective_info[collective]
+
+            if (
+                collective in overlap_node_ancestors
+                or overlap_node in self.node_ancestors[collective]
+            ):
+                why("dependency conflict")
+                continue
+
+            # Check if prefetching would exceed memory budget
+            if self._prefetch_would_exceed_memory_budget(collective):
+                why("prefetch would exceed memory budget")
+                continue
+
+            # Try to free memory by forcing hidden waits
+            while (
+                self.in_flight
+                and (self.max_in_flight_bytes - self.in_flight_bytes) < info.size_bytes
+                and self._wait_is_hidden(self._get_oldest_wait(), overlap_node)
+            ):
+                self._force_oldest_wait()
+
+            if (self.max_in_flight_bytes - self.in_flight_bytes) < info.size_bytes:
+                why("in-flight memory limit")
+                continue
+
+            # Check if we can reach this collective without scheduling compute, other collectives, or waits
+            path = self._find_schedulable_path(collective, overlap_node, why)
+            if path is None:
+                continue
+
+            log.debug(
+                "Overlapping collective %s with node %s: coll_domination=%d, current_depth=%d",
+                collective.name,
+                overlap_node.name,
+                self.compute_index_domination[collective],
+                self.current_compute_index,
+            )
+
+            # TODO: We previously tracked path compute time and added it back to available
+            # overlap time. With per-PG tracking this is complex: if there were in-flight
+            # collectives on one PG but not another, we can't add path time back to the PG
+            # that wasn't in-flight
+
+            # Schedule path and collective
+            self._schedule_path_to_collective(path, overlap_node)
+            self._handle_collective_start(collective)
+            self._update_cumulative_prefetch_memory(collective, info)
+
+            # Update exposed time for this collective
+            overlap_amount = min(pg_available_time, info.exposed_time_ms)
+            info.exposed_time_ms -= overlap_amount
+            info.hiding_nodes.add(overlap_node)
+
+            # Update available time for this PG
+            remaining_time_per_pg[pg_name] -= overlap_amount
+
+            if sum(remaining_time_per_pg.values()) == 0:
+                break
+
+        if remaining_time_per_pg:
+            self.wasted_compute += min(remaining_time_per_pg.values())
+
+    def _find_schedulable_path(
+        self, target: fx.Node, curr_overlap_node: fx.Node | None, why: WhyNoOverlap
+    ) -> OrderedSet[fx.Node] | None:
+        """Find path to target by collecting unscheduled dependencies."""
+        # Get unscheduled ancestors
+        unscheduled_ancestors = self.node_ancestors[target] - self.scheduled
+
+        # only schedule non distributed, non compute nodes
+        for node in unscheduled_ancestors:
+            if is_compute_node(node):
+                why("path blocked by compute node %s", node.name)
+                return None
+
+            if node in self.unscheduled_collectives:
+                why("path blocked by unscheduled collective %s", node.name)
+                return None
+
+            # if we schedule a wait tensor whose start collective is hidden by the
+            # current compute node we are scheduling, then we are effectively exposing it.
+            # similarly, dont schedule a wait of a collective that could be otherwise hidden,
+            # thus forcing it to be exposed.
+            # however, if it is already hidden it's fine to schedule it
+            if _schedulable_wait_node(node):
+                info = self.collective_info[self.wait_to_start[node]]
+                # Allow if fully hidden by other nodes
+                if not info.is_exposed and curr_overlap_node not in info.hiding_nodes:
+                    continue
+
+                why(
+                    "path blocked by wait node %s (exposed=%s, hiding_nodes=%s)",
+                    node.name,
+                    info.is_exposed,
+                    curr_overlap_node in info.hiding_nodes,
+                )
+
+            # Skip c10 ops and dtensor shard ops - they should be scheduled via main loop
+            target_str = str(node.target)
+            if "c10" in target_str or "_dtensor" in target_str:
+                log.debug(
+                    "Skipping c10/dtensor op %s in path to collective",
+                    node.name,
+                )
+                return None
+
+        return unscheduled_ancestors
+
+    def should_assume_bucketed(self, node: fx.Node) -> bool:
+        """
+        Check if there's an in-flight collective that can be bucketed with the given node. If so, assume they will bucket.
+        This is a optimistic heuristic to account for latency reduction with bucketing. The two nodes may not get bucketed.
+        """
+        if not torch._inductor.config.test_configs.assume_bucketing_reduces_latency:
+            return False
+
+        key = bucket_key(node, mode="custom_ops_multidtype")
+        if key is None:
+            return False
+
+        for in_flight_coll in self.in_flight:
+            if bucket_key(in_flight_coll, mode="custom_ops_multidtype") == key:
+                return True
+
+        return False
+
+    def _get_oldest_wait(self) -> fx.Node:
+        oldest_start = next(iter(self.in_flight))
+        return self.collective_info[oldest_start].wait_node
+
+    def _wait_is_hidden(
+        self, wait_node: fx.Node, overlap_node: fx.Node | None = None
+    ) -> bool:
+        assert is_wait_tensor(wait_node)
+        info = self.collective_info[self.wait_to_start[wait_node]]
+        return not info.is_exposed and overlap_node not in info.hiding_nodes
+
+    def _schedule_path_to_collective(
+        self, path: OrderedSet[fx.Node], curr_overlap_node: fx.Node
+    ) -> None:
+        """Schedule all nodes needed to reach a collective."""
+
+        assert all(n not in self.scheduled for n in path)
+        for node in sorted(path, key=lambda n: self.node_idx[n]):
+            assert not (is_compute_node(node) or node in self.unscheduled_collectives)
+            if _schedulable_wait_node(node):
+                # When we schedule wait tensors, we also force realization of all
+                # collectives enqueued prior to their corresponding collective.
+                # It's possible the scheduling of one wait tensor here has forced
+                # another in the path. If so, skip scheduling it.
+                if node in self.scheduled:
+                    continue
+
+                info = self.collective_info[self.wait_to_start[node]]
+                assert curr_overlap_node not in info.hiding_nodes
+                self._handle_wait(node)
+                continue
+
+            self._schedule(node)
+
+    def reorder_graph(self) -> None:
+        output_node = self.graph.output_node()
+        for node in self.scheduled:
+            if node.op == "placeholder":
+                continue
+            output_node.prepend(node)
+        self.graph.lint()
+
+    def _reorder_graph(self) -> None:
+        """Reorder graph based on schedule."""
+        exposed = [
+            c
+            for c in self.collective_info.values()
+            if c.exposed_time_ms == c.estimated_time_ms
+        ]
+
+        potentially_hidden_collectives = self.compute_potential_hidden_collectives()
+        bad_exposed = [
+            c for c in exposed if c.start_node in potentially_hidden_collectives
+        ]
+
+        # Compute total exposed and potential exposed time
+        total_exposed = sum(c.exposed_time_ms for c in self.collective_info.values())
+        hideable_exposed_ms = sum(
+            self.collective_info[c].exposed_time_ms
+            for c in potentially_hidden_collectives
+        )
+        total_potential_exposed = sum(
+            c.estimated_time_ms for c in self.collective_info.values()
+        )
+
+        counters["inductor"]["overlap_scheduling_exposed"] += len(exposed)
+        counters["inductor"]["overlap_scheduling_bad_exposed"] += len(bad_exposed)
+        counters["inductor"]["overlap_scheduling_potentially_hidden"] += len(
+            potentially_hidden_collectives
+        )
+        counters["inductor"]["overlap_original_mem"] = self.original_peak_memory
+        counters["inductor"]["rescheduled_mem"] = self.memory_tracker.peak_memory
+
+        log.info(
+            "Overlap scheduling results: exposed=%d, bad_exposed=%d, potentially_hidden=%d, "
+            "original_peak_memory=%d bytes, rescheduled_peak_memory=%d bytes, "
+            "total_exposed_ms=%.2f, hideable_exposed_ms=%.2f, total_potential_exposed_ms=%.2f, "
+            "wasted_compute_ms=%.2f",
+            len(exposed),
+            len(bad_exposed),
+            len(potentially_hidden_collectives),
+            self.original_peak_memory,
+            self.memory_tracker.peak_memory,
+            total_exposed,
+            hideable_exposed_ms,
+            total_potential_exposed,
+            self.wasted_compute,
+        )
+
+        self.reorder_graph()
+
+    def _bucket_collectives(self) -> None:
+        from torch._inductor.fx_passes.overlap_preserving_bucketer import (
+            OverlapPreservingBucketer,
+        )
+
+        bucketer = OverlapPreservingBucketer(
+            graph=self.graph,
+            collective_info=self.collective_info,
+            scheduled=self.scheduled,
+            max_bucket_memory_gb=2.0,  # Could make this configurable
+            max_coll_distance=self.max_node_distance,
+            insert_overlap_deps=self.insert_overlap_deps,
+        )
+        bucketer.bucket_collectives()
+
+    def compute_potential_hidden_nodes(
+        self, nodes_to_check: Iterable[fx.Node]
+    ) -> dict[fx.Node, fx.Node]:
+        """
+        Returns a dict containing a mapping of nodes which could potentially be hidden to their hiding node
+        """
+
+        def could_be_hidden(start: fx.Node) -> fx.Node | None:
+            for compute_node in self.compute_nodes:
+                if (
+                    start not in self.node_ancestors[compute_node]
+                    and compute_node not in self.node_ancestors[start]
+                ):
+                    return compute_node
+
+            return None
+
+        # TODO: We could potentially limit compute nodes per overlap time,
+        # today, this is optimistic, and just serves to avoid deferring
+        # collectives/waits that have no possible overlap as well as for analysis of how
+        # successfully we hid compute
+        potentially_hidden = {}
+        for node in nodes_to_check:
+            if mm := could_be_hidden(node):
+                potentially_hidden[node] = mm
+
+        return potentially_hidden
+
+    def compute_potential_hidden_collectives(self) -> dict[fx.Node, fx.Node]:
+        """Compute which collective operations could be hidden by compute."""
+        return self.compute_potential_hidden_nodes(self.collective_info.keys())
+
+    def compute_potential_hidden_waits(self) -> dict[fx.Node, fx.Node]:
+        """Compute which wait operations could be hidden by compte."""
+        wait_nodes = [info.wait_node for info in self.collective_info.values()]
+        return self.compute_potential_hidden_nodes(wait_nodes)
+
+
+def schedule_overlap_bucketing(
+    gm: torch.fx.GraphModule,
+    max_in_flight_gb: float = 5,
+    max_compute_pre_fetch: int = 200,
+    collective_bucketing: bool = False,
+    insert_overlap_deps: bool = False,
+    compute_overlap_multipler: float = 1.0,
+    max_coll_distance: int = 200,
+    custom_runtime_estimation: Callable[[fx.Node, int | None], float | None]
+    | None = None,
+    collective_estimator: Literal["analytical", "benchmark"] = "analytical",
+    max_memory_increase_gb: float | None = 1.0,
+    max_memory_increase_ratio: float | None = 0.05,
+) -> torch.fx.GraphModule:
+    """Schedule nodes to maximize compute-collective overlap.
+
+    Args:
+        gm: Input graph module to optimize.
+        max_in_flight_gb: Maximum GB of concurrent collective data. Too much in flight memory
+            can cause memory fragmentation within the CUDA Caching Allocator.
+        max_compute_pre_fetch: Maximum mm nodes to pre fetch. Note: should already be limited by max_in_flight_gb and
+            max_memory_increase_gb
+        collective_bucketing: Enable overlap-preserving collective bucketing.
+        insert_overlap_deps: Insert overlap dependencies using control deps operator. This should only be used if
+            compiling with inductor, or for subsequent passes before removing the ops prior to execution.
+        compute_overlap_multipler: Scale factor for compute time used to hide collectives. This can be used
+            to address over or under aggressive overlapping.
+        max_coll_distance: Maximum pre fetch or bucketing candidates. Mainly intended for compile time
+        custom_runtime_estimation: Custom runtime estimation function that estimates runtime in ms for an fx node.
+            If None, uses default estimations. This is currently limited to collectives and compute nodes.
+        collective_estimator: Method for estimating collective runtime. "analytical" uses bandwidth formulas,
+            "benchmark" uses CUDA events with power-of-2 rounding and interpolation.
+        max_memory_increase_gb: Maximum GB increase above baseline memory (absolute cap). If None, no absolute limit.
+        max_memory_increase_ratio: Maximum increase as ratio of baseline peak memory. If None, no ratio limit.
+            Uses minimum of absolute and ratio limits when both are specified.
+    """
+    return OverlapScheduler(
+        gm,
+        compute_overlap_multipler=compute_overlap_multipler,
+        max_in_flight_gb=max_in_flight_gb,
+        max_coll_distance=max_coll_distance,
+        max_compute_pre_fetch=max_compute_pre_fetch,
+        custom_runtime_estimation=custom_runtime_estimation,
+        collective_bucketing=collective_bucketing,
+        insert_overlap_deps=insert_overlap_deps,
+        collective_estimator=collective_estimator,
+        max_memory_increase_gb=max_memory_increase_gb,
+        max_memory_increase_ratio=max_memory_increase_ratio,
+    ).run()
+
+
+def schedule_overlap_bucketing_from_inductor_configs(
+    gm: torch.fx.GraphModule,
+) -> torch.fx.GraphModule:
+    """Schedule nodes to maximize compute-collective overlap using inductor configs.
+
+    Reads configuration from torch._inductor.config.aten_distributed_optimizations
+    and calls schedule_overlap_bucketing with those settings.
+    """
+    from torch._inductor import config
+
+    dist_opts = config.aten_distributed_optimizations
+
+    kwargs: dict[str, object] = {}
+
+    config_keys = (
+        "collective_bucketing",
+        "max_compute_pre_fetch",
+        "custom_runtime_estimation",
+        "insert_overlap_deps",
+        "collective_estimator",
+        "max_memory_increase_gb",
+        "max_memory_increase_ratio",
+        "compute_overlap_multipler",
+        "max_in_flight_gb",
+        "max_coll_distance",
+    )
+    for key in config_keys:
+        if (val := getattr(dist_opts, key, None)) is not None:
+            kwargs[key] = val
+
+    return schedule_overlap_bucketing(gm, **kwargs)  # type: ignore[arg-type]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/pad_mm.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/pad_mm.py
new file mode 100644
index 0000000000000000000000000000000000000000..556b32562dcd5533e526aa02c273ac7aca87b2e4
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/pad_mm.py
@@ -0,0 +1,945 @@
+import functools
+import itertools
+import operator
+import typing
+from collections.abc import Callable, Sequence
+from typing import Any
+
+import torch
+import torch._inductor.runtime.runtime_utils
+from torch import Tensor
+from torch._dynamo.utils import counters, dynamo_timed
+from torch._inductor import utils
+from torch._inductor.autoheuristic.autoheuristic import (
+    AHContext,
+    AutoHeuristic,
+    LocalFeedback,
+)
+from torch._inductor.autoheuristic.autoheuristic_utils import (
+    context_add_strides,
+    context_add_using_tf32,
+    pad_mm_operations,
+    pad_mm_precondition,
+)
+from torch._subclasses.fake_tensor import FakeTensor
+from torch.utils._mode_utils import no_dispatch
+
+from ...utils._triton import has_triton
+from ..pattern_matcher import (
+    fwd_only,
+    gen_register_replacement,
+    joint_fwd_bwd,
+    Match,
+    ReplaceFn,
+    SearchFn,
+)
+
+
+aten = torch.ops.aten
+
+
+# This flag is only used for testing purpose.
+# Changing it to True will ignore comparing do_bench times
+# between original pattern and padded one.
+_skip_do_bench_times = False
+
+
+def fetch_fake_tensors(match: Match, kwarg_names: Sequence[str]) -> list[Tensor]:
+    kwargs = match.kwargs
+    return [kwargs[name].meta["val"] for name in kwarg_names]
+
+
+def unwrap_fake_args(
+    *arg_names: str,
+) -> Callable[[Callable[..., Any]], Callable[[Match], Any]]:
+    def decorator(func: Callable[..., Any]) -> Callable[[Match], Any]:
+        def wrapper(match: Match) -> Any:
+            fake_tensors = fetch_fake_tensors(match, arg_names)
+            return func(*fake_tensors)
+
+        return wrapper
+
+    return decorator
+
+
+def get_alignment_size(x: Tensor) -> int:
+    return get_alignment_size_dtype(x.dtype)
+
+
+def get_alignment_size_dtype(dtype: torch.dtype) -> int:
+    if dtype == torch.float16 or dtype == torch.half or dtype == torch.bfloat16:
+        return 8
+    elif dtype == torch.float32 or dtype == torch.float:
+        return 4
+    else:
+        return 0
+
+
+def check_device(a: Tensor, b: Tensor) -> bool:
+    return (a.is_cuda and b.is_cuda) or (a.is_xpu and b.is_xpu)
+
+
+def check_dtype(a: Tensor, b: Tensor) -> bool:
+    return a.is_floating_point() and b.is_floating_point()
+
+
+def should_pad_common(mat1: Tensor, mat2: Tensor, input: Tensor | None = None) -> bool:
+    # It's fine we have symbolic shapes or strides as long as they
+    # have hints. Later, we will make sure we only pad non-symbolic dimensions.
+    def valid_shape_and_stride(t: Tensor | None) -> bool:
+        if t is None:
+            return True
+
+        symbolic_cnt = 0
+        for x in t.size():
+            if isinstance(x, int):
+                continue
+            elif utils.is_symbolic(x):
+                # pyrefly: ignore [missing-attribute]
+                if not x.node.has_hint():
+                    return False
+                symbolic_cnt += 1
+            else:
+                return False
+        # filter out cases where all dimensions are symbolic
+        if symbolic_cnt == len(t.size()):
+            return False
+        return all(
+            # pyrefly: ignore [missing-attribute]
+            isinstance(x, int) or (utils.is_symbolic(x) and x.node.has_hint())
+            for x in t.stride()
+        )
+
+    return (
+        torch._inductor.config.shape_padding
+        and check_device(mat1, mat2)
+        and check_dtype(mat1, mat2)
+        and all(valid_shape_and_stride(t) for t in (mat1, mat2, input))
+    )
+
+
+def get_padded_length(x: int | torch.SymInt, alignment_size: int) -> int:
+    # we don't pad x if it is symbolic
+    if isinstance(x, torch.SymInt) or alignment_size == 0 or x % alignment_size == 0:
+        return 0
+
+    # ignore dim that can be squeezed away
+    if x == 1:
+        return 0
+
+    return int((x // alignment_size + 1) * alignment_size) - x
+
+
+def pad_dim(x: Tensor, padded_length: int, dim: int) -> Tensor:
+    if padded_length == 0:
+        return x
+    pad = x.new_zeros(*x.shape[:dim], padded_length, *x.shape[dim + 1 :])
+    return torch.cat([x, pad], dim=dim)
+
+
+def addmm_pattern(
+    input: Tensor, mat1: Tensor, mat2: Tensor, beta: float, alpha: float
+) -> Tensor:
+    return aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha)
+
+
+def should_pad_addmm(match: Match) -> bool:
+    mat1, mat2, input = fetch_fake_tensors(match, ("mat1", "mat2", "input"))
+    return should_pad_common(mat1, mat2, input) and should_pad_bench(
+        match, mat1, mat2, torch.ops.aten.addmm, input=input
+    )
+
+
+def pad_addmm(
+    input: Tensor | None,
+    mat1: Tensor,
+    mat2: Tensor,
+    m_padded_length: int,
+    k_padded_length: int,
+    n_padded_length: int,
+    beta: float = 1.0,
+    alpha: float = 1.0,
+    mat1_pre_padded: bool = False,
+    mat2_pre_padded: bool = False,
+) -> Tensor:
+    # for paddings, dim order is reversed for some reasons
+    # and for every dim, we need to specify left and right padding
+    if not mat1_pre_padded:
+        mat1 = pad_mat1(
+            mat1, m_padded_length=m_padded_length, k_padded_length=k_padded_length
+        )
+    if not mat2_pre_padded:
+        mat2 = pad_mat2(
+            mat2, k_padded_length=k_padded_length, n_padded_length=n_padded_length
+        )
+
+    # the add broadcasts, so we only pad if the dimension != 1
+    if input is not None:
+        if n_padded_length != 0:
+            if input.dim() == 2 and input.shape[1] != 1:
+                input = pad_dim(input, n_padded_length, 1)
+            elif input.dim() == 1 and input.shape[0] != 1:
+                input = pad_dim(input, n_padded_length, 0)
+        if m_padded_length != 0 and input.dim() == 2 and input.shape[0] != 1:
+            input = pad_dim(input, m_padded_length, 0)
+
+    res = aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha)
+
+    if m_padded_length != 0:
+        res = res[:-m_padded_length, :]
+    if n_padded_length != 0:
+        res = res[:, :-n_padded_length]
+    return res
+
+
+def addmm_replace(
+    input: Tensor | None,
+    mat1: Tensor,
+    mat2: Tensor,
+    beta: float = 1.0,
+    alpha: float = 1.0,
+) -> Tensor:
+    k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
+    n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2))
+    m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
+    return pad_addmm(
+        input,
+        mat1,
+        mat2,
+        m_padded_length,
+        k_padded_length,
+        n_padded_length,
+        beta,
+        alpha,
+    )
+
+
+def is_mm_compute_bound(M: int, K: int, N: int, dtype: torch.dtype) -> bool:
+    denominator = M * K + N * K + M * N
+    if denominator == 0:
+        return False
+    arithmetic_intensity = (M * N * K) / denominator
+
+    # we have experienced some large perf hits in this case, even in bandwidth bound regimes
+    if (
+        dtype is torch.bfloat16
+        and K > M
+        and K > N
+        and (torch.xpu.is_available() or torch.cuda.get_device_capability() < (9, 0))
+    ):  # doesn't repro on h100s:
+        return True
+
+    # Fails with AMD
+    try:
+        machine_balance = (
+            1000 * utils.get_device_tflops(dtype)
+        ) / utils.get_gpu_dram_gbps()
+    except Exception:
+        return True
+
+    # dram_gbps might be underestimating bandwidth because of cache.
+    # if we estimate machine balance too low we might miss some speedups,
+    # if we estimate too high there will be unnecessary compilation time increase.
+    # TODO - finetune coefficient here. As a reference point, Triton mm model assumes
+    # 80% of reads are in cache and cache is 4x faster than dram_gbps
+    machine_balance = machine_balance * 0.5
+
+    return arithmetic_intensity > machine_balance
+
+
+@functools.cache
+def get_pad_cache() -> torch._inductor.codecache.LocalCache:
+    return torch._inductor.codecache.LocalCache()
+
+
+def get_cached_should_pad(key: str) -> bool:
+    return get_pad_cache().lookup(key)  # type: ignore[return-value]
+
+
+def set_cached_should_pad(key: str, value: bool) -> None:
+    return get_pad_cache().set_value(key, value=value)
+
+
+def get_cached_base_mm_benchmark_time(key: str) -> float:
+    return get_pad_cache().lookup(key)  # type: ignore[return-value]
+
+
+def set_cached_base_mm_benchmark_time(key: str, value: float) -> None:
+    return get_pad_cache().set_value(key, value=value)
+
+
+def should_pad_bench_key(
+    match: Match,
+    mat1: Tensor,
+    mat2: Tensor,
+    op: torch._ops.OpOverloadPacket,
+    input: Tensor | None = None,
+    is_base_time_key: bool = False,
+) -> str:
+    def tensor_key(t: Tensor) -> tuple[torch.Size, tuple[int, ...], torch.dtype]:
+        return (t.shape, t.stride(), t.dtype)
+
+    tf32_key = (
+        None
+        if mat1.dtype != torch.float32
+        else torch.backends.cuda.matmul.allow_tf32 or torch.backends.mkldnn.allow_tf32
+    )
+
+    def fmt_pad(name: str) -> str | None:
+        if is_base_time_key:
+            return None
+        return f"exclude_pad:{should_exclude_padding_time(match, name)}"
+
+    key = (
+        tensor_key(mat1),
+        tensor_key(mat2),
+        fmt_pad("mat1"),
+        fmt_pad("mat2"),
+        op,
+        input if input is None else tensor_key(input),
+        tf32_key,
+    )
+
+    key = str(key)
+    if is_base_time_key:
+        key = f"base mm time: {key}"
+    return key
+
+
+def get_non_view_def(node: torch.fx.Node) -> torch.fx.Node:
+    if node.op is operator.getitem:
+        return get_non_view_def(node.args[0])  # type: ignore[arg-type]
+
+    if (
+        node.op == "call_function"
+        and isinstance(node.target, torch._ops.OpOverload)
+        and utils.is_view(node.target)
+    ):
+        return get_non_view_def(node.all_input_nodes[0])
+
+    return node
+
+
+def should_exclude_padding_time(match: Match, arg_name: str) -> bool:
+    node_def = get_non_view_def(match.kwargs[arg_name])
+
+    # constant padding converts tensors to contiguous so even if the input tensor
+    # can be planned layout transform is not free. TODO - way to pad and preserve layout ?
+    if not fetch_fake_tensors(match, (arg_name,))[0].is_contiguous():
+        return False
+
+    # TODO - see issue https://github.com/pytorch/pytorch/issues/128889
+    # We would only able to completely plan these out if we were only doing
+    # first dimension padding. non-first we would still need a copy
+    # because these outputs are fixed dense.
+    cannot_plan_output = [
+        aten.mm.default,
+        aten.convolution.default,
+        aten.convolution_backward.default,
+        aten.bmm.default,
+        aten.addmm.default,
+        aten._scaled_dot_product_flash_attention.default,
+        aten._scaled_dot_product_efficient_attention.default,
+    ]
+
+    if node_def.target in cannot_plan_output:
+        return False
+
+    if (
+        node_def.target is aten.cat.default
+        and len(node_def.all_input_nodes)
+        > torch._inductor.config.max_pointwise_cat_inputs
+    ):
+        return False
+
+    # optimistically assume we should be able to memory plan away
+    # all non inputs
+    return node_def.op != "placeholder"
+
+
+def should_pad(key: str, ori_time: float, pad_time: float) -> bool:
+    multiplier = 1.1
+    # Shape padding introduces additional memory ops. Based on microbenchmarks, 1.1x represents a reasonable
+    # tradeoff between performance improvement from shape padding and overhead from additional memory ops
+    # TODO: Build a learned model which would be better than this heuristic
+    if "shape_padding_multiplier" in torch._inductor.config.post_grad_fusion_options:
+        multiplier = torch._inductor.config.post_grad_fusion_options[
+            "shape_padding_multiplier"
+        ].get("value", 1.1)
+        counters["inductor"]["shape_padding_multiplier"] += 1
+    should_pad = _skip_do_bench_times or ori_time > pad_time * multiplier
+    set_cached_should_pad(key, should_pad)
+    return should_pad
+
+
+def should_pad_mm_bf16(dtype: torch.dtype, M: int, N: int, K: int) -> bool:
+    # always force pad for mm with bf16 when the following are satisfied to avoid perf regression
+    large_k_threshold_to_pad = torch._inductor.config.post_grad_fusion_options[
+        "pad_aten_mm_pass"
+    ].get("k_threshold_to_pad", 8388608)
+    if (
+        dtype is torch.bfloat16
+        and K > M
+        and K > N
+        and N % 2 == 1
+        and K >= large_k_threshold_to_pad
+        and (torch.xpu.is_available() or torch.cuda.get_device_capability() < (9, 0))
+    ):  # doesn't repro on h100s:
+        return True
+    return False
+
+
+def should_pad_bench(*args: Any, **kwargs: Any) -> bool:
+    with dynamo_timed(
+        "pad_mm_benchmark",
+        log_pt2_compile_event=False,
+        dynamo_compile_column_us="compile_time_autotune_time_us",
+    ):
+        return _should_pad_bench(*args, **kwargs)
+
+
+def get_do_bench() -> Callable[[Callable[[], Any]], float]:
+    with dynamo_timed("pad_mm_benchmark_get_do_bench"):
+        return functools.partial(
+            # pyrefly: ignore [bad-argument-type]
+            torch._inductor.runtime.benchmarking.benchmarker.benchmark_gpu,
+            warmup=5,
+        )
+
+
+def _should_pad_bench(
+    match: Match,
+    mat1: Tensor,
+    mat2: Tensor,
+    op: torch._ops.OpOverloadPacket,
+    input: Tensor | None = None,
+) -> bool:
+    do_bench = get_do_bench()
+
+    m_padded_length = 0
+    n_padded_length = 0
+    with no_dispatch():
+        if op is torch.ops.aten.mm or op is torch.ops.aten.addmm:
+            m = mat1.shape[0]
+            k = mat1.shape[1]
+            n = mat2.shape[1]
+            k_padded_length = get_padded_length(k, get_alignment_size(mat1))
+            n_padded_length = get_padded_length(n, get_alignment_size(mat2))
+            m_padded_length = get_padded_length(m, get_alignment_size(mat1))
+        elif op is torch.ops.aten.bmm:
+            m = mat1.shape[1]
+            k = mat1.shape[2]
+            n = mat2.shape[2]
+            k_padded_length = get_padded_length(k, get_alignment_size(mat1))
+            m_padded_length = get_padded_length(m, get_alignment_size(mat1))
+            n_padded_length = get_padded_length(n, get_alignment_size(mat2))
+        else:
+            return False
+
+        if m_padded_length == k_padded_length == n_padded_length == 0:
+            return False
+
+        def realize_symbols(
+            ds: torch.Size | tuple[torch.SymInt, ...],
+        ) -> list[int]:
+            return [d if isinstance(d, int) else d.node.hint for d in ds]
+
+        if any(
+            dim == 0
+            for dim in itertools.chain(
+                realize_symbols(mat1.shape), realize_symbols(mat2.shape)
+            )
+        ):
+            return False
+
+        if torch._inductor.config.force_shape_pad:
+            return True
+
+        if torch._inductor.config.deterministic:
+            # In deterministic mode, don't benchmark for pad-mm and assumes
+            # no padding.
+            #
+            # Check the deterministic mode after 'force_shape_pad'
+            # so unit test relying on force_shape_pad should still pass
+            return False
+
+        if (
+            "pad_aten_mm_pass" in torch._inductor.config.post_grad_fusion_options
+            and should_pad_mm_bf16(mat1.dtype, m, n, k)
+        ):
+            return True
+
+        if not has_triton():
+            return False
+
+        if not is_mm_compute_bound(m, k, n, mat1.dtype):
+            return False
+
+        # We don't want to look up the cache for cases that are trivially false
+        # since it does file io
+        key = should_pad_bench_key(match, mat1, mat2, op, input)
+
+        cached_pad = get_cached_should_pad(key)
+        if cached_pad is not None:
+            return cached_pad
+
+        def realize_tensor(t):
+            if isinstance(t, FakeTensor):
+                size_hints = realize_symbols(t.size())
+                # pyrefly: ignore [bad-argument-type]
+                stride_hint = realize_symbols(t.stride())
+                real_size = (
+                    sum((d - 1) * s for d, s in zip(size_hints, stride_hint)) + 1
+                )
+                real_t = torch.randn(real_size, dtype=t.dtype, device=t.device)
+                return torch.as_strided(real_t, size_hints, stride_hint)
+            else:
+                return torch.randn_like(t)
+
+        mat1 = realize_tensor(mat1)
+        mat2 = realize_tensor(mat2)
+
+        # since we key on whether or not the inputs can be memory planned, set cache for the
+        # original time which is unaffected by whether or not the input can be planned
+        ori_time_key = should_pad_bench_key(
+            match, mat1, mat2, op, input, is_base_time_key=True
+        )
+        ori_time = get_cached_base_mm_benchmark_time(ori_time_key)
+        if ori_time is None and op is torch.ops.aten.addmm and input is not None:
+            # realize bias for addmm
+            input = realize_tensor(input)
+
+        mat1_pad = mat1
+        mat2_pad = mat2
+
+        is_bmm = op is torch.ops.aten.bmm
+
+        mat1_pre_padded = should_exclude_padding_time(match, "mat1")
+        fns = []
+        if mat1_pre_padded and (m_padded_length or k_padded_length):
+            mat1_pad = pad_mat1(
+                mat1_pad,
+                m_padded_length=m_padded_length,
+                k_padded_length=k_padded_length,
+                is_bmm=is_bmm,
+            )
+
+            def write_pad():
+                if is_bmm:
+                    mat1_pad[:, -m_padded_length:, -k_padded_length:].fill_(0)
+                else:
+                    mat1_pad[-m_padded_length:, -k_padded_length:].fill_(0)
+
+            fns.append(write_pad)
+
+        mat2_pre_padded = should_exclude_padding_time(match, "mat2")
+        if mat2_pre_padded and (k_padded_length or n_padded_length):
+            mat2_pad = pad_mat2(
+                mat2_pad,
+                k_padded_length=k_padded_length,
+                n_padded_length=n_padded_length,
+                is_bmm=is_bmm,
+            )
+
+            def write_pad():
+                if is_bmm:
+                    mat2_pad[:, -k_padded_length:, -n_padded_length:].fill_(0)
+                else:
+                    mat2_pad[-k_padded_length:, -n_padded_length:].fill_(0)
+
+            fns.append(write_pad)
+
+        if op is torch.ops.aten.addmm:
+            input_pad = None
+            if input is not None and (input.is_cuda or input.is_xpu):
+                input_pad = torch.randn_like(input)
+            fns.append(
+                lambda: pad_addmm(
+                    input_pad,
+                    mat1_pad,
+                    mat2_pad,
+                    m_padded_length,
+                    k_padded_length,
+                    n_padded_length,
+                    mat1_pre_padded=mat1_pre_padded,
+                    mat2_pre_padded=mat2_pre_padded,
+                )
+            )
+        elif op is torch.ops.aten.mm:
+            fns.append(
+                lambda: pad_mm(
+                    mat1_pad,
+                    mat2_pad,
+                    m_padded_length,
+                    k_padded_length,
+                    n_padded_length,
+                    mat1_pre_padded=mat1_pre_padded,
+                    mat2_pre_padded=mat2_pre_padded,
+                )
+            )
+        else:
+            fns.append(
+                lambda: pad_bmm(
+                    mat1_pad,
+                    mat2_pad,
+                    m_padded_length,
+                    k_padded_length,
+                    n_padded_length,
+                    mat1_pre_padded=mat1_pre_padded,
+                    mat2_pre_padded=mat2_pre_padded,
+                )
+            )
+
+        def orig_bench_fn():
+            if op is torch.ops.aten.bmm or op is torch.ops.aten.mm:
+                op(mat1, mat2)
+            else:
+                op(input, mat1, mat2)
+
+        def pad_bench_fn():
+            for fn in fns:
+                fn()
+
+        if (
+            torch._inductor.config.run_autoheuristic("pad_mm")
+            and op is torch.ops.aten.mm
+        ):
+            ah_should_pad = run_autoheuristic(
+                mat1,
+                mat2,
+                orig_bench_fn,
+                pad_bench_fn,
+                m_padded_length,
+                k_padded_length,
+                n_padded_length,
+                do_bench,
+                mat1_pre_padded,
+                mat2_pre_padded,
+                ori_time,
+                ori_time_key,
+                key,
+            )
+            if ah_should_pad is not None:
+                return ah_should_pad
+
+        if ori_time is None:
+            ori_time = do_bench(orig_bench_fn)
+            set_cached_base_mm_benchmark_time(ori_time_key, ori_time)
+
+        pad_time = do_bench(pad_bench_fn)
+
+        counters["inductor"]["pad_mm_bench"] += 1
+        return should_pad(key, ori_time, pad_time)
+
+
+def get_context(
+    mat1: Tensor,
+    mat2: Tensor,
+    mat1_pre_padded: bool,
+    mat2_pre_padded: bool,
+    m_padded_length: int,
+    k_padded_length: int,
+    n_padded_length: int,
+) -> AHContext:
+    context = AHContext()
+
+    context.add_feature("m", mat1.shape[0])
+    context.add_feature("k", mat1.shape[1])
+    context.add_feature("n", mat2.shape[1])
+
+    context_add_strides(context, "mat1", mat1.stride())
+    context_add_strides(context, "mat2", mat2.stride())
+
+    context.add_feature("m_padded_length", m_padded_length)
+    context.add_feature("k_padded_length", k_padded_length)
+    context.add_feature("n_padded_length", n_padded_length)
+
+    context.add_feature("mat1_align_size", get_alignment_size(mat1))
+    context.add_feature("mat2_align_size", get_alignment_size(mat2))
+
+    context.add_feature("mat1_dtype", mat1.dtype, is_categorical=True)
+    context.add_feature("mat2_dtype", mat2.dtype, is_categorical=True)
+
+    context.add_feature("prepadded_mat1", mat1_pre_padded, is_categorical=True)
+    context.add_feature("prepadded_mat2", mat2_pre_padded, is_categorical=True)
+
+    context_add_using_tf32(context, mat1.dtype)
+    return context
+
+
+def run_autoheuristic(
+    mat1: Tensor,
+    mat2: Tensor,
+    orig_bench_fn: Callable[[], None],
+    pad_bench_fn: Callable[[], None],
+    m_padded_length: int,
+    k_padded_length: int,
+    n_padded_length: int,
+    do_bench: Callable[[Callable[[], Any]], float],
+    mat1_pre_padded: bool,
+    mat2_pre_padded: bool,
+    ori_time: float,
+    ori_time_key: str,
+    key: str,
+) -> bool | None:
+    def feedback_fn(
+        choice: str,
+    ) -> float | None:
+        if choice == orig_choice:
+            return do_bench(orig_bench_fn)
+        elif choice == pad_choice:
+            return do_bench(pad_bench_fn)
+        return None
+
+    def fallback() -> str:
+        return "autotune"
+
+    orig_choice = "orig"
+    pad_choice = "pad"
+    choices = [orig_choice, pad_choice]
+    feedback = LocalFeedback(feedback_fn)  # type: ignore[arg-type]
+    context = get_context(
+        mat1,
+        mat2,
+        mat1_pre_padded,
+        mat2_pre_padded,
+        m_padded_length,
+        k_padded_length,
+        n_padded_length,
+    )
+    name = "pad_mm"
+    autoheuristic = AutoHeuristic(
+        fallback=fallback,
+        choices=choices,
+        feedback=feedback,
+        context=context,
+        name=name,
+        augment_context=pad_mm_operations(),
+        precondition=pad_mm_precondition,
+    )
+    choice = autoheuristic.get_choice()
+    choice2should_pad = {orig_choice: False, pad_choice: True, "autotune": None}
+    ah_should_pad = choice2should_pad.get(choice)
+
+    if torch._inductor.config.collect_autoheuristic(name):
+        ah_ori_time = autoheuristic.get_collected_feedback(orig_choice)
+        ah_pad_time = autoheuristic.get_collected_feedback(pad_choice)
+
+        # if precondition is not satisfied, autoheuristic does not collect data
+        if ah_ori_time is not None and ah_pad_time is not None:
+            if ori_time is None:
+                set_cached_base_mm_benchmark_time(ori_time_key, ah_ori_time)
+            return should_pad(key, ah_ori_time, ah_pad_time)
+    if ah_should_pad is not None:
+        set_cached_should_pad(key, ah_should_pad)
+    return ah_should_pad
+
+
+def mm_pattern(mat1: Tensor, mat2: Tensor) -> Tensor:
+    return aten.mm(mat1, mat2)
+
+
+def should_pad_mm(match: Match) -> bool:
+    mat1, mat2 = fetch_fake_tensors(match, ("mat1", "mat2"))
+    return should_pad_common(mat1, mat2) and should_pad_bench(
+        match, mat1, mat2, torch.ops.aten.mm
+    )
+
+
+def pad_mat1(
+    mat1: Tensor, *, m_padded_length: int, k_padded_length: int, is_bmm: bool = False
+) -> Tensor:
+    if k_padded_length != 0 or m_padded_length != 0:
+        # dim order is reversed for constant_pad_nd, for every dim we specify right and left padding
+        pad_arg = [0, k_padded_length, 0, m_padded_length]
+        if is_bmm:
+            pad_arg.extend((0, 0))
+        return aten.constant_pad_nd(mat1, pad_arg)
+    else:
+        return mat1
+
+
+def pad_mat2(
+    mat2: Tensor, *, k_padded_length: int, n_padded_length: int, is_bmm: bool = False
+) -> Tensor:
+    if k_padded_length != 0 or n_padded_length != 0:
+        # dim order is reversed for constant_pad_nd, for every dim we specify right and left padding
+        pad_arg = [0, n_padded_length, 0, k_padded_length]
+        if is_bmm:
+            pad_arg.extend((0, 0))
+        return aten.constant_pad_nd(mat2, pad_arg)
+    else:
+        return mat2
+
+
+def pad_mm(
+    mat1: Tensor,
+    mat2: Tensor,
+    m_padded_length: int,
+    k_padded_length: int,
+    n_padded_length: int,
+    mat1_pre_padded: bool = False,
+    mat2_pre_padded: bool = False,
+) -> Tensor:
+    if not mat1_pre_padded:
+        mat1 = pad_mat1(
+            mat1, m_padded_length=m_padded_length, k_padded_length=k_padded_length
+        )
+    if not mat2_pre_padded:
+        mat2 = pad_mat2(
+            mat2, k_padded_length=k_padded_length, n_padded_length=n_padded_length
+        )
+    res = aten.mm(mat1, mat2)
+    if m_padded_length != 0:
+        res = res[:-m_padded_length, :]
+    if n_padded_length != 0:
+        res = res[:, :-n_padded_length]
+    return res
+
+
+def mm_replace(mat1: Tensor, mat2: Tensor) -> Tensor:
+    k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
+    m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
+    n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2))
+    return pad_mm(
+        mat1,
+        mat2,
+        m_padded_length,
+        k_padded_length,
+        n_padded_length,
+    )
+
+
+def bmm_pattern(mat1: Tensor, mat2: Tensor) -> Tensor:
+    return aten.bmm(mat1, mat2)
+
+
+def should_pad_bmm(match: Match) -> bool:
+    mat1, mat2 = fetch_fake_tensors(match, ("mat1", "mat2"))
+    return should_pad_common(mat1, mat2) and should_pad_bench(
+        match, mat1, mat2, torch.ops.aten.bmm
+    )
+
+
+def pad_bmm(
+    mat1: Tensor,
+    mat2: Tensor,
+    m_padded_length: int,
+    k_padded_length: int,
+    n_padded_length: int,
+    mat1_pre_padded: bool = False,
+    mat2_pre_padded: bool = False,
+) -> Tensor:
+    if not mat1_pre_padded:
+        mat1 = pad_mat1(
+            mat1,
+            m_padded_length=m_padded_length,
+            k_padded_length=k_padded_length,
+            is_bmm=True,
+        )
+    if not mat2_pre_padded:
+        mat2 = pad_mat2(
+            mat2,
+            k_padded_length=k_padded_length,
+            n_padded_length=n_padded_length,
+            is_bmm=True,
+        )
+    res = aten.bmm(mat1, mat2)
+    if m_padded_length != 0:
+        res = res[:, :-m_padded_length, :]
+    if n_padded_length != 0:
+        res = res[:, :, :-n_padded_length]
+    return res
+
+
+def bmm_replace(mat1: Tensor, mat2: Tensor) -> Tensor:
+    k_padded_length = get_padded_length(mat1.shape[2], get_alignment_size(mat1))
+    n_padded_length = get_padded_length(mat2.shape[2], get_alignment_size(mat2))
+    m_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
+    return pad_bmm(
+        mat1,
+        mat2,
+        m_padded_length,
+        k_padded_length,
+        n_padded_length,
+    )
+
+
+@functools.cache
+def _pad_mm_init() -> None:
+    from .joint_graph import patterns
+
+    if torch.cuda.is_available():
+        # workaround https://github.com/pytorch/pytorch/issues/97894
+        device = "cuda"
+    elif torch.xpu.is_available():
+        device = "xpu"
+    else:
+        device = "cpu"
+
+    # sizes/values dont actually matter for initial trace
+    # once we get a possible match we re-trace with the actual values and verify the match still holds
+
+    dim2a = functools.partial(torch.empty, (4, 4), device=device, requires_grad=True)
+    dim2b = functools.partial(torch.empty, (4, 4), device=device, requires_grad=True)
+
+    dim3a = functools.partial(torch.empty, (4, 4, 4), device=device, requires_grad=True)
+    dim3b = functools.partial(torch.empty, (4, 4, 4), device=device, requires_grad=True)
+
+    dim1a = functools.partial(torch.empty, (4), device=device, requires_grad=True)
+
+    # workaround https://github.com/pytorch/pytorch/issues/97894
+    # 0.113377 is a "magic" value that lets us recover the lost input arg relationship
+    rep = {"beta": 0.213377, "alpha": 0.113377}
+
+    for pattern, replacement, args, workaround, extra_check in [
+        (
+            typing.cast(SearchFn, mm_pattern),
+            typing.cast(ReplaceFn, mm_replace),
+            [dim2a(), dim2b()],
+            {},
+            should_pad_mm,
+        ),
+        (
+            typing.cast(SearchFn, bmm_pattern),
+            typing.cast(ReplaceFn, bmm_replace),
+            [dim3a(), dim3b()],
+            {},
+            should_pad_bmm,
+        ),
+        (
+            typing.cast(SearchFn, addmm_pattern),
+            typing.cast(ReplaceFn, addmm_replace),
+            [dim1a(), dim2a(), dim2b()],
+            rep,
+            should_pad_addmm,
+        ),
+    ]:
+        assert isinstance(workaround, dict)  # mypy is unable to infer the type properly
+        name = pattern.__name__
+
+        gen_register_replacement(
+            f"{name}_training",
+            pattern,
+            replacement,
+            args,
+            # pyrefly: ignore [bad-argument-type]
+            joint_fwd_bwd,
+            # pyrefly: ignore [bad-argument-type]
+            patterns,
+            extra_check=extra_check,
+            scalar_workaround=workaround,
+        )
+
+        gen_register_replacement(
+            f"{name}_inference",
+            pattern,
+            replacement,
+            args,
+            # pyrefly: ignore [bad-argument-type]
+            fwd_only,
+            # pyrefly: ignore [bad-argument-type]
+            patterns,
+            extra_check=extra_check,
+            scalar_workaround=workaround,
+        )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/post_grad.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/post_grad.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a350b81bbecb47b044c3805bd8af82b04531d45
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/post_grad.py
@@ -0,0 +1,1923 @@
+# mypy: allow-untyped-decorators
+# mypy: allow-untyped-defs
+import functools
+import itertools
+import logging
+import operator
+from collections import Counter, defaultdict
+from collections.abc import Callable
+from typing import Any, TypeVar
+from typing_extensions import ParamSpec
+
+import torch
+import torch._inductor as inductor
+import torch.utils._pytree as pytree
+from torch import fx
+from torch._decomp import register_decomposition
+from torch._dynamo.utils import counters
+from torch._inductor import comms
+from torch._inductor.virtualized import ops  # noqa: F401
+from torch._logging import trace_structured
+from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype
+from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
+from torch.utils._ordered_set import OrderedSet
+
+from .. import config, ir, pattern_matcher  # noqa: F401
+from ..codegen.common import custom_backend_passes
+from ..comms import remove_fsdp2_unsharded_param_graph_input_usage
+from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage
+from ..lowering import lowerings as L
+from ..pattern_matcher import (
+    _return_true,
+    Arg,
+    CallFunction,
+    CallFunctionVarArgs,
+    filter_nodes,
+    fwd_only,
+    get_arg_value,
+    get_mutation_region_id,
+    Ignored,
+    init_once_fakemode,
+    KeywordArg,
+    ListOf,
+    Match,
+    MultiOutputPattern,
+    MULTIPLE,
+    PatternMatcherPass as PatternMatcherPassBase,
+    register_graph_pattern,
+    register_replacement,
+    stable_topological_sort,
+)
+from ..utils import (
+    decode_device,
+    get_all_devices,
+    get_gpu_type,
+    is_gpu,
+    is_pointwise_use,
+    OPTIMUS_EXCLUDE_POST_GRAD,
+)
+from ..virtualized import V
+from .b2b_gemm import B2B_GEMM_PASS
+from .ddp_fusion import fuse_ddp_communication
+from .group_batch_fusion import group_batch_fusion_passes, POST_GRAD_FUSIONS
+from .micro_pipeline_tp import micro_pipeline_tp_pass
+from .pre_grad import is_same_dict, save_inductor_dict
+from .reinplace import reinplace_inplaceable_ops
+from .split_cat import POST_GRAD_PATTERNS
+
+
+_T = TypeVar("_T")
+_P = ParamSpec("_P")
+
+PatternMatcherPass = functools.partial(
+    PatternMatcherPassBase, subsystem="post_grad_passes"
+)
+
+log = logging.getLogger(__name__)
+aten = torch.ops.aten
+prims = torch.ops.prims
+
+# First pass_patterns[0] are applied, then [1], then [2]
+pass_patterns = [
+    PatternMatcherPass(),
+    PatternMatcherPass(),
+    PatternMatcherPass(),
+]
+
+
+def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
+    """
+    Passes that run on after grad.  This is called once on the forwards
+    graph and once on the backwards graph.
+
+    The IR here has been normalized and functionalized.
+    """
+    GraphTransformObserver = functools.partial(
+        torch.fx.passes.graph_transform_observer.GraphTransformObserver,
+        subsystem="post_grad_passes",
+    )
+
+    if not torch._dynamo.config.skip_fsdp_hooks:
+        remove_fsdp2_unsharded_param_graph_input_usage(gm.graph)
+
+    if config.dce:
+        # has some issues with mutation in inference mode
+        gm.graph.eliminate_dead_code()
+
+    if is_inference and config.reorder_for_locality:
+        GraphTransformObserver(gm, "reorder_for_locality").apply_graph_pass(
+            reorder_for_locality
+        )
+
+    fake_tensor_updater = FakeTensorUpdater(gm.graph)
+
+    if post_grad_custom_pre_pass := config.post_grad_custom_pre_pass:
+        GraphTransformObserver(gm, "post_grad_custom_pre_pass").apply_graph_pass(
+            post_grad_custom_pre_pass
+        )
+
+    if torch._C._has_mkldnn:
+        if (
+            config.cpp.enable_grouped_gemm_template
+            and config.max_autotune
+            and "CPP" in config.max_autotune_gemm_backends
+        ):
+            from .mkldnn_fusion import grouped_gemm_pass
+
+            grouped_gemm_pass(gm.graph)
+
+        if config.cpp.enable_concat_linear:
+            from .quantization import concat_linear_woq_int4
+
+            # Concat linear optimization for WOQ int4
+            concat_linear_woq_int4(gm)
+
+    if config.pattern_matcher:
+        lazy_init()
+        GraphTransformObserver(gm, "post_grad_custom_pre_pass").apply_graph_pass(
+            functools.partial(group_batch_fusion_passes, pre_grad=False)
+        )
+        GraphTransformObserver(gm, "remove_noop_ops").apply_graph_pass(remove_noop_ops)
+        GraphTransformObserver(gm, "remove_assert_ops").apply_graph_pass(
+            remove_assert_ops
+        )
+        for i, patterns in enumerate(pass_patterns):
+            GraphTransformObserver(gm, f"pass_pattern_{i}").apply_graph_pass(
+                patterns.apply
+            )
+        for pass_name in config.post_grad_fusion_options:
+            # skip all patterns for group batch fusions or quantization patterns
+            if pass_name in POST_GRAD_FUSIONS or pass_name in OPTIMUS_EXCLUDE_POST_GRAD:
+                continue
+            pattern_matcher_pass = POST_GRAD_PATTERNS[pass_name]
+            inductor_before_change = save_inductor_dict(
+                [pattern_matcher_pass.pass_name]
+            )
+            GraphTransformObserver(gm, pass_name).apply_graph_pass(
+                pattern_matcher_pass.apply
+            )
+            if not is_same_dict(counters["inductor"], inductor_before_change):
+                trace_structured(
+                    "artifact",
+                    metadata_fn=lambda: {
+                        "name": f"{pattern_matcher_pass.pass_name}_post_grad",
+                        "encoding": "string",
+                    },
+                    payload_fn=lambda: gm.print_readable(
+                        print_output=False, include_stride=True, include_device=True
+                    ),
+                )
+        if config.b2b_gemm_pass:
+            B2B_GEMM_PASS.apply(gm.graph)  # type: ignore[arg-type]
+
+    if config._micro_pipeline_tp:
+        micro_pipeline_tp_pass(gm.graph)
+
+    if config._fuse_ddp_communication:
+        GraphTransformObserver(gm, "fuse_ddp_communication").apply_graph_pass(
+            lambda graph: fuse_ddp_communication(
+                graph,
+                config._fuse_ddp_communication_passes,
+                config._fuse_ddp_bucket_size,
+            )
+        )
+
+    if post_grad_custom_post_pass := config.post_grad_custom_post_pass:
+        GraphTransformObserver(gm, "post_grad_custom_post_pass").apply_graph_pass(
+            post_grad_custom_post_pass
+        )
+
+    GraphTransformObserver(gm, "stable_sort").apply_graph_pass(stable_topological_sort)
+
+    GraphTransformObserver(gm, "move_constructors_to_cuda").apply_graph_pass(
+        move_constructors_to_gpu
+    )
+
+    fake_tensor_updater.incremental_update()
+
+    for device, custom_backend_pass in custom_backend_passes.items():
+        if custom_backend_pass is not None:
+            gm_devices = [d.type for d in get_all_devices(gm)]
+            if device in gm_devices:
+                pass_name = "custom_backend_passes_" + device
+                GraphTransformObserver(gm, pass_name).apply_gm_pass(custom_backend_pass)
+
+    collectives_bucketing: bool = False
+
+    if config.bucket_reduce_scatters_fx != "none":
+        from torch._inductor.fx_passes.bucketing import bucket_reduce_scatter
+        from torch._inductor.fx_passes.fsdp import bucket_fsdp_reduce_scatter
+
+        p = (
+            bucket_fsdp_reduce_scatter
+            if "fsdp" in config.bucket_reduce_scatters_fx
+            else bucket_reduce_scatter
+        )
+        GraphTransformObserver(gm, "bucket_reduce_scatters").apply_graph_pass(
+            lambda graph: p(
+                graph.owning_module,
+                config.bucket_reduce_scatters_fx_bucket_size_determinator,
+                config.bucket_reduce_scatters_fx,  # type: ignore[arg-type]
+            )
+        )
+        collectives_bucketing = True
+
+    if config.bucket_all_reduces_fx != "none":
+        from torch._inductor.fx_passes.bucketing import bucket_all_reduce
+
+        GraphTransformObserver(gm, "bucket_all_reduce").apply_graph_pass(
+            lambda graph: bucket_all_reduce(
+                graph.owning_module,
+                config.bucket_all_reduces_fx_bucket_size_determinator,
+                config.bucket_all_reduces_fx,  # type: ignore[arg-type]
+            )
+        )
+        collectives_bucketing = True
+
+    # Fx all_gather bucketing introduces mutation op
+    # Keeping it in the end to keep invariant of functional graph for previous passes.
+    if config.bucket_all_gathers_fx != "none":
+        from torch._inductor.fx_passes.bucketing import bucket_all_gather
+        from torch._inductor.fx_passes.fsdp import bucket_fsdp_all_gather
+
+        p = (
+            bucket_fsdp_all_gather  # type: ignore[assignment]
+            if "fsdp" in config.bucket_all_gathers_fx
+            else bucket_all_gather
+        )
+        GraphTransformObserver(gm, "bucket_all_gathers").apply_graph_pass(
+            lambda graph: p(
+                graph.owning_module,
+                config.bucket_all_gathers_fx_bucket_size_determinator,
+                config.bucket_all_gathers_fx,  # type: ignore[arg-type]
+            )
+        )
+        collectives_bucketing = True
+
+    if collectives_bucketing:
+        # Fx collectives bucketing passes require topological sort for the cases:
+        # when bucketed collectives have users before the last collective in the bucket
+        # AND when inputs of bucketed collective have ancestors after the first collective in the bucket.
+        #
+        # In this case we can not manually pick the place for bucketed collective insertion.
+        # But we are guaranteed by the bucketing (independent collectives in the bucket),
+        # that it is possible to reorder nodes to satisfy all ordering requirements.
+        #
+        # --- before bucketing ---
+        # in0 = ...
+        # wait_ag0 = ag(in0)
+        # user0(wait_ag0)
+        # ...
+        # pre_in1 = ...
+        # in1 = transform(pre_in1)
+        # wait_ag1 = ag(in1)
+        # user1(wait_ag1)
+        #
+        # --- after bucketing ---
+        #
+        # in0 = ...
+        # user(wait_ag0) <--- wait_ag0 is defined only after bucketed collective.
+        #
+        # pre_in1 = ...
+        # in1 = transform(pre_in1)
+        # ag_bucket(in0+in1)
+        # wait_bucket
+        # wait_ag0 = wait_bucket[0]
+        # wait_ag1 = wait_bucket[1]
+        # user1(wait_ag1)
+        stable_topological_sort(gm.graph)
+
+    # Apply overlap scheduling if enabled
+    if config.aten_distributed_optimizations.enable_overlap_scheduling:
+        from torch._inductor.fx_passes.overlap_scheduling import (
+            schedule_overlap_bucketing_from_inductor_configs,
+        )
+
+        overlap_deps = config.aten_distributed_optimizations.insert_overlap_deps
+
+        # by default, insert overlap deps within inductor
+        with config.patch(
+            "aten_distributed_optimizations.insert_overlap_deps",
+            True if overlap_deps is None else overlap_deps,
+        ):
+            GraphTransformObserver(gm, "overlap_scheduling").apply_graph_pass(
+                lambda graph: schedule_overlap_bucketing_from_inductor_configs(
+                    graph.owning_module
+                )
+            )
+
+    # Keep these last, since they introduce mutation. Look at
+    # ./fx_passes/README.md for a discussion of mutation invariants.
+    GraphTransformObserver(gm, "reinplace_inplaceable_ops").apply_graph_pass(
+        functools.partial(reinplace_inplaceable_ops, fake_tensor_updater),
+    )
+    GraphTransformObserver(
+        gm, "decompose_triton_kernel_wrapper_functional"
+    ).apply_graph_pass(decompose_triton_kernel_wrapper_functional)
+    GraphTransformObserver(gm, "decompose_auto_functionalized").apply_graph_pass(
+        decompose_auto_functionalized
+    )
+    if not torch._dynamo.config.skip_fsdp_hooks:
+        GraphTransformObserver(gm, "reinplace_fsdp_all_gather").apply_graph_pass(
+            comms.reinplace_fsdp_all_gather
+        )
+    GraphTransformObserver(gm, "decompose_scan_to_while_loop").apply_gm_pass(
+        decompose_scan_to_while_loop
+    )
+    GraphTransformObserver(gm, "decompose_map_to_while_loop").apply_gm_pass(
+        decompose_map_to_while_loop
+    )
+
+    gm.recompile()
+    gm.graph.lint()
+
+
+def prepare_softmax_pattern(x, dim):
+    xmax = x.amax(dim=dim, keepdim=True)
+    xsub = x - xmax
+    xexp = xsub.exp()
+    xsum = xexp.sum(dim=dim, keepdim=True)
+    return xmax, xsum, xsub, xexp
+
+
+def prepare_softmax_replacement(x, dim):
+    """
+    Return xsub since otherwise log-softmax can not be matched
+    due to a use of this intermediate node. Same reason to return
+    xsub.exp() for softmax.
+    """
+    from torch._inductor.inductor_prims import prepare_softmax_online
+
+    xmax, xsum = prepare_softmax_online(x, dim)
+    xsub = x - xmax
+    return xmax, xsum, xsub, xsub.exp()
+
+
+def prepare_softmax_extra_check(match):
+    """
+    We only have triton online softmax kernels currently.
+    """
+    device_type = match.kwargs["x"].meta["val"].device.type
+    return (
+        config.online_softmax
+        and device_type in ["cuda", "xpu"]
+        and getattr(config, f"{device_type}_backend") == "triton"
+    )
+
+
+def decompose_map_to_while_loop(gm: torch.fx.GraphModule):
+    """This is similar to decompose_scan_to_while_loop."""
+    graph_pass = PatternMatcherPass()
+
+    @register_graph_pattern(
+        CallFunctionVarArgs(torch.ops.higher_order.map_impl),
+        # pyrefly: ignore [bad-argument-type]
+        pass_dict=graph_pass,
+    )
+    def _(match: Match, *args, **kwargs):
+        assert len(kwargs) == 0, (
+            "kwargs of map are not merged into args before entering decompose_map_to_while_loop_pass"
+        )
+        subgraph, fx_xs, fx_additional_inputs = args
+        sub_gm: torch.fx.GraphModule = getattr(gm, subgraph.target)
+        cur_node = match.nodes[0]
+        mapped_outputs = cur_node.meta["val"]
+
+        def lower_to_while_loop(*args, **kwargs):
+            assert len(kwargs) == 0
+            xs, additional_inputs = pytree.tree_unflatten(args, tree_spec)
+            assert isinstance(xs, (tuple, list)) and isinstance(
+                additional_inputs, (tuple, list)
+            ), (xs, additional_inputs)
+            map_length = xs[0].size(0)
+            loop_idx = torch.zeros([], dtype=torch.int64, device=torch.device("cpu"))
+
+            # Similar to NOTE [Pre-allocate scan's output buffer]
+            bound_symbols = {
+                arg.node.expr: arg
+                for arg in pytree.tree_leaves((args, map_length))
+                if isinstance(arg, torch.SymInt)
+            }
+            out_buffers = [
+                torch.empty_strided(
+                    resolve_shape_to_proxy(out.size(), bound_symbols),
+                    resolve_shape_to_proxy(out.stride(), bound_symbols),
+                    device=out.device,
+                    dtype=out.dtype,
+                    layout=out.layout,
+                    requires_grad=out.requires_grad,
+                )
+                for out in mapped_outputs
+            ]
+
+            while_loop_operands = (loop_idx, out_buffers, xs)
+            while_loop_flat_operands, operands_spec = pytree.tree_flatten(
+                while_loop_operands
+            )
+            while_loop_additional_inputs = additional_inputs
+            _, operands_and_additional_inputs_spec = pytree.tree_flatten(
+                (*while_loop_operands, additional_inputs)
+            )
+
+            def cond_fn(*flat_args):
+                loop_idx, _, _, _ = pytree.tree_unflatten(
+                    flat_args,
+                    operands_and_additional_inputs_spec,
+                )
+                return loop_idx < map_length
+
+            def body_fn(*flat_args):
+                loop_idx, out_bufs, xs, additional_inputs = pytree.tree_unflatten(
+                    flat_args,
+                    operands_and_additional_inputs_spec,
+                )
+
+                idx_int = loop_idx.item()
+                torch.ops.aten._assert_scalar.default(idx_int >= 0, "")
+                torch.ops.aten._assert_scalar.default(idx_int < map_length, "")
+                sub_xs = [torch.ops.aten.select.int(x, 0, idx_int) for x in xs]
+                outs = sub_gm(*sub_xs, *additional_inputs)
+
+                for out, buffer in zip(outs, out_bufs):
+                    buffer_slice = torch.ops.aten.select.int(buffer, 0, idx_int)
+                    buffer_slice.copy_(out)
+                return loop_idx + 1, *out_bufs, *xs
+
+            _, final_out, _ = pytree.tree_unflatten(
+                torch.ops.higher_order.while_loop(
+                    cond_fn,
+                    body_fn,
+                    tuple(while_loop_flat_operands),
+                    tuple(while_loop_additional_inputs),
+                ),
+                operands_spec,
+            )
+            return (final_out,)
+
+        lower_to_while_loop_args, tree_spec = pytree.tree_flatten(
+            (fx_xs, fx_additional_inputs)
+        )
+        match.replace_by_example(
+            lower_to_while_loop, lower_to_while_loop_args, run_functional_passes=False
+        )
+
+    graph_pass.apply(gm)
+
+    for _node in gm.graph.find_nodes(
+        op="call_function", target=torch.ops.higher_order.map_impl
+    ):
+        raise AssertionError("map is not lowered to while_loop")
+
+
+def resolve_shape_to_proxy(
+    shape: list[int | torch.SymInt], bound_symbols: dict[Any, Any]
+):
+    """
+    Given a list of symints/ints, this function returns a calculated expression of bound_symbols' values.
+    When we trace this function, we'll get a graph with call_function nodes that describes how the shape expr is
+    computed from bound_symbols' values.
+
+    Suppose shape = (s1*s2, s1+s2) and bound_symbols = {s1: arg0, s2: arg1}, the result will be
+    (arg0 * arg1, arg0 + arg1).
+    """
+    from torch.utils._sympy.interp import sympy_interp
+    from torch.utils._sympy.reference import PythonReferenceAnalysis
+
+    ret = []
+    for s in shape:
+        if isinstance(s, torch.SymInt):
+            ret.append(
+                sympy_interp(
+                    PythonReferenceAnalysis,
+                    bound_symbols,
+                    s.node.expr,
+                ),
+            )
+        else:
+            assert isinstance(s, int)
+            ret.append(s)
+    return ret
+
+
+def decompose_scan_to_while_loop(gm: torch.fx.GraphModule):
+    """
+    NOTE [decompose scan to while_loop]
+    This pass decomposes `scan` to  `while_loop` by replacing the scan fx_node with a while_loop hop.
+
+    Suppose we have a function f:
+
+        def f():
+            init = torch.zeros([])
+            xs = torch.arange(4)
+            ys = []
+            for i in range(xs.size(0)):
+                init = xs[i] + init
+                ys.append(init)
+
+            # Return the final carry and stack the intermediates
+            return init, torch.stack(ys)
+
+    We could rewrite it with a scan with the benefits of reducing compilation time/binary size, reducing
+    memory usage, supporting loops over unbacked shapes and cudagraph etc.
+
+        def g():
+            def step_fn(init: torch.Tensor, x: torch.Tensor):
+                next_init = x + init
+                return next_init, next_init
+
+            init = torch.zeros([])
+            xs = torch.arange(4)
+            final_carry, ys = torch._higher_order.scan(step_fn, init, xs)
+            return final_carry, ys
+
+    This pass will rewrite scan into:
+
+        def k():
+            init = torch.zeros([])
+            xs = torch.arange(4)
+
+            # we create a loop_idx and loop through xs.shape[0]
+            loop_idx = torch.zeros([])
+            ys = torch.empty_strided(_shape_stride_of_ys)
+            def cond_fn(loop_idx, ys, init, xs):
+                return loop_idx < xs.shape[0]
+
+            # we pre-allocate the output buffer ys and inplace
+            # copy the y of each intermediate into a slice.
+            # NOTE [Pre-allocate scan's output buffer].
+            def body_fn(loop_idx, ys, init, xs):
+                int_idx = loop_idx.item()
+                next_init, y = step_fn(init, xs[int_idx])
+                ys[int_idx].copy_(y)
+                return loop_idx + 1, ys, next_init, xs
+
+            final_carry, _, _, ys = torch._higher_order.while_loop(cond_fn, body_fn, (loop_idx, ys, init, xs))
+            return final_carry, ys
+    """
+
+    graph_pass = PatternMatcherPass()
+
+    @register_graph_pattern(
+        CallFunctionVarArgs(torch.ops.higher_order.scan),
+        # pyrefly: ignore [bad-argument-type]
+        pass_dict=graph_pass,
+    )
+    def _(match: Match, *args, **kwargs):
+        from torch._higher_order_ops.scan import _extract_carry_and_out
+
+        assert len(kwargs) == 0, (
+            "kwargs of scan are not merged into args before entering decompose_scan_to_while_loop_pass"
+        )
+
+        combine_subgraph, fx_init, fx_xs, fx_additional_inputs = args
+        assert combine_subgraph.op == "get_attr", "first arg is not combine_subgraph"
+        sub_gm: torch.fx.GraphModule = getattr(gm, combine_subgraph.target)
+        cur_node = match.nodes[0]
+        num_init_leaves = len(fx_init)
+        _, ys_outputs = _extract_carry_and_out(cur_node.meta["val"], num_init_leaves)
+
+        def lower_to_while_loop(*args, **kwargs):
+            """
+            The traced graph of this function will be used to replace the original scan fx_node.
+            """
+            assert len(kwargs) == 0
+
+            # Step 1: construct necessary inputs to while_loop based on scan's input.
+            (
+                init,
+                xs,
+                additional_inputs,
+            ) = pytree.tree_unflatten(args, tree_spec)
+            scan_length = xs[0].size(0)
+            loop_idx = torch.zeros([], dtype=torch.int64, device=torch.device("cpu"))
+
+            # NOTE [Pre-allocate scan's output buffer]
+            # In order to pre-allocate the output buffer for ys, we rely on the meta of scan's fx_node.
+            # However, the meta consists of concrete symints, we need to bind those symints with
+            # proxies in order to trace the torch.empty_strided call correctly.
+            #
+            # Also note that basic free symbols of tensor's shapes are guaranteed to be lifted as subgraph inputs
+            # in dynamo so we can always re-construct the sym expression from placeholders.
+            # See Note [Auto lift basic free symbols when create_graph_input] for how this is done.
+            bound_symbols = {
+                arg.node.expr: arg
+                for arg in pytree.tree_leaves((args, scan_length))
+                if isinstance(arg, torch.SymInt)
+            }
+            ys_outs = [
+                torch.empty_strided(
+                    resolve_shape_to_proxy(ys_out.size(), bound_symbols),
+                    resolve_shape_to_proxy(ys_out.stride(), bound_symbols),
+                    device=ys_out.device,
+                    dtype=ys_out.dtype,
+                    layout=ys_out.layout,
+                    requires_grad=ys_out.requires_grad,
+                )
+                for ys_out in ys_outputs
+            ]
+
+            while_loop_operands = (loop_idx, ys_outs, init, xs)
+            flat_operands, operands_spec = pytree.tree_flatten(while_loop_operands)
+            _, operands_and_additional_inputs_spec = pytree.tree_flatten(
+                (*while_loop_operands, additional_inputs)
+            )
+
+            # Step 2: create the cond_fn and body_fn for while_loop
+            def cond_fn(*flat_args):
+                loop_idx, _, _, _, _ = pytree.tree_unflatten(
+                    flat_args, operands_and_additional_inputs_spec
+                )  # type: ignore[has-type]
+                return loop_idx < scan_length  # type: ignore[has-type]
+
+            def body_fn(*flat_args):
+                loop_idx, ys_outs, carry, xs, additional_inputs = pytree.tree_unflatten(
+                    flat_args,
+                    operands_and_additional_inputs_spec,  # type: ignore[has-type]
+                )
+
+                idx_int = loop_idx.item()
+                torch.ops.aten._assert_scalar.default(idx_int >= 0, "")
+                torch.ops.aten._assert_scalar.default(idx_int < scan_length, "")
+                sub_xs = [torch.ops.aten.select.int(x, 0, idx_int) for x in xs]
+                next_carry, ys = _extract_carry_and_out(
+                    sub_gm(*(list(carry) + sub_xs + list(additional_inputs))),
+                    num_init_leaves,
+                )
+                for y, y_out in zip(ys, ys_outs):
+                    y_out_slice = torch.ops.aten.select.int(y_out, 0, idx_int)
+                    y_out_slice.copy_(y)
+                return loop_idx + 1, *ys_outs, *next_carry, *xs
+
+            # Step 3: call the while_loop operator
+            _, ys_outs, last_carry, _ = pytree.tree_unflatten(
+                torch.ops.higher_order.while_loop(
+                    cond_fn,
+                    body_fn,
+                    tuple(flat_operands),
+                    tuple(additional_inputs),
+                ),
+                operands_spec,
+            )
+            return list(last_carry) + list(ys_outs)
+
+        lower_to_while_loop_args, tree_spec = pytree.tree_flatten(
+            (
+                fx_init,
+                fx_xs,
+                fx_additional_inputs,
+            )
+        )
+        match.replace_by_example(
+            lower_to_while_loop,
+            lower_to_while_loop_args,
+            run_functional_passes=False,
+        )
+
+    graph_pass.apply(gm)
+
+    for _node in gm.graph.find_nodes(
+        op="call_function", target=torch.ops.higher_order.scan
+    ):
+        raise AssertionError("scan is not lowered to while_loop")
+
+
+@init_once_fakemode
+def lazy_init():
+    if torch._C._has_mkldnn:
+        from . import decompose_mem_bound_mm  # noqa: F401
+        from .mkldnn_fusion import _mkldnn_fusion_init
+
+        _mkldnn_fusion_init()
+
+    # Put this patterns in post-grad pass rather than joint-graph
+    # pass since otherwise there will be perf/peak-memory regression:
+    # https://github.com/pytorch/pytorch/issues/148141
+    register_replacement(
+        # pyrefly: ignore [bad-argument-type]
+        prepare_softmax_pattern,
+        # pyrefly: ignore [bad-argument-type]
+        prepare_softmax_replacement,
+        [torch.empty(4, 8)],
+        scalar_workaround=dict(dim=-1),
+        # pyrefly: ignore [bad-argument-type]
+        trace_fn=fwd_only,
+        # pyrefly: ignore [bad-argument-type]
+        pass_dicts=pass_patterns[1],
+        extra_check=prepare_softmax_extra_check,
+    )
+
+
+def reorder_for_locality(graph: torch.fx.Graph):
+    if torch.distributed.is_available():
+
+        def check():
+            # This is a wait node, and `other_node`` is some collective node.
+            # Eager semantics allow waits to be issued in a different order than
+            # the collectives. Reordering this wait node might reorder collectives
+            # which cause hangs. Once we have SPMD mode, we can safely reorder them.
+            # However, increasing the locality between a collective and its wait node
+            # is generally worse for performance.
+            return node.target != torch.ops._c10d_functional.wait_tensor.default
+    else:
+
+        def check():
+            return True
+
+    def visit(other_node):
+        if (
+            other_node.op == "call_function"
+            and other_node.target != operator.getitem
+            and all((n in seen_nodes) for n in other_node.users)
+            and get_mutation_region_id(graph, node)
+            == get_mutation_region_id(graph, other_node)
+            and check()
+        ):
+            # move node's producers right before it
+            node.prepend(other_node)
+
+    seen_nodes = OrderedSet[torch.fx.Node]()
+
+    # only reorder nodes before the first copy_ in the graph.
+    # copy_ will appear at the end of functionalized graphs when there is mutation on inputs,
+    # and this reordering doesn't work well with mutation
+    first_copy = next(
+        iter(graph.find_nodes(op="call_function", target=torch.ops.aten.copy_.default)),
+        None,
+    )
+    past_mutating_epilogue = first_copy is None
+
+    for node in reversed(graph.nodes):
+        seen_nodes.add(node)
+        if not past_mutating_epilogue:
+            past_mutating_epilogue = node is first_copy
+            continue
+
+        torch.fx.map_arg((node.args, node.kwargs), visit)
+
+
+def register_lowering_pattern(
+    pattern, extra_check=_return_true, pass_number=1
+) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
+    """
+    Register an aten to inductor IR replacement pattern
+    """
+    return pattern_matcher.register_lowering_pattern(
+        pattern,
+        extra_check,
+        # pyrefly: ignore [bad-argument-type]
+        pass_dict=pass_patterns[pass_number],
+    )
+
+
+################################################################################
+# Actual patterns below this point.
+# Priority of patterns is:
+#   - later output nodes first
+#   - order patterns are defined in
+################################################################################
+
+
+def is_valid_mm_plus_mm(match: Match):
+    if not (config.max_autotune or config.max_autotune_gemm):
+        return False
+
+    # Check if all required values exist
+    mat1_val = match.kwargs["mat1"].meta.get("val")
+    mat2_val = match.kwargs["mat2"].meta.get("val")
+    mat3_val = match.kwargs["mat3"].meta.get("val")
+    mat4_val = match.kwargs["mat4"].meta.get("val")
+
+    if mat1_val is None or mat2_val is None or mat3_val is None or mat4_val is None:
+        return False
+
+    *_b1, m1, k1 = mat1_val.shape
+    *_b2, k2, n1 = mat2_val.shape
+    if k1 != k2:
+        return False
+
+    *_b1, m2, k3 = mat3_val.shape
+    *_b2, k4, n2 = mat4_val.shape
+    if k3 != k4:
+        return False
+
+    if m1 != m2 or n1 != n2:
+        return False
+
+    return True
+
+
+@register_lowering_pattern(
+    CallFunction(
+        aten.add,
+        CallFunction(aten.mm, KeywordArg("mat1"), KeywordArg("mat2")),
+        CallFunction(aten.mm, KeywordArg("mat3"), KeywordArg("mat4")),
+    ),
+    extra_check=is_valid_mm_plus_mm,
+)
+def mm_plus_mm(match: Match, mat1, mat2, mat3, mat4):
+    return inductor.kernel.mm_plus_mm.tuned_mm_plus_mm(mat1, mat2, mat3, mat4)
+
+
+@register_graph_pattern(
+    CallFunction(
+        aten.cumsum.default,
+        CallFunction(
+            torch.ops.aten.full.default,
+            KeywordArg("shape"),
+            KeywordArg("fill_value"),
+            dtype=KeywordArg("dtype"),
+            layout=Ignored(),
+            device=KeywordArg("device"),
+            pin_memory=False,
+            _users=MULTIPLE,
+        ),
+        KeywordArg("dim"),
+        _users=MULTIPLE,
+    ),
+    # pyrefly: ignore [bad-argument-type]
+    pass_dict=pass_patterns[1],
+)
+def pointless_cumsum_replacement(match: Match, shape, fill_value, device, dtype, dim):
+    """Based on a pattern in OPTForCausalLM"""
+
+    if is_integer_dtype(dtype) or is_boolean_dtype(dtype):
+        # cumsum promotes all integral types to int64
+        dtype = torch.int64
+
+    def repl(*shape):
+        dim_size = shape[dim]
+        idx = torch.arange(1, dim_size + 1, device=device, dtype=dtype)
+
+        inter_shape = [1] * len(shape)
+        inter_shape[dim] = dim_size
+        return (idx * fill_value).view(inter_shape).expand(shape)
+
+    # only replace the output node, not all nodes
+    match.nodes = [match.output_node()]
+    # pyrefly: ignore [bad-argument-type]
+    match.replace_by_example(repl, list(shape))
+
+
+_cat_1 = CallFunction(aten.cat, Arg(), 1, _users=2)
+
+
+@register_lowering_pattern(
+    CallFunction(
+        aten.cat,
+        [
+            _cat_1,
+            CallFunction(
+                aten.slice,
+                _cat_1,
+                1,
+                0,
+                KeywordArg("size"),
+            ),
+        ],
+        1,
+    )
+)
+def cat_slice_cat(match, cat_input, size, dim=1):
+    """
+    This is an example of a more complex pattern where cat_1 is used
+    multiple times inside the pattern.  We fold 2 calls to cat into one.
+
+    Matches:
+        cat_1: f32[1024, 4077] = torch.ops.aten.cat.default([add_26, primals_217], 1)
+        slice_1: f32[1024, 4077] = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807)
+        slice_2: f32[1024, 19] = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 19)
+        cat_2: f32[1024, 4096] = torch.ops.aten.cat.default([cat_1, slice_2], 1)
+
+
+    Rewrite to:
+        slice_2 = torch.ops.aten.slice.Tensor(add_26, 1, 0, 19)
+        cat_2 = torch.ops.aten.cat.default([add_26, primals_217, slice2], 1)
+    """
+    first, *rest = cat_input
+    # Optimization is optional, because we can just not fold the cat
+    # size should be within first.get_size()[dim] such that the optimization is valid.
+    # For negative `end`, we currently fallback to not optimizing.
+    if size >= 0 and V.graph.sizevars.statically_known_leq(size, first.get_size()[dim]):
+        # fold 2 cats into 1 cat
+        return L[aten.cat](
+            [
+                first,
+                *rest,
+                L[aten.slice](first, dim, 0, size),
+            ],
+            dim,
+        )
+    else:
+        # don't expect to hit this case, just fall back
+        tmp = L[aten.cat](cat_input, dim)
+        return L[aten.cat](
+            [
+                tmp,
+                L[aten.slice](tmp, dim, 0, size),
+            ],
+            dim,
+        )
+
+
+def is_valid_splitwithsizes_cat(match):
+    split_nodes = filter_nodes(match.nodes, aten.split_with_sizes)
+    cat_nodes = filter_nodes(match.nodes, aten.cat)
+    get_item_nodes = filter_nodes(match.nodes, operator.getitem)
+    if len(split_nodes) != 1 or len(cat_nodes) != 1:
+        return False
+    split_node, cat_node = split_nodes[0], cat_nodes[0]
+    # The dim of split and cat should match for passthrough
+    if get_arg_value(split_node, 2, "dim") != get_arg_value(cat_node, 1, "dim"):
+        return False
+    get_item_args = OrderedSet(
+        get_arg_value(get_item_node, 1) for get_item_node in get_item_nodes
+    )
+    assert None not in get_item_args
+    split_sizes = get_arg_value(split_node, 1, "split_sizes")
+    # All parts of split should be included in the cat
+    if get_item_args != OrderedSet(range(len(split_sizes))):
+        return False
+    # The order of get_item_args should same with cat_node used.
+    # For example, if the split_node like split_with_sizes(input, [2, 2, 3], 1),
+    # the cat node should be like cat([get_item(0), get_item(1), get_item(2)], 1).
+    cat_items_args_order = [
+        get_arg_value(item_node, 1) for item_node in get_arg_value(cat_node, 0)
+    ]
+    if cat_items_args_order != list(range(len(split_sizes))):
+        return False
+
+    return True
+
+
+def same_meta(node1: torch.fx.Node, node2: torch.fx.Node):
+    """True if two nodes have the same metadata"""
+    val1 = node1.meta.get("val")
+    val2 = node2.meta.get("val")
+    return (
+        val1 is not None
+        and val2 is not None
+        and statically_known_true(sym_eq(val1.size(), val2.size()))
+        and val1.layout == val2.layout
+        and val1.dtype == val2.dtype
+        and val1.device == val2.device
+        and (
+            val1.layout != torch.strided
+            or statically_known_true(sym_eq(val1.stride(), val2.stride()))
+        )
+    )
+
+
+noop_registry: dict[Any, Any] = {}
+
+
+def register_noop_decomp(targets, nop_arg=0):
+    def register_fun(cond):
+        register_decomposition(targets, registry=noop_registry, unsafe=True)(
+            (cond, nop_arg)  # type: ignore[arg-type]
+        )
+        return cond
+
+    return register_fun
+
+
+@register_noop_decomp(aten.slice)
+def slice_noop(self, dim=0, start=None, end=None, step=1):
+    if start is None or end is None:
+        return False
+
+    slice_dim_size = self.shape[dim]
+    if (
+        statically_known_true(sym_eq(start, 0))
+        and (
+            statically_known_true(end >= 2**63 - 1)
+            or statically_known_true(end >= slice_dim_size)
+        )
+        and statically_known_true(sym_eq(step, 1))
+    ):
+        return True
+    return False
+
+
+@register_noop_decomp(aten.slice_scatter, 1)
+def slice_scatter_noop(self, src, dim=0, start=None, end=None, step=1):
+    if start is None:
+        start = 0
+    if end is None:
+        end = 2**63 - 1
+    slice_scatter_dim_size = self.shape[dim]
+    if (
+        self.shape == src.shape
+        and start == 0
+        and (
+            statically_known_true(end >= 2**63 - 1)
+            or statically_known_true(end >= slice_scatter_dim_size)
+        )
+        and step == 1
+    ):
+        return True
+    return False
+
+
+@register_noop_decomp(aten.repeat)
+def repeat_noop(self, repeats):
+    return all(r == 1 for r in repeats)
+
+
+@register_noop_decomp(aten.constant_pad_nd)
+def constant_pad_nd(x, padding, fill_value=0):
+    return all(p == 0 for p in padding)
+
+
+@register_noop_decomp(torch.ops.prims.convert_element_type)
+def convert_element_type_noop(x, dtype: torch.dtype):
+    return x.dtype == dtype
+
+
+@register_noop_decomp(torch.ops.prims.device_put)
+def device_put_noop(x, device, non_blocking=True):
+    return x.device == decode_device(device)
+
+
+@register_noop_decomp([aten.ceil, aten.floor, aten.round, aten.trunc])
+def int_noop(x):
+    return is_integer_dtype(x.dtype)
+
+
+@register_noop_decomp([aten.pow])
+def pow_noop(a, b):
+    return isinstance(b, int) and b == 1
+
+
+@register_noop_decomp([aten.cat], lambda args: args[0][0])
+def cat_noop(inputs, dim=0):
+    return len(inputs) == 1
+
+
+@register_noop_decomp(aten.view.default)
+def view_default_noop(arg, size):
+    return statically_known_true(sym_eq(arg.shape, tuple(size)))
+
+
+@register_noop_decomp(aten.view.dtype)
+def view_dtype_noop(arg, dtype):
+    return arg.dtype == dtype
+
+
+# Note, we also always have a check for identical metadata, which is why these
+# are safe
+@register_noop_decomp([aten.copy], nop_arg=1)
+@register_noop_decomp([aten.alias, aten.clone])
+def true_noop(*args, **kwargs):
+    return True
+
+
+def remove_noop_ops(graph: torch.fx.Graph):
+    """
+    Removes both operations that are essentially aten.clone and operations that are essentially aten.alias from the graph.
+    """
+    inputs = OrderedSet[torch.fx.Node]()
+    input_storages = OrderedSet[int | None]()
+    output_storages = OrderedSet[int | None]()
+
+    for node in graph.find_nodes(op="placeholder"):
+        inputs.add(node)
+        input_storages.add(get_node_storage(node))
+
+    output_node = next(iter(reversed(graph.nodes)))
+    assert output_node.op == "output"
+    outputs = output_node.args[0]
+    if not isinstance(outputs, (list, tuple)):
+        # nested subgraphs can have singleton outputs
+        outputs = (outputs,)
+    for out in outputs:
+        if isinstance(out, torch.fx.Node):
+            output_storages.add(get_node_storage(out))
+
+    for node in graph.nodes:
+        if node.target in noop_registry:
+            cond, src_index = noop_registry[node.target]
+            if isinstance(src_index, int):
+                src = node.args[src_index]
+            else:
+                src = src_index(node.args)
+            if not isinstance(src, torch.fx.Node):
+                continue
+            # Don't introduce new aliasing between inputs and outputs.
+            # See fx_passes/README.md for a discussion of why this is
+            # necessary.
+            node_storage = get_node_storage(node)
+            src_storage = get_node_storage(src)
+            node_is_view = node_storage == src_storage
+            if (
+                not node_is_view
+                and node_storage in output_storages
+                and (src_storage in input_storages or src_storage in output_storages)
+            ):
+                continue
+
+            # Even if input and outputs are expected to alias,
+            # don't make "node is src" True
+            if (
+                node_is_view
+                and node in output_node.args
+                and (src in inputs or src in output_node.args)
+            ):
+                continue
+
+            is_valid, args, kwargs = get_fake_args_kwargs(node)
+            if not is_valid:
+                continue
+            if same_meta(node, src) and cond(*args, **kwargs):
+                node.replace_all_uses_with(src)
+                graph.erase_node(node)
+
+
+def remove_assert_ops(graph: torch.fx.Graph):
+    """
+    Removes aten._assert_tensor_metadata.default op because
+    1) it will be lowered to a no-op in inductor
+    2) it can block fusion, such as unfuse_bias_add_to_pointwise fusion.
+
+    This op could come from aten.to functionalization in export.
+
+    For example, if we have a graph like below
+
+    %addmm = aten.addmm.default(%linear_bias, %arg3_1, %permute)
+    %_assert_tensor_metadata = aten._assert_tensor_metadata.default(%addmm, None, None, torch.float16)
+    %convert_element_type_3 = prims.convert_element_type.default(%addmm, torch.float32)
+    %pow_1 = aten.pow.Tensor_Scalar(%convert_element_type_3, 2)
+
+    We still want to fuse add from addmm with pow, instead of fusing add with mm, according to unfuse_bias_add_to_pointwise fusion.
+
+    However, aten._assert_tensor_metadata.default is not a pointwise op, and would fail the should_prefer_unfused_addmm check.
+
+    We remove this op so it doesn't block fusion decisions. It's safe because this op is lowered to a no-op with @register_lowering.
+
+    """
+    for node in graph.find_nodes(
+        op="call_function", target=torch.ops.aten._assert_tensor_metadata.default
+    ):
+        graph.erase_node(node)
+
+
+def decompose_triton_kernel_wrapper_functional(graph):
+    """Decomposes triton_kernel_wrapper_functional nodes into clones and the underlying
+    mutation node.
+
+    We assume that the reinplacing pass runs before this; the reinplacing pass
+    tells us (via rewriting the arguments or .meta to those nodes) which
+    Tensors we should clone and which Tensors are safe to reinplace.
+    """
+    graph_pass = PatternMatcherPass()
+
+    @register_graph_pattern(
+        CallFunctionVarArgs(torch.ops.higher_order.triton_kernel_wrapper_functional),
+        # pyrefly: ignore [bad-argument-type]
+        pass_dict=graph_pass,
+    )
+    def _(match: Match, *args, **kwargs):
+        from torch._higher_order_ops.triton_kernel_wrap import (
+            triton_kernel_wrapper_functional_dense,
+        )
+
+        flat_args, spec = pytree.tree_flatten((args, kwargs))
+
+        # NB: we combine (args, kwargs) into flat args for replacing.
+        # This is replace_by_example uses make_fx which does not support
+        # tracing a function with kwargs.
+        def decomp(*flat_args):
+            args, kwargs = pytree.tree_unflatten(flat_args, spec)
+            return (triton_kernel_wrapper_functional_dense(*args, **kwargs),)
+
+        # pyrefly: ignore [bad-argument-type]
+        match.replace_by_example(decomp, flat_args, run_functional_passes=False)
+
+    graph_pass.apply(graph)
+
+    for _ in graph.find_nodes(
+        op="call_function",
+        target=torch.ops.higher_order.triton_kernel_wrapper_functional,
+    ):
+        raise AssertionError("triton_kernel_wrapper_functional was not removed")
+
+
+def decompose_auto_functionalized(graph):
+    """Decomposes auto_functionalized nodes into clones and the underlying
+    mutation node.
+
+    We assume that the reinplacing pass runs before this; the reinplacing pass
+    tells us (via rewriting the arguments or .meta to those nodes) which
+    Tensors we should clone and which Tensors are safe to reinplace.
+    """
+    graph_pass = PatternMatcherPass()
+
+    @register_graph_pattern(
+        CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized),
+        # pyrefly: ignore [bad-argument-type]
+        pass_dict=graph_pass,
+    )
+    def _(match: Match, *args, **kwargs):
+        from torch._higher_order_ops.auto_functionalize import auto_functionalized_dense
+
+        only_clone_these_tensors = tuple(
+            match.nodes[0].meta.get("only_clone_these_tensors", [])
+        )
+
+        flat_args, spec = pytree.tree_flatten((args, kwargs))
+
+        # NB: we combine (args, kwargs) into flat args for replacing.
+        # This is replace_by_example uses make_fx which does not support
+        # tracing a function with kwargs.
+        def decomp(*flat_args):
+            args, kwargs = pytree.tree_unflatten(flat_args, spec)
+            assert len(args) == 1
+            mode = args[0]
+            return auto_functionalized_dense(mode, only_clone_these_tensors, **kwargs)
+
+        # pyrefly: ignore [bad-argument-type]
+        match.replace_by_example(decomp, flat_args, run_functional_passes=False)
+
+    @register_graph_pattern(
+        CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized_v2),
+        # pyrefly: ignore [bad-argument-type]
+        pass_dict=graph_pass,
+    )
+    def _(match: Match, *args, **kwargs):
+        from torch._higher_order_ops.auto_functionalize import (
+            auto_functionalized_v2_dense,
+        )
+
+        only_clone_these_bases = tuple(
+            match.nodes[0].meta.get("only_clone_these_tensors", [])
+        )
+
+        flat_args, spec = pytree.tree_flatten((args, kwargs))
+
+        def _maybe_resolve_constant_get_attr(node):
+            # Resolve getattr node to its value because they don't always have meta["val"]
+            if (
+                isinstance(node, torch.fx.Node)
+                and node.op == "get_attr"
+                and "val" not in node.meta
+            ):
+                const_attr = getattr(graph.owning_module, node.target)  # type: ignore[arg-type]
+                assert isinstance(
+                    const_attr, (torch.fx.GraphModule, pytree.TreeSpec)
+                ), (type(const_attr), const_attr)
+                return const_attr
+            return node
+
+        flat_args = [_maybe_resolve_constant_get_attr(arg) for arg in flat_args]
+
+        # NB: we combine (args, kwargs) into flat args for replacing.
+        # This is replace_by_example uses make_fx which does not support
+        # tracing a function with kwargs.
+        def decomp(*flat_args):
+            args, kwargs = pytree.tree_unflatten(flat_args, spec)
+            assert len(args) == 1
+            mutable_op = args[0]
+            return auto_functionalized_v2_dense(
+                mutable_op, only_clone_these_bases, **kwargs
+            )
+
+        # pyrefly: ignore [bad-argument-type]
+        match.replace_by_example(decomp, flat_args, run_functional_passes=False)
+
+    graph_pass.apply(graph)
+
+    # Remove unused get_attr nodes and their corresponding attributes from the graph module.
+    # When auto_functionalizing a hop, we need to clean up get_attr nodes for _constant_schema
+    # and the auto_functionalized graph module that are no longer referenced.
+    unused_get_attr_nodes = []
+    removable_attrs: OrderedSet[torch.fx.node.Target] = OrderedSet()
+    protected_attrs: OrderedSet[torch.fx.node.Target] = OrderedSet()
+
+    # First pass: identify unused get_attr nodes and track attribute usage
+    for node in graph.nodes:
+        if node.op != "get_attr":
+            continue
+
+        if len(node.users) == 0:
+            # Node is unused, mark for removal
+            unused_get_attr_nodes.append(node)
+
+            # Check if the attribute can be removed from the module
+            if (
+                hasattr(graph.owning_module, node.target)
+                and isinstance(
+                    getattr(graph.owning_module, node.target), torch.fx.GraphModule
+                )
+                and node.target not in protected_attrs
+            ):
+                removable_attrs.add(node.target)
+        else:
+            # Node is used, protect its attribute from removal
+            if node.target in removable_attrs:
+                removable_attrs.remove(node.target)
+            protected_attrs.add(node.target)
+
+    # Second pass: clean up unused nodes and attributes
+    for node in unused_get_attr_nodes:
+        graph.erase_node(node)
+
+    for attr_name in removable_attrs:
+        assert isinstance(attr_name, str)
+        delattr(graph.owning_module, attr_name)
+
+    graph.lint()
+
+    for _ in graph.find_nodes(
+        op="call_function", target=torch.ops.higher_order.auto_functionalized
+    ):
+        raise AssertionError("auto_functionalized was not removed")
+
+    for _ in graph.find_nodes(
+        op="call_function", target=torch.ops.higher_order.auto_functionalized_v2
+    ):
+        raise AssertionError("auto_functionalized_v2 was not removed")
+
+
+@register_lowering_pattern(
+    CallFunction(
+        aten.cat,
+        ListOf(
+            CallFunction(
+                operator.getitem,
+                CallFunction(
+                    aten.split_with_sizes,
+                    KeywordArg("input_"),
+                    Ignored(),
+                    Ignored(),
+                    _users=MULTIPLE,
+                ),
+                Ignored(),
+            ),
+        ),
+        Ignored(),
+    ),
+    pass_number=2,
+    extra_check=is_valid_splitwithsizes_cat,
+)
+def splitwithsizes_cat_replace(match, input_):
+    return input_
+
+
+def is_valid_cat_splitwithsizes(match):
+    cat_nodes = filter_nodes(match.nodes, aten.cat)
+    split_nodes = filter_nodes(match.nodes, aten.split_with_sizes)
+    if len(split_nodes) != 1 or len(cat_nodes) != 1:
+        return False
+    split_node, cat_node = split_nodes[0], cat_nodes[0]
+
+    # the cat node has other users: can't eliminate
+    if len(cat_node.users) > 1:
+        return False
+
+    # the dim of the cat and split should match
+    dim = get_arg_value(split_node, 2, "dim")
+    if dim != get_arg_value(cat_node, 1, "dim"):
+        return False
+
+    cat_inputs = list(get_arg_value(cat_node, 0))
+    split_sizes = get_arg_value(split_node, 1, "split_sizes")
+    # the number of input tensors in cat and the
+    # length of the split sizes should match
+    if len(cat_inputs) != len(split_sizes):
+        return False
+
+    for cat_input, split_size in zip(cat_inputs, split_sizes):
+        # each cat input tensor's size along dim
+        # should match the corresponding split size
+        if "val" not in cat_input.meta:
+            return False
+        cat_input_size = cat_input.meta["val"].size(dim)
+        if cat_input_size != split_size:
+            return False
+
+    return True
+
+
+@register_lowering_pattern(
+    CallFunction(
+        aten.split_with_sizes,
+        CallFunction(
+            aten.cat,
+            KeywordArg("input_"),
+            Ignored(),
+            _users=MULTIPLE,
+        ),
+        Ignored(),
+        Ignored(),
+    ),
+    pass_number=2,
+    extra_check=is_valid_cat_splitwithsizes,
+)
+def cat_splitwithsizes_replace(match, input_):
+    return input_
+
+
+def view_to_reshape(gm):
+    """
+    Replace view ops in the GraphModule to reshape ops.
+    """
+    subgraph_names: OrderedSet[str] = OrderedSet(
+        x.target for x in gm.graph.find_nodes(op="get_attr")
+    )
+
+    for child_name, child_mod in gm.named_children():
+        if child_name in subgraph_names and isinstance(child_mod, torch.fx.GraphModule):
+            view_to_reshape(child_mod)
+
+    for nd in gm.graph.find_nodes(
+        op="call_function", target=torch.ops.aten.view.default
+    ):
+        nd.target = torch.ops.aten.reshape.default
+
+
+def should_prefer_unfused_addmm(match):
+    inp = match.kwargs["inp"]
+    if not is_gpu(inp.meta["val"].device.type):
+        return False
+
+    output = match.output_node()
+    return all(is_pointwise_use(use) for use in output.users)
+
+
+@register_graph_pattern(
+    CallFunction(
+        aten.addmm,
+        KeywordArg("inp"),
+        Arg(),
+        Arg(),
+        beta=KeywordArg("beta"),
+        alpha=KeywordArg("alpha"),
+    ),
+    # pyrefly: ignore [bad-argument-type]
+    pass_dict=pass_patterns[2],
+    extra_check=should_prefer_unfused_addmm,
+)
+def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp, alpha, beta):
+    def repl(inp, x1, x2, alpha, beta):
+        mm_result = x1 @ x2
+        if alpha != 1:
+            mm_result = alpha * mm_result
+        if beta != 1:
+            inp = beta * inp
+        return inp + mm_result
+
+    # pyrefly: ignore [bad-argument-type]
+    match.replace_by_example(repl, [inp, mat1, mat2, alpha, beta])
+
+
+def is_valid_addmm_fusion(match):
+    mat1, mat2 = match.args
+    inp = match.kwargs["inp"]
+
+    if not (
+        isinstance(inp, torch.fx.Node) and isinstance(inp.meta["val"], torch.Tensor)
+    ):
+        return False  # Input is a number
+
+    in_shape = inp.meta["val"].shape
+    mm_shape = mat1.meta["val"].shape[0], mat2.meta["val"].shape[1]
+    matched = is_expandable_to(in_shape, mm_shape)
+    if not matched:
+        return False  # Shape mismatch
+
+    inp_dtype = inp.meta["val"].dtype
+
+    # aten cublas integration assumes equal dtypes
+    if inp_dtype != mat1.meta["val"].dtype or inp_dtype != mat2.meta["val"].dtype:
+        return False
+
+    return not should_prefer_unfused_addmm(match)
+
+
+@register_graph_pattern(
+    CallFunction(
+        aten.add,
+        CallFunction(aten.mm, Arg(), Arg()),
+        KeywordArg("inp"),
+    ),
+    # pyrefly: ignore [bad-argument-type]
+    pass_dict=pass_patterns[2],
+    extra_check=is_valid_addmm_fusion,
+)
+@register_graph_pattern(
+    CallFunction(
+        aten.add,
+        KeywordArg("inp"),
+        CallFunction(aten.mm, Arg(), Arg()),
+    ),
+    # pyrefly: ignore [bad-argument-type]
+    pass_dict=pass_patterns[2],
+    extra_check=is_valid_addmm_fusion,
+)
+def addmm(match, mat1, mat2, *, inp):
+    def repl(inp, mat1, mat2):
+        return aten.addmm(inp, mat1, mat2)
+
+    match.replace_by_example(repl, [inp, mat1, mat2])
+
+
+def register_partial_reduction_pattern():
+    "Reuse partial reductions in complete reductions"
+
+    # post grad equivalents
+    equiv_red = {
+        aten.amax.default: aten.max.default,
+        aten.amin.default: aten.min.default,
+    }
+
+    # TODO: to support other reductions like sum, would need to skip
+    # lower precision reductions since partial output would need to be kept at fp32.
+    for red_op in (aten.amax.default, aten.amin.default):
+        inp = KeywordArg("input")
+        partial_reduc = CallFunction(
+            red_op, inp, KeywordArg("reduced_dims"), KeywordArg("keepdim")
+        )
+        full_reduc = CallFunction([red_op, equiv_red[red_op]], inp)
+
+        @register_graph_pattern(
+            MultiOutputPattern([partial_reduc, full_reduc]),
+            # pyrefly: ignore [bad-argument-type]
+            pass_dict=pass_patterns[2],
+        )
+        def reuse_partial(match, input, reduced_dims, keepdim):
+            partial_red, full_red = match.output_nodes()
+
+            # if they're small, reuse not worth it
+            if not statically_known_true(input.meta["val"].numel() >= 4096):
+                return True
+
+            def replacement(inp: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+                partial = partial_red.target(inp, reduced_dims, keepdim)
+                complete = full_red.target(partial)
+                return (partial, complete)
+
+            counters["inductor"]["partial_reduction_reuse"] += 1
+            match.replace_by_example(replacement, [input])
+
+
+register_partial_reduction_pattern()
+
+
+def check_shape_cuda_and_fused_int_mm_mul_enabled(match):
+    return (
+        config.force_fuse_int_mm_with_mul
+        and len(getattr(match.args[2].meta.get("val"), "shape", [])) == 2
+        and getattr(match.args[2].meta.get("val"), "is_cuda", False)
+    )
+
+
+def is_index_put_and_requires_h2d_sync_for_gpu_value(node):
+    from torch.fx.operator_schemas import normalize_function
+
+    if node.target not in [
+        torch.ops.aten.index_put.default,
+        torch.ops.aten.index_put_.default,
+    ]:
+        return False
+    # Inductor falls back to aten.index_put_.
+    # index_put_ will will call nonzero() and perform a H2D sync if
+    # any of its indices are bool/byte tensors
+    # However, it will short-circuit this H2D sync and run mask_fill_
+    # if the value we are putting is a cpu scalar.
+    # Therefore, when inductor sees an index_put_ with byte tensor indices,
+    # it should *not* convert the cpu scalar value into a gpu tensor.
+    args_, _kwargs = normalize_function(node.target, node.args, node.kwargs)  # type: ignore[misc]
+    any_byte_bool_indices = False
+    indices = args_[1]
+    for i in indices:
+        if i is not None and i.meta["val"].dtype in [torch.bool, torch.int8]:
+            any_byte_bool_indices = True
+
+    val = args_[2].meta["val"]
+    val_is_cpu_scalar = val.device.type == "cpu" and val.numel() == 1
+    # If both these conditions hold, then converting the val
+    # to a gpu tensor will incur a H2D sync when inductor calls aten.index_put_
+    return any_byte_bool_indices and val_is_cpu_scalar
+
+
+class ConstructorMoverPass:
+    def __init__(
+        self, target: str, allow_outputs: bool = False, allow_inputs: bool = False
+    ) -> None:
+        """
+        Move constructors from cpu to the target_device.
+
+        Sweeps through the module, looking for constructor nodes that can be moved
+        to the target_device.
+
+        A constructor node can be moved to the target_device iff all of its users
+        can also be moved (tested by cannot_be_moved). Otherwise, all dependent
+        constructor nodes won't be moved.
+
+        - target: target device type
+        - allow_outputs: allow outputs to be moved
+        - allow_inputs: allow inputs to be moved
+        """
+
+        self.target = target
+        self.allow_inputs = allow_inputs
+        self.allow_outputs = allow_outputs
+
+        assert isinstance(target, str), (
+            "target should be a string representing the device type. "
+            f"Got: {type(target).__name__}"
+        )
+
+    def allow_cpu_device(self, node: fx.Node) -> bool:
+        """
+        Returns whether a node that returns a tensor on the target device may have
+        cpu tensors as input.
+        """
+        return node.target in (
+            torch.ops.aten.index.Tensor,
+            torch.ops.aten.index_put.default,
+            torch.ops.aten.index_put_.default,
+            torch.ops.aten.copy.default,
+            torch.ops.aten.copy_.default,
+            torch.ops.aten.slice_scatter.default,
+        )
+
+    def is_on_target_device(self, node: fx.Node) -> bool:
+        """
+        Returns whether a node is on the target device.
+        """
+        node_device = self.get_node_device(node)
+        return node_device is not None and node_device.type == self.target
+
+    def is_cpu_scalar_tensor(self, node: fx.Node) -> bool:
+        """
+        Returns whether a node is a cpu scalar tensor.
+        """
+        device = self.get_node_device(node)
+        is_cpu = device is not None and device.type == "cpu"
+        ten = node.meta.get("val")
+        is_scalar = isinstance(ten, torch.Tensor) and len(ten.size()) == 0
+        return is_cpu and is_scalar
+
+    def all_inputs_are_cpu_scalar_or_on_target_device(self, node: fx.Node) -> bool:
+        """
+        Returns whether a node's inputs are either cpu scalar tensors or
+        on the target device.
+        """
+        inputs = (
+            inp
+            for inp in itertools.chain(node.args, node.kwargs.values())
+            if isinstance(inp, fx.Node)
+        )
+        return all(
+            self.is_cpu_scalar_tensor(inp) or self.is_on_target_device(inp)
+            for inp in inputs
+        )
+
+    def cannot_be_moved(self, node: fx.Node) -> bool:
+        """
+        Returns whether a node can be moved to the target device.
+
+        If this function returns False, it means that this node and all of its users
+        won't be moved into the target device.
+        """
+        if node.target == "output":
+            return not self.allow_outputs
+
+        if not (
+            isinstance(node.target, torch._ops.OpOverload)
+            and node.target.namespace in ("prims", "aten")
+        ):
+            return True
+
+        if is_index_put_and_requires_h2d_sync_for_gpu_value(node):
+            return True
+
+        return False
+
+    def get_node_device(self, node: fx.Node) -> torch.device | None:
+        """
+        Get the device of a node.
+        """
+        ten = node.meta.get("val")
+        return None if not isinstance(ten, torch.Tensor) else ten.device
+
+    def get_cpu_indeg_count(self, graph: fx.Graph) -> dict[fx.Node, int]:
+        """
+        Get the number of cpu inputs to a node
+        """
+        cpu_indeg: dict[fx.Node, int] = Counter()
+
+        for node in graph.nodes:
+            cpu_count = 0
+
+            def add_cpu_inp(node):
+                nonlocal cpu_count
+                device = self.get_node_device(node)
+                cpu_count += device is not None and device.type == "cpu"
+
+            pytree.tree_map_only(fx.Node, add_cpu_inp, (node.args, node.kwargs))
+
+            # pyrefly: ignore [redundant-condition]
+            if cpu_count:
+                cpu_indeg[node] = cpu_count
+
+        return cpu_indeg
+
+    def __call__(self, graph: fx.Graph) -> None:
+        target_devices = OrderedSet[torch.device]()
+        constructors = []
+        cpu_placeholders: OrderedSet[fx.Node] = OrderedSet()
+
+        for node in graph.nodes:
+            device = self.get_node_device(node)
+            if device and device.type == self.target:
+                target_devices.add(device)
+
+            if (
+                self.allow_inputs
+                and node.op == "placeholder"
+                and self.is_cpu_scalar_tensor(node)
+            ):
+                cpu_placeholders.add(node)
+                constructors.append(node)
+                continue
+
+            if not (
+                isinstance(node.target, torch._ops.OpOverload)
+                and node.target.namespace in ("prims", "aten")
+            ):
+                continue
+
+            if not torch._subclasses.fake_tensor._is_tensor_constructor(node.target):
+                continue
+
+            if node.kwargs.get("device") != torch.device("cpu"):
+                continue
+
+            constructors.append(node)
+
+        # not handling multiple target devices initially
+        if not constructors or len(target_devices) != 1:
+            return
+
+        movable_constructors = self.find_movable_constructors(graph, constructors)
+
+        target_device = next(iter(target_devices))
+        movable_cpu_placeholders = movable_constructors & cpu_placeholders
+        if movable_cpu_placeholders:
+            node = next(iter(reversed(movable_cpu_placeholders)))
+            last_node = node
+            unsqueezed_nodes = []
+            for elem in movable_cpu_placeholders:
+                with graph.inserting_after(last_node):
+                    unsqueezed_nodes.append(
+                        graph.call_function(torch.ops.aten.unsqueeze.default, (elem, 0))
+                    )
+                    last_node = unsqueezed_nodes[-1]
+            with graph.inserting_after(last_node):
+                cpu_concat = graph.call_function(
+                    torch.ops.aten.cat.default, (unsqueezed_nodes,)
+                )
+                last_node = cpu_concat
+            with graph.inserting_after(last_node):
+                gpu_concat = graph.call_function(
+                    torch.ops.prims.device_put.default,
+                    (cpu_concat, target_device, True),
+                )
+                last_node = gpu_concat
+            with graph.inserting_after(last_node):
+                gpu_split = graph.call_function(
+                    torch.ops.aten.unbind.int, (gpu_concat,)
+                )
+                last_node = gpu_split
+            for idx, node in enumerate(movable_cpu_placeholders):
+                with graph.inserting_after(last_node):
+                    gpu_node = graph.call_function(operator.getitem, (gpu_split, idx))
+                    node.replace_all_uses_with(
+                        gpu_node,
+                        lambda x: x
+                        not in [cpu_concat, gpu_concat, gpu_split, gpu_node]
+                        + unsqueezed_nodes
+                        and x.target != torch.ops.aten.copy_.default,
+                    )
+                    last_node = gpu_node
+
+                # noop elimination if there are other device_put for gpu_node to
+                # target device. Alternatively, we could just move the other device_put
+                # earlier in the graph, but that is not supported in fx graph yet.
+                noop_device_puts = [
+                    user
+                    for user in gpu_node.users
+                    if user.target is torch.ops.prims.device_put.default
+                    and user.args[1] == target_device
+                ]
+                for noop in noop_device_puts:
+                    noop.replace_all_uses_with(gpu_node)
+                    graph.erase_node(noop)
+
+        movable_constructors -= movable_cpu_placeholders
+        for node in movable_constructors:
+            kwargs = node.kwargs.copy()
+            kwargs["device"] = target_device
+            node.kwargs = kwargs
+
+    def find_movable_constructors(
+        self, graph: fx.Graph, constructors: list[fx.Node]
+    ) -> OrderedSet[fx.Node]:
+        """
+        Starting from the cpu constructors, iterate through the graph and test that all of their
+        downstream uses can safely be moved to cpu.
+        """
+        cpu_indeg: dict[fx.Node, int] = self.get_cpu_indeg_count(graph)
+
+        # which constructors cannot be moved to gpu
+        cannot_move_to_gpu = OrderedSet[fx.Node]()
+
+        # For any node in the graph, which constructors does it have a dependency on
+        constructor_dependencies: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(
+            OrderedSet
+        )
+
+        # if a cpu node has a dependency on two different cpu constructors,
+        # then if either constructor cannot be moved to gpu, the other cannot as well.
+        # In this case any node with a dependency on one will have a dependency on the other
+        equal_constructor_sets: dict[fx.Node, OrderedSet[fx.Node]] = {
+            c: OrderedSet([c]) for c in constructors
+        }
+
+        def make_dependencies_equivalent(
+            set1: OrderedSet[fx.Node], set2: OrderedSet[fx.Node]
+        ) -> OrderedSet[fx.Node]:
+            # could use union find but not worth complexity here
+            set1.update(set2)
+            for obj in set1:
+                equal_constructor_sets[obj] = set1
+            return set1
+
+        queue: list[fx.Node] = list(constructors)
+
+        for c in queue:
+            constructor_dependencies[c].add(c)
+
+        while queue:
+            node = queue.pop()
+            dependencies = constructor_dependencies[node]
+
+            for user in node.users:
+                if self.cannot_be_moved(user):
+                    cannot_move_to_gpu.update(dependencies)
+                    break
+
+                # this node was used on a op which takes in multiple devices and output a gpu
+                # tensor. we can convert its cpu input to gpu without making further changes
+                if self.allow_cpu_device(user) and self.is_on_target_device(user):
+                    del cpu_indeg[user]
+                elif (
+                    self.allow_inputs
+                    and self.all_inputs_are_cpu_scalar_or_on_target_device(user)
+                ):
+                    # this node takes only cpu scalar tensors or gpu tensors as inputs
+                    # and outputs a gpu tensor. we can convert its cpu scalar inputs to gpu
+                    # without making further changes
+                    del cpu_indeg[user]
+                else:
+                    # otherwise, we should continue look at its downstream uses
+                    cpu_indeg[user] -= 1
+                    if cpu_indeg[user] == 0:
+                        del cpu_indeg[user]
+                        queue.append(user)
+
+                unioned_set = make_dependencies_equivalent(
+                    dependencies, constructor_dependencies[user]
+                )
+                constructor_dependencies[user] = unioned_set
+
+        for node in cpu_indeg:
+            if constructor_dependencies[node]:
+                cannot_move_to_gpu.update(constructor_dependencies[node])
+
+        all_cannot_move_to_gpu = cannot_move_to_gpu.copy()
+        for constructor in cannot_move_to_gpu:
+            all_cannot_move_to_gpu.update(equal_constructor_sets[constructor])
+
+        return OrderedSet(constructors) - all_cannot_move_to_gpu
+
+
+def move_constructors_to_gpu(graph: fx.Graph) -> None:
+    """
+    Moves intermediary tensors which are constructed on the cpu to gpu when safe
+    """
+
+    # cudagraph does not support cpu tensors. In this pass, we update the graph
+    # by explicitly moving cpu scalar tensors to gpu when profitable, relying on
+    # graph partition to split off this data copy, and cudagraphifying
+    # the remaining gpu ops.
+    allow_inputs_outputs = bool(
+        torch._inductor.config.triton.cudagraphs
+        and torch._inductor.config.graph_partition
+    )
+    ConstructorMoverPass(
+        get_gpu_type(),
+        allow_inputs=allow_inputs_outputs,
+        allow_outputs=allow_inputs_outputs,
+    )(graph)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/pre_grad.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/pre_grad.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fd81f9b8cd57a1384806191ff7b8338f7b36807
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/pre_grad.py
@@ -0,0 +1,877 @@
+# mypy: allow-untyped-defs
+import copy
+import functools
+import itertools
+import logging
+import types
+from collections.abc import Sequence
+
+import torch
+import torch.nn as nn
+from torch._dynamo.utils import counters, detect_fake_mode
+from torch._logging import trace_structured
+from torch.fx.experimental.optimization import (
+    matches_module_pattern,
+    replace_node_module,
+)
+from torch.fx.passes.graph_transform_observer import (
+    GraphTransformObserver as GraphTransformObserverBase,
+)
+from torch.fx.passes.shape_prop import ShapeProp
+from torch.nn import functional as F
+from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights
+
+from .. import config
+from ..fx_utils import matches_module_function_pattern
+from ..pattern_matcher import (
+    init_once_fakemode,
+    PatternMatcherPass as PatternMatcherPassBase,
+    stable_topological_sort,
+)
+from ..utils import is_cpu_device, pass_execution_and_save
+from .group_batch_fusion import group_batch_fusion_passes, PRE_GRAD_FUSIONS
+from .misc_patterns import numpy_compat_normalization
+from .split_cat import PRE_GRAD_PATTERNS
+
+
+PatternMatcherPass = functools.partial(
+    PatternMatcherPassBase, subsystem="pre_grad_passes"
+)
+GraphTransformObserver = functools.partial(
+    GraphTransformObserverBase, subsystem="pre_grad_passes"
+)
+
+log = logging.getLogger(__name__)
+
+efficient_conv_bn_eval_pass = PatternMatcherPass(
+    pass_name="efficient_conv_bn_eval_pass"
+)
+
+fuse_split_linear_add_pass = PatternMatcherPass(
+    pass_name="fuse_split_linear_add_pass",
+)
+fuse_chunk_squeeze_cat_pass = PatternMatcherPass(
+    pass_name="fuse_chunk_squeeze_cat_pass",
+)
+remove_reshape_pass = PatternMatcherPass(
+    pass_name="remove_reshape_pass",
+)
+
+# based on predispatch aten IR
+normalization_pass_aten = PatternMatcherPass(pass_name="normalization_pass_aten")
+merge_splits_pass_aten = PatternMatcherPass(pass_name="merge_splits_pass_aten")
+split_cat_pass_aten = PatternMatcherPass(pass_name="split_cat_pass_aten")
+unbind_stack_pass_aten = PatternMatcherPass(pass_name="unbind_stack_pass_aten")
+merge_getitem_cat_pass_aten = PatternMatcherPass(
+    pass_name="merge_getitem_cat_pass_aten"
+)
+merge_stack_tahn_unbind_pass_aten = PatternMatcherPass(
+    pass_name="merge_stack_tahn_unbind_pass_aten"
+)
+mutate_cat_pass_aten = PatternMatcherPass(pass_name="mutate_cat_pass_aten")
+remove_split_with_size_one_pass_aten = PatternMatcherPass(
+    pass_name="remove_split_with_size_one_pass_aten"
+)
+
+
+def save_inductor_dict(pass_to_compare=None):
+    if not pass_to_compare:
+        pass_to_compare = list(config.pre_grad_fusion_options.keys()) + list(
+            config.post_grad_fusion_options.keys()
+        )
+    return {p: dict(counters["inductor"]).get(p, 0) for p in pass_to_compare}
+
+
+def is_same_dict(inductor_dict, optimus_dict):
+    for pass_name, count in optimus_dict.items():
+        if count != dict(inductor_dict).get(pass_name, 0):
+            return False
+    return True
+
+
+def shape_prop(mod) -> None:
+    return None
+
+
+def normalize_node_kwargs_pass(graph):
+    return None
+
+
+def fuse_parallel_linear_pass(graph):
+    return None
+
+
+def remove_split_ops(graph, shape_prop):
+    return None
+
+
+def remove_split_ops_pass(graph):
+    remove_split_ops(graph.owning_module, shape_prop)
+
+
+def fuse_chunk_reshape_unsqueeze_concat_pass(graph):
+    return None
+
+
+def fuse_chunk_reshape_concat_pass(graph):
+    return None
+
+
+def remove_noop_pass(graph):
+    return None
+
+
+def stack_to_unsqueeze_pass(graph):
+    return None
+
+
+def merge_concats_pass(graph):
+    return None
+
+
+def relu_nan_to_num(graph):
+    return None
+
+
+def fuse_split_getitem_squeeze_cat(graph):
+    return None
+
+
+def use_triton_dot_compress(graph):
+    return None
+
+
+def use_triton_lce_replace_simple_LCE_helper(gm, shape_prop):
+    return None
+
+
+def use_triton_lce_replace_simple_LCE(graph):
+    return use_triton_lce_replace_simple_LCE_helper(graph.owning_module, shape_prop)
+
+
+def use_triton_lce_replace_normal_LCE_helper(gm, shape_prop):
+    return None
+
+
+def use_triton_lce_replace_normal_LCE(graph):
+    return use_triton_lce_replace_simple_LCE_helper(graph.owning_module, shape_prop)
+
+
+def use_matmul_lce_replace_normal_LCE(graph):
+    return None
+
+
+def use_matmul_fuse_lce_replace_first_LCE(graph):
+    return None
+
+
+@init_once_fakemode
+def lazy_init():
+    from . import efficient_conv_bn_eval, split_cat  # noqa: F401
+
+    if config.is_fbcode():
+        from . import fb  # type: ignore[attr-defined]  # noqa: F401
+
+
+def _get_pass_name_func(p):
+    if isinstance(p, PatternMatcherPassBase):
+        pass_name = p.pass_name
+        pass_func = p.apply
+    elif isinstance(p, types.FunctionType):
+        pass_name = p.__name__.lstrip("_")
+        pass_func = p
+    else:
+        pass_name = None
+        pass_func = None
+
+    return pass_name, pass_func
+
+
+def _run_pre_dispatch_passes(
+    gm: torch.fx.GraphModule,
+    example_inputs: Sequence[object] = (),
+    add_passes: str | None = None,
+    remove_passes: str | None = None,
+) -> None:
+    # order matters
+    default_pass_list = [
+        # normalize passes, must be called as the first passes
+        normalization_pass_aten,
+        normalize_node_kwargs_pass,
+        remove_noop_pass,
+        relu_nan_to_num,
+        fuse_chunk_reshape_concat_pass,
+        group_batch_fusion_passes,
+        normalize_node_kwargs_pass,
+        fuse_chunk_squeeze_cat_pass,
+        merge_concats_pass,
+        fuse_split_linear_add_pass,
+        remove_reshape_pass,
+        fuse_parallel_linear_pass,
+        remove_split_ops_pass,
+        stack_to_unsqueeze_pass,  # run before fuse_chunk_reshape_unsqueeze_concat_pass
+        fuse_chunk_reshape_unsqueeze_concat_pass,
+    ]
+
+    full_pass_list = default_pass_list + [
+        fuse_split_getitem_squeeze_cat,
+        use_triton_dot_compress,
+        use_triton_lce_replace_simple_LCE,
+        use_triton_lce_replace_normal_LCE,
+        use_matmul_fuse_lce_replace_first_LCE,
+        use_matmul_lce_replace_normal_LCE,
+    ]
+
+    log.info(
+        f"pre_grad_passes: add_passes: {add_passes}, remove_pass: {remove_passes}"  # noqa: G004
+    )
+    add_passes_list = []
+    remove_passes_list = []
+    if add_passes:
+        add_passes_list = add_passes.split(",")
+    if remove_passes:
+        remove_passes_list = remove_passes.split(",")
+
+    shape_prop = lambda mod: ShapeProp(  # noqa: E731
+        gm=mod,
+        # pyre-fixme[16]: Module `torch._dynamo.utils` has no attribute `detect_fake_mode`
+        fake_mode=detect_fake_mode(example_inputs),
+    ).propagate(*tuple(example_inputs))
+
+    for p in default_pass_list:
+        pass_name, pass_func = _get_pass_name_func(p)
+        # should not happen
+        if pass_name is None or pass_func is None:
+            continue
+        if pass_name in remove_passes_list:
+            continue
+        pass_execution_and_save(
+            pass_func,
+            gm,
+            example_inputs,
+            f"[Pre grad(predispatch IR)] Apply {pass_name} pass",
+        )
+
+    for p in full_pass_list:
+        pass_name, pass_func = _get_pass_name_func(p)
+        if pass_name is None or pass_func is None:
+            continue
+        if pass_name in add_passes_list:
+            pass_execution_and_save(
+                pass_func,
+                gm,
+                example_inputs,
+                f"[Pre grad(predispatch IR)] Apply {pass_name} pass",
+            )
+
+    if "remove_noop" not in remove_passes_list:
+        # Remove noops at the end, which may be generated other passes.
+        pass_execution_and_save(
+            remove_noop_pass,
+            gm,
+            example_inputs,
+            "[Pre grad(predispatch IR)]Apply remove_noop pass",
+        )
+    shape_prop(gm)
+
+
+def pre_grad_passes(
+    gm: torch.fx.GraphModule,
+    example_inputs: Sequence[object] = (),
+    add_passes: str | None = None,
+    remove_passes: str | None = None,
+) -> torch.fx.GraphModule:
+    """
+    Apply passes on the input FX graph using Torch IR.
+
+    WARNING:
+    The IR before grad is not functional or normalized, so it is harder
+    to write passes on this IR.  Passes must be safe with respect to
+    aliasing and mutation and need to handle all possible arg schemas.
+
+    Consider adding a new pass to post_grad.py or joint_graph.py which
+    are after functionalization and normalization.
+    """
+    if config.pattern_matcher:
+        lazy_init()
+        if hasattr(
+            config, "fx_passes_numeric_check"
+        ) and config.fx_passes_numeric_check.get("pre_grad", False):
+            gm_before_fx_passes = gm.__copy__()
+        # explicitly run with predispatch atenIR based passes
+        if config.is_predispatch:
+            _run_pre_dispatch_passes(gm, example_inputs, add_passes, remove_passes)
+        else:
+            # We only log the graph with changes to avoid the excessive compilation time
+            # https://fb.workplace.com/groups/257735836456307/permalink/633533465543207/
+            if example_inputs is not None:
+                gm = fuse_fx(gm, example_inputs)
+            numpy_compat_normalization(gm.graph)
+            # We should always do the normalization_pass first
+            if "normalization_pass" in config.pre_grad_fusion_options:
+                pattern_matcher_pass = PRE_GRAD_PATTERNS["normalization_pass"]
+                pattern_matcher_pass.apply(gm.graph)  # type: ignore[arg-type]
+            group_batch_fusion_passes(gm.graph, pre_grad=True)
+            for pass_name in config.pre_grad_fusion_options:
+                # skip all patterns for group batch fusions
+                if pass_name in PRE_GRAD_FUSIONS or pass_name == "normalization_pass":
+                    continue
+                pattern_matcher_pass = PRE_GRAD_PATTERNS[pass_name]
+                inductor_before_change = save_inductor_dict(
+                    [pattern_matcher_pass.pass_name]
+                )
+                # we support run same pattern multiple times, the default is to run only once
+                counter = config.pre_grad_fusion_options[pass_name].get("counter", 1)
+                for _ in range(counter):
+                    pattern_matcher_pass.apply(gm.graph)  # type: ignore[arg-type]
+                if not is_same_dict(counters["inductor"], inductor_before_change):
+                    trace_structured(
+                        "artifact",
+                        metadata_fn=lambda: {
+                            "name": f"{pattern_matcher_pass.pass_name}_pre_grad",
+                            "encoding": "string",
+                        },
+                        payload_fn=lambda: gm.print_readable(
+                            print_output=False, include_stride=True, include_device=True
+                        ),
+                    )
+            # TODO: move efficient_conv_bn_eval_pass to the fusions dict too.
+            efficient_conv_bn_eval_pass.apply(gm.graph)  # type: ignore[arg-type]
+
+    if config.pre_grad_custom_pass is not None:
+        GraphTransformObserver(gm, "pre_grad_custom_pass").apply_graph_pass(
+            config.pre_grad_custom_pass
+        )
+    stable_topological_sort(gm.graph)
+
+    from .quantization import quant_lift_up
+
+    quant_lift_up(gm)
+
+    gm.graph.lint()
+    gm.recompile()
+
+    if (
+        config.pattern_matcher
+        and hasattr(config, "fx_passes_numeric_check")
+        and config.fx_passes_numeric_check.get("pre_grad", False)
+        and example_inputs is not None
+    ):
+        from .numeric_utils import numeric_check_if_enabled
+
+        gm_after_fx_passes = gm.__copy__()
+        numeric_check_if_enabled(
+            gm_before_fx_passes,  # type: ignore[possibly-undefined]
+            gm_after_fx_passes,
+            example_inputs,
+            config.fx_passes_numeric_check.get("num_iterations", 1),
+            config.fx_passes_numeric_check.get("precision", 1e-4),
+        )
+
+    return gm
+
+
+def fuse_fx(gm: torch.fx.GraphModule, example_inputs) -> torch.fx.GraphModule:
+    is_cpu = is_cpu_device(example_inputs)
+    # pyre-fixme[16]: Module `torch._dynamo.utils` has no attribute `detect_fake_mode`
+    fake_mode = detect_fake_mode(example_inputs)
+
+    gm = sink_cat_after_pointwise(gm)
+    if config.permute_fusion and not is_cpu:
+        # For linear permute fusion, we need to check input info to identify
+        # and perform proper permutation/transpose
+        ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs)
+        with GraphTransformObserver(gm, "linear_permute_fusion"):
+            gm = linear_permute_fusion(gm)
+        with GraphTransformObserver(gm, "permute_linear_fusion"):
+            gm = permute_linear_fusion(gm)
+        with GraphTransformObserver(gm, "permute_matmul_fusion"):
+            gm = permute_matmul_fusion(gm)
+
+    # make sure the autograd is disabled.
+    if torch.is_grad_enabled() or not is_cpu:
+        return gm
+    if config.freezing:
+        with GraphTransformObserver(gm, "remove_identity"):
+            gm = remove_identity(gm)
+        with GraphTransformObserver(gm, "fuse_conv_bn"):
+            gm = fuse_conv_bn(gm)
+    return gm
+
+
+def fetch_attr(target: str, mod):
+    target_atoms = target.split(".")
+    attr_itr = mod
+    for i, atom in enumerate(target_atoms):
+        if not hasattr(attr_itr, atom):
+            raise RuntimeError(
+                f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
+            )
+        attr_itr = getattr(attr_itr, atom)
+    return attr_itr
+
+
+def remove_identity(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
+    """
+    Removes all identity layers from the module.
+    """
+
+    class IdentityRemover(torch.fx.Transformer):
+        def call_module(self, target, args, kwargs):
+            if isinstance(self.submodules[target], nn.Identity):
+                assert len(args) == 1
+                return args[0]
+            else:
+                return super().call_module(target, args, kwargs)
+
+    return IdentityRemover(gm).transform()
+
+
+def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False) -> torch.fx.GraphModule:
+    """
+    Fuses Convolution/BN layers for inference purposes.
+    """
+    modules_patterns = [
+        (torch.nn.Conv1d, torch.nn.BatchNorm1d),
+        (torch.nn.Conv2d, torch.nn.BatchNorm2d),
+        (torch.nn.Conv3d, torch.nn.BatchNorm3d),
+    ]
+    module_function_patterns = [
+        (torch.nn.Conv1d, F.batch_norm),
+        (torch.nn.Conv2d, F.batch_norm),
+        (torch.nn.Conv3d, F.batch_norm),
+    ]
+    modules = dict(gm.named_modules())
+
+    class ConvBNFusion:
+        def __init__(
+            self,
+            bn_node,
+            conv_module,
+            bn_module=None,  # For BN Module
+            bn_running_mean=None,  # For Functional BN
+            bn_running_var=None,
+            bn_eps=None,
+            bn_weight=None,
+            bn_bias=None,
+        ) -> None:
+            self.bn_nodes = [
+                bn_node,
+            ]
+            self.conv_module = conv_module
+            self.bn_module = bn_module
+            self.bn_running_mean = bn_running_mean
+            self.bn_running_var = bn_running_var
+            self.bn_eps = bn_eps
+            self.bn_weight = bn_weight
+            self.bn_bias = bn_bias
+            self.fusion_enabled = True
+
+        def add_bn_node(self, bn_node):
+            self.bn_nodes.append(bn_node)
+
+        def disable_fusion(self):
+            self.fusion_enabled = False
+
+        def is_fusion_enabled(self):
+            return self.fusion_enabled
+
+    conv_bn_to_fuse: dict[int, ConvBNFusion] = {}
+    for pattern in modules_patterns:
+        conv_bn_to_fuse.clear()
+        for node in gm.graph.nodes:
+            if matches_module_pattern(pattern, node, modules):
+                if len(node.args[0].users) > 1:  # Output of conv is used by other nodes
+                    continue
+                conv = modules[node.args[0].target]
+                bn = modules[node.target]
+                eval_mode = all(not n.training for n in [conv, bn])
+                if not eval_mode:
+                    continue
+                if not bn.track_running_stats:
+                    continue
+
+                # Do hash based on the module name of conv
+                hash_id = hash(node.args[0].target)
+                if hash_id not in conv_bn_to_fuse:
+                    conv_bn_to_fuse[hash_id] = ConvBNFusion(node, conv, bn)
+                else:
+                    if bn == conv_bn_to_fuse[hash_id].bn_module:
+                        # Do fusion if same bn module
+                        conv_bn_to_fuse[hash_id].add_bn_node(node)
+                    else:
+                        # Disable the conv bn folding if conv shared by different bn
+                        conv_bn_to_fuse[hash_id].disable_fusion()
+
+        for conv_bn_fusion in conv_bn_to_fuse.values():
+            if conv_bn_fusion.is_fusion_enabled():
+                bn_nodes = conv_bn_fusion.bn_nodes
+                conv = conv_bn_fusion.conv_module
+                bn = conv_bn_fusion.bn_module
+
+                # pyrefly: ignore [bad-argument-type]
+                fused_conv = fuse_conv_bn_eval(conv, bn)
+                for bn_node in bn_nodes:
+                    replace_node_module(bn_node.args[0], modules, fused_conv)
+                    bn_node.replace_all_uses_with(bn_node.args[0])
+                    gm.graph.erase_node(bn_node)
+
+    gm.graph.lint()
+    for pattern in module_function_patterns:
+        conv_bn_to_fuse.clear()
+        for node in gm.graph.nodes:
+            if matches_module_function_pattern(pattern, node, modules):
+                # TODO: support kwargs.
+                if len(node.args) != 8:
+                    continue
+                conv = modules[node.args[0].target]
+                bn_training = node.args[5]
+                bn_eps = node.args[7]
+                if conv.training or bn_training:
+                    continue
+                if type(bn_eps) is not float:
+                    continue
+
+                def _used_by_same_conv_module(users):
+                    conv_module_name = users[0].args[0].target
+                    return all(
+                        conv_module_name == user.args[0].target for user in users
+                    )
+
+                bn_args_is_constant = all(
+                    n.op == "get_attr"
+                    and (len(n.users) == 1 or _used_by_same_conv_module(list(n.users)))
+                    for n in node.args[1:5]
+                )
+                if not bn_args_is_constant:
+                    continue
+                bn_running_mean = fetch_attr(node.args[1].target, gm)
+                bn_running_var = fetch_attr(node.args[2].target, gm)
+                bn_weight = fetch_attr(node.args[3].target, gm)
+                bn_bias = fetch_attr(node.args[4].target, gm)
+                if bn_running_mean is None or bn_running_var is None:
+                    continue
+
+                # Do hash based on the module name of conv
+                hash_id = hash(node.args[0].target)
+                if hash_id not in conv_bn_to_fuse:
+                    conv_bn_to_fuse[hash_id] = ConvBNFusion(
+                        node,
+                        conv,
+                        bn_running_mean=bn_running_mean,
+                        bn_running_var=bn_running_var,
+                        bn_eps=bn_eps,
+                        bn_weight=bn_weight,
+                        bn_bias=bn_bias,
+                    )
+                else:
+                    if (
+                        hash(bn_running_mean)
+                        == hash(conv_bn_to_fuse[hash_id].bn_running_mean)
+                        and hash(bn_running_var)
+                        == hash(conv_bn_to_fuse[hash_id].bn_running_var)
+                        and torch.allclose(
+                            torch.tensor(bn_eps),
+                            torch.tensor(conv_bn_to_fuse[hash_id].bn_eps),
+                        )
+                        and hash(bn_weight) == hash(conv_bn_to_fuse[hash_id].bn_weight)
+                        and hash(bn_bias) == hash(conv_bn_to_fuse[hash_id].bn_bias)
+                    ):
+                        # Do fusion if same functional bn
+                        conv_bn_to_fuse[hash_id].add_bn_node(node)
+                    else:
+                        # Disable the conv bn folding if conv shared by different bn
+                        conv_bn_to_fuse[hash_id].disable_fusion()
+
+        for conv_bn_fusion in conv_bn_to_fuse.values():
+            if conv_bn_fusion.is_fusion_enabled():
+                bn_nodes = conv_bn_fusion.bn_nodes
+                conv = conv_bn_fusion.conv_module
+                bn_running_mean = conv_bn_fusion.bn_running_mean
+                bn_running_var = conv_bn_fusion.bn_running_var
+                bn_eps = conv_bn_fusion.bn_eps
+                bn_weight = conv_bn_fusion.bn_weight
+                bn_bias = conv_bn_fusion.bn_bias
+
+                fused_conv = copy.deepcopy(conv)
+                fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights(
+                    fused_conv.weight,
+                    fused_conv.bias,
+                    # pyrefly: ignore [bad-argument-type]
+                    bn_running_mean,
+                    # pyrefly: ignore [bad-argument-type]
+                    bn_running_var,
+                    # pyrefly: ignore [bad-argument-type]
+                    bn_eps,
+                    bn_weight,
+                    bn_bias,
+                )
+                for bn_node in bn_nodes:
+                    replace_node_module(bn_node.args[0], modules, fused_conv)
+                    bn_node.replace_all_uses_with(bn_node.args[0])
+                    gm.graph.erase_node(bn_node)
+    gm.graph.lint()
+    gm.recompile()
+
+    return gm
+
+
+class NormalizedLinearNode:
+    def __init__(self, node: torch.fx.Node) -> None:
+        assert node.op == "call_function"
+        assert node.target is torch.nn.functional.linear
+        self.node: torch.fx.Node = node
+
+    def get_input(self) -> torch.fx.Node:
+        if len(self.node.args) > 0:
+            return self.node.args[0]  # type: ignore[return-value]
+        else:
+            return self.node.kwargs["input"]  # type: ignore[return-value]
+
+    def get_weight(self) -> torch.fx.Node:
+        if len(self.node.args) > 1:
+            return self.node.args[1]  # type: ignore[return-value]
+        else:
+            return self.node.kwargs["weight"]  # type: ignore[return-value]
+
+    def get_bias(self) -> torch.fx.Node:
+        if len(self.node.args) > 2:
+            return self.node.args[2]  # type: ignore[return-value]
+        else:
+            return self.node.kwargs.get("bias", None)  # type: ignore[return-value]
+
+
+class NormalizedMatmulNode:
+    def __init__(self, node: torch.fx.Node) -> None:
+        assert node.op == "call_function"
+        assert node.target in [torch.bmm, torch.matmul]
+        self.node: torch.fx.Node = node
+
+    def get_input(self) -> torch.fx.Node:
+        if len(self.node.args) > 0:
+            return self.node.args[0]  # type: ignore[return-value]
+        else:
+            return self.node.kwargs["input"]  # type: ignore[return-value]
+
+    def get_other(self) -> torch.fx.Node:
+        if len(self.node.args) > 1:
+            return self.node.args[1]  # type: ignore[return-value]
+        else:
+            return self.node.kwargs["other"]  # type: ignore[return-value]
+
+
+def check_permute(node: torch.fx.Node) -> bool:
+    ranks = len(node.meta["tensor_meta"].shape)
+    if len(node.args) > 3:
+        permutation = [node.args[i] % ranks for i in range(1, ranks + 1)]  # type: ignore[operator]
+    elif (
+        "permutation" in node.kwargs
+        and node.kwargs["permutation"] is not None
+        and len(node.kwargs["permutation"]) > 2  # type: ignore[arg-type]
+    ):
+        permutation = [i % ranks for i in node.kwargs["permutation"]]  # type: ignore[operator, union-attr]
+    else:
+        return False
+    allowed_permutation = list(range(ranks))
+    allowed_permutation[-1] = ranks - 2
+    allowed_permutation[-2] = ranks - 1
+    return permutation == allowed_permutation
+
+
+def sink_cat_after_pointwise(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
+    def one_user(node):
+        users = list(node.users)
+        return users[0] if len(users) == 1 else None
+
+    def is_view(node):
+        return node.op == "call_method" and node.target == "view"
+
+    def is_pointwise_unary(node):
+        ops = "call_function", "call_method"
+        pointwise = torch.relu, torch.tanh, "relu", "tanh"
+        return node.op in ops and node.target in pointwise
+
+    g = module.graph
+    for node in g.nodes:
+        if node.op != "call_function" or node.target != torch.cat:
+            continue
+
+        cat_or_view = node
+        while True:
+            user = one_user(cat_or_view)
+            if not user or not is_view(user):
+                break
+            cat_or_view = user
+
+        if user and is_pointwise_unary(user):
+            with g.inserting_before(node):
+
+                def cat_args(tensors, dim=0):
+                    return tensors, dim
+
+                tensors, dim = cat_args(*node.args, **node.kwargs)
+                new_kwargs = {
+                    name: val for name, val in user.kwargs.items() if name != "input"
+                }
+                new_tensors = [
+                    g.create_node(user.op, user.target, args=(arg,), kwargs=new_kwargs)
+                    for arg in tensors
+                ]
+                new_cat = g.create_node(
+                    "call_function", torch.cat, args=(new_tensors, dim)
+                )
+                user.replace_all_uses_with(cat_or_view)
+                node.replace_all_uses_with(new_cat)
+                g.erase_node(user)
+                g.erase_node(node)
+    g.lint()
+    module.recompile()
+    return module
+
+
+def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
+    for node in module.graph.find_nodes(op="call_method", target="permute"):
+        if check_permute(node):
+            if len(node.args) > 0:
+                input_node = node.args[0]
+            else:
+                input_node = node.kwargs["input"]
+            if (
+                input_node.op == "call_function"
+                and input_node.target is torch.nn.functional.linear
+            ):
+                normalized = NormalizedLinearNode(input_node)
+                input = normalized.get_input()
+                weight = normalized.get_weight()
+                bias = normalized.get_bias()
+                with module.graph.inserting_before(node):
+                    fused_node = module.graph.call_function(
+                        linear_transpose, args=(input, weight, bias)
+                    )
+                    node.replace_all_uses_with(fused_node)
+                    module.graph.erase_node(node)
+                    if len(input_node.users) == 0:
+                        module.graph.erase_node(input_node)
+
+    module.graph.lint()
+    module.recompile()
+    return module
+
+
+# Y1 = X * W^T + bias
+# Y2 = Y1.permute(0, 2, 1)
+# ---->
+# Y2 = (W * X^T + bias.unsqueeze(-1))^T
+def linear_transpose(
+    input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None
+) -> torch.Tensor:
+    if bias is None:
+        return torch.matmul(weight, input.transpose(-1, -2))
+    return torch.matmul(weight, input.transpose(-1, -2)) + bias.unsqueeze(-1)
+
+
+def permute_linear_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
+    for node in module.graph.find_nodes(
+        op="call_function", target=torch.nn.functional.linear
+    ):
+        if len(node.args) > 0:
+            input_node = node.args[0]
+        else:
+            input_node = node.kwargs["input"]
+        if (
+            input_node.op == "call_method"
+            and input_node.target == "permute"
+            and check_permute(input_node)
+        ):
+            normalized = NormalizedLinearNode(node)
+            if len(input_node.args) > 0:
+                input = input_node.args[0]
+            else:
+                input = input_node.kwargs["input"]
+            weight = normalized.get_weight()
+            bias = normalized.get_bias()
+            with module.graph.inserting_before(node):
+                fused_node = module.graph.call_function(
+                    transpose_linear, args=(input, weight, bias)
+                )
+                node.replace_all_uses_with(fused_node)
+                module.graph.erase_node(node)
+                if len(input_node.users) == 0:
+                    module.graph.erase_node(input_node)
+
+    module.graph.lint()
+    module.recompile()
+    return module
+
+
+def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
+    for node in itertools.chain(
+        module.graph.find_nodes(op="call_function", target=torch.bmm),
+        module.graph.find_nodes(op="call_function", target=torch.matmul),
+    ):
+        normalized = NormalizedMatmulNode(node)
+        input_A_node = normalized.get_input()
+        input_B_node = normalized.get_other()
+        input_A = input_A_node
+        input_B = input_B_node
+        Atrans = Btrans = False
+        if (
+            input_A_node.op == "call_method"
+            and input_A_node.target == "permute"
+            and check_permute(input_A_node)
+        ):
+            Atrans = True
+            if len(input_A_node.args) > 0:
+                input_A = input_A_node.args[0]  # type: ignore[assignment]
+            else:
+                input_A = input_A_node.kwargs["input"]  # type: ignore[assignment]
+
+        if (
+            input_B_node.op == "call_method"
+            and input_B_node.target == "permute"
+            and check_permute(input_B_node)
+        ):
+            Btrans = True
+            if len(input_B_node.args) > 0:
+                input_B = input_B_node.args[0]  # type: ignore[assignment]
+            else:
+                input_B = input_B_node.kwargs["input"]  # type: ignore[assignment]
+
+        if Atrans or Btrans:
+            with module.graph.inserting_before(node):
+                fused_node = module.graph.call_function(
+                    transpose_matmul,
+                    args=(input_A, input_B, Atrans, Btrans),
+                )
+            node.replace_all_uses_with(fused_node)
+            module.graph.erase_node(node)
+            if Atrans and len(input_A_node.users) == 0:
+                module.graph.erase_node(input_A_node)
+            if Btrans and len(input_B_node.users) == 0:
+                module.graph.erase_node(input_B_node)
+
+    module.graph.lint()
+    module.recompile()
+    return module
+
+
+# X1 = X.permute(0, 2, 1)
+# Y1 = X1 * W1^T + bias1
+# ---->
+# Y2 = X1.transpose(-1, -2) * W1^T + bias1
+def transpose_linear(
+    input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None
+) -> torch.Tensor:
+    if bias is None:
+        return torch.matmul(input.transpose(-1, -2), weight.t())
+    return torch.matmul(input.transpose(-1, -2), weight.t()) + bias
+
+
+def transpose_matmul(
+    A: torch.Tensor, B: torch.Tensor, Atrans: bool, Btrans: bool
+) -> torch.Tensor:
+    if Atrans:
+        A = A.transpose(-1, -2)
+    if Btrans:
+        B = B.transpose(-1, -2)
+    return torch.matmul(A, B)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/quantization.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/quantization.py
new file mode 100644
index 0000000000000000000000000000000000000000..951a62acf227610007db48a9aae1aa6795d01ee8
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/quantization.py
@@ -0,0 +1,3968 @@
+# mypy: allow-untyped-decorators
+# mypy: allow-untyped-defs
+import copy
+import functools
+import itertools
+import math
+import operator
+from typing import Any
+
+import torch
+from torch._dynamo.utils import counters
+from torch.fx.experimental.symbolic_shapes import has_free_symbols
+from torch.fx.node import map_arg
+
+from .. import config
+from ..lowering import lowerings as L, require_channels_last
+from ..pattern_matcher import (
+    Arg,
+    CallFunction,
+    filter_nodes,
+    KeywordArg,
+    ListOf,
+    Match,
+    stable_topological_sort,
+)
+from ..utils import pad_listlike
+from .freezing_patterns import register_freezing_graph_pattern
+from .post_grad import register_lowering_pattern
+
+
+aten = torch.ops.aten
+prims = torch.ops.prims
+quantized_decomposed = torch.ops.quantized_decomposed
+quantized = torch.ops.quantized
+
+# Only for per tensor quant since permute may changes the channel idx
+_PER_TENSOR_QUANTIZE_OPS = [
+    quantized_decomposed.quantize_per_tensor.default,
+    quantized_decomposed.quantize_per_tensor.tensor,
+]
+
+_VIEW_OPS = [
+    aten.transpose.int,
+    aten.permute.default,
+    aten.view.default,
+    aten.reshape.default,
+]
+
+"""
+The quantization.py file primarily incorporates passes related to quantization fusion
+in inductor, includes:
+1. Dequant Promotion;
+2. Conv/GEMM weight prepack with oneDNN Library;
+3. Conv/GEMM quantization fusion with output quant node (if have);
+4. Other pointwise operators' quantization fusion like: qmaxpool2d, qcat and more;
+
+It also involves int8-mixed-fp32 and int8-mixed-bf16 quantization. The main difference
+of patterns for int8-mixed-bf16, comparing with int8-mixed-fp32, is
+1. There is to(dtype=torch.bfloat16) node at the inputs of activation and weight for Conv/GEMM.
+2. There is to(dtype=torch.float32) node at the outputs of Conv/GEMM before inputs to next quant node.
+Refer to: https://github.com/pytorch/pytorch/issues/111640 for detail design of int8-mixed-bf16
+quantization.
+"""
+
+
+def _get_pattern_output_dtype(match: Match):
+    """
+    Get the pattern's output dtype from node's meta
+    Assume only 1 output node in this matched pattern.
+    """
+    pattern_output_nodes = match.output_nodes()
+    assert len(pattern_output_nodes) == 1
+    output_node = pattern_output_nodes[0]
+    assert isinstance(output_node, torch.fx.Node)
+    output_dtype = output_node.meta["val"].dtype
+    assert output_dtype in [
+        torch.int8,
+        torch.uint8,
+        torch.float32,
+        torch.bfloat16,
+        torch.float8_e4m3fn,
+    ]
+    return output_dtype
+
+
+def _may_generate_pattern_with_dtype_convert(
+    pattern, dtype=Arg(), with_dtype_convert=True, users=1
+):
+    if with_dtype_convert:
+        return CallFunction(
+            prims.convert_element_type.default,
+            pattern,
+            dtype,
+            _users=users,
+        )
+    else:
+        return pattern
+
+
+def _may_generate_pattern_with_reshape(pattern, reshape_size=Arg(), with_reshape=True):
+    if with_reshape:
+        return CallFunction(
+            torch.ops.aten.reshape.default,
+            pattern,
+            reshape_size,
+        )
+    else:
+        return pattern
+
+
+def _generate_linear_t_pattern(
+    _dequant_per_channel_pattern,
+    dtype,
+):
+    assert dtype in [torch.float32, torch.bfloat16]
+    t_pattern = CallFunction(
+        aten.permute.default,
+        _may_generate_pattern_with_dtype_convert(
+            _dequant_per_channel_pattern,
+            KeywordArg("autocast_wgt_dtype"),
+            dtype == torch.bfloat16,
+        ),
+        KeywordArg("permute_axes"),
+    )
+    return t_pattern
+
+
+def _unary_fusion_pattern(unary_fusion, call_fn, users, is_bf16):
+    # only insert to_dtype if is_bf16 is True
+    computation_call = _may_generate_pattern_with_dtype_convert(
+        call_fn, dtype=KeywordArg("to_float"), with_dtype_convert=is_bf16, users=users
+    )
+    return unary_fusion(computation_call)
+
+
+def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False):
+    dequantize_per_tensor_activation_pattern = CallFunction(
+        quantized_decomposed.dequantize_per_tensor.tensor
+        if is_tensor_overload
+        else quantized_decomposed.dequantize_per_tensor.default,
+        KeywordArg("x"),
+        KeywordArg("x_scale"),
+        KeywordArg("x_zp"),
+        KeywordArg("x_quant_min"),
+        KeywordArg("x_quant_max"),
+        KeywordArg("x_dq_dtype"),
+    )
+    return dequantize_per_tensor_activation_pattern
+
+
+dequantize_per_channel_weight_pattern = CallFunction(
+    quantized_decomposed.dequantize_per_channel.default,
+    KeywordArg("q_weight"),
+    KeywordArg("w_scale"),
+    KeywordArg("w_zp"),
+    KeywordArg("w_axis"),
+    KeywordArg("w_quant_min"),
+    KeywordArg("w_quant_max"),
+    KeywordArg("w_dtype"),
+)
+
+dequantize_per_channel_to_bf16_weight_pattern = (
+    _may_generate_pattern_with_dtype_convert(
+        dequantize_per_channel_weight_pattern,
+        KeywordArg("autocast_wgt_dtype"),
+    )
+)
+
+dequantize_per_channel_clone_weight_pattern = CallFunction(
+    aten.clone.default,
+    dequantize_per_channel_weight_pattern,
+    memory_format=KeywordArg("memory_format"),
+)
+
+dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction(
+    aten.clone.default,
+    dequantize_per_channel_to_bf16_weight_pattern,
+    memory_format=KeywordArg("memory_format"),
+)
+
+
+def get_qconv_pt2e_pattern(x_scale_zp_are_tensors=False, users=1):
+    qconv_op = (
+        torch.ops.onednn.qconv_pointwise.tensor
+        if x_scale_zp_are_tensors
+        else torch.ops.onednn.qconv_pointwise.default
+    )
+    return CallFunction(
+        qconv_op,
+        KeywordArg("x"),
+        KeywordArg("x_scale"),
+        KeywordArg("x_zp"),
+        KeywordArg("packed_weight"),
+        KeywordArg("w_scale"),
+        KeywordArg("w_zp"),
+        KeywordArg("b"),
+        KeywordArg("stride"),
+        KeywordArg("padding"),
+        KeywordArg("dilation"),
+        KeywordArg("groups"),
+        KeywordArg("output_scale"),
+        KeywordArg("output_zero_point"),
+        KeywordArg("output_dtype"),
+        KeywordArg("postop_name"),
+        KeywordArg("postop_args"),
+        KeywordArg("postop_algorithm"),
+        _users=users,
+    )
+
+
+def get_qconv2d_binary_pt2e_pattern(x_scale_zp_are_tensors=False, users=1):
+    qconv_op = (
+        torch.ops.onednn.qconv2d_pointwise.binary_tensor
+        if x_scale_zp_are_tensors
+        else torch.ops.onednn.qconv2d_pointwise.binary
+    )
+    return CallFunction(
+        qconv_op,
+        KeywordArg("x"),
+        KeywordArg("x_scale"),
+        KeywordArg("x_zp"),
+        KeywordArg("packed_weight"),
+        KeywordArg("w_scale"),
+        KeywordArg("w_zp"),
+        KeywordArg("accum"),
+        KeywordArg("b"),
+        KeywordArg("stride"),
+        KeywordArg("padding"),
+        KeywordArg("dilation"),
+        KeywordArg("groups"),
+        KeywordArg("output_scale"),
+        KeywordArg("output_zero_point"),
+        KeywordArg("output_dtype"),
+        KeywordArg("accum_scale"),
+        KeywordArg("accum_zero_point"),
+        KeywordArg("binary_op_name"),
+        KeywordArg("alpha"),
+        KeywordArg("unary_op_name"),
+        KeywordArg("unary_op_args"),
+        KeywordArg("unary_op_algorithm"),
+        _users=users,
+    )
+
+
+def get_qlinear_pt2e_pattern(x_scale_zp_are_tensors, users=1):
+    qlinear_op = (
+        torch.ops.onednn.qlinear_pointwise.tensor
+        if x_scale_zp_are_tensors
+        else torch.ops.onednn.qlinear_pointwise.default
+    )
+    return CallFunction(
+        qlinear_op,
+        KeywordArg("x"),
+        KeywordArg("x_scale"),
+        KeywordArg("x_zp"),
+        KeywordArg("packed_weight"),
+        KeywordArg("w_scale"),
+        KeywordArg("w_zp"),
+        KeywordArg("b"),
+        KeywordArg("output_scale"),
+        KeywordArg("output_zero_point"),
+        KeywordArg("output_dtype"),
+        KeywordArg("postop_name"),
+        KeywordArg("postop_args"),
+        KeywordArg("postop_algorithm"),
+        _users=users,
+    )
+
+
+def get_qlinear_binary_pt2e_pattern(x_scale_zp_are_tensors, users=1):
+    qlinear_op = (
+        torch.ops.onednn.qlinear_pointwise.binary_tensor
+        if x_scale_zp_are_tensors
+        else torch.ops.onednn.qlinear_pointwise.binary
+    )
+    return CallFunction(
+        qlinear_op,
+        KeywordArg("x"),
+        KeywordArg("x_scale"),
+        KeywordArg("x_zp"),
+        KeywordArg("packed_weight"),
+        KeywordArg("w_scale"),
+        KeywordArg("w_zp"),
+        KeywordArg("x_2"),
+        KeywordArg("b"),
+        KeywordArg("output_scale"),
+        KeywordArg("output_zero_point"),
+        KeywordArg("output_dtype"),
+        KeywordArg("x2_scale"),
+        KeywordArg("x2_zp"),
+        KeywordArg("binary_op_name"),
+        KeywordArg("alpha"),
+        KeywordArg("unary_op_name"),
+        KeywordArg("unary_op_args"),
+        KeywordArg("unary_op_algorithm"),
+        _users=users,
+    )
+
+
+dequantize_accum_pattern = CallFunction(
+    quantized_decomposed.dequantize_per_tensor.default,
+    KeywordArg("accum"),
+    KeywordArg("accum_scale"),
+    KeywordArg("accum_zp"),
+    Arg(),
+    Arg(),
+    KeywordArg("accum_dq_dtype"),
+)
+
+
+def generate_pattern_with_binary(
+    binary_post_op,
+    computation_call,
+    extra_input_pattern,
+    dtype_convert=False,
+    swap_inputs=False,
+):
+    binary_pattern = (
+        CallFunction(
+            binary_post_op,
+            extra_input_pattern,
+            computation_call,
+        )
+        if swap_inputs
+        else CallFunction(
+            binary_post_op,
+            computation_call,
+            extra_input_pattern,
+        )
+    )
+    return _may_generate_pattern_with_dtype_convert(
+        binary_pattern,
+        KeywordArg("convert_dtype_after_inplace_add"),
+        dtype_convert,
+    )
+
+
+def generate_pattern_with_unary(computation_call, unary_post_op):
+    if unary_post_op is not None:
+        return CallFunction(
+            unary_post_op,
+            computation_call,
+        )
+    return computation_call
+
+
+def generate_pattern_with_output_quant(computation_call, with_dtype_convert=False):
+    quantized_op_output_pattern_pt2e = CallFunction(
+        quantized_decomposed.quantize_per_tensor.default,
+        _may_generate_pattern_with_dtype_convert(
+            computation_call,
+            Arg(),
+            with_dtype_convert,
+        ),
+        KeywordArg("o_inv_scale"),
+        KeywordArg("o_zp"),
+        KeywordArg("o_qmin"),
+        KeywordArg("o_qmax"),
+        KeywordArg("o_dtype"),
+    )
+    return quantized_op_output_pattern_pt2e
+
+
+def _check_node_kwarg_arg_value(check_node, kwarg_name, args_index, expected_value):
+    if kwarg_name in check_node.kwargs:
+        actual_value = check_node.kwargs[kwarg_name]
+        return actual_value == expected_value
+    else:
+        assert len(check_node.args) >= (args_index + 1)
+        actual_value = check_node.args[args_index]
+        return actual_value == expected_value
+
+
+def _is_valid_quantized_conv_optimization_pattern():
+    def fn(match):
+        output_dtype = _get_pattern_output_dtype(match)
+        if output_dtype in [torch.float32, torch.bfloat16]:
+            # Only keep matched pattern with same output_dtype
+            qconv_node_after_weight_prepack = filter_nodes(
+                match.nodes, torch.ops.onednn.qconv_pointwise
+            )[0]
+            return _check_node_kwarg_arg_value(
+                qconv_node_after_weight_prepack, "output_dtype", 13, output_dtype
+            )
+        return True
+
+    return fn
+
+
+def _is_valid_qconv_post_op_fusion_pattern(has_binary_post_op=False):
+    return (
+        _is_valid_qconv_binary_optimization_pattern()
+        if has_binary_post_op
+        else _is_valid_quantized_conv_optimization_pattern()
+    )
+
+
+def _is_valid_qconv_lowering_pattern():
+    def fn(match):
+        if len(match.nodes) != 1:
+            return False
+        return match.nodes[0].target in (
+            torch.ops.onednn.qconv_pointwise.default,
+            torch.ops.onednn.qconv_pointwise.tensor,
+            torch.ops.onednn.qconv2d_pointwise.binary,
+            torch.ops.onednn.qconv2d_pointwise.binary_tensor,
+        )
+
+    return fn
+
+
+def _register_quantized_conv_lowering(
+    pattern,
+    pass_number,
+    computation_op,
+):
+    @register_lowering_pattern(
+        pattern,
+        extra_check=_is_valid_qconv_lowering_pattern(),
+        pass_number=pass_number,
+    )
+    def qconv(match: Match, *args, **kwargs):
+        # Activation QParams
+        x, x_scale, x_zp = (
+            kwargs["x"],
+            kwargs["x_scale"],
+            kwargs["x_zp"],
+        )
+        # Weight QParams
+        packed_weight, w_scale, w_zp = (
+            kwargs["packed_weight"],
+            kwargs["w_scale"],
+            kwargs["w_zp"],
+        )
+        # Conv Params
+        b, stride, padding, dilation, groups = (
+            kwargs["b"],
+            kwargs["stride"],
+            kwargs["padding"],
+            kwargs["dilation"],
+            kwargs["groups"],
+        )
+        output_dtype = _get_pattern_output_dtype(match)
+        assert output_dtype in [
+            torch.int8,
+            torch.uint8,
+            torch.float8_e4m3fn,
+            torch.float32,
+            torch.bfloat16,
+        ]
+        # Output QParams
+        o_inv_scale = kwargs["output_scale"]
+        o_zero_point = kwargs["output_zero_point"]
+        output_dtype = kwargs["output_dtype"]
+        # post op
+        postop_name = kwargs["postop_name"]
+        postop_args = kwargs["postop_args"]
+        postop_algorithm = kwargs["postop_algorithm"]
+
+        computation_args = (
+            x,
+            x_scale,
+            x_zp,
+            packed_weight,
+            w_scale,
+            w_zp,
+            b,
+            stride,
+            padding,
+            dilation,
+            groups,
+            o_inv_scale,
+            o_zero_point,
+            output_dtype,
+            postop_name,
+            postop_args,
+            postop_algorithm,
+        )
+        counters["inductor"]["qconv_unary_lower_count"] += 1
+        counters["inductor"]["qconv_unary_lower_nodes"] += len(match.nodes)
+        return L[computation_op](*computation_args)
+
+    return qconv
+
+
+def _is_valid_quantized_linear_optimization_pattern():
+    def fn(match):
+        output_dtype = _get_pattern_output_dtype(match)
+        if output_dtype in [torch.float32, torch.bfloat16]:
+            # Only keep matched pattern with same output_dtype
+            qlinear_node_after_weight_prepack = filter_nodes(
+                match.nodes, torch.ops.onednn.qlinear_pointwise
+            )[0]
+            return _check_node_kwarg_arg_value(
+                qlinear_node_after_weight_prepack, "output_dtype", 9, output_dtype
+            )
+        return True
+
+    return fn
+
+
+def _is_valid_qlinear_post_op_fusion_pattern(has_binary_post_op=False):
+    return (
+        _is_valid_qlinear_binary_optimization_pattern()
+        if has_binary_post_op
+        else _is_valid_quantized_linear_optimization_pattern()
+    )
+
+
+def _is_valid_qlinear_lowering_pattern():
+    def fn(match):
+        if len(match.nodes) != 1:
+            return False
+        return match.nodes[0].target in (
+            torch.ops.onednn.qlinear_pointwise.default,
+            torch.ops.onednn.qlinear_pointwise.tensor,
+            torch.ops.onednn.qlinear_pointwise.binary,
+            torch.ops.onednn.qlinear_pointwise.binary_tensor,
+        )
+
+    return fn
+
+
+def _register_quantized_linear_unary_lowering(
+    pattern,
+    pass_number,
+    computation_op,
+):
+    @register_lowering_pattern(
+        pattern,
+        extra_check=_is_valid_qlinear_lowering_pattern(),
+        pass_number=pass_number,
+    )
+    def qlinear(match: Match, *args, **kwargs):
+        output_dtype = _get_pattern_output_dtype(match)
+        # Activation QParams
+        x, x_scale, x_zp = (
+            kwargs["x"],
+            kwargs["x_scale"],
+            kwargs["x_zp"],
+        )
+        # Weight QParams
+        packed_weight, w_scale, w_zp = (
+            kwargs["packed_weight"],
+            kwargs["w_scale"],
+            kwargs["w_zp"],
+        )
+
+        # bias
+        b = kwargs.get("b")
+
+        # Output QParams
+        o_inv_scale = kwargs["output_scale"]
+        o_zero_point = kwargs["output_zero_point"]
+
+        # post op
+        postop_name = kwargs["postop_name"]
+        postop_args = kwargs["postop_args"]
+        postop_algorithm = kwargs["postop_algorithm"]
+
+        computation_args = (
+            x,
+            x_scale,
+            x_zp,
+            packed_weight,
+            w_scale,
+            w_zp,
+            b,
+            o_inv_scale,
+            o_zero_point,
+            output_dtype,
+            postop_name,
+            postop_args,
+            postop_algorithm,
+        )
+        counters["inductor"]["qlinear_unary_lower_count"] += 1
+        counters["inductor"]["qlinear_unary_lower_nodes"] += len(match.nodes)
+        return L[computation_op](*computation_args)
+
+    return qlinear
+
+
+def _register_quantized_linear_binary_lowering(
+    pattern,
+    pass_number,
+    computation_op,
+):
+    @register_lowering_pattern(
+        pattern,
+        extra_check=_is_valid_qlinear_lowering_pattern(),
+        pass_number=pass_number,
+    )
+    def qlinear_binary(match: Match, *args, **kwargs):
+        output_dtype = _get_pattern_output_dtype(match)
+        assert output_dtype is not None
+        # Activation QParams
+        x, x_scale, x_zp = (
+            kwargs["x"],
+            kwargs["x_scale"],
+            kwargs["x_zp"],
+        )
+        x2 = kwargs["x_2"]
+        x2_scale = kwargs["x2_scale"]
+        x2_zp = kwargs["x2_zp"]
+        # Weight QParams
+        packed_weight, w_scale, w_zp = (
+            kwargs["packed_weight"],
+            kwargs["w_scale"],
+            kwargs["w_zp"],
+        )
+        # bias
+        b = kwargs.get("b")
+        # Output QParams
+        o_inv_scale = kwargs["output_scale"]
+        o_zero_point = kwargs["output_zero_point"]
+
+        x2.realize()
+        from .mkldnn_fusion import _qlinear_binary_can_be_inplace
+
+        binary_op_name = kwargs["binary_op_name"]
+        alpha = kwargs["alpha"]
+        unary_op_name = kwargs["unary_op_name"]
+        unary_op_args = kwargs["unary_op_args"]
+        unary_op_algorithm = kwargs["unary_op_algorithm"]
+        if (
+            # TODO Ensure sum is safe and remove such check, i.e.,
+            # x2 is not used by other operations
+            # or current qlinear sum is the last user of x2.
+            # This needs to be ensured when registering
+            # the lowering pattern of quantized_linear_binary.
+            binary_op_name == "sum" and (not _qlinear_binary_can_be_inplace(x2))
+        ):
+            binary_op_name = "add"
+
+        computation_args = (
+            x,
+            x_scale,
+            x_zp,
+            packed_weight,
+            w_scale,
+            w_zp,
+            x2,
+            b,
+            o_inv_scale,
+            o_zero_point,
+            output_dtype,
+            x2_scale,
+            x2_zp,
+            binary_op_name,
+            alpha,
+            unary_op_name,
+            unary_op_args,
+            unary_op_algorithm,
+        )
+        counters["inductor"]["qlinear_binary_lower_count"] += 1
+        counters["inductor"]["qlinear_binary_lower_nodes"] += len(match.nodes)
+        return L[computation_op](*computation_args)
+
+    return qlinear_binary
+
+
+def _is_valid_qconv_binary_optimization_pattern():
+    return _is_valid_quantized_op_binary_optimization_pattern(
+        torch.ops.onednn.qconv_pointwise
+    )
+
+
+def _is_valid_qlinear_binary_optimization_pattern():
+    return _is_valid_quantized_op_binary_optimization_pattern(
+        torch.ops.onednn.qlinear_pointwise,
+        # we don't insert q-dq for extra input due to accuracy issues
+        extra_input_from_dequant=False,
+    )
+
+
+def _is_valid_quantized_op_binary_optimization_pattern(
+    qop, extra_input_from_dequant=True
+):
+    # Check if it's a valid Binary Pattern for qconv2d and qlinear:
+    # * qop_pointwise should only has one users
+    # * If extra_input_from_dequant is True, extra input of binary node should come from dequant pattern
+    # * the two inputs of binary node should have attribute "meta" and should be tensors
+    # * the two inputs of binary node should have the same shape
+    # * All users of the extra input in this pattern should be
+    #   ancestor nodes of the compute node, except for the binary node
+    #   connected to the compute node.
+    def fn(match):
+        output_dtype = _get_pattern_output_dtype(match)
+        compute_node = filter_nodes(match.nodes, qop)[0]
+        # qop_pointwise should only have one user
+        if len(compute_node.users) != 1:
+            return False
+        binary_node_inputs = next(iter(compute_node.users)).args
+        assert len(binary_node_inputs) == 2, "Expects binary node with 2 inputs"
+        if output_dtype in [torch.float32, torch.bfloat16]:
+            extra_input_of_binary_node = None
+            for arg in binary_node_inputs:
+                if arg != compute_node:
+                    extra_input_of_binary_node = arg
+                    break
+            assert extra_input_of_binary_node is not None
+            # Extra input of binary node comes from dequant pattern
+            if extra_input_from_dequant and (
+                (not isinstance(extra_input_of_binary_node, torch.fx.Node))
+                or (
+                    extra_input_of_binary_node.target
+                    != quantized_decomposed.dequantize_per_tensor.default
+                )
+            ):
+                return False
+
+        # the two inputs of binary node should have attribute "meta" and should be tensors
+        if not (
+            hasattr(binary_node_inputs[0], "meta")
+            and isinstance(binary_node_inputs[0].meta.get("val", None), torch.Tensor)  # type: ignore[union-attr]
+        ) or not (
+            hasattr(binary_node_inputs[1], "meta")
+            and isinstance(binary_node_inputs[1].meta.get("val", None), torch.Tensor)  # type: ignore[union-attr]
+        ):
+            return False
+        # the two inputs of binary node should have the same shape
+        if (
+            binary_node_inputs[0].meta["val"].size()  # type: ignore[union-attr]
+            != binary_node_inputs[1].meta["val"].size()  # type: ignore[union-attr]
+        ):
+            return False
+
+        # All users of the extra input in this pattern should be
+        # ancestor nodes of the compute node, except for the binary node
+        # connected to the compute node.
+
+        from .mkldnn_fusion import _get_remaining_users
+
+        extra_input_of_pattern = (
+            match.kwargs["other"]
+            if "other" in match.kwargs
+            else (
+                match.kwargs["accum"]
+                if (output_dtype in [torch.uint8, torch.int8])
+                or (not extra_input_from_dequant)
+                else match.kwargs["accum_after_dequant"]
+            )
+        )
+        if (
+            len(_get_remaining_users(extra_input_of_pattern, compute_node)) > 1
+            or extra_input_of_pattern == compute_node.args[0]
+        ):
+            return False
+        return True
+
+    return fn
+
+
+def _register_quantized_conv_binary_lowering(
+    pattern,
+    pass_number,
+    computation_op,
+):
+    @register_lowering_pattern(
+        pattern,
+        extra_check=_is_valid_qconv_lowering_pattern(),
+        pass_number=pass_number,
+    )
+    def qconv_binary(match: Match, *args, **kwargs):
+        output_dtype = _get_pattern_output_dtype(match)
+        assert output_dtype is not None
+        x, x_scale, x_zp = kwargs["x"], kwargs["x_scale"], kwargs["x_zp"]
+        accum = kwargs["accum"]
+        accum_scale = kwargs["accum_scale"]
+        accum_zp = kwargs["accum_zero_point"]
+        packed_weight, w_scale, w_zp = (
+            kwargs["packed_weight"],
+            kwargs["w_scale"],
+            kwargs["w_zp"],
+        )
+        b, stride, padding, dilation, groups = (
+            kwargs["b"],
+            kwargs["stride"],
+            kwargs["padding"],
+            kwargs["dilation"],
+            kwargs["groups"],
+        )
+        # Output QParams
+        output_scale = kwargs["output_scale"]
+        output_zero_point = kwargs["output_zero_point"]
+
+        # post ops
+        binary_op_name = kwargs["binary_op_name"]
+        alpha = kwargs["alpha"]
+        unary_op_name = kwargs["unary_op_name"]
+        unary_op_args = kwargs["unary_op_args"]
+        unary_op_algorithm = kwargs["unary_op_algorithm"]
+
+        accum.realize()
+        from .mkldnn_fusion import _can_be_inplace
+
+        assert _can_be_inplace(accum), (
+            "QConv Binary Inplace Fusion requires accum is not an alias or mutation."
+        )
+
+        computation_args = (
+            x,
+            x_scale,
+            x_zp,
+            packed_weight,
+            w_scale,
+            w_zp,
+            accum,
+            b,
+            stride,
+            padding,
+            dilation,
+            groups,
+            output_scale,
+            output_zero_point,
+            output_dtype,
+            accum_scale,
+            accum_zp,
+            binary_op_name,
+            alpha,
+            unary_op_name,
+            unary_op_args,
+            unary_op_algorithm,
+        )
+        counters["inductor"]["qconv2d_binary_lower_count"] += 1
+        counters["inductor"]["qconv2d_binary_lower_nodes"] += len(match.nodes)
+        return L[computation_op](*computation_args)
+
+    return qconv_binary
+
+
+def _register_quantization_unary_lowering():
+    # QConv2d
+    for x_scale_zp_are_tensors, users in itertools.product([False, True], [1, 2]):
+        qconv_pattern = get_qconv_pt2e_pattern(x_scale_zp_are_tensors, users)
+        computation_op = (
+            torch.ops.onednn.qconv_pointwise.tensor
+            if x_scale_zp_are_tensors
+            else torch.ops.onednn.qconv_pointwise.default
+        )
+        _register_quantized_conv_lowering(
+            qconv_pattern,
+            2,  # pass_number
+            computation_op,
+        )
+
+    # QLinear
+    for x_scale_zp_are_tensors in (False, True):
+        qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors)
+        computation_op = (
+            torch.ops.onednn.qlinear_pointwise.tensor
+            if x_scale_zp_are_tensors
+            else torch.ops.onednn.qlinear_pointwise.default
+        )
+        _register_quantized_linear_unary_lowering(
+            qlinear_pattern,
+            2,  # pass_number
+            computation_op,
+        )
+
+
+def _register_quantization_binary_lowering():
+    # QConv2d
+    for x_scale_zp_are_tensors, users in itertools.product([False, True], [1, 2]):
+        qconv_pattern = get_qconv2d_binary_pt2e_pattern(x_scale_zp_are_tensors, users)
+        computation_op = (
+            torch.ops.onednn.qconv2d_pointwise.binary_tensor
+            if x_scale_zp_are_tensors
+            else torch.ops.onednn.qconv2d_pointwise.binary
+        )
+        _register_quantized_conv_binary_lowering(
+            qconv_pattern,
+            2,  # pass_number
+            computation_op,
+        )
+
+    # QLinear
+    for x_scale_zp_are_tensors in (False, True):
+        qlinear_pattern = get_qlinear_binary_pt2e_pattern(x_scale_zp_are_tensors)
+        computation_op = (
+            torch.ops.onednn.qlinear_pointwise.binary_tensor
+            if x_scale_zp_are_tensors
+            else torch.ops.onednn.qlinear_pointwise.binary
+        )
+        _register_quantized_linear_binary_lowering(
+            qlinear_pattern,
+            2,  # pass_number
+            computation_op,
+        )
+
+
+def _is_valid_quantized_maxpool2d_optimization_pattern():
+    def fn(match):
+        # Only match the pattern which max_pool2d_with_indices returns value
+        # instead of indices.
+        get_item_node = filter_nodes(match.nodes, operator.getitem)[0]
+        return get_item_node.args[1] == 0
+
+    return fn
+
+
+def _register_quantized_maxpool2d_lowering(
+    pattern,
+    computation_op,
+):
+    @register_lowering_pattern(
+        pattern,
+        extra_check=_is_valid_quantized_maxpool2d_optimization_pattern(),
+    )
+    def qmaxpool2d(match: Match, *args, **kwargs):
+        x = kwargs["x"]
+        kernel_size = kwargs["kernel_size"]
+        stride = kwargs.get("stride")
+        padding = kwargs.get("padding", 0)
+        dilation = kwargs.get("dilation", 1)
+        ceil_mode = kwargs.get("ceil_mode", False)
+
+        if padding == 0:
+            padding = [0, 0]
+        if dilation == 1:
+            dilation = [1, 1]
+        if not stride:
+            stride = kernel_size
+        kernel_size = pad_listlike(kernel_size, 2)
+        stride = pad_listlike(stride, 2)
+        padding = pad_listlike(padding, 2)
+        dilation = pad_listlike(dilation, 2)
+
+        assert len(kernel_size) == 2
+        assert len(stride) == 2
+        assert len(padding) == 2
+        assert len(dilation) == 2
+
+        computation_args = (
+            x,
+            kernel_size,
+            stride,
+            padding,
+            dilation,
+            ceil_mode,
+        )
+        computation_args, _ = require_channels_last(computation_op, *computation_args)
+        counters["inductor"]["qmaxpool2d_matcher_count"] += 1
+        counters["inductor"]["qmaxpool2d_matcher_nodes"] += len(match.nodes)
+        return L[computation_op](*computation_args)
+
+    return qmaxpool2d
+
+
+def _register_quantization_maxpool2d():
+    # Currently, the default parameters are not in FX Graph generated by Dynamo export.
+    # So, if user defines nn.MaxPool2d with different assignment of default parameter,
+    # it will generate graph with different number of input nodes and hence
+    # different pattern to be matched.
+    # Refer to the issue: https://github.com/pytorch/pytorch/issues/105901
+    max_pool2d_args_list = [
+        [
+            KeywordArg("stride"),
+        ],
+        [
+            KeywordArg("stride"),
+            KeywordArg("padding"),
+        ],
+        [
+            KeywordArg("stride"),
+            KeywordArg("padding"),
+            KeywordArg("dilation"),
+        ],
+        [
+            KeywordArg("stride"),
+            KeywordArg("padding"),
+            KeywordArg("dilation"),
+            KeywordArg("ceil_mode"),
+        ],
+    ]
+    for max_pool2d_args in max_pool2d_args_list:
+        dequantize_maxpool2d_pattern = CallFunction(
+            aten.max_pool2d_with_indices.default,
+            get_dequantize_per_tensor_activation_pattern(),
+            KeywordArg("kernel_size"),
+            *max_pool2d_args,
+        )
+        dequantize_lowmem_maxpool2d_pattern = CallFunction(
+            prims._low_memory_max_pool_with_offsets.default,
+            get_dequantize_per_tensor_activation_pattern(),
+            KeywordArg("kernel_size"),
+            *max_pool2d_args,
+            KeywordArg("offset_dtype"),
+        )
+        dequantize_maxpool2d_get_item_pattern = CallFunction(
+            operator.getitem,
+            dequantize_maxpool2d_pattern,
+            Arg(),
+        )
+        dequantize_lowmem_maxpool2d_get_item_pattern = CallFunction(
+            operator.getitem,
+            dequantize_lowmem_maxpool2d_pattern,
+            Arg(),
+        )
+        _register_quantized_maxpool2d_lowering(
+            generate_pattern_with_output_quant(dequantize_maxpool2d_get_item_pattern),
+            quantized.max_pool2d.default,
+        )
+        _register_quantized_maxpool2d_lowering(
+            generate_pattern_with_output_quant(
+                dequantize_lowmem_maxpool2d_get_item_pattern
+            ),
+            quantized.max_pool2d.default,
+        )
+
+
+def _is_input_output_same_scale_zp(check_node):
+    def fn(match):
+        # Ensure all the inputs and output has same scale and zero point
+        # Step 1: Check inputs/output zero point
+        # Get dequant nodes at input
+        dequant_nodes = filter_nodes(
+            match.nodes, quantized_decomposed.dequantize_per_tensor.default
+        )
+        zero_points = [node.args[2] for node in dequant_nodes]
+        # Get quant nodes at output
+        quant_nodes = filter_nodes(
+            match.nodes, quantized_decomposed.quantize_per_tensor.default
+        )
+        assert len(quant_nodes) == 1, "expect only 1 add node at output quant pattern"
+        zero_points.append(quant_nodes[0].args[2])
+        if not all(zero_point == zero_points[0] for zero_point in zero_points):
+            return False
+
+        # Step 2: Check inputs/output scale
+        scales = [node.args[1] for node in dequant_nodes]
+        scales.append(quant_nodes[0].args[1])
+        if not all(math.isclose(scale, scales[0], rel_tol=1e-5) for scale in scales):  # type: ignore[arg-type]
+            return False
+
+        return True
+
+    return fn
+
+
+def _register_quantized_cat_lowering(
+    pattern,
+    computation_op,
+):
+    @register_lowering_pattern(
+        pattern,
+        extra_check=_is_input_output_same_scale_zp(aten.cat.default),
+    )
+    def qcat(match: Match, inputs, dim, **kwargs):
+        # inputs is with format: [[x1, x1_dq_dtype, x1_zp, x1_scale], ...]
+        uint8_inputs = [input[0] for input in inputs]
+        counters["inductor"]["qcat_matcher_count"] += 1
+        counters["inductor"]["qcat_matcher_nodes"] += len(match.nodes)
+        return L[computation_op](uint8_inputs, dim)
+
+    return qcat
+
+
+_raw_dequantize_per_tensor_activation_pattern = CallFunction(
+    quantized_decomposed.dequantize_per_tensor.default,
+    Arg(),
+    Arg(),
+    Arg(),
+    Arg(),
+    Arg(),
+    Arg(),
+)
+
+
+def _register_quantization_cat():
+    dequantize_cat_pattern = CallFunction(
+        aten.cat.default,
+        ListOf(_raw_dequantize_per_tensor_activation_pattern),
+        KeywordArg("dim"),
+    )
+    _register_quantized_cat_lowering(
+        generate_pattern_with_output_quant(dequantize_cat_pattern),
+        aten.cat,
+    )
+
+
+def _register_quantized_reshape_lowering(
+    pattern,
+    computation_op,
+):
+    @register_lowering_pattern(
+        pattern,
+        extra_check=_is_input_output_same_scale_zp(aten.reshape.default),
+    )
+    def qreshape(match: Match, *args, **kwargs):
+        qx = kwargs["x"]
+        shape = kwargs["shape"]
+        counters["inductor"]["qreshape_matcher_count"] += 1
+        counters["inductor"]["qreshape_matcher_nodes"] += len(match.nodes)
+        return L[computation_op](qx, shape)
+
+    return qreshape
+
+
+def _register_quantization_reshape():
+    dequantize_reshape_pattern = CallFunction(
+        torch.ops.aten.reshape.default,
+        get_dequantize_per_tensor_activation_pattern(),
+        KeywordArg("shape"),
+    )
+    _register_quantized_reshape_lowering(
+        generate_pattern_with_output_quant(dequantize_reshape_pattern),
+        aten.reshape,
+    )
+
+
+def _is_valid_concat_linear_int8_woq_optimization_pattern():
+    def fn(match):
+        if not config.cpp.enable_concat_linear:
+            return False
+        assert all(k in match.kwargs for k in ("x", "w1", "w2", "w3", "scales"))
+        if not all(
+            hasattr(match.kwargs[key], "meta")
+            for key in ["x", "w1", "w2", "w3", "scales"]
+        ):
+            return False
+        x = match.kwargs["x"].meta["val"]
+        w1 = match.kwargs["w1"].meta["val"]
+        w2 = match.kwargs["w2"].meta["val"]
+        w3 = match.kwargs["w3"].meta["val"]
+        scales = match.kwargs["scales"].meta["val"]
+        if len(match.kwargs["scales"].meta["val"].size()) > 1:
+            return False
+        num_scales = match.kwargs["scales"].meta["val"].numel()
+        w1_cols = match.kwargs["w1"].meta["val"].size()[0]
+        w2_cols = match.kwargs["w2"].meta["val"].size()[0]
+        w3_cols = match.kwargs["w3"].meta["val"].size()[0]
+        return (
+            # For now, we only support woq mm kernels
+            # with x.type=bfloat16 and w.type=int8
+            x.dtype == torch.bfloat16
+            and w1.dtype == torch.int8
+            and w2.dtype == torch.int8
+            and w3.dtype == torch.int8
+            and scales.dtype == torch.bfloat16
+            and x.device.type in ("cpu", "cuda")
+            and x.device == w1.device
+            and w1.device == w2.device
+            and w2.device == w3.device
+            and x.device == scales.device
+            and num_scales == w1_cols + w2_cols + w3_cols
+        )
+
+    return fn
+
+
+def _is_valid_woq_optimization_pattern():
+    def fn(match):
+        assert all(k in match.kwargs for k in ("x", "weight", "scales"))
+        if not all(
+            hasattr(match.kwargs[key], "meta") for key in ["x", "weight", "scales"]
+        ):
+            return False
+        x = match.kwargs["x"].meta["val"]
+        weight = match.kwargs["weight"].meta["val"]
+        scales = match.kwargs["scales"].meta["val"]
+        return (
+            # For now, we only support woq mm kernels
+            # with x.type=bfloat16 and w.type=int8
+            x.dtype == torch.bfloat16
+            and weight.dtype == torch.int8
+            and scales.dtype == torch.bfloat16
+            and x.device.type in ("cpu", "cuda")
+            and x.device == weight.device
+            and x.device == scales.device
+        )
+
+    return fn
+
+
+def _register_concat_linear_int8_woq_lowering(
+    pattern, computation_woq, computation_reshape
+):
+    @register_freezing_graph_pattern(
+        pattern,
+        extra_check=_is_valid_concat_linear_int8_woq_optimization_pattern(),
+        pass_number=4,
+    )
+    def woq_int8(match: Match, *args, **kwargs):
+        x = kwargs["x"]
+        w1 = kwargs["w1"]
+        w2 = kwargs["w2"]
+        w3 = kwargs["w3"]
+        scales = kwargs["scales"]
+        counters["inductor"]["woq_matcher_count"] += 1
+        counters["inductor"]["woq_matcher_nodes"] += len(match.nodes)
+        out_features = (
+            w1.meta["val"].size()[0]
+            + w2.meta["val"].size()[0]
+            + w3.meta["val"].size()[0]
+        )
+        origin_x_size = tuple(x.meta["val"].size())
+        x_shape = [-1, origin_x_size[-1]]
+        out_shape = list(origin_x_size[:-1] + (out_features,))
+        mm_node_of_x = None
+        for candidate in iter(x.users.keys()):
+            if (
+                candidate.target is aten.mm.default
+                and list(candidate._input_nodes)[1].target is aten.cat.default
+            ):
+                mm_node_of_x = candidate
+                break
+        assert mm_node_of_x is not None, "unable to find mm node"
+        _, cat_wgt_node = mm_node_of_x._input_nodes
+        scaling_node = next(iter(mm_node_of_x.users.keys()))
+        user_of_scaling_node = next(iter(scaling_node.users.keys()))
+        # Some other pass is making some changes that entails
+        # adding a node before it's used, but it can only be found when
+        # lint is run. stable_topological_sort() is being run before lint,
+        # so that error was not being being discovered.
+        # We call stable_topological_sort here as a workaround.
+        stable_topological_sort(match.graph)
+        with match.graph.inserting_before(user_of_scaling_node):
+            new_cat_node = match.graph.call_function(
+                aten.cat.default,
+                args=([w1, w2, w3], 0),
+            )
+            x_reshape_node = match.graph.call_function(
+                computation_reshape, args=(x, x_shape)
+            )
+            new_woq_node = match.graph.call_function(
+                computation_woq,
+                args=(x_reshape_node, new_cat_node, scales),
+            )
+            new_woq_node.meta = copy.copy(x.meta)
+            output_reshape_node = match.graph.call_function(
+                computation_reshape, args=(new_woq_node, out_shape)
+            )
+            scaling_node.replace_all_uses_with(output_reshape_node)
+            match.graph.erase_node(scaling_node)
+            match.graph.erase_node(mm_node_of_x)
+            match.graph.erase_node(cat_wgt_node)
+            match.graph.lint()
+
+    return woq_int8
+
+
+def _register_woq_lowering(pattern, computation_woq, computation_reshape):
+    @register_lowering_pattern(
+        pattern,
+        extra_check=_is_valid_woq_optimization_pattern(),
+    )
+    def woq_int8(match: Match, *args, **kwargs):
+        x = kwargs["x"]
+        weight = kwargs["weight"]
+        scales = kwargs["scales"]
+        counters["inductor"]["woq_matcher_count"] += 1
+        counters["inductor"]["woq_matcher_nodes"] += len(match.nodes)
+        out_features = weight.get_size()[0]
+        origin_x_size = x.get_size()
+        x_shape = [-1, origin_x_size[-1]]
+        out_shape = origin_x_size[:-1] + [
+            out_features,
+        ]
+        func1 = L[computation_reshape](x, x_shape)
+        func2 = L[computation_woq](func1, weight, scales)
+        return L[computation_reshape](func2, out_shape)
+
+    return woq_int8
+
+
+def _register_woq_mm_int8_pattern1():
+    # F.linear(x, weight.to(dtype=x.dtype)) * scales
+    # case of dispatching to mm, with x reshape
+    _woq_pattern = CallFunction(
+        aten.mul.Tensor,
+        CallFunction(
+            aten.reshape.default,
+            CallFunction(
+                aten.mm.default,
+                CallFunction(aten.reshape.default, KeywordArg("x"), Arg()),
+                CallFunction(
+                    aten.permute.default,
+                    CallFunction(
+                        prims.convert_element_type.default, KeywordArg("weight"), Arg()
+                    ),
+                    Arg(),
+                ),
+            ),
+            Arg(),
+        ),
+        KeywordArg("scales"),
+    )
+    _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape)
+
+
+def _register_woq_mm_int8_pattern2():
+    # F.linear(x, weight.to(dtype=x.dtype)) * scales
+    # case of dispatching to mm, w/o x reshape
+    _woq_pattern = CallFunction(
+        aten.mul.Tensor,
+        CallFunction(
+            aten.reshape.default,
+            CallFunction(
+                aten.mm.default,
+                KeywordArg("x"),
+                CallFunction(
+                    aten.permute.default,
+                    CallFunction(
+                        prims.convert_element_type.default, KeywordArg("weight"), Arg()
+                    ),
+                    Arg(),
+                ),
+            ),
+            Arg(),
+        ),
+        KeywordArg("scales"),
+    )
+    _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape)
+
+
+def _register_woq_mm_int8_pattern3():
+    # F.linear(x, weight.to(dtype=x.dtype)) * scales
+    # case of dispatching to bmm
+    _woq_pattern = CallFunction(
+        aten.mul.Tensor,
+        CallFunction(
+            aten.bmm.default,
+            CallFunction(aten.expand.default, KeywordArg("x"), Arg()),
+            CallFunction(
+                aten.expand.default,
+                CallFunction(
+                    aten.permute.default,
+                    CallFunction(
+                        prims.convert_element_type.default, KeywordArg("weight"), Arg()
+                    ),
+                    Arg(),
+                ),
+                Arg(),
+            ),
+        ),
+        KeywordArg("scales"),
+    )
+    _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape)
+
+
+def _register_woq_mm_int8_pattern4():
+    _woq_pattern = CallFunction(
+        aten.mul.Tensor,
+        CallFunction(
+            aten.mm.default,
+            KeywordArg("x"),
+            CallFunction(
+                prims.convert_element_type.default,
+                CallFunction(
+                    aten.permute.default,
+                    KeywordArg("weight"),
+                    Arg(),
+                ),
+                Arg(),
+            ),
+        ),
+        KeywordArg("scales"),
+    )
+    _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape)
+
+
+def _register_int8_woq_concat_linear_pattern():
+    def _create_wgt_node(wgt_node_name: str):
+        return CallFunction(
+            prims.convert_element_type.default,
+            CallFunction(
+                aten.permute.default,
+                KeywordArg(wgt_node_name),
+                Arg(),
+            ),
+            Arg(),
+        )
+
+    cat_wgt = CallFunction(
+        aten.cat.default, [_create_wgt_node(wgt) for wgt in ["w1", "w2", "w3"]], 1
+    )
+
+    _woq_pattern = CallFunction(
+        aten.mul.Tensor,
+        CallFunction(aten.mm.default, KeywordArg("x"), cat_wgt),
+        KeywordArg("scales"),
+    )
+    _register_concat_linear_int8_woq_lowering(
+        _woq_pattern, aten._weight_int8pack_mm.default, aten.reshape
+    )
+
+
+def _register_quantization_lowerings():
+    _register_quantization_unary_lowering()
+    _register_quantization_binary_lowering()
+    _register_quantization_maxpool2d()
+    _register_quantization_cat()
+    _register_quantization_reshape()
+
+
+def _register_woq_lowerings():
+    _register_woq_mm_int8_pattern1()
+    _register_woq_mm_int8_pattern2()
+    _register_woq_mm_int8_pattern3()
+    _register_woq_mm_int8_pattern4()
+
+
+def _is_valid_dequant_promotion_pattern(dtype=torch.float32):
+    def _inner(match):
+        assert dtype in [torch.float32, torch.bfloat16]
+        dequant_pattern_end_node = match.output_node()
+        if dequant_pattern_end_node.target not in [
+            quantized_decomposed.dequantize_per_tensor.default,
+            quantized_decomposed.dequantize_per_tensor.tensor,
+            prims.convert_element_type.default,
+            aten.reshape.default,
+        ]:
+            return False
+
+        if dequant_pattern_end_node.target is aten.reshape.default:
+            dequant_node = (
+                dequant_pattern_end_node.args[
+                    0
+                ]  # pattern: linear <- reshape <- dequant
+                if dtype == torch.float32
+                else dequant_pattern_end_node.args[0].args[
+                    0
+                ]  # pattern: linear <- reshape <- to_bf16 <- dequant
+            )
+        else:
+            dequant_node = (
+                dequant_pattern_end_node  # pattern: linear <- dequant
+                if dtype == torch.float32
+                else dequant_pattern_end_node.args[
+                    0
+                ]  # pattern: linear <- to_bf16 <- dequant
+            )
+
+        if (
+            dequant_node.target
+            in [
+                quantized_decomposed.dequantize_per_tensor.default,
+                quantized_decomposed.dequantize_per_tensor.tensor,
+            ]
+            and len(list(dequant_pattern_end_node.users)) > 1
+        ):
+            # If dequant pattern has more than 1 users, then do dequant promoted
+            return True
+        return False
+
+    return _inner
+
+
+def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32):
+    @register_freezing_graph_pattern(
+        pattern,
+        extra_check=_is_valid_dequant_promotion_pattern(dtype),
+        pass_number=pass_number,
+    )
+    def dequant_promotion(match: Match, *args, **kwargs):
+        # Dequant_promotion will transform
+        # graph 1:
+        #            quant
+        #      + - - - | - - - +
+        #      |    dequant    |
+        #      |    /     \    |
+        #      |  node1  node2 |
+        #      + - | - - - | - +
+        #        quant   quant
+        # into:
+        # graph 2:
+        #            quant
+        #      + - - / - \ - - +
+        #      |dequant dequant|
+        #      |    |      |   |
+        #      | node1 node2   |
+        #      + - | - - - | - +
+        #        quant   quant
+        # In graph 1, the dequant node is shared by node1 and node2,
+        # as a result, neither node1 nor node2 could form an int8
+        # fusion pattern.
+        # After this transformation, the graph 2 could hit the int8
+        # fusion pattern: dequant-node-quant, respectively for
+        # node1 and node2.
+        assert dtype in [torch.float32, torch.bfloat16]
+
+        def clone_to_new_node(graph, source_node, user_node):
+            # Clone the source_node to a new node
+            # Replace user_node's input from source_node to new_node
+            assert source_node.op == "call_function", (
+                "clone_to_new_node only support node.op call_function"
+            )
+            with graph.inserting_before(user_node):
+                new_node = graph.call_function(
+                    source_node.target,
+                    args=source_node.args,
+                    kwargs=source_node.kwargs,
+                )
+                new_node.meta = copy.copy(source_node.meta)
+                user_node.replace_input_with(source_node, new_node)
+            return new_node
+
+        # Find the start node and end node of a dequant pattern
+        # * End node should be the match.output_node()
+        # * Start node should be the node of dequantize_per_tensor
+        dequant_pattern_end_node = match.output_node()
+        assert dequant_pattern_end_node.target in [
+            quantized_decomposed.dequantize_per_tensor.default,
+            quantized_decomposed.dequantize_per_tensor.tensor,
+            prims.convert_element_type.default,
+            aten.reshape.default,
+        ]
+
+        # For a dequant pattern, we should expect see the node list as:
+        # * OPT(aten.reshape.default)
+        # * OPT(prims.convert_element_type.default) (to_bf16)
+        # * dequantize_per_tensor
+        def _find_first_node_in_dequant_pattern(_node):
+            if _node.target in [
+                quantized_decomposed.dequantize_per_tensor.default,
+                quantized_decomposed.dequantize_per_tensor.tensor,
+            ]:
+                # For a dequant pattern, we expect the start node is a dequantize_per_tensor node
+                return _node
+            else:
+                assert len(_node.args) >= 1, (
+                    "In in dequant pattern, each node should have more than 1 arg."
+                )
+                return _find_first_node_in_dequant_pattern(_node.args[0])
+
+        dequant_pattern_start_node = _find_first_node_in_dequant_pattern(
+            dequant_pattern_end_node
+        )
+
+        assert dequant_pattern_start_node.target in [
+            quantized_decomposed.dequantize_per_tensor.default,
+            quantized_decomposed.dequantize_per_tensor.tensor,
+        ]
+
+        # Clone the dequant pattern for each user node
+        graph = match.graph
+        user_node_list = list(dequant_pattern_end_node.users)
+        for user_node in user_node_list[1:]:
+            _source_node = dequant_pattern_end_node
+            _user_node = user_node
+            while _source_node != dequant_pattern_start_node.args[0]:
+                _user_node = clone_to_new_node(graph, _source_node, _user_node)
+                _source_node = _source_node.args[0]  # type: ignore[assignment]
+
+        counters["inductor"]["dequant_promotion_matcher_count"] += 1
+        counters["inductor"]["dequant_promotion_matcher_nodes"] += len(match.nodes)
+
+
+def _is_valid_dequant_conv_pattern(dtype, with_dtype_convert):
+    def _inner(match):
+        # Here we do some further check to ensure:
+        # 1. It's a conv2d node with dim of 4, since we only support lowering of conv2d now.
+        # 2. The dequant pattern has only 1 user of conv2d node.
+        # If these conditions don't meet, we will not
+        # insert weight prepack node into the matched pattern.
+        conv_node = match.output_node()
+        assert conv_node.target is aten.convolution.default
+        input_meta_value = conv_node.args[0].meta.get("val")
+        weight_meta_value = conv_node.args[1].meta.get("val")
+        for meta_value in [input_meta_value, weight_meta_value]:
+            if (
+                meta_value is None
+                or (meta_value.device.type != "cpu" and meta_value.device.type != "xpu")
+                or meta_value.dim() not in [3, 4]
+            ):
+                # Only support conv1d/2d now
+                return False
+
+        assert dtype in [torch.float32, torch.bfloat16]
+
+        if not with_dtype_convert:
+            dequant_node = conv_node.args[0]
+        else:
+            convert_to_bf16 = conv_node.args[0]
+            dequant_node = convert_to_bf16.args[0]
+
+        if len(list(dequant_node.users)) != 1:
+            # Ensure the dequant pattern only has 1 user
+            # since we will delete the dequant pattern here
+            return False
+        return True
+
+    return _inner
+
+
+def _register_qconv_weight_prepack_pass(
+    pattern, pass_number, dtype=torch.float32, with_dtype_convert=False
+):
+    @register_freezing_graph_pattern(
+        pattern,
+        extra_check=_is_valid_dequant_conv_pattern(dtype, with_dtype_convert),
+        pass_number=pass_number,
+    )
+    def qconv_weight_prepack(match: Match, *args, **kwargs):
+        """
+        Match the pattern:
+        int8 activation
+          |
+        dequant_per_tensor
+          |
+        Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight
+
+        Insert weight prepack node and change the pattern to:
+        int8 activation
+          |
+        onednn.qconv_pointwise <- onednn.qconv_prepack <- int8_weight
+        """
+        assert dtype in [torch.float32, torch.bfloat16]
+        conv_node = match.output_node()
+        assert conv_node.target is aten.convolution.default
+        if not with_dtype_convert:
+            dequant_node = conv_node.args[0]
+        else:
+            convert_to_bf16 = conv_node.args[0]
+            dequant_node = convert_to_bf16.args[0]  # type: ignore[union-attr]
+        has_clone_to_channel_last_node_in_pattern = (
+            conv_node.args[1].target is aten.clone.default  # type: ignore[union-attr]
+        )
+        clone_node = (
+            conv_node.args[1] if has_clone_to_channel_last_node_in_pattern else None
+        )
+
+        if dtype == torch.float32:
+            dequant_per_channel = (
+                clone_node.args[0]  # type: ignore[union-attr]
+                if has_clone_to_channel_last_node_in_pattern
+                else conv_node.args[1]
+            )
+        else:
+            weight_to_bf16_node = (
+                clone_node.args[0]  # type: ignore[union-attr]
+                if has_clone_to_channel_last_node_in_pattern
+                else conv_node.args[1]
+            )
+            dequant_per_channel = weight_to_bf16_node.args[0]  # type: ignore[union-attr]
+
+        assert (
+            dequant_per_channel.target  # type: ignore[union-attr]
+            is quantized_decomposed.dequantize_per_channel.default
+        )
+
+        # Activation QParams
+        qx, x_zp, x_scale = (
+            kwargs["x"],
+            kwargs["x_zp"],
+            kwargs["x_scale"],
+        )
+
+        # Weight QParams
+        qw, w_scale, w_zp = (
+            kwargs["q_weight"],
+            kwargs["w_scale"],
+            kwargs["w_zp"],
+        )
+
+        # Conv Params
+        bias, stride, padding, dilation, groups = (
+            kwargs["b"],
+            kwargs["stride"],
+            kwargs["padding"],
+            kwargs["dilation"],
+            kwargs["groups"],
+        )
+
+        x_shape = qx.meta.get("tensor_meta").shape
+        if has_free_symbols(x_shape):
+            # For dynamic shape case, we can't get activation shape ahead of runtime.
+            x_shape = None
+        graph = match.graph
+        with graph.inserting_before(conv_node):
+            # Insert weight prepack node and the QConv node
+            packed_weight_inputs = (
+                qw,
+                w_scale,
+                x_scale,
+                x_zp,
+                stride,
+                padding,
+                dilation,
+                groups,
+                x_shape,
+            )
+            packed_weight_op = torch.ops.onednn.qconv_prepack
+            prepack_weight_node = graph.call_function(
+                packed_weight_op, args=packed_weight_inputs
+            )
+
+            new_args: tuple[Any, ...] = (
+                qx,
+                x_scale,
+                x_zp,
+                prepack_weight_node,
+                w_scale,
+                w_zp,
+                bias,
+                stride,
+                padding,
+                dilation,
+                groups,
+                1.0,  # output_scale
+                0,  # output_zero_point
+                dtype,  # output_dtype
+                "none",  # attr
+                [],  # scalars
+                "",  # algorithm
+            )
+            new_conv_node = graph.call_function(
+                torch.ops.onednn.qconv_pointwise.default, args=new_args
+            )
+            conv_node.replace_all_uses_with(new_conv_node)
+            new_conv_node.meta.update(conv_node.meta)
+
+            # Erase the original conv node
+            graph.erase_node(conv_node)
+            # Erase the dequant pattern
+            if with_dtype_convert:
+                graph.erase_node(convert_to_bf16)  # type: ignore[possibly-undefined, arg-type]
+            graph.erase_node(dequant_node)  # type: ignore[arg-type]
+            # Erase the dequant per channel pattern
+            if clone_node is not None:
+                graph.erase_node(clone_node)  # type: ignore[arg-type]
+            if dtype == torch.bfloat16:
+                graph.erase_node(weight_to_bf16_node)  # type: ignore[possibly-undefined, arg-type]
+            graph.erase_node(dequant_per_channel)  # type: ignore[arg-type]
+            counters["inductor"]["qconv_weight_prepack_matcher_count"] += 1
+            counters["inductor"]["qconv_weight_prepack_matcher_nodes"] += len(
+                match.nodes
+            )
+
+
+def _generate_dequant_convolution_node_pattern(
+    _dequant_per_channel_pattern, dtype=torch.float32, with_dtype_convert=False
+):
+    assert dtype in [torch.float32, torch.bfloat16]
+    dequant_convolution_node_pattern = CallFunction(
+        aten.convolution.default,
+        _may_generate_pattern_with_dtype_convert(
+            get_dequantize_per_tensor_activation_pattern(),
+            KeywordArg("autocast_act_dtype"),
+            with_dtype_convert,
+        ),
+        _dequant_per_channel_pattern,
+        KeywordArg("b"),
+        KeywordArg("stride"),
+        KeywordArg("padding"),
+        KeywordArg("dilation"),
+        KeywordArg("is_transposed"),
+        KeywordArg("out_padding"),
+        KeywordArg("groups"),
+    )
+    return dequant_convolution_node_pattern
+
+
+def _generate_qconv_weight_prepack_patterns(
+    dtype=torch.float32, with_dtype_convert=False
+):
+    assert dtype in [torch.float32, torch.bfloat16]
+    return (
+        _generate_dequant_convolution_node_pattern(
+            dequantize_per_channel_weight_pattern
+            if dtype == torch.float32
+            else dequantize_per_channel_to_bf16_weight_pattern,
+            dtype,
+            with_dtype_convert,
+        ),
+        # There is another pattern due to the pass of convert_conv_weights_to_channels_last
+        # https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362.
+        # Depend on some heuristics, it may or may not insert to(channel_last) node
+        # between convolution and dequant_per_channel node
+        _generate_dequant_convolution_node_pattern(
+            dequantize_per_channel_clone_weight_pattern
+            if dtype == torch.float32
+            else dequantize_per_channel_to_bf16_clone_weight_pattern,
+            dtype,
+            with_dtype_convert,
+        ),
+    )
+
+
+def _get_linear_node(match, input_dim_exceeds_two, input_contiguous):
+    output_reshape_node = None
+    if input_dim_exceeds_two:
+        if input_contiguous:
+            output_reshape_node = match.output_node()
+            assert output_reshape_node.target is aten.reshape.default
+            linear_node = output_reshape_node.args[0]
+        else:
+            linear_nodes = filter_nodes(match.nodes, aten.bmm.default)
+            assert len(linear_nodes) == 1
+            linear_node = linear_nodes[0]
+    else:
+        linear_node = match.output_node()
+
+    assert linear_node.target in (
+        aten.addmm.default,
+        aten.mm.default,
+        aten.bmm.default,
+    )
+    return linear_node, output_reshape_node
+
+
+def _get_linear_dq_node(
+    linear_node,
+    input_index,
+    input_dim_exceeds_two,
+    input_contiguous,
+    with_dtype_convert,
+):
+    act_reshape_node = None
+    activation_to_bf16_node = None
+    act_expand_node = None
+    if input_dim_exceeds_two:
+        if input_contiguous:
+            act_reshape_node = linear_node.args[input_index]
+            assert act_reshape_node.target is aten.reshape.default
+            if not with_dtype_convert:
+                # pattern: linear -> reshape -> dequant
+                dequant_node = act_reshape_node.args[0]
+            else:
+                # pattern: linear -> reshape -> to_bf16 -> dequant
+                activation_to_bf16_node = act_reshape_node.args[0]
+                dequant_node = activation_to_bf16_node.args[0]
+        else:
+            # bmm pattern decomposed from linear when input dim exceeds 2 and not contiguous
+            act_expand_node = linear_node.args[input_index]
+            assert act_expand_node.target is aten.expand.default
+            if not with_dtype_convert:
+                dequant_node = act_expand_node.args[0]
+            else:
+                activation_to_bf16_node = act_expand_node.args[0]
+                dequant_node = activation_to_bf16_node.args[0]
+    else:
+        if not with_dtype_convert:
+            # pattern: linear -> dequant
+            dequant_node = linear_node.args[input_index]
+        else:
+            # pattern: linear -> to_bf16 -> dequant
+            activation_to_bf16_node = linear_node.args[input_index]
+            dequant_node = activation_to_bf16_node.args[0]
+    return dequant_node, act_reshape_node, activation_to_bf16_node, act_expand_node
+
+
+def _is_valid_dequant_linear_pattern(
+    dtype, input_dim_exceeds_two, input_contiguous, with_dtype_convert
+):
+    def _inner(match):
+        # Check dequant pattern has only 1 user.
+        (
+            linear_node,
+            _,
+        ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous)
+
+        input_index = 1 if linear_node.target is aten.addmm.default else 0
+        assert dtype in [torch.float32, torch.bfloat16]
+        (
+            dequant_node,
+            _,
+            _,
+            _,
+        ) = _get_linear_dq_node(
+            linear_node,
+            input_index,
+            input_dim_exceeds_two,
+            input_contiguous,
+            with_dtype_convert,
+        )
+
+        assert dequant_node.target in [
+            quantized_decomposed.dequantize_per_tensor.default,
+            quantized_decomposed.dequantize_per_tensor.tensor,
+        ]
+
+        if len(list(dequant_node.users)) != 1:
+            # Ensure the dequant pattern only has 1 user
+            # since we will delete the dequant pattern here
+            return False
+
+        # Extra check for bmm pattern
+        if input_dim_exceeds_two and not input_contiguous:
+            # Check for act
+            # Act expand size should be exactly same as act size
+            act_expand_size = match.kwargs["act_expand_size"]
+            act_node = match.kwargs["x"]
+            if not (
+                hasattr(act_node, "meta")
+                and isinstance(act_node.meta.get("val", None), torch.Tensor)
+                and (act_node.meta["val"].size() == torch.Size(act_expand_size))
+            ):
+                return False
+
+            # Check for wgt
+            # wgt permute dims should be [1, 0]
+            wgt_permute_dims = match.kwargs["permute_axes"]
+            if wgt_permute_dims != [1, 0]:
+                return False
+
+            # Check below wgt size items:
+            # wgt before expand should with dim 2
+            # Expand size should with dim 3
+            # Expand size[0] should same as act size[0]
+            # Expand size[1] should same as wgt size[1]
+            # Expand size[2] should same as wgt size[0]
+            qweight_node = match.kwargs["q_weight"]
+            wgt_expand_size = match.kwargs["wgt_expand_size"]
+            if not (
+                hasattr(qweight_node, "meta")
+                and isinstance(qweight_node.meta.get("val", None), torch.Tensor)
+                and len(qweight_node.meta["val"].size()) == 2
+                and len(wgt_expand_size) == 3
+                and wgt_expand_size[0] == act_node.meta["val"].size()[0]
+                and wgt_expand_size[1] == qweight_node.meta["val"].size()[1]
+                and wgt_expand_size[2] == qweight_node.meta["val"].size()[0]
+            ):
+                return False
+
+        return True
+
+    return _inner
+
+
+def _register_qlinear_weight_prepack_pass(
+    pattern,
+    pass_number,
+    dtype=torch.float32,
+    input_dim_exceeds_two=False,
+    input_contiguous=True,
+    with_dtype_convert=False,
+):
+    @register_freezing_graph_pattern(
+        pattern,
+        extra_check=_is_valid_dequant_linear_pattern(
+            dtype, input_dim_exceeds_two, input_contiguous, with_dtype_convert
+        ),
+        pass_number=pass_number,
+    )
+    def qlinear_weight_prepack(match: Match, *args, **kwargs):
+        """
+        Match the pattern:
+        int8 activation
+          |
+        dequant_per_tensor
+          |
+        mm/addmm <- t <- dequant_per_channel <- int8_weight
+
+        Insert weight prepack node and change the pattern to:
+        int8 activation
+          |
+        onednn.qlinear_pointwise <- onednn.qlinear_prepack <- int8_weight
+        """
+        assert dtype in [torch.float32, torch.bfloat16]
+        (
+            linear_node,
+            output_reshape_node,
+        ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous)
+        input_index = 1 if linear_node.target is aten.addmm.default else 0
+        weight_index = input_index + 1
+
+        (
+            dequant_node,
+            act_reshape_node,
+            activation_to_bf16_node,
+            act_expand_node,
+        ) = _get_linear_dq_node(
+            linear_node,
+            input_index,
+            input_dim_exceeds_two,
+            input_contiguous,
+            with_dtype_convert,
+        )
+
+        if input_dim_exceeds_two and not input_contiguous:
+            wgt_expand_node = linear_node.args[weight_index]
+            assert wgt_expand_node.target is aten.expand.default
+            t_node = wgt_expand_node.args[0]
+        else:
+            t_node = linear_node.args[weight_index]
+
+        if dtype == torch.float32:
+            dequant_per_channel = t_node.args[0]
+        else:
+            weight_to_bf16_node = t_node.args[0]
+            dequant_per_channel = weight_to_bf16_node.args[0]
+        assert (
+            dequant_per_channel.target
+            is quantized_decomposed.dequantize_per_channel.default
+        )
+
+        # Activation QParams
+        qx, x_zp, x_scale = (
+            kwargs["x"],
+            kwargs["x_zp"],
+            kwargs["x_scale"],
+        )
+
+        # Weight QParams
+        qw, w_scale, w_zp = (
+            kwargs["q_weight"],
+            kwargs["w_scale"],
+            kwargs["w_zp"],
+        )
+
+        # Params
+        bias = kwargs.get("b")
+
+        x_shape = qx.meta.get("tensor_meta").shape
+        if has_free_symbols(x_shape):
+            # For dynamic shape case, we can't get activation shape ahead of runtime.
+            x_shape = None
+        graph = match.graph
+        with graph.inserting_before(linear_node):
+            # Insert weight prepack node and the qlinear node
+            packed_weight_inputs = (
+                qw,
+                x_shape,
+            )
+            packed_weight_op = torch.ops.onednn.qlinear_prepack
+            prepack_weight_node = graph.call_function(
+                packed_weight_op, args=packed_weight_inputs
+            )
+
+            new_args: tuple[Any, ...] = (
+                qx,
+                x_scale,
+                x_zp,
+                prepack_weight_node,
+                w_scale,
+                w_zp,
+                bias,
+                1.0,  # output_scale
+                0,  # output_zero_point
+                dtype,  # output_dtype
+                "none",  # post op name
+                [],  # post op args
+                "",  # post op algorithm
+            )
+            Node = torch.fx.node.Node
+            if isinstance(x_scale, Node) and isinstance(x_zp, Node):
+                new_linear_node = graph.call_function(
+                    torch.ops.onednn.qlinear_pointwise.tensor, args=new_args
+                )
+            else:
+                new_linear_node = graph.call_function(
+                    torch.ops.onednn.qlinear_pointwise.default, args=new_args
+                )
+            if input_dim_exceeds_two:
+                if input_contiguous:
+                    output_reshape_node.replace_all_uses_with(new_linear_node)
+                    new_linear_node.meta.update(output_reshape_node.meta)
+                else:
+                    if bias:
+                        output_add_node_for_bias = match.output_node()
+                        assert output_add_node_for_bias.target is aten.add.Tensor
+                        output_add_node_for_bias.replace_all_uses_with(new_linear_node)
+                        new_linear_node.meta.update(output_add_node_for_bias.meta)
+                    else:
+                        linear_node.replace_all_uses_with(new_linear_node)
+                        new_linear_node.meta.update(linear_node.meta)
+            else:
+                linear_node.replace_all_uses_with(new_linear_node)
+                new_linear_node.meta.update(linear_node.meta)
+
+            # Erase the original linear node
+            if input_dim_exceeds_two:
+                if input_contiguous:
+                    graph.erase_node(output_reshape_node)
+                elif not input_contiguous and bias:
+                    graph.erase_node(output_add_node_for_bias)  # type: ignore[possibly-undefined]
+            graph.erase_node(linear_node)
+            if input_dim_exceeds_two:
+                if input_contiguous:
+                    graph.erase_node(act_reshape_node)
+                else:
+                    graph.erase_node(act_expand_node)
+                    graph.erase_node(wgt_expand_node)  # type: ignore[possibly-undefined]
+            if with_dtype_convert:
+                graph.erase_node(activation_to_bf16_node)
+            # Erase the dequant pattern
+            graph.erase_node(dequant_node)
+            # Erase the dequant per channel pattern
+            graph.erase_node(t_node)
+            if dtype == torch.bfloat16:
+                graph.erase_node(weight_to_bf16_node)  # type: ignore[possibly-undefined]
+            graph.erase_node(dequant_per_channel)
+
+            counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1
+            counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len(
+                match.nodes
+            )
+
+
+def _generate_dequant_linear_node_pattern(
+    _dequant_per_channel_pattern,
+    dtype=torch.float32,
+    input_dim_exceeds_two=False,
+    is_tensor_overload=False,
+    with_dtype_convert=False,
+):
+    assert dtype in [torch.float32, torch.bfloat16]
+    t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype)
+    dequant_linear_bias_pattern = _may_generate_pattern_with_reshape(
+        CallFunction(
+            aten.addmm.default,
+            KeywordArg("b"),
+            _may_generate_pattern_with_reshape(
+                _may_generate_pattern_with_dtype_convert(
+                    get_dequantize_per_tensor_activation_pattern(is_tensor_overload),
+                    KeywordArg("autocast_act_dtype"),
+                    with_dtype_convert,
+                ),
+                KeywordArg("act_reshape_size"),
+                input_dim_exceeds_two,
+            ),
+            t_pattern,
+        ),
+        KeywordArg("output_reshape_size"),
+        input_dim_exceeds_two,
+    )
+    dequant_linear_no_bias_pattern = _may_generate_pattern_with_reshape(
+        CallFunction(
+            aten.mm.default,
+            _may_generate_pattern_with_reshape(
+                _may_generate_pattern_with_dtype_convert(
+                    get_dequantize_per_tensor_activation_pattern(is_tensor_overload),
+                    KeywordArg("autocast_act_dtype"),
+                    with_dtype_convert,
+                ),
+                KeywordArg("act_reshape_size"),
+                input_dim_exceeds_two,
+            ),
+            t_pattern,
+        ),
+        KeywordArg("output_reshape_size"),
+        input_dim_exceeds_two,
+    )
+    return dequant_linear_bias_pattern, dequant_linear_no_bias_pattern
+
+
+def _generate_dequant_bmm_node_pattern(
+    _dequant_per_channel_pattern,
+    dtype=torch.float32,
+    with_bias=False,
+    is_tensor_overload=False,
+    with_dtype_convert=False,
+):
+    # When activation of linear dim exceed 2 and not contiguous
+    t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype)
+
+    assert dtype in [torch.float32, torch.bfloat16]
+    dequant_bmm_pattern = CallFunction(
+        aten.bmm.default,
+        CallFunction(
+            aten.expand.default,
+            _may_generate_pattern_with_dtype_convert(
+                get_dequantize_per_tensor_activation_pattern(is_tensor_overload),
+                KeywordArg("autocast_act_dtype"),
+                with_dtype_convert,
+            ),
+            KeywordArg("act_expand_size"),
+        ),
+        CallFunction(
+            aten.expand.default,
+            t_pattern,
+            KeywordArg("wgt_expand_size"),
+        ),
+    )
+
+    def _generate_pattern_with_output_add(_dequant_bmm_pattern, _with_bias):
+        if _with_bias:
+            return CallFunction(
+                aten.add.Tensor,
+                _dequant_bmm_pattern,
+                KeywordArg("b"),
+            )
+        else:
+            return _dequant_bmm_pattern
+
+    return _generate_pattern_with_output_add(dequant_bmm_pattern, with_bias)
+
+
+def _generate_qlinear_weight_prepack_patterns(
+    dtype=torch.float32,
+    input_dim_exceeds_two=False,
+    input_contiguous=True,
+    with_bias=False,
+    is_tensor_overload=False,
+    with_dtype_convert=False,
+):
+    if input_dim_exceeds_two and not input_contiguous:
+        return _generate_dequant_bmm_node_pattern(
+            dequantize_per_channel_weight_pattern,
+            dtype,
+            with_bias,
+            is_tensor_overload,
+            with_dtype_convert,
+        )
+    else:
+        return _generate_dequant_linear_node_pattern(
+            dequantize_per_channel_weight_pattern,
+            dtype,
+            input_dim_exceeds_two,
+            is_tensor_overload,
+            with_dtype_convert,
+        )
+
+
+def _generate_linear_dynamic_fp16_pattern(
+    _dequant_weight_pattern,
+    input_dim_exceeds_two=False,
+    input_contiguous=True,
+    relu_fused=False,
+):
+    dtype = torch.float32
+    t_pattern = _generate_linear_t_pattern(_dequant_weight_pattern, dtype)
+
+    if input_dim_exceeds_two and not input_contiguous:
+        # pattern is
+        #                   x -> expand -> bmm (-> add) (-> relu)
+        # w -> dequant -> permute -> expand /
+        pattern_no_bias = CallFunction(
+            aten.bmm.default,
+            CallFunction(
+                aten.expand.default,
+                KeywordArg("x"),
+                KeywordArg("act_expand_size"),
+            ),
+            CallFunction(
+                aten.expand.default,
+                t_pattern,
+                KeywordArg("wgt_expand_size"),
+            ),
+        )
+        pattern_with_bias = CallFunction(
+            aten.add.Tensor,
+            pattern_no_bias,
+            KeywordArg("b"),
+        )
+        if relu_fused:
+            pattern_with_bias = CallFunction(aten.relu.default, pattern_with_bias)
+            pattern_no_bias = CallFunction(aten.relu.default, pattern_no_bias)
+        return pattern_with_bias, pattern_no_bias
+
+    x_pattern_with_reshape = _may_generate_pattern_with_reshape(
+        KeywordArg("x"),
+        KeywordArg("act_reshape_size"),
+        input_dim_exceeds_two,
+    )
+    dequant_linear_bias_pattern = generate_pattern_with_unary(
+        _may_generate_pattern_with_reshape(
+            CallFunction(
+                aten.addmm.default,
+                KeywordArg("b"),
+                x_pattern_with_reshape,
+                t_pattern,
+            ),
+            KeywordArg("output_reshape_size"),
+            input_dim_exceeds_two,
+        ),
+        aten.relu.default if relu_fused else None,
+    )
+    dequant_linear_no_bias_pattern = generate_pattern_with_unary(
+        _may_generate_pattern_with_reshape(
+            CallFunction(
+                aten.mm.default,
+                x_pattern_with_reshape,
+                t_pattern,
+            ),
+            KeywordArg("output_reshape_size"),
+            input_dim_exceeds_two,
+        ),
+        aten.relu.default if relu_fused else None,
+    )
+    return dequant_linear_bias_pattern, dequant_linear_no_bias_pattern
+
+
+def _register_dequant_promotion():
+    dequant_pattern_cases = itertools.product(
+        [torch.float32, torch.bfloat16], [True, False], [True, False]
+    )
+    for dtype, input_dim_exceeds_two, is_tensor_overload in dequant_pattern_cases:
+        # 4 dequantization patterns will be matched based on the dtype and input dimension size.
+        # Case 1: int8-mixed-fp32, input dim size is 2
+        # Case 2: int8-mixed-fp32, input dim size exceeds 2
+        # Case 3: int8-mixed-bf16, input dim size is 2
+        # Case 4: int8-mixed-bf16, input dim size exceeds 2
+        #           quant
+        #   + - - - - | - - - - +
+        #   |      dequant      |
+        #   |         |         |
+        #   |    OPT(to_bf16)   |
+        #   |         |         |
+        #   |    OPT(reshape)   |
+        #   |      /     \      |
+        #   |    node1  node2   |
+        #   + - - | - - - | - - +
+        #  OPT(reshape) OPT(reshape)
+        #   + - - | - - - | - - +
+        #  OPT(to_fp32) OPT(to_fp32)
+        #   + - - | - - - | - - +
+        #       quant   quant
+        _register_dequant_promotion_pass(
+            _may_generate_pattern_with_reshape(
+                _may_generate_pattern_with_dtype_convert(
+                    get_dequantize_per_tensor_activation_pattern(
+                        is_tensor_overload=is_tensor_overload
+                    ),
+                    KeywordArg("autocast_act_dtype"),
+                    dtype == torch.bfloat16,
+                ),
+                KeywordArg("act_reshape_size"),
+                with_reshape=input_dim_exceeds_two,
+            ),
+            pass_number=0,
+            dtype=dtype,
+        )  # pass_number=0 to run before weight prepack
+
+
+def _register_qconv_weight_prepack():
+    for dtype, with_dtype_convert in itertools.product(
+        [torch.float32, torch.bfloat16], [True, False]
+    ):
+        if dtype == torch.float32 and with_dtype_convert:
+            continue
+        weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(
+            dtype, with_dtype_convert
+        )
+        for weight_prepack_pattern in weight_prepack_patterns:
+            # Register to pass_number 1, so we can do dequant promotion in pass_number 0.
+            _register_qconv_weight_prepack_pass(
+                weight_prepack_pattern,
+                pass_number=1,
+                dtype=dtype,
+                with_dtype_convert=with_dtype_convert,
+            )
+
+
+def _register_qlinear_weight_prepack():
+    # 6 Linear related patterns will be matched based on the dtype, input dimension size and input contiguous.
+    # Then convert the pattern into a QLinear node with int8_fp32/bf16.
+    # Case 1: int8-mixed-fp32, input dim size is 2
+    # Case 2: int8-mixed-fp32, input dim size exceeds 2 and contiguous
+    # Case 3: int8-mixed-bf16, input dim size is 2
+    # Case 4: int8-mixed-bf16, input dim size exceeds 2 and contiguous
+
+    #   + - - - - | - - - - - - | - - - - - +
+    #   |    dq_per_tensor  dq_per_channel  |
+    #   |         |              |          |
+    #   |    OPT(to_bf16)    OPT(to_bf16)   |
+    #   |         |              |          |
+    #   |     OPT(reshape)   permute        |
+    #   |            \        /             |
+    #   |             addmm/mm              |
+    #   |                |                  |
+    #   |           OPT(reshape)            |
+
+    # Case 5: int8-mixed-fp32, input dim size exceeds 2 and not contiguous
+    # Case 6: int8-mixed-bf16, input dim size exceeds 2 and not contiguous
+
+    #   + - - - - | - - - - - - | - - - - - +
+    #   |    dq_per_tensor  dq_per_channel  |
+    #   |         |              |          |
+    #   |    OPT(to_bf16)    OPT(to_bf16)   |
+    #   |         |              |          |
+    #   |       expand       permute        |
+    #   |          \             |          |
+    #   |                    expand         |
+    #   |                    /              |
+    #   |               bmm                 |
+    #   |                |                  |
+    #   |            OPT(add)               |
+
+    linear_weight_prepack_cases = itertools.product(
+        [torch.float32, torch.bfloat16], [True, False], [True, False], [True, False]
+    )
+
+    # Step 1: register patterns from mm and addmm
+    for (
+        dtype,
+        input_dim_exceeds_two,
+        is_tensor_overload,
+        with_dtype_convert,
+    ) in linear_weight_prepack_cases:
+        if dtype == torch.float32 and with_dtype_convert:
+            continue
+        weight_prepack_patterns = _generate_qlinear_weight_prepack_patterns(
+            dtype,
+            input_dim_exceeds_two,
+            is_tensor_overload=is_tensor_overload,
+            with_dtype_convert=with_dtype_convert,
+        )
+        for weight_prepack_pattern in weight_prepack_patterns:
+            # Register to pass_number 1, so we can do dequant promotion in pass_number 0.
+            _register_qlinear_weight_prepack_pass(
+                weight_prepack_pattern,
+                pass_number=1,
+                dtype=dtype,
+                input_dim_exceeds_two=input_dim_exceeds_two,
+                with_dtype_convert=with_dtype_convert,
+            )
+
+    # Step 2: register patterns from bmm
+    # Linear might be decomposed into bmm when input dim exceeds 2 and not contiguous
+    # refer to:
+    # https://github.com/pytorch/pytorch/blob/80c07df659362a95da7cd4f3ec367abfdace38c4/torch/_decomp/decompositions.py#L3965-L3968
+    # in this case, we can convert it back to qlinear
+    for (
+        dtype,
+        with_bias,
+        is_tensor_overload,
+        with_dtype_convert,
+    ) in itertools.product(
+        [torch.float32, torch.bfloat16], [True, False], [True, False], [True, False]
+    ):
+        if dtype == torch.float32 and with_dtype_convert:
+            continue
+        bmm_pattern = _generate_qlinear_weight_prepack_patterns(
+            dtype=dtype,
+            input_dim_exceeds_two=True,
+            input_contiguous=False,
+            with_bias=with_bias,
+            is_tensor_overload=is_tensor_overload,
+            with_dtype_convert=with_dtype_convert,
+        )
+        _register_qlinear_weight_prepack_pass(
+            bmm_pattern,
+            pass_number=1
+            if with_bias
+            else 2,  # if with_bias, there is an output add, so we should try to match it firstly
+            dtype=dtype,
+            input_dim_exceeds_two=True,
+            input_contiguous=False,
+            with_dtype_convert=with_dtype_convert,
+        )
+
+
+def _register_linear_dynamic_fp16_weight_prepack_pass(
+    pattern,
+    pass_number,
+    input_dim_exceeds_two=False,
+    input_contiguous=True,
+    relu_fused=False,
+):
+    def _extra_check_fn(match: Match):
+        return match.kwargs["dtype_fp16"] == torch.float16
+
+    @register_freezing_graph_pattern(
+        pattern,
+        extra_check=_extra_check_fn,
+        pass_number=pass_number,
+    )
+    def linear_dynamic_fp16_weight_prepack(match: Match, *args, **kwargs):
+        """
+        Match the pattern:
+        fp32 activation
+          |
+        mm/addmm <- t <- to_fp32 <- to_fp16 <- weight
+          |
+        (reshape) <- (relu)
+
+        OR
+
+        fp32 activation
+          |
+        expand
+          |
+         bmm <- expand <- t <- to_fp32 <- to_fp16 <- weight
+          |
+        (add) <- (relu)
+
+        Insert weight prepack node and change the pattern to:
+        fp32 activation
+          |
+        onednn.linear_dynamic_fp16 <- onednn.linear_prepack_fp16 <- weight
+        (or onednn.linear_relu_dynamic_fp16)
+        """
+        # find params
+        x = kwargs["x"]
+        w = kwargs["w"]
+        bias = kwargs.get("b")
+
+        # find linear node
+        nodes_to_find = [aten.addmm.default, aten.mm.default, aten.bmm.default]
+        linear_nodes = []
+        for node in nodes_to_find:
+            linear_nodes.extend(filter_nodes(match.nodes, node))
+        assert len(linear_nodes) == 1
+        linear_node = linear_nodes[0]
+        assert isinstance(linear_node, torch.fx.node.Node)
+        input_index = 1 if linear_node.target is aten.addmm.default else 0
+        weight_index = input_index + 1
+
+        # find relu node
+        relu_node = None
+        if relu_fused:
+            relu_node = match.output_node()
+            assert isinstance(relu_node, torch.fx.node.Node)
+
+        # find reshape node, expand node and add node
+        (
+            act_reshape_node,
+            output_reshape_node,
+            expand_x_node,
+            expand_w_node,
+            add_bias_node,
+        ) = (None, None, None, None, None)
+        t_node = None
+        if input_dim_exceeds_two:
+            if input_contiguous:
+                act_reshape_node = linear_node.args[input_index]
+                t_node = linear_node.args[weight_index]
+                output_reshape_node = next(iter(linear_node.users))
+                assert output_reshape_node.target is aten.reshape.default
+            else:
+                expand_x_node = linear_node.args[input_index]
+                expand_w_node = linear_node.args[weight_index]
+                assert isinstance(expand_w_node, torch.fx.node.Node)
+                t_node = expand_w_node.args[0]
+                if bias:
+                    add_bias_node = next(iter(linear_node.users))
+                    assert add_bias_node.target is aten.add.Tensor
+        else:
+            t_node = linear_node.args[weight_index]
+        assert isinstance(t_node, torch.fx.node.Node)
+
+        w_to_fp32_node = t_node.args[0]
+        assert (
+            isinstance(w_to_fp32_node, torch.fx.node.Node)
+            and w_to_fp32_node.target
+            is quantized_decomposed.convert_element_type.no_fuse
+        )
+        w_to_fp16_node = w_to_fp32_node.args[0]
+        assert (
+            isinstance(w_to_fp16_node, torch.fx.node.Node)
+            and w_to_fp16_node.target
+            is quantized_decomposed.convert_element_type.no_fuse
+        )
+
+        x_shape = x.meta.get("tensor_meta").shape
+        if has_free_symbols(x_shape):
+            # For dynamic shape case, we can't get activation shape ahead of runtime.
+            x_shape = None
+        graph = match.graph
+        with graph.inserting_before(linear_node):
+            # Insert weight prepack node and the qlinear node
+            packed_weight_inputs = (
+                w,
+                x_shape,
+            )
+            packed_weight_op = torch.ops.onednn.linear_prepack_fp16
+            prepack_weight_node = graph.call_function(
+                packed_weight_op, args=packed_weight_inputs
+            )
+
+            # create new linear node and insert on graph
+            new_args: tuple[Any, ...] = (
+                x,
+                prepack_weight_node,
+                bias,
+            )
+            linear_op = (
+                torch.ops.onednn.linear_relu_dynamic_fp16.default
+                if relu_fused
+                else torch.ops.onednn.linear_dynamic_fp16.default
+            )
+            new_linear_node = graph.call_function(linear_op, args=new_args)
+            out_node = match.output_node()
+            out_node.replace_all_uses_with(new_linear_node)
+
+            # Erase the original nodes in the reverse order
+            new_linear_node.meta.update(out_node.meta)
+            if relu_node is not None:
+                graph.erase_node(relu_node)
+            if output_reshape_node is not None:
+                graph.erase_node(output_reshape_node)
+            if add_bias_node is not None:
+                graph.erase_node(add_bias_node)
+            graph.erase_node(linear_node)
+            if act_reshape_node is not None:
+                assert isinstance(act_reshape_node, torch.fx.node.Node)
+                graph.erase_node(act_reshape_node)
+            if expand_x_node is not None:
+                assert isinstance(expand_x_node, torch.fx.node.Node)
+                graph.erase_node(expand_x_node)
+            if expand_w_node is not None:
+                assert isinstance(expand_w_node, torch.fx.node.Node)
+                graph.erase_node(expand_w_node)
+            graph.erase_node(t_node)
+            graph.erase_node(w_to_fp32_node)
+            graph.erase_node(w_to_fp16_node)
+
+            counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1
+            counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len(
+                match.nodes
+            )
+
+
+def _register_linear_dynamic_fp16_weight_prepack():
+    to_dtype_op = torch.ops.quantized_decomposed.convert_element_type.no_fuse
+    weight_pattern = CallFunction(
+        to_dtype_op,
+        CallFunction(
+            to_dtype_op,
+            KeywordArg("w"),
+            KeywordArg("dtype_fp16"),
+        ),
+        KeywordArg("dtype_fp32"),
+    )
+    cases = itertools.product(
+        [False, True],  # input_dim_exceeds_two
+        [True, False],  # input_contiguous
+        [False, True],  # relu fused
+    )
+    for input_dim_exceeds_two, input_contiguous, relu_fused in cases:
+        patterns = _generate_linear_dynamic_fp16_pattern(
+            weight_pattern,
+            input_dim_exceeds_two,
+            input_contiguous,
+            relu_fused,
+        )
+        for pattern in patterns:
+            _register_linear_dynamic_fp16_weight_prepack_pass(
+                pattern,
+                pass_number=0 if relu_fused else 1,
+                input_dim_exceeds_two=input_dim_exceeds_two,
+                input_contiguous=input_contiguous,
+                relu_fused=relu_fused,
+            )
+
+
+def _register_smooth_quant_int_mm_pattern():
+    """
+    The pattern is:
+      (no bias) reshape -> _int_mm -> convert_element_type -> (expand ->) mul -> mul -> reshape
+    or
+      (with bias) pattern_no_bias -> add (-> reshape -> reshape)
+    """
+
+    # When torch.compile'ing with dynamic=True, the expand node and the two tailing reshape nodes exist
+    # When torch.compile'ing with dynamic=False, they don't exist
+    def get_pattern_no_bias(expand_a_scale: bool, reshape_a: bool = True):
+        return CallFunction(
+            aten.mul.Tensor,
+            CallFunction(
+                aten.mul.Tensor,
+                CallFunction(
+                    prims.convert_element_type.default,
+                    CallFunction(
+                        aten._int_mm.default,
+                        CallFunction(
+                            aten.reshape.default,
+                            KeywordArg("a"),
+                            KeywordArg("in_shape"),
+                        )
+                        if reshape_a
+                        else KeywordArg("a"),
+                        KeywordArg("b"),
+                    ),
+                    KeywordArg("dtype"),
+                ),
+                (
+                    CallFunction(
+                        aten.expand.default,
+                        KeywordArg("x_scale"),
+                        Arg(),
+                    )
+                    if expand_a_scale
+                    else KeywordArg("x_scale")
+                ),
+            ),
+            KeywordArg("w_scale"),
+        )
+
+    def _with_outer_reshape(pattern):
+        return CallFunction(
+            aten.reshape.default, pattern, KeywordArg("out_shape_no_bias")
+        )
+
+    # for torch.compile(dynamic=False)
+    pattern_no_bias_1 = _with_outer_reshape(get_pattern_no_bias(expand_a_scale=False))
+    pattern_with_bias_1 = CallFunction(
+        aten.add.Tensor,
+        pattern_no_bias_1,
+        KeywordArg("bias"),
+    )
+    # for torch.compile(dynamic=True)
+    pattern_no_bias_2 = _with_outer_reshape(get_pattern_no_bias(expand_a_scale=True))
+    pattern_with_bias_2 = CallFunction(
+        aten.reshape.default,
+        CallFunction(
+            aten.reshape.default,
+            CallFunction(
+                aten.add.Tensor,
+                pattern_no_bias_2,
+                KeywordArg("bias"),
+            ),
+            Arg(),
+        ),
+        KeywordArg("out_shape_with_bias"),
+    )
+
+    # The following patterns are for torchao int8_dynamic_activation_int8_weight linear,
+    # when both activation and weights are symmetrically quantized.
+    # In practice, though, they may also match smooth-quant pattern when a 2D input shape would be used.
+    # Since add is not currently being used as a oneDNN post-op, but is unfused, we don't need these patterns with bias.
+    # Ideally, we should add mul + add post-op support in ATen int8 oneDNN linear op.
+    pattern1_with_no_outer_or_act_reshape = get_pattern_no_bias(
+        expand_a_scale=False, reshape_a=False
+    )
+    pattern2_with_no_outer_or_act_reshape = get_pattern_no_bias(
+        expand_a_scale=True, reshape_a=False
+    )
+
+    def _validate_pattern(match: Match):
+        if len(match.nodes) not in [4, 5, 6, 7, 10]:
+            return False
+        # Make sure weight is a constant
+        aten_int_mm_node = filter_nodes(match.nodes, aten._int_mm.default)[0]
+        if not isinstance(aten_int_mm_node.args[1], torch.fx.node.Node):
+            return False
+        if aten_int_mm_node.args[1].op != "get_attr":
+            return False
+
+        if len(match.nodes) == 10:
+            # Check the two tailing reshape nodes can be fused
+            if match.nodes[9].args[1] != match.nodes[6].args[1]:
+                return False
+        if len(match.nodes) == 10 or (
+            len(match.nodes) == 7 and match.nodes[6].target is aten.add.Tensor
+        ):
+            bias_idx = 7 if len(match.nodes) == 10 else 6
+            # Check bias shape
+            bias_node = match.nodes[bias_idx].args[1]
+            if not isinstance(bias_node, torch.fx.node.Node):
+                return False
+            if len(bias_node.meta.get("tensor_meta").shape) != 1:  # type: ignore[union-attr]
+                return False
+        return True
+
+    pattern_to_pass_number = {
+        pattern_no_bias_2: 0,
+        pattern_with_bias_2: 0,
+        pattern_no_bias_1: 1,
+        pattern_with_bias_1: 1,
+        pattern1_with_no_outer_or_act_reshape: 2,
+        pattern2_with_no_outer_or_act_reshape: 2,
+    }
+    for pattern, pass_number in pattern_to_pass_number.items():
+
+        @register_freezing_graph_pattern(
+            pattern,
+            extra_check=_validate_pattern,
+            pass_number=pass_number,
+        )
+        def _int_mm_weight_prepack(match: Match, *args, **kwargs):
+            bias = kwargs.get("bias")
+            x = kwargs["a"]
+            weight = kwargs["b"]
+            dtype = kwargs["dtype"]
+            x_scale = kwargs["x_scale"]
+            w_scale = kwargs["w_scale"]
+            x_shape = x.meta.get("tensor_meta").shape
+            if has_free_symbols(x_shape):
+                # For dynamic shape case, we can't get activation shape ahead of runtime.
+                x_shape = None
+
+            out_node = match.output_node()
+            with match.graph.inserting_before(out_node):
+                transpose_node = match.graph.call_function(
+                    aten.permute.default, args=(weight, [1, 0])
+                )
+                contig_node = match.graph.call_function(
+                    aten.contiguous.default, args=(transpose_node,)
+                )
+                packed_weight_inputs = (
+                    contig_node,
+                    x_shape,
+                )
+                packed_weight_op = torch.ops.onednn.qlinear_prepack
+                prepack_weight_node = match.graph.call_function(
+                    packed_weight_op, args=packed_weight_inputs
+                )
+
+                dummy_zp = None
+                w_scale = match.graph.call_function(
+                    prims.convert_element_type.default, args=(w_scale, torch.float32)
+                )
+
+                x_scale_shape = x_scale.meta.get("tensor_meta").shape
+                x_scale_is_scalar = False
+                if not has_free_symbols(x_scale_shape):
+                    prod = 1
+                    for d in x_scale_shape:
+                        prod *= d
+                    x_scale_is_scalar = prod == 1
+
+                new_args: tuple[Any, ...]
+                if x_scale_is_scalar:
+                    # in this case, we can call onednn.qlinear directly
+                    new_args = (
+                        x,
+                        x_scale,
+                        dummy_zp,  # x_zp
+                        prepack_weight_node,
+                        w_scale,
+                        dummy_zp,  # w_zp
+                        bias,
+                        1.0,  # output_scale
+                        0,  # output_zero_point
+                        dtype,  # output_dtype
+                        "none",  # post op name
+                        [],  # post op args
+                        "",  # post op algorithm
+                    )
+                    new_linear_node = match.graph.call_function(
+                        torch.ops.onednn.qlinear_pointwise.tensor, args=new_args
+                    )
+                    out_node.replace_all_uses_with(new_linear_node)
+                    new_linear_node.meta.update(out_node.meta)
+                else:
+                    # onednn.qlinear does not support per-channel quantization of x
+                    # so in this case, we have to apply x scale and add bias ourselves after qlinear
+                    in_shape = kwargs.get("in_shape")
+                    if in_shape is None:
+                        x_reshaped = x
+                    else:
+                        x_reshaped = match.graph.call_function(
+                            aten.reshape.default, args=(x, in_shape)
+                        )
+                    new_args = (
+                        x_reshaped,
+                        1.0,  # x_scale
+                        0,  # x_zp
+                        prepack_weight_node,
+                        w_scale,
+                        dummy_zp,  # w_zp
+                        None,  # bias
+                        1.0,  # output_scale
+                        0,  # output_zero_point
+                        dtype,  # output_dtype
+                        "none",  # post op name
+                        [],  # post op args
+                        "",  # post op algorithm
+                    )
+                    new_linear_node = match.graph.call_function(
+                        torch.ops.onednn.qlinear_pointwise, args=new_args
+                    )
+                    # apply x scale
+                    new_out_node = match.graph.call_function(
+                        aten.mul.Tensor, args=(new_linear_node, x_scale)
+                    )
+
+                    # Add bias and reshape
+                    has_outer_reshape = (
+                        kwargs.get("out_shape_with_bias") is not None
+                        or kwargs.get("out_shape_no_bias") is not None
+                    )
+
+                    if has_outer_reshape:
+                        out_shape = kwargs.get(
+                            "out_shape_with_bias", kwargs["out_shape_no_bias"]
+                        )
+                    if bias is not None:
+                        new_out_node = match.graph.call_function(
+                            aten.add.Tensor, args=(new_out_node, bias)
+                        )
+                        if has_outer_reshape:
+                            new_out_node = match.graph.call_function(
+                                aten.reshape.default,
+                                args=(new_out_node, out_shape),  # type: ignore[possibly-undefined]
+                            )
+                    else:
+                        if has_outer_reshape:
+                            new_out_node = match.graph.call_function(
+                                aten.reshape.default,
+                                args=(new_out_node, out_shape),  # type: ignore[possibly-undefined]
+                            )
+                    out_node.replace_all_uses_with(new_out_node)
+                    new_out_node.meta.update(out_node.meta)
+                for node in reversed(match.nodes):
+                    match.graph.erase_node(node)
+                counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1
+                counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len(
+                    match.nodes
+                )
+
+
+class PostOpAttr:
+    def __init__(
+        self,
+        binary_op_name: str = "none",
+        alpha=None,
+        unary_op_name: str = "none",
+        scalars_attr=None,
+        algorithm_attr=None,
+    ) -> None:
+        self.binary_op_name = binary_op_name
+        self.alpha = alpha if alpha else 1.0
+        self.unary_op_name = unary_op_name
+        self.scalars_attr = scalars_attr if scalars_attr else []
+        self.algorithm_attr = algorithm_attr if algorithm_attr else ""
+
+
+def _register_qconv_post_op_fusion_pass(
+    pattern,
+    pass_number,
+    computation_op,
+    post_op_attr,
+):
+    has_binary_post_op = post_op_attr.binary_op_name != "none"
+
+    @register_freezing_graph_pattern(
+        pattern,
+        extra_check=_is_valid_qconv_post_op_fusion_pattern(has_binary_post_op),
+        pass_number=pass_number,
+    )
+    def qconv(match: Match, *args, **kwargs):
+        # Activation QParams
+        x, x_scale, x_zp = (
+            kwargs["x"],
+            kwargs["x_scale"],
+            kwargs["x_zp"],
+        )
+        # Weight QParams
+        packed_weight, w_scale, w_zp = (
+            kwargs["packed_weight"],
+            kwargs["w_scale"],
+            kwargs["w_zp"],
+        )
+        # Conv Params
+        b, stride, padding, dilation, groups = (
+            kwargs["b"],
+            kwargs["stride"],
+            kwargs["padding"],
+            kwargs["dilation"],
+            kwargs["groups"],
+        )
+        output_dtype = _get_pattern_output_dtype(match)
+        assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16]
+        # Output QParams
+        o_inv_scale = (
+            kwargs["o_inv_scale"]
+            if (output_dtype == torch.uint8 or output_dtype == torch.int8)
+            else 1.0
+        )
+        o_zero_point = (
+            kwargs["o_zp"]
+            if (output_dtype == torch.uint8 or output_dtype == torch.int8)
+            else 0
+        )
+        assert (
+            kwargs["postop_name"] == "none"
+        )  # Expected no post op fused in weight prepack phase
+        if post_op_attr.unary_op_name == "hardtanh":
+            min_value = kwargs.get("min_value")
+            max_value = kwargs.get("max_value")
+            post_op_attr.scalars_attr = [min_value, max_value]
+
+        out_node = match.output_node()
+        with match.graph.inserting_before(out_node):
+            if not has_binary_post_op:
+                computation_args: tuple[Any, ...] = (
+                    x,
+                    x_scale,
+                    x_zp,
+                    packed_weight,
+                    w_scale,
+                    w_zp,
+                    b,
+                    stride,
+                    padding,
+                    dilation,
+                    groups,
+                    o_inv_scale,
+                    o_zero_point,
+                    output_dtype,
+                    post_op_attr.unary_op_name,
+                    post_op_attr.scalars_attr,
+                    post_op_attr.algorithm_attr,
+                )
+            else:
+                accum = (
+                    kwargs["accum"]
+                    if output_dtype in [torch.uint8, torch.int8]
+                    else kwargs["accum_after_dequant"]
+                )
+                accum_scale = (
+                    kwargs["accum_scale"]
+                    if output_dtype in [torch.uint8, torch.int8]
+                    else 1.0
+                )
+                accum_zp = (
+                    kwargs["accum_zp"]
+                    if output_dtype in [torch.uint8, torch.int8]
+                    else 0
+                )
+                computation_args = (
+                    x,
+                    x_scale,
+                    x_zp,
+                    packed_weight,
+                    w_scale,
+                    w_zp,
+                    accum,
+                    b,
+                    stride,
+                    padding,
+                    dilation,
+                    groups,
+                    o_inv_scale,
+                    o_zero_point,
+                    output_dtype,
+                    accum_scale,
+                    accum_zp,
+                    post_op_attr.binary_op_name,
+                    post_op_attr.alpha,
+                    post_op_attr.unary_op_name,
+                    post_op_attr.scalars_attr,
+                    post_op_attr.algorithm_attr,
+                )
+            new_conv_node = match.graph.call_function(
+                computation_op, args=computation_args
+            )
+            out_node.replace_all_uses_with(new_conv_node)
+            new_conv_node.meta.update(out_node.meta)
+            for node in reversed(match.nodes):
+                match.graph.erase_node(node)
+        count_key = (
+            "qconv2d_binary_matcher_count"
+            if has_binary_post_op
+            else "qconv_unary_matcher_count"
+        )
+        nodes_key = (
+            "qconv2d_binary_matcher_nodes"
+            if has_binary_post_op
+            else "qconv_unary_matcher_nodes"
+        )
+        counters["inductor"][count_key] += 1
+        counters["inductor"][nodes_key] += len(match.nodes)
+
+    return qconv
+
+
+def _register_qconv_unary_fusion():
+    from .mkldnn_fusion import _hardswish_fusion, _hardtanh_fusion, _silu_fusion
+
+    for original_pattern_output_dtype in [torch.float32, torch.bfloat16]:
+        # Priority 1 to match: QConv2d Unary pattern with int8 output
+        # If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly.
+        # For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant
+        is_bf16 = original_pattern_output_dtype == torch.bfloat16
+        conv_unary_replace_patterns = {
+            PostOpAttr(
+                "none", None, "none", [], ""
+            ): generate_pattern_with_output_quant(
+                get_qconv_pt2e_pattern(users=1),
+            ),
+            PostOpAttr(
+                "none", None, "relu", [], ""
+            ): generate_pattern_with_output_quant(
+                generate_pattern_with_unary(
+                    get_qconv_pt2e_pattern(users=1), aten.relu.default
+                ),
+            ),
+            PostOpAttr(
+                "none", None, "hardtanh", [], ""
+            ): generate_pattern_with_output_quant(
+                _unary_fusion_pattern(
+                    _hardtanh_fusion,
+                    get_qconv_pt2e_pattern(users=1),
+                    1,
+                    is_bf16,
+                ),
+                with_dtype_convert=is_bf16,
+            ),
+            PostOpAttr(
+                "none", None, "hardswish", [], ""
+            ): generate_pattern_with_output_quant(
+                _unary_fusion_pattern(
+                    _hardswish_fusion,
+                    get_qconv_pt2e_pattern(users=1 if is_bf16 else 2),
+                    2,
+                    is_bf16,
+                ),
+                with_dtype_convert=is_bf16,
+            ),
+            PostOpAttr(
+                "none", None, "swish", [], ""
+            ): generate_pattern_with_output_quant(
+                _unary_fusion_pattern(
+                    _silu_fusion,
+                    get_qconv_pt2e_pattern(users=1 if is_bf16 else 2),
+                    2,
+                    is_bf16,
+                ),
+                with_dtype_convert=is_bf16,
+            ),
+        }
+
+        for unary_attr, patterns in conv_unary_replace_patterns.items():
+            # Register qconv2d pattern for ExternKernel Lowering
+            _register_qconv_post_op_fusion_pass(
+                patterns,
+                3,  # pass_number
+                torch.ops.onednn.qconv_pointwise.default,  # computation_op
+                unary_attr,  # unary_attr
+            )
+
+        # Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output
+        conv_unary_replace_float_out_patterns = {
+            PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary(
+                get_qconv_pt2e_pattern(users=1), aten.relu.default
+            ),
+            PostOpAttr(
+                "none", None, "hardtanh", [], ""
+            ): _may_generate_pattern_with_dtype_convert(
+                _unary_fusion_pattern(
+                    _hardtanh_fusion,
+                    get_qconv_pt2e_pattern(users=1),
+                    1,
+                    is_bf16,
+                ),
+                Arg(),
+                is_bf16,
+            ),
+            PostOpAttr(
+                "none", None, "hardswish", [], ""
+            ): _may_generate_pattern_with_dtype_convert(
+                _unary_fusion_pattern(
+                    _hardswish_fusion,
+                    get_qconv_pt2e_pattern(users=1 if is_bf16 else 2),
+                    2,
+                    is_bf16,
+                ),
+                Arg(),
+                is_bf16,
+            ),
+            PostOpAttr(
+                "none", None, "swish", [], ""
+            ): _may_generate_pattern_with_dtype_convert(
+                _unary_fusion_pattern(
+                    _silu_fusion,
+                    get_qconv_pt2e_pattern(users=1 if is_bf16 else 2),
+                    2,
+                    is_bf16,
+                ),
+                Arg(),
+                is_bf16,
+            ),
+        }
+
+        for unary_attr, patterns in conv_unary_replace_float_out_patterns.items():
+            # Register qconv2d pattern for ExternKernel Lowering
+            _register_qconv_post_op_fusion_pass(
+                patterns,
+                4,  # pass_number
+                torch.ops.onednn.qconv_pointwise.default,  # computation_op
+                unary_attr,  # unary_attr
+            )
+
+
+def _register_qconv_binary_fusion():
+    for int8_mixed_bf16_with_inplace_add in [False, True]:
+        # Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output
+        swap_binary_inputs_list = [False, True]
+        binary_replace_patterns = {}
+        for swap_inputs in swap_binary_inputs_list:
+            binary_replace_patterns.update(
+                {
+                    PostOpAttr(
+                        "sum", 1.0, "none", [], ""
+                    ): generate_pattern_with_output_quant(
+                        generate_pattern_with_binary(
+                            aten.add.Tensor,
+                            get_qconv_pt2e_pattern(users=1),
+                            dequantize_accum_pattern,
+                            int8_mixed_bf16_with_inplace_add,
+                            swap_inputs=swap_inputs,
+                        ),
+                    ),
+                    PostOpAttr(
+                        "sum", 1.0, "relu", [], ""
+                    ): generate_pattern_with_output_quant(
+                        generate_pattern_with_unary(
+                            generate_pattern_with_binary(
+                                aten.add.Tensor,
+                                get_qconv_pt2e_pattern(users=1),
+                                dequantize_accum_pattern,
+                                int8_mixed_bf16_with_inplace_add,
+                                swap_inputs=swap_inputs,
+                            ),
+                            aten.relu.default,
+                        ),
+                    ),
+                }
+            )
+
+        for binary_unary_attr, patterns in binary_replace_patterns.items():
+            _register_qconv_post_op_fusion_pass(
+                patterns,
+                3,  # pass_number
+                torch.ops.onednn.qconv2d_pointwise.binary,  # computation_op
+                binary_unary_attr,  # binary_unary_attr
+            )
+
+        # Priority 2 to match: QConv2d Binary-Unary pattern with fp32/bfloat16 output
+        binary_replace_float_out_patterns = {}
+        for swap_inputs in swap_binary_inputs_list:
+            binary_replace_float_out_patterns.update(
+                {
+                    PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary(
+                        generate_pattern_with_binary(
+                            aten.add.Tensor,
+                            get_qconv_pt2e_pattern(users=1),
+                            KeywordArg("accum_after_dequant"),
+                            int8_mixed_bf16_with_inplace_add,
+                            swap_inputs=swap_inputs,
+                        ),
+                        aten.relu.default,
+                    )
+                }
+            )
+
+        for (
+            binary_unary_attr,
+            patterns,
+        ) in binary_replace_float_out_patterns.items():
+            if int8_mixed_bf16_with_inplace_add:
+                _register_qconv_post_op_fusion_pass(
+                    patterns,
+                    3,  # pass_number
+                    torch.ops.onednn.qconv2d_pointwise.binary,  # computation_op
+                    binary_unary_attr,  # binary_unary_attr
+                )
+            else:
+                _register_qconv_post_op_fusion_pass(
+                    patterns,
+                    4,  # pass_number
+                    torch.ops.onednn.qconv2d_pointwise.binary,  # computation_op
+                    binary_unary_attr,  # binary_unary_attr
+                )
+
+        # Priority 3: QConv2d Binary pattern with fp32/bfloat16 output
+        binary_replace_float_out_patterns = {}
+        for swap_inputs in swap_binary_inputs_list:
+            binary_replace_float_out_patterns.update(
+                {
+                    PostOpAttr(
+                        "sum", 1.0, "none", [], ""
+                    ): generate_pattern_with_binary(
+                        aten.add.Tensor,
+                        get_qconv_pt2e_pattern(users=1),
+                        KeywordArg("accum_after_dequant"),
+                        int8_mixed_bf16_with_inplace_add,
+                        swap_inputs=swap_inputs,
+                    ),
+                }
+            )
+
+        for (
+            binary_unary_attr,
+            patterns,
+        ) in binary_replace_float_out_patterns.items():
+            _register_qconv_post_op_fusion_pass(
+                patterns,
+                4 if int8_mixed_bf16_with_inplace_add else 5,  # pass_number
+                torch.ops.onednn.qconv2d_pointwise.binary,  # computation_op
+                binary_unary_attr,  # binary_unary_attr
+            )
+
+
+def _register_qlinear_post_op_fusion_pass(
+    pattern,
+    pass_number,
+    computation_op,
+    post_op_attr,
+):
+    has_binary_post_op = post_op_attr.binary_op_name != "none"
+
+    @register_freezing_graph_pattern(
+        pattern,
+        extra_check=_is_valid_qlinear_post_op_fusion_pattern(has_binary_post_op),
+        pass_number=pass_number,
+    )
+    def qlinear_post_op_fusion(match: Match, *args, **kwargs):
+        """
+        Match the pattern:
+        qlinear - post op
+        """
+        output_dtype = _get_pattern_output_dtype(match)
+        # Activation QParams
+        x, x_scale, x_zp = (
+            kwargs["x"],
+            kwargs["x_scale"],
+            kwargs["x_zp"],
+        )
+        # Weight QParams
+        packed_weight, w_scale, w_zp = (
+            kwargs["packed_weight"],
+            kwargs["w_scale"],
+            kwargs["w_zp"],
+        )
+
+        # bias
+        b = kwargs.get("b")
+
+        # Output QParams
+        o_inv_scale = (
+            kwargs["o_inv_scale"]
+            if (output_dtype in [torch.uint8, torch.int8])
+            else 1.0
+        )
+        o_zero_point = (
+            kwargs["o_zp"] if (output_dtype in [torch.uint8, torch.int8]) else 0
+        )
+        assert (
+            kwargs["postop_name"] == "none"
+        )  # Expected no post op fused in weight prepack phase
+
+        out_node = match.output_node()
+        with match.graph.inserting_before(out_node):
+            if not has_binary_post_op:
+                computation_args: tuple[Any, ...] = (
+                    x,
+                    x_scale,
+                    x_zp,
+                    packed_weight,
+                    w_scale,
+                    w_zp,
+                    b,
+                    o_inv_scale,
+                    o_zero_point,
+                    output_dtype,
+                    post_op_attr.unary_op_name,
+                    post_op_attr.scalars_attr,
+                    post_op_attr.algorithm_attr,
+                )
+            else:
+                other = kwargs["other"] if "other" in kwargs else kwargs["accum"]
+                x2_scale = 1.0
+                x2_zp = 0
+                computation_args = (
+                    x,
+                    x_scale,
+                    x_zp,
+                    packed_weight,
+                    w_scale,
+                    w_zp,
+                    other,
+                    b,
+                    o_inv_scale,
+                    o_zero_point,
+                    output_dtype,
+                    x2_scale,
+                    x2_zp,
+                    post_op_attr.binary_op_name,
+                    post_op_attr.alpha,
+                    post_op_attr.unary_op_name,
+                    post_op_attr.scalars_attr,
+                    post_op_attr.algorithm_attr,
+                )
+            new_linear_node = match.graph.call_function(
+                computation_op, args=computation_args
+            )
+            out_node.replace_all_uses_with(new_linear_node)
+            new_linear_node.meta.update(out_node.meta)
+            for node in reversed(match.nodes):
+                match.graph.erase_node(node)
+        count_key = (
+            "qlinear_binary_matcher_count"
+            if has_binary_post_op
+            else "qlinear_unary_matcher_count"
+        )
+        nodes_key = (
+            "qlinear_binary_matcher_nodes"
+            if has_binary_post_op
+            else "qlinear_unary_matcher_nodes"
+        )
+        counters["inductor"][count_key] += 1
+        counters["inductor"][nodes_key] += len(match.nodes)
+
+
+def _register_qlinear_unary_fusion():
+    from .mkldnn_fusion import (
+        _gelu_fusion_1 as _gelu_fusion_erf,
+        _gelu_fusion_2 as _gelu_fusion_tanh,
+    )
+
+    for original_pattern_output_dtype in [torch.float32, torch.bfloat16]:
+        is_bf16 = original_pattern_output_dtype == torch.bfloat16
+        for x_scale_zp_are_tensors in (False, True):
+            qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors)
+            computation_op = (
+                torch.ops.onednn.qlinear_pointwise.tensor
+                if x_scale_zp_are_tensors
+                else torch.ops.onednn.qlinear_pointwise.default
+            )
+            # Priority 1 to match: QLinear Unary pattern with int8 output
+            linear_unary_replace_patterns = {
+                PostOpAttr(
+                    "none", None, "none", [], ""
+                ): generate_pattern_with_output_quant(
+                    qlinear_pattern,
+                ),
+                PostOpAttr(
+                    "none", None, "relu", [], ""
+                ): generate_pattern_with_output_quant(
+                    generate_pattern_with_unary(qlinear_pattern, aten.relu.default),
+                ),
+                PostOpAttr(
+                    "none", None, "gelu", [], "none"
+                ): generate_pattern_with_output_quant(
+                    _unary_fusion_pattern(
+                        _gelu_fusion_erf,
+                        get_qlinear_pt2e_pattern(
+                            x_scale_zp_are_tensors, 1 if is_bf16 else 2
+                        ),
+                        2,
+                        is_bf16,
+                    ),
+                    with_dtype_convert=is_bf16,
+                ),
+                PostOpAttr(
+                    "none", None, "gelu", [], "tanh"
+                ): generate_pattern_with_output_quant(
+                    _unary_fusion_pattern(
+                        _gelu_fusion_tanh,
+                        get_qlinear_pt2e_pattern(
+                            x_scale_zp_are_tensors, 1 if is_bf16 else 4
+                        ),
+                        4,
+                        is_bf16,
+                    ),
+                    with_dtype_convert=is_bf16,
+                ),
+            }
+
+            for unary_attr, patterns in linear_unary_replace_patterns.items():
+                _register_qlinear_post_op_fusion_pass(
+                    patterns,
+                    3,  # pass_number
+                    computation_op,
+                    unary_attr,  # unary_attr
+                )
+
+            # Priority 2 to match: QLinear Unary pattern with FP32/BF16 output
+            linear_unary_replace_float_out_patterns = {
+                PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary(
+                    qlinear_pattern, aten.relu.default
+                ),
+                PostOpAttr(
+                    "none", None, "gelu", [], "none"
+                ): _may_generate_pattern_with_dtype_convert(
+                    _unary_fusion_pattern(
+                        _gelu_fusion_erf,
+                        get_qlinear_pt2e_pattern(
+                            x_scale_zp_are_tensors, 1 if is_bf16 else 2
+                        ),
+                        2,
+                        is_bf16,
+                    ),
+                    Arg(),
+                    is_bf16,
+                ),
+                PostOpAttr(
+                    "none", None, "gelu", [], "tanh"
+                ): _may_generate_pattern_with_dtype_convert(
+                    _unary_fusion_pattern(
+                        _gelu_fusion_tanh,
+                        get_qlinear_pt2e_pattern(
+                            x_scale_zp_are_tensors, 1 if is_bf16 else 4
+                        ),
+                        4,
+                        is_bf16,
+                    ),
+                    Arg(),
+                    is_bf16,
+                ),
+            }
+
+            for unary_attr, patterns in linear_unary_replace_float_out_patterns.items():
+                _register_qlinear_post_op_fusion_pass(
+                    patterns,
+                    4,  # pass_number
+                    computation_op,
+                    unary_attr,  # unary_attr
+                )
+
+
+def _register_qlinear_binary_fusion():
+    r"""
+    Supported linear-binary(-unary) patterns
+
+        linear(X)   extra input
+               \   /
+                Add
+                 |
+            Optional(relu)
+                 |
+                 Y
+
+    1. int8-mixed-fp32
+    +---+---------------+-----------+------------------------------+---------+
+    | # | Add type      | Quant out | Pattern                      | Post op |
+    +---+---------------+-----------+------------------------------+---------+
+    | 1 | In-/out-place | Yes       | linear + fp32 -> (relu) -> q | add     |
+    +---+---------------+-----------+------------------------------+---------+
+    | 2 | In-/out-place | No        | linear + fp32 -> (relu)      | sum     |
+    +---+---------------+-----------+------------------------------+---------+
+
+    2. int8-mixed-bf16
+    +---+----------+---------------+-----------+-----------------------------------------+---------+
+    | # | X2 dtype | Add type      | Quant out | Pattern                                 | Post op |
+    +---+----------+---------------+-----------+-----------------------------------------+---------+
+    | 1 | BF16     | In-/out-place | Yes       | linear + bf16 -> (relu) -> q            | add     |
+    +---+----------+---------------+-----------+-----------------------------------------+---------+
+    | 2 | BF16     | In-/out-place | No        | linear + bf16 -> (relu)                 | sum     |
+    +---+----------+---------------+-----------+-----------------------------------------+---------+
+    | 3 | FP32     | Out-place     | Yes       | linear + fp32 -> (relu) -> q            | add     |
+    |   |          | In-place right|           |                                         |         |
+    +---+----------+---------------+-----------+-----------------------------------------+---------+
+    | 4 | FP32     | Out-place     | No        | linear + fp32 -> (relu)                 | sum     |
+    |   |          | In-place right|           |                                         |         |
+    +---+----------+---------------+-----------+-----------------------------------------+---------+
+    | 5 | FP32     | In-place left | Yes       | linear + fp32 -> to_bf16 -> (relu) -> q | add     |
+    +---+----------+---------------+-----------+-----------------------------------------+---------+
+    | 6 | FP32     | In-place left | No        | linear + fp32 -> to_bf16 -> (relu)      | add     |
+    +---+----------+---------------+-----------+-----------------------------------------+---------+
+
+    Note
+    (1) The positions of linear and the extra input can be swapped.
+    (2) we don't insert q-dq before the extra input of linear-add by recipe. But if q-dq is found at the
+    extra input, we don't match that pattern because we cannot match all these patterns in 3 passes.
+    """
+    for x_scale_zp_are_tensors in (False, True):
+        qlinear_binary_op = (
+            torch.ops.onednn.qlinear_pointwise.binary_tensor
+            if x_scale_zp_are_tensors
+            else torch.ops.onednn.qlinear_pointwise.binary
+        )
+        unary_postop_list = ["none", "relu"]
+        unary_postop_dict = {
+            "none": None,
+            "relu": aten.relu.default,
+        }
+        convert_dtype_after_binary_list = [False, True]
+
+        # Priority 1 to match: QLinear Binary or Binary-Unary pattern with int8 output
+        # Covers case (1) of int8-mixed-fp32 and case (1)(3)(5) of int8-mixed-bf16,
+        # totally 3 patterns (2 are identical)
+        swap_binary_inputs_list = [False, True]
+        int8_mixed_bf16_list = [False, True]
+        combinations = itertools.product(
+            unary_postop_list,
+            int8_mixed_bf16_list,
+            swap_binary_inputs_list,
+            convert_dtype_after_binary_list,
+        )
+        qlinear_binary_replace_patterns = {}
+        for unary_op, int8_mixed_bf16, swap_inputs, cvt_dtype_binary in combinations:
+            if not int8_mixed_bf16 and cvt_dtype_binary:
+                # No convert node after binary node if dtypes are all fp32
+                continue
+            qlinear_binary_replace_patterns.update(
+                {
+                    PostOpAttr(
+                        "add", 1.0, unary_op, [], ""
+                    ): generate_pattern_with_output_quant(
+                        generate_pattern_with_unary(
+                            generate_pattern_with_binary(
+                                aten.add.Tensor,
+                                get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
+                                KeywordArg("other"),
+                                # If fp32 extra input is inplace added to bf16 linear output,
+                                # a to_bf16 node is inserted after binary
+                                dtype_convert=cvt_dtype_binary,
+                                swap_inputs=swap_inputs,
+                            ),
+                            unary_postop_dict[unary_op],
+                        ),
+                    )
+                }
+            )
+        for binary_unary_attr, patterns in qlinear_binary_replace_patterns.items():
+            _register_qlinear_post_op_fusion_pass(
+                patterns,
+                3,  # pass_number
+                qlinear_binary_op,  # computation_op
+                binary_unary_attr,
+            )
+
+        # Priority 2.1 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output
+        # Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16,
+        # totally 2 patterns (2 are identical)
+        binary_replace_float_out_patterns = {}
+        for swap_binary_inputs in swap_binary_inputs_list:
+            binary_replace_float_out_patterns.update(
+                {
+                    PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary(
+                        generate_pattern_with_binary(
+                            aten.add.Tensor,
+                            get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
+                            KeywordArg("accum"),
+                            dtype_convert=False,
+                            swap_inputs=swap_binary_inputs,
+                        ),
+                        aten.relu.default,
+                    ),
+                }
+            )
+        for (
+            binary_unary_attr,
+            patterns,
+        ) in binary_replace_float_out_patterns.items():
+            _register_qlinear_post_op_fusion_pass(
+                patterns,
+                4,  # pass_number
+                qlinear_binary_op,  # computation_op
+                binary_unary_attr,
+            )
+        # Priority 2.2 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output
+        # Covers case (6) of int8-mixed-bf16
+        binary_replace_float_out_patterns = {}
+        for swap_binary_inputs in swap_binary_inputs_list:
+            binary_replace_float_out_patterns.update(
+                {
+                    PostOpAttr("add", 1.0, "relu", [], ""): generate_pattern_with_unary(
+                        generate_pattern_with_binary(
+                            aten.add.Tensor,
+                            get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
+                            KeywordArg("other"),
+                            dtype_convert=True,
+                            swap_inputs=swap_binary_inputs,
+                        ),
+                        aten.relu.default,
+                    ),
+                }
+            )
+        for (
+            binary_unary_attr,
+            patterns,
+        ) in binary_replace_float_out_patterns.items():
+            _register_qlinear_post_op_fusion_pass(
+                patterns,
+                4,  # pass_number
+                qlinear_binary_op,  # computation_op
+                binary_unary_attr,
+            )
+
+        # Priority 3.1: QLinear Binary pattern with fp32/bfloat16 output
+        # Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16,
+        # totally 2 patterns (2 are identical)
+        binary_replace_float_out_patterns = {}
+        for swap_binary_inputs in swap_binary_inputs_list:
+            binary_replace_float_out_patterns.update(
+                {
+                    PostOpAttr(
+                        "sum", 1.0, "none", [], ""
+                    ): generate_pattern_with_binary(
+                        aten.add.Tensor,
+                        get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
+                        KeywordArg("accum"),
+                        dtype_convert=False,
+                        swap_inputs=swap_binary_inputs,
+                    ),
+                }
+            )
+        for (
+            binary_unary_attr,
+            patterns,
+        ) in binary_replace_float_out_patterns.items():
+            _register_qlinear_post_op_fusion_pass(
+                patterns,
+                5,  # pass_number
+                qlinear_binary_op,  # computation_op
+                binary_unary_attr,
+            )
+        # Priority 3.2: QLinear Binary pattern with fp32/bfloat16 output
+        # Covers (6) of int8-mixed-bf16
+        binary_replace_float_out_patterns = {}
+        for swap_binary_inputs in swap_binary_inputs_list:
+            binary_replace_float_out_patterns.update(
+                {
+                    PostOpAttr(
+                        "add", 1.0, "none", [], ""
+                    ): generate_pattern_with_binary(
+                        aten.add.Tensor,
+                        get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
+                        KeywordArg("other"),
+                        dtype_convert=True,
+                        swap_inputs=swap_binary_inputs,
+                    ),
+                }
+            )
+        for (
+            binary_unary_attr,
+            patterns,
+        ) in binary_replace_float_out_patterns.items():
+            _register_qlinear_post_op_fusion_pass(
+                patterns,
+                5,  # pass_number
+                qlinear_binary_op,  # computation_op
+                binary_unary_attr,
+            )
+
+
+@functools.cache
+def _register_quantization_weight_pack_pass():
+    # Step 1: Dequant promotion for int8-mixed-fp32/bf16
+    _register_dequant_promotion()
+
+    # Step 2: QConv weight prepack
+    _register_qconv_weight_prepack()
+
+    # Step 3: QLinear weight prepack
+    _register_qlinear_weight_prepack()
+    _register_linear_dynamic_fp16_weight_prepack()
+
+    # Step 4: weight prepack for SmoothQuant from Torchao
+    _register_smooth_quant_int_mm_pattern()
+
+    # Step 5: QLinear post op Fusion
+    if not torch.ops.mkldnn._is_mkldnn_acl_supported():
+        # skip fusion on ARM
+        _register_qconv_unary_fusion()
+        _register_qconv_binary_fusion()
+        _register_qlinear_unary_fusion()
+        _register_qlinear_binary_fusion()
+
+
+def _is_valid_concat_linear_woq_int4_fusion(computation_nodes):
+    computation_op = torch.ops.aten._weight_int4pack_mm_for_cpu.default
+    act = computation_nodes[0].args[0]
+    wgt = computation_nodes[0].args[1]
+    in_feature_size = wgt.meta.get("val").size(1)  # type: ignore[union-attr]
+    group_size = computation_nodes[0].args[2]
+    return len(computation_nodes) >= 2 and all(
+        (
+            node.target == computation_op
+            and node.args[0] == act  # share same activation
+            and (
+                node.args[1].meta.get("val").size(1) == in_feature_size
+            )  # same in feature size
+            and (node.args[1] != wgt or gemm_idx == 0)
+            and node.args[1].op == "get_attr"  # wgt are all constants
+            and node.args[2] == group_size  # same group size
+        )
+        for gemm_idx, node in enumerate(computation_nodes)
+    )
+
+
+def concat_linear_woq_int4(gm: torch.fx.GraphModule):
+    """
+    Concat Linear optimization pass for WOQ int4
+    This pass fuses the original pattern:
+    def ...
+        return (woq_int4(x, w1, group_size, scale_zp1), woq_int4(x, w2, group_size, scale_zp1) ...)
+    into a single operation:
+    def ...
+        concat_res = woq_int4(x, concat_w, group_size, concat_scale_zp)
+        return split(concat_res, split_size_list)
+    """
+
+    def concat_wgt(packed_wgts, scale_zps, group_size, act_dtype):
+        # Concat the wgts and scale_zps, and repack the wgt
+        unpacked_wgts = []
+        for packed_wgt in packed_wgts:
+            # Get the unpacked weight list
+            # Same as https://github.com/pytorch/pytorch/pull/156174
+            K = packed_wgt.size(1) * 2
+            N = packed_wgt.size(0)
+            x = torch.eye(K).to(dtype=act_dtype)
+            qscales_and_zeros = (
+                torch.tensor([1.0, 8.0])
+                .to(dtype=act_dtype)
+                .expand(K // group_size, N, 2)
+                .contiguous()
+            )
+            unpacked_wgts.append(
+                torch.ops.aten._weight_int4pack_mm_for_cpu(
+                    x,
+                    packed_wgt,
+                    group_size,
+                    qscales_and_zeros,
+                )
+                .t()
+                .contiguous()
+                .to(torch.int32)  # N, K
+            )
+        concat_unpacked_wgt = torch.cat(unpacked_wgts, dim=0)
+        repack_w = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
+            concat_unpacked_wgt, 1
+        )
+        concat_scale_zp = torch.cat(scale_zps, dim=1).contiguous()
+        return repack_w, concat_scale_zp
+
+    graph = gm.graph
+    computation_op = torch.ops.aten._weight_int4pack_mm_for_cpu.default
+    for node in graph.find_nodes(op="call_function", target=computation_op):
+        if (
+            not node._erased
+            and isinstance(node.meta.get("val"), torch.Tensor)
+            and node.meta["val"].device.type == "cpu"
+        ):
+            act = node.args[0]
+            users = list(act.users)
+            if _is_valid_concat_linear_woq_int4_fusion(users):
+                with graph.inserting_before(node):
+                    assert all(user.args[1].op == "get_attr" for user in users)
+                    computation_node_0 = users[0]
+                    packed_wgts = [getattr(gm, user.args[1].target) for user in users]
+                    group_size = computation_node_0.args[2]
+                    scale_zps = [getattr(gm, user.args[3].target) for user in users]
+                    out_feature_size_list = [
+                        packed_wgt.size(0) for packed_wgt in packed_wgts
+                    ]
+                    repack_w, concat_scale_zp = concat_wgt(
+                        packed_wgts, scale_zps, group_size, act.meta.get("val").dtype
+                    )
+                    repack_w_node_name = computation_node_0.args[1].target + "_concat"
+                    concat_scale_zp_node_name = (
+                        computation_node_0.args[3].target + "_concat"
+                    )
+                    gm.register_buffer(repack_w_node_name, repack_w)
+                    setattr(gm, repack_w_node_name, repack_w)
+                    gm.register_buffer(concat_scale_zp_node_name, concat_scale_zp)
+                    setattr(gm, concat_scale_zp_node_name, concat_scale_zp)
+
+                    repack_w_node = graph.create_node(
+                        "get_attr", repack_w_node_name, (), {}
+                    )
+                    with graph.inserting_after(repack_w_node):
+                        concat_scale_zp_node = graph.create_node(
+                            "get_attr", concat_scale_zp_node_name, (), {}
+                        )
+
+                    with graph.inserting_after(concat_scale_zp_node):
+                        concat_int4_gemm_node = graph.create_node(
+                            "call_function",
+                            computation_op,
+                            (
+                                act,
+                                repack_w_node,
+                                group_size,
+                                concat_scale_zp_node,
+                            ),
+                        )
+                    with graph.inserting_after(concat_int4_gemm_node):
+                        split_node = graph.create_node(
+                            "call_function",
+                            torch.ops.aten.split_with_sizes.default,
+                            (
+                                concat_int4_gemm_node,
+                                out_feature_size_list,
+                                1,  # split dim
+                            ),
+                        )
+                    with graph.inserting_after(split_node):
+                        for gemm_idx, user in enumerate(users):
+                            assert user.target == computation_op
+                            get_item = graph.create_node(
+                                "call_function",
+                                operator.getitem,
+                                (
+                                    split_node,
+                                    gemm_idx,
+                                ),
+                            )
+                            with graph.inserting_after(get_item):
+                                clone_node = graph.create_node(
+                                    "call_function",
+                                    torch.ops.aten.clone.default,
+                                    (get_item,),
+                                    {"memory_format": torch.contiguous_format},
+                                )
+                                user.replace_all_uses_with(clone_node)
+                                graph.erase_node(user)
+
+
+def quant_lift_up(graph_module: torch.fx.GraphModule):
+    """
+    Lift up the quant node before view like nodes. It can benefit performance
+    of Attention like block. For example, we have the pattern as:
+
+             DQ
+    DQ       LINEAR
+    LINEAR   VIEW
+    VIEW     PERMUTE
+    PERMUTE  TRANSPOSE
+    Q        Q
+    DQ       DQ
+       Matmul
+        DIV
+        ADD
+      SOFTMAX
+
+    We want to lift up the quant nodes from matmul before view like nodes
+    as the output of Linear node.
+
+             DQ
+    DQ       LINEAR
+    LINEAR   Q
+    Q        VIEW
+    VIEW     PERMUTE
+    PERMUTE  TRANSPOSE
+    DQ       DQ
+       Matmul
+        DIV
+        ADD
+      SOFTMAX
+
+    It produces a DQ->LINEAR->Q pattern which can be fused by backend.
+    """
+
+    def is_view_op(node):
+        return node.op == "call_function" and node.target in _VIEW_OPS
+
+    for node in graph_module.graph.nodes:
+        #  Leslie: Here we verify that the quant node has exactly
+        # one input FX node, with constant scalar value for scale and zero point.
+        # For the case input of quant node has more than one input FX nodes,
+        # extend the implementation to lift up all the connected nodes
+        # before the view nodes to keep the topological order.
+        if (
+            node.op == "call_function"
+            and node.target in _PER_TENSOR_QUANTIZE_OPS
+            and len(node.all_input_nodes) == 1
+            and is_view_op(node.all_input_nodes[0])
+        ):
+            quant_node = node
+            input_node_of_quant = quant_node.args[0]
+
+            # Check the nodes along lift up path has only 1 user node
+            # Propagate view like node to find where to insert the new quant node
+            could_lift_up = True
+            current_node = quant_node
+            input_node = current_node.args[0]
+            while is_view_op(input_node):
+                if len(input_node.users) != 1:
+                    could_lift_up = False
+                    break
+                current_node = input_node
+                input_node = current_node.args[0]
+
+            # Further check the input node of the first view node has only 1 user node
+            if could_lift_up and len(input_node.users) == 1:
+                counters["inductor"]["quant_lift_up_count"] += 1
+                # Replace dequant's input from quant to quant's input
+                quant_node.replace_all_uses_with(input_node_of_quant)
+                # Insert the new quant node
+                with graph_module.graph.inserting_before(current_node):
+                    new_quant_node = graph_module.graph.node_copy(quant_node)
+                    input_node.replace_all_uses_with(new_quant_node)
+
+                    # Update inputs of new_quant_node
+                    def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node:
+                        if n == input_node_of_quant:
+                            return input_node
+                        else:
+                            return n
+
+                    new_args = map_arg(new_quant_node.args, maybe_replace_node)
+                    new_kwargs = map_arg(new_quant_node.kwargs, maybe_replace_node)
+                    new_quant_node.args = new_args  # type: ignore[assignment]
+                    new_quant_node.kwargs = new_kwargs  # type: ignore[assignment]
+                    graph_module.graph.erase_node(quant_node)
+
+    graph_module.graph.lint()
+    graph_module.recompile()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/reinplace.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/reinplace.py
new file mode 100644
index 0000000000000000000000000000000000000000..e42e8a1139770d488929f772b0441fe4f616d449
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/reinplace.py
@@ -0,0 +1,795 @@
+# mypy: allow-untyped-defs
+import itertools
+import logging
+import operator
+from collections import defaultdict
+from collections.abc import Callable, Sequence
+from contextlib import nullcontext
+from dataclasses import dataclass
+from typing import Any, cast
+
+import torch
+import torch.fx.node
+from torch._C._dynamo.guards import compute_overlapping_tensors
+from torch._dispatch.python import enable_python_dispatcher
+from torch._dynamo.utils import ReinplaceCounters, ReInplaceTrigger
+from torch._guards import detect_fake_mode
+from torch._higher_order_ops.triton_kernel_wrap import (
+    kernel_side_table,
+    triton_kernel_wrapper_functional,
+)
+from torch._inductor import config, inductor_prims
+from torch._inductor.fx_utils import get_node_storage, is_node_realized
+from torch._inductor.lowering import (
+    inplaceable_foreach_ops as inplaceable_foreach_ops_lowerings,
+)
+from torch._inductor.virtualized import V
+from torch.fx.experimental.symbolic_shapes import (
+    compute_unbacked_bindings,
+    GuardOnDataDependentSymNode,
+)
+from torch.fx.immutable_collections import immutable_dict, immutable_list
+from torch.fx.passes.reinplace import _is_view_op
+from torch.utils import _pytree as pytree
+from torch.utils._ordered_set import OrderedSet
+
+
+log = logging.getLogger(__name__)
+aten = torch.ops.aten
+
+
+@dataclass(frozen=True)
+class InplaceableOp:
+    inplace_op: Callable[..., Any]
+    mutated_arg: int
+    extra_check: Callable[[torch.fx.Node], bool] = lambda node: True
+
+
+_SCATTER_OP_TO_VIEW = {
+    torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default,
+    torch.ops.aten.select_scatter.default: torch.ops.aten.select.int,
+    torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor,
+    torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default,
+}
+_VIEW_OP_TO_SCATTER = {v: k for k, v in _SCATTER_OP_TO_VIEW.items()}
+
+
+def graph_call_function(graph: torch.fx.Graph, fn, *args, **kwargs):
+    fake_args, fake_kwargs = pytree.tree_map(
+        lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node,
+        (args, kwargs),
+    )
+    with V.fake_mode:
+        fake_result = fn(*fake_args, **fake_kwargs)
+
+    node = graph.call_function(fn, args, kwargs)
+
+    node.meta["val"] = fake_result
+
+    return node
+
+
+@dataclass
+class ViewOp:
+    target: torch._ops.OpOverload
+    args: tuple[Any, ...]
+    kwargs: dict[str, Any]
+
+
+def _inplace_generalized_scatter(
+    inp: torch.Tensor, src: torch.Tensor, view_ops: list[ViewOp]
+) -> torch.Tensor:
+    tmp = inp
+    for view in view_ops:
+        fake_args, fake_kwargs = pytree.tree_map(
+            lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node,
+            (view.args, view.kwargs),
+        )
+        # slice and select can allocate new unbacked symints, but those won't be reflected
+        # in the output of this function, hence shall be ignored.
+        fake_mode = detect_fake_mode(fake_args)
+        with (
+            fake_mode.shape_env.ignore_fresh_unbacked_symbols()
+            if fake_mode and fake_mode.shape_env
+            else nullcontext()
+        ):
+            tmp = view.target(tmp, *fake_args, **fake_kwargs)
+    try:
+        tmp.copy_(src)
+    except RuntimeError as e:
+        raise RuntimeError(
+            f"shape error in scatter op, can not broadcast {src.shape} to {tmp.shape}"
+        ) from e
+    return inp
+
+
+def _generalized_scatter(
+    inp: torch.Tensor, src: torch.Tensor, view_ops: list[ViewOp]
+) -> torch.Tensor:
+    out = inp.clone()
+    return _inplace_generalized_scatter(out, src, view_ops)
+
+
+def _decompose_scatter_functional_helper(
+    graph: torch.fx.Graph,
+    inp: torch.Tensor,
+    src: torch.Tensor,
+    view_ops: list[ViewOp],
+) -> torch.fx.Node:
+    view_op, view_ops_tail = view_ops[0], view_ops[1:]
+
+    if view_ops_tail:
+        view = graph_call_function(
+            graph, view_op.target, inp, *view_op.args, **view_op.kwargs
+        )
+        src = _decompose_scatter_functional_helper(graph, view, src, view_ops[1:])  # type: ignore[assignment]
+
+    return graph_call_function(
+        graph,
+        _VIEW_OP_TO_SCATTER[view_op.target],
+        inp,
+        src,
+        *view_op.args,
+        **view_op.kwargs,
+    )
+
+
+def _decompose_scatter_functional(
+    graph: torch.fx.Graph, node: torch.fx.Node
+) -> torch.fx.Node:
+    """Decompose _generalized_scatter to a sequence of view_scatter operations
+
+    e.g. _generalized_scatter(inp, src, [(aten.slice, 0, 0, 10), (aten.slice, 1, 10, -10)])
+
+    will become
+
+    view = aten.slice(inp, 0, 0, 10)
+    view_updated = aten.slice_scatter(view, src, 1, 10, -10)
+    inp_updated = aten.slice_scatter(inp, view_updated, 0, 0, 10)
+    """
+    assert node.target is _generalized_scatter
+    return _decompose_scatter_functional_helper(graph, *node.args)  # type: ignore[arg-type]
+
+
+def _decompose_scatter_mutating(
+    graph: torch.fx.Graph, node: torch.fx.Node
+) -> torch.fx.Node:
+    """Decompose _generalized_scatter using mutations
+
+    e.g. _generalized_scatter(inp, src, [(aten.slice, 0, 0, 10), (aten.slice, 1, 10, -10)])
+
+    will become
+
+    inp_updated = aten.clone(inp)
+    slice1 = aten.slice(inp_updated, 0, 0, 10)
+    slice2 = aten.slice(slice1, 1, 10, -10)
+    slice2.copy_(src)
+
+    """
+    assert node.target in (_generalized_scatter, _inplace_generalized_scatter)
+    inp, src, view_ops = node.args
+    assert not node.kwargs
+
+    if node.target is _generalized_scatter:
+        inp = graph_call_function(graph, aten.clone, inp)
+
+    tmp = inp
+    for view in view_ops:  # type: ignore[union-attr]
+        tmp = graph_call_function(graph, view.target, tmp, *view.args, **view.kwargs)  # type: ignore[union-attr]
+        # we need to set unbacked bindings that could have been created in the view ops.
+        if (V.fake_mode.shape_env) and (
+            symbol_to_path := compute_unbacked_bindings(
+                V.fake_mode.shape_env, tmp.meta["val"]
+            )
+        ):
+            tmp.meta["unbacked_bindings"] = symbol_to_path
+
+    graph_call_function(graph, aten.copy_.default, tmp, src)
+    return inp  # type: ignore[return-value]
+
+
+# View ops whose view_scatter op is lowered into mutations anyway,
+# so is never a pessimisation to decompose.
+_ALWAYS_MUTATING_SCATTER_OPS = OrderedSet(
+    [
+        aten.as_strided.default,
+        aten.diagonal.default,
+    ]
+)
+
+
+def scatter_always_uses_mutation(node: torch.fx.Node) -> bool:
+    _, _, view_ops = node.args
+    view_ops = cast(Sequence[torch.fx.node.Argument], view_ops)
+    return any(
+        target in _ALWAYS_MUTATING_SCATTER_OPS
+        for view in view_ops
+        if isinstance(target := getattr(view, "target", None), torch._ops.OpOverload)
+    )
+
+
+def should_reinplace_scatter(node: torch.fx.Node) -> bool:
+    """Choose between mutating and functional scatter decompositions
+
+    Reinplacing view scatter ops can be pessimising as it blocks fusion with the
+    input or output tensor computations. However, it is still profitable if the
+    input and output would have been realized anyway.
+
+    """
+    inp, _src, _view_ops = node.args
+
+    # Mutating scatter ops unconditionally realize input and output
+    if scatter_always_uses_mutation(node):
+        return True
+
+    if is_node_realized(inp) and is_node_realized(node):  # type: ignore[arg-type]
+        return True
+
+    # If the output is copied back into the input, this forces both to be
+    # realized as the output is a user of the input
+    if inp.op in ("placeholder", "get_attr") and any(  # type: ignore[union-attr]
+        user.target is aten.copy_.default and user.args[0] is inp for user in node.users
+    ):
+        return True
+
+    # Otherwise, assume fusions will make functional variants profitable
+    return False
+
+
+def decompose_generalized_scatter(graph: torch.fx.Graph) -> None:
+    """Replace _generalized_scatter with normal aten ops"""
+    for node in itertools.chain(
+        graph.find_nodes(op="call_function", target=_generalized_scatter),
+        graph.find_nodes(op="call_function", target=_inplace_generalized_scatter),
+    ):
+        use_mutation = (
+            node.target is _inplace_generalized_scatter
+            or scatter_always_uses_mutation(node)
+        )
+
+        with graph.inserting_before(node):
+            if use_mutation:
+                new_node = _decompose_scatter_mutating(graph, node)
+            else:
+                new_node = _decompose_scatter_functional(graph, node)
+
+        node.replace_all_uses_with(new_node)
+        graph.erase_node(node)
+
+
+def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
+    """
+    This canonicalizes view scatter ops into a generalized form, defined as:
+      def scatter(inp, src, views):
+        tmp = inp.clone()
+        for view in views:
+          tmp = view(tmp)
+        tmp.copy_(src)
+
+    We also fuse consecutive view scatter ops of the form
+        a = scatter(view2(self), src, [view1])
+        b = scatter(self, a, [view2])
+    which can be rewritten as
+        b = scatter(self, src, [view2, view1])
+        a = view2(b)
+
+    This is both more efficient as we only do a single scatter, and also
+    easier to reinplace since there is only one use of `self`
+    """
+
+    node_to_view_base: dict[torch.fx.Node, torch.fx.Node] = {}
+    node_to_view_op: dict[torch.fx.Node, list[ViewOp]] = defaultdict(list)
+
+    def handle_views(node: torch.fx.Node):
+        inp = node.args[0]
+        node_to_view_base[node] = node_to_view_base.get(inp, inp)  # type: ignore[arg-type, assignment]
+        node_to_view_op[node] = [
+            *node_to_view_op[inp],  # type: ignore[index]
+            ViewOp(
+                node.target,  # type: ignore[arg-type]
+                args=node.args[1:],
+                kwargs=node.kwargs,
+            ),
+        ]
+
+    def handle_view_scatter(node: torch.fx.Node):
+        assert len(node.args) >= 2
+        inp, src = node.args[:2]
+
+        assert isinstance(node.target, torch._ops.OpOverload)
+        scatter_view_op = ViewOp(
+            _SCATTER_OP_TO_VIEW[node.target],
+            args=node.args[2:],
+            kwargs=node.kwargs,
+        )
+
+        def can_fuse():
+            if src.target is not _generalized_scatter:  # type: ignore[union-attr]
+                return False
+            src_inp, _src_src, _src_scatter_view_op = src.args  # type: ignore[union-attr]
+
+            inp_base = node_to_view_base.get(inp, inp)  # type: ignore[arg-type]
+            src_base = node_to_view_base.get(src_inp, src_inp)  # type: ignore[arg-type]
+            return inp_base is src_base and node_to_view_op[src_inp] == [  # type: ignore[index]
+                *node_to_view_op[inp],  # type: ignore[index]
+                scatter_view_op,
+            ]
+
+        if not can_fuse():
+            with graph.inserting_before(node):
+                new_node = graph_call_function(
+                    graph,
+                    _generalized_scatter,
+                    inp,
+                    src,
+                    [scatter_view_op],
+                )
+            node.replace_all_uses_with(new_node)
+            graph.erase_node(node)
+            return
+
+        _src_inp, src_src, src_scatter_view_op = src.args  # type: ignore[union-attr]
+        with graph.inserting_before(src):  # type: ignore[arg-type]
+            new_node = graph_call_function(
+                graph,
+                _generalized_scatter,
+                inp,
+                src_src,
+                [scatter_view_op, *src_scatter_view_op],  # type: ignore[misc]
+            )
+            node.replace_all_uses_with(new_node)
+            graph.erase_node(node)
+
+            if src.users:  # type: ignore[union-attr]
+                new_src = graph_call_function(
+                    graph,
+                    _SCATTER_OP_TO_VIEW[node.target],
+                    new_node,
+                    *node.args[2:],
+                    **node.kwargs,
+                )
+
+                handle_views(new_src)
+                src.replace_all_uses_with(new_src)  # type: ignore[union-attr]
+
+            graph.erase_node(src)  # type: ignore[arg-type]
+
+    for node in graph.nodes:
+        if _is_view_op(node.target):
+            handle_views(node)
+        elif node.target in _SCATTER_OP_TO_VIEW:
+            handle_view_scatter(node)
+
+
+inplaceable_ops: dict[Callable[..., Any], InplaceableOp] = {
+    aten.index_put.default: InplaceableOp(aten.index_put_.default, 0),
+    aten._unsafe_index_put.default: InplaceableOp(inductor_prims._unsafe_index_put_, 0),
+    _generalized_scatter: InplaceableOp(
+        _inplace_generalized_scatter,
+        0,
+        extra_check=should_reinplace_scatter,
+    ),
+}
+
+try:
+    c10d_functional = torch.ops._c10d_functional
+    inplaceable_collective_ops: dict[Callable[..., Any], InplaceableOp] = {
+        c10d_functional.all_reduce.default: InplaceableOp(
+            c10d_functional.all_reduce_.default, 0
+        ),
+        c10d_functional.all_reduce_coalesced.default: InplaceableOp(
+            c10d_functional.all_reduce_coalesced_.default, 0
+        ),
+    }
+    inplaceable_ops.update(inplaceable_collective_ops)
+except AttributeError:
+    # _c10d_functional ops are only available when torch
+    # is built with USE_DISTRIBUTED=1.
+    pass
+
+inplaceable_foreach_ops: dict[torch._ops.OpOverload, InplaceableOp] = {}
+for outplace_op, inplace_op in inplaceable_foreach_ops_lowerings.items():
+    inplaceable_foreach_ops[outplace_op] = InplaceableOp(inplace_op, 0)
+
+
+inplaceable_triton_ops = OrderedSet([triton_kernel_wrapper_functional])
+
+
+# Operators that don't depend on the tensor data
+META_ONLY_OPS = OrderedSet(
+    [
+        aten.sym_size.int,
+        aten.sym_stride.int,
+        aten.sym_numel.default,
+        aten.sym_storage_offset.default,
+    ]
+)
+
+
+def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
+    """
+    Reinplaces in-placeable operations.
+    If there are no uses of a view of the mutated arg after the current node,
+    it is possible to inplace the op.
+    This above algorithm could be justified by observing side effects. While
+    we traverse the graph in forwards direction, only latter nodes could view
+    side effects of the current node. If the current node is not used later as
+    well as no view of this node is used later in the graph, then it is safe to
+    inplace as there would be no way to observe the side effects.
+    This condition is slightly different for graph inputs where they can only
+    be inplaced if the above condition is true and there's a copy_ in the
+    epilogue that signals that the caller wants to observe the mutation.
+
+    Unlike JIT Inductor, AOTInductor currently unlifts weights and buffers from
+    input args, so instead of checking mutation on placeholder, AOTInductor
+    checks mutation on get_attr. This is subject to change in future.
+    """
+
+    copy_args_to_copy_nodes = {}
+    # maps argument to the first copy_ node that mutates it.
+    copy_nodes = {}
+    mutated_inputs = OrderedSet[Any]()
+    storage_to_nodes = defaultdict(list)
+    node_order: dict[Any, int] = {}
+    for i, node in enumerate(reversed(graph.nodes)):
+        node_order[node] = len(graph.nodes) - i - 1
+        storage_to_nodes[get_node_storage(node)].append(node)
+        if node.target is aten.copy_.default and node.args[0].op in (
+            "placeholder",
+            "get_attr",
+        ):
+            dst = node.args[0]
+            src = node.args[1]
+            # If the target is a getitem and it indexes a possible clone,
+            # then skip over it
+            if src.target is operator.getitem and (
+                (
+                    src.args[0].target == triton_kernel_wrapper_functional
+                    and src.args[0].kwargs["kwargs"][src.args[1]] == node.args[0]
+                )
+                or (src.args[0].target in inplaceable_foreach_ops)
+                or (src.args[0].target is torch.ops.higher_order.auto_functionalized)
+            ):
+                src = src.args[0]
+
+            copy_args_to_copy_nodes[(dst, src)] = node
+            copy_nodes[dst] = node
+
+            mutated_inputs.add(node.args[0])
+
+    def any_use_of_views_after_node(node, shared_view_nodes, *, copy_node, mutated_arg):
+        node_loc = node_order[node]
+        copy_node_loc = node_order[copy_node] if copy_node is not None else None
+
+        def is_meta_only_user(node):
+            if _is_view_op(node.target):
+                return all(is_meta_only_user(u) for u in node.users)
+            return node.target in META_ONLY_OPS
+
+        for view in shared_view_nodes:
+            for user in view.users:
+                user_loc = node_order[user]
+                # Skip all users before node
+                if user_loc <= node_loc:
+                    continue
+                # Ignore uses after the copy_ epilogue node, where the input
+                # has already been mutated anyway
+                if copy_node_loc is not None and copy_node_loc <= user_loc:
+                    continue
+                # Reinplacing does not change shape metadata
+                if is_meta_only_user(user):
+                    continue
+                # If our graph looks like:
+                # foo(mutated_arg)
+                # mutated_arg.copy_(other)
+                # then it's safe for us to reinplace foo because mutated_arg
+                # will get overwritten anyways.
+                if (
+                    user.target is torch.ops.aten.copy_.default
+                    and mutated_arg is user.args[0]
+                ):
+                    continue
+                return True
+        return False
+
+    def can_inplace(node, mutated_arg):
+        # ls should be a list of tensors that all shares the same storage.
+        def _overlap(ls) -> bool:
+            try:
+                return len(compute_overlapping_tensors(ls)) != 0
+            except GuardOnDataDependentSymNode:
+                # If we fail with data dependent error we assume they all overlap.
+                return True
+
+        if isinstance(mutated_arg, (list, tuple)):
+            # TODO Using _overlap here causes a several issues.
+            unique_storages = OrderedSet(get_node_storage(arg) for arg in mutated_arg)
+            if len(unique_storages) != len(mutated_arg):
+                # At least two Tensors in mutated_arg alias each other, so we can't reinplace it.
+                # We can probably do better (that is, reinplace one of them and clone the other)
+                # but that requires more work and mutable List[Tensor] are not that common.
+                return False
+            return all(can_inplace(node, arg) for arg in mutated_arg)
+
+        if get_node_storage(mutated_arg) is None:
+            return False
+
+        shared_view_nodes = storage_to_nodes[get_node_storage(mutated_arg)]
+
+        # Only keep tensor that might overlap with mutated_arg.
+        shared_view_nodes = [
+            v
+            for v in shared_view_nodes
+            if _overlap([mutated_arg.meta["val"], v.meta["val"]])
+        ]
+
+        if mutated_arg.op in ("placeholder", "get_attr"):
+            # Get the first copy_ node that mutates the mutated_arg.
+            copy_node = copy_nodes.get(mutated_arg)
+            if copy_node is None:
+                # There is no copy_ back to the candidate mutated_arg (which is a graph input).
+                # Therefore the semantics of the program are that it does not mutate
+                # mutated_arg, so we cannot re-inplace it.
+                return False
+            if any_use_of_views_after_node(
+                node, shared_view_nodes, copy_node=copy_node, mutated_arg=mutated_arg
+            ):
+                return False
+
+            return True
+        elif any(view.op in ("placeholder", "get_attr") for view in shared_view_nodes):
+            # This should never happen in auto_functionalize_v2 non-inference mode,
+            # since all mutated_arg are bases.
+
+            # If mutated arg is view of any of the inputs of the graph,
+            # do not allow for inplacing.
+            # This would require more sophisticated algorithm to handle
+            return False
+        else:
+            return not any_use_of_views_after_node(
+                node, shared_view_nodes, copy_node=None, mutated_arg=mutated_arg
+            )
+
+    def log_inplace_results(
+        node_name,
+        old_tensors_to_clone,
+        tensors_to_clone,
+        missed_args,
+        missed_nodes,
+        trigger,
+    ):
+        # Total size of possibly_missed_reinplacing_opportunities for tensors with static shapes.
+        missed_bytes = 0
+
+        def bytes(node):
+            t = node.meta.get("val", None)
+            if (
+                t is not None
+                and isinstance(t.element_size(), int)
+                and isinstance(t.numel(), int)
+            ):
+                return t.element_size() * t.numel()
+            else:
+                return 0
+
+        for node in missed_nodes:
+            if isinstance(node, (list, tuple)):
+                for n in node:
+                    missed_bytes += bytes(n)
+            else:
+                missed_bytes += bytes(node)
+
+        log.info(
+            "For node %s, attempted to reinplace %s. We were unable to reinplace %s; "
+            "%s (if non-empty) are possible missed reinplacing opportunities that may be bad for "
+            "memory usage and performance. Total size of missed opportunities with static shapes is"
+            " : %s bytes.",
+            node_name,
+            old_tensors_to_clone,
+            tensors_to_clone,
+            missed_args,
+            missed_bytes,
+        )
+
+        ReinplaceCounters.add_missed_opportunities(trigger, len(missed_args))
+        ReinplaceCounters.add_missed_bytes(trigger, missed_bytes)
+
+    replace_dict: dict[torch.fx.Node, torch.fx.Node] = {}
+
+    def reinplace_and_refine_tensors_to_clone(
+        old_tensors_to_clone, kwargs, node_name, trigger
+    ):
+        tensors_to_clone: list[str] = []
+        storage_of_reinplaced_args = OrderedSet[int | None]()
+
+        # Those used to count possibly_missed_reinplacing_opportunities
+        missed_nodes = []
+        missed_args = []
+
+        # TODO this logic can be made more precise using _overlap
+        def tensor_with_same_storage_already_reinplaced(arg):
+            if isinstance(arg, (list, tuple)):
+                return any(
+                    get_node_storage(a) in storage_of_reinplaced_args for a in arg
+                )
+            return get_node_storage(mutated_arg) in storage_of_reinplaced_args
+
+        for arg in old_tensors_to_clone:
+            assert arg in kwargs
+
+            mutated_arg = kwargs[arg]
+
+            # Let's say we have:
+            # - op(x, y) that mutates both x and y
+            # - new_x, new_y = functional_op(x, y) is the functional variant
+            # If we are presented with functional_op(x, x), we must not reinplace
+            # this into op(x, x), because then it would be writing to the same Tensor.
+            # Instead, it's OK to reinplace one of them and to clone the other:
+            # >>> y = x.clone()
+            # >>> op(x, y)
+            # This also applies if we have views: functional_op(x, x[0])
+            # should not reinplace into op(x, x[0]).
+            should_attempt_reinplace = not tensor_with_same_storage_already_reinplaced(
+                mutated_arg
+            )
+            if should_attempt_reinplace and can_inplace(node, mutated_arg):
+                # In general, we probably do not need those optimizations.
+                copy_node = copy_args_to_copy_nodes.get((mutated_arg, node))
+                if copy_node is not None:
+                    replace_dict[copy_node] = copy_node.args[0]
+                if trigger != ReInplaceTrigger.AUTO_FUNC_V2:
+                    for user in node.users:
+                        # For auto_functionalize_v2, arg is the index of the base, where base at index i corresponds to
+                        # output atindex size(out)+i.
+                        # This used to compare string with integers before for auto_functionalize_v2. Not sure
+                        # if it was needed for inplaceable_triton_ops?
+                        if user.target is operator.getitem and user.args[1] == arg:
+                            replace_dict[user] = mutated_arg
+
+                if isinstance(mutated_arg, (list, tuple)):
+                    for a in mutated_arg:
+                        storage_of_reinplaced_args.add(get_node_storage(a))
+                else:
+                    storage_of_reinplaced_args.add(get_node_storage(mutated_arg))
+            else:
+                if should_attempt_reinplace:
+                    missed_args.append(arg)
+                    missed_nodes.append(mutated_arg)
+
+                tensors_to_clone.append(arg)
+
+        log_inplace_results(
+            node_name,
+            old_tensors_to_clone,
+            tensors_to_clone,
+            missed_args,
+            missed_nodes,
+            trigger,
+        )
+        return tensors_to_clone
+
+    for node in graph.nodes:
+        if (inplaceable_op := inplaceable_ops.get(node.target, None)) is not None:
+            mutated_arg = node.args[inplaceable_op.mutated_arg]
+            if can_inplace(node, mutated_arg) and inplaceable_op.extra_check(node):
+                # TODO(yifu): this doesn't properly remove copy epilogues for
+                # ops that mutate multiple inputs. Need to revise the copy
+                # node tracking logic to support the case.
+                copy_node = copy_args_to_copy_nodes.get((mutated_arg, node))
+                if copy_node is not None:
+                    replace_dict[copy_node] = copy_node.args[0]
+                node.target = inplaceable_op.inplace_op
+        elif node.target is torch.ops.higher_order.auto_functionalized_v2:
+            _mutable_op = node.args[0]
+            kwargs = node.kwargs
+
+            all_bases = kwargs["_all_bases"]
+            bases_to_clone = range(len(all_bases))
+            base_tensors_dct = dict(enumerate(all_bases))
+            new_bases_to_clone: list[int] = reinplace_and_refine_tensors_to_clone(
+                bases_to_clone,
+                base_tensors_dct,
+                node.target,
+                ReInplaceTrigger.AUTO_FUNC_V2,
+            )
+            # Stash the metadata. There is a pass later on where we decompose
+            # auto_functionalized into clones + a mutable op; this metadata
+            # tells the decomp to only clone the following inputs
+            node.meta["only_clone_these_tensors"] = new_bases_to_clone
+        elif node.target is torch.ops.higher_order.auto_functionalized:
+            _mutable_op = node.args[0]
+            from torch._higher_order_ops.auto_functionalize import get_mutable_args
+
+            tensors_to_clone, _ = get_mutable_args(_mutable_op)
+            # Don't try to reinplace Tensor | None args that are None.
+            tensors_to_clone = [
+                t for t in tensors_to_clone if node.kwargs[t] is not None
+            ]
+            tensors_to_clone = reinplace_and_refine_tensors_to_clone(
+                tensors_to_clone,
+                node.kwargs,
+                _mutable_op._name,
+                ReInplaceTrigger.AUTO_FUNC_V1,
+            )
+
+            # Stash the metadata. There is a pass later on where we decompose
+            # auto_functionalized into clones + a mutable op; this metadata
+            # tells the decomp to only clone the following inputs
+            node.meta["only_clone_these_tensors"] = tensors_to_clone
+        elif node.target in inplaceable_triton_ops:
+            kernel_idx = node.kwargs["kernel_idx"]
+            kernel = kernel_side_table.get_kernel(kernel_idx)
+            from triton.runtime.autotuner import Autotuner
+            from triton.runtime.jit import JITFunction
+
+            if isinstance(kernel, JITFunction):
+                kernel_name = kernel.fn.__name__
+            elif isinstance(kernel, Autotuner):
+                if config.is_fbcode():
+                    # Autotuner has different implementations for AMD and NV
+                    if torch.version.hip is None:
+                        kernel_name = kernel.base_fn.__name__
+                    else:
+                        kernel_name = kernel.fn.__name__
+                else:
+                    kernel_name = kernel.base_fn.__name__
+            else:
+                raise AssertionError("Unknown triton kernel type")
+
+            # inplaceable_triton_ops take an additional argument called
+            # tensors_to_clone which contain a list of tensors to clone
+            # This pass iterates over them and sees which ones are safe
+            # to eliminate (i.e. no longer need the clones)
+            tensors_to_clone = reinplace_and_refine_tensors_to_clone(
+                node.kwargs["tensors_to_clone"],
+                node.kwargs["kwargs"],
+                kernel_name,
+                ReInplaceTrigger.TRITON_OPS,
+            )
+
+            kwargs = dict(node.kwargs)
+            kwargs["tensors_to_clone"] = tensors_to_clone
+            node.kwargs = immutable_dict(kwargs)
+            if "eager_input_vals" in node.meta:
+                # We changed the kwargs, so we need to update eager_input_vals
+                # to something sane.
+                args, kwargs = node.meta["eager_input_vals"]
+                new_kwargs = {**kwargs}
+                new_kwargs["tensors_to_clone"] = immutable_list(tensors_to_clone)
+                new_kwargs = immutable_dict(new_kwargs)
+                node.meta["eager_input_vals"] = (args, new_kwargs)
+        elif (
+            inplaceable_op := inplaceable_foreach_ops.get(node.target, None)
+        ) is not None:
+            mutated_args = node.args[inplaceable_op.mutated_arg]
+
+            if not all((arg, node) in copy_args_to_copy_nodes for arg in mutated_args):
+                continue
+
+            if can_inplace(node, mutated_args):
+                for arg in mutated_args:
+                    copy_node = copy_args_to_copy_nodes[(arg, node)]
+                    replace_dict[copy_node] = copy_node.args[0]
+
+                node.target = inplaceable_op.inplace_op
+    for node, replacement in replace_dict.items():
+        while replacement in replace_dict:
+            replacement = replace_dict[replacement]
+        replace_dict[node] = replacement
+
+        node.replace_all_uses_with(replacement)
+        graph.erase_node(node)
+
+
+def reinplace_inplaceable_ops(
+    fake_tensor_updater: torch._inductor.fx_utils.FakeTensorUpdater,
+    graph: torch.fx.Graph,
+) -> None:
+    with enable_python_dispatcher():
+        canonicalize_view_scatter_ops(graph)
+        # canonicalize_view_scatter_ops adds new operations to the graph.
+        # We run fake_tensor_updater to update the alias information.
+        # Correct alias information is required for `reinplace_inplaceable_ops_core`.
+        fake_tensor_updater.incremental_update()
+        reinplace_inplaceable_ops_core(graph)
+        decompose_generalized_scatter(graph)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/replace_random.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/replace_random.py
new file mode 100644
index 0000000000000000000000000000000000000000..150ba5cde4a7cb7c2e3f1a8987082ea11c766c3a
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/replace_random.py
@@ -0,0 +1,150 @@
+# mypy: allow-untyped-defs
+import collections
+import logging
+
+import torch
+from torch.fx.passes.graph_transform_observer import GraphTransformObserver
+from torch.fx.passes.shape_prop import _extract_tensor_metadata
+
+from .. import config, inductor_prims
+from ..pattern_matcher import (
+    CallFunctionVarArgs,
+    Match,
+    PatternMatcherPass,
+    register_graph_pattern,
+)
+from ..virtualized import V
+
+
+log = logging.getLogger(__name__)
+patterns = PatternMatcherPass(subsystem="joint_graph_passes")
+aten = torch.ops.aten
+
+
+def replace_random_passes(gm: torch.fx.GraphModule):
+    """Modify the given FX graph to use backend-native random ops"""
+    if config.fallback_random:
+        return 0
+
+    count = patterns.apply(gm)
+    with GraphTransformObserver(gm, "fuse_seed_creation_pass", "joint_graph_passes"):
+        count += fuse_seed_creation_pass(gm.graph)
+
+    return count
+
+
+def fuse_seed_creation_pass(graph: torch.fx.Graph):
+    """
+    Horizontally fuse all the seed generation on each device
+
+        a = inductor_seed(dev)
+        b = inductor_seed(dev)
+
+    Becomes:
+        seeds = inductor_seeds(2, dev)
+        a = inductor_lookup_seed(seeds, 0)
+        b = inductor_lookup_seed(seeds, 1)
+
+    We do this because seed creation is entirely launch overhead bound.
+    """
+    device_seeds = collections.defaultdict(list)
+    for node in graph.nodes:
+        if CallFunctionVarArgs(inductor_prims.seed).match(node):
+            device_seeds[node.args[0]].append(node)
+
+    if not device_seeds:
+        return 0
+
+    for device, seeds in device_seeds.items():
+        with graph.inserting_before(seeds[0]):
+            combined = graph.call_function(inductor_prims.seeds, (len(seeds), device))
+            with V.fake_mode:
+                combined.meta["val"] = torch.empty(
+                    [len(seeds)], device=device, dtype=torch.int64
+                )
+                combined.meta["tensor_meta"] = _extract_tensor_metadata(
+                    combined.meta["val"]
+                )
+
+        for idx, seed in enumerate(seeds):
+            with graph.inserting_before(seed):
+                new_seed = graph.call_function(
+                    inductor_prims.lookup_seed, (combined, idx)
+                )
+            seed.replace_all_uses_with(new_seed)
+            new_seed.meta.update(seed.meta)
+            graph.erase_node(seed)
+
+    return len(device_seeds)
+
+
+def default_kwargs(device):
+    return {}
+
+
+def get_device(device):
+    if device is not None:
+        return device
+    return torch.empty([]).device  # default device
+
+
+# pyrefly: ignore [bad-argument-type]
+@register_graph_pattern(CallFunctionVarArgs(aten.rand.default), pass_dict=patterns)
+# pyrefly: ignore [bad-argument-type]
+@register_graph_pattern(CallFunctionVarArgs(aten.rand.generator), pass_dict=patterns)
+# pyrefly: ignore [bad-argument-type]
+@register_graph_pattern(CallFunctionVarArgs(aten.randn.default), pass_dict=patterns)
+# pyrefly: ignore [bad-argument-type]
+@register_graph_pattern(CallFunctionVarArgs(aten.randn.generator), pass_dict=patterns)
+def replace_random(
+    match: Match,
+    size,
+    *,
+    generator=None,
+    dtype=None,
+    device=None,
+    layout=None,
+    pin_memory=None,
+):
+    if generator is not None:
+        return
+
+    def replacement(size):
+        result = inductor_prims.random(
+            size, inductor_prims.seed(device), mode, **default_kwargs(device)
+        )
+        if dtype is not None:
+            result = result.to(dtype)
+        return result
+
+    mode = {
+        aten.rand: "rand",
+        aten.randn: "randn",
+    }[
+        match.output_node().target.overloadpacket  # type: ignore[union-attr]
+    ]  # type: ignore[union-attr]
+    device = get_device(device)
+    # pyrefly: ignore [bad-argument-type]
+    match.replace_by_example(replacement, [size])
+
+
+# pyrefly: ignore [bad-argument-type]
+@register_graph_pattern(CallFunctionVarArgs(aten.randint.low), pass_dict=patterns)
+def replace_randint(
+    match: Match,
+    low,
+    high,
+    size,
+    *,
+    dtype=torch.int64,
+    device=None,
+    layout=None,
+    pin_memory=None,
+):
+    def replacement(low, high, size):
+        result = inductor_prims.randint(low, high, size, inductor_prims.seed(device))
+        return result.to(dtype)
+
+    device = get_device(device)
+    # pyrefly: ignore [bad-argument-type]
+    match.replace_by_example(replacement, [low, high, size])
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/split_cat.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/split_cat.py
new file mode 100644
index 0000000000000000000000000000000000000000..6347bda3b525c200ce21cb87ecc2b4a3a685e25c
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/split_cat.py
@@ -0,0 +1,3040 @@
+# mypy: allow-untyped-defs
+import itertools
+import logging
+import operator
+import os
+from collections import defaultdict
+from collections.abc import Callable, Sequence
+from typing import Any, TypeAlias
+
+import torch
+from torch._dynamo.utils import counters
+from torch.fx.experimental.symbolic_shapes import free_symbols, guard_or_false
+from torch.utils._ordered_set import OrderedSet
+
+from ..pattern_matcher import (
+    Arg,
+    CallFunction,
+    CallFunctionVarArgs,
+    CallMethodVarArgs,
+    FailedMatch,
+    get_arg_value,
+    Ignored,
+    KeywordArg,
+    ListOf,
+    Match,
+    MatchContext,
+    MULTIPLE,
+    PatternExpr,
+    PatternMatcherPass,
+    register_graph_pattern,
+    RepeatedExpr,
+)
+from .group_batch_fusion import is_node_meta_valid, POST_GRAD_FUSIONS, PRE_GRAD_FUSIONS
+
+
+log = logging.getLogger(__name__)
+
+_Arguments: TypeAlias = tuple[torch.fx.node.Argument, ...]
+_TransformParam: TypeAlias = tuple[
+    _Arguments | None,
+    _Arguments | None,
+    _Arguments | None,
+    _Arguments | None,
+]
+_Range: TypeAlias = tuple[int, int]
+
+
+PRE_GRAD_PATTERNS: dict[str, PatternMatcherPass] = {}
+POST_GRAD_PATTERNS: dict[str, PatternMatcherPass] = {}
+
+pre_grad_pass_names = [
+    "normalization_pass",
+    "remove_split_with_size_one_pass",
+    "merge_getitem_cat_pass",
+    "merge_stack_tahn_unbind_pass",
+    "merge_splits_pass",
+    "mutate_cat_pass",
+    "split_cat_pass",
+    "unbind_stack_pass",
+    "split_cat_to_slices_pass",
+    "unbind_cat_to_view_pass",
+    "split_stack_to_cats_pass",
+    "unbind_stack_to_slices_pass",
+    "move_reshape_out_of_split_stack_pass",
+    "einsum_to_pointwise_pass",
+]
+
+post_grad_pass_names = [
+    "normalization_aten_pass",
+    "decompose_mm_pass",
+    "unbind_stack_aten_pass",
+    "shape_padding_multiplier",
+    "pad_aten_mm_pass",
+    "split_cat_aten_pass",
+    "select_cat_aten_pass",
+    "move_view_after_cat_aten_pass",
+]
+
+backend = os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_BACKEND", "inductor")
+
+for pass_name in pre_grad_pass_names:
+    # exclude all passes from the group batch fusion
+    # they do not use pattern matcher
+    if pass_name in PRE_GRAD_FUSIONS:
+        continue
+    PRE_GRAD_PATTERNS[pass_name] = PatternMatcherPass(
+        pass_name=pass_name,
+    )
+
+for pass_name in post_grad_pass_names:
+    # exclude all passes from the group batch fusion
+    # they do not use pattern matcher
+    if pass_name in POST_GRAD_FUSIONS:
+        continue
+    POST_GRAD_PATTERNS[pass_name] = PatternMatcherPass(
+        pass_name=pass_name,
+    )
+
+
+def construct_pattern_matcher_pass(pass_name: str):
+    """
+    Return the specific pattern_matcher_pass given the pass name.
+    """
+    if pass_name in PRE_GRAD_PATTERNS:
+        return PRE_GRAD_PATTERNS[pass_name]
+    else:
+        return POST_GRAD_PATTERNS[pass_name]
+
+
+def _get_split_args_default(split_node):
+    input_kwarg = "tensor"
+    split_size_kwarg = "split_size_or_sections"
+    dim_kwarg = "dim"
+    default_dim_value = 0
+    if split_node.op == "call_method":
+        split_size_kwarg = "split_size"
+    return (
+        get_arg_value(split_node, 0, input_kwarg),
+        get_arg_value(split_node, 1, split_size_kwarg),
+        get_arg_value(split_node, 2, dim_kwarg) or default_dim_value,
+    )
+
+
+def _get_dim(node: Any):
+    assert isinstance(node, torch.fx.Node)
+    if "dim" in node.kwargs:
+        assert isinstance(node.kwargs["dim"], int)
+        return node.kwargs["dim"]
+    if node.target is torch.unbind:
+        if len(node.args) == 2:
+            assert isinstance(node.args[-1], int)
+            return node.args[-1]
+        return 0  # defaults to dim=0
+    if node.target is torch.split:
+        if len(node.args) == 3:
+            assert isinstance(node.args[-1], int)
+            return node.args[-1]
+        return 0  # defaults to dim=0
+    raise AssertionError(
+        f"Can't extract `dim` from {node.target} {node.args} {node.kwargs}"
+    )
+
+
+# noqa: W605
+# ############The pattern to be optimized is#########
+#         unbind (dim=0)
+#       /   ...    \
+# getitem      getitem   -> user=1
+#    |            |
+#  split         split  -> dim=1, user=1, split_section_size=1
+#    |            |
+#  getitem       getitem  -> user=1
+#    \           /
+#        cat (dim=1)  -> user=1
+#          |
+
+# ################After transformation#############
+#          unbind (dim=0)
+#        /    ...   \
+#    getitem       getitem  -> user=1
+#       \          /
+#        cat (dim=1)  -> user=1
+#         |
+
+
+def normalize_split_base(
+    match: Match,
+    _get_split_args: Callable[
+        [torch.fx.Node], tuple[torch.fx.Node | None, Any | None, int | None]
+    ],
+):
+    """
+    Normalize split with split_size into split_with_sizes, so that we only deal with one type of split in
+    subsequent optimizations
+    """
+    split_node = match.nodes[0]
+    graph = match.graph
+    split_input, split_size, split_dim = _get_split_args(split_node)
+    if split_input is None or split_dim is None or split_size is None:
+        log.debug("couldn't find split args")
+        return
+    if not is_node_meta_valid(split_node):
+        log.debug("example value absent for node: %s", split_node)
+        return
+    assert isinstance(split_node.meta["example_value"], (list, tuple))
+    split_sections = [t.size()[split_dim] for t in split_node.meta["example_value"]]
+
+    if any(isinstance(section, torch.SymInt) for section in split_sections):
+        # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing.
+        return
+    if split_dim < 0:  # Normalize split dim
+        split_dim += split_input.meta["example_value"].dim()
+
+    new_args = (split_input, split_sections)
+    new_kwargs = {"dim": split_dim}
+    if (
+        split_node.args == new_args
+        and split_node.kwargs == new_kwargs
+        and split_node.op == "call_function"
+    ):
+        return
+
+    with graph.inserting_after(split_node):
+        new_split_node = graph.call_function(
+            torch.split,
+            args=new_args,
+            kwargs=new_kwargs,  # type: ignore[arg-type]
+        )
+    split_node.replace_all_uses_with(new_split_node)
+    new_split_node.meta.update(split_node.meta)
+    graph.erase_node(split_node)
+    counters[backend]["normalization_pass"] += 1
+
+
+@register_graph_pattern(
+    CallFunctionVarArgs(torch.split, users=MULTIPLE),
+    pass_dict=construct_pattern_matcher_pass("normalization_pass"),
+)
+@register_graph_pattern(
+    CallMethodVarArgs("split", users=MULTIPLE),
+    pass_dict=construct_pattern_matcher_pass("normalization_pass"),
+)
+def normalize_split_default(match: Match, *args, **kwargs):
+    return normalize_split_base(match, _get_split_args_default)
+
+
+@register_graph_pattern(
+    CallFunctionVarArgs(torch.split, users=MULTIPLE),
+    pass_dict=construct_pattern_matcher_pass("remove_split_with_size_one_pass"),
+)
+@register_graph_pattern(
+    CallMethodVarArgs("split", users=MULTIPLE),
+    pass_dict=construct_pattern_matcher_pass("remove_split_with_size_one_pass"),
+)
+def remove_split_with_size_one(match: Match, *args, **kwargs):
+    graph = match.graph
+    split_node = match.nodes[0]
+    split_input, split_size, split_dim = _get_split_args_default(split_node)
+    if split_input is None or split_dim is None or split_size is None:
+        log.debug("couldn't find split args")
+        return
+    if not is_node_meta_valid(split_node):
+        log.debug("example value absent for node: %s", split_node)
+        return
+    assert isinstance(split_node.meta["example_value"], (list, tuple))
+    split_sections = [t.size()[split_dim] for t in split_node.meta["example_value"]]
+
+    if any(isinstance(section, torch.SymInt) for section in split_sections):
+        # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing.
+        return
+    # remove the dummy split whose split sections size is one
+    # theoretically nodes with no users should be removed, but we have seen the corner case
+    # thus we add its users check to walk around the StopIteration error.
+    if len(split_sections) == 1 and len(split_node.users.keys()) > 0:
+        # find the grand children of the split_node
+        next_users = find_next_users(split_node)
+        user = next(iter(split_node.users.keys()))
+        # replace the users of grand child node with the input node
+        for next_user in next_users:
+            next_user.replace_input_with(user, split_input)
+        # erase the split node and its child
+        graph.erase_node(user)
+        graph.erase_node(split_node)
+        counters[backend]["remove_split_with_size_one_pass"] += 1
+
+
+@register_graph_pattern(
+    CallFunctionVarArgs(torch.unbind, users=MULTIPLE),
+    pass_dict=construct_pattern_matcher_pass("normalization_pass"),
+)
+@register_graph_pattern(
+    CallMethodVarArgs("unbind", users=MULTIPLE),
+    pass_dict=construct_pattern_matcher_pass("normalization_pass"),
+)
+def normalize_unbind_default(match: Match, *args, **kwargs):
+    node = match.nodes[0]
+    graph = match.graph
+    input = get_arg_value(node, 0, "input")
+    dim = get_arg_value(node, 1, "dim")
+    if dim is None:
+        axis = node.kwargs.get("axis")
+        if axis is not None:
+            dim = axis
+        else:
+            dim = 0
+    if input is None:
+        log.debug("couldn't find unbind args")
+        return
+    if not is_node_meta_valid(input):
+        log.debug("example value absent for node: %s", input)
+        return
+    ndim = input.meta["example_value"].ndim
+    # pyrefly: ignore [unsupported-operation]
+    if dim < 0:  # Normalize unbind dim
+        dim += ndim
+    with graph.inserting_after(node):
+        new_node = graph.call_function(
+            torch.unbind,
+            args=(input,),
+            kwargs={"dim": dim},
+        )
+    node.replace_all_uses_with(new_node)
+    new_node.meta.update(node.meta)
+    graph.erase_node(node)
+    counters[backend]["normalization_pass"] += 1
+
+
+@register_graph_pattern(
+    CallFunctionVarArgs([torch.cat, torch.concat], users=MULTIPLE),
+    pass_dict=construct_pattern_matcher_pass("normalization_pass"),
+)
+def normalize_cat_default(match: Match, *args, **kwargs):
+    cat_node = match.nodes[0]
+    graph = match.graph
+    tensors = get_arg_value(cat_node, 0, "tensors")
+    cat_dim = get_arg_value(cat_node, 1, "dim")
+    if cat_dim is None:
+        cat_axis = cat_node.kwargs.get("axis")
+        if cat_axis is not None:
+            cat_dim = cat_axis
+        else:
+            cat_dim = 0
+    if tensors is None or cat_dim is None:
+        log.debug("couldn't find cat args")
+        return
+    assert isinstance(tensors, (list, tuple))
+    for tensor in itertools.chain([cat_node], tensors):
+        if not is_node_meta_valid(tensor):
+            log.debug("example value absent for node: %s", tensor)
+            return
+
+    ndim = cat_node.meta["example_value"].dim()
+
+    def is_empty_tensor(x):
+        # special case where torch.cat supports cat'ing with an empty tensor
+        x_shape = x.meta["example_value"].shape
+        return len(x_shape) == 1 and guard_or_false(x_shape[0] == 0)
+
+    assert all(
+        ndim == x.meta["example_value"].dim() or is_empty_tensor(x) for x in tensors
+    )
+
+    # pyrefly: ignore [unsupported-operation]
+    if cat_dim < 0:  # Normalize cat dim
+        cat_dim += ndim
+
+    new_args = (tensors,)
+    new_kwargs = {"dim": cat_dim}
+    if (
+        cat_node.args == new_args
+        and cat_node.kwargs == new_kwargs
+        and cat_node.op == "call_function"
+        and cat_node.target is torch.cat
+    ):
+        return
+
+    with graph.inserting_after(cat_node):
+        new_cat_node = graph.call_function(
+            torch.cat,
+            args=new_args,
+            kwargs=new_kwargs,
+        )
+    cat_node.replace_all_uses_with(new_cat_node)
+    new_cat_node.meta.update(cat_node.meta)
+    graph.erase_node(cat_node)
+    counters[backend]["normalization_pass"] += 1
+
+
+@register_graph_pattern(
+    CallFunctionVarArgs(torch.stack, users=MULTIPLE),
+    pass_dict=construct_pattern_matcher_pass("normalization_pass"),
+)
+def normalize_stack_default(match: Match, *args, **kwargs):
+    node = match.nodes[0]
+    graph = match.graph
+    tensors = get_arg_value(node, 0, "tensors")
+    dim = get_arg_value(node, 1, "dim") or 0
+    if tensors is None or dim is None:
+        log.debug("couldn't find stack args")
+        return
+    assert isinstance(tensors, (list, tuple))
+
+    # A bug in pytorch, some nodes miss the example_value metadata
+    for tensor in itertools.chain([node], tensors):
+        if not is_node_meta_valid(tensor):
+            log.debug("example value absent for node: %s", tensor)
+            return
+
+    ndim = node.meta["example_value"].dim()
+    if dim < 0:  # Normalize dim
+        dim += ndim
+
+    with graph.inserting_after(node):
+        new_node = graph.call_function(
+            node.target,  # type: ignore[arg-type]
+            args=(tensors,),
+            kwargs={"dim": dim},
+        )
+    node.replace_all_uses_with(new_node)
+    new_node.meta.update(node.meta)
+    graph.erase_node(node)
+    counters[backend]["normalization_pass"] += 1
+
+
+def find_next_users(split_node: torch.fx.Node) -> list[torch.fx.Node]:
+    next_users = []
+    for getitem_node in split_node.users:
+        for getitem_user in getitem_node.users:
+            if getitem_user not in next_users:
+                next_users.append(getitem_user)
+    return next_users
+
+
+@register_graph_pattern(
+    CallMethodVarArgs("squeeze", users=MULTIPLE),
+    pass_dict=construct_pattern_matcher_pass("normalization_pass"),
+)
+def normalize_squeeze_default(match: Match, *args, **kwargs):
+    squeeze_node = match.nodes[0]
+    squeeze_input = get_arg_value(squeeze_node, 0)
+
+    if "dim" in squeeze_node.kwargs:
+        assert len(squeeze_node.args) == 1
+        dim = squeeze_node.kwargs["dim"]
+    elif len(squeeze_node.args) == 1:
+        # squeeze(Tensor)
+        dim = None
+    elif len(squeeze_node.args) == 2:
+        # squeeze(Tensor self, int dim)
+        # squeeze(Tensor self, int[] dim)
+        dim = squeeze_node.args[1]
+    else:
+        # squeeze(Tensor self, int[] dim) (called with varargs)
+        dim = squeeze_node.args[1:]
+
+    if isinstance(dim, Sequence) and len(dim) == 1:
+        dim = dim[0]
+
+    with match.graph.inserting_after(squeeze_node):
+        if dim is None:
+            new_squeeze_node = match.graph.call_function(
+                torch.squeeze, args=(squeeze_input,)
+            )
+        else:
+            new_squeeze_node = match.graph.call_function(
+                torch.squeeze, args=(squeeze_input,), kwargs={"dim": dim}
+            )
+    squeeze_node.replace_all_uses_with(new_squeeze_node)
+    new_squeeze_node.meta.update(squeeze_node.meta)
+    match.graph.erase_node(squeeze_node)
+
+
+@register_graph_pattern(
+    CallMethodVarArgs("reshape", users=MULTIPLE),
+    pass_dict=construct_pattern_matcher_pass("normalization_pass"),
+)
+def normalize_reshape_default(match: Match, *args, **kwargs):
+    reshape_node = match.nodes[0]
+    if not is_node_meta_valid(reshape_node):
+        log.debug("example value absent for node: %s", reshape_node)
+        return
+    reshape_input = get_arg_value(reshape_node, 0)
+
+    if free_symbols(reshape_node.meta["example_value"].shape):
+        log.debug("dynamic shape not supported: %s", reshape_node)
+        return
+
+    with match.graph.inserting_after(reshape_node):
+        new_reshape_node = match.graph.call_function(
+            torch.reshape,
+            args=(reshape_input, tuple(reshape_node.meta["example_value"].shape)),
+        )
+    reshape_node.replace_all_uses_with(new_reshape_node)
+    new_reshape_node.meta.update(reshape_node.meta)
+    match.graph.erase_node(reshape_node)
+
+
+@register_graph_pattern(
+    CallMethodVarArgs("clamp", users=MULTIPLE),
+    pass_dict=construct_pattern_matcher_pass("normalization_pass"),
+)
+@register_graph_pattern(
+    CallFunctionVarArgs(torch.clamp, users=MULTIPLE),
+    pass_dict=construct_pattern_matcher_pass("normalization_pass"),
+)
+def normalize_clamp_default(match: Match, *args, **kwargs):
+    clamp_node = match.nodes[0]
+    if not is_node_meta_valid(clamp_node):
+        log.debug("example value absent for node: %s", clamp_node)
+        return
+
+    if free_symbols(clamp_node.meta["example_value"].shape):
+        log.debug("dynamic shape not supported: %s", clamp_node)
+        return
+    if len(clamp_node.args) > 1:
+        args = (get_arg_value(clamp_node, 0),)
+        kwargs = {
+            "min": get_arg_value(clamp_node, 1, kwarg_name="min"),
+            "max": get_arg_value(clamp_node, 2, kwarg_name="max"),
+        }
+    else:
+        args = clamp_node.args
+        kwargs = clamp_node.kwargs
+    with match.graph.inserting_after(clamp_node):
+        new_clamp_node = match.graph.call_function(
+            torch.clamp,
+            args=args,
+            kwargs=kwargs,
+        )
+    clamp_node.replace_all_uses_with(new_clamp_node)
+    new_clamp_node.meta.update(clamp_node.meta)
+    match.graph.erase_node(clamp_node)
+
+
+@register_graph_pattern(
+    CallMethodVarArgs("detach", users=MULTIPLE),
+    pass_dict=construct_pattern_matcher_pass("normalization_pass"),
+)
+def normalize_detach_default(match: Match, *args, **kwargs):
+    detach_node = match.nodes[0]
+    if not is_node_meta_valid(detach_node):
+        log.debug("example value absent for node: %s", detach_node)
+        return
+
+    if free_symbols(detach_node.meta["example_value"].shape):
+        log.debug("dynamic shape not supported: %s", detach_node)
+        return
+
+    with match.graph.inserting_after(detach_node):
+        new_detach_node = match.graph.call_function(
+            torch.detach,
+            args=detach_node.args,
+        )
+    detach_node.replace_all_uses_with(new_detach_node)
+    new_detach_node.meta.update(detach_node.meta)
+    match.graph.erase_node(detach_node)
+
+
+class TorchSplit(CallFunction):
+    """
+    Matches a call to torch.split if it is in a normalized form. Ensures that all users of
+    splits are unique getitems.
+    """
+
+    def __init__(self, arg, sizes, func=torch.split) -> None:
+        # using KeywordArg("dim") for `dim` checks they all match
+        super().__init__(func, arg, sizes, _users=MULTIPLE, dim=KeywordArg("dim"))
+
+    def _match(self, node: torch.fx.Node, ctx: MatchContext):
+        m = super()._match(node, ctx)
+        if not m:
+            return m
+        split_sections = node.args[1]
+        if not isinstance(split_sections, (list, tuple)):
+            return FailedMatch("split not normalized")
+        # check users are all unique getitems
+        seen_idxs = OrderedSet[int]()
+        for user in node.users:
+            if not CallFunction(operator.getitem, Arg(), Arg()).match(user):
+                # This should ideally never happen. Split user should always be a getitem
+                return FailedMatch(f"user of split not a getitem: {user}")
+            if not isinstance(user.args[1], int):
+                return FailedMatch("only integer getitems are handled")
+            if user.args[1] in seen_idxs:
+                return FailedMatch(f"duplicate getitem {user.args[1]}")
+            if user.args[-1] < 0:  # type: ignore[operator]
+                # This shouldn't ideally happen as dynamo normalizes indexes to positive
+                return FailedMatch("negative index")
+            seen_idxs.add(user.args[1])
+        return m
+
+
+@register_graph_pattern(
+    TorchSplit(
+        CallFunction(
+            operator.getitem,
+            TorchSplit(
+                KeywordArg("first_split_input"),
+                KeywordArg("first_split_sections"),
+            ),
+            Ignored(),
+        ),
+        KeywordArg("next_split_sections"),
+    ),
+    pass_dict=construct_pattern_matcher_pass("merge_splits_pass"),
+)
+def merge_splits(
+    match: Match,
+    first_split_input: torch.fx.Node,
+    first_split_sections: list[int],
+    next_split_sections: list[int],
+    # Note: dim is implicitly passed by TorchSplit, as it internally uses a pattern with dim
+    dim: int,
+):
+    node = match.output_node()
+    # it is possible that the split has no users,
+    # we check the corner case and skip the pattern
+    if len(node.users.keys()) == 0:
+        return
+    graph = match.graph
+    first_split = node.args[0].args[0]  # type: ignore[union-attr]
+    next_split_index = node.args[0].args[1]  # type: ignore[union-attr]
+
+    new_split_sections = list(first_split_sections)
+    new_split_sections[next_split_index : next_split_index + 1] = next_split_sections  # type: ignore[operator, misc]
+
+    first_split_dim = _get_dim(first_split)
+
+    to_remove = []
+
+    with graph.inserting_before(first_split):  # type: ignore[arg-type]
+        # Add the new split node
+        new_split = graph.call_function(
+            torch.split,
+            args=(first_split_input, new_split_sections),
+            kwargs={"dim": first_split_dim},
+        )
+        if is_node_meta_valid(first_split_input):
+            new_split.meta["example_value"] = torch.split(
+                first_split_input.meta["example_value"],
+                new_split_sections,
+                dim=first_split_dim,
+            )
+        first_split_num_to_user = {
+            user.args[1]: user
+            for user in first_split.users  # type: ignore[union-attr]
+        }
+
+        new_split_num = 0
+        for split_num in range(len(first_split_sections)):
+            if split_num not in first_split_num_to_user:
+                new_split_num += 1
+                continue
+            old_getitem = first_split_num_to_user[split_num]
+            if split_num != next_split_index:
+                old_getitem.update_arg(0, new_split)
+                old_getitem.update_arg(1, new_split_num)
+                new_split_num += 1
+            else:
+                next_split_num_to_user = {user.args[1]: user for user in node.users}
+                # It is not necessary all getitems from the split node are used.
+                for next_split_num in range(len(next_split_sections)):
+                    with graph.inserting_after(new_split):
+                        new_getitem = graph.call_function(
+                            operator.getitem, args=(new_split, new_split_num)
+                        )
+                    new_split_num += 1
+                    if next_split_num not in next_split_num_to_user:
+                        continue
+                    next_getitem = next_split_num_to_user[next_split_num]
+                    new_getitem.meta.update(next_getitem.meta)
+                    next_getitem.replace_all_uses_with(new_getitem)
+                    to_remove.append(next_getitem)
+                to_remove.append(node)
+                to_remove.append(old_getitem)
+
+        to_remove.append(first_split)  # type: ignore[arg-type]
+    for node in to_remove:
+        graph.erase_node(node)
+
+    counters[backend]["merge_splits_pass"] += 1
+
+
+class SplitCatSimplifier:
+    """
+    Helper class to simplify split-cat pattern. In simple cases, both split and cat node can be removed in a "split->cat"
+    pattern. However, there are various cases where they can't and we need to simplify split/ add transforms before cat.
+    Some such cases are:
+        1. Final node has additional args (not coming from the initial split)
+        2. Shuffling of args between split/cat
+        3. Some final nodes are non-(cat/stack)
+        4. Split-dim != cat-dim (but equal split)
+
+    Note that any combination of the above cases can happen.
+
+    To deal with 1, 2, & 3 - we iterate over all users of split. And figure out common "ranges" that can be merged.
+    Then, we simplify the split accordingly. In the best case, split can be entirely removed.
+
+    To deal with 4, we add some transformations (unflatten + movedim) (See `get_transform_params`).
+
+    Finally, depending on final node being cat or stack, unsqueeze/flatten needs to be added.
+
+    """
+
+    def simplify(
+        self,
+        graph: torch.fx.Graph,
+        split_node: torch.fx.Node,
+        split_sections: list[int],
+    ):
+        # Find the next users (i.e. users after the getitem)
+        next_users = find_next_users(split_node)
+        # Gather inputs of the next users. When inputs come from `split_node`, they are instead represented by
+        # a tuple indicating the split ranges. See `get_user_input_list` for more details
+        user_inputs_list = self.get_user_input_list(split_node, next_users)
+        # Simplify the split_sections based on user_inputs_list. In simpler cases, len(simplified_split_ranges) == 1 and
+        # we can simply replace the split node. Otherwise, we simplify it.
+        simplified_split_ranges = self.get_simplified_split_ranges(
+            split_sections, next_users, user_inputs_list
+        )
+        if not simplified_split_ranges:  # Simplification not possible
+            return
+        transform_params_list = self.get_transform_params(
+            split_node, next_users, user_inputs_list
+        )
+        if not transform_params_list:
+            return
+
+        # Start actual replacement
+        user_inputs_list_new = self.replace_split(
+            graph, split_node, split_sections, user_inputs_list, simplified_split_ranges
+        )
+        self.replace_cat(
+            graph,
+            split_node,
+            next_users,
+            user_inputs_list_new,
+            transform_params_list,  # type: ignore[arg-type]
+        )
+        self.erase_old_nodes(graph, split_node, next_users)  # type: ignore[arg-type]
+        counters[backend]["unbind_stack_pass"] += 1
+
+    def get_user_input_list(
+        self, split_node: torch.fx.Node, next_users: list[torch.fx.Node]
+    ) -> list[list[torch.fx.Node | _Range]]:
+        """
+        Returns list of inputs to the following user nodes, in order. The outer list represents the user node. The inner
+        list represents the inputs to that particular node. This list can either contain
+          - a tuple representing the ranges of get_items that should go into the cat (closed interval)
+          - torch.fx.Node representing "other" inputs (which are not coming from our split)
+        """
+        user_inputs_list: list[list[torch.fx.Node | _Range]] = []
+        for user in next_users:
+            if user.target in (torch.cat, torch.stack):
+                user_inputs_list.append(self.get_merged_user_inputs(split_node, user))
+            else:
+                user_inputs_list.append(self.get_non_cat_node_input(split_node, user))  # type: ignore[arg-type]
+        return user_inputs_list
+
+    def get_merged_user_inputs(
+        self, split_node: torch.fx.Node, cat_node: torch.fx.Node
+    ) -> list[torch.fx.Node | _Range]:
+        user_inputs = get_arg_value(cat_node, 0, "tensors")
+        simplified_user_inputs = []
+        split_users = OrderedSet(split_node.users.keys())
+        for user_input in user_inputs:
+            if user_input not in split_users:
+                simplified_user_inputs.append(user_input)
+            else:
+                # Add which "getitem" cat depends on
+                simplified_user_inputs.append(user_input.args[1])
+        return self.merge_consecutive_inputs(simplified_user_inputs)
+
+    def get_non_cat_node_input(
+        self, split_node: torch.fx.Node, node: torch.fx.Node
+    ) -> list[_Range]:
+        """
+        Get input for a non cat node in the same format as `get_merged_user_inputs`
+        """
+        node_input = []
+        split_users = OrderedSet(split_node.users.keys())
+        for node_arg in node.all_input_nodes:
+            if node_arg in split_users:
+                getitem_num = get_arg_value(node_arg, 1)
+                node_input.append((getitem_num, getitem_num))
+        return node_input
+
+    def merge_consecutive_inputs(
+        self, inputs: list[torch.fx.Node | int]
+    ) -> list[torch.fx.Node | _Range]:
+        """
+        Merge consecutive inputs going into a user node.
+
+        For e.g.
+        [arg0, 0, 1, 2, arg1] -> [arg0, (0, 2), arg1]
+        """
+        merged_ranges = []
+        cur_range = None
+        for input_ in inputs:
+            if isinstance(input_, int):
+                if not cur_range:
+                    cur_range = [input_, input_]
+                elif input_ == cur_range[1] + 1:
+                    cur_range[1] += 1
+                else:
+                    merged_ranges.append(tuple(cur_range))
+                    cur_range = [input_, input_]
+            else:
+                if cur_range:
+                    merged_ranges.append(tuple(cur_range))
+                    cur_range = None
+                merged_ranges.append(input_)  # type: ignore[arg-type]
+        if cur_range:
+            merged_ranges.append(tuple(cur_range))
+        return merged_ranges  # type: ignore[return-value]
+
+    def get_simplified_split_ranges(
+        self,
+        split_sections,
+        next_users,
+        user_inputs_list: list[list[torch.fx.Node | _Range]],
+    ) -> list[_Range] | None:
+        ranges = OrderedSet[Any]()
+        for user_inputs in user_inputs_list:
+            ranges.update(u for u in user_inputs if isinstance(u, tuple))
+
+        cumulative_sizes = [0] + torch.cumsum(torch.tensor(split_sections), 0).tolist()
+        split_ranges = sorted(
+            [(cumulative_sizes[r[0]], cumulative_sizes[r[1] + 1]) for r in ranges]
+        )
+
+        if not self.has_non_overlapping_ranges(
+            split_ranges,
+        ):  # This need not be a strict condition
+            # However, we keep it now for simplicity.
+            return None
+        split_ranges = self.fill_gaps(split_ranges, 0, cumulative_sizes[-1])
+        if len(split_sections) == len(split_ranges):  # Simplification not possible
+            return None
+        counters[backend]["scmerge_split_sections_removed"] = len(split_sections) - len(
+            split_ranges
+        )
+        return split_ranges
+
+    def has_non_overlapping_ranges(self, ranges: list[_Range]) -> bool:
+        for range_, next_range in itertools.pairwise(ranges):
+            if range_[1] > next_range[0]:
+                return False
+        return True
+
+    def fill_gaps(self, ranges: list[_Range], min_: int, max_: int) -> list[_Range]:
+        cur = min_
+        filled_ranges = []
+        for a, b in ranges:
+            if cur < a:
+                filled_ranges.append((cur, a))
+            filled_ranges.append((a, b))
+            cur = b
+        if filled_ranges[-1][1] < max_:
+            filled_ranges.append((filled_ranges[-1][1], max_))
+        return filled_ranges
+
+    def get_transform_params(
+        self,
+        split_node: torch.fx.Node,
+        next_users: list[torch.fx.Node],
+        user_inputs_list: list[list[torch.fx.Node | _Range]],
+    ) -> list[list[_TransformParam]] | None:
+        """
+        Figure out what transforms are needed for each input to each cat node.
+
+        We replace a split node with an unflatten followed by a movedim
+        """
+        split_dim = _get_dim(split_node)
+        split_sections = split_node.args[1]
+        transform_params_list: list[list[_TransformParam]] = []
+
+        for user_node, user_inputs in zip(next_users, user_inputs_list):
+            if user_node.target not in (torch.cat, torch.stack):
+                transform_params_list.append([])
+                continue
+
+            cat_dim = get_arg_value(user_node, 1, "dim")
+            transform_params: list[_TransformParam] = []
+            for user_input in user_inputs:
+                if split_dim == cat_dim and user_node.target is torch.cat:
+                    # No transform needed
+                    transform_params.append((None, None, None, None))
+                elif isinstance(user_input, tuple):  # Split being simplified
+                    # Verify equal split
+                    subset_split_sections = split_sections[  # type: ignore[index]
+                        user_input[0] : user_input[1]
+                        + 1  # type: ignore[index]
+                    ]
+                    # All sections should be equal
+                    if len(OrderedSet(subset_split_sections)) != 1:  # type: ignore[arg-type]
+                        return None
+
+                    num_splits = len(subset_split_sections)  # type: ignore[arg-type]
+                    unflatten_params = (split_dim, (num_splits, -1))
+                    movedim_params = (
+                        (split_dim, cat_dim) if split_dim != cat_dim else None
+                    )
+                    transform_params.append(
+                        (unflatten_params, movedim_params, None, None)
+                    )
+                elif (
+                    user_node.target is torch.stack or split_dim != cat_dim
+                ):  # We need to unsqueeze inputs not coming through split
+                    transform_params.append((None, None, (cat_dim,), None))
+                else:  # Non-split inputs
+                    transform_params.append((None, None, None, None))
+            transform_params_list.append(transform_params)
+        return transform_params_list
+
+    def replace_split(
+        self,
+        graph: torch.fx.Graph,
+        split_node: torch.fx.Node,
+        split_sections: list[int],
+        user_inputs_list: list[list[torch.fx.Node | _Range]],
+        split_ranges: list[_Range],
+    ) -> list[list[torch.fx.Node]]:
+        """
+        Replace the split node. It can either remove the split node if len(split_ranges) == 1, or simplify it
+        into a split with lesser sections if len(split_ranges) > 1.
+
+        Returns the new `user_inputs_list`, with tuples replaced with new getitems from the newer split node.
+        """
+        split_input = split_node.args[0]
+        split_dim = _get_dim(split_node)
+        if len(split_ranges) == 1:  # We can completely eliminate the split node
+            split_items = [split_input]
+        else:
+            with graph.inserting_after(split_node):
+                new_split = graph.call_function(
+                    torch.split,
+                    args=(
+                        split_input,
+                        [r[1] - r[0] for r in split_ranges],
+                    ),
+                    kwargs={"dim": split_dim},
+                )
+                if is_node_meta_valid(split_input):  # type: ignore[arg-type, union-attr]
+                    new_split.meta["example_value"] = torch.split(
+                        split_input.meta["example_value"],  # type: ignore[union-attr]
+                        [r[1] - r[0] for r in split_ranges],
+                        dim=split_dim,
+                    )
+                counters[backend]["scmerge_split_added"] += 1
+            split_items = []
+            with graph.inserting_after(new_split):
+                for i in range(len(split_ranges)):
+                    getitem = graph.call_function(operator.getitem, args=(new_split, i))
+                    if is_node_meta_valid(new_split):
+                        getitem.meta["example_value"] = new_split.meta["example_value"][
+                            i
+                        ]
+                        split_items.append(getitem)
+        # Now assign the right getitem to the right input
+        cumulative_sizes = [0] + torch.cumsum(torch.tensor(split_sections), 0).tolist()
+        new_user_inputs_list = []
+        for user_inputs in user_inputs_list:
+            new_user_inputs = []
+            for user_input in user_inputs:
+                if isinstance(user_input, tuple):
+                    # Find the correct new getitem (present in split_items)
+                    new_user_inputs.append(
+                        # pyrefly: ignore [bad-argument-type]
+                        split_items[
+                            split_ranges.index(
+                                (
+                                    cumulative_sizes[user_input[0]],
+                                    cumulative_sizes[user_input[1] + 1],
+                                )
+                            )
+                        ]
+                    )
+                else:
+                    new_user_inputs.append(user_input)
+            new_user_inputs_list.append(new_user_inputs)
+        return new_user_inputs_list  # type: ignore[return-value]
+
+    def replace_cat(
+        self,
+        graph: torch.fx.Graph,
+        split_node: torch.fx.Node,
+        next_users: list[torch.fx.Node],
+        user_inputs_list_new,
+        transform_params_list: list[list[_TransformParam]],
+    ):
+        split_dim = _get_dim(split_node)
+        split_users = split_node.users.keys()
+        new_cats = []
+        for user_node, user_inputs_new, transform_params in zip(
+            next_users, user_inputs_list_new, transform_params_list
+        ):
+            if user_node.target not in (torch.cat, torch.stack):
+                # Change the args and kwargs of non-cat/stack nodes. Replace old getitems (belonging to
+                # the original split node) with the newer getitems
+                next_cat_input = 0
+                for input_node in user_node.all_input_nodes:
+                    if input_node in split_users:
+                        user_node.replace_input_with(
+                            input_node, user_inputs_new[next_cat_input]
+                        )
+                        next_cat_input += 1
+                continue
+
+            # Handle cat/stack user nodes
+            cat_dim = get_arg_value(user_node, 1, "dim")
+            user_inputs_new_transformed, user_inputs_new_transformed_meta = [], []
+            # For `unsqueeze` transform, we will combine consecutive inputs with the same unsqueeze params, and stack them
+            to_stack, to_stack_meta = [], []
+            stack_dim = None
+            with graph.inserting_before(user_node):
+                for user_input_new, transform_param in zip(
+                    user_inputs_new, transform_params
+                ):
+                    # pyrefly: ignore [bad-argument-type]
+                    if not is_node_meta_valid(user_input_new):
+                        log.debug("example value absent for node: %s", user_input_new)
+                        return
+                    # Apply transforms
+                    (
+                        unflatten_params,
+                        movedim_params,
+                        unsqueeze_params,
+                        flatten_params,
+                    ) = transform_param
+                    if unsqueeze_params and (
+                        stack_dim is None or stack_dim == unsqueeze_params[0]
+                    ):
+                        to_stack.append(user_input_new)
+                        # pyrefly: ignore [missing-attribute]
+                        to_stack_meta.append(user_input_new.meta["example_value"])
+                        stack_dim = unsqueeze_params[0]
+                        continue
+                    elif to_stack:
+                        stacked_input = graph.call_function(
+                            torch.stack, args=(to_stack,), kwargs={"dim": stack_dim}
+                        )
+                        stacked_input.meta["example_value"] = torch.stack(  # type: ignore[arg-type]
+                            to_stack_meta,
+                            dim=stack_dim,  # type: ignore[arg-type]
+                        )
+                        to_stack, to_stack_meta = [], []
+                        stack_dim = None
+                        user_inputs_new_transformed.append(stacked_input)
+                        user_inputs_new_transformed_meta.append(
+                            stacked_input.meta["example_value"]
+                        )
+                        if unsqueeze_params:
+                            to_stack.append(user_input_new)
+                            stack_dim = unsqueeze_params[0]
+                            # pyrefly: ignore [missing-attribute]
+                            to_stack_meta.append(user_input_new.meta["example_value"])
+                            continue
+
+                    if unflatten_params:
+                        # pyrefly: ignore [missing-attribute]
+                        user_input_new_meta = user_input_new.meta["example_value"]
+                        user_input_new = graph.call_function(
+                            torch.unflatten, args=(user_input_new, *unflatten_params)
+                        )
+                        user_input_new.meta["example_value"] = torch.unflatten(  # type: ignore[arg-type]
+                            user_input_new_meta,  # type: ignore[arg-type]
+                            *unflatten_params,  # type: ignore[arg-type]
+                        )
+                    if movedim_params:
+                        # pyrefly: ignore [missing-attribute]
+                        user_input_new_meta = user_input_new.meta["example_value"]
+                        user_input_new = graph.call_function(
+                            torch.movedim, args=(user_input_new, *movedim_params)
+                        )
+                        user_input_new.meta["example_value"] = torch.movedim(  # type: ignore[arg-type]
+                            user_input_new_meta,  # type: ignore[arg-type]
+                            *movedim_params,  # type: ignore[arg-type]
+                        )
+                    if flatten_params:
+                        # pyrefly: ignore [missing-attribute]
+                        user_input_new_meta = user_input_new.meta["example_value"]
+                        user_input_new = graph.call_function(
+                            torch.flatten, args=(user_input_new, *flatten_params)
+                        )
+                        user_input_new.meta["example_value"] = torch.flatten(  # type: ignore[arg-type]
+                            user_input_new_meta,
+                            *flatten_params,  # type: ignore[arg-type]
+                        )
+                    user_inputs_new_transformed.append(user_input_new)
+                    user_inputs_new_transformed_meta.append(
+                        # pyrefly: ignore [missing-attribute]
+                        user_input_new.meta["example_value"]
+                    )
+                if to_stack:
+                    stacked_input = graph.call_function(
+                        torch.stack, args=(to_stack,), kwargs={"dim": stack_dim}
+                    )
+                    stacked_input.meta["example_value"] = torch.stack(  # type: ignore[arg-type]
+                        to_stack_meta,
+                        dim=stack_dim,  # type: ignore[arg-type]
+                    )
+                    user_inputs_new_transformed.append(stacked_input)
+                    user_inputs_new_transformed_meta.append(
+                        stacked_input.meta["example_value"]
+                    )
+
+            with graph.inserting_after(user_node):
+                if len(user_inputs_new_transformed) > 1:
+                    new_cat_node = graph.call_function(
+                        torch.cat,
+                        args=(user_inputs_new_transformed,),
+                        kwargs={"dim": cat_dim},
+                    )
+                    new_cat_node.meta["example_value"] = torch.cat(
+                        user_inputs_new_transformed_meta,
+                        dim=cat_dim,
+                    )
+                    counters[backend]["scmerge_cat_added"] += 1
+                else:
+                    new_cat_node = user_inputs_new_transformed[-1]
+                    new_cat_node.meta["example_value"] = (
+                        user_inputs_new_transformed_meta[-1]
+                    )
+
+            if (
+                user_node.target is torch.cat
+                and split_dim != cat_dim
+                and split_node.target is torch.split
+            ):
+                with graph.inserting_after(new_cat_node):
+                    new_cat_node_meta = new_cat_node.meta["example_value"]
+                    new_cat_node = graph.call_function(
+                        torch.flatten, args=(new_cat_node, cat_dim, cat_dim + 1)
+                    )
+                    new_cat_node.meta["example_value"] = torch.flatten(
+                        new_cat_node_meta,
+                        cat_dim,
+                        cat_dim + 1,
+                    )
+            user_node.replace_all_uses_with(new_cat_node)
+            new_cats.append(new_cat_node)
+
+    def erase_old_nodes(
+        self,
+        graph: torch.fx.Graph,
+        split_node: torch.fx.Node,
+        next_users: list[torch.fx.Node],
+    ):
+        to_remove = [split_node]
+        counters[backend]["scmerge_split_removed"] += 1
+        to_remove.extend(split_node.users.keys())
+        for next_user in next_users:
+            if next_user.target not in (torch.cat, torch.stack):
+                continue
+            counters[backend]["scmerge_cat_removed"] += 1
+            to_remove.append(next_user)
+        for node in reversed(to_remove):
+            if len(node.users.keys()) == 0:
+                graph.erase_node(node)
+
+
+class UnbindCatRemover(SplitCatSimplifier):
+    """
+    Helper class to merge Unbind->Cat/Stack. Many of the cases are similar to SplitCatSimplifier.
+
+    Unbind can't be simplified like splits. So, we can only remove the unbind node. Other than this,
+    other cases like multiple users, additional args, dim mismatch are similar to `SplitCatSimplifier`,
+    hence we extend that class.
+    """
+
+    def remove_unbind(
+        self,
+        graph: torch.fx.Graph,
+        unbind_node: torch.fx.Node,
+    ):
+        if not is_node_meta_valid(unbind_node):
+            return
+        # we need to check if the getitem indices from unbind are consecutive and all go to the same cat node
+        # before we do the unbind remove, otherwise it will hit the error when we unbind part of them
+        getitem_indices = [getitem_node.args[1] for getitem_node in unbind_node.users]
+        if not is_sorted_and_consecutive(getitem_indices) or len(  # type: ignore[arg-type]
+            getitem_indices
+        ) != len(unbind_node.meta["example_value"]):
+            return
+        num_unbind = len(getitem_indices)
+        split_sections = [1 for _ in range(num_unbind)]  # type: ignore[operator, arg-type]
+
+        super().simplify(graph, unbind_node, split_sections)
+
+    def get_simplified_split_ranges(
+        self,
+        split_sections: list[int],
+        next_users: list[torch.fx.Node],
+        user_inputs_list: list[list[torch.fx.Node | _Range]],
+    ) -> list[_Range] | None:
+        simplified_split_ranges = super().get_simplified_split_ranges(
+            split_sections, next_users, user_inputs_list
+        )
+        if not simplified_split_ranges or len(simplified_split_ranges) != 1:
+            return None
+        return simplified_split_ranges
+
+    def get_transform_params(
+        self,
+        split_node: torch.fx.Node,
+        next_users: list[torch.fx.Node],
+        user_inputs_list: list[list[torch.fx.Node | _Range]],
+    ) -> list[list[_TransformParam]] | None:
+        """
+        Figure out what transforms are needed for each input to each cat node.
+
+        Here is the rough transforms we apply:
+
+        x -> unbind -> stack => x -> movedim
+
+        x -> unbind -> cat => x -> movedim -> flatten
+
+        When cat/stack nodes have additional args:
+
+             addn ---|              addn -> unsqueeze ---|
+        x -> unbind -> stack  =>           x -> movedim  -> cat
+
+             addn ---|                            addn ---|
+        x -> unbind -> cat  =>   x -> movedim -> flatten  -> cat
+
+        (Note application of these depends on the dims as well)
+
+
+        """
+        split_dim = _get_dim(split_node)
+        transform_params_list: list[list[_TransformParam]] = []
+        for user_node, user_inputs in zip(next_users, user_inputs_list):
+            cat_dim = get_arg_value(user_node, 1, "dim") or 0
+            transform_params: list[_TransformParam] = []
+            for user_input in user_inputs:
+                if isinstance(user_input, tuple):
+                    # User input is coming from unbind
+                    movedim_params = (
+                        (split_dim, cat_dim) if split_dim != cat_dim else None
+                    )
+                    flatten_params = None
+                    if user_node.target is torch.cat:
+                        flatten_params = (cat_dim, cat_dim + 1)
+                    transform_params.append(
+                        (None, movedim_params, None, flatten_params)
+                    )
+                elif (
+                    user_node.target is torch.stack
+                ):  # We need to unsqueeze inputs not coming through unbind into cat
+                    transform_params.append((None, None, (cat_dim,), None))
+                else:  # Non-unbind inputs
+                    transform_params.append((None, None, None, None))
+            transform_params_list.append(transform_params)
+        return transform_params_list
+
+
+class GetItem(CallFunction):
+    def __init__(self, arg, index, _users=1) -> None:
+        super().__init__(operator.getitem, arg, index, _users=_users)
+
+    def find_anchor_nodes(self, ctx: MatchContext, searched: OrderedSet[torch.fx.Node]):
+        # We generally match GetItem with arg being an Arg(). So, we never return the anchor
+        # nodes as the stored node in ctx.pattern_to_node is returned. Here we override find_anchor_nodes
+        # to not use ctx.pattern_to_node
+        for pattern in self.flat_args_kwargs[0]:
+            if isinstance(pattern, PatternExpr):
+                for other_node in pattern.find_anchor_nodes(ctx, searched):
+                    if not isinstance(other_node, torch.fx.Node):
+                        continue
+                    for node in other_node.users:
+                        if node not in searched:
+                            if self._match_fns(node):
+                                yield node
+                                searched.add(node)
+
+
+@register_graph_pattern(
+    RepeatedExpr(
+        CallFunction(
+            torch.squeeze,
+            GetItem(
+                TorchSplit(
+                    KeywordArg("split_input"),
+                    KeywordArg("split_sizes"),
+                ),
+                Ignored(),
+            ),
+            KeywordArg("dim"),
+            _users=MULTIPLE,
+        ),
+    ),
+    pass_dict=construct_pattern_matcher_pass("split_cat_pass"),
+)
+@register_graph_pattern(
+    RepeatedExpr(
+        CallFunction(
+            torch.squeeze,
+            GetItem(
+                TorchSplit(
+                    KeywordArg("split_input"),
+                    KeywordArg("split_sizes"),
+                ),
+                Ignored(),
+            ),
+            dim=KeywordArg("dim"),
+            _users=MULTIPLE,
+        )
+    ),
+    pass_dict=construct_pattern_matcher_pass("split_cat_pass"),
+)
+def merge_split_squeeze(
+    match: Match, split_input: torch.fx.Node, split_sizes: list[int], dim: int
+):
+    graph = match.graph
+    split = next(node for node in match.nodes if node.target is torch.split)
+    if not all(s == 1 for s in split_sizes):
+        return
+    if isinstance(dim, Sequence):
+        return
+    next_users = find_next_users(split)
+    if not all(node.target is torch.squeeze for node in next_users):
+        return
+    with graph.inserting_before(match.output_node()):
+        unbind = graph.call_function(
+            torch.unbind, args=(split_input,), kwargs={"dim": dim}
+        )
+        if is_node_meta_valid(split_input):
+            unbind.meta["example_value"] = torch.unbind(
+                split_input.meta["example_value"], dim=dim
+            )
+        for item_index, getitem_node in sorted(
+            [(getitem_node.args[1], getitem_node) for getitem_node in split.users]
+        ):
+            squeeze = next(iter(getitem_node.users.keys()))
+            new_get_item = graph.call_function(
+                operator.getitem, args=(unbind, item_index)
+            )
+            squeeze.replace_all_uses_with(new_get_item)
+            new_get_item.meta.update(squeeze.meta)
+            graph.erase_node(squeeze)
+            graph.erase_node(getitem_node)
+    graph.erase_node(split)
+    counters[backend]["split_cat_pass"] += 1
+
+
+getitem_unbind = ListOf(
+    GetItem(
+        CallFunction(
+            torch.unbind,
+            KeywordArg("unbind_input"),
+            dim=KeywordArg("dim"),
+            _users=MULTIPLE,
+        ),
+        Ignored(),
+        _users=MULTIPLE,
+    ),
+    partial=True,
+)
+
+
+@register_graph_pattern(
+    CallFunction([torch.stack, torch.cat], getitem_unbind, Ignored(), _users=MULTIPLE),
+    pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"),
+)
+@register_graph_pattern(
+    CallFunction(
+        [torch.stack, torch.cat], getitem_unbind, dim=Ignored(), _users=MULTIPLE
+    ),
+    pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"),
+)
+@register_graph_pattern(
+    CallFunction(
+        [torch.stack, torch.cat], tensors=getitem_unbind, dim=Ignored(), _users=MULTIPLE
+    ),
+    pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"),
+)
+def merge_unbind_stack(match: Match, unbind_input: torch.fx.Node, dim: int):
+    unbind_node = next(node for node in match.nodes if node.target is torch.unbind)
+    UnbindCatRemover().remove_unbind(match.graph, unbind_node)
+
+
+getitem_split = ListOf(
+    CallFunction(
+        operator.getitem,
+        TorchSplit(
+            Ignored(),
+            KeywordArg("split_sections"),
+        ),
+        Ignored(),
+        _users=MULTIPLE,
+    ),
+    partial=True,
+)
+
+
+reshape_getitem_split = ListOf(
+    CallFunction(
+        torch.reshape,
+        CallFunction(
+            operator.getitem,
+            TorchSplit(
+                Ignored(),
+                KeywordArg("split_sections"),
+            ),
+            Ignored(),
+            _users=MULTIPLE,
+        ),
+        Arg(),
+        _users=MULTIPLE,
+    ),
+    partial=True,
+)
+
+
+@register_graph_pattern(
+    CallFunction(
+        [torch.stack, torch.cat],
+        tensors=getitem_split,
+        dim=Ignored(),
+        _users=MULTIPLE,
+    ),
+    pass_dict=construct_pattern_matcher_pass("split_cat_pass"),
+)
+@register_graph_pattern(
+    CallFunction(
+        [torch.stack, torch.cat],
+        getitem_split,
+        dim=Ignored(),
+        _users=MULTIPLE,
+    ),
+    pass_dict=construct_pattern_matcher_pass("split_cat_pass"),
+)
+@register_graph_pattern(
+    CallFunction(
+        [torch.stack, torch.cat],
+        getitem_split,
+        Ignored(),
+        _users=MULTIPLE,
+    ),
+    pass_dict=construct_pattern_matcher_pass("split_cat_pass"),
+)
+def simplify_split_cat(match: Match, split_sections: list[int], dim: int):
+    if not isinstance(split_sections, (list, tuple)):  # Unnormalized split
+        return
+    split_node = next(node for node in match.nodes if node.target is torch.split)
+    # pyrefly: ignore [bad-argument-type]
+    SplitCatSimplifier().simplify(match.graph, split_node, split_sections)
+
+
+# noqa: W605
+# ############pattern to be optimized is#########
+
+#                 split_node(dim=1)
+#       /     \         ...       /         \
+# getitem    getitem          getitem     getitem   -> user=1
+#    \       /                     \       /
+#      cat (user=mul, dim=1)           cat(user=mul, dim=1)
+#       |            \                   |          \
+
+# ################after transformation#############
+
+#                 split_node(dim=1)
+#       /              ...                  \
+#     getitem                             getitem
+#     |    \                              |     \
+
+
+def has_same_parent_node(node: torch.fx.Node):
+    # the input nodes of the node should come from the same parent
+    prev_node = None
+    for getitem in node.args[0]:  # type: ignore[union-attr]
+        if getitem.target != operator.getitem:  # type: ignore[union-attr]
+            return False
+        if prev_node is None:
+            prev_node = getitem.args[0]  # type: ignore[union-attr]
+        else:
+            if getitem.args[0] != prev_node:  # type: ignore[union-attr]
+                return False
+    return True
+
+
+def remove_zeros(split_sections: list[int]):
+    """
+    Remove zeros from the list and get the index mapping dict from getitem
+    in split node to getitem in new split node
+    """
+    new_split_sections, index_mapping = [], {}
+    idx = 0
+    for i in range(len(split_sections)):
+        if split_sections[i] > 0:
+            new_split_sections.append(split_sections[i])
+            index_mapping[i] = idx
+            idx += 1
+
+    return new_split_sections, index_mapping
+
+
+def is_sorted_and_consecutive(arr: list[int]) -> bool:
+    # check if the array is sorted
+    if arr == sorted(arr):
+        # check if the differences between adjacent elements are all 1
+        return all(x[1] - x[0] == 1 for x in itertools.pairwise(arr))
+    else:
+        return False
+
+
+def calculate_fused_tensor_size(split_node: torch.fx.Node, indices: list[int]) -> int:
+    """
+    Calculate the fused tensor size in the indices
+    """
+    fused_tensor_size = 0
+    for i in range(len(split_node.args[1])):  # type: ignore[arg-type]
+        if i in indices:
+            fused_tensor_size += split_node.args[1][i]  # type: ignore[operator, assignment, index]
+    # pyrefly: ignore [bad-return]
+    return fused_tensor_size
+
+
+@register_graph_pattern(
+    CallFunction(
+        torch.cat,
+        getitem_split,
+        dim=Ignored(),
+        _users=MULTIPLE,
+    ),
+    pass_dict=construct_pattern_matcher_pass("merge_getitem_cat_pass"),
+)
+def merge_getitem_cat(match: Match, split_sections: list[int], dim: int):
+    if not isinstance(split_sections, (list, tuple)):  # Unnormalized split
+        return
+    graph = match.graph
+    split_node = next(node for node in match.nodes if node.target is torch.split)
+    split_input, _split_size, split_dim = _get_split_args_default(split_node)
+    # if the cat and split have different dims, return
+    # Find the next users (i.e. users after the getitem)
+    next_users = find_next_users(split_node)
+    # 'immutable_list' object does not support mutation. Create a new copy of it
+    split_sections = list(split_sections)
+    for cat_user in next_users:
+        if cat_user.target is torch.cat:
+            cat_dim = get_arg_value(cat_user, 1, "dim")
+            # check the all getitems in the cat_user from the same node
+            # check the input of the cat has all getitem from the split
+            # check all getitem only has one single user
+            if (
+                split_dim != cat_dim
+                or not has_same_parent_node(cat_user)
+                or not all(len(arg.users) == 1 for arg in cat_user.args[0])  # type: ignore[union-attr]
+            ):
+                continue
+            # find the index of getitems to be cated/stacked
+            # type: ignore[union-attr]
+            indices = [arg.args[1] for arg in cat_user.args[0]]  # type: ignore[union-attr]
+            # the getitems to be merged must be consecutive, otherwise
+            # returned sliced tensor could be wrong
+            if not is_sorted_and_consecutive(indices):  # type: ignore[arg-type]
+                continue
+            # update the arg of cat user, only keep the first getitem
+            cat_user.update_arg(0, cat_user.args[0][0])  # type: ignore[index]
+            # calculate the fused tensor sizes in the indices
+            fused_tensor_size = 0
+            for i in range(len(split_node.args[1])):  # type: ignore[arg-type]
+                if i in indices:
+                    fused_tensor_size += split_node.args[1][i]  # type: ignore[operator, assignment, index]
+            # update the split sections
+            split_sections[indices[0]] = calculate_fused_tensor_size(  # type: ignore[index]
+                split_node,
+                indices,  # type: ignore[arg-type]
+            )
+            # padding others with zeros to keep the same dict size
+            for i in indices[1:]:
+                split_sections[i] = 0  # type: ignore[index]
+            # remove all unused indexes in the split_node
+            new_split_sections, index_mapping = remove_zeros(split_sections)
+            with graph.inserting_after(split_node):
+                new_split_node = graph.call_function(
+                    torch.split,
+                    args=(split_input, split_sections),
+                    kwargs={"dim": split_dim},
+                )
+                split_node.replace_all_uses_with(new_split_node)
+                new_split_node.meta.update(split_node.meta)
+                # remove all unused getitem nodes
+                to_remove = [cat_user]
+                # dictionary keys changed during iteration
+                new_split_getitem_nodes = list(new_split_node.users.keys())
+                for getitem_node in new_split_getitem_nodes:
+                    if getitem_node.args[1] in indices[1:]:
+                        to_remove.append(getitem_node)
+                    # update meta data of getitem
+                    elif getitem_node.args[1] == indices[0]:
+                        cat_user.replace_all_uses_with(getitem_node)
+                        getitem_node.meta.update(cat_user.meta)
+                    else:
+                        # update getitem index for new split node
+                        getitem_node.update_arg(1, index_mapping[getitem_node.args[1]])
+                graph.erase_node(split_node)
+                for getitem_node in to_remove:
+                    graph.erase_node(getitem_node)
+                # update the split sections of new split node
+                new_split_node.update_arg(1, new_split_sections)
+                split_node = new_split_node
+                split_sections = new_split_sections
+
+                counters[backend]["merge_getitem_cat_pass"] += 1
+
+
+# ############pattern to be optimized is#########
+
+#                 split_node(dim=1)  -> user=multiple
+#       /     \         ...       /         \
+# getitem    getitem          getitem     getitem   -> user=multiple
+#    \       \                    /            \
+#          other_op /cat(user=mul, dim=1)             other_op
+#                      |
+
+# ################after transformation#############
+
+#                 split_node(dim=1)         -> -> user=multiple
+#       /     \         ...       /         \
+# getitem    getitem          getitem     getitem   -> user=multiple
+#    \       \                    /           \
+#                          other_op
+
+
+@register_graph_pattern(
+    CallFunction(
+        torch.cat,
+        getitem_split,
+        dim=Ignored(),
+        _users=MULTIPLE,
+    ),
+    pass_dict=construct_pattern_matcher_pass("mutate_cat_pass"),
+)
+def mutate_cat_node(match: Match, split_sections: list[int], dim: int):
+    if not isinstance(split_sections, (list, tuple)):  # Unnormalized split
+        return
+    graph = match.graph
+    split_node = next(node for node in match.nodes if node.target is torch.split)
+    _split_input, _split_size, split_dim = _get_split_args_default(split_node)
+    # if the cat and split have different dims, return
+    # Find the next users (i.e. users after the getitem)
+    next_users = find_next_users(split_node)
+    for cat_user in next_users:
+        if cat_user.target is torch.cat:
+            cat_dim = get_arg_value(cat_user, 1, "dim") or 0
+            # check that all getitems in the cat_user from the same node
+            # check the input of the cat has all getitem from the split
+            if split_dim != cat_dim or not has_same_parent_node(cat_user):
+                continue
+            # find the index of getitems to be cat
+            indices, idx_to_getitem = [], {}
+            for getitem in cat_user.args[0]:  # type: ignore[union-attr]
+                indices.append(getitem.args[1])  # type: ignore[union-attr]
+                idx_to_getitem[getitem.args[1]] = getitem  # type: ignore[union-attr]
+            # the getitems to be merged must be consecutive, otherwise
+            # returned sliced tensor could be wrong
+            if not is_sorted_and_consecutive(indices):  # type: ignore[arg-type]
+                continue
+            # case 1: the cat uses all getitems from the split
+            if len(split_sections) == len(cat_user.args[0]):  # type: ignore[arg-type]
+                # replace the users of the cat node to be the input of the split node
+                cat_user.replace_all_uses_with(split_node.args[0])  # type: ignore[arg-type]
+                # remove the cat node
+                graph.erase_node(cat_user)
+                counters[backend]["mutate_cat_pass"] += 1
+            # case 2: the cat uses some getitems from the split
+            elif is_node_meta_valid(split_node.args[0]):  # type: ignore[arg-type]
+                # check the split dim, and construct the slice tuple
+                start_fused_size = calculate_fused_tensor_size(
+                    split_node,
+                    list(range(indices[0])),  # type: ignore[arg-type]
+                )
+                end_fused_size = start_fused_size + calculate_fused_tensor_size(
+                    split_node,
+                    indices,  # type: ignore[arg-type]
+                )
+                slice_list = []
+                for i in range(len(split_node.args[0].meta["example_value"].shape)):  # type: ignore[union-attr]
+                    if i != split_dim:
+                        slice_list.append(slice(None, None, None))
+                    else:
+                        slice_list.append(slice(start_fused_size, end_fused_size, None))
+                with graph.inserting_after(split_node):
+                    slice_node = graph.call_function(
+                        operator.getitem,
+                        args=(split_node.args[0], tuple(slice_list)),
+                    )
+                    cat_user.replace_all_uses_with(slice_node)
+                    slice_node.meta.update(cat_user.meta)
+
+                # remove the cat node
+                graph.erase_node(cat_user)
+                counters[backend]["mutate_cat_pass"] += 1
+
+
+getitem_split_aten = ListOf(
+    CallFunction(
+        operator.getitem,
+        CallFunctionVarArgs([torch.ops.aten.split_with_sizes.default], users=MULTIPLE),
+        Ignored(),
+        _users=MULTIPLE,
+    ),
+    partial=True,
+)
+
+
+@register_graph_pattern(
+    CallFunctionVarArgs(torch.ops.aten.split.Tensor, users=MULTIPLE),
+    pass_dict=construct_pattern_matcher_pass("normalization_aten_pass"),
+)
+def normalize_split_default_aten(match: Match, *args, **kwargs):
+    split_node = match.nodes[0]
+    graph = match.graph
+    split_input, split_size, split_dim = _get_split_args_default(split_node)
+    if split_input is None or split_dim is None or split_size is None:
+        log.debug("couldn't find split args")
+        return
+    if not is_node_meta_valid(split_node):
+        log.debug("val absent for node: %s", split_node)
+        return
+    assert isinstance(split_node.meta["val"], (list, tuple))
+    split_sections = [t.size()[split_dim] for t in split_node.meta["val"]]
+    if any(isinstance(section, torch.SymInt) for section in split_sections):
+        # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing.
+        return
+    if split_dim < 0:  # Normalize split dim
+        split_dim += split_input.meta["val"].dim()
+    # we also need to check the input of the split_node
+    # primals =torch.randn(4096, 300)
+    # split = torch.ops.aten.split.Tensor(primals, 320, 1) -> truncate to 300 automatically
+    # split_2 = torch.ops.aten.split_with_sizes.default(primals, [320], dim = 1) -> runtime error
+    split_input_size = split_input.meta["val"].shape[split_dim]
+    split_size = min(split_size, split_input_size)
+    split_section_list = [split_size] * (len(split_node.meta["val"]))
+    new_args = (split_input, split_section_list)
+    new_kwargs = {"dim": split_dim}
+    if (
+        split_node.args == new_args
+        and split_node.kwargs == new_kwargs
+        and split_node.op == "call_function"
+    ):
+        return
+
+    with graph.inserting_after(split_node):
+        new_split_node = graph.call_function(
+            torch.ops.aten.split_with_sizes.default,
+            args=new_args,
+            kwargs=new_kwargs,  # type: ignore[arg-type]
+        )
+    split_node.replace_all_uses_with(new_split_node)
+    new_split_node.meta.update(split_node.meta)
+    graph.erase_node(split_node)
+    counters[backend]["normalization_aten_pass"] += 1
+
+
+@register_graph_pattern(
+    CallFunctionVarArgs(torch.ops.aten.split_with_sizes.default, users=MULTIPLE),
+    pass_dict=construct_pattern_matcher_pass("normalization_aten_pass"),
+)
+def normalize_split_with_size_default_aten(match: Match, *args, **kwargs):
+    split_node = match.nodes[0]
+    graph = match.graph
+    split_input, split_sections, split_dim = _get_split_args_default(split_node)
+    if split_input is None or split_dim is None or split_sections is None:
+        log.debug("couldn't find split args")
+        return
+    if not is_node_meta_valid(split_node):
+        log.debug("val absent for node: %s", split_node)
+        return
+    if any(isinstance(section, torch.SymInt) for section in split_sections):
+        # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing.
+        return
+    if split_dim < 0:  # Normalize split dim
+        split_dim += split_input.meta["val"].dim()
+
+    new_args = (split_input, split_sections)
+    new_kwargs = {"dim": split_dim}
+    if (
+        split_node.args == new_args
+        and split_node.kwargs == new_kwargs
+        and split_node.op == "call_function"
+    ):
+        return
+
+    with graph.inserting_after(split_node):
+        new_split_node = graph.call_function(
+            torch.ops.aten.split_with_sizes.default,
+            args=new_args,
+            kwargs=new_kwargs,  # type: ignore[arg-type]
+        )
+    split_node.replace_all_uses_with(new_split_node)
+    new_split_node.meta.update(split_node.meta)
+    graph.erase_node(split_node)
+    counters[backend]["normalization_aten_pass"] += 1
+
+
+@register_graph_pattern(
+    CallFunction(
+        torch.ops.aten.cat.default,
+        getitem_split_aten,
+        dim=Ignored(),
+        _users=MULTIPLE,
+    ),
+    pass_dict=construct_pattern_matcher_pass("split_cat_aten_pass"),
+)
+def merge_split_cat_aten(match: Match, *args, **kwargs):
+    graph = match.graph
+    split_node = match.nodes[0]
+    threshold_to_cat = torch._inductor.config.post_grad_fusion_options[
+        "split_cat_aten_pass"
+    ].get("threshold_to_cat", 10)
+    # get the getitem nodes from the split node
+    getitem_nodes = list(split_node.users.keys())
+    for cat_node in list(getitem_nodes[0].users.keys()):
+        cat_dim = get_arg_value(cat_node, 1, "dim")
+        cat_inputs = get_arg_value(cat_node, 0, "tensors")
+        try:
+            cat_input_len = len(cat_inputs)
+        except TypeError:
+            continue
+        if cat_input_len < threshold_to_cat:
+            continue
+        # check split node and cat node has same dim, and all getitem nodes have same parent node
+        parent_to_indices = defaultdict(list)  # type: ignore[var-annotated]
+        parent_to_getitems = defaultdict(list)  # type: ignore[var-annotated]
+        for cat_input in cat_inputs:
+            # skip all non-getitem cat input
+            if cat_input.target != operator.getitem:
+                continue
+            current_getitem_parent = cat_input.args[0]
+            split_dim = get_arg_value(current_getitem_parent, 2, "dim")
+            if split_dim != cat_dim:
+                break
+            getitem_idx = cat_input.args[1]
+            if (
+                current_getitem_parent not in parent_to_indices
+            ) or getitem_idx != parent_to_indices[current_getitem_parent][-1][-1] + 1:
+                parent_to_indices[current_getitem_parent].append([getitem_idx])
+                parent_to_getitems[current_getitem_parent].append([cat_input])
+            else:
+                parent_to_getitems[current_getitem_parent][-1].append(cat_input)
+                parent_to_indices[current_getitem_parent][-1].append(getitem_idx)
+
+        cat_inputs_list = list(cat_inputs)
+        update_cat_arg = []
+        # iterate through the indices to construct the slice nodes
+        for parent, indices in parent_to_indices.items():
+            for idx, indice in enumerate(indices):
+                start, end = indice[0], indice[-1]
+                split_sections = list(parent.args[1])
+                input_of_current_getitem_parent = parent.args[0]
+                if len(indice) >= threshold_to_cat or len(indice) == len(
+                    split_sections
+                ):
+                    if len(indice) != len(split_sections):
+                        # get the start and end slicing indices
+                        slice_node = graph.call_function(
+                            torch.ops.aten.slice.Tensor,
+                            args=(
+                                input_of_current_getitem_parent,
+                                split_dim,  # type: ignore[possibly-undefined]
+                                sum(split_sections[:start]),
+                                sum(split_sections[: end + 1]),
+                            ),
+                        )
+                    else:
+                        slice_node = input_of_current_getitem_parent
+                    # find the index in the cat_inputs_list given the getitem node
+                    update_cat_arg.append(
+                        (
+                            slice_node,
+                            cat_inputs_list.index(parent_to_getitems[parent][idx][0]),
+                            cat_inputs_list.index(parent_to_getitems[parent][idx][-1]),
+                        )
+                    )
+
+        result = []
+        i = 0
+        for slice_tensor, start, end in update_cat_arg:
+            while i < start:
+                result.append(cat_inputs_list[i])
+                i += 1
+            result.append(slice_tensor)
+            i = end + 1
+        while i < len(cat_inputs_list):
+            result.append(cat_inputs_list[i])
+            i += 1
+
+        cat_node.update_arg(0, result)
+        for getitem_node in getitem_nodes:
+            if len(getitem_node.users) == 0:
+                graph.erase_node(getitem_node)
+        if len(split_node.users) == 0:
+            graph.erase_node(split_node)
+        counters[backend]["split_cat_aten_pass"] += 1
+
+
+@register_graph_pattern(
+    CallFunction(
+        torch.ops.aten.cat.default,
+        ListOf(
+            CallFunctionVarArgs(torch.ops.aten.select.int, users=MULTIPLE),
+            partial=True,
+        ),
+        dim=Ignored(),
+        _users=MULTIPLE,
+    ),
+    pass_dict=construct_pattern_matcher_pass("select_cat_aten_pass"),
+)
+def merge_select_cat_aten(match: Match, *args, **kwargs):
+    graph = match.graph
+    node = match.nodes[0]
+    node_input = get_arg_value(node, 0, "tensors")
+    # get the select nodes from the node
+    select_nodes = list(node_input.users.keys())
+    for cat_node in list(node.users.keys()):
+        if cat_node.target is torch.ops.aten.cat.default:
+            cat_dim = get_arg_value(cat_node, 1, "dim")
+            cat_inputs = get_arg_value(cat_node, 0, "tensors")
+            # check all select nodes has same slice dim
+            if not all(
+                select_node.args[1] == select_nodes[0].args[1]
+                for select_node in select_nodes
+            ):
+                continue
+            # We only consider the case where selece slice dim and cat node has same dim
+            if select_nodes[0].args[1] != cat_dim:
+                continue
+            if not is_node_meta_valid(cat_node):
+                continue
+            # check the cat node has consecutive indices
+            indices = [select.args[2] for select in cat_node.args[0]]  # type: ignore[union-attr]
+            if (
+                not is_sorted_and_consecutive(indices)  # type: ignore[arg-type]
+                or len(select_nodes) != len(cat_inputs)
+            ):
+                continue
+            # check all the select nodes can be merged to the cat node input
+            if len(indices) != select_nodes[0].args[0].meta["val"].shape[cat_dim]:  # type: ignore[union-attr]
+                continue
+            # reshape the node input to be the same shape as the cat node
+            with graph.inserting_before(node):
+                view_node = graph.call_function(
+                    torch.ops.aten.view.default,
+                    args=(node_input, cat_node.meta["val"].shape),
+                )
+            # replace the node input with the new node
+            cat_node.replace_all_uses_with(view_node)
+            view_node.meta.update(cat_node.meta)
+            # remove the cat node
+            graph.erase_node(cat_node)
+            for select_node in select_nodes:
+                if len(select_node.users) == 0:
+                    graph.erase_node(select_node)
+            counters[backend]["select_cat_aten_pass"] += 1
+
+
+@register_graph_pattern(
+    CallFunctionVarArgs(torch.ops.aten.cat.default, users=MULTIPLE),
+    pass_dict=construct_pattern_matcher_pass("normalization_aten_pass"),
+)
+def normalize_cat_default_aten(match: Match, *args, **kwargs):
+    cat_node = match.nodes[0]
+    graph = match.graph
+    tensors = get_arg_value(cat_node, 0, "tensors")
+    cat_dim = get_arg_value(cat_node, 1, "dim")
+    if cat_dim is None:
+        cat_axis = cat_node.kwargs.get("axis")
+        if cat_axis is not None:
+            cat_dim = cat_axis
+        else:
+            cat_dim = 0
+    if tensors is None or cat_dim is None:
+        log.debug("couldn't find cat args")
+        return
+    assert isinstance(tensors, (list, tuple))
+    for tensor in itertools.chain([cat_node], tensors):
+        if "val" not in tensor.meta:
+            log.debug("val absent for node: %s", tensor)
+            return
+
+    ndim = cat_node.meta["val"].dim()
+
+    def is_empty_tensor(x: torch.fx.Node) -> bool:
+        # special case where torch.ops.aten.cat.default supports cat'ing with an empty tensor
+        x_shape = x.meta["val"].shape
+        return len(x_shape) == 1 and x_shape[0] == 0
+
+    assert all(ndim == x.meta["val"].dim() or is_empty_tensor(x) for x in tensors)
+
+    # pyrefly: ignore [unsupported-operation]
+    if cat_dim < 0:  # Normalize cat dim
+        cat_dim += ndim
+
+    with graph.inserting_after(cat_node):
+        new_cat_node = graph.call_function(
+            torch.ops.aten.cat.default,
+            args=(tensors,),
+            kwargs={"dim": cat_dim},
+        )
+    cat_node.replace_all_uses_with(new_cat_node)
+    new_cat_node.meta.update(cat_node.meta)
+    graph.erase_node(cat_node)
+    counters[backend]["normalization_aten_pass"] += 1
+
+
+@register_graph_pattern(
+    CallFunction(
+        torch.ops.aten.cat,
+        ListOf(CallFunctionVarArgs(torch.ops.aten.unsqueeze)),
+        _users=MULTIPLE,
+    ),
+    pass_dict=construct_pattern_matcher_pass("unbind_stack_aten_pass"),
+)
+def merge_unbind_stack_aten(match: Match, *args, **kwargs):
+    node = match.nodes[-1]
+    graph = match.graph
+    # pyre-fixme[6]
+    unsqueeze_nodes = list(node.args[0])  # type: ignore[arg-type]
+    cat_dim = get_arg_value(node, 1, "dim")
+    # check the unsqueeze nodes come from the select nodes
+    if not all(
+        get_arg_value(unsqueeze_node, 0, "input").target is torch.ops.aten.select
+        for unsqueeze_node in unsqueeze_nodes
+    ):
+        return
+    select_nodes = [
+        get_arg_value(unsqueeze_node, 0, "input") for unsqueeze_node in unsqueeze_nodes
+    ]
+    parent_of_select_node = get_arg_value(select_nodes[0], 0, "input")
+    # check the target of select_nodes are the same
+    if not all(
+        select_node.target is torch.ops.aten.select for select_node in select_nodes
+    ):
+        return
+    # check the select nodes come from the same parent node
+    if not all(
+        get_arg_value(select_node, 0, "input") == parent_of_select_node
+        for select_node in select_nodes
+    ):
+        return
+    if len(unsqueeze_nodes) != len(select_nodes):
+        return
+    # check the select nodes have the same dim
+    if not all(
+        get_arg_value(select_node, 1, "dim") == cat_dim for select_node in select_nodes
+    ):
+        return
+    # check the select nodes have consecutive indices starting from 0
+    if get_arg_value(select_nodes[0], 2, "index") != 0 or not is_sorted_and_consecutive(
+        [get_arg_value(select_node, 2, "index") for select_node in select_nodes]
+    ):
+        return
+    # check the users of parent of select node only from unsqueeze nodes that go to the cat node
+    # we simply check the number of users of the parent of select node
+    if len(parent_of_select_node.users.keys()) != len(node.args[0]):  # type: ignore[arg-type]
+        return
+    node.replace_all_uses_with(parent_of_select_node)
+    graph.erase_node(node)
+    for unsqueeze_node in unsqueeze_nodes:
+        graph.erase_node(unsqueeze_node)
+    for select_node in select_nodes:
+        if len(select_node.users) == 0:
+            graph.erase_node(select_node)
+    counters[backend]["unbind_stack_aten_pass"] += 1
+
+
+def divide_into_consecutive_sublists(indices: list[int]) -> list[list[int]]:
+    n = len(indices)
+    if n <= 1:
+        return [indices]
+
+    # Initialize the list of sublists
+    sublists = []
+
+    # Iterate over the indices
+    i = 0
+    while i < n:
+        # Initialize the current sublist
+        sublist = [indices[i]]
+
+        # Iterate over the remaining indices
+        j = i + 1
+        while j < n and indices[j] == indices[j - 1] + 1:
+            # Add the next index to the current sublist
+            sublist.append(indices[j])
+            j += 1
+
+        # Add the current sublist to the list of sublists
+        sublists.append(sublist)
+        # Move to the next index
+        i = j
+
+    return sublists
+
+
+def update_args_from_split_getitem(
+    graph: torch.fx.Graph,
+    node: torch.fx.Node,
+    getitem_indices: list[int],
+    parents_seen: list[torch.fx.Node],
+    new_cat_args: list[torch.fx.Node],
+    new_cat_args_meta: list[torch.fx.Node],
+    idx_to_getitems: dict[int, torch.fx.Node],
+    threshold_to_cat: int = 2,
+):
+    split_input, split_size, split_dim = _get_split_args_default(parents_seen[-1])
+    # case 1: the number of getitems is the same as the split size, eliminate the split
+    if len(split_size) == len(getitem_indices) and is_sorted_and_consecutive(
+        getitem_indices
+    ):
+        # we can merge the getitems from the previous parent
+        new_cat_args.append(split_input)
+        new_cat_args_meta.append(split_input.meta["example_value"])
+    else:
+        if len(getitem_indices) > 0:
+            # case 2: the number of getitems is smaller than the split size but larger than the threshold, and
+            # the indices of getitems are not all consecutive, we need to divide the indices into multiple groups
+            geitem_indices_sublist = divide_into_consecutive_sublists(getitem_indices)
+            for sublist in geitem_indices_sublist:
+                if len(sublist) >= threshold_to_cat:
+                    # case 2: the number of getitems is smaller than the split size but larger than the threshold
+                    # we need to slice the input of parent
+                    start_fused_size = sum(split_size[: sublist[0]])
+                    end_fused_size = sum(split_size[: sublist[-1] + 1])
+                    slice_list = []
+                    for i in range(len(split_input.meta["example_value"].shape)):  # type: ignore[union-attr]
+                        if i != split_dim:
+                            slice_list.append(slice(None, None, None))
+                        else:
+                            slice_list.append(
+                                slice(start_fused_size, end_fused_size, None)
+                            )
+                    with graph.inserting_after(node):
+                        slice_node = graph.call_function(
+                            operator.getitem,
+                            args=(split_input, tuple(slice_list)),
+                        )
+                        slice_node.meta["example_value"] = split_input.meta[
+                            "example_value"
+                        ][tuple(slice_list)]
+                        new_cat_args.append(slice_node)
+                        new_cat_args_meta.append(slice_node.meta["example_value"])
+                else:
+                    # case 3: the number of getitems is smaller than the threshold, no merge is done
+                    # get the getitems based on the indexes
+                    for i in sublist:
+                        new_cat_args.append(idx_to_getitems[i])
+                        new_cat_args_meta.append(
+                            idx_to_getitems[i].meta["example_value"]
+                        )
+
+
+def reshape_cat_node(
+    graph: torch.fx.Graph,
+    cat_node: torch.fx.Node,
+    unbind_input: torch.fx.Node,
+    cat_dim: int,
+    unbind_dim: int,
+    cat_shape: torch.Size,
+) -> torch.fx.Node:
+    if cat_dim != unbind_dim:
+        # construct the permute node args, which has the same shape as the slice node
+        # then it has the same dim as the unbind_input, i.e., shape of cat + 1
+        with graph.inserting_after(cat_node):
+            permute_list = list(range(len(cat_shape) + 1))
+            permute_list[unbind_dim], permute_list[cat_dim] = (
+                permute_list[cat_dim],
+                permute_list[unbind_dim],
+            )
+            permute_node = graph.call_function(
+                torch.permute,
+                args=(unbind_input, permute_list),
+            )
+            permute_node.meta["example_value"] = torch.permute(
+                unbind_input.meta["example_value"], permute_list
+            )  # type: ignore[arg-type]
+    else:
+        permute_node = unbind_input
+    with graph.inserting_after(permute_node):
+        reshape_node = graph.call_function(
+            torch.reshape, args=(permute_node, tuple(cat_shape))
+        )
+        reshape_node.meta["example_value"] = torch.reshape(
+            permute_node.meta["example_value"], tuple(cat_shape)
+        )  # type: ignore[arg-type]
+    return reshape_node
+
+
+def update_args_from_unbind_getitem(
+    graph: torch.fx.Graph,
+    node: torch.fx.Node,  # cat or stack node
+    getitem_indices: list[int],
+    parents_seen: list[torch.fx.Node],
+    new_cat_args: list[torch.fx.Node],
+    new_cat_args_meta: list[torch.fx.Node],
+    idx_to_getitems: dict[int, torch.fx.Node],
+    threshold_to_cat: int = 2,
+):
+    unbind_input = get_arg_value(parents_seen[-1], 0, "input")  # split or unbind input
+    unbind_dim = get_arg_value(parents_seen[-1], 1, "dim")  # split or unbind dim
+    cat_dim = get_arg_value(node, 1, "dim")  # cat or stack dim
+    # case 1: the number of getitems is the same as the split size, eliminate the split
+    size = list(unbind_input.meta["example_value"].shape)[unbind_dim]
+    if size == len(getitem_indices):
+        cat_shape = torch.cat(
+            [idx_to_getitems[i].meta["example_value"] for i in getitem_indices],
+            dim=cat_dim,
+        ).shape
+        # we can merge the getitems from the previous parent
+        reshape_node = reshape_cat_node(
+            graph, node, unbind_input, cat_dim, unbind_dim, cat_shape
+        )
+        new_cat_args.append(reshape_node)
+        new_cat_args_meta.append(reshape_node.meta["example_value"])
+    elif len(getitem_indices) >= threshold_to_cat and is_sorted_and_consecutive(
+        getitem_indices
+    ):
+        # case 2: the number of getitems is smaller than the split size but larger than the threshold
+        # we need to slice the input of parent
+        cat_shape = torch.cat(
+            [idx_to_getitems[i].meta["example_value"] for i in getitem_indices],
+            dim=cat_dim,
+        ).shape
+        slice_list = []
+        for i in range(len(cat_shape) + 1):
+            if i != unbind_dim:
+                slice_list.append(slice(None, None, None))  # start, end, step
+            else:
+                slice_list.append(
+                    slice(getitem_indices[0], getitem_indices[-1] + 1, None)
+                )
+        with graph.inserting_after(node):
+            slice_node = graph.call_function(
+                operator.getitem,
+                args=(unbind_input, tuple(slice_list)),
+            )
+            slice_node.meta["example_value"] = torch.narrow(
+                unbind_input.meta["example_value"],
+                unbind_dim,
+                getitem_indices[0],
+                getitem_indices[-1] - getitem_indices[0] + 1,
+            )
+            reshape_node = reshape_cat_node(
+                graph, node, slice_node, cat_dim, unbind_dim, cat_shape
+            )
+            new_cat_args.append(reshape_node)
+            new_cat_args_meta.append(reshape_node.meta["example_value"])
+    else:
+        # case 3: the number of getitems is smaller than the threshold, no merge is done
+        # get the getitems based on the indexes
+        for i in getitem_indices:
+            new_cat_args.append(idx_to_getitems[i])
+            new_cat_args_meta.append(idx_to_getitems[i].meta["example_value"])
+
+
+def construct_cat_args(
+    graph: torch.fx.Graph,
+    cat_or_stack_node: torch.fx.Node,
+    inputs: list[torch.fx.Node],
+    split_or_unbind_node: torch.fx.Node,
+    threshold_to_cat: int = 2,
+    run_update_func: Callable = update_args_from_split_getitem,  # type: ignore[type-arg]
+) -> tuple[list[torch.fx.Node], list[torch.Tensor]]:
+    new_cat_args, parents_seen, getitem_indices, idx_to_getitems = [], [], [], {}  # type: ignore[var-annotated]
+    new_cat_args_meta = []  # type: ignore[var-annotated]
+    for input in inputs:
+        if input.target != operator.getitem:
+            # update the last arg based on getitem_indices and parents_seens
+            if len(parents_seen) > 0:
+                run_update_func(  # type: ignore[arg-type, union-attr]
+                    graph,
+                    cat_or_stack_node,
+                    getitem_indices,
+                    parents_seen,
+                    new_cat_args,
+                    new_cat_args_meta,
+                    idx_to_getitems,  # type: ignore[arg-type, union-attr]
+                    threshold_to_cat,
+                )
+            new_cat_args.append(input)
+            new_cat_args_meta.append(input.meta["example_value"])
+            # reset the indices array
+            getitem_indices, idx_to_getitems = [], {}
+        else:
+            # get the parent node of the getitem input
+            parent, idx = input.args[0], input.args[1]  # type: ignore[union-attr]
+            if parent.target != split_or_unbind_node.target:  # type: ignore[union-attr]
+                new_cat_args.append(input)
+                new_cat_args_meta.append(input.meta["example_value"])
+                continue
+            # cannot use parents_seen to check since the first item could be non getitem node
+            if len(parents_seen) == 0:
+                parents_seen.append(parent)
+                idx_to_getitems[idx] = input
+                getitem_indices.append(idx)
+                # case: we only have one getitem input, and it is in the last position
+                if input == inputs[-1]:
+                    new_cat_args.append(input)
+                    new_cat_args_meta.append(input.meta["example_value"])
+                continue
+                # if it is the last input in the tensors, we also check if it can be optimized
+            if parent != parents_seen[-1] or input == inputs[-1]:
+                if input == inputs[-1]:
+                    getitem_indices.append(idx)
+                    idx_to_getitems[idx] = input
+                run_update_func(  # type: ignore[arg-type, union-attr]
+                    graph,
+                    cat_or_stack_node,
+                    getitem_indices,
+                    parents_seen,
+                    new_cat_args,
+                    new_cat_args_meta,
+                    idx_to_getitems,  # type: ignore[arg-type, union-attr]
+                    threshold_to_cat,
+                )
+                # reset the indices array for the next parent
+                # remember to add the last element since it is the first
+                # item in this round of parent
+                # add the parent to the list of seen parents
+                parents_seen.append(parent)
+                getitem_indices, idx_to_getitems = [idx], {idx: input}
+            else:
+                getitem_indices.append(idx)
+                idx_to_getitems[idx] = input
+    return new_cat_args, new_cat_args_meta
+
+
+def remove_split_unbind_children(graph: torch.fx.Graph, inputs: list[torch.fx.Node]):
+    nodes = OrderedSet[Any]()
+    for input in inputs:
+        if input.target is operator.getitem:
+            nodes.add(input.args[0])  # type: ignore[union-attr]
+        if len(input.users.keys()) == 0:
+            graph.erase_node(input)
+    # check the split node to remove if it has no users
+    for node in nodes:
+        if len(node.users.keys()) == 0:  # type: ignore[union-attr]
+            graph.erase_node(node)  # type: ignore[arg-type]
+
+
+# ############pattern to be optimized is#########
+
+#               split_node(dim=1)  -> user=multiple
+#       /           \         ...       /         \
+# other inputs    getitem        getitem     getitem   -> user=multiple
+#            \                    /            \
+#                cat(user=mul, dim=1)             other_op
+#                      |
+
+# ################after transformation#############
+
+#                 split_node(dim=1)     other inputs    -> -> user=multiple
+#                           /           \
+#                         cat (user=mul, dim=1, split_node)
+
+
+@register_graph_pattern(
+    CallFunction(
+        torch.cat,
+        getitem_split,
+        dim=Ignored(),
+        _users=MULTIPLE,
+    ),
+    pass_dict=construct_pattern_matcher_pass("split_cat_to_slices_pass"),
+)
+def split_cat_to_slices(match: Match, split_sections: list[int], dim: int):
+    if not isinstance(split_sections, (list, tuple)):  # Unnormalized split
+        return
+    split_nodes = [node for node in match.nodes if node.target is torch.split]
+    if split_nodes:
+        split_node = next(node for node in split_nodes)
+    else:
+        # Handle the case where there are no nodes with a target of torch.split
+        return
+    split_dim = get_arg_value(split_node, 2, "dim") or 0
+    graph = match.graph
+    threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[
+        "split_cat_to_slices_pass"
+    ].get("threshold_to_cat", 10)
+    # get the cat_node and check its inputs and meta data
+    next_users = find_next_users(split_node)
+    for cat_node in next_users:
+        if cat_node.target != torch.cat or not is_node_meta_valid(cat_node):
+            continue
+        cat_inputs = get_arg_value(cat_node, 0, "tensors")  # type: ignore[union-attr]
+        new_cat_args, _ = construct_cat_args(
+            graph,
+            cat_node,
+            cat_inputs,
+            split_node,
+            threshold_to_cat,
+            update_args_from_split_getitem,
+        )
+        # At least one node would be in the returned new_cat_args
+        # case 1: if new cat args has length 1, we can remove the cat node
+        if len(new_cat_args) == 1:
+            cat_node.replace_all_uses_with(new_cat_args[0])
+            # remove inputs of cat_node if they have no users
+            cat_inputs = cat_node.args[0]  # type: ignore[union-attr]
+            graph.erase_node(cat_node)
+            remove_split_unbind_children(graph, cat_inputs)  # type: ignore[arg-type]
+            counters[backend]["split_cat_to_slices_pass"] += 1
+            continue
+        if len(new_cat_args) > 1 and len(new_cat_args) < len(cat_inputs):
+            new_args = (new_cat_args,)
+            with graph.inserting_after(cat_node):
+                new_cat_node = graph.call_function(
+                    torch.cat,
+                    args=new_args,
+                    # split and cat have the same dim
+                    kwargs={"dim": split_dim},
+                )
+                cat_node.replace_all_uses_with(new_cat_node)
+                new_cat_node.meta.update(cat_node.meta)
+                # remove the cat node
+                graph.erase_node(cat_node)
+                remove_split_unbind_children(graph, cat_inputs)
+                counters[backend]["split_cat_to_slices_pass"] += 1
+
+
+# ############pattern to be optimized is#########
+
+#               unbind(dim=0)  -> user=multiple
+#       /           \         ...       /         \
+# getitem    getitem        getitem     getitem   -> user=multiple
+#            \                    /            \
+#                cat(user=mul, dim=1)             other_op
+#                      |
+
+# ################after transformation#############
+
+#                 input_of_unbind
+#                           |    \
+#                         slice
+#                           |
+#                          view
+#                           |
+
+
+@register_graph_pattern(
+    CallFunction(
+        torch.cat,
+        getitem_unbind,
+        dim=Ignored(),
+        _users=MULTIPLE,
+    ),
+    pass_dict=construct_pattern_matcher_pass("unbind_cat_to_view_pass"),
+)
+def unbind_cat_to_view(match: Match, unbind_input: torch.fx.Node, dim: int):
+    unbind_node = next(node for node in match.nodes if node.target is torch.unbind)
+    graph = match.graph
+    # get the cat_node and check its inputs and meta data
+    next_users = find_next_users(unbind_node)
+    threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[
+        "unbind_cat_to_view_pass"
+    ].get("threshold_to_cat", 10)
+    # get the cat_node and check its inputs and meta data
+    for cat_node in next_users:
+        if cat_node.target != torch.cat or not is_node_meta_valid(cat_node):
+            continue
+        inputs = get_arg_value(cat_node, 0, "tensors")  # type: ignore[union-attr]
+        new_cat_args, new_cat_args_meta = construct_cat_args(
+            graph,
+            cat_node,
+            inputs,
+            unbind_node,
+            threshold_to_cat,
+            update_args_from_unbind_getitem,
+        )
+        # get the view shape
+        # At least one node would be in the returned new_cat_args
+        # case 1: only one node in the new cat args, don't need to cat
+        if len(new_cat_args) == 1:
+            cat_node.replace_all_uses_with(new_cat_args[0])
+            # remove inputs of cat_node if they have no users
+            cat_inputs = cat_node.args[0]  # type: ignore[union-attr]
+            graph.erase_node(cat_node)
+            remove_split_unbind_children(graph, cat_inputs)  # type: ignore[arg-type]
+            counters[backend]["unbind_cat_to_view_pass"] += 1
+            continue
+        if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs):
+            # get the view shape
+            cat_dim = get_arg_value(cat_node, 1, "dim")
+            with graph.inserting_after(cat_node):
+                new_cat_node = graph.call_function(
+                    torch.cat,
+                    args=(new_cat_args,),
+                    kwargs={"dim": cat_dim},
+                )
+                new_cat_node.meta["example_value"] = torch.cat(
+                    new_cat_args_meta, dim=cat_dim
+                )  # type: ignore[arg-type]
+                cat_node.replace_all_uses_with(new_cat_node)
+                new_cat_node.meta.update(cat_node.meta)
+            # remove inputs of cat_node if they have no users
+            cat_inputs = cat_node.args[0]  # type: ignore[union-attr]
+            graph.erase_node(cat_node)
+            remove_split_unbind_children(graph, cat_inputs)  # type: ignore[arg-type]
+            counters[backend]["unbind_cat_to_view_pass"] += 1
+
+
+def reshape_cat_node_to_stack(
+    graph: torch.fx.Graph,
+    cat_node: torch.fx.Node,
+    stack_node: torch.fx.Node,
+    split_or_unbind_dim: int,
+) -> None:
+    # reshape the cat node to the stack node shape
+    stack_shape = stack_node.meta["example_value"].shape
+    stack_dim = _get_dim(stack_node)
+    if stack_dim != split_or_unbind_dim:
+        # case 1: the stack dim is not the same as the split dim
+        # we need to reshape the split input before we do the reshape
+        reshape_list = list(stack_shape)
+        reshape_list[stack_dim], reshape_list[split_or_unbind_dim] = (
+            reshape_list[split_or_unbind_dim],
+            reshape_list[stack_dim],
+        )
+        reshape_node = graph.call_function(
+            torch.reshape,
+            args=(cat_node, tuple(reshape_list)),
+        )
+        reshape_node.meta["example_value"] = torch.reshape(
+            cat_node.meta["example_value"],
+            tuple(reshape_list),  # pyrefly: ignore [bad-argument-type]
+        )
+        permute_list = list(range(len(stack_shape)))
+        permute_list[stack_dim], permute_list[split_or_unbind_dim] = (
+            permute_list[split_or_unbind_dim],
+            permute_list[stack_dim],
+        )
+        permute_node = graph.call_function(
+            torch.permute,
+            args=(reshape_node, permute_list),
+        )
+        permute_node.meta["example_value"] = torch.permute(
+            reshape_node.meta["example_value"], permute_list
+        )
+    else:
+        # case 2: the stack dim is the same as the split dim
+        # we can directly reshape the split input
+        permute_node = cat_node
+    reshape_node = graph.call_function(
+        torch.Tensor.view,
+        args=(permute_node, *stack_shape),  # type: ignore[arg-type]
+    )
+    stack_node.replace_all_uses_with(reshape_node)
+    reshape_node.meta.update(stack_node.meta)
+    stack_inputs = stack_node.args[0]  # type: ignore[union-attr]
+    # remove stack node
+    graph.erase_node(stack_node)
+    # check the input of stack node, and remove nodes that have no users
+    remove_split_unbind_children(graph, stack_inputs)  # type: ignore[arg-type]
+
+
+def convert_reshape_cat_arg_to_stack(
+    graph: torch.fx.Graph,
+    cat_node: torch.fx.Node,
+    stack_node: torch.fx.Node,
+    stack_node_shape: torch.Size,
+    stack_dim: int,
+    split_dim: int,
+) -> torch.fx.Node:
+    # reshape the cat node to the stack node shape
+    cat_shape = cat_node.meta["example_value"].shape
+    if stack_dim != split_dim:
+        permute_list = list(range(len(cat_shape)))
+        permute_list[stack_dim], permute_list[split_dim] = (
+            permute_list[split_dim],
+            permute_list[stack_dim],
+        )
+        permute_node = graph.call_function(
+            torch.permute,
+            args=(cat_node, permute_list),
+        )
+        permute_node.meta["example_value"] = torch.permute(
+            cat_node.meta["example_value"], permute_list
+        )
+    else:
+        permute_node = cat_node
+    reshape_node = graph.call_function(
+        torch.Tensor.view,
+        args=(permute_node, tuple(stack_node_shape)),  # type: ignore[arg-type]
+    )
+    reshape_node.meta["example_value"] = torch.Tensor.view(
+        permute_node.meta["example_value"],
+        tuple(stack_node_shape),  # type: ignore[arg-type]
+    )
+    return reshape_node
+
+
+# ############pattern to be optimized is#########
+#    |           |
+#   split       split   (dim=1)
+#   /     \      /   \
+# getitem  ...        getitem      other ops
+#        \      |       /            /
+#       stack(user=mul, dim=1 or 2) -> can be different dim
+#          |
+
+# ################after transformation#############
+
+#       /           \         ...       /         \
+# getitem    getitem        getitem     getitem   -> user=multiple
+#       \      /
+#       cat(user=mul, dim=1) cat_other_opts
+#          \                  /
+#                  cat
+#                   |
+#                  view
+#                   |
+
+
+@register_graph_pattern(
+    CallFunction(
+        torch.stack,
+        getitem_split,
+        dim=Ignored(),
+        _users=MULTIPLE,
+    ),
+    pass_dict=construct_pattern_matcher_pass("split_stack_to_cats_pass"),
+)
+def split_stack_to_cats(match: Match, split_sections: list[int], dim: int):
+    if not isinstance(split_sections, (list, tuple)):  # Unnormalized split
+        return
+    split_node = next(node for node in match.nodes if node.target is torch.split)
+    split_dim = get_arg_value(split_node, 2, "dim") or 0
+    graph = match.graph
+    threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[
+        "split_stack_to_cats_pass"
+    ].get("threshold_to_cat", 10)
+    # get the stack_node and check its inputs and meta data
+    next_users = find_next_users(split_node)
+    for stack_node in next_users:
+        if stack_node.target != torch.stack or not is_node_meta_valid(stack_node):
+            continue
+        inputs = get_arg_value(stack_node, 0, "tensors")  # type: ignore[union-attr]
+        new_cat_args, new_cat_args_meta = construct_cat_args(
+            graph,
+            stack_node,
+            inputs,
+            split_node,
+            threshold_to_cat,
+            update_args_from_split_getitem,
+        )
+        # At least one node would be in the returned new_cat_args
+        # case 1: only one node in the new cat args, don't need to cat
+        if len(new_cat_args) == 1:
+            reshape_cat_node_to_stack(graph, new_cat_args[0], stack_node, split_dim)
+            counters[backend]["split_stack_to_cats_pass"] += 1
+            continue
+        if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs):
+            with graph.inserting_after(stack_node):
+                cat_node = graph.call_function(
+                    torch.cat,
+                    args=(new_cat_args,),
+                    kwargs={"dim": split_dim},
+                )
+                cat_node.meta["example_value"] = torch.cat(  # type: ignore[arg-type]
+                    new_cat_args_meta, dim=split_dim
+                )
+                reshape_cat_node_to_stack(graph, cat_node, stack_node, split_dim)
+                counters[backend]["split_stack_to_cats_pass"] += 1
+
+
+# ############pattern to be optimized is#########
+
+#               unbind(dim=1)  -> user=multiple
+#                  \         ...       /         \
+# others    getitem        getitem     getitem   -> user=multiple
+#  \          \                    /            \
+#                stack(user=mul, dim=1)             other_op
+#                      |
+
+# ################after transformation#############
+
+#                 input_of_unbind
+#                           |    \
+#                         slice
+#                           |
+#                          view   others
+#                           |    /
+#                          stack
+#                           |
+
+
+@register_graph_pattern(
+    CallFunction(
+        torch.stack,
+        getitem_unbind,
+        dim=Ignored(),
+        _users=MULTIPLE,
+    ),
+    pass_dict=construct_pattern_matcher_pass("unbind_stack_to_slices_pass"),
+)
+def unbind_stack_to_slices(match: Match, unbind_input: torch.fx.Node, dim: int):
+    unbind_node = next(node for node in match.nodes if node.target is torch.unbind)
+    graph = match.graph
+    # get the cat_node and check its inputs and meta data
+    next_users = find_next_users(unbind_node)
+    threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[
+        "unbind_stack_to_slices_pass"
+    ].get("threshold_to_cat", 10)
+    # get the cat_node and check its inputs and meta data
+    for stack_node in next_users:
+        if stack_node.target != torch.stack or not is_node_meta_valid(stack_node):
+            continue
+        inputs = get_arg_value(stack_node, 0, "tensors")  # type: ignore[union-attr]
+        new_cat_args, new_cat_args_meta = construct_cat_args(
+            graph,
+            stack_node,
+            inputs,
+            unbind_node,
+            threshold_to_cat,
+            update_args_from_unbind_getitem,
+        )
+        unbind_dim = get_arg_value(unbind_node, 1, "dim") or 0
+        # At least one node would be in the returned new_cat_args
+        # case 1: only one node in the new cat args, don't need to cat
+        if len(new_cat_args) == 1:
+            reshape_cat_node_to_stack(graph, new_cat_args[0], stack_node, unbind_dim)
+            counters[backend]["unbind_stack_to_slices_pass"] += 1
+            continue
+        if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs):
+            # get the view shape
+            cat_dim = get_arg_value(stack_node, 1, "dim")
+            with graph.inserting_after(stack_node):
+                new_cat_node = graph.call_function(
+                    torch.cat,
+                    args=(new_cat_args,),
+                    kwargs={"dim": cat_dim},
+                )
+                new_cat_node.meta["example_value"] = torch.cat(
+                    new_cat_args_meta, dim=cat_dim
+                )
+                reshape_cat_node_to_stack(graph, new_cat_node, stack_node, unbind_dim)
+            counters[backend]["unbind_stack_to_slices_pass"] += 1
+
+
+# ############pattern to be optimized is#########
+#                   input
+#                     |
+#               split(dim=1)  -> user=multiple
+#                  \         \
+# others    getitem        getitem
+#  \          \               /
+#  reshape     reshape      reshape     other_op
+#  \          \             /         /
+#                stack(user=mul, dim=0)
+#                      |
+
+# ################after transformation#############
+#                          input
+#                           |
+#                         permute
+#                           |
+#                         reshape   others
+#                           |    /
+#                          cat (dim=0)
+#                           |
+
+
+def get_view_shape_list(cat_arg: torch.fx.Node, stack_dim: int) -> list[int]:
+    # cat_arg must be the split input
+    view_shape_list = []
+    for user in cat_arg.users:
+        if user.target is torch.split:
+            for getitem in user.users:
+                if getitem.target is operator.getitem:
+                    reshape_user = [
+                        user for user in getitem.users if user.target is torch.reshape
+                    ]
+                    if len(reshape_user) > 0:
+                        view_shape_list = list(
+                            reshape_user[0]
+                            .meta["example_value"]
+                            .unsqueeze(stack_dim)
+                            .shape
+                        )
+                        view_shape_list[stack_dim] = -1
+                        return view_shape_list
+    return view_shape_list
+
+
+@register_graph_pattern(
+    CallFunction(
+        torch.stack,
+        reshape_getitem_split,
+        dim=Ignored(),
+        _users=MULTIPLE,
+    ),
+    pass_dict=construct_pattern_matcher_pass("move_reshape_out_of_split_stack_pass"),
+)
+def move_reshape_out_of_split_stack(match: Match, *args, **kwargs):
+    split_node = next(node for node in match.nodes if node.target is torch.split)
+    split_dim = _get_dim(split_node)
+    split_users = list(split_node.users.keys())
+    stack_nodes = [node for node in match.nodes if node.target is torch.stack]
+    graph = match.graph
+    threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[
+        "move_reshape_out_of_split_stack_pass"
+    ].get("threshold_to_cat", 10)
+    for stack_node in stack_nodes:
+        if not is_node_meta_valid(stack_node):
+            log.debug("example value absent for node: %s", stack_node)
+            continue
+        stack_dim = _get_dim(stack_node)
+        stack_inputs = get_arg_value(stack_node, 0, "tensors")  # type: ignore[union-attr]
+        inputs = []
+        for stack_input in stack_inputs:
+            if stack_input.target != torch.reshape:
+                inputs.append(stack_input)
+            else:
+                inputs.append(stack_input.args[0])  # type: ignore[union-attr]
+        new_cat_args, _new_cat_args_meta = construct_cat_args(
+            graph,
+            stack_node,
+            inputs,
+            split_node,
+            threshold_to_cat,
+            update_args_from_split_getitem,
+        )
+        # At least one node would be in the returned new_cat_args
+        # case 1: only one node in the new cat args, don't need to cat
+        if len(new_cat_args) == 1:
+            reshape_node = convert_reshape_cat_arg_to_stack(
+                graph,
+                new_cat_args[0],
+                stack_node,
+                stack_node.meta["example_value"].shape,
+                stack_dim,
+                split_dim,
+            )
+            stack_node.replace_all_uses_with(reshape_node)
+            # remove stack node
+            graph.erase_node(stack_node)
+            # check the input of stack node, and remove nodes that have no users
+            remove_split_unbind_children(graph, stack_inputs)  # type: ignore[arg-type]
+            remove_split_unbind_children(graph, split_users)  # type: ignore[arg-type]
+            counters[backend]["move_reshape_out_of_split_stack_pass"] += 1
+            continue
+        if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs):
+            # decompose the cat args into multiple stack nodes, i.e., we stack
+            # all the nodes exist in the stack inputs and reshape the rest followed by a cat
+            stack_node_input, stack_node_input_meta, cat_inputs = [], [], []  # type: ignore[var-annotated]
+            for cat_arg in new_cat_args:
+                if cat_arg not in stack_inputs:
+                    if len(stack_node_input) > 0:
+                        with graph.inserting_after(stack_node):
+                            decomposed_stack_node = graph.call_function(
+                                torch.stack,
+                                args=(stack_node_input,),
+                                kwargs={"dim": stack_dim},
+                            )
+                            decomposed_stack_node.meta["example_value"] = torch.stack(
+                                stack_node_input_meta, dim=stack_dim
+                            )
+                            cat_inputs.append(decomposed_stack_node)
+                    # cat_arg must be the split input
+                    view_shape_list = get_view_shape_list(cat_arg, stack_dim)
+                    stack_node_shape = torch.reshape(
+                        cat_arg.meta["example_value"], tuple(view_shape_list)
+                    ).shape  # type: ignore[union-attr]
+                    cat_inputs.append(
+                        convert_reshape_cat_arg_to_stack(
+                            graph,
+                            cat_arg,
+                            stack_node,
+                            stack_node_shape,
+                            stack_dim,
+                            split_dim,
+                        )
+                    )
+                    stack_node_input, stack_node_input_meta = [], []
+                else:
+                    stack_node_input.append(cat_arg)
+                    stack_node_input_meta.append(cat_arg.meta["example_value"])
+
+            if len(stack_node_input) > 0:
+                with graph.inserting_after(stack_node):
+                    decomposed_stack_node = graph.call_function(
+                        torch.stack,
+                        args=(stack_node_input,),
+                        kwargs={"dim": stack_dim},
+                    )
+                    decomposed_stack_node.meta["example_value"] = torch.stack(
+                        stack_node_input_meta, dim=stack_dim
+                    )
+                    cat_inputs.append(decomposed_stack_node)
+
+            with graph.inserting_after(stack_node):
+                cat_node = graph.call_function(
+                    torch.cat,
+                    args=(cat_inputs,),
+                    kwargs={"dim": stack_dim},
+                )
+                stack_node.replace_all_uses_with(cat_node)
+                cat_node.meta.update(stack_node.meta)
+                graph.erase_node(stack_node)
+                remove_split_unbind_children(graph, stack_inputs)  # type: ignore[arg-type]
+                remove_split_unbind_children(graph, split_users)  # type: ignore[arg-type]
+            counters[backend]["move_reshape_out_of_split_stack_pass"] += 1
+
+
+view_getitem_split_aten = ListOf(
+    CallFunction(
+        [torch.ops.aten.reshape.default],
+        CallFunction(
+            operator.getitem,
+            CallFunctionVarArgs(
+                torch.ops.aten.split_with_sizes.default, users=MULTIPLE
+            ),
+            Ignored(),
+            _users=MULTIPLE,
+        ),
+        Arg(),
+        _users=MULTIPLE,
+    ),
+    partial=True,
+)
+
+
+@register_graph_pattern(
+    CallFunction(
+        torch.ops.aten.cat.default,
+        view_getitem_split_aten,
+        dim=Ignored(),
+        _users=MULTIPLE,
+    ),
+    pass_dict=construct_pattern_matcher_pass("move_view_after_cat_aten_pass"),
+)
+def move_view_after_cat(match: Match, *args, **kwargs):
+    split_node = next(
+        node
+        for node in match.nodes
+        if node.target is torch.ops.aten.split_with_sizes.default
+    )
+    split_input, split_section, split_dim = _get_split_args_default(split_node)
+    split_users = list(split_node.users.keys())
+    getitem_indices = [
+        getitem.args[1] for getitem in split_users if getitem.target is operator.getitem
+    ]
+    if not is_sorted_and_consecutive(getitem_indices):  # type: ignore[arg-type]
+        return
+    cat_nodes = [
+        node for node in match.nodes if node.target is torch.ops.aten.cat.default
+    ]
+    graph = match.graph
+    for cat_node in cat_nodes:
+        if not is_node_meta_valid(cat_node):
+            log.debug("example value absent for node: %s", cat_node)
+            continue
+        cat_dim = _get_dim(cat_node)
+        cat_inputs = get_arg_value(cat_node, 0, "tensors")  # type: ignore[union-attr]
+        # we only consider the following special case
+        if len(cat_inputs) != len(split_section):
+            continue
+        # check if the cat inputs are all the view nodes
+        if not all(
+            view_node.target is torch.ops.aten.reshape.default
+            for view_node in cat_inputs
+        ):
+            continue
+        # check if the view nodes are all from getitem nodes
+        if not all(
+            view_node.args[0].target is operator.getitem for view_node in cat_inputs
+        ):
+            continue
+        view_indices = [view.args[0].args[1] for view in cat_inputs]
+        if not is_sorted_and_consecutive(view_indices):  # type: ignore[arg-type]
+            continue
+        if cat_dim != split_dim:
+            # construct permute node
+            permute_list = list(range(len(cat_node.meta["val"].shape) + 1))
+            permute_list[split_dim], permute_list[cat_dim] = (
+                permute_list[cat_dim],
+                permute_list[split_dim],
+            )
+            permute_node = graph.call_function(
+                torch.ops.aten.permute.default,
+                args=(split_input, permute_list),
+            )
+        else:
+            permute_node = split_input
+
+        with graph.inserting_before(cat_node):
+            view_node = graph.call_function(
+                torch.ops.aten.reshape.default,
+                args=(permute_node, list(cat_node.meta["val"].shape)),
+            )
+            cat_node.replace_all_uses_with(view_node)
+            view_node.meta.update(cat_node.meta)
+            graph.erase_node(cat_node)
+        counters[backend]["move_view_after_cat_aten_pass"] += 1
+
+
+def match_einsum_strings(s: str) -> bool:
+    """
+    This function takes a string s as input, where s is in the format "3 letter string,
+    4 letter string -> 3 letter string".
+    It checks if the strings match the rule and returns True if they do, False otherwise.
+
+    The rule is:
+    - The three strings have the same first two characters.
+    - The first two strings have the same third character.
+    - The second and third strings have the same last character.
+    """
+
+    # Split the input string into parts
+    parts = s.replace("->", ",").split(",")
+
+    # Strip leading/trailing whitespaces from each part
+    parts = [part.strip() for part in parts]
+
+    # Check if we have exactly three parts
+    if len(parts) != 3:
+        return False
+
+    # Extract the strings
+    s1, s2, s3 = parts
+
+    # Check if the strings have the correct lengths
+    if len(s1) != 3 or len(s2) != 4 or len(s3) != 3:
+        return False
+
+    # Check the rule
+    return s1[:2] == s2[:2] == s3[:2] and s1[2] == s2[2] and s2[3] == s3[2]
+
+
+@register_graph_pattern(
+    CallFunctionVarArgs(torch.functional.einsum, users=MULTIPLE),
+    pass_dict=construct_pattern_matcher_pass("einsum_to_pointwise_pass"),
+)
+def replace_einsum_to_pointwise(match: Match, *args, **kwargs):
+    def repl(input, weights):
+        return (input.unsqueeze(-1) * weights).sum(-2)
+
+    def should_replace_einsum(einsum_node) -> bool:
+        equation = get_arg_value(einsum_node, 0)
+        users = einsum_node.users.keys()
+        # for now, we only consider the case of two operands
+        return (
+            len(einsum_node.args) == 3
+            and is_node_meta_valid(input)
+            and is_node_meta_valid(weights)
+            and any(
+                user.target == "add" or user.target is operator.add for user in users
+            )
+            and match_einsum_strings(equation)
+        )
+
+    einsum_node = match.nodes[0]
+    input, weights = get_arg_value(einsum_node, 1), get_arg_value(einsum_node, 2)
+    if should_replace_einsum(einsum_node):
+        # pyrefly: ignore [bad-argument-type]
+        match.replace_by_example(repl, [input, weights])
+        counters[backend]["einsum_to_pointwise_pass"] += 1
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9668f1b6c6e1d07c9a2744ab3894929826f39429
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__init__.py
@@ -0,0 +1 @@
+from . import flex, mm, mm_common, mm_plus_mm
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/bmm.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/bmm.py
new file mode 100644
index 0000000000000000000000000000000000000000..a155d35b5d059154e20cb8a1e88e361098e8d4c2
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/bmm.py
@@ -0,0 +1,343 @@
+# mypy: allow-untyped-defs
+import logging
+from typing import TYPE_CHECKING, Union
+
+import torch
+from torch._dynamo.utils import counters
+from torch._inductor.codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
+
+from .. import config as inductor_config, ir, lowering as L
+from ..kernel_inputs import MMKernelInputs
+from ..lowering import lowerings, make_pointwise, make_reduction, transform_args
+from ..select_algorithm import (
+    autotune_select_algorithm,
+    ExternKernelChoice,
+    SymbolicGridFn,
+    TritonTemplate,
+)
+from ..utils import (
+    _use_cutlass_for_op,
+    use_aten_gemm_kernels,
+    use_ck_gemm_template,
+    use_cpp_bmm_template,
+    use_cutlass_template,
+    use_triton_template,
+)
+from ..virtualized import ops, V
+from .mm_common import (
+    _is_static_problem,
+    is_batch_stride_largest_or_zero,
+    mm_args,
+    use_native_matmul,
+)
+
+
+if TYPE_CHECKING:
+    from ..ir import ChoiceCaller
+    from ..select_algorithm import KernelTemplate
+
+log = logging.getLogger(__name__)
+aten = torch.ops.aten
+
+
+@SymbolicGridFn
+def bmm_grid(b, m, n, meta, *, cdiv):
+    return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1)
+
+
+bmm_template = TritonTemplate(
+    name="bmm",
+    grid=bmm_grid,
+    source=r"""
+{{def_kernel("A", "B")}}
+    M = {{size("A", -2)}}
+    N = {{size("B", -1)}}
+    K = {{size("A", -1)}}
+
+    stride_aq = {{stride("A", 0)}}
+    stride_am = {{stride("A", 1)}}
+    stride_ak = {{stride("A", 2)}}
+
+    stride_bq = {{stride("B", 0)}}
+    stride_bk = {{stride("B", 1)}}
+    stride_bn = {{stride("B", 2)}}
+
+    # based on triton.ops.matmul
+    pid = tl.program_id(0).to(INDEX_DTYPE)
+    grid_m = (M + BLOCK_M - 1) // BLOCK_M
+    grid_n = (N + BLOCK_N - 1) // BLOCK_N
+
+    # re-order program ID for better L2 performance
+    width = GROUP_M * grid_n
+    group_id = pid // width
+    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
+    pid_m = group_id * GROUP_M + (pid % group_size)
+    pid_n = (pid % width) // (group_size)
+    tl.assume(pid_m >= 0)
+    tl.assume(pid_n >= 0)
+
+    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+    if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1):
+        ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+    else:
+        ram = rm % M
+    if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1):
+        rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+    else:
+        rbn = rn % N
+
+    rk = tl.arange(0, BLOCK_K)
+
+    idx_q = tl.program_id(1).to(INDEX_DTYPE)  # batch dimension for BMM
+    A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq)
+    B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq)
+
+    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+    for k in range(K, 0, -BLOCK_K):
+        if EVEN_K:
+            a = tl.load(A)
+            b = tl.load(B)
+        else:
+            a = tl.load(A, mask=rk[None, :] < k, other=0.)
+            b = tl.load(B, mask=rk[:, None] < k, other=0.)
+        acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
+        A += BLOCK_K * stride_ak
+        B += BLOCK_K * stride_bk
+
+    # rematerialize rm and rn to save registers
+    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+    idx_q = tl.program_id(1).to(INDEX_DTYPE)  # batch dimension for BMM
+    idx_m = rm[:, None]
+    idx_n = rn[None, :]
+    mask = (idx_m < M) & (idx_n < N)
+
+    # inductor generates a suffix
+    {{store_output(("idx_q", "idx_m", "idx_n"), "acc", "mask", val_shape=("BLOCK_M", "BLOCK_N"))}}
+""",
+    cache_codegen_enabled_for_template=True,
+)
+
+aten_bmm = ExternKernelChoice(torch.bmm, "at::bmm_out", op_overload=aten.bmm.out)
+aten_bmm_dtype = ExternKernelChoice(
+    torch.bmm,
+    "at::_bmm_out_dtype_cuda",
+    name="bmm_dtype",
+    op_overload=aten.bmm.dtype_out,
+)
+aten_baddbmm = ExternKernelChoice(
+    torch.baddbmm, "at::baddbmm_out", op_overload=aten.baddbmm.out
+)
+
+
+@L.register_lowering(aten.bmm)
+def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
+    """
+    Lowering for autotuning aten.bmm with different backends (Aten, Triton, CUTLASS, etc.)
+    """
+    if all(x.get_device().type == "cpu" for x in [mat1, mat2]):
+        # decompose to small ops when memory bound
+        if mat1.get_size()[1] == 1 or mat2.get_size()[2] == 1:
+            mat1 = L.unsqueeze(mat1, -1)
+            mat2 = L.unsqueeze(mat2, 1)
+            return L.sum_(L.mul(mat1, mat2), axis=2)
+
+        def is_valid_to_require_contiguous(t):
+            if not ir.is_storage_and_layout(t):
+                return True
+            _, layout = ir.as_storage_and_layout(t, freeze=False)
+            return isinstance(layout, ir.FlexibleLayout)
+
+        def is_preferred_layout_as_bmm_input(sizes, strides):
+            # contiguous on one of the last two dims
+            return (
+                strides[-1] == 1 and (sizes[-2] == 1 or strides[-2] >= sizes[-1])
+            ) or (strides[-2] == 1 and (sizes[-1] == 1 or strides[-1] >= sizes[-2]))
+
+        # Make the input of bmm contiguous
+        # if it is not contiguous on either of the last two dims,
+        # because bmm cpu implementation would do contiguous() if not.
+        # This is to avoid additional copies in bmm.
+        def may_require_contiguous(t, meta_t):
+            sizes = meta_t.meta["val"].size()
+            strides = meta_t.meta["val"].stride()
+            if not is_preferred_layout_as_bmm_input(sizes, strides):
+                t = ir.ExternKernel.require_contiguous(t)
+            return t
+
+        if is_valid_to_require_contiguous(mat1):
+            meta_mat1 = V.graph.current_node.args[0]
+            mat1 = may_require_contiguous(mat1, meta_mat1)
+        if is_valid_to_require_contiguous(mat2):
+            meta_mat2 = V.graph.current_node.args[1]
+            mat2 = may_require_contiguous(mat2, meta_mat2)
+
+    if use_native_matmul(mat1, mat2):
+        mat1 = lowerings[aten.unsqueeze](mat1, -1)
+        mat2 = lowerings[aten.unsqueeze](mat2, 1)
+        args, kwargs = transform_args(
+            args=[mat1, mat2],
+            kwargs={},
+            broadcast=True,
+            type_promotion_kind=None,
+            convert_input_to_bool=False,
+        )  # Handles broadcasting the arguments
+
+        if inductor_config.triton.codegen_upcast_to_fp32 and mat1.dtype in [
+            torch.float16,
+            torch.bfloat16,
+        ]:
+
+            def _to_dtype(x):
+                return ops.to_dtype(x, mat1.dtype, use_compute_types=False)
+
+            args = [make_pointwise(_to_dtype)(x) for x in args]
+
+        mul_pointwise = make_pointwise(ops.dot)(*args)
+        dot_reduction = make_reduction("dot")(mul_pointwise, 2)
+
+        return dot_reduction
+
+    # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
+    m, n, k, layout, mat1, mat2 = mm_args(
+        mat1, mat2, layout=layout, out_dtype=out_dtype
+    )
+    name = "bmm"
+
+    # Create MMKernelInputs for BMM at the top
+    kernel_inputs = MMKernelInputs([mat1, mat2], out_dtype=out_dtype)
+
+    # below is for getting an overview logging info of inductor mms
+    batch_size = mat1.get_size()[0]  # Extract batch dimension
+    counters["aten_mm_info"][f"aten.bmm_{batch_size}_{m}_{n}_{k}"] += 1
+    log.info(
+        "Tuned aten.bmm: batch=%s, m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
+        batch_size,
+        m,
+        n,
+        k,
+        mat1.get_dtype(),
+        mat2.get_dtype(),
+        layout,
+    )
+
+    aten_handler: ExternKernelChoice = aten_bmm
+    aten_extra_kwargs = {}
+    if out_dtype:
+        assert mat1.get_device().type == "cuda", "out_dtype is only supported for CUDA"
+        aten_handler = aten_bmm_dtype
+        aten_extra_kwargs = {"out_dtype": out_dtype}
+
+    choices: list[ChoiceCaller] = []
+
+    # Collect all templates for unified call
+    templates_to_use: list[Union[ExternKernelChoice, KernelTemplate]] = []
+    kwarg_overrides = {}
+
+    if use_aten_gemm_kernels():
+        templates_to_use.append(aten_handler)
+        kwarg_overrides[aten_handler.uid] = aten_extra_kwargs
+
+    if use_triton_template(layout, check_max_autotune=False) and (
+        out_dtype is None or out_dtype == mat1.get_dtype()
+    ):
+        # TODO: add out_dtype support for Triton Template
+        templates_to_use.append(bmm_template)
+
+    # Single unified call for all templates
+    choices.extend(
+        V.choices.get_template_configs(
+            kernel_inputs,
+            templates_to_use,
+            name,
+            kwarg_overrides=kwarg_overrides,
+        )
+    )
+    _, is_nonzero = _is_static_problem(layout)
+    batch_stride_largest_or_zero = is_batch_stride_largest_or_zero(mat1, mat2, layout)
+    if (
+        batch_stride_largest_or_zero
+        and is_nonzero
+        and use_cutlass_template(layout, m, n, k)
+        and _use_cutlass_for_op(name)
+    ):
+        from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate
+
+        CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
+            choices, layout, kernel_inputs.nodes()
+        )  # type: ignore[arg-type]
+
+    if use_cpp_bmm_template(layout, mat1, mat2):
+        from ..codegen.cpp_bmm_template import CppBmmTemplate
+
+        CppBmmTemplate.add_choices(
+            choices,
+            layout,
+            kernel_inputs.nodes(),
+        )
+
+    if use_ck_gemm_template(layout, m, n, k):
+        CKGemmTemplate.add_ck_gemm_choices(choices, layout, kernel_inputs.nodes())
+
+    return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
+
+
+@L.register_lowering(aten.baddbmm)
+def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
+    """
+    Lowering for autotuning aten.mm with different backends (Aten, Triton, CUTLASS, etc.)
+    """
+    if use_native_matmul(mat1, mat2):
+        if beta == 0:
+            arg1 = 0
+        else:
+            arg1 = lowerings[aten.mul](beta, inp)
+
+        if alpha == 0:
+            arg2 = 0
+        else:
+            arg2 = lowerings[aten.mul](alpha, lowerings[aten.bmm](mat1, mat2))
+
+        return lowerings[aten.add](arg1, arg2)
+
+    # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
+    m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout)
+
+    # Create MMKernelInputs for BadDBMM at the top
+    kernel_inputs = MMKernelInputs(
+        [inp, mat1, mat2], scalars=dict(alpha=alpha, beta=beta)
+    )
+
+    # below is for getting an overview logging info of inductor mms
+    batch_size = mat1.get_size()[0]
+    counters["aten_mm_info"][f"aten.baddbmm_{batch_size}_{m}_{n}_{k}"] += 1
+    log.info(
+        "Tuned aten.baddbmm: batch_size=%s, m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, inp=%s, output_layout=%s",
+        batch_size,
+        m,
+        n,
+        k,
+        mat1.get_dtype(),
+        mat2.get_dtype(),
+        inp.get_dtype(),
+        layout,
+    )
+    name = "baddbmm"
+    # options to tune from
+    choices: list[ChoiceCaller] = []
+
+    # Collect all templates for unified call
+    templates_to_use: list[Union[ExternKernelChoice, KernelTemplate]] = []
+    if use_aten_gemm_kernels():
+        templates_to_use.append(aten_baddbmm)
+
+    if use_triton_template(layout, check_max_autotune=False):
+        templates_to_use.append(bmm_template)
+
+    # Single unified call for all templates
+    choices.extend(
+        V.choices.get_template_configs(kernel_inputs, templates_to_use, name)
+    )
+
+    return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/conv.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e5a2aa09d4ea229ebb56ac589f56fc3900ba6ae
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/conv.py
@@ -0,0 +1,687 @@
+# mypy: allow-untyped-defs
+from __future__ import annotations
+
+import logging
+from typing import Optional, TYPE_CHECKING, TypedDict
+
+import torch
+from torch._inductor.codegen.rocm.ck_conv_template import CKGroupedConvFwdTemplate
+
+from .. import config, ir
+from ..lowering import (
+    add_layout_constraint,
+    constrain_to_fx_strides,
+    lowerings as L,
+    register_lowering,
+)
+from ..select_algorithm import (
+    autotune_select_algorithm,
+    ExternKernelChoice,
+    SymbolicGridFn,
+    TritonTemplate,
+)
+from ..utils import (
+    is_ones,
+    is_zeros,
+    pad_listlike,
+    sympy_product,
+    use_ck_conv_template,
+    use_triton_template,
+)
+from ..virtualized import V
+
+
+if TYPE_CHECKING:
+    from collections.abc import Sequence
+
+    from ..ir import TensorBox
+
+log = logging.getLogger(__name__)
+
+
+aten = torch.ops.aten
+
+
+@SymbolicGridFn
+def conv2d_grid(n, c, h, w, meta, *, cdiv):
+    return (
+        cdiv(n * h * w, meta["BLOCK_M"]),
+        cdiv(c, meta["BLOCK_N"]),
+        meta["GROUPS"],
+    )
+
+
+@SymbolicGridFn
+def conv3d_grid(n, c, d, h, w, meta, *, cdiv):
+    return (
+        cdiv(n * d * h * w, meta["BLOCK_M"]),
+        cdiv(c, meta["BLOCK_N"]),
+        meta["GROUPS"],
+    )
+
+
+LOOP_BODY_2D = """
+        idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
+        idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
+        idx_x_c = tl.arange(0, BLOCK_K) + k
+
+        x_ptrs = x_base + (
+            (idx_x_h * stride_xh)[:, None]
+            + (idx_x_w * stride_xw)[:, None]
+            + (idx_x_c * stride_xc)[None, :]
+        )
+        mask_x = (
+            (idx_n < BATCH)[:, None]
+            & (idx_x_h >= 0)[:, None]
+            & (idx_x_h < IN_H)[:, None]
+            & (idx_x_w >= 0)[:, None]
+            & (idx_x_w < IN_W)[:, None]
+            & (idx_x_c < GROUP_IN_C)[None, :]
+        )
+        matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
+
+        w_ptrs = w_base + (
+            (idx_x_c * stride_wc_in)[:, None] + (i * stride_wh) + (j * stride_ww)
+        )
+        mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C)
+        matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
+        acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32)
+"""
+
+"""
+This is a relatively simple conv implementation that can likely be
+improved.  Many alternate conv versions can be found here:
+https://github.com/pytorch/torchdynamo/pull/971
+"""
+conv2d_template = TritonTemplate(
+    name="convolution2d",
+    grid=conv2d_grid,
+    source=r"""
+{{def_kernel("X", "W")}}
+    # Tensor dimensions
+    BATCH = {{size("X", 0)}}
+    IN_C = {{size("X", 1)}}
+    IN_H = {{size("X", 2)}}
+    IN_W = {{size("X", 3)}}
+    OUT_C = {{size(None, 1)}}
+    OUT_H = {{size(None, 2)}}
+    OUT_W = {{size(None, 3)}}
+
+    # Strides:
+    stride_xn = {{stride("X", 0)}}
+    stride_xc = {{stride("X", 1)}}
+    stride_xh = {{stride("X", 2)}}
+    stride_xw = {{stride("X", 3)}}
+    stride_wc_out = {{stride("W", 0)}}
+    stride_wc_in = {{stride("W", 1)}}
+    stride_wh = {{stride("W", 2)}}
+    stride_ww = {{stride("W", 3)}}
+
+    nhw = tl.program_id(0).to(INDEX_DTYPE) * BLOCK_M + tl.arange(0, BLOCK_M)
+    idx_y_w = nhw % OUT_W
+    nh = nhw // OUT_W
+    idx_y_h = nh % OUT_H
+    idx_n = nh // OUT_H
+    idx_y_c = tl.program_id(1).to(INDEX_DTYPE) * BLOCK_N + tl.arange(0, BLOCK_N)
+
+{% if GROUPS == 1 %}
+    group = 0
+    GROUP_IN_C = IN_C
+    GROUP_OUT_C = OUT_C
+{% else %}
+    group = tl.program_id(2).to(INDEX_DTYPE)
+    GROUP_IN_C = IN_C // GROUPS
+    GROUP_OUT_C = OUT_C // GROUPS
+{% endif %}
+
+    x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None]
+    w_base = (
+        W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :]
+    )
+
+    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+{% if UNROLL %}
+{% for i in range(KERNEL_H) %}
+{% for j in range(KERNEL_W) %}
+    i = {{i}}
+    j = {{j}}
+    for k in range(0, GROUP_IN_C, BLOCK_K):
+        """
+    + LOOP_BODY_2D
+    + """
+{% endfor %}
+{% endfor %}
+{% else %}
+    # Could be simplified, but slightly slower:
+    # for i in range(KERNEL_H):
+    #     for j in range(KERNEL_W):
+    #         for k in range(0, GROUP_IN_C, BLOCK_K):
+    BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K
+    for ijk in range(KERNEL_H * KERNEL_W * BLOCK_K_COUNT):
+        k = (ijk % BLOCK_K_COUNT) * BLOCK_K
+        ij = ijk // BLOCK_K_COUNT
+        i = ij // KERNEL_W
+        j = ij % KERNEL_W
+        """
+    + LOOP_BODY_2D
+    + """
+{% endif %}
+
+    mask = (
+        (idx_n < BATCH)[:, None]
+        & (idx_y_h < OUT_H)[:, None]
+        & (idx_y_w < OUT_W)[:, None]
+        & (idx_y_c < GROUP_OUT_C)[None, :]
+    )
+    idx_n = idx_n[:, None]
+    idx_c = idx_y_c[None, :] + group * GROUP_OUT_C
+    idx_h = idx_y_h[:, None]
+    idx_w = idx_y_w[:, None]
+
+    # inductor generates a suffix
+    {{store_output(("idx_n", "idx_c", "idx_h", "idx_w"), "acc", "mask", val_shape=("BLOCK_M", "BLOCK_N"))}}
+""",
+)
+
+LOOP_BODY_3D = """
+        idx_x_d = d - PADDING_D + idx_y_d * STRIDE_D
+        idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
+        idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
+        idx_x_c = tl.arange(0, BLOCK_K) + k
+
+        x_ptrs = x_base + (
+            (idx_x_d * stride_xd)[:, None]
+            + (idx_x_h * stride_xh)[:, None]
+            + (idx_x_w * stride_xw)[:, None]
+            + (idx_x_c * stride_xc)[None, :]
+        )
+        mask_x = (
+            (idx_n < BATCH)[:, None]
+            & (idx_x_d >= 0)[:, None]
+            & (idx_x_d < IN_D)[:, None]
+            & (idx_x_h >= 0)[:, None]
+            & (idx_x_h < IN_H)[:, None]
+            & (idx_x_w >= 0)[:, None]
+            & (idx_x_w < IN_W)[:, None]
+            & (idx_x_c < GROUP_IN_C)[None, :]
+        )
+        matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
+
+        w_ptrs = w_base + (
+            (idx_x_c * stride_wc_in)[:, None] +
+            (d * stride_wd) + (i * stride_wh) + (j * stride_ww)
+        )
+        mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C)
+        matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
+        acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32)
+"""
+
+conv3d_template = TritonTemplate(
+    name="convolution3d",
+    grid=conv3d_grid,
+    source=r"""
+{{def_kernel("X", "W")}}
+    # Tensor dimensions
+    BATCH = {{size("X", 0)}}
+    IN_C = {{size("X", 1)}}
+    IN_D = {{size("X", 2)}}
+    IN_H = {{size("X", 3)}}
+    IN_W = {{size("X", 4)}}
+    OUT_C = {{size(None, 1)}}
+    OUT_D = {{size(None, 2)}}
+    OUT_H = {{size(None, 3)}}
+    OUT_W = {{size(None, 4)}}
+
+    # Strides:
+    stride_xn = {{stride("X", 0)}}
+    stride_xc = {{stride("X", 1)}}
+    stride_xd = {{stride("X", 2)}}
+    stride_xh = {{stride("X", 3)}}
+    stride_xw = {{stride("X", 4)}}
+    stride_wc_out = {{stride("W", 0)}}
+    stride_wc_in = {{stride("W", 1)}}
+    stride_wd = {{stride("W", 2)}}
+    stride_wh = {{stride("W", 3)}}
+    stride_ww = {{stride("W", 4)}}
+
+    ndhw = tl.program_id(0).to(INDEX_DTYPE) * BLOCK_M + tl.arange(0, BLOCK_M)
+    idx_y_w = ndhw % OUT_W
+    ndh = ndhw // OUT_W
+    idx_y_h = ndh % OUT_H
+    nd = ndh // OUT_H
+    idx_y_d = nd % OUT_D
+    idx_n = nd // OUT_D
+    idx_y_c = tl.program_id(1).to(INDEX_DTYPE) * BLOCK_N + tl.arange(0, BLOCK_N)
+
+{% if GROUPS == 1 %}
+    group = 0
+    GROUP_IN_C = IN_C
+    GROUP_OUT_C = OUT_C
+{% else %}
+    group = tl.program_id(2).to(INDEX_DTYPE)
+    GROUP_IN_C = IN_C // GROUPS
+    GROUP_OUT_C = OUT_C // GROUPS
+{% endif %}
+
+    x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None]
+    w_base = (
+        W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :]
+    )
+
+    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+{% if UNROLL %}
+{% for d in range(KERNEL_D) %}
+{% for i in range(KERNEL_H) %}
+{% for j in range(KERNEL_W) %}
+    d = {{d}}
+    i = {{i}}
+    j = {{j}}
+    for k in range(0, GROUP_IN_C, BLOCK_K):
+        """
+    + LOOP_BODY_3D
+    + """
+{% endfor %}
+{% endfor %}
+{% endfor %}
+{% else %}
+    # Could be simplified, but slightly slower:
+    # for d in range(KERNEL_D):
+    #   for i in range(KERNEL_H):
+    #     for j in range(KERNEL_W):
+    #         for k in range(0, GROUP_IN_C, BLOCK_K):
+    BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K
+    for dijk in range(KERNEL_D * KERNEL_H * KERNEL_W * BLOCK_K_COUNT):
+        k = (dijk % BLOCK_K_COUNT) * BLOCK_K
+        dij = dijk // BLOCK_K_COUNT
+        j = dij % KERNEL_W
+        di = dij // KERNEL_W
+        i = di % KERNEL_H
+        d = di // KERNEL_H
+        """
+    + LOOP_BODY_3D
+    + """
+{% endif %}
+
+    mask = (
+        (idx_n < BATCH)[:, None]
+        & (idx_y_d < OUT_D)[:, None]
+        & (idx_y_h < OUT_H)[:, None]
+        & (idx_y_w < OUT_W)[:, None]
+        & (idx_y_c < GROUP_OUT_C)[None, :]
+    )
+    idx_n = idx_n[:, None]
+    idx_c = idx_y_c[None, :] + group * GROUP_OUT_C
+    idx_d = idx_y_d[:, None]
+    idx_h = idx_y_h[:, None]
+    idx_w = idx_y_w[:, None]
+
+    # inductor generates a suffix
+    {{store_output(("idx_n", "idx_c", "idx_d", "idx_h", "idx_w"), "acc", "mask", val_shape=("BLOCK_M", "BLOCK_N"))}}
+""",
+)
+
+aten_convolution = ExternKernelChoice(
+    torch.convolution,
+    "at::convolution",
+    has_out_variant=False,
+    op_overload=aten.convolution.default,
+)
+
+
+def conv1x1_via_mm(x, w, *, out):
+    w = torch.squeeze(torch.squeeze(w, -1), -1)
+    return torch.matmul(
+        x.permute(0, 2, 3, 1), w.permute(1, 0), out=out.permute(0, 2, 3, 1)
+    )
+
+
+aten_conv1x1_via_mm = ExternKernelChoice(conv1x1_via_mm, None)
+
+
+class ConvLayoutParams(TypedDict):
+    stride: tuple[int, ...]
+    padding: tuple[int, ...]
+    dilation: tuple[int, ...]
+    transposed: bool
+    output_padding: tuple[int, ...]
+    groups: int
+
+
+def conv_layout(
+    x: TensorBox,
+    weight: TensorBox,
+    bias: Optional[TensorBox],
+    stride: Sequence[int],
+    padding: tuple[int, ...],
+    dilation: tuple[int, ...],
+    transposed: bool,
+    output_padding: tuple[int, ...],
+    groups: int,
+) -> ir.Layout:
+    """Determine output layout for a convolution"""
+    with V.graph.fake_mode:
+        output = torch.ops.aten.convolution(
+            ir.ir_node_to_tensor(x, guard_shape=True),
+            ir.ir_node_to_tensor(weight, guard_shape=True),
+            ir.ir_node_to_tensor(bias, guard_shape=True),
+            V.graph.sizevars.size_hints(stride),  # type: ignore[arg-type]
+            V.graph.sizevars.size_hints(padding),  # type: ignore[arg-type]
+            V.graph.sizevars.size_hints(dilation),  # type: ignore[arg-type]
+            transposed,
+            V.graph.sizevars.size_hints(output_padding),  # type: ignore[arg-type]
+            groups,
+        )
+        sizes = ir.convert_shape_to_inductor(output.size())
+        stride = ir.convert_shape_to_inductor(output.stride())  # type: ignore[assignment]
+
+    return ir.FixedLayout(
+        x.get_device_or_error(),
+        x.get_dtype(),
+        sizes,
+        stride,
+    )
+
+
+def channels_last_order(rank):
+    order = list(reversed(range(rank)))
+    order.insert(1, order.pop(-1))
+    return order
+
+
+def convert_1x1_conv_to_mm(x, weight, bias):
+    # special case for 1x1 convolution, which is actually just a matmul
+    rank = len(weight.get_size())
+    for _ in range(rank - 2):
+        weight = L[aten.squeeze](weight, dim=-1)
+    weight = L[aten.permute](weight, [1, 0])
+
+    x = ir.ExternKernel.require_stride_order(x, channels_last_order(rank))
+    x_permute = list(range(rank))
+    x_permute.append(x_permute.pop(1))
+    x = L[aten.permute](x, x_permute)
+    *sizes, in_chan = x.get_size()
+    x = L[aten.reshape](x, [sympy_product(sizes), in_chan])
+    if bias is None:
+        result = L[aten.mm](x, weight)
+    else:
+        result = L[aten.addmm](bias, x, weight)
+    result = L[aten.reshape](result, [*sizes, -1])
+    result_permute = list(range(rank))
+    result_permute.insert(1, result_permute.pop(-1))
+    return L[aten.permute](result, result_permute)
+
+
+@register_lowering(aten.convolution)
+def convolution(
+    x: TensorBox,
+    weight: TensorBox,
+    bias: Optional[TensorBox],
+    stride: Sequence[int],
+    padding: Sequence[int],
+    dilation: Sequence[int],
+    transposed: bool,
+    output_padding: Sequence[int],
+    groups: int,
+):
+    stride = tuple(stride)
+    padding = tuple(padding)
+    dilation = tuple(dilation)
+    output_padding = tuple(output_padding)
+    if not isinstance(groups, int):
+        groups = V.graph.sizevars.guard_int(groups)
+    assert isinstance(groups, int)
+
+    # Need use hint for triton template since the template does not
+    # work with a dynamic shape.
+    #
+    # No need to guard_int for dilation and output_padding
+    # since the template is only used when dilation is 1 and output_padding
+    # is 0.
+    stride = tuple(V.graph.sizevars.guard_int_seq(stride))
+    padding = tuple(V.graph.sizevars.guard_int_seq(padding))
+
+    kwargs: ConvLayoutParams = {
+        "stride": stride,
+        "padding": padding,
+        "dilation": dilation,
+        "transposed": transposed,
+        "output_padding": output_padding,
+        "groups": groups,
+    }
+
+    device_type = ir.get_device_type(x)
+
+    if len(x.get_size()) == len(weight.get_size()) - 1:
+        # add batch dimension to simplify rest of function
+        return L[aten.squeeze](
+            convolution(L[aten.expand](x, [1, *x.get_size()]), weight, bias, **kwargs),
+            dim=0,
+        )
+
+    out_chan, in_chan, *kernel_shape = V.graph.sizevars.guard_int_seq(weight.get_size())
+
+    # Always convert conv1D to 2D for Intel GPU.
+    # Only conv2D can be converted to channel last layout,
+    # which have much better performance.
+    if len(x.get_size()) == 3 and len(kernel_shape) == 1 and device_type == "xpu":
+        kwargs.update(
+            {
+                "stride": (1,) + stride,
+                "padding": (0,) + padding,
+                "dilation": (1,) + dilation,
+                "output_padding": (0,) + output_padding,
+            }
+        )
+        # (N, C, L) -> (N, C, 1, L)
+        x = L[aten.unsqueeze](x, dim=2)
+        weight = L[aten.unsqueeze](weight, dim=2)
+
+        return L[aten.squeeze](
+            convolution(x, weight, bias, **kwargs),
+            dim=2,
+        )
+
+    ndim = len(kernel_shape)
+    stride = pad_listlike(stride, ndim)
+    padding = pad_listlike(padding, ndim)
+    dilation = pad_listlike(dilation, ndim)
+    output_padding = pad_listlike(output_padding, ndim)
+
+    def channels_last_conv():
+        if V.graph.layout_opt and ndim == 2:
+            return True
+
+        layout = conv_layout(x, weight, None, **kwargs)
+        req_stride_order = ir.get_stride_order(
+            V.graph.sizevars.size_hints(layout.stride)
+        )
+        return req_stride_order == ir.NHWC_STRIDE_ORDER
+
+    autotuning_gemm = config.max_autotune or config.max_autotune_gemm
+
+    if (
+        (config.conv_1x1_as_mm or (autotuning_gemm and channels_last_conv()))
+        and is_ones(kernel_shape)
+        and is_ones(stride)
+        and is_zeros(padding)
+        and is_ones(dilation)
+        and not transposed
+        and is_zeros(output_padding)
+        and groups == 1
+        and V.graph.sizevars.statically_known_gt(sympy_product(x.get_size()), 0)
+    ):
+        return convert_1x1_conv_to_mm(x, weight, bias)
+
+    if bias is not None and device_type != "cpu":
+        # peel off the bias, cudnn is slower with it
+        result = convolution(x, weight, None, **kwargs)
+        return L[aten.add](
+            result, L[aten.view](bias, [result.get_size()[1]] + ndim * [1])
+        )
+
+    x.realize()
+    weight.realize()
+
+    # ndim can be 1 for convolution in models such as demucs
+    # TODO: check if it's beneficial to convert Conv1d to Conv2d and then
+    # apply channels last.
+    if V.graph.layout_opt and ndim == 2:
+        V.graph.num_channels_last_conv += 1
+        x = ir.ExternKernel.require_channels_last(x)  # type: ignore[assignment]
+        # TODO maybe we can convert weights to channels last just once before
+        # running the model.
+        weight = ir.ExternKernel.require_channels_last(weight)  # type: ignore[assignment]
+        layout = conv_layout(x, weight, None, **kwargs)
+    else:
+        layout = conv_layout(x, weight, None, **kwargs)
+        req_stride_order = ir.get_stride_order(
+            V.graph.sizevars.size_hints(layout.stride)
+        )
+        x = ir.ExternKernel.require_stride_order(x, req_stride_order)  # type: ignore[assignment]
+        weight = ir.ExternKernel.require_stride_order(weight, req_stride_order)  # type: ignore[assignment]
+
+    ordered_kwargs_for_cpp_kernel = [
+        "stride",
+        "padding",
+        "dilation",
+        "transposed",
+        "output_padding",
+        "groups",
+    ]
+    if bias is None:
+        args = [x, weight]
+        kwargs["bias"] = None  # type: ignore[typeddict-unknown-key]
+        ordered_kwargs_for_cpp_kernel.insert(0, "bias")
+    else:
+        args = [x, weight, bias]
+        bias.realize()
+        bias.freeze_layout()
+        V.graph.sizevars.guard_int_seq(bias.get_size())
+
+    choices = []
+    if torch._inductor.utils._use_conv_autotune_backend("ATEN"):
+        choices = [
+            aten_convolution.bind(
+                args,
+                layout,
+                ordered_kwargs_for_cpp_kernel,
+                **kwargs,
+            )
+        ]
+
+    if (
+        torch._inductor.utils._use_conv_autotune_backend("TRITON")
+        and use_triton_template(layout)
+        # templates only support these:
+        and is_ones(dilation)
+        and not transposed
+        and is_zeros(output_padding)
+        # there are some odd models where this check fails (e.g. shufflenet_v2_x1_0)
+        and V.graph.sizevars.statically_known_equals(in_chan * groups, x.get_size()[1])  # type: ignore[arg-type]
+    ):
+        if (
+            is_ones(kernel_shape)
+            and is_ones(stride)
+            and is_zeros(padding)
+            and groups == 1
+        ):
+            choices.append(aten_conv1x1_via_mm.bind(args, layout))
+
+        conv_configs = V.choices.get_conv_configs(device_type)
+
+        dtype_size = x.get_dtype().itemsize
+        for cfg in conv_configs(
+            sympy_product([x.get_size()[0], *x.get_size()[2:]]),
+            out_chan,
+            in_chan,
+            dtype_size=dtype_size,
+        ):
+            if ndim == 2:
+                conv2d_template.maybe_append_choice(
+                    choices,
+                    input_nodes=(x, weight),
+                    layout=layout,
+                    KERNEL_H=kernel_shape[0],
+                    KERNEL_W=kernel_shape[1],
+                    STRIDE_H=stride[0],
+                    STRIDE_W=stride[1],
+                    PADDING_H=padding[0],
+                    PADDING_W=padding[1],
+                    GROUPS=groups,
+                    # TODO(jansel): try unroll for bigger kernels once fixed:
+                    #               https://github.com/triton-lang/triton/issues/1254
+                    UNROLL=is_ones(kernel_shape),
+                    ALLOW_TF32=torch.backends.cudnn.allow_tf32,
+                    num_stages=cfg.num_stages,
+                    num_warps=cfg.num_warps,
+                    **cfg.kwargs,
+                )
+            elif ndim == 3:
+                conv3d_template.maybe_append_choice(
+                    choices,
+                    input_nodes=(x, weight),
+                    layout=layout,
+                    KERNEL_D=kernel_shape[0],
+                    KERNEL_H=kernel_shape[1],
+                    KERNEL_W=kernel_shape[2],
+                    STRIDE_D=stride[0],
+                    STRIDE_H=stride[1],
+                    STRIDE_W=stride[2],
+                    PADDING_D=padding[0],
+                    PADDING_H=padding[1],
+                    PADDING_W=padding[2],
+                    GROUPS=groups,
+                    # TODO(jansel): try unroll for bigger kernels once fixed:
+                    #               https://github.com/triton-lang/triton/issues/1254
+                    UNROLL=is_ones(kernel_shape),
+                    ALLOW_TF32=torch.backends.cudnn.allow_tf32,
+                    num_stages=cfg.num_stages,
+                    num_warps=cfg.num_warps,
+                    **cfg.kwargs,
+                )
+    if use_ck_conv_template(layout):
+        CKGroupedConvFwdTemplate.add_ck_conv_choices(
+            choices,
+            layout,
+            input_nodes=(x, weight) + ((bias,) if bias is not None else tuple()),
+            stride=stride,
+            padding=padding,
+            dilation=dilation,
+            groups=groups,
+            n_spatial_dimensions=ndim,
+        )
+    return autotune_select_algorithm("convolution", choices, args, layout)
+
+
+@register_lowering(aten._convolution)
+def _convolution(
+    x,
+    weight,
+    bias,
+    stride,
+    padding,
+    dilation,
+    transposed,
+    output_padding,
+    groups,
+    benchmark,
+    deterministic,
+    cudnn_enabled,
+    allow_tf32,
+):
+    return convolution(
+        x, weight, bias, stride, padding, dilation, transposed, output_padding, groups
+    )
+
+
+def constrain_conv_to_fx_strides(fx_node, *args, **kwargs):
+    assert fx_node.target is torch.ops.aten.convolution.default
+    if V.graph.layout_opt:
+        return args, kwargs
+    else:
+        return constrain_to_fx_strides(fx_node, *args, **kwargs)
+
+
+add_layout_constraint(aten.convolution, constrain_conv_to_fx_strides)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/custom_op.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/custom_op.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6a641ce83b17eade82a85cd10962dc377dab7e3
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/custom_op.py
@@ -0,0 +1,537 @@
+# Owner(s): ["module: inductor"]
+
+import functools
+import logging
+from collections.abc import Callable
+from typing import Any, Optional, Union
+
+import torch
+from torch._inductor.codegen.subgraph import SubgraphTemplate
+from torch._inductor.ir import Buffer, FixedLayout, ir_node_to_tensor, TensorBox
+from torch._inductor.lowering import lowerings, validate_ir
+from torch._inductor.select_algorithm import (
+    autotune_select_algorithm,
+    ExternKernelChoice,
+)
+from torch._inductor.virtualized import V
+from torch.utils._ordered_set import OrderedSet
+
+
+log = logging.getLogger(__name__)
+
+
+def _detect_collective_ops(choices: list) -> bool:
+    """
+    Detect if choices contain collective operations.
+    """
+    from torch._inductor.utils import is_collective_op
+
+    for choice in choices:
+        if not hasattr(choice, "gm") or choice.gm is None:
+            continue
+
+        for node in choice.gm.graph.nodes:
+            if node.op == "call_function" and node.target is not None:
+                op_name = str(node.target)
+
+                if is_collective_op(op_name) or is_collective_op(
+                    f"torch.ops.{op_name}"
+                ):
+                    return True
+
+    return False
+
+
+class CustomOpConfig:
+    """Config for custom op autotuning.
+
+    Specifies optional decomposition function with parameter values.
+    Each config creates exactly one variant.
+
+    Args:
+        decomposition: Optional functions to autotune. If not provided, default will be used.
+        **params: Parameters passed to the function
+
+    Examples:
+        CustomOpConfig(attention_impl, head_dim=32, method='chunked')
+        CustomOpConfig(head_dim=32, method='chunked')
+    """
+
+    def __init__(
+        self,
+        decomposition: Optional[Callable[..., Any]] = None,
+        **params: Any,
+    ):
+        if decomposition is not None and not callable(decomposition):
+            raise TypeError(
+                f"decomposition must be callable, got {type(decomposition)}"
+            )
+
+        self.decomposition = decomposition
+        self.params = params
+
+    def get_decomposition(
+        self, default_impl: Optional[Callable[..., Any]] = None
+    ) -> Callable[..., Any]:
+        """Return the decomposition function for this config.
+        When decomposition is not specified, return the default implementation.
+        """
+        if self.decomposition is not None:
+            return self.decomposition
+
+        if default_impl is not None and callable(default_impl):
+            return default_impl
+
+        raise TypeError(
+            "No decomposition specified in config and no default implementation provided. "
+            "Please provide a decomposition function in CustomOpConfig."
+        )
+
+    def __repr__(self) -> str:
+        decomp_name = self.decomposition.__name__ if self.decomposition else "default"
+        if self.params:
+            params_str = ", ".join(f"{k}={v}" for k, v in self.params.items())
+            return f"CustomOpConfig({decomp_name}, {params_str})"
+        return f"CustomOpConfig({decomp_name})"
+
+
+__all__ = [
+    "autotune_custom_op",
+    "register_custom_op_autotuning",
+    "CustomOpConfig",
+]
+
+
+def _extract_tensor_inputs(
+    args: tuple[Any, ...], kwargs: dict[str, Any]
+) -> tuple[list[Any], dict[str, Any]]:
+    """Extract tensor inputs from mixed args/kwargs.
+    Separates tensors (for autotuning input_nodes) from non-tensor parameters.
+    Non-tensor kwargs are later functools.partial'd into decomposition functions.
+
+    Args:
+        args: Positional arguments (mix of tensors and scalars)
+        kwargs: Keyword arguments (mix of tensors and scalars)
+
+    Returns:
+        Tuple of (tensor_inputs_list, non_tensor_kwargs)
+    """
+    tensor_inputs = []
+    non_tensor_kwargs = {}
+
+    # Process args and kwargs: separate tensor inputs and non tensor args
+    for i, arg in enumerate(args):
+        if isinstance(arg, (TensorBox, Buffer)):
+            tensor_inputs.append(arg)
+        else:
+            # Add non-tensor positional args to kwargs with generated names
+            non_tensor_kwargs[f"arg_{i}"] = arg
+
+    for key, value in kwargs.items():
+        if isinstance(value, (TensorBox, Buffer)):
+            tensor_inputs.append(value)
+        else:
+            non_tensor_kwargs[key] = value
+
+    return tensor_inputs, non_tensor_kwargs
+
+
+def _merge_config_and_runtime_kwargs(
+    config_params: dict[str, Any],
+    runtime_kwargs: dict[str, Any],
+) -> dict[str, Any]:
+    """Merge config parameters with runtime kwargs. Runtime kwargs take precedence.
+       If there are conflicts, log a warning and use runtime value.
+
+    Args:
+        config_params: Parameters from CustomOpConfig
+        runtime_kwargs: Runtime non-tensor kwargs from _extract_tensor_inputs
+
+    Returns:
+        Merged kwargs dictionary with runtime values taking precedence
+    """
+    merged_kwargs = config_params.copy()
+
+    # Check for conflicts and let runtime kwargs dominate
+    conflicts = OrderedSet(config_params.keys()).intersection(runtime_kwargs.keys())
+
+    for key in conflicts:
+        log.warning(
+            "Parameter '%s' specified both in CustomOpConfig (%s) "
+            "and at runtime (%s). Using runtime value.",
+            key,
+            config_params[key],
+            runtime_kwargs[key],
+        )
+
+    # Runtime kwargs override config params
+    merged_kwargs.update(runtime_kwargs)
+
+    return merged_kwargs
+
+
+def _adapt_user_input_gen_fns(
+    inputs: list[Any],
+    arg_names: list[str],
+    user_input_gen_fns: dict[str, Callable[[torch.Tensor], torch.Tensor]],
+) -> dict[int, Callable[[Any], torch.Tensor]]:
+    """Convert user input generators from name-based to index-based format.
+       Inductor autotune's input_gen_fns expects index of arg_names as key.
+
+    Uses V.graph.sizevars.size_hints() to guess best for dynamic shapes.
+    """
+
+    name_to_index = {name: i for i, name in enumerate(arg_names)}
+    index_based_fns = {}
+
+    for name, gen_fn in user_input_gen_fns.items():
+        if name in name_to_index:
+            index_based_fns[name_to_index[name]] = gen_fn
+        else:
+            log.warning(
+                "Unknown argument name '%s' in input_gen_fns. "
+                "Available argument names: %s",
+                name,
+                list(name_to_index.keys()),
+            )
+
+    def create_internal_input_gen_fn(
+        user_function: Callable[[torch.Tensor], torch.Tensor], arg_name: str
+    ) -> Callable[[Any], torch.Tensor]:
+        """Create internal input generator that converts IR buffer to user's fake tensor."""
+
+        def internal_input_gen_fn(ir_buffer: Any) -> torch.Tensor:
+            fake_tensor = ir_node_to_tensor(ir_buffer)
+            assert fake_tensor is not None, "ir_node_to_tensor returned None"
+            return user_function(fake_tensor)
+
+        return internal_input_gen_fn
+
+    return {
+        i: create_internal_input_gen_fn(
+            user_gen_fn, arg_names[i] if i < len(arg_names) else f"arg_{i}"
+        )
+        for i, user_gen_fn in index_based_fns.items()
+        if i < len(inputs)
+    }
+
+
+def _create_fallback_choice(
+    name: str,
+    default_impl: Callable[..., Any],
+    fake_output: torch.Tensor,
+    kwargs: dict[str, Any],
+) -> ExternKernelChoice:
+    """Create fallback choice for default implementation."""
+
+    def fallback_wrapper(*args: Any) -> Any:
+        return default_impl(*args, **kwargs)
+
+    return ExternKernelChoice(
+        kernel=fallback_wrapper,
+        name=f"{name}_fallback_default",
+        has_out_variant=False,
+        op_overload=default_impl,
+        use_fallback_kernel=True,
+    )
+
+
+def autotune_custom_op(
+    name: str,
+    decompositions: list[Callable[..., Any]],
+    inputs: list[Any],
+    non_tensor_args: list[dict[str, Any]],
+    op_overload: torch._ops.OpOverload,
+    user_input_gen_fns: Optional[
+        dict[str, Callable[[torch.Tensor], torch.Tensor]]
+    ] = None,
+) -> Union[TensorBox, Any]:
+    """Autotune custom operations by comparing multiple decomposition implementations.
+
+    Currently supports SINGLE OUTPUT custom ops only.
+    TODO: Add support for multiple output custom ops (tuple/list returns).
+
+    This function generates multiple implementation choices for a custom operation and
+    uses Inductor's autotuning system to select the best performing variant at runtime.
+    After selecting the best choice, applies inline fusion if the winning choice has a graph.
+
+    Args:
+        name: Unique identifier for the autotuning operation
+        decompositions: List of alternative implementation functions to benchmark
+        inputs: Input tensor IR nodes from compilation (TensorBox/Buffer objects)
+        non_tensor_args: List of kwargs dicts, paired with corresponding decompositions arg
+        op_overload: OpOverload of the custom op, used as fallback implementation
+        user_input_gen_fns: Optional custom input generators for benchmarking.
+                           Maps input indices to functions that take fake tensors
+                           and return real tensors for performance measurement.
+
+    Returns:
+        IR node representing the optimized operation result
+
+    Raises:
+        TypeError: If decompositions is not a list/tuple
+        RuntimeError: If no inputs or no valid choices generated
+    """
+    if not isinstance(decompositions, (list, tuple)):
+        raise TypeError(
+            f"decompositions must be a list or tuple of callables, got {type(decompositions)}"
+        )
+
+    if not inputs:
+        raise RuntimeError(f"Custom op '{name}' requires tensor inputs for autotuning")
+
+    if len(decompositions) != len(non_tensor_args):
+        raise ValueError(
+            f"decompositions and non_tensor_args must have same length, "
+            f"got {len(decompositions)} decompositions and {len(non_tensor_args)} kwargs"
+        )
+
+    template = SubgraphTemplate(name=name)
+    choices = template.generate_custom_op_choices(
+        name=name,
+        # pyrefly: ignore [bad-argument-type]
+        decompositions=decompositions,
+        input_nodes=list(inputs),
+        non_tensor_args=non_tensor_args,
+    )
+
+    # Add default implementation as fallback
+    if op_overload and hasattr(op_overload, "_op"):
+        fallback_name = f"{name}_fallback_default"
+        from torch._inductor.select_algorithm import extern_kernels
+
+        # Skip if extern_kernel already registered to avoid duplicate registration error
+        if not hasattr(extern_kernels, fallback_name):
+            with V.fake_mode:
+                fake_inputs = [ir_node_to_tensor(inp) for inp in inputs]
+                fallback_kwargs = non_tensor_args[0] if non_tensor_args else {}
+                fake_output = op_overload(*fake_inputs, **fallback_kwargs)
+
+            fallback_choice = _create_fallback_choice(
+                name, op_overload, fake_output, fallback_kwargs
+            )
+            fallback_choice.maybe_append_choice(
+                choices=choices,
+                input_nodes=list(inputs),
+                layout=FixedLayout(
+                    device=fake_output.device,
+                    dtype=fake_output.dtype,
+                    size=fake_output.shape,
+                    stride=fake_output.stride(),
+                ),
+            )
+
+    if not choices:
+        raise RuntimeError(f"No valid choices generated for {name}")
+
+    # Convert user input generation functions to internal format
+    input_gen_fns = {}
+    if user_input_gen_fns:
+        import inspect
+
+        arg_names = (
+            list(inspect.signature(decompositions[0]).parameters.keys())
+            if decompositions
+            else []
+        )
+        input_gen_fns = _adapt_user_input_gen_fns(inputs, arg_names, user_input_gen_fns)
+
+    is_collective = _detect_collective_ops(choices)
+
+    # Run autotuning and get both result and winning choice
+    selected_result, winning_choice = autotune_select_algorithm(
+        name=name,
+        choices=choices,
+        input_nodes=list(inputs),
+        layout=choices[0].layout,
+        input_gen_fns=input_gen_fns,
+        return_choice=True,
+        is_collective=is_collective,
+    )
+
+    # Apply inlining for fusion if winning_choice has graph; otherwise return result as-is(default fallback impl)
+    if winning_choice.gm is not None:
+        log.debug(
+            "Inlining winning choice: %s (name=%s)",
+            getattr(winning_choice, "name", type(winning_choice).__name__),
+            name,
+        )
+        from torch._inductor.codegen.subgraph import inline_subgraph_to_ir_nodes
+
+        return inline_subgraph_to_ir_nodes(winning_choice.gm, inputs, name)
+
+    log.debug(
+        "Winning choice does not support inlining: %s (name=%s)",
+        getattr(winning_choice, "name", type(winning_choice).__name__),
+        name,
+    )
+    return selected_result
+
+
+def _generate_dynamic_configs(
+    tensor_inputs: list[Buffer],
+    config_generator: Callable[[dict[str, torch.Tensor]], list[CustomOpConfig]],
+    default_impl: Callable[..., Any],
+    operation_name: str,
+) -> list[CustomOpConfig]:
+    """Generate configs dynamically based on input tensors at lowering time."""
+    import inspect
+
+    sig = inspect.signature(default_impl)
+    param_names = list(sig.parameters.keys())
+
+    with V.fake_mode:
+        fake_tensors = [ir_node_to_tensor(inp) for inp in tensor_inputs]
+
+    fake_tensors_dict = dict(zip(param_names, fake_tensors))
+
+    configs = config_generator(fake_tensors_dict)
+
+    if not isinstance(configs, (list, tuple)):
+        raise TypeError(
+            f"config_generator must return a list or tuple of CustomOpConfig, "
+            f"got {type(configs)}"
+        )
+    if not configs:
+        raise ValueError(f"config_generator returned empty list for {operation_name}. ")
+
+    return list(configs)
+
+
+def register_custom_op_autotuning(
+    custom_op: torch._library.custom_ops.CustomOpDef,
+    configs: Optional[Union[list[CustomOpConfig], list[Callable[..., Any]]]] = None,
+    config_generator: Optional[
+        Callable[[dict[str, torch.Tensor]], list[CustomOpConfig]]
+    ] = None,
+    name: Optional[str] = None,
+    input_gen_fns: Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]] = None,
+) -> None:
+    """Register custom op for autotuning with custom_op configs where each config
+    specifies a decomposition implementation function with its parameter values.
+
+    Args:
+        custom_op: Custom operation (decorated function from @torch.library.custom_op)
+        configs: List of CustomOpConfig objects for static inputs. Mutually exclusive with config_generator.
+        config_generator: Dynamic config generator function that takes a dict mapping
+                          parameter names to fake tensors, and returns list[CustomOpConfig]
+                          based on input tensor properties. Mutually exclusive with configs.
+        name: Operation name (default: "{op_name}_autotuned")
+        input_gen_fns: Custom input generators for benchmarking
+
+    Examples:
+        # Static configs
+        @torch.library.custom_op("mylib::attention", mutates_args=())
+        def my_attention(query, key, value, head_dim=32):
+            ...
+
+        register_custom_op_autotuning(
+            my_attention,
+            configs=[
+                CustomOpConfig(attention_impl, head_dim=32, method='chunked'),
+                CustomOpConfig(attention_impl, head_dim=64, method='tiled'),
+                CustomOpConfig(head_dim=128),  # No decomposition specified, use default
+            ],
+            input_gen_fns={
+                "query": lambda fake: torch.randn_like(fake, device='cuda'),
+                "key": lambda fake: torch.randn_like(fake, device='cuda'),
+                "value": lambda fake: torch.randn_like(fake, device='cuda'),
+            },
+        )
+
+        # Dynamic config generation based on input tensor properties
+        def generate_k_split_configs(fake_tensors: dict[str, torch.Tensor]) -> list[CustomOpConfig]:
+            # Access tensor shapes, dtypes, devices, etc.
+            m, k = fake_tensors["mat1"].shape
+            _, n = fake_tensors["mat2"].shape
+            k_splits = ... # compute possible k splits based on tensor properties
+            return [CustomOpConfig(k_splits=k) for k in k_splits]
+
+        register_custom_op_autotuning(
+            matmul_decomposeK_op,
+            config_generator=generate_k_split_configs,
+            input_gen_fns={...},
+        )
+    """
+    from torch._library.custom_ops import CustomOpDef
+
+    if not isinstance(custom_op, CustomOpDef):
+        raise TypeError(
+            f"custom_op must be a CustomOpDef (decorated function from @torch.library.custom_op), "
+            f"got {type(custom_op)}."
+        )
+
+    # Validate configs and config_generator are mutually exclusive
+    if configs is not None and config_generator is not None:
+        raise ValueError(
+            "Cannot specify both 'configs' and 'config_generator'. "
+            "Use 'config_generator' for shape-dependent configs."
+        )
+
+    if configs is None and config_generator is None:
+        raise ValueError("Must specify either 'configs' or 'config_generator'")
+
+    op_overload = custom_op._opoverload
+    default_impl = custom_op._init_fn
+
+    # Process and validate static configs at registration time
+    static_configs = None
+    if configs is not None:
+        if not isinstance(configs, (list, tuple)):
+            raise TypeError(f"configs must be a list or tuple, got {type(configs)}")
+
+        static_configs = []
+        for cfg in configs:
+            if isinstance(cfg, CustomOpConfig):
+                static_configs.append(cfg)
+            else:
+                raise TypeError(
+                    f"Each config must be a CustomOpConfig object, got {type(cfg)}"
+                )
+
+        if not static_configs:
+            raise ValueError("At least one config must be provided")
+
+    if name is None:
+        name = f"{op_overload._name}_autotuned"
+
+    @functools.wraps(op_overload)
+    def autotuning_lowering(*args: Any, **kwargs: Any) -> Any:
+        """Inductor lowering function that replaces custom op calls with autotuned versions."""
+        # Extract tensor inputs and non-tensor parameters (runtime kwargs)
+        tensor_inputs, runtime_kwargs = _extract_tensor_inputs(args, kwargs)
+
+        # Get configs: either generate dynamically or use static configs
+        if config_generator is not None:
+            configs_to_use = _generate_dynamic_configs(
+                tensor_inputs, config_generator, default_impl, name
+            )
+        else:
+            assert static_configs is not None
+            configs_to_use = static_configs
+
+        # Prepare decompositions and kwargs for autotuning
+        decompositions = []
+        non_tensor_args = []
+
+        for cfg in configs_to_use:
+            decomp = cfg.get_decomposition(default_impl=default_impl)
+            decompositions.append(decomp)
+
+            # Merge config params with runtime kwargs (runtime takes precedence)
+            merged_kwargs = _merge_config_and_runtime_kwargs(cfg.params, runtime_kwargs)
+            non_tensor_args.append(merged_kwargs)
+
+        result = autotune_custom_op(
+            name=name,
+            decompositions=decompositions,
+            inputs=tensor_inputs,
+            non_tensor_args=non_tensor_args,
+            op_overload=op_overload,
+            user_input_gen_fns=input_gen_fns,
+        )
+
+        validate_ir(result)
+        return result
+
+    lowerings[op_overload] = autotuning_lowering
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/mm.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/mm.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b57c458f46e62862ea997c5a0fad0d4729d65d5
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/mm.py
@@ -0,0 +1,1102 @@
+# mypy: allow-untyped-defs
+import functools
+import logging
+from typing import Any, Optional, Union
+
+import torch
+from torch._dynamo.utils import counters
+from torch._inductor.autoheuristic.autoheuristic import AutoHeuristicSelectAlgorithm
+from torch._inductor.autoheuristic.autoheuristic_utils import (
+    AHContext,
+    context_add_strides,
+    context_add_using_tf32,
+    mm_operations,
+)
+from torch._inductor.codegen.cpp_gemm_template import CppGemmTemplate
+from torch._inductor.remote_gemm_autotune_cache import gen_best_config
+from torch._inductor.virtualized import ops, V
+from torch.fx.experimental.proxy_tensor import make_fx
+from torch.nn.functional import ScalingType  # type: ignore[attr-defined]
+from torch.torch_version import TorchVersion
+
+from .. import config as inductor_config, distributed_autotune
+from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate
+from ..codegen.rocm.ck_tile_universal_gemm_template import CKTileGemmTemplate
+from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
+from ..codegen.subgraph import SubgraphChoiceCaller, SubgraphTemplate
+from ..ir import Buffer, ChoiceCaller, is_triton, Layout
+from ..kernel_inputs import MMKernelInputs
+from ..lowering import (
+    lowerings,
+    make_pointwise,
+    make_reduction,
+    register_lowering,
+    transform_args,
+)
+from ..select_algorithm import (
+    autotune_select_algorithm,
+    ExternKernelChoice,
+    KernelTemplate,
+    realize_inputs,
+    TritonTemplate,
+)
+from ..utils import (
+    _use_cutlass_for_op,
+    ceildiv,
+    use_aten_gemm_kernels,
+    use_ck_gemm_template,
+    use_ck_tile_gemm_template,
+    use_cpp_gemm_template,
+    use_cutlass_template,
+    use_decompose_k_choice,
+    use_triton_blackwell_tma_template,
+    use_triton_template,
+    use_triton_tma_template,
+)
+from .mm_common import (
+    _is_static_problem,
+    load_kernel_template,
+    mm_args,
+    mm_grid,
+    persistent_mm_grid,
+    use_native_matmul,
+)
+
+
+try:
+    import triton
+
+    triton_version = TorchVersion(triton.__version__)
+    has_triton = True
+except ImportError:
+    triton_version = TorchVersion("0.0.0")
+    has_triton = False
+
+log = logging.getLogger(__name__)
+aten = torch.ops.aten
+prims = torch.ops.prims
+
+# We define each template kernel in a separate file which is the name of the input to load_kernel_template
+# (e.g. triton_mm for templates/triton_mm.py.jinja).
+# If you are adding a new template, please follow that pattern and add a new file with your implementation in the templates folder.
+mm_template = TritonTemplate(
+    name="mm",
+    grid=mm_grid,
+    source=load_kernel_template("triton_mm")
+    if (torch.version.hip is None) or triton_version >= "3.3.0"
+    # FIXME: To get around rocm failures like https://github.com/pytorch/pytorch/actions/runs/13123783322/job/36617154943
+    # The only difference between the two templates is M >= BLOCK_M and N >= BLOCK_N checking.
+    # See more details in https://github.com/pytorch/pytorch/pull/146293
+    else load_kernel_template("triton_mm_rocm"),
+    cache_codegen_enabled_for_template=True,
+    prologue_loads_all_inputs=True,
+)
+
+persistent_tma_mm_template = TritonTemplate(
+    name="mm_persistent_tma",
+    grid=persistent_mm_grid,
+    source=load_kernel_template("triton_persistent_tma_mm"),
+)
+
+
+scaled_mm_device_tma_epilogue_scaling_template = TritonTemplate(
+    name="scaled_mm_device_tma_epilogue_scaling",
+    grid=persistent_mm_grid,
+    source=load_kernel_template("triton_epilogue_scaled_mm"),
+)
+
+
+scaled_mm_device_tma_main_loop_scaling_template = TritonTemplate(
+    name="scaled_mm_device_tma_main_loop_scaling",
+    grid=persistent_mm_grid,
+    source=load_kernel_template("triton_main_loop_scaled_mm"),
+)
+
+blackwell_ws_persistent_device_tma_mm_template = TritonTemplate(
+    name="blackwell_ws_persistent_device_tma",
+    grid=persistent_mm_grid,
+    source=load_kernel_template("triton_blackwell_ws_persistent_device_tma_mm"),
+)
+
+
+# prevent duplication registration of extern functions
+@functools.cache
+def lazy_register_extern_choice(fn):
+    return ExternKernelChoice(fn)
+
+
+aten_mm = ExternKernelChoice(torch.mm, "at::mm_out", op_overload=aten.mm.out)
+aten_mm_dtype = ExternKernelChoice(
+    torch.mm,
+    "at::_mm_dtype_out_cuda",
+    name="mm_dtype",
+    op_overload=aten.mm.dtype_out,
+)
+
+aten_addmm = ExternKernelChoice(
+    torch.addmm, "at::addmm_out", op_overload=aten.addmm.out
+)
+
+aten__int_mm = ExternKernelChoice(
+    torch._int_mm, "at::_int_mm_out", op_overload=aten._int_mm.out
+)
+
+aten__sparse_semi_structured_mm = ExternKernelChoice(
+    torch._sparse_semi_structured_mm,
+    "at::_sparse_semi_structured_mm",
+    has_out_variant=False,
+    op_overload=aten._sparse_semi_structured_mm.default,
+)
+
+aten__fp8_mm = ExternKernelChoice(
+    torch._scaled_mm, "at::_scaled_mm_out", op_overload=aten._scaled_mm.out
+)
+
+
+def _is_int8_mat(mat):
+    return mat.get_dtype() in (torch.int8, torch.uint8)
+
+
+def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1):
+    """
+    Giving torch.addmm a 1D tensor calls a different (faster) cublasLt
+    kernel under the hood.  There are a few shapes where this is slower,
+    but they are rare.
+    """
+    if (inp.stride(0) == 0 and inp.size(0) != 0) or inp.size(0) == 1:
+        return torch.addmm(inp[0], mat1, mat2, out=out, alpha=alpha, beta=beta)
+    return torch.addmm(inp, mat1, mat2, out=out, alpha=alpha, beta=beta)
+
+
+def check_supported_striding(mat_a, mat_b) -> None:
+    def is_row_major(stride) -> bool:
+        return V.graph.sizevars.statically_known_equals(stride[1], 1)
+
+    def is_col_major(stride) -> bool:
+        return V.graph.sizevars.statically_known_equals(stride[0], 1)
+
+    def has_zero_dim(size) -> bool:
+        return bool(
+            V.graph.sizevars.statically_known_equals(size[0], 0)
+            or V.graph.sizevars.statically_known_equals(size[1], 0)
+        )
+
+    # Check mat_a (self) stride requirements
+    torch._check(
+        is_row_major(mat_a.get_stride()) or has_zero_dim(mat_a.get_size()),
+        lambda: f"mat_a must be row_major, got stride {mat_a.get_stride()}",
+    )
+
+    # Check mat_b stride requirements
+    torch._check(
+        is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()),
+        lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}",
+    )
+
+
+aten_bias_addmm = ExternKernelChoice(bias_addmm, None)
+
+
+def decomposeK(a, b, k_splits):
+    m = a.shape[0]
+    n = b.shape[1]
+    k = a.shape[1]
+
+    k_parts = k // k_splits
+    B = k_splits
+    a_reshaped = torch.permute(a.reshape(m, B, k_parts), (1, 0, 2))
+    b_reshaped = b.reshape(B, k_parts, n)
+    result = torch.bmm(a_reshaped, b_reshaped, out_dtype=torch.float32)
+    reduced_buf = torch.sum(result, 0)
+    return reduced_buf.to(a.dtype)
+
+
+class DecomposeKSugraphTemplate(SubgraphTemplate):
+    def __init__(self):
+        super().__init__(
+            name="decompose_k",
+        )
+
+    def generate(  # type: ignore[override]
+        self,
+        input_nodes: list[Buffer],
+        layout: Layout,
+        k_split: int,
+    ) -> SubgraphChoiceCaller:
+        from torch._dispatch.python import enable_python_dispatcher
+
+        from ..decomposition import select_decomp_table
+
+        name = f"decompose_k_mm_{k_split}_split"
+        description = f"{k_split=}"
+
+        with enable_python_dispatcher():
+            decompositions = select_decomp_table()
+            fn = make_fx(
+                functools.partial(decomposeK, k_splits=k_split),
+                decompositions,
+            )
+
+            return super().generate(
+                name=name,
+                input_nodes=input_nodes,
+                layout=layout,
+                make_fx_graph=fn,
+                description=description,
+            )
+
+
+decompose_k_subgraph_template = DecomposeKSugraphTemplate()
+
+
+class ContiguousTemplate(SubgraphTemplate):
+    def __init__(self, name: str, description: str, fn: Any):
+        self.name = name
+        self.description = description
+        self.fn = fn
+        super().__init__(
+            name=name,
+        )
+
+    def generate(  # type: ignore[override]
+        self,
+        input_nodes: list[Buffer],
+        layout: Layout,
+    ) -> SubgraphChoiceCaller:
+        from torch._dispatch.python import enable_python_dispatcher
+
+        from ..decomposition import select_decomp_table
+
+        with enable_python_dispatcher():
+            decompositions = select_decomp_table()
+            fn = make_fx(
+                self.fn,
+                decompositions,
+            )
+
+            return super().generate(
+                name=self.name,
+                input_nodes=input_nodes,
+                layout=layout,
+                make_fx_graph=fn,
+                description=self.description,
+            )
+
+
+def contiguous_mm(a, b):
+    return torch.mm(a, b.contiguous())
+
+
+def contiguous_addmm(inp, a, b):
+    return torch.addmm(inp, a, b.contiguous())
+
+
+mm_contiguous_subgraph_template = ContiguousTemplate(
+    "contiguous_mm", "contiguous mm", contiguous_mm
+)
+addmm_contiguous_subgraph_template = ContiguousTemplate(
+    "contiguous_addmm", "contiguous addmm", contiguous_addmm
+)
+
+
+@register_lowering(aten.mm, type_promotion_kind=None)
+def tuned_mm(mat1, mat2, out_dtype=None, *, layout=None):
+    """
+    Lowering for autotuning aten.mm with different backends (Aten, Triton, CUTLASS, etc.)
+    """
+    if out_dtype is not None:
+        input_dtype = mat1.get_dtype()
+        torch._check(
+            mat2.get_dtype() == input_dtype,
+            lambda: "input dtypes must be the same",
+        )
+        torch._check(
+            mat1.get_device().type == "cuda",
+            lambda: "out_dtype is only supported for CUDA",
+        )
+        torch._check(
+            out_dtype == input_dtype
+            or (
+                out_dtype == torch.float32
+                and input_dtype in (torch.float16, torch.bfloat16)
+            ),
+            lambda: "out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs",
+        )
+
+    # Lower matmul-related operations (e.g., torch.matmul / torch.bmm / torch.addmm)
+    # into native matmul IR using `ops.dot`. When we see a matmul pattern
+    # (C[y, x] = A[y, r] * B[r, x]), the core idea is to emulate a broadcasted
+    # multiply followed by a sum.
+    #
+    # For example, given `C = torch.matmul(A, B)`, this can be rewritten as:
+    #
+    #     Prod = A.unsqueeze(-1) * B.unsqueeze(0)
+    #     C = Prod.sum(dim=1)
+    #
+    # Instead of explicitly using `ops.mul` and `ops.reduction("sum")`, we lower
+    # these into `ops.dot` (pointwise) and `ops.reduction("dot")`. These IR nodes
+    # are semantically equivalent to the `ops.mul` + `ops.reduction("sum")`
+    # combination, but are lowered to `tl.dot` during the code generation phase.
+    if use_native_matmul(mat1, mat2):
+        mat1 = lowerings[aten.unsqueeze](mat1, -1)
+        mat2 = lowerings[aten.unsqueeze](mat2, 0)
+        args, kwargs = transform_args(
+            args=[mat1, mat2],
+            kwargs={},
+            broadcast=True,
+            type_promotion_kind=None,
+            convert_input_to_bool=False,
+        )  # Handles broadcasting the arguments
+
+        if inductor_config.triton.codegen_upcast_to_fp32 and mat1.dtype in [
+            torch.float16,
+            torch.bfloat16,
+        ]:
+
+            def _to_dtype(x):
+                return ops.to_dtype(x, mat1.dtype, use_compute_types=False)
+
+            args = [make_pointwise(_to_dtype)(x) for x in args]
+
+        mul_pointwise = make_pointwise(ops.dot)(*args)
+        dot_reduction = make_reduction("dot")(mul_pointwise, 1)
+
+        return dot_reduction
+
+    # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
+    m, n, k, layout, mat1, mat2 = mm_args(
+        mat1, mat2, layout=layout, out_dtype=out_dtype
+    )
+    static_shape, is_nonzero = _is_static_problem(layout)
+    name = "mm"
+
+    # Create MMKernelInputs for standard MM at the top
+    kernel_inputs = MMKernelInputs([mat1, mat2], out_dtype=out_dtype)
+
+    # below is for getting an overview logging info of inductor mms
+    counters["aten_mm_info"][f"aten.mm_{m}_{n}_{k}"] += 1
+    log.info(
+        "Tuned aten.mm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
+        m,
+        n,
+        k,
+        mat1.get_dtype(),
+        mat2.get_dtype(),
+        layout,
+    )
+
+    choices: list[ChoiceCaller] = []
+    static_shape, is_nonzero = _is_static_problem(layout)
+
+    aten_handler: ExternKernelChoice = aten_mm
+    aten_extra_kwargs: dict[str, Any] = {}
+    if out_dtype is not None:
+        aten_handler = aten_mm_dtype
+        aten_extra_kwargs = {"out_dtype": out_dtype}
+
+    templates_to_use: list[Union[ExternKernelChoice, KernelTemplate]] = []
+    kwarg_overrides: dict[str, dict[str, Any]] = {}
+    if use_aten_gemm_kernels():
+        templates_to_use.append(aten_handler)
+        if aten_extra_kwargs:
+            kwarg_overrides[aten_handler.uid] = aten_extra_kwargs
+
+    if (
+        out_dtype is None
+        and is_nonzero
+        and use_triton_template(layout, check_max_autotune=True)
+    ):
+        if use_decompose_k_choice(m, n, k):
+            templates_to_use.append(decompose_k_subgraph_template)
+        # Triton Templates typically perform very poorly for large K.
+        # Its highly unlikely that if we want to use decompose_k, then
+        # Triton will ever win.
+        #
+        # To be conservative we increase this threshold for N/M by 2.
+        is_exhaustive = inductor_config.max_autotune_gemm_search_space == "exhaustive"
+        if is_exhaustive or not use_decompose_k_choice(m, n, k, threshold_multiple=2):
+            templates_to_use.append(mm_template)
+
+            if use_triton_tma_template(mat1, mat2, output_layout=layout):
+                templates_to_use.append(persistent_tma_mm_template)
+
+            if use_triton_blackwell_tma_template(mat1, mat2, output_layout=layout):
+                templates_to_use.append(blackwell_ws_persistent_device_tma_mm_template)
+
+        templates_to_use.append(mm_contiguous_subgraph_template)
+
+    choices.extend(
+        V.choices.get_template_configs(
+            kernel_inputs,
+            templates_to_use,
+            "mm",
+            kwarg_overrides=kwarg_overrides,
+        )
+    )
+
+    if (
+        out_dtype is None
+        and is_nonzero
+        and use_cutlass_template(layout, m, n, k)
+        and _use_cutlass_for_op("mm")
+    ):
+        CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
+            choices, layout, kernel_inputs.nodes()
+        )
+
+    if out_dtype is None and is_nonzero and use_ck_gemm_template(layout, m, n, k):
+        CKGemmTemplate.add_ck_gemm_choices(choices, layout, kernel_inputs.nodes())
+    if out_dtype is None and is_nonzero and use_ck_tile_gemm_template(layout, m, n, k):
+        CKTileGemmTemplate.add_choices(choices, layout, kernel_inputs.nodes())
+
+    if out_dtype is None and use_cpp_gemm_template(layout, mat1, mat2):
+        CppGemmTemplate.add_choices(
+            choices,
+            layout,
+            kernel_inputs.nodes(),
+        )
+
+    input_nodes = [mat1, mat2]
+    if (
+        out_dtype is None
+        and is_nonzero
+        and use_triton_template(layout)
+        and torch._inductor.config.run_autoheuristic(name)
+        and is_triton(mat1)
+    ):
+        always_included = []
+        if use_aten_gemm_kernels():
+            always_included.append("extern_mm")
+        num_choices_before_extra_configs = len(choices)
+        choices.extend(
+            V.choices.get_template_configs(
+                # TODO(coconutruben): remove once we deprecate ah
+                # mm-extra is a hack to keep the ah functionality alive
+                # while we transition to the unified kwargs retrieval
+                kernel_inputs,
+                [mm_template],
+                "mm-ah",
+            )
+        )
+
+        # using AutoHeuristic for ranking
+        ah_choices = mm_autoheuristic(
+            mat1,
+            mat2,
+            m,
+            n,
+            k,
+            choices,
+            name,
+            input_nodes,
+            mm_operations(),
+            None,
+            top_k=10,
+            always_included=always_included,
+        )
+        if not torch._inductor.config.collect_autoheuristic(name):
+            # if we are collecting data, we do not want to modify choices
+            if ah_choices is not None and len(ah_choices) > 0:
+                # the order in which autoheuristic returns choices is not the same as
+                # as the order of choices, which affects things like epilogue fusion.
+                # once epilogue fusion benchmarks choices in sorted order, I think we can
+                # just use the order returned by autoheuristic
+                choices = [choice for choice in choices if choice in ah_choices]
+            else:
+                choices = choices[:num_choices_before_extra_configs]
+
+    if out_dtype is None:
+        for k in inductor_config.external_matmul:
+            choices.append(
+                lazy_register_extern_choice(k).bind(kernel_inputs.nodes(), layout)
+            )
+
+    best_config_future = None
+    if out_dtype is None and torch._inductor.config.remote_gemm_autotune_cache:
+        # Purposely not awaiting the future here - this kicks off the best config lookup at lowering time
+        # The future will be awaited at scheduling time in select_algorithm.py
+        best_config_future = gen_best_config(mat1, mat2)
+
+    if box := distributed_autotune.maybe_autotune_remote(
+        name, choices, kernel_inputs.nodes(), layout
+    ):
+        return box
+
+    return autotune_select_algorithm(
+        name,
+        choices,
+        kernel_inputs.nodes(),
+        layout,
+        best_config_future=best_config_future,
+    )
+
+
+@register_lowering(aten._int_mm, type_promotion_kind=None)
+def tuned_int_mm(mat1, mat2, *, layout=None):
+    # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
+    m, n, k, layout, mat1, mat2 = mm_args(
+        mat1, mat2, layout=layout, out_dtype=torch.int32
+    )
+    name = "int_mm"
+    # below is for getting an overview logging info of inductor mms
+    counters["aten_mm_info"][f"aten._int_mm_{m}_{n}_{k}"] += 1
+    log.info(
+        "Tuned aten._int_mm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
+        m,
+        n,
+        k,
+        mat1.get_dtype(),
+        mat2.get_dtype(),
+        layout,
+    )
+
+    static_shape, is_nonzero = _is_static_problem(layout)
+    use_cutlass = static_shape and is_nonzero and use_cutlass_template(layout, m, n, k)
+    choices: list[ChoiceCaller] = []
+
+    # Create MMKernelInputs for Int MM
+    kernel_inputs = MMKernelInputs([mat1, mat2], out_dtype=torch.int32)
+
+    # Collect all templates for unified call
+    templates_to_use: list[Union[ExternKernelChoice, KernelTemplate]] = []
+    if use_aten_gemm_kernels():
+        templates_to_use.append(aten__int_mm)
+
+    if is_nonzero and use_triton_template(
+        layout, enable_int32=True, check_max_autotune=False
+    ):
+        templates_to_use.append(mm_template)
+
+    # Single unified call for all templates
+    choices.extend(
+        V.choices.get_template_configs(kernel_inputs, templates_to_use, name)
+    )
+
+    if use_cutlass and _use_cutlass_for_op(name):
+        CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
+            choices, layout, kernel_inputs.nodes(), fuseable=True, non_fuseable=True
+        )
+
+    return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
+
+
+@register_lowering(aten.addmm, type_promotion_kind=None)
+def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
+    """
+    Lowering for autotuning aten.addmm with different backends (Aten, Triton, CUTLASS, etc.)
+    """
+    if use_native_matmul(mat1, mat2):
+        if beta == 0:
+            arg1 = 0
+        else:
+            arg1 = lowerings[aten.mul](beta, inp)
+
+        if alpha == 0:
+            arg2 = 0
+        else:
+            arg2 = lowerings[aten.mul](alpha, lowerings[aten.mm](mat1, mat2))
+
+        return lowerings[aten.add](arg1, arg2)
+
+    # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
+    m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
+    static_shape, is_nonzero = _is_static_problem(layout)
+    name = "addmm"
+
+    # Create MMKernelInputs for AddMM at the top
+    kernel_inputs = MMKernelInputs(
+        [inp_expanded, mat1, mat2], scalars=dict(alpha=alpha, beta=beta)
+    )
+    choices: list[ChoiceCaller] = []
+
+    # below is for getting an overview logging info of inductor mms
+    counters["aten_mm_info"][f"aten.addmm_{m}_{n}_{k}"] += 1
+    log.info(
+        "Tuned aten.addmm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
+        m,
+        n,
+        k,
+        mat1.get_dtype(),
+        mat2.get_dtype(),
+        layout,
+    )
+    if (not is_nonzero) or (
+        not (inductor_config.max_autotune or inductor_config.max_autotune_gemm)
+    ):
+        # TODO(coconutruben): combine this with the main flow of addmm through
+        # a subgraph or something as inp vs inp_expanded causes some slight numeric
+        # differences
+        kernel_inputs = MMKernelInputs(
+            [inp, mat1, mat2], scalars=dict(alpha=alpha, beta=beta)
+        )
+        choices.extend(
+            V.choices.get_template_configs(
+                kernel_inputs,
+                [aten_addmm],
+                name,
+            )
+        )
+        return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
+
+    # Collect all templates for unified call
+    templates_to_use: list[Union[ExternKernelChoice, KernelTemplate]] = []
+    if use_aten_gemm_kernels():
+        templates_to_use.extend([aten_bias_addmm, aten_addmm])
+
+    if is_nonzero and use_triton_template(layout, check_max_autotune=False):
+        templates_to_use.append(mm_template)
+
+        if use_triton_tma_template(mat1, mat2, output_layout=layout):
+            templates_to_use.append(persistent_tma_mm_template)
+
+        if use_triton_blackwell_tma_template(mat1, mat2, output_layout=layout):
+            templates_to_use.append(blackwell_ws_persistent_device_tma_mm_template)
+
+        templates_to_use.append(addmm_contiguous_subgraph_template)
+
+    # Single unified call for all templates
+    choices.extend(
+        V.choices.get_template_configs(kernel_inputs, templates_to_use, name)
+    )
+
+    if (
+        is_nonzero
+        and use_cutlass_template(layout, m, n, k)
+        and _use_cutlass_for_op(name)
+    ):
+        CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
+            choices,
+            layout,
+            # reorder here because CUTLASS expects (x, w, bias) but torch
+            # is bias, x, w
+            kernel_inputs.nodes(reorder=[1, 2, 0]),
+            alpha=alpha,
+            beta=beta,
+        )
+
+    if is_nonzero and use_ck_gemm_template(layout, m, n, k):
+        CKGemmTemplate.add_ck_gemm_choices(
+            choices,
+            layout,
+            # reorder here because CK expects (x, w, bias) but torch
+            # is bias, x, w
+            kernel_inputs.nodes(reorder=[1, 2, 0]),
+            alpha=alpha,
+            beta=beta,
+            input_reorder=[2, 0, 1],
+        )
+
+    if use_cpp_gemm_template(layout, mat1, mat2):
+        CppGemmTemplate.add_choices(
+            choices,
+            layout,
+            kernel_inputs.nodes(),
+            alpha=alpha,
+            beta=beta,
+            has_bias=True,
+        )
+
+    return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
+
+
+@register_lowering(aten._sparse_semi_structured_mm, type_promotion_kind=None)
+def tuned_sparse_semi_structured_mm(
+    mat1, mat1_meta, mat2, *, out_dtype=None, layout=None
+):
+    from torch._inductor.select_algorithm import realize_inputs
+
+    # TODO(coconturuben): support V.choices.get_mm_configs for sparse_semi_structured_mm
+    mat1, mat1_meta, mat2 = realize_inputs(mat1, mat1_meta, mat2)
+    m1, k1 = mat1.get_size()
+    m2, _ = mat1_meta.get_size()
+    k2, n = mat2.get_size()
+    m = V.graph.sizevars.check_equals_and_simplify(m1, m2)
+    k = V.graph.sizevars.check_equals_and_simplify(2 * k1, k2)
+    if layout is None:
+        from torch._inductor.ir import FixedLayout
+
+        layout = FixedLayout(
+            mat2.get_device(),
+            out_dtype if out_dtype else mat2.get_dtype(),
+            [m, n],
+            [n, 1],
+        )
+    else:
+        assert out_dtype is None, "out_dtype is ignored if layout is specified."
+
+    choices = (
+        [
+            aten__sparse_semi_structured_mm.bind(
+                (mat1, mat1_meta, mat2), layout, out_dtype=out_dtype
+            )
+        ]
+        if use_aten_gemm_kernels()
+        else []
+    )
+
+    if (
+        m * n != 0
+        and use_cutlass_template(layout, m, n, k)
+        and _use_cutlass_for_op("sparse_semi_structured_mm")
+    ):
+        CUTLASS2xGemmTemplate.add_cutlass_gemm_choices(
+            choices, layout, [mat1, mat2, mat1_meta], fuseable=True, non_fuseable=True
+        )
+
+    return autotune_select_algorithm(
+        "sparse_semi_structured_mm", choices, (mat1, mat1_meta, mat2), layout
+    )
+
+
+scaling_pairs = [
+    (ScalingType.TensorWise, ScalingType.TensorWise),
+    (ScalingType.RowWise, ScalingType.RowWise),
+    (ScalingType.BlockWise1x128, ScalingType.BlockWise128x128),
+    (ScalingType.BlockWise1x128, ScalingType.BlockWise1x128),
+]
+
+
+epilogue_scaling_types = [ScalingType.TensorWise, ScalingType.RowWise]
+main_loop_scaling_types = [ScalingType.BlockWise1x128, ScalingType.BlockWise128x128]
+
+
+def _is_tensorwise_scaling(sz: Any) -> bool:
+    return (len(sz) == 0) or all(
+        V.graph.sizevars.statically_known_equals(d, 1) for d in sz
+    )
+
+
+def _is_rowwise_scaling(sz: Any, transpose: bool) -> bool:
+    idx = 0 if transpose else -1
+    return V.graph.sizevars.statically_known_equals(sz[idx], 1)
+
+
+def _is_blockwise1xTILESIZE_scaling(
+    sz: Any, tensor_sz: Any, tile_size: int, transpose: bool
+) -> bool:
+    lhs = 1 if transpose else 0
+    rhs = 0 if transpose else 1
+    return V.graph.sizevars.statically_known_equals(
+        sz[lhs], tensor_sz[lhs]
+    ) and V.graph.sizevars.statically_known_equals(
+        sz[rhs], ceildiv(tensor_sz[rhs], tile_size)
+    )
+
+
+def _is_blockwise128x128_scaling(sz: Any, tensor_sz: Any) -> bool:
+    return V.graph.sizevars.statically_known_equals(
+        sz[0], ceildiv(tensor_sz[0], 128)
+    ) and V.graph.sizevars.statically_known_equals(sz[1], ceildiv(tensor_sz[1], 128))
+
+
+def is_desired_scaling(
+    t: Any,
+    scale_size: torch.Tensor,
+    scaling_type: ScalingType,
+    transpose: bool = False,
+) -> bool:
+    match scaling_type:
+        case ScalingType.TensorWise:
+            return _is_tensorwise_scaling(scale_size)
+        case ScalingType.RowWise:
+            return _is_rowwise_scaling(scale_size, transpose)
+        case ScalingType.BlockWise1x128:
+            return _is_blockwise1xTILESIZE_scaling(
+                scale_size, t.get_size(), 128, transpose
+            )
+        case ScalingType.BlockWise128x128:
+            return _is_blockwise128x128_scaling(scale_size, t.get_size())
+        case _:
+            raise AssertionError(f"Unsupported scaling type {scaling_type}")
+
+
+def get_tile_size(scale_option) -> int:
+    match scale_option:
+        case ScalingType.BlockWise128x128:
+            return 128
+        case ScalingType.BlockWise1x128:
+            return 128
+        case _:
+            raise AssertionError(
+                f"Unsupported scaling type {scale_option} in get_tile_size"
+            )
+
+
+def get_scaling_options(
+    mat_a: Any,
+    mat_b: Any,
+    scale_a_size: torch.Tensor,
+    scale_b_size: torch.Tensor,
+) -> tuple[ScalingType, ScalingType]:
+    for scale_option_a, scale_option_b in scaling_pairs:
+        if is_desired_scaling(
+            mat_a, scale_a_size, scale_option_a
+        ) and is_desired_scaling(mat_b, scale_b_size, scale_option_b, transpose=True):
+            return scale_option_a, scale_option_b
+
+    raise AssertionError(
+        f"Inductor Triton does not support scale_a.shape = {scale_a_size}, scale_b.shape = {scale_b_size}"
+    )  # verify that shapes are supported by at least one existing pairing
+
+
+@register_lowering(aten._scaled_mm.default, type_promotion_kind=None)  # type: ignore[misc]
+def tuned_scaled_mm(
+    mat_a,
+    mat_b,
+    scale_a,
+    scale_b,
+    bias=None,
+    scale_result=None,
+    out_dtype=None,
+    use_fast_accum=False,
+    layout=None,
+):
+    """
+    Performs an optimized matrix multiplication where scaling factors are applied
+    to the inputs and/or output.
+
+    Args:
+        mat1 (Tensor): First input matrix
+        mat2 (Tensor): Second input matrix
+        scale1 (Tensor): Scale factor applied to mat1 (supports broadcasting)
+        scale2 (Tensor): Scale factor applied to mat2 (supports broadcasting)
+        bias (Tensor, optional): Optional bias tensor to add to the result
+        layout: Layout hint for optimization
+
+    Returns:
+        Tensor: The result of the scaled matrix multiplication
+    """
+    # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
+    m, n, k, layout, mat_a, mat_b = mm_args(
+        mat_a, mat_b, layout=layout, out_dtype=out_dtype
+    )
+    # below is for getting an overview logging info of inductor mms
+    counters["aten_mm_info"][f"aten._scaled_mm.default_{m}_{n}_{k}"] += 1
+    log.info(
+        "Tuned aten._scaled_mm.default: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
+        m,
+        n,
+        k,
+        mat_a.get_dtype(),
+        mat_b.get_dtype(),
+        layout,
+    )
+    name = "scaled_mm"
+    check_supported_striding(mat_a, mat_b)
+
+    scale_a_real, scale_b_real = realize_inputs(scale_a, scale_b)
+
+    input_nodes: list[Any]
+
+    if not bias:
+        input_nodes = [mat_a, mat_b, scale_a_real, scale_b_real]
+    else:
+        bias_real = realize_inputs(bias)
+        input_nodes = [mat_a, mat_b, scale_a_real, scale_b_real, bias_real]
+
+    # Create MMKernelInputs for Scaled MM (matrices are at indices 0, 1)
+    kernel_inputs = MMKernelInputs(
+        input_nodes, mat1_idx=0, mat2_idx=1, out_dtype=out_dtype
+    )
+
+    choices: list[ChoiceCaller] = []
+
+    # Collect all templates for unified call
+    templates_to_use: list[Union[ExternKernelChoice, KernelTemplate]] = []
+    kwarg_overrides = {}
+
+    if use_aten_gemm_kernels():
+        templates_to_use.append(aten__fp8_mm)
+        kwarg_overrides[aten__fp8_mm.uid] = dict(
+            out_dtype=out_dtype, use_fast_accum=use_fast_accum
+        )
+
+    _, is_nonzero = _is_static_problem(layout)
+
+    if (
+        # We dont have triton lowerings for the MX variants yet
+        scale_a.dtype == torch.float32
+        and is_nonzero
+        and use_triton_template(layout, enable_float8=True, check_max_autotune=False)
+    ):
+        overriders = dict(USE_FAST_ACCUM=use_fast_accum)
+
+        # TODO (paulzhan): There is no template that exists for bias and TMA
+        # Don't run tma template currently if bias exist
+        if use_triton_tma_template(mat_a, mat_b, output_layout=layout) and not bias:
+            scale_a_size, scale_b_size = scale_a_real.shape, scale_b_real.shape
+
+            scale_option_a, scale_option_b = get_scaling_options(
+                mat_a, mat_b, scale_a_size, scale_b_size
+            )
+            overriders["SCALE_RECIPE_A"] = scale_option_a.value
+            overriders["SCALE_RECIPE_B"] = scale_option_b.value
+
+            if (
+                scale_option_a in epilogue_scaling_types
+                and scale_option_b in epilogue_scaling_types
+            ):
+                templates_to_use.append(scaled_mm_device_tma_epilogue_scaling_template)
+                kwarg_overrides[scaled_mm_device_tma_epilogue_scaling_template.uid] = (
+                    overriders
+                )
+            elif (
+                scale_option_a in main_loop_scaling_types
+                and scale_option_b in main_loop_scaling_types
+            ):
+                overriders["TILE_SIZE_A"] = get_tile_size(scale_option_a)
+                overriders["TILE_SIZE_B"] = get_tile_size(scale_option_b)
+
+                templates_to_use.append(scaled_mm_device_tma_main_loop_scaling_template)
+                kwarg_overrides[scaled_mm_device_tma_main_loop_scaling_template.uid] = (
+                    overriders
+                )
+            else:
+                raise AssertionError(
+                    "Inductor Triton does not support scaling options that are present "
+                    + "in both epilogue scaling and main loop scaling"
+                )
+
+        if (
+            use_triton_blackwell_tma_template(mat_a, mat_b, output_layout=layout)
+            and not bias
+        ):
+            templates_to_use.append(blackwell_ws_persistent_device_tma_mm_template)
+            kwarg_overrides[blackwell_ws_persistent_device_tma_mm_template.uid] = (
+                overriders
+            )
+
+        templates_to_use.append(mm_template)
+        kwarg_overrides[mm_template.uid] = overriders
+
+    # Single unified call for all templates
+    choices.extend(
+        V.choices.get_template_configs(
+            kernel_inputs,
+            templates_to_use,
+            name,
+            kwarg_overrides=kwarg_overrides,
+        )
+    )
+
+    # Early return for MX variants
+    if scale_a.dtype != torch.float32:
+        return autotune_select_algorithm(name, choices, input_nodes, layout)
+
+    if (
+        is_nonzero
+        and use_cutlass_template(layout, m, n, k)
+        and _use_cutlass_for_op(name)
+    ):
+        CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
+            choices,
+            layout,
+            kernel_inputs.nodes(),  # type: ignore[arg-type]
+            use_fast_accum=use_fast_accum,  # type: ignore[arg-type]
+        )
+
+    if is_nonzero and use_ck_gemm_template(layout, m, n, k):
+        CKGemmTemplate.add_ck_gemm_choices(choices, layout, kernel_inputs.nodes())
+
+    return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
+
+
+@functools.cache
+def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool:
+    props = torch.cuda.get_device_properties(index or 0)
+    return props.major <= 7
+
+
+def dims_are_int(dims):
+    return all(isinstance(dim, int) for dim in dims)
+
+
+def mm_autoheuristic(
+    mat1,
+    mat2,
+    m,
+    n,
+    k,
+    choices,
+    name,
+    input_nodes,
+    ops,
+    precondition,
+    top_k: Optional[int] = None,
+    always_included=None,
+):
+    m, n, k = get_size_hints(mat1, mat2, m, n, k)
+    if not dims_are_int([m, n, k]):
+        return None
+    mat1_stride, mat2_stride = get_size_hints_strides(mat1, mat2)
+
+    def get_context(m, k, n, mat1, mat2, mat1_stride, mat2_stride):
+        context = AHContext()
+        context.add_feature("m", m)
+        context.add_feature("k", k)
+        context.add_feature("n", n)
+        context.add_feature("mat1_dtype", mat1.layout.dtype, is_categorical=True)
+        context.add_feature("mat2_dtype", mat2.layout.dtype, is_categorical=True)
+        context_add_strides(context, "mat1", mat1_stride)
+        context_add_strides(context, "mat2", mat2_stride)
+        context.add_feature(
+            "mat1_iscontig", mat1.layout.is_contiguous(), is_categorical=True
+        )
+        context.add_feature(
+            "mat2_iscontig", mat2.layout.is_contiguous(), is_categorical=True
+        )
+        if name == "mm":
+            context_add_using_tf32(context, mat1.layout.dtype)
+        return context
+
+    def fallback():
+        return None
+
+    context = get_context(m, k, n, mat1, mat2, mat1_stride, mat2_stride)
+    autoheuristic = AutoHeuristicSelectAlgorithm(
+        fallback=fallback,
+        choices=choices,
+        input_nodes=input_nodes,
+        context=context,
+        name=name,
+        augment_context=ops,
+        precondition=precondition,
+    )
+
+    if top_k is not None:
+        # TODO: is there a cleaner way to ensure aten.mm is always included?
+        return autoheuristic.get_top_k_choices_caller(
+            top_k, always_included=always_included
+        )
+
+    return autoheuristic.get_choice_caller()
+
+
+def get_size_hints(mat1, mat2, m, n, k):
+    if not isinstance(m, int) or not isinstance(k, int):
+        (m, k) = V.graph.sizevars.size_hints(
+            mat1.get_size(),
+            fallback=torch._inductor.config.unbacked_symint_fallback,
+        )
+
+    if not isinstance(n, int) or not isinstance(k, int):
+        (k, n) = V.graph.sizevars.size_hints(
+            mat2.get_size(),
+            fallback=torch._inductor.config.unbacked_symint_fallback,
+        )
+    return m, n, k
+
+
+def get_size_hints_strides(mat1, mat2):
+    mat1_stride = mat1.layout.stride
+    mat2_stride = mat2.layout.stride
+    strides = [mat1_stride, mat2_stride]
+    strides_hints = []
+    for stride in strides:
+        if not isinstance(stride, int):
+            stride = V.graph.sizevars.size_hints(
+                stride,
+                fallback=torch._inductor.config.unbacked_symint_fallback,
+            )
+        strides_hints.append(stride)
+    return strides_hints[0], strides_hints[1]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/mm_common.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/mm_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb22b95af2afcef65cb4876d9c9685633e5bde70
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/mm_common.py
@@ -0,0 +1,263 @@
+# mypy: allow-untyped-defs
+import logging
+from collections.abc import Sequence
+from functools import partial
+from pathlib import Path
+from typing import Any
+
+import torch
+from torch._inductor.select_algorithm import realize_inputs, SymbolicGridFn
+from torch._inductor.utils import get_current_backend, sympy_product
+from torch._inductor.virtualized import V
+from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
+
+from .. import config
+from ..codegen.wrapper import PythonWrapperCodegen
+from ..ir import _IntLike, Layout, TensorBox
+from ..utils import load_template
+
+
+log = logging.getLogger(__name__)
+
+
+@SymbolicGridFn
+def mm_grid(m, n, meta, *, cdiv):
+    """
+    The CUDA grid size for matmul triton templates.
+    """
+    return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1)
+
+
+@SymbolicGridFn
+def persistent_mm_grid(M: int, N: int, meta: dict[str, Any], *, cdiv, min):
+    """Defines the grid for persistent kernels."""
+    return (
+        min(meta["NUM_SMS"], cdiv(M, meta["BLOCK_M"]) * cdiv(N, meta["BLOCK_N"])),
+        1,
+        1,
+    )
+
+
+@SymbolicGridFn
+def persistent_grouped_mm_grid(*args):
+    meta = args[-1]
+    return (meta["NUM_SMS"], 1, 1)
+
+
+def acc_type(dtype):
+    if dtype in (torch.float16, torch.bfloat16):
+        return "tl.float32"
+    return f"tl.{dtype}".replace("torch.", "")
+
+
+def mm_args(
+    mat1,
+    mat2,
+    *others,
+    layout=None,
+    out_dtype=None,
+    use_4x2_dim=False,
+    mat2_transposed=False,
+):
+    """
+    Common arg processing for mm,bmm,addmm,etc
+    """
+    mat1, mat2 = realize_inputs(mat1, mat2)
+    *b1, m, k1 = mat1.get_size()
+    if mat2_transposed:
+        *b2, n, k2 = mat2.get_size()
+    else:
+        *b2, k2, n = mat2.get_size()
+    b = [V.graph.sizevars.check_equals_and_simplify(a, b) for a, b in zip(b1, b2)]
+    if use_4x2_dim:
+        k2 = k2 * 2
+    k = V.graph.sizevars.check_equals_and_simplify(k1, k2)
+    if layout is None:
+        from torch._inductor.ir import FixedLayout
+
+        if out_dtype is None:
+            out_dtype = mat1.get_dtype()
+
+        layout = FixedLayout(
+            mat1.get_device(),
+            out_dtype,
+            [*b, m, n],
+        )
+    else:
+        assert out_dtype is None, "out_dtype is ignored if layout is specified."
+    from ..lowering import expand
+
+    others = [realize_inputs(expand(x, layout.size)) for x in others]
+
+    return [m, n, k, layout, mat1, mat2, *others]
+
+
+def addmm_epilogue(dtype, alpha, beta):
+    def epilogue(acc, bias):
+        if alpha != 1:
+            acc = V.ops.mul(acc, V.ops.constant(alpha, dtype))
+        if beta != 1:
+            bias = V.ops.mul(bias, V.ops.constant(beta, dtype))
+        return V.ops.add(acc, bias)
+
+    return epilogue
+
+
+def scale_mm_epilogue():
+    """
+    Create an epilogue function that applies scaling to matrix multiplication result
+    using the given scale factors.
+
+    Args:
+        dtype: The data type of the output
+        scale_a: Scale factor for matrix A
+        scale_b: Scale factor for matrix B
+
+    Returns:
+        Epilogue function that takes the accumulator and applies scaling
+    """
+
+    def epilogue(acc, inv_a_scale, inv_b_scale, bias=None):
+        # The epilogue function receives the accumulator (result of mat1 @ mat2)
+        # and applies the scaling factors
+        # In the original scaled_mm, we use inverse scales, so we multiply by them
+        mul_scales = V.ops.mul(inv_a_scale, inv_b_scale)
+        mul_acc = V.ops.mul(acc, mul_scales)
+        if bias is not None:
+            return V.ops.add(mul_acc, bias)
+        else:
+            return mul_acc
+
+    return epilogue
+
+
+def use_native_matmul(mat1, mat2):
+    if not config.triton.native_matmul:
+        return False
+
+    # If tma matmul is on, don't do native matmul
+    if (
+        config.triton.enable_persistent_tma_matmul
+        and torch.utils._triton.has_triton_tma_device()
+    ):
+        raise AssertionError("native matmul doesn't support tma codegen yet")
+
+    # Currently only enable native matmul for default indexing
+    # TODO : support block ptr
+    if config.triton.use_block_ptr:
+        raise AssertionError("native matmul doesn't support block_ptr codegen yet")
+
+    # Currently only enable native matmul for triton on GPU.
+    device_type = mat1.get_device().type
+    if not (
+        device_type in ("cuda", "xpu") and get_current_backend(device_type) == "triton"
+    ):
+        return False
+
+    # Currently, tl.dot only supports following dtypes
+    triton_supported_dtype = [
+        torch.int8,
+        torch.uint8,
+        torch.float16,
+        torch.bfloat16,
+        torch.float32,
+    ]
+    if mat1.dtype not in triton_supported_dtype:
+        return False
+    if mat2.dtype not in triton_supported_dtype:
+        return False
+
+    # (..., M, K) @ (..., K, N)
+    m, k, n = mat1.get_size()[-2], mat1.get_size()[-1], mat2.get_size()[-1]
+
+    # If the shape has unbacked symbols, don't do native matmul.
+    # This is related to the behavior of statically_known_multiple_of on unbacked symints.
+    # Since statically_known_multiple_of just returns False for unbacked symbols
+    # due to the expensive cost, codegen fails when there is a unbacked symbol.
+    # In particular, it fails at _split_iteration_ranges in codegen/simd.py.
+    # See this : https://github.com/pytorch/pytorch/pull/131649
+    if any(map(has_free_unbacked_symbols, [m, k, n])):
+        return False
+
+    # Consider the shape (m,k,n) > 1
+    # TODO : support when size = 1
+    if (
+        V.graph.sizevars.statically_known_leq(m, 1)
+        or V.graph.sizevars.statically_known_leq(k, 1)
+        or V.graph.sizevars.statically_known_leq(n, 1)
+    ):
+        return False
+
+    return True
+
+
+def _is_static_problem(layout: Layout) -> tuple[bool, bool]:
+    """
+    Check if input tensors and output layout have static shapes and non-zero sizes.
+
+    Args:
+        layout: Output layout object with a 'size' attribute.
+
+    Returns:
+        Tuple[bool, bool]: (is_static, is_nonzero)
+            is_static: True if all shapes are statically known
+            is_nonzero: True if all dimensions are non-zero
+    """
+    static_shape = True
+    static_size = PythonWrapperCodegen.statically_known_list_of_ints_or_none(
+        layout.size
+    )
+    if static_size is None:
+        nonzero = True
+        for s in layout.size:
+            sz = PythonWrapperCodegen.statically_known_int_or_none(s)
+            if sz is not None and sz == 0:
+                nonzero = False
+                break
+        return False, nonzero
+    numel = 1
+    for dim in static_size:
+        numel *= dim
+    nonzero = numel > 0
+    return static_shape, nonzero
+
+
+def check_supported_striding(mat_a: TensorBox, mat_b: TensorBox) -> None:
+    def is_row_major(stride: Sequence[_IntLike]) -> bool:
+        return stride[-1] == 1
+
+    def is_col_major(stride: Sequence[_IntLike]) -> bool:
+        return stride[-2] == 1
+
+    def has_zero_dim(size: Sequence[_IntLike]) -> bool:
+        return bool(size[0] == 0 or size[1] == 0)
+
+    # Check mat_a (self) stride requirements
+    torch._check(
+        is_row_major(mat_a.get_stride()) or has_zero_dim(mat_a.get_size()),
+        lambda: f"mat_a must be row_major, got stride {mat_a.get_stride()}",
+    )
+
+    # Check mat_b stride requirements
+    torch._check(
+        is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()),
+        lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}",
+    )
+
+
+def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool:
+    """
+    Checking if the batch stride is the largest in the stride.
+    """
+    sizes = [mat1.get_size(), mat2.get_size(), layout.size]
+    strides = [mat1.get_stride(), mat2.get_stride(), layout.stride]
+    for size, stride in zip(sizes, strides):
+        assert len(size) == len(stride) == 3, "Expect 3D tensors"
+        if stride[0] != 0 and stride[0] != sympy_product(size[1:]):
+            return False
+
+    return True
+
+
+_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates"
+load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/mm_grouped.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/mm_grouped.py
new file mode 100644
index 0000000000000000000000000000000000000000..35ee6a541c15079d32f8291ee57e7e3909956cb4
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/mm_grouped.py
@@ -0,0 +1,891 @@
+# mypy: allow-untyped-defs
+import logging
+from dataclasses import asdict, dataclass
+from typing import Any, Optional
+
+import torch
+from torch._dynamo.utils import counters
+from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate
+from torch._inductor.runtime.triton_compat import tl
+from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs
+from torch._inductor.virtualized import V
+from torch.utils._triton import has_triton
+
+from ..ir import ChoiceCaller, Layout, TensorBox
+from ..lowering import register_lowering
+from ..select_algorithm import (
+    autotune_select_algorithm,
+    ExternKernelChoice,
+    realize_inputs,
+    TritonTemplate,
+)
+from ..utils import (
+    get_gpu_shared_memory,
+    get_num_sms,
+    has_free_symbols,
+    use_aten_gemm_kernels,
+    use_blackwell_cutedsl_grouped_mm,
+    use_triton_template,
+)
+from .mm_common import (
+    _is_static_problem,
+    check_supported_striding,
+    load_kernel_template,
+    persistent_grouped_mm_grid,
+)
+
+
+log = logging.getLogger(__name__)
+aten = torch.ops.aten
+
+
+@dataclass
+class Config:
+    kwargs: dict[str, int]
+    num_stages: int
+    num_warps: int
+
+
+_NV_CONFIGS = [
+    Config(
+        {
+            "BLOCK_M": block_size_m,
+            "BLOCK_N": block_size_n,
+            "BLOCK_K": block_size_k,
+            "NUM_CONSUMER_GROUPS": 1,
+        },
+        num_stages=num_stages,
+        num_warps=num_warps,
+    )
+    for block_size_m in [16, 32, 64, 128]
+    for block_size_n in [64, 128, 256]
+    for block_size_k in [64, 128, 256]
+    for num_stages in [3, 4]
+    for num_warps in [4, 8]
+]
+
+
+def grouped_mm_configs():
+    return _NV_CONFIGS
+
+
+def early_config_prune(g, m, dtsize, configs, named_args):
+    pruned_configs = []
+    for config in configs:
+        kw = config.kwargs
+        BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps, num_consumer_groups = (
+            kw["BLOCK_M"],
+            kw["BLOCK_N"],
+            kw["BLOCK_K"],
+            config.num_stages,
+            config.num_warps,
+            getattr(config, "num_consumer_groups", 0),
+        )
+
+        # 1. Prune NV configs depending on g and m.
+        if not has_free_symbols((g, m)):
+            a_is_2d, b_is_2d = named_args["A_IS_2D"], named_args["B_IS_2D"]
+            m_avg = m // g if a_is_2d and not b_is_2d else m
+            if m_avg <= 16:
+                if BLOCK_M > 32:
+                    continue
+            elif m_avg <= 32:
+                if BLOCK_M > 64:
+                    continue
+            elif m_avg <= 64:
+                if BLOCK_M <= 16:
+                    continue
+            else:
+                if BLOCK_M <= 32:
+                    continue
+
+        # 2. make sure we have enough smem
+        max_shared_memory = get_gpu_shared_memory()
+
+        required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
+        if required_shared_memory > max_shared_memory:
+            continue
+
+        use_warp_specialization = num_consumer_groups >= 1
+
+        # 3. make sure we can partition for ws
+        if use_warp_specialization:
+            if num_warps != 4:
+                continue
+
+            # "tritongpu-warp-spec-data-partition"
+            m_slice = BLOCK_M // num_consumer_groups
+            n_slice = BLOCK_N // num_consumer_groups
+            if m_slice < 64 and n_slice < 256:
+                continue
+
+        pruned_configs.append(config)
+
+    return pruned_configs
+
+
+triton_grouped_mm_source = r"""
+{% macro assign_maybe_constexpr(name, value_expr) -%}
+    {%- set value_str = value_expr | string -%}
+    {%- set sentinel = "__NOT_A_NUMBER__" -%}
+    {%- set as_int = value_str | int(default=sentinel) -%}
+    {%- set as_float = value_str | float(default=sentinel) -%}
+    {%- set is_constexpr = (as_int != sentinel) or (as_float != sentinel) -%}
+    {{ name }}{{ ": tl.constexpr" if is_constexpr else "" }} = {{ value_expr }}
+{%- endmacro %}
+
+import triton
+import triton.language as tl
+
+@triton.jit
+def do_tma_loads(
+    g, a_desc, b_desc, m_offset, n_offset, k_offset,
+    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+):
+{%- if A_IS_2D %}
+{%- if A_IS_K_MAJOR %}
+    a = a_desc.load([m_offset, k_offset])
+{%- else %}
+    a = a_desc.load([k_offset, m_offset])
+{%- endif %}
+{%- else %}
+{%- if A_IS_K_MAJOR %}
+    a = a_desc.load([g, m_offset, k_offset]).reshape(BLOCK_M, BLOCK_K)
+{%- else %}
+    a = a_desc.load([g, k_offset, m_offset]).reshape(BLOCK_K, BLOCK_M)
+{%- endif %}
+{%- endif %}
+{%- if B_IS_2D %}
+{%- if B_IS_K_MAJOR %}
+    b = b_desc.load([n_offset, k_offset])
+{%- else %}
+    b = b_desc.load([k_offset, n_offset])
+{%- endif %}
+{%- else %}
+{%- if B_IS_K_MAJOR %}
+    b = b_desc.load([g, n_offset, k_offset]).reshape(BLOCK_N, BLOCK_K)
+{%- else %}
+    b = b_desc.load([g, k_offset, n_offset]).reshape(BLOCK_K, BLOCK_N)
+{%- endif %}
+{%- endif %}
+
+    return (a, b)
+
+
+@triton.jit
+def do_mma(a, b, accumulator):
+{%- if USE_FAST_ACCUM %}
+{%- if A_IS_K_MAJOR and B_IS_K_MAJOR %}
+    accumulator = tl.dot(a, b.T, accumulator)
+{%- elif A_IS_K_MAJOR and not B_IS_K_MAJOR %}
+    accumulator = tl.dot(a, b, accumulator)
+{%- elif not A_IS_K_MAJOR and B_IS_K_MAJOR %}
+    accumulator = tl.dot(a.T, b.T, accumulator)
+{%- else %}
+    accumulator = tl.dot(a.T, b, accumulator)
+{%- endif %}
+{%- else %}
+{%- if A_IS_K_MAJOR and B_IS_K_MAJOR %}
+    accumulator += tl.dot(a, b.T)
+{%- elif A_IS_K_MAJOR and not B_IS_K_MAJOR %}
+    accumulator += tl.dot(a, b)
+{%- elif not A_IS_K_MAJOR and B_IS_K_MAJOR %}
+    accumulator += tl.dot(a.T, b.T)
+{%- else %}
+    accumulator += tl.dot(a.T, b)
+{%- endif %}
+{%- endif %}
+
+    return accumulator
+
+
+{%- if SCALED %}
+{%- if A_IS_2D or B_IS_2D %}
+{{def_kernel("a_ptr", "b_ptr", "scale_a_ptr", "scale_b_ptr", "offsets_ptr")}}
+{%- else %}
+{{def_kernel("a_ptr", "b_ptr", "scale_a_ptr", "scale_b_ptr")}}
+{%- endif %}
+{%- else %}
+{%- if A_IS_2D or B_IS_2D %}
+{{def_kernel("a_ptr", "b_ptr", "offsets_ptr")}}
+{%- else %}
+{{def_kernel("a_ptr", "b_ptr")}}
+{%- endif %}
+{%- endif %}
+    tidx = tl.program_id(0).to(INDEX_DTYPE)
+
+{%- set M_IS_VARYING = A_IS_2D and not B_IS_2D %}
+{%- set N_IS_VARYING = not A_IS_2D and B_IS_2D %}
+{%- set K_IS_VARYING = A_IS_2D and B_IS_2D %}
+
+{%- if A_IS_2D %}
+{%- if B_IS_2D %}
+    {{ assign_maybe_constexpr("G", size("offsets_ptr", 0)) }}
+{%- else %}
+    {{ assign_maybe_constexpr("G", size("b_ptr", 0)) }}
+{%- endif %}
+{%- else %}
+{%- if B_IS_2D %}
+    {{ assign_maybe_constexpr("G", size("a_ptr", 0)) }}
+{%- else %}
+    {{ assign_maybe_constexpr("G", size("a_ptr", 0)) }}
+{%- endif %}
+{%- endif %}
+
+    # the b_ptr tensor is given with its last two dims transposed, revert here
+
+    {{ assign_maybe_constexpr("M", size("a_ptr", -2)) }}
+    {{ assign_maybe_constexpr("N", size("b_ptr", -1)) }}
+    {{ assign_maybe_constexpr("K", size("a_ptr", -1)) }}
+
+    {{ assign_maybe_constexpr("A_STRIDE_M", stride("a_ptr", -2)) }}
+    {{ assign_maybe_constexpr("A_STRIDE_K", stride("a_ptr", -1)) }}
+{%- if not A_IS_2D %}
+    {{ assign_maybe_constexpr("A_STRIDE_G", stride("a_ptr", 0)) }}
+{%- if SCALED %}
+    {{ assign_maybe_constexpr("SCALE_A_STRIDE_G", stride("scale_a_ptr", 0)) }}
+{%- endif %}
+{%- endif %}
+    {{ assign_maybe_constexpr("B_STRIDE_N", stride("b_ptr", -1)) }}
+    {{ assign_maybe_constexpr("B_STRIDE_K", stride("b_ptr", -2)) }}
+{%- if not B_IS_2D %}
+    {{ assign_maybe_constexpr("B_STRIDE_G", stride("b_ptr", 0)) }}
+    B_STRIDE_G = {{stride("b_ptr", 0)}}
+{%- if SCALED %}
+    {{ assign_maybe_constexpr("SCALE_B_STRIDE_G", stride("scale_b_ptr", 0)) }}
+{%- endif %}
+{%- endif %}
+
+{%- if USE_TMA_LOAD %}
+{%- if USE_EXPERIMENTAL_MAKE_TENSOR_DESCRIPTOR %}
+    a_desc = tl._experimental_make_tensor_descriptor(
+{%- else %}
+    a_desc = tl.make_tensor_descriptor(
+{%- endif %}
+        a_ptr,
+{%- if A_IS_2D %}
+{%- if A_IS_K_MAJOR %}
+        shape=[M, K],
+        strides=[A_STRIDE_M, A_STRIDE_K],
+        block_shape=[BLOCK_M, BLOCK_K],
+{%- else %}
+        shape=[K, M],
+        strides=[A_STRIDE_K, A_STRIDE_M],
+        block_shape=[BLOCK_K, BLOCK_M],
+{%- endif %}
+{%- else %}
+{%- if A_IS_K_MAJOR %}
+        shape=[G, M, K],
+        strides=[A_STRIDE_G, A_STRIDE_M, A_STRIDE_K],
+        block_shape=[1, BLOCK_M, BLOCK_K],
+{%- else %}
+        shape=[G, K, M],
+        strides=[A_STRIDE_G, A_STRIDE_K, A_STRIDE_M],
+        block_shape=[1, BLOCK_K, BLOCK_M],
+{%- endif %}
+{%- endif %}
+    )
+
+{%- if USE_EXPERIMENTAL_MAKE_TENSOR_DESCRIPTOR %}
+    b_desc = tl._experimental_make_tensor_descriptor(
+{%- else %}
+    b_desc = tl.make_tensor_descriptor(
+{%- endif %}
+        b_ptr,
+{%- if B_IS_2D %}
+{%- if B_IS_K_MAJOR %}
+        shape=[N, K],
+        strides=[B_STRIDE_N, B_STRIDE_K],
+        block_shape=[BLOCK_N, BLOCK_K],
+{%- else %}
+        shape=[K, N],
+        strides=[B_STRIDE_K, B_STRIDE_N],
+        block_shape=[BLOCK_K, BLOCK_N],
+{%- endif %}
+{%- else %}
+{%- if B_IS_K_MAJOR %}
+        shape=[G, N, K],
+        strides=[B_STRIDE_G, B_STRIDE_N, B_STRIDE_K],
+        block_shape=[1, BLOCK_N, BLOCK_K],
+{%- else %}
+        shape=[G, K, N],
+        strides=[B_STRIDE_G, B_STRIDE_K, B_STRIDE_N],
+        block_shape=[1, BLOCK_K, BLOCK_N],
+{%- endif %}
+{%- endif %}
+    )
+{%- endif %}
+
+{%- if M_IS_VARYING %}
+    m_end_offset = 0
+{%- endif %}
+{%- if N_IS_VARYING %}
+    n_end_offset = 0
+{%- endif %}
+{%- if K_IS_VARYING %}
+    k_end_offset = 0
+{%- endif %}
+    iterated_tiles = 0
+    for g in tl.range(G):
+{%- if M_IS_VARYING %}
+        # Move across groups
+        m_start_offset = m_end_offset
+        m_end_offset = tl.load(offsets_ptr + g)
+        m_size = m_end_offset - m_start_offset
+{%- if SCALED %}
+        m_scale_start_offset = m_start_offset
+{%- endif %}
+{%- else %}
+        m_start_offset = 0
+        m_size = M
+{%- if SCALED %}
+        m_scale_start_offset = g * M
+{%- endif %}
+{%- endif %}
+
+{%- if N_IS_VARYING %}
+        # Move across groups
+        n_start_offset = n_end_offset
+        n_end_offset = tl.load(offsets_ptr + g)
+        n_size = n_end_offset - n_start_offset
+{%- if SCALED %}
+        n_scale_start_offset = n_start_offset
+{%- endif %}
+{%- else %}
+        n_start_offset = 0
+        n_size = N
+{%- if SCALED %}
+        n_scale_start_offset = g * N
+{%- endif %}
+{%- endif %}
+
+        if m_size > 0 and n_size > 0:
+{%- if K_IS_VARYING %}
+            # Move across groups
+            k_start_offset = k_end_offset
+            k_end_offset = tl.load(offsets_ptr + g)
+            k_size = k_end_offset - k_start_offset
+{%- else %}
+            k_start_offset = 0
+            k_size = K
+{%- endif %}
+
+            num_m_tiles = tl.cdiv(m_size, BLOCK_M)
+            num_n_tiles = tl.cdiv(n_size, BLOCK_N)
+            num_tiles = num_m_tiles * num_n_tiles
+
+            # Move across tiles
+            while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
+                gidx = tidx - iterated_tiles
+                # Split M first and N second.
+                tile_m_idx = gidx % num_m_tiles
+                tile_n_idx = gidx // num_m_tiles
+
+                accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+{%- if USE_TMA_LOAD %}
+                m_tile_offset = tile_m_idx * BLOCK_M
+                n_tile_offset = tile_n_idx * BLOCK_N
+                m_offset = (m_start_offset + m_tile_offset).to(tl.int32)
+                n_offset = (n_start_offset + n_tile_offset).to(tl.int32)
+
+                k_block_offset = 0
+                for k in range(k_size // BLOCK_K):
+                    k_offset = k_start_offset + k_block_offset
+                    a, b = do_tma_loads(
+                        g, a_desc, b_desc, m_offset, n_offset, k_offset,
+                        BLOCK_M, BLOCK_N, BLOCK_K
+                    )
+                    accumulator = do_mma(a, b, accumulator)
+                    k_block_offset += BLOCK_K
+
+                if k_size % BLOCK_K != 0:
+                    k_offset = k_start_offset + k_block_offset
+                    a, b = do_tma_loads(
+                        g, a_desc, b_desc, m_offset, n_offset, k_offset,
+                        BLOCK_M, BLOCK_N, BLOCK_K
+                    )
+{%- if K_IS_VARYING %}
+                    group_offs = k_block_offset + tl.arange(0, BLOCK_K)
+                    k_mask = group_offs < k_size
+{%- if A_IS_K_MAJOR %}
+                    a = tl.where(k_mask[None, :], a, 0)
+{%- else %}
+                    a = tl.where(k_mask[:, None], a, 0)
+{%- endif %}
+{%- if B_IS_K_MAJOR %}
+                    b = tl.where(k_mask[None, :], b, 0)
+{%- else %}
+                    b = tl.where(k_mask[:, None], b, 0)
+{%- endif %}
+{%- endif %}
+                    accumulator = do_mma(a, b, accumulator)
+{%- else %}
+                offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
+                offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
+                for k_block_offset in range(0, k_size, BLOCK_K):
+                    block_offs_k = k_block_offset + tl.arange(0, BLOCK_K)
+                    offs_k = block_offs_k + k_start_offset
+                    a_ptrs = (
+                        a_ptr
+{%- if not A_IS_2D %}
+                        + g * A_STRIDE_G
+{%- endif %}
+                        + (m_start_offset + offs_am[:, None]) * A_STRIDE_M
+                        + offs_k[None, :] * A_STRIDE_K
+                    )
+                    b_ptrs = (
+                        b_ptr
+{%- if not B_IS_2D %}
+                        + g * B_STRIDE_G
+{%- endif %}
+                        + (n_start_offset + offs_bn[:, None]) * B_STRIDE_N
+                        + offs_k[None, :] * B_STRIDE_K
+                    )
+                    a_mask = (offs_am[:, None] < m_size) & (block_offs_k[None, :] < k_size)
+                    b_mask = (offs_bn[:, None] < n_size) & (block_offs_k[None, :] < k_size)
+                    a = tl.load(a_ptrs, mask=a_mask, other=tl.zeros((), dtype=a_ptrs.dtype.element_ty))
+                    b = tl.load(b_ptrs, mask=b_mask, other=tl.zeros((), dtype=b_ptrs.dtype.element_ty))
+{%- if USE_FAST_ACCUM %}
+                    accumulator = tl.dot(a, b.T, accumulator)
+{%- else %}
+                    accumulator += tl.dot(a, b.T)
+{%- endif %}
+                    a_ptrs += BLOCK_K
+                    b_ptrs += BLOCK_K
+{%- endif %}
+
+                offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
+                offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
+{%- if SCALED %}
+                scale_a = tl.load(
+                    scale_a_ptr
+{%- if A_IS_2D %}
+                    + m_scale_start_offset
+{%- else %}
+                    + g * SCALE_A_STRIDE_G
+{%- endif %}
+                    + offs_am[:, None],
+                    mask=offs_am[:, None] < m_size,
+                    other=tl.zeros((), dtype=scale_a_ptr.dtype.element_ty),
+                )
+                scale_b = tl.load(
+                    scale_b_ptr
+{%- if B_IS_2D %}
+                    + n_scale_start_offset
+{%- else %}
+                    + g * SCALE_B_STRIDE_G
+{%- endif %}
+                    + offs_bn[None, :],
+                    mask=offs_bn[None, :] < n_size,
+                    other=tl.zeros((), dtype=scale_b_ptr.dtype.element_ty),
+                )
+                c = accumulator.to(tl.float32) * scale_a * scale_b
+{%- else %}
+                c = accumulator.to(tl.float32)
+{%- endif %}
+
+{%- if M_IS_VARYING %}
+                idx_m = (m_start_offset + offs_am[:, None])
+{%- else %}
+                idx_m = offs_am[:, None]
+{%- endif %}
+{%- if N_IS_VARYING %}
+                idx_n = (n_start_offset + offs_bn[None, :])
+{%- else %}
+                idx_n = offs_bn[None, :]
+{%- endif %}
+                mask = (offs_am[:, None] < m_size) & (offs_bn[None, :] < n_size)
+{%- if M_IS_VARYING or N_IS_VARYING %}
+                {{store_output(("idx_m", "idx_n"), "c", "mask", indent_width=16, val_shape=("BLOCK_M", "BLOCK_N"))}}
+{%- else %}
+                {{store_output(("g", "idx_m", "idx_n"), "c", "mask", indent_width=16, val_shape=("BLOCK_M", "BLOCK_N"))}}
+{%- endif %}
+                tidx += NUM_SMS
+
+            iterated_tiles += num_tiles
+"""
+
+
+triton_grouped_mm_template = TritonTemplate(
+    name="grouped_mm",
+    grid=persistent_grouped_mm_grid,
+    source=triton_grouped_mm_source,
+)
+
+triton_scaled_grouped_mm_template = TritonTemplate(
+    name="scaled_grouped_mm",
+    grid=persistent_grouped_mm_grid,
+    source=triton_grouped_mm_source,
+)
+
+cutedsl_grouped_mm_template = CuteDSLTemplate(
+    name="grouped_gemm_cutedsl",
+    source=load_kernel_template("cutedsl_mm_grouped"),
+)
+
+
+def grouped_mm_args(
+    mat1: TensorBox,
+    mat2: TensorBox,
+    offs: Optional[TensorBox],
+    layout=None,
+    out_dtype=None,
+):
+    mat1, mat2 = realize_inputs(mat1, mat2)
+    if offs is not None:
+        realize_inputs(offs)
+    mat1_size = mat1.get_size()
+    mat2_size = mat2.get_size()
+
+    m1dim, m2dim = len(mat1_size), len(mat2_size)
+
+    assert m1dim == 2 or m1dim == 3
+    assert m2dim == 2 or m2dim == 3
+
+    if layout is None:
+        from torch._inductor.ir import FixedLayout
+
+        if out_dtype is None:
+            out_dtype = mat1.get_dtype()
+        alignment = 16 // out_dtype.itemsize
+
+        if m1dim == 2:
+            if m2dim == 2:
+                assert offs is not None
+                out_size = [offs.get_size()[0], mat1_size[0], mat2_size[1]]
+            else:
+                out_size = [mat1_size[0], mat2_size[-1]]
+        else:
+            if m2dim == 2:
+                out_size = [mat1_size[1], mat2_size[1]]
+            else:
+                out_size = [mat1_size[0], mat1_size[1], mat2_size[-1]]
+        size_padded = (out_size[-1] + alignment - 1) // alignment * alignment
+        if len(out_size) == 2:
+            out_stride = [size_padded, 1]
+        else:
+            out_stride = [out_size[1] * size_padded, size_padded, 1]
+
+        layout = FixedLayout(
+            mat1.get_device(),
+            out_dtype,
+            out_size,
+            out_stride,
+        )
+    else:
+        assert out_dtype is None, "out_dtype is ignored if layout is specified."
+
+    return (mat1_size, mat2_size, layout, mat1, mat2, offs)
+
+
+aten__grouped_mm = ExternKernelChoice(
+    torch._grouped_mm,
+    "at::_grouped_mm",
+    op_overload=aten._grouped_mm.default,
+    has_out_variant=False,
+)
+
+
+aten__scaled_grouped_mm = ExternKernelChoice(
+    torch._scaled_grouped_mm,
+    "at::_scaled_grouped_mm",
+    op_overload=aten._scaled_grouped_mm.default,
+    has_out_variant=False,
+)
+
+
+def can_use_triton_kernel(
+    mat_a: TensorBox,
+    mat_b: TensorBox,
+    offs: Optional[TensorBox],
+    bias: Optional[TensorBox],
+    scale_result: Optional[TensorBox],
+) -> bool:
+    if not (
+        torch.cuda.is_available()
+        and torch.cuda.get_device_capability() >= (9, 0)
+        and not torch.version.hip
+    ):
+        return False
+    if not has_triton():
+        return False
+
+    # The _grouped_mm()/_scaled_grouped_mm() operator do not support
+    # bias nor scale_result yet.
+    if bias is not None:
+        return False
+    if scale_result is not None:
+        return False
+
+    if len(mat_a.get_size()) == 2 or len(mat_b.get_size()) == 2:
+        return offs is not None
+    else:
+        return offs is None
+
+
+def create_offsets(x, m1_size, m2_size, offs_size):
+    m1_is_2d = len(m1_size) == 2
+    m2_is_2d = len(m2_size) == 2
+    if m1_is_2d:
+        if m2_is_2d:
+            k = V.graph.sizevars.size_hint(m1_size[1])
+            noffs = V.graph.sizevars.size_hint(offs_size[0])
+            step = k / noffs
+            return torch.linspace(
+                step, k, noffs, dtype=x.get_dtype(), device=x.get_device()
+            )
+
+        else:
+            m = V.graph.sizevars.size_hint(m1_size[0])
+            noffs = V.graph.sizevars.size_hint(offs_size[0])
+            step = m / noffs
+            return torch.linspace(
+                step, m, noffs, dtype=x.get_dtype(), device=x.get_device()
+            )
+    else:
+        if m2_is_2d:
+            n = V.graph.sizevars.size_hint(m2_size[0])
+            noffs = V.graph.sizevars.size_hint(offs_size[0])
+            step = n / noffs
+            return torch.linspace(
+                step, n, noffs, dtype=x.get_dtype(), device=x.get_device()
+            )
+        else:
+            return None
+
+
+def _tuned_grouped_mm_common(
+    operator_name: str,
+    algorithm_name: str,
+    extern_kernel_choice: ExternKernelChoice,
+    kernel_template: TritonTemplate,
+    mat_a: TensorBox,
+    mat_b: TensorBox,
+    scale_a: Optional[TensorBox] = None,
+    scale_b: Optional[TensorBox] = None,
+    offs: Optional[TensorBox] = None,
+    bias: Optional[TensorBox] = None,
+    scale_result: Optional[TensorBox] = None,
+    out_dtype: Optional[torch.dtype] = None,
+    use_fast_accum: Optional[bool] = None,
+    layout: Optional[Layout] = None,
+) -> TensorBox:
+    assert (scale_a is None) == (scale_b is None)
+    assert scale_result is None or scale_a is not None
+
+    m1_size, m2_size, layout, mat_a, mat_b, offs = grouped_mm_args(
+        mat_a, mat_b, offs, layout=layout, out_dtype=out_dtype
+    )
+    counters["aten_mm_info"][operator_name] += 1
+    log_message = f"Tuned {operator_name}: mat1_shape=%s, mat2_shape=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s"
+    log.info(
+        log_message,
+        m1_size,
+        m2_size,
+        mat_a.get_dtype(),
+        mat_b.get_dtype(),
+        layout,
+    )
+
+    if scale_a is not None and scale_b is not None:
+        check_supported_striding(mat_a, mat_b)
+
+    # workaround for Inductor not supporting optional tensor input arguments
+    input_nodes: list[Any] = [mat_a, mat_b]
+    if scale_a is not None:
+        input_nodes.append(realize_inputs(scale_a))
+    if scale_b is not None:
+        input_nodes.append(realize_inputs(scale_b))
+    if offs is not None:
+        input_nodes.append(realize_inputs(offs))
+
+    if use_fast_accum is None:
+        aten_choice = extern_kernel_choice.bind(
+            input_nodes,
+            layout,
+            out_dtype=out_dtype,
+        )
+    else:
+        aten_choice = extern_kernel_choice.bind(
+            input_nodes,
+            layout,
+            out_dtype=out_dtype,
+            use_fast_accum=use_fast_accum,
+        )
+    if use_fast_accum is None:
+        use_fast_accum = False
+
+    choices: list[ChoiceCaller] = []
+    if use_aten_gemm_kernels():
+        choices.append(aten_choice)
+
+    _, is_nonzero = _is_static_problem(layout)
+
+    # Checking only for the equality of corresponding dims of
+    # multiplicands here, relying on meta function checks for
+    # everything else.
+    if len(m1_size) == 2:
+        if len(m2_size) == 2:
+            m, k1 = m1_size
+            k2, _ = m2_size
+            # pyrefly: ignore [missing-attribute]
+            g = offs.get_size()[0]
+            V.graph.sizevars.check_equals(k1, k2)
+            a_is_2d, b_is_2d = True, True
+        else:
+            # pyrefly: ignore [missing-attribute]
+            g1 = offs.layout.size[0]
+            m, k1 = m1_size
+            g2, k2, _ = m2_size
+            g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
+            V.graph.sizevars.check_equals(k1, k2)
+            a_is_2d, b_is_2d = True, False
+    else:
+        if len(m2_size) == 2:
+            # pyrefly: ignore [missing-attribute]
+            g1 = offs.layout.size[0]
+            g2, m, k1 = m1_size
+            k2, _ = m2_size
+            g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
+            V.graph.sizevars.check_equals(k1, k2)
+            a_is_2d, b_is_2d = False, True
+        else:
+            g1, m, k1 = m1_size
+            g2, k2, _ = m2_size
+            g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
+            V.graph.sizevars.check_equals(k1, k2)
+            a_is_2d, b_is_2d = False, False
+
+    if (
+        is_nonzero
+        and use_triton_template(layout)
+        and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result)
+    ):
+        scaled = scale_a is not None
+
+        a_is_k_major = mat_a.get_stride()[-1] == 1
+        b_is_k_major = mat_b.get_stride()[-2] == 1
+
+        triton_has_make_tensor_descriptor = hasattr(tl, "make_tensor_descriptor")
+        triton_has_experimental_make_tensor_descriptor = hasattr(
+            tl, "_experimental_make_tensor_descriptor"
+        )
+        use_tma_load = (
+            triton_has_make_tensor_descriptor
+            or triton_has_experimental_make_tensor_descriptor
+        )
+        kwargs = {
+            "SCALED": scaled,
+            "A_IS_2D": a_is_2d,
+            "B_IS_2D": b_is_2d,
+            "A_IS_K_MAJOR": a_is_k_major,
+            "B_IS_K_MAJOR": b_is_k_major,
+            "USE_FAST_ACCUM": use_fast_accum,
+            "NUM_SMS": get_num_sms(),
+            "USE_TMA_LOAD": use_tma_load,
+            "USE_EXPERIMENTAL_MAKE_TENSOR_DESCRIPTOR": triton_has_experimental_make_tensor_descriptor,
+        }
+
+        for config in early_config_prune(
+            g, m, mat_a.dtype.itemsize, grouped_mm_configs(), kwargs
+        ):
+            kernel_template.maybe_append_choice(
+                choices,
+                input_nodes=input_nodes,
+                layout=layout,
+                num_stages=config.num_stages,
+                num_warps=config.num_warps,
+                **kwargs,
+                **config.kwargs,
+            )
+
+    if use_blackwell_cutedsl_grouped_mm(
+        mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result
+    ):
+        for config in get_groupgemm_configs():
+            kwargs = dict(
+                ACC_DTYPE="cutlass.Float32",
+            )
+
+            cutedsl_grouped_mm_template.maybe_append_choice(
+                choices,
+                input_nodes=input_nodes,
+                layout=layout,
+                **kwargs,
+                **asdict(config),
+            )
+
+    input_gen_fns = {
+        4: lambda x: create_offsets(
+            x, m1_size, m2_size, offs.get_size() if offs is not None else None
+        ),
+    }
+    return autotune_select_algorithm(
+        algorithm_name, choices, input_nodes, layout, input_gen_fns=input_gen_fns
+    )
+
+
+@register_lowering(aten._grouped_mm.default, type_promotion_kind=None)
+def tuned_grouped_mm(
+    mat_a: TensorBox,
+    mat_b: TensorBox,
+    offs: Optional[TensorBox] = None,
+    bias: Optional[TensorBox] = None,
+    out_dtype: Optional[torch.dtype] = None,
+    layout: Optional[Layout] = None,
+) -> TensorBox:
+    """Auto-tuning for _grouped_mm() operator."""
+
+    return _tuned_grouped_mm_common(
+        "aten._grouped_mm.default",
+        "grouped_mm",
+        aten__grouped_mm,
+        triton_grouped_mm_template,
+        mat_a,
+        mat_b,
+        None,
+        None,
+        offs,
+        bias,
+        None,
+        out_dtype,
+        None,
+        layout,
+    )
+
+
+@register_lowering(aten._scaled_grouped_mm.default, type_promotion_kind=None)
+def tuned_scaled_grouped_mm(
+    mat_a: TensorBox,
+    mat_b: TensorBox,
+    scale_a: TensorBox,
+    scale_b: TensorBox,
+    offs: Optional[TensorBox] = None,
+    bias: Optional[TensorBox] = None,
+    scale_result: Optional[TensorBox] = None,
+    out_dtype: Optional[torch.dtype] = None,
+    use_fast_accum: bool = False,
+    layout: Optional[Layout] = None,
+) -> TensorBox:
+    """Auto-tuning for _scaled_grouped_mm() operator."""
+
+    # matching _scaled_grouped_mm_cuda Blas.cpp implementation
+    out_dtype = out_dtype or torch.bfloat16
+
+    return _tuned_grouped_mm_common(
+        "aten._scaled_grouped_mm.default",
+        "scaled_grouped_mm",
+        aten__scaled_grouped_mm,
+        triton_scaled_grouped_mm_template,
+        mat_a,
+        mat_b,
+        scale_a,
+        scale_b,
+        offs,
+        bias,
+        scale_result,
+        out_dtype,
+        use_fast_accum,
+        layout,
+    )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/mm_plus_mm.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/mm_plus_mm.py
new file mode 100644
index 0000000000000000000000000000000000000000..aef8dfb2168f4e9f410310f898ff3ae08bae02ee
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/mm_plus_mm.py
@@ -0,0 +1,177 @@
+# mypy: allow-untyped-defs
+
+import logging
+from typing import TYPE_CHECKING, Union
+
+import torch
+
+from .. import config as inductor_config
+from ..kernel_inputs import MMKernelInputs
+from ..lowering import lowerings
+from ..select_algorithm import (
+    autotune_select_algorithm,
+    ExternKernelChoice,
+    TritonTemplate,
+)
+from ..utils import use_aten_gemm_kernels, use_triton_template
+from ..virtualized import V
+from .mm_common import mm_args, mm_grid
+
+
+if TYPE_CHECKING:
+    from torch._inductor.ir import ChoiceCaller
+    from torch._inductor.select_algorithm import KernelTemplate
+
+log = logging.getLogger(__name__)
+
+aten = torch.ops.aten
+
+aten_mm_plus_mm = ExternKernelChoice(
+    torch.ops.inductor._mm_plus_mm, "torch::inductor::_mm_plus_mm"
+)
+
+mm_plus_mm_template = TritonTemplate(
+    name="mm_plus_mm",
+    grid=mm_grid,
+    debug=False,
+    source=r"""
+{{def_kernel("A", "B", "C", "D")}}
+    M = {{size("A", 0)}}
+    N = {{size("B", 1)}}
+    K1 = {{size("A", 1)}}
+    if M * N == 0:
+        # early exit due to zero-size input(s)
+        return
+    # K2 = {{size("C", 1)}}
+    stride_am = {{stride("A", 0)}}
+    stride_ak = {{stride("A", 1)}}
+    stride_bk = {{stride("B", 0)}}
+    stride_bn = {{stride("B", 1)}}
+    stride_cm = {{stride("C", 0)}}
+    stride_ck = {{stride("C", 1)}}
+    stride_dk = {{stride("D", 0)}}
+    stride_dn = {{stride("D", 1)}}
+
+    # based on triton.ops.matmul
+    pid = tl.program_id(0).to(INDEX_DTYPE)
+    grid_m = (M + BLOCK_M - 1) // BLOCK_M
+    grid_n = (N + BLOCK_N - 1) // BLOCK_N
+
+    # re-order program ID for better L2 performance
+    width = GROUP_M * grid_n
+    group_id = pid // width
+    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
+    pid_m = group_id * GROUP_M + (pid % group_size)
+    pid_n = (pid % width) // (group_size)
+    tl.assume(pid_m >= 0)
+    tl.assume(pid_n >= 0)
+
+    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+    if (((stride_am == 1 and stride_ak == M) or (stride_am == K1 and stride_ak == 1))
+        and ((stride_cm == 1 and stride_ck == M) or (stride_cm == K1 and stride_ck == 1))):
+        ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+    else:
+        ram = rm % M
+
+    if (((stride_bk == 1 and stride_bn == K1) or (stride_bk == N and stride_bn == 1))
+        and ((stride_dk == 1 and stride_dn == K1) or (stride_dk == N and stride_dn == 1))):
+        rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+    else:
+        rbn = rn % N
+
+    rk = tl.arange(0, BLOCK_K)
+    A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+    B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+    C = C + (ram[:, None] * stride_cm + rk[None, :] * stride_ck)
+    D = D + (rk[:, None] * stride_dk + rbn[None, :] * stride_dn)
+
+    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+    for k1 in range(K1, 0, -BLOCK_K):
+        # First matmul with A @ B
+        if EVEN_K:
+            a = tl.load(A)
+            b = tl.load(B)
+        else:
+            a = tl.load(A, mask=rk[None, :] < k1, other=0.)
+            b = tl.load(B, mask=rk[:, None] < k1, other=0.)
+        acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
+        A += BLOCK_K * stride_ak
+        B += BLOCK_K * stride_bk
+
+    for k2 in range(K1, 0, -BLOCK_K):
+
+        # Second matmul with C @ D
+        if EVEN_K:
+            c = tl.load(C)
+            d = tl.load(D)
+        else:
+            c = tl.load(C, mask=rk[None, :] < k2, other=0.)
+            d = tl.load(D, mask=rk[:, None] < k2, other=0.)
+        acc += tl.dot(c, d, allow_tf32=ALLOW_TF32)
+        C += BLOCK_K * stride_ck
+        D += BLOCK_K * stride_dk
+
+
+    idx_m = rm[:, None]
+    idx_n = rn[None, :]
+    mask = (idx_m < M) & (idx_n < N)
+
+    # inductor generates a suffix
+    {{store_output(("idx_m", "idx_n"), "acc", "mask", val_shape=("BLOCK_M", "BLOCK_N"))}}
+""",
+    cache_codegen_enabled_for_template=True,
+)
+
+
+def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
+    """
+    Computes mm(mat1, mat2) + mm(mat3, mat4)
+    """
+    # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
+    m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
+    m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout)
+
+    # Optimization is optional, because we can always just not do the fusion
+    if (
+        m1 * n1 == 0
+        or m2 * n2 == 0
+        or not V.graph.sizevars.statically_known_list_equals(
+            mat1.get_size(), mat3.get_size()
+        )
+        or not V.graph.sizevars.statically_known_list_equals(
+            mat2.get_size(), mat4.get_size()
+        )
+        or inductor_config.triton.native_matmul
+    ):
+        # TODO(jansel): support different K values when this is fixed:
+        # https://github.com/triton-lang/triton/issues/967
+        return lowerings[aten.add](
+            lowerings[aten.mm](mat1, mat2), lowerings[aten.mm](mat3, mat4)
+        )
+
+    # Create MMKernelInputs for MM Plus MM (matrices are at indices 0, 1 for first pair)
+    # Note: This is a special case with 4 matrices, but we use the first pair for M, N, K extraction
+    kernel_inputs = MMKernelInputs([mat1, mat2, mat3, mat4], mat1_idx=0, mat2_idx=1)
+
+    assert layout1 == layout2
+    # options to tune from
+    choices: list[ChoiceCaller] = []
+
+    # Collect all templates for unified call
+    templates_to_use: list[Union[ExternKernelChoice, KernelTemplate]] = []
+    if use_aten_gemm_kernels():
+        templates_to_use.append(aten_mm_plus_mm)
+
+    if use_triton_template(layout1, check_max_autotune=False):
+        templates_to_use.append(mm_plus_mm_template)
+
+    # Single unified call for all templates
+    choices.extend(
+        V.choices.get_template_configs(kernel_inputs, templates_to_use, "mm_plus_mm")
+    )
+
+    return autotune_select_algorithm(
+        "mm_plus_mm", choices, kernel_inputs.nodes(), layout1
+    )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/lookup_table/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/lookup_table/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ebb1d5618bfa5eb44cfa9c4bc3921e887f54374
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/lookup_table/__init__.py
@@ -0,0 +1,32 @@
+"""
+Template lookup table system for PyTorch Inductor.
+
+This package provides functionality for:
+- Loading pre-configured template choices from lookup tables
+- Managing template configurations and choices
+
+All functionality is contained within the LookupTableChoices class.
+You can customize any aspect by subclassing LookupTableChoices and overriding methods.
+
+Usage:
+    # Basic usage
+    choices = LookupTableChoices()
+    V.set_choices_handler(choices)
+
+    # Custom usage
+    class MyCustomChoices(LookupTableChoices):
+        def _get_lookup_table(self):
+            return my_custom_table
+
+        def make_lookup_key(self, kernel_inputs, op_name, include_device=False):
+            return f"custom_{op_name}_{hash(str(kernel_inputs))}"
+
+    V.set_choices_handler(MyCustomChoices())
+"""
+
+from .choices import LookupTableChoices
+
+
+__all__ = [
+    "LookupTableChoices",
+]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/lookup_table/choices.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/lookup_table/choices.py
new file mode 100644
index 0000000000000000000000000000000000000000..46e54180114aba4b59f0832b6cd64a408df521c9
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/lookup_table/choices.py
@@ -0,0 +1,418 @@
+from __future__ import annotations
+
+import copy
+import logging
+from functools import lru_cache
+from typing import Any, Optional, TYPE_CHECKING, Union
+
+import torch
+from torch._inductor import config
+from torch._inductor.choices import InductorChoices
+from torch._inductor.kernel_template_choice import KernelTemplateChoice
+from torch._inductor.template_heuristics.params import DictKernelTemplateParams
+
+
+log = logging.getLogger(__name__)
+
+
+if TYPE_CHECKING:
+    from collections.abc import Generator
+
+    from torch._inductor.codegen.common import KernelTemplate
+    from torch._inductor.kernel_inputs import KernelInputs
+    from torch._inductor.select_algorithm import ExternKernelChoice
+
+
+class LookupTableChoices(InductorChoices):
+    """
+    InductorChoices subclass that uses lookup table when available, otherwise falls back to parent.
+    All lookup functionality is contained within this class and can be customized by overriding methods.
+    """
+
+    def _get_lookup_table(self) -> dict[str, list[dict[str, Any]]]:
+        """
+        Get the template lookup table from config.
+        Override this method to use custom lookup table sources (database, API, etc.).
+        """
+        if not torch.cuda.is_available() or config.lookup_table.table is None:
+            return {}
+        return config.lookup_table.table
+
+    @staticmethod
+    @lru_cache
+    def _get_device_key(device: torch.device) -> Optional[str]:
+        """
+        Generate a device key for lookup table indexing.
+        For CPU devices, returns None.
+        For CUDA devices, returns the props.gcnArchName string.
+        """
+        if device.type != "cuda":
+            # only cuda devices are supported, this indicates that the system is not in use
+            # for this device
+            return None
+
+        # Get CUDA device properties
+        props = torch.cuda.get_device_properties(device.index)
+        return props.gcnArchName
+
+    @staticmethod
+    def _generate_kernel_inputs_key(kernel_inputs: KernelInputs) -> str:
+        """
+        Generate a key based on input node properties and scalars.
+        The key includes dtype, size, and stride information for each input node,
+        plus scalar values as key=value pairs separated by & signs.
+        """
+        # Get node information using existing methods
+        dtypes = kernel_inputs.dtypes()
+        shapes = kernel_inputs.shapes_hinted()
+        strides = kernel_inputs.strides_hinted()
+
+        # Create tuple of (dtype, shape_list, stride_list) for each node
+        node_info = tuple(
+            (dtype, list(shape), list(stride))
+            for dtype, shape, stride in zip(dtypes, shapes, strides)
+        )
+
+        # Create base key from node information
+        fmt_key = str(node_info)
+        # Add scalar information if present
+        if kernel_inputs._scalars:
+            # Sort scalars for consistent key generation and join with &
+            scalar_parts = [
+                f"{key}={value}"
+                for key, value in sorted(kernel_inputs._scalars.items())
+            ]
+            scalars_key = "&".join(scalar_parts)
+            fmt_key = f"{fmt_key}+{scalars_key}"
+
+        return f"{fmt_key}"
+
+    def make_lookup_key(
+        self, kernel_inputs: KernelInputs, op_name: str, include_device: bool = False
+    ) -> Optional[str]:
+        """
+        Create a flattened lookup key from kernel inputs and operation name.
+        Override this method to customize key generation.
+
+        Args:
+            kernel_inputs: KernelInputs object containing input nodes and scalars
+            op_name: Operation name (e.g., "mm", "addmm")
+            include_device: Whether to include device key in the generated key
+
+        Returns:
+            A string key combining device (optional), operation, and input information
+        """
+        device = kernel_inputs.device()
+        dev_key = self._get_device_key(device)
+        if dev_key is None:
+            # The system does not run when dev_key is None, regardless of
+            # whether include_device is True or False
+            return None
+        if not include_device:
+            dev_key = None
+
+        # Generate input key using our staticmethod
+        input_key = self._generate_kernel_inputs_key(kernel_inputs)
+
+        # Create the flattened lookup key
+        if dev_key is not None:
+            key_parts = [dev_key, input_key, op_name]
+        else:
+            key_parts = [input_key, op_name]
+
+        return "+".join(key_parts)
+
+    def make_lookup_key_variants(
+        self, kernel_inputs: KernelInputs, op_name: str
+    ) -> tuple[Optional[str], Optional[str]]:
+        """
+        Generate both device-specific and device-agnostic lookup keys.
+        Override this method to customize key variant generation.
+
+        Args:
+            kernel_inputs: KernelInputs object containing input nodes and scalars
+            op_name: Operation name (e.g., "mm", "addmm")
+
+        Returns:
+            Tuple of (device_key, device_agnostic_key). Either may be None if generation fails.
+        """
+        device_key = self.make_lookup_key(kernel_inputs, op_name, include_device=True)
+        device_agnostic_key = self.make_lookup_key(
+            kernel_inputs, op_name, include_device=False
+        )
+
+        return device_key, device_agnostic_key
+
+    @staticmethod
+    def _entry_is_valid(
+        cfg: dict[str, Any],
+        template_id: str,
+        template_hash_map: Optional[dict[str, Optional[str]]],
+    ) -> bool:
+        """
+        Check if a config entry is valid based on template hash validation.
+
+        Args:
+            cfg: Configuration dictionary that may contain a template_hash field
+            template_id: The template identifier
+            template_hash_map: Optional mapping from template_uid to src_hash for validation
+
+        Returns:
+            True if the config is valid and should be kept, False if it should be filtered out
+        """
+        # If hash checking is disabled or no hash map provided, keep the config
+        if not config.lookup_table.check_src_hash or not template_hash_map:
+            return True
+
+        template_hash = template_hash_map.get(template_id)
+        config_hash = cfg.get("template_hash")
+
+        # Both hashes present - validate they match
+        if template_hash is not None and config_hash is not None:
+            if config_hash != template_hash:
+                log.warning(
+                    "Hash validation failed for template '%s': config_hash='%s' != template_hash='%s'. "
+                    "Template code may have changed. Filtering out config: %s",
+                    template_id,
+                    config_hash,
+                    template_hash,
+                    {k: v for k, v in cfg.items() if k != "template_hash"},
+                )
+                return False
+            else:
+                log.debug(
+                    "Hash validation passed for template '%s': hash='%s'",
+                    template_id,
+                    template_hash,
+                )
+                return True
+        # Config has no hash - keep it
+        elif config_hash is None:
+            log.debug(
+                "Config for template '%s' has no hash - keeping it (template_hash='%s')",
+                template_id,
+                template_hash,
+            )
+            return True
+        # Template has no hash - keep config
+        else:
+            log.debug(
+                "Template '%s' has no src_hash - keeping config with hash '%s'",
+                template_id,
+                config_hash,
+            )
+            return True
+
+    def lookup_template_configs(
+        self,
+        kernel_inputs: KernelInputs,
+        op_name: str,
+        template_uids: list[str],
+        template_hash_map: Optional[dict[str, Optional[str]]] = None,
+    ) -> dict[str, list[dict[str, Any]]]:
+        """
+        Unified function to look up template configurations for multiple templates.
+        Override this method to customize lookup logic.
+
+        Args:
+            kernel_inputs: KernelInputs object containing input nodes and scalars
+            op_name: Operation name (e.g., "mm", "addmm")
+            template_uids: List of template identifiers (e.g., ["mm", "tma", "decompose_k"])
+            template_hash_map: Optional mapping from template_uid to src_hash for validation
+
+        Returns:
+            {}: No lookup table in use, or no matches found for any template
+            {"template_uid1": [config1, config2], ...}: Matches found, filtered configurations
+        """
+        lookup_table = self._get_lookup_table()
+        if not lookup_table:
+            log.debug("Lookup table: no table configured or CUDA unavailable")
+            return {}
+
+        # Try both key variants: device-specific first, then device-agnostic
+        # If both exist, device-specific takes priority
+        device_key, device_agnostic_key = self.make_lookup_key_variants(
+            kernel_inputs, op_name
+        )
+
+        config_list = []
+
+        for key_type, key in [
+            ("device-specific", device_key),
+            ("device-agnostic", device_agnostic_key),
+        ]:
+            if key is not None:
+                config_list = lookup_table.get(key, [])
+                if config_list:
+                    log.debug(
+                        "Lookup table: found %d configs using %s key '%s' for %s",
+                        len(config_list),
+                        key_type,
+                        key,
+                        op_name,
+                    )
+                    break
+        else:
+            log.debug(
+                "Lookup table: no match for %s (tried keys: %s, %s) (table has %d keys)",
+                op_name,
+                device_key,
+                device_agnostic_key,
+                len(lookup_table),
+            )
+            return {}
+
+        log.debug(
+            "Lookup table: found %d configs for %s templates %s",
+            len(config_list),
+            op_name,
+            template_uids,
+        )
+        # Group configs by template_id
+        configs_by_template: dict[str, list[dict[str, Any]]] = {}
+        for cfg in config_list:
+            if not isinstance(cfg, dict):
+                raise ValueError(
+                    f"Config for {op_name} operation is not a dictionary: {cfg}"
+                )
+            if "template_id" not in cfg:
+                raise ValueError(
+                    f"Config for {op_name} operation missing required 'template_id' field: {cfg}"
+                )
+
+            template_id = cfg["template_id"]
+            if template_id in template_uids:
+                if template_id not in configs_by_template:
+                    configs_by_template[template_id] = []
+                configs_by_template[template_id].append(cfg)
+
+        # Check template hashes and clean up template_id field
+        result = {}
+        for template_id, matching_configs in configs_by_template.items():
+            filtered_configs = []
+            for cfg in matching_configs:
+                # Check template hash using helper function
+                if not self._entry_is_valid(cfg, template_id, template_hash_map):
+                    continue
+
+                # Return a copy of the config, as we don't want to modify the original
+                cconfig = copy.deepcopy(cfg)
+                # Lastly, we have to throw out the template_id, as it's not a valid kwarg
+                # and just used to identify which template the entry belongs to
+                del cconfig["template_id"]
+                # Similarly, the template_hash is not a valid kwarg
+                cconfig.pop("template_hash", None)
+                filtered_configs.append(cconfig)
+
+            if filtered_configs:
+                result[template_id] = filtered_configs
+
+        return result
+
+    def _finalize_template_configs(
+        self,
+        template_choices: dict[str, Generator[KernelTemplateChoice, None, None]],
+        kernel_inputs: KernelInputs,
+        templates: list[Union[KernelTemplate, ExternKernelChoice]],
+        op_name: str,
+        kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None,
+    ) -> list[KernelTemplateChoice]:
+        """Check lookup table for hits, use those if found, otherwise fall back to parent."""
+        # 1. Collect template src_hashes for validation
+        template_uids = [template.uid for template in templates]
+        template_hash_map = {}
+        for template in templates:
+            src_hash = getattr(template, "src_hash", None)
+            template_hash_map[template.uid] = src_hash
+
+        log.debug(
+            "Choices: attempting lookup for %s with %d templates",
+            op_name,
+            len(template_uids),
+        )
+
+        # 2. Single batch lookup for all templates
+        lookup_results = self.lookup_template_configs(
+            kernel_inputs, op_name, template_uids, template_hash_map
+        )
+
+        # 3. Early exit if no lookup table or no matches
+        if not lookup_results:  # Empty dict
+            log.info("LookupChoices: lookup miss for %s, using fallback", op_name)
+            return self._fallback(
+                template_choices,
+                kernel_inputs,
+                templates,
+                op_name,
+                kwarg_overrides,
+            )
+
+        log.info(
+            "LookupChoices: lookup hit for %s - found %d/%d templates: %s",
+            op_name,
+            len(lookup_results),
+            len(template_uids),
+            list(lookup_results.keys()),
+        )
+
+        # 4. Create KTCs only for templates with lookup entries
+        return self._create_lookup_choices(
+            lookup_results, templates, kernel_inputs, op_name
+        )
+
+    def _fallback(
+        self,
+        template_choices: dict[str, Generator[KernelTemplateChoice, None, None]],
+        kernel_inputs: KernelInputs,
+        templates: list[Union[KernelTemplate, ExternKernelChoice]],
+        op_name: str,
+        kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None,
+    ) -> list[KernelTemplateChoice]:
+        """Fallback to parent if no lookup table or no matches."""
+        # NOTE: this is broken out, so that subclasses are able to override this
+        # to handle explicitly the situations where the lookup take had a miss vs
+        # overriding the entire logic
+        return super()._finalize_template_configs(
+            template_choices,
+            kernel_inputs,
+            templates,
+            op_name,
+            kwarg_overrides,
+        )
+
+    def _create_lookup_choices(
+        self,
+        lookup_results: dict[str, list[dict[str, Any]]],
+        templates: list[Union[KernelTemplate, ExternKernelChoice]],
+        kernel_inputs: KernelInputs,
+        op_name: str,
+    ) -> list[KernelTemplateChoice]:
+        """Create KernelTemplateChoice objects from lookup results using parent's get_ktc method."""
+        templates_by_uid = {template.uid: template for template in templates}
+        lookup_choices: list[KernelTemplateChoice] = []
+
+        for template_uid, configs in lookup_results.items():
+            template = templates_by_uid[template_uid]
+
+            # Use parent's get_ktc method to get a generator, then get the first base KTC
+            ktc_generator = self.get_ktc(kernel_inputs, template, op_name)
+
+            try:
+                base_ktc = next(ktc_generator)
+            except StopIteration:
+                # No configs from heuristic, skip this template
+                continue
+
+            # For each lookup config, create a KTC with the override kwargs
+            for c in configs:
+                lookup_ktc = KernelTemplateChoice(
+                    template=base_ktc.template,
+                    # use the ones from the lookup table
+                    params=DictKernelTemplateParams(c),
+                    extra_kwargs=base_ktc.extra_kwargs,
+                    layout=base_ktc.layout,
+                    inputs=base_ktc.inputs,
+                )
+                lookup_choices.append(lookup_ktc)
+
+        return lookup_choices
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/package/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/package/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..15587401b723581b57f94fdcddbcbc8255f73bfe
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/package/__init__.py
@@ -0,0 +1 @@
+from .package import AOTICompiledModel, load_package, package_aoti
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/package/build_package.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/package/build_package.py
new file mode 100644
index 0000000000000000000000000000000000000000..9205b9ced254275018472108485173eba9479f11
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/package/build_package.py
@@ -0,0 +1,15 @@
+build_package_contents = """
+import os
+from pathlib import Path
+
+from torch._inductor.package.package import compile_so
+
+curr_dir = Path(__file__).parent
+aoti_files = [
+    os.path.join(root, file)
+    for root, dirs, files in os.walk(curr_dir)
+    for file in files
+]
+
+output_so = compile_so(curr_dir, aoti_files, curr_dir)
+"""
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/package/package.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/package/package.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd11d033cadb3fc3cfdba8165fb42dd996284931
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/package/package.py
@@ -0,0 +1,138 @@
+import io
+import json
+import logging
+import os
+import tempfile
+from typing import IO
+
+import torch
+from torch._inductor import config
+from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder
+from torch.export.pt2_archive._package import (
+    AOTI_FILES,
+    AOTICompiledModel,
+    load_pt2,
+    package_pt2,
+)
+from torch.types import FileLike
+
+
+log = logging.getLogger(__name__)
+
+
+def compile_so(aoti_dir: str, aoti_files: list[str], so_path: str) -> str:
+    def get_aoti_file_with_suffix(suffix: str) -> str:
+        for file in aoti_files:
+            if file.endswith(suffix):
+                return file
+        raise RuntimeError(f"Unable to find file with suffix {suffix}")
+
+    # Compile all the files into a .so
+    cpp_file = os.path.join(aoti_dir, get_aoti_file_with_suffix(".cpp"))
+    consts_o = os.path.join(aoti_dir, get_aoti_file_with_suffix(".o"))
+
+    file_name = os.path.splitext(cpp_file)[0]
+
+    # Parse compile flags and build the .o file
+    with open(file_name + "_compile_flags.json") as f:
+        compile_flags = json.load(f)
+
+    compile_options = BuildOptionsBase(
+        **compile_flags, use_relative_path=config.is_fbcode()
+    )
+    object_builder = CppBuilder(
+        name=file_name,
+        sources=cpp_file,
+        BuildOption=compile_options,
+    )
+    output_o = object_builder.get_target_file_path()
+    object_builder.build()
+
+    # Parse linker flags and build the .so file
+    with open(file_name + "_linker_flags.json") as f:
+        linker_flags = json.load(f)
+
+    linker_options = BuildOptionsBase(
+        **linker_flags, use_relative_path=config.is_fbcode()
+    )
+    so_builder = CppBuilder(
+        name=os.path.split(so_path)[-1],
+        sources=[output_o, consts_o],
+        BuildOption=linker_options,
+        output_dir=so_path,
+    )
+    output_so = so_builder.get_target_file_path()
+    so_builder.build()
+
+    # mmapped weights
+    serialized_weights_filename = file_name + "_serialized_weights.bin"
+    if serialized_weights_filename in aoti_files:
+        with open(serialized_weights_filename, "rb") as f_weights:
+            serialized_weights = f_weights.read()
+
+        with open(output_so, "a+b") as f_so:
+            so_size = f_so.tell()
+            # Page align the weights
+            f_so.write(b" " * (16384 - so_size % 16384))
+            f_so.write(serialized_weights)
+
+    return output_so
+
+
+def package_aoti(
+    archive_file: FileLike,
+    aoti_files: AOTI_FILES,
+) -> FileLike:
+    """
+    Saves the AOTInductor generated files to the PT2Archive format.
+
+    Args:
+        archive_file: The file name to save the package to.
+        aoti_files: This can either be a singular path to a directory containing
+        the AOTInductor files, or a dictionary mapping the model name to the
+        path to its AOTInductor generated files.
+    """
+
+    return package_pt2(
+        archive_file,
+        aoti_files=aoti_files,
+    )
+
+
+def load_package(
+    path: FileLike,
+    model_name: str = "model",
+    run_single_threaded: bool = False,
+    num_runners: int = 1,
+    device_index: int = -1,
+) -> AOTICompiledModel:
+    try:
+        pt2_contents = load_pt2(
+            path,
+            run_single_threaded=run_single_threaded,
+            num_runners=num_runners,
+            device_index=device_index,
+        )
+        if model_name not in pt2_contents.aoti_runners:
+            raise RuntimeError(f"Model {model_name} not found in package")
+        return pt2_contents.aoti_runners[model_name]
+    except RuntimeError:
+        log.warning("Loading outdated pt2 file. Please regenerate your package.")
+
+    if isinstance(path, (io.IOBase, IO)):
+        with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
+            # TODO(angelayi): We shouldn't need to do this -- miniz should
+            # handle reading the buffer. This is just a temporary workaround
+            path.seek(0)
+            f.write(path.read())
+            log.debug("Writing buffer to tmp file located at %s.", f.name)
+            loader = torch._C._aoti.AOTIModelPackageLoader(
+                f.name, model_name, run_single_threaded, num_runners, device_index
+            )
+            return AOTICompiledModel(loader)
+
+    path = os.fspath(path)  # AOTIModelPackageLoader expects (str, str)
+    loader = torch._C._aoti.AOTIModelPackageLoader(
+        path, model_name, run_single_threaded, num_runners, device_index
+    )
+    return AOTICompiledModel(loader)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/autotune_cache.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/autotune_cache.py
new file mode 100644
index 0000000000000000000000000000000000000000..0034a6a8feb3de9d6c3052a4bd5b5cc18ac112e0
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/autotune_cache.py
@@ -0,0 +1,649 @@
+"""
+PyTorch Inductor Autotuning Cache System
+
+This module implements a caching system for autotuning configurations in PyTorch's Inductor compiler.
+It provides mechanisms to store and retrieve optimal kernel configurations both locally and remotely,
+which significantly speeds up compilation by reusing previously discovered optimal parameters.
+
+The caching system includes:
+- Local filesystem caching for individual machine reuse
+- Remote caching for sharing optimizations across machines
+- Bundled caching to efficiently store multiple related configurations
+- Cache invalidation based on PyTorch versions and backend changes
+- Serialization/deserialization support for worker processes
+
+Key components:
+- AutotuneCache: Main class for managing cache access and storage
+- AutotuneCacheBundler: Bundles multiple cache entries for efficient storage
+- LocalAutotuneCache: Handles filesystem-based caching
+- _LocalAutotuneCacheBackend: Low-level file operations for cache storage
+- AutotuneCacheArtifact: Integration with PyTorch's artifact system
+
+This caching system is critical for performance as it eliminates the need to re-run
+expensive autotuning operations when the same kernels are compiled multiple times.
+"""
+
+from __future__ import annotations
+
+import dataclasses
+import hashlib
+import logging
+import os
+import os.path
+import re
+from typing import Any, TYPE_CHECKING
+from typing_extensions import override
+
+import torch
+from torch._inductor.runtime.runtime_utils import cache_dir
+from torch.compiler._cache import (
+    CacheArtifact,
+    CacheArtifactFactory,
+    CacheArtifactManager,
+)
+from torch.utils._triton import has_triton
+
+from ..remote_cache import (
+    create_cache,
+    JsonDataTy,
+    RemoteCache,
+    RemoteCacheBackend,
+    RemoteCacheJsonSerde,
+)
+from .triton_compat import Config, HAS_WARP_SPEC
+
+
+if TYPE_CHECKING:
+    from ..remote_cache import Sample
+
+log = logging.getLogger(__name__)
+
+
+_InductorMetaTy = dict[str, object]
+
+
+def inductor_meta_from_config() -> _InductorMetaTy:
+    from torch._inductor import config
+
+    backend_hash = None
+    if has_triton():
+        try:
+            backend_hash = torch.utils._triton.triton_hash_with_backend()
+        except RuntimeError:
+            # This can get the error:
+            #   RuntimeError: 0 active drivers ([]). There should only be one.
+            pass
+
+    is_hip = None
+    if torch.version.hip is not None:
+        is_hip = True
+
+    return {
+        "autotune_local_cache": config.autotune_local_cache,
+        "autotune_remote_cache": config.autotune_remote_cache,
+        "backend_hash": backend_hash,
+        "bundled_autotune_remote_cache": config.bundled_autotune_remote_cache,
+        "coordinate_descent_tuning": config.coordinate_descent_tuning,
+        "is_fbcode": config.is_fbcode(),
+        "is_hip": is_hip,
+    }
+
+
+@CacheArtifactFactory.register
+class AutotuneCacheArtifact(CacheArtifact):
+    @override
+    def populate_cache(self) -> None:
+        autotune_cache = _LocalAutotuneCacheBackend()
+        key = os.path.join(cache_dir(), self.key)
+        autotune_cache._put(key, self.content)
+
+    @override
+    @staticmethod
+    def type() -> str:
+        return "autotune"
+
+    @override
+    @staticmethod
+    def encode(content: JsonDataTy) -> bytes:
+        assert not isinstance(content, bytes)
+        serde = RemoteCacheJsonSerde()
+        content_bytes = serde.encode(content)
+        assert isinstance(content_bytes, bytes)
+        return content_bytes
+
+
+@dataclasses.dataclass
+class AutotuneCache:
+    configs_hash: str
+    local_cache: tuple[RemoteCache[JsonDataTy], str] | None = None
+    remote_cache: tuple[RemoteCache[JsonDataTy], str] | None = None
+
+    # Create a AutotuneCache. Returns None if none of the caches can be used.
+    @staticmethod
+    def create(
+        inductor_meta: _InductorMetaTy, filename: str, configs_hash: str
+    ) -> AutotuneCache | None:
+        cache = AutotuneCache(configs_hash)
+        key = AutotuneCache._prepare_key(filename)
+
+        cache._setup_local_cache(inductor_meta, os.path.dirname(filename), key)
+        cache._setup_remote_autotune_cache(inductor_meta, key)
+        if cache.local_cache or cache.remote_cache:
+            return cache
+        else:
+            return None
+
+    @staticmethod
+    def _prepare_key(filename: str) -> str:
+        from torch.compiler import config as cconfig
+
+        # base of filename is already sha256 hash the source contents
+        key = f"{os.path.basename(filename)}:{cconfig.cache_key_tag}"
+        return hashlib.sha256(key.encode("utf-8")).hexdigest()
+
+    # Read the best config options from the most local cache and return it.
+    def _read(self) -> dict[str, JsonDataTy] | None:
+        if local_cache := self.local_cache:
+            cache, key = local_cache
+            if best_config := cache.get(key):
+                if isinstance(best_config, dict):
+                    return best_config
+
+        if remote_cache := self.remote_cache:
+            cache, key = remote_cache
+            if best_config := cache.get(key):
+                if isinstance(best_config, dict):
+                    return best_config
+
+        return None
+
+    # Read the best config options from the most local cache and figure out
+    # which `configs` represents that option.
+    def read_best(
+        self, inductor_meta: _InductorMetaTy, configs: list[Config]
+    ) -> Config | None:
+        if best := self._read():
+            return _load_cached_autotuning(
+                best, self.configs_hash, configs, inductor_meta
+            )
+        return None
+
+    # Set up local filesystem caching information
+    def _setup_local_cache(
+        self, inductor_meta: _InductorMetaTy, dirname: str, cache_key: str
+    ) -> None:
+        if not inductor_meta.get("autotune_local_cache", True):
+            return
+
+        from ..codecache import torch_key
+
+        """
+        [Note: torch_key in autotune cache key]
+        Include torch_key() in the cache key so that different versions
+        of torch result in cache invalidation. This is important in case
+        of changes to the best_config format or other code changes that
+        are not backward compatible w.r.t. the cache.
+        """
+        hasher = hashlib.sha256()
+        hasher.update(cache_key.encode("utf-8"))
+        hasher.update(torch_key())
+        updated_cache_key = hasher.hexdigest()
+
+        cache_filename = f"{dirname}/{updated_cache_key}.best_config"
+        local_cache = LocalAutotuneCache()
+        self.local_cache = (local_cache, cache_filename)
+
+    # Set up remote caching information
+    def _setup_remote_autotune_cache(
+        self, inductor_meta: _InductorMetaTy, cache_key: str
+    ) -> None:
+        if not _should_use_remote_autotune_cache(inductor_meta):
+            return
+
+        if (backend_hash := inductor_meta.get("backend_hash", None)) is None:
+            log.debug(
+                "backend_hash is not passed on the inductor_meta, unable to use autotune remote cache"
+            )
+            return
+        assert isinstance(backend_hash, str)
+
+        from ..codecache import torch_key
+
+        is_fbcode = bool(inductor_meta.get("is_fbcode", False))
+
+        salt = "autotune-best-config-v2"
+        # re: torch_key - see [Note: torch_key in autotune cache key]
+        key = torch_key().hex() + backend_hash + self.configs_hash + salt
+        key = hashlib.sha256(key.encode("utf-8")).hexdigest()
+
+        remote_cache = create_cache(
+            key,
+            is_fbcode,
+            "FbRemoteAutotuneCache",
+            "RemoteAutotuneCache",
+        )
+        if not remote_cache:
+            return
+
+        # Save the args passed to create_cache
+        # in case AutotuneCache needs to be pickled
+        self.remote_cache_full_key = key
+        self.is_fbcode = is_fbcode
+        self.remote_cache = (remote_cache, cache_key)
+
+    # The AutotuneCache may be serialized/deserialized if we're using
+    # AsyncCompile worker processes to run triton compilation.
+    # This is because AutotuneCache instances are created on the worker
+    # process, but we need to run AutotuneCache.save on the parent process
+    # when actually doing autotuning.
+    def __getstate__(self) -> dict[str, Any]:
+        # The remote cache handles themselves may not be serializable
+        # So clear it and reconstruct it on setstate
+        remote_cache = getattr(self, "remote_cache", None)
+        return {
+            **self.__dict__,
+            # Save the cache_key portion
+            "remote_cache": remote_cache and remote_cache[1],
+        }
+
+    def __setstate__(self, state: dict[str, Any]) -> None:
+        # Reconstruct the remote cache on the parent class
+        self.__dict__.update(state)
+        if self.remote_cache is not None:
+            assert isinstance(self.remote_cache, str)
+            assert hasattr(self, "remote_cache_full_key")
+            assert hasattr(self, "is_fbcode")
+            cache_key = self.remote_cache
+            remote_cache = create_cache(
+                self.remote_cache_full_key,
+                self.is_fbcode,
+                "FbRemoteAutotuneCache",
+                "RemoteAutotuneCache",
+            )
+            if remote_cache is not None:
+                self.remote_cache = (remote_cache, cache_key)
+            else:
+                log.warning("Warning, failed to recreate remote cache after pickling")
+                self.remote_cache = None
+
+    # Save the config in the caches
+    def save(
+        self,
+        config: Config,
+        time_taken_ns: int,
+        found_by_coordesc: bool = False,
+        triton_cache_hash: str | None = None,
+    ) -> None:
+        data = {
+            # pyrefly: ignore [missing-attribute]
+            **config.kwargs,
+            # pyrefly: ignore [missing-attribute]
+            "num_warps": config.num_warps,
+            # pyrefly: ignore [missing-attribute]
+            "num_stages": config.num_stages,
+            "configs_hash": self.configs_hash,
+            "found_by_coordesc": found_by_coordesc,
+            "time_taken_ms": time_taken_ns // 1000000,  # Convert from NS to MS
+            "triton_cache_hash": triton_cache_hash,
+        }
+        if HAS_WARP_SPEC:
+            data.update(
+                {
+                    "num_consumer_groups": getattr(config, "num_consumer_groups", 0),
+                    "num_buffers_warp_spec": getattr(
+                        config, "num_buffers_warp_spec", 0
+                    ),
+                }
+            )
+
+        if local_cache := self.local_cache:
+            cache, key = local_cache
+            cache.put(key, data)
+            AutotuneCacheBundler.put(key, data)
+            autotune_artifact_key = os.path.join(*key.split(os.sep)[-2:])
+            CacheArtifactManager.record_artifact(
+                AutotuneCacheArtifact.type(), autotune_artifact_key, data
+            )
+
+            if log.isEnabledFor(logging.DEBUG):
+                type_str = "coordesc" if found_by_coordesc else "heuristic"
+                log.debug("Save %s tuning result to %s", type_str, key)
+
+        if remote_cache := self.remote_cache:
+            cache, key = remote_cache
+            cache.put(key, data)
+
+
+class _AutotuneCacheBundlerImpl:
+    """
+    Caches a set of LocalAutotuneCacheBackend entries together in a single
+    cache.
+    """
+
+    _key: str
+    _cache: RemoteCache[JsonDataTy]
+
+    # All known entries from LocalAutotuneCache.put()
+    _entries: dict[str, JsonDataTy]
+
+    def end_compile(self) -> None:
+        # TODO: Do we need to compute time_taken_ms and encode that somehow?
+        if self._entries:
+            self._cache.put(self._key, self._entries)
+
+    def put(self, basename: str, data: JsonDataTy) -> None:
+        # Do we need to worry about duplicates? We only have a single local fs
+        # entry - so probably not.
+        self._entries[basename] = data
+
+    def __init__(self, key: str, cache: RemoteCache[JsonDataTy]) -> None:
+        self._key = key
+        self._cache = cache
+        self._entries = {}
+
+    def sync(self) -> None:
+        # We don't currently use this - but we could async load starting at
+        # `begin_compile` and wait for the load to be finished here.
+        pass
+
+    @classmethod
+    def _should_use_bundled_autotune_remote_cache(
+        cls, inductor_meta: _InductorMetaTy
+    ) -> bool:
+        # The bundled autotune cache is only available if you've also got local
+        # caching enabled (because we feed the bundled data to the local cache).
+        if not inductor_meta.get("autotune_local_cache", True):
+            return False
+
+        # Check if the we're enabled via config
+        if (
+            bundled_autotune_remote_cache := inductor_meta.get(
+                "bundled_autotune_remote_cache"
+            )
+        ) is not None:
+            return bool(bundled_autotune_remote_cache)
+
+        if not cls._get_is_fbcode(inductor_meta):
+            return False
+        if torch._utils_internal.is_fb_unit_test():
+            return False
+        if inductor_meta.get("is_hip"):
+            return False
+
+        try:
+            from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION
+        except ModuleNotFoundError:
+            return False
+
+        jk = torch._utils_internal.justknobs_getval_int(
+            "pytorch/remote_cache:bundled_autotune_remote_cache_version"
+        )
+        return REMOTE_CACHE_VERSION >= jk
+
+    def _load_cache(self) -> bool:
+        from torch._inductor import codecache
+
+        # The single key is defined on construction of the cache.
+        entries = self._cache.get(self._key)
+        if entries is None or not isinstance(entries, dict):
+            # We couldn't load the cache - so mark _entries as non-None so we
+            # store local cache values.
+            return False
+
+        # Go through the entries we got from the cache and save them locally.
+        time_saved_ns = 0
+        for basename, data in entries.items():
+            # Reconstruct the final filename (see put())
+            root, ext = _splitext_nodot(basename)
+            _, _, filename = codecache.get_path(root, ext)
+            if isinstance(data, dict) and (tsns := data.get("time_saved_ns")):
+                time_saved_ns += int(tsns)  # type: ignore[arg-type]
+            local_cache = LocalAutotuneCache()
+            local_cache.put(filename, data)
+
+        codecache.add_ephemeral_timeout_increase_for_distributed(time_saved_ns)
+
+        return True
+
+    @staticmethod
+    def _get_is_fbcode(inductor_meta: _InductorMetaTy) -> bool:
+        return bool(inductor_meta.get("is_fbcode", False))
+
+    @staticmethod
+    def _get_backend_hash(inductor_meta: _InductorMetaTy) -> str:
+        backend_hash = inductor_meta["backend_hash"]
+        assert isinstance(backend_hash, str)
+        return backend_hash
+
+
+class AutotuneCacheBundler:
+    _bundler: _AutotuneCacheBundlerImpl | None = None
+
+    def __init__(self) -> None:
+        pass
+
+    # Call this before we start any autotune computation for an inductor python
+    # file. On a cache hit it copies the individual results into the local
+    # autotune caches.
+    @classmethod
+    def begin_compile(
+        cls,
+        inductor_meta: _InductorMetaTy,
+        *,
+        code: str | None = None,
+        code_hash: str | None = None,
+    ) -> None:
+        assert cls._bundler is None
+
+        if code is not None:
+            assert code_hash is None, "Cannot specify both code and code_hash"
+            code_hash = _comment_stripped_hash(code)
+        assert code_hash is not None
+
+        if not _AutotuneCacheBundlerImpl._should_use_bundled_autotune_remote_cache(
+            inductor_meta
+        ):
+            return
+
+        cache = create_cache(
+            "bundled-autotune-v1",
+            _AutotuneCacheBundlerImpl._get_is_fbcode(inductor_meta),
+            "FbRemoteBundledAutotuneCache",
+            "RemoteBundledAutotuneCache",
+        )
+        if not cache:
+            return
+
+        # We're starting a compilation phase. We have a cache key for the code
+        # we're compiling. We'll get the individual autotune bundles later (via
+        # self.put()). For now create the AutotuneCacheBundler and try to load
+        # from the cache.
+
+        salt = "bundled-autotune-best-configs-v1"
+        backend_hash = _AutotuneCacheBundlerImpl._get_backend_hash(inductor_meta)
+        # TODO: The autotune cache includes configs_hash in the key. The problem
+        # is that the configs_hash includes info from the individual pointwise()
+        # calls (size_hints, for example) which we can't know yet. I *think*
+        # that info is basically present in the `code_hash` (since it's a
+        # parameter to the pointwise decorator) - but is there other info we
+        # need to include from inductor_meta?
+        key = code_hash + backend_hash + salt
+        key = hashlib.sha256(key.encode("utf-8")).hexdigest()
+
+        bundler = _AutotuneCacheBundlerImpl(key, cache)
+        if not bundler._load_cache():
+            # We couldn't load from the cache - so save the data so we can store
+            # the saved autotunes.
+            cls._bundler = bundler
+
+        # If we get a cache hit don't bother saving any of the individual
+        # autotune results.
+
+    # Call this after all individual autotune results are finished for a
+    # inductor python file. If we gathered any individual results then we bundle
+    # those and put it into the cache.
+    @classmethod
+    def end_compile(cls) -> None:
+        if bundler := cls._bundler:
+            cls._bundler = None
+            bundler.end_compile()
+
+    @classmethod
+    def sync(cls) -> None:
+        if bundler := cls._bundler:
+            bundler.sync()
+
+    @classmethod
+    def put(cls, filename: str, data: JsonDataTy) -> None:
+        if bundler := cls._bundler:
+            # The filename comes in as something like
+            # "/tmp/tmp{random}/{aa}/{basename}.py" (where aa is
+            # basename[1:3]). Strip it down and make sure that it looks like a path
+            # we could reconstruct (because it's possible for the caller to
+            # customize the path).
+            basename = os.path.basename(filename)
+
+            # TODO: check cache_dir() vs filename, then strip dirname
+            bundler.put(basename, data)
+
+
+# Remove the comments from the code (which include things like run ids and file
+# paths) and then hash the result.
+def _comment_stripped_hash(code: str) -> str:
+    code = re.sub(r"#.*$", "", code, count=0, flags=re.MULTILINE)
+    return torch._inductor.codecache.code_hash(code)
+
+
+def _should_use_remote_autotune_cache(inductor_meta: _InductorMetaTy) -> bool:
+    if (config := inductor_meta.get("autotune_remote_cache")) is not None:
+        return bool(config)
+    if not inductor_meta.get("is_fbcode"):
+        return False
+    if torch._utils_internal.is_fb_unit_test():
+        return False
+    if inductor_meta.get("is_hip"):
+        return False
+
+    try:
+        from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION
+    except ModuleNotFoundError:
+        return False
+
+    return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int(
+        "pytorch/remote_cache:autotune_memcache_version"
+    )
+
+
+def _load_cached_autotuning(
+    best_config: dict[str, JsonDataTy],
+    configs_hash: str,
+    configs: list[Config],
+    inductor_meta: _InductorMetaTy,
+) -> Config | None:
+    if best_config is None:
+        return None
+    if best_config.pop("configs_hash", None) != configs_hash:
+        return None
+
+    # Remove time taken for comparison
+    best_config.pop("time_taken_ms", None)
+
+    best_config.pop("triton_cache_hash", None)
+
+    if inductor_meta.get("coordinate_descent_tuning") and best_config.pop(
+        "found_by_coordesc", False
+    ):
+        num_warps = best_config.pop("num_warps")
+        num_stages = best_config.pop("num_stages")
+
+        # Extract common arguments
+        config_args = {
+            "num_warps": num_warps,
+            "num_stages": num_stages,
+        }
+
+        if HAS_WARP_SPEC:
+            config_args.update(
+                {
+                    "num_consumer_groups": best_config.pop("num_consumer_groups", 0),
+                    "num_buffers_warp_spec": best_config.pop(
+                        "num_buffers_warp_spec", 0
+                    ),
+                }
+            )
+
+        # Create the triton_config with the appropriate arguments
+        # pyrefly: ignore [bad-argument-count]
+        triton_config = Config(best_config, **config_args)
+        # pyrefly: ignore [missing-attribute]
+        triton_config.found_by_coordesc = True
+        return triton_config
+
+    matching_configs = [
+        cfg
+        for cfg in configs
+        # pyrefly: ignore [missing-attribute]
+        if all(val == best_config.get(key) for key, val in cfg.kwargs.items())
+        # pyrefly: ignore [missing-attribute]
+        and cfg.num_warps == best_config.get("num_warps")
+        # pyrefly: ignore [missing-attribute]
+        and cfg.num_stages == best_config.get("num_stages")
+    ]
+    if len(matching_configs) != 1:
+        return None
+
+    return matching_configs[0]
+
+
+class _LocalAutotuneCacheBackend(RemoteCacheBackend[bytes]):
+    @override
+    def _get(self, key: str) -> bytes | None:
+        try:
+            with open(key, "rb") as fd:
+                return fd.read()
+        except FileNotFoundError:
+            return None
+
+    @override
+    def _put(self, key: str, data: bytes) -> None:
+        os.makedirs(os.path.dirname(key), exist_ok=True)
+        from torch._inductor import codecache
+
+        codecache.write_atomic(key, data)
+
+
+class LocalAutotuneCache(RemoteCache[JsonDataTy]):
+    def __init__(self) -> None:
+        backend = _LocalAutotuneCacheBackend()
+        serde = RemoteCacheJsonSerde()
+        super().__init__(backend, serde)
+
+    @override
+    def _get(self, key: str, sample: Sample | None) -> JsonDataTy | None:
+        AutotuneCacheBundler.sync()
+        result = super()._get(key, sample)
+        if result is not None:
+            assert isinstance(result, dict)
+            # What? Why are we doing a put() here? Imagine we have a new model
+            # that reuses some existing kernels that have already been
+            # compiled. If we didn't do a `put` here (on cache hit) then the new
+            # model would only bundle *newly* compiled kernels, not existing
+            # kernels that were already compiled and cached.
+            AutotuneCacheBundler.put(key, result)
+            autotune_artifact_key = os.path.join(*key.split(os.sep)[-2:])
+            CacheArtifactManager.record_artifact(
+                AutotuneCacheArtifact.type(), autotune_artifact_key, result
+            )
+        return result
+
+    @override
+    def _put(self, key: str, value: JsonDataTy, sample: Sample | None) -> None:
+        AutotuneCacheBundler.put(key, value)
+        super()._put(key, value, sample)
+
+
+def _splitext_nodot(basename: str) -> tuple[str, str]:
+    root, ext = os.path.splitext(basename)
+    if ext:
+        ext = ext[1:]
+    return root, ext
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/benchmarking.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/benchmarking.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfa33f66ef3a4441613eedabe25731ca2edc25fa
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/benchmarking.py
@@ -0,0 +1,441 @@
+import functools
+import inspect
+import time
+from collections.abc import Callable
+from functools import cached_property, wraps
+from itertools import chain
+from statistics import median
+from typing import Any, Concatenate, Optional, Union
+from typing_extensions import ParamSpec, Self, TypeVar
+
+import torch
+import torch.utils._pytree as pytree
+from torch._dynamo.utils import counters, dynamo_timed
+from torch._inductor.config import use_experimental_benchmarker
+from torch.utils._debug_mode import DebugMode
+
+
+logger = torch._logging.getArtifactLogger(__name__, "benchmarking")
+use_experimental_benchmarker = (
+    use_experimental_benchmarker and torch.cuda.is_available()
+)
+
+
+MILLISECONDS_PER_SECOND = 1000
+
+P = ParamSpec("P")
+T = TypeVar("T")
+
+
+def may_distort_benchmarking_result(fn: Callable[..., Any]) -> Callable[..., Any]:
+    from torch._inductor import config
+
+    if config.test_configs.distort_benchmarking_result == "":
+        return fn
+
+    def distort(
+        ms: list[float] | tuple[float, ...] | float,
+    ) -> list[float] | tuple[float, ...] | float:
+        if isinstance(ms, (list, tuple)):
+            return type(ms)(distort(val) for val in ms)  # type: ignore[misc]
+
+        distort_method = config.test_configs.distort_benchmarking_result
+        assert isinstance(ms, float)
+        if distort_method == "inverse":
+            return 1.0 / ms if ms else 0.0
+        elif distort_method == "random":
+            import random
+
+            return random.random()
+        else:
+            raise RuntimeError(f"Unrecognized distort method {distort_method}")
+
+    @functools.wraps(fn)
+    def wrapper(
+        *args: list[Any], **kwargs: dict[str, Any]
+    ) -> list[float] | tuple[float, ...] | float:
+        ms = fn(*args, **kwargs)
+
+        return distort(ms)
+
+    return wrapper
+
+
+def may_ban_benchmarking() -> None:
+    if torch._inductor.config.deterministic:
+        raise RuntimeError("""In the deterministic mode of Inductor, we will avoid those
+        benchmarkings that would cause non deterministic results. Only benchmarkings in the vetted
+        scenarios are allowed. Example include autotuning for triton configs of pointwise kernels.
+
+        When you see this exception, you can do one of the following two things:
+        1. if the benchmarking you are doing does not introduce any non-determinism, you can just
+        add is_vetted_benchmarking=True to you benchmark_gpu call. That would solve the issue.
+
+        2. if the benchmarking you are doing indeed introduces non-determinism, you'll need to disable
+        such feature in deterministic mode or find an alternative implementation that is deterministic.
+        """)
+
+
+def time_and_count(
+    fn: Callable[Concatenate[Any, P], T],
+) -> Callable[Concatenate[Any, P], T]:
+    """Wraps `fn` with `dynamo_timed` context, and increments the appropriate dynamo
+    counters. It is expected that `fn` is a method of `Benchmarker` or one of its
+    subclasses; typing limitations prevent us from declaring this directly.
+    """
+
+    @wraps(fn)
+    def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T:
+        fn_qual_name = f"{self.__class__.__name__}.{fn.__name__}"
+        counters["inductor"][f"benchmarking.{fn_qual_name}"] += 1
+        with dynamo_timed(fn_qual_name, log_pt2_compile_event=False):
+            return fn(self, *args, **kwargs)
+
+    return wrapper
+
+
+class Benchmarker:
+    """
+    A device-agnostic benchmarking utility for measuring the runtime of
+    inductor generated callables.
+    """
+
+    def __init__(self: Self) -> None:
+        pass
+
+    def infer_device(self, *fn_args: Any, **fn_kwargs: Any) -> torch.device:
+        inferred_device: Optional[torch.device] = None
+        for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
+            # Some callables take nested structures as arguments so use the
+            # flattened form to find any tensors
+            for arg_or_kwarg_leaf in pytree.tree_leaves(arg_or_kwarg):
+                if not isinstance(arg_or_kwarg_leaf, torch.Tensor):
+                    continue
+                if inferred_device is None:
+                    inferred_device = arg_or_kwarg_leaf.device
+                elif arg_or_kwarg_leaf.device != inferred_device:
+                    raise ValueError(
+                        "Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!"
+                    )
+
+        if inferred_device is None:
+            raise ValueError(
+                "Can't safely infer the device type of `fn` with no device types"
+                " in `fn_args` or `fn_kwargs`. Use a direct benchmarking method instead e.g. "
+                "`Benchmarker.benchmark_cpu` or `Benchmarker.benchmark_gpu`."
+            )
+
+        return inferred_device
+
+    @time_and_count
+    def benchmark(
+        self: Self,
+        fn: Callable[..., Any],
+        fn_args: Optional[tuple[Any, ...]] = None,
+        fn_kwargs: Optional[dict[str, Any]] = None,
+        device: Optional[Union[str, torch.device]] = None,
+        **kwargs: Any,
+    ) -> float:
+        """Benchmark `fn(*fn_args, *fn_kwargs)` and return the runtime, in milliseconds (the
+        actual runtime calculation is dictated by the benchmarking implementation, but may be
+        one of [mean, median, minimum, etc.]). Functions as a convenience wrapper around
+        device-specific implementations, like `benchmark_cpu` and `benchmark_gpu`. Raises
+        `ValueError(...)` if we can't safely infer the device type of `fn`; for example,
+        if multiple device types are found in `fn_args` and `fn_kwargs`, or if no device
+        types are found. To bypass device inference, provide the device to the `device`
+        parameter.
+
+        WARNING: if `fn` mutates `fn_args` or `fn_kwargs`, benchmarking may fail unexpectedly.
+        For example, if `fn` clears a mutable object, subsequent invocations of `fn` during
+        benchmarking will fail. In such cases, `fn` should handle cloning its arguments internally.
+        If device inference is required, `Benchmarker.infer_device` can be used prior to calling
+        this method without any arguments for `fn_args` and `fn_kwargs`.
+
+        Arguments:
+        - fn: The function to benchmark.
+        - fn_args: The function's arguments.
+        - fn_kwargs: The function's kwargs.
+
+        Keyword Arguments:
+        - device: Which device to use for benchmarking. If not provided the device will be attempted
+        to be inferred from `fn_args` and `fn_kwargs`.
+        - **kwargs: The benchmarking implementation's kwargs.
+
+        Returns:
+        - The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds.
+        """
+        inferred_device: Optional[torch.device] = None
+        if device is not None:
+            inferred_device = (
+                torch.device(device) if isinstance(device, str) else device
+            )
+        else:
+            if fn_args is None and fn_kwargs is None:
+                raise ValueError(
+                    "`fn_args` and `fn_kwargs` cannot both be None if `device` is not provided."
+                )
+
+            fn_args = fn_args or tuple()
+            fn_kwargs = fn_kwargs or {}
+            inferred_device = self.infer_device(*fn_args, **fn_kwargs)
+
+        assert isinstance(inferred_device, torch.device)
+
+        fn_args = fn_args or tuple()
+        fn_kwargs = fn_kwargs or {}
+
+        # No need to wrap if the callable takes no arguments
+        if len(fn_args) == 0 and len(fn_kwargs) == 0:
+            _callable = fn
+        else:
+            _callable = lambda: fn(*fn_args, **fn_kwargs)  # noqa: E731
+
+        # Surfacing all kernels during autotuning is super noisy; filtering these out.
+        with DebugMode._benchmarking_inductor():
+            if inferred_device == torch.device("cpu"):
+                return self.benchmark_cpu(_callable, **kwargs)
+            # TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking
+            # implementation which was written specifically with CUDA devices in mind, we may want to
+            # explore alternate implementations for other device types.
+            return self.benchmark_gpu(_callable, **kwargs)
+
+    @time_and_count
+    def benchmark_cpu(
+        self: Self, _callable: Callable[[], Any], warmup: int = 20, rep: int = 100
+    ) -> float:
+        """Benchmark the CPU callable, `_callable`, and return the median runtime,
+        in milliseconds.
+
+        Arguments:
+        - _callable: The CPU callable to benchmark.
+
+        Keyword Arguments:
+        - warmup: Optionally, the duration, in milliseconds, to run `_callable`
+        before benchmarking starts.
+        - rep: Optionally, the duration, in milliseconds, to run `_callable`
+        during benchmarking.
+
+        Returns:
+        - The median runtime of `_callable`, in milliseconds.
+        """
+
+        def run_for(ms: int) -> list[float]:
+            timings = []
+            run_start_t = time.perf_counter()
+            while True:
+                start_t = time.perf_counter()
+                _callable()
+                end_t = time.perf_counter()
+                timings.append((end_t - start_t) * MILLISECONDS_PER_SECOND)
+                if ((end_t - run_start_t) * MILLISECONDS_PER_SECOND) > ms:
+                    break
+            return timings
+
+        run_for(warmup)
+        return median(run_for(rep))
+
+    @time_and_count
+    def benchmark_gpu(self: Self, *args: Any, **kwargs: Any) -> float:
+        raise NotImplementedError
+
+
+class TritonBenchmarker(Benchmarker):
+    @cached_property
+    def triton_do_bench(self: Self) -> Callable[..., Any]:
+        """Lazily import Triton's `do_bench`."""
+        try:
+            from triton.testing import do_bench
+        except ImportError as e:
+            raise NotImplementedError("requires Triton") from e
+        return do_bench
+
+    @may_distort_benchmarking_result
+    @time_and_count
+    # pyrefly: ignore [bad-override]
+    def benchmark_gpu(
+        self: Self,
+        _callable: Callable[[], Any],
+        is_vetted_benchmarking: bool = False,
+        **kwargs: Any,
+    ) -> float:
+        """Benchmark the GPU callable, `_callable`, and return the runtime, in milliseconds.
+
+        Arguments:
+        - _callable: The GPU callable to benchmark.
+
+        Keyword Arguments:
+        - quantiles: Optionally, a tuple of floats denoting the requested quantiles.
+        - return_mode: Optionally, the requested return mode. Currently, Triton's
+        `do_bench` supports min, max, mean, and median return modes.
+        - **kwargs: Additional kwargs passed to Triton's `do_bench`.
+
+        Returns:
+        - The runtime of `callable`, in milliseconds. If `kwargs["quantiles"]` is specified,
+        this is the first requested quantile. Else, if `kwargs["return_mode"]` is specified,
+        this is the requested return mode. Otherwise, this is the median.
+        """
+        if not is_vetted_benchmarking:
+            may_ban_benchmarking()
+
+        do_bench_params = inspect.signature(self.triton_do_bench).parameters
+        for kwarg in list(kwargs.keys()):
+            if kwarg not in do_bench_params:
+                del kwargs[kwarg]
+        if "quantiles" in kwargs:
+            return self.triton_do_bench(_callable, **kwargs)[0]
+        elif "return_mode" in kwargs:
+            return self.triton_do_bench(_callable, **kwargs)
+        return self.triton_do_bench(_callable, **kwargs, return_mode="median")
+
+
+class InductorBenchmarker(TritonBenchmarker):  # noqa: docstring_linter
+    @cached_property
+    def L2_cache_size(self: Self) -> int:
+        """Get the L2 cache size, in bytes, of the current device."""
+        device = torch.cuda.current_device()
+        props = torch.cuda.get_device_properties(device)
+        return props.L2_cache_size
+
+    def get_event_pairs(
+        self: Self, iters: int
+    ) -> list[tuple[torch.cuda.Event, torch.cuda.Event]]:
+        """Get `iters` pairs of CUDA events."""
+        return [
+            (
+                torch.cuda.Event(enable_timing=True),
+                torch.cuda.Event(enable_timing=True),
+            )
+            for _ in range(iters)
+        ]
+
+    def get_event_pairs_min_timing(
+        self: Self, event_pairs: list[tuple[torch.cuda.Event, torch.cuda.Event]]
+    ) -> float:
+        """Get the minimum timing, in milliseconds, for a group of CUDA event pairs."""
+        return min(
+            [
+                start_event.elapsed_time(end_event)
+                for start_event, end_event in event_pairs
+            ]
+        )
+
+    @may_distort_benchmarking_result
+    @time_and_count
+    def benchmark_gpu(  # type: ignore[override]
+        self: Self,
+        _callable: Callable[[], Any],
+        estimation_iters: int = 5,
+        memory_warmup_iters: int = 100,
+        benchmark_iters: int = 100,
+        max_benchmark_duration: int = 25,
+        return_mode: str = "min",
+        grad_to_none: list[torch.Tensor] | None = None,
+        is_vetted_benchmarking: bool = False,
+        **kwargs: Any,
+    ) -> float | list[float]:
+        """Benchmark a GPU callable using a custom benchmarking implementation.
+
+        Arguments:
+        - _callable: The callable to benchmark.
+
+        Keyword Arguments:
+        - estimation_iters: Optionally, the number of iterations to run `_callable`
+        during runtime estimation.
+        - memory_warmup_iters: Optionally, the number of iterations to flush the L2
+        cache before starting benchmarking.
+        - benchmark_iters: Optionally, the number of iterations to run `_callable`
+        during the benchmarking.
+        - max_benchmark_duration: Optionally, the maximum duration of the benchmarking,
+        in milliseconds. An estimated duration is calculated based on the values
+        of `memory_warmup_iters` and `benchmark_iters`, along with the estimated
+        runtime of `_callable` and various other factors, and we then shrink
+        `benchmark_iters` to fit in the allotted maximum duration.
+        - return_mode: Return mode for benchmark results. Options are "min" (default),
+        "all" (returns all measurements).
+        - grad_to_none: Optionally, a list of tensors whose gradients should be cleared
+        before each benchmark iteration.
+        - is_vetted_benchmarking: in deterministic mode, we only allow
+        benchmarking in vetted cases.
+        - **kwargs: Additional kwargs that may be passed to the fallback.
+
+        Returns:
+        - If return_mode="min": The minimum runtime of `_callable`, in milliseconds.
+        - If return_mode="all": List of all runtime measurements, in milliseconds.
+        """
+
+        if not is_vetted_benchmarking:
+            may_ban_benchmarking()
+
+        # we don't want any outside errors propagating into benchmarking
+        torch.cuda.synchronize()
+
+        # warmup `_callable` (and catches any failures in the process)
+        _callable()
+        torch.cuda.synchronize()
+
+        # see https://github.com/triton-lang/triton/pull/840 for why `dtype=torch.int`
+        buffer = torch.empty(self.L2_cache_size // 4, dtype=torch.int, device="cuda")
+        buffer.zero_()
+
+        # estimate the runtime of `_callable`
+        event_pairs = self.get_event_pairs(estimation_iters)
+        for start_event, end_event in event_pairs:
+            # Clear gradients before timing (matches triton.testing.do_bench)
+            if grad_to_none is not None:
+                for x in grad_to_none:
+                    x.grad = None
+            buffer.zero_()
+            start_event.record()
+            _callable()
+            end_event.record()
+        torch.cuda.synchronize()
+        estimated_timing = self.get_event_pairs_min_timing(event_pairs)
+
+        # adjust `benchmark_iters` to fit in the maximum benchmarking duration
+        benchmark_iters = max(
+            min(benchmark_iters, int(max_benchmark_duration // estimated_timing)), 1
+        )
+
+        # do the memory warmup
+        for _ in range(memory_warmup_iters):
+            buffer.zero_()
+
+        # benchmark `_callable`
+        event_pairs = self.get_event_pairs(benchmark_iters)
+        for start_event, end_event in event_pairs:
+            # Clear gradients before timing (matches triton.testing.do_bench)
+            if grad_to_none is not None:
+                for x in grad_to_none:
+                    x.grad = None
+            buffer.zero_()
+            start_event.record()
+            _callable()
+            end_event.record()
+        torch.cuda.synchronize()
+
+        # explicitly delete the buffer, sometimes helps memory
+        # footprint metrics in OSS Inductor performance benchmarks
+        del buffer
+
+        # Return based on the requested mode
+        if return_mode == "all":
+            # Get all timings from event pairs
+            all_timings = [
+                start_event.elapsed_time(end_event)
+                for start_event, end_event in event_pairs
+            ]
+            return all_timings
+        elif return_mode == "min":
+            benchmarked_timing = self.get_event_pairs_min_timing(event_pairs)
+            # return the minimum of `estimated_timing` and `benchmarked_timing`,
+            # we just want the minimum timing overall so we might as well check both
+            return min(estimated_timing, benchmarked_timing)
+        else:
+            raise ValueError(
+                f"Unsupported return_mode: {return_mode}. Use 'min' or 'all'."
+            )
+
+
+benchmarker = (
+    InductorBenchmarker() if use_experimental_benchmarker else TritonBenchmarker()
+)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/cache_dir_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/cache_dir_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..34b84a68f6300c1709593e303ff2a07e1f50bc46
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/cache_dir_utils.py
@@ -0,0 +1,54 @@
+import getpass
+import os
+import re
+import tempfile
+from collections.abc import Generator
+from contextlib import contextmanager
+
+from torch._environment import is_fbcode
+
+
+# Factoring out to file without torch dependencies
+
+
+def cache_dir() -> str:
+    cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR")
+    if cache_dir is None:
+        os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir = default_cache_dir()
+    os.makedirs(cache_dir, exist_ok=True)
+    return cache_dir
+
+
+def default_cache_dir() -> str:
+    sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser())
+    return os.path.join(
+        tempfile.gettempdir() if not is_fbcode() else "/var/tmp",
+        "torchinductor_" + sanitized_username,
+    )
+
+
+def triton_cache_dir(device: int) -> str:
+    if (directory := os.getenv("TRITON_CACHE_DIR")) is not None:
+        return directory
+    return os.path.join(
+        cache_dir(),
+        "triton",
+        str(device),
+    )
+
+
+@contextmanager
+def temporary_cache_dir(directory: str) -> Generator[None, None, None]:
+    from torch._inductor.utils import clear_caches
+
+    original = os.environ.get("TORCHINDUCTOR_CACHE_DIR")
+    os.environ["TORCHINDUCTOR_CACHE_DIR"] = directory
+    try:
+        clear_caches()
+        yield
+    finally:
+        clear_caches()
+        if original is None:
+            del os.environ["TORCHINDUCTOR_CACHE_DIR"]
+        else:
+            os.environ["TORCHINDUCTOR_CACHE_DIR"] = original
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/compile_tasks.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/compile_tasks.py
new file mode 100644
index 0000000000000000000000000000000000000000..11801eac925848eb1e0969d587e8fcc98484a1cd
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/compile_tasks.py
@@ -0,0 +1,78 @@
+from __future__ import annotations
+
+import functools
+import linecache
+import os
+import sys
+import time
+import warnings
+from pathlib import Path
+from types import ModuleType
+from typing import Any, TYPE_CHECKING
+
+from torch._utils_internal import log_triton_builds
+
+
+if TYPE_CHECKING:
+    from collections.abc import Callable
+
+    from torch._inductor.runtime.triton_heuristics import CachingAutotuner
+
+
+def _reload_python_module(
+    key: str, path: str, set_sys_modules: bool = True
+) -> ModuleType:
+    with open(path) as f:
+        try:
+            code = compile(f.read(), path, "exec", dont_inherit=True)
+        except Exception as e:
+            raise RuntimeError(
+                f"Failed to import {path}\n{type(e).__name__}: {e}"
+            ) from None
+        mod = ModuleType(f"{__name__}.{key}")
+        mod.__file__ = path
+        mod.key = key  # type: ignore[attr-defined]
+        exec(code, mod.__dict__, mod.__dict__)
+        if set_sys_modules:
+            sys.modules[mod.__name__] = mod
+        return mod
+
+
+@functools.cache
+def _set_triton_ptxas_path() -> None:
+    if os.environ.get("TRITON_PTXAS_PATH") is not None:
+        return
+    ptxas = Path(__file__).absolute().parents[2] / "bin" / "ptxas"
+    if not ptxas.exists():
+        return
+    if ptxas.is_file() and os.access(ptxas, os.X_OK):
+        os.environ["TRITON_PTXAS_PATH"] = str(ptxas)
+    else:
+        warnings.warn(f"{ptxas} exists but is not an executable")
+
+
+def _worker_compile_triton(
+    load_kernel: Callable[[], CachingAutotuner],
+    extra_env: dict[str, str],
+    extra_config: dict[str, Any],
+) -> tuple[CachingAutotuner, int]:
+    _set_triton_ptxas_path()
+    os.environ.update(extra_env)
+    from torch._inductor import config
+
+    with config.patch(extra_config):
+        fail = None
+        try:
+            start_ns = time.time_ns()
+            kernel = load_kernel()
+            kernel.precompile(warm_cache_only=True)
+            elapsed_ns = time.time_ns() - start_ns
+            kernel.prepare_for_pickle()
+            # We can release this memory in the compile subprocesses:
+            linecache.clearcache()
+            return kernel, elapsed_ns // 1000
+        except Exception as e:
+            fail = str(e)
+            raise
+        finally:
+            log_triton_builds(fail=fail)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/coordinate_descent_tuner.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/coordinate_descent_tuner.py
new file mode 100644
index 0000000000000000000000000000000000000000..91736febd29f61106fa5bb26d896941952582091
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/coordinate_descent_tuner.py
@@ -0,0 +1,412 @@
+# mypy: allow-untyped-defs
+import copy
+import itertools
+import logging
+from collections.abc import Callable
+from typing import TYPE_CHECKING
+
+from torch.utils._ordered_set import OrderedSet
+
+from ..utils import get_max_numwarps
+from .hints import TRITON_MAX_BLOCK
+from .runtime_utils import red_text, triton_config_to_hashable
+
+
+if TYPE_CHECKING:
+    from .triton_compat import triton
+
+
+log = logging.getLogger(__name__)
+
+
+def get_field(config, name):
+    if name == "num_warps":
+        return config.num_warps
+    elif name == "num_stages":
+        return config.num_stages
+    elif name == "waves_per_eu":
+        return config.kwargs.get(name, int(8 // config.num_warps))
+    else:
+        return config.kwargs.get(name, None)
+
+
+def set_field(config, name, value):
+    if name == "num_warps":
+        config.num_warps = value
+    elif name == "num_stages":
+        config.num_stages = value
+    else:
+        config.kwargs[name] = value
+
+
+class CoordescTuner:
+    """
+    The coordinate descent tuner. Tune one field/coordinate at a time.
+
+    TODO will it be necessary to tune multiple fields simultaneously.
+
+
+    TODO: what if both increasing and decreasing a field can improve perf.
+          i.e., there are multiple local optima..
+    """
+
+    def __init__(
+        self,
+        is_mm=False,
+        is_native_matmul=False,
+        is_mix_order_reduction=False,
+        name="unknown",
+        size_hints=None,
+        inductor_meta=None,
+        frozen_fields=None,
+    ):
+        self.is_mm = is_mm  # we will tune num_stages for mm
+
+        # Native matmul codegen assumes ZBLOCK=1 always.
+        # This is because 3d tl.dot is slow and so we want to tile y and x only.
+        # tl.dot also does not support size smaller than 16; we put this restriction.
+        self.is_native_matmul = is_native_matmul
+        assert not (self.is_mm and self.is_native_matmul)
+        self.is_mix_order_reduction = is_mix_order_reduction
+        self.cached_benchmark_results = {}
+        self.name = name
+        self.size_hints = size_hints
+        self.inductor_meta = inductor_meta or {}
+        self.frozen_fields: OrderedSet[str] = (
+            OrderedSet(frozen_fields) if frozen_fields is not None else OrderedSet()
+        )
+
+    def get_config_max(self, prefix: str) -> int:
+        max_block = TRITON_MAX_BLOCK[prefix.upper()]
+        size_hint = self.size_hints.get(prefix) if self.size_hints is not None else None
+        return min(max_block, size_hint) if size_hint is not None else max_block
+
+    def get_warpsmax(self):
+        # Avoid querying device directly if device properties are populated in inductor_meta
+        warp_size = self.inductor_meta.get("warp_size")
+        max_threads_per_block = self.inductor_meta.get("max_threads_per_block")
+        if warp_size and max_threads_per_block:
+            return max_threads_per_block // warp_size
+        else:
+            return get_max_numwarps()
+
+    def cache_benchmark_result(self, config, timing):
+        self.cached_benchmark_results[triton_config_to_hashable(config)] = timing
+
+    def lookup_in_cache(self, config):
+        return self.cached_benchmark_results.get(triton_config_to_hashable(config))
+
+    def call_func(self, func, config):
+        found = self.lookup_in_cache(config)
+        if found is not None:
+            log.debug("  CACHED")
+            return found
+        timing = func(config)
+        self.cache_benchmark_result(config, timing)
+        return timing
+
+    @property
+    def tunable_fields(self):
+        out = [
+            "XBLOCK",
+            "YBLOCK",
+            "ZBLOCK",
+            # NOTE: we should not tune R0_BLOCK for persistent reduction.
+            # We rely on the fact that persistent reduction's triton.Config
+            # does not have the R0_BLOCK field to guarantee that.
+            "R0_BLOCK",
+            "R1_BLOCK",
+            # the following 3 are for mm
+            "BLOCK_M",
+            "BLOCK_N",
+            "BLOCK_K",
+            "num_warps",
+        ]
+        if self.is_mm:
+            out.append("num_stages")
+        if self.inductor_meta.get("is_hip") is True:
+            out.append("waves_per_eu")
+        if self.is_native_matmul:
+            out.append("num_stages")
+            out.remove("ZBLOCK")  # ZBLOCK=1 always in native matmul
+
+        if self.is_mix_order_reduction:
+            # unlike TritonConfig.num_stages, this one is
+            # put in TritonConfig.kwargs["NUM_STAGES"] and is used to
+            # control the stage of pipelining of tl.range.
+            out.append("NUM_STAGES")
+
+        return [f for f in out if f not in self.frozen_fields]
+
+    def value_too_large(self, name: str, val: int) -> bool:
+        block_suffix = "BLOCK"
+        if name.endswith(block_suffix):
+            prefix = name.strip(block_suffix).lower()
+            return val > self.get_config_max(prefix)
+        if name == "num_warps":
+            return val > self.get_warpsmax()
+        if name == "waves_per_eu":
+            return val > 8
+
+        return False
+
+    def value_too_small(self, name: str, val: int) -> bool:
+        # In native matmul, block size should be >= 16 for tl.dot
+        if self.is_native_matmul:
+            if name in ["YBLOCK", "XBLOCK", "R0_BLOCK"]:
+                return val < 16
+
+        # Break if value becomes 0/neg
+        return val <= 0
+
+    def get_neighbour_values(self, name, orig_val, radius=None, include_self=False):
+        """
+        Get neighbour values in 'radius' steps. The original value is not
+        returned as it's own neighbour.
+        """
+        if radius is None:
+            radius = 1
+        if name == "NUM_STAGES":
+            # we see cases that
+            # NUM_STAGES=1 is better than NUM_STAGES=2
+            # while NUM_STAGES=1 is worse than NUM_STAGES=3
+            radius = max(radius, 2)
+
+        assert radius >= 1
+
+        def update(cur_val, inc=True):
+            if name in ["num_stages", "NUM_STAGES"]:
+                if inc:
+                    return cur_val + 1
+                else:
+                    return cur_val - 1
+            else:
+                if inc:
+                    return cur_val * 2
+                else:
+                    return cur_val // 2
+
+        out = []
+        # increment loop
+        cur_val = orig_val
+        for _ in range(radius):
+            cur_val = update(cur_val, True)
+            if self.value_too_large(name, cur_val):
+                break
+            out.append(cur_val)
+
+        # decrement loop
+        cur_val = orig_val
+        for _ in range(radius):
+            cur_val = update(cur_val, False)
+            if self.value_too_small(name, cur_val):
+                break
+            out.append(cur_val)
+
+        if include_self:
+            out.append(orig_val)
+        return out
+
+    @staticmethod
+    def has_improvement(baseline, test):
+        threshold = 0.001  # 0.1%
+        return test is not None and test < baseline * (1 - threshold)
+
+    def is_valid_config(self, config) -> bool:
+        if self.is_mix_order_reduction:
+            # Mix order reduction has an extra constraint that
+            # we should not tune XBLOCK beyond RSPLIT_SIZE
+            xblock = config.kwargs["XBLOCK"]
+            split_size = config.kwargs["RSPLIT_SIZE"]
+            return xblock <= split_size
+        return True
+
+    def check_all_tuning_directions(
+        self,
+        # pyrefly: ignore [missing-attribute]
+        func: Callable[["triton.Config"], float],
+        best_config,
+        best_timing,
+    ):
+        """
+        Check all directions. We only do this once the regular coordinate
+        descent tuning find no better choices any more.
+        We only have a few tunable fields, so this should be fine.
+        """
+        candidate_values_list = []
+        effective_fields = []
+        for field in self.tunable_fields:
+            old_value = get_field(best_config, field)
+            if old_value is None:
+                continue
+            radius = self.inductor_meta.get("coordinate_descent_search_radius", 1)
+            candidate_values = self.get_neighbour_values(
+                field,
+                old_value,
+                radius=radius,
+                include_self=True,
+            )
+            candidate_values_list.append(candidate_values)
+            effective_fields.append(field)
+
+        choices = itertools.product(*candidate_values_list)
+        improved = False
+        for choice in choices:
+            assert len(choice) == len(effective_fields)
+            candidate_config = copy.deepcopy(best_config)
+            for new_val, field in zip(choice, effective_fields):
+                set_field(candidate_config, field, new_val)
+            if not self.is_valid_config(candidate_config):
+                continue
+            cmp_res, candidate_timing = self.compare_config(
+                func, candidate_config, best_config, best_timing
+            )
+            if cmp_res:
+                improved = True
+                best_config = candidate_config
+                best_timing = candidate_timing
+
+        return improved, best_config, best_timing
+
+    def compare_config(self, func, candidate_config, best_config, best_timing):
+        """
+        Check if candidate_config is better than best_config.
+
+        Return a tuple of (compare_result, candidate_timing).
+        compare_result is true iff candidate_config is better.
+        """
+        log.debug("Try config %s", candidate_config)
+        try:
+            candidate_timing = self.call_func(func, candidate_config)
+        except Exception as e:
+            log.debug("Got exception %s", e)  # noqa: G200
+            return False, float("inf")
+
+        if self.has_improvement(best_timing, candidate_timing):
+            log.debug(
+                "Tune from %s %f -> %s %f",
+                best_config,
+                best_timing,
+                candidate_config,
+                candidate_timing,
+            )
+
+            return True, candidate_timing
+        return False, candidate_timing
+
+    def autotune(
+        self,
+        # pyrefly: ignore [missing-attribute]
+        func: Callable[["triton.Config"], float],
+        # pyrefly: ignore [missing-attribute]
+        baseline_config: "triton.Config",
+        baseline_timing: float | None = None,
+    ) -> "triton.Config":  # pyrefly: ignore  # missing-attribute
+        if baseline_timing is None:
+            baseline_timing = self.call_func(func, baseline_config)
+
+        log.debug("= Do coordinate descent tuning for %s =", self.name)
+        log.debug(
+            "%s: Baseline Config %s, baseline timing %f",
+            self.name,
+            baseline_config,
+            baseline_timing,
+        )
+        improved = True
+        best_config = baseline_config
+        best_timing = baseline_timing
+        tunable_fields = self.tunable_fields
+
+        while improved:
+            improved = False
+
+            for name in tunable_fields:
+                cur_val = get_field(best_config, name)
+                # some kernel don't have R0_BLOCK/YBLOCK/ZBLOCK. So cur_val may be None
+                if cur_val is None:
+                    continue
+
+                # It's possible that candidate_values is empty.
+                # E.g., if XBLOCK is 1 initially and size_hint for x is also 1.
+                # We would not try either larger or smaller XBLOCK in this case.
+                candidate_values = self.get_neighbour_values(name, cur_val)
+
+                for next_val in candidate_values:
+                    candidate_config = copy.deepcopy(best_config)
+                    set_field(candidate_config, name, next_val)
+
+                    if not self.is_valid_config(candidate_config):
+                        continue
+                    cmp_res, candidate_timing = self.compare_config(
+                        func, candidate_config, best_config, best_timing
+                    )
+                    if cmp_res:
+                        improved = True
+                        best_config, best_timing = candidate_config, candidate_timing
+
+            if not improved and self.inductor_meta.get(
+                "coordinate_descent_check_all_directions"
+            ):
+                old_best_timing = best_timing
+                improved, best_config, best_timing = self.check_all_tuning_directions(
+                    func, best_config, best_timing
+                )
+
+                if improved:
+                    msg = red_text(
+                        "%s: Coordinate descend tuning found improvement of %.3fx by looking in all directions."
+                    )
+                    log.debug(
+                        msg,
+                        self.name,
+                        old_best_timing / best_timing,
+                    )
+
+        log.debug(
+            "%s: Improve from %s %f -> %s %f, %.3fx",
+            self.name,
+            baseline_config,
+            baseline_timing,
+            best_config,
+            best_timing,
+            baseline_timing / best_timing,
+        )
+
+        return best_config
+
+    @staticmethod
+    def autotune_single_field(fn, init_val, min_val=None, max_val=None):
+        """
+        fn is a function that takes the field value and returns the benchmarking result
+        init_val is the starting point of autotuning.
+
+        Should work well for parabola like curve. Here is a real example
+        for split-size of mix-order-reduction: https://github.com/pytorch/pytorch/pull/166461
+        """
+        cache = {}
+
+        def _bench(val):
+            if val not in cache:
+                cache[val] = fn(val)
+                # print(f"split size {val} -> {cache[val]:.3f} ms")
+            return cache[val]
+
+        if min_val is None:
+            min_val = 1
+        if max_val is None:
+            max_val = 2**30  # some arbitrary large value
+
+        best_val = init_val
+        improved = True
+        while improved:
+            improved = False
+            candlist = [best_val // 2, best_val * 2]
+            for cand in candlist:
+                cand = max(cand, min_val)
+                cand = min(cand, max_val)
+
+                if _bench(cand) < _bench(best_val):
+                    best_val = cand
+                    improved = True
+
+        return best_val
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/debug_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/debug_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c15ff890dda6bc2cf9b541c1d5c8b76939c07ce
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/debug_utils.py
@@ -0,0 +1,138 @@
+import functools
+import logging
+import threading
+import weakref
+
+import torch
+from torch.utils._ordered_set import OrderedSet
+
+
+log = logging.getLogger(__name__)
+
+local = threading.local()
+local.memory_tracker = None
+
+
+class BufferMemoryTracker:
+    """
+    Tracks inductor runtime allocations and deallocations to compare against
+    expected behavior.
+    """
+
+    def __init__(self) -> None:
+        self.tensor_tracker: dict[str, torch.storage.UntypedStorage] = (
+            weakref.WeakValueDictionary()  # type: ignore[assignment]
+        )
+        self.died_since_last_step: OrderedSet[str] = OrderedSet()
+        self.added_since_last_step: OrderedSet[str] = OrderedSet()
+        self.error = (
+            torch._inductor.config.test_configs.track_memory_lifecycle == "assert"
+        )
+
+    def set_tensor(self, name: str, tensor: torch.Tensor) -> None:
+        storage = tensor.untyped_storage()
+
+        self.added_since_last_step.add(name)
+        self.tensor_tracker[name] = storage
+
+        def on_tensor_death() -> None:
+            self.died_since_last_step.add(name)
+
+        weakref.finalize(storage, on_tensor_death)
+
+    def advance_step(self) -> None:
+        self.died_since_last_step.clear()
+        self.added_since_last_step.clear()
+
+    def log_or_raise(self, msg: str) -> None:
+        if self.error:
+            raise RuntimeError(msg)
+        else:
+            log.info(msg)
+
+    def check_step_delta(
+        self,
+        expected_allocated: list[str],
+        expected_freed: list[str],
+        is_final_step: bool,
+    ) -> None:
+        """Check only the delta changes since last step"""
+
+        # Check expected deaths - we dont currently distinguish between nodes which die in last step
+        # and are returned as outputs, so skip if final_step.
+        if not is_final_step:
+            missing_deaths = OrderedSet(expected_freed) - self.died_since_last_step
+            if missing_deaths:
+                self.log_or_raise(
+                    f"Expected tensors to die but still alive: {missing_deaths}"
+                )
+
+        # Check for unexpected deaths
+        unexpected_deaths = self.died_since_last_step - OrderedSet(expected_freed)
+        if unexpected_deaths:
+            self.log_or_raise(f"Unexpected tensor deaths: {unexpected_deaths}")
+
+        # Check newly alive tensors - separate messages like deaths
+        actual_allocated = self.added_since_last_step
+        expected_allocated_set = OrderedSet(expected_allocated)
+
+        extra_alive = actual_allocated - expected_allocated_set
+        if extra_alive:
+            self.log_or_raise(f"Unexpected allocated tensors: {extra_alive}")
+
+        missing_alive = expected_allocated_set - actual_allocated
+        if missing_alive:
+            self.log_or_raise(
+                f"Expected allocated tensors but missing: {missing_alive}"
+            )
+
+        # Reset for next step
+        self.advance_step()
+
+        if is_final_step:
+            local.memory_tracker = None
+
+
+def get_mem_tracker() -> BufferMemoryTracker:
+    if local.memory_tracker is None:
+        local.memory_tracker = BufferMemoryTracker()
+    return local.memory_tracker
+
+
+def track_tensor(tensor: torch.Tensor, name: str) -> None:
+    get_mem_tracker().set_tensor(name, tensor)
+
+
+def tracked_empty_strided(
+    size: list[int],
+    stride: list[int],
+    *,
+    dtype: torch.dtype,
+    device: torch.device,
+    name: str,
+) -> torch.Tensor:
+    o = torch.empty_strided(size, stride, dtype=dtype, device=device)
+    track_tensor(o, name)
+    return o
+
+
+def check_memory_step(
+    allocated: list[str], freed: list[str], is_final_step: bool = False
+) -> None:
+    tracker = get_mem_tracker()
+    tracker.check_step_delta(allocated, freed, is_final_step)
+
+
+@functools.lru_cache(None)
+def register_check_mem_op() -> None:
+    lib = torch.library.Library("_inductor_debug", "FRAGMENT")  # noqa: TOR901
+    lib.define(
+        "check_memory_step(str[] allocated, str[] freed, bool is_final_step) -> ()"
+    )
+    lib.impl("check_memory_step", check_memory_step, "BackendSelect")
+    from torch._higher_order_ops.effects import _EffectType, _register_effectful_op
+
+    _register_effectful_op(
+        torch.ops._inductor_debug.check_memory_step.default,
+        _EffectType.ORDERED,
+    )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/halide_helpers.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/halide_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4bf70fe9d8db1cb66379df11e025ad84cc0069b
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/halide_helpers.py
@@ -0,0 +1,118 @@
+# mypy: allow-untyped-defs
+try:
+    import halide as hl  # type: ignore[import-untyped, import-not-found]
+except ImportError:
+    hl = None
+
+PHILOX_N_ROUNDS_DEFAULT = 10  # Default number of rounds for philox
+
+if hl is not None:
+    PHILOX_KEY_A_U32 = hl.u32(0x9E3779B9)
+    PHILOX_KEY_B_U32 = hl.u32(0xBB67AE85)
+    PHILOX_ROUND_A_U32 = hl.u32(0xD2511F53)
+    PHILOX_ROUND_B_U32 = hl.u32(0xCD9E8D57)
+else:
+    PHILOX_KEY_A_U32 = None
+    PHILOX_KEY_B_U32 = None
+    PHILOX_ROUND_A_U32 = None
+    PHILOX_ROUND_B_U32 = None
+
+
+def _pair_uniform_to_normal(u1, u2):
+    """Box-Muller transform"""
+    u1 = hl.max(hl.f32(1.0e-7), u1)
+    th = hl.f32(6.283185307179586) * u2
+    r = hl.sqrt(hl.f32(-2.0) * hl.log(u1))
+    return r * hl.cos(th), r * hl.sin(th)
+
+
+def _uint_to_uniform_float(x):
+    """
+    Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1).
+    """
+
+    # TODO:
+    # conditions can be simplified
+    # scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1)
+    # https://github.com/triton-lang/triton/blob/e4a0d93ff1a367c7d4eeebbcd7079ed267e6b06f/python/triton/language/random.py#L116-L132.
+    assert x.type() == hl.UInt(32) or x.type() == hl.Int(32)
+    x = hl.cast(hl.Int(32), x)
+    scale = hl.f64(4.6566127342e-10)
+    x = hl.select(x < 0, -x - 1, x)
+    return x * scale
+
+
+def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds):
+    def umulhi(a, b):
+        a = hl.cast(hl.UInt(64), a)
+        b = hl.cast(hl.UInt(64), b)
+        return hl.cast(hl.UInt(32), ((a * b) >> 32) & hl.u64(0xFFFFFFFF))
+
+    for _ in range(n_rounds):
+        _c0, _c2 = c0, c2
+
+        c0 = umulhi(PHILOX_ROUND_B_U32, _c2) ^ c1 ^ k0
+        c2 = umulhi(PHILOX_ROUND_A_U32, _c0) ^ c3 ^ k1
+        c1 = PHILOX_ROUND_B_U32 * _c2
+        c3 = PHILOX_ROUND_A_U32 * _c0
+        # raise key
+        k0 = k0 + PHILOX_KEY_A_U32
+        k1 = k1 + PHILOX_KEY_B_U32
+
+    return c0, c1, c2, c3
+
+
+def halide_philox(seed, c0, c1, c2, c3, n_rounds):
+    seed = hl.cast(hl.UInt(64), seed)
+
+    assert c0.type().bits() == 32
+
+    seed_hi = hl.cast(hl.UInt(32), (seed >> 32) & hl.u64(0xFFFFFFFF))
+    seed_lo = hl.cast(hl.UInt(32), seed & hl.u64(0xFFFFFFFF))
+
+    return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds)
+
+
+def randint4x(seed, offset, n_rounds):
+    offset = hl.cast(hl.UInt(32), offset)
+    _0 = hl.u32(0)
+    return halide_philox(seed, offset, _0, _0, _0, n_rounds)
+
+
+def rand4x(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT):
+    i1, i2, i3, i4 = randint4x(seed, offset, n_rounds)
+    u1 = _uint_to_uniform_float(i1)
+    u2 = _uint_to_uniform_float(i2)
+    u3 = _uint_to_uniform_float(i3)
+    u4 = _uint_to_uniform_float(i4)
+    return u1, u2, u3, u4
+
+
+def randint(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT):
+    ret, _, _, _ = randint4x(seed, offset, n_rounds)
+    return ret
+
+
+def rand(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT):
+    source = randint(seed, offset, n_rounds)
+    return _uint_to_uniform_float(source)
+
+
+def randn(seed, offset):
+    i1, i2, _, _ = randint4x(seed, offset, PHILOX_N_ROUNDS_DEFAULT)
+    u1 = _uint_to_uniform_float(i1)
+    u2 = _uint_to_uniform_float(i2)
+    n1, _ = _pair_uniform_to_normal(u1, u2)
+    return n1
+
+
+def randint64(seed, offset, low, high):
+    r0, r1, _r2, _r3 = randint4x(seed, offset, PHILOX_N_ROUNDS_DEFAULT)
+    r0 = hl.cast(hl.UInt(64), r0)
+    r1 = hl.cast(hl.UInt(64), r1)
+
+    result = r0 | (r1 << 32)
+    size = high - low
+    result = result % hl.cast(hl.UInt(64), size)
+    result = hl.cast(hl.Int(64), result) + low
+    return result
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/hints.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/hints.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9ddf91e9a59cf805907e7ee4accecdd4c214a37
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/hints.py
@@ -0,0 +1,224 @@
+# mypy: allow-untyped-defs
+from __future__ import annotations
+
+import collections
+import functools
+import typing
+from enum import auto, Enum
+
+import torch
+from torch.utils._triton import has_triton_package
+
+
+# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values
+# NOTE: if these fail asserts submit a PR to increase them
+TRITON_MAX_BLOCK = {
+    "X": 8192 if torch.version.hip else 4096,
+    "Y": 1024,
+    "Z": 1024,
+    "R0_": 4096 * 16,  # * 16 is multi-kernel only
+    "R1_": 2048 * 16,  # * 16 is multi-kernel only
+}
+TRITON_MAX_RSPLIT = 64
+
+
+class ReductionHint(Enum):
+    INNER = 0
+    OUTER = 1
+    OUTER_TINY = 2
+    DEFAULT = 3
+
+
+class TileHint(Enum):
+    SQUARE = 0
+    DEFAULT = 1
+
+
+# Define `AttrsDescriptorWrapper` function with clear conditional handling
+if has_triton_package():
+    import triton
+    import triton.backends.compiler
+    import triton.compiler.compiler
+
+    if hasattr(triton.backends.compiler, "AttrsDescriptor"):
+        # Triton 3.2.0 - the second implementation
+        from triton.backends.compiler import AttrsDescriptor
+
+        def AttrsDescriptorWrapper(
+            divisible_by_16=None,
+            equal_to_1=None,
+        ):
+            # Prepare the arguments for AttrsDescriptor
+            kwargs = {
+                "tt.divisibility": divisible_by_16,
+                "tt.equal_to": equal_to_1,
+            }
+
+            # Instantiate AttrsDescriptor with the prepared arguments
+            res = AttrsDescriptor.from_dict(
+                {"arg_properties": kwargs, "cls": AttrsDescriptor.__name__}
+            )
+            assert res.property_values["tt.divisibility"] == 16
+            assert res.property_values["tt.equal_to"] == 1
+            return res
+
+    elif hasattr(triton.compiler.compiler, "AttrsDescriptor"):
+        # Triton 3.0.0 - the original implementation
+        from triton.compiler.compiler import AttrsDescriptor
+
+        def AttrsDescriptorWrapper(
+            divisible_by_16=None,
+            equal_to_1=None,
+        ):
+            # Prepare the arguments for AttrsDescriptor
+            kwargs = {
+                "divisible_by_16": divisible_by_16,
+                "equal_to_1": equal_to_1,
+            }
+
+            # Instantiate AttrsDescriptor with the prepared arguments
+            return AttrsDescriptor(**kwargs)
+
+    else:
+        # Triton in 2025:
+        # note: there's also a range of triton commits not currently supported
+        # from ~Dec 9, 2024 to Jan 1 2025, in which AttrsDescriptors are still
+        # used, but the contents are different.
+
+        def AttrsDescriptorWrapper(
+            divisible_by_16=None,
+            equal_to_1=None,
+        ):
+            # pyrefly: ignore [not-iterable]
+            return {(x,): [["tt.divisibility", 16]] for x in divisible_by_16}
+
+else:
+    # Define a namedtuple as a fallback when AttrsDescriptor is not available
+    AttrsDescriptorWrapper = collections.namedtuple(  # type: ignore[no-redef, name-match]
+        # pyrefly: ignore [invalid-argument]
+        "AttrsDescriptor",
+        ["divisible_by_16", "equal_to_1"],
+        defaults=[(), ()],
+    )
+
+
+_NUM_THREADS_PER_WARP = 32
+
+
+class HeuristicType(Enum):
+    PERSISTENT_REDUCTION = auto()
+    POINTWISE = auto()
+    REDUCTION = auto()
+    SPLIT_SCAN = auto()
+    TEMPLATE = auto()
+    USER_AUTOTUNE = auto()
+    FIXED = auto()
+
+
+class AutotuneHint(Enum):
+    ONE_ELEMENT_PER_THREAD = 0
+
+    # Triton codegen tries to codegen set of AutotuneHints.
+    # Enum.__repr__ looks like """
+    # which isn't valid python.
+    # Enum.__str__ will just return "AutotuneHint.ELEMENTS_PER_WARP_32".
+    __repr__ = Enum.__str__
+
+
+class DeviceProperties(typing.NamedTuple):
+    """Copy device properties into a data structure not requiring torch to be imported"""
+
+    type: str  # type: ignore[assignment]
+    index: int  # type: ignore[assignment]
+    multi_processor_count: int
+    cc: int
+    major: int | None = None
+    regs_per_multiprocessor: int | None = None
+    max_threads_per_multi_processor: int | None = None
+    max_threads_per_block: int | None = None
+    warp_size: int | None = None
+
+    @classmethod
+    @functools.cache
+    def create(cls, device) -> DeviceProperties:
+        import torch
+        from torch._dynamo.device_interface import get_interface_for_device
+
+        device_type = device.type
+
+        if torch.version.hip and device_type == "cuda":
+            device_type = "hip"
+
+        device_interface = get_interface_for_device(device)
+        props = device_interface.get_device_properties(device)
+        try:
+            multi_processor_count = props.multi_processor_count
+        except AttributeError:
+            if device_type == "xpu":
+                multi_processor_count = props.gpu_subslice_count
+            elif device_type == "mtia":
+                multi_processor_count = 64
+            else:
+                raise
+        return cls(
+            type=device_type,
+            index=device.index,
+            multi_processor_count=multi_processor_count,
+            cc=device_interface.get_compute_capability(device),
+            major=getattr(props, "major", None),
+            regs_per_multiprocessor=getattr(props, "regs_per_multiprocessor", None),
+            max_threads_per_multi_processor=getattr(
+                props, "max_threads_per_multi_processor", None
+            ),
+            max_threads_per_block=getattr(props, "max_threads_per_block", 1024),
+            warp_size=getattr(props, "warp_size", 32 if device_type != "cpu" else None),
+        )
+
+
+class HalideInputSpec(typing.NamedTuple):
+    ctype: str
+    name: str
+    shape: list[str] | None = None
+    stride: list[str] | None = None
+    offset: str | None = None
+    alias_of: str | None = None
+
+    def bindings_type(self) -> str:
+        if self.ctype in ("at::Half*", "at::BFloat16*"):
+            return "uint16_t*"  # half not defined
+        return self.ctype
+
+    def halide_type(self) -> str:
+        if self.ctype == "at::Half*":
+            return "halide_type_t(halide_type_float, 16)"  # half not defined
+        if self.ctype == "at::BFloat16*":
+            return "halide_type_t(halide_type_bfloat, 16)"  # half not defined
+        return f"halide_type_of<{self.ctype.replace('*', '')}>()"
+
+    def is_scalar(self) -> bool:
+        return self.shape is None
+
+    def is_buffer(self) -> bool:
+        return self.shape is not None
+
+
+class HalideMeta(typing.NamedTuple):
+    argtypes: list[HalideInputSpec]
+    target: str
+    scheduler: str | None = None
+    scheduler_flags: dict[str, int | str] | None = None
+    cuda_device: int | None = None
+
+    def args(self) -> list[str]:
+        """Command line args to pass to halide generator"""
+        args = [f"target={self.target}"]
+        if self.scheduler:
+            args.append(f"autoscheduler={self.scheduler}")
+        if self.scheduler_flags:
+            assert self.scheduler
+            for k, v in self.scheduler_flags.items():
+                args.append(f"autoscheduler.{k}={v}")
+        return args
+
+    def is_cuda(self) -> bool:
+        return self.cuda_device is not None
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/runtime_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/runtime_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4e66378e85aec07c60e68e48d441178db423dc2
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/runtime_utils.py
@@ -0,0 +1,249 @@
+from __future__ import annotations
+
+import functools
+import operator
+from typing import Any, TYPE_CHECKING
+
+import torch
+
+# NOTE: other files rely on the imports below
+from torch._dynamo import callback as compilation_callback  # noqa: F401
+from torch._inductor.runtime.cache_dir_utils import (  # noqa: F401
+    cache_dir,
+    default_cache_dir,
+    triton_cache_dir,
+)
+
+
+if TYPE_CHECKING:
+    from collections.abc import Hashable
+
+    from .triton_compat import Config
+
+
+def conditional_product(*args: int) -> int:
+    return functools.reduce(operator.mul, [x for x in args if x])
+
+
+def ceildiv(number: int, denom: int) -> int:
+    return -(number // -denom)
+
+
+def is_power_of_2(n: int) -> bool:
+    """Returns whether n = 2 ** m for some integer m."""
+    return n > 0 and n & n - 1 == 0
+
+
+def next_power_of_2(n: int) -> int:
+    """Return the smallest power of 2 greater than or equal to n"""
+    n -= 1
+    n |= n >> 1
+    n |= n >> 2
+    n |= n >> 4
+    n |= n >> 8
+    n |= n >> 16
+    n |= n >> 32
+    n += 1
+    return n
+
+
+def last_power_of_2(n: int) -> int:
+    """Return the largest power of 2 less than or equal to n"""
+    next_pow2 = next_power_of_2(n)
+    return next_pow2 // 2 if next_pow2 > n else next_pow2
+
+
+def get_num_bytes(*args: torch.Tensor, num_in_out_args: int = 0) -> int:
+    """
+    Return the total number of bytes the arguments of tensor type takes.
+
+    For in/out args, tensor sizes are counted twice: once for reading and
+    once for writing.
+
+    The first num_in_out_args arguments are in out tensors.
+    """
+    return sum(
+        arg.numel() * arg.element_size() * (1 + int(i < num_in_out_args))
+        for i, arg in enumerate(args)
+        if isinstance(arg, torch.Tensor)
+    )
+
+
+def triton_config_to_hashable(cfg: Config) -> Hashable:
+    """
+    Convert triton config to a tuple that can uniquely identify it. We can use
+    the return value as a dictionary key.
+    """
+    # pyrefly: ignore [missing-attribute]
+    items = sorted(cfg.kwargs.items())
+    # pyrefly: ignore [missing-attribute]
+    items.append(("num_warps", cfg.num_warps))
+    # pyrefly: ignore [missing-attribute]
+    items.append(("num_stages", cfg.num_stages))
+    return tuple(items)
+
+
+def validate_triton_config(cfg: Config) -> None:
+    # [Note: Triton pre_hook in inductor]
+    # pre-hook is a lambda function, which we don't attempt to serialize.
+    # right now, if a pre-hook is attached to the config, it will not be saved;
+    # and then it won't be used when the config is loaded from cache.
+    # So we assert - if we do get a pre_hook, it might get ignored after caching.
+    assert getattr(cfg, "pre_hook", None) is None, (
+        "triton configs with pre_hooks not supported"
+    )
+
+
+def create_bandwidth_info_str(
+    ms: float,
+    num_gb: float,
+    gb_per_s: float,
+    prefix: str = "",
+    suffix: str = "",
+    color: bool = True,
+) -> str:
+    info_str = f"{prefix}{ms:.3f}ms    \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}"
+    slow = ms > 0.012 and gb_per_s < 650
+    return red_text(info_str) if color and slow else info_str
+
+
+def get_max_y_grid() -> int:
+    return 65535
+
+
+try:
+    # pyrefly: ignore [import-error]
+    import colorama
+
+    HAS_COLORAMA = True
+except ModuleNotFoundError:
+    HAS_COLORAMA = False
+    colorama = None  # type: ignore[assignment]
+
+
+if HAS_COLORAMA:
+
+    def _color_text(msg: str, color: str) -> str:
+        # pyrefly: ignore [missing-attribute]
+        return getattr(colorama.Fore, color.upper()) + msg + colorama.Fore.RESET
+
+else:
+
+    def _color_text(msg: str, color: str) -> str:
+        return msg
+
+
+def green_text(msg: str) -> str:
+    return _color_text(msg, "green")
+
+
+def yellow_text(msg: str) -> str:
+    return _color_text(msg, "yellow")
+
+
+def red_text(msg: str) -> str:
+    return _color_text(msg, "red")
+
+
+def blue_text(msg: str) -> str:
+    return _color_text(msg, "blue")
+
+
+def get_first_attr(obj: Any, *attrs: str) -> Any:
+    """
+    Return the first available attribute or throw an exception if none is present.
+    """
+    for attr in attrs:
+        if hasattr(obj, attr):
+            return getattr(obj, attr)
+
+    raise AssertionError(f"{obj} does not has any of the attributes: {attrs}")
+
+
+dynamo_timed = torch._dynamo.utils.dynamo_timed  # type: ignore[has-type]
+
+
+def triton_hash_to_path_key(key: str) -> str:
+    # In early versions of Triton, the hash is directly used in the path name.
+    # Later, the hash is converted to base64 before being used in the path name.
+    # Later, the base64 conversion was replaced to the base32
+    #
+    # This code tries to import _base64 and falls back to _base32 if _base64 is unavailable.
+    #
+    # To handle this, try to import the to-base64-conversion function.
+    # If it exists, use it; otherwise, try using _base32; if both are unavailable, use the hash directly.
+    try:
+        from triton.runtime.cache import _base64
+
+        return _base64(key)
+    except Exception:
+        try:
+            from triton.runtime.cache import _base32
+
+            return _base32(key)
+        except Exception:
+            return key
+
+
+def compile_mps_shader(source: str) -> Any:
+    """
+    Compiles shader source but raise more actionable error message when needed
+    """
+    try:
+        return torch.mps.compile_shader(source)
+    except SyntaxError as err:
+        raise SyntaxError(f"failed to compile {source} with {err.msg}") from err
+
+
+def torch_dtype_to_jax_runtime(dtype: torch.dtype) -> Any:
+    """
+    Map PyTorch dtype to actual JAX dtype object at runtime.
+
+    This helper is used in generated Pallas kernels at runtime to convert
+    PyTorch dtypes to JAX dtype objects (not string representations).
+
+    Args:
+        dtype: PyTorch dtype to convert
+
+    Returns:
+        JAX dtype object (e.g., jnp.float32 object itself)
+    """
+    import jax.numpy as jnp  # pyrefly: ignore [import-error]
+
+    dtype_map = {
+        torch.float32: jnp.float32,
+        torch.float64: jnp.float64,
+        torch.float16: jnp.float16,
+        torch.bfloat16: jnp.bfloat16,
+        torch.int32: jnp.int32,
+        torch.int64: jnp.int64,
+        torch.int16: jnp.int16,
+        torch.int8: jnp.int8,
+        torch.uint8: jnp.uint8,
+        torch.bool: jnp.bool_,
+        torch.complex64: jnp.complex64,
+        torch.complex128: jnp.complex128,
+    }
+    if dtype not in dtype_map:
+        raise ValueError(f"Unsupported dtype for JAX conversion: {dtype}")
+    return dtype_map[dtype]
+
+
+def torch_dtype_to_jax(dtype: torch.dtype) -> str:
+    """
+    Map PyTorch dtype to JAX dtype expression string.
+
+    This helper is used at compile time in codegen to generate
+    JAX dtype expressions for Pallas kernels.
+
+    Args:
+        dtype: PyTorch dtype to convert
+
+    Returns:
+        JAX dtype expression as string (e.g., "jnp.float32")
+    """
+    jax_dtype = torch_dtype_to_jax_runtime(dtype)
+    dtype_name = jax_dtype.__name__
+    if dtype_name == "bool":
+        dtype_name = "bool_"
+    return f"jnp.{dtype_name}"
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/static_cuda_launcher.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/static_cuda_launcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..f48f351ce823a325a2f15092ad964aeba09aaf82
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/static_cuda_launcher.py
@@ -0,0 +1,270 @@
+import functools
+import os
+from typing import Any
+from typing_extensions import Unpack
+
+from .triton_compat import ASTSource, CompiledKernel, knobs as triton_knobs
+from .triton_helpers import get_constexprs
+
+
+class StaticallyLaunchedCudaKernel:
+    """
+    Parses the metadata of a CompiledKernel from Triton into a structure that can
+    launch the cuda kernel directly. Only works for triton kernels compiled to cubin.
+
+    Doing this avoids C++ codegen and compilation during compile, since we can use a
+    statically compiled library to launch the kernel. To avoid mallocing for the arguments,
+    we have a launcher for different numbers of arguments up to a max. StaticCudaLauncher
+    only supports # of arguments up until 10 for now.
+
+    Workflow:
+    Compile time:
+    1. Compile a kernel with triton and get a CompiledKernel
+    2. Instantiate kernel = StaticallyLaunchedCudaKernel(triton_kernel)
+    3. Write to a cubin file: kernel.write_cubin_to_file(filepath)
+    4. Call kernel.load_kernel() (CUDA should be initialized by this point) to load the cubin
+    Runtime:
+    5. Call kernel.run(grid, stream, args) to launch the kernel
+
+    Note that after step 3, StaticallyLaunchedCudaKernel is fully pickleable/serializable.
+    This allows it to be cached by FXGraphCache/TritonBundler, as well as sent from the worker
+    to the parent process in inductor.
+
+    There are two main versions of triton that we wish to support: 3.3 and 3.2. Triton makes considerable changes
+    to how it handles constants in 3.3, so there's some special logic necessary to handle both versions.
+    """
+
+    def __init__(self, kernel: CompiledKernel) -> None:
+        # pyrefly: ignore [missing-attribute]
+        self.name = kernel.src.fn.__name__
+        # pyrefly: ignore [missing-attribute]
+        self.cubin_raw = kernel.asm.get("cubin", None)
+        # pyrefly: ignore [missing-attribute]
+        self.cubin_path = kernel._cubin_path
+
+        # Used by torch.compile to filter constants in older triton versions
+        # pyrefly: ignore [missing-attribute]
+        self.arg_names = kernel.src.fn.arg_names
+
+        # Const exprs that are declared by the triton kernel directly
+        # Used to generate the kernel launcher's def args
+        # pyrefly: ignore [missing-attribute]
+        self.declared_constexprs = get_constexprs(kernel.src.fn)
+
+        # pyrefly: ignore [missing-attribute]
+        self.hash = kernel.hash
+
+        if triton_knobs is None:
+            # pyrefly: ignore [missing-attribute]
+            launch_enter = kernel.__class__.launch_enter_hook
+            # pyrefly: ignore [missing-attribute]
+            launch_exit = kernel.__class__.launch_exit_hook
+        else:
+            launch_enter = triton_knobs.runtime.launch_enter_hook
+            launch_exit = triton_knobs.runtime.launch_exit_hook
+
+        def hook_is_empty(hook: Any) -> bool:
+            if hook is None:
+                return True
+            if (
+                triton_knobs
+                and (HookChain := getattr(triton_knobs, "HookChain", None)) is not None
+                and isinstance(hook, HookChain)
+            ):
+                # Support hooks after https://github.com/triton-lang/triton/pull/7866
+                return len(hook.calls) == 0
+            return False
+
+        if not hook_is_empty(launch_enter) or not hook_is_empty(launch_exit):
+            raise NotImplementedError(
+                "We don't support launch enter or launch exit hooks"
+            )
+        # pyrefly: ignore [missing-attribute]
+        self.num_warps = kernel.metadata.num_warps
+        self.shared = (
+            # pyrefly: ignore [missing-attribute]
+            kernel.shared if hasattr(kernel, "shared") else kernel.metadata.shared
+        )
+
+        def needs_scratch_arg(scratch_name: str, param_name: str) -> bool:
+            # pyrefly: ignore [missing-attribute]
+            if hasattr(kernel.metadata, param_name):
+                if getattr(kernel.metadata, param_name) > 0:
+                    raise NotImplementedError(
+                        f"{scratch_name} scratch not yet supported"
+                    )
+                return True
+            return False
+
+        # Newer triton versions pass an extra global scratch parameter to the compiled cuda kernel.
+        # Inductor never uses this field or enables it, but we still have to pass
+        # an extra None into the set of params if its enabled
+        self.has_global_scratch = needs_scratch_arg("Global", "global_scratch_size")
+        # same situation for profile scratch - triton-lang/triton#7258
+        self.has_profile_scratch = needs_scratch_arg("Profile", "profile_scratch_size")
+
+        # pyrefly: ignore [missing-attribute]
+        self.arg_tys = self.arg_ty_from_signature(kernel.src)
+        self.function: int | None = None  # Loaded by load_kernel(on the parent process)
+        num_ctas = 1
+        if hasattr(kernel, "num_ctas"):
+            num_ctas = kernel.num_ctas
+        elif hasattr(kernel, "metadata"):
+            num_ctas = kernel.metadata.num_ctas
+
+        if num_ctas != 1:
+            raise NotImplementedError(
+                "Static cuda launcher only supports num_ctas == 1"
+            )
+
+    def reload_cubin_from_raw(self, filepath: str) -> str:
+        """
+        If the cubin file triton generated gets deleted under us, we can
+        reload it from the raw cubin file.
+        """
+        if self.cubin_path is None:
+            assert self.cubin_raw is not None
+            os.makedirs(os.path.dirname(filepath), exist_ok=True)
+            with open(filepath, "wb") as f:
+                f.write(self.cubin_raw)
+                self.cubin_path = filepath
+        return self.cubin_path
+
+    def load_kernel(self, device: int) -> None:
+        from torch._C import _StaticCudaLauncher
+
+        if self.function is not None:
+            return
+
+        assert hasattr(self, "cubin_path")
+        assert self.cubin_path is not None
+        (self.function, self.n_regs, self.n_spills) = _StaticCudaLauncher._load_kernel(
+            self.cubin_path, self.name, self.shared, device
+        )
+        # Don't need the cubin path anymore now that we've loaded
+        self.cubin_path = None
+        self.cubin_raw = None
+
+    @staticmethod
+    @functools.lru_cache
+    def type_mappings() -> dict[str, str]:
+        return {
+            "i1": "i",
+            "i8": "b",
+            "i16": "h",
+            "i32": "i",
+            "i64": "l",
+            "u1": "I",
+            "u8": "B",
+            "u16": "H",
+            "u32": "I",
+            "u64": "K",
+            "fp16": "f",
+            "bf16": "f",
+            "fp32": "f",
+            "f32": "f",
+            "fp64": "d",
+            # TODO handle nvTmaDesc/CUtensormap
+        }
+
+    def extract_type(self, ty: str) -> str:
+        """
+        Takes a triton type from CompiledKernel.signature and
+        converts it into a single char encoding. _StaticCudaLauncher
+        will switch on this char to figure out what type the underlying
+        value should be passed to the triton kernel as.
+        """
+        if ty[0] == "*":
+            return "O"
+        elif ty == "nvTmaDesc":
+            raise NotImplementedError("nvTmaDesc kernels are not yet supported")
+        return StaticallyLaunchedCudaKernel.type_mappings()[ty]
+
+    def arg_ty_from_signature(self, src: ASTSource) -> str:
+        def index_key(i: Any) -> int:
+            if isinstance(i, str):
+                # pyrefly: ignore [missing-attribute]
+                return src.fn.arg_names.index(i)
+            elif isinstance(i, tuple):
+                # In triton 3.3, src.fn.constants has tuples as a key
+                return i[0]
+            else:
+                return i
+
+        # pyrefly: ignore [missing-attribute]
+        signature = {index_key(key): value for key, value in src.signature.items()}
+        # Triton uses these as the main way to filter out constants passed to their cubin
+        constants = [index_key(key) for key in getattr(src, "constants", dict())]
+        # This value is always a superset of kernel.fn.constexprs: kernel.fn.constexprs are
+        # constants declared by the triton kernel directly, whereas this list can have
+        # constants that are unused by the triton kernel that triton figured out during
+        # compilation.
+        self.full_constexprs = constants
+        # Despite requiring them to be passed in, the triton CUDA launcher
+        # completely ignores the constexprs passed into it when generating code.
+        # So we can ignore them here too
+        params = []
+
+        for i in sorted(signature.keys()):
+            ty = signature[i]
+            # In newer triton versions, constants are passed in to signature with type `constexpr`
+            # In older triton versions, there can be constants in src.constants that are not `constexpr` in signature
+            # so we check both here
+            if ty == "constexpr" or i in constants:
+                pass
+            else:
+                # pyrefly: ignore [bad-argument-type]
+                params.append(self.extract_type(ty))
+        return "".join(params)
+
+    def __getstate__(self) -> dict[str, Any]:
+        # Remove objects that are no longer valid for pickling
+        state = self.__dict__.copy()
+        state["function"] = None
+        # Cubin paths aren't consistent across processes, so we clear
+        # and reload them.
+        state["cubin_path"] = None
+        return state
+
+    def run(
+        self,
+        grid_x: int,
+        grid_y: int,
+        grid_z: int,
+        stream: int,
+        *args: Unpack[tuple[object, ...]],
+    ) -> None:
+        """Actually run the kernel at runtime. This function is the hot codepath."""
+        from torch._C import _StaticCudaLauncher
+
+        # Assert load_kernel() has been called and args match
+        assert self.function is not None
+
+        # TODO: actually, if the args *don't* match, we probably should
+        # throw an exception. But if inductor is the only one calling this
+        # thing, it should always match.
+        # Get rid of constants before passing to cubin launcher
+
+        # Add a None if triton wants extra parameters for scratch spaces
+        arg_tys = self.arg_tys
+        for has_scratch in [self.has_global_scratch, self.has_profile_scratch]:
+            if has_scratch:
+                arg_tys = arg_tys + "O"
+                args = (*args, None)
+        # pyrefly: ignore [bad-argument-type]
+        assert len(args) == len(arg_tys)
+
+        # TODO: can handle grid functions here or in C++, so
+        # that we don't need the grid handler above.
+        _StaticCudaLauncher._launch_kernel(
+            self.function,
+            grid_x,
+            grid_y,
+            grid_z,
+            self.num_warps,
+            self.shared,
+            arg_tys,
+            # pyrefly: ignore [bad-argument-type]
+            args,
+            stream,
+        )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_compat.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_compat.py
new file mode 100644
index 0000000000000000000000000000000000000000..49ceacb50bc3d9f4b6c5c9451d6b810d7898bf20
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_compat.py
@@ -0,0 +1,176 @@
+from __future__ import annotations
+
+import inspect
+from typing import Any
+
+import torch
+
+
+try:
+    import triton
+except ImportError:
+    triton = None
+
+
+if triton is not None:
+    import triton.language as tl
+    from triton import Config
+    from triton.compiler import CompiledKernel
+    from triton.runtime.autotuner import OutOfResources
+    from triton.runtime.jit import JITFunction, KernelInterface
+
+    try:
+        from triton.runtime.autotuner import PTXASError
+    except ImportError:
+
+        class PTXASError(Exception):  # type: ignore[no-redef]
+            pass
+
+    try:
+        from triton.compiler.compiler import ASTSource
+    except ImportError:
+        ASTSource = None
+
+    try:
+        from triton.backends.compiler import GPUTarget
+    except ImportError:
+
+        def GPUTarget(
+            backend: str,
+            arch: int | str,
+            warp_size: int,
+        ) -> Any:
+            if torch.version.hip:
+                return [backend, arch, warp_size]
+            return (backend, arch)
+
+    # In the latest triton, math functions were shuffled around into different modules:
+    # https://github.com/triton-lang/triton/pull/3172
+    try:
+        from triton.language.extra import libdevice
+
+        libdevice = tl.extra.libdevice  # noqa: F811
+        math = tl.math
+    except ImportError:
+        if hasattr(tl.extra, "cuda") and hasattr(tl.extra.cuda, "libdevice"):
+            libdevice = tl.extra.cuda.libdevice
+            math = tl.math
+        elif hasattr(tl.extra, "intel") and hasattr(tl.extra.intel, "libdevice"):
+            libdevice = tl.extra.intel.libdevice
+            math = tl.math
+        else:
+            libdevice = tl.math
+            math = tl
+
+    try:
+        from triton.language.standard import _log2
+    except ImportError:
+
+        def _log2(x: Any) -> Any:
+            raise NotImplementedError
+
+    def _triton_config_has(param_name: str) -> bool:
+        if not hasattr(triton, "Config"):
+            return False
+        if not hasattr(triton.Config, "__init__"):
+            return False
+        return param_name in inspect.signature(triton.Config.__init__).parameters
+
+    # Drop the legacy support of autoWS
+    HAS_WARP_SPEC = False
+
+    try:
+        from triton import knobs
+    except ImportError:
+        knobs = None
+
+    try:
+        from triton.runtime.cache import triton_key  # type: ignore[attr-defined]
+    except ImportError:
+        from triton.compiler.compiler import (
+            triton_key,  # type: ignore[attr-defined,no-redef]
+        )
+
+    builtins_use_semantic_kwarg = (
+        "_semantic" in inspect.signature(triton.language.core.view).parameters
+    )
+    HAS_TRITON = True
+else:
+
+    def _raise_error(*args: Any, **kwargs: Any) -> Any:
+        raise RuntimeError("triton package is not installed")
+
+    class OutOfResources(Exception):  # type: ignore[no-redef]
+        pass
+
+    class PTXASError(Exception):  # type: ignore[no-redef]
+        pass
+
+    Config = object
+    CompiledKernel = object
+    KernelInterface = object
+    ASTSource = None
+    GPUTarget = None
+    _log2 = _raise_error
+    libdevice = None
+    math = None
+    knobs = None
+    builtins_use_semantic_kwarg = False
+
+    class triton:  # type: ignore[no-redef]
+        @staticmethod
+        def jit(*args: Any, **kwargs: Any) -> Any:
+            return _raise_error
+
+    class tl:  # type: ignore[no-redef]
+        @staticmethod
+        def constexpr(val: Any) -> Any:
+            return val
+
+        tensor = Any
+        dtype = Any
+
+    class JITFunction:  # type: ignore[no-redef]
+        pass
+
+    HAS_WARP_SPEC = False
+    triton_key = _raise_error
+    HAS_TRITON = False
+
+
+def cc_warp_size(cc: str | int) -> int:
+    if torch.version.hip:
+        cc_str = str(cc)
+        if "gfx10" in cc_str or "gfx11" in cc_str:
+            return 32
+        else:
+            return 64
+    else:
+        return 32
+
+
+try:
+    autograd_profiler = torch.autograd.profiler
+except AttributeError:  # Compile workers only have a mock version of torch
+
+    class autograd_profiler:  # type: ignore[no-redef]
+        _is_profiler_enabled = False
+
+
+__all__ = [
+    "Config",
+    "CompiledKernel",
+    "OutOfResources",
+    "KernelInterface",
+    "PTXASError",
+    "ASTSource",
+    "GPUTarget",
+    "tl",
+    "_log2",
+    "libdevice",
+    "math",
+    "triton",
+    "cc_warp_size",
+    "knobs",
+    "triton_key",
+]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_helpers.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e89868e216a5156c08a6c4922f16e1897eedeeb
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_helpers.py
@@ -0,0 +1,761 @@
+# mypy: allow-untyped-decorators
+# mypy: allow-untyped-defs
+import math as pymath
+import warnings
+from collections.abc import Callable
+from typing import Any, TypeVar
+
+from .triton_compat import (  # noqa: F401
+    _log2,
+    builtins_use_semantic_kwarg,
+    JITFunction,
+    libdevice,
+    math,
+    tl,
+    triton,
+)
+
+
+_T = TypeVar("_T")
+_LOG_2_E: tl.constexpr = tl.constexpr(pymath.log2(pymath.e))
+
+
+def set_driver_to_cpu():
+    driver = triton.runtime.driver
+    if backend := triton.backends.backends.get("cpu", None):
+        if isinstance(driver.active, backend.driver):
+            # Don't re-initialize backend if it is already active
+            return
+        driver.set_active(backend.driver())
+        return
+    # This can be a hard error once triton-cpu is merged into fbcode
+    warnings.warn(
+        "Could not find an active CPU backend. Generated kernels will not be executable!"
+    )
+
+
+def set_driver_to_gpu():
+    driver = triton.runtime.driver
+    for name, backend in triton.backends.backends.items():
+        if backend.driver.is_active() and name != "cpu":
+            # After https://github.com/triton-lang/triton/commit/b844d519bc5e86edf00fe6b3c6c2d1badcd509a4,
+            # `driver.active` can be of `LazyProxy` type and the sign of this - `_obj` attribute.
+            if (
+                isinstance(driver.active, backend.driver)
+                or hasattr(driver.active, "_obj")
+                and isinstance(driver.active._obj, backend.driver)
+            ):
+                # Don't re-initialize backend if it is already active
+                return
+            driver.set_active(backend.driver())
+            return
+    raise RuntimeError("Could not find an active GPU backend")
+
+
+def get_backend_options():
+    from triton.runtime import driver
+
+    target = driver.active.get_current_target()
+    backend = triton.compiler.compiler.make_backend(target)
+    options = backend.parse_options(dict())
+    return options.__dict__
+
+
+def get_constexprs(kernel: JITFunction) -> list[int]:
+    return [p.num for p in kernel.params if p.is_constexpr]
+
+
+@triton.jit
+def promote_to_tensor(x):
+    # Addition promotes to tensor for us
+    return x + tl.zeros((1,), tl.int1)
+
+
+@triton.jit
+def div_floor_integer(a, b):
+    # NOTE: a // b is C division, but we want floor division
+    # Based on c10::div_floor_integer
+    quot = a // b
+    remainder = a % b
+    fixed = tl.where(remainder != 0, quot - 1, quot)
+    return tl.where((a < 0) != (b < 0), fixed, quot)
+
+
+@triton.jit
+def remainder_integer(a, b):
+    # NOTE: a % b matches C division, not floor division
+    remainder = a % b
+    return tl.where((remainder != 0) & ((a < 0) != (b < 0)), remainder + b, remainder)
+
+
+@triton.jit
+def is_floating(x):
+    return promote_to_tensor(x).dtype.is_floating()
+
+
+@triton.jit
+def _prod_accumulate(a, b):
+    return a * b
+
+
+@triton.jit
+def prod(input, axis):
+    return tl.reduce(input, axis, _prod_accumulate)
+
+
+@triton.jit
+def minimum(a, b):
+    mask = a < b
+    if is_floating(a):
+        mask |= a != a
+    return tl.where(mask, a, b)
+
+
+@triton.jit
+def maximum(a, b):
+    mask = a > b
+    if is_floating(a):
+        mask |= a != a
+    return tl.where(mask, a, b)
+
+
+@triton.jit
+def min2(a, dim):
+    return tl.reduce(a, dim, minimum)
+
+
+@triton.jit
+def max2(a, dim):
+    return tl.reduce(a, dim, maximum)
+
+
+@triton.jit
+def minimum_with_index(a_value, a_index, b_value, b_index):
+    mask = a_value < b_value
+    equal = a_value == b_value
+    if is_floating(a_value):
+        a_isnan = a_value != a_value
+        b_isnan = b_value != b_value
+        mask |= a_isnan & (not b_isnan)
+        # Consider NaNs as equal
+        equal |= a_isnan & b_isnan
+
+    # Prefer lowest index if values are equal
+    mask |= equal & (a_index < b_index)
+    return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)
+
+
+@triton.jit
+def maximum_with_index(a_value, a_index, b_value, b_index):
+    mask = a_value > b_value
+    equal = a_value == b_value
+    if is_floating(a_value):
+        a_isnan = a_value != a_value
+        b_isnan = b_value != b_value
+        mask |= a_isnan & (not b_isnan)
+        # Consider NaNs as equal
+        equal |= a_isnan & b_isnan
+
+    # Prefer lowest index if values are equal
+    mask |= equal & (a_index < b_index)
+    return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)
+
+
+@triton.jit
+def min_with_index(value, index, dim):
+    return tl.reduce((value, index), dim, minimum_with_index)
+
+
+@triton.jit
+def max_with_index(value, index, dim):
+    return tl.reduce((value, index), dim, maximum_with_index)
+
+
+@triton.jit
+def exp(x, use_fast_math: tl.constexpr):
+    if use_fast_math:
+        return math.exp(x)
+    else:
+        return libdevice.exp(x)
+
+
+@triton.jit
+def online_softmax_reduce(lhs_max, lhs_sum, dim, use_fast_math: tl.constexpr):
+    out_max = max2(lhs_max, dim)
+    out_max_keepdim = tl.expand_dims(out_max, dim)
+    delta = tl.where(out_max_keepdim == float("-inf"), 0, lhs_max - out_max_keepdim)
+    out_sum = tl.sum(lhs_sum * exp(delta, use_fast_math), dim)
+    return out_max, out_sum
+
+
+@triton.jit
+def online_softmax_combine(lhs_max, lhs_sum, rhs_max, use_fast_math: tl.constexpr):
+    """
+    When we do combine, we assume lhs is the accumulator and rhs is the next
+    block of data.
+    Then rhs_sum is always 1. With that assumption, we can save some registers
+    and computation.
+    """
+    out_max = maximum(lhs_max, rhs_max)
+
+    lhs_scale = tl.where(
+        out_max == float("-inf"), 1.0, exp(lhs_max - out_max, use_fast_math)
+    )
+    rhs_scale = tl.where(
+        out_max == float("-inf"), 1.0, exp(rhs_max - out_max, use_fast_math)
+    )
+
+    # Should be
+    #   out_sum = lhs_sum * lhs_scale + rhs_sum * rhs_scale
+    # but since rhs_sum is all 1, we can simplify it.
+    out_sum = lhs_sum * lhs_scale + rhs_scale
+    return out_max, out_sum
+
+
+@triton.jit
+def welford_reduce(value, mean, m2, weight, first_iteration):
+    if first_iteration:
+        new_weight = tl.full(weight.shape, 1, weight.dtype)
+        new_mean = value
+        new_m2 = tl.zeros_like(m2)
+    else:
+        delta = value - mean
+        new_weight = weight + 1
+        new_mean = mean + delta / new_weight
+        new_m2 = m2 + delta * (value - new_mean)
+    return new_mean, new_m2, new_weight
+
+
+@triton.jit
+def welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
+    delta = mean_2 - mean_1
+    new_weight = weight_1 + weight_2
+    w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight)
+    return (
+        mean_1 + delta * w2_over_w,
+        m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,
+        new_weight,
+    )
+
+
+@triton.jit
+def welford(mean, m2, weight, dim):
+    return tl.reduce((mean, m2, weight), dim, welford_combine)
+
+
+@triton.jit
+def device_assert_then(cond, msg, r):
+    tl.device_assert(cond, msg)
+    return r
+
+
+@triton.jit
+def randint64(seed, offset, low, high):
+    r0, r1, _r2, _r3 = tl.randint4x(seed, offset)
+    r0 = r0.to(tl.uint64)
+    r1 = r1.to(tl.uint64)
+    result = r0 | (r1 << 32)
+    size = high - low
+    result = result % size.to(tl.uint64)
+    result = result.to(tl.int64) + low
+    return result
+
+
+@triton.jit
+def _any_combine(a, b):
+    return a | b
+
+
+@triton.jit
+def any(a, dim):
+    return tl.reduce(a, dim, _any_combine)
+
+
+@triton.jit
+def bucketize_binary_search(
+    values: tl.tensor,
+    boundaries_ptr: tl.tensor,
+    BOUNDARIES_SIZE: int,
+    BOUNDARIES_UNDERLYING_NUMEL: int,
+    BOUNDARIES_STRIDE: int,
+    boundary_indices: tl.tensor,
+    indexing_dtype: tl.dtype,
+    right: "bool",  # triton can't handle the unquoted bool annotation
+    sorter_ptr: tl.tensor,
+    SORTER_STRIDE: int,
+    sorter_indices: tl.tensor,
+):
+    """
+    See [Note: Inductor bucketize op]
+
+    Inputs:
+    -------
+    values: the values to bucketize.
+    boundaries_ptr: a pointer to the beginning of the boundaries tensor, in 1-D.
+    BOUNDARIES_SIZE: the length of the last dimension of the boundaries tensor (i.e. one
+    individual set of boundaries).
+    BOUNDARIES_UNDERLYING_NUMEL: the length of the boundaries tensor, in 1-D, ignoring
+    any striding.
+    BOUNDARIES_STRIDE: the stride of the last dimension of the boundaries tensor
+    boundary_indices: a tensor of the same size as "values"; each element is an index
+    into a 1-D, un-strided boundaries tensor, pointing to the first element in the set
+    of boundaries used for that value.
+    indexing_dtype: the dtype used for indexing into the boundaries tensor, and the
+    return dtype.
+    right: if true, use boundary intervals closed on the left; otherwise use intervals
+    closed on the right.
+    sorter_ptr: an optional pointer to a sorter tensor of the same shape as boundaries,
+    but potentially different striding.  If present, this allows us to treat boundaries
+    as sorted even if the elements of boundaries are unsorted.
+    SORTER_STRIDE: must be present if sorter_ptr is non-None; the stride of the last
+    dimension of the sorter tensor.
+    sorter_indices: must be present if sorter_ptr is non-None; see "boundary_indices".
+    BLOCK_SHAPE: the shape of the data block being processed.
+    """
+
+    low = tl.zeros(values.shape, dtype=indexing_dtype)
+    high = tl.full(values.shape, BOUNDARIES_SIZE, dtype=indexing_dtype)
+
+    full_range = BOUNDARIES_SIZE + 1
+    while full_range > 1:
+        mid = (high + low) // 2
+        mask = (
+            (mid * BOUNDARIES_STRIDE + boundary_indices) < BOUNDARIES_UNDERLYING_NUMEL
+        ).logical_and(mid < BOUNDARIES_SIZE)
+        mid_indices = (
+            mid
+            if sorter_ptr is None or SORTER_STRIDE is None
+            else tl.load(
+                sorter_ptr + sorter_indices + SORTER_STRIDE * mid,
+                mask=mask,
+                other=0,
+            )
+        )
+
+        bucket_upper_bound = tl.load(
+            boundaries_ptr + boundary_indices + BOUNDARIES_STRIDE * mid_indices,
+            mask=mask,
+            other=0,
+        )
+        if right:
+            is_above = values >= bucket_upper_bound
+        else:
+            is_above = values > bucket_upper_bound
+
+        low = tl.where(is_above & mask, mid + 1, low)
+        high = tl.where(is_above, high, mid)
+
+        full_range = (full_range + 1) // 2
+
+    return low
+
+
+@triton.jit
+def pack_value_flag(
+    value,
+    flag,
+    DTYPE_VALUE_AS_UINT: tl.constexpr,
+    DTYPE_PACK: tl.constexpr,
+):
+    # Workaround for triton bug, tensor.to doesn't unwrap constexpr values
+    DTYPE_VALUE_AS_UINT = tl.core._unwrap_if_constexpr(DTYPE_VALUE_AS_UINT)
+    bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth
+    uv = value.to(DTYPE_VALUE_AS_UINT, bitcast=True).to(DTYPE_PACK)
+    return flag.to(DTYPE_PACK) | (uv << bitwidth)
+
+
+@triton.jit
+def unpack_value(
+    pack,
+    DTYPE_VALUE,
+    DTYPE_VALUE_AS_UINT,
+):
+    # Workaround for triton bug, tensor.to doesn't unwrap constexpr values
+    DTYPE_VALUE = tl.core._unwrap_if_constexpr(DTYPE_VALUE)
+    DTYPE_VALUE_AS_UINT = tl.core._unwrap_if_constexpr(DTYPE_VALUE_AS_UINT)
+    bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth
+    value_uint = (pack >> bitwidth).to(DTYPE_VALUE_AS_UINT)
+    return value_uint.to(DTYPE_VALUE, bitcast=True)
+
+
+@triton.jit
+def unpack_flag(pack, DTYPE_FLAG):
+    return pack.to(DTYPE_FLAG)
+
+
+@triton.jit
+def exclusive_scan_decoupled_lookback(
+    scratch_base,
+    block_value,
+    index,
+    combine_fn,
+    DTYPE_VALUE_AS_UINT: tl.constexpr,
+    DTYPE_PACK: tl.constexpr,
+):
+    """Compute exclusive scan of a scalar value between blocks
+
+    Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back
+
+    scratch_base: Pointer to scratch space in global memory
+    block_value: Scalar value for this block
+    index: Scalar index of this block relative to the current scan
+    combine_fn: Function ``(value, value) -> value`` which is scanned over
+    DTYPE_VALUE_AS_UINT: A tl.uint{n} type equal in size to ``block_value``
+    DTYPE_PACK: Unsigned type twice the width of block_value
+
+    NOTE: This function is limited to values which are 32-bits or less because
+    we need to pack (value, flag) into a single unsigned int.
+    """
+    # Publish block sum so subsequent blocks don't get stuck waiting for us
+    DTYPE_VALUE = block_value.dtype
+    pack = pack_value_flag(
+        block_value,
+        tl.full(block_value.shape, 1, DTYPE_VALUE_AS_UINT),
+        DTYPE_VALUE_AS_UINT,
+        DTYPE_PACK,
+    )
+    if index > 0:
+        tl.atomic_xchg(scratch_base + index, pack, sem="relaxed")
+
+    # Calculate exclusive prefix scan
+    exclusive_prefix = tl.zeros([], DTYPE_VALUE)
+    prefix_valid = False
+    test_target = index - 1
+    while test_target >= 0:
+        # tl.atomic_load
+        flag = tl.full([], 0, DTYPE_VALUE_AS_UINT)
+        while flag == 0:
+            pack = tl.atomic_add(scratch_base + test_target, 0, sem="relaxed")
+            flag = unpack_flag(pack, DTYPE_VALUE_AS_UINT)
+
+        value = unpack_value(pack, DTYPE_VALUE, DTYPE_VALUE_AS_UINT)
+        if prefix_valid:
+            exclusive_prefix = combine_fn(value, exclusive_prefix)
+        else:
+            exclusive_prefix = value
+            prefix_valid = True
+
+        if flag == 2:
+            test_target = -1
+        else:
+            test_target = test_target - 1
+
+    # Make inclusive block sum visible to other blocks
+    if prefix_valid:
+        inclusive_prefix = combine_fn(exclusive_prefix, block_value)
+    else:
+        inclusive_prefix = block_value
+    pack = pack_value_flag(
+        inclusive_prefix,
+        tl.full([], 2, DTYPE_VALUE_AS_UINT),
+        DTYPE_VALUE_AS_UINT,
+        DTYPE_PACK,
+    )
+    tl.atomic_xchg(scratch_base + index, pack, sem="relaxed")
+    return exclusive_prefix
+
+
+@triton.jit
+def exclusive_scan_decoupled_lookback_64(scratch_base, block_value, index, combine_fn):
+    """Compute exclusive scan of a scalar value between blocks
+
+    Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back
+
+    scratch_base: Pointer to scratch space in global memory
+    block_value: Scalar value for this block, must be 64-bits wide
+    index: Scalar index of this block relative to the current scan
+    combine_fn: Function ``(value, value) -> value`` which is scanned over
+    init: Scalar value equal to the identity of combine_fn
+    """
+    # Publish block sum so subsequent blocks don't get stuck waiting for us
+    if index > 0:
+        block_value_u64 = block_value.to(tl.uint64, bitcast=True)
+        tl.store(scratch_base + 3 * index + 1, block_value_u64)
+        tl.debug_barrier()
+        flag_one = tl.full([], 1, tl.uint64)
+        tl.atomic_xchg(scratch_base + 3 * index + 0, flag_one, sem="release")
+
+    # Calculate exclusive prefix scan
+    exclusive_prefix = tl.zeros([], block_value.dtype)
+    prefix_valid = False
+    test_target = index - 1
+    while test_target >= 0:
+        flag = tl.full([], 0, tl.uint64)
+        while flag == 0:
+            flag = tl.atomic_add(scratch_base + 3 * test_target + 0, 0, sem="acquire")
+
+        value_u64 = tl.load(scratch_base + 3 * test_target + flag.to(tl.int32))
+        value = value_u64.to(block_value.dtype, bitcast=True)
+        if prefix_valid:
+            exclusive_prefix = combine_fn(value, exclusive_prefix)
+        else:
+            exclusive_prefix = value
+            prefix_valid = True
+
+        if flag == 2:
+            test_target = -1
+        else:
+            test_target = test_target - 1
+
+    # Make inclusive block sum visible to other blocks
+    if prefix_valid:
+        inclusive_prefix = combine_fn(exclusive_prefix, block_value)
+    else:
+        inclusive_prefix = block_value
+    inclusive_prefix_u64 = inclusive_prefix.to(tl.uint64, bitcast=True)
+    tl.store(scratch_base + 3 * index + 2, inclusive_prefix_u64)
+    tl.debug_barrier()
+    flag_two = tl.full([], 2, tl.uint64)
+    tl.atomic_xchg(scratch_base + 3 * index + 0, flag_two, sem="release")
+
+    return exclusive_prefix
+
+
+@triton.jit
+def frexp(x):
+    # TODO(isuruf): use inline_asm_elementwise here
+    y = libdevice.ilogb(x) + 1
+    exponent = tl.where(x == 0, 0, y)
+    mantissa = tl.where(x == 0, 0, libdevice.ldexp(x, -y))
+    return mantissa, exponent
+
+
+@triton.jit
+def _compare_and_swap_with_index(
+    x,
+    idxs,
+    rnumel,
+    flip,
+    i: tl.constexpr,
+    n_dims: tl.constexpr,
+    stable: tl.constexpr,
+    descending: tl.constexpr,
+):
+    n_outer: tl.constexpr = x.numel >> n_dims
+    shape: tl.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)]
+
+    idtype = tl.core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
+
+    y = tl.reshape(x, shape)
+    iy = y.to(idtype, bitcast=True)
+    # slice left/right with 'stride' 2**(n_dims - i - 1)
+    right_mask = tl.arange(0, 2)[None, :, None].to(idtype)
+    left_mask = (1 - right_mask).to(idtype)
+    ileft = tl.broadcast_to(tl.sum(iy * left_mask, 1).to(idtype)[:, None, :], shape)
+    iright = tl.broadcast_to(tl.sum(iy * right_mask, 1).to(idtype)[:, None, :], shape)
+    ileft = tl.reshape(ileft, x.shape)
+    iright = tl.reshape(iright, x.shape)
+    left = ileft.to(x.dtype, bitcast=True)
+    right = iright.to(x.dtype, bitcast=True)
+
+    # idx
+    y_idx = tl.reshape(idxs, shape)
+    left_idx = tl.broadcast_to(
+        tl.sum(y_idx * left_mask.to(y_idx.dtype), 1)[:, None, :], shape
+    )
+    right_idx = tl.broadcast_to(
+        tl.sum(y_idx * right_mask.to(y_idx.dtype), 1)[:, None, :], shape
+    )
+    left_idx = tl.reshape(left_idx, x.shape)
+    right_idx = tl.reshape(right_idx, x.shape)
+
+    # valid
+    if rnumel is None:
+        left_valid_mask = tl.full(x.shape, True, tl.int1)
+        right_valid_mask = tl.full(x.shape, True, tl.int1)
+    else:
+        left_valid_mask = left_idx < rnumel
+        right_valid_mask = right_idx < rnumel
+
+    # actual compare-and-swap
+    ix = x.to(idtype, bitcast=True)
+
+    # sort treats nan as having the higher value. comparisons with nan always return False.
+    # to align with sort semantics, we need to update descending to check if right_isnan,
+    # and ascending to check if left_isnan.
+    left_isnan = left != left
+    right_isnan = right != right
+
+    if descending:
+        cond = left < right
+        if is_floating(left):
+            if not stable:
+                cond = cond | right_isnan
+            else:
+                cond = cond | (right_isnan & (~left_isnan))
+
+    else:
+        cond = left > right
+        if is_floating(left):
+            if not stable:
+                cond = cond | left_isnan
+            else:
+                cond = cond | (left_isnan & (~right_isnan))
+
+    if stable:
+        # When stable sorting, tie break by index
+        eq = left == right
+        if is_floating(left):
+            eq = eq | (left_isnan & right_isnan)
+        cond = cond | (eq & (left_idx > right_idx))
+
+    cond = (right_valid_mask > left_valid_mask) | (
+        (right_valid_mask == left_valid_mask) & cond
+    )
+    cond = (cond ^ flip).to(tl.int1)
+    ret = ix ^ tl.where(cond, ileft ^ iright, tl.zeros_like(ix))
+    new_idxs = idxs ^ tl.where(cond, left_idx ^ right_idx, tl.zeros_like(idxs))
+
+    return ret.to(x.dtype, bitcast=True), new_idxs
+
+
+@triton.jit
+def _bitonic_merge_with_index(
+    x,
+    idxs,
+    rnumel,
+    stage: tl.constexpr,
+    alternating: tl.constexpr,
+    n_dims: tl.constexpr,
+    stable: tl.constexpr,
+    descending: tl.constexpr,
+):
+    n_outer: tl.constexpr = x.numel >> n_dims
+    tl.static_assert(stage <= n_dims)
+    # flip denotes whether to re-arrange sub-sequences of elements in ascending or
+    # descending order.
+    # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
+    # if flip = 00110011... then all the elements will be re-arranged alternatingly (with
+    # a stride of 2) at this stage
+    if alternating:
+        shape: tl.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage]
+        flip = tl.reshape(
+            tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape
+        )
+    else:
+        flip = False
+    # perform `stage` rounds of `compare-and-swap`
+    for i in tl.static_range(stage):
+        x, idxs = _compare_and_swap_with_index(
+            x, idxs, rnumel, flip, i + (n_dims - stage), n_dims, stable, descending
+        )
+    return x, idxs
+
+
+@triton.jit
+def sort_with_index(
+    x,  # value
+    idxs,  # index
+    rnumel,  # number of elements
+    dim: tl.constexpr = None,
+    stable: tl.constexpr = tl.constexpr(False),
+    descending: tl.constexpr = tl.constexpr(False),
+):
+    x, idxs = tl.broadcast(x, idxs)
+    # handle default dimension or check that it is the most minor dim
+    _dim: tl.constexpr = len(x.shape) - 1 if dim is None else dim
+    tl.static_assert(
+        _dim == len(x.shape) - 1, "only minor dimension is currently supported"
+    )
+    # iteratively run bitonic merge-sort steps
+    n_dims: tl.constexpr = _log2(x.shape[_dim])
+
+    for i in tl.static_range(1, n_dims + 1):
+        x, idxs = _bitonic_merge_with_index(
+            x,
+            idxs,
+            rnumel,
+            i,
+            alternating=i < n_dims,
+            n_dims=n_dims,
+            stable=stable,
+            descending=descending,
+        )
+    return x, idxs
+
+
+@triton.jit
+def select_one(x, mask, dim, keep_dims=False):
+    idtype = tl.core.get_int_dtype(x.dtype.primitive_bitwidth, signed=False)
+    ix = x.to(idtype, bitcast=True)
+    iy = tl.sum(ix * mask, dim, keep_dims=keep_dims)
+    return iy.to(x.dtype, bitcast=True)
+
+
+@triton.jit
+def x_grid_barrier(sem):
+    """
+    Wait for all other thread blocks in grid sharing same y/z program_id
+    to reach this barrier before returning.
+
+    Args:
+        sem: an uint32 semaphores, zero or 0x80000000 initialized.  Must be unique to each y/z program ID.
+    """
+    # ensure stores before this are visible
+    tl.debug_barrier()
+
+    one_i32 = 1
+    one_u32 = one_i32.to(tl.uint32)  # type: ignore[attr-defined]
+    expected = tl.num_programs(0).to(tl.uint32)
+    if tl.program_id(0) == 0:
+        nb = 0x80000000 - (expected - one_u32)
+    else:
+        nb = one_u32
+
+    old_arrive = tl.atomic_add(sem, nb, sem="release")
+
+    bar_flipped = False
+    while not bar_flipped:
+        # want a `ld.acquire.gpu.u32 $0,[$1];` but Triton doesn't have it
+        current_arrive = tl.atomic_add(sem, 0, sem="acquire")
+        # current_arrive = tl.load(sem, volatile=True)
+        bar_flipped = ((old_arrive ^ current_arrive) & 0x80000000) != 0
+
+    # TODO(jansel): is this needed?
+    tl.debug_barrier()
+
+
+def triton_builtin(f: Callable[..., _T]) -> Callable[..., _T]:
+    """
+    Decorator to mark a function as a Triton built-in function.  These functions
+    are evaluated at compile time.
+
+    Args:
+        f (function): The function to be marked as a Triton built-in.
+
+    Returns:
+        function: The same function, marked as a Triton built-in.
+    """
+    if builtins_use_semantic_kwarg:
+        # support Triton before and after https://github.com/triton-lang/triton/pull/7054
+        # and after https://github.com/triton-lang/triton/pull/7239
+        def wrapper(*args, _semantic, **kwargs):
+            kwargs["_builder"] = _semantic
+            return f(*args, **kwargs)
+    else:
+        wrapper = f  # type: ignore[assignment]
+
+    wrapper.__triton_builtin__ = True  # type: ignore[attr-defined]
+    return wrapper
+
+
+@triton_builtin
+def constexpr_next_power_of_2(
+    n: tl.constexpr, *, _builder: object = None
+) -> tl.constexpr:
+    """
+    A version triton.next_power_of_two that can be used within a kernel on constants.
+    """
+    assert isinstance(n, tl.constexpr)
+    return tl.constexpr(triton.next_power_of_2(n.value))
+
+
+@triton_builtin
+def if_mask(mask: Any, val, *, _builder: object = None) -> tl.constexpr:
+    """
+    Work around triton compile error: `ValueError: `other` cannot be provided without `mask``
+    A compile-time to check to return either `val` or `None` depending on the value of mask.
+    """
+    if isinstance(mask, tl.constexpr) and mask.value is None:
+        return tl.constexpr(None)
+    return val
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py
new file mode 100644
index 0000000000000000000000000000000000000000..2aefc498efb3e1731c433cb9924d6520e34cc16c
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py
@@ -0,0 +1,3874 @@
+# mypy: allow-untyped-defs
+from __future__ import annotations
+
+import builtins
+import copy
+import dataclasses
+import functools
+import hashlib
+import inspect
+import itertools
+import logging
+import math
+import operator
+import os
+import os.path
+import re
+import sys
+import threading
+import time
+from collections import namedtuple
+from typing import Any, Generic, Literal, TYPE_CHECKING, TypeVar, Union
+
+import torch
+from torch._dynamo.utils import counters, set_feature_use
+from torch._inductor import metrics
+from torch._prims_common import compute_required_storage_length
+from torch.utils._debug_mode import get_active_debug_mode
+from torch.utils._ordered_set import OrderedSet
+
+from ..triton_bundler import TritonBundler
+from ..utils import prefix_is_reduction, triton_version_uses_attrs_dict
+from . import triton_helpers
+from .autotune_cache import AutotuneCache
+from .benchmarking import benchmarker
+from .coordinate_descent_tuner import CoordescTuner
+from .hints import (
+    _NUM_THREADS_PER_WARP,
+    AutotuneHint,
+    DeviceProperties,
+    HeuristicType,
+    ReductionHint,
+    TileHint,
+    TRITON_MAX_BLOCK,
+    TRITON_MAX_RSPLIT,
+)
+from .runtime_utils import (
+    ceildiv,
+    conditional_product,
+    create_bandwidth_info_str,
+    dynamo_timed,
+    get_first_attr,
+    get_max_y_grid,
+    get_num_bytes,
+    next_power_of_2,
+    triton_cache_dir,
+    triton_config_to_hashable,
+    triton_hash_to_path_key,
+    validate_triton_config,
+)
+from .static_cuda_launcher import StaticallyLaunchedCudaKernel
+from .triton_compat import (
+    ASTSource,
+    autograd_profiler,
+    cc_warp_size,
+    CompiledKernel,
+    Config,
+    GPUTarget,
+    HAS_WARP_SPEC,
+    KernelInterface,
+    knobs,
+    OutOfResources,
+    PTXASError,
+    triton,
+)
+from .triton_helpers import get_constexprs
+
+
+class InductorConfig(Config):
+    """Inductor-specific Triton config with additional control flags"""
+
+    def __init__(self, *args, dynamic_scale_rblock=True, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.dynamic_scale_rblock = dynamic_scale_rblock
+
+
+class NoTritonConfigsError(RuntimeError):
+    pass
+
+
+if TYPE_CHECKING:
+    from collections.abc import Callable, Container, Hashable
+
+    from torch._guards import CompileId
+
+    LauncherType = Any
+
+_KernelType = Union[CompiledKernel, StaticallyLaunchedCudaKernel]
+_T = TypeVar("_T", bound=_KernelType)
+
+log = logging.getLogger(__name__)
+
+triton_name_sub = re.compile(r"^def [^(]+\(")
+
+
+def generate_lookup_hash_from_source_code(size_hints_str: str, source_code: str) -> str:
+    # Name agnostic + strip white space
+    fn_strip_name = re.sub(triton_name_sub, "(", source_code.strip(), count=1)
+    hash_str = size_hints_str + fn_strip_name
+    fn_hash = hashlib.sha256(hash_str.encode("utf-8")).hexdigest()
+
+    return fn_hash
+
+
+def lookup_autotune_config(size_hints, fn) -> Config | None:
+    lookup_table = torch._inductor.config.autotune_lookup_table
+    cached_config = None
+    if len(lookup_table) > 0 and "_fused_" in fn.src:
+        fn_hash = generate_lookup_hash_from_source_code(str(size_hints), fn.src)
+        if fn_hash in lookup_table:
+            config_dict = lookup_table[fn_hash]
+            block_configs = {k: v for k, v in config_dict.items() if "BLOCK" in k}
+            cached_config = Config(
+                block_configs,
+                num_warps=config_dict["num_warps"],
+                num_stages=config_dict["num_stages"],
+            )
+
+    return cached_config
+
+
+def get_total_reduction_numel(numels: dict[str, int]) -> int:
+    return conditional_product(
+        *[numel for prefix, numel in numels.items() if prefix_is_reduction(prefix)]
+    )
+
+
+def autotune_hints_to_configs(
+    hints: OrderedSet[AutotuneHint],
+    size_hints,
+    block_size: int,
+    device_props: DeviceProperties,
+) -> list[Config]:
+    """
+    AutotuneHints can be attached to the metadata of triton kernels for providing
+    suggestions about what to try for autotuning. One reason to do this is if there are
+    some configs that are only useful in specific scenarios, in which case we can avoid
+    wasting compile time on autotuning unless we know we are in one of those scenarios.
+
+    Based on those hints, this function will generate a list of additional autotuning
+    configs to try.
+    """
+    xyz_options: tuple[tuple[int, int | None, int | None], ...]
+    configs: list[Config] = []
+    for hint in hints:
+        if hint == AutotuneHint.ONE_ELEMENT_PER_THREAD:
+            if len(size_hints) == 1:
+                xyz_options = ((block_size // 4, None, None),)
+            elif len(size_hints) == 2:
+                xyz_options = ((block_size // 4, 1, None), (1, block_size // 4, None))
+            elif len(size_hints) == 3:
+                xyz_options = (
+                    (block_size // 4, 1, 1),
+                    (1, block_size // 4, 1),
+                    (1, 1, block_size // 4),
+                )
+            configs.extend(
+                triton_config(
+                    size_hints,
+                    *xyz,
+                    num_elements_per_warp=(
+                        device_props.warp_size if device_props.warp_size else 32
+                    ),
+                )
+                for xyz in xyz_options
+            )
+
+    return configs
+
+
+def _dump_launch_params(args, kwargs, launcher, kernel_name, grid):
+    call_args = []
+    call_kwargs = {}
+    for arg in args:
+        if isinstance(arg, (int, bool)):
+            call_args.append(str(arg))
+        else:
+            call_args.append("T")
+    for k, v in kwargs.items():
+        if isinstance(arg, (int, bool)):
+            call_kwargs[k] = v
+        else:
+            call_kwargs[k] = v
+    call_kwargs.update(launcher.config.kwargs)
+    call_kwargs["num_warps"] = launcher.config.num_warps
+    call_kwargs["num_stages"] = launcher.config.num_stages
+    if HAS_WARP_SPEC:
+        call_kwargs["num_consumer_groups"] = getattr(
+            launcher.config, "num_consumer_groups", 0
+        )
+        call_kwargs["num_buffers_warp_spec"] = getattr(
+            launcher.config, "num_buffers_warp_spec", 0
+        )
+    args_str = [*call_args]
+    args_str.extend(f"{k}={v}" for k, v in call_kwargs.items())
+    args_str = ", ".join(args_str)
+    abs_path = os.path.abspath(sys.argv[0])
+    with open(f"{abs_path}.launch_params", "a") as f:
+        f.write(f"{kernel_name} | {args_str} | {grid!r}\n")
+
+
+def check_autotune_cache(
+    configs: list[Config], filename: str | None, inductor_meta: dict[str, Any]
+) -> tuple[list[Config], AutotuneCache | None, dict[str, Any]]:
+    """
+    Given a list of configs, checks autotune cache and return metadata
+    """
+    autotune_cache = None
+    autotune_cache_info = {}
+    disabled = inductor_meta.get("force_disable_caches", False)
+    if (
+        not disabled
+        and filename is not None
+        and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning"))
+        and os.environ.get("TRITON_INTERPRET", "0") != "1"
+    ):
+        configs_hash = hash_configs(configs)
+
+        autotune_cache = AutotuneCache.create(inductor_meta, filename, configs_hash)
+        if autotune_cache:
+            if best_config := autotune_cache.read_best(inductor_meta, configs):
+                configs = [best_config]
+                autotune_cache_info["best_config"] = triton_config_to_hashable(
+                    best_config
+                )
+                autotune_cache_info["autotune_cache_state"] = "hit"
+
+            else:
+                autotune_cache_info["autotune_cache_state"] = "miss"
+                autotune_cache_info["num_configs"] = len(configs)
+                if inductor_meta.get("coordinate_descent_tuning"):
+                    autotune_cache_info["coordesc_tuning"] = True
+                    if len(configs) == 1:
+                        # This is the config that coordinate descent tuning started at, which
+                        # is not the same as the final config chosen (i.e. only_config, best_config)
+                        autotune_cache_info["coordesc_tuning_start_config"] = (
+                            triton_config_to_hashable(configs[0])
+                        )
+    else:
+        if len(configs) == 1:
+            autotune_cache_info["autotune_cache_state"] = "only 1 config"
+            autotune_cache_info["only_config"] = triton_config_to_hashable(configs[0])
+
+        if disabled:
+            autotune_cache_info["autotune_cache_state"] = "force_disabled"
+            log.debug("autotune caching is disabled by config.force_disable_caches")
+
+    return configs, autotune_cache, autotune_cache_info
+
+
+class CachingAutotuner(KernelInterface):
+    """
+    Simplified version of Triton autotuner that has no invalidation
+    key and caches the best config to disk to improve cold start times.
+    Unlike the main triton Autotuner, this version can precompile all
+    configs, and does not rely on the Triton JIT.
+    """
+
+    def __init__(
+        self,
+        fn,
+        triton_meta,  # passed directly to triton
+        configs,
+        save_cache_hook,
+        mutated_arg_names: list[str],  # see [Note: clone mutated buffers]
+        optimize_mem,
+        heuristic_type,
+        size_hints=None,
+        inductor_meta=None,  # metadata not relevant to triton
+        custom_kernel=False,  # whether the kernel is inductor-generated or custom
+        filename: str | None = None,
+        reset_to_zero_arg_names: list[str] | None = None,
+        autotune_cache_info: dict[str, Any] | None = None,
+    ):
+        super().__init__()
+
+        assert len(configs) > 0, "Non-empty TritonConfig list required for compiling"
+        # makes sure there are no pre-hooks on any of the triton configs
+        for cfg in configs:
+            validate_triton_config(cfg)
+
+        self.fn = fn
+        self.device_props: DeviceProperties = triton_meta["device"]
+        self.triton_meta = {
+            **triton_meta,
+            "device": self.device_props.index,
+            "device_type": self.device_props.type,
+        }
+        self.inductor_meta = {} if inductor_meta is None else inductor_meta
+        # Add device properties to inductor_meta for use by coordinate descent tuner
+        self.inductor_meta["warp_size"] = self.device_props.warp_size
+        self.inductor_meta["max_threads_per_block"] = (
+            self.device_props.max_threads_per_block
+        )
+        self.deterministic_mode = self.inductor_meta.get("deterministic", False)
+
+        self.save_cache_hook = save_cache_hook
+        self.mutated_arg_names = mutated_arg_names
+        self.reset_to_zero_arg_names = (
+            [] if reset_to_zero_arg_names is None else reset_to_zero_arg_names
+        )
+        self.optimize_mem = optimize_mem
+        cached_config = lookup_autotune_config(size_hints, fn)
+        self.configs = [cached_config] if cached_config else configs
+
+        self.heuristic_type = heuristic_type
+        self.custom_kernel = custom_kernel
+        self.cuda_kernel_saved = False
+        self.autotune_cache_info = autotune_cache_info
+        if log.isEnabledFor(logging.DEBUG):
+            log.debug(
+                "CachingAutotuner gets %d configs for %s",
+                len(self.configs),
+                self.fn.__name__,
+            )
+            for c in self.configs:
+                log.debug(c)
+
+        self.compile_results: list[CompileResult[_KernelType]] = []
+        self.launchers: list[LauncherType] = []
+        self.lock = threading.Lock()
+        if os.getenv("TRITON_CACHE_DIR") is None:
+            os.environ["TRITON_CACHE_DIR"] = triton_cache_dir(
+                self.triton_meta.get("device", 0)
+            )
+        log.debug("Triton cache dir: %s", os.environ["TRITON_CACHE_DIR"])
+
+        self.size_hints = size_hints
+        self.is_mix_order_reduction = self.inductor_meta.get("RSPLIT_SIZE") is not None
+        self.coordesc_tuner = CoordescTuner(
+            is_mm=False,
+            is_native_matmul=triton_meta.get("native_matmul", False),
+            is_mix_order_reduction=self.is_mix_order_reduction,
+            name=self.fn.__name__,
+            size_hints=size_hints,
+            inductor_meta=self.inductor_meta,
+        )
+        self.filename = filename
+
+        # used for profiling
+        self.kernel_hash: str = ""
+
+        # Kernels are stored in the codecache with the filename as a hash of the code.
+        # We rely on this to obtain the kernel hash
+        if self.filename is not None:
+            base_name = os.path.basename(self.filename)
+            if ".py" in base_name:
+                self.kernel_hash = os.path.splitext(base_name)[0]
+
+        self.precompile_time_taken_ns = 0
+        self.autotune_time_taken_ns = 0
+        # Dumps the launch configs after autotuning.
+        self.dump_launch_params = (
+            os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_PARAMS", "0") == "1"
+        )
+
+        self.triton_interpret = os.environ.get("TRITON_INTERPRET", "0") == "1"
+
+        # Compile-time info included in runtime logginging
+        self.compile_id: CompileId | None = None
+        self.is_backward = False
+
+        # Mode for launch grid calculation
+        self.grid_mode: Literal["python", "cpp"] = "python"
+
+    def is_statically_launchable(self):
+        """
+        Checks if every compiled kernel is statically launchable, which
+        allows us to efficiently cache it in FXGraphCache
+        """
+        if not self.compile_results:
+            return False
+        return all(
+            isinstance(x, StaticTritonCompileResult) for x in self.compile_results
+        )
+
+    def recheck_autotune_cache(
+        self, reload_kernel_from_src: Callable[[], CachingAutotuner]
+    ) -> None:
+        """
+        On cache load on static autotuner, we need to recheck the autotune cache, since
+        a best config could have been found from a previous run
+        """
+        assert self.is_statically_launchable()
+
+        configs = [result.config for result in self.compile_results]
+
+        (cached_configs, _, autotune_cache_info) = check_autotune_cache(
+            configs, self.filename, self.inductor_meta
+        )
+        self.autotune_cache_info = autotune_cache_info
+        # I.e. there was an autotune cache hit
+        if len(cached_configs) == 1 and len(configs) > 1:
+            best_config = cached_configs[0]
+            # Grab the best compiled config, if it's in the list of available ones
+            best_config_hash = triton_config_to_hashable(best_config)
+
+            for compile_result in self.compile_results:
+                if triton_config_to_hashable(compile_result.config) == best_config_hash:
+                    self.compile_results = [compile_result]
+                    return
+
+            # If the best config isn't in our list of compile results,
+            # it's likely because it was found by coordesc after the cache
+            # already saved
+            if best_config.found_by_coordesc:
+                with dynamo_timed("CachingAutotuner.slow_precompile_config"):
+                    if self.fn.fn is None:
+                        self.fn = reload_kernel_from_src().fn
+                    self.compile_results = [self._precompile_config(best_config)]
+
+    def set_compile_info(self, compile_id: CompileId | None, is_backward: bool) -> None:
+        self.compile_id = compile_id
+        self.is_backward = is_backward
+
+    def precompile(
+        self,
+        warm_cache_only=False,
+        reload_kernel: Callable[[], CachingAutotuner] | None = None,
+        static_triton_bundle_key: str | None = None,
+    ):
+        if warm_cache_only:
+            self._precompile_worker()
+            return
+        with self.lock:
+            # Helper function for reloading a kernel generated in a worker
+            # in the parent class. Normally we don't need to reload the kernel
+            # in the parent process, but in certain cases (coordesc tuning, dynamic_scale_rblock),
+            # we need to actually run compilation on the parent process
+            if reload_kernel is not None:
+                self._reload_kernel = reload_kernel
+            self._precompile_worker()
+            if static_triton_bundle_key is not None and self.is_statically_launchable():
+                TritonBundler.put_static_autotuner(static_triton_bundle_key, self)
+            self._make_launchers()
+            self._dynamic_scale_rblock()
+
+    def _precompile_worker(self):
+        if self.compile_results:
+            for result in self.compile_results:
+                TritonBundler.put(
+                    triton_hash_to_path_key(result.kernel.hash),  # type: ignore[attr-defined]
+                    self.triton_meta.get("device", 0),
+                )
+            return
+        assert not self.launchers
+        if not self.configs:
+            raise NoTritonConfigsError("No triton configs are available")
+
+        compile_results = []
+        exc = None
+        for c in self.configs:
+            try:
+                compile_results.append(self._precompile_config(c))
+            except (OutOfResources, PTXASError) as e:
+                exc = e
+        if len(compile_results) == 0:
+            raise NoTritonConfigsError(
+                f"No valid triton configs. {type(exc).__name__}: {exc}"
+            )
+        self.compile_results = compile_results
+        self.configs = None
+
+    def _dynamic_scale_rblock(self):
+        # TODO(jansel): we should find a way to move this extra compile into the worker process
+        # Currently it relies on _make_launchers(), which requires a cuda context, to populate nreg.
+        device_prop = self.device_props
+        if (
+            not self.deterministic_mode
+            and self.inductor_meta.get("dynamic_scale_rblock", True)
+            and not self.inductor_meta.get("persistent_reduction")
+            and self.heuristic_type == HeuristicType.REDUCTION
+            and self.size_hints is not None
+            # Disable for Intel as Triton is not ready to return n_regs for a compiled_binary.
+            and device_prop.type in ["cuda", "hip"]
+            and device_prop.major
+            and (device_prop.major >= 8 or torch.version.hip)
+            and device_prop.regs_per_multiprocessor is not None
+        ):
+            assert device_prop.regs_per_multiprocessor
+            assert device_prop.max_threads_per_multi_processor
+            assert device_prop.multi_processor_count
+            seen_config_hashes: OrderedSet[Hashable] | None = None
+            warp_size = device_prop.warp_size or 32
+            for result in self.compile_results:
+                triton_config = result.config
+                compiled_binary = result.kernel
+                assert len(self.size_hints) >= 2
+                xblock = triton_config.kwargs.get("XBLOCK", 1)
+                reduction_kwargs = [
+                    kwarg for kwarg in triton_config.kwargs if kwarg.startswith("R")
+                ]
+                rblocks = [triton_config.kwargs[kwarg] for kwarg in reduction_kwargs]
+                total_block = (self.size_hints["x"] + xblock - 1) // xblock
+                nreg = getattr(compiled_binary, "n_regs", None)
+                if nreg is None:
+                    continue
+
+                # make sure rblocks are not too small
+                if conditional_product(*rblocks) <= 64:
+                    continue
+
+                # each SM of A100 has 65536 32-bit registers. To maximize
+                # the theoretical occupancy, we need run 2048 threads on each
+                # SM. So each thread should use no more than 65536 / 2048
+                # = 32 registers. In cases where occupancy matters, and each
+                # thread uses too many registers, reduce R0_BLOCK to reduce
+                # the register usage.
+                # For kernel https://gist.github.com/shunting314/e4cccc031fe30d378b9b23c08c238cbd
+                # from PLBartForCausalLM, latency improve from
+                # 7.795ms to 4.883ms.
+                #
+                if (
+                    nreg
+                    <= device_prop.regs_per_multiprocessor
+                    // device_prop.max_threads_per_multi_processor
+                ):
+                    continue
+
+                nreg_per_warp = nreg * warp_size
+                nreg_per_block = nreg_per_warp * triton_config.num_warps
+
+                # Previously we set max_blocks_per_sm to 'max_threads_per_multi_processo / (32 * num_warps)'
+                # The formula below is a tighter upper bound since we have the assumption that
+                #   nreg > device_prop.regs_per_multiprocessor // device_prop.max_threads_per_multi_processor
+                # due to the if condition above and:
+                #   regs_per_multiprocessor / nreg_per_block
+                #   = regs_per_multiprocessor / (nreg * 32 * num_warps)
+                #   < regs_per_multiprocessor / ((regs_per_multiprocessor / max_threads_per_multi_processor) * 32 * num_warps)
+                #   = max_threads_per_multi_processor / (32 * num_warps)
+                # Using a tighter upper bound can reveal more optimization opportunities.
+                max_blocks_per_sm = max(
+                    device_prop.regs_per_multiprocessor // nreg_per_block, 1
+                )
+
+                if total_block <= max_blocks_per_sm * device_prop.multi_processor_count:
+                    # no need to improve occupancy
+                    continue
+                new_config = copy.deepcopy(triton_config)
+
+                # Reduce the largest Rn_BLOCK by a factor of 2.
+                largest_rkwarg: str = max(
+                    reduction_kwargs, key=triton_config.kwargs.__getitem__
+                )
+                new_config.kwargs[largest_rkwarg] //= 2
+
+                if seen_config_hashes is None:
+                    seen_config_hashes = OrderedSet(
+                        [
+                            triton_config_to_hashable(x.config)
+                            for x in self.compile_results
+                        ]
+                    )
+                new_config_hash = triton_config_to_hashable(new_config)
+                if new_config_hash in seen_config_hashes:
+                    continue
+                seen_config_hashes.add(new_config_hash)
+                log.debug(
+                    "Dynamically scale down %s from TritonConfig(%s) and get a new TritonConfig(%s)",
+                    largest_rkwarg,
+                    triton_config,
+                    new_config,
+                )
+                if self.fn.fn is None:
+                    """
+                    We are in the parent process, while this program was compiled in a worker
+                    and the fn was dropped in prepare_for_pickle().  We haven't loaded the module
+                    containing the real fn yet.
+                    """
+                    assert hasattr(self, "_reload_kernel")
+                    assert callable(self._reload_kernel)
+                    self.fn = self._reload_kernel().fn
+                self.compile_results.append(self._precompile_config(new_config))  # noqa: B909
+
+            self._make_launchers()
+
+    def _make_launchers(self):
+        if len(self.launchers) == len(self.compile_results):
+            return
+
+        from torch._dynamo.device_interface import DeviceGuard
+
+        device_interface = self.get_device_interface()
+
+        # load binary to the correct device
+        with DeviceGuard(device_interface, self.triton_meta["device"]):
+            # need to initialize context
+            with dynamo_timed(
+                "CachingAutotuner.synchronize",
+                # Deliberately avoid overloading pt2_compile_events:
+                log_pt2_compile_event=False,
+            ):
+                device_interface.synchronize(device_interface.current_device())
+
+            launchers = []
+            exc = None
+            for result in self.compile_results:
+                try:
+                    launchers.append(result.make_launcher())
+
+                except (OutOfResources, PTXASError, torch.cuda.OutOfMemoryError) as e:
+                    exc = e
+        if len(launchers) == 0:
+            raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}")
+        self.launchers = launchers
+
+    def prepare_for_pickle(self) -> tuple[Any, Any, Any, Any, Any, Any]:
+        """Drop stuff from triton.JITFunction that does not pickle.
+        This must be called after precompile so that these things are no longer needed.
+        Returns a tuple of old values
+        """
+        old_values = (
+            self.fn.fn,
+            self.fn.__globals__,
+            self.fn.used_global_vals,
+            self.fn.repr,
+            self.launchers,
+            getattr(self.fn, "_hash_lock", None),
+        )
+        self.fn.fn = None
+        self.fn.__globals__ = None
+        self.fn.used_global_vals = None
+        self.fn.repr = _ConstRepr(self.fn.repr(self.fn))
+        self.launchers = []
+        self.fn._hash_lock = None
+        return old_values
+
+    def restore_after_unpickle(
+        self, old_values: tuple[Any, Any, Any, Any, Any, Any] | None
+    ) -> None:
+        if old_values:
+            (
+                self.fn.fn,
+                self.fn.__globals__,
+                self.fn.used_global_vals,
+                self.fn.repr,
+                self.launchers,
+                self.fn._hash_lock,
+            ) = old_values
+        else:
+            # even if we don't need/have specific values, we do need the
+            # _hash_lock to be a valid RLock
+            self.fn._hash_lock = threading.RLock()
+
+    def prepare_for_caching(self) -> None:
+        """
+        Statically Launched CUDA Kernels have a raw cubin on them
+        that we don't need to store in the cache(since TritonBundler handles the collection for us)
+        """
+        for result in self.compile_results:
+            if isinstance(result, StaticTritonCompileResult):
+                # Don't save this in the inductor cache, as it is very large
+                result.kernel.cubin_raw = None
+
+    def __getstate__(self) -> dict[str, Any]:
+        assert not self.launchers, (
+            "pickle should not be called with after make_launchers()"
+        )
+        return {
+            **self.__dict__,
+            "lock": None,
+        }
+
+    def __setstate__(self, state: dict[str, Any]) -> None:
+        self.__dict__.update(state)
+        self.lock = threading.Lock()
+
+    def get_device_interface(self):
+        # this code cannot run in compile workers, because it imports from torch
+        from torch._dynamo.device_interface import get_interface_for_device
+
+        return get_interface_for_device(self.device_props.type.replace("hip", "cuda"))
+
+    def _create_compile_meta(self, cfg: Config) -> dict[str, Any]:
+        """
+        Create compilation metadata for a given autotuner config. This involves
+        processing the Config kwargs so that the kwargs that are not part
+        of the triton signature are passed in as options to triton.compile
+        instead
+        """
+        compile_meta = copy.deepcopy(self.triton_meta)
+        compile_meta["num_warps"] = cfg.num_warps
+        compile_meta["num_stages"] = cfg.num_stages
+
+        cfg_kwargs = cfg.kwargs
+        if self.device_props.type == "hip":
+            cfg_kwargs = {**cfg_kwargs}
+            for k in ("matrix_instr_nonkdim", "waves_per_eu", "kpack"):
+                if k in cfg_kwargs:
+                    compile_meta[k] = cfg_kwargs.pop(k)
+        compile_meta["constants"].update(cfg_kwargs)
+
+        for i in get_constexprs(self.fn):
+            arg_name = self.fn.arg_names[i]
+            if arg_name not in compile_meta["constants"] and (
+                arg_name == "num_warps" or arg_name == "num_stages"
+            ):
+                compile_meta["constants"][arg_name] = getattr(cfg, arg_name)
+        if HAS_WARP_SPEC:
+            compile_meta["num_consumer_groups"] = getattr(cfg, "num_consumer_groups", 0)
+            compile_meta["num_buffers_warp_spec"] = getattr(
+                cfg, "num_buffers_warp_spec", 0
+            )
+        compile_meta["debug"] = self.inductor_meta.get(
+            "assert_indirect_indexing", True
+        ) and not self.inductor_meta.get("is_hip", False)
+
+        # device type will be "hip" rather than "cuda" here
+        compile_meta["device_type"] = self.device_props.type
+        compile_meta["cc"] = self.device_props.cc
+
+        return compile_meta
+
+    def _create_compile_options(
+        self, cfg: Config, compile_meta: dict[str, Any]
+    ) -> dict[str, Any]:
+        """
+        Create options to pass to triton.compile based on the compile metadata
+        and the given config.
+        """
+        options = {
+            "num_warps": compile_meta["num_warps"],
+            "num_stages": compile_meta["num_stages"],
+            "debug": compile_meta["debug"],
+            "sanitize_overflow": False,  # turn off additional asserts added for overflow checks
+        }
+        if "enable_fp_fusion" in compile_meta:
+            options["enable_fp_fusion"] = compile_meta["enable_fp_fusion"]
+        if HAS_WARP_SPEC:
+            options.update(
+                {
+                    "num_consumer_groups": compile_meta.get("num_consumer_groups", 0),
+                    "num_buffers_warp_spec": compile_meta.get(
+                        "num_buffers_warp_spec", 0
+                    ),
+                }
+            )
+        if self.device_props.type == "cuda":
+            options.update(
+                {
+                    "launch_cooperative_grid": compile_meta.get(
+                        "launch_cooperative_grid", False
+                    ),
+                    "launch_pdl": compile_meta.get("launch_pdl", False),  # True
+                }
+            )
+        if self.device_props.type == "hip":
+            if "waves_per_eu" in compile_meta:
+                options["waves_per_eu"] = compile_meta["waves_per_eu"]
+            if "matrix_instr_nonkdim" in compile_meta:
+                options["matrix_instr_nonkdim"] = compile_meta["matrix_instr_nonkdim"]
+
+        return options
+
+    def _precompile_config(self, cfg: Config) -> CompileResult[_KernelType]:
+        """Ahead of time compile a given autotuner config."""
+        compile_meta = self._create_compile_meta(cfg)
+
+        if self.device_props.type == "cpu":
+            triton_helpers.set_driver_to_cpu()
+        else:
+            triton_helpers.set_driver_to_gpu()
+
+        if not ASTSource:
+            raise RuntimeError("Installed triton version too old, please upgrade")
+
+        compile_args = (
+            ASTSource(
+                self.fn,
+                compile_meta["signature"],
+                compile_meta["constants"],
+                compile_meta["configs"][0],
+            ),
+        )
+
+        if self.device_props.type == "mtia":
+            from mtia.host_runtime.torch_mtia.acc_flags import (  # type: ignore[import-not-found]
+                build_codename,
+            )
+
+            arch = build_codename()
+        else:
+            arch = compile_meta["cc"]
+
+        target = GPUTarget(
+            compile_meta["device_type"],
+            arch,
+            cc_warp_size(compile_meta["cc"]),
+        )
+
+        options = self._create_compile_options(cfg, compile_meta)
+
+        compile_kwargs = {
+            "target": target,
+            "options": options,
+        }
+
+        try:
+            binary = triton.compile(*compile_args, **compile_kwargs)
+        except Exception:
+            log.exception(
+                "Triton compilation failed: %s\n%s\nmetadata: %s",
+                self.inductor_meta.get("kernel_name", "triton_"),
+                self.fn.src,
+                compile_meta,
+            )
+            raise
+
+        # Simulate JIT Hook call
+        if (
+            torch._inductor.config.run_jit_post_compile_hook
+            and knobs
+            and getattr(knobs.runtime, "jit_post_compile_hook", None)
+        ):
+            try:
+                hook = knobs.runtime.jit_post_compile_hook
+
+                # base args everyone should get
+                call_kwargs = dict(
+                    key=getattr(self.fn, "cache_key", self.kernel_hash or str(self.fn)),
+                    repr=getattr(self.fn, "src", None),
+                    fn=self.fn,
+                    compile=binary,
+                    is_manual_warmup=False,
+                    already_compiled=True,
+                )
+
+                # only add inductor_args if the hook takes it
+                sig = inspect.signature(hook)
+                params = sig.parameters
+                if "inductor_args" in params and "config_args" in self.inductor_meta:
+                    call_kwargs["inductor_args"] = self.inductor_meta["config_args"]
+
+                hook(**call_kwargs)
+            except Exception:
+                log.exception("jit_post_compile_hook failed")
+
+        TritonBundler.put(
+            triton_hash_to_path_key(binary.hash), self.triton_meta.get("device", 0)
+        )
+        # If the binary has a cubin file to directly launch, save it on the binary
+        static_launcher = StaticTritonCompileResult.can_statically_launch(
+            binary, self.inductor_meta, self.triton_meta, self.heuristic_type
+        )
+
+        if static_launcher is not None:
+            result = StaticTritonCompileResult(
+                static_launcher, cfg, compile_meta, self.inductor_meta
+            )
+            return result
+
+        return TritonCompileResult(binary, cfg, compile_meta, self.inductor_meta)
+
+    def bench(self, launcher, *args, with_profiler=False, **kwargs):
+        """Measure the performance of a given launcher"""
+        # we don't skip configs with spilled registers when auto-tuning custom
+        # (user-written) Triton kernels, as (i) we don't have any knowledge or
+        # control over the kernel code; (ii) there is empirical evidence that
+        # for some (complicated) custom Triton kernels, a register-spilling
+        # config may yield the best latency.
+        if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get(
+            "spill_threshold", 16
+        ):
+            log.debug(
+                "Skip config %s because of register spilling: %d",
+                launcher.config,
+                launcher.n_spills,
+            )
+            return float("inf")
+
+        device_interface = self.get_device_interface()
+        stream = device_interface.get_raw_stream(device_interface.current_device())
+
+        cpu_copies = self.copy_args_to_cpu_if_needed(*args, **kwargs)
+
+        def kernel_call():
+            cloned_args, cloned_kwargs = self.maybe_clone_args(
+                cpu_copies, *args, **kwargs
+            )
+            # reset to zero before evaluating any config
+            self.reset_to_zero_args(*args, **kwargs)
+            kernel_name = self.inductor_meta.get("kernel_name", "triton kernel")
+            if autograd_profiler._is_profiler_enabled:
+                profiler_kwargs = self.get_profiler_kwargs(stream, launcher)
+                with torch._C._profiler._RecordFunctionFast(
+                    kernel_name,
+                    cloned_args,
+                    profiler_kwargs,
+                ):
+                    try:
+                        launcher(
+                            *cloned_args,
+                            **cloned_kwargs,
+                            stream=stream,
+                        )
+                    except Exception:
+                        log.error("Failed during launch %s: ", kernel_name)
+                        raise
+
+            else:
+                try:
+                    launcher(
+                        *cloned_args,
+                        **cloned_kwargs,
+                        stream=stream,
+                    )
+                except Exception:
+                    log.error("Failed during launch %s: ", kernel_name)
+                    raise
+            self.restore_args_from_cpu(cpu_copies)
+
+        # only use profiler when not already in a profiler instance
+        if with_profiler and not autograd_profiler._is_profiler_enabled:
+            from torch._inductor.utils import do_bench_using_profiling
+
+            return do_bench_using_profiling(kernel_call, warmup=10, rep=40)
+
+        benchmark_kwargs = (
+            {}
+            if self.device_props.type == "cpu"
+            else {"rep": 40, "is_vetted_benchmarking": True}
+        )
+        return benchmarker.benchmark(
+            fn=kernel_call,
+            device=self.device_props.type,
+            **benchmark_kwargs,  # type: ignore[arg-type]
+        )
+
+    def copy_args_to_cpu_if_needed(self, *args, **kwargs):
+        """
+        To support benchmarking in the presence of mutated args, we need to avoid
+        autotuning contanminating them. We try to pass cloned args to the kernel.
+        If those clones would increase the peak memory usage, however, we instead
+        copy to cpu and restore them after each iteration. Figure out the args
+        to be copied and do the copying.
+        """
+        if not self.optimize_mem:
+            return {}
+
+        copies = {}
+        try:
+            budget = torch.cuda.max_memory_allocated() - torch.cuda.memory_allocated()
+        except RuntimeError:
+            # Possibly a custom CUDA allocator, see https://github.com/pytorch/pytorch/issues/163257
+            return {}
+
+        def maybe_copy(name, arg):
+            if name in self.mutated_arg_names and arg.is_cuda:
+                nonlocal budget
+                assert isinstance(arg, torch.Tensor)
+                required_storage_length = compute_required_storage_length(
+                    arg.size(),
+                    arg.stride(),
+                    0,
+                )
+                size = required_storage_length * arg.element_size()
+                if size > budget:
+                    cpu_arg = torch.empty_strided(
+                        (required_storage_length,),
+                        (1,),
+                        dtype=arg.dtype,
+                        device="cpu",
+                        pin_memory=True,
+                    )
+                    cpu_arg.copy_(
+                        arg.as_strided((required_storage_length,), (1,)),
+                        non_blocking=True,
+                    )
+                    copies[name] = (arg, cpu_arg)
+                else:
+                    budget -= size
+
+        for name, arg in zip(self.fn.arg_names, args):
+            maybe_copy(name, arg)
+
+        for name, arg in kwargs.items():
+            maybe_copy(name, arg)
+
+        return copies
+
+    def restore_args_from_cpu(self, cpu_copies):
+        for pair in cpu_copies.values():
+            arg, cpu_arg = pair
+            required_storage_length = compute_required_storage_length(
+                arg.size(),
+                arg.stride(),
+                0,
+            )
+            arg.as_strided((required_storage_length,), (1,)).copy_(
+                cpu_arg, non_blocking=True
+            )
+
+    def reset_to_zero_args(self, *args, **kwargs):
+        if not self.reset_to_zero_arg_names:
+            return
+        for i, arg in enumerate(args):
+            if self.fn.arg_names[i] in self.reset_to_zero_arg_names:
+                assert isinstance(
+                    arg,
+                    torch.Tensor,
+                ), (
+                    "self.reset_to_zero_arg_names should only contain valid argument names"
+                )
+                arg.zero_()
+
+        for name, arg in kwargs.items():
+            if name in self.reset_to_zero_arg_names:
+                assert isinstance(
+                    arg,
+                    torch.Tensor,
+                ), (
+                    "self.reset_to_zero_arg_names should only contain valid argument names"
+                )
+                arg.zero_()
+
+    def maybe_clone_args(
+        self, exclude: Container[str], *args, **kwargs
+    ) -> tuple[list[Any], dict[str, Any]]:
+        """
+        Prepare new args and kwargs by cloning any in-place buffers
+        (that are not in the provided exclusion list), to avoid autotune
+        contaminating them. Avoid cloning the other buffers because it
+        leads to increased memory usage.
+        """
+        from ..compile_fx import clone_preserve_strides
+
+        def prepare_arg(name, arg):
+            if name in self.mutated_arg_names and name not in exclude:
+                assert isinstance(arg, torch.Tensor)
+                return clone_preserve_strides(arg)
+            else:
+                return arg
+
+        cloned_args = [
+            prepare_arg(name, arg)
+            for name, arg in itertools.zip_longest(self.fn.arg_names[: len(args)], args)
+        ]
+        cloned_kwargs = {name: prepare_arg(name, arg) for name, arg in kwargs.items()}
+        return cloned_args, cloned_kwargs
+
+    def clone_args(self, *args, **kwargs) -> tuple[list[Any], dict[str, Any]]:
+        return self.maybe_clone_args(OrderedSet(), *args, **kwargs)
+
+    def benchmark_all_configs(self, *args, **kwargs):
+        with (
+            dynamo_timed(
+                "CachingAutotuner.benchmark_all_configs",
+                log_pt2_compile_event=True,
+                metadata={"kernel_name": self.inductor_meta.get("kernel_name")},
+                dynamo_compile_column_us="runtime_triton_autotune_time_us",
+                compile_id=self.compile_id,
+                is_backward=self.is_backward,
+                log_waitcounter=True,
+                waitcounter_name_override="triton_autotuner",
+            ),
+            # Temporarily disable due to spam
+            # compilation_callback.callback_handler.install_callbacks(
+            #     compilation_callback.CallbackTrigger.TRITON_AUTOTUNING,
+            #     str(self.compile_id),
+            # ),
+        ):
+            timings = {
+                launcher: self.bench(launcher, *args, **kwargs)
+                for launcher in self.launchers
+            }
+
+            for k, v in timings.items():
+                self.coordesc_tuner.cache_benchmark_result(k.config, v)
+
+            if log.isEnabledFor(logging.DEBUG):
+                log.debug("Benchmark all input configs for %s, get:", self.fn.__name__)
+                for k, v in timings.items():
+                    log.debug(
+                        "%s: %f, nreg %d, nspill %d, #shared-mem %s",
+                        k.config,
+                        v,
+                        k.n_regs,
+                        k.n_spills,
+                        k.shared,
+                    )
+
+            if metrics.is_metric_table_enabled("kernel_autotune"):
+                if self.fn.fn is None:
+                    self.fn = self._reload_kernel().fn
+
+                kernel_path = self.fn.fn.__code__.co_filename
+                kernel_name = self.fn.__name__
+
+                for k, v in timings.items():
+                    metrics.log_kernel_autotune_result(
+                        kernel_path, kernel_name, k.config, v
+                    )
+
+            self.reset_to_zero_args(*args, **kwargs)
+            return timings
+
+    def autotune_to_one_config(self, *args, **kwargs):
+        """Do the actual autotuning"""
+        start_time = time.time_ns()
+        timings = self.benchmark_all_configs(*args, **kwargs)
+        benchmark_time_taken_ns = time.time_ns() - start_time
+        self.launchers = [builtins.min(timings, key=timings.get)]
+        self.autotune_time_taken_ns = (
+            self.precompile_time_taken_ns + benchmark_time_taken_ns
+        )
+
+        # log the best config
+        launcher = self.launchers[0]
+        log.debug(
+            "Best config for %s: %s: %f, nreg %d, nspill %d, #shared-mem %s",
+            self.fn.__name__,
+            launcher.config,
+            timings[launcher],
+            launcher.n_regs,
+            launcher.n_spills,
+            launcher.shared,
+        )
+
+        if self.save_cache_hook:
+            self.save_cache_hook(
+                launcher.config,
+                self.autotune_time_taken_ns,
+                triton_cache_hash=launcher.cache_hash,
+            )
+
+    def save_gpu_kernel(self, stream, launcher):
+        key = self.inductor_meta.get("kernel_name", None)  # unique kernel name
+        assert key is not None, "kernel_name can not be None"
+        params = {
+            "mangled_name": (
+                launcher.bin.metadata.name
+                if hasattr(launcher.bin.metadata, "name")
+                else launcher.bin.metadata["name"]
+            ),
+            "num_warps": (
+                launcher.bin.num_warps
+                if hasattr(launcher.bin, "num_warps")
+                else launcher.bin.metadata.num_warps
+            ),
+            "shared_mem": (
+                launcher.bin.shared
+                if hasattr(launcher.bin, "shared")
+                else launcher.bin.metadata.shared
+            ),
+            "stream": stream,
+            # User defined triton kernels will have arbitrary kwarg names
+            "config": config_to_dict(launcher.config),
+            "inductor_meta": self.inductor_meta,
+            "triton_meta": self.triton_meta,
+            "def_args": launcher.def_args,
+            "call_args": launcher.call_args,
+            "global_scratch": launcher.global_scratch,
+            "profile_scratch": launcher.profile_scratch,
+        }
+        if self.device_props.type == "xpu":
+            # On the XPU backend, threads_per_warp is not always 32.
+            # For Intel GEMM Triton kernels, it can be 16.
+            # This information must be preserved so that the Cpp wrapper
+            # can launch the kernel with the correct configuration.
+            params["threads_per_warp"] = getattr(
+                launcher.bin.metadata, "threads_per_warp", 32
+            )
+
+        from torch._inductor import config
+        from torch._inductor.codecache import CudaKernelParamCache
+
+        bin_type = {"hip": "hsaco", "xpu": "spv"}.get(self.device_props.type, "cubin")
+        binary = launcher.bin.asm[bin_type]
+
+        # ROCm multi-arch: capture LLVM IR
+        if torch.version.hip and config.aot_inductor.emit_multi_arch_kernel:
+            # Multi-arch ROCm: Capture LLVM IR for cross-architecture compilation
+            asm_type = "ll"
+
+            # llir is the key to obtain LLVM IR from triton
+            asm = launcher.bin.asm.get("llir", None)
+
+            # CRITICAL: Multi-arch compilation cannot proceed without LLVM IR
+            # Fail fast with clear error message pointing to the issue
+            if not asm:
+                available_keys = list(launcher.bin.asm.keys())
+                raise RuntimeError(
+                    f"ROCm multi-arch requires LLVM IR, but none found. "
+                    f"Available keys: {available_keys}. "
+                    f"Triton may need to be patched to emit LLVM IR."
+                )
+
+        # Everything else: capture architecture-specific assembly
+        else:
+            asm_type = {"hip": "amdgcn", "cuda": "ptx", "xpu": "spv"}.get(
+                self.device_props.type, None
+            )
+            asm = launcher.bin.asm.get(asm_type, None)
+
+        CudaKernelParamCache.set(key, params, binary, bin_type, asm, asm_type)
+        self.cuda_kernel_saved = True
+
+    def coordinate_descent_tuning(self, launcher, *args, **kwargs):
+        """
+        Coordinate descent tuning can be run with or without max-autotune.
+
+        The only difference between these two is the starting config for coordinate_descent tuning.
+        E.g., assuming regular autotune only get one config C1; while max-autotune get 4 configs C1, C2, C3, C4
+        and max-autotune figure out C3 is the best.
+
+        Then if coordinate desecnt tuning is run with max-autotune disabled, it will start from C1;
+        while if coordinate descent tuning is run with max-autotune enabled, it will start from C3.
+        """
+        if self.heuristic_type in (
+            HeuristicType.TEMPLATE,
+            HeuristicType.USER_AUTOTUNE,
+            HeuristicType.FIXED,
+        ):
+            # skip triton template
+            return launcher
+
+        if self.deterministic_mode and self.heuristic_type in (
+            HeuristicType.REDUCTION,
+            HeuristicType.PERSISTENT_REDUCTION,
+            HeuristicType.SPLIT_SCAN,
+        ):
+            # Not only RBLOCK size matters for numericals of reduction.
+            # num_warps also matters since that affect how much data
+            # is handled by each thread, how many warp-reduction we do
+            # in parallel and how much data is there for block
+            # reduction.
+            return launcher
+
+        with dynamo_timed(
+            "CachingAutotuner.coordinate_descent_tuning",
+            # These generate too many pt2_compile_event logs:
+            log_pt2_compile_event=False,
+            metadata={"kernel_name": self.inductor_meta.get("kernel_name")},
+            dynamo_compile_column_us="runtime_triton_autotune_time_us",
+            compile_id=self.compile_id,
+            is_backward=self.is_backward,
+            log_waitcounter=True,
+            waitcounter_name_override="triton_autotuner",
+        ):
+            return self._coordinate_descent_tuning(launcher, *args, **kwargs)
+
+    def _coordinate_descent_tuning(self, launcher, *args, **kwargs):
+        config2launcher = {launcher.config: launcher}
+
+        # TODO: should we just load the kernels ahead of time if we know we're going to call this?
+        if self.fn.fn is None:
+            """
+            We are in the parent process, while this program was compiled in a worker
+            and the fn was dropped in prepare_for_pickle().  We haven't loaded the module
+            containing the real fn yet.
+            """
+            assert hasattr(self, "_reload_kernel")
+            assert callable(self._reload_kernel)
+            self.fn = self._reload_kernel().fn
+
+        def benchmark_one_config(config):
+            with self.lock:
+                launcher = self._precompile_config(config).make_launcher()
+            config2launcher[config] = launcher
+
+            out = self.bench(launcher, *args, **kwargs)
+            counters["inductor"]["coordesc_tuning_bench"] += 1
+            log.debug(
+                "COORDESC: %s: %f, nreg %d, nspill %d, #shared-mem %d",
+                launcher.config,
+                out,
+                launcher.n_regs,
+                launcher.n_spills,
+                launcher.shared,
+            )
+            return out
+
+        assert not (
+            self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION
+            and "R0_BLOCK" in launcher.config.kwargs
+        ), (
+            "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have R0_BLOCK"
+        )
+        start_time = time.time_ns()
+        best_config = self.coordesc_tuner.autotune(
+            benchmark_one_config, launcher.config, None
+        )
+        coordesc_time_taken_ns = time.time_ns() - start_time
+        best_config.found_by_coordesc = True
+
+        if self.save_cache_hook:
+            self.save_cache_hook(
+                best_config,
+                self.autotune_time_taken_ns + coordesc_time_taken_ns,
+                found_by_coordesc=True,
+            )
+
+        if best_config not in config2launcher:
+            # On a Coordesc cache hit, we might not have loaded the launcher
+            # This can happen because PyCodeCache saves CachingAutotuners in memory,
+            # even for separate compile IDs (which can have different inputs without changing output code)
+            config2launcher[best_config] = self._precompile_config(
+                best_config
+            ).make_launcher()
+
+        fn_hash = generate_lookup_hash_from_source_code(
+            str(self.size_hints), self.fn.src
+        )
+        log.debug("Function hash %s has best config %s", fn_hash, best_config)
+        return config2launcher[best_config]
+
+    def get_profiler_kwargs(self, stream, launcher):
+        kernel_kwargs_str = ",".join(
+            f"{k}={v}" for (k, v) in launcher.config.kwargs.items()
+        )
+
+        ret = {
+            "kernel_file": (self.filename or ""),
+            "kernel_hash": self.kernel_hash,
+            "kernel_backend": "triton",
+            "stream": stream,
+            "num_warps": launcher.config.num_warps,
+            "num_stages": launcher.config.num_stages,
+            "kernel_kwargs": kernel_kwargs_str,
+        }
+        if "kernel_name" in self.inductor_meta:
+            ret["kernel_name"] = self.inductor_meta["kernel_name"]
+        if "kernel_flop" in self.inductor_meta:
+            ret["kernel_flop"] = self.inductor_meta["kernel_flop"]
+        if "kernel_num_gb" in self.inductor_meta:
+            ret["kernel_num_gb"] = self.inductor_meta["kernel_num_gb"]
+        return ret
+
+    def run(
+        self,
+        *args,
+        stream,
+        benchmark_run=False,
+        **kwargs,
+    ):  # type:ignore[override]
+        """Launch triton kernel call and return result."""
+        debug_mode = get_active_debug_mode()
+        debug_call = None
+        if debug_mode:
+            arg_names = list(self.triton_meta.get("signature", {}).keys())
+            kernel_kwargs = dict(zip(arg_names, args))
+            kernel_kwargs.update(kwargs)
+            debug_call = debug_mode.record_triton_kernel(
+                kernel_name=self.fn.__name__, kwargs=kernel_kwargs
+            )
+
+        if hasattr(triton, "set_allocator"):
+
+            def alloc_fn(size: int, align: int, stream: int | None):
+                return torch.empty(
+                    size, dtype=torch.int8, device=self.device_props.type
+                )
+
+            triton.set_allocator(alloc_fn)
+
+        if self.triton_interpret:
+            args, grid = self._interpret_args_grid(args, self.configs[0])
+            return self.fn[grid](
+                *args,
+                **kwargs,
+                **self.configs[0].kwargs,
+            )
+
+        if len(self.launchers) != 1:
+            if len(self.launchers) == 0:
+                start_time = time.time_ns()
+                self.precompile()
+                self.precompile_time_taken_ns = time.time_ns() - start_time
+            if len(self.launchers) > 1:
+                self.autotune_to_one_config(*args, **kwargs)
+
+        if not getattr(
+            self.launchers[0].config, "found_by_coordesc", False
+        ) and self.inductor_meta.get("coordinate_descent_tuning", False):
+            self.launchers = [
+                self.coordinate_descent_tuning(self.launchers[0], *args, **kwargs)
+            ]
+
+        (launcher,) = self.launchers
+        if launcher.store_cubin and (not benchmark_run or not self.cuda_kernel_saved):
+            self.save_gpu_kernel(stream, launcher)
+
+        # PyTorch execution trace replay calls CachingAutotuner::run() instead of calls launcher
+        # so _RecordFunctionFast need to capture the args into CachingAutotuner::run()
+        # make a copy here to avoid mutating the original args
+        args_without_constexprs = tuple(args)
+
+        if self.dump_launch_params:
+            new_args, grid = self._interpret_args_grid(args, launcher.config)
+            _dump_launch_params(new_args, kwargs, launcher, self.fn.__name__, grid)
+
+        # it is faster than entering and exiting a context manager, even if the context
+        # manager is a nullcontext.
+        if autograd_profiler._is_profiler_enabled:
+            profiler_kwargs = self.get_profiler_kwargs(stream, launcher)
+
+            with torch._C._profiler._RecordFunctionFast(
+                self.inductor_meta.get("kernel_name", "triton kernel"),
+                args_without_constexprs,
+                profiler_kwargs,
+            ):
+                result = launcher(
+                    *args,
+                    **kwargs,
+                    stream=stream,
+                )
+        else:
+            result = launcher(
+                *args,
+                **kwargs,
+                stream=stream,
+            )
+
+        if debug_call:
+            debug_call.finalize(self.get_device_interface())
+        return result
+
+    def _interpret_args_grid(
+        self, args: tuple[Any, ...], cfg: Config
+    ) -> tuple[tuple[Any, ...], tuple[int, int, int]]:
+        if triton_version_uses_attrs_dict():
+
+            def filtered_signature() -> list[str]:
+                # constexprs are not passed in as args
+                new_signature: list[str] = []
+                from triton.runtime.interpreter import InterpretedFunction
+
+                for i, x in enumerate(self.triton_meta["signature"].keys()):
+                    if isinstance(self.fn, InterpretedFunction):
+                        # These are torch compiled triton kernels that definitely
+                        # have block size configs. Dynamo does not currently
+                        # trace user defined triton kernels when TRITON_INTERPRET=1
+                        if x not in cfg.kwargs:
+                            new_signature.append(x)
+                    elif i not in get_constexprs(self.fn):
+                        # use constexprs rather than just configs since user
+                        # defined triton kernels may not have any configs
+                        new_signature.append(x)
+
+                return new_signature
+
+        else:
+
+            def filtered_signature() -> list[str]:
+                return list(self.triton_meta["signature"].keys())
+
+        grid = GridExpr.from_meta(
+            self.inductor_meta, cfg, mode=self.grid_mode
+        ).eval_slow(
+            dict(
+                zip(
+                    [
+                        *filtered_signature(),
+                        *self.inductor_meta.get("extra_launcher_args", ()),
+                    ],
+                    args,
+                )
+            )
+        )
+        if self.inductor_meta.get("extra_launcher_args"):
+            args = args[: -len(self.inductor_meta["extra_launcher_args"])]
+        return args, grid
+
+
+class _ConstRepr:
+    def __init__(self, value: str):
+        self.value = value
+
+    def __call__(self, _=None) -> str:
+        return self.value
+
+
+class CompileResult(Generic[_T]):
+    """
+    Base class representing compiled result.
+    """
+
+    def __init__(
+        self,
+        kernel: _T,
+        config: Config,
+        compile_meta: dict[str, Any],
+        inductor_meta: dict[str, Any],
+    ):
+        self.kernel = kernel
+        self.config = config
+        self.compile_meta = compile_meta
+        self.inductor_meta = inductor_meta
+
+    def make_launcher(self) -> LauncherType: ...
+
+    def _gen_launcher_code(self, scope, def_args, runner_args) -> LauncherType:
+        grid = GridExpr.from_meta(self.inductor_meta, self.config)
+        # grid.prefix is usually empty, grid.x_grid is something like `-(xnumel//-1024)`
+        lines = [
+            f"def launcher({', '.join(def_args)}, stream):",
+            *[f"    {line}" for line in grid.prefix],
+            f"    grid_0 = {grid.x_grid}",
+            f"    grid_1 = {grid.y_grid}",
+            f"    grid_2 = {grid.z_grid}",
+            f"    runner({', '.join(runner_args)})",
+        ]
+        launcher_code = "\n".join(lines)
+        exec(launcher_code, scope)
+        return scope["launcher"]
+
+    def _get_arg_lists(
+        self, arg_names, constexprs
+    ) -> tuple[list[str], list[str], OrderedSet[str]]:
+        """
+        Return a bunch of intermediate lists of args needed for generating
+        launcher code.
+        """
+        compile_meta = self.compile_meta
+        cfg = self.config
+        known_constants = OrderedSet(
+            arg for i, arg in enumerate(arg_names) if i in constexprs
+        )
+
+        """
+        https://github.com/pytorch/pytorch/issues/115344
+
+        self.fn.constexprs doesn't properly deal with None args, so when we filter out
+        an arg in UserDefinedTritonKernel.codegen, we need to filter it here as well.
+        We also don't want to modify self.fn.
+
+        We know that we removed something from the signature if:
+            1. It's in compile_meta["constants"]
+            2. It isn't a constant we already know about
+                Note: The value of interest has already been added to compile_meta['constants'],
+                    so we use self.fn.constexprs instead.
+            3. It isn't in the compile_meta signature
+        """
+        none_args = OrderedSet(
+            k
+            for k, v in compile_meta["constants"].items()
+            if v is None and k not in known_constants
+        )
+        none_args = none_args.difference(OrderedSet(compile_meta["signature"].keys()))
+
+        def _convert_constant(constant):
+            if isinstance(constant, str):
+                return "r'" + constant + "'"
+            else:
+                return repr(constant)
+
+        if triton_version_uses_attrs_dict():
+            call_args = arg_names
+            def_args = arg_names
+            implicit_constants = OrderedSet(
+                (
+                    "num_warps",
+                    "num_stages",
+                )
+            ).union(OrderedSet(k for k in known_constants))
+            if implicit_constants := implicit_constants & OrderedSet(
+                compile_meta["constants"].keys()
+            ):
+                # num_warps/num_stages are special implicit args that are not in the signature
+                # see test_triton_kernel_special_params
+                def_args = [arg for arg in def_args if arg not in implicit_constants]
+                repl = {
+                    k: _convert_constant(compile_meta["constants"].get(k))
+                    for k in implicit_constants
+                }
+                call_args = [repl.get(arg, arg) for arg in call_args]
+        else:
+            call_args = [
+                arg
+                for i, arg in enumerate(arg_names)
+                if i not in constexprs and arg not in none_args
+            ]
+            cfg_dict = config_to_dict(cfg)
+            def_args = [
+                name
+                for name in arg_names
+                if name not in cfg_dict and name not in none_args
+            ]
+
+        if "extra_launcher_args" in self.inductor_meta:
+            def_args = [*def_args, *self.inductor_meta["extra_launcher_args"]]
+
+        return call_args, def_args, none_args
+
+
+class CannotStaticallyLaunchKernel(Exception):
+    pass
+
+
+class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]):
+    """
+    TritonCompileResult that uses StaticCudaLauncher,
+    which vastly simplifies the setup and metadata needed to be kept.
+    """
+
+    @staticmethod
+    def can_statically_launch(
+        kernel: CompiledKernel,
+        inductor_meta: dict[str, Any],
+        triton_meta: dict[str, Any],
+        heuristic_type: HeuristicType,
+    ) -> StaticallyLaunchedCudaKernel | None:
+        if not torch._inductor.config.use_static_cuda_launcher:
+            return None
+
+        def check_can_launch() -> StaticallyLaunchedCudaKernel:
+            if triton_meta.get("device_type") != "cuda":
+                # Only cuda kernels
+                raise CannotStaticallyLaunchKernel("Non-cuda device")
+
+            if torch._inductor.config.cpp_wrapper:
+                # If we're running with cpp wrapper, it doesn't
+                # make sense to statically compile since everything
+                # is codegenned anyway
+                raise CannotStaticallyLaunchKernel("Cpp wrapper enabled")
+
+            if (
+                heuristic_type == HeuristicType.USER_AUTOTUNE
+                and not torch._inductor.config.static_launch_user_defined_triton_kernels
+            ):
+                # Don't support user defined triton kernels yet
+                raise CannotStaticallyLaunchKernel("User defined triton kernel")
+
+            if inductor_meta.get("store_cubin"):
+                # Requires storing the entire binary
+                raise CannotStaticallyLaunchKernel("store_cubin is enabled")
+
+            if getattr(kernel.metadata, "launch_pdl", False) or getattr(
+                kernel.metadata, "launch_cooperative_grid", False
+            ):
+                raise CannotStaticallyLaunchKernel(
+                    "static launch does not support launch attributes"
+                )
+
+            cubin_location = os.path.join(
+                triton_cache_dir(triton_meta.get("device", 0)),
+                triton_hash_to_path_key(kernel.hash),
+                f"{kernel.src.fn.__name__}.cubin",
+            )
+
+            if not os.path.exists(cubin_location):
+                raise CannotStaticallyLaunchKernel(
+                    f"Cubin path not found: {cubin_location}"
+                )
+
+            else:
+                kernel._cubin_path = cubin_location
+
+            try:
+                static_kernel = StaticallyLaunchedCudaKernel(kernel)
+            except NotImplementedError as e:
+                raise CannotStaticallyLaunchKernel(f"NotImplemented: {str(e)}") from e
+
+            return static_kernel
+
+        try:
+            result = check_can_launch()
+            return result
+        except CannotStaticallyLaunchKernel as e:
+            log.info("Bypassing StaticallyLaunchedCudaKernel due to %s", str(e))  # noqa: G200
+            if torch._inductor.config.strict_static_cuda_launcher:
+                raise e
+            return None
+
+    def reload_cubin_path(self):
+        """
+        When loading from cache on disk, we want to reload cubin
+        files from their appropriate location on disc.
+        """
+        cubin_location = os.path.join(
+            triton_cache_dir(self.compile_meta.get("device", 0)),
+            triton_hash_to_path_key(self.kernel.hash),
+            f"{self.kernel.name}.cubin",
+        )
+        if not os.path.exists(cubin_location):
+            if self.kernel.cubin_raw is not None:
+                # We saved the raw cubin, so write it to he appropriate location
+                self.kernel.reload_cubin_from_raw(cubin_location)
+            else:
+                raise RuntimeError(
+                    "Cubin file saved by TritonBundler not found at %s", cubin_location
+                )
+        self.kernel.cubin_path = cubin_location
+
+    def make_launcher(self) -> LauncherType:
+        # If at least one static make_launcher call occurs,
+        # we're sure static cuda launcher was used for this compile
+        set_feature_use("static_cuda_launcher", True)
+        # Load the binary on the parent
+        if not self.kernel.cubin_path:
+            self.reload_cubin_path()
+        device = self.compile_meta.get("device", 0)
+        if device is None:
+            device = 0
+        self.kernel.load_kernel(device)
+        scope = {
+            "runner": self.kernel.run,
+        }
+
+        # NOTE: Constexpr handling for triton and static cuda launcher
+
+        # Triton kernels have two types of constexprs: *declared* ones, which are ones the user
+        # has explicitly declared as tl.constexpr, and *implied* ones, which are expressions triton
+        # deems constant while compiling/analyzing the code (i.e. unused parameters, for example)
+
+        # Triton kernels handle constexprs slightly differently depending on which version of triton
+        # we care about (we support 3.2.0 and 3.3.0).
+
+        # In 3.2.0, triton kernels do not require passing any declared constexprs into the kernel
+        # In 3.3.0, triton kernels require all declared constexprs be passed into the kernel, where
+        # they are subsequently ignored.
+        # When statically launching, since we're launching from the triton generated cubin, we actually want to
+        # always get rid of all const exprs, declared or implied, since the underlying cubin file has all
+        # of the constants stripped away anyway.
+
+        # But CachingAutotuner.run will pass us a different number of arguments depending on
+        # whether or not we're in triton 3.2.0 or later, so we grab def_args with the same logic
+        # as the (non static) TritonCompileResult. We then generate call_args ourselves, since we
+        # want only a subset of the arguments passed to triton.
+        # Here, arg_names is exactly fn.src.arg_names and declared_constexprs is exactly fn.src.constexprs,
+        # which matches behavior with regular TritonCompileResult
+        _, def_args, none_args = self._get_arg_lists(
+            self.kernel.arg_names, self.kernel.declared_constexprs
+        )
+
+        call_args = [
+            arg
+            for i, arg in enumerate(self.kernel.arg_names)
+            if i not in self.kernel.full_constexprs and arg not in none_args
+        ]
+
+        # StaticallyLaunchedCudaKernel.run takes in order grid_0, grid_1, grid_2, stream, and call_args
+        runner_args = ["grid_0", "grid_1", "grid_2", "stream", *call_args]
+        launcher = self._gen_launcher_code(scope, def_args, runner_args)
+        launcher.config = self.config  # type: ignore[attr-defined]
+        launcher.n_regs = self.kernel.n_regs  # type: ignore[attr-defined]
+        launcher.n_spills = self.kernel.n_spills  # type: ignore[attr-defined]
+        launcher.shared = self.kernel.shared  # type: ignore[attr-defined]
+        launcher.cache_hash = triton_hash_to_path_key(self.kernel.hash)  # type: ignore[attr-defined]
+        launcher.store_cubin = False  # type: ignore[attr-defined]
+        launcher._is_static = True  # type: ignore[attr-defined]
+        return launcher
+
+
+class TritonCompileResult(CompileResult[CompiledKernel]):
+    """
+    Upstream Triton CompileKernel can not be pickled.  This is a wrapper
+    to support serialization and generate the launcher function.
+    """
+
+    @staticmethod
+    @functools.lru_cache(32)
+    def _kernel_metadata_cls(fields: tuple[str, ...]) -> Any:
+        return namedtuple("KernelMetadata", sorted(fields))
+
+    @staticmethod
+    def _serialize_metadata(metadata):
+        """
+        Triton uses a nested class called KernelMetadata to store metadata information.
+        Pickle does not work well with nested namedtuples, as the namedtuple doesn't appear
+        in the toplevel namespace of the module. So these serialization/deser functions
+        are used to convert the namedtuples to a dict and back.
+
+        As for packed_metadata, depending on the triton backend, KernelMetadata can be
+        a namedtuple, or a regular tuple! So the serialization function branches on whether
+        the metadata to be serialized is a namedtuple or regular, serializable one.
+        """
+
+        def is_namedtuple(obj) -> bool:
+            return (
+                isinstance(obj, tuple)
+                and hasattr(obj, "_asdict")
+                and hasattr(obj, "_fields")
+            )
+
+        if is_namedtuple(metadata):
+            return metadata._asdict()
+        else:
+            return metadata
+
+    @staticmethod
+    def _deserialize_metadata(metadata):
+        if isinstance(metadata, dict):
+            return TritonCompileResult._kernel_metadata_cls(tuple(metadata.keys()))(
+                **metadata
+            )
+        else:
+            return metadata
+
+    def __getstate__(self) -> dict[str, Any]:
+        kernel = self.kernel
+        # replace the fields that don't pickle nicely
+        kernel_state = {
+            **kernel.__dict__,
+            # See doc about serializing metadata above
+            "metadata": self._serialize_metadata(kernel.metadata),
+            "packed_metadata": self._serialize_metadata(
+                getattr(kernel, "packed_metadata", None)
+            ),
+            "module": None,  # regenerated by kernel._init_handles()
+            "function": None,  # regenerated by kernel._init_handles()
+            "run": None,  # regenerated by kernel._init_handles()
+        }
+        return {**self.__dict__, "kernel": kernel_state}  # type: ignore[dict-item]
+
+    def __setstate__(self, state: dict[str, Any]) -> None:
+        # src = ASTSource.__new__(ASTSource)
+        # src.__setstate__(state["kernel"]["src"])
+        # TODO(jansel): need to fixup src.fn which is now None
+        kernel = CompiledKernel.__new__(CompiledKernel)
+        metadata = state["kernel"]["metadata"]
+        packed_metadata = state["kernel"]["packed_metadata"]
+        kernel.__dict__.update(
+            {
+                **state["kernel"],
+                # "src": src,
+                "metadata": self._deserialize_metadata(metadata),
+                "packed_metadata": self._deserialize_metadata(packed_metadata),
+            }
+        )
+        self.__dict__.update(state)
+        self.kernel = kernel
+
+    def make_launcher(self) -> LauncherType:
+        """
+        Launching triton kernels is performance sensitive, we compile
+        a custom Python function get the grid() and reorder the args to
+        the underlying wrapper.
+        """
+        cfg = self.config
+        compile_meta = self.compile_meta
+        binary = self.kernel
+        fn = binary.src.fn
+        binary._init_handles()
+        (call_args, def_args, none_args) = self._get_arg_lists(
+            fn.arg_names, get_constexprs(fn)
+        )
+        binary_shared = (
+            binary.shared if hasattr(binary, "shared") else binary.metadata.shared
+        )
+
+        if knobs is None:
+            launch_enter = binary.__class__.launch_enter_hook
+            launch_exit = binary.__class__.launch_exit_hook
+        else:
+            launch_enter = knobs.runtime.launch_enter_hook
+            launch_exit = knobs.runtime.launch_exit_hook
+
+        import math as math_lib
+
+        import triton as triton_lib
+
+        import torch as torch_lib
+
+        scope = {
+            "grid_meta": cfg.kwargs,
+            "bin": binary,
+            "launch_enter_hook": launch_enter,
+            "launch_exit_hook": launch_exit,
+            "metadata": (
+                binary.packed_metadata
+                if hasattr(binary, "packed_metadata")
+                else binary.metadata
+            ),
+            "shared": binary_shared,
+            "num_warps": (
+                binary.num_warps
+                if hasattr(binary, "num_warps")
+                else binary.metadata.num_warps
+            ),
+            "cta_args": (
+                (
+                    binary.num_ctas,
+                    *get_first_attr(binary, "cluster_dims", "clusterDims"),
+                )
+                if hasattr(binary, "num_ctas")
+                else (
+                    (binary.metadata.num_ctas, *binary.metadata.cluster_dims)
+                    if hasattr(binary, "metadata")
+                    and hasattr(binary.metadata, "num_ctas")
+                    and hasattr(binary.metadata, "cluster_dims")
+                    else ()
+                )
+            ),
+            "function": get_first_attr(binary, "function", "cu_function"),
+            "runner": get_first_attr(binary, "run", "c_wrapper"),
+            "math": math_lib,
+            "torch": torch_lib,
+            "triton": triton_lib,
+        }
+
+        if not hasattr(binary, "launch_metadata"):
+            # launch args before CompiledKernel.launch_metadata is added.
+            # TODO(jansel): delete this branch in mid-2025
+            runner_args = [
+                "grid_0",
+                "grid_1",
+                "grid_2",
+                "num_warps",
+                "*cta_args",
+                "shared",
+                "stream",
+                "function",
+                "launch_enter_hook",
+                "launch_exit_hook",
+                "metadata",
+                *call_args,
+            ]
+        else:  # args after CompiledKernel.launch_metadata: https://github.com/triton-lang/triton/pull/3492
+            # Getting the kernel launch args is extremely perf-sensitive.  Evaluating
+            # `bin.launch_metadata` is relatively expensive, and returns None unless a
+            # `launch_enter_hook` is installed.  So if we don't have that hook installed,
+            # we want to burn None in to the launch args with zero overhead.
+            # See https://github.com/pytorch/pytorch/issues/123597
+            if launch_enter:
+                launch_metadata = f"bin.launch_metadata((grid_0, grid_1, grid_2), stream, {', '.join(call_args)})"
+            else:
+                launch_metadata = "None"
+            runner_args = [
+                "grid_0",
+                "grid_1",
+                "grid_2",
+                "stream",
+                "function",
+                "metadata",
+                launch_metadata,
+                "launch_enter_hook",
+                "launch_exit_hook",
+                *call_args,
+            ]
+
+        launcher = self._gen_launcher_code(scope, def_args, runner_args)
+
+        launcher = scope["launcher"]
+        launcher.config = cfg
+        launcher.n_regs = getattr(binary, "n_regs", None)
+        launcher.n_spills = getattr(binary, "n_spills", None)
+        launcher.shared = binary_shared
+        launcher.cache_hash = triton_hash_to_path_key(binary.hash)
+        launcher.store_cubin = self.inductor_meta.get("store_cubin", False)
+        # store this global variable to avoid the high overhead of reading it when calling run
+        if launcher.store_cubin:
+            launcher.fn = fn
+            launcher.bin = binary
+            if triton_version_uses_attrs_dict():
+                # arg filtering wasn't done above
+                cfg_dict = config_to_dict(cfg)
+                def_args = [x for x in def_args if x not in cfg_dict]
+                call_args = [
+                    x
+                    for x in call_args
+                    if compile_meta["signature"].get(x, "constexpr") != "constexpr"
+                    and x not in none_args
+                ]
+            launcher.def_args = def_args
+            launcher.call_args = call_args
+            kernel_metadata = getattr(self.kernel, "metadata", None)
+
+            # for the scratch arguments: None indicates that the kernel doesn't
+            # take any scratch argument; otherwise a number indicates the number
+            # of bytes of scratch that need to be provided.
+
+            # in AMD's Triton backend, the global scratch size is never provided
+            # (but for AMD it's safe to pass an extra null arg, so always include it)
+            global_scratch: int | None = getattr(
+                kernel_metadata,
+                "global_scratch_size",
+                (0 if torch.version.hip else None),
+            )
+            profile_scratch: int | None = getattr(
+                kernel_metadata, "profile_scratch_size", None
+            )
+            launcher.global_scratch = global_scratch
+            launcher.profile_scratch = profile_scratch
+        return launcher
+
+
+def _find_names(obj):
+    import gc
+    import inspect
+
+    frame = inspect.currentframe()
+    while frame is not None:
+        frame.f_locals
+        frame = frame.f_back
+    obj_names = []
+    for referrer in gc.get_referrers(obj):
+        if isinstance(referrer, dict):
+            for k, v in referrer.items():
+                if v is obj:
+                    obj_names.append(k)
+    return obj_names
+
+
+collected_calls: list[Any] = []
+
+
+def start_graph():
+    collected_calls.clear()
+
+
+def end_graph(output_file):
+    if len(collected_calls) == 0:
+        return
+    overall_time = sum(call[0] for call in collected_calls)
+    overall_gb = sum(call[1] for call in collected_calls)
+    cur_file = inspect.stack()[1].filename
+    summary_str = (
+        f"SUMMARY ({cur_file})\n"
+        f"{overall_time:.2f}ms   \t {overall_gb:.2f} GB\t {overall_gb / (overall_time / 1e3):.2f}GB/s"
+    )
+    log.info(
+        "%s",
+        summary_str,
+    )
+    if output_file is not None:
+        # sort perf numbers in descending order, i.e. placing the
+        # most runtime-heavy kernels at the top of the list
+        sorted_calls = sorted(collected_calls, key=lambda c: float(c[0]), reverse=True)
+        try:
+            with open(output_file, "a") as file:
+                log.info(
+                    "Save profile bandwidth results to %s",
+                    output_file,
+                )
+                file.write("====================\n")
+                file.write(f"TRITON KERNELS BANDWIDTH INFO ({cur_file})\n")
+                for ms, num_gb, gb_per_s, kernel_name in sorted_calls:
+                    # also display the runtime percentage for each kernel
+                    percentage = f"{ms / overall_time * 100:.2f}%"
+                    suffix = f" \t {percentage} \t {kernel_name}"
+                    bw_info_str = create_bandwidth_info_str(
+                        ms,
+                        num_gb,
+                        gb_per_s,
+                        suffix=suffix,
+                        color=False,
+                    )
+                    file.write(bw_info_str + "\n")
+                file.write(f"{summary_str}\n\n")
+        except Exception:
+            log.warning(
+                "failed to write profile bandwidth result into %s",
+                output_file,
+                exc_info=True,
+            )
+
+
+class DebugAutotuner(CachingAutotuner):
+    def __init__(
+        self,
+        *args,
+        regex_filter="",
+        with_profiler=False,
+        with_bandwidth_info=True,
+        **kwargs,
+    ):
+        self.regex_filter = regex_filter
+        self.with_profiler = with_profiler
+        self.with_bandwidth_info = with_bandwidth_info
+        super().__init__(*args, **kwargs)
+        self.cached = None
+
+    def run(self, *args, stream, **kwargs):
+        if not self.with_bandwidth_info:
+            super().run(*args, stream=stream, **kwargs, benchmark_run=True)
+            return
+        else:
+            possible_names = _find_names(self)
+            kernel_name = f"{max(possible_names, key=len)}"
+            if not re.match(self.regex_filter, kernel_name):
+                return
+            if len(self.launchers) != 1:
+                if len(self.launchers) == 0:
+                    start_time = time.time_ns()
+                    self.precompile()
+                    self.precompile_time_taken_ns = time.time_ns() - start_time
+                if len(self.launchers) > 1:
+                    self.autotune_to_one_config(*args, **kwargs)
+            (launcher,) = self.launchers
+
+            if launcher.store_cubin:
+                self.save_gpu_kernel(stream, launcher)
+
+            if self.cached is None:
+                ms = self.bench(launcher, *args, with_profiler=self.with_profiler)
+                num_in_out_ptrs = len(
+                    [
+                        arg_name
+                        for arg_name in self.fn.arg_names
+                        if arg_name.startswith("in_out_ptr")
+                    ]
+                )
+                num_gb = self.inductor_meta.get("kernel_num_gb", None)
+                if num_gb is None:
+                    num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9
+                gb_per_s = num_gb / (ms / 1e3)
+                self.cached = ms, num_gb, gb_per_s, kernel_name
+                collected_calls.append((ms, num_gb, gb_per_s, kernel_name))
+                log.info(
+                    "%s",
+                    create_bandwidth_info_str(
+                        ms, num_gb, gb_per_s, suffix=f" \t {kernel_name}"
+                    ),
+                )
+            else:
+                # in AOTI, we will call the kernel and its timing info has been cached already
+                collected_calls.append(self.cached)
+
+
+def hash_configs(configs: list[Config]):
+    """
+    Hash used to check for changes in configurations
+    """
+    hasher = hashlib.sha256()
+    for cfg in configs:
+        hasher.update(
+            f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode()
+        )
+    return hasher.hexdigest()
+
+
+def cached_autotune(
+    size_hints: list[int] | None,
+    configs: list[Config],
+    triton_meta,
+    heuristic_type,
+    filename=None,
+    inductor_meta=None,
+    custom_kernel=False,
+):
+    """
+    A copy of triton.autotune that calls our subclass.  Our subclass
+    has additional debugging, error handling, and on-disk caching.
+    """
+    configs = unique_configs(configs)
+    assert len(configs) == 1 or filename
+    inductor_meta = {} if inductor_meta is None else inductor_meta
+
+    configs, autotune_cache, autotune_cache_info = check_autotune_cache(
+        configs, filename, inductor_meta
+    )
+    mutated_arg_names = inductor_meta.pop("mutated_arg_names", ())
+    optimize_mem = inductor_meta.pop("optimize_mem", True)
+
+    if "restore_value" in triton_meta:
+        mutated_arg_names += triton_meta.pop("restore_value")
+
+    reset_to_zero_arg_names: list[str] = []
+    if "reset_to_zero" in triton_meta:
+        reset_to_zero_arg_names.extend(triton_meta.pop("reset_to_zero"))
+
+    def decorator(fn):
+        # Remove XBLOCK from config if it's not a function argument.
+        # This way, coordinate descent tuning will not try to tune it.
+        #
+        # Context: When TritonKernel.no_x_dim is True, we hardcode XBLOCK to 1.
+        import inspect
+
+        if "XBLOCK" not in inspect.signature(fn.fn).parameters:
+            for tconfig in configs:
+                if "XBLOCK" in tconfig.kwargs:
+                    assert tconfig.kwargs["XBLOCK"] == 1
+                    tconfig.kwargs.pop("XBLOCK")
+
+        if inductor_meta.get("profile_bandwidth"):
+            return DebugAutotuner(
+                fn,
+                triton_meta=triton_meta,
+                inductor_meta=inductor_meta,
+                regex_filter=inductor_meta["profile_bandwidth_regex"],
+                with_profiler=inductor_meta[
+                    "profile_bandwidth_with_do_bench_using_profiling"
+                ],
+                configs=configs,
+                save_cache_hook=autotune_cache and autotune_cache.save,
+                mutated_arg_names=mutated_arg_names,
+                reset_to_zero_arg_names=reset_to_zero_arg_names,
+                optimize_mem=optimize_mem,
+                heuristic_type=heuristic_type,
+                size_hints=size_hints,
+                custom_kernel=custom_kernel,
+                filename=filename,
+                with_bandwidth_info=True,
+            )
+        return CachingAutotuner(
+            fn,
+            triton_meta=triton_meta,
+            inductor_meta=inductor_meta,
+            configs=configs,
+            save_cache_hook=autotune_cache and autotune_cache.save,
+            mutated_arg_names=mutated_arg_names,
+            reset_to_zero_arg_names=reset_to_zero_arg_names,
+            optimize_mem=optimize_mem,
+            heuristic_type=heuristic_type,
+            size_hints=size_hints,
+            custom_kernel=custom_kernel,
+            filename=filename,
+            autotune_cache_info=autotune_cache_info,
+        )
+
+    return decorator
+
+
+def unique_configs(configs: list[Config]):
+    """Remove duplicate configurations"""
+    seen: OrderedSet[Hashable] = OrderedSet()
+    pruned_configs = []
+
+    for cfg in configs:
+        key = triton_config_to_hashable(cfg)
+        if key not in seen:
+            seen.add(key)
+            pruned_configs.append(cfg)
+    return pruned_configs
+
+
+def check_config(cfg, *, xnumel=None, ynumel=None, znumel=None):
+    for numel, label in zip((xnumel, ynumel, znumel), "XYZ"):
+        if numel is None:
+            continue
+        block = cfg[f"{label}BLOCK"]
+        if numel == 1:
+            assert block == 1, (
+                f"TritonKernel.indexing assumes numel == 1 => BLOCK == 1"
+                f" but {label.lower()}numel=={numel} and {label}BLOCK={block} (cfg={cfg})."
+            )
+        max_block = TRITON_MAX_BLOCK[label]
+        max_block_str = f'config.triton.max_block["{label}"]'
+        assert max_block % block == 0, (
+            f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}"
+            f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})."
+        )
+
+
+def check_max_block(cfg: dict[str, int]):
+    """
+    Check that block sizes are within the maximum allowed.
+    """
+    for var, val in cfg.items():
+        block_suffix = "BLOCK"
+        if block_suffix in var:
+            prefix = var.removesuffix(block_suffix)
+            max_block = TRITON_MAX_BLOCK[prefix]
+            assert val <= max_block, (
+                f"'{var}' too large. Maximum: {max_block}. Actual: {val}."
+            )
+
+
+def _num_warps(num_warps, max_num_warps=8, min_num_warps=2, register_intensive=False):
+    # On AMD GPU each warp has 64 lanes which is double the size on NV GPU,
+    # therefore using half the number of warps here correspondingly.
+    if torch.version.hip:
+        max_num_warps = (max_num_warps + 1) // 2
+        min_num_warps = (min_num_warps + 1) // 2
+    # persistent reduction is register intensive
+    if register_intensive:
+        max_num_warps = max_num_warps // 2
+    return next_power_of_2(min(max(num_warps, min_num_warps), max_num_warps))
+
+
+def _check_max_grid_x(size_hints, x, num_warps):
+    # Check if maxGridSize is exceeded - if so then must scale XBLOCK further
+    max_grid_x = 2147483647
+    warp_size = (
+        64 if torch.version.hip else 32
+    )  # TODO: query warp size once #129663 is merged
+    num_blocks = (size_hints["x"] + x - 1) // x
+
+    while (num_blocks * num_warps * warp_size) > max_grid_x and x < size_hints["x"]:
+        x *= 2  # Scale up XBLOCK if grid exceeds limits
+        num_blocks = num_blocks // 2
+    if (num_blocks * num_warps * warp_size) > max_grid_x:
+        raise AssertionError(
+            "Reduction config exceeds cudaDeviceProp maxGridSize. Please raise a pytorch issue"
+        )
+    return x, num_blocks
+
+
+def triton_config(
+    size_hints,
+    x,
+    y=None,
+    z=None,
+    num_stages=1,
+    num_elements_per_warp=256,
+    min_elem_per_thread=0,
+    num_warps=None,
+    matrix_instr=None,
+    waves_per_eu=None,
+) -> Config:
+    """
+    Construct a pointwise triton config with some adjustment heuristics
+    based on size_hints. Size_hints is a tuple of numels in each tile
+    dimension and will be rounded up to the nearest power of 2.
+
+    num_elements_per_warp is a suggestion for controlling how many warps
+    the triton config should contain. e.g.: if x=16, y=8, z=4 then
+    num_elements = 16*8*4 = 512. Then if we set num_elements_per_warp=128,
+    we'll launch 512 (elem) / 128 (elem/warp) = 4 warps. Note that it's
+    just a suggestion, and sometimes other adjustment heuristics will
+    override the num_elements_per_warp.
+
+    min_elem_per_thread controls the minimum number of elements
+    processed by each thread. It's always enforced.
+    """
+    # Ideally we want to read this from some device config
+
+    maxGridSize = [2147483647, 65535, 65535]
+
+    target = conditional_product(x, y, z)
+    if conditional_product(*size_hints.values()) < target:
+        target //= 8
+
+    # shrink sizes to size hints
+    x = min(x, size_hints["x"])
+    if y:
+        y = min(y, size_hints["y"])
+    if z:
+        z = min(z, size_hints["z"])
+
+    # if we are below original block size, scale up where we can;
+    # or if the calculated grid size is larger than the limit, we bump up the corresponding dimension
+    while x < min(size_hints["x"], TRITON_MAX_BLOCK["X"]) and (
+        x * maxGridSize[0] < size_hints["x"] or conditional_product(x, y, z) < target
+    ):
+        x *= 2
+    while (
+        y
+        and y < min(size_hints["y"], TRITON_MAX_BLOCK["Y"])
+        and (
+            y * maxGridSize[1] < size_hints["y"]
+            or conditional_product(x, y, z) < target
+        )
+    ):
+        y *= 2
+    while (
+        z
+        and z < min(size_hints["z"], TRITON_MAX_BLOCK["Z"])
+        and (
+            z * maxGridSize[2] < size_hints["z"]
+            or conditional_product(x, y, z) < target
+        )
+    ):
+        z *= 2
+
+    # Calculate num_warps if they are not hard passed to config
+    if num_warps is None:
+        num_warps = _num_warps(
+            conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1
+        )
+    # we are going to arrive at 2 warps only if bs was too small due to
+    # numel being too small. However to workaround some ptx bugs we still
+    # want at least 4 warps if there's enough elements per thread
+    # given that this is a rare situation, don't expect this to affect perf
+    # in general
+    # see https://github.com/pytorch/pytorch/pull/97950
+    if conditional_product(x, y, z) >= 128 and not torch.version.hip:
+        num_warps = max(num_warps, 4)
+    xnumel = size_hints["x"]
+    ynumel = size_hints.get("y")
+    znumel = size_hints.get("z")
+
+    # Increase x to satisfy min_elem_per_thread requirements.
+    block_size = max(
+        conditional_product(x, y, z),
+        min_elem_per_thread * _NUM_THREADS_PER_WARP * num_warps,
+    )
+    x *= math.ceil(block_size / conditional_product(x, y, z))
+
+    x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps)
+    x = min(x, size_hints["x"])
+
+    cfg = {"XBLOCK": x}
+    if y:
+        cfg["YBLOCK"] = y
+    if z:
+        cfg["ZBLOCK"] = z
+    check_max_block(cfg)
+    check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel)
+    config = Config(cfg, num_warps=num_warps, num_stages=num_stages)
+
+    if torch.version.hip:
+        if matrix_instr is not None:
+            config.kwargs["matrix_instr_nonkdim"] = matrix_instr
+        if waves_per_eu is not None:
+            config.kwargs["waves_per_eu"] = waves_per_eu
+
+    return config
+
+
+def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]:
+    """
+    Converts a linear reduction numel to ND, in row major order.
+    This order is often desirable as it presents opportunities to coalesce memory
+    accesses.
+    For example, if r = 64 and size_hints = [32,32], this function returns [32, 2].
+    This unraveling works because both r and size_hints are powers of 2.
+    """
+    # Shrink r to size_hints.
+    r = min(r, get_total_reduction_numel(size_hints))
+    num_reduction_dims = len(
+        [prefix for prefix in size_hints if prefix_is_reduction(prefix)]
+    )
+
+    remaining = r
+    rnumels = {}
+    for idx in range(num_reduction_dims - 1, -1, -1):
+        prefix = f"r{idx}_"
+        max_size = min(size_hints[prefix], TRITON_MAX_BLOCK[prefix.upper()])
+        dim = min(max_size, remaining)
+        assert remaining % dim == 0, (
+            f"Expected dimension '{dim}' to divide remaining size '{remaining}'"
+        )
+        rnumels[prefix] = dim
+        remaining //= dim
+
+    # Sanity check the results.
+    final_numel = conditional_product(*rnumels.values())
+    assert r == final_numel, (
+        f"Expected ND reduction size ({rnumels}) to have {r} elements."
+    )
+    assert all(rnumels[prefix] <= size_hints[prefix] for prefix in rnumels), (
+        f"rnumels exceed size_hints. {rnumels} > {size_hints}"
+    )
+
+    return rnumels
+
+
+def triton_config_reduction(
+    size_hints,
+    x: int,
+    r: int,
+    num_stages=1,
+    num_warps=None,
+    register_intensive=False,
+    dynamic_scale_rblock=True,
+    reduction_hint=None,
+    min_num_warps=None,
+) -> Config:
+    """
+    Construct a reduction triton config with some adjustment heuristics
+    based on size_hints. Size_hints is a tuple of numels in each tile
+    dimension and will be rounded up to the nearest power of 2.
+    """
+    # Convert the linear reduction numel into a multi-dimensional block.
+    rnumels = _get_nd_reduction_numels(r, size_hints)
+
+    # shrink sizes to size hints
+    x = min(x, size_hints["x"])
+
+    def total_numel() -> int:
+        return conditional_product(x, *rnumels.values())
+
+    target = total_numel()
+    if conditional_product(*size_hints.values()) < target:
+        target //= 8
+
+    # if we are below original block size, scale up where we can
+    while x < size_hints["x"] and total_numel() < target:
+        x *= 2
+    for prefix in sorted(rnumels):
+        while rnumels[prefix] < size_hints[prefix] and total_numel() < target:
+            rnumels[prefix] *= 2
+
+    if num_warps is None:
+        if reduction_hint == ReductionHint.INNER:
+            # r is contiguous, ensure at least 8 elements per thread
+            # xblock is usually 1-2, default to giving each thread more work
+            num_warps = r // 128
+        else:
+            num_warps = total_numel() // 128
+
+    max_num_warps = 16 if r <= 8192 else 32
+    if min_num_warps is not None:
+        _num_warps_func = functools.partial(_num_warps, min_num_warps=min_num_warps)
+    else:
+        _num_warps_func = _num_warps
+
+    num_warps = _num_warps_func(
+        num_warps, max_num_warps=max_num_warps, register_intensive=register_intensive
+    )
+
+    x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps)
+
+    for prefix in sorted(rnumels):
+        while total_numel() > target:
+            if rnumels[prefix] == 1:
+                break
+            rnumels[prefix] //= 2
+
+    cfg = _get_config({"x": x, **rnumels})
+    check_max_block(cfg)
+    check_config(cfg, xnumel=size_hints["x"])
+    return InductorConfig(
+        cfg,
+        num_warps=num_warps,
+        num_stages=num_stages,
+        dynamic_scale_rblock=dynamic_scale_rblock,
+    )
+
+
+def _get_config(numels: dict[str, int]) -> dict[str, int]:
+    """
+    Convert numels ("x", "r0_", etc.) to block sizes ("XBLOCK", "R0_BLOCK"), etc.
+    """
+
+    return {prefix.upper() + "BLOCK": numel for prefix, numel in numels.items()}
+
+
+def triton_config_tiled_reduction(
+    size_hints, x, y, r, num_stages=1, register_intensive=False
+):
+    """
+    Construct a tile reduction triton config with some adjustment
+    heuristics based on size_hints. Size_hints is a tuple of numels in
+    each tile dimension and will be rounded up to the nearest power of 2.
+    """
+    # Convert the linear reduction numel into a multi-dimensional block.
+    rnumels = _get_nd_reduction_numels(r, size_hints)
+
+    # shrink sizes to size hints
+    x = min(x, size_hints["x"])
+    y = min(y, size_hints["y"])
+
+    def total_numel() -> int:
+        return conditional_product(x, y, *rnumels.values())
+
+    target = total_numel()
+    if conditional_product(*size_hints.values()) < target:
+        target //= 8
+
+    # if we are below original block size, scale up where we can
+    while x < size_hints["x"] and total_numel() < target:
+        x *= 2
+    for prefix in sorted(rnumels):
+        while rnumels[prefix] < size_hints[prefix] and total_numel() < target:
+            rnumels[prefix] *= 2
+    while y < size_hints["y"] and total_numel() < target:
+        y *= 2
+
+    cfg = _get_config({"x": x, "y": y, **rnumels})
+    num_warps = _num_warps(total_numel() // 256, min_num_warps=1)
+    num_warps = _num_warps(
+        num_warps, max_num_warps=16, register_intensive=register_intensive
+    )
+    check_config(cfg, xnumel=size_hints["x"], ynumel=size_hints["y"])
+    check_max_block(cfg)
+    return Config(cfg, num_warps=num_warps, num_stages=num_stages)
+
+
+def _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs: list[Config]):
+    tma_min_block_sizes: dict[str, int]
+    if (tma_min_block_sizes := inductor_meta.get("tma_min_block_sizes")) and configs:
+        # Rn blocks are not provided to the kernel for persistent reductions
+        if inductor_meta.get("persistent_reduction"):
+            tma_min_block_sizes = {
+                block_type: block_size
+                for block_type, block_size in tma_min_block_sizes.items()
+                if not prefix_is_reduction(block_type.lower())
+            }
+
+        assert all(
+            block_type in configs[0].kwargs for block_type in tma_min_block_sizes
+        )
+
+        # Add a config that is guaranteed to compile
+        example_config = configs[0]
+        config_block_sizes = {**example_config.kwargs}
+        config_block_sizes.update(tma_min_block_sizes)
+        new_configs = [
+            Config(
+                config_block_sizes,
+                num_warps=example_config.num_warps,
+                num_stages=example_config.num_stages,
+                maxnreg=example_config.maxnreg,
+                pre_hook=example_config.pre_hook,
+            )
+        ]
+        # Remove configs that will not compile
+        for c in configs:
+            if all(
+                c.kwargs.get(block_type) >= min_block_value
+                for block_type, min_block_value in tma_min_block_sizes.items()
+            ):
+                new_configs.append(c)
+
+        log.debug(
+            "Filtering configs for TMA API restrictions. Input configs size: %d. Output configs size: %d",
+            len(configs),
+            len(new_configs),
+        )
+        return new_configs
+    return configs
+
+
+def pointwise(
+    size_hints,
+    triton_meta,
+    tile_hint=None,
+    filename=None,
+    min_elem_per_thread=0,
+    inductor_meta=None,
+):
+    """
+    Construct @triton.heuristics() based on size_hints.
+    """
+    inductor_meta = {} if inductor_meta is None else inductor_meta
+    assert not inductor_meta.get("no_x_dim")
+
+    numel = functools.reduce(operator.mul, size_hints.values())
+    bs = max(256, min(numel // 128, 1024))
+
+    hinted_configs = autotune_hints_to_configs(
+        inductor_meta.get("autotune_hints", OrderedSet()),
+        size_hints,
+        bs,
+        triton_meta["device"],
+    )
+
+    triton_config_with_settings = functools.partial(
+        triton_config, min_elem_per_thread=min_elem_per_thread
+    )
+
+    configs = None
+    if len(size_hints) == 1:
+        if not inductor_meta.get("autotune_pointwise", True) and not (
+            inductor_meta.get("max_autotune")
+            or inductor_meta.get("max_autotune_pointwise")
+        ):
+            configs = [triton_config_with_settings(size_hints, bs)]
+        else:
+            configs = [
+                triton_config_with_settings(size_hints, bs, num_elements_per_warp=256),
+                triton_config_with_settings(
+                    size_hints, bs // 2, num_elements_per_warp=64
+                ),
+                *hinted_configs,
+            ]
+            # Additional configs appended for ROCm builds
+            if torch.version.hip:
+                configs.extend(
+                    [
+                        triton_config_with_settings(
+                            size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2
+                        ),
+                        triton_config_with_settings(
+                            size_hints,
+                            4096,  # wrt: better than the max_block for some kernel
+                        ),
+                        triton_config_with_settings(
+                            size_hints,
+                            2048,
+                            num_warps=8,
+                            num_stages=2,
+                            waves_per_eu=1,  # 20% improvement
+                        ),
+                    ]
+                )
+                if inductor_meta.get("atomic_add_found"):
+                    configs.extend(
+                        [
+                            triton_config_with_settings(
+                                size_hints,
+                                64,
+                                num_warps=1,
+                                num_stages=1,  # 250% improvement
+                            )
+                        ]
+                    )
+    if len(size_hints) == 2:
+        # Only avoiding tuning on TileHint.SQUARE if not on ROCm builds
+        # ROCm has observed improvement by diverging here
+        if (
+            not inductor_meta.get("autotune_pointwise", True)
+            or (torch.version.hip is None and tile_hint == TileHint.SQUARE)
+        ) and not (
+            inductor_meta.get("max_autotune")
+            or inductor_meta.get("max_autotune_pointwise")
+        ):
+            configs = [triton_config_with_settings(size_hints, 32, 32)]
+        else:
+            configs = [
+                triton_config_with_settings(size_hints, 32, 32),
+                triton_config_with_settings(size_hints, 64, 64),  # ~8% better for fp16
+                triton_config_with_settings(size_hints, 256, 16),
+                triton_config_with_settings(size_hints, 16, 256),
+                triton_config_with_settings(size_hints, bs, 1),
+                triton_config_with_settings(size_hints, 1, bs),
+                *hinted_configs,
+            ]
+            # Additional configs appended for ROCm builds
+            if torch.version.hip:
+                configs.extend(
+                    [
+                        triton_config_with_settings(
+                            size_hints, 64, 32
+                        ),  # better for some kernels
+                        triton_config_with_settings(
+                            size_hints, 128, 16
+                        ),  # +10% for some kernels
+                        triton_config_with_settings(
+                            size_hints, 128, 32
+                        ),  # additional 10% more
+                        triton_config_with_settings(
+                            size_hints, 32, 512
+                        ),  # +30% for some kernels
+                    ]
+                )
+    if len(size_hints) == 3:
+        if not inductor_meta.get("autotune_pointwise", True):
+            configs = [triton_config_with_settings(size_hints, 16, 16, 16)]
+        else:
+            configs = [
+                triton_config_with_settings(size_hints, 16, 16, 16),
+                triton_config_with_settings(size_hints, 64, 8, 8),
+                triton_config_with_settings(size_hints, 8, 64, 8),
+                triton_config_with_settings(size_hints, 8, 8, 64),
+                triton_config_with_settings(size_hints, bs, 1, 1),
+                triton_config_with_settings(size_hints, 1, bs, 1),
+                triton_config_with_settings(size_hints, 1, 1, bs),
+                *hinted_configs,
+            ]
+
+    if not configs:
+        raise NotImplementedError(f"size_hints: {size_hints}")
+
+    configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
+
+    return cached_autotune(
+        size_hints,
+        configs,
+        triton_meta=triton_meta,
+        inductor_meta=inductor_meta,
+        heuristic_type=HeuristicType.POINTWISE,
+        filename=filename,
+    )
+
+
+def make_matmul_triton_config(sizes: dict[str, int], num_warps: int, num_stages: int):
+    config = {
+        "XBLOCK": sizes.get("x"),
+        "YBLOCK": sizes.get("y"),
+        "ZBLOCK": sizes.get("z"),
+        "R0_BLOCK": sizes.get("r"),
+    }
+    # Remove keys with None values (i.e., missing in sizes)
+    config = {k: v for k, v in config.items() if v is not None}
+    return Config(config, num_warps=num_warps, num_stages=num_stages)
+
+
+def _config_helper(bmm=False, persistent=False):
+    # Each entry is: (sizes_dict, num_warps, num_stages)
+    _base_mm_configs = [
+        ({"x": 32, "y": 32, "r": 16}, 2, 1),
+        ({"x": 32, "y": 32, "r": 128}, 4, 2),
+        ({"x": 32, "y": 64, "r": 32}, 8, 5),
+        ({"x": 64, "y": 32, "r": 32}, 8, 5),
+        ({"x": 64, "y": 32, "r": 128}, 4, 5),
+        ({"x": 64, "y": 64, "r": 16}, 4, 2),
+        ({"x": 64, "y": 64, "r": 32}, 4, 2),
+        ({"x": 64, "y": 64, "r": 64}, 8, 3),
+        ({"x": 64, "y": 64, "r": 128}, 4, 5),
+        ({"x": 64, "y": 128, "r": 32}, 4, 3),
+        ({"x": 64, "y": 128, "r": 32}, 8, 4),
+        ({"x": 64, "y": 128, "r": 64}, 4, 3),
+        ({"x": 64, "y": 128, "r": 128}, 4, 4),
+        ({"x": 128, "y": 64, "r": 32}, 4, 3),
+        ({"x": 128, "y": 64, "r": 32}, 8, 4),
+        ({"x": 128, "y": 128, "r": 32}, 8, 2),
+        ({"x": 128, "y": 128, "r": 32}, 4, 3),
+        ({"x": 128, "y": 128, "r": 64}, 4, 3),
+        ({"x": 128, "y": 128, "r": 64}, 8, 5),
+    ]
+    out = []
+    for sizes, w, s in _base_mm_configs:
+        d = dict(sizes)
+        if persistent:
+            d.pop("r", None)
+        if bmm:
+            d["z"] = 1
+        out.append((d, w, s))
+
+    # Deduplicate by converting dicts to immutable frozensets
+    deduped = {(frozenset(d.items()), w, s): (d, w, s) for d, w, s in out}
+
+    return list(deduped.values())
+
+
+triton_native_mm_configs = _config_helper(bmm=False, persistent=False)
+triton_native_persistent_mm_configs = _config_helper(bmm=False, persistent=True)
+triton_native_bmm_configs = _config_helper(bmm=True, persistent=False)
+triton_native_persistent_bmm_configs = _config_helper(bmm=True, persistent=True)
+
+
+def _reduction_configs(
+    *,
+    size_hints: dict[str, int],
+    inductor_meta: dict[str, Any],
+    triton_meta: dict[str, Any],
+    num_dynamic=0,
+) -> list[Config]:
+    reduction_hint = inductor_meta.get("reduction_hint")
+
+    # Convert reductions to 1D, to simplify heuristics.
+    rnumel = get_total_reduction_numel(size_hints)
+
+    register_intensive = False
+    MAX_R0_BLOCK = 2048
+    loads_and_red = inductor_meta.get("num_load", 0) + inductor_meta.get(
+        "num_reduction", 0
+    )
+    if size_hints["x"] >= 1024 and loads_and_red >= 10:
+        # A heuristics to reduce R0_BLOCK if a kernel potentially need many registers.
+        # Consider load and reduction since load need move data into registers and
+        # reduction needs an accumulator.
+        #
+        # The magic numbers are a bit arbitrary.
+        #
+        # We cannot rely on dynamically scaling down R0_BLOCK later, since sometimes
+        # triton makes it to use less registers with worse perf. Check:
+        # https://github.com/pytorch/pytorch/issues/126463
+        #
+        # The heuristic is a very simple one since registers can be reused. But
+        # hopefully it can be a good enough indicator.
+        MAX_R0_BLOCK = 1024
+        register_intensive = True
+
+    if triton_meta.get("native_matmul"):
+        if len(size_hints) == 3:
+            return [
+                make_matmul_triton_config(sizes, num_warps, num_stages)
+                for sizes, num_warps, num_stages in triton_native_mm_configs
+            ]
+        elif len(size_hints) == 4:
+            return [
+                make_matmul_triton_config(sizes, num_warps, num_stages)
+                for sizes, num_warps, num_stages in triton_native_bmm_configs
+            ]
+        else:
+            raise NotImplementedError("native matmul only supports mm/bmm pattern")
+
+    def make_config(
+        x,
+        r,
+        num_warps=None,
+        num_stages=1,
+        register_intensive=False,
+        dynamic_scale_rblock=True,
+    ):
+        # For 3D case with tiling scores, create an adapted version
+        if "y" in size_hints:
+            assert "tiling_scores" in inductor_meta
+            return adapt_config_for_tiling(
+                size_hints,
+                inductor_meta["tiling_scores"],
+                x,
+                r,
+                num_warps=num_warps,
+                num_stages=num_stages,
+                register_intensive=register_intensive,
+            )
+        else:
+            # For other cases, use the original function
+            return triton_config_reduction(
+                size_hints,
+                x,
+                r,
+                num_warps=num_warps,
+                num_stages=num_stages,
+                register_intensive=register_intensive,
+                dynamic_scale_rblock=dynamic_scale_rblock,
+                reduction_hint=reduction_hint,
+            )
+
+    def outer_config_opt():
+        # Default to 64 for vectorized loads
+        max_x_block, x_block = 256, 64
+        load_factor = inductor_meta.get("num_load", 0)
+        x = size_hints["x"]
+        num_warps = None
+
+        # Try to use all SMs with small x
+        if x <= 1024:
+            x_block = max(min(x // 128, 8), 2)
+            outer_r_block = min(rnumel, 64)
+        # Lower bound x = 1024, 1024 // 16 = 128 around # of SMs
+        elif x // 4096 <= 8:
+            x_block = 16
+            outer_r_block = 512 // x_block
+        elif num_dynamic > 1:
+            # Lots of compute with multiple dynamic shape per loop iteration
+            # Larger RBLOCK minimizes loop iteration
+            outer_r_block = max(min((rnumel // 64), 64), 8)
+        elif num_dynamic == 1:
+            # Dynamic shapes introduce a lot register pressure for indexing
+            outer_r_block = (
+                1
+                if load_factor >= 3
+                else min(next_power_of_2(max(rnumel, 128) // 128), 8)
+            )
+        else:
+            x_block = max(min(max_x_block, next_power_of_2(x // 4096)), x_block)
+            if load_factor < 4 or rnumel <= 128:
+                outer_r_block = 512 // x_block
+            else:
+                # Heavier reductions contain a lot more overhead per loop iteration
+                # We minimize the overhead by enlarging r block
+                if rnumel >= 2048:
+                    outer_r_block = 64
+                else:
+                    outer_r_block = 32
+                x_block = min(x_block, 32)
+                num_warps = 4
+
+        # Set register intensive to true by default as we try to maximize tiles with heuristic
+        return make_config(
+            x_block,
+            outer_r_block,
+            num_warps=num_warps,
+            register_intensive=register_intensive,
+        )
+
+    contiguous_config = make_config(
+        2 if rnumel <= 2048 else 1,  # 1024 or less is persistent
+        min(rnumel, MAX_R0_BLOCK),
+        register_intensive=register_intensive,
+    )
+    tiny_config = make_config(
+        2 * (256 // rnumel) if rnumel <= 256 else 1,
+        min(rnumel, MAX_R0_BLOCK),
+        register_intensive=register_intensive,
+    )
+
+    outer_config = make_config(64, 8, register_intensive=register_intensive)
+    # TODO (paulzhan): Test heuristic on AMD and internal testing
+    # for correctness
+    if not torch.version.hip:
+        outer_config = outer_config_opt()
+
+    configs = []
+
+    if inductor_meta.get("add_persistent_rblock") and loads_and_red <= 8:
+        xnumel = max(4096 // rnumel, 1)
+        c = make_config(
+            xnumel,
+            min(rnumel, 32768),
+            register_intensive=register_intensive,
+            dynamic_scale_rblock=False,
+        )
+        configs.append(c)
+
+    # For 3d tiling, default to more autotuning initially
+    if "y" in size_hints:
+        pass
+    elif inductor_meta.get("max_autotune") or inductor_meta.get(
+        "max_autotune_pointwise"
+    ):
+        pass  # skip all these cases
+    elif reduction_hint == ReductionHint.INNER:
+        return configs + [contiguous_config]
+    elif reduction_hint == ReductionHint.OUTER:
+        return configs + [outer_config]
+    elif reduction_hint == ReductionHint.OUTER_TINY:
+        return configs + [tiny_config]
+
+    return configs + [
+        contiguous_config,
+        outer_config,
+        tiny_config,
+        make_config(64, 64),
+        make_config(8, 512),
+        # halve the XBLOCK/Rn_BLOCK compared to outer_config
+        # TODO: this may only be beneficial when each iteration of the reduction
+        # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
+        make_config(64, 4, num_warps=8),
+    ]
+
+
+def match_target_block_product(
+    size_hints, tiling_scores, target_block_product, min_block_size=1
+):
+    """
+    Distribute block sizes across dimensions according to tiling scores,
+    aiming to match a target product of block sizes.
+    """
+    total_score = sum(tiling_scores.values())
+    if total_score == 0:
+        # just assume even score with no minimum block size
+        min_block_size = 1
+        tiling_scores = dict.fromkeys(tiling_scores.keys(), target_block_product)
+
+    # First, give each coalescing dimension at least min_block_size
+    block_sizes = {}
+    relative_scores = {}
+    curr_block_product = 1
+
+    for dim, score in tiling_scores.items():
+        if score == 0:
+            block_sizes[dim] = 1
+            continue
+
+        block_sizes[dim] = min_block_size
+        curr_block_product *= min_block_size
+        relative_scores[dim] = score / total_score
+
+    # Scale up dimensions by their relative scores until we reach the target
+    while curr_block_product < target_block_product and relative_scores:
+        dim, score = max(relative_scores.items(), key=lambda item: item[1])
+
+        # Check if we've hit the max for this dimension
+        if (
+            block_sizes[dim] >= TRITON_MAX_BLOCK[dim.capitalize()]
+            or block_sizes[dim] >= size_hints[dim]
+        ):
+            del relative_scores[dim]
+            continue
+
+        block_sizes[dim] *= 2
+        relative_scores[dim] /= 2
+        curr_block_product *= 2
+
+    return block_sizes
+
+
+def adapt_config_for_tiling(
+    size_hints,
+    tiling_scores,
+    original_x,
+    original_r,
+    num_warps=None,
+    num_stages=1,
+    register_intensive=False,
+    persistent_reduction=False,
+) -> Config:
+    """
+    Create an adapted configuration based on tiling scores,
+    redistributing the same total block size (x * r) according to tiling scores.
+    """
+    assert all(s in tiling_scores for s in size_hints)
+    target_block_product = original_x * original_r
+    block_sizes = match_target_block_product(
+        size_hints, tiling_scores, target_block_product
+    )
+
+    return triton_config_tiled_reduction(
+        size_hints,
+        block_sizes["x"],
+        block_sizes["y"],
+        block_sizes["r0_"],
+        num_stages=num_stages,
+        register_intensive=register_intensive,
+    )
+
+
+def filter_reduction_configs_for_determinism(
+    inductor_meta: dict[str, Any], configs: list[Config]
+) -> list[Config]:
+    """
+    Filter configs for reduction so the numerics can be deterministic.
+
+    Heuristics:
+    - skip reduction configs with too small RBLOCK
+    - skip reduction configs with XBLOCK==1 if we are confident it will not perform well
+    - if there is a tie, pick the config with second largest RBLOCK
+    - if there is still a tie, pick the config with second largest num_warps
+    - if there is still a tie, pick the config with second largest XBLOCK
+    """
+    configs = unique_configs(configs)
+    assert len(configs) > 0
+
+    def _do_filter_due_to_inductor_config():
+        return (
+            inductor_meta.get("deterministic", False)
+            or inductor_meta.get("force_filter_reduction_configs", False)
+        ) or inductor_meta.get("are_deterministic_algorithms_enabled")
+
+    if not _do_filter_due_to_inductor_config() or len(configs) == 1:
+        # no filtering happening if NOT in deterministic mode
+        return configs
+
+    if log.isEnabledFor(logging.DEBUG):
+        log.debug("reduction configs before filtering:")
+        for c in configs:
+            log.debug("%s", c)
+            log.debug("")
+
+    def _has_too_small_rblock(config):
+        rblock = config.kwargs.get("R0_BLOCK")
+        # too small RBLOCK is likely to be bad
+        return rblock is not None and rblock <= 4
+
+    def _nonpromising_xblock_1(config):
+        # kernel like https://gist.github.com/shunting314/0b3281c087e79bc915fe45985ff9d7d5
+        # without a load/store having contiguous rdim is unlikely to perform well with XBLOCK==1
+        return config.kwargs["XBLOCK"] == 1 and not inductor_meta.get(
+            "has_loadstore_with_contiguous_rdim", True
+        )
+
+    newconfigs = [*filter(lambda x: not _has_too_small_rblock(x), configs)]
+    # accept the filtering only if there are configs left
+    if len(newconfigs) > 0:
+        configs = newconfigs
+
+    newconfigs = [*filter(lambda x: not _nonpromising_xblock_1(x), configs)]
+    if len(newconfigs) > 0:
+        configs = newconfigs
+
+    assert len(configs) > 0
+
+    def _r0_block(c):
+        return c.kwargs.get("R0_BLOCK", -1)
+
+    def _xblock(c):
+        return c.kwargs.get("XBLOCK", -1)
+
+    def _num_warps(c):
+        return c.num_warps
+
+    def _pick_second_largest(accessor):
+        nonlocal configs
+        configs = sorted(configs, key=lambda x: accessor(x))
+        if accessor(configs[0]) != accessor(configs[-1]):
+            max_val = accessor(configs[-1])
+            configs = [*filter(lambda x: accessor(x) != max_val, configs)]
+            second_max_val = accessor(configs[-1])
+            configs = [*filter(lambda x: accessor(x) == second_max_val, configs)]
+        return configs
+
+    def _pick_config():
+        nonlocal configs
+        assert len(configs) > 0
+        if len(configs) == 1:
+            return configs[0]
+
+        # break tie by R0_BLOCK
+        configs = _pick_second_largest(_r0_block)
+        if len(configs) == 1:
+            return configs[0]
+
+        # break tie by num_warps
+        configs = _pick_second_largest(_num_warps)
+        if len(configs) == 1:
+            return configs[0]
+
+        # break tie by XBLOCK
+        configs = _pick_second_largest(_xblock)
+
+        # there is still a tie, pick the first one
+        return configs[0]
+
+    configs = [_pick_config()]
+
+    if log.isEnabledFor(logging.DEBUG):
+        log.debug("reduction configs after filtering:")
+        for c in configs:
+            log.debug("%s", c)
+            log.debug("")
+    return configs
+
+
+def reduction(
+    size_hints,
+    reduction_hint=False,
+    triton_meta=None,
+    filename=None,
+    inductor_meta=None,
+):
+    """args to @triton.heuristics()"""
+    inductor_meta = {} if inductor_meta is None else inductor_meta
+    inductor_meta["reduction_hint"] = reduction_hint
+    if inductor_meta.get("no_x_dim"):
+        size_hints["x"] = 1
+
+    assert triton_meta is not None
+
+    num_dynamic = 0
+    for k in triton_meta["signature"]:
+        if "ks" in k:
+            num_dynamic += 1
+
+    configs = _reduction_configs(
+        size_hints=size_hints,
+        inductor_meta=inductor_meta,
+        triton_meta=triton_meta,
+        num_dynamic=num_dynamic,
+    )
+
+    configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
+    configs = filter_reduction_configs_for_determinism(inductor_meta, configs)
+
+    return cached_autotune(
+        size_hints,
+        configs=configs,
+        triton_meta=triton_meta,
+        inductor_meta=inductor_meta,
+        heuristic_type=HeuristicType.REDUCTION,
+        filename=filename,
+    )
+
+
+def cooperative_reduction(
+    size_hints,
+    reduction_hint,
+    triton_meta,
+    filename,
+    inductor_meta,
+):
+    inductor_meta = {} if inductor_meta is None else inductor_meta
+    inductor_meta["reduction_hint"] = reduction_hint
+    if inductor_meta.get("no_x_dim"):
+        size_hints["x"] = 1
+
+    # Cooperative reductions currently only support a single reduction dimension.
+    assert len(size_hints) == 2, (
+        "Cooperative reductions don't support tiling reduction dims"
+    )
+    xnumel, rnumel = size_hints["x"], size_hints["r0_"]
+
+    # TODO(jansel): we should base target on the SM count of the local GPU
+    target = 64
+    split = max(1, min(target // xnumel, TRITON_MAX_RSPLIT))
+    assert rnumel >= split
+    assert split <= TRITON_MAX_RSPLIT
+    if inductor_meta["persistent_reduction"]:
+        configs = _persistent_reduction_configs(
+            {"x": xnumel, "r0_": rnumel // split},
+            reduction_hint,
+            inductor_meta,
+            triton_meta,
+        )
+    else:
+        configs = _reduction_configs(
+            size_hints={"x": xnumel, "r0_": rnumel // split},
+            inductor_meta=inductor_meta,
+            triton_meta=triton_meta,
+        )
+    for config in configs:
+        config.kwargs["RSPLIT"] = split
+    # TODO(jansel): add more configs in max_autotune
+
+    configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
+    configs = filter_reduction_configs_for_determinism(inductor_meta, configs)
+    return cached_autotune(
+        size_hints,
+        configs=configs,
+        triton_meta=triton_meta,
+        inductor_meta=inductor_meta,
+        heuristic_type=HeuristicType.REDUCTION,
+        filename=filename,
+    )
+
+
+def _persistent_reduction_configs(
+    size_hints,
+    reduction_hint=False,
+    inductor_meta=None,
+    triton_meta=None,
+):
+    xnumel = size_hints["x"]
+    rnumel = get_total_reduction_numel(size_hints)
+
+    MAX_PERSISTENT_BLOCK_NUMEL = 4096
+
+    if triton_meta.get("native_matmul"):
+        if len(size_hints) == 3:
+            return [
+                make_matmul_triton_config(sizes, num_warps, num_stages)
+                for sizes, num_warps, num_stages in triton_native_persistent_mm_configs
+            ]
+        elif len(size_hints) == 4:
+            return [
+                make_matmul_triton_config(sizes, num_warps, num_stages)
+                for sizes, num_warps, num_stages in triton_native_persistent_bmm_configs
+            ]
+        else:
+            raise NotImplementedError("native matmul only supports mm/bmm pattern")
+
+    max_autotune_enabled = inductor_meta.get("max_autotune") or inductor_meta.get(
+        "max_autotune_pointwise"
+    )
+
+    if torch.version.hip:
+        xblock_vals = [1, 4, 8, 16, 32, 64, 128, 256]
+    else:
+        xblock_vals = [1, 8, 32, 128]
+
+    if "y" not in size_hints:
+        configs = [
+            triton_config_reduction(
+                size_hints,
+                xblock,
+                rnumel,
+                register_intensive=True,
+                reduction_hint=reduction_hint,
+            )
+            for xblock in xblock_vals
+            if xblock == 1
+            or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel)
+        ]
+    else:
+        configs = []
+        assert "tiling_scores" in inductor_meta
+        x_y_scores = {dim: inductor_meta["tiling_scores"][dim] for dim in ("x", "y")}
+        for target_block_size in xblock_vals:
+            if target_block_size * rnumel > MAX_PERSISTENT_BLOCK_NUMEL:
+                continue
+
+            block_sizes = match_target_block_product(
+                size_hints, x_y_scores, target_block_size
+            )
+            configs.append(
+                triton_config_tiled_reduction(
+                    size_hints, block_sizes["x"], block_sizes["y"], rnumel
+                )
+            )
+
+    tiny_configs = [
+        triton_config_reduction(
+            size_hints,
+            2 * (256 // rnumel) if rnumel <= 256 else 1,
+            rnumel,
+        )
+    ]
+
+    # defer to more autotuning, initially
+    if "y" in size_hints:
+        pass
+    # TODO(jansel): we should be able to improve these heuristics
+    elif not max_autotune_enabled:  # Do not filter configs when tuning
+        if reduction_hint == ReductionHint.INNER and rnumel >= 256:
+            if rnumel > 1024 or xnumel // 8 < 128 or inductor_meta.get("RSPLIT_SIZE"):
+                configs = configs[:1]
+            else:
+                if not torch.cuda.is_available():
+                    # TODO(Intel): CUDA uses num_warps = 1 to disable shared memory.
+                    # We apply different configurations from #168335.
+                    # We currently let cost model in Triton to decide whether to use shared memory.
+                    loads_and_stores = inductor_meta.get(
+                        "num_load", 0
+                    ) + inductor_meta.get("num_store", 0)
+                    x_block = 8
+                    if xnumel // x_block < 128 or loads_and_stores >= 5:
+                        x_block = 1
+                    num_warps, min_num_warps, reduction_hint = None, None, None
+                else:
+                    x_block = min(1024 // rnumel, 8)
+                    num_warps, min_num_warps = 1, 1
+                configs = [
+                    triton_config_reduction(
+                        size_hints,
+                        x_block,
+                        rnumel,
+                        register_intensive=True,
+                        num_warps=num_warps,
+                        min_num_warps=min_num_warps,
+                        reduction_hint=reduction_hint,
+                    )
+                ]
+
+        elif reduction_hint == ReductionHint.OUTER:
+            configs = configs[-1:]
+        elif reduction_hint == ReductionHint.OUTER_TINY:
+            configs = tiny_configs
+    else:
+        if torch.version.hip:
+            # If autotune is enabled append tiny configs
+            for conf in tiny_configs:
+                if conf not in configs:
+                    configs.append(conf)
+
+    for c in configs:
+        # we don't need Rn_BLOCK for persistent reduction
+        for prefix in size_hints:
+            if prefix_is_reduction(prefix):
+                c.kwargs.pop(f"{prefix.upper()}BLOCK")
+
+    return configs
+
+
+def persistent_reduction(
+    size_hints,
+    reduction_hint=False,
+    triton_meta=None,
+    filename=None,
+    inductor_meta=None,
+):
+    inductor_meta = {} if inductor_meta is None else inductor_meta
+    inductor_meta["reduction_hint"] = reduction_hint
+    if inductor_meta.get("no_x_dim"):
+        size_hints["x"] = 1
+
+    configs = _persistent_reduction_configs(
+        size_hints, reduction_hint, inductor_meta, triton_meta
+    )
+
+    # This key is not added to the inductor meta as its clear from the heuristic
+    # choice that it is persistent. Add it and remove it below so that persistent
+    # configs can be filtered appropriately by _maybe_filter_configs_for_tma_restrictions
+    persistent_reduction_key = "persistent_reduction"
+    inductor_meta[persistent_reduction_key] = True
+    configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
+    inductor_meta.pop(persistent_reduction_key)
+
+    if inductor_meta.get("RSPLIT_SIZE"):
+        new_configs = []
+        rsplit_size = inductor_meta.get("RSPLIT_SIZE")
+        rnumel_hint = size_hints["r0_"]
+        min_x_block = 1
+        if rnumel_hint <= 512:
+            min_x_block = 4
+        x_block = min(max(rsplit_size // 32, min_x_block), 16)
+        for c in configs:
+            c.kwargs["RSPLIT_SIZE"] = rsplit_size
+            # small XBLOCK to use less registers/smem
+            c.kwargs["XBLOCK"] = x_block
+
+            num_iters = rsplit_size // x_block
+            c.kwargs["NUM_STAGES"] = min(max(num_iters // 4, 1), 3)
+
+            if rnumel_hint <= 1024:
+                c.num_warps //= 2
+                c.num_warps = max(c.num_warps, 1)
+                new_configs.append(c)
+
+                # less warps so potentially each sm can run more thread blocks
+                # Inside each thread block, we handle the split sequentially,
+                # more thread blocks is beneficial here.
+                newc = copy.deepcopy(c)
+                newc.num_warps = 2
+                new_configs.append(newc)
+            else:
+                # more warps for larger rows
+                new_configs.append(c)
+
+                if c.num_warps < 32:
+                    newc = copy.deepcopy(c)
+                    newc.num_warps *= 2
+                    new_configs.append(newc)
+
+        configs = unique_configs(new_configs)
+
+    configs = filter_reduction_configs_for_determinism(inductor_meta, configs)
+    return cached_autotune(
+        size_hints,
+        configs,
+        triton_meta=triton_meta,
+        inductor_meta=inductor_meta,
+        filename=filename,
+        heuristic_type=HeuristicType.PERSISTENT_REDUCTION,
+    )
+
+
+def split_scan(
+    size_hints,
+    reduction_hint=False,
+    triton_meta=None,
+    filename=None,
+    inductor_meta=None,
+):
+    """Heuristic for TritonSplitScanKernel"""
+    inductor_meta = {} if inductor_meta is None else inductor_meta
+    inductor_meta["reduction_hint"] = reduction_hint
+    if inductor_meta.get("no_x_dim"):
+        size_hints["x"] = 1
+
+    assert triton_meta is not None
+    if len(size_hints) != 2:
+        raise NotImplementedError(f"size_hints: {size_hints}")
+
+    configs = _reduction_configs(
+        size_hints=size_hints, inductor_meta=inductor_meta, triton_meta=triton_meta
+    )
+
+    # Fixup configs to enforce the minimum Rn_BLOCK size
+    min_rblock = inductor_meta.get("min_split_scan_rblock", 256)
+    for cfg in configs:
+        for var in list(cfg.kwargs.keys()):
+            if var.startswith("R") and cfg.kwargs[var] < min_rblock:
+                cfg.kwargs[var] = min_rblock
+
+    configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
+    configs = filter_reduction_configs_for_determinism(inductor_meta, configs)
+    return cached_autotune(
+        size_hints,
+        configs=configs,
+        triton_meta=triton_meta,
+        inductor_meta=inductor_meta,
+        heuristic_type=HeuristicType.SPLIT_SCAN,
+        filename=filename,
+    )
+
+
+def template(
+    num_stages,
+    num_warps,
+    triton_meta,
+    num_consumer_groups=0,
+    num_buffers_warp_spec=0,
+    filename=None,
+    inductor_meta=None,
+):
+    """
+    Compile a triton template
+    """
+    # Prepare the base configuration
+    config_args = {
+        "num_stages": num_stages,
+        "num_warps": num_warps,
+    }
+
+    # Conditionally add arguments based on HAS_WARP_SPEC
+    if HAS_WARP_SPEC:
+        config_args.update(
+            {
+                "num_consumer_groups": num_consumer_groups,
+                "num_buffers_warp_spec": num_buffers_warp_spec,
+            }
+        )
+    return cached_autotune(
+        None,
+        [triton.Config({}, **config_args)],
+        triton_meta=triton_meta,
+        inductor_meta=inductor_meta,
+        heuristic_type=HeuristicType.TEMPLATE,
+        filename=filename,
+    )
+
+
+def _pop_config_kwargs(config: dict[str, Any]) -> dict[str, Any]:
+    """Extract triton.Config options that should become kwargs"""
+    popped = {}
+    for key in (
+        "num_warps",
+        "num_stages",
+        "num_ctas",
+        "maxnreg",
+        "num_consumer_groups",
+        "num_buffers_warp_spec",
+    ):
+        val = config.pop(key, None)
+        if val is not None:
+            popped[key] = val
+    return popped
+
+
+def config_to_dict(config: Config) -> dict[str, Any]:
+    config_dict = {
+        **config.kwargs,
+        "num_warps": config.num_warps,
+        "num_stages": config.num_stages,
+    }
+    if HAS_WARP_SPEC:
+        config_dict.update(
+            {
+                "num_consumer_groups": getattr(config, "num_consumer_groups", 0),
+                "num_buffers_warp_spec": getattr(config, "num_buffers_warp_spec", 0),
+            }
+        )
+    return config_dict
+
+
+def config_from_dict(config: dict[str, Any]) -> Config:
+    config = {**config}
+    return Config(config, **_pop_config_kwargs(config))
+
+
+def fixed_config(config, filename, triton_meta, inductor_meta):
+    """
+    Used when the configuration is already decided at compile time
+    """
+    config = {**config}
+    return cached_autotune(
+        None,
+        [triton.Config(config, **_pop_config_kwargs(config))],
+        triton_meta=triton_meta,
+        inductor_meta=inductor_meta,
+        heuristic_type=HeuristicType.FIXED,
+        filename=filename,
+    )
+
+
+def user_autotune(
+    configs, triton_meta, filename=None, inductor_meta=None, custom_kernel=False
+):
+    """
+    Compile a user defined triton kernel
+    """
+    if len(configs) == 0:
+        configs = [triton.Config({})]
+    else:
+        configs = [*map(config_from_dict, configs)]
+    return cached_autotune(
+        None,
+        configs,
+        triton_meta=triton_meta,
+        heuristic_type=HeuristicType.USER_AUTOTUNE,
+        filename=filename,
+        inductor_meta=inductor_meta,
+        custom_kernel=custom_kernel,
+    )
+
+
+def foreach(triton_meta, filename=None, inductor_meta=None):
+    """
+    Compile a triton foreach kernel
+    """
+    configs = []
+
+    # Naive autotuning path for num_warps
+    if not (
+        inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
+    ):
+        configs.append(triton.Config({}, num_stages=1, num_warps=8))
+    else:
+        for warps in [1, 2, 4, 8]:
+            configs.append(triton.Config({}, num_stages=1, num_warps=warps))
+
+    return cached_autotune(
+        None,
+        configs,
+        triton_meta=triton_meta,
+        inductor_meta=inductor_meta,
+        heuristic_type=HeuristicType.TEMPLATE,
+        filename=filename,
+    )
+
+
+@dataclasses.dataclass
+class GridExpr:
+    """Generate code for grid size expressions in launcher"""
+
+    inductor_meta: dict[str, Any]
+    mode: Literal["python", "cpp"] = "python"
+    prefix: list[str] = dataclasses.field(default_factory=list)
+    x_grid: str | int = 1
+    y_grid: str | int = 1
+    z_grid: str | int = 1
+
+    def __post_init__(self) -> None:
+        assert self.mode in ("python", "cpp")
+
+    def generate(self, meta: dict[str, int]) -> None:
+        raise NotImplementedError
+
+    def ceildiv(self, numel: str | int, block: None | int | str) -> str | int:
+        if block is None or block == 1:
+            return numel
+        if isinstance(numel, int) and isinstance(block, int):
+            return ceildiv(numel, block)  # constant fold
+        # This trick only works in python, where
+        # negative integer division is floored
+        if self.mode == "python":
+            return f"-(({numel}) // -({block}))"
+        # For cpp code gen
+        return f"(({numel} + ({block} - 1)) / ({block}))"
+
+    def maximum(self, seq: list[int | str]) -> int | str:
+        """Codegen for max function with constant folding, constants are represented as int"""
+        items = self._constant_fold(max, seq)
+        if len(items) <= 1:
+            return items[0]
+        if self.mode == "python":
+            return f"max({', '.join(map(str, items))})"
+        return functools.reduce(lambda x, y: f"std::max({x}, {y})", items)
+
+    def summation(self, seq: list[int | str]) -> int | str:
+        """Codegen for sum function with constant folding, constants are represented as int"""
+        items = self._constant_fold(sum, seq)
+        if len(items) <= 1:
+            return items[0]
+        return " + ".join(map(str, items))
+
+    def _constant_fold(
+        self, fn: Callable[[list[int]], int], seq: list[int | str]
+    ) -> list[int | str]:
+        """Constant fold through a commutative fn where ints are constants"""
+        items: list[int | str] = [x for x in seq if not isinstance(x, int)]
+        const_items = [x for x in seq if isinstance(x, int)]
+        if const_items:
+            items.append(fn(const_items))
+        return items
+
+    def assign_tmp(self, name: str, expr: str | int) -> str:
+        # Grid functions are one per kernel, so name collisions are fine
+        if self.mode == "python":
+            return f"{name} = {expr}"
+        if self.mode == "cpp":
+            return f"uint32_t {name} = {expr};"
+        raise AssertionError(f"invalid mode {self.mode}")
+
+    @staticmethod
+    def from_meta(
+        inductor_meta: dict[str, Any],
+        cfg: Config | dict[str, int],
+        mode: Literal["python", "cpp"] = "python",
+    ) -> GridExpr:
+        grid_cls = globals()[inductor_meta["grid_type"]]
+        assert issubclass(grid_cls, GridExpr)
+        grid = grid_cls(inductor_meta=inductor_meta, mode=mode)
+        if isinstance(cfg, Config):
+            cfg = config_to_dict(cfg)
+        grid.generate(cfg)
+        return grid
+
+    def eval_slow(self, meta: dict[str, int]) -> tuple[int, int, int]:
+        scope = {**meta}
+        for line in self.prefix:
+            exec(line, scope)
+        exec(f"grid_0 = {self.x_grid}", scope)
+        exec(f"grid_1 = {self.y_grid}", scope)
+        exec(f"grid_2 = {self.z_grid}", scope)
+        return scope["grid_0"], scope["grid_1"], scope["grid_2"]
+
+
+class Grid1D(GridExpr):
+    def generate(self, meta: dict[str, int]) -> None:
+        self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
+
+
+class Grid2D(GridExpr):
+    def generate(self, meta: dict[str, int]) -> None:
+        self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
+        self.y_grid = self.ceildiv("ynumel", meta.get("YBLOCK"))
+
+
+class Grid3D(GridExpr):
+    def generate(self, meta: dict[str, int]) -> None:
+        self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
+        self.y_grid = self.ceildiv("ynumel", meta.get("YBLOCK"))
+        self.z_grid = self.ceildiv("znumel", meta.get("ZBLOCK"))
+
+
+class Grid2DWithYZOverflow(GridExpr):
+    def generate(self, meta: dict[str, int]) -> None:
+        self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
+        self.prefix.extend(
+            [
+                self.assign_tmp(
+                    "y_grid_raw_", self.ceildiv("ynumel", meta.get("YBLOCK"))
+                ),
+                self.assign_tmp(
+                    "y_grid_div_", self.ceildiv("y_grid_raw_", get_max_y_grid())
+                ),
+            ]
+        )
+        self.y_grid = self.ceildiv("y_grid_raw_", "y_grid_div_")
+        self.z_grid = "y_grid_div_"
+
+
+class MixOrderReductionGrid(GridExpr):
+    def generate(self, meta: dict[str, int]) -> None:
+        split_size = meta.get("RSPLIT_SIZE")
+        xblock = meta.get("XBLOCK")
+        assert split_size, "Missing RSPLIT_SIZE"
+        assert xblock, "Missing XBLOCK"
+        assert split_size % xblock == 0, f"{split_size=}, {xblock=}"
+        self.x_grid = self.ceildiv("xnumel", split_size)
+
+
+class CooperativeReductionGrid(GridExpr):
+    def generate(self, meta: dict[str, int]) -> None:
+        self.x_grid = str(meta["RSPLIT"])
+        self.y_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
+
+
+class SplitScanGrid(GridExpr):
+    def generate(self, meta: dict[str, int]) -> None:
+        assert meta.get("XBLOCK", 1) == 1
+        self.x_grid = self.ceildiv("r0_numel", meta.get("R0_BLOCK"))
+        self.y_grid = "xnumel"
+
+
+class FixedGrid(GridExpr):
+    @staticmethod
+    def setup_grid_as_args() -> dict[str, Any]:
+        """Inductor meta so the launcher takes three extra grid arguments"""
+        return {
+            "grid_type": FixedGrid.__name__,
+            "fixed_grid": ["_grid_0", "_grid_1", "_grid_2"],
+            "extra_launcher_args": ["_grid_0", "_grid_1", "_grid_2"],
+        }
+
+    def generate(self, meta: dict[str, int]) -> None:
+        self.x_grid, self.y_grid, self.z_grid = self.inductor_meta["fixed_grid"]
+
+
+class PrecomputedGrid(GridExpr):
+    def generate(self, meta: dict[str, int]) -> None:
+        for candidate in self.inductor_meta["precomputed_grids"]:
+            if all(meta.get(k) == v for k, v in candidate["config"].items()):
+                self.x_grid, self.y_grid, self.z_grid = candidate[self.mode]
+                return
+        raise AssertionError(
+            f"Precomputed grid not found for {meta} in {self.inductor_meta['precomputed_grids']}"
+        )
+
+
+class ComboKernelGrid(GridExpr):
+    def generate(self, meta: dict[str, int]):
+        combo_meta = self.inductor_meta["combo_grid_meta"]
+        if combo_meta["default_config"]:
+            meta = {**combo_meta["default_config"], **meta}
+        no_x_dims = []
+        xnumels = []
+        ynumels = []
+        znumels = []
+        for num in range(combo_meta["num_kernels"]):
+            assert (
+                combo_meta[f"xnumel_{num}"] is None or combo_meta[f"xnumel_{num}"] > 0
+            )
+            no_x_dims.append(combo_meta[f"no_x_dim_{num}"])
+            xnumels.append(combo_meta[f"xnumel_{num}"] or f"xnumel_{num}")
+            if f"ynumel_{num}" in combo_meta:
+                ynumels.append(combo_meta[f"ynumel_{num}"] or f"ynumel_{num}")
+            if f"znumel_{num}" in combo_meta:
+                znumels.append(combo_meta[f"znumel_{num}"] or f"znumel_{num}")
+
+        self.x_grid = self.combo_x_grid(xnumels, no_x_dims, meta)
+        if combo_meta["min_blocks"]:
+            self.x_grid = self.maximum([self.x_grid, combo_meta["min_blocks"]])
+        if ynumels:
+            self.y_grid = self.ceildiv(self.maximum(ynumels), meta.get("YBLOCK"))
+        if znumels:
+            self.z_grid = self.ceildiv(self.maximum(znumels), meta.get("ZBLOCK"))
+
+    def combo_x_grid(
+        self,
+        xnumels: list[int | str],
+        no_x_dims: list[bool],
+        meta: dict[str, int],
+    ) -> str | int:
+        raise NotImplementedError
+
+
+class SequentialComboKernelGrid(ComboKernelGrid):
+    def combo_x_grid(
+        self,
+        xnumels: list[int | str],
+        no_x_dims: list[bool],
+        meta: dict[str, int],
+    ) -> str | int:
+        assert len(xnumels) == len(no_x_dims)
+        return self.summation(
+            [
+                self.ceildiv(x, 1 if no_x_dim else meta.get("XBLOCK"))
+                for x, no_x_dim in zip(xnumels, no_x_dims)
+            ]
+        )
+
+
+class RoundRobinComboKernelGrid(ComboKernelGrid):
+    def combo_x_grid(
+        self,
+        xnumels: list[int | str],
+        no_x_dims: list[bool],
+        meta: dict[str, int],
+    ) -> str:
+        assert len(xnumels) == len(no_x_dims)
+        num_kernels = self.inductor_meta["combo_grid_meta"]["num_kernels"]
+        exprs = [x for x, no_x_dim in zip(xnumels, no_x_dims) if no_x_dim]
+        xnumels_x_dim = [x for x, no_x_dim in zip(xnumels, no_x_dims) if not no_x_dim]
+        if xnumels_x_dim:
+            exprs.append(self.ceildiv(self.maximum(xnumels_x_dim), meta.get("XBLOCK")))
+        return f"({self.maximum(exprs)}) * {num_kernels}"
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_pycute/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_pycute/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dfecd307879be861d3faa8f998a285a215edf02c
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_pycute/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_pycute/__pycache__/int_tuple.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_pycute/__pycache__/int_tuple.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fad747d8914d5a781da3599b1344f71cf97e5eca
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_pycute/__pycache__/int_tuple.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_pycute/__pycache__/layout.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_pycute/__pycache__/layout.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..be9da3ad207a316f43e81c99faeff4f85c81c51e
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_pycute/__pycache__/layout.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_pycute/__pycache__/typing.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_pycute/__pycache__/typing.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b1ad9ab3493d53031bf95e89997f56f042235dd4
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_pycute/__pycache__/typing.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/checkpoint/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/checkpoint/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a23973a5e1333ee59c55fd71b2dcf955860ffa2
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/checkpoint/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_optim/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_optim/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bc0ec129c38fb3bf12db8cc2db96bc9ed9fc5800
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_optim/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_optim/__pycache__/api.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_optim/__pycache__/api.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..748ac1c74c460896bafdc915dc8be872db487511
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_optim/__pycache__/api.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4c3cd8ade2e3c7e7ceb84089b87a074b4fe7ca32
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/api.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/api.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf7261b5c68c710896bb7199b84137650bff8980
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/api.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/logger.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/logger.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4dbdf962769f53ca1dbb22cd95d28a187a484333
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/logger.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/logging_handlers.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/logging_handlers.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1182ce4f4541d3a221ee896415f2e0d8099bffcf
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/logging_handlers.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/metadata.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/metadata.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..da1c99e30424b5dbbc6fabf57f3c17e9899b5761
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/metadata.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/reshard.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/reshard.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..28c18dc3c98963451f11144e58bde4a5794ebfea
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/reshard.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/shard.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/shard.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9c63ef4c38cb3ceadcc889e5e168f4007aa62708
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/shard.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d197822bf6dd9f36987baa1760c30d1f0535d620
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c61f48cca12cb005a644e1068c09e7fa7335fe10
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/_common.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/_common.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8f03efc0b151c5853452f783275c8950eafc35f2
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/_common.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/binary_cmp.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/binary_cmp.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b71b3c88174d59becd66cb818a8fe3c434356e45
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/binary_cmp.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/init.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/init.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a999d6e32f5aa061a9f6eafb53eb597e13b3582b
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/init.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/misc_ops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/misc_ops.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b1b7d2ea887928bf17459d3552da2257eaecd233
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/misc_ops.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/tensor_ops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/tensor_ops.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0a7991130af81bca593c3acfaa928b33dd7fc1d3
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/tensor_ops.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/_common.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a356e524a47a6f1e73022a707f19d7ddb8c935d
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/_common.py
@@ -0,0 +1,115 @@
+# mypy: allow-untyped-defs
+import functools
+
+from torch.distributed._shard.common_op_utils import _basic_validation
+from torch.distributed._shard.sharded_tensor import (
+    _sharded_op_impl,
+    Shard,
+    ShardedTensor,
+)
+
+
+def _sharded_op_common(op, early_stop_func, extra_check):
+    """
+    Inject sharded tensor op registration with common logics executed before
+    different behaviors are done on either local shards or a local tensor.
+
+    Example::
+        >>> # xdoctest: +SKIP("Undefined variables")
+        >>> op = torch.transpose
+        >>> @_sharded_op_impl(op)
+        >>> @_sharded_op_common(op, early_stop_func, extra_check)
+        >>> def sharded_tensor_op(types, args, kwargs, process_group):
+        >>>   ...
+        >>>
+        >>> st = sharded_tensor.rand(32, 16)
+        >>> st.transpose(1, 2)
+        >>> # This will call '_sharded_op_common'
+
+    Args:
+        op: The op to be registered and applied to all shards of the st.
+        early_stop_func (Callable, optional): the func for early stop.
+            Default: if ``None``, no early stop.
+        extra_check (Callable, optional): the func for extra condition check.
+            Default: if ``None``, no extra check.
+
+    Return:
+        func (Callable): Torch function for which we want to provide a sharded
+            implementation (ex: torch.transpose)
+    """
+
+    def decorator_sharded_func(wrapped_func):
+        @functools.wraps(wrapped_func)
+        def wrapper(types, args=(), kwargs=None, pg=None):
+            _basic_validation(op, args, kwargs)
+
+            # pyrefly: ignore [index-error]
+            st = args[0]
+            if kwargs is None:
+                kwargs = {}
+            if extra_check:
+                extra_check(*args, **kwargs)
+            if early_stop_func:
+                early_stop = early_stop_func(*args, **kwargs)
+                if early_stop:
+                    return st
+            return wrapped_func(types, args, kwargs, pg)
+
+        return wrapper
+
+    return decorator_sharded_func
+
+
+def _register_sharded_op_on_local_shards(
+    op, early_stop_func=None, extra_check=None, customized_func=None
+):
+    """
+    Handles ``__torch_function__`` dispatch for ops which are performed on
+    each shard of the sharded tensor such as elementwise op like
+    ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
+
+    For more complicated ops, a customized func can be used to generate
+    the new shards and sharded tensor size.
+
+    This function expects that the original ShardingSpec for the ShardedTensor
+    is preserved irrespective of whether or not a customized function is used.
+
+    Args:
+        op: The op to be registered and applied to all shards of the st.
+        early_stop_func (Callable, optional): the func for early stop.
+            Default: if ``None``, no early stop.
+        extra_check (Callable, optional): the func for extra condition check.
+            Default: if ``None``, no extra check.
+        customized_func (Callable, optional): the func for customized logic
+            to generate new shards and sharded tensor size.
+            Default: if ``None``, we simply lower to the real op call with
+                all local shards of the st.
+
+    Return:
+        func (Callable): registered implementation for sharded op for
+        ``__torch_function__`` dispatch.
+    """
+
+    @_sharded_op_impl(op)
+    @_sharded_op_common(op, early_stop_func, extra_check)
+    def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None):
+        # pyrefly: ignore [index-error]
+        st = args[0]
+        st_metadata = st.metadata()
+        local_shards = st.local_shards()
+        local_shards_new = []
+        if customized_func:
+            local_shards_new, st_metadata = customized_func(args, kwargs, pg)
+        else:
+            for local_shard in local_shards:
+                args = (local_shard.tensor, *args[1:])
+                local_shards_new.append(
+                    Shard(op(*args, **kwargs), local_shard.metadata)
+                )
+        return ShardedTensor._init_from_local_shards_and_global_metadata(
+            local_shards_new,
+            st_metadata,
+            process_group=pg,
+            init_rrefs=st._init_rrefs,
+            sharding_spec=st.sharding_spec(),
+        )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py
new file mode 100644
index 0000000000000000000000000000000000000000..0548b81fb90af087593d05695418664c6d109f2d
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py
@@ -0,0 +1,78 @@
+# mypy: allow-untyped-defs
+import torch
+import torch.distributed as dist
+import torch.distributed.distributed_c10d as distributed_c10d
+from torch.distributed._shard.sharded_tensor import _sharded_op_impl, ShardedTensor
+
+
+def _communicate_result(result, pg):
+    # Gather results from all ranks.
+    if result:
+        result_tensor = torch.ones(1, device=torch.device(torch.cuda.current_device()))
+    else:
+        result_tensor = torch.zeros(1, device=torch.device(torch.cuda.current_device()))
+
+    dist.all_reduce(result_tensor, group=pg)
+
+    expected_result = torch.ones(
+        1, device=torch.device(torch.cuda.current_device())
+    ) * dist.get_world_size(pg)
+
+    return torch.equal(result_tensor, expected_result)
+
+
+def binary_cmp(cmp_fun, types, args, kwargs=None, process_group=None):
+    if len(args) != 2:
+        raise ValueError(f"Expected two arguments for torch.{cmp_fun.__name__}")
+
+    st1 = args[0]
+    st2 = args[1]
+    if not (isinstance(st1, ShardedTensor) and isinstance(st2, ShardedTensor)):
+        raise TypeError(
+            f"Both arguments to torch.{cmp_fun.__name__} need to be of type ShardedTensor"
+        )
+
+    # Verify same PG
+    if st1._process_group != st2._process_group:
+        return False
+
+    if distributed_c10d._rank_not_in_group(
+        st1._process_group
+    ) or distributed_c10d._rank_not_in_group(st2._process_group):
+        return distributed_c10d._rank_not_in_group(
+            st1._process_group
+        ) == distributed_c10d._rank_not_in_group(st2._process_group)
+
+    # Verify metadata
+    if st1.metadata() != st2.metadata():
+        return _communicate_result(False, st1._process_group)
+
+    # Verify number of local shards
+    st1_local_shards = st1.local_shards()
+    st2_local_shards = st2.local_shards()
+    if len(st1_local_shards) != len(st2_local_shards):
+        return _communicate_result(False, st1._process_group)
+
+    # kwargs must be dict-like
+    if kwargs is None:
+        kwargs = {}
+    # Verify each local shard
+    for idx in range(len(st1_local_shards)):
+        if st1_local_shards[idx].metadata != st2_local_shards[idx].metadata:
+            return _communicate_result(False, st1._process_group)
+        if not cmp_fun(
+            st1_local_shards[idx].tensor, st2_local_shards[idx].tensor, **kwargs
+        ):
+            return _communicate_result(False, st1._process_group)
+
+    return _communicate_result(True, st1._process_group)
+
+
+@_sharded_op_impl(torch.equal)
+def equal(types, args, kwargs, process_group):
+    return binary_cmp(torch.equal, types, args, kwargs, process_group)
+
+
+@_sharded_op_impl(torch.allclose)
+def allclose(types, args, kwargs, process_group):
+    return binary_cmp(torch.allclose, types, args, kwargs, process_group)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/init.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/init.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0e576b45ebeeda7661e0011b6a100cd60d0f5f4
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/init.py
@@ -0,0 +1,164 @@
+# mypy: allow-untyped-defs
+import torch
+import torch.distributed._shard.sharded_tensor as sharded_tensor
+from torch.distributed._shard.sharded_tensor import _sharded_op_impl
+
+
+def validate_param(param, param_name):
+    if param is None:
+        raise ValueError(f"param: {param_name} shouldn't be None!")
+
+
+@_sharded_op_impl(torch.nn.init.uniform_)
+def uniform_(types, args=(), kwargs=None, pg=None):
+    r"""
+    Fills the Tensor in tensor.local_shards with values drawn from the uniform
+    distribution :math:`\mathcal{U}(a, b)`.
+    Args:
+        tensor: tensor sharded across devices
+        a: the lower bound of the uniform distribution
+        b: the upper bound of the uniform distribution
+    """
+    validate_param(kwargs, "kwargs")
+    # pyrefly: ignore [unsupported-operation]
+    sharded_tensor = kwargs["tensor"]
+    validate_param(sharded_tensor, "tensor")
+    # pyrefly: ignore [unsupported-operation]
+    a = kwargs["a"]
+    validate_param(a, "a")
+    # pyrefly: ignore [unsupported-operation]
+    b = kwargs["b"]
+    validate_param(b, "b")
+
+    for shard in sharded_tensor.local_shards():
+        torch.nn.init.uniform_(shard.tensor, a=a, b=b)
+    return sharded_tensor
+
+
+@_sharded_op_impl(torch.nn.init.normal_)
+def normal_(types, args=(), kwargs=None, pg=None):
+    r"""
+    Fills the Tensors in tensor.local_shards with values drawn from the normal
+    distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
+    Args:
+        tensor: tensor sharded across devices
+        mean: the mean of the normal distribution
+        std: the standard deviation of the normal distribution
+    """
+    validate_param(kwargs, "kwargs")
+    # pyrefly: ignore [unsupported-operation]
+    sharded_tensor = kwargs["tensor"]
+    validate_param(sharded_tensor, "tensor")
+    # pyrefly: ignore [unsupported-operation]
+    mean = kwargs["mean"]
+    validate_param(mean, "mean")
+    # pyrefly: ignore [unsupported-operation]
+    std = kwargs["std"]
+    validate_param(std, "std")
+
+    for shard in sharded_tensor.local_shards():
+        torch.nn.init.normal_(shard.tensor, mean=mean, std=std)
+    return sharded_tensor
+
+
+@_sharded_op_impl(torch.nn.init.kaiming_uniform_)
+def kaiming_uniform_(types, args=(), kwargs=None, pg=None):
+    r"""
+    Fills the Tensors in tensor.local_shards with values according to the method
+    described in `Delving deep into rectifiers: Surpassing human-level
+    performance on ImageNet classification` - He, K. et al. (2015), using a
+    uniform distribution. The resulting tensor will have values sampled from
+    :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
+    .. math::
+        \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
+    Also known as He initialization.
+    Args:
+        tensor: tensor sharded across devices
+        a: the negative slope of the rectifier used after this layer (only
+            used with ``'leaky_relu'``)
+        mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
+            preserves the magnitude of the variance of the weights in the
+            forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
+            backwards pass.
+        nonlinearity: the non-linear function (`nn.functional` name),
+            recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
+    """
+    validate_param(kwargs, "kwargs")
+    # pyrefly: ignore [unsupported-operation]
+    sharded_tensor = kwargs["tensor"]
+    validate_param(sharded_tensor, "tensor")
+    # pyrefly: ignore [unsupported-operation]
+    a = kwargs["a"]
+    validate_param(a, "a")
+    # pyrefly: ignore [unsupported-operation]
+    mode = kwargs["mode"]
+    validate_param(mode, "mode")
+    # pyrefly: ignore [unsupported-operation]
+    nonlinearity = kwargs["nonlinearity"]
+    validate_param(nonlinearity, "nonlinearity")
+
+    for shard in sharded_tensor.local_shards():
+        torch.nn.init.kaiming_uniform_(
+            shard.tensor, a=a, mode=mode, nonlinearity=nonlinearity
+        )
+    return sharded_tensor
+
+
+@_sharded_op_impl(torch.nn.init.constant_)
+def constant_(types, args=(), kwargs=None, pg=None):
+    r"""
+    Fills the input ShardedTensor with the value \text{val}val.
+    Args:
+        tensor: tensor sharded across devices
+        val: the value to fill the tensor with
+    """
+    validate_param(kwargs, "kwargs")
+    # pyrefly: ignore [unsupported-operation]
+    sharded_tensor = kwargs["tensor"]
+    validate_param(sharded_tensor, "tensor")
+    # pyrefly: ignore [unsupported-operation]
+    val = kwargs["val"]
+    validate_param(val, "val")
+    for shard in sharded_tensor.local_shards():
+        torch.nn.init.constant_(shard.tensor, val=val)
+    return sharded_tensor
+
+
+tensor_like_creation_op_map = {
+    torch.full_like: sharded_tensor.full,
+    torch.empty_like: sharded_tensor.empty,
+    torch.zeros_like: sharded_tensor.zeros,
+    torch.ones_like: sharded_tensor.ones,
+    torch.rand_like: sharded_tensor.rand,
+    torch.randn_like: sharded_tensor.randn,
+}
+
+
+# tensor ops that behave the same as the default tensor
+def register_tensor_creation_op(op):
+    @_sharded_op_impl(op)
+    def tensor_creation_op(types, args=(), kwargs=None, pg=None):
+        """
+        Handles ``__torch_function__`` dispatch for tensor creation ops that
+        takes a ShardedTensor as argument, such as ``torch.zeros_like`` or
+        ``torch.full_like``.
+        """
+        creation_op = tensor_like_creation_op_map.get(op)
+        if creation_op is None:
+            raise RuntimeError(f"Tensor creation {op} not supported!")
+        if kwargs is None:
+            kwargs = {}
+
+        # pyrefly: ignore [index-error]
+        st = args[0]
+
+        new_st = creation_op(st.sharding_spec(), st.size(), *args[1:], **kwargs)  # type: ignore[operator]
+        return new_st
+
+
+register_tensor_creation_op(torch.full_like)
+register_tensor_creation_op(torch.empty_like)
+register_tensor_creation_op(torch.zeros_like)
+register_tensor_creation_op(torch.ones_like)
+register_tensor_creation_op(torch.rand_like)
+register_tensor_creation_op(torch.randn_like)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5b7ad7c77b1b7948f5464cde0bee0f703d738fb
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py
@@ -0,0 +1,222 @@
+# mypy: allow-untyped-defs
+import copy
+
+import torch
+from torch.distributed._shard.common_op_utils import _register_default_op
+from torch.distributed._shard.sharded_tensor import (
+    _sharded_op_impl,
+    Shard,
+    ShardedTensor,
+)
+
+from ._common import _register_sharded_op_on_local_shards
+
+
+# Tensor properties access
+_register_default_op(torch.Tensor.shape.__get__, _sharded_op_impl)  # type: ignore[attr-defined]
+_register_default_op(torch.Tensor.dtype.__get__, _sharded_op_impl)  # type: ignore[attr-defined]
+_register_default_op(torch.Tensor.layout.__get__, _sharded_op_impl)  # type: ignore[attr-defined]
+_register_default_op(torch.Tensor.size, _sharded_op_impl)
+_register_default_op(torch.Tensor.dim, _sharded_op_impl)
+_register_default_op(torch.Tensor.ndim.__get__, _sharded_op_impl)  # type: ignore[attr-defined]
+_register_default_op(torch.Tensor.is_contiguous, _sharded_op_impl)
+_register_default_op(torch.Tensor.contiguous, _sharded_op_impl)
+_register_default_op(torch.Tensor.is_floating_point, _sharded_op_impl)
+
+# __reduce_ex__ to dispatch to get_state/set_state
+_register_default_op(torch.Tensor.__reduce_ex__, _sharded_op_impl)
+
+# autograd related properties
+_register_default_op(torch.Tensor.requires_grad.__get__, _sharded_op_impl)  # type: ignore[attr-defined]
+# TODO: set grad with a ShardedTensor that consists of all local grads
+_register_default_op(torch.Tensor.grad.__get__, _sharded_op_impl)  # type: ignore[union-attr]
+_register_default_op(torch.Tensor.grad_fn.__get__, _sharded_op_impl)  # type: ignore[union-attr]
+_register_default_op(torch.Tensor.is_leaf.__get__, _sharded_op_impl)  # type: ignore[attr-defined]
+
+
+# device property is ambiguous as from a global prospective,
+# ShardedTensor.device consists of multiple devices (might even across hosts)
+# We choose to return the current device of the local tensor to represent
+# the device property on each rank
+@_sharded_op_impl(torch.Tensor.device.__get__)
+def tensor_device(types, args=(), kwargs=None, pg=None):
+    # pyrefly: ignore [index-error]
+    self_st = args[0]
+    # Validate types
+    if not isinstance(self_st, ShardedTensor):
+        raise TypeError("input needs to be a ShardedTensor")
+    dev: torch.device
+    if self_st._local_shards:
+        dev = self_st._local_shards[0].tensor.device
+    elif pg and pg._get_backend_name() == "gloo":
+        dev = torch.device("cpu")
+    else:
+        dev = torch.device(torch.cuda.current_device())
+    return dev
+
+
+@_sharded_op_impl(torch.Tensor.is_meta.__get__)  # type: ignore[attr-defined]
+def st_is_meta(types, args=(), kwargs=None, pg=None):
+    # pyrefly: ignore [index-error]
+    return args[0].local_tensor().is_meta
+
+
+def sharded_type_as_check(*args, **kwargs):
+    """
+    Perform extra checks for the sharded_type_as op such as the input needs to
+    be either a Tensor or ShardedTensor.
+
+    Args: same as ``torch.Tensor.type_as``.
+
+    Return: None
+    """
+    if len(args) < 2:
+        raise ValueError("Needs to give a tensor to cast type as!")
+    if not isinstance(args[1], torch.Tensor) and not isinstance(args[1], ShardedTensor):
+        raise ValueError("Needs to give a Tensor or ShardedTensor to cast type as!")
+
+
+def same_dtype(*args, **kwargs):
+    """
+    When the dtype is the same, return the original ShardedTensor.
+
+    Args: same as ``torch.Tensor.type_as``.
+
+    Return (bool): Whether to return early or not.
+    """
+    return args[0].dtype == args[1].dtype
+
+
+def sharded_type_as(args, kwargs, pg):
+    """
+    Handles ``__torch_function__`` dispatch for the ``torch.Tensor.type_as`` op.
+
+    Args: same as ``torch.Tensor.type_as``.
+
+    Return:
+        new_local_shards (List[Shard]): Local shards for the new sharded tensor.
+        st_meta (ShardedTensorMetadata): Metadata of the new sharded tensor.
+    """
+    st = args[0]
+    tensor = args[1]
+    if isinstance(tensor, ShardedTensor):
+        tensor = tensor.local_tensor()
+    new_local_shards = [
+        Shard(shard.tensor.type_as(tensor), shard.metadata)
+        for shard in st.local_shards()
+    ]
+    st_meta = copy.deepcopy(st._metadata)
+    st_meta.tensor_properties.dtype = tensor.dtype
+    return new_local_shards, st_meta
+
+
+_register_sharded_op_on_local_shards(
+    torch.Tensor.type_as,
+    early_stop_func=same_dtype,
+    extra_check=sharded_type_as_check,
+    customized_func=sharded_type_as,
+)
+
+
+def sharded_deepcopy(args, kwargs, pg):
+    # NOTE: we directly implement deepcopy magic method
+    # instead of using the default tensor.__deepcopy__
+    # and implement clone(). This is because the default
+    # tensor deepcopy copies every attribute, but the
+    # process_group in ShardedTensor cannot be deep copied.
+    self_st = args[0]
+    new_local_shards = copy.deepcopy(self_st.local_shards())
+    new_metadata = copy.deepcopy(self_st.metadata())
+    return new_local_shards, new_metadata
+
+
+_register_sharded_op_on_local_shards(
+    torch.Tensor.__deepcopy__,
+    customized_func=sharded_deepcopy,
+)
+
+
+@_sharded_op_impl(torch.Tensor.copy_)
+def sharded_inplace_copy(types, args, kwargs, pg):
+    # NOTE: inplace op don't need to rewrap
+    kwargs = {} if kwargs is None else kwargs
+    self_st = args[0]
+    new_st = args[1]
+    nonblocking = kwargs.get("non_blocking", False)
+    for local_shard, new_shard in zip(self_st.local_shards(), new_st.local_shards()):
+        if local_shard.metadata != new_shard.metadata:
+            raise RuntimeError(
+                "inplace copy can only happen between two ShardedTensor with same metadata!"
+            )
+    for local_shard, new_shard in zip(self_st.local_shards(), new_st.local_shards()):
+        local_shard.tensor.copy_(new_shard.tensor, nonblocking)
+
+    return self_st
+
+
+def sharded_clone(args, kwargs, pg):
+    self_st = args[0]
+    desire_memory_format = kwargs.get("memory_format", None)
+    if desire_memory_format and desire_memory_format != torch.preserve_format:
+        raise RuntimeError("Only support torch.preserve_format for ShardedTensor!")
+    cloned_local_shards = [
+        Shard(
+            local_shard.tensor.clone(memory_format=desire_memory_format),
+            metadata=copy.deepcopy(local_shard.metadata),
+        )
+        for local_shard in self_st.local_shards()
+    ]
+    new_metadata = copy.deepcopy(self_st.metadata())
+    return cloned_local_shards, new_metadata
+
+
+_register_sharded_op_on_local_shards(
+    torch.Tensor.clone,
+    customized_func=sharded_clone,
+)
+
+
+def sharded_detach(args, kwargs, pg):
+    self_st = args[0]
+    detached_local_shards = [
+        Shard(
+            local_shard.tensor.detach(),
+            metadata=copy.deepcopy(local_shard.metadata),
+        )
+        for local_shard in self_st.local_shards()
+    ]
+    new_metadata = copy.deepcopy(self_st.metadata())
+    new_metadata.tensor_properties.requires_grad = False
+    return detached_local_shards, new_metadata
+
+
+_register_sharded_op_on_local_shards(
+    torch.Tensor.detach,
+    customized_func=sharded_detach,
+)
+
+
+@_sharded_op_impl(torch.Tensor.requires_grad_)
+def tensor_requires_grad_set(types, args=(), kwargs=None, pg=None):
+    # pyrefly: ignore [index-error]
+    self_st = args[0]
+    # Validate types
+    if not isinstance(self_st, ShardedTensor):
+        raise TypeError("input needs to be a ShardedTensor")
+
+    if kwargs is None:
+        kwargs = {}
+
+    requires_grad = args[1] if len(args) > 1 else kwargs.get("requires_grad", True)
+    if requires_grad == self_st.requires_grad:
+        return self_st
+
+    for local_shard in self_st.local_shards():
+        local_shard.tensor.requires_grad_(requires_grad)
+
+        # update the wrapper class property
+    with torch._C.DisableTorchFunctionSubclass():
+        self_st.requires_grad_(requires_grad)
+    # update the metadata in the meanwhile
+    self_st._metadata.tensor_properties.requires_grad = requires_grad
+    return self_st
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_plan/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_plan/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d9188cd0c67b34ce8f5bf862521b0ba886a532ee
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_plan/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_plan/__pycache__/api.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_plan/__pycache__/api.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e64714cd9f608a18cba22a46de903b46d55d3b0d
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_plan/__pycache__/api.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fcc5b58f61be961460e8223974b0e07e5e16a449
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/_internals.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/_internals.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..66aed89b50a787b2fbaf496609304c6cc15a3d67
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/_internals.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/api.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/api.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..82e81bb905286ed41fed0167d0a3fcf14e2c3bbc
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/api.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/chunk_sharding_spec.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/chunk_sharding_spec.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..09cb60ef11d1ba1b0889bfe450bcc6420f1b1812
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/chunk_sharding_spec.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..106560957d74870279650e67ec8eaefbc6599ad4
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/_common.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/_common.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dca2d7e276e5ef02486d7df9615ffe53d8f1e83b
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/_common.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/embedding.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/embedding.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d9279e9fee1949ab8c805c7a3ef1d4500810345c
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/embedding.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/embedding_bag.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/embedding_bag.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5fa64440ce41df4f4aa66ce19121ec8f63288787
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/embedding_bag.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a8a05fe79d19d2dc67e6ff535ae419e255192c1
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py
@@ -0,0 +1,350 @@
+# mypy: allow-untyped-defs
+
+import torch
+import torch.distributed as dist
+from torch.distributed._shard.sharded_tensor import ShardedTensor
+from torch.distributed._shard.sharded_tensor._ops._common import _sharded_op_common
+from torch.distributed._shard.sharding_spec import ChunkShardingSpec
+from torch.distributed._shard.sharding_spec._internals import (
+    get_chunk_sharding_params,
+    get_chunked_dim_size,
+    get_split_size,
+)
+from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op
+from torch.distributed.nn.functional import (
+    _all_gather_base,
+    all_reduce,
+    all_to_all_single,
+)
+
+
+def _chunk_sharding_spec_check(spec, op):
+    """
+    For the given op implementation check if the sharding spec is ChunkShardingSpec.
+    """
+    if not isinstance(spec, ChunkShardingSpec):
+        raise NotImplementedError(
+            f"Only ChunkShardingSpec supported for '{op.__name__}'."
+        )
+
+
+def _register_sharded_op_on_local_tensor(
+    op, early_stop_func=None, extra_check=None, customized_func=None
+):
+    """
+    Handles ``__torch_function__`` dispatch for ops which are performed on
+    the single local tensor of the sharded tensor such as op like
+    ``torch.nn.functional.softmax`` or ``torch.Tensor.view``.
+
+    For more complicated ops, a customized func can be used to generate
+    the new local tensor, sharding spec and sharded tensor size.
+
+    Args:
+        op: The op to be registered and applied to all shards of the st.
+        early_stop_func (Callable, optional): the func for early stop.
+            Default: if ``None``, no early stop.
+        extra_check (Callable, optional): the func for extra condition check.
+            Default: if ``None``, no extra check.
+        customized_func (Callable, optional): the func for customized logic
+            to generate the new local tensor, sharding spec and sharded tensor size.
+            Default: if ``None``, we simply lower to the real op call with
+                the single local tensor of the st.
+
+    Return:
+        func (Callable): registered implementation for sharded op for
+        ``__torch_function__`` dispatch.
+    """
+
+    @custom_sharding_spec_op(ChunkShardingSpec, op)
+    @_sharded_op_common(op, early_stop_func, extra_check)
+    def sharded_tensor_op_on_local_tensor(types, args=(), kwargs=None, pg=None):
+        # pyrefly: ignore [index-error]
+        st = args[0]
+        sharding_spec = st.sharding_spec()
+        if len(st.local_shards()) != 1:
+            raise TypeError(
+                f"torch function '{op.__name__}', with args: {args} and "
+                f"kwargs: {kwargs} only supported for single local tensor!"
+            )
+        st_size = st.size()
+        if customized_func:
+            local_tensor, sharding_spec, st_size = customized_func(args, kwargs, pg)
+        else:
+            args = (st.local_tensor(), *args[1:])
+            local_tensor = op(*args, **kwargs)
+        return ShardedTensor._init_from_local_tensor(
+            local_tensor.contiguous(),
+            sharding_spec,
+            st_size,  # type: ignore[arg-type]
+            process_group=pg,
+            init_rrefs=st._init_rrefs,
+        )
+
+
+def _handle_col_wise_sharding_base(
+    op_func,
+    col_dim,
+    input,
+    world_size,
+    weight,
+    local_shard,
+    pg,
+    gathered_inputs,
+    mode=None,
+    gathered_per_sample_weights=None,
+    gathered_offsets=None,
+    padding_idx=None,
+):
+    """
+    For col-wise sharding of weight, lots of logic are common.
+    So we extract the common logic and put in this function:
+    Step 1. To get input from each rank and
+    Step 2. To perform the op on the concatenated tensor.
+    Step 3. To distribute results to each rank with col rearrangement.
+    Step 4. To concatenate all results from all ranks.
+
+    Args:
+        op_func: operator which is applied to the input tensor.
+        col_dim: dim of result tensor after the operation.
+        input: tensor to be applied op on.
+        world_size: number of ranks.
+        weight: sharded weight tensor.
+        local_shard: col-wise sharded weight tensor.
+        pg: process group.
+        gathered_inputs: list of inputs from all ranks. If specified, we
+            don't need to communicate with each rank any more.
+        mode: aggregation mode of EmbeddingBag.
+        gathered_per_sample_weights: per_sample_weights across all ranks.
+        gathered_offsets: offsets across all ranks.
+        padding_idx: If specified, the entries at padding_idx do
+            not contribute to the gradient; therefore, the embedding
+            vector at padding_idx is not updated during training,
+            i.e. it remains as a fixed "pad".
+            Note that the embedding vector at padding_idx is
+            excluded from the reduction.
+
+    Return: final result of input being applied with the op.
+    """
+    # run the operator's function for all the inputs.
+    results = []
+    for i, inp in enumerate(gathered_inputs):
+        if op_func is torch.nn.functional.embedding_bag:
+            result = op_func(
+                inp,
+                local_shard,
+                offsets=gathered_offsets[i] if gathered_offsets is not None else None,
+                # pyrefly: ignore [bad-argument-type]
+                mode=mode,
+                per_sample_weights=gathered_per_sample_weights[i]
+                if gathered_per_sample_weights is not None
+                else None,
+                padding_idx=padding_idx,
+            )
+        elif op_func is torch.nn.functional.embedding:
+            result = op_func(
+                inp,
+                local_shard,
+                padding_idx=padding_idx,
+            )
+        else:
+            result = op_func(inp, local_shard)
+        results.append(torch.transpose(result, 0, col_dim))
+
+    # Distribute results to each rank with col rearrangement.
+    output = _result_distribute_with_col_rearrange(
+        results, input, world_size, weight, pg
+    )
+
+    # transpose the output and return result.
+    return torch.transpose(output, 0, col_dim)
+
+
+def _result_distribute_with_col_rearrange(results, input, world_size, weight, pg):
+    """
+    For col-wise sharding of weight, we need to distribute
+    results to each rank. We do them in this function.
+    Note that, if the index in the Sharding Spec is not equal to
+    the rank number, we need to do the rearrangement based on the
+    order given by the Sharding Spec (placement).
+
+    Args:
+        results: results from ops applied to inputs from all ranks.
+            We need to distribute them back to their original ranks.
+        input: tensor to be applied op to.
+        world_size: number of ranks.
+        weight: sharded weight tensor.
+        pg: process group.
+
+    Return: column rearranged result.
+    """
+    # Process results and outputs for all2all.
+    sharding_dim = weight._sharding_spec.dim
+    sharding_dim_size = weight.size(sharding_dim)
+    dims = list(results[0].size())
+    dims[0] = sharding_dim_size
+    combined_results = torch.cat(results)
+    output = torch.empty(
+        *dims, device=combined_results.device, dtype=combined_results.dtype
+    )
+
+    # Compute output splits
+    split_size = get_split_size(sharding_dim_size, world_size)
+    output_split_sizes = [0] * world_size
+    for idx, placement in enumerate(weight._sharding_spec.placements):
+        output_split_sizes[placement.rank()] = get_chunked_dim_size(
+            sharding_dim_size, split_size, idx
+        )
+
+    # distribute the outputs using all2all.
+    output = all_to_all_single(
+        output, combined_results, output_split_sizes=output_split_sizes, group=pg
+    )
+
+    # Check if we need to rearrange columns appropriately for output.
+    rearrange_columns = any(
+        idx != placement.rank()
+        for idx, placement in enumerate(weight._sharding_spec.placements)
+    )
+    if not rearrange_columns:
+        return output
+
+    indices = []
+    for placement in weight._sharding_spec.placements:
+        dim_size = output_split_sizes[placement.rank()]
+        start = sum(
+            split_size if i < placement.rank() else 0
+            for i, split_size in enumerate(output_split_sizes)
+        )
+        indices += list(range(start, start + dim_size))
+
+    return output.index_select(0, torch.tensor(indices, device=output.device))
+
+
+def _handle_max_norm_col_wise(
+    max_norm,
+    norm_type,
+    local_shard,
+    input,
+    world_size,
+    gathered_inputs,
+    pg,
+):
+    """
+    For col-wise sharding of weight, we need to aggregate the
+    norm across all ranks before we can perform the proper re-norm.
+    Note that, the max_norm logic is only applied to the embedding
+    indices that are looked up and not the whole shard.
+
+    Args:
+        max_norm: If given, each embedding vector with norm larger
+            than max_norm is renormalized to have norm max_norm.
+            Note: this will modify weight in-place.
+        norm_type: The p in the p-norm to compute for the max_norm option.
+        local_shard: col-wise shared local weight used for lookup.
+        input: tensor to be applied op to.
+        world_size: number of ranks.
+        gathered_inputs: list of inputs from all ranks.
+        pg: process group.
+
+    Return:
+        local_shard_norm_renormed: local_shard re-normed to max_norm if the norm is larger
+            than it.
+
+    """
+    norm_type = norm_type if norm_type is not None else 2.0
+    unique_inp = torch.unique(torch.cat(gathered_inputs))
+    local_shard_sum = torch.sum(
+        torch.pow(torch.abs(local_shard), norm_type), dim=1, dtype=local_shard.dtype
+    )
+    # For col-wise sharding, we need to first aggregate the powered sum
+    # from each rank first and then calculate the norm.
+    local_shard_sum = all_reduce(local_shard_sum, group=pg)
+    local_shard_norm = torch.pow(local_shard_sum, 1.0 / norm_type)
+    max_norm_tensor = torch.full(
+        (local_shard.size(0),),
+        float("inf"),
+        dtype=local_shard.dtype,
+        device=input.device,
+    )
+    max_norm_tensor[unique_inp] = max_norm
+    local_shard_t = local_shard.t().contiguous()
+    normalized_tensor = torch.where(
+        local_shard_norm > max_norm_tensor, max_norm_tensor, local_shard_norm
+    )
+    # Make sure divisor is not zero.
+    local_shard_norm[local_shard_norm == 0.0] = 1.0
+    local_shard_norm_renormed = (
+        torch.div(torch.mul(local_shard_t, normalized_tensor), local_shard_norm)
+        .t()
+        .contiguous()
+    )
+    return local_shard_norm_renormed
+
+
+def _all_gather_base_input(input, pg):
+    """
+    Use _all_gather_base to get a concatenated input from each rank.
+
+    Args:
+        input: tensor to be applied op on.
+        pg: process group.
+
+    Returns:
+        gathered_inputs: input gathered from each rank and concat by dim 0.
+    """
+    # allgather the inputs first.
+    gather_inp_size = list(input.size())
+    gather_inp_size[0] = input.size(0) * dist.get_world_size(pg)
+    gather_inp = torch.empty(gather_inp_size, device=input.device, dtype=input.dtype)
+    return _all_gather_base(gather_inp, input, group=pg)
+
+
+def _handle_row_wise_mask(gather_inp, padding_idx, weight, world_size, rank):
+    """
+    Mask the input for embedding look-up for IDs which are not stored
+    on the current rank. This function also adjust the ``padding_idx``
+    so that it is only used on the rank where the corresponding row is
+    stored.
+
+    Note that, with ``max_norm`` flag on, only weights of rows being
+    looked up will be re-normed. So we need an extra row for masked ID
+    so that it does not affect the final result and ``max_norm``.
+
+    Args:
+        gather_inp: tensor to be applied op on gathered from all ranks.
+        padding_idx: If specified, the entries at padding_idx do
+            not contribute to the gradient; therefore, the embedding
+            vector at padding_idx is not updated during training,
+            i.e. it remains as a fixed "pad".
+            Note that the embedding vector at padding_idx is
+            excluded from the reduction.
+        weight: weight tensor of Embedding look-up table.
+        world_size: number of ranks.
+        rank: # of cuda process.
+
+    Returns:
+        lookup_input: Tensor of masked input.
+        padding_idx: adjusted padding_idx.
+        padding_row: The extra row we used during lookup so that
+            looking up does not affect ``max_norm``.
+    """
+    (start_pos, chunk_size) = get_chunk_sharding_params(
+        weight.size(0), world_size, weight._sharding_spec, rank
+    )
+    mask = (gather_inp < start_pos) | (gather_inp >= start_pos + chunk_size)
+    lookup_input = gather_inp.clone() - start_pos
+    lookup_input[mask] = chunk_size
+    if (
+        padding_idx is not None
+        and padding_idx >= start_pos
+        and padding_idx < (start_pos + chunk_size)
+    ):
+        padding_idx = padding_idx - start_pos
+    else:
+        padding_idx = None
+
+    # When max_norm is set, it will only re-norm the row being looked up.
+    padding_row = torch.zeros(
+        1, weight.size(1), device=gather_inp.device, dtype=weight.dtype
+    )
+    return lookup_input, padding_idx, padding_row
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..117aed79520d9ad78c10bdd2310fb6b032c2a024
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py
@@ -0,0 +1,294 @@
+# mypy: allow-untyped-defs
+
+import torch
+import torch.distributed as dist
+from torch.distributed._shard.sharded_tensor import ShardedTensor
+from torch.distributed._shard.sharding_spec import ChunkShardingSpec
+from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op
+from torch.distributed.nn.functional import all_gather, reduce_scatter
+
+from ._common import (
+    _all_gather_base_input,
+    _handle_col_wise_sharding_base,
+    _handle_max_norm_col_wise,
+    _handle_row_wise_mask,
+)
+
+
+@custom_sharding_spec_op(ChunkShardingSpec, torch.nn.functional.embedding)
+def sharded_embedding(types, args, kwargs, pg):
+    """
+    Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``.
+    This method computes a sharded embedding lookup and has the following limitations:
+
+    1. Supports only sharding of ``weight``.
+    2. Supports only ``ChunkShardingSpec``.
+    3. Supports only a single local shard per rank.
+    4. Supports all specs except for scale_grad_by_freq, sparse, etc.
+
+    Based on the dimension that the weight is sharded on, there are two
+    algorithms:
+
+    ROWWISE SHARDING
+    ================
+    For row-wise sharding the weight is sharded on dimension 0.
+
+    The overall algorithm can be best explained with an example. Let's assume
+    the dims for input are (4 x 6) and W are (10 x 17) and W is sharded across
+    4 GPUs creating 3 shard of (3 x 17) and 1 shard of (1 x 17).
+    The algorithm is as follows:
+
+    1. First the input is all gathered to all ranks, since this is SPMD and
+       input is actually sharded across all ranks. The inputs then become a
+       4 (4 x 6) tensor on each rank. For example if the given input is
+       tensor([[6, 5, 2, 9, 6, 3],
+               [3, 1, 2, 4, 7, 6],
+               [4, 0, 4, 9, 8, 9],
+               [8, 6, 6, 4, 6, 1]])
+       on rank 0.
+       Then on every rank, we will have this tensor.
+       If input itself is already replicated, no all-gather will be done.
+    2. Next, we mask the ID which are not stored on that rank.
+       For example on rank 0, we store ID [0, 1, 2]. We only keep the ID
+       inside the set of numbers. The rest of them will be masked to an extra row.
+       The masked matrix will be used for embedding look up and is like:
+       tensor([[4, 4, 2, 4, 4, 4],
+               [4, 1, 2, 4, 4, 4],
+               [4, 0, 4, 4, 4, 4],
+               [4, 4, 4, 4, 4, 1]])
+       The reason of having an extra row (aka, number 4 in the example) is
+       because when max_norm is specified only weight which has looked will
+       be re-normed so mask IDs whose embeddings are not stored in current
+       rank will to an extra row will ensure max_norm still works as expected.
+    3. If max_norm is specified, the extra row guarantees that the mask ID will
+       not affect the behavior of weigh re-norm.
+
+    COLWISE SHARDING
+    ================
+    For col-wise sharding the weight is sharded on dimension 1.
+
+    The overall algorithm can be best explained with an example. Let's assume
+    the dims for input are (4 x 6) and W are (16 x 17) and W is sharded across
+    4 GPUs creating 3 shards of (16 x 5) and 1 shard of (16 x 2).
+    The algorithm is as follows:
+
+    1. First the input is broadcasted to all ranks, since this is SPMD we
+       actually do an all_gather for all the inputs resulting in 4 (4 x 6)
+       inputs on each rank.
+    2. Next we perform local embedding lookup operation by apply each
+       input (4 x 6) with the local shard (16 x 5) ((16 x 2) for the last).
+       This results in 4 (5 x 6 x 4) ((2 x 6 x 4) for the last) matrices
+       on each rank. We transpose dim 0 and dim 2.
+    3. Next, we concat these 4 matrices and perform an all2all to share the
+       appropriate (5 x 6 x 4) or (2 x 6 x 4) matrices to each rank.
+    4. Now, each rank receives a (17 x 6 x 4) matrix which is basically the
+       size of the result we need.
+    5. If placements are not in order any appropriate rearrangement of columns
+       are done for the (17 x 6 x 4) matrix and finally we transpose the
+       dim 0 and dim 2 again.
+    6. If max_norm is specified, we manually sum up the norm and renorm. Because
+       the renorm must be in place, we need to override the local_shard to mimic
+       this behavior.
+    """
+    # Validate input params
+    _validate_embedding_param(args, kwargs)
+
+    input = args[0]
+    weight = args[1]
+    max_norm = kwargs.get("max_norm")
+    norm_type = kwargs.get("norm_type")
+    padding_idx = kwargs.get("padding_idx")
+
+    local_shard = weight.local_tensor().contiguous()
+    sharding_dim = weight._sharding_spec.dim
+    world_size = dist.get_world_size(pg)
+    rank = dist.get_rank(pg)
+
+    if sharding_dim == 1:
+        output, local_shard = _handle_col_wise_sharding(
+            input, world_size, weight, local_shard, max_norm, norm_type, padding_idx, pg
+        )
+        weight.local_shards()[0].tensor = local_shard
+        return output
+    elif sharding_dim == 0:
+        return _handle_row_wise_sharding(
+            input,
+            world_size,
+            weight,
+            local_shard,
+            max_norm,
+            norm_type,
+            padding_idx,
+            rank,
+            pg,
+        )
+    else:
+        raise RuntimeError(
+            f"nn.Embedding weight sharded on dim {sharding_dim} not supported!"
+        )
+
+
+def _validate_embedding_param(args, kwargs):
+    """
+    Validate input params of sharded embedding op.
+
+    Args:
+        input: list of ID used for lookup.
+        weight: sharded weight tensor.
+        kwargs: same as normal Embedding.
+
+    Return: None.
+    """
+
+    input = args[0]
+    weight = args[1]
+    max_norm = kwargs.get("max_norm")
+    scale_grad_by_freq = kwargs.get("scale_grad_by_freq")
+    sparse = kwargs.get("sparse")
+
+    # Validate types
+    if not isinstance(input, torch.Tensor):
+        raise TypeError("input need to be torch.Tensor")
+    if not isinstance(weight, ShardedTensor):
+        raise TypeError("weight needs to be ShardedTensor")
+    weight_size = weight.size()
+    if len(weight_size) != 2:
+        raise ValueError("Weight needs to have exactly 2 dims")
+    if int(torch.min(input).item()) < 0:
+        raise ValueError(
+            "Index out of range in Input %d %d",
+            int(torch.min(input).item()),
+            weight_size[1],
+        )
+    if int(torch.max(input).item()) >= weight_size[0]:
+        raise ValueError(
+            "Index out of range in Input %d %d",
+            int(torch.max(input).item()),
+            weight_size[1],
+        )
+    if scale_grad_by_freq:
+        raise RuntimeError(
+            'nn.Embedding weight sharded with flag on "scale_grad_by_freq" not supported!'
+        )
+    if sparse:
+        raise RuntimeError(
+            'nn.Embedding weight sharded with flag on "sparse" not supported!'
+        )
+    if max_norm and max_norm <= 0.0:
+        raise ValueError('"max_norm" must be larger than zero!')
+
+    if not isinstance(weight._sharding_spec, ChunkShardingSpec):
+        raise ValueError("Only ChunkShardingSpec supported for ShardedTensor ops!")
+    if len(weight.local_shards()) != 1:
+        raise ValueError("Only one local shard supported!")
+
+
+def _handle_col_wise_sharding(
+    input, world_size, weight, local_shard, max_norm, norm_type, padding_idx, pg
+):
+    """
+    Entry-point function to handle the logic of col-wise sharding of weight
+    for embedding. (Detailed explanations of the logic can be found in
+    the comment for sharded_embedding.)
+
+    Args:
+        input: list of ID used for lookup and aggregation.
+        world_size: number of ranks.
+        weight: sharded weight tensor.
+        local_shard: col-wise shared local weight used for lookup.
+        max_norm: If given, each embedding vector with norm larger
+            than max_norm is renormalized to have norm max_norm.
+            Note: this will modify weight in-place.
+        norm_type: The p in the p-norm to compute for the max_norm option.
+        padding_idx: If specified, the entries at padding_idx do
+            not contribute to the gradient; therefore, the embedding
+            vector at padding_idx is not updated during training,
+            i.e. it remains as a fixed "pad".
+        pg: process group.
+
+    Returns: final result of lookup.
+    """
+    # allgather the inputs first for non Replicated Tensor.
+    gathered_inputs = all_gather(input, group=pg)
+
+    if max_norm is not None:
+        # max_norm changes the weight in-place
+        local_shard = _handle_max_norm_col_wise(
+            max_norm, norm_type, local_shard, input, world_size, gathered_inputs, pg
+        )
+
+    output = _handle_col_wise_sharding_base(
+        torch.nn.functional.embedding,
+        len(input.size()),
+        input,
+        world_size,
+        weight,
+        local_shard,
+        pg,
+        gathered_inputs,
+        padding_idx=padding_idx,
+    )
+    return (output, local_shard)
+
+
+def _handle_row_wise_sharding(
+    input, world_size, weight, local_shard, max_norm, norm_type, padding_idx, rank, pg
+):
+    """
+    Entry-point function to handle the logic of row-wise sharding of weight
+    for embedding. (Detailed explanations of the logic can be found in
+    the comment for sharded_embedding.)
+
+    Args:
+        input: list of ID used for lookup and aggregation.
+        world_size: number of ranks.
+        weight: sharded weight tensor.
+        local_shard: row-wise shared local weight used for lookup.
+        max_norm: If given, each embedding vector with norm larger
+            than max_norm is renormalized to have norm max_norm.
+            Note: this will modify weight in-place.
+        norm_type: The p in the p-norm to compute for the max_norm option.
+        padding_idx: If specified, the entries at padding_idx do
+            not contribute to the gradient; therefore, the embedding
+            vector at padding_idx is not updated during training,
+            i.e. it remains as a fixed "pad".
+        rank: # of cuda process.
+        pg: process group.
+
+    Returns: final result of lookup.
+    """
+    # allgather the inputs first for non Replicated Tensor.
+    gather_inp = _all_gather_base_input(input, pg)
+
+    # Mask the input according to sharding spec.
+    lookup_input, padding_idx, padding_row = _handle_row_wise_mask(
+        gather_inp, padding_idx, weight, world_size, rank
+    )
+
+    # When input is a large tensor, the value of weight is changed.
+    # This is a walk-around for now. GH issue: #81717
+    if max_norm is not None:
+        torch.nn.functional.embedding(
+            torch.unique(lookup_input)[:-1],
+            local_shard,
+            padding_idx=padding_idx,
+            max_norm=max_norm,
+            norm_type=norm_type,
+        )
+        max_norm = None
+
+    local_input_embeddings = torch.nn.functional.embedding(
+        lookup_input,
+        torch.cat([local_shard, padding_row]),
+        padding_idx=padding_idx,
+        max_norm=max_norm,
+        norm_type=norm_type,
+    )
+
+    # TODO: Make the result a PartialTensor.
+    local_shards = local_input_embeddings.chunk(pg.size())
+    return reduce_scatter(
+        torch.empty_like(local_shards[0]),
+        list(local_shards),
+        group=pg,
+    )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1581575f5f47058325af51129fd0d9d4497b1d9
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py
@@ -0,0 +1,479 @@
+# mypy: allow-untyped-defs
+
+from typing import cast
+
+import torch
+import torch.distributed as dist
+from torch._C._distributed_c10d import ReduceOp
+from torch.distributed._shard.sharded_tensor import ShardedTensor
+from torch.distributed._shard.sharding_spec import ChunkShardingSpec
+from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op
+from torch.distributed.nn.functional import all_gather, reduce_scatter
+
+from ._common import (
+    _all_gather_base_input,
+    _handle_col_wise_sharding_base,
+    _handle_max_norm_col_wise,
+    _handle_row_wise_mask,
+)
+
+
+@custom_sharding_spec_op(ChunkShardingSpec, torch.nn.functional.embedding_bag)
+def sharded_embedding_bag(types, args, kwargs, pg):
+    """
+    Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding_bag``.
+    This method computes a sharded embedding bag aggregation and has the following limitations:
+
+    1. Supports only sharding of ``weight``.
+    2. Supports only ``ChunkShardingSpec``.
+    3. Supports only a single local shard per rank.
+    4. Supports all specs except for scale_grad_by_freq, sparse, etc.
+
+    Based on the dimension that the weight is sharded on, there are two
+    algorithms:
+
+    ROWWISE SHARDING
+    ================
+    For row-wise sharding the weight is sharded on dimension 0.
+
+    The overall algorithm can be best explained with an example. Let's assume
+    the dims for input are (4 x 6) and W are (16 x 17) and W is sharded across
+    4 GPUs creating 4 shard of (4 x 17).
+    The algorithm is as follows:
+
+    1. First the input is all gathered to all ranks, since this is SPMD and
+       input is actually sharded across all ranks. The inputs then become a
+       4 (4 x 6) tensor on each rank. For example if the given input is
+       tensor([[6, 5, 2, 9, 6, 3],
+               [3, 1, 2, 4, 7, 6],
+               [4, 0, 4, 9, 8, 9],
+               [8, 6, 6, 4, 6, 1]])
+       on rank 0.
+       Then on every rank, we will have this tensor.
+       If input itself is already replicated, no all-gather will be done.
+    2. Next, we mask the ID which are not stored on that rank.
+       For example on rank 0, we store ID [0, 1, 2]. We only keep the ID
+       inside the set of numbers. The rest of them will be masked to an extra row.
+       The masked matrix will be used for embedding look up and is like:
+       tensor([[4, 4, 2, 4, 4, 4],
+               [4, 1, 2, 4, 4, 4],
+               [4, 0, 4, 4, 4, 4],
+               [4, 4, 4, 4, 4, 1]])
+    3. If ``max_norm`` is specified, the extra row guarantees that the mask ID will
+       not affect the behavior of weigh re-norm.
+    4. The example above only happens in one rank and each rank does a very similar thing.
+       For "Mean" mode we need to divide by either column size (2D) or the interval length
+       defined by the offset (excluding the row specified in ``padding_idx``).
+       We also need to mask the unexisting row to neg Inf so that negative value does not
+       gets wiped out in the "Max" mode.
+
+    COLWISE SHARDING
+    ================
+    For col-wise sharding the weight is sharded on dimension 1.
+
+    The overall algorithm can be best explained with an example. Let's assume
+    the dims for input are (4 x 6) and W are (16 x 17) and W is sharded across
+    4 GPUs creating 3 shards of (16 x 5) and 1 shard of (16 x 2).
+    The algorithm is as follows:
+
+    1. First the input is broadcasted to all ranks, since this is SPMD we
+       actually do an all_gather for all the inputs resulting in 4 (4 x 6)
+       inputs on each rank.
+    2. Next we perform local embedding bag operation under the given mode by
+       apply each input (4 x 6) with the local shard (16 x 5) ((16 x 2) for the last).
+       This results in 4 (5 x 4) ((2 x 4) for the last) matrices on each rank.
+       We transpose the aggregation result.
+    3. Next, we concatenate these 4 matrices and perform an all2all to share the
+       appropriate (5 x 4) or (2 x 4) matrices to each rank.
+    4. Now, each rank receives a (17 x 4) matrix which is basically the
+       size of the result we need.
+    5. If placements are not in order any appropriate rearrangement of columns
+       are done for the (17 x 4) matrix and finally we transpose the output again.
+    6. If max_norm is specified, we manually sum up the norm and renorm. Because
+       the renorm must be in place, we need to override the local_shard to mimic
+       this behavior.
+    """
+    # Validate input params
+    _validate_embedding_bag_param(args, kwargs)
+
+    input = args[0]
+    weight = args[1]
+    offsets = kwargs.get("offsets")
+    per_sample_weights = kwargs.get("per_sample_weights")
+    mode = kwargs.get("mode")
+    max_norm = kwargs.get("max_norm")
+    norm_type = kwargs.get("norm_type")
+    include_last_offset = kwargs.get("include_last_offset")
+    padding_idx = kwargs.get("padding_idx")
+
+    local_shard = weight.local_tensor().contiguous()
+    sharding_dim = weight._sharding_spec.dim
+    world_size = dist.get_world_size(pg)
+    rank = dist.get_rank(pg)
+    if include_last_offset:
+        offsets = offsets[:-1]
+
+    if sharding_dim == 1:
+        output, local_shard = _handle_col_wise_sharding(
+            input,
+            world_size,
+            weight,
+            local_shard,
+            offsets,
+            per_sample_weights,
+            mode,
+            max_norm,
+            norm_type,
+            padding_idx,
+            pg,
+        )
+        weight.local_shards()[0].tensor = local_shard
+        return output
+    elif sharding_dim == 0:
+        return _handle_row_wise_sharding(
+            input,
+            world_size,
+            weight,
+            local_shard,
+            offsets,
+            per_sample_weights,
+            mode,
+            max_norm,
+            norm_type,
+            padding_idx,
+            rank,
+            pg,
+        )
+    else:
+        raise RuntimeError(
+            f"nn.EmbeddingBag weight sharded on dim {sharding_dim} not supported!"
+        )
+
+
+def _validate_embedding_bag_param(args, kwargs):
+    """
+    Validate input params of sharded embeddingBag op.
+
+    Args:
+        input: list of ID used for lookup and aggregation.
+        weight: sharded weight tensor.
+        kwargs: same as normal EmbeddingBag.
+
+    Return: None.
+    """
+
+    input = args[0]
+    weight = args[1]
+    offsets = kwargs.get("offsets")
+    per_sample_weights = kwargs.get("per_sample_weights")
+    mode = kwargs.get("mode")
+    max_norm = kwargs.get("max_norm")
+    scale_grad_by_freq = kwargs.get("scale_grad_by_freq")
+    sparse = kwargs.get("sparse")
+    include_last_offset = kwargs.get("include_last_offset")
+
+    # Validate types
+    if not isinstance(input, torch.Tensor):
+        raise TypeError("input need to be torch.Tensor")
+    if offsets is not None and not isinstance(offsets, torch.Tensor):
+        raise TypeError("offsets need to be torch.Tensor")
+    if per_sample_weights is not None and not isinstance(
+        per_sample_weights, torch.Tensor
+    ):
+        raise TypeError("per_sample_weights need to be torch.Tensor")
+    if not isinstance(weight, ShardedTensor):
+        raise TypeError("weight needs to be ShardedTensor")
+    if len(input.size()) > 2:
+        raise ValueError("Input more than 2 dims not supported")
+    weight_size = weight.size()
+    if len(weight_size) != 2:
+        raise ValueError("Weight needs to have exactly 2 dims")
+    if int(torch.min(input).item()) < 0:
+        raise ValueError(
+            "Index out of range in Input %d %d",
+            int(torch.min(input).item()),
+            weight_size[1],
+        )
+    if int(torch.max(input).item()) >= weight_size[0]:
+        raise ValueError(
+            "Index out of range in Input %d %d",
+            int(torch.max(input).item()),
+            weight_size[1],
+        )
+    if offsets is not None and len(input.size()) != 1:
+        raise ValueError("Input dimension needs to be exactly 1 dim")
+    if len(input.size()) == 1 and offsets is None:
+        raise ValueError("offsets is required for 1D input")
+    if per_sample_weights is not None and per_sample_weights.size() != input.size():
+        raise ValueError(
+            f"per_sample_weights size {per_sample_weights.size()} not equal to input size {input.size()}"
+        )
+    if mode is None:
+        mode = "mean"
+    if mode not in ["sum", "mean", "max"]:
+        raise ValueError(f"mode '{mode}' is not supported")
+    if scale_grad_by_freq:
+        raise RuntimeError(
+            'nn.Embedding weight sharded with flag on "scale_grad_by_freq" not supported!'
+        )
+    if sparse:
+        raise RuntimeError(
+            'nn.Embedding weight sharded with flag on "sparse" not supported!'
+        )
+    if include_last_offset and offsets is None:
+        raise ValueError('offsets is required for flag "include_last_offset"!')
+    if include_last_offset and cast(list[int], offsets)[-1] != input.size(0):
+        raise ValueError(
+            'offsets need to have the input size in the end when the flag "include_last_offset" is on!'
+        )
+
+    if max_norm and max_norm <= 0.0:
+        raise ValueError('"max_norm" must be larger than zero!')
+
+    if not isinstance(weight._sharding_spec, ChunkShardingSpec):
+        raise ValueError("Only ChunkShardingSpec supported for ShardedTensor ops!")
+    if len(weight.local_shards()) != 1:
+        raise ValueError("Only one local shard supported!")
+
+
+def _handle_col_wise_sharding(
+    input,
+    world_size,
+    weight,
+    local_shard,
+    offsets,
+    per_sample_weights,
+    mode,
+    max_norm,
+    norm_type,
+    padding_idx,
+    pg,
+):
+    """
+    Entry-point function to handle the logic of col-wise sharding of weight
+    for embeddingBag. (Detailed explanations of the logic can be found in
+    the comment for sharded_embedding_bag.)
+
+    Args:
+        input: list of ID used for lookup and aggregation.
+        world_size: number of ranks.
+        weight: sharded weight tensor.
+        local_shard: col-wise shared local weight used for lookup.
+        offsets: list of start positions of each bag for 1D input.
+        per_sample_weights: weights for weighted sum mode.
+        mode: aggregation method of each bag.
+        max_norm: If given, each embedding vector with norm larger
+            than max_norm is renormalized to have norm max_norm.
+            Note: this will modify weight in-place.
+        norm_type: The p in the p-norm to compute for the max_norm option.
+        padding_idx: If specified, the entries at padding_idx do
+            not contribute to the gradient; therefore, the embedding
+            vector at padding_idx is not updated during training,
+            i.e. it remains as a fixed "pad".
+            Note that the embedding vector at padding_idx is
+            excluded from the reduction.
+        pg: process group.
+
+    Return:
+        output: final result of lookup and aggregation.
+        local_shard: col-wise shared local weight used for lookup.
+            If max_norm, this will be the renormed weight.
+    """
+    # allgather the special input of embedding bag first.
+    (
+        gathered_inputs,
+        gathered_per_sample_weights,
+        gathered_offsets,
+    ) = _all_gather_embedding_bag_input(input, per_sample_weights, offsets, pg)
+
+    if max_norm is not None:
+        # max_norm changes the weight in-place
+        local_shard = _handle_max_norm_col_wise(
+            max_norm, norm_type, local_shard, input, world_size, gathered_inputs, pg
+        )
+
+    output = _handle_col_wise_sharding_base(
+        torch.nn.functional.embedding_bag,
+        1,
+        input,
+        world_size,
+        weight,
+        local_shard,
+        pg,
+        gathered_inputs,
+        mode=mode,
+        gathered_per_sample_weights=gathered_per_sample_weights,
+        gathered_offsets=gathered_offsets,
+        padding_idx=padding_idx,
+    )
+    return (output, local_shard)
+
+
+def _handle_row_wise_sharding(
+    input,
+    world_size,
+    weight,
+    local_shard,
+    offsets,
+    per_sample_weights,
+    mode,
+    max_norm,
+    norm_type,
+    padding_idx,
+    rank,
+    pg,
+):
+    """
+    Entry-point function to handle the logic of row-wise sharding of weight
+    for embeddingBag. (Detailed explanations of the logic can be found in
+    the comment for sharded_embedding_bag.)
+
+    Args:
+        input: list of ID used for lookup and aggregation.
+        world_size: number of ranks.
+        weight: sharded weight tensor.
+        local_shard: row-wise shared local weight used for lookup.
+        offsets: list of start positions of each bag for 1D input.
+        per_sample_weights: weights for weighted sum mode.
+        mode: aggregation method of each bag.
+        max_norm: If given, each embedding vector with norm larger
+            than max_norm is renormalized to have norm max_norm.
+            Note: this will modify weight in-place.
+        norm_type: The p in the p-norm to compute for the max_norm option.
+        padding_idx: If specified, the entries at padding_idx do
+            not contribute to the gradient; therefore, the embedding
+            vector at padding_idx is not updated during training,
+            i.e. it remains as a fixed "pad".
+            Note that the embedding vector at padding_idx is
+            excluded from the reduction.
+        rank: # of cuda process.
+        pg: process group.
+
+    Returns:
+        gathered_output: final result of lookup and aggregation.
+    """
+    if input.dim() > 1 and per_sample_weights is None:
+        # allgather the inputs first for non Replicated Tensor.
+        gather_inp = _all_gather_base_input(input, pg)
+    else:
+        (
+            gathered_inputs,
+            gathered_per_sample_weights,
+            gathered_offsets,
+        ) = _all_gather_embedding_bag_input(input, per_sample_weights, offsets, pg)
+        cat_dim = 0 if input.dim() != 1 else -1
+        gather_inp = torch.cat(gathered_inputs, dim=cat_dim)
+        if per_sample_weights is not None:
+            per_sample_weights = torch.cat(gathered_per_sample_weights, dim=cat_dim)
+        offset_add = 0 if input.dim() > 1 else input.size(0)
+        if offsets is not None:
+            offsets_list = torch.cat(
+                [gathered_offsets[i] + (offset_add * i) for i in range(pg.size())],
+                dim=cat_dim,
+            )
+
+    # Mask the input according to sharding spec.
+    lookup_input, padding_local, padding_row = _handle_row_wise_mask(
+        gather_inp, padding_idx, weight, world_size, rank
+    )
+    if mode == "max":
+        padding_row[:] = -float("Inf")
+
+    # When input is a large tensor, the value of weight is changed.
+    # This is a walk-around for now. GH issue: #81717.
+    if max_norm is not None:
+        torch.nn.functional.embedding_bag(
+            torch.unique(lookup_input)[:-1],
+            local_shard,
+            offsets=torch.tensor([0], device=local_shard.device, dtype=torch.long),
+            mode=mode,
+            per_sample_weights=None,
+            max_norm=max_norm,
+            norm_type=norm_type,
+            padding_idx=padding_local,
+        )
+        max_norm = None
+    result = torch.nn.functional.embedding_bag(
+        lookup_input,
+        torch.cat([local_shard, padding_row]),
+        offsets=offsets_list if offsets is not None else offsets,  # type: ignore[possibly-undefined]
+        mode=mode if mode != "mean" else "sum",
+        per_sample_weights=per_sample_weights,
+        max_norm=max_norm,
+        norm_type=norm_type,
+        padding_idx=padding_local,
+    )
+
+    op = ReduceOp.SUM if mode != "max" else ReduceOp.MAX
+    # TODO: Make the result a PartialTensor and move the logic below there.
+    local_shards = result.chunk(pg.size())
+    result = reduce_scatter(
+        torch.empty_like(local_shards[0]),
+        list(local_shards),
+        op=op,
+        group=pg,
+    )
+
+    # For Mean, we cannot do the division until very end because the sum of means
+    # not equal to the mean of sum. (Divisor is different)
+    if mode == "mean":
+        if input.dim() > 1:
+            padding_idx = padding_idx if padding_idx is not None else -1
+            split_sizes = torch.sum(
+                torch.ne(input, padding_idx), dim=-1, dtype=local_shard.dtype
+            )
+        else:
+            split_sizes = torch.cat(
+                (
+                    # pyrefly: ignore [unsupported-operation]
+                    offsets[1 : offsets.size(0)] - offsets[0:-1],
+                    # pyrefly: ignore [unsupported-operation]
+                    (input.size(0) - offsets[-1]).unsqueeze(0),
+                ),
+                dim=-1,
+            )
+        return torch.div(result, split_sizes.unsqueeze(1))
+
+    # Return the appropriate local result.
+    return result
+
+
+def _all_gather_embedding_bag_input(input, per_sample_weights, offsets, pg):
+    """
+    In case we need to gather input and all other parameters of embeddingBag
+    ops, we need to stack all input together to perform ``all_gather``
+    collective communication just once.
+
+    Note that since offsets does not share the same size as input and
+    is always smaller than input, we resize it during the communication.
+
+    Args:
+        input: tensor to be applied op on.
+        per_sample_weights: weights for weighted sum mode.
+        offsets: when input is 1D. offsets determines the starting
+            index position of each bag (sequence) in input.
+        pg: process group.
+
+    Returns:
+        gathered_inputs: list of input tensor gathered from each rank.
+        gathered_per_sample_weights: list of per_sample_weights from each rank.
+        gathered_offsets: list of offsets from each rank.
+    """
+    input_to_gather = [input]
+    if per_sample_weights is not None:
+        input_to_gather.append(per_sample_weights)
+    if offsets is not None:
+        input_to_gather.append(offsets.clone().resize_(input.size()))
+    gathered_inputs = all_gather(torch.stack(input_to_gather), group=pg)
+
+    gathered_per_sample_weights = None
+    if per_sample_weights is not None:
+        gathered_per_sample_weights = [t[1] for t in gathered_inputs]
+    gathered_offsets = None
+    if offsets is not None:
+        idx = 2 if per_sample_weights is not None else 1
+        gathered_offsets = [
+            t[idx].resize_(offsets.size()).to(offsets.dtype) for t in gathered_inputs
+        ]
+    gathered_inputs = [t[0].to(input.dtype) for t in gathered_inputs]
+    return gathered_inputs, gathered_per_sample_weights, gathered_offsets
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_sharded_tensor/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_sharded_tensor/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5ee703a4c18d5ddf453cd45c1f7fcd79ef667e04
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_sharded_tensor/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d6fbd341813c0d97df4d2a7006f4f34cb657c478
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/__pycache__/join.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/__pycache__/join.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cea32f94a13812af216424b5e9b5944cbb7daf27
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/__pycache__/join.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..639a95a5f985bc22dcb4182e21d2df341ee9b9b5
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/__pycache__/checkpoint_wrapper.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/__pycache__/checkpoint_wrapper.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1f8e1e49cc91da282d7e7d58b57ffeb9426c9ef0
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/__pycache__/checkpoint_wrapper.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..081d397a9c1f11e332f95649d362e1f3c27abe8a
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py
@@ -0,0 +1,321 @@
+# mypy: allow-untyped-defs
+import warnings
+from abc import ABC, abstractmethod
+from collections.abc import Callable, Iterator
+from enum import auto, Enum
+from functools import partial
+from typing import Any
+
+import torch
+import torch.nn as nn
+from torch.autograd.graph import save_on_cpu
+from torch.distributed.utils import _pack_kwargs, _replace_by_prefix, _unpack_kwargs
+from torch.utils.checkpoint import checkpoint as torch_utils_checkpoint
+
+
+_CHECKPOINT_WRAPPED_MODULE = "_checkpoint_wrapped_module"
+_CHECKPOINT_PREFIX = _CHECKPOINT_WRAPPED_MODULE + "."
+
+
+class CheckpointImpl(Enum):
+    REENTRANT = auto()
+    NO_REENTRANT = auto()
+
+
+class ActivationWrapper(torch.nn.Module, ABC):
+    """
+    Base class for Activation Checkpoint and Activation Offload.
+
+    Not meant to be instantiated directly.
+    """
+
+    def __init__(self, mod):
+        super().__init__()
+        self._checkpoint_wrapped_module = mod
+        # state_dict post hook to remove prefix to allow loading into a
+        # non-checkpoint wrapped module.
+        self._register_state_dict_hook(self._post_state_dict_hook)
+        # load_state_dict pre-hook to allow loading back into
+        # checkpoint-wrapped module.
+        self.register_load_state_dict_pre_hook(self._pre_load_state_dict_hook)
+
+    @abstractmethod
+    def forward(self, *args, **kwargs):
+        raise ValueError("Subclasses should implement forward().")
+
+    def __getattr__(self, name: str) -> Any:
+        """Forward missing attributes to wrapped module."""
+        try:
+            return super().__getattr__(name)  # defer to nn.Module's logic
+        except AttributeError:
+            return getattr(self._checkpoint_wrapped_module, name)
+
+    def __getitem__(self, key: int) -> Any:
+        """Forward indexing calls in case the module is a nn.Sequential."""
+        return self._checkpoint_wrapped_module.__getitem__(key)  # type: ignore[operator]
+
+    def named_parameters(
+        self,
+        *args,
+        **kwargs,
+    ) -> Iterator[tuple[str, torch.nn.Parameter]]:
+        """
+        Override :meth:`named_parameters()` to intercept parameter names.
+
+        remove all occurrences of ``_CHECKPOINT_PREFIX``.
+        """
+        for param_name, param in super().named_parameters(*args, **kwargs):
+            yield param_name.replace(_CHECKPOINT_PREFIX, ""), param
+
+    @staticmethod
+    def _post_state_dict_hook(
+        module: nn.Module,
+        state_dict: dict[str, Any],
+        prefix: str,
+        *args: Any,
+    ) -> dict[str, Any]:
+        """
+        _post_state_dict_hook() is called after the state_dict() of this FSDP module is executed.
+
+        For ``checkpoint_wrapper``, it will strip checkpoint-wrapped module prefix,
+        so that this module can be loaded into non-checkpointed modules.
+        It would still be able to be loaded into checkpoint-wrapped modules as this class,
+        adds the prefix back before loading the state_dict.
+        """
+        _replace_by_prefix(state_dict, f"{prefix}{_CHECKPOINT_PREFIX}", prefix)
+        return state_dict
+
+    @staticmethod
+    def _pre_load_state_dict_hook(
+        module: nn.Module,
+        state_dict: dict[str, Any],
+        prefix: str,
+        *args: Any,
+    ) -> None:
+        """
+        ``_pre_state_dict_hook` is called before ``self._load_from_state_dict()`` is called.
+
+        For ``checkpoint_wrapper``, it will add back the module
+        prefix so that non-checkpointed modules can be loaded into
+        checkpoint_wrapper modules properly.
+        """
+        _replace_by_prefix(state_dict, prefix, prefix + f"{_CHECKPOINT_PREFIX}")
+
+
+class OffloadWrapper(ActivationWrapper):
+    def forward(self, *args, **kwargs):
+        with save_on_cpu(pin_memory=True):
+            return self._checkpoint_wrapped_module(*args, **kwargs)
+
+
+class CheckpointWrapper(ActivationWrapper):
+    """
+    An ``nn.Module`` that wraps another ``nn.Module`` with checkpointing.
+
+    Note that this module is not meant to be used directly but instead,
+    it is to be used through the ``checkpoint_wrapper`` function.
+    """
+
+    def __init__(
+        self,
+        mod: torch.nn.Module,
+        checkpoint_impl: CheckpointImpl = CheckpointImpl.NO_REENTRANT,
+        checkpoint_fn=None,
+        **checkpoint_fn_kwargs,
+    ):
+        super().__init__(mod)
+        self.checkpoint_impl = checkpoint_impl
+        if checkpoint_fn is None:
+            # use torch.utils.checkpoint
+            self.checkpoint_fn = partial(
+                torch_utils_checkpoint,
+                use_reentrant=(self.checkpoint_impl == CheckpointImpl.REENTRANT),
+                **checkpoint_fn_kwargs,
+            )
+        else:
+            # Construct user-specified checkpoint function.
+            self.checkpoint_fn = partial(
+                checkpoint_fn,
+                **checkpoint_fn_kwargs,
+            )
+
+    def forward(self, *args, **kwargs):
+        # Support keyword arguments for reentrant checkpoint. Note that this
+        # only works if user has specified self.checkpoint_impl and is not
+        # using their own custom checkpoint_fn.
+        if self.checkpoint_impl == CheckpointImpl.REENTRANT and kwargs != {}:
+            # Pack the args and kwargs
+            flat_args, kwarg_keys = _pack_kwargs(*args, **kwargs)
+
+            # Function that only takes (packed) args, but can unpack them
+            # into the original args and kwargs for the checkpointed
+            # function, and runs that function.
+            def my_function(*inputs):
+                # unpack back into args and kwargs
+                unpacked_args, unpacked_kwargs = _unpack_kwargs(inputs, kwarg_keys)
+                # run original module
+                return self._checkpoint_wrapped_module(
+                    *unpacked_args, **unpacked_kwargs
+                )
+
+            # Pass the function that only takes packed args into reentrant
+            # checkpoint API.
+            return self.checkpoint_fn(  # type: ignore[misc]
+                my_function,
+                *flat_args,
+            )
+        else:
+            return self.checkpoint_fn(  # type: ignore[misc]
+                self._checkpoint_wrapped_module, *args, **kwargs
+            )
+
+
+def offload_wrapper(module: torch.nn.Module) -> torch.nn.Module:
+    """
+    Wrap a module for activation offloading to CPU.
+
+    Offloads intermediate activations to the CPU for modules wrapped with this function.
+    Wrappers with activation offload can be composed with ones that do recomputation-based
+    checkpoint to trade off increased compute versus increased CPU
+    memory usage and additional H2D transfers.
+
+    Usage::
+        offloaded_module = offload_wrapper(module)
+        outputs = checkpointed_module(inputs)
+    Args:
+        module (nn.Module):
+            The module to be wrapped
+    Returns:
+        (nn.Module):
+            Wrapped module
+    """
+    return OffloadWrapper(module)
+
+
+def checkpoint_wrapper(
+    module: torch.nn.Module,
+    checkpoint_impl: CheckpointImpl = CheckpointImpl.NO_REENTRANT,
+    checkpoint_fn=None,
+    **checkpoint_fn_kwargs,
+) -> torch.nn.Module:
+    """
+    Wrap a module for activation checkpointing.
+
+    If the module is wrapped with this function, all subsequent calls to the module will,
+    automatically perform checkpointing without the user having to explicitly call ``checkpoint`` function.
+
+    Usage::
+        checkpointed_module = checkpoint_wrapper(module)
+        outputs = checkpointed_module(inputs)
+    Args:
+        module (nn.Module):
+            The module to be wrapped
+        checkpoint_impl (Optional[CheckpointImpl]):
+            The checkpointing implementation to use. Note that this will only
+            be passed into the ``torch.utils.checkpoint.checkpoint``
+            implementation, and is ignored if a custom ``checkpoint_fn`` is
+            specified. Note that for implementations using reentrant checkpoint
+            from ``torch.utils.checkpoint``, keyword arguments will only be
+            supported if ``checkpoint_impl`` is passed as ``CheckpointImpl.REENTRANT`.
+        checkpoint_fn (Optional[Callable]):
+            Functional checkpoint implementation to use. If this is specified,
+            it will be used over the default ``torch.utils.checkpoint.checkpoint``
+            implementation and the `checkpoint_impl` argument will be ignored.
+        **checkpoint_fn_kwargs: (Dict[str, Any]): Keyword arguments to pass into `checkpoint_fn`.
+
+    Returns:
+        (nn.Module):
+            Wrapped module
+    """
+
+    if checkpoint_impl == CheckpointImpl.REENTRANT:
+        warnings.warn(
+            f"Please specify {CheckpointImpl.NO_REENTRANT} as "
+            f"{CheckpointImpl.REENTRANT} will soon be removed as "
+            "the default and eventually deprecated.",
+            FutureWarning,
+            stacklevel=2,
+        )
+    return CheckpointWrapper(
+        module,
+        checkpoint_impl,
+        checkpoint_fn,
+        **checkpoint_fn_kwargs,
+    )
+
+
+def apply_activation_checkpointing(
+    model,
+    checkpoint_wrapper_fn=checkpoint_wrapper,
+    check_fn=lambda _: True,
+    auto_wrap_policy: Callable[[nn.Module, bool, int], bool] | None = None,
+):
+    """
+    Apply :func:`checkpoint_wrapper` to modules within `model` based on a user-defined configuration.
+
+    For each module within `model`, the `check_fn` is used to decide
+    whether `module` should be wrapped with :func:`checkpoint_wrapper` or not.
+
+    Note::
+        This function modifies `model` in place and replaces appropriate layers with
+        their checkpoint-wrapped modules.
+    Note::
+        This function will not wrap the overall root module. If this is needed, please directly use
+        :func:`checkpoint_wrapper` or :func:`offload_wrapper`.
+    Usage::
+        model = nn.Sequential(
+            nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10)
+        )
+        check_fn = lambda l: isinstance(l, nn.Linear)
+        # checkpoint activations
+        apply_activation_checkpointing(model, checkpoint_wrapper_fn=checkpoint_wrapper, check_fn=check_fn)
+        # Or offload activations to CPU
+        apply_activation_checkpointing(model, checkpoint_wrapper_fn=offload_wrapper, check_fn=check_fn)
+    Args:
+        model (nn.Module):
+            The model whose submodules should be wrapped with activation checkpointing.
+        checkpoint_wrapper_fn (Optional[Callable[nn.Module]])
+            A ``Callable`` which will wrap modules
+        check_fn (Optional[Callable[nn.Module, nn.Module]])
+            A lambda function which will be passed each child submodule of ``model`` and returns
+            ``True`` or ``False`` depending on whether the submodule should be wrapped.
+        auto_wrap_policy (Optional[Callable[[nn.Module, bool, int], bool]]): A policy to wrap model's
+            submodules with AC. Note that if this is specified, it takes precedence over ``check_fn``.
+    Returns: None (`model` is modified inplace)
+    """
+    # TODO: Importing inside function to avoid circular import issue between FSDP and
+    # checkpoint_wrapper. This can be resolved once wrap() APIs are decoupled from FSDP code.
+    from torch.distributed.fsdp._wrap_utils import _construct_wrap_fn, _post_order_apply
+    from torch.distributed.fsdp.wrap import (
+        _Policy,
+        _recursive_wrap,
+        lambda_auto_wrap_policy,
+    )
+
+    policy = (
+        auto_wrap_policy
+        if auto_wrap_policy is not None
+        else partial(lambda_auto_wrap_policy, lambda_fn=check_fn)
+    )
+    if not callable(policy):
+        if not isinstance(policy, _Policy):
+            raise ValueError(
+                f"Expected {policy} to be callable or be a pre-defined wrap policy"
+            )
+        target_module_to_kwargs = policy._run_policy(
+            model, ignored_modules=set(), root_kwargs={}
+        )
+        wrap_fn = _construct_wrap_fn(
+            model, target_module_to_kwargs, checkpoint_wrapper_fn
+        )
+        _post_order_apply(model, wrap_fn)
+        return
+
+    _recursive_wrap(
+        module=model,
+        auto_wrap_policy=policy,  # type: ignore[arg-type]
+        wrapper_cls=checkpoint_wrapper_fn,
+        ignored_modules=set(),
+        ignored_params=set(),
+        only_wrap_children=True,
+    )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_comm_hooks/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_comm_hooks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b57a075ad729d0ae3004dc15585250b04810f43
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_comm_hooks/__init__.py
@@ -0,0 +1,7 @@
+from . import default_hooks as default
+
+
+LOW_PRECISION_HOOKS = [
+    default.fp16_compress_hook,
+    default.bf16_compress_hook,
+]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_comm_hooks/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_comm_hooks/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c614d7e3886f1d8b755c57b46bcb6e96e1f56ec0
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_comm_hooks/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_comm_hooks/__pycache__/default_hooks.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_comm_hooks/__pycache__/default_hooks.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4bf95a83ffae47dc87a76e1312245e9823ea4cbf
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_comm_hooks/__pycache__/default_hooks.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_comm_hooks/default_hooks.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_comm_hooks/default_hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..76cd01c2265b1d7e5739d79b406cb94a0b0a9893
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_comm_hooks/default_hooks.py
@@ -0,0 +1,191 @@
+# mypy: allow-untyped-defs
+import functools
+
+import torch
+import torch.distributed as dist
+
+
+class DefaultState:
+    r"""
+    Stores state needed to perform the default communication algorithm within a communication hook.
+
+    Args:
+        process_group (ProcessGroup): The process group to be used.
+    """
+
+    __slots__ = [
+        "process_group",
+        "world_size",
+        "gradient_predivide_factor",
+        "gradient_postdivide_factor",
+    ]
+
+    def __init__(self, process_group: dist.ProcessGroup):
+        if process_group is None:
+            raise ValueError(f"Expected to pass in an explicit ProcessGroup to {self}.")
+        self.process_group = process_group
+        self.world_size = dist.get_world_size(process_group)
+        # Setting two factors `self.gradient_predivide_factor`
+        # and `self.gradient_postdivide_factor` to avoid underflow and overflow
+        self.gradient_predivide_factor = self._get_gradient_predivide_factor(
+            self.world_size
+        )
+        self.gradient_postdivide_factor = (
+            self.world_size / self.gradient_predivide_factor
+        )
+
+    @staticmethod
+    def _get_gradient_predivide_factor(world_size: int) -> float:
+        factor: int = 1
+        while world_size % factor == 0 and world_size / factor > factor:
+            factor *= 2
+        return float(factor)
+
+
+class LowPrecisionState(DefaultState):
+    r"""
+    Stores state needed to perform gradient communication in a lower precision within a communication hook.
+
+    Communication hook will cast gradients back to the original
+    parameter precision specified by ``parameter_type`` (default: torch.float32).
+    Builds on top of the :class:`DefaultState`.
+
+    Args:
+        parameter_type (torch.dtype): The precision of model's parameters.
+        Required for a hook to cast gradients back to a parameter's precision.
+    """
+
+    __slots__ = [
+        "parameter_type",
+    ]
+
+    def __init__(
+        self,
+        process_group,
+        parameter_type=torch.float32,
+    ):
+        super().__init__(process_group)
+        self.parameter_type = parameter_type
+
+
+def _decompress(state: LowPrecisionState, grad: torch.Tensor):
+    """
+    Casts gradients back to full parameter precision so that further computation happens in full precision.
+    """
+    orig_grad_data = grad.data
+    grad.data = grad.data.to(state.parameter_type)
+    device_type = ""
+    try:
+        if grad.device.type == "privateuse1":
+            device_type = torch._C._get_privateuse1_backend_name()
+        else:
+            device_type = grad.device.type
+        backend = getattr(torch, device_type)
+    except AttributeError as e:
+        raise AttributeError(
+            f"Device {grad.device}  does not have a \
+                corresponding backend registered as 'torch.device_type'."
+        ) from e
+
+    # Don't let this memory get reused until after the transfer.
+    orig_grad_data.record_stream(backend.current_stream())  # type: ignore[arg-type]
+
+
+def allreduce_hook(state: DefaultState, grad: torch.Tensor):
+    r"""
+    Implement the  FSDP communication hook for ``all_reduce`` algorithm and a necessary pre- and post-division of gradients.
+
+    Args:
+        state (DefaultState): State information, configures pre- and post-division factors.
+        grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks.
+    """
+    # Average grad by pre-division factor. Together pre- and post-division factors
+    # lead to an overall averaging by world_size, required for consistency with PyTorch DDP.
+    # This is a two-step process to avoid potential underflow and overflow.
+    if state.gradient_predivide_factor > 1:
+        grad.div_(state.gradient_predivide_factor)
+    dist.all_reduce(grad, group=state.process_group)
+    # Average grad by post-division factor.
+    if state.gradient_postdivide_factor > 1:
+        grad.div_(state.gradient_postdivide_factor)
+
+
+def reduce_scatter_hook(state: DefaultState, grad: torch.Tensor, output: torch.Tensor):
+    r"""
+    Implement the  FSDP communication hook for ``reduce_scatter`` algorithm.
+
+    For sharded FSDP strategies and a necessary pre- and post-division of gradients.
+
+    Args:
+        state (DefaultState): State information, configures pre- and post-division factors.
+        grad (torch.Tensor): An unsharded gradient for the local batch that needs to be
+        communicated across ranks.
+        output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``.
+    """
+    # Average grad by pre-division factor.
+    if state.gradient_predivide_factor > 1:
+        grad.div_(state.gradient_predivide_factor)
+    dist.reduce_scatter_tensor(output, grad, group=state.process_group)
+    # Average grad's shard by post-division factor.
+    if state.gradient_postdivide_factor > 1:
+        output.div_(state.gradient_postdivide_factor)
+
+
+def _low_precision_hook(
+    prec: torch.dtype,
+    state: LowPrecisionState,
+    grad: torch.Tensor,
+    output: torch.Tensor | None,
+):
+    if grad.dtype != prec:
+        grad.data = grad.data.to(prec)
+    if output is not None:
+        if output.dtype != prec:
+            output.data = output.data.to(prec)
+        reduce_scatter_hook(state, grad, output)
+        _decompress(state, output)
+    else:
+        allreduce_hook(state, grad)
+        _decompress(state, grad)
+
+
+def fp16_compress_hook(
+    state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor | None = None
+):
+    r"""
+    Implement FSDP communication hook for a simple gradient compression approach.
+    Casts ``grad`` to half-precision floating-point format (``torch.float16``).
+
+    It also averages gradients by ``world_size`` in two steps: first it pre-divides gradients by a
+    ``state.gradient_predivide_factor``, and after a communication step (``all_reduce`` or ``reduce_scatter``)
+    gradients are averaged by a ``state.gradient_postdivide_factor``.
+    Once post-division is done, compressed gradients are casted back to parameters' precision.
+
+    Args:
+        state (LowPrecisionState): State information, configures pre- and post-division factors, parameters' precision.
+        grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks in a lower precision.
+        output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``.
+    """
+    fp16_hook = functools.partial(_low_precision_hook, torch.float16)
+    return fp16_hook(state, grad, output)
+
+
+def bf16_compress_hook(
+    state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor | None = None
+):
+    r"""
+    Implement FSDP communication hook for a simple gradient compression approach .
+    Casts ``grad`` to half-precision floating-point format.
+
+    It also averages gradients by ``world_size`` in two steps: first it pre-divides gradients by a
+    ``state.gradient_predivide_factor``, and after a communication step (``all_reduce`` or ``reduce_scatter``)
+    gradients are averaged by a ``state.gradient_postdivide_factor``.
+    Once post-division is done, compressed gradients are casted back to parameters' precision.
+
+    Args:
+        state (LowPrecisionState): State information, configures pre- and post-division factors, parameters' precision.
+        grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks in a lower precision.
+        output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``.
+    """
+    bf16_hook = functools.partial(_low_precision_hook, torch.bfloat16)
+    return bf16_hook(state, grad, output)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_optimizer_overlap/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_optimizer_overlap/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba62bfb68f42a136dcfa27bcf378d3892cf6751a
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_optimizer_overlap/__init__.py
@@ -0,0 +1 @@
+from .optimizer_overlap import _as_overlapped_optim
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_optimizer_overlap/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_optimizer_overlap/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b7f5e69c94d3627d1d3076726b33fd299ed3d142
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_optimizer_overlap/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_optimizer_overlap/__pycache__/optimizer_overlap.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_optimizer_overlap/__pycache__/optimizer_overlap.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5ff9f2927d7800ecfcf03959a3f75041c3b531e2
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_optimizer_overlap/__pycache__/optimizer_overlap.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py
new file mode 100644
index 0000000000000000000000000000000000000000..569a42ffe7643bb6b6403dfb323a4dfd28493e1b
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py
@@ -0,0 +1,96 @@
+# mypy: allow-untyped-defs
+import inspect
+from abc import ABC, abstractmethod
+
+from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import allreduce_hook
+from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import (
+    _hook_then_optimizer,
+    _OptimizerHookState,
+)
+from torch.distributed.fsdp import FullyShardedDataParallel
+from torch.distributed.optim import as_functional_optim
+from torch.nn.parallel import DistributedDataParallel
+from torch.optim import Optimizer
+
+
+# Contains the mappings between the regular and overlapped optimizer types.
+_registered_overlapped_optims: dict[type, type] = {}
+
+
+def register_overlapped(optim_cls):
+    def decorator(target_overlapped_optim_cls):
+        if target_overlapped_optim_cls in _registered_overlapped_optims:
+            raise ValueError(
+                f"{target_overlapped_optim_cls} already registered with optim_cls "
+                f"{_registered_overlapped_optims[optim_cls]} {optim_cls}, trying to"
+                f"re-register it for {optim_cls} is not supported."
+            )
+        _registered_overlapped_optims[optim_cls] = target_overlapped_optim_cls
+        return target_overlapped_optim_cls
+
+    return decorator
+
+
+class OverlappedOptimizer(ABC):
+    def __init__(self, optim_cls: type) -> None:
+        """
+        Initialize the OverlappedOptimizer.
+
+        Overlappedoptimizer is a base class that child classes can implement to
+        specify how different optimizers will register themselves with DDP.
+        """
+        self.optim_cls = optim_cls
+
+    @abstractmethod
+    def register_ddp(self, ddp: DistributedDataParallel) -> None:
+        """Registers the overlapped optimizer with DDP."""
+        raise NotImplementedError(
+            f"{self.__class__.__name__} does not support overlapped DDP."
+        )
+
+    @abstractmethod
+    def register_fsdp(self, fsdp: FullyShardedDataParallel) -> None:
+        """Registers the overlapped optimizer with FSDP."""
+        raise NotImplementedError(
+            f"{self.__class__.__name__} does not support overlapped FSDP."
+        )
+
+
+@register_overlapped(Optimizer)
+class _OverlappedStandardOptimizer(OverlappedOptimizer):
+    """Overlaps a regular ``Optimizer``."""
+
+    def __init__(self, optim_cls: type, params, *optim_args, **optim_kwargs) -> None:
+        super().__init__(optim_cls)
+        f_optim = as_functional_optim(self.optim_cls, *optim_args, **optim_kwargs)
+        self._opt_hook_state = _OptimizerHookState(f_optim, params)
+
+    def register_ddp(self, ddp_inst: DistributedDataParallel):
+        # NOTE: using a custom communication hook and fused optimizer is not
+        # yet supported.
+        ddp_inst.register_comm_hook(  # type: ignore[operator]
+            None,  # wrapped hook state
+            _hook_then_optimizer(allreduce_hook, self._opt_hook_state),
+        )
+
+    # TODO: register_fsdp once FSDP supports communication hook.
+    def register_fsdp(self, fsdp: FullyShardedDataParallel) -> None:
+        """Register the overlapped optimizer with FSDP."""
+        raise NotImplementedError(
+            f"{self.__class__.__name__} does not support overlapped FSDP."
+        )
+
+
+def _as_overlapped_optim(optim_cls: type, params, *args, **kwargs):
+    """Return a new ``OverlappedOptimizer`` instance that supports ``optim_cls``."""
+    for clz in inspect.getmro(optim_cls):
+        try:
+            return _registered_overlapped_optims[clz](
+                optim_cls, params, *args, **kwargs
+            )
+        except KeyError:
+            pass
+
+    # Fallback to standard overlapped optimizer, which will raise errors if user
+    # is attempting to use an unsupported optimizer.
+    return _OverlappedStandardOptimizer(optim_cls, params, *args, **kwargs)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_quantization/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_quantization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_quantization/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_quantization/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b92c819fd5144035ae84bf95d70a57924e044623
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_quantization/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_quantization/__pycache__/quantization.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_quantization/__pycache__/quantization.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c19ce25faecf59c49647072074f3acccffc2ec5e
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_quantization/__pycache__/quantization.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_quantization/quantization.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_quantization/quantization.py
new file mode 100644
index 0000000000000000000000000000000000000000..69d88604561355b344b43129108d276e398e0f9f
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/_quantization/quantization.py
@@ -0,0 +1,151 @@
+# mypy: allow-untyped-defs
+import functools
+from enum import Enum
+
+import torch
+import torch.distributed as dist
+
+
+TORCH_HALF_MIN = torch.finfo(torch.float16).min
+TORCH_HALF_MAX = torch.finfo(torch.float16).max
+
+
+class DQuantType(Enum):
+    """
+    Different quantization methods for auto_quantize API are identified here.
+
+    auto_quantize API currently supports fp16 and bfp16 methods.
+    """
+
+    FP16 = ("fp16",)
+    BFP16 = "bfp16"
+
+    def __str__(self) -> str:
+        return self.value
+
+
+def _fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor:
+    return torch.clamp(tensor, TORCH_HALF_MIN, TORCH_HALF_MAX).half()
+
+
+def _quantize_tensor(tensor, qtype):
+    if not isinstance(tensor, torch.Tensor):
+        raise RuntimeError(
+            f"_quantize_tensor expecting torch.Tensor as input but found {type(tensor)}"
+        )
+    if qtype == DQuantType.FP16:
+        return _fp32_to_fp16_with_clamp(tensor)
+    elif qtype == DQuantType.BFP16:
+        return torch.ops.quantization._FloatToBfloat16Quantized(tensor)
+    else:
+        raise RuntimeError(f"Quantization type {qtype} is not supported")
+
+
+def _quantize_tensor_list(tensor_list, qtype):
+    if not isinstance(tensor_list, list) or not all(
+        isinstance(p, torch.Tensor) for p in tensor_list
+    ):
+        raise RuntimeError(
+            f"_quantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}"
+        )
+    quantized_tensor_list = [_quantize_tensor(t, qtype) for t in tensor_list]
+    return quantized_tensor_list
+
+
+def _dequantize_tensor(tensor, qtype, quant_loss=None):
+    if not isinstance(tensor, torch.Tensor):
+        raise RuntimeError(
+            f"_dequantize_tensor expecting torch.Tensor as input but found {type(tensor)}"
+        )
+    if qtype == DQuantType.FP16:
+        if tensor.dtype != torch.float16:
+            raise RuntimeError(
+                f"tensor dtype is {tensor.dtype} while expected to be FP16."
+            )
+        elif tensor.dtype == torch.float16 and quant_loss is None:
+            return tensor.float()
+        else:
+            # pyrefly: ignore [unsupported-operation]
+            return tensor.float() / quant_loss
+    elif qtype == DQuantType.BFP16:
+        if tensor.dtype != torch.float16:
+            raise RuntimeError(
+                f"tensor dtype is {tensor.dtype} while expected to be FP16."
+            )
+        else:
+            return torch.ops.quantization._Bfloat16QuantizedToFloat(tensor)
+    else:
+        raise RuntimeError(f"Quantization type {qtype} is not supported")
+
+
+def _dequantize_tensor_list(tensor_list, qtype, quant_loss=None):
+    if not isinstance(tensor_list, list) or not all(
+        isinstance(p, torch.Tensor) for p in tensor_list
+    ):
+        raise RuntimeError(
+            f"_dequantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}"
+        )
+    dequantized_tensor_list = [_dequantize_tensor(t, qtype) for t in tensor_list]
+    return dequantized_tensor_list
+
+
+def auto_quantize(func, qtype, quant_loss=None):
+    """
+    Quantize the input tensors, choose the precision types, and pass other necessary arguments and then dequantizes the output.
+
+    Currently it only supports:
+        . FP16 and BFP16 quantization method supported for gloo and nccl backends
+        . all_gather, all_to_all collective ops
+    Note: BFP16 only supports 2D tensors.
+    Args:
+        func (Callable): A function representing collective operations.
+        qtype (QuantType): Quantization method
+        quant_loss (float, optional): This can be used to improve accuracy in the dequantization.
+    Returns:
+        (Callable): the same collective as func but enables automatic quantization/dequantization.
+    """
+
+    @functools.wraps(func)
+    def wrapper(*args, **kwargs):
+        group = kwargs.get("group")
+        async_op = kwargs.get("async_op", False)
+        if async_op is True:
+            raise RuntimeError("The async_op=True mode is not supported yet.")
+        if func is dist.all_gather:
+            tensors = args[0]
+            input_tensors = _quantize_tensor(args[1], qtype)
+            out_tensors = _quantize_tensor_list(tensors, qtype)
+            dist.all_gather(out_tensors, input_tensors, group=group, async_op=async_op)
+            for i, t in enumerate(
+                _dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)
+            ):
+                tensors[i] = t
+
+        elif func is dist.all_to_all:
+            tensors = args[0]
+            input_tensors = _quantize_tensor_list(args[1], qtype)
+            out_tensors = _quantize_tensor_list(tensors, qtype)
+            dist.all_to_all(out_tensors, input_tensors, group=group, async_op=async_op)
+            for i, t in enumerate(
+                _dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)
+            ):
+                tensors[i] = t
+
+        elif func is dist.all_to_all_single:
+            tensors = args[0]
+            out_splits = kwargs.get("out_splits")
+            in_splits = kwargs.get("in_splits")
+            # Quantizing the input/output tensor
+            input_tensors = _quantize_tensor(args[1], qtype)
+            out_tensors = _quantize_tensor(tensors, qtype)
+            dist.all_to_all_single(
+                out_tensors, input_tensors, out_splits, in_splits, group=group
+            )
+            for i, t in enumerate(
+                _dequantize_tensor(out_tensors, qtype, quant_loss=quant_loss)
+            ):
+                tensors[i] = t
+        else:
+            raise RuntimeError(f"The collective op {func} is not supported yet")
+
+    return wrapper
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9cc6d12785cc760ae77039d5403bd36c94fcdb8
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__init__.py
@@ -0,0 +1,140 @@
+# mypy: allow-untyped-defs
+import sys
+from enum import Enum
+from functools import partial
+
+
+# To suppress FutureWarning from partial since 3.13
+if sys.version_info >= (3, 13):
+    from enum import member
+
+    def _enum_member(x):
+        return member(x)
+else:
+
+    def _enum_member(x):
+        return x
+
+
+import torch.distributed as dist
+
+from . import (
+    debugging_hooks as debugging,
+    default_hooks as default,
+    optimizer_overlap_hooks as optimizer_overlap,
+    powerSGD_hook as powerSGD,
+    quantization_hooks as quantization,
+)
+
+
+__all__ = ["DDPCommHookType", "register_ddp_comm_hook"]
+
+
+def _ddp_comm_hook_wrapper(comm_hook, model, state):
+    model.register_comm_hook(state, comm_hook)
+
+
+def _powerSGD_comm_hook_wrapper(
+    comm_hook,
+    model,
+    state,
+    matrix_approximation_rank,
+    start_powerSGD_iter=1_000,
+):
+    """
+    Wrap PowerSGD communication hook.
+
+    To be consistent with the wrappers of other DDP comm hooks, the input state only needs to be a process group,
+    which will be wrapped up with other state info.
+    """
+    powerSGD_state = powerSGD.PowerSGDState(
+        process_group=state,
+        matrix_approximation_rank=matrix_approximation_rank,
+        start_powerSGD_iter=start_powerSGD_iter,
+    )
+    model.register_comm_hook(powerSGD_state, comm_hook)
+
+
+class DDPCommHookType(Enum):
+    """
+    Enumerate ``ddp_comm_hooks`` and ``ddp_comm_hook_wrapper`` communucation hook types.
+
+    DDPCommHookType enumerates the hooks of ``torch.distributed.algorithms.ddp_comm_hooks``
+    as names and ``ddp_comm_hook_wrapper`` partials with hook specified. As an example,
+    you can register allreduce hook by
+    ``DDPCommHookType.ALLREDUCE.value(model=model, state=process_group)``.
+    """
+
+    ALLREDUCE = _enum_member(
+        partial(_ddp_comm_hook_wrapper, comm_hook=default.allreduce_hook)
+    )
+    FP16_COMPRESS = _enum_member(
+        partial(_ddp_comm_hook_wrapper, comm_hook=default.fp16_compress_hook)
+    )
+    BF16_COMPRESS = _enum_member(
+        partial(_ddp_comm_hook_wrapper, comm_hook=default.bf16_compress_hook)
+    )
+    QUANTIZE_PER_TENSOR = _enum_member(
+        partial(
+            _ddp_comm_hook_wrapper, comm_hook=quantization.quantization_pertensor_hook
+        )
+    )
+    QUANTIZE_PER_CHANNEL = _enum_member(
+        partial(
+            _ddp_comm_hook_wrapper, comm_hook=quantization.quantization_perchannel_hook
+        )
+    )
+    POWER_SGD = _enum_member(
+        partial(
+            _powerSGD_comm_hook_wrapper,
+            comm_hook=powerSGD.powerSGD_hook,
+            matrix_approximation_rank=1,
+        )
+    )
+    # Rank-2 PowerSGD can give a higher accuracy than the default rank-1 version,
+    # but it runs slower and consumes more memory.
+    POWER_SGD_RANK2 = _enum_member(
+        partial(
+            _powerSGD_comm_hook_wrapper,
+            comm_hook=powerSGD.powerSGD_hook,
+            matrix_approximation_rank=2,
+        )
+    )
+    # Batching can lead to a faster training at the cost of accuracy.
+    BATCHED_POWER_SGD = _enum_member(
+        partial(
+            _powerSGD_comm_hook_wrapper,
+            comm_hook=powerSGD.batched_powerSGD_hook,
+            matrix_approximation_rank=1,
+        )
+    )
+    BATCHED_POWER_SGD_RANK2 = _enum_member(
+        partial(
+            _powerSGD_comm_hook_wrapper,
+            comm_hook=powerSGD.batched_powerSGD_hook,
+            matrix_approximation_rank=2,
+        )
+    )
+    NOOP = _enum_member(
+        partial(
+            _ddp_comm_hook_wrapper,
+            comm_hook=debugging.noop_hook,
+        )
+    )
+
+
+def register_ddp_comm_hook(comm_hook_type: DDPCommHookType, model, state=None):
+    """
+    Register ``ddp_comm_hooks`` to DDP model.
+
+    Registers the hooks of ``torch.distributed.algorithms.ddp_comm_hooks``
+    to the DDP model. User can specify the type of hook as an enum
+    ``DDPCommHookType`` type using ``comm_hook_type`` input. State input will
+    be passed to the model.
+    Uses Python comm hook implementations.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> register_ddp_comm_hook(DDPCommHookType.FP16_COMPRESS, model, state)
+    """
+    comm_hook_type.value(model=model, state=state)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b1c2378c4ca979ec9d21378b3dcd9b1f3de17b56
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/ddp_zero_hook.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/ddp_zero_hook.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..030096b40c8978f6328bb58b41304b607f57037b
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/ddp_zero_hook.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/debugging_hooks.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/debugging_hooks.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fcac585f84173b24db8e45cdfdef03f5a7625577
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/debugging_hooks.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/default_hooks.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/default_hooks.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f2ed1cc64995e0d3440f5dcace20332f8cc9f28b
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/default_hooks.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/mixed_precision_hooks.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/mixed_precision_hooks.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bc7659cff2e94bd5239aee8d520aeb024b7a02c1
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/mixed_precision_hooks.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/optimizer_overlap_hooks.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/optimizer_overlap_hooks.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e8ab8d568002abbfa43f9a3d398dc3e3bcbcdd0
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/optimizer_overlap_hooks.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/post_localSGD_hook.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/post_localSGD_hook.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a6384c302cd5c648a63a43777d768a6efb75a894
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/post_localSGD_hook.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/powerSGD_hook.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/powerSGD_hook.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f7144f3ce8bfcb2d31e04ee9d0dc57a6ab862c7e
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/powerSGD_hook.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/quantization_hooks.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/quantization_hooks.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a5d1c94acd9ab392f961b2ca33817ee7de8e2458
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/quantization_hooks.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa8c865c89151033b379d6cd4785fd15e002cd66
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py
@@ -0,0 +1,457 @@
+# mypy: allow-untyped-defs
+import weakref
+from collections.abc import Callable
+from typing import Any
+
+import torch
+import torch.distributed as dist
+from torch.distributed.optim import ZeroRedundancyOptimizer
+from torch.distributed.optim.zero_redundancy_optimizer import _OverlapStatus
+from torch.nn.parallel.distributed import DistributedDataParallel
+
+
+__all__ = ["hook_with_zero_step", "hook_with_zero_step_interleaved"]
+
+# Functional optimizers require passing a list of gradients to their `step()`
+# method, and ZeRO requires a functional optimizer to overlap with DDP
+# Passing a `None` instead of an actual gradient indicates to the optimizer
+# to not update the corresponding parameter
+_NO_PARAM_UPDATE: None = None
+
+
+def _perform_local_step(
+    bucket: dist.GradBucket,
+    zero: ZeroRedundancyOptimizer,
+    rank: int,
+):
+    r"""
+    Perform a local optimizer step using the gradients provided by ``bucket``.
+
+    Arguments:
+        bucket (dist.GradBucket): the bucket providing the gradients.
+        zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer`
+            instance to perform the :meth:`_local_step`.
+        rank (int): the calling process's rank.
+
+    .. warning::
+        This function assumes that appropriate synchronization has taken place
+        so that the bucket's gradients can be used.
+    """
+    overlap_info = zero._overlap_info
+    bucket_index = bucket.index()
+    assert len(zero.optim.param_groups) == 1, (
+        "Overlapping DDP with ZeRO only supports a single parameter group"
+    )
+
+    # Construct the `gradients` input for the local optimizer step, which
+    # expects `None` in a list position to indicate that the corresponding
+    # parameter should not be updated
+    num_local_optim_params = len(zero.optim.param_groups[0]["params"])
+    gradients: list[torch.Tensor | None] = [
+        _NO_PARAM_UPDATE for _ in range(num_local_optim_params)
+    ]
+    assert bucket_index in overlap_info.offsets, (
+        f"Bucket index {bucket_index} was not assigned to rank {rank}"
+    )
+    gradients_offset = overlap_info.offsets[bucket_index]
+    bucket_assignment = zero._bucket_assignments_per_rank[rank][bucket_index]
+    bucket_offset = bucket_assignment.offset
+    length = len(bucket_assignment.parameters)
+    bucket_gradients = bucket.gradients()[bucket_offset : bucket_offset + length]
+    for i, grad in enumerate(bucket_gradients):
+        gradients[gradients_offset + i] = grad
+
+    zero._local_step(gradients)
+
+
+def _broadcast_bucket(
+    bucket_index: int,
+    zero: ZeroRedundancyOptimizer,
+):
+    r"""
+    Broadcasts a bucket's parameters.
+
+    Arguments:
+        bucket_index (int): the index of the bucket corresponding to the
+            parameters to broadcast.
+        zero (ZeroRedundancyOptimizer): the calling process's
+            :class:`ZeroRedundancyOptimizer` instance.
+    """
+    overlap_info = zero._overlap_info
+    assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, (
+        "`assigned_ranks_per_bucket` is not fully constructed"
+    )
+    # Sort to ensure the same ordering across ranks
+    assigned_ranks = sorted(overlap_info.assigned_ranks_per_bucket[bucket_index])
+    assert len(assigned_ranks) > 0, (
+        f"Bucket {bucket_index} should be assigned to at least one rank"
+    )
+    for assigned_rank in assigned_ranks:
+        bucket_assignments = zero._bucket_assignments_per_rank[assigned_rank]
+        if bucket_index in bucket_assignments:
+            send_tensor = bucket_assignments[bucket_index].tensor
+            assert send_tensor is not None
+            overlap_info.broadcast_handles.append(
+                dist.broadcast(
+                    send_tensor,
+                    src=dist.get_global_rank(zero.process_group, assigned_rank),
+                    group=zero.process_group,
+                    async_op=True,
+                )
+            )
+
+
+def _save_ddp_bucket_info(
+    bucket: dist.GradBucket,
+    zero: ZeroRedundancyOptimizer,
+):
+    r"""
+    Save :class:`DistributedDataParallel` gradient bucket information for :class:`ZeroRedundancyOptimizer` instance ``zero``.
+
+    In particular, this function is meant to be called upon seeing each
+    gradient bucket to use when overlapping, meaning it does not save or compute any global
+    information.
+
+    Arguments:
+        bucket (dist.GradBucket): the current gradient bucket.
+        zero (ZeroRedundancyOptimizer): the calling process's
+            :class:`ZeroRedundancyOptimizer` instance.
+    """
+    overlap_info = zero._overlap_info
+    bucket_params = bucket.parameters()
+    assert len(bucket_params) > 0, "Empty bucket"
+
+    # Save the parameters in the bucket
+    overlap_info.params_per_bucket.append(bucket_params)
+    if overlap_info.shard_buckets:
+        # Additionally save the bucket size for the assignment heuristic to use
+        bucket_size = 0
+        for param in bucket_params:
+            bucket_size += param.numel()
+        assert overlap_info.total_size is not None
+        overlap_info.total_size += bucket_size
+
+
+def _hook_with_zero_step_setup(
+    ddp_ref: weakref.ReferenceType,
+    zero: ZeroRedundancyOptimizer,
+    bucket: dist.GradBucket,
+):
+    r"""
+    Encapsulate the setup logic for :func:`hook_with_zero_step` and :func:`hook_with_zero_step_interleaved`.
+
+    This means the logic to run in the
+    hook before the backward pass and optimizer step can actually be
+    overlapped. This is factored out since it is common to both
+    :func:`hook_with_zero_step` and :func:`hook_with_zero_step_interleaved`.
+
+    Arguments:
+        ddp_ref (weakref.ReferenceType): weak reference to the process's
+            :class:`DistributedDataParallel` instance.
+        zero (ZeroRedundancyOptimizer): the calling process's
+            :class:`ZeroRedundancyOptimizer` instance.
+        bucket (dist.GradBucket): the current gradient bucket.
+    """
+    # Proceed as normal until the DDP buckets have been rebuilt
+    if not ddp_ref()._has_rebuilt_buckets:  # type: ignore[union-attr]
+        assert zero._overlap_info.status == _OverlapStatus.UNINITIALIZED
+        return
+
+    bucket_index = bucket.index()
+    overlap_info = zero._overlap_info
+    if overlap_info.status == _OverlapStatus.UNINITIALIZED:
+        overlap_info.status = _OverlapStatus.DDP_HAS_REBUILT_BUCKETS
+
+    if overlap_info.status == _OverlapStatus.DDP_HAS_REBUILT_BUCKETS:
+        if bucket_index == 0 and len(overlap_info.params_per_bucket) > 0:
+            # This corresponds to the first bucket of the backward pass
+            # immediately after all information has been saved, so we
+            # can perform the delayed ZeRO initialization
+            zero._init_zero_for_overlap()
+        else:
+            # Once DDP buckets have been rebuilt but ZeRO has not been
+            # properly initialized yet, save the information needed
+            _save_ddp_bucket_info(bucket, zero)
+
+
+def hook_with_zero_step(
+    hook: Callable[[Any, dist.GradBucket], torch.futures.Future],
+    ddp: DistributedDataParallel,
+    zero: ZeroRedundancyOptimizer,
+    shard_buckets: bool = False,
+) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
+    r"""
+    Modify ``hook`` to overlap :class:`ZeroRedundancyOptimizer` optimizer step with :class:`DistributedDataParallel` backward pass.
+
+    This approach overlaps the optimizer computation and communication with the
+    backward communication. In particular, the backward computation proceeds
+    contiguously, and the optimizer computation follows, overlapping with
+    outstanding backward communication (i.e. all-reduces) and possibly other
+    optimizer communication (i.e. broadcasts).
+    The optimizer step computation begins after the last gradient bucket computation has finished.
+
+    This approach may be preferred over :meth:`hook_with_zero_step_interleaved`
+    if communication is relatively slow compared to computation.
+
+    Arguments:
+        hook (Callable[[Any, dist.GradBucket], torch.futures.Future]): the hook
+            to modify.
+        ddp (DistributedDataParallel): the :class:`DistributedDataParallel`
+            instance to use.
+        zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer`
+            instance to use.
+        shard_buckets (bool): if ``True``, then the assignment of each
+            :class:`DistributedDataParallel` bucket is partitioned across
+            possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e.
+            across possibly multiple ranks) to approximate uniformity; if
+            ``False``, then each bucket is wholly assigned to a single
+            :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank).
+
+    Returns:
+        The modified hook.
+
+    Raises:
+        ValueError: if ``zero`` was constructed with ``overlap_with_ddp=False``.
+        RuntimeError: if using any backend other than NCCL/HCCL since currently
+            Gloo may hang.
+
+    .. warning::
+        Given the way that overlapping :class:`DistributedDataParallel` with
+        :class:`ZeroRedundancyOptimizer` is currently implemented, the first
+        two or three training iterations do not perform parameter updates in
+        the optimizer step, depending on if ``static_graph=False`` or
+        ``static_graph=True``, respectively. This is because it needs
+        information about the gradient bucketing strategy used by
+        :class:`DistributedDataParallel`, which is not finalized until the
+        second forward pass if ``static_graph=False`` or until the third
+        forward pass if ``static_graph=True``.
+    """
+    if not zero._overlap_with_ddp:
+        raise ValueError(
+            "ZeroRedundancyOptimizer must be constructed with "
+            "`overlap_with_ddp=True` to use this hook properly"
+        )
+    ddp_ref = weakref.ref(ddp)
+
+    # NOTE: Gloo may hang with this overlapping approach; see https://github.com/pytorch/pytorch/issues/62300
+    pg = dist.get_backend(ddp_ref().process_group)  # type: ignore[union-attr]
+    if pg == dist.Backend.GLOO:
+        raise RuntimeError(
+            "Gloo backend using Overlapping DDP with ZeRO may meet hangs"
+        )
+
+    if shard_buckets:
+        zero._overlap_info.shard_buckets = True
+        zero._overlap_info.total_size = 0
+
+    def hook_with_zero_fn(
+        state: Any,
+        bucket: dist.GradBucket,
+    ) -> torch.futures.Future[torch.Tensor]:
+        r"""
+        Return :class:`Future` that runs the optimizer step if this corresponds to the last gradient bucket.
+
+        Perform equivalent of :class:`ZeroRedundancyOptimizer` :meth:`step` if ``bucket`` is last gradient bucket.
+        The function gives a gradient bucket tensor and
+        performs additional computation on the iteration that
+        the :class:`DistributedDataParallel` buckets are rebuilt to collect
+        information used to implement the modified hook.
+
+        Arguments:
+            state (Any): any state for the hook.
+            bucket (dist.GradBucket): the :class:`DistributedDataParallel`
+                gradient bucket.
+        """
+        fut = hook(state, bucket)
+        _hook_with_zero_step_setup(ddp_ref, zero, bucket)
+        if zero._overlap_info.status != _OverlapStatus.INITIALIZED:
+            return fut
+
+        overlap_info = zero._overlap_info
+        bucket_index = bucket.index()
+        rank = zero.global_rank
+
+        assert overlap_info.status == _OverlapStatus.INITIALIZED
+        assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, (
+            "`assigned_ranks_per_bucket` is not fully constructed"
+        )
+        assigned_to_bucket = (
+            rank in overlap_info.assigned_ranks_per_bucket[bucket_index]
+        )
+
+        # Save the bucket reference and all-reduce future for the final bucket
+        if assigned_to_bucket:
+            overlap_info.bucket_index_to_bucket[bucket_index] = bucket
+            overlap_info.bucket_index_to_future[bucket_index] = fut
+
+        # Check that buckets are indexed incrementally starting from 0 in the
+        # order of their autograd hooks firing
+        if len(overlap_info.bucket_indices_seen) > 0:
+            assert overlap_info.bucket_indices_seen[-1] == bucket_index - 1, (
+                "Bucket indices are not in incremental order"
+            )
+        else:
+            assert bucket_index == 0, "Bucket indices do not start from 0"
+        overlap_info.bucket_indices_seen.append(bucket_index)
+
+        # Directly return the future without any optimizer computation if this
+        # is not the last bucket
+        num_buckets = len(overlap_info.params_per_bucket)
+        is_last_bucket = bucket_index == num_buckets - 1
+        if not is_last_bucket:
+            return fut
+
+        # Perform partial optimizer step on all buckets after the final
+        # bucket has been computed
+        # NOTE: This should not be chained as a callback to the last bucket's
+        # all-reduce future since that would add synchronization that delays
+        # all optimizer computation to wait for that last all-reduce
+        for bucket_index in range(num_buckets):
+            assigned_ranks = overlap_info.assigned_ranks_per_bucket[bucket_index]
+            if rank in assigned_ranks:
+                # Wait on the bucket's all-reduce future to ensure correct
+                # gradients
+                assert bucket_index in overlap_info.bucket_index_to_future, (
+                    f"All-reduce future for bucket {bucket_index} not saved "
+                    f"on rank {rank}"
+                )
+                allreduce_future = overlap_info.bucket_index_to_future[bucket_index]
+                allreduce_future.wait()
+
+                # Perform the partial optimizer step
+                curr_bucket = overlap_info.bucket_index_to_bucket[bucket_index]
+                _perform_local_step(curr_bucket, zero, rank)
+
+            _broadcast_bucket(bucket_index, zero)
+
+        # Ensure that all parameter updates are finished before the
+        # next forward pass
+        overlap_info.wait_for_broadcasts()
+        overlap_info.clear_per_iter_info()
+
+        return fut
+
+    return hook_with_zero_fn
+
+
+def hook_with_zero_step_interleaved(
+    hook: Callable[[Any, dist.GradBucket], torch.futures.Future],
+    ddp: DistributedDataParallel,
+    zero: ZeroRedundancyOptimizer,
+    shard_buckets: bool = False,
+) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
+    r"""
+    Modify ``hook`` to overlap :class:`ZeroRedundancyOptimizer` optimizer step with :class:`DistributedDataParallel` backward pass
+
+    This approach overlaps the optimizer computation and communication with the
+    backward computation and communication. In particular, once a bucket's
+    gradients have been computed, the optimizer computation using those
+    gradients is launched (though the actual computation must wait for the
+    bucket's all-reduce to complete). This yields an interleaving of all-
+    reduces and broadcasts in the communication stream.
+
+    This approach may be preferred over :meth:`hook_with_zero_step` if
+    communication is relatively fast compared to computation.
+
+    Arguments:
+        hook (Any * dist.GradBucket -> torch.futures.Future): the hook to
+            modify.
+        ddp (DistributedDataParallel): the :class:`DistributedDataParallel`
+            instance to use.
+        zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer`
+            instance to use.
+        shard_buckets (bool): if ``True``, then the assignment of each
+            :class:`DistributedDataParallel` bucket is partitioned across
+            possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e.
+            across possibly multiple ranks) to approximate uniformity; if
+            ``False``, then each bucket is wholly assigned to a single
+            :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank).
+
+    Returns:
+        The modified hook.
+
+    Raises:
+        ValueError: if ``zero`` was constructed with ``overlap_with_ddp=False``.
+        RuntimeError: if using any backend other than NCCL since currently
+            Gloo may hang.
+
+    .. warning::
+        Given the way that overlapping :class:`DistributedDataParallel` with
+        :class:`ZeroRedundancyOptimizer` is currently implemented, the first
+        two or three training iterations do not perform parameter updates in
+        the optimizer step, depending on if ``static_graph=False`` or
+        ``static_graph=True``, respectively. This is because it needs
+        information about the gradient bucketing strategy used by
+        :class:`DistributedDataParallel`, which is not finalized until the
+        second forward pass if ``static_graph=False`` or until the third
+        forward pass if ``static_graph=True``.
+    """
+    if not zero._overlap_with_ddp:
+        raise ValueError(
+            "ZeroRedundancyOptimizer must be constructed with "
+            "`overlap_with_ddp=True` to use this hook properly"
+        )
+    ddp_ref = weakref.ref(ddp)
+
+    # NOTE: Gloo may hang with this overlapping approach; see https://github.com/pytorch/pytorch/issues/62300
+    pg = dist.get_backend(ddp_ref().process_group)  # type: ignore[union-attr]
+    if pg == dist.Backend.GLOO:
+        raise RuntimeError(
+            "Gloo backend using Overlapping DDP with ZeRO may meet hangs"
+        )
+
+    if shard_buckets:
+        zero._overlap_info.shard_buckets = True
+        zero._overlap_info.total_size = 0
+
+    def hook_with_zero_interleaved_fn(
+        state,
+        bucket: dist.GradBucket,
+    ) -> torch.futures.Future[torch.Tensor]:
+        r"""
+        Return :class:`Future` that gives gradient bucket tensor and performs partial :class:`ZeroRedundancyOptimizer` :meth:`step`.
+
+        This function uses the gradients in gradient in given bucket to perform a partial
+        :class:`ZeroRedundancyOptimizer` :meth:`step`
+
+        Arguments:
+            state: any state for the hook.
+            bucket (dist.GradBucket): the :class:`DistributedDataParallel`
+                gradient bucket.
+        """
+        fut = hook(state, bucket)
+        _hook_with_zero_step_setup(ddp_ref, zero, bucket)
+        if zero._overlap_info.status != _OverlapStatus.INITIALIZED:
+            return fut
+
+        def zero_step(fut: torch.futures.Future) -> torch.Tensor:
+            r"""
+            Perform partial :class:`ZeroRedundancyOptimizer` :meth:`step` using gradients in the :class:`DistributedDataParallel`.
+
+            Returns:
+                A :class:`torch.Tensor` representing the contents of the
+                gradient bucket.
+            """
+            overlap_info = zero._overlap_info
+            bucket_index = bucket.index()
+            rank = zero.global_rank
+
+            assigned_ranks = overlap_info.assigned_ranks_per_bucket[bucket_index]
+            overlap_info.bucket_indices_seen.append(bucket_index)
+            if rank in assigned_ranks:
+                _perform_local_step(bucket, zero, rank)
+
+            _broadcast_bucket(bucket_index, zero)
+
+            num_buckets = len(overlap_info.params_per_bucket)
+            if len(overlap_info.bucket_indices_seen) == num_buckets:
+                # Ensure that all parameter updates are finished before the
+                # next forward pass
+                overlap_info.wait_for_broadcasts()
+                overlap_info.clear_per_iter_info()
+
+            return bucket.buffer()
+
+        return fut.then(zero_step)
+
+    return hook_with_zero_interleaved_fn
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..53a184839a06f4787471f14f48137f4aa344fd91
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py
@@ -0,0 +1,29 @@
+from typing import Any
+
+import torch
+from torch.distributed import GradBucket
+
+
+__all__ = ["noop_hook"]
+
+
+def noop_hook(_: Any, bucket: GradBucket) -> torch.futures.Future[torch.Tensor]:
+    """
+    Return a future that wraps the input, so it is a no-op that does not incur any communication overheads.
+
+    This hook should **only** be used for headroom analysis of allreduce optimization,
+    instead of the normal gradient synchronization.
+    For example, if only less than 10% speedup of training time can be observed after this hook is registered,
+    it usually implies that allreduce is not a performance bottleneck for this case.
+    Such instrumentation can be particularly useful
+    if GPU traces cannot be easily retrieved or the trace analysis is complicated
+    some factors such as the overlap between allreduce and computation or the desynchronization across ranks.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> ddp_model.register_comm_hook(None, noop_hook)
+    """
+    fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
+    fut.set_result(bucket.buffer())
+
+    return fut
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..20a0de7ef318c10f3b5bbdaf98483d9fd19b2691
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
@@ -0,0 +1,211 @@
+# mypy: allow-untyped-defs
+from collections.abc import Callable
+from typing import Any, cast
+
+import torch
+import torch.distributed as dist
+
+
+__all__ = [
+    "allreduce_hook",
+    "fp16_compress_hook",
+    "bf16_compress_hook",
+    "fp16_compress_wrapper",
+    "bf16_compress_wrapper",
+]
+
+
+def _allreduce_fut(
+    process_group: dist.ProcessGroup, tensor: torch.Tensor
+) -> torch.futures.Future[torch.Tensor]:
+    """Average the input gradient tensor by allreduce and returns a future."""
+    group_to_use = process_group if process_group is not None else dist.group.WORLD
+
+    # Apply the division first to avoid overflow, especially for FP16.
+    # pyrefly: ignore [missing-attribute]
+    tensor.div_(group_to_use.size())
+
+    return (
+        dist.all_reduce(tensor, group=group_to_use, async_op=True)
+        .get_future()
+        .then(lambda fut: fut.value()[0])
+    )
+
+
+def allreduce_hook(
+    process_group: dist.ProcessGroup, bucket: dist.GradBucket
+) -> torch.futures.Future[torch.Tensor]:
+    """
+    Call ``allreduce`` using ``GradBucket`` tensors.
+
+    Once gradient tensors are aggregated across all workers, its ``then``
+    callback takes the mean and returns the result.
+
+    If user registers this DDP communication hook,
+    DDP results is expected to be same as the case where no hook was registered.
+    Hence, this won't change behavior of DDP and user can use this as a reference
+    or modify this hook to log useful information or any other purposes while
+    unaffecting DDP behavior.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> ddp_model.register_comm_hook(process_group, allreduce_hook)
+    """
+    return _allreduce_fut(process_group, bucket.buffer())
+
+
+def _compress_hook(
+    dtype: torch.dtype,
+    process_group: dist.ProcessGroup,
+    bucket: dist.GradBucket,
+) -> torch.futures.Future[torch.Tensor]:
+    group_to_use = process_group if process_group is not None else dist.group.WORLD
+    # pyrefly: ignore [missing-attribute]
+    world_size = group_to_use.size()
+
+    buffer = (
+        cast(tuple[torch.Tensor, ...], bucket)[0]
+        if isinstance(bucket, tuple)
+        else bucket.buffer()
+    )
+    compressed_tensor = buffer.to(dtype).div_(world_size)
+
+    def decompress(fut):
+        decompressed_tensor = buffer
+        # Decompress in place to reduce the peak memory.
+        # See: https://github.com/pytorch/pytorch/issues/45968
+        value = fut if isinstance(fut, torch.Tensor) else fut.value()[0]
+        decompressed_tensor.copy_(value)
+        return decompressed_tensor
+
+    if torch.compiler.is_compiling():
+        grad = dist._functional_collectives.all_reduce(
+            compressed_tensor,
+            "sum",
+            # pyrefly: ignore [bad-argument-type]
+            group_to_use,
+        )
+        return decompress(grad)
+    else:
+        fut = dist.all_reduce(
+            compressed_tensor, group=group_to_use, async_op=True
+        ).get_future()
+        return fut.then(decompress)
+
+
+def fp16_compress_hook(
+    process_group: dist.ProcessGroup,
+    bucket: dist.GradBucket,
+) -> torch.futures.Future[torch.Tensor]:
+    """
+    Compress by casting ``GradBucket`` to ``torch.float16`` divided by process group size.
+
+    This DDP communication hook implements a simple gradient compression
+    approach that casts ``GradBucket`` tensor to half-precision floating-point format (``torch.float16``)
+    and then divides it by the process group size.
+    It allreduces those ``float16`` gradient tensors. Once compressed gradient
+    tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``).
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> ddp_model.register_comm_hook(process_group, fp16_compress_hook)
+    """
+    return _compress_hook(torch.float16, process_group, bucket)
+
+
+def bf16_compress_hook(
+    process_group: dist.ProcessGroup,
+    bucket: dist.GradBucket,
+) -> torch.futures.Future[torch.Tensor]:
+    """
+    Warning: This API is experimental, and it requires NCCL version later than 2.9.6.
+
+    This DDP communication hook implements a simple gradient compression
+    approach that casts ``GradBucket`` tensor to half-precision
+    `Brain floating point format `_ (``torch.bfloat16``)
+    and then divides it by the process group size.
+    It allreduces those ``bfloat16`` gradient tensors. Once compressed gradient
+    tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``).
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> ddp_model.register_comm_hook(process_group, bf16_compress_hook)
+    """
+    return _compress_hook(torch.bfloat16, process_group, bucket)
+
+
+def fp16_compress_wrapper(
+    hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]],
+) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
+    """
+    Cast input tensor to ``torch.float16``, cast result of hook back to input dtype.
+
+    This wrapper casts the input gradient tensor of a given DDP communication hook to half-precision
+    floating point format (``torch.float16``), and casts the resulting tensor of the given hook back to
+    the input data type, such as ``float32``.
+    Therefore, ``fp16_compress_hook`` is equivalent to ``fp16_compress_wrapper(allreduce_hook)``.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
+        >>> ddp_model.register_comm_hook(state, fp16_compress_wrapper(powerSGD_hook))
+    """
+
+    def fp16_compress_wrapper_hook(
+        hook_state, bucket: dist.GradBucket
+    ) -> torch.futures.Future[torch.Tensor]:
+        # Cast bucket tensor to FP16.
+        bucket.set_buffer(bucket.buffer().to(torch.float16))
+
+        fut = hook(hook_state, bucket)
+
+        def decompress(fut):
+            decompressed_tensor = bucket.buffer()
+            # Decompress in place to reduce the peak memory.
+            # See: https://github.com/pytorch/pytorch/issues/45968
+            decompressed_tensor.copy_(fut.value())
+            return decompressed_tensor
+
+        # Decompress after hook has run.
+        return fut.then(decompress)
+
+    return fp16_compress_wrapper_hook
+
+
+def bf16_compress_wrapper(
+    hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]],
+) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
+    """
+    Warning: This API is experimental, and it requires NCCL version later than 2.9.6.
+
+    This wrapper casts the input gradient tensor of a given DDP communication hook to half-precision
+    `Brain floating point format `_  (``torch.bfloat16``),
+    and casts the resulting tensor of the given hook back to the input data type, such as ``float32``.
+
+    Therefore, ``bf16_compress_hook`` is equivalent to ``bf16_compress_wrapper(allreduce_hook)``.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
+        >>> ddp_model.register_comm_hook(state, bf16_compress_wrapper(powerSGD_hook))
+    """
+
+    def bf16_compress_wrapper_hook(
+        hook_state, bucket: dist.GradBucket
+    ) -> torch.futures.Future[torch.Tensor]:
+        # Cast bucket tensor to BF16.
+        bucket.set_buffer(bucket.buffer().to(torch.bfloat16))
+
+        fut = hook(hook_state, bucket)
+
+        def decompress(fut):
+            decompressed_tensor = bucket.buffer()
+            # Decompress in place to reduce the peak memory.
+            # See: https://github.com/pytorch/pytorch/issues/45968
+            decompressed_tensor.copy_(fut.value())
+            return decompressed_tensor
+
+        # Decompress after hook has run.
+        return fut.then(decompress)
+
+    return bf16_compress_wrapper_hook
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1968042e5e21aa1b6714f78356b43896cccdf60
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py
@@ -0,0 +1,86 @@
+from dataclasses import dataclass
+from typing import Any, no_type_check
+
+import torch
+import torch.distributed as dist
+from torch.autograd import Variable
+from torch.distributed.utils import _free_storage
+
+
+@dataclass
+class _AllreduceUpcastHookState:
+    """
+    State to manage DDP mixed precision in backward / gradient communication.
+
+    This contains a weakref to the DDP module for access to reducer and process
+    group, and a stream to run parameter and gradient upcasts.
+    """
+
+    ddp_weakref: Any
+    upcast_stream: torch.Stream
+    wait_for_stream_enqueued: bool = False
+
+
+@no_type_check
+def _reducer_allreduce_and_upcast_hook(
+    hook_state: _AllreduceUpcastHookState, bucket: dist.GradBucket
+) -> torch.futures.Future[torch.Tensor]:
+    """
+    Perform allreduce in precision ``reduce_dtype``, upcast to prepare for optimizer.
+
+    Performs allreduce in the reduced precision given by DDP's mixed precision
+    reduce_dtype, and upcasts parameters and gradients to fp32 in preparation
+    to run the optimizer.
+    """
+    ddp_weakref = hook_state.ddp_weakref
+    reducer, process_group = ddp_weakref().reducer, ddp_weakref().process_group
+    # Cast bucket if different than param_dtype.
+    if (
+        ddp_weakref().mixed_precision.param_dtype
+        != ddp_weakref().mixed_precision.reduce_dtype
+    ):
+        # Cast bucket tensor to reduce_dtype
+        bucket.set_buffer(
+            bucket.buffer().to(ddp_weakref().mixed_precision.reduce_dtype)
+        )
+    fut = reducer._run_allreduce_hook(bucket)
+    ret_fut = torch.futures.Future()
+    stream = hook_state.upcast_stream
+    with stream:
+        fut.wait()
+        bucket.buffer().div_(process_group.size())
+        ret_fut.set_result(bucket.buffer())
+
+        # Upcast parameters and gradients so optimizer step can run in fp32.
+        for p in bucket.parameters():
+            p.data = p._fp_param
+            # free storage for mp param as it will be allocated again in next
+            # forward pass.
+            _free_storage(p._mp_param)
+            p.grad.data = p.grad.to(p.data.dtype)
+
+    # enqueue a callback to wait for this stream at end of backward
+    def wait_for_stream_cb():
+        torch.accelerator.current_stream().wait_stream(stream)
+        # Remove post-backward hooks since they are re-installed in next
+        # iteration, similar to FSDP.
+        # Parameters that don't require grad still needed to be casted since
+        # they may participate in computation. However, they would not be recast
+        # by hook above as they don't have a grad hook installed, so cast them
+        # back here.
+        for _, p in ddp_weakref().module.named_parameters():
+            if hasattr(p, "_ddp_mp_hook_state"):
+                p._ddp_mp_hook_state[1].remove()
+                delattr(p, "_ddp_mp_hook_state")
+            if not p.requires_grad and not hasattr(p, "_ddp_ignored"):
+                p.data = p._fp_param
+
+        # reset for next backward pass
+        hook_state.wait_for_stream_enqueued = False
+
+    if not hook_state.wait_for_stream_enqueued:
+        Variable._execution_engine.queue_callback(wait_for_stream_cb)
+        # mark that the callback is enqueued
+        hook_state.wait_for_stream_enqueued = True
+
+    return ret_fut
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..162160e394ad0b634365f941f3a9f216bf1aa2d8
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py
@@ -0,0 +1,163 @@
+# mypy: allow-untyped-defs
+from collections.abc import Callable
+from dataclasses import dataclass
+from functools import partial
+from typing import Any, no_type_check
+
+import torch
+import torch.distributed as dist
+from torch.autograd import Variable
+
+
+__all__: list[str] = []
+
+_FUNCTIONAL_OPTIM_STEP_METHOD_NAME = "step_param"
+
+
+class _OptimizerHookState:
+    """
+    Holds state for running optimizer in-line after DDP communication hook.
+
+    Currently contains only optimizer class which must have a method `step_param`.
+    """
+
+    __slots__ = ["functional_optimizer", "params_to_optimize"]
+
+    def __init__(self, functional_optim, params=None):
+        self.functional_optimizer = functional_optim
+        self._check_valid_functional_optim()
+        self._set_params_to_optimize(params)
+
+    def _set_params_to_optimize(self, params):
+        if params is not None:
+            self.params_to_optimize = set(params)
+
+    def _check_valid_functional_optim(self):
+        if not hasattr(self.functional_optimizer, _FUNCTIONAL_OPTIM_STEP_METHOD_NAME):
+            raise ValueError(
+                f"Class {type(self.functional_optimizer)} must implement method "
+                f"{_FUNCTIONAL_OPTIM_STEP_METHOD_NAME}."
+            )
+
+
+@dataclass
+class _OptimInBackwardHookState:
+    optim_stream: torch.Stream
+    wait_for_optim_stream_enqueued: bool
+
+
+@no_type_check
+def _apply_optim_in_backward_hook(
+    gradient_is_bucket_view: bool,
+) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
+    r"""
+    Register hook to apply the optimizer in backward.
+
+    If torch.distributed.optim._apply_optimizer_in_backward is used to overlap
+    optimizer with backward pass, DDP will run the below hook to run optimizer
+    step for parameters after gradient communication has taken place.
+    """
+    optim_in_bwd_state = _OptimInBackwardHookState(
+        optim_stream=torch.Stream(),
+        wait_for_optim_stream_enqueued=False,
+    )
+
+    def apply_optim_in_backward_hook(
+        hook_state: Any,
+        bucket: dist.GradBucket,
+        optim_stream_state,
+    ) -> torch.futures.Future[torch.Tensor]:
+        # Run original hook
+        ddp_weakref = hook_state
+        ddp_inst = ddp_weakref()
+        reducer, process_group = ddp_inst.reducer, ddp_inst.process_group
+        fut = reducer._run_allreduce_hook(bucket)
+        optimizer_stream = optim_stream_state.optim_stream
+        with optimizer_stream:
+            fut.wait()
+            # Apply gradient division since C++ side only allreduces and does
+            # not average. TODO: (rohan-varma) the div factor may be different
+            # when running with join hook
+            bucket.buffer().div_(process_group.size())
+            model_params = bucket.parameters()
+            grads = bucket.gradients()
+            # TODO (rohan-varma): upcast as needed for DDP mixed precision,
+            # once optimizer in backward + DDP mixed precision is supported.
+            for p, g in zip(model_params, grads):
+                if hasattr(p, "_in_backward_optimizers"):
+                    # Note: need to set grad to the bucket's grad, because
+                    # running allreduce results in the bucket's grad being
+                    # reduced, but not grad field.
+                    if not gradient_is_bucket_view:
+                        p.grad = g
+                    for optim in p._in_backward_optimizers:
+                        optim.step()
+
+        # Need to return a Future[Tensor] to obey comm hook API contract.
+        ret_fut = torch.futures.Future()
+        ret_fut.set_result(bucket.buffer())
+
+        # enqueue a callback to wait for this optimizer stream at the end of
+        # backward and set all DDP managed grads to None.
+        def wait_for_optim_stream_callback():
+            torch.accelerator.current_stream().wait_stream(
+                optim_stream_state.optim_stream
+            )
+            # Set DDP managed grads to None
+            for param in ddp_inst._get_data_parallel_params(ddp_inst.module):
+                if hasattr(param, "_in_backward_optimizers"):
+                    param.grad = None
+
+            # reset for the next backwards pass
+            optim_stream_state.wait_for_optim_stream_enqueued = False
+
+        if not optim_stream_state.wait_for_optim_stream_enqueued:
+            Variable._execution_engine.queue_callback(wait_for_optim_stream_callback)
+            # mark that the callback is enqueued
+            optim_stream_state.wait_for_optim_stream_enqueued = True
+
+        return ret_fut
+
+    comm_hook = partial(
+        apply_optim_in_backward_hook, optim_stream_state=optim_in_bwd_state
+    )
+    # These are needed for DDP's logging of comm hooks
+    comm_hook.__name__ = apply_optim_in_backward_hook.__name__
+    comm_hook.__qualname__ = apply_optim_in_backward_hook.__qualname__
+
+    return comm_hook
+
+
+def _hook_then_optimizer(
+    hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]],
+    optimizer_state: _OptimizerHookState,
+) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
+    r"""Run optimizer in a functional fashion after DDP communication hook."""
+    has_set_params = (
+        hasattr(optimizer_state, "params_to_optimize")
+        and optimizer_state.params_to_optimize is not None
+    )
+
+    def hook_then_optimizer_wrapper(
+        hook_state, bucket: dist.GradBucket
+    ) -> torch.futures.Future[torch.Tensor]:
+        # Run original hook
+        fut = hook(hook_state, bucket)
+
+        def optimizer_step(fut):
+            gradient_tensors = bucket.gradients()
+            model_params = bucket.parameters()
+            for grad_tensor, model_param in zip(gradient_tensors, model_params):
+                if (
+                    not has_set_params
+                    or model_param in optimizer_state.params_to_optimize
+                ):
+                    optimizer_state.functional_optimizer.step_param(
+                        model_param,
+                        grad_tensor,
+                    )
+            return bucket.buffer()
+
+        return fut.then(optimizer_step)
+
+    return hook_then_optimizer_wrapper
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff513f62183c516b96c62ca89eee51d2b1793e85
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py
@@ -0,0 +1,124 @@
+# mypy: allow-untyped-defs
+import logging
+
+import torch
+import torch.distributed as dist
+
+from . import default_hooks as default
+
+
+logger = logging.getLogger(__name__)
+
+
+class PostLocalSGDState:
+    r"""
+    Store state for all-reducing gradients globally until given step, then locally after.
+
+    Stores the state for all-reducing gradients globally using ``process_group`` until step ``start_localSGD_iter``,
+    and all-reducing gradients locally using ``subgroup`` afterwards.
+
+    If ``process_group`` is ``None``, the global process group will be used.
+    If ``subgroup`` is ``None``, the intra-node process group on each machine will be used.
+
+    Additionally, ``post_local_gradient_allreduce`` may be worth tuning,
+    because both true and false may give a faster convergence.
+    """
+
+    __slots__ = [
+        "process_group",
+        "subgroup",
+        "start_localSGD_iter",
+        "post_local_gradient_allreduce",
+        "iter",
+    ]
+
+    def __init__(
+        self,
+        process_group,
+        subgroup,
+        start_localSGD_iter,
+        post_local_gradient_allreduce=True,
+    ):
+        """Initialize state object with given parameters and log when localSGD start."""
+        logger.info(
+            "Local SGD will be started after %s iterations", start_localSGD_iter
+        )
+
+        # The group used for all-reducing gradients globally.
+        self.process_group = process_group
+        # The group used for all-reducing gradients locally.
+        self.subgroup = subgroup
+        self.start_localSGD_iter = start_localSGD_iter
+        # Allreduce gradients locally since iteration `start_localSGD_iter`.
+        # This may help with the convergence efficiency at the cost of relatively cheap intra-subgroup communication.
+        self.post_local_gradient_allreduce = post_local_gradient_allreduce
+        # Iteration/step in the training loop.
+        self.iter = 0
+
+    def maybe_increase_iter(self, bucket):
+        """Track iterations and trigger log message at start of local SGD."""
+        # Since bucket 0 is the last bucket to allreduce in an iteration.
+        # Only increase `iter` when bucket 0 is processed.
+        if bucket.is_last():
+            self.iter += 1
+
+        if self.iter == self.start_localSGD_iter:
+            logger.info("Start to apply local SGD after %s iterations.", self.iter)
+
+
+def post_localSGD_hook(
+    state: PostLocalSGDState, bucket: dist.GradBucket
+) -> torch.futures.Future[torch.Tensor]:
+    """
+    Run post-localSGD algorithm.
+
+    This DDP communication hook is used for running post-localSGD algorithm,
+    by combining with a model averaging component (e.g.,
+    :class:`~torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager`)
+    that runs after the optimizer step.
+
+    Args:
+        state (PostLocalSGDState): State information to run post-localSGD.
+            Users mainly need to tune ``start_localSGD_iter`` to determine when to start local SGD.
+        bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
+            Note that since DDP comm hook only supports single process single device mode,
+            only exactly one tensor is stored in this bucket.
+
+    Returns:
+        Future handler of the communication, which updates the gradients in place.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> state = PostLocalSGDState(process_group=process_group, subgroup=subgroup,
+                                  start_localSGD_iter=10)
+        >>> ddp_model.register_comm_hook(state, post_localSGD_hook)
+        >>> # Also need to establish a model averaging module and run model averaging after ``optimizer.step()``.
+        >>> # Please refer to the examples in ``torch.distributed.algorithms.model_averaging.averagers`` module.
+    """
+    global_group_to_use = (
+        state.process_group if state.process_group is not None else dist.group.WORLD
+    )
+
+    # The input tensor is a flattened 1D tensor.
+    input_tensor = bucket.buffer()
+
+    # Run allreduce using `global_group_to_use` in the first `start_localSGD_iter` iterations.
+    if state.iter < state.start_localSGD_iter:
+        state.maybe_increase_iter(bucket)
+        return default._allreduce_fut(global_group_to_use, input_tensor)  # type: ignore[arg-type]
+
+    # If `post_local_gradient_allreduce` is not set,
+    # then no gradient synchronization after the first `start_localSGD_iter` iterations.
+    if not state.post_local_gradient_allreduce:
+        fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
+        fut.set_result(input_tensor)
+        return fut
+
+    # Run allreduce using `subgroup` after the first `start_localSGD_iter` iterations.
+    # Note that by default, a separate subgroup for each node is created which
+    # causes an intra-node allreduce to be done at each training step.
+    # From this moment, model averaging should run after the optimizer step,
+    # to globally allreduce all the parameters.
+    if state.subgroup is None:
+        state.subgroup, _ = dist.new_subgroups()
+    return default._allreduce_fut(state.subgroup, input_tensor)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1e95d12514eda18b52ae07a44a68e1678bd27a9
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
@@ -0,0 +1,862 @@
+# mypy: allow-untyped-defs
+import logging
+import math
+from collections import defaultdict
+
+import torch
+import torch.distributed as dist
+from torch.distributed import distributed_c10d
+from torch.utils._typing_utils import not_none
+
+from . import default_hooks as default
+
+
+__all__ = ["PowerSGDState", "powerSGD_hook", "batched_powerSGD_hook"]
+
+logger = logging.getLogger(__name__)
+
+
+def _orthogonalize(matrices, epsilon=0):
+    """
+    Decide between Gram-Schmidt or QR factorization to orthogonalize a batch of matrices.
+
+    QR factorization doesn't work with half-precision, but it is usually faster with a rank > 2.
+    """
+    assert len(matrices.shape) == 3 and matrices.shape[2] <= matrices.shape[1]
+
+    num_matrices = matrices.shape[0]
+    rank = matrices.shape[2]
+    dtype = matrices.dtype
+    if rank <= 2 or dtype in [torch.float16, torch.bfloat16]:
+        _orthogonalize_gram_schmidt(matrices, epsilon=epsilon)
+    else:
+        torch.linalg.qr(
+            matrices,
+            out=(
+                matrices,
+                torch.empty(
+                    num_matrices, rank, rank, device=matrices.device, dtype=dtype
+                ),
+            ),
+        )
+
+
+def _orthogonalize_gram_schmidt(matrices, epsilon=0):
+    """
+    Apply Gram-Schmidt procedure to orthogonalize a batch of matrices.
+
+    If epsilon is 0, this is equivalent to `torch.qr(matrices, out=(matrices, _))`,
+    """
+    num_cols = matrices.shape[2]
+    for i in range(num_cols):
+        # Normalize the i'th column.
+        col = matrices[:, :, i : i + 1]
+        # If no epsilon is added here, division by zero may be caused by vanishing gradients.
+        # This epsilon is not needed if the input batch of matrices covers the gradients of at least one entire layer
+        # in the neural network.
+        if epsilon == 0:
+            # Note that col ** 2 can underflow/overflow if we use FP16.
+            # May need to consider multiplying a scaling factor and dividing it later, or using bfloat16 instead.
+            try:
+                col /= torch.norm(col, dim=1, keepdim=True)
+            except ZeroDivisionError:
+                logger.error(
+                    "The matrices to be orthogonalized has at least a column of all 0s. Please set a small value such as 1e-8 "
+                    "as `orthogonalization_epsilon` in PowerSGD state."
+                )
+                # Recover the values from NaNs to 0s.
+                col.fill_(0.0)
+        else:
+            col /= torch.norm(col, dim=1, keepdim=True) + epsilon
+        # Project it on the rest and remove it.
+        if i + 1 < num_cols:
+            rest = matrices[:, :, i + 1 :]
+            rest -= torch.sum(col * rest, dim=1, keepdim=True) * col
+
+
+def _should_compress(
+    num_rows, num_cols, matrix_approximation_rank, min_compression_rate
+):
+    """
+    Recommend if tensor given is worth compressing.
+
+    Returns a recommendation as to whether the 2D tensor described by the arguments is worth compressing,
+    including statistics describing the expected savings from compression.  We consider a tensor worth
+    compressing when ``min_compression_rate`` < uncompressed size / compressed size, where
+    uncompressed size = ``num_rows`` * ``num_cols``,
+    and compressed size = (``num_rows`` + ``num_cols``) * ``matrix_approximation_rank``.
+
+    The result of this function is a tuple of the form (compression_recommendation, uncompressed_el_count, compressed_el_count), where:
+
+    compression_recommendation is true if the tensor is worth compressing, and false otherwise (see above);
+
+    uncompressed_el_count is the uncompressed element count, i.e. ``num_rows`` * ``num_cols``; and,
+
+    compress_el_count is the element count after compression, i.e. (``num_rows`` + ``num_cols``) * ``matrix_approximation_rank``.
+    """  # noqa: B950
+    uncompressed_size = num_rows * num_cols
+    compressed_size = (num_rows + num_cols) * matrix_approximation_rank
+    return (
+        compressed_size * min_compression_rate < uncompressed_size,
+        uncompressed_size,
+        compressed_size,
+    )
+
+
+def _report_compression_stats(bucket, state):
+    """Report compression stats at frequency of ``compression_stats_logging_frequency`` specified in PowerSGD state."""
+    if bucket.is_last() and state.iter >= state.next_stats_report:
+        stats = state.compression_stats()
+        logger.info(
+            "Compression stats: iter %s, total before compression %s, total after compression %s, "
+            "rate %s",
+            state.iter,
+            stats[1],
+            stats[2],
+            stats[0],
+        )
+        state.next_stats_report = state.iter + state.compression_stats_logging_frequency
+
+
+class PowerSGDState:
+    r"""
+    Store both the algorithm's hyperparameters and internal state for all gradients during training.
+
+    Particularly, ``matrix_approximation_rank`` and ``start_powerSGD_iter`` are the main hyperparameters that should be tuned by the user.
+    For performance, we suggest to keep binary hyperparameters ``use_error_feedback`` and ``warm_start`` on.
+
+    1. ``matrix_approximation_rank`` controls the size of compressed low-rank tensors, which determines the compression rate. The lower the rank, the stronger the compression.
+
+        1.1. If ``matrix_approximation_rank`` is too low, the full model quality will need more training steps to reach or will never reach and yield loss in accuracy.
+
+        1.2. The increase of ``matrix_approximation_rank`` can substantially increase the computation costs of the compression, and the accuracy may not be further improved beyond a certain ``matrix_approximation_rank`` threshold.
+
+    To tune ``matrix_approximation_rank``, we suggest to start from 1 and increase by factors of 2 (like an exponential grid search, 1, 2, 4, ...), until a satisfactory accuracy is reached. Typically only a small value 1-4 is used. For some NLP tasks (as shown in Appendix D of the original paper), this value has been increased to 32.
+
+    2. ``start_powerSGD_iter`` defers PowerSGD compression until step ``start_powerSGD_iter``, and vanilla allreduce runs prior to step ``start_powerSGD_iter``. This hybrid scheme of **vanilla allreduce + PowerSGD** can effectively improve the accuracy, even a relatively small ``matrix_approximation_rank`` is used. This is because that, the beginning of training phase is usually very sensitive to inaccurate gradients, and compressing gradients too early may make the training quickly take a suboptimal trajectory, which can result in an irrecoverable impact on the accuracy.
+
+    To tune ``start_powerSGD_iter``, we suggest to start with 10% of total training steps, and increase it until a satisfactory accuracy is reached. If there is a warm-up stage in the training, ``start_powerSGD_iter`` typically should be no less than the number of warm-up steps.
+
+    3. ``min_compression_rate`` is the minimum compression rate required when a layer is compressed. Due to the computation overheads incurred by the compression, a tensor is worth compressing only if there can be sufficient saving in bandwidth, where ``(num_rows + num_cols) * matrix_approximation_rank * min_compression_rate < num_rows * num_cols``. If the specified compression rate threshold cannot be satisfied, the tensor will be directly allreduced without compression.
+
+    Compression statistics are logged every ``compression_stats_logging_frequency`` iterations once PowerSGD compression starts.
+
+    4. ``orthogonalization_epsilon`` can be a very small value (e.g., 1e-8) added to every normalized matrix column in orthogonalization step, to prevent div-by-zero error if any column has all 0s. If this can already be prevented (e.g., by batch normalization), an epsilon of 0 is recommended for accuracy.
+
+    5. ``batch_tensors_with_same_shape`` controls whether to compress and decompress tensors with same shape in a batched operation to achieve higher parallelism. Note that you should also increase the bucket size (i.e., ``bucket_cap_mb`` arg in DDP constructor) to make more same-shaped tensors appear in the same bucket, however this may reduce the overlap between computation and communication, and increase the memory footprint due to stacking the tensors of the same shape. Set to ``True`` if the compression / decompression computation is a bottleneck.
+
+    .. warning ::
+        If error feedback or warm-up is enabled, the minimum value of ``start_powerSGD_iter`` allowed in DDP is 2.
+        This is because there is another internal optimization that rebuilds buckets at iteration 1 in DDP,
+        and this can conflict with any tensor memorized before the rebuild process.
+    """  # noqa: B950
+
+    __slots__ = [
+        "process_group",
+        # The fields below are the hyperparameters that often need to be tuned by the user.
+        "matrix_approximation_rank",
+        "start_powerSGD_iter",
+        # The fields below are the hyperparameters that seldom need be tuned by the user.
+        "min_compression_rate",
+        "orthogonalization_epsilon",
+        # The fields below are the binary hyperparameters recommended to be turned on for performance and accuracy.
+        "use_error_feedback",
+        "warm_start",
+        "batch_tensors_with_same_shape",
+        # The fields below are internal state.
+        "rng",
+        "error_dict",
+        "p_memory_dict",
+        "q_memory_dict",
+        "iter",
+        # The fields below are for recording compression stats.
+        "total_numel_before_compression",
+        "total_numel_after_compression",
+        "compression_stats_logging_frequency",
+        "next_stats_report",
+    ]
+
+    def __init__(
+        self,
+        process_group,
+        matrix_approximation_rank=1,
+        start_powerSGD_iter=1_000,
+        min_compression_rate=2,
+        use_error_feedback=True,
+        warm_start=True,
+        orthogonalization_epsilon=0,
+        random_seed=0,
+        compression_stats_logging_frequency=10_000,
+        batch_tensors_with_same_shape: bool = False,
+    ):
+        logger.info(
+            "PowerSGD config: matrix_approximation_rank = %s; start_powerSGD_iter = %s; "
+            "min_compression_rate = %s; orthogonalization_epsilon = %s; use_error_feedback = %s; warm_start = %s; "
+            "random_seed = %s; compression_stats_logging_frequency = %s; batch_tensors_with_same_shape = %s",
+            matrix_approximation_rank,
+            start_powerSGD_iter,
+            min_compression_rate,
+            orthogonalization_epsilon,
+            use_error_feedback,
+            warm_start,
+            random_seed,
+            compression_stats_logging_frequency,
+            batch_tensors_with_same_shape,
+        )
+
+        self.process_group = process_group
+        self.matrix_approximation_rank = matrix_approximation_rank
+        # Deferring PowerSGD compression util step 'start_powerSGD_iter' can have two advantages:
+        # 1) It turns out that PowerSGD may lead to a non-trivial accuracy loss,
+        # even if the matrix approximation rank is increased to a large value.
+        # To mitigate the accuracy loss, a simple yet effective way is mixing vanilla allreduce
+        # (or a more conservative compression such as FP16 compression) with PowerSGD.
+        # 2) There is an internal optimization of rebuilding buckets process in DDP,
+        # in order to save the memory space.
+        # This step takes place after the first iteration.
+        # However, this means that the shape of input bucketized tensors is subject to change,
+        # which will complicate the implementations of error feedback and warm-up.
+        # Running vanilla allreduce in the first few iterations can avoid this complexity.
+        if (use_error_feedback or warm_start) and start_powerSGD_iter <= 1:
+            raise ValueError(
+                "Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, "
+                "because PowerSGD can only be applied after the first two iterations in DDP."
+            )
+        self.start_powerSGD_iter = start_powerSGD_iter
+        self.min_compression_rate = min_compression_rate
+        # Error feedback is usually crucial for both for convergence and generalization,
+        # because PowerSGD is a biased compressor,
+        # i.e., compressing and decompressing a random gradient does not yield the original in expectation.
+        # This mechanism requires a temporary copy of the input gradients,
+        # so it increases the peak memory consumption by the size of the gradient tensor.
+        # However, if the target matrices are known to be exactly low-ranked (instead of just low stable rank),
+        # sometimes it is possible to converge to the optima without error feedback.
+        # See: http://proceedings.mlr.press/v54/yurtsever17a/yurtsever17a.pdf
+        self.use_error_feedback = use_error_feedback
+        # Warm-start reuses P(s) and Q(s) from the previous iteration.
+        # This can improve the approximation quality and hence improve the accuracy.
+        # Additionally, by avoiding the initialization of these low-rank tensors at every step,
+        # this can also accelerate training.
+        # However, this is at the cost of extra memory.
+        self.warm_start = warm_start
+        # Can use a very small value to prevent div-by-zero error caused by orthogonalization of vanishing gradients.
+        self.orthogonalization_epsilon = orthogonalization_epsilon
+        # The purpose of this RNG is to generate different random seeds for initializing Q across iterations,
+        # but in the same order for all the DDP replicas.
+        # Different random seeds across iterations indicate different 'projections' of the gradients at different SGD steps.
+        # If the same random projection is used,
+        # there will be differences between the gradients that are never synchronized.
+        import numpy as np
+
+        self.rng = np.random.RandomState(random_seed)
+        # Since there is only a single state instance for all the input buckets,
+        # need to maintain a dictionary that maps each bucket index to the local error.
+        self.error_dict: dict[int, torch.Tensor] = {}
+        self.p_memory_dict: dict[int, torch.Tensor] = {}
+        self.q_memory_dict: dict[int, torch.Tensor] = {}
+        # Iteration/step in the training loop.
+        self.iter = 0
+        # Compression stats accumulators
+        self.total_numel_before_compression = 0
+        self.total_numel_after_compression = 0
+        # We'll report compression stats every 'compression_stats_logging_frequency' iterations
+        # Note that we always report compression stats at least once.
+        self.compression_stats_logging_frequency = max(
+            1, compression_stats_logging_frequency
+        )
+        self.next_stats_report = 0
+        # Batching tensors with same shape can increase parallelism in compression / decompression computation.
+        # This requires a larger bucket size to make more same-shaped tensor to appear in one bucket, however
+        # this may reduce the overlap between computation and communication, and increase the memory footprint
+        # due to stacking tensors.
+        # Turn on if compression / decompression computation is a bottleneck.
+        self.batch_tensors_with_same_shape = batch_tensors_with_same_shape
+
+    def __getstate__(self):
+        r"""
+        Return a ``Dict[str, Any]`` which will be pickled and saved.
+
+        ``process_group`` is not serializable and excluded from
+        a returned state.
+        """
+        logger.warning(
+            "NOTE: Process group is not serializable and excluded from a saved state."
+        )
+        return {
+            slot: getattr(self, slot)
+            for slot in self.__slots__
+            if slot != "process_group"
+        }
+
+    def __setstate__(self, state):
+        r"""
+        Take a provided ``state`` and set to this ``PowerSGDState`` instance.
+
+        ``process_group`` is set to default.
+        """
+        self.process_group = distributed_c10d._get_default_group()
+        logger.warning(
+            "NOTE: Process group will be set to a default group (i.e. the world size).\
+                If a different group is desired, please set `self.process_group` after PowerSGD state is loaded."
+        )
+        for slot, value in state.items():
+            setattr(self, slot, value)
+
+    def maybe_increase_iter(self, bucket):
+        """Track iterations and trigger log message at start of local SGD."""
+        # Since bucket 0 is the last bucket to allreduce in an iteration.
+        # Only increase `iter` when bucket 0 is processed.
+        if bucket.is_last():
+            self.iter += 1
+
+        if self.iter == self.start_powerSGD_iter:
+            logger.info("Start to apply PowerSGD after %s iterations.", self.iter)
+
+    def compression_stats(self):
+        r"""
+        Return latest compression statistics as tuple.
+
+        Returns tuple of form (compress_rate, numel_before_compression, numel_after_compression) where:
+
+        compress_rate is the effective compression rate i.e. (number of elements before compression) / (number of elements after compression);
+
+        numel_before_compression is the total number of elements before compression was applied; and,
+
+        numel_after_compression is the total number of elements after compression was applied.
+        """  # noqa: B950
+        compress_rate = (
+            self.total_numel_before_compression / self.total_numel_after_compression
+            if self.total_numel_after_compression > 0
+            else 0
+        )
+        return (
+            compress_rate,
+            self.total_numel_before_compression,
+            self.total_numel_after_compression,
+        )
+
+
+def powerSGD_hook(
+    state: PowerSGDState, bucket: dist.GradBucket
+) -> torch.futures.Future[torch.Tensor]:
+    r"""
+    Implement PowerSGD algorithm.
+
+    This DDP communication hook implements PowerSGD gradient compression
+    algorithm described in the `paper `_.
+    Once gradient tensors are aggregated across all workers, this hook applies
+    compression as follows:
+
+    1. Views the input flattened 1D gradient tensor as a list of per-parameter tensors, and divides all the tensors into two groups:
+
+        1.1 The tensors that should be compressed before allreduce, because the compression can give enough saving in bandwidth.
+
+        1.2 Rest of the tensors will be directly allreduced without compression, including all the vector tensors (for biases).
+
+    2. Handles uncompressed tensors:
+
+        2.1. Allocate contiguous memory for those uncompressed tensors, and allreduces all the uncompressed tensors as a batch, without compression;
+
+        2.2. Copies the individual uncompressed tensors from the contiguous memory back to the input tensor.
+
+    3. Handles the tensors that should be compressed by PowerSGD compression:
+
+        3.1. For each tensor M, creates two low-rank tensors P and Q for decomposing M,
+        such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
+
+        3.2. Computes each P in Ps, which is equal to MQ;
+
+        3.3. Allreduces Ps as a batch;
+
+        3.4. Orthogonalizes each P in Ps;
+
+        3.5. Computes each Q in Qs, which is approximately equal to M^TP;
+
+        3.6. Allreduces Qs as a batch;
+
+        3.7. Computes each M among all the compressed tensors, which is approximately equal to PQ^T.
+
+    Note that this communication hook enforces vanilla allreduce for the first ``state.start_powerSGD_iter`` iterations.
+    This not only gives the user more control over the tradeoff between speedup and accuracy,
+    but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers.
+
+    Args:
+        state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
+            To tune the compression configs, mainly need to tune ``matrix_approximation_rank``, ``start_powerSGD_iter``
+            and ``min_compression_rate``.
+        bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
+            Note that since DDP comm hook only supports single process single device mode,
+            only exactly one tensor is stored in this bucket.
+
+    Returns:
+        Future handler of the communication, which updates the gradients in place.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1,
+                                  start_powerSGD_iter=10, min_compression_rate=0.5)
+        >>> ddp_model.register_comm_hook(state, powerSGD_hook)
+    """  # noqa: B950
+    process_group = state.process_group
+    group_to_use = (
+        process_group if process_group is not None else not_none(dist.group.WORLD)
+    )
+    world_size = group_to_use.size()
+
+    # The input tensor is a flattened 1D tensor.
+    input_tensor = bucket.buffer()
+
+    # Run vanilla allreduce in the first `start_powerSGD_iter` iterations.
+    if state.iter < state.start_powerSGD_iter:
+        state.maybe_increase_iter(bucket)
+        return default._allreduce_fut(group_to_use, input_tensor)
+
+    # Apply PowerSGD after `start_powerSGD_iter` iterations.
+    device = input_tensor.device
+    dtype = input_tensor.dtype
+
+    # Incorporate the error from the previous state into the gradients.
+    bucket_index = bucket.index()
+    input_tensor_cp = None
+    total_length = input_tensor.shape[0]
+    if state.use_error_feedback:
+        if bucket_index in state.error_dict:
+            input_tensor.add_(state.error_dict[bucket_index])
+        else:
+            logger.info(
+                "A zero tensor of length %s that represents local error is created.",
+                total_length,
+            )
+            state.error_dict[bucket_index] = torch.zeros(
+                total_length, device=device, dtype=dtype
+            )
+
+        # Keep a copy of the input tensor,
+        # so that we can compute the local error caused by compression later,
+        # by comparing this copy and the input tensor updated after decompression.
+        input_tensor_cp = input_tensor.detach().clone()
+
+    # Unflatten the input tensor into per-parameter tensors, for layer-wise compression.
+    tensors = bucket.gradients()
+
+    # Step I: Divide all the tensors into two groups,
+    # one will be compressed before allreduce and the other will be directly allreduced without compression.
+    tensors_to_compress, uncompressed_tensors = [], []
+    total_Ps_size = 0
+    total_Qs_size = 0
+    for tensor in tensors:
+        matrix = tensor.view(tensor.shape[0], -1)
+        n, m = matrix.shape
+        matrix_approximation_rank = min(n, m, state.matrix_approximation_rank)
+        compress_test = _should_compress(
+            n, m, matrix_approximation_rank, state.min_compression_rate
+        )
+        state.total_numel_before_compression += compress_test[1]
+        if compress_test[0]:
+            tensors_to_compress.append(matrix)
+            total_Ps_size += n * matrix_approximation_rank
+            total_Qs_size += m * matrix_approximation_rank
+            state.total_numel_after_compression += compress_test[2]
+        else:
+            uncompressed_tensors.append(tensor)
+            state.total_numel_after_compression += compress_test[1]
+
+    _report_compression_stats(bucket, state)
+
+    # Step II: Handle uncompressed tensors.
+    # Allocate contiguous memory for these tensors to allreduce efficiently.
+    uncompressed_tensors_memory = (
+        torch.cat([tensor.view(-1) for tensor in uncompressed_tensors])
+        if uncompressed_tensors
+        else torch.tensor([], device=device, dtype=dtype)
+    )
+
+    # Step III: Handle the tensors that should be compressed.
+    # Allocate contiguous memory for Ps and Qs to allreduce efficiently.
+    # If warm-start is enabled, reuse Ps and Qs from the previous iteration if possible.
+    # The memory spaces of Ps and Qs need to be allocated in the first iteration when PowerSGD is applied.
+    need_randomize_qs = False
+    if not state.warm_start or bucket_index not in state.p_memory_dict:
+        need_randomize_qs = True
+        # If warm-start is disabled, low-rank tensors will be initialized at every step.
+        # Only log this if warm-start to avoid spamming.
+        if state.warm_start:
+            logger.info(
+                "Allocating contiguous memory of length %s for Ps, and of length %s for Qs, respectively.",
+                total_Ps_size,
+                total_Qs_size,
+            )
+        state.p_memory_dict[bucket_index] = torch.empty(
+            total_Ps_size, device=device, dtype=dtype
+        )
+        state.q_memory_dict[bucket_index] = torch.empty(
+            total_Qs_size, device=device, dtype=dtype
+        )
+
+    # Batch tensors to compress by shape.
+    shape_to_tensors = defaultdict(list)
+    for tensor in tensors_to_compress:
+        shape_to_tensors[tensor.shape].append(tensor)
+
+    # This function decides whether to batch tensors with same shape or not according to the argument,
+    # so the following process could share the same code.
+    def maybe_batched_tensors_to_compress():
+        for tensors in shape_to_tensors.values():
+            if state.batch_tensors_with_same_shape:
+                batch_size = len(tensors)
+                if batch_size == 1:
+                    # Use the original tensor to avoid copy.
+                    yield tensors[0].unsqueeze(0)
+                else:
+                    yield torch.stack(tensors)
+            else:
+                for tensor in tensors:
+                    yield tensor.unsqueeze(0)
+
+    # Create Ps and Qs that point to the allocated memory.
+    tensors_to_compress = []
+    ps = []
+    qs = []
+    p_idx = 0
+    q_idx = 0
+    for tensor in maybe_batched_tensors_to_compress():
+        batch_size, n, m = tensor.shape
+        matrix_approximation_rank = min(n, m, state.matrix_approximation_rank)
+        tensors_to_compress.append(tensor)
+        ps.append(
+            state.p_memory_dict[bucket_index][
+                p_idx : p_idx + batch_size * n * matrix_approximation_rank
+            ].view(batch_size, n, matrix_approximation_rank)
+        )
+        qs.append(
+            state.q_memory_dict[bucket_index][
+                q_idx : q_idx + batch_size * m * matrix_approximation_rank
+            ].view(batch_size, m, matrix_approximation_rank)
+        )
+        p_idx += batch_size * n * matrix_approximation_rank
+        q_idx += batch_size * m * matrix_approximation_rank
+
+    # If warm-start is enabled, reuse Qs from the previous iteration if possible and skip filling random values.
+    # The exception is the first iteration when PowerSGD is applied.
+    if not need_randomize_qs:
+        for q in qs:
+            _orthogonalize(q, state.orthogonalization_epsilon)
+    else:
+        with torch.random.fork_rng(devices=[]):
+            # Fork this RNG to avoid changing the seed globally and affecting the random sampling anywhere else in the training.
+            # The seed makes sure that the initial random values are the same across all the DDP replicas.
+            # This seed should differ at every step.
+            # Since it is very slow to fork RNG state across all the CUDA devices,
+            # only fork on CPU and then move the generated tensor to the CUDA device (by overwriting q).
+            torch.manual_seed(state.rng.randint(1_000_000_000))
+            for q in qs:
+                q.copy_(
+                    torch.randn(
+                        *q.shape,
+                        device="cpu",
+                        dtype=dtype,
+                    )
+                )
+                _orthogonalize(q, state.orthogonalization_epsilon)
+
+    # Compute Ps.
+    for tensor, q, p in zip(tensors_to_compress, qs, ps):
+        torch.bmm(tensor, q, out=p)
+
+    # This allreduce is only applied to uncompressed tensors,
+    # so it should have been kicked off before the above computation on the compressed tensors to hide more communication costs.
+    # However, this somehow requires a separate future chain at this time.
+    allreduce_contiguous_uncompressed_tensors_fut = dist.all_reduce(
+        uncompressed_tensors_memory, group=group_to_use, async_op=True
+    ).get_future()
+
+    def unpack_uncompressed_tensors_and_allreduce_ps(fut):
+        uncompressed_tensors_memory = fut.value()[0].div_(world_size)
+        idx = 0
+        for tensor in uncompressed_tensors:
+            tensor.copy_(
+                uncompressed_tensors_memory[idx : idx + tensor.numel()].view_as(tensor)
+            )
+            idx += tensor.numel()
+
+        # Since these Ps will be orthogonalized later, no need to divide them by world size.
+        return (
+            dist.all_reduce(
+                state.p_memory_dict[bucket_index], group=group_to_use, async_op=True
+            )
+            .get_future()
+            .wait()[0]
+        )
+
+    def compute_qs(fut):
+        state.p_memory_dict[bucket_index] = fut.value()
+        for p in ps:
+            _orthogonalize(p, state.orthogonalization_epsilon)
+
+        # Compute Qs.
+        for tensor, p, q in zip(tensors_to_compress, ps, qs):
+            torch.bmm(tensor.transpose(1, 2), p, out=q)
+
+        # TODO: The above procedure does two matmul+allreduce steps per iteration --
+        # one left multiplication and one right multiplication.
+        # For warm-start, can take one such step at a time, and alternate between them.
+
+        # Allreduce Qs.
+        return (
+            dist.all_reduce(
+                state.q_memory_dict[bucket_index], group=group_to_use, async_op=True
+            )
+            .get_future()
+            .wait()[0]
+        )
+
+    def decompress(fut):
+        state.q_memory_dict[bucket_index] = fut.value().div_(world_size)
+
+        for p, q, tensor in zip(ps, qs, tensors_to_compress):
+            torch.bmm(p, q.transpose(1, 2), out=tensor)
+
+        # Copy batched tensors back to original buffer.
+        if state.batch_tensors_with_same_shape:
+            for tensor in tensors_to_compress:
+                if tensor.shape[0] == 1:
+                    # Skip tensor with batch_size == 1 since itself is the original tensor.
+                    continue
+                original_tensors = shape_to_tensors[tensor.shape[1:]]
+                for i, original_tensor in enumerate(original_tensors):
+                    original_tensor.copy_(tensor[i])
+
+        if torch.cuda.is_available():
+            torch.cuda.synchronize(device)
+
+        if state.use_error_feedback:
+            # Memorize the local errors.
+            assert input_tensor_cp is not None
+            state.error_dict[bucket_index] = input_tensor_cp - input_tensor
+        if not state.warm_start:
+            state.p_memory_dict.clear()
+            state.q_memory_dict.clear()
+
+        state.maybe_increase_iter(bucket)
+
+        return input_tensor
+
+    return (
+        allreduce_contiguous_uncompressed_tensors_fut.then(
+            unpack_uncompressed_tensors_and_allreduce_ps
+        )
+        .then(compute_qs)
+        .then(decompress)
+    )
+
+
+def batched_powerSGD_hook(
+    state: PowerSGDState, bucket: dist.GradBucket
+) -> torch.futures.Future[torch.Tensor]:
+    r"""
+    Implement simplified PowerSGD algorithm.
+
+    This DDP communication hook implements a simplified PowerSGD gradient compression
+    algorithm described in the `paper `_.
+    This variant does not compress the gradients layer by layer,
+    but instead compresses the flattened input tensor that batches all the gradients.
+    Therefore, it is **faster** than :meth:`powerSGD_hook`,
+    but usually results in a **much lower accuracy**, unless ``matrix_approximation_rank`` is 1.
+
+    .. warning ::
+        Increasing ``matrix_approximation_rank`` here may not necessarily increase the accuracy,
+        because batching per-parameter tensors without column/row alignment can destroy low-rank structure.
+        Therefore, the user should always consider :meth:`powerSGD_hook` first,
+        and only consider this variant when a satisfactory accuracy can be achieved when ``matrix_approximation_rank`` is 1.
+
+    Once gradient tensors are aggregated across all workers, this hook applies
+    compression as follows:
+
+    1. Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings;
+
+    2. Creates two low-rank tensors P and Q for decomposing M, such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
+
+    3. Computes P, which is equal to MQ;
+
+    4. Allreduces P;
+
+    5. Orthogonalizes P;
+
+    6. Computes Q, which is approximately equal to M^TP;
+
+    7. Allreduces Q;
+
+    8. Computes M, which is approximately equal to PQ^T.
+
+    9. Truncates the input tensor to the original length.
+
+    Note that this communication hook enforces vanilla allreduce for the first ``state.start_powerSGD_iter`` iterations.
+    This not only gives the user more control over the tradeoff between speedup and accuracy,
+    but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers.
+
+    Args:
+        state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
+            To tune the compression configs, mainly need to tune ``matrix_approximation_rank`` and ``start_powerSGD_iter``.
+        bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
+            Note that since DDP comm hook only supports single process single device mode,
+            only exactly one tensor is stored in this bucket.
+
+    Returns:
+        Future handler of the communication, which updates the gradients in place.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
+        >>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)
+    """  # noqa: B950
+    process_group = state.process_group
+    group_to_use = (
+        process_group if process_group is not None else not_none(dist.group.WORLD)
+    )
+    world_size = group_to_use.size()
+
+    # The input tensor is a flattened 1D tensor.
+    input_tensor = bucket.buffer()
+
+    # Run vanilla allreduce in the first `start_powerSGD_iter` iterations.
+    if state.iter < state.start_powerSGD_iter:
+        state.maybe_increase_iter(bucket)
+        return default._allreduce_fut(group_to_use, input_tensor)
+
+    # Apply PowerSGD after `start_powerSGD_iter` iterations.
+    device = input_tensor.device
+    total_length = input_tensor.shape[0]
+    state.total_numel_before_compression += total_length
+
+    # View the input tensor as a 2D square-shape tensor, and pad 0s if necessary.
+    square_side_length = math.ceil(math.sqrt(total_length))
+    state.total_numel_after_compression += (
+        square_side_length * state.matrix_approximation_rank * 2
+    )
+    padded_total_length = square_side_length**2
+    input_tensor.resize_(padded_total_length)
+    input_tensor[total_length:padded_total_length].fill_(0)
+
+    _report_compression_stats(bucket, state)
+
+    # Incorporate the error from the previous state into the gradients.
+    bucket_index = bucket.index()
+    input_tensor_cp = None
+    if state.use_error_feedback:
+        if bucket_index in state.error_dict:
+            input_tensor.add_(state.error_dict[bucket_index])
+        else:
+            logger.info(
+                "A zero tensor of length %s that represents local error is created.",
+                padded_total_length,
+            )
+            state.error_dict[bucket_index] = torch.zeros(
+                padded_total_length, device=device, dtype=input_tensor.dtype
+            )
+
+        # Keep a copy of the input tensor,
+        # so that we can compute the local error caused by compression later,
+        # by comparing this copy and the input tensor updated after decompression.
+        input_tensor_cp = input_tensor.detach().clone()
+    matrix = input_tensor.view(square_side_length, square_side_length)
+
+    # Reuse P and Q from the previous iteration if possible.
+    # The memory spaces of P and Q need to be allocated in the first iteration when PowerSGD is applied.
+    if not state.warm_start or bucket_index not in state.p_memory_dict:
+        # If warm-start is disabled, low-rank tensors will be initialized at every step.
+        # Only log this if warm-start to avoid spamming.
+        if state.warm_start:
+            logger.info(
+                "Initializing low-rank tensors P and Q, each of which has a shape of %s x %s.",
+                square_side_length,
+                state.matrix_approximation_rank,
+            )
+
+        def create_low_rank_tensor(fill_random_values, rng):
+            """Return a low-rank 2D tensor of square_side_length * matrix_approximation_rank."""
+            if fill_random_values:
+                with torch.random.fork_rng(devices=[]):
+                    # Fork this RNG to avoid changing the seed globally and affecting the random sampling
+                    # anywhere else in the training.
+                    # The seed makes sure that the initial random values are the same across all the DDP replicas.
+                    # This seed should differ at every step.
+                    # Since it is very slow to fork RNG state across all the CUDA devices,
+                    # only fork on CPU and then move the generated tensor to the CUDA device.
+                    torch.manual_seed(rng.randint(1_000_000_000))
+                    return torch.randn(
+                        square_side_length,
+                        state.matrix_approximation_rank,
+                        device="cpu",
+                        dtype=input_tensor.dtype,
+                    ).to(device)
+            else:
+                return torch.empty(
+                    square_side_length,
+                    state.matrix_approximation_rank,
+                    device=device,
+                    dtype=input_tensor.dtype,
+                )
+
+        state.p_memory_dict[bucket_index] = create_low_rank_tensor(
+            fill_random_values=False, rng=state.rng
+        )
+        state.q_memory_dict[bucket_index] = create_low_rank_tensor(
+            fill_random_values=True, rng=state.rng
+        )
+    _orthogonalize(state.q_memory_dict[bucket_index])
+
+    torch.matmul(
+        matrix, state.q_memory_dict[bucket_index], out=state.p_memory_dict[bucket_index]
+    )
+    allreduce_p_fut = dist.all_reduce(
+        state.p_memory_dict[bucket_index], group=group_to_use, async_op=True
+    ).get_future()
+
+    def compute_q(fut):
+        state.p_memory_dict[bucket_index] = fut.value()[0]
+        _orthogonalize(state.p_memory_dict[bucket_index])
+
+        torch.matmul(
+            matrix.t(),
+            state.p_memory_dict[bucket_index],
+            out=state.q_memory_dict[bucket_index],
+        )
+
+        # TODO: The above procedure does two matmul+allreduce steps per iteration --
+        # one left multiplication and one right multiplication.
+        # For warm-start, can take one such step at a time, and alternate between them.
+
+        return (
+            dist.all_reduce(
+                state.q_memory_dict[bucket_index], group=group_to_use, async_op=True
+            )
+            .get_future()
+            .wait()[0]
+        )
+
+    def decompress(fut):
+        state.q_memory_dict[bucket_index] = fut.value().div_(world_size)
+        torch.matmul(
+            state.p_memory_dict[bucket_index],
+            state.q_memory_dict[bucket_index].t(),
+            out=matrix,
+        )
+
+        if state.use_error_feedback:
+            # Memorize the local errors.
+            assert input_tensor_cp is not None
+            state.error_dict[bucket_index] = input_tensor_cp - input_tensor
+        # Removing this seemingly unnecessary sync somehow may cause failures.
+        # See: https://github.com/pytorch/pytorch/pull/54838
+        if torch.cuda.is_available():
+            torch.cuda.synchronize(device)
+        if not state.warm_start:
+            state.p_memory_dict.clear()
+            state.q_memory_dict.clear()
+        ret = input_tensor.resize_(total_length)
+
+        state.maybe_increase_iter(bucket)
+
+        return ret
+
+    return allreduce_p_fut.then(compute_q).then(decompress)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..886155908e1a702972184a550082c33c677eacdc
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py
@@ -0,0 +1,220 @@
+# mypy: allow-untyped-defs
+import torch
+import torch.distributed as dist
+from torch import nn
+
+
+def _quantize_per_tensor_backend(x, scale, zero_point):
+    y = torch.round(x / scale) + zero_point
+    y = torch.clamp(y, 0, 255).to(torch.uint8)
+    return y
+
+
+def _dequantize_per_tensor_backend(y, scale, zero_point):
+    x = scale * (y.to(torch.float32) - zero_point)
+    return x
+
+
+def _quantize_per_channel_backend(x, scale, zero_point):
+    y = torch.zeros(x.size(), device=x.device)
+    for i in range(x.size()[0]):
+        y[i, :] = torch.round(x[i, :] / scale[i]) + zero_point[i]
+    y = torch.clamp(y, 0, 255).to(torch.uint8)
+    return y
+
+
+def _dequantize_per_channel_backend(y, scale, zero_point):
+    y = y.to(torch.float32).to(y.device)
+    x = torch.zeros_like(y, device=y.device)
+    for i in range(x.size()[0]):
+        x[i, :] = scale[i] * (y[i, :] - zero_point[i])
+    return x
+
+
+def _get_allgather_out_list(all_gather_in_list, world_size):
+    out_list = [
+        torch.zeros_like(
+            all_gather_in_list,
+            device=all_gather_in_list.device,
+            dtype=all_gather_in_list.dtype,
+        )
+        for _ in range(world_size)
+    ]
+    return out_list
+
+
+def quantization_pertensor_hook(
+    process_group: dist.ProcessGroup, bucket: dist.GradBucket
+) -> torch.futures.Future[torch.Tensor]:
+    """
+    Apply ``torch.quantize_per_tensor`` logic to DDP using ``allgather`` protocol.
+
+    Workers first allgather the scale and zero point of their own
+    ``GradBucket`` prior to the quantization. After all workers have that information,
+    the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's
+    own gradient tensor, and uses ``allgather`` to communicate these across all workers.
+    The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes and
+    aggregates each quantized gradient tensor locally and returns the mean.
+
+    .. warning ::
+        This is experimental, and uses ``allgather`` protocol which is considerably slower than
+        ``allreduce`` protocol. It works only with flattened grads.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> ddp_model.register_comm_hook(process_group, quantization_pertensor_hook)
+    """
+    group_to_use = process_group if process_group is not None else dist.group.WORLD
+    rank = process_group.rank() if process_group is not None else dist.get_rank()
+    # pyrefly: ignore [missing-attribute]
+    world_size = group_to_use.size()
+
+    tensor = bucket.buffer()
+
+    myObserver = torch.ao.quantization.MinMaxObserver().to(tensor.device)
+    myObserver(tensor)
+
+    s, z = myObserver.calculate_qparams()
+    s_and_z = torch.FloatTensor([s, z]).to(tensor.device)
+
+    all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size)
+
+    # First, allgather scale and zeros.
+    fut = dist.all_gather(
+        all_ranks_s_and_z, s_and_z, group=group_to_use, async_op=True
+    ).get_future()
+
+    def quantize_and_allgather(fut):
+        # Store scale and zeros across all workers.
+        all_ranks_s_and_z = fut.wait()[0]
+        # All workers quantize their own ``GradBucket`` tensors.
+        quantized_tensor = _quantize_per_tensor_backend(
+            tensor, all_ranks_s_and_z[rank][0], all_ranks_s_and_z[rank][1]
+        )
+        # Allgather quantized tensors.
+        fut = dist.all_gather(
+            _get_allgather_out_list(quantized_tensor, world_size),
+            quantized_tensor,
+            group=group_to_use,
+            async_op=True,
+        ).get_future()
+
+        return fut.wait()
+
+    def dequantize_and_aggregate(fut):
+        all_ranks_quantized_tensor = fut.wait()[0]
+
+        aggregated_dequantized_tensor = torch.zeros_like(
+            all_ranks_quantized_tensor[0], device=tensor.device, dtype=torch.float32
+        )
+        # Using previously allgathered scales and zeros, dequantize gradient tensors
+        # locally and then aggregate them.
+        for r, quantized_tensor in enumerate(all_ranks_quantized_tensor):
+            aggregated_dequantized_tensor += _dequantize_per_tensor_backend(
+                quantized_tensor, all_ranks_s_and_z[r][0], all_ranks_s_and_z[r][1]
+            )
+
+        return aggregated_dequantized_tensor / world_size
+
+    return fut.then(quantize_and_allgather).then(dequantize_and_aggregate)
+
+
+def quantization_perchannel_hook(
+    process_group: dist.ProcessGroup, bucket: dist.GradBucket, bucket_size=512
+) -> torch.futures.Future[torch.Tensor]:
+    """
+    Apply``torch.quantize_per_channel`` logic to DDP using ``allgather`` protocol.
+
+    Compared to per-tensor, the main motivation of per-channel is
+    for considerably large tensors such as a tensor that contains 6 million
+    elements quantizing per a bucket size of 512 (or 128) elements may significantly
+    increase the resolution.
+
+    It first splits ``GradBucket`` tensor into multiple chunks (channels) of ``bucket_size``
+    elements. Then, workers allgather the scales and zero points of their own
+    ``GradBucket`` prior to the quantization. After all workers have that information,
+    the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's
+    own gradient tensor, and uses ``allgather`` to communicate these across all workers.
+    The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes, flattens, and
+    aggregates each quantized gradient tensor locally and returns the mean.
+
+    .. warning ::
+        This is experimental, and uses ``allgather`` protocol which is considerably slower than
+        ``allreduce`` protocol. It works only with flattened grads.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> ddp_model.register_comm_hook(process_group, quantization_perchannel_hook)
+    """
+    group_to_use = process_group if process_group is not None else dist.group.WORLD
+    rank = process_group.rank() if process_group is not None else dist.get_rank()
+    # pyrefly: ignore [missing-attribute]
+    world_size = group_to_use.size()
+
+    tensor = bucket.buffer()
+
+    tensor_in_channels = (
+        nn.functional.pad(
+            input=tensor,
+            pad=(0, bucket_size - len(tensor) % bucket_size),
+            mode="constant",
+            value=0,
+        )
+        .view(-1, bucket_size)
+        .to(tensor.device)
+    )
+
+    myPerChannelObserver = torch.ao.quantization.PerChannelMinMaxObserver().to(
+        tensor.device
+    )
+    myPerChannelObserver(tensor_in_channels)
+
+    s_ch, z_ch = myPerChannelObserver.calculate_qparams()
+    s_and_z = torch.stack((s_ch, z_ch)).to(tensor.device)
+
+    all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size)
+    # First, allgather scale and zeros.
+    fut = dist.all_gather(
+        all_ranks_s_and_z, s_and_z, group=group_to_use, async_op=True
+    ).get_future()
+
+    def quantize_and_allgather(fut):
+        # Store scale and zeros across all workers.
+        all_ranks_s_and_z = fut.wait()[0]
+        # All workers quantize their corresponding ``GradBucket`` tensors.
+        quantized_tensor = _quantize_per_channel_backend(
+            tensor_in_channels,
+            all_ranks_s_and_z[rank, 0, :],
+            all_ranks_s_and_z[rank, 1, :],
+        )
+        # Allgather quantized tensors.
+        fut = dist.all_gather(
+            _get_allgather_out_list(quantized_tensor, world_size),
+            quantized_tensor,
+            group=group_to_use,
+            async_op=True,
+        ).get_future()
+
+        return fut.wait()
+
+    def dequantize_and_aggregate(fut):
+        all_ranks_quantized_tensor = fut.wait()[0]
+
+        aggregated_dequantized_tensor = torch.zeros_like(
+            all_ranks_quantized_tensor[0], device=tensor.device, dtype=torch.float32
+        )
+        # Using previously allgathered scales and zeros, dequantize gradient tensors
+        # locally and then aggregate them.
+        for r, quantized_tensor in enumerate(all_ranks_quantized_tensor):
+            aggregated_dequantized_tensor += _dequantize_per_channel_backend(
+                quantized_tensor, all_ranks_s_and_z[r][0], all_ranks_s_and_z[r][1]
+            )
+
+        return (
+            torch.flatten(aggregated_dequantized_tensor).to(tensor.device)[
+                : tensor.size()[0]
+            ]
+            / world_size
+        )
+
+    return fut.then(quantize_and_allgather).then(dequantize_and_aggregate)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8f961b3e21f5f598729777596e87695df06d0fae
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/averagers.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/averagers.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..333b6939a9e9346fa793cded91cb1e3677081483
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/averagers.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/hierarchical_model_averager.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/hierarchical_model_averager.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..296989aeebc4e90e12661974a60d9dca980db7cc
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/hierarchical_model_averager.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e72d4385974933d001a14c3b4c0b1e5e260a5075
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/averagers.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/averagers.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d669d4ea592250556ed5188b21ae265bb3b2c9c
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/averagers.py
@@ -0,0 +1,128 @@
+# mypy: allow-untyped-defs
+import warnings
+from abc import ABC, abstractmethod
+from collections.abc import Iterable
+
+import torch
+import torch.distributed as dist
+import torch.distributed.algorithms.model_averaging.utils as utils
+from torch.utils._typing_utils import not_none as _not_none
+
+
+__all__ = ["ModelAverager", "PeriodicModelAverager"]
+
+
+class ModelAverager(ABC):
+    r"""Base class for all model averagers.
+
+    Args:
+        process_group: The process group to be used for all-reduce.
+                       If ``None``, the default process group, which
+                       is created by :func:`torch.distributed.init_process_group`,
+                       will be used. (default: ``None``)
+    """
+
+    def __init__(self, process_group: dist.ProcessGroup | None = None):
+        self.process_group = (
+            process_group if process_group is not None else _not_none(dist.group.WORLD)
+        )
+        self.step = 0
+
+    @abstractmethod
+    def average_parameters(self, params):
+        raise NotImplementedError
+
+
+class PeriodicModelAverager(ModelAverager):
+    r"""
+    Averages parameters periodically after the warm-up stage.
+
+    This can be used for running `post-local SGD `_,
+    by running :class:`~torch.nn.DistributedDataParallel` (DDP)
+    using the subgroups created by :meth:`~torch.distributed.new_subgroups`.
+
+    Args:
+        period (int): The number of steps per model averaging.
+                      Usually the period should be greater than ``1`` to reduce the communication cost.
+                      Otherwise, only DDP needs to be used.
+        warmup_steps (int): The number of warm-up steps. During this stage,
+                            model averaging is skipped.
+        process_group: The process group to be used for all-reduce.
+                       If ``None``, the default process group, which
+                       is created by :func:`torch.distributed.init_process_group`,
+                       will be used. (default: ``None``)
+
+    Example::
+
+        >>> # xdoctest: +SKIP("undefined variables")
+        >>> import torch
+        >>> import torch.distributed as dist
+        >>> import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
+        >>> import torch.distributed.algorithms.model_averaging.averagers as averagers
+        >>> import torch.nn as nn
+        >>>
+        >>> dist.init_process_group("nccl", rank=rank, world_size=16)
+        >>> torch.cuda.set_device(rank)
+        >>> module = nn.Linear(1, 1, bias=False).cuda()
+        >>> model = nn.parallel.DistributedDataParallel(
+        >>>    module, device_ids=[rank], output_device=rank
+        >>> )
+        >>> # Register a post-localSGD communication hook.
+        >>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)
+        >>> model.register_comm_hook(state, post_localSGD_hook)
+        >>>
+        >>> # In the first 100 steps, run global gradient averaging like normal DDP at every step.
+        >>> # After 100 steps, run model averaging every 4 steps.
+        >>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``.
+        >>> averager = averagers.PeriodicModelAverager(period=4, warmup_steps=100)
+        >>> for step in range(0, 200):
+        >>>    optimizer.zero_grad()
+        >>>    loss = loss_fn(output, labels)
+        >>>    loss.backward()
+        >>>    optimizer.step()
+        >>>    # Will average model parameters globally every 4 steps. Thus,
+        >>>    # inter-node communication only occurs every 4 iterations after
+        >>>    # the initial ``warmup_steps`` period.
+        >>>    averager.average_parameters(model.parameters())
+    """
+
+    def __init__(
+        self, period, warmup_steps=0, process_group: dist.ProcessGroup | None = None
+    ):
+        super().__init__(process_group)
+        if warmup_steps < 0:
+            raise ValueError("Arg ``warmup_steps`` must be a non-negative number.")
+        self.warmup_steps = warmup_steps
+        if period < 1:
+            raise ValueError("Arg ``period`` must be a positive value.")
+        elif period == 1:
+            warnings.warn(
+                "When period is 1, no need to use model averaging because the communication cost "
+                "of all-reducing parameters will be no less than the cost of all-reducing gradients "
+                "by DistributedDataParallel in the backward pass. Therefore, only "
+                "DistributedDataParallel should be used for this case.",
+                stacklevel=2,
+            )
+        self.period = period
+
+    def average_parameters(
+        self,
+        params: Iterable[torch.nn.Parameter] | Iterable[dict[str, torch.nn.Parameter]],
+    ):
+        """
+        Averages parameters or parameter groups of an optimizer if ``step`` is no less than ``warmup_steps``.
+
+        Can be divided by ``period``, where ``step`` is increased by 1
+        at each iteration in the training loop.
+        Args:
+            params: The parameters of a model or parameter groups of an optimizer.
+
+        """
+        if (
+            self.step >= self.warmup_steps
+            and (self.step - self.warmup_steps) % self.period == 0
+        ):
+            utils.average_parameters_or_parameter_groups(
+                params, _not_none(self.process_group)
+            )
+        self.step += 1
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f7edc447d1089e2c09ba10764bc0fbfce9a1770
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py
@@ -0,0 +1,179 @@
+# mypy: allow-untyped-defs
+# Copyright 2022 Cruise LLC
+import logging
+import warnings
+from collections import OrderedDict
+from collections.abc import Iterable
+
+import torch
+import torch.distributed as dist
+import torch.distributed.algorithms.model_averaging.averagers as averagers
+import torch.distributed.algorithms.model_averaging.utils as utils
+
+
+logger = logging.getLogger(__name__)
+
+
+class HierarchicalModelAverager(averagers.ModelAverager):
+    r"""
+    Runs hierarchical model averaging (`hierarchical SGD `_).
+
+    Process groups of different sizes are organized in a hierarchy, and they average parameters
+    by using different periods concurrently after the warm-up stage.
+    This is an extension of :class:`~torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager`
+    that supports `post-local SGD `_, which essentially only supports
+    a two-level hierarchy: the intra-machine level and the global level, where the intra-machine
+    level is usually embedded in :meth:`~torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook`.
+    Similarly, the process groups within this class do not have such an intra-machine process
+    subgroup, which should be embedded by the post-local SGD communication hook instead.
+
+    Args:
+        period_group_size_dict: An ordered dict mapping keys of model averaging period to
+                                process group size, used for initializing process groups of
+                                different sizes in a hierarchy to average parameters concurrently.
+                                Particularly, at each iteration, there will be at most a single
+                                process group that runs averaging -- the period of such group should
+                                have the largest period which the current step can be divided by.
+                                For example, if the dict has three keys: 2, 4, and 8,
+                                then this means totally three process groups will be created to
+                                average parameters every 2, 4, and 8 iterations, respectively.
+                                At the 4th iteration, only the second process group will run
+                                averaging, because the first process group should be a
+                                subset of the second process group, and no need to execute the first
+                                process group redundantly.
+                                On the other hand, the third process group can only be triggered
+                                every 8 iterations, so it will not be triggered at the 4th iteration.
+        warmup_steps (int): The number of warm-up steps. During this stage, model averaging is skipped.
+        process_group (ProcessGroup, optional): The overall process group containing all the processes that runs model averaging.
+                                                If ``None``, the default process group, which is created
+                                                by :func:`torch.distributed.init_process_group`, will be used.
+                                                (default: ``None``)
+
+    Example::
+        >>> # xdoctest: +SKIP('undefined rank')
+        >>> from collections import OrderedDict
+        >>> import torch
+        >>> import torch.distributed as dist
+        >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (
+        >>>     PostLocalSGDState,
+        >>>     post_localSGD_hook,
+        >>> )
+        >>> import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD
+        >>> import torch.nn as nn
+        >>>
+        >>> dist.init_process_group("nccl", rank=rank, world_size=16)
+        >>> torch.cuda.set_device(rank)
+        >>> module = nn.Linear(1, 1, bias=False).to(rank)
+        >>> model = nn.parallel.DistributedDataParallel(
+        >>>    module, device_ids=[rank], output_device=rank
+        >>> )
+        >>> # Register a post-localSGD communication hook.
+        >>> # Assume that each machine has 4 GPUs, then each intra-machine subgroup has a size of 4.
+        >>> subgroup, _ = dist.new_subgroups()
+        >>> state = PostLocalSGDState(process_group=None, subgroup=subgroup, start_localSGD_iter=100)
+        >>> model.register_comm_hook(state, post_localSGD_hook)
+        >>>
+        >>> # Average parameters among each group of 8 processes every 4 iterations, and among all
+        >>> # the 16 processes every 16 iterations.
+        >>> averager = hierarchicalSGD.HierarchicalModelAverager(
+        >>>     period_group_size_dict=OrderedDict([(4, 8), (16, 16)]), warmup_steps=100)
+        >>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``.
+        >>> # In the first 100 steps, run global gradient averaging like normal DDP at every step.
+        >>> # After 100 steps, run model averaging at two levels.
+        >>> for step in range(0, 200):
+        >>>    optimizer.zero_grad()
+        >>>    loss = loss_fn(output, labels)
+        >>>    loss.backward()
+        >>>    optimizer.step()
+        >>>    # Average parameters after ``optimizer.step()``.
+        >>>    # Thus, the inter-node communication only occurs periodically after ``warmup_steps``.
+        >>>    averager.average_parameters(model.parameters())
+
+    .. warning ::
+        The last group size in the dict must be the size of the provided ``process_group``,
+        which indicates model averaging at the highest level of the hierarchy.
+        If ``process_group`` is not provided, then the last group size should be equal to the world size.
+
+    .. warning ::
+        `HierarchicalModelAverager` is experimental and subject to change.
+    """
+
+    def __init__(self, period_group_size_dict=None, warmup_steps=0, process_group=None):
+        super().__init__(process_group)
+        if not period_group_size_dict:
+            raise ValueError("Arg ``period_group_size_dict`` must not be empty.")
+        self._periods = list(period_group_size_dict.keys())
+        if self._periods[0] <= 0:
+            raise ValueError(
+                "The minimum period in arg ``period_group_size_dict`` must be a positive value."
+            )
+        elif self._periods[-1] == 1:
+            warnings.warn(
+                "When the maximum period in arg ``period_group_size_dict`` is 1, "
+                "no need to use model averaging because the communication cost "
+                "of all-reducing parameters will be no less than the cost of all-reducing gradients "
+                "by DistributedDataParallel in the backward pass. Therefore, only "
+                "DistributedDataParallel should be used for this case.",
+                stacklevel=2,
+            )
+        overall_group_size = dist.get_world_size(group=self.process_group)
+        if list(period_group_size_dict.values())[-1] != overall_group_size:
+            raise ValueError(
+                f"The last value in arg ``period_process_group_dict`` {list(period_group_size_dict.values())[-1]} "
+                f"must be equal to the size of arg ``process_group`` {overall_group_size}."
+            )
+
+        self.period_process_group_dict = OrderedDict()
+        logger.info("Model averaging hierarchy:")
+        for period, group_size in period_group_size_dict.items():
+            logger.info(
+                "\tEach group that has %s processes average parameters every %s iterations, "
+                "if no higher-level averaging.",
+                group_size,
+                period,
+            )
+            if group_size != overall_group_size:
+                self.period_process_group_dict[period], _ = dist.new_subgroups(
+                    group_size=group_size, group=self.process_group
+                )
+            else:
+                self.period_process_group_dict[period] = self.process_group
+
+        if warmup_steps < 0:
+            raise ValueError("Arg ``warmup_steps`` must be a non-negative number.")
+        self.warmup_steps = warmup_steps
+
+    def _find_process_group(self):
+        """
+        Return a process group as the value of an ``period_process_group_dict`` entry.
+
+        If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``,
+        then the returned process group is the one corresponding to the largest period,
+        since this process group will be used for averaging parameters at this ``step``.
+        Returns ``None`` if not found.
+        """
+        for period in reversed(self._periods):
+            if self.step % period == 0:
+                return self.period_process_group_dict[period]
+        return None
+
+    def average_parameters(
+        self,
+        params: Iterable[torch.nn.Parameter] | Iterable[dict[str, torch.nn.Parameter]],
+    ):
+        """
+        Averages parameters or parameter groups of an optimizer.
+
+        Averaging only occurs if ``step`` is no less than ``warmup_steps``
+        and it can be divided by a period in the keys of ``period_process_group_dict``,
+        where ``step`` is increased by 1 at each iteration in the training loop.
+        If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``,
+        only the largest period is used, and the corresponding process group is used for averaging parameters.
+        Args:
+            params: The parameters of a model or parameter groups of an optimizer.
+        """
+        if self.step >= self.warmup_steps:
+            group = self._find_process_group()
+            if group is not None:
+                utils.average_parameters_or_parameter_groups(params, group)
+        self.step += 1
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a61c036913edd6cf7fbcde6b77bc6ee5970065e
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/utils.py
@@ -0,0 +1,86 @@
+# mypy: allow-untyped-defs
+import itertools
+from collections.abc import Iterable, Iterator
+
+import torch
+import torch.distributed as dist
+
+# The two imports below are not always available depending on the
+# USE_DISTRIBUTED compile flag. Make sure they raise import error
+# if we're trying to use them.
+from torch.distributed import group, ProcessGroup
+
+
+__all__ = [
+    "average_parameters",
+    "get_params_to_average",
+    "average_parameters_or_parameter_groups",
+]
+
+
+def average_parameters(
+    params: Iterator[torch.nn.Parameter], process_group: ProcessGroup
+):
+    """
+    Averages all the given parameters.
+
+    For allreduce efficiency, all the parameters are flattened into a contiguous buffer.
+    Thus, it requires extra memory of the same size as the given parameters.
+    """
+    group_to_use = process_group if process_group is not None else group.WORLD
+    # Do not update any parameter if not in the process group.
+    if dist._rank_not_in_group(group_to_use):
+        return
+
+    params_it1, params_it2 = itertools.tee(params)
+    # If the input parameters have different data types,
+    # packing these parameters will trigger an implicit type up-casting.
+    # The original parameter data types will be restored during the subsequent unpacking.
+    flat_params = torch.cat([p.data.reshape(-1) for p in params_it1])
+    flat_params /= dist.get_world_size(group_to_use)
+    # Make sure the allreduce will not conflict with any other ongoing process group.
+    if torch.accelerator.is_available():
+        torch.accelerator.synchronize()
+    dist.all_reduce(flat_params, group=group_to_use)
+
+    offset = 0
+    for p in params_it2:
+        p.data = flat_params[offset : offset + p.numel()].view_as(p).type_as(p)
+        offset += p.numel()
+
+
+def get_params_to_average(
+    params: Iterable[torch.nn.Parameter] | Iterable[dict[str, torch.nn.Parameter]],
+):
+    """
+    Return a list of parameters that need to average.
+
+    This filters out the parameters that do not contain any gradients.
+    Args:
+        params: The parameters of a model or parameter groups of an optimizer.
+    """
+    filtered_params = []
+    for param in params:
+        if isinstance(param, torch.nn.Parameter):
+            # model.parameters() input
+            param_data = param
+            if param_data.grad is not None:
+                filtered_params.append(param_data)
+        elif isinstance(param, dict):
+            # optimizer.param_groups input
+            for param_data in param["params"]:
+                if param_data.grad is not None:
+                    filtered_params.append(param_data)
+        else:
+            raise NotImplementedError(
+                f"Parameter input of type {type(param)} is not supported"
+            )
+    return filtered_params
+
+
+def average_parameters_or_parameter_groups(
+    params: Iterable[torch.nn.Parameter] | Iterable[dict[str, torch.nn.Parameter]],
+    process_group: ProcessGroup,
+):
+    """Averages parameters of a model or parameter groups of an optimizer."""
+    average_parameters(iter(get_params_to_average(params)), process_group)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..98af7be5b3edc481aa7afaa7a787e61db3ea33d3
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/__pycache__/fr_trace.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/__pycache__/fr_trace.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9e4668609c02e1316b48c32c84f636607a8863c4
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/__pycache__/fr_trace.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e2a6fb861aafd49fb98a096cdeb465f04a940efe
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/builder.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/builder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8bbb8e077ebf8cb5ed361304ec53d7ba29187fe0
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/builder.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/config_manager.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/config_manager.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a5e388a4bddf68c17aff1661d66b1533c61c727
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/config_manager.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/fr_logger.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/fr_logger.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..349c540198a1b2eda077f29e8beb41b5510603d8
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/fr_logger.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/loader.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/loader.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6bbf5a1bf1d8e99479546461b7f0ebe1c64399e1
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/loader.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/types.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/types.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4fed270a886cff158c180add28580a78a1037685
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/types.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..56e9ead76b1bfab373deedfc2bea87738f958deb
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/builder.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..56736450e3f2a8decdc6dfc11c929d8a1bdfb16f
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/builder.py
@@ -0,0 +1,457 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import ast
+import copy
+import os
+import sys
+from typing import Any  # type: ignore[attr-defined]
+
+from torch.distributed.flight_recorder.components.fr_logger import FlightRecorderLogger
+from torch.distributed.flight_recorder.components.types import (
+    Collective,
+    Database,
+    EntryState,
+    Group,
+    MatchStateRecord,
+    Membership,
+    NCCLCall,
+    Op,
+    Traceback,
+)
+from torch.distributed.flight_recorder.components.utils import (
+    add_stack_id_in_entries,
+    align_trace_from_beginning,
+    check_current_entry_match,
+    check_no_missing_dump_files,
+    check_version,
+    error_analysis,
+    find_coalesced_group as find_coalesced_group_p2p_only,
+    find_coalesced_group_with_non_p2p,
+    get_version_detail,
+    just_print_entries,
+    match_coalesced_groups as match_coalesced_groups_p2p_only,
+    match_coalesced_groups_with_non_p2p,
+)
+
+
+__all__ = [
+    "build_groups_memberships",
+    "build_collectives",
+    "transform_ft",
+    "build_db",
+]
+
+# Set up logging
+logger: FlightRecorderLogger = FlightRecorderLogger()
+
+
+try:
+    from tabulate import tabulate
+except ModuleNotFoundError:
+    logger.warning("tabulate is not installed. Proceeding without it.")
+
+    # Define a no-op tabulate function
+    def tabulate(data: Any, headers: Any = None) -> Any:  # type: ignore[misc]
+        return data
+
+
+"""
+Flat DB builder
+"""
+
+
+def build_groups_memberships(
+    pg_config: Any,
+) -> tuple[
+    list[Group],
+    dict[Any, Group],
+    list[Membership],
+    dict[str, set[Any]],
+    dict[tuple[str, int], str],
+]:
+    """
+    pg_config: {
+        global_rank: {
+            (pg_guid, desc, ranks)
+        }
+    }
+
+    `pg_guid` is a system generated id, but depending on the mode of PG creation it could be a globally incrementing int
+          or a hash of the ranks.  See `_process_group_name` in distributed_c10d.py.
+    `desc` is provided by the user (optionally) and should be 'meaningful' (e.g. TP/PP/DP group)
+    `ranks` is a list of the 'global ranks' that are members of the PG.
+
+    (pg_guid, desc, ranks) tuples are appended lazily to the flight buffer when `getNCCLComm` is called on a PG and
+    the `enabled_` flag is true for that PG.
+        - the order of calling (init_process_group, new_group, etc) does not affect the order of the tuples in the list
+
+    Returns:
+        `groups`: a groups table where each row is a Group namedtuple.
+        `_groups`: a dict that is indexed by pg_guid with Group namedtuple as value.
+        `memberships`: a membership table where each row is a Membership namedtuple.
+        `_memberships`: a dict that is indexed by pg_guid with set of ranks (int) as value.
+        `_pg_guids`: a dict that is indexed by (pg_uid, global_rank) with pg_guid as value.
+    """
+    # flat lists for return
+    groups = []
+    memberships = []
+
+    # dicts for faster cross-rank validation
+    _groups = {}
+    _memberships = {}
+    _pg_guids = {}
+    for global_rank in pg_config:
+        for pg_uid in pg_config[global_rank]:
+            desc = pg_config[global_rank][pg_uid]["desc"]
+            ranks = ast.literal_eval(pg_config[global_rank][pg_uid]["ranks"])
+            # With the adoption of the split_group API, we can have multiple PGs with the same pg_guid (PG Name)
+            # So we need to add the hash of all its ranks within the PG as well.
+            # Also guid must be a string because `_process_group_name` returns a string.
+            pg_guid = pg_uid + str(hash(frozenset(ranks)))
+            _pg_guids[(pg_uid, global_rank)] = pg_guid
+            if isinstance(ranks, str):
+                # TODO Bug in FR data format? ranks is '[0, 1,...]'
+                ranks = eval(ranks)
+
+            if pg_guid not in _groups:
+                groups.append(Group(id=pg_guid, desc=desc, size=len(ranks)))
+                for rank in ranks:
+                    memberships.append(Membership(group_id=pg_guid, global_rank=rank))
+                _groups[pg_guid] = groups[-1]
+                _memberships[pg_guid] = set(ranks)
+            else:
+                # validation across ranks
+                assert _groups[pg_guid].desc == desc, (
+                    f"mismatch in desc {_groups[pg_guid].desc} vs {desc} for group {pg_guid}"
+                )
+                assert _memberships[pg_guid] == set(ranks), (
+                    f"mismatch in membership for group {pg_guid} {_memberships[pg_guid]} vs {set(ranks)}"
+                )
+    return groups, _groups, memberships, _memberships, _pg_guids
+
+
+def build_collectives(
+    all_entries: dict[int, list[dict[str, Any]]],
+    _groups: dict[str, Group],
+    _memberships: dict[str, set[Any]],
+    _pg_guids: dict[tuple[str, int], str],
+    version: str,
+    mismatch_cap: int = 10,
+) -> tuple[list[Traceback], list[Collective], list[NCCLCall]]:
+    """
+    groups, memberships are the non-flat dicts that are indexable
+    all_entries is a raw dict from the original dumps:
+
+    all_entries: {
+        global_rank: [
+            {
+                record_id: ordered id of the event in the trace buffer
+                pg_id: ProcessGroupNCCL::uid_
+                    *note: `pg_id` corresponds to nothing in groups table
+                process_group: (pg_name, desc)
+                    *note: `pg_name`, `desc` corresponds to `pg_id`, `desc` in groups table
+                collective_seq_id: ordered id for collective operations and coalesced group operations
+                p2p_seq_id: ordered id for point-to-point operations
+                op_id: ordered id including individual ops inside coalescing group
+                profiling_name: descriptive name of the operation
+                'time_created_ns',
+                'input_sizes',
+                'output_sizes',
+                'state',
+                'time_discovered_started_ns',
+                'time_discovered_completed_ns',
+                'retired',
+                'frames',
+            }
+        ]
+    }
+    """
+    tracebacks: list[Traceback] = []
+
+    collectives: list[Collective] = []
+    nccl_calls: list[NCCLCall] = []
+
+    # once we find one mismatch, we stop pairing up collectives since the pairing is possibly incorrect
+    # instead, just record the remaining ops as NCCLCalls
+    mismatch = {_groups[g].id: 0 for g in _groups}
+
+    # For best effort partial analysis.
+    dumps_ranks = {int(key) for key in all_entries}
+    """
+    - it doesn't matter what order I put collectives/ncclops into their table. we can later on re-sort it by start time
+    - there could be multiple options for the "first" collective to pair up (rank 0,1 might do a bcast while rank 2,3 do a bcast)
+    - within a group, the first collective must be the same on all ranks in the group, then it can be marked as a
+    collective and removed
+    """
+    while all_entries:
+        # we greedily match collectives, starting arbitrarily with the trace from the first rank
+        # later, if we exhaust the first rank, we continue with the next 'first rank'
+        rank_iter = iter(all_entries)
+        first_rank = next(rank_iter)
+        other_ranks = list(rank_iter)
+
+        if len(all_entries[first_rank]) == 0:
+            all_entries.pop(first_rank)
+            continue
+
+        # lets match the first collective! we need to know which ranks are involved, and ensure that this same
+        # collective is also the first one on those ranks within that group
+        entries = all_entries[first_rank]
+        current_entry = entries[0]
+        desc = current_entry["process_group"][1]
+        # For db build and logs printing, we want to use the original pg_name, not the hash one.
+        original_pg_name = current_entry["process_group"][0]
+        pg_name = _pg_guids[(original_pg_name, first_rank)]
+        expected_ranks = set(_memberships[pg_name])
+        entry_state = EntryState(current_entry, expected_ranks)
+        match_record = MatchStateRecord(
+            expected_ranks=expected_ranks,
+            other_ranks=other_ranks,
+            entry_state=entry_state,
+            candidate_ranks={first_rank},
+            candidate_idx={},
+            found_ranks=set(),
+            found_idx={},
+            errors=set(),
+        )
+
+        major_v, minor_v = get_version_detail(version)
+        find_coalesced_group = (
+            find_coalesced_group_p2p_only
+            if major_v <= 2 and minor_v < 7
+            else find_coalesced_group_with_non_p2p
+        )
+        maybe_coalesced_group = find_coalesced_group(
+            pg_name, entries, _pg_guids, first_rank
+        )
+        if len(maybe_coalesced_group) > 1:
+            num_coalesced_entries = len(maybe_coalesced_group)
+            # We need a copy of the original expected ranks to avoid modifying it.
+            candidate_ranks = copy.deepcopy(expected_ranks)
+            done_ranks = set()
+            all_coalesced_entries = {}
+            while candidate_ranks:
+                curr = candidate_ranks.pop()
+                done_ranks.add(curr)
+                grp = (
+                    find_coalesced_group(pg_name, all_entries[curr], _pg_guids, curr)  # type: ignore[index]
+                    if curr in all_entries  # type: ignore[comparison-overlap]
+                    else []
+                )
+                all_coalesced_entries[curr] = grp
+                for _, entry in grp:
+                    op = Op(entry, _memberships, pg_name)
+                    peer = None
+                    if op.type == "send":
+                        assert op._src_g == curr, (
+                            f"Send src error: {curr} expected but {op._src_g} is set"
+                        )
+                        peer = op._dst_g
+                    elif op.type == "recv":
+                        assert op._dst_g == curr, (
+                            f"Recv dst error: {curr} expected but {op._dst_g} is set"
+                        )
+                        peer = op._src_g
+                    if peer and peer not in done_ranks:
+                        candidate_ranks.add(peer)
+
+            if major_v <= 2 and minor_v < 7:
+                match = match_coalesced_groups_p2p_only(
+                    all_coalesced_entries,
+                    group_size=_groups[pg_name].size,
+                    groups=_groups,
+                    memberships=_memberships,
+                    _pg_guids=_pg_guids,
+                )
+            else:
+                match = match_coalesced_groups_with_non_p2p(
+                    copy.deepcopy(
+                        all_coalesced_entries
+                    ),  # We want to keep a copy for cleanup.
+                    pg_info=(pg_name, desc),
+                    memberships=_memberships,
+                    _pg_guids=_pg_guids,
+                    mismatch=mismatch,
+                    dumps_ranks=dumps_ranks,
+                    version=version,
+                    collectives=collectives,
+                    match_record=match_record,
+                )
+
+            if match and mismatch[pg_name] == 0:
+                # We treat coalesced collectives as a single collective.
+                # TODO: we need to surface a merged collective info like input/output sizes to users.
+                collectives.append(
+                    match_record.entry_state.to_collective(len(collectives))
+                )
+            else:
+                mismatch[pg_name] += 1
+            for r in all_coalesced_entries:
+                idx_map = {r: i for i, _ in reversed(all_coalesced_entries[r])}  # noqa: B035
+                nccl_calls.extend(
+                    reversed(
+                        match_record.entry_state.to_nccl_call(
+                            all_entries,
+                            idx_map,
+                            len(nccl_calls),
+                            collectives[-1].id if match else None,
+                        )
+                    )
+                )
+                # This extra cleanup is needed because we need to pop all collectives within a coalesced collective.
+                for i, k in idx_map.items():
+                    for _ in range(1, num_coalesced_entries):
+                        all_entries[i].pop(k)
+        else:
+            # Iterate through all the ranks and check if there is a mismatch for the current entry.
+            check_current_entry_match(
+                all_entries,
+                _pg_guids,
+                (pg_name, desc),
+                current_entry,
+                _memberships,
+                mismatch,
+                match_record,
+            )
+
+            # Use heuristics to decide what type of errors and error messages we should print.
+            error_analysis(
+                all_entries,
+                match_record,
+                dumps_ranks,
+                first_rank,
+                current_entry,
+                mismatch,
+                get_version_detail(version),
+                pg_name,
+            )
+
+            # at this point there are 3 possibilities
+            # 1. we found a match on all the ranks that are members of the group
+            #  -> we create a Collective and remove the individual entries from their original lists
+            if match_record.found_ranks == expected_ranks and mismatch[pg_name] == 0:
+                collectives.append(
+                    match_record.entry_state.to_collective(len(collectives))
+                )
+                idx_map = {
+                    r: match_record.found_idx[r] if r != first_rank else 0
+                    for r in match_record.found_ranks
+                }
+                nccl_calls.extend(
+                    match_record.entry_state.to_nccl_call(
+                        all_entries, idx_map, len(nccl_calls), collectives[-1].id
+                    )
+                )
+
+            # 2. we found a partial match but some ranks are missing
+            # 3. we found no match
+            #  -> since its not a complete collective, no entry goes into collectives but we still record a nccl call
+            #     TODO should there be a way to mark 'mismatches'?
+            else:
+                logger.debug("appending a non-matching collective")
+                idx_map = {
+                    r: match_record.candidate_idx[r] if r != first_rank else 0
+                    for r in match_record.candidate_ranks
+                }
+                collectives.append(
+                    match_record.entry_state.to_collective(
+                        len(collectives),
+                        errors=match_record.errors,
+                        idx_map=idx_map,
+                        all_entries=all_entries,
+                    )
+                )
+                nccl_calls.extend(
+                    match_record.entry_state.to_nccl_call(
+                        all_entries, idx_map, len(nccl_calls), None
+                    )
+                )
+
+        if mismatch[pg_name] > mismatch_cap:
+            logger.error(
+                "Too many mismatches for process_group %s: %s aborting", pg_name, desc
+            )
+            break
+
+    return tracebacks, collectives, nccl_calls
+
+
+def transform_ft(
+    details: dict[str, dict[str, Any]], group_world_size: int
+) -> dict[str, dict[str, Any]]:
+    for dump_key, dump in details.items():
+        rank = dump["rank"]
+        for key, pg_config in dump["pg_config"].items():
+            if pg_config["desc"] == "default_pg":
+                ranks = eval(pg_config["ranks"])
+                replica_id = rank // group_world_size
+                first_rank = replica_id * group_world_size
+                new_ranks = [r + first_rank for r in ranks]
+                details[dump_key]["pg_config"][key]["ranks"] = f"{new_ranks}"
+
+    return details
+
+
+def build_db(
+    details: dict[str, dict[str, Any]], args: argparse.Namespace, version: str
+) -> Database:
+    if args.verbose:
+        os.environ["FR_TRACE_VERBOSE_OUTPUT"] = "1"
+    # temporary state used for building database
+    entries = {}
+    pg_config = {}
+    version_by_ranks = {}
+    for dump in details.values():
+        rank = dump["rank"]
+        entries[rank] = dump["entries"]
+        version_by_ranks[rank] = dump["version"]
+        pg_config[rank] = dump["pg_config"]
+
+    # Ensure version is consistent across all ranks.
+    check_version(version_by_ranks, version)
+    entries = align_trace_from_beginning(entries)
+    stack_id_trace_map: dict[str, int] = {}
+    if args.just_print_entries:
+        entries, stack_id_trace_map = add_stack_id_in_entries(entries)
+
+    # flattened database
+    groups, _groups, memberships, _memberships, _pg_guids = build_groups_memberships(
+        pg_config
+    )
+    logger.debug("built groups, memberships")
+
+    if args.just_print_entries:
+        just_print_entries(
+            entries, _groups, _memberships, _pg_guids, args, stack_id_trace_map
+        )
+        sys.exit(0)
+
+    if not args.allow_incomplete_ranks:
+        check_no_missing_dump_files(entries, memberships)
+
+    tracebacks, collectives, nccl_calls = build_collectives(
+        entries, _groups, _memberships, _pg_guids, version, args.mismatch_cap
+    )
+    logger.debug("built collectives, nccl_calls")
+    if args.verbose:
+        logger.debug("Groups")
+        logger.debug(tabulate(groups, headers=Group._fields))
+        logger.debug("Memberships")
+        logger.debug(tabulate(memberships, headers=Membership._fields))
+        logger.debug("Collectives")
+        logger.debug(tabulate(collectives, headers=Collective._fields))
+        logger.debug("NCCLCalls")
+        logger.debug(tabulate(nccl_calls, headers=NCCLCall._fields))
+    db = Database(
+        tracebacks=tracebacks,
+        collectives=collectives,
+        ncclcalls=nccl_calls,
+        groups=groups,
+        memberships=memberships,
+    )
+    return db
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/config_manager.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/config_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b12966588215ce01118f9aea9f8bb771390c3c
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/config_manager.py
@@ -0,0 +1,110 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import logging
+from collections.abc import Sequence
+
+from torch.distributed.flight_recorder.components.fr_logger import FlightRecorderLogger
+
+
+__all__ = ["JobConfig"]
+
+
+logger: FlightRecorderLogger = FlightRecorderLogger()
+
+
+class JobConfig:
+    """
+    A helper class to manage the script configuration.
+    """
+
+    def __init__(self: "JobConfig"):
+        self.parser = argparse.ArgumentParser(
+            description="PyTorch Flight recorder analyzing script."
+        )
+        self.parser.add_argument(
+            "trace_dir",
+            nargs="?",
+            help="Directory containing one trace file per rank, named with _.",
+        )
+        self.parser.add_argument(
+            "--selected-ranks",
+            default=None,
+            nargs="+",
+            type=int,
+            help="List of ranks we want to show traces for.",
+        )
+        self.parser.add_argument(
+            "--allow-incomplete-ranks",
+            action="store_true",
+            help=(
+                "FR trace require all ranks to have dumps for analysis. "
+                "This flag allows best-effort partial analysis of results "
+                "and printing of collected data."
+            ),
+        )
+        self.parser.add_argument(
+            "--pg-filters",
+            default=None,
+            nargs="+",
+            type=str,
+            help=(
+                "List of filter strings, it could be pg name or pg desc. "
+                "If specified, only show traces for the given pg."
+            ),
+        )
+        self.parser.add_argument("-o", "--output", default=None)
+        self.parser.add_argument(
+            "-p",
+            "--prefix",
+            help=(
+                "Common filename prefix to strip such that rank can be extracted. "
+                "If not specified, will attempt to infer a common prefix."
+            ),
+            default=None,
+        )
+        self.parser.add_argument("-j", "--just_print_entries", action="store_true")
+        self.parser.add_argument("-v", "--verbose", action="store_true")
+        self.parser.add_argument("--print_stack_trace", action="store_true")
+        self.parser.add_argument(
+            "--mismatch_cap",
+            type=int,
+            default=10,
+            help="Maximum number of mismatches we print (from earliest).",
+        )
+        self.parser.add_argument(
+            "--transform-ft",
+            action="store_true",
+            help="Transform PG config to use global ranks to analyze traces produced by torchft",
+        )
+        self.parser.add_argument(
+            "--group-world-size",
+            type=int,
+            default=None,
+            help="The number of ranks in 1 torchft replica group. Must be specified if --transform-ft is True",
+        )
+
+    def parse_args(self: "JobConfig", args: Sequence[str] | None) -> argparse.Namespace:
+        # pyrefly: ignore [bad-assignment]
+        args = self.parser.parse_args(args)
+        # pyrefly: ignore [missing-attribute]
+        if args.selected_ranks is not None:
+            # pyrefly: ignore [missing-attribute]
+            assert args.just_print_entries, (
+                "Not support selecting ranks without printing entries"
+            )
+        # pyrefly: ignore [missing-attribute]
+        if args.pg_filters is not None:
+            # pyrefly: ignore [missing-attribute]
+            assert args.just_print_entries, (
+                "Not support selecting pg filters without printing entries"
+            )
+        # pyrefly: ignore [missing-attribute]
+        if args.verbose:
+            logger.set_log_level(logging.DEBUG)
+        # pyrefly: ignore [bad-return]
+        return args
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/fr_logger.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/fr_logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..e56634397bff9d6d1ec38eab43f1856f52e02829
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/fr_logger.py
@@ -0,0 +1,54 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+from collections.abc import Callable
+from typing import Any
+
+
+__all__ = ["FlightRecorderLogger"]
+
+
+class FlightRecorderLogger:
+    _instance: Any | None = None
+    logger: logging.Logger
+
+    def __init__(self) -> None:
+        self.logger: logging.Logger = logging.getLogger("Flight Recorder")
+
+    def __new__(cls) -> Any:
+        if cls._instance is None:
+            cls._instance = super().__new__(cls)
+            cls._instance.logger = logging.getLogger("Flight Recorder")
+            cls._instance.logger.setLevel(logging.INFO)
+            formatter = logging.Formatter("%(message)s")
+            ch = logging.StreamHandler()
+            ch.setFormatter(formatter)
+            cls._instance.logger.addHandler(ch)
+        return cls._instance
+
+    def set_log_level(self, level: int) -> None:
+        self.logger.setLevel(level)
+
+    @property
+    def debug(self) -> Callable[..., None]:
+        return self.logger.debug
+
+    @property
+    def info(self) -> Callable[..., None]:
+        return self.logger.info
+
+    @property
+    def warning(self) -> Callable[..., None]:
+        return self.logger.warning
+
+    @property
+    def error(self) -> Callable[..., None]:
+        return self.logger.error
+
+    @property
+    def critical(self) -> Callable[..., None]:
+        return self.logger.critical
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/loader.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce361b103fe04488d0390df1b898d27016f2b47b
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/loader.py
@@ -0,0 +1,98 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import gc
+import os
+import pickle
+import re
+import time
+from collections import defaultdict
+from typing import Any
+
+from torch.distributed.flight_recorder.components.fr_logger import FlightRecorderLogger
+
+
+__all__ = [
+    "read_dump",
+    "read_dir",
+]
+
+
+logger: FlightRecorderLogger = FlightRecorderLogger()
+
+
+def read_dump(prefix: str, filename: str) -> dict[str, str | int | list[Any]]:
+    basename = os.path.basename(filename)
+
+    rank = int(basename[len(prefix) :])
+    host_name = f"host_rank{rank}"
+
+    with open(filename, "rb") as infile:
+        dump = pickle.load(infile)
+
+    entries = dump["entries"]
+    version = dump["version"]
+    pg_config = dump["pg_config"]
+
+    return {
+        "host_name": host_name,
+        "rank": rank,
+        "entries": entries,
+        "version": version,
+        "pg_config": pg_config,
+    }
+
+
+exp = re.compile(r"([\w\-\_]*?)(\d+)$")
+
+
+def _determine_prefix(files: list[str]) -> str:
+    """If the user doesn't specify a prefix, but does pass a dir full of similarly-prefixed files, we should be able to
+    infer the common prefix most of the time.  But if we can't confidently infer, just fall back to requiring the user
+    to specify it
+    """
+    possible_prefixes: defaultdict[str, set[int]] = defaultdict(set)
+    for f in files:
+        m = exp.search(f)
+        if m:
+            p, r = m.groups()
+            possible_prefixes[p].add(int(r))
+    if len(possible_prefixes) == 1:
+        prefix = next(iter(possible_prefixes))
+        logger.debug("Inferred common prefix %s", prefix)
+        return prefix
+    else:
+        raise ValueError(
+            "Unable to automatically determine the common prefix for the trace file names. "
+            "Please specify --prefix argument manually"
+        )
+
+
+def read_dir(args: argparse.Namespace) -> tuple[dict[str, dict[str, Any]], str]:
+    gc.disable()
+    prefix = args.prefix
+    details = {}
+    t0 = time.time()
+    version = ""
+    filecount = 0
+    assert os.path.isdir(args.trace_dir), f"folder {args.trace_dir} does not exist"
+    for root, _, files in os.walk(args.trace_dir):
+        if prefix is None:
+            prefix = _determine_prefix(files)
+        for f in files:
+            if (offset := f.find(prefix)) == -1:
+                continue
+            details[f] = read_dump(f[:offset] + prefix, os.path.join(root, f))
+            filecount += 1
+            if not version:
+                version = str(details[f]["version"])
+    tb = time.time()
+    assert len(details) > 0, (
+        f"no files loaded from {args.trace_dir} with prefix {prefix}"
+    )
+    logger.debug("loaded %s files in %ss", filecount, tb - t0)
+    return details, version
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/types.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/types.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fdfd9d8838b5e6d24c96501ba5556dd001b1a6a
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/types.py
@@ -0,0 +1,661 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import os
+from enum import auto, Enum
+from typing import (  # type: ignore[attr-defined]
+    _eval_type,
+    Any,
+    Generic,
+    NamedTuple,
+    TypeVar,
+)
+
+from torch.distributed.flight_recorder.components.fr_logger import FlightRecorderLogger
+
+
+__all__ = [
+    "Ref",
+    "TypeInfo",
+    "MatchState",
+    "MatchInfo",
+    "Group",
+    "Membership",
+    "Traceback",
+    "Collective",
+    "NCCLCall",
+    "Database",
+    "EntryState",
+    "Op",
+    "MatchStateRecord",
+]
+
+
+T = TypeVar("T", bound=NamedTuple)
+
+
+class Ref(Generic[T]):
+    pass
+
+
+class TypeInfo(NamedTuple):
+    name: str
+    fields: list[tuple[str, type]]  # type: ignore[type-arg]
+
+    @classmethod
+    def from_type(cls, c: T) -> "TypeInfo":
+        if hasattr(c, "__name__"):
+            name = c.__name__
+        else:
+            name = str(c)
+        return cls(
+            name,
+            [(f, _eval_type(c.__annotations__[f], globals(), {})) for f in c._fields],
+        )
+
+
+class MatchState(Enum):
+    """
+    Enum representing the possible states of matching for collective operations.
+
+    - FULLY_MATCHED: Indicates that all aspects of the collective operations match.
+    - COLLECTIVE_TYPE_MISMATCH: The types of the collective operations differ.
+    - SIZE_OR_SYNTAX_MISMATCH: There is a mismatch in input/output sizes or violation of collective syntax.
+    - COLLECTIVE_STATE_MISMATCH:
+        The states of the collective not same, such as one finished while another just started or scheduled.
+    - COLLECTIVE_DTYPE_MISMATCH: The data types of the collective input/output differ.
+    - UNDECIDED:
+        The match status is ambiguous or cannot be determined, e.g., we might need to check all ranks for alltoall_base.
+    """
+
+    FULLY_MATCHED = auto()
+    COLLECTIVE_TYPE_MISMATCH = auto()
+    SIZE_OR_SYNTAX_MISMATCH = auto()
+    COLLECTIVE_STATE_MISMATCH = auto()
+    COLLECTIVE_DTYPE_MISMATCH = auto()
+    UNDECIDED = auto()
+
+
+class MatchInfo:
+    """
+    Aside from the match state, we also store some dynamic info for the match such as the culprit rank
+    or collective state that caused the mismatch.
+    """
+
+    def __init__(self, state: MatchState, culprit: str | None = None) -> None:
+        self._state = state
+        self.culprit = culprit
+
+    def __str__(self) -> str:
+        details = f", {self.culprit}" if getattr(self, "culprit", None) else ""
+        return f"Error type: {self._state.name}{details}"
+
+    @property
+    def state(self) -> MatchState:
+        return self._state
+
+
+"""
+Schema for flat DB
+
+TODO schemas not yet implemented
+# threads as recorded at termination of process
+Threads
+    id: int
+    traceback_id: int
+    process_id: int
+
+Process:
+    id: int # Same as world groups RANK
+    pid: int
+    hostname: str
+
+NCCLOp:
+    # nccl op implementation details (sends/recv)
+    id: int
+    nccl_call_id: int
+
+"""
+
+
+class Group(NamedTuple):
+    id: str
+    desc: str
+    size: int
+
+
+class Membership(NamedTuple):
+    group_id: str
+    global_rank: int
+
+
+class Traceback(NamedTuple):
+    id: int
+    frames: str
+
+
+class Collective(NamedTuple):
+    id: int
+    group_id: str
+    pass_check: bool
+    collective_seq_id: int
+    p2p_seq_id: int
+    record_id: int
+    pg_desc: str
+    collective_name: str
+    input_sizes: list[list[int]]
+    output_sizes: list[list[int]]
+    expected_ranks: set[int]
+    collective_state: str
+    collective_frames: list[dict[str, str]]
+    input_numel: int | None = None
+    output_numel: int | None = None
+    missing_ranks: set[int] | None = None
+    mismatch_collectives: dict[int, "Collective"] | None = None
+    type_of_mismatch: MatchInfo | None = None
+
+
+class NCCLCall(NamedTuple):
+    id: int
+    collective_id: Ref[Collective]
+    group_id: str
+    global_rank: int  # technically Ref[Process] once we have it
+    traceback_id: Ref[Traceback]
+    collective_type: str
+    sizes: list[list[int]]
+
+
+class Database(NamedTuple):
+    groups: list[Group]
+    memberships: list[Membership]
+    tracebacks: list[Traceback]
+    collectives: list[Collective]
+    ncclcalls: list[NCCLCall]
+
+
+# TODO: We need to add a schema for the following
+types = [
+    TypeInfo.from_type(t)  # type: ignore[type-var]
+    for t in [Database, NCCLCall, Collective, Traceback, Membership, Group]
+    if (
+        isinstance(t, type)
+        and issubclass(t, tuple)
+        and hasattr(t, "_fields")
+        and t is not TypeInfo
+    )
+]
+
+"""
+Stacktrace cache
+TODO
+"""
+
+
+"""
+Collective Matching logic
+
+NOTE: For now, these collectives need to be supported by NCCL,
+https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/overview.html.
+"""
+COLLECTIVES = {
+    "broadcast",
+    "_broadcast_oop",
+    "reduce",
+    "_reduce_oop",
+    "all_gather",
+    "all_reduce",
+    "_all_gather_base",
+    "all_gather_into_tensor_coalesced",
+    "reduce_scatter",
+    "reduce_scatter_tensor_coalesced",
+    "_reduce_scatter_base",
+    "gather",
+    "scatter",
+    "all_to_all",
+    "all_reduce_barrier",
+    "allreduce_coalesced",
+    "ALLGATHER_coalesced",
+    "REDUCE_SCATTER_coalesced",
+}
+
+P2P = {
+    "send",
+    "recv",
+}
+
+
+class EntryState:
+    """
+    Util class to keep track of the state of an entry and standardize the way we
+    log the error info during analysis.
+    """
+
+    def __init__(self, entry: dict[str, Any], expected_ranks: set[int]) -> None:
+        self.pg_name = entry["process_group"][0]
+        self.desc = entry["process_group"][1]
+        self.pg_desc = (
+            f"{self.pg_name}:{self.desc}" if self.desc != "undefined" else self.pg_name
+        )
+        self.profiling_name = entry["profiling_name"]
+        self.collective_seq_id = entry["collective_seq_id"]
+        self.p2p_seq_id = entry["p2p_seq_id"]
+        self.record_id = entry["record_id"]
+        self.input_sizes = entry["input_sizes"]
+        self.output_sizes = entry["output_sizes"]
+        self.collective_state = entry["state"]
+        self.collective_frames = entry.get("frames", [])
+        self.expected_ranks = expected_ranks
+        self.missing_ranks: set[int]
+        self.input_numel: int
+        self.output_numel: int
+        self.errors: set[tuple[int, MatchInfo]]
+
+    def log(
+        self,
+        logger: FlightRecorderLogger,
+        logger_msg: str,
+        frame_formatter: Any,
+        total_numel: tuple[int, int] | None = None,
+        errors: set[tuple[int, MatchInfo]] | None = None,
+        missing_ranks: set[int] | None = None,
+    ) -> None:
+        logger.info(
+            logger_msg,
+            self.collective_seq_id,
+        )
+        logger.info("internal record id: %s", self.record_id)
+        logger.info("group info: %s", self.pg_desc)
+        logger.info("collective: %s", self.profiling_name)
+        if missing_ranks:
+            self.missing_ranks = missing_ranks
+            logger.info("missing ranks: %s", missing_ranks)
+        if total_numel:
+            self.input_numel = total_numel[0]
+            self.output_numel = total_numel[1]
+            logger.info("total input numel: %d", total_numel[0])
+            logger.info("total output numel: %d", total_numel[1])
+        logger.info("input sizes: %s", self.input_sizes)
+        logger.info("output sizes: %s", self.output_sizes)
+        logger.info("world size: %d", len(self.expected_ranks))
+        logger.info("expected ranks: %s", str(self.expected_ranks))
+        logger.info("collective state: %s", self.collective_state)
+        if errors:
+            self.errors = errors
+            error_msg = ", ".join(
+                f"Culprit rank {error[0]}; {str(error[1])}" for error in errors
+            )
+            logger.info("error msg: %s", error_msg)
+        logger.info(
+            "collective stack trace: \n %s", frame_formatter(self.collective_frames)
+        )
+
+    def to_collective(
+        self,
+        id: int,
+        errors: set[tuple[int, MatchInfo]] | None = None,
+        idx_map: dict[int, int] | None = None,
+        all_entries: dict[int, list[dict[str, Any]]] | None = None,
+    ) -> Collective:
+        if not errors:
+            return Collective(
+                id=id,
+                group_id=self.pg_name,
+                record_id=self.record_id,
+                pg_desc=self.pg_desc,
+                pass_check=True,
+                collective_seq_id=self.collective_seq_id,
+                p2p_seq_id=self.p2p_seq_id,
+                collective_name=self.profiling_name,
+                input_sizes=self.input_sizes,
+                output_sizes=self.output_sizes,
+                expected_ranks=self.expected_ranks,
+                collective_state=self.collective_state,
+                collective_frames=self.collective_frames,
+                missing_ranks=getattr(self, "missing_ranks", None),
+            )
+        else:
+            assert idx_map is not None, "idx_map is None"
+            assert all_entries is not None, "all_entries is None"
+            mismatch_collectives = {}
+            for rank, error in errors:
+                idx = idx_map[rank]
+                entry = all_entries[rank][idx]
+                desc = entry["process_group"][1]
+                pg_name = entry["process_group"][0]
+                mismatch_collectives[rank] = Collective(
+                    id=id,
+                    group_id=entry["process_group"][0],
+                    record_id=entry["record_id"],
+                    pg_desc=f"{pg_name}:{desc}" if desc != "undefined" else pg_name,
+                    pass_check=False,
+                    collective_seq_id=entry["collective_seq_id"],
+                    p2p_seq_id=entry["p2p_seq_id"],
+                    collective_name=entry["profiling_name"],
+                    input_sizes=entry["input_sizes"],
+                    output_sizes=entry["output_sizes"],
+                    expected_ranks=self.expected_ranks,
+                    collective_state=entry["state"],
+                    collective_frames=entry.get("frames", []),
+                    type_of_mismatch=error,
+                )
+            return Collective(
+                id=id,
+                group_id=self.pg_name,
+                record_id=self.record_id,
+                pg_desc=self.pg_desc,
+                pass_check=False,
+                collective_seq_id=self.collective_seq_id,
+                p2p_seq_id=self.p2p_seq_id,
+                collective_name=self.profiling_name,
+                input_sizes=self.input_sizes,
+                output_sizes=self.output_sizes,
+                expected_ranks=self.expected_ranks,
+                collective_state=self.collective_state,
+                collective_frames=self.collective_frames,
+                input_numel=self.input_numel if hasattr(self, "input_numel") else None,
+                output_numel=self.output_numel
+                if hasattr(self, "output_numel")
+                else None,
+                missing_ranks=self.missing_ranks
+                if hasattr(self, "missing_ranks")
+                else None,
+                mismatch_collectives=mismatch_collectives,
+            )
+
+    def to_nccl_call(
+        self,
+        all_entries: dict[int, list[dict[str, Any]]],
+        idx_map: dict[int, int],
+        nccl_call_id: int,
+        collective_id: Any,
+    ) -> list[NCCLCall]:
+        result = []
+        for i, k in idx_map.items():
+            all_entries[i].pop(k)
+            result.append(
+                NCCLCall(
+                    id=nccl_call_id,
+                    collective_id=collective_id,
+                    group_id=self.pg_name,  # type: ignore[arg-type]
+                    global_rank=i,
+                    traceback_id=0,  # type: ignore[arg-type]
+                    collective_type=self.profiling_name,
+                    sizes=self.input_sizes,
+                )
+            )
+            nccl_call_id += 1
+        return result
+
+
+class Op:
+    """Parses relevant info about operation out of 'event' dict
+
+    examples of supported `profiling_name`s:
+        nccl:broadcast
+        nccl:send 1->2
+        nccl:recv 3<-0
+    """
+
+    def __init__(
+        self, event: dict[Any, Any], memberships: dict[str, set[Any]], pg_name: str
+    ):
+        self.profiling_name = event["profiling_name"]
+        comm_lib_backend, name = self.profiling_name.split(":")
+        assert comm_lib_backend in ["nccl", "xccl"], (
+            f"name formatting error? {comm_lib_backend} != 'nccl' or 'xccl'"
+        )
+        parts = name.split(" ")
+        type = parts[0]
+        meta = parts[1] if len(parts) == 2 else None
+        self.state = event["state"]
+        # Store the hashed pg_name for accessing memberships, and original pg info for display
+        self.pg_name = pg_name  # This is the hashed version used for memberships lookup
+        self.original_pg_name, self.pg_desc = event["process_group"]
+        assert type in COLLECTIVES | P2P | {"coalesced"}, (
+            f"{type} is not a supported operation"
+        )
+        self.type = type
+        if type == "send":
+            assert isinstance(meta, str)
+            s, d = meta.split("->")
+            self._src, self._dst = int(s), int(d)
+        elif type == "recv":
+            assert isinstance(meta, str)
+            d, s = meta.split("<-")
+            self._dst, self._src = int(d), int(s)
+        else:
+            self._src, self._dst = -1, -1
+        self._init_global_src_dst(memberships[pg_name])
+        self.pg_size = len(memberships[pg_name])
+        if type in P2P | COLLECTIVES:
+            self.input_sizes = event["input_sizes"]
+            self.output_sizes = event["output_sizes"]
+        else:
+            self.input_sizes, self.output_sizes = None, None
+        self.collective_seq_id = event["collective_seq_id"]
+        self.stack_id = event.get("stack_id", -1)
+        self.p2p_seq_id = event["p2p_seq_id"]
+        self.input_dtypes = event["input_dtypes"]
+        self.output_dtypes = event["output_dtypes"]
+        self.time_created_ns = event["time_created_ns"]
+        self.collective_frames = event.get("frames", [])
+        self.is_verbose = os.getenv("FR_TRACE_VERBOSE_OUTPUT", "0") == "1"
+
+    def _init_global_src_dst(self, pg_ranks: set[Any]) -> None:
+        pg_ranks_sorted = sorted(pg_ranks)
+        self._src_g = pg_ranks_sorted[self._src] if self._src is not None else None
+        self._dst_g = pg_ranks_sorted[self._dst] if self._dst is not None else None
+
+    @property
+    def src(self) -> int:
+        assert self.type in P2P, "can't get src of non-p2p op"
+        return self._src
+
+    @property
+    def dst(self) -> int:
+        assert self.type in P2P, "can't get dst of non-p2p op"
+        return self._dst
+
+    def __repr__(self) -> str:
+        p2p_info = ""
+        if self.type in P2P:
+            p2p_info = f"s={self._src_g} d={self._dst_g}"
+        if self.is_verbose:
+            verbose_info = (
+                f"timestamp_created={self.time_created_ns}",
+                p2p_info,
+                f"input_sizes={self.input_sizes}",
+                f"output_sizes={self.output_sizes}",
+                f"input_dtypes={self.input_dtypes}",
+                f"output_dtypes={self.output_dtypes}",
+                "collective_seq_id | p2p_seq_id="
+                f"{self.p2p_seq_id if self.type in P2P else self.collective_seq_id}",
+                f"pg_name={self.pg_name}",
+                f"pg_description={self.pg_desc}",
+                f"pg_size={self.pg_size}",
+                f"stack_id={self.stack_id}",
+                f"state={self.state}",
+            )
+            return f"{self.type}(%s)" % ", ".join(s for s in verbose_info if s)
+        return f"{self.type}(%sinput_sizes={self.input_sizes}, state={self.state})" % (
+            f"{p2p_info}, " if p2p_info else ""
+        )
+
+    def dtype_mismatch(self, other: "Op") -> bool:
+        if (
+            (
+                self.type not in ["scatter", "gather", "broadcast"]
+                and set(self.input_dtypes) != set(self.output_dtypes)
+                and self.input_sizes[0]
+                and self.output_sizes[0]
+            )
+            or (
+                self.type not in ["scatter", "broadcast"]
+                and set(self.input_dtypes) != set(other.input_dtypes)
+                and self.input_sizes[0]
+                and other.input_sizes[0]
+            )
+            or (
+                self.type not in ["gather"]
+                and set(self.output_dtypes) != set(other.output_dtypes)
+                and self.output_sizes[0]
+                and other.output_sizes[0]
+            )
+        ):
+            return True
+        return False
+
+    def match(self, other: "Op") -> MatchInfo:
+        # TODO: I think this can validly not match,
+        # e.g. if one PG was used for p2p ops between only some of the peers?
+        # if self.seq_id != other.seq_id:
+        # return False
+
+        if self.type == "send":
+            # TODO: We need more states for p2p ops.
+            return (
+                MatchInfo(MatchState.FULLY_MATCHED)
+                if (
+                    other.type == "recv"
+                    and self.src == other.src
+                    and self.dst == other.dst
+                    and self.input_sizes == other.output_sizes
+                )
+                else MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH)
+            )
+        elif self.type == "recv":
+            return (
+                MatchInfo(MatchState.FULLY_MATCHED)
+                if (
+                    other.type == "send"
+                    and self.src == other.src
+                    and self.dst == other.dst
+                    and self.output_sizes == other.input_sizes
+                )
+                else MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH)
+            )
+        elif self.type in COLLECTIVES:
+            if self.type != other.type:
+                return MatchInfo(
+                    MatchState.COLLECTIVE_TYPE_MISMATCH,
+                    f"Expected collective type: '{self.type}' does not match found collective type: '{other.type}'",
+                )
+            if (
+                self.type not in ["all_to_all", "scatter"]
+                and self.input_sizes != other.input_sizes
+            ):
+                return MatchInfo(
+                    MatchState.SIZE_OR_SYNTAX_MISMATCH,
+                    f"Expected input sizes: '{self.input_sizes}' does not match found input sizes: "
+                    f"'{other.input_sizes}'",
+                )
+            if (
+                self.type not in ["all_to_all", "gather"]
+                and self.output_sizes != other.output_sizes
+            ):
+                return MatchInfo(
+                    MatchState.SIZE_OR_SYNTAX_MISMATCH,
+                    f"Expected output sizes: '{self.output_sizes}' does not match found output sizes: "
+                    f"'{other.output_sizes}'",
+                )
+            if (
+                self.type in ["all_reduce", "allreduce_coalesced"]
+                and self.input_sizes != other.output_sizes
+            ):
+                return MatchInfo(
+                    MatchState.SIZE_OR_SYNTAX_MISMATCH,
+                    f"Expected input sizes: '{self.input_sizes}' does not match found output sizes: '{other.output_sizes}'",
+                )
+            if (
+                self.type
+                in [
+                    "all_gather",
+                    "all_gather_base",
+                    "all_gather_into_tensor_coalesced",
+                ]
+                and math.prod(other.output_sizes[0])
+                != math.prod(self.input_sizes[0]) * self.pg_size
+            ):
+                return MatchInfo(
+                    MatchState.SIZE_OR_SYNTAX_MISMATCH,
+                    f"Found input numel '{math.prod(other.input_sizes[0])} * pg size {self.pg_size}' "
+                    f"does not match output numel '{math.prod(other.output_sizes[0])}'",
+                )
+            if (
+                self.type
+                in [
+                    "reduce_scatter",
+                    "_reduce_scatter_base",
+                    "reduce_scatter_tensor_coalesced",
+                ]
+                and math.prod(other.input_sizes[0])
+                != math.prod(self.output_sizes[0]) * self.pg_size
+            ):
+                return MatchInfo(
+                    MatchState.SIZE_OR_SYNTAX_MISMATCH,
+                    f"Found input numel '{math.prod(other.input_sizes[0])}' does not match output numel "
+                    f"'{math.prod(other.output_sizes[0])} * pg size {self.pg_size}'",
+                )
+            if self.dtype_mismatch(other):
+                return MatchInfo(
+                    MatchState.COLLECTIVE_DTYPE_MISMATCH,
+                    f"Expected dtypes: '{set(self.input_dtypes)}' does not "
+                    f"match found dtype: '{set(self.output_dtypes)}/"
+                    f"{set(other.input_dtypes)}/{set(other.output_dtypes)}'",
+                )
+            if self.state != other.state:
+                # MatchState()
+                return MatchInfo(
+                    MatchState.COLLECTIVE_STATE_MISMATCH,
+                    f"Expected state: '{self.state}' does not match found state: '{other.state}'",
+                )
+            if self.type == "all_to_all":
+                return MatchInfo(MatchState.UNDECIDED)
+        elif self.type in [
+            "coalesced",
+            "ALLGATHER_coalesced",
+            "REDUCE_SCATTER_coalesced",
+        ]:
+            return (
+                MatchInfo(MatchState.FULLY_MATCHED)
+                if (other.type == self.type)
+                else MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH)
+            )
+        return MatchInfo(MatchState.FULLY_MATCHED)
+
+
+class MatchStateRecord:
+    def __init__(
+        self,
+        expected_ranks: set[int],
+        other_ranks: list[int],
+        entry_state: EntryState,
+        candidate_ranks: set[int],
+        candidate_idx: dict[int, int],
+        found_ranks: set[int],
+        found_idx: dict[int, int],
+        errors: set[tuple[int, MatchInfo]],
+    ) -> None:
+        self.expected_ranks = expected_ranks
+        self.other_ranks = other_ranks
+        self.entry_state = entry_state
+        self.candidate_ranks = candidate_ranks
+        self.candidate_idx = candidate_idx
+        self.found_ranks = found_ranks
+        self.found_idx = found_idx
+        self.errors = errors
+        self.has_undecided_case = False
+
+    def reset_for_coalesced(
+        self, entry_state: EntryState, candidate_ranks: set[int]
+    ) -> None:
+        self.entry_state = entry_state
+        self.candidate_ranks = candidate_ranks
+        self.candidate_idx = {}
+        self.found_ranks = set()
+        self.found_idx = {}
+        self.errors = set()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ab7919a2a24d81e7be692bc8bb9b0c326a99b28
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/utils.py
@@ -0,0 +1,789 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import math
+from typing import Any
+
+from torch.distributed.flight_recorder.components.fr_logger import FlightRecorderLogger
+from torch.distributed.flight_recorder.components.types import (
+    Collective,
+    EntryState,
+    Group,
+    MatchInfo,
+    MatchState,
+    MatchStateRecord,
+    Membership,
+    Op,
+    P2P,
+)
+
+
+__all__ = [
+    "add_stack_id_in_entries",
+    "align_trace_from_beginning",
+    "check_current_entry_match",
+    "check_no_missing_dump_files",
+    "check_version",
+    "error_analysis",
+    "find_coalesced_group",
+    "find_coalesced_group_with_non_p2p",
+    "get_version_detail",
+    "just_print_entries",
+    "match_coalesced_groups_with_non_p2p",
+    "match_coalesced_groups",
+    "format_frame",
+    "format_frames",
+    "match_one_event",
+    "check_size_alltoall",
+]
+
+logger: FlightRecorderLogger = FlightRecorderLogger()
+
+
+try:
+    from tabulate import tabulate
+except ModuleNotFoundError:
+    logger.debug("tabulate is not installed. Proceeding without it.")
+
+
+def format_frame(frame: dict[str, str]) -> str:
+    name = frame["name"]
+    filename = frame["filename"]
+    line = frame["line"]
+    return f"{name} at {filename}:{line}"
+
+
+def format_frames(frames: list[dict[str, str]]) -> str:
+    formatted_frames = []
+    for frame in frames:
+        # pyrefly: ignore [bad-argument-type]
+        formatted_frames.append(format_frame(frame))
+    return "\n".join(formatted_frames)
+
+
+def match_one_event(
+    event_a: dict[Any, Any],
+    event_b: dict[Any, Any],
+    memberships: dict[str, set[Any]],
+    pg_name: str,
+) -> MatchInfo:
+    op_a = Op(event_a, memberships, pg_name)
+    op_b = Op(event_b, memberships, pg_name)
+    return op_a.match(op_b)
+
+
+def match_coalesced_groups(
+    all_rank_events: dict[Any, Any],
+    group_size: int,
+    groups: dict[str, Group],
+    memberships: dict[str, set[Any]],
+    _pg_guids: dict[tuple[str, int], str],
+) -> bool:
+    """
+    all_rank_events: {
+        rank: [
+            (idx, event_dict)
+        ]
+    }
+
+    Note: it is possible for event dicts in a coalesced group to be asymmetric.
+        e.g. the following events lists form a valid coalescing group
+             events0 [send:1]
+             events1 [recv:0, send:2]
+             events2 [recv:1]
+
+    Rule 1: all ops should find a match
+    Rule 2: relative ordering of sends and recvs in one event list can be arbitrary
+        e.g.
+        events1 [recv:0, send:2]  —> okay
+        events1 [send:2, recv:0] —> also okay
+    Rule 3: sends to the same dest or recvs from the src should be in a consistent order
+        e.g.
+        rank0 [send:1 (100B), send:1 (1000B)]
+        rank1 [recv:0 (1000B), recv:0 (100B)]   —> not okay
+    """
+    all_ops = {
+        rank: [
+            Op(e, memberships, _pg_guids[(e["process_group"][0], rank)])
+            for i, e in all_rank_events[rank]
+        ]
+        for rank in all_rank_events
+    }
+
+    def visualize_ops(
+        match: bool,
+        _pg_guids: dict[tuple[str, int], str],
+    ) -> None:
+        all_ops = {
+            rank: [
+                Op(e, memberships, _pg_guids[(e["process_group"][0], rank)])
+                for i, e in all_rank_events[rank]
+            ]
+            for rank in all_rank_events
+        }
+
+        i = 0
+        row = []
+        progress = True
+        table = []
+        while progress:
+            progress = False
+            for r in all_ops:
+                if len(all_ops[r]) > i:
+                    rank, event = all_rank_events[r][i]
+                    # Check if the pg_guid exists for this rank and process group
+                    pg_key = (event["process_group"][0], rank)
+                    if pg_key in _pg_guids:
+                        row.append(
+                            Op(
+                                event,
+                                memberships,
+                                _pg_guids[pg_key],
+                            )
+                        )
+                    else:
+                        # Skip this entry if pg_guid mapping doesn't exist
+                        row.append(None)  # type: ignore[arg-type]
+                    progress = True
+                else:
+                    row.append(None)  # type: ignore[arg-type]
+            table.append(row)
+            row = []
+            i += 1
+        title = "Match" if match else "MISMATCH"
+        logger.info("%s \n", title)
+        logger.info("%s", tabulate(table))  # type: ignore[operator]
+
+    # TODO can't verify seq_id bc there might have been valid seq deltas between ranks even within a pg.
+    for op_list in all_ops.values():
+        if not op_list:
+            # print("TODO- not sure if its valid for only some ranks in a PG to participate in a coalesced op?")
+            return False
+        assert op_list[-1].type == "coalesced"
+        op_list.pop(-1)
+
+    while all_ops:
+        first_rank = next(iter(all_ops))
+        my_ops = all_ops[first_rank]
+
+        if len(all_ops[first_rank]) == 0:
+            all_ops.pop(first_rank)
+            continue
+
+        # lets match the first collective! we need to know which ranks are involved, and ensure that this same
+        # collective is also the first one on those ranks within that group
+        op = my_ops[0]
+        match_idx = -1
+        if op.type in P2P:
+            dst_global_rank = sorted(memberships[op.pg_name])[op.dst]
+            peer_ops = all_ops[dst_global_rank]
+            for i, other in enumerate(peer_ops):
+                if op.match(other).state == MatchState.FULLY_MATCHED:
+                    match_idx = i
+                    break
+                elif op.dst == other.src:
+                    # Rule 3
+                    break
+                else:
+                    # Rule 1
+                    continue
+        else:
+            raise NotImplementedError("coalesced collective ops")
+        if match_idx >= 0:
+            my_ops.pop(0)
+            peer_ops.pop(match_idx)
+        else:
+            visualize_ops(False, _pg_guids)
+            return False
+
+    visualize_ops(True, _pg_guids)
+    return True
+
+
+# We enabled the creating FR entry for non-P2P slow path collective ops in v2.7.
+def match_coalesced_groups_with_non_p2p(
+    all_rank_events: dict[Any, Any],
+    pg_info: tuple[str, str],
+    memberships: dict[str, set[Any]],
+    _pg_guids: dict[tuple[str, int], str],
+    mismatch: dict[str, int],
+    dumps_ranks: set[int],
+    version: str,
+    collectives: list[Collective],
+    match_record: MatchStateRecord,
+) -> bool:
+    """
+    all_rank_events: {
+        rank: [
+            (idx, event_dict)
+        ]
+    }
+
+    Note: it is possible for event dicts in a coalesced group to be asymmetric.
+        e.g. the following events lists form a valid coalescing group
+             events0 [send:1]
+             events1 [recv:0, send:2]
+             events2 [recv:1]
+
+    Rule 1: all ops should find a match
+    Rule 2: relative ordering of sends and recvs in one event list can be arbitrary
+        e.g.
+        events1 [recv:0, send:2]  —> okay
+        events1 [send:2, recv:0] —> also okay
+    Rule 3: sends to the same dest or recvs from the src should be in a consistent order
+        e.g.
+        rank0 [send:1 (100B), send:1 (1000B)]
+        rank1 [recv:0 (1000B), recv:0 (100B)]   —> not okay
+    """
+    all_ops = {
+        rank: [
+            Op(e, memberships, _pg_guids[(e["process_group"][0], rank)])
+            for _, e in all_rank_events[rank]
+        ]
+        for rank in all_rank_events
+    }
+    is_p2p = any(op.type in P2P for ops in all_ops.values() for op in ops)
+    pg_name = pg_info[0]
+
+    def visualize_ops(
+        match: bool,
+        _pg_guids: dict[tuple[str, int], str],
+    ) -> None:
+        all_ops = {
+            rank: [
+                Op(e, memberships, _pg_guids[(e["process_group"][0], rank)])
+                for _, e in all_rank_events[rank]
+            ]
+            for rank in all_rank_events
+        }
+
+        i = 0
+        row = []
+        progress = True
+        table = []
+        while progress:
+            progress = False
+            for r in all_ops:
+                if len(all_ops[r]) > i:
+                    rank, event = all_rank_events[r][i]
+                    # Check if the pg_guid exists for this rank and process group
+                    pg_key = (event["process_group"][0], rank)
+                    if pg_key in _pg_guids:
+                        row.append(
+                            Op(
+                                event,
+                                memberships,
+                                _pg_guids[pg_key],
+                            )
+                        )
+                    else:
+                        # Skip this entry if pg_guid mapping doesn't exist
+                        row.append(None)  # type: ignore[arg-type]
+                    progress = True
+                else:
+                    row.append(None)  # type: ignore[arg-type]
+            table.append(row)
+            row = []
+            i += 1
+        title = "Match" if match else "MISMATCH"
+        logger.info("%s \n", title)
+        logger.info("%s", tabulate(table))  # type: ignore[operator]
+
+    # TODO Need to verify no seq_id deltas for P2P ops.
+    for rank, op_list in all_ops.items():
+        if not op_list:
+            logger.error("Rank %s has an empty op list.", rank)
+            continue
+        if op_list[-1].type == "coalesced" and is_p2p:
+            op_list.pop(-1)
+
+    while all_ops:
+        first_rank = next(iter(all_ops))
+        my_ops = all_ops[first_rank]
+
+        if len(all_ops[first_rank]) == 0:
+            all_ops.pop(first_rank)
+            continue
+
+        # lets match the first collective! we need to know which ranks are involved, and ensure that this same
+        # collective is also the first one on those ranks within that group
+        op = my_ops[0]
+        match_idx = -1
+        if is_p2p:
+            dst_global_rank = sorted(memberships[op.pg_name])[op.dst]
+            peer_ops = all_ops[dst_global_rank]
+            for i, other in enumerate(peer_ops):
+                if op.match(other).state == MatchState.FULLY_MATCHED:
+                    match_idx = i
+                    break
+                elif op.dst == other.src:
+                    # Rule 3
+                    break
+                else:
+                    # Rule 1
+                    continue
+            if match_idx >= 0:
+                my_ops.pop(0)
+                peer_ops.pop(match_idx)
+            else:
+                visualize_ops(False, _pg_guids)
+                return False
+        else:
+            all_coalesced_entries = {
+                rank: [e for _, e in all_rank_events[rank]] for rank in all_rank_events
+            }
+            current_entry = all_coalesced_entries[first_rank][0]
+            my_ops.pop(0)
+
+            match_record.reset_for_coalesced(
+                EntryState(current_entry, match_record.expected_ranks),
+                {first_rank},
+            )
+
+            # Iterate through all the ranks and check if there is a mismatch for the current entry.
+            check_current_entry_match(
+                all_coalesced_entries,
+                _pg_guids,
+                pg_info,
+                current_entry,
+                memberships,
+                mismatch,
+                match_record,
+            )
+
+            # Use heuristics to decide what type of errors and error messages we should print.
+            error_analysis(
+                all_coalesced_entries,
+                match_record,
+                dumps_ranks,
+                first_rank,
+                current_entry,
+                mismatch,
+                get_version_detail(version),
+                pg_info[0],
+            )
+
+            # TODO: For now, we only check the correctness of individual collective within a coalesced one in
+            # this script. We need to merge  (e.g, input/output sizes) together
+            # for downstream consumer.
+
+            # at this point there are 3 possibilities
+            # 1. we found a match on all the ranks that are members of the group
+            #  -> we create a Collective and remove the individual entries from their original lists
+            if (
+                match_record.found_ranks == match_record.expected_ranks
+                and mismatch[pg_name] == 0
+            ):
+                # Just pop out this collective.
+                idx_map = {
+                    r: match_record.found_idx[r] if r != first_rank else 0
+                    for r in match_record.found_ranks
+                }
+                for i, k in idx_map.items():
+                    all_rank_events[i].pop(k)
+                for r in match_record.found_ranks:
+                    if r != first_rank:
+                        all_ops[r].pop(0)
+
+            # 2. we found a partial match but some ranks are missing
+            # 3. we found no match
+            #  -> since its not a complete collective, no entry goes into collectives but we still record a nccl call
+            else:
+                logger.debug("Non-matching collective inside coalesced group")
+                idx_map = {
+                    r: match_record.candidate_idx[r] if r != first_rank else 0
+                    for r in match_record.candidate_ranks
+                }
+                collectives.append(
+                    match_record.entry_state.to_collective(
+                        len(collectives),
+                        errors=match_record.errors,
+                        idx_map=idx_map,
+                        all_entries=all_coalesced_entries,
+                    )
+                )
+                return False
+
+    if is_p2p:
+        visualize_ops(True, _pg_guids)
+    return True
+
+
+def check_size_alltoall(alltoall_cases: list[dict[str, Any]]) -> tuple[bool, int, int]:
+    input_numel = 0
+    output_numel = 0
+    for e in alltoall_cases:
+        input_numel += math.prod(e["input_sizes"][0])
+        output_numel += math.prod(e["output_sizes"][0])
+    return input_numel != output_numel, input_numel, output_numel
+
+
+def check_current_entry_match(
+    all_entries: dict[int, list[dict[str, Any]]],
+    _pg_guids: dict[tuple[str, int], str],
+    pg_info: tuple[str, str],
+    current_entry: dict[str, Any],
+    _memberships: dict[str, set[Any]],
+    mismatch: dict[str, int],
+    match_record: MatchStateRecord,
+) -> None:
+    pg_name, desc = pg_info[0], pg_info[1]
+    for o in match_record.expected_ranks.intersection(set(match_record.other_ranks)):
+        for i, e in enumerate(all_entries[o]):  # type: ignore[index]
+            # step over ops from other PGs
+            # only check match state when seq_id matches
+            if (
+                _pg_guids[(e["process_group"][0], o)] == pg_name
+                and e["process_group"][1] == desc
+                and e["collective_seq_id"] == match_record.entry_state.collective_seq_id
+            ):
+                match_info = match_one_event(current_entry, e, _memberships, pg_name)
+                if (
+                    match_info.state in [MatchState.FULLY_MATCHED, MatchState.UNDECIDED]
+                    and mismatch[pg_name] == 0
+                ):
+                    match_record.found_ranks.add(o)
+                    match_record.found_idx[o] = i
+                    match_record.has_undecided_case = (
+                        match_info.state == MatchState.UNDECIDED
+                    )
+                else:
+                    match_record.candidate_ranks.add(o)
+                    match_record.candidate_idx[o] = i
+                    if match_info.state not in [
+                        MatchState.FULLY_MATCHED,
+                        MatchState.UNDECIDED,
+                    ]:
+                        # Here we assume the current rank is not the source of the error.
+                        # But it's possible that the current rank is the culprit, then users will
+                        # see lots of normal ranks reported as culprit.
+                        # TODO: we need to figure out a better way to handle the case mentioned above.
+                        match_record.errors.add((o, match_info))
+                break
+
+
+def error_analysis(
+    all_entries: dict[int, list[dict[str, Any]]],
+    match_record: MatchStateRecord,
+    dumps_ranks: set[int],
+    first_rank: int,
+    current_entry: dict[str, Any],
+    mismatch: dict[str, int],
+    version: tuple[int, int],
+    pg_name: str,
+) -> None:
+    major_v, minor_v = version[0], version[1]
+    # case one: not every rank join the collective or in the flight recorder.
+    if (
+        match_record.candidate_ranks | match_record.found_ranks
+    ) != match_record.expected_ranks and match_record.expected_ranks - (
+        match_record.candidate_ranks | match_record.found_ranks
+    ) <= dumps_ranks:
+        mismatch[pg_name] += 1
+        logger_msg = "Not all ranks joining collective, sequence number: %s"
+        missing_ranks = match_record.expected_ranks - (
+            match_record.candidate_ranks | match_record.found_ranks
+        )
+        match_record.entry_state.log(
+            logger, logger_msg, format_frames, missing_ranks=missing_ranks
+        )
+        match_record.candidate_ranks.update(match_record.found_ranks)
+        match_record.candidate_idx.update(match_record.found_idx)
+        match_record.found_idx.clear()
+        match_record.found_ranks.clear()
+    # We didn't see any mismatch and all expected ranks are in the dump.
+    elif len(
+        match_record.candidate_ranks
+    ) == 1 and match_record.expected_ranks.issubset(dumps_ranks):
+        # case two: alltoall or alltoall_base case.
+        if match_record.has_undecided_case:
+            alltoall_cases = [current_entry] + [
+                all_entries[o][match_record.found_idx[o]]
+                for o in match_record.found_ranks
+            ]
+            fail_check, total_input_numel, total_output_numel = check_size_alltoall(
+                alltoall_cases
+            )
+            if major_v <= 2 and minor_v <= 3:
+                # We don't log the input/output sizes for alltoall before v2.4,
+                # so we don't consider the size mismatch as an error for now.
+                fail_check = False
+            if fail_check:
+                # When we see errors in all_to_all, it's hard to tell which rank is the source of the error.
+                mismatch[pg_name] += 1
+                logger_msg = (
+                    "Input/output mismatch in the collective sequence number: %s"
+                )
+                match_record.entry_state.log(
+                    logger,
+                    logger_msg,
+                    format_frames,
+                    total_numel=(total_input_numel, total_output_numel),
+                )
+                match_record.candidate_ranks.update(match_record.found_ranks)
+                match_record.candidate_idx.update(match_record.found_idx)
+                match_record.found_idx.clear()
+                match_record.found_ranks.clear()
+                match_record.errors.add(
+                    (first_rank, MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH))
+                )
+            else:
+                match_record.found_ranks.update(match_record.candidate_ranks)
+                match_record.found_idx.update(match_record.candidate_idx)
+                match_record.candidate_idx.clear()
+                match_record.candidate_ranks.clear()
+        # case three: all joined and everything matches on all ranks.
+        else:
+            match_record.found_ranks.update(match_record.candidate_ranks)
+            match_record.found_idx.update(match_record.candidate_idx)
+            match_record.candidate_idx.clear()
+            match_record.candidate_ranks.clear()
+    # case four: mismatch cases due to not same type, size mismatch or state mismatch.
+    elif len(match_record.errors) > 0:
+        mismatch[pg_name] += 1
+        logger_msg = "Collective sequence number: %s has errors"
+        match_record.entry_state.log(
+            logger, logger_msg, format_frames, errors=match_record.errors
+        )
+        match_record.candidate_ranks.update(match_record.found_ranks)
+        match_record.candidate_idx.update(match_record.found_idx)
+        match_record.found_idx.clear()
+        match_record.found_ranks.clear()
+    # partial analysis case when we cannot decide what's wrong with this collective entry.
+    else:
+        match_record.candidate_ranks.update(match_record.found_ranks)
+        match_record.candidate_idx.update(match_record.found_idx)
+        match_record.found_idx.clear()
+        match_record.found_ranks.clear()
+        # if any element in expected_ranks not in dumps_ranks.
+        if match_record.expected_ranks - dumps_ranks:
+            mismatch[pg_name] += 1
+            logger.info(
+                "We cannot decide what's wrong with this collective entry "
+                "because we missed FR dumps from ranks (%s) so we don't have enough "
+                "information. If you want to debug further use -j to dump all raw trace",
+                str(match_record.expected_ranks - dumps_ranks),
+            )
+        else:
+            logger.info(
+                "No errors found for this collective entry, There could be some "
+                "other reasons why we see collective timeout."
+            )
+
+
+def find_coalesced_group(
+    pg_name: str,
+    entries: list[dict[str, Any]],
+    _pg_guids: dict[tuple[str, int], str],
+    rank: int,
+) -> list[tuple[int, dict[str, Any]]]:
+    """Given a list of entries, if the collective_seq_id of the first entry matches that of subsequent ones,
+    build an return a list of entries terminating in a 'coalesced' op entry all sharing a collective_seq_id
+    """
+    found = []
+    collective_seq_id = None
+    for i, e in enumerate(entries):
+        if _pg_guids[(e["process_group"][0], rank)] != pg_name:
+            continue
+        elif collective_seq_id is None:
+            collective_seq_id = (
+                e["p2p_seq_id"] if e["is_p2p"] else e["collective_seq_id"]
+            )
+            found.append((i, e))
+        elif not e["is_p2p"] and e["collective_seq_id"] == collective_seq_id:
+            found.append((i, e))
+        elif e["is_p2p"] and e["p2p_seq_id"] == collective_seq_id:
+            found.append((i, e))
+        else:
+            break
+
+    if len(found) > 1:
+        assert found[-1][1]["profiling_name"] == "nccl:coalesced"
+        return found
+    return []
+
+
+# We enabled the creating FR entry for non-P2P slow path collective ops in v2.7.
+def find_coalesced_group_with_non_p2p(
+    pg_name: str,
+    entries: list[dict[str, Any]],
+    _pg_guids: dict[tuple[str, int], str],
+    rank: int,
+) -> list[tuple[int, dict[str, Any]]]:
+    """Given a list of entries, if the collective_seq_id of the first entry matches that of subsequent ones,
+    build an return a list of entries terminating in a 'coalesced' op entry all sharing a collective_seq_id
+    """
+    found = []
+    collective_seq_id = None
+    for i, e in enumerate(entries):
+        if _pg_guids[(e["process_group"][0], rank)] != pg_name:
+            continue
+        elif collective_seq_id is None:
+            collective_seq_id = (
+                e["p2p_seq_id"] if e["is_p2p"] else e["collective_seq_id"]
+            )
+            found.append((i, e))
+        elif not e["is_p2p"] and e["collective_seq_id"] == collective_seq_id:
+            found.append((i, e))
+        elif e["is_p2p"] and e["p2p_seq_id"] == collective_seq_id:
+            found.append((i, e))
+        else:
+            break
+
+    if len(found) > 1:
+        name = found[-1][1]["profiling_name"]
+        if name.startswith("nccl:") and not name.endswith("_coalesced"):
+            logger.error("Rank %s does not have a coalesced end.", rank)
+        return found
+    return []
+
+
+def just_print_entries(
+    all_entries: dict[int, list[dict[str, Any]]],
+    _groups: dict[str, Group],
+    _memberships: dict[str, set[Any]],
+    _pg_guids: dict[tuple[str, int], str],
+    args: argparse.Namespace,
+    stack_id_trace_map: dict[str, int],
+) -> None:
+    rows = []
+    ranks = sorted(all_entries.keys())
+    headers = [
+        f"Rank {rank}"
+        for rank in ranks
+        if args.selected_ranks is None or rank in args.selected_ranks
+    ]
+    progress = True
+    while progress:
+        progress = False
+        row = []
+        for rank in ranks:
+            if args.selected_ranks is not None and rank not in args.selected_ranks:
+                continue
+            if len(all_entries[rank]) == 0:
+                row.append("")
+            else:
+                entry = all_entries[rank].pop(0)
+                pg_name = _pg_guids[(entry["process_group"][0], rank)]
+                if (
+                    args.pg_filters is None
+                    or entry["process_group"][1] in args.pg_filters
+                    or entry["process_group"][0] in args.pg_filters
+                ):
+                    row.append(str(Op(entry, _memberships, pg_name)))
+                else:
+                    row.append("")
+                progress = True
+        if progress:
+            rows.append(row)
+
+    logger.info(tabulate(rows, headers=headers))
+
+    if stack_id_trace_map and args.print_stack_trace:
+        headers = ["stack_id", "frame_stack"]
+        rows = []
+
+        for frame, stack_id in sorted(
+            stack_id_trace_map.items(), key=lambda item: item[1]
+        ):
+            rows.append([str(stack_id), frame])
+
+        logger.info(tabulate(rows, headers=headers))
+
+
+def check_no_missing_dump_files(
+    entries: dict[int, Any], memberships: list[Membership]
+) -> None:
+    all_ranks = set()
+    for membership in memberships:
+        all_ranks.add(int(membership.global_rank))
+    dumps_ranks = {int(key) for key in entries}
+    missing = all_ranks - dumps_ranks
+    assert len(missing) == 0, f"Missing dump files from ranks {missing}"
+
+
+def check_version(version_by_ranks: dict[str, str], version: str) -> None:
+    for rank, v in version_by_ranks.items():
+        assert v == version, (
+            f"Rank {rank} has different version {v} from the given version {version}"
+        )
+
+
+def get_version_detail(version: str) -> tuple[int, int]:
+    # pyrefly: ignore [bad-assignment]
+    version = version.split(".")
+    assert len(version) == 2, f"Invalid version {version}"
+    major, minor = map(int, version)
+    return major, minor
+
+
+def add_stack_id_in_entries(
+    entries: dict[int, list[dict[str, Any]]],
+) -> tuple[dict[int, list[dict[str, Any]]], dict[str, int]]:
+    stack_id = 0
+    stack_id_trace_map = {}
+    for rank in entries:
+        for dump in entries[rank]:
+            if dump.get("frames", []):
+                frames = str(dump["frames"])
+                if frames not in stack_id_trace_map:
+                    stack_id_trace_map[frames] = stack_id
+                    dump["stack_id"] = stack_id
+                    stack_id += 1
+                else:
+                    dump["stack_id"] = stack_id_trace_map[frames]
+            else:
+                dump["stack_id"] = -1
+
+    return entries, stack_id_trace_map
+
+
+def align_trace_from_beginning(
+    entries: dict[int, list[dict[str, Any]]],
+) -> dict[int, list[dict[str, Any]]]:
+    """
+    Align the trace entries by record ID for entries.
+    This function takes a dictionary of rank names to lists of trace entries as input.
+    Each trace entry is a dictionary containing information about a collective operation,
+    including its unique identifier (`record_id` is monotonically increasing as we write into the ring buffer).
+    The function finds the largest starting point across all ranks by taking the maximum
+    `record_id` value of the first entry in each rank. Finally, it filters out any
+    entries with `record_id` values less than the maximum starting point.
+    The function returns the updated dictionary of sorted and filtered trace entries.
+
+    Args:
+        entries (Dict[str, List[Dict[str, Any]]]): A dictionary of rank names to lists of trace entries.
+
+    Returns:
+        entries (Dict[str, List[Dict[str, Any]]]): Entries sorted by record ID and filtered by the maximum starting point.
+    """
+
+    maximum_starting_record_id = 0
+    for rank in entries:
+        # Although this is a ring buffer, we already sort the entries by `record_id` when dumping, we just
+        # need to find the largest starting point. For example, if the buffer has the following entries:
+        # Rank 0: [0, 1, 2, 3, 4, 5, 6]
+        # Rank 1: [1, 2, 3, 4, 5, 6, 7]
+        # Rank 2: [2, 3, 4, 5, 6, 7, 8]
+        # Rank 3: [0, 1, 2, 3, 4, 5, None]
+        # Then we should start from collective 2 not 0 because any collective before,
+        # we don't have complete records from all ranks so we need to ignore them.
+        # If we don't have any trace from some ranks, ignore them
+        # as well.
+        if len(entries[rank]) == 0:
+            continue
+        first_record_id = entries[rank][0]["record_id"]
+        maximum_starting_record_id = max(maximum_starting_record_id, first_record_id)
+
+    for rank in entries:
+        entries[rank] = [
+            entry
+            for entry in entries[rank]
+            if entry["record_id"] >= maximum_starting_record_id
+        ]
+
+    return entries
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_runtime_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_runtime_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..655cfb03534d5296346bfd7bd6099355c18e592d
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_runtime_utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4d0b341a3f82b35fc903ccffd5208d8fdade399
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__init__.py
@@ -0,0 +1,20 @@
+from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
+from ._fully_shard import (
+    FSDPModule,
+    fully_shard,
+    register_fsdp_forward_method,
+    share_comm_ctx,
+    UnshardHandle,
+)
+
+
+__all__ = [
+    "CPUOffloadPolicy",
+    "FSDPModule",
+    "fully_shard",
+    "MixedPrecisionPolicy",
+    "OffloadPolicy",
+    "register_fsdp_forward_method",
+    "UnshardHandle",
+    "share_comm_ctx",
+]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c6ff32bad693998119d8d70d9a72920ccdb622c2
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_api.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_api.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..573a9a1e0ef7a29c0bd5fda0880acc236bb95a8b
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_api.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_collectives.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_collectives.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..66a2ead8e76f2ebedfb4336d7bf241415d942744
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_collectives.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_common.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_common.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2676f2e896a02c2031e26a5d2fda3f0444fe6d61
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_common.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_init.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_init.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a46d77b30ea90c95d27c72a444522161393e8ce
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_init.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_param.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_param.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..29767b3f97f32a1ca5318214ff18de0eb71a8915
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_param.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_param_group.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_param_group.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4cfa03b498e71638d485f8e626b92e778f3c4087
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_param_group.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_state.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_state.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bafd4b5da03176d1cd34befffa8d290ec9c34a77
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_state.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fully_shard.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fully_shard.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9a7dda55869dae94feddf834f2bb96ea557c6bc8
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fully_shard.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_api.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_api.py
new file mode 100644
index 0000000000000000000000000000000000000000..38650323f5e99727f04964ca59fb268ca8e7b65c
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_api.py
@@ -0,0 +1,155 @@
+# mypy: allow-untyped-defs
+from abc import ABC, abstractmethod
+from collections.abc import Sequence
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+import torch.distributed as dist
+
+
+_ReduceOp = Union[dist.ReduceOp, dist.ReduceOp.RedOpType]
+
+
+@dataclass(frozen=True)
+class MixedPrecisionPolicy:
+    """
+    This configures FSDP's mixed precision. Unlike autocast, this applies mixed
+    precision at the module level, not op level, which means low-precision
+    activations are saved for backward and high-to-low-precision casts are
+    incurred only at module boundaries.
+
+    FSDP works well with module-level mixed precision since it keeps the
+    high-precision sharded parameters in memory anyway. In other words, FSDP
+    does not require any extra memory to keep a high-precision copy of the
+    parameters for the optimizer step.
+
+    Attributes:
+        param_dtype (Optional[torch.dtype]): This specifies the dtype for
+            the unsharded parameter and hence the dtype for forward/backward
+            computation and the parameter all-gather. If this is ``None``, then
+            the unsharded parameter uses the original dtype. The optimizer step
+            uses the sharded parameter in the original dtype. (Default:
+            ``None``)
+        reduce_dtype (Optional[torch.dtype]): This specifies the dtype for
+            gradient reduction (i.e. reduce-scatter or all-reduce). If this is
+            ``None`` but ``param_dtype`` is not ``None``, then the reduction
+            uses the compute dtype. This can be used to run gradient reduction
+            in full precision while using low precision for compute. If also
+            gradient reduction is disabled via :meth:`set_requires_gradient_sync`,
+            then FSDP will accumulate gradients using ``reduce_dtype``.
+            (Default: ``None``)
+        output_dtype (Optional[torch.dtype]): This specifies the dtype for
+            casting floating-point forward outputs. This can be used to
+            help implement cases where different modules have different mixed
+            precision policies. (Default: ``None``)
+        cast_forward_inputs (bool): This specifies whether FSDP should cast the
+            forward's floating-point input tensors to ``param_dtype`` or not.
+    """
+
+    param_dtype: Optional[torch.dtype] = None
+    reduce_dtype: Optional[torch.dtype] = None
+    output_dtype: Optional[torch.dtype] = None
+    cast_forward_inputs: bool = True
+
+
+class Comm(ABC):
+    """
+    Interface for communication primitives.
+    A primitive primarily needs to handle 3 tasks, namely:
+
+    1. How to allocate memory for communication
+       Depending on the goal, an implementation can choose to:
+       a. associate each call to a temporary buffer
+          (best for flexibility and simplicity)
+       b. reuse an persistent buffer for efficiency reasons
+
+    2. Where to allocate memory
+       (e.g. NCCL mem pool or regular cuda caching allocator)
+
+    3. What to do/call upon the comm is called
+       (see `AllGather` interface as an example)
+    """
+
+    @abstractmethod
+    def allocate(
+        self,
+        size: Sequence[Union[int, torch.SymInt]],
+        *,
+        dtype: torch.dtype,
+        device: torch.device,
+    ) -> torch.Tensor:
+        """
+        This handles the "how to allocate memory" part.
+
+        A default implementation could be simply:
+
+        .. code-block:: python
+            with self.mem_pool:
+                torch.empty(...)
+
+        Args:
+            size (Sequence[Union[int, torch.SymInt]]): size of the tensor buffer
+            dtype (torch.dtype): dtype of the tensor buffer
+            device (torch.device): which device to allocate the tensor onto
+        """
+        ...
+
+
+class AllGather(Comm):
+    """
+    Interface for all_gather comm primitive
+    """
+
+    @abstractmethod
+    def __call__(
+        self,
+        output_tensor: torch.Tensor,
+        input_tensor: torch.Tensor,
+        group: dist.ProcessGroup,
+        async_op: bool = False,
+    ) -> Optional[dist.Work]: ...
+
+
+class ReduceScatter(Comm):
+    """
+    Interface for reduce_scatter comm primitive
+    """
+
+    @abstractmethod
+    def __call__(
+        self,
+        output_tensor: torch.Tensor,
+        input_tensor: torch.Tensor,
+        group: dist.ProcessGroup,
+        op: _ReduceOp,
+        async_op: bool = False,
+    ) -> Optional[dist.Work]: ...
+
+
+@dataclass
+class OffloadPolicy:
+    """
+    This base class represents the policy of no offloading and is only used as
+    the default value for the ``offload_policy`` arg.
+    """
+
+
+@dataclass
+class CPUOffloadPolicy(OffloadPolicy):
+    """
+    This offload policy offloads parameters, gradients, and optimizer states to
+    CPU. Sharded parameters are copied host-to-device before all-gather. The
+    all-gathered parameters are freed according to ``reshard_after_forward``.
+    Sharded gradients are copied device-to-host in backward, and the optimizer
+    step runs on CPU with CPU optimizer states.
+
+    Attributes:
+        pin_memory (bool): Whether to pin sharded parameter and gradient
+            memory. Pinning memory allows both more efficient H2D/D2H copies
+            and for the copies to overlap with compute. However, the pinned
+            memory cannot be used by other processes. Set this to ``False`` if
+            you have insufficient CPU memory. (Default: ``True``)
+    """
+
+    pin_memory: bool = True
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bd7d24cd7d3f2fc24b634d72197f8e51c4839e6
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py
@@ -0,0 +1,762 @@
+import math
+from collections.abc import Callable, Sequence
+from itertools import chain
+from typing import Any, cast, NamedTuple, Optional, Union
+
+import torch
+import torch.distributed as dist
+from torch.distributed.device_mesh import _get_device_handle
+from torch.distributed.distributed_c10d import ReduceOp
+from torch.distributed.fsdp._fully_shard._fsdp_api import AllGather, ReduceScatter
+from torch.distributed.tensor import DTensor
+
+from ._fsdp_api import _ReduceOp
+from ._fsdp_common import (
+    _get_dim0_padded_size,
+    _raise_assert_with_print,
+    _to_dtype_if_needed,
+    compiled_autograd_enabled,
+)
+from ._fsdp_param import FSDPParam, ShardedState
+
+
+class AllGatherResult(NamedTuple):
+    all_gather_output: torch.Tensor
+    all_gather_event: Optional[torch.Event]
+    all_gather_work: Optional[dist.distributed_c10d.Work]
+    # For each parameter, the all-gather input dtype for each input
+    param_all_gather_input_dtypes: list[list[torch.dtype]]
+    # For each parameter, the all-gather input numel for each input
+    param_all_gather_input_numels: list[list[int]]
+    # 1D flattened version of `param_all_gather_input_numels` saved to avoid
+    # CPU overhead from recomputing
+    all_gather_input_split_sizes: list[int]
+
+
+lib = torch.library.Library("fsdp", "FRAGMENT")  # noqa: TOR901
+
+lib.define(
+    """
+    all_gather_copy_in(
+        Tensor[] all_gather_inputs,
+        Tensor all_gather_output,
+        SymInt[] inp_split_sizes,
+        SymInt all_gather_input_numel,
+        SymInt rank
+    ) -> (Tensor, Tensor)
+    """
+)
+
+
+class DefaultAllocMixin:
+    def allocate(
+        self,
+        size: Sequence[Union[int, torch.SymInt]],
+        *,
+        dtype: torch.dtype,
+        device: torch.device,
+    ) -> torch.Tensor:
+        return torch.empty(*size, dtype=dtype, device=device)
+
+
+class ProcessGroupAllocMixin:
+    def __init__(self, group: dist.ProcessGroup, *args: Any, **kwargs: Any):
+        self._group = group
+        super().__init__(*args, **kwargs)
+
+    def allocate(
+        self,
+        size: Sequence[Union[int, torch.SymInt]],
+        *,
+        dtype: torch.dtype,
+        device: torch.device,
+    ) -> torch.Tensor:
+        backend = self._group._get_backend(device)
+        if backend.supports_tensor_alloc(device):
+            size_1d = math.prod(int(s) for s in size)
+            return backend.allocate_tensor(size_1d, dtype=dtype, device=device)
+        return torch.empty(*size, dtype=dtype, device=device)
+
+
+class DefaultAllGather(DefaultAllocMixin, AllGather):
+    def __call__(
+        self,
+        output_tensor: torch.Tensor,
+        input_tensor: torch.Tensor,
+        group: dist.ProcessGroup,
+        async_op: bool = False,
+    ) -> Optional[dist.Work]:
+        return dist.all_gather_into_tensor(
+            output_tensor,
+            input_tensor,
+            group=group,
+            async_op=async_op,
+        )
+
+
+class ProcessGroupAllocAllGather(ProcessGroupAllocMixin, AllGather):
+    def __init__(self, group: dist.ProcessGroup) -> None:
+        super().__init__(group)
+
+    def __call__(
+        self,
+        output_tensor: torch.Tensor,
+        input_tensor: torch.Tensor,
+        group: dist.ProcessGroup,
+        async_op: bool = False,
+    ) -> Optional[dist.Work]:
+        return dist.all_gather_into_tensor(
+            output_tensor,
+            input_tensor,
+            group=group,
+            async_op=async_op,
+        )
+
+
+class DefaultReduceScatter(DefaultAllocMixin, ReduceScatter):
+    def __call__(
+        self,
+        output_tensor: torch.Tensor,
+        input_tensor: torch.Tensor,
+        group: dist.ProcessGroup,
+        op: _ReduceOp,
+        async_op: bool = False,
+    ) -> dist.Work:
+        return dist.reduce_scatter_tensor(
+            output=output_tensor,
+            input=input_tensor,
+            group=group,
+            op=op,
+            async_op=async_op,
+        )
+
+
+class ProcessGroupAllocReduceScatter(ProcessGroupAllocMixin, ReduceScatter):
+    def __init__(self, group: dist.ProcessGroup) -> None:
+        super().__init__(group)
+
+    def __call__(
+        self,
+        output_tensor: torch.Tensor,
+        input_tensor: torch.Tensor,
+        group: dist.ProcessGroup,
+        op: _ReduceOp,
+        async_op: bool = False,
+    ) -> dist.Work:
+        return dist.reduce_scatter_tensor(
+            output=output_tensor,
+            input=input_tensor,
+            group=group,
+            op=op,
+            async_op=async_op,
+        )
+
+
+@torch.library.impl(lib, "all_gather_copy_in", "Meta")
+def all_gather_copy_in_meta(
+    all_gather_inputs: list[torch.Tensor],
+    all_gather_output: torch.Tensor,
+    inp_split_sizes: list[int],
+    all_gather_input_numel: int,
+    rank: int,
+) -> tuple[torch.Tensor, torch.Tensor]:
+    all_gather_input = all_gather_output.narrow(
+        0, all_gather_input_numel * rank, all_gather_input_numel
+    )
+    return all_gather_input, all_gather_output
+
+
+@torch.library.impl(lib, "all_gather_copy_in", "CUDA")
+@torch.library.impl(lib, "all_gather_copy_in", "XPU")
+@torch.library.impl(lib, "all_gather_copy_in", "HPU")
+@torch.library.impl(lib, "all_gather_copy_in", "CPU")
+@torch.library.impl(lib, "all_gather_copy_in", "MTIA")
+@torch.library.impl(lib, "all_gather_copy_in", "PrivateUse1")
+def all_gather_copy_in_cuda(
+    all_gather_inputs: list[torch.Tensor],
+    all_gather_output: torch.Tensor,
+    inp_split_sizes: list[int],
+    all_gather_input_numel: int,
+    rank: int,
+) -> tuple[torch.Tensor, torch.Tensor]:
+    all_gather_input = all_gather_output.narrow(
+        0, all_gather_input_numel * rank, all_gather_input_numel
+    )
+    foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
+    with torch.no_grad():
+        torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs)
+    return all_gather_input, all_gather_output
+
+
+lib.define(
+    "split_with_sizes_copy(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()"
+)
+
+
+@torch.library.impl(lib, "split_with_sizes_copy", "Meta")
+@torch.library.impl(lib, "split_with_sizes_copy", "CUDA")
+@torch.library.impl(lib, "split_with_sizes_copy", "XPU")
+@torch.library.impl(lib, "split_with_sizes_copy", "HPU")
+@torch.library.impl(lib, "split_with_sizes_copy", "CPU")
+@torch.library.impl(lib, "split_with_sizes_copy", "MTIA")
+@torch.library.impl(lib, "split_with_sizes_copy", "PrivateUse1")
+def split_with_sizes_copy(
+    all_gather_output: torch.Tensor,
+    all_gather_input_split_sizes: list[int],
+    dim: int,
+    out: list[torch.Tensor],
+) -> None:
+    torch.split_with_sizes_copy(
+        all_gather_output, all_gather_input_split_sizes, dim=dim, out=out
+    )
+
+
+lib.define(
+    "chunk_cat(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> ()"
+)
+
+
+@torch.library.impl(lib, "chunk_cat", "Meta")
+@torch.library.impl(lib, "chunk_cat", "CUDA")
+@torch.library.impl(lib, "chunk_cat", "XPU")
+@torch.library.impl(lib, "chunk_cat", "HPU")
+@torch.library.impl(lib, "chunk_cat", "CPU")
+@torch.library.impl(lib, "chunk_cat", "MTIA")
+@torch.library.impl(lib, "chunk_cat", "PrivateUse1")
+def chunk_cat(
+    tensors: list[torch.Tensor],
+    dim: int,
+    num_chunks: int,
+    out: torch.Tensor,
+) -> None:
+    torch._chunk_cat(tensors, dim, num_chunks, out=out)
+
+
+@torch.no_grad()
+def foreach_all_gather(
+    fsdp_params: list[FSDPParam],
+    group: dist.ProcessGroup,
+    async_op: bool,
+    all_gather_copy_in_stream: torch.Stream,
+    all_gather_stream: torch.Stream,
+    device: torch.device,
+    all_gather_comm: AllGather,
+) -> Optional[AllGatherResult]:
+    world_size, rank = group.size(), group.rank()
+    device_handle = _get_device_handle(device.type)
+    with device_handle.stream(all_gather_copy_in_stream):
+        param_all_gather_inputs = _get_param_all_gather_inputs(fsdp_params)
+        (
+            param_all_gather_input_dtypes,
+            param_all_gather_input_numels,
+            dtype,
+        ) = _get_all_gather_input_metadatas(param_all_gather_inputs)
+        if dtype == torch.uint8:
+            all_gather_inputs = [
+                t.view(torch.uint8) for ts in param_all_gather_inputs for t in ts
+            ]
+        else:
+            all_gather_inputs = [*chain.from_iterable(param_all_gather_inputs)]
+        inp_split_sizes = [t.numel() for t in all_gather_inputs]
+        all_gather_input_numel = sum(inp_split_sizes)
+        all_gather_output = all_gather_comm.allocate(
+            (all_gather_input_numel * world_size,), dtype=dtype, device=device
+        )
+        all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(
+            all_gather_inputs,
+            all_gather_output,
+            inp_split_sizes,
+            all_gather_input_numel,
+            rank,
+        )
+        del param_all_gather_inputs
+    all_gather_stream.wait_stream(all_gather_copy_in_stream)
+    with device_handle.stream(all_gather_stream):
+        all_gather_work = all_gather_comm(
+            output_tensor=all_gather_output,
+            input_tensor=all_gather_input,
+            group=group,
+            async_op=async_op,
+        )
+        all_gather_event = all_gather_stream.record_event()
+        return AllGatherResult(
+            all_gather_output,
+            all_gather_event,
+            all_gather_work,
+            param_all_gather_input_dtypes,
+            param_all_gather_input_numels,
+            inp_split_sizes,
+        )
+
+
+@torch.no_grad()
+def _get_param_all_gather_inputs(
+    fsdp_params: list[FSDPParam],
+) -> list[list[torch.Tensor]]:
+    if compiled_autograd_enabled():
+        return [fsdp_param.all_gather_inputs for fsdp_param in fsdp_params]
+
+    # Intentionally try to run a fast-path that bypasses abstractions for the
+    # common FSDP case of bf16/fp32 mixed precision in order to use foreach
+    # copy for lower CPU overhead and more efficient copying in eager
+    def use_foreach_copy(fsdp_param: FSDPParam) -> bool:
+        return (
+            fsdp_param.param_dtype is not None
+            and not fsdp_param.offload_to_cpu
+            and not hasattr(fsdp_param._sharded_local_tensor, "fsdp_pre_all_gather")
+        )
+
+    param_all_gather_inputs: list[list[torch.Tensor]] = [[] for _ in fsdp_params]
+    foreach_copy_indices: list[int] = []
+    foreach_copy_inputs: list[torch.Tensor] = []
+    foreach_copy_input_numels: list[int] = []
+
+    # 1st pass: for foreach-copy parameters, get inputs and metadata for the
+    # foreach copy, and for the others, actually get their all-gather inputs
+    for i, fsdp_param in enumerate(fsdp_params):
+        if use_foreach_copy(fsdp_param):
+            foreach_copy_indices.append(i)
+            all_gather_input = (
+                fsdp_param._sharded_param_data
+                if fsdp_param.sharded_state == ShardedState.SHARDED
+                else cast(torch.Tensor, fsdp_param._sharded_post_forward_param_data)
+            )
+            foreach_copy_inputs.append(all_gather_input)
+            foreach_copy_input_numels.append(all_gather_input.numel())
+        else:
+            param_all_gather_inputs[i] = fsdp_param.all_gather_inputs
+
+    # 2nd pass: use foreach copy to compute the remaining all-gather inputs
+    if foreach_copy_inputs:
+        fsdp_param_0 = fsdp_params[foreach_copy_indices[0]]
+        param_dtype, device = fsdp_param_0.param_dtype, fsdp_param_0.device
+        flat_foreach_copy_input = torch.empty(
+            (sum(foreach_copy_input_numels),), device=device, dtype=param_dtype
+        )
+        splits = torch.split(flat_foreach_copy_input, foreach_copy_input_numels)
+        torch._foreach_copy_(splits, foreach_copy_inputs)
+        for i, split in zip(foreach_copy_indices, splits):
+            param_all_gather_inputs[i] = [split]
+
+    return param_all_gather_inputs
+
+
+@torch.no_grad()
+def foreach_all_gather_copy_out(
+    all_gather_result: AllGatherResult,
+    fsdp_params: list[FSDPParam],
+    group: dist.ProcessGroup,
+) -> None:
+    (
+        all_gather_output,
+        all_gather_event,
+        all_gather_work,
+        param_all_gather_input_dtypes,
+        param_all_gather_input_numels,
+        all_gather_input_split_sizes,
+    ) = all_gather_result
+    _dtype, device = all_gather_output.dtype, all_gather_output.device
+    device_handle = _get_device_handle(device.type)
+    if all_gather_event is not None:  # sync op
+        device_handle.current_stream().wait_event(all_gather_event)
+    if isinstance(all_gather_work, dist.distributed_c10d.Work):  # async op
+        all_gather_work.wait()
+    world_size, device = group.size(), all_gather_output.device
+
+    split_with_sizes_out: list[torch.Tensor] = []
+    shard_i_copy_infos: list[tuple[FSDPParam, list[torch.Tensor]]] = []
+    for all_gather_input_numels, all_gather_input_dtypes, fsdp_param in zip(
+        param_all_gather_input_numels, param_all_gather_input_dtypes, fsdp_params
+    ):
+        # NOTE: Under compile, make sure we always recreate all_gather_outputs
+        # per AllGather. See [Note: Invariants for torch.compile Traceable FSDP2].
+        force_recreate = compiled_autograd_enabled()
+        fsdp_param.init_all_gather_outputs(
+            all_gather_input_numels,
+            all_gather_input_dtypes,
+            world_size,
+            device,
+            force_recreate=force_recreate,
+        )
+        if not force_recreate:
+            fsdp_param.alloc_all_gather_outputs()
+        param_all_gather_outputs = fsdp_param.all_gather_outputs
+        if fsdp_param.fsdp_placement.dim != 0:
+            # Copy to a temporary and then chunk-cat into the final all-gather
+            # output tensors
+            param_all_gather_outputs = [
+                torch.empty_like(t) for t in param_all_gather_outputs
+            ]
+            shard_i_copy_infos.append((fsdp_param, param_all_gather_outputs))
+        split_with_sizes_out.extend(param_all_gather_outputs)
+
+    all_gather_output = all_gather_output.view(world_size, -1)
+    if all_gather_output.dtype == torch.uint8:
+        out = [t.view(world_size, -1).view(torch.uint8) for t in split_with_sizes_out]
+    else:
+        out = [t.view(world_size, -1) for t in split_with_sizes_out]
+
+    # only avoid VC bump if we are not in inference mode
+    if torch._dynamo.is_compiling():
+        # For torch.compile, we turn off inference_mode for fake tensor
+        # propagation, and therefore graph break on is_inference. For `compile`,
+        # we don't care about VCs, so just skip the optimization.
+        non_inference_outs = []
+    else:
+        non_inference_outs = [o for o in out if not o.is_inference()]
+
+    if len(non_inference_outs) > 0:
+        with torch.autograd._unsafe_preserve_version_counter(tuple(non_inference_outs)):
+            torch.ops.fsdp.split_with_sizes_copy(
+                all_gather_output, all_gather_input_split_sizes, dim=1, out=out
+            )
+    else:
+        torch.ops.fsdp.split_with_sizes_copy(
+            all_gather_output, all_gather_input_split_sizes, dim=1, out=out
+        )
+
+    for fsdp_param, param_all_gather_outputs in shard_i_copy_infos:
+        # Chunk-cat from the temporary to the final all-gather output tensors
+        shard_dim = fsdp_param.fsdp_placement.dim
+
+        with torch.autograd._unsafe_preserve_version_counter(
+            tuple(fsdp_param.all_gather_outputs)
+        ):
+            for param_all_gather_output, target_all_gather_output in zip(
+                param_all_gather_outputs, fsdp_param.all_gather_outputs
+            ):
+                padded_sharded_size = (
+                    fsdp_param.padded_sharded_param_size
+                    if fsdp_param.sharded_state == ShardedState.SHARDED
+                    else cast(
+                        torch.Tensor, fsdp_param._sharded_post_forward_param_data
+                    ).size()
+                )
+                pre_param_size = list(padded_sharded_size)
+                pre_param_size[0] *= world_size
+                chunks = torch.chunk(
+                    param_all_gather_output.view(pre_param_size), world_size, dim=0
+                )
+                post_param_size = list(padded_sharded_size)
+                post_param_size[shard_dim] *= world_size
+                cat_out = target_all_gather_output.view(post_param_size)
+                torch.cat(chunks, dim=shard_dim, out=cat_out)
+
+
+@torch.no_grad()
+def foreach_reduce(
+    fsdp_params: list[FSDPParam],
+    unsharded_grads: list[torch.Tensor],
+    reduce_scatter_group: dist.ProcessGroup,
+    reduce_scatter_stream: torch.Stream,
+    reduce_scatter_comm: ReduceScatter,
+    orig_dtype: Optional[torch.dtype],
+    reduce_dtype: Optional[torch.dtype],
+    device: torch.device,
+    gradient_divide_factor: Optional[float],
+    all_reduce_group: Optional[dist.ProcessGroup],  # not `None` iff HSDP
+    all_reduce_stream: torch.Stream,
+    all_reduce_grads: bool,
+    partial_reduce_output: Optional[torch.Tensor],  # only used for HSDP
+    all_reduce_hook: Optional[Callable[[torch.Tensor], None]],
+    force_sum_reduction_for_comms: bool = False,
+) -> tuple[
+    torch.Tensor,
+    torch.Event,
+    torch.Event,
+    Optional[torch.Tensor],
+    Optional[torch.Event],
+    Optional[torch.Tensor],
+]:
+    """
+    ``unsharded_grads`` owns the references to the gradients computed by
+    autograd, so clearing the list frees the gradients.
+    """
+
+    grad_dtypes = {grad.dtype for grad in unsharded_grads}
+    if len(grad_dtypes) != 1:
+        # Check this at runtime since it could be a real runtime error if e.g.
+        # fp8 weights do not produce the correct higher precision gradients
+        _raise_assert_with_print(
+            f"FSDP reduce-scatter expects uniform gradient dtype but got {grad_dtypes}"
+        )
+    grad_dtype = unsharded_grads[0].dtype
+    reduce_dtype = reduce_dtype or grad_dtype
+    (predivide_factor, postdivide_factor, reduce_scatter_op, all_reduce_op) = (
+        _get_gradient_divide_factors(
+            reduce_scatter_group,
+            all_reduce_group,
+            reduce_dtype,
+            device.type,
+            gradient_divide_factor,
+            force_sum_reduction_for_comms,
+        )
+    )
+
+    if reduce_scatter_group is None:
+        world_size = 1
+    else:
+        world_size = reduce_scatter_group.size()
+    device_handle = _get_device_handle(device.type)
+    current_stream = device_handle.current_stream()
+
+    if world_size > 1:
+        for i, (fsdp_param, unsharded_grad) in enumerate(
+            zip(fsdp_params, unsharded_grads)
+        ):
+            if (shard_dim := fsdp_param.fsdp_placement.dim) == 0:
+                continue
+            if unsharded_grad.size(shard_dim) % world_size != 0:
+                raise AssertionError(
+                    f"Shard({shard_dim}) requires even sharding: {unsharded_grad.size()=} {world_size=}"
+                )
+            chunks = torch.chunk(unsharded_grad, world_size, dim=shard_dim)
+            unsharded_grads[i] = torch.cat(chunks, dim=0)
+
+    padded_unsharded_sizes = tuple(
+        _get_dim0_padded_size(grad.size(), world_size) for grad in unsharded_grads
+    )
+    reduce_scatter_input_numel = sum(s.numel() for s in padded_unsharded_sizes)
+    reduce_scatter_output_numel = reduce_scatter_input_numel // world_size
+    reduce_scatter_input = reduce_scatter_comm.allocate(
+        (reduce_scatter_input_numel,),
+        dtype=reduce_dtype,
+        device=device,
+    )
+
+    foreach_reduce_scatter_copy_in(unsharded_grads, reduce_scatter_input, world_size)
+
+    # Only after the copy-in finishes can we free the gradients
+    unsharded_grads.clear()
+    reduce_scatter_stream.wait_stream(current_stream)
+    all_reduce_input = None
+    all_reduce_event = None
+
+    with device_handle.stream(reduce_scatter_stream):
+        reduce_output = reduce_scatter_comm.allocate(
+            (reduce_scatter_output_numel,),
+            dtype=reduce_dtype,
+            device=device,
+        )
+        _div_if_needed(reduce_scatter_input, predivide_factor)
+        if world_size > 1:
+            reduce_scatter_comm(
+                output_tensor=reduce_output,
+                input_tensor=reduce_scatter_input,
+                group=reduce_scatter_group,
+                op=reduce_scatter_op,
+            )
+        else:
+            # For single GPU, just copy the input to output (no actual reduce-scatter needed), and
+            # account for a possible gradient_divide_factor.
+            if gradient_divide_factor is not None:
+                reduce_output.copy_(reduce_scatter_input / gradient_divide_factor)
+            else:
+                reduce_output.copy_(reduce_scatter_input)
+        reduce_scatter_event = reduce_scatter_stream.record_event()
+        post_reduce_stream = reduce_scatter_stream
+        if all_reduce_group is not None:  # HSDP or DDP/replicate
+            # Accumulations must run in the reduce-scatter stream
+            if not all_reduce_grads:
+                if partial_reduce_output is not None:
+                    partial_reduce_output += reduce_output
+                else:
+                    partial_reduce_output = reduce_output
+                return (
+                    reduce_scatter_input,
+                    reduce_scatter_event,
+                    post_reduce_stream.record_event(),
+                    all_reduce_input,
+                    all_reduce_event,
+                    partial_reduce_output,
+                )
+            if partial_reduce_output is not None:
+                reduce_output += partial_reduce_output
+            post_reduce_stream = all_reduce_stream
+            if world_size >= 1:
+                all_reduce_stream.wait_stream(reduce_scatter_stream)
+            else:
+                all_reduce_stream.wait_stream(current_stream)
+            with device_handle.stream(all_reduce_stream):
+                dist.all_reduce(
+                    reduce_output,
+                    group=all_reduce_group,
+                    op=all_reduce_op,
+                )
+                all_reduce_input = reduce_output
+                all_reduce_event = all_reduce_stream.record_event()
+    # -- END: ops in reduce_scatter stream
+
+    if all_reduce_hook is not None:
+        # Execute user-specified all reduce hook.
+        # If native HSDP is used, this is executed after the HSDP all reduce.
+        # If 1-d FSDP is used, this is executed post reduce-scatter.
+        post_reduce_stream = all_reduce_stream
+        all_reduce_stream.wait_stream(reduce_scatter_stream)
+        with device_handle.stream(all_reduce_stream):
+            all_reduce_hook(reduce_output)
+    # -- END: ops post reduce_scatter
+
+    with device_handle.stream(post_reduce_stream):
+        _div_if_needed(reduce_output, postdivide_factor)
+        reduce_output = _to_dtype_if_needed(reduce_output, orig_dtype)
+        # View out and accumulate sharded gradients
+        flat_grad_offset = 0  # [0, reduce_scatter_output_numel - 1]
+        for padded_unsharded_size, fsdp_param in zip(
+            padded_unsharded_sizes, fsdp_params
+        ):
+            # Assume even sharding for Shard(i), i > 0; otherwise would require
+            # copy-out for contiguous strides
+            new_sharded_grad = torch.as_strided(
+                reduce_output,
+                size=fsdp_param.sharded_size,
+                stride=fsdp_param.contiguous_sharded_stride,
+                storage_offset=flat_grad_offset,
+            )
+            to_accumulate_grad = fsdp_param.sharded_param.grad is not None
+            if fsdp_param.offload_to_cpu:
+                # Only overlap the D2H copy (copying to pinned memory) if not
+                # accumulating gradients since the CPU add kernel depends on
+                # the copy result and we cannot run the add as a callback
+                non_blocking = fsdp_param.pin_memory and not to_accumulate_grad
+                # Since the GPU sharded gradient is allocated in the RS stream,
+                # we can free it here by not keeping a ref without waiting for
+                # the D2H copy since future RS-stream ops run after the copy
+                new_sharded_grad = new_sharded_grad.to(
+                    torch.device("cpu"), non_blocking=non_blocking
+                )
+                if non_blocking:
+                    # Record an event on which to block the CPU thread to
+                    # ensure that the D2H copy finishes before the optimizer
+                    fsdp_param.grad_offload_event = post_reduce_stream.record_event()
+            if to_accumulate_grad:
+                if not isinstance(fsdp_param.sharded_param.grad, DTensor):
+                    raise AssertionError(
+                        f"Expected fsdp_param.sharded_param.grad to be DTensor, got {type(fsdp_param.sharded_param.grad)}"
+                    )
+                fsdp_param.sharded_param.grad._local_tensor += new_sharded_grad
+            else:
+                new_sharded_dtensor_grad = fsdp_param.to_sharded_dtensor(
+                    new_sharded_grad
+                )
+                fsdp_param.sharded_param.grad = new_sharded_dtensor_grad
+            if not compiled_autograd_enabled():
+                for hook in (
+                    getattr(fsdp_param.sharded_param, "_post_accumulate_grad_hooks", {})
+                    or {}
+                ).values():
+                    hook(fsdp_param.sharded_param)
+            padded_sharded_numel = padded_unsharded_size.numel() // world_size
+            flat_grad_offset += padded_sharded_numel
+        post_reduce_event = post_reduce_stream.record_event()
+    # The RS output is allocated in the RS stream and used in the default
+    # stream (for optimizer). To ensure its memory is not reused for later
+    # RSs, we do not need extra synchronization since the sharded parameters
+    # hold refs through the end of backward.
+    return (
+        reduce_scatter_input,
+        reduce_scatter_event,
+        post_reduce_event,
+        all_reduce_input,
+        all_reduce_event,
+        None,
+    )
+
+
+def foreach_reduce_scatter_copy_in(
+    unsharded_grads: list[torch.Tensor],
+    reduce_scatter_input: torch.Tensor,
+    world_size: int,
+) -> None:
+    reduce_scatter_input = reduce_scatter_input.view(world_size, -1)
+    torch.ops.fsdp.chunk_cat(
+        unsharded_grads, dim=0, num_chunks=world_size, out=reduce_scatter_input
+    )
+
+
+def _get_all_gather_input_metadatas(
+    param_all_gather_inputs: list[list[torch.Tensor]],
+) -> tuple[list[list[torch.dtype]], list[list[int]], torch.dtype]:
+    param_all_gather_input_dtypes: list[list[torch.dtype]] = []
+    param_all_gather_input_numels: list[list[int]] = []
+    all_gather_dtype = param_all_gather_inputs[0][0].dtype
+    for all_gather_inputs in param_all_gather_inputs:
+        input_dtypes: list[torch.dtype] = []
+        input_numels: list[int] = []
+        for all_gather_input in all_gather_inputs:
+            if all_gather_input.dtype != all_gather_dtype:
+                all_gather_dtype = torch.uint8
+            input_dtypes.append(all_gather_input.dtype)
+            input_numels.append(all_gather_input.numel())
+        param_all_gather_input_dtypes.append(input_dtypes)
+        param_all_gather_input_numels.append(input_numels)
+    return (
+        param_all_gather_input_dtypes,
+        param_all_gather_input_numels,
+        all_gather_dtype,
+    )
+
+
+def _get_gradient_divide_factors(
+    reduce_scatter_group: Optional[dist.ProcessGroup],
+    all_reduce_group: Optional[dist.ProcessGroup],
+    reduce_dtype: torch.dtype,
+    device_type: str = "",
+    factor: Optional[float] = None,
+    force_sum_reduction_for_comms: bool = False,
+) -> tuple[
+    Optional[float],
+    Optional[float],
+    Union[dist.ReduceOp, dist.ReduceOp.RedOpType],
+    Union[dist.ReduceOp, dist.ReduceOp.RedOpType],
+]:
+    # MTIA appears to only support SUM reduction, hence we force it implicitly
+    if device_type == "mtia":
+        force_sum_reduction_for_comms = True
+
+    # For fp32/bf16, we do not need to worry about overflow/underflow, so we
+    # use NCCL's built-in division to avoid separate div kernels
+    overflow_risk = reduce_dtype not in (torch.float32, torch.bfloat16)
+    if reduce_scatter_group is not None:
+        data_parallel_size = reduce_scatter_group.size()
+    else:
+        data_parallel_size = 1
+
+    if all_reduce_group is not None:
+        data_parallel_size *= all_reduce_group.size()
+
+    if not overflow_risk and not force_sum_reduction_for_comms:
+        if factor is None:
+            # Warning: NCCL ReduceOp.AVG may produce incorrect results with
+            # world size 1.
+            if data_parallel_size == 1:
+                return None, None, ReduceOp.SUM, ReduceOp.SUM
+            return None, None, ReduceOp.AVG, ReduceOp.AVG
+        if reduce_scatter_group is not None and factor == reduce_scatter_group.size():
+            reduce_scatter_op = ReduceOp.AVG
+        else:
+            reduce_scatter_op = torch.distributed._make_nccl_premul_sum(1 / factor)
+        return None, None, reduce_scatter_op, ReduceOp.SUM
+
+    if factor is None:
+        factor = float(data_parallel_size)
+    pre_factor: Optional[float]
+    if overflow_risk:
+        # Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid
+        # overflow/underflow. For N data parallel workers, each worker computes
+        # g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid
+        # overflow/underflow, we divide by ~sqrt(N) before/after the reduction.
+        pre_factor = 1
+        while factor % pre_factor == 0 and factor / pre_factor > pre_factor:
+            pre_factor *= 2
+        post_factor = factor / pre_factor
+    else:
+        # Prefer post-multiplying as it operates on less data and is thus faster
+        pre_factor, post_factor = None, factor
+
+    return pre_factor, post_factor, ReduceOp.SUM, ReduceOp.SUM
+
+
+def _div_if_needed(tensor: torch.Tensor, div_factor: Optional[float]) -> None:
+    if div_factor is not None and div_factor != 1:
+        tensor.div_(div_factor)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_common.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..85addad83b3b08cbed358f3eb31b2bf4f2a2c9e8
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_common.py
@@ -0,0 +1,181 @@
+# mypy: allow-untyped-defs
+import math
+import traceback
+from dataclasses import dataclass
+from enum import auto, Enum
+from typing import Any, Optional
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.distributed._composable.contract import _get_registry
+from torch.distributed.tensor import DeviceMesh, DTensor
+from torch.distributed.tensor._dtensor_spec import DTensorSpec
+
+
+_compiled_autograd_enabled: bool = False
+
+
+def detect_compiled_autograd():
+    if torch.compiler.is_compiling():
+        raise AssertionError(
+            "`detect_compiled_autograd()` is designed to be called in eager mode"
+        )
+    global _compiled_autograd_enabled
+    import torch._dynamo.compiled_autograd as ca
+
+    _compiled_autograd_enabled = (
+        ca.compiled_autograd_enabled
+        or ca.compiled_autograd_enabled_force_eager
+        or ca.in_compiled_autograd_region
+    )
+
+
+def compiled_autograd_enabled():
+    global _compiled_autograd_enabled
+    return _compiled_autograd_enabled
+
+
+@dataclass
+class DataParallelMeshInfo:
+    mesh: DeviceMesh
+    shard_mesh_dim: Optional[int] = None
+    replicate_mesh_dim: Optional[int] = None
+
+    def __post_init__(self):
+        if self.shard_mesh_dim is None and self.replicate_mesh_dim is None:
+            raise AssertionError(
+                "At least one of shard_mesh_dim and replicate_mesh_dim must not be None"
+            )
+
+
+@dataclass
+class FSDPMeshInfo(DataParallelMeshInfo):
+    def __post_init__(self):
+        super().__post_init__()
+        if self.shard_mesh_dim is None:
+            raise AssertionError("Expects non-None shard_mesh_dim")
+        self.shard_mesh_size: int = self.mesh.size(self.shard_mesh_dim)
+        self.shard_process_group = self.mesh.get_group(self.shard_mesh_dim)
+        self.shard_mesh_rank: int = self.shard_process_group.rank()
+
+
+@dataclass
+class DDPMeshInfo(DataParallelMeshInfo):
+    def __post_init__(self):
+        super().__post_init__()
+        if self.replicate_mesh_dim is None:
+            raise AssertionError("Expects non-None replicate_mesh_dim")
+        self.replicate_mesh_size: int = self.mesh.size(self.replicate_mesh_dim)
+        self.replicate_process_group = self.mesh.get_group(self.replicate_mesh_dim)
+        self.replicate_mesh_rank: int = self.replicate_process_group.rank()
+
+
+@dataclass
+class HSDPMeshInfo(FSDPMeshInfo, DDPMeshInfo):
+    def __post_init__(self):  # pylint:disable=useless-parent-delegation
+        # Calls `FSDPMeshInfo` -> `DDPMeshInfo` -> `DataParallelMeshInfo`
+        super().__post_init__()
+
+
+class TrainingState(Enum):
+    """Describes the training state of one FSDP state / parameter group."""
+
+    # Transition to forward starting pre-forward until post-forward
+    FORWARD = auto()
+    # Transition to pre-backward when unsharding in backward
+    PRE_BACKWARD = auto()
+    # Transition to post-backward when resharding and reducing gradients
+    POST_BACKWARD = auto()
+    # Idle before/after forward or before pre-backward/after post-backward
+    IDLE = auto()
+
+
+def _raise_assert_with_print(*args: Any, **kwargs: Any):
+    print(f"[Rank {dist.get_rank()}] ", end="")
+    print(*args, **kwargs)
+    traceback.print_stack()
+    raise AssertionError(*args, **kwargs)
+
+
+def _is_composable_with_fsdp(module: nn.Module) -> bool:
+    registry = _get_registry(module)
+    if registry is None:
+        return True
+    # Registry keys by function name
+    return "replicate" not in registry
+
+
+def _get_dim0_padded_size(tensor_size: torch.Size, dim0_factor: int) -> torch.Size:
+    padded_dim0 = math.ceil(tensor_size[0] / dim0_factor) * dim0_factor
+    return torch.Size([padded_dim0]) + tensor_size[1:]
+
+
+def _chunk_with_empty(
+    tensor: torch.Tensor, num_chunks: int, dim: int
+) -> list[torch.Tensor]:
+    chunks = list(torch.chunk(tensor, num_chunks, dim=dim))
+    while len(chunks) < num_chunks:
+        chunks.append(chunks[0].new_empty(0))
+    return chunks
+
+
+def _get_dim_chunked_size(
+    chunk: torch.Tensor, unchunked_size: torch.Size, dim: int
+) -> torch.Size:
+    if chunk.numel() > 0:
+        return chunk.size()
+    # For 0 numel, we need to preserve nonzero-sized dims for DTensor APIs
+    return unchunked_size[:dim] + torch.Size([0]) + unchunked_size[dim + 1 :]
+
+
+def _from_local_no_grad(
+    local_tensor: torch.Tensor,
+    sharding_spec: DTensorSpec,
+) -> DTensor:
+    """
+    This method is similar to ``DTensor.from_local()`` except that in eager mode
+    it avoids some CPU overhead by avoiding default args and not being differentiable.
+    """
+
+    if not compiled_autograd_enabled():
+        # pyrefly: ignore [bad-argument-type]
+        return DTensor(
+            # Use the local tensor directly instead of constructing a new tensor
+            # variable, e.g. with `view_as()`, since this is not differentiable
+            # pyrefly: ignore [bad-argument-count]
+            local_tensor,
+            sharding_spec,
+            # pyrefly: ignore [unexpected-keyword]
+            requires_grad=local_tensor.requires_grad,
+        )
+    else:
+        return DTensor.from_local(
+            local_tensor,
+            sharding_spec.mesh,
+            sharding_spec.placements,
+            shape=sharding_spec.shape,
+            stride=sharding_spec.stride,
+        )
+
+
+def _to_dtype_if_needed(
+    tensor: torch.Tensor, dtype: Optional[torch.dtype]
+) -> torch.Tensor:
+    if dtype is not None and tensor.dtype != dtype:
+        return tensor.to(dtype)
+    return tensor
+
+
+def _cast_fp_tensor(dtype: torch.dtype, x: torch.Tensor) -> torch.Tensor:
+    if (
+        not isinstance(x, torch.Tensor)
+        or not torch.is_floating_point(x)
+        or x.dtype == dtype
+    ):
+        return x
+    return x.to(dtype)
+
+
+def is_bw() -> bool:
+    return torch._C._current_graph_task_id() != -1
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_init.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_init.py
new file mode 100644
index 0000000000000000000000000000000000000000..01d196795c3d8f9270138f757b3e7f3de9e10f11
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_init.py
@@ -0,0 +1,243 @@
+import itertools
+import logging
+from typing import Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch._logging import warning_once
+from torch.distributed.device_mesh import _get_device_handle
+from torch.distributed.tensor import DeviceMesh, DTensor, init_device_mesh
+from torch.utils._python_dispatch import is_traceable_wrapper_subclass
+
+from ._fsdp_common import _is_composable_with_fsdp, FSDPMeshInfo, HSDPMeshInfo
+from ._fsdp_state import _get_module_fsdp_state
+
+
+logger = logging.getLogger("torch.distributed.fsdp.fully_shard")
+
+
+def _get_post_forward_mesh_info(
+    reshard_after_forward: Union[bool, int], mesh_info: FSDPMeshInfo
+) -> Optional[FSDPMeshInfo]:
+    shard_mesh_size = mesh_info.shard_mesh_size
+    if not isinstance(reshard_after_forward, (bool, int)):
+        raise ValueError(
+            "reshard_after_forward should be a bool or an int representing the "
+            f"group size to reshard to, not {reshard_after_forward}"
+        )
+    # NOTE: `isinstance(False, int)` returns `True`.
+    if not isinstance(reshard_after_forward, bool) and isinstance(
+        reshard_after_forward, int
+    ):
+        if (
+            reshard_after_forward < 1
+            or reshard_after_forward > shard_mesh_size
+            or shard_mesh_size % reshard_after_forward != 0
+        ):
+            raise ValueError(
+                "If passing reshard_after_forward as an int, it should be a "
+                f"factor of {shard_mesh_size}, not {reshard_after_forward}"
+            )
+        elif reshard_after_forward == 1:
+            msg = (
+                "reshard_after_forward=1 (int) means resharding parameters to world size 1, "
+                "instead of reshard_after_forward=True (bool)"
+            )
+            warning_once(logger, msg, stacklevel=2)
+            reshard_after_forward = False
+        elif reshard_after_forward == shard_mesh_size:
+            reshard_after_forward = True
+    post_forward_mesh_info = None
+    if reshard_after_forward is True:
+        post_forward_mesh_info = mesh_info
+    elif reshard_after_forward is not False:  # int case
+        # For HSDP, we can flatten the two replicate dims into the 0th dim
+        post_forward_mesh_tensor = mesh_info.mesh.mesh.view(-1, reshard_after_forward)
+        post_forward_mesh = DeviceMesh(
+            mesh_info.mesh.device_type, post_forward_mesh_tensor
+        )
+        post_forward_mesh_info = HSDPMeshInfo(
+            post_forward_mesh, shard_mesh_dim=1, replicate_mesh_dim=0
+        )
+    return post_forward_mesh_info
+
+
+def _init_default_fully_shard_mesh() -> DeviceMesh:
+    """Default to global CUDA mesh if possible else global CPU mesh."""
+    if not dist.distributed_c10d.is_initialized():
+        dist.distributed_c10d.init_process_group()
+    default_pg = dist.distributed_c10d._get_default_group()
+    device = torch._C._get_accelerator()
+    mesh = init_device_mesh(device.type, mesh_shape=(default_pg.size(),))
+    return mesh
+
+
+def _get_device_from_mesh(mesh: DeviceMesh) -> torch.device:
+    if mesh.device_type == "cpu":
+        return torch.device("cpu")
+    device_handle = _get_device_handle(mesh.device_type)
+    return torch.device(mesh.device_type, device_handle.current_device())
+
+
+def _ignore_module(
+    module: nn.Module,
+    ignored_params: set[nn.Parameter],
+    ignore_decision: dict[nn.Module, bool],
+) -> bool:
+    """
+    Decide if it is safe to ignore a module for applying fully_shard.
+    """
+    if module in ignore_decision:
+        return ignore_decision[module]
+
+    if len(list(module.buffers(recurse=False))) > 0:
+        # Cannot ignore a module with any buffer
+        ignore_decision[module] = False
+        return False
+
+    for _, param in module.named_parameters(recurse=False):
+        if param not in ignored_params:
+            # at least one param is not ignored. So this module shouldn't be.
+            ignore_decision[module] = False
+            return False
+
+    # Need to consider descendants of module
+    for child in list(module.children()):
+        ignore_child = _ignore_module(child, ignored_params, ignore_decision)
+        if not ignore_child:
+            # Cannot ignore module if one of its children is not ignored
+            ignore_decision[module] = False
+            return False
+
+    # Safe to ignore module
+    ignore_decision[module] = True
+    return True
+
+
+def _adjust_managed_modules(
+    modules: list[nn.Module], ignored_params: set[nn.Parameter]
+) -> list[nn.Module]:
+    """
+    Adjust the given list of managed modules by removing those with all parameters ignored.
+    """
+    ignore_decision: dict[nn.Module, bool] = {}
+    new_modules = []
+    for module in modules:
+        ignored = _ignore_module(module, ignored_params, ignore_decision)
+        if not ignored:
+            new_modules.append(module)
+    return new_modules
+
+
+def _get_managed_modules(
+    root_modules: tuple[nn.Module, ...],
+    ignored_params: Optional[set[nn.Parameter]] = None,
+) -> list[nn.Module]:
+    modules: list[nn.Module] = []
+    root_modules_set = set(root_modules)
+    # Track visisted modules to avoid visiting shared modules multiple times
+    visited_modules: set[nn.Module] = set()
+
+    def dfs(module: nn.Module) -> None:
+        """
+        Runs a DFS to collect managed modules, not recursing into modules with
+        a non-composable API or ``fully_shard`` already applied.
+        """
+        if not _is_composable_with_fsdp(module):
+            return
+        elif (
+            module not in root_modules_set
+            and _get_module_fsdp_state(module) is not None
+        ):
+            return  # nested `fully_shard` module
+        visited_modules.add(module)
+        for submodule in module.children():
+            if submodule not in visited_modules:
+                dfs(submodule)
+        modules.append(module)
+
+    for root_module in root_modules:
+        dfs(root_module)
+
+    if ignored_params is None:
+        return modules
+
+    adjusted_modules = _adjust_managed_modules(modules, ignored_params)
+    return adjusted_modules
+
+
+def _verify_managed_param(name: str, param: nn.Parameter) -> None:
+    """
+    Verify if the parameter is accepted by fully_shard. The only restriction now
+    is that the parameter cannot be a scalar tensor (param.numel == 0) since we
+    need at least one dim to shard.
+    """
+    if len(param.shape) == 0:
+        raise ValueError(
+            "fully_shard doesn't support scalar parameters. "
+            f"Change {name} to a 1D tensor with numel equal to 1."
+        )
+
+
+def _get_managed_states(
+    modules: list[nn.Module], ignored_params: Optional[set[nn.Parameter]] = None
+) -> tuple[list[nn.Parameter], list[torch.Tensor]]:
+    params: list[nn.Parameter] = []
+    buffers: list[torch.Tensor] = []
+    # Track visited parameters/buffers to avoid visiting shared parameters and
+    # buffers multiple times
+    visited_params: set[nn.Parameter] = set()
+    visited_buffers: set[torch.Tensor] = set()
+    if ignored_params is None:
+        ignored_params = set()
+
+    for module in modules:
+        for name, param in module.named_parameters(recurse=False):
+            if param in ignored_params:
+                # do not include an ignored parameters
+                continue
+            if param not in visited_params:
+                _verify_managed_param(name, param)
+                params.append(param)
+                visited_params.add(param)
+        for buffer in module.buffers(recurse=False):
+            if buffer not in visited_buffers:
+                buffers.append(buffer)
+                visited_buffers.add(buffer)
+    return params, buffers
+
+
+def _move_states_to_device(
+    params: list[nn.Parameter],
+    buffers: list[torch.Tensor],
+    device: torch.device,
+) -> None:
+    """
+    We have FSDP move states to device for simpler and faster initialization
+    since FSDP almost always uses CUDA for training. We move parameters/buffers
+    rather than modules since modules to support ignoring parameters/buffers in
+    the future.
+    """
+    # Follow the logic in `nn.Module._apply`
+    # pyrefly: ignore [bad-argument-type]
+    for tensor in itertools.chain(params, buffers):
+        if tensor.device == device or tensor.device.type == "meta":
+            # Keep meta-device tensors on meta device for deferred init
+            continue
+        if isinstance(tensor, DTensor):
+            if (dtensor_mesh_type := tensor.device_mesh.device_type) != device.type:
+                raise ValueError(
+                    "Requires DTensor to have mesh of the same type as the FSDP mesh "
+                    f"but got {dtensor_mesh_type} for DTensor and {device.type} for FSDP"
+                )
+            raise AssertionError(
+                f"Expects DTensor to be moved to {dtensor_mesh_type} but got {tensor.device}"
+            )
+        tensor_ = tensor
+        if is_traceable_wrapper_subclass(tensor_):
+            with torch.no_grad():  # avoid autograd increasing C++ refcount by 1
+                tensor_on_device = nn.Parameter(tensor.to(device))
+            torch.utils.swap_tensors(tensor, tensor_on_device)
+        else:
+            tensor.data = tensor.to(device)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py
new file mode 100644
index 0000000000000000000000000000000000000000..476fbd94928947bc95cf13eab10b85d76e554164
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py
@@ -0,0 +1,966 @@
+# mypy: allow-untyped-defs
+import inspect
+import itertools
+from collections.abc import Callable, Sequence
+from dataclasses import dataclass, field
+from enum import auto, Enum
+from typing import Any, cast, Optional
+
+import torch
+import torch.nn as nn
+from torch._prims_common import make_contiguous_strides_for
+from torch.distributed._functional_collectives import AsyncCollectiveTensor
+from torch.distributed.device_mesh import DeviceMesh
+from torch.distributed.fsdp._fully_shard._fsdp_common import DDPMeshInfo
+from torch.distributed.tensor import DTensor, Replicate, Shard
+from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
+from torch.distributed.tensor.placement_types import _StridedShard, Placement
+
+from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
+from ._fsdp_common import (
+    _chunk_with_empty,
+    _from_local_no_grad,
+    _get_dim_chunked_size,
+    _raise_assert_with_print,
+    _to_dtype_if_needed,
+    compiled_autograd_enabled,
+    FSDPMeshInfo,
+    HSDPMeshInfo,
+)
+
+
+"""
+[Note: FSDP tensors]
+FSDP considers the following tensors:
+- Original parameter: parameter passed to :class:`FSDPParam`, i.e. the one
+  on the module when applying FSDP
+- Sharded parameter: sharding the original parameter on dim-0 (or a
+  user-specified dim) as a DTensor over the main mesh
+- All-gather inputs: the ``torch.Tensor`` or ``Tensor`` s passed to all-gather,
+  derived from the sharded parameter
+- All-gather output: the ``torch.Tensor`` or ``Tensor`` s resulting from
+  all-gathering the all-gather inputs
+- Unsharded parameter: parameter used for forward/backward computation, derived
+  from the all-gather output; autograd leaf
+
+We define these tensors to describe the general framework that can accommodate
+extensions, where:
+- all-gather-inputs = pre-all-gather-transform(sharded-parameter)
+- unsharded-parameter = post-all-gather-transform(all-gather-outputs)
+
+For the default ``torch.Tensor`` case, there is only one all-gather input, and
+it shares the same underlying tensor data as the sharded parameter, meaning
+that they can be thought of as the same tensors. The same applies for the
+all-gather output and unsharded parameter. For non-``torch.Tensor`` extensions,
+these equivalences may no longer hold due to the pre/post-all-gather
+transforms, and some may have multiple all-gather inputs/outputs (e.g.
+quantized data and scales).
+
+[Note: FSDP and autograd]
+FSDP dynamically frees and allocates the unsharded parameter. Since autograd
+can pack a reference to it or a view to save for backward, we use storage
+resizing to implement the freeing/allocation since that preserves the aliasing.
+This implies that we construct the unsharded parameter object once and write to
+it in-place thereafter. For the default ``torch.Tensor` original parameter
+case, the all-gather output and unsharded parameter share the same
+data, so we use storage resizing on the all-gather output.
+"""
+
+lib = torch.library.Library("fsdp", "FRAGMENT")  # noqa: TOR901
+
+lib.define("copy_(Tensor(a!) tensor, Tensor data) -> ()")
+
+
+@torch.library.impl(lib, "copy_", "Meta")
+@torch.library.impl(lib, "copy_", "CUDA")
+@torch.library.impl(lib, "copy_", "XPU")
+@torch.library.impl(lib, "copy_", "HPU")
+@torch.library.impl(lib, "copy_", "CPU")
+@torch.library.impl(lib, "copy_", "MTIA")
+def copy_(tensor, data):
+    tensor.copy_(data)
+
+
+"""
+[Note: Avoiding functionalization for fsdp.copy_ and inductor.resize_storage_bytes_]
+
+Currently we don't functionalize `fsdp.copy_` op or `inductor.resize_storage_bytes_` op
+(i.e. they show up as a mutation op in the middle of the AOT joint graph).
+
+Reason:
+Traceable FSDP2 compiled autograd BWD graph have the following traits:
+(1) Two inputs of the graph were aliased to each other (one from hook closed-over tensors, one from FWD saved tensors).
+(2) One of them is mutated (copy_ and resize_ to handle the all-gathered param).
+(3) They are both subclasses.
+The combination of these traits is not supported by AOTAutograd (it's difficult to reason about subclass aliasing).
+So this doesn't work at all for Traceable FSDP2.
+
+The compromise we use is to avoid functionalization for the FSDP2 copy_ and resize_ ops.
+This avoids the problem above, because from AOTAutograd point-of-view there are no mutations
+that functionalization needs to handle. (Although we need to be careful not to DCE those mutable ops.)
+
+We can avoid this functionalization because:
+(1) The nn.Parameter is never used before its .copy_() is called in eager code (i.e. no alias of it is created),
+so it's safe to call .copy_() in the middle of the graph to update its content and start using the nn.Parameter downstream.
+(2) We always re-allocate the buffer for nn.Parameter to store the AllGather output and to be used in downstream user ops.
+So calling resize-to-0 in the middle of the graph to free nn.Parameter memory after use should always be okay
+(since we always allocate anew next time we need it, we strictly don't need to keep the old tensor storage around anymore).
+
+Q: Wouldn't the extra resize_ and copy_ ops hurt both memory usage and performance?
+A: Yes it would. As an optimization, we have an Inductor post-grad FX pass to remove those resize_ and copy_ ops
+for unsharded params that have this pattern: resize_(full) -> copy_ -> resize_(0).
+
+TODO:
+Now that we are maintaining the invariant of "no aliased + mutated graph inputs" in both the forward and backward,
+it is now more feasible to functionalize all of the mutable FSDP ops. Some of the pros and cons are:
+
+Cons (of functionalizing those ops):
+(1) By not functionalizing them as we are today, we are making it more likely that they will run at the "correct" time
+in the generated code. If we start to functionalize them, we will need to make sure that Inductor reinplaces them
+in a way where it properly moves the mutations back to exactly where they should have run, or we risk suffering worse
+peak memory than eager. (We probably already need to do something similar in Inductor's reinplacing for copy_:
+https://github.com/pytorch/pytorch/issues/135305#issuecomment-2334888089)
+
+Pros (of functionalizing):
+(1) Better safety, we don't need to worry about the graph passes in inductor/partitioning handling input mutations
+mid-graph quite as much (to be fair we've already done some amount of auditing, but we might have to do some more).
+(2) Better perf: each mutation midway through the graph prevents Inductor from pattern matching across it.
+But maybe there are few enough mutations induced by FSDP for this to matter.
+"""
+
+
+@torch.library.impl(lib, "copy_", "Functionalize")
+def copy__functionalize(tensor, data):
+    torch._sync(tensor)
+    torch._sync(data)
+    tensor_inner = torch._from_functional_tensor(tensor)
+    data_inner = torch._from_functional_tensor(data)
+    with torch._C._ExcludeDispatchKeyGuard(
+        torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
+    ):
+        torch.ops.fsdp.copy_.default(tensor_inner, data_inner)
+
+
+torch.fx.node.has_side_effect(torch.ops.fsdp.copy_.default)
+
+
+class ShardedState(Enum):
+    """
+    - ``SHARDED``: The sharded parameter is registered to the module. It is the
+      only contributor to parameter memory.
+    - ``SHARDED_POST_FORWARD``: The unsharded parameter is resharded to a
+      smaller world size. Since this data should not be used for computation,
+      we do not register it to the module. Users should reshard the module
+      before any in-place modifications. Both it and the sharded parameter
+      contribute to parameter memory.
+    - ``UNSHARDED``: The unsharded parameter is registered to the module. Both
+      it and the sharded parameter contribute to parameter memory.
+    """
+
+    SHARDED = auto()
+    SHARDED_POST_FORWARD = auto()
+    UNSHARDED = auto()
+
+
+@dataclass
+class ParamModuleInfo:
+    """
+    For a parameter, this stores the module and the parameter name to be able
+    to do a parameter swap via ``setattr(module, param_name, ...)`` or to get
+    the parameter via ``getattr(module, param_name)``. We additionally save
+    shared modules and shared parameter names to update them accordingly.
+    """
+
+    # Parameter names are unprefixed, e.g. "weight", not "lin.weight"
+    module: nn.Module
+    param_name: str
+    shared_modules: list[nn.Module] = field(default_factory=list)
+    shared_param_names: list[str] = field(default_factory=list)
+
+
+@dataclass
+class ExtensionsData:
+    # User-defined metadata passed from pre to post-all-gather
+    all_gather_metadata: Optional[Any] = None
+    # Save the all-gather input sizes to unflatten the all-gather outputs to ND
+    all_gather_input_sizes: Sequence[torch.Size] = ()  # ND
+
+    def clear(self):
+        self.all_gather_metadata = None
+        self.all_gather_input_sizes = ()
+
+
+class FSDPParam:
+    """
+    This class manages a parameter with FSDP or FSDP variants applied,
+    implementing dim-0 per-parameter sharding.
+    """
+
+    orig_dtype: torch.dtype
+    param_dtype: Optional[torch.dtype]
+    reduce_dtype: Optional[torch.dtype]
+    _orig_size: torch.Size  # ND
+    sharded_size: torch.Size  # ND
+    contiguous_sharded_stride: tuple[int, ...]
+    padded_sharded_param_size: torch.Size  # ND
+    sharded_post_forward_size: torch.Size  # ND
+    contiguous_sharded_post_forward_stride: tuple[int, ...]
+    _sharded_param_data: torch.Tensor  # 1D
+    sharded_param: nn.Parameter  # ND
+    _sharded_post_forward_param_data: Optional[torch.Tensor]  # 1D
+    _sharded_post_forward_param: Optional[nn.Parameter]  # ND
+    _unsharded_param: nn.Parameter  # ND
+    unsharded_accumulated_grad: Optional[torch.Tensor]  # ND
+    _sharding_spec: DTensorSpec
+    # DTensor attributes (only defined for DTensor `param`):
+    _tp_spec: DTensorSpec
+    all_gather_outputs: list[torch.Tensor]  # 1D
+    # All-gather extension attributes
+    _extensions_data: ExtensionsData
+    _unsharded_inner_tensors: list[torch.Tensor]
+
+    def __init__(
+        self,
+        param: nn.Parameter,
+        module_info: ParamModuleInfo,
+        mesh_info: FSDPMeshInfo,
+        post_forward_mesh_info: Optional[FSDPMeshInfo],
+        device: torch.device,
+        shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]],
+        mp_policy: MixedPrecisionPolicy,
+        offload_policy: OffloadPolicy,
+    ):
+        self._module_info: ParamModuleInfo = module_info
+        self.mesh_info = mesh_info
+        self.post_forward_mesh_info = post_forward_mesh_info
+        # pyrefly: ignore [read-only]
+        self.device = device
+        self.mp_policy = mp_policy
+        self.offload_to_cpu: bool = isinstance(offload_policy, CPUOffloadPolicy)
+        self.pin_memory = (
+            self.offload_to_cpu and cast(CPUOffloadPolicy, offload_policy).pin_memory
+        )
+        self.grad_offload_event: Optional[torch.Event] = None
+        self._init_sharded_param(param, device, shard_placement_fn)
+        if self.post_forward_mesh_info:
+            self._init_sharded_post_forward_param_metadata(param)
+        self._init_extensions()
+        self.all_gather_outputs: list[torch.Tensor] = []
+        self.unsharded_accumulated_grad = None
+        self._param_fqn: Optional[str] = None  # prefixed from root module
+        # TODO: Remove this padding logic once DTensor pads the local tensor:
+        # https://github.com/pytorch/pytorch/issues/113045
+        self._post_load_hook_handle = (
+            module_info.module.register_load_state_dict_post_hook(
+                lambda *args, **kwargs: self.reset_sharded_param()
+            )
+        )
+
+    @torch.no_grad()
+    def _init_sharded_param(
+        self,
+        param: nn.Parameter,
+        device: torch.device,
+        shard_placement_fn: Optional[Callable],
+    ):
+        if param.device != device and param.device.type != "meta":
+            raise AssertionError(
+                f"Expects the parameter to already be moved to device {device} but got {param.device}"
+            )
+        if not param.is_contiguous():
+            raise NotImplementedError(
+                f"FSDP does not support non-contiguous parameters yet: {param.shape=} {param.stride()=}"
+            )
+        fsdp_placement = shard_placement_fn(param) if shard_placement_fn else None
+        if fsdp_placement is None:
+            fsdp_placement = Shard(0)
+        elif fsdp_placement.dim < 0:
+            fsdp_placement = Shard(fsdp_placement.dim + param.ndim)
+        if not isinstance(fsdp_placement, Shard):
+            raise AssertionError(
+                f"Expected Shard, got {type(fsdp_placement)}: {fsdp_placement}"
+            )
+        self.fsdp_placement = fsdp_placement
+        shard_dim = fsdp_placement.dim
+        # TODO: Replace the sharded DTensor parameter construction logic with
+        # `distribute_tensor` after https://github.com/pytorch/pytorch/issues/116101
+        # TODO: Simplify the following sharded parameter padding logic after
+        # https://github.com/pytorch/pytorch/issues/113045
+        self.is_dtensor = isinstance(param, DTensor)
+        if self.is_dtensor:
+            self._tp_spec = cast(DTensor, param)._spec
+            dp_mesh, tp_mesh = (self.mesh_info.mesh, self._tp_spec.mesh)
+            if dp_mesh is None or tp_mesh is None:
+                raise AssertionError(
+                    "FSDP requires the DP and model parallel TP/EP mesh to be not None but got: \n"
+                    f"DP's mesh: {dp_mesh}\nTP/EP's mesh: {tp_mesh}"
+                )
+            self._spmd_mesh = DeviceMesh._concatenate([dp_mesh, tp_mesh])
+            if len(self._tp_spec.placements) > 2:
+                raise NotImplementedError(
+                    f"FSDP only supports 1D TP/EP or 2D EP+TP, not {self._tp_spec.placements}"
+                )
+            split_factor = self._tp_spec.num_shards_map[shard_dim]
+            if not (2 <= self._spmd_mesh.ndim <= 4):
+                raise AssertionError(
+                    "_spmd_mesh.ndim can only be 2 (FSDP+TP/EP), 3 (FSDP+EP+TP, HSDP+TP/EP), "
+                    f"or 4 (HSDP+EP+TP) but got {self._spmd_mesh.ndim}."
+                )
+            self._spmd_placements: tuple[Placement, ...]
+            if isinstance(self.mesh_info, FSDPMeshInfo):  # FSDP or HSDP
+                dp_shard_tp_placement = (
+                    (
+                        _StridedShard(shard_dim, split_factor=split_factor)
+                        if split_factor > 1
+                        else fsdp_placement
+                    ),
+                    *self._tp_spec.placements,
+                )
+            else:  # DDP
+                dp_shard_tp_placement = (
+                    (Replicate()),
+                    *self._tp_spec.placements,
+                )
+            if isinstance(self.mesh_info, HSDPMeshInfo):  # HSDP
+                if self.mesh_info.replicate_mesh_dim != 0:
+                    raise AssertionError(
+                        f"Expected replicate_mesh_dim to be 0, got {self.mesh_info.replicate_mesh_dim}"
+                    )
+                self._spmd_placements = (Replicate(),) + dp_shard_tp_placement
+            else:  # FSDP or DDP
+                self._spmd_placements = dp_shard_tp_placement
+
+            self._sharding_spec = DTensorSpec(
+                self._spmd_mesh,
+                self._spmd_placements,
+                tensor_meta=self._tp_spec.tensor_meta,
+            )
+            param_data = cast(DTensor, param)._local_tensor
+        else:
+            self._spmd_mesh = self.mesh_info.mesh
+            if isinstance(self.mesh_info, HSDPMeshInfo):  # HSDP
+                self._spmd_placements = (Replicate(), fsdp_placement)
+            elif isinstance(self.mesh_info, FSDPMeshInfo):  # FSDP
+                self._spmd_placements = (fsdp_placement,)
+            elif isinstance(self.mesh_info, DDPMeshInfo):  # DDP
+                self._spmd_placements = (Replicate(),)
+            self._sharding_spec = DTensorSpec(
+                self._spmd_mesh,
+                self._spmd_placements,
+                tensor_meta=TensorMeta(param.size(), param.stride(), param.dtype),
+            )
+            param_data = param
+        if not param_data.is_contiguous():
+            raise AssertionError(
+                f"Expected contiguous tensor, got {param_data.shape=} {param_data.stride()=}"
+            )
+        shard_dim = fsdp_placement.dim
+        if shard_dim >= param_data.ndim:
+            raise AssertionError(
+                f"Shard dim {shard_dim} is invalid for {param_data.ndim}D tensor: {param.shape}"
+            )
+        self._orig_size = param_data.size()
+        self._contiguous_orig_stride = make_contiguous_strides_for(self._orig_size)
+        if isinstance(self.mesh_info, FSDPMeshInfo):  # FSDP or HSDP
+            shard_rank = self.mesh_info.shard_mesh_rank
+            shard_world_size = self.mesh_info.shard_mesh_size
+        else:  # DDP
+            shard_rank = 0
+            shard_world_size = 1
+
+        if shard_dim > 0 and param_data.size(shard_dim) % shard_world_size != 0:
+            # If sharding on nonzero dim, require even sharding for now because
+            # the uneven sharding (1) requires extra copies before/after FSDP
+            # collectives and (2) introduces extra complexity to handle padding
+            # and unpadding
+            raise NotImplementedError(
+                f"FSDP does not support uneven sharding on dim {shard_dim}: "
+                f"{param_data.size()} (world size: {shard_world_size})"
+            )
+        chunks = _chunk_with_empty(param_data, shard_world_size, dim=shard_dim)
+        sharded_param = chunks[shard_rank]
+        self.sharded_size = _get_dim_chunked_size(
+            sharded_param, param_data.size(), dim=shard_dim
+        )
+        self.contiguous_sharded_stride = make_contiguous_strides_for(self.sharded_size)
+        padded_sharded_size = chunks[0].size()  # 0th always padded
+        self.padded_sharded_param_size = padded_sharded_size
+        # Pre-pad the sharded parameter to avoid padding before all-gather
+        padded_sharded_param = param_data.new_zeros(padded_sharded_size)
+        if sharded_param.numel() > 0:
+            padded_sharded_param.narrow(
+                dim=shard_dim, start=0, length=sharded_param.size(shard_dim)
+            ).copy_(sharded_param)
+        if self.offload_to_cpu and not padded_sharded_param.is_meta:
+            padded_sharded_param = padded_sharded_param.cpu()
+            if self.pin_memory:
+                padded_sharded_param = padded_sharded_param.pin_memory()
+        self._sharded_param_data = padded_sharded_param.view(-1)
+        length = sharded_param.size(shard_dim) if sharded_param.numel() > 0 else 0
+        sharded_param = padded_sharded_param.narrow(
+            dim=shard_dim, start=0, length=length
+        )
+        if not sharded_param.is_contiguous():
+            raise AssertionError(
+                f"Expected contiguous tensor with {self.fsdp_placement=}"
+            )
+        self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
+        self.sharded_param.requires_grad_(param.requires_grad)
+        # Let `param_data` be freed normally when its ref count reaches 0 when
+        # the `fully_shard` call returns to allow provided parameters to alias
+        self._setattr_on_modules(self.sharded_param)
+        self.sharded_state = ShardedState.SHARDED
+
+    def _init_sharded_post_forward_param_metadata(self, param: torch.Tensor) -> None:
+        mesh_info = self.post_forward_mesh_info
+        if mesh_info is None:
+            raise AssertionError("Expected post_forward_mesh_info to not be None")
+        param_data = param._local_tensor if isinstance(param, DTensor) else param
+        if isinstance(mesh_info, FSDPMeshInfo):
+            chunks = _chunk_with_empty(param_data, mesh_info.shard_mesh_size, dim=0)
+            self.sharded_post_forward_size = _get_dim_chunked_size(
+                chunks[mesh_info.shard_mesh_rank],
+                param_data.size(),
+                dim=self.fsdp_placement.dim,
+            )
+        else:  # DDP
+            chunks = _chunk_with_empty(param_data, 1, dim=0)
+            self.sharded_post_forward_size = _get_dim_chunked_size(
+                chunks[0],
+                param_data.size(),
+                dim=self.fsdp_placement.dim,
+            )
+        self.contiguous_sharded_post_forward_stride = make_contiguous_strides_for(
+            self.sharded_post_forward_size
+        )
+
+    def init_dtype_attrs(self, mp_policy: MixedPrecisionPolicy):
+        param_dtype, reduce_dtype = (mp_policy.param_dtype, mp_policy.reduce_dtype)
+        self.orig_dtype = self.sharded_param.dtype
+        # Clamp `reduce_dtype` to `None` if no casting is required: since
+        # gradients are computed in `param_dtype`, if `reduce_dtype` matches,
+        # then we do not need extra casting
+        if reduce_dtype == param_dtype:
+            reduce_dtype = None
+        # Clamp `param_dtype` to `None` if no casting is required
+        if param_dtype == self.orig_dtype:
+            param_dtype = None
+        self.param_dtype = param_dtype
+        self.reduce_dtype = reduce_dtype
+        # None indicates that the mixed precision is not enabled
+
+    def _init_extensions(self) -> None:
+        inner_tensor = self._sharded_local_tensor
+        has_fsdp_pre_all_gather = hasattr(inner_tensor, "fsdp_pre_all_gather")
+        has_fsdp_post_all_gather = hasattr(inner_tensor, "fsdp_post_all_gather")
+        if has_fsdp_pre_all_gather != has_fsdp_post_all_gather:
+            raise AssertionError(
+                "Both fsdp_pre_all_gather and fsdp_post_all_gather should be defined "
+                f"if using all-gather extensions: {inner_tensor}"
+            )
+        if has_fsdp_pre_all_gather:
+            self._extensions_data = ExtensionsData()
+        self._unsharded_inner_tensors: list[torch.Tensor] = []
+
+    def init_all_gather_outputs(
+        self,
+        all_gather_input_numels: list[int],
+        all_gather_input_dtypes: list[torch.dtype],
+        world_size: int,
+        device: torch.device,
+        force_recreate: bool = False,
+    ):
+        if not force_recreate and len(self.all_gather_outputs) > 0:
+            return  # already initialized
+        self.all_gather_outputs = [
+            torch.empty(torch.Size([numel * world_size]), dtype=dtype, device=device)
+            for numel, dtype in zip(all_gather_input_numels, all_gather_input_dtypes)
+        ]
+
+    def init_unsharded_param(self):
+        """
+        [Note: Invariants for torch.compile Traceable FSDP2]
+        1. Under compile, we always re-populate the content of `self._unsharded_param`
+           per AllGather using the slow path.
+        2. Under compile, we always recreate `self.all_gather_outputs` per AllGather.
+           This is to ensure the buffer creation is internal to the graph and
+           avoid `self.all_gather_outputs` being captured as a graph input.
+        3. Under compile, at the end of `free_unsharded_param()`, we always clean up
+           `self.all_gather_outputs` and `self._unsharded_inner_tensors`,
+           to avoid them being captured as graph output.
+
+        With these invariants, only these tensors will be inputs to the graph:
+        - Sharded parameters
+        - Placeholders for the `self._unsharded_param` nn.Parameter
+        """
+        if not compiled_autograd_enabled() and hasattr(
+            self, "_unsharded_param"
+        ):  # after the 1st all-gather
+            inner_tensor = self._sharded_local_tensor
+            if not hasattr(inner_tensor, "fsdp_post_all_gather"):
+                return  # already initialized
+            for tensor in self._unsharded_inner_tensors:
+                alloc_storage(tensor)
+            all_gather_outputs = self._unflatten_all_gather_outputs()
+            inner_tensor.fsdp_post_all_gather(
+                all_gather_outputs,
+                self._extensions_data.all_gather_metadata,
+                self.param_dtype or self.orig_dtype,
+                out=self._unsharded_param,
+            )
+            self._extensions_data.clear()
+            return
+        inner_tensor = self._sharded_local_tensor
+        if not compiled_autograd_enabled() and hasattr(
+            inner_tensor, "fsdp_post_all_gather"
+        ):
+            all_gather_outputs = self._unflatten_all_gather_outputs()
+            (
+                unsharded_tensor,
+                self._unsharded_inner_tensors,
+            ) = inner_tensor.fsdp_post_all_gather(
+                all_gather_outputs,
+                self._extensions_data.all_gather_metadata,
+                self.param_dtype or self.orig_dtype,
+            )
+            self._extensions_data.clear()
+        else:
+            # For the default path (no post-all-gather), the all-gather output
+            # gives the unsharded parameter data directly
+            if len(self.all_gather_outputs) != 1:
+                raise AssertionError(
+                    f"Expected 1 all_gather_output, got {len(self.all_gather_outputs)}"
+                )
+            unsharded_tensor = self.all_gather_outputs[0]
+        unsharded_param = torch.as_strided(
+            unsharded_tensor,
+            self._orig_size,
+            self._contiguous_orig_stride,
+            storage_offset=0,
+        )
+        if self.is_dtensor:
+            unsharded_param = _from_local_no_grad(unsharded_param, self._tp_spec)
+        if hasattr(self, "_unsharded_param"):
+            if not compiled_autograd_enabled():
+                raise AssertionError("Expected compiled_autograd to be enabled")
+            with (
+                torch.no_grad(),
+                torch.autograd._unsafe_preserve_version_counter(self._unsharded_param),
+            ):
+                # NOTE: Under compile, if an unsharded param goes through
+                # resize_(full) -> copy_ -> resize_(0) pattern, we will remove those
+                # resize_ and copy_ ops in a compiler graph pass
+                # `remove_fsdp2_unsharded_param_graph_input_usage` to recover performance.
+                self._unsharded_param.untyped_storage().resize_(
+                    self._unsharded_param.numel() * self._unsharded_param.itemsize
+                )
+                torch.ops.fsdp.copy_(self._unsharded_param, unsharded_param)
+        else:
+            self._unsharded_param = nn.Parameter(
+                unsharded_param, requires_grad=self.sharded_param.requires_grad
+            )
+
+    def _unflatten_all_gather_outputs(self) -> tuple[torch.Tensor, ...]:
+        return tuple(
+            t.view(-1, *s[1:])
+            for t, s in zip(
+                self.all_gather_outputs, self._extensions_data.all_gather_input_sizes
+            )
+        )
+
+    def to_sharded(self) -> None:
+        self._setattr_on_modules(self.sharded_param)
+        self.free_unsharded_param()
+        self.sharded_state = ShardedState.SHARDED
+
+    def to_sharded_post_forward(self) -> None:
+        if self.is_dtensor:
+            raise NotImplementedError(
+                "Resharding to smaller mesh with TP is not supported yet"
+            )
+        self._assert_in_states(ShardedState.UNSHARDED)
+        if self.post_forward_mesh_info is None:
+            raise AssertionError("Expected post_forward_mesh_info to not be None")
+        if len(self.all_gather_outputs) != 1:
+            raise AssertionError(
+                f"Expected 1 all_gather_output, got {len(self.all_gather_outputs)}"
+            )
+        shard_world_size = self.post_forward_mesh_info.shard_mesh_size
+        if (numel := self.all_gather_outputs[0].numel()) % shard_world_size != 0:
+            _raise_assert_with_print(
+                f"All-gather output size ({numel}) must be divisible by the shard "
+                f"world size ({shard_world_size})"
+            )
+        shard_rank = self.post_forward_mesh_info.shard_mesh_rank
+        # pyrefly: ignore [unbound-name]
+        sharded_numel = numel // shard_world_size
+        self._sharded_post_forward_param_data = (
+            self.all_gather_outputs[0].narrow(
+                0, sharded_numel * shard_rank, sharded_numel
+            )
+        ).clone()  # clone to be able to free all-gather output
+        sharded_post_forward_tensor = torch.as_strided(
+            self._sharded_post_forward_param_data,
+            size=self.sharded_post_forward_size,
+            stride=self.contiguous_sharded_post_forward_stride,
+            storage_offset=0,
+        )
+        self._sharded_post_forward_param = nn.Parameter(
+            self.to_sharded_post_forward_dtensor(sharded_post_forward_tensor)
+        )
+        self._setattr_on_modules(self._sharded_post_forward_param)
+        self.free_unsharded_param()
+        self.sharded_state = ShardedState.SHARDED_POST_FORWARD
+
+    def to_unsharded(self) -> None:
+        # Assume that the data has been allocated and all-gathered
+        set_requires_grad_if_needed(self.sharded_param, self._unsharded_param)
+        self._setattr_on_modules(self._unsharded_param)
+        if self.sharded_state == ShardedState.SHARDED_POST_FORWARD:
+            # The data is allocated in the default stream via the post-forward
+            # reshard and must be kept alive for the next all-gather copy-in.
+            # Since we call this method after the copy-out, the data's lifetime
+            # is ensured without further synchronization.
+            self._sharded_post_forward_param = None
+            self._sharded_post_forward_param_data = None  # free
+        self.sharded_state = ShardedState.UNSHARDED
+
+    def _setattr_on_modules(self, param: nn.Parameter) -> None:
+        unsafe_setattr_param(
+            self._module_info.module, self._module_info.param_name, param
+        )
+        for shared_module, shared_param_name in zip(
+            self._module_info.shared_modules, self._module_info.shared_param_names
+        ):
+            unsafe_setattr_param(shared_module, shared_param_name, param)
+
+    def to_sharded_dtensor(self, tensor: torch.Tensor) -> DTensor:
+        """
+        Converts a local tensor representing either the sharded parameter or
+        sharded gradient to DTensor.
+        """
+        if tensor.shape != self.sharded_size:
+            _raise_assert_with_print(
+                f"Expects size {self.sharded_size} but got {tensor.shape}"
+            )
+        return _from_local_no_grad(
+            tensor,
+            self._sharding_spec,
+        )
+
+    def to_sharded_post_forward_dtensor(self, tensor: torch.Tensor) -> DTensor:
+        if tensor.shape != self.sharded_post_forward_size:
+            _raise_assert_with_print(
+                f"Expects size {self.sharded_post_forward_size} but got {tensor.shape}"
+            )
+        if not isinstance(self.post_forward_mesh_info, HSDPMeshInfo):
+            raise AssertionError(
+                f"Expected HSDPMeshInfo, got {type(self.post_forward_mesh_info)}"
+            )
+        # TODO: Prefer this DTensor to be read-only and generalize the
+        # placement once we support TP.
+        post_forward_sharding_spec = DTensorSpec(
+            self.post_forward_mesh_info.mesh,
+            (Replicate(), Shard(0)),
+            tensor_meta=self._sharding_spec.tensor_meta,
+        )
+        return _from_local_no_grad(tensor, post_forward_sharding_spec)
+
+    def to_accumulated_grad_if_needed(self) -> None:
+        # Access `_unsharded_param` to bypass the sharded state check since we
+        # prefer to reshard before upcasting the gradient to save memory
+        if (
+            self.reduce_dtype is None
+            or self._unsharded_param.grad is None
+            or self._unsharded_param.grad.dtype == self.reduce_dtype
+        ):
+            return
+        unsharded_grad = self._unsharded_param.grad
+        self._unsharded_param.grad = None
+        self.unsharded_accumulated_grad = unsharded_grad.to(self.reduce_dtype)
+
+    def accumulate_unsharded_grad_if_needed(self) -> None:
+        if (
+            self.unsharded_accumulated_grad is not None
+            and self.unsharded_param.grad is not None
+        ):
+            self.unsharded_accumulated_grad += self.unsharded_param.grad
+            self.unsharded_param.grad = None
+
+    def alloc_all_gather_outputs(self) -> None:
+        for tensor in self.all_gather_outputs:
+            alloc_storage(tensor)
+
+    def free_unsharded_param(self) -> None:
+        if compiled_autograd_enabled():
+            """
+            Assumptions under compile:
+            - `self._unsharded_param` is NOT an alias of `self.all_gather_outputs`.
+            Instead, we resize `self._unsharded_param` storage size to full and then
+            explicitly *copy* the data from `self.all_gather_outputs` to `self._unsharded_param`
+            in `init_unsharded_param()`. (For full-graph FSDP2 case, we will then remove
+            the resize_ and copy_ ops in a compiler graph pass to recover performance.)
+            - `self.all_gather_outputs` and `self._unsharded_inner_tensors` are NOT
+            graph inputs. They are created within the graph and is guaranteed to be freed
+            by the end of the graph. They don't leak outside of the graph.
+            """
+            self._unsharded_param.untyped_storage().resize_(0)
+            self.all_gather_outputs = []
+            self._unsharded_inner_tensors = []
+        else:
+            for tensor in itertools.chain(
+                self.all_gather_outputs, self._unsharded_inner_tensors
+            ):
+                free_storage(tensor)
+
+    @property
+    def all_gather_inputs(self) -> list[torch.Tensor]:  # 1D
+        self._assert_in_states(ShardedState.SHARDED, ShardedState.SHARDED_POST_FORWARD)
+        if self.sharded_state == ShardedState.SHARDED:
+            if not compiled_autograd_enabled() and hasattr(
+                self._sharded_local_tensor, "fsdp_pre_all_gather"
+            ):
+                sharded_local_tensor = self._sharded_local_tensor
+                if self.offload_to_cpu:
+                    sharded_local_tensor = sharded_local_tensor.to(
+                        self.device, non_blocking=True
+                    )
+                pre_all_gather_signature = inspect.signature(
+                    # pyrefly: ignore [missing-attribute]
+                    sharded_local_tensor.fsdp_pre_all_gather
+                )
+                num_fn_params = len(pre_all_gather_signature.parameters)
+                # Old signature only passes mesh; keep for BC for now
+                if num_fn_params not in (1, 5):
+                    raise AssertionError(
+                        f"Invalid fsdp_pre_all_gather: {pre_all_gather_signature}\n"
+                        "Expects fsdp_pre_all_gather(self, mesh: DeviceMesh, "
+                        "outer_size: torch.Size, outer_stride: tuple[int, ...], "
+                        "module: nn.Module, mp_policy: MixedPrecisionPolicy)"
+                    )
+                if num_fn_params == 1:
+                    (
+                        all_gather_inputs,
+                        self._extensions_data.all_gather_metadata,
+                        # pyrefly: ignore [missing-attribute]
+                    ) = sharded_local_tensor.fsdp_pre_all_gather(
+                        self.shard_mesh_from_root
+                    )
+                else:
+                    (
+                        all_gather_inputs,
+                        self._extensions_data.all_gather_metadata,
+                        # pyrefly: ignore [missing-attribute]
+                    ) = sharded_local_tensor.fsdp_pre_all_gather(
+                        self.shard_mesh_from_root,
+                        self._orig_size,
+                        self._contiguous_orig_stride,
+                        self._module_info.module,
+                        self.mp_policy,
+                    )
+                    if (
+                        sharded_local_tensor.size() != self.padded_sharded_param_size
+                        and any(
+                            all_gather_input.size() != self.padded_sharded_param_size
+                            for all_gather_input in all_gather_inputs
+                        )
+                    ):
+                        # NOTE: Since this error can only be raised on the
+                        # ranks that have padding, this can manifest as a NCCL
+                        # watchdog timeout, as the other ranks will not error.
+                        raise AssertionError(
+                            "When a parameter is unevenly sharded by FSDP "
+                            f"(orig size={self._orig_size}, FSDP world size={self.mesh_info.mesh.size()}), "
+                            "fsdp_pre_all_gather must return all-gather inputs with the padded sharded size "
+                            f"{self.padded_sharded_param_size} but got {[t.size() for t in all_gather_inputs]}"
+                        )
+                self._extensions_data.all_gather_input_sizes = [
+                    t.size() for t in all_gather_inputs
+                ]
+                return [t.view(-1) for t in all_gather_inputs]
+            sharded_param_data = self._sharded_param_data
+            if self.offload_to_cpu:
+                sharded_param_data = sharded_param_data.to(
+                    self.device, non_blocking=True
+                )
+            return [_to_dtype_if_needed(sharded_param_data, self.param_dtype)]
+        elif self.sharded_state == ShardedState.SHARDED_POST_FORWARD:
+            if not compiled_autograd_enabled() and hasattr(
+                self._sharded_local_tensor, "fsdp_pre_all_gather"
+            ):
+                raise NotImplementedError
+            all_gather_input = _to_dtype_if_needed(
+                cast(torch.Tensor, self._sharded_post_forward_param_data),
+                self.param_dtype,
+            )
+            return [all_gather_input]
+        return [torch.empty(0)]  # mypy
+
+    @property
+    def unsharded_param(self) -> nn.Parameter:  # ND
+        return self._unsharded_param
+
+    @property
+    def unsharded_grad_data(self) -> torch.Tensor:
+        grad = self.unsharded_param.grad
+        if grad is None:
+            raise AssertionError("Expects unsharded_param.grad to not be None")
+        return self._get_grad_inner_tensor(grad)
+
+    @property
+    def unsharded_accumulated_grad_data(self) -> torch.Tensor:
+        grad = self.unsharded_accumulated_grad
+        if grad is None:
+            raise AssertionError("Expects unsharded_accumulated_grad to not be None")
+        return self._get_grad_inner_tensor(grad)
+
+    def _get_grad_inner_tensor(self, grad: torch.Tensor) -> torch.Tensor:
+        if self.is_dtensor:
+            if isinstance(grad, AsyncCollectiveTensor):
+                grad = grad.wait()
+            if not isinstance(grad, DTensor):
+                raise AssertionError(f"Expected DTensor, got {type(grad)}")
+            placements = self._tp_spec.placements
+            if placements != grad.placements:
+                if len(self._tp_spec.placements) != len(grad.placements):
+                    raise AssertionError(
+                        f"Expected same placement length: {self._tp_spec=} {grad.placements=}"
+                    )
+                grad = grad.redistribute(placements=placements)
+            grad = grad._local_tensor
+        return grad
+
+    @property
+    def _sharded_local_tensor(self) -> torch.Tensor:
+        return cast(DTensor, self.sharded_param)._local_tensor
+
+    @property
+    def shard_mesh(self):
+        mesh = self.mesh_info.mesh
+        if mesh.ndim == 1:
+            return mesh
+        elif mesh.ndim == 2:
+            if mesh.mesh_dim_names is None:
+                raise AssertionError("Expected mesh_dim_names to not be None")
+            return mesh[mesh.mesh_dim_names[-1]]
+        raise ValueError(f"Invalid mesh: {mesh}")
+
+    @property
+    def shard_mesh_from_root(self):
+        mesh = self.mesh_info.mesh
+
+        if mesh.ndim == 1:
+            return mesh
+        else:
+            if mesh.mesh_dim_names is None:
+                raise AssertionError("Expected mesh_dim_names to not be None")
+            shard_dim_name = mesh.mesh_dim_names[-1]
+            return mesh[shard_dim_name]
+
+    def _assert_in_states(self, *states: ShardedState) -> None:
+        if self.sharded_state not in states:
+            _raise_assert_with_print(
+                f"Expects to be in one of {states}, not {self.sharded_state}"
+            )
+
+    def reset_sharded_param(self):
+        # For ops like `nn.Module._apply` or `load_state_dict(assign=True)`
+        # that change the sharded parameter tensor, we may need to re-pad the
+        # sharded local tensor and re-save the reference.
+        module_info = self._module_info
+        new_param = getattr(module_info.module, module_info.param_name)
+        if new_param is not self.sharded_param:
+            if torch.__future__.get_swap_module_params_on_conversion():
+                raise AssertionError(
+                    f"Expects swap_tensors to preserve object but got {new_param} "
+                    f"instead of {self.sharded_param}"
+                )
+            self.sharded_param = new_param
+        # pyrefly: ignore [missing-attribute]
+        local_tensor = new_param._local_tensor
+        if local_tensor.is_meta:
+            return
+        updated_local_tensor = False
+        # local_tensor can be padded twice
+        # 1st time in fully_shard(model)
+        # 2nd time in model(input) lazy_init
+        # 2nd time should be no-op if parameters remain unchanged
+        # 2nd time shouldn't be no-op if people call model.load_state_dict(...) before lazy_init
+        # this makes it possible for trainer to call `sd = model.state_dict()` before the training loop
+        # and use `sd` without calling .state_dict() per iteration
+        same_local_tensor = False
+        # TODO: need to support tensor subclass
+        if type(self._sharded_param_data) is torch.Tensor:
+            same_local_tensor = (
+                # when sharding param with shape (1, ...) over 2 ranks
+                # local_tensor on rank 1 can be size 0, data_ptr() can be 0
+                self._sharded_param_data.untyped_storage().data_ptr() > 0
+                and self._sharded_param_data.untyped_storage().data_ptr()
+                == local_tensor.untyped_storage().data_ptr()
+            )
+        padded_sharded_size = self.padded_sharded_param_size
+        shard_dim = self.fsdp_placement.dim
+        length = local_tensor.size(shard_dim) if local_tensor.numel() > 0 else 0
+        if local_tensor.size() != padded_sharded_size and not same_local_tensor:
+            if shard_dim != 0:
+                raise AssertionError(
+                    f"Shard({shard_dim}) requires even sharding: {local_tensor.size()=}"
+                )
+            padded_local_tensor = local_tensor.new_zeros(padded_sharded_size)
+            padded_local_tensor.narrow(dim=shard_dim, start=0, length=length).copy_(
+                local_tensor
+            )
+            local_tensor = padded_local_tensor
+            updated_local_tensor = True
+        if self.pin_memory and not local_tensor.is_pinned():
+            local_tensor = local_tensor.cpu().pin_memory()
+            updated_local_tensor = True
+        if not same_local_tensor:
+            self._sharded_param_data = local_tensor.view(-1)
+        if not isinstance(self.sharded_param, DTensor):
+            raise AssertionError(f"Expected DTensor, got {type(self.sharded_param)}")
+        if updated_local_tensor:
+            # Only change the local tensor object if needed
+            self.sharded_param._local_tensor = local_tensor.narrow(
+                dim=shard_dim, start=0, length=length
+            )
+            if not self.sharded_param._local_tensor.is_contiguous():
+                raise AssertionError(
+                    "Expected sharded_param._local_tensor to be contiguous"
+                )
+        self._sharding_spec = self.sharded_param._spec
+
+    def __repr__(self):
+        return f"FSDPParam(fqn={self._param_fqn}, orig_size={self._orig_size})"
+
+
+def alloc_storage(tensor: torch.Tensor) -> None:
+    size = tensor.numel() * tensor.itemsize
+    if (storage := tensor.untyped_storage()).size() != size:
+        storage.resize_(size)
+
+
+def free_storage(tensor: torch.Tensor) -> None:
+    if (storage := tensor.untyped_storage()).size() != 0:
+        storage.resize_(0)
+
+
+# NOTE: These bypass `nn.Module.__setattr__` checks, which incur non-trivial
+# CPU overhead, if the module did not override it. For FSDP, we know we do not
+# need those checks when transitioning between sharded/unsharded parameters.
+def unsafe_setattr_param(
+    module: nn.Module, param_name: str, param: nn.Parameter
+) -> None:
+    if getattr(module.__setattr__, "__func__", None) is nn.Module.__setattr__:
+        module._parameters[param_name] = param
+    else:  # slow path
+        setattr(module, param_name, param)
+
+
+def set_requires_grad_if_needed(
+    src_tensor: torch.Tensor, dst_tensor: torch.Tensor
+) -> None:
+    # Only call `requires_grad_` if needed to avoid the Python <> C++ context
+    # switch overhead
+    if src_tensor.requires_grad != dst_tensor.requires_grad:
+        dst_tensor.requires_grad_(src_tensor.requires_grad)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py
new file mode 100644
index 0000000000000000000000000000000000000000..b70a5f06f4ae9b982b0f8e3a486573f79176c30b
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py
@@ -0,0 +1,901 @@
+# mypy: allow-untyped-defs
+import contextlib
+import logging
+from collections.abc import Callable
+from typing import Any, cast, NamedTuple, Optional
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.distributed.device_mesh import _get_device_handle
+from torch.distributed.fsdp._common_utils import _named_parameters_with_duplicates
+from torch.distributed.tensor import Shard
+from torch.profiler import record_function
+from torch.utils._pytree import tree_flatten, tree_unflatten
+from torch.utils.hooks import RemovableHandle
+
+from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
+from ._fsdp_collectives import (
+    AllGather,
+    AllGatherResult,
+    DefaultAllGather,
+    DefaultReduceScatter,
+    foreach_all_gather,
+    foreach_all_gather_copy_out,
+    foreach_reduce,
+    ProcessGroupAllocAllGather,
+    ProcessGroupAllocReduceScatter,
+    ReduceScatter,
+)
+from ._fsdp_common import (
+    compiled_autograd_enabled,
+    DDPMeshInfo,
+    FSDPMeshInfo,
+    HSDPMeshInfo,
+    is_bw,
+    TrainingState,
+)
+from ._fsdp_param import alloc_storage, FSDPParam, ParamModuleInfo, ShardedState
+
+
+logger = logging.getLogger("torch.distributed.fsdp.fully_shard")
+
+_ModuleToHandleDict = dict[nn.Module, RemovableHandle]  # for state dict
+
+
+"""
+[Note: Overlapping all-gather copy-in and all-gather]
+For implicit forward prefetching, we want to overlap the next copy-in with the
+current all-gather. We do so using a separate copy-in stream. However, since
+we have the all-gather input as a view into the output, we must make sure to
+copy into different memory from the current all-gather's output. Thus, we keep
+a reference to the current all-gather's output and have the next FSDP parameter
+group free it after its copy-in. Finally, we have the last FSDP state flush the
+reference to avoid holding onto memory after forward.
+"""
+
+
+class FSDPCommContext:
+    """This has the communication state shared across FSDP states/parameter groups."""
+
+    def lazy_init(self, device: torch.device):
+        self.device_handle = _get_device_handle(device.type)
+        # Setting the all-gather/reduce-scatter streams to be higher priority
+        # can help avoid some issues where their copies in/out are delayed and
+        # block computation (this is different from high-pri NCCL streams)
+        high_priority = -1
+        # All-gather state and copy-in stream allow overlapping the next
+        # copy-in with the current all-gather in forward; copy-in overlaps with
+        # reduce-scatter in backward without the separate copy-in stream
+        self.all_gather_copy_in_stream = self.device_handle.Stream(
+            priority=high_priority
+        )
+        # All-gather stream allows overlapping next all-gather with current
+        # forward compute
+        self.all_gather_stream = self.device_handle.Stream(priority=high_priority)
+        # Reduce-scatter stream gives separate execution "thread" for post-
+        # backward logic like pre/post-gradient division and reduce-scatter
+        self.reduce_scatter_stream = self.device_handle.Stream(priority=high_priority)
+        # Run the HSDP all-reduces concurrently with all-gather/reduce-scatter
+        # since collectives use different network resources and can overlap
+        # in the typical intra-node sharding / inter-node replication case
+        self.all_reduce_stream = self.device_handle.Stream()
+        # All-gather/reduce-scatter states keep references to collective
+        # tensors produced in one stream and used in another and accompanying
+        # CUDA events for synchronization
+        self.all_gather_state: Optional[AllGatherState] = None
+        self.reduce_scatter_state: Optional[ReduceScatterState] = None
+        # Post-forward order for explicit backward prefetching
+        self.post_forward_order: list[FSDPParamGroup] = []  # will cause ref cycles
+
+    def get_all_gather_streams(
+        self, async_op: bool, training_state: TrainingState
+    ) -> tuple[torch.Stream, torch.Stream]:
+        if not async_op and training_state in (
+            TrainingState.FORWARD,
+            TrainingState.PRE_BACKWARD,
+        ):
+            # Use separate streams for implicit prefetching
+            return self.all_gather_copy_in_stream, self.all_gather_stream
+        current_stream = self.device_handle.current_stream()
+        return current_stream, current_stream
+
+
+# See [Note: Overlapping all-gather copy-in and all-gather]
+class AllGatherState(NamedTuple):
+    all_gather_result: AllGatherResult
+    event: Optional[torch.Event]  # all-gather copy-out
+
+
+class ReduceScatterState(NamedTuple):
+    reduce_scatter_input: torch.Tensor
+    event: Optional[torch.Event]  # reduce-scatter event
+
+
+class AllReduceState(NamedTuple):
+    all_reduce_input: torch.Tensor
+    event: Optional[torch.Event]  # all-reduce event
+
+
+class FSDPParamGroup:
+    """This class represents a parameter group to communicate together."""
+
+    _orig_dtype: Optional[torch.dtype]
+    _reduce_dtype: Optional[torch.dtype]
+
+    def __init__(
+        self,
+        params: list[nn.Parameter],
+        modules: tuple[nn.Module, ...],
+        mesh_info: FSDPMeshInfo,
+        post_forward_mesh_info: Optional[FSDPMeshInfo],
+        device: torch.device,
+        shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]],
+        mp_policy: MixedPrecisionPolicy,
+        offload_policy: OffloadPolicy,
+    ):
+        self.modules = modules  # permit ref cycle because 1:1 lifetime
+        param_module_infos = _get_param_module_infos(params, modules)
+
+        self.fsdp_params = [
+            FSDPParam(
+                param,
+                module_info,
+                mesh_info,
+                post_forward_mesh_info,
+                device,
+                shard_placement_fn,
+                mp_policy,
+                offload_policy,
+            )
+            for param, module_info in zip(params, param_module_infos)
+        ]
+        self.mesh_info = mesh_info
+        self.post_forward_mesh_info = post_forward_mesh_info
+        # pyrefly: ignore [read-only]
+        self.device = device
+        self.device_handle = _get_device_handle(device.type)
+        self.mp_policy = mp_policy
+        self.offload_policy = offload_policy
+        self._training_state = TrainingState.IDLE
+        # Group's sharded state always matches its parameters' sharded states
+        self._sharded_state = ShardedState.SHARDED
+        self._module_fqn: Optional[str] = None  # prefixed from root module
+        # Only consider resetting sharded parameters once in lazy init since it
+        # can incur nontrivial overhead to reset them
+        self._reset_sharded_params: bool = False
+
+        # - Hook state
+        self._module_to_pre_save_state_dict_hook_handle: _ModuleToHandleDict = {}
+        self._module_to_pre_load_state_dict_hook_handle: _ModuleToHandleDict = {}
+        self._all_reduce_hook: Optional[Callable[[torch.Tensor], None]] = None
+        self._all_gather_comm: AllGather = DefaultAllGather()
+        self._all_gather_output = torch.empty(0, device=self.device)
+        self._reduce_scatter_comm: ReduceScatter = DefaultReduceScatter()
+        # Optional stream to run the user-defined all-reduce hook in
+        # Saved here and not in the comm. context because we allow the user to
+        # specify it, possibly at construction time before lazy init
+        self._all_reduce_hook_stream: Optional[torch.cuda.Stream] = None
+
+        # - Communication and communication/computation overlap
+        self.comm_ctx = FSDPCommContext()
+        # Group's indices in the shared post-forward order
+        self._post_forward_indices: list[int] = []
+        # Whether to reduce gradients at all (whether for FSDP or HSDP)
+        self.reduce_grads: bool = True
+        # Whether to all-reduce gradients for HSDP; only used if
+        # `self.reduce_grads` is true, in which case setting this to false
+        # means reduce-scatter but no all-reduce
+        self.all_reduce_grads: bool = True
+        # Whether to reshard parameters after backward (only useful for
+        # gradient accumulation)
+        self.reshard_after_backward: bool = True
+        # Optional custom factor for the gradient reduction op (e.g. to divide
+        # by a factor other than the world size)
+        self.gradient_divide_factor: Optional[float] = None
+        # Whether reduce-scatter and all-reduce should be issued using only
+        # summations, potentially with separate pre-/post-scaling.
+        self.force_sum_reduction_for_comms: bool = False
+        # `async_op` arg used for pre-forward/pre-backward unshard; can be
+        # overridden to only do explicit prefetching and avoid inter-stream
+        # fragmentation from using separate unshard streams
+        self.unshard_async_op: bool = False
+        # Whether to unshard in backward: can be overridden by the user if the
+        # parameters in this group are not needed for backward (e.g. embedding)
+        self.unshard_in_backward: bool = True
+
+        # - CUDA events for stream synchronization
+        # Holds the all-gather output buffer, sync objects, and metadata
+        self._all_gather_result: Optional[AllGatherResult] = None
+        # Holds the reduce-scatter/all-reduce view-out CUDA event that marks the end of
+        # the group's post-backward (e.g. reduce-scatter, all-reduce and div), which
+        # should be waited on at the end of backward
+        self._post_reduce_event: Optional[torch.Event] = None
+        # Holds the reshard-after-forward CUDA event when resharding to a
+        # different world size, which should be waited on in the next unshard
+        self._reshard_after_forward_event: Optional[torch.Event] = None
+
+        # Only for HSDP, if accumulating gradients without all-reduce, save the
+        # partial reduce output (only reduce-scattered but not all-reduced)
+        self._partial_reduce_output: Optional[torch.Tensor] = None
+        # Holds the all-reduce input and all-reduce event to keep it alive
+        # until the end of backward (critical when doing bf16 reduction with
+        # fp32 parameters since the all-reduce input is allocated in the RS
+        # stream and will have no refs to it after being upcast to fp32)
+        self._all_reduce_state: Optional[AllReduceState] = None
+
+    # Initialization #
+    def _init_mp_dtypes(self) -> None:
+        for fsdp_param in self.fsdp_params:
+            fsdp_param.init_dtype_attrs(self.mp_policy)
+        trainable_params: list[FSDPParam] = [
+            p for p in self.fsdp_params if p.sharded_param.requires_grad
+        ]
+        orig_dtypes = {p.orig_dtype for p in trainable_params}
+        reduce_dtypes = {p.reduce_dtype for p in trainable_params}
+        if len(trainable_params) > 0 and len(orig_dtypes) != 1:
+            # Models may have no grad params
+            raise AssertionError(
+                f"FSDP expects uniform original parameter dtype but got {orig_dtypes}"
+            )
+        self._orig_dtype = next(iter(orig_dtypes)) if trainable_params else None
+        if len(trainable_params) > 0 and len(reduce_dtypes) != 1:
+            # This can be relaxed if we issue one reduce-scatter per reduce
+            # dtype (but we would need a way for users to specify multiple
+            # reduce dtypes)
+            raise AssertionError(
+                f"FSDP expects uniform reduce dtype but got {reduce_dtypes}"
+            )
+        self._reduce_dtype = next(iter(reduce_dtypes)) if trainable_params else None
+
+    def lazy_init(self):
+        # Lazy init should be idempotent
+        # Users may change or register parameters after construction time.
+        # For example, DoRA (https://arxiv.org/abs/2402.09353) initializes linear magnitudes based on
+        # other parameters (e.g. loaded from the state dict).
+        if not hasattr(self.comm_ctx, "device_handle"):
+            self.comm_ctx.device_handle = _get_device_handle(self.device.type)
+        if self.is_sharded and not self._reset_sharded_params:
+            for fsdp_param in self.fsdp_params:
+                fsdp_param.reset_sharded_param()
+                fsdp_param._init_extensions()  # allow monkey patch after init
+            self._reset_sharded_params = True
+        self._validate_no_meta_params()
+        self._validate_cpu_offload_params()
+        # Initialize mixed precision attributes lazily in case the user changes
+        # the parameter dtypes after construction time but before forward
+        self._init_mp_dtypes()
+        self._register_state_dict_hooks()
+
+    def set_allocate_memory_from_process_group(self, enable: bool) -> None:
+        """
+        Whether to (try to) use the ProcessGroup's allocate_tensor method for
+        the staging buffers for collective comms.
+        """
+        if not isinstance(
+            self._all_gather_comm, (DefaultAllGather | ProcessGroupAllocAllGather)
+        ):
+            raise AssertionError(
+                "cannot call set_allocate_memory_from_process_group() "
+                f"when all gather comm is custom: {self._all_gather_comm.__class__.__name__}"
+            )
+        self._all_gather_comm = (
+            ProcessGroupAllocAllGather(self._all_gather_process_group)
+            if enable
+            else DefaultAllGather()
+        )
+
+        if not isinstance(
+            self._reduce_scatter_comm,
+            (DefaultReduceScatter | ProcessGroupAllocReduceScatter),
+        ):
+            raise AssertionError(
+                "cannot call set_allocate_memory_from_process_group() "
+                f"when reduce scatter comm is custom: {self._reduce_scatter_comm.__class__.__name__}"
+            )
+        self._reduce_scatter_comm = (
+            ProcessGroupAllocReduceScatter(self._reduce_scatter_process_group)
+            if enable
+            else DefaultReduceScatter()
+        )
+
+    # Runtime #
+    def unshard(self, async_op: bool = False):
+        if self._all_gather_result is not None:  # already called, pending wait
+            return
+        if self.is_unsharded:
+            return  # no-op
+        if (
+            not self.unshard_in_backward
+            and self._training_state == TrainingState.PRE_BACKWARD
+        ):
+            return
+        if self._reshard_after_forward_event is not None:
+            # Resharded parameter data is allocated in the default stream and
+            # used in the all-gather streams
+            self._wait_all_gather_streams_on_event(self._reshard_after_forward_event)
+            self._reshard_after_forward_event = None
+
+        if isinstance(self.mesh_info, FSDPMeshInfo):
+            world_size = self._all_gather_process_group.size()
+        else:
+            world_size = 1
+        if world_size == 1:
+            # can't skip due to early return in wait_for_unshard if
+            # no self._all_gather_result
+            self._all_gather_result = AllGatherResult(
+                all_gather_output=self._all_gather_output,
+                all_gather_event=self.device_handle.Event().record(),
+                all_gather_work=None,
+                param_all_gather_input_dtypes=[],
+                param_all_gather_input_numels=[],
+                all_gather_input_split_sizes=[],
+            )
+
+            return
+
+        with record_function(self._with_fqn("FSDP::all_gather")):
+            self._all_gather_result = foreach_all_gather(
+                self.fsdp_params,
+                self._all_gather_process_group,
+                async_op,
+                *self.comm_ctx.get_all_gather_streams(async_op, self._training_state),
+                self.device,
+                self._all_gather_comm,
+            )
+
+    def wait_for_unshard(self):
+        """
+        1. In forward with implicit prefetching, to overlap the current copy-out
+        with the next all-gather, we save a reference to the current all-gather
+        result to free after the next copy-out.
+        2. Otherwise (explicit prefetching or in backward), we free the
+        all-gather result immediately after the current copy-out since we can
+        already overlap the current copy-out with the previous reduce-scatter.
+        """
+        if not self._all_gather_result:
+            return  # no preceding unshard
+        async_op = self._all_gather_result.all_gather_work is not None
+        if self._training_state == TrainingState.FORWARD:  # implicit prefetch
+            if prev_all_gather_state := self.comm_ctx.all_gather_state:
+                self._wait_all_gather_streams_on_event(prev_all_gather_state.event)
+                self.comm_ctx.all_gather_state = None  # free the all-gather result
+        if isinstance(self.mesh_info, FSDPMeshInfo):
+            world_size = self._all_gather_process_group.size()
+        else:
+            world_size = 1
+        if world_size == 1:
+            # directly initialize unsharded parameters from sharded parameters
+
+            for fsdp_param in self.fsdp_params:
+                # Use all_gather_inputs which already handles conversion to param_dtype
+                # This is consistent with the world_size > 1 path
+                all_gather_input = fsdp_param.all_gather_inputs[0]
+
+                # Make sure the all_gather_outputs has proper storage size before using it
+                # First ensure we have at least one tensor in all_gather_outputs
+                fsdp_param.init_all_gather_outputs(
+                    [all_gather_input.numel()],
+                    [all_gather_input.dtype],
+                    world_size,
+                    self.device,
+                    force_recreate=False,
+                )
+
+                tensor = fsdp_param.all_gather_outputs[0]
+                alloc_storage(tensor)
+
+                # find alternative way to check if tensor.is_inference
+                with torch.autograd._unsafe_preserve_version_counter(tensor):
+                    tensor.copy_(all_gather_input)
+
+        else:
+            with record_function(self._with_fqn("FSDP::all_gather_copy_out")):
+                foreach_all_gather_copy_out(
+                    self._all_gather_result,
+                    self.fsdp_params,
+                    self._all_gather_process_group,
+                )
+
+        for fsdp_param in self.fsdp_params:
+            fsdp_param.init_unsharded_param()
+
+        self._to_unsharded()
+        all_gather_copy_out_event = self.device_handle.Event()
+        all_gather_copy_out_event.record()
+
+        if (
+            not async_op
+            and self._training_state == TrainingState.FORWARD
+            and world_size > 1
+        ):
+            # Defer free to allow for overlap of this copy-out with next
+            # all-gather collective
+            self.comm_ctx.all_gather_state = AllGatherState(
+                self._all_gather_result, all_gather_copy_out_event
+            )
+        else:
+            self._wait_all_gather_streams_on_event(all_gather_copy_out_event)
+
+        self._all_gather_result = None  # free unless saved in `all_gather_state`
+
+    def _wait_all_gather_streams_on_event(self, event: Optional[torch.Event]):
+        # Calling `unshard` before lazy init means streams are not initialized
+        if hasattr(self.comm_ctx, "all_gather_copy_in_stream") and event is not None:
+            self.comm_ctx.all_gather_copy_in_stream.wait_event(event)
+        if hasattr(self.comm_ctx, "all_gather_stream") and event is not None:
+            self.comm_ctx.all_gather_stream.wait_event(event)
+
+    def reshard(self):
+        if self._training_state == TrainingState.FORWARD:
+            if not self._reshard_after_forward:
+                return
+            if self._use_post_forward_mesh:
+                self._to_sharded_post_forward()
+                self._reshard_after_forward_event = self.device_handle.Event()
+                if self._reshard_after_forward_event is not None:
+                    self._reshard_after_forward_event.record()
+                return
+        self._to_sharded()
+
+    def pre_forward(
+        self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any]
+    ) -> tuple[tuple[Any, ...], dict[str, Any]]:
+        if not compiled_autograd_enabled():
+            logger.debug("%s", self._with_fqn("FSDP::pre_forward"))
+        with record_function(self._with_fqn("FSDP::pre_forward")):
+            self._training_state = TrainingState.FORWARD
+            self.unshard(self.unshard_async_op)
+            self.wait_for_unshard()
+            args, kwargs = self._register_post_backward_hook(args, kwargs)
+            return args, kwargs
+
+    def post_forward(self, module: nn.Module, input: Any, output: Any):
+        if not compiled_autograd_enabled():
+            logger.debug("%s", self._with_fqn("FSDP::post_forward"))
+        with record_function(self._with_fqn("FSDP::post_forward")):
+            if not compiled_autograd_enabled():
+                # for AC(fully_shard(model)), AC runs fsdp's _pre_forward
+                # it shouldn't change post_forward_order
+                if not is_bw():
+                    self.reshard()
+                    self._record_post_forward()
+            else:
+                self.reshard()
+                self._record_post_forward()
+            self._training_state = TrainingState.IDLE
+            return output
+
+    def _record_post_forward(self) -> None:
+        # Since a group has one pre-backward unshard for each forward call
+        # before the backward, we record each usage (with multiplicity)
+        post_forward_index = len(self.comm_ctx.post_forward_order)
+        self.comm_ctx.post_forward_order.append(self)
+        self._post_forward_indices.append(post_forward_index)
+
+    def pre_backward(self, default_prefetch: bool, *unused: Any):
+        if (
+            compiled_autograd_enabled()
+            and self._training_state == TrainingState.PRE_BACKWARD
+        ):
+            # Traceable FSDP2 cannot trigger the param group's `post_backward` immediately after param usage;
+            # instead it relies on this to trigger the previously unexecuted `post_backward`.
+            self.post_backward()
+        if self._training_state == TrainingState.PRE_BACKWARD:
+            return
+        if not compiled_autograd_enabled():
+            logger.debug("%s", self._with_fqn("FSDP::pre_backward"))
+        with record_function(self._with_fqn("FSDP::pre_backward")):
+            self._training_state = TrainingState.PRE_BACKWARD
+            self.unshard(self.unshard_async_op)  # no-op if prefetched
+            self.wait_for_unshard()
+            if default_prefetch and not compiled_autograd_enabled():
+                self._backward_prefetch()
+
+    def post_backward(self, *unused: Any):
+        # This method should be idempotent and safe to call even when this
+        # FSDP parameter group was not used in backward (should be a no-op)
+        if not compiled_autograd_enabled():
+            logger.debug("%s", self._with_fqn("FSDP::post_backward"))
+        self._training_state = TrainingState.POST_BACKWARD
+        with record_function(self._with_fqn("FSDP::post_backward_accumulate")):
+            for fsdp_param in self.fsdp_params:
+                fsdp_param.accumulate_unsharded_grad_if_needed()
+        with record_function(self._with_fqn("FSDP::post_backward_reshard")):
+            if not self.reduce_grads:
+                if self.reshard_after_backward:
+                    self.reshard()
+                for fsdp_param in self.fsdp_params:
+                    fsdp_param.to_accumulated_grad_if_needed()
+                return
+            # Save the autograd-computed gradients before resharding to only
+            # access the unsharded parameters when their data is present
+            fsdp_params_with_grad: list[FSDPParam] = []
+            unsharded_grads: list[torch.Tensor] = []
+            for fsdp_param in self.fsdp_params:
+                if not hasattr(fsdp_param, "_unsharded_param"):
+                    continue
+                # May have an accumulated gradient of the reduce dtype if the
+                # previous backward did not reduce-scatter
+                if fsdp_param.unsharded_accumulated_grad is not None:
+                    fsdp_params_with_grad.append(fsdp_param)
+                    unsharded_grads.append(fsdp_param.unsharded_accumulated_grad_data)
+                    fsdp_param.unsharded_accumulated_grad = None
+                elif fsdp_param.unsharded_param.grad is not None:
+                    fsdp_params_with_grad.append(fsdp_param)
+                    unsharded_grads.append(fsdp_param.unsharded_grad_data)
+                    fsdp_param.unsharded_param.grad = None
+            if self.reshard_after_backward:
+                self.reshard()
+        if len(fsdp_params_with_grad) == 0:
+            return
+        with record_function(self._with_fqn("FSDP::post_backward_reduce")):
+            if (
+                self.comm_ctx.reduce_scatter_state is not None
+                and self.comm_ctx.reduce_scatter_state.event is not None
+            ):
+                self.device_handle.current_stream().wait_event(
+                    self.comm_ctx.reduce_scatter_state.event
+                )
+            self.comm_ctx.reduce_scatter_state = None
+            all_reduce_pg = (
+                self._all_reduce_process_group
+                if isinstance(self.mesh_info, DDPMeshInfo)
+                else None
+            )
+            all_reduce_stream: torch.cuda.Stream
+            if all_reduce_pg is None and self._all_reduce_hook_stream is not None:
+                # this means the native HSDP is not enabled,
+                # but user may want to have a custom HSDP setup
+                if self._all_reduce_hook is None:
+                    raise AssertionError(
+                        "all reduce hook stream is specified but hook itself is missing."
+                    )
+                all_reduce_stream = self._all_reduce_hook_stream
+            else:
+                all_reduce_stream = self.comm_ctx.all_reduce_stream
+
+            self._wait_for_post_backward()
+            (
+                reduce_scatter_input,
+                reduce_scatter_event,
+                self._post_reduce_event,
+                all_reduce_input,
+                all_reduce_event,
+                self._partial_reduce_output,
+            ) = foreach_reduce(
+                fsdp_params_with_grad,
+                unsharded_grads,
+                (
+                    self._reduce_scatter_process_group
+                    if isinstance(self.mesh_info, FSDPMeshInfo)
+                    else None  # pyre-fixme[6]
+                ),
+                self.comm_ctx.reduce_scatter_stream,
+                self._reduce_scatter_comm,
+                self._orig_dtype,
+                self._reduce_dtype,
+                self.device,
+                self.gradient_divide_factor,
+                (
+                    self._all_reduce_process_group
+                    if isinstance(self.mesh_info, DDPMeshInfo)
+                    else None
+                ),
+                all_reduce_stream,
+                self.all_reduce_grads,
+                self._partial_reduce_output,
+                self._all_reduce_hook,
+                self.force_sum_reduction_for_comms,
+            )
+            self.comm_ctx.reduce_scatter_state = ReduceScatterState(
+                reduce_scatter_input, reduce_scatter_event
+            )
+            if all_reduce_input is not None:
+                if self.device.type != "cpu":
+                    if all_reduce_event is None:
+                        raise AssertionError(
+                            "Expected all_reduce_event to be set for non-CPU device"
+                        )
+                self._all_reduce_state = AllReduceState(
+                    all_reduce_input, all_reduce_event
+                )
+
+    def finalize_backward(self):
+        self._wait_for_post_backward()
+        for fsdp_param in self.fsdp_params:
+            if fsdp_param.grad_offload_event is not None:
+                fsdp_param.grad_offload_event.synchronize()
+                fsdp_param.grad_offload_event = None
+        if self._all_gather_result is not None:
+            # If there was a mistargeted unshard without a corresponding wait,
+            # then we wait here and clear the unshard
+            if (event := self._all_gather_result.all_gather_event) is not None:
+                torch.accelerator.current_stream().wait_event(event)
+            work = self._all_gather_result.all_gather_work
+            if isinstance(work, dist.distributed_c10d.Work):
+                work.wait()
+            self._all_gather_result = None
+        self._post_forward_indices.clear()
+
+    def _wait_for_post_backward(self):
+        if self._post_reduce_event is not None:
+            self.device_handle.current_stream().wait_event(self._post_reduce_event)
+            self._post_reduce_event = None
+        if (
+            self._all_reduce_state is not None
+            and self._all_reduce_state.event is not None
+        ):
+            self.device_handle.current_stream().wait_event(self._all_reduce_state.event)
+        self._all_reduce_state = None
+
+    def _backward_prefetch(self) -> None:
+        if self._training_state == TrainingState.PRE_BACKWARD:
+            if not self._post_forward_indices:
+                # Can be cleared if running multiple `backward`s
+                return
+            curr_index = self._post_forward_indices.pop()
+            if (target_index := curr_index - 1) < 0:
+                return
+            # Prefetch naively using the reverse post-forward order, which may
+            # have mistargeted prefetches if not all modules used in forward
+            # are used in this backward
+            # pyrefly: ignore [unbound-name]
+            target_fsdp_param_group = self.comm_ctx.post_forward_order[target_index]
+            self._prefetch_unshard(target_fsdp_param_group, "backward")
+
+    @staticmethod
+    def _prefetch_unshard(
+        target_fsdp_param_group: "FSDPParamGroup", pass_type: str
+    ) -> None:
+        if pass_type == "backward":
+            training_state = TrainingState.PRE_BACKWARD
+        elif pass_type == "forward":
+            training_state = TrainingState.FORWARD
+        else:
+            raise ValueError(f"Unknown pass type: {pass_type}")
+        target_fqn = target_fsdp_param_group._module_fqn
+        with (
+            record_function(f"FSDP::{pass_type}_prefetch for {target_fqn}"),
+            target_fsdp_param_group.use_training_state(training_state),
+        ):
+            async_op = target_fsdp_param_group.unshard_async_op
+            target_fsdp_param_group.unshard(async_op)
+
+    # Utilities #
+    def _to_sharded(self):
+        if not self.is_sharded:
+            for fsdp_param in self.fsdp_params:
+                fsdp_param.to_sharded()
+            self._sharded_state = ShardedState.SHARDED
+
+    def _to_sharded_post_forward(self):
+        if not self.is_sharded_post_forward:
+            for fsdp_param in self.fsdp_params:
+                fsdp_param.to_sharded_post_forward()
+            self._sharded_state = ShardedState.SHARDED_POST_FORWARD
+
+    def _to_unsharded(self):
+        if not self.is_unsharded:
+            for fsdp_param in self.fsdp_params:
+                fsdp_param.to_unsharded()
+            self._sharded_state = ShardedState.UNSHARDED
+
+    @property
+    def is_sharded(self) -> bool:
+        return self._sharded_state == ShardedState.SHARDED
+
+    @property
+    def is_sharded_post_forward(self) -> bool:
+        return self._sharded_state == ShardedState.SHARDED_POST_FORWARD
+
+    @property
+    def is_unsharded(self) -> bool:
+        return self._sharded_state == ShardedState.UNSHARDED
+
+    @contextlib.contextmanager
+    def use_training_state(self, training_state: TrainingState):
+        old_training_state = self._training_state
+        self._training_state = training_state
+        try:
+            yield
+        finally:
+            self._training_state = old_training_state
+
+    # Hook Registration #
+    def _register_post_backward_hook(
+        self, args: tuple[Any, ...], kwargs: dict[str, Any]
+    ) -> tuple[tuple[Any, ...], dict[str, Any]]:
+        # Traceable FSDP2 relies on `root_post_backward_callback` to call each
+        # `FSDPParamGroup.post_backward`
+        if (not torch._dynamo.config.skip_fsdp_hooks) or compiled_autograd_enabled():
+            return args, kwargs
+        if not torch.is_grad_enabled():
+            return args, kwargs
+        args_list, args_spec = tree_flatten(args)
+        kwargs_list, kwargs_spec = tree_flatten(kwargs)
+        args_kwargs_list = list(args_list) + list(kwargs_list)
+        inp_tensor_indices: list[int] = []
+        inp_tensors: list[torch.Tensor] = []
+        for i, obj in enumerate(args_kwargs_list):
+            if torch.is_tensor(obj) and obj.requires_grad:
+                inp_tensor_indices.append(i)
+                inp_tensors.append(obj)
+        if len(inp_tensors) == 0:
+            return args, kwargs  # no tensors that require gradients
+        inp_tensors = RegisterPostBackwardFunction.apply(self, *inp_tensors)
+        for inp_tensor_idx, inp_tensor in zip(inp_tensor_indices, inp_tensors):
+            args_kwargs_list[inp_tensor_idx] = inp_tensor
+        args_list = args_kwargs_list[: len(args_list)]
+        kwargs_list = args_kwargs_list[len(args_list) :]
+        args = tree_unflatten(args_list, args_spec)
+        kwargs = tree_unflatten(kwargs_list, kwargs_spec)
+        return args, kwargs
+
+    def _register_state_dict_hooks(self) -> None:
+        num_pre_save_hooks = len(self._module_to_pre_save_state_dict_hook_handle)
+        num_pre_load_hooks = len(self._module_to_pre_load_state_dict_hook_handle)
+        if num_pre_save_hooks != num_pre_load_hooks:
+            raise AssertionError(
+                f"Pre-save: {num_pre_save_hooks} pre-load: {num_pre_load_hooks}"
+            )
+        if num_pre_save_hooks > 0:
+            return  # already registered
+        modules_with_fsdp_params: set[nn.Module] = {
+            fsdp_param._module_info.module for fsdp_param in self.fsdp_params
+        }
+
+        def to_sharded_hook(*args: Any, **kwargs: Any) -> None:
+            self._to_sharded()
+
+        for module in modules_with_fsdp_params:
+            self._module_to_pre_save_state_dict_hook_handle[module] = (
+                module.register_state_dict_pre_hook(to_sharded_hook)
+            )
+            self._module_to_pre_load_state_dict_hook_handle[module] = (
+                module._register_load_state_dict_pre_hook(to_sharded_hook)
+            )
+
+    # Properties #
+    @property
+    def _reshard_after_forward(self) -> bool:
+        return self.post_forward_mesh_info is not None
+
+    @property
+    def _use_post_forward_mesh(self) -> bool:
+        return (
+            self._reshard_after_forward
+            and self.mesh_info != self.post_forward_mesh_info
+        )
+
+    @property
+    def _is_hsdp(self) -> bool:
+        return isinstance(self.mesh_info, HSDPMeshInfo)
+
+    @property
+    def _all_gather_process_group(self) -> dist.ProcessGroup:
+        mesh_info = (
+            cast(FSDPMeshInfo, self.post_forward_mesh_info)
+            if self.is_sharded_post_forward
+            else self.mesh_info
+        )
+        if not isinstance(mesh_info, FSDPMeshInfo):
+            raise AssertionError(
+                f"Expected mesh_info to be FSDPMeshInfo, got {type(mesh_info)}"
+            )
+        return mesh_info.shard_process_group
+
+    @property
+    def _reduce_scatter_process_group(self) -> dist.ProcessGroup:
+        if not isinstance(self.mesh_info, FSDPMeshInfo):
+            raise AssertionError(
+                f"Expected mesh_info to be FSDPMeshInfo, got {type(self.mesh_info)}"
+            )
+        return self.mesh_info.shard_process_group
+
+    @property
+    def _all_reduce_process_group(self) -> dist.ProcessGroup:
+        if not isinstance(self.mesh_info, DDPMeshInfo):
+            raise AssertionError(
+                f"Expected mesh_info to be DDPMeshInfo or HSDPMeshInfo, got {type(self.mesh_info)}"
+            )
+        return self.mesh_info.replicate_process_group
+
+    def _with_fqn(self, label: str) -> str:
+        if self._module_fqn:
+            return f"{label} ({self._module_fqn})"
+        return label
+
+    def __repr__(self):
+        return f"FSDPParamGroup(fqn={self._module_fqn})"
+
+    def _validate_no_meta_params(self):
+        param_names_on_meta = [
+            fsdp_param._param_fqn
+            for fsdp_param in self.fsdp_params
+            if fsdp_param.sharded_param.device.type == "meta"
+        ]
+        if param_names_on_meta:
+            raise RuntimeError(
+                "FSDP parameters should be materialized from meta device before training, "
+                f"but the following were still on meta device: {param_names_on_meta}\n"
+                "For example, call module.to_empty(device) to materialize to device and "
+                "call module.reset_parameters() on each module to initialize values."
+            )
+
+    def _validate_cpu_offload_params(self):
+        if not isinstance(self.offload_policy, CPUOffloadPolicy):
+            return
+        fsdp_params_not_on_cpu = [
+            fsdp_param
+            for fsdp_param in self.fsdp_params
+            if fsdp_param.sharded_param.device.type != "cpu"
+        ]
+        if fsdp_params_not_on_cpu:
+            raise RuntimeError(
+                "FSDP parameters should be materialized on CPU when enabling CPU offloading. "
+                'For example, load a CPU state dict or call module.to_empty(device="cpu"). '
+                "Found following parameters on non-CPU device: "
+                f"{[(fsdp_param._param_fqn, fsdp_param.sharded_param.device) for fsdp_param in fsdp_params_not_on_cpu]}\n"
+            )
+
+
+def _get_param_module_infos(
+    params: list[nn.Parameter], modules: tuple[nn.Module, ...]
+) -> list[ParamModuleInfo]:
+    """
+    Shared parameter: lin1.weight = lin2.weight
+    Shared module: mlp.lin1 = mlp.lin2
+    We do not remove duplicates when traversing both modules and parameters to
+    find shared modules' parameters and shared parameters within a module.
+    """
+    params_set = set(params)
+    param_to_module_info: dict[nn.Parameter, ParamModuleInfo] = {}
+    for module in modules:
+        for _, submodule in module.named_modules(remove_duplicate=False):
+            for param_name, param in _named_parameters_with_duplicates(
+                submodule, recurse=False
+            ):
+                if param in params_set:
+                    if param not in param_to_module_info:
+                        param_to_module_info[param] = ParamModuleInfo(
+                            submodule, param_name
+                        )
+                    else:
+                        param_to_module_info[param].shared_modules.append(submodule)
+                        param_to_module_info[param].shared_param_names.append(
+                            param_name
+                        )
+    if len(param_to_module_info) != len(params):
+        raise AssertionError(f"Some parameters are not in the module tree of {modules}")
+    return [param_to_module_info[param] for param in params]
+
+
+class RegisterPostBackwardFunction(torch.autograd.Function):
+    @staticmethod
+    def _assert_not_tracing_fsdp():
+        if compiled_autograd_enabled():
+            # TODO: Find a way to print the offending FSDP2 module.
+            msg = """\
+When Traceable FSDP2 is enabled, we should not be calling into `RegisterPostBackwardFunction`.
+Instead, we rely on the param group's next `pre_backward` hook to trigger its previously unexecuted
+`post_backward`, and we rely on FSDPState's `root_post_backward_callback` to trigger the resharding
+of any leftover unsharded param groups.
+If you are here, it means the forward part of this FSDP2 instance is not compiled, and you must also
+compile the forward part if you want to use Traceable FSDP2."""
+            torch._dynamo.comptime.comptime.print(msg)
+            raise RuntimeError(msg)
+
+    @staticmethod
+    # pyrefly: ignore [bad-override]
+    def forward(ctx, param_group: FSDPParamGroup, *inputs: torch.Tensor):
+        # All tensors in `inputs` should require gradient
+        RegisterPostBackwardFunction._assert_not_tracing_fsdp()
+        ctx.param_group = param_group
+        return inputs
+
+    @staticmethod
+    def backward(ctx, *grads: torch.Tensor):
+        RegisterPostBackwardFunction._assert_not_tracing_fsdp()
+        ctx.param_group.post_backward()
+        return (None,) + grads
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py
new file mode 100644
index 0000000000000000000000000000000000000000..d68dfbf2ddcb0faaf1888fc912ba09bc599e2c5c
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py
@@ -0,0 +1,408 @@
+# mypy: allow-untyped-decorators
+# mypy: allow-untyped-defs
+import functools
+import logging
+from collections.abc import Callable, Sequence
+from typing import Any, Optional, TYPE_CHECKING
+
+import torch
+import torch.nn as nn
+from torch._logging import warning_once
+from torch.autograd import Variable
+from torch.autograd.graph import _MultiHandle
+from torch.distributed._composable_state import (
+    _get_module_state,
+    _insert_module_state,
+    _State,
+)
+from torch.distributed.device_mesh import _get_device_handle
+from torch.distributed.utils import _apply_to_tensors, _to_kwargs
+from torch.utils._pytree import tree_flatten
+
+from ._fsdp_api import MixedPrecisionPolicy
+from ._fsdp_common import (
+    _cast_fp_tensor,
+    compiled_autograd_enabled,
+    detect_compiled_autograd,
+    TrainingState,
+)
+from ._fsdp_param_group import FSDPCommContext, FSDPParamGroup
+
+
+if TYPE_CHECKING:
+    from ._fsdp_param import FSDPParam
+
+
+logger = logging.getLogger("torch.distributed.fsdp.fully_shard")
+
+
+class FSDPStateContext:
+    """This has state shared across FSDP states."""
+
+    def __init__(self) -> None:
+        # All FSDP states in the root state's module tree
+        self.all_states: list[FSDPState] = []
+        # Iteration's forward root runs the once-per-forward logic; this root
+        # may not be the overall root set by lazy initialization in cases where
+        # only a submodule runs forward (e.g. encoder-only for eval)
+        self.iter_forward_root: Optional[FSDPState] = None
+        # Final callback should only be queued once per backward
+        self.post_backward_final_callback_queued: bool = False
+        # Whether to finalize backward in this backward's final callback
+        self.is_last_backward: bool = True
+        # Optional user-provided event recorded after optimizer for the
+        # all-gather streams to wait on in the root pre-forward
+        self.post_optim_event: Optional[torch.Event] = None
+
+
+def disable_if_config_true(func):
+    @functools.wraps(func)
+    def fsdp_hook_wrapper(*args, **kwargs):
+        if torch._dynamo.config.skip_fsdp_hooks:
+            return torch._dynamo.disable(
+                func,
+                recursive=True,
+                reason="skipping FSDP hooks since torch._dynamo.config.skip_fsdp_hooks is set",
+            )(*args, **kwargs)
+        else:
+            return func(*args, **kwargs)
+
+    return fsdp_hook_wrapper
+
+
+class FSDPState(_State):
+    def __init__(self) -> None:
+        super().__init__()
+        self._fsdp_param_group: Optional[FSDPParamGroup] = None
+        self._is_root: Optional[bool] = None  # root set during lazy init
+        self._state_ctx = FSDPStateContext()
+        self._comm_ctx = FSDPCommContext()
+        self._training_state: TrainingState = TrainingState.IDLE
+        self._states_to_forward_prefetch: list[FSDPState] = []
+        self._states_to_backward_prefetch: list[FSDPState] = []
+        self._modules_to_run_forward: set[nn.Module] = set()
+        # ``False`` when user set reshard_after_forward
+        # through ``fully_shard`` or ``set_reshard_after_forward``
+        self._auto_reshard_after_forward: Optional[bool] = True
+
+    # Define a separate init since `__init__` is called in the contract
+    def init(
+        self,
+        modules: tuple[nn.Module, ...],
+        device: torch.device,
+        mp_policy: MixedPrecisionPolicy,
+        auto_reshard_after_forward: bool,
+    ) -> None:
+        for module in modules:
+            _insert_module_state(module, self)
+        self._modules = modules
+        # pyrefly: ignore [read-only]
+        self._device = device
+        self._device_handle = _get_device_handle(device.type)
+        self._mp_policy = mp_policy
+        self._auto_reshard_after_forward = auto_reshard_after_forward
+        if len(modules) == 1:
+            self._pre_forward_hook_handle = modules[0].register_forward_pre_hook(
+                self._pre_forward, prepend=True, with_kwargs=True
+            )
+            self._post_forward_hook_handle = modules[0].register_forward_hook(
+                self._post_forward, prepend=False
+            )
+        else:
+            hook_handle = _register_group_forward_hooks(
+                modules,
+                self._pre_forward,
+                self._post_forward,
+                self._modules_to_run_forward,
+            )
+            self._pre_forward_hook_handle = hook_handle
+            self._post_forward_hook_handle = hook_handle
+
+    def _root_pre_forward(
+        self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any]
+    ) -> tuple[tuple[Any, ...], dict[str, Any]]:
+        self._lazy_init()
+        if self._state_ctx.iter_forward_root is not None:
+            return args, kwargs
+        if not compiled_autograd_enabled():
+            logger.debug("FSDP::root_pre_forward")
+        self._state_ctx.iter_forward_root = self
+        with torch.profiler.record_function("FSDP::root_pre_forward"):
+            # Wait for optimizer before implicitly prefetched all-gathers
+            if (event := self._state_ctx.post_optim_event) is not None:
+                self._comm_ctx.all_gather_copy_in_stream.wait_event(event)
+                self._comm_ctx.all_gather_stream.wait_event(event)
+                self._state_ctx.post_optim_event = None
+            else:
+                current_stream = self._device_handle.current_stream()
+                self._comm_ctx.all_gather_copy_in_stream.wait_stream(current_stream)
+                self._comm_ctx.all_gather_stream.wait_stream(current_stream)
+            if self._device.type in [
+                "cuda",
+                "hpu",
+                "xpu",
+                "mtia",
+                torch._C._get_privateuse1_backend_name(),
+            ]:
+                with torch.profiler.record_function("FSDP::inputs_to_device"):
+                    args_tuple, kwargs_tuple = _to_kwargs(
+                        args, kwargs, self._device, False
+                    )  # same as DDP
+                args, kwargs = args_tuple[0], kwargs_tuple[0]
+        return args, kwargs
+
+    def _lazy_init(self) -> None:
+        """
+        Lazy initialization represents when all modules' parallelisms have
+        finalized (e.g. FSDP has been applied to all desired modules). This
+        means that we can determine which state is the root, and we do so by
+        the 1st state to run forward.
+        """
+        if self._is_root is not None:
+            return  # no-op: already initialized
+        self._is_root = True
+        if len(self._modules) > 1:
+            raise RuntimeError(
+                f"FSDP requires a single root module but got {self._modules}"
+            )
+        detect_compiled_autograd()
+        root_module = self._modules[0]
+        visited_states: set[FSDPState] = set()
+        for module_name, module in root_module.named_modules():
+            if (state := _get_module_fsdp_state(module)) is None:
+                continue
+            if module is not root_module:
+                if state not in visited_states and state._is_root is not None:
+                    raise RuntimeError(
+                        "FSDP state has already been lazily initialized for "
+                        f"{module_name}\nFSDP requires running forward through "
+                        "the root module first"
+                    )
+                state._is_root = False
+            self._state_ctx.all_states.append(state)
+            visited_states.add(state)
+        if self._fsdp_param_group and self._auto_reshard_after_forward:
+            # For the root, do not reshard after forward since for training,
+            # the parameters would be freed and all-gathered immediately
+            self._fsdp_param_group.post_forward_mesh_info = None
+        self._init_fqns()
+        self._init_shared_state()
+        # Run parameter group lazy inits after initializing FQNs for improved
+        # error messages
+        for state in self._state_ctx.all_states:
+            if state._fsdp_param_group:
+                state._fsdp_param_group.lazy_init()
+
+    def _init_shared_state(self) -> None:
+        self._comm_ctx.lazy_init(self._device)
+        for state in self._state_ctx.all_states:
+            state._state_ctx = self._state_ctx
+            state._comm_ctx = self._comm_ctx
+            if fsdp_param_group := state._fsdp_param_group:
+                fsdp_param_group.comm_ctx = self._comm_ctx
+
+    def _init_fqns(self) -> None:
+        """Sets module and parameter FQN attributes for debugging."""
+        if not self._is_root:
+            raise AssertionError("Expected _is_root to be True")
+        root_module = self._modules[0]
+        param_to_fsdp_param: dict[nn.Parameter, FSDPParam] = {}
+        module_to_fsdp_param_group: dict[nn.Module, FSDPParamGroup] = {}
+        for state in self._state_ctx.all_states:
+            if fsdp_param_group := state._fsdp_param_group:
+                for fsdp_param in fsdp_param_group.fsdp_params:
+                    param_to_fsdp_param[fsdp_param.sharded_param] = fsdp_param
+                for module in fsdp_param_group.modules:
+                    module_to_fsdp_param_group[module] = fsdp_param_group
+        for param_name, param in root_module.named_parameters():
+            if param in param_to_fsdp_param:
+                param_to_fsdp_param[param]._param_fqn = param_name
+        for module_name, module in root_module.named_modules():
+            if module in module_to_fsdp_param_group:
+                module_fqn = module_to_fsdp_param_group[module]._module_fqn
+                if module_fqn is None:
+                    module_to_fsdp_param_group[module]._module_fqn = module_name
+                else:
+                    if not isinstance(module_fqn, str):
+                        raise AssertionError(
+                            f"Expected module_fqn to be str, got {type(module_fqn)}: {module_fqn}"
+                        )
+                    module_fqn += f", {module_name}"
+                    module_to_fsdp_param_group[module]._module_fqn = module_fqn
+
+    @disable_if_config_true
+    def _pre_forward(
+        self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any]
+    ) -> tuple[tuple[Any, ...], dict[str, Any]]:
+        # When composing with module-hook-based activation checkpointing, the
+        # pre-backward hook is responsible for the unshard
+        if self._training_state == TrainingState.PRE_BACKWARD:
+            return args, kwargs
+        self._training_state = TrainingState.FORWARD
+        args, kwargs = self._root_pre_forward(module, args, kwargs)
+        if self._mp_policy.cast_forward_inputs and self._mp_policy.param_dtype:
+            with torch.profiler.record_function("FSDP::cast_forward_inputs"):
+                cast_fn = functools.partial(
+                    _cast_fp_tensor, self._mp_policy.param_dtype
+                )
+                args, kwargs = (
+                    _apply_to_tensors(cast_fn, args),
+                    _apply_to_tensors(cast_fn, kwargs),
+                )
+        if self._fsdp_param_group:
+            args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs)
+        for fsdp_state in self._states_to_forward_prefetch:
+            if (target_param_group := fsdp_state._fsdp_param_group) is not None:
+                FSDPParamGroup._prefetch_unshard(target_param_group, "forward")
+        return args, kwargs
+
+    @disable_if_config_true
+    def _post_forward(self, module: nn.Module, input: Any, output: Any) -> Any:
+        # When composing with module-hook-based activation checkpointing, the
+        # post-backward hook is responsible for the reshard
+        if self._training_state == TrainingState.PRE_BACKWARD:
+            return output
+        if self._fsdp_param_group:
+            output = self._fsdp_param_group.post_forward(module, input, output)
+        output = self._register_pre_backward_hook(output)
+        self._training_state = TrainingState.IDLE
+        if self._state_ctx.iter_forward_root is self:
+            if all_gather_state := self._comm_ctx.all_gather_state:
+                # Free the last all-gather result if needed; refer to
+                # [Note: Overlapping all-gather copy-in and all-gather]
+                self._comm_ctx.all_gather_copy_in_stream.wait_event(
+                    all_gather_state.event
+                )
+                self._comm_ctx.all_gather_stream.wait_event(all_gather_state.event)
+                self._comm_ctx.all_gather_state = None  # free the all-gather result
+            self._state_ctx.iter_forward_root = None
+        if self._mp_policy.output_dtype is not None:
+            with torch.profiler.record_function("FSDP::cast_forward_outputs"):
+                output = _apply_to_tensors(
+                    functools.partial(_cast_fp_tensor, self._mp_policy.output_dtype),
+                    output,
+                )
+        return output
+
+    def _pre_backward(self, grad: torch.Tensor) -> torch.Tensor:
+        self._training_state = TrainingState.PRE_BACKWARD
+        self._register_root_post_backward_final_callback()
+        if self._fsdp_param_group:
+            default_prefetch = len(self._states_to_backward_prefetch) == 0
+            self._fsdp_param_group.pre_backward(default_prefetch)
+        for fsdp_state in self._states_to_backward_prefetch:
+            if (target_param_group := fsdp_state._fsdp_param_group) is not None:
+                FSDPParamGroup._prefetch_unshard(target_param_group, "backward")
+        return grad
+
+    def _root_post_backward_final_callback(self) -> None:
+        if not compiled_autograd_enabled():
+            logger.debug("FSDP::root_post_backward")
+        with torch.profiler.record_function("FSDP::root_post_backward_callback"):
+            for state in self._state_ctx.all_states:
+                fsdp_param_group = state._fsdp_param_group
+                if (
+                    fsdp_param_group
+                    and fsdp_param_group._training_state != TrainingState.POST_BACKWARD
+                ):
+                    # Run post-backward in case forward inputs did not require
+                    # gradient so the autograd backward did not run
+                    fsdp_param_group.post_backward()
+                state._training_state = TrainingState.IDLE
+                if fsdp_param_group:
+                    fsdp_param_group._training_state = TrainingState.IDLE
+                if self._state_ctx.is_last_backward:
+                    state._finalize_backward()
+            if self._state_ctx.is_last_backward:
+                self._comm_ctx.post_forward_order.clear()
+                if self._comm_ctx.reduce_scatter_state is not None:
+                    self._device_handle.current_stream().wait_event(
+                        self._comm_ctx.reduce_scatter_state.event
+                    )
+                    self._comm_ctx.reduce_scatter_state = None
+            self._state_ctx.post_backward_final_callback_queued = False
+
+    def _finalize_backward(self) -> None:
+        if self._modules_to_run_forward:
+            msg = (
+                f"{len(self._modules_to_run_forward)} of the {len(self._modules)} "
+                f"modules passed to fully_shard did not run forward before backward, "
+                "which is error-prone since FSDP post-forward/pre-backward logic "
+                "will not run for these modules. We recommend passing only modules "
+                "that run forward together. Modules that did not run forward: "
+                f"{list(self._modules_to_run_forward)}"
+            )
+            warning_once(logger, msg, stacklevel=2)
+            # Clear since we want the next forward to run
+            self._modules_to_run_forward.clear()
+        if self._fsdp_param_group:
+            self._fsdp_param_group.finalize_backward()
+
+    def _register_pre_backward_hook(self, output: Any) -> Any:
+        if not torch.is_grad_enabled():
+            return output
+        flat_outputs, _ = tree_flatten(output)
+        for t in flat_outputs:
+            if torch.is_tensor(t) and t.requires_grad:
+                t.register_hook(self._pre_backward)
+        return output
+
+    def _register_root_post_backward_final_callback(self):
+        if self._state_ctx.post_backward_final_callback_queued:
+            return
+        self._state_ctx.post_backward_final_callback_queued = True
+        Variable._execution_engine.queue_callback(
+            self._root_post_backward_final_callback
+        )
+
+
+def _get_module_fsdp_state(module: nn.Module) -> Optional[FSDPState]:
+    state = _get_module_state(module)
+    if isinstance(state, FSDPState):
+        return state
+    return None
+
+
+def _register_group_forward_hooks(
+    modules: Sequence[nn.Module],
+    pre_hook: Callable,
+    post_hook: Callable,
+    modules_to_run: set[nn.Module],
+):
+    """
+    Registers group forward pre and post-hooks. The pre-hook runs upon the
+    first module pre-forward, and the post-hook runs upon the last. If at least
+    one module does not run forward, then the post-hook does not run.
+    """
+    modules_set = set(modules)
+
+    @disable_if_config_true
+    @functools.wraps(pre_hook)
+    def wrapped_pre_hook(*args: Any, **kwargs: Any):
+        if len(modules_to_run) == 0:  # first to run
+            modules_to_run.update(modules_set)
+            return pre_hook(*args, **kwargs)
+
+    @disable_if_config_true
+    def get_wrapped_post_hook(module: nn.Module):
+        @functools.wraps(post_hook)
+        def wrapped_post_hook(*args: Any, **kwargs: Any):
+            modules_to_run.discard(module)
+            if len(modules_to_run) == 0:
+                return post_hook(*args, **kwargs)
+
+        return wrapped_post_hook
+
+    pre_handles = [
+        module.register_forward_pre_hook(
+            wrapped_pre_hook, prepend=True, with_kwargs=True
+        )
+        for module in modules
+    ]
+    post_handles = [
+        module.register_forward_hook(
+            get_wrapped_post_hook(module), prepend=False, always_call=True
+        )
+        for module in modules
+    ]
+    return _MultiHandle(tuple(pre_handles + post_handles))
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fully_shard.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fully_shard.py
new file mode 100644
index 0000000000000000000000000000000000000000..998a33746f961fbf65f43b2c4245a6f12a9d3893
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fully_shard.py
@@ -0,0 +1,746 @@
+# mypy: allow-untyped-decorators
+# mypy: allow-untyped-defs
+
+from __future__ import annotations
+
+import functools
+from contextlib import contextmanager
+from typing import Any, cast, NoReturn, Optional, overload, TYPE_CHECKING, Union
+from typing_extensions import deprecated
+
+import torch
+import torch.nn as nn
+from torch.distributed._composable import contract
+from torch.distributed.utils import _get_root_modules
+
+from ._fsdp_api import AllGather, MixedPrecisionPolicy, OffloadPolicy, ReduceScatter
+from ._fsdp_common import FSDPMeshInfo, HSDPMeshInfo
+from ._fsdp_init import (
+    _get_device_from_mesh,
+    _get_managed_modules,
+    _get_managed_states,
+    _get_post_forward_mesh_info,
+    _init_default_fully_shard_mesh,
+    _move_states_to_device,
+)
+from ._fsdp_param_group import FSDPParamGroup
+from ._fsdp_state import _get_module_fsdp_state, FSDPState
+
+
+if TYPE_CHECKING:
+    from collections.abc import Callable, Iterable, Iterator
+
+    from torch.distributed.tensor import DeviceMesh, Shard
+
+__all__ = [
+    "fully_shard",
+    "FSDPModule",
+    "UnshardHandle",
+    "register_fsdp_forward_method",
+    "get_cls_to_fsdp_cls",
+    "disable_fsdp_module_new_init",
+    "share_comm_ctx",
+]
+
+
+cls_to_fsdp_cls: dict[type, type] = {}
+
+
+def get_cls_to_fsdp_cls() -> dict[type, type]:
+    return cls_to_fsdp_cls
+
+
+@overload
+# pyrefly: ignore [inconsistent-overload]
+def fully_shard(
+    module: nn.Module,
+    *,
+    mesh: Optional[DeviceMesh] = ...,
+    reshard_after_forward: Union[bool, int] = ...,
+    shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = ...,
+    mp_policy: MixedPrecisionPolicy = ...,
+    offload_policy: OffloadPolicy = ...,
+    ignored_params: Optional[set[nn.Parameter]] = ...,
+) -> FSDPModule: ...
+
+
+@overload
+# pyrefly: ignore [inconsistent-overload]
+def fully_shard(
+    module: list[nn.Module],
+    *,
+    mesh: Optional[DeviceMesh] = ...,
+    reshard_after_forward: Union[bool, int] = ...,
+    shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = ...,
+    mp_policy: MixedPrecisionPolicy = ...,
+    offload_policy: OffloadPolicy = ...,
+    ignored_params: Optional[set[nn.Parameter]] = ...,
+) -> list[FSDPModule]: ...
+
+
+# The decorator adds a state object to `module` that can be accessed via
+# `fully_shard.state(module)`. The state object and module are 1:1.
+# [1] Python runtime decorator does not play well with static type checking
+# so suppressing some type checks to support type overloads
+# such that caller can still get correct return types based on input type
+@contract(state_cls=FSDPState)  # type: ignore[misc] # see [1]
+def fully_shard(
+    module,
+    *,
+    mesh: Optional[DeviceMesh] = None,
+    reshard_after_forward: Optional[Union[bool, int]] = None,
+    shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = None,
+    mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
+    offload_policy: OffloadPolicy = OffloadPolicy(),
+    ignored_params: Optional[set[nn.Parameter]] = None,
+):
+    """
+    Apply fully sharded data parallelism (FSDP) to ``module``, where FSDP
+    shards module parameters, gradients, and optimizer states across data
+    parallel workers to save memory at the cost of communication.
+
+    At initialization, FSDP shards the module's parameters across the data
+    parallel workers given by ``mesh``. Before forward, FSDP all-gathers the
+    sharded parameters across the data-parallel workers to get the unsharded
+    parameters for forward computation. If ``reshard_after_forward`` is
+    ``True``, then FSDP frees the unsharded parameters after forward and
+    re-all-gathers them in backward before gradient computation. After gradient
+    computation, FSDP frees the unsharded parameters and reduce-scatters the
+    unsharded gradients across data-parallel workers.
+
+    This implementation represents the sharded parameters as :class:`DTensor` s
+    sharded on dim-0, while the unsharded parameters will be like the original
+    parameters on ``module`` (e.g. :class:`torch.Tensor` if originally
+    :class:`torch.Tensor`). A module
+    `forward pre-hook `_
+    on ``module`` all-gathers the parameters, and a module
+    `forward hook `_
+    on ``module`` frees them (if needed). Similar backward hooks all-gather
+    parameters and later free parameters and reduce-scatter gradients.
+
+    Since grouping multiple tensors together for one collective is critical for
+    communication efficiency, this implementation makes this grouping first
+    class. Calling :meth:`fully_shard` on ``module`` constructs one group that
+    includes the parameters in ``module.parameters()`` except those already
+    assigned to a group from an earlier call on a submodule. This means that
+    :meth:`fully_shard` should be called bottom-up on your model. Each group's
+    parameters are all-gathered in one collective, and its gradients are
+    reduce-scattered in one collective. Partitioning the model into multiple
+    groups ("layer by layer") allows for peak memory savings and communication/computation
+    overlap. Users generally should *not* call :meth:`fully_shard` only on the
+    topmost root module.
+
+    Args:
+        module (Union[nn.Module, List[nn.Module]): The module or modules to
+            shard with FSDP and group together for communication.
+        mesh (Optional[DeviceMesh]): This data parallel mesh defines the
+            sharding and device. If 1D, then parameters are fully sharded
+            across the 1D mesh (FSDP) with ``(Shard(0),)`` placement. If 2D,
+            then parameters are sharded across the 1st dim and replicated
+            across the 0th dim (HSDP) with ``(Replicate(), Shard(0))``
+            placement. The mesh's device type gives the device type used for
+            communication; if a CUDA or CUDA-like device type, then we use the
+            current device.
+        reshard_after_forward (Optional[Union[bool, int]]): This controls the parameter
+            behavior after forward and can trade off memory and communication:
+
+            - If ``True``, then this reshards parameters after forward and
+              re-all-gathers in backward.
+            - If ``False``, then this keeps the unsharded parameters in memory
+              after forward and avoids the all-gather in backward. For best performance,
+              we usually set ``False`` for the root module, because the root module
+              is typically required immediately when the backward pass begins.
+            - If ``None``, it is set to ``True`` for non-root modules and ``False``
+              for root modules.
+            - If an ``int``, then this represents the world size to reshard to
+              after forward. It should be a non-trivial divisor of the ``mesh``
+              shard dim size (i.e. excluding 1 and the dim size itself). A
+              choice may be the intra-node size (e.g. ``torch.cuda.device_count()``).
+              This allows the all-gather in backward to be over a smaller world
+              size at the cost of higher memory usage than setting to ``True``.
+            - After forward, the parameters registered to the module depend on
+              to this: The registered parameters are the sharded parameters if
+              ``True``; unsharded parameters if ``False``; and the parameters
+              resharded to the smaller mesh otherwise. To modify the parameters
+              between forward and backward, the registered parameters must be
+              the sharded parameters. For ``False`` or an ``int``, this can be
+              done by manually resharding via :meth:`reshard`.
+        shard_placement_fn (Optional[Callable[[nn.Parameter], Optional[Shard]]]):
+            This callable can be used to override the sharding placement for a
+            parameter to shard a parameter on a dimension other than dim-0. If
+            this callable returns a :class:`Shard` placement (not ``None``),
+            then FSDP will shard according to that placement (e.g. ``Shard(1)``).
+            If sharding on a nonzero dim, we currently require even sharding,
+            i.e. the tensor dim size on that dim must be divisible by the FSDP
+            shard mesh size.
+        mp_policy (MixedPrecisionPolicy): This controls the mixed precision
+            policy, which offers parameter/reduction mixed precision for this
+            module. See :class:`MixedPrecisionPolicy` for details.
+        offload_policy (OffloadPolicy): This controls the offloading policy,
+            which offers parameter/gradient/optimizer state offloading. See
+            :class:`OffloadPolicy` and its subclasses for details.
+        ignored_params: Optional(Set[nn.Parameter]): The set of parameters to be
+            ignored by FSDP. They will not be sharded, nor moved to the device
+            during init, nor have their gradients reduced in backward.
+
+    Returns:
+        FSDPModule: The module with FSDP applied (in-place).
+    """
+    torch._C._log_api_usage_once("torch.distributed.fsdp.fully_shard")
+    if isinstance(module, (nn.ModuleList, nn.ModuleDict)):
+        raise ValueError(
+            f"fully_shard does not support containers that do not implement forward: {module}"
+        )
+    mesh = mesh or _init_default_fully_shard_mesh()
+    if mesh.ndim not in (1, 2):
+        raise ValueError(f"fully_shard expects a 1D or 2D DeviceMesh but got {mesh}")
+    elif mesh.ndim == 1:
+        mesh_info = FSDPMeshInfo(mesh, shard_mesh_dim=0)
+    else:
+        if mesh.mesh_dim_names is None:
+            raise AssertionError(
+                "Please init the 2D mesh for HSDP with mesh_dim_names specified"
+            )
+        mesh_info = HSDPMeshInfo(mesh, shard_mesh_dim=1, replicate_mesh_dim=0)
+    device = _get_device_from_mesh(mesh)
+    auto_reshard_after_forward = reshard_after_forward is None
+    # If the user does not provide ``reshard_after_forward``, we set it to True.
+    # During lazy_init, we identify which module is the root and override its value to False
+    post_forward_mesh_info = _get_post_forward_mesh_info(
+        reshard_after_forward if not auto_reshard_after_forward else True,  # type: ignore[arg-type]
+        mesh_info,
+    )
+
+    arg_module = module
+    modules = (
+        (module,) if isinstance(module, nn.Module) else tuple(_get_root_modules(module))
+    )
+    state = fully_shard.state(modules[0])  # type: ignore[attr-defined] # see [1]
+    state.init(modules, device, mp_policy, auto_reshard_after_forward)
+
+    managed_modules = _get_managed_modules(modules, ignored_params)
+    params, buffers = _get_managed_states(managed_modules, ignored_params)
+
+    _move_states_to_device(params, buffers, device)
+    if params:
+        state._fsdp_param_group = FSDPParamGroup(
+            params,
+            modules,
+            mesh_info,
+            post_forward_mesh_info,
+            device,
+            shard_placement_fn,
+            mp_policy,
+            offload_policy,
+        )
+
+    # For Dynamo
+    for managed_module in managed_modules:
+        managed_module._is_fsdp_managed_module = True  # type: ignore[assignment]
+        managed_module._fsdp_use_orig_params = True  # type: ignore[assignment]
+
+    # Place FSDP leftmost for highest priority in the method resolution order
+    for module in modules:
+        cls = module.__class__
+        new_cls = cls_to_fsdp_cls.get(cls)
+        if not new_cls:
+            dct = {"__deepcopy__": _unimplemented_deepcopy}
+            new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct)
+            cls_to_fsdp_cls[cls] = new_cls
+        module.__class__ = new_cls
+    return arg_module
+
+
+def _unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn:
+    raise AssertionError(
+        "FSDP does not support deepcopy. Please use state dict for serialization."
+    )
+
+
+_enable_fsdp_module_new_init: bool = True
+
+
+@contextmanager
+def disable_fsdp_module_new_init() -> Iterator[None]:
+    global _enable_fsdp_module_new_init
+    prev, _enable_fsdp_module_new_init = _enable_fsdp_module_new_init, False
+    try:
+        yield
+    finally:
+        _enable_fsdp_module_new_init = prev
+
+
+class FSDPModule:
+    def __new__(cls, *args, **kwargs):
+        """
+        Override ``__new__`` to remove the FSDP class and directly construct
+        the original class for cases like indexing into a container module.
+        """
+        # Use index 2 since 0 is the dynamically constructed `FSDP<...>` class
+        # and index 1 is the `FSDPModule` class itself
+        orig_cls = cls.__mro__[2]
+        self = orig_cls.__new__(orig_cls, *args, **kwargs)
+        if _enable_fsdp_module_new_init:
+            self.__init__(*args, **kwargs)
+        return self
+
+    def reshard(self) -> None:
+        """
+        Reshards the module's parameters, freeing the unsharded parameters if
+        they are allocated and registering the sharded parameters to the
+        module. This method is *not* recursive.
+        """
+        state = self._get_fsdp_state()
+        if fsdp_param_group := state._fsdp_param_group:
+            fsdp_param_group.reshard()
+
+    def unshard(self, async_op: bool = False) -> Optional[UnshardHandle]:
+        """
+        Unshards the module's parameters by allocating memory and all-gathering
+        the parameters. This method is *not* recursive. The unshard follows the
+        :class:`MixedPrecisionPolicy`, so it will all-gather following
+        ``param_dtype`` if set.
+
+        Args:
+            async_op (bool): If ``True``, then returns a :class:`UnshardHandle`
+                that has a :meth:`wait` method to wait on the unshard op. If
+                ``False``, then returns ``None`` and waits on the handle inside
+                this function.
+
+        .. note:: If ``async_op=True``, then FSDP will wait on the pending
+            unshard in the module's pre-forward for the user. The user only
+            needs to call :meth:`wait` explicitly if the wait should happen
+            before pre-forward.
+        """
+        state = self._get_fsdp_state()
+        fsdp_param_group = state._fsdp_param_group
+        if fsdp_param_group is not None:
+            fsdp_param_group.lazy_init()
+            fsdp_param_group.unshard(async_op=async_op)
+        handle = _UnshardHandleImpl(fsdp_param_group)
+        if async_op:
+            return handle
+        handle.wait()
+        return None
+
+    def set_is_last_backward(self, is_last_backward: bool) -> None:
+        """
+        Sets whether the next backward is the last one. On the last backward,
+        FSDP waits on pending gradient reduction and clears internal data
+        data structures for backward prefetching. This can be useful for
+        microbatching.
+        """
+        state = self._get_fsdp_state()
+        state._state_ctx.is_last_backward = is_last_backward
+
+    def set_requires_gradient_sync(
+        self, requires_gradient_sync: bool, *, recurse: bool = True
+    ) -> None:
+        """
+        Sets if the module should sync gradients. This can be used to implement
+        gradient accumulation *without communication*. For HSDP, this controls
+        both reduce-scatter and all-reduce together. This is the equivalence of
+        `no_sync` in FSDP1.
+
+        Args:
+            requires_gradient_sync (bool): Whether to reduce gradients for the
+                module's parameters.
+            recurse (bool): Whether to set for all FSDP submodules or just the
+                passed-in module.
+        """
+        self_module = cast(nn.Module, self)
+        modules = list(self_module.modules()) if recurse else [self_module]
+        for module in modules:
+            if isinstance(module, FSDPModule):
+                state = module._get_fsdp_state()
+                if fsdp_param_group := state._fsdp_param_group:
+                    fsdp_param_group.reduce_grads = requires_gradient_sync
+                    fsdp_param_group.all_reduce_grads = requires_gradient_sync
+
+    def set_requires_all_reduce(
+        self, requires_all_reduce: bool, *, recurse: bool = True
+    ) -> None:
+        """
+        Sets if the module should all-reduce gradients. This can be used to
+        implement gradient accumulation with only reduce-scatter but not
+        all-reduce for HSDP.
+        """
+        self_module = cast(nn.Module, self)
+        modules = list(self_module.modules()) if recurse else [self_module]
+        for module in modules:
+            if isinstance(module, FSDPModule):
+                state = module._get_fsdp_state()
+                if fsdp_param_group := state._fsdp_param_group:
+                    fsdp_param_group.all_reduce_grads = requires_all_reduce
+
+    def set_reshard_after_forward(
+        self, reshard_after_forward: bool, recurse: bool = True
+    ) -> None:
+        """
+        Sets if the module should reshard parameters after forward. This can be
+        used to change the ``reshard_after_forward`` FSDP arg at runtime. For
+        example, this can be used to set the FSDP root module's value to
+        ``True`` (since it is otherwise specially set to ``False``), or it can
+        set an FSDP module's value to ``False`` for running evals and set back
+        to ``True`` for training.
+
+        Args:
+            reshard_after_forward (bool): Whether to reshard parameters after
+                forward.
+            recurse (bool): Whether to set for all FSDP submodules or just the
+                passed-in module.
+        """
+        if not isinstance(reshard_after_forward, bool):
+            raise ValueError(
+                f"reshard_after_forward should be a bool, got {type(reshard_after_forward)}"
+            )
+        self_module = cast(nn.Module, self)
+        modules = list(self_module.modules()) if recurse else [self_module]
+        for module in modules:
+            if isinstance(module, FSDPModule):
+                state = module._get_fsdp_state()
+                state._auto_reshard_after_forward = False
+                if fsdp_param_group := state._fsdp_param_group:
+                    fsdp_param_group.post_forward_mesh_info = (
+                        _get_post_forward_mesh_info(
+                            reshard_after_forward, fsdp_param_group.mesh_info
+                        )
+                    )
+
+    def set_reshard_after_backward(
+        self, reshard_after_backward: bool, *, recurse: bool = True
+    ) -> None:
+        """
+        Sets if the module should reshard parameters after backward. This can
+        be used during gradient accumulation to trade off higher memory for
+        reduced communication since the unsharded parameters do not need to be
+        re-all-gathered before the next forward.
+
+        Args:
+            reshard_after_backward (bool): Whether to reshard parameters after
+                backward.
+            recurse (bool): Whether to set for all FSDP submodules or just the
+                passed-in module.
+        """
+        self_module = cast(nn.Module, self)
+        modules = list(self_module.modules()) if recurse else [self_module]
+        for module in modules:
+            if isinstance(module, FSDPModule):
+                state = module._get_fsdp_state()
+                if fsdp_param_group := state._fsdp_param_group:
+                    fsdp_param_group.reshard_after_backward = reshard_after_backward
+
+    def set_modules_to_forward_prefetch(self, modules: list[FSDPModule]) -> None:
+        """
+        Sets the FSDP modules for which this FSDP module should explicitly
+        prefetch all-gathers in forward. The prefetching runs after this
+        module's all-gather copy-out.
+
+        Passing a singleton list containing the next FSDP module gives the same
+        all-gather overlap behavior as the default overlap behavior, except the
+        prefetched all-gather is issued earlier from the CPU. Passing a list
+        with at least length two is required for more aggressive overlap and
+        will use more reserved memory.
+
+        Args:
+            modules (List[FSDPModule]): FSDP modules to prefetch.
+        """
+        _assert_all_fsdp_modules(modules)
+        self._get_fsdp_state()._states_to_forward_prefetch = [
+            module._get_fsdp_state() for module in modules
+        ]
+
+    def set_modules_to_backward_prefetch(self, modules: list[FSDPModule]) -> None:
+        """
+        Sets the FSDP modules for which this FSDP module should explicitly
+        prefetch all-gathers in backward. This overrides the default backward
+        pretching implementation that prefetches the next FSDP module based on
+        the reverse post-forward order.
+
+        Passing a singleton list containing the previous FSDP module gives the
+        same all-gather overlap behavior as the default overlap behavior.
+        Passing a list with at least length two is required for more aggressive
+        overlap and will use more reserved memory.
+
+        Args:
+            modules (List[FSDPModule]): FSDP modules to prefetch.
+        """
+        _assert_all_fsdp_modules(modules)
+        self._get_fsdp_state()._states_to_backward_prefetch = [
+            module._get_fsdp_state() for module in modules
+        ]
+
+    def set_custom_all_gather(self, comm: AllGather) -> None:
+        """
+        Overrides the default ``all_gather`` communication behavior,
+        to have better control over the communication and memory usage.
+        See `Comm` and `ReduceScatter` for details.
+
+        Args:
+            comm (AllGather): Custom all-gather communication.
+        """
+        state = self._get_fsdp_state()
+        if (fsdp_param_group := state._fsdp_param_group) is not None:
+            fsdp_param_group._all_gather_comm = comm
+
+    def set_custom_reduce_scatter(self, comm: ReduceScatter) -> None:
+        """
+        Overrides the default ``reduce_scatter`` communication behavior,
+        to have better control over the communication and memory usage.
+        See `Comm` and `ReduceScatter` for details.
+
+        Args:
+            comm (ReduceScatter): Custom reduce_scatter communication.
+        """
+        state = self._get_fsdp_state()
+        if (fsdp_param_group := state._fsdp_param_group) is not None:
+            fsdp_param_group._reduce_scatter_comm = comm
+
+    def set_all_reduce_hook(
+        self,
+        hook: Callable[[torch.Tensor], None],
+        *,
+        stream: Optional[torch.cuda.Stream] = None,
+    ):
+        """
+        Args:
+            hook (Callable[[torch.Tensor], None]): User-defined all-reduce hook
+                with expected signature ``hook(reduce_output: torch.Tensor) -> None``
+                where ``reduce_output`` is the reduce-scatter output if only
+                using FSDP or the all-reduce output if using native HSDP.
+            stream (Optional[torch.cuda.Stream]): Stream to run the all-reduce
+                hook in. This should only be set if not using native HSDP. If
+                using native HSDP, the hook will run in the internally defined
+                all-reduce stream used by the native HSDP all-reduce.
+        """
+        state = self._get_fsdp_state()
+        if (fsdp_param_group := state._fsdp_param_group) is not None:
+            fsdp_param_group._all_reduce_hook = hook
+            if stream is not None:
+                if fsdp_param_group._is_hsdp:
+                    raise ValueError("stream cannot be set when using native HSDP")
+                fsdp_param_group._all_reduce_hook_stream = stream
+
+    def set_post_optim_event(self, event: torch.Event) -> None:
+        """
+        Sets a post-optimizer-step event for the root FSDP module to wait the
+        all-gather streams on.
+
+        By default, the root FSDP module waits the all-gather streams on the
+        current stream to ensure that the optimizer step has finished before
+        all-gathering. However, this may introduce false dependencies if
+        there is unrelated computation after the optimizer step. This API
+        allows the user to provide their own event to wait on. After the root
+        waits on the event, the event is discarded, so this API should be
+        called with a new event each iteration.
+
+        Args:
+            event (torch.Event): Event recorded after the optimizer step
+                to wait all-gather streams on.
+        """
+        self._get_fsdp_state()._state_ctx.post_optim_event = event
+
+    @deprecated("Use `set_gradient_divide_factor` instead")
+    def set_reduce_scatter_divide_factor(self, factor: float) -> None:
+        """Use :py:meth:`set_gradient_divide_factor` instead"""
+        self.set_gradient_divide_factor(factor)
+
+    def set_gradient_divide_factor(self, factor: float) -> None:
+        """
+        Sets a custom divide factor for the gradient reduction. This might use
+        a custom reduce op using NCCL's PreMulSum, which allows multiplying by
+        the factor before reduction.
+
+        Args:
+            factor (float): Custom divide factor.
+        """
+        state = self._get_fsdp_state()
+        if (fsdp_param_group := state._fsdp_param_group) is not None:
+            fsdp_param_group.gradient_divide_factor = factor
+
+    def set_force_sum_reduction_for_comms(self, enable: bool) -> None:
+        """
+        Sets whether to require the low-level collective communication
+        primitives to exclusively use "sum"-type reductions, even if it comes
+        at the cost of separate additional pre- or post-scaling operations.
+        This is needed for example because NCCL currently supports zero-copy
+        transfers only for this kind of collectives.
+
+        NB: for MTIA devices, this is always implicitly enabled.
+
+        NB: if `set_all_reduce_hook` is used under FSDP setup, the caller needs
+        to ensure the custom all-reduce across FSDP units follow this strategy
+        as well, as FSDP can no longer automatically handle that.
+
+        Args:
+            enable (bool): Whether to only ever use ReduceOp.SUM for comms.
+        """
+        state = self._get_fsdp_state()
+        if (fsdp_param_group := state._fsdp_param_group) is not None:
+            fsdp_param_group.force_sum_reduction_for_comms = enable
+
+    def set_unshard_in_backward(self, unshard_in_backward: bool) -> None:
+        """
+        Sets whether the FSDP module's parameters need to be unsharded in
+        backward. This can be used in expert cases when the user knows that all
+        parameters in this FSDP module's parameter group are not needed for
+        backward computation (e.g. embedding).
+        """
+        state = self._get_fsdp_state()
+        if (fsdp_param_group := state._fsdp_param_group) is not None:
+            fsdp_param_group.unshard_in_backward = unshard_in_backward
+
+    def set_allocate_memory_from_process_group_for_comm(self, enable: bool) -> None:
+        """
+        Sets whether the temporary staging buffers used to send and receive data
+        over collective communications should be allocated using the custom
+        optimized allocator provided by the ProcessGroup itself (if any). This
+        might allow the ProcessGroup to be more efficient. For example, when
+        using NCCL, this enables it to leverage zero-copy transfers over SHARP
+        (for NVLink and/or InfiniBand).
+
+        This cannot be used together with :meth:`set_custom_all_gather` or
+        :meth:`set_custom_reduce_scatter` as those APIs allow for
+        finer-grained control over each communication, and this method cannot
+        determine their staging buffer allocation strategy.
+
+        Args:
+            enable (bool): Whether to turn on ProcessGroup allocation.
+        """
+        state = self._get_fsdp_state()
+        if (fsdp_param_group := state._fsdp_param_group) is not None:
+            fsdp_param_group.set_allocate_memory_from_process_group(enable)
+
+    def _set_unshard_async_op(self, async_op: bool):
+        """
+        Sets whether to use ``async_op=True`` or ``False`` for the pre-forward
+        and pre-backward unshard op. This defaults to ``False`` but can be set
+        to ``True`` with this method.
+
+        Setting this to ``True`` allows the all-gather allocations to happen in
+        the default stream, avoiding inter-stream memory fragmentation.
+        However, you must use explicit prefetching (e.g. via :meth:`unshard`)
+        in forward to still get overlap, and the pre-all-gather ops like dtype
+        casting and copy-in will not overlap with compute.
+        """
+        self_module = cast(nn.Module, self)
+        for module in self_module.modules():
+            if isinstance(module, FSDPModule):
+                state = module._get_fsdp_state()
+                if fsdp_param_group := state._fsdp_param_group:
+                    fsdp_param_group.unshard_async_op = async_op
+
+    def _get_fsdp_state(self) -> FSDPState:
+        if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None:
+            raise AssertionError(f"No FSDP state found on {self}")
+        return state
+
+    def _apply(self, *args: Any, **kwargs: Any) -> Any:
+        # Reshard to ensure that sharded parameters are registered
+        self.reshard()
+        ret = super()._apply(*args, **kwargs)  # type: ignore[misc]
+        state = self._get_fsdp_state()
+        if not (fsdp_param_group := state._fsdp_param_group):
+            return ret
+        # TODO: Remove this padding logic once DTensor pads the local tensor:
+        # https://github.com/pytorch/pytorch/issues/113045
+        with torch.no_grad():
+            for fsdp_param in fsdp_param_group.fsdp_params:
+                fsdp_param.reset_sharded_param()
+        return ret
+
+
+class UnshardHandle:
+    """
+    A handle to wait on a :meth:`FSDPModule.unshard` op.
+    """
+
+    def wait(self) -> None:
+        """
+        Waits on the unshard op. This ensures that the current stream can use
+        the unsharded parameters, which are now registered to the module.
+        """
+        return
+
+
+class _UnshardHandleImpl(UnshardHandle):
+    def __init__(self, fsdp_param_group: Optional[FSDPParamGroup]):
+        self._fsdp_param_group = fsdp_param_group
+
+    def wait(self):
+        if self._fsdp_param_group is not None:
+            self._fsdp_param_group.wait_for_unshard()
+            # Avoid keeping a reference
+            self._fsdp_param_group = None
+
+
+def register_fsdp_forward_method(module: nn.Module, method_name: str) -> None:
+    """
+    Registers a method on ``module`` to be considered a forward method for
+    FSDP.
+
+    FSDP all-gathers parameters pre-forward and optionally frees parameters
+    post-forward (depending on ``reshard_after_forward``). FSDP only knows to
+    do this for :meth:`nn.Module.forward` by default. This function patches a
+    user-specified method to run the pre/post-forward hooks before/after the
+    method, respectively. If ``module`` is not an :class:`FSDPModule`, then
+    this is a no-op.
+
+    Args:
+        module (nn.Module): Module to register the forward method on.
+        method_name (str): Name of the forward method.
+    """
+    if not isinstance(module, FSDPModule):
+        # Make no-op to allow including both when using/not using FSDP
+        return
+    if not hasattr(module, method_name):
+        raise ValueError(f"{type(module)} does not have a method {method_name}")
+    orig_method = getattr(module, method_name)
+
+    @functools.wraps(orig_method)
+    def wrapped_method(self, *args, **kwargs):
+        fsdp_state = self._get_fsdp_state()
+        args, kwargs = fsdp_state._pre_forward(self, args, kwargs)
+        out = orig_method(*args, **kwargs)
+        return fsdp_state._post_forward(self, args, out)
+
+    # Use `__get__` to make `wrapped_method` an instance method
+    setattr(
+        module,
+        method_name,
+        wrapped_method.__get__(module, type(module)),  # type:ignore[attr-defined]
+    )
+
+
+def share_comm_ctx(modules: list[FSDPModule]) -> None:
+    """
+    Share cuda streams for multiple FSDPModules
+
+    Example usage:
+        from torch.distributed.fsdp import share_comm_ctx
+        share_comm_ctx([fsdp_model_1, fsdp_model_2, ...])
+
+    For Pipeline Parallelism (PP), each model chunk is a FSDP root. We want
+    to share cuda streams for all-gather, reduce-scatter, and all-reduce.
+    This avoids allocating inter-stream memory framgmentation
+
+    Args:
+        modules (List[FSDPModule]): modules to share cuda streams
+    """
+    if len(modules) == 0:
+        return
+    for module in modules:
+        if not isinstance(module, FSDPModule):
+            raise ValueError(f"Expects list of FSDPModules but got {module}")
+    fsdp_states = [module._get_fsdp_state() for module in modules]
+    comm_ctx = fsdp_states[0]._comm_ctx
+    for fsdp_state in fsdp_states[1:]:
+        fsdp_state._comm_ctx = comm_ctx
+        if fsdp_param_group := fsdp_state._fsdp_param_group:
+            fsdp_param_group.comm_ctx = comm_ctx
+
+
+def _assert_all_fsdp_modules(modules: Iterable[Any]) -> None:
+    for module in modules:
+        if not isinstance(module, FSDPModule):
+            raise ValueError(f"Expects FSDPModule but got {type(module)}: {module}")
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_adamw.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_adamw.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f95783222bd87887181399275d30414bf5847fe0
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_adamw.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cfaa668a18373df8576804a8cb730d8e030ad46
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+from ._conv_ops import *  # noqa: F403
+from ._embedding_ops import *  # noqa: F403
+from ._math_ops import *  # noqa: F403
+from ._matrix_ops import *  # noqa: F403
+from ._pointwise_ops import *  # noqa: F403
+from ._random_ops import *  # noqa: F403
+from ._tensor_ops import *  # noqa: F403
+from ._view_ops import *  # noqa: F403
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_common_rules.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_common_rules.py
new file mode 100644
index 0000000000000000000000000000000000000000..2312f8e56c554b03e04a2d93fe45b899cf948916
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_common_rules.py
@@ -0,0 +1,285 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+import string
+from typing import cast
+
+import torch
+from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
+from torch.distributed.tensor._op_schema import OpSchema, OutputSharding
+from torch.distributed.tensor._ops.utils import prod
+from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
+
+
+def _replace_char_in_str(string: str, new_char: str, idx: int) -> str:
+    return string[:idx] + new_char + string[idx + 1 :]
+
+
+def _gen_reshard_suggestions(
+    op_schema: OpSchema,
+    input_dims: list[str],
+    input_specs: tuple[DTensorSpec, ...],
+    dim_to_sharding: dict[str, int],
+    pending_sum: list[int],
+) -> OutputSharding:
+    suggested_arg_specs: list[DTensorSpec] = []
+    for input_dim, input_spec in zip(input_dims, input_specs):
+        dim_map = [dim_to_sharding[dim] for dim in input_dim]
+        suggested_arg_specs.append(
+            DTensorSpec.from_dim_map(
+                mesh=input_spec.mesh,
+                dim_map=dim_map,
+                sums=pending_sum,
+                tensor_meta=input_spec.tensor_meta,
+            )
+        )
+    suggested_schema = OpSchema(op_schema.op, tuple(suggested_arg_specs), {})
+    suggested_schema._inplace_rewrap_schema_suggestion(op_schema)
+    return OutputSharding(
+        None,
+        redistribute_schema=suggested_schema,
+    )
+
+
+def einop_rule(
+    equation: str,
+    op_schema: OpSchema,
+    *,
+    linearity: bool = False,
+    enforce_sharding: dict[str, int] | None = None,
+) -> OutputSharding:
+    """
+    Propagate the sharding of inputs to output for ops whose data moves according to einsum notation.
+
+    This is mostly borrowed from @zdevito's sharding simulator. Examples:
+        mk,kn->mn - einsum
+        ij,ij->ij - addition
+        ij,j->ij - broadcasted addition
+        ij->i - reduction
+    Other ops could use this propagation algorithm when applied, note
+    that einsum propagation only deal with list of specs (DTensor specs)
+    as it only works on list of tensors!
+
+    linearity in einop_rule means that the calling op `f` follows this rule:
+        f(a + b) = f(a) + f(b)
+
+    In this case we can propagate the partial sum, note that linearity in einop
+    only applies to partial sum, not other operations like min/max (which are
+    associative but not linear).
+    """
+    # parse einop equation and extract arg specs
+    inputs, outputs = equation.split("->")
+    input_dims, output_dims = inputs.split(","), outputs.split(",")
+    input_specs = op_schema.args_spec
+    # NOTE: only support single output unless needed in future
+    output_dim = output_dims[0]
+
+    dim_to_sharding: dict[str, int] = {}
+    dim_to_size: dict[str, int] = {}
+    # record pending sum, key is mesh dimension, value is pending sum
+    # counter across input specs
+    pending_sums_counter: dict[int, int] = {}
+    seen_shardings: dict[int, str] = {}
+    needs_reshard = False
+
+    def merge_sharding(dim: str, a: int, b: int) -> int:
+        # merge the sharding of inputs if it's able to merge, i.e. we can merge
+        # replicate and shard to shard, but this will trigger an reshard operation
+        if a != b:
+            if a == -1 or b == -1:
+                # reshard the replicate to match the sharded one
+                nonlocal needs_reshard
+                needs_reshard = True
+                return a if a != -1 else b
+            else:
+                # TODO: further merge the sharding properly (i.e. reshard one input to replicate)
+                raise RuntimeError(
+                    f"{equation}: dim {dim} sharded two different ways: {a} and {b}"
+                )
+        else:
+            return a
+
+    for input_dim, input_spec in zip(input_dims, input_specs):
+        # deal with partial sums
+        input_sums = input_spec.sums
+        for sum_dim in input_sums:
+            if sum_dim not in pending_sums_counter:
+                seen_shardings[sum_dim] = "+"
+            # update pending sum counter for pending sum mesh
+            # dimension with the occurrence from each input
+            pending_sums_counter[sum_dim] = pending_sums_counter.get(sum_dim, 0) + 1
+
+        for idx, (dim, mesh_dim) in enumerate(zip(input_dim, input_spec.dim_map)):
+            if enforce_sharding and dim in enforce_sharding:
+                if enforce_sharding[dim] != mesh_dim:
+                    needs_reshard = True
+                dim_to_sharding[dim] = enforce_sharding[dim]
+                dim_to_size[dim] = input_spec.shape[idx]
+            elif dim not in dim_to_sharding:
+                dim_to_sharding[dim] = mesh_dim
+                dim_to_size[dim] = input_spec.shape[idx]
+            else:
+                dim_to_sharding[dim] = merge_sharding(
+                    dim, dim_to_sharding[dim], mesh_dim
+                )
+                assert dim_to_size[dim] == input_spec.shape[idx]
+
+            # after merging sharding, we check if there're multiple
+            # sharding on the same mesh dim.
+            merged_sharding_for_dim = dim_to_sharding[dim]
+            if merged_sharding_for_dim != -1:
+                if (
+                    merged_sharding_for_dim in seen_shardings
+                    and dim != seen_shardings[merged_sharding_for_dim]
+                ):
+                    needs_reshard = True
+                    seen_shardings[merged_sharding_for_dim] += dim
+                else:
+                    seen_shardings[merged_sharding_for_dim] = dim
+
+    if pending_sums_counter and not linearity:
+        # return reshard suggestion with no pending sum, because we already properly
+        # merge the sharding, this reshard suggestion is legit to use
+        return _gen_reshard_suggestions(
+            op_schema, input_dims, input_specs, dim_to_sharding, []
+        )
+    else:
+        # It's a op that support linearity, but not all input arguments are partial
+        # we fail the sharding propagation with suggestion to make all inputs be
+        # partial on the corresponding mesh dim (all inputs should be partial for
+        # the mesh dims in order to execute locally and delay the sum reduction)
+        for value in pending_sums_counter.values():
+            if value != len(input_specs):
+                needs_reshard = True
+
+    for mesh_dim, dims in seen_shardings.items():
+        if len(dims) > 1:
+            # we found different input dims are being sharded on the same mesh dim
+            # in order to perform local op computation, we need to reshard inputs
+            # base on some simple heuristics, now we simply pick the one with least comm
+            # volume. (i.e. the input with least size)
+            # TODO: consider a more advanced heuristic to pick the best sharding
+            costs = []
+            for d in dims:
+                cost = 0
+                for input_dim, input_spec in zip(input_dims, input_specs):
+                    if (
+                        d in input_dim
+                        and input_spec.dim_map[input_dim.index(d)] == mesh_dim
+                    ):
+                        assert input_spec.tensor_meta is not None
+                        global_shape = input_spec.tensor_meta.shape
+                        local_shape, _ = compute_local_shape_and_global_offset(
+                            global_shape,
+                            input_spec.mesh,
+                            input_spec.placements,
+                            skip_offset=True,
+                        )
+                        cost += prod(local_shape) * input_spec.mesh.size(mesh_dim)
+                # pyrefly: ignore [bad-argument-type]
+                costs.append(cost)
+            d_to_keep_sharding = dims[costs.index(max(costs))]
+            for d in dims:
+                # update dim_to_sharding to keep the sharding of the dim with
+                # highest comm and make the rest of the dims to replicate
+                if d != d_to_keep_sharding:
+                    dim_to_sharding[d] = -1
+
+    pending_sums = list(pending_sums_counter.keys())
+    if needs_reshard:
+        return _gen_reshard_suggestions(
+            op_schema, input_dims, input_specs, dim_to_sharding, pending_sums
+        )
+
+    # generate output pending sum if a dim is sharded, and it appears in input
+    # but not output
+    for dim, shard_on_mesh in dim_to_sharding.items():
+        if dim not in output_dims[0] and shard_on_mesh != -1:
+            pending_sums.append(shard_on_mesh)
+
+    # if no need to reshard, we directly generate the output sharding
+    output_dim_map = []
+    output_shape = []
+    for dim in output_dim:
+        if dim == "1":
+            # find output dim that is a singleton dimension, mark sharding and shape
+            output_dim_map.append(-1)
+            output_shape.append(1)
+        else:
+            output_dim_map.append(dim_to_sharding[dim])
+            output_shape.append(dim_to_size[dim])
+
+    # XXX: since we still need to have intermediate shape calculation, we need
+    # to pass in the shape here. We should remove this once sharding decomp works
+    # for ops like addmm
+    assert input_specs[0].tensor_meta is not None
+    tensor_meta = TensorMeta(
+        torch.Size(output_shape),
+        input_specs[0].tensor_meta.stride,
+        input_specs[0].tensor_meta.dtype,
+    )
+    return OutputSharding(
+        DTensorSpec.from_dim_map(
+            input_specs[0].mesh,
+            output_dim_map,
+            pending_sums,
+            tensor_meta=tensor_meta,
+        )
+    )
+
+
+def pointwise_rule(op_schema: OpSchema, linearity: bool = False) -> OutputSharding:
+    """
+    Propagate the sharding for pointwise operations.
+
+    Examples:
+        ij,ij->ij - addition/mul
+        ij,j->ij - broadcasted addition
+    """
+    alphabet = string.ascii_lowercase
+    # find the max_dim first in case we need to broadcasting
+    input_specs = op_schema.args_spec
+    max_dim = max(input.ndim for input in input_specs)
+    dimchars = []
+    singleton_counter: list[int] = [0] * max_dim
+    for input in input_specs:
+        start_dim = max_dim - input.ndim
+        p = alphabet[start_dim:max_dim]
+        # handle the "broadcasting to a common shape case"
+        # see https://pytorch.org/docs/stable/notes/broadcasting.html
+        # If any of the dimensions is singleton dimension (i.e. 1).
+        # we mark the dim char as a special "1" to distinguish with
+        # the non-singleton dimension, so that sharding propagation
+        # should just ignore the singleton dimension.
+        if len(input_specs) > 1:
+            for i in range(max_dim):
+                if i < start_dim:
+                    # treat the leading miss dim chars as singleton
+                    singleton_counter[i] += 1
+                elif input.shape[i - start_dim] == 1:
+                    # mark singleton dim char as a special "1" in einop rule
+                    singleton_counter[i] += 1
+                    p = _replace_char_in_str(p, "1", (i - start_dim))
+
+        dimchars.append(p)
+    out_dimchars = alphabet[:max_dim]
+    # check if we replace the all inputs dim char with singleton dimension,
+    # if we replace all inputs, we also need to replace the output dimension.
+    for output_dim_idx in range(len(out_dimchars)):
+        if singleton_counter[output_dim_idx] == len(input_specs):
+            out_dimchars = _replace_char_in_str(out_dimchars, "1", output_dim_idx)
+
+    fmt = f"{','.join(p for p in dimchars)}->{out_dimchars}"
+
+    enforce_sharding: dict[str, int] = {}
+    if op_schema.is_inplace_op():
+        follow_spec = op_schema.args_spec[0]
+        enforce_sharding.update(zip(out_dimchars, follow_spec.dim_map))
+    elif op_schema.is_out_variant_op():
+        follow_spec = cast(DTensorSpec, op_schema.kwargs_schema["out"])
+        enforce_sharding.update(zip(out_dimchars, follow_spec.dim_map))
+
+    return einop_rule(
+        fmt,
+        op_schema,
+        linearity=linearity,
+        enforce_sharding=enforce_sharding,
+    )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_conv_ops.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_conv_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f456d505c12789f82e6c16aabf2692e871d3dfc
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_conv_ops.py
@@ -0,0 +1,127 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# implement matrix related ops for distributed tensor
+
+import torch
+from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
+from torch.distributed.tensor._op_schema import OpSchema, OutputSharding
+from torch.distributed.tensor._ops.registration import register_prop_rule
+
+
+aten = torch.ops.aten
+
+
+@register_prop_rule(aten.convolution.default)
+def convolution_rules(op_schema: OpSchema) -> OutputSharding:
+    (
+        input_spec,
+        weight_spec,
+        bias_spec,
+        stride,
+        padding,
+        dilation,
+        _transposed,
+        _output_padding,
+        _groups,
+    ) = op_schema.args_schema
+
+    assert isinstance(input_spec, DTensorSpec)
+    assert isinstance(weight_spec, DTensorSpec)
+    # bias_spec can be None (optional parameter in aten.convolution schema)
+    if bias_spec is not None:
+        assert isinstance(bias_spec, DTensorSpec)
+    assert input_spec.tensor_meta is not None
+    assert weight_spec.tensor_meta is not None
+    in_shape = input_spec.tensor_meta.shape
+    weight_shape = weight_spec.tensor_meta.shape
+    assert isinstance(stride, list), f"stride must be list, got {type(stride)}"
+    assert isinstance(padding, list), f"padding must be list, got {type(padding)}"
+    assert isinstance(dilation, list), f"dilation must be list, got {type(dilation)}"
+    # weight_shape might not be torch.Size in all cases (e.g., SymIntArrayRef during tracing)
+    # so we don't assert its type, just use it
+    out_conv_shape = [
+        (d + 2 * padding[i] - dilation[i] * (weight_shape[i + 1] - 1) - 1) // stride[i]
+        + 1
+        for (i, d) in enumerate(in_shape[2:])
+    ]
+    output_shape = [in_shape[0], weight_shape[0]] + out_conv_shape
+    output_stride = [1]
+    for i in range(1, len(output_shape)):
+        output_stride.insert(0, output_stride[0] * output_shape[-i])
+    output_dim_map = input_spec.dim_map
+    pending_sums = input_spec.sums
+
+    tensor_meta = TensorMeta(
+        torch.Size(output_shape),
+        tuple(output_stride),
+        input_spec.tensor_meta.dtype,
+    )
+    return OutputSharding(
+        DTensorSpec.from_dim_map(
+            input_spec.mesh,
+            output_dim_map,
+            pending_sums,
+            tensor_meta=tensor_meta,
+        )
+    )
+
+
+@register_prop_rule(aten.convolution_backward.default)
+def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding:
+    input_spec = op_schema.args_schema[0]
+    (
+        grad_output_spec,
+        input_spec,
+        weight_spec,
+        bias_shape_opt,
+        _stride,
+        _padding,
+        _dilation,
+        _transposed,
+        _output_padding,
+        _groups,
+        _output_mask,
+    ) = op_schema.args_schema
+
+    assert isinstance(grad_output_spec, DTensorSpec)
+    assert isinstance(input_spec, DTensorSpec)
+    assert isinstance(weight_spec, DTensorSpec)
+    # bias_shape_opt can be None (optional parameter in aten.convolution_backward schema)
+    if bias_shape_opt is not None:
+        assert isinstance(bias_shape_opt, list)
+    assert input_spec.tensor_meta is not None
+    weight_tensor_meta = weight_spec.tensor_meta
+
+    # Only create bias_tensor_meta if bias_shape_opt is not None
+    if bias_shape_opt is not None:
+        bias_tensor_meta = TensorMeta(
+            torch.Size(bias_shape_opt),
+            (1,),
+            input_spec.tensor_meta.dtype,
+        )
+    else:
+        bias_tensor_meta = None
+
+    grad_input_spec = input_spec
+    grad_weight_spec = DTensorSpec.from_dim_map(
+        input_spec.mesh,
+        [-1, -1, -1, -1],
+        [0],
+        tensor_meta=weight_tensor_meta,
+    )
+
+    # Only create grad_bias_spec if we have bias_tensor_meta
+    if bias_tensor_meta is not None:
+        grad_bias_spec = DTensorSpec.from_dim_map(
+            input_spec.mesh,
+            [-1],
+            [0],
+            tensor_meta=bias_tensor_meta,
+        )
+    else:
+        grad_bias_spec = None
+
+    # TODO: actually the output_mask is not respected here, we should
+    # set the corresponding spec to `None` if the output_mask is not `False`
+    # for a certain output Tensor. This also applies to the conv handler
+    # in torch/distributed/tensor/_tp_conv.py
+    return OutputSharding([grad_input_spec, grad_weight_spec, grad_bias_spec])
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_einsum_strategy.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_einsum_strategy.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d46ede21f97bdf8539e73e14eab3a5697402d8e
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_einsum_strategy.py
@@ -0,0 +1,186 @@
+import itertools
+from dataclasses import dataclass
+
+from torch.distributed.device_mesh import DeviceMesh
+from torch.distributed.tensor._dtensor_spec import DTensorSpec
+from torch.distributed.tensor._op_schema import OpSpec, OpStrategy
+from torch.distributed.tensor.placement_types import (
+    Partial,
+    Placement,
+    Replicate,
+    Shard,
+)
+
+
+@dataclass
+class EinsumDims:
+    contracting_dims: list[str]
+    batch_dims: list[str]
+    lhs_out_only_dims: list[str]
+    rhs_out_only_dims: list[str]
+
+    @classmethod
+    def parse_equation(cls, equation: str) -> tuple[list[str], str]:
+        # parse einop equation and extract arg specs
+        """
+        Parse the einsum equation str to input dim chars and output dim char
+        """
+        inputs, outputs = equation.split("->")
+        input_dims, output_dims = inputs.split(","), outputs.split(",")
+
+        # NOTE: only support at most two inputs, and single output
+        # extend to support more inputs if needed in future
+        assert len(input_dims) <= 2, "Only support at most two inputs"
+        assert len(output_dims) == 1, "Only support single output"
+        output_dim = output_dims[0]
+        return input_dims, output_dim
+
+    @classmethod
+    def parse_dims(cls, input_dims: list[str], output_dim: str) -> "EinsumDims":
+        """
+        Parse the dims and extract the contracting, batch, and free dimensions
+        for the left and right hand sides.
+        """
+        dim_char_set: set[str] = set()
+        for input_dim in input_dims:
+            dim_char_set.update(input_dim)
+
+        # get a deterministic order of all dim chars
+        all_dim_chars = sorted(dim_char_set)
+
+        # parse input and output dimensions
+        lhs_out_only_dims, rhs_out_only_dims = [], []
+        batch_dims, contracting_dims = [], []
+
+        for dim_char in all_dim_chars:
+            if dim_char not in output_dim:
+                contracting_dims.append(dim_char)
+            else:
+                is_batch_dim = True
+                for input_dim in input_dims:
+                    is_batch_dim = is_batch_dim and dim_char in input_dim
+
+                if is_batch_dim:
+                    batch_dims.append(dim_char)
+                else:
+                    assert len(input_dims) == 2, (
+                        "free dimension only supported for two inputs!"
+                    )
+                    lhs, rhs = input_dims
+                    if dim_char in lhs:
+                        lhs_out_only_dims.append(dim_char)
+                    elif dim_char in rhs:
+                        rhs_out_only_dims.append(dim_char)
+                    else:
+                        raise RuntimeError("Invalid dimension character")
+
+        return cls(
+            contracting_dims=contracting_dims,
+            batch_dims=batch_dims,
+            lhs_out_only_dims=lhs_out_only_dims,
+            rhs_out_only_dims=rhs_out_only_dims,
+        )
+
+
+def gen_einsum_strategies(
+    equation: str,
+    mesh: DeviceMesh,
+    *,
+    linearity: bool = False,
+) -> OpStrategy:
+    """
+    Generate a strategy list for the ops that follow einsum style notation.
+
+    In principle, each mesh dim is independent of other device mesh dim when we
+    generate strategies. So we generate strategy over each device mesh dim and
+    do product combination on all mesh dims. We basically follow the below rule
+    for each device mesh dim:
+
+    1. Shard on contracting dim: When both inputs shard on contracting dim over
+       the same device dim. The result will be Partial over that device dim.
+
+    2. Shard on noncontracting dim:
+        2.1: Shard on batch dim: output, both inputs all should shard on batch
+        dim.
+        2.2: Shard on lhs only dim or rhs only dim: both output and lhs or rhs
+        input should shard on this free dim.
+
+    3. Linearity (Partial): If enabled, set Partial on output and inputs over
+       the same device mesh dim.
+    """
+    # parse einop equation and extract dims
+    input_dims, output_dim = EinsumDims.parse_equation(equation)
+    edims = EinsumDims.parse_dims(input_dims, output_dim)
+    all_mesh_dim_strategies = []
+
+    # generate strategies for each mesh dim and do cartesian product for final strategy. E.g., for a 2D mesh, we can have [P(),R,R]
+    strategies_over_one_mesh_dim = []
+
+    # placement list stores placements of [output, input1, input2, ...]
+    # first we always have replicate all for inputs and output
+    placement_list: list[Placement] = [Replicate()] * (len(input_dims) + 1)
+    strategies_over_one_mesh_dim.append(placement_list)
+
+    # split batch dim
+    for batch_dim in edims.batch_dims:
+        output_batch_dim = output_dim.index(batch_dim)
+        placement_list = [Shard(output_batch_dim)]
+        for input_dim in input_dims:
+            input_batch_dim = input_dim.index(batch_dim)
+            placement_list.append(Shard(input_batch_dim))
+
+        strategies_over_one_mesh_dim.append(placement_list)
+
+    # split contracting dim
+    for contracting_dim in edims.contracting_dims:
+        # Contracting dim can shard on same device axis for both inputs. This
+        # results in the output being Partial on that device axis. For example:
+        # bmk_{x},k_{x}n -> bmn{Ux} (becomes partial over device axis x)
+        placement_list = [Partial()]
+        for input_dim in input_dims:
+            input_contracting_dim = input_dim.index(contracting_dim)
+            placement_list.append(Shard(input_contracting_dim))
+
+        strategies_over_one_mesh_dim.append(placement_list)
+
+    # split lhs free dim
+    for lhs_dim in edims.lhs_out_only_dims:
+        lhs_free_dim_output = output_dim.index(lhs_dim)
+        lhs_free_dim_input = input_dims[0].index(lhs_dim)
+        # this means split the lhs input and output
+        # i.e. S(0), R -> S(0)
+        lhs_placement_list: list[Placement] = [
+            Shard(lhs_free_dim_output),
+            Shard(lhs_free_dim_input),
+            Replicate(),
+        ]
+        strategies_over_one_mesh_dim.append(lhs_placement_list)
+
+    # split rhs free dim
+    for rhs_dim in edims.rhs_out_only_dims:
+        rhs_free_dim_output = output_dim.index(rhs_dim)
+        rhs_free_dim_input = input_dims[1].index(rhs_dim)
+        rhs_placement_list: list[Placement] = [
+            Shard(rhs_free_dim_output),
+            Replicate(),
+            Shard(rhs_free_dim_input),
+        ]
+        strategies_over_one_mesh_dim.append(rhs_placement_list)
+
+    # linearity strategy
+    if linearity:
+        linearity_placement_list: list[Placement] = [Partial()]
+        for _ in input_dims:
+            linearity_placement_list.append(Partial())
+        strategies_over_one_mesh_dim.append(linearity_placement_list)
+
+    # generate strategies for entire mesh
+    all_mesh_dim_strategies = [strategies_over_one_mesh_dim] * mesh.ndim
+    strategy_combs = itertools.product(*all_mesh_dim_strategies)
+    all_strategies = []
+    for strategy_comb in strategy_combs:
+        spec_list = [DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)]
+        strat = OpSpec(output_specs=spec_list[0], input_specs=spec_list[1:])
+        all_strategies.append(strat)
+
+    return OpStrategy(all_strategies)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_embedding_ops.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_embedding_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7c4abf353be5430c3ff827077b0c6226840bb40
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_embedding_ops.py
@@ -0,0 +1,111 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# implement matrix related ops for distributed tensor
+from typing import cast
+
+import torch
+from torch.distributed.tensor._op_schema import (
+    OpSchema,
+    OpStrategy,
+    PlacementList,
+    StrategyType,
+)
+from torch.distributed.tensor._ops.registration import register_op_strategy
+from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy
+from torch.distributed.tensor.placement_types import (
+    MaskPartial,
+    Partial,
+    Replicate,
+    Shard,
+)
+
+
+aten = torch.ops.aten
+
+
+@register_op_strategy(aten.embedding.default)
+def embedding_strategy(op_schema: OpSchema) -> StrategyType:
+    """
+    This strategy handles embedding op. We have two possible embedding shardings:
+    rowwise and colwise
+    """
+    weight_strategy = cast(OpStrategy, op_schema.args_schema[0])
+    indices_strategy = cast(OpStrategy, op_schema.args_schema[1])
+    mesh = op_schema.get_mesh_from_args()
+
+    weight_shape = weight_strategy.shape
+    indices_shape = indices_strategy.shape
+    output_emd_dim = len(indices_shape)
+
+    single_mesh_dim_strategies = []
+
+    # placement list stores placements of [output, weight, input_indices]
+    # first we always have replicate all for inputs and output
+    all_replicate: PlacementList = [Replicate()] * 3
+    single_mesh_dim_strategies.append(all_replicate)
+
+    # colwise sharding, output shard on last dim, weight shard on dim 1, input replicate
+    colwise_sharding: PlacementList = [Shard(output_emd_dim), Shard(1), Replicate()]
+    single_mesh_dim_strategies.append(colwise_sharding)
+
+    # rowwise sharding, output is embedding partial, weight shard on dim 0, input accepts embedding partial
+    embedding_partial_placement = MaskPartial(offset_shape=weight_shape, offset_dim=0)
+
+    # NOTE we want to reuse the same mask partial placement so that we can reuse the same mask that generates
+    # from the input indices and use it for output reduction
+    rowwise_sharding: PlacementList = [
+        embedding_partial_placement,
+        Shard(0),
+        embedding_partial_placement,
+    ]
+    single_mesh_dim_strategies.append(rowwise_sharding)
+
+    # batch dim sharding, weight replicated, input can shard on any dim, output follows input
+    for input_dim in range(len(indices_shape)):
+        batch_sharding: PlacementList = [
+            Shard(input_dim),
+            Replicate(),
+            Shard(input_dim),
+        ]
+        single_mesh_dim_strategies.append(batch_sharding)
+
+    return expand_to_full_mesh_op_strategy(mesh, op_schema, single_mesh_dim_strategies)
+
+
+@register_op_strategy(aten.embedding_dense_backward.default)
+def embedding_dense_backward_strategy(op_schema: OpSchema) -> StrategyType:
+    """
+    This strategy handles embedding op. We have two possible embedding shardings:
+    rowwise and colwise
+    """
+    grad_out_strategy = cast(OpStrategy, op_schema.args_schema[0])
+    indices_strategy = cast(OpStrategy, op_schema.args_schema[1])
+    mesh = op_schema.get_mesh_from_args()
+
+    grad_out_shape = grad_out_strategy.shape
+    indices_shape = indices_strategy.shape
+    grad_out_ndim = len(grad_out_shape)
+
+    single_mesh_dim_strategies = []
+
+    # placement list stores placements of [output, weight, input_indices]
+    # first we always have replicate all for inputs and output
+    all_replicate: PlacementList = [Replicate()] * 3
+    single_mesh_dim_strategies.append(all_replicate)
+
+    # colwise sharding backward, grad_out shard on last dim, input replicate,
+    # weight grad shard colwise
+    colwise_sharding: PlacementList = [Shard(1), Shard(grad_out_ndim - 1), Replicate()]
+    single_mesh_dim_strategies.append(colwise_sharding)
+
+    # batch dim sharding, weight replicated, grad_out/input have same sharding
+    # that can shard on any dim, weight grad partial
+    for input_dim in range(len(indices_shape)):
+        batch_sharding: PlacementList = [Partial(), Shard(input_dim), Shard(input_dim)]
+        single_mesh_dim_strategies.append(batch_sharding)
+
+    # grad_out partial, input replicate, weight grad keep partial
+    partial_sharding: PlacementList = [Partial(), Partial(), Replicate()]
+    single_mesh_dim_strategies.append(partial_sharding)
+
+    return expand_to_full_mesh_op_strategy(mesh, op_schema, single_mesh_dim_strategies)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_mask_buffer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_mask_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..26b0a713db42c89043c01ae148a4dfd4e1b62d1a
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_mask_buffer.py
@@ -0,0 +1,43 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+from dataclasses import dataclass
+
+import torch
+
+
+@dataclass
+class MaskBuffer:
+    data: torch.Tensor | None = None
+    # refcount allows shared usage of the MaskBuffer, as long as all users have the same data
+    refcount: int = 0
+
+    def materialize_mask(self, mask):
+        if self.refcount == 0:
+            self.data = mask
+        else:
+            assert self.data is not None
+            if not torch.equal(self.data, mask):
+                raise RuntimeError(
+                    "MaskBuffer has been materialized with conflicting data"
+                )
+        self.refcount += 1
+
+    def release_mask(self):
+        if self.refcount == 0 or self.data is None:
+            raise RuntimeError("MaskBuffer has not been materialized")
+        self.refcount -= 1
+        if self.refcount == 0:
+            self.data = None
+
+    def apply_mask(self, tensor):
+        if self.refcount == 0 or self.data is None:
+            raise RuntimeError("MaskBuffer has not been materialized")
+
+        # NOTE: MaskPartial is being used by the embedding op and the gather op.
+        # For gather, the mask has the same dimension as the output tensor, whereas
+        # the output of the embedding op has an additional dimension compare to the input,
+        # hence the output masking logic below having two different cases.
+        if tensor.ndim == self.data.ndim:
+            tensor[self.data] = 0.0
+        else:
+            tensor[self.data, :] = 0.0
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_math_ops.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_math_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..23fbd92bf99e4c4e1e9cb10ebde39a1918dfc12b
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_math_ops.py
@@ -0,0 +1,1406 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+import math
+from collections.abc import Sequence
+from dataclasses import dataclass
+from enum import Enum
+from typing import cast, Union
+
+import torch
+from torch.distributed.device_mesh import DeviceMesh
+from torch.distributed.tensor._dtensor_spec import DTensorSpec
+from torch.distributed.tensor._op_schema import (
+    OpSchema,
+    OpSpec,
+    OpStrategy,
+    PlacementList,
+    RuntimeSchemaInfo,
+    TupleStrategy,
+)
+from torch.distributed.tensor._ops.registration import register_op_strategy
+from torch.distributed.tensor._ops.utils import (
+    as_list,
+    expand_to_full_mesh_op_strategy,
+    generate_redistribute_costs,
+    is_tensor_evenly_shardable,
+    is_tensor_evenly_shardable_on_dim,
+    normalize_dim,
+    normalize_dims,
+)
+from torch.distributed.tensor._utils import normalize_to_torch_size
+from torch.distributed.tensor.placement_types import (
+    _StridedShard,
+    Partial,
+    Placement,
+    Replicate,
+    Shard,
+)
+
+
+aten = torch.ops.aten
+
+
+class Reduction(Enum):
+    NONE = 0
+    MEAN = 1
+    SUM = 2
+
+
+@dataclass(frozen=True)
+class NormReduction:
+    norm_type: int | float | str
+
+
+ReductionOpType = Union[NormReduction, str]
+
+
+@dataclass(frozen=True)
+class _NormPartial(Partial):
+    """
+    This placement is used for partial vector norm.
+
+    For p-norms (where p not inf or -inf), the p-norm over n elements computes
+        (sum_i x_i^p)^(1/p)
+    where the sum is from i=1 to n. The reduction op is the p-norm itself.
+    For example, consider 2 ranks, a (4,) tensor sharded on dim-0, and 2-norm:
+        Rank 0: [t1, t2] | Rank 1: [t3, t4]
+    After computing 2-norm per gradient (partial placement):
+        Rank 0: [sqrt(t1^2 + t2^2)] | Rank 1: [sqrt(t3^2 + t4^2)]
+    Converting from partial to replicate wants to ultimately get:
+        Rank 0/1: [sqrt(t1^2 + t2^2 + t3^2 + t4^2)]
+    This can be achieved by computing 2-norm on each rank's result. This holds
+    similarly for inf and -inf norm. For 0-norm, the reduction op is sum.
+    """
+
+    norm_type: int | float | str = 2
+
+    def __init__(self, norm_type: int | float | str = 2):
+        reduce_op = None
+        if norm_type in (float("inf"), "inf"):
+            reduce_op = "max"
+        elif norm_type in (float("-inf"), "-inf"):
+            reduce_op = "min"
+        elif isinstance(norm_type, (int, float)):
+            reduce_op = "sum"
+        else:
+            raise NotImplementedError(f"Unsupported norm type: {norm_type}")
+
+        super().__init__(reduce_op)
+        object.__setattr__(self, "norm_type", norm_type)
+
+    def _partition_value(
+        self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
+    ) -> torch.Tensor:
+        """
+        For example, consider 4 ranks, a (3,) replicated tensor, and 2-norm:
+            Ranks 0 and 1: sqrt(t1^2 + t2^2 + t3^3)
+        To convert from replicated to partial, we want f(x) such that
+            sqrt(t1^2 + t2^2 + t3^3) = sqrt(4f(t1)^2 + 4f(t2)^2 + 4f(t3)^2)
+                                     = sqrt(4) sqrt(f(t1)^2 + f(t2)^2 + f(t3)^2).
+        One such f(x) is f(x) = x / sqrt(4). This generalizes to d ranks and
+        p-norm as f(x) = x / d^(1/p).
+        """
+        if self.reduce_op in ("max", "min"):
+            return tensor
+        elif self.reduce_op == "sum":
+            if self.norm_type == 0:
+                raise NotImplementedError(f"Unsupported norm type:: {self.norm_type}")
+            elif self.norm_type == 1:
+                return tensor / mesh.size(mesh_dim)
+            if not isinstance(self.norm_type, (int, float)):
+                raise AssertionError(
+                    f"Expected int or float, got {type(self.norm_type)}"
+                )
+            return tensor / math.pow(mesh.size(mesh_dim), 1 / self.norm_type)
+        raise NotImplementedError(self.reduce_op)
+
+    def _reduce_shard_value(
+        self,
+        tensor: torch.Tensor,
+        mesh: DeviceMesh,
+        mesh_dim: int,
+        shard_spec: Placement,
+    ) -> torch.Tensor:
+        if not isinstance(shard_spec, Shard):
+            raise AssertionError(f"Expected Shard, got {type(shard_spec)}")
+        tensor = self._pre_reduce_transform(tensor)
+        reduced_tensor = super()._reduce_shard_value(tensor, mesh, mesh_dim, shard_spec)
+        return self._post_reduce_transform(reduced_tensor)
+
+    def _reduce_value(
+        self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
+    ) -> torch.Tensor:
+        tensor = self._pre_reduce_transform(tensor)
+        reduced_tensor = super()._reduce_value(tensor, mesh, mesh_dim)
+        return self._post_reduce_transform(reduced_tensor)
+
+    def _pre_reduce_transform(self, tensor: torch.Tensor) -> torch.Tensor:
+        if self.reduce_op == "sum":
+            if not isinstance(self.norm_type, (int, float)):
+                raise AssertionError(
+                    f"Expected int or float, got {type(self.norm_type)}"
+                )
+            if self.norm_type != 0 and self.norm_type != 1:
+                # pyrefly: ignore [unsupported-operation]
+                return tensor**self.norm_type
+        return tensor
+
+    def _post_reduce_transform(self, tensor: torch.Tensor) -> torch.Tensor:
+        if self.reduce_op == "sum":
+            if not isinstance(self.norm_type, (int, float)):
+                raise AssertionError(
+                    f"Expected int or float, got {type(self.norm_type)}"
+                )
+            if self.norm_type != 0 and self.norm_type != 1:
+                # pyrefly: ignore [unsupported-operation]
+                return tensor ** (1.0 / self.norm_type)
+        return tensor
+
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, _NormPartial):
+            return False
+        return self.norm_type == other.norm_type
+
+    def __hash__(self) -> int:
+        return 1 + hash(self.norm_type)
+
+    def __repr__(self) -> str:
+        """
+        machine readable representation of the _NormPartial placement
+        """
+        return f"_NormPartial(reduce_op={self.reduce_op}, norm_type={self.norm_type})"
+
+    def __str__(self) -> str:
+        """human readable representation of the _NormPartial placement"""
+        return f"_NormP({self.reduce_op}, {self.norm_type})"
+
+
+def _infer_reduction_dims(dims_arg: object, ndim: int) -> list[int] | None:
+    if dims_arg is None:
+        return None
+    dims = cast(list[int], as_list(dims_arg))
+    dims = cast(list[int], normalize_dims(dims, ndim))
+    empty_dims = [[0], [-1], []]
+    if ndim == 0 and dims_arg in empty_dims:
+        return None
+    return dims
+
+
+def _infer_reduce_dims_map(
+    reduction_dims: list[int], input_ndim: int, keep_dim=False
+) -> list[int]:
+    reduction_dims_map = []
+    new_dim_count = 0
+    for input_dim in range(input_ndim):
+        if input_dim in reduction_dims and not keep_dim:
+            # if input dim in reduction dims, mark it as -1
+            reduction_dims_map.append(-1)
+        else:
+            # otherwise mark it as the new dim
+            reduction_dims_map.append(new_dim_count)
+            new_dim_count += 1
+
+    return reduction_dims_map
+
+
+def _replicate_dims_start_at(
+    placements: Sequence[Placement], start_dim: int = 0
+) -> tuple[Placement, ...]:
+    new_placements: list[Placement] = []
+    for p in placements:
+        if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
+            new_placements.append(Replicate())  # make it replicate
+        else:
+            new_placements.append(p)  # keep the placement
+    return tuple(new_placements)
+
+
+# return new_placements which align with placements but skip the skipped_dim
+def _skip_dim(
+    placements: tuple[Placement, ...], skipped_dim: int
+) -> tuple[Placement, ...]:
+    new_placements: list[Placement] = []
+    for p in placements:
+        if isinstance(p, Shard) and p.dim >= skipped_dim:
+            new_placements.append(Shard(p.dim - 1))
+        else:
+            new_placements.append(p)
+    return tuple(new_placements)
+
+
+def replicate_reduction_dims(
+    placements: tuple[Placement, ...], reduction_dims: list[int]
+) -> tuple[Placement, ...]:
+    # replicate the reduction dims if not reduction_linear
+    new_placements: list[Placement] = []
+
+    for p in placements:
+        if p.is_partial():
+            new_placements.append(Replicate())
+        elif isinstance(p, Shard) and p.dim in reduction_dims:
+            new_placements.append(Replicate())
+        else:
+            new_placements.append(p)
+
+    return tuple(new_placements)
+
+
+def map_placements_after_reduction(
+    placements: tuple[Placement, ...],
+    reduction_dims: list[int],
+    reduction_dims_map: list[int],
+    reduction_op: ReductionOpType,
+) -> tuple[Placement, ...]:
+    """
+    Map each placement based on the output shape after reduction.
+    """
+    new_placements: list[Placement] = []
+    for placement in placements:
+        if isinstance(placement, (Replicate, Partial)):
+            new_placements.append(placement)
+        else:
+            if not isinstance(placement, Shard | _StridedShard):
+                raise AssertionError(
+                    f"Expected Shard/_StridedShard, got {type(placement)}"
+                )
+            shard_dim = placement.dim
+            new_shard_dim = reduction_dims_map[shard_dim]
+            if new_shard_dim == -1 or shard_dim in reduction_dims:
+                # if new_shard_dim collapsed or its in the reduction dims
+                # (i.e. for the case where keepdims=True), we generate partial
+                new_placements.append(get_placement_from_reduction_op(reduction_op))
+            else:
+                if isinstance(placement, Shard):
+                    new_placements.append(Shard(new_shard_dim))
+                else:
+                    new_placements.append(
+                        _StridedShard(
+                            new_shard_dim, split_factor=placement.split_factor
+                        )
+                    )
+    return tuple(new_placements)
+
+
+def get_placement_from_reduction_op(reduction_op: ReductionOpType) -> Placement:
+    if isinstance(reduction_op, NormReduction):
+        return _NormPartial(norm_type=reduction_op.norm_type)
+    return Partial(reduction_op)
+
+
+def common_reduction_strategy(
+    input_strategy: OpStrategy,
+    reduce_dims: list[int],
+    keep_dim: bool = False,
+    reduction_linear: bool = True,
+    reduction_op: ReductionOpType = "sum",
+) -> OpStrategy:
+    """
+    reduction_linear means that the reduction `f` follows this rule:
+        f([f(a), f(b)]) = f([a, b])
+
+    reduction linear should be super set of linearity.
+    """
+    # by default follow reduction input strategy
+    reduction_strategy = OpStrategy([])
+
+    for op_spec in input_strategy.strategies:
+        if reduction_op == "avg":
+            output_spec = op_spec.output_spec
+            local_shape = list(output_spec.tensor_meta.shape)  # type:ignore[union-attr]
+            for dim in reduce_dims:
+                if not is_tensor_evenly_shardable_on_dim(local_shape, output_spec, dim):
+                    # reduce(avg) is not linear for unevenly sharded tensors
+                    reduction_linear = False
+                    break
+
+        for p in op_spec.output_spec.placements:
+            # when the partial reduction op matches the global reduction op,
+            # we can delay redistribution (i.e max, max)
+            if isinstance(p, Partial) and p.reduce_op != reduction_op:
+                reduction_linear = False
+                break
+
+        if not reduction_linear:
+            # input placements for this strategy should clear out pending sum and sharding
+            # on the reduction dimension
+            input_placements = replicate_reduction_dims(
+                op_spec.output_spec.placements, reduce_dims
+            )
+        else:
+            input_placements = op_spec.output_spec.placements
+
+        input_spec = DTensorSpec(
+            mesh=input_strategy.mesh,
+            placements=input_placements,
+            tensor_meta=op_spec.output_spec.tensor_meta,
+        )
+
+        reduce_dims_map = _infer_reduce_dims_map(reduce_dims, input_spec.ndim, keep_dim)
+        out_placements = map_placements_after_reduction(
+            input_spec.placements, reduce_dims, reduce_dims_map, reduction_op
+        )
+        redistribute_cost = [generate_redistribute_costs(input_strategy, input_spec)]
+        reduction_strategy.strategies.append(
+            OpSpec(
+                output_specs=DTensorSpec(
+                    mesh=input_strategy.mesh,
+                    placements=out_placements,
+                ),
+                input_specs=(input_spec,),
+                redistribute_cost=redistribute_cost,
+            )
+        )
+
+    return reduction_strategy
+
+
+LINEAR_REDUCTION_OP_MAP = {
+    aten.all.default: "product",
+    aten.all.dim: "product",
+    aten.sum.default: "sum",
+    aten.sum.dim_IntList: "sum",
+    aten.any.default: "sum",
+    aten.any.dim: "sum",
+    aten.any.out: "sum",
+    # These are only valid when there is no padding
+    aten.prod.default: "product",
+    aten.prod.dim_int: "product",
+    aten.prod.int_out: "product",
+    # avg is only linear when there is no padding
+    aten.mean.default: "avg",
+    aten.mean.dim: "avg",
+    aten.mean.out: "avg",
+    aten.max.default: "max",
+    aten.max.dim: "max",
+    aten.max.out: "max",
+    aten.min.default: "min",
+    aten.min.dim: "min",
+    aten.min.out: "min",
+    aten.amax.default: "max",
+    aten.amax.out: "max",
+    aten.amin.default: "min",
+    aten.amin.out: "min",
+    # argmax and argmin is using custom hanndler leveraging linear reduction of max and min
+    aten.argmax.default: "max",
+    aten.argmin.default: "min",
+}
+
+
+@register_op_strategy(
+    list(LINEAR_REDUCTION_OP_MAP.keys()), schema_info=RuntimeSchemaInfo(1)
+)
+def linear_reduction_strategy(op_schema: OpSchema) -> OpStrategy:
+    args_schema = op_schema.args_schema
+    input_strategy = args_schema[0]
+    if not isinstance(input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
+
+    dims = None
+    if len(op_schema.args_schema) > 1:
+        dims = _infer_reduction_dims(args_schema[1], input_strategy.ndim)
+
+    reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims
+
+    keep_dim = len(op_schema.args_schema) > 2 and bool(op_schema.args_schema[2])
+    reduction_op = LINEAR_REDUCTION_OP_MAP[op_schema.op]
+    return common_reduction_strategy(
+        input_strategy,
+        reduce_dims,
+        keep_dim=keep_dim,
+        reduction_linear=True,
+        reduction_op=reduction_op,
+    )
+
+
+@register_op_strategy(aten.cumsum.default, schema_info=RuntimeSchemaInfo(1))
+def cumsum_strategy(op_schema: OpSchema) -> OpStrategy:
+    args_schema = op_schema.args_schema
+    input_strategy = args_schema[0]
+    if not isinstance(input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
+    dim = args_schema[1]
+    if not isinstance(dim, int):
+        raise AssertionError(f"Expected int, got {type(dim)}")
+
+    return common_reduction_strategy(
+        input_strategy, [dim], keep_dim=True, reduction_linear=False
+    )
+
+
+@register_op_strategy(
+    [
+        aten.std.correction,
+        aten.std.correction_out,
+        aten.var.correction,
+        aten.var.correction_out,
+    ],
+    schema_info=RuntimeSchemaInfo(1, ["keepdim"]),
+)
+def std_var_reduction_strategy(op_schema: OpSchema) -> OpStrategy:
+    args_schema = op_schema.args_schema
+    input_strategy = args_schema[0]
+    if not isinstance(input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
+    dims = None
+    if len(op_schema.args_schema) > 1:
+        dims = _infer_reduction_dims(args_schema[1], input_strategy.ndim)
+
+    reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims
+
+    keep_dim = cast(bool, op_schema.kwargs_schema.get("keepdim", False))
+    return common_reduction_strategy(
+        input_strategy, reduce_dims, keep_dim=keep_dim, reduction_linear=False
+    )
+
+
+@register_op_strategy(
+    [aten.linalg_vector_norm.default], schema_info=RuntimeSchemaInfo(1)
+)
+def vector_norm_strategy(op_schema: OpSchema) -> OpStrategy:
+    args_schema = op_schema.args_schema
+    input_strategy = args_schema[0]
+    if not isinstance(input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
+
+    norm_type = args_schema[1] if len(args_schema) > 1 else 2
+    if not isinstance(norm_type, (int, float, str)):
+        raise AssertionError(f"Expected int, float, or str, got {type(norm_type)}")
+    dim = args_schema[2] if len(args_schema) > 2 else None
+    keepdim = args_schema[3] if len(args_schema) > 3 else False
+    dims = _infer_reduction_dims(dim, input_strategy.ndim)
+    reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims
+    return common_reduction_strategy(
+        input_strategy,
+        reduce_dims,
+        keep_dim=cast(bool, keepdim),
+        reduction_linear=True,
+        reduction_op=NormReduction(norm_type),
+    )
+
+
+@register_op_strategy(
+    [aten._foreach_norm.Scalar], schema_info=RuntimeSchemaInfo(1, needs_pytree=True)
+)
+def foreach_norm_strategy(op_schema: OpSchema) -> TupleStrategy:
+    args_schema = op_schema.args_schema
+    input_tuple_strategy = args_schema[0]
+    if not isinstance(input_tuple_strategy, TupleStrategy):
+        raise AssertionError(
+            f"Expected TupleStrategy, got {type(input_tuple_strategy)}"
+        )
+    norm_type = args_schema[1] if len(args_schema) > 1 else 2
+    if not isinstance(norm_type, (int, float, str)):
+        raise AssertionError(f"Expected int, float, or str, got {type(norm_type)}")
+    output_tuple_strategy_children: list[OpStrategy] = []
+    for op_strategy in input_tuple_strategy.children:
+        if not isinstance(op_strategy, OpStrategy):
+            raise AssertionError(f"Expected OpStrategy, got {type(op_strategy)}")
+        reduce_dims = list(range(op_strategy.ndim))
+        output_strategy = common_reduction_strategy(
+            op_strategy,
+            reduce_dims,
+            reduction_linear=True,
+            reduction_op=NormReduction(norm_type),
+        )
+        output_tuple_strategy_children.append(output_strategy)
+    return TupleStrategy(output_tuple_strategy_children)
+
+
+@register_op_strategy(
+    [aten._foreach_max.default], schema_info=RuntimeSchemaInfo(1, needs_pytree=True)
+)
+def foreach_max_strategy(op_schema: OpSchema) -> TupleStrategy:
+    """
+    Strategy for _foreach_max, which reduces each tensor in a list to its maximum value.
+    """
+    args_schema = op_schema.args_schema
+    input_tuple_strategy = args_schema[0]
+    if not isinstance(input_tuple_strategy, TupleStrategy):
+        raise AssertionError(
+            f"Expected TupleStrategy, got {type(input_tuple_strategy)}"
+        )
+    output_tuple_strategy_children: list[OpStrategy] = []
+    for op_strategy in input_tuple_strategy.children:
+        if not isinstance(op_strategy, OpStrategy):
+            raise AssertionError(f"Expected OpStrategy, got {type(op_strategy)}")
+        # Reduce all dimensions to get a scalar
+        reduce_dims = list(range(op_strategy.ndim))
+        output_strategy = common_reduction_strategy(
+            op_strategy,
+            reduce_dims,
+            reduction_linear=True,
+            reduction_op="max",
+        )
+        output_tuple_strategy_children.append(output_strategy)
+    return TupleStrategy(output_tuple_strategy_children)
+
+
+@register_op_strategy(
+    [
+        aten._linalg_svd.default,
+        aten.linalg_qr.default,
+        # TODO: The diagonal ops can have an improved sharding strategy for
+        # shard placements that does not require redistributing to replicate.
+        aten.diagonal_copy.default,
+        aten.diag_embed.default,
+        aten.diag.default,
+        aten.diagonal.default,
+        aten.tril.default,
+        aten.triu.default,
+        aten._linalg_eigh.default,
+        aten.upsample_bicubic2d.default,
+        aten.upsample_bilinear2d.default,
+        aten.upsample_linear1d.default,
+        aten.upsample_nearest2d.default,
+        aten.upsample_trilinear3d.default,
+        # TODO: support the full F.interpolate set of options.
+    ],
+    schema_info=RuntimeSchemaInfo(1),
+)
+def linalg_replicate_strategy(op_schema: OpSchema) -> OpStrategy:
+    """
+    Since we do not have a simple way to compute some linear algebra operations
+    like SVD or QR decomposition, always fall back to replicate.
+    """
+    args_schema = op_schema.args_schema
+    input_strategy = args_schema[0]
+    if not isinstance(input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
+    mesh = input_strategy.mesh
+
+    output_strategies: list[OpSpec] = []
+    for placement_strategy in input_strategy.strategies:
+        replicate_placements = tuple(Replicate() for _ in range(mesh.ndim))
+        replicate_spec = DTensorSpec(
+            mesh=mesh,
+            placements=replicate_placements,
+            tensor_meta=placement_strategy.output_spec.tensor_meta,
+        )
+        redistribute_cost = [
+            generate_redistribute_costs(input_strategy, replicate_spec)
+        ]
+        replicate_strategy = OpSpec(
+            output_specs=replicate_spec,
+            input_specs=(replicate_spec,),
+            redistribute_cost=redistribute_cost,
+        )
+        output_strategies.append(replicate_strategy)
+    return OpStrategy(output_strategies)
+
+
+@register_op_strategy(
+    [aten._log_softmax.default, aten._softmax.default, aten._safe_softmax.default],
+    schema_info=RuntimeSchemaInfo(1),
+)
+def softmax_strategy(op_schema: OpSchema) -> OpStrategy:
+    input_strategy, softmax_dim, *_ = op_schema.args_schema
+    input_strategy = cast(OpStrategy, input_strategy)
+
+    softmax_dim = cast(int, softmax_dim)
+    softmax_dim = normalize_dim(softmax_dim, input_strategy.ndim)
+
+    output_strategy = OpStrategy([])
+    for input_placement_strategy in input_strategy.strategies:
+        redistribute_costs = []
+        input_src_spec = input_placement_strategy.output_spec
+
+        # make sure input is replicated along the softmax dim
+        input_target_spec = DTensorSpec(
+            mesh=input_strategy.mesh,
+            placements=replicate_reduction_dims(
+                input_src_spec.placements, [softmax_dim]
+            ),
+            tensor_meta=input_src_spec.tensor_meta,
+        )
+        redistribute_costs.append(
+            generate_redistribute_costs(input_strategy, input_target_spec)
+        )
+        output_target_spec = input_target_spec
+        output_strategy.strategies.append(
+            OpSpec(
+                output_specs=output_target_spec,
+                input_specs=[input_target_spec],
+                redistribute_cost=redistribute_costs,
+            )
+        )
+
+    return output_strategy
+
+
+@register_op_strategy(
+    [
+        aten._log_softmax_backward_data.default,
+        aten._softmax_backward_data.default,
+    ],
+    schema_info=RuntimeSchemaInfo(2),
+)
+def softmax_backward_strategy(op_schema: OpSchema) -> OpStrategy:
+    grad_out_strategy, out_strategy, softmax_dim, _ = op_schema.args_schema
+    grad_out_strategy = cast(OpStrategy, grad_out_strategy)
+    out_strategy = cast(OpStrategy, out_strategy)
+    softmax_dim = cast(int, softmax_dim)
+    softmax_dim = normalize_dim(softmax_dim, grad_out_strategy.ndim)
+
+    grad_in_strategy = OpStrategy([])
+    for grad_out_placement_strat, out_placement_strat in zip(
+        grad_out_strategy.strategies, out_strategy.strategies
+    ):
+        # follow the sharding of the grad_out or out depending on which has more shards
+        grad_out_src_spec = grad_out_placement_strat.output_spec
+        out_src_spec = out_placement_strat.output_spec
+        src_spec = (
+            grad_out_src_spec
+            if grad_out_src_spec.num_shards >= out_src_spec.num_shards
+            else out_src_spec
+        )
+
+        # make sure inputs are replicated along the softmax dim
+        tgt_spec = DTensorSpec(
+            mesh=grad_out_strategy.mesh,
+            placements=replicate_reduction_dims(src_spec.placements, [softmax_dim]),
+        )
+        new_grad_out_spec = DTensorSpec(
+            mesh=tgt_spec.mesh,
+            placements=tgt_spec.placements,
+            tensor_meta=grad_out_src_spec.tensor_meta,
+        )
+        new_out_spec = DTensorSpec(
+            mesh=tgt_spec.mesh,
+            placements=tgt_spec.placements,
+            tensor_meta=out_src_spec.tensor_meta,
+        )
+        redist_grad_out_cost = generate_redistribute_costs(grad_out_strategy, tgt_spec)
+        redist_out_cost = generate_redistribute_costs(out_strategy, tgt_spec)
+        grad_in_strategy.strategies.append(
+            OpSpec(
+                output_specs=tgt_spec,
+                input_specs=(new_grad_out_spec, new_out_spec),
+                redistribute_cost=[redist_grad_out_cost, redist_out_cost],
+            )
+        )
+
+    return grad_in_strategy
+
+
+@register_op_strategy(
+    [aten.nll_loss_forward.default, aten.nll_loss2d_forward.default],
+    schema_info=RuntimeSchemaInfo(3),
+)
+def nll_loss_forward_strategy(op_schema: OpSchema) -> OpStrategy:
+    mesh = op_schema.get_mesh_from_args()
+
+    if not len(op_schema.args_schema) == 5:
+        raise AssertionError(f"Expected 5 args, got {len(op_schema.args_schema)}")
+
+    (
+        input_strategy,
+        target_strategy,
+        weight_strategy,
+        reduction,
+        _,
+    ) = op_schema.args_schema
+    input_strategy = cast(OpStrategy, input_strategy)
+    target_strategy = cast(OpStrategy, target_strategy)
+    reduction = cast(int, reduction)
+
+    input_shape = input_strategy.shape
+    channel_dim = 1 if len(input_shape) >= 2 else 0
+
+    output_strategy = OpStrategy([])
+    for idx, input_placement_strategy in enumerate(input_strategy.strategies):
+        op_args_target_specs = []
+        redistribute_costs = []
+
+        # make sure input is replicated along the channel dim
+        input_src_spec = input_placement_strategy.output_spec
+        input_expected_spec = DTensorSpec(
+            mesh=mesh,
+            placements=replicate_reduction_dims(
+                input_src_spec.placements, [channel_dim]
+            ),
+            tensor_meta=input_src_spec.tensor_meta,
+        )
+        op_args_target_specs.append(input_expected_spec)
+        redistribute_costs.append(
+            generate_redistribute_costs(input_strategy, input_expected_spec)
+        )
+
+        # target doesn't have channel dim, and it follows input on other dims
+        target_src_spec = target_strategy.strategies[idx].output_spec
+        target_expected_spec = DTensorSpec(
+            mesh=mesh,
+            placements=_skip_dim(input_expected_spec.placements, channel_dim),
+            tensor_meta=target_src_spec.tensor_meta,
+        )
+        op_args_target_specs.append(target_expected_spec)
+        redistribute_costs.append(
+            generate_redistribute_costs(target_strategy, target_expected_spec)
+        )
+
+        # weight tensor, if given, has to be a Tensor of size input_shape[channel_dim]
+        # make sure it is replicated
+        if weight_strategy is not None:
+            if not isinstance(weight_strategy, OpStrategy):
+                raise AssertionError(
+                    f"Expected OpStrategy, got {type(weight_strategy)}"
+                )
+            weight_src_spec = weight_strategy.strategies[idx].output_spec
+            weight_expected_spec = DTensorSpec(
+                mesh=mesh,
+                placements=_replicate_dims_start_at(weight_src_spec.placements),
+                tensor_meta=weight_src_spec.tensor_meta,
+            )
+            op_args_target_specs.append(weight_expected_spec)
+            redistribute_costs.append(
+                generate_redistribute_costs(weight_strategy, weight_expected_spec)
+            )
+
+        if reduction == Reduction.NONE.value:
+            output_expected_spec = target_expected_spec
+            total_weight_expected_spec = DTensorSpec(
+                mesh=mesh, placements=tuple([Replicate()] * mesh.ndim)
+            )
+        else:
+            if reduction == Reduction.MEAN.value:
+                reduction_op = "avg"
+                if not is_tensor_evenly_shardable(
+                    target_expected_spec.shape, target_expected_spec
+                ):
+                    raise ValueError(
+                        "The intermediate results of nll_loss cannot be evenly sharded, \
+                        resulting in biased mean result."
+                    )
+            else:  # reduction == Reduction.SUM.value:
+                reduction_op = "sum"
+            reduce_dims = list(range(target_expected_spec.ndim))
+            reduce_dims_map = _infer_reduce_dims_map(
+                reduce_dims, target_expected_spec.ndim, keep_dim=False
+            )
+            out_placements = map_placements_after_reduction(
+                target_expected_spec.placements,
+                reduce_dims,
+                reduce_dims_map,
+                reduction_op,
+            )
+            output_expected_spec = DTensorSpec(
+                mesh=mesh,
+                placements=out_placements,
+            )
+
+            # whether reduction is sum or mean, the total weight has to be summed up if not replicated
+            total_weight_placements = map_placements_after_reduction(
+                target_expected_spec.placements,
+                reduce_dims,
+                reduce_dims_map,
+                "sum",
+            )
+            total_weight_expected_spec = DTensorSpec(
+                mesh=mesh,
+                placements=total_weight_placements,
+            )
+
+        output_strategy.strategies.append(
+            OpSpec(
+                output_specs=(output_expected_spec, total_weight_expected_spec),
+                input_specs=op_args_target_specs,
+                redistribute_cost=redistribute_costs,
+            )
+        )
+
+    return output_strategy
+
+
+@register_op_strategy(
+    [aten.nll_loss_backward.default, aten.nll_loss2d_backward.default],
+    schema_info=RuntimeSchemaInfo(4),
+)
+def nll_loss_backward_strategy(op_schema: OpSchema) -> OpStrategy:
+    # backward op does not need to validate the mesh since forward op has already done it
+    mesh = op_schema.get_mesh_from_args(validate=False)
+
+    if not len(op_schema.args_schema) == 7:
+        raise AssertionError(f"Expected 7 args, got {len(op_schema.args_schema)}")
+    (
+        grad_out_strategy,
+        input_strategy,
+        target_strategy,
+        weight_strategy,
+        reduction,
+        _,
+        total_weight_strategy,
+    ) = op_schema.args_schema
+    grad_out_strategy = cast(OpStrategy, grad_out_strategy)
+    input_strategy = cast(OpStrategy, input_strategy)
+    target_strategy = cast(OpStrategy, target_strategy)
+    reduction = cast(int, reduction)
+    total_weight_strategy = cast(OpStrategy, total_weight_strategy)
+
+    input_shape = input_strategy.shape
+    channel_dim = 1 if len(input_shape) >= 2 else 0
+
+    grad_in_strategy = OpStrategy([])
+    for idx, input_placement_strategy in enumerate(input_strategy.strategies):
+        op_args_target_specs = []
+        redistribute_costs = []
+
+        # make sure input is replicated along the channel dim
+        input_src_spec = input_placement_strategy.output_spec
+        input_expected_spec = DTensorSpec(
+            mesh=mesh,
+            placements=replicate_reduction_dims(
+                input_src_spec.placements, [channel_dim]
+            ),
+            tensor_meta=input_src_spec.tensor_meta,
+        )
+        op_args_target_specs.append(input_expected_spec)
+        redistribute_costs.append(
+            generate_redistribute_costs(input_strategy, input_expected_spec)
+        )
+
+        # target doesn't have channel dim, and it follows input on other dims
+        target_src_spec = target_strategy.strategies[idx].output_spec
+        target_expected_spec = DTensorSpec(
+            mesh=mesh,
+            placements=_skip_dim(input_expected_spec.placements, channel_dim),
+            tensor_meta=target_src_spec.tensor_meta,
+        )
+        op_args_target_specs.append(target_expected_spec)
+        redistribute_costs.append(
+            generate_redistribute_costs(target_strategy, target_expected_spec)
+        )
+
+        # grad_out follows target if there is no reduction;
+        # otherwise, it should be a replicated scalar.
+        grad_out_src_spec = grad_out_strategy.strategies[idx].output_spec
+        if reduction == Reduction.NONE.value:
+            grad_out_expected_spec = target_expected_spec
+        else:
+            grad_out_expected_spec = DTensorSpec(
+                mesh=mesh,
+                placements=_replicate_dims_start_at(grad_out_src_spec.placements),
+                tensor_meta=grad_out_src_spec.tensor_meta,
+            )
+        op_args_target_specs.insert(0, grad_out_expected_spec)
+        redistribute_costs.insert(
+            0, generate_redistribute_costs(grad_out_strategy, grad_out_expected_spec)
+        )
+
+        # weight tensor, if given, has to be a Tensor of size input_shape[channel_dim]
+        # make sure it is replicated
+        if weight_strategy is not None:
+            if not isinstance(weight_strategy, OpStrategy):
+                raise AssertionError(
+                    f"Expected OpStrategy, got {type(weight_strategy)}"
+                )
+            weight_src_spec = weight_strategy.strategies[idx].output_spec
+            weight_expected_spec = DTensorSpec(
+                mesh=mesh,
+                placements=_replicate_dims_start_at(weight_src_spec.placements),
+                tensor_meta=weight_src_spec.tensor_meta,
+            )
+            op_args_target_specs.append(weight_expected_spec)
+            redistribute_costs.append(
+                generate_redistribute_costs(weight_strategy, weight_expected_spec)
+            )
+
+        # total_weight should always be replicated
+        total_weight_src_spec = total_weight_strategy.strategies[idx].output_spec
+        total_weight_expected_spec = DTensorSpec(
+            mesh=mesh,
+            placements=_replicate_dims_start_at(total_weight_src_spec.placements),
+            tensor_meta=total_weight_src_spec.tensor_meta,
+        )
+        op_args_target_specs.append(total_weight_expected_spec)
+        redistribute_costs.append(
+            generate_redistribute_costs(
+                total_weight_strategy, total_weight_expected_spec
+            )
+        )
+
+        grad_in_expected_spec = input_expected_spec
+        grad_in_strategy.strategies.append(
+            OpSpec(
+                output_specs=grad_in_expected_spec,
+                input_specs=op_args_target_specs,
+                redistribute_cost=redistribute_costs,
+            )
+        )
+
+    return grad_in_strategy
+
+
+def _common_norm_forward_strategy(
+    op_schema: OpSchema,
+    rms_norm: bool = False,
+) -> OpStrategy:
+    """Common forward strategy logic for layer_norm and rms_norm."""
+    mesh = op_schema.get_mesh_from_args()
+
+    if not rms_norm:
+        # layer_norm args: input, normalized_shape, weight, bias, eps
+        # for None weight and bias, their corresponding objects will
+        # be None as well. layer_norm_strategy returns one OpStrategy
+        # for the triple return values (out, mean, rstd).
+        if not len(op_schema.args_schema) == 5:
+            raise AssertionError(f"Expected 5 args, got {len(op_schema.args_schema)}")
+        (
+            input_strategy,
+            normalized_shape,
+            weight_strategy,
+            bias_strategy,
+            _,
+        ) = op_schema.args_schema
+    else:
+        # rms_norm args: input, normalized_shape, weight, eps
+        if not len(op_schema.args_schema) == 4:
+            raise AssertionError(f"Expected 4 args, got {len(op_schema.args_schema)}")
+        (
+            input_strategy,
+            normalized_shape,
+            weight_strategy,
+            _,
+        ) = op_schema.args_schema
+        bias_strategy = None
+
+    # the current norm implementation requires that all
+    # input DTensor's sharding must be in form of OpStrategy
+    if not isinstance(input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
+    if not isinstance(normalized_shape, (int, Sequence, torch.Size)):
+        raise AssertionError(
+            f"Expected int, Sequence, or torch.Size, got {type(normalized_shape)}"
+        )
+    normalized_size = normalize_to_torch_size(normalized_shape)
+
+    input_ndim = input_strategy.ndim
+    axis = input_ndim - len(normalized_size)
+
+    # we use OpStrategy because the output values (out, mean, rstd)
+    # should have the same placements
+    output_strategy = OpStrategy([])
+    for idx, input_placement_strategy in enumerate(input_strategy.strategies):
+        op_args_target_specs = []
+        redistribute_costs = []
+        input_src_spec = input_placement_strategy.output_spec
+
+        # for the input tensor, we replicate it on the inner dims if necessary
+        # TODO: we can avoid forcing the redistribution once we figure out
+        # how to decompose layer norm
+        input_target_spec = DTensorSpec(
+            mesh=mesh,
+            placements=_replicate_dims_start_at(input_src_spec.placements, axis),
+            tensor_meta=input_src_spec.tensor_meta,
+        )
+        op_args_target_specs.append(input_target_spec)
+        redistribute_costs.append(
+            generate_redistribute_costs(input_strategy, input_target_spec)
+        )
+
+        if weight_strategy is not None:
+            if not isinstance(weight_strategy, OpStrategy):
+                raise AssertionError(
+                    f"Expected OpStrategy, got {type(weight_strategy)}"
+                )
+            weight_src_spec = weight_strategy.strategies[idx].output_spec
+
+            # for the weight tensor, we replicate it on all dims if necessary
+            # TODO: we can avoid forcing the redistribution once we figure out
+            # how to decompose layer norm
+            weight_target_spec = DTensorSpec(
+                mesh=mesh,
+                placements=_replicate_dims_start_at(weight_src_spec.placements),
+                tensor_meta=weight_src_spec.tensor_meta,
+            )
+            op_args_target_specs.append(weight_target_spec)
+            redistribute_costs.append(
+                generate_redistribute_costs(weight_strategy, weight_target_spec)
+            )
+
+        if bias_strategy is not None:
+            if not isinstance(bias_strategy, OpStrategy):
+                raise AssertionError(f"Expected OpStrategy, got {type(bias_strategy)}")
+            bias_src_spec = bias_strategy.strategies[idx].output_spec
+
+            # for the bias tensor, we replicate it on all dims if necessary
+            # TODO: we can avoid forcing the redistribution once we figure out
+            # how to decompose layer norm
+            bias_target_spec = DTensorSpec(
+                mesh=mesh,
+                placements=_replicate_dims_start_at(bias_src_spec.placements),
+                tensor_meta=bias_src_spec.tensor_meta,
+            )
+            op_args_target_specs.append(bias_target_spec)
+            redistribute_costs.append(
+                generate_redistribute_costs(bias_strategy, bias_target_spec)
+            )
+
+        # the output spec is the same as input spec
+        output_target_spec = input_target_spec
+        output_strategy.strategies.append(
+            OpSpec(
+                output_specs=output_target_spec,
+                input_specs=op_args_target_specs,
+                redistribute_cost=redistribute_costs,
+            )
+        )
+
+    return output_strategy
+
+
+@register_op_strategy(
+    [aten.native_layer_norm.default],
+    schema_info=RuntimeSchemaInfo(1),
+)
+def layer_norm_strategy(op_schema: OpSchema) -> OpStrategy:
+    return _common_norm_forward_strategy(op_schema)
+
+
+@register_op_strategy(
+    [aten._fused_rms_norm.default],
+    schema_info=RuntimeSchemaInfo(1),
+)
+def fused_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
+    return _common_norm_forward_strategy(op_schema, rms_norm=True)
+
+
+def _common_norm_backward_strategy(
+    op_schema: OpSchema,
+    rms_norm: bool = False,
+) -> OpStrategy:
+    """Common backward strategy logic for layer_norm and rms_norm."""
+    # backward op does not need to validate the mesh since forward op has already done it
+    mesh = op_schema.get_mesh_from_args(validate=False)
+
+    if not rms_norm:
+        # layer_norm args: grad_out, input, normalized_shape, mean, rstd,
+        # weight, bias, output_mask. For None weight and bias, their
+        # corresponding objects will be None as well.
+        if not len(op_schema.args_schema) == 8:
+            raise AssertionError(f"Expected 8 args, got {len(op_schema.args_schema)}")
+        (
+            grad_out_strategy,
+            input_strategy,
+            normalized_shape,
+            mean_strategy,
+            rstd_strategy,
+            weight_strategy,
+            bias_strategy,
+            output_mask,
+        ) = op_schema.args_schema
+    else:
+        # rms_norm args: grad_out, input, normalized_shape, rstd,
+        if not len(op_schema.args_schema) == 6:
+            raise AssertionError(f"Expected 6 args, got {len(op_schema.args_schema)}")
+        (
+            grad_out_strategy,
+            input_strategy,
+            normalized_shape,
+            rstd_strategy,
+            weight_strategy,
+            output_mask,
+        ) = op_schema.args_schema
+        mean_strategy = None
+        bias_strategy = None
+
+    if not isinstance(grad_out_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(grad_out_strategy)}")
+    if not isinstance(input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
+    if not isinstance(rstd_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(rstd_strategy)}")
+    if mean_strategy is not None:
+        if not isinstance(mean_strategy, OpStrategy):
+            raise AssertionError(f"Expected OpStrategy, got {type(mean_strategy)}")
+
+    if not isinstance(normalized_shape, (int, Sequence, torch.Size)):
+        raise AssertionError(
+            f"Expected int, Sequence, or torch.Size, got {type(normalized_shape)}"
+        )
+    normalized_size = normalize_to_torch_size(normalized_shape)
+    input_ndim = input_strategy.ndim
+    axis = input_ndim - len(normalized_size)
+    outer_dims = list(range(axis))
+
+    if not rms_norm:
+        if not (isinstance(output_mask, list) and len(output_mask) == 3):
+            raise AssertionError(
+                f"Expected output_mask to be list of length 3, got {type(output_mask)} "
+                f"of length {len(output_mask) if isinstance(output_mask, list) else 'N/A'}"
+            )
+    else:
+        if not (isinstance(output_mask, list) and len(output_mask) == 2):
+            raise AssertionError(
+                f"Expected output_mask to be list of length 2, got {type(output_mask)} "
+                f"of length {len(output_mask) if isinstance(output_mask, list) else 'N/A'}"
+            )
+
+    # output tuple: (d_input, d_weight[, d_bias])
+    out_tuple_strategy = OpStrategy([])
+    for idx, input_placement_strategy in enumerate(input_strategy.strategies):
+        # args for OpSpec
+        output_specs_list: list[DTensorSpec | None] = []
+        input_specs_list: list[DTensorSpec] = []
+        redistribute_costs = []
+
+        input_src_spec = input_placement_strategy.output_spec
+        # arg: grad_out
+        # TODO: change the strategy to the following rule.
+        # d_input is basically a product of element-wise mul of
+        # grad_out, rstd, and normalized input, among which rstd
+        # and normalized input (x_hat) should have the same sharding
+        # placements, and grad_out's sharding is determined by the
+        # pointwise result of x_hat and weight/bias.
+        # TODO: now grad_out spec follows input spec. we may need
+        # to change it to apply a pointwise rule over grad_out,
+        # input, and weight.
+        grad_out_target_spec = DTensorSpec(
+            mesh=mesh,
+            placements=_replicate_dims_start_at(input_src_spec.placements, axis),
+            tensor_meta=input_src_spec.tensor_meta,
+        )
+        input_specs_list.append(grad_out_target_spec)
+        redistribute_costs.append(
+            generate_redistribute_costs(grad_out_strategy, grad_out_target_spec)
+        )
+        output_specs_list.append(grad_out_target_spec if output_mask[0] else None)
+
+        # arg: input
+        input_target_spec = DTensorSpec(
+            mesh=mesh,
+            placements=_replicate_dims_start_at(input_src_spec.placements, axis),
+            tensor_meta=input_src_spec.tensor_meta,
+        )
+        input_specs_list.append(input_target_spec)
+        redistribute_costs.append(
+            generate_redistribute_costs(input_strategy, input_target_spec)
+        )
+
+        # arg: mean
+        if not rms_norm:
+            if mean_strategy is None:
+                raise AssertionError("Expected mean_strategy to not be None")
+            mean_src_spec = mean_strategy.strategies[idx].output_spec
+            input_specs_list.append(mean_src_spec)
+            redistribute_costs.append([0.0 for _ in mean_strategy.strategies])
+
+        # arg: rstd
+        rstd_src_spec = rstd_strategy.strategies[idx].output_spec
+        input_specs_list.append(rstd_src_spec)
+        redistribute_costs.append([0.0 for _ in rstd_strategy.strategies])
+
+        def _add_target_input_spec(strategy) -> DTensorSpec:
+            # shared logic for setting the weight and bias target input specs
+            if not isinstance(strategy, OpStrategy):
+                raise AssertionError(f"Expected OpStrategy, got {type(strategy)}")
+            src_spec = strategy.strategies[idx].output_spec
+            # no need to redistribute since they should be replicated in forward pass
+            input_specs_list.append(src_spec)
+            redistribute_costs.append([0.0 for _ in strategy.strategies])
+            return src_spec
+
+        # arg: weight
+        # d_weight = sum(grad_out * (input - mean) / rstd, outer_dim, keepdim=False)
+        # For RMS norm, mean is 0, so it's just: sum(grad_out * input / rstd, outer_dim, keepdim=False)
+        if weight_strategy is not None:
+            weight_src_spec = _add_target_input_spec(weight_strategy)
+            # TODO: now d_weight spec follows input spec w/ a reduction.
+            # we may need to change to a pointwise rule over grad_out and
+            # input, then apply a reduction.
+            inp_placements = _replicate_dims_start_at(input_src_spec.placements, axis)
+            reduce_dims_map = _infer_reduce_dims_map(
+                outer_dims, input_src_spec.ndim, False
+            )
+            out_placements = map_placements_after_reduction(
+                inp_placements, outer_dims, reduce_dims_map, "sum"
+            )
+            weight_out_spec = DTensorSpec(
+                mesh=mesh,
+                placements=out_placements,
+                tensor_meta=weight_src_spec.tensor_meta,
+            )
+            output_specs_list.append(weight_out_spec if output_mask[1] else None)
+        else:
+            if not rms_norm:
+                error_msg = "output_mask[1] should not be `True` while weight argument is `None` in native_layer_norm_backward."
+            else:
+                error_msg = "output_mask[1] should not be `True` while weight argument is `None` in _fused_rms_norm_backward."
+            if output_mask[1] is not False:
+                raise AssertionError(error_msg)
+            output_specs_list.append(None)
+
+        # arg: bias
+        # d_bias = sum(grad_out, outer_dim, keepdim=False)
+        if not rms_norm:
+            if bias_strategy is not None:
+                bias_src_spec = _add_target_input_spec(bias_strategy)
+                # d_bias spec follows a reduction over grad_out
+                inp_placements = _replicate_dims_start_at(
+                    grad_out_target_spec.placements, axis
+                )
+                reduce_dims_map = _infer_reduce_dims_map(
+                    outer_dims, grad_out_target_spec.ndim, False
+                )
+                out_placements = map_placements_after_reduction(
+                    inp_placements, outer_dims, reduce_dims_map, "sum"
+                )
+                bias_out_spec = DTensorSpec(
+                    mesh=mesh,
+                    placements=out_placements,
+                    tensor_meta=bias_src_spec.tensor_meta,
+                )
+                output_specs_list.append(bias_out_spec if output_mask[2] else None)
+            else:
+                if output_mask[2] is not False:
+                    raise AssertionError(
+                        "output_mask[2] should not be `True` while bias argument is `None` in native_layer_norm_backward."
+                    )
+                output_specs_list.append(None)
+
+        out_tuple_strategy.strategies.append(
+            OpSpec(
+                output_specs=tuple(output_specs_list),
+                input_specs=input_specs_list,
+                redistribute_cost=redistribute_costs,
+            )
+        )
+
+    return out_tuple_strategy
+
+
+@register_op_strategy(
+    [aten.native_layer_norm_backward.default],
+    schema_info=RuntimeSchemaInfo(2),
+)
+def layer_norm_bwd_strategy(op_schema: OpSchema) -> OpStrategy:
+    return _common_norm_backward_strategy(op_schema)
+
+
+@register_op_strategy(
+    [aten._fused_rms_norm_backward.default],
+    schema_info=RuntimeSchemaInfo(2),
+)
+def fused_rms_norm_bwd_strategy(op_schema: OpSchema) -> OpStrategy:
+    return _common_norm_backward_strategy(op_schema, rms_norm=True)
+
+
+def sort_strategy(op_schema: OpSchema, sort_dim: int) -> OpStrategy:
+    input_strategy = cast(OpStrategy, op_schema.args_schema[0])
+    sort_dim = normalize_dim(sort_dim, input_strategy.ndim)
+    single_mesh_dim_strategies = []
+    all_replicate: PlacementList = [Replicate()] * 3
+    single_mesh_dim_strategies.append(all_replicate)
+    for dim in range(input_strategy.ndim):
+        if dim != sort_dim:
+            dim_shardings: PlacementList = [Shard(dim)] * 3
+            single_mesh_dim_strategies.append(dim_shardings)
+    return expand_to_full_mesh_op_strategy(
+        input_strategy.mesh, op_schema, single_mesh_dim_strategies, input_index=2
+    )
+
+
+@register_op_strategy(
+    [aten.topk.default],
+    schema_info=RuntimeSchemaInfo(2),
+)
+def topk_strategy(op_schema: OpSchema) -> OpStrategy:
+    topk_dim = (
+        cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else -1
+    )
+    return sort_strategy(op_schema, topk_dim)
+
+
+@register_op_strategy(
+    aten.sort.default,
+    schema_info=RuntimeSchemaInfo(
+        1,
+    ),
+)
+def sort_default_strategy(op_schema: OpSchema) -> OpStrategy:
+    # mostly copy paste from topk_strategy
+    input_strategy = op_schema.args_schema[0]
+    if not isinstance(input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
+    sort_dim = -1
+    if len(op_schema.args_schema) > 1:
+        sort_dim = cast(int, op_schema.args_schema[1])
+    return sort_strategy(op_schema, sort_dim)
+
+
+@register_op_strategy(
+    aten.sort.stable,
+    schema_info=RuntimeSchemaInfo(
+        1,
+        static_kwargkey=["dim", "descending", "stable"],
+    ),
+)
+def sort_stable_strategy(op_schema: OpSchema) -> OpStrategy:
+    # mostly copy paste from topk_strategy
+    input_strategy = op_schema.args_schema[0]
+    if not isinstance(input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
+    sort_dim = -1
+    if "dim" in op_schema.kwargs_schema:
+        sort_dim = cast(int, op_schema.kwargs_schema["dim"])
+    return sort_strategy(op_schema, sort_dim)
+
+
+@register_op_strategy(
+    [aten.histc.default],
+    # strategy choice depends on the value of 'min' and 'max' kwargs, which are position 2 and 3
+    schema_info=RuntimeSchemaInfo(2),
+)
+def histc_strategy(op_schema: OpSchema) -> OpStrategy:
+    input_strategy = cast(OpStrategy, op_schema.args_schema[0])
+    single_mesh_dim_strategies: list[PlacementList] = []
+    single_mesh_dim_strategies.append([Replicate(), Replicate()])
+
+    # histc can support sharded input and partial output on any input dim, provided the min and max
+    # values are user-specified.  If not user-specified, the true min and max of the data in each local
+    # tensor will be used to compute bin boundaries, which will not be the same across ranks, leading to
+    # an incorrect final result
+    if len(op_schema.args_schema) == 4:
+        for dim in range(input_strategy.ndim):
+            dim_shardings: PlacementList = [Partial(), Shard(dim)]
+            single_mesh_dim_strategies.append(dim_shardings)
+
+    return expand_to_full_mesh_op_strategy(
+        input_strategy.mesh, op_schema, single_mesh_dim_strategies
+    )
+
+
+@register_op_strategy(
+    [aten.logsumexp.default],
+    schema_info=RuntimeSchemaInfo(
+        # static_argnum is the position where non-Tensor args beings.
+        static_argnum=1,
+        # static_kwargkey is the name of kwargs to hash (which determines
+        # whether sharding prop can be cached).
+        static_kwargkey=["keepdim"],
+    ),
+)
+def logsumexp_strategy(op_schema: OpSchema) -> OpStrategy:
+    """Implements the sharding propagation strategy for logsumexp."""
+
+    # args_schema contains all but the DTensor args (e.g., dim, keepdim).
+    args_schema = op_schema.args_schema
+    if not len(args_schema) > 1:
+        raise AssertionError(
+            f"Expected more than 1 arg (input and dim are required), got {len(args_schema)}"
+        )
+
+    input_strategy = args_schema[0]
+    if not isinstance(input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
+
+    dims_arg = args_schema[1]
+    reduce_dims = _infer_reduction_dims(dims_arg, input_strategy.ndim)
+    if reduce_dims is None:
+        raise AssertionError("Expected reduce_dims to not be None")
+
+    keep_dim = cast(bool, op_schema.kwargs_schema.get("keepdim", False))
+    return common_reduction_strategy(
+        input_strategy,
+        reduce_dims,
+        keep_dim=keep_dim,
+        reduction_linear=False,
+    )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_matrix_ops.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..c00a44ef8f4f41730bdb4ca0550ffa1808a8fffe
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_matrix_ops.py
@@ -0,0 +1,1087 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# implement matrix related ops for distributed tensor
+
+
+import torch
+from torch.distributed.device_mesh import DeviceMesh
+from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
+from torch.distributed.tensor._op_schema import (
+    OpSchema,
+    OpSpec,
+    OpStrategy,
+    PlacementList,
+    RuntimeSchemaInfo,
+)
+from torch.distributed.tensor._ops._einsum_strategy import gen_einsum_strategies
+from torch.distributed.tensor._ops.registration import register_op_strategy
+from torch.distributed.tensor._ops.utils import (
+    expand_to_full_mesh_op_strategy,
+    generate_redistribute_costs,
+    infer_broadcast_dims_map,
+    is_tensor_shardable,
+    map_placements_after_broadcast,
+    prod,
+)
+from torch.distributed.tensor._utils import (
+    compute_local_shape_and_global_offset,
+    compute_local_stride,
+)
+from torch.distributed.tensor.placement_types import (
+    Partial,
+    Placement,
+    Replicate,
+    Shard,
+)
+
+
+aten = torch.ops.aten
+
+
+@register_op_strategy(aten.t.default)
+def transpose_strategy(op_schema: OpSchema) -> OpStrategy:
+    self_strategy = op_schema.args_schema[0]
+    if not isinstance(self_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(self_strategy)}")
+
+    transpose_strategies = []
+    for input_strategy in self_strategy.strategies:
+        input_spec = input_strategy.output_spec
+        # follow the input spec but transpose the Shard placements
+        output_placements = [
+            Shard(1 - p.dim) if isinstance(p, Shard) else p
+            for p in input_spec.placements
+        ]
+        transpose_strategy = OpSpec(
+            output_specs=DTensorSpec(
+                mesh=input_strategy.mesh,
+                placements=tuple(output_placements),
+            ),
+            input_specs=(input_strategy.output_spec,),
+        )
+        transpose_strategies.append(transpose_strategy)
+
+    return OpStrategy(strategies=transpose_strategies)
+
+
+def _mm_like_strategy(
+    mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema
+) -> OpStrategy:
+    self_strategy, mat2_strategy = op_schema.args_schema
+    if not isinstance(self_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(self_strategy)}")
+    if not isinstance(mat2_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(mat2_strategy)}")
+    # generate all possible strategies for mm
+    mm_strategy = gen_einsum_strategies(mm_equation, mesh)
+    # filter out invalid strategies and associate costs
+    strategies = mm_strategy.strategies
+    filtered_strategies = []
+    for strtg in strategies:
+        if strtg.input_specs is None:
+            raise AssertionError(
+                f"Expected input_specs to be not None, got {strtg.input_specs}"
+            )
+        self_spec = strtg.input_specs[0]
+        mat2_spec = strtg.input_specs[1]
+        if is_tensor_shardable(
+            self_strategy.shape, self_spec, allow_unbacked_sharding=True
+        ) and is_tensor_shardable(
+            mat2_strategy.shape, mat2_spec, allow_unbacked_sharding=True
+        ):
+            redistribute_cost = [
+                generate_redistribute_costs(self_strategy, self_spec),
+                generate_redistribute_costs(mat2_strategy, mat2_spec),
+            ]
+            strtg.redistribute_cost = redistribute_cost
+            filtered_strategies.append(strtg)
+
+    mm_strategy.strategies = filtered_strategies
+
+    return mm_strategy
+
+
+def _addmm_like_strategy(
+    mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema
+) -> OpStrategy:
+    self_strategy, mat1_strategy, mat2_strategy = op_schema.args_schema
+    if not isinstance(self_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(self_strategy)}")
+    if not isinstance(mat1_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(mat1_strategy)}")
+    if not isinstance(mat2_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(mat2_strategy)}")
+    self_shape = self_strategy.shape
+    mm_out_shape = torch.Size(
+        [
+            mat2_strategy.shape[-1] if i == len(mat1_strategy.shape) - 1 else dim_size
+            for i, dim_size in enumerate(mat1_strategy.shape)
+        ]
+    )
+    # generate all possible strategies for mm
+    mm_strategy = gen_einsum_strategies(mm_equation, mesh)
+    # filter out invalid strategies and associate costs
+    strategies = mm_strategy.strategies
+    filtered_strategies = []
+    for strtg in strategies:
+        # construct new strategy by consider the self arg
+        if strtg.input_specs is None:
+            raise AssertionError(
+                f"Expected input_specs to be not None, got {strtg.input_specs}"
+            )
+        mat1_spec = strtg.input_specs[0]
+        mat2_spec = strtg.input_specs[1]
+        out_spec = strtg.output_spec
+
+        # self arg's spec should follow the output of mm, but need
+        # to consider broadcast for the self arg
+        broadcast_dims_map = infer_broadcast_dims_map(mm_out_shape, self_shape)
+        self_placements = map_placements_after_broadcast(
+            out_spec.placements, mm_out_shape, broadcast_dims_map
+        )
+        self_spec = DTensorSpec(mesh=mesh, placements=self_placements)
+
+        if is_tensor_shardable(
+            mat1_strategy.shape, mat1_spec, allow_unbacked_sharding=True
+        ) and is_tensor_shardable(
+            mat2_strategy.shape, mat2_spec, allow_unbacked_sharding=True
+        ):
+            # update input specs with new self spec
+            strtg.input_specs = (self_spec, mat1_spec, mat2_spec)
+
+            # associate costs
+            redistribute_cost = [
+                generate_redistribute_costs(self_strategy, self_spec),
+                generate_redistribute_costs(mat1_strategy, mat1_spec),
+                generate_redistribute_costs(mat2_strategy, mat2_spec),
+            ]
+            strtg.redistribute_cost = redistribute_cost
+            filtered_strategies.append(strtg)
+
+    mm_strategy.strategies = filtered_strategies
+
+    return mm_strategy
+
+
+def _scaled_mm_like_strategy(
+    mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema
+) -> OpStrategy:
+    (
+        self_strategy,
+        mat2_strategy,
+        scale_self_strategy,
+        scale_mat2_strategy,
+        bias_strategy,
+        scale_result_strategy,
+        *_,
+    ) = op_schema.args_schema
+    if not isinstance(self_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(self_strategy)}")
+    if not isinstance(mat2_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(mat2_strategy)}")
+    if not isinstance(scale_self_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(scale_self_strategy)}")
+    if not isinstance(scale_mat2_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(scale_mat2_strategy)}")
+    # TODO: add support for these later
+    if bias_strategy is not None:
+        raise AssertionError("_scaled_mm on DTensors doesn't support bias")
+    if scale_result_strategy is not None:
+        raise AssertionError("_scaled_mm on DTensors doesn't support scale_result")
+    # generate all possible strategies for mm
+    mm_strategy = gen_einsum_strategies(mm_equation, mesh)
+    # filter out invalid strategies and associate costs
+    strategies = mm_strategy.strategies
+    filtered_strategies = []
+    for strtg in strategies:
+        if strtg.input_specs is None:
+            raise AssertionError(
+                f"Expected input_specs to be not None, got {strtg.input_specs}"
+            )
+        self_spec = strtg.input_specs[0]
+        mat2_spec = strtg.input_specs[1]
+        # propagate the operands' specs to their scales, except for tensor-wise
+        # scaling which can have any numbers of dims (legacy...), hence sharding
+        # dims won't map. for tensor-wise, anyways, we can only do replication.
+        scale_self_spec = (
+            DTensorSpec(self_spec.mesh, (Replicate(),))
+            if prod(scale_self_strategy.shape) == 1
+            else self_spec
+        )
+        scale_mat2_spec = (
+            DTensorSpec(mat2_spec.mesh, (Replicate(),))
+            if prod(scale_mat2_strategy.shape) == 1
+            else mat2_spec
+        )
+        strtg.input_specs = list(strtg.input_specs) + [scale_self_spec, scale_mat2_spec]
+        if (
+            is_tensor_shardable(
+                self_strategy.shape, self_spec, allow_unbacked_sharding=True
+            )
+            and is_tensor_shardable(
+                mat2_strategy.shape, mat2_spec, allow_unbacked_sharding=True
+            )
+            and is_tensor_shardable(
+                scale_self_strategy.shape, scale_self_spec, allow_unbacked_sharding=True
+            )
+            and is_tensor_shardable(
+                scale_mat2_strategy.shape, scale_mat2_spec, allow_unbacked_sharding=True
+            )
+        ):
+            redistribute_cost = [
+                generate_redistribute_costs(self_strategy, self_spec),
+                generate_redistribute_costs(mat2_strategy, mat2_spec),
+                generate_redistribute_costs(scale_self_strategy, scale_self_spec),
+                generate_redistribute_costs(scale_mat2_strategy, scale_mat2_spec),
+            ]
+            strtg.redistribute_cost = redistribute_cost
+            filtered_strategies.append(strtg)
+
+    mm_strategy.strategies = filtered_strategies
+
+    return mm_strategy
+
+
+@register_op_strategy(aten.dot.default)
+def dot_strategy(op_schema: OpSchema) -> OpStrategy:
+    mesh = op_schema.get_mesh_from_args()
+    return _mm_like_strategy("i,i->", mesh, op_schema)
+
+
+@register_op_strategy(aten.mm.default)
+def mm_strategy(op_schema: OpSchema) -> OpStrategy:
+    mesh = op_schema.get_mesh_from_args()
+    return _mm_like_strategy("mk,kn->mn", mesh, op_schema)
+
+
+@register_op_strategy(aten.addmm.default)
+def addmm_strategy(op_schema: OpSchema) -> OpStrategy:
+    mesh = op_schema.get_mesh_from_args()
+    return _addmm_like_strategy("mk,kn->mn", mesh, op_schema)
+
+
+@register_op_strategy(aten.bmm.default)
+def bmm_strategy(op_schema: OpSchema) -> OpStrategy:
+    mesh = op_schema.get_mesh_from_args()
+    return _mm_like_strategy("bmk,bkn->bmn", mesh, op_schema)
+
+
+@register_op_strategy(aten.baddbmm.default)
+def baddbmm_strategy(op_schema: OpSchema) -> OpStrategy:
+    mesh = op_schema.get_mesh_from_args()
+    return _addmm_like_strategy("bmk,bkn->bmn", mesh, op_schema)
+
+
+@register_op_strategy(aten._scaled_mm.default)
+def scaled_mm_strategy(op_schema: OpSchema) -> OpStrategy:
+    mesh = op_schema.get_mesh_from_args()
+    return _scaled_mm_like_strategy("mk,kn->mn", mesh, op_schema)
+
+
+def _scaled_dot_product_flash_attention_base_strategies(
+    op_schema: OpSchema,
+) -> list[PlacementList]:
+    """Helper that returns list of base placement strategies (without CP)."""
+    return_debug_mask = len(op_schema.args_schema) >= 6 and op_schema.args_schema[5]
+    q_input_strategy = op_schema.args_schema[0]
+    if not isinstance(q_input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(q_input_strategy)}")
+    # assuming q/k/v have the same shape
+
+    single_mesh_dim_strategies = []
+
+    # placement list stores placements of [outputs, inputs]
+    # in the spda case, we have 3 valid tensor outputs and 3 tensor inputs
+    # first we can always accept full replication for both inputs and outputs
+    all_replicate: PlacementList = [
+        Replicate(),
+        Replicate(),
+        None,  # cum_seq_q
+        None,  # cum_seq_k
+        None,  # max_q
+        None,  # max_k
+        Replicate(),  # rng_state
+        None,  # unused
+        Replicate(),
+        Replicate(),
+        Replicate(),
+        Replicate(),
+    ]
+    single_mesh_dim_strategies.append(all_replicate)
+
+    # second we can accept the sharding pattern of tensor parallelism, which
+    # shard on the num of head dim
+    qkv_sharding = Shard(1)  # num head dim
+    output_sharding = Shard(1)  # num head dim
+    logsumexp_sharding = Shard(1)  # num head dim
+    if return_debug_mask:
+        debug_attn_mask_sharding: Placement = Shard(1)  # num head dim
+    else:
+        # empty debug mask, replicated
+        debug_attn_mask_sharding = Replicate()
+
+    num_heads_dim_sharding: PlacementList = [
+        output_sharding,
+        logsumexp_sharding,
+        None,  # cum_seq_q
+        None,  # cum_seq_k
+        None,  # max_q
+        None,  # max_k
+        Replicate(),  # rng_state
+        None,  # unused
+        debug_attn_mask_sharding,
+        qkv_sharding,
+        qkv_sharding,
+        qkv_sharding,
+    ]
+    single_mesh_dim_strategies.append(num_heads_dim_sharding)
+
+    # Shard on the batch dimension
+    debug_attn_mask_sharding = Shard(0) if return_debug_mask else Replicate()
+    single_mesh_dim_strategies.append(
+        [
+            Shard(0),  # output
+            Shard(0),  # logsumexp
+            None,  # cum_seq_q
+            None,  # cum_seq_k
+            None,  # max_q
+            None,  # max_k
+            Replicate(),  # rng_state
+            None,  # unused
+            debug_attn_mask_sharding,  # debugattn
+            Shard(0),  # q
+            Shard(0),  # k
+            Shard(0),  # v
+        ]
+    )
+    return single_mesh_dim_strategies
+
+
+@register_op_strategy(
+    aten._scaled_dot_product_flash_attention.default, schema_info=RuntimeSchemaInfo(5)
+)
+def scaled_dot_product_flash_attention_strategy(op_schema: OpSchema) -> OpStrategy:
+    # NOTE: currently we only support some simple strategies to support tensor parallelism
+    # TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation
+    # as it involves: matmul, pointwise, reduction ops together.
+
+    mesh = op_schema.get_mesh_from_args()
+    single_mesh_dim_strategies = _scaled_dot_product_flash_attention_base_strategies(
+        op_schema
+    )
+    return expand_to_full_mesh_op_strategy(
+        mesh, op_schema, single_mesh_dim_strategies, input_index=9
+    )
+
+
+def _scaled_dot_product_flash_attention_backward_base_strategies(
+    op_schema: OpSchema,
+) -> list[PlacementList]:
+    """Helper that returns list of base placement strategies (without CP)."""
+    q_input_strategy = op_schema.args_schema[1]
+    if not isinstance(q_input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(q_input_strategy)}")
+    # assuming q/k/v have the same shape
+
+    tensor_input_indices = [
+        i
+        for i, arg_spec in enumerate(op_schema.args_schema)
+        if isinstance(arg_spec, OpStrategy)
+    ]
+    num_tensor_inputs = len(tensor_input_indices)
+
+    single_mesh_dim_strategies = []
+
+    # placement list stores placements of [outputs, inputs]
+    # in the spda backward case, we have 3 tensor outputs and 6 to 10 tensor inputs
+    # first we can always accept full replication for both inputs and outputs
+    all_replicate: PlacementList = [Replicate()] * (3 + num_tensor_inputs)
+
+    single_mesh_dim_strategies.append(all_replicate)
+
+    # second we can accept the sharding pattern of tensor parallelism, which
+    # shard on the num of head dim
+    grad_output_sharding = Shard(1)  # num head dim
+    qkv_sharding = Shard(1)  # num head dim
+    output_sharding = Shard(1)  # num head dim
+    logsumexp_sharding = Shard(1)  # num head dim
+    grad_qkv_sharding = Shard(1)  # num head dim
+
+    num_heads_dim_sharding: PlacementList = [
+        grad_qkv_sharding,
+        grad_qkv_sharding,
+        grad_qkv_sharding,
+        grad_output_sharding,
+        qkv_sharding,
+        qkv_sharding,
+        qkv_sharding,
+        output_sharding,
+        logsumexp_sharding,
+    ]
+    # accept replicate on the rest tensor inputs, potentially
+    # cum_seq_q, cum_seq_k, philox_seed, philox_offset
+    # at indices 6, 7, 12, 13, respectively
+    num_heads_dim_sharding.extend([Replicate()] * (num_tensor_inputs - 6))
+    single_mesh_dim_strategies.append(num_heads_dim_sharding)
+
+    # Batch sharding
+    batch_dim_sharding: PlacementList = [
+        Shard(0),  # grad_q
+        Shard(0),  # grad_k
+        Shard(0),  # grad_v
+        Shard(0),  # grad_output
+        Shard(0),  # q
+        Shard(0),  # k
+        Shard(0),  # v
+        Shard(0),  # output
+        Shard(0),  # logsumexp
+    ]
+    # accept replicate on the rest tensor inputs, potentially
+    # cum_seq_q, cum_seq_k, philox_seed, philox_offset
+    # at indices 6, 7, 12, 13, respectively
+    batch_dim_sharding.extend([Replicate()] * (num_tensor_inputs - 6))
+    single_mesh_dim_strategies.append(batch_dim_sharding)
+
+    return single_mesh_dim_strategies
+
+
+@register_op_strategy(aten._scaled_dot_product_flash_attention_backward.default)
+def scaled_dot_product_flash_attention_backward_strategy(
+    op_schema: OpSchema,
+) -> OpStrategy:
+    # backward op does not need to validate the mesh since forward op has already done it
+    mesh = op_schema.get_mesh_from_args(validate=False)
+    single_mesh_dim_strategies = (
+        _scaled_dot_product_flash_attention_backward_base_strategies(op_schema)
+    )
+    return expand_to_full_mesh_op_strategy(
+        mesh, op_schema, single_mesh_dim_strategies, input_index=3
+    )
+
+
+@register_op_strategy(aten.constant_pad_nd.default)
+def constant_pad_nd_strategy(op_schema: OpSchema) -> OpStrategy:
+    mesh = op_schema.get_mesh_from_args(validate=False)
+
+    # TODO(d4l3k); implement a more correct strategy for constant_pad_nd
+    return OpStrategy(
+        [
+            OpSpec(
+                output_specs=DTensorSpec(mesh, (Replicate(),)),
+                input_specs=(
+                    DTensorSpec(mesh, (Replicate(),)),
+                    DTensorSpec(mesh, (Replicate(),)),
+                ),
+                redistribute_cost=[[1]],
+            )
+        ]
+    )
+
+
+def _scaled_dot_product_efficient_attention_base_strategies(
+    op_schema: OpSchema,
+) -> list[PlacementList]:
+    """Helper that returns list of base placement strategies (without CP)."""
+    q_input_strategy = op_schema.args_schema[0]
+    if not isinstance(q_input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(q_input_strategy)}")
+    # assuming q/k/v have the same shape
+
+    has_attn_bias = op_schema.args_schema[3] is not None
+    compute_log_sumexp = op_schema.args_schema[4]
+
+    single_mesh_dim_strategies: list[PlacementList] = []
+
+    # placement list stores placements of [outputs, inputs]
+    # in the spda case, we have 2 valid tensor outputs and 3 or 4 tensor inputs
+    # first we can always accept full replication for both inputs and outputs
+    all_replicate: PlacementList = [
+        Replicate(),
+        Replicate(),
+        None,
+        None,
+        Replicate(),
+        Replicate(),
+        Replicate(),
+    ]
+    if has_attn_bias:
+        all_replicate.append(Replicate())  # attn bias
+
+    single_mesh_dim_strategies.append(all_replicate)
+
+    # second we can accept the sharding pattern of tensor parallelism, which
+    # shard on the heads dimension
+    qkv_sharding = Shard(1)
+    output_sharding = Shard(1)
+    if compute_log_sumexp:
+        logsumexp_sharding: Placement = Shard(1)
+    else:
+        # empty logsumexp, replicated
+        logsumexp_sharding = Replicate()
+
+    num_heads_dim_sharding = [
+        output_sharding,
+        logsumexp_sharding,
+        None,
+        None,
+        qkv_sharding,
+        qkv_sharding,
+        qkv_sharding,
+    ]
+    if has_attn_bias:
+        num_heads_dim_sharding.append(Shard(1))
+    single_mesh_dim_strategies.append(num_heads_dim_sharding)
+
+    # batch sharding
+    if compute_log_sumexp:
+        logsumexp_sharding_dp: Placement = Shard(0)
+    else:
+        # empty logsumexp, replicated
+        logsumexp_sharding_dp = Replicate()
+    batch_sharding = [
+        Shard(0),  # output
+        logsumexp_sharding_dp,  # logsumexp
+        None,  # philox_seed
+        None,  # philox_offset
+        Shard(0),  # q
+        Shard(0),  # k
+        Shard(0),  # v
+    ]
+    if has_attn_bias:
+        batch_sharding.append(Shard(0))
+
+    single_mesh_dim_strategies.append(batch_sharding)
+
+    return single_mesh_dim_strategies
+
+
+@register_op_strategy(
+    aten._scaled_dot_product_efficient_attention.default,
+    schema_info=RuntimeSchemaInfo(4),
+)
+def scaled_dot_product_efficient_attention_strategy(op_schema: OpSchema) -> OpStrategy:
+    # NOTE: currently we only support some simple strategies to support tensor parallelism
+    mesh = op_schema.get_mesh_from_args()
+    single_mesh_dim_strategies = (
+        _scaled_dot_product_efficient_attention_base_strategies(op_schema)
+    )
+    return expand_to_full_mesh_op_strategy(
+        mesh,
+        op_schema,
+        single_mesh_dim_strategies,
+        input_index=4,
+    )
+
+
+def _scaled_dot_product_efficient_attention_backward_base_strategies(
+    op_schema: OpSchema,
+) -> list[PlacementList]:
+    """Helper that returns list of base placement strategies (without CP)."""
+    q_input_strategy = op_schema.args_schema[1]
+    if not isinstance(q_input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(q_input_strategy)}")
+    # assuming q/k/v have the same shape
+    has_attn_bias = op_schema.args_schema[4] is not None
+
+    single_mesh_dim_strategies = []
+
+    # placement list stores placements of [outputs, inputs]
+    # in the spda backward case, we have 4 tensor outputs and 8 or 9 tensor inputs
+    # NOTE: Output sharding of grad_bias on heads dim if attn_bias is present;
+    #       otherwise grad_bias will be empty and its DTensorSpec will be removed.
+    # first we can always accept full replication for both inputs and outputs
+    all_replicate: PlacementList = [Replicate()] * (12 + has_attn_bias)
+
+    if not has_attn_bias:
+        all_replicate[3] = None  # grad bias is None if attn_bias is not present
+
+    single_mesh_dim_strategies.append(all_replicate)
+
+    # second we can accept the sharding pattern of tensor parallelism, which
+    # shard on the heads dimension
+    grad_output_sharding = Shard(1)
+    qkv_sharding = Shard(1)
+    output_sharding = Shard(1)
+    logsumexp_sharding = Shard(1)
+    grad_qkv_sharding = Shard(1)
+    grad_bias_sharding = Shard(1) if has_attn_bias else None
+
+    num_heads_dim_sharding: PlacementList = [
+        grad_qkv_sharding,
+        grad_qkv_sharding,
+        grad_qkv_sharding,
+        grad_bias_sharding,
+        grad_output_sharding,
+        qkv_sharding,
+        qkv_sharding,
+        qkv_sharding,
+        # the place for optional input attn_bias,
+        output_sharding,
+        logsumexp_sharding,
+    ]
+    # input sharding of attn_bias on heads dim if present
+    if has_attn_bias:
+        num_heads_dim_sharding.insert(8, Shard(1))
+    # accept replicate on the rest scalar tensor inputs
+    # namely philox_seed and philox_offset
+    num_heads_dim_sharding.extend([Replicate(), Replicate()])
+    single_mesh_dim_strategies.append(num_heads_dim_sharding)
+
+    # Shards on batch dim
+    batch_dim_sharding: PlacementList = [
+        Shard(0),  # grad_q
+        Shard(0),  # grad_k
+        Shard(0),  # grad_v
+        Shard(0) if has_attn_bias else None,  # grad_bias
+        Shard(0),  # grad_output
+        Shard(0),  # q
+        Shard(0),  # k
+        Shard(0),  # v
+        Shard(0),  # output
+        Shard(0),  # logsumexp
+    ]
+    # accept replicate on the rest tensor inputs, potentially
+    # cum_seq_q, cum_seq_k, philox_seed, philox_offset
+    # at indices 6, 7, 12, 13, respectively
+    if has_attn_bias:
+        batch_dim_sharding.insert(8, Shard(0))
+    batch_dim_sharding.extend([Replicate(), Replicate()])
+    single_mesh_dim_strategies.append(batch_dim_sharding)
+
+    return single_mesh_dim_strategies
+
+
+@register_op_strategy(aten._scaled_dot_product_efficient_attention_backward.default)
+def scaled_dot_product_efficient_attention_backward_strategy(
+    op_schema: OpSchema,
+) -> OpStrategy:
+    # backward op does not need to validate the mesh since forward op has already done it
+    mesh = op_schema.get_mesh_from_args(validate=False)
+    single_mesh_dim_strategies = (
+        _scaled_dot_product_efficient_attention_backward_base_strategies(op_schema)
+    )
+    return expand_to_full_mesh_op_strategy(
+        mesh,
+        op_schema,
+        single_mesh_dim_strategies,
+        input_index=4,
+    )
+
+
+def _scaled_dot_product_cudnn_attention_base_strategies(
+    op_schema: OpSchema,
+) -> list[PlacementList]:
+    """Helper that returns list of base placement strategies (without CP)."""
+    (
+        query_strategy,  # query
+        _,  # key
+        _,  # value
+        attn_bias_strategy,
+        compute_log_sumexp,  # compute_log_sumexp
+        *rest_args,  # optional args: dropout_p, is_causal, return_debug_mask, scale
+    ) = op_schema.args_schema
+    return_debug_mask = len(op_schema.args_schema) >= 8 and rest_args[2]
+    has_attn_bias = attn_bias_strategy is not None
+    debug_attn_mask_sharding: Placement | None = (
+        Replicate() if return_debug_mask else None
+    )
+
+    if not isinstance(query_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(query_strategy)}")
+    # assuming q/k/v have the same shape
+
+    single_mesh_dim_strategies = []
+
+    # placement list stores placements of [outputs, inputs]
+    # in the spda case, we have 2 valid tensor outputs and 3 tensor inputs
+    # first we can always accept full replication for both inputs and outputs
+    all_replicate: PlacementList = [
+        Replicate(),  # output
+        Replicate(),  # logsumexp
+        None,  # cum_seq_q
+        None,  # cum_seq_k
+        None,  # max_q
+        None,  # max_k
+        None,  # philox_seed
+        None,  # philox_offset
+        # NOTE: debug_attn_mask is not supported by pytorch and is always an empty tensor
+        # https://github.com/pytorch/pytorch/blob/60205b0eb2602317856312a66d955c88334ade0b/aten/src/ATen/native/transformers/cuda/attention.cu#L839-L840
+        debug_attn_mask_sharding,  # debug_attn_mask
+        Replicate(),  # q
+        Replicate(),  # k
+        Replicate(),  # v
+    ]
+    if has_attn_bias:
+        all_replicate.append(Replicate())  # attn bias
+
+    single_mesh_dim_strategies.append(all_replicate)
+
+    # second we can accept the sharding pattern of tensor parallelism, which
+    # shard on the num of head dim
+    tp_sharding = Shard(1)  # num head dim
+    qkv_sharding = tp_sharding
+    output_sharding = tp_sharding
+    logsumexp_sharding = tp_sharding if compute_log_sumexp else Replicate()
+    debug_attn_mask_sharding = tp_sharding if return_debug_mask else None
+
+    num_heads_dim_sharding: PlacementList = [
+        output_sharding,
+        logsumexp_sharding,
+        None,  # cum_seq_q
+        None,  # cum_seq_k
+        None,  # max_q
+        None,  # max_k
+        None,  # philox_seed
+        None,  # philox_offset
+        debug_attn_mask_sharding,
+        qkv_sharding,
+        qkv_sharding,
+        qkv_sharding,
+    ]
+    single_mesh_dim_strategies.append(num_heads_dim_sharding)
+
+    # batch parallelism
+    logsumexp_sharding = Shard(0) if compute_log_sumexp else Replicate()
+    debug_attn_mask_sharding = Shard(0) if return_debug_mask else None
+    batch_dim_sharding: PlacementList = [
+        Shard(0),  # output
+        logsumexp_sharding,
+        None,  # cum_seq_q
+        None,  # cum_seq_k
+        None,  # max_q
+        None,  # max_k
+        None,  # philox_seed
+        None,  # philox_offset
+        debug_attn_mask_sharding,
+        Shard(0),  # q
+        Shard(0),  # k
+        Shard(0),  # v
+    ]
+    single_mesh_dim_strategies.append(batch_dim_sharding)
+
+    return single_mesh_dim_strategies
+
+
+@register_op_strategy(
+    aten._scaled_dot_product_cudnn_attention.default,
+    schema_info=RuntimeSchemaInfo(4),
+)
+def scaled_dot_product_cudnn_attention_strategy(op_schema: OpSchema) -> OpStrategy:
+    mesh = op_schema.get_mesh_from_args()
+    single_mesh_dim_strategies = _scaled_dot_product_cudnn_attention_base_strategies(
+        op_schema
+    )
+    return expand_to_full_mesh_op_strategy(
+        mesh, op_schema, single_mesh_dim_strategies, input_index=9
+    )
+
+
+def _scaled_dot_product_cudnn_attention_backward_base_strategies(
+    op_schema: OpSchema,
+) -> list[PlacementList]:
+    """Helper that returns list of base placement strategies (without CP)."""
+    if len(op_schema.args_schema) < 15:
+        raise AssertionError(
+            f"Expected at least 15 args_schema, got {len(op_schema.args_schema)}"
+        )
+    has_attn_bias = op_schema.args_schema[8] is not None
+    has_scale = len(op_schema.args_schema) >= 16 and False
+
+    query_strategy = op_schema.args_schema[1]
+    if not isinstance(query_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(query_strategy)}")
+    # assuming q/k/v have the same shape
+
+    single_mesh_dim_strategies = []
+
+    # placement list stores placements of [outputs, inputs]
+    # cudnn outputs: (Tensor dq, Tensor dk, Tensor dv)
+    # cudnn inputs: (
+    #   Tensor grad_out,
+    #   Tensor query,
+    #   Tensor key,
+    #   Tensor value,
+    #   Tensor out,
+    #   Tensor logsumexp,
+    #   Tensor philox_seed,
+    #   Tensor philox_offset,
+    #   Tensor attn_bias,
+    #   Tensor cum_seq_q,
+    #   Tensor cum_seq_k,
+    #   SymInt max_q,
+    #   SymInt max_k,
+    #   float dropout_p,
+    #   bool is_causal,
+    #   int? scale,
+    # )
+
+    # case 1: we can always accept full replication for both inputs and outputs
+    all_replicate_out: PlacementList = [
+        Replicate(),  # dq
+        Replicate(),  # dk
+        Replicate(),  # dv
+    ]
+    all_replicate_inp: PlacementList = [Replicate()] * 6
+    all_replicate_inp += [
+        Replicate()
+    ] * 2  # philox_seed, philox_offset is casted to Replicate() in DTensor
+    all_replicate_inp += [Replicate() if has_attn_bias else None]
+    all_replicate_inp += [None] * 6
+    if has_scale:
+        all_replicate_inp.append(None)
+
+    all_replicate: PlacementList = all_replicate_out + all_replicate_inp
+    single_mesh_dim_strategies.append(all_replicate)
+
+    # case 2: we can accept the sharding pattern of tensor parallelism, which
+    #   shards on the num of head dim
+    qkv_sharding = Shard(1)  # num head dim
+    output_sharding = Shard(1)  # num head dim
+    logsumexp_sharding = Shard(1)  # num head dim
+
+    num_heads_dim_sharding_out: PlacementList = [qkv_sharding] * 3
+    num_heads_dim_sharding_inp: PlacementList = [qkv_sharding] * 4
+    num_heads_dim_sharding_inp += [output_sharding]
+    num_heads_dim_sharding_inp += [logsumexp_sharding]
+    num_heads_dim_sharding_inp += [
+        Replicate()
+    ] * 2  # philox_seed, philox_offset is casted to Replicate() in DTensor
+    num_heads_dim_sharding_inp += [Shard(1) if has_attn_bias else None]
+    num_heads_dim_sharding_inp += [None] * 6
+    if has_scale:
+        num_heads_dim_sharding_inp.append(None)
+
+    num_heads_dim_sharding = num_heads_dim_sharding_out + num_heads_dim_sharding_inp
+    single_mesh_dim_strategies.append(num_heads_dim_sharding)
+
+    # case 3: we can accept the sharding pattern of batch parallelism, which
+    #   shards on the batch dimension
+    qkv_sharding = Shard(0)
+    output_sharding = Shard(0)
+    logsumexp_sharding = Shard(0)
+
+    batch_dim_sharding_out: PlacementList = [qkv_sharding] * 3
+    batch_dim_sharding_inp: PlacementList = [qkv_sharding] * 4
+    batch_dim_sharding_inp += [output_sharding]
+    batch_dim_sharding_inp += [logsumexp_sharding]
+    batch_dim_sharding_inp += [
+        Replicate()
+    ] * 2  # philox_seed, philox_offset is casted to Replicate() in DTensor
+    batch_dim_sharding_inp += [Shard(0) if has_attn_bias else None]
+    batch_dim_sharding_inp += [None] * 6
+    if has_scale:
+        batch_dim_sharding_inp.append(None)
+
+    batch_dim_sharding = batch_dim_sharding_out + batch_dim_sharding_inp
+    single_mesh_dim_strategies.append(batch_dim_sharding)
+
+    return single_mesh_dim_strategies
+
+
+@register_op_strategy(aten._scaled_dot_product_cudnn_attention_backward.default)
+def scaled_scaled_dot_product_cudnn_attention_backward_strategy(
+    op_schema: OpSchema,
+) -> OpStrategy:
+    # backward op does not need to validate the mesh since forward op has already done it
+    mesh = op_schema.get_mesh_from_args(validate=False)
+    single_mesh_dim_strategies = (
+        _scaled_dot_product_cudnn_attention_backward_base_strategies(op_schema)
+    )
+    return expand_to_full_mesh_op_strategy(
+        mesh, op_schema, single_mesh_dim_strategies, input_index=3
+    )
+
+
+@register_op_strategy(aten._grouped_mm.default)
+def grouped_mm_strategy(op_schema: OpSchema) -> OpStrategy:
+    mesh = op_schema.get_mesh_from_args()
+
+    mat1_strategy = op_schema.args_schema[0]
+    if not isinstance(mat1_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(mat1_strategy)}")
+    mat2_strategy = op_schema.args_schema[1]
+    if not isinstance(mat2_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(mat2_strategy)}")
+    if len(op_schema.args_schema) > 3:
+        bias_strategy = op_schema.args_schema[3]
+        if bias_strategy is not None:
+            raise AssertionError("grouped_mm doesn't support bias yet")
+
+    single_mesh_dim_strategies = []
+
+    offs_placement = None
+    if len(op_schema.args_schema) > 2 and op_schema.args_schema[2] is not None:
+        offs_placement = Replicate()  # offs should always be replicated
+
+    all_replicate: PlacementList = [
+        Replicate(),
+        Replicate(),  # mat1
+        Replicate(),  # mat2
+        offs_placement,  # offs
+        None,  # bias
+    ]
+    partial_replicate: PlacementList = [
+        Partial(),
+        Partial(),  # mat1
+        Replicate(),  # mat2
+        offs_placement,  # offs
+        None,  # bias
+    ]
+    replicate_partial: PlacementList = [
+        Partial(),
+        Replicate(),  # mat1
+        Partial(),  # mat2
+        offs_placement,  # offs
+        None,  # bias
+    ]
+    single_mesh_dim_strategies = [all_replicate, partial_replicate, replicate_partial]
+
+    if mat1_strategy.ndim == 2 and mat2_strategy.ndim == 3:
+        # rowwise_replicate for 2dx3d not supported
+        replicate_colwise_2x3: PlacementList = [
+            Shard(1),
+            Replicate(),  # mat1
+            Shard(2),  # mat2
+            offs_placement,  # offs
+            None,  # bias
+        ]
+        colwise_rowwise_2x3: PlacementList = [
+            Partial(),
+            Shard(1),  # mat1
+            Shard(1),  # mat2
+            offs_placement,  # offs
+            None,  # bias
+        ]
+        single_mesh_dim_strategies.extend([replicate_colwise_2x3, colwise_rowwise_2x3])
+
+    if mat1_strategy.ndim == 3 and mat2_strategy.ndim == 2:
+        # replicate_colwise for 3dx2d not supported
+        colwise_rowwise_3x2: PlacementList = [
+            Partial(),
+            Shard(2),  # mat1
+            Shard(0),  # mat2
+            offs_placement,  # offs
+            None,  # bias
+        ]
+        rowwise_replicate_3x2: PlacementList = [
+            Shard(0),
+            Shard(1),  # mat1
+            Replicate(),  # mat2
+            offs_placement,  # offs
+            None,  # bias
+        ]
+        single_mesh_dim_strategies.extend([colwise_rowwise_3x2, rowwise_replicate_3x2])
+
+    if mat1_strategy.ndim == 2 and mat2_strategy.ndim == 2:
+        # colwise_rowwise for 2dx2d not supported
+        replicate_colwise_2x2: PlacementList = [
+            Shard(2),
+            Replicate(),  # mat1
+            Shard(1),  # mat2
+            offs_placement,  # offs
+            None,  # bias
+        ]
+        rowwise_replicate_2x2: PlacementList = [
+            Shard(1),
+            Shard(0),  # mat1
+            Replicate(),  # mat2
+            offs_placement,  # offs
+            None,  # bias
+        ]
+        single_mesh_dim_strategies.extend(
+            [replicate_colwise_2x2, rowwise_replicate_2x2]
+        )
+
+    if mat1_strategy.ndim == 3 and mat2_strategy.ndim == 3:
+        replicate_colwise_3x3: PlacementList = [
+            Shard(2),
+            Replicate(),  # mat1
+            Shard(2),  # mat2
+            offs_placement,  # offs
+            None,  # bias
+        ]
+        rowwise_replicate_3x3: PlacementList = [
+            Shard(1),
+            Shard(1),  # mat1
+            Replicate(),  # mat2
+            offs_placement,  # offs
+            None,  # bias
+        ]
+        colwise_rowwise_3x3: PlacementList = [
+            Partial(),
+            Shard(2),  # mat1
+            Shard(1),  # mat2
+            offs_placement,  # offs
+            None,  # bias
+        ]
+        batch_dim_sharding: PlacementList = [
+            Shard(0),
+            Shard(0),  # mat1
+            Shard(0),  # mat2
+            offs_placement,  # offs
+            None,  # bias
+        ]
+        single_mesh_dim_strategies.extend(
+            [
+                replicate_colwise_3x3,
+                rowwise_replicate_3x3,
+                colwise_rowwise_3x3,
+                batch_dim_sharding,
+            ]
+        )
+
+    def valid_grouped_mm_strides(
+        input_specs: list[DTensorSpec], output_specs: tuple[DTensorSpec | None, ...]
+    ) -> bool:
+        # 1. compute the local-tensor shape/strides given this sharding proposal
+        # 2. apply the logic from the groped_mm meta function
+        # UGH the input DTensorSpecs are missing their tensormetas... so i can get them another way
+        def local_meta(spec: OpSpec, placements: tuple[Placement, ...]) -> TensorMeta:
+            if not isinstance(spec.output_specs, DTensorSpec):
+                raise AssertionError(
+                    f"Expected DTensorSpec, got {type(spec.output_specs)}"
+                )
+            if not isinstance(spec.output_specs.tensor_meta, TensorMeta):
+                raise AssertionError(
+                    f"Expected TensorMeta, got {type(spec.output_specs.tensor_meta)}"
+                )
+            meta: TensorMeta = spec.output_specs.tensor_meta
+            local_stride = compute_local_stride(meta.stride, mesh, placements)
+            local_shape, _ = compute_local_shape_and_global_offset(
+                meta.shape, mesh, placements, skip_offset=True
+            )
+            return TensorMeta(torch.Size(local_shape), local_stride, meta.dtype)
+
+        # pyrefly: ignore [missing-attribute]
+        mat1_meta = local_meta(mat1_strategy.strategies[0], input_specs[0].placements)
+        # pyrefly: ignore [missing-attribute]
+        mat2_meta = local_meta(mat2_strategy.strategies[0], input_specs[1].placements)
+
+        def check_valid_strides(meta: TensorMeta) -> bool:
+            # copied from `_meta_grouped_mm_common` in meta_registrations.py
+            end_dim = len(meta.shape) - 1
+            alignment = 16 // meta.dtype.itemsize
+            if meta.stride[end_dim - 1] == 1 and meta.stride[end_dim] >= max(
+                1, meta.shape[end_dim - 1]
+            ):
+                if meta.stride[end_dim] % alignment != 0:
+                    return False
+            elif meta.stride[end_dim] == 1 and meta.stride[end_dim - 1] >= max(
+                1, meta.shape[end_dim]
+            ):
+                if meta.stride[end_dim - 1] % alignment != 0:
+                    return False
+            else:
+                return False
+            return True
+
+        mat1_valid = check_valid_strides(mat1_meta)
+        mat2_valid = check_valid_strides(mat2_meta)
+        return mat1_valid and mat2_valid
+
+    return expand_to_full_mesh_op_strategy(
+        mesh,
+        op_schema,
+        single_mesh_dim_strategies,
+        input_index=1,
+        is_valid_strategy_cb=valid_grouped_mm_strides,
+    )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_pointwise_ops.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_pointwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..79030c0d4e28904af6b8d400d8cdd82872d315e3
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_pointwise_ops.py
@@ -0,0 +1,809 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+from collections.abc import Sequence
+from typing import cast
+
+import torch
+from torch.distributed.tensor._dtensor_spec import DTensorSpec
+from torch.distributed.tensor._op_schema import (
+    OpSchema,
+    OpSpec,
+    OpStrategy,
+    RuntimeSchemaInfo,
+    StrategyType,
+    TupleStrategy,
+)
+from torch.distributed.tensor._ops.registration import register_op_strategy
+from torch.distributed.tensor._ops.utils import (
+    generate_redistribute_costs,
+    infer_broadcast_dims_map,
+    map_placements_after_broadcast,
+    normalize_dim,
+)
+from torch.distributed.tensor.placement_types import (
+    _StridedShard,
+    Partial,
+    Placement,
+    Replicate,
+    Shard,
+)
+from torch.utils._typing_utils import not_none
+
+
+aten = torch.ops.aten
+# leave the remaining pointwise_ops list here for convenience,
+# Below ops are some pointwise ops that are yet to be supported,
+# they might not be a complete list.
+# pointwise_ops = [
+#     "fake_quantize_per_channel_affine",
+#     "fake_quantize_per_tensor_affine",
+#     "floor_divide",  # floor_divide is deprecated
+#     "frexp",  # multiple output pointwise op, need to add support
+#     "gradient",  #  need investigation on this op
+#     "imag",  # complex data type only
+#     "quantized_batch_norm",
+#     "quantized_max_pool1d",
+#     "quantized_max_pool2d",
+#     "real",  # complex data type only
+# ]
+
+
+pointwise_ops = [
+    # please keep the entries below alphabetically sorted
+    aten.__ilshift__.Scalar,
+    aten.__ilshift__.Tensor,
+    aten.__irshift__.Scalar,
+    aten.__irshift__.Tensor,
+    aten.__lshift__.Scalar,
+    aten.__lshift__.Tensor,
+    aten.__rshift__.Scalar,
+    aten.__rshift__.Tensor,
+    aten._conj.default,
+    aten.abs.default,
+    aten.abs.out,
+    aten.abs_.default,
+    aten.acos.default,
+    aten.acos.out,
+    aten.acos_.default,
+    aten.acosh.default,
+    aten.acosh.out,
+    aten.acosh_.default,
+    aten.add.Scalar,
+    aten.add.out,
+    aten.add_.Scalar,
+    aten.addcdiv.default,
+    aten.addcdiv.out,
+    aten.addcdiv_.default,
+    aten.addcmul.default,
+    aten.addcmul.out,
+    aten.addcmul_.default,
+    aten.angle.default,
+    aten.angle.out,
+    aten.asin.default,
+    aten.asin.out,
+    aten.asin_.default,
+    aten.asinh.default,
+    aten.asinh.out,
+    aten.asinh_.default,
+    aten.atan.default,
+    aten.atan.out,
+    aten.atan2.default,
+    aten.atan2.out,
+    aten.atan2_.default,
+    aten.atan_.default,
+    aten.atanh.default,
+    aten.atanh.out,
+    aten.atanh_.default,
+    aten.bitwise_and.Scalar,
+    aten.bitwise_and.Scalar_Tensor,
+    aten.bitwise_and.Scalar_out,
+    aten.bitwise_and.Tensor,
+    aten.bitwise_and.Tensor_out,
+    aten.bitwise_and_.Scalar,
+    aten.bitwise_and_.Tensor,
+    aten.bitwise_left_shift.Scalar_Tensor,
+    aten.bitwise_left_shift.Tensor,
+    aten.bitwise_left_shift.Tensor_Scalar,
+    aten.bitwise_left_shift.Tensor_Scalar_out,
+    aten.bitwise_left_shift.Tensor_out,
+    aten.bitwise_left_shift_.Tensor,
+    aten.bitwise_left_shift_.Tensor_Scalar,
+    aten.bitwise_not.default,
+    aten.bitwise_not.out,
+    aten.bitwise_not_.default,
+    aten.bitwise_or.Scalar,
+    aten.bitwise_or.Scalar_Tensor,
+    aten.bitwise_or.Scalar_out,
+    aten.bitwise_or.Tensor,
+    aten.bitwise_or.Tensor_out,
+    aten.bitwise_or_.Scalar,
+    aten.bitwise_or_.Tensor,
+    aten.bitwise_right_shift.Scalar_Tensor,
+    aten.bitwise_right_shift.Tensor,
+    aten.bitwise_right_shift.Tensor_Scalar,
+    aten.bitwise_right_shift.Tensor_Scalar_out,
+    aten.bitwise_right_shift.Tensor_out,
+    aten.bitwise_right_shift_.Tensor,
+    aten.bitwise_right_shift_.Tensor_Scalar,
+    aten.bitwise_xor.Scalar,
+    aten.bitwise_xor.Scalar_Tensor,
+    aten.bitwise_xor.Scalar_out,
+    aten.bitwise_xor.Tensor,
+    aten.bitwise_xor.Tensor_out,
+    aten.bitwise_xor_.Scalar,
+    aten.bitwise_xor_.Tensor,
+    aten.ceil.default,
+    aten.ceil.out,
+    aten.ceil_.default,
+    aten.clamp.default,
+    aten.clamp.Tensor,
+    aten.clamp.out,
+    aten.clamp_.default,
+    aten.clamp_.Tensor,
+    aten.clamp_min.default,
+    aten.clamp_min.Tensor,
+    aten.clamp_max.default,
+    aten.clamp_max.Tensor,
+    aten.clip.default,
+    aten.clip.out,
+    aten.clip_.default,
+    aten.conj_physical.default,
+    aten.conj_physical.out,
+    aten.conj_physical_.default,
+    aten.copysign.Scalar,
+    aten.copysign.Scalar_out,
+    aten.copysign.Tensor,
+    aten.copysign.out,
+    aten.copysign_.Scalar,
+    aten.copysign_.Tensor,
+    aten.cos.default,
+    aten.cos.out,
+    aten.cos_.default,
+    aten.cosh.default,
+    aten.cosh.out,
+    aten.cosh_.default,
+    aten.deg2rad.default,
+    aten.deg2rad.out,
+    aten.deg2rad_.default,
+    aten.digamma.default,
+    aten.digamma.out,
+    aten.digamma_.default,
+    aten.div.Tensor,
+    aten.div.Tensor_mode,
+    aten.div.out,
+    aten.div.out_mode,
+    aten.div_.Tensor,
+    aten.div_.Tensor_mode,
+    aten.eq.Tensor,
+    aten.eq.Tensor_out,
+    aten.eq.Scalar,
+    aten.eq.Scalar_out,
+    aten.erf.default,
+    aten.erf.out,
+    aten.erf_.default,
+    aten.erfc.default,
+    aten.erfc.out,
+    aten.erfc_.default,
+    aten.erfinv.default,
+    aten.erfinv.out,
+    aten.erfinv_.default,
+    aten.exp.default,
+    aten.exp.out,
+    aten.exp2.default,
+    aten.exp2.out,
+    aten.exp2_.default,
+    aten.exp_.default,
+    aten.expm1.default,
+    aten.expm1.out,
+    aten.expm1_.default,
+    aten.float_power.Scalar,
+    aten.float_power.Scalar_out,
+    aten.float_power.Tensor_Scalar,
+    aten.float_power.Tensor_Scalar_out,
+    aten.float_power.Tensor_Tensor,
+    aten.float_power.Tensor_Tensor_out,
+    aten.float_power_.Scalar,
+    aten.float_power_.Tensor,
+    aten.floor.default,
+    aten.floor.out,
+    aten.floor_.default,
+    aten.fmod.Scalar,
+    aten.fmod.Scalar_out,
+    aten.fmod.Tensor,
+    aten.fmod.Tensor_out,
+    aten.fmod_.Scalar,
+    aten.fmod_.Tensor,
+    aten.frac.default,
+    aten.frac.out,
+    aten.frac_.default,
+    aten.ge.Scalar,
+    aten.ge.Tensor,
+    aten.gelu.default,
+    aten.gt.Tensor,
+    aten.gt.Tensor_out,
+    aten.gt.Scalar,
+    aten.gt.Scalar_out,
+    aten.gt.Scalar,
+    aten.gt.Tensor,
+    aten.hypot.default,
+    aten.hypot.out,
+    aten.hypot_.default,
+    aten.i0.default,
+    aten.i0.out,
+    aten.i0_.default,
+    aten.igamma.default,
+    aten.igamma.out,
+    aten.igamma_.default,
+    aten.igammac.default,
+    aten.igammac.out,
+    aten.igammac_.default,
+    aten.isinf.default,
+    aten.isnan.default,
+    aten.isneginf.default,
+    aten.isneginf.out,
+    aten.isposinf.default,
+    aten.isposinf.out,
+    aten.ldexp.default,
+    aten.ldexp.out,
+    aten.ldexp_.default,
+    aten.lt.Tensor,
+    aten.lt.Tensor_out,
+    aten.lt.Scalar,
+    aten.lt.Scalar_out,
+    aten.le.Scalar,
+    aten.le.Tensor,
+    aten.lerp.Scalar,
+    aten.lerp.Scalar_out,
+    aten.lerp.Tensor,
+    aten.lerp.Tensor_out,
+    aten.lerp_.Scalar,
+    aten.lerp_.Tensor,
+    aten.lgamma.default,
+    aten.lgamma.out,
+    aten.lgamma_.default,
+    aten.log.default,
+    aten.log.out,
+    aten.log10.default,
+    aten.log10.out,
+    aten.log10_.default,
+    aten.log1p.default,
+    aten.log1p.out,
+    aten.log1p_.default,
+    aten.log2.default,
+    aten.log2.out,
+    aten.log2_.default,
+    aten.log_.default,
+    aten.logaddexp.default,
+    aten.logaddexp.out,
+    aten.logaddexp2.default,
+    aten.logaddexp2.out,
+    aten.logical_and.default,
+    aten.logical_and.out,
+    aten.logical_and_.default,
+    aten.logical_not.default,
+    aten.logical_not.out,
+    aten.logical_not_.default,
+    aten.logical_or.default,
+    aten.logical_or.out,
+    aten.logical_or_.default,
+    aten.logical_xor.default,
+    aten.logical_xor.out,
+    aten.logical_xor_.default,
+    aten.logit.default,
+    aten.logit.out,
+    aten.logit_.default,
+    aten.masked_fill.Scalar,
+    aten.masked_fill_.Scalar,
+    aten.maximum.default,
+    aten.maximum.out,
+    aten.minimum.default,
+    aten.minimum.out,
+    aten.mul.out,
+    aten.mvlgamma.default,
+    aten.mvlgamma.out,
+    aten.mvlgamma_.default,
+    aten.native_dropout_backward.default,
+    aten.native_dropout_backward.out,
+    aten.nan_to_num.default,
+    aten.nan_to_num.out,
+    aten.nan_to_num_.default,
+    aten.ne.Scalar,
+    aten.neg.default,
+    aten.neg.out,
+    aten.neg_.default,
+    aten.nextafter.default,
+    aten.nextafter.out,
+    aten.nextafter_.default,
+    aten.polygamma.default,
+    aten.polygamma.out,
+    aten.polygamma_.default,
+    aten.positive.default,
+    aten.pow.Scalar,
+    aten.pow.Scalar_out,
+    aten.pow.Tensor_Scalar,
+    aten.pow.Tensor_Scalar_out,
+    aten.pow.Tensor_Tensor,
+    aten.pow.Tensor_Tensor_out,
+    aten.pow_.Scalar,
+    aten.pow_.Tensor,
+    aten.reciprocal.default,
+    aten.reciprocal.out,
+    aten.reciprocal_.default,
+    aten.rad2deg.default,
+    aten.rad2deg.out,
+    aten.rad2deg_.default,
+    aten.relu.default,
+    aten.relu_.default,
+    aten.remainder.Scalar,
+    aten.remainder.Scalar_Tensor,
+    aten.remainder.Scalar_out,
+    aten.remainder.Tensor,
+    aten.remainder.Tensor_out,
+    aten.remainder_.Scalar,
+    aten.remainder_.Tensor,
+    aten.round.decimals,
+    aten.round.decimals_out,
+    aten.round.default,
+    aten.round.out,
+    aten.round_.decimals,
+    aten.round_.default,
+    aten.rsqrt.default,
+    aten.rsqrt.out,
+    aten.rsqrt_.default,
+    aten.rsub.Scalar,
+    aten.sgn.default,
+    aten.sgn.out,
+    aten.sgn_.default,
+    aten.sigmoid.default,
+    aten.sigmoid.out,
+    aten.sigmoid_.default,
+    aten.sign.default,
+    aten.sign.out,
+    aten.sign_.default,
+    aten.signbit.default,
+    aten.signbit.out,
+    aten.silu.default,
+    aten.silu.out,
+    aten.sin.default,
+    aten.sin.out,
+    aten.sin_.default,
+    aten.sinc.default,
+    aten.sinc.out,
+    aten.sinc_.default,
+    aten.sinh.default,
+    aten.sinh.out,
+    aten.sinh_.default,
+    aten.sqrt.default,
+    aten.sqrt.out,
+    aten.sqrt_.default,
+    aten.square.default,
+    aten.square.out,
+    aten.square_.default,
+    aten.sub.Scalar,
+    aten.sub.Tensor,
+    aten.sub.out,
+    aten.sub_.Scalar,
+    aten.sub_.Tensor,
+    aten.tan.default,
+    aten.tan.out,
+    aten.tan_.default,
+    aten.tanh.default,
+    aten.tanh.out,
+    aten.tanh_.default,
+    aten.true_divide.Tensor,
+    aten.trunc.default,
+    aten.trunc.out,
+    aten.trunc_.default,
+    aten.where.self,
+    aten.where.self_out,
+    aten.xlogy.OutScalar_Self,
+    aten.xlogy.OutScalar_Other,
+    aten.xlogy.OutTensor,
+    aten.xlogy.Scalar_Other,
+    aten.xlogy.Scalar_Self,
+    aten.xlogy.Tensor,
+    aten.xlogy_.Scalar_Other,
+    aten.xlogy_.Tensor,
+    # backward point-wise ops
+    # please keep the entries below alphabetically sorted
+    aten.gelu_backward.default,
+    aten.sigmoid_backward.default,
+    aten.silu_backward.default,
+    aten.tanh_backward.default,
+    aten.threshold_backward.default,
+]
+
+# the linear pointwise ops map, key is op, value is the type of linearity
+linear_pointwise_ops = {
+    aten.to.dtype: 0,
+    aten.add.Tensor: 1,
+    aten.add_.Tensor: 1,
+    aten.div.Scalar: 0,
+    aten.div_.Scalar: 0,
+    aten.mul.Scalar: 0,
+    aten.mul_.Scalar: 0,
+    aten.mul.Tensor: 2,
+    aten.mul_.Tensor: 2,
+    aten.copy_.default: 1,
+}
+
+
+def pointwise_strategy(op_schema: OpSchema, linearity: int = -1) -> OpStrategy:
+    followed_strategy_index = -1
+    max_shards = -1
+    max_ndim = -1
+
+    if op_schema.is_inplace_op():
+        # inplace op should follow the first arg strategy
+        followed_strategy = op_schema.args_schema[0]
+        followed_strategy_index = 0
+    elif op_schema.is_out_variant_op():
+        # out variant op should follow the out kwarg strategy
+        followed_strategy = op_schema.kwargs_schema["out"]
+        # out variant is technically a kwarg for the strategy to follow so it does not
+        # have an "index", we set it to a reasonably large number just to indicate it's
+        # not a valid index
+        followed_strategy_index = 100
+    else:
+        # normal pointwise op, we choose to follow the arg with
+        # the max shards in case operands needs reshard
+        # in case of multiple operands with max shard, we take
+        # the one with the max number of dimensions
+        for idx, arg_strategy in enumerate(op_schema.args_schema):
+            if not isinstance(arg_strategy, OpStrategy):
+                continue
+
+            arg_max_shards = arg_strategy.max_num_shards()
+            arg_max_ndim = arg_strategy.ndim
+            if (arg_max_shards > max_shards) or (
+                arg_max_shards == max_shards and arg_max_ndim > max_ndim
+            ):
+                followed_strategy_index = idx
+                max_shards = arg_max_shards
+                max_ndim = arg_max_ndim
+
+        followed_strategy = op_schema.args_schema[followed_strategy_index]
+
+    assert isinstance(followed_strategy, OpStrategy), (
+        f"no strategy to follow for {op_schema}!"
+    )
+    return common_pointwise_strategy(
+        op_schema.args_schema,
+        followed_strategy,
+        followed_strategy_index,
+        linearity,
+    )
+
+
+def linear_pointwise_strategy(op_schema: OpSchema) -> StrategyType:
+    """
+    Linear pointwise operators can propagate pending reductions.
+    For example, c = add(a, b); if a is pending sum, then c will be
+    pending sum as well without any communication overhead.
+
+    Note that:
+    1. Only unary and binary operations are supported, out variant
+      ops are not supported.
+    2. There're multiple types of linearity, refer to the doc of
+      common_pointwise_strategy for more details.
+    """
+    linearity_type = linear_pointwise_ops.get(op_schema.op, -1)
+    return pointwise_strategy(op_schema, linearity=linearity_type)
+
+
+def common_pointwise_strategy(
+    args_schema: Sequence[object],
+    followed_strategy: OpStrategy,
+    followed_strategy_index: int,
+    linearity: int = -1,
+    scalar_tensor_idx: int | None = None,
+) -> OpStrategy:
+    """
+    Common strategy for pointwise operations.
+
+    Args:
+        args_schema: Input arguments schema
+        followed_strategy: Strategy to follow for output placement
+        followed_strategy_index: Index of the strategy being followed
+        linearity: depending on the operator, we support different types of linearity
+            -1: the operation does not support linearity
+            0: the unary operation that supports linearity, output propagates partial.
+            1: the binary operation supports add linearity, where it requires every operand
+                to be partial, output propagates partial.
+            2: the binary operation supports multiplicative linearity, where it requires
+                the primary operand to be partial, and the other operands to be replicate,
+                output propagates partial.
+        scalar_tensor_idx: Index of the Replicate scalar tensor for which we allow the mesh
+            to be different from the mesh of followed_strategy
+    """
+    # handle broadcasting
+    common_shape = torch.broadcast_shapes(
+        *[arg.shape for arg in args_schema if isinstance(arg, OpStrategy)]
+    )
+    pointwise_strategy = OpStrategy([])
+
+    for op_spec in followed_strategy.strategies:
+        spec_to_follow = op_spec.output_spec
+
+        out_placements: list[Placement] = []
+        for placement in spec_to_follow.placements:
+            if isinstance(placement, Shard | _StridedShard):
+                shard_dim = normalize_dim(placement.dim, len(spec_to_follow.shape))
+                common_ndim = len(common_shape)
+                new_shard_dim = common_ndim - len(spec_to_follow.shape) + shard_dim
+                if isinstance(placement, _StridedShard):
+                    out_placements.append(
+                        _StridedShard(
+                            new_shard_dim, split_factor=placement.split_factor
+                        )
+                    )
+                else:
+                    out_placements.append(Shard(new_shard_dim))
+            elif isinstance(placement, Partial):
+                # note that only partial-sum and partial-avg are supported for linearity
+                partial_supports_linearity = placement.is_partial(
+                    "sum"
+                ) or placement.is_partial("avg")
+                if linearity > 0 and partial_supports_linearity:
+                    # propagate the partial placement
+                    out_placements.append(placement)
+                else:
+                    # clear the partial placement if op does not support linearity
+                    # by default we just replicate the partial, need to see if this
+                    # is optimal for all cases
+                    out_placements.append(Replicate())
+            else:
+                out_placements.append(placement)
+
+        input_specs: list[DTensorSpec] = []
+        redistribute_costs: list[list[float]] = []
+        for input_idx, input_arg in enumerate(args_schema):
+            if isinstance(input_arg, OpStrategy):
+                input_arg_spec = input_arg.strategies[0].output_spec
+
+                # sanity check that all args that follow the same strategy
+                # are on the same DeviceMesh
+                if input_arg.mesh != followed_strategy.mesh:
+                    # For the scalar tensor arg in fused ops, do not follow followed_strategy;
+                    # instead, let the input mesh and the Replicate placements propagate through.
+                    if input_idx == scalar_tensor_idx:
+                        assert all(p == Replicate() for p in input_arg_spec.placements)
+                        input_arg_target_spec = DTensorSpec(
+                            mesh=input_arg.mesh,
+                            placements=input_arg_spec.placements,
+                            tensor_meta=input_arg_spec.tensor_meta,
+                        )
+                        input_specs.append(input_arg_target_spec)
+                        redistribute_costs.append(
+                            generate_redistribute_costs(
+                                input_arg, input_arg_target_spec
+                            )
+                        )
+                        continue
+                    else:
+                        raise ValueError(
+                            f"Could not run pointwise computation across different mesh: "
+                            f"Found {input_arg.mesh} and {followed_strategy.mesh}!"
+                        )
+
+                # every arg follow the out_placements, but need to handle broadcasting
+                input_arg_dims_map = infer_broadcast_dims_map(
+                    common_shape, input_arg_spec.shape
+                )
+
+                # Determine if this input should convert Partial to Replicate base on linearity
+                should_convert_partial = (
+                    linearity == 2
+                    and input_idx
+                    != followed_strategy_index  # Don't convert the "followed" strategy
+                )
+
+                input_target_placements = map_placements_after_broadcast(
+                    tuple(out_placements),
+                    common_shape,
+                    input_arg_dims_map,
+                    partial_to_replicate=should_convert_partial,
+                )
+
+                input_arg_target_spec = DTensorSpec(
+                    mesh=followed_strategy.mesh,
+                    placements=input_target_placements,
+                    tensor_meta=input_arg_spec.tensor_meta,
+                )
+                input_specs.append(input_arg_target_spec)
+                redistribute_costs.append(
+                    generate_redistribute_costs(input_arg, input_arg_target_spec)
+                )
+
+        pointwise_strategy.strategies.append(
+            OpSpec(
+                output_specs=DTensorSpec(
+                    mesh=followed_strategy.mesh,
+                    placements=tuple(out_placements),
+                ),
+                input_specs=input_specs,
+                redistribute_cost=redistribute_costs,
+            )
+        )
+    return pointwise_strategy
+
+
+for op in linear_pointwise_ops:
+    register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))(
+        linear_pointwise_strategy
+    )
+
+for op in pointwise_ops:
+    register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))(
+        pointwise_strategy
+    )
+
+
+# TODO: add all for_each ops
+for_each_ops = [
+    aten._foreach_abs.default,
+    aten._foreach_abs_.default,
+    aten._foreach_addcdiv_.Scalar,
+    aten._foreach_addcdiv_.ScalarList,
+    aten._foreach_addcdiv_.Tensor,
+    aten._foreach_addcmul.Scalar,
+    aten._foreach_addcmul_.Scalar,
+    aten._foreach_addcmul_.ScalarList,
+    aten._foreach_addcmul_.Tensor,
+    aten._foreach_clamp_max_.Scalar,
+    aten._foreach_clamp_min_.Scalar,
+    aten._foreach_div_.List,
+    aten._foreach_div_.Scalar,
+    aten._foreach_div_.ScalarList,
+    aten._foreach_div_.Tensor,
+    aten._foreach_div.List,
+    aten._foreach_div.Scalar,
+    aten._foreach_div.ScalarList,
+    aten._foreach_div.Tensor,
+    aten._foreach_lerp_.Scalar,
+    aten._foreach_maximum_.List,
+    aten._foreach_mul.Scalar,
+    aten._foreach_mul.ScalarList,
+    aten._foreach_mul.Tensor,
+    aten._foreach_mul.List,
+    aten._foreach_mul_.Scalar,
+    aten._foreach_mul_.ScalarList,
+    aten._foreach_mul_.Tensor,
+    aten._foreach_mul_.List,
+    aten._foreach_pow.List,
+    aten._foreach_pow.ScalarList,
+    aten._foreach_neg.default,
+    aten._foreach_neg_.default,
+    aten._foreach_reciprocal_.default,
+    aten._foreach_sub.Scalar,
+    aten._foreach_sub_.Scalar,
+    aten._foreach_sub.List,
+    aten._foreach_sub_.List,
+    aten._foreach_sub.ScalarList,
+    aten._foreach_sub_.ScalarList,
+    aten._foreach_sqrt.default,
+    aten._foreach_sqrt_.default,
+    aten._foreach_zero_.default,
+    aten._foreach_exp.default,
+    aten._foreach_exp_.default,
+    aten._foreach_cos.default,
+    aten._foreach_cos_.default,
+    aten._foreach_log.default,
+    aten._foreach_log_.default,
+    aten._amp_foreach_non_finite_check_and_unscale_.default,
+]
+
+for_each_linearity_ops = [
+    aten._foreach_add.Scalar,
+    aten._foreach_add_.Scalar,
+    aten._foreach_add_.ScalarList,
+    aten._foreach_add.List,
+    aten._foreach_add_.List,
+]
+
+
+def list_pointwise_strategy(
+    op_schema: OpSchema, linearity: bool = False
+) -> StrategyType:
+    """
+    Apply the pointwise strategy to the zipped arguments. For example, if we
+    run a foreach add of two lists l1 and l2, then we apply the pointwise
+    strategy on each pair (l1[i], l2[i]). If the first argument is a list but
+    the second (or later) one is a tensor, then we broadcast the tensor by
+    replicating it into a list with the length of the first argument.
+
+    Args:
+        mesh (DeviceMesh): device mesh for pointwise ops
+        op_schema (OpSchema): schema of the operator to generate strategy for
+        linearity (bool): specify whether op(a) + op(b) = op(a + b)
+
+    Returns:
+        OpStrategy: generated strategy
+    """
+
+    def args_tuple_strategies(
+        args_schema: tuple[object, ...],
+    ) -> list[TupleStrategy | None]:
+        first_arg = args_schema[0]
+        assert isinstance(first_arg, TupleStrategy)
+        strategy_len = len(first_arg.children)
+        tuple_strategies: list[TupleStrategy | None] = []
+        for arg_idx, arg in enumerate(args_schema):
+            if isinstance(arg, TupleStrategy):
+                # every tuple strategy should have the same length
+                assert len(arg.children) == strategy_len
+                tuple_strategies.append(arg)
+            elif isinstance(arg, OpStrategy):
+                if arg_idx > 0:  # implicitly broadcast
+                    tuple_strategies.append(
+                        TupleStrategy([arg for _ in range(strategy_len)])
+                    )
+                else:
+                    raise RuntimeError(
+                        f"list op only supports tuple strategy! {op_schema}"
+                    )
+            else:
+                # insert None as placeholder so that the idx of arg is kept
+                tuple_strategies.append(None)
+        return tuple_strategies
+
+    args_strategies = args_tuple_strategies(op_schema.args_schema)
+    follow_strategy: TupleStrategy = not_none(args_strategies[0])
+    list_strategy: list[OpStrategy] = []
+
+    for child_idx, child_strtgy in enumerate(follow_strategy.children):
+        assert isinstance(child_strtgy, OpStrategy)
+        args_schema: list[OpStrategy | None] = [
+            cast(OpStrategy, arg_strategy.children[child_idx]) if arg_strategy else None
+            for arg_strategy in args_strategies
+        ]
+        pointwise_strategy: OpStrategy = common_pointwise_strategy(
+            args_schema,
+            child_strtgy,
+            linearity,
+            scalar_tensor_idx=(
+                _FUSED_OP_SCALAR_IDX if op_schema.op in fused_ops else None
+            ),
+        )
+        list_strategy.append(pointwise_strategy)
+    return TupleStrategy(list_strategy)
+
+
+def list_linear_pointwise_strategy(op_schema: OpSchema) -> StrategyType:
+    """
+    for each list op stratgy that supports linearity
+    """
+    return list_pointwise_strategy(op_schema, linearity=True)
+
+
+for op in for_each_ops:
+    register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))(
+        list_pointwise_strategy
+    )
+
+for op in for_each_linearity_ops:
+    register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))(
+        list_linear_pointwise_strategy
+    )
+
+fused_ops = [
+    aten._fused_adam_.default,
+    aten._fused_adam.default,
+    aten._fused_adam.tensor_lr,
+    aten._fused_adam_.tensor_lr,
+    aten._fused_adamw_.default,
+    aten._fused_adamw.default,
+    aten._fused_adamw.tensor_lr,
+    aten._fused_adamw_.tensor_lr,
+]
+
+
+# The state_steps arg of fused adam / adamw is a Replicate scalar tensor, which will be put on
+# the compute_mesh of an op across all parameter groups, even when not all parameter groups
+# are on the same device mesh. This idx will help avoid hitting exceptions or unnecessary
+# redistribute during sharding propagation.
+_FUSED_OP_SCALAR_IDX = 5
+
+for op in fused_ops:
+    register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))(
+        list_pointwise_strategy
+    )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_random_ops.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd4cf8fec226aa2538205c9a82f68ad05dbabb18
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_random_ops.py
@@ -0,0 +1,43 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+import torch
+from torch.distributed.tensor._op_schema import (
+    OpSchema,
+    OpSpec,
+    OpStrategy,
+    StrategyType,
+)
+from torch.distributed.tensor._ops.registration import register_op_strategy
+from torch.distributed.tensor._ops.utils import is_tensor_partial
+
+
+aten = torch.ops.aten
+
+
+@register_op_strategy(
+    [
+        aten.normal_.default,
+        aten.uniform_.default,
+        aten.native_dropout.default,
+        aten.bernoulli_.float,
+        aten.bernoulli.default,
+    ]
+)
+def random_op_strategy(op_schema: OpSchema) -> StrategyType:
+    self_strategy = op_schema.args_schema[0]
+    assert isinstance(self_strategy, OpStrategy)
+
+    random_strategy = OpStrategy([])
+    for arg_strategy in self_strategy.strategies:
+        arg_spec = arg_strategy.output_spec
+        if is_tensor_partial(arg_spec):
+            # TODO: figure out how inplace random op should behave when it's partial
+            raise RuntimeError(f"{op_schema.op} with Partial is not supported yet!")
+        random_strategy.strategies.append(
+            OpSpec(
+                output_specs=arg_spec,
+                input_specs=(arg_spec,),
+                redistribute_cost=[[0.0] * len(self_strategy.strategies)],
+            )
+        )
+
+    return random_strategy
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_tensor_ops.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_tensor_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c50085de487b12fae0ea657414eedfa40ce3cb1
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_tensor_ops.py
@@ -0,0 +1,1258 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+from collections.abc import Sequence, Sized
+from typing import cast
+
+import torch
+from torch._prims_common import IntLike
+from torch.distributed.tensor._dtensor_spec import DTensorSpec
+from torch.distributed.tensor._op_schema import (
+    OpSchema,
+    OpSpec,
+    OpStrategy,
+    OutputSharding,
+    PlacementList,
+    RuntimeSchemaInfo,
+    StrategyType,
+    TupleStrategy,
+)
+from torch.distributed.tensor._ops._common_rules import pointwise_rule
+from torch.distributed.tensor._ops._embedding_ops import MaskPartial
+from torch.distributed.tensor._ops.registration import (
+    register_op_strategy,
+    register_prop_rule,
+)
+from torch.distributed.tensor._ops.utils import (
+    expand_to_full_mesh_op_strategy,
+    generate_redistribute_costs,
+    is_tensor_dim_sharded,
+    is_tensor_evenly_shardable,
+    is_tensor_partial,
+    normalize_dim,
+    shift_shard_dims_after_insert,
+    shift_shard_dims_after_remove,
+)
+from torch.distributed.tensor.placement_types import (
+    Partial,
+    Placement,
+    Replicate,
+    Shard,
+)
+from torch.fx.experimental.symbolic_shapes import statically_known_true
+
+
+aten = torch.ops.aten
+
+
+def propagate_single_input_strategy(op_schema: OpSchema) -> StrategyType:
+    # For ops with a single tensor input, we perform a 1:1 mapping such that
+    # for each strategy that the input supports, we create a corresponding strategy.
+    # Note: this may be a complete waste of work, because it should be equivalent to
+    # `return first_input_strategy` (unless creating a deep copy is important for some reason)
+    if len([s for s in op_schema.args_schema if isinstance(s, OpStrategy)]) != 1:
+        raise AssertionError(
+            "propagate_single_input_strategy only works for single-tensor-input ops"
+        )
+    first_input_strategy = op_schema.args_schema[0]
+    if not isinstance(first_input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(first_input_strategy)}")
+    return OpStrategy(
+        [
+            OpSpec(
+                output_specs=DTensorSpec(
+                    mesh=first_input_strategy.mesh,
+                    placements=strategy.output_spec.placements,
+                    tensor_meta=strategy.output_spec.tensor_meta,
+                ),
+                input_specs=[
+                    DTensorSpec(
+                        mesh=first_input_strategy.mesh,
+                        placements=strategy.output_spec.placements,
+                        tensor_meta=strategy.output_spec.tensor_meta,
+                    )
+                ],
+                redistribute_cost=[
+                    generate_redistribute_costs(
+                        first_input_strategy, strategy.output_spec
+                    )
+                ],
+            )
+            for strategy in first_input_strategy.strategies
+        ]
+    )
+
+
+register_op_strategy(
+    [
+        aten.clone.default,
+        aten.contiguous.default,
+        aten.detach.default,
+        aten.alias.default,
+        aten.fill_.Scalar,
+        aten.view.dtype,
+        aten.zero_.default,
+    ]
+)(propagate_single_input_strategy)
+
+
+register_op_strategy(
+    aten._to_copy.default, schema_info=RuntimeSchemaInfo(static_kwargkey=["dtype"])
+)(propagate_single_input_strategy)
+
+
+@register_op_strategy(
+    [
+        aten.equal.default,
+        aten.is_same_size.default,
+    ]
+)
+def equal_strategy(op_schema: OpSchema) -> StrategyType:
+    # equal_strategy deals with ops that comparing two tensor, we need to make sure
+    # sharding layout the same with two operands, we choose to follow the arg with max
+    # num of shards, still keep is_same_size here for completeness as they share the
+    # same strategy in theory.
+    mesh = op_schema.get_mesh_from_args()
+    self_strategy, other_strategy = op_schema.args_schema
+    if not isinstance(self_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(self_strategy)}")
+    if not isinstance(other_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(other_strategy)}")
+
+    select_strategy = (
+        self_strategy
+        if self_strategy.max_num_shards() >= other_strategy.max_num_shards()
+        else other_strategy
+    )
+    equal_strategy = OpStrategy([])
+
+    for arg_strategy in select_strategy.strategies:
+        arg_spec = arg_strategy.output_spec
+        if is_tensor_partial(arg_spec):
+            # if the arg_spec have partial, reshard to replicate
+            # otherwise local shard tensor comparison would be invalid
+            output_spec = DTensorSpec(
+                mesh=mesh,
+                placements=tuple(
+                    Replicate() if isinstance(p, Partial) else p
+                    for p in arg_spec.placements
+                ),
+            )
+            equal_strategy.strategies.append(OpSpec(output_specs=output_spec))
+        else:
+            equal_strategy.strategies.append(OpSpec(arg_spec))
+    return equal_strategy
+
+
+register_op_strategy(
+    aten.empty_like.default, schema_info=RuntimeSchemaInfo(1, ["dtype"])
+)(propagate_single_input_strategy)
+
+
+@register_op_strategy(
+    [
+        aten.ones_like.default,
+        aten.rand_like.default,
+        aten.randn_like.default,
+        aten.zeros_like.default,
+    ],
+    schema_info=RuntimeSchemaInfo(1, ["dtype"]),
+)
+@register_op_strategy(
+    [aten.full_like.default],
+    schema_info=RuntimeSchemaInfo(2, ["dtype"]),
+)
+@register_op_strategy(
+    [
+        aten.randint_like.default,
+        aten.randint_like.low_dtype,
+        aten.randint_like.low_dtype_out,
+    ],
+    schema_info=RuntimeSchemaInfo(3, ["dtype"]),
+)
+def create_like_strategy(op_schema: OpSchema) -> StrategyType:
+    # create_like_strategy deals with ops that creating tensors with same
+    # shape as input, but with specific content that does not depend on
+    # the input, we can propagate sharding, but we have to make sure we
+    # move from partial to replicated.
+    select_strategy = op_schema.args_schema[0]
+    create_like_strategy = OpStrategy([])
+    if not isinstance(select_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(select_strategy)}")
+    for arg_strategy in select_strategy.strategies:
+        arg_spec = arg_strategy.output_spec
+        output_spec = DTensorSpec(
+            mesh=select_strategy.mesh,
+            placements=tuple(
+                Replicate() if isinstance(p, Partial) else p
+                for p in arg_spec.placements
+            ),
+        )
+        create_like_strategy.strategies.append(
+            OpSpec(output_specs=output_spec, input_specs=(arg_spec,))
+        )
+
+    return create_like_strategy
+
+
+@register_op_strategy(
+    [
+        aten.new_empty.default,
+        aten.new_full.default,
+        aten.new_ones.default,
+        aten.new_zeros.default,
+        aten.new_empty_strided.default,
+    ],
+    schema_info=RuntimeSchemaInfo(1, ["dtype"]),
+)
+def new_factory_strategy(op_schema: OpSchema) -> StrategyType:
+    # Currently there are two strategies:
+    # 1. let the output be replicated
+    # 2. let the output follow the input if input and output have the same shape
+    input_strategy = op_schema.args_schema[0]
+    if not isinstance(input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
+
+    mesh = input_strategy.mesh
+    input_shape = input_strategy.shape
+    output_shape = op_schema.args_schema[1]
+    if not isinstance(output_shape, list):
+        raise AssertionError(f"Expected list, got {type(output_shape)}")
+
+    new_factory_strategy = OpStrategy([])
+    for arg_strategy in input_strategy.strategies:
+        input_spec = arg_strategy.output_spec
+        replica_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim))
+        new_factory_strategy.strategies.append(
+            OpSpec(
+                output_specs=replica_spec,
+                input_specs=(input_spec,),
+                redistribute_cost=[[0.0] * len(input_strategy.strategies)],
+            )
+        )
+
+        if tuple(input_shape) == tuple(output_shape) and input_spec.is_sharded():
+            # NOTE: for new_empty_strided, currently the non-replicate sharding
+            #       is supported only when the shape is evenly shardable
+            if (
+                op_schema.op == aten.new_empty_strided.default
+                and not is_tensor_evenly_shardable(input_shape, input_spec)
+            ):
+                continue
+
+            new_factory_strategy.strategies.append(
+                OpSpec(
+                    output_specs=input_spec,
+                    input_specs=(input_spec,),
+                    # encouraging new tensor placement to be the same as input
+                    redistribute_cost=[[-0.1] * len(input_strategy.strategies)],
+                )
+            )
+
+    return new_factory_strategy
+
+
+@register_op_strategy(aten.bucketize.Tensor)
+def gen_bucketize_strategy(op_schema: OpSchema) -> StrategyType:
+    """Just propagate input sharding, but expect replicated for boundaries input."""
+    mesh = op_schema.get_mesh_from_args()
+    input_strategy, boundaries_strategy = op_schema.args_schema
+    bucketize_strategy = OpStrategy([])
+    if not isinstance(input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
+    if not isinstance(boundaries_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(boundaries_strategy)}")
+    for arg_strategy in input_strategy.strategies:
+        arg_spec = DTensorSpec(
+            mesh,
+            arg_strategy.output_spec.placements,
+            arg_strategy.output_spec.tensor_meta,
+        )
+        replica_spec = DTensorSpec(
+            mesh,
+            tuple([Replicate()] * mesh.ndim),
+            boundaries_strategy.strategies[0].output_spec.tensor_meta,
+        )
+        bucketize_strategy.strategies.append(
+            OpSpec(
+                output_specs=arg_spec,
+                input_specs=(arg_spec, replica_spec),
+                redistribute_cost=[
+                    generate_redistribute_costs(input_strategy, arg_spec),
+                    generate_redistribute_costs(boundaries_strategy, replica_spec),
+                ],
+            )
+        )
+
+    return bucketize_strategy
+
+
+@register_op_strategy(aten.select.int, schema_info=RuntimeSchemaInfo(1))
+def select_int_strategy(op_schema: OpSchema) -> StrategyType:
+    """
+    In this select op, first determine the input specs, then determine the output specs.
+    - Input specs:
+        - If the input is sharded on the selected dim, unshard it and change to replicate.
+        - Otherwise, keep the original input specs.
+    - Output specs:
+        - It checks the input specs with the following cases:
+        - Case 1 shard_dim == selected_dim: not possible as the input is already unsharded.
+        - Case 2 shard_dim < selected_dim: keep the input specs.
+        - Case 3 shard_dim > selected_dim: shard_dim -= 1.
+    """
+    input_strategy = op_schema.args_schema[0]
+    if not isinstance(input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
+    if len(op_schema.args_schema) != 3:
+        raise AssertionError(f"Expected 3 args, got {len(op_schema.args_schema)}")
+    selected_dim, index = (
+        cast(int, op_schema.args_schema[1]),
+        cast(int, op_schema.args_schema[2]),
+    )
+    input_shape = input_strategy.shape
+    input_ndim = input_strategy.ndim
+    selected_dim = normalize_dim(selected_dim, input_ndim)
+    index = normalize_dim(index, input_shape[selected_dim])
+
+    select_strategy = OpStrategy([])
+    for arg_strategy in input_strategy.strategies:
+        arg_spec = arg_strategy.output_spec
+
+        # determine input spec
+        input_specs = arg_spec
+        if is_tensor_dim_sharded(arg_spec, dim=selected_dim):
+            # if input is sharded on the selected dim, need to unshard it, change to replicate
+            arg_target_placements = unshard_tensor_dim(
+                arg_spec.placements, dim=selected_dim
+            )
+            input_specs = DTensorSpec(arg_spec.mesh, arg_target_placements)  # R
+
+        # determine output spec
+        output_specs = input_specs
+        if input_specs.is_sharded():
+            # handle cases with sharded_dim != selected_dim
+            output_placements = shift_shard_dims_after_remove(
+                input_specs.placements, selected_dim
+            )
+            output_specs = DTensorSpec(
+                arg_spec.mesh, placements=tuple(output_placements)
+            )
+
+        select_strategy.strategies.append(
+            OpSpec(
+                output_specs=output_specs,
+                input_specs=(input_specs,),
+            )
+        )
+    return select_strategy
+
+
+@register_op_strategy(
+    aten.select_backward.default,
+    schema_info=RuntimeSchemaInfo(1),
+)
+def select_backward_strategy(op_schema: OpSchema) -> OpStrategy:
+    # func: select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index) -> Tensor
+    args_schema = op_schema.args_schema
+    input_strategy, dim = args_schema[0], args_schema[2]
+    if not isinstance(input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {input_strategy}")
+    if not isinstance(dim, int):
+        raise AssertionError(f"Expected int, got {type(dim)}")
+    output_strategies: list[OpSpec] = []
+    for placement_strategy in input_strategy.strategies:
+        input_spec = placement_strategy.output_spec
+        # NOTE: shard_dim is guaranteed to exist because
+        # grad_input has one more dim than grad_output
+        output_placements = shift_shard_dims_after_insert(input_spec.placements, dim)
+        output_specs = DTensorSpec(input_spec.mesh, tuple(output_placements))
+        output_strategies.append(
+            OpSpec(output_specs=output_specs, input_specs=(input_spec,))
+        )
+    return OpStrategy(output_strategies)
+
+
+@register_op_strategy(aten.slice.Tensor, schema_info=RuntimeSchemaInfo(1))
+def gen_slice_strategy(op_schema: OpSchema) -> StrategyType:
+    """Forward all shardings except the slice dimension."""
+    defaults = (None, 0, None, None, 1)
+    input_strategy, dim, start, end, step = (
+        op_schema.args_schema + defaults[len(op_schema.args_schema) :]
+    )
+    if not isinstance(input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
+
+    mesh = input_strategy.mesh
+    input_shape = input_strategy.shape
+    input_ndim = input_strategy.ndim
+    if not isinstance(dim, int):
+        raise AssertionError(f"Expected int, got {type(dim)}")
+    if start is None:
+        start = 0
+    if end is None or statically_known_true(end > input_shape[dim]):
+        end = input_shape[dim]
+    if not isinstance(start, IntLike):
+        raise AssertionError(f"Expected IntLike, got {type(start)}")
+    if not isinstance(end, IntLike):
+        raise AssertionError(f"Expected IntLike, got {type(end)}")
+    if not isinstance(step, IntLike):
+        raise AssertionError(f"Expected IntLike, got {type(step)}")
+
+    # normalize args
+    slice_dim = normalize_dim(dim, input_ndim)  # type: ignore[arg-type]
+    start = normalize_dim(start, input_shape[dim])  # type: ignore[arg-type]
+    end = normalize_dim(end, input_shape[dim])  # type: ignore[arg-type]
+
+    statically_redundant_slice = (
+        statically_known_true(start == 0)
+        and statically_known_true(end == input_shape[dim])
+        and statically_known_true(step == 1)
+    )
+
+    slice_strategy = OpStrategy([])
+
+    for arg_strategy in input_strategy.strategies:
+        arg_spec = arg_strategy.output_spec
+        if (
+            not is_tensor_dim_sharded(arg_spec, dim=slice_dim)
+            or statically_redundant_slice
+        ):
+            # only add the strategy if the slice dim is not sharded
+            out_spec = DTensorSpec(mesh, arg_spec.placements)
+            slice_strategy.strategies.append(
+                OpSpec(
+                    output_specs=out_spec,
+                    input_specs=(arg_spec,),
+                    redistribute_cost=[[0.0] * len(input_strategy.strategies)],
+                )
+            )
+    if not slice_strategy.strategies:
+        # if all strategies are filtered out, unsharding all specs on slice dim
+        # of the input strategy, and use that as the op strategy
+        for arg_strategy in input_strategy.strategies:
+            arg_spec = arg_strategy.output_spec
+            unshard_spec = DTensorSpec(
+                mesh, unshard_tensor_dim(arg_spec.placements, dim=slice_dim)
+            )
+            slice_strategy.strategies.append(
+                OpSpec(
+                    output_specs=unshard_spec,
+                    redistribute_cost=[
+                        generate_redistribute_costs(input_strategy, unshard_spec)
+                    ],
+                )
+            )
+    return slice_strategy
+
+
+@register_op_strategy(
+    aten.slice_backward.default,
+    schema_info=RuntimeSchemaInfo(1),
+)
+def slice_backward_rules(op_schema: OpSchema) -> OpStrategy:
+    # func: slice_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step) -> Tensor
+    args_schema = op_schema.args_schema
+    input_strategy, dim = args_schema[0], args_schema[2]
+    if not isinstance(input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {input_strategy}")
+    output_strategies: list[OpSpec] = []
+    for placement_strategy in input_strategy.strategies:
+        output_spec = placement_strategy.output_spec
+        new_placements: list[Placement] = []
+        for placement in output_spec.placements:
+            # Redistribute to replicate only if the dim is sharded and matches the slice dim
+            if isinstance(placement, Shard) and placement.dim == dim:
+                new_placements.append(Replicate())
+            else:
+                new_placements.append(placement)
+        new_spec = DTensorSpec(output_spec.mesh, tuple(new_placements))
+        redistribute_cost = [generate_redistribute_costs(input_strategy, new_spec)]
+        new_strategy = OpSpec(
+            output_specs=new_spec, redistribute_cost=redistribute_cost
+        )
+        output_strategies.append(new_strategy)
+    return OpStrategy(output_strategies)
+
+
+def unshard_tensor_dim(
+    placements: Sequence[Placement], dim: int
+) -> tuple[Placement, ...]:
+    """Disallow the given tensor dimension to be sharded."""
+    return tuple(
+        p if (not isinstance(p, Shard) or p.dim != dim) else Replicate()
+        for p in placements
+    )
+
+
+def replicate_tensor_dim(
+    placements: Sequence[Placement], dim: int
+) -> tuple[Placement, ...]:
+    """Force the given tensor dimension to be replicated."""
+    # Not using p.is_shard() to avoid mypy complain about Placement not having
+    # attribute dim.
+    return tuple(
+        Replicate() if p.is_partial() or isinstance(p, Shard) and p.dim == dim else p
+        for p in placements
+    )
+
+
+@register_op_strategy(aten.slice_scatter.default, schema_info=RuntimeSchemaInfo(2))
+def gen_slice_scatter_strategy(op_schema: OpSchema) -> StrategyType:
+    # 1. number of dimensions in input and src need to match.
+    # 2. number of elements on all non-dim need to match between input and src.
+    # 3. number of elements in src in dim need to match the slice size.
+    # Given the above:
+    # - We suggest for src to follow the sharding of input, except on the scatter dimension,
+    #   where our best bet for now is to make them replicated as a fall-back.
+    #   TODO: Ideally we'd like to make sure the output is re-sharded afterwards to keep input sharding.
+    mesh = op_schema.get_mesh_from_args()
+    input_strategy = op_schema.args_schema[0]
+    src_strategy = op_schema.args_schema[1]
+    if not isinstance(input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
+    if not isinstance(src_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(src_strategy)}")
+    input_ndim = input_strategy.ndim
+    slice_dim = (
+        cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else 0
+    )
+    slice_dim = normalize_dim(slice_dim, input_ndim)
+
+    slice_scatter_strategy = OpStrategy([])
+    # by default follow the input strategy for both input and src
+    for arg_strategy in input_strategy.strategies:
+        arg_spec = arg_strategy.output_spec
+        if not (
+            is_tensor_dim_sharded(arg_spec, dim=slice_dim)
+            or is_tensor_partial(arg_spec)
+        ):
+            input_spec = DTensorSpec(mesh, arg_spec.placements, arg_spec.tensor_meta)
+            # TODO: need to relax the constraint to src
+            src_spec = DTensorSpec(mesh, arg_spec.placements)
+            # only add the strategy if the slice_scatter dim is not sharded or partial
+            slice_scatter_strategy.strategies.append(
+                OpSpec(
+                    output_specs=arg_spec,
+                    input_specs=(input_spec, src_spec),
+                    redistribute_cost=[
+                        generate_redistribute_costs(input_strategy, input_spec),
+                        generate_redistribute_costs(src_strategy, src_spec),
+                    ],
+                )
+            )
+
+    if not slice_scatter_strategy.strategies:
+        # if all strategies are filtered out, replicating all specs on slice_scatter dim
+        # of the input strategy, and use that as the op strategy
+        for arg_strategy in input_strategy.strategies:
+            arg_spec = arg_strategy.output_spec
+            new_placement = replicate_tensor_dim(arg_spec.placements, dim=slice_dim)
+            input_spec = DTensorSpec(mesh, new_placement)
+            src_spec = DTensorSpec(mesh, new_placement)
+            slice_scatter_strategy.strategies.append(
+                OpSpec(
+                    output_specs=input_spec,
+                    input_specs=(input_spec, src_spec),
+                    redistribute_cost=[
+                        generate_redistribute_costs(input_strategy, input_spec),
+                        generate_redistribute_costs(src_strategy, src_spec),
+                    ],
+                )
+            )
+    return slice_scatter_strategy
+
+
+@register_op_strategy(aten._local_scalar_dense.default)
+def replica_only_strategy(op_schema: OpSchema) -> StrategyType:
+    """Only allow replication on the input/output."""
+    input_strategy = op_schema.args_schema[0]
+    if not isinstance(input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
+    mesh = input_strategy.mesh
+    replicate_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim))
+    return OpStrategy([OpSpec(replicate_spec)])
+
+
+@register_op_strategy(
+    [
+        aten.scatter_.value,
+        aten.scatter.value,
+        aten.scatter_.src,
+        aten.scatter.src,
+    ],
+    schema_info=RuntimeSchemaInfo(1),
+)
+def scatter_strategy(op_schema: OpSchema) -> StrategyType:
+    mesh = op_schema.get_mesh_from_args()
+    single_mesh_dim_strategies = []
+
+    # placement list stores placements of [output, input, index, src]
+    # first we always have replicate all for inputs and output
+    if len(op_schema.args_strategy) < 3:
+        # scatter_.src/scatter.src with src be float number instead of tensor
+        all_replicate: PlacementList = [Replicate()] * 3
+    else:
+        all_replicate = [Replicate()] * 4
+    single_mesh_dim_strategies.append(all_replicate)
+
+    # TODO: see if we can support input sharding pattern
+    op_strategy = expand_to_full_mesh_op_strategy(
+        mesh,
+        op_schema,
+        single_mesh_dim_strategies,
+        inplace_op=op_schema.is_inplace_op(),
+    )
+    return op_strategy
+
+
+@register_op_strategy(aten.scatter_add.default, schema_info=RuntimeSchemaInfo(1))
+def scatter_add_strategy(op_schema: OpSchema) -> StrategyType:
+    input_strategy = op_schema.args_schema[0]
+    dim = op_schema.args_schema[1]
+    index_strategy = op_schema.args_schema[2]
+
+    if not isinstance(input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
+    if not isinstance(index_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(index_strategy)}")
+    if not isinstance(dim, int):
+        raise AssertionError(f"Expected int, got {type(dim)}")
+    dim = normalize_dim(dim, input_strategy.ndim)
+    mesh = input_strategy.mesh
+    input_shape = input_strategy.shape
+    index_shape = index_strategy.shape
+
+    single_mesh_dim_strategies = []
+
+    # placement list stores placements of [output, input, index, src]
+    # first we always have replicate all for inputs and output
+    all_replicate: PlacementList = [Replicate()] * 4
+    single_mesh_dim_strategies.append(all_replicate)
+
+    if len(input_shape) == len(index_shape):
+        for d in range(len(input_shape)):
+            if d != dim and input_shape[d] == index_shape[d]:
+                sharding: PlacementList = [Shard(d), Shard(d), Shard(d), Shard(d)]
+                single_mesh_dim_strategies.append(sharding)
+
+    return expand_to_full_mesh_op_strategy(
+        mesh, op_schema, single_mesh_dim_strategies, input_index=1
+    )
+
+
+@register_op_strategy(aten.gather.default, schema_info=RuntimeSchemaInfo(1))
+def gather_strategy(op_schema: OpSchema) -> StrategyType:
+    mesh = op_schema.get_mesh_from_args()
+    input_strategy = cast(OpStrategy, op_schema.args_schema[0])
+    dim = cast(int, op_schema.args_schema[1])
+    dim = normalize_dim(dim, input_strategy.ndim)
+    index_strategy = cast(OpStrategy, op_schema.args_schema[2])
+
+    input_shape = input_strategy.shape
+    index_shape = index_strategy.shape
+
+    single_mesh_dim_strategies = []
+
+    # placement list stores placements of [output, input, index]
+    # first we always have replicate all for inputs and output
+    all_replicate: PlacementList = [Replicate()] * 3
+    single_mesh_dim_strategies.append(all_replicate)
+
+    # input sharding, input sharded, index accepts mask partial, output follows index
+    # this only works when the input is sharded on the gather dimension, and
+    # index has size 1 on the gather dimension
+    if dim < len(index_shape) and index_shape[dim] == 1:
+        index_partial_placement = MaskPartial(offset_shape=input_shape, offset_dim=dim)
+        input_sharding: PlacementList = [
+            index_partial_placement,
+            Shard(dim),
+            index_partial_placement,
+        ]
+        single_mesh_dim_strategies.append(input_sharding)
+
+    # index sharding, input replicated, index sharded, output follows index
+    # this only works when the sharding dimension is the gather dimension
+    index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim)]
+    single_mesh_dim_strategies.append(index_sharding)
+
+    if len(input_shape) == len(index_shape):
+        for d in range(len(input_shape)):
+            if d != dim:
+                sharding: PlacementList = [Shard(d), Shard(d), Shard(d)]
+                single_mesh_dim_strategies.append(sharding)
+
+    return expand_to_full_mesh_op_strategy(
+        mesh, op_schema, single_mesh_dim_strategies, input_index=1
+    )
+
+
+def _derive_follow_placements_from_tuple_strategy(
+    op: torch._ops.OpOverload,
+    tuple_strategy: TupleStrategy,
+) -> Sequence[Placement]:
+    """
+    derive the placements to follow from the tuple strategy, mainly used by
+    aten.stack, aten.cat, where each operand have the same shape, and correspondingly
+    expecting the same sharding
+    """
+
+    def merge_placement(
+        cur_placement: Placement, new_placement: Placement
+    ) -> Placement:
+        # semantic if we already have a follow placement, we
+        # check each placement for the current arg placement
+        # to see if we want to merge/adjust the placement to follow
+        # the priority: Partial -> Shard -> Replicate
+        if cur_placement == new_placement:
+            return cur_placement
+
+        if cur_placement.is_partial():
+            if new_placement.is_shard():
+                # follow new placement
+                return new_placement
+            elif new_placement.is_partial():
+                # different partial types, we can't merge and have to replicate all here
+                return Replicate()
+            else:
+                # follow partial
+                return cur_placement
+        elif cur_placement.is_shard():
+            if new_placement.is_shard():
+                # cur/new placement are different sharding (i.e. different shard dim)
+                # currently fallback to replicate all args
+                return Replicate()
+            else:
+                # for partial/replicate, follow the current shard placement
+                return cur_placement
+        else:
+            # current replicate, just follow new placement
+            return new_placement
+
+    follow_placements: list[Placement] | None = None
+    mesh = tuple_strategy.child_mesh(0)
+    for arg_strategy in tuple_strategy.children:
+        if not isinstance(arg_strategy, OpStrategy):
+            raise AssertionError(f"Expected OpStrategy, got {type(arg_strategy)}")
+        if arg_strategy.mesh != mesh:
+            raise ValueError(
+                f"All operands in {op} must have the same mesh, "
+                f"but got {arg_strategy.mesh} and {mesh}."
+            )
+
+        for placement_strategy in arg_strategy.strategies:
+            arg_placements = placement_strategy.output_spec.placements
+            if follow_placements is None:
+                follow_placements = list(arg_placements)
+                continue
+            if follow_placements is None:
+                raise AssertionError(
+                    "follow_placements should not be None at this point"
+                )
+            for mesh_idx in range(mesh.ndim):
+                # merge placements with the priority
+                follow_placements[mesh_idx] = merge_placement(
+                    follow_placements[mesh_idx], arg_placements[mesh_idx]
+                )
+    if follow_placements is None:
+        raise AssertionError("follow placements should not be None!")
+    return follow_placements
+
+
+@register_op_strategy(aten.stack.default, RuntimeSchemaInfo(1, needs_pytree=True))
+def stack_strategy(op_schema: OpSchema) -> StrategyType:
+    args_schema = op_schema.args_schema
+    input_tuple_strategy = args_schema[0]
+    if not isinstance(input_tuple_strategy, TupleStrategy):
+        raise AssertionError(f"Expected TupleStrategy, got {input_tuple_strategy}")
+    input_strategies: list[OpStrategy] = []
+    for child in input_tuple_strategy.children:
+        assert isinstance(child, OpStrategy), f"Expected OpStrategy, got {child}"
+        input_strategies.append(child)
+    first_input_strategy = input_strategies[0]
+    common_input_ndim = first_input_strategy.ndim
+    dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0
+    # normalize the dim to be within the common input ndim
+    dim = normalize_dim(dim, common_input_ndim)
+
+    mesh = first_input_strategy.mesh
+
+    follow_placements = _derive_follow_placements_from_tuple_strategy(
+        op_schema.op, input_tuple_strategy
+    )
+
+    # create op strategy base on the follow placements
+    op_strategy = OpStrategy([])
+
+    input_specs = tuple(
+        DTensorSpec(mesh, tuple(follow_placements))
+        for _ in range(len(input_tuple_strategy.children))
+    )
+
+    # stack op would "insert" new dim, so all sharded dim >= the inserted dim need to
+    # be normalized with the new Shard placement
+    follow_placements = shift_shard_dims_after_insert(follow_placements, dim)
+    output_spec = DTensorSpec(mesh, tuple(follow_placements))
+    redistribute_cost = [
+        generate_redistribute_costs(input_strategies[i], input_specs[i])
+        for i in range(len(input_specs))
+    ]
+    op_strategy.strategies.append(
+        OpSpec(
+            output_specs=output_spec,
+            input_specs=input_specs,
+            redistribute_cost=redistribute_cost,
+        )
+    )
+    return op_strategy
+
+
+@register_op_strategy(aten.cat.default, RuntimeSchemaInfo(1, needs_pytree=True))
+def cat_strategy(op_schema: OpSchema) -> StrategyType:
+    args_schema = op_schema.args_schema
+    input_tuple_strategy = args_schema[0]
+    if not isinstance(input_tuple_strategy, TupleStrategy):
+        raise AssertionError(f"Expected TupleStrategy, got {input_tuple_strategy}")
+    num_input_tensor = len(input_tuple_strategy.children)
+    first_input_strategy = input_tuple_strategy.children[0]
+    if not isinstance(first_input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {first_input_strategy}")
+    common_input_ndim = first_input_strategy.ndim
+    dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0
+    # normalize the dim to be within the common input ndim
+    dim = normalize_dim(dim, common_input_ndim)
+
+    mesh = first_input_strategy.mesh
+
+    op_strategy = OpStrategy([])
+    # use a set to deduplicate strategies with the same placement
+    strategies_placement_pool = set()
+    for this_strategy in input_tuple_strategy.children:
+        # check strategy of each tensor to be concatenated
+        if not isinstance(this_strategy, OpStrategy):
+            raise AssertionError(f"Expected OpStrategy, got {type(this_strategy)}")
+        if this_strategy.mesh != mesh:
+            raise AssertionError("cat op doesn't support cross mesh concatenation")
+        for op_spec in this_strategy.strategies:
+            # Check each OpSpec of the tensor, the placement in this OpSpec
+            # is used as the exemplar strategy that other tensors and output
+            # tensor should follow. We also need to deduplicate the output
+            # strategy with the same placement.
+            if not isinstance(op_spec, OpSpec):
+                raise AssertionError(f"Expected OpSpec, got {type(op_spec)}")
+            # exemplar OpSpec to follow
+            exemplar_spec = op_spec.output_spec
+            # check if the tensor is sharded on the concat dim
+            if is_tensor_dim_sharded(exemplar_spec, dim):
+                # if the tensor is sharded on the concat dim, we need to unshard it
+                # first
+                exemplar_placement = unshard_tensor_dim(exemplar_spec.placements, dim)
+            else:
+                exemplar_placement = exemplar_spec.placements
+            if exemplar_placement not in strategies_placement_pool:
+                strategies_placement_pool.add(exemplar_placement)
+                # assert isinstance(exemplar_placement, Tuple)
+                redistribute_costs = []
+                input_specs = []
+                for idx in range(num_input_tensor):
+                    # extract the strategy for the idx tensors to build the tensor_metadata and redistribute_cost
+                    that_tensor_strategy = input_tuple_strategy.children[idx]
+                    if not isinstance(that_tensor_strategy, OpStrategy):
+                        raise AssertionError(
+                            f"Expected OpStrategy, got {type(that_tensor_strategy)}"
+                        )
+                    input_spec = DTensorSpec(
+                        mesh,
+                        exemplar_placement,
+                        tensor_meta=that_tensor_strategy.strategies[
+                            0
+                        ].output_spec.tensor_meta,
+                    )
+                    input_specs.append(input_spec)
+                    redistribute_costs.append(
+                        generate_redistribute_costs(that_tensor_strategy, input_spec)
+                    )
+                op_strategy.strategies.append(
+                    OpSpec(
+                        output_specs=DTensorSpec(mesh, exemplar_placement),
+                        input_specs=tuple(input_specs),
+                        redistribute_cost=redistribute_costs,
+                    )
+                )
+    return op_strategy
+
+
+@register_prop_rule(aten.index_select.default, schema_info=RuntimeSchemaInfo(1))
+def prop_index_select(op_schema: OpSchema) -> OutputSharding:
+    values_spec, dim, indices_spec = op_schema.args_schema
+
+    if not isinstance(values_spec, DTensorSpec):
+        raise AssertionError(f"Expected DTensorSpec, got {type(values_spec)}")
+    if not isinstance(dim, int):
+        raise AssertionError(f"Expected int, got {type(dim)}")
+    if not isinstance(indices_spec, DTensorSpec):
+        raise AssertionError(f"Expected DTensorSpec, got {type(indices_spec)}")
+
+    all_indices_spec: list[DTensorSpec | None] = [
+        indices_spec if dim == i else None for i in range(values_spec.ndim)
+    ]
+
+    result = prop_index(
+        OpSchema(
+            op=op_schema.op,
+            args_schema=(values_spec, all_indices_spec),
+            kwargs_schema=op_schema.kwargs_schema,
+        )
+    )
+    if result.redistribute_schema:
+        schema_suggestion = result.redistribute_schema
+        result.redistribute_schema = OpSchema(
+            op=op_schema.op,
+            args_schema=(
+                schema_suggestion.args_schema[0],
+                dim,
+                schema_suggestion.args_schema[1][dim],  # type: ignore[index]
+            ),
+            kwargs_schema=op_schema.kwargs_schema,
+        )
+    return result
+
+
+@register_op_strategy(
+    [
+        aten.index_put.default,
+        aten._index_put_impl_.default,
+    ],
+    schema_info=RuntimeSchemaInfo(needs_pytree=True),
+)
+def prop_index_put(op_schema: OpSchema) -> StrategyType:
+    # We have 3 DTensor spec from argument `in`, `indices` and `values`
+    # accordingly.
+    in_spec, indices_spec, values_spec, *_ = op_schema.args_schema
+    if not isinstance(in_spec, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(in_spec)}")
+    # `indices`` is a tuple of scalar LongTensor, so we use TupleStrategy.
+    if not isinstance(indices_spec, TupleStrategy):
+        raise AssertionError(f"Expected TupleStrategy, got {type(indices_spec)}")
+    if not isinstance(values_spec, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(values_spec)}")
+    mesh = values_spec.mesh
+    op_strategy = OpStrategy([])
+    # 1. `indices` should all be replicated first.
+    indices_redistribute_costs = []
+    new_indices_spec: list[DTensorSpec | None] = []
+    for indices_spec_child in indices_spec.children:
+        if not isinstance(indices_spec_child, OpStrategy):
+            raise AssertionError(f"Expected OpStrategy, got {type(indices_spec_child)}")
+
+        replicated_spec = DTensorSpec(
+            mesh=mesh,
+            placements=tuple([Replicate()] * mesh.ndim),
+            tensor_meta=indices_spec_child.strategies[0].output_spec.tensor_meta,
+        )
+        new_indices_spec.append(replicated_spec)
+        child_costs = generate_redistribute_costs(indices_spec_child, replicated_spec)
+        indices_redistribute_costs.append(child_costs)
+
+    # 2. For placement rule of `values` and `in`, assume `values` shape =
+    # [a,b,c,d,e,f], `in` shape = [d,e,f]. Then `values`'s a,b,c (selected dim)
+    # must be replicated and d,e,f (nonselected dim) in both `values` and `in`
+    # should follow the same sharding (replicate or shard, but not partial).
+    size_offset = (
+        in_spec.strategies[0].output_spec.ndim
+        - values_spec.strategies[0].output_spec.ndim
+    )
+    # We can either let `values` follow `in`'s placements or reverse.
+    for exemplar_spec in [in_spec, values_spec]:
+        # use exemplar_spec as the target spec
+        for strategy in exemplar_spec.strategies:
+            in_spec_new_placements: list[Placement] = []
+            values_spec_new_placements: list[Placement] = []
+            placements = strategy.output_spec.placements
+            for placement in placements:
+                if placement.is_shard():
+                    if not isinstance(placement, Shard):
+                        raise AssertionError(f"Expected Shard, got {type(placement)}")
+                    if exemplar_spec is in_spec:
+                        # let `values_spce` follow `in_spec`
+                        if placement.dim < size_offset:
+                            # sharded on selected dim, need to change to replicate
+                            in_spec_new_placements.append(Replicate())
+                            values_spec_new_placements.append(Replicate())
+                        else:
+                            in_spec_new_placements.append(placement)
+                            values_spec_new_placements.append(
+                                Shard(placement.dim - size_offset)
+                            )
+                    else:
+                        # let `in_spec` follow `values_spec`
+                        in_spec_new_placements.append(
+                            Shard(placement.dim + size_offset)
+                        )
+                        values_spec_new_placements.append(placement)
+                else:
+                    in_spec_new_placements.append(Replicate())
+                    values_spec_new_placements.append(Replicate())
+            new_in_spec = DTensorSpec(
+                mesh=mesh,
+                placements=tuple(in_spec_new_placements),
+                tensor_meta=in_spec.strategies[0].output_spec.tensor_meta,
+            )
+            new_values_spec = DTensorSpec(
+                mesh=mesh,
+                placements=tuple(values_spec_new_placements),
+                tensor_meta=values_spec.strategies[0].output_spec.tensor_meta,
+            )
+            output_spec = DTensorSpec(
+                mesh=mesh,
+                placements=tuple(in_spec_new_placements),
+                tensor_meta=in_spec.strategies[0].output_spec.tensor_meta,
+            )
+            cost_in_spec = generate_redistribute_costs(in_spec, new_in_spec)
+            cost_values_spec = generate_redistribute_costs(values_spec, new_values_spec)
+            op_strategy.strategies.append(
+                OpSpec(
+                    input_specs=(
+                        new_in_spec,
+                        *new_indices_spec,  # type: ignore[arg-type]
+                        new_values_spec,
+                    ),
+                    output_specs=output_spec,
+                    redistribute_cost=[
+                        cost_in_spec,
+                        *indices_redistribute_costs,
+                        cost_values_spec,
+                    ],
+                )
+            )
+    return op_strategy
+
+
+@register_prop_rule(aten.index.Tensor, schema_info=RuntimeSchemaInfo(needs_pytree=True))
+def prop_index(op_schema: OpSchema) -> OutputSharding:
+    """
+    Expect replicated on the first input; _mostly_ pointwise on the second input.
+
+    TODO: exception: when the dtype of second input is "bool", then a torch.nonzero needs to be triggered first.
+    """
+    # Current sharding constraints:
+    # For values:
+    #   1. We currently require that the dimension of values_spec be replicated or partial
+    #      if they are being indexed on.
+    #   2. Other dimensions of values_spec can remain sharded if they are so.
+    # For indices:
+    #   Indices can be either sharded or replicated. All index tensors need to be sharded
+    #   in a compatible way, following the pointwise rule (including resolving Partial
+    #   into either sharded or replicated)
+
+    values_spec, multi_indices_spec = op_schema.args_schema
+    if not isinstance(values_spec, DTensorSpec):
+        raise AssertionError(f"Expected DTensorSpec, got {type(values_spec)}")
+    if not isinstance(multi_indices_spec, list):
+        raise AssertionError(f"Expected list, got {type(multi_indices_spec)}")
+    multi_indices_spec = cast(list[DTensorSpec | None], multi_indices_spec)
+    valid_indices_spec: list[tuple[int, DTensorSpec]] = [
+        (i, a) for i, a in enumerate(multi_indices_spec) if a is not None
+    ]
+
+    # 1. All indices have to be sharded equally. Moreover, indices can be broadcast.
+    #    Here, we piggyback on the pointwise sharding rule for indices.
+    indices_out = pointwise_rule(
+        OpSchema(
+            op=op_schema.op,
+            args_schema=tuple(v[1] for v in valid_indices_spec),
+            kwargs_schema={},
+        )
+    )
+    need_reshard_on_indices = indices_out.output_spec is None
+
+    if not need_reshard_on_indices:
+        # this means that our inputs are already sharded properly and we will use that as our indices_spec
+        if not isinstance(indices_out.output_spec, DTensorSpec):
+            raise AssertionError(
+                f"Expected DTensorSpec, got {type(indices_out.output_spec)}"
+            )
+        indices_spec: DTensorSpec = indices_out.output_spec
+    else:
+        if indices_out.redistribute_schema is None:
+            raise AssertionError("redistribute_schema should not be None")
+        valid_indices_suggestion = indices_out.redistribute_schema
+        for i, v in enumerate(valid_indices_suggestion.args_spec):
+            multi_indices_spec[valid_indices_spec[i][0]] = v
+        # we'll need to call pointwise_rule again to see what's our ideal indices_spec and then
+        # use that to compute our ideal values_spec
+        indices_output_spec = pointwise_rule(valid_indices_suggestion).output_spec
+        if not isinstance(indices_output_spec, DTensorSpec):
+            raise AssertionError(
+                f"Expected DTensorSpec, got {type(indices_output_spec)}"
+            )
+        indices_spec = indices_output_spec
+
+    lookup_dims = {v[0] for v in valid_indices_spec}
+
+    need_reshard_on_values = tuple(
+        (isinstance(vp, Shard) and (vp.dim in lookup_dims or isinstance(ip, Shard)))
+        for vp, ip in zip(values_spec.placements, indices_spec.placements)
+    )
+
+    if not need_reshard_on_indices and not any(need_reshard_on_values):
+        value_placements = values_spec.placements
+
+        all_dims_consecutive = all(
+            b[0] - a[0] == 1
+            for b, a in zip(valid_indices_spec[1:], valid_indices_spec[:-1])
+        )
+        if all_dims_consecutive:
+            # if all index vectors are consecutives, insert at the dimension of the first index
+            insert_dim: int = valid_indices_spec[0][0]
+        else:
+            # else, insert on the first dimension
+            insert_dim = 0
+
+        def place(vp: Placement, ip: Placement) -> Placement:
+            if isinstance(vp, Shard):
+                return Shard(
+                    vp.dim
+                    if vp.dim < insert_dim
+                    # accounts for the offset in output dimensions
+                    else vp.dim
+                    + indices_spec.ndim
+                    - sum(1 if vp.dim > v[0] else 0 for v in valid_indices_spec)
+                )
+            if isinstance(ip, Shard):
+                return Shard(ip.dim + insert_dim)
+            # Partial or Replicated
+            return vp
+
+        value_placements = tuple(
+            place(vp, ip)
+            for vp, ip in zip(values_spec.placements, indices_spec.placements)
+        )
+        result = OutputSharding(
+            output_spec=DTensorSpec(
+                mesh=values_spec.mesh,
+                placements=value_placements,
+            )
+        )
+        return result
+    else:
+        result = OutputSharding(
+            output_spec=None,
+            redistribute_schema=OpSchema(
+                op=op_schema.op,
+                args_schema=(
+                    DTensorSpec(
+                        mesh=values_spec.mesh,
+                        placements=tuple(
+                            Replicate() if need_reshard_on_values[i] else v
+                            for i, v in enumerate(values_spec.placements)
+                        ),
+                        tensor_meta=values_spec.tensor_meta,
+                    ),
+                    multi_indices_spec,
+                ),
+                kwargs_schema=op_schema.kwargs_schema,
+            ),
+        )
+        return result
+
+
+@register_op_strategy(
+    [
+        aten.split.Tensor,
+        aten.split_with_sizes.default,
+        aten.split_with_sizes_copy.default,
+    ],
+    RuntimeSchemaInfo(1),
+)
+def split_strategy(op_schema: OpSchema) -> OpStrategy:
+    input_strategy = op_schema.args_schema[0]
+    split_size_or_sections = op_schema.args_schema[1]
+    if not isinstance(input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
+    input_ndim = input_strategy.ndim
+    split_dim = (
+        cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else 0
+    )
+    dim = normalize_dim(split_dim, input_ndim)
+
+    def size_split(N, i) -> list:
+        # Last chunk will be smaller if the tensor size N
+        # along the given dimension dim is not divisible by i.
+        if not i > 0:
+            raise AssertionError(f"Split size must be positive, got {i}")
+        return [i] * (N // i) + ([N % i] if N % i != 0 else [])
+
+    output_size_list = (
+        size_split(input_strategy.shape[dim], split_size_or_sections)
+        if isinstance(split_size_or_sections, int)
+        else split_size_or_sections
+    )
+    if not isinstance(output_size_list, Sized):
+        raise AssertionError(f"Expected Sized, got {type(output_size_list)}")
+
+    all_strategies = []
+    for strategy in input_strategy.strategies:
+        spec = strategy.output_spec
+        placements = spec.placements
+        if is_tensor_dim_sharded(spec, dim=dim):
+            # if the input is sharded on the split dim, we need to unshard it
+            placements = unshard_tensor_dim(spec.placements, dim=dim)
+
+        input_spec = DTensorSpec(spec.device_mesh, placements, spec.tensor_meta)
+        output_specs = tuple(
+            DTensorSpec(spec.device_mesh, placements)
+            for _ in range(len(output_size_list))
+        )
+        all_strategies.append(
+            OpSpec(
+                output_specs=output_specs,
+                input_specs=(input_spec,),
+                redistribute_cost=[
+                    generate_redistribute_costs(input_strategy, input_spec)
+                ],
+            )
+        )
+
+    return OpStrategy(all_strategies)
+
+
+# TODO: fix remaining failures in xfail("unbind") in test_dtensor_ops.py
+#       and remove this xfail item
+@register_op_strategy(aten.unbind.int, schema_info=RuntimeSchemaInfo(1))
+def gen_unbind_strategy(op_schema: OpSchema) -> StrategyType:
+    """Forward all shardings except the unbind dimension."""
+    input_strategy = op_schema.args_schema[0]
+    if not isinstance(input_strategy, OpStrategy):
+        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
+    input_ndim = input_strategy.ndim
+    input_shape = input_strategy.shape
+    unbind_dim = (
+        cast(int, op_schema.args_schema[1]) if len(op_schema.args_schema) > 1 else 0
+    )
+    unbind_dim = normalize_dim(unbind_dim, input_ndim)
+
+    mesh = input_strategy.mesh
+    unbind_strategy = OpStrategy([])
+    for arg_strategy in input_strategy.strategies:
+        arg_spec = arg_strategy.output_spec
+        if is_tensor_dim_sharded(arg_spec, dim=unbind_dim):
+            raise RuntimeError(
+                f"Attempted to unbind along the sharded dimension {unbind_dim}. ",
+                "It cannot be performed without redistribution, which is disallowed "
+                "by the current operator.",
+            )
+        # only add the strategy if the unbind dim is not sharded
+        output_placements = shift_shard_dims_after_remove(
+            arg_spec.placements, unbind_dim
+        )
+        output_specs = tuple(
+            DTensorSpec(mesh, tuple(output_placements))
+            for _ in range(input_shape[unbind_dim])
+        )
+        unbind_strategy.strategies.append(
+            OpSpec(
+                output_specs=output_specs,
+                input_specs=(arg_spec,),
+                redistribute_cost=[[0.0] * len(input_strategy.strategies)],
+            )
+        )
+    return unbind_strategy
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_view_ops.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_view_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee157aa26df4bc3e3a76052ae0b66255b12f7617
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_view_ops.py
@@ -0,0 +1,798 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+from collections.abc import Callable, Iterable, Sequence
+from dataclasses import dataclass
+from typing import cast, Optional
+
+import torch
+from torch import Tensor
+from torch._prims_common import DimsType
+from torch.distributed.tensor._dtensor_spec import DTensorSpec
+from torch.distributed.tensor._op_schema import (
+    OpSchema,
+    OpSpec,
+    OpStrategy,
+    RuntimeSchemaInfo,
+    StrategyType,
+)
+from torch.distributed.tensor._ops.registration import register_op_strategy
+from torch.distributed.tensor._ops.utils import (
+    generate_redistribute_costs,
+    normalize_dim,
+    normalize_dims,
+    prod,
+)
+from torch.distributed.tensor.placement_types import (
+    _StridedShard,
+    Placement,
+    Replicate,
+    Shard,
+)
+
+
+aten = torch.ops.aten
+
+Shape = tuple[int, ...]
+
+
+@dataclass
+class DimSpec:
+    """Specifies how an output dimension maps to an input dimension."""
+
+    def inputs(self) -> Iterable["DimSpec"]:
+        return ()
+
+
+# Rules that map each dimension of the output to dimensions of the input tensor
+DimMap = tuple[DimSpec, ...]
+
+
+@dataclass
+class Singleton(DimSpec):
+    """Output dimension is a singleton."""
+
+
+@dataclass
+class InputDim(DimSpec):
+    """Output dimension maps directly to an input dimension."""
+
+    input_dim: int
+
+
+@dataclass
+class Broadcast(DimSpec):
+    """Output is the broadcast of a singleton input dimension."""
+
+    dim: DimSpec
+    dim_size: int
+
+    @classmethod
+    def new(cls, dim: DimSpec, dim_size: int) -> DimSpec:
+        return Broadcast(dim, dim_size)
+
+    def inputs(self) -> Iterable[DimSpec]:
+        return (self.dim,)
+
+
+@dataclass
+class NewDim(DimSpec):
+    """This is a new dimension created by the op."""
+
+    size: int
+
+    @classmethod
+    def new(cls, size: int) -> DimSpec:
+        return Singleton() if size == 1 else NewDim(size)
+
+
+@dataclass
+class Repeat(DimSpec):
+    """Output dimension is the input dimension repeated n-times."""
+
+    input_dim: DimSpec
+    times: int
+
+    @classmethod
+    def new(cls, dim: DimSpec, times: int) -> DimSpec:
+        if times == 1:
+            return dim
+        elif isinstance(dim, Singleton):
+            # repeating a singleton is the same as broadcasting it
+            return Broadcast(dim, times)
+        else:
+            return Repeat(dim, times)
+
+    def inputs(self) -> Iterable[DimSpec]:
+        return (self.input_dim,)
+
+
+@dataclass
+class Flatten(DimSpec):
+    """Flatten a set of input dimensions, ensuring right-most adjacent elements remain adjacent in the output."""
+
+    input_dims: Sequence[DimSpec]
+
+    @classmethod
+    def new(cls, dims: Sequence[DimSpec]) -> DimSpec:
+        if len(dims) == 0:
+            # flattening a scalar leads to a singleton
+            return Singleton()
+        elif len(dims) == 1:
+            # flattening a single dimension is no-op
+            return dims[0]
+        else:
+            return Flatten(dims)
+
+    def inputs(self) -> Iterable[DimSpec]:
+        return self.input_dims
+
+
+@dataclass
+class Split(DimSpec):
+    """
+    This dimension is a member of a decomposition of the input dim.
+
+    Note that input_dim itself could be a Flattened set of input dims.
+    """
+
+    input_dim: DimSpec
+    group_shape: Shape
+    split_id: int
+
+    @classmethod
+    def new(cls, dim: DimSpec, group_shape: tuple[int, ...], idx: int) -> DimSpec:
+        if not len(group_shape) > 0:
+            raise AssertionError(
+                f"Expected group_shape length > 0, got {len(group_shape)}"
+            )
+        if len(group_shape) == 1:
+            # not really a group, just return the input dim back
+            if not idx == 0:
+                raise AssertionError(f"Expected idx == 0, got {idx}")
+            return dim
+        elif group_shape[idx] == 1:
+            return Singleton()
+        else:
+            # remove singletons from group
+            # group_mapping = [(new_index, (shape, old_index)) ...]
+            group_mapping = list(
+                enumerate((s, i) for i, s in enumerate(group_shape) if s != 1)
+            )
+            new_group_shape = tuple(m[1][0] for m in group_mapping)
+            new_idx = next(filter(lambda x: x[1][1] == idx, group_mapping))[0]
+            return Split(dim, new_group_shape, new_idx)
+
+    def inputs(self) -> Iterable[DimSpec]:
+        return (self.input_dim,)
+
+
+def dim_pad_left(ndim: int, min_dims: int) -> DimMap:
+    return (Singleton(),) * max(0, min_dims - ndim) + tuple(
+        InputDim(i) for i in range(ndim)
+    )
+
+
+def dim_atleast_3d(ndim: int) -> DimMap:
+    if ndim == 0:
+        return (Singleton(), Singleton(), Singleton())
+    elif ndim == 1:
+        return (Singleton(), InputDim(0), Singleton())
+    elif ndim == 2:
+        return (InputDim(0), InputDim(1), Singleton())
+    else:
+        return tuple(InputDim(i) for i in range(ndim))
+
+
+def expand(input_shape: Shape, shape: Shape) -> DimMap:
+    """Implement broadcast on multiple dimensions."""
+    if not len(shape) >= len(input_shape):
+        raise AssertionError(
+            f"Expected len(shape) >= len(input_shape), got {len(shape)} < {len(input_shape)}"
+        )
+
+    # 1. create padded input dimensions
+    padded_input = dim_pad_left(len(input_shape), len(shape))
+    # 2. check that input shapes are compatible
+    mapping = []
+    for p, desired_s in zip(padded_input, shape):
+        if isinstance(p, Singleton):
+            actual_s = 1
+            if not desired_s >= 0:
+                raise AssertionError(f"Expected desired_s >= 0, got {desired_s}")
+        else:
+            if not isinstance(p, InputDim):
+                raise AssertionError(f"DimSpec not supported in expand: {p}")
+            actual_s = input_shape[p.input_dim]
+            if not (actual_s == 1 or desired_s == -1 or desired_s == actual_s):
+                raise AssertionError(
+                    f"Expected actual_s == 1 or desired_s == -1 or "
+                    f"desired_s == actual_s, got actual_s={actual_s}, desired_s={desired_s}"
+                )
+        mapping.append(
+            p
+            if desired_s in (1, -1) or desired_s == actual_s
+            else Broadcast.new(p, desired_s)
+        )
+    return tuple(mapping)
+
+
+def normalize_sizes(sizes: Shape | tuple[Shape]) -> Shape:
+    if isinstance(sizes[0], int):
+        return cast(Shape, sizes)
+    elif len(sizes) == 1:
+        return sizes[0]
+    else:
+        raise RuntimeError("Size must be int... or tuple")
+
+
+def dim_flatten(ndim: int, start_dim=0, end_dim=-1) -> DimMap:
+    if ndim == 0:
+        return (Singleton(),)
+    elif ndim == 1:
+        return (InputDim(0),)
+    else:
+        # only flattening dims from start_dim to end_dim (inclusive)
+        # other dims are passed through
+        if end_dim < 0:
+            end_dim += ndim
+        results: list[DimSpec] = [InputDim(i) for i in range(start_dim)]
+        results.append(
+            Flatten.new(tuple(InputDim(i) for i in range(start_dim, end_dim + 1)))
+        )
+        results.extend([InputDim(i) for i in range(end_dim + 1, ndim)])
+        return tuple(results)
+
+
+def dim_movedim(
+    ndim: int,
+    input: DimsType,
+    destination: DimsType,
+) -> DimMap:
+    input = normalize_dims(input, ndim)
+    destination = normalize_dims(destination, ndim)
+
+    if not len(input) == len(destination):
+        raise AssertionError(
+            f"Expected len(input) == len(destination), got {len(input)} != {len(destination)}"
+        )
+    input_set = set(input)
+    if not len(input_set) == len(input):
+        raise AssertionError("Found repeated input dims")
+    if not len(set(destination)) == len(destination):
+        raise AssertionError("Found repeated output dims")
+    if not max(input) < ndim:
+        raise AssertionError(f"Expected max(input) < ndim, got {max(input)} >= {ndim}")
+    if not max(destination) < ndim:
+        raise AssertionError(
+            f"Expected max(destination) < ndim, got {max(destination)} >= {ndim}"
+        )
+
+    dest = [-1] * ndim
+    for i, d in zip(input, destination):
+        dest[d] = i
+
+    unused_inputs_iter = iter(i for i in range(ndim) if i not in input_set)
+    for i in range(ndim):
+        if dest[i] == -1:
+            dest[i] = next(unused_inputs_iter)
+
+    return tuple(InputDim(i) for i in dest)
+
+
+def dim_repeat(ndim: int, sizes: Shape) -> DimMap:
+    sizes = normalize_sizes(sizes)
+    if not len(sizes) >= ndim:
+        raise AssertionError(
+            f"Number of dimensions of repeat dims {sizes} can not be smaller than number of dimensions of tensor {ndim}."
+        )
+    pad = len(sizes) - ndim
+    return tuple(Repeat.new(Singleton(), s) for s in sizes[:pad]) + tuple(
+        Repeat.new(InputDim(i), s) for i, s in enumerate(sizes[pad:])
+    )
+
+
+def infer_size(total_size: int, sizes: Shape) -> Shape:
+    """
+    One dimension input to view may be "-1".
+
+    Infer the size of this dimension given the total_size.
+    """
+    infers = [i for i, s in enumerate(sizes) if s == -1]
+    size = prod(sizes)
+    if not len(infers) <= 1:
+        raise AssertionError("can only infer one size")
+    if infers:
+        size = -size
+        missing_size = total_size // size
+        if not total_size % size == 0:
+            raise AssertionError(
+                f"size inferred for -1 is not integral {sizes} should have {total_size} elements."
+            )
+        return tuple(s if s != -1 else missing_size for s in sizes)
+    if not size == total_size:
+        raise AssertionError(f"sizes do not match {total_size} vs {size}")
+    return sizes
+
+
+def view_groups(from_size: Shape, to_size: Shape) -> DimMap:
+    """
+    Decompose a reshape operation into forwarding, flattening, or splitting dimensions for each output dimension.
+
+    A view or reshape operation can be decomposed into a set of 3 types of smaller operations:
+    1) Forward a dimension from input to output
+    2) Flatten a set of dimensions into a single dimension
+    3) Split one dimension into multiple dimensions
+
+    view_groups identifies these operations and returns, for each output dimension, what
+    is operation was performed in the input dimension. For example:
+
+        view_groups([2, 3, 4], [2, 12]) -> (
+            InputDim(0),
+            Flatten((InputDim(1), InputDim(2)))
+        )
+
+    - output dimension 0 maps to input dimension 0
+    - output dimension 1 maps to a flattened input dimensions 1 and 2
+
+
+        view_groups([2, 3], [3, 2]) -> (
+            Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 0),
+            Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 1),
+        )
+
+    - in the above, input is flattened into a single dimension and then split
+      into two separate dimensions with different sizes from the input.
+    """
+    from_nelem = prod(from_size)
+    to_size = infer_size(from_nelem, normalize_sizes(to_size))
+
+    if not from_nelem == prod(to_size):
+        raise AssertionError("Total view shape does not add up")
+
+    from_idx = 0
+    to_idx = 0
+    from_len = len(from_size)
+    to_len = len(to_size)
+
+    result_pp = []
+
+    while from_idx < from_len or to_idx < to_len:
+        from_group_dim, to_group_shape = [], []
+
+        if from_idx >= from_len:
+            f = 1
+        else:
+            f = from_size[from_idx]
+            from_group_dim.append(from_idx)
+            from_idx += 1
+
+        if to_idx >= to_len:
+            t = 1
+        else:
+            t = to_size[to_idx]
+            to_group_shape.append(t)
+            to_idx += 1
+
+        # if any of the groups is singleton, great, we need to backtrack though
+        if f == 1 and t != 1:
+            # produces ([1], [])
+            to_idx -= 1
+            to_group_shape = []
+        elif f != 1 and t == 1:
+            # produces ([], [1])
+            from_idx -= 1
+            from_group_dim = []
+        else:
+            # produces ([1], [1]),  ([2], [2]), ([2,3], [6])
+            while f != t:
+                if f < t:
+                    nf = from_size[from_idx]
+                    from_group_dim.append(from_idx)
+                    from_idx += 1
+                    f *= nf
+                else:
+                    nt = to_size[to_idx]
+                    to_group_shape.append(nt)
+                    to_idx += 1
+                    t *= nt
+
+        if len(to_group_shape) > 0:
+            flattened = Flatten.new(
+                tuple(InputDim(fi) for fi in from_group_dim if from_size[fi] >= 1)
+            )
+            result_pp += [
+                Split.new(flattened, tuple(to_group_shape), i)
+                for i in range(len(to_group_shape))
+            ]
+
+    return tuple(result_pp)
+
+
+def dim_tile(ndim: int, dims: tuple[int, ...]) -> DimMap:
+    if len(dims) < ndim:
+        dims = (1,) * (ndim - len(dims)) + dims
+    return dim_repeat(ndim, dims)
+
+
+def dim_transpose(ndim: int, dim1: int, dim2: int) -> DimMap:
+    dim1 = normalize_dim(dim1, ndim)
+    dim2 = normalize_dim(dim2, ndim)
+    if not dim1 < ndim:
+        raise AssertionError(f"Expected dim1 < ndim, got {dim1} >= {ndim}")
+    if not dim2 < ndim:
+        raise AssertionError(f"Expected dim2 < ndim, got {dim2} >= {ndim}")
+    dimmap = [InputDim(i) for i in range(ndim)]
+    swapdim = dimmap[dim1]
+    dimmap[dim1] = dimmap[dim2]
+    dimmap[dim2] = swapdim
+    return tuple(dimmap)
+
+
+def dim_squeeze(shape: Shape, dim: int | None = None) -> DimMap:
+    # FIXME: this is wrong when dim=None and one of the dimensions
+    # equals size of the mesh. For example squeeze(DTensor(tensor(4), Shard[0])) could
+    # end up as squeeze(tensor(1)) if we have 4 devices; this would lead to
+    # removal of a dimension that is not actually a singleton.
+    return tuple(
+        InputDim(i)
+        for i, s in enumerate(shape)
+        if s > 1 or (dim is not None and i != normalize_dim(dim, len(shape)))
+    )
+
+
+def dim_unsqueeze(ndim: int, dim: int) -> DimMap:
+    dims = tuple(InputDim(i) for i in range(ndim))
+    if dim < 0:
+        dim += ndim + 1
+    return dims[:dim] + (Singleton(),) + dims[dim:]
+
+
+def dim_view_as_real(shape: Shape) -> DimMap:
+    ndim = len(shape)
+    results: list[DimSpec] = [InputDim(i) for i in range(ndim - 1)]
+    # each complex number is split into two real numbers,
+    # resulting in one more dimension of size 2
+    results.append(Split(InputDim(ndim - 1), (shape[-1], 2), 0))
+    results.append(Split(InputDim(ndim - 1), (shape[-1], 2), 1))
+    return tuple(results)
+
+
+def dim_reduction(ndim: int, dim_or_dims: DimsType | None, keepdim: bool) -> DimMap:
+    """
+    General fallback for reduction ops where Partial() does not apply.
+
+    This will cause incoming tensor to be replicated on the reducing dimensions.
+    """
+    if dim_or_dims is None:
+        dim_or_dims = tuple(range(ndim))
+    if isinstance(dim_or_dims, int):
+        dim_or_dims = (dim_or_dims,)
+    dim_or_dims = tuple(d if d >= 0 else d + ndim for d in dim_or_dims)
+    return tuple(
+        InputDim(i) if i not in dim_or_dims else Singleton()
+        for i in range(ndim)
+        if i not in dim_or_dims or keepdim
+    )
+
+
+dim_maps: dict[Callable[..., torch.Tensor], Callable[..., DimMap]] = {
+    torch.atleast_1d: lambda x: dim_pad_left(x.ndim, 1),
+    torch.atleast_2d: lambda x: dim_pad_left(x.ndim, 2),
+    torch.atleast_3d: lambda x: dim_atleast_3d(x.ndim),
+    torch.broadcast_to: lambda input, shape: expand(input.shape, shape),
+    Tensor.expand: lambda self, *sizes: expand(self.shape, normalize_sizes(sizes)),
+    torch.flatten: lambda tensor: dim_flatten(tensor.ndim),
+    torch.movedim: lambda input, source, destination: dim_movedim(
+        input.ndim, source, destination
+    ),
+    torch.permute: lambda input, dims: tuple(
+        InputDim(i) for i in normalize_dims(dims, input.ndim)
+    ),
+    torch.ravel: lambda tensor: dim_flatten(tensor.ndim),
+    Tensor.repeat: lambda self, *sizes: dim_repeat(self.ndim, sizes),
+    torch.reshape: lambda input, shape: view_groups(input.shape, shape),
+    torch.squeeze: lambda input, dim=None: dim_squeeze(input.shape, dim),
+    torch.tile: lambda input, dims: dim_tile(input.ndim, dims),
+    torch.transpose: lambda input, dim0, dim1: dim_transpose(input.ndim, dim0, dim1),
+    torch.unsqueeze: lambda input, dim: dim_unsqueeze(input.ndim, dim),
+    Tensor.view: lambda input, *shape: view_groups(input.shape, shape),
+    torch.view_as_complex: lambda input: dim_flatten(input.ndim, input.ndim - 2),
+    torch.view_as_real: lambda input: dim_view_as_real(input.shape),
+}
+
+
+def propagate_shape_and_sharding(
+    input_src_placements: Sequence[Placement],
+    global_input_shape: Shape,
+    rule: DimMap,
+    mesh_sizes: Shape,
+    strict_view: bool = False,
+) -> tuple[Sequence[Placement], Sequence[Placement]]:
+    """
+    Determine input target sharding and output sharding based on
+    given global tensor shape and input source sharding.
+
+    Sharding propagation follows mapped dimensions:
+    - An output dimension that maps directly to an input dimension is sharded equally
+    - An output dimension that is a flattened set of input dimensions can only be
+      sharded if only the leftmost flattened dimension is sharded.
+    - An output dimension that is a split of the input dimension can only be sharded
+      if the leftmost split size is divisible by the mesh dimension
+    """
+    if not len(input_src_placements) == len(mesh_sizes):
+        raise AssertionError(f"{input_src_placements} != {mesh_sizes}")
+    # for each input dim, for each mesh dim, provides a list of possible shardable dimensions
+    mesh_ndim = len(mesh_sizes)
+    shardable_dims: dict[int, list[bool]] = {}
+
+    # in case an input dimension disappears (e.g. collapsing, reduction)
+    # we cannot shard in that dimension (we need a replication fall-back rule)
+    seen_input_dims: set[int] = set()
+
+    def collect_used_inputs(cmd: DimSpec) -> None:
+        if isinstance(cmd, InputDim):
+            seen_input_dims.add(cmd.input_dim)
+        for inp in cmd.inputs():
+            collect_used_inputs(inp)
+
+    for cmd in rule:
+        collect_used_inputs(cmd)
+    for dim in range(len(global_input_shape)):
+        shardable_dims[dim] = [dim in seen_input_dims] * mesh_ndim
+
+    def maybe_get_shard_mesh_dim_and_placement(
+        input_dim: InputDim,
+    ) -> tuple[Optional[int], Optional[Shard | _StridedShard]]:
+        # if input_dim is sharded, return the mesh_dim and shard placement
+        for i, placement in enumerate(input_src_placements):
+            if (
+                isinstance(placement, Shard | _StridedShard)
+                and placement.dim == input_dim.input_dim
+            ):
+                return i, placement
+        return None, None
+
+    # NOTE: This function has three responsibilities:
+    # 1. determine "theoretically" if an output dimension can be sharded, i.e. fill the shardable_dims map
+    # 2. determine "theoretically" the corresponding input dimension to shard on, via return value
+    # 3. throw an error when strict_view is enabled and we cannot shard an output dimension
+    # 1 and 2 doesn't require the info of whether current input is sharded.
+    # 3 requires that info, to decide whether we can error out. Maybe we can refactor
+    # to make this function purely "theoretical".
+    def get_in_dim_to_shard(cmd: DimSpec) -> InputDim | None:
+        if isinstance(cmd, InputDim):
+            return cmd
+        elif isinstance(cmd, Flatten):
+            for i, dim in enumerate(cmd.input_dims):
+                # so far all Flatten is always composed of InputDims; revisit this if needed
+                if not isinstance(dim, InputDim):
+                    raise AssertionError(f"Expected InputDim, got {type(dim)}")
+                can_shard_dim = True
+                shard_mesh_dim, shard_placement = (
+                    maybe_get_shard_mesh_dim_and_placement(dim)
+                )
+                input_sharded = shard_mesh_dim is not None
+                if i > 0:
+                    can_shard_dim = False
+                    if strict_view and input_sharded:
+                        raise RuntimeError(
+                            f"Attempted to flatten multiple dimensions, with dimension {dim.input_dim} being sharded. ",
+                            "It cannot be performed without redistribution, which is disallowed by the current operator.",
+                        )
+                elif input_sharded:
+                    if not (shard_placement is not None and shard_mesh_dim is not None):
+                        raise AssertionError(
+                            "Expected shard_placement and shard_mesh_dim to be not None"
+                        )
+                    tensor_dim_size = global_input_shape[shard_placement.dim]
+                    mesh_dim_size = mesh_sizes[shard_mesh_dim]
+                    if tensor_dim_size % mesh_dim_size != 0:
+                        can_shard_dim = False
+                        if strict_view:
+                            raise RuntimeError(
+                                f"Attempted to flatten unevenly sharded dimension {i}, "
+                                "which would require resharding the input. "
+                                "Please explicitly redistribute the tensor instead."
+                            )
+                shardable_dims[dim.input_dim] = [can_shard_dim] * mesh_ndim
+
+            if not isinstance(cmd.input_dims[0], InputDim):
+                raise AssertionError(
+                    f"Expected InputDim, got {type(cmd.input_dims[0])}"
+                )
+            return cmd.input_dims[0]
+        elif isinstance(cmd, Split):
+            in_dim = get_in_dim_to_shard(cmd.input_dim)
+            out_size = cmd.group_shape[cmd.split_id]
+            if cmd.split_id == 0 and in_dim is not None:
+                # we need to check that the input dimension is divisible
+                # by the size of the submesh we're sharding it on
+                # NOTE: it would be possible to shard the same input dimension
+                # on more than one mesh dimension. In that case, the dimension
+                # needs to be divisible by the product of mesh sizes.
+                # In order to keep the problem more tractable, we will not consider
+                # double resharding as a suggestion (e.g. [Shard(0), Shard(0) ])
+                # but we will allow it if that's the input and it's compatible
+
+                # 1. is this dimension shardable on each individual mesh dim?
+                shardable_dims[in_dim.input_dim] = [
+                    out_size % mesh_dim_size == 0 for mesh_dim_size in mesh_sizes
+                ]
+
+                shard_mesh_dim, _ = maybe_get_shard_mesh_dim_and_placement(in_dim)
+                if strict_view and shard_mesh_dim is not None:
+                    if not shardable_dims[in_dim.input_dim][shard_mesh_dim]:
+                        raise RuntimeError(
+                            f"Attempted to split the sharded dimension {in_dim.input_dim} into multiple subdimensions. ",
+                            "It cannot be performed without redistribution, which is disallowed by the current operator.",
+                        )
+
+                # 2. here we special case things like [Shard(0), Shard(0)]
+                submesh_size = 1
+                for size, shard in zip(mesh_sizes, input_src_placements):
+                    if isinstance(shard, Shard | _StridedShard) and shard.dim == in_dim:
+                        submesh_size *= size
+                if not out_size % submesh_size == 0:
+                    raise AssertionError(
+                        f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}."
+                    )
+
+            # we will only shard our first component of the split
+            return in_dim if cmd.split_id == 0 else None
+        elif isinstance(cmd, Repeat):
+            in_dim = get_in_dim_to_shard(cmd.input_dim)
+            if in_dim is not None:
+                shardable_dims[in_dim.input_dim] = [False] * mesh_ndim
+            return None
+        else:
+            return None
+
+    # for each output dim, find the corresponding input dim in terms of sharding prop
+    shard_dim_map = {}
+    for dim, cmd in enumerate(rule):
+        in_dim = get_in_dim_to_shard(cmd)
+        if in_dim is not None:
+            shard_dim_map[in_dim.input_dim] = dim
+
+    input_tgt_placements = [
+        (
+            Replicate()
+            if isinstance(p, Shard | _StridedShard)
+            and not shardable_dims[p.dim][mesh_dim]
+            else p
+        )
+        for mesh_dim, p in enumerate(input_src_placements)
+    ]
+
+    def _rewrite_shard_dim(p: Shard | _StridedShard):
+        """
+        Rewrite the shard dim to the corresponding tensor dim in output.
+        For ``_StridedShard``, we can safely keep the placement type and
+        ``split_factor`` unchanged and only rewrite the ``dim`` because:
+        1. ``_StridedShard`` has no impact on sharding (i.e. how
+            tensor is partitioned) compared to ``Shard``. It only changes
+            how shards permute across the devices.
+        2. ``view()`` op on DTensor strictly forbids shard redistribution
+            which means if ``view()`` may cause shard permutation across
+            devices, it should be rejected. This is enforced in today's
+            sharding prop for ``view()``.
+        3. Since DTensor ``view()`` won't introduce any redistribution,
+            it's certain that ``placements`` won't change except the
+            inner ``dim`` attribute of ``Shard`` or ``_StridedShard``.
+        """
+        if isinstance(p, _StridedShard):
+            return _StridedShard(shard_dim_map[p.dim], split_factor=p.split_factor)
+        else:
+            return Shard(shard_dim_map[p.dim])
+
+    output_placements = [
+        _rewrite_shard_dim(p) if isinstance(p, Shard | _StridedShard) else p
+        for p in input_tgt_placements
+    ]
+
+    return input_tgt_placements, output_placements
+
+
+def register_op_strategy_map(
+    aten_op_overload: torch._ops.OpOverload,
+    local_op_name: Callable[..., torch.Tensor],
+    schema_info: RuntimeSchemaInfo | None = None,
+    strict_view: bool = False,
+) -> None:
+    """
+    Helper that registers strategies for view-like operators that follow a pattern:
+      (1) define the way input dims are split/combined to form output dims (dim_maps)
+      (2) register a strategy for the op schema that uses the dim_map as a sharding prop rule
+
+    strict_view: if True, we will error out if the view-operation would require resharding the input.
+       Currently, this should be set to 'true' for any "view" ops.
+       We could diverge behavior for "reshape" ops which could perform a redistribute implicitly.
+    """
+    dim_map: Callable[..., DimMap] = dim_maps[local_op_name]
+
+    @register_op_strategy(aten_op_overload, schema_info=schema_info)
+    def reshape_strategy(op_schema: OpSchema) -> StrategyType:
+        rules = dim_map(*op_schema.args_schema, **op_schema.kwargs_schema)
+        input_strategy = cast(OpStrategy, op_schema.args_schema[0])
+        mesh = op_schema.get_mesh_from_args(validate=False)
+
+        global_in_shape = input_strategy.shape
+        if global_in_shape is None:
+            raise AssertionError("Shape required.")
+
+        output_strategy = OpStrategy([])
+        for input_placement_strategy in input_strategy.strategies:
+            input_src_spec = input_placement_strategy.output_spec
+
+            input_tgt_placements, output_placements = propagate_shape_and_sharding(
+                input_src_spec.placements,
+                tuple(global_in_shape),
+                rules,
+                mesh.shape,
+                strict_view,
+            )
+
+            # TODO: optimize this. we shouldn't simply blindly replicate
+            #       unshardable dims ...
+            # FIXME: this can be wrong for situations where we have
+            #        [Shard(0), Shard(0)]
+            input_tgt_spec = DTensorSpec(
+                placements=tuple(input_tgt_placements),
+                mesh=mesh,
+                tensor_meta=input_src_spec.tensor_meta,
+            )
+            redistribute_costs: list[list[float]] = [
+                generate_redistribute_costs(input_strategy, input_tgt_spec)
+            ]
+
+            output_spec = DTensorSpec(mesh=mesh, placements=tuple(output_placements))
+            output_strategy.strategies.append(
+                OpSpec(
+                    output_specs=output_spec,
+                    input_specs=(input_tgt_spec,),
+                    redistribute_cost=redistribute_costs,
+                )
+            )
+
+        return output_strategy
+
+
+register_op_strategy_map(aten.squeeze.default, torch.squeeze)
+register_op_strategy_map(
+    aten.squeeze_.dim, torch.squeeze, schema_info=RuntimeSchemaInfo(1)
+)
+register_op_strategy_map(
+    aten.squeeze.dim, torch.squeeze, schema_info=RuntimeSchemaInfo(1)
+)
+register_op_strategy_map(
+    aten.view.default,
+    Tensor.view,
+    schema_info=RuntimeSchemaInfo(1),
+    strict_view=True,
+)
+register_op_strategy_map(
+    aten.reshape.default, torch.reshape, schema_info=RuntimeSchemaInfo(1)
+)
+register_op_strategy_map(
+    aten._unsafe_view.default,
+    Tensor.view,
+    schema_info=RuntimeSchemaInfo(1),
+    strict_view=True,
+)
+register_op_strategy_map(
+    aten.unsqueeze.default, torch.unsqueeze, schema_info=RuntimeSchemaInfo(1)
+)
+register_op_strategy_map(
+    aten.expand.default, Tensor.expand, schema_info=RuntimeSchemaInfo(1)
+)
+register_op_strategy_map(
+    aten.permute.default, torch.permute, schema_info=RuntimeSchemaInfo(1)
+)
+register_op_strategy_map(
+    aten.repeat.default, Tensor.repeat, schema_info=RuntimeSchemaInfo(1)
+)
+register_op_strategy_map(
+    aten.transpose.int, torch.transpose, schema_info=RuntimeSchemaInfo(1)
+)
+register_op_strategy_map(aten.view_as_complex.default, torch.view_as_complex)
+register_op_strategy_map(aten.view_as_real.default, torch.view_as_real)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/registration.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/registration.py
new file mode 100644
index 0000000000000000000000000000000000000000..98ec79d101591864f34025c3249db8f060654154
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/registration.py
@@ -0,0 +1,83 @@
+#  Copyright (c) Meta Platforms, Inc. and affiliates
+from collections.abc import Callable
+from typing import TypeAlias, TypeVar
+
+import torch
+from torch.distributed.tensor._api import DTensor
+from torch.distributed.tensor._op_schema import (
+    OpSchema,
+    OutputSharding,
+    RuntimeSchemaInfo,
+    StrategyType,
+)
+
+
+# convenient wrapper to register sharding propagation rules
+def register_prop_rule(
+    op: torch._ops.OpOverload | list[torch._ops.OpOverload],
+    schema_info: RuntimeSchemaInfo | None = None,
+) -> Callable[
+    [Callable[[OpSchema], OutputSharding]], Callable[[OpSchema], OutputSharding]
+]:
+    def wrapper(
+        impl: Callable[[OpSchema], OutputSharding],
+    ) -> Callable[[OpSchema], OutputSharding]:
+        overloads = op if isinstance(op, list) else [op]
+        for overload in overloads:
+            DTensor._op_dispatcher.sharding_propagator.register_sharding_prop_rule(
+                overload, impl, schema_info
+            )
+        return impl
+
+    return wrapper
+
+
+# Note:
+# using TypeVar here allows the registration decorator to preserve the specific type info of the wrapped strategy,
+# while hardcoding the typing on the wrapper (e.g. Callable[[OpSchema], StrategyType]) would mean mypy would treat
+# the return value of the wrapped strategy as always being a `StrategyType` even if it were a derived class like
+# MyStrategyType(StrategyType).
+_OpSchemaT = TypeVar("_OpSchemaT", bound=OpSchema)
+_StrategyTypeT = TypeVar("_StrategyTypeT", bound=StrategyType)
+_ShardingStrategyFunc: TypeAlias = Callable[[_OpSchemaT], _StrategyTypeT]
+
+
+def register_op_strategy(
+    op: torch._ops.OpOverload | list[torch._ops.OpOverload],
+    schema_info: RuntimeSchemaInfo | None = None,
+) -> Callable[[_ShardingStrategyFunc], _ShardingStrategyFunc]:
+    # For every ATen op that accepts any args in this list,
+    # the arg itself can impact the strides (and potentially the sharding strategy)
+    # of the output tensor.
+    # thus, we will detect ATen schemas with any of these args and ensure
+    # that they get specialized here.
+    arg_names_that_require_specializing_cache_strategy = [
+        "memory_format",
+    ]
+
+    def wrapper(impl: _ShardingStrategyFunc) -> _ShardingStrategyFunc:
+        if isinstance(op, list):
+            overloads = op
+        else:
+            overloads = [op]
+
+        for overload in overloads:
+            curr_schema_info = None
+            if schema_info is None:
+                specialized_args = [
+                    a.name
+                    for a in overload._schema.arguments
+                    if a.name in arg_names_that_require_specializing_cache_strategy
+                ]
+                if any(specialized_args):
+                    curr_schema_info = RuntimeSchemaInfo(
+                        static_kwargkey=specialized_args
+                    )
+            else:
+                curr_schema_info = schema_info
+            DTensor._op_dispatcher.sharding_propagator.register_op_strategy(
+                overload, impl, curr_schema_info
+            )
+        return impl
+
+    return wrapper
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2022c214298f26df41503988fa8684ab20ca3bf
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/utils.py
@@ -0,0 +1,388 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+import functools
+import itertools
+import operator
+from collections.abc import Callable, Iterable, Sequence
+from typing import cast
+
+import torch
+from torch._prims_common import DimsSequenceType, DimsType
+from torch.distributed.tensor._collective_utils import redistribute_cost
+from torch.distributed.tensor._dtensor_spec import DTensorSpec
+from torch.distributed.tensor._op_schema import (
+    OpSchema,
+    OpSpec,
+    OpStrategy,
+    PlacementList,
+    StrategyType,
+)
+from torch.distributed.tensor.device_mesh import DeviceMesh
+from torch.distributed.tensor.placement_types import (
+    _StridedShard,
+    Partial,
+    Placement,
+    Replicate,
+    Shard,
+)
+
+
+def replicate_op_strategy(op_schema: OpSchema) -> StrategyType:
+    """
+    Fallback strategy all use Replication()
+    """
+    args_strategy = op_schema.args_strategy
+    kwargs_strategy = op_schema.kwargs_strategy
+    inputs_strategy = args_strategy + kwargs_strategy
+
+    output_type = [str(ret.type) for ret in op_schema.op._schema.returns]
+    output_len = output_type.count("Tensor")
+    # TODO(zpcore): Confirm if view op can be handle properly or not. Prevent
+    # handling view ops until confirmed.
+    if op_schema.op.is_view:
+        raise RuntimeError(
+            "fallback strategy is unable to handle view ops until confirmed"
+        )
+    if "List[Tensor]" in output_type:
+        raise RuntimeError(
+            "fallback strategy is unable to handle ops with List[Tensor] output "
+            "because size of the list may depend on the op's input value"
+        )
+
+    mesh = inputs_strategy[0].mesh
+
+    dim_sharding: PlacementList = [Replicate()] * (output_len + len(inputs_strategy))
+    single_dim_placement = [dim_sharding]
+    return expand_to_full_mesh_op_strategy(
+        mesh, op_schema, single_dim_placement, input_index=output_len
+    )
+
+
+def as_list(
+    x: list[object] | object,
+    # pyre-fixme[11]: Annotation `immutable_list` is not defined as a type.
+) -> list[object] | torch.fx.immutable_collections.immutable_list:  # type: ignore[valid-type]
+    # During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args,
+    # which is an object but treated as a list by the tracer. Therefore, keep
+    # `immutable_list` intact here as well.
+    if type(x) is list or isinstance(x, torch.fx.immutable_collections.immutable_list):
+        return x
+    else:
+        return [x]
+
+
+def normalize_dim(dim: int, ndim: int) -> int:
+    return dim if dim >= 0 else dim + ndim
+
+
+def normalize_dims(dims: DimsType, ndim: int) -> DimsSequenceType:
+    """Normalize a dim or a sequence of dims, so that they are all positive."""
+    if isinstance(dims, int):
+        dims = (normalize_dim(dims, ndim),)
+    elif isinstance(dims, list):
+        dims = [normalize_dim(dim, ndim) for dim in dims]
+    elif isinstance(dims, tuple):
+        dims = tuple(normalize_dim(dim, ndim) for dim in dims)
+    return dims
+
+
+def prod(xs: Iterable[int]) -> int:
+    return functools.reduce(operator.mul, xs, 1)
+
+
+def is_tensor_shardable(
+    shape: Sequence[int],
+    spec: DTensorSpec,
+    allow_unbacked_sharding: bool | None = None,
+) -> bool:
+    """
+    Check if the shape is shardable according to the spec.
+
+    allow_unbacked_sharding: determines the fallback value if unbacked shapes are involved,
+    and the queried shape properties are not statically known.
+
+    e.g. when asking if u0 is shardable on num_shards, and u0 has generic bounds [0, inf],
+    the behavior of allow_unbacked_sharding is:
+
+        None: will data-dependent error
+        True: assumes shardability; we return True, allowing zero-size shards at runtime when u0 < num_shards.
+        False: returns False, and lower-bounding u0, e.g. torch._check(u0 >= num_shards), is needed to enable sharding.
+    """
+    from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
+
+    assert allow_unbacked_sharding in [None, True, False]
+    guard_fn = {
+        None: bool,
+        True: guard_or_false,
+        False: guard_or_true,
+    }[allow_unbacked_sharding]
+
+    # number of shards in each tensor dimension
+    shards_map = [1] * len(shape)
+    for i, placement in enumerate(spec.placements):
+        if placement.is_shard():
+            shard_dim = cast(Shard, placement).dim
+            if shard_dim >= len(shape):
+                return False
+            shards_map[shard_dim] *= spec.mesh.size(i)
+
+    for i, dim_size in enumerate(shape):
+        # TODO: maybe we should determine is_shardable based on
+        #       whether it's evenly sharded or not
+        if shards_map[i] > 1 and guard_fn(dim_size < shards_map[i]):
+            return False
+
+    return True
+
+
+def is_tensor_evenly_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool:
+    """Check if the shape is evenly shardable according to the spec."""
+    # number of shards in each tensor dimension
+    shards_map = [1] * len(shape)
+    for i, placement in enumerate(spec.placements):
+        if placement.is_shard():
+            shard_dim = cast(Shard, placement).dim
+            shards_map[shard_dim] *= spec.mesh.size(i)
+
+    for i, dim_size in enumerate(shape):
+        if shards_map[i] > 1 and (dim_size % shards_map[i] != 0):
+            return False
+
+    return True
+
+
+def is_tensor_evenly_shardable_on_dim(
+    shape: Sequence[int], spec: DTensorSpec, dim: int
+) -> bool:
+    """Check if the shape is evenly shardable according to the spec on dim."""
+    dim = normalize_dim(dim, len(shape))
+
+    num_shards = 1
+    for i, placement in enumerate(spec.placements):
+        if placement.is_shard():
+            shard_dim = cast(Shard, placement).dim
+            if shard_dim == dim:
+                num_shards *= spec.mesh.size(i)
+
+    return shape[dim] % num_shards == 0
+
+
+def is_tensor_dim_sharded(spec: DTensorSpec, dim: int) -> bool:
+    """Return True if tensor dim is sharded."""
+    return any(p.is_shard(dim) for p in spec.placements)
+
+
+def is_tensor_partial(spec: DTensorSpec) -> bool:
+    """Return True if tensor is partial on the mesh."""
+    return any(p.is_partial() for p in spec.placements)
+
+
+def infer_broadcast_dims_map(
+    common_shape: torch.Size, input_shape: torch.Size
+) -> list[int]:
+    # infer the broadcast dims map, where it maps from the common shape dim to the input shape dim
+    # this is aligned with the broadcast semantics
+    # e.g. if common_shape = [1, 2, 3, 4] and input_shape = [2, 3, 4],
+    # broadcast_dims_map will be [-1, 0, 1, 2]
+    # meaning that dim 0 in the output has no mapping to the input, and dim 1 in the output maps to dim 0 in the input
+    common_ndim = len(common_shape)
+    input_ndim = len(input_shape)
+    broadcast_dims_map = [-1] * common_ndim
+    for idx in range(-1, -1 - input_ndim, -1):
+        if input_shape[idx] == common_shape[idx]:
+            broadcast_dims_map[common_ndim + idx] = input_ndim + idx
+    return broadcast_dims_map
+
+
+def map_placements_after_broadcast(
+    placements: tuple[Placement, ...],
+    shape: torch.Size,
+    broadcast_dims_map: list[int],
+    partial_to_replicate: bool = False,
+) -> tuple[Placement, ...]:
+    """Map each placement based on the output shape after broadcast."""
+    new_placements: list[Placement] = []
+    for placement in placements:
+        if isinstance(placement, Partial):
+            if partial_to_replicate:
+                # map the partial placement to replicate
+                new_placements.append(Replicate())
+            else:
+                new_placements.append(placement)
+        elif isinstance(placement, Replicate):
+            new_placements.append(placement)
+        else:
+            assert isinstance(placement, Shard | _StridedShard)
+            shard_dim = normalize_dim(placement.dim, len(shape))
+            new_shard_dim = broadcast_dims_map[shard_dim]
+            if new_shard_dim != -1:
+                # there's a map from the common shape shard dim to
+                # the input shape shard dim before broadcasting,
+                # use that instead
+                if isinstance(placement, _StridedShard):
+                    new_placements.append(
+                        _StridedShard(
+                            new_shard_dim, split_factor=placement.split_factor
+                        )
+                    )
+                else:
+                    new_placements.append(Shard(new_shard_dim))
+            else:
+                # there's no map between common shape shard dim and
+                # the input shape shard dim before broadcasting,
+                # in this case it means implicit broadcasting happen
+                # in this dim, so we can just mark it as replicate
+                # and implicit broadcast will broadcast automatically
+                # to the sharded shape
+                new_placements.append(Replicate())
+
+    return tuple(new_placements)
+
+
+def generate_redistribute_costs(
+    src_strategy: OpStrategy, dst_spec: DTensorSpec
+) -> list[float]:
+    """Generates one row in the 'redistribute_costs' matrix in an OpSpec
+    The length of the returned list will match the number of strategies in 'src_strategy'.
+
+    Each value in the row is the cost of redistributing from a particular src_strategy to dst_spec.
+    """
+    redistribute_costs: list[float] = [
+        redistribute_cost(strat.output_spec, dst_spec)
+        for strat in src_strategy.strategies
+    ]
+
+    return redistribute_costs
+
+
+def expand_to_full_mesh_op_strategy(
+    mesh: DeviceMesh,
+    op_schema: OpSchema,
+    single_mesh_dim_strategies: list[PlacementList],
+    *,
+    input_index: int = 1,
+    inplace_op: bool = False,
+    is_valid_strategy_cb: Callable[
+        [list[DTensorSpec], tuple[DTensorSpec | None, ...]], bool
+    ]
+    | None = None,
+) -> OpStrategy:
+    """
+    Convenience function to allow writing a sharding strategy considering only a single mesh dimension,
+    and have it expanded combinatorically to all mesh dimensions.
+
+    Args:
+        mesh (DeviceMesh): the device mesh to expand the strategy to
+        op_schema (OpSchema): the op schema
+        single_mesh_dim_strategies (list[PlacementList]): the sharding strategies to expand. The outer list is over
+            different strategies.  The inner PlacementList is over the outputs and inputs of the op. If input_index is 1,
+            a PlacementList looks like [output_placement, input_placement1, input_placement2, ...].
+        input_index: the number of outputs of the op, defaults to 1
+        inplace_op: whether the op is inplace or not, defaults to False
+        is_valid_strategy_cb: a callback function to filter out invalid sharding rules, defaults to None.
+
+    Example: Let's say `my_op(tensor_x, tensor_y) - > output_tensor`  can support sharding or replicating tensor_x,
+    but always requires tensor_y to be replicated.  We can specify these valid combinations ignoring mesh dims.
+    Then, we can rely on `expand_to_full_mesh_op_strategy` to create every possible combination of these shardings
+    over multiple mesh dimensions, filtering out any combinations that are invalid based on the actual mesh dim size.
+
+        single_mesh_dim_strategies = [
+            # first strategy: return output sharded on first dim, shard tensor_x on its first dim, replicate tensor_y
+            [Shard(0), Shard(0), Replicate()]
+            # second strategy: replicate output, and both inputs
+            [Replicate(), Replicate(), Replicate()]
+        ]
+    """
+    # Expand the single_mesh_dim_strategies to full mesh dim strategies.
+    all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim
+
+    strategy_combs = itertools.product(*all_mesh_dim_strategies)
+
+    all_strategies = []
+    for strategy_comb in strategy_combs:
+        spec_list: list[DTensorSpec | None] = []
+        for specs in zip(*strategy_comb):
+            if specs[0] is not None:
+                # TODO: we should fill in tensor_meta here.  If nothing else, it helps the filter strategy callback
+                # pyrefly: ignore [bad-argument-type]
+                spec_list.append(DTensorSpec(mesh, specs))
+            else:
+                spec_list.append(None)
+
+        input_specs: list[DTensorSpec] = [
+            s for s in spec_list[input_index:] if isinstance(s, DTensorSpec)
+        ]
+
+        args_strategy = op_schema.args_strategy
+        kwargs_strategy = op_schema.kwargs_strategy
+        input_args_strategy = args_strategy + kwargs_strategy
+
+        if len(input_specs) != len(input_args_strategy):
+            raise AssertionError(
+                f"input_specs({len(input_specs)}) != strategies({len(input_args_strategy)}: "
+                f"{len(args_strategy)} args + {len(kwargs_strategy)} kwargs)"
+            )
+        self_spec = input_args_strategy[0].strategies[0].output_spec
+
+        if inplace_op and self_spec.placements != input_specs[0].placements:
+            # if it's inplace op, we would only allow the OpSpec to be added when the
+            # input_spec matches the first argument's runtime sharding, otherwise we skip
+            continue
+
+        output_specs: tuple[DTensorSpec | None, ...]
+        if input_index > 1:
+            output_specs = tuple(spec_list[:input_index])
+        else:
+            if spec_list[0] is not None:
+                output_specs = spec_list[0]  # type: ignore[assignment]
+            else:
+                raise RuntimeError("output spec is None")
+
+        # check all inputs are shardable
+        if not all(
+            is_tensor_shardable(inp.shape, s)
+            for inp, s in zip(input_args_strategy, input_specs)
+        ):
+            continue
+
+        # perform additional op-specific filtering
+        if is_valid_strategy_cb is not None:
+            if not is_valid_strategy_cb(input_specs, output_specs):
+                continue
+
+        redistribute_cost = [
+            generate_redistribute_costs(input_strategy, input_spec)
+            for input_strategy, input_spec in zip(input_args_strategy, input_specs)
+        ]
+
+        strategy = OpSpec(
+            output_specs=output_specs,
+            input_specs=input_specs,
+            redistribute_cost=redistribute_cost,
+        )
+        all_strategies.append(strategy)
+    return OpStrategy(all_strategies)
+
+
+def shift_shard_dims_after_insert(
+    placements: Sequence[Placement], insert_dim: int = 0
+) -> Sequence[Placement]:
+    normalized_placements: list[Placement] = []
+    for placement in placements:
+        if isinstance(placement, Shard) and placement.dim >= insert_dim:
+            normalized_placements.append(Shard(placement.dim + 1))
+        else:
+            normalized_placements.append(placement)
+    return normalized_placements
+
+
+def shift_shard_dims_after_remove(
+    placements: Sequence[Placement], remove_dim: int = 0
+) -> Sequence[Placement]:
+    normalized_placements: list[Placement] = []
+    for placement in placements:
+        if isinstance(placement, Shard) and placement.dim > remove_dim:
+            normalized_placements.append(Shard(placement.dim - 1))
+        else:
+            normalized_placements.append(placement)
+    return normalized_placements
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0012040d74a3e0caaf23a71c138681b9c372e591
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/__init__.py
@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+from collections.abc import Iterator
+from contextlib import contextmanager
+
+from torch.distributed.tensor._api import DTensor
+from torch.distributed.tensor.experimental._attention import context_parallel
+from torch.distributed.tensor.experimental._func_map import local_map
+from torch.distributed.tensor.experimental._register_sharding import register_sharding
+
+
+__all__ = ["context_parallel", "implicit_replication", "local_map", "register_sharding"]
+
+
+@contextmanager
+def implicit_replication() -> Iterator[None]:
+    """
+    This context manager allows :class:`DTensor` to implicitly treat all non-DTensors (``torch.Tensor``)
+    in the program be replicate :class:`DTensor` s during the operator computation.
+
+    .. warning:: This might possible lead to incorrect results if ``torch.Tensor`` s are not replicated
+        in practice, please use it at your discretion.
+    """
+    try:
+        DTensor._op_dispatcher._allow_implicit_replication = True
+        yield
+    finally:
+        DTensor._op_dispatcher._allow_implicit_replication = False
+
+
+# Set namespace for exposed private names
+context_parallel.__module__ = "torch.distributed.tensor.experimental"
+implicit_replication.__module__ = "torch.distributed.tensor.experimental"
+local_map.__module__ = "torch.distributed.tensor.experimental"
+register_sharding.__module__ = "torch.distributed.tensor.experimental"
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_attention.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..f238739ddd5cf4f8e120f1e6a0337f0cfc8cc58d
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_attention.py
@@ -0,0 +1,44 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# Backward compatibility stub - this module has been moved to _context_parallel/_attention.py
+
+from ._context_parallel._attention import (
+    _CausalBehavior,
+    _context_parallel_shard,
+    _ContextParallel,
+    _cp_options,
+    _disable_context_parallel_dispatcher,
+    _enable_context_parallel_dispatcher,
+    _is_causal_behavior,
+    _RotateMethod,
+    _templated_ring_attention,
+    context_parallel,
+    context_parallel_unshard,
+    set_rotate_method,
+)
+from ._context_parallel._load_balancer import (
+    _HeadTailLoadBalancer,
+    _LoadBalancer,
+    _PerDocumentHeadTailLoadBalancer,
+    _PTRRLoadBalancer,
+)
+
+
+# TODO(fegin): add deprecation message once the final interfaces are concluded.
+__all__ = [
+    "_CausalBehavior",
+    "_context_parallel_shard",
+    "_ContextParallel",
+    "_cp_options",
+    "_disable_context_parallel_dispatcher",
+    "_enable_context_parallel_dispatcher",
+    "_is_causal_behavior",
+    "_RotateMethod",
+    "_templated_ring_attention",
+    "context_parallel",
+    "context_parallel_unshard",
+    "set_rotate_method",
+    "_HeadTailLoadBalancer",
+    "_LoadBalancer",
+    "_PerDocumentHeadTailLoadBalancer",
+    "_PTRRLoadBalancer",
+]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_func_map.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_func_map.py
new file mode 100644
index 0000000000000000000000000000000000000000..759841a40aaa14b3f985dc7bce730198617ada5b
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_func_map.py
@@ -0,0 +1,278 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+import functools
+from collections.abc import Callable, Sequence
+from typing import Optional, Union
+
+import torch
+from torch.distributed._functional_collectives import AsyncCollectiveTensor
+from torch.distributed.tensor import DeviceMesh, DTensor
+from torch.distributed.tensor.placement_types import Placement
+
+
+try:
+    from torch.utils import _cxx_pytree as pytree
+except ImportError:
+    from torch.utils import _pytree as pytree  # type: ignore[no-redef]
+
+
+__all__ = ["local_map"]
+
+PlacementType = Optional[Sequence[Placement]]
+InputPlacements = Optional[tuple[PlacementType, ...]]
+OutputPlacements = Union[PlacementType, tuple[PlacementType, ...]]
+
+
+def local_map(
+    func: Callable | None = None,
+    out_placements: OutputPlacements = None,
+    in_placements: InputPlacements = None,
+    in_grad_placements: InputPlacements = None,
+    device_mesh: DeviceMesh | None = None,
+    *,
+    redistribute_inputs: bool = False,
+):
+    """
+    :meth:`local_map` is an experimental API that allows users to pass :class:`DTensor` s
+    to a function that is written to be applied on ``torch.Tensor`` s. It is done by extracting
+    the local components of :class:`DTensor`, call the function, and wrap the outputs to
+    :class:`DTensor` according to the ``out_placements``.
+
+    Args:
+        func (Callable): the function to be applied on each local shard of
+            :class:`DTensor` s.
+        out_placements (Union[`PlacementType`, Tuple[`PlacementType`, ...]]):
+            the desired placements of the :class:`DTensor` s in ``func``'s flattened output.
+            If the flattened ``output`` is a single value, the ``out_placements`` should be
+            of type `PlacementType`. Otherwise if the flattened ``output`` has multiple
+            values, the ``out_placements`` should be a tuple of `PlacementType` values 1:1
+            mapping to the flattened ``output``.
+            Besides, for :class:`Tensor` output, we use `PlacementType` as its
+            placements (a `Tuple[Placement]` value). For non-Tensor output, the `PlacementType`
+            should be `None`.
+            Note that the only exception is when no :class:`DTensor` argument is passed
+            in. In this case, even if `out_placements` is not `None`, the result function
+            should ignore the desired placements because the function is not running with
+            :class:`DTensor` s.
+        in_placements (Tuple[`PlacementType`, ...], optional):
+            the required placements of the :class:`DTensor` s in the flattened inputs of ``func``.
+            If ``in_placements`` is specified, :meth:`local_map` would examine whether the
+            placements of each :class:`DTensor` argument is the same as the required
+            placements or not. If the placements are not the same and
+            ``redistribute_inputs`` is ``False``, an exception will be raised. Otherwise if
+            ``redistribute_inputs`` is ``True``, the argument will be first redistributed to
+            the required sharding placements before passing its local tensor to ``func``.
+            The only exception is when required placements are not ``None`` and the
+            argument is a :class:`torch.Tensor`. In this case, the placements examination
+            will be skipped and the argument will be directly passed to ``func``.
+            If ``in_placements`` is ``None``, no placements examination will be performed.
+            Default: None
+        in_grad_placements (Tuple[`PlacementType`, ...], optional):
+            the placements hint of the :class:`DTensor` s gradient corresponds
+            to the flattened input DTensor. This argument is the hint that user
+            can give to :meth:`to_local` in case the gradient layout of the
+            local tensor input does not match its :class:`DTensor` input layout.
+            If not specified, we will assume the gradient layout of the local
+            tensor input remains the same as the original :class:`DTensor` input
+            and use that for gradient computation. Default: None.
+        device_mesh (:class:`DeviceMesh`, optional):
+            the device mesh that the output :class:`DTensor` s are placed on. If not
+            specified, this will be inferred from the first input :class:`DTensor`'s device
+            mesh. Default: None.
+
+    Keyword Args:
+        redistribute_inputs (bool, optional):
+            the bool value indicating whether to reshard the input :class:`DTensor` s when
+            their placements are different from the required input placements. If this
+            value is ``False`` and some :class:`DTensor` input has a different placement,
+            an exception will be raised. Default: False.
+
+    Returns:
+        A ``Callable`` that applies ``func`` to each local shard of the input :class:`DTensor`
+        and returns a :class:`DTensor` constructed from the return value of ``func``.
+
+    Raises:
+        AssertionError: For any non-DTensor output, we require its corresponding
+            output placement in ``out_placements`` be None. An AssertionError will be raised
+            if this is not the case.
+
+        ValueError: If ``redistribute_inputs=False`` but the input :class:`DTensor` needs
+            a redistribution according to ``in_placements``.
+
+    Example:
+        >>> # xdoctest: +SKIP("distributed")
+        >>> def mm_allreduce_forward(device_mesh, W, X):
+        >>>     partial_sum_tensor = torch.mm(W, X)
+        >>>     reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh)
+        >>>     return reduced_tensor
+        >>>
+        >>> W = torch.randn(12, 8, requires_grad=False)
+        >>> X = torch.randn(8, 16, requires_grad=False)
+        >>> Y = torch.mm(W, X)
+        >>> row_wise = [Shard(0)]  # row-wise sharding placements on 1-d mesh
+        >>> col_wise = [Shard(1)]  # col-wise sharding placements on 1-d mesh
+        >>>
+        >>> # local_mm_allreduce_forward is the function wrapped with DTensor/Tensor conversion
+        >>> local_mm_allreduce_forward = local_map(
+        >>>     mm_allreduce_forward,
+        >>>     out_placements=[Replicate()],
+        >>>     in_placements=[col_wise, row_wise],
+        >>>     device_mesh=device_mesh,
+        >>> )
+        >>>
+        >>> W_dt = distribute_tensor(
+        ...     W, device_mesh, (col_wise)
+        ... )  # col-wisely sharded W tensor
+        >>> X_dt = distribute_tensor(
+        ...     X, device_mesh, (row_wise)
+        ... )  # row-wisely sharded X tensor
+        >>> Y_dt = local_mm_allreduce_forward(
+        ...     device_mesh, W_dt, X_dt
+        ... )  # apply local_mm_allreduce_forward to DTensors
+
+    .. note:: This API is currently experimental and subject to change
+    """
+
+    if func is None:
+        # decorator mode
+        def decorated(func):
+            return local_map(
+                func=func,
+                out_placements=out_placements,
+                in_placements=in_placements,
+                in_grad_placements=in_grad_placements,
+                device_mesh=device_mesh,
+                redistribute_inputs=redistribute_inputs,
+            )
+
+        return decorated
+
+    return functools.partial(
+        _local_map_wrapped,
+        func,
+        out_placements,
+        in_placements,
+        in_grad_placements,
+        device_mesh,
+        redistribute_inputs,
+    )
+
+
+def _local_map_wrapped(
+    func: Callable,
+    out_placements: OutputPlacements,
+    in_placements: InputPlacements,
+    in_grad_placements: InputPlacements,
+    device_mesh: DeviceMesh | None,
+    redistribute_inputs: bool,
+    *args,
+    **kwargs,
+):
+    # process input args
+    flat_args, args_spec = pytree.tree_flatten(args)
+    if in_placements is not None:
+        assert len(in_placements) == len(flat_args), (
+            f"in_placements length {len(in_placements)} does not match the number "
+            f"of input args {len(flat_args)}!"
+        )
+
+    # we assume every DTensor object is placed on the same device mesh
+    flat_local_args = []
+    seen_dtensor_arg = False
+    for idx, arg in enumerate(flat_args):
+        if isinstance(arg, DTensor):
+            # TODO: the current code doesn't consider the uneven sharding case
+            # Need to think about what the consequence is when the input DTensor
+            # is uneven sharded.
+            if device_mesh is None:  # infer device mesh from the DTensor arg
+                device_mesh = arg.device_mesh
+
+            # this function is applied to at least one DTensor argument
+            seen_dtensor_arg = True
+
+            if in_placements is not None:
+                spec = in_placements[idx]
+                assert spec is not None, (
+                    f"DTensor input {arg} expects placements but received {spec}!"
+                )
+
+                if not isinstance(spec, tuple):
+                    spec = tuple(spec)
+
+                if arg.placements != spec:
+                    if redistribute_inputs:
+                        # redistribute to input placements
+                        arg = arg.redistribute(placements=spec)
+                    else:
+                        raise ValueError(
+                            f"arg {arg} in local_map has a mismatched placements: "
+                            f"arg placements is {arg.placements} but the input "
+                            f"placements is {spec}! "
+                            "If redistribute_inputs is wanted, set "
+                            "redistribute_inputs=True to local_map."
+                        )
+
+            if in_grad_placements is not None:
+                spec = in_grad_placements[idx]
+                assert spec is not None, (
+                    f"DTensor input {arg} expects in grad placements but received {spec}!"
+                )
+                if not isinstance(spec, tuple):
+                    spec = tuple(spec)
+                local_arg = arg.to_local(grad_placements=spec)
+            else:
+                local_arg = arg.to_local()
+
+            if isinstance(local_arg, AsyncCollectiveTensor):
+                local_arg = local_arg.wait()
+
+            flat_local_args.append(local_arg)
+        else:
+            # Non-Tensor input must have None in `in_placements`
+            if in_placements is not None and not isinstance(arg, torch.Tensor):
+                spec = in_placements[idx]
+                assert spec is None, (
+                    f"Non-Tensor input {arg} expects None placements "
+                    f"but received {spec}!"
+                )
+
+            flat_local_args.append(arg)
+
+    # pyrefly: ignore [bad-argument-type]
+    local_args = pytree.tree_unflatten(flat_local_args, args_spec)
+
+    out = func(*local_args, **kwargs)
+
+    if seen_dtensor_arg:
+        # process output to be DTensor if we've seen DTensor inputs
+        flat_out, out_spec = pytree.tree_flatten(out)
+
+        flat_dist_out = []
+        out_placements_tuple = (
+            out_placements if isinstance(out_placements, tuple) else (out_placements,)
+        )
+        assert len(flat_out) == len(out_placements_tuple), (
+            "local_map requires one PlacementType be provided for each output value,"
+            f" received {len(out_placements_tuple)} out_placements but"
+            f" {len(flat_out)} is expected!"
+        )
+        for out, spec in zip(flat_out, out_placements_tuple):
+            if isinstance(out, torch.Tensor):
+                assert not isinstance(out, DTensor), (
+                    f"torch.Tensor output expected but received {type(out)}: {out}"
+                )
+
+                flat_dist_out.append(
+                    DTensor.from_local(out, device_mesh, spec, run_check=False)
+                )
+            else:
+                assert spec is None, (
+                    f"Non-tensor output {out} expects None placements but received {spec}!"
+                )
+
+                flat_dist_out.append(out)
+
+        # pyrefly: ignore [bad-argument-type]
+        return pytree.tree_unflatten(flat_dist_out, out_spec)
+    else:
+        return out
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_register_sharding.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_register_sharding.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b365dcf286d03be9628c5f909682bcd0a818f7e
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_register_sharding.py
@@ -0,0 +1,136 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+from collections.abc import Callable, Sequence
+from functools import partial
+
+import torch
+from torch._ops import OpOverload
+from torch.distributed.tensor import DTensor
+from torch.distributed.tensor._op_schema import (
+    OpSchema,
+    OpStrategy,
+    PlacementList,
+    RuntimeSchemaInfo,
+    StrategyType,
+    TupleStrategy,
+)
+from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy
+
+
+__all__ = ["register_sharding"]
+
+
+def register_sharding(op: OpOverload | list[OpOverload]):
+    """
+    :meth:`register_sharding` is an experimental API that allows users to register sharding
+    strategies for an operator when the tensor inputs and outputs are DTensor.
+    It can be useful when: (1) there doesn't exist a default sharding strategy for ``op``,
+    e.g. when ``op`` is a custom operator that is not supported by :class:`DTensor`; (2)
+    when users would like to overwrite default sharding strategies of existing operators.
+
+    Args:
+        op (Union[OpOverload, List[OpOverload]]):
+            An op or a list of ops to register the customized sharding function.
+
+    Returns:
+        A function decorator which can be used to wrap a function that defines the sharding
+        strategy for the operator specified in ``op``. The defined sharding strategy will be
+        registered to DTensor and will override the default sharding strategy if DTensor has
+        already implemented the operator. The customized sharding function takes the same inputs
+        as the original op (except that if an arg is a :class:`torch.Tensor`, it will be
+        replaced by a tensor-like object that DTensor uses internally). The function should
+        return a sequence of 2-tuples, each specifying acceptable output placements and its
+        corresponding input placements.
+
+    Example:
+        >>> # xdoctest: +SKIP("distributed")
+        >>> @register_sharding(aten._softmax.default)
+        >>> def custom_softmax_sharding(x, dim, half_to_float):
+        >>>     softmax_dim = dim if dim >= 0 else dim + x.ndim
+        >>>     acceptable_shardings = []
+        >>>
+        >>>     all_replicate = ([Replicate()], [Replicate(), None, None])
+        >>>     acceptable_shardings.append(all_replicate)
+        >>>
+        >>>     for sharding_dim in range(x.ndim):
+        >>>         if sharding_dim != softmax_dim:
+        >>>             all_sharded = (
+        >>>                 [Shard(sharding_dim)],
+        >>>                 [Shard(sharding_dim), None, None],
+        >>>             )
+        >>>             acceptable_shardings.append(all_sharded)
+        >>>
+        >>>     return acceptable_shardings
+
+    .. note:: This API is currently experimental and subject to change
+    """
+
+    def custom_strategy(
+        custom_sharding_fn: Callable[
+            ..., Sequence[tuple[PlacementList, PlacementList]]
+        ],
+        op_schema: OpSchema,
+    ) -> StrategyType:
+        def strategy_to_spec(strategy: object) -> object:
+            if isinstance(strategy, OpStrategy):
+                # take the output spec from the first strategy
+                return strategy.strategies[0].output_spec
+            elif isinstance(strategy, TupleStrategy):
+                return tuple(strategy_to_spec(s) for s in strategy.children)
+            else:
+                return strategy
+
+        mesh = op_schema.get_mesh_from_args()
+
+        args_schema = tuple(strategy_to_spec(i) for i in op_schema.args_schema)
+        kwargs_schema = {
+            k: strategy_to_spec(v) for k, v in op_schema.kwargs_schema.items()
+        }
+
+        acceptable_shardings = custom_sharding_fn(*args_schema, **kwargs_schema)
+
+        single_mesh_dim_strategies: list[PlacementList] = []
+        for output_specs, input_specs in acceptable_shardings:
+            single_mesh_dim_strategies.append(output_specs + input_specs)
+
+        # TODO: handle out variant ops
+        return expand_to_full_mesh_op_strategy(
+            mesh,
+            op_schema,
+            single_mesh_dim_strategies,
+            input_index=len(op_schema.op._schema.returns),
+            inplace_op=op_schema.is_inplace_op(),
+        )
+
+    def wrapper(custom_sharding_fn):
+        def derive_schema_info(op):
+            # NOTE: without user directly providing RuntimeSchemaInfo, for now
+            #       we create it in a conservative fashion as follows:
+            #       1. let static_argnum be the first int argument
+            #       2. let static_kwargkey include all the int type kwargs
+            #       3. always set needs_pytree=True
+            static_argnum = 100
+            static_kwargkey: list[str] = []
+            for i, arg in enumerate(op._schema.arguments):
+                if isinstance(arg.type, torch.IntType) or (
+                    isinstance(arg.type, torch.OptionalType)
+                    and isinstance(arg.type.getElementType(), torch.IntType)
+                ):
+                    static_argnum = min(i, static_argnum)
+                    if arg.kwarg_only:
+                        static_kwargkey.append(arg.name)
+            return RuntimeSchemaInfo(
+                static_argnum, static_kwargkey or None, needs_pytree=True
+            )
+
+        overloads = op if isinstance(op, list) else [op]
+        for overload in overloads:
+            DTensor._op_dispatcher.sharding_propagator.register_op_strategy(
+                overload,
+                partial(custom_strategy, custom_sharding_fn),
+                derive_schema_info(overload),
+            )
+
+        return custom_sharding_fn
+
+    return wrapper
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_tp_transform.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_tp_transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..1075df79f33956d710348330b38f56228ebc871b
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_tp_transform.py
@@ -0,0 +1,557 @@
+# mypy: allow-untyped-defs
+import copy
+import operator
+from collections.abc import Sequence
+from typing import Any, cast
+
+import torch
+from torch._subclasses.fake_tensor import FakeTensor
+from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor
+from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
+from torch.distributed.tensor._op_schema import (
+    OpSchema,
+    OpSpec,
+    OutputSharding,
+    OutputSpecType,
+)
+from torch.distributed.tensor._redistribute import redistribute_local_tensor
+from torch.distributed.tensor.parallel.style import ColwiseParallel, ParallelStyle
+from torch.distributed.tensor.placement_types import Placement, Replicate, Shard
+from torch.export import ExportedProgram
+from torch.export.exported_program import ExportGraphSignature
+from torch.fx import GraphModule
+from torch.fx.experimental.proxy_tensor import make_fx
+from torch.fx.node import Node
+from torch.fx.passes.infra.pass_base import PassBase, PassResult
+from torch.fx.passes.shape_prop import _extract_tensor_metadata
+from torch.utils import _pytree as pytree
+
+
+__all__ = ["tensor_parallel_transformation"]
+
+aten = torch.ops.aten
+
+
+def tensor_parallel_transformation(
+    exported_program: ExportedProgram,
+    rank: int,
+    world_size: int,
+    device_type: str,
+    parallel_strategies: dict[str, ParallelStyle],
+) -> ExportedProgram:
+    """
+    The entry point function to perform graph transformations on an exported program
+    to transform a single-device graph into a tensor parallel graph.
+
+    .. warning::
+        This API is experimental and subject to change.
+    """
+
+    gm = exported_program.graph_module
+    sig = copy.deepcopy(exported_program.graph_signature)
+    state_dict = copy.copy(exported_program.state_dict)
+
+    with gm._set_replace_hook(sig.get_replace_hook()):
+        res = _TensorParallelTransformPass(
+            rank,
+            world_size,
+            device_type,
+            state_dict,
+            exported_program.graph_signature,
+            parallel_strategies,
+        )(gm)
+        assert res is not None
+        gm = res.graph_module
+
+    return exported_program._update(gm, sig, state_dict=state_dict)
+
+
+class _TensorParallelTransformPass(PassBase):
+    """
+    This pass is responsible for transforming a single-device graph into a tensor parallel
+    graph. It will mark the OpSpec of each node in the graph, partition the graph into
+    distributed graph, then shard the parameters/buffers accordingly.
+    """
+
+    def __init__(
+        self,
+        rank: int,
+        world_size: int,
+        device_type: str,
+        state_dict: dict[str, torch.Tensor],
+        graph_signature: ExportGraphSignature,
+        parallel_strategies: dict[str, ParallelStyle],
+    ) -> None:
+        super().__init__()
+        self.rank = rank
+        self.mesh = DeviceMesh(device_type, torch.arange(world_size))
+        self.state_dict: dict[str, torch.Tensor] = state_dict
+        self.graph_signature = graph_signature
+        self.parallel_strategies = parallel_strategies
+
+    def call(self, graph_module) -> PassResult:
+        gm = copy.deepcopy(graph_module)
+
+        parameter_placements = _generate_parameter_and_buffer_placements(
+            list(self.state_dict.keys()), self.parallel_strategies
+        )
+        placement_strategies = _mark_sharding(
+            gm, self.graph_signature, self.mesh, parameter_placements
+        )
+        _partitioner(gm)
+        _shard_state_dict(
+            self.state_dict, placement_strategies, self.graph_signature, self.mesh
+        )
+        return PassResult(gm, True)
+
+
+def _generate_parameter_and_buffer_placements(
+    params_and_buffers: list[str],
+    parallel_strategies: dict[str, ParallelStyle],
+) -> dict[str, Placement]:
+    """
+    Build parameter placements based on the give parallel style of linear layers.
+    """
+    parameter_placements: dict[str, Placement] = {}
+    for linear_fqn, parallel_style in parallel_strategies.items():
+        weight_fqn = f"{linear_fqn}.weight"
+        bias_fqn = f"{linear_fqn}.bias"
+        assert weight_fqn in params_and_buffers
+        parameter_placements[weight_fqn] = (
+            Shard(0) if parallel_style == ColwiseParallel else Shard(1)
+        )
+        if bias_fqn in params_and_buffers:
+            parameter_placements[bias_fqn] = (
+                Shard(0) if parallel_style == ColwiseParallel else Replicate()
+            )
+    return parameter_placements
+
+
+def _mark_tensor_parallel_shardings(
+    gm: GraphModule,
+    graph_signature: ExportGraphSignature,
+    mesh: DeviceMesh,
+    parameter_placements: dict[str, Placement],
+) -> dict[Node, OpSpec]:
+    """
+    Mark the placement strategies of the parameter and buffer placeholder nodes.
+    """
+    placement_strategies: dict[Node, OpSpec] = {}
+    num_params_and_buffers = len(graph_signature.inputs_to_parameters) + len(
+        graph_signature.inputs_to_buffers
+    )
+    placeholder_idx: int = 0
+    for node in gm.graph.nodes:
+        if node.op == "placeholder":
+            if placeholder_idx < num_params_and_buffers:
+                fqn: str = _get_input_node_fqn(node.name, graph_signature)
+                placement: Placement = (
+                    parameter_placements[fqn]
+                    if fqn in parameter_placements
+                    else Replicate()
+                )
+                placement_strategies[node] = _create_placement_strategy(
+                    node,
+                    mesh,
+                    placements=(placement,),
+                )
+                placeholder_idx += 1
+            else:
+                placement_strategies[node] = _create_placement_strategy(
+                    node,
+                    mesh,
+                    placements=(Replicate(),),
+                )
+    return placement_strategies
+
+
+def _get_input_node_fqn(input_name: str, graph_signature: ExportGraphSignature) -> str:
+    """
+    Return the FQN of an input node.
+    """
+    if input_name in graph_signature.inputs_to_parameters:
+        return graph_signature.inputs_to_parameters[input_name]
+    elif input_name in graph_signature.inputs_to_buffers:
+        return graph_signature.inputs_to_buffers[input_name]
+    else:
+        raise ValueError(
+            f"{input_name} not found in inputs_to_parameters or inputs_to_buffers"
+        )
+
+
+def _mark_sharding(
+    gm: GraphModule,
+    graph_signature: ExportGraphSignature,
+    mesh: DeviceMesh,
+    parameter_placements: dict[str, Placement],
+) -> dict[Node, OpSpec]:
+    """
+    Mark the sharding strategy for each node in the graph module.
+    """
+    placement_strategies: dict[Node, OpSpec] = _mark_tensor_parallel_shardings(
+        gm,
+        graph_signature,
+        mesh,
+        parameter_placements,
+    )
+
+    for node in gm.graph.nodes:
+        if node.op == "placeholder":
+            if node not in placement_strategies:
+                placement_strategies[node] = _create_placement_strategy(
+                    node, mesh, placements=(Replicate(),)
+                )
+            node.meta["sharding"] = placement_strategies[node]
+        elif node.op == "call_function":
+            if node.target is operator.getitem:
+                input_nodes = node.all_input_nodes
+                assert len(input_nodes) == 1, (
+                    f"non-compute op only support one input now, found node: {node} with length of inputs: {len(node.args)}"
+                )
+                arg_strategy = placement_strategies[input_nodes[0]]
+                placement_strategies[node] = _create_placement_strategy(
+                    node,
+                    mesh,
+                    placements=arg_strategy.output_spec.placements,
+                    input_specs=_get_input_node_specs(node, placement_strategies),
+                )
+                node.meta["sharding"] = placement_strategies[node]
+            else:
+                op_schema = _get_op_schema(node, placement_strategies)
+
+                # get DTensor specs for inputs and outputs
+                if (
+                    op_schema.op
+                    not in DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs
+                    and op_schema.op
+                    not in DTensor._op_dispatcher.sharding_propagator.op_to_rules
+                ):
+                    # Mark all as replicated
+                    output_sharding = _generate_default_output_sharding(
+                        node,
+                        mesh,
+                        op_schema,
+                    )
+                else:
+                    output_sharding = DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding(  # type: ignore[assignment]
+                        op_schema,
+                    )
+                placement_strategies[node] = OpSpec(
+                    # pyrefly: ignore [bad-argument-type]
+                    output_specs=_get_output_spec_from_output_sharding(output_sharding),
+                    # pyrefly: ignore [missing-attribute]
+                    input_specs=output_sharding.redistribute_schema.args_spec
+                    # pyrefly: ignore [missing-attribute]
+                    if output_sharding.redistribute_schema is not None
+                    else _get_input_node_specs(node, placement_strategies),
+                )
+                node.meta["sharding"] = placement_strategies[node]
+        elif node.op == "output":
+            node.meta["sharding"] = None
+        else:
+            raise RuntimeError(f"op code {node.op} not supported")
+    return placement_strategies
+
+
+def _get_output_spec_from_output_sharding(
+    output_sharding: OutputSharding,
+) -> DTensorSpec:
+    """
+    Util function to extract output spec from output sharding.
+    """
+    if isinstance(output_sharding.output_spec, DTensorSpec):
+        return output_sharding.output_spec
+    else:
+        # For ops that return multiple outputs, the outputs should have the same output spec
+        assert isinstance(output_sharding.output_spec, Sequence)
+        assert output_sharding.output_spec[0] is not None
+        output_sharding.output_spec[0].tensor_meta = None
+        return output_sharding.output_spec[0]
+
+
+def _create_placement_strategy(
+    node: Node,
+    mesh: DeviceMesh,
+    placements: tuple[Placement, ...],
+    input_specs: Sequence[DTensorSpec] | None = None,
+) -> OpSpec:
+    """
+    Util function to construct an OpSpec for a given node.
+    """
+    placement = OpSpec(
+        input_specs=input_specs,
+        output_specs=DTensorSpec(
+            mesh=mesh,
+            placements=placements,
+        ),
+    )
+    _populate_tensor_meta(node, placement.output_specs)
+    return placement
+
+
+def _populate_tensor_meta(node: Node, output_spec: OutputSpecType) -> None:
+    """
+    Util function to populate tensor meta of output_spec based on node metadata.
+    """
+    if isinstance(node.meta["val"], Sequence):
+        assert isinstance(output_spec, Sequence)
+        for spec, fake_tensor in zip(output_spec, node.meta["val"]):
+            assert spec is not None
+            spec.tensor_meta = TensorMeta(
+                shape=fake_tensor.shape,
+                stride=fake_tensor.stride(),
+                dtype=fake_tensor.dtype,
+            )
+    else:
+        assert isinstance(output_spec, DTensorSpec)
+        output_spec.tensor_meta = TensorMeta(
+            shape=node.meta["val"].shape,
+            stride=node.meta["val"].stride(),
+            dtype=node.meta["val"].dtype,
+        )
+
+
+def _generate_default_output_sharding(
+    node: Node,
+    mesh: DeviceMesh,
+    op_schema: OpSchema,
+) -> OutputSharding:
+    """
+    Util function to create a default output sharding that suggests Replicate placement for both args and outputs.
+    """
+
+    def update_arg_spec(arg_spec: DTensorSpec) -> DTensorSpec:
+        return DTensorSpec(
+            mesh=arg_spec.mesh,
+            placements=(Replicate(),),
+            tensor_meta=arg_spec.tensor_meta,
+        )
+
+    new_op_schema = OpSchema(
+        op=op_schema.op,
+        args_schema=pytree.tree_map_only(
+            DTensorSpec, update_arg_spec, op_schema.args_schema
+        ),
+        kwargs_schema=op_schema.kwargs_schema,
+    )
+
+    def create_output_spec(tensor: FakeTensor) -> DTensorSpec:
+        return DTensorSpec(
+            mesh=mesh,
+            placements=(Replicate(),),
+            tensor_meta=TensorMeta(
+                shape=tensor.shape,
+                stride=tensor.stride(),
+                dtype=tensor.dtype,
+            ),
+        )
+
+    return OutputSharding(
+        output_spec=pytree.tree_map_only(
+            FakeTensor, create_output_spec, node.meta["val"]
+        ),
+        redistribute_schema=new_op_schema,
+        needs_redistribute=True,
+    )
+
+
+def _partitioner(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
+    """
+    Graph partitioner that partitions the single device graph
+    to distributed graph
+    """
+    for node in gm.graph.nodes:
+        node_sharding = node.meta["sharding"]
+        if node.op == "placeholder":
+            out_spec = node_sharding.output_spec
+            local_val = _partition_val(node.meta["val"], out_spec)
+            # update node value
+            node.meta["val"] = local_val
+        elif node.op == "call_function":
+            out_spec = node_sharding.output_spec
+            # check if there's misaligned sharding, insert reshard if there is
+            expected_input_specs = node_sharding.input_specs
+            for idx, input_arg in enumerate(node.all_input_nodes):
+                input_arg_sharding = input_arg.meta["sharding"]
+                input_arg_spec = input_arg_sharding.output_spec
+                desired_spec = (
+                    out_spec
+                    if expected_input_specs is None
+                    else expected_input_specs[idx]
+                )
+                if input_arg_spec != desired_spec:
+                    _insert_reshard_gm(
+                        gm, node, input_arg, input_arg_spec, desired_spec
+                    )
+            # convert output val to its local component
+            output_val = node.meta["val"]
+            node.meta["val"] = _partition_val(output_val, out_spec)
+        elif node.op == "output":
+            for input_arg in node.all_input_nodes:
+                # input args of output should be Replicate, otherwise redistribution is needed.
+                input_args_to_check: Sequence[Node] = (
+                    input_arg if isinstance(input_arg, Sequence) else [input_arg]
+                )
+                for arg in input_args_to_check:
+                    arg_sharding = arg.meta["sharding"]
+                    arg_spec = arg_sharding.output_spec
+                    desired_spec = copy.copy(arg_spec)
+                    desired_spec.placements = (Replicate(),)
+                    if arg_spec != desired_spec:
+                        _insert_reshard_gm(gm, node, arg, arg_spec, desired_spec)
+        else:
+            raise RuntimeError(f"op code {node} not supported")
+
+    _clean_up_graph_metadata(gm)
+    gm.graph.lint()
+    gm.recompile()
+    return gm
+
+
+def _partition_val(val: Any, spec: DTensorSpec) -> Any:
+    """
+    util function to convert a full tensor val to its local component
+    """
+    if isinstance(val, torch.Tensor):
+        local_shard = val
+        if val.ndim == 0:
+            # If it's already a scalar tensor, it is already local, we don't
+            # need to do anything
+            return local_shard
+
+        for idx, placement in enumerate(spec.placements):
+            if placement.is_shard():
+                placement = cast(Shard, placement)
+                num_chunks = spec.mesh.size(mesh_dim=idx)
+                my_coord = spec.mesh.get_coordinate()
+                assert my_coord is not None, "current rank not in mesh!"
+                my_coord_on_mesh_dim = my_coord[idx]
+                local_shard = placement._split_tensor(
+                    local_shard, num_chunks, with_padding=False, contiguous=True
+                )[0][my_coord_on_mesh_dim]
+        return local_shard
+    elif isinstance(val, (list, tuple)):
+        return val.__class__(_partition_val(v, spec) for v in val)
+    else:
+        raise RuntimeError(f"val type {type(val)} not supported")
+
+
+def _insert_reshard_gm(
+    gm: torch.fx.GraphModule,
+    node: Node,
+    input_arg: Node,
+    input_arg_spec: DTensorSpec,
+    desired_spec: DTensorSpec,
+) -> None:
+    """
+    Transform the graph for tensor redistribution.
+    """
+    input_arg_spec.tensor_meta = input_arg.meta["tensor_meta"]
+    desired_spec.tensor_meta = input_arg.meta["tensor_meta"]
+    input_arg_tensor = input_arg.meta["val"]
+
+    # insert reshard operation
+    def reshard_fn(local_tensor: torch.Tensor) -> torch.Tensor:
+        return redistribute_local_tensor(
+            local_tensor,
+            input_arg_spec,
+            desired_spec,
+        )
+
+    reshard_gm = make_fx(reshard_fn)(input_arg_tensor)
+    reshard_gm_nodes = list(reshard_gm.graph.nodes)
+    input_node = reshard_gm_nodes[0]
+    with gm.graph.inserting_before(node):
+        # copy nn_module_stack metadata for output, all-reduce nodes
+        for reshard_node in reshard_gm.graph.nodes:
+            if reshard_node.op not in ["placeholder", "output"]:
+                reshard_node.meta["nn_module_stack"] = (
+                    copy.copy(input_arg.meta["nn_module_stack"])
+                    if input_arg.op != "placeholder"
+                    else copy.copy(node.meta["nn_module_stack"])
+                )
+        output_node = gm.graph.graph_copy(
+            reshard_gm.graph,
+            val_map={
+                input_node: input_arg,
+            },
+        )
+    node.replace_input_with(input_arg, output_node)  # type: ignore[arg-type]
+
+
+def _clean_up_graph_metadata(gm: torch.fx.GraphModule) -> None:
+    """
+    Clean up the graph by removing sharding and partitioning related metadata
+    """
+    for node in gm.graph.nodes:
+        if "sharding" in node.meta:
+            del node.meta["sharding"]
+        if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor):
+            local_tensor_meta = _extract_tensor_metadata(node.meta["val"])
+            node.meta["tensor_meta"] = local_tensor_meta
+
+
+def _get_input_node_specs(
+    node: Node, placement_strategies: dict[Node, OpSpec]
+) -> tuple[DTensorSpec, ...]:
+    """
+    Get the input specs of a node.
+    """
+    input_specs_list: list[DTensorSpec] = []
+    for input_arg in node.all_input_nodes:
+        if input_arg in placement_strategies:
+            output_spec = placement_strategies[input_arg].output_specs
+            assert isinstance(output_spec, DTensorSpec)
+            input_specs_list.append(output_spec)
+        else:
+            raise ValueError(f"{input_arg} does not have output_spec populated.")
+    return tuple(input_specs_list)
+
+
+def _get_op_schema(node: Node, placement_strategies: dict[Node, OpSpec]) -> OpSchema:
+    """
+    Util function to construct the operator schema of a node.
+    """
+    args_schema_list = pytree.tree_map_only(
+        Node, lambda arg: placement_strategies[arg].output_specs, node.args
+    )
+    op_schema = OpSchema(
+        op=cast(torch._ops.OpOverload, node.target),
+        args_schema=tuple(args_schema_list),
+        kwargs_schema=cast(dict[str, object], node.kwargs),
+    )
+    return op_schema
+
+
+def _shard_state_dict(
+    state_dict: dict[str, torch.Tensor],
+    placement_strategies: dict[Node, OpSpec],
+    graph_signature: ExportGraphSignature,
+    mesh: DeviceMesh,
+) -> None:
+    """
+    Inplace partition the weights based on the OpSpec
+    """
+    for node, op_spec in placement_strategies.items():
+        if node.op != "placeholder":
+            continue
+        if node.name in graph_signature.inputs_to_parameters:
+            fqn = graph_signature.inputs_to_parameters[node.name]
+        elif node.name in graph_signature.inputs_to_buffers:
+            fqn = graph_signature.inputs_to_buffers[node.name]
+        else:
+            continue
+        assert fqn in state_dict, f"{fqn} not found in state dict: {state_dict.keys()}"
+
+        original_param = state_dict[fqn]
+        dtensor_param = distribute_tensor(
+            original_param,
+            mesh,
+            op_spec.output_spec.placements,
+        )
+        local_param = dtensor_param.to_local()
+        state_dict[fqn] = (
+            torch.nn.Parameter(local_param)
+            if isinstance(original_param, torch.nn.Parameter)
+            else local_param
+        )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/ddp.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/ddp.py
new file mode 100644
index 0000000000000000000000000000000000000000..19c1d3ca5477ee79f418fd3d2de71eac4103c1e4
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/ddp.py
@@ -0,0 +1,104 @@
+# mypy: allow-untyped-defs
+from typing import Any
+
+import torch.nn as nn
+from torch.distributed.tensor.parallel._data_parallel_utils import (
+    _flatten_tensor,
+    _unflatten_tensor,
+)
+
+
+__all__ = []  # type: ignore[var-annotated]
+
+
+def _get_submodule_n_params(module: nn.Module, path: str):
+    """
+    Get submodule and the direct path of parameter from the module
+    """
+    if "." in path:
+        path_list = path.split(".")
+        parent_module_path = ".".join(path_list[:-1])
+        module = module.get_submodule(parent_module_path)
+        path = path_list[-1]
+    return module, path
+
+
+def _update_module_param(param_list: list[tuple[nn.Module, str, nn.Parameter]]):
+    """
+    Update parameters within the module
+    """
+    for item in param_list:
+        parent_module, module_path, t = item
+        assert hasattr(parent_module, module_path)
+        delattr(parent_module, module_path)
+        setattr(parent_module, module_path, t)
+
+
+def _reconstruct_dtensor(module: nn.Module, _input: Any):
+    """
+    Reconstruct DTensor parameters from local tensors
+    """
+    param_list = []
+    # TODO: To add perf optimizations to this iterations
+    for name, t in module.named_parameters():
+        if hasattr(t, "_st_info"):
+            dtensor = _unflatten_tensor(t, t._st_info)
+            param_list.append((*_get_submodule_n_params(module, name), dtensor))
+    _update_module_param(param_list)  # type: ignore[arg-type]
+
+
+def _localize_dtensor(
+    module: nn.Module, *_: Any, ignored_params: set[nn.Parameter] | None = None
+):
+    """
+    Convert DTensor parameters to local tensors
+    """
+    if ignored_params is None:
+        ignored_params = set()
+    param_list = []
+    for name, param in module.named_parameters():
+        if param in ignored_params:
+            continue
+        t, sharding_info = _flatten_tensor(param)
+        if sharding_info is not None:
+            t = nn.Parameter(t)
+            t._st_info = sharding_info  # type: ignore[attr-defined]
+            param_list.append((*_get_submodule_n_params(module, name), t))
+    _update_module_param(param_list)  # type: ignore[arg-type]
+
+
+def _pre_dp_module_transform(module: nn.Module):
+    """
+    Enable the composability between Tensor Parallelism (TP) and Data
+    Parallelism(DP) in PyTorch when using DDP. We need to convert Parameters which
+    are DTensors to local tensors before wrapping with data parallelism API.
+    We then register two hooks, one for converting local tensors back to DTensor
+    preforward and one to convert DTensors back to tensors after Forward. By
+    integrating this way, we avoid any special handling of DTensor parameters by DDP
+    and get DTensor's gradients propagated back to DP, e.g. gradient buckets of DDP.
+
+    For now, this API only works with ``DistributedDataParallel``. It will later support
+    other DP methods such as FSDP.
+
+    Args:
+        module (:class:`nn.Module`):
+            Module which has been applied TP on.
+
+    Example::
+        >>> # xdoctest: +SKIP("distributed")
+        >>> from torch.distributed.tensor.parallel import parallelize_module, PairwiseParallel
+        >>> from torch.nn.parallel import DistributedDataParallel as DDP
+        >>> from torch.distributed.tensor.parallel.ddp import pre_dp_module_transform
+        >>>
+        >>> # Define the module.
+        >>> m = module(...)
+        >>> parallelize_module(m, PairwiseParallel())
+        >>> m = pre_dp_module_transform(m)
+        >>> m = DDP(m)
+        >>>
+    """
+
+    _localize_dtensor(module, None, None)
+    # TODO: To add test cases and ensure that it works for nested modules
+    module.register_forward_pre_hook(_reconstruct_dtensor)
+    module.register_forward_hook(_localize_dtensor)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_dynamism.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_dynamism.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3172d092fd1f1864c7607174455c615d741333eb
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_dynamism.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/accelerator_partitioner.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/accelerator_partitioner.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3460c3b0e27aa86ca20f9973a4c1260711383c29
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/accelerator_partitioner.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/debug.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/debug.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5b5b9be67afc36c44c17ecdbe7568891dae76ad7
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/debug.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/meta_tracer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/meta_tracer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..981db57b8ab0b85b46081b64027adf2606ba944e
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/meta_tracer.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9fd4fdac1bbf0b514683cb0659ed29ee697a69b9
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..900cb7a1a01724210f63bf5def194d083883f2b6
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2035f874497e9c3a6603d1532330018445826f6e
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f1d4d25e5ef1ded4c7f6626bf1eb78d0604694a8
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d83166d5b60217e722ef1641e89a4434e3559c84
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aa68a56fb2c2992128b2def7ad81380bfc0c6e6a
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6fc9272da6768025f3af3bee5797a5b8d845e704
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/z3_types.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/z3_types.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6e793904d3e508a4b4fe9ab88540da37e6793d51
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/z3_types.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b1befa636939e0fd8483fb66e1f5f386943414d4
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/core.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/core.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..89166b24c6a625fd996b549404124661f68a8e31
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/core.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/dispatch.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/dispatch.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2e88609e3e37f0a585d33fed4a00c2431f899ff4
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/dispatch.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/match.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/match.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a101aa8972d402d620655dd8858d672ad662973
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/match.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/more.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/more.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e811d8cd40da9a0d017723c0a52211209d2ee4a2
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/more.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/unification_tools.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/unification_tools.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eb5fed88aa11f3cfd453de5dc791af6a0178b581
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/unification_tools.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..57f896ef044dc7fe520739750a372ae2c45be4e5
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/variable.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/variable.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2d17d8b20c7bf4df2b02f9fa3507a1db17e6a5f8
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/variable.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb7304069243fb45604e165b06b377a5db233a7d
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py
@@ -0,0 +1,7 @@
+from .core import dispatch
+from .dispatcher import (
+    Dispatcher,
+    halt_ordering,
+    MDNotImplementedError,
+    restart_ordering,
+)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..54837533ed87f6c08ba11b758f1ff0ef15ed67d1
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2ad5f02d72447f7336fb79d1b245fe905c60c9e9
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/core.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/core.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ad9153d05d7997675197cf80ee7b20569135b932
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/core.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/dispatcher.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/dispatcher.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c86695cfed4fb4e981bf793f667dfb4aec2dec9d
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/dispatcher.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ee11c2d3245d012782c6fdf6fd2ed9dda0e83410
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/variadic.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/variadic.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e0f514dacf32924ae2f57ae0a20469dc87754349
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/variadic.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/conflict.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/conflict.py
new file mode 100644
index 0000000000000000000000000000000000000000..181e0e8dd167ac8b15d58f612308cdfeca1547e1
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/conflict.py
@@ -0,0 +1,139 @@
+# mypy: allow-untyped-defs
+import operator
+
+from .utils import _toposort, groupby
+from .variadic import isvariadic
+
+
+__all__ = [
+    "AmbiguityWarning",
+    "supercedes",
+    "consistent",
+    "ambiguous",
+    "ambiguities",
+    "super_signature",
+    "edge",
+    "ordering",
+]
+
+
+class AmbiguityWarning(Warning):
+    pass
+
+
+def supercedes(a, b):
+    """A is consistent and strictly more specific than B"""
+    if len(a) < len(b):
+        # only case is if a is empty and b is variadic
+        return not a and len(b) == 1 and isvariadic(b[-1])
+    elif len(a) == len(b):
+        return all(map(issubclass, a, b))
+    else:
+        # len(a) > len(b)
+        p1 = 0
+        p2 = 0
+        while p1 < len(a) and p2 < len(b):
+            cur_a = a[p1]
+            cur_b = b[p2]
+            if not (isvariadic(cur_a) or isvariadic(cur_b)):
+                if not issubclass(cur_a, cur_b):
+                    return False
+                p1 += 1
+                p2 += 1
+            elif isvariadic(cur_a):
+                assert p1 == len(a) - 1
+                return p2 == len(b) - 1 and issubclass(cur_a, cur_b)
+            elif isvariadic(cur_b):
+                assert p2 == len(b) - 1
+                if not issubclass(cur_a, cur_b):
+                    return False
+                p1 += 1
+        return p2 == len(b) - 1 and p1 == len(a)
+
+
+def consistent(a, b):
+    """It is possible for an argument list to satisfy both A and B"""
+
+    # Need to check for empty args
+    if not a:
+        return not b or isvariadic(b[0])
+    if not b:
+        return not a or isvariadic(a[0])
+
+    # Non-empty args check for mutual subclasses
+    if len(a) == len(b):
+        return all(issubclass(aa, bb) or issubclass(bb, aa) for aa, bb in zip(a, b))
+    else:
+        p1 = 0
+        p2 = 0
+        while p1 < len(a) and p2 < len(b):
+            cur_a = a[p1]
+            cur_b = b[p2]
+            if not issubclass(cur_b, cur_a) and not issubclass(cur_a, cur_b):
+                return False
+            if not (isvariadic(cur_a) or isvariadic(cur_b)):
+                p1 += 1
+                p2 += 1
+            elif isvariadic(cur_a):
+                p2 += 1
+            elif isvariadic(cur_b):
+                p1 += 1
+        # We only need to check for variadic ends
+        # Variadic types are guaranteed to be the last element
+        return (
+            isvariadic(cur_a)  # type: ignore[possibly-undefined]
+            and p2 == len(b)
+            or isvariadic(cur_b)  # type: ignore[possibly-undefined]
+            and p1 == len(a)
+        )
+
+
+def ambiguous(a, b):
+    """A is consistent with B but neither is strictly more specific"""
+    return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a))
+
+
+def ambiguities(signatures):
+    """All signature pairs such that A is ambiguous with B"""
+    signatures = list(map(tuple, signatures))
+    return {
+        (a, b)
+        for a in signatures
+        for b in signatures
+        if hash(a) < hash(b)
+        and ambiguous(a, b)
+        and not any(supercedes(c, a) and supercedes(c, b) for c in signatures)
+    }
+
+
+def super_signature(signatures):
+    """A signature that would break ambiguities"""
+    n = len(signatures[0])
+    assert all(len(s) == n for s in signatures)
+
+    return [max((type.mro(sig[i]) for sig in signatures), key=len)[0] for i in range(n)]
+
+
+def edge(a, b, tie_breaker=hash):
+    """A should be checked before B
+    Tie broken by tie_breaker, defaults to ``hash``
+    """
+    # A either supersedes B and B does not supersede A or if B does then call
+    # tie_breaker
+    return supercedes(a, b) and (
+        not supercedes(b, a) or tie_breaker(a) > tie_breaker(b)
+    )
+
+
+def ordering(signatures):
+    """A sane ordering of signatures to check, first to last
+    Topological sort of edges as given by ``edge`` and ``supercedes``
+    """
+    signatures = list(map(tuple, signatures))
+    edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
+    edges = groupby(operator.itemgetter(0), edges)
+    for s in signatures:
+        if s not in edges:
+            edges[s] = []
+    edges = {k: [b for a, b in v] for k, v in edges.items()}  # type: ignore[assignment, attr-defined]
+    return _toposort(edges)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/core.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..69b9f3b2b5a2cb8e9df9d502b4254abffff2dd18
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/core.py
@@ -0,0 +1,92 @@
+# mypy: allow-untyped-defs
+import inspect
+from collections.abc import Callable
+from typing import Any, TypeVar
+from typing_extensions import TypeVarTuple, Unpack
+
+from .dispatcher import Dispatcher, MethodDispatcher
+
+
+global_namespace = {}  # type: ignore[var-annotated]
+
+__all__ = ["dispatch", "ismethod"]
+
+T = TypeVar("T")
+Ts = TypeVarTuple("Ts")
+
+
+def dispatch(
+    *types: Unpack[Ts], **kwargs: Any
+) -> Callable[[Callable[..., T]], Callable[..., T]]:
+    """Dispatch function on the types of the inputs
+    Supports dispatch on all non-keyword arguments.
+    Collects implementations based on the function name.  Ignores namespaces.
+    If ambiguous type signatures occur a warning is raised when the function is
+    defined suggesting the additional method to break the ambiguity.
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> @dispatch(int)
+        ... def f(x):
+        ...     return x + 1
+        >>> @dispatch(float)
+        ... def f(x):
+        ...     return x - 1
+        >>> # xdoctest: +SKIP
+        >>> f(3)
+        4
+        >>> f(3.0)
+        2.0
+        >>> # Specify an isolated namespace with the namespace keyword argument
+        >>> my_namespace = {}
+        >>> @dispatch(int, namespace=my_namespace)
+        ... def foo(x):
+        ...     return x + 1
+        >>> # Dispatch on instance methods within classes
+        >>> class MyClass(object):
+        ...     @dispatch(list)
+        ...     def __init__(self, data):
+        ...         self.data = data
+        ...
+        ...     @dispatch(int)
+        ...     def __init__(self, datum):
+        ...         self.data = [datum]
+        >>> MyClass([1, 2, 3]).data
+        [1, 2, 3]
+        >>> MyClass(3).data
+        [3]
+    """
+    namespace = kwargs.get("namespace", global_namespace)
+
+    types_tuple: tuple[type, ...] = tuple(types)  # type: ignore[arg-type]
+
+    def _df(func):
+        name = func.__name__
+
+        if ismethod(func):
+            dispatcher = inspect.currentframe().f_back.f_locals.get(  # type: ignore[union-attr]
+                name,  # type: ignore[union-attr]
+                MethodDispatcher(name),
+            )
+        else:
+            if name not in namespace:
+                namespace[name] = Dispatcher(name)
+            dispatcher = namespace[name]
+
+        dispatcher.add(types_tuple, func)
+        return dispatcher
+
+    return _df
+
+
+def ismethod(func):
+    """Is func a method?
+    Note that this has to work as the method is defined but before the class is
+    defined.  At this stage methods look like functions.
+    """
+    if hasattr(inspect, "signature"):
+        signature = inspect.signature(func)
+        return signature.parameters.get("self", None) is not None
+    else:
+        spec = inspect.getfullargspec(func)  # type: ignore[union-attr, assignment]
+        return spec and spec.args and spec.args[0] == "self"
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/dispatcher.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/dispatcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2459b82247bce59cd13ed040722c7278bf36ea0
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/dispatcher.py
@@ -0,0 +1,455 @@
+# mypy: allow-untyped-defs
+import inspect
+import itertools as itl
+from typing_extensions import deprecated
+from warnings import warn
+
+from .conflict import ambiguities, AmbiguityWarning, ordering, super_signature
+from .utils import expand_tuples
+from .variadic import isvariadic, Variadic
+
+
+__all__ = [
+    "MDNotImplementedError",
+    "ambiguity_warn",
+    "halt_ordering",
+    "restart_ordering",
+    "variadic_signature_matches_iter",
+    "variadic_signature_matches",
+    "Dispatcher",
+    "source",
+    "MethodDispatcher",
+    "str_signature",
+    "warning_text",
+]
+
+
+class MDNotImplementedError(NotImplementedError):
+    """A NotImplementedError for multiple dispatch"""
+
+
+def ambiguity_warn(dispatcher, ambiguities):
+    """Raise warning when ambiguity is detected
+    Parameters
+    ----------
+    dispatcher : Dispatcher
+        The dispatcher on which the ambiguity was detected
+    ambiguities : set
+        Set of type signature pairs that are ambiguous within this dispatcher
+    See Also:
+        Dispatcher.add
+        warning_text
+    """
+    warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning)
+
+
+@deprecated(
+    "`halt_ordering` is deprecated, you can safely remove this call.",
+    category=FutureWarning,
+)
+def halt_ordering():
+    """Deprecated interface to temporarily disable ordering."""
+
+
+@deprecated(
+    "`restart_ordering` is deprecated, if you would like to eagerly order the dispatchers, "
+    "you should call the `reorder()` method on each dispatcher.",
+    category=FutureWarning,
+)
+def restart_ordering(on_ambiguity=ambiguity_warn):
+    """Deprecated interface to temporarily resume ordering."""
+
+
+def variadic_signature_matches_iter(types, full_signature):
+    """Check if a set of input types matches a variadic signature.
+    Notes
+    -----
+    The algorithm is as follows:
+    Initialize the current signature to the first in the sequence
+    For each type in `types`:
+        If the current signature is variadic
+            If the type matches the signature
+                yield True
+            Else
+                Try to get the next signature
+                If no signatures are left we can't possibly have a match
+                    so yield False
+        Else
+            yield True if the type matches the current signature
+            Get the next signature
+    """
+    sigiter = iter(full_signature)
+    sig = next(sigiter)
+    for typ in types:
+        matches = issubclass(typ, sig)
+        yield matches
+        if not isvariadic(sig):
+            # we're not matching a variadic argument, so move to the next
+            # element in the signature
+            sig = next(sigiter)
+    else:
+        try:
+            sig = next(sigiter)
+        except StopIteration:
+            assert isvariadic(sig)
+            yield True
+        else:
+            # We have signature items left over, so all of our arguments
+            # haven't matched
+            yield False
+
+
+def variadic_signature_matches(types, full_signature):
+    # No arguments always matches a variadic signature
+    assert full_signature
+    return all(variadic_signature_matches_iter(types, full_signature))
+
+
+class Dispatcher:
+    """Dispatch methods based on type signature
+    Use ``dispatch`` to add implementations
+    Examples
+    --------
+    >>> # xdoctest: +SKIP("bad import name")
+    >>> from multipledispatch import dispatch
+    >>> @dispatch(int)
+    ... def f(x):
+    ...     return x + 1
+    >>> @dispatch(float)
+    ... def f(x):
+    ...     return x - 1
+    >>> f(3)
+    4
+    >>> f(3.0)
+    2.0
+    """
+
+    __slots__ = "__name__", "name", "funcs", "_ordering", "_cache", "doc"
+
+    def __init__(self, name, doc=None):
+        self.name = self.__name__ = name
+        self.funcs = {}
+        self.doc = doc
+
+        self._cache = {}
+
+    def register(self, *types, **kwargs):
+        """register dispatcher with new implementation
+        >>> # xdoctest: +SKIP
+        >>> f = Dispatcher("f")
+        >>> @f.register(int)
+        ... def inc(x):
+        ...     return x + 1
+        >>> @f.register(float)
+        ... def dec(x):
+        ...     return x - 1
+        >>> @f.register(list)
+        ... @f.register(tuple)
+        ... def reverse(x):
+        ...     return x[::-1]
+        >>> f(1)
+        2
+        >>> f(1.0)
+        0.0
+        >>> f([1, 2, 3])
+        [3, 2, 1]
+        """
+
+        def _df(func):
+            self.add(types, func, **kwargs)  # type: ignore[call-arg]
+            return func
+
+        return _df
+
+    @classmethod
+    def get_func_params(cls, func):
+        if hasattr(inspect, "signature"):
+            sig = inspect.signature(func)
+            return sig.parameters.values()
+
+    @classmethod
+    def get_func_annotations(cls, func):
+        """get annotations of function positional parameters"""
+        params = cls.get_func_params(func)
+        if params:
+            Parameter = inspect.Parameter
+
+            params = (
+                param
+                for param in params
+                if param.kind
+                in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
+            )
+
+            annotations = tuple(param.annotation for param in params)
+
+            if all(ann is not Parameter.empty for ann in annotations):
+                return annotations
+
+    def add(self, signature, func):
+        """Add new types/method pair to dispatcher
+        >>> # xdoctest: +SKIP
+        >>> D = Dispatcher("add")
+        >>> D.add((int, int), lambda x, y: x + y)
+        >>> D.add((float, float), lambda x, y: x + y)
+        >>> D(1, 2)
+        3
+        >>> D(1, 2.0)
+        Traceback (most recent call last):
+        ...
+        NotImplementedError: Could not find signature for add: 
+        >>> # When ``add`` detects a warning it calls the ``on_ambiguity`` callback
+        >>> # with a dispatcher/itself, and a set of ambiguous type signature pairs
+        >>> # as inputs.  See ``ambiguity_warn`` for an example.
+        """
+        # Handle annotations
+        if not signature:
+            annotations = self.get_func_annotations(func)
+            if annotations:
+                signature = annotations
+
+        # Handle union types
+        if any(isinstance(typ, tuple) for typ in signature):
+            for typs in expand_tuples(signature):
+                self.add(typs, func)
+            return
+
+        new_signature = []
+
+        for index, typ in enumerate(signature, start=1):
+            if not isinstance(typ, (type, list)):
+                str_sig = ", ".join(
+                    c.__name__ if isinstance(c, type) else str(c) for c in signature
+                )
+                raise TypeError(
+                    f"Tried to dispatch on non-type: {typ}\n"
+                    f"In signature: <{str_sig}>\n"
+                    f"In function: {self.name}"
+                )
+
+            # handle variadic signatures
+            if isinstance(typ, list):
+                if index != len(signature):
+                    raise TypeError("Variadic signature must be the last element")
+
+                if len(typ) != 1:
+                    raise TypeError(
+                        "Variadic signature must contain exactly one element. "
+                        "To use a variadic union type place the desired types "
+                        "inside of a tuple, e.g., [(int, str)]"
+                    )
+                # pyrefly: ignore [bad-specialization]
+                new_signature.append(Variadic[typ[0]])
+            else:
+                new_signature.append(typ)
+
+        self.funcs[tuple(new_signature)] = func
+        self._cache.clear()
+
+        try:
+            del self._ordering
+        except AttributeError:
+            pass
+
+    @property
+    def ordering(self):
+        try:
+            return self._ordering
+        except AttributeError:
+            return self.reorder()
+
+    def reorder(self, on_ambiguity=ambiguity_warn):
+        self._ordering = od = ordering(self.funcs)
+        amb = ambiguities(self.funcs)
+        if amb:
+            on_ambiguity(self, amb)
+        return od
+
+    def __call__(self, *args, **kwargs):
+        types = tuple(type(arg) for arg in args)
+        try:
+            func = self._cache[types]
+        except KeyError as e:
+            func = self.dispatch(*types)
+            if not func:
+                raise NotImplementedError(
+                    f"Could not find signature for {self.name}: <{str_signature(types)}>"
+                ) from e
+            self._cache[types] = func
+        try:
+            return func(*args, **kwargs)
+
+        except MDNotImplementedError as e:
+            funcs = self.dispatch_iter(*types)
+            next(funcs)  # burn first
+            for func in funcs:
+                try:
+                    return func(*args, **kwargs)
+                except MDNotImplementedError:
+                    pass
+
+            raise NotImplementedError(
+                "Matching functions for "
+                f"{self.name}: <{str_signature(types)}> found, but none completed successfully",
+            ) from e
+
+    def __str__(self):
+        return f""
+
+    __repr__ = __str__
+
+    def dispatch(self, *types):
+        """Determine appropriate implementation for this type signature
+        This method is internal.  Users should call this object as a function.
+        Implementation resolution occurs within the ``__call__`` method.
+        >>> # xdoctest: +SKIP
+        >>> from multipledispatch import dispatch
+        >>> @dispatch(int)
+        ... def inc(x):
+        ...     return x + 1
+        >>> implementation = inc.dispatch(int)
+        >>> implementation(3)
+        4
+        >>> print(inc.dispatch(float))
+        None
+        See Also:
+          ``multipledispatch.conflict`` - module to determine resolution order
+        """
+
+        if types in self.funcs:
+            return self.funcs[types]
+
+        try:
+            return next(self.dispatch_iter(*types))
+        except StopIteration:
+            return None
+
+    def dispatch_iter(self, *types):
+        n = len(types)
+        for signature in self.ordering:
+            if len(signature) == n and all(map(issubclass, types, signature)):
+                result = self.funcs[signature]
+                yield result
+            elif len(signature) and isvariadic(signature[-1]):
+                if variadic_signature_matches(types, signature):
+                    result = self.funcs[signature]
+                    yield result
+
+    @deprecated(
+        "`resolve()` is deprecated, use `dispatch(*types)`", category=FutureWarning
+    )
+    def resolve(self, types):
+        """Determine appropriate implementation for this type signature
+        .. deprecated:: 0.4.4
+            Use ``dispatch(*types)`` instead
+        """
+        return self.dispatch(*types)
+
+    def __getstate__(self):
+        return {"name": self.name, "funcs": self.funcs}
+
+    def __setstate__(self, d):
+        self.name = d["name"]
+        self.funcs = d["funcs"]
+        self._ordering = ordering(self.funcs)
+        self._cache = {}
+
+    @property
+    def __doc__(self):  # type: ignore[override]
+        docs = [f"Multiply dispatched method: {self.name}"]
+
+        if self.doc:
+            docs.append(self.doc)
+
+        other = []
+        for sig in self.ordering[::-1]:
+            func = self.funcs[sig]
+            if func.__doc__:
+                s = f"Inputs: <{str_signature(sig)}>\n"
+                s += "-" * len(s) + "\n"
+                s += func.__doc__.strip()
+                docs.append(s)
+            else:
+                other.append(str_signature(sig))
+
+        if other:
+            docs.append("Other signatures:\n    " + "\n    ".join(other))
+
+        return "\n\n".join(docs)
+
+    def _help(self, *args):
+        return self.dispatch(*map(type, args)).__doc__
+
+    def help(self, *args, **kwargs):
+        """Print docstring for the function corresponding to inputs"""
+        print(self._help(*args))
+
+    def _source(self, *args):
+        func = self.dispatch(*map(type, args))
+        if not func:
+            raise TypeError("No function found")
+        return source(func)
+
+    def source(self, *args, **kwargs):
+        """Print source code for the function corresponding to inputs"""
+        print(self._source(*args))
+
+
+def source(func):
+    s = f"File: {inspect.getsourcefile(func)}\n\n"
+    s = s + inspect.getsource(func)
+    return s
+
+
+class MethodDispatcher(Dispatcher):
+    """Dispatch methods based on type signature
+    See Also:
+        Dispatcher
+    """
+
+    # pyrefly: ignore [bad-override]
+    __slots__ = ("obj", "cls")
+
+    @classmethod
+    def get_func_params(cls, func):
+        if hasattr(inspect, "signature"):
+            sig = inspect.signature(func)
+            return itl.islice(sig.parameters.values(), 1, None)
+
+    def __get__(self, instance, owner):
+        self.obj = instance
+        self.cls = owner
+        return self
+
+    def __call__(self, *args, **kwargs):
+        types = tuple(type(arg) for arg in args)
+        func = self.dispatch(*types)
+        if not func:
+            raise NotImplementedError(
+                f"Could not find signature for {self.name}: <{str_signature(types)}>"
+            )
+        return func(self.obj, *args, **kwargs)
+
+
+def str_signature(sig):
+    """String representation of type signature
+    >>> str_signature((int, float))
+    'int, float'
+    """
+    return ", ".join(cls.__name__ for cls in sig)
+
+
+def warning_text(name, amb):
+    """The text for ambiguity warnings"""
+    text = f"\nAmbiguities exist in dispatched function {name}\n\n"
+    text += "The following signatures may result in ambiguous behavior:\n"
+    for pair in amb:
+        text += "\t" + ", ".join("[" + str_signature(s) + "]" for s in pair) + "\n"
+    text += "\n\nConsider making the following additions:\n\n"
+    text += "\n\n".join(
+        [
+            "@dispatch(" + str_signature(super_signature(s)) + f")\ndef {name}(...)"
+            for s in amb
+        ]
+    )
+    return text
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b21183c40b97a0757fc5c332cb783f39fc85efe
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py
@@ -0,0 +1,127 @@
+# mypy: allow-untyped-defs
+from collections import OrderedDict
+
+
+__all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"]
+
+
+def raises(err, lamda):  # codespell:ignore lamda
+    try:
+        lamda()  # codespell:ignore lamda
+        return False
+    except err:
+        return True
+
+
+def expand_tuples(L):
+    """
+    >>> expand_tuples([1, (2, 3)])
+    [(1, 2), (1, 3)]
+    >>> expand_tuples([1, 2])
+    [(1, 2)]
+    """
+    if not L:
+        return [()]
+    elif not isinstance(L[0], tuple):
+        rest = expand_tuples(L[1:])
+        return [(L[0],) + t for t in rest]
+    else:
+        rest = expand_tuples(L[1:])
+        return [(item,) + t for t in rest for item in L[0]]
+
+
+# Taken from theano/theano/gof/sched.py
+# Avoids licensing issues because this was written by Matthew Rocklin
+def _toposort(edges):
+    """Topological sort algorithm by Kahn [1] - O(nodes + vertices)
+    inputs:
+        edges - a dict of the form {a: {b, c}} where b and c depend on a
+    outputs:
+        L - an ordered list of nodes that satisfy the dependencies of edges
+    >>> _toposort({1: (2, 3), 2: (3,)})
+    [1, 2, 3]
+    >>> # Closely follows the wikipedia page [2]
+    >>> # [1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
+    >>> # Communications of the ACM
+    >>> # [2] http://en.wikipedia.org/wiki/Toposort#Algorithms
+    """
+    incoming_edges = reverse_dict(edges)
+    incoming_edges = OrderedDict((k, set(val)) for k, val in incoming_edges.items())
+    S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges)
+    L = []
+
+    while S:
+        n, _ = S.popitem()
+        L.append(n)
+        for m in edges.get(n, ()):
+            assert n in incoming_edges[m]
+            incoming_edges[m].remove(n)
+            if not incoming_edges[m]:
+                S[m] = None
+    if any(incoming_edges.get(v, None) for v in edges):
+        raise ValueError("Input has cycles")
+    return L
+
+
+def reverse_dict(d):
+    """Reverses direction of dependence dict
+    >>> d = {"a": (1, 2), "b": (2, 3), "c": ()}
+    >>> reverse_dict(d)  # doctest: +SKIP
+    {1: ('a',), 2: ('a', 'b'), 3: ('b',)}
+    :note: dict order are not deterministic. As we iterate on the
+        input dict, it make the output of this function depend on the
+        dict order. So this function output order should be considered
+        as undeterministic.
+    """
+    result = OrderedDict()  # type: ignore[var-annotated]
+    for key in d:
+        for val in d[key]:
+            result[val] = result.get(val, ()) + (key,)
+    return result
+
+
+# Taken from toolz
+# Avoids licensing issues because this version was authored by Matthew Rocklin
+def groupby(func, seq):
+    """Group a collection by a key function
+    >>> names = ["Alice", "Bob", "Charlie", "Dan", "Edith", "Frank"]
+    >>> groupby(len, names)  # doctest: +SKIP
+    {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
+    >>> iseven = lambda x: x % 2 == 0
+    >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8])  # doctest: +SKIP
+    {False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
+    See Also:
+        ``countby``
+    """
+
+    d = OrderedDict()  # type: ignore[var-annotated]
+    for item in seq:
+        key = func(item)
+        if key not in d:
+            d[key] = []
+        d[key].append(item)
+    return d
+
+
+def typename(type):
+    """Get the name of `type`.
+    Parameters
+    ----------
+    type : Union[Type, Tuple[Type]]
+    Returns
+    -------
+    str
+        The name of `type` or a tuple of the names of the types in `type`.
+    Examples
+    --------
+    >>> typename(int)
+    'int'
+    >>> typename((int, float))
+    '(int, float)'
+    """
+    try:
+        return type.__name__
+    except AttributeError:
+        if len(type) == 1:
+            return typename(*type)
+        return f"({', '.join(map(typename, type))})"
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/variadic.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/variadic.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b5604a152480f83916108cb1b02de3bc9b9adb5
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/variadic.py
@@ -0,0 +1,96 @@
+# mypy: allow-untyped-defs
+from .utils import typename
+
+
+__all__ = ["VariadicSignatureType", "isvariadic", "VariadicSignatureMeta", "Variadic"]
+
+
+class VariadicSignatureType(type):
+    # checking if subclass is a subclass of self
+    def __subclasscheck__(cls, subclass):
+        other_type = subclass.variadic_type if isvariadic(subclass) else (subclass,)
+        return subclass is cls or all(
+            issubclass(other, cls.variadic_type)  # type: ignore[attr-defined]
+            for other in other_type
+        )
+
+    def __eq__(cls, other):
+        """
+        Return True if other has the same variadic type
+        Parameters
+        ----------
+        other : object (type)
+            The object (type) to check
+        Returns
+        -------
+        bool
+            Whether or not `other` is equal to `self`
+        """
+        return isvariadic(other) and set(cls.variadic_type) == set(other.variadic_type)  # type: ignore[attr-defined]
+
+    def __hash__(cls):
+        return hash((type(cls), frozenset(cls.variadic_type)))  # type: ignore[attr-defined]
+
+
+def isvariadic(obj):
+    """Check whether the type `obj` is variadic.
+    Parameters
+    ----------
+    obj : type
+        The type to check
+    Returns
+    -------
+    bool
+        Whether or not `obj` is variadic
+    Examples
+    --------
+    >>> # xdoctest: +SKIP
+    >>> isvariadic(int)
+    False
+    >>> isvariadic(Variadic[int])
+    True
+    """
+    return isinstance(obj, VariadicSignatureType)
+
+
+class VariadicSignatureMeta(type):
+    """A metaclass that overrides ``__getitem__`` on the class. This is used to
+    generate a new type for Variadic signatures. See the Variadic class for
+    examples of how this behaves.
+    """
+
+    def __getitem__(cls, variadic_type):
+        if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)):
+            raise ValueError(
+                "Variadic types must be type or tuple of types"
+                " (Variadic[int] or Variadic[(int, float)]"
+            )
+
+        if not isinstance(variadic_type, tuple):
+            variadic_type = (variadic_type,)
+        return VariadicSignatureType(
+            f"Variadic[{typename(variadic_type)}]",
+            (),
+            dict(variadic_type=variadic_type, __slots__=()),
+        )
+
+
+class Variadic(metaclass=VariadicSignatureMeta):
+    """A class whose getitem method can be used to generate a new type
+    representing a specific variadic signature.
+    Examples
+    --------
+    >>> # xdoctest: +SKIP
+    >>> Variadic[int]  # any number of int arguments
+    
+    >>> Variadic[(int, str)]  # any number of one of int or str arguments
+    
+    >>> issubclass(int, Variadic[int])
+    True
+    >>> issubclass(int, Variadic[(int, str)])
+    True
+    >>> issubclass(str, Variadic[(int, str)])
+    True
+    >>> issubclass(float, Variadic[(int, str)])
+    False
+    """
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/common.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/common.h
new file mode 100644
index 0000000000000000000000000000000000000000..ddfd338bf68f0b5602ff46e655a35ecd59071bd7
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/common.h
@@ -0,0 +1,207 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc.  All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+//     * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+//     * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+//     * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: kenton@google.com (Kenton Varda) and others
+//
+// Contains basic types and utilities used by the rest of the library.
+
+#ifndef GOOGLE_PROTOBUF_COMMON_H__
+#define GOOGLE_PROTOBUF_COMMON_H__
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+
+#ifndef PROTOBUF_USE_EXCEPTIONS
+#if defined(_MSC_VER) && defined(_CPPUNWIND)
+  #define PROTOBUF_USE_EXCEPTIONS 1
+#elif defined(__EXCEPTIONS)
+  #define PROTOBUF_USE_EXCEPTIONS 1
+#else
+  #define PROTOBUF_USE_EXCEPTIONS 0
+#endif
+#endif
+
+#if PROTOBUF_USE_EXCEPTIONS
+#include 
+#endif
+#if defined(__APPLE__)
+#include   // for TARGET_OS_IPHONE
+#endif
+
+#if defined(__ANDROID__) || defined(GOOGLE_PROTOBUF_OS_ANDROID) || (defined(TARGET_OS_IPHONE) && TARGET_OS_IPHONE) || defined(GOOGLE_PROTOBUF_OS_IPHONE)
+#include 
+#endif
+
+#include 
+
+namespace std {}
+
+namespace google {
+namespace protobuf {
+namespace internal {
+
+// Some of these constants are macros rather than const ints so that they can
+// be used in #if directives.
+
+// The current version, represented as a single integer to make comparison
+// easier:  major * 10^6 + minor * 10^3 + micro
+#define GOOGLE_PROTOBUF_VERSION 3013000
+
+// A suffix string for alpha, beta or rc releases. Empty for stable releases.
+#define GOOGLE_PROTOBUF_VERSION_SUFFIX ""
+
+// The minimum header version which works with the current version of
+// the library.  This constant should only be used by protoc's C++ code
+// generator.
+static const int kMinHeaderVersionForLibrary = 3013000;
+
+// The minimum protoc version which works with the current version of the
+// headers.
+#define GOOGLE_PROTOBUF_MIN_PROTOC_VERSION 3013000
+
+// The minimum header version which works with the current version of
+// protoc.  This constant should only be used in VerifyVersion().
+static const int kMinHeaderVersionForProtoc = 3013000;
+
+// Verifies that the headers and libraries are compatible.  Use the macro
+// below to call this.
+void PROTOBUF_EXPORT VerifyVersion(int headerVersion, int minLibraryVersion,
+                                   const char* filename);
+
+// Converts a numeric version number to a string.
+std::string PROTOBUF_EXPORT VersionString(int version);
+
+}  // namespace internal
+
+// Place this macro in your main() function (or somewhere before you attempt
+// to use the protobuf library) to verify that the version you link against
+// matches the headers you compiled against.  If a version mismatch is
+// detected, the process will abort.
+#define GOOGLE_PROTOBUF_VERIFY_VERSION                                    \
+  ::google::protobuf::internal::VerifyVersion(                            \
+    GOOGLE_PROTOBUF_VERSION, GOOGLE_PROTOBUF_MIN_LIBRARY_VERSION,         \
+    __FILE__)
+
+
+// ===================================================================
+// from google3/util/utf8/public/unilib.h
+
+class StringPiece;
+namespace internal {
+
+// Checks if the buffer contains structurally-valid UTF-8.  Implemented in
+// structurally_valid.cc.
+PROTOBUF_EXPORT bool IsStructurallyValidUTF8(const char* buf, int len);
+
+inline bool IsStructurallyValidUTF8(StringPiece str) {
+  return IsStructurallyValidUTF8(str.data(), static_cast(str.length()));
+}
+
+// Returns initial number of bytes of structurally valid UTF-8.
+PROTOBUF_EXPORT int UTF8SpnStructurallyValid(StringPiece str);
+
+// Coerce UTF-8 byte string in src_str to be
+// a structurally-valid equal-length string by selectively
+// overwriting illegal bytes with replace_char (typically ' ' or '?').
+// replace_char must be legal printable 7-bit Ascii 0x20..0x7e.
+// src_str is read-only.
+//
+// Returns pointer to output buffer, src_str.data() if no changes were made,
+//  or idst if some bytes were changed. idst is allocated by the caller
+//  and must be at least as big as src_str
+//
+// Optimized for: all structurally valid and no byte copying is done.
+//
+PROTOBUF_EXPORT char* UTF8CoerceToStructurallyValid(StringPiece str, char* dst,
+                                                    char replace_char);
+
+}  // namespace internal
+
+// This lives in message_lite.h now, but we leave this here for any users that
+// #include common.h and not message_lite.h.
+PROTOBUF_EXPORT void ShutdownProtobufLibrary();
+
+namespace internal {
+
+// Strongly references the given variable such that the linker will be forced
+// to pull in this variable's translation unit.
+template 
+void StrongReference(const T& var) {
+  auto volatile unused = &var;
+  (void)&unused;  // Use address to avoid an extra load of "unused".
+}
+
+}  // namespace internal
+
+#if PROTOBUF_USE_EXCEPTIONS
+class FatalException : public std::exception {
+ public:
+  FatalException(const char* filename, int line, const std::string& message)
+      : filename_(filename), line_(line), message_(message) {}
+  virtual ~FatalException() throw();
+
+  virtual const char* what() const throw();
+
+  const char* filename() const { return filename_; }
+  int line() const { return line_; }
+  const std::string& message() const { return message_; }
+
+ private:
+  const char* filename_;
+  const int line_;
+  const std::string message_;
+};
+#endif
+
+// This is at the end of the file instead of the beginning to work around a bug
+// in some versions of MSVC.
+using std::string;
+
+}  // namespace protobuf
+}  // namespace google
+
+#include 
+
+#endif  // GOOGLE_PROTOBUF_COMMON_H__
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/util/field_mask_util.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/util/field_mask_util.h
new file mode 100644
index 0000000000000000000000000000000000000000..00c971fb3c1cddcead3b6750d4578068462a5402
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/util/field_mask_util.h
@@ -0,0 +1,266 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc.  All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+//     * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+//     * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+//     * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Defines utilities for the FieldMask well known type.
+
+#ifndef GOOGLE_PROTOBUF_UTIL_FIELD_MASK_UTIL_H__
+#define GOOGLE_PROTOBUF_UTIL_FIELD_MASK_UTIL_H__
+
+#include 
+
+#include 
+#include 
+#include 
+
+// Must be included last.
+#include 
+
+namespace google {
+namespace protobuf {
+namespace util {
+
+class PROTOBUF_EXPORT FieldMaskUtil {
+  typedef google::protobuf::FieldMask FieldMask;
+
+ public:
+  // Converts FieldMask to/from string, formatted by separating each path
+  // with a comma (e.g., "foo_bar,baz.quz").
+  static std::string ToString(const FieldMask& mask);
+  static void FromString(StringPiece str, FieldMask* out);
+
+  // Populates the FieldMask with the paths corresponding to the fields with the
+  // given numbers, after checking that all field numbers are valid.
+  template 
+  static void FromFieldNumbers(const std::vector& field_numbers,
+                               FieldMask* out) {
+    for (const auto field_number : field_numbers) {
+      const FieldDescriptor* field_desc =
+          T::descriptor()->FindFieldByNumber(field_number);
+      GOOGLE_CHECK(field_desc != nullptr)
+          << "Invalid field number for " << T::descriptor()->full_name() << ": "
+          << field_number;
+      AddPathToFieldMask(field_desc->lowercase_name(), out);
+    }
+  }
+
+  // Converts FieldMask to/from string, formatted according to proto3 JSON
+  // spec for FieldMask (e.g., "fooBar,baz.quz"). If the field name is not
+  // style conforming (i.e., not snake_case when converted to string, or not
+  // camelCase when converted from string), the conversion will fail.
+  static bool ToJsonString(const FieldMask& mask, std::string* out);
+  static bool FromJsonString(StringPiece str, FieldMask* out);
+
+  // Get the descriptors of the fields which the given path from the message
+  // descriptor traverses, if field_descriptors is not null.
+  // Return false if the path is not valid, and the content of field_descriptors
+  // is unspecified.
+  static bool GetFieldDescriptors(
+      const Descriptor* descriptor, StringPiece path,
+      std::vector* field_descriptors);
+
+  // Checks whether the given path is valid for type T.
+  template 
+  static bool IsValidPath(StringPiece path) {
+    return GetFieldDescriptors(T::descriptor(), path, nullptr);
+  }
+
+  // Checks whether the given FieldMask is valid for type T.
+  template 
+  static bool IsValidFieldMask(const FieldMask& mask) {
+    for (int i = 0; i < mask.paths_size(); ++i) {
+      if (!GetFieldDescriptors(T::descriptor(), mask.paths(i), nullptr))
+        return false;
+    }
+    return true;
+  }
+
+  // Adds a path to FieldMask after checking whether the given path is valid.
+  // This method check-fails if the path is not a valid path for type T.
+  template 
+  static void AddPathToFieldMask(StringPiece path, FieldMask* mask) {
+    GOOGLE_CHECK(IsValidPath(path)) << path;
+    mask->add_paths(std::string(path));
+  }
+
+  // Creates a FieldMask with all fields of type T. This FieldMask only
+  // contains fields of T but not any sub-message fields.
+  template 
+  static FieldMask GetFieldMaskForAllFields() {
+    FieldMask out;
+    GetFieldMaskForAllFields(T::descriptor(), &out);
+    return out;
+  }
+  template 
+  PROTOBUF_DEPRECATED_MSG("Use *out = GetFieldMaskForAllFields() instead")
+  static void GetFieldMaskForAllFields(FieldMask* out) {
+    GetFieldMaskForAllFields(T::descriptor(), out);
+  }
+  // This flavor takes the protobuf type descriptor as an argument.
+  // Useful when the type is not known at compile time.
+  static void GetFieldMaskForAllFields(const Descriptor* descriptor,
+                                       FieldMask* out);
+
+  // Converts a FieldMask to the canonical form. It will:
+  //   1. Remove paths that are covered by another path. For example,
+  //      "foo.bar" is covered by "foo" and will be removed if "foo"
+  //      is also in the FieldMask.
+  //   2. Sort all paths in alphabetical order.
+  static void ToCanonicalForm(const FieldMask& mask, FieldMask* out);
+
+  // Creates an union of two FieldMasks.
+  static void Union(const FieldMask& mask1, const FieldMask& mask2,
+                    FieldMask* out);
+
+  // Creates an intersection of two FieldMasks.
+  static void Intersect(const FieldMask& mask1, const FieldMask& mask2,
+                        FieldMask* out);
+
+  // Subtracts mask2 from mask1 base of type T.
+  template 
+  static void Subtract(const FieldMask& mask1, const FieldMask& mask2,
+                       FieldMask* out) {
+    Subtract(T::descriptor(), mask1, mask2, out);
+  }
+  // This flavor takes the protobuf type descriptor as an argument.
+  // Useful when the type is not known at compile time.
+  static void Subtract(const Descriptor* descriptor, const FieldMask& mask1,
+                       const FieldMask& mask2, FieldMask* out);
+
+  // Returns true if path is covered by the given FieldMask. Note that path
+  // "foo.bar" covers all paths like "foo.bar.baz", "foo.bar.quz.x", etc.
+  // Also note that parent paths are not covered by explicit child path, i.e.
+  // "foo.bar" does NOT cover "foo", even if "bar" is the only child.
+  static bool IsPathInFieldMask(StringPiece path, const FieldMask& mask);
+
+  class MergeOptions;
+  // Merges fields specified in a FieldMask into another message.
+  static void MergeMessageTo(const Message& source, const FieldMask& mask,
+                             const MergeOptions& options, Message* destination);
+
+  class TrimOptions;
+  // Removes from 'message' any field that is not represented in the given
+  // FieldMask. If the FieldMask is empty, does nothing.
+  // Returns true if the message is modified.
+  static bool TrimMessage(const FieldMask& mask, Message* message);
+
+  // Removes from 'message' any field that is not represented in the given
+  // FieldMask with customized TrimOptions.
+  // If the FieldMask is empty, does nothing.
+  // Returns true if the message is modified.
+  static bool TrimMessage(const FieldMask& mask, Message* message,
+                          const TrimOptions& options);
+
+ private:
+  friend class SnakeCaseCamelCaseTest;
+  // Converts a field name from snake_case to camelCase:
+  //   1. Every character after "_" will be converted to uppercase.
+  //   2. All "_"s are removed.
+  // The conversion will fail if:
+  //   1. The field name contains uppercase letters.
+  //   2. Any character after a "_" is not a lowercase letter.
+  // If the conversion succeeds, it's guaranteed that the resulted
+  // camelCase name will yield the original snake_case name when
+  // converted using CamelCaseToSnakeCase().
+  //
+  // Note that the input can contain characters not allowed in C identifiers.
+  // For example, "foo_bar,baz_quz" will be converted to "fooBar,bazQuz"
+  // successfully.
+  static bool SnakeCaseToCamelCase(StringPiece input,
+                                   std::string* output);
+  // Converts a field name from camelCase to snake_case:
+  //   1. Every uppercase letter is converted to lowercase with an additional
+  //      preceding "_".
+  // The conversion will fail if:
+  //   1. The field name contains "_"s.
+  // If the conversion succeeds, it's guaranteed that the resulted
+  // snake_case name will yield the original camelCase name when
+  // converted using SnakeCaseToCamelCase().
+  //
+  // Note that the input can contain characters not allowed in C identifiers.
+  // For example, "fooBar,bazQuz" will be converted to "foo_bar,baz_quz"
+  // successfully.
+  static bool CamelCaseToSnakeCase(StringPiece input,
+                                   std::string* output);
+};
+
+class PROTOBUF_EXPORT FieldMaskUtil::MergeOptions {
+ public:
+  MergeOptions()
+      : replace_message_fields_(false), replace_repeated_fields_(false) {}
+  // When merging message fields, the default behavior is to merge the
+  // content of two message fields together. If you instead want to use
+  // the field from the source message to replace the corresponding field
+  // in the destination message, set this flag to true. When this flag is set,
+  // specified submessage fields that are missing in source will be cleared in
+  // destination.
+  void set_replace_message_fields(bool value) {
+    replace_message_fields_ = value;
+  }
+  bool replace_message_fields() const { return replace_message_fields_; }
+  // The default merging behavior will append entries from the source
+  // repeated field to the destination repeated field. If you only want
+  // to keep the entries from the source repeated field, set this flag
+  // to true.
+  void set_replace_repeated_fields(bool value) {
+    replace_repeated_fields_ = value;
+  }
+  bool replace_repeated_fields() const { return replace_repeated_fields_; }
+
+ private:
+  bool replace_message_fields_;
+  bool replace_repeated_fields_;
+};
+
+class PROTOBUF_EXPORT FieldMaskUtil::TrimOptions {
+ public:
+  TrimOptions() : keep_required_fields_(false) {}
+  // When trimming message fields, the default behavior is to trim required
+  // fields of the present message if they are not specified in the field mask.
+  // If you instead want to keep required fields of the present message even
+  // they are not specified in the field mask, set this flag to true.
+  void set_keep_required_fields(bool value) { keep_required_fields_ = value; }
+  bool keep_required_fields() const { return keep_required_fields_; }
+
+ private:
+  bool keep_required_fields_;
+};
+
+}  // namespace util
+}  // namespace protobuf
+}  // namespace google
+
+#include 
+
+#endif  // GOOGLE_PROTOBUF_UTIL_FIELD_MASK_UTIL_H__
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/numa/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/numa/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ea12b8ed9d1800af92429174e5a96ae775fe0de
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/numa/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/numa/__pycache__/binding.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/numa/__pycache__/binding.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bab0320e3935ee66bcf9773f54fd66ec7e2707b5
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/numa/__pycache__/binding.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cd0fb920d895b8425da4c2ea4731c455d6ba7572
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/check_kernel_launches.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/check_kernel_launches.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..562e902b1ee882fa103849df17111eee4c6b9977
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/check_kernel_launches.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_cuda.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_cuda.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..847f45da8eee01a8f67fdc4e86996d3c94fd7e32
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_cuda.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_device_type.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_device_type.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a5c2f0d98b3641bcd9191a9f4ef30ad24abd7db6
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_device_type.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_fsdp.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_fsdp.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8ff0a9721a22620f3499845d1ec5828b593954ed
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_fsdp.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_mps.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_mps.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..564e8d58a1122a26aab1db3eb4ca1ee893319569
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_mps.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_pruning.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_pruning.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7b1f0c60f344bdf4d2c5b67038cab89b46613277
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_pruning.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/custom_op_db.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/custom_op_db.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d9935bfe6aafef7a88fc28951ac35947c00319e1
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/custom_op_db.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/dist_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/dist_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3e0252ce1b48560a768c7bc1b7d1b54f79bb7b2b
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/dist_utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/dynamo_pytree_test_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/dynamo_pytree_test_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5c1d66f9121c58e4ff6ce271cccf552a5170c8ed
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/dynamo_pytree_test_utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/fake_config_module3.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/fake_config_module3.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5967d9c83925cce5e61d3ebe535a791a5f91d2f8
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/fake_config_module3.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/hop_db.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/hop_db.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7a20b76250a460d7e93c087c9df94f9365561a50
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/hop_db.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/quantization_torch_package_models.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/quantization_torch_package_models.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c0d0a136c03ab2a88c9db521914861d939dbf2b
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/quantization_torch_package_models.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/static_module.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/static_module.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b20e822c9d92de7bdad7ab8b339be6e4ee8f412
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/static_module.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/triton_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/triton_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..924902ac621c486959d00650d43b17d7fc877649
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/triton_utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/two_tensor.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/two_tensor.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..05b23e4438c57222c957c2b538d24c16c9a04d2f
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/two_tensor.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/codegen/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/codegen/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..04cf6bcca217fd80d15ab92de81d3bef9f59023b
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/codegen/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/checkpoint_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/checkpoint_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..49a57ca2639916b24d2aa6fc2fed5a7051aa3d91
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/checkpoint_utils.py
@@ -0,0 +1,194 @@
+# mypy: allow-untyped-defs
+
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+import io
+import logging
+import os
+import shutil
+import tempfile
+from collections.abc import Callable
+from functools import wraps
+from typing import Any, cast, IO, Optional
+
+# introduced as collections.abc.Buffer in Python 3.12
+from typing_extensions import Buffer
+
+import torch.distributed as dist
+from torch.distributed.checkpoint._extension import (
+    ExtensionRegistry,
+    StreamTransformExtension,
+)
+
+
+class Rot13Example(StreamTransformExtension):
+    """
+    This is an example stream transform extension which just does rot13 on each
+    alphanumeric character of the stream.  It is mainly intended as a demonstration
+    and for testing; there isn't a production use case for this.
+    """
+
+    def __init__(self, chunk_size: int = io.DEFAULT_BUFFER_SIZE) -> None:
+        super().__init__()
+        self._chunk_size = chunk_size
+
+    @staticmethod
+    def from_descriptor(version: str) -> "Rot13Example":
+        if version.partition(".")[0] != "1":
+            raise ValueError(f"Unknown extension {version=}")
+        return Rot13Example()
+
+    @staticmethod
+    def registry_name() -> str:
+        return "stream.rot13"
+
+    def get_descriptor(self) -> str:
+        return f"{self.registry_name()}/1"
+
+    @staticmethod
+    def _rot13bytes(b: Buffer, count: int) -> None:
+        b = memoryview(b)
+        for i in range(count):
+            ch = b[i]
+            if ch >= ord("A") and ch <= ord("Z"):
+                ch += ord("a") - ord("A")
+            elif ch >= ord("a") and ch <= ord("z"):
+                ch += ord("A") - ord("a")
+            b[i] = ch
+
+    def transform_to(self, output: IO[bytes]) -> IO[bytes]:
+        class Writer(io.RawIOBase):
+            def __init__(self, output: IO[bytes]) -> None:
+                self.output = output
+
+            def writeable(self) -> bool:
+                return True
+
+            def write(self, b: Buffer) -> Optional[int]:
+                # Don't mutate the input
+                chunk = bytearray(b)
+                Rot13Example._rot13bytes(chunk, len(chunk))
+                return self.output.write(chunk)
+
+            def flush(self) -> None:
+                self.output.flush()
+
+        return cast(IO[bytes], Writer(output))
+
+    def transform_from(self, input: IO[bytes]) -> IO[bytes]:
+        class Reader(io.RawIOBase):
+            def __init__(self, input: IO[bytes]) -> None:
+                self.input = input
+
+            def readable(self) -> bool:
+                return True
+
+            def readinto(self, b: Buffer) -> Optional[int]:
+                if hasattr(self.input, "readinto"):
+                    count = self.input.readinto(b)
+                else:
+                    # It's possible self.input is an IO[bytes] with no readinto method.
+                    # In that case, we emulate with a read and copy.  In practice,
+                    # all of the current concrete extensions have readinto.
+                    view = memoryview(b)
+                    r = self.input.read(len(view))
+                    if r is None:
+                        count = None
+                    else:
+                        count = len(r)
+                        view[:count] = r
+                if count == 0 or count is None:
+                    return count
+
+                Rot13Example._rot13bytes(b, count)
+                return count
+
+            def seekable(self) -> bool:
+                return self.input.seekable()
+
+            def seek(self, offset: int, whence: int = os.SEEK_SET) -> int:
+                return self.input.seek(offset, whence)
+
+            def tell(self) -> int:
+                return self.input.tell()
+
+        return cast(IO[bytes], Reader(input))
+
+
+def get_test_extension_registry() -> ExtensionRegistry:
+    registry = ExtensionRegistry()
+    registry.register(Rot13Example)
+    return registry
+
+
+def with_temp_dir(
+    func: Optional[Callable] = None,
+) -> Optional[Callable]:
+    """
+    Wrapper to initialize temp directory for distributed checkpoint.
+    """
+    assert func is not None
+
+    @wraps(func)
+    def wrapper(self, *args: tuple[object], **kwargs: dict[str, Any]) -> None:
+        if dist.is_initialized():
+            # Only create temp_dir when rank is 0
+            if dist.get_rank() == 0:
+                temp_dir = tempfile.mkdtemp()
+                print(f"Using temp directory: {temp_dir}")
+            else:
+                temp_dir = ""
+            object_list = [temp_dir]
+
+            # Broadcast temp_dir to all the other ranks
+            os.sync()
+            dist.broadcast_object_list(object_list)
+            self.temp_dir = object_list[0]
+            os.sync()
+        else:
+            temp_dir = tempfile.mkdtemp()
+            print(f"No process group initialized, using temp directory: {temp_dir}")
+            self.temp_dir = temp_dir
+
+        try:
+            func(self, *args, **kwargs)
+        finally:
+            if dist.is_initialized() and dist.get_rank() == 0:
+                shutil.rmtree(self.temp_dir, ignore_errors=True)
+            else:
+                shutil.rmtree(self.temp_dir, ignore_errors=True)
+
+    return wrapper
+
+
+def with_checkpoint_logging(
+    func: Optional[Callable] = None,
+    logger_name: str = "torch.distributed.checkpoint",
+    level: int = logging.INFO,
+) -> Optional[Callable]:
+    """
+    Wrapper to configure checkpoint logging for distributed tests.
+
+    Args:
+        func: The test function to wrap
+        logger_name: Name of the logger to configure (default: 'torch.distributed.checkpoint')
+        level: Logging level to set (default: logging.INFO)
+    """
+    assert func is not None
+
+    @wraps(func)
+    def wrapper(self, *args: tuple[object], **kwargs: dict[str, Any]) -> None:
+        # Get the logger and store original level
+        target_logger = logging.getLogger(logger_name)
+        original_level = target_logger.level
+
+        # Set the desired logging level
+        target_logger.setLevel(level)
+
+        try:
+            func(self, *args, **kwargs)
+        finally:
+            # Restore original logging level
+            target_logger.setLevel(original_level)
+
+    return wrapper
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/common_state_dict.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/common_state_dict.py
new file mode 100644
index 0000000000000000000000000000000000000000..a78e312306ba2500afa3722d6271c645d25f97cf
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/common_state_dict.py
@@ -0,0 +1,170 @@
+# mypy: allow-untyped-defs
+
+# Owner(s): ["oncall: distributed"]
+
+import copy
+from itertools import chain
+from typing import Any
+
+import torch
+import torch.nn as nn
+from torch.distributed._sharded_tensor import ShardedTensor
+from torch.distributed._state_dict_utils import _gather_state_dict
+from torch.distributed.checkpoint.state_dict import (
+    _PG,
+    _STATE,
+    set_state_dict,
+    StateDictOptions,
+)
+from torch.distributed.tensor import DTensor
+
+
+class VerifyStateDictMixin:
+    def _compare_tensor(self, orig_tensor, dist_tensor, offload_to_cpu=False):
+        if isinstance(dist_tensor, (DTensor, ShardedTensor)):
+            dist_tensor = _gather_state_dict({"mykey": dist_tensor}).pop("mykey")
+
+        if offload_to_cpu:
+            orig_tensor = orig_tensor.cpu()
+            dist_tensor = dist_tensor.cpu()
+        self.assertTrue(isinstance(dist_tensor, torch.Tensor))
+        self.assertTrue(torch.allclose(orig_tensor, dist_tensor))
+
+    def _verify_msd(
+        self,
+        msd: dict[str, Any],
+        dist_msd: dict[str, Any],
+        options: StateDictOptions = StateDictOptions(),
+        offload_to_cpu=False,
+    ) -> None:
+        if not options.ignore_frozen_params:
+            self.assertEqual(len(msd), len(dist_msd))
+        for fqn, param in msd.items():
+            dist_param = dist_msd.get(fqn)
+            if not options.ignore_frozen_params:
+                self.assertIsNotNone(dist_param, f"{fqn=}")
+                try:
+                    self._compare_tensor(param, dist_param, offload_to_cpu)
+                except AssertionError as e:
+                    raise AssertionError(
+                        f"{fqn} has mismatched value {param} {dist_param}"
+                    ) from e
+            elif dist_param is None:
+                self.assertFalse(param.requires_grad, f"{fqn=}")
+
+    def _verify_osd(
+        self,
+        model: nn.Module,
+        optim: torch.optim.Optimizer,
+        osd: dict[str, Any],
+        dist_osd: dict[str, Any],
+    ) -> None:
+        params = list(chain.from_iterable(g["params"] for g in optim.param_groups))
+        param_pid_mapping = dict(zip(params, range(len(params)), strict=True))
+        fqn_pid_mapping = {}
+        for fqn, param in model.named_parameters():
+            pid = param_pid_mapping[param]
+            fqn_pid_mapping[fqn] = pid
+            fqn_pid_mapping[pid] = fqn
+        # Check optimizer_state_dict state
+
+        self.assertEqual(len(osd[_STATE]), len(dist_osd[_STATE]))
+        for pid, states in osd[_STATE].items():
+            fqn = fqn_pid_mapping[pid]
+            dist_states = dist_osd[_STATE].get(fqn, None)
+            self.assertIsNotNone(dist_states, fqn)
+            self.assertEqual(len(states), len(dist_states))
+            for key, state in states.items():
+                dist_state = states.get(key, None)
+                self.assertIsNotNone(dist_state)
+                self._compare_tensor(state, dist_state)
+
+        # Check optimizer_state_dict param_group
+        old_dist_osd_pg = dist_osd[_PG]
+        if len(osd[_PG]) != len(dist_osd[_PG]):
+            self.assertTrue(len(dist_osd[_PG]) > len(osd[_PG]))
+            new_pg = copy.deepcopy(dist_osd[_PG][0])
+            new_pg["params"] = []
+            for dist_group in dist_osd[_PG]:
+                new_pg["params"].extend(dist_group["params"])
+            dist_osd[_PG] = [new_pg]
+
+        self.assertEqual(len(osd[_PG]), len(dist_osd[_PG]))
+        for group, dist_group in zip(osd[_PG], dist_osd[_PG], strict=True):
+            self.assertEqual(len(group), len(dist_group))
+            for key, value in group.items():
+                # Below doesn't work because param_groups can have None
+                # values.
+                # dist_value = dist_group.get(key, None)
+                # self.assertIsNotNone(dist_value, (dist_group, group))
+                dist_value = dist_group[key]
+                if key == "params":
+                    fqns = [fqn_pid_mapping[pid] for pid in value]
+                    self.assertEqual(sorted(fqns), sorted(dist_value))
+                else:
+                    self.assertEqual(value, dist_value)
+        dist_osd[_PG] = old_dist_osd_pg
+
+    def _verify_osd_by_load(
+        self,
+        model: nn.Module,
+        optim: torch.optim.Optimizer,
+        new_optim: torch.optim.Optimizer,
+        dist_osd: dict[str, Any],
+    ) -> None:
+        new_dist_osd = _gather_state_dict(dist_osd)
+        set_state_dict(
+            model,
+            optimizers=new_optim,
+            model_state_dict={},
+            optim_state_dict=new_dist_osd,
+        )
+        self.assertEqual(optim.state_dict(), new_optim.state_dict())
+
+
+class FusionEmbedding(nn.Module):
+    def __init__(self, vocab_size: int, fusion_vocab_size: int, embed_dim: int) -> None:
+        super().__init__()
+        self.embedding = nn.Embedding(vocab_size, embed_dim)
+        self.fusion_embedding = nn.Embedding(fusion_vocab_size, embed_dim)
+
+
+class FusionEmbeddingWithHook(nn.Module):
+    def __init__(self, vocab_size: int, fusion_vocab_size: int, embed_dim: int) -> None:
+        super().__init__()
+        self.embedding = nn.Embedding(vocab_size, embed_dim)
+        self.fusion_embedding = nn.Embedding(fusion_vocab_size, embed_dim)
+        self._register_state_dict_hook(FusionEmbeddingWithHook._state_dict_hook)
+        self._register_load_state_dict_pre_hook(
+            FusionEmbeddingWithHook._load_state_dict_hook, with_module=True
+        )
+
+    def _state_dict_hook(self, destination, prefix, keep_vars):
+        """Remove "embedding" from the original embedding in the state_dict
+        name. This keeps the original state dict name for the embedding
+        from before fusing with the FusionEmbedding.
+        """
+        key = prefix + "embedding.weight"
+        new_key = prefix + "weight"
+        destination[new_key] = destination[key]
+        del destination[key]
+
+    def _load_state_dict_hook(self, state_dict, prefix, *args, **kwargs):
+        """Apply extra "embedding" prefix to the state_dict key to
+        account for the FusionEmbedding wrapping.
+        """
+        if state_dict:
+            key = prefix + "weight"
+            new_key = prefix + "embedding.weight"
+            state_dict[new_key] = state_dict[key]
+            del state_dict[key]
+
+
+class FusionEmbeddingWithModifier(FusionEmbeddingWithHook):
+    # _fqn_modifiers is a private function as a contract between DSD. When users change the state_dict
+    # keys, they need to provide a mapping from the new key to the original key. This is used to ensure
+    # consistency between the state_dict keys and fqn.
+    def _fqn_modifiers(self) -> dict[str, str]:
+        return {
+            "weight": "embedding",
+        }
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..32498f6d14917511f599af30e6afc3c5972280fc
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py
@@ -0,0 +1,748 @@
+# mypy: allow-untyped-defs
+
+import contextlib
+import enum
+import logging
+import os
+import threading
+from typing import NamedTuple
+
+import torch
+import torch.distributed as dist
+import torch.distributed.autograd as dist_autograd
+import torch.nn as nn
+from torch.distributed import rpc
+from torch.distributed.nn import RemoteModule
+from torch.nn.parallel import DistributedDataParallel
+from torch.testing._internal.common_distributed import (
+    requires_gloo,
+    requires_nccl,
+    skip_if_lt_x_gpu,
+    skip_if_rocm_multiprocess,
+)
+from torch.testing._internal.dist_utils import dist_init, INIT_METHOD_TEMPLATE
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+NUM_EM_ROW = 2
+D_SPARSE = 3
+D_DENSE = 2
+D_HID = 3
+D_OUT = 1
+NUM_TRAINERS = 4
+# Trainers + the master + the remote worker
+WORLD_SIZE = NUM_TRAINERS + 2
+TRAINER_RANKS = list(range(NUM_TRAINERS))
+REMOTE_WORKER_RANK = TRAINER_RANKS[-1] + 1
+MASTER_RANK = REMOTE_WORKER_RANK + 1
+
+
+class DdpMode(enum.Enum):
+    # Don't apply DDP
+    NONE = enum.auto()
+    # Apply DDP to the top level nn.Module
+    OUTSIDE = enum.auto()
+    # Embed DDP inside the top level nn.Module
+    INSIDE = enum.auto()
+
+
+def init_logger():
+    logger = logging.getLogger(__name__)
+    level = logging.DEBUG if "debug" in os.environ else logging.INFO
+    logger.setLevel(level)
+    console = logging.StreamHandler()
+    formatter = logging.Formatter(
+        "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s"
+    )
+    console.setFormatter(formatter)
+    console.setLevel(level)
+    # add the handlers to the logger
+    logger.addHandler(console)
+    logger.propagate = False
+    return logger
+
+
+gLogger = init_logger()
+
+
+class FeatureSet(NamedTuple):
+    """A feature set has 2 types of features"""
+
+    dense_features: torch.Tensor
+    sparse_features: torch.LongTensor
+    values: torch.Tensor
+
+
+def _call_method(method, rref, *args, **kwargs):
+    return method(rref.local_value(), *args, **kwargs)
+
+
+def _remote_method(method, rref, *args, **kwargs):
+    args_tup = tuple([method, rref] + list(args))
+    return rpc.rpc_sync(rref.owner(), _call_method, args=args_tup, kwargs=kwargs)
+
+
+def _remote_method_async(method, rref, *args, **kwargs):
+    args_tup = tuple([method, rref] + list(args))
+    return rpc.rpc_async(rref.owner(), _call_method, args=args_tup, kwargs=kwargs)
+
+
+class RemoteEM(nn.Module):
+    def __init__(self, num_embeddings: int, embedding_dim: int):
+        gLogger.info("Initing RemoteEM with %s %s", num_embeddings, embedding_dim)
+        super().__init__()
+        init_em = [0.5] * embedding_dim
+        self.em = nn.EmbeddingBag(
+            num_embeddings,
+            embedding_dim,
+            _weight=torch.tensor([init_em] * num_embeddings),
+        )
+
+    def forward(self, input: torch.Tensor):
+        gLogger.debug("Running RemoteEM.forward() on: %s", input)
+        return self.em(input, offsets=torch.LongTensor(range(input.shape[0])))
+
+
+# Return a linear module with predefined parameters.
+def getLinear(d_in, d_out):
+    l = nn.Linear(d_in, d_out, bias=False)
+    w = torch.ones((d_out, d_in))
+    w[0][0] = -1
+    w.requires_grad_()
+    l.weight.data = w
+    return l
+
+
+class RemoteNet(nn.Module):
+    def __init__(self, d_in: int, d_out: int):
+        gLogger.info("Initing RemoteNet with %s %s", d_in, d_out)
+        super().__init__()
+        self.fc = getLinear(d_in, d_out)
+        self.relu = nn.ReLU()
+
+    def forward(self, input: torch.Tensor):
+        gLogger.debug("Running RemoteNet.forward() on: %s", input)
+        return self.relu(self.fc(input))
+
+
+class HybridModel(nn.Module):
+    def __init__(
+        self,
+        remote_em_rref: rpc.RRef,
+        remote_net_rref: rpc.RRef,
+        process_group_for_ddp: dist.ProcessGroup = None,
+    ):
+        super().__init__()
+        self.remote_em_rref = remote_em_rref
+        self.remote_net_rref = remote_net_rref
+        self.fc1 = getLinear(D_DENSE, D_DENSE)
+        self.fc2 = getLinear(D_HID, D_OUT)
+
+        self.non_ddp_params = tuple(self.fc1.parameters()) + tuple(
+            self.fc2.parameters()
+        )
+        self.ddp_params = ()
+
+        if process_group_for_ddp is not None:
+            self.non_ddp_params, self.ddp_params = (
+                tuple(self.fc1.parameters()),
+                tuple(self.fc2.parameters()),
+            )
+            gLogger.info("Use DDP for the second local net.")
+            self.fc2 = DistributedDataParallel(
+                self.fc2, check_reduction=True, process_group=process_group_for_ddp
+            )
+
+        gLogger.info(
+            "HybridModel has %s groups of parameters.", len(list(self.parameters()))
+        )
+
+    def forward(self, input: FeatureSet):
+        gLogger.debug("Running HybridModel.forward on %s", input)
+        sparse = _remote_method(
+            RemoteEM.forward, self.remote_em_rref, input.sparse_features
+        )
+        # The same size of mini batch.
+        assert sparse.shape[0] == input.dense_features.shape[0]
+        dense = self.fc1(input.dense_features)
+        x = torch.cat((dense, sparse), 1)
+        gLogger.debug("Concatenated feature: %s", x)
+        x = _remote_method(RemoteNet.forward, self.remote_net_rref, x)
+        return self.fc2(x)
+
+
+class Trainer:
+    def __init__(
+        self,
+        remote_em_rref: rpc.RRef,
+        remote_net_rref: rpc.RRef,
+        ddp_mode: DdpMode,
+        rank: int,
+    ):
+        self.rank = rank
+        self.trainer_group = (
+            dist.new_group(TRAINER_RANKS)
+            if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE)
+            else None
+        )
+        self.remote_em_rref = remote_em_rref
+        self.remote_net_rref = remote_net_rref
+        self.hybrid_module = HybridModel(
+            self.remote_em_rref,
+            self.remote_net_rref,
+            self.trainer_group if ddp_mode == DdpMode.INSIDE else None,
+        )
+        self.ddp_params, self.non_ddp_params = (
+            self.hybrid_module.ddp_params,
+            self.hybrid_module.non_ddp_params,
+        )
+        if ddp_mode == DdpMode.OUTSIDE:
+            gLogger.info("Wrapping the whole hybrid module into DDP.")
+            self.ddp_params += self.non_ddp_params
+            self.non_ddp_params = ()
+            self.hybrid_module = DistributedDataParallel(
+                self.hybrid_module,
+                check_reduction=True,
+                process_group=self.trainer_group,
+            )
+        gLogger.info(
+            "Succeeded in creating a HybridModel instance with "
+            "%s ddp params and %s other local params.",
+            len(self.ddp_params),
+            len(self.non_ddp_params),
+        )
+
+    def destroy_pg(self):
+        if self.trainer_group:
+            dist.destroy_process_group(self.trainer_group)
+
+    def train_batch(
+        self,
+        mini_batch: FeatureSet,
+        trainer_has_less_inputs: bool,
+        simulate_uneven_inputs: bool,
+    ):
+        grads_dict = None
+
+        if not simulate_uneven_inputs:
+            input_batches = [mini_batch]
+        else:
+            # Split into microbatches, and trim to simulate uneven inputs.
+            dense_features = mini_batch.dense_features
+            sparse_features = mini_batch.sparse_features
+            values = mini_batch.values
+
+            dense_microbatch = torch.split(dense_features, 2)
+            sparse_microbatch = torch.split(sparse_features, 2)
+            values_microbatch = torch.split(values, 2)
+            batches = []
+            for d, s, v in zip(
+                dense_microbatch, sparse_microbatch, values_microbatch, strict=True
+            ):
+                feature_set = FeatureSet(dense_features=d, sparse_features=s, values=v)
+                batches.append(feature_set)
+
+            if trainer_has_less_inputs:
+                input_batches = batches[: len(batches) // 2]
+                gLogger.info(
+                    "Trainer reduced input patches from %s "
+                    "to %s to simulate uneven inputs.",
+                    len(batches),
+                    len(input_batches),
+                )
+            else:
+                input_batches = batches
+
+        with (
+            self.hybrid_module.join()
+            if simulate_uneven_inputs
+            else contextlib.nullcontext()
+        ):
+            for b in input_batches:
+                with dist_autograd.context() as context_id:
+                    output = self.hybrid_module.forward(b)
+                    loss = (output * mini_batch.values).sum()
+                    dist_autograd.backward(context_id, [loss])
+                    grads_dict = dist_autograd.get_gradients(context_id)
+                    gLogger.info(
+                        "Loss is %s for mini batch: %s. Grads dict has %s entries: %s",
+                        loss,
+                        mini_batch,
+                        len(grads_dict),
+                        grads_dict,
+                    )
+        return (
+            tuple(grads_dict[param] for param in self.ddp_params),
+            tuple(grads_dict[param] for param in self.non_ddp_params),
+        )
+
+
+def get_training_examples():
+    n = 16
+    training_examples = FeatureSet(
+        dense_features=torch.zeros((n, D_DENSE)),
+        sparse_features=torch.zeros(n, dtype=torch.long),
+        values=torch.zeros(n),
+    )
+    idx = 0
+    # Every example has another one that has exactly the same features but an
+    # opposite value. Therefore, their grads cancel each other in all-reduce.
+    for value in (-1, 1):
+        for x in (-1.0 * value, 1.0 * value):
+            for y in (1.0 * value, -1.0 * value):
+                for z in (0, 1):
+                    training_examples.dense_features[idx, :] = torch.tensor((x, y))
+                    training_examples.sparse_features[idx] = z
+                    training_examples.values[idx] = value
+                    idx += 1
+
+    # Split the examples among NUM_TRAINERS trainers
+    assert 0 == (n % NUM_TRAINERS)
+    examples_per_trainer = int(n / NUM_TRAINERS)
+    return [
+        FeatureSet(
+            dense_features=training_examples.dense_features[
+                start : start + examples_per_trainer, :
+            ],
+            sparse_features=training_examples.sparse_features[
+                start : start + examples_per_trainer
+            ],
+            values=training_examples.values[start : start + examples_per_trainer],
+        )
+        for start in range(0, n, examples_per_trainer)
+    ]
+
+
+shutdown_signal = threading.Condition()
+
+
+def set_shutdown_signal():
+    global shutdown_signal
+    with shutdown_signal:
+        shutdown_signal.notify()
+
+
+class DdpUnderDistAutogradTest(RpcAgentTestFixture):
+    @property
+    def world_size(self) -> int:
+        return WORLD_SIZE
+
+    def remote_worker_name(self) -> str:
+        # The name has to be consistent with that in 'dist_init' decorator.
+        return f"worker{REMOTE_WORKER_RANK}"
+
+    def trainer_name(self, rank):
+        # The name has to be consistent with that in 'dist_init' decorator.
+        return f"worker{rank}"
+
+    def _remote_worker_process(self, ddp_mode):
+        gLogger.info("The remote worker is running.")
+        dist.init_process_group(
+            backend="gloo",
+            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
+            world_size=self.world_size,
+            rank=self.rank,
+        )
+
+        if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE):
+            # new_group needs to be called on ranks.
+            dist.new_group(TRAINER_RANKS)
+
+        global shutdown_signal
+        with shutdown_signal:
+            shutdown_signal.wait()
+        gLogger.info("Exiting remote worker.")
+        dist.destroy_process_group()
+
+    def _trainer_process(self, rank: int):
+        gLogger.info("Running the trainer #%s...", rank)
+        gLogger.info(
+            "Initing trainer process group by trainer #%s with ranks %s",
+            rank,
+            TRAINER_RANKS,
+        )
+        dist.init_process_group(
+            backend="gloo",
+            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
+            world_size=self.world_size,
+            rank=self.rank,
+        )
+
+        gLogger.info("Waiting for shutdown signal on trainer #%s...", rank)
+
+        global shutdown_signal
+        with shutdown_signal:
+            shutdown_signal.wait()
+        gLogger.info("Exiting the trainer #%s...", rank)
+        dist.destroy_process_group()
+
+    def _master_process(self, ddp_mode: DdpMode, simulate_uneven_inputs: bool):
+        gLogger.info("Running the master process...")
+        dist.init_process_group(
+            backend="gloo",
+            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
+            world_size=self.world_size,
+            rank=self.rank,
+        )
+
+        remote_em_rref = rpc.remote(
+            self.remote_worker_name(), RemoteEM, args=(NUM_EM_ROW, D_SPARSE)
+        )
+        remote_net_rref = rpc.remote(
+            self.remote_worker_name(), RemoteNet, args=(D_DENSE + D_SPARSE, D_HID)
+        )
+        gLogger.info("Created remote rrefs on master")
+        self.do_test_on_master(
+            ddp_mode, simulate_uneven_inputs, remote_em_rref, remote_net_rref
+        )
+
+    def do_test_on_master(
+        self,
+        ddp_mode: DdpMode,
+        simulate_uneven_inputs: bool,
+        remote_em_rref: rpc.RRef,
+        remote_net_rref: rpc.RRef,
+    ):
+        if simulate_uneven_inputs:
+            gLogger.info(
+                "Running DDP + RPC test with simulating uneven inputs across trainers."
+            )
+
+        trainer_rrefs = []
+        for rank in TRAINER_RANKS:
+            trainer = self.trainer_name(rank)
+            trainer_rrefs.append(
+                rpc.remote(
+                    trainer,
+                    Trainer,
+                    args=(remote_em_rref, remote_net_rref, ddp_mode, rank),
+                )
+            )
+
+        if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE):
+            # new_group needs to be called on ranks.
+            dist.new_group(TRAINER_RANKS)
+
+        training_examples = get_training_examples()
+        for _ in range(3):
+            futures = []
+            num_trainers = len(trainer_rrefs)
+            for idx, trainer_rref in enumerate(trainer_rrefs):
+                # Half the trainers will deplete inputs earlier than the rest.
+                trainer_has_less_inputs = (
+                    simulate_uneven_inputs and idx < num_trainers // 2
+                )
+                futures.append(
+                    _remote_method_async(
+                        Trainer.train_batch,
+                        trainer_rref,
+                        training_examples[idx],
+                        trainer_has_less_inputs,
+                        simulate_uneven_inputs,
+                    )
+                )
+
+            for future in futures:
+                ddp_grads, non_ddp_grads = future.wait()
+                # When there are uneven inputs, it is not necessary that grads
+                # cancel each other out, since some trainers contribute 0 grad.
+                if not simulate_uneven_inputs:
+                    for grad in ddp_grads:
+                        self.assertEqual(
+                            grad,
+                            torch.zeros_like(grad),
+                            msg=f"The grad for any ddp parameter should be zeros, because "
+                            "the training examples' grads cancel each other. Received "
+                            f"gradient {grad}",
+                        )
+                for grad in non_ddp_grads:
+                    self.assertNotEqual(
+                        grad,
+                        torch.zeros_like(grad),
+                        msg="The grad for any non-ddp parameter shouldn't be zeros",
+                    )
+
+        # Destroy process groups
+        for trainer_rref in trainer_rrefs:
+            _remote_method_async(Trainer.destroy_pg, trainer_rref).wait()
+
+        # Send shutdown signals.
+        for rank in TRAINER_RANKS:
+            trainer = self.trainer_name(rank)
+            rpc.rpc_sync(trainer, set_shutdown_signal, args=())
+
+        rpc.rpc_sync(self.remote_worker_name(), set_shutdown_signal, args=())
+
+    def _do_test(self, ddp_mode, simulate_uneven_inputs=False):
+        if self.rank == MASTER_RANK:
+            self._master_process(ddp_mode, simulate_uneven_inputs)
+        elif self.rank == REMOTE_WORKER_RANK:
+            self._remote_worker_process(ddp_mode)
+        elif self.rank in TRAINER_RANKS:
+            self._trainer_process(self.rank)
+        else:
+            raise RuntimeError(f"Unknown process rank: {self.rank}")
+
+    @requires_gloo()
+    @dist_init
+    def test_backward_no_ddp(self):
+        self._do_test(DdpMode.NONE)
+
+    @requires_gloo()
+    @dist_init
+    def test_backward_ddp_outside(self):
+        self._do_test(DdpMode.OUTSIDE)
+
+    @requires_gloo()
+    @dist_init
+    def test_backward_ddp_outside_uneven_inputs(self):
+        self._do_test(DdpMode.OUTSIDE, simulate_uneven_inputs=True)
+
+    @requires_gloo()
+    @dist_init
+    def test_backward_ddp_inside(self):
+        self._do_test(DdpMode.INSIDE)
+
+
+# Common utils for both CPU and CUDA test suites
+class CommonDdpComparisonTest(RpcAgentTestFixture):
+    @property
+    def world_size(self) -> int:
+        return NUM_TRAINERS
+
+    def trainer_name(self, rank):
+        # The name has to be consistent with that in 'dist_init' decorator.
+        return f"worker{rank}"
+
+    @staticmethod
+    def get_remote_grads(rref, context_id):
+        return dist_autograd.get_gradients(context_id)[rref.local_value().weight]
+
+
+class DdpComparisonTest(CommonDdpComparisonTest):
+    def _run_test_ddp_comparision(self, simulate_uneven_inputs=False):
+        gLogger.info("Running trainer rank: %s", self.rank)
+        # Each trainer uses a different random seed. Otherwise, they are going
+        # to have exactly the same initial model parameters, input, and
+        # therefore grads. That means the grads will be the same before and
+        # after DDP's all-reduce.
+        torch.manual_seed(self.rank)
+        dist.init_process_group(
+            backend="gloo",
+            # Postfix file_name with "pg" since file_name is also used by RPC agent
+            init_method=INIT_METHOD_TEMPLATE.format(file_name=f"{self.file_name}_pg"),
+            world_size=self.world_size,
+            rank=self.rank,
+        )
+        net = nn.Linear(2, 3)
+        ddp_net = DistributedDataParallel(net)
+
+        # Odd ranks join early if simulate_uneven_inputs.
+        num_inputs = 1
+        if simulate_uneven_inputs:
+            if self.rank % 2 == 0:
+                num_inputs += 2
+        inputs_list = [torch.rand((3, 2)) for _ in range(num_inputs)]
+
+        if simulate_uneven_inputs:
+            gLogger.info(
+                "Rank %s training with %s inputs.", self.rank, len(inputs_list)
+            )
+
+        # Use distributed autograd. The gradients will be in RPC context map.
+        grads_dict = {}
+        with ddp_net.join(simulate_uneven_inputs):
+            for i, inputs in enumerate(inputs_list):
+                with dist_autograd.context() as context_id:
+                    loss = ddp_net(inputs).norm()
+                    dist_autograd.backward(context_id, [loss])
+                    grads_dict = dist_autograd.get_gradients(context_id)
+                gLogger.info("Trainer #%s got grad dict: %s", self.rank, grads_dict)
+
+                # Use local autograd. The gradients will be in each variable's '.grad'.
+                ddp_net.zero_grad()
+                loss = ddp_net(inputs).norm()
+                loss.backward()
+
+                # The gradients should be the same
+                for param in net.parameters():
+                    self.assertTrue(
+                        param in grads_dict,
+                        msg=f"Param {param} is not in dist_auto grad dict {grads_dict} for iteration {i}",
+                    )
+                    self.assertEqual(
+                        grads_dict[param],
+                        param.grad,
+                        msg=f"The grads for param {param} are different under local "
+                        f"and dist autograd: {param.grad} \n---\n {grads_dict[param]} for iteration {i}",
+                    )
+        dist.destroy_process_group()
+
+    @requires_gloo()
+    @dist_init
+    def test_ddp_comparison(self):
+        self._run_test_ddp_comparision()
+
+    @requires_gloo()
+    @dist_init
+    def test_ddp_comparison_uneven_inputs(self):
+        # test with simulating uneven inputs in DDP
+        self._run_test_ddp_comparision(simulate_uneven_inputs=True)
+
+    @requires_gloo()
+    @dist_init
+    def test_ddp_dist_autograd_sparse_grads(self):
+        # Each trainer uses a different random seed. Otherwise, they are going
+        # to have exactly the same initial model parameters, input, and
+        # therefore grads. That means the grads will be the same before and
+        # after DDP's all-reduce.
+        torch.manual_seed(self.rank)
+        dist.init_process_group(
+            backend="gloo",
+            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
+            world_size=self.world_size,
+            rank=self.rank,
+        )
+
+        model = nn.EmbeddingBag(10, 3, sparse=True)
+        ddp_model = DistributedDataParallel(model)
+
+        # Different inputs for each
+        input = torch.LongTensor(10).random_(0, 10)
+        offsets = torch.LongTensor([0, 4])
+
+        # Run local.
+        loss = ddp_model(input, offsets).sum()
+        loss.backward()
+
+        with dist_autograd.context() as context_id:
+            loss = ddp_model(input, offsets).sum()
+            dist_autograd.backward(context_id, [loss])
+            grads_dict = dist_autograd.get_gradients(context_id)
+            self.assertEqual(1, len(grads_dict))
+            self.assertEqual(model.weight.grad, grads_dict[model.weight])
+
+    @requires_gloo()
+    @dist_init
+    def test_ddp_dist_autograd_local_vs_remote(self):
+        # Each trainer uses a different random seed. Otherwise, they are going
+        # to have exactly the same initial model parameters, input, and
+        # therefore grads. That means the grads will be the same before and
+        # after DDP's all-reduce.
+        torch.manual_seed(self.rank)
+        dist.init_process_group(
+            backend="gloo",
+            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
+            world_size=self.world_size,
+            rank=self.rank,
+        )
+
+        # Use two different remote device input string, w/ and w/o the default
+        # device string "cpu", respectively.
+        for remote_device in ["worker0/cpu", "worker0"]:
+            remote_layer1 = RemoteModule(
+                remote_device=remote_device, module_cls=nn.Linear, args=(10, 5, False)
+            )
+            layer1 = nn.Linear(10, 5, False)
+            # Start with the same parameters for remote and local
+            layer1.weight = remote_layer1.module_rref.to_here().weight
+
+            # Run local case.
+            layer2 = nn.Linear(5, 1)
+            inputs = torch.rand((10, 10))
+            ddp_model = DistributedDataParallel(layer2)
+            loss = ddp_model(layer1(inputs)).sum()
+            loss.backward()
+
+            # Run remote case.
+            with dist_autograd.context() as context_id:
+                loss = ddp_model(remote_layer1(inputs)).sum()
+                dist_autograd.backward(context_id, [loss])
+                grads_dict = dist_autograd.get_gradients(context_id)
+                dist.barrier()
+                self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight])
+                self.assertEqual(
+                    layer1.weight.grad,
+                    rpc.rpc_sync(
+                        "worker0",
+                        CommonDdpComparisonTest.get_remote_grads,
+                        args=(remote_layer1.module_rref, context_id),
+                    ),
+                )
+
+
+class CudaDdpComparisonTest(CommonDdpComparisonTest):
+    @skip_if_lt_x_gpu(NUM_TRAINERS)
+    @requires_nccl()
+    @dist_init
+    @skip_if_rocm_multiprocess
+    def test_ddp_dist_autograd_local_vs_remote_gpu(self):
+        # Each trainer uses a different random seed. Otherwise, they are going
+        # to have exactly the same initial model parameters, input, and
+        # therefore grads. That means the grads will be the same before and
+        # after DDP's all-reduce.
+        torch.manual_seed(self.rank)
+        dist.init_process_group(
+            backend="gloo",
+            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
+            world_size=self.world_size,
+            rank=self.rank,
+        )
+
+        remote_layer1 = RemoteModule(
+            remote_device="worker0/cpu", module_cls=nn.Linear, args=(10, 7, False)
+        )
+        layer1 = nn.Linear(10, 7, False)
+        # Start with the same parameters for remote and local
+        layer1.weight = remote_layer1.module_rref.to_here().weight
+
+        layer2 = nn.Linear(7, 5).cuda(self.rank)
+        ddp_layer2 = DistributedDataParallel(layer2, device_ids=[self.rank])
+
+        remote_layer3 = RemoteModule(
+            remote_device="worker0/cpu", module_cls=nn.Linear, args=(5, 3, False)
+        )
+        layer3 = nn.Linear(5, 3, False)
+        # Start with the same parameters for remote and local
+        layer3.weight = remote_layer3.module_rref.to_here().weight
+
+        layer4 = nn.Linear(3, 1).cuda(self.rank)
+        ddp_layer4 = DistributedDataParallel(layer4, device_ids=[self.rank])
+
+        # Run local case.
+        inputs = torch.rand((10, 10))
+        loss = ddp_layer4(
+            layer3(ddp_layer2(layer1(inputs).cuda(self.rank)).cpu()).cuda(self.rank)
+        ).sum()
+        loss.backward()
+
+        # Run remote case.
+        with dist_autograd.context() as context_id:
+            loss = ddp_layer4(
+                remote_layer3(
+                    ddp_layer2(remote_layer1(inputs).cuda(self.rank)).cpu()
+                ).cuda(self.rank)
+            ).sum()
+            dist_autograd.backward(context_id, [loss])
+            grads_dict = dist_autograd.get_gradients(context_id)
+            dist.barrier()
+            self.assertEqual(
+                layer1.weight.grad,
+                rpc.rpc_sync(
+                    "worker0",
+                    CommonDdpComparisonTest.get_remote_grads,
+                    args=(remote_layer1.module_rref, context_id),
+                ),
+            )
+            self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight])
+            self.assertEqual(
+                layer3.weight.grad,
+                rpc.rpc_sync(
+                    "worker0",
+                    CommonDdpComparisonTest.get_remote_grads,
+                    args=(remote_layer3.module_rref, context_id),
+                ),
+            )
+            self.assertEqual(layer4.weight.grad, grads_dict[layer4.weight])
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/distributed_test.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/distributed_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..45bd2d1035b1b190520d712666d7d449adc25664
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/distributed_test.py
@@ -0,0 +1,10420 @@
+# mypy: allow-untyped-defs
+
+import copy
+import itertools
+import json
+import math
+import operator
+import os
+import random
+import re
+import sys
+import tempfile
+import time
+import unittest
+from collections import defaultdict, namedtuple, OrderedDict
+from collections.abc import Callable
+from contextlib import contextmanager, nullcontext
+from dataclasses import dataclass
+from datetime import timedelta
+from functools import reduce
+from typing import Any, NamedTuple, Union
+
+import numpy as np
+
+import torch
+import torch.cuda
+import torch.distributed as dist
+import torch.distributed.algorithms.model_averaging.averagers as averagers
+import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD
+import torch.distributed.algorithms.model_averaging.utils as model_averaging_utils
+import torch.distributed.optim.post_localSGD_optimizer as post_localSGD_optimizer
+import torch.nn as nn
+import torch.nn.functional as F
+from torch._utils_internal import (
+    TEST_MASTER_ADDR as MASTER_ADDR,
+    TEST_MASTER_PORT as MASTER_PORT,
+)
+from torch.autograd import DeviceType
+from torch.cuda.amp import autocast, GradScaler
+from torch.distributed.algorithms.ddp_comm_hooks import (
+    default_hooks as default,
+    post_localSGD_hook as post_localSGD,
+    powerSGD_hook as powerSGD,
+    quantization as quantization_hooks,
+)
+from torch.distributed.distributed_c10d import (
+    _get_default_group,
+    _get_pg_config,
+    get_world_size,
+)
+from torch.distributed.optim import _apply_optimizer_in_backward
+from torch.distributed.utils import (
+    _sync_module_states,
+    _verify_param_shape_across_processes,
+)
+from torch.nn.parallel import DistributedDataParallel
+from torch.nn.parallel.distributed import _dump_DDP_relevant_env_vars, _MixedPrecision
+from torch.profiler import ExecutionTraceObserver, ProfilerActivity
+from torch.testing._internal.common_distributed import (
+    captured_output,
+    cleanup_temp_dir,
+    DistTestCases,
+    init_multigpu_helper,
+    initialize_temp_directories,
+    MultiProcessTestCase,
+    nccl_skip_if_lt_x_gpu,
+    require_n_gpus_for_nccl_backend,
+    requires_nccl_version,
+    simple_sparse_reduce_tests,
+    skip_if_lt_x_gpu,
+    skip_if_no_gpu,
+    skip_if_odd_worldsize,
+    skip_if_rocm_multiprocess,
+    skip_if_small_worldsize,
+    TEST_SKIPS,
+    verify_ddp_error_logged,
+    with_dist_debug_levels,
+    with_nccl_blocking_wait,
+)
+from torch.testing._internal.common_utils import (
+    FILE_SCHEMA,
+    instantiate_parametrized_tests,
+    IS_FBCODE,
+    IS_MACOS,
+    IS_SANDCASTLE,
+    IS_WINDOWS,
+    MI200_ARCH,
+    skip_but_pass_in_sandcastle,
+    skip_but_pass_in_sandcastle_if,
+    skipIfRocm,
+    skipIfRocmArch,
+    TemporaryFileName,
+)
+from torch.utils._python_dispatch import TorchDispatchMode
+from torch.utils.data.distributed import DistributedSampler
+
+
+try:
+    import torchvision
+
+    HAS_TORCHVISION = True
+except Exception:  # Covering both ImportError and RuntimeError
+    HAS_TORCHVISION = False
+
+if sys.platform == "win32":
+    import msvcrt
+else:
+    import fcntl
+
+
+class NetWithBuffers(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.a = nn.Linear(10, 10, bias=False)
+        self.b = nn.Linear(10, 1, bias=False)
+        self.register_buffer("buffer", torch.randn(1, 2))
+
+    def forward(self, x):
+        self.buffer.add_(1)
+        return self.b(self.a(x))
+
+
+class Foo:
+    def __init__(self, x):
+        # Can be tensor or int
+        self.x = x
+
+    def __eq__(self, other):
+        def eq(value, other):
+            if isinstance(value, torch.Tensor):
+                return torch.equal(value, other)
+            return value == other
+
+        for attr, value in self.__dict__.items():
+            other_value = other.__dict__[attr]
+            if not eq(value, other_value):
+                return False
+        return True
+
+
+f = Foo(10)
+f.bar = 1
+
+
+# Defer instantiation until the seed is set so that randn() returns the same
+# values in all processes.
+def create_collectives_object_test_list():
+    return [
+        {"key1": 3, "key2": 4, "key3": {"nested": True}},
+        f,
+        Foo(torch.randn(3, 3)),
+        "foo",
+        [1, 2, True, "string", [4, 5, "nested"]],
+    ]
+
+
+# Allowlist of distributed backends where profiling collectives is supported.
+PROFILING_SUPPORTED_BACKENDS = [
+    dist.Backend.NCCL,
+    dist.Backend.GLOO,
+    dist.Backend.MPI,
+    dist.Backend.UCC,
+]
+
+# Allowlist of distributed backends where profiling is supported with use_cuda=True
+CUDA_PROFILING_SUPPORTED_BACKENDS = [
+    dist.Backend.GLOO,
+    dist.Backend.MPI,
+    dist.Backend.NCCL,
+    dist.Backend.UCC,
+]
+
+# Allowlist of distributed backends where profiling is supported for p2p ops
+SEND_RECV_PROFILING_SUPPORTED_BACKENDS = [
+    dist.Backend.MPI,
+    dist.Backend.GLOO,
+    dist.Backend.NCCL,
+    dist.Backend.UCC,
+]
+
+# Dummy NamedTuple data structures to test DDP support for NamedTuple types.
+EXPECTED_FIELDS = ("a", "b")
+TestNamedTupleInput_0 = namedtuple("NamedTuple", EXPECTED_FIELDS)
+
+
+class TestNamedTupleInput_1(NamedTuple):
+    a: torch.tensor
+    b: torch.tensor
+
+
+skipIfNoTorchVision = skip_but_pass_in_sandcastle_if(
+    not HAS_TORCHVISION, "no torchvision"
+)
+
+BACKEND = os.environ["BACKEND"]
+INIT_METHOD = os.getenv("INIT_METHOD", "env://")
+
+DEFAULT_TIMEOUT = 300
+CUSTOMIZED_TIMEOUT = {"test_DistributedDataParallel": 500}
+
+
+def get_profiling_event(event_name, profiler, dedup_gpu_user_annotation=False):
+    event_list = (
+        profiler.events()
+        if isinstance(profiler, torch.profiler.profile)
+        else profiler.function_events
+    )
+    return [
+        event
+        for event in event_list
+        if (
+            (event.name.endswith(event_name) or event.name.startswith(event_name))
+            and (not dedup_gpu_user_annotation or event.device_type != DeviceType.CUDA)
+        )
+    ]
+
+
+def get_profiler_nccl_meta(prof):
+    """Torch profiler includes nccl metadata in an inserted operator called "record_param_comms"
+    We will need to test metadata obtained from profiler here"""
+    with TemporaryFileName(mode="w+t", suffix=".json") as trace_file:
+        prof.export_chrome_trace(trace_file)
+        with open(trace_file) as f:
+            events = json.load(f)["traceEvents"]
+        print(f"Trace saved to {trace_file}")
+
+        return [e for e in events if e.get("name") == "record_param_comms"]
+
+
+# Base error message substring on unfinished reductions.
+ddp_prev_reduction_unfinished_str = (
+    "Expected to have finished reduction in the prior iteration"
+)
+# Error message substring when find_unused_parameters=True has not been passed
+ddp_recommend_find_unused_params_str = (
+    "passing the keyword argument `find_unused_parameters=True`"
+)
+# Error message substring when find_unused_parameters=True is enabled
+ddp_find_unused_params_enabled_str = "Since `find_unused_parameters=True` is enabled"
+# Error message substring for possibility of not all model outputs being used
+# in loss computation
+ddp_outputs_not_used_in_loss_str = (
+    "`forward` function outputs participate in calculating loss"
+)
+# Error message substring suggesting to use TORCH_DISTRIBUTED_DEBUG
+ddp_suggest_debug_mode_str = (
+    "set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL"
+)
+
+
+class DDPUnevenTestInput(NamedTuple):
+    name: str
+    model: nn.Module
+    inp: Union[torch.tensor, tuple]
+    sync_interval: int
+    throw_on_early_termination: bool = False
+    hook: Callable = None
+    state: Any = None
+
+
+class _FC2(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc = nn.Linear(10, 50, bias=True)
+        self.fc.bias.requires_grad = False
+
+    def forward(self, x):
+        x = self.fc(x)
+        return x
+
+
+class Net(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc1 = nn.Linear(2, 10, bias=False)
+        self.fc2 = _FC2()
+        self.fc3 = nn.Linear(50, 4, bias=False)
+        self.relu = nn.ReLU()
+        self.no_grad_param = nn.Parameter(
+            torch.tensor([2, 2]).long(), requires_grad=False
+        )
+
+    def forward(self, x):
+        x = self.relu(self.fc1(x))
+        x = self.relu(self.fc2(x))
+        x = self.fc3(x)
+        return F.softmax(x, dim=1)
+
+
+class LargeNet(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc1 = nn.Linear(1000, 2000, bias=False)
+        self.fc2 = nn.Linear(2000, 500, bias=False)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.fc2(x)
+        return x
+
+
+class Task(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.p = nn.Parameter(torch.ones(2, 2))
+
+    def forward(self, x):
+        return self.p + x
+
+
+class BatchNormNet(nn.Module):
+    def __init__(self, affine=True):
+        super().__init__()
+        self.fc1 = nn.Linear(2, 40, bias=False)
+        self.bn = nn.BatchNorm1d(4, affine=affine)
+        self.fc2 = nn.Linear(40, 4, bias=False)
+
+    def forward(self, x):
+        x = torch.reshape(self.fc1(x), (-1, 4, 10))
+        x = self.bn(x)
+        x = torch.reshape(x, (-1, 40))
+        x = self.fc2(x)
+        return F.softmax(x, dim=1)
+
+
+class UnusedParamTwoLinLayerNet(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.a = nn.Linear(10, 10, bias=False)
+        self.b = nn.Linear(10, 10, bias=False)
+        self.c = nn.Linear(5, 5, bias=False)
+
+    def forward(self, x):
+        a = self.a(x)
+        b = self.b(x)
+        return (a, b)
+
+
+class DictOutputModule(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.module = UnusedParamTwoLinLayerNet()
+
+    def forward(self, x):
+        predictions = self.module(x)
+        loss = (predictions[0] + predictions[1]).sum()
+        return {
+            "predictions": predictions,
+            "loss": loss,
+        }
+
+
+class TwoLinLayerNet(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.a = nn.Linear(10, 10, bias=False)
+        self.b = nn.Linear(10, 1, bias=False)
+
+    def forward(self, x):
+        a = self.a(x)
+        b = self.b(x)
+        return (a, b)
+
+
+class EmbeddingNetDifferentParams(nn.Module):
+    """
+    A module containing an embedding with different dimension or different # of
+    parameters depending on the rank.
+    """
+
+    def __init__(self, rank, diff_num_params=False):
+        super().__init__()
+        embedding_dim = 500 if diff_num_params or rank == 0 else 50
+        self.embedding = nn.Embedding(num_embeddings=10, embedding_dim=embedding_dim)
+        self.lin = nn.Linear(embedding_dim, 1)
+        if diff_num_params:
+            self.lin2 = nn.Linear(1, 1, bias=False)
+
+    def forward(self, x):
+        x = self.embedding(x)
+        return self.lin(x)
+
+
+class ControlFlowToyModel(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.lin1 = nn.Linear(10, 10, bias=False)
+        self.lin2 = nn.Linear(10, 10, bias=False)
+
+    def forward(self, x):
+        # Second layer is used dependent on input x.
+        use_second_layer = torch.equal(x, torch.ones(20, 10, device=x.device))
+        if use_second_layer:
+            return self.lin2(F.relu(self.lin1(x)))
+        else:
+            return F.relu(self.lin1(x))
+
+
+def get_timeout(test_id):
+    test_name = test_id.split(".")[-1]
+    if test_name in CUSTOMIZED_TIMEOUT:
+        return CUSTOMIZED_TIMEOUT[test_name]
+    else:
+        return DEFAULT_TIMEOUT
+
+
+default_pg_timeout = 60
+
+CUSTOM_PG_TIMEOUT = {
+    # This test runs slowly and needs additional time to complete, otherwise can
+    # be taken down by TORCH_NCCL_ASYNC_ERROR_HANDLING
+    "test_ddp_uneven_inputs": 300,
+    # This test has a short timeout since it tests being taken down by
+    # TORCH_NCCL_ASYNC_ERROR_HANDLING which we want to happen quickly.
+    "test_ddp_model_diff_across_ranks": 5,
+    # This test has a short timeout since it tests being taken down by
+    # TORCH_NCCL_ASYNC_ERROR_HANDLING which we want to happen quickly.
+    "test_ddp_has_finalized": 5,
+}
+
+
+def require_backend_is_available(backends):
+    def check(backend):
+        if backend == dist.Backend.GLOO:
+            return dist.is_gloo_available()
+        if backend == dist.Backend.NCCL:
+            return dist.is_nccl_available()
+        if backend == dist.Backend.MPI:
+            return dist.is_mpi_available()
+        if backend == dist.Backend.UCC:
+            return dist.is_ucc_available()
+        if backend in DistTestCases.backend_feature["plugin"]:
+            return True
+        return False
+
+    if BACKEND not in backends:
+        return skip_but_pass_in_sandcastle(
+            f"Test requires backend {BACKEND} to be one of {backends}"
+        )
+
+    if not check(dist.Backend(BACKEND)):
+        return skip_but_pass_in_sandcastle(
+            f"Test requires backend {BACKEND} to be available"
+        )
+    return lambda func: func
+
+
+def require_world_size(world_size):
+    if int(os.environ["WORLD_SIZE"]) < world_size:
+        return skip_but_pass_in_sandcastle(
+            f"Test requires world size of {world_size:d}"
+        )
+    return lambda func: func
+
+
+def require_exact_world_size(world_size):
+    if int(os.environ["WORLD_SIZE"]) != world_size:
+        return skip_but_pass_in_sandcastle(
+            f"Test requires an exact world size of {world_size:d}"
+        )
+    return lambda func: func
+
+
+@contextmanager
+def _lock():
+    TEMP_DIR = os.environ["TEMP_DIR"]
+    lockfile = os.path.join(TEMP_DIR, "lockfile")
+    with open(lockfile, "w") as lf:
+        try:
+            if sys.platform == "win32":
+                msvcrt.locking(lf.fileno(), msvcrt.LK_RLCK, 1)
+                yield
+            else:
+                fcntl.flock(lf.fileno(), fcntl.LOCK_EX)
+                yield
+        finally:
+            if sys.platform == "win32":
+                msvcrt.locking(lf.fileno(), msvcrt.LK_UNLCK, 1)
+            else:
+                fcntl.flock(lf.fileno(), fcntl.LOCK_UN)
+            lf.close()
+
+
+@contextmanager
+def _rank_temp_file():
+    if dist.get_rank() == 0:
+        fd, name = tempfile.mkstemp()
+        os.close(fd)
+    else:
+        name = None
+    object_list = [name]
+    dist.broadcast_object_list(object_list)
+    name = object_list[0]
+    try:
+        yield name
+    finally:
+        if dist.get_rank() == 0:
+            os.remove(name)
+
+
+def _build_tensor(size, value=None, dtype=torch.float, device_id=None):
+    if value is None:
+        value = size
+    if device_id is None:
+        return torch.empty(size, size, size, dtype=dtype).fill_(value)
+    else:
+        return torch.empty(size, size, size, dtype=dtype).fill_(value).cuda(device_id)
+
+
+def _build_multidim_tensor(dim, dim_size, value=None, dtype=torch.float):
+    if value is None:
+        value = dim
+    return torch.empty(size=[dim_size for _ in range(dim)], dtype=dtype).fill_(value)
+
+
+def _create_autograd_profiler():
+    return torch.autograd.profiler.profile(record_shapes=True)
+
+
+def _create_torch_profiler():
+    return torch.profiler.profile(
+        activities=[
+            torch.profiler.ProfilerActivity.CPU,
+        ],
+        record_shapes=True,
+    )
+
+
+class Barrier:
+    barrier_id = 0
+
+    @classmethod
+    def init(cls):
+        cls.barrier_id = 0
+        barrier_dir = os.path.join(os.environ["TEMP_DIR"], "barrier")
+        for f_name in os.listdir(barrier_dir):
+            os.unlink(os.path.join(barrier_dir, f_name))
+
+    @classmethod
+    def sync(cls, wait_for=None, timeout=10):
+        if wait_for is None:
+            wait_for = dist.get_world_size()
+        cls.barrier_id += 1
+        barrier_dir = os.path.join(os.environ["TEMP_DIR"], "barrier")
+        pid = str(os.getpid())
+        barrier_file = os.path.join(barrier_dir, pid)
+        with _lock():
+            with open(barrier_file, "w") as f:
+                f.write(str(cls.barrier_id))
+
+        start_time = time.time()
+        while True:
+            arrived = 0
+            with _lock():
+                for f_name in os.listdir(barrier_dir):
+                    with open(os.path.join(barrier_dir, f_name)) as f:
+                        data = f.read()
+                        if int(data) >= cls.barrier_id:
+                            arrived += 1
+            if arrived == wait_for:
+                break
+
+            if time.time() - start_time > timeout:
+                raise RuntimeError("barrier timeout")
+            time.sleep(0.1)
+
+
+class TestDistBackend(MultiProcessTestCase):
+    @classmethod
+    def setUpClass(cls):
+        os.environ["MASTER_ADDR"] = str(MASTER_ADDR)
+        # Not setting MASTER_PORT and get a random free port
+        super().setUpClass()
+
+    def setUp(self):
+        super().setUp()
+        # initialize temp directories
+        initialize_temp_directories()
+        # initialize Barrier
+        Barrier.init()
+        # Skip return code checking for following tests as they are expected to
+        # crash a process due to TORCH_NCCL_ASYNC_ERROR_HANDLING.
+        self.skip_return_code_checks = [self.test_ddp_has_finalized.__wrapped__]
+
+    def tearDown(self):
+        cleanup_temp_dir()
+        super().tearDown()
+
+    @property
+    def init_method(self):
+        return f"{FILE_SCHEMA}{self.file_name}"
+
+    @property
+    def destroy_pg_upon_exit(self) -> bool:
+        # Overriding base test class: do not auto destroy PG upon exit.
+        return False
+
+    @classmethod
+    def _run(cls, rank, test_name, file_name, pipe, **kwargs):
+        if BACKEND == "nccl" and not torch.cuda.is_available():
+            sys.exit(TEST_SKIPS["no_cuda"].exit_code)
+        self = cls(test_name)
+        self.rank = rank
+        self.file_name = file_name
+
+        if torch.cuda.is_available() and torch.cuda.device_count() < int(
+            self.world_size
+        ):
+            sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
+        try:
+            pg_timeout_seconds = CUSTOM_PG_TIMEOUT.get(test_name, default_pg_timeout)
+            timeout = timedelta(seconds=pg_timeout_seconds)
+            dist.init_process_group(
+                init_method=self.init_method,
+                backend=BACKEND,
+                world_size=int(self.world_size),
+                rank=self.rank,
+                timeout=timeout,
+            )
+        except RuntimeError as e:
+            if "recompile" in e.args[0]:
+                sys.exit(TEST_SKIPS["backend_unavailable"].exit_code)
+
+            raise
+
+        # Execute barrier prior to running test to ensure that every process
+        # has finished initialization and that the following test
+        # immediately exiting due to a skip doesn't cause flakiness.
+        self._barrier()
+
+        self.run_test(test_name, pipe)
+        self._barrier()
+        dist.destroy_process_group()
+        sys.exit(0)
+
+    # Needed since MultiProcessTestCase assumes a world_size of 4, but we
+    # run these tests under other various world_sizes.
+    @property
+    def world_size(self):
+        return os.environ["WORLD_SIZE"]
+
+
+class DistributedTest:
+    class _DistTestBase:
+        def _barrier(self, *args, **kwargs):
+            Barrier.sync(*args, **kwargs)
+
+        def _init_group_test(self, **kwargs):
+            group = [1, 2]
+            group_id = dist.new_group(group, **kwargs)
+            rank = dist.get_rank()
+            if rank not in group:
+                return ([], None, rank)
+
+            return (group, group_id, rank)
+
+        def _init_full_group_test(self, **kwargs):
+            group = list(range(dist.get_world_size()))
+            group_id = dist.new_group(**kwargs)
+            rank = dist.get_rank()
+            return (group, group_id, rank)
+
+        def _init_global_test(self):
+            group = list(range(dist.get_world_size()))
+            group_id = dist.group.WORLD
+            rank = dist.get_rank()
+            return (group, group_id, rank)
+
+        def _verify_buffers_equal(self, m1, m2):
+            # verify buffers across models
+            m1_buf_dict = dict(m1.module.named_buffers())
+            for name, buf in m2.module.named_buffers():
+                self.assertEqual(buf, m1_buf_dict[name])
+
+            # Verify buffers across ranks.
+            m1_buffers = list(m1.buffers())
+            m2_buffers = list(m2.buffers())
+            for buf1, buf2 in zip(m1_buffers, m2_buffers, strict=True):
+                gathered_bufs = [
+                    torch.empty_like(buf1) for _ in range(dist.get_world_size())
+                ]
+                dist.all_gather(gathered_bufs, buf1)
+                gathered_bufs_m2 = [
+                    torch.empty_like(buf2) for _ in range(dist.get_world_size())
+                ]
+                for b in gathered_bufs:
+                    self.assertEqual(b, buf1)
+                dist.all_gather(gathered_bufs_m2, buf2)
+                for b in gathered_bufs_m2:
+                    self.assertEqual(b, buf2)
+
+        def _sanity_check_profiler_nccl_meta(self, nccl_meta_events):
+            """Torch profiler includes nccl metadata in an inserted operator called "record_param_comms"
+            We test for basic fields in this profiler event that correspond to the nccl communication
+            collectives"""
+            per_coll_meta = defaultdict(list)
+            for e in nccl_meta_events:
+                args = e.get("args", {})
+                collname = args.get("Collective name", "")
+                self.assertNotEqual(collname, "")
+                self.assertNotEqual(args.get("dtype", ""), "")
+
+                per_coll_meta[collname].append(args)
+                if collname == "wait":
+                    continue
+
+                self.assertEqual(args["Process Group Description"], "default_pg")
+                self.assertNotEqual(args["Process Group Ranks"], "")
+
+                self.assertGreaterEqual(args.get("In msg nelems", -1), 0)
+                self.assertGreaterEqual(args.get("Out msg nelems", -1), 0)
+                self.assertGreaterEqual(args.get("Group size", -1), 0)
+                self.assertGreaterEqual(args.get("Global rank start", -1), 0)
+                self.assertGreaterEqual(args.get("Global rank stride", -1), 0)
+
+            # print(per_coll_meta)
+            return per_coll_meta
+
+        def test_dump_DDP_relevant_env_vars(self):
+            with captured_output() as (out, _):
+                _dump_DDP_relevant_env_vars()
+                lines = out.getvalue().splitlines()
+
+            def format_line(var):
+                return f"env:{var}={os.environ.get(var, 'N/A')}"
+
+            # Check relevant env vars
+            vars = [
+                "MASTER_ADDR",
+                "MASTER_PORT",
+                "WORLD_SIZE",
+                "NCCL_TOPO_DUMP_FILE",  # N/A
+                "TORCH_NCCL_ASYNC_ERROR_HANDLING",
+            ]
+            for var in vars:
+                line = format_line(var)
+                self.assertIn(line, lines)
+            # Check irrelevant env vars
+            vars = [
+                "xxx",
+                "yyy",
+                "zzz",
+            ]
+            for var in vars:
+                line = format_line(var)
+                self.assertNotIn(line, lines)
+
+        # GET RANK
+        def test_get_rank(self):
+            test_dir = os.path.join(os.environ["TEMP_DIR"], "test_dir")
+            pid = str(os.getpid())
+            num_processes = dist.get_world_size()
+            with open(os.path.join(test_dir, pid), "w") as f:
+                f.write(str(dist.get_rank()))
+
+            self._barrier()
+
+            all_ranks = set()
+            for f_name in os.listdir(test_dir):
+                with open(os.path.join(test_dir, f_name)) as f:
+                    all_ranks.add(int(f.read()))
+            self.assertEqual(len(all_ranks), num_processes)
+
+            self._barrier()
+
+            if dist.get_rank() == 0:
+                for f_name in os.listdir(test_dir):
+                    os.unlink(os.path.join(test_dir, f_name))
+
+            self._barrier()
+
+        def test_get_backend(self):
+            if dist.get_world_size() > 2:
+                group = [1, 2]
+            else:
+                group = [0, 1]
+            group_id = dist.new_group(group)
+            backend_str = BACKEND.lower()
+            self.assertEqual(dist.get_backend(), backend_str)
+            if dist.get_rank() in group:
+                self.assertEqual(dist.get_backend(group_id), backend_str)
+            else:
+                with self.assertRaisesRegex(
+                    ValueError, "Invalid process group specified"
+                ):
+                    dist.get_backend(group_id)
+
+        def test_Backend_enum_class(self):
+            # test parsing
+            backend = BACKEND.lower()
+            self.assertEqual(dist.Backend(BACKEND.upper()), backend)
+            self.assertEqual(dist.Backend(BACKEND), backend)
+            with self.assertRaises(ValueError):
+                dist.Backend(None)
+            with self.assertRaises(ValueError):
+                dist.Backend(3)
+            with self.assertRaises(ValueError):
+                dist.Backend(["gloo"])
+
+        # Test destroy
+        def test_destroy_group(self):
+            if dist.get_world_size() > 2:
+                group = [1, 2]
+            else:
+                group = [0, 1]
+            group_id = dist.new_group(group)
+            self._barrier()
+            dist.destroy_process_group(group_id)
+
+        # Test get rank and size of group
+        def test_get_rank_size_group(self):
+            if dist.get_world_size() > 2:
+                group = [1, 2]
+            else:
+                group = [0, 1]
+            group_id = dist.new_group(group)
+            if dist.get_rank() in group:
+                self.assertEqual(dist.get_world_size(group_id), 2)
+                self.assertTrue(dist.get_rank(group_id) in list(range(2)))
+            else:
+                self.assertEqual(dist.get_world_size(group_id), -1)
+                self.assertEqual(dist.get_rank(group_id), -1)
+
+        # Test destroy full groups
+        def test_destroy_full_group(self):
+            _, group_id, _ = self._init_full_group_test()
+            self._barrier()
+            dist.destroy_process_group(group_id)
+
+        # Test get rank and size of full group
+        def test_get_rank_size_full_group(self):
+            _, group_id, _ = self._init_full_group_test()
+            self.assertEqual(dist.get_world_size(group_id), dist.get_world_size())
+            self.assertEqual(dist.get_rank(group_id), dist.get_rank())
+
+        def _test_barrier_timeout(self, group_id, timeout):
+            local_rank = dist.get_rank(group_id)
+
+            # Only execute barrier on rank == 0, causing it to timeout
+            if local_rank == 0:
+                expected_time = time.time() + timeout.total_seconds()
+                # In debug mode, we execute a monitored_barrier before the
+                # collective, so assert on that.
+                if dist.get_debug_level() == dist.DebugLevel.DETAIL:
+                    exception_ctx = self.assertRaisesRegex(
+                        Exception, "failed to pass monitoredBarrier"
+                    )
+                else:
+                    exception_ctx = self.assertRaisesRegex(
+                        Exception, " (Timed out|closed|timeout) "
+                    )
+                with exception_ctx:
+                    dist.barrier(group_id)
+                self.assertGreaterAlmostEqual(time.time(), expected_time, delta=0.1)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo", "Only gloo backend supports timeouts"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            not INIT_METHOD.startswith("file://"),
+            "Requires file:// initialization method. "
+            + "Both tcp:// and env:// rely on the TCP store for which "
+            "reinitialization has proven racy.",
+        )
+        def test_barrier_timeout_global(self):
+            dist.destroy_process_group()
+
+            # Explicitly pass world size to the barrier because we've
+            # just destroyed any state in torch.distributed.
+            self._barrier(wait_for=int(os.environ["WORLD_SIZE"]))
+
+            # Reinitialize global process group
+            timeout = timedelta(seconds=1)
+            dist.init_process_group(
+                init_method=INIT_METHOD,
+                backend=BACKEND,
+                world_size=int(os.environ["WORLD_SIZE"]),
+                rank=self.rank,
+                timeout=timeout,
+            )
+            self._test_barrier_timeout(dist.group.WORLD, timeout)
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo", "Only gloo backend supports timeouts"
+        )
+        def test_barrier_timeout_group(self):
+            timeout = timedelta(seconds=5)
+            _, group_id, _ = self._init_group_test(timeout=timeout)
+            if group_id is not None:
+                self._test_barrier_timeout(group_id, timeout)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo", "Only gloo backend supports timeouts"
+        )
+        def test_barrier_timeout_full_group(self):
+            timeout = timedelta(seconds=1)
+            _, group_id, _ = self._init_full_group_test(timeout=timeout)
+            if group_id is not None:
+                self._test_barrier_timeout(group_id, timeout)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @require_world_size(4)
+        @skip_if_lt_x_gpu(2)
+        def test_new_subgroups(self):
+            subgroup_size = 2
+            cur_subgroup, subgroups = dist.new_subgroups(subgroup_size)
+
+            world_size = dist.get_world_size()
+            self.assertEqual(cur_subgroup.size(), subgroup_size)
+            self.assertEqual(len(subgroups), world_size / subgroup_size)
+            self.assertFalse(dist._rank_not_in_group(cur_subgroup))
+
+            for subgroup in subgroups:
+                dist.destroy_process_group(subgroup)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @require_exact_world_size(4)
+        def test_new_subgroups_with_group_param(self):
+            # Initialize global test environment
+            self._init_global_test()
+            # Set up GPU devices for each rank
+            init_multigpu_helper(dist.get_world_size(), BACKEND)
+            # Create two subgroups: one with ranks [0,2] and another with ranks [1,3]
+            cur_subgroup, subgroups = dist.new_subgroups_by_enumeration(
+                ranks_per_subgroup_list=[[0, 2], [1, 3]]
+            )
+
+            # Further divide the current subgroup into sub-subgroups of size 1
+            cur_sub_subgroup, sub_subgroups = dist.new_subgroups(
+                group_size=1, group=cur_subgroup
+            )
+            # Verify we have 2 sub-subgroups (one for each rank in the original subgroup)
+            self.assertEqual(len(sub_subgroups), 2)
+            # Verify the current process's sub-subgroup has size 1
+            self.assertEqual(cur_sub_subgroup.size(), 1)
+            # Verify the current process is in its assigned sub-subgroup
+            self.assertFalse(dist._rank_not_in_group(group=cur_sub_subgroup))
+
+            # Clean up by destroying all created process groups
+            for sub_subgroup in sub_subgroups:
+                dist.destroy_process_group(sub_subgroup)
+
+            for subgroup in subgroups:
+                dist.destroy_process_group(subgroup)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @skip_if_no_gpu
+        def test_new_subgroups_group_size_exceeds_world_size(self):
+            with self.assertRaisesRegex(ValueError, "must not exceed"):
+                dist.new_subgroups(100)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @require_world_size(4)
+        @skip_if_lt_x_gpu(4)
+        def test_new_subgroups_world_size_not_divisible_by_group_size(self):
+            expected_msg = f"The world size ({dist.get_world_size()}) must be divisible by 'group_size=3'"
+            with self.assertRaisesRegex(
+                ValueError,
+                re.escape(expected_msg),
+            ):
+                dist.new_subgroups(3)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @require_world_size(4)
+        @skip_if_lt_x_gpu(4)
+        def test_new_subgroups_by_enumeration(self):
+            _group, _group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            cur_subgroup, subgroups = dist.new_subgroups_by_enumeration(
+                ranks_per_subgroup_list=[[0, 2], [1, 3]]
+            )
+            if device_id >= 4:
+                self.assertIsNone(cur_subgroup)
+            else:
+                self.assertEqual(cur_subgroup.size(), 2)
+                self.assertEqual(len(subgroups), 2)
+                if device_id == 0 or device_id == 2:
+                    self.assertEqual(cur_subgroup, subgroups[0])
+                else:
+                    self.assertEqual(cur_subgroup, subgroups[1])
+
+            for subgroup in subgroups:
+                dist.destroy_process_group(subgroup)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @require_world_size(4)
+        @skip_if_lt_x_gpu(4)
+        def test_new_subgroups_by_enumeration_input_rank_exceeds_world_size(self):
+            _group, group_id, _rank = self._init_global_test()
+            init_multigpu_helper(dist.get_world_size(), BACKEND)
+            world_size = get_world_size(group_id)
+
+            with self.assertRaisesRegex(
+                ValueError,
+                "The new group's rank should be within the world_size set by init_process_group",
+            ):
+                dist.new_subgroups_by_enumeration(
+                    ranks_per_subgroup_list=[[0, 1], [world_size, 2]]
+                )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @skip_if_no_gpu
+        def test_new_subgroups_by_enumeration_negative_input_rank(self):
+            self._init_global_test()
+
+            with self.assertRaisesRegex(
+                ValueError,
+                "The new group's rank should be within the world_size set by init_process_group",
+            ):
+                dist.new_subgroups_by_enumeration(
+                    ranks_per_subgroup_list=[[-1, -2], [-3, -4]]
+                )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @require_world_size(4)
+        @skip_if_lt_x_gpu(4)
+        def test_new_subgroups_overlap_not_allowed(self):
+            with self.assertRaisesRegex(
+                ValueError, "Rank 1 has appeared in both subgroup"
+            ):
+                dist.new_subgroups_by_enumeration(
+                    ranks_per_subgroup_list=[[0], [1, 2], [1, 3]]
+                )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_average_parameters(self):
+            rank = dist.get_rank()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+
+            model = nn.Sequential(
+                nn.Conv2d(3, 3, kernel_size=3, padding=1),
+                nn.ReLU(),
+                nn.Linear(1, 5, bias=False),
+            ).cuda(device_id)
+            # Test global model averaging
+            for p in model.parameters():
+                p.data = torch.ones_like(p.data)
+            model_averaging_utils.average_parameters(
+                params=model.parameters(), process_group=None
+            )
+            # Every element will be the same as the input.
+            for p in model.parameters():
+                self.assertEqual(p.data, torch.ones_like(p.data))
+
+            # Test partial model averaging
+            for p in model.parameters():
+                p.data = torch.ones_like(p.data) * rank
+            group_nccl = dist.new_group(ranks=[0, 1], backend="nccl")
+            model_averaging_utils.average_parameters(
+                params=model.parameters(), process_group=group_nccl
+            )
+            if not dist._rank_not_in_group(group_nccl):
+                # Every element on device 0 or 1 should be the average of 0 and 1, i.e., 0.5.
+                for p in model.parameters():
+                    self.assertEqual(p.data, torch.ones_like(p.data) * 0.5)
+            else:
+                # Every element on device not in the subgroup should remain the same.
+                for p in model.parameters():
+                    self.assertEqual(p.data, torch.ones_like(p.data) * rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_periodic_model_averager(self):
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+
+            model = nn.Linear(1, 5, bias=False).cuda(device_id)
+            param = next(model.parameters())
+            tensor = torch.ones_like(param.data) * rank
+            expected_avg_tensor = (
+                torch.ones_like(param.data) * sum(range(world_size)) / world_size
+            )
+            period = 4
+            for warmup_steps in [12, 13, 14, 15]:
+                averager = averagers.PeriodicModelAverager(
+                    period=period, warmup_steps=warmup_steps
+                )
+                for step in range(20):
+                    # Reset the parameters at every step.
+                    param.data = copy.deepcopy(tensor)
+                    for params in model.parameters():
+                        # mock grad
+                        params.grad = torch.ones_like(param.data)
+                    averager.average_parameters(model.parameters())
+                    if step >= warmup_steps and (step - warmup_steps) % period == 0:
+                        self.assertEqual(param.data, expected_avg_tensor)
+                    else:
+                        # No model averaging, so the parameters are not updated.
+                        self.assertEqual(param.data, tensor)
+
+        @skip_if_lt_x_gpu(2)
+        def test_periodic_model_averager_param_group(self):
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+
+            model = nn.Linear(1, 5, bias=False).cuda(device_id)
+            param = next(model.parameters())
+            opt = torch.optim.SGD(model.parameters(), lr=0.1)
+
+            period = 4
+            for warmup_steps in [12, 13, 14, 15]:
+                averager = averagers.PeriodicModelAverager(
+                    period=period, warmup_steps=warmup_steps
+                )
+                for step in range(20):
+                    # Reset the parameters at every step.
+                    for param_group in opt.param_groups:
+                        for params in param_group["params"]:
+                            # mock grad
+                            params.grad = torch.ones_like(param.data) * rank
+                            params.data = torch.ones_like(param.data) * rank
+                    averager.average_parameters(opt.param_groups)
+                    if step >= warmup_steps and (step - warmup_steps) % period == 0:
+                        for param_group in opt.param_groups:
+                            for params in param_group["params"]:
+                                if params.grad is None:
+                                    continue
+                                self.assertEqual(
+                                    param.data,
+                                    torch.ones_like(param.data)
+                                    * sum(range(world_size))
+                                    / world_size,
+                                )
+                    else:
+                        # No model averaging, so the parameters are not updated.
+                        for param_group in opt.param_groups:
+                            for params in param_group["params"]:
+                                if params.grad is None:
+                                    continue
+                                self.assertEqual(
+                                    param.data, torch.ones_like(param.data) * rank
+                                )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_1_level_hierarchical_model_averager_equivalent_to_periodic_model_averager(
+            self,
+        ):
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+
+            model = nn.Linear(1, 5, bias=False).cuda(device_id)
+            param = next(model.parameters())
+            tensor = torch.ones_like(param.data) * rank
+            expected_avg_tensor = (
+                torch.ones_like(param.data) * sum(range(world_size)) / world_size
+            )
+            period = 4
+            for warmup_steps in [12, 13, 14, 15]:
+                averager = hierarchicalSGD.HierarchicalModelAverager(
+                    # Run the global averaging at a period of 4,
+                    # which is equivalent to the above periodic model averaging test case.
+                    period_group_size_dict=OrderedDict([(period, world_size)]),
+                    warmup_steps=warmup_steps,
+                )
+
+                averager = averagers.PeriodicModelAverager(
+                    period=period, warmup_steps=warmup_steps
+                )
+                for step in range(20):
+                    # Reset the parameters at every step.
+                    param.data = copy.deepcopy(tensor)
+                    for params in model.parameters():
+                        # mock grad
+                        params.grad = torch.ones_like(param.data)
+                    averager.average_parameters(model.parameters())
+                    if step >= warmup_steps and (step - warmup_steps) % period == 0:
+                        self.assertEqual(param.data, expected_avg_tensor)
+                    else:
+                        # No model averaging, so the parameters are not updated.
+                        self.assertEqual(param.data, tensor)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @require_exact_world_size(4)
+        @skip_if_lt_x_gpu(4)
+        def test_3_level_hierarchical_model_averager(self):
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+
+            model = nn.Linear(1, 5, bias=False).cuda(device_id)
+            param = next(model.parameters())
+            tensor = torch.ones_like(param.data) * rank
+            # Set up such a hierarchical model averaging as follows:
+            # after the first 10 warmup steps,
+            # run model averaging every 2 steps within each subgroup of size 2,
+            # run model averaging every 4 steps within each subgroup of size 3,
+            # and run the global model averaging every 8 steps.
+            # If there is a conflict in model averaging at a step, only run the highest-level model averaging.
+            warmup_steps = 10
+            subgroup_size1 = 2
+            subgroup_avg_period1 = 2
+            subgroup_size2 = 4
+            subgroup_avg_period2 = 4
+            global_avg_period = 8
+            period_group_size_dict = OrderedDict(
+                [
+                    (subgroup_avg_period1, subgroup_size1),
+                    (subgroup_avg_period2, subgroup_size2),
+                    (global_avg_period, world_size),
+                ]
+            )
+            averager = hierarchicalSGD.HierarchicalModelAverager(
+                period_group_size_dict=period_group_size_dict, warmup_steps=warmup_steps
+            )
+            self.assertEqual(dist.get_pg_count(), len(period_group_size_dict))
+
+            subgroup1 = averager.period_process_group_dict[subgroup_avg_period1]
+            subgroup2 = averager.period_process_group_dict[subgroup_avg_period2]
+            real_group_ranks_res1 = _get_pg_config(subgroup1)["ranks"]
+            real_group_ranks_res2 = _get_pg_config(subgroup2)["ranks"]
+
+            expect_group_ranks_res1 = (
+                rank // subgroup_size1 * subgroup_size1
+                + np.array(list(range(subgroup_size1)))
+            ).tolist()
+            expect_group_ranks_res2 = (
+                rank // subgroup_size2 * subgroup_size2
+                + np.array(list(range(subgroup_size2)))
+            ).tolist()
+            self.assertEqual(real_group_ranks_res1, expect_group_ranks_res1)
+            self.assertEqual(real_group_ranks_res2, expect_group_ranks_res2)
+
+            expected_avg_tensor_within_subgroup1 = (
+                torch.ones_like(param.data)
+                * sum(real_group_ranks_res1)
+                / subgroup_size1
+            )
+            expected_avg_tensor_within_subgroup2 = (
+                torch.ones_like(param.data)
+                * sum(real_group_ranks_res2)
+                / subgroup_size2
+            )
+            expected_global_avg_tensor = (
+                torch.ones_like(param.data) * sum(range(world_size)) / world_size
+            )
+            for step in range(25):
+                # Reset the parameters at every step.
+                param.data = copy.deepcopy(tensor)
+                for params in model.parameters():
+                    # mock grad
+                    params.grad = torch.ones_like(param.data)
+                averager.average_parameters(model.parameters())
+                if step == 16 or step == 24:
+                    # Run global model averaging when `step` can be divided by 8.
+                    self.assertEqual(param.data, expected_global_avg_tensor)
+                elif step == 12 or step == 20:
+                    # Run model averaging within subgroup when `step` can be divided by 4 but not by 8.
+                    self.assertEqual(param.data, expected_avg_tensor_within_subgroup2)
+                elif step == 10 or step == 14 or step == 18 or step == 22:
+                    # Run model averaging within subgroup when `step` can be divided by 2 but not by 4 or 8.
+                    self.assertEqual(param.data, expected_avg_tensor_within_subgroup1)
+                else:
+                    # No model averaging, so the parameters are not updated.
+                    self.assertEqual(param.data, tensor)
+
+        # Coalescing manager (sync mode)
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl" or IS_FBCODE or IS_SANDCASTLE,
+            "Coalescing manager currently tests with NCCL only; internal test flaky",
+        )
+        def test_coalescing_manager(self):
+            self._barrier()
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+            num_colls = 2
+            size_per_coll = 8
+            small_tensors = [
+                torch.ones(size_per_coll, device=device_id) for _ in range(num_colls)
+            ]
+
+            with dist._coalescing_manager():
+                for i in range(num_colls):
+                    dist.all_reduce(small_tensors[i])
+
+            big_tensor = torch.ones(num_colls * size_per_coll, device=device_id)
+            dist.all_reduce(big_tensor)
+
+            for i in range(num_colls):
+                self.assertEqual(
+                    small_tensors[i],
+                    big_tensor[i * size_per_coll : (i + 1) * size_per_coll],
+                )
+
+            self._barrier()
+
+        # Coalescing manager (async mode)
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl" or IS_FBCODE or IS_SANDCASTLE,
+            "Coalescing manager currently tests with NCCL only; internal test flaky",
+        )
+        def test_coalescing_manager_async(self):
+            self._barrier()
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+            num_colls = 2
+            size_per_coll = 8
+            small_tensors = [
+                torch.ones(size_per_coll, device=device_id) for _ in range(num_colls)
+            ]
+
+            with dist._coalescing_manager(async_ops=True) as cm:
+                for i in range(num_colls):
+                    dist.all_reduce(small_tensors[i])
+            cm.wait()
+
+            big_tensor = torch.ones(num_colls * size_per_coll, device=device_id)
+            dist.all_reduce(big_tensor)
+
+            for i in range(num_colls):
+                self.assertEqual(
+                    small_tensors[i],
+                    big_tensor[i * size_per_coll : (i + 1) * size_per_coll],
+                )
+
+            self._barrier()
+
+        # NCCL Batch SEND RECV
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_batch_isend_irecv_nccl(self):
+            self._barrier()
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+            p2p_op_list = []
+            recv_tensors = [None for _ in range(world_size)]
+            expected_tensors = [None for _ in range(world_size)]
+
+            for val in ["1", "0"]:
+                os.environ["TORCH_NCCL_BLOCKING_WAIT"] = val
+                for src in range(world_size):
+                    send_tensor = _build_tensor(rank + 1, device_id=device_id).fill_(
+                        src
+                    )
+                    recv_tensors[src] = _build_tensor(
+                        src + 1, value=-1, device_id=device_id
+                    ).fill_(-1)
+                    expected_tensors[src] = _build_tensor(
+                        src + 1, value=-1, device_id=device_id
+                    ).fill_(rank)
+                    recv_op = dist.P2POp(dist.irecv, recv_tensors[src], src)
+                    p2p_op_list.append(recv_op)
+                    send_op = dist.P2POp(dist.isend, send_tensor, src)
+                    p2p_op_list.append(send_op)
+
+                reqs = dist.batch_isend_irecv(p2p_op_list)
+                for req in reqs:
+                    req.wait()
+
+                for src in range(world_size):
+                    self.assertEqual(recv_tensors[src], expected_tensors[src])
+
+            self._barrier()
+
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_batch_isend_irecv_ring_exchange_nccl(self):
+            self._barrier()
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+
+            send_tensor = _build_tensor(world_size, device_id=device_id)
+            recv_tensor = _build_tensor(world_size, value=-1, device_id=device_id)
+            send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1) % world_size)
+            recv_op = dist.P2POp(
+                dist.irecv, recv_tensor, (rank - 1 + world_size) % world_size
+            )
+            reqs = dist.batch_isend_irecv([send_op, recv_op])
+            for req in reqs:
+                req.wait()
+
+            self._barrier()
+
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_batch_isend_irecv_self_nccl(self):
+            self._barrier()
+            # Ensure the process group has been fully initialized (needed by
+            # the first sub-group batch_isend_irecv call)
+            dist.barrier()
+            rank = dist.get_rank()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            p2p_op_list = []
+
+            if rank == 0:
+                send_tensor = _build_tensor(rank + 1, device_id=device_id)
+                recv_tensor = _build_tensor(rank + 1, value=-1, device_id=device_id)
+                recv_op = dist.P2POp(dist.irecv, recv_tensor, 0)
+                p2p_op_list.append(recv_op)
+                send_op = dist.P2POp(dist.isend, send_tensor, 0)
+                p2p_op_list.append(send_op)
+
+                reqs = dist.batch_isend_irecv(p2p_op_list)
+                for req in reqs:
+                    req.wait()
+
+            self._barrier()
+
+        @skip_if_no_gpu
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_batch_isend_irecv_no_rank_zero_nccl(self):
+            self._barrier()
+            # Ensure the process group has been fully initialized (needed by
+            # the first sub-group batch_isend_irecv call)
+            dist.barrier()
+            rank = dist.get_rank()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+            p2p_op_list = []
+
+            if rank == 1:
+                peer = 2
+            elif rank == 2:
+                peer = 1
+
+            if rank in [1, 2]:
+                send_tensor = _build_tensor(rank + 1, device_id=device_id)
+                recv_tensor = _build_tensor(peer + 1, value=-1, device_id=device_id)
+                recv_op = dist.P2POp(dist.irecv, recv_tensor, peer)
+                p2p_op_list.append(recv_op)
+                send_op = dist.P2POp(dist.isend, send_tensor, peer)
+                p2p_op_list.append(send_op)
+
+                reqs = dist.batch_isend_irecv(p2p_op_list)
+                for req in reqs:
+                    req.wait()
+
+            self._barrier()
+
+        # GLOO Batch SEND RECV CPU
+        @skip_but_pass_in_sandcastle_if(BACKEND != "gloo", "GLOO Batch Send Recv CPU")
+        def test_batch_isend_irecv_gloo(self):
+            self._barrier()
+            rank = dist.get_rank()
+            p2p_op_list = []
+
+            for src in range(dist.get_world_size()):
+                if src == rank:
+                    continue
+                send_tensor = _build_tensor(rank + 1)
+                recv_tensor = _build_tensor(src + 1, value=-1)
+                recv_op = dist.P2POp(dist.irecv, recv_tensor, src)
+                p2p_op_list.append(recv_op)
+                send_op = dist.P2POp(dist.isend, send_tensor, src)
+                p2p_op_list.append(send_op)
+
+            reqs = dist.batch_isend_irecv(p2p_op_list)
+            for req in reqs:
+                req.wait()
+
+            self._barrier()
+
+        # GLOO Batch SEND RECV CPU with provided tags
+        @skip_but_pass_in_sandcastle_if(BACKEND != "gloo", "GLOO Batch Send Recv CPU")
+        def test_batch_isend_irecv_gloo_tags(self):
+            self._barrier()
+            rank = dist.get_rank()
+            p2p_op_list = []
+
+            for src in range(dist.get_world_size()):
+                if src == rank:
+                    continue
+                send_tensor = _build_tensor(rank + 1)
+                recv_tensor = _build_tensor(src + 1, value=-1)
+                recv_op = dist.P2POp(dist.irecv, recv_tensor, src, tag=src)
+                p2p_op_list.append(recv_op)
+                send_op = dist.P2POp(dist.isend, send_tensor, src, tag=rank)
+                p2p_op_list.append(send_op)
+
+            reqs = dist.batch_isend_irecv(p2p_op_list)
+            for req in reqs:
+                req.wait()
+
+            self._barrier()
+
+        # NCCL Batch SEND RECV Op Error
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_batch_isend_irecv_op_err(self):
+            self._barrier()
+            rank = dist.get_rank()
+            if rank == 0:
+                rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+                device_id = rank_to_GPU[rank][0]
+                with self.assertRaisesRegex(ValueError, "^Invalid ``op``"):
+                    send_tensor = _build_tensor(rank + 1, device_id=device_id)
+                    send_op = dist.P2POp(dist.broadcast, send_tensor, 1)
+                    dist.batch_isend_irecv([send_op])
+
+        # NCCL Batch SEND RECV p2p_op_list Error
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_batch_isend_irecv_op_list_err(self):
+            self._barrier()
+            rank = dist.get_rank()
+            if rank == 0:
+                with self.assertRaisesRegex(ValueError, "^Invalid ``p2p_op_list``"):
+                    dist.batch_isend_irecv([1, 2])
+
+        # NCCL Batch SEND RECV Mixed Backend Error
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_batch_isend_irecv_mixed_backend_err(self):
+            self._barrier()
+            rank = dist.get_rank()
+            init_multigpu_helper(dist.get_world_size(), BACKEND)
+            group_gloo = dist.new_group(ranks=[0, 1], backend="gloo")
+            group_nccl = dist.new_group(ranks=[0, 1], backend="nccl")
+            if rank == 0:
+                with self.assertRaisesRegex(
+                    ValueError, "All ops need to use the same group"
+                ):
+                    send_tensor = _build_tensor(rank + 1)
+                    send_op_gloo = dist.P2POp(dist.isend, send_tensor, 1, group_gloo)
+                    send_op_nccl = dist.P2POp(dist.isend, send_tensor, 1, group_nccl)
+                    dist.batch_isend_irecv([send_op_gloo, send_op_nccl])
+
+        # NCCL SEND RECV
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def _test_send_recv_nccl(self, profiler_ctx=None):
+            # TODO: now that nccl send/recv is supported, there does not seem to
+            # be a need to have nccl send/recv be tested separately.
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+
+            tensor = _build_tensor(rank + 1, device_id=device_id)
+            profiler_cls = profiler_ctx if profiler_ctx is not None else nullcontext()
+            with profiler_cls as prof:
+                for src in range(world_size):
+                    if src == rank:
+                        # Send mode
+                        for dst in range(world_size):
+                            if dst == rank:
+                                continue
+                            dist.send(tensor, dst)
+                    else:
+                        # Recv mode
+                        expected_tensor = _build_tensor(src + 1)
+                        output_tensor = _build_tensor(
+                            src + 1, value=-1, device_id=device_id
+                        )
+                        dist.recv(output_tensor, src)
+                        self.assertEqual(output_tensor, expected_tensor)
+
+                self._barrier()
+
+            if profiler_ctx is not None:
+                backend = dist.get_backend()
+                if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
+                    for event_name in [f"{backend}:send", f"{backend}:recv"]:
+                        events = get_profiling_event(
+                            event_name, prof, dedup_gpu_user_annotation=True
+                        )
+                        self.assertTrue(events)
+                        # Event order is not deterministic, so simply assert their shape
+                        # is found in the following list.
+                        expected_shapes = [
+                            [[rank + 1] * 3] for rank in range(dist.get_world_size())
+                        ]
+                        for event in events:
+                            self.assertTrue(event.input_shapes in expected_shapes)
+
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_send_recv_nccl(self):
+            self._test_send_recv_nccl()
+
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_send_recv_nccl_autograd_profiler(self):
+            profiler_ctx = torch.autograd.profiler.profile(record_shapes=True)
+            self._test_send_recv_nccl(profiler_ctx)
+
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode causes hang")
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
+        )
+        def test_send_recv_nccl_torch_profiler(self):
+            profiler_ctx = torch.profiler.profile(
+                activities=[
+                    torch.profiler.ProfilerActivity.CPU,
+                    torch.profiler.ProfilerActivity.CUDA,
+                ],
+                record_shapes=True,
+            )
+            self._test_send_recv_nccl(profiler_ctx)
+
+        # SEND RECV
+        def _test_send_recv(self, profiler_ctx):
+            rank = dist.get_rank()
+            send_size = rank + 1
+            tensor = _build_tensor(send_size)
+            ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
+            with ctx as prof:
+                for src in range(dist.get_world_size()):
+                    if src == rank:
+                        # Send mode
+                        for dst in range(dist.get_world_size()):
+                            if dst == rank:
+                                continue
+                            dist.send(tensor, dst)
+                    else:
+                        # Recv mode
+                        recv_size = src + 1
+                        expected_tensor = _build_tensor(recv_size)
+                        output_tensor = _build_tensor(recv_size, value=-1)
+                        dist.recv(output_tensor, src)
+                        self.assertEqual(output_tensor, expected_tensor)
+
+            if profiler_ctx is not None:
+                backend = dist.get_backend()
+                if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
+                    for event_name in [f"{backend}:send", f"{backend}:recv"]:
+                        events = get_profiling_event(event_name, prof)
+                        # Each rank sends/recvs from all other ranks.
+                        event_count = sum(e.count for e in events)
+                        expected_event_count = dist.get_world_size() - 1
+                        self.assertEqual(event_count, expected_event_count)
+                        # Event order is not deterministic, so simply assert their shape
+                        # is found in the following list.
+                        expected_shapes = [
+                            [[rank + 1] * 3] for rank in range(dist.get_world_size())
+                        ]
+                        for event in events:
+                            self.assertTrue(event.is_async)
+                            self.assertTrue(event.input_shapes in expected_shapes)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl send/recv tested by test_send_recv_nccl"
+        )
+        def test_send_recv(self):
+            self._test_send_recv(profiler_ctx=None)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
+        )
+        def test_send_recv_autograd_profiler(self):
+            autograd_profiler_ctx = _create_autograd_profiler()
+            self._test_send_recv(profiler_ctx=autograd_profiler_ctx)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
+        )
+        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode causes hang")
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
+        )
+        def test_send_recv_torch_profiler(self):
+            torch_profiler_ctx = _create_torch_profiler()
+            return self._test_send_recv(profiler_ctx=torch_profiler_ctx)
+
+        # SEND RECV ANY SOURCE
+        def _test_send_recv_any_source(self, profiler_ctx):
+            rank = dist.get_rank()
+            send_recv_size = 10
+            tensor = _build_tensor(send_recv_size, value=rank)
+            recv_ranks = []
+            irecv_ranks = []
+
+            ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
+            with ctx as prof:
+                for dst in range(dist.get_world_size()):
+                    if dst == rank:
+                        # Recv mode
+                        for dst in range(dist.get_world_size()):
+                            if dst == rank:
+                                continue
+
+                            for recv in ["recv", "irecv"]:
+                                output_tensor = _build_tensor(send_recv_size, value=-1)
+
+                                if recv == "recv":
+                                    sender = dist.recv(output_tensor)
+                                    recv_ranks.append(sender)
+                                elif recv == "irecv":
+                                    work = dist.irecv(output_tensor)
+                                    work.wait()
+                                    sender = work._source_rank()
+                                    irecv_ranks.append(sender)
+
+                                # Assert the scalar value "sender" that should be
+                                # equal to the rank of the sender is equal to all
+                                # values in the received tensor.
+                                self.assertTrue(output_tensor.eq(sender).all())
+                    else:
+                        # Send mode
+                        dist.send(tensor, dst)  # recv
+                        dist.send(tensor, dst)  # irecv
+
+            if profiler_ctx is not None:
+                backend = dist.get_backend()
+                if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
+                    for event_name in [f"{backend}:send", f"{backend}:recvAnySource"]:
+                        events = get_profiling_event(event_name, prof)
+                        # Each rank sends/recvs from other rank twice.
+                        self.assertEqual(
+                            sum(event.count for event in events),
+                            2 * (dist.get_world_size() - 1),
+                        )
+                        for event in events:
+                            self.assertTrue(event.is_async)
+                            self.assertEqual(event.input_shapes, [[send_recv_size] * 3])
+
+                # Each rank would have 2 * (world_size - 1) sends, verify that
+                # globally we receive the same amount on the other end.
+                recv_ranks_tensor = torch.cat(
+                    (torch.tensor(recv_ranks), torch.tensor(irecv_ranks)), 0
+                )
+                global_recv_ranks = [
+                    torch.empty_like(recv_ranks_tensor)
+                    for _ in range(dist.get_world_size())
+                ]
+                dist.all_gather(global_recv_ranks, recv_ranks_tensor)
+                global_recv_ranks_list = []
+                for tensor in global_recv_ranks:
+                    global_recv_ranks_list += tensor.tolist()
+
+                from itertools import groupby
+
+                global_recv_ranks_list.sort()
+                frequency = [
+                    len(list(group)) for key, group in groupby(global_recv_ranks_list)
+                ]
+                self.assertEqual(dist.get_world_size(), len(frequency))
+                self.assertEqual(
+                    [2 * (dist.get_world_size() - 1)] * dist.get_world_size(), frequency
+                )
+                self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["sendrecv anysource"],
+            f"{BACKEND} does not support send/recv from any source",
+        )
+        def test_send_recv_any_source(self):
+            self._test_send_recv_any_source(profiler_ctx=None)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["sendrecv anysource"],
+            f"{BACKEND} does not support send/recv from any source",
+        )
+        def test_send_recv_any_source_autograd_profiler(self):
+            autograd_profiler_ctx = _create_autograd_profiler()
+            self._test_send_recv_any_source(profiler_ctx=autograd_profiler_ctx)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["sendrecv anysource"],
+            f"{BACKEND} does not support send/recv from any source",
+        )
+        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang")
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
+        )
+        def test_send_recv_any_source_torch_profiler(self):
+            torch_profiler_ctx = _create_torch_profiler()
+            return self._test_send_recv_any_source(profiler_ctx=torch_profiler_ctx)
+
+        # SEND RECV WITH TAG
+        def _test_send_recv_with_tag(self, profiler_ctx):
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            send_recv_size = 10
+            tensor = _build_tensor(send_recv_size, value=rank)
+            ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
+            with ctx as prof:
+                for dst in range(world_size):
+                    if dst == rank:
+                        # Recv mode
+                        for src in range(world_size):
+                            if src == rank:
+                                continue
+                            output_tensor = _build_tensor(send_recv_size, value=-1)
+                            dist.recv(output_tensor, src, tag=src)
+                            self.assertTrue(output_tensor.eq(src).all())
+                    else:
+                        # Send mode
+                        dist.send(tensor, dst, tag=rank)
+
+            if profiler_ctx is not None:
+                backend = dist.get_backend()
+                if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
+                    for event_name in [f"{backend}:send", f"{backend}:recv"]:
+                        events = get_profiling_event(event_name, prof)
+                        # Each rank sends/recvs from all other ranks
+                        event_count = sum(e.count for e in events)
+                        expected_event_count = dist.get_world_size() - 1
+                        self.assertEqual(event_count, expected_event_count)
+                        for event in events:
+                            self.assertTrue(event.is_async)
+                            self.assertEqual(event.name, event_name)
+                            self.assertEqual(event.input_shapes, [[send_recv_size] * 3])
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
+        )
+        def test_send_recv_with_tag(self):
+            self._test_send_recv_with_tag(profiler_ctx=None)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
+        )
+        def test_send_recv_with_tag_autograd_profiler(self):
+            autograd_profiler_ctx = _create_autograd_profiler()
+            return self._test_send_recv_with_tag(profiler_ctx=autograd_profiler_ctx)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
+        )
+        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang")
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
+        )
+        def test_send_recv_with_tag_torch_profiler(self):
+            torch_profiler_ctx = _create_torch_profiler()
+            return self._test_send_recv_with_tag(profiler_ctx=torch_profiler_ctx)
+
+        # ISEND
+        def _test_isend(self, profiler_ctx):
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
+            with ctx as prof:
+                if rank == 0:
+                    requests = [
+                        dist.isend(_build_tensor(dest, 10), dest)
+                        for dest in range(1, world_size)
+                    ]
+                    for request in requests:
+                        request.wait()
+                        self.assertTrue(request.is_completed())
+                else:
+                    tensor = _build_tensor(rank, -1)
+                    dist.recv(tensor, 0)
+                    self.assertEqual(tensor, _build_tensor(rank, 10))
+
+                self._barrier()
+
+            if profiler_ctx is not None:
+                backend = dist.get_backend()
+                if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
+                    expected_event_name = (
+                        f"{backend}:send" if rank == 0 else f"{backend}:recv"
+                    )
+                    events = get_profiling_event(expected_event_name, prof)
+                    event_count = sum(e.count for e in events)
+                    expected_count = dist.get_world_size() - 1 if rank == 0 else 1
+                    self.assertEqual(expected_count, event_count)
+                    # Event ordering is not guaranteed, so simply ensure the shapes are
+                    # found in the following map.
+                    expected_shapes = {
+                        r: [[r] * 3] for r in range(1, dist.get_world_size())
+                    }
+                    for event in events:
+                        self.assertTrue(event.is_async)
+                        self.assertEqual(event.name, expected_event_name)
+                        if rank == 0:
+                            self.assertTrue(
+                                event.input_shapes in expected_shapes.values()
+                            )
+                        else:
+                            self.assertEqual(event.input_shapes, expected_shapes[rank])
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support isend"
+        )
+        def test_isend(self):
+            self._test_isend(profiler_ctx=None)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support isend"
+        )
+        def test_isend_autograd_profiler(self):
+            autograd_profiler_ctx = _create_autograd_profiler()
+            self._test_isend(profiler_ctx=autograd_profiler_ctx)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support isend"
+        )
+        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang")
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
+        )
+        def test_isend_torch_profiler(self):
+            torch_profiler_ctx = _create_torch_profiler()
+            self._test_isend(profiler_ctx=torch_profiler_ctx)
+
+        # IRECV
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support irecv"
+        )
+        def test_irecv(self):
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+
+            if rank == 0:
+                expected_tensors = [
+                    _build_tensor(src, -1) for src in range(1, world_size)
+                ]
+                requests = [
+                    dist.irecv(expected_tensors[src - 1], src)
+                    for src in range(1, world_size)
+                ]
+
+                for src in range(1, world_size):
+                    requests[src - 1].wait()
+                    self.assertTrue(requests[src - 1].is_completed())
+                    self.assertEqual(expected_tensors[src - 1], _build_tensor(src, 10))
+            else:
+                tensor = _build_tensor(rank, 10)
+                dist.send(tensor, 0)
+
+            self._barrier()
+
+        # BROADCAST
+        def _test_broadcast_helper(
+            self,
+            group,
+            group_id,
+            rank,
+            cuda=False,
+            rank_to_GPU=None,
+            with_options=False,
+        ):
+            for dtype, value, requires_cuda in [
+                (torch.float, -1e-10, False),
+                (torch.double, -1e-100, False),
+                (torch.half, -0.1, True),
+                (torch.int8, -2, False),
+                (torch.uint8, 129, False),
+                (torch.int, -1e5, False),
+                (torch.long, -1e15, False),
+            ]:
+                if requires_cuda and not cuda:
+                    continue
+                for src in group:
+                    expected_tensor = _build_tensor(src + 1, value, dtype)
+                    if cuda:
+                        expected_tensor = expected_tensor.cuda(rank_to_GPU[rank][0])
+                    if rank == src:
+                        if with_options:
+                            opts = dist.BroadcastOptions()
+                            opts.rootTensor = 0
+                            opts.rootRank = src
+                            self.call_dist_op(
+                                ":broadcast",
+                                True,
+                                group_id.broadcast,
+                                [expected_tensor],
+                                opts,
+                            )
+                        else:
+                            self.call_dist_op(
+                                ":broadcast",
+                                False,
+                                dist.broadcast,
+                                expected_tensor,
+                                src,
+                                group_id,
+                            )
+                    else:
+                        tensor = _build_tensor(src + 1, -1, dtype)
+                        if cuda:
+                            tensor = tensor.cuda(rank_to_GPU[rank][0])
+                        if with_options:
+                            opts = dist.BroadcastOptions()
+                            opts.rootTensor = 0
+                            opts.rootRank = src
+                            self.call_dist_op(
+                                ":broadcast", True, group_id.broadcast, [tensor], opts
+                            )
+                        else:
+                            self.call_dist_op(
+                                ":broadcast",
+                                False,
+                                dist.broadcast,
+                                tensor,
+                                src,
+                                group_id,
+                            )
+                        self.assertEqual(tensor.size(), expected_tensor.size())
+                        self.assertEqual(
+                            tensor.ne(expected_tensor).max(), torch.tensor(False)
+                        )
+
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_broadcast(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_broadcast_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo" and BACKEND != "nccl",
+            "Only Gloo and Nccl backend supports CUDA allReduce",
+        )
+        @skip_if_no_gpu
+        def test_broadcast_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+            self._test_broadcast_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_broadcast_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_broadcast_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_broadcast_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_broadcast_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl",
+            "Only NCCL backend supports high priority stream",
+        )
+        @skip_if_no_gpu
+        def test_nccl_high_priority_stream(self):
+            group, _, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+
+            new_port = str(MASTER_PORT + 1)
+            os.environ["MASTER_PORT"] = new_port
+            gen_iterator = dist.rendezvous("env://", rank, dist.get_world_size())
+            store, rank, size = next(gen_iterator)
+            store = dist.PrefixStore(new_port, store)
+
+            opts = dist.ProcessGroupNCCL.Options()
+            opts.is_high_priority_stream = False
+            group_id = dist.ProcessGroupNCCL(store, rank, size, opts)
+
+            self._test_broadcast_helper(group, group_id, rank, True, rank_to_GPU, True)
+
+        # REDUCE
+        def _test_reduce_helper(
+            self,
+            group,
+            group_id,
+            rank,
+            op,
+            master_value,
+            worker_value,
+            expected_value,
+            cuda=False,
+            rank_to_GPU=None,
+        ):
+            for src in group:
+                tensor = _build_tensor(src + 1).fill_(
+                    master_value if rank == src else worker_value
+                )
+                if cuda:
+                    tensor = tensor.cuda(rank_to_GPU[rank][0])
+                self.call_dist_op(
+                    ":reduce",
+                    False,
+                    dist.reduce,
+                    tensor,
+                    src,
+                    op,
+                    group_id,
+                    tensor_shapes=[tensor.shape],
+                )
+                if rank == src:
+                    self.assertEqual(tensor, _build_tensor(src + 1, expected_value))
+
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_sum(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA reduce"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        @skip_if_no_gpu
+        def test_reduce_sum_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+            self._test_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + 10 * (len(group) - 1),
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_product(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                2,
+                10,
+                reduce(operator.mul, [10] * (len(group) - 1), 2),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_min(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_max(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        @skip_if_small_worldsize
+        def test_reduce_group_sum(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        @skip_if_small_worldsize
+        def test_reduce_group_product(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                2,
+                10,
+                reduce(operator.mul, [10] * (len(group) - 1), 2),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        @skip_if_small_worldsize
+        def test_reduce_group_min(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        @skip_if_small_worldsize
+        def test_reduce_group_max(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_full_group_sum(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_full_group_product(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                2,
+                10,
+                reduce(operator.mul, [10] * (len(group) - 1), 2),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_full_group_min(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_full_group_max(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
+            )
+
+        # REDUCE TWICE
+        def _test_reduce_twice_helper(
+            self,
+            group,
+            group_id,
+            rank,
+            op,
+            master_value,
+            worker_value,
+            expected_value,
+            cuda=False,
+            rank_to_GPU=None,
+        ):
+            for src in group:
+                tensors = [
+                    _build_tensor(src + 1).fill_(
+                        master_value if rank == src else worker_value
+                    )
+                    for i in range(2)
+                ]
+                if cuda:
+                    for i in range(2):
+                        tensors[i] = tensors[i].cuda(rank_to_GPU[rank][0])
+                self.call_dist_op(
+                    ":reduce",
+                    False,
+                    dist.reduce,
+                    tensors[0],
+                    src,
+                    op,
+                    group_id,
+                    secondary_op_call=lambda: dist.reduce(
+                        tensors[1], src, op, group_id
+                    ),
+                    tensor_shapes=[tensors[0].shape],
+                )
+                if rank == src:
+                    for tensor in tensors:
+                        self.assertEqual(tensor, _build_tensor(src + 1, expected_value))
+
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_sum_twice(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_reduce_twice_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA reduce"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        @skip_if_no_gpu
+        def test_reduce_sum_cuda_twice(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+            self._test_reduce_twice_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + 10 * (len(group) - 1),
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports reduce_scatter_v"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        @skip_if_no_gpu
+        def test_reduce_scatter_v_cuda(self):
+            self._barrier()
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+
+            input_split_sizes = [src + 1 for src in group]
+            start_len = sum(input_split_sizes[:rank])
+            end_len = start_len + input_split_sizes[rank]
+            sum_len = sum(input_split_sizes)
+            master_value = 2
+            worker_value = 10
+
+            for async_val in [True, False]:
+                tensor = _build_tensor(sum_len, worker_value, device_id=device_id)
+                tensor[start_len:end_len].fill_(master_value)
+                out_tensor = (
+                    torch.empty(
+                        input_split_sizes[rank], sum_len, sum_len, dtype=torch.float
+                    )
+                    .fill_(-1)
+                    .cuda(device_id)
+                )
+
+                req = dist.reduce_scatter(
+                    out_tensor,
+                    list(torch.split(tensor, input_split_sizes)),
+                    dist.ReduceOp.SUM,
+                    group_id,
+                    async_val,
+                )
+                if async_val:
+                    req.wait()
+
+                expected_value = 2 + (10 * (len(group) - 1))
+                expected_tensor = torch.empty(
+                    input_split_sizes[rank], sum_len, sum_len, dtype=torch.float
+                )
+                expected_tensor = expected_tensor.fill_(expected_value).cuda(device_id)
+
+                self.assertEqual(out_tensor, expected_tensor)
+            self._barrier()
+
+        # Test reduce_scatter_tensor accepting single tensor as input
+        def _reduce_scatter_tensor_helper(
+            self, tensor_out, tensor_in, group_id, rank, cuda=True, rank_to_GPU=None
+        ):
+            if cuda:
+                tensor_in = tensor_in.cuda(rank_to_GPU[rank][0])
+                tensor_out = tensor_out.cuda(rank_to_GPU[rank][0])
+            tensor_shapes = [tensor_out.shape]
+            self.call_dist_op(
+                ":reduce_scatter_tensor",
+                False,
+                dist.reduce_scatter_tensor,
+                tensor_out,
+                tensor_in,
+                dist.ReduceOp.SUM,
+                group_id,
+                False,
+                expect_event=False,
+                tensor_shapes=tensor_shapes,
+            )
+            return tensor_out
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA reduce_scatter_tensor"
+        )
+        @skip_if_no_gpu
+        def test_reduce_scatter_tensor_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            size = 2
+            tensor_out = torch.zeros(size, dtype=torch.int64)
+
+            # Concatenated input
+            tensor_in = torch.arange(len(group) * size)
+            tensor_out = self._reduce_scatter_tensor_helper(
+                tensor_out, tensor_in, group_id, rank, True, rank_to_GPU
+            )
+            # Check result
+            expected_tensor = torch.arange(rank * size, (rank + 1) * size) * len(group)
+            self.assertEqual(tensor_out, expected_tensor)
+            self._barrier()
+
+            # Stacked input
+            tensor_in = torch.reshape(tensor_in, (len(group), size))
+            tensor_out = self._reduce_scatter_tensor_helper(
+                tensor_out, tensor_in, group_id, rank, True, rank_to_GPU
+            )
+            # Check result
+            # Should be the same as the result in concatenated case
+            self.assertEqual(tensor_out, expected_tensor)
+            self._barrier()
+
+        def call_dist_op(
+            self,
+            profiling_title_postfix,
+            is_async,
+            op,
+            *args,
+            expect_event=True,
+            secondary_op_call=None,
+            profile_cuda=False,
+            tensor_shapes=None,
+            **kwargs,
+        ):
+            op_calls = [lambda: op(*args, **kwargs)]
+            if secondary_op_call is not None:
+                op_calls.append(secondary_op_call)
+
+            autograd_profiler_ctx = torch.autograd.profiler.profile(
+                use_cuda=profile_cuda, record_shapes=True
+            )
+
+            # TODO: move this test to use torch.profiler once kineto issues are
+            # fixed internally.
+            with autograd_profiler_ctx:
+                works = [op_call() for op_call in op_calls]
+                if is_async:
+                    for work in works:
+                        work.wait()
+
+            if expect_event and dist.get_backend() in PROFILING_SUPPORTED_BACKENDS:
+                # We are only interested in the backend's implementation not the dispatcher wrapper.
+                events = get_profiling_event(
+                    dist.get_backend() + profiling_title_postfix, autograd_profiler_ctx
+                )
+                # DETAIL debug mode can use a pg wrapper that issues more collectives
+                # under the hood
+                if dist.get_debug_level() != dist.DebugLevel.DETAIL:
+                    self.assertEqual(len(events), len(op_calls))
+                for e in events:
+                    self.assertTrue(e.is_async)
+                    self.assertEqual(e.count, 1)
+                    self.assertGreaterEqual(e.cpu_time, 0)
+                    # Verify tensor shapes if given
+                    # DETAIL debug mode can use a pg wrapper that issues more collectives
+                    # under the hood
+                    if (
+                        tensor_shapes is not None
+                        and dist.get_debug_level() != dist.DebugLevel.DETAIL
+                    ):
+                        self.assertEqual(
+                            e.input_shapes,
+                            tensor_shapes,
+                            f"event shape: {e.input_shapes} vs tensor {tensor_shapes}",
+                        )
+
+        # ALL REDUCE
+        def _test_all_reduce_helper(
+            self,
+            group,
+            group_id,
+            rank,
+            op,
+            master_value,
+            worker_value,
+            expected_value,
+            cuda=False,
+            rank_to_GPU=None,
+            dtype=torch.float,
+            async_op=False,
+        ):
+            for src in group:
+                curr_value = master_value if rank == src else worker_value
+
+                tensor = _build_tensor(src + 1, dtype=dtype).fill_(curr_value)
+                if cuda:
+                    tensor = tensor.cuda(rank_to_GPU[rank][0])
+                if tensor.dtype == torch.complex64:
+                    tensor_shapes = [torch.view_as_real(tensor).shape]
+                else:
+                    tensor_shapes = [tensor.shape]
+                self.call_dist_op(
+                    ":all_reduce",
+                    async_op,
+                    dist.all_reduce,
+                    tensor,
+                    op,
+                    group_id,
+                    async_op=async_op,
+                    tensor_shapes=tensor_shapes,
+                )
+                # Currently, only Gloo backend has profiling tested with CUDA enabled.
+                # Only run cuda profiling test for one rank to speed up since
+                # running with different src_rank does not affect the correctness.
+                if (
+                    src == 0
+                    and cuda
+                    and dist.get_backend() in CUDA_PROFILING_SUPPORTED_BACKENDS
+                ):
+                    self.call_dist_op(
+                        ":all_reduce",
+                        async_op,
+                        dist.all_reduce,
+                        tensor,
+                        op,
+                        group_id,
+                        async_op=async_op,
+                        profile_cuda=True,
+                        tensor_shapes=tensor_shapes,
+                    )
+
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_sum(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_sum_async(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+                async_op=True,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo" and BACKEND != "nccl",
+            "Only Gloo and NCCL backends will have CUDA allReduce tested",
+        )
+        @skip_if_no_gpu
+        def test_all_reduce_sum_cuda(self):
+            torch.cuda.set_device(self.rank)
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo" and BACKEND != "nccl",
+            "Only Gloo and NCCL backends will have CUDA allReduce tested",
+        )
+        @skip_if_no_gpu
+        def test_all_reduce_sum_cuda_async(self):
+            torch.cuda.set_device(self.rank)
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+                True,
+                rank_to_GPU,
+                async_op=True,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_sum_complex(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                complex(2, 3),
+                complex(10, 11),
+                complex(2, 3) + (complex(10, 11) * (len(group) - 1)),
+                dtype=torch.cfloat,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_complex_unsupported_ops(self):
+            unsupported_ops = [
+                dist.ReduceOp.MAX,
+                dist.ReduceOp.MIN,
+                dist.ReduceOp.PRODUCT,
+                dist.ReduceOp.BAND,
+                dist.ReduceOp.BOR,
+                dist.ReduceOp.BXOR,
+            ]
+            _group, group_id, _rank = self._init_global_test()
+            for unsupported_op in unsupported_ops:
+                with self.assertRaisesRegex(ValueError, "all_reduce does not support"):
+                    dist.all_reduce(
+                        _build_tensor(1, dtype=torch.cfloat), unsupported_op, group_id
+                    )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo" and BACKEND != "nccl",
+            "Only Gloo and NCCL backends will have CUDA allReduce tested",
+        )
+        @skip_if_no_gpu
+        def test_all_reduce_sum_cuda_complex(self):
+            torch.cuda.set_device(self.rank)
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                complex(2, 3),
+                complex(10, 11),
+                complex(2, 3) + (complex(10, 11) * (len(group) - 1)),
+                True,
+                rank_to_GPU,
+                dtype=torch.cfloat,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_product(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                2,
+                10,
+                reduce(operator.mul, [10] * (len(group) - 1), 2),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_min(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_max(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
+            )
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_group_sum(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+            )
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_group_product(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                2,
+                10,
+                reduce(operator.mul, [10] * (len(group) - 1), 2),
+            )
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_group_min(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
+            )
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_group_max(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_full_group_sum(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_full_group_product(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                2,
+                10,
+                reduce(operator.mul, [10] * (len(group) - 1), 2),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_full_group_min(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_full_group_max(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
+            )
+
+        # SPARSE ALL REDUCE
+        def _test_sparse_all_reduce_sum(self, fn):
+            _group, group_id, rank = self._init_global_test()
+
+            tests = simple_sparse_reduce_tests(
+                rank, dist.get_world_size(), num_inputs=1
+            )
+            for inputs, outputs in tests:
+                tensors = [fn(input) for input in inputs]
+                dist.all_reduce(tensors[0], dist.ReduceOp.SUM, group_id)
+                self.assertEqual(tensors[0], outputs[0])
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo", "Only Gloo backend support sparse all reduce"
+        )
+        def test_sparse_all_reduce_sum(self):
+            self._test_sparse_all_reduce_sum(lambda t: t)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo", "Only Gloo backend support sparse all reduce"
+        )
+        @skip_if_no_gpu
+        def test_sparse_all_reduce_sum_cuda(self):
+            self._test_sparse_all_reduce_sum(lambda t: t.clone().cuda())
+
+        # ALL REDUCE - COALESCED
+        @staticmethod
+        def _all_reduce_coalesced_sum_test_cases(group_size):
+            return (
+                [2, 3, complex(2, 3)],
+                [10, 11, complex(10, 11)],
+                [
+                    2 + 10 * (group_size - 1),
+                    3 + 11 * (group_size - 1),
+                    complex(2, 3) + complex(10, 11) * (group_size - 1),
+                ],
+                [torch.float, torch.float, torch.cfloat],
+            )
+
+        @staticmethod
+        def _all_reduce_coalesced_product_test_cases(group_size):
+            return (
+                [1, 2],
+                [3, 4],
+                [1 * 3 ** (group_size - 1), 2 * 4 ** (group_size - 1)],
+                [torch.float, torch.float],
+            )
+
+        @staticmethod
+        def _all_reduce_coalesced_min_test_cases(group_size):
+            return (
+                [1, 4],
+                [2, 3],
+                [1, 3],
+                [torch.float, torch.float],
+            )
+
+        @staticmethod
+        def _all_reduce_coalesced_max_test_cases(group_size):
+            return (
+                [1, 4],
+                [2, 3],
+                [2, 4],
+                [torch.float, torch.float],
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_coalesced_max_complex_unsupported(self):
+            _group, group_id, _rank = self._init_global_test()
+            with self.assertRaisesRegex(ValueError, "all_reduce does not support"):
+                dist.all_reduce_coalesced(
+                    [_build_tensor(1, dtype=torch.cfloat)], dist.ReduceOp.MAX, group_id
+                )
+
+        def _test_all_reduce_coalesced_helper(
+            self,
+            group,
+            group_id,
+            rank,
+            op,
+            cuda=False,
+            rank_to_GPU=None,
+        ):
+            test_case_func = {
+                dist.ReduceOp.SUM: self._all_reduce_coalesced_sum_test_cases,
+                dist.ReduceOp.PRODUCT: self._all_reduce_coalesced_product_test_cases,
+                dist.ReduceOp.MIN: self._all_reduce_coalesced_min_test_cases,
+                dist.ReduceOp.MAX: self._all_reduce_coalesced_max_test_cases,
+            }[op]
+
+            master_values, worker_values, expected_values, dtypes = test_case_func(
+                len(group)
+            )
+
+            for src in group:
+                curr_values = master_values if rank == src else worker_values
+                tensors = [
+                    _build_tensor(src + 1, val, dtype=dtype)
+                    for dtype, val in zip(dtypes, curr_values, strict=True)
+                ]
+                if cuda:
+                    tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
+                tensor_shapes = []
+                for tensor in tensors:
+                    if tensor.dtype == torch.complex64:
+                        tensor_shapes.append(torch.view_as_real(tensor).shape)
+                    else:
+                        tensor_shapes.append(tensor.shape)
+                self.call_dist_op(
+                    ":all_reduce",
+                    False,
+                    dist.all_reduce_coalesced,
+                    tensors,
+                    op,
+                    group_id,
+                    tensor_shapes=tensor_shapes,
+                )
+                expected_tensors = [
+                    _build_tensor(src + 1, expected_value, dtype=dtype)
+                    for dtype, expected_value in zip(
+                        dtypes, expected_values, strict=True
+                    )
+                ]
+                self.assertEqual(tensors, expected_tensors)
+
+            self._barrier()
+
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_sum(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_coalesced_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                cuda=False,
+                rank_to_GPU=None,
+            )
+
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_product(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_coalesced_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                cuda=False,
+                rank_to_GPU=None,
+            )
+
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_min(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_coalesced_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.MIN,
+                cuda=False,
+                rank_to_GPU=None,
+            )
+
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_max(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_coalesced_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, cuda=False, rank_to_GPU=None
+            )
+
+        @skip_if_small_worldsize
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_group_sum(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_reduce_coalesced_helper(
+                group, group_id, rank, dist.ReduceOp.SUM, cuda=False, rank_to_GPU=None
+            )
+
+        @skip_if_small_worldsize
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_group_product(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_reduce_coalesced_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                cuda=False,
+                rank_to_GPU=None,
+            )
+
+        @skip_if_small_worldsize
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_group_min(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_reduce_coalesced_helper(
+                group, group_id, rank, dist.ReduceOp.MIN, cuda=False, rank_to_GPU=None
+            )
+
+        @skip_if_small_worldsize
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_group_max(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_reduce_coalesced_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, cuda=False, rank_to_GPU=None
+            )
+
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_full_group_sum(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_reduce_coalesced_helper(
+                group, group_id, rank, dist.ReduceOp.SUM, cuda=False, rank_to_GPU=None
+            )
+
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_full_group_product(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_reduce_coalesced_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                cuda=False,
+                rank_to_GPU=None,
+            )
+
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_full_group_min(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_reduce_coalesced_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.MIN,
+                cuda=False,
+                rank_to_GPU=None,
+            )
+
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_full_group_max(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_reduce_coalesced_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, cuda=False, rank_to_GPU=None
+            )
+
+        # SCATTER
+        def _test_scatter_helper(
+            self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
+        ):
+            for dest in group:
+                tensor = _build_tensor(dest + 1, -1, dtype=dtype)
+                expected_tensor = _build_tensor(dest + 1, rank, dtype=dtype)
+                tensors = (
+                    [_build_tensor(dest + 1, i, dtype=dtype) for i in group]
+                    if rank == dest
+                    else []
+                )
+                if cuda:
+                    tensor = tensor.cuda(rank_to_GPU[rank][0])
+                    tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
+                if dtype == torch.complex64:
+                    tensor_shapes = [torch.view_as_real(t).shape for t in tensors]
+                else:
+                    tensor_shapes = [t.shape for t in tensors]
+                self.call_dist_op(
+                    ":scatter",
+                    False,
+                    dist.scatter,
+                    tensor,
+                    src=dest,
+                    scatter_list=tensors,
+                    group=group_id,
+                    expect_event=False,
+                    tensor_shapes=tensor_shapes,
+                )
+                self.assertEqual(tensor, expected_tensor)
+
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        def test_scatter_checks(self):
+            group, _group_id, rank = self._init_global_test()
+            one = torch.ones([1])
+
+            # Specify scatter_list argument only on source rank.
+            output = one.clone() * -1
+            if rank == 0:
+                scatter_list = [one.clone() * i for i in group]
+                dist.scatter(output, src=0, scatter_list=scatter_list)
+            else:
+                dist.scatter(output, src=0)
+            self.assertEqual(output, one * rank)
+
+            # Don't specify src argument.
+            output = one.clone() * -1
+            if rank == 0:
+                scatter_list = [one.clone() * i for i in group]
+                dist.scatter(output, scatter_list=scatter_list)
+            else:
+                dist.scatter(output)
+            self.assertEqual(output, one * rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        def test_scatter(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_scatter_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA gather"
+        )
+        @skip_if_no_gpu
+        def test_scatter_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_scatter_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        def test_scatter_complex(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_scatter_helper(group, group_id, rank, dtype=torch.cfloat)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA gather"
+        )
+        @skip_if_no_gpu
+        def test_scatter_cuda_complex(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_scatter_helper(
+                group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        @skip_if_small_worldsize
+        def test_scatter_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_scatter_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        def test_scatter_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_scatter_helper(group, group_id, rank)
+
+        # GATHER
+        def _test_gather_helper(
+            self, group, group_id, rank, cuda=False, rank_to_GPU=None
+        ):
+            for dest in group:
+                tensor = _build_tensor(dest + 1, rank)
+                tensors = (
+                    [_build_tensor(dest + 1, -1) for i in group] if rank == dest else []
+                )
+                if cuda:
+                    tensor = tensor.cuda(rank_to_GPU[rank][0])
+                    tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
+                self.call_dist_op(
+                    ":gather",
+                    False,
+                    dist.gather,
+                    tensor,
+                    dst=dest,
+                    gather_list=tensors,
+                    group=group_id,
+                    expect_event=False,
+                    tensor_shapes=[tensors[0].shape] if len(tensors) > 0 else None,
+                )
+                if rank == dest:
+                    expected_tensors = [_build_tensor(dest + 1, i) for i in group]
+                    for t1, t2 in zip(tensors, expected_tensors, strict=True):
+                        self.assertEqual(t1, t2)
+
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        def test_gather_checks(self):
+            group, _group_id, rank = self._init_global_test()
+            one = torch.ones([1])
+
+            # Specify gather_list argument only on destination rank.
+            if rank == 0:
+                gather_list = [one.clone() for _ in group]
+                dist.gather(one * rank, dst=0, gather_list=gather_list)
+                for i in group:
+                    self.assertEqual(gather_list[i], one * i)
+            else:
+                dist.gather(one * rank, dst=0)
+
+            # Don't specify dst argument.
+            if rank == 0:
+                gather_list = [one.clone() for _ in group]
+                dist.gather(one * rank, gather_list=gather_list)
+                for i in group:
+                    self.assertEqual(gather_list[i], one * i)
+            else:
+                dist.gather(one * rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        def test_gather(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_gather_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA gather"
+        )
+        @skip_if_no_gpu
+        def test_gather_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_gather_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        @skip_if_small_worldsize
+        def test_gather_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_gather_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        def test_gather_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_gather_helper(group, group_id, rank)
+
+        # ALL GATHER
+        def _test_all_gather_helper(
+            self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
+        ):
+            for dest in group:
+                tensor = _build_tensor(dest + 1, rank, dtype=dtype)
+                tensors = [_build_tensor(dest + 1, -1, dtype=dtype) for i in group]
+                allgather = dist.all_gather
+                if cuda:
+                    tensor = tensor.cuda(rank_to_GPU[rank][0])
+                    tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
+                if tensors[0].dtype == torch.complex64:
+                    tensor_shapes = [torch.view_as_real(tensors[0]).shape]
+                else:
+                    tensor_shapes = [tensors[0].shape]
+                self.call_dist_op(
+                    ":all_gather",
+                    False,
+                    allgather,
+                    tensors,
+                    tensor,
+                    group_id,
+                    False,
+                    tensor_shapes=tensor_shapes,
+                )
+
+                expected_tensors = [
+                    _build_tensor(dest + 1, i, dtype=dtype) for i in group
+                ]
+                for t1, t2 in zip(tensors, expected_tensors, strict=True):
+                    self.assertEqual(t1, t2)
+
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_gather(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_gather_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all gather"
+        )
+        @skip_if_no_gpu
+        def test_all_gather_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_gather_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_gather_complex(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_gather_helper(group, group_id, rank, dtype=torch.cfloat)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all gather"
+        )
+        @skip_if_no_gpu
+        def test_all_gather_cuda_complex(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_gather_helper(
+                group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat
+            )
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_gather_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_gather_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_gather_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_gather_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports all_gather_v"
+        )
+        @skip_if_no_gpu
+        def test_all_gather_v_cuda(self):
+            self._barrier()
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+
+            output_split_sizes = [dst + 1 for dst in group]
+            sum_len = sum(output_split_sizes)
+            value = 2
+
+            for async_val in [True, False]:
+                tensor = (
+                    torch.empty(
+                        output_split_sizes[rank], sum_len, sum_len, dtype=torch.float
+                    )
+                    .fill_(value)
+                    .cuda(device_id)
+                )
+                out_tensor = _build_tensor(sum_len, -1, device_id=device_id)
+
+                req = dist.all_gather(
+                    list(torch.split(out_tensor, output_split_sizes)),
+                    tensor,
+                    group_id,
+                    async_val,
+                )
+                if async_val:
+                    req.wait()
+
+                expected_value = value
+                expected_tensor = _build_tensor(
+                    sum_len, expected_value, device_id=device_id
+                )
+
+                self.assertEqual(out_tensor, expected_tensor)
+            self._barrier()
+
+        # Test all_gather accepting single tensor as output
+        def _all_gather_into_tensor_helper(
+            self, tensor_out, tensor_in, group_id, rank, cuda=True, rank_to_GPU=None
+        ):
+            if cuda:
+                tensor_in = tensor_in.cuda(rank_to_GPU[rank][0])
+                tensor_out = tensor_out.cuda(rank_to_GPU[rank][0])
+            if tensor_out.dtype == torch.complex64:
+                tensor_shapes = [torch.view_as_real(tensor_in).shape]
+            else:
+                tensor_shapes = [tensor_in.shape]
+            self.call_dist_op(
+                ":all_gather_into_tensor",
+                False,
+                dist.all_gather_into_tensor,
+                tensor_out,
+                tensor_in,
+                group_id,
+                False,
+                expect_event=False,
+                tensor_shapes=tensor_shapes,
+            )
+            return tensor_out
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_gather_into_tensor"
+        )
+        @skip_if_no_gpu
+        def test_all_gather_into_cat_tensor_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            size = 2
+            tensor_in = torch.ones([size, size]) * rank
+            # Concatenated output
+            tensor_out = torch.ones([len(group) * size, size]) * (-1)
+            tensor_out = self._all_gather_into_tensor_helper(
+                tensor_out, tensor_in, group_id, rank, True, rank_to_GPU
+            )
+
+            # Check result
+            # Concatenate all blocks into a bigger tensor
+            expected_tensor = torch.cat([torch.ones([size, size]) * i for i in group])
+            self.assertEqual(tensor_out, expected_tensor)
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_gather_into_tensor"
+        )
+        @skip_if_no_gpu
+        def test_all_gather_into_stack_tensor_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            size = 2
+            tensor_in = torch.ones([size, size]) * rank
+            # Stacked output
+            tensor_out = torch.ones([len(group), size, size]) * (-1)
+            tensor_out = self._all_gather_into_tensor_helper(
+                tensor_out, tensor_in, group_id, rank, True, rank_to_GPU
+            )
+
+            # Check result
+            # Stack all blocks into a bigger tensor
+            expected_tensor = torch.stack([torch.ones([size, size]) * i for i in group])
+            self.assertEqual(tensor_out, expected_tensor)
+            self._barrier()
+
+        def _run_all_gather_coalesced_and_verify(
+            self, output_tensor_lists, input_tensors, expected_tensors, group_id
+        ):
+            """
+            Helper that runs all_gather_coalesced and returns true if output
+            matches expectations.
+            """
+            tensor_shapes = []
+            for input_tensor in input_tensors:
+                if input_tensor.dtype == torch.complex64:
+                    tensor_shapes.append(torch.view_as_real(input_tensor).shape)
+                else:
+                    tensor_shapes.append(input_tensor.shape)
+            self.call_dist_op(
+                ":all_gather",
+                False,
+                dist.all_gather_coalesced,
+                output_tensor_lists,
+                input_tensors,
+                group_id,
+                tensor_shapes=tensor_shapes,
+            )
+
+            for l1, l2 in zip(output_tensor_lists, expected_tensors, strict=True):
+                for t1, t2 in zip(l1, l2, strict=True):
+                    if not torch.equal(t1, t2):
+                        return False
+            return True
+
+        def _test_all_gather_coalesced_helper(
+            self, group, group_id, rank, dtype=torch.float
+        ):
+            # TODO: Instead we should probably go through _rank_not_in_group
+            # mechanism to disable sending tensors
+            if group_id is not None:
+                for test_case_id in range(2, 5):
+                    # Make sure we create tensors of incompatible sizes, e.g.
+                    # [1], [2x2], [3x3x3] ... to be sent in one batch
+                    input_tensors = [
+                        _build_multidim_tensor(
+                            tensor_id, tensor_id, rank + tensor_id, dtype=dtype
+                        )
+                        for tensor_id in range(1, test_case_id)
+                    ]
+                    output_tensor_lists = [
+                        [
+                            _build_multidim_tensor(
+                                tensor_id, tensor_id, -1, dtype=dtype
+                            )
+                            for tensor_id in range(1, test_case_id)
+                        ]
+                        for _ in group
+                    ]
+                    expected_tensors = [
+                        [
+                            _build_multidim_tensor(
+                                tensor_id, tensor_id, rank_iter + tensor_id, dtype=dtype
+                            )
+                            for tensor_id in range(1, test_case_id)
+                        ]
+                        for rank_iter in group
+                    ]
+                    assert self._run_all_gather_coalesced_and_verify(
+                        output_tensor_lists, input_tensors, expected_tensors, group_id
+                    ), "output tensors do not match expected outputs"
+
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
+            f"{BACKEND} does not support all_gather_coalesced",
+        )
+        def test_all_gather_coalesced_simple(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_gather_coalesced_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
+            f"{BACKEND} does not support all_gather_coalesced",
+        )
+        def test_all_gather_coalesced_complex(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_gather_coalesced_helper(
+                group, group_id, rank, dtype=torch.cfloat
+            )
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
+            f"{BACKEND} does not support all_gather_coalesced",
+        )
+        def test_all_gather_coalesced_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_gather_coalesced_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
+            f"{BACKEND} does not support all_gather_coalesced",
+        )
+        def test_all_gather_coalesced_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_gather_coalesced_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
+            f"{BACKEND} does not support all_gather_coalesced",
+        )
+        def test_all_gather_coalesced_with_empty(self):
+            group, group_id, rank = self._init_global_test()
+            input_tensors = [
+                rank * torch.ones([2, 2]),
+                torch.ones([0]),
+                (rank + 1) * torch.ones([3, 3]),
+                torch.ones([0]),
+                torch.ones([0]),
+            ]
+            output_tensors_lists = [
+                [
+                    -1 * torch.ones([2, 2]),
+                    -1 * torch.ones([0]),
+                    -1 * torch.ones([3, 3]),
+                    -1 * torch.ones([0]),
+                    -1 * torch.ones([0]),
+                ]
+                for _ in group
+            ]
+            expected_tensors = [
+                [
+                    r * torch.ones([2, 2]),
+                    torch.ones([0]),
+                    (r + 1) * torch.ones([3, 3]),
+                    torch.ones([0]),
+                    torch.ones([0]),
+                ]
+                for r in group
+            ]
+            assert self._run_all_gather_coalesced_and_verify(
+                output_tensors_lists, input_tensors, expected_tensors, group_id
+            )
+            self._barrier()
+
+        # AllToAll
+        def _test_all_to_all_single_equal_split_helper(
+            self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
+        ):
+            if group_id is not None:
+                size = len(group)
+                in_tensor = torch.ones([size, size], dtype=dtype) * rank
+                expected_tensor = torch.cat(
+                    [torch.ones([1, size], dtype=dtype) * i for i in group]
+                )
+                out_tensor = torch.ones([size, size], dtype=dtype) * -1
+                if cuda:
+                    in_tensor = in_tensor.cuda(rank_to_GPU[rank][0])
+                    expected_tensor = expected_tensor.cuda(rank_to_GPU[rank][0])
+                    out_tensor = out_tensor.cuda(rank_to_GPU[rank][0])
+                if dtype == torch.complex64:
+                    tensor_shapes = [torch.view_as_real(in_tensor).shape]
+                else:
+                    tensor_shapes = [in_tensor.shape]
+                self.call_dist_op(
+                    ":all_to_all",
+                    False,
+                    dist.all_to_all_single,
+                    out_tensor,
+                    in_tensor,
+                    group=group_id,
+                    tensor_shapes=tensor_shapes,
+                )
+                self.assertEqual(out_tensor, expected_tensor)
+            self._barrier()
+
+        def _test_all_to_all_single_unequal_split_helper(
+            self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
+        ):
+            if group_id is not None:
+                size = len(group)
+                in_splits = [i + 1 for i in group]
+                out_splits = [rank + 1 for _ in group]
+                in_tensor = torch.ones([sum(in_splits), size], dtype=dtype) * rank
+                out_tensor = torch.ones([(rank + 1) * size, size], dtype=dtype)
+                expected_tensor = torch.cat(
+                    [torch.ones([rank + 1, size], dtype=dtype) * i for i in group]
+                )
+                if cuda:
+                    in_tensor = in_tensor.cuda(rank_to_GPU[rank][0])
+                    expected_tensor = expected_tensor.cuda(rank_to_GPU[rank][0])
+                    out_tensor = out_tensor.cuda(rank_to_GPU[rank][0])
+                dist.all_to_all_single(
+                    out_tensor, in_tensor, out_splits, in_splits, group=group_id
+                )
+                self.assertEqual(out_tensor, expected_tensor)
+            self._barrier()
+
+        def _test_all_to_all_helper(
+            self,
+            group,
+            group_id,
+            rank,
+            cuda=False,
+            rank_to_GPU=None,
+            dtype=torch.float,
+        ):
+            if group_id is not None:
+                size = len(group)
+                in_splits = [i + 1 for i in group]
+                in_tensors = [
+                    torch.ones([in_splits[i], size], dtype=dtype) * rank
+                    for i, _ in enumerate(group)
+                ]
+                out_tensors = [
+                    torch.ones([(rank + 1), size], dtype=dtype) for _ in group
+                ]
+                expected_tensors = [
+                    torch.ones([rank + 1, size], dtype=dtype) * i for i in group
+                ]
+                if cuda:
+                    in_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in in_tensors]
+                    expected_tensors = [
+                        t.cuda(rank_to_GPU[rank][0]) for t in expected_tensors
+                    ]
+                    out_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in out_tensors]
+                dist.all_to_all(out_tensors, in_tensors, group=group_id)
+                for t1, t2 in zip(out_tensors, expected_tensors, strict=True):
+                    self.assertEqual(t1, t2)
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
+        )
+        def test_all_to_all_single_equal_split(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_to_all_single_equal_split_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_no_gpu
+        def test_all_to_all_single_equal_split_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_single_equal_split_helper(
+                group,
+                group_id,
+                rank,
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
+        )
+        def test_all_to_all_single_equal_split_complex(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_to_all_single_equal_split_helper(
+                group, group_id, rank, dtype=torch.cfloat
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_no_gpu
+        def test_all_to_all_single_equal_split_cuda_complex(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_single_equal_split_helper(
+                group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
+        )
+        def test_all_to_all_single_unequal_split(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_to_all_single_unequal_split_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_no_gpu
+        def test_all_to_all_single_unequal_split_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_single_unequal_split_helper(
+                group,
+                group_id,
+                rank,
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
+        )
+        def test_all_to_all_single_unequal_split_complex(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_to_all_single_unequal_split_helper(
+                group, group_id, rank, dtype=torch.cfloat
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_no_gpu
+        def test_all_to_all_single_unequal_split_cuda_complex(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_single_unequal_split_helper(
+                group,
+                group_id,
+                rank,
+                True,
+                rank_to_GPU,
+                dtype=torch.cfloat,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports all_to_all"
+        )
+        def test_all_to_all(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_to_all_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only NCCL supports CUDA all_to_all"
+        )
+        @skip_if_rocm_multiprocess
+        def test_all_to_all_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports all_to_all"
+        )
+        def test_all_to_all_complex(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_to_all_helper(group, group_id, rank, dtype=torch.cfloat)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only NCCL supports CUDA all_to_all"
+        )
+        @skip_if_rocm_multiprocess
+        def test_all_to_all_cuda_complex(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_helper(
+                group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
+        )
+        @skip_if_small_worldsize
+        def test_all_to_all_single_equal_split_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_to_all_single_equal_split_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_no_gpu
+        @skip_if_small_worldsize
+        def test_all_to_all_single_equal_split_group_cuda(self):
+            group, group_id, rank = self._init_group_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_single_equal_split_helper(
+                group,
+                group_id,
+                rank,
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
+        )
+        @skip_if_small_worldsize
+        def test_all_to_all_single_unequal_split_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_to_all_single_unequal_split_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_no_gpu
+        @skip_if_small_worldsize
+        def test_all_to_all_single_unequal_split_group_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_single_unequal_split_helper(
+                group,
+                group_id,
+                rank,
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports all_to_all"
+        )
+        @skip_if_small_worldsize
+        def test_all_to_all_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_to_all_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_small_worldsize
+        @skip_if_rocm_multiprocess
+        def test_all_to_all_group_cuda(self):
+            group, group_id, rank = self._init_group_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
+        )
+        def test_all_to_all_single_equal_split_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_to_all_single_equal_split_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_no_gpu
+        def test_all_to_all_single_equal_split_full_group_cuda(self):
+            group, group_id, rank = self._init_full_group_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_single_equal_split_helper(
+                group,
+                group_id,
+                rank,
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
+        )
+        def test_all_to_all_single_unequal_split_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_to_all_single_unequal_split_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_no_gpu
+        def test_all_to_all_single_unequal_split_full_group_cuda(self):
+            group, group_id, rank = self._init_full_group_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_single_unequal_split_helper(
+                group,
+                group_id,
+                rank,
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports all_to_all"
+        )
+        def test_all_to_all_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_to_all_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only NCCL supports CUDA all_to_all"
+        )
+        @skip_if_rocm_multiprocess
+        def test_all_to_all_full_group_cuda(self):
+            group, group_id, rank = self._init_full_group_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_helper(group, group_id, rank, True, rank_to_GPU)
+
+        # BARRIER
+        def _test_barrier_helper(
+            self, group, group_id, rank, cuda=False, rank_to_GPU=None
+        ):
+            WAIT_TIME = 0.3  # seconds
+
+            for dest in group:
+                expected_time = torch.DoubleTensor(1).fill_(0.0)
+                if cuda:
+                    expected_time = expected_time.cuda(rank_to_GPU[rank][0])
+                if dest == rank:
+                    expected_time.fill_(time.time() + WAIT_TIME)
+                    dist.broadcast(expected_time, dest, group_id)
+                    time.sleep(WAIT_TIME + 0.1)  # sleep a little bit longer
+                    dist.barrier(group_id)
+                else:
+                    dist.broadcast(expected_time, dest, group_id)
+                    dist.barrier(group_id)
+                    self.assertGreaterAlmostEqual(
+                        float(time.time()),
+                        float(expected_time[0]),
+                        msg=f"destination rank: {dest:d}, my rank: {rank:d}"
+                        + " (if you see this failure, please report in #14554)",
+                    )
+
+            # Use higher timeout for the instance where the test runs
+            # against a subgroup and uses a CUDA tensor for expected time.
+            # The CUDA initialization for the participating processes can
+            # take long enough for the barrier timeout to trigger on the
+            # process that doesn't participate in the group.
+            self._barrier(timeout=20)
+
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "mpi", "MPI doesn't supports GPU barrier"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally"
+        )
+        def test_barrier_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_barrier_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_if_small_worldsize
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "mpi", "MPI doesn't supports GPU barrier"
+        )
+        def test_barrier_group_cuda(self):
+            group, group_id, rank = self._init_group_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_barrier_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_if_small_worldsize
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "mpi", "MPI doesn't supports GPU barrier"
+        )
+        def test_barrier_full_group_cuda(self):
+            group, group_id, rank = self._init_full_group_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_barrier_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["cpu barrier"],
+            f"{BACKEND} does not support CPU barrier",
+        )
+        def test_barrier(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_barrier_helper(group, group_id, rank)
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["cpu barrier"],
+            f"{BACKEND} does not support CPU barrier",
+        )
+        def test_barrier_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_barrier_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["cpu barrier"],
+            f"{BACKEND} does not support CPU barrier",
+        )
+        def test_barrier_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_barrier_helper(group, group_id, rank)
+
+        def _model_step(self, model):
+            for param in model.parameters():
+                if param.grad is not None:
+                    with torch.no_grad():
+                        param += param.grad
+                    param.grad = None
+
+        def _model_step_with_zero_grad(self, model):
+            for param in model.parameters():
+                if param.grad is not None:
+                    with torch.no_grad():
+                        param += param.grad
+                    param.grad.requires_grad_(False)
+                    param.grad.zero_()
+
+        def _prepare_dummy_data(self, local_bs):
+            # global_bs for DDP should be divisible by WORLD_SIZE
+            world_size = int(os.environ["WORLD_SIZE"])
+            global_bs = world_size * local_bs
+            input_cpu = torch.randn(global_bs, 2)
+            target = torch.randn(global_bs, 4)
+            loss = nn.MSELoss()
+            return global_bs, input_cpu, target, loss
+
+        # END TO END TEST FOR DISTRIBUTEDDATAPARALLEL
+        def _test_DDP_helper(
+            self, model, input_var, target, loss, scale_factor=1.0, memory_format=None
+        ):
+            model.train()
+            output = model(input_var)
+            l = loss(output, target) * scale_factor
+            l.backward()
+            if memory_format is not None:
+                self.assertTrue(output.is_contiguous(memory_format=memory_format))
+
+        def _assert_equal_param(self, param_gpu, param_DDP):
+            self.assertEqual(len(param_gpu), len(param_DDP))
+            for p_gpu, p_DDP in zip(param_gpu, param_DDP, strict=True):
+                self.assertEqual(p_gpu, p_DDP)
+
+        def _test_DDP_niter(
+            self,
+            model_base,
+            model_DDP,
+            input,
+            target,
+            loss,
+            local_bs,
+            rank,
+            batch_size,
+            test_save,
+            offset=None,
+            world_size=0,
+            zero_grad=False,
+            memory_format=None,
+            n_iter=5,
+        ):
+            for idx in range(n_iter):
+                # single cpu/gpu training
+                self._test_DDP_helper(
+                    model_base, input, target, loss, memory_format=memory_format
+                )
+
+                if offset is None:
+                    offset = rank * local_bs
+
+                # DDP training, DDP scatters subsets of input_cpu to nodes/GPUs
+                self._test_DDP_helper(
+                    model_DDP,
+                    input[offset : offset + local_bs],
+                    target[offset : offset + local_bs],
+                    loss,
+                    world_size * local_bs / batch_size if world_size != 0 else 1,
+                    memory_format=memory_format,
+                )
+
+                # Update weights and run a second iteration to shake out errors
+                if zero_grad:
+                    self._model_step_with_zero_grad(model_base)
+                    self._model_step_with_zero_grad(model_DDP)
+                else:
+                    self._model_step(model_base)
+                    self._model_step(model_DDP)
+                self._assert_equal_param(
+                    list(model_base.parameters()), list(model_DDP.module.parameters())
+                )
+
+                # Shuffle the input so that DDP input is different
+                input = input[torch.randperm(batch_size)]
+
+                # save the model in the middle and reload
+                if test_save and idx == 2 and INIT_METHOD.startswith("file://"):
+                    with tempfile.NamedTemporaryFile() as tmp:
+                        if sys.platform == "win32":
+                            torch.save(model_DDP, tmp)
+                            tmp.seek(0)
+                            # weights_only=False as this is legacy code that saves the model
+                            model_DDP = torch.load(tmp, weights_only=False)
+                        else:
+                            torch.save(model_DDP, tmp.name)
+                            # weights_only=False as this is legacy code that saves the model
+                            model_DDP = torch.load(tmp.name, weights_only=False)
+
+            with tempfile.TemporaryFile() as tmp_file:
+                torch.save(model_DDP, tmp_file)
+                tmp_file.seek(0)
+                # weights_only=False as this is legacy code that saves the model
+                saved_model = torch.load(tmp_file, weights_only=False)
+            for k in model_DDP.state_dict():
+                self.assertEqual(model_DDP.state_dict()[k], saved_model.state_dict()[k])
+
+        def _test_DistributedDataParallel(
+            self,
+            gpu_subset,
+            rank,
+            output_device=None,
+            gradient_as_bucket_view=False,
+            static_graph=False,
+            set_static_graph_twice=False,
+        ):
+            # Run a simple end to end DDP model, use result of single node model
+            # as baseline
+
+            # cpu training setup
+            model = Net()
+
+            # single gpu training setup
+            model_gpu = copy.deepcopy(model)
+            model_gpu.cuda(gpu_subset[0])
+
+            # DDP training setup
+            model_DDP = copy.deepcopy(model)
+            model_DDP.cuda(gpu_subset[0])
+            model_DDP = nn.parallel.DistributedDataParallel(
+                model_DDP,
+                device_ids=gpu_subset,
+                gradient_as_bucket_view=gradient_as_bucket_view,
+                static_graph=static_graph,
+            )
+
+            if set_static_graph_twice:
+                model_DDP._set_static_graph()
+
+            # test serializable/unserializable
+            with tempfile.NamedTemporaryFile() as tmp:
+                if sys.platform == "win32":
+                    torch.save(model_DDP, tmp)
+                    tmp.seek(0)
+                    # weights_only=False as this is legacy code that saves the model
+                    model_DDP = torch.load(tmp, weights_only=False)
+                else:
+                    torch.save(model_DDP, tmp.name)
+                    # weights_only=False as this is legacy code that saves the model
+                    model_DDP = torch.load(tmp.name, weights_only=False)
+
+            # dummy data initialization
+            local_bs = len(gpu_subset)
+            global_bs, input_cpu, target, loss = self._prepare_dummy_data(local_bs)
+
+            # check two model parameters over 5 iterations
+            self._test_DDP_niter(
+                model_gpu,
+                model_DDP,
+                input_cpu.cuda(gpu_subset[0]),
+                target.cuda(gpu_subset[0]),
+                loss,
+                local_bs,
+                rank,
+                global_bs,
+                True,
+            )
+            self._barrier()
+
+        def _test_DistributedDataParallelCPU(self, gradient_as_bucket_view=False):
+            # Run a simple end to end DDP-CPU model, use result of single node
+            # model as baseline
+            _group, _group_id, rank = self._init_global_test()
+
+            # cpu training setup
+            model_base = Net()
+
+            # DDP-CPU training setup
+            model_DDP = copy.deepcopy(model_base)
+            model_DDP = nn.parallel.DistributedDataParallel(
+                model_DDP, gradient_as_bucket_view=gradient_as_bucket_view
+            )
+
+            # dummy data initialization
+            local_bs = 2
+            global_bs, input_cpu, target, loss = self._prepare_dummy_data(local_bs)
+
+            # check two model parameters over 5 iterations
+            self._test_DDP_niter(
+                model_base,
+                model_DDP,
+                input_cpu,
+                target,
+                loss,
+                local_bs,
+                rank,
+                global_bs,
+                False,
+                zero_grad=True,
+            )
+            self._barrier()
+
+            return model_DDP
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "nccl does not support DDP on CPU models"
+        )
+        def test_DistributedDataParallelCPU(self):
+            self._test_DistributedDataParallelCPU()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "nccl does not support DDP on CPU models"
+        )
+        def test_DistributedDataParallelCPU_grad_is_view(self):
+            self._test_DistributedDataParallelCPU(gradient_as_bucket_view=True)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_DistributedDataParallel_requires_grad(self):
+            # a module without gradients shouldn't be accepted
+            self.assertRaises(
+                RuntimeError, lambda: nn.parallel.DistributedDataParallel(nn.Module())
+            )
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_ddp_zero_output_features(self):
+            class ToyModel(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.net1 = nn.Linear(10, 10)
+                    self.relu = nn.ReLU()
+                    self.net2 = nn.Linear(10, 0)
+
+            model = ToyModel().to(self.rank)
+            nn.parallel.DistributedDataParallel(model, device_ids=[self.rank])
+
+        @skip_but_pass_in_sandcastle_if(BACKEND == "nccl", "Gloo-only test")
+        def test_ddp_create_graph(self):
+            class Model(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.p = nn.Parameter(torch.tensor(1.0))
+
+                def forward(self):
+                    return self.p.pow(2)
+
+            model = Model()
+            ddp_model = torch.nn.parallel.DistributedDataParallel(model)
+            for _ in range(6):
+                # Verify DDP doesn't throw when ran with create_graph=True.
+                # Although we do warn about potential issues, please see
+                # https://github.com/pytorch/pytorch/issues/63929 for details.
+                ddp_model().backward(create_graph=True)
+                # grad tensors should require grad.
+                self.assertTrue(
+                    all(param.requires_grad for param in ddp_model.parameters())
+                )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_DistributedDataParallel_non_default_stream(self):
+            stream = torch.cuda.Stream(self.rank)
+            rank = self.rank
+            with torch.cuda.stream(stream):
+                net = torch.nn.parallel.DistributedDataParallel(
+                    torch.nn.Linear(1, 1, bias=False).cuda(rank), device_ids=[rank]
+                )
+                for i in range(1000):
+                    # Clear gradients manually
+                    grad = net.module.weight.grad
+                    if grad is not None:
+                        grad.requires_grad_(False)
+                        grad.zero_()
+                    # Forward + BW
+                    batch = torch.tensor([rank]).float().cuda(rank)
+                    loss = net(batch).sum()
+                    loss.backward()
+                    # For each worker, the gradient on the weight should be worker_rank.
+                    grad = net.module.weight.grad
+                    avg = grad.clone()
+                    # All-reducing the gradient averages should give us the gradient
+                    # average. If not, then one of the workers has not correctly
+                    # written back the averaged gradient before this all-reduce call.
+                    dist.all_reduce(avg)
+                    world_size = int(os.environ["WORLD_SIZE"])
+                    avg.div_(world_size)
+                    expected_grad = sum(i for i in range(world_size)) / world_size
+                    self.assertEqual(
+                        avg[0, 0],
+                        expected_grad,
+                        msg=f"Expected gradient of {expected_grad} but got {avg} on rank {self.rank}",
+                    )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["cuda"],
+            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_ddp_comm_hook_logging(self):
+            hooks = [
+                default.allreduce_hook,
+                default.fp16_compress_hook,
+                powerSGD.powerSGD_hook,
+                powerSGD.batched_powerSGD_hook,
+                quantization_hooks.quantization_pertensor_hook,
+                quantization_hooks.quantization_perchannel_hook,
+            ]
+
+            cpp_builtin_hooks = [
+                dist.BuiltinCommHookType.ALLREDUCE,
+                dist.BuiltinCommHookType.FP16_COMPRESS,
+            ]
+
+            for hook in hooks:
+                ddp_model = torch.nn.parallel.DistributedDataParallel(
+                    torch.nn.Linear(1, 1, bias=False).cuda(self.rank),
+                    device_ids=[self.rank],
+                )
+                ddp_logging_data = ddp_model._get_ddp_logging_data()
+                # Hook not registered yet, so should be empty
+                self.assertEqual(ddp_logging_data.get("comm_hook"), None)
+                ddp_model.register_comm_hook(None, hook)
+                ddp_logging_data = ddp_model._get_ddp_logging_data()
+                self.assertEqual(ddp_logging_data.get("comm_hook"), hook.__qualname__)
+
+            for hook in cpp_builtin_hooks:
+                ddp_model = torch.nn.parallel.DistributedDataParallel(
+                    torch.nn.Linear(1, 1, bias=False).cuda(self.rank),
+                    device_ids=[self.rank],
+                )
+                ddp_logging_data = ddp_model._get_ddp_logging_data()
+                # Hook not registered yet, so should be empty
+                self.assertEqual(ddp_logging_data.get("comm_hook"), None)
+                ddp_model._register_builtin_comm_hook(hook)
+                ddp_logging_data = ddp_model._get_ddp_logging_data()
+                self.assertEqual(ddp_logging_data.get("comm_hook"), str(hook))
+
+            # No hook registered
+            ddp_model = torch.nn.parallel.DistributedDataParallel(
+                torch.nn.Linear(1, 1, bias=False).cuda(self.rank),
+                device_ids=[self.rank],
+            )
+            ddp_logging_data = ddp_model._get_ddp_logging_data()
+            # Hook not registered yet, so should be empty
+            self.assertEqual(ddp_logging_data.get("comm_hook"), None)
+            # After second forward pass, hook should still be empty string
+            for _ in range(2):
+                inp = torch.ones(1, 1, device=self.rank)
+                loss = ddp_model(inp).sum()
+                loss.backward()
+
+            ddp_logging_data = ddp_model._get_ddp_logging_data()
+            # Note: DETAIL debug mode logs DDP logging data to stdout and
+            # thus accesses std::map, which fills in a default value for the
+            # type if it didn't exist.
+            self.assertEqual(ddp_logging_data.get("comm_hook", ""), "")
+
+        def _test_ddp_hook_with_optimizer_parity(
+            self,
+            grad_as_bucket_view,
+            static_graph,
+            optim_cls,
+            optimize_subset,
+            *functional_optim_args,
+            **functional_optim_kwargs,
+        ):
+            rank = self.rank
+            torch.cuda.set_device(rank)
+            torch.manual_seed(rank)
+            torch.cuda.manual_seed(rank)
+            models_to_test = [
+                (LargeNet(), torch.randn(1, 1000).cuda()),
+            ]
+            if HAS_TORCHVISION:
+                models_to_test.append(
+                    (torchvision.models.resnet50(), torch.randn(1, 3, 3, 1000).cuda())
+                )
+            for model, inp in models_to_test:
+                # Enable determinism in cudnn operators
+                with torch.backends.cudnn.flags(
+                    enabled=True, deterministic=True, benchmark=False
+                ):
+                    # Create DDP model that runs optimizer in fused fashion.
+                    ddp_model_with_optimizer_hook = (
+                        torch.nn.parallel.DistributedDataParallel(
+                            copy.deepcopy(model).cuda(),
+                            device_ids=[self.rank],
+                            gradient_as_bucket_view=grad_as_bucket_view,
+                            static_graph=static_graph,
+                        )
+                    )
+
+                    # Create DDP model with no hook that does optimizer after
+                    # backward.
+                    ddp_model_with_no_hook = torch.nn.parallel.DistributedDataParallel(
+                        copy.deepcopy(model).cuda(),
+                        device_ids=[self.rank],
+                        gradient_as_bucket_view=grad_as_bucket_view,
+                        static_graph=static_graph,
+                    )
+                    hook_params = ddp_model_with_optimizer_hook.parameters()
+                    no_hook_params = ddp_model_with_no_hook.parameters()
+                    if optimize_subset:
+                        hook_params = list(hook_params)
+                        no_hook_params = list(no_hook_params)
+                        self.assertGreater(len(hook_params), 0)
+                        hook_params = [hook_params[0]]
+                        no_hook_params = [no_hook_params[0]]
+
+                    # Register a fused optimizer that will run optimizer in step
+                    # with allreduce.
+
+                    if optimize_subset:
+                        # API where optim_params is specified.
+                        ddp_model_with_optimizer_hook._register_fused_optim(
+                            optim_cls,
+                            *functional_optim_args,
+                            optim_params=hook_params,
+                            **functional_optim_kwargs,
+                        )
+                    else:
+                        # API where optim_params is omitted
+                        ddp_model_with_optimizer_hook._register_fused_optim(
+                            optim_cls,
+                            *functional_optim_args,
+                            **functional_optim_kwargs,
+                        )
+
+                    optimizer_no_hook = optim_cls(
+                        no_hook_params,
+                        *functional_optim_args,
+                        **functional_optim_kwargs,
+                    )
+
+                    # Verify parameters are equal initially.
+                    for hook_param, allreduce_param in zip(
+                        ddp_model_with_optimizer_hook.parameters(),
+                        ddp_model_with_no_hook.parameters(),
+                        strict=True,
+                    ):
+                        self.assertEqual(hook_param, allreduce_param)
+
+                    # Save old parameters to later verify optimizer modified them.
+                    opt_hook_init_params = copy.deepcopy(
+                        list(ddp_model_with_optimizer_hook.parameters())
+                    )
+
+                    # Run optimizer with hook model.
+                    for _ in range(6):
+                        ddp_model_with_optimizer_hook.zero_grad()
+                        out = ddp_model_with_optimizer_hook(inp)
+                        loss = out.sum()
+                        loss.backward()
+
+                    dist.barrier()
+
+                    # Run regular model.
+                    for _ in range(6):
+                        ddp_model_with_no_hook.zero_grad()
+                        out = ddp_model_with_no_hook(inp)
+                        loss = out.sum()
+                        loss.backward()
+                        optimizer_no_hook.step()
+
+                    dist.barrier()
+
+                    # Now verify parameters are equal.
+                    for hook_param, allreduce_param in zip(
+                        ddp_model_with_optimizer_hook.parameters(),
+                        ddp_model_with_no_hook.parameters(),
+                        strict=True,
+                    ):
+                        self.assertEqual(hook_param, allreduce_param)
+
+                    # Verify optimizer modified appropriate parameter set,
+                    # otherwise they'd be trivially equal above.
+                    if optimize_subset:
+                        self.assertNotEqual(
+                            opt_hook_init_params[0],
+                            next(iter(ddp_model_with_optimizer_hook.parameters())),
+                        )
+                        # Untouched params should be equal
+                        self.assertEqual(
+                            opt_hook_init_params[1:],
+                            list(ddp_model_with_optimizer_hook.parameters())[1:],
+                        )
+                    else:
+                        self.assertNotEqual(
+                            opt_hook_init_params,
+                            list(ddp_model_with_optimizer_hook.parameters()),
+                        )
+                    dist.barrier()
+
+        """
+        # Commenting out the following 3 tests as they cause Sandcastle jobs to fail
+        # Failure signature:
+        # AttributeError: type object 'TestDistBackendWithSpawn' has no attribute 'test_ddp_hook_with_optimizer_parity_adamw
+
+        from torch.testing._internal.common_utils import parametrize
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl" or BACKEND == "ucc",
+            "Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259",
+        )
+        @skip_if_lt_x_gpu(2)
+        @parametrize("grad_as_bucket_view", [True, False])
+        @parametrize("static_graph", [True, False])
+        @parametrize("optimize_subset", [True, False])
+        def test_ddp_hook_with_optimizer_parity_adamw(
+            self,
+            grad_as_bucket_view,
+            static_graph,
+            optimize_subset,
+        ):
+            adamw_lr = 1e-2
+            adamw_betas = (0.9, 0.99)
+            adamw_eps = 1e-6
+            self._test_ddp_hook_with_optimizer_parity(
+                grad_as_bucket_view,
+                static_graph,
+                torch.optim.AdamW,
+                optimize_subset,
+                adamw_lr,
+                betas=adamw_betas,
+                eps=adamw_eps,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl" or BACKEND == "ucc",
+            "Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259",
+        )
+        @skip_if_lt_x_gpu(2)
+        @parametrize("optimize_subset", [True, False])
+        def test_ddp_hook_with_optimizer_parity_adam(self, optimize_subset):
+            adam_lr = 1e-2
+            adam_betas = (0.9, 0.99)
+            adam_eps = 1e-6
+            self._test_ddp_hook_with_optimizer_parity(
+                True,  # grad as bucket view
+                False,  # static graph
+                torch.optim.Adam,
+                optimize_subset,
+                adam_lr,
+                betas=adam_betas,
+                eps=adam_eps,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl" or BACKEND == "ucc",
+            "Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259",
+        )
+        @skip_if_lt_x_gpu(2)
+        @parametrize("optimize_subset", [True, False])
+        def test_ddp_hook_with_optimizer_parity_sgd(self, optimize_subset):
+            sgd_lr = 1e-2
+            sgd_momentum = 0.9
+            sgd_weight_decay = 0.01
+            # Not testing grad_as_bucket_view and static_graph as they are
+            # tested in AdamW test above.
+            self._test_ddp_hook_with_optimizer_parity(
+                True,  # grad as bucket view
+                False,  # static_graph
+                torch.optim.SGD,
+                optimize_subset,
+                sgd_lr,
+                momentum=sgd_momentum,
+                weight_decay=sgd_weight_decay,
+            )
+        """
+
+        @skip_if_lt_x_gpu(2)
+        def test_get_data_parallel_params(self):
+            torch.cuda.set_device(self.rank)
+            model = TwoLinLayerNet().cuda()
+            # Parameters to ignore are in the format {module_name}.{param_name}
+            params_to_ignore = ["a.weight"]
+            torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
+                model, params_to_ignore
+            )
+            torch.nn.parallel.DistributedDataParallel(model, device_ids=[self.rank])
+            dp_params = (
+                torch.nn.parallel.DistributedDataParallel._get_data_parallel_params(
+                    model, named_params=True
+                )
+            )
+            for name, _ in dp_params:
+                self.assertNotEqual(f"module.{params_to_ignore[0]}", name)
+
+            # test named_params=False, just check if returns the expected
+            # no of parameters.
+            num_ddp_params = len(list(model.parameters())) - 1
+            count = 0
+            dp_params = (
+                torch.nn.parallel.DistributedDataParallel._get_data_parallel_params(
+                    model, named_params=False
+                )
+            )
+            for _ in dp_params:
+                count += 1
+            self.assertEqual(count, num_ddp_params)
+
+        def _test_ddp_apply_optim_in_backward(
+            self,
+            optim_cls,
+            optim_kwargs,
+            init_before,
+            gradient_as_bucket_view=True,
+        ):
+            # Need to seed to ensure inputs are unique across rank. Otherwise,
+            # allreduce won't have any effect.
+            torch.manual_seed(self.rank)
+            torch.cuda.manual_seed(self.rank)
+            torch.cuda.set_device(self.rank)
+
+            # Test a simple linear as well as a ResNet model.
+            models_to_test = [
+                nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3)).cuda()
+            ]
+            if HAS_TORCHVISION:
+                models_to_test.append(torchvision.models.resnet50().cuda())
+
+            for j, model in enumerate(models_to_test):
+                model_optim_in_bwd = copy.deepcopy(model)
+                model = nn.parallel.DistributedDataParallel(
+                    model,
+                    device_ids=[self.rank],
+                    gradient_as_bucket_view=gradient_as_bucket_view,
+                )
+                optim = optim_cls(model.parameters(), **optim_kwargs)
+                if init_before:
+                    _apply_optimizer_in_backward(
+                        optimizer_class=optim_cls,
+                        params=model_optim_in_bwd.parameters(),
+                        optimizer_kwargs=optim_kwargs,
+                    )
+                model_optim_in_bwd = nn.parallel.DistributedDataParallel(
+                    model_optim_in_bwd,
+                    device_ids=[self.rank],
+                    gradient_as_bucket_view=gradient_as_bucket_view,
+                )
+                if not init_before:
+                    _apply_optimizer_in_backward(
+                        optimizer_class=optim_cls,
+                        params=model_optim_in_bwd.parameters(),
+                        optimizer_kwargs=optim_kwargs,
+                    )
+
+                for p1, p2 in zip(
+                    model.parameters(), model_optim_in_bwd.parameters(), strict=True
+                ):
+                    self.assertEqual(p1, p2, "Parameters not initially equal!")
+                # Enable determinism in cudnn operators
+                with torch.backends.cudnn.flags(
+                    enabled=True, deterministic=True, benchmark=False
+                ):
+                    for i in range(8):
+                        inp = (
+                            torch.randn(1, 3, 1000, 1000, device="cuda")
+                            if j == 1
+                            else torch.randn(10, 3, device="cuda")
+                        )
+                        model(inp).sum().backward()
+                        optim.step()
+                        model_optim_in_bwd(
+                            inp
+                        ).sum().backward()  # runs optimizer as well
+                        for p1, p2 in zip(
+                            model.parameters(),
+                            model_optim_in_bwd.parameters(),
+                            strict=True,
+                        ):
+                            self.assertEqual(
+                                p1, p2, f"Params not equal at iteration {i}"
+                            )
+                            self.assertTrue(
+                                p2.grad is None,
+                                f"Optim in backward grad is not None at {i}",
+                            )
+
+                        # set_to_none for regular optimizer to match in backward
+                        # case.
+                        optim.zero_grad(set_to_none=True)
+
+        @skipIfRocm
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_apply_optim_in_backward(self):
+            for optim_cls, init_before in itertools.product(
+                [torch.optim.SGD, torch.optim.Adam], [True, False]
+            ):
+                with self.subTest(optim_cls=optim_cls):
+                    self._test_ddp_apply_optim_in_backward(
+                        optim_cls=optim_cls,
+                        optim_kwargs={"lr": 0.03},
+                        init_before=init_before,
+                    )
+
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_apply_optim_in_backward_grad_as_bucket_view_false(self):
+            for init_before in [True, False]:
+                self._test_ddp_apply_optim_in_backward(
+                    optim_cls=torch.optim.SGD,
+                    optim_kwargs={"lr": 0.03},
+                    init_before=init_before,
+                    gradient_as_bucket_view=False,
+                )
+
+        @skipIfRocmArch(MI200_ARCH)
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_apply_optim_in_backward_ignored_params(self):
+            torch.cuda.set_device(self.rank)
+            for init_before in [True, False]:
+                with self.subTest(init_before=init_before):
+                    torch.manual_seed(self.rank)
+                    torch.cuda.manual_seed(self.rank)
+                    model = TwoLinLayerNet()
+                    # Parameters to ignore are in the format {module_name}.{param_name}
+                    params_to_ignore = ["a.weight"]
+                    torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
+                        model, params_to_ignore
+                    )
+                    if init_before:
+                        _apply_optimizer_in_backward(
+                            optimizer_class=torch.optim.SGD,
+                            params=model.parameters(),
+                            optimizer_kwargs={"lr": 0.03},
+                        )
+                    net = torch.nn.parallel.DistributedDataParallel(
+                        model.cuda(self.rank),
+                        device_ids=[self.rank],
+                    )
+                    if not init_before:
+                        _apply_optimizer_in_backward(
+                            optimizer_class=torch.optim.SGD,
+                            params=model.parameters(),
+                            optimizer_kwargs={"lr": 0.03},
+                        )
+                    inp = torch.randn(1, 10)
+                    a, b = net(inp)
+                    (a.transpose(0, 1) @ b).sum().backward()
+                    # a.weight did not go through allreduce, so optimizer acted on local
+                    # gradient, which should be different across ranks. Remaining params
+                    # should be equal.
+                    models = [None for _ in range(dist.get_world_size())]
+                    dist.all_gather_object(models, model)
+                    rank0_model, remainder = models[0], models[1:]
+                    for m in remainder:
+                        self.assertNotEqual(rank0_model.a.weight, m.a.weight)
+                        self.assertEqual(
+                            list(rank0_model.b.parameters()), list(m.b.parameters())
+                        )
+                        self.assertEqual(rank0_model.a.bias, m.a.bias)
+
+        def _get_fp16_config(self) -> _MixedPrecision:
+            return _MixedPrecision(
+                param_dtype=torch.float16,
+                reduce_dtype=torch.float16,
+                buffer_dtype=torch.float16,
+            )
+
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_native_mixed_precision_ignored_params(self):
+            rank = self.rank
+            torch.manual_seed(rank)
+            torch.cuda.manual_seed(rank)
+            torch.cuda.set_device(rank)
+            model = TwoLinLayerNet()
+            model.register_buffer("buffer", torch.ones(5))
+            # Parameters to ignore are in the format {module_name}.{param_name}
+            to_ignore = ["a.weight", "buffer"]
+            torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
+                model,
+                to_ignore,
+            )
+            mp_config = self._get_fp16_config()
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.to(rank),
+                device_ids=[rank],
+                mixed_precision=mp_config,
+                gradient_as_bucket_view=True,
+            )
+            to_ignore = [f"module.{name}" for name in to_ignore]
+            expected_ignored = len(to_ignore)
+            n_ignored = 0
+            # ignored params should not have _mp_param or _fp_param fields.
+            for n, p in itertools.chain(net.named_parameters(), net.named_buffers()):
+                if n in to_ignore:
+                    n_ignored += 1
+                    self.assertFalse(hasattr(p, "_mp_param"))
+                    self.assertFalse(hasattr(p, "_fp_param"))
+                else:
+                    self.assertEqual(mp_config.param_dtype, p._mp_param.dtype)
+                    self.assertEqual(torch.float32, p._fp_param.dtype)
+
+            self.assertEqual(expected_ignored, n_ignored)
+
+        def _test_ddp_native_mixed_precision(
+            self, gradient_as_bucket_view, set_grad_to_none
+        ):
+            rank = self.rank
+            torch.manual_seed(rank)
+            torch.cuda.manual_seed(rank)
+            torch.cuda.set_device(rank)
+            inp = torch.randn(10, 1)
+            mp_config = self._get_fp16_config()
+
+            class MyModel(torch.nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.m = torch.nn.Linear(1, 5)
+                    self.register_buffer("buffer", torch.randn(1, 2))
+                    self.p = torch.nn.Parameter(torch.randn(10, 5), requires_grad=False)
+
+                def forward(self_, x):  # noqa: B902
+                    params = self_.m.parameters()
+                    for p in params:
+                        self.assertEqual(mp_config.param_dtype, p.dtype)
+
+                    self.assertEqual(self_.buffer.dtype, mp_config.buffer_dtype)
+
+                    self.assertEqual(mp_config.param_dtype, x.dtype)
+                    return self_.m(x) + self_.p
+
+            m = MyModel()
+
+            net = torch.nn.parallel.DistributedDataParallel(
+                m.to(rank),
+                device_ids=[rank],
+                mixed_precision=mp_config,
+                gradient_as_bucket_view=gradient_as_bucket_view,
+            )
+            # Buffers are casted in constructor.
+            self.assertEqual(net.module.buffer.dtype, mp_config.buffer_dtype)
+            # Each param should have an mp_param in the lower precision, and
+            # an fp_param in the higher precision.
+            for p in net.parameters():
+                self.assertEqual(mp_config.param_dtype, p._mp_param.dtype)
+                self.assertEqual(torch.float32, p._fp_param.dtype)
+
+            for _ in range(6):
+                loss = net(inp).sum()
+                loss.backward()
+                # Verify gradient synchronization and params and grads are fp32.
+                for n, param in net.named_parameters():
+                    self.assertEqual(param.dtype, torch.float32)
+                    if param.grad is None:
+                        assert n == "module.p"  # Only param that doesn't require grad
+                    else:
+                        self.assertEqual(param.grad.dtype, torch.float32)
+                        tensor_list = [
+                            torch.zeros_like(param.grad)
+                            for _ in range(dist.get_world_size(net.process_group))
+                        ]
+                        dist.all_gather(tensor_list, param.grad)
+                        g, rest = tensor_list[0], tensor_list[1:]
+                        self.assertEqual(g.dtype, torch.float32)
+                        for g_ in rest:
+                            self.assertEqual(g_.dtype, torch.float32)
+                            self.assertEqual(g, g_)
+                net.zero_grad(set_to_none=set_grad_to_none)
+
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_native_mixed_precision_no_grad_as_bucket_view_no_set_grad_none(
+            self,
+        ):
+            self._test_ddp_native_mixed_precision(
+                gradient_as_bucket_view=False,
+                set_grad_to_none=False,
+            )
+
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_native_mixed_precision_grad_as_bucket_view_no_set_grad_none(self):
+            self._test_ddp_native_mixed_precision(
+                gradient_as_bucket_view=True,
+                set_grad_to_none=False,
+            )
+
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_native_mixed_precision_grad_as_bucket_view_set_grad_to_none(self):
+            self._test_ddp_native_mixed_precision(
+                gradient_as_bucket_view=True, set_grad_to_none=True
+            )
+
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_native_mixed_precision_no_grad_as_bucket_view_set_grad_to_none(
+            self,
+        ):
+            self._test_ddp_native_mixed_precision(
+                gradient_as_bucket_view=True, set_grad_to_none=True
+            )
+
+        def _test_ddp_hook_parity(self, state, hook, num_validated_iters=100):
+            rank = self.rank
+            m = torch.nn.Linear(1, 5)
+            try:
+                process_group = state.process_group
+            except AttributeError:
+                process_group = state
+
+            net_with_hook = torch.nn.parallel.DistributedDataParallel(
+                copy.deepcopy(m).to(rank),
+                device_ids=[rank],
+                process_group=process_group,
+            )
+            net_with_hook.register_comm_hook(state=state, hook=hook)
+            net_without_hook = torch.nn.parallel.DistributedDataParallel(
+                copy.deepcopy(m).to(rank),
+                device_ids=[rank],
+                process_group=process_group,
+            )
+            for i in range(100):
+                # Clear gradients manually.
+                for g in [
+                    net_without_hook.module.weight.grad,
+                    net_with_hook.module.weight.grad,
+                ]:
+                    if g is not None:
+                        g.requires_grad_(False)
+                        g.zero_()
+                # Forward + BW
+                batch = torch.tensor([rank]).float().cuda(rank)
+                loss = net_without_hook(batch).sum()
+                loss.backward()
+                # For each worker, the gradient on the weight should be worker_rank.
+                grad = net_without_hook.module.weight.grad
+                avg = grad.clone()
+                expected_grad = (
+                    sum(i for i in range(dist.get_world_size())) / dist.get_world_size()
+                )
+                loss_hook = net_with_hook(batch).sum()
+                loss_hook.backward()
+                grad_hook = net_with_hook.module.weight.grad
+                avg_hook = grad_hook.clone()
+
+                if i < num_validated_iters:
+                    # Verify hook grad with expected.
+                    self.assertEqual(
+                        avg_hook[0, 0].item(),
+                        expected_grad,
+                        msg=f"Expected hook grad of {expected_grad} but got {avg_hook[0, 0]}",
+                    )
+                    # Verify hook grad with vanilla allreduce
+                    self.assertEqual(
+                        avg_hook[0, 0],
+                        avg[0, 0],
+                        msg=f"Expected hook grad to be close to allreduce {avg[0, 0]}, but got {avg_hook[0, 0]}",
+                    )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["cuda"],
+            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_ddp_hook_parity_allreduce(self):
+            self._test_ddp_hook_parity(state=None, hook=default.allreduce_hook)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["cuda"],
+            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_ddp_hook_parity_allreduce_process_group(self):
+            # process_group is passed in to both DDP and comm. hook
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            gpus = [rank_to_GPU[int(r)][0] for r in range(world_size)]
+            process_group = torch.distributed.new_group(gpus)
+            self._test_ddp_hook_parity(state=process_group, hook=default.allreduce_hook)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["cuda"],
+            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_ddp_hook_parity_powerSGD(self):
+            for warm_start in [True, False]:
+                powersgd_state = powerSGD.PowerSGDState(
+                    process_group=None,
+                    matrix_approximation_rank=1,
+                    start_powerSGD_iter=2,
+                    warm_start=warm_start,
+                )
+                self._test_ddp_hook_parity(
+                    state=powersgd_state, hook=powerSGD.powerSGD_hook
+                )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["cuda"],
+            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_ddp_hook_parity_post_localSGD(self):
+            # Although we start run local SGD at iteration 10, since we still use the global process group to run it,
+            # the post-LocalSGD actually still allreduces gradients globally for the remaining iterations.
+            state = post_localSGD.PostLocalSGDState(
+                process_group=None, subgroup=dist.group.WORLD, start_localSGD_iter=10
+            )
+            self._test_ddp_hook_parity(
+                state=state, hook=post_localSGD.post_localSGD_hook
+            )
+            # Only validate the warmup iterations before local SGD is applied,
+            # because when `post_local_gradient_allreduce` is disabled, the gradients will not be synchronized at all.
+            # Note that in practice a model averager has to be applied to run model averaging,
+            # so local gradient averaging is not necessary.
+            start_localSGD_iter = 10
+            state = post_localSGD.PostLocalSGDState(
+                process_group=None,
+                subgroup=dist.group.WORLD,
+                start_localSGD_iter=start_localSGD_iter,
+                post_local_gradient_allreduce=False,
+            )
+            self._test_ddp_hook_parity(
+                state=state,
+                hook=post_localSGD.post_localSGD_hook,
+                num_validated_iters=start_localSGD_iter,
+            )
+
+            # When `subgroup` is None, it is equivalent to the subgroup on the each node.
+            # For this single-node test environment, the intra-node process group is equivalent to
+            # the global process group.
+            if self.world_size == dist.get_world_size():
+                state = post_localSGD.PostLocalSGDState(
+                    process_group=None, subgroup=None, start_localSGD_iter=10
+                )
+                self._test_ddp_hook_parity(
+                    state=state, hook=post_localSGD.post_localSGD_hook
+                )
+
+            # Since we start local SGD later than the total number of 100 iterations,
+            # no local SGD actually is executed, and we don't even need to provide a subgroup for this case.
+            state = post_localSGD.PostLocalSGDState(
+                process_group=None, subgroup=None, start_localSGD_iter=1000
+            )
+            self._test_ddp_hook_parity(
+                state=state, hook=post_localSGD.post_localSGD_hook
+            )
+
+        def _prepare_single_device_module(
+            self,
+            rank,
+            process_group,
+            devices,
+            device_ids,
+            global_batch_size,
+            gradient_as_bucket_view=False,
+        ):
+            model = Net()
+            device = devices[0] if devices else torch.device(f"cuda:{rank:d}")
+            ddp_model = DistributedDataParallel(
+                copy.deepcopy(model).to(device),
+                device_ids=device_ids,
+                process_group=process_group,
+                bucket_cap_mb=0.001,
+                gradient_as_bucket_view=gradient_as_bucket_view,
+            )
+
+            model.to(device)
+
+            input = torch.randn(global_batch_size, 2).to(device)
+            target = torch.randn(global_batch_size, 4).to(device)
+
+            return model, ddp_model, input, target
+
+        def _prepare_cpu_module(
+            self,
+            process_group,
+            global_batch_size,
+            gradient_as_bucket_view=False,
+        ):
+            model = Net()
+            ddp_model = DistributedDataParallel(
+                copy.deepcopy(model),
+                process_group=process_group,
+                bucket_cap_mb=0.001,
+                gradient_as_bucket_view=gradient_as_bucket_view,
+            )
+            input = torch.randn(global_batch_size, 2)
+            target = torch.randn(global_batch_size, 4)
+            return model, ddp_model, input, target
+
+        def _test_accumulate_gradients_no_sync(
+            self, num_iters=2, ddp_comm_hook=None, gradient_as_bucket_view=False
+        ):
+            """
+            This is the recommended way to implement accumulate grads.
+            If ``ddp_comm_hook`` input was specified, it will also register that hook
+            to the ``ddp_model``. The hook fed into this function should not change
+            the resulting gradients.
+            """
+            _group, group_id, rank = self._init_global_test()
+            world_size = get_world_size()
+
+            # FIXME: Add testing for gloo/CUDA
+            if BACKEND == "mpi" or BACKEND == "gloo":
+                global_batch_size = world_size
+                local_batch_size = 1
+                model, ddp_model, input, target = self._prepare_cpu_module(
+                    group_id, global_batch_size, gradient_as_bucket_view
+                )
+
+            if BACKEND == "nccl":
+                rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+                int_devices = rank_to_GPU[rank][:1]
+                devices = [torch.device("cuda:" + str(i)) for i in int_devices]
+                global_batch_size = world_size
+                local_batch_size = len(devices)
+                model, ddp_model, input, target = self._prepare_single_device_module(
+                    rank,
+                    group_id,
+                    devices,
+                    devices,
+                    global_batch_size,
+                    gradient_as_bucket_view,
+                )
+
+            if ddp_comm_hook is not None:
+                ddp_model.register_comm_hook(group_id, ddp_comm_hook)
+
+            def step_model(model, input, target):
+                model.train()
+                output = model(input)
+                loss = F.mse_loss(output, target.to(output.device))
+                loss.backward()
+
+            # ensure accumulate grads works with no_grad => no grads are accumulated.
+            with torch.no_grad():
+                with ddp_model.no_sync():
+                    ddp_model.train()
+                    ddp_model(input)
+
+            # check two model parameters over num_iters iterations
+            for iteration in range(num_iters):
+                step_model(model, input, target)
+
+                ddp_input = input[
+                    rank * local_batch_size : (rank + 1) * local_batch_size
+                ]
+                ddp_target = target[
+                    rank * local_batch_size : (rank + 1) * local_batch_size
+                ]
+
+                if iteration % 2 == 0:
+                    # accumulate grads locally
+                    with ddp_model.no_sync():
+                        step_model(ddp_model, ddp_input, ddp_target)
+                else:
+                    # sync grads
+                    step_model(ddp_model, ddp_input, ddp_target)
+
+                for i, j in zip(
+                    model.parameters(), ddp_model.parameters(), strict=True
+                ):
+                    if not i.requires_grad:
+                        continue
+                    if iteration % 2 == 0:
+                        self.assertNotEqual(i.grad, j.grad)
+                    else:
+                        self.assertEqual(i.grad, j.grad)
+
+                # Shuffle the input so that DDP input is different
+                torch.manual_seed(1337 + iteration)
+                input = input[torch.randperm(global_batch_size)]
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
+            "get_future is only supported on mpi, nccl and gloo",
+        )
+        @nccl_skip_if_lt_x_gpu(BACKEND, 2)
+        def test_accumulate_gradients_no_sync(self):
+            """
+            Runs _test_accumulate_gradients_no_sync using default inputs
+            """
+            self._test_accumulate_gradients_no_sync()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
+            "get_future is only supported on mpi, nccl and gloo",
+        )
+        @nccl_skip_if_lt_x_gpu(BACKEND, 2)
+        def test_accumulate_gradients_no_sync_grad_is_view(self):
+            """
+            Runs _test_accumulate_gradients_no_sync using default inputs
+            """
+            self._test_accumulate_gradients_no_sync(gradient_as_bucket_view=True)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
+            "get_future is only supported on mpi, nccl and gloo",
+        )
+        @nccl_skip_if_lt_x_gpu(BACKEND, 2)
+        def test_accumulate_gradients_no_sync_allreduce_hook(self):
+            """
+            Runs multiple iterations on _test_accumulate_gradients_no_sync
+            using allreduce hook and validates whether future result was properly
+            passed as gradients in reducer.
+            """
+
+            world_size = get_world_size()
+
+            def allreduce_hook(
+                group_id: object, bucket: dist.GradBucket
+            ) -> torch.futures.Future[torch.Tensor]:
+                tensors = [bucket.buffer() / world_size]
+                return (
+                    group_id.allreduce(tensors)
+                    .get_future()
+                    .then(lambda fut: fut.value()[0])
+                )
+
+            self._test_accumulate_gradients_no_sync(
+                num_iters=4, ddp_comm_hook=allreduce_hook
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
+            "get_future is only supported on mpi, nccl and gloo",
+        )
+        @nccl_skip_if_lt_x_gpu(BACKEND, 2)
+        def test_accumulate_gradients_no_sync_allreduce_with_then_hook(self):
+            """
+            Runs multiple iterations on _test_accumulate_gradients_no_sync using allreduce
+            hook that also uses then callbacks. In first then callback result is multiplied
+            by 2, and the second callback divides the result by 2 * world_size. It validates
+            whether final result was properly passed as gradients in reducer.
+            """
+
+            world_size = get_world_size()
+
+            def allreduce_with_then_hook(
+                group_id: object, bucket: dist.GradBucket
+            ) -> torch.futures.Future[torch.Tensor]:
+                fut = group_id.allreduce([bucket.buffer()]).get_future()
+
+                def mult(fut):
+                    # Multiply the result by 2.
+                    return 2 * fut.wait()[0]
+
+                def div(fut):
+                    # Divide the result by 2 * world_size.
+                    return fut.wait() / (2 * world_size)
+
+                return fut.then(mult).then(div)
+
+            self._test_accumulate_gradients_no_sync(
+                num_iters=4, ddp_comm_hook=allreduce_with_then_hook
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
+            "get_future is only supported on mpi, nccl and gloo",
+        )
+        @nccl_skip_if_lt_x_gpu(BACKEND, 2)
+        def test_get_future(self):
+            def mult(fut):
+                return [t * 3 for t in fut.wait()]
+
+            def add(fut):
+                return [t + 1 for t in fut.wait()]
+
+            group, group_id, rank = self._init_global_test()
+            input = _build_tensor(3, 2)
+            if BACKEND == "nccl":
+                rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+                device_id = rank_to_GPU[rank][0]
+                input = input.to(device_id)
+            fut = group_id.allreduce([input]).get_future()
+            res = fut.then(mult).then(add).wait()
+            expected = _build_tensor(3, 2 * len(group) * 3 + 1)
+
+            self.assertEqual(res[0], expected)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel(self):
+            _group, _group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            gpus = list(rank_to_GPU[rank])
+
+            for use_bucket_view, static_graph in itertools.product(
+                (False, True), (False, True)
+            ):
+                self._test_DistributedDataParallel(
+                    gpu_subset=gpus,
+                    rank=rank,
+                    gradient_as_bucket_view=use_bucket_view,
+                    static_graph=static_graph,
+                )
+
+                # test set static graph twice
+                self._test_DistributedDataParallel(
+                    gpu_subset=gpus,
+                    rank=rank,
+                    gradient_as_bucket_view=use_bucket_view,
+                    static_graph=static_graph,
+                    set_static_graph_twice=True,
+                )
+
+                # test output_device
+                self._test_DistributedDataParallel(
+                    gpu_subset=gpus,
+                    rank=rank,
+                    output_device=torch.device("cuda"),
+                    gradient_as_bucket_view=use_bucket_view,
+                    static_graph=static_graph,
+                )
+
+                # test device_ids
+                gpus_list = [torch.device("cuda:" + str(i)) for i in gpus]
+                self._test_DistributedDataParallel(
+                    gpu_subset=gpus_list,
+                    rank=rank,
+                    output_device=torch.device("cuda"),
+                    gradient_as_bucket_view=use_bucket_view,
+                    static_graph=static_graph,
+                )
+
+        def _test_DistributedDataParallel_with_amp(self, grad_is_view=False):
+            torch.manual_seed(31415)
+            # Creates model and optimizer in default precision
+            model = Net().cuda()
+            optimizer = torch.optim.SGD(model.parameters(), lr=0.03)
+
+            # Creates a GradScaler once at the beginning of training.
+            scaler = GradScaler()
+
+            ddp_model = nn.parallel.DistributedDataParallel(
+                model, device_ids=[self.rank], gradient_as_bucket_view=grad_is_view
+            )
+
+            input = torch.randn(dist.get_world_size() * 2, 2).cuda()
+            target = torch.randn(dist.get_world_size() * 2, 4).cuda()
+            loss_fn = nn.MSELoss()
+
+            # verify grads are none before training
+            for p in ddp_model.parameters():
+                self.assertTrue(p is not None)
+                self.assertTrue(p.grad is None)
+
+            for idx in range(20):
+                optimizer.zero_grad()
+                # Runs the forward pass with autocasting.
+                with autocast():
+                    output = ddp_model(input)
+                    loss = loss_fn(output, target)
+
+                # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
+                # Backward passes under autocast are not recommended.
+                # Backward ops run in the same dtype autocast chose for corresponding forward ops.
+                scaler.scale(loss).backward()
+
+                # verify grads are not none and are valid during training
+                for p in ddp_model.parameters():
+                    if p.requires_grad:
+                        self.assertTrue(p.grad is not None)
+                        self.assertFalse(p.grad.isnan().any())
+                        self.assertFalse(p.grad.isinf().any())
+
+                # scaler.step() first unscales the gradients of the optimizer's assigned params.
+                # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
+                # otherwise, optimizer.step() is skipped.
+                scaler.step(optimizer)
+
+                # Updates the scale for next iteration.
+                scaler.update()
+
+                # Shuffle the input so that DDP input is different
+                torch.manual_seed(1337 + idx)
+                input = input[torch.randperm(dist.get_world_size() * 2)]
+
+            return ddp_model
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel_with_amp_and_grad_is_view(self):
+            torch.cuda.set_device(self.rank)
+            ddp_model_grad_not_view = self._test_DistributedDataParallel_with_amp(
+                grad_is_view=False
+            )
+            ddp_model_grad_is_view = self._test_DistributedDataParallel_with_amp(
+                grad_is_view=True
+            )
+            for i, j in zip(
+                ddp_model_grad_not_view.parameters(),
+                ddp_model_grad_is_view.parameters(),
+                strict=True,
+            ):
+                self.assertEqual(i, j)
+
+        def _test_DistributedDataParallel_SyncBatchNorm(
+            self,
+            gpu_subset,
+            rank,
+            local_bs,
+            global_bs,
+            offset,
+            output_device=None,
+            affine=True,
+        ):
+            # Run a simple end to end DDP model, use result of single node model
+            # as baseline
+
+            # cpu training setup
+            model = BatchNormNet() if affine else BatchNormNet(affine=False)
+
+            # single gpu training setup
+            model_gpu = copy.deepcopy(model)
+            model_gpu.cuda(gpu_subset[0])
+
+            # DDP training setup
+            model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(model))
+            model_DDP.cuda(gpu_subset[0])
+            model_DDP = nn.parallel.DistributedDataParallel(
+                model_DDP, device_ids=gpu_subset
+            )
+
+            # test serializable/unserializable
+            with tempfile.NamedTemporaryFile() as tmp:
+                if sys.platform == "win32":
+                    torch.save(model_DDP, tmp)
+                    tmp.seek(0)
+                    # weights_only=False as this is legacy code that saves the model
+                    model_DDP = torch.load(tmp, weights_only=False)
+                else:
+                    torch.save(model_DDP, tmp.name)
+                    # weights_only=False as this is legacy code that saves the model
+                    model_DDP = torch.load(tmp.name, weights_only=False)
+
+            # data initialization
+            input_cpu = torch.randn(global_bs, 2)
+            target = torch.randn(global_bs, 4)
+            loss = nn.MSELoss()
+
+            # check two model parameters over 5 iterations
+            self._test_DDP_niter(
+                model_gpu,
+                model_DDP,
+                input_cpu.cuda(gpu_subset[0]),
+                target.cuda(gpu_subset[0]),
+                loss,
+                local_bs,
+                rank,
+                global_bs,
+                True,
+                offset,
+                dist.get_world_size(),
+                5 if affine else 2,
+            )
+            self._barrier()
+
+        def _test_post_localSGD_optimizer_parity(self, create_averager, grad_is_view):
+            learning_rate = 0.03
+
+            DDP_NET = Net()
+            net = torch.nn.parallel.DistributedDataParallel(
+                copy.deepcopy(DDP_NET).cuda(),
+                device_ids=[self.rank],
+                gradient_as_bucket_view=grad_is_view,
+            )
+            averager = create_averager()
+            opt = torch.optim.SGD(net.parameters(), lr=learning_rate)
+
+            net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel(
+                copy.deepcopy(DDP_NET).cuda(),
+                device_ids=[self.rank],
+                gradient_as_bucket_view=grad_is_view,
+            )
+            # Process group cannot be pickled in some environments,
+            # so cannot deep copy an averager. See:
+            # https://github.com/pytorch/pytorch/pull/74737#pullrequestreview-922487496
+            averager2 = create_averager()
+            post_localSGD_opt = self._create_post_localSGD_optimizer(
+                net_using_post_localSGD_opt, learning_rate, averager2
+            )
+
+            input = torch.randn(dist.get_world_size() * 2, 2).cuda()
+            target = torch.randn(dist.get_world_size() * 2, 4).cuda()
+            loss_fn = nn.MSELoss()
+
+            for _ in range(20):
+                self._perform_a_train_step(opt, net, loss_fn, input, target)
+                averager.average_parameters(net.parameters())
+
+                self._perform_a_train_step(
+                    post_localSGD_opt,
+                    net_using_post_localSGD_opt,
+                    loss_fn,
+                    input,
+                    target,
+                )
+                for p1, p2 in zip(
+                    net.parameters(),
+                    net_using_post_localSGD_opt.parameters(),
+                    strict=True,
+                ):
+                    self.assertEqual(p1.data, p2.data)
+
+            # Also check if the built-in step counters are the same to prevent a bug like #74737.
+            self.assertEqual(averager.step, averager2.step)
+
+        def _create_periodic_model_averager(self):
+            return averagers.PeriodicModelAverager(period=4, warmup_steps=10)
+
+        def _create_post_localSGD_optimizer(self, net, learning_rate, averager):
+            return post_localSGD_optimizer.PostLocalSGDOptimizer(
+                optim=torch.optim.SGD(net.parameters(), lr=learning_rate),
+                averager=averager,
+            )
+
+        def _perform_a_train_step(self, optimizer, net, loss_fn, input, target):
+            optimizer.zero_grad()
+            output = net(input)
+            loss = loss_fn(output, target)
+            loss.backward()
+            optimizer.step()
+
+        def _test_post_localSGD_optimizer_step_reload(
+            self, create_averager, chkpt_file
+        ):
+            learning_rate = 0.03
+
+            net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel(
+                Net().cuda(), device_ids=[self.rank]
+            )
+
+            averager = create_averager()
+            post_localSGD_opt = self._create_post_localSGD_optimizer(
+                net_using_post_localSGD_opt, learning_rate, averager
+            )
+
+            averager2 = create_averager()
+            dummy_post_localSGD_opt = self._create_post_localSGD_optimizer(
+                net_using_post_localSGD_opt, learning_rate, averager2
+            )
+
+            input = torch.randn(dist.get_world_size() * 2, 2).cuda()
+            target = torch.randn(dist.get_world_size() * 2, 4).cuda()
+            loss_fn = nn.MSELoss()
+
+            for _ in range(20):
+                self._perform_a_train_step(
+                    post_localSGD_opt,
+                    net_using_post_localSGD_opt,
+                    loss_fn,
+                    input,
+                    target,
+                )
+
+            if self.rank == 0:
+                torch.save(
+                    {"optimizer_state_dict": post_localSGD_opt.state_dict()}, chkpt_file
+                )
+
+            dist.barrier()
+            map_location = {"cuda:0": f"cuda:{self.rank:d}"}
+            checkpoint = torch.load(chkpt_file, map_location=map_location)
+            dummy_post_localSGD_opt.load_state_dict(checkpoint["optimizer_state_dict"])
+
+            # Check that we didn't hit the trivial case
+            self.assertNotEqual(averager2.step, 0)
+            # Check if dummy averager was initialized to a correct value
+            self.assertEqual(averager.step, averager2.step)
+
+            # Remove 'step' entry from a checkpoint.
+            # And make sure it is not in the state dictionary
+            del checkpoint["optimizer_state_dict"]["step"]
+            self.assertNotIn("step", checkpoint["optimizer_state_dict"])
+
+            # Check if checkpoint without a 'step' entry invokes a warning
+            with self.assertWarnsRegex(
+                expected_warning=UserWarning,
+                expected_regex="Loaded state dict does not contain a step counter for an averager. "
+                "Setting step counter to 0.",
+            ):
+                dummy_post_localSGD_opt.load_state_dict(
+                    checkpoint["optimizer_state_dict"]
+                )
+
+            self.assertEqual(averager2.step, 0)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_post_localSGD_optimizer_parity(self):
+            torch.cuda.set_device(self.rank)
+            self._test_post_localSGD_optimizer_parity(
+                self._create_periodic_model_averager,
+                grad_is_view=False,
+            )
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_post_localSGD_optimizer_parity_grad_is_view(self):
+            torch.cuda.set_device(self.rank)
+            self._test_post_localSGD_optimizer_parity(
+                self._create_periodic_model_averager,
+                grad_is_view=True,
+            )
+
+        def _create_hierarchical_model_averager(self):
+            period_group_size_dict = OrderedDict([(2, 2), (4, dist.get_world_size())])
+            return hierarchicalSGD.HierarchicalModelAverager(
+                period_group_size_dict=period_group_size_dict, warmup_steps=4
+            )
+
+        @skip_if_lt_x_gpu(4)
+        @skip_if_odd_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_post_localSGD_optimizer_parity_with_hierarchical_sgd(self):
+            torch.cuda.set_device(self.rank)
+            self._test_post_localSGD_optimizer_parity(
+                self._create_hierarchical_model_averager,
+                grad_is_view=False,
+            )
+
+        @skip_if_lt_x_gpu(4)
+        @skip_if_odd_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_post_localSGD_optimizer_parity_with_hierarchical_sgd_grad_is_view(
+            self,
+        ):
+            torch.cuda.set_device(self.rank)
+            self._test_post_localSGD_optimizer_parity(
+                self._create_hierarchical_model_averager,
+                grad_is_view=True,
+            )
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_post_localSGD_optimizer_step_reload(self):
+            torch.cuda.set_device(self.rank)
+            with _rank_temp_file() as tmp_file:
+                self._test_post_localSGD_optimizer_step_reload(
+                    self._create_periodic_model_averager, tmp_file
+                )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel_SyncBatchNorm_Channels_Last(self):
+            self._test_DistributedDataParallel_SyncBatchNorm_with_memory_format(
+                torch.channels_last
+            )
+            self._test_DistributedDataParallel_SyncBatchNorm_with_memory_format(
+                torch.channels_last_3d
+            )
+
+        def _test_DistributedDataParallel_SyncBatchNorm_with_memory_format(
+            self, memory_format
+        ):
+            _group, _group_id, rank = self._init_global_test()
+            num_processes = dist.get_world_size()
+            local_bs = 2
+            bs_offset = int(rank * 2)
+            global_bs = int(num_processes * 2)
+
+            model = nn.SyncBatchNorm(2, momentum=0.99)
+            model_gpu = copy.deepcopy(model).cuda(rank)
+            model_DDP = nn.parallel.DistributedDataParallel(
+                model_gpu, device_ids=[rank]
+            )
+
+            shapes = [global_bs, 2, 4, 4] + (
+                [] if memory_format is torch.channels_last else [4]
+            )
+
+            input_gpu = (
+                torch.randn(*shapes, dtype=torch.float)
+                .cuda(rank)
+                .to(memory_format=memory_format)
+            )
+            target_gpu = (
+                torch.randn(*shapes, dtype=torch.float)
+                .cuda(rank)
+                .to(memory_format=memory_format)
+            )
+            loss = nn.MSELoss()
+
+            # check two model parameters over 5 iterations
+            self._test_DDP_niter(
+                model_gpu,
+                model_DDP,
+                input_gpu,
+                target_gpu,
+                loss,
+                local_bs,
+                rank,
+                global_bs,
+                True,
+                bs_offset,
+                dist.get_world_size(),
+                memory_format=memory_format,
+            )
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel_SyncBatchNorm(self):
+            _group, _group_id, rank = self._init_global_test()
+            world_size = dist.get_world_size()
+            # DDP does not support replicating BN layers within a process, hence
+            # testing with one module replica per process
+            gpus = [rank]
+
+            local_bs = 2
+            bs_offset = int(rank * 2)
+            global_bs = int(world_size * 2)
+
+            self._test_DistributedDataParallel_SyncBatchNorm(
+                gpu_subset=gpus,
+                rank=rank,
+                local_bs=local_bs,
+                global_bs=global_bs,
+                offset=bs_offset,
+            )
+
+            # test output_device
+            self._test_DistributedDataParallel_SyncBatchNorm(
+                gpu_subset=gpus,
+                rank=rank,
+                local_bs=local_bs,
+                global_bs=global_bs,
+                offset=bs_offset,
+                output_device=torch.device("cuda"),
+            )
+
+            # test device_ids
+            gpus = [torch.device("cuda:" + str(i)) for i in gpus]
+            self._test_DistributedDataParallel_SyncBatchNorm(
+                gpu_subset=gpus,
+                rank=rank,
+                local_bs=local_bs,
+                global_bs=global_bs,
+                offset=bs_offset,
+                output_device=torch.device("cuda"),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel_SyncBatchNorm_No_Affine(self):
+            _group, _group_id, rank = self._init_global_test()
+            world_size = dist.get_world_size()
+            # DDP does not support replicating BN layers within a process, hence
+            # testing with one module replica per process
+            gpus = [rank]
+
+            local_bs = 2
+            bs_offset = int(rank * 2)
+            global_bs = int(world_size * 2)
+
+            self._test_DistributedDataParallel_SyncBatchNorm(
+                gpu_subset=gpus,
+                rank=rank,
+                local_bs=local_bs,
+                global_bs=global_bs,
+                offset=bs_offset,
+                affine=False,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel_SyncBatchNorm_2D_Input(self):
+            _group, _group_id, rank = self._init_global_test()
+            # DDP does not support replicating BN layers within a process, hence
+            # testing with one module replica per process
+            gpus = [rank]
+
+            model = nn.BatchNorm1d(2)
+
+            # single gpu training setup
+            model_gpu = copy.deepcopy(model)
+            model_gpu.cuda(gpus[0])
+
+            # DDP training setup
+            model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(model))
+            model_DDP.cuda(gpus[0])
+            model_DDP = nn.parallel.DistributedDataParallel(model_DDP, device_ids=gpus)
+
+            local_bs = len(gpus) * 2
+            global_bs = dist.get_world_size() * local_bs
+            input_cpu = torch.randn(global_bs, 2)
+            target = torch.randn(global_bs, 2)
+            loss = nn.MSELoss()
+
+            # disabling cudnn.
+            # SyncBatchNorm goes through native_batch_norm kernel, this avoids the
+            # numerical issue created by the divergent code path.
+            with torch.backends.cudnn.flags(False):
+                # check two model parameters over 5 iterations
+                self._test_DDP_niter(
+                    model_gpu,
+                    model_DDP,
+                    input_cpu.cuda(gpus[0]),
+                    target.cuda(gpus[0]),
+                    loss,
+                    local_bs,
+                    rank,
+                    global_bs,
+                    True,
+                )
+                self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        @require_world_size(2)
+        def test_DistributedDataParallel_SyncBatchNorm_Single_Input_Per_Process(self):
+            _group, _group_id, rank = self._init_global_test()
+            # DDP does not support replicating BN layers within a process, hence
+            # testing with one module replica per process
+            gpus = [rank]
+
+            model = nn.BatchNorm1d(2)
+
+            # single gpu training setup
+            model_gpu = copy.deepcopy(model)
+            model_gpu.cuda(gpus[0])
+
+            # DDP training setup
+            model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(model))
+            model_DDP.cuda(gpus[0])
+            model_DDP = nn.parallel.DistributedDataParallel(model_DDP, device_ids=gpus)
+
+            local_bs = 1
+            global_bs = dist.get_world_size()
+            input_cpu = torch.randn(global_bs, 2)
+            target = torch.randn(global_bs, 2)
+            loss = nn.MSELoss()
+
+            # disabling cudnn.
+            # SyncBatchNorm goes through native_batch_norm kernel, this avoids the
+            # numerical issue created by the divergent code path.
+            with torch.backends.cudnn.flags(False):
+                # check two model parameters over 5 iterations
+                self._test_DDP_niter(
+                    model_gpu,
+                    model_DDP,
+                    input_cpu.cuda(gpus[0]),
+                    target.cuda(gpus[0]),
+                    loss,
+                    local_bs,
+                    rank,
+                    global_bs,
+                    True,
+                )
+                self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_Running_Value(
+            self,
+        ):
+            ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99)
+            _group, _group_id, rank = self._init_global_test()
+            model = nn.parallel.DistributedDataParallel(
+                ONLY_SBN_NET.cuda(rank), device_ids=[rank]
+            )
+
+            input_var = []
+            for i in range(dist.get_world_size()):
+                input_var_rank = torch.cat(
+                    [
+                        torch.ones(2, 1, 10 ** (i + 1)) * (0.1 ** (i - 1)),
+                        torch.ones(2, 1, 10 ** (i + 1)) * (0.3 ** (i - 1)),
+                    ],
+                    dim=1,
+                )
+                input_var.append(input_var_rank)
+
+            all_input_var = torch.cat(
+                [
+                    x.permute(1, 0, 2).contiguous().view(ONLY_SBN_NET.num_features, -1)
+                    for x in input_var
+                ],
+                dim=1,
+            ).cuda(rank)
+
+            for _ in range(100):
+                y = model(input_var[rank].cuda(rank))
+                y.mean().backward()
+
+            running_mean, running_var = (
+                model.module.running_mean,
+                model.module.running_var,
+            )
+            torch.testing.assert_close(running_mean, all_input_var.mean(1))
+            torch.testing.assert_close(running_var, all_input_var.var(1))
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_gradient(self):
+            _group, _group_id, rank = self._init_global_test()
+            # only do single GPU per process
+            gpus = [rank]
+
+            # cpu training setup
+            num_processes = dist.get_world_size()
+            local_bs = rank + 2
+            bs_offset = int((rank + 3) * rank / 2)
+            global_bs = int((num_processes + 3) * num_processes / 2)
+
+            self._test_DistributedDataParallel_SyncBatchNorm(
+                gpu_subset=gpus,
+                rank=rank,
+                local_bs=local_bs,
+                global_bs=global_bs,
+                offset=bs_offset,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel_SyncBatchNorm_half(self):
+            _group, _group_id, rank = self._init_global_test()
+
+            model = BatchNormNet()
+            model = model.half()
+            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+            model = nn.parallel.DistributedDataParallel(
+                model.cuda(rank), device_ids=[rank]
+            )
+            inp = torch.randn(2, 2, dtype=torch.float16, device=torch.device(rank))
+            # Check that forward/backward do not error with dtype mismatch
+            out = model(inp)
+            self.assertEqual(out.dtype, torch.float16)
+            out.sum().backward()
+            for param in model.parameters():
+                self.assertEqual(param.grad.dtype, torch.float16)
+
+        def _test_ddp_logging_data(self, is_gpu):
+            rank = dist.get_rank()
+            model_DDP = Net()
+            if is_gpu:
+                model_DDP = nn.parallel.DistributedDataParallel(
+                    model_DDP.cuda(rank), device_ids=[rank]
+                )
+            else:
+                model_DDP = nn.parallel.DistributedDataParallel(model_DDP)
+
+            # dummy data initialization
+            local_bs = 2
+            batch_size, input, target, loss = self._prepare_dummy_data(local_bs)
+            if is_gpu:
+                input = input.cuda(rank)
+                target = target.cuda(rank)
+
+            model_DDP._set_ddp_runtime_logging_sample_rate(2)
+
+            for idx in range(20):
+                offset = rank * local_bs
+
+                # DDP training, DDP scatters subsets of input to nodes/GPUs
+                self._test_DDP_helper(
+                    model_DDP,
+                    input[offset : offset + local_bs],
+                    target[offset : offset + local_bs],
+                    loss,
+                    1,
+                )
+
+                self._model_step_with_zero_grad(model_DDP)
+
+                # Verify DDP logging data is sampled as expected
+                # If it has ran more than 10 iterations and this is
+                # the sampled iteration for measuring run time stats,
+                # the run time stats for this idx-th iteration will not
+                # be zeros.
+                ddp_logging_data = model_DDP._get_ddp_logging_data()
+                if idx > 0 and (idx < 10 or idx % 2 == 0):
+                    self.assertGreaterEqual(
+                        ddp_logging_data.get("forward_compute_time"), 1
+                    )
+                    self.assertGreaterEqual(
+                        ddp_logging_data.get("backward_compute_time"), 1
+                    )
+                    self.assertGreaterEqual(
+                        ddp_logging_data.get("backward_comm_time"), 1
+                    )
+                    self.assertGreaterEqual(
+                        ddp_logging_data.get("backward_compute_time"),
+                        ddp_logging_data.get("backward_compute_comm_overlap_time"),
+                    )
+                    self.assertGreaterEqual(
+                        ddp_logging_data.get("backward_comm_time"),
+                        ddp_logging_data.get("backward_compute_comm_overlap_time"),
+                    )
+                    self.assertEqual(ddp_logging_data.get("iteration"), idx)
+                elif idx > 0:
+                    # if the idx-th iteration is not sampled to set runtime stats,
+                    # ddp_logging_data.iteration will not be updated to current
+                    # iteration.
+                    self.assertNotEqual(ddp_logging_data.get("iteration"), idx)
+
+                # Shuffle the input so that DDP input is different
+                input = input[torch.randperm(batch_size)]
+
+            return model_DDP
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "nccl does not support DDP on CPU models"
+        )
+        def test_ddp_logging_data_cpu(self):
+            def parse_env(var):
+                return os.environ.get(var, "N/A")
+
+            dist.set_debug_level(dist.DebugLevel.INFO)
+            _, group_id, _ = self._init_global_test()
+            model_DDP = self._test_ddp_logging_data(is_gpu=False)
+
+            ddp_logging_data = model_DDP._get_ddp_logging_data()
+            self.assertEqual(ddp_logging_data.get("world_size"), dist.get_world_size())
+            self.assertEqual(ddp_logging_data.get("rank"), dist.get_rank())
+            self.assertEqual(ddp_logging_data.get("module_name"), "Net")
+            self.assertEqual(ddp_logging_data.get("device_ids"), "")
+            # output_device is -1 in default if it is not set, e.g.
+            # output_device of CPU training is -1.
+            self.assertEqual(ddp_logging_data.get("output_device"), -1)
+            self.assertEqual(ddp_logging_data.get("broadcast_buffers"), 1)
+            self.assertEqual(ddp_logging_data.get("bucket_cap_bytes"), 25 * 1024 * 1024)
+            self.assertEqual(ddp_logging_data.get("find_unused_parameters"), 0)
+            self.assertEqual(ddp_logging_data.get("gradient_as_bucket_view"), 0)
+            self.assertEqual(
+                ddp_logging_data.get("backend_name"), dist.get_backend(group_id)
+            )
+            self.assertEqual(ddp_logging_data.get("iteration"), 18)
+            params = list(model_DDP.parameters())
+            num_params = 0
+            param_size = 0
+            params = list(filter(lambda parameter: parameter.requires_grad, params))
+            for p in params:
+                num_params += 1
+                param_size += p.numel() * p.element_size()
+            self.assertEqual(ddp_logging_data.get("dtypes"), "float")
+            self.assertEqual(
+                ddp_logging_data.get("total_parameter_size_bytes"), param_size
+            )
+            self.assertEqual(ddp_logging_data.get("num_parameter_tensors"), num_params)
+            self.assertEqual(ddp_logging_data.get("bucket_sizes"), str(param_size))
+            self.assertEqual(
+                ddp_logging_data.get("master_port"), parse_env("MASTER_PORT")
+            )
+            self.assertEqual(
+                ddp_logging_data.get("master_addr"), parse_env("MASTER_ADDR")
+            )
+            self.assertEqual(
+                ddp_logging_data.get("torch_distributed_debug"),
+                parse_env("TORCH_DISTRIBUTED_DEBUG"),
+            )
+            self.assertEqual(
+                ddp_logging_data.get("cuda_visible_devices"),
+                parse_env("CUDA_VISIBLE_DEVICES"),
+            )
+            if ddp_logging_data.get("backend_name") == "gloo":
+                self.assertEqual(
+                    ddp_logging_data.get("gloo_socket_ifname"),
+                    parse_env("GLOO_SOCKET_IFNAME"),
+                )
+                self.assertEqual(
+                    ddp_logging_data.get("gloo_device_transport"),
+                    parse_env("GLOO_DEVICE_TRANSPORT"),
+                )
+                default_gloo_threads = 2
+                self.assertEqual(
+                    ddp_logging_data.get("gloo_num_threads"),
+                    default_gloo_threads,
+                )
+
+            self.assertEqual(ddp_logging_data.get("nccl_socket_ifname"), None)
+            self.assertEqual(ddp_logging_data.get("nccl_blocking_wait"), None)
+            self.assertEqual(ddp_logging_data.get("nccl_async_error_handling"), None)
+            self.assertEqual(ddp_logging_data.get("nccl_debug"), None)
+            self.assertEqual(ddp_logging_data.get("nccl_nthreads"), None)
+            self.assertEqual(ddp_logging_data.get("nccl_ib_timeout"), None)
+            # test runtime logging fields
+            # Note: DETAIL debug mode logs DDP logging data to stdout and
+            # thus accesses std::map, which fills in a default value for the
+            # type if it didn't exist.
+            self.assertEqual(ddp_logging_data.get("unused_parameter_size", 0), 0)
+            self.assertEqual(ddp_logging_data.get("has_rebuilt_buckets"), 1)
+            self.assertEqual(
+                ddp_logging_data.get("rebuilt_bucket_sizes"), str(param_size)
+            )
+            grad_ready_order = ddp_logging_data.get(
+                "prev_iteration_grad_ready_order_indices"
+            )
+            expected_order = list(reversed([str(x) for x in range(3)]))
+            self.assertEqual(grad_ready_order, ", ".join(expected_order))
+            bucket_indices = ddp_logging_data.get("rebuilt_per_bucket_param_indices")
+            self.assertEqual(bucket_indices, " ".join(expected_order))
+            # It is hard to test accurate latency, but it can test whether the latency is
+            # a valid value and in the expected range.
+            self.assertGreaterEqual(ddp_logging_data.get("avg_forward_compute_time"), 1)
+            self.assertGreaterEqual(
+                ddp_logging_data.get("avg_backward_compute_time"), 1
+            )
+            self.assertGreaterEqual(ddp_logging_data.get("avg_backward_comm_time"), 1)
+            self.assertGreaterEqual(
+                ddp_logging_data.get("avg_backward_compute_time"),
+                ddp_logging_data.get("avg_backward_compute_comm_overlap_time"),
+            )
+            self.assertGreaterEqual(
+                ddp_logging_data.get("avg_backward_comm_time"),
+                ddp_logging_data.get("avg_backward_compute_comm_overlap_time"),
+            )
+            # Test host-side times are roughly in the order that we expect
+            fwd_host_side_time = ddp_logging_data.get("forward_compute_time_start")
+            bwd_comp_start_host_side_time = ddp_logging_data.get(
+                "backward_compute_time_start"
+            )
+            bwd_comp_end_host_side_time = ddp_logging_data.get(
+                "backward_compute_time_end"
+            )
+            bwd_comm_start_host_side_time = ddp_logging_data.get(
+                "backward_comm_time_start"
+            )
+            bwd_comm_end_host_side_time = ddp_logging_data.get("backward_comm_time_end")
+            self.assertGreaterEqual(
+                bwd_comm_end_host_side_time, bwd_comm_start_host_side_time
+            )
+            self.assertGreaterEqual(
+                bwd_comm_start_host_side_time, bwd_comp_start_host_side_time
+            )
+            self.assertGreaterEqual(
+                bwd_comp_end_host_side_time, bwd_comp_start_host_side_time
+            )
+            self.assertGreaterEqual(bwd_comp_start_host_side_time, fwd_host_side_time)
+
+            # test larger net with mixed data types, verify multiple bucket sizes
+            model = LargeNet()
+            model.float()
+            model.fc1.double()
+            model_DDP = nn.parallel.DistributedDataParallel(model, bucket_cap_mb=1.5)
+            ddp_logging_data = model_DDP._get_ddp_logging_data()
+            params = list(model_DDP.parameters())
+            self.assertEqual(
+                ddp_logging_data.get("bucket_cap_bytes"), int(1.5 * 1024 * 1024)
+            )
+            bucket_sizes = [
+                params[1].numel() * params[1].element_size(),
+                params[0].numel() * params[0].element_size(),
+            ]
+            self.assertEqual(
+                ddp_logging_data.get("bucket_sizes"),
+                ", ".join(str(x) for x in bucket_sizes),
+            )
+            self.assertEqual(ddp_logging_data.get("dtypes"), "double, float")
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_ddp_logging_data_gpu(self):
+            _group, _group_id, rank = self._init_global_test()
+            model_DDP = self._test_ddp_logging_data(is_gpu=True)
+            ddp_logging_data = model_DDP._get_ddp_logging_data()
+            self.assertEqual(ddp_logging_data.get("device_ids"), str(rank))
+            self.assertEqual(ddp_logging_data.get("output_device"), rank)
+            grad_ready_order = ddp_logging_data.get(
+                "prev_iteration_grad_ready_order_indices"
+            )
+            expected_order = list(reversed([str(x) for x in range(3)]))
+            self.assertEqual(grad_ready_order, ", ".join(expected_order))
+            bucket_indices = ddp_logging_data.get("rebuilt_per_bucket_param_indices")
+            self.assertEqual(bucket_indices, " ".join(expected_order))
+            # test runtime logging fields
+            # It is hard to test accurate latency, but it can test whether the latency is
+            # a valid value and in the expected range.
+            self.assertGreaterEqual(ddp_logging_data.get("avg_forward_compute_time"), 1)
+            self.assertGreaterEqual(
+                ddp_logging_data.get("avg_backward_compute_comm_overlap_time"), 1
+            )
+            self.assertGreaterEqual(
+                ddp_logging_data.get("avg_backward_compute_time"),
+                ddp_logging_data.get("avg_backward_compute_comm_overlap_time"),
+            )
+            self.assertGreaterEqual(
+                ddp_logging_data.get("avg_backward_comm_time"),
+                ddp_logging_data.get("avg_backward_compute_comm_overlap_time"),
+            )
+            # Test host-side times are roughly in the order that we expect
+            fwd_host_side_time = ddp_logging_data.get("forward_compute_time_start")
+            bwd_comp_start_host_side_time = ddp_logging_data.get(
+                "backward_compute_time_start"
+            )
+            bwd_comp_end_host_side_time = ddp_logging_data.get(
+                "backward_compute_time_end"
+            )
+            bwd_comm_start_host_side_time = ddp_logging_data.get(
+                "backward_comm_time_start"
+            )
+            bwd_comm_end_host_side_time = ddp_logging_data.get("backward_comm_time_end")
+            self.assertGreaterEqual(
+                bwd_comm_end_host_side_time, bwd_comm_start_host_side_time
+            )
+            self.assertGreaterEqual(
+                bwd_comm_start_host_side_time, bwd_comp_start_host_side_time
+            )
+            self.assertGreaterEqual(
+                bwd_comp_end_host_side_time, bwd_comp_start_host_side_time
+            )
+            self.assertGreaterEqual(bwd_comp_start_host_side_time, fwd_host_side_time)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "nccl does not support DDP on CPU models"
+        )
+        def test_static_graph_api_cpu(self):
+            model_DDP = nn.parallel.DistributedDataParallel(Net())
+            expected_err = "should be called before training loop starts"
+            with self.assertRaisesRegex(RuntimeError, expected_err):
+                local_bs = 2
+                _batch_size, input, target, loss = self._prepare_dummy_data(local_bs)
+                offset = dist.get_rank() * local_bs
+
+                # DDP training, DDP scatters subsets of input to nodes/GPUs
+                self._test_DDP_helper(
+                    model_DDP,
+                    input[offset : offset + local_bs],
+                    target[offset : offset + local_bs],
+                    loss,
+                    1,
+                )
+                model_DDP._set_static_graph()
+
+            # Verify error was logged in ddp_logging_data.
+            verify_ddp_error_logged(model_DDP, expected_err)
+
+        @skipIfNoTorchVision
+        def test_SyncBatchNorm_process_group(self):
+            # When adopting `convert_sync_batchnorm` to convert a `nn.modules`,
+            # it need to recursively pass the `process_group` in the module when the `SyncBatchNorm`
+            # is nested in a sub-module or sub-sub-module (e.g. resnet50 in torchvision.models).
+
+            process_ids = 0
+            process_group = torch.distributed.new_group([process_ids])
+            res50_model = torchvision.models.resnet50()
+            res50_model_sync = nn.SyncBatchNorm.convert_sync_batchnorm(
+                copy.deepcopy(res50_model), process_group
+            )
+            process_group_sync = res50_model_sync.layer1[0].bn1.process_group
+            self.assertEqual(process_group_sync, process_group)
+
+        def _run_reduction_test(
+            self, tensor, expected_tensor, op, reduction_fn=dist.all_reduce, dst=None
+        ):
+            if reduction_fn is not dist.all_reduce and dst is None:
+                raise ValueError(f"Reduction fn {reduction_fn} must specify dst!")
+            if dst is not None:
+                reduction_fn(tensor, dst, op)
+                # Only destination rank tensor is expected to have final result.
+                if dist.get_rank() == dst:
+                    self.assertEqual(tensor, expected_tensor)
+            else:
+                reduction_fn(tensor, op)
+                self.assertEqual(tensor, expected_tensor)
+
+        @require_backend_is_available({"nccl"})
+        @skip_if_lt_x_gpu(2)
+        def test_nccl_backend_bool_allreduce(self):
+            torch.cuda.set_device(self.rank)
+            # Run all_reduce with PRODUCT
+            element = self.rank % 2 == 0
+            for op in [dist.ReduceOp.PRODUCT, dist.ReduceOp.MIN]:
+                input_tensor = torch.tensor([element, element]).to(self.rank)
+                self._run_reduction_test(
+                    input_tensor, torch.tensor([False, False]).to(self.rank), op
+                )
+                # Ensure that all ranks contributing True (cast to 1) results in the
+                # correct reduction.
+                input_tensor = torch.tensor([True, True]).to(self.rank)
+                expected_tensor = input_tensor.clone()
+                self._run_reduction_test(input_tensor, expected_tensor, op)
+
+            # Run all_reduce with SUM
+            for op in [dist.ReduceOp.SUM, dist.ReduceOp.MAX]:
+                input_tensor = torch.tensor([element, element]).to(self.rank)
+                self._run_reduction_test(
+                    input_tensor, torch.tensor([True, True]).to(self.rank), op
+                )
+            # TODO: NCCL backend does not work correctly for bitwise reduction ops
+            # (see https://github.com/pytorch/pytorch/issues/41362). Add tests for
+            # these once it is supported.
+
+        @require_backend_is_available({"nccl"})
+        @skip_if_lt_x_gpu(2)
+        def test_nccl_backend_bool_allgather(self):
+            torch.cuda.set_device(self.rank)
+            inp = {0: [True, True], 1: [False, True]}
+            input_tensor = torch.tensor(inp[self.rank % 2]).to(self.rank)
+            # Preserve a copy of the tensor to compare against after allgather.
+            input_tensor_copy = input_tensor.clone()
+            tensor_list = [
+                torch.tensor([False, False]).to(self.rank)
+                for _ in range(dist.get_world_size())
+            ]
+            dist.all_gather(tensor_list, input_tensor)
+
+            self.assertEqual(len(tensor_list), dist.get_world_size())
+            for i, t in enumerate(tensor_list):
+                expected = torch.tensor(inp[i % 2]).to(self.rank)
+                self.assertEqual(t, expected)
+            # Ensure that the input tensor is not modified, since this collective
+            # does not modify its input.
+            self.assertEqual(input_tensor_copy, input_tensor)
+
+        @require_backend_is_available({"nccl"})
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_nccl_backend_bool_reduce(self):
+            torch.cuda.set_device(self.rank)
+            inp = {0: [True, True], 1: [False, False]}
+            # Run reduce() with product op
+            for op in [dist.ReduceOp.PRODUCT, dist.ReduceOp.MIN]:
+                # make sure rank 0 gets False if WORLD_SIZE=1 to match expected tensor
+                input_tensor = torch.tensor(inp[(self.rank + 1) % 2]).to(self.rank)
+                expected = torch.tensor([False, False]).to(self.rank)
+                self._run_reduction_test(input_tensor, expected, op, dist.reduce, dst=0)
+                # Ensure that all ranks contributing True (cast to 1) results in the
+                # correct reduction.
+                input_tensor = torch.tensor([True, True]).to(self.rank)
+                expected_tensor = input_tensor.clone()
+                self._run_reduction_test(
+                    input_tensor, expected_tensor, op, dist.reduce, dst=0
+                )
+
+            for op in [dist.ReduceOp.SUM, dist.ReduceOp.MAX]:
+                input_tensor = torch.tensor(inp[self.rank % 2]).to(self.rank)
+                expected = (
+                    torch.tensor([True, True]).to(self.rank)
+                    if self.rank == 0
+                    else input_tensor.clone()
+                )
+                self._run_reduction_test(input_tensor, expected, op, dist.reduce, dst=0)
+
+        @require_backend_is_available({"nccl"})
+        @skip_if_lt_x_gpu(2)
+        def test_nccl_backend_bool_broadcast(self):
+            tensor_size = 10
+            bcast_tensor = torch.tensor(
+                [
+                    (random.random() < 0.5 if self.rank == 0 else False)
+                    for _ in range(tensor_size)
+                ]
+            ).to(self.rank)
+            dist.broadcast(bcast_tensor, src=0)
+            # Now allgather and ensure the tensors are equal.
+            tensor_list = [
+                torch.tensor([False for _ in range(tensor_size)]).to(self.rank)
+                for _ in range(dist.get_world_size())
+            ]
+            dist.all_gather(tensor_list, bcast_tensor)
+            expected = tensor_list[0]
+            for tensor in tensor_list[1:]:
+                self.assertEqual(tensor, expected)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_DistributedSampler_padding(self):
+            # Tests padding of distributed sampler.
+            world_size = dist.get_world_size()
+
+            # Simulates the 'casual' dataset size
+            dataset_size = 100 + world_size + 1
+            dataset = [torch.ones(1).to(self.rank) * i for i in range(dataset_size)]
+
+            # Simulates the 'tiny' dataset size
+            dataset_tiny_size = max(world_size // 2 - 1, 1)
+            dataset_tiny = [
+                torch.ones(1).to(self.rank) * i for i in range(dataset_tiny_size)
+            ]
+
+            # Specifying drop_last=True will cause the tail of the data to be dropped.
+            dist_sampler = DistributedSampler(dataset=dataset, drop_last=True)
+            local_num_samples, local_dataset_size = (
+                dist_sampler.num_samples,
+                dist_sampler.total_size,
+            )
+            # The effective dataset size should be the greatest integer that is <=
+            # dataset_size that is divisible by the world_size. This is to ensure each
+            # rank processes the same number of samples.
+            effective_dataset_size = (
+                math.ceil((dataset_size - world_size) / world_size)
+                if dataset_size % world_size != 0
+                else dataset_size / world_size
+            )
+            self.assertEqual(local_num_samples, effective_dataset_size)
+            self.assertEqual(local_dataset_size, local_num_samples * world_size)
+            indices_list = list(iter(dist_sampler))
+            self.assertEqual(len(indices_list), local_num_samples)
+
+            def validate_global_samples(local_num_samples):
+                # Ensure that each rank processes the same number of samples.
+                world_samples = [
+                    torch.LongTensor([0]).to(self.rank) for _ in range(world_size)
+                ]
+                dist.all_gather(
+                    world_samples, torch.tensor([local_num_samples]).to(self.rank)
+                )
+                world_samples = [sample.item() for sample in world_samples]
+                self.assertEqual(len(set(world_samples)), 1)
+
+            validate_global_samples(local_num_samples)
+
+            # drop_last=False is the default and will add additional indices to be sampled,
+            # increasing the effective dataset size.
+            dist_sampler_added_samples = DistributedSampler(dataset=dataset)
+            local_num_samples, local_dataset_size = (
+                dist_sampler_added_samples.num_samples,
+                dist_sampler_added_samples.total_size,
+            )
+            # The effective dataset size is the smallest integer that is >= dataset_size
+            # and divisible by the world size.
+            self.assertEqual(local_num_samples, math.ceil(dataset_size / world_size))
+            self.assertEqual(local_dataset_size, local_num_samples * world_size)
+            indices_list = list(iter(dist_sampler_added_samples))
+            self.assertEqual(len(indices_list), local_num_samples)
+
+            # Ensure that each rank processes the same number of samples.
+            validate_global_samples(local_num_samples)
+
+            # Ensure additional samples are padded even when
+            # the extremely small dataset is given.
+            dist_sampler_added_samples_tiny = DistributedSampler(dataset=dataset_tiny)
+            local_num_samples, local_dataset_size = (
+                dist_sampler_added_samples_tiny.num_samples,
+                dist_sampler_added_samples_tiny.total_size,
+            )
+            self.assertEqual(
+                local_num_samples, math.ceil(dataset_tiny_size / world_size)
+            )
+            self.assertEqual(local_dataset_size, local_num_samples * world_size)
+            indices_list = list(iter(dist_sampler_added_samples_tiny))
+            self.assertEqual(len(indices_list), local_num_samples)
+            validate_global_samples(local_num_samples)
+
+        def _test_allgather_object(self, subgroup=None):
+            # Only set device for NCCL backend since it must use GPUs.
+
+            gather_objects = create_collectives_object_test_list()
+
+            backend = os.environ["BACKEND"]
+            if backend == "nccl":
+                # Case where rank != GPU device.
+                next_rank = (self.rank + 1) % int(self.world_size)
+                torch.cuda.set_device(next_rank)
+
+            # If GPU test, add object with GPU tensor
+            if backend == "nccl":
+                gather_objects.append(Foo(torch.randn(3, 3, device=0)))
+
+            output_gathered = [None for _ in range(dist.get_world_size())]
+            dist.all_gather_object(
+                output_gathered,
+                gather_objects[self.rank % len(gather_objects)],
+                group=subgroup,
+            )
+
+            for i, val in enumerate(output_gathered):
+                expected = gather_objects[i % len(gather_objects)]
+                self.assertEqual(val, expected)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @require_n_gpus_for_nccl_backend(
+            int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]
+        )
+        @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
+        def test_all_gather_object_default_pg(self):
+            return self._test_allgather_object()
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @require_n_gpus_for_nccl_backend(
+            int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]
+        )
+        @with_dist_debug_levels(levels=["DETAIL", "OFF", "INFO"])
+        def test_all_gather_object_subgroup(self):
+            default = _get_default_group()
+            backend = dist.get_backend(default)
+            subgroup = dist.new_group(backend=backend)
+            return self._test_allgather_object(subgroup=subgroup)
+
+        def _test_gather_object(self, pg=None):
+            # Ensure stateful objects can be gathered
+            gather_objects = create_collectives_object_test_list()
+            my_rank = dist.get_rank(pg)
+
+            backend = os.environ["BACKEND"]
+            if backend == "nccl":
+                # Case where rank != GPU device.
+                next_rank = (self.rank + 1) % int(self.world_size)
+                torch.cuda.set_device(next_rank)
+
+            # If GPU test, add object with GPU tensor
+            if backend == "nccl":
+                gather_objects.append(Foo(torch.randn(3, 3, device=my_rank)))
+
+            output_gathered = [None for _ in range(dist.get_world_size(pg))]
+            gather_on_rank = 0
+            dist.gather_object(
+                gather_objects[self.rank % len(gather_objects)],
+                object_gather_list=output_gathered
+                if my_rank == gather_on_rank
+                else None,
+                dst=gather_on_rank,
+                group=pg,
+            )
+            if my_rank != gather_on_rank:
+                self.assertEqual(
+                    output_gathered, [None for _ in range(dist.get_world_size())]
+                )
+            else:
+                for i, val in enumerate(output_gathered):
+                    expected = gather_objects[i % len(gather_objects)]
+                    self.assertEqual(val, expected)
+
+            # Validate errors when objects can't be pickled.
+            class Bar:
+                pass
+
+            b = Bar()
+            gather_objects = [b for _ in range(dist.get_world_size())]
+            with self.assertRaises(AttributeError):
+                dist.all_gather_object(
+                    [None for _ in range(dist.get_world_size())],
+                    gather_objects[self.rank],
+                    group=pg,
+                )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @with_dist_debug_levels(levels=["DETAIL", "OFF", "INFO"])
+        @require_exact_world_size(4)
+        def test_gather_object(self):
+            return self._test_gather_object()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @with_dist_debug_levels(levels=["DETAIL", "OFF", "INFO"])
+        @require_exact_world_size(4)
+        def test_gather_object_subgroup(self):
+            default = _get_default_group()
+            backend = dist.get_backend(default)
+            subgroup = dist.new_group(backend=backend)
+            return self._test_gather_object(subgroup)
+
+        def validate_net_equivalence(self, net):
+            # Helper to validate synchronization of nets across ranks.
+            net_module_states = list(net.module.state_dict().values())
+            # Check that all tensors in module's state_dict() are equal.
+            for t in net_module_states:
+                tensor_list = [
+                    torch.zeros_like(t) for _ in range(dist.get_world_size())
+                ]
+                dist.all_gather(tensor_list, t)
+                for tensor in tensor_list:
+                    self.assertEqual(tensor, t)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_sync_module_states(self):
+            # Test that after calling _sync_module_states, models across ranks
+            # are the same and are equal to the model on the input rank.
+            dim = 2
+            rank = self.rank
+            rank_to_broadcast = 1
+            # Seed to ensure that ranks are initialized with different initial models.
+            torch.manual_seed(rank)
+            model = nn.Linear(dim, dim, bias=False)
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(rank), device_ids=[self.rank], bucket_cap_mb=1
+            )
+            new_model = nn.Linear(dim, dim, bias=False).cuda(rank)
+            net.module = copy.deepcopy(new_model)
+            # Assert params are different
+            net_module_states = list(net.module.state_dict().values())
+            for t in net_module_states:
+                tensor_list = [
+                    torch.zeros_like(t) for _ in range(dist.get_world_size())
+                ]
+                dist.all_gather(tensor_list, t)
+                for i, tensor in enumerate(tensor_list):
+                    if i == rank:
+                        self.assertEqual(t, tensor)
+                    else:
+                        # tensor from another rank should be different.
+                        self.assertNotEqual(t, tensor)
+
+            _sync_module_states(
+                module=net.module,
+                process_group=net.process_group,
+                broadcast_bucket_size=net.broadcast_bucket_size,
+                src=rank_to_broadcast,
+                params_and_buffers_to_ignore=net.parameters_to_ignore,
+            )
+            # Now all model params should be the same.
+            self.validate_net_equivalence(net)
+            # Since the network params were broadcast from rank_to_broadcast, validate that
+            # they are the same as new_model on rank_to_broadcast.
+            if rank == rank_to_broadcast:
+                expected_states = new_model.state_dict().values()
+                for t, expected in zip(net_module_states, expected_states, strict=True):
+                    self.assertEqual(t, expected)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_grad_div_uneven_inputs(self):
+            # Test gradient division during training with join() API. If
+            # divide_by_initial_world_size=False, we scale by the effective world
+            # size when allreducing grads.
+            dim = 5
+            batch = 1
+            grad_scale = 50
+            rank = self.rank
+            model = nn.Linear(dim, dim, bias=False)
+            inp = torch.ones(batch, dim, device=self.rank) * grad_scale
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(rank), device_ids=[self.rank], bucket_cap_mb=1
+            )
+            n_iters = 3
+            if self.rank > 0:
+                n_iters += 2
+
+            with net.join(divide_by_initial_world_size=False):
+                for _ in range(n_iters):
+                    loss = net(inp).sum()
+                    loss.backward()
+                    # The grad is always expected_grad, since we divide by the number
+                    # of currently active processes and inactive processes contribute
+                    # zero gradient. If we kept dividing by static initial world
+                    # size as processes leave, the grad would be smaller.
+                    expected_grad = torch.ones(dim, dim, device=self.rank) * grad_scale
+                    param = next(iter(net.parameters()))
+                    self.assertEqual(expected_grad, param.grad)
+                    # Avoid accumulating grads so that it's the same every iteration
+                    net.zero_grad()
+                    torch.cuda.synchronize(device=self.rank)
+
+            # If divide_by_initial_world_size=True (default), we always scale grads
+            # by the initial world_size.
+            with net.join(divide_by_initial_world_size=True):
+                for i in range(n_iters):
+                    loss = net(inp).sum()
+                    loss.backward()
+                    effective_ws = dist.get_world_size()
+                    if i >= 3:
+                        effective_ws -= 1
+                    expected_grad = (
+                        torch.ones(dim, dim, device=self.rank)
+                        * grad_scale
+                        * effective_ws
+                    ) / dist.get_world_size()
+                    param = next(iter(net.parameters()))
+                    self.assertEqual(expected_grad, param.grad)
+                    # Avoid accumulating grad so that it's the same every iteration.
+                    net.zero_grad()
+                    torch.cuda.synchronize(device=self.rank)
+
+        def _test_ddp_profiling(self, profiler_ctx, profiler_ctx2=None):
+            """Runs DDP based model training and captures profiles.
+            This test will do two profiler runs.
+            1. An initial basic run to check if profiler events are correctly captured.
+            2. A second profiling pass after running some iterations of DDP, to check robustness of thread local state.
+
+            args
+                profiler_ctx : Profiler context manager for pass 1
+                profiler_ctx2 : Profiler context manager for pass 2.
+                    This can be left out as None, in which case a deepcopy
+                    of profiler_ctx is used.
+            Returns:
+                prof: Instantiated profiler object that can be used for post analysis.
+            """
+            batch = 3
+            dim = 10
+            num_iters = 6
+            torch.cuda.set_device(self.rank)
+            model = nn.Linear(dim, dim, bias=False)
+            inp = torch.rand(batch, dim, device=self.rank)
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(self.rank),
+                device_ids=[self.rank],
+            )
+            if profiler_ctx2 is None:
+                profiler_ctx2 = copy.deepcopy(profiler_ctx)
+
+            with profiler_ctx as prof:
+                for _ in range(num_iters):
+                    loss = net(inp).sum()
+                    loss.backward()
+
+            all_reduce_event_name = f"{dist.get_backend()}:all_reduce"
+            events = get_profiling_event(
+                all_reduce_event_name, prof, dedup_gpu_user_annotation=True
+            )
+            event_count = sum(e.count for e in events)
+            self.assertEqual(event_count, num_iters)
+            for event in events:
+                self.assertTrue(event.is_async)
+                self.assertEqual(event.name, all_reduce_event_name)
+
+            broadcast_event_name = f"{dist.get_backend()}:broadcast"
+            broadcast_events = get_profiling_event(
+                broadcast_event_name, prof, dedup_gpu_user_annotation=True
+            )
+            event_count = sum(e.count for e in broadcast_events)
+            # Broadcast is called during rebuild_buckets
+            self.assertGreaterEqual(event_count, 1)
+            for event in broadcast_events:
+                self.assertEqual(event.name, broadcast_event_name)
+
+            # Run DDP with profiling for a few iterations, then enable profiling
+            # for a single pass, and ensure it is recorded. This tests that the
+            # thread local state is correctly updated.
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+            )
+            for _ in range(3):
+                loss = net(inp).sum()
+                loss.backward()
+            # Now enable the profiler.
+            with profiler_ctx2 as prof:
+                loss = net(inp).sum()
+                loss.backward()
+
+            events = get_profiling_event(
+                all_reduce_event_name, prof, dedup_gpu_user_annotation=True
+            )
+            self.assertGreaterEqual(len(events), 1)
+            self.assertGreaterEqual(events[0].count, 1)
+            self.assertEqual(events[0].name, all_reduce_event_name)
+            for event in events:
+                self.assertTrue(event.is_async)
+            # Ensure searching unused parameters was profiled
+            events = get_profiling_event("search_unused_parameters", prof)
+            self.assertEqual(len(events), 1)
+
+            return prof
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle("Currently failing in NVIDIA internal CI")
+        def test_ddp_profiling_autograd_profiler(self):
+            autograd_profiler_ctx = torch.autograd.profiler.profile()
+            return self._test_ddp_profiling(profiler_ctx=autograd_profiler_ctx)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang")
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
+        )
+        def test_ddp_profiling_torch_profiler(self):
+            cpu_act = torch.profiler.ProfilerActivity.CPU
+            cuda_act = torch.profiler.ProfilerActivity.CUDA
+            torch_profiler_ctx = torch.profiler.profile(activities=[cpu_act, cuda_act])
+            prof = self._test_ddp_profiling(profiler_ctx=torch_profiler_ctx)
+
+            if dist.get_backend() != "nccl":
+                return
+
+            # Note comment out the "os.remove(trace_file)" in `get_profiler_nccl_meta()`
+            # to debug any mismatches.
+            nccl_meta_events = get_profiler_nccl_meta(prof)
+            self.assertGreater(len(nccl_meta_events), 0)
+
+            nccl_meta = self._sanity_check_profiler_nccl_meta(nccl_meta_events)
+
+            # additionally check the specific collectives in this test case
+            self.assertEqual(len(nccl_meta["allreduce"]), 2)
+            self.assertEqual(len(nccl_meta["wait"]), 1)
+
+            # check allreduce message sizes
+            a0 = nccl_meta["allreduce"][0]
+            self.assertEqual(a0["Out msg nelems"], 100, msg=f"{a0}")
+            self.assertEqual(a0["dtype"], "Float", msg=f"{a0}")
+            a1 = nccl_meta["allreduce"][1]
+            self.assertEqual(a1["Out msg nelems"], 1, msg=f"{a1}")
+            self.assertEqual(a1["dtype"], "Int", msg=f"{a1}")
+
+        def _validate_execution_trace_nccl(self, et_file: str) -> None:
+            """Torch profiler includes nccl metadata in an inserted operator called "record_param_comms"
+            We test for basic fields in these nodes in the Execution Trace.
+            """
+            with open(et_file) as f:
+                et = json.load(f)
+            pg_cfg_node = [
+                n for n in et["nodes"] if n["name"] == "## process_group:init ##"
+            ]
+            self.assertGreaterEqual(len(pg_cfg_node), 1)
+            nccl_meta_nodes = [
+                n for n in et["nodes"] if n["name"] == "record_param_comms"
+            ]
+            self.assertEqual(len(nccl_meta_nodes), 3)
+            per_coll_meta = defaultdict(list)
+
+            # Sanity check NCCL metadata nodes
+            for n in nccl_meta_nodes:
+                attrs_list = n.get("attrs", [])
+                self.assertGreater(len(attrs_list), 0)
+                attrs = {a["name"]: a["value"] for a in attrs_list}
+
+                collname = attrs.get("collective_name", "")
+                self.assertNotEqual(collname, "")
+                self.assertNotEqual(attrs.get("dtype", ""), "")
+
+                per_coll_meta[collname].append(attrs)
+                if collname == "wait":
+                    continue
+
+                self.assertEqual(attrs["pg_name"], "0")  # yes this is a string
+                self.assertEqual(attrs["pg_desc"], "default_pg")
+                self.assertEqual(attrs["pg_size"], 2)
+
+                self.assertGreaterEqual(attrs.get("in_msg_nelems", -1), 0)
+                self.assertGreaterEqual(attrs.get("out_msg_nelems", -1), 0)
+                self.assertTrue("in_split_size" in attrs)
+                self.assertTrue("out_split_size" in attrs)
+                self.assertEqual(attrs.get("global_rank_start", -1), 0)
+                self.assertEqual(attrs.get("global_rank_stride", -1), 1)
+
+            # print(per_coll_meta)
+            self.assertEqual(len(per_coll_meta["allreduce"]), 2)
+            self.assertEqual(len(per_coll_meta["wait"]), 1)
+
+            # check allreduce message sizes
+            a0 = per_coll_meta["allreduce"][0]
+            self.assertEqual(a0["out_msg_nelems"], 100, msg=f"{a0}")
+            self.assertEqual(a0["dtype"], "Float", msg=f"{a0}")
+            a1 = per_coll_meta["allreduce"][1]
+            self.assertEqual(a1["out_msg_nelems"], 1, msg=f"{a1}")
+            self.assertEqual(a1["dtype"], "Int", msg=f"{a1}")
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang")
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
+        )
+        @unittest.skipIf(BACKEND != "nccl", "Tests nccl metadata primarily.")
+        def test_ddp_profiling_execution_trace(self):
+            self.assertEqual(dist.get_backend(), "nccl")
+            # Create a temp file to save execution trace data
+            with TemporaryFileName("w+t", suffix=".et.json") as et_file:
+                et = ExecutionTraceObserver().register_callback(et_file)
+
+                # first profiler context need not have ET
+                torch_profiler_ctx1 = torch.profiler.profile(
+                    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
+                )
+                # collect ET in second profiler pass
+                torch_profiler_ctx2 = torch.profiler.profile(
+                    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
+                    execution_trace_observer=et,
+                )
+                self._test_ddp_profiling(
+                    profiler_ctx=torch_profiler_ctx1,
+                    profiler_ctx2=torch_profiler_ctx2,
+                )
+
+                print(f"Execution trace saved at {et_file}")
+                self._validate_execution_trace_nccl(et_file)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_join_model_equivalence(self):
+            # Verifies equivalence with model training locally and with DDP under
+            # the join context manager.
+            batch = 3
+            dim = 10
+            learning_rate = 0.03
+            model = nn.Linear(dim, dim, bias=False)
+            inp = torch.rand(batch, dim, device=self.rank)
+            local_model = copy.deepcopy(model)
+            local_model = local_model.cuda(self.rank)
+            rank_to_iter_mapping = {
+                rank: 2 * (rank + 1) for rank in range(dist.get_world_size())
+            }
+            # run local model
+            local_iters = sum(rank_to_iter_mapping.values())
+            local_optim = torch.optim.SGD(local_model.parameters(), lr=learning_rate)
+            for _ in range(local_iters):
+                local_optim.zero_grad()
+                out = local_model(inp)
+                loss = out.sum()
+                loss.backward()
+                local_optim.step()
+
+            # run DDP model with join API
+            num_iters = rank_to_iter_mapping[self.rank]
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(self.rank), device_ids=[self.rank]
+            )
+            ddp_optim = torch.optim.SGD(
+                model.parameters(), lr=learning_rate * dist.get_world_size()
+            )
+            with net.join():
+                for _ in range(num_iters):
+                    ddp_optim.zero_grad()
+                    out = net(inp)
+                    loss = out.sum()
+                    loss.backward()
+                    torch.cuda.synchronize(device=self.rank)
+                    ddp_optim.step()
+
+            # Validate model state dicts are equal
+            for (_, local_tensor), (_, dist_tensor) in zip(
+                local_model.state_dict().items(),
+                net.module.state_dict().items(),
+                strict=True,
+            ):
+                self.assertEqual(local_tensor, dist_tensor)
+
+        def _run_uneven_inputs_test(
+            self,
+            test_case,
+            iteration_mapping,
+            find_unused_params,
+        ):
+            model = test_case.model
+            inp = test_case.inp
+            rank = self.rank
+            sync_interval = test_case.sync_interval
+            torch.cuda.set_device(rank)
+            # Ensure all outstanding GPU work is completed so this test runs independently.
+            dist.barrier()
+            # Bucket_cap_mb is intentionally low to test allreduce scheduling when
+            # there are many buckets.
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(rank),
+                device_ids=[rank],
+                bucket_cap_mb=1,
+                find_unused_parameters=find_unused_params,
+            )
+            # Register hook if specified
+            if test_case.hook is not None:
+                net.register_comm_hook(test_case.state, test_case.hook)
+                print(f"registered hook {test_case.hook}")
+
+            # Determine num iters for this rank via the passed in mapping.
+            num_iters = iteration_mapping[rank]
+            # If we throw when earliest rank terminates, we should ensure
+            # that we iterate for that minimum number of times.
+            num_iters_tensor = torch.tensor(
+                [num_iters], device=torch.cuda.current_device()
+            )
+            dist.all_reduce(num_iters_tensor, op=dist.ReduceOp.MIN)
+            min_num_iters = num_iters_tensor.item()
+            total_iters = 0
+            if test_case.throw_on_early_termination:
+                if min_num_iters == num_iters:
+                    # Early termination rank(s)
+                    exception_ctx = self.assertRaisesRegex(
+                        RuntimeError, f"Rank {self.rank} exhausted all inputs"
+                    )
+                else:
+                    # Non early termination rank
+                    exception_ctx = self.assertRaisesRegex(
+                        RuntimeError,
+                        "Detected at least one rank that exhausted inputs.",
+                    )
+            else:
+                exception_ctx = nullcontext()
+            with exception_ctx:
+                with net.join(
+                    throw_on_early_termination=test_case.throw_on_early_termination
+                ):
+                    for i in range(num_iters):
+                        # Use model.no_sync() to disable grad synchronization every
+                        # sync_interval.
+                        if i % sync_interval != 0:
+                            context = net.no_sync()
+                        else:
+                            context = nullcontext()
+                        with context:
+                            if isinstance(inp, tuple):
+                                loss = net(*inp).sum()
+                            else:
+                                loss = net(inp).sum()
+                            loss.backward()
+                            self._model_step(net)
+                            # Ensure completion of GPU kernels (including allreduce). If the
+                            # join API is not properly implemented, then this should hang
+                            # since the allreduce will hang.
+                            torch.cuda.synchronize(device=rank)
+                        total_iters += 1
+            if test_case.throw_on_early_termination:
+                # Ensure we iterated min_num_iters times.
+                self.assertEqual(total_iters, min_num_iters)
+            else:
+                # Ensure we iterated at least min_num_iters times.
+                self.assertGreaterEqual(total_iters, min_num_iters)
+
+            # Ensure completion of all GPU kernels.
+            torch.cuda.synchronize(device=rank)
+            # When throwing on early rank termination, we do not
+            # broadcast model state from an authoritative rank. All models
+            # should already be in sync.
+            if not test_case.throw_on_early_termination:
+                self.assertTrue(net._authoritative_rank)
+                # All ranks should have agreed on the same authoritative_rank!
+                final_rank_tensor = torch.tensor(
+                    [net._authoritative_rank], device=self.rank
+                )
+                tensor_list = [
+                    torch.zeros_like(final_rank_tensor)
+                    for _ in range(dist.get_world_size())
+                ]
+                dist.all_gather(tensor_list, final_rank_tensor)
+                max_rank = dist.get_world_size() - 1
+                self.assertSetEqual(
+                    {max_rank}, {tensor.item() for tensor in tensor_list}
+                )
+                # Ensure that all models are the same across ranks after all have joined.
+                self.validate_net_equivalence(net)
+                # Ensure that running with DDP uneven inputs was logged.
+                ddp_logging_data = net._get_ddp_logging_data()
+                self.assertTrue(ddp_logging_data.get("join_uneven_inputs"))
+                dist.barrier()
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_uneven_inputs_stop_iteration_sync_bn(self):
+            # Tests that uneven inputs join handler correctly throws StopIteration
+            # for models with SyncBN or general collective comm when
+            # throw_on_early_termination=True.
+            class ModelWithComm(torch.nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.lin = nn.Linear(2, 40, bias=False)
+
+                def forward(self, x):
+                    x = self.lin(x)
+                    dist.all_reduce(x)
+                    return x
+
+            torch.cuda.set_device(self.rank)
+            model_bn = BatchNormNet()
+            model_bn = nn.SyncBatchNorm.convert_sync_batchnorm(
+                copy.deepcopy(model_bn)
+            ).cuda(self.rank)
+            comm_model = ModelWithComm().cuda(self.rank)
+            model_input = torch.randn(10, 2).cuda(torch.cuda.current_device())
+
+            for model in [model_bn, comm_model]:
+                model = torch.nn.parallel.DistributedDataParallel(
+                    model,
+                    device_ids=[self.rank],
+                )
+                min_num_iters = 5
+                if self.rank != 0:
+                    # Early termination rank(s)
+                    num_iters = min_num_iters
+                    exception_ctx = self.assertRaisesRegex(
+                        RuntimeError, f"Rank {self.rank} exhausted all inputs"
+                    )
+                else:
+                    # Non early termination rank
+                    num_iters = min_num_iters * 2
+                    exception_ctx = self.assertRaisesRegex(
+                        RuntimeError,
+                        "Detected at least one rank that exhausted inputs.",
+                    )
+                n = 0
+                with exception_ctx:
+                    with model.join(throw_on_early_termination=True):
+                        for _ in range(num_iters):
+                            loss = model(model_input).sum()
+                            loss.backward()
+                            self._model_step(model)
+                            n += 1
+
+                self.assertEqual(n, min_num_iters)
+                # Verify model equivalence
+                self.validate_net_equivalence(model)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_uneven_inputs(self):
+            dim = 1000
+            batch = 1
+            # Create a variety of models to run uneven input tests on.
+            large_model = nn.Sequential(
+                nn.Conv2d(1, 20, 5),
+                nn.ReLU(),
+                nn.Conv2d(20, 32, 5),
+                nn.ReLU(),
+                nn.Conv2d(32, 256, 5),
+                nn.ReLU(),
+            )
+            small_model = nn.Linear(dim, dim, bias=False)
+            bn_net = BatchNormNet()
+
+            class UnusedParamModule(nn.Module):
+                def __init__(self, unused_params_rank):
+                    super().__init__()
+                    self.t0 = Task()
+                    self.t1 = Task()
+                    self.unused_params_rank = unused_params_rank
+
+                def task_parameters(self):
+                    return (self.t0.p, self.t1.p)
+
+                def forward(self, x, rank):
+                    return (
+                        self.t1(self.t0(x))
+                        if rank != self.unused_params_rank
+                        else self.t1(x)
+                    )
+
+            unjoined_rank_with_unused_params_model = UnusedParamModule(1)
+            joined_rank_with_unused_params_model = UnusedParamModule(0)
+
+            rank = self.rank
+            models_to_test = [
+                # Network with batchnorm
+                DDPUnevenTestInput(
+                    name="batch_norm_net",
+                    model=bn_net,
+                    inp=torch.ones(batch, 2, device=rank),
+                    sync_interval=1,
+                ),
+                DDPUnevenTestInput(
+                    name="large_conv_model",
+                    model=large_model,
+                    inp=torch.ones(batch, batch, dim, dim, device=rank),
+                    sync_interval=1,
+                ),
+                DDPUnevenTestInput(
+                    name="small_model",
+                    model=small_model,
+                    inp=torch.ones(batch, dim, device=rank),
+                    sync_interval=1,
+                ),
+                # Unused parameter test where rank that does not join early has unused params
+                DDPUnevenTestInput(
+                    name="unjoined_rank_with_unused_params_model",
+                    model=unjoined_rank_with_unused_params_model,
+                    inp=(torch.ones(batch, 2, device=rank), rank),
+                    sync_interval=1,
+                ),
+                # Unused parameter test where rank that does join early has unused params
+                DDPUnevenTestInput(
+                    name="joined_rank_with_unused_params_model",
+                    model=joined_rank_with_unused_params_model,
+                    inp=(torch.ones(batch, 2, device=rank), rank),
+                    sync_interval=1,
+                ),
+            ]
+
+            # Test models that have hook installed.
+            models_with_hook = [
+                DDPUnevenTestInput(
+                    name="small_model_allreduce_hook",
+                    model=small_model,
+                    hook=default.allreduce_hook,
+                    state=None,
+                    inp=torch.ones(batch, dim, device=rank),
+                    sync_interval=1,
+                ),
+                DDPUnevenTestInput(
+                    name="small_model_power_sgd_hook",
+                    model=small_model,
+                    hook=powerSGD.powerSGD_hook,
+                    state=powerSGD.PowerSGDState(
+                        process_group=None,
+                        matrix_approximation_rank=1,
+                        # Config so that powerSGD runs immediately instead of
+                        # allreduce.
+                        start_powerSGD_iter=1,
+                        warm_start=False,
+                        use_error_feedback=False,
+                    ),
+                    inp=torch.ones(batch, dim, device=rank),
+                    sync_interval=1,
+                ),
+            ]
+            models_to_test.extend(models_with_hook)
+
+            # Add resnet model if we have torchvision installed.
+            if HAS_TORCHVISION:
+                resnet_model = torchvision.models.resnet50()
+                models_to_test.append(
+                    DDPUnevenTestInput(
+                        name="resnet_model",
+                        model=resnet_model,
+                        inp=torch.ones(1, 3, 1000, 1000),
+                        sync_interval=1,
+                    )
+                )
+
+            # Test with no_sync every 2, 3, 4, ... iterations.
+            models_with_sync = []
+            for i, test_input in enumerate(models_to_test):
+                models_with_sync.append(
+                    DDPUnevenTestInput(
+                        name=test_input.name,
+                        model=test_input.model,
+                        inp=test_input.inp,
+                        sync_interval=i + 2,
+                    )
+                )
+
+            throw_on_early_term_tests = []
+            for test_input in models_to_test:
+                throw_on_early_term_tests.append(
+                    DDPUnevenTestInput(
+                        name=test_input.name,
+                        model=test_input.model,
+                        inp=test_input.inp,
+                        sync_interval=test_input.sync_interval,
+                        throw_on_early_termination=True,
+                    )
+                )
+
+            models_to_test.extend(models_with_sync)
+            models_to_test.extend(throw_on_early_term_tests)
+
+            # 0 iteration tests for when one process does not train model at all, so
+            # we must shadow the broadcast calls made when rebuilding buckets.
+            baseline_num_iters = [0, 5]
+            iteration_offsets = [2, 3, 10]
+            num_uneven_ranks = [1]
+            if dist.get_world_size() > 2:
+                num_uneven_ranks.append(2)
+            iteration_mappings = []
+            # Generate rank : num_iters mappings for various uneven input scenarios.
+            # This includes cases where rank 0 joins early and all other ranks join
+            # later, and scenarios where multiple ranks join early, but at different
+            # iterations, and later ranks join later.
+            for num_early_join_ranks in num_uneven_ranks:
+                for baseline_iter in baseline_num_iters:
+                    for offset in iteration_offsets:
+                        mapping = dict.fromkeys(
+                            range(num_early_join_ranks), baseline_iter
+                        )
+                        # if num_early_join_ranks > 1, ranks > 0 that will join early
+                        # iterate offset//2 more times than rank 0, to test nodes
+                        # depleting inputs at different times.
+                        if num_early_join_ranks > 1:
+                            for rank in mapping:
+                                if rank > 0:
+                                    mapping[rank] += offset // 2
+                        mapping.update(
+                            dict.fromkeys(
+                                range(num_early_join_ranks, dist.get_world_size()),
+                                baseline_iter + offset,
+                            )
+                        )
+                        iteration_mappings.append(mapping)
+
+            for test_case, iteration_mapping in itertools.product(
+                models_to_test, iteration_mappings
+            ):
+                if self.rank == 0:
+                    print(
+                        f"""Running test: {test_case.name} sync interval
+                        {test_case.sync_interval} with iteration mapping
+                        {iteration_mapping}"""
+                    )
+                self._run_uneven_inputs_test(
+                    test_case,
+                    iteration_mapping,
+                    find_unused_params=("unused_params_model" in test_case.name),
+                )
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_uneven_input_join_disable(self):
+            # tests that if net.join() with enable=False is specified, DDP works as
+            # expected with even inputs.
+            torch.manual_seed(self.rank)
+            net = torch.nn.parallel.DistributedDataParallel(
+                torch.nn.Linear(1, 1).cuda(self.rank), device_ids=[self.rank]
+            )
+            inp = torch.ones(1) * self.rank
+            n_iters = 5
+            world_size = dist.get_world_size()
+            with net.join(enable=False):
+                for _ in range(n_iters):
+                    # Clear grads
+                    grad = net.module.weight.grad
+                    if grad is not None:
+                        grad.requires_grad_(False)
+                        grad.zero_()
+                    out = net(inp)
+                    loss = out.sum()
+                    loss.backward()
+                    # Validate gradients to ensure that we divide by the correct
+                    # world_size when join mode is disabled.
+                    expected_grad = sum(i for i in range(world_size)) / world_size
+                    self.assertEqual(net.module.weight.grad.item(), expected_grad)
+
+            join_config = net._join_config
+            self.assertFalse(join_config.enable)
+            self.validate_net_equivalence(net)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_uneven_input_exception(self):
+            # Tests that exceptions during training are correctly propagated by the
+            # context manager.
+            error_str = "Intentional error"
+
+            class ExceptionModule(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.param = nn.Parameter(torch.ones(1, requires_grad=True))
+
+                def forward(self, _):
+                    raise ValueError(error_str)
+
+            exception_module = ExceptionModule()
+            net = torch.nn.parallel.DistributedDataParallel(
+                exception_module.cuda(self.rank), device_ids=[self.rank]
+            )
+            inp = torch.ones(1)
+            with self.assertRaisesRegex(ValueError, error_str):
+                with net.join():
+                    out = net(inp)
+                    loss = out.sum()
+                    loss.backward()
+
+        def _test_broadcast_object_list(self, group=None):
+            gather_objects = create_collectives_object_test_list()
+
+            # Only set device for NCCL backend since it must use GPUs.
+            # Case where rank != GPU device.
+            next_rank = (self.rank + 1) % int(self.world_size)
+            backend = os.environ["BACKEND"]
+            if backend == "nccl":
+                torch.cuda.set_device(next_rank)
+
+            src_rank = 0
+            # If GPU test, add object with GPU tensor
+            if backend == "nccl":
+                gather_objects.append(Foo(torch.randn(3, 3, device=0)))
+
+            if IS_FBCODE:
+                # Create Tensor with > 2^31 Bytes storage requirements
+                # Only on FBCODE as testing OOMs in OSS
+                gather_objects.append(Foo(torch.randn(3, 178956971)))
+            objects = (
+                gather_objects
+                if self.rank == src_rank
+                else [None for _ in gather_objects]
+            )
+
+            # Single object test with device specified. Backend="gloo", device=cpu
+            if backend != "nccl":
+                single_obj_list = [objects[0]]
+                if self.rank != src_rank:
+                    self.assertNotEqual(single_obj_list[0], gather_objects[0])
+                dist.broadcast_object_list(
+                    single_obj_list, src=0, group=group, device=torch.device("cpu")
+                )
+                self.assertEqual(single_obj_list[0], gather_objects[0])
+
+            # Single object test with device specified. Backend="gloo", device=current_device+1
+            # The test is gated by the fact GPU count is the same as world size to avoid the case
+            # when backend is gloo but there is no multiple GPU devices.
+            if backend != "nccl" and torch.cuda.device_count() == int(self.world_size):
+                single_obj_list = [objects[0]]
+                if self.rank != src_rank:
+                    self.assertNotEqual(single_obj_list[0], gather_objects[0])
+                dist.broadcast_object_list(
+                    single_obj_list, src=0, group=group, device=torch.device(next_rank)
+                )
+                self.assertEqual(single_obj_list[0], gather_objects[0])
+
+            # Single object test with device specified. Backend="nccl", device=current_device+1
+            if backend == "nccl" and torch.cuda.device_count() == int(self.world_size):
+                single_obj_list = [objects[0]]
+                if self.rank != src_rank:
+                    self.assertNotEqual(single_obj_list[0], gather_objects[0])
+                dist.broadcast_object_list(
+                    single_obj_list, src=0, group=group, device=torch.device(next_rank)
+                )
+                self.assertEqual(single_obj_list[0], gather_objects[0])
+
+            # Single object test: backward compatibility with device unspecified
+            single_obj_list = [objects[0]]
+            if self.rank != src_rank:
+                self.assertNotEqual(single_obj_list[0], gather_objects[0])
+            dist.broadcast_object_list(single_obj_list, src=0, group=group)
+            self.assertEqual(single_obj_list[0], gather_objects[0])
+
+            # Multiple input objects test
+            if self.rank != src_rank:
+                self.assertNotEqual(objects, gather_objects)
+            dist.broadcast_object_list(objects, src=0, group=group)
+            self.assertEqual(objects, gather_objects)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @require_n_gpus_for_nccl_backend(
+            int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]
+        )
+        @with_dist_debug_levels(levels=["DETAIL"])
+        @unittest.skip(
+            "Test is failing, see https://github.com/pytorch/pytorch/pull/113620"
+        )
+        def test_broadcast_object_list(self):
+            return self._test_broadcast_object_list()
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @require_n_gpus_for_nccl_backend(
+            int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]
+        )
+        @with_dist_debug_levels(levels=["DETAIL"])
+        def _test_broadcast_object_list_subgroup(self):
+            default = _get_default_group()
+            backend = dist.get_backend(default)
+            subgroup = dist.new_group(backend=backend)
+            return self._test_broadcast_object_list(subgroup)
+
+        def _test_ddp_ignore_params_arg(self, static_graph=False):
+            class TestModel(nn.Module):
+                def __init__(self, rank):
+                    self.rank = rank
+                    super().__init__()
+                    self.fc1 = nn.Linear(1, 1, bias=False)
+                    # Proxy that will be materialized to another architecture later.
+                    # (after wrapping model with DDP)
+                    if self.rank == 0:
+                        self.fc2 = nn.Linear(1, 10, bias=False)
+                    else:
+                        self.fc2 = nn.Linear(10, 10, bias=False)
+
+                def forward(self, x):
+                    x = self.fc1(x)
+                    x = self.fc2(x)
+                    return x
+
+            device_id = self.rank
+            # Ensure the test works for both find_unused_parameter and broadcast_buffer settings.
+            for find_unused, broadcast_buffers in itertools.product(
+                [False, True], [False, True]
+            ):
+                model = TestModel(self.rank).float().to(device_id)
+                # Note that the model can have different shape buffers if we pass
+                # them in to be ignored as well.
+                model.fc2.register_buffer(
+                    "ignore_buffer", torch.zeros(5 + self.rank, device=self.rank)
+                )
+                proxy_params = list(model.fc2.parameters())
+                model_fc2_name = next(
+                    module_name
+                    for module_name, module in model.named_modules()
+                    if module is model.fc2
+                )
+                proxy_param_names = [
+                    f"{model_fc2_name}.{param_name}"
+                    for param_name, _ in model.fc2.named_parameters()
+                ]
+                proxy_buffer_names = [
+                    f"{model_fc2_name}.{buf_name}"
+                    for buf_name, _ in model.fc2.named_buffers()
+                ]
+                # Specify that we should ignore proxy_params since it will be
+                # materialized later.
+                torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
+                    model, proxy_param_names + proxy_buffer_names
+                )
+                ddp = torch.nn.parallel.DistributedDataParallel(
+                    model,
+                    device_ids=[device_id],
+                    find_unused_parameters=find_unused,
+                    broadcast_buffers=broadcast_buffers,
+                    static_graph=static_graph,
+                )
+                # Materialize new params. These are not registered in DDP and thus
+                # don't have autograd hooks installed on them.
+                ddp.module.fc2 = nn.Linear(1, 1, bias=False).to(device_id)
+
+                # local model with the new materialized parameters.
+                local_model = copy.deepcopy(ddp.module).cuda(self.rank)
+
+                inp = torch.ones(1, dtype=torch.float).to(device_id) * (self.rank + 1)
+                for _ in range(6):
+                    ddp(inp).sum().backward()
+
+                    local_model(inp).sum().backward()
+                    # materialized param grad is not touched by DDP, so its grad should
+                    # be the same as if running locally.
+                    for materialized_param, local_param in zip(
+                        ddp.module.fc2.parameters(),
+                        local_model.fc2.parameters(),
+                        strict=True,
+                    ):
+                        self.assertEqual(materialized_param.grad, local_param.grad)
+
+                    # fc1 parameter grad should still be different, due to allreduce.
+                    for synced_param, local_param in zip(
+                        ddp.module.fc1.parameters(),
+                        local_model.fc1.parameters(),
+                        strict=True,
+                    ):
+                        self.assertFalse(synced_param.grad == local_param.grad)
+
+                    # Proxy module grad should not be touched
+                    for proxy_param in proxy_params:
+                        self.assertTrue(proxy_param.grad is None)
+
+                # Synchronize since we run multiple iterations of this test, to
+                # isolate failure hangs.
+                torch.cuda.synchronize(device=self.rank)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_ignore_params_arg(self):
+            self._test_ddp_ignore_params_arg(static_graph=False)
+            self._test_ddp_ignore_params_arg(static_graph=True)
+
+        @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_unused_params_rebuild_buckets_exception(self):
+            class ToyModel(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.net1 = nn.Linear(10, 10, bias=False)
+                    self.net2 = nn.Linear(10, 10, bias=False)
+
+                def forward(self, x):
+                    return self.net1(x)
+
+            ddp = torch.nn.parallel.DistributedDataParallel(
+                ToyModel().cuda(self.rank), device_ids=[self.rank]
+            )
+            for i in range(2):
+                inp = torch.rand(1, 10)
+                if i > 0:
+                    # On 2nd iteration, this will fail during rebuild_buckets,
+                    # but we should report an error regarding unused parameters
+                    # since that is the underlying root cause.
+                    try:
+                        ddp(inp).sum().backward()
+                    except RuntimeError as e:
+                        msg = str(e)
+                        verify_ddp_error_logged(ddp, msg)
+                        expected_strs = [
+                            ddp_prev_reduction_unfinished_str,
+                            ddp_recommend_find_unused_params_str,
+                            ddp_outputs_not_used_in_loss_str,
+                        ]
+                        # In debug mode, should show parameters that weren't reduced.
+                        # Without debug mode, should show suggestion to use debug mode.
+                        if dist.get_debug_level() == dist.DebugLevel.OFF:
+                            expected_strs.append(ddp_suggest_debug_mode_str)
+                        else:
+                            unreduced_params = ", ".join(["net2.weight"])
+                            expected_strs.append(
+                                f"did not receive grad for rank {self.rank}: {unreduced_params}"
+                            )
+                        for s in expected_strs:
+                            self.assertTrue(s in msg, f"Expected {s} to be in {msg}")
+                        self.assertFalse(ddp_find_unused_params_enabled_str in msg)
+                    else:
+                        self.assertFalse(
+                            True, "DDP unused parameters error not raised."
+                        )
+                else:
+                    ddp(inp).sum().backward()
+
+            dist.barrier()
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_shared_grad_acc_unused_params(self):
+            # When find_unused_parameters=True, ensure we mark unused parameters
+            # even if they share gradient accumulators.
+            class ToyModel(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    # net1, bias, and net1.bias are all unused params.
+                    self.net1 = nn.Linear(10, 5, bias=False)
+                    self.bias = nn.Parameter(torch.zeros(5))
+                    # net1.bias and self.bias are names for the same underlying
+                    # parameter, so they share the same grad acc. This caused
+                    # the bug reported in https://github.com/pytorch/pytorch/issues/41324.
+                    self.net1.bias = self.bias
+                    self.net2 = nn.Linear(10, 5)
+
+                def forward(self, x):
+                    return self.net2(x).sum()
+
+            torch.cuda.set_device(self.rank)
+            model = ToyModel().to(torch.cuda.current_device())
+            for static in [True, False]:
+                ddp_model = torch.nn.parallel.DistributedDataParallel(
+                    copy.deepcopy(model),
+                    device_ids=[self.rank],
+                    find_unused_parameters=True,
+                    static_graph=static,
+                )
+                inp = torch.randn(20, 10, device=self.rank)
+                for _ in range(6):
+                    loss = ddp_model(inp)
+                    # To test https://github.com/pytorch/pytorch/issues/61982
+                    loss /= 10
+                    loss.backward()
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_device(self):
+            expected_len = 2
+
+            class TensorWrapper:
+                __slots__ = ["t", "moved_to_gpu"]
+
+                def __init__(self, t):
+                    self.t = t
+                    self.moved_to_gpu = False
+
+            # Handlers for specific types of validation we want to do based on
+            # the input type.
+
+            def tuple_and_list_validator(x):
+                self.assertTrue(len(x), expected_len)
+                self.assertEqual(1, len({t.device for t in x}))
+                self.assertEqual(x[0].device.index, self.rank)
+                return x[0] + x[1]
+
+            def namedtuple_validator(x):
+                self.assertEqual(x._fields, EXPECTED_FIELDS)
+                self.assertEqual(x.a.device.index, x.b.device.index)
+                self.assertEqual(x.a.device.index, self.rank)
+                return x.a + x.b
+
+            def custom_type_validator(x):
+                self.assertTrue(x.moved_to_gpu or (str(x.t.device) == "cpu"))
+                x.t = x.t.to(self.rank)
+                x.moved_to_gpu = True
+                return x.t
+
+            def dict_validator(x):
+                self.assertTrue(EXPECTED_FIELDS[0] in x)
+                self.assertTrue(EXPECTED_FIELDS[1] in x)
+                self.assertEqual(1, len({t.device for t in x.values()}))
+                self.assertEqual(x[EXPECTED_FIELDS[0]].device.index, self.rank)
+                return x[EXPECTED_FIELDS[0]] + x[EXPECTED_FIELDS[1]]
+
+            validators = {
+                TensorWrapper: custom_type_validator,
+                tuple: tuple_and_list_validator,
+                list: tuple_and_list_validator,
+                TestNamedTupleInput_0: namedtuple_validator,
+                TestNamedTupleInput_1: namedtuple_validator,
+                dict: dict_validator,
+            }
+
+            class ToyModel(torch.nn.Module):
+                def __init__(self_):  # noqa: B902
+                    super().__init__()
+                    self_.lin = nn.Linear(10, 10, bias=False)
+
+                def forward(self_, x, expected_type):  # noqa: B902
+                    # Similar to scatter, the recursive to in the single-device
+                    # case does not move tensors if they are in a custom type.
+                    self.assertTrue(isinstance(x, expected_type))
+                    fwd_tensor = validators[expected_type](x)
+                    return self_.lin(fwd_tensor)
+
+            model = torch.nn.parallel.DistributedDataParallel(
+                ToyModel().to(self.rank), device_ids=[self.rank]
+            )
+
+            def train_iter(inp, input_type):
+                for _ in range(4):
+                    out = model(inp, input_type)
+                    out.sum().backward()
+
+            # CPU tuple input, should be moved to the proper device before call
+            # to forward.
+            inp = tuple(torch.randn(10, 10) for _ in range(expected_len))
+            train_iter(inp, tuple)
+
+            # List CPU input, should be moved to proper device before call to
+            # forward.
+            inp = [torch.randn(10, 10) for _ in range(expected_len)]
+            train_iter(inp, list)
+            # Custom type containing tensor. The type is maintained, but the
+            # device is not propagated (which is what happens with scatter too)
+            inp = TensorWrapper(torch.randn(10, 10))
+            train_iter(inp, TensorWrapper)
+            # NamedTuple input. The type should be maintained and tensor inputs
+            # should be moved to the correct device as in scatter.
+            batch = 5
+            dim = 10
+            a = torch.rand(batch, dim)
+            b = torch.rand(batch, dim)
+
+            inp = TestNamedTupleInput_0(a, b)
+            train_iter(inp, type(inp))
+
+            inp = TestNamedTupleInput_1(a, b)
+            train_iter(inp, type(inp))
+
+            # dictionary input.
+            inp = {
+                EXPECTED_FIELDS[0]: a,
+                EXPECTED_FIELDS[1]: b,
+            }
+            train_iter(inp, type(inp))
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_namedtuple(self):
+            batch = 5
+            dim = 10
+
+            a = torch.rand(batch, dim, device=self.rank)
+            b = torch.rand(batch, dim, device=self.rank)
+
+            class NamedTupleModule(torch.nn.Module):
+                def __init__(self_):  # noqa: B902
+                    super().__init__()
+                    self_.lin = nn.Linear(10, 1)
+
+                def forward(self_, input, expected_type):  # noqa: B902
+                    # Without NamedTuple support, this would be of type tuple.
+                    self.assertTrue(
+                        isinstance(input, expected_type),
+                        f"Expected type {expected_type} but got {type(input)}",
+                    )
+                    self.assertEqual(input._fields, EXPECTED_FIELDS)
+                    self.assertEqual(a, input.a)
+                    self.assertEqual(b, input.b)
+                    return self_.lin(torch.mul(input.a, input.b))
+
+            model = torch.nn.parallel.DistributedDataParallel(
+                NamedTupleModule().cuda(self.rank), device_ids=[self.rank]
+            )
+            inp = TestNamedTupleInput_0(a, b)
+            # The following would fail if DDP does not propagate NamedTuples correctly.
+            model(inp, type(inp))
+
+            inp = TestNamedTupleInput_1(a, b)
+            model(inp, type(inp))
+
+        @require_backend_is_available({"gloo"})
+        def test_grads_same_across_ranks_with_no_sync(self):
+            _group, _group_id, rank = self._init_global_test()
+            world_size = dist.get_world_size()
+            if world_size < 2:
+                self.skipTest("This test requires at least two ranks.")
+
+            class SimpleConditionalModel(nn.Module):
+                # if rank is 0, uses nn1 on the first pass and nn2 on the second pass.
+                # else, uses nn3 on the first pass and nn4 on the second pass.
+
+                def __init__(self, rank):
+                    super().__init__()
+
+                    self.rank = rank
+                    self.nn1 = nn.Linear(1, 1)
+                    self.nn2 = nn.Linear(1, 1)
+                    self.nn3 = nn.Linear(1, 1)
+                    self.nn4 = nn.Linear(1, 1)
+                    self.state = 0
+
+                def forward(self, input):
+                    if self.state == 0:
+                        self.state = 1
+                        if self.rank == 0:
+                            return self.nn1(input)
+                        else:
+                            return self.nn3(input)
+                    else:
+                        self.state = 0
+                        if self.rank == 0:
+                            return self.nn2(input)
+                        else:
+                            return self.nn4(input)
+
+            model = torch.nn.parallel.DistributedDataParallel(
+                SimpleConditionalModel(rank), find_unused_parameters=True
+            )
+            mse_loss = nn.MSELoss()
+            grad_accumulation = 2
+
+            for microbatch_idx in range(grad_accumulation):
+                if microbatch_idx < grad_accumulation - 1:
+                    context = model.no_sync
+                else:
+                    context = nullcontext
+
+                with context():
+                    input = torch.rand((1,))
+                    output = model.forward(input)
+                    target = torch.rand((1,))
+
+                    loss = mse_loss(output, target)
+                    loss.backward()
+
+            self.assertTrue(
+                not any(p.grad is None for p in model.parameters()),
+                "Gradients can't be None for any model parameter.",
+            )
+            grads = torch.cat([p.grad.view(-1) for p in model.parameters()])
+
+            # Gather all gradients to rank 0.
+            if rank == 0:
+                gathered_grads = [torch.zeros_like(grads) for _ in range(world_size)]
+            else:
+                gathered_grads = []
+
+            dist.gather(grads, gather_list=gathered_grads, dst=0)
+            if rank == 0:
+                for g in gathered_grads[1:]:
+                    self.assertTrue(
+                        torch.allclose(gathered_grads[0], g),
+                        "Gradients are not the same for all ranks.",
+                    )
+
+        @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_control_flow_same_across_ranks(self):
+            # Control flow that is the same across ranks.
+            batch = 20
+            dim = 10
+
+            world_size = dist.get_world_size()
+            torch.cuda.set_device(self.rank)
+            model = torch.nn.parallel.DistributedDataParallel(
+                ControlFlowToyModel().cuda(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+            )
+            random_input = torch.randn(batch, dim, device=self.rank)
+            ones_input = torch.ones(batch, dim, device=self.rank)
+            for i in range(6):
+                if i % 2 == 0:
+                    out = model(random_input)
+                else:
+                    out = model(ones_input)
+                loss = out.sum()
+                loss.backward()
+                # On even iterations, 2nd param goes unused, on odd iterations,
+                # it is used.
+                local_used_map = model.reducer._get_local_used_map()
+                if i % 2 == 0:
+                    expected = torch.tensor(
+                        [world_size, 0], device=self.rank, dtype=torch.int32
+                    )
+                else:
+                    expected = torch.tensor(
+                        [world_size, world_size], device=self.rank, dtype=torch.int32
+                    )
+
+                # Validate parameter usage.
+                variable_usage_tensor = local_used_map
+                self.assertEqual(variable_usage_tensor, expected)
+
+            # Validate appropriate error message when DDP is used with
+            # find_unused_parameters=False.
+            model = torch.nn.parallel.DistributedDataParallel(
+                ControlFlowToyModel().cuda(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=False,
+            )
+            for i in range(2):
+                if i == 0:
+                    loss = model(random_input).sum()
+                    loss.backward()
+                else:
+                    try:
+                        loss = model(random_input).sum()
+                        loss.backward()
+                    except RuntimeError as e:
+                        msg = str(e)
+                        verify_ddp_error_logged(model, msg)
+                        # 2nd linear layer is unused
+                        unused_param_index = 1
+                        expected_strs = [
+                            ddp_prev_reduction_unfinished_str,
+                            ddp_recommend_find_unused_params_str,
+                            ddp_outputs_not_used_in_loss_str,
+                            f"Parameter indices which did not receive grad for rank {self.rank}: {unused_param_index}",
+                        ]
+                        # In debug mode, should show parameters that weren't reduced.
+                        # Without debug mode, should show suggestion to use debug mode.
+                        if dist.get_debug_level() == dist.DebugLevel.OFF:
+                            expected_strs.append(ddp_suggest_debug_mode_str)
+                        else:
+                            unreduced_params = ", ".join(["lin2.weight"])
+                            expected_strs.append(
+                                f"did not receive grad for rank {self.rank}: {unreduced_params}"
+                            )
+                        for s in expected_strs:
+                            self.assertTrue(s in msg, f"Expected {s} to be in {msg}")
+                        self.assertFalse(ddp_find_unused_params_enabled_str in msg)
+                    else:
+                        self.assertFalse(True, "DDP error not raised")
+
+            dist.barrier()
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_invalid_static_graph(self):
+            torch.cuda.set_device(self.rank)
+            model = torch.nn.parallel.DistributedDataParallel(
+                ControlFlowToyModel().cuda(self.rank),
+                device_ids=[self.rank],
+                static_graph=True,
+            )
+            random_input = torch.randn(20, 10, device=self.rank)
+            ones_input = torch.ones(20, 10, device=self.rank)
+            # unused parameter in the first iteration got used
+            # in second iteration.
+            expected_err = "Your training graph has changed in this iteration"
+            with self.assertRaisesRegex(RuntimeError, expected_err):
+                for i in range(2):
+                    if i % 2 == 0:
+                        out = model(random_input)
+                    else:
+                        out = model(ones_input)
+                    loss = out.sum()
+                    loss.backward()
+
+            verify_ddp_error_logged(model, expected_err)
+
+            # used parameter in the first iteration got unused
+            # in second iteration.
+            with self.assertRaisesRegex(
+                RuntimeError,
+                "Expected to have finished reduction in the prior iteration "
+                "before starting a new one. This error indicates that your "
+                "training graph has changed in this iteration, "
+                "e.g., one parameter is used in first iteration, "
+                "but then got unused in the second iteration. "
+                "this is not compatible with static_graph set to True.\n"
+                "Parameter indices which did not receive grad for",
+            ):
+                for i in range(2):
+                    if i % 2 != 0:
+                        out = model(random_input)
+                    else:
+                        out = model(ones_input)
+                    loss = out.sum()
+                    loss.backward()
+
+            verify_ddp_error_logged(model, "Expected to have finished reduction")
+
+        @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_control_flow_different_across_ranks(self):
+            # Control flow that is different across ranks.
+            batch = 20
+            dim = 10
+
+            class ToyModel(nn.Module):
+                def __init__(self, rank):
+                    super().__init__()
+                    self.lin1 = nn.Linear(10, 10, bias=False)
+                    self.lin2 = nn.Linear(10, 10, bias=False)
+                    self.rank = rank
+
+                def forward(self, x):
+                    # Control-flow that is rank and input dependent for the
+                    # model.
+                    use_second_layer = (
+                        torch.equal(x, torch.ones(batch, dim, device=x.device))
+                        and self.rank == 1
+                    )
+
+                    if use_second_layer:
+                        return self.lin2(F.relu(self.lin1(x)))
+                    else:
+                        return F.relu(self.lin1(x))
+
+            world_size = dist.get_world_size()
+            torch.cuda.set_device(self.rank)
+            model = torch.nn.parallel.DistributedDataParallel(
+                ToyModel(self.rank).cuda(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+            )
+            random_input = torch.randn(batch, dim, device=self.rank)
+            ones_input = torch.ones(batch, dim, device=self.rank)
+            for i in range(6):
+                if i % 2 == 0:
+                    out = model(random_input)
+                else:
+                    out = model(ones_input)
+                loss = out.sum()
+                loss.backward()
+                # On even iterations, 2nd param goes unused, on odd iterations,
+                # it is used only on rank 1.
+                local_used_map = model.reducer._get_local_used_map()
+
+                if i % 2 == 0:
+                    expected = torch.tensor(
+                        [world_size, 0], device=self.rank, dtype=torch.int32
+                    )
+                else:
+                    expected = torch.tensor(
+                        [world_size, 1], device=self.rank, dtype=torch.int32
+                    )
+
+                variable_usage_tensor = local_used_map
+                # Validate parameter usage. On odd iterations, 2nd param is only
+                # used on rank 1.
+                self.assertEqual(variable_usage_tensor, expected)
+
+            # Validate appropriate error message when DDP is used with
+            # find_unused_parameters=False.
+            model = torch.nn.parallel.DistributedDataParallel(
+                ToyModel(self.rank).cuda(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=False,
+            )
+            for i in range(2):
+                if i == 0:
+                    loss = model(random_input).sum()
+                    loss.backward()
+                else:
+                    try:
+                        loss = model(random_input).sum()
+                        loss.backward()
+                    except RuntimeError as e:
+                        msg = str(e)
+                        verify_ddp_error_logged(model, msg)
+                        unused_param_index = 1
+                        expected_strs = [
+                            ddp_prev_reduction_unfinished_str,
+                            ddp_recommend_find_unused_params_str,
+                            ddp_outputs_not_used_in_loss_str,
+                            f"Parameter indices which did not receive grad for rank {self.rank}: {unused_param_index}",
+                        ]
+                        # In debug mode, should show parameters that weren't reduced.
+                        # Without debug mode, should show suggestion to use debug mode.
+                        if dist.get_debug_level() == dist.DebugLevel.OFF:
+                            expected_strs.append(ddp_suggest_debug_mode_str)
+                        else:
+                            unreduced_params = ", ".join(["lin2.weight"])
+                            expected_strs.append(
+                                f"did not receive grad for rank {self.rank}: {unreduced_params}"
+                            )
+                        for s in expected_strs:
+                            self.assertTrue(s in msg, f"Expected {s} to be in {msg}")
+                        self.assertFalse(ddp_find_unused_params_enabled_str in msg)
+                    else:
+                        self.assertFalse(True, "DDP error not raised")
+
+            dist.barrier()
+
+        @require_backend_is_available({"gloo"})
+        def test_scatter_object_list(self):
+            src_rank = 0
+            collectives_object_test_list = create_collectives_object_test_list()
+            scatter_list = (
+                collectives_object_test_list
+                if self.rank == src_rank
+                else [None for _ in collectives_object_test_list]
+            )
+            world_size = dist.get_world_size()
+            scatter_list = scatter_list[:world_size]
+            i = 0
+            while len(scatter_list) < world_size:
+                scatter_list.append(scatter_list[i])
+                i += 1
+
+            output_obj_list = [None]
+            dist.scatter_object_list(output_obj_list, scatter_list, src=src_rank)
+            self.assertEqual(
+                output_obj_list[0],
+                collectives_object_test_list[
+                    self.rank % len(collectives_object_test_list)
+                ],
+            )
+            # Ensure errors are raised upon incorrect arguments.
+            with self.assertRaisesRegex(
+                ValueError,
+                "Expected argument scatter_object_output_list to be a list of size at least 1.",
+            ):
+                dist.scatter_object_list([], scatter_list, src=src_rank)
+
+        def _generate_sparse_tensors_for_bucket_assignment_test(self):
+            tensors = [
+                torch.empty([50], dtype=torch.float),
+                torch.empty([25], dtype=torch.double),
+                torch.empty([50], dtype=torch.float),
+                torch.empty([25], dtype=torch.double),
+                torch.empty([50], dtype=torch.float),
+                torch.empty([25], dtype=torch.double),
+            ]
+
+            tensors_sparse = [t.to_sparse() for t in tensors]
+            return tensors_sparse
+
+        def _test_compute_bucket_assignment_by_size(self, use_logger):
+            group_gloo = dist.new_group(
+                timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
+            )
+            # Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test
+            # determinism.
+            os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
+            group_to_use = dist.new_group(
+                backend=dist.get_backend(), timeout=timedelta(seconds=5)
+            )
+            torch.cuda.set_device(self.rank)
+
+            # Create a valid model. The constructor initializes the logger that we use later.
+            # We never actually use the rest of the model - we only need its logger.
+            net = EmbeddingNetDifferentParams(0)
+            net = torch.nn.parallel.DistributedDataParallel(
+                net.to(self.rank),
+                device_ids=[self.rank],
+                process_group=group_to_use,
+            )
+
+            # if we don't pass a logger then we can only check that an exception was thrown.
+            expected_err = "No support for sparse tensors."
+            with self.assertRaisesRegex(RuntimeError, expected_err):
+                tensors_sparse = (
+                    self._generate_sparse_tensors_for_bucket_assignment_test()
+                )
+                if use_logger:
+                    dist._compute_bucket_assignment_by_size(
+                        tensors_sparse, [400], logger=net.logger
+                    )
+                else:
+                    dist._compute_bucket_assignment_by_size(tensors_sparse, [400])
+            if use_logger:
+                verify_ddp_error_logged(net, expected_err)
+
+            # Perform gloo-based barrier to ensure one rank doesn't exit test
+            # early which causes failure with Barrier.sync.
+            dist.barrier(group_gloo)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_compute_bucket_assignment_by_size_sparse_error_without_logger(self):
+            self._test_compute_bucket_assignment_by_size(use_logger=False)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_compute_bucket_assignment_by_size_sparse_error_with_logger(self):
+            self._test_compute_bucket_assignment_by_size(use_logger=True)
+
+        def _test_verify_model_across_rank(self, use_logger):
+            group_gloo = dist.new_group(
+                timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
+            )
+            group_to_use = dist.new_group(
+                backend=dist.get_backend(), timeout=timedelta(seconds=5)
+            )
+            torch.cuda.set_device(self.rank)
+
+            # Create a valid model. The constructor initializes the logger that we use later.
+            net = EmbeddingNetDifferentParams(0)
+            net = torch.nn.parallel.DistributedDataParallel(
+                net.to(self.rank),
+                device_ids=[self.rank],
+                process_group=group_to_use,
+            )
+
+            # Modify the model so that the number of parameters are different for each rank.
+            # This will cause a RuntimeError to be thrown below in _verify_param_shape_across_processes,
+            # so we can check if the correct error is thrown and is logged.
+            # We can't do this in the constructor above otherwise the logger will
+            # not be properly initialized.
+            net.module.lin = nn.Linear(100 if self.rank == 0 else 10, 1)
+
+            # if we pass a logger we can verify that it was logged
+            caught = 0
+            try:
+                if use_logger:
+                    _verify_param_shape_across_processes(
+                        net.process_group, list(net.parameters()), net.logger
+                    )
+                else:
+                    _verify_param_shape_across_processes(
+                        net.process_group, list(net.parameters())
+                    )
+            except Exception:
+                caught = 1
+
+            # As long as there is one rank catching the exception
+            t = torch.Tensor([caught])
+            dist.all_reduce(t, group=group_gloo)
+            self.assertGreater(t, 0)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally"
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_verify_model_across_rank_with_logger(self):
+            self._test_verify_model_across_rank(use_logger=True)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally"
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_verify_model_across_rank_without_logger(self):
+            self._test_verify_model_across_rank(use_logger=False)
+
+        def _run_test_ddp_model_with_diff_params(self, net, ddp_group, group_gloo):
+            caught = 0
+            try:
+                net = torch.nn.parallel.DistributedDataParallel(
+                    net.to(self.rank), device_ids=[self.rank], process_group=ddp_group
+                )
+            except Exception:
+                caught = 1
+
+            # As long as there is one rank catching the exception
+            t = torch.Tensor([caught])
+            dist.all_reduce(t, group=group_gloo)
+            self.assertGreater(t, 0)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally"
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_model_diff_shape_across_ranks(self):
+            group_gloo = dist.new_group(
+                timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
+            )
+            group_to_use = dist.new_group(
+                backend=dist.get_backend(), timeout=timedelta(seconds=10)
+            )
+            torch.cuda.set_device(self.rank)
+            # Creates network with different sized embedding table on different
+            # ranks. This should throw an error during DDP init.
+            net = EmbeddingNetDifferentParams(self.rank)
+            self._run_test_ddp_model_with_diff_params(net, group_to_use, group_gloo)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally"
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_model_diff_num_params_across_ranks(self):
+            group_gloo = dist.new_group(
+                timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
+            )
+            group_to_use = dist.new_group(
+                backend=dist.get_backend(), timeout=timedelta(seconds=10)
+            )
+            torch.cuda.set_device(self.rank)
+
+            # Creates network with diff # of param across ranks, reducer should
+            # recognize this and throw appropriate error.
+            net = EmbeddingNetDifferentParams(
+                self.rank, diff_num_params=(self.rank == 1)
+            )
+
+            self._run_test_ddp_model_with_diff_params(
+                net,
+                group_to_use,
+                group_gloo,
+            )
+
+        def _test_output_unused_in_loss(self, module_cls, gradient_as_bucket_view):
+            model = module_cls()
+            local_net = copy.deepcopy(model)
+            net = torch.nn.parallel.DistributedDataParallel(
+                copy.deepcopy(model).cuda(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+            )
+
+            # Tests that certain parameters not getting gradient since the
+            # output is unused in loss computation is supported. Specifically,
+            # checks that the grads remain unchanged and are the same as local
+            # training.
+            inp = torch.randn(10, 10)
+
+            # Ensure that if a param is not used in loss computation, its
+            # gradient is untouched, i.e. if it is None before it is None after,
+            # not zero.
+            if module_cls == DictOutputModule:
+                a, b = local_net(inp)["predictions"]
+                a_dist, b_dist = net(inp)["predictions"]
+            else:
+                a, b = local_net(inp)
+                a_dist, b_dist = net(inp)
+
+            loss_dist = b_dist.sum()
+            loss_dist.backward()
+
+            # Ensure that gradient corresponding to parameter "a" was not
+            # touched, i.e. it is None and matches the local grad.
+            if module_cls == DictOutputModule:
+                self.assertTrue(net.module.module.a.weight.grad is None)
+                self.assertEqual(
+                    net.module.module.a.weight.grad, local_net.module.a.weight.grad
+                )
+            else:
+                self.assertTrue(net.module.a.weight.grad is None)
+                self.assertEqual(net.module.a.weight.grad, local_net.a.weight.grad)
+
+            saved_a_local_grad = None
+            saved_a_dist_grad = None
+            net.zero_grad()
+            local_net.zero_grad()
+            for i in range(6):
+                if module_cls == DictOutputModule:
+                    a, b = local_net(inp)["predictions"]
+                    a_dist, b_dist = net(inp)["predictions"]
+                else:
+                    a, b = local_net(inp)
+                    a_dist, b_dist = net(inp)
+                if i < 2:
+                    # Use both params in loss computation. Later, "a" will go
+                    # unused and we check to ensure DDP supports this and
+                    # gradients remain the same as local training.
+                    t = a @ b
+                    t_dist = a_dist @ b_dist
+                    loss = t.sum()
+                    loss_dist = t_dist.sum()
+                else:
+                    # Model output "a" unused in loss.
+                    loss = b.sum()
+                    loss_dist = b_dist.sum()
+                loss.backward()
+                loss_dist.backward()
+                if i == 1:
+                    # Save grads to compare with them in next iterations.
+                    if module_cls == DictOutputModule:
+                        saved_a_local_grad = local_net.module.a.weight.grad
+                        saved_a_dist_grad = net.module.module.a.weight.grad
+                    else:
+                        saved_a_local_grad = local_net.a.weight.grad
+                        saved_a_dist_grad = net.module.a.weight.grad
+                    self.assertEqual(saved_a_local_grad, saved_a_dist_grad)
+                elif i >= 2:
+                    # parameter "a" of both models should be the same and not change
+                    if module_cls == DictOutputModule:
+                        self.assertEqual(
+                            net.module.module.a.weight.grad, saved_a_dist_grad
+                        )
+                        self.assertEqual(
+                            local_net.module.a.weight.grad, saved_a_local_grad
+                        )
+                    else:
+                        self.assertEqual(net.module.a.weight.grad, saved_a_dist_grad)
+                        self.assertEqual(local_net.a.weight.grad, saved_a_local_grad)
+
+                # Verify grads are the same
+                for local_param, dist_param in zip(
+                    local_net.parameters(), net.parameters(), strict=True
+                ):
+                    local_grad = local_param.grad
+                    dist_grad = dist_param.grad
+                    self.assertEqual(local_grad, dist_grad)
+
+            dist.barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_output_unused_in_loss_tuple_module(self):
+            module_cls = UnusedParamTwoLinLayerNet
+            for grad_as_bucket_view in [True, False]:
+                self._test_output_unused_in_loss(module_cls, grad_as_bucket_view)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_output_unused_in_loss_dict_module(self):
+            module_cls = DictOutputModule
+            for grad_as_bucket_view in [True, False]:
+                self._test_output_unused_in_loss(module_cls, grad_as_bucket_view)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_undefined_grad_parity_unused_parameters(self):
+            # TODO: enable this for general training use cases:
+            # https://github.com/pytorch/pytorch/issues/58511.
+            x = torch.ones(1, 2).to(self.rank)
+            net = Net().to(self.rank)
+            local_net = copy.deepcopy(net)
+            net = torch.nn.parallel.DistributedDataParallel(
+                net,
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+            )
+            out = net(x).sum()
+            local_out = local_net(x).sum()
+            # Simulates undefined gradients.
+            torch._C._functions.UndefinedGrad()(out).backward()
+            torch._C._functions.UndefinedGrad()(local_out).backward()
+            for (dist_param_name, dist_param), (local_param_name, local_param) in zip(
+                net.named_parameters(), local_net.named_parameters(), strict=True
+            ):
+                dist_grad = dist_param.grad
+                local_grad = local_param.grad
+                self.assertEqual(
+                    dist_grad,
+                    local_grad,
+                    f"""DDP param {dist_param_name} with grad {dist_grad}
+                    does not match local param {local_param_name} with grad
+                    {local_grad}""",
+                )
+
+        def _test_different_graph_across_ranks(
+            self, find_unused_parameters=False, static_graph=False
+        ):
+            class ToyModel(nn.Module):
+                def __init__(self, rank):
+                    super().__init__()
+                    self.lin1 = nn.Linear(10, 10, bias=False)
+                    self.lin2 = nn.Linear(10, 10, bias=False)
+                    self.rank = rank
+
+                def forward(self, x):
+                    if self.rank == 0:
+                        return self.lin2(F.relu(self.lin1(x)))
+                    else:
+                        return F.relu(self.lin1(x))
+
+            torch.manual_seed(31415)
+            torch.cuda.set_device(self.rank)
+            model = ToyModel(self.rank).cuda(self.rank)
+            ddp_model = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+                find_unused_parameters=find_unused_parameters,
+                gradient_as_bucket_view=True,
+                static_graph=static_graph,
+            )
+            random_input = torch.randn(20, 10, device=self.rank)
+            for _ in range(10):
+                out = ddp_model(random_input)
+                loss = out.sum()
+                loss.backward()
+            return ddp_model
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_different_graph_across_ranks(self):
+            base_model = self._test_different_graph_across_ranks(
+                find_unused_parameters=True
+            )
+            self.assertFalse(
+                base_model._get_ddp_logging_data().get("has_rebuilt_buckets", 0)
+            )
+            static_model = self._test_different_graph_across_ranks(static_graph=True)
+            self.assertTrue(
+                static_model._get_ddp_logging_data().get("has_rebuilt_buckets", 0)
+            )
+            for i, j in zip(
+                base_model.parameters(), static_model.parameters(), strict=True
+            ):
+                self.assertEqual(i, j)
+
+        @require_backend_is_available({"gloo"})
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "MacOS uses uv transport which does not have as robust error handling as tcp transport",
+        )
+        def test_monitored_barrier_gloo(self):
+            tensors = [torch.ones(10) * self.rank]
+            # Kick off some allreduce work on all ranks
+            for _ in range(10):
+                dist.all_reduce(torch.cat(tensors))
+            # Run monitored barrier and ensure it passes
+            timeout = timedelta(seconds=2)
+            dist.monitored_barrier(timeout=timeout)
+            # Check monitored_barrier success with wait_all_ranks=True
+            for _ in range(10):
+                dist.all_reduce(torch.cat(tensors))
+            dist.monitored_barrier(timeout=timeout, wait_all_ranks=True)
+            # All ranks besides 1 call into barrier, rank 0 should report failure
+            # while others report gloo error.
+            failed_rank = 1
+            src_rank = 0
+            if self.rank == src_rank:
+                with self.assertRaisesRegex(
+                    RuntimeError, f"Rank {failed_rank} failed to pass monitoredBarrier"
+                ):
+                    dist.monitored_barrier(timeout=timeout)
+            elif self.rank != failed_rank:
+                # Other ranks should not pass barrier since rank 0 failed.
+                err_regex = (
+                    f"Rank {self.rank} successfully reached monitoredBarrier,"
+                    f" but received errors while waiting for send/recv from rank"
+                    f" {src_rank}"
+                )
+                with self.assertRaisesRegex(RuntimeError, err_regex):
+                    dist.monitored_barrier(timeout=timeout)
+
+            # We need a barrier since otherwise failed_rank exits too early
+            # and cause a timeout.
+            self._barrier(timeout=30)
+
+        @require_backend_is_available({"gloo"})
+        def test_monitored_barrier_gloo_subgroup(self):
+            # Tests that monitored_barrier works as expected on non-default
+            # process groups.
+            failed_rank = 1
+            timeout = 0.1
+            subgroup = dist.new_group(ranks=[0, 1])
+
+            if self.rank == failed_rank:
+                return
+
+            if self.rank == 0:
+                with self.assertRaisesRegex(
+                    RuntimeError, f"Rank {failed_rank} failed to pass monitoredBarrier"
+                ):
+                    dist.monitored_barrier(subgroup, timeout)
+            else:
+                # Other ranks call into monitored_barrier, but this should be a
+                # noop because they are not part of the subgroup. Verify that
+                # there are no errors here.
+                dist.monitored_barrier(subgroup, timeout)
+
+        def _test_monitored_barrier_allreduce_hang(self, wait_all_ranks):
+            # tests expected behavior when nonzero rank hangs.
+            nccl_pg = dist.new_group(
+                ranks=list(range(int(self.world_size))),
+                # provide sufficient timeout so communicators
+                # can be initialized in ctor.
+                timeout=timedelta(seconds=15),
+                backend=dist.Backend.NCCL,
+            )
+            gloo_pg = dist.new_group(
+                ranks=list(range(int(self.world_size))),
+                backend=dist.Backend.GLOO,
+            )
+            tensors = [torch.ones(10, device=self.rank) * self.rank]
+            # Let all ranks call allreduce first to set up communicators etc.
+            # Directly simulating error here will run into store issue described
+            # in https://github.com/pytorch/pytorch/issues/54524.
+            nccl_pg.allreduce(tensors).wait(timedelta(seconds=5))
+            # All ranks besides 0 call into allreduce. This is to simulate a
+            # desync across the world, where some ranks call into
+            # monitored_barrier() and others are stuck in collective comm. In
+            # practice, we don't need TORCH_NCCL_BLOCKING_WAIT, but we use it in this
+            # test to ensure it exits cleanly.
+            if self.rank != 0:
+                # Can get different errors here depending on whether gloo-based
+                # wrapper PG is enabled or not, since with wrapper pg, it will
+                # fail in a collective synchronization check and not actually
+                # call into the nccl pg.
+                if dist.get_debug_level() == dist.DebugLevel.DETAIL:
+                    err_regex = "Timed out waiting"
+                else:
+                    err_regex = "caught collective operation timeout"
+                with self.assertRaisesRegex(RuntimeError, err_regex):
+                    nccl_pg.allreduce(tensors).wait(timedelta(seconds=0.1))
+            else:
+                # Rank 0 should report first (in order) timed out rank or all ranks
+                # depending on wait_all_ranks flag passed into monitored_barrier.
+                if wait_all_ranks:
+                    rank_str = ", ".join(
+                        [str(i) for i in range(1, int(self.world_size))]
+                    )
+                    err_regex = f"Ranks {rank_str} failed to pass monitoredBarrier"
+                else:
+                    expected_first_fail_rank = 1
+                    err_regex = f"Rank {expected_first_fail_rank} failed to pass monitoredBarrier"
+                monitored_barrier_timeout_seconds = timedelta(seconds=0.1)
+                with self.assertRaisesRegex(RuntimeError, err_regex):
+                    gloo_pg.monitored_barrier(
+                        monitored_barrier_timeout_seconds, wait_all_ranks=wait_all_ranks
+                    )
+
+            self._barrier(timeout=30)
+
+        @with_nccl_blocking_wait
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_monitored_barrier_allreduce_hang(self):
+            # tests expected behavior when nonzero rank hangs and we want to
+            # report first timed out rank.
+            self._test_monitored_barrier_allreduce_hang(wait_all_ranks=False)
+
+        @with_nccl_blocking_wait
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_monitored_barrier_allreduce_hang_wait_all_ranks(self):
+            # Need to disable TORCH_NCCL_DUMP_ON_TIMEOUT otherwise this test times out
+            os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "0"
+            # tests expected behavior when nonzero rank hangs and we want to
+            # report all timed out ranks.
+            self._test_monitored_barrier_allreduce_hang(wait_all_ranks=True)
+
+        @require_backend_is_available({"gloo"})
+        def test_monitored_barrier_gloo_rank_0_timeout(self):
+            # tests error when rank 0 exhausts its given timeout.
+            process_group = dist.new_group(ranks=list(range(int(self.world_size))))
+            timeout = timedelta(seconds=0)
+            if self.rank == 0:
+                with self.assertRaisesRegex(
+                    RuntimeError, f"Rank {self.rank} timed out in monitoredBarrier"
+                ):
+                    process_group.monitored_barrier(timeout)
+
+        @require_backend_is_available({"gloo"})
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "MacOS uses uv transport which does not have as robust error handling as tcp transport",
+        )
+        def test_monitored_barrier_failure_order(self):
+            # Ensure that the first (in sorted order) rank is reported when
+            # multiple ranks fail to pass the monitored_barrier.
+            # TODO(#54879): Provide ability to wait and report all failed ranks
+            expected_first_failed_rank = 2
+            timeout = timedelta(seconds=2)
+            src_rank = 0
+            if self.rank == src_rank:
+                with self.assertRaisesRegex(
+                    RuntimeError, f"Rank {expected_first_failed_rank}"
+                ):
+                    dist.monitored_barrier(timeout=timeout)
+            elif self.rank == 1:
+                err_regex = (
+                    f"Rank {self.rank} successfully reached monitoredBarrier,"
+                    f" but received errors while waiting for send/recv from rank"
+                    f" {src_rank}"
+                )
+                with self.assertRaisesRegex(RuntimeError, err_regex):
+                    dist.monitored_barrier(timeout=timeout)
+
+        @require_backend_is_available({"gloo"})
+        @skip_if_small_worldsize
+        def test_monitored_barrier_wait_all_ranks(self):
+            # Tests simple case where > 1 rank does not call into monitored
+            # barrier and verifies all ranks are reported by rank 0.
+            if self.rank == 0:
+                timeout = timedelta(seconds=0.1)
+                rank_str = ", ".join([str(i) for i in range(1, int(self.world_size))])
+                err_regex = f"Ranks {rank_str} failed to pass monitoredBarrier"
+                with self.assertRaisesRegex(RuntimeError, err_regex):
+                    dist.monitored_barrier(timeout=timeout, wait_all_ranks=True)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @with_dist_debug_levels(levels=["INFO"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_build_debug_param_to_name_mapping(self):
+            model = TwoLinLayerNet()
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(self.rank),
+                device_ids=[self.rank],
+            )
+            expected_mapping = {0: "a.weight", 1: "b.weight"}
+            net_params, _ = net._build_params_for_reducer()
+            param_to_name_mapping = net._build_debug_param_to_name_mapping(net_params)
+            self.assertDictEqual(expected_mapping, param_to_name_mapping)
+
+            # Test when DDP is used with ignored parameters.
+            model = TwoLinLayerNet()
+            # Parameters to ignore are in the format {module_name}.{param_name}
+            params_to_ignore = ["a.weight"]
+            torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
+                model, params_to_ignore
+            )
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(self.rank),
+                device_ids=[self.rank],
+            )
+            expected_mapping = {0: "b.weight"}
+            net_params, _ = net._build_params_for_reducer()
+            param_to_name_mapping = net._build_debug_param_to_name_mapping(net_params)
+            self.assertDictEqual(expected_mapping, param_to_name_mapping)
+
+            # Test errors are raised when DDP and module parameters mismatch.
+            # This generally indicates a bug with DDP and is not expected to
+            # happen in user applications.
+            model = TwoLinLayerNet()
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(self.rank),
+                device_ids=[self.rank],
+            )
+            net_params, _ = net._build_params_for_reducer()
+            if self.rank == 0:
+                print(type(net_params[0]))
+
+            net_params.extend(
+                [
+                    torch.nn.Parameter(torch.ones(1)),
+                    torch.nn.Parameter(torch.ones(1)),
+                ]
+            )
+
+            with self.assertRaisesRegex(ValueError, "Expected param to name mapping"):
+                net._build_debug_param_to_name_mapping(net_params)
+
+            net_params = net_params[:-3]
+            with self.assertRaisesRegex(ValueError, "Param with name"):
+                net._build_debug_param_to_name_mapping(net_params)
+
+            net_params.extend(
+                [
+                    torch.nn.Parameter(torch.ones(1)),
+                    torch.nn.Parameter(torch.ones(1)),
+                ]
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @with_dist_debug_levels(levels=["INFO"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_build_debug_param_to_name_mapping_requires_grad(self):
+            class Net(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.lin = nn.Linear(10, 10)
+                    # Is not tracked by DDP and should not show up in param to
+                    # name mapping.
+                    self.lin.bias.requires_grad_(False)
+
+                def forward(self, x):
+                    return self.lin(x)
+
+            model = Net()
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(self.rank), device_ids=[self.rank]
+            )
+            expected_mapping = {
+                0: "lin.weight",
+            }
+            net_params, _ = net._build_params_for_reducer()
+            param_to_name_mapping = net._build_debug_param_to_name_mapping(net_params)
+            self.assertEqual(param_to_name_mapping, expected_mapping)
+
+        def _test_ddp_multiple_nested_unused_params_error(self, ignore_sparse):
+            debug_mode_off = dist.get_debug_level() == dist.DebugLevel.OFF
+
+            class SubModule(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.embedding_net = EmbeddingNetDifferentParams(0)
+                    self.lin = TwoLinLayerNet()
+                    self.bn = BatchNormNet()
+                    self.lin_layer = nn.Linear(4, 10, bias=False)
+
+                def forward(self, x):
+                    x = self.bn(x)
+                    x = self.lin_layer(x)
+                    x = self.lin.a(x)  # self.lin.b param unused
+                    # EmbeddingNetDifferentParams entirely unused: self.embedding_net.embedding and
+                    # self.embedding_net.lin unused.
+                    return x
+
+            class MyModel(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.sub_module = SubModule()
+
+                def forward(self, x):
+                    return self.sub_module(x)
+
+            model = MyModel()
+            sparse_embedding_fqns = []
+            if ignore_sparse:
+                for module_name, module in model.named_modules():
+                    if module == model.sub_module.embedding_net.embedding:
+                        for parameter_name, _param in module.named_parameters(
+                            recurse=False
+                        ):
+                            fqn = f"{module_name}.{parameter_name}"
+                            sparse_embedding_fqns.append(fqn)
+
+                torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
+                    model, sparse_embedding_fqns
+                )
+                unused_modules = [
+                    model.sub_module.embedding_net.lin,
+                    model.sub_module.lin.b,
+                ]
+            else:
+                unused_modules = list(model.sub_module.embedding_net.modules()) + [
+                    model.sub_module.lin.b,
+                ]
+
+            expected_unused_param_fqns = []
+            used_param_fqns = []  # Validate that these don't mistakenly show up.
+            fqn_to_param_index = {}
+            index = 0
+            for module_name, module in model.named_modules():
+                for parameter_name, _param in module.named_parameters(recurse=False):
+                    fqn = f"{module_name}.{parameter_name}"
+                    fqn_to_param_index[fqn] = index
+                    if fqn not in sparse_embedding_fqns:
+                        index += 1
+                    if module in unused_modules:
+                        expected_unused_param_fqns.append(fqn)
+                    else:
+                        if (
+                            not ignore_sparse
+                            or module != model.sub_module.embedding_net.embedding
+                        ):
+                            used_param_fqns.append(fqn)
+
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(self.rank),
+                device_ids=[self.rank],
+            )
+            batch, dim = 10, 2
+            inp = torch.ones(batch, dim)
+            for i in range(2):
+                if i == 0:
+                    out = net(inp)
+                    loss = out.sum()
+                    loss.backward()
+                else:
+                    try:
+                        out = net(inp)
+                        loss = out.sum()
+                        loss.backward()
+                    except RuntimeError as e:
+                        e = str(e)
+
+                        unused_param_substr = e[e.find("did not receive grad") :]
+                        # Validate that each unused param fully qualified name
+                        # shows up in error logs. We do this instead of
+                        # constructing a joined string since order of parameters
+                        # can be different in Reducer. In addition, validate
+                        # param indices show up as well.
+                        for unused_param_fqn in expected_unused_param_fqns:
+                            self.assertTrue(
+                                unused_param_fqn in unused_param_substr
+                                or debug_mode_off
+                            )
+                            self.assertTrue(
+                                str(fqn_to_param_index[unused_param_fqn])
+                                in unused_param_substr,
+                                f"Did not find index {fqn_to_param_index[unused_param_fqn]} for {unused_param_fqn}",
+                            )
+
+                        # Validate that used param fqns don't show up in error
+                        # logs.
+                        for used_param_fqn in used_param_fqns:
+                            self.assertFalse(used_param_fqn in unused_param_substr)
+                        # Validate that ignored param fqns don't show up as unused
+                        # (since DDP does not track them)
+                        for sparse_param_fqn in sparse_embedding_fqns:
+                            self.assertFalse(sparse_param_fqn in unused_param_substr)
+                    else:
+                        self.assertTrue(False, "Expected error was not raised!")
+
+        @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_multiple_nested_unused_params_error(self):
+            self._test_ddp_multiple_nested_unused_params_error(ignore_sparse=False)
+
+        @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_multiple_nested_unused_params_err_ignore_params(self):
+            # Tests unused parameter reporting when DDP is configured to ignore
+            # certain parameters.
+            self._test_ddp_multiple_nested_unused_params_error(ignore_sparse=True)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_inference(self):
+            # tests that DDP module can be run on a single node with no_grad
+            # or eval setting and there is no hang.
+            rank = self.rank
+            torch.cuda.set_device(rank)
+            model = Net().cuda()
+            local_model = copy.deepcopy(model)
+            model = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[rank],
+            )
+            syncbn_model = nn.SyncBatchNorm(
+                2, momentum=0.99, track_running_stats=False
+            ).cuda()
+            local_syncbn_model = copy.deepcopy(syncbn_model)
+            syncbn_model = torch.nn.parallel.DistributedDataParallel(
+                syncbn_model, device_ids=[rank]
+            )
+            inp = torch.randn(10, 2, device=rank)
+            inp_syncbn = torch.randn(10, 2, 4, 4, device=rank)
+            tests = [
+                (model, local_model, inp),
+                (syncbn_model, local_syncbn_model, inp_syncbn),
+            ]
+            for test in tests:
+                test_model, test_local_model, test_inp = test
+                if self.rank == 0:
+                    test_model.eval()
+                    test_local_model.eval()
+                    for _ in range(6):
+                        self.assertEqual(
+                            test_model(test_inp), test_local_model(test_inp)
+                        )
+
+            # Barrier since only rank 0 runs inference. Test should be
+            # much faster than 30s, but this is to avoid flakiness.
+            self._barrier(timeout=30)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_lt_x_gpu(2)
+        @unittest.skip(
+            "Test is failing, see https://github.com/pytorch/pytorch/pull/113620"
+        )
+        def test_ddp_sync_bn_training_vs_eval(self):
+            rank = self.rank
+            torch.cuda.set_device(rank)
+            # Need to set track_running_stats=False, when track_running_stats=True,
+            # bn_training is False and sync could not occur in eval model.
+            model = nn.SyncBatchNorm(2, momentum=0.99, track_running_stats=False).cuda(
+                rank
+            )
+            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
+            # Test sync occurs in training mode.
+            with torch.autograd.profiler.profile() as prof:
+                for _ in range(6):
+                    inp = torch.randn(10, 2, 4, 4).cuda(rank)
+                    out = model(inp)
+                    loss = out.sum()
+                    loss.backward()
+
+            # SyncBN allgathers stats across all ranks, so verify call to
+            # all_gather in profiler.
+            if BACKEND == "nccl":
+                all_gather_calls = get_profiling_event("_all_gather_base", prof)
+            else:
+                all_gather_calls = get_profiling_event("all_gather", prof)
+            self.assertNotEqual([], all_gather_calls)
+
+            # Only do inference on one rank. If SyncBN did collective stats sync,
+            # this would hang/error.
+            model_inference = model.module
+            if self.rank == 0:
+                model_inference.eval()
+                with torch.autograd.profiler.profile() as prof:
+                    for _ in range(6):
+                        inp = torch.randn(10, 2, 4, 4).cuda(rank)
+                        out = model_inference(inp)
+                        loss = out.sum()
+                        loss.backward()
+
+                # Ensure sync does not occur in eval() mode.
+                if BACKEND == "nccl":
+                    all_gather_calls = get_profiling_event("_all_gather_base", prof)
+                else:
+                    all_gather_calls = get_profiling_event("all_gather", prof)
+                self.assertEqual([], all_gather_calls)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_python_error_logged(self):
+            # Most python exceptions in DDP are raised during init before
+            # reducer is constructed, so we don't have a logger in those cases.
+            # However, the below is one example where a python error is thrown
+            # after reducer is constructed.
+            model = TwoLinLayerNet().cuda(self.rank)
+            model = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+            )
+            expected_err = "must be callable"
+            with self.assertRaisesRegex(TypeError, expected_err):
+                model.register_comm_hook({}, {})
+
+            verify_ddp_error_logged(model, expected_err)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_static_graph_nested_types(self):
+            # Tests for static graph training when outputs are not just tensors
+            # but can be (nested) tuple, list, dict, etc.
+            rank = self.rank
+            torch.cuda.set_device(rank)
+
+            class NestedOutputModule(torch.nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.lin = nn.Linear(100, 1, bias=False)
+
+                def forward(self, inp, output_type):
+                    if output_type == "tuple":
+                        return (
+                            self.lin(inp),
+                            (
+                                self.lin(inp),
+                                self.lin(inp),
+                            ),
+                        )
+                    elif output_type == "list":
+                        return [
+                            self.lin(inp),
+                            [
+                                self.lin(inp),
+                                self.lin(inp),
+                            ],
+                        ]
+                    elif output_type == "dict":
+                        return {
+                            "a": self.lin(inp),
+                            "b": {
+                                "c": self.lin(inp),
+                            },
+                        }
+
+            def get_loss(model_output):
+                loss = 0.0
+                if isinstance(model_output, torch.Tensor):
+                    return model_output.sum()
+                elif isinstance(model_output, dict):
+                    for value in model_output.values():
+                        loss += get_loss(value)
+                elif isinstance(model_output, (tuple, list)):
+                    for x in model_output:
+                        loss += get_loss(x)
+                else:
+                    raise ValueError(f"Unknown model output type {type(model_output)}")
+                return loss
+
+            model = NestedOutputModule().cuda(rank)
+            model_static_graph = copy.deepcopy(model)
+            model = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[rank],
+            )
+            model_static_graph = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[rank],
+                static_graph=True,
+            )
+            inp = torch.randn(10, 100)
+            type_mapping = {
+                "list": list,
+                "tuple": tuple,
+                "dict": dict,
+            }
+            for output_type in type_mapping:
+                for _ in range(6):
+                    out = model(inp, output_type=output_type)
+                    loss = get_loss(out)
+                    loss.backward()
+                    self._model_step(model)
+                    out_static = model_static_graph(inp, output_type=output_type)
+                    self.assertTrue(isinstance(out_static, type_mapping[output_type]))
+                    loss_static = get_loss(out_static)
+                    loss_static.backward()
+                    self._model_step(model_static_graph)
+                    for p, p_static in zip(
+                        model.parameters(), model_static_graph.parameters(), strict=True
+                    ):
+                        self.assertEqual(p, p_static)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_returns_tensor_with_no_grad(self):
+            # Tests case where module returns tensor that does not require grad.
+            torch.cuda.set_device(self.rank)
+
+            class MyModel(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.fc1 = nn.Linear(10, 10, bias=False)
+                    self.fc2 = nn.Linear(10, 10, bias=False)
+
+                def forward(self, x):
+                    x = self.fc2(F.relu(self.fc1(x)))
+                    y = x.clone()
+                    x = x.detach()
+                    assert not x.requires_grad
+                    return (x, y)
+
+            model = MyModel().to(self.rank)
+            inp = torch.randn(1, 10, device=self.rank)
+            for find_unused, static_graph in itertools.product(
+                [True, False], [True, False]
+            ):
+                ddp = DistributedDataParallel(
+                    model,
+                    device_ids=[self.rank],
+                    output_device=self.rank,
+                    find_unused_parameters=find_unused,
+                    static_graph=static_graph,
+                )
+                for _ in range(6):
+                    out = ddp(inp)
+                    self.assertFalse(out[0].requires_grad)
+                    o = (out[0] + out[1]).sum()
+                    o.backward()
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_detect_ddp_is_actually_static(self):
+            class ToyModel(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.net1 = nn.Linear(10, 10, bias=False)
+                    self.net2 = nn.Linear(10, 10)
+
+                def forward(self, x, find_unused, dynamic):
+                    if find_unused:
+                        if dynamic:
+                            return self.net2(self.net1(x))
+                        else:
+                            return self.net2(x)
+                    else:
+                        return self.net2(self.net1(x))
+
+            # Set of unused parameters don't change across iterations
+            torch.cuda.set_device(self.rank)
+            model = ToyModel().cuda()
+            for find_unused in [True, False]:
+                ddp = torch.nn.parallel.DistributedDataParallel(
+                    model,
+                    device_ids=[self.rank],
+                    find_unused_parameters=find_unused,
+                )
+                inp = torch.randn(1, 10, device="cuda")
+                for _ in range(6):
+                    out = ddp(inp, find_unused=find_unused, dynamic=False)
+                    loss = out.sum()
+                    loss.backward()
+                    self.assertTrue(ddp.reducer._ddp_graph_static())
+
+            # Set of unused parameters dynamically change
+            ddp = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+            )
+            inp = torch.randn(1, 10, device="cuda")
+            for i in range(6):
+                out = ddp(inp, find_unused=True, dynamic=i % 2 == 0)
+                loss = out.sum()
+                loss.backward()
+            self.assertFalse(ddp.reducer._ddp_graph_static())
+
+        def _test_ddp_new_tensor_in_fwd(self, static_graph):
+            # Test from https://github.com/pytorch/pytorch/issues/60733
+            class MyModel(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.fc1 = nn.Linear(10, 10, bias=False)
+                    self.fc2 = nn.Linear(10, 10, bias=False)
+                    self.device = self.fc1.weight.device
+
+                def __init_opt(self):
+                    opt = torch.randn(1, 10, device=self.device)
+                    return opt
+
+                def forward(self, x, opt_1, opt_2, opt_nested):
+                    x = F.relu(self.fc1(x))
+                    x = self.fc2(x)
+                    if opt_1 is None:
+                        opt_1 = self.__init_opt()
+                    if opt_2 is None:
+                        opt_2 = self.__init_opt()
+                    if opt_nested is None or not torch.is_tensor(opt_nested):
+                        opt_nested = self.__init_opt()
+                    # Test multiple tensors as well as newly created tensors
+                    # within a struct.
+                    return x, opt_1, opt_2, {"tensor": opt_nested}
+
+            model = MyModel().to(self.rank)
+            for find_unused in [True, False]:
+                ddp = DistributedDataParallel(
+                    model,
+                    device_ids=[self.rank],
+                    output_device=self.rank,
+                    broadcast_buffers=False,
+                    find_unused_parameters=find_unused,
+                    static_graph=static_graph,
+                )
+
+                opt = [None for _ in range(3)]
+                for i in range(2):
+                    ddp.zero_grad()
+                    x = torch.randn(1, 10, device=self.rank)
+                    out, opt[0], opt[1], opt[2] = ddp(
+                        x, opt_1=opt[0], opt_2=opt[1], opt_nested=opt[2]
+                    )
+                    for i in range(len(opt)):
+                        if torch.is_tensor(opt[i]):
+                            self.assertEqual(opt[i].grad_fn, None)
+                        else:
+                            self.assertEqual(opt[i]["tensor"].grad_fn, None)
+                    out.mean().backward()
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_new_tensor_in_fwd(self):
+            return self._test_ddp_new_tensor_in_fwd(static_graph=False)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_new_tensor_in_fwd_static_graph(self):
+            return self._test_ddp_new_tensor_in_fwd(static_graph=True)
+
+        def _test_ddp_buffer_hook_allreduce(self, return_futures):
+            rank = self.rank
+            torch.cuda.set_device(rank)
+            torch.manual_seed(rank)
+            torch.cuda.manual_seed(rank)
+
+            def buffer_comm_hook(ddp, named_buffers):
+                buffers = [buffer for (_, buffer) in named_buffers.items()]
+                futs = [
+                    dist.all_reduce(
+                        buffer, group=ddp.process_group, async_op=True
+                    ).get_future()
+                    for buffer in buffers
+                ]
+                if return_futures:
+                    return futs
+                else:
+                    torch.futures.collect_all(futs).wait()
+
+            hook_pre_fwd = (
+                torch.nn.parallel.distributed._BufferCommHookLocation.PRE_FORWARD
+            )
+            hook_post_fwd = (
+                torch.nn.parallel.distributed._BufferCommHookLocation.POST_FORWARD
+            )
+            for hook_run_location in [
+                hook_pre_fwd,
+                hook_post_fwd,
+            ]:
+                model = NetWithBuffers().cuda(rank)
+                model_ddp = torch.nn.parallel.DistributedDataParallel(
+                    model,
+                    device_ids=[self.rank],
+                )
+                model_ddp._register_buffer_comm_hook(
+                    model_ddp, buffer_comm_hook, hook_run_location
+                )
+                model_ddp_no_hook = torch.nn.parallel.DistributedDataParallel(
+                    copy.deepcopy(model),
+                    device_ids=[self.rank],
+                    broadcast_buffers=False,
+                )
+                inp = torch.randn(2, 10, device=rank)
+                for _ in range(2):
+                    loss_hook = model_ddp(inp).sum()
+                    # Since buffer reduction is done pre-forward, simulate it for
+                    # no hook case here.
+                    # Simulate allreduce appropriately depending on hook location.
+                    if hook_run_location == hook_pre_fwd:
+                        model_no_hook_buffers = list(model_ddp_no_hook.module.buffers())
+                        for tensor in model_no_hook_buffers:
+                            dist.all_reduce(tensor)
+
+                    loss_no_hook = model_ddp_no_hook(inp).sum()
+                    if hook_run_location == hook_post_fwd:
+                        model_no_hook_buffers = list(model_ddp_no_hook.module.buffers())
+                        for tensor in model_no_hook_buffers:
+                            dist.all_reduce(tensor)
+                    torch.cuda.synchronize()
+
+                    # if return_futures, they are only awaited on by DDP
+                    # at the end of the backwards pass for maximum overlap.
+                    if not return_futures:
+                        self._verify_buffers_equal(model_ddp, model_ddp_no_hook)
+                    loss_hook.backward()
+                    loss_no_hook.backward()
+                    # Note that when custom hooks return futures, this
+                    # comparison is not expected to work when hook run location
+                    # is pre-forward pass. This is because the hook does async
+                    # communication and forward pass modifies the buffer without
+                    # appropriate synchronization. Therefore, if returning
+                    # futures from custom buffer hooks, it is advised to set
+                    # hook run location to post forward.
+                    if return_futures and hook_run_location == hook_post_fwd:
+                        self._verify_buffers_equal(model_ddp, model_ddp_no_hook)
+                dist.barrier()
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_buffer_hook_allreduce_return_future(self):
+            self._test_ddp_buffer_hook_allreduce(return_futures=True)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_buffer_hook_allreduce(self):
+            self._test_ddp_buffer_hook_allreduce(return_futures=False)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_broadcast_buffer_via_hook(self):
+            # test that _distributed_broadcast_coalesced via registered hook is
+            # equivalent to DDP's default broadcast coalesced.
+            rank = self.rank
+            torch.cuda.set_device(rank)
+            torch.manual_seed(rank)
+            torch.cuda.manual_seed(rank)
+
+            def buffer_comm_hook(ddp, named_buffers):
+                # named_buffers is a Dict[str, Tensor] representing a mapping
+                # from buffer name to buffer.
+                buffers = [buffer for (_, buffer) in named_buffers.items()]
+                ddp._default_broadcast_coalesced(buffers)
+
+            model = NetWithBuffers().cuda(rank)
+            model_ddp = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+            )
+            model_ddp._register_buffer_comm_hook(model_ddp, buffer_comm_hook)
+            model_ddp_no_hook = torch.nn.parallel.DistributedDataParallel(
+                copy.deepcopy(model),
+                device_ids=[self.rank],
+            )
+            inp = torch.randn(2, 10, device=rank)
+            for _ in range(2):
+                loss_hook = model_ddp(inp).sum()
+                loss_no_hook = model_ddp_no_hook(inp).sum()
+                self._verify_buffers_equal(model_ddp, model_ddp_no_hook)
+                loss_hook.backward()
+                loss_no_hook.backward()
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_remove_autograd_hooks(self):
+            class SimulateError(torch.autograd.Function):
+                @staticmethod
+                def forward(ctx, input):
+                    return input
+
+                @staticmethod
+                def backward(ctx, grad_output):
+                    raise RuntimeError
+
+            class MyModel(nn.Module):
+                def __init__(self, device):
+                    super().__init__()
+                    self.error = True
+                    self.fc1 = nn.Linear(10, 10).cuda(device)
+
+                def forward(self, inp):
+                    if self.error:
+                        return self.fc1(SimulateError.apply(inp))
+                    else:
+                        return self.fc1(inp)
+
+            # Run with error to trigger backward pass that marks fc1 as being marked
+            # ready. If we don't remove autograd hooks before running below it would
+            # fail on the old autograd hook.
+            model = MyModel(self.rank)
+            input = torch.rand(10, 10, requires_grad=True).cuda(self.rank)
+            model_ddp1 = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+            )
+
+            with self.assertRaises(RuntimeError):
+                model_ddp1(input).sum().backward()
+
+            # Remove autograd hooks on old instance.
+            model_ddp1._remove_autograd_hooks()
+
+            # Try another DDP instance without error now.
+            model.error = False
+            model_ddp2 = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+            )
+            model_ddp2(input).sum().backward()
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @unittest.skip(
+            "Test is failing, tracking issue at https://github.com/pytorch/pytorch/issues/102751"
+        )
+        def test_ddp_has_finalized(self):
+            @dataclass
+            class MyClass:
+                obj: torch.Tensor
+
+            class MyModel(nn.Module):
+                def __init__(self, rank):
+                    super().__init__()
+                    self.rank = rank
+                    self.fc1 = nn.Linear(1024, 1024).cuda(rank)
+                    self.fc2 = nn.Linear(1024, 2 * 1024).cuda(rank)
+
+                def forward(self, inp):
+                    if self.rank == 0:
+                        return self.fc1(inp), MyClass(self.fc2(inp))
+                    else:
+                        return self.fc1(inp), self.fc2(inp)
+
+            model = MyModel(self.rank)
+            input = torch.rand(10, 1024, requires_grad=True).cuda(self.rank)
+            ddp = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+                bucket_cap_mb=(1024 * 4 / 1024 / 1024),  # One bucket per parameter.
+            )
+
+            if self.rank == 0:
+                out1, _ = ddp(input)
+                out1.sum().backward()
+            else:
+                out1, out2 = ddp(input)
+                (out1.sum() + out2.sum()).backward()
+
+            if self.rank == 0:
+                with self.assertRaisesRegex(
+                    RuntimeError,
+                    "Expected to have finished reduction in the prior iteration",
+                ):
+                    ddp._check_reducer_finalized()
+
+                with self.assertRaisesRegex(
+                    RuntimeError,
+                    "Expected to have finished reduction in the prior iteration",
+                ):
+                    ddp(input)
+            else:
+                ddp._check_reducer_finalized()
+                ddp(input)
+
+        """
+        # The set of "test_ddp_update_process_group..." below failed after
+        # upgrading CI from 2 GPUs to 4 GPUs.
+        # Commented out for now.
+        # Test purpose needs better documentation.
+
+        def _run_ddp_update_process_group(self, new_pg):
+            def get_num_torch_recompiles():
+                guard_failures = torch._dynamo.utils.guard_failures
+                num_recompiles = [len(guard_failures[code]) for code in guard_failures]
+                return 0 if len(num_recompiles) == 0 else max(num_recompiles)
+
+            class SimulateError(torch.autograd.Function):
+                @staticmethod
+                def forward(ctx, input):
+                    return input
+
+                @staticmethod
+                def backward(ctx, grad_output):
+                    raise RuntimeError
+
+            class MyModel(torch.nn.Module):
+                def __init__(self, device):
+                    super().__init__()
+                    # 4MB for multiple buckets.
+                    self.fc1 = torch.nn.Linear(1024, 1024).cuda(device)
+                    self.fc2 = torch.nn.Linear(1024, 1024).cuda(device)
+                    self.fc3 = torch.nn.Linear(1024, 1024).cuda(device)
+
+                def forward(self, inp, error):
+                    if error:
+                        return self.fc3(self.fc2(self.fc1(SimulateError.apply(inp))))
+                    else:
+                        return self.fc3(self.fc2(self.fc1(inp)))
+
+
+            input = torch.rand(10, 1024, requires_grad=True).cuda(self.rank)
+            ddp = torch.nn.parallel.DistributedDataParallel(
+                MyModel(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+                bucket_cap_mb=1,
+            )
+            model = torch.compile(ddp)
+
+            def run_iteration():
+                # Run regular iteration.
+                out = model(input, error=False)
+                out.sum().backward()
+                torch.cuda.synchronize()
+
+                # Run with error.
+                with self.assertRaises(RuntimeError):
+                    out = model(input, error=True)
+                    out.sum().backward()
+                torch.cuda.synchronize()
+
+            run_iteration()
+            assert 0 == get_num_torch_recompiles()
+
+            if new_pg:
+                # Now reduce world_size and run iteration.
+                group_size_2 = dist.new_group(ranks=[0, 1])
+                ddp._update_process_group(group_size_2)
+                if self.rank in [0, 1]:
+                    run_iteration()
+
+                # Increase the world size and run iteration.
+                group_size_3 = dist.new_group(ranks=[1, 2, 3])
+                ddp._update_process_group(group_size_3)
+                if self.rank in [1, 2, 3]:
+                    run_iteration()
+
+                # Back to default size.
+                ddp._update_process_group(_get_default_group())
+                run_iteration()
+            else:
+                # Create default pg of smaller size.
+                dist.destroy_process_group()
+
+                if self.rank in [1, 2, 3]:
+                    dist.init_process_group(
+                        init_method=self.init_method,
+                        backend=BACKEND,
+                        world_size=3,
+                        rank=self.rank - 1,
+                        timeout=timedelta(seconds=default_pg_timeout),
+                    )
+                    ddp._update_process_group(_get_default_group())
+                    run_iteration()
+                    dist.destroy_process_group()
+
+                # Need a barrier here to ensure ranks 1, 2 and 3 are done.
+                self._barrier(wait_for=4)
+
+                # Need to init pg again for "_barrier" to succeed.
+                dist.init_process_group(
+                    init_method=self.init_method,
+                    backend=BACKEND,
+                    world_size=4,
+                    rank=self.rank,
+                    timeout=timedelta(seconds=default_pg_timeout),
+                )
+
+            # Validate no more recompiles.
+            assert 0 == get_num_torch_recompiles()
+
+        @skip_if_lt_x_gpu(4)
+        @require_world_size(4)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_update_process_group_new_group(self):
+            self._run_ddp_update_process_group(new_pg=True)
+
+        @skip_if_lt_x_gpu(4)
+        @require_world_size(4)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_update_process_group_default_group(self):
+            self._run_ddp_update_process_group(new_pg=False)
+
+        @skip_if_lt_x_gpu(4)
+        @require_world_size(4)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_update_process_group_grad_undefined(self):
+            class SimulateError(torch.autograd.Function):
+                @staticmethod
+                def forward(ctx, input):
+                    return input
+
+                @staticmethod
+                def backward(ctx, grad_output):
+                    raise RuntimeError
+
+            class MyModel(torch.nn.Module):
+                def __init__(self, device):
+                    super().__init__()
+                    self.fc1 = torch.nn.Linear(10, 10).cuda(device)
+                    self.fc2 = torch.nn.Linear(10, 10).cuda(device)
+                    self.fc3 = torch.nn.Linear(10, 10).cuda(device)
+
+                def forward(self, inp, error):
+                    if error:
+                        return self.fc3(self.fc2(self.fc1(SimulateError.apply(inp))))
+                    else:
+                        return self.fc2(self.fc1(inp))
+
+
+            input = torch.rand(10, 10, requires_grad=True).cuda(self.rank)
+            ddp = torch.nn.parallel.DistributedDataParallel(
+                MyModel(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+                bucket_cap_mb=1,
+            )
+
+            try:
+                ddp(input, True).sum().backward()
+            except RuntimeError:
+                ddp._update_process_group(_get_default_group())
+
+            # Reset grads.
+            for param in ddp.parameters():
+                param.grad = None
+
+            # Run ddp again.
+            ddp(input, False).sum().backward()
+
+        @skip_if_lt_x_gpu(4)
+        @require_world_size(4)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_update_process_group_no_find_unused(self):
+            ddp = torch.nn.parallel.DistributedDataParallel(
+                torch.nn.Linear(10, 10).cuda(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=False,
+            )
+            ddp._update_process_group(_get_default_group())
+        """
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_broadcast_buffer(self):
+            rank = self.rank
+            torch.cuda.set_device(rank)
+            torch.manual_seed(rank)
+            torch.cuda.manual_seed(rank)
+
+            class NetWithBuffers(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.a = nn.Linear(10, 10, bias=False)
+                    self.b = nn.Linear(10, 1, bias=False)
+                    self.register_buffer("buffer", torch.randn(1, 2))
+
+                def forward(self, x):
+                    return self.b(self.a(x))
+
+            model = NetWithBuffers().cuda(rank)
+            model_ddp = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+            )
+            inp = torch.randn(2, 10, device=rank)
+            for _ in range(2):
+                if rank == 0:
+                    model_ddp.module.buffer = model_ddp.module.buffer + 1
+                loss = model_ddp(inp).sum()
+                loss.backward()
+                # Ensure all buffers are synchronized.
+                bufs = [
+                    torch.empty_like(model_ddp.module.buffer)
+                    for _ in range(dist.get_world_size())
+                ]
+                dist.all_gather(bufs, model_ddp.module.buffer)
+                rank_0_buf = bufs[0]
+                for buf in bufs[1:]:
+                    self.assertEqual(rank_0_buf, buf)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl" and BACKEND != "gloo",
+            "Only Nccl & Gloo backend support DistributedDataParallel",
+        )
+        def test_static_graph_multi_forward(self):
+            class Net(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.lin = nn.Linear(10, 10)
+                    self.relu = nn.ReLU()
+
+                def forward(self, x):
+                    return self.relu(self.lin(x))
+
+            torch.cuda.set_device(self.rank)
+            torch.manual_seed(42 << 1337 % (self.rank + 1))
+            model = Net().cuda(self.rank)
+            local_model = copy.deepcopy(model)
+            model = torch.nn.parallel.DistributedDataParallel(
+                model, device_ids=[self.rank], static_graph=True
+            )
+            inp = torch.ones(2, 10, device="cuda")
+            for _ in range(3):
+                model.zero_grad()
+                local_model.zero_grad()
+                a = model(inp)
+                b = model(inp)
+                loss = a.sum() + b.sum()
+                loss.backward()
+                # Grads should be equal to a local model that ran through inp
+                # `world_size` times and averaged grads
+                if self.rank == 0:
+                    inp_clone = inp.clone()
+                    iters = dist.get_world_size()
+                    for _ in range(iters):
+                        a = local_model(inp_clone)
+                        b = local_model(inp_clone)
+                        loss = a.sum() + b.sum()
+                        loss.backward()
+
+                    for p in local_model.parameters():
+                        p.grad.data = p.grad / iters
+
+                    for p_ddp, p_local in zip(
+                        model.parameters(), local_model.parameters(), strict=True
+                    ):
+                        self.assertTrue(
+                            torch.allclose(p_ddp.grad, p_local.grad),
+                            f"{p_ddp.grad} vs {p_local.grad}",
+                        )
+
+            dist.barrier()
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl" and BACKEND != "gloo",
+            "Only Nccl & Gloo backend support DistributedDataParallel",
+        )
+        def test_sync_bn_logged(self):
+            model = BatchNormNet()
+            rank = self.rank
+            # single gpu training setup
+            model_gpu = model.cuda(rank)
+            no_sync_bn = torch.nn.parallel.DistributedDataParallel(
+                copy.deepcopy(model_gpu),
+                device_ids=[self.rank],
+            )
+            ddp_logging_data = no_sync_bn._get_ddp_logging_data()
+            sync_bn_logged = ddp_logging_data.get("has_sync_bn", True)
+            self.assertFalse(sync_bn_logged)
+            model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(model_gpu)
+            model_DDP = torch.nn.parallel.DistributedDataParallel(
+                model_DDP,
+                device_ids=[self.rank],
+            )
+            ddp_logging_data = model_DDP._get_ddp_logging_data()
+            sync_bn_logged = ddp_logging_data.get("has_sync_bn", False)
+            self.assertTrue(sync_bn_logged)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_stateless_api_with_ddp(self):
+            class MockModule(torch.nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.l1 = torch.nn.Linear(1, 1)
+                    buffer = torch.ones(1)
+                    self.register_buffer("buffer", buffer)
+
+                def forward(self, x):
+                    return self.l1(x) + self.buffer
+
+            device = self.rank
+            module = MockModule().to(device)
+            module = torch.nn.parallel.DistributedDataParallel(
+                module, device_ids=[device]
+            )
+            x = torch.rand((1, 1)).to(device)
+            weight = torch.tensor([[1.0]], device=device, requires_grad=True)
+            bias = torch.tensor([0.0], device=device, requires_grad=True)
+            buffer = torch.tensor([0.0], device=device)
+            parameters = {
+                "module.l1.weight": weight,
+                "module.l1.bias": bias,
+                "module.buffer": buffer,
+            }
+            prev_weight = module.module.l1.weight.clone()
+            prev_buffer = module.module.buffer.clone()
+
+            res = torch.func.functional_call(module, parameters, x)
+            self.assertEqual(x, res)
+            # check that the weight remain unmodified
+            cur_weight = module.module.l1.weight
+            cur_buffer = module.module.buffer
+            self.assertEqual(cur_weight, prev_weight)
+            self.assertEqual(cur_buffer, prev_buffer)
+            # run a backward pass and check the gradients
+            res.backward()
+            self.assertIsNotNone(weight.grad)
+            self.assertIsNotNone(bias.grad)
+            # Gradient was not calculated for the module stated and buffers
+            self.assertIsNone(buffer.grad)
+            self.assertIsNone(module.module.l1.weight.grad)
+            self.assertIsNone(module.module.l1.bias.grad)
+            self.assertIsNone(module.module.buffer.grad)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_forward_backward_hook(self):
+            class DummyTestModel(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    torch.manual_seed(0)
+                    self.fc = nn.Linear(2, 2)
+
+                def forward(self, x):
+                    return self.fc(x)
+
+            def relu_hook(module, input):
+                return nn.functional.relu(input[0])
+
+            def gelu_hook(module, _input, output):
+                return nn.functional.gelu(output)
+
+            def celu_hook(module, _input, output):
+                return (nn.functional.celu(output[0]),)
+
+            local_model = DummyTestModel()
+            ddp_model = DummyTestModel()
+            local_model.fc.register_forward_pre_hook(relu_hook)
+            local_model.fc.register_forward_hook(gelu_hook)
+            ddp_model.fc.register_forward_pre_hook(relu_hook)
+            ddp_model.fc.register_forward_hook(gelu_hook)
+            local_model.fc.register_backward_hook(celu_hook)
+            ddp_model.fc.register_backward_hook(celu_hook)
+            ddp_model = DistributedDataParallel(
+                ddp_model.to(self.rank), device_ids=[self.rank]
+            )
+            input_data = torch.rand(5, 2)
+            output_local = local_model(input_data)
+            output_ddp = ddp_model(input_data.to(self.rank))
+            self.assertEqual(output_local, output_ddp)
+            output_local.sum().backward()
+            output_ddp.sum().backward()
+            ddp_grads = [p.grad for p in ddp_model.parameters()]
+            self.assertEqual(ddp_grads[0], local_model.fc.weight.grad)
+            self.assertEqual(ddp_grads[1], local_model.fc.bias.grad)
+
+        def _test_hook_pickling(self, hook, hook_state):
+            torch.manual_seed(0)
+            learning_rate = 0.01
+            chkpt_file = tempfile.gettempdir() + "/checkpoint.pt"
+            rank = self.rank
+
+            input = torch.randn(7, 1, device=rank)
+            target = torch.randn(7, 5, device=rank)
+            net = torch.nn.Linear(1, 5).to(rank)
+            ddp_model = DistributedDataParallel(copy.deepcopy(net), device_ids=[rank])
+            dummy_ddp_model = DistributedDataParallel(
+                copy.deepcopy(net), device_ids=[rank]
+            )
+            optimizer = torch.optim.SGD(ddp_model.parameters(), lr=learning_rate)
+            ddp_model.register_comm_hook(hook_state, hook)
+            ddp_model.train()
+
+            for _ in range(10):
+                optimizer.zero_grad()
+                out = ddp_model(input)
+                loss = F.mse_loss(out, target)
+                loss.backward()
+                optimizer.step()
+
+            state = {
+                "state_dict": ddp_model.state_dict(),
+                "comm_hook": hook,
+                "comm_hook_state": hook_state,
+            }
+
+            if rank == 0:
+                with self.assertLogs("torch.distributed") as captured:
+                    torch.save(state, chkpt_file)
+
+                # Check that the logger has only one entry
+                self.assertEqual(len(captured.records), 1)
+                # Check that the logger has an expected entry
+                self.assertEqual(
+                    captured.records[0].getMessage(),
+                    "NOTE: Process group is not serializable and excluded from a saved state.",
+                )
+
+            dist.barrier()
+            map_location = {"cuda:0": f"cuda:{rank:d}"}
+            with self.assertLogs("torch.distributed") as captured:
+                checkpoint = torch.load(chkpt_file, map_location=map_location)
+
+            # Check that the logger has only one entry
+            self.assertEqual(len(captured.records), 1)
+            # Check that the logger has an expected entry
+            self.assertEqual(
+                captured.records[0].getMessage(),
+                "NOTE: Process group will be set to a default group (i.e. the world size).\
+                If a different group is desired, please set `self.process_group` after PowerSGD state is loaded.",
+            )
+
+            dummy_ddp_model.load_state_dict(checkpoint["state_dict"])
+            dummy_hook = checkpoint["comm_hook"]
+            dummy_hook_state = checkpoint["comm_hook_state"]
+            dummy_optimizer = torch.optim.SGD(
+                dummy_ddp_model.parameters(), lr=learning_rate
+            )
+
+            # Check that loaded function is correct
+            self.assertEqual(dummy_hook.__qualname__, hook.__qualname__)
+
+            # Check that all slots' keys were restored correctly
+            self.assertEqual(hook_state.__slots__, dummy_hook_state.__slots__)
+
+            # Check that all slots' attributes are restored correctly
+            # Excluding ``process_group`` and ``rng``.
+            for entry in dummy_hook_state.__slots__:
+                if entry != "process_group" and entry != "rng":
+                    self.assertEqual(
+                        getattr(dummy_hook_state, entry), getattr(hook_state, entry)
+                    )
+
+            # Check that ``process_group`` was set to default
+            self.assertEqual(dummy_hook_state.process_group, _get_default_group())
+
+            # Check that a random state was restored properly:
+            # ``np.random.RandomState.get_state`` returns a tuple with entries:
+            # ``bit_generator`` - str,
+            # ``state.key`` - ndarray dtype[uint32],
+            # ``state.pos`` - int,
+            # ``has_gauss`` - int,
+            # ``gauss`` - float
+            #  (refer to https://github.com/numpy/numpy/blob/266aad7478bc7fbcc55eea7f942a0d373b838396/numpy/random/mtrand.pyi)
+            # To make sure random state was restored properly, all entries should equal the original
+            for entry1, entry2 in zip(
+                hook_state.rng.get_state(),
+                dummy_hook_state.rng.get_state(),
+                strict=True,
+            ):
+                np.testing.assert_array_equal(entry1, entry2)
+
+            dummy_ddp_model.register_comm_hook(dummy_hook_state, dummy_hook)
+            dummy_ddp_model.train()
+
+            for _ in range(10):
+                optimizer.zero_grad()
+                dummy_optimizer.zero_grad()
+                out_origin = ddp_model(input)
+                out_dummy = dummy_ddp_model(input)
+                loss_origin = F.mse_loss(out_origin, target)
+                loss_dummy = F.mse_loss(out_dummy, target)
+                loss_origin.backward()
+                loss_dummy.backward()
+                optimizer.step()
+                dummy_optimizer.step()
+
+            # Check that gradients after 10 epochs are the same
+            for orig_param, dummy_param in zip(
+                ddp_model.parameters(), dummy_ddp_model.parameters(), strict=True
+            ):
+                self.assertEqual(orig_param.grad, dummy_param.grad)
+
+            dist.barrier()
+            if rank == 0:
+                os.remove(chkpt_file)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["cuda"],
+            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        @skip_but_pass_in_sandcastle_if(True, "Skipped due to flakiness")
+        def test_ddp_hook_pickling_powerSGD(self):
+            hook = powerSGD.powerSGD_hook
+            powersgd_state = powerSGD.PowerSGDState(
+                process_group=None,
+                matrix_approximation_rank=1,
+                start_powerSGD_iter=4,
+            )
+            self._test_hook_pickling(hook, powersgd_state)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_device_mesh_initialization(self):
+            """
+            Test DDP with device_mesh initialization.
+            """
+            world_size = int(os.environ["WORLD_SIZE"])
+
+            from torch.distributed.device_mesh import init_device_mesh
+
+            device_mesh = init_device_mesh("cuda", (world_size,))
+
+            pg = _get_default_group()
+
+            torch.cuda.set_device(self.rank)
+            model = TwoLinLayerNet().cuda()
+            ddp_model = torch.nn.parallel.DistributedDataParallel(
+                model, device_mesh=device_mesh
+            )
+            self.assertEqual(ddp_model.device_mesh, device_mesh)
+
+            with self.assertRaisesRegex(
+                RuntimeError,
+                "Cannot specify both process_group and device_mesh arguments.",
+            ):
+                ddp_model = torch.nn.parallel.DistributedDataParallel(
+                    model, process_group=pg, device_mesh=device_mesh
+                )
+
+            with self.assertRaisesRegex(
+                RuntimeError, "Only 1D device mesh is supported,"
+            ):
+                device_mesh = init_device_mesh("cuda", (2, world_size // 2))
+                ddp_model = torch.nn.parallel.DistributedDataParallel(
+                    model, device_mesh=device_mesh
+                )
+
+        @skip_if_lt_x_gpu(2)
+        @require_world_size(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_compile_static_graph(self):
+            "Tests that DDP works with torch compile when static_graph=True"
+            model = torch.nn.Linear(10, 10).cuda(self.rank)
+            model_clone = copy.deepcopy(model)
+            ddp = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+            )
+            ddp_static = torch.nn.parallel.DistributedDataParallel(
+                model_clone, device_ids=[self.rank], static_graph=True
+            )
+            ddp = torch.compile(ddp)
+            ddp_static = torch.compile(ddp_static)
+            input = torch.rand(10, 10).cuda(self.rank)
+            # verify output and gradient parity
+            for _ in range(6):
+                out_ddp = ddp(input).sum()
+                out_ddp_static = ddp_static(input).sum()
+                self.assertEqual(out_ddp, out_ddp_static)
+                out_ddp.backward()
+                out_ddp_static.backward()
+                for p1, p2 in zip(
+                    ddp.parameters(), ddp_static.parameters(), strict=True
+                ):
+                    self.assertEqual(p1.grad, p2.grad)
+
+        @skip_if_lt_x_gpu(2)
+        @require_world_size(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_sink_noclone(self):
+            "Tests that we can configure DDP to avoid clone"
+
+            class OpPatcher(TorchDispatchMode):
+                def __torch_dispatch__(self, func, types, args=(), kwargs=None):
+                    func_packet = func._overloadpacket
+                    if func_packet == torch.ops.aten.clone:
+                        raise RuntimeError("clone encountered!")
+                    kwargs = kwargs if kwargs else {}
+                    return func(*args, **kwargs)
+
+            class MyModel(torch.nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.fc = torch.nn.Linear(10, 10)
+
+                def forward(self, input):
+                    return self.fc(input)
+
+            model = MyModel().cuda(self.rank)
+            ddp = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+            )
+            ddp._set_ddp_sink_clone(False)
+            input = torch.rand(10, 10).cuda(self.rank)
+
+            with OpPatcher():
+                ddp(input).sum().backward()
+
+        def _test_skip_all_reduce_unused_parameters(
+            self,
+            find_unused_parameters=False,
+            static_graph=False,
+            skip_all_reduce_unused_params=False,
+        ):
+            class LargeNet(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.fc1 = nn.Linear(100, 5000, bias=False)
+                    # fc2 is unused
+                    self.fc2 = nn.Linear(100, 100, bias=False)
+
+                def forward(self, x):
+                    y = self.fc1(x)
+                    return y
+
+            torch.manual_seed(31415)
+            torch.cuda.set_device(self.rank)
+            model = LargeNet().cuda(self.rank)
+            ddp_model = torch.nn.parallel.DistributedDataParallel(
+                model,
+                find_unused_parameters=find_unused_parameters,
+                static_graph=static_graph,
+                bucket_cap_mb=1.5,
+                skip_all_reduce_unused_params=skip_all_reduce_unused_params,
+            )
+            random_input = torch.randn(20, 100, device=self.rank)
+            for _ in range(10):
+                out = ddp_model(random_input)
+                loss = out.sum()
+                loss.backward()
+            return ddp_model
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_skip_all_reduce_unused_parameters(self):
+            base_model = self._test_skip_all_reduce_unused_parameters(
+                find_unused_parameters=True, static_graph=False
+            )
+            test_model_1 = self._test_skip_all_reduce_unused_parameters(
+                find_unused_parameters=True,
+                static_graph=False,
+                skip_all_reduce_unused_params=True,
+            )
+
+            self.assertEqual(
+                base_model._get_ddp_logging_data().get("num_buckets_reduced"), 2
+            )
+            self.assertEqual(
+                test_model_1._get_ddp_logging_data().get("num_buckets_reduced"), 1
+            )
+
+            for i, j in zip(
+                base_model.parameters(), test_model_1.parameters(), strict=True
+            ):
+                self.assertEqual(i, j)
+
+
+instantiate_parametrized_tests(DistributedTest._DistTestBase)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/distributed_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/distributed_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..10002da5854421a2d53076eb8458f42ac7a1e4e2
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/distributed_utils.py
@@ -0,0 +1,67 @@
+# mypy: allow-untyped-defs
+
+from contextlib import contextmanager
+from datetime import timedelta
+from functools import partial, wraps
+
+import torch.distributed as dist
+import torch.distributed.distributed_c10d as c10d
+
+
+class MockProcessGroup(dist.ProcessGroup):
+    def getBackendName(self):
+        return "mock_process_group"
+
+
+def create_mock_pg(prefix_store, rank, world_size, timeout):
+    return MockProcessGroup(rank, world_size)
+
+
+dist.Backend.register_backend("mock_process_group", create_mock_pg)
+
+
+def mock_init_dist(rank, world_size):
+    # !!! WARNING !!!
+    # Kids don't try this at home, this is a cute pile of hacks that
+    # depends on a small mountain of c10d internals
+    assert not dist.is_initialized()
+    store = dist.HashStore()
+    # Trick _store_based_barrier into believing everyone else already checked-in
+    # Zero is the group index
+    store.add(f"{c10d.STORE_BASED_BARRIER_PREFIX}:0", world_size - 1)
+    dist.init_process_group(
+        backend="mock_process_group",
+        rank=rank,
+        world_size=world_size,
+        store=store,
+        group_name="fake",
+        timeout=timedelta(seconds=1),
+    )
+
+
+@contextmanager
+def with_dist(rank=0, world_size=2):
+    """
+    Context manager that initializer c10d with a fake process group.
+    """
+    mock_init_dist(rank=rank, world_size=world_size)
+    try:
+        yield
+    finally:
+        dist.destroy_process_group()
+
+
+def with_fake_comms(func=None, rank=0, world_size=2):
+    """
+    Function wrapper that inits a fake process group designed for testing.
+    Right now only querying for world size is available
+    """
+    if func is None:
+        return partial(with_fake_comms, rank=rank, world_size=world_size)
+
+    @wraps(func)
+    def wrapper(self, *args, **kwargs):
+        with with_dist(rank, world_size):
+            func(self, *args, **kwargs)
+
+    return wrapper
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/fake_pg.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/fake_pg.py
new file mode 100644
index 0000000000000000000000000000000000000000..de9c2cc7ee52093b555d94e5f4426fcbb6721b47
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/fake_pg.py
@@ -0,0 +1,32 @@
+# mypy: allow-untyped-defs
+
+import torch.distributed as dist
+from torch._C._distributed_c10d import FakeProcessGroup
+
+
+class FakeStore(dist.Store):
+    """
+    A fake store is a fake Key-Value store simply for initialization usage
+    the of fake process group, one can either use FakeStore or HashStore.
+    """
+
+
+def _create_fake_pg(common_opts, backend_opts):
+    """
+    A fake process group (not related to FakeTensor) is a process group which
+    doesn't actually do any communication, it just hallucinates some
+    communication.  You can run a single rank with a fake process group
+    without needing multiple processes (simulates per-rank behavior)
+
+    NOTE: This is not a real process group, and it would produce wrong results
+    for every collective. It should be used as a convenient tool when playing
+    with distributed but don't care about the actual data.
+    """
+    return FakeProcessGroup._create_internal(
+        common_opts.group_rank, common_opts.group_size, backend_opts
+    )
+
+
+dist.Backend.register_backend(
+    "fake", _create_fake_pg, extended_api=True, devices=["cpu", "cuda", "hpu", "xpu"]
+)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/multi_threaded_pg.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/multi_threaded_pg.py
new file mode 100644
index 0000000000000000000000000000000000000000..79aff05b3421f37cf63501e5692f84723be73439
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/multi_threaded_pg.py
@@ -0,0 +1,611 @@
+# mypy: allow-untyped-defs
+
+import sys
+import threading
+import weakref
+from dataclasses import dataclass
+from functools import partial, reduce
+from typing import Optional, Union
+
+import torch
+import torch.distributed as dist
+from torch._C._distributed_c10d import (
+    _create_work_from_future,
+    AllgatherOptions,
+    AllreduceOptions,
+    AllToAllOptions,
+    BarrierOptions,
+    BroadcastOptions,
+    ReduceOp,
+    ReduceScatterOptions,
+    ScatterOptions,
+    Store,
+)
+from torch.distributed.distributed_c10d import _CollOp, _store_based_barrier, P2POp
+from torch.futures import Future
+from torch.utils import _pytree as pytree
+
+
+"""
+TODO:
+Lots of missing collectives.
+Collectives validation.
+Make timeout robust by making collectives respect the test deadline.
+Make tests robust by making collectives interruptible.
+We need some synchronization around cleanup to ensure that timedout ranks don't cause spurious failures.
+
+"""
+
+
+def flatten_list(lst):
+    return pytree.tree_leaves(lst)
+
+
+def ret_work(ret):
+    fut = Future()
+    fut.set_result(ret)
+    return _create_work_from_future(fut)
+
+
+def binop_reduce(tensors, op):
+    res = op(torch.stack(tensors), dim=0)
+    if isinstance(res, torch.Tensor):
+        return res
+    # min/max return a namedtuple
+    return res.values
+
+
+def bitwise_reduce(tensors, op):
+    return reduce(op, tensors)
+
+
+_reduce_ops = {
+    ReduceOp.SUM: partial(binop_reduce, op=torch.sum),
+    ReduceOp.AVG: partial(binop_reduce, op=torch.mean),
+    ReduceOp.PRODUCT: partial(binop_reduce, op=torch.prod),
+    ReduceOp.MIN: partial(binop_reduce, op=torch.min),
+    ReduceOp.MAX: partial(binop_reduce, op=torch.max),
+    ReduceOp.BAND: partial(bitwise_reduce, op=torch.bitwise_and),
+    ReduceOp.BOR: partial(bitwise_reduce, op=torch.bitwise_or),
+    ReduceOp.BXOR: partial(bitwise_reduce, op=torch.bitwise_xor),
+}
+
+
+# Note [Hide collectives mutation from autograd]
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# Threaded PG is intended to closely simulate the behavior of regular process
+# groups.  However, our regular PG implementations perform a dispatch through
+# c10d, whereas Threaded PG does not for some reason (some superficial
+# but not very convincing reasons include that Threaded PG is implemented
+# in Python but you can't override Backend in Python, you can only override
+# ProcessGroup in Python), thereby bypassing the dispatch step.  Now we have
+# a problem: c10d's signatures are LIES, they mutate their (output) tensor
+# arguments but their annotations don't have mutations on them so we don't
+# actually update any view metadata if you do differentiation.  This
+# ordinarily "doesn't matter" because distributed collectives aren't
+# differentiable anyway, but it's possible to tickle this in testing if
+# someone tries to touch the grad_fn of a Tensor.  There a few ways to
+# fix this, but the easiest way was to use the .detach() trick to hide
+# the mutations from autograd.
+
+
+class AllToAll:
+    @torch.no_grad()
+    def work(self, data):
+        world_size = len(data)
+        for dest_rank in range(world_size):
+            output_tensor_list, _ = data[dest_rank]
+            for src_rank in range(world_size):
+                _, input_tensor_list = data[src_rank]
+                # See Note [Hide collectives mutation from autograd]
+                output_tensor_list[src_rank].detach().copy_(
+                    input_tensor_list[dest_rank]
+                )
+
+
+class AllToAllBase:
+    @torch.no_grad()
+    def work(self, data):
+        world_size = len(data)
+        for dest_rank in range(world_size):
+            output_buffer, _, output_split_sizes, _ = data[dest_rank]
+
+            output_indexes = self._size_cumsum(
+                output_buffer.size(0), output_split_sizes, world_size
+            )
+
+            for src_rank in range(world_size):
+                _, input_buffer, _, input_split_sizes = data[src_rank]
+                input_indexes = self._size_cumsum(
+                    input_buffer.size(0), input_split_sizes, world_size
+                )
+
+                # See Note [Hide collectives mutation from autograd]
+                output_buffer[
+                    output_indexes[src_rank] : output_indexes[src_rank + 1]
+                ].detach().copy_(
+                    input_buffer[
+                        input_indexes[dest_rank] : input_indexes[dest_rank + 1]
+                    ]
+                )
+
+    def _size_cumsum(
+        self,
+        buf_size: int,
+        sizes: Union[torch.Tensor, list[int], None],
+        world_size: int,
+    ) -> torch.Tensor:
+        if sizes is None or len(sizes) == 0:
+            sizes = torch.full((world_size,), buf_size // world_size, dtype=torch.int64)
+        if not isinstance(sizes, torch.Tensor):
+            sizes = torch.tensor(sizes, dtype=torch.int64)
+        assert sizes.dtype == torch.int64
+        sizes = torch.cumsum(
+            torch.cat(
+                (torch.tensor([0], dtype=torch.int64, device=sizes.device), sizes),
+                dim=0,
+            ),
+            dim=0,
+        )
+        return sizes
+
+
+class AllReduce:
+    def __init__(self, op):
+        if op.op not in _reduce_ops:
+            raise NotImplementedError(
+                f"AllReduce op {op.op} not supported on multithreaded pg for now."
+            )
+        self.op = op.op
+
+    @torch.no_grad()
+    def work(self, data):
+        for i in range(len(data[0])):
+            # use rank0 as the device for sum
+            rank_0_device = data[0][i].device
+            # collect all data to the list and make them
+            # all on rank 0 device
+            tensors = [
+                data[src_rank][i].to(rank_0_device) for src_rank in range(len(data))
+            ]
+
+            # now mimic reduce across all ranks
+            res = _reduce_ops[self.op](tensors)
+
+            # copy all the reduced value to each rank
+            for src_rank in range(len(data)):
+                # See Note [Hide collectives mutation from autograd]
+                data[src_rank][i].detach().copy_(res.to(data[src_rank][i].device))
+
+
+class AllGather:
+    @torch.no_grad()
+    def work(self, data):
+        for src_rank in range(len(data)):
+            in_tensor_list = data[src_rank][1]
+            # Can't handle all_gather with multiple tensors
+            assert len(in_tensor_list) == 1
+            src_tensor = in_tensor_list[0]
+
+            for dest in data:
+                dest_tensor = dest[0][0][src_rank]
+                # See Note [Hide collectives mutation from autograd]
+                dest_tensor.detach().copy_(src_tensor)
+
+
+class Scatter:
+    def __init__(self, src):
+        self.src = src
+
+    @torch.no_grad()
+    def work(self, data):
+        src_in_tensor_list = data[self.src][1]
+        # Can't handle scatter with multiple input tensor list
+        assert len(src_in_tensor_list) == 1
+        src_in_tensors = src_in_tensor_list[0]
+
+        for rank, each_rank_data in enumerate(data):
+            out_tensor_list = each_rank_data[0]
+            # Can't handle scatter with multiple output tensor
+            assert len(out_tensor_list) == 1
+            dest_tensor = out_tensor_list[0]
+            # See Note [Hide collectives mutation from autograd]
+            dest_tensor.detach().copy_(src_in_tensors[rank])
+
+
+class Gather:
+    def __init__(self, dst):
+        self.dst = dst
+
+    @torch.no_grad()
+    def work(self, data):
+        # Can't handle gather with multiple tensor lists
+        assert len(data[self.dst][0]) == 1
+        out_tensor_list = data[self.dst][0][0]
+        for rank, each_rank_data in enumerate(data):
+            src_in_tensor_list = each_rank_data[1]
+            # Can't handle gather with multiple tensor lists
+            assert len(src_in_tensor_list) == 1
+            dest_tensor = out_tensor_list[rank]
+            # See Note [Hide collectives mutation from autograd]
+            dest_tensor.detach().copy_(src_in_tensor_list[0])
+
+
+class ReduceScatter:
+    def __init__(self, op):
+        if op != dist.ReduceOp.SUM and op != dist.ReduceOp.AVG:
+            raise NotImplementedError(f"ReduceScatter does not support {op}")
+        self.op = op
+
+    @torch.no_grad()
+    def work(self, data):
+        start_reduction = [False for _ in range(len(data))]
+        for each_rank_data in data:
+            # Can't handle reduce_scatter with multiple scatter list
+            assert len(each_rank_data[1]) == 1
+            to_scatter = each_rank_data[1][0]
+            for i in range(len(to_scatter)):
+                dest_tensor_on_rank_i = data[i][0]
+                # Can't handle reduce_scatter with multiple output tensor
+                assert len(dest_tensor_on_rank_i) == 1
+                dst_tensor_device = dest_tensor_on_rank_i[0].device
+                if not start_reduction[i]:
+                    # See Note [Hide collectives mutation from autograd]
+                    dest_tensor_on_rank_i[0].detach().copy_(
+                        to_scatter[i].to(dst_tensor_device)
+                    )
+                    start_reduction[i] = True
+                else:
+                    # See Note [Hide collectives mutation from autograd]
+                    dest_tensor_on_rank_i[0].detach().add_(
+                        to_scatter[i].to(dst_tensor_device)
+                    )
+        if self.op == dist.ReduceOp.AVG:
+            num_ranks = len(data)
+            for each_rank_data in data:
+                # See Note [Hide collectives mutation from autograd]
+                each_rank_data[0][0].detach().div_(num_ranks)
+
+
+class Broadcast:
+    def __init__(self, src):
+        self.src = src
+
+    @torch.no_grad()
+    def work(self, data):
+        in_tensor_list = flatten_list(data[self.src])
+        for i in range(len(data)):
+            if i == self.src:
+                continue
+            out_tensor_list = flatten_list(data[i])
+            for j in range(len(in_tensor_list)):
+                # See Note [Hide collectives mutation from autograd]
+                out_tensor_list[j].detach().copy_(in_tensor_list[j])
+
+
+class Collective:
+    def __init__(self, world_size, collective, pg):
+        self._world_size = world_size
+        self._collective = collective
+
+        self._start_cond = threading.Condition()
+        self._done_cond = threading.Condition()
+
+        self._data = [None] * world_size
+        self._count = 0
+        self._done = False
+
+        self._pg = pg
+
+    def join(self, rank, data):
+        with self._start_cond:
+            self._data[rank] = data
+            self._count += 1
+
+            # notify rank 0
+            if self._count == self._world_size:
+                if rank > 0:
+                    self._start_cond.notify()
+
+            if rank == 0:
+                self._start_cond.wait_for(
+                    lambda: self._count == self._world_size
+                    or self._pg._terminate.is_set()
+                )
+                # SystemExit is not a subclass of Exception but BaseException
+                # and can be distinguished from normal exception raised from program errors
+                # so that we can hide it from the exception queue
+                if self._pg._terminate.is_set():
+                    sys.exit("Test termination event occurs.")
+
+        with self._done_cond:
+            # wait for rank 0 to finish
+            if rank > 0:
+                self._done_cond.wait_for(
+                    lambda: self._done or self._pg._terminate.is_set()
+                )
+                if self._pg._terminate.is_set():
+                    sys.exit("Test termination event occurs.")
+            else:
+                # copy data around
+                self._collective.work(self._data)
+                self._done = True
+                self._done_cond.notify_all()
+        return ret_work(data)
+
+
+class ProcessLocalGroup(dist.ProcessGroup):
+    _coll_lock = threading.Lock()
+    _cur_coll_on_pgs = {}
+
+    _terminate = threading.Event()
+
+    @classmethod
+    def _start_coll(cls, collective, pg):
+        with cls._coll_lock:
+            # pg_name is unique, we use that to record the mapping between pg and collective
+            if pg.pg_name not in cls._cur_coll_on_pgs:
+                cls._cur_coll_on_pgs[pg.pg_name] = Collective(
+                    pg.size(), collective, cls
+                )
+            return cls._cur_coll_on_pgs[pg.pg_name]
+
+    @classmethod
+    def _end_coll(cls, collective, pg):
+        # This is racily called by all ranks, so only one will work
+        with cls._coll_lock:
+            if (
+                pg.pg_name in cls._cur_coll_on_pgs
+                and cls._cur_coll_on_pgs[pg.pg_name] == collective
+            ):
+                cls._cur_coll_on_pgs.pop(pg.pg_name)
+
+    @classmethod
+    def exception_handle(cls, exc):
+        cls._terminate.set()
+        for coll in cls._cur_coll_on_pgs.values():
+            with coll._start_cond:
+                coll._start_cond.notify()
+            with coll._done_cond:
+                coll._done_cond.notify_all()
+
+    @classmethod
+    def reset(cls):
+        with cls._coll_lock:
+            cls._cur_coll_on_pgs = {}
+            cls._terminate.clear()
+
+    def alltoall_base(
+        self,
+        output_buffer: torch.Tensor,
+        input_buffer: torch.Tensor,
+        output_split_sizes: Optional[list[int]],
+        input_split_sizes: Optional[list[int]],
+        opts=AllToAllOptions(),
+    ) -> torch.Tensor:
+        coll = ProcessLocalGroup._start_coll(AllToAllBase(), self)
+        res = coll.join(
+            self._rank,
+            (output_buffer, input_buffer, output_split_sizes, input_split_sizes),
+        )
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def alltoall(self, output_tensor_list, input_tensor_list, opts=AllToAllOptions()):
+        coll = ProcessLocalGroup._start_coll(AllToAll(), self)
+        res = coll.join(self._rank, (output_tensor_list, input_tensor_list))
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def allreduce(self, tensor_list, opts=AllreduceOptions()):
+        coll = ProcessLocalGroup._start_coll(AllReduce(opts.reduceOp), self)
+        res = coll.join(self._rank, tensor_list)
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def allreduce_coalesced(self, tensor_list, opts=AllreduceOptions()):
+        coll = ProcessLocalGroup._start_coll(AllReduce(opts.reduceOp), self)
+        res = coll.join(self._rank, tensor_list)
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def barrier(self, opts=BarrierOptions()):
+        return self.allreduce(tensor_list=[torch.ones(1)])
+
+    def allgather(self, output_tensors, input_tensor, opts=AllgatherOptions()):
+        coll = ProcessLocalGroup._start_coll(AllGather(), self)
+        res = coll.join(self._rank, (output_tensors, input_tensor))
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def _allgather_base(self, output_tensor, input_tensor, opts=AllgatherOptions()):
+        tensor_list = list(torch.chunk(output_tensor, self._world_size))
+        return self.allgather([tensor_list], [input_tensor], opts)
+
+    def broadcast(self, tensor_list, opts=BroadcastOptions()):
+        coll = ProcessLocalGroup._start_coll(Broadcast(opts.rootRank), self)
+        res = coll.join(self._rank, tensor_list)
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def scatter(self, output_tensors, input_tensors, opts=ScatterOptions()):
+        coll = ProcessLocalGroup._start_coll(Scatter(opts.rootRank), self)
+        res = coll.join(self._rank, (output_tensors, input_tensors))
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def gather(self, output_tensors, input_tensors, opts=ScatterOptions()):
+        coll = ProcessLocalGroup._start_coll(Gather(opts.rootRank), self)
+        res = coll.join(self._rank, (output_tensors, input_tensors))
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def reduce_scatter(self, output_tensor, scatter_list, opts=ReduceScatterOptions()):
+        coll = ProcessLocalGroup._start_coll(ReduceScatter(opts.reduceOp), self)
+        res = coll.join(self._rank, (output_tensor, scatter_list))
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def _reduce_scatter_base(
+        self, output_tensor, input_tensor, opts=ReduceScatterOptions()
+    ):
+        tensor_list = list(torch.chunk(input_tensor, self._world_size))
+        return self.reduce_scatter([output_tensor], [tensor_list], opts)
+
+    def reduce_scatter_tensor_coalesced(
+        self, output_tensors, input_tensors, opts=ReduceScatterOptions()
+    ):
+        works = [
+            self._reduce_scatter_base(output_tensor, input_tensor, opts)
+            for output_tensor, input_tensor in zip(
+                output_tensors, input_tensors, strict=True
+            )
+        ]
+        for work in works[:-1]:
+            work.wait()
+        return works[-1]
+
+    def allgather_into_tensor_coalesced(
+        self, output_tensor_list, input_tensor_list, opts=AllgatherOptions()
+    ):
+        res = None
+        for o_t, i_t in zip(output_tensor_list, input_tensor_list, strict=True):
+            res = self._allgather_base(o_t, i_t)
+        return res
+
+    def __init__(self, rank, world_size):
+        super().__init__(rank, world_size)
+        self._rank = rank
+        self._world_size = world_size
+        world = dist.distributed_c10d._world
+        if isinstance(world, ThreadLocalWorld):
+            world = world._get_world()
+        self._world = weakref.ref(world)
+        self._ctx = torch.autograd.set_multithreading_enabled(False)
+
+    def size(self):
+        return self._world_size
+
+    @property
+    def pg_name(self):
+        """
+        return the global registered name of the current pg in the world
+        """
+        return self._world().pg_names[self]
+
+    @property
+    def group_name(self):
+        return self.pg_name
+
+    def getBackendName(self):
+        return "threaded"
+
+    def __repr__(self):
+        return f"ThreadedPG world_size:{self._world_size} rank:{self._rank}"
+
+
+def _create_threaded_pg(prefix_store, rank, world_size, timeout):
+    pg = ProcessLocalGroup(rank, world_size)
+    # https://github.com/pytorch/pytorch/pull/103033 changed store based barrier to optional
+    # When device mesh involves sub groups while store based barrier is not enabled in c10d,
+    # even though threaded pg actual collectives are assumed to be single threaded,
+    # different threads may be initializing different groups,
+    # leading to race conditions.
+    # For example, if we have a mesh of [[0, 1], [2, 3]], the sub groups
+    # (dim 0 and 1) would be initialized in different threads independently.
+    # In this case we can no longer rely on class or global variables
+    # but have to rely on store based barrier to make sure each group
+    # is ready separately before we can invoke collectives in any of the groups.
+
+    # the prefix store is already per group so we pass an empty name here
+    _store_based_barrier(rank, prefix_store, "", world_size, timeout)
+    return pg
+
+
+dist.Backend.register_backend("threaded", _create_threaded_pg, devices=["cpu", "cuda"])
+
+
+@dataclass
+class WorldData:
+    default_pg: dist.ProcessGroup
+    pg_map: dict[dist.ProcessGroup, tuple[str, Optional[Store]]]
+    pg_names: dict[dist.ProcessGroup, str]
+    pg_group_ranks: dict[dist.ProcessGroup, dict[int, int]]
+    pg_backend_config: dict[dist.ProcessGroup, str]
+    group_count: int
+    tags_to_pg: dict[str, list[dist.ProcessGroup]]
+    pg_to_tag: dict[dist.ProcessGroup, str]
+    pg_coalesce_state: dict[dist.ProcessGroup, list[Union[_CollOp, P2POp]]]
+
+
+class ThreadLocalWorld:
+    _world = threading.local()
+
+    def _get_world(self) -> WorldData:
+        if not hasattr(ThreadLocalWorld._world, "world"):
+            ThreadLocalWorld._world.world = WorldData(
+                None, {}, {}, {}, {}, 0, {}, {}, {}
+            )
+        return ThreadLocalWorld._world.world
+
+    @property
+    def default_pg(self):
+        return self._get_world().default_pg
+
+    @default_pg.setter
+    def default_pg(self, value):
+        self._get_world().default_pg = value
+
+    @property
+    def pg_map(self):
+        return self._get_world().pg_map
+
+    @property
+    def pg_names(self):
+        return self._get_world().pg_names
+
+    @property
+    def pg_group_ranks(self):
+        return self._get_world().pg_group_ranks
+
+    @property
+    def pg_backend_config(self):
+        return self._get_world().pg_backend_config
+
+    @property
+    def group_count(self) -> int:
+        return self._get_world().group_count
+
+    @group_count.setter
+    def group_count(self, value):
+        self._get_world().group_count = value
+
+    @property
+    def tags_to_pg(self):
+        return self._get_world().tags_to_pg
+
+    @property
+    def pg_to_tag(self):
+        return self._get_world().pg_to_tag
+
+    @property
+    def pg_coalesce_state(self) -> dict[dist.ProcessGroup, list[Union[_CollOp, P2POp]]]:
+        return self._get_world().pg_coalesce_state
+
+
+_old_pg_world = None
+_ctx_manager = None
+
+
+def _install_threaded_pg():
+    global _old_pg_world
+    global _ctx_manager
+    _old_pg_world = dist.distributed_c10d._world
+    dist.distributed_c10d._world = ThreadLocalWorld()
+    _ctx_manager = torch.autograd.set_multithreading_enabled(False)
+
+    return dist.distributed_c10d._world
+
+
+def _uninstall_threaded_pg():
+    dist.distributed_c10d._world = _old_pg_world
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a24e4f97f05df22396dc08e3e6bc381085477882
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc_utils.py
@@ -0,0 +1,188 @@
+# mypy: allow-untyped-defs
+
+import os
+import sys
+import unittest
+
+from torch.testing._internal.common_distributed import MultiProcessTestCase
+from torch.testing._internal.common_utils import (
+    find_free_port,
+    IS_SANDCASTLE,
+    TEST_WITH_DEV_DBG_ASAN,
+)
+from torch.testing._internal.distributed.ddp_under_dist_autograd_test import (
+    CudaDdpComparisonTest,
+    DdpComparisonTest,
+    DdpUnderDistAutogradTest,
+)
+from torch.testing._internal.distributed.nn.api.remote_module_test import (
+    CudaRemoteModuleTest,
+    RemoteModuleTest,
+    ThreeWorkersRemoteModuleTest,
+)
+from torch.testing._internal.distributed.rpc.dist_autograd_test import (
+    CudaDistAutogradTest,
+    DistAutogradTest,
+    FaultyAgentDistAutogradTest,
+    TensorPipeAgentDistAutogradTest,
+    TensorPipeCudaDistAutogradTest,
+)
+from torch.testing._internal.distributed.rpc.dist_optimizer_test import (
+    DistOptimizerTest,
+)
+from torch.testing._internal.distributed.rpc.examples.parameter_server_test import (
+    ParameterServerTest,
+)
+from torch.testing._internal.distributed.rpc.examples.reinforcement_learning_rpc_test import (
+    ReinforcementLearningRpcTest,
+)
+from torch.testing._internal.distributed.rpc.faulty_agent_rpc_test import (
+    FaultyAgentRpcTest,
+)
+from torch.testing._internal.distributed.rpc.jit.dist_autograd_test import (
+    JitDistAutogradTest,
+)
+from torch.testing._internal.distributed.rpc.jit.rpc_test import JitRpcTest
+from torch.testing._internal.distributed.rpc.jit.rpc_test_faulty import (
+    JitFaultyAgentRpcTest,
+)
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+from torch.testing._internal.distributed.rpc.rpc_test import (
+    CudaRpcTest,
+    RpcTest,
+    TensorPipeAgentCudaRpcTest,
+    TensorPipeAgentRpcTest,
+)
+
+
+def _check_and_set_tcp_init():
+    # if we are running with TCP init, set main address and port
+    # before spawning subprocesses, since different processes could find
+    # different ports.
+    use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None)
+    if use_tcp_init == "1":
+        os.environ["MASTER_ADDR"] = "127.0.0.1"
+        os.environ["MASTER_PORT"] = str(find_free_port())
+
+
+def _check_and_unset_tcp_init():
+    use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None)
+    if use_tcp_init == "1":
+        del os.environ["MASTER_ADDR"]
+        del os.environ["MASTER_PORT"]
+
+
+# The tests for the RPC module need to cover multiple possible combinations:
+# - different aspects of the API, each one having its own suite of tests;
+# - different agents (ProcessGroup, TensorPipe, ...);
+# To avoid a combinatorial explosion in code size, and to prevent forgetting to
+# add a combination, these are generated automatically by the code in this file.
+# Here, we collect all the test suites that we need to cover.
+# We then have one separate file for each agent, from which
+# we call the generate_tests function of this file, passing to it a fixture for
+# the agent, which then gets mixed-in with each test suite.
+
+
+@unittest.skipIf(
+    TEST_WITH_DEV_DBG_ASAN,
+    "Skip ASAN as torch + multiprocessing spawn have known issues",
+)
+class SpawnHelper(MultiProcessTestCase):
+    def setUp(self):
+        super().setUp()
+        _check_and_set_tcp_init()
+        self._spawn_processes()
+
+    def tearDown(self):
+        _check_and_unset_tcp_init()
+        super().tearDown()
+
+
+# This list contains test suites that are agent-agnostic and that only verify
+# compliance with the generic RPC interface specification. These tests should
+# *not* make use of implementation details of a specific agent (options,
+# attributes, ...). These test suites will be instantiated multiple times, once
+# for each agent (except the faulty agent, which is special).
+GENERIC_TESTS = [
+    RpcTest,
+    ParameterServerTest,
+    DistAutogradTest,
+    DistOptimizerTest,
+    JitRpcTest,
+    JitDistAutogradTest,
+    RemoteModuleTest,
+    ThreeWorkersRemoteModuleTest,
+    DdpUnderDistAutogradTest,
+    DdpComparisonTest,
+    ReinforcementLearningRpcTest,
+]
+GENERIC_CUDA_TESTS = [
+    CudaRpcTest,
+    CudaDistAutogradTest,
+    CudaRemoteModuleTest,
+    CudaDdpComparisonTest,
+]
+
+
+# This list contains test suites that will only be run on the TensorPipeAgent.
+# These suites should be standalone, and separate from the ones in the generic
+# list (not subclasses of those!).
+TENSORPIPE_TESTS = [
+    TensorPipeAgentRpcTest,
+    TensorPipeAgentDistAutogradTest,
+]
+TENSORPIPE_CUDA_TESTS = [
+    TensorPipeAgentCudaRpcTest,
+    TensorPipeCudaDistAutogradTest,
+]
+
+
+# This list contains test suites that will only be run on the faulty RPC agent.
+# That agent is special as it's only used to perform fault injection in order to
+# verify the error handling behavior. Thus the faulty agent will only run the
+# suites in this list, which were designed to test such behaviors, and not the
+# ones in the generic list.
+FAULTY_AGENT_TESTS = [
+    FaultyAgentRpcTest,
+    FaultyAgentDistAutogradTest,
+    JitFaultyAgentRpcTest,
+]
+
+
+def generate_tests(
+    prefix: str,
+    mixin: type[RpcAgentTestFixture],
+    tests: list[type[RpcAgentTestFixture]],
+    module_name: str,
+) -> dict[str, type[RpcAgentTestFixture]]:
+    """Mix in the classes needed to autogenerate the tests based on the params.
+
+    Takes a series of test suites, each written against a "generic" agent (i.e.,
+    derived from the abstract RpcAgentTestFixture class), as the `tests` args.
+    Takes a concrete subclass of RpcAgentTestFixture, which specializes it for a
+    certain agent, as the `mixin` arg. Produces all combinations of them.
+    Returns a dictionary of class names to class type
+    objects which can be inserted into the global namespace of the calling
+    module. The name of each test will be a concatenation of the `prefix` arg
+    and the original name of the test suite.
+    The `module_name` should be the name of the calling module so
+    that the classes can be fixed to make it look like they belong to it, which
+    is necessary for pickling to work on them.
+    """
+    ret: dict[str, type[RpcAgentTestFixture]] = {}
+    for test_class in tests:
+        if IS_SANDCASTLE and TEST_WITH_DEV_DBG_ASAN:
+            print(
+                f"Skipping test {test_class} on sandcastle for the following reason: "
+                "Skip dev-asan as torch + multiprocessing spawn have known issues",
+                file=sys.stderr,
+            )
+            continue
+
+        name = f"{prefix}{test_class.__name__}"
+        class_ = type(name, (test_class, mixin, SpawnHelper), {})
+        class_.__module__ = module_name
+        ret[name] = class_
+    return ret
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/generated/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/generated/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..32e1dde3560b7353b0be79c123ada63271af68e8
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/generated/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e83b07b9898d3bcc6ee3437d5cc8e19f5df0ceda
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/aot_autograd.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/aot_autograd.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..42ed04376da127a5dc0260de80533a673302e2e6
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/aot_autograd.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/autograd_registration.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/autograd_registration.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f64521610a884383da176359649c7ceffc7a7429
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/autograd_registration.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/fake_tensor.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/fake_tensor.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e4c08e1fe086523e1a9a994eff4419ddb39a1fd6
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/fake_tensor.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/generate_tests.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/generate_tests.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..25f7b6f887e3b9cabd8708c616e080eef8d2e2dd
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/generate_tests.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/make_fx.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/make_fx.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5c16440d2207cd1adee4b4cf71ca533bf789f0f2
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/make_fx.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/test_module/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/test_module/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..50295d62b48de25c9b3db73ec9b963df9d8b13d9
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/test_module/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/test_module/__pycache__/future_div.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/test_module/__pycache__/future_div.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c04ee0d8aa5118fa7aef8608191e38007848a6a4
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/test_module/__pycache__/future_div.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/test_module/__pycache__/no_future_div.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/test_module/__pycache__/no_future_div.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d461663ff7f92ee5807c7ea6614fc1308cbafef4
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/test_module/__pycache__/no_future_div.cpython-312.pyc differ